├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── README.md └── src ├── llm ├── LLMInference.cpp ├── LLMInference.h └── main.cpp └── vlm ├── VLMInference.cpp └── VLMInference.h /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | build 3 | models 4 | .idea -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "llama.cpp"] 2 | path = llama.cpp 3 | url = https://github.com/ggml-org/llama.cpp 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | project(llama_inference) 3 | 4 | set(CMAKE_CXX_STANDARD 17) 5 | set(LLAMA_BUILD_COMMON On) 6 | 7 | add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/llama.cpp") 8 | 9 | add_executable( 10 | chat 11 | src/llm/LLMInference.cpp 12 | src/llm/main.cpp 13 | ) 14 | target_link_libraries( 15 | chat 16 | PRIVATE 17 | common llama ggml 18 | ) 19 | 20 | add_executable( 21 | chat-vision 22 | src/vlm/VLMInference.cpp 23 | src/vlm/main.cpp 24 | ) 25 | target_link_libraries( 26 | chat-vision 27 | PRIVATE 28 | common llama ggml 29 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple LLM Inference in C++ with llama.cpp 2 | 3 | ## Build 4 | 5 | ```bash 6 | # clone repository 7 | git clone --depth=1 https://github.com/shubham0204/llama.cpp-simple-chat-interface 8 | cd llama.cpp-simple-chat-interface 9 | # download model 10 | wget https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct-GGUF/resolve/main/smollm2-360m-instruct-q8_0.gguf -P models 11 | # build the executable 12 | mkdir build 13 | cd build 14 | cmake .. 15 | make chat 16 | ./chat 17 | ``` -------------------------------------------------------------------------------- /src/llm/LLMInference.cpp: -------------------------------------------------------------------------------- 1 | #include "LLMInference.h" 2 | #include 3 | #include 4 | 5 | void 6 | LLMInference::loadModel(const std::string& model_path, float min_p, float temperature) { 7 | // create an instance of llama_model 8 | llama_model_params modelParams = llama_model_default_params(); 9 | _model = llama_model_load_from_file(model_path.data(), modelParams); 10 | 11 | if (!_model) { 12 | throw std::runtime_error("load_model() failed"); 13 | } 14 | 15 | // create an instance of llama_context 16 | llama_context_params ctxParams = llama_context_default_params(); 17 | ctxParams.n_ctx = 0; // take context size from the model GGUF file 18 | ctxParams.no_perf = true; // disable performance metrics 19 | _ctx = llama_init_from_model(_model, ctxParams); 20 | 21 | if (!_ctx) { 22 | throw std::runtime_error("llama_new_context_with_model() returned null"); 23 | } 24 | 25 | // initialize sampler 26 | llama_sampler_chain_params samplerParams = llama_sampler_chain_default_params(); 27 | samplerParams.no_perf = true; // disable performance metrics 28 | _sampler = llama_sampler_chain_init(samplerParams); 29 | llama_sampler_chain_add(_sampler, llama_sampler_init_min_p(min_p, 1)); 30 | llama_sampler_chain_add(_sampler, llama_sampler_init_temp(temperature)); 31 | llama_sampler_chain_add(_sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); 32 | 33 | _formattedMessages = std::vector(llama_n_ctx(_ctx)); 34 | _messages.clear(); 35 | } 36 | 37 | void 38 | LLMInference::addChatMessage(const std::string& message, const std::string& role) { 39 | _messages.push_back({ strdup(role.data()), strdup(message.data()) }); 40 | } 41 | 42 | void 43 | LLMInference::startCompletion(const std::string& query) { 44 | addChatMessage(query, "user"); 45 | 46 | // apply the chat-template 47 | const char* tmpl = llama_model_chat_template(_model, nullptr); 48 | int newLen = llama_chat_apply_template(tmpl, _messages.data(), _messages.size(), true, 49 | _formattedMessages.data(), _formattedMessages.size()); 50 | if (newLen > static_cast(_formattedMessages.size())) { 51 | // resize the output buffer `_formattedMessages` 52 | // and re-apply the chat template 53 | _formattedMessages.resize(newLen); 54 | newLen = llama_chat_apply_template(tmpl, _messages.data(), _messages.size(), true, 55 | _formattedMessages.data(), _formattedMessages.size()); 56 | } 57 | if (newLen < 0) { 58 | throw std::runtime_error("llama_chat_apply_template() in " 59 | "LLMInference::start_completion() failed"); 60 | } 61 | std::string prompt(_formattedMessages.begin() + _prevLen, _formattedMessages.begin() + newLen); 62 | _promptTokens = common_tokenize(llama_model_get_vocab(_model), prompt, true, true); 63 | 64 | // create a llama_batch containing a single sequence 65 | // see llama_batch_init for more details 66 | _batch.token = _promptTokens.data(); 67 | _batch.n_tokens = _promptTokens.size(); 68 | } 69 | 70 | std::string 71 | LLMInference::completionLoop() { 72 | // check if the length of the inputs to the model 73 | // have exceeded the context size of the model 74 | int contextSize = llama_n_ctx(_ctx); 75 | int nCtxUsed = llama_get_kv_cache_used_cells(_ctx); 76 | if (nCtxUsed + _batch.n_tokens > contextSize) { 77 | std::cerr << "context size exceeded" << '\n'; 78 | exit(0); 79 | } 80 | // run the model 81 | if (llama_decode(_ctx, _batch) < 0) { 82 | throw std::runtime_error("llama_decode() failed"); 83 | } 84 | 85 | // sample a token and check if it is an EOG (end of generation token) 86 | // convert the integer token to its corresponding word-piece 87 | _currToken = llama_sampler_sample(_sampler, _ctx, -1); 88 | if (llama_vocab_is_eog(llama_model_get_vocab(_model), _currToken)) { 89 | addChatMessage(strdup(_response.data()), "assistant"); 90 | _response.clear(); 91 | return "[EOG]"; 92 | } 93 | std::string piece = common_token_to_piece(_ctx, _currToken, true); 94 | _response += piece; 95 | 96 | // re-init the batch with the newly predicted token 97 | // key, value pairs of all previous tokens have been cached 98 | // in the KV cache 99 | _batch.token = &_currToken; 100 | _batch.n_tokens = 1; 101 | 102 | return piece; 103 | } 104 | 105 | void 106 | LLMInference::stopCompletion() { 107 | const char* tmpl = llama_model_chat_template(_model, nullptr); 108 | _prevLen = llama_chat_apply_template(tmpl, _messages.data(), _messages.size(), false, nullptr, 0); 109 | if (_prevLen < 0) { 110 | throw std::runtime_error("llama_chat_apply_template() in " 111 | "LLMInference::stop_completion() failed"); 112 | } 113 | } 114 | 115 | LLMInference::~LLMInference() { 116 | // free memory held by the message text in messages 117 | // (as we had used strdup() to create a malloc'ed copy) 118 | for (llama_chat_message& message : _messages) { 119 | delete message.content; 120 | } 121 | llama_kv_cache_clear(_ctx); 122 | llama_sampler_free(_sampler); 123 | llama_free(_ctx); 124 | llama_model_free(_model); 125 | } -------------------------------------------------------------------------------- /src/llm/LLMInference.h: -------------------------------------------------------------------------------- 1 | #ifndef LLMINFERENCE_H 2 | #define LLMINFERENCE_H 3 | 4 | #include "common.h" 5 | #include "llama.h" 6 | #include 7 | #include 8 | 9 | class LLMInference { 10 | 11 | // llama.cpp-specific types 12 | llama_context* _ctx; 13 | llama_model* _model; 14 | llama_sampler* _sampler; 15 | llama_batch _batch; 16 | llama_token _currToken; 17 | 18 | // container to store user/assistant messages in the chat 19 | std::vector _messages; 20 | // stores the string generated after applying 21 | // the chat-template to all messages in `_messages` 22 | std::vector _formattedMessages; 23 | // stores the tokens for the last query 24 | // appended to `_messages` 25 | std::vector _promptTokens; 26 | int _prevLen = 0; 27 | 28 | // stores the complete response for the given query 29 | std::string _response = ""; 30 | 31 | public: 32 | 33 | void loadModel(const std::string& modelPath, float minP, float temperature); 34 | 35 | void addChatMessage(const std::string& message, const std::string& role); 36 | 37 | void startCompletion(const std::string& query); 38 | 39 | std::string completionLoop(); 40 | 41 | void stopCompletion(); 42 | 43 | ~LLMInference(); 44 | }; 45 | 46 | #endif -------------------------------------------------------------------------------- /src/llm/main.cpp: -------------------------------------------------------------------------------- 1 | #include "LLMInference.h" 2 | #include 3 | #include 4 | 5 | int main(int argc, char* argv[]) { 6 | std::string modelPath = "../models/DeepSeek-R1-Distill-Qwen-1.5B-Q8_0.gguf"; 7 | float temperature = 1.0f; 8 | float minP = 0.05f; 9 | std::unique_ptr llmInference = std::make_unique(); 10 | llmInference->loadModel(modelPath, minP, temperature); 11 | llmInference->addChatMessage("You are a helpful assistant", "system"); 12 | while (true) { 13 | std::cout << "Enter query:\n"; 14 | std::string query; 15 | std::getline(std::cin, query); 16 | if (query == "exit") { 17 | break; 18 | } 19 | llmInference->startCompletion(query); 20 | std::string predictedToken; 21 | while ((predictedToken = llmInference->completionLoop()) != "[EOG]") { 22 | std::cout << predictedToken; 23 | fflush(stdout); 24 | } 25 | std::cout << '\n'; 26 | } 27 | return 0; 28 | } -------------------------------------------------------------------------------- /src/vlm/VLMInference.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by equip on 21-02-2025. 3 | // 4 | 5 | #include "VLMInference.h" 6 | -------------------------------------------------------------------------------- /src/vlm/VLMInference.h: -------------------------------------------------------------------------------- 1 | #ifndef VLMINFERENCE_H 2 | #define VLMINFERENCE_H 3 | 4 | #include "llama.h" 5 | 6 | class VLMInference { 7 | 8 | }; 9 | 10 | 11 | 12 | #endif //VLMINFERENCE_H 13 | --------------------------------------------------------------------------------