├── demo.png ├── images └── huahua.jpg ├── llamaCpp ├── llama.lib ├── ggml_shared.lib ├── ggml_static.lib ├── common │ ├── build-info.cpp │ ├── build-info.cpp.in │ ├── console.h │ ├── grammar-parser.h │ ├── CMakeLists.txt │ ├── sampling.h │ ├── train.h │ ├── sampling.cpp │ ├── common.h │ ├── base64.hpp │ ├── console.cpp │ ├── grammar-parser.cpp │ └── log.h ├── llamaCpp.pri ├── ggml-backend-impl.h ├── ggml-alloc.h ├── ggml-backend.h ├── ggml-impl.h ├── ggml-quants.h └── ggml-alloc.c ├── shader.qrc ├── images.qrc ├── README.md ├── MiniQwen.pro ├── humanassets.cpp ├── humanassets.h ├── llmmodel.h ├── main.cpp ├── sqlconversationmodel.h ├── main.qml ├── sqlconversationmodel.cpp ├── LICENSE.md └── llmmodel.cpp /demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/MiniQwen/HEAD/demo.png -------------------------------------------------------------------------------- /images/huahua.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/MiniQwen/HEAD/images/huahua.jpg -------------------------------------------------------------------------------- /llamaCpp/llama.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/MiniQwen/HEAD/llamaCpp/llama.lib -------------------------------------------------------------------------------- /llamaCpp/ggml_shared.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/MiniQwen/HEAD/llamaCpp/ggml_shared.lib -------------------------------------------------------------------------------- /llamaCpp/ggml_static.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kleinlee/MiniQwen/HEAD/llamaCpp/ggml_static.lib -------------------------------------------------------------------------------- /shader.qrc: -------------------------------------------------------------------------------- 1 | 2 | 3 | main.qml 4 | 5 | 6 | -------------------------------------------------------------------------------- /images.qrc: -------------------------------------------------------------------------------- 1 | 2 | 3 | images/huahua.jpg 4 | 5 | 6 | -------------------------------------------------------------------------------- /llamaCpp/common/build-info.cpp: -------------------------------------------------------------------------------- 1 | int LLAMA_BUILD_NUMBER = 1601; 2 | char const *LLAMA_COMMIT = "5a7d312"; 3 | char const *LLAMA_COMPILER = "MSVC 19.33.31629.0"; 4 | char const *LLAMA_BUILD_TARGET = "x64"; 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MiniQwen 2 | license: apache-2.0 3 | - llama.cpp+Qwen1.8B,并使用QT搭建一个建议客户端。 4 | - 简化调用代码复杂度,专供Qwen使用 5 | - 内存占用在1.3GB 6 | - i5-12600K CPU上4线程可以达到14-15token/s。对终端设备非常友好。 7 | 8 | teaser_b 9 | -------------------------------------------------------------------------------- /llamaCpp/common/build-info.cpp.in: -------------------------------------------------------------------------------- 1 | int LLAMA_BUILD_NUMBER = @BUILD_NUMBER@; 2 | char const *LLAMA_COMMIT = "@BUILD_COMMIT@"; 3 | char const *LLAMA_COMPILER = "@BUILD_COMPILER@"; 4 | char const *LLAMA_BUILD_TARGET = "@BUILD_TARGET@"; 5 | -------------------------------------------------------------------------------- /llamaCpp/common/console.h: -------------------------------------------------------------------------------- 1 | // Console functions 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace console { 8 | enum display_t { 9 | reset = 0, 10 | prompt, 11 | user_input, 12 | error 13 | }; 14 | 15 | void init(bool use_simple_io, bool use_advanced_display); 16 | void cleanup(); 17 | void set_display(display_t display); 18 | bool readline(std::string & line, bool multiline_input); 19 | } 20 | -------------------------------------------------------------------------------- /MiniQwen.pro: -------------------------------------------------------------------------------- 1 | QT += qml quick sql quickcontrols2 2 | 3 | CONFIG += qmltypes 4 | QML_IMPORT_NAME = AAAAA 5 | QML_IMPORT_MAJOR_VERSION = 1 6 | 7 | HEADERS += \ 8 | humanassets.h \ 9 | llmmodel.h \ 10 | sqlconversationmodel.h 11 | SOURCES += main.cpp \ 12 | humanassets.cpp \ 13 | llmmodel.cpp \ 14 | sqlconversationmodel.cpp 15 | 16 | RESOURCES += \ 17 | shader.qrc\ 18 | images.qrc 19 | 20 | include(llamaCpp/llamaCpp.pri) 21 | 22 | INSTALLS += target 23 | 24 | OTHER_FILES += \ 25 | main.qml 26 | -------------------------------------------------------------------------------- /humanassets.cpp: -------------------------------------------------------------------------------- 1 | #include "humanassets.h" 2 | 3 | HumanAssets* HumanAssets::instance_ = NULL; 4 | HumanAssets::HumanAssets() 5 | { 6 | // instance = NULL; 7 | m_chat_model = new LLMModel; 8 | m_chat_model->LoadModel(); 9 | // connect(ui->pushButton_stop, SIGNAL(clicked()), m_chat_model, SLOT(Reset())); 10 | } 11 | 12 | void HumanAssets::ChatUpdated() 13 | { 14 | emit ChatUpdatedSignal(); 15 | } 16 | 17 | void HumanAssets::SlotStopChat() 18 | { 19 | m_chat_model->Reset(); 20 | } 21 | void HumanAssets::SlotNewChat(QString question) 22 | { 23 | m_chat_model->Run(question); 24 | 25 | } 26 | 27 | void HumanAssets::SlotNewAnswer(QString str) 28 | { 29 | 30 | } 31 | -------------------------------------------------------------------------------- /llamaCpp/common/grammar-parser.h: -------------------------------------------------------------------------------- 1 | // Implements a parser for an extended Backus-Naur form (BNF), producing the 2 | // binary context-free grammar format specified by llama.h. Supports character 3 | // ranges, grouping, and repetition operators. As an example, a grammar for 4 | // arithmetic might look like: 5 | // 6 | // root ::= expr 7 | // expr ::= term ([-+*/] term)* 8 | // term ::= num | "(" space expr ")" space 9 | // num ::= [0-9]+ space 10 | // space ::= [ \t\n]* 11 | 12 | #pragma once 13 | #include "llama.h" 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | namespace grammar_parser { 20 | struct parse_state { 21 | std::map symbol_ids; 22 | std::vector> rules; 23 | 24 | std::vector c_rules(); 25 | }; 26 | 27 | parse_state parse(const char * src); 28 | void print_grammar(FILE * file, const parse_state & state); 29 | } 30 | -------------------------------------------------------------------------------- /humanassets.h: -------------------------------------------------------------------------------- 1 | #ifndef HUMANASSETS_H 2 | #define HUMANASSETS_H 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "llmmodel.h" 8 | class HumanAssets : public QObject 9 | { 10 | Q_OBJECT 11 | private: 12 | HumanAssets(); //构造函数是私有的,这样就不能在其它地方创建该实例 13 | 14 | static HumanAssets *instance_; //定义一个唯一指向实例的静态指针,并且是私有的。 15 | public: 16 | static HumanAssets* GetInstance() //定义一个公有函数,可以获取这个唯一的实例,并且在需要的时候创建该实例。 17 | { 18 | if(instance_ == NULL) //判断是否第一次调用 19 | instance_ = new HumanAssets(); 20 | return instance_; 21 | } 22 | 23 | void ChatUpdated(); 24 | Q_SIGNALS: 25 | void ChatUpdatedSignal(); 26 | public: 27 | QString m_text; 28 | QString m_task_id; 29 | QList m_blendshape; 30 | QDateTime m_start_time_ms; 31 | 32 | LLMModel* m_chat_model; 33 | 34 | public slots: 35 | void SlotNewChat(QString question); 36 | void SlotStopChat(); 37 | void SlotNewAnswer(QString str); 38 | }; 39 | 40 | #endif // HUMANASSETS_H 41 | -------------------------------------------------------------------------------- /llmmodel.h: -------------------------------------------------------------------------------- 1 | #ifndef LLMMODEL_H 2 | #define LLMMODEL_H 3 | 4 | #include "common.h" 5 | #include "llama.h" 6 | #include 7 | #include 8 | #include 9 | class LLMModel : public QObject 10 | { 11 | Q_OBJECT 12 | public: 13 | LLMModel(); 14 | ~LLMModel(); 15 | 16 | int LoadModel(); 17 | void Run(QString qstr_input); 18 | 19 | gpt_params params; 20 | // 输入的token 21 | std::vector embd_inp; 22 | std::vector embd; 23 | std::vector cml_pfx; 24 | std::vector cml_sfx; 25 | 26 | struct llama_sampling_context * ctx_sampling; 27 | int n_ctx; 28 | 29 | int n_remain = 0; 30 | int n_consumed = 0; 31 | 32 | llama_model * model; 33 | llama_context * ctx = NULL; 34 | llama_context * ctx_guidance = NULL; 35 | bool is_interacting = false; 36 | bool m_log_file = true; 37 | 38 | std::string m_input; 39 | QTimer* m_timer; 40 | 41 | QStringList m_output; 42 | signals: 43 | void SignalNewAnswer(QString str, bool end = false); 44 | public slots: 45 | void Update(); 46 | void Reset(); 47 | }; 48 | 49 | #endif // LLMMODEL_H 50 | -------------------------------------------------------------------------------- /llamaCpp/llamaCpp.pri: -------------------------------------------------------------------------------- 1 | SOURCES += \ 2 | llamaCpp/common/build-info.cpp \ 3 | llamaCpp/common/common.cpp \ 4 | llamaCpp/common/console.cpp \ 5 | llamaCpp/common/grammar-parser.cpp \ 6 | llamaCpp/common/sampling.cpp \ 7 | llamaCpp/common/train.cpp 8 | 9 | HEADERS += \ 10 | llamaCpp/common/base64.hpp \ 11 | llamaCpp/common/common.h \ 12 | llamaCpp/common/console.h \ 13 | llamaCpp/common/grammar-parser.h \ 14 | llamaCpp/common/log.h \ 15 | llamaCpp/common/sampling.h \ 16 | llamaCpp/common/stb_image.h \ 17 | llamaCpp/common/train.h \ 18 | llamaCpp/ggml-alloc.h \ 19 | llamaCpp/ggml-backend-impl.h \ 20 | llamaCpp/ggml-backend.h \ 21 | llamaCpp/ggml-impl.h \ 22 | llamaCpp/ggml-quants.h \ 23 | llamaCpp/ggml.h \ 24 | llamaCpp/llama.h \ 25 | llamaCpp/unicode.h 26 | 27 | INCLUDEPATH += llamaCpp/ 28 | INCLUDEPATH += llamaCpp/common 29 | 30 | win32:CONFIG(release, debug|release): LIBS += -L$$PWD/./ -lggml_shared 31 | else:win32:CONFIG(debug, debug|release): LIBS += -L$$PWD/./ -lggml_shared 32 | 33 | INCLUDEPATH += $$PWD/llamaCpp 34 | DEPENDPATH += $$PWD/llamaCpp 35 | 36 | win32:CONFIG(release, debug|release): LIBS += -L$$PWD/./ -lllama 37 | else:win32:CONFIG(debug, debug|release): LIBS += -L$$PWD/./ -lllamad 38 | 39 | win32:CONFIG(release, debug|release): LIBS += -L$$PWD/./ -lggml_static 40 | else:win32:CONFIG(debug, debug|release): LIBS += -L$$PWD/./ -lggml_staticd 41 | -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | static void connectToDatabase() 9 | { 10 | QSqlDatabase database = QSqlDatabase::database(); 11 | if (!database.isValid()) { 12 | database = QSqlDatabase::addDatabase("QSQLITE"); 13 | if (!database.isValid()) 14 | qFatal("Cannot add database: %s", qPrintable(database.lastError().text())); 15 | } 16 | 17 | const QString fileName = "chat-database.sqlite3"; 18 | // When using the SQLite driver, open() will create the SQLite database if it doesn't exist. 19 | database.setDatabaseName(fileName); 20 | // When using the SQLite driver, open() will create the SQLite database if it doesn't exist. 21 | database.setDatabaseName(fileName); 22 | qDebug() << fileName; 23 | if (!database.open()) { 24 | qFatal("Cannot open database: %s", qPrintable(database.lastError().text())); 25 | QFile::remove(fileName); 26 | } 27 | } 28 | int main(int argc, char **argv) 29 | { 30 | QGuiApplication app(argc, argv); 31 | QQuickStyle::setStyle("Material"); 32 | connectToDatabase(); 33 | QQuickView view; 34 | view.setColor(QColor(0,0,0,0)); 35 | view.setResizeMode(QQuickView::SizeRootObjectToView); 36 | view.setSource(QUrl("qrc:///main.qml")); 37 | view.show(); 38 | 39 | return QGuiApplication::exec(); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /sqlconversationmodel.h: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 The Qt Company Ltd. 2 | // SPDX-License-Identifier: LicenseRef-Qt-Commercial OR BSD-3-Clause 3 | 4 | #ifndef SQLCONVERSATIONMODEL_H 5 | #define SQLCONVERSATIONMODEL_H 6 | 7 | #include 8 | #include 9 | #include "humanassets.h" 10 | #include 11 | class SqlConversationModel : public QSqlTableModel 12 | { 13 | Q_OBJECT 14 | QML_NAMED_ELEMENT(SqlConversationModel) 15 | Q_PROPERTY(QString recipient READ recipient WRITE setRecipient NOTIFY recipientChanged) 16 | 17 | public: 18 | SqlConversationModel(QObject *parent = nullptr); 19 | 20 | QString recipient() const; 21 | void setRecipient(const QString &recipient); 22 | 23 | QVariant data(const QModelIndex &index, int role) const override; 24 | QHash roleNames() const override; 25 | 26 | Q_INVOKABLE void sendMessage(const QString &recipient, const QString &message); 27 | Q_INVOKABLE void removeAllMessage(); 28 | 29 | HumanAssets* m_human_assets; 30 | 31 | signals: 32 | void recipientChanged(); 33 | public slots: 34 | void updateMsg(); 35 | void SlotNewAnswer(QString str, bool end = false); 36 | private: 37 | QString m_recipient; 38 | 39 | int m_current_msg_length; 40 | QString m_message; 41 | QString m_message_receive; 42 | QString m_timestamp; 43 | 44 | int sub_text_index = 0; 45 | int m_chat_round; 46 | 47 | bool m_chatEnded = true; 48 | }; 49 | 50 | #endif // SQLCONVERSATIONMODEL_H 51 | -------------------------------------------------------------------------------- /llamaCpp/common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # common 2 | 3 | 4 | # Build info header 5 | # 6 | 7 | if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../.git") 8 | set(GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../.git") 9 | 10 | # Is git submodule 11 | if(NOT IS_DIRECTORY "${GIT_DIR}") 12 | file(READ ${GIT_DIR} REAL_GIT_DIR_LINK) 13 | string(REGEX REPLACE "gitdir: (.*)\n$" "\\1" REAL_GIT_DIR ${REAL_GIT_DIR_LINK}) 14 | set(GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${REAL_GIT_DIR}") 15 | endif() 16 | 17 | set(GIT_INDEX "${GIT_DIR}/index") 18 | else() 19 | message(WARNING "Git repository not found; to enable automatic generation of build info, make sure Git is installed and the project is a Git repository.") 20 | set(GIT_INDEX "") 21 | endif() 22 | 23 | # Add a custom command to rebuild build-info.cpp when .git/index changes 24 | add_custom_command( 25 | OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/build-info.cpp" 26 | COMMENT "Generating build details from Git" 27 | COMMAND ${CMAKE_COMMAND} -DMSVC=${MSVC} -DCMAKE_C_COMPILER_VERSION=${CMAKE_C_COMPILER_VERSION} 28 | -DCMAKE_C_COMPILER_ID=${CMAKE_C_COMPILER_ID} -DCMAKE_VS_PLATFORM_NAME=${CMAKE_VS_PLATFORM_NAME} 29 | -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -P "${CMAKE_CURRENT_SOURCE_DIR}/../scripts/gen-build-info-cpp.cmake" 30 | WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/.." 31 | DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/build-info.cpp.in" ${GIT_INDEX} 32 | VERBATIM 33 | ) 34 | set(TARGET build_info) 35 | add_library(${TARGET} OBJECT build-info.cpp) 36 | if (BUILD_SHARED_LIBS) 37 | set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) 38 | endif() 39 | 40 | 41 | set(TARGET common) 42 | 43 | add_library(${TARGET} STATIC 44 | base64.hpp 45 | common.h 46 | common.cpp 47 | sampling.h 48 | sampling.cpp 49 | console.h 50 | console.cpp 51 | grammar-parser.h 52 | grammar-parser.cpp 53 | train.h 54 | train.cpp 55 | ) 56 | 57 | if (BUILD_SHARED_LIBS) 58 | set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) 59 | endif() 60 | 61 | target_include_directories(${TARGET} PUBLIC .) 62 | target_compile_features(${TARGET} PUBLIC cxx_std_11) 63 | target_link_libraries(${TARGET} PRIVATE llama build_info) 64 | -------------------------------------------------------------------------------- /llamaCpp/ggml-backend-impl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // ggml-backend internal header 4 | 5 | #include "ggml-backend.h" 6 | 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | 11 | // 12 | // Backend buffer 13 | // 14 | 15 | typedef void * ggml_backend_buffer_context_t; 16 | 17 | struct ggml_backend_buffer_i { 18 | void (*free_buffer) (ggml_backend_buffer_t buffer); 19 | void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer 20 | size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback 21 | void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback 22 | void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback 23 | }; 24 | 25 | struct ggml_backend_buffer { 26 | struct ggml_backend_buffer_i iface; 27 | 28 | ggml_backend_t backend; 29 | ggml_backend_buffer_context_t context; 30 | 31 | size_t size; 32 | }; 33 | 34 | GGML_API ggml_backend_buffer_t ggml_backend_buffer_init( 35 | struct ggml_backend * backend, 36 | struct ggml_backend_buffer_i iface, 37 | ggml_backend_buffer_context_t context, 38 | size_t size); 39 | 40 | // 41 | // Backend 42 | // 43 | 44 | typedef void * ggml_backend_context_t; 45 | 46 | struct ggml_backend_i { 47 | const char * (*get_name)(ggml_backend_t backend); 48 | 49 | void (*free)(ggml_backend_t backend); 50 | 51 | // buffer allocation 52 | ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size); 53 | 54 | // get buffer alignment 55 | size_t (*get_alignment)(ggml_backend_t backend); 56 | 57 | // tensor data access 58 | // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize 59 | void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); 60 | void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); 61 | void (*synchronize) (ggml_backend_t backend); 62 | 63 | // (optional) copy tensor between different backends, allow for single-copy tranfers 64 | void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); 65 | void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); 66 | 67 | // compute graph with a plan 68 | ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph); 69 | void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); 70 | void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); 71 | 72 | // compute graph without a plan 73 | void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph); 74 | 75 | // check if the backend supports an operation 76 | bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); 77 | }; 78 | 79 | struct ggml_backend { 80 | struct ggml_backend_i iface; 81 | 82 | ggml_backend_context_t context; 83 | }; 84 | 85 | #ifdef __cplusplus 86 | } 87 | #endif 88 | -------------------------------------------------------------------------------- /llamaCpp/ggml-alloc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | 5 | #ifdef __cplusplus 6 | extern "C" { 7 | #endif 8 | 9 | struct ggml_backend; 10 | struct ggml_backend_buffer; 11 | 12 | // 13 | // Legacy API 14 | // 15 | 16 | typedef struct ggml_allocr * ggml_allocr_t; 17 | 18 | // initialize allocator for use with CPU backend only 19 | GGML_API ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment); 20 | GGML_API ggml_allocr_t ggml_allocr_new_measure(size_t alignment); 21 | 22 | // initialize allocator for use with ggml-backend 23 | GGML_API ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer); 24 | GGML_API ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer 25 | GGML_API ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend); 26 | 27 | GGML_API struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc); 28 | 29 | // tell the allocator to parse nodes following the order described in the list 30 | // you should call this if your graph are optimized to execute out-of-order 31 | GGML_API void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n); 32 | 33 | GGML_API void ggml_allocr_free (ggml_allocr_t alloc); 34 | GGML_API bool ggml_allocr_is_measure (ggml_allocr_t alloc); 35 | GGML_API void ggml_allocr_reset (ggml_allocr_t alloc); 36 | GGML_API void ggml_allocr_alloc (ggml_allocr_t alloc, struct ggml_tensor * tensor); 37 | GGML_API size_t ggml_allocr_max_size (ggml_allocr_t alloc); 38 | 39 | GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph); 40 | 41 | // 42 | // ggml-backend v2 API 43 | // 44 | 45 | // Seperate tensor and graph allocator objects 46 | // This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators 47 | // The original API is kept as a wrapper around the new API 48 | 49 | // Tensor allocator 50 | typedef struct ggml_tallocr * ggml_tallocr_t; 51 | 52 | GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment); 53 | GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment); 54 | GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer); 55 | GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer 56 | GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend); 57 | 58 | GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc); 59 | 60 | GGML_API void ggml_tallocr_free (ggml_tallocr_t talloc); 61 | GGML_API bool ggml_tallocr_is_measure (ggml_tallocr_t talloc); 62 | GGML_API void ggml_tallocr_reset (ggml_tallocr_t talloc); 63 | GGML_API void ggml_tallocr_alloc (ggml_tallocr_t talloc, struct ggml_tensor * tensor); 64 | GGML_API size_t ggml_tallocr_max_size (ggml_tallocr_t talloc); 65 | 66 | 67 | // Graph allocator 68 | typedef struct ggml_gallocr * ggml_gallocr_t; 69 | 70 | GGML_API ggml_gallocr_t ggml_gallocr_new(void); 71 | GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc); 72 | 73 | GGML_API void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n); 74 | GGML_API size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph); 75 | 76 | // Allocate tensors from the allocators given by the hash table 77 | GGML_API void ggml_gallocr_alloc_graph_n( 78 | ggml_gallocr_t galloc, 79 | struct ggml_cgraph * graph, 80 | struct ggml_hash_set hash_set, 81 | ggml_tallocr_t * hash_node_talloc); 82 | 83 | #ifdef __cplusplus 84 | } 85 | #endif 86 | -------------------------------------------------------------------------------- /llamaCpp/common/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "llama.h" 4 | 5 | #include "grammar-parser.h" 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | // sampling parameters 12 | typedef struct llama_sampling_params { 13 | int32_t n_prev = 64; // number of previous tokens to remember 14 | int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. 15 | int32_t top_k = 40; // <= 0 to use vocab size 16 | float top_p = 0.95f; // 1.0 = disabled 17 | float min_p = 0.05f; // 0.0 = disabled 18 | float tfs_z = 1.00f; // 1.0 = disabled 19 | float typical_p = 1.00f; // 1.0 = disabled 20 | float temp = 0.80f; // 1.0 = disabled 21 | int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) 22 | float penalty_repeat = 1.10f; // 1.0 = disabled 23 | float penalty_freq = 0.00f; // 0.0 = disabled 24 | float penalty_present = 0.00f; // 0.0 = disabled 25 | int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 26 | float mirostat_tau = 5.00f; // target entropy 27 | float mirostat_eta = 0.10f; // learning rate 28 | bool penalize_nl = true; // consider newlines as a repeatable token 29 | 30 | std::string grammar; // optional BNF-like grammar to constrain sampling 31 | 32 | // Classifier-Free Guidance 33 | // https://arxiv.org/abs/2306.17806 34 | std::string cfg_negative_prompt; // string to help guidance 35 | float cfg_scale = 1.f; // how strong is guidance 36 | 37 | std::unordered_map logit_bias; // logit bias for specific tokens 38 | } llama_sampling_params; 39 | 40 | // general sampler context 41 | // TODO: move to llama.h 42 | struct llama_sampling_context { 43 | // parameters that will be used for sampling 44 | llama_sampling_params params; 45 | 46 | // mirostat sampler state 47 | float mirostat_mu; 48 | 49 | llama_grammar * grammar; 50 | 51 | // internal 52 | grammar_parser::parse_state parsed_grammar; 53 | 54 | // TODO: replace with ring-buffer 55 | std::vector prev; 56 | std::vector cur; 57 | }; 58 | 59 | #include "common.h" 60 | 61 | // Create a new sampling context instance. 62 | struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); 63 | 64 | void llama_sampling_free(struct llama_sampling_context * ctx); 65 | 66 | // Reset the sampler context 67 | // - clear prev tokens 68 | // - reset grammar 69 | void llama_sampling_reset(llama_sampling_context * ctx); 70 | 71 | // Copy the sampler context 72 | void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); 73 | 74 | // Get the last sampled token 75 | llama_token llama_sampling_last(llama_sampling_context * ctx); 76 | 77 | // Get a string representation of the last sampled tokens 78 | std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); 79 | 80 | // Print sampling parameters into a string 81 | std::string llama_sampling_print(const llama_sampling_params & params); 82 | 83 | // this is a common sampling function used across the examples for convenience 84 | // it can serve as a starting point for implementing your own sampling function 85 | // Note: When using multiple sequences, it is the caller's responsibility to call 86 | // llama_sampling_reset when a sequence ends 87 | // 88 | // required: 89 | // - ctx_main: context to use for sampling 90 | // - ctx_sampling: sampling-specific context 91 | // 92 | // optional: 93 | // - ctx_cfg: context to use for classifier-free guidance 94 | // - idx: sample from llama_get_logits_ith(ctx, idx) 95 | // 96 | // returns: 97 | // - token: sampled token 98 | // - candidates: vector of candidate tokens 99 | // 100 | llama_token llama_sampling_sample( 101 | struct llama_sampling_context * ctx_sampling, 102 | struct llama_context * ctx_main, 103 | struct llama_context * ctx_cfg, 104 | int idx = 0); 105 | 106 | void llama_sampling_accept( 107 | struct llama_sampling_context * ctx_sampling, 108 | struct llama_context * ctx_main, 109 | llama_token id, 110 | bool apply_grammar); 111 | -------------------------------------------------------------------------------- /main.qml: -------------------------------------------------------------------------------- 1 | import QtQuick 2 | import QtQuick.Layouts 3 | import QtQuick.Controls 4 | import AAAAA 5 | 6 | Rectangle { 7 | width: 500 8 | height: 800 9 | color: "#000000FF" 10 | 11 | property string inConversationWith : "huahua" 12 | 13 | Rectangle { 14 | id: rect0 15 | anchors.fill: parent 16 | ColumnLayout { 17 | anchors.fill: parent 18 | 19 | ListView { 20 | id: listView 21 | Layout.fillWidth: true 22 | Layout.fillHeight: true 23 | Layout.margins: pane.leftPadding + messageField.leftPadding 24 | displayMarginBeginning: 40 25 | displayMarginEnd: 40 26 | verticalLayoutDirection: ListView.BottomToTop 27 | spacing: 12 28 | model: SqlConversationModel { 29 | recipient: inConversationWith 30 | } 31 | delegate: Column { 32 | anchors.right: sentByMe ? listView.contentItem.right : undefined 33 | spacing: 6 34 | 35 | readonly property bool sentByMe: model.recipient !== "Me" 36 | 37 | Row { 38 | id: messageRow 39 | spacing: 6 40 | anchors.right: sentByMe ? parent.right : undefined 41 | 42 | Image { 43 | id: avatar 44 | width: 50 45 | height: 50 46 | source: !sentByMe ? "images/" + model.author.replace(" ", "_") + ".jpg" : "" 47 | } 48 | 49 | Rectangle { 50 | width: Math.min(messageText.implicitWidth + 24, listView.width - avatar.width - messageRow.spacing) 51 | height: messageText.implicitHeight + 24 52 | radius: 10 53 | color: sentByMe ? "lightgrey" : "steelblue" 54 | 55 | TextEdit { 56 | id: messageText 57 | text: model.message 58 | font.pixelSize: 14 59 | color: sentByMe ? "black" : "white" 60 | anchors.fill: parent 61 | anchors.margins: 12 62 | wrapMode: Text.WordWrap 63 | selectByMouse: true 64 | readOnly: true 65 | } 66 | } 67 | } 68 | 69 | Label { 70 | id: timestampText 71 | text: Qt.formatDateTime(model.timestamp, "d MMM hh:mm") 72 | color: "lightgrey" 73 | anchors.right: sentByMe ? parent.right : undefined 74 | } 75 | } 76 | 77 | ScrollBar.vertical: ScrollBar {} 78 | } 79 | 80 | Pane { 81 | id: pane 82 | Layout.fillWidth: true 83 | 84 | RowLayout { 85 | width: parent.width 86 | 87 | TextArea { 88 | id: messageField 89 | Layout.fillWidth: true 90 | placeholderText: qsTr("Compose message") 91 | wrapMode: TextArea.Wrap 92 | Keys.onPressed: (event)=> { 93 | if (event.key === Qt.Key_Return) { 94 | listView.model.sendMessage(inConversationWith, messageField.text); 95 | messageField.text = ""; 96 | } 97 | } 98 | onTextChanged: { 99 | if (length > 30) remove(30, length); 100 | } 101 | } 102 | 103 | Button { 104 | id: sendButton 105 | text: qsTr("发送") 106 | enabled: messageField.length > 0 107 | onClicked: { 108 | listView.model.sendMessage(inConversationWith, messageField.text); 109 | messageField.text = ""; 110 | } 111 | } 112 | Button { 113 | id: clearButton 114 | text: qsTr("清空历史") 115 | onClicked: { 116 | listView.model.removeAllMessage(); 117 | } 118 | } 119 | } 120 | } 121 | } 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /sqlconversationmodel.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 The Qt Company Ltd. 2 | // SPDX-License-Identifier: LicenseRef-Qt-Commercial OR BSD-3-Clause 3 | 4 | #include "sqlconversationmodel.h" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | static const char *conversationsTableName = "Conversations"; 13 | 14 | static void createTable() 15 | { 16 | if (QSqlDatabase::database().tables().contains(conversationsTableName)) { 17 | // The table already exists; we don't need to do anything. 18 | return; 19 | } 20 | 21 | QSqlQuery query; 22 | if (!query.exec( 23 | "CREATE TABLE IF NOT EXISTS 'Conversations' (" 24 | // "'id' INTEGER PRIMARY KEY," 25 | "'author' TEXT NOT NULL," 26 | "'recipient' TEXT NOT NULL," 27 | "'timestamp' TEXT NOT NULL," 28 | "'message' TEXT NOT NULL," 29 | "FOREIGN KEY('author') REFERENCES Contacts ( name )," 30 | "FOREIGN KEY('recipient') REFERENCES Contacts ( name )" 31 | ")")) { 32 | qFatal("Failed to query database: %s", qPrintable(query.lastError().text())); 33 | } 34 | 35 | // query.exec("INSERT INTO Conversations VALUES(1, 'Me', 'huahua', '2016-01-01T11:24:53', 'Hi!')"); 36 | query.exec("INSERT INTO Conversations VALUES('huahua', 'Me', '2016-01-07T14:36:16', '我是Qwen。')"); 37 | } 38 | 39 | SqlConversationModel::SqlConversationModel(QObject *parent) : 40 | QSqlTableModel(parent) 41 | { 42 | m_human_assets = HumanAssets::GetInstance(); 43 | createTable(); 44 | setTable(conversationsTableName); 45 | setSort(2, Qt::DescendingOrder); 46 | // Ensures that the model is sorted correctly after submitting a new row. 47 | setEditStrategy(QSqlTableModel::OnManualSubmit); 48 | 49 | // m_timer = new QTimer(this); 50 | // connect(m_timer, &QTimer::timeout, this, &SqlConversationModel::updateMsg); 51 | 52 | connect(this->m_human_assets,SIGNAL(ChatUpdatedSignal()),SLOT(updateMsg())); 53 | 54 | connect(this->m_human_assets->m_chat_model, SIGNAL(SignalNewAnswer(QString, bool)), this, SLOT(SlotNewAnswer(QString, bool))); 55 | 56 | } 57 | 58 | void SqlConversationModel::SlotNewAnswer(QString str, bool end) 59 | { 60 | qDebug() << "SlotNewAnswer end" << end; 61 | m_chatEnded = end; 62 | if (m_chatEnded) 63 | { 64 | return; 65 | } 66 | 67 | m_timestamp = QDateTime::currentDateTime().toString(Qt::ISODateWithMs); 68 | qDebug() << "SlotNewAnswer" << str << rowCount() << m_timestamp; 69 | QSqlRecord newRecord = record(); 70 | // newRecord.setValue("id", m_chat_round + 1); 71 | newRecord.setValue("author", "huahua"); 72 | newRecord.setValue("recipient", "Me"); 73 | newRecord.setValue("timestamp", m_timestamp); 74 | if (sub_text_index == 0) 75 | { 76 | m_message_receive = str; 77 | newRecord.setValue("message", m_message_receive); 78 | if (!insertRecord(rowCount(), newRecord)) { 79 | qWarning() << "Failed to send message:" << lastError().text(); 80 | return; 81 | } 82 | } 83 | else 84 | { 85 | m_message_receive += str; 86 | newRecord.setValue("message", m_message_receive); 87 | if (!setRecord(0, newRecord)) { 88 | qWarning() << "Failed to send message:" << lastError().text(); 89 | return; 90 | } 91 | } 92 | 93 | qDebug() << "sssss" << rowCount() << str ; 94 | submitAll(); 95 | sub_text_index ++; 96 | } 97 | 98 | QString SqlConversationModel::recipient() const 99 | { 100 | return m_recipient; 101 | } 102 | 103 | void SqlConversationModel::setRecipient(const QString &recipient) 104 | { 105 | if (recipient == m_recipient) 106 | return; 107 | 108 | m_recipient = recipient; 109 | 110 | const QString filterString = QString::fromLatin1( 111 | "(recipient = '%1' AND author = 'Me') OR (recipient = 'Me' AND author='%1')").arg(m_recipient); 112 | setFilter(filterString); 113 | select(); 114 | 115 | emit recipientChanged(); 116 | } 117 | 118 | QVariant SqlConversationModel::data(const QModelIndex &index, int role) const 119 | { 120 | if (role < Qt::UserRole) 121 | return QSqlTableModel::data(index, role); 122 | 123 | const QSqlRecord sqlRecord = record(index.row()); 124 | return sqlRecord.value(role - Qt::UserRole); 125 | } 126 | 127 | QHash SqlConversationModel::roleNames() const 128 | { 129 | QHash names; 130 | names[Qt::UserRole + 0] = "author"; 131 | names[Qt::UserRole + 1] = "recipient"; 132 | names[Qt::UserRole + 2] = "timestamp"; 133 | names[Qt::UserRole + 3] = "message"; 134 | return names; 135 | } 136 | 137 | void SqlConversationModel::sendMessage(const QString &recipient, const QString &message) 138 | { 139 | m_chatEnded = false; 140 | m_chat_round = rowCount() + 1; 141 | qDebug() << "sendMessage" << recipient << message << rowCount(); 142 | m_timestamp = QDateTime::currentDateTime().toString(Qt::ISODateWithMs); 143 | m_recipient = recipient; 144 | m_message = message; 145 | 146 | QSqlRecord newRecord = record(); 147 | // newRecord.setValue("id", m_chat_round); 148 | newRecord.setValue("author", "Me"); 149 | newRecord.setValue("recipient", recipient); 150 | newRecord.setValue("timestamp", m_timestamp); 151 | newRecord.setValue("message", m_message); 152 | if (!insertRecord(rowCount(), newRecord)) { 153 | qWarning() << "Failed to send message:" << lastError().text(); 154 | return; 155 | } 156 | qDebug() << "TTTTTTTT" << recipient << message; 157 | submitAll(); 158 | m_human_assets->SlotNewChat(m_message); 159 | sub_text_index = 0; 160 | } 161 | 162 | void SqlConversationModel::updateMsg() 163 | { 164 | } 165 | void SqlConversationModel::removeAllMessage() 166 | { 167 | qDebug() << "removeAllMessage" << m_chatEnded; 168 | if (m_chatEnded) 169 | { 170 | removeRows(0, rowCount() - 1); 171 | } 172 | submitAll(); 173 | } 174 | -------------------------------------------------------------------------------- /llamaCpp/ggml-backend.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | #include "ggml-alloc.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | // 11 | // Backend buffer 12 | // 13 | 14 | struct ggml_backend_buffer; 15 | typedef struct ggml_backend_buffer * ggml_backend_buffer_t; 16 | 17 | // backend buffer functions 18 | GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); 19 | GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); 20 | GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); 21 | GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); 22 | GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); 23 | GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); 24 | GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); 25 | 26 | // 27 | // Backend 28 | // 29 | 30 | struct ggml_backend; 31 | typedef struct ggml_backend * ggml_backend_t; 32 | typedef void * ggml_backend_graph_plan_t; 33 | 34 | GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor); 35 | 36 | GGML_API const char * ggml_backend_name(ggml_backend_t backend); 37 | GGML_API void ggml_backend_free(ggml_backend_t backend); 38 | 39 | GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); 40 | 41 | GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); 42 | 43 | GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); 44 | GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); 45 | 46 | GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); 47 | GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); 48 | 49 | GGML_API void ggml_backend_synchronize(ggml_backend_t backend); 50 | 51 | GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph); 52 | 53 | GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); 54 | GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan); 55 | GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); 56 | GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op); 57 | 58 | // tensor copy between different backends 59 | GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); 60 | 61 | // 62 | // CPU backend 63 | // 64 | 65 | GGML_API ggml_backend_t ggml_backend_cpu_init(void); 66 | 67 | GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend); 68 | GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); 69 | 70 | // Create a backend buffer from an existing pointer 71 | GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size); 72 | 73 | 74 | // 75 | // Backend scheduler 76 | // 77 | 78 | // The backend scheduler allows for multiple backends to be used together 79 | // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends 80 | // The backends are selected based on: 81 | // - the backend that supports the operation 82 | // - the location of the pre-allocated tensors (e.g. the weights) 83 | /* 84 | Example usage: 85 | 86 | sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends); 87 | // sched is initialized with measure allocators and cannot be used until allocated with a measure graph 88 | 89 | // initialize buffers from a measure graph 90 | measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed 91 | 92 | // in build_graph: 93 | build_graph(...) { 94 | // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer) 95 | alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu); 96 | ggml_allocr_alloc(alloc_cpu, tensor); 97 | 98 | // manually assigning nodes to a backend (optional, shouldn't be needed in most cases) 99 | struct ggml_tensor * node = ggml_mul_mat(ctx, ...); 100 | ggml_backend_sched_set_node_backend(sched, node, backend_gpu); 101 | } 102 | 103 | // allocate backend buffers from measure graph 104 | ggml_backend_sched_init_measure(sched, measure_graph); 105 | 106 | // the scheduler is now ready to compute graphs 107 | 108 | // compute 109 | graph = build_graph(sched); 110 | ggml_backend_sched_graph_compute(sched, graph); 111 | */ 112 | 113 | struct ggml_backend_sched; 114 | typedef struct ggml_backend_sched * ggml_backend_sched_t; 115 | 116 | // Initialize a backend scheduler 117 | GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends); 118 | 119 | GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); 120 | 121 | // Initialize backend buffers from a measure graph 122 | GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); 123 | 124 | GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend); 125 | GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend); 126 | 127 | GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); 128 | 129 | // Allocate a graph on the backend scheduler 130 | GGML_API void ggml_backend_sched_graph_compute( 131 | ggml_backend_sched_t sched, 132 | struct ggml_cgraph * graph); 133 | 134 | #ifdef __cplusplus 135 | } 136 | #endif 137 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Tongyi Qianwen RESEARCH LICENSE AGREEMENT 2 | 3 | Tongyi Qianwen Release Date: November 30, 2023 4 | 5 | By clicking to agree or by using or distributing any portion or element of the Tongyi Qianwen Materials, you will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately. 6 | 7 | 1. Definitions 8 | a. This Tongyi Qianwen RESEARCH LICENSE AGREEMENT (this "Agreement") shall mean the terms and conditions for use, reproduction, distribution and modification of the Materials as defined by this Agreement. 9 | b. "We"(or "Us") shall mean Alibaba Cloud. 10 | c. "You" (or "Your") shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Materials for any purpose and in any field of use. 11 | d. "Third Parties" shall mean individuals or legal entities that are not under common control with Us or You. 12 | e. "Tongyi Qianwen" shall mean the large language models, and software and algorithms, consisting of trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Us. 13 | f. "Materials" shall mean, collectively, Alibaba Cloud's proprietary Tongyi Qianwen and Documentation (and any portion thereof) made available under this Agreement. 14 | g. "Source" form shall mean the preferred form for making modifications, including but not limited to model source code, documentation source, and configuration files. 15 | h. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, 16 | and conversions to other media types. 17 | i. "Non-Commercial" shall mean for research or evaluation purposes only. 18 | 19 | 2. Grant of Rights 20 | a. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Alibaba Cloud's intellectual property or other rights owned by Us embodied in the Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Materials FOR NON-COMMERCIAL PURPOSES ONLY. 21 | b. If you are commercially using the Materials, You shall request a license from Us. 22 | 23 | 3. Redistribution 24 | You may reproduce and distribute copies of the Materials or derivative works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 25 | a. You shall give any other recipients of the Materials or derivative works a copy of this Agreement; 26 | b. You shall cause any modified files to carry prominent notices stating that You changed the files; 27 | c. You shall retain in all copies of the Materials that You distribute the following attribution notices within a "Notice" text file distributed as a part of such copies: "Tongyi Qianwen is licensed under the Tongyi Qianwen RESEARCH LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved."; and 28 | d. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such derivative works as a whole, provided Your use, reproduction, and distribution of the work otherwise complies with the terms and conditions of this Agreement. 29 | 30 | 4. Rules of use 31 | a. The Materials may be subject to export controls or restrictions in China, the United States or other countries or regions. You shall comply with applicable laws and regulations in your use of the Materials. 32 | b. You can not use the Materials or any output therefrom to improve any other large language model (excluding Tongyi Qianwen or derivative works thereof). 33 | 34 | 5. Intellectual Property 35 | a. We retain ownership of all intellectual property rights in and to the Materials and derivatives made by or for Us. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by you, you are and will be the owner of such derivative works and modifications. 36 | b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials. 37 | c. If you commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any entity alleging that the Materials or any output therefrom, or any part of the foregoing, infringe any intellectual property or other right owned or licensable by you, then all licences granted to you under this Agreement shall terminate as of the date such lawsuit or other proceeding is commenced or brought. 38 | 39 | 6. Disclaimer of Warranty and Limitation of Liability 40 | a. We are not obligated to support, update, provide training for, or develop any further version of the Tongyi Qianwen Materials or to grant any license thereto. 41 | b. THE MATERIALS ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM. 42 | c. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED. 43 | d. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to your use or distribution of the Materials. 44 | 45 | 7. Survival and Termination. 46 | a. The term of this Agreement shall commence upon your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. 47 | b. We may terminate this Agreement if you breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, you must delete and cease use of the Materials. Sections 6 and 8 shall survive the termination of this Agreement. 48 | 49 | 8. Governing Law and Jurisdiction. 50 | a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. 51 | b. The People's Courts in Hangzhou City shall have exclusive jurisdiction over any dispute arising out of this Agreement. 52 | 53 | 9. Other Terms and Conditions. 54 | a. Any arrangements, understandings, or agreements regarding the Material not stated herein are separate from and independent of the terms and conditions of this Agreement. You shall request a seperate license from Us, if You use the Materials in ways not expressly agreed to in this Agreement. 55 | b. We shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 56 | -------------------------------------------------------------------------------- /llamaCpp/ggml-impl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | 5 | // GGML internal header 6 | 7 | #include 8 | #include 9 | #include 10 | #include // memcpy 11 | #include // fabsf 12 | 13 | #ifdef __cplusplus 14 | extern "C" { 15 | #endif 16 | 17 | // static_assert should be a #define, but if it's not, 18 | // fall back to the _Static_assert C11 keyword. 19 | // if C99 - static_assert is noop 20 | // ref: https://stackoverflow.com/a/53923785/4039976 21 | #ifndef static_assert 22 | #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) 23 | #define static_assert(cond, msg) _Static_assert(cond, msg) 24 | #else 25 | #define static_assert(cond, msg) struct global_scope_noop_trick 26 | #endif 27 | #endif 28 | 29 | // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 30 | #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) 31 | #ifndef __FMA__ 32 | #define __FMA__ 33 | #endif 34 | #ifndef __F16C__ 35 | #define __F16C__ 36 | #endif 37 | #ifndef __SSE3__ 38 | #define __SSE3__ 39 | #endif 40 | #endif 41 | 42 | // 16-bit float 43 | // on Arm, we use __fp16 44 | // on x86, we use uint16_t 45 | #if defined(__ARM_NEON) && !defined(_MSC_VER) 46 | 47 | // if YCM cannot find , make a symbolic link to it, for example: 48 | // 49 | // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ 50 | // 51 | #include 52 | 53 | #define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) 54 | #define GGML_COMPUTE_FP32_TO_FP16(x) (x) 55 | 56 | #define GGML_FP16_TO_FP32(x) ((float) (x)) 57 | #define GGML_FP32_TO_FP16(x) (x) 58 | 59 | #else 60 | 61 | #ifdef __wasm_simd128__ 62 | #include 63 | #else 64 | #ifdef __POWER9_VECTOR__ 65 | #include 66 | #undef bool 67 | #define bool _Bool 68 | #else 69 | #if defined(_MSC_VER) || defined(__MINGW32__) 70 | #include 71 | #else 72 | #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) 73 | #if !defined(__riscv) 74 | #include 75 | #endif 76 | #endif 77 | #endif 78 | #endif 79 | #endif 80 | 81 | #ifdef __riscv_v_intrinsic 82 | #include 83 | #endif 84 | 85 | #ifdef __F16C__ 86 | 87 | #ifdef _MSC_VER 88 | #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) 89 | #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) 90 | #else 91 | #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) 92 | #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) 93 | #endif 94 | 95 | #elif defined(__POWER9_VECTOR__) 96 | 97 | #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) 98 | #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) 99 | /* the inline asm below is about 12% faster than the lookup method */ 100 | #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) 101 | #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) 102 | 103 | static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { 104 | register float f; 105 | register double d; 106 | __asm__( 107 | "mtfprd %0,%2\n" 108 | "xscvhpdp %0,%0\n" 109 | "frsp %1,%0\n" : 110 | /* temp */ "=d"(d), 111 | /* out */ "=f"(f): 112 | /* in */ "r"(h)); 113 | return f; 114 | } 115 | 116 | static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { 117 | register double d; 118 | register ggml_fp16_t r; 119 | __asm__( /* xscvdphp can work on double or single precision */ 120 | "xscvdphp %0,%2\n" 121 | "mffprd %1,%0\n" : 122 | /* temp */ "=d"(d), 123 | /* out */ "=r"(r): 124 | /* in */ "f"(f)); 125 | return r; 126 | } 127 | 128 | #else 129 | 130 | // FP16 <-> FP32 131 | // ref: https://github.com/Maratyszcza/FP16 132 | 133 | static inline float fp32_from_bits(uint32_t w) { 134 | union { 135 | uint32_t as_bits; 136 | float as_value; 137 | } fp32; 138 | fp32.as_bits = w; 139 | return fp32.as_value; 140 | } 141 | 142 | static inline uint32_t fp32_to_bits(float f) { 143 | union { 144 | float as_value; 145 | uint32_t as_bits; 146 | } fp32; 147 | fp32.as_value = f; 148 | return fp32.as_bits; 149 | } 150 | 151 | static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { 152 | const uint32_t w = (uint32_t) h << 16; 153 | const uint32_t sign = w & UINT32_C(0x80000000); 154 | const uint32_t two_w = w + w; 155 | 156 | const uint32_t exp_offset = UINT32_C(0xE0) << 23; 157 | #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) 158 | const float exp_scale = 0x1.0p-112f; 159 | #else 160 | const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); 161 | #endif 162 | const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; 163 | 164 | const uint32_t magic_mask = UINT32_C(126) << 23; 165 | const float magic_bias = 0.5f; 166 | const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; 167 | 168 | const uint32_t denormalized_cutoff = UINT32_C(1) << 27; 169 | const uint32_t result = sign | 170 | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); 171 | return fp32_from_bits(result); 172 | } 173 | 174 | static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { 175 | #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) 176 | const float scale_to_inf = 0x1.0p+112f; 177 | const float scale_to_zero = 0x1.0p-110f; 178 | #else 179 | const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); 180 | const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); 181 | #endif 182 | float base = (fabsf(f) * scale_to_inf) * scale_to_zero; 183 | 184 | const uint32_t w = fp32_to_bits(f); 185 | const uint32_t shl1_w = w + w; 186 | const uint32_t sign = w & UINT32_C(0x80000000); 187 | uint32_t bias = shl1_w & UINT32_C(0xFF000000); 188 | if (bias < UINT32_C(0x71000000)) { 189 | bias = UINT32_C(0x71000000); 190 | } 191 | 192 | base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; 193 | const uint32_t bits = fp32_to_bits(base); 194 | const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); 195 | const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); 196 | const uint32_t nonsign = exp_bits + mantissa_bits; 197 | return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); 198 | } 199 | 200 | #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) 201 | #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) 202 | 203 | #endif // __F16C__ 204 | 205 | #endif // __ARM_NEON 206 | 207 | // precomputed f32 table for f16 (256 KB) 208 | // defined in ggml.c, initialized in ggml_init() 209 | extern float ggml_table_f32_f16[1 << 16]; 210 | 211 | // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, 212 | // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. 213 | // This is also true for POWER9. 214 | #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) 215 | 216 | inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { 217 | uint16_t s; 218 | memcpy(&s, &f, sizeof(uint16_t)); 219 | return ggml_table_f32_f16[s]; 220 | } 221 | 222 | #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) 223 | #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) 224 | 225 | #endif 226 | 227 | #define GGML_HASHTABLE_FULL ((size_t)-1) 228 | #define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2) 229 | 230 | bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key); 231 | 232 | // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted 233 | size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key); 234 | 235 | // returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full 236 | size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key); 237 | 238 | // return index, asserts if table is full 239 | size_t ggml_hash_find_or_insert( struct ggml_hash_set hash_set, struct ggml_tensor * key); 240 | 241 | #ifdef __cplusplus 242 | } 243 | #endif 244 | -------------------------------------------------------------------------------- /llamaCpp/common/train.h: -------------------------------------------------------------------------------- 1 | // Various helper functions and utilities for training 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "ggml.h" 10 | #include "llama.h" 11 | 12 | #define LLAMA_TRAIN_MAX_NODES 16384 13 | 14 | typedef std::string mt19937_state; 15 | 16 | struct train_state { 17 | struct ggml_opt_context * opt; 18 | 19 | uint64_t train_its; 20 | uint64_t train_samples; 21 | uint64_t train_tokens; 22 | uint64_t train_epochs; 23 | 24 | size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes) 25 | mt19937_state shuffle_rng_state_current; 26 | mt19937_state shuffle_rng_state_next; 27 | size_t shuffle_sample_count; 28 | size_t shuffle_next_sample; 29 | }; 30 | 31 | struct train_params_common { 32 | const char * fn_train_data; 33 | const char * fn_checkpoint_in; 34 | const char * fn_checkpoint_out; 35 | const char * pattern_fn_it; 36 | const char * fn_latest; 37 | 38 | bool print_usage; 39 | 40 | int save_every; 41 | 42 | uint32_t seed; 43 | 44 | int n_ctx; 45 | int n_threads; 46 | int n_batch; 47 | int n_gradient_accumulation; 48 | int n_epochs; 49 | int n_gpu_layers; 50 | 51 | bool custom_n_ctx; 52 | 53 | bool use_flash; 54 | bool use_checkpointing; 55 | 56 | std::string sample_start; 57 | bool include_sample_start; 58 | bool escape; 59 | bool overlapping_samples; 60 | bool fill_with_next_samples; 61 | bool separate_with_eos; 62 | bool separate_with_bos; 63 | bool sample_random_offsets; 64 | 65 | bool force_reshuffle; 66 | 67 | int warmup; 68 | int cos_decay_steps; 69 | float cos_decay_restart; 70 | float cos_decay_min; 71 | bool enable_restart; 72 | 73 | int opt_past; 74 | float opt_delta; 75 | int opt_max_no_improvement; 76 | 77 | int adam_n_iter; 78 | float adam_alpha; 79 | float adam_min_alpha; 80 | float adam_decay; 81 | int adam_decay_min_ndim; 82 | float adam_beta1; 83 | float adam_beta2; 84 | float adam_gclip; 85 | float adam_eps_f; 86 | }; 87 | 88 | typedef void (*save_train_files_callback)(void * data, struct train_state * train); 89 | 90 | struct train_opt_callback_data { 91 | struct train_params_common * params; 92 | struct train_state * train; 93 | save_train_files_callback save_cb; 94 | void * save_data; 95 | struct llama_context * lctx; 96 | int last_save_iter; 97 | llama_token * tokens_data; 98 | size_t tokens_size; 99 | size_t * samples_begin; 100 | size_t * samples_size; 101 | size_t * shuffled_samples_offs; 102 | size_t * shuffled_samples_begin; 103 | size_t * shuffled_samples_size; 104 | size_t samples_count; 105 | struct ggml_tensor * tokens_input; 106 | struct ggml_tensor * target_probs; 107 | int first_iter; 108 | int first_epoch; 109 | int iter_at_last_epoch; 110 | int64_t last_time; 111 | double millis_per_iter; 112 | }; 113 | 114 | struct train_state * init_train_state(); 115 | void free_train_state(struct train_state * state); 116 | 117 | struct train_params_common get_default_train_params_common(); 118 | void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params); 119 | 120 | bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param); 121 | void finish_processing_train_args(struct train_params_common * params); 122 | 123 | struct random_normal_distribution; 124 | struct random_uniform_distribution; 125 | 126 | struct random_normal_distribution * init_random_normal_distribution (int seed, float mean, float std, float min, float max); 127 | struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max); 128 | 129 | void free_random_normal_distribution (struct random_normal_distribution * rnd); 130 | void free_random_uniform_distribution(struct random_uniform_distribution * rnd); 131 | 132 | struct ggml_tensor * randomize_tensor_normal (struct ggml_tensor * tensor, struct random_normal_distribution * rnd); 133 | struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd); 134 | 135 | // generate random float in interval [0,1) 136 | float frand(); 137 | float frand_normal (struct random_normal_distribution * rnd); 138 | float frand_uniform(struct random_uniform_distribution * rnd); 139 | 140 | int clamp (const int v, const int min, const int max); 141 | float fclamp(const float v, const float min, const float max); 142 | 143 | void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0); 144 | void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1); 145 | void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2); 146 | void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3); 147 | 148 | size_t tokenize_file( 149 | struct llama_context * lctx, 150 | const char * filename, 151 | const std::string & sample_start, 152 | bool include_sample_start, 153 | bool overlapping_samples, 154 | unsigned context_length, 155 | std::vector & out_tokens, 156 | std::vector & out_samples_begin, 157 | std::vector & out_samples_size); 158 | 159 | int64_t get_example_targets_batch( 160 | struct llama_context * lctx, 161 | struct ggml_tensor * tokens_input, 162 | struct ggml_tensor * target_probs, 163 | int64_t example_id, 164 | const size_t * samples_offs, 165 | const size_t * samples_begin, 166 | const size_t * samples_size, 167 | size_t samples_count, 168 | const llama_token * train_data, 169 | size_t n_train_data, 170 | bool separate_with_eos, 171 | bool separate_with_bos, 172 | bool fill_with_next_samples, 173 | bool sample_random_offsets); 174 | 175 | 176 | void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state); 177 | mt19937_state mt19937_get_state(const std::mt19937& rng); 178 | mt19937_state mt19937_seed_to_state(unsigned seed); 179 | 180 | mt19937_state shuffle_samples( 181 | const mt19937_state & rng_state, 182 | size_t * shuffled_offs, 183 | size_t * shuffled_begins, 184 | size_t * shuffled_sizes, 185 | const size_t * begins, 186 | const size_t * sizes, 187 | size_t count); 188 | 189 | size_t hash_combine(size_t h1, size_t h2); 190 | 191 | size_t compute_samples_hash( 192 | const char* fn, 193 | const size_t* samples_begin, 194 | const size_t* samples_size, 195 | size_t sample_count); 196 | 197 | 198 | std::string replace_str(const char * s, const char * needle, const char * replacement); 199 | 200 | void print_duration(double milliseconds); 201 | 202 | float cosine_decay( 203 | int64_t step, 204 | int64_t decay_steps, 205 | float minimum); 206 | 207 | float cosine_decay_restart( 208 | int64_t step, 209 | int64_t decay_steps, 210 | float minimum, 211 | float restart_step_mult); 212 | 213 | float learning_schedule( 214 | int64_t step, 215 | int64_t warmup_steps, 216 | int64_t decay_steps, 217 | float learning_rate, 218 | float overall_minimum, 219 | float cos_decay_minimum, 220 | float cos_decay_restart_step_mult, 221 | bool enable_restart); 222 | 223 | void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name); 224 | 225 | void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt); 226 | void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt); 227 | 228 | bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train); 229 | void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train); 230 | 231 | std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration); 232 | 233 | void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel); 234 | -------------------------------------------------------------------------------- /llamaCpp/common/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | 3 | struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { 4 | struct llama_sampling_context * result = new llama_sampling_context(); 5 | 6 | result->params = params; 7 | result->grammar = nullptr; 8 | 9 | // if there is a grammar, parse it 10 | if (!params.grammar.empty()) { 11 | result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); 12 | 13 | // will be empty (default) if there are parse errors 14 | if (result->parsed_grammar.rules.empty()) { 15 | fprintf(stderr, "%s: failed to parse grammar\n", __func__); 16 | return nullptr; 17 | } 18 | 19 | std::vector grammar_rules(result->parsed_grammar.c_rules()); 20 | 21 | result->grammar = llama_grammar_init( 22 | grammar_rules.data(), 23 | grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); 24 | } 25 | 26 | result->prev.resize(params.n_prev); 27 | 28 | return result; 29 | } 30 | 31 | void llama_sampling_free(struct llama_sampling_context * ctx) { 32 | if (ctx->grammar != NULL) { 33 | llama_grammar_free(ctx->grammar); 34 | } 35 | 36 | delete ctx; 37 | } 38 | 39 | void llama_sampling_reset(llama_sampling_context * ctx) { 40 | if (ctx->grammar != NULL) { 41 | llama_grammar_free(ctx->grammar); 42 | ctx->grammar = NULL; 43 | } 44 | 45 | if (!ctx->parsed_grammar.rules.empty()) { 46 | std::vector grammar_rules(ctx->parsed_grammar.c_rules()); 47 | 48 | ctx->grammar = llama_grammar_init( 49 | grammar_rules.data(), 50 | grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); 51 | } 52 | 53 | std::fill(ctx->prev.begin(), ctx->prev.end(), 0); 54 | ctx->cur.clear(); 55 | } 56 | 57 | void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { 58 | if (dst->grammar) { 59 | llama_grammar_free(dst->grammar); 60 | dst->grammar = nullptr; 61 | } 62 | 63 | if (src->grammar) { 64 | dst->grammar = llama_grammar_copy(src->grammar); 65 | } 66 | 67 | dst->prev = src->prev; 68 | } 69 | 70 | llama_token llama_sampling_last(llama_sampling_context * ctx) { 71 | return ctx->prev.back(); 72 | } 73 | 74 | std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) { 75 | const int size = ctx_sampling->prev.size(); 76 | 77 | n = std::min(n, size); 78 | 79 | std::string result; 80 | 81 | for (int i = size - n; i < size; i++) { 82 | result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]); 83 | } 84 | 85 | return result; 86 | } 87 | 88 | std::string llama_sampling_print(const llama_sampling_params & params) { 89 | char result[1024]; 90 | 91 | snprintf(result, sizeof(result), 92 | "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" 93 | "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" 94 | "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", 95 | params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, 96 | params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, 97 | params.mirostat, params.mirostat_eta, params.mirostat_tau); 98 | 99 | return std::string(result); 100 | } 101 | 102 | llama_token llama_sampling_sample( 103 | struct llama_sampling_context * ctx_sampling, 104 | struct llama_context * ctx_main, 105 | struct llama_context * ctx_cfg, 106 | const int idx) { 107 | const llama_sampling_params & params = ctx_sampling->params; 108 | 109 | const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); 110 | 111 | const float temp = params.temp; 112 | const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; 113 | const float top_p = params.top_p; 114 | const float min_p = params.min_p; 115 | const float tfs_z = params.tfs_z; 116 | const float typical_p = params.typical_p; 117 | const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; 118 | const float penalty_repeat = params.penalty_repeat; 119 | const float penalty_freq = params.penalty_freq; 120 | const float penalty_present = params.penalty_present; 121 | const int mirostat = params.mirostat; 122 | const float mirostat_tau = params.mirostat_tau; 123 | const float mirostat_eta = params.mirostat_eta; 124 | const bool penalize_nl = params.penalize_nl; 125 | 126 | auto & prev = ctx_sampling->prev; 127 | auto & cur = ctx_sampling->cur; 128 | 129 | llama_token id = 0; 130 | 131 | float * logits = llama_get_logits_ith(ctx_main, idx); 132 | 133 | // apply params.logit_bias map 134 | for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { 135 | logits[it->first] += it->second; 136 | } 137 | 138 | cur.clear(); 139 | 140 | for (llama_token token_id = 0; token_id < n_vocab; token_id++) { 141 | cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); 142 | } 143 | 144 | llama_token_data_array cur_p = { cur.data(), cur.size(), false }; 145 | 146 | if (ctx_cfg) { 147 | llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale); 148 | } 149 | 150 | // apply penalties 151 | if (!prev.empty()) { 152 | const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; 153 | 154 | llama_sample_repetition_penalties(ctx_main, &cur_p, 155 | prev.data() + prev.size() - penalty_last_n, 156 | penalty_last_n, penalty_repeat, penalty_freq, penalty_present); 157 | 158 | if (!penalize_nl) { 159 | for (size_t idx = 0; idx < cur_p.size; idx++) { 160 | if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { 161 | cur_p.data[idx].logit = nl_logit; 162 | break; 163 | } 164 | } 165 | } 166 | } 167 | 168 | if (ctx_sampling->grammar != NULL) { 169 | llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); 170 | } 171 | 172 | if (temp < 0.0) { 173 | // greedy sampling, with probs 174 | llama_sample_softmax(ctx_main, &cur_p); 175 | id = cur_p.data[0].id; 176 | } else if (temp == 0.0) { 177 | // greedy sampling, no probs 178 | id = llama_sample_token_greedy(ctx_main, &cur_p); 179 | } else { 180 | if (mirostat == 1) { 181 | const int mirostat_m = 100; 182 | llama_sample_temp(ctx_main, &cur_p, temp); 183 | id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); 184 | } else if (mirostat == 2) { 185 | llama_sample_temp(ctx_main, &cur_p, temp); 186 | id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); 187 | } else { 188 | // temperature sampling 189 | size_t min_keep = std::max(1, params.n_probs); 190 | 191 | llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); 192 | llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); 193 | llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); 194 | llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); 195 | llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); 196 | llama_sample_temp (ctx_main, &cur_p, temp); 197 | 198 | id = llama_sample_token(ctx_main, &cur_p); 199 | 200 | //{ 201 | // const int n_top = 10; 202 | // LOG("top %d candidates:\n", n_top); 203 | 204 | // for (int i = 0; i < n_top; i++) { 205 | // const llama_token id = cur_p.data[i].id; 206 | // (void)id; // To avoid a warning that id is unused when logging is disabled. 207 | // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p); 208 | // } 209 | //} 210 | 211 | LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str()); 212 | } 213 | } 214 | 215 | return id; 216 | } 217 | 218 | void llama_sampling_accept( 219 | struct llama_sampling_context * ctx_sampling, 220 | struct llama_context * ctx_main, 221 | llama_token id, 222 | bool apply_grammar) { 223 | ctx_sampling->prev.erase(ctx_sampling->prev.begin()); 224 | ctx_sampling->prev.push_back(id); 225 | 226 | if (ctx_sampling->grammar != NULL && apply_grammar) { 227 | llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /llamaCpp/ggml-quants.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml-impl.h" 4 | 5 | // GGML internal header 6 | 7 | #include 8 | #include 9 | 10 | #define QK4_0 32 11 | typedef struct { 12 | ggml_fp16_t d; // delta 13 | uint8_t qs[QK4_0 / 2]; // nibbles / quants 14 | } block_q4_0; 15 | static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); 16 | 17 | #define QK4_1 32 18 | typedef struct { 19 | ggml_fp16_t d; // delta 20 | ggml_fp16_t m; // min 21 | uint8_t qs[QK4_1 / 2]; // nibbles / quants 22 | } block_q4_1; 23 | static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); 24 | 25 | #define QK5_0 32 26 | typedef struct { 27 | ggml_fp16_t d; // delta 28 | uint8_t qh[4]; // 5-th bit of quants 29 | uint8_t qs[QK5_0 / 2]; // nibbles / quants 30 | } block_q5_0; 31 | static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); 32 | 33 | #define QK5_1 32 34 | typedef struct { 35 | ggml_fp16_t d; // delta 36 | ggml_fp16_t m; // min 37 | uint8_t qh[4]; // 5-th bit of quants 38 | uint8_t qs[QK5_1 / 2]; // nibbles / quants 39 | } block_q5_1; 40 | static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); 41 | 42 | #define QK8_0 32 43 | typedef struct { 44 | ggml_fp16_t d; // delta 45 | int8_t qs[QK8_0]; // quants 46 | } block_q8_0; 47 | static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); 48 | 49 | #define QK8_1 32 50 | typedef struct { 51 | float d; // delta 52 | float s; // d * sum(qs[i]) 53 | int8_t qs[QK8_1]; // quants 54 | } block_q8_1; 55 | static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); 56 | 57 | // 58 | // Super-block quantization structures 59 | // 60 | 61 | // Super-block size 62 | #ifdef GGML_QKK_64 63 | #define QK_K 64 64 | #define K_SCALE_SIZE 4 65 | #else 66 | #define QK_K 256 67 | #define K_SCALE_SIZE 12 68 | #endif 69 | 70 | // 2-bit quantization 71 | // weight is represented as x = a * q + b 72 | // 16 blocks of 16 elements each 73 | // Effectively 2.5625 bits per weight 74 | typedef struct { 75 | uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits 76 | uint8_t qs[QK_K/4]; // quants 77 | ggml_fp16_t d; // super-block scale for quantized scales 78 | ggml_fp16_t dmin; // super-block scale for quantized mins 79 | } block_q2_K; 80 | static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); 81 | 82 | // 3-bit quantization 83 | // weight is represented as x = a * q 84 | // 16 blocks of 16 elements each 85 | // Effectively 3.4375 bits per weight 86 | #ifdef GGML_QKK_64 87 | typedef struct { 88 | uint8_t hmask[QK_K/8]; // quants - high bit 89 | uint8_t qs[QK_K/4]; // quants - low 2 bits 90 | uint8_t scales[2]; 91 | ggml_fp16_t d; // super-block scale 92 | } block_q3_K; 93 | static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); 94 | #else 95 | typedef struct { 96 | uint8_t hmask[QK_K/8]; // quants - high bit 97 | uint8_t qs[QK_K/4]; // quants - low 2 bits 98 | uint8_t scales[12]; // scales, quantized with 6 bits 99 | ggml_fp16_t d; // super-block scale 100 | } block_q3_K; 101 | static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); 102 | #endif 103 | 104 | // 4-bit quantization 105 | // 8 blocks of 32 elements each 106 | // weight is represented as x = a * q + b 107 | // Effectively 4.5 bits per weight 108 | #ifdef GGML_QKK_64 109 | typedef struct { 110 | ggml_fp16_t d[2]; // super-block scales/mins 111 | uint8_t scales[2]; // 4-bit block scales/mins 112 | uint8_t qs[QK_K/2]; // 4--bit quants 113 | } block_q4_K; 114 | static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); 115 | #else 116 | typedef struct { 117 | ggml_fp16_t d; // super-block scale for quantized scales 118 | ggml_fp16_t dmin; // super-block scale for quantized mins 119 | uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits 120 | uint8_t qs[QK_K/2]; // 4--bit quants 121 | } block_q4_K; 122 | static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); 123 | #endif 124 | 125 | // 5-bit quantization 126 | // 8 blocks of 32 elements each 127 | // weight is represented as x = a * q + b 128 | // Effectively 5.5 bits per weight 129 | #ifdef GGML_QKK_64 130 | typedef struct { 131 | ggml_fp16_t d; // super-block scale 132 | int8_t scales[QK_K/16]; // 8-bit block scales 133 | uint8_t qh[QK_K/8]; // quants, high bit 134 | uint8_t qs[QK_K/2]; // quants, low 4 bits 135 | } block_q5_K; 136 | static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); 137 | #else 138 | typedef struct { 139 | ggml_fp16_t d; // super-block scale for quantized scales 140 | ggml_fp16_t dmin; // super-block scale for quantized mins 141 | uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits 142 | uint8_t qh[QK_K/8]; // quants, high bit 143 | uint8_t qs[QK_K/2]; // quants, low 4 bits 144 | } block_q5_K; 145 | static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); 146 | #endif 147 | 148 | // 6-bit quantization 149 | // weight is represented as x = a * q 150 | // 16 blocks of 16 elements each 151 | // Effectively 6.5625 bits per weight 152 | typedef struct { 153 | uint8_t ql[QK_K/2]; // quants, lower 4 bits 154 | uint8_t qh[QK_K/4]; // quants, upper 2 bits 155 | int8_t scales[QK_K/16]; // scales, quantized with 8 bits 156 | ggml_fp16_t d; // super-block scale 157 | } block_q6_K; 158 | static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); 159 | 160 | // This is only used for intermediate quantization and dot products 161 | typedef struct { 162 | float d; // delta 163 | int8_t qs[QK_K]; // quants 164 | int16_t bsums[QK_K/16]; // sum of quants in groups of 16 165 | } block_q8_K; 166 | static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); 167 | 168 | 169 | // Quantization 170 | void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k); 171 | void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k); 172 | void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k); 173 | void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k); 174 | void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k); 175 | void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k); 176 | 177 | void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k); 178 | void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k); 179 | void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k); 180 | void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k); 181 | void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k); 182 | void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k); 183 | 184 | void quantize_row_q4_0(const float * restrict x, void * restrict y, int k); 185 | void quantize_row_q4_1(const float * restrict x, void * restrict y, int k); 186 | void quantize_row_q5_0(const float * restrict x, void * restrict y, int k); 187 | void quantize_row_q5_1(const float * restrict x, void * restrict y, int k); 188 | void quantize_row_q8_0(const float * restrict x, void * restrict y, int k); 189 | void quantize_row_q8_1(const float * restrict x, void * restrict y, int k); 190 | 191 | void quantize_row_q2_K(const float * restrict x, void * restrict y, int k); 192 | void quantize_row_q3_K(const float * restrict x, void * restrict y, int k); 193 | void quantize_row_q4_K(const float * restrict x, void * restrict y, int k); 194 | void quantize_row_q5_K(const float * restrict x, void * restrict y, int k); 195 | void quantize_row_q6_K(const float * restrict x, void * restrict y, int k); 196 | void quantize_row_q8_K(const float * restrict x, void * restrict y, int k); 197 | 198 | // Dequantization 199 | void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k); 200 | void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k); 201 | void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k); 202 | void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k); 203 | void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k); 204 | //void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k); 205 | 206 | void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k); 207 | void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k); 208 | void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k); 209 | void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k); 210 | void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); 211 | void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); 212 | 213 | // Dot product 214 | void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); 215 | void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); 216 | void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); 217 | void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); 218 | void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); 219 | 220 | void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); 221 | void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); 222 | void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); 223 | void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); 224 | void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); 225 | -------------------------------------------------------------------------------- /llmmodel.cpp: -------------------------------------------------------------------------------- 1 | #include "llmmodel.h" 2 | #include 3 | #include 4 | LLMModel::LLMModel() 5 | { 6 | m_timer = new QTimer; 7 | connect(m_timer, &QTimer::timeout, this, &LLMModel::Update); 8 | n_remain = params.n_predict; 9 | } 10 | 11 | void LLMModel::Reset() 12 | { 13 | if (m_timer->isActive()) 14 | { 15 | m_timer->stop(); 16 | embd.clear(); 17 | } 18 | } 19 | 20 | void LLMModel::Update() 21 | { 22 | static int n_past = 0; 23 | static int n_past_guidance = 0; 24 | 25 | 26 | LOG("*********embd_inp: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); 27 | // predict 28 | if (!embd.empty()) { 29 | // infinite text generation via context swapping 30 | // if we run out of context: 31 | // - take the n_keep first tokens from the original prompt (via n_past) 32 | // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches 33 | // if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { 34 | if (n_past + (int) embd.size() > 480) { 35 | if (params.n_predict == -2) { 36 | LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); 37 | } 38 | 39 | const int n_left = n_past - params.n_keep - 1; 40 | const int n_discard = n_left/2; 41 | 42 | LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", 43 | n_past, n_left, n_ctx, params.n_keep, n_discard); 44 | 45 | llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); 46 | llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); 47 | 48 | n_past -= n_discard; 49 | 50 | LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); 51 | 52 | LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); 53 | } 54 | 55 | for (int i = 0; i < (int) embd.size(); i += params.n_batch) { 56 | int n_eval = (int) embd.size() - i; 57 | if (n_eval > params.n_batch) { 58 | n_eval = params.n_batch; 59 | } 60 | 61 | LOG("eval: %s, embd:%d, n_past:%d\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str(), embd.size(), n_past); 62 | if (n_eval > 1) 63 | { 64 | LOG(" XXXXXXXXX %d\n", embd.size() + n_past); 65 | } 66 | 67 | if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { 68 | LOG_TEE("%s : failed to eval\n", __func__); 69 | break; 70 | } 71 | 72 | n_past += n_eval; 73 | 74 | LOG("n_past = %d\n", n_past); 75 | } 76 | } 77 | 78 | embd.clear(); 79 | const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); 80 | 81 | llama_sampling_accept(ctx_sampling, ctx, id, true); 82 | 83 | LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); 84 | 85 | embd.push_back(id); 86 | 87 | // decrement remaining sampling budget 88 | --n_remain; 89 | LOG("n_remain: %d\n", n_remain); 90 | 91 | // display text 92 | for (auto id : embd) { 93 | const std::string token_str = llama_token_to_piece(ctx, id); 94 | QString new_str = QString::fromStdString(token_str); 95 | m_output << new_str; 96 | // qDebug() << "xxx" << new_str << embd.size(); 97 | emit SignalNewAnswer(new_str, false); 98 | } 99 | 100 | if (m_output.size() >5) 101 | { 102 | for (int iii = m_output.size() - 5; iii < m_output.size(); iii++) 103 | { 104 | if (m_output.at(iii) != "" && m_output.at(iii) != "\n" ) 105 | { 106 | break; 107 | } 108 | QString new_str = ""; 109 | emit SignalNewAnswer(new_str, true); 110 | m_timer->stop(); 111 | } 112 | } 113 | 114 | // end of text token 115 | if (!embd.empty() && embd.back() == llama_token_eos(model)) { 116 | LOG_TEE(" [end of text]\n"); 117 | QString new_str = ""; 118 | emit SignalNewAnswer(new_str, true); 119 | m_timer->stop(); 120 | } 121 | } 122 | LLMModel::~LLMModel() 123 | { 124 | if (ctx) 125 | { 126 | llama_free(ctx); 127 | llama_free_model(model); 128 | 129 | llama_sampling_free(ctx_sampling); 130 | llama_backend_free(); 131 | } 132 | if (m_log_file) 133 | { 134 | LOG_TEE("Log end\n"); 135 | } 136 | } 137 | 138 | int LLMModel::LoadModel() 139 | { 140 | params.interactive = true; 141 | params.chatml = true; 142 | params.prompt = "You are a helpful assistant."; 143 | params.model = "ggml-model-q5_k_m.gguf"; 144 | 145 | llama_sampling_params & sparams = params.sparams; 146 | 147 | if (m_log_file) 148 | { 149 | log_set_target(log_filename_generator("main", "log")); 150 | LOG_TEE("Log start\n"); 151 | } 152 | LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); 153 | LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); 154 | 155 | if (params.seed == LLAMA_DEFAULT_SEED) { 156 | params.seed = time(NULL); 157 | } 158 | LOG_TEE("%s: seed = %u\n", __func__, params.seed); 159 | LOG("%s: llama backend init\n", __func__); 160 | llama_backend_init(params.numa); 161 | // load the model and apply lora adapter, if any 162 | LOG("%s: load the model and apply lora adapter, if any\n", __func__); 163 | std::tie(model, ctx) = llama_init_from_gpt_params(params); 164 | 165 | if (model == NULL) { 166 | LOG_TEE("%s: error: unable to load model\n", __func__); 167 | return -1; 168 | } 169 | 170 | const int n_ctx = llama_n_ctx(ctx); 171 | LOG("n_ctx: %d\n", n_ctx); 172 | 173 | // print system information 174 | { 175 | LOG_TEE("\n"); 176 | LOG_TEE("%s\n", get_system_info(params).c_str()); 177 | } 178 | 179 | const bool add_bos = llama_should_add_bos_token(model); 180 | LOG("add_bos: %d\n", add_bos); 181 | LOG("tokenize the prompt\n"); 182 | if (params.chatml) { 183 | params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; 184 | } 185 | embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); 186 | // chatml prefix & suffix 187 | cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true); 188 | cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true); 189 | LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str()); 190 | LOG("cml_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_sfx).c_str()); 191 | LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); 192 | LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); 193 | // Ensure the input doesn't exceed the context size by truncating embd if necessary. 194 | if ((int) embd_inp.size() > n_ctx - 100) { 195 | LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); 196 | qDebug() << "Error"; 197 | return -1; 198 | } 199 | 200 | // number of tokens to keep when resetting context 201 | if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) { 202 | params.n_keep = (int)embd_inp.size(); 203 | } 204 | 205 | // similar for chatml mode 206 | if (params.chatml) { 207 | params.interactive_first = true; 208 | params.antiprompt.push_back("<|im_start|>user\n"); 209 | } 210 | 211 | // enable interactive mode if interactive start is specified 212 | if (params.interactive_first) { 213 | params.interactive = true; 214 | } 215 | 216 | if (params.verbose_prompt) { 217 | LOG_TEE("\n"); 218 | LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); 219 | LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); 220 | for (int i = 0; i < (int) embd_inp.size(); i++) { 221 | LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); 222 | } 223 | 224 | if (params.n_keep > 0) { 225 | LOG_TEE("%s: static prompt based on n_keep: '", __func__); 226 | for (int i = 0; i < params.n_keep; i++) { 227 | LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str()); 228 | } 229 | LOG_TEE("'\n"); 230 | } 231 | LOG_TEE("\n"); 232 | } 233 | 234 | LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); 235 | LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); 236 | LOG_TEE("\n\n"); 237 | 238 | ctx_sampling = llama_sampling_init(sparams); 239 | return 1; 240 | } 241 | void LLMModel::Run(QString qstr_input) 242 | { 243 | m_output.clear(); 244 | llama_sampling_reset(ctx_sampling); 245 | static int chat_round_index = 0; 246 | m_input = qstr_input.toStdString(); 247 | LOG("waiting for user input\n"); 248 | // qDebug() << chat_round_index << QDateTime::currentDateTime(); 249 | chat_round_index = chat_round_index + 1; 250 | 251 | if (m_input.length() > 1) { 252 | LOG("buffer: '%s'\n", m_input.c_str()); 253 | 254 | // chatml mode: insert user chat prefix 255 | if (params.chatml) { 256 | LOG("inserting chatml prefix\n"); 257 | n_consumed = embd_inp.size(); 258 | embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end()); 259 | } 260 | if (params.escape) { 261 | process_escapes(m_input); 262 | } 263 | 264 | const auto line_inp = ::llama_tokenize(ctx, m_input, false, false); 265 | LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); 266 | embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); 267 | 268 | // chatml mode: insert assistant chat suffix 269 | if (params.chatml) { 270 | LOG("inserting chatml suffix\n"); 271 | embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end()); 272 | } 273 | 274 | n_remain -= line_inp.size(); 275 | LOG("n_remain: %d\n", n_remain); 276 | } 277 | embd.clear(); 278 | LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); 279 | while ((int) embd_inp.size() > n_consumed) { 280 | embd.push_back(embd_inp[n_consumed]); 281 | 282 | // push the prompt in the sampling context in order to apply repetition penalties later 283 | // for the prompt, we don't apply grammar rules 284 | llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); 285 | 286 | ++n_consumed; 287 | if ((int) embd.size() >= params.n_batch) { 288 | break; 289 | } 290 | } 291 | 292 | m_timer->start(5); 293 | } 294 | -------------------------------------------------------------------------------- /llamaCpp/common/common.h: -------------------------------------------------------------------------------- 1 | // Various helper functions and utilities 2 | 3 | #pragma once 4 | 5 | #include "llama.h" 6 | 7 | #include "sampling.h" 8 | 9 | #define LOG_NO_FILE_LINE_FUNCTION 10 | #include "log.h" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #ifdef _WIN32 21 | #define DIRECTORY_SEPARATOR '\\' 22 | #else 23 | #define DIRECTORY_SEPARATOR '/' 24 | #endif // _WIN32 25 | 26 | #define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) 27 | #define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) 28 | 29 | #define print_build_info() do { \ 30 | fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ 31 | fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ 32 | } while(0) 33 | 34 | // build info 35 | extern int LLAMA_BUILD_NUMBER; 36 | extern char const *LLAMA_COMMIT; 37 | extern char const *LLAMA_COMPILER; 38 | extern char const *LLAMA_BUILD_TARGET; 39 | 40 | // 41 | // CLI argument parsing 42 | // 43 | int32_t get_num_physical_cores(); 44 | 45 | struct gpt_params { 46 | uint32_t seed = -1; // RNG seed 47 | 48 | int32_t n_threads = get_num_physical_cores(); 49 | int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) 50 | int32_t n_predict = -1; // new tokens to predict 51 | int32_t n_ctx = 512; // context size 52 | int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) 53 | int32_t n_keep = 0; // number of tokens to keep from initial prompt 54 | int32_t n_draft = 16; // number of tokens to draft during speculative decoding 55 | int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) 56 | int32_t n_parallel = 1; // number of parallel sequences to decode 57 | int32_t n_sequences = 1; // number of sequences to decode 58 | float p_accept = 0.5f; // speculative decoding accept probability 59 | float p_split = 0.1f; // speculative decoding split probability 60 | int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) 61 | int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) 62 | int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors 63 | float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs 64 | int32_t n_beams = 0; // if non-zero then use beam search of given width. 65 | float rope_freq_base = 0.0f; // RoPE base frequency 66 | float rope_freq_scale = 0.0f; // RoPE frequency scaling factor 67 | float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor 68 | float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor 69 | float yarn_beta_fast = 32.0f; // YaRN low correction dim 70 | float yarn_beta_slow = 1.0f; // YaRN high correction dim 71 | int32_t yarn_orig_ctx = 0; // YaRN original context length 72 | int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment 73 | // pinging @cebtenzzre 74 | 75 | // // sampling parameters 76 | struct llama_sampling_params sparams; 77 | 78 | std::string model = "E:/Code/NLP/qwen.cpp/Qwen-1_8B-Chat/ggml-model-q5_k_m.gguf"; // model path 79 | std::string model_draft = ""; // draft model for speculative decoding 80 | std::string model_alias = "unknown"; // model alias 81 | std::string prompt = "You are a helpful assistant."; 82 | std::string prompt_file = ""; // store the external prompt file name 83 | std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state 84 | std::string input_prefix = ""; // string to prefix user inputs with 85 | std::string input_suffix = ""; // string to suffix user inputs with 86 | std::vector antiprompt; // string upon seeing which more user input is prompted 87 | std::string logdir = ""; // directory in which to save YAML log files 88 | 89 | // TODO: avoid tuple, use struct 90 | std::vector> lora_adapter; // lora adapter path with user defined scale 91 | std::string lora_base = ""; // base model path for the lora adapter 92 | 93 | int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. 94 | int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line 95 | // (which is more convenient to use for plotting) 96 | // 97 | bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt 98 | size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score 99 | 100 | bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS 101 | bool memory_f16 = true; // use f16 instead of f32 for memory kv 102 | bool random_prompt = false; // do not randomize prompt if none provided 103 | bool use_color = false; // use color to distinguish generations and inputs 104 | bool interactive = false; // interactive mode 105 | bool chatml = false; // chatml mode (used for models trained on chatml syntax) 106 | bool prompt_cache_all = false; // save user input and generations to prompt cache 107 | bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it 108 | 109 | bool embedding = false; // get only sentence embedding 110 | bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\" 111 | bool interactive_first = false; // wait for user input immediately 112 | bool multiline_input = false; // reverse the usage of `\` 113 | bool simple_io = false; // improves compatibility with subprocesses and limited consoles 114 | bool cont_batching = false; // insert new sequences for decoding on-the-fly 115 | 116 | bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix 117 | bool ignore_eos = false; // ignore generated EOS tokens 118 | bool instruct = false; // instruction mode (used for Alpaca models) 119 | bool logits_all = false; // return logits for all tokens in the batch 120 | bool use_mmap = true; // use mmap for faster loads 121 | bool use_mlock = false; // use mlock to keep model in memory 122 | bool numa = false; // attempt optimizations that help on some NUMA systems 123 | bool verbose_prompt = false; // print prompt tokens before generation 124 | bool infill = false; // use infill mode 125 | bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes 126 | 127 | // multimodal models (see examples/llava) 128 | std::string mmproj = ""; // path to multimodal projector 129 | std::string image = ""; // path to an image file 130 | }; 131 | 132 | bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); 133 | 134 | bool gpt_params_parse(int argc, char ** argv, gpt_params & params); 135 | 136 | void gpt_print_usage(int argc, char ** argv, const gpt_params & params); 137 | 138 | std::string get_system_info(const gpt_params & params); 139 | 140 | std::string gpt_random_prompt(std::mt19937 & rng); 141 | 142 | void process_escapes(std::string& input); 143 | 144 | // 145 | // Model utils 146 | // 147 | 148 | // TODO: avoid tuplue, use struct 149 | std::tuple llama_init_from_gpt_params(gpt_params & params); 150 | 151 | struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); 152 | struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); 153 | 154 | // Batch utils 155 | 156 | void llama_batch_clear(struct llama_batch & batch); 157 | 158 | void llama_batch_add( 159 | struct llama_batch & batch, 160 | llama_token id, 161 | llama_pos pos, 162 | const std::vector & seq_ids, 163 | bool logits); 164 | 165 | // 166 | // Vocab utils 167 | // 168 | 169 | // tokenizes a string into a vector of tokens 170 | // should work similar to Python's `tokenizer.encode` 171 | std::vector llama_tokenize( 172 | const struct llama_context * ctx, 173 | const std::string & text, 174 | bool add_bos, 175 | bool special = false); 176 | 177 | std::vector llama_tokenize( 178 | const struct llama_model * model, 179 | const std::string & text, 180 | bool add_bos, 181 | bool special = false); 182 | 183 | // tokenizes a token into a piece 184 | // should work similar to Python's `tokenizer.id_to_piece` 185 | std::string llama_token_to_piece( 186 | const struct llama_context * ctx, 187 | llama_token token); 188 | 189 | // TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function 190 | // that takes into account the tokenizer type and decides how to handle the leading space 191 | // 192 | // detokenizes a vector of tokens into a string 193 | // should work similar to Python's `tokenizer.decode` 194 | // removes the leading space from the first non-BOS token 195 | std::string llama_detokenize_spm( 196 | llama_context * ctx, 197 | const std::vector & tokens); 198 | 199 | // detokenizes a vector of tokens into a string 200 | // should work similar to Python's `tokenizer.decode` 201 | std::string llama_detokenize_bpe( 202 | llama_context * ctx, 203 | const std::vector & tokens); 204 | 205 | // Uses the value from the model metadata if possible, otherwise 206 | // defaults to true when model type is SPM, otherwise false. 207 | bool llama_should_add_bos_token(const llama_model * model); 208 | 209 | // 210 | // YAML utils 211 | // 212 | 213 | bool create_directory_with_parents(const std::string & path); 214 | void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data); 215 | void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data); 216 | void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data); 217 | std::string get_sortable_timestamp(); 218 | 219 | void dump_non_result_info_yaml( 220 | FILE * stream, const gpt_params & params, const llama_context * lctx, 221 | const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); 222 | 223 | // 224 | // KV cache utils 225 | // 226 | 227 | // Dump the KV cache view with the number of sequences per cell. 228 | void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); 229 | 230 | // Dump the KV cache view showing individual sequences in each cell (long output). 231 | void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); 232 | -------------------------------------------------------------------------------- /llamaCpp/common/base64.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | This is free and unencumbered software released into the public domain. 3 | 4 | Anyone is free to copy, modify, publish, use, compile, sell, or 5 | distribute this software, either in source code form or as a compiled 6 | binary, for any purpose, commercial or non-commercial, and by any 7 | means. 8 | 9 | In jurisdictions that recognize copyright laws, the author or authors 10 | of this software dedicate any and all copyright interest in the 11 | software to the public domain. We make this dedication for the benefit 12 | of the public at large and to the detriment of our heirs and 13 | successors. We intend this dedication to be an overt act of 14 | relinquishment in perpetuity of all present and future rights to this 15 | software under copyright law. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 21 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 22 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 23 | OTHER DEALINGS IN THE SOFTWARE. 24 | 25 | For more information, please refer to 26 | */ 27 | 28 | #ifndef PUBLIC_DOMAIN_BASE64_HPP_ 29 | #define PUBLIC_DOMAIN_BASE64_HPP_ 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | 36 | class base64_error : public std::runtime_error 37 | { 38 | public: 39 | using std::runtime_error::runtime_error; 40 | }; 41 | 42 | class base64 43 | { 44 | public: 45 | enum class alphabet 46 | { 47 | /** the alphabet is detected automatically */ 48 | auto_, 49 | /** the standard base64 alphabet is used */ 50 | standard, 51 | /** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/ 52 | url_filename_safe 53 | }; 54 | 55 | enum class decoding_behavior 56 | { 57 | /** if the input is not padded, the remaining bits are ignored */ 58 | moderate, 59 | /** if a padding character is encounter decoding is finished */ 60 | loose 61 | }; 62 | 63 | /** 64 | Encodes all the elements from `in_begin` to `in_end` to `out`. 65 | 66 | @warning The source and destination cannot overlap. The destination must be able to hold at least 67 | `required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator. 68 | 69 | @tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than 70 | 8 bits 71 | @tparam Output_iterator the destination; the elements written to it are from the type `char` 72 | @param in_begin the beginning of the source 73 | @param in_end the ending of the source 74 | @param out the destination iterator 75 | @param alphabet which alphabet should be used 76 | @returns the iterator to the next element past the last element copied 77 | @throws see `Input_iterator` and `Output_iterator` 78 | */ 79 | template 80 | static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out, 81 | alphabet alphabet = alphabet::standard) 82 | { 83 | constexpr auto pad = '='; 84 | const char* alpha = alphabet == alphabet::url_filename_safe 85 | ? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" 86 | : "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; 87 | 88 | while (in_begin != in_end) { 89 | std::uint8_t i0 = 0, i1 = 0, i2 = 0; 90 | 91 | // first character 92 | i0 = static_cast(*in_begin); 93 | ++in_begin; 94 | 95 | *out = alpha[i0 >> 2 & 0x3f]; 96 | ++out; 97 | 98 | // part of first character and second 99 | if (in_begin != in_end) { 100 | i1 = static_cast(*in_begin); 101 | ++in_begin; 102 | 103 | *out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)]; 104 | ++out; 105 | } else { 106 | *out = alpha[(i0 & 0x3) << 4]; 107 | ++out; 108 | 109 | // last padding 110 | *out = pad; 111 | ++out; 112 | 113 | // last padding 114 | *out = pad; 115 | ++out; 116 | 117 | break; 118 | } 119 | 120 | // part of second character and third 121 | if (in_begin != in_end) { 122 | i2 = static_cast(*in_begin); 123 | ++in_begin; 124 | 125 | *out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)]; 126 | ++out; 127 | } else { 128 | *out = alpha[(i1 & 0xf) << 2]; 129 | ++out; 130 | 131 | // last padding 132 | *out = pad; 133 | ++out; 134 | 135 | break; 136 | } 137 | 138 | // rest of third 139 | *out = alpha[i2 & 0x3f]; 140 | ++out; 141 | } 142 | 143 | return out; 144 | } 145 | /** 146 | Encodes a string. 147 | 148 | @param str the string that should be encoded 149 | @param alphabet which alphabet should be used 150 | @returns the encoded base64 string 151 | @throws see base64::encode() 152 | */ 153 | static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard) 154 | { 155 | std::string result; 156 | 157 | result.reserve(required_encode_size(str.length()) + 1); 158 | 159 | encode(str.begin(), str.end(), std::back_inserter(result), alphabet); 160 | 161 | return result; 162 | } 163 | /** 164 | Encodes a char array. 165 | 166 | @param buffer the char array 167 | @param size the size of the array 168 | @param alphabet which alphabet should be used 169 | @returns the encoded string 170 | */ 171 | static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard) 172 | { 173 | std::string result; 174 | 175 | result.reserve(required_encode_size(size) + 1); 176 | 177 | encode(buffer, buffer + size, std::back_inserter(result), alphabet); 178 | 179 | return result; 180 | } 181 | /** 182 | Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`, 183 | in other words: inplace decoding is possible. 184 | 185 | @warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`, 186 | otherwise the behavior depends on the output iterator. 187 | 188 | @tparam Input_iterator the source; the returned elements are cast to `char` 189 | @tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t` 190 | @param in_begin the beginning of the source 191 | @param in_end the ending of the source 192 | @param out the destination iterator 193 | @param alphabet which alphabet should be used 194 | @param behavior the behavior when an error was detected 195 | @returns the iterator to the next element past the last element copied 196 | @throws base64_error depending on the set behavior 197 | @throws see `Input_iterator` and `Output_iterator` 198 | */ 199 | template 200 | static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out, 201 | alphabet alphabet = alphabet::auto_, 202 | decoding_behavior behavior = decoding_behavior::moderate) 203 | { 204 | //constexpr auto pad = '='; 205 | std::uint8_t last = 0; 206 | auto bits = 0; 207 | 208 | while (in_begin != in_end) { 209 | auto c = *in_begin; 210 | ++in_begin; 211 | 212 | if (c == '=') { 213 | break; 214 | } 215 | 216 | auto part = _base64_value(alphabet, c); 217 | 218 | // enough bits for one byte 219 | if (bits + 6 >= 8) { 220 | *out = (last << (8 - bits)) | (part >> (bits - 2)); 221 | ++out; 222 | 223 | bits -= 2; 224 | } else { 225 | bits += 6; 226 | } 227 | 228 | last = part; 229 | } 230 | 231 | // check padding 232 | if (behavior != decoding_behavior::loose) { 233 | while (in_begin != in_end) { 234 | auto c = *in_begin; 235 | ++in_begin; 236 | 237 | if (c != '=') { 238 | throw base64_error("invalid base64 character."); 239 | } 240 | } 241 | } 242 | 243 | return out; 244 | } 245 | /** 246 | Decodes a string. 247 | 248 | @param str the base64 encoded string 249 | @param alphabet which alphabet should be used 250 | @param behavior the behavior when an error was detected 251 | @returns the decoded string 252 | @throws see base64::decode() 253 | */ 254 | static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_, 255 | decoding_behavior behavior = decoding_behavior::moderate) 256 | { 257 | std::string result; 258 | 259 | result.reserve(max_decode_size(str.length())); 260 | 261 | decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior); 262 | 263 | return result; 264 | } 265 | /** 266 | Decodes a string. 267 | 268 | @param buffer the base64 encoded buffer 269 | @param size the size of the buffer 270 | @param alphabet which alphabet should be used 271 | @param behavior the behavior when an error was detected 272 | @returns the decoded string 273 | @throws see base64::decode() 274 | */ 275 | static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_, 276 | decoding_behavior behavior = decoding_behavior::moderate) 277 | { 278 | std::string result; 279 | 280 | result.reserve(max_decode_size(size)); 281 | 282 | decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior); 283 | 284 | return result; 285 | } 286 | /** 287 | Decodes a string inplace. 288 | 289 | @param[in,out] str the base64 encoded string 290 | @param alphabet which alphabet should be used 291 | @param behavior the behavior when an error was detected 292 | @throws base64::decode_inplace() 293 | */ 294 | static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_, 295 | decoding_behavior behavior = decoding_behavior::moderate) 296 | { 297 | str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin()); 298 | } 299 | /** 300 | Decodes a char array inplace. 301 | 302 | @param[in,out] str the string array 303 | @param size the length of the array 304 | @param alphabet which alphabet should be used 305 | @param behavior the behavior when an error was detected 306 | @returns the pointer to the next element past the last element decoded 307 | @throws base64::decode_inplace() 308 | */ 309 | static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_, 310 | decoding_behavior behavior = decoding_behavior::moderate) 311 | { 312 | return decode(str, str + size, str, alphabet, behavior); 313 | } 314 | /** 315 | Returns the required decoding size for a given size. The value is calculated with the following formula: 316 | 317 | $$ 318 | \lceil \frac{size}{4} \rceil \cdot 3 319 | $$ 320 | 321 | @param size the size of the encoded input 322 | @returns the size of the resulting decoded buffer; this the absolute maximum 323 | */ 324 | static std::size_t max_decode_size(std::size_t size) noexcept 325 | { 326 | return (size / 4 + (size % 4 ? 1 : 0)) * 3; 327 | } 328 | /** 329 | Returns the required encoding size for a given size. The value is calculated with the following formula: 330 | 331 | $$ 332 | \lceil \frac{size}{3} \rceil \cdot 4 333 | $$ 334 | 335 | @param size the size of the decoded input 336 | @returns the size of the resulting encoded buffer 337 | */ 338 | static std::size_t required_encode_size(std::size_t size) noexcept 339 | { 340 | return (size / 3 + (size % 3 ? 1 : 0)) * 4; 341 | } 342 | 343 | private: 344 | static std::uint8_t _base64_value(alphabet& alphabet, char c) 345 | { 346 | if (c >= 'A' && c <= 'Z') { 347 | return c - 'A'; 348 | } else if (c >= 'a' && c <= 'z') { 349 | return c - 'a' + 26; 350 | } else if (c >= '0' && c <= '9') { 351 | return c - '0' + 52; 352 | } 353 | 354 | // comes down to alphabet 355 | if (alphabet == alphabet::standard) { 356 | if (c == '+') { 357 | return 62; 358 | } else if (c == '/') { 359 | return 63; 360 | } 361 | } else if (alphabet == alphabet::url_filename_safe) { 362 | if (c == '-') { 363 | return 62; 364 | } else if (c == '_') { 365 | return 63; 366 | } 367 | } // auto detect 368 | else { 369 | if (c == '+') { 370 | alphabet = alphabet::standard; 371 | 372 | return 62; 373 | } else if (c == '/') { 374 | alphabet = alphabet::standard; 375 | 376 | return 63; 377 | } else if (c == '-') { 378 | alphabet = alphabet::url_filename_safe; 379 | 380 | return 62; 381 | } else if (c == '_') { 382 | alphabet = alphabet::url_filename_safe; 383 | 384 | return 63; 385 | } 386 | } 387 | 388 | throw base64_error("invalid base64 character."); 389 | } 390 | }; 391 | 392 | #endif // !PUBLIC_DOMAIN_BASE64_HPP_ 393 | -------------------------------------------------------------------------------- /llamaCpp/common/console.cpp: -------------------------------------------------------------------------------- 1 | #include "console.h" 2 | #include 3 | #include 4 | 5 | #if defined(_WIN32) 6 | #define WIN32_LEAN_AND_MEAN 7 | #ifndef NOMINMAX 8 | #define NOMINMAX 9 | #endif 10 | #include 11 | #include 12 | #include 13 | #ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING 14 | #define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004 15 | #endif 16 | #else 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #endif 26 | 27 | #define ANSI_COLOR_RED "\x1b[31m" 28 | #define ANSI_COLOR_GREEN "\x1b[32m" 29 | #define ANSI_COLOR_YELLOW "\x1b[33m" 30 | #define ANSI_COLOR_BLUE "\x1b[34m" 31 | #define ANSI_COLOR_MAGENTA "\x1b[35m" 32 | #define ANSI_COLOR_CYAN "\x1b[36m" 33 | #define ANSI_COLOR_RESET "\x1b[0m" 34 | #define ANSI_BOLD "\x1b[1m" 35 | 36 | namespace console { 37 | 38 | // 39 | // Console state 40 | // 41 | 42 | static bool advanced_display = false; 43 | static bool simple_io = true; 44 | static display_t current_display = reset; 45 | 46 | static FILE* out = stdout; 47 | 48 | #if defined (_WIN32) 49 | static void* hConsole; 50 | #else 51 | static FILE* tty = nullptr; 52 | static termios initial_state; 53 | #endif 54 | 55 | // 56 | // Init and cleanup 57 | // 58 | 59 | void init(bool use_simple_io, bool use_advanced_display) { 60 | advanced_display = use_advanced_display; 61 | simple_io = use_simple_io; 62 | #if defined(_WIN32) 63 | // Windows-specific console initialization 64 | DWORD dwMode = 0; 65 | hConsole = GetStdHandle(STD_OUTPUT_HANDLE); 66 | if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) { 67 | hConsole = GetStdHandle(STD_ERROR_HANDLE); 68 | if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) { 69 | hConsole = nullptr; 70 | simple_io = true; 71 | } 72 | } 73 | if (hConsole) { 74 | // Check conditions combined to reduce nesting 75 | if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) && 76 | !SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { 77 | advanced_display = false; 78 | } 79 | // Set console output codepage to UTF8 80 | SetConsoleOutputCP(CP_UTF8); 81 | } 82 | HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); 83 | if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { 84 | // Set console input codepage to UTF16 85 | _setmode(_fileno(stdin), _O_WTEXT); 86 | 87 | // Set ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) 88 | if (simple_io) { 89 | dwMode |= ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT; 90 | } else { 91 | dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); 92 | } 93 | if (!SetConsoleMode(hConIn, dwMode)) { 94 | simple_io = true; 95 | } 96 | } 97 | #else 98 | // POSIX-specific console initialization 99 | if (!simple_io) { 100 | struct termios new_termios; 101 | tcgetattr(STDIN_FILENO, &initial_state); 102 | new_termios = initial_state; 103 | new_termios.c_lflag &= ~(ICANON | ECHO); 104 | new_termios.c_cc[VMIN] = 1; 105 | new_termios.c_cc[VTIME] = 0; 106 | tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); 107 | 108 | tty = fopen("/dev/tty", "w+"); 109 | if (tty != nullptr) { 110 | out = tty; 111 | } 112 | } 113 | 114 | setlocale(LC_ALL, ""); 115 | #endif 116 | } 117 | 118 | void cleanup() { 119 | // Reset console display 120 | set_display(reset); 121 | 122 | #if !defined(_WIN32) 123 | // Restore settings on POSIX systems 124 | if (!simple_io) { 125 | if (tty != nullptr) { 126 | out = stdout; 127 | fclose(tty); 128 | tty = nullptr; 129 | } 130 | tcsetattr(STDIN_FILENO, TCSANOW, &initial_state); 131 | } 132 | #endif 133 | } 134 | 135 | // 136 | // Display and IO 137 | // 138 | 139 | // Keep track of current display and only emit ANSI code if it changes 140 | void set_display(display_t display) { 141 | if (advanced_display && current_display != display) { 142 | fflush(stdout); 143 | switch(display) { 144 | case reset: 145 | fprintf(out, ANSI_COLOR_RESET); 146 | break; 147 | case prompt: 148 | fprintf(out, ANSI_COLOR_YELLOW); 149 | break; 150 | case user_input: 151 | fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN); 152 | break; 153 | case error: 154 | fprintf(out, ANSI_BOLD ANSI_COLOR_RED); 155 | } 156 | current_display = display; 157 | fflush(out); 158 | } 159 | } 160 | 161 | static char32_t getchar32() { 162 | #if defined(_WIN32) 163 | HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE); 164 | wchar_t high_surrogate = 0; 165 | 166 | while (true) { 167 | INPUT_RECORD record; 168 | DWORD count; 169 | if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) { 170 | return WEOF; 171 | } 172 | 173 | if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) { 174 | wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar; 175 | if (wc == 0) { 176 | continue; 177 | } 178 | 179 | if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate 180 | high_surrogate = wc; 181 | continue; 182 | } 183 | if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate 184 | if (high_surrogate != 0) { // Check if we have a high surrogate 185 | return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000; 186 | } 187 | } 188 | 189 | high_surrogate = 0; // Reset the high surrogate 190 | return static_cast(wc); 191 | } 192 | } 193 | #else 194 | wchar_t wc = getwchar(); 195 | if (static_cast(wc) == WEOF) { 196 | return WEOF; 197 | } 198 | 199 | #if WCHAR_MAX == 0xFFFF 200 | if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate 201 | wchar_t low_surrogate = getwchar(); 202 | if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate 203 | return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; 204 | } 205 | } 206 | if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair 207 | return 0xFFFD; // Return the replacement character U+FFFD 208 | } 209 | #endif 210 | 211 | return static_cast(wc); 212 | #endif 213 | } 214 | 215 | static void pop_cursor() { 216 | #if defined(_WIN32) 217 | if (hConsole != NULL) { 218 | CONSOLE_SCREEN_BUFFER_INFO bufferInfo; 219 | GetConsoleScreenBufferInfo(hConsole, &bufferInfo); 220 | 221 | COORD newCursorPosition = bufferInfo.dwCursorPosition; 222 | if (newCursorPosition.X == 0) { 223 | newCursorPosition.X = bufferInfo.dwSize.X - 1; 224 | newCursorPosition.Y -= 1; 225 | } else { 226 | newCursorPosition.X -= 1; 227 | } 228 | 229 | SetConsoleCursorPosition(hConsole, newCursorPosition); 230 | return; 231 | } 232 | #endif 233 | putc('\b', out); 234 | } 235 | 236 | static int estimateWidth(char32_t codepoint) { 237 | #if defined(_WIN32) 238 | (void)codepoint; 239 | return 1; 240 | #else 241 | return wcwidth(codepoint); 242 | #endif 243 | } 244 | 245 | static int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) { 246 | #if defined(_WIN32) 247 | CONSOLE_SCREEN_BUFFER_INFO bufferInfo; 248 | if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) { 249 | // go with the default 250 | return expectedWidth; 251 | } 252 | COORD initialPosition = bufferInfo.dwCursorPosition; 253 | DWORD nNumberOfChars = length; 254 | WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); 255 | 256 | CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; 257 | GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); 258 | 259 | // Figure out our real position if we're in the last column 260 | if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { 261 | DWORD nNumberOfChars; 262 | WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL); 263 | GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); 264 | } 265 | 266 | int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; 267 | if (width < 0) { 268 | width += newBufferInfo.dwSize.X; 269 | } 270 | return width; 271 | #else 272 | // We can trust expectedWidth if we've got one 273 | if (expectedWidth >= 0 || tty == nullptr) { 274 | fwrite(utf8_codepoint, length, 1, out); 275 | return expectedWidth; 276 | } 277 | 278 | fputs("\033[6n", tty); // Query cursor position 279 | int x1; 280 | int y1; 281 | int x2; 282 | int y2; 283 | int results = 0; 284 | results = fscanf(tty, "\033[%d;%dR", &y1, &x1); 285 | 286 | fwrite(utf8_codepoint, length, 1, tty); 287 | 288 | fputs("\033[6n", tty); // Query cursor position 289 | results += fscanf(tty, "\033[%d;%dR", &y2, &x2); 290 | 291 | if (results != 4) { 292 | return expectedWidth; 293 | } 294 | 295 | int width = x2 - x1; 296 | if (width < 0) { 297 | // Calculate the width considering text wrapping 298 | struct winsize w; 299 | ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); 300 | width += w.ws_col; 301 | } 302 | return width; 303 | #endif 304 | } 305 | 306 | static void replace_last(char ch) { 307 | #if defined(_WIN32) 308 | pop_cursor(); 309 | put_codepoint(&ch, 1, 1); 310 | #else 311 | fprintf(out, "\b%c", ch); 312 | #endif 313 | } 314 | 315 | static void append_utf8(char32_t ch, std::string & out) { 316 | if (ch <= 0x7F) { 317 | out.push_back(static_cast(ch)); 318 | } else if (ch <= 0x7FF) { 319 | out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); 320 | out.push_back(static_cast(0x80 | (ch & 0x3F))); 321 | } else if (ch <= 0xFFFF) { 322 | out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); 323 | out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); 324 | out.push_back(static_cast(0x80 | (ch & 0x3F))); 325 | } else if (ch <= 0x10FFFF) { 326 | out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); 327 | out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); 328 | out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); 329 | out.push_back(static_cast(0x80 | (ch & 0x3F))); 330 | } else { 331 | // Invalid Unicode code point 332 | } 333 | } 334 | 335 | // Helper function to remove the last UTF-8 character from a string 336 | static void pop_back_utf8_char(std::string & line) { 337 | if (line.empty()) { 338 | return; 339 | } 340 | 341 | size_t pos = line.length() - 1; 342 | 343 | // Find the start of the last UTF-8 character (checking up to 4 bytes back) 344 | for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { 345 | if ((line[pos] & 0xC0) != 0x80) { 346 | break; // Found the start of the character 347 | } 348 | } 349 | line.erase(pos); 350 | } 351 | 352 | static bool readline_advanced(std::string & line, bool multiline_input) { 353 | if (out != stdout) { 354 | fflush(stdout); 355 | } 356 | 357 | line.clear(); 358 | std::vector widths; 359 | bool is_special_char = false; 360 | bool end_of_stream = false; 361 | 362 | char32_t input_char; 363 | while (true) { 364 | fflush(out); // Ensure all output is displayed before waiting for input 365 | input_char = getchar32(); 366 | 367 | if (input_char == '\r' || input_char == '\n') { 368 | break; 369 | } 370 | 371 | if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) { 372 | end_of_stream = true; 373 | break; 374 | } 375 | 376 | if (is_special_char) { 377 | set_display(user_input); 378 | replace_last(line.back()); 379 | is_special_char = false; 380 | } 381 | 382 | if (input_char == '\033') { // Escape sequence 383 | char32_t code = getchar32(); 384 | if (code == '[' || code == 0x1B) { 385 | // Discard the rest of the escape sequence 386 | while ((code = getchar32()) != (char32_t) WEOF) { 387 | if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { 388 | break; 389 | } 390 | } 391 | } 392 | } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace 393 | if (!widths.empty()) { 394 | int count; 395 | do { 396 | count = widths.back(); 397 | widths.pop_back(); 398 | // Move cursor back, print space, and move cursor back again 399 | for (int i = 0; i < count; i++) { 400 | replace_last(' '); 401 | pop_cursor(); 402 | } 403 | pop_back_utf8_char(line); 404 | } while (count == 0 && !widths.empty()); 405 | } 406 | } else { 407 | int offset = line.length(); 408 | append_utf8(input_char, line); 409 | int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); 410 | if (width < 0) { 411 | width = 0; 412 | } 413 | widths.push_back(width); 414 | } 415 | 416 | if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { 417 | set_display(prompt); 418 | replace_last(line.back()); 419 | is_special_char = true; 420 | } 421 | } 422 | 423 | bool has_more = multiline_input; 424 | if (is_special_char) { 425 | replace_last(' '); 426 | pop_cursor(); 427 | 428 | char last = line.back(); 429 | line.pop_back(); 430 | if (last == '\\') { 431 | line += '\n'; 432 | fputc('\n', out); 433 | has_more = !has_more; 434 | } else { 435 | // llama will just eat the single space, it won't act as a space 436 | if (line.length() == 1 && line.back() == ' ') { 437 | line.clear(); 438 | pop_cursor(); 439 | } 440 | has_more = false; 441 | } 442 | } else { 443 | if (end_of_stream) { 444 | has_more = false; 445 | } else { 446 | line += '\n'; 447 | fputc('\n', out); 448 | } 449 | } 450 | 451 | fflush(out); 452 | return has_more; 453 | } 454 | 455 | static bool readline_simple(std::string & line, bool multiline_input) { 456 | #if defined(_WIN32) 457 | std::wstring wline; 458 | if (!std::getline(std::wcin, wline)) { 459 | // Input stream is bad or EOF received 460 | line.clear(); 461 | GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0); 462 | return false; 463 | } 464 | 465 | int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL); 466 | line.resize(size_needed); 467 | WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL); 468 | #else 469 | if (!std::getline(std::cin, line)) { 470 | // Input stream is bad or EOF received 471 | line.clear(); 472 | return false; 473 | } 474 | #endif 475 | if (!line.empty()) { 476 | char last = line.back(); 477 | if (last == '/') { // Always return control on '/' symbol 478 | line.pop_back(); 479 | return false; 480 | } 481 | if (last == '\\') { // '\\' changes the default action 482 | line.pop_back(); 483 | multiline_input = !multiline_input; 484 | } 485 | } 486 | line += '\n'; 487 | 488 | // By default, continue input if multiline_input is set 489 | return multiline_input; 490 | } 491 | 492 | bool readline(std::string & line, bool multiline_input) { 493 | set_display(user_input); 494 | 495 | if (simple_io) { 496 | return readline_simple(line, multiline_input); 497 | } 498 | return readline_advanced(line, multiline_input); 499 | } 500 | 501 | } 502 | -------------------------------------------------------------------------------- /llamaCpp/common/grammar-parser.cpp: -------------------------------------------------------------------------------- 1 | #include "grammar-parser.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace grammar_parser { 10 | // NOTE: assumes valid utf8 (but checks for overrun) 11 | // copied from llama.cpp 12 | static std::pair decode_utf8(const char * src) { 13 | static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; 14 | uint8_t first_byte = static_cast(*src); 15 | uint8_t highbits = first_byte >> 4; 16 | int len = lookup[highbits]; 17 | uint8_t mask = (1 << (8 - len)) - 1; 18 | uint32_t value = first_byte & mask; 19 | const char * end = src + len; // may overrun! 20 | const char * pos = src + 1; 21 | for ( ; pos < end && *pos; pos++) { 22 | value = (value << 6) + (static_cast(*pos) & 0x3F); 23 | } 24 | return std::make_pair(value, pos); 25 | } 26 | 27 | static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { 28 | uint32_t next_id = static_cast(state.symbol_ids.size()); 29 | auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); 30 | return result.first->second; 31 | } 32 | 33 | static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { 34 | uint32_t next_id = static_cast(state.symbol_ids.size()); 35 | state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; 36 | return next_id; 37 | } 38 | 39 | static void add_rule( 40 | parse_state & state, 41 | uint32_t rule_id, 42 | const std::vector & rule) { 43 | if (state.rules.size() <= rule_id) { 44 | state.rules.resize(rule_id + 1); 45 | } 46 | state.rules[rule_id] = rule; 47 | } 48 | 49 | static bool is_word_char(char c) { 50 | return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); 51 | } 52 | 53 | static std::pair parse_hex(const char * src, int size) { 54 | const char * pos = src; 55 | const char * end = src + size; 56 | uint32_t value = 0; 57 | for ( ; pos < end && *pos; pos++) { 58 | value <<= 4; 59 | char c = *pos; 60 | if ('a' <= c && c <= 'f') { 61 | value += c - 'a' + 10; 62 | } else if ('A' <= c && c <= 'F') { 63 | value += c - 'A' + 10; 64 | } else if ('0' <= c && c <= '9') { 65 | value += c - '0'; 66 | } else { 67 | break; 68 | } 69 | } 70 | if (pos != end) { 71 | throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); 72 | } 73 | return std::make_pair(value, pos); 74 | } 75 | 76 | static const char * parse_space(const char * src, bool newline_ok) { 77 | const char * pos = src; 78 | while (*pos == ' ' || *pos == '\t' || *pos == '#' || 79 | (newline_ok && (*pos == '\r' || *pos == '\n'))) { 80 | if (*pos == '#') { 81 | while (*pos && *pos != '\r' && *pos != '\n') { 82 | pos++; 83 | } 84 | } else { 85 | pos++; 86 | } 87 | } 88 | return pos; 89 | } 90 | 91 | static const char * parse_name(const char * src) { 92 | const char * pos = src; 93 | while (is_word_char(*pos)) { 94 | pos++; 95 | } 96 | if (pos == src) { 97 | throw std::runtime_error(std::string("expecting name at ") + src); 98 | } 99 | return pos; 100 | } 101 | 102 | static std::pair parse_char(const char * src) { 103 | if (*src == '\\') { 104 | switch (src[1]) { 105 | case 'x': return parse_hex(src + 2, 2); 106 | case 'u': return parse_hex(src + 2, 4); 107 | case 'U': return parse_hex(src + 2, 8); 108 | case 't': return std::make_pair('\t', src + 2); 109 | case 'r': return std::make_pair('\r', src + 2); 110 | case 'n': return std::make_pair('\n', src + 2); 111 | case '\\': 112 | case '"': 113 | case '[': 114 | case ']': 115 | return std::make_pair(src[1], src + 2); 116 | default: 117 | throw std::runtime_error(std::string("unknown escape at ") + src); 118 | } 119 | } else if (*src) { 120 | return decode_utf8(src); 121 | } 122 | throw std::runtime_error("unexpected end of input"); 123 | } 124 | 125 | const char * parse_alternates( 126 | parse_state & state, 127 | const char * src, 128 | const std::string & rule_name, 129 | uint32_t rule_id, 130 | bool is_nested); 131 | 132 | static const char * parse_sequence( 133 | parse_state & state, 134 | const char * src, 135 | const std::string & rule_name, 136 | std::vector & out_elements, 137 | bool is_nested) { 138 | size_t last_sym_start = out_elements.size(); 139 | const char * pos = src; 140 | while (*pos) { 141 | if (*pos == '"') { // literal string 142 | pos++; 143 | last_sym_start = out_elements.size(); 144 | while (*pos != '"') { 145 | auto char_pair = parse_char(pos); 146 | pos = char_pair.second; 147 | out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); 148 | } 149 | pos = parse_space(pos + 1, is_nested); 150 | } else if (*pos == '[') { // char range(s) 151 | pos++; 152 | enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; 153 | if (*pos == '^') { 154 | pos++; 155 | start_type = LLAMA_GRETYPE_CHAR_NOT; 156 | } 157 | last_sym_start = out_elements.size(); 158 | while (*pos != ']') { 159 | auto char_pair = parse_char(pos); 160 | pos = char_pair.second; 161 | enum llama_gretype type = last_sym_start < out_elements.size() 162 | ? LLAMA_GRETYPE_CHAR_ALT 163 | : start_type; 164 | 165 | out_elements.push_back({type, char_pair.first}); 166 | if (pos[0] == '-' && pos[1] != ']') { 167 | auto endchar_pair = parse_char(pos + 1); 168 | pos = endchar_pair.second; 169 | out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); 170 | } 171 | } 172 | pos = parse_space(pos + 1, is_nested); 173 | } else if (is_word_char(*pos)) { // rule reference 174 | const char * name_end = parse_name(pos); 175 | uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); 176 | pos = parse_space(name_end, is_nested); 177 | last_sym_start = out_elements.size(); 178 | out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); 179 | } else if (*pos == '(') { // grouping 180 | // parse nested alternates into synthesized rule 181 | pos = parse_space(pos + 1, true); 182 | uint32_t sub_rule_id = generate_symbol_id(state, rule_name); 183 | pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); 184 | last_sym_start = out_elements.size(); 185 | // output reference to synthesized rule 186 | out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); 187 | if (*pos != ')') { 188 | throw std::runtime_error(std::string("expecting ')' at ") + pos); 189 | } 190 | pos = parse_space(pos + 1, is_nested); 191 | } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator 192 | if (last_sym_start == out_elements.size()) { 193 | throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); 194 | } 195 | 196 | // apply transformation to previous symbol (last_sym_start to end) according to 197 | // rewrite rules: 198 | // S* --> S' ::= S S' | 199 | // S+ --> S' ::= S S' | S 200 | // S? --> S' ::= S | 201 | uint32_t sub_rule_id = generate_symbol_id(state, rule_name); 202 | std::vector sub_rule; 203 | // add preceding symbol to generated rule 204 | sub_rule.insert( 205 | sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); 206 | if (*pos == '*' || *pos == '+') { 207 | // cause generated rule to recurse 208 | sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); 209 | } 210 | // mark start of alternate def 211 | sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); 212 | if (*pos == '+') { 213 | // add preceding symbol as alternate only for '+' (otherwise empty) 214 | sub_rule.insert( 215 | sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); 216 | } 217 | sub_rule.push_back({LLAMA_GRETYPE_END, 0}); 218 | add_rule(state, sub_rule_id, sub_rule); 219 | 220 | // in original rule, replace previous symbol with reference to generated rule 221 | out_elements.resize(last_sym_start); 222 | out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); 223 | 224 | pos = parse_space(pos + 1, is_nested); 225 | } else { 226 | break; 227 | } 228 | } 229 | return pos; 230 | } 231 | 232 | const char * parse_alternates( 233 | parse_state & state, 234 | const char * src, 235 | const std::string & rule_name, 236 | uint32_t rule_id, 237 | bool is_nested) { 238 | std::vector rule; 239 | const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); 240 | while (*pos == '|') { 241 | rule.push_back({LLAMA_GRETYPE_ALT, 0}); 242 | pos = parse_space(pos + 1, true); 243 | pos = parse_sequence(state, pos, rule_name, rule, is_nested); 244 | } 245 | rule.push_back({LLAMA_GRETYPE_END, 0}); 246 | add_rule(state, rule_id, rule); 247 | return pos; 248 | } 249 | 250 | static const char * parse_rule(parse_state & state, const char * src) { 251 | const char * name_end = parse_name(src); 252 | const char * pos = parse_space(name_end, false); 253 | size_t name_len = name_end - src; 254 | uint32_t rule_id = get_symbol_id(state, src, name_len); 255 | const std::string name(src, name_len); 256 | 257 | if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { 258 | throw std::runtime_error(std::string("expecting ::= at ") + pos); 259 | } 260 | pos = parse_space(pos + 3, true); 261 | 262 | pos = parse_alternates(state, pos, name, rule_id, false); 263 | 264 | if (*pos == '\r') { 265 | pos += pos[1] == '\n' ? 2 : 1; 266 | } else if (*pos == '\n') { 267 | pos++; 268 | } else if (*pos) { 269 | throw std::runtime_error(std::string("expecting newline or end at ") + pos); 270 | } 271 | return parse_space(pos, true); 272 | } 273 | 274 | parse_state parse(const char * src) { 275 | try { 276 | parse_state state; 277 | const char * pos = parse_space(src, true); 278 | while (*pos) { 279 | pos = parse_rule(state, pos); 280 | } 281 | return state; 282 | } catch (const std::exception & err) { 283 | fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); 284 | return parse_state(); 285 | } 286 | } 287 | 288 | static void print_grammar_char(FILE * file, uint32_t c) { 289 | if (0x20 <= c && c <= 0x7f) { 290 | fprintf(file, "%c", static_cast(c)); 291 | } else { 292 | // cop out of encoding UTF-8 293 | fprintf(file, "", c); 294 | } 295 | } 296 | 297 | static bool is_char_element(llama_grammar_element elem) { 298 | switch (elem.type) { 299 | case LLAMA_GRETYPE_CHAR: return true; 300 | case LLAMA_GRETYPE_CHAR_NOT: return true; 301 | case LLAMA_GRETYPE_CHAR_ALT: return true; 302 | case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; 303 | default: return false; 304 | } 305 | } 306 | 307 | static void print_rule_binary(FILE * file, const std::vector & rule) { 308 | for (auto elem : rule) { 309 | switch (elem.type) { 310 | case LLAMA_GRETYPE_END: fprintf(file, "END"); break; 311 | case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; 312 | case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; 313 | case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; 314 | case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; 315 | case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; 316 | case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; 317 | } 318 | switch (elem.type) { 319 | case LLAMA_GRETYPE_END: 320 | case LLAMA_GRETYPE_ALT: 321 | case LLAMA_GRETYPE_RULE_REF: 322 | fprintf(file, "(%u) ", elem.value); 323 | break; 324 | case LLAMA_GRETYPE_CHAR: 325 | case LLAMA_GRETYPE_CHAR_NOT: 326 | case LLAMA_GRETYPE_CHAR_RNG_UPPER: 327 | case LLAMA_GRETYPE_CHAR_ALT: 328 | fprintf(file, "(\""); 329 | print_grammar_char(file, elem.value); 330 | fprintf(file, "\") "); 331 | break; 332 | } 333 | } 334 | fprintf(file, "\n"); 335 | } 336 | 337 | static void print_rule( 338 | FILE * file, 339 | uint32_t rule_id, 340 | const std::vector & rule, 341 | const std::map & symbol_id_names) { 342 | if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { 343 | throw std::runtime_error( 344 | "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); 345 | } 346 | fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); 347 | for (size_t i = 0, end = rule.size() - 1; i < end; i++) { 348 | llama_grammar_element elem = rule[i]; 349 | switch (elem.type) { 350 | case LLAMA_GRETYPE_END: 351 | throw std::runtime_error( 352 | "unexpected end of rule: " + std::to_string(rule_id) + "," + 353 | std::to_string(i)); 354 | case LLAMA_GRETYPE_ALT: 355 | fprintf(file, "| "); 356 | break; 357 | case LLAMA_GRETYPE_RULE_REF: 358 | fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); 359 | break; 360 | case LLAMA_GRETYPE_CHAR: 361 | fprintf(file, "["); 362 | print_grammar_char(file, elem.value); 363 | break; 364 | case LLAMA_GRETYPE_CHAR_NOT: 365 | fprintf(file, "[^"); 366 | print_grammar_char(file, elem.value); 367 | break; 368 | case LLAMA_GRETYPE_CHAR_RNG_UPPER: 369 | if (i == 0 || !is_char_element(rule[i - 1])) { 370 | throw std::runtime_error( 371 | "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + 372 | std::to_string(rule_id) + "," + std::to_string(i)); 373 | } 374 | fprintf(file, "-"); 375 | print_grammar_char(file, elem.value); 376 | break; 377 | case LLAMA_GRETYPE_CHAR_ALT: 378 | if (i == 0 || !is_char_element(rule[i - 1])) { 379 | throw std::runtime_error( 380 | "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + 381 | std::to_string(rule_id) + "," + std::to_string(i)); 382 | } 383 | print_grammar_char(file, elem.value); 384 | break; 385 | } 386 | if (is_char_element(elem)) { 387 | switch (rule[i + 1].type) { 388 | case LLAMA_GRETYPE_CHAR_ALT: 389 | case LLAMA_GRETYPE_CHAR_RNG_UPPER: 390 | break; 391 | default: 392 | fprintf(file, "] "); 393 | } 394 | } 395 | } 396 | fprintf(file, "\n"); 397 | } 398 | 399 | void print_grammar(FILE * file, const parse_state & state) { 400 | try { 401 | std::map symbol_id_names; 402 | for (const auto & kv : state.symbol_ids) { 403 | symbol_id_names[kv.second] = kv.first; 404 | } 405 | for (size_t i = 0, end = state.rules.size(); i < end; i++) { 406 | // fprintf(file, "%zu: ", i); 407 | // print_rule_binary(file, state.rules[i]); 408 | print_rule(file, uint32_t(i), state.rules[i], symbol_id_names); 409 | // fprintf(file, "\n"); 410 | } 411 | } catch (const std::exception & err) { 412 | fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); 413 | } 414 | } 415 | 416 | std::vector parse_state::c_rules() { 417 | std::vector ret; 418 | ret.reserve(rules.size()); 419 | for (const auto & rule : rules) { 420 | ret.push_back(rule.data()); 421 | } 422 | return ret; 423 | } 424 | } 425 | -------------------------------------------------------------------------------- /llamaCpp/common/log.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | // -------------------------------- 13 | // 14 | // Basic usage: 15 | // 16 | // -------- 17 | // 18 | // The LOG() and LOG_TEE() macros are ready to go by default 19 | // they do not require any initialization. 20 | // 21 | // LOGLN() and LOG_TEELN() are variants which automatically 22 | // include \n character at the end of the log string. 23 | // 24 | // LOG() behaves exactly like printf, by default writing to a logfile. 25 | // LOG_TEE() additionally, prints to the screen too ( mimics Unix tee command ). 26 | // 27 | // Default logfile is named 28 | // "llama..log" 29 | // Default LOG_TEE() secondary output target is 30 | // stderr 31 | // 32 | // Logs can be dynamically disabled or enabled using functions: 33 | // log_disable() 34 | // and 35 | // log_enable() 36 | // 37 | // A log target can be changed with: 38 | // log_set_target( string ) 39 | // creating and opening, or re-opening a file by string filename 40 | // or 41 | // log_set_target( FILE* ) 42 | // allowing to point at stderr, stdout, or any valid FILE* file handler. 43 | // 44 | // -------- 45 | // 46 | // End of Basic usage. 47 | // 48 | // -------------------------------- 49 | 50 | // Specifies a log target. 51 | // default uses log_handler() with "llama.log" log file 52 | // this can be changed, by defining LOG_TARGET 53 | // like so: 54 | // 55 | // #define LOG_TARGET (a valid FILE*) 56 | // #include "log.h" 57 | // 58 | // or it can be simply redirected to stdout or stderr 59 | // like so: 60 | // 61 | // #define LOG_TARGET stderr 62 | // #include "log.h" 63 | // 64 | // The log target can also be redirected to a diffrent function 65 | // like so: 66 | // 67 | // #define LOG_TARGET log_handler_diffrent() 68 | // #include "log.h" 69 | // 70 | // FILE* log_handler_diffrent() 71 | // { 72 | // return stderr; 73 | // } 74 | // 75 | // or: 76 | // 77 | // #define LOG_TARGET log_handler_another_one("somelog.log") 78 | // #include "log.h" 79 | // 80 | // FILE* log_handler_another_one(char*filename) 81 | // { 82 | // static FILE* logfile = nullptr; 83 | // (...) 84 | // if( !logfile ) 85 | // { 86 | // fopen(...) 87 | // } 88 | // (...) 89 | // return logfile 90 | // } 91 | // 92 | #ifndef LOG_TARGET 93 | #define LOG_TARGET log_handler() 94 | #endif 95 | 96 | #ifndef LOG_TEE_TARGET 97 | #define LOG_TEE_TARGET stderr 98 | #endif 99 | 100 | // Utility for synchronizing log configuration state 101 | // since std::optional was introduced only in c++17 102 | enum LogTriState 103 | { 104 | LogTriStateSame, 105 | LogTriStateFalse, 106 | LogTriStateTrue 107 | }; 108 | 109 | // Utility to obtain "pid" like unique process id and use it when creating log files. 110 | inline std::string log_get_pid() 111 | { 112 | static std::string pid; 113 | if (pid.empty()) 114 | { 115 | // std::this_thread::get_id() is the most portable way of obtaining a "process id" 116 | // it's not the same as "pid" but is unique enough to solve multiple instances 117 | // trying to write to the same log. 118 | std::stringstream ss; 119 | ss << std::this_thread::get_id(); 120 | pid = ss.str(); 121 | } 122 | 123 | return pid; 124 | } 125 | 126 | // Utility function for generating log file names with unique id based on thread id. 127 | // invocation with log_filename_generator( "llama", "log" ) creates a string "llama..log" 128 | // where the number is a runtime id of the current thread. 129 | 130 | #define log_filename_generator(log_file_basename, log_file_extension) log_filename_generator_impl(LogTriStateSame, log_file_basename, log_file_extension) 131 | 132 | // INTERNAL, DO NOT USE 133 | inline std::string log_filename_generator_impl(LogTriState multilog, const std::string & log_file_basename, const std::string & log_file_extension) 134 | { 135 | static bool _multilog = false; 136 | 137 | if (multilog != LogTriStateSame) 138 | { 139 | _multilog = multilog == LogTriStateTrue; 140 | } 141 | 142 | std::stringstream buf; 143 | 144 | buf << log_file_basename; 145 | if (_multilog) 146 | { 147 | buf << "."; 148 | buf << log_get_pid(); 149 | } 150 | buf << "."; 151 | buf << log_file_extension; 152 | 153 | return buf.str(); 154 | } 155 | 156 | #ifndef LOG_DEFAULT_FILE_NAME 157 | #define LOG_DEFAULT_FILE_NAME log_filename_generator("llama", "log") 158 | #endif 159 | 160 | // Utility for turning #define values into string literals 161 | // so we can have a define for stderr and 162 | // we can print "stderr" instead of literal stderr, etc. 163 | #define LOG_STRINGIZE1(s) #s 164 | #define LOG_STRINGIZE(s) LOG_STRINGIZE1(s) 165 | 166 | #define LOG_TEE_TARGET_STRING LOG_STRINGIZE(LOG_TEE_TARGET) 167 | 168 | // Allows disabling timestamps. 169 | // in order to disable, define LOG_NO_TIMESTAMPS 170 | // like so: 171 | // 172 | // #define LOG_NO_TIMESTAMPS 173 | // #include "log.h" 174 | // 175 | #ifndef LOG_NO_TIMESTAMPS 176 | #ifndef _MSC_VER 177 | #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] " 178 | #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() 179 | #else 180 | #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] " 181 | #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() 182 | #endif 183 | #else 184 | #define LOG_TIMESTAMP_FMT "%s" 185 | #define LOG_TIMESTAMP_VAL ,"" 186 | #endif 187 | 188 | #ifdef LOG_TEE_TIMESTAMPS 189 | #ifndef _MSC_VER 190 | #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] " 191 | #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() 192 | #else 193 | #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] " 194 | #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() 195 | #endif 196 | #else 197 | #define LOG_TEE_TIMESTAMP_FMT "%s" 198 | #define LOG_TEE_TIMESTAMP_VAL ,"" 199 | #endif 200 | 201 | // Allows disabling file/line/function prefix 202 | // in order to disable, define LOG_NO_FILE_LINE_FUNCTION 203 | // like so: 204 | // 205 | // #define LOG_NO_FILE_LINE_FUNCTION 206 | // #include "log.h" 207 | // 208 | #ifndef LOG_NO_FILE_LINE_FUNCTION 209 | #ifndef _MSC_VER 210 | #define LOG_FLF_FMT "[%24s:%5d][%24s] " 211 | #define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ 212 | #else 213 | #define LOG_FLF_FMT "[%24s:%5ld][%24s] " 214 | #define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ 215 | #endif 216 | #else 217 | #define LOG_FLF_FMT "%s" 218 | #define LOG_FLF_VAL ,"" 219 | #endif 220 | 221 | #ifdef LOG_TEE_FILE_LINE_FUNCTION 222 | #ifndef _MSC_VER 223 | #define LOG_TEE_FLF_FMT "[%24s:%5d][%24s] " 224 | #define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ 225 | #else 226 | #define LOG_TEE_FLF_FMT "[%24s:%5ld][%24s] " 227 | #define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ 228 | #endif 229 | #else 230 | #define LOG_TEE_FLF_FMT "%s" 231 | #define LOG_TEE_FLF_VAL ,"" 232 | #endif 233 | 234 | // INTERNAL, DO NOT USE 235 | // USE LOG() INSTEAD 236 | // 237 | #ifndef _MSC_VER 238 | #define LOG_IMPL(str, ...) \ 239 | do { \ 240 | if (LOG_TARGET != nullptr) \ 241 | { \ 242 | fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ 243 | fflush(LOG_TARGET); \ 244 | } \ 245 | } while (0) 246 | #else 247 | #define LOG_IMPL(str, ...) \ 248 | do { \ 249 | if (LOG_TARGET != nullptr) \ 250 | { \ 251 | fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ 252 | fflush(LOG_TARGET); \ 253 | } \ 254 | } while (0) 255 | #endif 256 | 257 | // INTERNAL, DO NOT USE 258 | // USE LOG_TEE() INSTEAD 259 | // 260 | #ifndef _MSC_VER 261 | #define LOG_TEE_IMPL(str, ...) \ 262 | do { \ 263 | if (LOG_TARGET != nullptr) \ 264 | { \ 265 | fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ 266 | fflush(LOG_TARGET); \ 267 | } \ 268 | if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr) \ 269 | { \ 270 | fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL, __VA_ARGS__); \ 271 | fflush(LOG_TEE_TARGET); \ 272 | } \ 273 | } while (0) 274 | #else 275 | #define LOG_TEE_IMPL(str, ...) \ 276 | do { \ 277 | if (LOG_TARGET != nullptr) \ 278 | { \ 279 | fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ 280 | fflush(LOG_TARGET); \ 281 | } \ 282 | if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr) \ 283 | { \ 284 | fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL "", ##__VA_ARGS__); \ 285 | fflush(LOG_TEE_TARGET); \ 286 | } \ 287 | } while (0) 288 | #endif 289 | 290 | // The '\0' as a last argument, is a trick to bypass the silly 291 | // "warning: ISO C++11 requires at least one argument for the "..." in a variadic macro" 292 | // so we can have a single macro which can be called just like printf. 293 | 294 | // Main LOG macro. 295 | // behaves like printf, and supports arguments the exact same way. 296 | // 297 | #ifndef _MSC_VER 298 | #define LOG(...) LOG_IMPL(__VA_ARGS__, "") 299 | #else 300 | #define LOG(str, ...) LOG_IMPL("%s" str, "", __VA_ARGS__, "") 301 | #endif 302 | 303 | // Main TEE macro. 304 | // does the same as LOG 305 | // and 306 | // simultaneously writes stderr. 307 | // 308 | // Secondary target can be changed just like LOG_TARGET 309 | // by defining LOG_TEE_TARGET 310 | // 311 | #ifndef _MSC_VER 312 | #define LOG_TEE(...) LOG_TEE_IMPL(__VA_ARGS__, "") 313 | #else 314 | #define LOG_TEE(str, ...) LOG_TEE_IMPL("%s" str, "", __VA_ARGS__, "") 315 | #endif 316 | 317 | // LOG macro variants with auto endline. 318 | #ifndef _MSC_VER 319 | #define LOGLN(...) LOG_IMPL(__VA_ARGS__, "\n") 320 | #define LOG_TEELN(...) LOG_TEE_IMPL(__VA_ARGS__, "\n") 321 | #else 322 | #define LOGLN(str, ...) LOG_IMPL("%s" str, "", __VA_ARGS__, "\n") 323 | #define LOG_TEELN(str, ...) LOG_TEE_IMPL("%s" str, "", __VA_ARGS__, "\n") 324 | #endif 325 | 326 | // INTERNAL, DO NOT USE 327 | inline FILE *log_handler1_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, const std::string & filename = LOG_DEFAULT_FILE_NAME, FILE *target = nullptr) 328 | { 329 | static bool _initialized = false; 330 | static bool _append = false; 331 | static bool _disabled = filename.empty() && target == nullptr; 332 | static std::string log_current_filename{filename}; 333 | static FILE *log_current_target{target}; 334 | static FILE *logfile = nullptr; 335 | 336 | if (change) 337 | { 338 | if (append != LogTriStateSame) 339 | { 340 | _append = append == LogTriStateTrue; 341 | return logfile; 342 | } 343 | 344 | if (disable == LogTriStateTrue) 345 | { 346 | // Disable primary target 347 | _disabled = true; 348 | } 349 | // If previously disabled, only enable, and keep previous target 350 | else if (disable == LogTriStateFalse) 351 | { 352 | _disabled = false; 353 | } 354 | // Otherwise, process the arguments 355 | else if (log_current_filename != filename || log_current_target != target) 356 | { 357 | _initialized = false; 358 | } 359 | } 360 | 361 | if (_disabled) 362 | { 363 | // Log is disabled 364 | return nullptr; 365 | } 366 | 367 | if (_initialized) 368 | { 369 | // with fallback in case something went wrong 370 | return logfile ? logfile : stderr; 371 | } 372 | 373 | // do the (re)initialization 374 | if (target != nullptr) 375 | { 376 | if (logfile != nullptr && logfile != stdout && logfile != stderr) 377 | { 378 | fclose(logfile); 379 | } 380 | 381 | log_current_filename = LOG_DEFAULT_FILE_NAME; 382 | log_current_target = target; 383 | 384 | logfile = target; 385 | } 386 | else 387 | { 388 | if (log_current_filename != filename) 389 | { 390 | if (logfile != nullptr && logfile != stdout && logfile != stderr) 391 | { 392 | fclose(logfile); 393 | } 394 | } 395 | 396 | logfile = fopen(filename.c_str(), _append ? "a" : "w"); 397 | } 398 | 399 | if (!logfile) 400 | { 401 | // Verify whether the file was opened, otherwise fallback to stderr 402 | logfile = stderr; 403 | 404 | fprintf(stderr, "Failed to open logfile '%s' with error '%s'\n", filename.c_str(), std::strerror(errno)); 405 | fflush(stderr); 406 | 407 | // At this point we let the init flag be to true below, and let the target fallback to stderr 408 | // otherwise we would repeatedly fopen() which was already unsuccessful 409 | } 410 | 411 | _initialized = true; 412 | 413 | return logfile ? logfile : stderr; 414 | } 415 | 416 | // INTERNAL, DO NOT USE 417 | inline FILE *log_handler2_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, FILE *target = nullptr, const std::string & filename = LOG_DEFAULT_FILE_NAME) 418 | { 419 | return log_handler1_impl(change, append, disable, filename, target); 420 | } 421 | 422 | // Disables logs entirely at runtime. 423 | // Makes LOG() and LOG_TEE() produce no output, 424 | // untill enabled back. 425 | #define log_disable() log_disable_impl() 426 | 427 | // INTERNAL, DO NOT USE 428 | inline FILE *log_disable_impl() 429 | { 430 | return log_handler1_impl(true, LogTriStateSame, LogTriStateTrue); 431 | } 432 | 433 | // Enables logs at runtime. 434 | #define log_enable() log_enable_impl() 435 | 436 | // INTERNAL, DO NOT USE 437 | inline FILE *log_enable_impl() 438 | { 439 | return log_handler1_impl(true, LogTriStateSame, LogTriStateFalse); 440 | } 441 | 442 | // Sets target fir logs, either by a file name or FILE* pointer (stdout, stderr, or any valid FILE*) 443 | #define log_set_target(target) log_set_target_impl(target) 444 | 445 | // INTERNAL, DO NOT USE 446 | inline FILE *log_set_target_impl(const std::string & filename) { return log_handler1_impl(true, LogTriStateSame, LogTriStateSame, filename); } 447 | inline FILE *log_set_target_impl(FILE *target) { return log_handler2_impl(true, LogTriStateSame, LogTriStateSame, target); } 448 | 449 | // INTERNAL, DO NOT USE 450 | inline FILE *log_handler() { return log_handler1_impl(); } 451 | 452 | // Enable or disable creating separate log files for each run. 453 | // can ONLY be invoked BEFORE first log use. 454 | #define log_multilog(enable) log_filename_generator_impl((enable) ? LogTriStateTrue : LogTriStateFalse, "", "") 455 | // Enable or disable append mode for log file. 456 | // can ONLY be invoked BEFORE first log use. 457 | #define log_append(enable) log_append_impl(enable) 458 | // INTERNAL, DO NOT USE 459 | inline FILE *log_append_impl(bool enable) 460 | { 461 | return log_handler1_impl(true, enable ? LogTriStateTrue : LogTriStateFalse, LogTriStateSame); 462 | } 463 | 464 | inline void log_test() 465 | { 466 | log_disable(); 467 | LOG("01 Hello World to nobody, because logs are disabled!\n"); 468 | log_enable(); 469 | LOG("02 Hello World to default output, which is \"%s\" ( Yaaay, arguments! )!\n", LOG_STRINGIZE(LOG_TARGET)); 470 | LOG_TEE("03 Hello World to **both** default output and " LOG_TEE_TARGET_STRING "!\n"); 471 | log_set_target(stderr); 472 | LOG("04 Hello World to stderr!\n"); 473 | LOG_TEE("05 Hello World TEE with double printing to stderr prevented!\n"); 474 | log_set_target(LOG_DEFAULT_FILE_NAME); 475 | LOG("06 Hello World to default log file!\n"); 476 | log_set_target(stdout); 477 | LOG("07 Hello World to stdout!\n"); 478 | log_set_target(LOG_DEFAULT_FILE_NAME); 479 | LOG("08 Hello World to default log file again!\n"); 480 | log_disable(); 481 | LOG("09 Hello World _1_ into the void!\n"); 482 | log_enable(); 483 | LOG("10 Hello World back from the void ( you should not see _1_ in the log or the output )!\n"); 484 | log_disable(); 485 | log_set_target("llama.anotherlog.log"); 486 | LOG("11 Hello World _2_ to nobody, new target was selected but logs are still disabled!\n"); 487 | log_enable(); 488 | LOG("12 Hello World this time in a new file ( you should not see _2_ in the log or the output )?\n"); 489 | log_set_target("llama.yetanotherlog.log"); 490 | LOG("13 Hello World this time in yet new file?\n"); 491 | log_set_target(log_filename_generator("llama_autonamed", "log")); 492 | LOG("14 Hello World in log with generated filename!\n"); 493 | #ifdef _MSC_VER 494 | LOG_TEE("15 Hello msvc TEE without arguments\n"); 495 | LOG_TEE("16 Hello msvc TEE with (%d)(%s) arguments\n", 1, "test"); 496 | LOG_TEELN("17 Hello msvc TEELN without arguments\n"); 497 | LOG_TEELN("18 Hello msvc TEELN with (%d)(%s) arguments\n", 1, "test"); 498 | LOG("19 Hello msvc LOG without arguments\n"); 499 | LOG("20 Hello msvc LOG with (%d)(%s) arguments\n", 1, "test"); 500 | LOGLN("21 Hello msvc LOGLN without arguments\n"); 501 | LOGLN("22 Hello msvc LOGLN with (%d)(%s) arguments\n", 1, "test"); 502 | #endif 503 | } 504 | 505 | inline bool log_param_single_parse(const std::string & param) 506 | { 507 | if ( param == "--log-test") 508 | { 509 | log_test(); 510 | return true; 511 | } 512 | 513 | if ( param == "--log-disable") 514 | { 515 | log_disable(); 516 | return true; 517 | } 518 | 519 | if ( param == "--log-enable") 520 | { 521 | log_enable(); 522 | return true; 523 | } 524 | 525 | if (param == "--log-new") 526 | { 527 | log_multilog(true); 528 | return true; 529 | } 530 | 531 | if (param == "--log-append") 532 | { 533 | log_append(true); 534 | return true; 535 | } 536 | 537 | return false; 538 | } 539 | 540 | inline bool log_param_pair_parse(bool check_but_dont_parse, const std::string & param, const std::string & next = std::string()) 541 | { 542 | if ( param == "--log-file") 543 | { 544 | if (!check_but_dont_parse) 545 | { 546 | log_set_target(log_filename_generator(next.empty() ? "unnamed" : next, "log")); 547 | } 548 | 549 | return true; 550 | } 551 | 552 | return false; 553 | } 554 | 555 | inline void log_print_usage() 556 | { 557 | printf("log options:\n"); 558 | /* format 559 | printf(" -h, --help show this help message and exit\n");*/ 560 | /* spacing 561 | printf("__-param----------------Description\n");*/ 562 | printf(" --log-test Run simple logging test\n"); 563 | printf(" --log-disable Disable trace logs\n"); 564 | printf(" --log-enable Enable trace logs\n"); 565 | printf(" --log-file Specify a log filename (without extension)\n"); 566 | printf(" --log-new Create a separate new log file on start. " 567 | "Each log file will have unique name: \"..log\"\n"); 568 | printf(" --log-append Don't truncate the old log file.\n"); 569 | } 570 | 571 | #define log_dump_cmdline(argc, argv) log_dump_cmdline_impl(argc, argv) 572 | 573 | // INTERNAL, DO NOT USE 574 | inline void log_dump_cmdline_impl(int argc, char **argv) 575 | { 576 | std::stringstream buf; 577 | for (int i = 0; i < argc; ++i) 578 | { 579 | if (std::string(argv[i]).find(' ') != std::string::npos) 580 | { 581 | buf << " \"" << argv[i] <<"\""; 582 | } 583 | else 584 | { 585 | buf << " " << argv[i]; 586 | } 587 | } 588 | LOGLN("Cmd:%s", buf.str().c_str()); 589 | } 590 | 591 | #define log_tostr(var) log_var_to_string_impl(var).c_str() 592 | 593 | inline std::string log_var_to_string_impl(bool var) 594 | { 595 | return var ? "true" : "false"; 596 | } 597 | 598 | inline std::string log_var_to_string_impl(std::string var) 599 | { 600 | return var; 601 | } 602 | 603 | inline std::string log_var_to_string_impl(const std::vector & var) 604 | { 605 | std::stringstream buf; 606 | buf << "[ "; 607 | bool first = true; 608 | for (auto e : var) 609 | { 610 | if (first) 611 | { 612 | first = false; 613 | } 614 | else 615 | { 616 | buf << ", "; 617 | } 618 | buf << std::to_string(e); 619 | } 620 | buf << " ]"; 621 | 622 | return buf.str(); 623 | } 624 | 625 | template 626 | inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens) 627 | { 628 | std::stringstream buf; 629 | buf << "[ "; 630 | 631 | bool first = true; 632 | for (const auto &token : tokens) 633 | { 634 | if (!first) { 635 | buf << ", "; 636 | } else { 637 | first = false; 638 | } 639 | 640 | auto detokenized = llama_token_to_piece(ctx, token); 641 | 642 | detokenized.erase( 643 | std::remove_if( 644 | detokenized.begin(), 645 | detokenized.end(), 646 | [](const unsigned char c) { return !std::isprint(c); }), 647 | detokenized.end()); 648 | 649 | buf 650 | << "'" << detokenized << "'" 651 | << ":" << std::to_string(token); 652 | } 653 | buf << " ]"; 654 | 655 | return buf.str(); 656 | } 657 | 658 | template 659 | inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch) 660 | { 661 | std::stringstream buf; 662 | buf << "[ "; 663 | 664 | bool first = true; 665 | for (int i = 0; i < batch.n_tokens; ++i) 666 | { 667 | if (!first) { 668 | buf << ", "; 669 | } else { 670 | first = false; 671 | } 672 | 673 | auto detokenized = llama_token_to_piece(ctx, batch.token[i]); 674 | 675 | detokenized.erase( 676 | std::remove_if( 677 | detokenized.begin(), 678 | detokenized.end(), 679 | [](const unsigned char c) { return !std::isprint(c); }), 680 | detokenized.end()); 681 | 682 | buf 683 | << "\n" << std::to_string(i) 684 | << ":token '" << detokenized << "'" 685 | << ":pos " << std::to_string(batch.pos[i]) 686 | << ":n_seq_id " << std::to_string(batch.n_seq_id[i]) 687 | << ":seq_id " << std::to_string(batch.seq_id[i][0]) 688 | << ":logits " << std::to_string(batch.logits[i]); 689 | } 690 | buf << " ]"; 691 | 692 | return buf.str(); 693 | } 694 | 695 | #ifdef LOG_DISABLE_LOGS 696 | 697 | #undef LOG 698 | #define LOG(...) // dummy stub 699 | #undef LOGLN 700 | #define LOGLN(...) // dummy stub 701 | 702 | #undef LOG_TEE 703 | #define LOG_TEE(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf 704 | 705 | #undef LOG_TEELN 706 | #define LOG_TEELN(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf 707 | 708 | #undef LOG_DISABLE 709 | #define LOG_DISABLE() // dummy stub 710 | 711 | #undef LOG_ENABLE 712 | #define LOG_ENABLE() // dummy stub 713 | 714 | #undef LOG_ENABLE 715 | #define LOG_ENABLE() // dummy stub 716 | 717 | #undef LOG_SET_TARGET 718 | #define LOG_SET_TARGET(...) // dummy stub 719 | 720 | #undef LOG_DUMP_CMDLINE 721 | #define LOG_DUMP_CMDLINE(...) // dummy stub 722 | 723 | #endif // LOG_DISABLE_LOGS 724 | -------------------------------------------------------------------------------- /llamaCpp/ggml-alloc.c: -------------------------------------------------------------------------------- 1 | #include "ggml-alloc.h" 2 | #include "ggml-backend-impl.h" 3 | #include "ggml.h" 4 | #include "ggml-impl.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #define MAX(a, b) ((a) > (b) ? (a) : (b)) 13 | #define MAX_FREE_BLOCKS 256 14 | 15 | //#define GGML_ALLOCATOR_DEBUG 16 | 17 | //#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__) 18 | #define AT_PRINTF(...) 19 | 20 | // TODO: GGML_PAD ? 21 | static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) { 22 | assert(alignment && !(alignment & (alignment - 1))); // power of 2 23 | size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment; 24 | return offset + align; 25 | } 26 | 27 | struct free_block { 28 | void * addr; 29 | size_t size; 30 | }; 31 | 32 | struct ggml_tallocr { 33 | struct ggml_backend_buffer * buffer; 34 | bool buffer_owned; 35 | void * base; 36 | size_t alignment; 37 | 38 | int n_free_blocks; 39 | struct free_block free_blocks[MAX_FREE_BLOCKS]; 40 | 41 | size_t max_size; 42 | 43 | bool measure; 44 | 45 | #ifdef GGML_ALLOCATOR_DEBUG 46 | struct ggml_tensor * allocated_tensors[1024]; 47 | #endif 48 | }; 49 | 50 | #ifdef GGML_ALLOCATOR_DEBUG 51 | static void add_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) { 52 | for (int i = 0; i < 1024; i++) { 53 | if (alloc->allocated_tensors[i] == NULL) { 54 | alloc->allocated_tensors[i] = tensor; 55 | return; 56 | } 57 | } 58 | GGML_ASSERT(!"out of allocated_tensors"); 59 | } 60 | static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) { 61 | for (int i = 0; i < 1024; i++) { 62 | if (alloc->allocated_tensors[i] == tensor || 63 | (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) { 64 | alloc->allocated_tensors[i] = NULL; 65 | return; 66 | } 67 | } 68 | printf("tried to free tensor %s not found\n", tensor->name); 69 | GGML_ASSERT(!"tensor not found"); 70 | } 71 | #endif 72 | 73 | // check if a tensor is allocated by this buffer 74 | static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) { 75 | return tensor->buffer == alloc->buffer; 76 | } 77 | 78 | static bool ggml_is_view(struct ggml_tensor * t) { 79 | return t->view_src != NULL; 80 | } 81 | 82 | void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) { 83 | GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources 84 | GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated 85 | 86 | size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor); 87 | size = aligned_offset(NULL, size, alloc->alignment); 88 | 89 | AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size); 90 | 91 | size_t max_avail = 0; 92 | 93 | // find the best fitting free block besides the last block 94 | int best_fit_block = -1; 95 | size_t best_fit_size = SIZE_MAX; 96 | for (int i = 0; i < alloc->n_free_blocks - 1; i++) { 97 | struct free_block * block = &alloc->free_blocks[i]; 98 | max_avail = MAX(max_avail, block->size); 99 | if (block->size >= size && block->size <= best_fit_size) { 100 | best_fit_block = i; 101 | best_fit_size = block->size; 102 | } 103 | } 104 | 105 | AT_PRINTF("block %d\n", best_fit_block); 106 | 107 | if (best_fit_block == -1) { 108 | // the last block is our last resort 109 | struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1]; 110 | max_avail = MAX(max_avail, block->size); 111 | if (block->size >= size) { 112 | best_fit_block = alloc->n_free_blocks - 1; 113 | } else { 114 | fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n", 115 | __func__, size, max_avail); 116 | GGML_ASSERT(!"not enough space in the buffer"); 117 | return; 118 | } 119 | } 120 | struct free_block * block = &alloc->free_blocks[best_fit_block]; 121 | void * addr = block->addr; 122 | block->addr = (char*)block->addr + size; 123 | block->size -= size; 124 | if (block->size == 0) { 125 | // remove block if empty 126 | alloc->n_free_blocks--; 127 | for (int j = best_fit_block; j < alloc->n_free_blocks; j++) { 128 | alloc->free_blocks[j] = alloc->free_blocks[j+1]; 129 | } 130 | } 131 | 132 | tensor->data = addr; 133 | tensor->buffer = alloc->buffer; 134 | if (!alloc->measure) { 135 | ggml_backend_buffer_init_tensor(alloc->buffer, tensor); 136 | } 137 | 138 | #ifdef GGML_ALLOCATOR_DEBUG 139 | add_allocated_tensor(alloc, tensor); 140 | size_t cur_max = (char*)addr - (char*)alloc->data + size; 141 | if (cur_max > alloc->max_size) { 142 | printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); 143 | for (int i = 0; i < 1024; i++) { 144 | if (alloc->allocated_tensors[i]) { 145 | printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0); 146 | } 147 | } 148 | printf("\n"); 149 | } 150 | #endif 151 | 152 | alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size); 153 | } 154 | 155 | // this is a very naive implementation, but for our case the number of free blocks should be very small 156 | static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) { 157 | if (ggml_tallocr_is_own(alloc, tensor) == false) { 158 | // the tensor was not allocated in this buffer 159 | // this can happen because the graph allocator will try to free weights and other tensors from different buffers 160 | // the easiest way to deal with this is just to ignore it 161 | // AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer); 162 | return; 163 | } 164 | 165 | void * ptr = tensor->data; 166 | 167 | size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor); 168 | size = aligned_offset(NULL, size, alloc->alignment); 169 | AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks); 170 | 171 | if (!alloc->measure) { 172 | ggml_backend_buffer_free_tensor(alloc->buffer, tensor); 173 | } 174 | 175 | #ifdef GGML_ALLOCATOR_DEBUG 176 | remove_allocated_tensor(alloc, tensor); 177 | #endif 178 | 179 | // see if we can merge with an existing block 180 | for (int i = 0; i < alloc->n_free_blocks; i++) { 181 | struct free_block * block = &alloc->free_blocks[i]; 182 | // check if ptr is at the end of the block 183 | if ((char*)block->addr + block->size == ptr) { 184 | block->size += size; 185 | // check if we can merge with the next block 186 | if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) { 187 | block->size += alloc->free_blocks[i+1].size; 188 | alloc->n_free_blocks--; 189 | for (int j = i+1; j < alloc->n_free_blocks; j++) { 190 | alloc->free_blocks[j] = alloc->free_blocks[j+1]; 191 | } 192 | } 193 | return; 194 | } 195 | // check if ptr is at the beginning of the block 196 | if ((char*)ptr + size == block->addr) { 197 | block->addr = ptr; 198 | block->size += size; 199 | // check if we can merge with the previous block 200 | if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) { 201 | alloc->free_blocks[i-1].size += block->size; 202 | alloc->n_free_blocks--; 203 | for (int j = i; j < alloc->n_free_blocks; j++) { 204 | alloc->free_blocks[j] = alloc->free_blocks[j+1]; 205 | } 206 | } 207 | return; 208 | } 209 | } 210 | // otherwise, add a new block 211 | GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks"); 212 | // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster) 213 | int insert_pos = 0; 214 | while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) { 215 | insert_pos++; 216 | } 217 | // shift all blocks from insert_pos onward to make room for the new block 218 | for (int i = alloc->n_free_blocks; i > insert_pos; i--) { 219 | alloc->free_blocks[i] = alloc->free_blocks[i-1]; 220 | } 221 | // insert the new block 222 | alloc->free_blocks[insert_pos].addr = ptr; 223 | alloc->free_blocks[insert_pos].size = size; 224 | alloc->n_free_blocks++; 225 | } 226 | 227 | void ggml_tallocr_reset(ggml_tallocr_t alloc) { 228 | alloc->n_free_blocks = 1; 229 | size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment); 230 | alloc->free_blocks[0].addr = (char *)alloc->base + align_offset; 231 | 232 | if (alloc->measure) { 233 | alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows 234 | } else { 235 | alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset; 236 | } 237 | } 238 | 239 | ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) { 240 | struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size); 241 | 242 | ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr)); 243 | 244 | *alloc = (struct ggml_tallocr) { 245 | /*.buffer = */ buffer, 246 | /*.buffer_owned = */ true, 247 | /*.base = */ ggml_backend_buffer_get_base(buffer), 248 | /*.alignment = */ alignment, 249 | /*.n_free_blocks = */ 0, 250 | /*.free_blocks = */ {{0}}, 251 | /*.max_size = */ 0, 252 | /*.measure = */ false, 253 | #ifdef GGML_ALLOCATOR_DEBUG 254 | /*.allocated_tensors = */ {0}, 255 | #endif 256 | }; 257 | 258 | ggml_tallocr_reset(alloc); 259 | 260 | return alloc; 261 | } 262 | 263 | ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment) { 264 | ggml_tallocr_t alloc = ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment); 265 | alloc->measure = true; 266 | 267 | return alloc; 268 | } 269 | 270 | ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend) { 271 | // create a backend buffer to get the correct tensor allocation sizes 272 | ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, 1); 273 | 274 | // TODO: move alloc initialization to a common ggml_tallocr_new_impl function 275 | ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer); 276 | alloc->buffer_owned = true; 277 | alloc->measure = true; 278 | ggml_tallocr_reset(alloc); 279 | return alloc; 280 | } 281 | 282 | ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size) { 283 | ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, size); 284 | ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer); 285 | alloc->buffer_owned = true; 286 | return alloc; 287 | } 288 | 289 | ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer) { 290 | ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr)); 291 | 292 | *alloc = (struct ggml_tallocr) { 293 | /*.buffer = */ buffer, 294 | /*.buffer_owned = */ false, 295 | /*.base = */ ggml_backend_buffer_get_base(buffer), 296 | /*.alignment = */ ggml_backend_buffer_get_alignment(buffer), 297 | /*.n_free_blocks = */ 0, 298 | /*.free_blocks = */ {{0}}, 299 | /*.max_size = */ 0, 300 | /*.measure = */ false, 301 | #ifdef GGML_ALLOCATOR_DEBUG 302 | /*.allocated_tensors = */ {0}, 303 | #endif 304 | }; 305 | 306 | ggml_tallocr_reset(alloc); 307 | 308 | return alloc; 309 | } 310 | 311 | struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) { 312 | return alloc->buffer; 313 | } 314 | 315 | void ggml_tallocr_free(ggml_tallocr_t alloc) { 316 | if (alloc == NULL) { 317 | return; 318 | } 319 | 320 | if (alloc->buffer_owned) { 321 | ggml_backend_buffer_free(alloc->buffer); 322 | } 323 | free(alloc); 324 | } 325 | 326 | bool ggml_tallocr_is_measure(ggml_tallocr_t alloc) { 327 | return alloc->measure; 328 | } 329 | 330 | size_t ggml_tallocr_max_size(ggml_tallocr_t alloc) { 331 | return alloc->max_size; 332 | } 333 | 334 | // graph allocator 335 | 336 | struct hash_node { 337 | int n_children; 338 | int n_views; 339 | }; 340 | 341 | struct ggml_gallocr { 342 | ggml_tallocr_t talloc; 343 | struct ggml_hash_set hash_set; 344 | struct hash_node * hash_values; 345 | size_t hash_values_size; 346 | ggml_tallocr_t * hash_allocs; 347 | int * parse_seq; 348 | int parse_seq_len; 349 | }; 350 | 351 | ggml_gallocr_t ggml_gallocr_new(void) { 352 | ggml_gallocr_t galloc = (ggml_gallocr_t)malloc(sizeof(struct ggml_gallocr)); 353 | 354 | *galloc = (struct ggml_gallocr) { 355 | /*.talloc = */ NULL, 356 | /*.hash_set = */ {0}, 357 | /*.hash_values = */ NULL, 358 | /*.hash_values_size = */ 0, 359 | /*.hash_allocs = */ NULL, 360 | /*.parse_seq = */ NULL, 361 | /*.parse_seq_len = */ 0, 362 | }; 363 | 364 | return galloc; 365 | } 366 | 367 | void ggml_gallocr_free(ggml_gallocr_t galloc) { 368 | if (galloc == NULL) { 369 | return; 370 | } 371 | 372 | if (galloc->hash_set.keys != NULL) { 373 | free(galloc->hash_set.keys); 374 | } 375 | if (galloc->hash_values != NULL) { 376 | free(galloc->hash_values); 377 | } 378 | if (galloc->hash_allocs != NULL) { 379 | free(galloc->hash_allocs); 380 | } 381 | if (galloc->parse_seq != NULL) { 382 | free(galloc->parse_seq); 383 | } 384 | free(galloc); 385 | } 386 | 387 | void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n) { 388 | free(galloc->parse_seq); 389 | galloc->parse_seq = malloc(sizeof(int) * n); 390 | 391 | for (int i = 0; i < n; i++) { 392 | galloc->parse_seq[i] = list[i]; 393 | } 394 | galloc->parse_seq_len = n; 395 | } 396 | 397 | static struct hash_node * hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) { 398 | size_t i = ggml_hash_find_or_insert(galloc->hash_set, t); 399 | return &galloc->hash_values[i]; 400 | } 401 | 402 | static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { 403 | if (a->type != b->type) { 404 | return false; 405 | } 406 | for (int i = 0; i < GGML_MAX_DIMS; i++) { 407 | if (a->ne[i] != b->ne[i]) { 408 | return false; 409 | } 410 | if (a->nb[i] != b->nb[i]) { 411 | return false; 412 | } 413 | } 414 | return true; 415 | } 416 | 417 | static bool ggml_op_can_inplace(enum ggml_op op) { 418 | switch (op) { 419 | case GGML_OP_SCALE: 420 | case GGML_OP_DIAG_MASK_ZERO: 421 | case GGML_OP_DIAG_MASK_INF: 422 | case GGML_OP_ADD: 423 | case GGML_OP_ADD1: 424 | case GGML_OP_SUB: 425 | case GGML_OP_MUL: 426 | case GGML_OP_DIV: 427 | case GGML_OP_SQR: 428 | case GGML_OP_SQRT: 429 | case GGML_OP_LOG: 430 | case GGML_OP_UNARY: 431 | case GGML_OP_ROPE: 432 | case GGML_OP_RMS_NORM: 433 | case GGML_OP_SOFT_MAX: 434 | return true; 435 | 436 | default: 437 | return false; 438 | } 439 | } 440 | 441 | static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * node) { 442 | if (galloc->talloc != NULL) { 443 | return galloc->talloc; 444 | } 445 | 446 | return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)]; 447 | } 448 | 449 | static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) { 450 | ggml_tallocr_t alloc = node_tallocr(galloc, view); 451 | 452 | //printf("init_view: %s from src %s\n", view->name, view->view_src->name); 453 | GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL); 454 | if (update_backend) { 455 | view->backend = view->view_src->backend; 456 | } 457 | view->buffer = view->view_src->buffer; 458 | view->data = (char *)view->view_src->data + view->view_offs; 459 | 460 | // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend 461 | // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras 462 | assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend); 463 | 464 | if (!alloc->measure) { 465 | ggml_backend_buffer_init_tensor(alloc->buffer, view); 466 | } 467 | } 468 | 469 | static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) { 470 | ggml_tallocr_t alloc = node_tallocr(galloc, node); 471 | 472 | if (node->data == NULL) { 473 | if (ggml_is_view(node)) { 474 | init_view(galloc, node, true); 475 | } else { 476 | // see if we can reuse a parent's buffer (inplace) 477 | if (ggml_op_can_inplace(node->op)) { 478 | for (int i = 0; i < GGML_MAX_SRC; i++) { 479 | struct ggml_tensor * parent = node->src[i]; 480 | if (parent == NULL) { 481 | break; 482 | } 483 | 484 | // if the node's data is external, then we cannot re-use it 485 | if (ggml_tallocr_is_own(alloc, parent) == false) { 486 | AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data); 487 | continue; 488 | } 489 | 490 | struct hash_node * p_hn = hash_get(galloc, parent); 491 | if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) { 492 | if (ggml_is_view(parent)) { 493 | struct ggml_tensor * view_src = parent->view_src; 494 | struct hash_node * view_src_hn = hash_get(galloc, view_src); 495 | if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { 496 | // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite 497 | // the parent's data that it will need later (same layout requirement). the problem is that then 498 | // we cannot free the tensor because the original address of the allocation is lost. 499 | // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views 500 | // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data) 501 | AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name); 502 | node->view_src = view_src; 503 | view_src_hn->n_views += 1; 504 | init_view(galloc, node, false); 505 | return; 506 | } 507 | } else { 508 | AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name); 509 | node->view_src = parent; 510 | p_hn->n_views += 1; 511 | init_view(galloc, node, false); 512 | return; 513 | } 514 | } 515 | } 516 | } 517 | ggml_tallocr_alloc(alloc, node); 518 | } 519 | } 520 | } 521 | 522 | static void free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) { 523 | ggml_tallocr_t alloc = node_tallocr(galloc, node); 524 | 525 | ggml_tallocr_free_tensor(alloc, node); 526 | } 527 | 528 | static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * gf) { 529 | const int * parse_seq = galloc->parse_seq; 530 | int parse_seq_len = galloc->parse_seq_len; 531 | 532 | // count number of children and views 533 | for (int i = 0; i < gf->n_nodes; i++) { 534 | struct ggml_tensor * node = gf->nodes[i]; 535 | 536 | if (ggml_is_view(node)) { 537 | struct ggml_tensor * view_src = node->view_src; 538 | hash_get(galloc, view_src)->n_views += 1; 539 | if (node->buffer == NULL && node->data != NULL) { 540 | // view of a pre-allocated tensor, didn't call init_view() yet 541 | init_view(galloc, node, true); 542 | } 543 | } 544 | 545 | for (int j = 0; j < GGML_MAX_SRC; j++) { 546 | struct ggml_tensor * parent = node->src[j]; 547 | if (parent == NULL) { 548 | break; 549 | } 550 | hash_get(galloc, parent)->n_children += 1; 551 | if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) { 552 | init_view(galloc, parent, true); 553 | } 554 | } 555 | } 556 | 557 | // allocate tensors 558 | // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers 559 | int last_barrier_pos = 0; 560 | int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes; 561 | 562 | for (int ind = 0; ind < n_nodes; ind++) { 563 | // allocate a node if there is no parse_seq or this is not a barrier 564 | if (parse_seq_len == 0 || parse_seq[ind] != -1) { 565 | int i = parse_seq_len ? parse_seq[ind] : ind; 566 | struct ggml_tensor * node = gf->nodes[i]; 567 | 568 | // allocate parents (leafs) 569 | for (int j = 0; j < GGML_MAX_SRC; j++) { 570 | struct ggml_tensor * parent = node->src[j]; 571 | if (parent == NULL) { 572 | break; 573 | } 574 | allocate_node(galloc, parent); 575 | } 576 | 577 | // allocate node 578 | allocate_node(galloc, node); 579 | 580 | AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name); 581 | for (int j = 0; j < GGML_MAX_SRC; j++) { 582 | struct ggml_tensor * parent = node->src[j]; 583 | if (parent == NULL) { 584 | break; 585 | } 586 | AT_PRINTF("%s", parent->name); 587 | if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) { 588 | AT_PRINTF(", "); 589 | } 590 | } 591 | AT_PRINTF("\n"); 592 | } 593 | 594 | // update parents 595 | // update immediately if there is no parse_seq 596 | // update only at barriers if there is parse_seq 597 | if ((parse_seq_len == 0) || parse_seq[ind] == -1) { 598 | int update_start = parse_seq_len ? last_barrier_pos : ind; 599 | int update_end = parse_seq_len ? ind : ind + 1; 600 | for (int i = update_start; i < update_end; i++) { 601 | int node_i = parse_seq_len ? parse_seq[i] : i; 602 | struct ggml_tensor * node = gf->nodes[node_i]; 603 | 604 | for (int j = 0; j < GGML_MAX_SRC; j++) { 605 | struct ggml_tensor * parent = node->src[j]; 606 | if (parent == NULL) { 607 | break; 608 | } 609 | struct hash_node * p_hn = hash_get(galloc, parent); 610 | p_hn->n_children -= 1; 611 | 612 | //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views); 613 | 614 | if (p_hn->n_children == 0 && p_hn->n_views == 0) { 615 | if (ggml_is_view(parent)) { 616 | struct ggml_tensor * view_src = parent->view_src; 617 | struct hash_node * view_src_hn = hash_get(galloc, view_src); 618 | view_src_hn->n_views -= 1; 619 | AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views); 620 | if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) { 621 | free_node(galloc, view_src); 622 | } 623 | } 624 | else { 625 | free_node(galloc, parent); 626 | } 627 | } 628 | } 629 | } 630 | AT_PRINTF("\n"); 631 | if (parse_seq_len) { 632 | last_barrier_pos = ind + 1; 633 | } 634 | } 635 | } 636 | } 637 | 638 | size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph) { 639 | size_t hash_size = graph->visited_hash_table.size; 640 | 641 | // check if the hash table is initialized and large enough 642 | if (galloc->hash_set.size < hash_size) { 643 | if (galloc->hash_set.keys != NULL) { 644 | free(galloc->hash_set.keys); 645 | } 646 | if (galloc->hash_values != NULL) { 647 | free(galloc->hash_values); 648 | } 649 | galloc->hash_set.keys = malloc(sizeof(struct ggml_tensor *) * hash_size); 650 | galloc->hash_set.size = hash_size; 651 | galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size); 652 | } 653 | 654 | // reset hash table 655 | memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * hash_size); 656 | memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size); 657 | 658 | galloc->talloc = talloc; 659 | ggml_tallocr_alloc_graph_impl(galloc, graph); 660 | galloc->talloc = NULL; 661 | 662 | size_t max_size = ggml_tallocr_max_size(talloc); 663 | 664 | return max_size; 665 | } 666 | 667 | void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_talloc) { 668 | const size_t hash_size = hash_set.size; 669 | 670 | GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs)); 671 | 672 | galloc->talloc = NULL; 673 | 674 | // alloc hash_values if needed 675 | if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) { 676 | free(galloc->hash_values); 677 | galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size); 678 | galloc->hash_values_size = hash_size; 679 | } 680 | 681 | // free hash_set.keys if needed 682 | if (galloc->hash_set.keys != NULL) { 683 | free(galloc->hash_set.keys); 684 | } 685 | galloc->hash_set = hash_set; 686 | 687 | // reset hash values 688 | memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size); 689 | 690 | galloc->hash_allocs = hash_node_talloc; 691 | 692 | ggml_tallocr_alloc_graph_impl(galloc, graph); 693 | 694 | // remove unowned resources 695 | galloc->hash_set.keys = NULL; 696 | galloc->hash_allocs = NULL; 697 | } 698 | 699 | // legacy API wrapper 700 | 701 | struct ggml_allocr { 702 | ggml_tallocr_t talloc; 703 | ggml_gallocr_t galloc; 704 | }; 705 | 706 | static ggml_allocr_t ggml_allocr_new_impl(ggml_tallocr_t talloc) { 707 | ggml_allocr_t alloc = (ggml_allocr_t)malloc(sizeof(struct ggml_allocr)); 708 | *alloc = (struct ggml_allocr) { 709 | /*.talloc = */ talloc, 710 | /*.galloc = */ ggml_gallocr_new(), 711 | }; 712 | return alloc; 713 | } 714 | 715 | ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment) { 716 | return ggml_allocr_new_impl(ggml_tallocr_new(data, size, alignment)); 717 | } 718 | 719 | ggml_allocr_t ggml_allocr_new_measure(size_t alignment) { 720 | return ggml_allocr_new_impl(ggml_tallocr_new_measure(alignment)); 721 | } 722 | 723 | ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) { 724 | return ggml_allocr_new_impl(ggml_tallocr_new_from_buffer(buffer)); 725 | } 726 | 727 | ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size) { 728 | return ggml_allocr_new_impl(ggml_tallocr_new_from_backend(backend, size)); 729 | } 730 | 731 | ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend) { 732 | return ggml_allocr_new_impl(ggml_tallocr_new_measure_from_backend(backend)); 733 | } 734 | 735 | struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc) { 736 | return ggml_tallocr_get_buffer(alloc->talloc); 737 | } 738 | 739 | void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) { 740 | ggml_gallocr_set_parse_seq(alloc->galloc, list, n); 741 | } 742 | 743 | void ggml_allocr_free(ggml_allocr_t alloc) { 744 | ggml_gallocr_free(alloc->galloc); 745 | ggml_tallocr_free(alloc->talloc); 746 | free(alloc); 747 | } 748 | 749 | bool ggml_allocr_is_measure(ggml_allocr_t alloc) { 750 | return ggml_tallocr_is_measure(alloc->talloc); 751 | } 752 | 753 | void ggml_allocr_reset(ggml_allocr_t alloc) { 754 | ggml_tallocr_reset(alloc->talloc); 755 | } 756 | 757 | void ggml_allocr_alloc(ggml_allocr_t alloc, struct ggml_tensor * tensor) { 758 | ggml_tallocr_alloc(alloc->talloc, tensor); 759 | } 760 | 761 | size_t ggml_allocr_max_size(ggml_allocr_t alloc) { 762 | return ggml_tallocr_max_size(alloc->talloc); 763 | } 764 | 765 | size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) { 766 | return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph); 767 | } 768 | --------------------------------------------------------------------------------