├── start_server.bat
├── test_chat.py
├── readme.md
├── test.py
├── rtx_api_july_2024.py
├── rtx_api_2_11.py
├── rtx_api_3_5.py
├── rtx_server_july_2024.py
├── rtx_server_0_5_2025.py
└── rtx_api_4_24.py
/start_server.bat:
--------------------------------------------------------------------------------
1 | PowerShell -Command "Start-Process python3 rtx_server_0_5_2025.py -Verb RunAs"
2 |
--------------------------------------------------------------------------------
/test_chat.py:
--------------------------------------------------------------------------------
1 | import rtx_api_july_2024 as rtx_api
2 |
3 | while True:
4 | user_input = input("$ ")
5 | for data in rtx_api.send_message_streaming(user_input):
6 | print(data)
7 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | ## Python API for Chat With RTX
2 |
3 | ### Usage
4 |
5 | `.\start_server.bat`
6 |
7 | ```python
8 | import rtx_api_july_2024 as rtx_api
9 |
10 | response = rtx_api.send_message("write fire emoji")
11 | print(response)
12 | ```
13 |
14 |
15 | ### Speed
16 | Chat With RTX builds int4 (W4A16 AWQ) tensortRT engines for LLMs
17 |
18 | | Model | On 4090 |
19 | |-|-|
20 | | Mistral | 457 char/sec |
21 | | Llama2 | 315 char/sec |
22 | | ChatGLM3 | 385 char/sec |
23 | | Gemma | 407 char/sec |
24 |
25 |
26 |
27 |
28 |
29 | ```
30 | Update History of Chat With RTX
31 | 3.2024 Removed youtube video transcript fetch
32 | 4.2024 Added Whisper Speech to text model
33 | 7.2024 Electron app ui
34 | ```
35 | LICENSE: CC0
36 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import rtx_api_july_2024 as rtx_api
2 | from time import time_ns
3 |
4 | if __name__ == '__main__':
5 | current_time_ns = time_ns()
6 | start_sec = current_time_ns / 1_000_000_000
7 |
8 | out = ""
9 | for i in range(10):
10 | print('iteration', i)
11 | tmp = rtx_api.send_message("Write essay on: War, Famine, and Pestilence")
12 | print(tmp)
13 | out += tmp
14 |
15 | current_time_ns = time_ns()
16 | end_sec = current_time_ns / 1_000_000_000
17 |
18 | took = end_sec - start_sec
19 | char_per_second = len(out) / took
20 | print("char/sec:", int(char_per_second))
21 |
22 | # print(out + "\n")
23 | # out = rtx_api.send_message("Write single sentence sumary of: " + out)
24 | # print("SUMMARY: " + out)
25 |
--------------------------------------------------------------------------------
/rtx_api_july_2024.py:
--------------------------------------------------------------------------------
1 | from http.client import HTTPConnection
2 |
3 | def send_message_streaming(message):
4 | connection = HTTPConnection('localhost', 8000)
5 | connection.putrequest('POST', '/', skip_host=True)
6 | connection.putheader('Content-Type', 'text/plain')
7 | connection._http_vsn = 11
8 | connection._http_vsn_str = 'HTTP/1.1'
9 | encoded_message = message.encode('utf-8')
10 | connection.putheader('Content-Length', str(len(encoded_message)))
11 | connection.endheaders()
12 | connection.send(encoded_message)
13 | response = connection.getresponse()
14 |
15 | if response.status == 200:
16 | while True:
17 | chunk = response.readline().decode('utf-8')
18 | if not chunk:
19 | return
20 | yield chunk.strip()
21 | else:
22 | raise Exception(f"Error: Server responded with status {response.status}")
23 |
24 | def send_message(message):
25 | response = ""
26 | for data in send_message_streaming(message):
27 | response += data + '\n'
28 | return response
29 |
--------------------------------------------------------------------------------
/rtx_api_2_11.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import random
3 | import string
4 | import json
5 |
6 | def join_queue(session_hash, fn_index, port, chatdata):
7 | python_object = {
8 | "data": chatdata,
9 | "event_data": None,
10 | "fn_index": fn_index,
11 | "trigger_id": 46,
12 | "session_hash": session_hash
13 | }
14 | json_string = json.dumps(python_object)
15 |
16 | url = f"http://127.0.0.1:{port}/queue/join?__theme=dark"
17 |
18 | response = requests.post(url, data=json_string)
19 | # print("Join Queue Response:", response.json())
20 |
21 | def listen_for_updates(session_hash, port):
22 | url = f"http://127.0.0.1:{port}/queue/data?session_hash={session_hash}"
23 |
24 | response = requests.get(url, stream=True)
25 | for line in response.iter_lines():
26 | if line:
27 | try:
28 | data = json.loads(line[5:])
29 | # if data['msg'] == 'process_generating':
30 | # print(data['output']['data'][0][0][1])
31 | if data['msg'] == 'process_completed':
32 | return data['output']['data'][0][0][1]
33 | except Exception as e:
34 | pass
35 | return ""
36 |
37 | def send_message(message, port):
38 | session_hash = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
39 |
40 | #fn_indexes are some gradio generated indexes from rag/trt/ui/user_interface.py
41 |
42 | join_queue(session_hash, 36, port, [])
43 | listen_for_updates(session_hash, port)
44 |
45 | join_queue(session_hash, 37, port, [])
46 | listen_for_updates(session_hash, port)
47 |
48 | chatdata = ["", [], "AI model default", None]
49 | join_queue(session_hash, 38, port, chatdata)
50 | listen_for_updates(session_hash, port)
51 |
52 | chatdata = ["", []]
53 | join_queue(session_hash, 39, port, chatdata)
54 | listen_for_updates(session_hash, port)
55 |
56 | #add chat history here -v
57 | chatdata = [[[message, None]], None]
58 | join_queue(session_hash, 40, port, chatdata)
59 | return listen_for_updates(session_hash, port)
60 |
--------------------------------------------------------------------------------
/rtx_api_3_5.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import random
3 | import string
4 | import psutil
5 | import json
6 |
7 | port = None
8 |
9 | def find_chat_with_rtx_port():
10 | global port
11 | connections = psutil.net_connections(kind='inet')
12 | for host in connections:
13 | try:
14 | if host.pid:
15 | process = psutil.Process(host.pid)
16 | if "ChatWithRTX" in process.exe():
17 | test_port = host.laddr.port
18 | url = f"http://127.0.0.1:{test_port}/queue/join"
19 | response = requests.post(url, data="", timeout=0.05)
20 | if response.status_code == 422:
21 | port = test_port
22 | return
23 | except:
24 | pass
25 |
26 | def join_queue(session_hash, fn_index, port, chatdata):
27 | #fn_indexes are some gradio generated indexes from rag/trt/ui/user_interface.py
28 | python_object = {
29 | "data": chatdata,
30 | "event_data": None,
31 | "fn_index": fn_index,
32 | "session_hash": session_hash
33 | }
34 | json_string = json.dumps(python_object)
35 |
36 | url = f"http://127.0.0.1:{port}/queue/join"
37 | response = requests.post(url, data=json_string)
38 | # print("Join Queue Response:", response.json())
39 |
40 | def listen_for_updates(session_hash, port):
41 | url = f"http://127.0.0.1:{port}/queue/data?session_hash={session_hash}"
42 |
43 | response = requests.get(url, stream=True)
44 | for line in response.iter_lines():
45 | if line:
46 | try:
47 | data = json.loads(line[5:])
48 | # if data['msg'] == 'process_generating':
49 | # print(data['output']['data'][0][0][1])
50 | if data['msg'] == 'process_completed':
51 | return data['output']['data'][0][0][1]
52 | except Exception as e:
53 | pass
54 | return ""
55 |
56 | def send_message(message):
57 | if not port:
58 | find_chat_with_rtx_port()
59 | if not port:
60 | raise Exception("Failed to find a server port for 'Chat with RTX'. Ensure the server is running.")
61 |
62 | session_hash = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
63 |
64 | #add chat history here -v
65 | chatdata = [[[message, None]], None]
66 | join_queue(session_hash, 34, port, chatdata)
67 | return listen_for_updates(session_hash, port)
68 |
--------------------------------------------------------------------------------
/rtx_server_july_2024.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import shutil
3 | import os
4 |
5 | def copy_self_to_another_directory(target_directory):
6 | current_script = os.path.realpath(__file__)
7 | target_file = os.path.join(target_directory, os.path.basename(current_script))
8 | if os.path.exists(target_file):
9 | os.remove(target_file)
10 | shutil.copy(current_script, target_file)
11 |
12 | # copy & run self in the nvidia virtual env
13 | python_bin = "C:\\Program Files\\NVIDIA Corporation\\ChatRTX\\env_nvd_rag\\Scripts\\python.exe"
14 | target_directory = "C:\\Program Files\\NVIDIA Corporation\\ChatRTX\\RAG\\trt-llm-rag-windows-ChatRTX_0.4.0\\ChatRTXUI\\engine"
15 | if os.getcwd() != target_directory:
16 | copy_self_to_another_directory(target_directory)
17 | os.chdir(target_directory)
18 | subprocess.run([python_bin, target_directory + "\\" + os.path.basename(__file__)])
19 | exit(0)
20 |
21 |
22 |
23 |
24 | from http.server import BaseHTTPRequestHandler, HTTPServer
25 | from configuration import Configuration
26 | from backend import Backend, Mode
27 |
28 | data_path = os.path.expandvars("%programdata%\\NVIDIA Corporation\\chatrtx")
29 | backend = Backend(model_setup_dir=data_path)
30 |
31 | backend.init_model(model_id="mistral_7b_AWQ_int4_chat")
32 | # backend.init_model(model_id="llama2_13b_AWQ_INT4_chat")
33 | # backend.init_model(model_id="chatglm3_6b_AWQ_int4")
34 | # backend.init_model(model_id="gemma_7b_int4")
35 | # backend.init_model(model_id="clip_model")
36 |
37 | status = backend.ChatRTX(chatrtx_mode=Mode.AI)
38 | # status = backend.ChatRTX(chatrtx_mode=Mode.RAG, data_dir=dataset_dir)
39 |
40 | class RequestHandler(BaseHTTPRequestHandler):
41 | def do_POST(self):
42 | content_length = int(self.headers['Content-Length'])
43 | post_data = self.rfile.read(content_length)
44 |
45 | input_string = post_data.decode('utf-8')
46 | answer_stream = backend.query_stream(query=input_string)
47 |
48 | self.send_response(200)
49 | self.send_header('Content-type', 'text/plain')
50 | self.send_header('Transfer-Encoding', 'chunked')
51 | self.end_headers()
52 |
53 | for part in answer_stream:
54 | print(part)
55 | chunk = part.encode('utf-8')
56 | if not chunk:
57 | continue
58 | # HTTP/1.1 chunked transfer encoding standard
59 | chunk_size = len(chunk)
60 | self.wfile.write(f"{chunk_size:x}\r\n".encode('utf-8'))
61 | self.wfile.write(chunk + b"\r\n")
62 | self.wfile.flush()
63 | self.wfile.flush()
64 |
65 | httpd = HTTPServer(('', 8000), RequestHandler)
66 | httpd.serve_forever()
67 |
--------------------------------------------------------------------------------
/rtx_server_0_5_2025.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import shutil
3 | import os
4 |
5 |
6 | def copy_self_to_another_directory(target_directory):
7 | current_script = os.path.realpath(__file__)
8 | target_file = os.path.join(target_directory, os.path.basename(current_script))
9 | if os.path.exists(target_file):
10 | os.remove(target_file)
11 | shutil.copy(current_script, target_file)
12 |
13 |
14 | # copy & run self in the nvidia virtual env
15 | python_bin = (
16 | "C:\\Program Files\\NVIDIA Corporation\\ChatRTX\\env_nvd_rag\\Scripts\\python.exe"
17 | )
18 | target_directory = "C:\\Program Files\\NVIDIA Corporation\\ChatRTX\\RAG\\trt-llm-rag-windows-ChatRTX_0.5.0\\ChatRTXUI\\engine"
19 | if os.getcwd() != target_directory:
20 | copy_self_to_another_directory(target_directory)
21 | os.chdir(target_directory)
22 | subprocess.run([python_bin, target_directory + "\\" + os.path.basename(__file__)])
23 | exit(0)
24 |
25 |
26 | from http.server import BaseHTTPRequestHandler, HTTPServer
27 | from configuration import Configuration
28 | from backend import Backend, Mode
29 |
30 | data_path = os.path.expandvars("%programdata%\\NVIDIA Corporation\\chatrtx")
31 | backend = Backend(model_setup_dir=data_path)
32 |
33 | backend.init_model(model_id="mistral_7b_AWQ_int4_chat")
34 | # backend.init_model(model_id="llama2_13b_AWQ_INT4_chat")
35 | # backend.init_model(model_id="chatglm3_6b_AWQ_int4")
36 | # backend.init_model(model_id="gemma_7b_int4")
37 | # backend.init_model(model_id="clip_model")
38 |
39 | # dataset_dir = '{directory_to_your_dataset}'
40 | # status = backend.ChatRTX(chatrtx_mode=Mode.RAG, data_dir=dataset_dir)
41 | status = backend.ChatRTX(chatrtx_mode=Mode.AI)
42 |
43 |
44 | class RequestHandler(BaseHTTPRequestHandler):
45 | def do_POST(self):
46 | content_length = int(self.headers["Content-Length"])
47 | post_data = self.rfile.read(content_length)
48 |
49 | input_string = post_data.decode("utf-8")
50 | answer_stream = backend.query_stream(query=input_string)
51 |
52 | self.send_response(200)
53 | self.send_header("Content-type", "text/plain")
54 | self.send_header("Transfer-Encoding", "chunked")
55 | self.end_headers()
56 |
57 | for part in answer_stream:
58 | print(part)
59 | chunk = part.encode("utf-8")
60 | if not chunk:
61 | continue
62 | # HTTP/1.1 chunked transfer encoding standard
63 | chunk_size = len(chunk)
64 | self.wfile.write(f"{chunk_size:x}\r\n".encode("utf-8"))
65 | self.wfile.write(chunk + b"\r\n")
66 | self.wfile.flush()
67 | self.wfile.flush()
68 |
69 |
70 | httpd = HTTPServer(("", 8000), RequestHandler)
71 | httpd.serve_forever()
72 |
--------------------------------------------------------------------------------
/rtx_api_4_24.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import random
3 | import string
4 | import psutil
5 | import json
6 | import os
7 |
8 | port = None
9 | fn_index = None
10 |
11 | #the server is https because speech to text js microphone needs https
12 | appdata_folder = os.path.dirname(os.getenv('APPDATA')).replace('\\', '/')
13 | cert_path = appdata_folder + "/Local/NVIDIA/ChatRTX/RAG/trt-llm-rag-windows-ChatRTX_0.3/certs/servercert.pem"
14 | key_path = appdata_folder + "/Local/NVIDIA/ChatRTX/RAG/trt-llm-rag-windows-ChatRTX_0.3/certs/serverkey.pem"
15 | ca_bundle = appdata_folder + "/Local/NVIDIA/ChatRTX/env_nvd_rag/Library/ssl/cacert.pem"
16 |
17 | def find_chat_with_rtx_port():
18 | global port
19 | connections = psutil.net_connections(kind='inet')
20 | for host in connections:
21 | try:
22 | if host.pid:
23 | process = psutil.Process(host.pid)
24 |
25 | if "chatrtx" in process.exe().lower():
26 | test_port = host.laddr.port
27 | url = f"https://127.0.0.1:{test_port}/queue/join"
28 |
29 | response = requests.post(url, data="", timeout=0.1, cert=(cert_path, key_path), verify=ca_bundle)
30 | if response.status_code == 422:
31 | port = test_port
32 | return
33 | except:
34 | pass
35 |
36 | def join_queue(session_hash, set_fn_index, port, chatdata):
37 | #fn_indexes are some gradio generated indexes from rag/trt/ui/user_interface.py
38 | python_object = {
39 | "data": chatdata,
40 | "event_data": None,
41 | "fn_index": set_fn_index,
42 | "session_hash": session_hash
43 | }
44 | json_string = json.dumps(python_object)
45 |
46 | url = f"https://127.0.0.1:{port}/queue/join"
47 | response = requests.post(url, data=json_string, cert=(cert_path, key_path), verify=ca_bundle)
48 | # print("Join Queue Response:", response)
49 |
50 | def listen_for_updates(session_hash, port):
51 | url = f"https://127.0.0.1:{port}/queue/data?session_hash={session_hash}"
52 |
53 | response = requests.get(url, stream=True, cert=(cert_path, key_path), verify=ca_bundle)
54 | # print("Listen Response:", response)
55 | try:
56 | for line in response.iter_lines():
57 | if line:
58 | data = json.loads(line[5:])
59 | # if data['msg'] == 'process_generating':
60 | # print(data['output']['data'][0][0][1])
61 | if data['msg'] == 'process_completed':
62 | return data['output']['data'][0][0][1]
63 | except Exception as e:
64 | pass
65 | return ""
66 |
67 | def auto_find_fn_index(port):
68 | global fn_index
69 |
70 | print("Searching for llm streamed completion function. Takes about 30 seconds.")
71 | session_hash = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
72 | chatdata = [[["write a comma", None]], None]
73 | for i in range(10, 1000):
74 | join_queue(session_hash, i, port, chatdata)
75 | res = listen_for_updates(session_hash, port)
76 | if res:
77 | fn_index = i
78 | return
79 | raise Exception("Failed to find fn_index")
80 |
81 | def send_message(message):
82 | global fn_index
83 |
84 | if not port:
85 | find_chat_with_rtx_port()
86 | if not port:
87 | raise Exception("Failed to find a server port for 'Chat with RTX'. Ensure the server is running.")
88 | if not fn_index:
89 | #comment this line
90 | auto_find_fn_index(port)
91 | print("To make initialization instant hardcode:\nfn_index =", fn_index)
92 | # fn_index =
93 |
94 | session_hash = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
95 |
96 | chatdata = [[[message, None]], None]
97 | join_queue(session_hash, fn_index, port, chatdata)
98 | return listen_for_updates(session_hash, port)
99 |
--------------------------------------------------------------------------------