├── start_server.bat ├── test_chat.py ├── readme.md ├── test.py ├── rtx_api_july_2024.py ├── rtx_api_2_11.py ├── rtx_api_3_5.py ├── rtx_server_july_2024.py ├── rtx_server_0_5_2025.py └── rtx_api_4_24.py /start_server.bat: -------------------------------------------------------------------------------- 1 | PowerShell -Command "Start-Process python3 rtx_server_0_5_2025.py -Verb RunAs" 2 | -------------------------------------------------------------------------------- /test_chat.py: -------------------------------------------------------------------------------- 1 | import rtx_api_july_2024 as rtx_api 2 | 3 | while True: 4 | user_input = input("$ ") 5 | for data in rtx_api.send_message_streaming(user_input): 6 | print(data) 7 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## Python API for Chat With RTX 2 | 3 | ### Usage 4 | 5 | `.\start_server.bat` 6 | 7 | ```python 8 | import rtx_api_july_2024 as rtx_api 9 | 10 | response = rtx_api.send_message("write fire emoji") 11 | print(response) 12 | ``` 13 | 14 | 15 | ### Speed 16 | Chat With RTX builds int4 (W4A16 AWQ) tensortRT engines for LLMs 17 | 18 | | Model | On 4090 | 19 | |-|-| 20 | | Mistral | 457 char/sec | 21 | | Llama2 | 315 char/sec | 22 | | ChatGLM3 | 385 char/sec | 23 | | Gemma | 407 char/sec | 24 | 25 |
26 |
27 |
28 | 29 | ``` 30 | Update History of Chat With RTX 31 | 3.2024 Removed youtube video transcript fetch 32 | 4.2024 Added Whisper Speech to text model 33 | 7.2024 Electron app ui 34 | ``` 35 | LICENSE: CC0 36 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import rtx_api_july_2024 as rtx_api 2 | from time import time_ns 3 | 4 | if __name__ == '__main__': 5 | current_time_ns = time_ns() 6 | start_sec = current_time_ns / 1_000_000_000 7 | 8 | out = "" 9 | for i in range(10): 10 | print('iteration', i) 11 | tmp = rtx_api.send_message("Write essay on: War, Famine, and Pestilence") 12 | print(tmp) 13 | out += tmp 14 | 15 | current_time_ns = time_ns() 16 | end_sec = current_time_ns / 1_000_000_000 17 | 18 | took = end_sec - start_sec 19 | char_per_second = len(out) / took 20 | print("char/sec:", int(char_per_second)) 21 | 22 | # print(out + "\n") 23 | # out = rtx_api.send_message("Write single sentence sumary of: " + out) 24 | # print("SUMMARY: " + out) 25 | -------------------------------------------------------------------------------- /rtx_api_july_2024.py: -------------------------------------------------------------------------------- 1 | from http.client import HTTPConnection 2 | 3 | def send_message_streaming(message): 4 | connection = HTTPConnection('localhost', 8000) 5 | connection.putrequest('POST', '/', skip_host=True) 6 | connection.putheader('Content-Type', 'text/plain') 7 | connection._http_vsn = 11 8 | connection._http_vsn_str = 'HTTP/1.1' 9 | encoded_message = message.encode('utf-8') 10 | connection.putheader('Content-Length', str(len(encoded_message))) 11 | connection.endheaders() 12 | connection.send(encoded_message) 13 | response = connection.getresponse() 14 | 15 | if response.status == 200: 16 | while True: 17 | chunk = response.readline().decode('utf-8') 18 | if not chunk: 19 | return 20 | yield chunk.strip() 21 | else: 22 | raise Exception(f"Error: Server responded with status {response.status}") 23 | 24 | def send_message(message): 25 | response = "" 26 | for data in send_message_streaming(message): 27 | response += data + '\n' 28 | return response 29 | -------------------------------------------------------------------------------- /rtx_api_2_11.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import random 3 | import string 4 | import json 5 | 6 | def join_queue(session_hash, fn_index, port, chatdata): 7 | python_object = { 8 | "data": chatdata, 9 | "event_data": None, 10 | "fn_index": fn_index, 11 | "trigger_id": 46, 12 | "session_hash": session_hash 13 | } 14 | json_string = json.dumps(python_object) 15 | 16 | url = f"http://127.0.0.1:{port}/queue/join?__theme=dark" 17 | 18 | response = requests.post(url, data=json_string) 19 | # print("Join Queue Response:", response.json()) 20 | 21 | def listen_for_updates(session_hash, port): 22 | url = f"http://127.0.0.1:{port}/queue/data?session_hash={session_hash}" 23 | 24 | response = requests.get(url, stream=True) 25 | for line in response.iter_lines(): 26 | if line: 27 | try: 28 | data = json.loads(line[5:]) 29 | # if data['msg'] == 'process_generating': 30 | # print(data['output']['data'][0][0][1]) 31 | if data['msg'] == 'process_completed': 32 | return data['output']['data'][0][0][1] 33 | except Exception as e: 34 | pass 35 | return "" 36 | 37 | def send_message(message, port): 38 | session_hash = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) 39 | 40 | #fn_indexes are some gradio generated indexes from rag/trt/ui/user_interface.py 41 | 42 | join_queue(session_hash, 36, port, []) 43 | listen_for_updates(session_hash, port) 44 | 45 | join_queue(session_hash, 37, port, []) 46 | listen_for_updates(session_hash, port) 47 | 48 | chatdata = ["", [], "AI model default", None] 49 | join_queue(session_hash, 38, port, chatdata) 50 | listen_for_updates(session_hash, port) 51 | 52 | chatdata = ["", []] 53 | join_queue(session_hash, 39, port, chatdata) 54 | listen_for_updates(session_hash, port) 55 | 56 | #add chat history here -v 57 | chatdata = [[[message, None]], None] 58 | join_queue(session_hash, 40, port, chatdata) 59 | return listen_for_updates(session_hash, port) 60 | -------------------------------------------------------------------------------- /rtx_api_3_5.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import random 3 | import string 4 | import psutil 5 | import json 6 | 7 | port = None 8 | 9 | def find_chat_with_rtx_port(): 10 | global port 11 | connections = psutil.net_connections(kind='inet') 12 | for host in connections: 13 | try: 14 | if host.pid: 15 | process = psutil.Process(host.pid) 16 | if "ChatWithRTX" in process.exe(): 17 | test_port = host.laddr.port 18 | url = f"http://127.0.0.1:{test_port}/queue/join" 19 | response = requests.post(url, data="", timeout=0.05) 20 | if response.status_code == 422: 21 | port = test_port 22 | return 23 | except: 24 | pass 25 | 26 | def join_queue(session_hash, fn_index, port, chatdata): 27 | #fn_indexes are some gradio generated indexes from rag/trt/ui/user_interface.py 28 | python_object = { 29 | "data": chatdata, 30 | "event_data": None, 31 | "fn_index": fn_index, 32 | "session_hash": session_hash 33 | } 34 | json_string = json.dumps(python_object) 35 | 36 | url = f"http://127.0.0.1:{port}/queue/join" 37 | response = requests.post(url, data=json_string) 38 | # print("Join Queue Response:", response.json()) 39 | 40 | def listen_for_updates(session_hash, port): 41 | url = f"http://127.0.0.1:{port}/queue/data?session_hash={session_hash}" 42 | 43 | response = requests.get(url, stream=True) 44 | for line in response.iter_lines(): 45 | if line: 46 | try: 47 | data = json.loads(line[5:]) 48 | # if data['msg'] == 'process_generating': 49 | # print(data['output']['data'][0][0][1]) 50 | if data['msg'] == 'process_completed': 51 | return data['output']['data'][0][0][1] 52 | except Exception as e: 53 | pass 54 | return "" 55 | 56 | def send_message(message): 57 | if not port: 58 | find_chat_with_rtx_port() 59 | if not port: 60 | raise Exception("Failed to find a server port for 'Chat with RTX'. Ensure the server is running.") 61 | 62 | session_hash = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) 63 | 64 | #add chat history here -v 65 | chatdata = [[[message, None]], None] 66 | join_queue(session_hash, 34, port, chatdata) 67 | return listen_for_updates(session_hash, port) 68 | -------------------------------------------------------------------------------- /rtx_server_july_2024.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import shutil 3 | import os 4 | 5 | def copy_self_to_another_directory(target_directory): 6 | current_script = os.path.realpath(__file__) 7 | target_file = os.path.join(target_directory, os.path.basename(current_script)) 8 | if os.path.exists(target_file): 9 | os.remove(target_file) 10 | shutil.copy(current_script, target_file) 11 | 12 | # copy & run self in the nvidia virtual env 13 | python_bin = "C:\\Program Files\\NVIDIA Corporation\\ChatRTX\\env_nvd_rag\\Scripts\\python.exe" 14 | target_directory = "C:\\Program Files\\NVIDIA Corporation\\ChatRTX\\RAG\\trt-llm-rag-windows-ChatRTX_0.4.0\\ChatRTXUI\\engine" 15 | if os.getcwd() != target_directory: 16 | copy_self_to_another_directory(target_directory) 17 | os.chdir(target_directory) 18 | subprocess.run([python_bin, target_directory + "\\" + os.path.basename(__file__)]) 19 | exit(0) 20 | 21 | 22 | 23 | 24 | from http.server import BaseHTTPRequestHandler, HTTPServer 25 | from configuration import Configuration 26 | from backend import Backend, Mode 27 | 28 | data_path = os.path.expandvars("%programdata%\\NVIDIA Corporation\\chatrtx") 29 | backend = Backend(model_setup_dir=data_path) 30 | 31 | backend.init_model(model_id="mistral_7b_AWQ_int4_chat") 32 | # backend.init_model(model_id="llama2_13b_AWQ_INT4_chat") 33 | # backend.init_model(model_id="chatglm3_6b_AWQ_int4") 34 | # backend.init_model(model_id="gemma_7b_int4") 35 | # backend.init_model(model_id="clip_model") 36 | 37 | status = backend.ChatRTX(chatrtx_mode=Mode.AI) 38 | # status = backend.ChatRTX(chatrtx_mode=Mode.RAG, data_dir=dataset_dir) 39 | 40 | class RequestHandler(BaseHTTPRequestHandler): 41 | def do_POST(self): 42 | content_length = int(self.headers['Content-Length']) 43 | post_data = self.rfile.read(content_length) 44 | 45 | input_string = post_data.decode('utf-8') 46 | answer_stream = backend.query_stream(query=input_string) 47 | 48 | self.send_response(200) 49 | self.send_header('Content-type', 'text/plain') 50 | self.send_header('Transfer-Encoding', 'chunked') 51 | self.end_headers() 52 | 53 | for part in answer_stream: 54 | print(part) 55 | chunk = part.encode('utf-8') 56 | if not chunk: 57 | continue 58 | # HTTP/1.1 chunked transfer encoding standard 59 | chunk_size = len(chunk) 60 | self.wfile.write(f"{chunk_size:x}\r\n".encode('utf-8')) 61 | self.wfile.write(chunk + b"\r\n") 62 | self.wfile.flush() 63 | self.wfile.flush() 64 | 65 | httpd = HTTPServer(('', 8000), RequestHandler) 66 | httpd.serve_forever() 67 | -------------------------------------------------------------------------------- /rtx_server_0_5_2025.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import shutil 3 | import os 4 | 5 | 6 | def copy_self_to_another_directory(target_directory): 7 | current_script = os.path.realpath(__file__) 8 | target_file = os.path.join(target_directory, os.path.basename(current_script)) 9 | if os.path.exists(target_file): 10 | os.remove(target_file) 11 | shutil.copy(current_script, target_file) 12 | 13 | 14 | # copy & run self in the nvidia virtual env 15 | python_bin = ( 16 | "C:\\Program Files\\NVIDIA Corporation\\ChatRTX\\env_nvd_rag\\Scripts\\python.exe" 17 | ) 18 | target_directory = "C:\\Program Files\\NVIDIA Corporation\\ChatRTX\\RAG\\trt-llm-rag-windows-ChatRTX_0.5.0\\ChatRTXUI\\engine" 19 | if os.getcwd() != target_directory: 20 | copy_self_to_another_directory(target_directory) 21 | os.chdir(target_directory) 22 | subprocess.run([python_bin, target_directory + "\\" + os.path.basename(__file__)]) 23 | exit(0) 24 | 25 | 26 | from http.server import BaseHTTPRequestHandler, HTTPServer 27 | from configuration import Configuration 28 | from backend import Backend, Mode 29 | 30 | data_path = os.path.expandvars("%programdata%\\NVIDIA Corporation\\chatrtx") 31 | backend = Backend(model_setup_dir=data_path) 32 | 33 | backend.init_model(model_id="mistral_7b_AWQ_int4_chat") 34 | # backend.init_model(model_id="llama2_13b_AWQ_INT4_chat") 35 | # backend.init_model(model_id="chatglm3_6b_AWQ_int4") 36 | # backend.init_model(model_id="gemma_7b_int4") 37 | # backend.init_model(model_id="clip_model") 38 | 39 | # dataset_dir = '{directory_to_your_dataset}' 40 | # status = backend.ChatRTX(chatrtx_mode=Mode.RAG, data_dir=dataset_dir) 41 | status = backend.ChatRTX(chatrtx_mode=Mode.AI) 42 | 43 | 44 | class RequestHandler(BaseHTTPRequestHandler): 45 | def do_POST(self): 46 | content_length = int(self.headers["Content-Length"]) 47 | post_data = self.rfile.read(content_length) 48 | 49 | input_string = post_data.decode("utf-8") 50 | answer_stream = backend.query_stream(query=input_string) 51 | 52 | self.send_response(200) 53 | self.send_header("Content-type", "text/plain") 54 | self.send_header("Transfer-Encoding", "chunked") 55 | self.end_headers() 56 | 57 | for part in answer_stream: 58 | print(part) 59 | chunk = part.encode("utf-8") 60 | if not chunk: 61 | continue 62 | # HTTP/1.1 chunked transfer encoding standard 63 | chunk_size = len(chunk) 64 | self.wfile.write(f"{chunk_size:x}\r\n".encode("utf-8")) 65 | self.wfile.write(chunk + b"\r\n") 66 | self.wfile.flush() 67 | self.wfile.flush() 68 | 69 | 70 | httpd = HTTPServer(("", 8000), RequestHandler) 71 | httpd.serve_forever() 72 | -------------------------------------------------------------------------------- /rtx_api_4_24.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import random 3 | import string 4 | import psutil 5 | import json 6 | import os 7 | 8 | port = None 9 | fn_index = None 10 | 11 | #the server is https because speech to text js microphone needs https 12 | appdata_folder = os.path.dirname(os.getenv('APPDATA')).replace('\\', '/') 13 | cert_path = appdata_folder + "/Local/NVIDIA/ChatRTX/RAG/trt-llm-rag-windows-ChatRTX_0.3/certs/servercert.pem" 14 | key_path = appdata_folder + "/Local/NVIDIA/ChatRTX/RAG/trt-llm-rag-windows-ChatRTX_0.3/certs/serverkey.pem" 15 | ca_bundle = appdata_folder + "/Local/NVIDIA/ChatRTX/env_nvd_rag/Library/ssl/cacert.pem" 16 | 17 | def find_chat_with_rtx_port(): 18 | global port 19 | connections = psutil.net_connections(kind='inet') 20 | for host in connections: 21 | try: 22 | if host.pid: 23 | process = psutil.Process(host.pid) 24 | 25 | if "chatrtx" in process.exe().lower(): 26 | test_port = host.laddr.port 27 | url = f"https://127.0.0.1:{test_port}/queue/join" 28 | 29 | response = requests.post(url, data="", timeout=0.1, cert=(cert_path, key_path), verify=ca_bundle) 30 | if response.status_code == 422: 31 | port = test_port 32 | return 33 | except: 34 | pass 35 | 36 | def join_queue(session_hash, set_fn_index, port, chatdata): 37 | #fn_indexes are some gradio generated indexes from rag/trt/ui/user_interface.py 38 | python_object = { 39 | "data": chatdata, 40 | "event_data": None, 41 | "fn_index": set_fn_index, 42 | "session_hash": session_hash 43 | } 44 | json_string = json.dumps(python_object) 45 | 46 | url = f"https://127.0.0.1:{port}/queue/join" 47 | response = requests.post(url, data=json_string, cert=(cert_path, key_path), verify=ca_bundle) 48 | # print("Join Queue Response:", response) 49 | 50 | def listen_for_updates(session_hash, port): 51 | url = f"https://127.0.0.1:{port}/queue/data?session_hash={session_hash}" 52 | 53 | response = requests.get(url, stream=True, cert=(cert_path, key_path), verify=ca_bundle) 54 | # print("Listen Response:", response) 55 | try: 56 | for line in response.iter_lines(): 57 | if line: 58 | data = json.loads(line[5:]) 59 | # if data['msg'] == 'process_generating': 60 | # print(data['output']['data'][0][0][1]) 61 | if data['msg'] == 'process_completed': 62 | return data['output']['data'][0][0][1] 63 | except Exception as e: 64 | pass 65 | return "" 66 | 67 | def auto_find_fn_index(port): 68 | global fn_index 69 | 70 | print("Searching for llm streamed completion function. Takes about 30 seconds.") 71 | session_hash = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) 72 | chatdata = [[["write a comma", None]], None] 73 | for i in range(10, 1000): 74 | join_queue(session_hash, i, port, chatdata) 75 | res = listen_for_updates(session_hash, port) 76 | if res: 77 | fn_index = i 78 | return 79 | raise Exception("Failed to find fn_index") 80 | 81 | def send_message(message): 82 | global fn_index 83 | 84 | if not port: 85 | find_chat_with_rtx_port() 86 | if not port: 87 | raise Exception("Failed to find a server port for 'Chat with RTX'. Ensure the server is running.") 88 | if not fn_index: 89 | #comment this line 90 | auto_find_fn_index(port) 91 | print("To make initialization instant hardcode:\nfn_index =", fn_index) 92 | # fn_index = 93 | 94 | session_hash = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) 95 | 96 | chatdata = [[[message, None]], None] 97 | join_queue(session_hash, fn_index, port, chatdata) 98 | return listen_for_updates(session_hash, port) 99 | --------------------------------------------------------------------------------