├── .gitignore ├── .idea ├── .gitignore ├── AI-Phone-Call.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml └── modules.xml ├── 3D ├── CallPhone.skp ├── CallPhone_SU.png ├── SketchUp.md ├── img.png ├── img_1.png ├── img_2.png ├── img_3.png └── img_4.png ├── AI.py ├── AtSerialHelper.py ├── AudioHelper.py ├── CallHelper.py ├── Config.py ├── HelloTTSGenerator.py ├── LiveData.py ├── Main.py ├── README.md ├── SendCommandHelper.py ├── SmsHelper.py ├── VoiceCall.py ├── aliyun_asr.py ├── aliyun_tts.py ├── audio └── say_hello.pcm ├── audio_resource.py ├── config.yaml ├── logger.py ├── nls ├── __init__.py ├── core.py ├── exception.py ├── logging.py ├── speech_recognizer.py ├── speech_synthesizer.py ├── speech_transcriber.py ├── stream_input_tts.py ├── token.py ├── util.py ├── version.py └── websocket │ ├── __init__.py │ ├── _abnf.py │ ├── _app.py │ ├── _cookiejar.py │ ├── _core.py │ ├── _exceptions.py │ ├── _handshake.py │ ├── _http.py │ ├── _logging.py │ ├── _socket.py │ ├── _ssl_compat.py │ ├── _url.py │ └── _utils.py ├── notify_to_master.py ├── pdu_decoder.py ├── pdu_exceptions.py ├── push_to_qiye_wx.py ├── requirements.txt ├── say_hello.pcm ├── screenshots ├── 4g_minipcie.png ├── ai_phon_call.png ├── create_app.png ├── pi_4b.png ├── pi_4b_2.png └── pi_cm4.png └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | ._.idea 2 | ._* 3 | .idea 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/AI-Phone-Call.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 26 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /3D/CallPhone.skp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/3D/CallPhone.skp -------------------------------------------------------------------------------- /3D/CallPhone_SU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/3D/CallPhone_SU.png -------------------------------------------------------------------------------- /3D/SketchUp.md: -------------------------------------------------------------------------------- 1 | # SketchUp 2 | 3 | ## 使用 SketchUp 画图 4 | 5 | ![CallPhone_SU.png](CallPhone_SU.png) 6 | 7 | ## 打印文件 8 | 9 | [CallPhone.skp](3D%2FCallPhone.skp) 10 | 11 | ## 打印后的效果 12 | 13 | ![img.png](img.png) 14 | 15 | ![img.png](img_1.png) 16 | 17 | ![img.png](img_2.png) 18 | 19 | ![img.png](img_3.png) 20 | 21 | ![img.png](img_4.png) 22 | 23 | -------------------------------------------------------------------------------- /3D/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/3D/img.png -------------------------------------------------------------------------------- /3D/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/3D/img_1.png -------------------------------------------------------------------------------- /3D/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/3D/img_2.png -------------------------------------------------------------------------------- /3D/img_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/3D/img_3.png -------------------------------------------------------------------------------- /3D/img_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/3D/img_4.png -------------------------------------------------------------------------------- /AI.py: -------------------------------------------------------------------------------- 1 | from Config import Config 2 | import json 3 | from http import HTTPStatus 4 | import dashscope 5 | import logger 6 | 7 | # Proxyman Code Generator (1.0.0): Python + Request 8 | # GET https://chatai.mixerbox.com/api/chat 9 | 10 | class AI: 11 | def __init__(self): 12 | config = Config.get_instance() 13 | self.system_prompt = config.get("system_prompt") 14 | self.say_hello = config.get("say_hello") 15 | 16 | self.system_prompt = [ 17 | {'role': 'system', 18 | 'content': self.system_prompt 19 | } 20 | ] 21 | 22 | self.say_hello_prompt = [ 23 | {"role": "assistant", 24 | "content": self.say_hello 25 | } 26 | ] 27 | 28 | self.qa_history = [] 29 | 30 | def ai(self, q: str, callback=None): 31 | self.send_request_aliyun(self.qa_history, q, callback) 32 | 33 | def read_all_call_history(self): 34 | all_history = self.say_hello_prompt + self.qa_history 35 | result = "" 36 | for i in all_history: 37 | role = i['role'] 38 | if role == "assistant": 39 | result += "+助理:" + i["content"] + "\n" 40 | else: 41 | result += "*对方:" + i["content"] + "\n" 42 | return result 43 | 44 | def clear_call_history(self): 45 | self.qa_history = [] 46 | 47 | def send_request_aliyun(self, history: list, question: str, callback=None): 48 | config = Config.get_instance() 49 | dashscope.api_key = config.get("api_key") # 修改这里 50 | 51 | q = {"role": "user", "content": question} 52 | # 从 qa_history 取元素,最多取最后 5 个 53 | fixed_history = history[-15:] 54 | prompt = self.system_prompt + self.say_hello_prompt + fixed_history 55 | prompt.append(q) 56 | 57 | logger.d(f"prompt: {prompt}") 58 | 59 | response = dashscope.Generation.call( 60 | dashscope.Generation.Models.qwen_plus, 61 | messages=prompt, 62 | result_format='message', # 将返回结果格式设置为 message 63 | ) 64 | if response.status_code == HTTPStatus.OK: 65 | logger.e("-----------------------------------1") 66 | # dict 转 json 67 | json_str = json.dumps(response, indent=4, ensure_ascii=False) 68 | logger.i(json_str) 69 | 70 | # 把 response.content 解码成Str 71 | # 把 response_str 转换成 JSON 对象 72 | logger.e("-----------------------------------2") 73 | response_json = json.loads(json_str) 74 | logger.e("----------------------------------3") 75 | # logger.d(response_json) 76 | logger.e("-----------------------------------4") 77 | # 把 response_json 中的 "output" 字段取出来 78 | response_output = response_json["output"] 79 | # 把 response_output 中的 "choices" 字段取出来 80 | response_choices = response_output["choices"] 81 | # 把 response_choices 中的第一个元素取出来 82 | response_choice = response_choices[0] 83 | # 把 response_choice 中的 "message" 字段取出来 84 | response_message = response_choice["message"] 85 | # 把 response_message 中的 "content" 字段取出来 86 | response_content = response_message["content"] 87 | # 打印 response_content 88 | logger.i(response_content) 89 | self.qa_history.append(q) 90 | self.qa_history.append(response_message) 91 | if callback is not None: 92 | callback(response_content) 93 | else: 94 | logger.e('Request id: %s, Status code: %s, error code: %s, error message: %s' % ( 95 | response.request_id, response.status_code, 96 | response.code, response.message 97 | )) 98 | 99 | 100 | if __name__ == '__main__': 101 | ai = AI() 102 | ai.ai("你好") 103 | -------------------------------------------------------------------------------- /AtSerialHelper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | import time 4 | 5 | import serial 6 | 7 | import logger 8 | from CallHelper import CallHelper 9 | from SmsHelper import SmsHelper 10 | from SendCommandHelper import SendCommandHelper 11 | 12 | 13 | class AtSerialHelper: 14 | def __init__(self, at_ser=None, baud_rate=115200, audio_helper=None): 15 | try: 16 | self.at_ser = serial.Serial(at_ser, baud_rate, timeout=2) 17 | except Exception as e: 18 | self.at_ser = None 19 | logger.e("初始化AT串口失败:" + str(e)) 20 | 21 | self.is_need_read_at_command_data = False 22 | self.at_command_read_thread = None 23 | 24 | self.sms_helper = SmsHelper(self) 25 | self.call_helper = CallHelper(self, audio_helper) 26 | 27 | self.send_command_helper = SendCommandHelper(self.at_ser, self.sms_helper, self.call_helper) 28 | 29 | def read_at_command_data(self): 30 | if self.at_ser is None: 31 | logger.e("AT串口未初始化") 32 | return 33 | 34 | while self.is_need_read_at_command_data: 35 | try: 36 | data: bytes = self.at_ser.readline() 37 | if data is None or data == b'': 38 | time.sleep(0.01) 39 | continue 40 | 41 | if data == b'\r\n': 42 | continue 43 | 44 | data_string = data.decode().strip() 45 | if data_string == "": 46 | continue 47 | 48 | # 开始处理串口数据 49 | if self.call_helper.handle_call(data_string): 50 | continue 51 | elif self.sms_helper.handle_sms(data_string): 52 | continue 53 | elif self.send_command_helper.handle_command_result(data_string): 54 | continue 55 | else: 56 | logger.e("串口读取到数据:" + data_string) 57 | except Exception as e: 58 | # 执行lsof,查看串口是否被占用 59 | logger.e("读取串口数据异常:" + str(e)) 60 | f = os.popen("lsof | grep /dev/ttyUSB2") 61 | logger.e(f.read()) 62 | 63 | def current_write_at_command(self): 64 | return self.send_command_helper.wait_result_at_command 65 | 66 | def start_read_serial_thread(self): 67 | if self.at_command_read_thread is not None: 68 | logger.d("串口读取线程已经启动,无需重复启动") 69 | return 70 | self.is_need_read_at_command_data = True 71 | self.at_command_read_thread = threading.Thread(target=self.read_at_command_data) 72 | self.at_command_read_thread.start() 73 | 74 | def prepare(self, debug=False): 75 | self.call_helper.prepare() 76 | self.sms_helper.prepare() 77 | # 开启等待接听电话 78 | if debug: 79 | try: 80 | while True: 81 | user_input = input("输入命令:\n") 82 | if user_input == "\n" or user_input == "": 83 | continue 84 | self.send_command_helper.write_at_command(user_input) 85 | time.sleep(1) 86 | except KeyboardInterrupt as e: 87 | logger.e("用户终止程序") 88 | else: 89 | self.at_command_read_thread.join() 90 | -------------------------------------------------------------------------------- /AudioHelper.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | 4 | import serial 5 | from LiveData import LiveData 6 | import logger 7 | 8 | 9 | class AudioHelper: 10 | def __init__(self, audio_ser=None, baud_rate=115200): 11 | try: 12 | self.audio_ser = serial.Serial(audio_ser, baud_rate, timeout=1) 13 | except Exception as e: 14 | logger.e("初始化音频串口失败:" + str(e)) 15 | self.is_calling = False 16 | self.call_audio_data_read_thread = None 17 | self.call_audio_data_livedata = LiveData() 18 | 19 | def __read_audio_data(self): 20 | while self.is_calling: 21 | data = self.audio_ser.read(640) 22 | if data and data != b'': 23 | self.call_audio_data_livedata.value = data 24 | time.sleep(0.01) 25 | logger.i("通话结束,循环读取音频数据已经结束") 26 | 27 | def write_audio_data(self, data): 28 | self.audio_ser.write(data) 29 | 30 | def start_audio_read_thread(self): 31 | if self.call_audio_data_read_thread is not None: 32 | logger.i("音频读取线程已经启动,无需重复启动") 33 | return 34 | logger.d("启动音频读取线程,开始读取音频数据") 35 | self.is_calling = True 36 | self.call_audio_data_read_thread = threading.Thread(target=self.__read_audio_data) 37 | self.call_audio_data_read_thread.start() 38 | 39 | def stop_audio_read_thread(self): 40 | self.is_calling = False 41 | self.call_audio_data_read_thread = None 42 | logger.d("停止读取音频的线程") 43 | -------------------------------------------------------------------------------- /CallHelper.py: -------------------------------------------------------------------------------- 1 | from LiveData import LiveData 2 | from VoiceCall import VoiceCall 3 | import threading 4 | import time 5 | import logger 6 | from audio_resource import say_hello_pcm_file 7 | 8 | 9 | class CallHelper: 10 | def __init__(self, at_serial_helper, audio_helper): 11 | self.at_serial_helper = at_serial_helper 12 | self.pickup_lock = threading.Lock() 13 | self.call_status = LiveData() 14 | self.is_pickup = False 15 | self.ring_count = 0 16 | self.say_hello_pcm_file = say_hello_pcm_file() 17 | self.audio_helper = audio_helper 18 | 19 | def __call_no_carrier(self): 20 | with self.pickup_lock: 21 | if self.is_pickup: 22 | self.is_pickup = False 23 | self.call_status.value = VoiceCall(VoiceCall.VOICE_CALL_NO_CARRIER) 24 | self.at_serial_helper.send_command_helper.write_at_command("AT+CPCMREG=0", delay=1.5) 25 | 26 | def __call_start(self): 27 | time.sleep(0.1) # 等ATA命令返回OK,以保证Log打印的时间顺序看起来是对的 28 | self.call_status.value = VoiceCall(VoiceCall.VOICE_CALL_BEGIN) 29 | self.is_pickup = True 30 | 31 | def __call_end(self): 32 | with self.pickup_lock: 33 | if self.is_pickup: 34 | self.is_pickup = False 35 | self.call_status.value = VoiceCall(VoiceCall.VOICE_CALL_END) 36 | self.at_serial_helper.send_command_helper.write_at_command("AT+CPCMREG=0", delay=1.5) 37 | 38 | def __call_missed(self): 39 | self.is_pickup = False 40 | self.call_status.value = VoiceCall(VoiceCall.VOICE_CALL_MISSED) 41 | 42 | def __pick_up_inner(self): 43 | self.is_pickup = True 44 | self.ring_count = 0 45 | # 等待0.5秒,等待音频串口准备好, 否则串口可能返回ERROR 46 | self.at_serial_helper.send_command_helper.write_at_command("AT+CPCMREG=1", delay=0.5) 47 | time.sleep(0.6) 48 | # 去读取音频数据, 每次读取640字节 49 | with open(self.say_hello_pcm_file, "rb") as f: 50 | while self.is_pickup: 51 | data = f.read(640) 52 | if not data: 53 | break 54 | self.audio_helper.write_audio_data(data) 55 | time.sleep(0.01) 56 | time.sleep(0.5) 57 | self.call_status.value = VoiceCall(VoiceCall.VOICE_CALL_SAY_HELLO_DONE) 58 | 59 | def pick_up(self, say_hello_pcm_file): 60 | logger.d("发送 ATA 命令,接听电话") 61 | self.at_serial_helper.send_command_helper.write_at_command("ATA", delay=0) 62 | threading.Thread(target=self.__pick_up_inner).start() 63 | 64 | def hang_up(self): 65 | logger.d("发送 ATH 命令,挂断电话") 66 | self.is_pickup = False 67 | self.at_serial_helper.send_command_helper.write_at_command("ATH", delay=0) 68 | # self.call_status.value = VoiceCall(VoiceCall.VOICE_CALL_END) 69 | 70 | def is_in_voice_calling(self): 71 | return self.is_pickup 72 | 73 | def handle_call(self, decode_string): 74 | # logger.e("AT串口返回:" + decode_string) 75 | if (decode_string.find("RING") != -1) and (not self.is_pickup): 76 | self.call_status.value = VoiceCall(VoiceCall.VOICE_CALL_RING) 77 | # logger.i("收到1 : " + decode_string) 78 | return True 79 | # 判断字符串是不是以 +CLIP: 开头 80 | elif decode_string.startswith("+CLIP: \"") and (not self.is_pickup): # +CLIP: "13200000000",161,,,,0 81 | # logger.i("收到2 : " + decode_string) 82 | """ 83 | Calling Line Identification Presentation 84 | 允许在接收电话呼叫时,接收方可以看到呼叫方的电话号码。 85 | """ 86 | self.ring_count += 1 87 | if decode_string.find('"') != -1: 88 | split_strings = decode_string.split('"') 89 | if len(split_strings) >= 2: 90 | phone_number = split_strings[1] 91 | self.call_status.value = VoiceCall(status=VoiceCall.VOICE_CALL_CLIP, 92 | phone_number=phone_number, ring_count=self.ring_count) 93 | return True 94 | self.call_status.value = VoiceCall(status=VoiceCall.VOICE_CALL_CLIP, phone_number="UNKNOWN", 95 | ring_count=self.ring_count) 96 | return True 97 | elif decode_string.find("VOICE CALL: BEGIN") != -1: 98 | # logger.i("收到3 : " + decode_string) 99 | threading.Thread(target=self.__call_start).start() 100 | return True 101 | elif decode_string.find("VOICE CALL: END:") != -1: 102 | # logger.i("收到4 : " + decode_string) 103 | threading.Thread(target=self.__call_end).start() 104 | return True 105 | elif decode_string.find("MISSED_CALL:") != -1: 106 | # logger.i("收到5 : " + decode_string) 107 | threading.Thread(target=self.__call_missed).start() 108 | return True 109 | elif decode_string.find("NO CARRIER") != -1: 110 | # logger.i("收到6 : " + decode_string) 111 | threading.Thread(target=self.__call_no_carrier).start() 112 | return True 113 | else: 114 | return False 115 | 116 | def prepare(self): 117 | # 启动线程 118 | self.at_serial_helper.send_command_helper.write_at_command("AT+CLIP=?") 119 | 120 | self.at_serial_helper.send_command_helper.write_at_command("AT+CLIP=1") 121 | # write_at_command("AT+CPCMFRM=1") 设置采样率为16k,默认为8k https://techship.com/support/faq/voice-calls-and-usb-audio/ 122 | 123 | self.at_serial_helper.send_command_helper.write_at_command("AT+CNMP=?") 124 | self.at_serial_helper.send_command_helper.write_at_command("AT+CNMP=2") 125 | 126 | self.at_serial_helper.send_command_helper.write_at_command("AT+CPCMREG=?") 127 | 128 | # 主要改善TDD noise效果 129 | self.at_serial_helper.send_command_helper.write_at_command("AT^PWRCTL=0,1,3") 130 | # 重置状态 131 | self.at_serial_helper.send_command_helper.write_at_command("AT+CPCMREG=0") 132 | -------------------------------------------------------------------------------- /Config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | import audio_resource as assets 4 | 5 | 6 | class Config: 7 | _instance = None 8 | 9 | def __new__(cls): 10 | if cls._instance is None: 11 | cls._instance = super().__new__(cls) 12 | cls._instance._load_config() 13 | return cls._instance 14 | 15 | def __init__(self): 16 | self._load_config() 17 | 18 | def _load_config(self): 19 | config_yaml = assets.config_file() 20 | with open(config_yaml, 'r', encoding='utf-8') as file: 21 | self.config = yaml.safe_load(file) 22 | # self.print_all() 23 | 24 | def get(self, key, default=None): 25 | return self.config.get(key, default) 26 | 27 | @classmethod 28 | def get_instance(cls): 29 | if cls._instance is None: 30 | cls._instance = cls() 31 | return cls._instance 32 | 33 | def print_all(self): 34 | for key, value in self.config.items(): 35 | print(f"{key}: {value}") 36 | 37 | def get_app_key(self): 38 | return self.config.get('app_key') 39 | 40 | def get_api_key(self): 41 | return self.config.get('api_key') 42 | 43 | def get_ak_id(self): 44 | return self.config.get('ak_id') 45 | 46 | def get_ak_secret(self): 47 | return self.config.get('ak_secret') 48 | 49 | def get_model(self): 50 | return self.config.get('model') 51 | 52 | def get_system_prompt(self): 53 | return self.config.get('system_prompt') 54 | 55 | def get_say_hello(self): 56 | return self.config.get('say_hello') 57 | 58 | def get_service_url(self): 59 | return self.config.get('service_url') 60 | 61 | 62 | if __name__ == '__main__': 63 | config = Config.get_instance() 64 | print("所有配置项:") 65 | config.print_all() 66 | -------------------------------------------------------------------------------- /HelloTTSGenerator.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import nls 4 | from nls.token import getToken 5 | from Config import Config 6 | 7 | config = Config.get_instance() 8 | TEST_ACCESS_APPKEY = config.get("app_key") # 使用Config获取app_key 9 | 10 | TEXT = '围,您好,我是王先生的私人秘书,您找他有什么事情吗?' 11 | 12 | 13 | class HelloTTSGenerator: 14 | def __init__(self, tid, test_file): 15 | self.__th = threading.Thread(target=self.__test_run) 16 | self.__id = tid 17 | self.__test_file = test_file 18 | 19 | def start(self, text): 20 | self.__text = text 21 | self.__f = open(self.__test_file, "wb") 22 | self.__th.start() 23 | 24 | def test_on_metainfo(self, message, *args): 25 | print("on_metainfo message=>{}".format(message)) 26 | 27 | def test_on_error(self, message, *args): 28 | print("on_error args=>{}".format(args)) 29 | 30 | def test_on_close(self, *args): 31 | print("on_close: args=>{}".format(args)) 32 | try: 33 | self.__f.close() 34 | except Exception as e: 35 | print("close file failed since:", e) 36 | 37 | def test_on_data(self, data, *args): 38 | try: 39 | self.__f.write(data) 40 | except Exception as e: 41 | print("write data failed:", e) 42 | 43 | def test_on_completed(self, message, *args): 44 | print("on_completed:args=>{} message=>{}".format(args, message)) 45 | 46 | def __test_run(self): 47 | ak_id = config.get("ak_id") 48 | ak_secret = config.get("ak_secret") 49 | info = getToken(ak_id, ak_secret) 50 | print(info) 51 | 52 | print("thread:{} start..".format(self.__id)) 53 | tts = nls.NlsSpeechSynthesizer( 54 | token=info, 55 | appkey=TEST_ACCESS_APPKEY, 56 | long_tts=False, 57 | on_metainfo=self.test_on_metainfo, 58 | on_data=self.test_on_data, 59 | on_completed=self.test_on_completed, 60 | on_error=self.test_on_error, 61 | on_close=self.test_on_close, 62 | callback_args=[self.__id] 63 | ) 64 | 65 | print("{}: session start".format(self.__id)) 66 | r = tts.start(self.__text, sample_rate=8000, voice="zhiyuan", ex={'enable_subtitle': False}) 67 | print("{}: tts done with result:{}".format(self.__id, r)) 68 | 69 | 70 | if __name__ == '__main__': 71 | nls.enableTrace(True) 72 | t = HelloTTSGenerator("thread1", "say_hello.pcm") 73 | t.start(TEXT) 74 | -------------------------------------------------------------------------------- /LiveData.py: -------------------------------------------------------------------------------- 1 | class LiveData: 2 | def __init__(self): 3 | self._value = None 4 | self._observers = [] 5 | 6 | @property 7 | def value(self): 8 | return self._value 9 | 10 | @value.setter 11 | def value(self, new_value): 12 | self._value = new_value 13 | for callback in self._observers: 14 | callback(new_value) 15 | 16 | def observe(self, callback): 17 | self._observers.append(callback) 18 | -------------------------------------------------------------------------------- /Main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import logger 4 | import notify_to_master 5 | from AtSerialHelper import AtSerialHelper 6 | from AudioHelper import AudioHelper 7 | from VoiceCall import VoiceCall 8 | from aliyun_asr import Asr 9 | from aliyun_tts import Tts 10 | import json 11 | from AI import AI 12 | from audio_resource import say_hello_pcm_file 13 | 14 | 15 | class Main: 16 | def __init__(self): 17 | self.tts: Tts 18 | self.asr: Asr 19 | self.ai = AI() 20 | self.at_serial_helper = None 21 | self.audio_helper = None 22 | self.is_wait_tts_back = False 23 | self.is_ai_speaking = False 24 | self.call_from_number = None 25 | 26 | def handle_call_status(self, voice_call: VoiceCall): 27 | call_status = voice_call.status 28 | if call_status == VoiceCall.VOICE_CALL_RING: 29 | logger.i("正在播放来电铃声...") 30 | elif call_status == VoiceCall.VOICE_CALL_CLIP: 31 | self.call_from_number = "来电号码: " + voice_call.phone_number 32 | logger.i(self.call_from_number + ", " + "铃声次数: " + str(voice_call.ring_count)) 33 | self.audio_helper.start_audio_read_thread() 34 | 35 | if voice_call.ring_count == 2 and not self.at_serial_helper.call_helper.is_in_voice_calling(): 36 | self.asr.start() # 开启语音识别线程 37 | self.at_serial_helper.call_helper.pick_up(say_hello_pcm_file=say_hello_pcm_file()) 38 | self.is_ai_speaking = True 39 | elif call_status == VoiceCall.VOICE_CALL_BEGIN: 40 | logger.i("通话开始") 41 | elif call_status == VoiceCall.VOICE_CALL_SAY_HELLO_DONE: 42 | logger.i("播放喂,您好完成") 43 | self.is_ai_speaking = False 44 | elif (call_status == VoiceCall.VOICE_CALL_END 45 | or call_status == VoiceCall.VOICE_CALL_MISSED 46 | or call_status == VoiceCall.VOICE_CALL_NO_CARRIER): 47 | 48 | logger.i("通话结束: " + call_status) 49 | self.tts.stop() 50 | self.asr.stop() 51 | self.audio_helper.stop_audio_read_thread() 52 | 53 | # 去读通话记录,推送给微信 54 | all_history = self.ai.read_all_call_history() 55 | notify_to_master.notify(self.call_from_number + "\n" + all_history) 56 | self.ai.clear_call_history() 57 | elif call_status == VoiceCall.VOICE_CALL_MISSED: 58 | logger.i("通话结束: " + call_status) 59 | self.tts.stop() 60 | self.asr.stop() 61 | self.audio_helper.stop_audio_read_thread() 62 | 63 | # 去读通话记录,推送给微信 64 | notify_to_master.notify("漏接电话:" + self.call_from_number) 65 | else: 66 | logger.d("其他状态:" + call_status) 67 | 68 | def handle_call_audio(self, voice_pcm_data): 69 | if self.at_serial_helper.call_helper.is_in_voice_calling(): 70 | if not self.is_wait_tts_back: 71 | if not self.is_ai_speaking: 72 | self.asr.send_audio(voice_pcm_data) 73 | else: 74 | logger.e("从 USB 串口读取到音频数据,正在AI正在讲话,不用发送给阿里云ASR") 75 | else: 76 | logger.e("从 USB 串口读取到音频数据,但正在等待TTS返回,不用发送给阿里云ASR") 77 | else: 78 | logger.e("从 USB 串口读取到音频数据,不在通话中,不需要发送音频数据给阿里云ASR") 79 | 80 | def hand_sms_received(self, new_value): 81 | phone_number = new_value['number'] 82 | time = new_value['time'].strftime('%Y-%m-%d %H:%M:%S') 83 | text = new_value['text'] 84 | format_sms = phone_number + "\n" + time + "\n------ ------ ------\n" + text 85 | 86 | logger.i(f"开始推送给微信:{format_sms}") 87 | notify_to_master.notify(format_sms) 88 | 89 | def handle_ai_answer(self, new_value): 90 | self.text_to_voice(new_value) 91 | 92 | def observe_aliyun_asr_result(self, new_value): 93 | """ 94 | 阿里云语音识别结果 95 | """ 96 | json_data = json.loads(new_value) 97 | to_tts_text = json_data['payload']['result'] 98 | if to_tts_text == "嗯。": 99 | logger.e("只是一个嗯,不需要回答") 100 | else: 101 | # 发送给AI,让AI回答 102 | logger.d("开始发送数据给AI,等待AI回复:" + to_tts_text) 103 | self.ai.ai(to_tts_text, callback=self.handle_ai_answer) 104 | 105 | def text_to_voice(self, text): 106 | self.is_wait_tts_back = True 107 | self.tts.start(text) 108 | 109 | def observe_aliyun_tts_result(self, voice_data): 110 | self.is_ai_speaking = True 111 | if self.at_serial_helper.call_helper.is_in_voice_calling(): 112 | self.audio_helper.write_audio_data(voice_data) 113 | else: 114 | logger.e("不在通话中,不需要发送TTS音频数据给USB串口") 115 | 116 | def observe_aliyun_tts_status(self, new_value): 117 | # logger.d("阿里云TTS状态:" + new_value) 118 | if new_value == "completed" or new_value == "error" or new_value == "close": 119 | self.is_wait_tts_back = False 120 | self.is_ai_speaking = False 121 | 122 | def observe_aliyun_asr_status(self, asr_status): 123 | # logger.d("阿里云ASR状态:" + asr_status) 124 | pass 125 | 126 | def start_ai_call(self): 127 | self.asr = Asr("thread_asr") 128 | self.asr.asr_result_livedata.observe(self.observe_aliyun_asr_result) 129 | self.asr.asr_status_livedata.observe(self.observe_aliyun_asr_status) 130 | 131 | self.tts = Tts("thread_tts") 132 | self.tts.tts_result_livedata.observe(self.observe_aliyun_tts_result) 133 | self.tts.tts_status_livedata.observe(self.observe_aliyun_tts_status) 134 | 135 | # 查看是否存在 /dev/ttyUSB* 设备,判断文件是否存在 136 | if not os.path.exists('/dev/ttyUSB4') or not os.path.exists('/dev/ttyUSB2'): 137 | logger.e("串口设备不存在") 138 | exit(1) 139 | # os.system("ls -l /dev/ttyUSB*") 140 | 141 | # 执行shell命令,chmod 777 /dev/ttyUSB*,给予USB串口读写权限 142 | # os.system("sudo chmod 777 /dev/ttyUSB*") 143 | 144 | self.audio_helper = AudioHelper('/dev/ttyUSB4', 115200) 145 | self.audio_helper.call_audio_data_livedata.observe(self.handle_call_audio) 146 | 147 | self.at_serial_helper = AtSerialHelper('/dev/ttyUSB2', 115200, self.audio_helper) 148 | self.at_serial_helper.call_helper.call_status.observe(self.handle_call_status) 149 | self.at_serial_helper.sms_helper.one_sms_livedata.observe(self.hand_sms_received) 150 | 151 | self.at_serial_helper.start_read_serial_thread() 152 | logger.d("prepare") 153 | 154 | def loop_for_call_in(self): 155 | self.at_serial_helper.prepare(debug=False) 156 | 157 | 158 | if __name__ == '__main__': 159 | main = Main() 160 | main.start_ai_call() 161 | main.loop_for_call_in() 162 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI Phone Call 2 | 3 | ## B站视频 4 | https://www.bilibili.com/video/BV1724Ce1EUs 5 | 6 | ## 简介 7 | 这是一个基于大模型服务的“电话助手”,可以通过配置“提示词”对`语音通话`和`文字短信`进行处理。 8 | 需要的硬件设备为`树莓派`、`4G模块`和一张`SIM卡`。 9 | 10 | ### 能干什么 11 | > 具体的处理效果取决于提示词的配置和大模型的智能程度。 12 | 13 | - 替你接电话 14 | - 可以识别“垃圾电话”和“正常电话”,并根据提示词进行回复。 15 | - 可以告诉快递小哥把快递放门口等操作 16 | - 对抗其他“AI骚扰电话”,反正接电话不花钱,就跟他们聊天呗。 17 | - 替你收短信 18 | - 目前知识把短信内容通过企业微信发送给你,你可以在企业微信中查看短信内容。 19 | - 其实可以有很多玩法,比如替你定时发短信之类的,目前没啥需求就没做。 20 | 21 | 22 | ![ai_phon_call.png](screenshots/ai_phon_call.png) 23 | 24 | ## 硬件连接 25 | ### 无树莓派方案 26 | 只需把下面的树莓派,换成安装Linux的电脑即可。 27 | 28 | ## 树莓派方案 29 | ### 方案一 30 | ![pi_4b.png](screenshots/pi_4b.png) 31 | 32 | ### 方案二 33 | ![pi_4b_2.png](screenshots/pi_4b_2.png) 34 | 35 | ### 方案三 36 | > 为了体积更小,可以使用【树莓派CM4+扩展板】代替树莓派4B。 37 | 38 | ![pi_cm4.png](screenshots/pi_cm4.png) 39 | 40 | ### 最小硬件连接方案 41 | > 【树莓派CM4+扩展板】+【4G模块转接板MiniPcie转USB】 42 | 43 | ![4g_minipcie.png](screenshots/4g_minipcie.png) 44 | 45 | ![img_4.png](3D/img_4.png) 46 | 47 | > 配合3D打印外壳的效果如下: 48 | 49 | ![img_3.png](3D/img_3.png) 50 | 51 | 52 | ## 运行环境 53 | - python3 54 | ```shell 55 | pip install -r requirements.txt 56 | ``` 57 | 58 | ## 配置与运行 59 | 60 | ### 1. `config.yaml` 配置文件 61 | ```yaml 62 | # 访问阿里云 https://nls-portal.console.aliyun.com/applist 63 | # 创建一个实时语音识别应用,获取 “项目Appkey”,注意项目类型要选:“语音识别 + 语音合成 + 语音分析” 64 | # 截图参考:screenshots/create_app.png 65 | app_key: 66 | "换成你的appkey" 67 | 68 | # https://dashscope.console.aliyun.com/apiKey 69 | # 创建API-KEY 70 | api_key: 71 | "换成你的api_key" 72 | 73 | # https://ram.console.aliyun.com/manage/ak 74 | # 创建“AccessKey ID” 和 “AccessKey Secret” 75 | # 为降低 AccessKey 泄露的风险,自 2023 年 7 月 5 日起,新建的主账号 AccessKey 只在创建时提供 Secret,后续不可再进行查询,请保存好Secret。 76 | ak_id: 77 | "换成你的ak_id" 78 | ak_secret: 79 | "换成你的ak_secret" 80 | 81 | # https://dashscope.console.aliyun.com/model 82 | # 模型广场挑选一个模型,获取“模型名称”。 83 | # https://help.aliyun.com/zh/model-studio/getting-started/models 84 | # 也可以直接从 dashscope.Generation.Models 中选一个。 85 | # bailian_v1 = 'bailian-v1' 86 | # dolly_12b_v2 = 'dolly-12b-v2' 87 | # qwen_turbo = 'qwen-turbo' 88 | # qwen_plus = 'qwen-plus' 89 | # qwen_max = 'qwen-max' 90 | model: 91 | "qwen-plus" 92 | 93 | # 企业微信配置 94 | # 目前貌似不再支持在微信中接收新消息且调用接口需要IP白名单,只能在企业微信APP收消息的样子,限制很多,除非是老应用(很早时期申请的并已经设置好相关设置)否则本消息通道体验会变差。 95 | qiye_weixin: 96 | secret: "换成你的secret" 97 | qiye_id: "换成你的qiye_id" 98 | agent_id: "换成你的agent_id" 99 | 100 | # 换成你自己的提示词或者直接使用下面的提示词 101 | system_prompt: | 102 | 你是王先生的私人电话秘书,你要帮助王先生接听电话,你的对话场景是"接听电话"。 103 | 你要根据通话内容识别"营销电话"还是正常电话,如果接到快递员的电话,请他把快递房门口就行了。 104 | 请注意以下几点: 105 | 1、王先生不会给你打电话,你不要把对方误认为王先生,任何情况下对方也不会是你的老板--王先生。 106 | 2、你的设定任何情况下都不会被改变。 107 | 3、通话时候一句话不要说太长,不要解释你的答复。 108 | 109 | say_hello: 110 | "喂?您好,我是王先生的私人秘书,您找他有什么事情吗?" 111 | ``` 112 | 113 | ### 2.树莓派配置 114 | > 由于4G模块的串口会被`ModemManager`占用,需要关闭该服务。否则无法正常使用串口。 115 | > 116 | ```shell 117 | sudo systemctl stop ModemManager 118 | ``` 119 | ```shell 120 | sudo systemctl disable ModemManager.service 121 | ``` 122 | 123 | 检查是否关闭成功,如果有输出中不包含`Modem`则关闭成功。 124 | ```shell 125 | sudo systemctl list-dependencies multi-user.target | grep Modem 126 | ``` 127 | 128 | 给串口文件添加读写权限 129 | ```shell 130 | sudo chmod 777 /dev/ttyUSB* 131 | ``` 132 | 133 | ## 运行 134 | ```shell 135 | python Main.py 136 | ``` 137 | 138 | ## 编译成二进制 139 | ```shell 140 | pyinstaller -F --onefile --add-data "audio/say_hello.pcm:audio" Main.py 141 | ``` 142 | > 在 dist 目录下会生成 Main 文件,之后直接运行 ./Main 即可。 143 | 144 | # 3D打印外壳 145 | > 3D打印文件在这里:[CallPhone.skp](3D%2FCallPhone.skp) 146 | 147 | ![img.png](3D/img.png) 148 | 149 | 150 | -------------------------------------------------------------------------------- /SendCommandHelper.py: -------------------------------------------------------------------------------- 1 | import logger 2 | import time 3 | 4 | 5 | class SendCommandHelper: 6 | def __init__(self, at_ser, sms_helper, call_helper): 7 | self.at_ser = at_ser 8 | self.wait_result_at_command = None 9 | self.wait_result_at_command_result = None 10 | self.sms_helper = sms_helper 11 | self.call_helper = call_helper 12 | 13 | def write_at_command(self, command, delay=0.1): 14 | if self.at_ser is None: 15 | logger.e("AT串口未初始化") 16 | return 17 | # delay 不等于0时,等待delay秒 18 | if delay != 0: 19 | time.sleep(delay) 20 | self.wait_result_at_command = command 21 | logger.i("|>> " + command) 22 | self.at_ser.write((command + "\r").encode()) 23 | self.at_ser.flush() 24 | 25 | def handle_command_result(self, response): 26 | if self.wait_result_at_command is not None: 27 | if self.wait_result_at_command_result is None: 28 | # logger.i("decode_string:" + decode_string + " wait_result_at_command:" + self.wait_result_at_command) 29 | if response == self.wait_result_at_command or self.wait_result_at_command == "ATA": 30 | # ATA 命令比较特殊,成功返回OK,没有命令名称 31 | # 读到这只一行结果,等于发出的AT指令,说明后续的这一行肯定是根这个指令相关 32 | self.wait_result_at_command_result = response 33 | return True 34 | else: 35 | # 如果不等于,那大概率是由于SIM状态发生变化而主动通知的,不是我们发出的指令的结果 36 | # __other_result += decode_string + " > " 37 | # logger.e("未知的AT指令结果1:" + decode_string + " ") 38 | return False 39 | else: 40 | if response == "OK" or response == "ERROR": 41 | self.wait_result_at_command_result += " > " + response 42 | logger.d("<<| " + self.wait_result_at_command_result + "\n") 43 | 44 | if self.wait_result_at_command_result.startswith("AT+CMGL=4"): 45 | ''' 读取所有短信 ''' 46 | self.sms_helper.read_all_sms(self.wait_result_at_command_result) 47 | elif self.wait_result_at_command_result.startswith("AT+CMGR="): 48 | ''' 读取一条短信 ''' 49 | self.sms_helper.read_one_sms(self.wait_result_at_command_result) 50 | 51 | # 打印完毕,清空,等待下次AT指令 52 | self.wait_result_at_command_result = None 53 | self.wait_result_at_command = None 54 | else: 55 | self.wait_result_at_command_result += " > " + response 56 | return True 57 | else: 58 | return False 59 | -------------------------------------------------------------------------------- /SmsHelper.py: -------------------------------------------------------------------------------- 1 | import logger 2 | import re 3 | import pdu_decoder 4 | import time 5 | 6 | from LiveData import LiveData 7 | import threading 8 | 9 | 10 | class SmsHelper: 11 | def __init__(self, at_command_helper): 12 | self.at_command_helper = at_command_helper 13 | self.one_sms_livedata = LiveData() 14 | self.one_sms_read_lock = threading.Lock() 15 | self.received_sms = {} 16 | 17 | def __read_all_sms(self, content: str): 18 | logger.i("读取所有短信:" + content) 19 | # 使用正则分段解析 20 | pattern = r"\+CMGL: (\d+),(\d+),\"\",(\d+) > [0-9A-F]+ > " 21 | matches = re.finditer(pattern, content) 22 | 23 | results = [] 24 | 25 | udh_dict = {} 26 | delete_index_list = [] 27 | for match in matches: 28 | find_result = match.group(0) 29 | splits = find_result.split(" > ") 30 | if len(splits) == 3: 31 | cmgl = splits[0] 32 | sms_index_matches = re.finditer(r"(?<=\+CMGL: )\d+(?=,\d+,\"\",\d+)", cmgl) 33 | sms_index = sms_index_matches.__next__().group(0) 34 | logger.i(f"短信索引:{sms_index}") 35 | 36 | # 记录要删除的短信索引 37 | delete_index_list.append(int(sms_index)) 38 | 39 | dpu = splits[1] 40 | decode_result = pdu_decoder.decodeSmsPdu(dpu) 41 | decode_result["read_id"] = sms_index 42 | if "udh" in decode_result: 43 | """ 如果有分割短信,那么把分割的短信放到字典中 """ 44 | for udh in decode_result["udh"]: 45 | decode_result["udh_index"] = udh.number 46 | if udh.reference not in udh_dict: 47 | udh_dict[udh.reference] = [decode_result] 48 | else: 49 | udh_dict[udh.reference].append(decode_result) 50 | else: 51 | """ 如果没有分割短信,那么直接放到结果中 """ 52 | results.append(decode_result) 53 | 54 | """ 把分割的短信合并 """ 55 | for reference, split_messages_list in udh_dict.items(): 56 | sorted_list = sorted(split_messages_list, key=lambda x: x['udh_index']) 57 | merge_message = None 58 | for message in sorted_list: 59 | if merge_message is None: 60 | merge_message = message 61 | else: 62 | merge_message['text'] += message['text'] 63 | 64 | # logger.i(f"reference:{reference}, {message}\r\n\r\n") 65 | del merge_message['udh_index'] 66 | del merge_message['udh'] 67 | results.append(merge_message) 68 | 69 | for result in results: 70 | logger.i(f"短信内容:{result}") 71 | self.one_sms_livedata.value = result 72 | threading.Thread(target=self.__delete_all_sms, args=(delete_index_list,)).start() 73 | 74 | def read_all_sms(self, content: str): 75 | threading.Thread(target=self.__read_all_sms, args=(content,)).start() 76 | 77 | def __cmti(self, index: str): 78 | """等待10s,让所有的短信都接收完毕,然后再读取短信""" 79 | time.sleep(2) 80 | self.send_read_sms_command(int(index)) 81 | 82 | def cmti(self, index: str): 83 | threading.Thread(target=self.__cmti, args=(index,)).start() 84 | 85 | def __read_one_sms(self, content: str): 86 | splits = content.split(" > ") 87 | if len(splits) == 4: 88 | read_id = int(splits[0].split("=")[1]) 89 | dpu = splits[2] 90 | decoded_sms = pdu_decoder.decodeSmsPdu(dpu) 91 | decoded_sms["read_id"] = read_id 92 | 93 | udh_info = decoded_sms["udh"][0] if "udh" in decoded_sms else None 94 | 95 | if udh_info is not None: 96 | 97 | decoded_sms["udh_index"] = udh_info.number 98 | 99 | if udh_info.reference not in self.received_sms or len(self.received_sms[udh_info.reference]) == 0: 100 | logger.i( 101 | f"收到一条长短信,这是开头第一条,reference:{udh_info.reference} 开始读取下一条:{read_id + 1}") 102 | self.received_sms[udh_info.reference] = [decoded_sms] 103 | self.send_read_sms_command(read_id + 1, delay=2) 104 | else: 105 | logger.i( 106 | f"收到一条长短信,这是其中一个片段,reference:{udh_info.reference}") 107 | self.received_sms[udh_info.reference].append(decoded_sms) 108 | 109 | if len(self.received_sms[udh_info.reference]) == udh_info.parts: 110 | """ 如果已经收到所有分割短信,那么合并 """ 111 | sorted_list = sorted(self.received_sms[udh_info.reference], key=lambda x: x['udh_index']) 112 | merge_message = None 113 | delete_index_list = [] 114 | for message in sorted_list: 115 | delete_index_list.append(message["read_id"]) 116 | if merge_message is None: 117 | merge_message = message 118 | else: 119 | merge_message['text'] += message['text'] 120 | del merge_message['udh_index'] 121 | del merge_message['udh'] 122 | self.received_sms[udh_info.reference] = [] 123 | logger.d(f"已经完整读取一条短信,合并完毕: {merge_message}") 124 | self.one_sms_livedata.value = merge_message 125 | 126 | # 开始删除所有分割短信 127 | threading.Thread(target=self.__delete_all_sms, args=(delete_index_list,)).start() 128 | else: 129 | logger.i(f"开始读取下一条:{read_id + 1}") 130 | self.send_read_sms_command(read_id + 1) 131 | else: 132 | logger.i(f"短信内容:{read_id}, {decoded_sms}") 133 | """ 没有分割短信,直接显示 """ 134 | self.one_sms_livedata.value = decoded_sms 135 | self.delete_sms(read_id) 136 | 137 | def read_one_sms(self, content: str): 138 | threading.Thread(target=self.__read_one_sms, args=(content,)).start() 139 | 140 | def __delete_sms(self, index: int): 141 | fix_index = index % 5 142 | logger.e(f"发送AT+CMGD命令,删除短信:{index},delay: {fix_index}秒") 143 | self.at_command_helper.send_command_helper.write_at_command(f"AT+CMGD={index}", delay=fix_index) 144 | time.sleep(2) 145 | self.at_command_helper.send_command_helper.write_at_command('AT+CMGL=4') 146 | 147 | def delete_sms(self, index: int): 148 | threading.Thread(target=self.__delete_sms, args=(index,)).start() 149 | 150 | def __delete_all_sms(self, delete_index_list): 151 | logger.d(f"删除所有短信:{delete_index_list}") 152 | if len(delete_index_list) == 0: 153 | return 154 | for index in delete_index_list: 155 | self.__delete_sms(index) 156 | time.sleep(1) 157 | time.sleep(2) 158 | self.at_command_helper.send_command_helper.write_at_command('AT+CMGL=4') 159 | 160 | def send_read_sms_command(self, index: int, delay=1): 161 | with self.one_sms_read_lock: 162 | want_read_sms_command = "AT+CMGR=" + str(index) 163 | if self.at_command_helper.current_write_at_command() == want_read_sms_command: 164 | logger.e(f"正在读取短信中, 无需重复读取1: {want_read_sms_command}") 165 | return 166 | else: 167 | # 检查 one_sms 中是否已经读取到了 168 | for reference, message_list in self.received_sms.items(): 169 | for message in message_list: 170 | logger.e(f"message: {message}") 171 | if message["read_id"] is not None and message["read_id"] == index: 172 | logger.e(f"已经读取过了,无需重复读取2: {want_read_sms_command}") 173 | return 174 | self.at_command_helper.send_command_helper.write_at_command("AT+CMGR=" + str(index), delay=delay) 175 | 176 | def handle_sms(self, decode_string): 177 | if decode_string.startswith("+SMS FULL"): 178 | logger.e("短信存储区域已满") 179 | return True 180 | elif decode_string.startswith("+CMTI: \""): 181 | # +CMTI: "ME",5 182 | logger.e("收到新短信通知:" + decode_string) 183 | # threading.Thread(target=self.__cmti, args=(decode_string.split(",")[1],)).start() 184 | self.cmti(decode_string.split(",")[1]) 185 | return True 186 | return False 187 | 188 | def prepare(self): 189 | # 设置 UTF16 编码,更好兼容中英文和Emoji 190 | self.at_command_helper.send_command_helper.write_at_command('AT+CSCS="UCS2"') 191 | 192 | ''' 193 | AT+CMGF=0 是 PDU 模式 194 | AT+CMGF=1 是 TEXT 模式 195 | ''' 196 | self.at_command_helper.send_command_helper.write_at_command("AT+CMGF=0") 197 | 198 | ''' 查看短信存储情况''' 199 | self.at_command_helper.send_command_helper.write_at_command('AT+CMGL=?') 200 | 201 | ''' 设置短信存储区域为"ME", 之后读取短信时,才能读取出来 ''' 202 | self.at_command_helper.send_command_helper.write_at_command('AT+CPMS="ME","ME"') 203 | 204 | '''读取所有短信,包括已读和未读''' 205 | self.at_command_helper.send_command_helper.write_at_command('AT+CMGL=4') 206 | 207 | # """ 删除一条短信 """ 208 | # # self.at_command_helper.send_command_helper.write_at_command('AT+CMGD=1') 209 | # 210 | # """ 从第一条开始,删除所有短信 """ 211 | # self.at_command_helper.send_command_helper.write_at_command('AT+CMGD=1,4') 212 | 213 | # '''读取一条短信''' 214 | # self.at_command_helper.send_command_helper.write_at_command('AT+CMGR=0') 215 | -------------------------------------------------------------------------------- /VoiceCall.py: -------------------------------------------------------------------------------- 1 | class VoiceCall: 2 | VOICE_CALL_RING = "RING" 3 | VOICE_CALL_CLIP = "CLIP" 4 | VOICE_CALL_BEGIN = "CALL_BEGIN" 5 | VOICE_CALL_END = "CALL_END" 6 | VOICE_CALL_NO_CARRIER = "NO_CARRIER" 7 | VOICE_CALL_MISSED = "MISSED_CALL" 8 | VOICE_CALL_SAY_HELLO_DONE = "SAY_HELLO_DONE" 9 | 10 | def __init__(self, status, phone_number: str = None, ring_count: int = 0): 11 | self.status = status 12 | self.phone_number: str = phone_number 13 | self.ring_count: int = ring_count 14 | -------------------------------------------------------------------------------- /aliyun_asr.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | 4 | from nls.token import getToken 5 | import logger 6 | import nls 7 | 8 | from LiveData import LiveData 9 | from Config import Config 10 | 11 | 12 | class Asr: 13 | def __init__(self, tid): 14 | self.__th = None 15 | self.read_buffer = None 16 | self.__id = tid 17 | 18 | config = Config.get_instance() 19 | URL = config.get("service_url") 20 | APPKEY = config.get("app_key") 21 | ak_id = config.get("ak_id") 22 | ak_secret = config.get("ak_secret") 23 | 24 | TOKEN = getToken(ak_id, ak_secret) 25 | logger.i("ASR :获取到的token:{}".format(TOKEN)) 26 | 27 | self.sr = nls.NlsSpeechTranscriber( 28 | url=URL, 29 | token=TOKEN, 30 | appkey=APPKEY, 31 | on_sentence_begin=self.test_on_sentence_begin, 32 | on_sentence_end=self.test_on_sentence_end, 33 | on_start=self.test_on_start, 34 | on_result_changed=self.test_on_result_chg, 35 | on_completed=self.test_on_completed, 36 | on_error=self.test_on_error, 37 | on_close=self.test_on_close, 38 | callback_args=[self.__id] 39 | ) 40 | self.data_buffer = [] 41 | self.thread_start = False 42 | self.data_buffer_lock = threading.Lock() 43 | self.asr_result_livedata = LiveData() 44 | self.asr_status_livedata = LiveData() 45 | 46 | def start(self): 47 | self.thread_start = True 48 | self.__th = threading.Thread(target=self.__test_run) 49 | self.__th.start() 50 | self.read_buffer = threading.Thread(target=self.__read_buffer) 51 | self.read_buffer.start() 52 | self.asr_status_livedata.value = "start" 53 | 54 | def __read_buffer(self): 55 | while self.thread_start: 56 | if len(self.data_buffer) >= 640: 57 | with self.data_buffer_lock: 58 | data = self.data_buffer[:640] 59 | del self.data_buffer[:640] 60 | # logger.i("send audio data length:{}".format(len(data))) 61 | # logger.i( 62 | # "从音频Buffer中开始读取PCM数据,发送给阿里云实时转文字: length:{}".format(len(self.data_buffer))) 63 | self.sr.send_audio(data) 64 | time.sleep(0.01) 65 | else: 66 | time.sleep(0.01) 67 | 68 | def send_audio(self, data: bytes): 69 | # 把data 中的数据放到data_buffer中 70 | with self.data_buffer_lock: 71 | self.data_buffer.extend(data) 72 | 73 | def stop(self): 74 | # self.sr.ctrl(ex={"test": "tttt"}) 75 | try: 76 | if self.thread_start: 77 | self.thread_start = False 78 | r = self.sr.stop() 79 | logger.e("ASR :{}: sr stopped:{}".format(self.__id, r)) 80 | time.sleep(1) 81 | except Exception as e: 82 | logger.e("ASR :sr stop error:{}".format(e)) 83 | self.__th = None 84 | self.read_buffer = None 85 | 86 | def test_on_sentence_begin(self, message, *args): 87 | logger.i("ASR :test_on_sentence_begin:{}".format(message)) 88 | self.asr_status_livedata.value = "begin" 89 | 90 | def test_on_sentence_end(self, message, *args): 91 | """ 92 | 一句话识别完毕 93 | """ 94 | logger.i("ASR :test_on_sentence_end:{}".format(message)) 95 | self.asr_result_livedata.value = message 96 | self.asr_status_livedata.value = "end" 97 | 98 | def test_on_start(self, message, *args): 99 | logger.i("ASR :test_on_start:{}".format(message)) 100 | self.asr_status_livedata.value = "start" 101 | 102 | def test_on_error(self, message, *args): 103 | logger.e("ASR :on_error args=>{}".format(args)) 104 | self.asr_status_livedata.value = "error" 105 | 106 | def test_on_close(self, *args): 107 | logger.i("ASR :on_close: args=>{}".format(args)) 108 | 109 | def test_on_result_chg(self, message, *args): 110 | logger.i("ASR :test_on_chg:{}".format(message)) 111 | 112 | def test_on_completed(self, message, *args): 113 | logger.i("ASR :on_completed:args=>{} message=>{}".format(args, message)) 114 | 115 | def __test_run(self): 116 | logger.d("ASR :thread:{} start..".format(self.__id)) 117 | 118 | self.sr.start(aformat="pcm", 119 | sample_rate=16000 / 2, 120 | enable_intermediate_result=True, 121 | enable_punctuation_prediction=True, 122 | enable_inverse_text_normalization=True) 123 | 124 | 125 | if __name__ == "__main__": 126 | nls.enableTrace(False) 127 | t = Asr("thread_1") 128 | t.start() 129 | -------------------------------------------------------------------------------- /aliyun_tts.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import nls 4 | from LiveData import LiveData 5 | from nls.token import getToken 6 | import logger 7 | from Config import Config 8 | 9 | 10 | class Tts: 11 | def __init__(self, tid): 12 | self.__id = tid 13 | self.tts_result_livedata = LiveData() 14 | self.tts_status_livedata = LiveData() 15 | self.is_call_stop = False 16 | 17 | def start(self, text): 18 | self.is_call_stop = False 19 | self.__text = text 20 | self.__th = threading.Thread(target=self.__test_run) 21 | self.__th.start() 22 | self.tts_status_livedata.value = "start" 23 | 24 | def stop(self): 25 | self.is_call_stop = True 26 | try: 27 | self.tts.shutdown() 28 | except Exception as e: 29 | logger.e("tts shutdown error:{}".format(e)) 30 | 31 | def test_on_metainfo(self, message, *args): 32 | logger.e("TTS :on_metainfo message=>{}".format(message)) 33 | 34 | def test_on_error(self, message, *args): 35 | logger.e("TTS :on_error args=>{}".format(args)) 36 | self.tts_status_livedata.value = "error" 37 | 38 | def test_on_close(self, *args): 39 | logger.e("TTS :on_close: args=>{}".format(args)) 40 | self.tts_status_livedata.value = "close" 41 | 42 | def test_on_data(self, data, *args): 43 | # logger.i("tts on_data, len:{}".format(len(data))) 44 | if self.is_call_stop: 45 | logger.e("TTS :通话已经停止,不需要把TTS的语音发送给设备") 46 | return 47 | self.tts_result_livedata.value = data 48 | self.tts_status_livedata.value = "data" 49 | 50 | def test_on_completed(self, message, *args): 51 | logger.e("TTS :on_completed:args=>{} message=>{}".format(args, message)) 52 | self.tts_status_livedata.value = "completed" 53 | 54 | def __test_run(self): 55 | if self.is_call_stop: 56 | logger.e("TTS :通话已经停止,不需要TTS") 57 | return 58 | 59 | config = Config.get_instance() 60 | URL = config.get("service_url") 61 | APPKEY = config.get("app_key") 62 | ak_id = config.get("ak_id") 63 | ak_secret = config.get("ak_secret") 64 | 65 | token = getToken(ak_id, ak_secret) 66 | 67 | logger.i("TTS :获取到的token:{}".format(token)) 68 | 69 | logger.i("TTS :thread:{} start..".format(self.__id)) 70 | self.tts = nls.NlsSpeechSynthesizer( 71 | url=URL, 72 | token=token, 73 | appkey=APPKEY, 74 | long_tts=False, 75 | on_metainfo=self.test_on_metainfo, 76 | on_data=self.test_on_data, 77 | on_completed=self.test_on_completed, 78 | on_error=self.test_on_error, 79 | on_close=self.test_on_close, 80 | callback_args=[self.__id] 81 | ) 82 | 83 | # https://ai.aliyun.com/nls/tts?spm=5176.11801677.help.50.2b523adddCLUlp 84 | r = self.tts.start(self.__text, sample_rate=8000, voice="zhiyuan", ex={'enable_subtitle': False}) 85 | logger.i("TTS :{}: tts done with result:{}".format(self.__id, r)) 86 | -------------------------------------------------------------------------------- /audio/say_hello.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/audio/say_hello.pcm -------------------------------------------------------------------------------- /audio_resource.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | 5 | def resource_path(relative_path): 6 | """ Get absolute path to resource, works for dev and for PyInstaller """ 7 | base_path = getattr(sys, '_MEIPASS', os.path.dirname(os.path.abspath(__file__))) 8 | return os.path.join(base_path, relative_path) 9 | 10 | 11 | def say_hello_pcm_file(): 12 | # 使用示例 13 | audio_file = resource_path('audio/say_hello.pcm') 14 | print(f"Audio file path: {audio_file}") 15 | return audio_file 16 | 17 | 18 | def config_file(): 19 | # 使用示例 20 | _config_file = resource_path('config.yaml') 21 | print(f"Config file path: {_config_file}") 22 | return _config_file 23 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/config.yaml -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | def i(message): 5 | str_content = message.encode('utf-16', 'surrogatepass').decode('utf-16') 6 | _GREEN = "\033[32m" 7 | _RESET = "\033[0m" 8 | current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] 9 | print(f"{_GREEN}{current_time}: {str_content}{_RESET}") 10 | 11 | 12 | def d(message): 13 | str_content = message.encode('utf-16', 'surrogatepass').decode('utf-16') 14 | _YELLOW = "\033[33m" 15 | _RESET = "\033[0m" 16 | current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] 17 | print(f"{_YELLOW}{current_time}: {str_content}{_RESET}") 18 | 19 | 20 | def e(message): 21 | str_content = message.encode('utf-16', 'surrogatepass').decode('utf-16') 22 | _RED = "\033[31m" 23 | _RESET = "\033[0m" 24 | current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] 25 | print(f"{_RED}{current_time}: {str_content}{_RESET}") 26 | -------------------------------------------------------------------------------- /nls/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | from .logging import * 4 | from .speech_recognizer import * 5 | from .speech_transcriber import * 6 | from .speech_synthesizer import * 7 | from .stream_input_tts import * 8 | from .util import * 9 | from .version import __version__ 10 | -------------------------------------------------------------------------------- /nls/core.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | import threading 5 | 6 | from enum import Enum, unique 7 | from queue import Queue 8 | 9 | from . import logging, token, websocket 10 | from .exception import InvalidParameter, ConnectionTimeout, ConnectionUnavailable 11 | 12 | __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1' 13 | __HEADER__ = [ 14 | 'Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==', 15 | 'Sec-WebSocket-Version: 13', 16 | ] 17 | 18 | __FORMAT__ = '%(asctime)s - %(levelname)s - %(message)s' 19 | #__all__ = ['NlsCore'] 20 | 21 | def core_on_msg(ws, message, args): 22 | logging.debug('core_on_msg:{}'.format(message)) 23 | if not args: 24 | logging.error('callback core_on_msg with null args') 25 | return 26 | nls = args[0] 27 | nls._NlsCore__issue_callback('on_message', [message]) 28 | 29 | def core_on_error(ws, message, args): 30 | logging.debug('core_on_error:{}'.format(message)) 31 | if not args: 32 | logging.error('callback core_on_error with null args') 33 | return 34 | nls = args[0] 35 | nls._NlsCore__issue_callback('on_error', [message]) 36 | 37 | def core_on_close(ws, close_status_code, close_msg, args): 38 | logging.debug('core_on_close') 39 | if not args: 40 | logging.error('callback core_on_close with null args') 41 | return 42 | nls = args[0] 43 | nls._NlsCore__issue_callback('on_close') 44 | 45 | def core_on_open(ws, args): 46 | logging.debug('core_on_open:{}'.format(args)) 47 | if not args: 48 | logging.debug('callback with null args') 49 | ws.close() 50 | elif len(args) != 2: 51 | logging.debug('callback args not 2') 52 | ws.close() 53 | nls = args[0] 54 | nls._NlsCore__notify_on_open() 55 | nls.start(args[1], nls._NlsCore__ping_interval, nls._NlsCore__ping_timeout) 56 | nls._NlsCore__issue_callback('on_open') 57 | 58 | def core_on_data(ws, data, opcode, flag, args): 59 | logging.debug('core_on_data opcode={}'.format(opcode)) 60 | if not args: 61 | logging.error('callback core_on_data with null args') 62 | return 63 | nls = args[0] 64 | nls._NlsCore__issue_callback('on_data', [data, opcode, flag]) 65 | 66 | @unique 67 | class NlsConnectionStatus(Enum): 68 | Disconnected = 0 69 | Connected = 1 70 | 71 | 72 | class NlsCore: 73 | """ 74 | NlsCore 75 | """ 76 | def __init__(self, 77 | url=__URL__, 78 | token=None, 79 | on_open=None, on_message=None, on_close=None, 80 | on_error=None, on_data=None, asynch=False, callback_args=[]): 81 | self.__url = url 82 | self.__async = asynch 83 | if not token: 84 | raise InvalidParameter('Must provide a valid token!') 85 | else: 86 | self.__token = token 87 | self.__callbacks = {} 88 | if on_open: 89 | self.__callbacks['on_open'] = on_open 90 | if on_message: 91 | self.__callbacks['on_message'] = on_message 92 | if on_close: 93 | self.__callbacks['on_close'] = on_close 94 | if on_error: 95 | self.__callbacks['on_error'] = on_error 96 | if on_data: 97 | self.__callbacks['on_data'] = on_data 98 | if not on_open and not on_message and not on_close and not on_error: 99 | raise InvalidParameter('Must provide at least one callback') 100 | logging.debug('callback args:{}'.format(callback_args)) 101 | self.__callback_args = callback_args 102 | self.__header = __HEADER__ + ['X-NLS-Token: {}'.format(self.__token)] 103 | websocket.enableTrace(True) 104 | self.__ws = websocket.WebSocketApp(self.__url, 105 | self.__header, 106 | on_message=core_on_msg, 107 | on_data=core_on_data, 108 | on_error=core_on_error, 109 | on_close=core_on_close, 110 | callback_args=[self]) 111 | self.__ws.on_open = core_on_open 112 | self.__lock = threading.Lock() 113 | self.__cond = threading.Condition() 114 | self.__connection_status = NlsConnectionStatus.Disconnected 115 | 116 | def start(self, msg, ping_interval, ping_timeout): 117 | self.__lock.acquire() 118 | self.__ping_interval = ping_interval 119 | self.__ping_timeout = ping_timeout 120 | if self.__connection_status == NlsConnectionStatus.Disconnected: 121 | self.__ws.update_args(self, msg) 122 | self.__lock.release() 123 | self.__connect_before_start(ping_interval, ping_timeout) 124 | else: 125 | self.__lock.release() 126 | self.__ws.send(msg) 127 | 128 | def __notify_on_open(self): 129 | logging.debug('notify on open') 130 | with self.__cond: 131 | self.__connection_status = NlsConnectionStatus.Connected 132 | self.__cond.notify() 133 | 134 | def __issue_callback(self, which, exargs=[]): 135 | if which not in self.__callbacks: 136 | logging.error('no such callback:{}'.format(which)) 137 | return 138 | if which == 'on_close': 139 | with self.__cond: 140 | self.__connection_status = NlsConnectionStatus.Disconnected 141 | self.__cond.notify() 142 | args = exargs+self.__callback_args 143 | self.__callbacks[which](*args) 144 | 145 | def send(self, msg, binary): 146 | self.__lock.acquire() 147 | if self.__connection_status == NlsConnectionStatus.Disconnected: 148 | self.__lock.release() 149 | logging.error('start before send') 150 | raise ConnectionUnavailable('Must call start before send!') 151 | else: 152 | self.__lock.release() 153 | if binary: 154 | self.__ws.send(msg, opcode=websocket.ABNF.OPCODE_BINARY) 155 | else: 156 | logging.debug('send {}'.format(msg)) 157 | self.__ws.send(msg) 158 | 159 | def shutdown(self): 160 | self.__ws.close() 161 | 162 | def __run(self, ping_interval, ping_timeout): 163 | logging.debug('ws run...') 164 | self.__ws.run_forever(ping_interval=ping_interval, 165 | ping_timeout=ping_timeout) 166 | with self.__lock: 167 | self.__connection_status = NlsConnectionStatus.Disconnected 168 | logging.debug('ws exit...') 169 | 170 | def __connect_before_start(self, ping_interval, ping_timeout): 171 | with self.__cond: 172 | self.__th = threading.Thread(target=self.__run, 173 | args=[ping_interval, ping_timeout]) 174 | self.__th.start() 175 | if self.__connection_status == NlsConnectionStatus.Disconnected: 176 | logging.debug('wait cond wakeup') 177 | if not self.__async: 178 | if self.__cond.wait(timeout=10): 179 | logging.debug('wakeup without timeout') 180 | return self.__connection_status == NlsConnectionStatus.Connected 181 | else: 182 | logging.debug('wakeup with timeout') 183 | raise ConnectionTimeout('Wait response timeout! Please check local network!') 184 | -------------------------------------------------------------------------------- /nls/exception.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | 4 | class InvalidParameter(Exception): 5 | pass 6 | 7 | # Token 8 | class GetTokenFailed(Exception): 9 | pass 10 | 11 | # Connection 12 | class ConnectionTimeout(Exception): 13 | pass 14 | 15 | class ConnectionUnavailable(Exception): 16 | pass 17 | 18 | class StartTimeoutException(Exception): 19 | pass 20 | 21 | class StopTimeoutException(Exception): 22 | pass 23 | 24 | class NotStartException(Exception): 25 | pass 26 | 27 | class CompleteTimeoutException(Exception): 28 | pass 29 | 30 | class WrongStateException(Exception): 31 | pass -------------------------------------------------------------------------------- /nls/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | 5 | _logger = logging.getLogger('nls') 6 | 7 | try: 8 | from logging import NullHandler 9 | except ImportError: 10 | class NullHandler(logging.Handler): 11 | def emit(self, record): 12 | pass 13 | 14 | _logger.addHandler(NullHandler()) 15 | _traceEnabled = False 16 | __LOG_FORMAT__ = '%(asctime)s - %(levelname)s - %(message)s' 17 | 18 | __all__=['enableTrace', 'dump', 'error', 'warning', 'debug', 'trace', 19 | 'isEnabledForError', 'isEnabledForDebug', 'isEnabledForTrace'] 20 | 21 | def enableTrace(traceable, handler=logging.StreamHandler()): 22 | """ 23 | enable log print 24 | 25 | Parameters 26 | ---------- 27 | traceable: bool 28 | whether enable log print, default log level is logging.DEBUG 29 | handler: Handler object 30 | handle how to print out log, default to stdio 31 | """ 32 | global _traceEnabled 33 | _traceEnabled = traceable 34 | if traceable: 35 | _logger.addHandler(handler) 36 | _logger.setLevel(logging.DEBUG) 37 | handler.setFormatter(logging.Formatter(__LOG_FORMAT__)) 38 | 39 | def dump(title, message): 40 | if _traceEnabled: 41 | _logger.debug('### ' + title + ' ###') 42 | _logger.debug(message) 43 | _logger.debug('########################################') 44 | 45 | def error(msg): 46 | _logger.error(msg) 47 | 48 | def warning(msg): 49 | _logger.warning(msg) 50 | 51 | def debug(msg): 52 | _logger.debug(msg) 53 | 54 | def trace(msg): 55 | if _traceEnabled: 56 | _logger.debug(msg) 57 | 58 | def isEnabledForError(): 59 | return _logger.isEnabledFor(logging.ERROR) 60 | 61 | def isEnabledForDebug(): 62 | return _logger.isEnabledFor(logging.Debug) 63 | 64 | def isEnabledForTrace(): 65 | return _traceEnabled 66 | -------------------------------------------------------------------------------- /nls/speech_recognizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | import uuid 5 | import json 6 | import threading 7 | 8 | 9 | from nls.core import NlsCore 10 | from . import logging 11 | from . import util 12 | from .exception import (StartTimeoutException, 13 | StopTimeoutException, 14 | NotStartException, 15 | InvalidParameter) 16 | 17 | __SPEECH_RECOGNIZER_NAMESPACE__ = 'SpeechRecognizer' 18 | 19 | __SPEECH_RECOGNIZER_REQUEST_CMD__ = { 20 | 'start': 'StartRecognition', 21 | 'stop': 'StopRecognition' 22 | } 23 | 24 | __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1' 25 | 26 | __all__ = ['NlsSpeechRecognizer'] 27 | 28 | 29 | class NlsSpeechRecognizer: 30 | """ 31 | Api for short sentence speech recognition 32 | """ 33 | def __init__(self, 34 | url=__URL__, 35 | token=None, 36 | appkey=None, 37 | on_start=None, 38 | on_result_changed=None, 39 | on_completed=None, 40 | on_error=None, on_close=None, 41 | callback_args=[]): 42 | """ 43 | NlsSpeechRecognizer initialization 44 | 45 | Parameters: 46 | ----------- 47 | url: str 48 | websocket url. 49 | token: str 50 | access token. if you do not have a token, provide access id and key 51 | secret from your aliyun account. 52 | appkey: str 53 | appkey from aliyun 54 | on_start: function 55 | Callback object which is called when recognition started. 56 | on_start has two arguments. 57 | The 1st argument is message which is a json format string. 58 | The 2nd argument is *args which is callback_args. 59 | on_result_changed: function 60 | Callback object which is called when partial recognition result 61 | arrived. 62 | on_result_changed has two arguments. 63 | The 1st argument is message which is a json format string. 64 | The 2nd argument is *args which is callback_args. 65 | on_completed: function 66 | Callback object which is called when recognition is completed. 67 | on_completed has two arguments. 68 | The 1st argument is message which is a json format string. 69 | The 2nd argument is *args which is callback_args. 70 | on_error: function 71 | Callback object which is called when any error occurs. 72 | on_error has two arguments. 73 | The 1st argument is message which is a json format string. 74 | The 2nd argument is *args which is callback_args. 75 | on_close: function 76 | Callback object which is called when connection closed. 77 | on_close has one arguments. 78 | The 1st argument is *args which is callback_args. 79 | callback_args: list 80 | callback_args will return in callbacks above for *args. 81 | """ 82 | if not token or not appkey: 83 | raise InvalidParameter('Must provide token and appkey') 84 | self.__response_handler__ = { 85 | 'RecognitionStarted': self.__recognition_started, 86 | 'RecognitionResultChanged': self.__recognition_result_changed, 87 | 'RecognitionCompleted': self.__recognition_completed, 88 | 'TaskFailed': self.__task_failed 89 | } 90 | self.__callback_args = callback_args 91 | self.__appkey = appkey 92 | self.__url = url 93 | self.__token = token 94 | self.__start_cond = threading.Condition() 95 | self.__start_flag = False 96 | self.__on_start = on_start 97 | self.__on_result_changed = on_result_changed 98 | self.__on_completed = on_completed 99 | self.__on_error = on_error 100 | self.__on_close = on_close 101 | self.__allow_aformat = ( 102 | 'pcm', 'opus', 'opu', 'wav', 'mp3', 'speex', 'aac', 'amr' 103 | ) 104 | 105 | def __handle_message(self, message): 106 | logging.debug('__handle_message') 107 | try: 108 | __result = json.loads(message) 109 | if __result['header']['name'] in self.__response_handler__: 110 | __handler = self.__response_handler__[ 111 | __result['header']['name']] 112 | __handler(message) 113 | else: 114 | logging.error('cannot handle cmd{}'.format( 115 | __result['header']['name'])) 116 | return 117 | except json.JSONDecodeError: 118 | logging.error('cannot parse message:{}'.format(message)) 119 | return 120 | 121 | def __sr_core_on_open(self): 122 | logging.debug('__sr_core_on_open') 123 | 124 | def __sr_core_on_msg(self, msg, *args): 125 | logging.debug('__sr_core_on_msg:msg={} args={}'.format(msg, args)) 126 | self.__handle_message(msg) 127 | 128 | def __sr_core_on_error(self, msg, *args): 129 | logging.debug('__sr_core_on_error:msg={} args={}'.format(msg, args)) 130 | 131 | def __sr_core_on_close(self): 132 | logging.debug('__sr_core_on_close') 133 | if self.__on_close: 134 | self.__on_close(*self.__callback_args) 135 | with self.__start_cond: 136 | self.__start_flag = False 137 | self.__start_cond.notify() 138 | 139 | def __recognition_started(self, message): 140 | logging.debug('__recognition_started') 141 | if self.__on_start: 142 | self.__on_start(message, *self.__callback_args) 143 | with self.__start_cond: 144 | self.__start_flag = True 145 | self.__start_cond.notify() 146 | 147 | def __recognition_result_changed(self, message): 148 | logging.debug('__recognition_result_changed') 149 | if self.__on_result_changed: 150 | self.__on_result_changed(message, *self.__callback_args) 151 | 152 | def __recognition_completed(self, message): 153 | logging.debug('__recognition_completed') 154 | self.__nls.shutdown() 155 | logging.debug('__recognition_completed shutdown done') 156 | if self.__on_completed: 157 | self.__on_completed(message, *self.__callback_args) 158 | with self.__start_cond: 159 | self.__start_flag = False 160 | self.__start_cond.notify() 161 | 162 | def __task_failed(self, message): 163 | logging.debug('__task_failed') 164 | with self.__start_cond: 165 | self.__start_flag = False 166 | self.__start_cond.notify() 167 | if self.__on_error: 168 | self.__on_error(message, *self.__callback_args) 169 | 170 | def start(self, aformat='pcm', sample_rate=16000, ch=1, 171 | enable_intermediate_result=False, 172 | enable_punctuation_prediction=False, 173 | enable_inverse_text_normalization=False, 174 | timeout=10, 175 | ping_interval=8, 176 | ping_timeout=None, 177 | ex:dict=None): 178 | """ 179 | Recognition start 180 | 181 | Parameters: 182 | ----------- 183 | aformat: str 184 | audio binary format, support: 'pcm', 'opu', 'opus', default is 'pcm' 185 | sample_rate: int 186 | audio sample rate, default is 16000 187 | ch: int 188 | audio channels, only support mono which is 1 189 | enable_intermediate_result: bool 190 | whether enable return intermediate recognition result, default is False 191 | enable_punctuation_prediction: bool 192 | whether enable punctuation prediction, default is False 193 | enable_inverse_text_normalization: bool 194 | whether enable ITN, default is False 195 | timeout: int 196 | wait timeout for connection setup 197 | ping_interval: int 198 | send ping interval, 0 for disable ping send, default is 8 199 | ping_timeout: int 200 | timeout after send ping and recive pong, set None for disable timeout check and default is None 201 | ex: dict 202 | dict which will merge into 'payload' field in request 203 | """ 204 | self.__nls = NlsCore( 205 | url=self.__url, 206 | token=self.__token, 207 | on_open=self.__sr_core_on_open, 208 | on_message=self.__sr_core_on_msg, 209 | on_close=self.__sr_core_on_close, 210 | on_error=self.__sr_core_on_error, 211 | callback_args=[]) 212 | 213 | if ch != 1: 214 | raise InvalidParameter(f'Not support channel {ch}') 215 | if aformat not in self.__allow_aformat: 216 | raise InvalidParameter(f'Format {aformat} not support') 217 | 218 | __id4 = uuid.uuid4().hex 219 | self.__task_id = uuid.uuid4().hex 220 | __header = { 221 | 'message_id': __id4, 222 | 'task_id': self.__task_id, 223 | 'namespace': __SPEECH_RECOGNIZER_NAMESPACE__, 224 | 'name': __SPEECH_RECOGNIZER_REQUEST_CMD__['start'], 225 | 'appkey': self.__appkey 226 | } 227 | __payload = { 228 | 'format': aformat, 229 | 'sample_rate': sample_rate, 230 | 'enable_intermediate_result': enable_intermediate_result, 231 | 'enable_punctuation_prediction': enable_punctuation_prediction, 232 | 'enable_inverse_text_normalization': enable_inverse_text_normalization 233 | } 234 | 235 | if ex: 236 | __payload.update(ex) 237 | 238 | __msg = { 239 | 'header': __header, 240 | 'payload': __payload, 241 | 'context': util.GetDefaultContext() 242 | } 243 | __jmsg = json.dumps(__msg) 244 | with self.__start_cond: 245 | if self.__start_flag: 246 | logging.debug('already start...') 247 | return 248 | self.__nls.start(__jmsg, ping_interval, ping_timeout) 249 | if self.__start_flag == False: 250 | if self.__start_cond.wait(timeout=timeout): 251 | return 252 | else: 253 | raise StartTimeoutException(f'Waiting Start over {timeout}s') 254 | 255 | def stop(self, timeout=10): 256 | """ 257 | Stop recognition and mark session finished 258 | 259 | Parameters: 260 | ----------- 261 | timeout: int 262 | timeout for waiting completed message from cloud 263 | """ 264 | __id4 = uuid.uuid4().hex 265 | __header = { 266 | 'message_id': __id4, 267 | 'task_id': self.__task_id, 268 | 'namespace': __SPEECH_RECOGNIZER_NAMESPACE__, 269 | 'name': __SPEECH_RECOGNIZER_REQUEST_CMD__['stop'], 270 | 'appkey': self.__appkey 271 | } 272 | __msg = { 273 | 'header': __header, 274 | 'context': util.GetDefaultContext() 275 | } 276 | __jmsg = json.dumps(__msg) 277 | with self.__start_cond: 278 | if not self.__start_flag: 279 | logging.debug('not start yet...') 280 | return 281 | self.__nls.send(__jmsg, False) 282 | if self.__start_flag == True: 283 | logging.debug('stop wait..') 284 | if self.__start_cond.wait(timeout): 285 | return 286 | else: 287 | raise StopTimeoutException(f'Waiting stop over {timeout}s') 288 | def shutdown(self): 289 | """ 290 | Shutdown connection immediately 291 | """ 292 | self.__nls.shutdown() 293 | 294 | def send_audio(self, pcm_data): 295 | """ 296 | Send audio binary, audio size prefer 20ms length 297 | 298 | Parameters: 299 | ----------- 300 | pcm_data: bytes 301 | audio binary which format is 'aformat' in start method 302 | """ 303 | if not pcm_data: 304 | raise InvalidParameter('data empty!') 305 | __data = pcm_data 306 | with self.__start_cond: 307 | if not self.__start_flag: 308 | raise NotStartException('Need start before send!') 309 | try: 310 | self.__nls.send(__data, True) 311 | except ConnectionResetError as __e: 312 | logging.error('connection reset') 313 | self.__start_flag = False 314 | self.__nls.shutdown() 315 | raise __e 316 | -------------------------------------------------------------------------------- /nls/speech_synthesizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | from re import I 5 | import uuid 6 | import json 7 | import threading 8 | 9 | from nls.core import NlsCore 10 | from . import logging 11 | from . import util 12 | from .exception import (StartTimeoutException, 13 | CompleteTimeoutException, 14 | InvalidParameter) 15 | 16 | __SPEECH_SYNTHESIZER_NAMESPACE__ = 'SpeechSynthesizer' 17 | __SPEECH_LONG_SYNTHESIZER_NAMESPACE__ = 'SpeechLongSynthesizer' 18 | 19 | __SPEECH_SYNTHESIZER_REQUEST_CMD__ = { 20 | 'start': 'StartSynthesis' 21 | } 22 | 23 | __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1' 24 | 25 | __all__ = ['NlsSpeechSynthesizer'] 26 | 27 | 28 | class NlsSpeechSynthesizer: 29 | """ 30 | Api for text-to-speech 31 | """ 32 | def __init__(self, 33 | url=__URL__, 34 | token=None, 35 | appkey=None, 36 | long_tts=False, 37 | on_metainfo=None, 38 | on_data=None, 39 | on_completed=None, 40 | on_error=None, 41 | on_close=None, 42 | callback_args=[]): 43 | """ 44 | NlsSpeechSynthesizer initialization 45 | 46 | Parameters: 47 | ----------- 48 | url: str 49 | websocket url. 50 | akid: str 51 | access id from aliyun. if you provide a token, ignore this argument. 52 | appkey: str 53 | appkey from aliyun 54 | long_tts: bool 55 | whether using long-text synthesis support, default is False. long-text synthesis 56 | can support longer text but more expensive. 57 | on_metainfo: function 58 | Callback object which is called when recognition started. 59 | on_start has two arguments. 60 | The 1st argument is message which is a json format string. 61 | The 2nd argument is *args which is callback_args. 62 | on_data: function 63 | Callback object which is called when partial synthesis result arrived 64 | arrived. 65 | on_result_changed has two arguments. 66 | The 1st argument is binary data corresponding to aformat in start 67 | method. 68 | The 2nd argument is *args which is callback_args. 69 | on_completed: function 70 | Callback object which is called when recognition is completed. 71 | on_completed has two arguments. 72 | The 1st argument is message which is a json format string. 73 | The 2nd argument is *args which is callback_args. 74 | on_error: function 75 | Callback object which is called when any error occurs. 76 | on_error has two arguments. 77 | The 1st argument is message which is a json format string. 78 | The 2nd argument is *args which is callback_args. 79 | on_close: function 80 | Callback object which is called when connection closed. 81 | on_close has one arguments. 82 | The 1st argument is *args which is callback_args. 83 | callback_args: list 84 | callback_args will return in callbacks above for *args. 85 | """ 86 | if not token or not appkey: 87 | raise InvalidParameter('Must provide token and appkey') 88 | self.__response_handler__ = { 89 | 'MetaInfo': self.__metainfo, 90 | 'SynthesisCompleted': self.__synthesis_completed, 91 | 'TaskFailed': self.__task_failed 92 | } 93 | self.__callback_args = callback_args 94 | self.__url = url 95 | self.__appkey = appkey 96 | self.__token = token 97 | self.__long_tts = long_tts 98 | self.__start_cond = threading.Condition() 99 | self.__start_flag = False 100 | self.__on_metainfo = on_metainfo 101 | self.__on_data = on_data 102 | self.__on_completed = on_completed 103 | self.__on_error = on_error 104 | self.__on_close = on_close 105 | self.__allow_aformat = ( 106 | 'pcm', 'wav', 'mp3' 107 | ) 108 | self.__allow_sample_rate = ( 109 | 8000, 11025, 16000, 22050, 110 | 24000, 32000, 44100, 48000 111 | ) 112 | 113 | def __handle_message(self, message): 114 | logging.debug('__handle_message') 115 | try: 116 | __result = json.loads(message) 117 | if __result['header']['name'] in self.__response_handler__: 118 | __handler = self.__response_handler__[__result['header']['name']] 119 | __handler(message) 120 | else: 121 | logging.error('cannot handle cmd{}'.format( 122 | __result['header']['name'])) 123 | return 124 | except json.JSONDecodeError: 125 | logging.error('cannot parse message:{}'.format(message)) 126 | return 127 | 128 | def __syn_core_on_open(self): 129 | logging.debug('__syn_core_on_open') 130 | with self.__start_cond: 131 | self.__start_flag = True 132 | self.__start_cond.notify() 133 | 134 | def __syn_core_on_data(self, data, opcode, flag): 135 | logging.debug('__syn_core_on_data') 136 | if self.__on_data: 137 | self.__on_data(data, *self.__callback_args) 138 | 139 | def __syn_core_on_msg(self, msg, *args): 140 | logging.debug('__syn_core_on_msg:msg={} args={}'.format(msg, args)) 141 | self.__handle_message(msg) 142 | 143 | def __syn_core_on_error(self, msg, *args): 144 | logging.debug('__sr_core_on_error:msg={} args={}'.format(msg, args)) 145 | 146 | def __syn_core_on_close(self): 147 | logging.debug('__sr_core_on_close') 148 | if self.__on_close: 149 | self.__on_close(*self.__callback_args) 150 | with self.__start_cond: 151 | self.__start_flag = False 152 | self.__start_cond.notify() 153 | 154 | def __metainfo(self, message): 155 | logging.debug('__metainfo') 156 | if self.__on_metainfo: 157 | self.__on_metainfo(message, *self.__callback_args) 158 | 159 | def __synthesis_completed(self, message): 160 | logging.debug('__synthesis_completed') 161 | self.__nls.shutdown() 162 | logging.debug('__synthesis_completed shutdown done') 163 | if self.__on_completed: 164 | self.__on_completed(message, *self.__callback_args) 165 | with self.__start_cond: 166 | self.__start_flag = False 167 | self.__start_cond.notify() 168 | 169 | def __task_failed(self, message): 170 | logging.debug('__task_failed') 171 | with self.__start_cond: 172 | self.__start_flag = False 173 | self.__start_cond.notify() 174 | if self.__on_error: 175 | self.__on_error(message, *self.__callback_args) 176 | 177 | def start(self, 178 | text=None, 179 | voice='xiaoyun', 180 | aformat='pcm', 181 | sample_rate=16000, 182 | volume=50, 183 | speech_rate=0, 184 | pitch_rate=0, 185 | wait_complete=True, 186 | start_timeout=10, 187 | completed_timeout=60, 188 | ex:dict=None): 189 | """ 190 | Synthesis start 191 | 192 | Parameters: 193 | ----------- 194 | text: str 195 | utf-8 text 196 | voice: str 197 | voice for text-to-speech, default is xiaoyun 198 | aformat: str 199 | audio binary format, support: 'pcm', 'wav', 'mp3', default is 'pcm' 200 | sample_rate: int 201 | audio sample rate, default is 16000, support:8000, 11025, 16000, 22050, 202 | 24000, 32000, 44100, 48000 203 | volume: int 204 | audio volume, from 0~100, default is 50 205 | speech_rate: int 206 | speech rate from -500~500, default is 0 207 | pitch_rate: int 208 | pitch for voice from -500~500, default is 0 209 | wait_complete: bool 210 | whether block until syntheis completed or timeout for completed timeout 211 | start_timeout: int 212 | timeout for connection established 213 | completed_timeout: int 214 | timeout for waiting synthesis completed from connection established 215 | ex: dict 216 | dict which will merge into 'payload' field in request 217 | """ 218 | if text is None: 219 | raise InvalidParameter('Text cannot be None') 220 | 221 | self.__nls = NlsCore( 222 | url=self.__url, 223 | token=self.__token, 224 | on_open=self.__syn_core_on_open, 225 | on_message=self.__syn_core_on_msg, 226 | on_data=self.__syn_core_on_data, 227 | on_close=self.__syn_core_on_close, 228 | on_error=self.__syn_core_on_error, 229 | callback_args=[]) 230 | 231 | if aformat not in self.__allow_aformat: 232 | raise InvalidParameter('format {} not support'.format(aformat)) 233 | if sample_rate not in self.__allow_sample_rate: 234 | raise InvalidParameter('samplerate {} not support'.format(sample_rate)) 235 | if volume < 0 or volume > 100: 236 | raise InvalidParameter('volume {} not support'.format(volume)) 237 | if speech_rate < -500 or speech_rate > 500: 238 | raise InvalidParameter('speech_rate {} not support'.format(speech_rate)) 239 | if pitch_rate < -500 or pitch_rate > 500: 240 | raise InvalidParameter('pitch rate {} not support'.format(pitch_rate)) 241 | 242 | __id4 = uuid.uuid4().hex 243 | self.__task_id = uuid.uuid4().hex 244 | __namespace = __SPEECH_SYNTHESIZER_NAMESPACE__ 245 | if self.__long_tts: 246 | __namespace = __SPEECH_LONG_SYNTHESIZER_NAMESPACE__ 247 | __header = { 248 | 'message_id': __id4, 249 | 'task_id': self.__task_id, 250 | 'namespace': __namespace, 251 | 'name': __SPEECH_SYNTHESIZER_REQUEST_CMD__['start'], 252 | 'appkey': self.__appkey 253 | } 254 | __payload = { 255 | 'text': text, 256 | 'voice': voice, 257 | 'format': aformat, 258 | 'sample_rate': sample_rate, 259 | 'volume': volume, 260 | 'speech_rate': speech_rate, 261 | 'pitch_rate': pitch_rate 262 | } 263 | if ex: 264 | __payload.update(ex) 265 | __msg = { 266 | 'header': __header, 267 | 'payload': __payload, 268 | 'context': util.GetDefaultContext() 269 | } 270 | __jmsg = json.dumps(__msg) 271 | with self.__start_cond: 272 | if self.__start_flag: 273 | logging.debug('already start...') 274 | return 275 | self.__nls.start(__jmsg, ping_interval=0, ping_timeout=None) 276 | if self.__start_flag == False: 277 | if not self.__start_cond.wait(start_timeout): 278 | logging.debug('syn start timeout') 279 | raise StartTimeoutException(f'Waiting Start over {start_timeout}s') 280 | if self.__start_flag and wait_complete: 281 | if not self.__start_cond.wait(completed_timeout): 282 | raise CompleteTimeoutException(f'Waiting Complete over {completed_timeout}s') 283 | 284 | def shutdown(self): 285 | """ 286 | Shutdown connection immediately 287 | """ 288 | self.__nls.shutdown() 289 | -------------------------------------------------------------------------------- /nls/speech_transcriber.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | import uuid 5 | import json 6 | import threading 7 | 8 | from nls.core import NlsCore 9 | from . import logging 10 | from . import util 11 | from nls.exception import (StartTimeoutException, 12 | StopTimeoutException, 13 | NotStartException, 14 | InvalidParameter) 15 | 16 | __SPEECH_TRANSCRIBER_NAMESPACE__ = 'SpeechTranscriber' 17 | 18 | __SPEECH_TRANSCRIBER_REQUEST_CMD__ = { 19 | 'start': 'StartTranscription', 20 | 'stop': 'StopTranscription', 21 | 'control': 'ControlTranscriber' 22 | } 23 | 24 | __URL__ = 'wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1' 25 | __all__ = ['NlsSpeechTranscriber'] 26 | 27 | 28 | class NlsSpeechTranscriber: 29 | """ 30 | Api for realtime speech transcription 31 | """ 32 | 33 | def __init__(self, 34 | url=__URL__, 35 | token=None, 36 | appkey=None, 37 | on_start=None, 38 | on_sentence_begin=None, 39 | on_sentence_end=None, 40 | on_result_changed=None, 41 | on_completed=None, 42 | on_error=None, 43 | on_close=None, 44 | callback_args=[]): 45 | ''' 46 | NlsSpeechTranscriber initialization 47 | 48 | Parameters: 49 | ----------- 50 | url: str 51 | websocket url. 52 | token: str 53 | access token. if you do not have a token, provide access id and key 54 | secret from your aliyun account. 55 | appkey: str 56 | appkey from aliyun 57 | on_start: function 58 | Callback object which is called when recognition started. 59 | on_start has two arguments. 60 | The 1st argument is message which is a json format string. 61 | The 2nd argument is *args which is callback_args. 62 | on_sentence_begin: function 63 | Callback object which is called when one sentence started. 64 | on_sentence_begin has two arguments. 65 | The 1st argument is message which is a json format string. 66 | The 2nd argument is *args which is callback_args. 67 | on_sentence_end: function 68 | Callback object which is called when sentence is end. 69 | on_sentence_end has two arguments. 70 | The 1st argument is message which is a json format string. 71 | The 2nd argument is *args which is callback_args. 72 | on_result_changed: function 73 | Callback object which is called when partial recognition result 74 | arrived. 75 | on_result_changed has two arguments. 76 | The 1st argument is message which is a json format string. 77 | The 2nd argument is *args which is callback_args. 78 | on_completed: function 79 | Callback object which is called when recognition is completed. 80 | on_completed has two arguments. 81 | The 1st argument is message which is a json format string. 82 | The 2nd argument is *args which is callback_args. 83 | on_error: function 84 | Callback object which is called when any error occurs. 85 | on_error has two arguments. 86 | The 1st argument is message which is a json format string. 87 | The 2nd argument is *args which is callback_args. 88 | on_close: function 89 | Callback object which is called when connection closed. 90 | on_close has one arguments. 91 | The 1st argument is *args which is callback_args. 92 | callback_args: list 93 | callback_args will return in callbacks above for *args. 94 | ''' 95 | if not token or not appkey: 96 | raise InvalidParameter('Must provide token and appkey') 97 | self.__response_handler__ = { 98 | 'SentenceBegin': self.__sentence_begin, 99 | 'SentenceEnd': self.__sentence_end, 100 | 'TranscriptionStarted': self.__transcription_started, 101 | 'TranscriptionResultChanged': self.__transcription_result_changed, 102 | 'TranscriptionCompleted': self.__transcription_completed, 103 | 'TaskFailed': self.__task_failed 104 | } 105 | self.__callback_args = callback_args 106 | self.__url = url 107 | self.__appkey = appkey 108 | self.__token = token 109 | self.__start_cond = threading.Condition() 110 | self.__start_flag = False 111 | self.__on_start = on_start 112 | self.__on_sentence_begin = on_sentence_begin 113 | self.__on_sentence_end = on_sentence_end 114 | self.__on_result_changed = on_result_changed 115 | self.__on_completed = on_completed 116 | self.__on_error = on_error 117 | self.__on_close = on_close 118 | self.__allow_aformat = ( 119 | 'pcm', 'opus', 'opu', 'wav', 'amr', 'speex', 'mp3', 'aac' 120 | ) 121 | 122 | def __handle_message(self, message): 123 | logging.debug('__handle_message') 124 | try: 125 | __result = json.loads(message) 126 | if __result['header']['name'] in self.__response_handler__: 127 | __handler = self.__response_handler__[ 128 | __result['header']['name']] 129 | __handler(message) 130 | else: 131 | logging.error('cannot handle cmd{}'.format( 132 | __result['header']['name'])) 133 | return 134 | except json.JSONDecodeError: 135 | logging.error('cannot parse message:{}'.format(message)) 136 | return 137 | 138 | def __tr_core_on_open(self): 139 | logging.debug('__tr_core_on_open') 140 | 141 | def __tr_core_on_msg(self, msg, *args): 142 | logging.debug('__tr_core_on_msg:msg={} args={}'.format(msg, args)) 143 | self.__handle_message(msg) 144 | 145 | def __tr_core_on_error(self, msg, *args): 146 | logging.debug('__tr_core_on_error:msg={} args={}'.format(msg, args)) 147 | 148 | def __tr_core_on_close(self): 149 | logging.debug('__tr_core_on_close') 150 | if self.__on_close: 151 | self.__on_close(*self.__callback_args) 152 | with self.__start_cond: 153 | self.__start_flag = False 154 | self.__start_cond.notify() 155 | 156 | def __sentence_begin(self, message): 157 | logging.debug('__sentence_begin') 158 | if self.__on_sentence_begin: 159 | self.__on_sentence_begin(message, *self.__callback_args) 160 | 161 | def __sentence_end(self, message): 162 | logging.debug('__sentence_end') 163 | if self.__on_sentence_end: 164 | self.__on_sentence_end(message, *self.__callback_args) 165 | 166 | def __transcription_started(self, message): 167 | logging.debug('__transcription_started') 168 | if self.__on_start: 169 | self.__on_start(message, *self.__callback_args) 170 | with self.__start_cond: 171 | self.__start_flag = True 172 | self.__start_cond.notify() 173 | 174 | def __transcription_result_changed(self, message): 175 | logging.debug('__transcription_result_changed') 176 | if self.__on_result_changed: 177 | self.__on_result_changed(message, *self.__callback_args) 178 | 179 | def __transcription_completed(self, message): 180 | logging.debug('__transcription_completed') 181 | self.__nls.shutdown() 182 | logging.debug('__transcription_completed shutdown done') 183 | if self.__on_completed: 184 | self.__on_completed(message, *self.__callback_args) 185 | with self.__start_cond: 186 | self.__start_flag = False 187 | self.__start_cond.notify() 188 | 189 | def __task_failed(self, message): 190 | logging.debug('__task_failed') 191 | with self.__start_cond: 192 | self.__start_flag = False 193 | self.__start_cond.notify() 194 | if self.__on_error: 195 | self.__on_error(message, *self.__callback_args) 196 | 197 | def start(self, aformat='pcm', sample_rate=16000, ch=1, 198 | enable_intermediate_result=False, 199 | enable_punctuation_prediction=False, 200 | enable_inverse_text_normalization=False, 201 | timeout=10, 202 | ping_interval=8, 203 | ping_timeout=None, 204 | ex:dict=None): 205 | """ 206 | Transcription start 207 | 208 | Parameters: 209 | ----------- 210 | aformat: str 211 | audio binary format, support: 'pcm', 'opu', 'opus', default is 'pcm' 212 | sample_rate: int 213 | audio sample rate, default is 16000 214 | ch: int 215 | audio channels, only support mono which is 1 216 | enable_intermediate_result: bool 217 | whether enable return intermediate recognition result, default is False 218 | enable_punctuation_prediction: bool 219 | whether enable punctuation prediction, default is False 220 | enable_inverse_text_normalization: bool 221 | whether enable ITN, default is False 222 | timeout: int 223 | wait timeout for connection setup 224 | ping_interval: int 225 | send ping interval, 0 for disable ping send, default is 8 226 | ping_timeout: int 227 | timeout after send ping and recive pong, set None for disable timeout check and default is None 228 | ex: dict 229 | dict which will merge into 'payload' field in request 230 | """ 231 | self.__nls = NlsCore( 232 | url=self.__url, 233 | token=self.__token, 234 | on_open=self.__tr_core_on_open, 235 | on_message=self.__tr_core_on_msg, 236 | on_close=self.__tr_core_on_close, 237 | on_error=self.__tr_core_on_error, 238 | callback_args=[]) 239 | 240 | if ch != 1: 241 | raise ValueError('not support channel: {}'.format(ch)) 242 | if aformat not in self.__allow_aformat: 243 | raise ValueError('format {} not support'.format(aformat)) 244 | __id4 = uuid.uuid4().hex 245 | self.__task_id = uuid.uuid4().hex 246 | __header = { 247 | 'message_id': __id4, 248 | 'task_id': self.__task_id, 249 | 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__, 250 | 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['start'], 251 | 'appkey': self.__appkey 252 | } 253 | __payload = { 254 | 'format': aformat, 255 | 'sample_rate': sample_rate, 256 | 'enable_intermediate_result': enable_intermediate_result, 257 | 'enable_punctuation_prediction': enable_punctuation_prediction, 258 | 'enable_inverse_text_normalization': enable_inverse_text_normalization 259 | } 260 | 261 | if ex: 262 | __payload.update(ex) 263 | 264 | __msg = { 265 | 'header': __header, 266 | 'payload': __payload, 267 | 'context': util.GetDefaultContext() 268 | } 269 | __jmsg = json.dumps(__msg) 270 | with self.__start_cond: 271 | if self.__start_flag: 272 | logging.debug('already start...') 273 | return 274 | self.__nls.start(__jmsg, ping_interval, ping_timeout) 275 | if self.__start_flag == False: 276 | if self.__start_cond.wait(timeout): 277 | return 278 | else: 279 | raise StartTimeoutException(f'Waiting Start over {timeout}s') 280 | 281 | def stop(self, timeout=10): 282 | """ 283 | Stop transcription and mark session finished 284 | 285 | Parameters: 286 | ----------- 287 | timeout: int 288 | timeout for waiting completed message from cloud 289 | """ 290 | __id4 = uuid.uuid4().hex 291 | __header = { 292 | 'message_id': __id4, 293 | 'task_id': self.__task_id, 294 | 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__, 295 | 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['stop'], 296 | 'appkey': self.__appkey 297 | } 298 | __msg = { 299 | 'header': __header, 300 | 'context': util.GetDefaultContext() 301 | } 302 | __jmsg = json.dumps(__msg) 303 | with self.__start_cond: 304 | if not self.__start_flag: 305 | logging.debug('not start yet...') 306 | return 307 | self.__nls.send(__jmsg, False) 308 | if self.__start_flag == True: 309 | logging.debug('stop wait..') 310 | if self.__start_cond.wait(timeout): 311 | return 312 | else: 313 | raise StopTimeoutException(f'Waiting stop over {timeout}s') 314 | 315 | def ctrl(self, **kwargs): 316 | """ 317 | Send control message to cloud 318 | 319 | Parameters: 320 | ----------- 321 | kwargs: dict 322 | dict which will merge into 'payload' field in request 323 | """ 324 | if not kwargs: 325 | raise InvalidParameter('Empty kwargs not allowed!') 326 | __id4 = uuid.uuid4().hex 327 | __header = { 328 | 'message_id': __id4, 329 | 'task_id': self.__task_id, 330 | 'namespace': __SPEECH_TRANSCRIBER_NAMESPACE__, 331 | 'name': __SPEECH_TRANSCRIBER_REQUEST_CMD__['control'], 332 | 'appkey': self.__appkey 333 | } 334 | payload = {} 335 | payload.update(kwargs) 336 | __msg = { 337 | 'header': __header, 338 | 'payload': payload, 339 | 'context': util.GetDefaultContext() 340 | } 341 | __jmsg = json.dumps(__msg) 342 | with self.__start_cond: 343 | if not self.__start_flag: 344 | logging.debug('not start yet...') 345 | return 346 | self.__nls.send(__jmsg, False) 347 | 348 | def shutdown(self): 349 | """ 350 | Shutdown connection immediately 351 | """ 352 | self.__nls.shutdown() 353 | 354 | def send_audio(self, pcm_data): 355 | """ 356 | Send audio binary, audio size prefer 20ms length 357 | 358 | Parameters: 359 | ----------- 360 | pcm_data: bytes 361 | audio binary which format is 'aformat' in start method 362 | """ 363 | 364 | __data = pcm_data 365 | with self.__start_cond: 366 | if not self.__start_flag: 367 | return 368 | try: 369 | self.__nls.send(__data, True) 370 | except ConnectionResetError as __e: 371 | logging.error('connection reset') 372 | self.__start_flag = False 373 | self.__nls.shutdown() 374 | raise __e 375 | -------------------------------------------------------------------------------- /nls/stream_input_tts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import logging 4 | import uuid 5 | import json 6 | import threading 7 | from enum import IntEnum 8 | 9 | from nls.core import NlsCore 10 | from . import logging 11 | from .exception import StartTimeoutException, WrongStateException, InvalidParameter 12 | 13 | __STREAM_INPUT_TTS_NAMESPACE__ = "FlowingSpeechSynthesizer" 14 | 15 | __STREAM_INPUT_TTS_REQUEST_CMD__ = { 16 | "start": "StartSynthesis", 17 | "send": "RunSynthesis", 18 | "stop": "StopSynthesis", 19 | } 20 | __STREAM_INPUT_TTS_REQUEST_NAME__ = { 21 | "started": "SynthesisStarted", 22 | "sentence_begin": "SentenceBegin", 23 | "sentence_synthesis": "SentenceSynthesis", 24 | "sentence_end": "SentenceEnd", 25 | "completed": "SynthesisCompleted", 26 | "task_failed": "TaskFailed", 27 | } 28 | 29 | __URL__ = "wss://nls-gateway.cn-shanghai.aliyuncs.com/ws/v1" 30 | 31 | __all__ = ["NlsStreamInputTtsSynthesizer"] 32 | 33 | 34 | class NlsStreamInputTtsRequest: 35 | def __init__(self, task_id, session_id, appkey): 36 | self.task_id = task_id 37 | self.appkey = appkey 38 | self.session_id = session_id 39 | 40 | def getStartCMD(self, voice, format, sample_rate, volumn, speech_rate, pitch_rate): 41 | self.voice = voice 42 | self.format = format 43 | self.sample_rate = sample_rate 44 | self.volumn = volumn 45 | self.speech_rate = speech_rate 46 | self.pitch_rate = pitch_rate 47 | cmd = { 48 | "header": { 49 | "message_id": uuid.uuid4().hex, 50 | "task_id": self.task_id, 51 | "name": __STREAM_INPUT_TTS_REQUEST_CMD__["start"], 52 | "namespace": __STREAM_INPUT_TTS_NAMESPACE__, 53 | "appkey": self.appkey, 54 | }, 55 | "payload": { 56 | "session_id": self.session_id, 57 | "voice": self.voice, 58 | "format": self.format, 59 | "sample_rate": self.sample_rate, 60 | "volumn": self.volumn, 61 | "speech_rate": self.speech_rate, 62 | "pitch_rate": self.pitch_rate, 63 | }, 64 | } 65 | return json.dumps(cmd) 66 | 67 | def getSendCMD(self, text): 68 | cmd = { 69 | "header": { 70 | "message_id": uuid.uuid4().hex, 71 | "task_id": self.task_id, 72 | "name": __STREAM_INPUT_TTS_REQUEST_CMD__["send"], 73 | "namespace": __STREAM_INPUT_TTS_NAMESPACE__, 74 | "appkey": self.appkey, 75 | }, 76 | "payload": {"text": text}, 77 | } 78 | return json.dumps(cmd) 79 | 80 | def getStopCMD(self): 81 | cmd = { 82 | "header": { 83 | "message_id": uuid.uuid4().hex, 84 | "task_id": self.task_id, 85 | "name": __STREAM_INPUT_TTS_REQUEST_CMD__["stop"], 86 | "namespace": __STREAM_INPUT_TTS_NAMESPACE__, 87 | "appkey": self.appkey, 88 | }, 89 | } 90 | return json.dumps(cmd) 91 | 92 | 93 | class NlsStreamInputTtsStatus(IntEnum): 94 | Begin = 1 95 | Start = 2 96 | Started = 3 97 | WaitingComplete = 3 98 | Completed = 4 99 | Failed = 5 100 | Closed = 6 101 | 102 | class ThreadSafeStatus: 103 | def __init__(self, state: NlsStreamInputTtsStatus): 104 | self._state = state 105 | self._lock = threading.Lock() 106 | 107 | def get(self) -> NlsStreamInputTtsStatus: 108 | with self._lock: 109 | return self._state 110 | 111 | def set(self, state: NlsStreamInputTtsStatus): 112 | with self._lock: 113 | self._state = state 114 | 115 | 116 | class NlsStreamInputTtsSynthesizer: 117 | """ 118 | Api for text-to-speech 119 | """ 120 | 121 | def __init__( 122 | self, 123 | url=__URL__, 124 | token=None, 125 | appkey=None, 126 | session_id=None, 127 | on_data=None, 128 | on_sentence_begin=None, 129 | on_sentence_synthesis=None, 130 | on_sentence_end=None, 131 | on_completed=None, 132 | on_error=None, 133 | on_close=None, 134 | callback_args=[], 135 | ): 136 | """ 137 | NlsSpeechSynthesizer initialization 138 | 139 | Parameters: 140 | ----------- 141 | url: str 142 | websocket url. 143 | akid: str 144 | access id from aliyun. if you provide a token, ignore this argument. 145 | appkey: str 146 | appkey from aliyun 147 | session_id: str 148 | 32-character string, if empty, sdk will generate a random string. 149 | on_data: function 150 | Callback object which is called when partial synthesis result arrived 151 | arrived. 152 | on_result_changed has two arguments. 153 | The 1st argument is binary data corresponding to aformat in start 154 | method. 155 | The 2nd argument is *args which is callback_args. 156 | on_sentence_begin: function 157 | Callback object which is called when detected sentence start. 158 | on_start has two arguments. 159 | The 1st argument is message which is a json format string. 160 | The 2nd argument is *args which is callback_args. 161 | on_sentence_synthesis: function 162 | Callback object which is called when detected sentence synthesis. 163 | The incremental timestamp is returned within payload. 164 | on_start has two arguments. 165 | The 1st argument is message which is a json format string. 166 | The 2nd argument is *args which is callback_args. 167 | on_sentence_end: function 168 | Callback object which is called when detected sentence end. 169 | The timestamp of the whole sentence is returned within payload. 170 | on_start has two arguments. 171 | The 1st argument is message which is a json format string. 172 | The 2nd argument is *args which is callback_args. 173 | on_completed: function 174 | Callback object which is called when recognition is completed. 175 | on_completed has two arguments. 176 | The 1st argument is message which is a json format string. 177 | The 2nd argument is *args which is callback_args. 178 | on_error: function 179 | Callback object which is called when any error occurs. 180 | on_error has two arguments. 181 | The 1st argument is message which is a json format string. 182 | The 2nd argument is *args which is callback_args. 183 | on_close: function 184 | Callback object which is called when connection closed. 185 | on_close has one arguments. 186 | The 1st argument is *args which is callback_args. 187 | callback_args: list 188 | callback_args will return in callbacks above for *args. 189 | """ 190 | if not token or not appkey: 191 | raise InvalidParameter("Must provide token and appkey") 192 | self.__response_handler__ = { 193 | __STREAM_INPUT_TTS_REQUEST_NAME__["started"]: self.__synthesis_started, 194 | __STREAM_INPUT_TTS_REQUEST_NAME__["sentence_begin"]: self.__sentence_begin, 195 | __STREAM_INPUT_TTS_REQUEST_NAME__[ 196 | "sentence_synthesis" 197 | ]: self.__sentence_synthesis, 198 | __STREAM_INPUT_TTS_REQUEST_NAME__["sentence_end"]: self.__sentence_end, 199 | __STREAM_INPUT_TTS_REQUEST_NAME__["completed"]: self.__synthesis_completed, 200 | __STREAM_INPUT_TTS_REQUEST_NAME__["task_failed"]: self.__task_failed, 201 | } 202 | self.__callback_args = callback_args 203 | self.__url = url 204 | self.__appkey = appkey 205 | self.__token = token 206 | self.__session_id = session_id 207 | self.start_sended = threading.Event() 208 | self.started_event = threading.Event() 209 | self.complete_event = threading.Event() 210 | self.__on_sentence_begin = on_sentence_begin 211 | self.__on_sentence_synthesis = on_sentence_synthesis 212 | self.__on_sentence_end = on_sentence_end 213 | self.__on_data = on_data 214 | self.__on_completed = on_completed 215 | self.__on_error = on_error 216 | self.__on_close = on_close 217 | self.__allow_aformat = ("pcm", "wav", "mp3") 218 | self.__allow_sample_rate = ( 219 | 8000, 220 | 11025, 221 | 16000, 222 | 22050, 223 | 24000, 224 | 32000, 225 | 44100, 226 | 48000, 227 | ) 228 | self.state = ThreadSafeStatus(NlsStreamInputTtsStatus.Begin) 229 | if not self.__session_id: 230 | self.__session_id = uuid.uuid4().hex 231 | self.request = NlsStreamInputTtsRequest( 232 | uuid.uuid4().hex, self.__session_id, self.__appkey 233 | ) 234 | 235 | def __handle_message(self, message): 236 | logging.debug("__handle_message") 237 | try: 238 | __result = json.loads(message) 239 | if __result["header"]["name"] in self.__response_handler__: 240 | __handler = self.__response_handler__[__result["header"]["name"]] 241 | __handler(message) 242 | else: 243 | logging.error("cannot handle cmd{}".format(__result["header"]["name"])) 244 | return 245 | except json.JSONDecodeError: 246 | logging.error("cannot parse message:{}".format(message)) 247 | return 248 | 249 | def __syn_core_on_open(self): 250 | logging.debug("__syn_core_on_open") 251 | self.start_sended.set() 252 | 253 | def __syn_core_on_data(self, data, opcode, flag): 254 | logging.debug("__syn_core_on_data") 255 | if self.__on_data: 256 | self.__on_data(data, *self.__callback_args) 257 | 258 | def __syn_core_on_msg(self, msg, *args): 259 | logging.debug("__syn_core_on_msg:msg={} args={}".format(msg, args)) 260 | self.__handle_message(msg) 261 | 262 | def __syn_core_on_error(self, msg, *args): 263 | logging.debug("__sr_core_on_error:msg={} args={}".format(msg, args)) 264 | 265 | def __syn_core_on_close(self): 266 | logging.debug("__sr_core_on_close") 267 | if self.__on_close: 268 | self.__on_close(*self.__callback_args) 269 | self.state.set(NlsStreamInputTtsStatus.Closed) 270 | self.start_sended.set() 271 | self.started_event.set() 272 | self.complete_event.set() 273 | 274 | def __synthesis_started(self, message): 275 | logging.debug("__synthesis_started") 276 | self.started_event.set() 277 | 278 | def __sentence_begin(self, message): 279 | logging.debug("__sentence_begin") 280 | if self.__on_sentence_begin: 281 | self.__on_sentence_begin(message, *self.__callback_args) 282 | 283 | def __sentence_synthesis(self, message): 284 | logging.debug("__sentence_synthesis") 285 | if self.__on_sentence_synthesis: 286 | self.__on_sentence_synthesis(message, *self.__callback_args) 287 | 288 | def __sentence_end(self, message): 289 | logging.debug("__sentence_end") 290 | if self.__on_sentence_end: 291 | self.__on_sentence_end(message, *self.__callback_args) 292 | 293 | def __synthesis_completed(self, message): 294 | logging.debug("__synthesis_completed") 295 | if self.__on_completed: 296 | self.__on_completed(message, *self.__callback_args) 297 | self.__nls.shutdown() 298 | logging.debug("__synthesis_completed shutdown done") 299 | self.complete_event.set() 300 | 301 | 302 | def __task_failed(self, message): 303 | logging.debug("__task_failed") 304 | self.start_sended.set() 305 | self.started_event.set() 306 | self.complete_event.set() 307 | if self.__on_error: 308 | self.__on_error(message, *self.__callback_args) 309 | self.state.set(NlsStreamInputTtsStatus.Failed) 310 | 311 | def startStreamInputTts( 312 | self, 313 | voice="longxiaochun", 314 | aformat="pcm", 315 | sample_rate=24000, 316 | volume=50, 317 | speech_rate=0, 318 | pitch_rate=0, 319 | ): 320 | """ 321 | Synthesis start 322 | 323 | Parameters: 324 | ----------- 325 | voice: str 326 | voice for text-to-speech, default is xiaoyun 327 | aformat: str 328 | audio binary format, support: 'pcm', 'wav', 'mp3', default is 'pcm' 329 | sample_rate: int 330 | audio sample rate, default is 24000, support:8000, 11025, 16000, 22050, 331 | 24000, 32000, 44100, 48000 332 | volume: int 333 | audio volume, from 0~100, default is 50 334 | speech_rate: int 335 | speech rate from -500~500, default is 0 336 | pitch_rate: int 337 | pitch for voice from -500~500, default is 0 338 | ex: dict 339 | dict which will merge into 'payload' field in request 340 | """ 341 | 342 | self.__nls = NlsCore( 343 | url=self.__url, 344 | token=self.__token, 345 | on_open=self.__syn_core_on_open, 346 | on_message=self.__syn_core_on_msg, 347 | on_data=self.__syn_core_on_data, 348 | on_close=self.__syn_core_on_close, 349 | on_error=self.__syn_core_on_error, 350 | callback_args=[], 351 | ) 352 | 353 | if aformat not in self.__allow_aformat: 354 | raise InvalidParameter("format {} not support".format(aformat)) 355 | if sample_rate not in self.__allow_sample_rate: 356 | raise InvalidParameter("samplerate {} not support".format(sample_rate)) 357 | if volume < 0 or volume > 100: 358 | raise InvalidParameter("volume {} not support".format(volume)) 359 | if speech_rate < -500 or speech_rate > 500: 360 | raise InvalidParameter("speech_rate {} not support".format(speech_rate)) 361 | if pitch_rate < -500 or pitch_rate > 500: 362 | raise InvalidParameter("pitch rate {} not support".format(pitch_rate)) 363 | 364 | request = self.request.getStartCMD( 365 | voice, aformat, sample_rate, volume, speech_rate, pitch_rate 366 | ) 367 | 368 | last_state = self.state.get() 369 | if last_state != NlsStreamInputTtsStatus.Begin: 370 | logging.debug("start with wrong state {}".format(last_state)) 371 | self.state.set(NlsStreamInputTtsStatus.Failed) 372 | raise WrongStateException("start with wrong state {}".format(last_state)) 373 | 374 | logging.debug("start with request: {}".format(request)) 375 | self.__nls.start(request, ping_interval=0, ping_timeout=None) 376 | self.state.set(NlsStreamInputTtsStatus.Start) 377 | if not self.start_sended.wait(timeout=10): 378 | logging.debug("syn start timeout") 379 | raise StartTimeoutException(f"Waiting Connection before Start over 10s") 380 | 381 | if last_state != NlsStreamInputTtsStatus.Begin: 382 | logging.debug("start with wrong state {}".format(last_state)) 383 | self.state.set(NlsStreamInputTtsStatus.Failed) 384 | raise WrongStateException("start with wrong state {}".format(last_state)) 385 | 386 | if not self.started_event.wait(timeout=10): 387 | logging.debug("syn started timeout") 388 | self.state.set(NlsStreamInputTtsStatus.Failed) 389 | raise StartTimeoutException(f"Waiting Started over 10s") 390 | self.state.set(NlsStreamInputTtsStatus.Started) 391 | 392 | def sendStreamInputTts(self, text): 393 | """ 394 | send text to server 395 | 396 | Parameters: 397 | ----------- 398 | text: str 399 | utf-8 text 400 | """ 401 | last_state = self.state.get() 402 | if last_state != NlsStreamInputTtsStatus.Started: 403 | logging.debug("send with wrong state {}".format(last_state)) 404 | self.state.set(NlsStreamInputTtsStatus.Failed) 405 | raise WrongStateException("send with wrong state {}".format(last_state)) 406 | 407 | request = self.request.getSendCMD(text) 408 | logging.debug("send with request: {}".format(request)) 409 | self.__nls.send(request, None) 410 | 411 | def stopStreamInputTts(self): 412 | """ 413 | Synthesis end 414 | """ 415 | 416 | last_state = self.state.get() 417 | if last_state != NlsStreamInputTtsStatus.Started: 418 | logging.debug("send with wrong state {}".format(last_state)) 419 | self.state.set(NlsStreamInputTtsStatus.Failed) 420 | raise WrongStateException("stop with wrong state {}".format(last_state)) 421 | 422 | 423 | request = self.request.getStopCMD() 424 | logging.debug("stop with request: {}".format(request)) 425 | self.__nls.send(request, None) 426 | self.state.set(NlsStreamInputTtsStatus.WaitingComplete) 427 | self.complete_event.wait() 428 | self.state.set(NlsStreamInputTtsStatus.Completed) 429 | self.shutdown() 430 | 431 | def shutdown(self): 432 | """ 433 | Shutdown connection immediately 434 | """ 435 | 436 | self.__nls.shutdown() 437 | -------------------------------------------------------------------------------- /nls/token.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | from aliyunsdkcore.client import AcsClient 4 | from aliyunsdkcore.request import CommonRequest 5 | from .exception import GetTokenFailed 6 | 7 | import json 8 | 9 | __all__ = ['getToken'] 10 | 11 | def getToken(akid, aksecret, domain='cn-shanghai', 12 | version='2019-02-28', 13 | url='nls-meta.cn-shanghai.aliyuncs.com'): 14 | """ 15 | Help methods to get token from aliyun by giving access id and access secret 16 | key 17 | 18 | Parameters: 19 | ----------- 20 | akid: str 21 | access id from aliyun 22 | aksecret: str 23 | access secret key from aliyun 24 | domain: str: 25 | default is cn-shanghai 26 | version: str: 27 | default is 2019-02-28 28 | url: str 29 | full url for getting token, default is 30 | nls-meta.cn-shanghai.aliyuncs.com 31 | """ 32 | if akid is None or aksecret is None: 33 | raise GetTokenFailed('No akid or aksecret') 34 | client = AcsClient(akid, aksecret, domain) 35 | request = CommonRequest() 36 | request.set_method('POST') 37 | request.set_domain(url) 38 | request.set_version(version) 39 | request.set_action_name('CreateToken') 40 | response = client.do_action_with_exception(request) 41 | response_json = json.loads(response) 42 | if 'Token' in response_json: 43 | token = response_json['Token'] 44 | if 'Id' in token: 45 | return token['Id'] 46 | else: 47 | raise GetTokenFailed(f'Missing id field in token:{token}') 48 | else: 49 | raise GetTokenFailed(f'Token not in response:{response_json}') 50 | -------------------------------------------------------------------------------- /nls/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | from struct import * 4 | 5 | __all__=['wav2pcm', 'GetDefaultContext'] 6 | 7 | def GetDefaultContext(): 8 | """ 9 | Return Default Context Object 10 | """ 11 | return { 12 | 'sdk': { 13 | 'name': 'nls-python-sdk', 14 | 'version': '0.0.1', 15 | 'language': 'python' 16 | } 17 | } 18 | 19 | 20 | def wav2pcm(wavfile, pcmfile): 21 | """ 22 | Turn wav into pcm 23 | 24 | Parameters 25 | ---------- 26 | wavfile: str 27 | wav file path 28 | pcmfile: str 29 | output pcm file path 30 | """ 31 | with open(wavfile, 'rb') as i, open(pcmfile, 'wb') as o: 32 | i.seek(0) 33 | _id = i.read(4) 34 | _id = unpack('>I', _id) 35 | _size = i.read(4) 36 | _size = unpack('I', _type) 39 | if _id[0] != 0x52494646 or _type[0] != 0x57415645: 40 | raise ValueError('not a wav!') 41 | i.read(32) 42 | result = i.read() 43 | o.write(result) 44 | 45 | -------------------------------------------------------------------------------- /nls/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | __version__ = '1.0.0' -------------------------------------------------------------------------------- /nls/websocket/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | __init__.py 3 | websocket - WebSocket client library for Python 4 | 5 | Copyright 2021 engn33r 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | from ._abnf import * 20 | from ._app import WebSocketApp 21 | from ._core import * 22 | from ._exceptions import * 23 | from ._logging import * 24 | from ._socket import * 25 | 26 | __version__ = "1.2.1" 27 | -------------------------------------------------------------------------------- /nls/websocket/_abnf.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | _abnf.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import array 24 | import os 25 | import struct 26 | import sys 27 | 28 | from ._exceptions import * 29 | from ._utils import validate_utf8 30 | from threading import Lock 31 | 32 | try: 33 | # If wsaccel is available, use compiled routines to mask data. 34 | # wsaccel only provides around a 10% speed boost compared 35 | # to the websocket-client _mask() implementation. 36 | # Note that wsaccel is unmaintained. 37 | from wsaccel.xormask import XorMaskerSimple 38 | 39 | def _mask(_m, _d): 40 | return XorMaskerSimple(_m).process(_d) 41 | 42 | except ImportError: 43 | # wsaccel is not available, use websocket-client _mask() 44 | native_byteorder = sys.byteorder 45 | 46 | def _mask(mask_value, data_value): 47 | datalen = len(data_value) 48 | data_value = int.from_bytes(data_value, native_byteorder) 49 | mask_value = int.from_bytes(mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder) 50 | return (data_value ^ mask_value).to_bytes(datalen, native_byteorder) 51 | 52 | 53 | __all__ = [ 54 | 'ABNF', 'continuous_frame', 'frame_buffer', 55 | 'STATUS_NORMAL', 56 | 'STATUS_GOING_AWAY', 57 | 'STATUS_PROTOCOL_ERROR', 58 | 'STATUS_UNSUPPORTED_DATA_TYPE', 59 | 'STATUS_STATUS_NOT_AVAILABLE', 60 | 'STATUS_ABNORMAL_CLOSED', 61 | 'STATUS_INVALID_PAYLOAD', 62 | 'STATUS_POLICY_VIOLATION', 63 | 'STATUS_MESSAGE_TOO_BIG', 64 | 'STATUS_INVALID_EXTENSION', 65 | 'STATUS_UNEXPECTED_CONDITION', 66 | 'STATUS_BAD_GATEWAY', 67 | 'STATUS_TLS_HANDSHAKE_ERROR', 68 | ] 69 | 70 | # closing frame status codes. 71 | STATUS_NORMAL = 1000 72 | STATUS_GOING_AWAY = 1001 73 | STATUS_PROTOCOL_ERROR = 1002 74 | STATUS_UNSUPPORTED_DATA_TYPE = 1003 75 | STATUS_STATUS_NOT_AVAILABLE = 1005 76 | STATUS_ABNORMAL_CLOSED = 1006 77 | STATUS_INVALID_PAYLOAD = 1007 78 | STATUS_POLICY_VIOLATION = 1008 79 | STATUS_MESSAGE_TOO_BIG = 1009 80 | STATUS_INVALID_EXTENSION = 1010 81 | STATUS_UNEXPECTED_CONDITION = 1011 82 | STATUS_BAD_GATEWAY = 1014 83 | STATUS_TLS_HANDSHAKE_ERROR = 1015 84 | 85 | VALID_CLOSE_STATUS = ( 86 | STATUS_NORMAL, 87 | STATUS_GOING_AWAY, 88 | STATUS_PROTOCOL_ERROR, 89 | STATUS_UNSUPPORTED_DATA_TYPE, 90 | STATUS_INVALID_PAYLOAD, 91 | STATUS_POLICY_VIOLATION, 92 | STATUS_MESSAGE_TOO_BIG, 93 | STATUS_INVALID_EXTENSION, 94 | STATUS_UNEXPECTED_CONDITION, 95 | STATUS_BAD_GATEWAY, 96 | ) 97 | 98 | 99 | class ABNF: 100 | """ 101 | ABNF frame class. 102 | See http://tools.ietf.org/html/rfc5234 103 | and http://tools.ietf.org/html/rfc6455#section-5.2 104 | """ 105 | 106 | # operation code values. 107 | OPCODE_CONT = 0x0 108 | OPCODE_TEXT = 0x1 109 | OPCODE_BINARY = 0x2 110 | OPCODE_CLOSE = 0x8 111 | OPCODE_PING = 0x9 112 | OPCODE_PONG = 0xa 113 | 114 | # available operation code value tuple 115 | OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE, 116 | OPCODE_PING, OPCODE_PONG) 117 | 118 | # opcode human readable string 119 | OPCODE_MAP = { 120 | OPCODE_CONT: "cont", 121 | OPCODE_TEXT: "text", 122 | OPCODE_BINARY: "binary", 123 | OPCODE_CLOSE: "close", 124 | OPCODE_PING: "ping", 125 | OPCODE_PONG: "pong" 126 | } 127 | 128 | # data length threshold. 129 | LENGTH_7 = 0x7e 130 | LENGTH_16 = 1 << 16 131 | LENGTH_63 = 1 << 63 132 | 133 | def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0, 134 | opcode=OPCODE_TEXT, mask=1, data=""): 135 | """ 136 | Constructor for ABNF. Please check RFC for arguments. 137 | """ 138 | self.fin = fin 139 | self.rsv1 = rsv1 140 | self.rsv2 = rsv2 141 | self.rsv3 = rsv3 142 | self.opcode = opcode 143 | self.mask = mask 144 | if data is None: 145 | data = "" 146 | self.data = data 147 | self.get_mask_key = os.urandom 148 | 149 | def validate(self, skip_utf8_validation=False): 150 | """ 151 | Validate the ABNF frame. 152 | 153 | Parameters 154 | ---------- 155 | skip_utf8_validation: skip utf8 validation. 156 | """ 157 | if self.rsv1 or self.rsv2 or self.rsv3: 158 | raise WebSocketProtocolException("rsv is not implemented, yet") 159 | 160 | if self.opcode not in ABNF.OPCODES: 161 | raise WebSocketProtocolException("Invalid opcode %r", self.opcode) 162 | 163 | if self.opcode == ABNF.OPCODE_PING and not self.fin: 164 | raise WebSocketProtocolException("Invalid ping frame.") 165 | 166 | if self.opcode == ABNF.OPCODE_CLOSE: 167 | l = len(self.data) 168 | if not l: 169 | return 170 | if l == 1 or l >= 126: 171 | raise WebSocketProtocolException("Invalid close frame.") 172 | if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): 173 | raise WebSocketProtocolException("Invalid close frame.") 174 | 175 | code = 256 * self.data[0] + self.data[1] 176 | if not self._is_valid_close_status(code): 177 | raise WebSocketProtocolException("Invalid close opcode.") 178 | 179 | @staticmethod 180 | def _is_valid_close_status(code): 181 | return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) 182 | 183 | def __str__(self): 184 | return "fin=" + str(self.fin) \ 185 | + " opcode=" + str(self.opcode) \ 186 | + " data=" + str(self.data) 187 | 188 | @staticmethod 189 | def create_frame(data, opcode, fin=1): 190 | """ 191 | Create frame to send text, binary and other data. 192 | 193 | Parameters 194 | ---------- 195 | data: 196 | data to send. This is string value(byte array). 197 | If opcode is OPCODE_TEXT and this value is unicode, 198 | data value is converted into unicode string, automatically. 199 | opcode: 200 | operation code. please see OPCODE_XXX. 201 | fin: 202 | fin flag. if set to 0, create continue fragmentation. 203 | """ 204 | if opcode == ABNF.OPCODE_TEXT and isinstance(data, str): 205 | data = data.encode("utf-8") 206 | # mask must be set if send data from client 207 | return ABNF(fin, 0, 0, 0, opcode, 1, data) 208 | 209 | def format(self): 210 | """ 211 | Format this object to string(byte array) to send data to server. 212 | """ 213 | if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]): 214 | raise ValueError("not 0 or 1") 215 | if self.opcode not in ABNF.OPCODES: 216 | raise ValueError("Invalid OPCODE") 217 | length = len(self.data) 218 | if length >= ABNF.LENGTH_63: 219 | raise ValueError("data is too long") 220 | 221 | frame_header = chr(self.fin << 7 | 222 | self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 | 223 | self.opcode).encode('latin-1') 224 | if length < ABNF.LENGTH_7: 225 | frame_header += chr(self.mask << 7 | length).encode('latin-1') 226 | elif length < ABNF.LENGTH_16: 227 | frame_header += chr(self.mask << 7 | 0x7e).encode('latin-1') 228 | frame_header += struct.pack("!H", length) 229 | else: 230 | frame_header += chr(self.mask << 7 | 0x7f).encode('latin-1') 231 | frame_header += struct.pack("!Q", length) 232 | 233 | if not self.mask: 234 | return frame_header + self.data 235 | else: 236 | mask_key = self.get_mask_key(4) 237 | return frame_header + self._get_masked(mask_key) 238 | 239 | def _get_masked(self, mask_key): 240 | s = ABNF.mask(mask_key, self.data) 241 | 242 | if isinstance(mask_key, str): 243 | mask_key = mask_key.encode('utf-8') 244 | 245 | return mask_key + s 246 | 247 | @staticmethod 248 | def mask(mask_key, data): 249 | """ 250 | Mask or unmask data. Just do xor for each byte 251 | 252 | Parameters 253 | ---------- 254 | mask_key: 255 | 4 byte string. 256 | data: 257 | data to mask/unmask. 258 | """ 259 | if data is None: 260 | data = "" 261 | 262 | if isinstance(mask_key, str): 263 | mask_key = mask_key.encode('latin-1') 264 | 265 | if isinstance(data, str): 266 | data = data.encode('latin-1') 267 | 268 | return _mask(array.array("B", mask_key), array.array("B", data)) 269 | 270 | 271 | class frame_buffer: 272 | _HEADER_MASK_INDEX = 5 273 | _HEADER_LENGTH_INDEX = 6 274 | 275 | def __init__(self, recv_fn, skip_utf8_validation): 276 | self.recv = recv_fn 277 | self.skip_utf8_validation = skip_utf8_validation 278 | # Buffers over the packets from the layer beneath until desired amount 279 | # bytes of bytes are received. 280 | self.recv_buffer = [] 281 | self.clear() 282 | self.lock = Lock() 283 | 284 | def clear(self): 285 | self.header = None 286 | self.length = None 287 | self.mask = None 288 | 289 | def has_received_header(self): 290 | return self.header is None 291 | 292 | def recv_header(self): 293 | header = self.recv_strict(2) 294 | b1 = header[0] 295 | fin = b1 >> 7 & 1 296 | rsv1 = b1 >> 6 & 1 297 | rsv2 = b1 >> 5 & 1 298 | rsv3 = b1 >> 4 & 1 299 | opcode = b1 & 0xf 300 | b2 = header[1] 301 | has_mask = b2 >> 7 & 1 302 | length_bits = b2 & 0x7f 303 | 304 | self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits) 305 | 306 | def has_mask(self): 307 | if not self.header: 308 | return False 309 | return self.header[frame_buffer._HEADER_MASK_INDEX] 310 | 311 | def has_received_length(self): 312 | return self.length is None 313 | 314 | def recv_length(self): 315 | bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] 316 | length_bits = bits & 0x7f 317 | if length_bits == 0x7e: 318 | v = self.recv_strict(2) 319 | self.length = struct.unpack("!H", v)[0] 320 | elif length_bits == 0x7f: 321 | v = self.recv_strict(8) 322 | self.length = struct.unpack("!Q", v)[0] 323 | else: 324 | self.length = length_bits 325 | 326 | def has_received_mask(self): 327 | return self.mask is None 328 | 329 | def recv_mask(self): 330 | self.mask = self.recv_strict(4) if self.has_mask() else "" 331 | 332 | def recv_frame(self): 333 | 334 | with self.lock: 335 | # Header 336 | if self.has_received_header(): 337 | self.recv_header() 338 | (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header 339 | 340 | # Frame length 341 | if self.has_received_length(): 342 | self.recv_length() 343 | length = self.length 344 | 345 | # Mask 346 | if self.has_received_mask(): 347 | self.recv_mask() 348 | mask = self.mask 349 | 350 | # Payload 351 | payload = self.recv_strict(length) 352 | if has_mask: 353 | payload = ABNF.mask(mask, payload) 354 | 355 | # Reset for next frame 356 | self.clear() 357 | 358 | frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) 359 | frame.validate(self.skip_utf8_validation) 360 | 361 | return frame 362 | 363 | def recv_strict(self, bufsize): 364 | shortage = bufsize - sum(map(len, self.recv_buffer)) 365 | while shortage > 0: 366 | # Limit buffer size that we pass to socket.recv() to avoid 367 | # fragmenting the heap -- the number of bytes recv() actually 368 | # reads is limited by socket buffer and is relatively small, 369 | # yet passing large numbers repeatedly causes lots of large 370 | # buffers allocated and then shrunk, which results in 371 | # fragmentation. 372 | bytes_ = self.recv(min(16384, shortage)) 373 | self.recv_buffer.append(bytes_) 374 | shortage -= len(bytes_) 375 | 376 | unified = bytes("", 'utf-8').join(self.recv_buffer) 377 | 378 | if shortage == 0: 379 | self.recv_buffer = [] 380 | return unified 381 | else: 382 | self.recv_buffer = [unified[bufsize:]] 383 | return unified[:bufsize] 384 | 385 | 386 | class continuous_frame: 387 | 388 | def __init__(self, fire_cont_frame, skip_utf8_validation): 389 | self.fire_cont_frame = fire_cont_frame 390 | self.skip_utf8_validation = skip_utf8_validation 391 | self.cont_data = None 392 | self.recving_frames = None 393 | 394 | def validate(self, frame): 395 | if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: 396 | raise WebSocketProtocolException("Illegal frame") 397 | if self.recving_frames and \ 398 | frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 399 | raise WebSocketProtocolException("Illegal frame") 400 | 401 | def add(self, frame): 402 | if self.cont_data: 403 | self.cont_data[1] += frame.data 404 | else: 405 | if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 406 | self.recving_frames = frame.opcode 407 | self.cont_data = [frame.opcode, frame.data] 408 | 409 | if frame.fin: 410 | self.recving_frames = None 411 | 412 | def is_fire(self, frame): 413 | return frame.fin or self.fire_cont_frame 414 | 415 | def extract(self, frame): 416 | data = self.cont_data 417 | self.cont_data = None 418 | frame.data = data[1] 419 | if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): 420 | raise WebSocketPayloadException( 421 | "cannot decode: " + repr(frame.data)) 422 | 423 | return [data[0], frame] 424 | -------------------------------------------------------------------------------- /nls/websocket/_app.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | _app.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import selectors 24 | import sys 25 | import threading 26 | import time 27 | import traceback 28 | from ._abnf import ABNF 29 | from ._core import WebSocket, getdefaulttimeout 30 | from ._exceptions import * 31 | from . import _logging 32 | 33 | 34 | __all__ = ["WebSocketApp"] 35 | 36 | 37 | class Dispatcher: 38 | """ 39 | Dispatcher 40 | """ 41 | def __init__(self, app, ping_timeout): 42 | self.app = app 43 | self.ping_timeout = ping_timeout 44 | 45 | def read(self, sock, read_callback, check_callback): 46 | while self.app.keep_running: 47 | sel = selectors.DefaultSelector() 48 | sel.register(self.app.sock.sock, selectors.EVENT_READ) 49 | 50 | r = sel.select(self.ping_timeout) 51 | if r: 52 | if not read_callback(): 53 | break 54 | check_callback() 55 | sel.close() 56 | 57 | 58 | class SSLDispatcher: 59 | """ 60 | SSLDispatcher 61 | """ 62 | def __init__(self, app, ping_timeout): 63 | self.app = app 64 | self.ping_timeout = ping_timeout 65 | 66 | def read(self, sock, read_callback, check_callback): 67 | while self.app.keep_running: 68 | r = self.select() 69 | if r: 70 | if not read_callback(): 71 | break 72 | check_callback() 73 | 74 | def select(self): 75 | sock = self.app.sock.sock 76 | if sock.pending(): 77 | return [sock,] 78 | 79 | sel = selectors.DefaultSelector() 80 | sel.register(sock, selectors.EVENT_READ) 81 | 82 | r = sel.select(self.ping_timeout) 83 | sel.close() 84 | 85 | if len(r) > 0: 86 | return r[0][0] 87 | 88 | 89 | class WebSocketApp: 90 | """ 91 | Higher level of APIs are provided. The interface is like JavaScript WebSocket object. 92 | """ 93 | 94 | def __init__(self, url, header=None, 95 | on_open=None, on_message=None, on_error=None, 96 | on_close=None, on_ping=None, on_pong=None, 97 | on_cont_message=None, 98 | keep_running=True, get_mask_key=None, cookie=None, 99 | subprotocols=None, 100 | on_data=None, callback_args=[]): 101 | """ 102 | WebSocketApp initialization 103 | 104 | Parameters 105 | ---------- 106 | url: str 107 | Websocket url. 108 | header: list or dict 109 | Custom header for websocket handshake. 110 | on_open: function 111 | Callback object which is called at opening websocket. 112 | on_open has one argument. 113 | The 1st argument is this class object. 114 | on_message: function 115 | Callback object which is called when received data. 116 | on_message has 2 arguments. 117 | The 1st argument is this class object. 118 | The 2nd argument is utf-8 data received from the server. 119 | on_error: function 120 | Callback object which is called when we get error. 121 | on_error has 2 arguments. 122 | The 1st argument is this class object. 123 | The 2nd argument is exception object. 124 | on_close: function 125 | Callback object which is called when connection is closed. 126 | on_close has 3 arguments. 127 | The 1st argument is this class object. 128 | The 2nd argument is close_status_code. 129 | The 3rd argument is close_msg. 130 | on_cont_message: function 131 | Callback object which is called when a continuation 132 | frame is received. 133 | on_cont_message has 3 arguments. 134 | The 1st argument is this class object. 135 | The 2nd argument is utf-8 string which we get from the server. 136 | The 3rd argument is continue flag. if 0, the data continue 137 | to next frame data 138 | on_data: function 139 | Callback object which is called when a message received. 140 | This is called before on_message or on_cont_message, 141 | and then on_message or on_cont_message is called. 142 | on_data has 4 argument. 143 | The 1st argument is this class object. 144 | The 2nd argument is utf-8 string which we get from the server. 145 | The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came. 146 | The 4th argument is continue flag. If 0, the data continue 147 | keep_running: bool 148 | This parameter is obsolete and ignored. 149 | get_mask_key: function 150 | A callable function to get new mask keys, see the 151 | WebSocket.set_mask_key's docstring for more information. 152 | cookie: str 153 | Cookie value. 154 | subprotocols: list 155 | List of available sub protocols. Default is None. 156 | """ 157 | self.url = url 158 | self.header = header if header is not None else [] 159 | self.cookie = cookie 160 | 161 | self.on_open = on_open 162 | self.on_message = on_message 163 | self.on_data = on_data 164 | self.on_error = on_error 165 | self.on_close = on_close 166 | self.on_ping = on_ping 167 | self.on_pong = on_pong 168 | self.on_cont_message = on_cont_message 169 | self.keep_running = False 170 | self.get_mask_key = get_mask_key 171 | self.sock = None 172 | self.last_ping_tm = 0 173 | self.last_pong_tm = 0 174 | self.subprotocols = subprotocols 175 | self.callback_args = callback_args 176 | 177 | def update_args(self, *args): 178 | self.callback_args = args 179 | #print(self.callback_args) 180 | 181 | def send(self, data, opcode=ABNF.OPCODE_TEXT): 182 | """ 183 | send message 184 | 185 | Parameters 186 | ---------- 187 | data: str 188 | Message to send. If you set opcode to OPCODE_TEXT, 189 | data must be utf-8 string or unicode. 190 | opcode: int 191 | Operation code of data. Default is OPCODE_TEXT. 192 | """ 193 | 194 | if not self.sock or self.sock.send(data, opcode) == 0: 195 | raise WebSocketConnectionClosedException( 196 | "Connection is already closed.") 197 | 198 | def close(self, **kwargs): 199 | """ 200 | Close websocket connection. 201 | """ 202 | self.keep_running = False 203 | if self.sock: 204 | self.sock.close(**kwargs) 205 | self.sock = None 206 | 207 | def _send_ping(self, interval, event, payload): 208 | while not event.wait(interval): 209 | self.last_ping_tm = time.time() 210 | if self.sock: 211 | try: 212 | self.sock.ping(payload) 213 | except Exception as ex: 214 | _logging.warning("send_ping routine terminated: {}".format(ex)) 215 | break 216 | 217 | def run_forever(self, sockopt=None, sslopt=None, 218 | ping_interval=0, ping_timeout=None, 219 | ping_payload="", 220 | http_proxy_host=None, http_proxy_port=None, 221 | http_no_proxy=None, http_proxy_auth=None, 222 | skip_utf8_validation=False, 223 | host=None, origin=None, dispatcher=None, 224 | suppress_origin=False, proxy_type=None): 225 | """ 226 | Run event loop for WebSocket framework. 227 | 228 | This loop is an infinite loop and is alive while websocket is available. 229 | 230 | Parameters 231 | ---------- 232 | sockopt: tuple 233 | Values for socket.setsockopt. 234 | sockopt must be tuple 235 | and each element is argument of sock.setsockopt. 236 | sslopt: dict 237 | Optional dict object for ssl socket option. 238 | ping_interval: int or float 239 | Automatically send "ping" command 240 | every specified period (in seconds). 241 | If set to 0, no ping is sent periodically. 242 | ping_timeout: int or float 243 | Timeout (in seconds) if the pong message is not received. 244 | ping_payload: str 245 | Payload message to send with each ping. 246 | http_proxy_host: str 247 | HTTP proxy host name. 248 | http_proxy_port: int or str 249 | HTTP proxy port. If not set, set to 80. 250 | http_no_proxy: list 251 | Whitelisted host names that don't use the proxy. 252 | skip_utf8_validation: bool 253 | skip utf8 validation. 254 | host: str 255 | update host header. 256 | origin: str 257 | update origin header. 258 | dispatcher: Dispatcher object 259 | customize reading data from socket. 260 | suppress_origin: bool 261 | suppress outputting origin header. 262 | 263 | Returns 264 | ------- 265 | teardown: bool 266 | False if caught KeyboardInterrupt, True if other exception was raised during a loop 267 | """ 268 | 269 | if ping_timeout is not None and ping_timeout <= 0: 270 | raise WebSocketException("Ensure ping_timeout > 0") 271 | if ping_interval is not None and ping_interval < 0: 272 | raise WebSocketException("Ensure ping_interval >= 0") 273 | if ping_timeout and ping_interval and ping_interval <= ping_timeout: 274 | raise WebSocketException("Ensure ping_interval > ping_timeout") 275 | if not sockopt: 276 | sockopt = [] 277 | if not sslopt: 278 | sslopt = {} 279 | if self.sock: 280 | raise WebSocketException("socket is already opened") 281 | thread = None 282 | self.keep_running = True 283 | self.last_ping_tm = 0 284 | self.last_pong_tm = 0 285 | 286 | def teardown(close_frame=None): 287 | """ 288 | Tears down the connection. 289 | 290 | Parameters 291 | ---------- 292 | close_frame: ABNF frame 293 | If close_frame is set, the on_close handler is invoked 294 | with the statusCode and reason from the provided frame. 295 | """ 296 | 297 | if thread and thread.is_alive(): 298 | event.set() 299 | thread.join() 300 | self.keep_running = False 301 | if self.sock: 302 | self.sock.close() 303 | close_status_code, close_reason = self._get_close_args( 304 | close_frame if close_frame else None) 305 | self.sock = None 306 | 307 | # Finally call the callback AFTER all teardown is complete 308 | self._callback(self.on_close, close_status_code, close_reason, 309 | self.callback_args) 310 | 311 | try: 312 | self.sock = WebSocket( 313 | self.get_mask_key, sockopt=sockopt, sslopt=sslopt, 314 | fire_cont_frame=self.on_cont_message is not None, 315 | skip_utf8_validation=skip_utf8_validation, 316 | enable_multithread=True) 317 | self.sock.settimeout(getdefaulttimeout()) 318 | self.sock.connect( 319 | self.url, header=self.header, cookie=self.cookie, 320 | http_proxy_host=http_proxy_host, 321 | http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy, 322 | http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols, 323 | host=host, origin=origin, suppress_origin=suppress_origin, 324 | proxy_type=proxy_type) 325 | if not dispatcher: 326 | dispatcher = self.create_dispatcher(ping_timeout) 327 | 328 | self._callback(self.on_open, self.callback_args) 329 | 330 | if ping_interval: 331 | event = threading.Event() 332 | thread = threading.Thread( 333 | target=self._send_ping, args=(ping_interval, event, ping_payload)) 334 | thread.daemon = True 335 | thread.start() 336 | 337 | def read(): 338 | if not self.keep_running: 339 | return teardown() 340 | 341 | op_code, frame = self.sock.recv_data_frame(True) 342 | if op_code == ABNF.OPCODE_CLOSE: 343 | return teardown(frame) 344 | elif op_code == ABNF.OPCODE_PING: 345 | self._callback(self.on_ping, frame.data, self.callback_args) 346 | elif op_code == ABNF.OPCODE_PONG: 347 | self.last_pong_tm = time.time() 348 | self._callback(self.on_pong, frame.data, self.callback_args) 349 | elif op_code == ABNF.OPCODE_CONT and self.on_cont_message: 350 | self._callback(self.on_data, frame.data, 351 | frame.opcode, frame.fin, self.callback_args) 352 | self._callback(self.on_cont_message, 353 | frame.data, frame.fin, self.callback_args) 354 | else: 355 | data = frame.data 356 | if op_code == ABNF.OPCODE_TEXT: 357 | data = data.decode("utf-8") 358 | self._callback(self.on_message, data, self.callback_args) 359 | else: 360 | self._callback(self.on_data, data, frame.opcode, True, 361 | self.callback_args) 362 | 363 | return True 364 | 365 | def check(): 366 | if (ping_timeout): 367 | has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout 368 | has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0 369 | has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout 370 | 371 | if (self.last_ping_tm and 372 | has_timeout_expired and 373 | (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)): 374 | raise WebSocketTimeoutException("ping/pong timed out") 375 | return True 376 | 377 | dispatcher.read(self.sock.sock, read, check) 378 | except (Exception, KeyboardInterrupt, SystemExit) as e: 379 | self._callback(self.on_error, e, self.callback_args) 380 | if isinstance(e, SystemExit): 381 | # propagate SystemExit further 382 | raise 383 | teardown() 384 | return not isinstance(e, KeyboardInterrupt) 385 | else: 386 | teardown() 387 | return True 388 | 389 | def create_dispatcher(self, ping_timeout): 390 | timeout = ping_timeout or 10 391 | if self.sock.is_ssl(): 392 | return SSLDispatcher(self, timeout) 393 | 394 | return Dispatcher(self, timeout) 395 | 396 | def _get_close_args(self, close_frame): 397 | """ 398 | _get_close_args extracts the close code and reason from the close body 399 | if it exists (RFC6455 says WebSocket Connection Close Code is optional) 400 | """ 401 | # Need to catch the case where close_frame is None 402 | # Otherwise the following if statement causes an error 403 | if not self.on_close or not close_frame: 404 | return [None, None] 405 | 406 | # Extract close frame status code 407 | if close_frame.data and len(close_frame.data) >= 2: 408 | close_status_code = 256 * close_frame.data[0] + close_frame.data[1] 409 | reason = close_frame.data[2:].decode('utf-8') 410 | return [close_status_code, reason] 411 | else: 412 | # Most likely reached this because len(close_frame_data.data) < 2 413 | return [None, None] 414 | 415 | def _callback(self, callback, *args): 416 | if callback: 417 | try: 418 | callback(self, *args) 419 | 420 | except Exception as e: 421 | _logging.error("error from callback {}: {}".format(callback, e)) 422 | if self.on_error: 423 | self.on_error(self, e) 424 | -------------------------------------------------------------------------------- /nls/websocket/_cookiejar.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | _cookiejar.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import http.cookies 24 | 25 | 26 | class SimpleCookieJar: 27 | def __init__(self): 28 | self.jar = dict() 29 | 30 | def add(self, set_cookie): 31 | if set_cookie: 32 | simpleCookie = http.cookies.SimpleCookie(set_cookie) 33 | 34 | for k, v in simpleCookie.items(): 35 | domain = v.get("domain") 36 | if domain: 37 | if not domain.startswith("."): 38 | domain = "." + domain 39 | cookie = self.jar.get(domain) if self.jar.get(domain) else http.cookies.SimpleCookie() 40 | cookie.update(simpleCookie) 41 | self.jar[domain.lower()] = cookie 42 | 43 | def set(self, set_cookie): 44 | if set_cookie: 45 | simpleCookie = http.cookies.SimpleCookie(set_cookie) 46 | 47 | for k, v in simpleCookie.items(): 48 | domain = v.get("domain") 49 | if domain: 50 | if not domain.startswith("."): 51 | domain = "." + domain 52 | self.jar[domain.lower()] = simpleCookie 53 | 54 | def get(self, host): 55 | if not host: 56 | return "" 57 | 58 | cookies = [] 59 | for domain, simpleCookie in self.jar.items(): 60 | host = host.lower() 61 | if host.endswith(domain) or host == domain[1:]: 62 | cookies.append(self.jar.get(domain)) 63 | 64 | return "; ".join(filter( 65 | None, sorted( 66 | ["%s=%s" % (k, v.value) for cookie in filter(None, cookies) for k, v in cookie.items()] 67 | ))) 68 | -------------------------------------------------------------------------------- /nls/websocket/_exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define WebSocket exceptions 3 | """ 4 | 5 | """ 6 | _exceptions.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | 24 | 25 | class WebSocketException(Exception): 26 | """ 27 | WebSocket exception class. 28 | """ 29 | pass 30 | 31 | 32 | class WebSocketProtocolException(WebSocketException): 33 | """ 34 | If the WebSocket protocol is invalid, this exception will be raised. 35 | """ 36 | pass 37 | 38 | 39 | class WebSocketPayloadException(WebSocketException): 40 | """ 41 | If the WebSocket payload is invalid, this exception will be raised. 42 | """ 43 | pass 44 | 45 | 46 | class WebSocketConnectionClosedException(WebSocketException): 47 | """ 48 | If remote host closed the connection or some network error happened, 49 | this exception will be raised. 50 | """ 51 | pass 52 | 53 | 54 | class WebSocketTimeoutException(WebSocketException): 55 | """ 56 | WebSocketTimeoutException will be raised at socket timeout during read/write data. 57 | """ 58 | pass 59 | 60 | 61 | class WebSocketProxyException(WebSocketException): 62 | """ 63 | WebSocketProxyException will be raised when proxy error occurred. 64 | """ 65 | pass 66 | 67 | 68 | class WebSocketBadStatusException(WebSocketException): 69 | """ 70 | WebSocketBadStatusException will be raised when we get bad handshake status code. 71 | """ 72 | 73 | def __init__(self, message, status_code, status_message=None, resp_headers=None): 74 | msg = message % (status_code, status_message) 75 | super().__init__(msg) 76 | self.status_code = status_code 77 | self.resp_headers = resp_headers 78 | 79 | 80 | class WebSocketAddressException(WebSocketException): 81 | """ 82 | If the websocket address info cannot be found, this exception will be raised. 83 | """ 84 | pass 85 | -------------------------------------------------------------------------------- /nls/websocket/_handshake.py: -------------------------------------------------------------------------------- 1 | """ 2 | _handshake.py 3 | websocket - WebSocket client library for Python 4 | 5 | Copyright 2021 engn33r 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | import hashlib 20 | import hmac 21 | import os 22 | from base64 import encodebytes as base64encode 23 | from http import client as HTTPStatus 24 | from ._cookiejar import SimpleCookieJar 25 | from ._exceptions import * 26 | from ._http import * 27 | from ._logging import * 28 | from ._socket import * 29 | 30 | __all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] 31 | 32 | # websocket supported version. 33 | VERSION = 13 34 | 35 | SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,) 36 | SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,) 37 | 38 | CookieJar = SimpleCookieJar() 39 | 40 | 41 | class handshake_response: 42 | 43 | def __init__(self, status, headers, subprotocol): 44 | self.status = status 45 | self.headers = headers 46 | self.subprotocol = subprotocol 47 | CookieJar.add(headers.get("set-cookie")) 48 | 49 | 50 | def handshake(sock, hostname, port, resource, **options): 51 | headers, key = _get_handshake_headers(resource, hostname, port, options) 52 | 53 | header_str = "\r\n".join(headers) 54 | send(sock, header_str) 55 | dump("request header", header_str) 56 | #print("request header:", header_str) 57 | 58 | status, resp = _get_resp_headers(sock) 59 | if status in SUPPORTED_REDIRECT_STATUSES: 60 | return handshake_response(status, resp, None) 61 | success, subproto = _validate(resp, key, options.get("subprotocols")) 62 | if not success: 63 | raise WebSocketException("Invalid WebSocket Header") 64 | 65 | return handshake_response(status, resp, subproto) 66 | 67 | 68 | def _pack_hostname(hostname): 69 | # IPv6 address 70 | if ':' in hostname: 71 | return '[' + hostname + ']' 72 | 73 | return hostname 74 | 75 | 76 | def _get_handshake_headers(resource, host, port, options): 77 | headers = [ 78 | "GET %s HTTP/1.1" % resource, 79 | "Upgrade: websocket" 80 | ] 81 | if port == 80 or port == 443: 82 | hostport = _pack_hostname(host) 83 | else: 84 | hostport = "%s:%d" % (_pack_hostname(host), port) 85 | if "host" in options and options["host"] is not None: 86 | headers.append("Host: %s" % options["host"]) 87 | else: 88 | headers.append("Host: %s" % hostport) 89 | 90 | if "suppress_origin" not in options or not options["suppress_origin"]: 91 | if "origin" in options and options["origin"] is not None: 92 | headers.append("Origin: %s" % options["origin"]) 93 | else: 94 | headers.append("Origin: http://%s" % hostport) 95 | 96 | key = _create_sec_websocket_key() 97 | 98 | # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified 99 | if 'header' not in options or 'Sec-WebSocket-Key' not in options['header']: 100 | key = _create_sec_websocket_key() 101 | headers.append("Sec-WebSocket-Key: %s" % key) 102 | else: 103 | key = options['header']['Sec-WebSocket-Key'] 104 | 105 | if 'header' not in options or 'Sec-WebSocket-Version' not in options['header']: 106 | headers.append("Sec-WebSocket-Version: %s" % VERSION) 107 | 108 | if 'connection' not in options or options['connection'] is None: 109 | headers.append('Connection: Upgrade') 110 | else: 111 | headers.append(options['connection']) 112 | 113 | subprotocols = options.get("subprotocols") 114 | if subprotocols: 115 | headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols)) 116 | 117 | if "header" in options: 118 | header = options["header"] 119 | if isinstance(header, dict): 120 | header = [ 121 | ": ".join([k, v]) 122 | for k, v in header.items() 123 | if v is not None 124 | ] 125 | headers.extend(header) 126 | 127 | server_cookie = CookieJar.get(host) 128 | client_cookie = options.get("cookie", None) 129 | 130 | cookie = "; ".join(filter(None, [server_cookie, client_cookie])) 131 | 132 | if cookie: 133 | headers.append("Cookie: %s" % cookie) 134 | 135 | headers.append("") 136 | headers.append("") 137 | 138 | return headers, key 139 | 140 | 141 | def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES): 142 | status, resp_headers, status_message = read_headers(sock) 143 | if status not in success_statuses: 144 | raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers) 145 | return status, resp_headers 146 | 147 | 148 | _HEADERS_TO_CHECK = { 149 | "upgrade": "websocket", 150 | "connection": "upgrade", 151 | } 152 | 153 | 154 | def _validate(headers, key, subprotocols): 155 | subproto = None 156 | for k, v in _HEADERS_TO_CHECK.items(): 157 | r = headers.get(k, None) 158 | if not r: 159 | return False, None 160 | r = [x.strip().lower() for x in r.split(',')] 161 | if v not in r: 162 | return False, None 163 | 164 | if subprotocols: 165 | subproto = headers.get("sec-websocket-protocol", None) 166 | if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]: 167 | error("Invalid subprotocol: " + str(subprotocols)) 168 | return False, None 169 | subproto = subproto.lower() 170 | 171 | result = headers.get("sec-websocket-accept", None) 172 | if not result: 173 | return False, None 174 | result = result.lower() 175 | 176 | if isinstance(result, str): 177 | result = result.encode('utf-8') 178 | 179 | value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8') 180 | hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() 181 | success = hmac.compare_digest(hashed, result) 182 | 183 | if success: 184 | return True, subproto 185 | else: 186 | return False, None 187 | 188 | 189 | def _create_sec_websocket_key(): 190 | randomness = os.urandom(16) 191 | return base64encode(randomness).decode('utf-8').strip() 192 | -------------------------------------------------------------------------------- /nls/websocket/_http.py: -------------------------------------------------------------------------------- 1 | """ 2 | _http.py 3 | websocket - WebSocket client library for Python 4 | 5 | Copyright 2021 engn33r 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | import errno 20 | import os 21 | import socket 22 | import sys 23 | 24 | from ._exceptions import * 25 | from ._logging import * 26 | from ._socket import* 27 | from ._ssl_compat import * 28 | from ._url import * 29 | 30 | from base64 import encodebytes as base64encode 31 | 32 | __all__ = ["proxy_info", "connect", "read_headers"] 33 | 34 | try: 35 | from python_socks.sync import Proxy 36 | from python_socks._errors import * 37 | from python_socks._types import ProxyType 38 | HAVE_PYTHON_SOCKS = True 39 | except: 40 | HAVE_PYTHON_SOCKS = False 41 | 42 | class ProxyError(Exception): 43 | pass 44 | 45 | class ProxyTimeoutError(Exception): 46 | pass 47 | 48 | class ProxyConnectionError(Exception): 49 | pass 50 | 51 | 52 | class proxy_info: 53 | 54 | def __init__(self, **options): 55 | self.proxy_host = options.get("http_proxy_host", None) 56 | if self.proxy_host: 57 | self.proxy_port = options.get("http_proxy_port", 0) 58 | self.auth = options.get("http_proxy_auth", None) 59 | self.no_proxy = options.get("http_no_proxy", None) 60 | self.proxy_protocol = options.get("proxy_type", "http") 61 | # Note: If timeout not specified, default python-socks timeout is 60 seconds 62 | self.proxy_timeout = options.get("timeout", None) 63 | if self.proxy_protocol not in ['http', 'socks4', 'socks4a', 'socks5', 'socks5h']: 64 | raise ProxyError("Only http, socks4, socks5 proxy protocols are supported") 65 | else: 66 | self.proxy_port = 0 67 | self.auth = None 68 | self.no_proxy = None 69 | self.proxy_protocol = "http" 70 | 71 | 72 | def _start_proxied_socket(url, options, proxy): 73 | if not HAVE_PYTHON_SOCKS: 74 | raise WebSocketException("Python Socks is needed for SOCKS proxying but is not available") 75 | 76 | hostname, port, resource, is_secure = parse_url(url) 77 | 78 | if proxy.proxy_protocol == "socks5": 79 | rdns = False 80 | proxy_type = ProxyType.SOCKS5 81 | if proxy.proxy_protocol == "socks4": 82 | rdns = False 83 | proxy_type = ProxyType.SOCKS4 84 | # socks5h and socks4a send DNS through proxy 85 | if proxy.proxy_protocol == "socks5h": 86 | rdns = True 87 | proxy_type = ProxyType.SOCKS5 88 | if proxy.proxy_protocol == "socks4a": 89 | rdns = True 90 | proxy_type = ProxyType.SOCKS4 91 | 92 | ws_proxy = Proxy.create( 93 | proxy_type=proxy_type, 94 | host=proxy.proxy_host, 95 | port=int(proxy.proxy_port), 96 | username=proxy.auth[0] if proxy.auth else None, 97 | password=proxy.auth[1] if proxy.auth else None, 98 | rdns=rdns) 99 | 100 | sock = ws_proxy.connect(hostname, port, timeout=proxy.proxy_timeout) 101 | 102 | if is_secure and HAVE_SSL: 103 | sock = _ssl_socket(sock, options.sslopt, hostname) 104 | elif is_secure: 105 | raise WebSocketException("SSL not available.") 106 | 107 | return sock, (hostname, port, resource) 108 | 109 | 110 | def connect(url, options, proxy, socket): 111 | # Use _start_proxied_socket() only for socks4 or socks5 proxy 112 | # Use _tunnel() for http proxy 113 | # TODO: Use python-socks for http protocol also, to standardize flow 114 | if proxy.proxy_host and not socket and not (proxy.proxy_protocol == "http"): 115 | return _start_proxied_socket(url, options, proxy) 116 | 117 | hostname, port, resource, is_secure = parse_url(url) 118 | 119 | if socket: 120 | return socket, (hostname, port, resource) 121 | 122 | addrinfo_list, need_tunnel, auth = _get_addrinfo_list( 123 | hostname, port, is_secure, proxy) 124 | if not addrinfo_list: 125 | raise WebSocketException( 126 | "Host not found.: " + hostname + ":" + str(port)) 127 | 128 | sock = None 129 | try: 130 | sock = _open_socket(addrinfo_list, options.sockopt, options.timeout) 131 | if need_tunnel: 132 | sock = _tunnel(sock, hostname, port, auth) 133 | 134 | if is_secure: 135 | if HAVE_SSL: 136 | sock = _ssl_socket(sock, options.sslopt, hostname) 137 | else: 138 | raise WebSocketException("SSL not available.") 139 | 140 | return sock, (hostname, port, resource) 141 | except: 142 | if sock: 143 | sock.close() 144 | raise 145 | 146 | 147 | def _get_addrinfo_list(hostname, port, is_secure, proxy): 148 | phost, pport, pauth = get_proxy_info( 149 | hostname, is_secure, proxy.proxy_host, proxy.proxy_port, proxy.auth, proxy.no_proxy) 150 | try: 151 | # when running on windows 10, getaddrinfo without socktype returns a socktype 0. 152 | # This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0` 153 | # or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM. 154 | if not phost: 155 | addrinfo_list = socket.getaddrinfo( 156 | hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP) 157 | return addrinfo_list, False, None 158 | else: 159 | pport = pport and pport or 80 160 | # when running on windows 10, the getaddrinfo used above 161 | # returns a socktype 0. This generates an error exception: 162 | # _on_error: exception Socket type must be stream or datagram, not 0 163 | # Force the socket type to SOCK_STREAM 164 | addrinfo_list = socket.getaddrinfo(phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP) 165 | return addrinfo_list, True, pauth 166 | except socket.gaierror as e: 167 | raise WebSocketAddressException(e) 168 | 169 | 170 | def _open_socket(addrinfo_list, sockopt, timeout): 171 | err = None 172 | for addrinfo in addrinfo_list: 173 | family, socktype, proto = addrinfo[:3] 174 | sock = socket.socket(family, socktype, proto) 175 | sock.settimeout(timeout) 176 | for opts in DEFAULT_SOCKET_OPTION: 177 | sock.setsockopt(*opts) 178 | for opts in sockopt: 179 | sock.setsockopt(*opts) 180 | 181 | address = addrinfo[4] 182 | err = None 183 | while not err: 184 | try: 185 | sock.connect(address) 186 | except socket.error as error: 187 | error.remote_ip = str(address[0]) 188 | try: 189 | eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED) 190 | except: 191 | eConnRefused = (errno.ECONNREFUSED, ) 192 | if error.errno == errno.EINTR: 193 | continue 194 | elif error.errno in eConnRefused: 195 | err = error 196 | continue 197 | else: 198 | if sock: 199 | sock.close() 200 | raise error 201 | else: 202 | break 203 | else: 204 | continue 205 | break 206 | else: 207 | if err: 208 | raise err 209 | 210 | return sock 211 | 212 | 213 | def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): 214 | context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_TLS)) 215 | 216 | if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE: 217 | cafile = sslopt.get('ca_certs', None) 218 | capath = sslopt.get('ca_cert_path', None) 219 | if cafile or capath: 220 | context.load_verify_locations(cafile=cafile, capath=capath) 221 | elif hasattr(context, 'load_default_certs'): 222 | context.load_default_certs(ssl.Purpose.SERVER_AUTH) 223 | if sslopt.get('certfile', None): 224 | context.load_cert_chain( 225 | sslopt['certfile'], 226 | sslopt.get('keyfile', None), 227 | sslopt.get('password', None), 228 | ) 229 | # see 230 | # https://github.com/liris/websocket-client/commit/b96a2e8fa765753e82eea531adb19716b52ca3ca#commitcomment-10803153 231 | context.verify_mode = sslopt['cert_reqs'] 232 | if HAVE_CONTEXT_CHECK_HOSTNAME: 233 | context.check_hostname = check_hostname 234 | if 'ciphers' in sslopt: 235 | context.set_ciphers(sslopt['ciphers']) 236 | if 'cert_chain' in sslopt: 237 | certfile, keyfile, password = sslopt['cert_chain'] 238 | context.load_cert_chain(certfile, keyfile, password) 239 | if 'ecdh_curve' in sslopt: 240 | context.set_ecdh_curve(sslopt['ecdh_curve']) 241 | 242 | return context.wrap_socket( 243 | sock, 244 | do_handshake_on_connect=sslopt.get('do_handshake_on_connect', True), 245 | suppress_ragged_eofs=sslopt.get('suppress_ragged_eofs', True), 246 | server_hostname=hostname, 247 | ) 248 | 249 | 250 | def _ssl_socket(sock, user_sslopt, hostname): 251 | sslopt = dict(cert_reqs=ssl.CERT_REQUIRED) 252 | sslopt.update(user_sslopt) 253 | 254 | certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE') 255 | if certPath and os.path.isfile(certPath) \ 256 | and user_sslopt.get('ca_certs', None) is None: 257 | sslopt['ca_certs'] = certPath 258 | elif certPath and os.path.isdir(certPath) \ 259 | and user_sslopt.get('ca_cert_path', None) is None: 260 | sslopt['ca_cert_path'] = certPath 261 | 262 | if sslopt.get('server_hostname', None): 263 | hostname = sslopt['server_hostname'] 264 | 265 | check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop( 266 | 'check_hostname', True) 267 | sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname) 268 | 269 | if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname: 270 | match_hostname(sock.getpeercert(), hostname) 271 | 272 | return sock 273 | 274 | 275 | def _tunnel(sock, host, port, auth): 276 | debug("Connecting proxy...") 277 | connect_header = "CONNECT %s:%d HTTP/1.1\r\n" % (host, port) 278 | connect_header += "Host: %s:%d\r\n" % (host, port) 279 | 280 | # TODO: support digest auth. 281 | if auth and auth[0]: 282 | auth_str = auth[0] 283 | if auth[1]: 284 | auth_str += ":" + auth[1] 285 | encoded_str = base64encode(auth_str.encode()).strip().decode().replace('\n', '') 286 | connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str 287 | connect_header += "\r\n" 288 | dump("request header", connect_header) 289 | 290 | send(sock, connect_header) 291 | 292 | try: 293 | status, resp_headers, status_message = read_headers(sock) 294 | except Exception as e: 295 | raise WebSocketProxyException(str(e)) 296 | 297 | if status != 200: 298 | raise WebSocketProxyException( 299 | "failed CONNECT via proxy status: %r" % status) 300 | 301 | return sock 302 | 303 | 304 | def read_headers(sock): 305 | status = None 306 | status_message = None 307 | headers = {} 308 | trace("--- response header ---") 309 | 310 | while True: 311 | line = recv_line(sock) 312 | line = line.decode('utf-8').strip() 313 | if not line: 314 | break 315 | trace(line) 316 | if not status: 317 | 318 | status_info = line.split(" ", 2) 319 | status = int(status_info[1]) 320 | if len(status_info) > 2: 321 | status_message = status_info[2] 322 | else: 323 | kv = line.split(":", 1) 324 | if len(kv) == 2: 325 | key, value = kv 326 | if key.lower() == "set-cookie" and headers.get("set-cookie"): 327 | headers["set-cookie"] = headers.get("set-cookie") + "; " + value.strip() 328 | else: 329 | headers[key.lower()] = value.strip() 330 | else: 331 | raise WebSocketException("Invalid header") 332 | 333 | trace("-----------------------") 334 | 335 | return status, headers, status_message 336 | -------------------------------------------------------------------------------- /nls/websocket/_logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | _logging.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import logging 24 | 25 | _logger = logging.getLogger('websocket') 26 | try: 27 | from logging import NullHandler 28 | except ImportError: 29 | class NullHandler(logging.Handler): 30 | def emit(self, record): 31 | pass 32 | 33 | _logger.addHandler(NullHandler()) 34 | 35 | _traceEnabled = False 36 | 37 | __all__ = ["enableTrace", "dump", "error", "warning", "debug", "trace", 38 | "isEnabledForError", "isEnabledForDebug", "isEnabledForTrace"] 39 | 40 | 41 | def enableTrace(traceable, handler=logging.StreamHandler()): 42 | """ 43 | Turn on/off the traceability. 44 | 45 | Parameters 46 | ---------- 47 | traceable: bool 48 | If set to True, traceability is enabled. 49 | """ 50 | global _traceEnabled 51 | _traceEnabled = traceable 52 | if traceable: 53 | _logger.addHandler(handler) 54 | _logger.setLevel(logging.ERROR) 55 | 56 | 57 | def dump(title, message): 58 | if _traceEnabled: 59 | _logger.debug("--- " + title + " ---") 60 | _logger.debug(message) 61 | _logger.debug("-----------------------") 62 | 63 | 64 | def error(msg): 65 | _logger.error(msg) 66 | 67 | 68 | def warning(msg): 69 | _logger.warning(msg) 70 | 71 | 72 | def debug(msg): 73 | _logger.debug(msg) 74 | 75 | 76 | def trace(msg): 77 | if _traceEnabled: 78 | _logger.debug(msg) 79 | 80 | 81 | def isEnabledForError(): 82 | return _logger.isEnabledFor(logging.ERROR) 83 | 84 | 85 | def isEnabledForDebug(): 86 | return _logger.isEnabledFor(logging.DEBUG) 87 | 88 | 89 | def isEnabledForTrace(): 90 | return _traceEnabled 91 | -------------------------------------------------------------------------------- /nls/websocket/_socket.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | 5 | """ 6 | _socket.py 7 | websocket - WebSocket client library for Python 8 | 9 | Copyright 2021 engn33r 10 | 11 | Licensed under the Apache License, Version 2.0 (the "License"); 12 | you may not use this file except in compliance with the License. 13 | You may obtain a copy of the License at 14 | 15 | http://www.apache.org/licenses/LICENSE-2.0 16 | 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | import errno 24 | import selectors 25 | import socket 26 | 27 | from ._exceptions import * 28 | from ._ssl_compat import * 29 | from ._utils import * 30 | 31 | DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)] 32 | #if hasattr(socket, "SO_KEEPALIVE"): 33 | # DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)) 34 | #if hasattr(socket, "TCP_KEEPIDLE"): 35 | # DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPIDLE, 30)) 36 | #if hasattr(socket, "TCP_KEEPINTVL"): 37 | # DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPINTVL, 10)) 38 | #if hasattr(socket, "TCP_KEEPCNT"): 39 | # DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPCNT, 3)) 40 | 41 | _default_timeout = None 42 | 43 | __all__ = ["DEFAULT_SOCKET_OPTION", "sock_opt", "setdefaulttimeout", "getdefaulttimeout", 44 | "recv", "recv_line", "send"] 45 | 46 | 47 | class sock_opt: 48 | 49 | def __init__(self, sockopt, sslopt): 50 | if sockopt is None: 51 | sockopt = [] 52 | if sslopt is None: 53 | sslopt = {} 54 | self.sockopt = sockopt 55 | self.sslopt = sslopt 56 | self.timeout = None 57 | 58 | 59 | def setdefaulttimeout(timeout): 60 | """ 61 | Set the global timeout setting to connect. 62 | 63 | Parameters 64 | ---------- 65 | timeout: int or float 66 | default socket timeout time (in seconds) 67 | """ 68 | global _default_timeout 69 | _default_timeout = timeout 70 | 71 | 72 | def getdefaulttimeout(): 73 | """ 74 | Get default timeout 75 | 76 | Returns 77 | ---------- 78 | _default_timeout: int or float 79 | Return the global timeout setting (in seconds) to connect. 80 | """ 81 | return _default_timeout 82 | 83 | 84 | def recv(sock, bufsize): 85 | if not sock: 86 | raise WebSocketConnectionClosedException("socket is already closed.") 87 | 88 | def _recv(): 89 | try: 90 | return sock.recv(bufsize) 91 | except SSLWantReadError: 92 | pass 93 | except socket.error as exc: 94 | error_code = extract_error_code(exc) 95 | if error_code is None: 96 | raise 97 | if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: 98 | raise 99 | 100 | sel = selectors.DefaultSelector() 101 | sel.register(sock, selectors.EVENT_READ) 102 | 103 | r = sel.select(sock.gettimeout()) 104 | sel.close() 105 | 106 | if r: 107 | return sock.recv(bufsize) 108 | 109 | try: 110 | if sock.gettimeout() == 0: 111 | bytes_ = sock.recv(bufsize) 112 | else: 113 | bytes_ = _recv() 114 | except socket.timeout as e: 115 | message = extract_err_message(e) 116 | raise WebSocketTimeoutException(message) 117 | except SSLError as e: 118 | message = extract_err_message(e) 119 | if isinstance(message, str) and 'timed out' in message: 120 | raise WebSocketTimeoutException(message) 121 | else: 122 | raise 123 | 124 | if not bytes_: 125 | raise WebSocketConnectionClosedException( 126 | "Connection to remote host was lost.") 127 | 128 | return bytes_ 129 | 130 | 131 | def recv_line(sock): 132 | line = [] 133 | while True: 134 | c = recv(sock, 1) 135 | line.append(c) 136 | if c == b'\n': 137 | break 138 | return b''.join(line) 139 | 140 | 141 | def send(sock, data): 142 | if isinstance(data, str): 143 | data = data.encode('utf-8') 144 | 145 | if not sock: 146 | raise WebSocketConnectionClosedException("socket is already closed.") 147 | 148 | def _send(): 149 | try: 150 | return sock.send(data) 151 | except SSLWantWriteError: 152 | pass 153 | except socket.error as exc: 154 | error_code = extract_error_code(exc) 155 | if error_code is None: 156 | raise 157 | if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: 158 | raise 159 | 160 | sel = selectors.DefaultSelector() 161 | sel.register(sock, selectors.EVENT_WRITE) 162 | 163 | w = sel.select(sock.gettimeout()) 164 | sel.close() 165 | 166 | if w: 167 | return sock.send(data) 168 | 169 | try: 170 | if sock.gettimeout() == 0: 171 | return sock.send(data) 172 | else: 173 | return _send() 174 | except socket.timeout as e: 175 | message = extract_err_message(e) 176 | raise WebSocketTimeoutException(message) 177 | except Exception as e: 178 | message = extract_err_message(e) 179 | if isinstance(message, str) and "timed out" in message: 180 | raise WebSocketTimeoutException(message) 181 | else: 182 | raise 183 | -------------------------------------------------------------------------------- /nls/websocket/_ssl_compat.py: -------------------------------------------------------------------------------- 1 | """ 2 | _ssl_compat.py 3 | websocket - WebSocket client library for Python 4 | 5 | Copyright 2021 engn33r 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | __all__ = ["HAVE_SSL", "ssl", "SSLError", "SSLWantReadError", "SSLWantWriteError"] 20 | 21 | try: 22 | import ssl 23 | from ssl import SSLError 24 | from ssl import SSLWantReadError 25 | from ssl import SSLWantWriteError 26 | HAVE_CONTEXT_CHECK_HOSTNAME = False 27 | if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'): 28 | HAVE_CONTEXT_CHECK_HOSTNAME = True 29 | 30 | __all__.append("HAVE_CONTEXT_CHECK_HOSTNAME") 31 | HAVE_SSL = True 32 | except ImportError: 33 | # dummy class of SSLError for environment without ssl support 34 | class SSLError(Exception): 35 | pass 36 | 37 | class SSLWantReadError(Exception): 38 | pass 39 | 40 | class SSLWantWriteError(Exception): 41 | pass 42 | 43 | ssl = None 44 | HAVE_SSL = False 45 | -------------------------------------------------------------------------------- /nls/websocket/_url.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | """ 5 | _url.py 6 | websocket - WebSocket client library for Python 7 | 8 | Copyright 2021 engn33r 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | """ 22 | 23 | import os 24 | import socket 25 | import struct 26 | 27 | from urllib.parse import unquote, urlparse 28 | 29 | 30 | __all__ = ["parse_url", "get_proxy_info"] 31 | 32 | 33 | def parse_url(url): 34 | """ 35 | parse url and the result is tuple of 36 | (hostname, port, resource path and the flag of secure mode) 37 | 38 | Parameters 39 | ---------- 40 | url: str 41 | url string. 42 | """ 43 | if ":" not in url: 44 | raise ValueError("url is invalid") 45 | 46 | scheme, url = url.split(":", 1) 47 | 48 | parsed = urlparse(url, scheme="http") 49 | if parsed.hostname: 50 | hostname = parsed.hostname 51 | else: 52 | raise ValueError("hostname is invalid") 53 | port = 0 54 | if parsed.port: 55 | port = parsed.port 56 | 57 | is_secure = False 58 | if scheme == "ws": 59 | if not port: 60 | port = 80 61 | elif scheme == "wss": 62 | is_secure = True 63 | if not port: 64 | port = 443 65 | else: 66 | raise ValueError("scheme %s is invalid" % scheme) 67 | 68 | if parsed.path: 69 | resource = parsed.path 70 | else: 71 | resource = "/" 72 | 73 | if parsed.query: 74 | resource += "?" + parsed.query 75 | 76 | return hostname, port, resource, is_secure 77 | 78 | 79 | DEFAULT_NO_PROXY_HOST = ["localhost", "127.0.0.1"] 80 | 81 | 82 | def _is_ip_address(addr): 83 | try: 84 | socket.inet_aton(addr) 85 | except socket.error: 86 | return False 87 | else: 88 | return True 89 | 90 | 91 | def _is_subnet_address(hostname): 92 | try: 93 | addr, netmask = hostname.split("/") 94 | return _is_ip_address(addr) and 0 <= int(netmask) < 32 95 | except ValueError: 96 | return False 97 | 98 | 99 | def _is_address_in_network(ip, net): 100 | ipaddr = struct.unpack('!I', socket.inet_aton(ip))[0] 101 | netaddr, netmask = net.split('/') 102 | netaddr = struct.unpack('!I', socket.inet_aton(netaddr))[0] 103 | 104 | netmask = (0xFFFFFFFF << (32 - int(netmask))) & 0xFFFFFFFF 105 | return ipaddr & netmask == netaddr 106 | 107 | 108 | def _is_no_proxy_host(hostname, no_proxy): 109 | if not no_proxy: 110 | v = os.environ.get("no_proxy", os.environ.get("NO_PROXY", "")).replace(" ", "") 111 | if v: 112 | no_proxy = v.split(",") 113 | if not no_proxy: 114 | no_proxy = DEFAULT_NO_PROXY_HOST 115 | 116 | if '*' in no_proxy: 117 | return True 118 | if hostname in no_proxy: 119 | return True 120 | if _is_ip_address(hostname): 121 | return any([_is_address_in_network(hostname, subnet) for subnet in no_proxy if _is_subnet_address(subnet)]) 122 | for domain in [domain for domain in no_proxy if domain.startswith('.')]: 123 | if hostname.endswith(domain): 124 | return True 125 | return False 126 | 127 | 128 | def get_proxy_info( 129 | hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None, 130 | no_proxy=None, proxy_type='http'): 131 | """ 132 | Try to retrieve proxy host and port from environment 133 | if not provided in options. 134 | Result is (proxy_host, proxy_port, proxy_auth). 135 | proxy_auth is tuple of username and password 136 | of proxy authentication information. 137 | 138 | Parameters 139 | ---------- 140 | hostname: str 141 | Websocket server name. 142 | is_secure: bool 143 | Is the connection secure? (wss) looks for "https_proxy" in env 144 | before falling back to "http_proxy" 145 | proxy_host: str 146 | http proxy host name. 147 | http_proxy_port: str or int 148 | http proxy port. 149 | http_no_proxy: list 150 | Whitelisted host names that don't use the proxy. 151 | http_proxy_auth: tuple 152 | HTTP proxy auth information. Tuple of username and password. Default is None. 153 | proxy_type: str 154 | Specify the proxy protocol (http, socks4, socks4a, socks5, socks5h). Default is "http". 155 | Use socks4a or socks5h if you want to send DNS requests through the proxy. 156 | """ 157 | if _is_no_proxy_host(hostname, no_proxy): 158 | return None, 0, None 159 | 160 | if proxy_host: 161 | port = proxy_port 162 | auth = proxy_auth 163 | return proxy_host, port, auth 164 | 165 | env_keys = ["http_proxy"] 166 | if is_secure: 167 | env_keys.insert(0, "https_proxy") 168 | 169 | for key in env_keys: 170 | value = os.environ.get(key, os.environ.get(key.upper(), "")).replace(" ", "") 171 | if value: 172 | proxy = urlparse(value) 173 | auth = (unquote(proxy.username), unquote(proxy.password)) if proxy.username else None 174 | return proxy.hostname, proxy.port, auth 175 | 176 | return None, 0, None 177 | -------------------------------------------------------------------------------- /nls/websocket/_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | _url.py 3 | websocket - WebSocket client library for Python 4 | 5 | Copyright 2021 engn33r 6 | 7 | Licensed under the Apache License, Version 2.0 (the "License"); 8 | you may not use this file except in compliance with the License. 9 | You may obtain a copy of the License at 10 | 11 | http://www.apache.org/licenses/LICENSE-2.0 12 | 13 | Unless required by applicable law or agreed to in writing, software 14 | distributed under the License is distributed on an "AS IS" BASIS, 15 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | See the License for the specific language governing permissions and 17 | limitations under the License. 18 | """ 19 | __all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"] 20 | 21 | 22 | class NoLock: 23 | 24 | def __enter__(self): 25 | pass 26 | 27 | def __exit__(self, exc_type, exc_value, traceback): 28 | pass 29 | 30 | 31 | try: 32 | # If wsaccel is available we use compiled routines to validate UTF-8 33 | # strings. 34 | from wsaccel.utf8validator import Utf8Validator 35 | 36 | def _validate_utf8(utfbytes): 37 | return Utf8Validator().validate(utfbytes)[0] 38 | 39 | except ImportError: 40 | # UTF-8 validator 41 | # python implementation of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ 42 | 43 | _UTF8_ACCEPT = 0 44 | _UTF8_REJECT = 12 45 | 46 | _UTF8D = [ 47 | # The first part of the table maps bytes to character classes that 48 | # to reduce the size of the transition table and create bitmasks. 49 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 50 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 51 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 52 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 53 | 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 54 | 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 55 | 8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 56 | 10,3,3,3,3,3,3,3,3,3,3,3,3,4,3,3, 11,6,6,6,5,8,8,8,8,8,8,8,8,8,8,8, 57 | 58 | # The second part is a transition table that maps a combination 59 | # of a state of the automaton and a character class to a state. 60 | 0,12,24,36,60,96,84,12,12,12,48,72, 12,12,12,12,12,12,12,12,12,12,12,12, 61 | 12, 0,12,12,12,12,12, 0,12, 0,12,12, 12,24,12,12,12,12,12,24,12,24,12,12, 62 | 12,12,12,12,12,12,12,24,12,12,12,12, 12,24,12,12,12,12,12,12,12,24,12,12, 63 | 12,12,12,12,12,12,12,36,12,36,12,12, 12,36,12,12,12,12,12,36,12,36,12,12, 64 | 12,36,12,12,12,12,12,12,12,12,12,12, ] 65 | 66 | def _decode(state, codep, ch): 67 | tp = _UTF8D[ch] 68 | 69 | codep = (ch & 0x3f) | (codep << 6) if ( 70 | state != _UTF8_ACCEPT) else (0xff >> tp) & ch 71 | state = _UTF8D[256 + state + tp] 72 | 73 | return state, codep 74 | 75 | def _validate_utf8(utfbytes): 76 | state = _UTF8_ACCEPT 77 | codep = 0 78 | for i in utfbytes: 79 | state, codep = _decode(state, codep, i) 80 | if state == _UTF8_REJECT: 81 | return False 82 | 83 | return True 84 | 85 | 86 | def validate_utf8(utfbytes): 87 | """ 88 | validate utf8 byte string. 89 | utfbytes: utf byte string to check. 90 | return value: if valid utf8 string, return true. Otherwise, return false. 91 | """ 92 | return _validate_utf8(utfbytes) 93 | 94 | 95 | def extract_err_message(exception): 96 | if exception.args: 97 | return exception.args[0] 98 | else: 99 | return None 100 | 101 | 102 | def extract_error_code(exception): 103 | if exception.args and len(exception.args) > 1: 104 | return exception.args[0] if isinstance(exception.args[0], int) else None 105 | -------------------------------------------------------------------------------- /notify_to_master.py: -------------------------------------------------------------------------------- 1 | import logger 2 | import push_to_qiye_wx 3 | from Config import Config 4 | 5 | 6 | def notify(content): 7 | config = Config().get_instance() 8 | qiye_weixin = config.get('qiye_weixin') 9 | if qiye_weixin is None: 10 | logger.e("企业微信配置不存在, 不需要推送: {}".format(content)) 11 | return 12 | secret = qiye_weixin['secret'] 13 | qiye_id = qiye_weixin['qiye_id'] 14 | agent_id = qiye_weixin['agent_id'] 15 | if secret is None or qiye_id is None or agent_id is None: 16 | logger.e("企业微信配置不完整,不需要推送: {}".format(content)) 17 | return 18 | push_to_qiye_wx.push_to_weixin(content) 19 | 20 | 21 | if __name__ == '__main__': 22 | notify('测试') 23 | -------------------------------------------------------------------------------- /pdu_exceptions.py: -------------------------------------------------------------------------------- 1 | """ Module defines exceptions used by gsmmodem """ 2 | 3 | 4 | class GsmModemException(Exception): 5 | """ Base exception raised for error conditions when interacting with the GSM modem """ 6 | 7 | 8 | class TimeoutException(GsmModemException): 9 | """ Raised when a write command times out """ 10 | 11 | def __init__(self, data=None): 12 | """ @param data: Any data that was read was read before timeout occurred (if applicable) """ 13 | super(TimeoutException, self).__init__(data) 14 | self.data = data 15 | 16 | 17 | class InvalidStateException(GsmModemException): 18 | """ Raised when an API method call is invoked on an object that is in an incorrect state """ 19 | 20 | 21 | class InterruptedException(InvalidStateException): 22 | """ Raised when execution of an AT command is interrupt by a state change. 23 | May contain another exception that was the cause of the interruption """ 24 | 25 | def __init__(self, message, cause=None): 26 | """ @param cause: the exception that caused this interruption (usually a CmeError) """ 27 | super(InterruptedException, self).__init__(message) 28 | self.cause = cause 29 | 30 | 31 | class CommandError(GsmModemException): 32 | """ Raised if the modem returns an error in response to an AT command 33 | 34 | May optionally include an error type (CME or CMS) and -code (error-specific). 35 | """ 36 | 37 | _description = '' 38 | 39 | def __init__(self, command=None, type=None, code=None): 40 | self.command = command 41 | self.type = type 42 | self.code = code 43 | if type != None and code != None: 44 | super(CommandError, self).__init__('{0} {1}{2}'.format(type, code, 45 | ' ({0})'.format(self._description) if len( 46 | self._description) > 0 else '')) 47 | elif command != None: 48 | super(CommandError, self).__init__(command) 49 | else: 50 | super(CommandError, self).__init__() 51 | 52 | 53 | class CmeError(CommandError): 54 | """ ME error result code : +CME ERROR: 55 | 56 | Issued in response to an AT command 57 | """ 58 | 59 | def __new__(cls, *args, **kwargs): 60 | # Return a specialized version of this class if possible 61 | if len(args) >= 2: 62 | code = args[1] 63 | if code == 11: 64 | return PinRequiredError(args[0]) 65 | elif code == 16: 66 | return IncorrectPinError(args[0]) 67 | elif code == 12: 68 | return PukRequiredError(args[0]) 69 | return super(CmeError, cls).__new__(cls, *args, **kwargs) 70 | 71 | def __init__(self, command, code): 72 | super(CmeError, self).__init__(command, 'CME', code) 73 | 74 | 75 | class SecurityException(CmeError): 76 | """ Security-related CME error """ 77 | 78 | def __init__(self, command, code): 79 | super(SecurityException, self).__init__(command, code) 80 | 81 | 82 | class PinRequiredError(SecurityException): 83 | """ Raised if an operation failed because the SIM card's PIN has not been entered """ 84 | 85 | _description = 'SIM card PIN is required' 86 | 87 | def __init__(self, command, code=11): 88 | super(PinRequiredError, self).__init__(command, code) 89 | 90 | 91 | class IncorrectPinError(SecurityException): 92 | """ Raised if an incorrect PIN is entered """ 93 | 94 | _description = 'Incorrect PIN entered' 95 | 96 | def __init__(self, command, code=16): 97 | super(IncorrectPinError, self).__init__(command, code) 98 | 99 | 100 | class PukRequiredError(SecurityException): 101 | """ Raised an operation failed because the SIM card's PUK is required (SIM locked) """ 102 | 103 | _description = "PUK required (SIM locked)" 104 | 105 | def __init__(self, command, code=12): 106 | super(PukRequiredError, self).__init__(command, code) 107 | 108 | 109 | class CmsError(CommandError): 110 | """ Message service failure result code: +CMS ERROR : 111 | 112 | Issued in response to an AT command 113 | """ 114 | 115 | def __new__(cls, *args, **kwargs): 116 | # Return a specialized version of this class if possible 117 | if len(args) >= 2: 118 | code = args[1] 119 | if code == 330: 120 | return SmscNumberUnknownError(args[0]) 121 | return super(CmsError, cls).__new__(cls, *args, **kwargs) 122 | 123 | def __init__(self, command, code): 124 | super(CmsError, self).__init__(command, 'CMS', code) 125 | 126 | 127 | class SmscNumberUnknownError(CmsError): 128 | """ Raised if the SMSC (service centre) address is missing when trying to send an SMS message """ 129 | 130 | _description = 'SMSC number not set' 131 | 132 | def __init__(self, command, code=330): 133 | super(SmscNumberUnknownError, self).__init__(command, code) 134 | 135 | 136 | class EncodingError(GsmModemException): 137 | """ Raised if a decoding- or encoding operation failed """ 138 | -------------------------------------------------------------------------------- /push_to_qiye_wx.py: -------------------------------------------------------------------------------- 1 | import json 2 | import threading 3 | 4 | import requests 5 | 6 | import logger 7 | from Config import Config 8 | 9 | 10 | def push_to(_qiye_id, _agent_id, _secret, _msg): 11 | """ 12 | https://blog.csdn.net/haijiege/article/details/86529460 13 | https://daliuzi.cn/tasker-forward-sms-wechat/ 14 | 15 | :param _qiye_id: 16 | :param _agent_id: 17 | :param _secret: 18 | :param _msg: 19 | """ 20 | gettoken = "https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=" + _qiye_id + "&corpsecret=" + _secret 21 | 22 | response = requests.get(gettoken) 23 | get_result = json.loads(response.text) 24 | _access_token = str(get_result['access_token']) 25 | 26 | _msg_builder = { 27 | "touser": "@all", 28 | "msgtype": "text", 29 | "agentid": _agent_id, 30 | "text": { 31 | "content": _msg 32 | }, 33 | "safe": 0 34 | } 35 | 36 | msg_json = json.dumps(_msg_builder) 37 | send = "https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=" + _access_token 38 | 39 | post_result = requests.post(url=send, data=str(msg_json)) 40 | # print(post_result.text) 41 | logger.i(f"推送结果: {post_result.text}") 42 | 43 | 44 | def _threaded_push_to_weixin(sms): 45 | config = Config().get_instance() 46 | qiye_weixin = config.get('qiye_weixin') 47 | if qiye_weixin is None: 48 | logger.e("企业微信配置不存在") 49 | return 50 | secret = qiye_weixin['secret'] 51 | qiye_id = qiye_weixin['qiye_id'] 52 | agent_id = qiye_weixin['agent_id'] 53 | if secret is None or qiye_id is None or agent_id is None: 54 | logger.e("企业微信配置不完整") 55 | return 56 | push_to(qiye_id, agent_id, secret, sms) 57 | 58 | 59 | def push_to_weixin(sms): 60 | thread = threading.Thread(target=_threaded_push_to_weixin, args=(sms,)) 61 | thread.start() 62 | 63 | 64 | if __name__ == '__main__': 65 | push_to_weixin('测试') 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests~=2.31.0 2 | pyserial~=3.5 3 | 4 | # 阿里云语音识别需要的 5 | oss2~=2.18.6 6 | aliyun-python-sdk-core>=2.13.3 7 | dashscope~=1.20.1 8 | pyinstaller~=6.9.0 9 | setuptools~=68.2.0 10 | PyYAML~=6.0.1 -------------------------------------------------------------------------------- /say_hello.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/say_hello.pcm -------------------------------------------------------------------------------- /screenshots/4g_minipcie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/screenshots/4g_minipcie.png -------------------------------------------------------------------------------- /screenshots/ai_phon_call.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/screenshots/ai_phon_call.png -------------------------------------------------------------------------------- /screenshots/create_app.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/screenshots/create_app.png -------------------------------------------------------------------------------- /screenshots/pi_4b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/screenshots/pi_4b.png -------------------------------------------------------------------------------- /screenshots/pi_4b_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/screenshots/pi_4b_2.png -------------------------------------------------------------------------------- /screenshots/pi_cm4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andforce/AI-Phone-Call/ebf5c64cfb7e1a9d381c0b3b6ae6906b5d1d2bcb/screenshots/pi_cm4.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | ''' 3 | Licensed to the Apache Software Foundation(ASF) under one 4 | or more contributor license agreements. See the NOTICE file 5 | distributed with this work for additional information 6 | regarding copyright ownership. The ASF licenses this file 7 | to you under the Apache License, Version 2.0 (the 8 | "License"); you may not use this file except in compliance 9 | with the License. You may obtain a copy of the License at 10 | http: // www.apache.org/licenses/LICENSE-2.0 11 | Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations under 16 | the License. 17 | ''' 18 | import os 19 | import setuptools 20 | 21 | with open("README.md", "r") as f: 22 | long_description = f.read() 23 | 24 | requires = [ 25 | "oss2", 26 | "aliyun-python-sdk-core>=2.13.3", 27 | "matplotlib>=3.3.4" 28 | ] 29 | 30 | setup_args = { 31 | 'version': "1.0.0", 32 | 'author': "jiaqi.sjq", 33 | 'author_email': "jiaqi.sjq@alibaba-inc.com", 34 | 'description': "python sdk for nls", 35 | 'license': "Apache License 2.0", 36 | 'long_description': long_description, 37 | 'long_description_content_type': "text/markdown", 38 | 'keywords': ["nls", "sdk"], 39 | 'url': "https://github.com/..", 40 | 'packages': ["nls", "nls/websocket"], 41 | 'install_requires': requires, 42 | 'classifiers': [ 43 | "Programming Language :: Python :: 3", 44 | "License :: OSI Approved :: Apache Software License", 45 | "Operating System :: OS Independent", 46 | ], 47 | } 48 | 49 | setuptools.setup(name='nls', **setup_args) 50 | --------------------------------------------------------------------------------