├── .gitignore ├── .gitattributes ├── client ├── requirements.txt ├── openai_completion.py ├── langchain_chat_openai.py ├── openai_completion_stream.py └── langchain_openai.py ├── src ├── routes │ ├── health_check.rs │ ├── mod.rs │ ├── completions.rs │ └── chat.rs ├── triton │ ├── mod.rs │ ├── telemetry.rs │ └── request.rs ├── lib.rs ├── state.rs ├── error.rs ├── main.rs ├── config.rs ├── utils.rs ├── startup.rs ├── telemetry.rs └── history │ └── mod.rs ├── images ├── demo.gif └── trace.png ├── .gitmodules ├── templates ├── history_template.liquid ├── history_template_llama3.liquid ├── history_template_custom_roles.liquid └── history_template_baichuan.liquid ├── .dockerignore ├── LICENSE ├── docker-compose.yml ├── Dockerfile ├── Cargo.toml ├── README.md └── Cargo.lock /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | .idea/ 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.liquid eol=lf 2 | -------------------------------------------------------------------------------- /client/requirements.txt: -------------------------------------------------------------------------------- 1 | openai 2 | langchain 3 | -------------------------------------------------------------------------------- /src/routes/health_check.rs: -------------------------------------------------------------------------------- 1 | pub async fn health_check() {} 2 | -------------------------------------------------------------------------------- /images/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/npuichigo/openai_trtllm/HEAD/images/demo.gif -------------------------------------------------------------------------------- /images/trace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/npuichigo/openai_trtllm/HEAD/images/trace.png -------------------------------------------------------------------------------- /src/triton/mod.rs: -------------------------------------------------------------------------------- 1 | tonic::include_proto!("inference"); 2 | 3 | pub(crate) mod request; 4 | pub(crate) mod telemetry; 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "triton_common"] 2 | path = triton_common 3 | url = https://github.com/triton-inference-server/common.git 4 | -------------------------------------------------------------------------------- /templates/history_template.liquid: -------------------------------------------------------------------------------- 1 | {% for item in items -%} 2 | {{ item.identity }}{% if item.name %} {{ item.name }}{% endif %}: {{ item.content }} 3 | {% endfor -%} 4 | ASSISTANT: -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod config; 2 | mod error; 3 | pub mod history; 4 | pub mod routes; 5 | pub mod startup; 6 | pub mod state; 7 | pub mod telemetry; 8 | mod utils; 9 | 10 | mod triton; 11 | -------------------------------------------------------------------------------- /templates/history_template_llama3.liquid: -------------------------------------------------------------------------------- 1 | {% for item in items -%} 2 | <|start_header_id|>{{ item.identity }}<|end_header_id|> 3 | {{ item.content }}<|eot_id|> 4 | {% endfor -%} 5 | <|start_header_id|>assistant<|end_header_id|> -------------------------------------------------------------------------------- /src/routes/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) use chat::compat_chat_completions; 2 | pub(crate) use completions::compat_completions; 3 | pub(crate) use health_check::health_check; 4 | 5 | pub(crate) mod chat; 6 | mod completions; 7 | mod health_check; 8 | -------------------------------------------------------------------------------- /client/openai_completion.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | 3 | from openai import OpenAI 4 | 5 | client = OpenAI(base_url="http://localhost:3000/v1", api_key="test") 6 | 7 | result = client.completions.create( 8 | model="ensemble", 9 | prompt="Say this is a test", 10 | ) 11 | pprint.pprint(result) 12 | -------------------------------------------------------------------------------- /src/state.rs: -------------------------------------------------------------------------------- 1 | use crate::history::HistoryBuilder; 2 | use crate::triton::grpc_inference_service_client::GrpcInferenceServiceClient; 3 | use tonic::transport::Channel; 4 | 5 | #[derive(Clone)] 6 | pub struct AppState { 7 | pub grpc_client: GrpcInferenceServiceClient, 8 | pub history_builder: HistoryBuilder, 9 | } 10 | -------------------------------------------------------------------------------- /templates/history_template_custom_roles.liquid: -------------------------------------------------------------------------------- 1 | {% for item in items -%} 2 | {%- capture identity -%} 3 | {%- case item.identity -%} 4 | {%- when "System", "Tool" -%} 5 | Robot 6 | {%- when "User" -%} 7 | Customer 8 | {%- when "Assistant" -%} 9 | Support 10 | {%- else -%} 11 | {%- endcase -%} 12 | {%- endcapture -%} 13 | 14 | {{- identity }}{% if item.name %} {{ item.name }}{% endif %}: {{ item.content }} 15 | {% endfor -%} 16 | ASSISTANT: -------------------------------------------------------------------------------- /templates/history_template_baichuan.liquid: -------------------------------------------------------------------------------- 1 | {% for item in items -%} 2 | {%- capture identity -%} 3 | {%- case item.identity -%} 4 | {%- when "System", "Tool" -%} 5 | System 6 | {%- when "User" -%} 7 | 8 | {%- when "Assistant" -%} 9 | 10 | {%- else -%} 11 | {%- endcase -%} 12 | {%- endcapture -%} 13 | 14 | {{- identity }}{% if item.name %} {{ item.name }}{% endif %}: {{ item.content }} 15 | {% endfor -%} 16 | : -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Ignore IntelliJ IDEA files 2 | .idea 3 | 4 | # Generated by Cargo 5 | # will have compiled files and executables 6 | target/ 7 | 8 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 9 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 10 | Cargo.lock 11 | 12 | # These are backup files generated by rustfmt 13 | **/*.rs.bk 14 | 15 | # MSVC Windows builds of rustc generate these, which store debugging information 16 | *.pdb 17 | 18 | /models 19 | -------------------------------------------------------------------------------- /client/langchain_chat_openai.py: -------------------------------------------------------------------------------- 1 | from langchain.chat_models import ChatOpenAI 2 | from langchain.schema.messages import HumanMessage, SystemMessage 3 | 4 | 5 | chat = ChatOpenAI(openai_api_base="http://localhost:3000/v1", 6 | openai_api_key="test", model_name="ensemble", 7 | max_tokens=100) 8 | 9 | messages = [ 10 | SystemMessage(content="You're a helpful assistant"), 11 | HumanMessage(content="What is the purpose of model regularization?"), 12 | ] 13 | 14 | result = chat.invoke(messages) 15 | print(result.content) 16 | -------------------------------------------------------------------------------- /client/openai_completion_stream.py: -------------------------------------------------------------------------------- 1 | from sys import stdout 2 | 3 | from openai import OpenAI 4 | 5 | client = OpenAI(base_url="http://localhost:3000/v1", api_key="test") 6 | 7 | response = client.completions.create( 8 | model="ensemble", 9 | prompt="This is a story of a hero who went", 10 | stream=True, 11 | max_tokens=50, 12 | ) 13 | for event in response: 14 | if not isinstance(event, dict): 15 | event = event.model_dump() 16 | event_text = event["choices"][0]["text"] 17 | stdout.write(event_text) 18 | stdout.flush() 19 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use axum::{ 2 | http::StatusCode, 3 | response::{IntoResponse, Response}, 4 | }; 5 | 6 | #[derive(Debug)] 7 | pub struct AppError(anyhow::Error); 8 | 9 | impl IntoResponse for AppError { 10 | fn into_response(self) -> Response { 11 | ( 12 | StatusCode::INTERNAL_SERVER_ERROR, 13 | "An error occurred while trying to fulfill your request.", 14 | ) 15 | .into_response() 16 | } 17 | } 18 | 19 | impl From for AppError 20 | where 21 | E: Into, 22 | { 23 | fn from(err: E) -> Self { 24 | Self(err.into()) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | use figment::providers::{Env, Serialized}; 3 | use figment::Figment; 4 | 5 | use openai_trtllm::config::Config; 6 | use openai_trtllm::startup; 7 | use openai_trtllm::telemetry; 8 | 9 | #[tokio::main] 10 | async fn main() -> anyhow::Result<()> { 11 | let config: Config = Figment::new() 12 | .merge(Env::prefixed("OPENAI_TRTLLM_")) 13 | .merge(Serialized::defaults(Config::parse())) 14 | .extract() 15 | .unwrap(); 16 | 17 | telemetry::init_subscriber("openai_trtllm", "info", config.otlp_endpoint.clone())?; 18 | 19 | startup::run_server(config).await 20 | } 21 | -------------------------------------------------------------------------------- /client/langchain_openai.py: -------------------------------------------------------------------------------- 1 | from langchain.llms import OpenAI 2 | from langchain.prompts import PromptTemplate 3 | from langchain.chains import LLMChain 4 | 5 | 6 | template = """ 7 | USER: You are a helpful, medical specialist. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. You always answer medical questions based on facts. 8 | ASSISTANT: Ok great ! I am a medical expert! 9 | USER: {question} 10 | ASSISTANT: 11 | """ 12 | 13 | prompt = PromptTemplate(template=template, input_variables=["question"]) 14 | llm = OpenAI(openai_api_base="http://localhost:3000/v1", openai_api_key="test", 15 | model_name="ensemble") 16 | llm_chain = LLMChain(prompt=prompt, llm=llm) 17 | 18 | question = "What can I do about glenoid cavity injury ?" 19 | 20 | result = llm_chain.run(question) 21 | print(result) 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy 2 | of this software and associated documentation files (the "Software"), to deal 3 | in the Software without restriction, including without limitation the rights 4 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 5 | copies of the Software, and to permit persons to whom the Software is 6 | furnished to do so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all 9 | copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. 18 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | openai_trtllm: 5 | image: openai_trtllm 6 | build: 7 | context: . 8 | dockerfile: Dockerfile 9 | command: 10 | - "--host" 11 | - "0.0.0.0" 12 | - "--port" 13 | - "3000" 14 | - "--triton-endpoint" 15 | - "http://tensorrtllm_backend:8001" 16 | ports: 17 | - "3000:3000" 18 | depends_on: 19 | - tensorrtllm_backend 20 | restart: on-failure 21 | 22 | # Triton backend for TensorRT LLM 23 | tensorrtllm_backend: 24 | image: nvcr.io/nvidia/tritonserver:24.03-trtllm-python-py3 25 | command: 26 | - "tritonserver" 27 | - "--model-repository=/models" 28 | volumes: 29 | - /path/to/model_repository:/models 30 | ports: 31 | - "8000:8000" 32 | - "8001:8001" 33 | - "8002:8002" 34 | deploy: 35 | replicas: 1 36 | resources: 37 | reservations: 38 | devices: 39 | - driver: nvidia 40 | count: 1 41 | capabilities: [ gpu ] 42 | shm_size: '2g' 43 | ulimits: 44 | memlock: -1 45 | stack: 67108864 46 | restart: on-failure 47 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rust:1.74.0-bookworm as chef 2 | 3 | WORKDIR /app 4 | 5 | RUN apt-get update && apt-get install lld clang protobuf-compiler -y 6 | 7 | RUN cargo install cargo-chef --locked 8 | 9 | FROM chef as planner 10 | 11 | COPY . . 12 | 13 | # Compute a lock-like file for our project 14 | RUN cargo chef prepare --recipe-path recipe.json 15 | 16 | FROM chef as builder 17 | 18 | COPY --from=planner /app/recipe.json recipe.json 19 | 20 | # Build our project dependencies, not our application! 21 | RUN cargo chef cook --release --recipe-path recipe.json 22 | 23 | # Up to this point, if our dependency tree stays the same, 24 | # all layers should be cached. 25 | COPY . . 26 | 27 | # Build our project 28 | RUN cargo build --release --bin openai_trtllm 29 | 30 | FROM debian:bookworm-slim AS runtime 31 | 32 | WORKDIR /app 33 | 34 | RUN apt-get update -y \ 35 | && apt-get install -y --no-install-recommends openssl ca-certificates \ 36 | # Clean up 37 | && apt-get autoremove -y \ 38 | && apt-get clean -y \ 39 | && rm -rf /var/lib/apt/lists/* 40 | 41 | COPY --from=builder /app/target/release/openai_trtllm openai_trtllm 42 | 43 | ENTRYPOINT ["./openai_trtllm"] 44 | -------------------------------------------------------------------------------- /src/config.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | use serde::{Deserialize, Serialize}; 3 | 4 | #[derive(Parser, Debug, Serialize, Deserialize)] 5 | pub struct Config { 6 | /// Host to bind to 7 | #[arg(long, short = 'H', default_value_t = String::from("0.0.0.0"))] 8 | pub host: String, 9 | 10 | /// Port to bind to 11 | #[arg(long, short, default_value_t = 3000)] 12 | pub port: usize, 13 | 14 | /// Triton gRPC endpoint 15 | #[arg(long, short, default_value_t = String::from("http://localhost:8001"))] 16 | pub triton_endpoint: String, 17 | 18 | /// Endpoint of OpenTelemetry collector 19 | #[arg(long, short)] 20 | #[serde(skip_serializing_if = "Option::is_none")] 21 | pub otlp_endpoint: Option, 22 | 23 | /// Template for converting OpenAI message history to prompt 24 | #[arg(long)] 25 | #[serde(skip_serializing_if = "Option::is_none")] 26 | pub history_template: Option, 27 | 28 | /// File containing the history template string 29 | #[arg(long)] 30 | #[serde(skip_serializing_if = "Option::is_none")] 31 | pub history_template_file: Option, 32 | 33 | /// Api Key to access the server 34 | #[arg(long)] 35 | #[serde(skip_serializing_if = "Option::is_none")] 36 | pub api_key: Option, 37 | } 38 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "openai_trtllm" 3 | version = "0.2.1" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | axum = { version = "0.7" } 8 | tokio = { version = "1.33.0", features = ["full"] } 9 | async-stream = "0.3.5" 10 | tonic = "0.10.2" 11 | prost = "0.12.1" 12 | prost-types = "0.12.1" 13 | anyhow = { version = "1.0.75", features = ["backtrace"] } 14 | clap = { version = "4.4.7", features = ["derive"] } 15 | figment = { version = "0.10.12", features = ["env"] } 16 | serde = { version = "1.0.190", features = ["derive"] } 17 | serde_json = "1.0.108" 18 | bytes = "1.5.0" 19 | uuid = { version = "1.6.1", features = ["v4"] } 20 | tracing = { version = "0.1", features = ["log"] } 21 | tracing-subscriber = { version = "0.3", features = ["registry", "env-filter", "json"] } 22 | tracing-opentelemetry = { version = "0.22" } 23 | tower-http = { version = "0.5", features = ["trace"] } 24 | opentelemetry = { version = "0.21.0", features = ["metrics", "logs"] } 25 | opentelemetry_sdk = { version = "0.21.1", features = ["rt-tokio", "logs"] } 26 | opentelemetry-otlp = { version = "0.14.0", features = ["tonic", "metrics", "logs"] } 27 | opentelemetry-semantic-conventions = { version = "0.13.0" } 28 | axum-tracing-opentelemetry = "0.16.0" 29 | liquid = "0.26.4" 30 | 31 | [build-dependencies] 32 | anyhow = "1.0.75" 33 | tonic-build = "0.10.2" 34 | walkdir = "2.4.0" 35 | -------------------------------------------------------------------------------- /src/triton/telemetry.rs: -------------------------------------------------------------------------------- 1 | use axum::http::HeaderMap; 2 | use opentelemetry::global; 3 | use opentelemetry::propagation::Injector; 4 | use tonic::Request; 5 | use tracing::Span; 6 | use tracing_opentelemetry::OpenTelemetrySpanExt; 7 | 8 | pub struct MetadataMap<'a>(&'a mut tonic::metadata::MetadataMap); 9 | 10 | impl<'a> Injector for MetadataMap<'a> { 11 | /// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs 12 | fn set(&mut self, key: &str, value: String) { 13 | if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) { 14 | if let Ok(val) = tonic::metadata::MetadataValue::try_from(&value) { 15 | self.0.insert(key, val); 16 | } 17 | } 18 | } 19 | } 20 | 21 | pub(crate) fn propagate_context(request: &mut Request, header: &HeaderMap) { 22 | let mut metadata_map = MetadataMap(request.metadata_mut()); 23 | 24 | // Propagate the current opentelemetry context 25 | let cx = Span::current().context(); 26 | global::get_text_map_propagator(|propagator| propagator.inject_context(&cx, &mut metadata_map)); 27 | 28 | // Propagate x-request-id header to the request if it exists 29 | if let Some(x_request_id) = header.get("x-request-id") { 30 | if let Ok(x_request_id) = x_request_id.to_str() { 31 | metadata_map.set("x-request-id", x_request_id.to_string()); 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::marker::PhantomData; 3 | use std::str; 4 | use std::str::Utf8Error; 5 | 6 | use bytes::{Buf, Bytes}; 7 | use serde::{de, Deserialize, Deserializer}; 8 | 9 | pub(crate) fn string_or_seq_string<'de, D>(deserializer: D) -> Result, D::Error> 10 | where 11 | D: Deserializer<'de>, 12 | { 13 | struct StringOrVec(PhantomData>); 14 | 15 | impl<'de> de::Visitor<'de> for StringOrVec { 16 | type Value = Vec; 17 | 18 | fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 19 | formatter.write_str("string or list of strings") 20 | } 21 | 22 | fn visit_str(self, value: &str) -> Result 23 | where 24 | E: de::Error, 25 | { 26 | Ok(vec![value.to_owned()]) 27 | } 28 | 29 | fn visit_seq(self, visitor: S) -> Result 30 | where 31 | S: de::SeqAccess<'de>, 32 | { 33 | Deserialize::deserialize(de::value::SeqAccessDeserializer::new(visitor)) 34 | } 35 | } 36 | 37 | deserializer.deserialize_any(StringOrVec(PhantomData)) 38 | } 39 | 40 | pub(crate) fn deserialize_bytes_tensor(encoded_tensor: Vec) -> Result, Utf8Error> { 41 | let mut bytes = Bytes::from(encoded_tensor); 42 | let mut strs = Vec::new(); 43 | while bytes.has_remaining() { 44 | let len = bytes.get_u32_le() as usize; 45 | if len <= bytes.remaining() { 46 | let slice = bytes.split_to(len); 47 | let s = str::from_utf8(&slice)?; 48 | strs.push(s.to_string()); 49 | } 50 | } 51 | Ok(strs) 52 | } 53 | -------------------------------------------------------------------------------- /src/startup.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use axum::routing::{get, post}; 3 | use axum::Router; 4 | use axum::middleware::{self, Next}; 5 | use axum::http::{Request, StatusCode}; 6 | use axum::response::Response; 7 | use axum::body::Body; 8 | use axum_tracing_opentelemetry::middleware::OtelAxumLayer; 9 | 10 | use crate::config::Config; 11 | use crate::history::HistoryBuilder; 12 | use crate::routes; 13 | use crate::state::AppState; 14 | use crate::triton::grpc_inference_service_client::GrpcInferenceServiceClient; 15 | 16 | async fn auth_middleware( 17 | req: Request, 18 | next: Next, 19 | api_key: Option, 20 | ) -> Result { 21 | if let Some(ref key) = api_key { 22 | if let Some(auth_header) = req.headers().get("Authorization") { 23 | if let Ok(auth_str) = auth_header.to_str() { 24 | if auth_str == format!("Bearer {}", key) { 25 | return Ok(next.run(req).await); 26 | } 27 | } 28 | } 29 | Err(StatusCode::UNAUTHORIZED) 30 | } else { 31 | Ok(next.run(req).await) 32 | } 33 | } 34 | 35 | pub async fn run_server(config: Config) -> anyhow::Result<()> { 36 | tracing::info!("Connecting to triton endpoint: {}", config.triton_endpoint); 37 | let grpc_client = GrpcInferenceServiceClient::connect(config.triton_endpoint) 38 | .await 39 | .context("failed to connect triton endpoint")?; 40 | 41 | let history_builder = 42 | HistoryBuilder::new(&config.history_template, &config.history_template_file)?; 43 | let state = AppState { 44 | grpc_client, 45 | history_builder, 46 | }; 47 | 48 | let api_key = config.api_key.clone(); 49 | 50 | let app = Router::new() 51 | .route("/v1/completions", post(routes::compat_completions)) 52 | .route( 53 | "/v1/chat/completions", 54 | post(routes::compat_chat_completions), 55 | ) 56 | .route("/health_check", get(routes::health_check)) 57 | .with_state(state) 58 | .layer(OtelAxumLayer::default()) 59 | .layer(middleware::from_fn(move |req, next| { 60 | auth_middleware(req, next, api_key.clone()) 61 | })); 62 | 63 | let address = format!("{}:{}", config.host, config.port); 64 | tracing::info!("Starting server at {}", address); 65 | 66 | let listener = tokio::net::TcpListener::bind(address).await.unwrap(); 67 | axum::serve(listener, app) 68 | .with_graceful_shutdown(shutdown_signal()) 69 | .await?; 70 | 71 | Ok(()) 72 | } 73 | 74 | async fn shutdown_signal() { 75 | tokio::signal::ctrl_c() 76 | .await 77 | .expect("failed to install CTRL+C signal handler"); 78 | 79 | opentelemetry::global::shutdown_tracer_provider(); 80 | } 81 | -------------------------------------------------------------------------------- /src/telemetry.rs: -------------------------------------------------------------------------------- 1 | use opentelemetry::trace::TraceError; 2 | use opentelemetry::{global, KeyValue}; 3 | use opentelemetry_otlp::WithExportConfig; 4 | use opentelemetry_sdk::propagation::TraceContextPropagator; 5 | use opentelemetry_sdk::trace as sdktrace; 6 | use opentelemetry_sdk::{runtime, Resource}; 7 | use tracing_subscriber::util::SubscriberInitExt; 8 | use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer}; 9 | 10 | fn init_tracer(name: &str, otlp_endpoint: &str) -> Result { 11 | opentelemetry_otlp::new_pipeline() 12 | .tracing() 13 | .with_exporter( 14 | opentelemetry_otlp::new_exporter() 15 | .tonic() 16 | .with_endpoint(otlp_endpoint), 17 | ) 18 | .with_trace_config( 19 | sdktrace::config() 20 | .with_resource(Resource::new(vec![KeyValue::new( 21 | "service.name", 22 | name.to_owned(), 23 | )])) 24 | .with_sampler(sdktrace::Sampler::AlwaysOn), 25 | ) 26 | .install_batch(runtime::Tokio) 27 | } 28 | 29 | /// Compose multiple layers into a `tracing`'s subscriber. 30 | /// 31 | /// # Implementation Notes 32 | /// 33 | /// We are using `impl Subscriber` as return type to avoid having to spell out the actual 34 | /// type of the returned subscriber, which is indeed quite complex. 35 | pub fn init_subscriber( 36 | name: &str, 37 | env_filter: &str, 38 | otlp_endpoint: Option, 39 | ) -> anyhow::Result<()> { 40 | global::set_text_map_propagator(TraceContextPropagator::new()); 41 | 42 | let env_filter = EnvFilter::try_from_default_env() 43 | .unwrap_or_else(|_| EnvFilter::new(env_filter)) 44 | .add_directive("otel::tracing=trace".parse()?) 45 | .add_directive("otel=debug".parse()?); 46 | 47 | let telemetry_layer = if let Some(otlp_endpoint) = otlp_endpoint { 48 | let tracer = init_tracer(name, &otlp_endpoint)?; 49 | 50 | Some( 51 | tracing_opentelemetry::layer() 52 | .with_error_records_to_exceptions(true) 53 | .with_tracer(tracer), 54 | ) 55 | } else { 56 | None 57 | }; 58 | 59 | let fmt_layer = if cfg!(debug_assertions) { 60 | tracing_subscriber::fmt::layer() 61 | .pretty() 62 | .with_line_number(true) 63 | .with_thread_names(true) 64 | .boxed() 65 | } else { 66 | tracing_subscriber::fmt::layer() 67 | .json() 68 | .flatten_event(true) 69 | .boxed() 70 | }; 71 | 72 | tracing_subscriber::registry() 73 | .with(env_filter) 74 | .with(telemetry_layer) 75 | .with(fmt_layer) 76 | .init(); 77 | 78 | Ok(()) 79 | } 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # openai_trtllm - OpenAI-compatible API for TensorRT-LLM 2 | 3 | Provide an OpenAI-compatible API for [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) 4 | and [NVIDIA Triton Inference Server](https://github.com/triton-inference-server/tensorrtllm_backend), which allows you 5 | to integrate with [langchain](https://github.com/langchain-ai/langchain) 6 | 7 | ## Quick overview 8 | 9 | ![demo](images/demo.gif) 10 | 11 | ## Get started 12 | 13 | ### Prerequisites 14 | 15 | Make sure you have built your own TensorRT LLM engine following 16 | the [tensorrtllm_backend tutorial](https://github.com/triton-inference-server/tensorrtllm_backend#using-the-tensorrt-llm-backend). 17 | The final model repository should look like 18 | the [official example](https://github.com/triton-inference-server/tensorrtllm_backend/tree/v0.9.0/all_models/inflight_batcher_llm). 19 | 20 | **Notice: to enable streaming, you should set decoupled to true for triton_model_repo/tensorrt_llm/config.pbtxt per the 21 | tutorial** 22 | 23 | Remember to include the dependencies when cloning to build the project. 24 | 25 | ```bash 26 | git clone --recursive https://github.com/npuichigo/openai_trtllm.git 27 | ``` 28 | 29 | ### Build locally 30 | 31 | Make sure you have [Rust](https://www.rust-lang.org/tools/install) installed. 32 | 33 | ```bash 34 | cargo run --release 35 | ``` 36 | 37 | The executable arguments can be set from environment variables (prefixed by OPENAI_TRTLLM_) or command line: 38 | 39 | **Notice: `openai_trtllm` communicate with `triton` over gRPC, so the `--triton-endpoint` should be the gRPC port.** 40 | 41 | ```bash 42 | ./target/release/openai_trtllm --help 43 | Usage: openai_trtllm [OPTIONS] 44 | 45 | Options: 46 | -H, --host 47 | Host to bind to [default: 0.0.0.0] 48 | -p, --port 49 | Port to bind to [default: 3000] 50 | -t, --triton-endpoint 51 | Triton gRPC endpoint [default: http://localhost:8001] 52 | -o, --otlp-endpoint 53 | Endpoint of OpenTelemetry collector 54 | --history-template 55 | Template for converting OpenAI message history to prompt 56 | --history-template-file 57 | File containing the history template string 58 | --api-key 59 | Api Key to access the server 60 | -h, --help 61 | Print help 62 | ``` 63 | 64 | ### Build with Docker 65 | 66 | Make sure you have [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/) 67 | installed. 68 | 69 | ```bash 70 | docker compose build openai_trtllm 71 | docker compose up 72 | ``` 73 | 74 | ## Chat template 75 | 76 | `openai_trtllm` support custom history templates to convert message history to prompt for chat models. The template 77 | engine used here is [liquid](https://shopify.github.io/liquid/). Follow the syntax to create your own template. 78 | 79 | For examples of history templates, see the [templates](templates) folder. 80 | 81 | Here's an example of llama3: 82 | 83 | ``` 84 | {% for item in items -%} 85 | <|start_header_id|>{{ item.identity }}<|end_header_id|> 86 | {{ item.content }}<|eot_id|> 87 | {% endfor -%} 88 | <|start_header_id|>assistant<|end_header_id|> 89 | ``` 90 | 91 | ## LangChain integration 92 | 93 | Since the `openai_trtllm` is compatible with OpenAI API, you can easily integrate with LangChain as an alternative to 94 | [`OpenAI`](https://api.python.langchain.com/en/latest/llms/langchain_openai.llms.base.OpenAI.html#langchain_openai.llms.base.OpenAI) 95 | or [`ChatOpenAI`](https://api.python.langchain.com/en/latest/chat_models/langchain_openai.chat_models.base.ChatOpenAI.html#langchain_openai.chat_models.base.ChatOpenAI). 96 | 97 | Although you can use the 98 | [TensorRT LLM integration](https://api.python.langchain.com/en/latest/llms/langchain_nvidia_trt.llms.TritonTensorRTLLM.html#langchain_nvidia_trt.llms.TritonTensorRTLLM) 99 | published recently, it has no support for chat models yet, not to mention user defined templates. 100 | 101 | ## Tracing 102 | 103 | Trace is available with the support of tracing, tracing-opentelemetry and opentelemetry-otlp crates. 104 | 105 | Here is an example of tracing with Tempo on a k8s cluster: 106 | 107 | 108 | To test tracing locally, let's say you use the Jaeger backend. 109 | 110 | ```bash 111 | docker run --rm --name jaeger \ 112 | -p 6831:6831/udp \ 113 | -p 6832:6832/udp \ 114 | -p 5778:5778 \ 115 | -p 16686:16686 \ 116 | -p 4317:4317 \ 117 | -p 4318:4318 \ 118 | -p 14250:14250 \ 119 | -p 14268:14268 \ 120 | -p 14269:14269 \ 121 | -p 9411:9411 \ 122 | jaegertracing/all-in-one:1.51 123 | 124 | ``` 125 | 126 | To enable tracing, set the `OPENAI_TRTLLM_OTLP_ENDPOINT` environment variable or `--otlp-endpoint` command line 127 | argument to the endpoint of your OpenTelemetry collector. 128 | 129 | ```bash 130 | OPENAI_TRTLLM_OTLP_ENDPOINT=http://localhost:4317 cargo run --release 131 | ``` 132 | 133 | ## References 134 | 135 | - [cria](https://github.com/AmineDiro/cria) 136 | -------------------------------------------------------------------------------- /src/triton/request.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | 3 | use crate::triton::model_infer_request::{InferInputTensor, InferRequestedOutputTensor}; 4 | 5 | use super::{InferTensorContents, ModelInferRequest}; 6 | 7 | pub(crate) struct Builder { 8 | inner: anyhow::Result, 9 | } 10 | 11 | impl Builder { 12 | pub(crate) fn new() -> Self { 13 | Self::default() 14 | } 15 | 16 | pub(crate) fn build(self) -> anyhow::Result { 17 | self.inner 18 | } 19 | 20 | pub(crate) fn model_name(self, model_name: S) -> Self 21 | where 22 | S: Into, 23 | { 24 | self.and_then(|mut request| { 25 | request.model_name = model_name.into(); 26 | Ok(request) 27 | }) 28 | } 29 | 30 | fn model_version(self, model_version: S) -> Self 31 | where 32 | S: Into, 33 | { 34 | self.and_then(|mut request| { 35 | request.model_version = model_version.into(); 36 | Ok(request) 37 | }) 38 | } 39 | 40 | fn id(self, id: S) -> Self 41 | where 42 | S: Into, 43 | { 44 | self.and_then(|mut request| { 45 | request.id = id.into(); 46 | Ok(request) 47 | }) 48 | } 49 | 50 | pub(crate) fn input(self, name: S, shape: V, data: InferTensorData) -> Self 51 | where 52 | S: Into, 53 | V: Into>, 54 | { 55 | self.and_then(|mut request| { 56 | request.inputs.push(InferInputTensor { 57 | name: name.into(), 58 | shape: shape.into(), 59 | datatype: data.as_ref().into(), 60 | contents: Some(data.into()), 61 | ..Default::default() 62 | }); 63 | Ok(request) 64 | }) 65 | } 66 | 67 | pub(crate) fn output(self, name: S) -> Self 68 | where 69 | S: Into, 70 | { 71 | self.and_then(|mut request| { 72 | request.outputs.push(InferRequestedOutputTensor { 73 | name: name.into(), 74 | ..Default::default() 75 | }); 76 | Ok(request) 77 | }) 78 | } 79 | 80 | fn and_then(self, f: F) -> Self 81 | where 82 | F: FnOnce(ModelInferRequest) -> anyhow::Result, 83 | { 84 | Self { 85 | inner: self.inner.and_then(f), 86 | } 87 | } 88 | } 89 | 90 | impl Default for Builder { 91 | fn default() -> Self { 92 | Self { 93 | inner: Ok(ModelInferRequest::default()), 94 | } 95 | } 96 | } 97 | 98 | pub(crate) enum InferTensorData { 99 | Bool(Vec), 100 | Int32(Vec), 101 | Int64(Vec), 102 | UInt32(Vec), 103 | UInt64(Vec), 104 | FP32(Vec), 105 | FP64(Vec), 106 | Bytes(Vec>), 107 | } 108 | 109 | /// View `InferTensorData` as triton datatype 110 | impl AsRef for InferTensorData { 111 | fn as_ref(&self) -> &str { 112 | match self { 113 | InferTensorData::Bool(_) => "BOOL", 114 | InferTensorData::Int32(_) => "INT32", 115 | InferTensorData::Int64(_) => "INT64", 116 | InferTensorData::UInt32(_) => "UINT32", 117 | InferTensorData::UInt64(_) => "UINT64", 118 | InferTensorData::FP32(_) => "FP32", 119 | InferTensorData::FP64(_) => "FP64", 120 | InferTensorData::Bytes(_) => "BYTES", 121 | } 122 | } 123 | } 124 | 125 | impl From for InferTensorContents { 126 | fn from(data: InferTensorData) -> Self { 127 | match data { 128 | InferTensorData::Bool(data) => InferTensorContents { 129 | bool_contents: data, 130 | ..Default::default() 131 | }, 132 | InferTensorData::Int32(data) => InferTensorContents { 133 | int_contents: data, 134 | ..Default::default() 135 | }, 136 | InferTensorData::Int64(data) => InferTensorContents { 137 | int64_contents: data, 138 | ..Default::default() 139 | }, 140 | InferTensorData::UInt32(data) => InferTensorContents { 141 | uint_contents: data, 142 | ..Default::default() 143 | }, 144 | InferTensorData::UInt64(data) => InferTensorContents { 145 | uint64_contents: data, 146 | ..Default::default() 147 | }, 148 | InferTensorData::FP32(data) => InferTensorContents { 149 | fp32_contents: data, 150 | ..Default::default() 151 | }, 152 | InferTensorData::FP64(data) => InferTensorContents { 153 | fp64_contents: data, 154 | ..Default::default() 155 | }, 156 | InferTensorData::Bytes(data) => InferTensorContents { 157 | bytes_contents: data, 158 | ..Default::default() 159 | }, 160 | } 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /src/history/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::routes::chat::ChatCompletionMessageParams; 2 | use anyhow::bail; 3 | use liquid::{ParserBuilder, Template}; 4 | use serde::Serialize; 5 | use std::fs::File; 6 | use std::io::Read; 7 | use std::sync::Arc; 8 | 9 | const DEFAULT_TEMPLATE: &str = "{% for item in items %}\ 10 | {{ item.identity }}{% if item.name %} {{ item.name }}{% endif %}: {{ item.content }} 11 | {% endfor %}\ 12 | ASSISTANT:"; 13 | 14 | #[derive(Clone)] 15 | pub struct HistoryBuilder { 16 | history_template: Arc