├── .gitignore ├── .tm_properties ├── LICENSE ├── Makefile ├── README.md ├── bot.go ├── chat.go ├── cmd └── i18n │ ├── go.mod │ ├── main.go │ └── template.go ├── config.json.sample ├── db.go ├── deploy.sh ├── erased.ogg ├── function_calls.go ├── go.mod ├── go.sum ├── i18n └── i18n.go ├── image.go ├── llm.go ├── localizations_src └── ru.json ├── main.go ├── markdown.go ├── models.go ├── opus ├── AUTHORS.txt ├── callbacks.c ├── encoder.go ├── stream.go ├── stream_errors.go └── streams_map.go ├── state.go ├── tele_handlers.go ├── tools ├── duckduckgo.go ├── tool_cryptorate.go ├── tool_search_vector_db.go └── tool_websearch.go ├── types └── types.go ├── uploads └── .gitkeep ├── validation.go ├── vectordb ├── chroma.go ├── embedder.go ├── embedding.go ├── handler.go ├── openaiclient.go ├── split_documents.go ├── text.go ├── text_spliter.go ├── token_splitter.go └── tsoptions.go └── voice.go /.gitignore: -------------------------------------------------------------------------------- 1 | config.json 2 | chatgpt-bot 3 | .idea 4 | .DS_Store 5 | lib* 6 | bot.db 7 | .env 8 | cmd/i18n/i18n 9 | uploads/* -------------------------------------------------------------------------------- /.tm_properties: -------------------------------------------------------------------------------- 1 | TM_ENABLE_IS = true 2 | TM_ENABLE_FMT = true 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 tectiv3 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VERSION := $(shell git describe --tags) 2 | BUILD_TIME := $(shell date +%FT%T%z) 3 | .PHONY: build 4 | 5 | build: 6 | go build -ldflags "-w -X main.BuildTime=${BUILD_TIME} -X main.Version=${VERSION}" . 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Telegram ChatGPT Bot 2 | 3 | A telegram bot which answers to messages with [ChatGPT API](https://platform.openai.com/docs/api-reference/chat). 4 | 5 | ## Configuration 6 | 7 | Copy example configuration file: 8 | 9 | ```bash 10 | $ cp config.json.sample config.json 11 | ``` 12 | 13 | and set your values: 14 | 15 | ```json 16 | { 17 | "telegram_bot_token": "123456:abcdefghijklmnop-QRSTUVWXYZ7890", 18 | "openai_api_key": "key-ABCDEFGHIJK1234567890", 19 | "openai_org_id": "org-1234567890abcdefghijk", 20 | "allowed_telegram_users": ["user1", "user2"], 21 | "openai_model": "gpt-4o-mini", 22 | "verbose": false 23 | } 24 | ``` 25 | 26 | ### Install dependencies 27 | 28 | `libmp3lame0` is required for mp3 encoding. (macOS: `brew install lame`) 29 | 30 | `libopus0` is required for ogg opus decoding. (macOS: `brew install opusfile`) 31 | 32 | On macOS, you might also need `brew install pkg-config`. 33 | 34 | #### Ubuntu: 35 | ```bash 36 | $ sudo apt-get install libmp3lame0 libopus0 libopusfile0 libogg0 libmp3lame-dev libopusfile-dev 37 | ``` 38 | 39 | ## Build 40 | 41 | ```bash 42 | $ go build 43 | ``` 44 | 45 | ## Run 46 | 47 | Run the built binary with the config file's path: 48 | 49 | ```bash 50 | $ ./chatgpt-bot 51 | ``` 52 | 53 | ## Run as a systemd service 54 | 55 | Createa a systemd service file: 56 | 57 | ``` 58 | [Unit] 59 | Description=Telegram ChatGPT Bot 60 | After=syslog.target 61 | After=network.target 62 | 63 | [Service] 64 | Type=simple 65 | User=pi 66 | Group=pi 67 | WorkingDirectory=/dir/to/chatgpt-bot 68 | ExecStart=/dir/to/chatgpt-bot/chatgpt-bot /path/to/config.json 69 | Restart=always 70 | RestartSec=5 71 | 72 | [Install] 73 | WantedBy=multi-user.target 74 | ``` 75 | 76 | ```bash 77 | $ systemctl daemon-reload 78 | ``` 79 | 80 | and `systemctl` enable|start|restart|stop the service. 81 | 82 | ## License 83 | 84 | The MIT License (MIT) 85 | 86 | Permission is hereby granted, free of charge, to any person obtaining a copy 87 | of this software and associated documentation files (the "Software"), to deal 88 | in the Software without restriction, including without limitation the rights 89 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 90 | copies of the Software, and to permit persons to whom the Software is 91 | furnished to do so, subject to the following conditions: 92 | 93 | The above copyright notice and this permission notice shall be included in all 94 | copies or substantial portions of the Software. 95 | 96 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 97 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 98 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 99 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 100 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 101 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 102 | SOFTWARE. 103 | 104 | -------------------------------------------------------------------------------- /bot.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | "time" 8 | 9 | "github.com/google/uuid" 10 | "github.com/tectiv3/chatgpt-bot/i18n" 11 | 12 | tele "gopkg.in/telebot.v3" 13 | ) 14 | 15 | const ( 16 | cmdStart = "/start" 17 | cmdReset = "/reset" 18 | cmdModel = "/model" 19 | cmdTemp = "/temperature" 20 | cmdPrompt = "/prompt" 21 | cmdAge = "/age" 22 | cmdPromptCL = "/defaultprompt" 23 | cmdStream = "/stream" 24 | cmdStop = "/stop" 25 | cmdVoice = "/voice" 26 | cmdInfo = "/info" 27 | cmdLang = "/lang" 28 | cmdImage = "/image" 29 | cmdToJapanese = "/ja" 30 | cmdToEnglish = "/en" 31 | cmdToRussian = "/ru" 32 | cmdToItalian = "/it" 33 | cmdToSpanish = "/es" 34 | cmdToChinese = "/cn" 35 | cmdRoles = "/roles" 36 | cmdRole = "/role" 37 | cmdQA = "/qa" 38 | cmdUsers = "/users" 39 | cmdAddUser = "/add" 40 | cmdDelUser = "/del" 41 | msgStart = "This bot will answer your messages with ChatGPT API" 42 | masterPrompt = "You are a helpful assistant. You always try to answer truthfully. If you don't know the answer, just say that you don't know, don't try to make up an answer. Don't explain yourself. Do not introduce yourself, just answer the user concisely." 43 | pOllama = "ollama" 44 | pGroq = "groq" 45 | pOpenAI = "openai" 46 | miniModel = "gpt-4o-mini" 47 | pAWS = "aws" 48 | pAnthropic = "anthropic" 49 | openAILatest = "openAILatest" 50 | ) 51 | 52 | var ( 53 | menu = &tele.ReplyMarkup{ResizeKeyboard: true} 54 | replyMenu = &tele.ReplyMarkup{ResizeKeyboard: true, OneTimeKeyboard: true} 55 | removeMenu = &tele.ReplyMarkup{RemoveKeyboard: true} 56 | btnModel = tele.Btn{Text: "Select Model", Unique: "btnModel", Data: ""} 57 | btnT0 = tele.Btn{Text: "0.0", Unique: "btntemp", Data: "0.0"} 58 | btnT2 = tele.Btn{Text: "0.2", Unique: "btntemp", Data: "0.2"} 59 | btnT4 = tele.Btn{Text: "0.4", Unique: "btntemp", Data: "0.4"} 60 | btnT6 = tele.Btn{Text: "0.6", Unique: "btntemp", Data: "0.6"} 61 | btnT8 = tele.Btn{Text: "0.8", Unique: "btntemp", Data: "0.8"} 62 | btnT10 = tele.Btn{Text: "1.0", Unique: "btntemp", Data: "1.0"} 63 | btnCreate = tele.Btn{Text: "New Role", Unique: "btnRole", Data: "create"} 64 | btnUpdate = tele.Btn{Text: "Update", Unique: "btnUpdate", Data: "update"} 65 | btnDelete = tele.Btn{Text: "Delete", Unique: "btnDelete", Data: "delete"} 66 | btnCancel = tele.Btn{Text: "Cancel", Unique: "btnRole", Data: "cancel"} 67 | 68 | btnReset = tele.Btn{Text: "New Conversation", Unique: "btnreset", Data: "r"} 69 | btnEmpty = tele.Btn{Text: "", Data: "no_data"} 70 | ) 71 | 72 | func init() { 73 | replyMenu.Inline(menu.Row(btnReset)) 74 | removeMenu.Inline(menu.Row(btnEmpty)) 75 | } 76 | 77 | // run will launch bot with given parameters 78 | func (s *Server) run() { 79 | b, err := tele.NewBot(tele.Settings{ 80 | Token: s.conf.TelegramBotToken, 81 | URL: s.conf.TelegramServerURL, 82 | Poller: &tele.LongPoller{ 83 | Timeout: 1 * time.Second, 84 | AllowedUpdates: []string{ 85 | "message", 86 | "edited_message", 87 | "inline_query", 88 | "callback_query", 89 | // "message_reaction", 90 | // "message_reaction_count", 91 | }, 92 | }, 93 | }) 94 | if err != nil { 95 | Log.Fatal(err) 96 | return 97 | } 98 | //if done, err := b.Logout(); err != nil { 99 | // log.Fatal(err) 100 | // return 101 | //} else { 102 | // log.Println("Logout: ", done) 103 | // return 104 | //} 105 | 106 | // b.Use(middleware.Logger()) 107 | s.loadUsers() 108 | 109 | s.Lock() 110 | b.Use(s.whitelist()) 111 | s.bot = b 112 | s.Unlock() 113 | 114 | b.Handle(cmdStart, func(c tele.Context) error { 115 | return c.Send( 116 | l.GetWithLocale(c.Sender().LanguageCode, msgStart), 117 | "text", 118 | &tele.SendOptions{ReplyTo: c.Message()}, 119 | ) 120 | }) 121 | 122 | b.Handle(cmdModel, func(c tele.Context) error { 123 | chat := s.getChat(c.Chat(), c.Sender()) 124 | model := strings.TrimSpace(c.Message().Payload) 125 | if model == "" { 126 | rows := []tele.Row{} 127 | row := []tele.Btn{} 128 | 129 | for _, m := range s.conf.Models { 130 | if len(row) == 3 { 131 | rows = append(rows, menu.Row(row...)) 132 | row = []tele.Btn{} 133 | } 134 | if m.Provider == pOpenAI || (m.Provider == pAnthropic && s.conf.AnthropicEnabled) || (m.Provider == pAWS && s.conf.AWSEnabled) { 135 | row = append(row, tele.Btn{Text: m.Name, Unique: "btnModel", Data: m.Name}) 136 | } 137 | } 138 | rows = append(rows, menu.Row(row...)) 139 | 140 | menu.Inline(rows...) 141 | 142 | return c.Send(chat.t("Select model"), menu) 143 | } 144 | Log.WithField("user", c.Sender().Username).Info("Selected model ", model) 145 | chat.ModelName = model 146 | chat.Stream = true 147 | s.db.Save(&chat) 148 | 149 | return c.Send(chat.t("Model set to {{.model}}", &i18n.Replacements{"model": model})) 150 | }) 151 | 152 | b.Handle(cmdTemp, func(c tele.Context) error { 153 | menu.Inline(menu.Row(btnT0, btnT2, btnT4, btnT6, btnT8, btnT10)) 154 | chat := s.getChat(c.Chat(), c.Sender()) 155 | 156 | return c.Send( 157 | fmt.Sprintf( 158 | chat.t( 159 | "Set temperature from less random (0.0) to more random (1.0).\nCurrent: %0.2f (default: 0.8)", 160 | ), 161 | chat.Temperature, 162 | ), 163 | menu, 164 | ) 165 | }) 166 | 167 | b.Handle(cmdRole, func(c tele.Context) error { 168 | chat := s.getChat(c.Chat(), c.Sender()) 169 | name := strings.TrimSpace(c.Message().Payload) 170 | if err := ValidateRoleName(name); err != nil { 171 | return c.Send( 172 | chat.t("Invalid role name: {{.error}}", &i18n.Replacements{"error": err.Error()}), 173 | "text", 174 | &tele.SendOptions{ReplyTo: c.Message()}, 175 | ) 176 | } 177 | role := s.findRole(chat.UserID, name) 178 | if role == nil { 179 | return c.Send(chat.t("Role not found")) 180 | } 181 | s.setChatRole(&role.ID, chat.ChatID) 182 | 183 | return c.Send(chat.t("Role set to {{.role}}", &i18n.Replacements{"role": role.Name})) 184 | }) 185 | 186 | b.Handle(cmdRoles, func(c tele.Context) error { 187 | chat := s.getChat(c.Chat(), c.Sender()) 188 | roles := chat.User.Roles 189 | rows := []tele.Row{} 190 | // iterate over roles, add menu button with role name 3 buttons in a row 191 | row := []tele.Btn{ 192 | { 193 | Text: chat.t("default"), 194 | Unique: "btnRole", 195 | Data: "___default___", 196 | }, 197 | } 198 | for _, role := range roles { 199 | if len(row) == 3 { 200 | rows = append(rows, menu.Row(row...)) 201 | row = []tele.Btn{} 202 | } 203 | row = append( 204 | row, 205 | tele.Btn{Text: role.Name, Unique: "btnRole", Data: strconv.Itoa(int(role.ID))}, 206 | ) 207 | } 208 | if len(row) != 0 { 209 | rows = append(rows, menu.Row(row...)) 210 | row = []tele.Btn{} 211 | } 212 | row = append(row, btnCreate, btnUpdate, btnDelete) 213 | rows = append(rows, menu.Row(row...)) 214 | // Log.Info(rows) 215 | menu.Inline(rows...) 216 | 217 | return c.Send(chat.t("Select role"), menu) 218 | }) 219 | 220 | b.Handle(&btnCreate, func(c tele.Context) error { 221 | Log.WithField("user", c.Sender().Username).Info("Selected role ", c.Data()) 222 | chat := s.getChat(c.Chat(), c.Sender()) 223 | 224 | user := chat.User 225 | if c.Data() == "cancel" { 226 | s.db.Model(&user).Update("State", nil) 227 | 228 | return c.Edit(chat.t("Canceled"), removeMenu) 229 | } 230 | 231 | if c.Data() == "___default___" { 232 | chat.MasterPrompt = masterPrompt 233 | s.db.Save(&chat) 234 | s.setChatRole(nil, chat.ChatID) 235 | 236 | return c.Edit(chat.t("Default prompt set")) 237 | } 238 | 239 | if c.Data() != "create" { 240 | roleID := asUint(c.Data()) 241 | role := s.getRole(roleID) 242 | if role == nil { 243 | return c.Send(chat.t("Role not found")) 244 | } 245 | // s.db.Model(&chat).Update("RoleID", role.ID) // gorm is weird 246 | s.setChatRole(&role.ID, chat.ChatID) 247 | s.setChatLastMessageID(nil, chat.ChatID) 248 | 249 | return c.Edit(chat.t("Role set to {{.role}}", &i18n.Replacements{"role": role.Name})) 250 | } 251 | 252 | state := State{ 253 | Name: "RoleCreate", 254 | FirstStep: Step{ 255 | Field: "Name", 256 | Prompt: "Enter role name", 257 | Next: &Step{ 258 | Prompt: "Enter system prompt", 259 | Field: "Prompt", 260 | }, 261 | }, 262 | } 263 | s.db.Model(&user).Update("State", state) 264 | 265 | menu.Inline(menu.Row(btnCancel)) 266 | 267 | id := &([]string{strconv.Itoa(c.Message().ID)}[0]) 268 | s.setChatLastMessageID(id, chat.ChatID) 269 | 270 | return c.Edit(chat.t("Enter role name"), menu) 271 | }) 272 | 273 | b.Handle(&btnUpdate, func(c tele.Context) error { 274 | Log.WithField("user", c.Sender().Username).Info("Selected option ", c.Data()) 275 | chat := s.getChat(c.Chat(), c.Sender()) 276 | user := chat.User 277 | 278 | if c.Data() != "update" { 279 | roleID := asUint(c.Data()) 280 | role := s.getRole(roleID) 281 | if role == nil { 282 | return c.Edit(chat.t("Role not found")) 283 | } 284 | 285 | state := State{ 286 | Name: "RoleUpdate", 287 | ID: &roleID, 288 | FirstStep: Step{ 289 | Field: "Name", 290 | Prompt: "Enter role name", 291 | Next: &Step{ 292 | Prompt: "Enter system prompt", 293 | Field: "Prompt", 294 | }, 295 | }, 296 | } 297 | user.State = &state 298 | s.db.Save(&user) 299 | 300 | menu.Inline(menu.Row(btnCancel)) 301 | 302 | return c.Edit(chat.t(state.FirstStep.Prompt), menu) 303 | } 304 | 305 | roles := chat.User.Roles 306 | rows := []tele.Row{} 307 | // iterate over roles, add menu button with role name 3 buttons in a row 308 | row := []tele.Btn{} 309 | for _, role := range roles { 310 | if len(row) == 3 { 311 | rows = append(rows, menu.Row(row...)) 312 | row = []tele.Btn{} 313 | } 314 | row = append( 315 | row, 316 | tele.Btn{Text: role.Name, Unique: "btnUpdate", Data: strconv.Itoa(int(role.ID))}, 317 | ) 318 | } 319 | rows = append(rows, menu.Row(row...), menu.Row(btnCancel)) 320 | menu.Inline(rows...) 321 | 322 | return c.Edit(chat.t("Select Role"), menu) 323 | }) 324 | 325 | b.Handle(&btnDelete, func(c tele.Context) error { 326 | Log.WithField("user", c.Sender().Username).Info("Selected option ", c.Data()) 327 | chat := s.getChat(c.Chat(), c.Sender()) 328 | 329 | if c.Data() != "delete" { 330 | roleID := asUint(c.Data()) 331 | role := s.getRole(roleID) 332 | if role == nil { 333 | return c.Send(chat.t("Role not found")) 334 | } 335 | // Log.WithField("roleID", roleID).WithField("chat", *chat.RoleID).Info("Role deleted") 336 | if chat.RoleID != nil { 337 | Log.WithField("roleID", roleID).WithField("chat", *chat.RoleID).Info("Role deleted") 338 | if *chat.RoleID == roleID { 339 | // s.db.Model(&chat).Update("RoleID", nil) // stupid gorm insert chat, roles, users and duplicates roleID 340 | s.setChatRole(nil, chat.ChatID) 341 | } 342 | } 343 | s.db.Unscoped().Delete(&Role{}, roleID) 344 | 345 | return c.Edit(chat.t("Role deleted")) 346 | } 347 | 348 | roles := chat.User.Roles 349 | rows := []tele.Row{} 350 | // iterate over roles, add menu button with role name, 3 buttons in a row 351 | // TODO: refactor to use native menu.Split(3, btns) 352 | row := []tele.Btn{} 353 | for _, role := range roles { 354 | if len(row) == 3 { 355 | rows = append(rows, menu.Row(row...)) 356 | row = []tele.Btn{} 357 | } 358 | row = append( 359 | row, 360 | tele.Btn{Text: role.Name, Unique: "btnDelete", Data: strconv.Itoa(int(role.ID))}, 361 | ) 362 | } 363 | rows = append(rows, menu.Row(row...), menu.Row(btnCancel)) 364 | menu.Inline(rows...) 365 | 366 | return c.Edit(chat.t("Select Role"), menu) 367 | }) 368 | 369 | b.Handle(cmdAge, func(c tele.Context) error { 370 | chat := s.getChat(c.Chat(), c.Sender()) 371 | ageStr := strings.TrimSpace(c.Message().Payload) 372 | age, err := ValidateAge(ageStr) 373 | if err != nil { 374 | return c.Send( 375 | chat.t("Invalid age: {{.error}}", &i18n.Replacements{"error": err.Error()}), 376 | "text", 377 | &tele.SendOptions{ReplyTo: c.Message()}, 378 | ) 379 | } 380 | chat.ConversationAge = int64(age) 381 | s.db.Save(&chat) 382 | 383 | return c.Send( 384 | fmt.Sprintf(chat.t("Conversation age set to %d days"), age), 385 | "text", 386 | &tele.SendOptions{ReplyTo: c.Message()}, 387 | ) 388 | }) 389 | 390 | b.Handle(cmdPrompt, func(c tele.Context) error { 391 | chat := s.getChat(c.Chat(), c.Sender()) 392 | query := strings.TrimSpace(c.Message().Payload) 393 | if err := ValidatePrompt(query); err != nil { 394 | return c.Send( 395 | chat.t("Invalid prompt: {{.error}}", &i18n.Replacements{"error": err.Error()}), 396 | "text", 397 | &tele.SendOptions{ReplyTo: c.Message()}, 398 | ) 399 | } 400 | 401 | chat.MasterPrompt = query 402 | s.db.Save(&chat) 403 | 404 | return c.Send(chat.t("Prompt set"), "text", &tele.SendOptions{ReplyTo: c.Message()}) 405 | }) 406 | 407 | b.Handle(cmdPromptCL, func(c tele.Context) error { 408 | chat := s.getChat(c.Chat(), c.Sender()) 409 | chat.MasterPrompt = masterPrompt 410 | chat.RoleID = nil 411 | s.db.Save(&chat) 412 | 413 | return c.Send(chat.t("Default prompt set"), "text", &tele.SendOptions{ReplyTo: c.Message()}) 414 | }) 415 | 416 | b.Handle(cmdStream, func(c tele.Context) error { 417 | chat := s.getChat(c.Chat(), c.Sender()) 418 | chat.Stream = !chat.Stream 419 | s.db.Save(&chat) 420 | status := "disabled" 421 | if chat.Stream { 422 | status = "enabled" 423 | } 424 | text := chat.t("Stream is {{.status}}", &i18n.Replacements{"status": chat.t(status)}) 425 | 426 | return c.Send(text, "text", &tele.SendOptions{ReplyTo: c.Message()}) 427 | }) 428 | 429 | b.Handle(cmdQA, func(c tele.Context) error { 430 | chat := s.getChat(c.Chat(), c.Sender()) 431 | chat.QA = !chat.QA 432 | s.db.Save(&chat) 433 | status := "disabled" 434 | if chat.QA { 435 | status = "enabled" 436 | } 437 | text := chat.t("Questions List is {{.status}}", &i18n.Replacements{"status": chat.t(status)}) 438 | 439 | return c.Send(text, "text", &tele.SendOptions{ReplyTo: c.Message()}) 440 | }) 441 | 442 | b.Handle(cmdVoice, func(c tele.Context) error { 443 | go s.pageToSpeech(c, c.Message().Payload) 444 | 445 | return c.Send("Downloading page", "text", &tele.SendOptions{ReplyTo: c.Message()}) 446 | }) 447 | 448 | b.Handle(cmdStop, func(c tele.Context) error { 449 | return nil 450 | }) 451 | 452 | b.Handle(cmdInfo, func(c tele.Context) error { 453 | chat := s.getChat(c.Chat(), c.Sender()) 454 | status := "disabled" 455 | if chat.Stream { 456 | status = "enabled" 457 | } 458 | status = chat.t(status) 459 | 460 | prompt := chat.MasterPrompt 461 | role := chat.t("default") 462 | if chat.RoleID != nil { 463 | prompt = chat.Role.Prompt 464 | role = chat.Role.Name 465 | } 466 | 467 | return c.Send( 468 | fmt.Sprintf( 469 | "Version: %s\nModel: %s\nTemperature: %0.2f\nPrompt: %s\nStreaming: %s\nConvesation Age (days): %d\nRole: %s", 470 | Version, 471 | s.getModel(chat.ModelName), 472 | chat.Temperature, 473 | prompt, 474 | status, 475 | chat.ConversationAge, 476 | role, 477 | ), 478 | "text", 479 | &tele.SendOptions{ReplyTo: c.Message()}, 480 | ) 481 | }) 482 | 483 | b.Handle(cmdToJapanese, func(c tele.Context) error { 484 | go s.onTranslate(c, "To Japanese: ") 485 | 486 | return nil 487 | }) 488 | 489 | b.Handle(cmdToEnglish, func(c tele.Context) error { 490 | go s.onTranslate(c, "To English: ") 491 | 492 | return nil 493 | }) 494 | 495 | b.Handle(cmdToRussian, func(c tele.Context) error { 496 | go s.onTranslate(c, "To Russian: ") 497 | 498 | return nil 499 | }) 500 | 501 | b.Handle(cmdToItalian, func(c tele.Context) error { 502 | go s.onTranslate(c, "To Italian: ") 503 | return nil 504 | }) 505 | 506 | b.Handle(cmdToSpanish, func(c tele.Context) error { 507 | go s.onTranslate(c, "To Spanish: ") 508 | 509 | return nil 510 | }) 511 | 512 | b.Handle(cmdToChinese, func(c tele.Context) error { 513 | go s.onTranslate(c, "To Chinese: ") 514 | 515 | return nil 516 | }) 517 | 518 | b.Handle(cmdImage, func(c tele.Context) error { 519 | chat := s.getChat(c.Chat(), c.Sender()) 520 | msg := chat.getSentMessage(c) 521 | msg, _ = c.Bot().Edit(msg, "Generating...") 522 | if err := s.textToImage(c, c.Message().Payload, true); err != nil { 523 | _, _ = c.Bot().Edit(msg, "Generating...") 524 | return c.Send("Error: " + err.Error()) 525 | } 526 | _ = c.Bot().Delete(msg) 527 | 528 | return nil 529 | }) 530 | 531 | b.Handle(cmdLang, func(c tele.Context) error { 532 | chat := s.getChat(c.Chat(), c.Sender()) 533 | langCode := strings.TrimSpace(c.Message().Payload) 534 | if err := ValidateLanguageCode(langCode); err != nil { 535 | return c.Send( 536 | chat.t("Invalid language code: {{.error}}", &i18n.Replacements{"error": err.Error()}), 537 | "text", 538 | &tele.SendOptions{ReplyTo: c.Message()}, 539 | ) 540 | } 541 | chat.Lang = langCode 542 | s.db.Save(&chat) 543 | return c.Send( 544 | fmt.Sprintf("Language set to %s", chat.Lang), 545 | "text", 546 | &tele.SendOptions{ReplyTo: c.Message()}, 547 | ) 548 | }) 549 | 550 | b.Handle(&btnModel, func(c tele.Context) error { 551 | Log.WithField("user", c.Sender().Username).Info("Selected model ", c.Data()) 552 | chat := s.getChat(c.Chat(), c.Sender()) 553 | chat.ModelName = c.Data() 554 | s.db.Save(&chat) 555 | 556 | return c.Edit(chat.t("Model set to {{.model}}", &i18n.Replacements{"model": c.Data()})) 557 | }) 558 | 559 | b.Handle(&btnT0, func(c tele.Context) error { 560 | Log.WithField("user", c.Sender().Username).Info("Selected temperature ", c.Data()) 561 | chat := s.getChat(c.Chat(), c.Sender()) 562 | temp, err := ValidateTemperature(c.Data()) 563 | if err != nil { 564 | Log.WithField("error", err).Warn("Invalid temperature value") 565 | return c.Edit(chat.t("Invalid temperature value")) 566 | } 567 | chat.Temperature = temp 568 | s.db.Save(&chat) 569 | 570 | return c.Edit(chat.t("Temperature set to {{.temp}}", &i18n.Replacements{"temp": c.Data()})) 571 | }) 572 | 573 | b.Handle(&btnReset, func(c tele.Context) error { 574 | chat := s.getChat(c.Chat(), c.Sender()) 575 | 576 | s.deleteHistory(chat.ID) 577 | s.setChatLastMessageID(nil, chat.ChatID) 578 | 579 | return c.Edit(removeMenu) 580 | }) 581 | 582 | b.Handle(cmdReset, func(c tele.Context) error { 583 | chat := s.getChat(c.Chat(), c.Sender()) 584 | // Log.Info("Resetting chat") 585 | s.deleteHistory(chat.ID) 586 | if chat.MessageID != nil { 587 | id, _ := strconv.Atoi(*chat.MessageID) 588 | sentMessage := &tele.Message{ID: id, Chat: &tele.Chat{ID: chat.ChatID}} 589 | 590 | // Log.Infof("Resetting chat menu, sentMessage: %v", sentMessage) 591 | c.Bot().Edit(sentMessage, removeMenu) 592 | s.setChatLastMessageID(nil, chat.ChatID) 593 | 594 | return nil 595 | } 596 | s.setChatLastMessageID(nil, chat.ChatID) 597 | 598 | return nil 599 | }) 600 | 601 | b.Handle(tele.OnText, func(c tele.Context) error { 602 | chat := s.getChat(c.Chat(), c.Sender()) 603 | 604 | // not handling user input through stepper/state machine 605 | if chat.User.State == nil { 606 | // if e := b.React(c.Sender(), c.Message(), react.React(react.Eyes)); e != nil { 607 | // Log.Warn(e) 608 | // } 609 | 610 | go s.onText(c) 611 | } else { 612 | chat.removeMenu(c) 613 | // in the middle of stepper input 614 | go s.onState(c) 615 | } 616 | 617 | return nil 618 | }) 619 | 620 | b.Handle(tele.OnQuery, func(c tele.Context) error { 621 | query := c.Query().Text 622 | article := &tele.ArticleResult{Title: "N/A"} 623 | result, err := s.anonymousAnswer(c, query) 624 | if err != nil { 625 | article = &tele.ArticleResult{ 626 | Title: "Error!", 627 | Text: err.Error(), 628 | } 629 | } else { 630 | article = &tele.ArticleResult{ 631 | Title: query, 632 | Text: result, 633 | } 634 | } 635 | 636 | results := make(tele.Results, 1) 637 | results[0] = article 638 | // needed to set a unique string ID for each result 639 | id := uuid.New() 640 | results[0].SetResultID(id.String()) 641 | 642 | c.Answer(&tele.QueryResponse{Results: results, CacheTime: 100}) 643 | return nil 644 | }) 645 | 646 | b.Handle(tele.OnDocument, func(c tele.Context) error { 647 | chat := s.getChat(c.Chat(), c.Sender()) 648 | go s.onDocument(c) 649 | 650 | // b.React(c.Recipient(), c.Message(), react.React(react.Eyes)) 651 | 652 | return c.Send(chat.t("Processing document. Please wait...")) 653 | }) 654 | 655 | b.Handle(tele.OnVoice, func(c tele.Context) error { 656 | go s.onVoice(c) 657 | 658 | return nil 659 | }) 660 | 661 | b.Handle(tele.OnPhoto, func(c tele.Context) error { 662 | go s.onPhoto(c) 663 | 664 | return nil 665 | }) 666 | 667 | b.Handle(cmdUsers, func(c tele.Context) error { 668 | if !in_array(c.Sender().Username, s.conf.AllowedTelegramUsers) { 669 | return nil 670 | } 671 | go s.onGetUsers(c) 672 | 673 | return nil 674 | }) 675 | 676 | b.Handle(cmdAddUser, func(c tele.Context) error { 677 | if !in_array(c.Sender().Username, s.conf.AllowedTelegramUsers) { 678 | return nil 679 | } 680 | name := strings.TrimSpace(c.Message().Payload) 681 | if err := ValidateUsername(name); err != nil { 682 | return c.Send( 683 | fmt.Sprintf("Invalid username: %s", err.Error()), 684 | "text", 685 | &tele.SendOptions{ReplyTo: c.Message()}, 686 | ) 687 | } 688 | s.addUser(name) 689 | s.loadUsers() 690 | 691 | go s.onGetUsers(c) 692 | 693 | return nil 694 | }) 695 | 696 | b.Handle(cmdDelUser, func(c tele.Context) error { 697 | if !in_array(c.Sender().Username, s.conf.AllowedTelegramUsers) { 698 | return nil 699 | } 700 | name := strings.TrimSpace(c.Message().Payload) 701 | if err := ValidateUsername(name); err != nil { 702 | return c.Send( 703 | fmt.Sprintf("Invalid username: %s", err.Error()), 704 | "text", 705 | &tele.SendOptions{ReplyTo: c.Message()}, 706 | ) 707 | } 708 | s.delUser(name) 709 | s.loadUsers() 710 | 711 | go s.onGetUsers(c) 712 | 713 | return nil 714 | }) 715 | 716 | b.Start() 717 | } 718 | 719 | // Restrict returns a middleware that handles a list of provided 720 | // usernames with the logic defined by In and Out functions. 721 | // If the username is found in the Usernames field, In function will be called, 722 | // otherwise Out function will be called. 723 | func Restrict(v RestrictConfig) tele.MiddlewareFunc { 724 | return func(next tele.HandlerFunc) tele.HandlerFunc { 725 | if v.In == nil { 726 | v.In = next 727 | } 728 | if v.Out == nil { 729 | v.Out = next 730 | } 731 | return func(c tele.Context) error { 732 | for _, username := range v.Usernames { 733 | if username == c.Sender().Username { 734 | return v.In(c) 735 | } 736 | } 737 | return v.Out(c) 738 | } 739 | } 740 | } 741 | 742 | // Whitelist returns a middleware that skips the update for users 743 | // NOT specified in the usernames field. 744 | func (s *Server) whitelist() tele.MiddlewareFunc { 745 | return func(next tele.HandlerFunc) tele.HandlerFunc { 746 | return Restrict(RestrictConfig{ 747 | Usernames: s.users, 748 | In: next, 749 | Out: func(c tele.Context) error { 750 | return c.Send( 751 | fmt.Sprintf("not allowed: %s", c.Sender().Username), 752 | "text", 753 | &tele.SendOptions{ReplyTo: c.Message()}, 754 | ) 755 | }, 756 | })(next) 757 | } 758 | } 759 | -------------------------------------------------------------------------------- /chat.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "os" 6 | "strconv" 7 | "time" 8 | 9 | "github.com/meinside/openai-go" 10 | "github.com/tectiv3/awsnova-go" 11 | "github.com/tectiv3/chatgpt-bot/i18n" 12 | tele "gopkg.in/telebot.v3" 13 | ) 14 | 15 | func (c *Chat) getSentMessage(context tele.Context) *tele.Message { 16 | c.mutex.Lock() 17 | defer c.mutex.Unlock() 18 | if c.MessageID != nil { 19 | id, _ := strconv.Atoi(*c.MessageID) 20 | 21 | return &tele.Message{ID: id, Chat: &tele.Chat{ID: c.ChatID}} 22 | } 23 | // if we already have a message ID, use it, otherwise create a new message 24 | if context.Get("reply") != nil { 25 | sentMessage := context.Get("reply").(tele.Message) 26 | c.MessageID = &([]string{strconv.Itoa(sentMessage.ID)}[0]) 27 | return &sentMessage 28 | } 29 | 30 | msgPointer, _ := context.Bot().Send(context.Recipient(), "...", "text", &tele.SendOptions{ReplyTo: context.Message()}) 31 | c.MessageID = &([]string{strconv.Itoa(msgPointer.ID)}[0]) 32 | 33 | return msgPointer 34 | } 35 | 36 | func (c *Chat) addToolResultToDialog(id, content string) { 37 | c.mutex.Lock() 38 | defer c.mutex.Unlock() 39 | msg := openai.NewChatToolMessage(id, content) 40 | // log.Printf("Adding tool message to history: %v\n", msg) 41 | c.History = append(c.History, 42 | ChatMessage{ 43 | Role: msg.Role, 44 | Content: &content, 45 | ChatID: c.ChatID, 46 | ToolCallID: &id, 47 | CreatedAt: time.Now(), 48 | }) 49 | } 50 | 51 | func (c *Chat) addImageToDialog(text, path string) { 52 | c.mutex.Lock() 53 | defer c.mutex.Unlock() 54 | 55 | c.History = append(c.History, 56 | ChatMessage{ 57 | Role: openai.ChatMessageRoleUser, 58 | Content: &text, 59 | ImagePath: &path, 60 | ChatID: c.ChatID, 61 | CreatedAt: time.Now(), 62 | }) 63 | } 64 | 65 | func (c *Chat) addMessageToDialog(msg openai.ChatMessage) { 66 | c.mutex.Lock() 67 | defer c.mutex.Unlock() 68 | // log.Printf("Adding message to history: %v\n", msg) 69 | toolCalls := make([]ToolCall, 0) 70 | for _, tc := range msg.ToolCalls { 71 | toolCalls = append(toolCalls, ToolCall{ 72 | ID: tc.ID, 73 | Type: tc.Type, 74 | Function: tc.Function, 75 | }) 76 | } 77 | content, err := msg.ContentString() 78 | if err != nil { 79 | if contentArr, err := msg.ContentArray(); err == nil { 80 | for _, c := range contentArr { 81 | if c.Type == "text" { 82 | content = *c.Text 83 | break 84 | } 85 | //if c.Type == "image_url" { 86 | // 87 | //} 88 | } 89 | } 90 | } 91 | c.History = append(c.History, 92 | ChatMessage{ 93 | Role: msg.Role, 94 | Content: &content, 95 | ToolCalls: toolCalls, 96 | ChatID: c.ChatID, 97 | CreatedAt: time.Now(), 98 | }) 99 | } 100 | 101 | func (c *Chat) getDialog(request *string) []openai.ChatMessage { 102 | prompt := c.MasterPrompt 103 | if c.RoleID != nil { 104 | prompt = c.Role.Prompt 105 | } 106 | 107 | system := openai.NewChatSystemMessage(prompt) 108 | if request != nil { 109 | c.addMessageToDialog(openai.NewChatUserMessage(*request)) 110 | } 111 | 112 | history := []openai.ChatMessage{system} 113 | for _, h := range c.History { 114 | if h.CreatedAt.Before( 115 | time.Now().AddDate(0, 0, -int(c.ConversationAge)), 116 | ) { 117 | continue 118 | } 119 | 120 | var message openai.ChatMessage 121 | 122 | if h.ImagePath != nil { 123 | reader, err := os.Open(*h.ImagePath) 124 | if err != nil { 125 | Log.Warn("Error opening image file", "error=", err) 126 | continue 127 | } 128 | defer reader.Close() 129 | 130 | image, err := io.ReadAll(reader) 131 | if err != nil { 132 | Log.Warn("Error reading file content", "error=", err) 133 | continue 134 | } 135 | content := []openai.ChatMessageContent{{Type: "text", Text: h.Content}} 136 | content = append(content, openai.NewChatMessageContentWithBytes(image)) 137 | message = openai.ChatMessage{Role: h.Role, Content: content} 138 | } else { 139 | message = openai.ChatMessage{Role: h.Role, Content: h.Content} 140 | } 141 | if h.Role == openai.ChatMessageRoleAssistant && h.ToolCalls != nil { 142 | message.ToolCalls = make([]openai.ToolCall, 0) 143 | for _, tc := range h.ToolCalls { 144 | message.ToolCalls = append(message.ToolCalls, openai.ToolCall{ 145 | ID: tc.ID, 146 | Type: tc.Type, 147 | Function: tc.Function, 148 | }) 149 | } 150 | } 151 | if h.ToolCallID != nil { 152 | message.ToolCallID = h.ToolCallID 153 | } 154 | history = append(history, message) 155 | } 156 | 157 | // Log.Infof("Dialog history: %v", history) 158 | 159 | return history 160 | } 161 | 162 | func (c *Chat) getNovaDialog(request *string) []awsnova.Message { 163 | if request != nil { 164 | c.addMessageToDialog(openai.NewChatUserMessage(*request)) 165 | } 166 | 167 | history := []awsnova.Message{} 168 | for _, h := range c.History { 169 | if h.CreatedAt.Before( 170 | time.Now().AddDate(0, 0, -int(c.ConversationAge)), 171 | ) { 172 | continue 173 | } 174 | 175 | var message awsnova.Message 176 | 177 | if h.ImagePath != nil { 178 | reader, err := os.Open(*h.ImagePath) 179 | if err != nil { 180 | Log.Warn("Error opening image file", "error=", err) 181 | continue 182 | } 183 | defer reader.Close() 184 | 185 | image, err := io.ReadAll(reader) 186 | if err != nil { 187 | Log.Warn("Error reading file content", "error=", err) 188 | continue 189 | } 190 | content := []awsnova.Content{{ 191 | Text: h.Content, 192 | Image: &awsnova.Image{Format: "png", Source: struct { 193 | Bytes string `json:"bytes"` 194 | }{Bytes: toBase64(image)}}, 195 | }} 196 | message = awsnova.Message{Role: string(h.Role), Content: content} 197 | } else { 198 | message = awsnova.Message{Role: string(h.Role), Content: []awsnova.Content{{ 199 | Text: h.Content, 200 | }}} 201 | } 202 | // if h.Role == "assistant" && h.ToolCalls != nil { 203 | // message.ToolCalls = make([]openai.ToolCall, 0) 204 | // for _, tc := range h.ToolCalls { 205 | // message.ToolCalls = append(message.ToolCalls, openai.ToolCall{ 206 | // ID: tc.ID, 207 | // Type: tc.Type, 208 | // Function: tc.Function, 209 | // }) 210 | // } 211 | // } 212 | // if h.ToolCallID != nil { 213 | // message.ToolCallID = h.ToolCallID 214 | // } 215 | history = append(history, message) 216 | } 217 | 218 | // Log.Infof("Dialog history: %v", history) 219 | 220 | return history 221 | } 222 | 223 | func (c *Chat) t(key string, replacements ...*i18n.Replacements) string { 224 | return l.GetWithLocale(c.Lang, key, replacements...) 225 | } 226 | 227 | // Safe methods for updating chat properties 228 | func (c *Chat) updateTotalTokens(tokens int) { 229 | c.mutex.Lock() 230 | defer c.mutex.Unlock() 231 | c.TotalTokens += tokens 232 | } 233 | 234 | func (c *Chat) setMessageID(id *string) { 235 | c.mutex.Lock() 236 | defer c.mutex.Unlock() 237 | c.MessageID = id 238 | } 239 | 240 | func (c *Chat) getMessageID() *string { 241 | c.mutex.Lock() 242 | defer c.mutex.Unlock() 243 | return c.MessageID 244 | } 245 | 246 | func (c *Chat) removeMenu(context tele.Context) { 247 | c.mutex.Lock() 248 | if c.MessageID != nil { 249 | _, _ = context.Bot().EditReplyMarkup(tele.StoredMessage{MessageID: *c.MessageID, ChatID: c.ChatID}, removeMenu) 250 | c.MessageID = nil 251 | } 252 | c.mutex.Unlock() 253 | } 254 | -------------------------------------------------------------------------------- /cmd/i18n/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tectiv3/chatgpt-bot/i18n 2 | 3 | go 1.22.1 4 | -------------------------------------------------------------------------------- /cmd/i18n/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "flag" 7 | "fmt" 8 | "io" 9 | "log" 10 | "os" 11 | "path/filepath" 12 | "strings" 13 | "time" 14 | ) 15 | 16 | const jsonFileExt = ".json" 17 | const defaultOutputDir = "localizations" 18 | 19 | type localizationFile map[string]string 20 | 21 | var ( 22 | input = flag.String("input", "", "input localizations folder") 23 | output = flag.String("output", "", "where to output the generated package") 24 | 25 | errFlagInputNotSet = errors.New("the flag -input must be set") 26 | ) 27 | 28 | func main() { 29 | flag.Parse() 30 | 31 | if err := run(input, output); err != nil { 32 | log.Fatal(err.Error()) 33 | } 34 | } 35 | 36 | func run(in, out *string) error { 37 | inputDir, outputDir, err := parseFlags(in, out) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | files, err := getLocalizationFiles(inputDir) 43 | if err != nil { 44 | return err 45 | } 46 | 47 | localizations, err := generateLocalizations(files) 48 | if err != nil { 49 | return err 50 | } 51 | 52 | return generateFile(outputDir, localizations) 53 | } 54 | 55 | func generateLocalizations(files []string) (map[string]string, error) { 56 | localizations := map[string]string{} 57 | for _, file := range files { 58 | newLocalizations, err := getLocalizationsFromFile(file) 59 | if err != nil { 60 | return nil, err 61 | } 62 | for key, value := range newLocalizations { 63 | localizations[key] = value 64 | } 65 | } 66 | return localizations, nil 67 | } 68 | 69 | func getLocalizationFiles(dir string) ([]string, error) { 70 | var files []string 71 | err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { 72 | if err != nil { 73 | return err 74 | } 75 | ext := filepath.Ext(path) 76 | if !info.IsDir() && (ext == jsonFileExt) { 77 | files = append(files, path) 78 | } 79 | return nil 80 | }) 81 | return files, err 82 | } 83 | 84 | func generateFile(output string, localizations map[string]string) error { 85 | dir := output 86 | parent := output 87 | if strings.Contains(output, string(filepath.Separator)) { 88 | parent = filepath.Base(dir) 89 | } 90 | 91 | err := os.MkdirAll(output, 0700) 92 | if err != nil { 93 | return err 94 | } 95 | 96 | f, err := os.Create(fmt.Sprintf("%v/%v.go", dir, parent)) 97 | if err != nil { 98 | return err 99 | } 100 | 101 | maxWidth := 0 102 | for name := range localizations { 103 | if len(name) > maxWidth { 104 | maxWidth = len(name) 105 | } 106 | } 107 | 108 | lineUp := func(name string) string { 109 | return strings.Repeat(" ", maxWidth-len(name)) 110 | } 111 | 112 | return packageTemplate.Execute(f, struct { 113 | Timestamp time.Time 114 | Localizations map[string]string 115 | Package string 116 | LineUp func(string) string 117 | }{ 118 | Timestamp: time.Now(), 119 | Localizations: localizations, 120 | Package: parent, 121 | LineUp: lineUp, 122 | }) 123 | } 124 | 125 | func getLocalizationsFromFile(file string) (map[string]string, error) { 126 | newLocalizations := map[string]string{} 127 | 128 | openFile, err := os.Open(file) 129 | if err != nil { 130 | return nil, err 131 | } 132 | 133 | byteValue, err := io.ReadAll(openFile) 134 | if err != nil { 135 | return nil, err 136 | } 137 | 138 | localizationFile := localizationFile{} 139 | if err := json.Unmarshal(byteValue, &localizationFile); err != nil { 140 | return nil, err 141 | } 142 | 143 | slicePath := getSlicePath(file) 144 | for key, value := range localizationFile { 145 | newLocalizations[strings.Join(append(slicePath, key), ".")] = value 146 | } 147 | 148 | return newLocalizations, nil 149 | } 150 | 151 | func getSlicePath(file string) []string { 152 | dir, file := filepath.Split(file) 153 | 154 | paths := strings.Replace(dir, *input, "", -1) 155 | pathSlice := strings.Split(paths, string(filepath.Separator)) 156 | 157 | var strs []string 158 | for _, part := range pathSlice { 159 | part := strings.TrimSpace(part) 160 | part = strings.Trim(part, "/") 161 | if part != "" { 162 | strs = append(strs, part) 163 | } 164 | } 165 | 166 | strs = append(strs, strings.Replace(file, filepath.Ext(file), "", -1)) 167 | return strs 168 | } 169 | 170 | func parseFlags(input *string, output *string) (string, string, error) { 171 | var inputDir, outputDir string 172 | 173 | if *input == "" { 174 | return "", "", errFlagInputNotSet 175 | } 176 | if *output == "" { 177 | outputDir = defaultOutputDir 178 | } else { 179 | outputDir = *output 180 | } 181 | 182 | inputDir = *input 183 | 184 | return inputDir, outputDir, nil 185 | } 186 | -------------------------------------------------------------------------------- /cmd/i18n/template.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "text/template" 5 | ) 6 | 7 | var packageTemplate = template.Must(template.New("").Parse(`// Code generated by go-localize; DO NOT EDIT. 8 | // This file was generated by robots at 9 | // {{ .Timestamp }} 10 | 11 | package {{ .Package }} 12 | 13 | import ( 14 | "bytes" 15 | "fmt" 16 | "strings" 17 | "text/template" 18 | ) 19 | 20 | var localizations = map[string]string{ 21 | {{- range $key, $element := .Localizations }} 22 | "{{ $key }}":{{ call $.LineUp $key }} "{{ $element }}", 23 | {{- end }} 24 | } 25 | 26 | type Replacements map[string]interface{} 27 | 28 | type Localizer struct { 29 | Locale string 30 | FallbackLocale string 31 | Localizations map[string]string 32 | } 33 | 34 | func New(locale string, fallbackLocale string) *Localizer { 35 | t := &Localizer{Locale: locale, FallbackLocale: fallbackLocale} 36 | t.Localizations = localizations 37 | return t 38 | } 39 | 40 | func (t Localizer) SetLocales(locale, fallback string) Localizer { 41 | t.Locale = locale 42 | t.FallbackLocale = fallback 43 | return t 44 | } 45 | 46 | func (t Localizer) SetLocale(locale string) Localizer { 47 | t.Locale = locale 48 | return t 49 | } 50 | 51 | func (t Localizer) SetFallbackLocale(fallback string) Localizer { 52 | t.FallbackLocale = fallback 53 | return t 54 | } 55 | 56 | func (t Localizer) GetWithLocale(locale, key string, replacements ...*Replacements) string { 57 | str, ok := t.Localizations[t.getLocalizationKey(locale, key)] 58 | if !ok { 59 | str, ok = t.Localizations[t.getLocalizationKey(t.FallbackLocale, key)] 60 | if !ok { 61 | if strings.Index(key, "}}") == -1 { 62 | return key 63 | } 64 | return t.replace(key, replacements...) 65 | } 66 | } 67 | 68 | // If the str doesn't have any substitutions, no need to 69 | // template.Execute. 70 | if strings.Index(str, "}}") == -1 { 71 | return str 72 | } 73 | 74 | return t.replace(str, replacements...) 75 | } 76 | 77 | func (t Localizer) Get(key string, replacements ...*Replacements) string { 78 | str := t.GetWithLocale(t.Locale, key, replacements...) 79 | return str 80 | } 81 | 82 | func (t Localizer) getLocalizationKey(locale string, key string) string { 83 | return fmt.Sprintf("%v.%v", locale, key) 84 | } 85 | 86 | func (t Localizer) replace(str string, replacements ...*Replacements) string { 87 | b := &bytes.Buffer{} 88 | tmpl, err := template.New("").Parse(str) 89 | if err != nil { 90 | return str 91 | } 92 | 93 | replacementsMerge := Replacements{} 94 | for _, replacement := range replacements { 95 | for k, v := range *replacement { 96 | replacementsMerge[k] = v 97 | } 98 | } 99 | 100 | err = template.Must(tmpl, err).Execute(b, replacementsMerge) 101 | if err != nil { 102 | return str 103 | } 104 | buff := b.String() 105 | return buff 106 | } 107 | `, 108 | )) 109 | -------------------------------------------------------------------------------- /config.json.sample: -------------------------------------------------------------------------------- 1 | { 2 | "telegram_bot_token": "", 3 | "openai_api_key": "", 4 | "openai_org_id": "", 5 | "allowed_telegram_users": [], 6 | "verbose": false, 7 | } -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import tele "gopkg.in/telebot.v3" 4 | 5 | // getChat returns chat from db or creates a new one 6 | func (s *Server) getChat(c *tele.Chat, u *tele.User) *Chat { 7 | var chat Chat 8 | 9 | s.db.Preload("User").Preload("User.Roles").Preload("Role").Preload("History").FirstOrCreate(&chat, Chat{ChatID: c.ID}) 10 | if len(chat.MasterPrompt) == 0 { 11 | chat.MasterPrompt = masterPrompt 12 | chat.ModelName = openAILatest 13 | chat.Temperature = 0.8 14 | chat.Stream = true 15 | chat.ConversationAge = 1 16 | s.db.Save(&chat) 17 | } 18 | 19 | if chat.UserID == 0 { 20 | user := s.getUser(u.Username) 21 | chat.UserID = user.ID 22 | s.db.Save(&chat) 23 | } 24 | 25 | if chat.ConversationAge == 0 { 26 | chat.ConversationAge = 1 27 | s.db.Save(&chat) 28 | } 29 | if chat.Lang == "" { 30 | chat.Lang = u.LanguageCode 31 | s.db.Save(&chat) 32 | } 33 | 34 | // s.db.Find(&chat.History, "chat_id = ?", chat.ID) 35 | // log.Printf("History %d, chatid %d\n", len(chat.History), chat.ID) 36 | 37 | return &chat 38 | } 39 | 40 | func (s *Server) getChatByID(chatID int64) *Chat { 41 | var chat Chat 42 | s.db.First(&chat, Chat{ChatID: chatID}) 43 | s.db.Find(&chat.History, "chat_id = ?", chat.ID) 44 | 45 | return &chat 46 | } 47 | 48 | // getUsers returns all users from db 49 | func (s *Server) getUsers() []User { 50 | var users []User 51 | s.db.Model(&User{}).Preload("Threads").Preload("Threads.Role").Find(&users) 52 | 53 | return users 54 | } 55 | 56 | // getUser returns user from db 57 | func (s *Server) getUser(username string) (user User) { 58 | s.db.First(&user, User{Username: username}) 59 | 60 | return user 61 | } 62 | 63 | func (s *Server) addUser(username string) { 64 | s.db.Create(&User{Username: username}) 65 | } 66 | 67 | func (s *Server) delUser(username string) { 68 | user := s.getUser(username) 69 | if user.ID == 0 { 70 | Log.Info("User not found: ", username) 71 | return 72 | } 73 | 74 | var chat Chat 75 | s.db.First(&chat, Chat{UserID: user.ID}) 76 | if chat.ID != 0 { 77 | s.deleteHistory(chat.ID) 78 | s.db.Unscoped().Delete(&Chat{}, chat.ID) 79 | } 80 | s.db.Unscoped().Delete(&User{}, user.ID) 81 | } 82 | 83 | func (s *Server) deleteHistory(chatID uint) { 84 | s.db.Where("chat_id = ?", chatID).Delete(&ChatMessage{}) 85 | } 86 | 87 | func (s *Server) loadUsers() { 88 | s.Lock() 89 | defer s.Unlock() 90 | admins := s.conf.AllowedTelegramUsers 91 | var usernames []string 92 | s.db.Model(&User{}).Pluck("username", &usernames) 93 | s.users = []string{} 94 | for _, username := range admins { 95 | if !in_array(username, usernames) { 96 | usernames = append(usernames, username) 97 | } 98 | } 99 | s.users = append(s.users, usernames...) 100 | } 101 | 102 | func (s *Server) findRole(userID uint, name string) *Role { 103 | var r Role 104 | s.db.First(&r, Role{UserID: userID, Name: name}) 105 | 106 | return &r 107 | } 108 | 109 | func (s *Server) getModel(model string) *AiModel { 110 | for _, m := range s.conf.Models { 111 | if m.Name == model { 112 | return &m 113 | } 114 | if m.ModelID == model { 115 | return &m 116 | } 117 | if model == openAILatest { 118 | return &AiModel{s.conf.OpenAILatestModel, "OpenAI Latest", "openai"} 119 | } 120 | } 121 | 122 | return &AiModel{model, model, "openai"} 123 | } 124 | 125 | func (s *Server) getRole(id uint) *Role { 126 | var r Role 127 | s.db.First(&r, id) 128 | 129 | return &r 130 | } 131 | 132 | func (s *Server) setChatRole(id *uint, ChatID int64) { 133 | s.db.Model(&Chat{}).Where("chat_id", ChatID).Update("role_id", id) 134 | } 135 | 136 | func (s *Server) setChatLastMessageID(id *string, ChatID int64) { 137 | s.db.Model(&Chat{}).Where("chat_id", ChatID).Update("message_id", id) 138 | } 139 | 140 | // set user.State to null 141 | func (s *Server) resetUserState(user User) { 142 | s.db.Model(&User{}).Where("id", user.ID).Update("State", nil) 143 | } 144 | -------------------------------------------------------------------------------- /deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # check if number is greater or equal to 2 4 | if [ "$#" -lt 2 ]; then 5 | echo "Usage: $0 ssh_host deploy_path [arm|x86]" 6 | exit 1 7 | fi 8 | 9 | # Define the SSH host 10 | SSH_HOST=$1 11 | DEPLOY_PATH=$2 12 | ARCH="x86" 13 | 14 | if [ "$#" -eq 3 ]; then 15 | ARCH=$3 16 | fi 17 | if [ $ARCH == "arm" ]; then 18 | echo "Building for ARM" 19 | env CGO_LDFLAGS="-Llib_arm -lmp3lame -lopus -logg" GOOS=linux GOARCH=arm64 CGO_ENABLED=1 CC=aarch64-unknown-linux-gnu-gcc go build 20 | else 21 | echo "Building for x86" 22 | env CGO_LDFLAGS="-Llib_x86 -lmp3lame -lopus -logg" GOOS=linux GOARCH=amd64 CGO_ENABLED=1 CC=x86_64-linux-gnu-gcc go build 23 | fi 24 | 25 | #env CGO_LDFLAGS="-Llib -lmp3lame -lopus -logg" GOOS=linux GOARCH=arm64 CGO_ENABLED=1 CC=x86_64-linux-gnu-gcc go build 26 | ssh $SSH_HOST "sudo service gptbot stop" > /dev/null 2>&1 27 | scp chatgpt-bot $SSH_HOST:$DEPLOY_PATH 28 | ssh $SSH_HOST "sudo service gptbot start" > /dev/null 2>&1 29 | -------------------------------------------------------------------------------- /erased.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tectiv3/chatgpt-bot/fa8d55c077f3e99bb4a6eb7bf4a2f1567fc3f6a8/erased.ogg -------------------------------------------------------------------------------- /function_calls.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "runtime/debug" 9 | "strconv" 10 | "strings" 11 | "sync" 12 | "time" 13 | 14 | "github.com/go-shiori/go-readability" 15 | "github.com/meinside/openai-go" 16 | "github.com/tectiv3/chatgpt-bot/i18n" 17 | "github.com/tectiv3/chatgpt-bot/tools" 18 | "github.com/tectiv3/chatgpt-bot/vectordb" 19 | tele "gopkg.in/telebot.v3" 20 | ) 21 | 22 | func (s *Server) getFunctionTools() []openai.ChatCompletionTool { 23 | availableTools := []openai.ChatCompletionTool{ 24 | /* 25 | openai.NewChatCompletionTool( 26 | "search_images", 27 | "Search image or GIFs for a given query", 28 | openai.NewToolFunctionParameters(). 29 | AddPropertyWithDescription("query", "string", "The query to search for"). 30 | AddPropertyWithEnums("type", "string", "The type of image to search for. Default to `photo` if not specified", []string{"photo", "gif"}). 31 | AddPropertyWithEnums("region", "string", 32 | "The region to use for the search. Infer this from the language used for the query. Default to `wt-wt` if not specified or can not be inferred. Do not leave it empty.", 33 | []string{ 34 | "xa-ar", "xa-en", "ar-es", "au-en", "at-de", "be-fr", "be-nl", "br-pt", "bg-bg", 35 | "ca-en", "ca-fr", "ct-ca", "cl-es", "cn-zh", "co-es", "hr-hr", "cz-cs", "dk-da", 36 | "ee-et", "fi-fi", "fr-fr", "de-de", "gr-el", "hk-tzh", "hu-hu", "in-en", "id-id", 37 | "id-en", "ie-en", "il-he", "it-it", "jp-jp", "kr-kr", "lv-lv", "lt-lt", "xl-es", 38 | "my-ms", "my-en", "mx-es", "nl-nl", "nz-en", "no-no", "pe-es", "ph-en", "ph-tl", 39 | "pl-pl", "pt-pt", "ro-ro", "ru-ru", "sg-en", "sk-sk", "sl-sl", "za-en", "es-es", 40 | "se-sv", "ch-de", "ch-fr", "ch-it", "tw-tzh", "th-th", "tr-tr", "ua-uk", "uk-en", 41 | "us-en", "ue-es", "ve-es", "vn-vi", "wt-wt", 42 | }). 43 | SetRequiredParameters([]string{"query", "type", "region"}), 44 | ), 45 | openai.NewChatCompletionTool( 46 | "web_search", 47 | "This is web search. Use this tool to search the internet. Use it when you need access to real time information. The top 10 results will be added to the vector db. The top 3 results are also getting returned to you directly. For more search queries through the same websites, use the vector_search tool. Input should be a string. Append sources to the response.", 48 | openai.NewToolFunctionParameters(). 49 | AddPropertyWithDescription("query", "string", "A query to search the web for"). 50 | // AddPropertyWithEnums("region", "string", 51 | // "The region to use for the search. Infer this from the language used for the query. Default to `wt-wt` if not specified or can not be inferred. Do not leave it empty.", 52 | // []string{"xa-ar", "xa-en", "ar-es", "au-en", "at-de", "be-fr", "be-nl", "br-pt", "bg-bg", 53 | // "ca-en", "ca-fr", "ct-ca", "cl-es", "cn-zh", "co-es", "hr-hr", "cz-cs", "dk-da", 54 | // "ee-et", "fi-fi", "fr-fr", "de-de", "gr-el", "hk-tzh", "hu-hu", "in-en", "id-id", 55 | // "id-en", "ie-en", "il-he", "it-it", "jp-jp", "kr-kr", "lv-lv", "lt-lt", "xl-es", 56 | // "my-ms", "my-en", "mx-es", "nl-nl", "nz-en", "no-no", "pe-es", "ph-en", "ph-tl", 57 | // "pl-pl", "pt-pt", "ro-ro", "ru-ru", "sg-en", "sk-sk", "sl-sl", "za-en", "es-es", 58 | // "se-sv", "ch-de", "ch-fr", "ch-it", "tw-tzh", "th-th", "tr-tr", "ua-uk", "uk-en", 59 | // "us-en", "ue-es", "ve-es", "vn-vi", "wt-wt"}). 60 | SetRequiredParameters([]string{"query"}), 61 | ), 62 | */ 63 | openai.NewChatCompletionTool( 64 | "text_to_speech", 65 | "Convert provided text to speech.", 66 | openai.NewToolFunctionParameters(). 67 | AddPropertyWithDescription("text", "string", "A text to use."). 68 | AddPropertyWithEnums("language", "string", 69 | "The language to use for the speech synthesis. Default to `en` if could not be detected.", 70 | []string{"fr", "ru", "en", "ja", "ua", "de", "es", "it", "tw"}). 71 | SetRequiredParameters([]string{"text", "language"}), 72 | ), 73 | /* 74 | openai.NewChatCompletionTool( 75 | "full_webpage_to_speech", 76 | "Download full web page and convert it to speech. Use ONLY when you need to pass the full content of a web page to the speech synthesiser.", 77 | openai.NewToolFunctionParameters(). 78 | AddPropertyWithDescription("url", "string", "A valid URL to a web page, should not end in PDF."). 79 | SetRequiredParameters([]string{"url"}), 80 | ), 81 | openai.NewChatCompletionTool( 82 | "vector_search", 83 | `Useful for searching through added files and websites. Search for keywords in the text not whole questions, avoid relative words like "yesterday" think about what could be in the text. The input to this tool will be run against a vector db. The top results will be returned as json.`, 84 | openai.NewToolFunctionParameters(). 85 | AddPropertyWithDescription("query", "string", "A query to search the vector db"). 86 | SetRequiredParameters([]string{"query"}), 87 | ), 88 | */ 89 | openai.NewChatCompletionTool( 90 | "set_reminder", 91 | "Set a reminder to do something at a specific time.", 92 | openai.NewToolFunctionParameters(). 93 | AddPropertyWithDescription("reminder", "string", "A reminder of what to do, e.g. 'buy groceries'"). 94 | AddPropertyWithDescription("time", "number", "A time at which to be reminded in minutes from now, e.g. 1440"). 95 | SetRequiredParameters([]string{"reminder", "time"}), 96 | ), 97 | openai.NewChatCompletionTool( 98 | "make_summary", 99 | "Make a summary of a web page.", 100 | openai.NewToolFunctionParameters(). 101 | AddPropertyWithDescription("url", "string", "A valid URL to a web page"). 102 | SetRequiredParameters([]string{"url"}), 103 | ), 104 | openai.NewChatCompletionTool( 105 | "get_crypto_rate", 106 | "Get the current rate of various crypto currencies", 107 | openai.NewToolFunctionParameters(). 108 | AddPropertyWithDescription("asset", "string", "Asset of the crypto"). 109 | SetRequiredParameters([]string{"asset"}), 110 | ), 111 | } 112 | 113 | // availableTools = append(availableTools, 114 | // openai.NewChatCompletionTool( 115 | // "generate_image", 116 | // "Generate an image based on the input text", 117 | // openai.NewToolFunctionParameters(). 118 | // AddPropertyWithDescription("text", "string", "The text to generate an image from"). 119 | // AddPropertyWithDescription("hd", "boolean", "Whether to generate an HD image. Default to false."). 120 | // SetRequiredParameters([]string{"text", "hd"}), 121 | // ), 122 | // ) 123 | 124 | return availableTools 125 | } 126 | 127 | func (s *Server) handleResponseFunctionCalls(chat *Chat, c tele.Context, functions []openai.ResponseOutput) (string, error) { 128 | return "", nil 129 | } 130 | 131 | func (s *Server) handleFunctionCall(chat *Chat, c tele.Context, response openai.ChatMessage) (string, error) { 132 | // refactor to handle multiple function calls not just the first one 133 | result := "" 134 | var resultErr error 135 | var toolID string 136 | sentMessage := chat.getSentMessage(c) 137 | toolCallsCount := len(response.ToolCalls) 138 | reply := "" 139 | for i, toolCall := range response.ToolCalls { 140 | function := toolCall.Function 141 | if function.Name == "" { 142 | err := fmt.Sprint("there was no returned function call name") 143 | resultErr = fmt.Errorf(err) 144 | continue 145 | } 146 | Log.WithField("tools", toolCallsCount).WithField("tool", i).WithField("function", function.Name).WithField("user", c.Sender().Username).Info("Function call") 147 | 148 | switch function.Name { 149 | case "search_images": 150 | type parsed struct { 151 | Query string `json:"query"` 152 | Type string `json:"type"` 153 | Region string `json:"region"` 154 | } 155 | var arguments parsed 156 | if err := toolCall.ArgumentsInto(&arguments); err != nil { 157 | resultErr = fmt.Errorf("failed to parse arguments into struct: %s", err) 158 | continue 159 | } 160 | if s.conf.Verbose { 161 | Log.Info("Will call ", function.Name, "(", arguments.Query, ", ", arguments.Type, ", ", arguments.Region, ")") 162 | } 163 | _, _ = c.Bot().Edit(sentMessage, 164 | fmt.Sprintf(chat.t("Action: {{.tool}}\nAction input: %s", &i18n.Replacements{"tool": chat.t(function.Name)}), arguments.Query), 165 | ) 166 | param, err := tools.NewSearchImageParam(arguments.Query, arguments.Region, arguments.Type) 167 | if err != nil { 168 | resultErr = err 169 | continue 170 | } 171 | result := tools.SearchImages(param) 172 | if result.IsErr() { 173 | resultErr = result.Error() 174 | continue 175 | } 176 | res := *result.Unwrap() 177 | if len(res) == 0 { 178 | resultErr = fmt.Errorf("no results found") 179 | continue 180 | } 181 | img := tele.FromURL(res[0].Image) 182 | return "", c.Send(&tele.Photo{ 183 | File: img, 184 | Caption: fmt.Sprintf("%s\n%s", res[0].Title, res[0].Link), 185 | }) 186 | 187 | case "web_search": 188 | type parsed struct { 189 | Query string `json:"query"` 190 | // Region string `json:"region"` 191 | } 192 | var arguments parsed 193 | if err := toolCall.ArgumentsInto(&arguments); err != nil { 194 | resultErr = fmt.Errorf("failed to parse arguments into struct: %s", err) 195 | continue 196 | } 197 | if len(reply) > 0 { 198 | reply += "\n" 199 | } 200 | reply += fmt.Sprintf(chat.t("Action: {{.tool}}\nAction input: %s", &i18n.Replacements{"tool": chat.t(function.Name)}), arguments.Query) 201 | _, _ = c.Bot().Edit(sentMessage, reply) 202 | 203 | if s.conf.Verbose { 204 | Log.Info("Will call ", function.Name, "(", arguments.Query, ")") 205 | } 206 | var err error 207 | result, err = s.webSearchSearX(arguments.Query, "wt-wt", c.Sender().Username) 208 | if err != nil { 209 | Log.Warn("Failed to search web", "error=", err) 210 | continue 211 | } 212 | resultErr = nil 213 | toolID = toolCall.ID 214 | response.Role = openai.ChatMessageRoleAssistant 215 | chat.addMessageToDialog(response) 216 | 217 | case "vector_search": 218 | type parsed struct { 219 | Query string `json:"query"` 220 | } 221 | var arguments parsed 222 | if err := toolCall.ArgumentsInto(&arguments); err != nil { 223 | resultErr = fmt.Errorf("failed to parse arguments into struct: %s", err) 224 | continue 225 | } 226 | if s.conf.Verbose { 227 | Log.Info("Will call ", function.Name, "(", arguments.Query, ")") 228 | } 229 | if len(reply) > 0 { 230 | reply += "\n" 231 | } 232 | reply += fmt.Sprintf(chat.t("Action: {{.tool}}\nAction input: %s", &i18n.Replacements{"tool": chat.t(function.Name)}), arguments.Query) 233 | _, _ = c.Bot().Edit(sentMessage, reply) 234 | var err error 235 | result, err = s.vectorSearch(arguments.Query, c.Sender().Username) 236 | if err != nil { 237 | Log.Warn("Failed to search vector", "error=", err) 238 | continue 239 | } 240 | resultErr = nil 241 | toolID = toolCall.ID 242 | response.Role = openai.ChatMessageRoleAssistant 243 | chat.addMessageToDialog(response) 244 | 245 | case "text_to_speech": 246 | type parsed struct { 247 | Text string `json:"text"` 248 | Language string `json:"language"` 249 | } 250 | var arguments parsed 251 | if err := toolCall.ArgumentsInto(&arguments); err != nil { 252 | resultErr = fmt.Errorf("failed to parse arguments into struct: %s", err) 253 | continue 254 | } 255 | if s.conf.Verbose { 256 | Log.Info("Will call ", function.Name, "(", arguments.Text, ", ", arguments.Language, ")") 257 | } 258 | if len(reply) > 0 { 259 | reply += "\n" 260 | } 261 | reply += fmt.Sprintf(chat.t("Action: {{.tool}}\nAction input: %s", &i18n.Replacements{"tool": chat.t(function.Name)}), arguments.Text) 262 | _, _ = c.Bot().Edit(sentMessage, reply) 263 | 264 | go s.textToSpeech(c, arguments.Text, arguments.Language) 265 | 266 | case "web_to_speech": 267 | type parsed struct { 268 | URL string `json:"url"` 269 | } 270 | var arguments parsed 271 | if err := toolCall.ArgumentsInto(&arguments); err != nil { 272 | resultErr = fmt.Errorf("failed to parse arguments into struct: %s", err) 273 | continue 274 | } 275 | if s.conf.Verbose { 276 | Log.Info("Will call ", function.Name, "(", arguments.URL, ")") 277 | } 278 | _, _ = c.Bot().Edit(sentMessage, 279 | fmt.Sprintf(chat.t("Action: {{.tool}}\nAction input: %s", &i18n.Replacements{"tool": chat.t(function.Name)}), arguments.URL), 280 | ) 281 | 282 | go s.pageToSpeech(c, arguments.URL) 283 | 284 | return "", nil 285 | 286 | case "generate_image": 287 | type parsed struct { 288 | Text string `json:"text"` 289 | HD bool `json:"hd"` 290 | } 291 | var arguments parsed 292 | if err := toolCall.ArgumentsInto(&arguments); err != nil { 293 | resultErr = fmt.Errorf("failed to parse arguments into struct: %s", err) 294 | continue 295 | } 296 | if s.conf.Verbose { 297 | Log.WithField("user", c.Sender().Username).Info("Will call ", function.Name, "(", arguments.Text, ", ", arguments.HD, ")") 298 | } 299 | _, _ = c.Bot().Edit(sentMessage, 300 | fmt.Sprintf(chat.t("Action: {{.tool}}\nAction input: %s", &i18n.Replacements{"tool": chat.t(function.Name)}), arguments.Text), 301 | ) 302 | 303 | if err := s.textToImage(c, arguments.Text, arguments.HD); err != nil { 304 | Log.WithField("user", c.Sender().Username).Warn(err) 305 | } else { 306 | continue 307 | } 308 | 309 | case "set_reminder": 310 | type parsed struct { 311 | Reminder string `json:"reminder"` 312 | Minutes int64 `json:"time"` 313 | } 314 | var arguments parsed 315 | if err := toolCall.ArgumentsInto(&arguments); err != nil { 316 | resultErr = fmt.Errorf("failed to parse arguments into struct: %s", err) 317 | continue 318 | } 319 | if s.conf.Verbose { 320 | Log.Info("Will call ", function.Name, "(", arguments.Reminder, ", ", arguments.Minutes, ")") 321 | } 322 | _, _ = c.Bot().Edit(sentMessage, 323 | fmt.Sprintf(chat.t("Action: {{.tool}}\nAction input: %s", &i18n.Replacements{"tool": chat.t(function.Name)}), arguments.Reminder+","+strconv.Itoa(int(arguments.Minutes))), 324 | ) 325 | 326 | if err := s.setReminder(c.Chat().ID, arguments.Reminder, arguments.Minutes); err != nil { 327 | resultErr = fmt.Errorf("failed to set reminder: %s", err) 328 | continue 329 | } 330 | 331 | return fmt.Sprintf("Reminder set for %d minutes from now", arguments.Minutes), nil 332 | 333 | case "make_summary": 334 | type parsed struct { 335 | URL string `json:"url"` 336 | } 337 | var arguments parsed 338 | if err := toolCall.ArgumentsInto(&arguments); err != nil { 339 | resultErr = fmt.Errorf("failed to parse arguments into struct: %s", err) 340 | continue 341 | } 342 | if s.conf.Verbose { 343 | Log.Info("Will call ", function.Name, "(", arguments.URL, ")") 344 | } 345 | _, _ = c.Bot().Edit(sentMessage, 346 | fmt.Sprintf(chat.t("Action: {{.tool}}\nAction input: %s", &i18n.Replacements{"tool": chat.t(function.Name)}), arguments.URL), 347 | ) 348 | go s.getPageSummary(chat, arguments.URL) 349 | continue 350 | 351 | case "get_crypto_rate": 352 | type parsed struct { 353 | Asset string `json:"asset"` 354 | } 355 | var arguments parsed 356 | if err := toolCall.ArgumentsInto(&arguments); err != nil { 357 | resultErr = fmt.Errorf("failed to parse arguments into struct: %s", err) 358 | continue 359 | } 360 | if s.conf.Verbose { 361 | Log.Info("Will call ", function.Name, "(", arguments.Asset, ")") 362 | } 363 | _, _ = c.Bot().Edit(sentMessage, 364 | fmt.Sprintf(chat.t("Action: {{.tool}}\nAction input: %s", &i18n.Replacements{"tool": chat.t(function.Name)}), arguments.Asset)) 365 | 366 | return s.getCryptoRate(arguments.Asset) 367 | } 368 | } 369 | 370 | if len(result) == 0 { 371 | s.saveHistory(chat) 372 | return "", resultErr 373 | } 374 | chat.addToolResultToDialog(toolID, result) 375 | 376 | if chat.Stream { 377 | _ = s.getStreamAnswer(chat, c, nil) 378 | return "", nil 379 | } 380 | 381 | err := s.getAnswer(chat, c, nil) 382 | return "", err 383 | } 384 | 385 | func (s *Server) setReminder(chatID int64, reminder string, minutes int64) error { 386 | timer := time.NewTimer(time.Minute * time.Duration(minutes)) 387 | go func() { 388 | <-timer.C 389 | _, _ = s.bot.Send(tele.ChatID(chatID), reminder) 390 | }() 391 | 392 | return nil 393 | } 394 | 395 | func (s *Server) pageToSpeech(c tele.Context, url string) { 396 | defer func() { 397 | if err := recover(); err != nil { 398 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 399 | } 400 | }() 401 | 402 | article, err := readability.FromURL(url, 30*time.Second) 403 | if err != nil { 404 | Log.Fatalf("failed to parse %s, %v\n", url, err) 405 | } 406 | 407 | if s.conf.Verbose { 408 | Log.Info("Page title=", article.Title, ", content=", len(article.TextContent)) 409 | } 410 | _ = c.Notify(tele.Typing) 411 | 412 | s.sendAudio(c, article.TextContent) 413 | } 414 | 415 | func (s *Server) getPageSummary(chat *Chat, url string) { 416 | defer func() { 417 | if err := recover(); err != nil { 418 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 419 | } 420 | }() 421 | article, err := readability.FromURL(url, 30*time.Second) 422 | if err != nil { 423 | Log.Fatalf("failed to parse %s, %v\n", url, err) 424 | } 425 | 426 | if s.conf.Verbose { 427 | Log.Info("Page title=", article.Title, ", content=", len(article.TextContent)) 428 | } 429 | 430 | msg := openai.NewChatUserMessage(article.TextContent) 431 | // You are acting as a summarization AI, and for the input text please summarize it to the most important 3 to 5 bullet points for brevity: 432 | system := openai.NewChatSystemMessage("Make a summary of the article. Try to be as brief as possible and highlight key points. Use markdown to annotate the summary.") 433 | 434 | history := []openai.ChatMessage{system, msg} 435 | 436 | response, err := s.openAI.CreateChatCompletion(miniModel, history, openai.ChatCompletionOptions{}.SetUser(userAgent(31337)).SetTemperature(0.2)) 437 | if err != nil { 438 | Log.Warn("failed to create chat completion", "error=", err) 439 | s.bot.Send(tele.ChatID(chat.ChatID), err.Error(), "text", replyMenu) 440 | return 441 | } 442 | 443 | chat.TotalTokens += response.Usage.TotalTokens 444 | str, _ := response.Choices[0].Message.ContentString() 445 | chat.addMessageToDialog(openai.NewChatAssistantMessage(str)) 446 | s.db.Save(&chat) 447 | 448 | if _, err := s.bot.Send(tele.ChatID(chat.ChatID), 449 | str, 450 | "text", 451 | &tele.SendOptions{ParseMode: tele.ModeMarkdown}, 452 | replyMenu, 453 | ); err != nil { 454 | Log.Error("Sending", "error=", err) 455 | } 456 | } 457 | 458 | func (s *Server) getCryptoRate(asset string) (string, error) { 459 | asset = strings.ToLower(asset) 460 | format := "$%0.0f" 461 | switch asset { 462 | case "btc": 463 | asset = "bitcoin" 464 | case "eth": 465 | asset = "ethereum" 466 | case "ltc": 467 | asset = "litecoin" 468 | case "sol": 469 | asset = "solana" 470 | case "xrp": 471 | asset = "ripple" 472 | format = "$%0.3f" 473 | case "xlm": 474 | asset = "stellar" 475 | format = "$%0.3f" 476 | case "ada": 477 | asset = "cardano" 478 | format = "$%0.3f" 479 | } 480 | client := &http.Client{} 481 | client.Timeout = 10 * time.Second 482 | req, err := http.NewRequest("GET", fmt.Sprintf("https://api.coincap.io/v2/assets/%s", asset), nil) 483 | if err != nil { 484 | return "", err 485 | } 486 | resp, err := client.Do(req) 487 | if err != nil { 488 | return "", err 489 | } 490 | defer resp.Body.Close() 491 | 492 | var symbol CoinCap 493 | err = json.NewDecoder(resp.Body).Decode(&symbol) 494 | if err != nil { 495 | return "", err 496 | } 497 | price, _ := strconv.ParseFloat(symbol.Data.PriceUsd, 64) 498 | 499 | return fmt.Sprintf(format, price), nil 500 | } 501 | 502 | func (s *Server) webSearchDDG(input, region, username string) (string, error) { 503 | param, err := tools.NewSearchParam(input, region) 504 | if err != nil { 505 | return "", err 506 | } 507 | result := tools.Search(param) 508 | if result.IsErr() { 509 | return "", result.Error() 510 | } 511 | res := *result.Unwrap() 512 | if len(res) == 0 { 513 | return "", fmt.Errorf("no results found") 514 | } 515 | Log.Info("Search results found=", len(res)) 516 | ctx := context.WithValue(context.Background(), "ollama", s.conf.OllamaEnabled) 517 | limit := 10 518 | 519 | wg := sync.WaitGroup{} 520 | counter := 0 521 | for i := range res { 522 | if counter > limit { 523 | break 524 | } 525 | // if result link ends in .pdf, skip 526 | if strings.HasSuffix(res[i].Link, ".pdf") { 527 | continue 528 | } 529 | 530 | counter += 1 531 | wg.Add(1) 532 | go func(i int) { 533 | defer func() { 534 | if r := recover(); r != nil { 535 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 536 | } 537 | }() 538 | err := vectordb.DownloadWebsiteToVectorDB(ctx, res[i].Link, username) 539 | if err != nil { 540 | Log.Warn("Error downloading website", "error=", err) 541 | wg.Done() 542 | return 543 | } 544 | wg.Done() 545 | }(i) 546 | } 547 | wg.Wait() 548 | 549 | return s.vectorSearch(input, username) 550 | } 551 | 552 | func (s *Server) webSearchSearX(input, region, username string) (string, error) { 553 | input = strings.TrimSuffix(strings.TrimSuffix(strings.TrimPrefix(input, "\""), "\""), "?") 554 | //if region != "" && region != "wt-wt" { 555 | // input += ":" + region 556 | //} 557 | res, err := tools.SearchSearX(input) 558 | if err != nil { 559 | return "", err 560 | } 561 | 562 | Log.Info("Search results found=", len(res)) 563 | ctx := context.WithValue(context.Background(), "ollama", s.conf.OllamaEnabled) 564 | limit := 10 565 | 566 | wg := sync.WaitGroup{} 567 | counter := 0 568 | for i := range res { 569 | if counter > limit { 570 | break 571 | } 572 | // if result link ends in .pdf, skip 573 | if strings.HasSuffix(res[i].URL, ".pdf") { 574 | continue 575 | } 576 | 577 | counter += 1 578 | wg.Add(1) 579 | go func(i int) { 580 | defer func() { 581 | if r := recover(); r != nil { 582 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 583 | } 584 | }() 585 | err := vectordb.DownloadWebsiteToVectorDB(ctx, res[i].URL, username) 586 | if err != nil { 587 | Log.Warn("Error downloading website", "error=", err) 588 | wg.Done() 589 | return 590 | } 591 | wg.Done() 592 | }(i) 593 | } 594 | wg.Wait() 595 | 596 | return s.vectorSearch(input, username) 597 | } 598 | 599 | func (s *Server) vectorSearch(input string, username string) (string, error) { 600 | ctx := context.Background() 601 | ctx = context.WithValue(ctx, "ollama", s.conf.OllamaEnabled) 602 | docs, err := vectordb.SearchVectorDB(ctx, input, username) 603 | type DocResult struct { 604 | Text string 605 | Source string 606 | } 607 | var results []DocResult 608 | 609 | for _, r := range docs { 610 | newResult := DocResult{Text: r.PageContent} 611 | source, ok := r.Metadata["url"].(string) 612 | if ok { 613 | newResult.Source = source 614 | } 615 | 616 | results = append(results, newResult) 617 | } 618 | 619 | if len(docs) == 0 { 620 | response := "no results found. Try other db search keywords or download more websites." 621 | Log.Warn("no results found", "input", input) 622 | results = append(results, DocResult{Text: response}) 623 | } else if len(results) == 0 { 624 | response := "No new results found, all returned results have been used already. Try other db search keywords or download more websites." 625 | results = append(results, DocResult{Text: response}) 626 | } 627 | 628 | resultJson, err := json.Marshal(results) 629 | if err != nil { 630 | return "", err 631 | } 632 | 633 | return string(resultJson), nil 634 | } 635 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tectiv3/chatgpt-bot 2 | 3 | go 1.23.3 4 | 5 | require ( 6 | github.com/PuerkitoBio/goquery v1.10.3 7 | github.com/amikos-tech/chroma-go v0.1.4 8 | github.com/eminarican/safetypes v0.0.8 9 | github.com/go-shiori/go-readability v0.0.0-20250217085726-9f5bf5ca7612 10 | github.com/google/uuid v1.6.0 11 | github.com/joho/godotenv v1.5.1 12 | github.com/meinside/openai-go v0.4.7 13 | github.com/pkoukk/tiktoken-go v0.1.7 14 | github.com/sirupsen/logrus v1.9.3 15 | github.com/tectiv3/awsnova-go v0.0.0-20250112173251-e2244ec0b117 16 | github.com/tectiv3/go-lame v0.0.0-20240321153525-da7c3c48f794 17 | golang.org/x/crypto v0.39.0 18 | golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 19 | golang.org/x/net v0.41.0 20 | gopkg.in/telebot.v3 v3.3.8 21 | gorm.io/driver/sqlite v1.6.0 22 | gorm.io/gorm v1.30.0 23 | ) 24 | 25 | require ( 26 | github.com/Masterminds/semver v1.5.0 // indirect 27 | github.com/andybalholm/cascadia v1.3.3 // indirect 28 | github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect 29 | github.com/dlclark/regexp2 v1.11.5 // indirect 30 | github.com/go-shiori/dom v0.0.0-20230515143342-73569d674e1c // indirect 31 | github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f // indirect 32 | github.com/imacks/aws-sigv4 v0.1.1 // indirect 33 | github.com/jinzhu/inflection v1.0.0 // indirect 34 | github.com/jinzhu/now v1.1.5 // indirect 35 | github.com/mattn/go-sqlite3 v1.14.28 // indirect 36 | github.com/oklog/ulid v1.3.1 // indirect 37 | golang.org/x/sys v0.33.0 // indirect 38 | golang.org/x/term v0.32.0 // indirect 39 | golang.org/x/text v0.26.0 // indirect 40 | ) 41 | 42 | replace github.com/meinside/openai-go => github.com/tectiv3/openai-go v0.6.2 43 | -------------------------------------------------------------------------------- /i18n/i18n.go: -------------------------------------------------------------------------------- 1 | // Code generated by go-localize; DO NOT EDIT. 2 | // This file was generated by robots at 3 | // 2024-09-04 12:16:23.631291 +0100 WEST m=+0.002913251 4 | 5 | package i18n 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "strings" 11 | "text/template" 12 | ) 13 | 14 | var localizations = map[string]string{ 15 | "ru.Action: {{.tool}}\nAction input: %s": "Действие: {{.tool}}\nЗапрос: %s", 16 | "ru.Conversation age set to %d days": "Длительность хранения истории разговора установлена на %d дней", 17 | "ru.Enter role name": "Введите имя для этой роли", 18 | "ru.Enter system prompt": "Введите системный запрос который определит как будет вести себя ассистент", 19 | "ru.Model set to {{.model}}": "Языковая модель установлена на {{.model}}", 20 | "ru.New Conversation": "Новый Диалог", 21 | "ru.New Role": "Новая Роль", 22 | "ru.No response from API.": "Нет ответа от API.", 23 | "ru.Please provide a longer prompt": "Пожалуйста, введите более длинный запрос", 24 | "ru.Please provide a number": "Пожалуйста, введите число", 25 | "ru.Please provide a text file": "Пожалуйста, загрузите текстовый файл", 26 | "ru.Processing document. Please wait...": "Идет обработка документа. Пожалуйста, подождите...", 27 | "ru.Prompt set": "Новый системный запрос установлен", 28 | "ru.Role deleted": "Роль была успешно удалена", 29 | "ru.Role not found": "Роль не найдена", 30 | "ru.Select Role": "Выберите новую роль для ассистента", 31 | "ru.Select model": "Выберите языковую модель", 32 | "ru.Set temperature from less random (0.0) to more random (1.0). Current: %0.2f (default: 0.8)": "Установите креативность модели от менее случайной (0.0) до более случайной (1.0). Текущая: %0.2f (по умолчанию: 0.8)", 33 | "ru.Stream is {{.status}}": "Функция потоковой передачи сообщений: {{.status}}", 34 | "ru.Temperature set to {{.temp}}": "Креативность модели установлена на {{.temp}}", 35 | "ru.This bot will answer your messages with ChatGPT API": "Этот бот будет отвечать на ваши сообщения с помощью ChatGPT", 36 | "ru._Transcript:_\\n%s\\n\\n_Answer:_ \\n\\n\"": "_Транскрипт:_\n%s\n\n_Ответ:_ \n\n", 37 | "ru.default": "По умолчанию", 38 | "ru.disabled": "деактивировано", 39 | "ru.enabled": "активировано", 40 | "ru.get_crypto_rate": "Запрос курса криптовалюты", 41 | "ru.search_images": "Поиск изображений", 42 | "ru.set_reminder": "Установка напоминания", 43 | "ru.vector_search": "Поиск по векторной базе данных", 44 | "ru.web_search": "Поиск в интернете", 45 | } 46 | 47 | type Replacements map[string]interface{} 48 | 49 | type Localizer struct { 50 | Locale string 51 | FallbackLocale string 52 | Localizations map[string]string 53 | } 54 | 55 | func New(locale string, fallbackLocale string) *Localizer { 56 | t := &Localizer{Locale: locale, FallbackLocale: fallbackLocale} 57 | t.Localizations = localizations 58 | return t 59 | } 60 | 61 | func (t Localizer) SetLocales(locale, fallback string) Localizer { 62 | t.Locale = locale 63 | t.FallbackLocale = fallback 64 | return t 65 | } 66 | 67 | func (t Localizer) SetLocale(locale string) Localizer { 68 | t.Locale = locale 69 | return t 70 | } 71 | 72 | func (t Localizer) SetFallbackLocale(fallback string) Localizer { 73 | t.FallbackLocale = fallback 74 | return t 75 | } 76 | 77 | func (t Localizer) GetWithLocale(locale, key string, replacements ...*Replacements) string { 78 | str, ok := t.Localizations[t.getLocalizationKey(locale, key)] 79 | if !ok { 80 | str, ok = t.Localizations[t.getLocalizationKey(t.FallbackLocale, key)] 81 | if !ok { 82 | if strings.Index(key, "}}") == -1 { 83 | return key 84 | } 85 | return t.replace(key, replacements...) 86 | } 87 | } 88 | 89 | // If the str doesn't have any substitutions, no need to 90 | // template.Execute. 91 | if strings.Index(str, "}}") == -1 { 92 | return str 93 | } 94 | 95 | return t.replace(str, replacements...) 96 | } 97 | 98 | func (t Localizer) Get(key string, replacements ...*Replacements) string { 99 | str := t.GetWithLocale(t.Locale, key, replacements...) 100 | return str 101 | } 102 | 103 | func (t Localizer) getLocalizationKey(locale string, key string) string { 104 | return fmt.Sprintf("%v.%v", locale, key) 105 | } 106 | 107 | func (t Localizer) replace(str string, replacements ...*Replacements) string { 108 | b := &bytes.Buffer{} 109 | tmpl, err := template.New("").Parse(str) 110 | if err != nil { 111 | return str 112 | } 113 | 114 | replacementsMerge := Replacements{} 115 | for _, replacement := range replacements { 116 | for k, v := range *replacement { 117 | replacementsMerge[k] = v 118 | } 119 | } 120 | 121 | err = template.Must(tmpl, err).Execute(b, replacementsMerge) 122 | if err != nil { 123 | return str 124 | } 125 | buff := b.String() 126 | return buff 127 | } 128 | -------------------------------------------------------------------------------- /image.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/meinside/openai-go" 8 | tele "gopkg.in/telebot.v3" 9 | ) 10 | 11 | func (s *Server) handleImage(c tele.Context) { 12 | photo := c.Message().Photo.File 13 | 14 | var fileName string 15 | // var err error 16 | // var reader io.ReadCloser 17 | 18 | if s.conf.TelegramServerURL != "" { 19 | f, err := c.Bot().FileByID(photo.FileID) 20 | if err != nil { 21 | Log.Warn("Error getting file ID", "error=", err) 22 | return 23 | } 24 | // start reader from f.FilePath 25 | //reader, err = os.Open(f.FilePath) 26 | //if err != nil { 27 | // Log.Warn("Error opening file", "error=", err) 28 | // return 29 | //} 30 | fileName = f.FilePath 31 | } else { 32 | out, err := os.Create("uploads/" + photo.FileID + ".jpg") 33 | if err != nil { 34 | Log.Warn("Error creating file", "error=", err) 35 | return 36 | } 37 | if err := c.Bot().Download(&photo, out.Name()); err != nil { 38 | Log.Warn("Error getting file content", "error=", err) 39 | return 40 | } 41 | fileName = out.Name() 42 | } 43 | 44 | //defer reader.Close() 45 | // 46 | //bytes, err := io.ReadAll(reader) 47 | //if err != nil { 48 | // Log.Warn("Error reading file content", "error=", err) 49 | // return 50 | //} 51 | // 52 | //var base64Encoding string 53 | // 54 | //// Determine the content type of the image file 55 | //mimeType := http.DetectContentType(bytes) 56 | // 57 | //// Prepend the appropriate URI scheme header depending 58 | //// on the MIME type 59 | //switch mimeType { 60 | //case "image/jpeg": 61 | // base64Encoding += "data:image/jpeg;base64," 62 | //case "image/png": 63 | // base64Encoding += "data:image/png;base64," 64 | //} 65 | // 66 | //// Append the base64 encoded output 67 | //encoded := base64Encoding + toBase64(bytes) 68 | 69 | chat := s.getChat(c.Chat(), c.Sender()) 70 | chat.addImageToDialog(c.Message().Caption, fileName) 71 | s.db.Save(&chat) 72 | 73 | s.complete(c, "", true) 74 | } 75 | 76 | func (s *Server) textToImage(c tele.Context, text string, hd bool) error { 77 | Log.WithField("user", c.Sender().Username).Info("generating image") 78 | options := openai.ImageOptions{}.SetResponseFormat(openai.IamgeResponseFormatURL). 79 | SetSize(openai.ImageSize1024x1024_DallE3). 80 | SetN(1). 81 | SetModel("dall-e-3") 82 | if hd { 83 | options.SetQuality("hd") 84 | } 85 | 86 | created, err := s.openAI.CreateImage(text, options) 87 | if err != nil { 88 | return fmt.Errorf("failed to create image: %s", err) 89 | } 90 | 91 | if len(created.Data) <= 0 { 92 | return fmt.Errorf("no items returned") 93 | } 94 | 95 | Log.WithField("user", c.Sender().Username).WithField("results", len(created.Data)).Info("image generation complete") 96 | 97 | for _, item := range created.Data { 98 | m := &tele.Photo{File: tele.FromURL(*item.URL)} 99 | _ = c.Send(m, "text", &tele.SendOptions{ReplyTo: c.Message()}) 100 | } 101 | 102 | return nil 103 | } 104 | -------------------------------------------------------------------------------- /localizations_src/ru.json: -------------------------------------------------------------------------------- 1 | { 2 | "New Conversation": "Новый Диалог", 3 | "New Role": "Новая Роль", 4 | "Select Role": "Выберите новую роль для ассистента", 5 | "Select model": "Выберите языковую модель", 6 | "Please provide a number": "Пожалуйста, введите число", 7 | "Please provide a longer prompt": "Пожалуйста, введите более длинный запрос", 8 | "enabled": "активировано", 9 | "disabled": "деактивировано", 10 | "Stream is {{.status}}": "Функция потоковой передачи сообщений: {{.status}}", 11 | "Model set to {{.model}}": "Языковая модель установлена на {{.model}}", 12 | "Temperature set to {{.temp}}": "Креативность модели установлена на {{.temp}}", 13 | "Prompt set": "Новый системный запрос установлен", 14 | "This bot will answer your messages with ChatGPT API": "Этот бот будет отвечать на ваши сообщения с помощью ChatGPT", 15 | "Set temperature from less random (0.0) to more random (1.0).\nCurrent: %0.2f (default: 0.8)": "Установите креативность модели от менее случайной (0.0) до более случайной (1.0).\nТекущая: %0.2f (по умолчанию: 0.8)", 16 | "Conversation age set to %d days": "Длительность хранения истории разговора установлена на %d дней", 17 | "Processing document. Please wait...": "Идет обработка документа. Пожалуйста, подождите...", 18 | "_Transcript:_\n%s\n\n_Answer:_ \n\n": "_Транскрипт:_\n%s\n\n_Ответ:_ \n\n", 19 | "Please provide a text file": "Пожалуйста, загрузите текстовый файл", 20 | "No response from API.": "Нет ответа от API.", 21 | "Action: {{.tool}}\nAction input: %s": "Действие: {{.tool}}\nЗапрос: %s", 22 | "get_crypto_rate": "Запрос курса криптовалюты", 23 | "set_reminder": "Установка напоминания", 24 | "vector_search": "Поиск по векторной базе данных", 25 | "search_images": "Поиск изображений", 26 | "web_search": "Поиск в интернете", 27 | "default": "По умолчанию", 28 | "Role deleted": "Роль была успешно удалена", 29 | "Enter role name": "Введите имя для этой роли", 30 | "Enter system prompt": "Введите системный запрос который определит как будет вести себя ассистент", 31 | "Role not found": "Роль не найдена" 32 | } -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // main.go 4 | 5 | import ( 6 | "encoding/json" 7 | "fmt" 8 | stdlog "log" 9 | "os" 10 | "path" 11 | "runtime" 12 | "strings" 13 | "time" 14 | 15 | "github.com/joho/godotenv" 16 | "github.com/meinside/openai-go" 17 | log "github.com/sirupsen/logrus" 18 | "github.com/tectiv3/awsnova-go" 19 | "github.com/tectiv3/chatgpt-bot/i18n" 20 | "golang.org/x/crypto/ssh/terminal" 21 | "gorm.io/driver/sqlite" 22 | "gorm.io/gorm" 23 | "gorm.io/gorm/logger" 24 | ) 25 | 26 | var ( 27 | Log *log.Entry 28 | l *i18n.Localizer 29 | // Version string will be set by linker 30 | Version = "dev" 31 | BuildTime = "unknown" 32 | ) 33 | 34 | func main() { 35 | logrus := log.New() 36 | // redirect Go standard log library calls to logrus writer 37 | stdlog.SetFlags(0) 38 | stdlog.SetFlags(stdlog.LstdFlags | stdlog.Lshortfile) 39 | logrus.Formatter = &log.TextFormatter{ 40 | FullTimestamp: true, 41 | DisableTimestamp: !terminal.IsTerminal(int(os.Stdout.Fd())), 42 | TimestampFormat: "Jan 2 15:04:05.000", 43 | CallerPrettyfier: func(f *runtime.Frame) (string, string) { 44 | return strings.Replace(f.Function, "(*Server).", "", -1), 45 | fmt.Sprintf("%s:%d", path.Base(f.File), f.Line) 46 | }, 47 | } 48 | logrus.SetFormatter(logrus.Formatter) 49 | logrus.SetReportCaller(true) 50 | stdlog.SetOutput(logrus.Writer()) 51 | logrus.SetOutput(os.Stdout) 52 | Log = logrus.WithFields(log.Fields{"ver": Version}) 53 | 54 | confFilepath := "config.json" 55 | if len(os.Args) == 2 { 56 | confFilepath = os.Args[1] 57 | } 58 | 59 | if conf, err := loadConfig(confFilepath); err == nil { 60 | apiKey := conf.OpenAIAPIKey 61 | orgID := conf.OpenAIOrganizationID 62 | level := logger.Error 63 | if conf.Verbose { 64 | // level = logger.Info 65 | logrus.SetLevel(log.DebugLevel) 66 | } 67 | newLogger := logger.New( 68 | stdlog.New(os.Stdout, "\r\n", stdlog.LstdFlags), // io writer 69 | logger.Config{ 70 | SlowThreshold: time.Second, // Slow SQL threshold 71 | LogLevel: level, // Log level 72 | IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger 73 | ParameterizedQueries: false, // Don't include params in the SQL log 74 | Colorful: true, // Disable color 75 | }, 76 | ) 77 | 78 | db, err := gorm.Open(sqlite.Open("bot.db"), &gorm.Config{Logger: newLogger}) 79 | if err != nil { 80 | panic("failed to connect database") 81 | } 82 | 83 | // Migrate the schema 84 | if err := db.AutoMigrate(&User{}); err != nil { 85 | panic("failed to migrate user") 86 | } 87 | if err := db.AutoMigrate(&Chat{}); err != nil { 88 | panic("failed to migrate chat") 89 | } 90 | if err := db.AutoMigrate(&ChatMessage{}); err != nil { 91 | panic("failed to migrate chat message") 92 | } 93 | if err := db.AutoMigrate(&Role{}); err != nil { 94 | panic("failed to migrate role") 95 | } 96 | 97 | Log.WithField("allowed_users", len(conf.AllowedTelegramUsers)).Info("Started") 98 | server := &Server{ 99 | conf: conf, 100 | db: db, 101 | openAI: openai.NewClient(apiKey, orgID), 102 | nova: awsnova.NewClient(conf.AWSRegion, conf.AWSModelID, awsnova.AWSCredentials{ 103 | AccessKeyID: conf.AWSAccessKeyID, 104 | SecretAccessKey: conf.AWSSecretAccessKey, 105 | }), 106 | } 107 | if conf.AnthropicEnabled { 108 | server.anthropic = openai.NewClient(conf.AnthropicAPIKey, "").SetBaseURL("https://api.anthropic.com") 109 | } 110 | l = i18n.New("ru", "en") 111 | 112 | server.run() 113 | } else { 114 | Log.Warn("failed to load config", "error=", err) 115 | } 116 | } 117 | 118 | // load config at given path 119 | func loadConfig(fpath string) (conf config, err error) { 120 | if err := godotenv.Load(); err != nil { 121 | log.WithField("error", err).Warn("Error loading .env file") 122 | } 123 | 124 | var bytes []byte 125 | if bytes, err = os.ReadFile(fpath); err == nil { 126 | if err = json.Unmarshal(bytes, &conf); err == nil { 127 | return conf, nil 128 | } 129 | } 130 | 131 | return config{}, err 132 | } 133 | -------------------------------------------------------------------------------- /markdown.go: -------------------------------------------------------------------------------- 1 | // original source github.com/zavitkov/tg-markdown 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "regexp" 7 | "strings" 8 | ) 9 | 10 | func escapeTelegramMarkdownV2(text string) string { 11 | telegramSpecialChars := "_*[]()~`>#+-=|{}.!" 12 | for _, char := range telegramSpecialChars { 13 | text = strings.ReplaceAll(text, string(char), "\\"+string(char)) 14 | } 15 | return text 16 | } 17 | 18 | func escapeTelegramMarkdownV2CodeBlocks(text string) string { 19 | backslashRegex := regexp.MustCompile(`\\([^ ]|\\|$)`) 20 | 21 | text = backslashRegex.ReplaceAllStringFunc(text, func(match string) string { 22 | if len(match) > 1 { 23 | return fmt.Sprintf("\\\\%s", match[1:]) 24 | } 25 | 26 | return "\\\\" 27 | }) 28 | 29 | text = escapeTelegramMarkdownV2(text) 30 | 31 | return text 32 | } 33 | 34 | func wrapLinksInMarkdown(text string) string { 35 | re := regexp.MustCompile(`\bhttps?://[a-zA-Z0-9-.]+(.[a-zA-Z]{2,})(:[0-9]{1,5})?(/[a-zA-Z0-9-_.~%/?#=&+]*)?\b`) 36 | 37 | wrappedText := re.ReplaceAllStringFunc(text, func(link string) string { 38 | if strings.Contains(link, "](http") || strings.Contains(link, "[http") { 39 | return link 40 | } 41 | return fmt.Sprintf("[%s](%s)", link, link) 42 | }) 43 | 44 | return wrappedText 45 | } 46 | 47 | func removePlaceholders(text string) string { 48 | re := regexp.MustCompile(`ELEMENTPLACEHOLDER\d+|CODEBLOCKPLACEHOLDER\d+|INLINECODEPLACEHOLDER\d+|LINKPLACEHOLDER\d`) 49 | 50 | result := re.ReplaceAllString(text, "") 51 | 52 | return result 53 | } 54 | 55 | func addPlaceholders(text string, elements map[string]string) (string, map[string]string) { 56 | codeBlockRegex := regexp.MustCompile("```(?s:.*?)```") 57 | inlineCodeRegex := regexp.MustCompile("`[^`]+`") 58 | linkRegex := regexp.MustCompile(`\[[^\]]+\]\([^\)]+\)`) 59 | 60 | elementsRegex := regexp.MustCompile(fmt.Sprintf("%s|%s|%s", codeBlockRegex.String(), inlineCodeRegex.String(), linkRegex.String())) 61 | 62 | matches := elementsRegex.FindAllString(text, -1) 63 | for i, match := range matches { 64 | placeholder := fmt.Sprintf("ELEMENTPLACEHOLDER%d", len(elements)+i) 65 | 66 | // Определяем тип элемента и добавляем соответствующий маркер 67 | switch { 68 | case codeBlockRegex.MatchString(match): 69 | placeholder = fmt.Sprintf("CODEBLOCKPLACEHOLDER%d", len(elements)+i) 70 | case inlineCodeRegex.MatchString(match): 71 | placeholder = fmt.Sprintf("INLINECODEPLACEHOLDER%d", len(elements)+i) 72 | case linkRegex.MatchString(match): 73 | placeholder = fmt.Sprintf("LINKPLACEHOLDER%d", len(elements)+i) 74 | } 75 | 76 | elements[placeholder] = match 77 | text = strings.Replace(text, match, placeholder, 1) 78 | } 79 | return text, elements 80 | } 81 | 82 | func processPlaceholders(md string, elements map[string]string) string { 83 | for placeholder, element := range elements { 84 | if strings.HasPrefix(placeholder, "LINKPLACEHOLDER") { 85 | parts := strings.SplitN(element, "](", 2) 86 | linkText := parts[0][1:] 87 | linkURL := parts[1][:len(parts[1])-1] 88 | 89 | linkText = escapeTelegramMarkdownV2(linkText) 90 | element = "[" + linkText + "](" + linkURL + ")" 91 | } 92 | 93 | if strings.HasPrefix(placeholder, "CODEBLOCKPLACEHOLDER") { 94 | re := regexp.MustCompile("(?s)```([a-zA-Z]*)\\n(.*?)\\n```") 95 | 96 | element = re.ReplaceAllStringFunc(element, func(block string) string { 97 | matches := re.FindStringSubmatch(block) 98 | if len(matches) > 2 { 99 | language := matches[1] 100 | return fmt.Sprintf("```%s\n%s\n```", language, escapeTelegramMarkdownV2CodeBlocks(matches[2])) 101 | } 102 | return block 103 | }) 104 | } 105 | 106 | if strings.HasPrefix(placeholder, "INLINECODEPLACEHOLDER") { 107 | re := regexp.MustCompile("`([^`]+)`") 108 | 109 | element = re.ReplaceAllStringFunc(element, func(block string) string { 110 | matches := re.FindStringSubmatch(block) 111 | if len(matches) > 1 { 112 | return fmt.Sprintf("`%s`", escapeTelegramMarkdownV2CodeBlocks(matches[1])) 113 | } 114 | return block 115 | }) 116 | } 117 | 118 | md = strings.Replace(md, placeholder, element, 1) 119 | } 120 | 121 | return md 122 | } 123 | 124 | func processStyles(md string) string { 125 | replacements := []struct { 126 | regex *regexp.Regexp 127 | replacement string 128 | }{ 129 | // bold text 130 | {regexp.MustCompile(`\\\*\\\*(.*?)\\\*\\\*`), `*$1*`}, 131 | 132 | // italic text 133 | {regexp.MustCompile(`\\\*(.*?)\\\*`), `_${1}_`}, 134 | {regexp.MustCompile(`\\_(.*?)\\_`), `_${1}_`}, 135 | 136 | // strikethrough text 137 | {regexp.MustCompile(`\\~\\~(.*?)\\~\\~`), `~$1~`}, 138 | {regexp.MustCompile(`\\~(.*?)\\~`), `~$1~`}, 139 | 140 | // headings 141 | {regexp.MustCompile(`\\\#+(.*)`), `*$1*`}, 142 | } 143 | 144 | for _, r := range replacements { 145 | md = r.regex.ReplaceAllString(md, r.replacement) 146 | } 147 | 148 | return md 149 | } 150 | 151 | func ConvertMarkdownToTelegramMarkdownV2(md string) string { 152 | elements := make(map[string]string) 153 | 154 | md = removePlaceholders(md) 155 | md, elements = addPlaceholders(md, elements) 156 | 157 | md = wrapLinksInMarkdown(md) 158 | md, elements = addPlaceholders(md, elements) 159 | 160 | md = escapeTelegramMarkdownV2(md) 161 | 162 | md = processStyles(md) 163 | 164 | md = processPlaceholders(md, elements) 165 | 166 | return md 167 | } 168 | -------------------------------------------------------------------------------- /models.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "encoding/base64" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "sync" 11 | "time" 12 | 13 | "github.com/meinside/openai-go" 14 | "github.com/tectiv3/awsnova-go" 15 | tele "gopkg.in/telebot.v3" 16 | "gorm.io/gorm" 17 | ) 18 | 19 | // config struct for loading a configuration file 20 | type config struct { 21 | // telegram bot api 22 | TelegramBotToken string `json:"telegram_bot_token"` 23 | TelegramServerURL string `json:"telegram_server_url"` 24 | 25 | Models []AiModel `json:"models"` 26 | 27 | // openai api 28 | OpenAIAPIKey string `json:"openai_api_key"` 29 | OpenAIOrganizationID string `json:"openai_org_id"` 30 | OpenAILatestModel string `json:"openai_latest_model"` 31 | 32 | OllamaURL string `json:"ollama_url"` 33 | OllamaEnabled bool `json:"ollama_enabled"` 34 | 35 | GroqAPIKey string `json:"groq_api_key"` 36 | 37 | AnthropicAPIKey string `json:"anthropic_api_key"` 38 | AnthropicEnabled bool `json:"anthropic_enabled"` 39 | 40 | AWSAccessKeyID string `json:"aws_access_key_id"` 41 | AWSSecretAccessKey string `json:"aws_secret_access_key"` 42 | AWSModelID string `json:"aws_model_id"` 43 | AWSRegion string `json:"aws_region"` 44 | AWSEnabled bool `json:"aws_enabled"` 45 | 46 | // other configurations 47 | AllowedTelegramUsers []string `json:"allowed_telegram_users"` 48 | Verbose bool `json:"verbose,omitempty"` 49 | PiperDir string `json:"piper_dir"` 50 | } 51 | 52 | type AiModel struct { 53 | ModelID string `json:"model_id"` 54 | Name string `json:"name"` 55 | Provider string `json:"provider"` // openai, ollama, groq, nova 56 | } 57 | 58 | type Server struct { 59 | sync.RWMutex 60 | conf config 61 | users []string 62 | openAI *openai.Client 63 | anthropic *openai.Client 64 | nova *awsnova.Client 65 | bot *tele.Bot 66 | db *gorm.DB 67 | } 68 | 69 | type User struct { 70 | gorm.Model 71 | TelegramID *int64 `gorm:"nullable:true"` 72 | Username string 73 | ApiKey *string `gorm:"nullable:true"` 74 | OrgID *string `gorm:"nullable:true"` 75 | Threads []Chat 76 | Roles []Role 77 | State *State `json:"state,omitempty" gorm:"type:text"` 78 | } 79 | 80 | type Role struct { 81 | gorm.Model 82 | UserID uint `json:"user_id"` 83 | Name string 84 | Prompt string 85 | } 86 | 87 | type Chat struct { 88 | gorm.Model 89 | mutex sync.Mutex `gorm:"-"` 90 | ChatID int64 `sql:"chat_id" json:"chat_id"` 91 | UserID uint `json:"user_id" gorm:"nullable:false"` 92 | RoleID *uint `json:"role_id" gorm:"nullable:true"` 93 | MessageID *string `json:"last_message_id" gorm:"nullable:true"` 94 | Lang string 95 | History []ChatMessage 96 | User User `gorm:"foreignKey:UserID;references:ID;fetch:join"` 97 | Role Role `gorm:"foreignKey:RoleID;references:ID;fetch:join"` 98 | Temperature float64 99 | ModelName string 100 | MasterPrompt string 101 | Stream bool 102 | QA bool 103 | Voice bool 104 | ConversationAge int64 105 | TotalTokens int `json:"total_tokens"` 106 | } 107 | 108 | type ChatMessage struct { 109 | ID uint `gorm:"primarykey"` 110 | CreatedAt time.Time 111 | UpdatedAt time.Time 112 | ChatID int64 `sql:"chat_id" json:"chat_id"` 113 | 114 | Role openai.ChatMessageRole `json:"role"` 115 | ToolCallID *string `json:"tool_call_id,omitempty"` 116 | Content *string `json:"content,omitempty"` 117 | ImagePath *string `json:"image_path,omitempty"` 118 | 119 | // for function call 120 | ToolCalls ToolCalls `json:"tool_calls,omitempty" gorm:"type:text"` // when role == 'assistant' 121 | } 122 | 123 | // ToolCalls is a custom type that will allow us to implement 124 | // the driver.Valuer and sql.Scanner interfaces on a slice of ToolCall. 125 | type ToolCalls []ToolCall 126 | 127 | type ToolCall struct { 128 | ID string `json:"id"` 129 | Type string `json:"type"` // == 'function' 130 | Function openai.ToolCallFunction `json:"function"` 131 | } 132 | 133 | // Value implements the driver.Valuer interface, allowing 134 | // for converting the ToolCalls to a JSON string for database storage. 135 | func (tc ToolCalls) Value() (driver.Value, error) { 136 | if tc == nil { 137 | return nil, nil 138 | } 139 | 140 | return json.Marshal(tc) 141 | } 142 | 143 | // Scan implements the sql.Scanner interface, allowing for 144 | // converting a JSON string from the database back into the ToolCalls slice. 145 | func (tc *ToolCalls) Scan(value interface{}) error { 146 | if value == nil { 147 | *tc = nil 148 | return nil 149 | } 150 | 151 | b, ok := value.([]byte) 152 | if !ok { 153 | return fmt.Errorf("type assertion to []byte failed") 154 | } 155 | 156 | return json.Unmarshal(b, &tc) 157 | } 158 | 159 | type GPTResponse interface { 160 | Type() string // direct, array, image, audio, async 161 | Value() interface{} // string, []string 162 | CanReply() bool // if true replyMenu need to be shown 163 | } 164 | 165 | // WAV writer struct 166 | type wavWriter struct { 167 | w io.Writer 168 | } 169 | 170 | // WAV file header struct 171 | type wavHeader struct { 172 | RIFFID [4]byte // RIFF header 173 | FileSize uint32 // file size - 8 174 | WAVEID [4]byte // WAVE header 175 | FMTID [4]byte // fmt header 176 | Subchunk1Size uint32 // size of the fmt chunk 177 | AudioFormat uint16 // audio format code 178 | NumChannels uint16 // number of channels 179 | SampleRate uint32 // sample rate 180 | ByteRate uint32 // bytes per second 181 | BlockAlign uint16 // block align 182 | BitsPerSample uint16 // bits per sample 183 | DataID [4]byte // data header 184 | Subchunk2Size uint32 // size of the data chunk 185 | } 186 | 187 | // RestrictConfig defines config for Restrict middleware. 188 | type RestrictConfig struct { 189 | // Chats is a list of chats that are going to be affected 190 | // by either In or Out function. 191 | Usernames []string 192 | 193 | // In defines a function that will be called if the chat 194 | // of an update will be found in the Chats list. 195 | In tele.HandlerFunc 196 | 197 | // Out defines a function that will be called if the chat 198 | // of an update will NOT be found in the Chats list. 199 | Out tele.HandlerFunc 200 | } 201 | 202 | func in_array(needle string, haystack []string) bool { 203 | for _, v := range haystack { 204 | if needle == v { 205 | return true 206 | } 207 | } 208 | 209 | return false 210 | } 211 | 212 | type CoinCap struct { 213 | Data struct { 214 | Symbol string `json:"symbol"` 215 | PriceUsd string `json:"priceUsd"` 216 | } `json:"data"` 217 | Timestamp int64 `json:"timestamp"` 218 | } 219 | 220 | func toBase64(b []byte) string { 221 | return base64.StdEncoding.EncodeToString(b) 222 | } 223 | 224 | // Context handling utilities 225 | 226 | // WithTimeout creates a context with timeout for operations 227 | func WithTimeout(duration time.Duration) (context.Context, context.CancelFunc) { 228 | return context.WithTimeout(context.Background(), duration) 229 | } 230 | 231 | // WithCancel creates a cancellable context 232 | func WithCancel() (context.Context, context.CancelFunc) { 233 | return context.WithCancel(context.Background()) 234 | } 235 | 236 | // DefaultTimeout for operations 237 | const ( 238 | DefaultTimeout = 30 * time.Second 239 | LongTimeout = 5 * time.Minute 240 | ) 241 | -------------------------------------------------------------------------------- /opus/AUTHORS.txt: -------------------------------------------------------------------------------- 1 | All code and content in this project is Copyright © 2015-2022 Go Opus Authors 2 | 3 | Go Opus Authors and copyright holders of this package are listed below, in no 4 | particular order. By adding yourself to this list you agree to license your 5 | contributions under the relevant license (see the LICENSE file). 6 | 7 | Hraban Luyat 8 | Dejian Xu 9 | Tobias Wellnitz 10 | Elinor Natanzon 11 | Victor Gaydov 12 | Randy Reddig 13 | -------------------------------------------------------------------------------- /opus/callbacks.c: -------------------------------------------------------------------------------- 1 | // +build !nolibopusfile 2 | 3 | // Copyright © Go Opus Authors (see AUTHORS file) 4 | // 5 | // License for use of this code is detailed in the LICENSE file 6 | 7 | // Allocate callback struct in C to ensure it's not managed by the Go GC. This 8 | // plays nice with the CGo rules and avoids any confusion. 9 | 10 | #include 11 | #include 12 | 13 | // Defined in Go. Uses the same signature as Go, no need for proxy function. 14 | int go_readcallback(void *p, unsigned char *buf, int nbytes); 15 | 16 | static struct OpusFileCallbacks callbacks = { 17 | .read = go_readcallback, 18 | }; 19 | 20 | // Proxy function for op_open_callbacks, because it takes a void * context but 21 | // we want to pass it non-pointer data, namely an arbitrary uintptr_t 22 | // value. This is legal C, but go test -race (-d=checkptr) complains anyway. So 23 | // we have this wrapper function to shush it. 24 | // https://groups.google.com/g/golang-nuts/c/995uZyRPKlU 25 | OggOpusFile * 26 | my_open_callbacks(uintptr_t p, int *error) 27 | { 28 | return op_open_callbacks((void *)p, &callbacks, NULL, 0, error); 29 | } 30 | -------------------------------------------------------------------------------- /opus/encoder.go: -------------------------------------------------------------------------------- 1 | // Copyright © Go Opus Authors (see AUTHORS file) 2 | // 3 | // License for use of this code is detailed in the LICENSE file 4 | 5 | package opus 6 | 7 | import ( 8 | "fmt" 9 | "unsafe" 10 | ) 11 | 12 | /* 13 | #cgo pkg-config: opus 14 | #include 15 | 16 | int 17 | bridge_encoder_set_dtx(OpusEncoder *st, opus_int32 use_dtx) 18 | { 19 | return opus_encoder_ctl(st, OPUS_SET_DTX(use_dtx)); 20 | } 21 | 22 | int 23 | bridge_encoder_get_dtx(OpusEncoder *st, opus_int32 *dtx) 24 | { 25 | return opus_encoder_ctl(st, OPUS_GET_DTX(dtx)); 26 | } 27 | 28 | int 29 | bridge_encoder_get_sample_rate(OpusEncoder *st, opus_int32 *sample_rate) 30 | { 31 | return opus_encoder_ctl(st, OPUS_GET_SAMPLE_RATE(sample_rate)); 32 | } 33 | 34 | 35 | int 36 | bridge_encoder_set_bitrate(OpusEncoder *st, opus_int32 bitrate) 37 | { 38 | return opus_encoder_ctl(st, OPUS_SET_BITRATE(bitrate)); 39 | } 40 | 41 | int 42 | bridge_encoder_get_bitrate(OpusEncoder *st, opus_int32 *bitrate) 43 | { 44 | return opus_encoder_ctl(st, OPUS_GET_BITRATE(bitrate)); 45 | } 46 | 47 | int 48 | bridge_encoder_set_complexity(OpusEncoder *st, opus_int32 complexity) 49 | { 50 | return opus_encoder_ctl(st, OPUS_SET_COMPLEXITY(complexity)); 51 | } 52 | 53 | int 54 | bridge_encoder_get_complexity(OpusEncoder *st, opus_int32 *complexity) 55 | { 56 | return opus_encoder_ctl(st, OPUS_GET_COMPLEXITY(complexity)); 57 | } 58 | 59 | int 60 | bridge_encoder_set_max_bandwidth(OpusEncoder *st, opus_int32 max_bw) 61 | { 62 | return opus_encoder_ctl(st, OPUS_SET_MAX_BANDWIDTH(max_bw)); 63 | } 64 | 65 | int 66 | bridge_encoder_get_max_bandwidth(OpusEncoder *st, opus_int32 *max_bw) 67 | { 68 | return opus_encoder_ctl(st, OPUS_GET_MAX_BANDWIDTH(max_bw)); 69 | } 70 | 71 | int 72 | bridge_encoder_set_inband_fec(OpusEncoder *st, opus_int32 fec) 73 | { 74 | return opus_encoder_ctl(st, OPUS_SET_INBAND_FEC(fec)); 75 | } 76 | 77 | int 78 | bridge_encoder_get_inband_fec(OpusEncoder *st, opus_int32 *fec) 79 | { 80 | return opus_encoder_ctl(st, OPUS_GET_INBAND_FEC(fec)); 81 | } 82 | 83 | int 84 | bridge_encoder_set_packet_loss_perc(OpusEncoder *st, opus_int32 loss_perc) 85 | { 86 | return opus_encoder_ctl(st, OPUS_SET_PACKET_LOSS_PERC(loss_perc)); 87 | } 88 | 89 | int 90 | bridge_encoder_get_packet_loss_perc(OpusEncoder *st, opus_int32 *loss_perc) 91 | { 92 | return opus_encoder_ctl(st, OPUS_GET_PACKET_LOSS_PERC(loss_perc)); 93 | } 94 | 95 | */ 96 | import "C" 97 | 98 | type Bandwidth int 99 | 100 | type Application int 101 | 102 | const ( 103 | // Optimize encoding for VoIP 104 | AppVoIP = Application(C.OPUS_APPLICATION_VOIP) 105 | // 4 kHz passband 106 | Narrowband = Bandwidth(C.OPUS_BANDWIDTH_NARROWBAND) 107 | // 6 kHz passband 108 | Mediumband = Bandwidth(C.OPUS_BANDWIDTH_MEDIUMBAND) 109 | // 8 kHz passband 110 | Wideband = Bandwidth(C.OPUS_BANDWIDTH_WIDEBAND) 111 | // 12 kHz passband 112 | SuperWideband = Bandwidth(C.OPUS_BANDWIDTH_SUPERWIDEBAND) 113 | // 20 kHz passband 114 | Fullband = Bandwidth(C.OPUS_BANDWIDTH_FULLBAND) 115 | ) 116 | 117 | var errEncUninitialized = fmt.Errorf("opus encoder uninitialized") 118 | 119 | type Error int 120 | 121 | var _ error = Error(0) 122 | 123 | // Libopus errors 124 | const ( 125 | ErrOK = Error(C.OPUS_OK) 126 | ErrBadArg = Error(C.OPUS_BAD_ARG) 127 | ErrBufferTooSmall = Error(C.OPUS_BUFFER_TOO_SMALL) 128 | ErrInternalError = Error(C.OPUS_INTERNAL_ERROR) 129 | ErrInvalidPacket = Error(C.OPUS_INVALID_PACKET) 130 | ErrUnimplemented = Error(C.OPUS_UNIMPLEMENTED) 131 | ErrInvalidState = Error(C.OPUS_INVALID_STATE) 132 | ErrAllocFail = Error(C.OPUS_ALLOC_FAIL) 133 | ) 134 | 135 | // Error string (in human readable format) for libopus errors. 136 | func (e Error) Error() string { 137 | return fmt.Sprintf("opus: %s", C.GoString(C.opus_strerror(C.int(e)))) 138 | } 139 | 140 | // Encoder contains the state of an Opus encoder for libopus. 141 | type Encoder struct { 142 | p *C.struct_OpusEncoder 143 | channels int 144 | // Memory for the encoder struct allocated on the Go heap to allow Go GC to 145 | // manage it (and obviate need to free()) 146 | mem []byte 147 | } 148 | 149 | // NewEncoder allocates a new Opus encoder and initializes it with the 150 | // appropriate parameters. All related memory is managed by the Go GC. 151 | func NewEncoder(sample_rate int, channels int, application Application) (*Encoder, error) { 152 | var enc Encoder 153 | err := enc.Init(sample_rate, channels, application) 154 | if err != nil { 155 | return nil, err 156 | } 157 | return &enc, nil 158 | } 159 | 160 | // Init initializes a pre-allocated opus encoder. Unless the encoder has been 161 | // created using NewEncoder, this method must be called exactly once in the 162 | // life-time of this object, before calling any other methods. 163 | func (enc *Encoder) Init(sample_rate int, channels int, application Application) error { 164 | if enc.p != nil { 165 | return fmt.Errorf("opus encoder already initialized") 166 | } 167 | if channels != 1 && channels != 2 { 168 | return fmt.Errorf("Number of channels must be 1 or 2: %d", channels) 169 | } 170 | size := C.opus_encoder_get_size(C.int(channels)) 171 | enc.channels = channels 172 | enc.mem = make([]byte, size) 173 | enc.p = (*C.OpusEncoder)(unsafe.Pointer(&enc.mem[0])) 174 | errno := int(C.opus_encoder_init( 175 | enc.p, 176 | C.opus_int32(sample_rate), 177 | C.int(channels), 178 | C.int(application))) 179 | if errno != 0 { 180 | return Error(int(errno)) 181 | } 182 | return nil 183 | } 184 | 185 | // Encode raw PCM data and store the result in the supplied buffer. On success, 186 | // returns the number of bytes used up by the encoded data. 187 | func (enc *Encoder) Encode(pcm []int16, data []byte) (int, error) { 188 | if enc.p == nil { 189 | return 0, errEncUninitialized 190 | } 191 | if len(pcm) == 0 { 192 | return 0, fmt.Errorf("opus: no data supplied") 193 | } 194 | if len(data) == 0 { 195 | return 0, fmt.Errorf("opus: no target buffer") 196 | } 197 | // libopus talks about samples as 1 sample containing multiple channels. So 198 | // e.g. 20 samples of 2-channel data is actually 40 raw data points. 199 | if len(pcm)%enc.channels != 0 { 200 | return 0, fmt.Errorf("opus: input buffer length must be multiple of channels") 201 | } 202 | samples := len(pcm) / enc.channels 203 | n := int(C.opus_encode( 204 | enc.p, 205 | (*C.opus_int16)(&pcm[0]), 206 | C.int(samples), 207 | (*C.uchar)(&data[0]), 208 | C.opus_int32(cap(data)))) 209 | if n < 0 { 210 | return 0, Error(n) 211 | } 212 | return n, nil 213 | } 214 | 215 | // Encode raw PCM data and store the result in the supplied buffer. On success, 216 | // returns the number of bytes used up by the encoded data. 217 | func (enc *Encoder) EncodeFloat32(pcm []float32, data []byte) (int, error) { 218 | if enc.p == nil { 219 | return 0, errEncUninitialized 220 | } 221 | if len(pcm) == 0 { 222 | return 0, fmt.Errorf("opus: no data supplied") 223 | } 224 | if len(data) == 0 { 225 | return 0, fmt.Errorf("opus: no target buffer") 226 | } 227 | if len(pcm)%enc.channels != 0 { 228 | return 0, fmt.Errorf("opus: input buffer length must be multiple of channels") 229 | } 230 | samples := len(pcm) / enc.channels 231 | n := int(C.opus_encode_float( 232 | enc.p, 233 | (*C.float)(&pcm[0]), 234 | C.int(samples), 235 | (*C.uchar)(&data[0]), 236 | C.opus_int32(cap(data)))) 237 | if n < 0 { 238 | return 0, Error(n) 239 | } 240 | return n, nil 241 | } 242 | 243 | // SetDTX configures the encoder's use of discontinuous transmission (DTX). 244 | func (enc *Encoder) SetDTX(dtx bool) error { 245 | i := 0 246 | if dtx { 247 | i = 1 248 | } 249 | res := C.bridge_encoder_set_dtx(enc.p, C.opus_int32(i)) 250 | if res != C.OPUS_OK { 251 | return Error(res) 252 | } 253 | return nil 254 | } 255 | 256 | // DTX reports whether this encoder is configured to use discontinuous 257 | // transmission (DTX). 258 | func (enc *Encoder) DTX() (bool, error) { 259 | var dtx C.opus_int32 260 | res := C.bridge_encoder_get_dtx(enc.p, &dtx) 261 | if res != C.OPUS_OK { 262 | return false, Error(res) 263 | } 264 | return dtx != 0, nil 265 | } 266 | 267 | // SampleRate returns the encoder sample rate in Hz. 268 | func (enc *Encoder) SampleRate() (int, error) { 269 | var sr C.opus_int32 270 | res := C.bridge_encoder_get_sample_rate(enc.p, &sr) 271 | if res != C.OPUS_OK { 272 | return 0, Error(res) 273 | } 274 | return int(sr), nil 275 | } 276 | 277 | // SetBitrate sets the bitrate of the Encoder 278 | func (enc *Encoder) SetBitrate(bitrate int) error { 279 | res := C.bridge_encoder_set_bitrate(enc.p, C.opus_int32(bitrate)) 280 | if res != C.OPUS_OK { 281 | return Error(res) 282 | } 283 | return nil 284 | } 285 | 286 | // SetBitrateToAuto will allow the encoder to automatically set the bitrate 287 | func (enc *Encoder) SetBitrateToAuto() error { 288 | res := C.bridge_encoder_set_bitrate(enc.p, C.opus_int32(C.OPUS_AUTO)) 289 | if res != C.OPUS_OK { 290 | return Error(res) 291 | } 292 | return nil 293 | } 294 | 295 | // SetBitrateToMax causes the encoder to use as much rate as it can. This can be 296 | // useful for controlling the rate by adjusting the output buffer size. 297 | func (enc *Encoder) SetBitrateToMax() error { 298 | res := C.bridge_encoder_set_bitrate(enc.p, C.opus_int32(C.OPUS_BITRATE_MAX)) 299 | if res != C.OPUS_OK { 300 | return Error(res) 301 | } 302 | return nil 303 | } 304 | 305 | // Bitrate returns the bitrate of the Encoder 306 | func (enc *Encoder) Bitrate() (int, error) { 307 | var bitrate C.opus_int32 308 | res := C.bridge_encoder_get_bitrate(enc.p, &bitrate) 309 | if res != C.OPUS_OK { 310 | return 0, Error(res) 311 | } 312 | return int(bitrate), nil 313 | } 314 | 315 | // SetComplexity sets the encoder's computational complexity 316 | func (enc *Encoder) SetComplexity(complexity int) error { 317 | res := C.bridge_encoder_set_complexity(enc.p, C.opus_int32(complexity)) 318 | if res != C.OPUS_OK { 319 | return Error(res) 320 | } 321 | return nil 322 | } 323 | 324 | // Complexity returns the computational complexity used by the encoder 325 | func (enc *Encoder) Complexity() (int, error) { 326 | var complexity C.opus_int32 327 | res := C.bridge_encoder_get_complexity(enc.p, &complexity) 328 | if res != C.OPUS_OK { 329 | return 0, Error(res) 330 | } 331 | return int(complexity), nil 332 | } 333 | 334 | // SetMaxBandwidth configures the maximum bandpass that the encoder will select 335 | // automatically 336 | func (enc *Encoder) SetMaxBandwidth(maxBw Bandwidth) error { 337 | res := C.bridge_encoder_set_max_bandwidth(enc.p, C.opus_int32(maxBw)) 338 | if res != C.OPUS_OK { 339 | return Error(res) 340 | } 341 | return nil 342 | } 343 | 344 | // MaxBandwidth gets the encoder's configured maximum allowed bandpass. 345 | func (enc *Encoder) MaxBandwidth() (Bandwidth, error) { 346 | var maxBw C.opus_int32 347 | res := C.bridge_encoder_get_max_bandwidth(enc.p, &maxBw) 348 | if res != C.OPUS_OK { 349 | return 0, Error(res) 350 | } 351 | return Bandwidth(maxBw), nil 352 | } 353 | 354 | // SetInBandFEC configures the encoder's use of inband forward error 355 | // correction (FEC) 356 | func (enc *Encoder) SetInBandFEC(fec bool) error { 357 | i := 0 358 | if fec { 359 | i = 1 360 | } 361 | res := C.bridge_encoder_set_inband_fec(enc.p, C.opus_int32(i)) 362 | if res != C.OPUS_OK { 363 | return Error(res) 364 | } 365 | return nil 366 | } 367 | 368 | // InBandFEC gets the encoder's configured inband forward error correction (FEC) 369 | func (enc *Encoder) InBandFEC() (bool, error) { 370 | var fec C.opus_int32 371 | res := C.bridge_encoder_get_inband_fec(enc.p, &fec) 372 | if res != C.OPUS_OK { 373 | return false, Error(res) 374 | } 375 | return fec != 0, nil 376 | } 377 | 378 | // SetPacketLossPerc configures the encoder's expected packet loss percentage. 379 | func (enc *Encoder) SetPacketLossPerc(lossPerc int) error { 380 | res := C.bridge_encoder_set_packet_loss_perc(enc.p, C.opus_int32(lossPerc)) 381 | if res != C.OPUS_OK { 382 | return Error(res) 383 | } 384 | return nil 385 | } 386 | 387 | // PacketLossPerc gets the encoder's configured packet loss percentage. 388 | func (enc *Encoder) PacketLossPerc() (int, error) { 389 | var lossPerc C.opus_int32 390 | res := C.bridge_encoder_get_packet_loss_perc(enc.p, &lossPerc) 391 | if res != C.OPUS_OK { 392 | return 0, Error(res) 393 | } 394 | return int(lossPerc), nil 395 | } 396 | -------------------------------------------------------------------------------- /opus/stream.go: -------------------------------------------------------------------------------- 1 | // Copyright © Go Opus Authors (see AUTHORS file) 2 | // 3 | // License for use of this code is detailed in the LICENSE file 4 | 5 | package opus 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | "unsafe" 11 | ) 12 | 13 | /* 14 | #cgo pkg-config: opusfile 15 | #include 16 | #include 17 | #include 18 | 19 | OggOpusFile *my_open_callbacks(uintptr_t p, int *error); 20 | 21 | */ 22 | import "C" 23 | 24 | // Stream wraps a io.Reader in a decoding layer. It provides an API similar to 25 | // io.Reader, but it provides raw PCM data instead of the encoded Opus data. 26 | // 27 | // This is not the same as directly decoding the bytes on the io.Reader; opus 28 | // streams are Ogg Opus audio streams, which package raw Opus data. 29 | // 30 | // This wraps libopusfile. For more information, see the api docs on xiph.org: 31 | // 32 | // https://www.opus-codec.org/docs/opusfile_api-0.7/index.html 33 | type Stream struct { 34 | id uintptr 35 | oggfile *C.OggOpusFile 36 | read io.Reader 37 | // Preallocated buffer to pass to the reader 38 | buf []byte 39 | } 40 | 41 | const maxEncodedFrameSize = 10000 42 | 43 | var streams = newStreamsMap() 44 | 45 | //export go_readcallback 46 | func go_readcallback(p unsafe.Pointer, cbuf *C.uchar, cmaxbytes C.int) C.int { 47 | streamId := uintptr(p) 48 | stream := streams.Get(streamId) 49 | if stream == nil { 50 | // This is bad 51 | return -1 52 | } 53 | 54 | maxbytes := int(cmaxbytes) 55 | if maxbytes > cap(stream.buf) { 56 | maxbytes = cap(stream.buf) 57 | } 58 | // Don't bother cleaning up old data because that's not required by the 59 | // io.Reader API. 60 | n, err := stream.read.Read(stream.buf[:maxbytes]) 61 | // Go allows returning non-nil error (like EOF) and n>0, libopusfile doesn't 62 | // expect that. So return n first to indicate the valid bytes, let the 63 | // subsequent call (which will be n=0, same-error) handle the actual error. 64 | if n == 0 && err != nil { 65 | if err == io.EOF { 66 | return 0 67 | } else { 68 | return -1 69 | } 70 | } 71 | C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&stream.buf[0]), C.size_t(n)) 72 | return C.int(n) 73 | } 74 | 75 | // NewStream creates and initializes a new stream. Don't call .Init() on this. 76 | func NewStream(read io.Reader) (*Stream, error) { 77 | var s Stream 78 | err := s.Init(read) 79 | if err != nil { 80 | return nil, err 81 | } 82 | return &s, nil 83 | } 84 | 85 | // Init initializes a stream with an io.Reader to fetch opus encoded data from 86 | // on demand. Errors from the reader are all transformed to an EOF, any actual 87 | // error information is lost. The same happens when a read returns successfully, 88 | // but with zero bytes. 89 | func (s *Stream) Init(read io.Reader) error { 90 | if s.oggfile != nil { 91 | return fmt.Errorf("opus stream is already initialized") 92 | } 93 | if read == nil { 94 | return fmt.Errorf("Reader must be non-nil") 95 | } 96 | 97 | s.read = read 98 | s.buf = make([]byte, maxEncodedFrameSize) 99 | s.id = streams.NextId() 100 | var errno C.int 101 | 102 | // Immediately delete the stream after .Init to avoid leaking if the 103 | // caller forgets to (/ doesn't want to) call .Close(). No need for that, 104 | // since the callback is only ever called during a .Read operation; just 105 | // Save and Delete from the map around that every time a reader function is 106 | // called. 107 | streams.Save(s) 108 | defer streams.Del(s) 109 | oggfile := C.my_open_callbacks(C.uintptr_t(s.id), &errno) 110 | if errno != 0 { 111 | return StreamError(errno) 112 | } 113 | s.oggfile = oggfile 114 | return nil 115 | } 116 | 117 | // Read a chunk of raw opus data from the stream and decode it. Returns the 118 | // number of decoded samples per channel. This means that a dual channel 119 | // (stereo) feed will have twice as many samples as the value returned. 120 | // 121 | // Read may successfully read less bytes than requested, but it will never read 122 | // exactly zero bytes successfully if a non-zero buffer is supplied. 123 | // 124 | // The number of channels in the output data must be known in advance. It is 125 | // possible to extract this information from the stream itself, but I'm not 126 | // motivated to do that. Feel free to send a pull request. 127 | func (s *Stream) Read(pcm []int16) (int, error) { 128 | if s.oggfile == nil { 129 | return 0, fmt.Errorf("opus stream is uninitialized or already closed") 130 | } 131 | if len(pcm) == 0 { 132 | return 0, nil 133 | } 134 | streams.Save(s) 135 | defer streams.Del(s) 136 | n := C.op_read( 137 | s.oggfile, 138 | (*C.opus_int16)(&pcm[0]), 139 | C.int(len(pcm)), 140 | nil) 141 | if n < 0 { 142 | return 0, StreamError(n) 143 | } 144 | if n == 0 { 145 | return 0, io.EOF 146 | } 147 | return int(n), nil 148 | } 149 | 150 | // ReadFloat32 is the same as Read, but decodes to float32 instead of int16. 151 | func (s *Stream) ReadFloat32(pcm []float32) (int, error) { 152 | if s.oggfile == nil { 153 | return 0, fmt.Errorf("opus stream is uninitialized or already closed") 154 | } 155 | if len(pcm) == 0 { 156 | return 0, nil 157 | } 158 | streams.Save(s) 159 | defer streams.Del(s) 160 | n := C.op_read_float( 161 | s.oggfile, 162 | (*C.float)(&pcm[0]), 163 | C.int(len(pcm)), 164 | nil) 165 | if n < 0 { 166 | return 0, StreamError(n) 167 | } 168 | if n == 0 { 169 | return 0, io.EOF 170 | } 171 | return int(n), nil 172 | } 173 | 174 | func (s *Stream) Close() error { 175 | if s.oggfile == nil { 176 | return fmt.Errorf("opus stream is uninitialized or already closed") 177 | } 178 | C.op_free(s.oggfile) 179 | if closer, ok := s.read.(io.Closer); ok { 180 | return closer.Close() 181 | } 182 | return nil 183 | } 184 | -------------------------------------------------------------------------------- /opus/stream_errors.go: -------------------------------------------------------------------------------- 1 | // Copyright © 2015-2017 Go Opus Authors (see AUTHORS file) 2 | // 3 | // License for use of this code is detailed in the LICENSE file 4 | 5 | package opus 6 | 7 | /* 8 | #include 9 | */ 10 | import "C" 11 | 12 | // StreamError represents an error from libopusfile. 13 | type StreamError int 14 | 15 | var _ error = StreamError(0) 16 | 17 | // Libopusfile errors. The names are copied verbatim from the libopusfile 18 | // library. 19 | const ( 20 | ErrStreamFalse = StreamError(C.OP_FALSE) 21 | ErrStreamEOF = StreamError(C.OP_EOF) 22 | ErrStreamHole = StreamError(C.OP_HOLE) 23 | ErrStreamRead = StreamError(C.OP_EREAD) 24 | ErrStreamFault = StreamError(C.OP_EFAULT) 25 | ErrStreamImpl = StreamError(C.OP_EIMPL) 26 | ErrStreamInval = StreamError(C.OP_EINVAL) 27 | ErrStreamNotFormat = StreamError(C.OP_ENOTFORMAT) 28 | ErrStreamBadHeader = StreamError(C.OP_EBADHEADER) 29 | ErrStreamVersion = StreamError(C.OP_EVERSION) 30 | ErrStreamNotAudio = StreamError(C.OP_ENOTAUDIO) 31 | ErrStreamBadPacked = StreamError(C.OP_EBADPACKET) 32 | ErrStreamBadLink = StreamError(C.OP_EBADLINK) 33 | ErrStreamNoSeek = StreamError(C.OP_ENOSEEK) 34 | ErrStreamBadTimestamp = StreamError(C.OP_EBADTIMESTAMP) 35 | ) 36 | 37 | func (i StreamError) Error() string { 38 | switch i { 39 | case ErrStreamFalse: 40 | return "OP_FALSE" 41 | case ErrStreamEOF: 42 | return "OP_EOF" 43 | case ErrStreamHole: 44 | return "OP_HOLE" 45 | case ErrStreamRead: 46 | return "OP_EREAD" 47 | case ErrStreamFault: 48 | return "OP_EFAULT" 49 | case ErrStreamImpl: 50 | return "OP_EIMPL" 51 | case ErrStreamInval: 52 | return "OP_EINVAL" 53 | case ErrStreamNotFormat: 54 | return "OP_ENOTFORMAT" 55 | case ErrStreamBadHeader: 56 | return "OP_EBADHEADER" 57 | case ErrStreamVersion: 58 | return "OP_EVERSION" 59 | case ErrStreamNotAudio: 60 | return "OP_ENOTAUDIO" 61 | case ErrStreamBadPacked: 62 | return "OP_EBADPACKET" 63 | case ErrStreamBadLink: 64 | return "OP_EBADLINK" 65 | case ErrStreamNoSeek: 66 | return "OP_ENOSEEK" 67 | case ErrStreamBadTimestamp: 68 | return "OP_EBADTIMESTAMP" 69 | default: 70 | return "libopusfile error: %d (unknown code)" 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /opus/streams_map.go: -------------------------------------------------------------------------------- 1 | // Copyright © Go Opus Authors (see AUTHORS file) 2 | // 3 | // License for use of this code is detailed in the LICENSE file 4 | 5 | package opus 6 | 7 | import ( 8 | "sync" 9 | "sync/atomic" 10 | ) 11 | 12 | // A map of simple integers to the actual pointers to stream structs. Avoids 13 | // passing pointers into the Go heap to C. 14 | // 15 | // As per the CGo pointers design doc for go 1.6: 16 | // 17 | // A particular unsafe area is C code that wants to hold on to Go func and 18 | // pointer values for future callbacks from C to Go. This works today but is not 19 | // permitted by the invariant. It is hard to detect. One safe approach is: Go 20 | // code that wants to preserve funcs/pointers stores them into a map indexed by 21 | // an int. Go code calls the C code, passing the int, which the C code may store 22 | // freely. When the C code wants to call into Go, it passes the int to a Go 23 | // function that looks in the map and makes the call. An explicit call is 24 | // required to release the value from the map if it is no longer needed, but 25 | // that was already true before. 26 | // 27 | // - https://github.com/golang/proposal/blob/master/design/12416-cgo-pointers.md 28 | type streamsMap struct { 29 | sync.RWMutex 30 | m map[uintptr]*Stream 31 | counter uintptr 32 | } 33 | 34 | func (sm *streamsMap) Get(id uintptr) *Stream { 35 | sm.RLock() 36 | defer sm.RUnlock() 37 | return sm.m[id] 38 | } 39 | 40 | func (sm *streamsMap) Del(s *Stream) { 41 | sm.Lock() 42 | defer sm.Unlock() 43 | delete(sm.m, s.id) 44 | } 45 | 46 | // NextId returns a unique ID for each call. 47 | func (sm *streamsMap) NextId() uintptr { 48 | return atomic.AddUintptr(&sm.counter, 1) 49 | } 50 | 51 | func (sm *streamsMap) Save(s *Stream) { 52 | sm.Lock() 53 | defer sm.Unlock() 54 | sm.m[s.id] = s 55 | } 56 | 57 | func newStreamsMap() *streamsMap { 58 | return &streamsMap{ 59 | counter: 0, 60 | m: map[uintptr]*Stream{}, 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /state.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/json" 6 | "fmt" 7 | 8 | "strconv" 9 | ) 10 | 11 | type Step struct { 12 | Field string 13 | Prompt string 14 | Input *string 15 | Next *Step 16 | } 17 | 18 | type State struct { 19 | ID *uint 20 | Name string 21 | FirstStep Step 22 | } 23 | 24 | // Value implements the driver.Valuer interface, allowing 25 | // for converting the State to a JSON string for database storage. 26 | func (s State) Value() (driver.Value, error) { 27 | if s.ID == nil && s.Name == "" && s.FirstStep == (Step{}) { 28 | return nil, nil 29 | } 30 | return json.Marshal(s) 31 | } 32 | 33 | // Scan implements the sql.Scanner interface, allowing for 34 | // converting a JSON string from the database back into the State slice. 35 | func (s *State) Scan(value interface{}) error { 36 | if value == nil { 37 | s = nil 38 | return nil 39 | } 40 | 41 | b, ok := value.([]byte) 42 | if !ok { 43 | return fmt.Errorf("type assertion to []byte failed") 44 | } 45 | 46 | return json.Unmarshal(b, &s) 47 | } 48 | 49 | func findEmptyStep(step *Step) *Step { 50 | if step.Input != nil { 51 | if step.Next == nil { 52 | return nil 53 | } 54 | 55 | return findEmptyStep(step.Next) 56 | } 57 | 58 | return step 59 | } 60 | 61 | func asUint(s string) uint { 62 | i, _ := strconv.Atoi(s) 63 | 64 | return uint(i) 65 | } 66 | -------------------------------------------------------------------------------- /tele_handlers.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "runtime/debug" 8 | "strconv" 9 | "strings" 10 | "time" 11 | 12 | "github.com/tectiv3/chatgpt-bot/i18n" 13 | tele "gopkg.in/telebot.v3" 14 | ) 15 | 16 | func (s *Server) onDocument(c tele.Context) { 17 | defer func() { 18 | if err := recover(); err != nil { 19 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 20 | } 21 | }() 22 | Log.WithField("user", c.Sender().Username). 23 | WithField("name", c.Message().Document.FileName). 24 | WithField("mime", c.Message().Document.MIME). 25 | WithField("size", c.Message().Document.FileSize). 26 | Info("Got a file") 27 | 28 | // Validate file size 29 | if err := ValidateFileSize(c.Message().Document.FileSize); err != nil { 30 | chat := s.getChat(c.Chat(), c.Sender()) 31 | _ = c.Send( 32 | chat.t("File too large: {{.error}}", &i18n.Replacements{"error": err.Error()}), 33 | "text", 34 | &tele.SendOptions{ReplyTo: c.Message()}, 35 | ) 36 | return 37 | } 38 | 39 | // Validate file type 40 | if c.Message().Document.MIME != "text/plain" { 41 | chat := s.getChat(c.Chat(), c.Sender()) 42 | _ = c.Send( 43 | chat.t("Please provide a text file"), 44 | "text", 45 | &tele.SendOptions{ReplyTo: c.Message()}, 46 | ) 47 | return 48 | } 49 | var reader io.ReadCloser 50 | var err error 51 | if s.conf.TelegramServerURL != "" { 52 | f, err := c.Bot().FileByID(c.Message().Document.FileID) 53 | if err != nil { 54 | Log.Warn("Error getting file ID", "error=", err) 55 | return 56 | } 57 | // start reader from f.FilePath 58 | reader, err = os.Open(f.FilePath) 59 | if err != nil { 60 | Log.Warn("Error opening file", "error=", err) 61 | return 62 | } 63 | } else { 64 | reader, err = s.bot.File(&c.Message().Document.File) 65 | if err != nil { 66 | _ = c.Send(err.Error(), "text", &tele.SendOptions{ReplyTo: c.Message()}) 67 | return 68 | } 69 | } 70 | defer reader.Close() 71 | bytes, err := io.ReadAll(reader) 72 | if err != nil { 73 | _ = c.Send(err.Error(), "text", &tele.SendOptions{ReplyTo: c.Message()}) 74 | return 75 | } 76 | 77 | response, err := s.simpleAnswer(c, string(bytes)) 78 | if err != nil { 79 | _ = c.Send(response) 80 | return 81 | } 82 | Log.WithField("user", c.Sender().Username).Info("Response length=", len(response)) 83 | 84 | if len(response) == 0 { 85 | return 86 | } 87 | 88 | file := tele.FromReader(strings.NewReader(response)) 89 | _ = c.Send(&tele.Document{File: file, FileName: "answer.txt", MIME: "text/plain"}) 90 | } 91 | 92 | func (s *Server) onText(c tele.Context) { 93 | defer func() { 94 | if err := recover(); err != nil { 95 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 96 | } 97 | }() 98 | 99 | message := strings.TrimSpace(c.Message().Payload) 100 | if len(message) == 0 { 101 | message = strings.TrimSpace(c.Message().Text) 102 | } 103 | 104 | // Basic validation for message length 105 | if len(message) == 0 { 106 | chat := s.getChat(c.Chat(), c.Sender()) 107 | _ = c.Send(chat.t("Please provide a message"), "text", &tele.SendOptions{ReplyTo: c.Message()}) 108 | return 109 | } 110 | 111 | if len(message) > MaxPromptLength { 112 | chat := s.getChat(c.Chat(), c.Sender()) 113 | _ = c.Send( 114 | chat.t("Message too long. Maximum length is {{.max}} characters", &i18n.Replacements{"max": fmt.Sprintf("%d", MaxPromptLength)}), 115 | "text", 116 | &tele.SendOptions{ReplyTo: c.Message()}, 117 | ) 118 | return 119 | } 120 | 121 | s.complete(c, message, true) 122 | } 123 | 124 | func (s *Server) onVoice(c tele.Context) { 125 | defer func() { 126 | if err := recover(); err != nil { 127 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 128 | } 129 | }() 130 | 131 | Log.WithField("user", c.Sender().Username). 132 | Info("Got a voice, filesize=", c.Message().Voice.FileSize) 133 | 134 | s.handleVoice(c) 135 | } 136 | 137 | func (s *Server) onPhoto(c tele.Context) { 138 | defer func() { 139 | if err := recover(); err != nil { 140 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 141 | } 142 | }() 143 | 144 | Log.WithField("user", c.Sender().Username). 145 | Info("Got a photo, filesize=", c.Message().Photo.FileSize) 146 | 147 | if c.Message().Photo.FileSize == 0 { 148 | return 149 | } 150 | 151 | s.handleImage(c) 152 | } 153 | 154 | func (s *Server) onTranslate(c tele.Context, prefix string) { 155 | defer func() { 156 | if err := recover(); err != nil { 157 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 158 | } 159 | }() 160 | 161 | query := c.Message().Text 162 | if len(query) < 1 { 163 | _ = c.Send( 164 | "Please provide a longer prompt", 165 | "text", 166 | &tele.SendOptions{ReplyTo: c.Message()}, 167 | ) 168 | 169 | return 170 | } 171 | // get the text after the command 172 | if len(c.Message().Entities) > 0 { 173 | command := c.Message().EntityText(c.Message().Entities[0]) 174 | query = query[len(command):] 175 | } 176 | 177 | s.complete(c, fmt.Sprintf("%s\n%s", prefix, query), true) 178 | } 179 | 180 | func (s *Server) onGetUsers(c tele.Context) { 181 | defer func() { 182 | if err := recover(); err != nil { 183 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 184 | } 185 | }() 186 | 187 | users := s.getUsers() 188 | text := "Users:\n" 189 | for _, user := range users { 190 | threads := user.Threads 191 | var historyLen int64 192 | var updatedAt time.Time 193 | var totalTokens int 194 | var model string 195 | role := "default" 196 | 197 | if len(threads) > 0 { 198 | s.db.Model(&ChatMessage{}).Where("chat_id = ?", threads[0].ID).Count(&historyLen) 199 | updatedAt = threads[0].UpdatedAt 200 | totalTokens = threads[0].TotalTokens 201 | model = threads[0].ModelName 202 | if threads[0].RoleID != nil { 203 | role = threads[0].Role.Name 204 | } 205 | } 206 | 207 | text += fmt.Sprintf( 208 | "*%s*, last used: *%s*, history: *%d*, usage: *%d*, model: *%s*, role: *%s*\n", 209 | user.Username, 210 | updatedAt.Format("2006/01/02 15:04"), 211 | historyLen, 212 | totalTokens, 213 | model, 214 | role, 215 | ) 216 | } 217 | 218 | _ = c.Send(text, "text", &tele.SendOptions{ReplyTo: c.Message(), ParseMode: tele.ModeMarkdown}) 219 | } 220 | 221 | func (s *Server) onState(c tele.Context) { 222 | defer func() { 223 | if err := recover(); err != nil { 224 | Log.WithField("error", err).Error("panic: ", string(debug.Stack())) 225 | } 226 | }() 227 | 228 | chat := s.getChat(c.Chat(), c.Sender()) 229 | user := chat.User 230 | state := user.State 231 | step := findEmptyStep(&state.FirstStep) 232 | 233 | if step == nil { 234 | s.resetUserState(user) 235 | return 236 | } 237 | 238 | step.Input = &c.Message().Text 239 | s.db.Model(&user).Update("State", state) 240 | 241 | chat.removeMenu(c) 242 | 243 | next := findEmptyStep(step) 244 | if next != nil { 245 | menu.Inline(menu.Row(btnCancel)) 246 | sentMessage, err := c.Bot().Send(c.Recipient(), chat.t(next.Prompt), menu) 247 | if err != nil { 248 | Log.WithField("err", err).Error("Error sending message") 249 | return 250 | } 251 | id := &([]string{strconv.Itoa(sentMessage.ID)}[0]) 252 | s.setChatLastMessageID(id, chat.ChatID) 253 | 254 | return 255 | } 256 | 257 | s.setChatLastMessageID(nil, chat.ChatID) 258 | Log.WithField("State", state.Name).Info("State: Done!") 259 | switch state.Name { 260 | case "RoleCreate": 261 | role := Role{ 262 | UserID: user.ID, 263 | Name: *state.FirstStep.Input, 264 | Prompt: *state.FirstStep.Next.Input, 265 | } 266 | 267 | chat.mutex.Lock() 268 | defer chat.mutex.Unlock() 269 | user.Roles = append(user.Roles, role) 270 | s.db.Save(&user) 271 | 272 | if err := c.Send(chat.t("Role created")); err != nil { 273 | Log.WithField("err", err).Error("Error sending message") 274 | } 275 | case "RoleUpdate": 276 | role := s.getRole(*state.ID) 277 | if role == nil { 278 | Log.Warn("Role not found") 279 | return 280 | } 281 | role.Name = *state.FirstStep.Input 282 | role.Prompt = *state.FirstStep.Next.Input 283 | s.db.Save(role) 284 | 285 | if err := c.Send(chat.t("Role updated")); err != nil { 286 | Log.WithField("err", err).Error("Error sending message") 287 | } 288 | default: 289 | Log.Warn("Unknown state: ", state.Name) 290 | } 291 | s.resetUserState(user) 292 | } 293 | -------------------------------------------------------------------------------- /tools/duckduckgo.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log" 10 | "math/rand" 11 | "net/http" 12 | "net/url" 13 | "regexp" 14 | "strings" 15 | "time" 16 | 17 | "github.com/PuerkitoBio/goquery" 18 | "golang.org/x/net/html" 19 | 20 | safe "github.com/eminarican/safetypes" 21 | ) 22 | 23 | type SearchParam struct { 24 | Query string 25 | Region string 26 | ImageType string 27 | } 28 | 29 | type ClientOption struct { 30 | Referrer string 31 | UserAgent string 32 | Timeout time.Duration 33 | } 34 | 35 | type SearchResult struct { 36 | Title string 37 | Link string 38 | Snippet string 39 | Image string 40 | } 41 | 42 | type Result struct { 43 | Answer string `json:"Answer"` 44 | Results []struct { 45 | Height int `json:"height"` 46 | Image string `json:"image"` 47 | Source string `json:"source"` 48 | Thumbnail string `json:"thumbnail"` 49 | Title string `json:"title"` 50 | URL string `json:"url"` 51 | Width int `json:"width"` 52 | } `json:"results"` 53 | } 54 | 55 | var defaultClientOption = &ClientOption{ 56 | Referrer: "https://duckduckgo.com", 57 | UserAgent: `Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36`, 58 | Timeout: 5 * time.Second, 59 | } 60 | 61 | func NewClientOption(referrer, userAgent string, timeout time.Duration) *ClientOption { 62 | if referrer == "" { 63 | referrer = defaultClientOption.Referrer 64 | } 65 | if userAgent == "" { 66 | referrer = defaultClientOption.UserAgent 67 | } 68 | 69 | if timeout == 0 { 70 | timeout = defaultClientOption.Timeout 71 | } 72 | 73 | return &ClientOption{ 74 | Referrer: referrer, 75 | UserAgent: userAgent, 76 | Timeout: timeout, 77 | } 78 | } 79 | 80 | func NewSearchParam(query, region string) (*SearchParam, error) { 81 | q := strings.TrimSpace(query) 82 | if q == "" { 83 | return nil, errors.New("search query is empty") 84 | } 85 | 86 | return &SearchParam{Query: q, Region: region}, nil 87 | } 88 | 89 | func NewSearchImageParam(query, region, imageType string) (*SearchParam, error) { 90 | q := strings.TrimSpace(query) 91 | if q == "" { 92 | return nil, errors.New("search query is empty") 93 | } 94 | 95 | return &SearchParam{ 96 | Query: q, 97 | Region: region, 98 | ImageType: imageType, 99 | }, nil 100 | } 101 | 102 | func (param *SearchParam) buildURL() (*url.URL, error) { 103 | u := &url.URL{ 104 | Scheme: "https", 105 | Host: "html.duckduckgo.com", 106 | Path: "html", 107 | } 108 | q := u.Query() 109 | q.Add("q", param.Query) 110 | q.Add("l", param.Region) 111 | q.Add("v", "1") 112 | q.Add("o", "json") 113 | q.Add("api", "/d.js") 114 | u.RawQuery = q.Encode() 115 | 116 | return u, nil 117 | } 118 | 119 | func buildRequest(param *SearchParam, opt *ClientOption) (*http.Request, error) { 120 | u, err := param.buildURL() 121 | if err != nil { 122 | return nil, err 123 | } 124 | 125 | req, err := http.NewRequest(http.MethodGet, u.String(), nil) 126 | if err != nil { 127 | return req, err 128 | } 129 | 130 | req.Header.Add("Referrer", opt.Referrer) 131 | req.Header.Add("User-Agent", opt.UserAgent) 132 | req.Header.Add("Cookie", "kl="+param.Region) 133 | req.Header.Add("Content-Type", "x-www-form-urlencoded") 134 | 135 | return req, nil 136 | } 137 | 138 | var re = regexp.MustCompile(`vqd="([\d-]+)"`) 139 | 140 | func addParams(r *http.Request, p map[string]string) { 141 | q := r.URL.Query() 142 | 143 | for k, v := range p { 144 | q.Add(k, v) 145 | } 146 | 147 | r.URL.RawQuery = q.Encode() 148 | } 149 | 150 | func token(keywords string) (string, error) { 151 | var client = http.DefaultClient 152 | const URL = "https://duckduckgo.com/" 153 | 154 | r, _ := http.NewRequest("POST", URL, nil) 155 | addParams(r, map[string]string{"q": keywords}) 156 | 157 | res, err := client.Do(r) 158 | if err != nil { 159 | return "", err 160 | } 161 | 162 | defer res.Body.Close() 163 | 164 | body, err := io.ReadAll(res.Body) 165 | 166 | if err != nil { 167 | return "", err 168 | } 169 | 170 | token := re.Find(body) 171 | 172 | if token == nil { 173 | log.Println(string(body)) 174 | return "", errors.New("token parsing failed") 175 | } 176 | 177 | return strings.Trim(string(token)[4:len(token)-1], "\"&"), nil 178 | } 179 | 180 | func buildImagesRequest(param *SearchParam, opt *ClientOption) (*http.Request, error) { 181 | vqd, err := token(param.Query) 182 | if err != nil { 183 | return nil, err 184 | } 185 | log.Printf("vqd: %s", vqd) 186 | u := &url.URL{ 187 | Scheme: "https", 188 | Host: "duckduckgo.com", 189 | Path: "i.js", 190 | } 191 | 192 | q := u.Query() 193 | q.Add("l", param.Region) 194 | q.Add("o", "json") 195 | q.Add("q", param.Query) 196 | //q.Add("v", "1") 197 | q.Add("vqd", vqd) 198 | q.Add("f", ",,,type:"+param.ImageType) 199 | q.Add("p", "-1") 200 | q.Add("s", "0") 201 | //q.Add("v7exp", "a") 202 | //q.Add("api", "/i.js") 203 | u.RawQuery = q.Encode() 204 | 205 | req, err := http.NewRequest(http.MethodGet, u.String(), nil) 206 | if err != nil { 207 | return req, err 208 | } 209 | 210 | req.Header.Add("Referrer", opt.Referrer) 211 | req.Header.Add("User-Agent", opt.UserAgent) 212 | req.Header.Add("Cookie", "kl="+param.Region) 213 | req.Header.Add("Content-Type", "x-www-form-urlencoded") 214 | 215 | return req, nil 216 | } 217 | 218 | func parse(r io.Reader) safe.Result[*[]SearchResult] { 219 | doc, err := goquery.NewDocumentFromReader(r) 220 | if err != nil { 221 | return safe.Err[*[]SearchResult](err.Error()) 222 | } 223 | 224 | var ( 225 | result []SearchResult 226 | item SearchResult 227 | ) 228 | doc.Find(".result").Each(func(i int, s *goquery.Selection) { 229 | item.Title = s.Find(".result__title a").Text() 230 | 231 | item.Link = extractLink( 232 | s.Find(".result__url").AttrOr("href", ""), 233 | ) 234 | 235 | item.Snippet = removeHtmlTagsFromText( 236 | s.Find(".result__snippet").Text(), 237 | ) 238 | 239 | result = append(result, item) 240 | }) 241 | 242 | return safe.AsResult[*[]SearchResult](&result, nil) 243 | } 244 | 245 | func removeHtmlTags(node *html.Node, buf *bytes.Buffer) { 246 | if node.Type == html.TextNode { 247 | buf.WriteString(node.Data) 248 | } 249 | 250 | for child := node.FirstChild; child != nil; child = child.NextSibling { 251 | removeHtmlTags(child, buf) 252 | } 253 | } 254 | 255 | func removeHtmlTagsFromText(text string) string { 256 | node, err := html.Parse(strings.NewReader(text)) 257 | if err != nil { 258 | // If it cannot be parsed text as HTML, return the text as is. 259 | return text 260 | } 261 | 262 | buf := &bytes.Buffer{} 263 | removeHtmlTags(node, buf) 264 | 265 | return buf.String() 266 | } 267 | 268 | // Extract target URL from href included in search result 269 | // e.g. 270 | // 271 | // `//duckduckgo.com/l/?uddg=https%3A%2F%2Fwww.vim8.org%2Fdownload.php&rut=...` 272 | // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 273 | // --> `https://www.vim8.org/download.php` 274 | func extractLink(href string) string { 275 | u, err := url.Parse(fmt.Sprintf("https:%s", strings.TrimSpace(href))) 276 | if err != nil { 277 | return "" 278 | } 279 | 280 | q := u.Query() 281 | if !q.Has("uddg") { 282 | return "" 283 | } 284 | 285 | return q.Get("uddg") 286 | } 287 | 288 | func SearchWithOption(param *SearchParam, opt *ClientOption) safe.Result[*[]SearchResult] { 289 | c := &http.Client{ 290 | Timeout: opt.Timeout, 291 | } 292 | req, err := buildRequest(param, opt) 293 | if err != nil { 294 | return safe.Err[*[]SearchResult](err.Error()) 295 | } 296 | 297 | resp, err := c.Do(req) 298 | if err != nil { 299 | return safe.Err[*[]SearchResult](err.Error()) 300 | } 301 | defer resp.Body.Close() 302 | 303 | result := parse(resp.Body) 304 | if result.IsErr() { 305 | return result 306 | } 307 | 308 | return result 309 | } 310 | 311 | func Search(param *SearchParam) safe.Result[*[]SearchResult] { 312 | return SearchWithOption(param, defaultClientOption) 313 | } 314 | 315 | func SearchImages(param *SearchParam) safe.Result[*[]SearchResult] { 316 | c := &http.Client{Timeout: defaultClientOption.Timeout} 317 | req, err := buildImagesRequest(param, defaultClientOption) 318 | if err != nil { 319 | return safe.Err[*[]SearchResult](err.Error()) 320 | } 321 | 322 | resp, err := c.Do(req) 323 | if err != nil { 324 | return safe.Err[*[]SearchResult](err.Error()) 325 | } 326 | defer resp.Body.Close() 327 | 328 | result := Result{} 329 | body, err := io.ReadAll(resp.Body) 330 | if err != nil { 331 | return safe.Err[*[]SearchResult](err.Error()) 332 | } 333 | if err := json.Unmarshal(body, &result); err != nil { 334 | log.Println(string(body)) 335 | return safe.Err[*[]SearchResult](string(body)) 336 | } 337 | var ( 338 | res []SearchResult 339 | item SearchResult 340 | ) 341 | for _, r := range result.Results { 342 | item.Title = r.Title 343 | item.Link = r.URL 344 | item.Snippet = "" 345 | item.Image = r.Image 346 | res = append(res, item) 347 | } 348 | 349 | rand.Shuffle(len(res), func(i, j int) { res[i], res[j] = res[j], res[i] }) 350 | 351 | return safe.AsResult[*[]SearchResult](&res, nil) 352 | } 353 | -------------------------------------------------------------------------------- /tools/tool_cryptorate.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "strconv" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | // GetCryptoRate is a tool that can do get crypto rate. 14 | type GetCryptoRate struct { 15 | sessionString string 16 | } 17 | 18 | type CoinCap struct { 19 | Data struct { 20 | Symbol string `json:"symbol"` 21 | PriceUsd string `json:"priceUsd"` 22 | } `json:"data"` 23 | Timestamp int64 `json:"timestamp"` 24 | } 25 | 26 | func (t GetCryptoRate) Description() string { 27 | return `Usefull for getting the current rate of various crypto currencies.` 28 | } 29 | 30 | func (t GetCryptoRate) Name() string { 31 | return "DownloadWebsite" 32 | } 33 | 34 | func (t GetCryptoRate) Call(ctx context.Context, input string) (string, error) { 35 | result, err := getCryptoRate(input) 36 | if err != nil { 37 | return fmt.Sprintf("error from tool: %s", err.Error()), nil //nolint:nilerr 38 | } 39 | 40 | return result, nil 41 | } 42 | 43 | func getCryptoRate(asset string) (string, error) { 44 | asset = strings.ToLower(asset) 45 | format := "$%0.0f" 46 | switch asset { 47 | case "btc": 48 | asset = "bitcoin" 49 | case "eth": 50 | asset = "ethereum" 51 | case "ltc": 52 | asset = "litecoin" 53 | case "xrp": 54 | asset = "ripple" 55 | format = "$%0.3f" 56 | case "xlm": 57 | asset = "stellar" 58 | format = "$%0.3f" 59 | case "ada": 60 | asset = "cardano" 61 | format = "$%0.3f" 62 | } 63 | client := &http.Client{} 64 | client.Timeout = 10 * time.Second 65 | req, err := http.NewRequest("GET", fmt.Sprintf("https://api.coincap.io/v2/assets/%s", asset), nil) 66 | if err != nil { 67 | return "", err 68 | } 69 | resp, err := client.Do(req) 70 | if err != nil { 71 | return "", err 72 | } 73 | defer resp.Body.Close() 74 | 75 | var symbol CoinCap 76 | err = json.NewDecoder(resp.Body).Decode(&symbol) 77 | if err != nil { 78 | return "", err 79 | } 80 | price, _ := strconv.ParseFloat(symbol.Data.PriceUsd, 64) 81 | 82 | return fmt.Sprintf(format, price), nil 83 | } 84 | -------------------------------------------------------------------------------- /tools/tool_search_vector_db.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | log "github.com/sirupsen/logrus" 7 | "github.com/tectiv3/chatgpt-bot/vectordb" 8 | ) 9 | 10 | // SearchVectorDB is a tool that finds the most relevant documents in the vector db. 11 | type SearchVectorDB struct { 12 | SessionString string 13 | Ollama bool 14 | } 15 | 16 | type DocResult struct { 17 | Text string 18 | Source string 19 | } 20 | 21 | // usedResults is a map of used results for all users throughout the session. Potentially a memory leak. 22 | var usedResults = make(map[string][]string) 23 | 24 | func (t SearchVectorDB) Description() string { 25 | return `Useful for searching through added files and websites. Search for keywords in the text not whole questions, avoid relative words like "yesterday" think about what could be in the text. 26 | The input to this tool will be run against a vector db. The top results will be returned as json.` 27 | } 28 | 29 | func (t SearchVectorDB) Name() string { 30 | return "SearchVectorDB" 31 | } 32 | 33 | func (t SearchVectorDB) Call(ctx context.Context, input string) (string, error) { 34 | ctx = context.WithValue(ctx, "ollama", t.Ollama) 35 | docs, err := vectordb.SearchVectorDB(ctx, input, t.SessionString) 36 | 37 | var results []DocResult 38 | 39 | for _, r := range docs { 40 | newResult := DocResult{Text: r.PageContent} 41 | 42 | source, ok := r.Metadata["url"].(string) 43 | if ok { 44 | newResult.Source = source 45 | } 46 | 47 | for _, usedLink := range usedResults[t.SessionString] { 48 | if usedLink == newResult.Text { 49 | continue 50 | } 51 | } 52 | results = append(results, newResult) 53 | usedResults[t.SessionString] = append(usedResults[t.SessionString], newResult.Text) 54 | } 55 | 56 | if len(docs) == 0 { 57 | response := "no results found. Try other db search keywords or download more websites." 58 | log.Warn("no results found", "input", input) 59 | results = append(results, DocResult{Text: response}) 60 | } else if len(results) == 0 { 61 | response := "No new results found, all returned results have been used already. Try other db search keywords or download more websites." 62 | results = append(results, DocResult{Text: response}) 63 | } 64 | 65 | resultJson, err := json.Marshal(results) 66 | if err != nil { 67 | return "", err 68 | } 69 | 70 | return string(resultJson), nil 71 | } 72 | -------------------------------------------------------------------------------- /tools/tool_websearch.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | log "github.com/sirupsen/logrus" 8 | "github.com/tectiv3/chatgpt-bot/vectordb" 9 | "net/http" 10 | "net/url" 11 | "os" 12 | "runtime/debug" 13 | "strings" 14 | "sync" 15 | ) 16 | 17 | type WebSearch struct { 18 | SessionString string 19 | Ollama bool 20 | } 21 | 22 | type seaXngResult struct { 23 | Query string `json:"query"` 24 | NumberOfResults int `json:"number_of_results"` 25 | Results []SearXResult `json:"results"` 26 | Answers []any `json:"answers"` 27 | Corrections []any `json:"corrections"` 28 | Infoboxes []any `json:"infoboxes"` 29 | Suggestions []string `json:"suggestions"` 30 | UnresponsiveEngines []any `json:"unresponsive_engines"` 31 | } 32 | 33 | type SearXResult struct { 34 | URL string `json:"url"` 35 | Title string `json:"title"` 36 | Content string `json:"content"` 37 | PublishedDate any `json:"publishedDate,omitempty"` 38 | ImgSrc any `json:"img_src,omitempty"` 39 | Engine string `json:"engine"` 40 | ParsedURL []string `json:"parsed_url"` 41 | Template string `json:"template"` 42 | Engines []string `json:"engines"` 43 | Positions []int `json:"positions"` 44 | Score float64 `json:"score"` 45 | Category string `json:"category"` 46 | } 47 | 48 | var usedLinks = make(map[string][]string) 49 | 50 | func (t WebSearch) Description() string { 51 | return `Useful for searching the internet. You have to use this tool if you're not 100% certain. Do not append question mark to the query. The top 10 results will be added to the vector db. The top 3 results are also getting returned to you directly. For more search queries through the same websites, use the VectorDB tool. Append region info to the query. For example :en-us. Infer this from the language used for the query. Default to empty if not specified or can not be inferred. Possible regions: ` + strings.Join( 52 | []string{"xa-ar", "xa-en", "ar-es", "au-en", "at-de", "be-fr", "be-nl", "br-pt", "bg-bg", 53 | "ca-en", "ca-fr", "ct-ca", "cl-es", "cn-zh", "co-es", "hr-hr", "cz-cs", "dk-da", 54 | "ee-et", "fi-fi", "fr-fr", "de-de", "gr-el", "hk-tzh", "hu-hu", "in-en", "id-id", 55 | "id-en", "ie-en", "il-he", "it-it", "jp-jp", "kr-kr", "lv-lv", "lt-lt", "xl-es", 56 | "my-ms", "my-en", "mx-es", "nl-nl", "nz-en", "no-no", "pe-es", "ph-en", "ph-tl", 57 | "pl-pl", "pt-pt", "ro-ro", "ru-ru", "sg-en", "sk-sk", "sl-sl", "za-en", "es-es", 58 | "se-sv", "ch-de", "ch-fr", "ch-it", "tw-tzh", "th-th", "tr-tr", "ua-uk", "uk-en", 59 | "us-en", "ue-es", "ve-es", "vn-vi", "vn-en", "za-en"}, ", ") 60 | } 61 | 62 | func (t WebSearch) Name() string { 63 | return "WebSearch" 64 | } 65 | 66 | func (t WebSearch) Call(ctx context.Context, input string) (string, error) { 67 | ctx = context.WithValue(ctx, "ollama", t.Ollama) 68 | results, err := SearchSearX(input) 69 | 70 | wg := sync.WaitGroup{} 71 | counter := 0 72 | for i := range results { 73 | for _, usedLink := range usedLinks[t.SessionString] { 74 | if usedLink == results[i].URL { 75 | continue 76 | } 77 | } 78 | if results[i].Score <= 0.5 { 79 | continue 80 | } 81 | 82 | if counter > 10 { 83 | break 84 | } 85 | 86 | // if result link ends in .pdf, skip 87 | if strings.HasSuffix(results[i].URL, ".pdf") { 88 | continue 89 | } 90 | 91 | counter += 1 92 | wg.Add(1) 93 | go func(i int) { 94 | defer func() { 95 | if r := recover(); r != nil { 96 | log.WithField("error", err).Error("panic: ", string(debug.Stack())) 97 | } 98 | }() 99 | ctx = context.WithValue(ctx, "ollama", t.Ollama) 100 | err := vectordb.DownloadWebsiteToVectorDB(ctx, results[i].URL, t.SessionString) 101 | if err != nil { 102 | log.Warn("Error downloading website", "error=", err) 103 | wg.Done() 104 | return 105 | } 106 | usedLinks[t.SessionString] = append(usedLinks[t.SessionString], results[i].URL) 107 | wg.Done() 108 | }(i) 109 | } 110 | wg.Wait() 111 | 112 | result, err := SearchVectorDB.Call( 113 | SearchVectorDB{SessionString: t.SessionString, Ollama: t.Ollama}, 114 | context.Background(), 115 | input, 116 | ) 117 | 118 | if err != nil { 119 | return fmt.Sprintf("error from vector db search: %s", err.Error()), nil //nolint:nilerr 120 | } 121 | 122 | return result, nil 123 | } 124 | 125 | func SearchSearX(input string) ([]SearXResult, error) { 126 | // remove quotes and question mark. Question mark even escaped still causes 404 in searx 127 | input = strings.TrimSuffix(strings.TrimSuffix(strings.TrimPrefix(input, "\""), "\""), "?") 128 | inputQuery := url.QueryEscape(input) 129 | searXNGDomain := os.Getenv("SEARXNG_DOMAIN") 130 | query := fmt.Sprintf("%s/?q=%s&format=json", searXNGDomain, inputQuery) 131 | //log.Info("Searching", "query", query) 132 | 133 | resp, err := http.Get(query) 134 | 135 | if err != nil { 136 | log.Warn("Error making the request", "error=", err) 137 | return []SearXResult{}, err 138 | } 139 | defer resp.Body.Close() 140 | 141 | if resp.StatusCode > 300 { 142 | log.Warn("Error with the response", "status", resp.Status) 143 | return []SearXResult{}, fmt.Errorf("error with the response: %s", resp.Status) 144 | } 145 | 146 | var apiResponse seaXngResult 147 | //body, err := io.ReadAll(resp.Body) 148 | //log.Info("Response", "body", string(body)) 149 | 150 | if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { 151 | //if err := json.Unmarshal(body, &apiResponse); err != nil { 152 | log.Warn("Error decoding the response", "error=", err) //, "body", string(body)) 153 | return []SearXResult{}, err 154 | } 155 | log.Info("Search results found=", len(apiResponse.Results)) 156 | 157 | if len(apiResponse.Results) == 0 { 158 | return []SearXResult{}, fmt.Errorf("no results found") 159 | } 160 | 161 | return apiResponse.Results, nil 162 | } 163 | -------------------------------------------------------------------------------- /types/types.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/base64" 6 | "encoding/json" 7 | "fmt" 8 | "github.com/meinside/openai-go" 9 | tele "gopkg.in/telebot.v3" 10 | "io" 11 | ) 12 | 13 | // ToolCalls is a custom type that will allow us to implement 14 | // the driver.Valuer and sql.Scanner interfaces on a slice of ToolCall. 15 | type ToolCalls []ToolCall 16 | 17 | type ToolCall struct { 18 | ID string `json:"id"` 19 | Type string `json:"type"` // == 'function' 20 | Function openai.ToolCallFunction `json:"function"` 21 | } 22 | 23 | // Value implements the driver.Valuer interface, allowing 24 | // for converting the ToolCalls to a JSON string for database storage. 25 | func (tc ToolCalls) Value() (driver.Value, error) { 26 | if tc == nil { 27 | return nil, nil 28 | } 29 | return json.Marshal(tc) 30 | } 31 | 32 | // Scan implements the sql.Scanner interface, allowing for 33 | // converting a JSON string from the database back into the ToolCalls slice. 34 | func (tc *ToolCalls) Scan(value interface{}) error { 35 | if value == nil { 36 | *tc = nil 37 | return nil 38 | } 39 | 40 | b, ok := value.([]byte) 41 | if !ok { 42 | return fmt.Errorf("type assertion to []byte failed") 43 | } 44 | 45 | return json.Unmarshal(b, &tc) 46 | } 47 | 48 | type GPTResponse interface { 49 | Type() string // direct, array, image, audio, async 50 | Value() interface{} // string, []string 51 | CanReply() bool // if true replyMenu need to be shown 52 | } 53 | 54 | // WAV writer struct 55 | type wavWriter struct { 56 | w io.Writer 57 | } 58 | 59 | // WAV file header struct 60 | type wavHeader struct { 61 | RIFFID [4]byte // RIFF header 62 | FileSize uint32 // file size - 8 63 | WAVEID [4]byte // WAVE header 64 | FMTID [4]byte // fmt header 65 | Subchunk1Size uint32 // size of the fmt chunk 66 | AudioFormat uint16 // audio format code 67 | NumChannels uint16 // number of channels 68 | SampleRate uint32 // sample rate 69 | ByteRate uint32 // bytes per second 70 | BlockAlign uint16 // block align 71 | BitsPerSample uint16 // bits per sample 72 | DataID [4]byte // data header 73 | Subchunk2Size uint32 // size of the data chunk 74 | } 75 | 76 | // RestrictConfig defines config for Restrict middleware. 77 | type RestrictConfig struct { 78 | // Chats is a list of chats that are going to be affected 79 | // by either In or Out function. 80 | Usernames []string 81 | 82 | // In defines a function that will be called if the chat 83 | // of an update will be found in the Chats list. 84 | In tele.HandlerFunc 85 | 86 | // Out defines a function that will be called if the chat 87 | // of an update will NOT be found in the Chats list. 88 | Out tele.HandlerFunc 89 | } 90 | 91 | func in_array(needle string, haystack []string) bool { 92 | for _, v := range haystack { 93 | if needle == v { 94 | return true 95 | } 96 | } 97 | 98 | return false 99 | } 100 | 101 | type CoinCap struct { 102 | Data struct { 103 | Symbol string `json:"symbol"` 104 | PriceUsd string `json:"priceUsd"` 105 | } `json:"data"` 106 | Timestamp int64 `json:"timestamp"` 107 | } 108 | 109 | type StepType string 110 | 111 | const ( 112 | StepHandleAgentAction StepType = "HandleAgentAction" 113 | StepHandleAgentFinish StepType = "HandleAgentFinish" 114 | StepHandleChainEnd StepType = "HandleChainEnd" 115 | StepHandleChainError StepType = "HandleChainError" 116 | StepHandleChainStart StepType = "HandleChainStart" 117 | StepHandleFinalAnswer StepType = "HandleFinalAnswer" 118 | StepHandleLLMGenerateContentEnd StepType = "HandleLLMGenerateContentEnd" 119 | StepHandleLLMGenerateContentStart StepType = "HandleLLMGenerateContentStart" 120 | StepHandleLlmEnd StepType = "HandleLlmEnd" 121 | StepHandleLlmError StepType = "HandleLlmError" 122 | StepHandleLlmStart StepType = "HandleLlmStart" 123 | StepHandleNewSession StepType = "HandleNewSession" 124 | StepHandleOllamaStart StepType = "HandleOllamaStart" 125 | StepHandleParseError StepType = "HandleParseError" 126 | StepHandleRetrieverEnd StepType = "HandleRetrieverEnd" 127 | StepHandleRetrieverStart StepType = "HandleRetrieverStart" 128 | StepHandleSourceAdded StepType = "HandleSourceAdded" 129 | StepHandleToolEnd StepType = "HandleToolEnd" 130 | StepHandleToolError StepType = "HandleToolError" 131 | StepHandleToolStart StepType = "HandleToolStart" 132 | StepHandleVectorFound StepType = "HandleVectorFound" 133 | ) 134 | 135 | type ClientQuery struct { 136 | Prompt string `json:"prompt"` 137 | MaxIterations int `json:"maxIterations"` 138 | ModelName string `json:"modelName"` 139 | Session string `json:"session"` 140 | } 141 | 142 | type Source struct { 143 | Name string `json:"name"` 144 | Link string `json:"link"` 145 | Summary string `json:"summary"` 146 | } 147 | 148 | type HttpJsonStreamElement struct { 149 | Close bool `json:"close"` 150 | Message string `json:"message"` 151 | Stream bool `json:"stream"` 152 | StepType StepType `json:"stepType"` 153 | Source Source `json:"source"` 154 | Session string `json:"session"` 155 | } 156 | 157 | func toBase64(b []byte) string { 158 | return base64.StdEncoding.EncodeToString(b) 159 | } 160 | -------------------------------------------------------------------------------- /uploads/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tectiv3/chatgpt-bot/fa8d55c077f3e99bb4a6eb7bf4a2f1567fc3f6a8/uploads/.gitkeep -------------------------------------------------------------------------------- /validation.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "regexp" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | // Input validation errors 12 | var ( 13 | ErrInvalidInput = errors.New("invalid input") 14 | ErrInputTooShort = errors.New("input too short") 15 | ErrInputTooLong = errors.New("input too long") 16 | ErrInvalidFormat = errors.New("invalid format") 17 | ErrInvalidRange = errors.New("value out of range") 18 | ) 19 | 20 | // Validation constraints 21 | const ( 22 | MinUsernameLength = 3 23 | MaxUsernameLength = 32 24 | MinPromptLength = 3 25 | MaxPromptLength = 4000 26 | MinRoleNameLength = 1 27 | MaxRoleNameLength = 50 28 | MaxFileSize = 10 * 1024 * 1024 // 10MB 29 | MinAge = 1 30 | MaxAge = 365 31 | ) 32 | 33 | // Input validation functions 34 | 35 | // ValidateUsername validates telegram username 36 | func ValidateUsername(username string) error { 37 | if username == "" { 38 | return fmt.Errorf("%w: username cannot be empty", ErrInvalidInput) 39 | } 40 | 41 | if len(username) < MinUsernameLength { 42 | return fmt.Errorf("%w: username must be at least %d characters", ErrInputTooShort, MinUsernameLength) 43 | } 44 | 45 | if len(username) > MaxUsernameLength { 46 | return fmt.Errorf("%w: username must be at most %d characters", ErrInputTooLong, MaxUsernameLength) 47 | } 48 | 49 | // Username should contain only alphanumeric characters and underscores 50 | matched, _ := regexp.MatchString(`^[a-zA-Z0-9_]+$`, username) 51 | if !matched { 52 | return fmt.Errorf("%w: username can only contain letters, numbers, and underscores", ErrInvalidFormat) 53 | } 54 | 55 | return nil 56 | } 57 | 58 | // ValidatePrompt validates user prompt input 59 | func ValidatePrompt(prompt string) error { 60 | if strings.TrimSpace(prompt) == "" { 61 | return fmt.Errorf("%w: prompt cannot be empty", ErrInvalidInput) 62 | } 63 | 64 | if len(prompt) < MinPromptLength { 65 | return fmt.Errorf("%w: prompt must be at least %d characters", ErrInputTooShort, MinPromptLength) 66 | } 67 | 68 | if len(prompt) > MaxPromptLength { 69 | return fmt.Errorf("%w: prompt must be at most %d characters", ErrInputTooLong, MaxPromptLength) 70 | } 71 | 72 | return nil 73 | } 74 | 75 | // ValidateRoleName validates role name input 76 | func ValidateRoleName(name string) error { 77 | name = strings.TrimSpace(name) 78 | if name == "" { 79 | return fmt.Errorf("%w: role name cannot be empty", ErrInvalidInput) 80 | } 81 | 82 | if len(name) < MinRoleNameLength { 83 | return fmt.Errorf("%w: role name must be at least %d character", ErrInputTooShort, MinRoleNameLength) 84 | } 85 | 86 | if len(name) > MaxRoleNameLength { 87 | return fmt.Errorf("%w: role name must be at most %d characters", ErrInputTooLong, MaxRoleNameLength) 88 | } 89 | 90 | return nil 91 | } 92 | 93 | // ValidateAge validates conversation age input 94 | func ValidateAge(ageStr string) (int, error) { 95 | if ageStr == "" { 96 | return 0, fmt.Errorf("%w: age cannot be empty", ErrInvalidInput) 97 | } 98 | 99 | age, err := strconv.Atoi(ageStr) 100 | if err != nil { 101 | return 0, fmt.Errorf("%w: age must be a number", ErrInvalidFormat) 102 | } 103 | 104 | if age < MinAge || age > MaxAge { 105 | return 0, fmt.Errorf("%w: age must be between %d and %d days", ErrInvalidRange, MinAge, MaxAge) 106 | } 107 | 108 | return age, nil 109 | } 110 | 111 | // ValidateTemperature validates temperature input 112 | func ValidateTemperature(tempStr string) (float64, error) { 113 | if tempStr == "" { 114 | return 0, fmt.Errorf("%w: temperature cannot be empty", ErrInvalidInput) 115 | } 116 | 117 | temp, err := strconv.ParseFloat(tempStr, 64) 118 | if err != nil { 119 | return 0, fmt.Errorf("%w: temperature must be a number", ErrInvalidFormat) 120 | } 121 | 122 | if temp < 0.0 || temp > 1.0 { 123 | return 0, fmt.Errorf("%w: temperature must be between 0.0 and 1.0", ErrInvalidRange) 124 | } 125 | 126 | return temp, nil 127 | } 128 | 129 | // ValidateLanguageCode validates language code input 130 | func ValidateLanguageCode(lang string) error { 131 | if lang == "" { 132 | return fmt.Errorf("%w: language code cannot be empty", ErrInvalidInput) 133 | } 134 | 135 | // Language codes should be 2-5 characters (e.g., "en", "en-US") 136 | if len(lang) < 2 || len(lang) > 5 { 137 | return fmt.Errorf("%w: language code must be 2-5 characters", ErrInvalidFormat) 138 | } 139 | 140 | // Basic format validation for language codes 141 | matched, _ := regexp.MatchString(`^[a-z]{2}(-[A-Z]{2})?$`, lang) 142 | if !matched { 143 | return fmt.Errorf("%w: invalid language code format", ErrInvalidFormat) 144 | } 145 | 146 | return nil 147 | } 148 | 149 | // ValidateFileSize validates uploaded file size 150 | func ValidateFileSize(size int64) error { 151 | if size <= 0 { 152 | return fmt.Errorf("%w: file size must be greater than 0", ErrInvalidInput) 153 | } 154 | 155 | if size > MaxFileSize { 156 | return fmt.Errorf("%w: file size must be less than %d bytes", ErrInputTooLong, MaxFileSize) 157 | } 158 | 159 | return nil 160 | } 161 | -------------------------------------------------------------------------------- /vectordb/chroma.go: -------------------------------------------------------------------------------- 1 | package vectordb 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "github.com/amikos-tech/chroma-go" 8 | chromatypes "github.com/amikos-tech/chroma-go/types" 9 | "github.com/google/uuid" 10 | log "github.com/sirupsen/logrus" 11 | "golang.org/x/exp/maps" 12 | ) 13 | 14 | var ( 15 | ErrInvalidScoreThreshold = errors.New("score threshold must be between 0 and 1") 16 | ErrUnexpectedResponseLength = errors.New("unexpected length of response") 17 | ErrNewClient = errors.New("error creating collection") 18 | ErrAddDocument = errors.New("error adding document") 19 | ErrRemoveCollection = errors.New("error resetting collection") 20 | ErrUnsupportedOptions = errors.New("unsupported options") 21 | ) 22 | 23 | // Option is a function that configures an Options. 24 | type Option func(*Options) 25 | 26 | // Options is a set of options for similarity search and add documents. 27 | type Options struct { 28 | NameSpace string 29 | ScoreThreshold float32 30 | Filters any 31 | Embedder Embedder 32 | Deduplicater func(context.Context, Document) bool 33 | } 34 | 35 | // WithNameSpace returns an Option for setting the name space. 36 | func WithNameSpace(nameSpace string) Option { 37 | return func(o *Options) { 38 | o.NameSpace = nameSpace 39 | } 40 | } 41 | 42 | func WithScoreThreshold(scoreThreshold float32) Option { 43 | return func(o *Options) { 44 | o.ScoreThreshold = scoreThreshold 45 | } 46 | } 47 | 48 | // WithFilters searches can be limited based on metadata filters. Searches with metadata 49 | // filters retrieve exactly the number of nearest-neighbors results that match the filters. In 50 | // most cases the search latency will be lower than unfiltered searches 51 | // See https://docs.pinecone.io/docs/metadata-filtering 52 | func WithFilters(filters any) Option { 53 | return func(o *Options) { 54 | o.Filters = filters 55 | } 56 | } 57 | 58 | // WithEmbedder returns an Option for setting the embedder that could be used when 59 | // adding documents or doing similarity search (instead the embedder from the Store context) 60 | // this is useful when we are using multiple LLMs with single vectorstore. 61 | func WithEmbedder(embedder Embedder) Option { 62 | return func(o *Options) { 63 | o.Embedder = embedder 64 | } 65 | } 66 | 67 | // WithDeduplicater returns an Option for setting the deduplicater that could be used 68 | // when adding documents. This is useful to prevent wasting time on creating an embedding 69 | // when one already exists. 70 | func WithDeduplicater(fn func(ctx context.Context, doc Document) bool) Option { 71 | return func(o *Options) { 72 | o.Deduplicater = fn 73 | } 74 | } 75 | 76 | // chromaStore is a wrapper around the chromaGo API and client. 77 | type chromaStore struct { 78 | client *chromago.Client 79 | collection *chromago.Collection 80 | distanceFunction chromatypes.DistanceFunction 81 | chromaURL string 82 | openaiAPIKey string 83 | openaiOrganization string 84 | 85 | nameSpace string 86 | nameSpaceKey string 87 | embedder Embedder 88 | includes []chromatypes.QueryEnum 89 | } 90 | 91 | // New creates an active client connection to the collection in the Chroma server 92 | // and returns the `chromaStore` object needed by the other accessors. 93 | func newChroma(key, url, namespace string) (chromaStore, error) { 94 | s := chromaStore{} 95 | chromaClient, err := chromago.NewClient(url) 96 | if err != nil { 97 | return s, err 98 | } 99 | if _, errHb := chromaClient.Heartbeat(context.Background()); errHb != nil { 100 | return s, errHb 101 | } 102 | s.client = chromaClient 103 | 104 | // var embeddingFunction chromatypes.EmbeddingFunction 105 | // embeddingFunction, err = openai.NewOpenAIEmbeddingFunction(key) 106 | // if err != nil { 107 | // return s, err 108 | // } 109 | embedder := NewEmbedder(NewOpenAIClient(key)) 110 | embeddingFunction := chromaGoEmbedder{Embedder: embedder} 111 | 112 | col, errCc := s.client.CreateCollection(context.Background(), namespace, map[string]any{}, true, 113 | embeddingFunction, "cosine") 114 | if errCc != nil { 115 | return s, fmt.Errorf("%w: %w", ErrNewClient, errCc) 116 | } 117 | 118 | s.collection = col 119 | 120 | return s, nil 121 | } 122 | 123 | // AddDocuments adds the text and metadata from the documents to the Chroma collection associated with 'chromaStore'. 124 | // and returns the ids of the added documents. 125 | func (s chromaStore) AddDocuments(ctx context.Context, 126 | docs []Document, 127 | options ...Option, 128 | ) ([]string, error) { 129 | opts := s.getOptions(options...) 130 | if opts.Embedder != nil || opts.ScoreThreshold != 0 || opts.Filters != nil { 131 | return nil, ErrUnsupportedOptions 132 | } 133 | 134 | nameSpace := s.getNameSpace(opts) 135 | if nameSpace != "" && s.nameSpaceKey == "" { 136 | return nil, fmt.Errorf("%w: nameSpace without nameSpaceKey", ErrUnsupportedOptions) 137 | } 138 | 139 | ids := make([]string, len(docs)) 140 | texts := make([]string, len(docs)) 141 | metadatas := make([]map[string]any, len(docs)) 142 | for docIdx, doc := range docs { 143 | ids[docIdx] = uuid.New().String() // TODO (noodnik2): find & use something more meaningful 144 | texts[docIdx] = doc.PageContent 145 | mc := make(map[string]any, 0) 146 | maps.Copy(mc, doc.Metadata) 147 | metadatas[docIdx] = mc 148 | if nameSpace != "" { 149 | metadatas[docIdx][s.nameSpaceKey] = nameSpace 150 | } 151 | } 152 | 153 | if _, addErr := s.collection.Add(ctx, nil, metadatas, texts, ids); addErr != nil { 154 | log.WithField("metadatas", metadatas).WithField("ids", ids).Warn("Collection add failed:", addErr) 155 | return nil, fmt.Errorf("%w: %w", ErrAddDocument, addErr) 156 | } 157 | 158 | return ids, nil 159 | } 160 | 161 | func (s chromaStore) SimilaritySearch(ctx context.Context, query string, numDocuments int, 162 | options ...Option, 163 | ) ([]Document, error) { 164 | opts := s.getOptions(options...) 165 | 166 | if opts.Embedder != nil { 167 | // embedder is not used by this method, so shouldn't ever be specified 168 | return nil, fmt.Errorf("%w: Embedder", ErrUnsupportedOptions) 169 | } 170 | 171 | scoreThreshold, stErr := s.getScoreThreshold(opts) 172 | if stErr != nil { 173 | return nil, stErr 174 | } 175 | 176 | filter := s.getNamespacedFilter(opts) 177 | qr, queryErr := s.collection.Query(ctx, []string{query}, int32(numDocuments), filter, nil, s.includes) 178 | if queryErr != nil { 179 | return nil, queryErr 180 | } 181 | 182 | if len(qr.Documents) != len(qr.Metadatas) || len(qr.Metadatas) != len(qr.Distances) { 183 | return nil, fmt.Errorf("%w: qr.Documents[%d], qr.Metadatas[%d], qr.Distances[%d]", 184 | ErrUnexpectedResponseLength, len(qr.Documents), len(qr.Metadatas), len(qr.Distances)) 185 | } 186 | var sDocs []Document 187 | for docsI := range qr.Documents { 188 | for docI := range qr.Documents[docsI] { 189 | if score := 1.0 - qr.Distances[docsI][docI]; score >= scoreThreshold { 190 | sDocs = append(sDocs, Document{ 191 | Metadata: qr.Metadatas[docsI][docI], 192 | PageContent: qr.Documents[docsI][docI], 193 | Score: score, 194 | }) 195 | } 196 | } 197 | } 198 | 199 | return sDocs, nil 200 | } 201 | 202 | func (s chromaStore) RemoveCollection() error { 203 | if s.client == nil || s.collection == nil { 204 | return fmt.Errorf("%w: no collection", ErrRemoveCollection) 205 | } 206 | _, errDc := s.client.DeleteCollection(context.Background(), s.collection.Name) 207 | if errDc != nil { 208 | return fmt.Errorf("%w(%s): %w", ErrRemoveCollection, s.collection.Name, errDc) 209 | } 210 | return nil 211 | } 212 | 213 | func (s chromaStore) getOptions(options ...Option) Options { 214 | opts := Options{} 215 | for _, opt := range options { 216 | opt(&opts) 217 | } 218 | return opts 219 | } 220 | 221 | func (s chromaStore) getScoreThreshold(opts Options) (float32, error) { 222 | if opts.ScoreThreshold < 0 || opts.ScoreThreshold > 1 { 223 | return 0, ErrInvalidScoreThreshold 224 | } 225 | return opts.ScoreThreshold, nil 226 | } 227 | 228 | func (s chromaStore) getNameSpace(opts Options) string { 229 | if opts.NameSpace != "" { 230 | return opts.NameSpace 231 | } 232 | return s.nameSpace 233 | } 234 | 235 | func (s chromaStore) getNamespacedFilter(opts Options) map[string]any { 236 | filter, _ := opts.Filters.(map[string]any) 237 | 238 | nameSpace := s.getNameSpace(opts) 239 | if nameSpace == "" || s.nameSpaceKey == "" { 240 | return filter 241 | } 242 | 243 | nameSpaceFilter := map[string]any{s.nameSpaceKey: nameSpace} 244 | if filter == nil { 245 | return nameSpaceFilter 246 | } 247 | 248 | return map[string]any{"$and": []map[string]any{nameSpaceFilter, filter}} 249 | } 250 | -------------------------------------------------------------------------------- /vectordb/embedder.go: -------------------------------------------------------------------------------- 1 | package vectordb 2 | 3 | import ( 4 | "context" 5 | 6 | chromatypes "github.com/amikos-tech/chroma-go/types" 7 | ) 8 | 9 | var _ chromatypes.EmbeddingFunction = chromaGoEmbedder{} // compile-time check 10 | 11 | // chromaGoEmbedder adapts an 'embeddings.Embedder' to a 'chroma_go.EmbeddingFunction'. 12 | type chromaGoEmbedder struct { 13 | Embedder 14 | } 15 | 16 | func (e chromaGoEmbedder) EmbedDocuments(ctx context.Context, texts []string) ([]*chromatypes.Embedding, error) { 17 | _embeddings, err := e.Embedder.EmbedDocuments(ctx, texts) 18 | if err != nil { 19 | return nil, err 20 | } 21 | _chrmembeddings := make([]*chromatypes.Embedding, len(_embeddings)) 22 | for i, emb := range _embeddings { 23 | _chrmembeddings[i] = chromatypes.NewEmbeddingFromFloat32(emb) 24 | } 25 | 26 | return _chrmembeddings, nil 27 | } 28 | 29 | func (e chromaGoEmbedder) EmbedQuery(ctx context.Context, text string) (*chromatypes.Embedding, error) { 30 | _embedding, err := e.Embedder.EmbedQuery(ctx, text) 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | return chromatypes.NewEmbeddingFromFloat32(_embedding), nil 36 | } 37 | 38 | func (e chromaGoEmbedder) EmbedRecords(ctx context.Context, records []*chromatypes.Record, force bool) error { 39 | return chromatypes.EmbedRecordsDefaultImpl(e, ctx, records, force) 40 | } 41 | -------------------------------------------------------------------------------- /vectordb/embedding.go: -------------------------------------------------------------------------------- 1 | package vectordb 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | const ( 10 | defaultEmbeddingModel = "text-embedding-3-large" 11 | ) 12 | 13 | // NewEmbedder creates a new Embedder from the given EmbedderClient, with 14 | // some options that affect how embedding will be done. 15 | func NewEmbedder(client EmbedderClient, opts ...EOption) *EmbedderImpl { 16 | e := &EmbedderImpl{ 17 | client: client, 18 | StripNewLines: defaultStripNewLines, 19 | BatchSize: defaultBatchSize, 20 | } 21 | 22 | for _, opt := range opts { 23 | opt(e) 24 | } 25 | return e 26 | } 27 | 28 | // Embedder is the interface for creating vector embeddings from texts. 29 | type Embedder interface { 30 | // EmbedDocuments returns a vector for each text. 31 | EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) 32 | // EmbedQuery embeds a single text. 33 | EmbedQuery(ctx context.Context, text string) ([]float32, error) 34 | } 35 | 36 | // EmbedderClient is the interface LLM clients implement for embeddings. 37 | type EmbedderClient interface { 38 | CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) 39 | } 40 | 41 | // EmbedderClientFunc is an adapter to allow the use of ordinary functions as Embedder Clients. If 42 | // `f` is a function with the appropriate signature, `EmbedderClientFunc(f)` is an `EmbedderClient` 43 | // that calls `f`. 44 | type EmbedderClientFunc func(ctx context.Context, texts []string) ([][]float32, error) 45 | 46 | func (e EmbedderClientFunc) CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error) { 47 | return e(ctx, texts) 48 | } 49 | 50 | type EmbedderImpl struct { 51 | client EmbedderClient 52 | 53 | StripNewLines bool 54 | BatchSize int 55 | } 56 | 57 | const ( 58 | defaultBatchSize = 512 59 | defaultStripNewLines = true 60 | ) 61 | 62 | type EOption func(p *EmbedderImpl) 63 | 64 | // WithStripNewLines is an option for specifying the should it strip new lines. 65 | func WithStripNewLines(stripNewLines bool) EOption { 66 | return func(p *EmbedderImpl) { 67 | p.StripNewLines = stripNewLines 68 | } 69 | } 70 | 71 | // WithBatchSize is an option for specifying the batch size. 72 | func WithBatchSize(batchSize int) EOption { 73 | return func(p *EmbedderImpl) { 74 | p.BatchSize = batchSize 75 | } 76 | } 77 | 78 | // EmbedQuery embeds a single text. 79 | func (ei *EmbedderImpl) EmbedQuery(ctx context.Context, text string) ([]float32, error) { 80 | if ei.StripNewLines { 81 | text = strings.ReplaceAll(text, "\n", " ") 82 | } 83 | 84 | emb, err := ei.client.CreateEmbedding(ctx, []string{text}) 85 | if err != nil { 86 | return nil, fmt.Errorf("error embedding query: %w", err) 87 | } 88 | 89 | return emb[0], nil 90 | } 91 | 92 | // EmbedDocuments creates one vector embedding for each of the texts. 93 | func (ei *EmbedderImpl) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) { 94 | texts = MaybeRemoveNewLines(texts, ei.StripNewLines) 95 | return BatchedEmbed(ctx, ei.client, texts, ei.BatchSize) 96 | } 97 | 98 | func MaybeRemoveNewLines(texts []string, removeNewLines bool) []string { 99 | if !removeNewLines { 100 | return texts 101 | } 102 | 103 | for i := 0; i < len(texts); i++ { 104 | texts[i] = strings.ReplaceAll(texts[i], "\n", " ") 105 | } 106 | 107 | return texts 108 | } 109 | 110 | // BatchTexts splits strings by the length batchSize. 111 | func BatchTexts(texts []string, batchSize int) [][]string { 112 | batchedTexts := make([][]string, 0, len(texts)/batchSize+1) 113 | 114 | for i := 0; i < len(texts); i += batchSize { 115 | batchedTexts = append(batchedTexts, texts[i:minInt([]int{i + batchSize, len(texts)})]) 116 | } 117 | 118 | return batchedTexts 119 | } 120 | 121 | // BatchedEmbed creates embeddings for the given input texts, batching them 122 | // into batches of batchSize if needed. 123 | func BatchedEmbed(ctx context.Context, embedder EmbedderClient, texts []string, batchSize int) ([][]float32, error) { 124 | batchedTexts := BatchTexts(texts, batchSize) 125 | 126 | emb := make([][]float32, 0, len(texts)) 127 | for _, batch := range batchedTexts { 128 | curBatchEmbeddings, err := embedder.CreateEmbedding(ctx, batch) 129 | if err != nil { 130 | return nil, fmt.Errorf("error embedding batch: %w", err) 131 | } 132 | emb = append(emb, curBatchEmbeddings...) 133 | } 134 | 135 | return emb, nil 136 | } 137 | 138 | // MinInt returns the minimum value in nums. 139 | // If nums is empty, it returns 0. 140 | func minInt(nums []int) int { 141 | var min int 142 | for idx := 0; idx < len(nums); idx++ { 143 | item := nums[idx] 144 | if idx == 0 { 145 | min = item 146 | continue 147 | } 148 | if item < min { 149 | min = item 150 | } 151 | } 152 | return min 153 | } 154 | -------------------------------------------------------------------------------- /vectordb/handler.go: -------------------------------------------------------------------------------- 1 | package vectordb 2 | 3 | import ( 4 | "context" 5 | "github.com/go-shiori/go-readability" 6 | log "github.com/sirupsen/logrus" 7 | "os" 8 | "strings" 9 | "time" 10 | ) 11 | 12 | // VectorStore is the interface for saving and querying documents in the 13 | // form of vector embeddings. 14 | type VectorStore interface { 15 | AddDocuments(ctx context.Context, docs []Document, options ...Option) ([]string, error) 16 | SimilaritySearch(ctx context.Context, query string, numDocuments int, options ...Option) ([]Document, error) //nolint:lll 17 | } 18 | 19 | type Document struct { 20 | PageContent string 21 | Metadata map[string]any 22 | Score float32 23 | } 24 | 25 | // Retriever is a retriever for vector stores. 26 | type Retriever struct { 27 | CallbacksHandler interface{} 28 | v VectorStore 29 | numDocs int 30 | options []Option 31 | } 32 | 33 | var _ Retriever = Retriever{} 34 | 35 | // GetRelevantDocuments returns documents using the vector store. 36 | func (r Retriever) GetRelevantDocuments(ctx context.Context, query string) ([]Document, error) { 37 | docs, err := r.v.SimilaritySearch(ctx, query, r.numDocs, r.options...) 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | return docs, nil 43 | } 44 | 45 | // ToRetriever takes a vector store and returns a retriever using the 46 | // vector store to retrieve documents. 47 | func ToRetriever(vectorStore VectorStore, numDocuments int, options ...Option) Retriever { 48 | return Retriever{ 49 | v: vectorStore, 50 | numDocs: numDocuments, 51 | options: options, 52 | } 53 | } 54 | 55 | func newStore(ctx context.Context, sessionString string) (*chromaStore, error) { 56 | store, err := newChroma(os.Getenv("OPENAI_API_KEY"), os.Getenv("CHROMA_DB_URL"), sessionString+"nodeps") 57 | 58 | return &store, err 59 | } 60 | 61 | func saveToVectorDb(timeoutCtx context.Context, docs []Document, sessionString string) error { 62 | store, err := newStore(timeoutCtx, sessionString) 63 | if err != nil { 64 | return err 65 | } 66 | 67 | for i := range docs { 68 | if len(docs[i].PageContent) == 0 { 69 | // remove the document from the list 70 | docs = append(docs[:i], docs[i+1:]...) 71 | } 72 | } 73 | 74 | if _, err := store.AddDocuments(timeoutCtx, docs); err != nil { 75 | log.Warn(err) 76 | return err 77 | } 78 | //log.Info("Added documents, count=", len(docs)) 79 | 80 | return nil 81 | } 82 | 83 | func DownloadWebsiteToVectorDB(ctx context.Context, url string, sessionString string) error { 84 | article, err := readability.FromURL(url, 10*time.Second) 85 | if err != nil { 86 | return err 87 | } 88 | 89 | vectorLoader := NewText(strings.NewReader(article.TextContent)) 90 | splitter := NewTokenSplitter(WithSeparators([]string{"\n\n", "\n"})) 91 | splitter.ChunkOverlap = 100 92 | splitter.ChunkSize = 300 93 | docs, err := vectorLoader.LoadAndSplit(ctx, splitter) 94 | 95 | for i := range docs { 96 | docs[i].Metadata = map[string]interface{}{"url": url} 97 | } 98 | timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 99 | defer cancel() 100 | 101 | return saveToVectorDb(timeoutCtx, docs, sessionString) 102 | } 103 | 104 | func SearchVectorDB(ctx context.Context, input string, sessionString string) ([]Document, error) { 105 | amountOfResults := 3 106 | scoreThreshold := 0.4 107 | store, err := newStore(ctx, sessionString) 108 | if err != nil { 109 | return []Document{}, err 110 | } 111 | 112 | options := []Option{WithScoreThreshold(float32(scoreThreshold))} 113 | retriever := ToRetriever(store, amountOfResults, options...) 114 | 115 | return retriever.GetRelevantDocuments(context.Background(), input) 116 | } 117 | -------------------------------------------------------------------------------- /vectordb/openaiclient.go: -------------------------------------------------------------------------------- 1 | package vectordb 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | log "github.com/sirupsen/logrus" 10 | "io" 11 | "net/http" 12 | "strings" 13 | ) 14 | 15 | const ( 16 | defaultBaseURL = "https://api.openai.com/v1" 17 | ) 18 | 19 | // ErrEmptyResponse is returned when the OpenAI API returns an empty response. 20 | var ErrEmptyResponse = errors.New("empty response") 21 | 22 | // OpenAIClient is a client for the OpenAI API. 23 | type OpenAIClient struct { 24 | token string 25 | baseURL string 26 | organization string 27 | httpClient Doer 28 | 29 | EmbeddingModel string 30 | } 31 | 32 | // Doer performs a HTTP request. 33 | type Doer interface { 34 | Do(req *http.Request) (*http.Response, error) 35 | } 36 | 37 | // New returns a new OpenAI client. 38 | func NewOpenAIClient(token string) *OpenAIClient { 39 | return &OpenAIClient{ 40 | token: token, 41 | EmbeddingModel: defaultEmbeddingModel, 42 | baseURL: strings.TrimSuffix(defaultBaseURL, "/"), 43 | httpClient: http.DefaultClient, 44 | } 45 | } 46 | 47 | // Completion is a completion. 48 | type Completion struct { 49 | Text string `json:"text"` 50 | } 51 | 52 | // EmbeddingRequest is a request to create an embedding. 53 | type EmbeddingRequest struct { 54 | Model string `json:"model"` 55 | Input []string `json:"input"` 56 | } 57 | 58 | type embeddingResponsePayload struct { 59 | Object string `json:"object"` 60 | Data []struct { 61 | Object string `json:"object"` 62 | Embedding []float32 `json:"embedding"` 63 | Index int `json:"index"` 64 | } `json:"data"` 65 | Model string `json:"model"` 66 | Usage struct { 67 | PromptTokens int `json:"prompt_tokens"` 68 | TotalTokens int `json:"total_tokens"` 69 | } `json:"usage"` 70 | } 71 | 72 | type errorMessage struct { 73 | Error struct { 74 | Message string `json:"message"` 75 | Type string `json:"type"` 76 | } `json:"error"` 77 | } 78 | 79 | // CreateEmbedding creates embeddings. 80 | func (c *OpenAIClient) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) { 81 | resp, err := c.createEmbedding(ctx, &EmbeddingRequest{ 82 | Input: inputTexts, 83 | Model: defaultEmbeddingModel, 84 | }) 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | if len(resp.Data) == 0 { 90 | return nil, ErrEmptyResponse 91 | } 92 | 93 | embeddings := make([][]float32, 0) 94 | for i := 0; i < len(resp.Data); i++ { 95 | embeddings = append(embeddings, resp.Data[i].Embedding) 96 | } 97 | 98 | return embeddings, nil 99 | } 100 | 101 | func (c *OpenAIClient) setHeaders(req *http.Request) { 102 | req.Header.Set("Content-Type", "application/json") 103 | req.Header.Set("Authorization", "Bearer "+c.token) 104 | if c.organization != "" { 105 | req.Header.Set("OpenAI-Organization", c.organization) 106 | } 107 | } 108 | 109 | func (c *OpenAIClient) buildURL(suffix string, model string) string { 110 | // open ai implement: 111 | return fmt.Sprintf("%s%s", c.baseURL, suffix) 112 | } 113 | 114 | // nolint:lll 115 | func (c *OpenAIClient) createEmbedding(ctx context.Context, payload *EmbeddingRequest) (*embeddingResponsePayload, error) { 116 | if c.baseURL == "" { 117 | c.baseURL = defaultBaseURL 118 | } 119 | payload.Model = c.EmbeddingModel 120 | 121 | payloadBytes, err := json.Marshal(payload) 122 | if err != nil { 123 | return nil, fmt.Errorf("marshal payload: %w", err) 124 | } 125 | 126 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.buildURL("/embeddings", c.EmbeddingModel), bytes.NewReader(payloadBytes)) 127 | if err != nil { 128 | return nil, fmt.Errorf("create request: %w", err) 129 | } 130 | c.setHeaders(req) 131 | 132 | r, err := c.httpClient.Do(req) 133 | if err != nil { 134 | return nil, fmt.Errorf("send request: %w", err) 135 | } 136 | defer r.Body.Close() 137 | 138 | if r.StatusCode != http.StatusOK { 139 | msg := fmt.Sprintf("API returned unexpected status code: %d", r.StatusCode) 140 | 141 | // No need to check the error here: if it fails, we'll just return the 142 | // status code. 143 | var errResp errorMessage 144 | 145 | // read all from r.Body 146 | body, err := io.ReadAll(r.Body) 147 | if err != nil { 148 | return nil, fmt.Errorf("read response body: %w", err) 149 | } 150 | 151 | if err := json.Unmarshal(body, &errResp); err != nil { 152 | log.Warn(string(body)) 153 | return nil, errors.New(msg) // nolint:goerr113 154 | } 155 | 156 | return nil, fmt.Errorf("%s: %s", msg, errResp.Error.Message) // nolint:goerr113 157 | } 158 | 159 | var response embeddingResponsePayload 160 | 161 | if err := json.NewDecoder(r.Body).Decode(&response); err != nil { 162 | return nil, fmt.Errorf("decode response: %w", err) 163 | } 164 | 165 | return &response, nil 166 | } 167 | -------------------------------------------------------------------------------- /vectordb/split_documents.go: -------------------------------------------------------------------------------- 1 | package vectordb 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "strings" 7 | ) 8 | 9 | // ErrMismatchMetadatasAndText is returned when the number of texts and metadatas 10 | // given to CreateDocuments does not match. The function will not error if the 11 | // length of the metadatas slice is zero. 12 | var ErrMismatchMetadatasAndText = errors.New("number of texts and metadatas does not match") 13 | 14 | // SplitDocuments splits documents using a textsplitter. 15 | func SplitDocuments(textSplitter TextSplitter, documents []Document) ([]Document, error) { 16 | texts := make([]string, 0) 17 | metadatas := make([]map[string]any, 0) 18 | for _, document := range documents { 19 | texts = append(texts, document.PageContent) 20 | metadatas = append(metadatas, document.Metadata) 21 | } 22 | 23 | return CreateDocuments(textSplitter, texts, metadatas) 24 | } 25 | 26 | // CreateDocuments creates documents from texts and metadatas with a text splitter. If 27 | // the length of the metadatas is zero, the result documents will contain no metadata. 28 | // Otherwise, the numbers of texts and metadatas must match. 29 | func CreateDocuments(textSplitter TextSplitter, texts []string, metadatas []map[string]any) ([]Document, error) { 30 | if len(metadatas) == 0 { 31 | metadatas = make([]map[string]any, len(texts)) 32 | } 33 | 34 | if len(texts) != len(metadatas) { 35 | return nil, ErrMismatchMetadatasAndText 36 | } 37 | 38 | documents := make([]Document, 0) 39 | 40 | for i := 0; i < len(texts); i++ { 41 | chunks, err := textSplitter.SplitText(texts[i]) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | for _, chunk := range chunks { 47 | // Copy the document metadata 48 | curMetadata := make(map[string]any, len(metadatas[i])) 49 | for key, value := range metadatas[i] { 50 | curMetadata[key] = value 51 | } 52 | 53 | documents = append(documents, Document{ 54 | PageContent: chunk, 55 | Metadata: curMetadata, 56 | }) 57 | } 58 | } 59 | 60 | return documents, nil 61 | } 62 | 63 | // joinDocs comines two documents with the separator used to split them. 64 | func joinDocs(docs []string, separator string) string { 65 | return strings.TrimSpace(strings.Join(docs, separator)) 66 | } 67 | 68 | // mergeSplits merges smaller splits into splits that are closer to the chunkSize. 69 | func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap int, lenFunc func(string) int) []string { //nolint:cyclop 70 | docs := make([]string, 0) 71 | currentDoc := make([]string, 0) 72 | total := 0 73 | 74 | for _, split := range splits { 75 | totalWithSplit := total + lenFunc(split) 76 | if len(currentDoc) != 0 { 77 | totalWithSplit += lenFunc(separator) 78 | } 79 | 80 | maybePrintWarning(total, chunkSize) 81 | if totalWithSplit > chunkSize && len(currentDoc) > 0 { 82 | doc := joinDocs(currentDoc, separator) 83 | if doc != "" { 84 | docs = append(docs, doc) 85 | } 86 | 87 | for shouldPop(chunkOverlap, chunkSize, total, lenFunc(split), lenFunc(separator), len(currentDoc)) { 88 | total -= lenFunc(currentDoc[0]) //nolint:gosec 89 | if len(currentDoc) > 1 { 90 | total -= lenFunc(separator) 91 | } 92 | currentDoc = currentDoc[1:] //nolint:gosec 93 | } 94 | } 95 | 96 | currentDoc = append(currentDoc, split) 97 | total += lenFunc(split) 98 | if len(currentDoc) > 1 { 99 | total += lenFunc(separator) 100 | } 101 | } 102 | 103 | doc := joinDocs(currentDoc, separator) 104 | if doc != "" { 105 | docs = append(docs, doc) 106 | } 107 | 108 | return docs 109 | } 110 | 111 | func maybePrintWarning(total, chunkSize int) { 112 | if total > chunkSize { 113 | log.Printf( 114 | "[WARN] created a chunk with size of %v, which is longer then the specified %v\n", 115 | total, 116 | chunkSize, 117 | ) 118 | } 119 | } 120 | 121 | // Keep popping if: 122 | // - the chunk is larger than the chunk overlap 123 | // - or if there are any chunks and the length is long 124 | func shouldPop(chunkOverlap, chunkSize, total, splitLen, separatorLen, currentDocLen int) bool { 125 | docsNeededToAddSep := 2 126 | if currentDocLen < docsNeededToAddSep { 127 | separatorLen = 0 128 | } 129 | 130 | return currentDocLen > 0 && (total > chunkOverlap || (total+splitLen+separatorLen > chunkSize && total > 0)) 131 | } 132 | -------------------------------------------------------------------------------- /vectordb/text.go: -------------------------------------------------------------------------------- 1 | package vectordb 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "io" 7 | ) 8 | 9 | // Text loads text data from an io.Reader. 10 | type Text struct { 11 | r io.Reader 12 | } 13 | 14 | // Loader is the interface for loading and splitting documents from a source. 15 | type Loader interface { 16 | // Load loads from a source and returns documents. 17 | Load(ctx context.Context) ([]Document, error) 18 | // LoadAndSplit loads from a source and splits the documents using a text splitter. 19 | LoadAndSplit(ctx context.Context, splitter TextSplitter) ([]Document, error) 20 | } 21 | 22 | var _ Loader = Text{} 23 | 24 | // NewText creates a new text loader with an io.Reader. 25 | func NewText(r io.Reader) Text { 26 | return Text{ 27 | r: r, 28 | } 29 | } 30 | 31 | // Load reads from the io.Reader and returns a single document with the data. 32 | func (l Text) Load(_ context.Context) ([]Document, error) { 33 | buf := new(bytes.Buffer) 34 | _, err := io.Copy(buf, l.r) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | return []Document{ 40 | { 41 | PageContent: buf.String(), 42 | Metadata: map[string]any{}, 43 | }, 44 | }, nil 45 | } 46 | 47 | // LoadAndSplit reads text data from the io.Reader and splits it into multiple 48 | // documents using a text splitter. 49 | func (l Text) LoadAndSplit(ctx context.Context, splitter TextSplitter) ([]Document, error) { 50 | docs, err := l.Load(ctx) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | return SplitDocuments(splitter, docs) 56 | } 57 | -------------------------------------------------------------------------------- /vectordb/text_spliter.go: -------------------------------------------------------------------------------- 1 | package vectordb 2 | 3 | // TextSplitter is the standard interface for splitting texts. 4 | type TextSplitter interface { 5 | SplitText(text string) ([]string, error) 6 | } 7 | -------------------------------------------------------------------------------- /vectordb/token_splitter.go: -------------------------------------------------------------------------------- 1 | package vectordb 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pkoukk/tiktoken-go" 7 | ) 8 | 9 | const ( 10 | // nolint:gosec 11 | _defaultTokenModelName = "gpt-4" 12 | _defaultTokenEncoding = "cl100k_base" 13 | _defaultTokenChunkSize = 512 14 | _defaultTokenChunkOverlap = 100 15 | ) 16 | 17 | // TokenSplitter is a text splitter that will split texts by tokens. 18 | type TokenSplitter struct { 19 | ChunkSize int 20 | ChunkOverlap int 21 | ModelName string 22 | EncodingName string 23 | AllowedSpecial []string 24 | DisallowedSpecial []string 25 | } 26 | 27 | func NewTokenSplitter(opts ...TsOption) TokenSplitter { 28 | options := DefaultTsOptions() 29 | for _, o := range opts { 30 | o(&options) 31 | } 32 | 33 | s := TokenSplitter{ 34 | ChunkSize: options.ChunkSize, 35 | ChunkOverlap: options.ChunkOverlap, 36 | ModelName: options.ModelName, 37 | EncodingName: options.EncodingName, 38 | AllowedSpecial: options.AllowedSpecial, 39 | DisallowedSpecial: options.DisallowedSpecial, 40 | } 41 | 42 | return s 43 | } 44 | 45 | // SplitText splits a text into multiple text. 46 | func (s TokenSplitter) SplitText(text string) ([]string, error) { 47 | // Get the tokenizer 48 | var tk *tiktoken.Tiktoken 49 | var err error 50 | if s.EncodingName != "" { 51 | tk, err = tiktoken.GetEncoding(s.EncodingName) 52 | } else { 53 | tk, err = tiktoken.EncodingForModel(s.ModelName) 54 | } 55 | if err != nil { 56 | return nil, fmt.Errorf("tiktoken.GetEncoding: %w", err) 57 | } 58 | texts := s.splitText(text, tk) 59 | 60 | return texts, nil 61 | } 62 | 63 | func (s TokenSplitter) splitText(text string, tk *tiktoken.Tiktoken) []string { 64 | splits := make([]string, 0) 65 | inputIDs := tk.Encode(text, s.AllowedSpecial, s.DisallowedSpecial) 66 | 67 | startIdx := 0 68 | curIdx := len(inputIDs) 69 | if startIdx+s.ChunkSize < curIdx { 70 | curIdx = startIdx + s.ChunkSize 71 | } 72 | for startIdx < len(inputIDs) { 73 | chunkIDs := inputIDs[startIdx:curIdx] 74 | splits = append(splits, tk.Decode(chunkIDs)) 75 | startIdx += s.ChunkSize - s.ChunkOverlap 76 | curIdx = startIdx + s.ChunkSize 77 | if curIdx > len(inputIDs) { 78 | curIdx = len(inputIDs) 79 | } 80 | } 81 | return splits 82 | } 83 | -------------------------------------------------------------------------------- /vectordb/tsoptions.go: -------------------------------------------------------------------------------- 1 | package vectordb 2 | 3 | import "unicode/utf8" 4 | 5 | // TsOptions is a struct that contains options for a text splitter. 6 | type TsOptions struct { 7 | ChunkSize int 8 | ChunkOverlap int 9 | Separators []string 10 | LenFunc func(string) int 11 | ModelName string 12 | EncodingName string 13 | AllowedSpecial []string 14 | DisallowedSpecial []string 15 | SecondSplitter TextSplitter 16 | CodeBlocks bool 17 | ReferenceLinks bool 18 | } 19 | 20 | // DefaultTsOptions returns the default options for all text splitter. 21 | func DefaultTsOptions() TsOptions { 22 | return TsOptions{ 23 | ChunkSize: _defaultTokenChunkSize, 24 | ChunkOverlap: _defaultTokenChunkOverlap, 25 | Separators: []string{"\n\n", "\n", " ", ""}, 26 | LenFunc: utf8.RuneCountInString, 27 | 28 | ModelName: _defaultTokenModelName, 29 | EncodingName: _defaultTokenEncoding, 30 | AllowedSpecial: []string{}, 31 | DisallowedSpecial: []string{"all"}, 32 | } 33 | } 34 | 35 | // TsOption is a function that can be used to set options for a text splitter. 36 | type TsOption func(*TsOptions) 37 | 38 | // WithChunkSize sets the chunk size for a text splitter. 39 | func WithChunkSize(chunkSize int) TsOption { 40 | return func(o *TsOptions) { 41 | o.ChunkSize = chunkSize 42 | } 43 | } 44 | 45 | // WithChunkOverlap sets the chunk overlap for a text splitter. 46 | func WithChunkOverlap(chunkOverlap int) TsOption { 47 | return func(o *TsOptions) { 48 | o.ChunkOverlap = chunkOverlap 49 | } 50 | } 51 | 52 | // WithSeparators sets the separators for a text splitter. 53 | func WithSeparators(separators []string) TsOption { 54 | return func(o *TsOptions) { 55 | o.Separators = separators 56 | } 57 | } 58 | 59 | // WithLenFunc sets the lenfunc for a text splitter. 60 | func WithLenFunc(lenFunc func(string) int) TsOption { 61 | return func(o *TsOptions) { 62 | o.LenFunc = lenFunc 63 | } 64 | } 65 | 66 | // WithModelName sets the model name for a text splitter. 67 | func WithModelName(modelName string) TsOption { 68 | return func(o *TsOptions) { 69 | o.ModelName = modelName 70 | } 71 | } 72 | 73 | // WithEncodingName sets the encoding name for a text splitter. 74 | func WithEncodingName(encodingName string) TsOption { 75 | return func(o *TsOptions) { 76 | o.EncodingName = encodingName 77 | } 78 | } 79 | 80 | // WithAllowedSpecial sets the allowed special tokens for a text splitter. 81 | func WithAllowedSpecial(allowedSpecial []string) TsOption { 82 | return func(o *TsOptions) { 83 | o.AllowedSpecial = allowedSpecial 84 | } 85 | } 86 | 87 | // WithDisallowedSpecial sets the disallowed special tokens for a text splitter. 88 | func WithDisallowedSpecial(disallowedSpecial []string) TsOption { 89 | return func(o *TsOptions) { 90 | o.DisallowedSpecial = disallowedSpecial 91 | } 92 | } 93 | 94 | // WithSecondSplitter sets the second splitter for a text splitter. 95 | func WithSecondSplitter(secondSplitter TextSplitter) TsOption { 96 | return func(o *TsOptions) { 97 | o.SecondSplitter = secondSplitter 98 | } 99 | } 100 | 101 | // WithCodeBlocks sets whether indented and fenced codeblocks should be included 102 | // in the output. 103 | func WithCodeBlocks(renderCode bool) TsOption { 104 | return func(o *TsOptions) { 105 | o.CodeBlocks = renderCode 106 | } 107 | } 108 | 109 | // WithReferenceLinks sets whether reference links (i.e. `[text][label]`) 110 | // should be patched with the url and title from their definition. Note that 111 | // by default reference definitions are dropped from the output. 112 | // 113 | // Caution: this also affects how other inline elements are rendered, e.g. all 114 | // emphasis will use `*` even when another character (e.g. `_`) was used in the 115 | // input. 116 | func WithReferenceLinks(referenceLinks bool) TsOption { 117 | return func(o *TsOptions) { 118 | o.ReferenceLinks = referenceLinks 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /voice.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "encoding/json" 7 | "io" 8 | "net/http" 9 | "os" 10 | "os/exec" 11 | "strings" 12 | 13 | "github.com/meinside/openai-go" 14 | "github.com/tectiv3/chatgpt-bot/opus" 15 | "github.com/tectiv3/go-lame" 16 | tele "gopkg.in/telebot.v3" 17 | ) 18 | 19 | func convertToWav(r io.Reader) ([]byte, error) { 20 | output := new(bytes.Buffer) 21 | wavWriter, err := newWavWriter(output, 48000, 1, 16) 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | s, err := opus.NewStream(r) 27 | if err != nil { 28 | return nil, err 29 | } 30 | defer s.Close() 31 | 32 | pcmbuf := make([]float32, 16384) 33 | for { 34 | n, err := s.ReadFloat32(pcmbuf) 35 | if err == io.EOF { 36 | break 37 | } else if err != nil { 38 | Log.Fatal(err) 39 | } 40 | pcm := pcmbuf[:n*1] 41 | 42 | err = wavWriter.WriteSamples(pcm) 43 | if err != nil { 44 | return nil, err 45 | } 46 | } 47 | 48 | return output.Bytes(), err 49 | } 50 | 51 | // Helper function to create a new WAV writer 52 | func newWavWriter(w io.Writer, sampleRate int, numChannels int, bitsPerSample int) (*wavWriter, error) { 53 | var header wavHeader 54 | 55 | // Set header values 56 | header.RIFFID = [4]byte{'R', 'I', 'F', 'F'} 57 | header.WAVEID = [4]byte{'W', 'A', 'V', 'E'} 58 | header.FMTID = [4]byte{'f', 'm', 't', ' '} 59 | header.Subchunk1Size = 16 60 | header.AudioFormat = 1 61 | header.NumChannels = uint16(numChannels) 62 | header.SampleRate = uint32(sampleRate) 63 | header.BitsPerSample = uint16(bitsPerSample) 64 | header.ByteRate = uint32(sampleRate * numChannels * bitsPerSample / 8) 65 | header.BlockAlign = uint16(numChannels * bitsPerSample / 8) 66 | header.DataID = [4]byte{'d', 'a', 't', 'a'} 67 | 68 | // Write header 69 | err := binary.Write(w, binary.LittleEndian, &header) 70 | if err != nil { 71 | return nil, err 72 | } 73 | 74 | return &wavWriter{w: w}, nil 75 | } 76 | 77 | // WriteSamples Write samples to the WAV file 78 | func (ww *wavWriter) WriteSamples(samples []float32) error { 79 | // Convert float32 samples to int16 samples 80 | int16Samples := make([]int16, len(samples)) 81 | for i, s := range samples { 82 | if s > 1.0 { 83 | s = 1.0 84 | } else if s < -1.0 { 85 | s = -1.0 86 | } 87 | int16Samples[i] = int16(s * 32767) 88 | } 89 | // Write int16 samples to the WAV file 90 | return binary.Write(ww.w, binary.LittleEndian, &int16Samples) 91 | } 92 | 93 | func wavToMp3(wav []byte) []byte { 94 | reader := bytes.NewReader(wav) 95 | wavHdr, err := lame.ReadWavHeader(reader) 96 | if err != nil { 97 | Log.Warn("not a wav file", "error=", err.Error()) 98 | return nil 99 | } 100 | output := new(bytes.Buffer) 101 | wr, _ := lame.NewWriter(output) 102 | defer wr.Close() 103 | 104 | wr.EncodeOptions = wavHdr.ToEncodeOptions() 105 | if _, err := io.Copy(wr, reader); err != nil { 106 | return nil 107 | } 108 | 109 | return output.Bytes() 110 | } 111 | 112 | func (s *Server) handleVoice(c tele.Context) { 113 | if c.Message().Voice.FileSize == 0 { 114 | return 115 | } 116 | audioFile := c.Message().Voice.File 117 | var reader io.ReadCloser 118 | var err error 119 | 120 | if s.conf.TelegramServerURL != "" { 121 | f, err := c.Bot().FileByID(audioFile.FileID) 122 | if err != nil { 123 | Log.Warn("Error getting file ID", "error=", err) 124 | return 125 | } 126 | // start reader from f.FilePath 127 | reader, err = os.Open(f.FilePath) 128 | if err != nil { 129 | Log.Warn("Error opening file", "error=", err) 130 | return 131 | } 132 | } else { 133 | reader, err = c.Bot().File(&audioFile) 134 | if err != nil { 135 | Log.Warn("Error getting file content", "error=", err) 136 | return 137 | } 138 | } 139 | defer reader.Close() 140 | 141 | //body, err := ioutil.ReadAll(reader) 142 | //if err != nil { 143 | // fmt.Println("Error reading file content:", err) 144 | // return nil 145 | //} 146 | 147 | wav, err := convertToWav(reader) 148 | if err != nil { 149 | Log.Warn("failed to convert to wav", "error=", err) 150 | return 151 | } 152 | mp3 := wavToMp3(wav) 153 | if mp3 == nil { 154 | Log.Warn("failed to convert to mp3") 155 | return 156 | } 157 | audio := openai.NewFileParamFromBytes(mp3) 158 | transcript, err := s.openAI.CreateTranscription(audio, "whisper-1", nil) 159 | if err != nil { 160 | Log.Warn("failed to create transcription", "error=", err) 161 | return 162 | } 163 | if transcript.JSON == nil && 164 | transcript.Text == nil && 165 | transcript.SRT == nil && 166 | transcript.VerboseJSON == nil && 167 | transcript.VTT == nil { 168 | Log.Warn("There was no returned data") 169 | 170 | return 171 | } 172 | 173 | if strings.HasPrefix(strings.ToLower(*transcript.Text), "reset") { 174 | chat := s.getChat(c.Chat(), c.Sender()) 175 | s.deleteHistory(chat.ID) 176 | 177 | v := &tele.Voice{File: tele.FromDisk("erased.ogg")} 178 | _ = c.Send(v) 179 | 180 | return 181 | } 182 | 183 | s.complete(c, *transcript.Text, false) 184 | chat := s.getChat(c.Chat(), c.Sender()) 185 | sentMessage := chat.getSentMessage(c) 186 | response := sentMessage.Text 187 | 188 | Log.WithField("user", c.Sender().Username).Info("Response length=", len(response)) 189 | 190 | if len(response) == 0 { 191 | return 192 | } 193 | 194 | s.sendAudio(c, response) 195 | } 196 | 197 | func (s *Server) sendAudio(c tele.Context, text string) { 198 | url := "https://api.openai.com/v1/audio/speech" 199 | body := map[string]string{ 200 | "model": "tts-1", 201 | "input": text, 202 | "voice": "alloy", 203 | "response_format": "opus", 204 | "speed": "1", 205 | } 206 | jsonStr, _ := json.Marshal(body) 207 | req, _ := http.NewRequest("POST", url, bytes.NewBuffer(jsonStr)) 208 | req.Header.Set("Authorization", "Bearer "+s.conf.OpenAIAPIKey) 209 | req.Header.Set("Content-Type", "application/json") 210 | 211 | client := &http.Client{} 212 | resp, err := client.Do(req) 213 | if err != nil { 214 | Log.Warn("failed to send request", "error=", err) 215 | } 216 | defer resp.Body.Close() 217 | 218 | out, err := os.CreateTemp("", "chatbot") 219 | if err != nil { 220 | Log.Warn("failed to create temp file", "error=", err) 221 | return 222 | } 223 | 224 | _, err = io.Copy(out, resp.Body) 225 | if err := out.Close(); err != nil { 226 | return 227 | } 228 | 229 | v := &tele.Voice{File: tele.FromDisk(out.Name())} 230 | defer os.Remove(out.Name()) 231 | _ = c.Send(v) 232 | } 233 | 234 | func (s *Server) textToSpeech(c tele.Context, text, lang string) error { 235 | switch lang { 236 | case "en": 237 | case "fr": 238 | case "ru": 239 | break 240 | default: 241 | s.sendAudio(c, text) 242 | return nil 243 | } 244 | if len(s.conf.PiperDir) == 0 { 245 | return c.Send("PiperDir is not set") 246 | } 247 | cmd := exec.Command(s.conf.PiperDir+"piper", "-m", s.conf.PiperDir+lang+".onnx", "-f", "-") 248 | 249 | stdin, _ := cmd.StdinPipe() 250 | stdout, _ := cmd.StdoutPipe() 251 | stderr, _ := cmd.StderrPipe() 252 | go io.Copy(os.Stderr, stderr) 253 | 254 | out, err := os.CreateTemp("", "piper.wav") 255 | if err != nil { 256 | return c.Send("Error creating temp file: " + err.Error()) 257 | } 258 | defer out.Close() 259 | 260 | if err := cmd.Start(); err != nil { 261 | return c.Send("Error starting command: " + err.Error()) 262 | } 263 | if _, err := stdin.Write([]byte(text)); err != nil { 264 | return c.Send("Error writing to command: " + err.Error()) 265 | } 266 | stdin.Close() 267 | _, err = io.Copy(out, stdout) 268 | if err != nil { 269 | return c.Send("Error reading from the command: " + err.Error()) 270 | } 271 | if err := cmd.Wait(); err != nil { 272 | return c.Send("Error waiting for command: " + err.Error()) 273 | } 274 | 275 | Log.WithField("user", c.Sender().Username).Info("TTS done") 276 | v := &tele.Voice{File: tele.FromDisk(out.Name())} 277 | defer os.Remove(out.Name()) 278 | 279 | return c.Send(v) 280 | } 281 | --------------------------------------------------------------------------------