├── .gitignore ├── Cargo.toml ├── .circleci └── config.yml ├── src ├── main.rs ├── http.rs ├── inference.rs └── args.rs ├── README.md └── Dockerfile.gpu /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /target 3 | **/*.rs.bk 4 | **/.*.swp 5 | 6 | version.json 7 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ds-srv" 3 | version = "0.7.4" 4 | authors = ["Alexandre Lissy "] 5 | 6 | [features] 7 | default = [] 8 | dump_debug_stream = [] 9 | 10 | [dependencies] 11 | deepspeech = "0.7.0" 12 | audrey = "0.2" 13 | clap = "2.31.2" 14 | log = "0.4.1" 15 | simplelog = "0.5.2" 16 | hyper = "0.12.1" 17 | futures = "0.1.21" 18 | bytes = "0.4.8" 19 | byte-slice-cast = "0.2.0" 20 | serde = "1.0.66" 21 | serde_derive = "1.0.66" 22 | serde_json = "1.0.19" 23 | mkstemp-rs = "1.0.0" 24 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | jobs: 4 | build-job: 5 | docker: 6 | - image: cimg/base:2020.01 7 | steps: 8 | - checkout 9 | - setup_remote_docker 10 | - run: 11 | name: Create a version.json 12 | command: | 13 | # create a version.json per https://github.com/mozilla-services/Dockerflow/blob/master/docs/version_object.md 14 | printf '{"commit":"%s","version":"%s","source":"https://github.com/%s/%s","build":"%s"}\n' \ 15 | "$CIRCLE_SHA1" \ 16 | "$CIRCLE_TAG" \ 17 | "$CIRCLE_PROJECT_USERNAME" \ 18 | "$CIRCLE_PROJECT_REPONAME" \ 19 | "$CIRCLE_BUILD_URL" > version.json 20 | - run: 21 | name: Build Docker image 22 | command: docker build -f Dockerfile.gpu -t app:build . 23 | - run: 24 | name: Push Docker image to Dockerhub 25 | command: | 26 | echo $DOCKER_PASS | docker login -u $DOCKER_USER --password-stdin 27 | 28 | if [ "${CIRCLE_BRANCH}" == "master" ]; then 29 | docker tag app:build ${DOCKERHUB_REPO}:latest 30 | docker push ${DOCKERHUB_REPO}:latest 31 | elif [ ! -z "${CIRCLE_TAG}" ]; then 32 | docker tag app:build "${DOCKERHUB_REPO}:${CIRCLE_TAG}" 33 | docker push "${DOCKERHUB_REPO}:${CIRCLE_TAG}" 34 | fi 35 | 36 | workflows: 37 | version: 2 38 | build-workflow: 39 | jobs: 40 | - build-job: 41 | filters: 42 | tags: 43 | only: /.*/ 44 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate log; 3 | #[macro_use] 4 | extern crate serde_derive; 5 | 6 | extern crate simplelog; 7 | 8 | use std::sync::mpsc::channel; 9 | use std::thread; 10 | 11 | mod args; 12 | use args::ArgsParser; 13 | 14 | mod http; 15 | use http::th_http_listener; 16 | 17 | mod inference; 18 | use inference::th_inference; 19 | 20 | fn main() { 21 | let rc = ArgsParser::from_cli(); 22 | 23 | let log_level = rc.verbosity_level.into(); 24 | let _ = simplelog::TermLogger::init(log_level, simplelog::Config::default()); 25 | 26 | debug!("Parsed all CLI args: {:?}", rc); 27 | 28 | let (tx_audio, rx_audio) = channel(); 29 | 30 | let mut threads = Vec::new(); 31 | let rc_inference = rc.clone(); 32 | let thread_inference = thread::Builder::new() 33 | .name("InferenceService".to_string()) 34 | .spawn(move || { 35 | th_inference( 36 | rc_inference.model, 37 | rc_inference.scorer, 38 | rx_audio, 39 | rc_inference.dump_dir, 40 | rc_inference.warmup_dir, 41 | rc_inference.warmup_cycles, 42 | ); 43 | }); 44 | threads.push(thread_inference); 45 | 46 | let rc_http = rc.clone(); 47 | let thread_http = thread::Builder::new() 48 | .name("HttpService".to_string()) 49 | .spawn(move || { 50 | th_http_listener(rc_http.http_ip, rc_http.http_port, tx_audio); 51 | }); 52 | threads.push(thread_http); 53 | 54 | println!("Started all thread."); 55 | 56 | for hdl in threads { 57 | if hdl.is_ok() { 58 | hdl.unwrap().join().unwrap(); 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Build 2 | ===== 3 | - Download compatible version of `native_client.tar.xz` (check `deepspeech-rs`) 4 | - `LB_LIBRARY_PATH=... LIBRARY_PATH=... cargo build` with both path pointing to the extracted `native_client.tar.xz` 5 | 6 | Run 7 | === 8 | - Download compatible DeepSpeech model and extract 9 | - 10 | ``` 11 | $ LD_LIBRARY_PATH=...: ./target/debug/ds-srv --model models/output_graph.pbmm --lm models/lm.binary --trie models/trie -vvvvv 12 | 05:16:57 [DEBUG] ds_srv: Parsed all CLI args: RuntimeConfig { http_ip: V6(::), http_port: 8080, dump_dir: "/tmp", warmup_dir: "", warmup_cycles: 10, model: "models/output_graph.pbmm", lm: "models/lm.binary", trie: "models/trie", verbosity_level: DEBUG } 13 | Started all thread. 14 | 05:16:57 [INFO] Inference thread started 15 | TensorFlow: v1.11.0-rc2-4-g77b7b17 16 | 05:16:57 [INFO] Building server http://[::]:8080 17 | DeepSpeech: v0.2.1-alpha.1-0-gae2cfe0 18 | 05:16:57 [INFO] Listening on http://[::]:8080 19 | 2018-09-27 07:16:57.157376: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA 20 | 05:16:57 [DEBUG] tokio_reactor::background: starting background reactor 21 | 05:17:02 [INFO] Model ready and waiting for data to infer ... 22 | ``` 23 | 24 | Test 25 | ==== 26 | 27 | Using `4507-16021-0012.wav` from DeepSpeech's release: 28 | 29 | ``` 30 | $ curl -v -H 'Content-Type: application/octet-stream' --data-binary @"./audio/4507-16021-0012.wav" http://127.0.0.1:8080 31 | * Trying 127.0.0.1... 32 | * TCP_NODELAY set 33 | * Connected to 127.0.0.1 (127.0.0.1) port 8080 (#0) 34 | > POST / HTTP/1.1 35 | > Host: 127.0.0.1:8080 36 | > User-Agent: curl/7.58.0 37 | > Accept: */* 38 | > Content-Type: application/octet-stream 39 | > Content-Length: 87564 40 | > Expect: 100-continue 41 | > 42 | < HTTP/1.1 100 Continue 43 | * We are completely uploaded and fine 44 | < HTTP/1.1 200 OK 45 | < content-type: application/json 46 | < content-length: 84 47 | < date: Thu, 27 Sep 2018 05:12:36 GMT 48 | < 49 | * Connection #0 to host 127.0.0.1 left intact 50 | {"status":"ok","data":[{"text":"why should one hall on the way ","confidence":1.0}]} 51 | ``` 52 | -------------------------------------------------------------------------------- /Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-runtime-ubuntu18.04 2 | 3 | ARG DEEPSPEECH_VERSION=0.7.4 4 | 5 | RUN apt-get update && \ 6 | apt-get install -y --no-install-recommends \ 7 | build-essential \ 8 | clang-5.0 \ 9 | sudo \ 10 | curl 11 | 12 | RUN useradd -c 'ds-srv' -m -d /home/ds -s /bin/bash ds 13 | 14 | ENV CUDA_ROOT /usr/local/cuda-10.0/ 15 | ENV HOME /home/ds 16 | ENV DS_VER $DEEPSPEECH_VERSION 17 | ENV LD_LIBRARY_PATH $HOME/lib/:$CUDA_ROOT/lib64/:$LD_LIBRARY_PATH 18 | ENV LIBRARY_PATH $LD_LIBRARY_PATH 19 | ENV PATH $HOME/.cargo/bin/:$HOME/bin/:$PATH 20 | 21 | RUN mkdir /app && chown ds:ds /app 22 | 23 | COPY --chown=ds:ds version.json /app/version.json 24 | 25 | # required for ldconfig call to fix libnvidia-ml.so issue 26 | # Workaround libnvidia-ml.so: https://github.com/NVIDIA/nvidia-docker/issues/854#issuecomment-451464721 27 | RUN echo "ds ALL=(root) NOPASSWD: /sbin/ldconfig" > /etc/sudoers.d/ds && \ 28 | chmod 0440 /etc/sudoers.d/ds 29 | 30 | USER ds 31 | 32 | EXPOSE 8080 33 | 34 | WORKDIR /home/ds 35 | 36 | RUN mkdir -p ${HOME}/lib/ ${HOME}/bin/ ${HOME}/data/models/ ${HOME}/src/ds-srv/ 37 | 38 | RUN curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable 39 | 40 | RUN curl https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.v${DS_VER}.gpu/artifacts/public/native_client.tar.xz -sSL | xz -d | tar -C ${HOME}/lib/ -xf - 41 | 42 | RUN curl https://github.com/mozilla/DeepSpeech/releases/download/v${DS_VER}/deepspeech-${DS_VER}-models.pbmm -sSL > ${HOME}/data/models/output_graph.pbmm 43 | 44 | RUN curl https://github.com/mozilla/DeepSpeech/releases/download/v${DS_VER}/deepspeech-${DS_VER}-models.scorer -sSL > ${HOME}/data/models/kenlm.scorer 45 | 46 | COPY Cargo.toml ${HOME}/src/ds-srv/ 47 | 48 | COPY src ${HOME}/src/ds-srv/src/ 49 | 50 | # Force stubs required for building, but breaking runtime 51 | RUN cargo install --force --path ${HOME}/src/ds-srv/ 52 | 53 | ENTRYPOINT sudo /sbin/ldconfig && nvidia-smi && ds-srv \ 54 | -vvvv \ 55 | --model $HOME/data/models/output_graph.pbmm \ 56 | --scorer $HOME/data/models/kenlm.scorer \ 57 | --http_ip ::0 \ 58 | --http_port 8080 59 | -------------------------------------------------------------------------------- /src/http.rs: -------------------------------------------------------------------------------- 1 | extern crate futures; 2 | extern crate hyper; 3 | extern crate serde_json; 4 | 5 | use args::TcpPort; 6 | 7 | use self::futures::{future, Future, Stream}; 8 | use self::hyper::header::{HeaderValue, CONTENT_TYPE}; 9 | use self::hyper::service::service_fn; 10 | use self::hyper::{Body, Method, Request, Response, Server, StatusCode}; 11 | 12 | use std::net::{IpAddr, SocketAddr}; 13 | use std::sync::mpsc::channel; 14 | use std::sync::mpsc::Sender; 15 | 16 | use std::fs::File; 17 | use std::io::Read; 18 | 19 | type ResponseFuture = Box, Error = hyper::Error> + Send>; 20 | 21 | use inference::InferenceResult; 22 | use inference::RawAudioPCM; 23 | 24 | static mut tx_audio: Option)>> = None; 25 | 26 | fn http_handler(req: Request) -> ResponseFuture { 27 | debug!("Received HTTP: {} {}", req.method(), req.uri()); 28 | match (req.method(), req.uri().path()) { 29 | (&Method::GET, "/__version__") => { 30 | debug!("Reading version JSON from /app/version.json"); 31 | let mut json_version = String::new(); 32 | let mut file = File::open("/app/version.json").unwrap(); 33 | file.read_to_string(&mut json_version).unwrap(); 34 | Box::new(future::ok( 35 | Response::builder() 36 | .status(StatusCode::OK) 37 | .header(CONTENT_TYPE, "application/json") 38 | .body(Body::from(json_version)) 39 | .unwrap() 40 | )) 41 | }, 42 | (&Method::GET, "/__heartbeat__") => { 43 | debug!("App heatbeat checks"); 44 | Box::new(future::ok( 45 | Response::builder() 46 | .status(StatusCode::OK) 47 | .body(Body::from("")) 48 | .unwrap() 49 | )) 50 | }, 51 | (&Method::GET, "/__lbheartbeat__") => { 52 | debug!("Load-Balancer heatbeat checks"); 53 | Box::new(future::ok( 54 | Response::builder() 55 | .status(StatusCode::OK) 56 | .body(Body::from("")) 57 | .unwrap() 58 | )) 59 | }, 60 | (&Method::POST, "/") => { 61 | debug!("POST connection accepted"); 62 | let (parts, body) = req.into_parts(); 63 | match parts.headers.get(CONTENT_TYPE) { 64 | Some(h) if h == HeaderValue::from_static("application/octet-stream") => { 65 | debug!("This is valid: {:?}", h); 66 | Box::new(body.concat2().map(|audio_content| { 67 | let raw_pcm = audio_content.into_bytes(); 68 | debug!("RAW PCM is {:?} bytes", raw_pcm.len()); 69 | let inference_result = raw_pcm.len(); 70 | let infer = format!("inference: {}", inference_result); 71 | 72 | let pcm = RawAudioPCM { 73 | content: raw_pcm.clone(), 74 | }; 75 | 76 | let (tx_string, rx_string) = channel(); 77 | 78 | unsafe { 79 | match tx_audio { 80 | Some(ref tx_audio_ok) => match tx_audio_ok 81 | .clone() 82 | .send((pcm, tx_string)) 83 | { 84 | Ok(_) => { 85 | debug!("Successfully sent message to thread"); 86 | match rx_string.recv() { 87 | Ok(decoded_audio) => { 88 | info!("Received reply: {:?}", decoded_audio); 89 | Response::builder() 90 | .status(StatusCode::OK) 91 | .header(CONTENT_TYPE, "application/json") 92 | .body(Body::from( 93 | serde_json::to_string(&decoded_audio) 94 | .unwrap(), 95 | )) 96 | .unwrap() 97 | } 98 | Err(err_recv) => { 99 | error!("Error trying to rx.recv(): {:?}", err_recv); 100 | Response::builder() 101 | .status(StatusCode::NOT_FOUND) 102 | .body(infer.into()) 103 | .unwrap() 104 | } 105 | } 106 | } 107 | Err(err) => { 108 | error!("Error while sending message to thread: {:?}", err); 109 | Response::builder() 110 | .status(StatusCode::NOT_FOUND) 111 | .body(infer.into()) 112 | .unwrap() 113 | } 114 | }, 115 | None => { 116 | error!("Unable to tx.send()"); 117 | Response::builder() 118 | .status(StatusCode::NOT_FOUND) 119 | .body(infer.into()) 120 | .unwrap() 121 | } 122 | } 123 | } 124 | })) 125 | } 126 | _ => Box::new(future::ok( 127 | Response::builder() 128 | .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) 129 | .body(Body::empty()) 130 | .unwrap(), 131 | )), 132 | } 133 | } 134 | _ => Box::new(future::ok( 135 | Response::builder() 136 | .status(StatusCode::METHOD_NOT_ALLOWED) 137 | .body(Body::empty()) 138 | .unwrap(), 139 | )), 140 | } 141 | } 142 | 143 | pub fn th_http_listener( 144 | http_ip: IpAddr, 145 | http_port: TcpPort, 146 | _tx_audio: Sender<(RawAudioPCM, Sender)>, 147 | ) { 148 | unsafe { 149 | tx_audio = Some(_tx_audio); 150 | } 151 | 152 | let socket = SocketAddr::new(http_ip, http_port); 153 | info!("Building server http://{}", &socket); 154 | let server = Server::bind(&socket) 155 | .serve(|| service_fn(http_handler)) 156 | .map_err(|e| eprintln!("server error: {}", e)); 157 | info!("Listening on http://{}", socket); 158 | hyper::rt::run(server); 159 | } 160 | -------------------------------------------------------------------------------- /src/inference.rs: -------------------------------------------------------------------------------- 1 | extern crate serde; 2 | 3 | extern crate audrey; 4 | extern crate deepspeech; 5 | extern crate futures; 6 | 7 | extern crate mkstemp; 8 | 9 | extern crate byte_slice_cast; 10 | extern crate bytes; 11 | 12 | use self::audrey::read::Description; 13 | use self::audrey::read::Reader; 14 | use self::audrey::Format; 15 | use self::byte_slice_cast::*; 16 | use self::bytes::Bytes; 17 | use self::deepspeech::Model; 18 | 19 | use std::fs::File; 20 | use std::io::Cursor; 21 | use std::path::Path; 22 | use std::sync::mpsc::{Receiver, Sender}; 23 | use std::time::Instant; 24 | use std::vec::Vec; 25 | 26 | #[derive(Debug)] 27 | pub struct RawAudioPCM { 28 | pub content: Bytes, 29 | } 30 | 31 | #[derive(Debug, Serialize, Deserialize)] 32 | pub struct InferenceData { 33 | text: String, 34 | confidence: f32, 35 | } 36 | 37 | #[derive(Debug, Serialize, Deserialize)] 38 | pub struct InferenceResult { 39 | status: String, 40 | data: Vec, 41 | } 42 | 43 | // The model has been trained on this specific 44 | // sample rate. 45 | const AUDIO_SAMPLE_RATE: u32 = 16000; 46 | const AUDIO_CHANNELS: u32 = 1; 47 | const AUDIO_FORMAT: Format = Format::Wav; 48 | 49 | fn start_model(model: String, scorer: String) -> Model { 50 | let mut m = Model::load_from_files( 51 | Path::new(&model) 52 | ).unwrap(); 53 | 54 | m.enable_external_scorer( 55 | Path::new(&scorer) 56 | ); 57 | 58 | m 59 | } 60 | 61 | fn ensure_valid_audio(desc: Description) -> bool { 62 | let rv_format = if desc.format() != AUDIO_FORMAT { 63 | error!("Invalid audio format: {:?}", desc.format()); 64 | false 65 | } else { 66 | true 67 | }; 68 | 69 | let rv_channels = if desc.channel_count() != AUDIO_CHANNELS { 70 | error!("Invalid number of channels: {}", desc.channel_count()); 71 | false 72 | } else { 73 | true 74 | }; 75 | 76 | let rv_rate = if desc.sample_rate() != AUDIO_SAMPLE_RATE { 77 | error!("Invalid sample rate: {}", desc.sample_rate()); 78 | false 79 | } else { 80 | true 81 | }; 82 | 83 | rv_format && rv_channels && rv_rate 84 | } 85 | 86 | fn inference_result(result: String, status: bool) -> InferenceResult { 87 | let confidence_value = match status { 88 | true => 1.0, 89 | false => 0.0, 90 | }; 91 | 92 | let status_value = match status { 93 | true => "ok".to_string(), 94 | false => "ko".to_string(), 95 | }; 96 | 97 | let mut inf_data: Vec = Vec::new(); 98 | inf_data.push(InferenceData { 99 | confidence: confidence_value, 100 | text: result, 101 | }); 102 | 103 | let inf_result = InferenceResult { 104 | status: status_value, 105 | data: inf_data, 106 | }; 107 | 108 | inf_result 109 | } 110 | 111 | fn inference_error() -> InferenceResult { 112 | inference_result("".to_string(), false) 113 | } 114 | 115 | fn inference(m: &mut Model, buffer: &[i16]) -> InferenceResult { 116 | let start = Instant::now(); 117 | 118 | let rv = match m.speech_to_text(buffer) { 119 | Ok(result) => inference_result(result, true), 120 | Err(err) => { 121 | error!("Error while running inference: {:?}", err); 122 | inference_error() 123 | } 124 | }; 125 | 126 | let duration = start.elapsed(); 127 | info!("Inference took: {:?}", duration); 128 | 129 | rv 130 | } 131 | 132 | fn maybe_dump_debug(stream: Bytes, directory: String) { 133 | use self::mkstemp::TempFile; 134 | use std::io::Write; 135 | 136 | let temp_root = Path::new(&directory); 137 | let temp_file_name = temp_root.join("ds-debug-wav-XXXXXX"); 138 | 139 | debug!( 140 | "Dumping RAW PCM content to {:?} => {:?}", 141 | temp_root, temp_file_name 142 | ); 143 | 144 | match TempFile::new(temp_file_name.to_str().unwrap(), false) { 145 | Ok(mut file) => match file.write(&*stream) { 146 | Ok(_) => debug!("Sucessfully write debug file"), 147 | Err(err) => error!("Error writing content of debug file {:?}", err), 148 | }, 149 | Err(err) => error!("Error creating debug file: {:?}", err), 150 | } 151 | } 152 | 153 | fn maybe_warmup_model(mut m: &mut Model, directory: String, cycles: i32) { 154 | let warmup_dir = Path::new(&directory); 155 | let mut allwaves = Vec::new(); 156 | 157 | for entry in warmup_dir.read_dir().expect("read_dir call failed") { 158 | if let Ok(entry) = entry { 159 | match entry.path().extension() { 160 | Some(ext) if ext == "wav" => { 161 | debug!("Found one more WAV file: {:?}", entry.path()); 162 | allwaves.push(entry.path()); 163 | } 164 | Some(_) => {} 165 | None => {} 166 | } 167 | } 168 | } 169 | 170 | for wave in allwaves.iter() { 171 | debug!("Warmup with {:?}", wave); 172 | if let Ok(audio_file) = File::open(wave) { 173 | if let Ok(mut reader) = Reader::new(audio_file) { 174 | let audio_buf: Vec<_> = reader.samples().map(|s| s.unwrap()).collect::>(); 175 | for i in 0..cycles { 176 | info!("Warmup cycle {} of {}", i + 1, cycles); 177 | inference(&mut m, &*audio_buf); 178 | } 179 | } 180 | } 181 | } 182 | } 183 | 184 | pub fn th_inference( 185 | model: String, 186 | scorer: String, 187 | rx_audio: Receiver<(RawAudioPCM, Sender)>, 188 | dump_dir: String, 189 | warmup_dir: String, 190 | warmup_cycles: i32, 191 | ) { 192 | info!("Inference thread started"); 193 | let mut model_instance = start_model(model, scorer); 194 | 195 | if warmup_dir.len() > 0 { 196 | maybe_warmup_model(&mut model_instance, warmup_dir.clone(), warmup_cycles); 197 | } 198 | 199 | loop { 200 | info!("Model ready and waiting for data to infer ..."); 201 | match rx_audio.recv() { 202 | Ok((audio, tx_string)) => { 203 | info!("Received message: {:?} bytes", audio.content.len()); 204 | 205 | #[cfg(feature = "dump_debug_stream")] 206 | maybe_dump_debug(audio.content.clone(), dump_dir.clone()); 207 | 208 | let inf = match Reader::new(Cursor::new(&*audio.content)) { 209 | Ok(mut reader) => { 210 | let desc = reader.description(); 211 | 212 | match ensure_valid_audio(desc) { 213 | true => { 214 | let audio_buf: Vec<_> = 215 | reader.samples().map(|s| s.unwrap()).collect::>(); 216 | inference(&mut model_instance, &*audio_buf) 217 | } 218 | 219 | false => inference_error(), 220 | } 221 | } 222 | 223 | Err(err) => { 224 | error!("Audrey read error: {:?}", err); 225 | let mut audio_u8 = audio.content.to_vec(); 226 | match audio_u8.as_mut_slice_of::() { 227 | Ok(audio_i16) => { 228 | info!("Trying with RAW PCM {:?} bytes", audio_i16.len()); 229 | inference(&mut model_instance, &*audio_i16) 230 | } 231 | Err(err) => { 232 | error!("Unable to make u8 -> i16: {:?}", err); 233 | inference_error() 234 | } 235 | } 236 | } 237 | }; 238 | 239 | match tx_string.send(inf) { 240 | Ok(_) => {} 241 | Err(err) => error!("Error sending inference result: {:?}", err), 242 | } 243 | } 244 | 245 | Err(err_recv) => error!("Error trying to rx.recv(): {:?}", err_recv), 246 | } 247 | } 248 | } 249 | -------------------------------------------------------------------------------- /src/args.rs: -------------------------------------------------------------------------------- 1 | extern crate clap; 2 | extern crate simplelog; 3 | 4 | use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; 5 | use std::str::FromStr; 6 | 7 | pub type TcpPort = u16; 8 | 9 | #[derive(Debug, Eq, PartialEq, Clone, Copy)] 10 | pub enum VerbosityLevel { 11 | DEBUG = 0, 12 | INFO = 1, 13 | WARN = 2, 14 | ERROR = 3, 15 | } 16 | 17 | impl Into for VerbosityLevel { 18 | fn into(self) -> simplelog::LevelFilter { 19 | match self { 20 | VerbosityLevel::DEBUG => simplelog::LevelFilter::Debug, 21 | VerbosityLevel::INFO => simplelog::LevelFilter::Info, 22 | VerbosityLevel::WARN => simplelog::LevelFilter::Warn, 23 | VerbosityLevel::ERROR => simplelog::LevelFilter::Error, 24 | } 25 | } 26 | } 27 | 28 | #[derive(Debug, Clone)] 29 | /// Holds the program's runtime configuration 30 | pub struct RuntimeConfig { 31 | pub http_ip: IpAddr, 32 | pub http_port: TcpPort, 33 | pub dump_dir: String, 34 | pub warmup_dir: String, 35 | pub warmup_cycles: i32, 36 | pub model: String, 37 | pub scorer: String, 38 | pub verbosity_level: VerbosityLevel, 39 | } 40 | 41 | pub struct ArgsParser; 42 | 43 | impl ArgsParser { 44 | fn to_ip_addr(o: Option<&str>) -> IpAddr { 45 | let default_ip = IpAddr::V6(Ipv6Addr::from_str("::0").unwrap()); 46 | match o { 47 | Some(ip_str) => { 48 | if Ipv6Addr::from_str(ip_str).is_ok() { 49 | IpAddr::V6(Ipv6Addr::from_str(ip_str).unwrap()) 50 | } else if Ipv4Addr::from_str(ip_str).is_ok() { 51 | IpAddr::V4(Ipv4Addr::from_str(ip_str).unwrap()) 52 | } else { 53 | default_ip 54 | } 55 | } 56 | None => default_ip, 57 | } 58 | } 59 | 60 | fn to_port(o: Option<&str>) -> TcpPort { 61 | let default_port = 8080; 62 | match o.unwrap_or(default_port.to_string().as_str()) 63 | .parse::() 64 | { 65 | Ok(rv) => rv, 66 | Err(_) => default_port, 67 | } 68 | } 69 | 70 | fn to_verbosity_level(occ: u64) -> VerbosityLevel { 71 | match occ { 72 | 0 => VerbosityLevel::ERROR, 73 | 1 => VerbosityLevel::WARN, 74 | 2 => VerbosityLevel::INFO, 75 | 3 => VerbosityLevel::DEBUG, 76 | _ => VerbosityLevel::DEBUG, 77 | } 78 | } 79 | 80 | pub fn from_cli() -> RuntimeConfig { 81 | let matches = clap::App::new("DeepSpeech Inference Server") 82 | .version("0.1") 83 | .author("") 84 | .about("Running inference from POST-ed RAW PCM.") 85 | .arg( 86 | clap::Arg::with_name("http_ip") 87 | .short("h") 88 | .long("http_ip") 89 | .value_name("HTTP_IP") 90 | .help("IP address to listen on for HTTP") 91 | .takes_value(true) 92 | .required(false), 93 | ) 94 | .arg( 95 | clap::Arg::with_name("http_port") 96 | .short("p") 97 | .long("http_port") 98 | .value_name("HTTP_PORT") 99 | .help("TCP port to listen on for HTTP") 100 | .takes_value(true) 101 | .required(false), 102 | ) 103 | .arg( 104 | clap::Arg::with_name("dump_dir") 105 | .short("d") 106 | .long("dump_dir") 107 | .value_name("DUMP_DIR") 108 | .help("Directory to use to dump debug WAV streams") 109 | .takes_value(true) 110 | .required(false), 111 | ) 112 | .arg( 113 | clap::Arg::with_name("warmup_dir") 114 | .short("w") 115 | .long("warmup_dir") 116 | .value_name("WARMUP_DIR") 117 | .help("Directory to use to warmup model") 118 | .takes_value(true) 119 | .required(false), 120 | ) 121 | .arg( 122 | clap::Arg::with_name("warmup_cycles") 123 | .short("c") 124 | .long("warmup_cycles") 125 | .value_name("WARMUP_CYCLES") 126 | .help("How many warmup cycles to perform for each WAVE in WARMUP_DIR") 127 | .takes_value(true) 128 | .required(false), 129 | ) 130 | .arg( 131 | clap::Arg::with_name("model") 132 | .short("m") 133 | .long("model") 134 | .value_name("MODEL") 135 | .help("TensorFlow model to use") 136 | .takes_value(true) 137 | .required(true), 138 | ) 139 | .arg( 140 | clap::Arg::with_name("scorer") 141 | .short("scorer") 142 | .long("scorer") 143 | .value_name("Scorer") 144 | .help("External scorer to use") 145 | .takes_value(true) 146 | .required(true), 147 | ) 148 | .arg( 149 | clap::Arg::with_name("v") 150 | .short("v") 151 | .multiple(true) 152 | .help("Sets the level of verbosity"), 153 | ) 154 | .get_matches(); 155 | 156 | RuntimeConfig { 157 | http_ip: ArgsParser::to_ip_addr(matches.value_of("http_ip")), 158 | http_port: ArgsParser::to_port(matches.value_of("http_port")), 159 | dump_dir: String::from(matches.value_of("dump_dir").unwrap_or("/tmp")), 160 | warmup_dir: String::from(matches.value_of("warmup_dir").unwrap_or("")), 161 | warmup_cycles: matches 162 | .value_of("warmup_cycles") 163 | .unwrap_or("10") 164 | .parse::() 165 | .unwrap(), 166 | model: String::from(matches.value_of("model").unwrap()), 167 | scorer: String::from(matches.value_of("scorer").unwrap()), 168 | verbosity_level: ArgsParser::to_verbosity_level(matches.occurrences_of("v")), 169 | } 170 | } 171 | } 172 | 173 | #[test] 174 | fn test_to_ip_addr() { 175 | assert_eq!( 176 | ArgsParser::to_ip_addr(Some("")), 177 | IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)) 178 | ); 179 | assert_eq!( 180 | ArgsParser::to_ip_addr(Some("239.255.0.1")), 181 | IpAddr::V4(Ipv4Addr::new(239, 255, 0, 1)) 182 | ); 183 | assert_eq!( 184 | ArgsParser::to_ip_addr(Some("1.2.3.4")), 185 | IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)) 186 | ); 187 | assert_eq!( 188 | ArgsParser::to_ip_addr(Some("::1")), 189 | IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)) 190 | ); 191 | assert_eq!( 192 | ArgsParser::to_ip_addr(Some("ffx3::1")), 193 | IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)) 194 | ); 195 | assert_eq!( 196 | ArgsParser::to_ip_addr(Some("ff03::1")), 197 | IpAddr::V6(Ipv6Addr::new(0xff03, 0, 0, 0, 0, 0, 0, 1)) 198 | ); 199 | } 200 | 201 | #[test] 202 | fn test_to_port() { 203 | assert_eq!(ArgsParser::to_port(Some("xxx")), 8080); 204 | assert_eq!(ArgsParser::to_port(Some("8080")), 8080); 205 | assert_eq!(ArgsParser::to_port(Some("1234")), 1234); 206 | } 207 | 208 | #[test] 209 | fn test_to_verbosity_level() { 210 | assert_eq!(ArgsParser::to_verbosity_level(0), VerbosityLevel::ERROR); 211 | assert_eq!(ArgsParser::to_verbosity_level(1), VerbosityLevel::WARN); 212 | assert_eq!(ArgsParser::to_verbosity_level(2), VerbosityLevel::INFO); 213 | assert_eq!(ArgsParser::to_verbosity_level(3), VerbosityLevel::DEBUG); 214 | assert_eq!(ArgsParser::to_verbosity_level(4), VerbosityLevel::DEBUG); 215 | assert_eq!(ArgsParser::to_verbosity_level(42), VerbosityLevel::DEBUG); 216 | } 217 | 218 | #[test] 219 | fn test_args() { 220 | let rc = ArgsParser::from_cli(); 221 | 222 | assert_eq!(rc.http_ip.to_string(), "0.0.0.0"); 223 | assert_eq!(rc.http_port.to_string(), "8080"); 224 | assert_eq!(rc.verbosity_level, VerbosityLevel::ERROR); 225 | } 226 | --------------------------------------------------------------------------------