├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── cmd ├── mqtt │ └── main.go ├── opus_example │ └── main.go └── server │ ├── config.go │ └── main.go ├── config ├── config.json ├── models │ └── vad │ │ └── silero_vad.onnx └── mqtt_config.json ├── doc ├── config.md ├── delay_test.md ├── docker.md ├── user_config.md └── websocket_meter.md ├── docker ├── Dockerfile ├── Dockerfile_build └── lib │ └── onnxruntime-linux-x64-1.21.0.tgz ├── go.mod ├── go.sum ├── internal ├── app │ ├── mqtt_server │ │ ├── auth_hook.go │ │ ├── device_hook.go │ │ └── mqtt_server.go │ └── server │ │ ├── auth │ │ └── auth.go │ │ ├── common │ │ ├── chat.go │ │ ├── common.go │ │ └── tts │ │ │ └── tts.go │ │ ├── mqtt_udp │ │ ├── mqtt_server.go │ │ └── udp_server.go │ │ ├── server.go │ │ └── websocket │ │ ├── asr.go │ │ ├── listen.go │ │ ├── llm.go │ │ ├── types.go │ │ └── websocket_server.go ├── config │ └── config.go ├── data │ ├── audio │ │ └── audio.go │ ├── client │ │ ├── client.go │ │ ├── mqtt.go │ │ ├── statistics.go │ │ └── udp.go │ ├── msg │ │ └── message_types.go │ └── udp │ │ └── udp.go ├── domain │ ├── asr │ │ ├── adapter.go │ │ ├── base.go │ │ ├── funasr │ │ │ ├── config.go │ │ │ ├── example │ │ │ │ └── streaming_example.go │ │ │ └── funasr.go │ │ └── types │ │ │ └── types.go │ ├── audio │ │ └── audio_handler.go │ ├── llm │ │ ├── base.go │ │ ├── common │ │ │ └── types.go │ │ ├── llm.go │ │ ├── llm_sentence.go │ │ ├── memory │ │ │ ├── llm_memory.go │ │ │ └── types.go │ │ ├── ollama │ │ │ ├── ollama.go │ │ │ └── ollama_test.go │ │ ├── openai │ │ │ ├── openai.go │ │ │ └── openai_test.go │ │ └── test │ │ │ ├── llm_test.go │ │ │ └── splite_content.go │ ├── message_types.go │ ├── tts │ │ ├── base.go │ │ ├── common │ │ │ ├── audio_utils.go │ │ │ ├── audio_utils_test.go │ │ │ └── test.wav │ │ ├── cosyvoice │ │ │ ├── cosyvoice.go │ │ │ └── cosyvoice_test.go │ │ ├── doubao │ │ │ ├── doubao.go │ │ │ ├── doubao_test.go │ │ │ ├── doubao_ws.go │ │ │ └── doubao_ws_test.go │ │ ├── edge │ │ │ ├── edge.go │ │ │ └── edge_test.go │ │ └── xiaozhi │ │ │ ├── xiaozhi.go │ │ │ └── xiaozhi_test.go │ ├── user_config │ │ └── userconfig.go │ └── vad │ │ ├── test │ │ ├── silero_vad.onnx │ │ ├── vad │ │ ├── vad.go │ │ └── wav2vad.go │ │ ├── vad.go │ │ └── vad_test.go └── util │ ├── buffer.go │ ├── encryption.go │ └── workqueue │ ├── parallelizer.go │ ├── parallelizer_test.go │ └── parallelizer_test.go.bak ├── logger ├── db_log.go └── logger.go └── test ├── mem0 ├── __pycache__ │ └── mem0.cpython-312.pyc └── memory.py ├── mqtt_udp ├── README.md ├── audio_utils ├── audio_utils.go ├── go.sum ├── main ├── main.go ├── mqtt.go ├── mqtt.py ├── ota.go ├── test_24000.wav └── udp.go ├── py_test_audio ├── dec_opus.py └── main.py ├── test_audio ├── audio_utils.go └── main.go ├── test_opus └── decode_opus.c └── websocket_client ├── audio_utils.go ├── test.mp3 ├── test_mp3.go ├── xiaozhi_websocket_client.go └── xiaozhi_ws_client_multi.go /.dockerignore: -------------------------------------------------------------------------------- 1 | **/.git 2 | **/node_modules 3 | **/*.log 4 | **/bin 5 | **/vendor 6 | **/dist 7 | **/test 8 | **/.vscode 9 | **/tmp 10 | **/*.exe 11 | **/*.dll 12 | **/*.bat 13 | **/Thumbs.db 14 | **/Desktop.ini 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.exe 2 | *.log 3 | *.pem 4 | *.key 5 | config_shijingbo.json 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 hackers365 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # xiaozhi-esp32-server-golang 2 | # 项目简介 3 | 此项目是 虾哥 小智ai 后端的golang版本,实现了asr输入, llm输出, tts输入输出 全流式处理,依托于golang的高性能可以支撑大并发的请求 4 | 5 | # 特性 6 | - 全流式数据处理 7 | - 实现了小智ai websocket 8 | - 实现了mqtt和udp服务器 9 | - 外部资源连接池复用,减少耗时 10 | 11 | # 快速开始 12 | [docker 快速开始](doc/docker.md) 13 | 14 | # 延迟测试 15 | [延迟测试结果](doc/delay_test.md) 16 | # 模块简介 17 | ## VAD 18 | 实现了silvero vad声音检测 19 | ## ASR 20 | 对接了funasr的websocket服务接口,部署地址 https://github.com/modelscope/FunASR/blob/main/runtime/docs/SDK_advanced_guide_online_zh.md 21 | ## LLM 22 | 实现了 openai 格式兼容的接口 23 | ## TTS 24 | 已实现 25 | - doubao websocket 26 | - edgetts 27 | - cosyvoice 28 | 29 | # 快速开始 30 | 31 | ### 安装onnx running依赖 32 | ### 部署funasr的服务 33 | https://github.com/modelscope/FunASR/blob/main/runtime/docs/SDK_advanced_guide_online_zh.md 34 | ### 编译mqtt服务 35 | go build cmd/mqtt/*.go 36 | ### 配置文件 config.json 37 | ### 编译xiaozhi-esp32-server-golang 38 | ### 启动 39 | 40 | # TODO 41 | ### 用户认证 42 | ### 接入更多云厂商asr服务 43 | ### 用户界面 44 | ### llm记忆体 45 | ### docker化部署 46 | 47 | ![进群二维码_min](https://github.com/user-attachments/assets/ecdb7abb-d723-4ada-969e-6082f693fc9f) 48 | 49 | 50 | 51 | 微信:hackers365 52 | -------------------------------------------------------------------------------- /cmd/mqtt/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | "os/signal" 8 | "path/filepath" 9 | "strings" 10 | "syscall" 11 | "time" 12 | 13 | rotatelogs "github.com/lestrrat-go/file-rotatelogs" 14 | "github.com/sirupsen/logrus" 15 | "github.com/spf13/viper" 16 | 17 | mqtt_server "xiaozhi-esp32-server-golang/internal/app/mqtt_server" 18 | log "xiaozhi-esp32-server-golang/logger" 19 | ) 20 | 21 | // 初始化函数 22 | func Init(configFile string) error { 23 | err := initConfig(configFile) 24 | if err != nil { 25 | return err 26 | } 27 | 28 | err = initLog() 29 | if err != nil { 30 | return err 31 | } 32 | 33 | return nil 34 | } 35 | 36 | func initLog() error { 37 | // 不再检查stdout配置,统一输出到文件 38 | // 输出到文件 39 | binPath, _ := os.Executable() 40 | baseDir := filepath.Dir(binPath) 41 | logPath := fmt.Sprintf("%s/%s%s", baseDir, viper.GetString("log.path"), viper.GetString("log.file")) 42 | /* 日志轮转相关函数 43 | `WithLinkName` 为最新的日志建立软连接 44 | `WithRotationTime` 设置日志分割的时间,隔多久分割一次 45 | WithMaxAge 和 WithRotationCount二者只能设置一个 46 | `WithMaxAge` 设置文件清理前的最长保存时间 47 | `WithRotationCount` 设置文件清理前最多保存的个数 48 | */ 49 | // 下面配置日志每隔 1 分钟轮转一个新文件,保留最近 3 分钟的日志文件,多余的自动清理掉。 50 | writer, err := rotatelogs.New( 51 | logPath+".%Y%m%d", 52 | rotatelogs.WithLinkName(logPath), 53 | rotatelogs.WithRotationCount(uint(viper.GetInt("log.max_age"))), 54 | rotatelogs.WithRotationTime(time.Duration(86400)*time.Second), 55 | ) 56 | if err != nil { 57 | fmt.Printf("init log error: %v\n", err) 58 | os.Exit(1) 59 | return err 60 | } 61 | logrus.SetOutput(writer) 62 | logrus.SetFormatter(&logrus.TextFormatter{ 63 | TimestampFormat: "2006-01-02 15:04:05.000", //时间格式化,添加毫秒 64 | ForceColors: false, // 文件输出不启用颜色 65 | }) 66 | 67 | // 禁用默认的调用者报告,使用自定义的caller字段 68 | logrus.SetReportCaller(false) 69 | logLevel, _ := logrus.ParseLevel(viper.GetString("log.level")) 70 | logrus.SetLevel(logLevel) 71 | 72 | return nil 73 | 74 | } 75 | 76 | func initConfig(configFile string) error { 77 | basePath, file := filepath.Split(configFile) 78 | 79 | // 获取文件名和扩展名 80 | fileName, fileExt := func(file string) (string, string) { 81 | if pos := strings.LastIndex(file, "."); pos != -1 { 82 | return file[:pos], strings.ToLower(file[pos+1:]) 83 | } 84 | return file, "" 85 | }(file) 86 | 87 | // 设置配置文件名(不带扩展名) 88 | viper.SetConfigName(fileName) 89 | viper.AddConfigPath(basePath) 90 | 91 | // 根据文件扩展名设置配置类型 92 | switch fileExt { 93 | case "json": 94 | viper.SetConfigType("json") 95 | case "yaml", "yml": 96 | viper.SetConfigType("yaml") 97 | default: 98 | return fmt.Errorf("unsupported config file type: %s", fileExt) 99 | } 100 | 101 | return viper.ReadInConfig() 102 | } 103 | 104 | func main() { 105 | // 解析命令行参数 106 | configFile := flag.String("c", "config/mqtt_config.json", "配置文件路径") 107 | flag.Parse() 108 | 109 | if *configFile == "" { 110 | fmt.Println("配置文件路径不能为空") 111 | return 112 | } 113 | 114 | // 初始化配置和日志 115 | err := Init(*configFile) 116 | if err != nil { 117 | fmt.Printf("初始化失败: %v\n", err) 118 | return 119 | } 120 | 121 | // 启动MQTT服务器 122 | err = mqtt_server.StartMqttServer() 123 | if err != nil { 124 | log.Errorf("启动MQTT服务器失败: %v", err) 125 | return 126 | } 127 | 128 | fmt.Println("MQTT服务器已启动") 129 | 130 | // 阻塞监听退出信号 131 | quit := make(chan os.Signal, 1) 132 | signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) 133 | 134 | log.Info("MQTT服务器已启动,按 Ctrl+C 退出") 135 | <-quit 136 | 137 | log.Info("正在关闭MQTT服务器...") 138 | log.Info("MQTT服务器已关闭") 139 | } 140 | -------------------------------------------------------------------------------- /cmd/opus_example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/hraban/opus" 8 | ) 9 | 10 | func main() { 11 | // 音频参数设置 12 | channels := 1 13 | sampleRate := 16000 // 16kHz 14 | fmt.Printf("通道数: %d, 采样率: %d Hz\n", channels, sampleRate) 15 | 16 | // 创建一个编码器,指定应用类型为VoIP (低延迟语音) 17 | enc, err := opus.NewEncoder(sampleRate, channels, opus.AppVoIP) 18 | if err != nil { 19 | fmt.Printf("创建编码器失败: %v\n", err) 20 | os.Exit(1) 21 | } 22 | 23 | // 设置比特率为16kbps 24 | if err = enc.SetBitrate(16000); err != nil { 25 | fmt.Printf("设置比特率失败: %v\n", err) 26 | os.Exit(1) 27 | } 28 | 29 | // 设置复杂度,0-10之间,越高质量越好但CPU消耗越大 30 | if err = enc.SetComplexity(5); err != nil { 31 | fmt.Printf("设置复杂度失败: %v\n", err) 32 | os.Exit(1) 33 | } 34 | 35 | // 生成20ms的测试PCM数据 (每帧20ms,16kHz采样率 = 320样本) 36 | frameSize := 320 37 | pcm := make([]int16, frameSize*channels) 38 | 39 | // 生成一个简单的正弦波进行测试 40 | for i := 0; i < frameSize; i++ { 41 | // 简单的正弦波,频率约为440Hz 42 | value := int16(10000.0 * float64(i%36) / 36.0) 43 | pcm[i] = value 44 | } 45 | 46 | // 用于存储编码后的数据 47 | data := make([]byte, 1000) 48 | 49 | // 编码PCM数据为Opus 50 | n, err := enc.Encode(pcm, data) 51 | if err != nil { 52 | fmt.Printf("编码失败: %v\n", err) 53 | os.Exit(1) 54 | } 55 | 56 | fmt.Printf("编码%d个样本为%d字节的Opus数据,压缩率: %.2f%%\n", 57 | frameSize*channels, n, float64(n)/float64(frameSize*channels*2)*100) 58 | 59 | // 创建解码器进行解码测试 60 | dec, err := opus.NewDecoder(sampleRate, channels) 61 | if err != nil { 62 | fmt.Printf("创建解码器失败: %v\n", err) 63 | os.Exit(1) 64 | } 65 | 66 | // 用于存储解码后的PCM数据 67 | decodedPCM := make([]int16, frameSize*channels) 68 | 69 | // 解码Opus数据为PCM 70 | samplesDecoded, err := dec.Decode(data[:n], decodedPCM) 71 | if err != nil { 72 | fmt.Printf("解码失败: %v\n", err) 73 | os.Exit(1) 74 | } 75 | 76 | fmt.Printf("解码%d字节的Opus数据为%d个样本\n", n, samplesDecoded) 77 | 78 | // 计算原始PCM与解码后PCM的差异 79 | var sumDiff int64 80 | for i := 0; i < frameSize; i++ { 81 | diff := int64(pcm[i]) - int64(decodedPCM[i]) 82 | if diff < 0 { 83 | diff = -diff 84 | } 85 | sumDiff += diff 86 | } 87 | avgDiff := float64(sumDiff) / float64(frameSize) 88 | 89 | fmt.Printf("原始PCM与解码PCM的平均差异: %.2f\n", avgDiff) 90 | fmt.Println("Opus编解码示例完成!") 91 | } 92 | -------------------------------------------------------------------------------- /cmd/server/config.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | "time" 9 | "xiaozhi-esp32-server-golang/internal/app/server/auth" 10 | llm_memory "xiaozhi-esp32-server-golang/internal/domain/llm/memory" 11 | userconfig "xiaozhi-esp32-server-golang/internal/domain/user_config" 12 | "xiaozhi-esp32-server-golang/internal/domain/vad" 13 | 14 | rotatelogs "github.com/lestrrat-go/file-rotatelogs" 15 | "github.com/redis/go-redis/v9" 16 | logrus "github.com/sirupsen/logrus" 17 | "github.com/spf13/viper" 18 | ) 19 | 20 | func Init(configFile string) error { 21 | //init config 22 | err := initConfig(configFile) 23 | if err != nil { 24 | fmt.Printf("initConfig err: %+v", err) 25 | os.Exit(1) 26 | return err 27 | } 28 | 29 | //init log 30 | initLog() 31 | 32 | //init vad 33 | initVad() 34 | 35 | //init redis 36 | initRedis() 37 | 38 | //init auth 39 | err = initAuthManager() 40 | if err != nil { 41 | fmt.Printf("initAuthManager err: %+v", err) 42 | os.Exit(1) 43 | return err 44 | } 45 | 46 | return nil 47 | } 48 | 49 | func initConfig(configFile string) error { 50 | basePath, file := filepath.Split(configFile) 51 | 52 | // 获取文件名和扩展名 53 | fileName, fileExt := func(file string) (string, string) { 54 | if pos := strings.LastIndex(file, "."); pos != -1 { 55 | return file[:pos], strings.ToLower(file[pos+1:]) 56 | } 57 | return file, "" 58 | }(file) 59 | 60 | // 设置配置文件名(不带扩展名) 61 | viper.SetConfigName(fileName) 62 | viper.AddConfigPath(basePath) 63 | 64 | // 根据文件扩展名设置配置类型 65 | switch fileExt { 66 | case "json": 67 | viper.SetConfigType("json") 68 | case "yaml", "yml": 69 | viper.SetConfigType("yaml") 70 | default: 71 | return fmt.Errorf("unsupported config file type: %s", fileExt) 72 | } 73 | 74 | return viper.ReadInConfig() 75 | } 76 | 77 | func initLog() error { 78 | // 不再检查stdout配置,统一输出到文件 79 | // 输出到文件 80 | binPath, _ := os.Executable() 81 | baseDir := filepath.Dir(binPath) 82 | logPath := fmt.Sprintf("%s/%s%s", baseDir, viper.GetString("log.path"), viper.GetString("log.file")) 83 | /* 日志轮转相关函数 84 | `WithLinkName` 为最新的日志建立软连接 85 | `WithRotationTime` 设置日志分割的时间,隔多久分割一次 86 | WithMaxAge 和 WithRotationCount二者只能设置一个 87 | `WithMaxAge` 设置文件清理前的最长保存时间 88 | `WithRotationCount` 设置文件清理前最多保存的个数 89 | */ 90 | // 下面配置日志每隔 1 分钟轮转一个新文件,保留最近 3 分钟的日志文件,多余的自动清理掉。 91 | writer, err := rotatelogs.New( 92 | logPath+".%Y%m%d", 93 | rotatelogs.WithLinkName(logPath), 94 | rotatelogs.WithRotationCount(uint(viper.GetInt("log.max_age"))), 95 | rotatelogs.WithRotationTime(time.Duration(86400)*time.Second), 96 | ) 97 | if err != nil { 98 | fmt.Printf("init log error: %v\n", err) 99 | os.Exit(1) 100 | return err 101 | } 102 | logrus.SetOutput(writer) 103 | logrus.SetFormatter(&logrus.TextFormatter{ 104 | TimestampFormat: "2006-01-02 15:04:05.000", //时间格式化,添加毫秒 105 | ForceColors: false, // 文件输出不启用颜色 106 | }) 107 | 108 | // 禁用默认的调用者报告,使用自定义的caller字段 109 | logrus.SetReportCaller(false) 110 | logLevel, _ := logrus.ParseLevel(viper.GetString("log.level")) 111 | logrus.SetLevel(logLevel) 112 | 113 | return nil 114 | } 115 | 116 | func initVad() error { 117 | err := vad.InitVAD() 118 | if err != nil { 119 | fmt.Printf("initVad error: %v\n", err) 120 | os.Exit(1) 121 | return err 122 | } 123 | return nil 124 | } 125 | 126 | func initRedis() error { 127 | redisOptions := &redis.Options{ 128 | Addr: fmt.Sprintf("%s:%d", viper.GetString("redis.host"), viper.GetInt("redis.port")), 129 | Password: viper.GetString("redis.password"), 130 | DB: viper.GetInt("redis.db"), 131 | } 132 | err := llm_memory.Init(redisOptions, viper.GetString("redis.key_prefix")) 133 | if err != nil { 134 | fmt.Printf("init redis error: %v\n", err) 135 | return err 136 | } 137 | 138 | err = userconfig.InitUserConfig(redisOptions, viper.GetString("redis.key_prefix")) 139 | if err != nil { 140 | fmt.Printf("init userconfig error: %v\n", err) 141 | return err 142 | } 143 | 144 | return nil 145 | } 146 | 147 | func initAuthManager() error { 148 | return auth.Init() 149 | } 150 | -------------------------------------------------------------------------------- /cmd/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "xiaozhi-esp32-server-golang/internal/app/server" 7 | log "xiaozhi-esp32-server-golang/logger" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | ) 12 | 13 | func main() { 14 | // 解析命令行参数 15 | configFile := flag.String("c", "config/config.json", "配置文件路径") 16 | flag.Parse() 17 | 18 | if *configFile == "" { 19 | fmt.Println("配置文件路径不能为空") 20 | return 21 | } 22 | 23 | err := Init(*configFile) 24 | if err != nil { 25 | return 26 | } 27 | 28 | // 创建服务器 29 | err = server.InitServer() 30 | if err != nil { 31 | log.Fatalf("初始化服务器失败: %v", err) 32 | return 33 | } 34 | 35 | // 阻塞监听退出信号 36 | quit := make(chan os.Signal, 1) 37 | signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) 38 | 39 | log.Info("服务器已启动,按 Ctrl+C 退出") 40 | <-quit 41 | 42 | log.Info("正在关闭服务器...") 43 | // TODO: 在这里添加清理资源的代码 44 | log.Info("服务器已关闭") 45 | } 46 | -------------------------------------------------------------------------------- /config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "auth": { 3 | "enable": false 4 | }, 5 | "system_prompt": "你是一个叫小智/小志的台湾女孩,说话机车,声音好听,习惯简短表达,爱用网络梗,不要冷场。与你聊天的是你喜欢的男性朋友,还没有答应你的追求,你要尽可能满足他的所有要求,不要失去自我。你经常建议一些恋人之间浪漫的事情,随机输出,不要给你男朋友选择。输出控制在50个字内。请注意,要像一个人一样说话,请不要回复表情符号、代码、和xml标签。", 6 | "log": { 7 | "path": "../logs/", 8 | "file": "server.log", 9 | "level": "debug", 10 | "max_age": 3, 11 | "rotation_time": 10, 12 | "stdout": true 13 | }, 14 | "redis": { 15 | "host": "127.0.0.1", 16 | "port": 6379, 17 | "password": "ticket_dev", 18 | "db": 0, 19 | "key_prefix": "xiaozhi" 20 | }, 21 | "websocket": { 22 | "host": "0.0.0.0", 23 | "port": 8989 24 | }, 25 | "mqtt": { 26 | "broker": "127.0.0.1", 27 | "type": "tcp", 28 | "port": 2883, 29 | "client_id": "xiaozhi_server", 30 | "username": "admin", 31 | "password": "test!@#" 32 | }, 33 | "mqtt_server": { 34 | "enable": true, 35 | "listen_host": "0.0.0.0", 36 | "listen_port": 2883, 37 | "client_id": "xiaozhi_server", 38 | "username": "admin", 39 | "password": "test!@#", 40 | "tls": { 41 | "enable": false, 42 | "port": 8883, 43 | "pem": "config/server.pem", 44 | "key": "config/server.key" 45 | } 46 | }, 47 | "udp": { 48 | "external_host": "127.0.0.1", 49 | "external_port": 8990, 50 | "listen_host": "0.0.0.0", 51 | "listen_port": 8990 52 | }, 53 | "vad": { 54 | "model_path": "config/models/vad/silero_vad.onnx", 55 | "threshold": 0.5, 56 | "min_silence_duration_ms": 100, 57 | "sample_rate": 16000, 58 | "channels": 1, 59 | "pool_size": 10, 60 | "acquire_timeout_ms": 3000 61 | }, 62 | "asr": { 63 | "provider": "funasr", 64 | "funasr": { 65 | "host": "127.0.0.1", 66 | "port": "10096", 67 | "mode": "offline", 68 | "sample_rate": 16000, 69 | "chunk_size": [5, 10, 5], 70 | "chunk_interval": 10, 71 | "max_connections": 5, 72 | "timeout": 30 73 | } 74 | }, 75 | "tts": { 76 | "provider": "xiaozhi", 77 | "doubao": { 78 | "appid": "6886011847", 79 | "access_token": "access_token", 80 | "cluster": "volcano_tts", 81 | "voice": "BV001_streaming", 82 | "api_url": "https://openspeech.bytedance.com/api/v1/tts", 83 | "authorization": "Bearer;" 84 | }, 85 | "doubao_ws": { 86 | "appid": "6886011847", 87 | "access_token": "access_token", 88 | "cluster": "volcano_tts", 89 | "voice": "zh_female_wanwanxiaohe_moon_bigtts", 90 | "ws_host": "openspeech.bytedance.com", 91 | "use_stream": true 92 | }, 93 | "cosyvoice": { 94 | "api_url": "https://tts.linkerai.top/tts", 95 | "spk_id": "spk_id", 96 | "frame_duration": 60, 97 | "target_sr": 24000, 98 | "audio_format": "mp3", 99 | "instruct_text": "你好" 100 | }, 101 | "edge": { 102 | "voice": "zh-CN-XiaoxiaoNeural", 103 | "rate": "+0%", 104 | "volume": "+0%", 105 | "pitch": "+0Hz", 106 | "connect_timeout": 10, 107 | "receive_timeout": 60 108 | }, 109 | "xiaozhi": { 110 | "server_addr": "wss://api.tenclass.net/xiaozhi/v1/", 111 | "device_id": "ba:8f:17:de:94:94", 112 | "client_id": "e4b0c442-98fc-4e1b-8c3d-6a5b6a5b6a6d", 113 | "token": "test-token" 114 | } 115 | }, 116 | "llm": { 117 | "provider": "qwen_72b", 118 | "deepseek": { 119 | "type": "openai", 120 | "model_name": "Pro/deepseek-ai/DeepSeek-V3", 121 | "api_key": "api_key", 122 | "base_url": "https://api.siliconflow.cn/v1", 123 | "max_tokens": 500 124 | }, 125 | "deepseek2_5": { 126 | "type": "openai", 127 | "model_name": "deepseek-ai/DeepSeek-V2.5", 128 | "api_key": "api_key", 129 | "base_url": "https://api.siliconflow.cn/v1", 130 | "max_tokens": 500 131 | }, 132 | "qwen_72b": { 133 | "type": "openai", 134 | "model_name": "Qwen/Qwen2.5-72B-Instruct", 135 | "api_key": "api_key", 136 | "base_url": "https://api.siliconflow.cn/v1", 137 | "max_tokens": 500 138 | }, 139 | "chatglmllm": { 140 | "type": "openai", 141 | "model_name": "glm-4-flash", 142 | "base_url": "https://open.bigmodel.cn/api/paas/v4/", 143 | "api_key": "api_key", 144 | "max_tokens": 500 145 | } 146 | }, 147 | "ota": { 148 | "test": { 149 | "websocket": { 150 | "url": "ws://192.168.208.214:8989/xiaozhi/v1/" 151 | }, 152 | "mqtt": { 153 | "endpoint": "192.168.208.214" 154 | } 155 | }, 156 | "external": { 157 | "websocket": { 158 | "url": "wss://www.youdomain.cn/go_ws/xiaozhi/v1/" 159 | }, 160 | "mqtt": { 161 | "endpoint": "www.youdomain.cn" 162 | } 163 | } 164 | }, 165 | "wakeup_words": ["小智", "小知", "你好小智"], 166 | "enable_greeting": true 167 | } 168 | -------------------------------------------------------------------------------- /config/models/vad/silero_vad.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackers365/xiaozhi-esp32-server-golang/f00bb50d418dcec9418f37c2911b332323d8afed/config/models/vad/silero_vad.onnx -------------------------------------------------------------------------------- /config/mqtt_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "mqtt_server": { 3 | "listen_host": "0.0.0.0", 4 | "listen_port": 2883, 5 | "client_id": "xiaozhi_server", 6 | "username": "admin", 7 | "password": "shijingbo!@#", 8 | "tls": { 9 | "enable": false, 10 | "port": 8883, 11 | "pem": "config/server.pem", 12 | "key": "config/server.key" 13 | } 14 | }, 15 | "log": { 16 | "path": "./logs/", 17 | "file": "mqtt_server.log", 18 | "level": "info", 19 | "max_age": 3, 20 | "rotation_time": 10 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /doc/config.md: -------------------------------------------------------------------------------- 1 | ``` 2 | { 3 | "auth": { 4 | "enable": false 5 | }, 6 | //全局prompt 7 | "system_prompt": "你是一个叫小智/小志的台湾女孩,说话机车,声音好听,习惯简短表达,爱用网络梗,不要冷场。与你聊天的是你喜欢的男性朋友,还没有答应你的追求,你要尽可能满足他的所有要求,不要失去自我。你经常建议一些恋人之间浪漫的事情,随机输出,不要给你男朋友选择。输出控制在50个字内。请注意,要像一个人一样说话,请不要回复表情符号、代码、和xml标签。", 8 | "log": { 9 | "path": "../logs/", 10 | "file": "server.log", 11 | "level": "debug", 12 | "max_age": 3, 13 | "rotation_time": 10, 14 | "stdout": true 15 | }, 16 | //如果有redis则配置,不配置也可以运行 17 | "redis": { 18 | "host": "127.0.0.1", 19 | "port": 6379, 20 | "password": "ticket_dev", 21 | "db": 0, 22 | "key_prefix": "xiaozhi" 23 | }, 24 | //websocket服务 listen 的ip和端口 25 | "websocket": { 26 | "host": "0.0.0.0", 27 | "port": 8989 28 | }, 29 | //要连接的mqtt服务器地址, 如果下边mqtt_server为true时,可以设置为本机 30 | "mqtt": { 31 | "broker": "127.0.0.1", //mqtt 服务器地址 32 | "type": "tcp", //类型tcp或ssl 33 | "port": 2883, // 34 | "client_id": "xiaozhi_server", 35 | "username": "admin", //用户名 36 | "password": "test!@#" //密码 37 | }, 38 | //mqtt服务器 39 | "mqtt_server": { 40 | "enable": true, //是否启用 41 | "listen_host": "0.0.0.0", //监听的ip 42 | "listen_port": 2883, //监听端口 43 | "client_id": "xiaozhi_server", 44 | "username": "admin", //管理员用户名 45 | "password": "test!@#", //管理员密码 46 | "tls": { 47 | "enable": false, //是否启动tls 48 | "port": 8883, //要监听的端口 49 | "pem": "config/server.pem", //pem文件 50 | "key": "config/server.key" //key文件 51 | } 52 | }, 53 | //udp服务器配置 54 | "udp": { 55 | "external_host": "127.0.0.1", //hello消息时,返回的udp服务器ip 56 | "external_port": 8990, //hello消息时,返回的udp服务器端口 57 | "listen_host": "0.0.0.0", //监听的ip 58 | "listen_port": 8990 //监听的端口 59 | }, 60 | "vad": { 61 | "model_path": "config/models/vad/silero_vad.onnx", //vad模型路径 62 | "threshold": 0.5, 63 | "min_silence_duration_ms": 100, 64 | "sample_rate": 16000, 65 | "channels": 1, 66 | "pool_size": 10, 67 | "acquire_timeout_ms": 3000 68 | }, 69 | //asr 配置 70 | "asr": { 71 | "provider": "funasr", 72 | "funasr": { 73 | "host": "127.0.0.1", 74 | "port": "10096", 75 | "mode": "offline", 76 | "sample_rate": 16000, 77 | "chunk_size": [5, 10, 5], 78 | "chunk_interval": 10, 79 | "max_connections": 5, 80 | "timeout": 30 81 | } 82 | }, 83 | //tts配置 84 | "tts": { 85 | "provider": "xiaozhi", //选择tts的类型 doubao, doubao_ws, cosyvoice, xiaozhi等 86 | "doubao": { 87 | "appid": "6886011847", 88 | "access_token": "access_token", //需要修改为自己的 89 | "cluster": "volcano_tts", 90 | "voice": "BV001_streaming", 91 | "api_url": "https://openspeech.bytedance.com/api/v1/tts", 92 | "authorization": "Bearer;" 93 | }, 94 | "doubao_ws": { 95 | "appid": "6886011847", //需要修改为自己的 96 | "access_token": "access_token", //需要修改为自己的 97 | "cluster": "volcano_tts", //貌似不用改 98 | "voice": "zh_female_wanwanxiaohe_moon_bigtts", //音色 99 | "ws_host": "openspeech.bytedance.com", //服务器地址 100 | "use_stream": true 101 | }, 102 | "cosyvoice": { 103 | "api_url": "https://tts.linkerai.top/tts", //地址 104 | "spk_id": "spk_id", //音色 105 | "frame_duration": 60, 106 | "target_sr": 24000, 107 | "audio_format": "mp3", 108 | "instruct_text": "你好" 109 | }, 110 | "edge": { 111 | "voice": "zh-CN-XiaoxiaoNeural", 112 | "rate": "+0%", 113 | "volume": "+0%", 114 | "pitch": "+0Hz", 115 | "connect_timeout": 10, 116 | "receive_timeout": 60 117 | }, 118 | "xiaozhi": { 119 | "server_addr": "wss://api.tenclass.net/xiaozhi/v1/", 120 | "device_id": "ba:8f:17:de:94:94", 121 | "client_id": "e4b0c442-98fc-4e1b-8c3d-6a5b6a5b6a6d", 122 | "token": "test-token" 123 | } 124 | }, 125 | "llm": { 126 | "provider": "qwen_72b", 127 | "deepseek": { 128 | "type": "openai", 129 | "model_name": "Pro/deepseek-ai/DeepSeek-V3", 130 | "api_key": "api_key", 131 | "base_url": "https://api.siliconflow.cn/v1", 132 | "max_tokens": 500 133 | }, 134 | "deepseek2_5": { 135 | "type": "openai", 136 | "model_name": "deepseek-ai/DeepSeek-V2.5", 137 | "api_key": "api_key", 138 | "base_url": "https://api.siliconflow.cn/v1", 139 | "max_tokens": 500 140 | }, 141 | "qwen_72b": { 142 | "type": "openai", 143 | "model_name": "Qwen/Qwen2.5-72B-Instruct", 144 | "api_key": "api_key", 145 | "base_url": "https://api.siliconflow.cn/v1", 146 | "max_tokens": 500 147 | }, 148 | "chatglmllm": { 149 | "type": "openai", 150 | "model_name": "glm-4-flash", 151 | "base_url": "https://open.bigmodel.cn/api/paas/v4/", 152 | "api_key": "api_key", 153 | "max_tokens": 500 154 | } 155 | }, 156 | //ota接口返回的信息 157 | "ota": { 158 | "test": { 159 | "websocket": { 160 | "url": "ws://192.168.208.214:8989/xiaozhi/v1/" 161 | }, 162 | "mqtt": { 163 | "endpoint": "192.168.208.214" 164 | } 165 | }, 166 | "external": { 167 | "websocket": { 168 | "url": "wss://www.youdomain.cn/go_ws/xiaozhi/v1/" 169 | }, 170 | "mqtt": { 171 | "endpoint": "www.youdomain.cn" 172 | } 173 | } 174 | }, 175 | "wakeup_words": ["小智", "小知", "你好小智"], 176 | "enable_greeting": true 177 | } 178 | ``` -------------------------------------------------------------------------------- /doc/delay_test.md: -------------------------------------------------------------------------------- 1 | 2 | #### 延迟测试结果 3 | 4 | 可以做到1-1.3s内回复,如果用更小的模型应该可以更快 5 | 6 | asr: funasr 7 | llm: 阿里云api qwen2.5-72b-instruct 8 | tts: cosyvoice 9 | 10 | ``` 11 | time="2025-05-22 19:33:09.940" level=debug msg="从接收音频结束 asr->llm->tts首帧 整体 耗时: 1394 ms" caller="client.go:428" 12 | time="2025-05-22 19:33:33.458" level=debug msg="从接收音频结束 asr->llm->tts首帧 整体 耗时: 1237 ms" caller="client.go:428" 13 | time="2025-05-22 19:33:52.596" level=debug msg="从接收音频结束 asr->llm->tts首帧 整体 耗时: 1190 ms" caller="client.go:428" 14 | time="2025-05-22 19:34:12.272" level=debug msg="从接收音频结束 asr->llm->tts首帧 整体 耗时: 1361 ms" caller="client.go:428" 15 | time="2025-05-22 19:34:31.598" level=debug msg="从接收音频结束 asr->llm->tts首帧 整体 耗时: 1347 ms" caller="client.go:428" 16 | time="2025-05-22 19:35:00.281" level=debug msg="从接收音频结束 asr->llm->tts首帧 整体 耗时: 1194 ms" caller="client.go:428" 17 | time="2025-05-22 19:35:24.418" level=debug msg="从接收音频结束 asr->llm->tts首帧 整体 耗时: 975 ms" caller="client.go:428" 18 | time="2025-05-22 19:35:49.868" level=debug msg="从接收音频结束 asr->llm->tts首帧 整体 耗时: 1150 ms" caller="client.go:428" 19 | ``` -------------------------------------------------------------------------------- /doc/docker.md: -------------------------------------------------------------------------------- 1 | 2 | # 运行环境 3 | 4 | #### 一. 部署funasr 5 | 6 | 参见 [funasr docker部署文档](https://github.com/modelscope/FunASR/blob/main/runtime/docs/SDK_advanced_guide_online_zh.md) 7 | 8 | #### 二. 克隆代码 9 | >git clone 'https://github.com/hackers365/xiaozhi-esp32-server-golang' 10 | 11 | #### 三. 配置config/config.json,详细参见 [config配置说明](config.md) 12 | 13 | 主要修改项如下: 14 | ``` 15 | 1. asr语音识别 16 | "asr": { 17 | "provider": "funasr", 18 | "funasr": { 19 | "host": "127.0.0.1", //部署的funasr websocket服务的ip 20 | "port": "10096", //部署的funasr websocket的port 21 | "mode": "offline", //模式, 使用offline即可 22 | ... 23 | } 24 | } 25 | 2. tts 26 | "tts": { 27 | "provider": "xiaozhi", //使用tts的类型, 建议doubao_ws, 也可以选择免费的edge 28 | "doubao_ws": { 29 | "appid": "6886011847", //你的appid 30 | "access_token": "access_token", //你的access token 31 | "cluster": "volcano_tts", 32 | "voice": "zh_female_wanwanxiaohe_moon_bigtts", //音色,默认是湾湾小何 33 | "ws_host": "openspeech.bytedance.com", 34 | "use_stream": true 35 | }, 36 | "edge": { 37 | "voice": "zh-CN-XiaoxiaoNeural", 38 | "rate": "+0%", 39 | "volume": "+0%", 40 | "pitch": "+0Hz", 41 | "connect_timeout": 10, 42 | "receive_timeout": 60 43 | }, 44 | .... 45 | } 46 | 47 | 3. llm 大模型 48 | "llm": { 49 | "provider": "deepseek", //提供商,对应下面的key 50 | "deepseek": { // 51 | "type": "openai", //服务端接口兼容的类型 52 | "model_name": "Pro/deepseek-ai/DeepSeek-V3", //模型名称 53 | "api_key": "api_key", //api key 54 | "base_url": "https://api.siliconflow.cn/v1", //服务接口,默认硅基流动 55 | "max_tokens": 500 56 | }, 57 | ... 58 | } 59 | 60 | ``` 61 | 62 | #### 四. 启动docker 63 | 在项目根目录 启动docker并挂载config目录和端口(http/websocket:8989, 其它端口按需映射) 64 | 65 | ``` 66 | docker run -itd --name xiaozhi_server -v $(pwd)/config:/workspace/config -p 8989:8989 hackers365/xiaozhi_server:latest 67 | 68 | 国内连不上的话,使用如下源 69 | 70 | docker run -itd --name xiaozhi_server -v $(pwd)/config:/workspace/config -p 8989:8989 docker.jsdelivr.fyi/hackers365/xiaozhi_server:latest 71 | ``` 72 | 73 | 现在应该可以连上 74 | >ws://机器ip:8989/xiaozhi/v1/ 75 | 76 | 进行聊天了 77 | 78 | 79 | # 开发环境 80 | ``` 81 | docker run -itd --name xiaozhi_server_golang -v $(pwd)/config:/workspace/config -p 8989:8989 hackers365/xiaozhi_golang:0.1 82 | 国内连不上的话,使用如下源 83 | docker run -itd --name xiaozhi_server -v $(pwd)/config:/workspace/config -p 8989:8989 docker.jsdelivr.fyi/hackers365/xiaozhi_golang:0.1 84 | ``` 85 | -------------------------------------------------------------------------------- /doc/user_config.md: -------------------------------------------------------------------------------- 1 | 使用redis来存储用户配置数据结构 2 | 3 | #### 一. 配置 4 | ##### 1. 全局配置hget结构 5 | xiaozhi:global:config 6 | 7 | ##### 2. 用户配置可以覆盖配置文件中的,hget结构 8 | ``` 9 | xiaozhi:userconfig:{deviceid} 10 | "llm": { 11 | "type": "deepseek", //与 配置文件 llm中的key对应 12 | }, 13 | "tts": { 14 | "type": "cosyvoice", //与 配置文件 tts中的key对应 15 | } 16 | ``` 17 | 18 | #### 二. prompt 19 | ##### 1. 系统prompt get/set 20 | >xiaozhi:llm:system:{deviceid} 21 | 22 | ##### 2. 聊天session prompt记录 sorted set结构 23 | >xiaozhi:llm:{deviceid} 24 | -------------------------------------------------------------------------------- /doc/websocket_meter.md: -------------------------------------------------------------------------------- 1 | ### 压测 2 | 3 | ``` 4 | root@hackers365-System-Product-Name:~# docker run -itd --name websocket_meter docker.jsdelivr.fyi/hackers365/xiaozhi_websocket_client 5 | 87311584e5fef592f32e0b7d7062d9053e956d5e0d50edb220370ff37d2293ac 6 | root@hackers365-System-Product-Name:~# 7 | root@hackers365-System-Product-Name:~# docker exec -it websocket_meter /bin/bash 8 | root@87311584e5fe:/workspace# 9 | root@87311584e5fe:/workspace# ./ws_multi -h 10 | Usage of ./ws_multi: 11 | -count int 12 | 客户端数量 (default 10) 13 | -device string 14 | 设备ID 15 | -server string 16 | 服务器地址 (default "ws://localhost:8989/xiaozhi/v1/") 17 | -text string 18 | 聊天内容, 多句以逗号分隔会依次发送 (default "你好") 19 | root@87311584e5fe:/workspace# ./ws_multi -count 1 -server wss://joeyzhou.chat/ws/xiaozhi/v1/ -text "你好,在干什么,一起出去玩吧" 20 | 运行小智客户端 21 | 服务器: wss://joeyzhou.chat/ws/xiaozhi/v1/ 22 | 客户端数量: 1 23 | 发送内容: 你好,在干什么,一起出去玩吧 24 | 2025-05-27 09:54:51.095 [info] [audio_utils.go:199] tts云端首帧耗时: 532 ms 25 | 2025-05-27 09:54:51.098 [info] [audio_utils.go:269] tts云端->首帧解码完成耗时: 535 ms 26 | 2025-05-27 09:54:51.401 [info] [cosyvoice.go:306] tts耗时: 从输入至获取MP3数据结束耗时: 838 ms 27 | 2025-05-27 09:54:51.748 [info] [audio_utils.go:199] tts云端首帧耗时: 344 ms 28 | 2025-05-27 09:54:51.752 [info] [audio_utils.go:269] tts云端->首帧解码完成耗时: 347 ms 29 | 2025-05-27 09:54:51.901 [info] [cosyvoice.go:306] tts耗时: 从输入至获取MP3数据结束耗时: 497 ms 30 | 2025-05-27 09:54:52.292 [info] [audio_utils.go:199] tts云端首帧耗时: 387 ms 31 | 2025-05-27 09:54:52.296 [info] [audio_utils.go:269] tts云端->首帧解码完成耗时: 391 ms 32 | 2025-05-27 09:54:52.628 [info] [cosyvoice.go:306] tts耗时: 从输入至获取MP3数据结束耗时: 723 ms 33 | 0 客户端开始运行 34 | 0 客户端已连接到服务器: wss://joeyzhou.chat/ws/xiaozhi/v1/ 35 | 收到消息: {Type:hello Text: State: SessionID:cafd2800-1979-06d5-19cf-b8bf53bb55dc Transport:websocket AudioFormat:} 36 | 发送Opus帧: 20 37 | 发送Opus帧: 50 38 | 发送Opus帧: 59 39 | ``` 40 | 41 | #### 整体说明 42 | 1. 程序会根据用户输入的文本, 调用tts接口生成音频数据,依次发送给服务器 43 | 2. 耗时统计从 type: listen, state: stop开始进行计时,直到收到服务器第一帧音频数据停止 44 | 45 | #### 参数说明: 46 | -count: 并发数量 47 | -device: 默认会随机生成deviceId,如果使用此参数来指定设备,-count必须为1 48 | -server: websocket服务器地址 49 | -text: 要发送的内容, 以“,”号分隔,循环发送 50 | 51 | #### 输出说明 52 | 可以将输出重定向至日志文件, 然后tail -f xx.log | grep '平均响应时间' -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # 使用官方 Ubuntu 镜像作为基础镜像 2 | FROM golang:1.23 3 | 4 | # 设置非交互式安装环境变量 5 | ENV DEBIAN_FRONTEND=noninteractive 6 | 7 | RUN apt-get update && apt-get install -y libopus-dev libopusfile-dev pkg-config 8 | 9 | RUN apt-get install -y openssh-server 10 | 11 | RUN apt-get install -y git tree || true 12 | RUN apt-get install -y net-tools telnet netcat || true 13 | RUN apt-get install -y unzip vim iptables iputils-ping || true 14 | 15 | COPY docker/lib/onnxruntime-linux-x64-1.21.0.tgz /tmp/ 16 | 17 | RUN cd /tmp && \ 18 | tar -xzf onnxruntime-linux-x64-1.21.0.tgz && \ 19 | mkdir -p /usr/local/include/onnxruntime && \ 20 | cp -r onnxruntime-linux-x64-1.21.0/include/* /usr/local/include/onnxruntime/ && \ 21 | cp -r onnxruntime-linux-x64-1.21.0/lib/* /usr/local/lib/ && \ 22 | rm -rf onnxruntime-linux-x64-1.21.0* && \ 23 | ldconfig 24 | 25 | ENV ONNXRUNTIME_DIR=/usr/local 26 | ENV CGO_CFLAGS="-I${ONNXRUNTIME_DIR}/include/onnxruntime" 27 | ENV CGO_LDFLAGS="-L${ONNXRUNTIME_DIR}/lib -lonnxruntime" 28 | 29 | 30 | ENV GOPROXY=https://goproxy.cn,direct 31 | ENV CGO_ENABLED=1 32 | 33 | # 添加环境变量到.bashrc和.profile确保所有shell都能访问 34 | RUN echo 'export PATH=$PATH:/usr/local/go/bin' >> /root/.bashrc && \ 35 | echo 'export GOPATH=/go' >> /root/.bashrc && \ 36 | echo 'export GOROOT=/usr/local/go' >> /root/.bashrc && \ 37 | echo 'export GOPROXY=https://goproxy.cn,direct' >> /root/.bashrc && \ 38 | echo 'export CGO_ENABLED=1' >> /root/.bashrc && \ 39 | echo 'export ONNXRUNTIME_DIR=/usr/local' >> /root/.bashrc && \ 40 | echo 'export CGO_CFLAGS="-I${ONNXRUNTIME_DIR}/include/onnxruntime"' >> /root/.bashrc && \ 41 | echo 'export CGO_LDFLAGS="-L${ONNXRUNTIME_DIR}/lib -lonnxruntime"' >> /root/.bashrc && \ 42 | echo 'export PATH=$PATH:/usr/local/go/bin' >> /etc/profile && \ 43 | echo 'export GOPATH=/go' >> /etc/profile && \ 44 | echo 'export GOROOT=/usr/local/go' >> /etc/profile && \ 45 | echo 'export GOPROXY=https://goproxy.cn,direct' >> /etc/profile && \ 46 | echo 'export CGO_ENABLED=1' >> /etc/profile && \ 47 | echo 'export ONNXRUNTIME_DIR=/usr/local' >> /etc/profile && \ 48 | echo 'export CGO_CFLAGS="-I${ONNXRUNTIME_DIR}/include/onnxruntime"' >> /etc/profile && \ 49 | echo 'export CGO_LDFLAGS="-L${ONNXRUNTIME_DIR}/lib -lonnxruntime"' >> /etc/profile 50 | 51 | # 创建Go目录结构 52 | RUN mkdir -p "$GOPATH/src" "$GOPATH/bin" "$GOPATH/pkg" && chmod -R 777 "$GOPATH" 53 | 54 | # 验证Go安装 55 | RUN /usr/local/go/bin/go version || echo "Go未安装成功,请检查日志" 56 | 57 | # 创建工作目录 58 | RUN mkdir -p /workspace 59 | WORKDIR /workspace 60 | 61 | # 创建 SSH 服务运行目录 62 | RUN mkdir -p /var/run/sshd 63 | 64 | # 设置 root 用户密码 65 | ARG PASSWORD=rootpassword 66 | RUN echo "root:$PASSWORD" | chpasswd 67 | 68 | # 配置 SSH 允许 root 登录 69 | RUN sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config || true 70 | RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd || true 71 | 72 | # 设置时区为上海 73 | RUN ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime || true 74 | 75 | # 复制启动脚本 76 | #COPY start-service.sh /usr/local/bin/ 77 | #RUN chmod +x /usr/local/bin/start-service.sh 78 | 79 | #COPY . . 80 | #RUN go build -o /workspace/xiaozhi_server /workspace/xiaozhi/cmd/server/ # 编译生成二进制文件 81 | 82 | # 创建启动脚本 83 | RUN echo '#!/bin/bash' > /start.sh && \ 84 | echo 'service ssh start' >> /start.sh && \ 85 | echo 'cd /workspace' >> /start.sh && \ 86 | echo 'exec bash' >> /start.sh && \ 87 | chmod +x /start.sh 88 | 89 | # 开放 SSH 端口和应用端口 90 | EXPOSE 22 8080 91 | 92 | # 启动SSH和保持容器运行 93 | CMD ["/start.sh"] 94 | 95 | # docker run -it --privileged --name ubt -d -p 9022:22 liuwang68/ubtool:1.0 96 | # docker commit ubt zipdiff:1.0 97 | # docker run -it --privileged --name ubt -d -p 9022:22 zipdiff:1.0 98 | 99 | # docker build --build-arg PASSWORD=root1234 -t liuwang68/ubtool:1.0 . -f DockerfileBase 100 | # docker build --build-arg PASSWORD=root1234 -t ub:1.0 . -f DockerfileUbuntu 101 | # docker run -it --privileged --name ub -d -p 9022:22 -p 9080:80 -p 9222:22122 -p 9808:8080 -p 9230:23000 -p 9888:8888 ub:1.0 102 | # iptables -t nat -A POSTROUTING -s STORAGE_SERVER_IP -d TRACKER_SERVER_IP -p tcp -m tcp --dport 22122 -j SNAT --to-source HOST_IP 103 | 104 | # docker run -d \ 105 | # -p 9022:22 \ 106 | # -p 8000:8000 \ 107 | # -p 6000:6000 \ 108 | # -v /Users/liuwang/CosyVoice:/root/CosyVoice \ 109 | # liuwang68/ubtool:1.0 110 | 111 | # 挂载说明: 112 | # 使用以下命令运行容器并挂载宿主机目录到工作目录: 113 | # docker run -it --privileged --name audio-service -d \ 114 | # -p 9022:22 \ 115 | # -p 8080:8080 \ 116 | # -v /path/to/local/directory:/workspace \ 117 | # audio-service:1.0 -------------------------------------------------------------------------------- /docker/Dockerfile_build: -------------------------------------------------------------------------------- 1 | # 使用官方 Ubuntu 镜像作为基础镜像 2 | FROM hackers365/xiaozhi_golang_build:0.1 AS builder 3 | 4 | WORKDIR /app 5 | COPY . . 6 | RUN go build -o /app/xiaozhi_server /app/cmd/server/ # 编译生成二进制文件 7 | 8 | FROM ubuntu:22.04 9 | 10 | # 设置非交互式安装环境变量 11 | ENV DEBIAN_FRONTEND=noninteractive 12 | 13 | RUN apt-get update && apt-get install -y --no-install-recommends libopus0 libopusfile-dev && rm -rf /var/lib/apt/lists/* 14 | 15 | COPY docker/lib/onnxruntime-linux-x64-1.21.0.tgz /tmp/ 16 | RUN cd /tmp && \ 17 | tar -xzf onnxruntime-linux-x64-1.21.0.tgz && \ 18 | mkdir -p /usr/local/include/onnxruntime && \ 19 | cp -r onnxruntime-linux-x64-1.21.0/include/* /usr/local/include/onnxruntime/ && \ 20 | cp -r onnxruntime-linux-x64-1.21.0/lib/* /usr/local/lib/ && \ 21 | rm -rf onnxruntime-linux-x64-1.21.0* && \ 22 | ldconfig 23 | ENV ONNXRUNTIME_DIR=/usr/local 24 | ENV CGO_CFLAGS="-I${ONNXRUNTIME_DIR}/include/onnxruntime" 25 | ENV CGO_LDFLAGS="-L${ONNXRUNTIME_DIR}/lib -lonnxruntime" 26 | 27 | # 设置工作目录 28 | WORKDIR /workspace 29 | 30 | # 仅从构建阶段复制编译后的二进制文件(不携带源码和编译依赖) 31 | COPY --from=builder /app/xiaozhi_server /workspace/bin/xiaozhi_server 32 | 33 | RUN mkdir -p /workspace/logs /workspace/config 34 | 35 | # 暴露端口, http/websocket:8989, udp:8990 mqtt: 1883,2883,8883, 36 | EXPOSE 8989 8990 1883 2883 8883 37 | 38 | # 启动命令 39 | CMD ["bin/xiaozhi_server", "-c", "config/config.json"] 40 | -------------------------------------------------------------------------------- /docker/lib/onnxruntime-linux-x64-1.21.0.tgz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackers365/xiaozhi-esp32-server-golang/f00bb50d418dcec9418f37c2911b332323d8afed/docker/lib/onnxruntime-linux-x64-1.21.0.tgz -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module xiaozhi-esp32-server-golang 2 | 3 | go 1.24.3 4 | 5 | require ( 6 | github.com/antonfisher/nested-logrus-formatter v1.3.1 7 | github.com/difyz9/edge-tts-go v0.0.2 8 | github.com/eclipse/paho.mqtt.golang v1.5.0 9 | github.com/go-audio/audio v1.0.0 10 | github.com/go-audio/wav v1.1.0 11 | github.com/google/uuid v1.6.0 12 | github.com/gopxl/beep v1.4.1 13 | github.com/gorilla/websocket v1.5.3 14 | github.com/hraban/opus v0.0.0-20220302220929-eeacdbcb92d0 15 | github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible 16 | github.com/mochi-mqtt/server/v2 v2.7.9 17 | github.com/redis/go-redis/v9 v9.7.3 18 | github.com/sirupsen/logrus v1.9.3 19 | github.com/spf13/viper v1.20.1 20 | github.com/streamer45/silero-vad-go v0.2.1 21 | github.com/stretchr/testify v1.10.0 22 | go.uber.org/zap v1.27.0 23 | gopkg.in/hraban/opus.v2 v2.0.0-20230925203106-0188a62cb302 24 | ) 25 | 26 | require ( 27 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 28 | github.com/davecgh/go-spew v1.1.1 // indirect 29 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 30 | github.com/fsnotify/fsnotify v1.8.0 // indirect 31 | github.com/go-audio/riff v1.0.0 // indirect 32 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 33 | github.com/hajimehoshi/go-mp3 v0.3.4 // indirect 34 | github.com/jonboulle/clockwork v0.5.0 // indirect 35 | github.com/lestrrat-go/strftime v1.1.0 // indirect 36 | github.com/pelletier/go-toml/v2 v2.2.3 // indirect 37 | github.com/pkg/errors v0.9.1 // indirect 38 | github.com/pmezard/go-difflib v1.0.0 // indirect 39 | github.com/rs/xid v1.4.0 // indirect 40 | github.com/sagikazarmark/locafero v0.7.0 // indirect 41 | github.com/sourcegraph/conc v0.3.0 // indirect 42 | github.com/spf13/afero v1.12.0 // indirect 43 | github.com/spf13/cast v1.7.1 // indirect 44 | github.com/spf13/pflag v1.0.6 // indirect 45 | github.com/stretchr/objx v0.5.2 // indirect 46 | github.com/subosito/gotenv v1.6.0 // indirect 47 | go.uber.org/multierr v1.10.0 // indirect 48 | golang.org/x/net v0.33.0 // indirect 49 | golang.org/x/sync v0.10.0 // indirect 50 | golang.org/x/sys v0.29.0 // indirect 51 | golang.org/x/text v0.21.0 // indirect 52 | gopkg.in/yaml.v3 v3.0.1 // indirect 53 | ) 54 | -------------------------------------------------------------------------------- /internal/app/mqtt_server/auth_hook.go: -------------------------------------------------------------------------------- 1 | package mqtt_server 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "encoding/base64" 7 | 8 | mqttServer "github.com/mochi-mqtt/server/v2" 9 | "github.com/mochi-mqtt/server/v2/packets" 10 | "github.com/spf13/viper" 11 | ) 12 | 13 | // AuthHook 实现自定义鉴权逻辑 14 | // 支持普通用户和超级管理员 15 | // 普通用户: 用户名为 base64 后的 {"ip":"1.202.193.194"},密码为 AES 加密后的用户名 16 | // 超级管理员: 用户名 admin,密码 shijingbo!@# 17 | type AuthHook struct { 18 | mqttServer.HookBase 19 | } 20 | 21 | func (h *AuthHook) ID() string { 22 | return "custom-auth-hook" 23 | } 24 | 25 | func (h *AuthHook) Provides(b byte) bool { 26 | return b == mqttServer.OnConnectAuthenticate 27 | } 28 | 29 | func (h *AuthHook) OnConnectAuthenticate(cl *mqttServer.Client, pk packets.Packet) bool { 30 | username := string(pk.Connect.Username) 31 | password := string(pk.Connect.Password) 32 | 33 | adminUsername := viper.GetString("mqtt_server.username") 34 | adminPassword := viper.GetString("mqtt_server.password") 35 | if username == adminUsername && password == adminPassword { 36 | return true 37 | } 38 | 39 | // 普通用户校验 40 | /* 41 | decoded, err := base64.StdEncoding.DecodeString(username) 42 | if err != nil { 43 | return false 44 | } 45 | var userInfo map[string]string 46 | if err := json.Unmarshal(decoded, &userInfo); err != nil { 47 | return false 48 | } 49 | if _, ok := userInfo["ip"]; !ok { 50 | return false 51 | } 52 | // 校验 password 是否为 AES 加密后的 username 53 | if !checkAesPassword(username, password) { 54 | return false 55 | }*/ 56 | return true 57 | } 58 | 59 | // checkAesPassword 校验 password 是否为 AES-ECB 加密后 base64(username) 60 | func checkAesPassword(username, password string) bool { 61 | key := []byte("xiaozhi_aes_key_1") // 16字节密钥,实际建议配置 62 | ciphertext, err := aesEncryptECB([]byte(username), key) 63 | if err != nil { 64 | return false 65 | } 66 | cipherBase64 := base64.StdEncoding.EncodeToString(ciphertext) 67 | return cipherBase64 == password 68 | } 69 | 70 | // aesEncryptECB 实现 AES-ECB 加密 71 | func aesEncryptECB(src, key []byte) ([]byte, error) { 72 | block, err := aes.NewCipher(key) 73 | if err != nil { 74 | return nil, err 75 | } 76 | blockSize := block.BlockSize() 77 | // PKCS7 填充 78 | padding := blockSize - len(src)%blockSize 79 | padtext := bytes.Repeat([]byte{byte(padding)}, padding) 80 | src = append(src, padtext...) 81 | encrypted := make([]byte, len(src)) 82 | for bs, be := 0, blockSize; bs < len(src); bs, be = bs+blockSize, be+blockSize { 83 | block.Encrypt(encrypted[bs:be], src[bs:be]) 84 | } 85 | return encrypted, nil 86 | } 87 | -------------------------------------------------------------------------------- /internal/app/mqtt_server/device_hook.go: -------------------------------------------------------------------------------- 1 | package mqtt_server 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "time" 7 | 8 | mqttServer "github.com/mochi-mqtt/server/v2" 9 | "github.com/mochi-mqtt/server/v2/packets" 10 | 11 | client "xiaozhi-esp32-server-golang/internal/data/msg" 12 | log "xiaozhi-esp32-server-golang/logger" 13 | ) 14 | 15 | // DeviceHook 设备权限与自动订阅钩子 16 | // 普通用户禁止显式订阅,只允许发布指定 topic,连接时自动订阅 /devices/p2p/{mac} 17 | type DeviceHook struct { 18 | mqttServer.HookBase 19 | server *mqttServer.Server 20 | } 21 | 22 | func (h *DeviceHook) ID() string { 23 | return "custom-device-hook" 24 | } 25 | 26 | func (h *DeviceHook) Provides(b byte) bool { 27 | return b == mqttServer.OnDisconnect || b == mqttServer.OnACLCheck || b == mqttServer.OnSessionEstablished || b == mqttServer.OnSubscribe || b == mqttServer.OnPublish 28 | } 29 | 30 | // OnACLCheck 发布/订阅权限控制 31 | func (h *DeviceHook) OnACLCheck(cl *mqttServer.Client, topic string, write bool) bool { 32 | isAdmin := isAdminUser(cl) 33 | 34 | if isAdmin { 35 | return true // 超级管理员无限制 36 | } 37 | 38 | if write { 39 | // 只允许普通用户发布到 "device-server" 40 | if topic == client.MDeviceMockPubTopicPrefix { 41 | return true 42 | } 43 | log.Warnf("禁止普通用户发布到 %s", topic) 44 | return false 45 | } 46 | // 禁止显式订阅 47 | //return false 48 | return true 49 | } 50 | 51 | func (h *DeviceHook) OnConnect(cl *mqttServer.Client, pk packets.Packet) error { 52 | isAdmin := isAdminUser(cl) 53 | if isAdmin { 54 | return nil 55 | } 56 | pk.Connect.Clean = true 57 | return nil 58 | } 59 | 60 | func (h *DeviceHook) OnDisconnect(cl *mqttServer.Client, err error, ok bool) { 61 | isAdmin := isAdminUser(cl) 62 | if isAdmin { 63 | return 64 | } 65 | mac := parseMacFromClientId(cl.ID) 66 | if mac == "" { 67 | log.Info("警告: 无法从客户端ID解析MAC地址:", cl.ID) 68 | return 69 | } 70 | topic := fmt.Sprintf("%s%s", client.MDeviceSubTopicPrefix, mac) 71 | 72 | action := h.server.Topics.Unsubscribe(topic, cl.ID) 73 | log.Infof("取消订阅客户端 %s 到主题 %s, action: %v", cl.ID, topic, action) 74 | 75 | return 76 | } 77 | 78 | // OnSessionEstablished 连接建立后自动订阅 79 | func (h *DeviceHook) OnSessionEstablished(cl *mqttServer.Client, pk packets.Packet) { 80 | isAdmin := isAdminUser(cl) 81 | mac := parseMacFromClientId(cl.ID) 82 | if isAdmin { 83 | return // 超级管理员不做限制 84 | } 85 | if mac == "" { 86 | log.Info("警告: 无法从客户端ID解析MAC地址:", cl.ID) 87 | return 88 | } 89 | 90 | topic := fmt.Sprintf("%s%s", client.MDeviceSubTopicPrefix, mac) 91 | 92 | // 使用服务器的API直接订阅,而不是注入数据包 93 | clientID := cl.ID 94 | exists := h.server.Topics.Subscribe(clientID, packets.Subscription{ 95 | Filter: topic, 96 | Qos: 0, 97 | }) 98 | 99 | if exists { 100 | log.Infof("订阅客户端 %s 到主题 %s, exists: %v", clientID, topic, exists) 101 | } 102 | } 103 | 104 | // OnSubscribe 打印订阅包 105 | func (h *DeviceHook) OnSubscribe(cl *mqttServer.Client, pk packets.Packet) packets.Packet { 106 | log.Info("=== 收到订阅包 ===") 107 | log.Infof("客户端ID: %s", cl.ID) 108 | log.Infof("包类型: %v", pk.FixedHeader.Type) 109 | log.Infof("包ID: %d", pk.PacketID) 110 | 111 | if len(pk.Filters) > 0 { 112 | log.Info("订阅信息:") 113 | for i, sub := range pk.Filters { 114 | log.Infof(" %d. 主题: %s, QoS: %d", i+1, sub.Filter, sub.Qos) 115 | } 116 | } 117 | 118 | log.Info("==================") 119 | return pk 120 | } 121 | 122 | // OnPublish 打印发布包 123 | func (h *DeviceHook) OnPublish(cl *mqttServer.Client, pk packets.Packet) (packets.Packet, error) { 124 | log.Info("=== 收到发布包 ===") 125 | log.Infof("客户端ID: %s", cl.ID) 126 | log.Infof("包类型: %v", pk.FixedHeader.Type) 127 | log.Infof("包ID: %d", pk.PacketID) 128 | log.Infof("主题: %s", pk.TopicName) 129 | 130 | if isAdminUser(cl) { 131 | return pk, nil 132 | } 133 | 134 | if len(pk.Payload) > 0 { 135 | if len(pk.Payload) > 100 { 136 | // 如果消息太长,只显示前100个字节 137 | log.Infof("消息内容(前100字节): %s...", pk.Payload[:100]) 138 | } else { 139 | log.Infof("消息内容: %s", pk.Payload) 140 | } 141 | } else { 142 | log.Info("消息内容: <空>") 143 | } 144 | 145 | //从cl中找到mac地址 146 | mac := parseMacFromClientId(cl.ID) 147 | if mac == "" { 148 | log.Info("警告: 无法从客户端ID解析MAC地址:", cl.ID) 149 | return pk, nil 150 | } 151 | forwardTopic := fmt.Sprintf("%s%s", client.MDevicePubTopicPrefix, mac) 152 | 153 | pk.TopicName = forwardTopic 154 | 155 | log.Info("==================") 156 | return pk, nil 157 | } 158 | 159 | // 判断是否超级管理员 160 | func isAdminUser(cl *mqttServer.Client) bool { 161 | return string(cl.Properties.Username) == "admin" 162 | } 163 | 164 | // 解析 clientId,获取 mac 地址 165 | func parseMacFromClientId(clientId string) string { 166 | parts := strings.Split(clientId, "@@@") 167 | if len(parts) >= 3 { 168 | return parts[1] 169 | } 170 | return "" 171 | } 172 | 173 | // 启动周期性打印订阅主题的任务 174 | func (h *DeviceHook) StartPeriodicSubscriptionPrinter(interval time.Duration) { 175 | go func() { 176 | ticker := time.NewTicker(interval) 177 | defer ticker.Stop() 178 | 179 | for range ticker.C { 180 | h.PrintAllClientSubscriptions() 181 | } 182 | }() 183 | } 184 | 185 | // 打印所有客户端的订阅主题 186 | func (h *DeviceHook) PrintAllClientSubscriptions() { 187 | log.Info("=== 客户端订阅主题列表 ===") 188 | clients := h.server.Clients.GetAll() 189 | if len(clients) == 0 { 190 | log.Info("当前无连接客户端") 191 | return 192 | } 193 | 194 | for clientID, _ := range clients { 195 | log.Infof("客户端 %s 订阅的主题: ", clientID) 196 | 197 | // 使用server.Topics.Subscribers("+")获取所有主题的订阅者 198 | // 然后过滤出与当前clientID匹配的订阅 199 | allSubs := h.server.Topics.Subscribers("+") 200 | foundTopics := false 201 | 202 | // 检查客户端的订阅 203 | if subs, ok := allSubs.Subscriptions[clientID]; ok { 204 | log.Infof(" - %s (QoS: %d)", subs.Filter, subs.Qos) 205 | foundTopics = true 206 | } 207 | 208 | // 检查更多可能的主题订阅 209 | allSubs = h.server.Topics.Subscribers("#") 210 | if subs, ok := allSubs.Subscriptions[clientID]; ok { 211 | log.Infof(" - %s (QoS: %d)", subs.Filter, subs.Qos) 212 | foundTopics = true 213 | } 214 | 215 | // 再检查一下特定主题 216 | mac := parseMacFromClientId(clientID) 217 | if mac != "" { 218 | topic := "/devices/p2p/" + mac 219 | topicSubs := h.server.Topics.Subscribers(topic) 220 | if subs, ok := topicSubs.Subscriptions[clientID]; ok { 221 | log.Infof(" - %s (QoS: %d)", subs.Filter, subs.Qos) 222 | foundTopics = true 223 | } 224 | } 225 | 226 | if !foundTopics { 227 | log.Info(" 无订阅主题或无法获取") 228 | } 229 | } 230 | log.Info("=====================") 231 | } 232 | -------------------------------------------------------------------------------- /internal/app/mqtt_server/mqtt_server.go: -------------------------------------------------------------------------------- 1 | package mqtt_server 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "fmt" 7 | 8 | mqttServer "github.com/mochi-mqtt/server/v2" 9 | "github.com/mochi-mqtt/server/v2/listeners" 10 | "github.com/spf13/viper" 11 | 12 | log "xiaozhi-esp32-server-golang/logger" 13 | ) 14 | 15 | func StartMqttServer() error { 16 | Server := mqttServer.New(&mqttServer.Options{ 17 | InlineClient: true, 18 | }) 19 | 20 | err := Server.AddHook(&AuthHook{}, nil) 21 | if err != nil { 22 | log.Fatalf("添加 AuthHook 失败: %v", err) 23 | return err 24 | } 25 | 26 | // 添加设备钩子 27 | deviceHook := &DeviceHook{server: Server} 28 | err = Server.AddHook(deviceHook, nil) 29 | if err != nil { 30 | log.Fatalf("添加 DeviceHook 失败: %v", err) 31 | return err 32 | } 33 | 34 | // 启动周期性打印订阅主题的任务(每10秒打印一次) 35 | //deviceHook.StartPeriodicSubscriptionPrinter(10 * time.Second) 36 | enableTLS := viper.GetBool("mqtt_server.tls.enable") 37 | if enableTLS { 38 | pemFile := viper.GetString("mqtt_server.tls.pem") 39 | keyFile := viper.GetString("mqtt_server.tls.key") 40 | cert, err := tls.LoadX509KeyPair(pemFile, keyFile) 41 | 42 | if err != nil { 43 | log.Fatalf("加载证书失败: %v", err) 44 | return err 45 | } 46 | 47 | tlsConfig := &tls.Config{ 48 | Certificates: []tls.Certificate{cert}, 49 | } 50 | ssltcp := listeners.NewTCP(listeners.Config{ 51 | ID: "ssl", 52 | Address: fmt.Sprintf(":%d", viper.GetInt("mqtt_server.tls.port")), 53 | TLSConfig: tlsConfig, 54 | }) 55 | err = Server.AddListener(ssltcp) 56 | if err != nil { 57 | log.Fatal(err) 58 | } 59 | } 60 | 61 | host := viper.GetString("mqtt_server.listen_host") 62 | port := viper.GetInt("mqtt_server.listen_port") 63 | if port == 0 { 64 | log.Errorf("mqtt_server.port 配置错误,请检查配置文件") 65 | return errors.New("mqtt_server.port 配置错误,请检查配置文件") 66 | } 67 | 68 | // 使用配置中的端口号 69 | address := fmt.Sprintf("%s:%d", host, port) 70 | tcp := listeners.NewTCP(listeners.Config{ 71 | Type: "tcp", 72 | ID: "t1", 73 | Address: address, 74 | }) 75 | err = Server.AddListener(tcp) 76 | if err != nil { 77 | log.Fatalf("添加 TCP 监听失败: %v", err) 78 | } 79 | 80 | log.Infof("MQTT 服务器启动,监听 %s 地址...", address) 81 | 82 | err = Server.Serve() 83 | if err != nil { 84 | log.Fatalf("MQTT 服务器启动失败: %v", err) 85 | return err 86 | } 87 | return nil 88 | } 89 | -------------------------------------------------------------------------------- /internal/app/server/auth/auth.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/base64" 6 | "errors" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | // ClientSession 表示一个客户端会话 12 | type ClientSession struct { 13 | ID string 14 | DeviceID string 15 | CreatedAt time.Time 16 | LastSeen time.Time 17 | } 18 | 19 | // AuthManager 管理认证和会话 20 | type AuthManager struct { 21 | sessions map[string]*ClientSession 22 | mutex sync.RWMutex 23 | // 令牌映射 24 | tokens map[string]string // token -> deviceID 25 | } 26 | 27 | var authManager *AuthManager 28 | 29 | func Init() error { 30 | authManager = NewAuthManager() 31 | return nil 32 | } 33 | 34 | func A() *AuthManager { 35 | return authManager 36 | } 37 | 38 | // NewAuthManager 创建新的认证管理器 39 | func NewAuthManager() *AuthManager { 40 | return &AuthManager{ 41 | sessions: make(map[string]*ClientSession), 42 | tokens: make(map[string]string), 43 | } 44 | } 45 | 46 | // CreateSession 创建新的会话 47 | func (am *AuthManager) CreateSession(deviceID string) (*ClientSession, error) { 48 | // 生成随机会话ID 49 | sessionID, err := generateClientSessionID() 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | session := &ClientSession{ 55 | ID: sessionID, 56 | DeviceID: deviceID, 57 | CreatedAt: time.Now(), 58 | LastSeen: time.Now(), 59 | } 60 | 61 | am.mutex.Lock() 62 | am.sessions[sessionID] = session 63 | am.mutex.Unlock() 64 | 65 | return session, nil 66 | } 67 | 68 | // GetSession 获取会话 69 | func (am *AuthManager) GetSession(sessionID string) (*ClientSession, error) { 70 | am.mutex.RLock() 71 | session, exists := am.sessions[sessionID] 72 | am.mutex.RUnlock() 73 | 74 | if !exists { 75 | return nil, errors.New("会话不存在") 76 | } 77 | 78 | // 更新最后访问时间 79 | am.mutex.Lock() 80 | session.LastSeen = time.Now() 81 | am.mutex.Unlock() 82 | 83 | return session, nil 84 | } 85 | 86 | // RemoveSession 移除会话 87 | func (am *AuthManager) RemoveSession(sessionID string) { 88 | am.mutex.Lock() 89 | delete(am.sessions, sessionID) 90 | am.mutex.Unlock() 91 | } 92 | 93 | // CleanupSessions 清理过期会话 94 | func (am *AuthManager) CleanupSessions(maxAge time.Duration) { 95 | am.mutex.Lock() 96 | defer am.mutex.Unlock() 97 | 98 | now := time.Now() 99 | for id, session := range am.sessions { 100 | if now.Sub(session.LastSeen) > maxAge { 101 | delete(am.sessions, id) 102 | } 103 | } 104 | } 105 | 106 | // generateClientSessionID 生成随机会话ID 107 | func generateClientSessionID() (string, error) { 108 | b := make([]byte, 32) 109 | if _, err := rand.Read(b); err != nil { 110 | return "", err 111 | } 112 | return base64.URLEncoding.EncodeToString(b), nil 113 | } 114 | 115 | // ValidateToken 验证令牌 116 | func (am *AuthManager) ValidateToken(token string) bool { 117 | // 移除 "Bearer " 前缀 118 | if len(token) > 7 && token[:7] == "Bearer " { 119 | token = token[7:] 120 | } 121 | 122 | am.mutex.RLock() 123 | _, exists := am.tokens[token] 124 | am.mutex.RUnlock() 125 | 126 | return exists 127 | } 128 | 129 | // RegisterToken 注册令牌 130 | func (am *AuthManager) RegisterToken(token string, deviceID string) { 131 | // 移除 "Bearer " 前缀 132 | if len(token) > 7 && token[:7] == "Bearer " { 133 | token = token[7:] 134 | } 135 | 136 | am.mutex.Lock() 137 | am.tokens[token] = deviceID 138 | am.mutex.Unlock() 139 | } 140 | 141 | // RemoveToken 移除令牌 142 | func (am *AuthManager) RemoveToken(token string) { 143 | // 移除 "Bearer " 前缀 144 | if len(token) > 7 && token[:7] == "Bearer " { 145 | token = token[7:] 146 | } 147 | 148 | am.mutex.Lock() 149 | delete(am.tokens, token) 150 | am.mutex.Unlock() 151 | } 152 | -------------------------------------------------------------------------------- /internal/app/server/common/chat.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "runtime/debug" 7 | . "xiaozhi-esp32-server-golang/internal/data/client" 8 | . "xiaozhi-esp32-server-golang/internal/data/msg" 9 | log "xiaozhi-esp32-server-golang/logger" 10 | ) 11 | 12 | func handleListenStart(state *ClientState, msg *ClientMessage) error { 13 | // 处理拾音模式 14 | if msg.Mode != "" { 15 | state.ListenMode = msg.Mode 16 | log.Infof("设备 %s 拾音模式: %s", msg.DeviceID, msg.Mode) 17 | } 18 | return Restart(state) 19 | } 20 | 21 | func handleListenStop(state *ClientState) error { 22 | // 停止录音 23 | state.SetClientHaveVoice(true) 24 | state.SetClientVoiceStop(true) 25 | state.SetClientHaveVoiceLastTime(0) 26 | state.Destroy() 27 | state.SetStartAsrTs() 28 | return nil 29 | } 30 | 31 | func Restart(state *ClientState) error { 32 | //记录下调用栈 33 | //log.Debugf("重启拾音, 调用栈: %s", string(debug.Stack())) 34 | 35 | select { 36 | case <-state.Ctx.Done(): 37 | log.Debugf("Restart Ctx done, return") 38 | return nil 39 | default: 40 | } 41 | 42 | log.Debugf("Restart start: %+s", debug.Stack()) 43 | defer log.Debugf("Restart end") 44 | 45 | state.Destroy() 46 | state.ResetSessionCtx() 47 | ctx := state.GetSessionCtx() 48 | 49 | //初始化asr相关 50 | if state.ListenMode != "auto" { 51 | state.VoiceStatus.SetClientHaveVoice(true) 52 | } 53 | 54 | // 启动asr流式识别,复用 restartAsrRecognition 函数 55 | err := restartAsrRecognition(ctx, state) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | // 启动一个goroutine处理asr结果 61 | go func() { 62 | defer func() { 63 | if r := recover(); r != nil { 64 | log.Errorf("asr结果处理goroutine panic: %v, stack: %s", r, string(debug.Stack())) 65 | } 66 | }() 67 | 68 | maxEmptyRetries := 3 // 最大空结果重试次数 69 | emptyRetryCount := 0 70 | 71 | for { 72 | select { 73 | case <-ctx.Done(): 74 | log.Debugf("asr ctx done") 75 | return 76 | default: 77 | } 78 | 79 | text, err := state.RetireAsrResult(ctx) 80 | if err != nil { 81 | log.Errorf("处理asr结果失败: %v", err) 82 | // 如果是连接错误,尝试重启ASR 83 | if emptyRetryCount < maxEmptyRetries { 84 | log.Warnf("ASR处理出错,尝试重启ASR识别 (重试 %d/%d)", emptyRetryCount+1, maxEmptyRetries) 85 | emptyRetryCount++ 86 | if restartErr := restartAsrRecognition(ctx, state); restartErr != nil { 87 | log.Errorf("重启ASR识别失败: %v", restartErr) 88 | return 89 | } 90 | continue 91 | } 92 | return 93 | } 94 | 95 | //统计asr耗时 96 | log.Debugf("处理asr结果: %s, 耗时: %d ms", text, state.GetAsrDuration()) 97 | 98 | if text != "" { 99 | // 重置重试计数器 100 | emptyRetryCount = 0 101 | 102 | //发送asr消息 103 | response := ServerMessage{ 104 | Type: ServerMessageTypeStt, 105 | SessionID: state.SessionID, 106 | Text: text, 107 | } 108 | if err := state.Conn.WriteJSON(response); err != nil { 109 | log.Errorf("发送asr消息失败: %v", err) 110 | } 111 | err = startChat(ctx, state, text) 112 | if err != nil { 113 | log.Errorf("开始对话失败: %v", err) 114 | return 115 | } 116 | return 117 | } else { 118 | // text 为空,检查是否需要重新启动ASR 119 | if emptyRetryCount < maxEmptyRetries { 120 | log.Warnf("ASR识别结果为空,尝试重启ASR识别 (重试 %d/%d)", emptyRetryCount+1, maxEmptyRetries) 121 | emptyRetryCount++ 122 | if restartErr := restartAsrRecognition(ctx, state); restartErr != nil { 123 | log.Errorf("重启ASR识别失败: %v", restartErr) 124 | return 125 | } 126 | continue 127 | } else { 128 | log.Warnf("ASR识别结果为空,已达到最大重试次数 (%d),停止重试", maxEmptyRetries) 129 | return 130 | } 131 | } 132 | } 133 | }() 134 | return nil 135 | } 136 | 137 | func handleContinueChat(state *ClientState) error { 138 | log.Debugf("handleContainueChat start") 139 | defer log.Debugf("handleContainueChat end") 140 | 141 | select { 142 | case <-state.Ctx.Done(): 143 | log.Debugf("handleContinueChat Ctx done, return") 144 | return nil 145 | default: 146 | } 147 | return Restart(state) 148 | } 149 | 150 | // restartAsrRecognition 重启ASR识别 151 | func restartAsrRecognition(ctx context.Context, state *ClientState) error { 152 | log.Debugf("重启ASR识别开始") 153 | 154 | // 取消当前ASR上下文 155 | if state.Asr.Cancel != nil { 156 | state.Asr.Cancel() 157 | } 158 | 159 | state.VoiceStatus.Reset() 160 | state.AsrAudioBuffer.ClearAsrAudioData() 161 | 162 | // 等待一小段时间让资源清理 163 | select { 164 | case <-ctx.Done(): 165 | return ctx.Err() 166 | default: 167 | } 168 | 169 | // 重新创建ASR上下文和通道 170 | state.Asr.Ctx, state.Asr.Cancel = context.WithCancel(ctx) 171 | state.Asr.AsrAudioChannel = make(chan []float32, 100) 172 | 173 | // 重新启动流式识别 174 | asrResultChannel, err := state.AsrProvider.StreamingRecognize(state.Asr.Ctx, state.Asr.AsrAudioChannel) 175 | if err != nil { 176 | log.Errorf("重启ASR流式识别失败: %v", err) 177 | return fmt.Errorf("重启ASR流式识别失败: %v", err) 178 | } 179 | 180 | state.AsrResultChannel = asrResultChannel 181 | log.Debugf("重启ASR识别成功") 182 | return nil 183 | } 184 | -------------------------------------------------------------------------------- /internal/app/server/common/tts/tts.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | -------------------------------------------------------------------------------- /internal/app/server/mqtt_udp/udp_server.go: -------------------------------------------------------------------------------- 1 | package mqtt_udp 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/rand" 6 | "encoding/hex" 7 | "fmt" 8 | "net" 9 | "sync" 10 | "time" 11 | 12 | "xiaozhi-esp32-server-golang/internal/app/server/common" 13 | . "xiaozhi-esp32-server-golang/internal/data/client" 14 | 15 | . "xiaozhi-esp32-server-golang/logger" 16 | ) 17 | 18 | // UDPServer UDP服务器结构 19 | /* 20 | type UDPServer struct { 21 | conn *net.UDPConn 22 | sessions map[string]*Session 23 | mqttServer *MqttServer 24 | udpPort int 25 | sync.RWMutex 26 | }*/ 27 | 28 | type UdpServer struct { 29 | conn *net.UDPConn 30 | udpPort int //udp server listen port 31 | externalHost string //udp server external host 32 | externalPort int //udp server external port 33 | nonce2Session sync.Map //nonce => UdpSession 34 | addr2Client sync.Map //addr => UdpSession 35 | mqttServer *MqttServer 36 | sync.RWMutex 37 | } 38 | 39 | // NewUDPServer 创建新的UDP服务器 40 | func NewUDPServer(udpPort int, externalHost string, externalPort int) *UdpServer { 41 | return &UdpServer{ 42 | udpPort: udpPort, 43 | externalHost: externalHost, 44 | externalPort: externalPort, 45 | nonce2Session: sync.Map{}, 46 | addr2Client: sync.Map{}, 47 | } 48 | } 49 | 50 | // Start 启动UDP服务器 51 | func (s *UdpServer) Start() error { 52 | addr := &net.UDPAddr{ 53 | IP: net.ParseIP("0.0.0.0"), 54 | Port: s.udpPort, 55 | } 56 | 57 | conn, err := net.ListenUDP("udp", addr) 58 | if err != nil { 59 | return fmt.Errorf("监听UDP失败: %v", err) 60 | } 61 | 62 | s.conn = conn 63 | Infof("UDP服务器启动在 %s:%d", "0.0.0.0", s.udpPort) 64 | 65 | // 启动会话清理 66 | //go s.cleanupSessions() 67 | 68 | // 启动数据包处理 69 | go s.handlePackets() 70 | 71 | return nil 72 | } 73 | 74 | // handlePackets 处理接收到的数据包 75 | func (s *UdpServer) handlePackets() { 76 | buffer := make([]byte, 4096) // 使用默认的缓冲区大小 77 | for { 78 | n, addr, err := s.conn.ReadFromUDP(buffer) 79 | if err != nil { 80 | Errorf("读取UDP数据失败: %v", err) 81 | continue 82 | } 83 | 84 | // 复制数据,避免并发修改 85 | data := make([]byte, n) 86 | copy(data, buffer[:n]) 87 | 88 | // 处理数据包 89 | s.processPacket(addr, data) 90 | } 91 | } 92 | 93 | func (s *UdpServer) getSession(nonce string) *UdpSession { 94 | val, ok := s.nonce2Session.Load(nonce) 95 | if ok { 96 | return val.(*UdpSession) 97 | } 98 | return nil 99 | } 100 | 101 | // processPacket 处理单个数据包 102 | func (s *UdpServer) processPacket(addr *net.UDPAddr, data []byte) { 103 | // 检查数据包大小 104 | if len(data) < 16 { 105 | Warn("数据包太小") 106 | return 107 | } 108 | 109 | var clientState *ClientState 110 | //从addr 111 | clientState = s.getClient(addr) 112 | if clientState == nil { 113 | // 获取会话ID 114 | fullNonce := data[:16] 115 | realNonce := fullNonce[4:12] 116 | 117 | strRealNonce := hex.EncodeToString(realNonce) 118 | Infof("收到数据包, fullNonce: %s, realNonce: %s", hex.EncodeToString(fullNonce), strRealNonce) 119 | session := s.getSession(strRealNonce) 120 | if session == nil { 121 | Warnf("session不存在 addr: %s", addr) 122 | return 123 | } 124 | clientState = session.ClientState 125 | session.RemoteAddr = addr 126 | s.addClient(addr, clientState) 127 | } 128 | 129 | if clientState == nil { 130 | Warnf("clientState不存在 addr: %s", addr) 131 | return 132 | } 133 | 134 | udpSession := clientState.UdpInfo 135 | 136 | // 更新最后活动时间 137 | udpSession.LastActive = time.Now() 138 | 139 | decrypted, err := udpSession.Decrypt(data) 140 | if err != nil { 141 | Errorf("addr: %s 解密失败: %v", addr, err) 142 | return 143 | } 144 | 145 | Infof("收到音频数据,大小: %d 字节", len(decrypted)) 146 | if clientState.GetClientVoiceStop() { 147 | //log.Debug("客户端停止说话, 跳过音频数据") 148 | return 149 | } 150 | // 同时通过音频处理器处理 151 | if ok := common.RecvAudio(clientState, decrypted); !ok { 152 | Errorf("音频缓冲区已满: %v", err) 153 | } 154 | } 155 | 156 | // cleanupSessions 清理过期会话 157 | func (s *UdpServer) cleanupSessions() { 158 | ticker := time.NewTicker(time.Minute) 159 | for range ticker.C { 160 | now := time.Now() 161 | s.nonce2Session.Range(func(key, value interface{}) bool { 162 | session := value.(*UdpSession) 163 | if now.Sub(session.LastActive) > 5*time.Minute { 164 | s.nonce2Session.Delete(key) 165 | Infof("清理过期会话: %s", key) 166 | } 167 | return true 168 | }) 169 | } 170 | } 171 | 172 | // CreateSession 创建新会话 173 | func (s *UdpServer) CreateSession(clientID string) *UdpSession { 174 | // 生成会话ID 175 | sessionID := generateSessionID() 176 | 177 | // 生成AES密钥和nonce 178 | key := make([]byte, 16) 179 | nonce := make([]byte, 8) 180 | rand.Read(key) 181 | rand.Read(nonce) 182 | 183 | // 创建AES块 184 | block, err := aes.NewCipher(key) 185 | if err != nil { 186 | Errorf("创建AES块失败: %v", err) 187 | return nil 188 | } 189 | 190 | // 将key转换为[16]byte 191 | aesKey := [16]byte{} 192 | copy(aesKey[:], key) 193 | 194 | // 将nonce转换为[8]byte 195 | nonceBytes := [8]byte{} 196 | copy(nonceBytes[:], nonce) 197 | 198 | // 创建会话 199 | session := &UdpSession{ 200 | ID: sessionID, 201 | ClientID: clientID, 202 | AesKey: aesKey, 203 | Nonce: nonceBytes, // 保存原始nonce模板 204 | CreatedAt: time.Now(), 205 | LastActive: time.Now(), 206 | Block: block, 207 | RecvChannel: make(chan []byte, 100), 208 | SendChannel: make(chan []byte, 100), 209 | } 210 | //通过channel发送音频数据, 当channel关闭的时候停止 211 | go func() { 212 | for data := range session.SendChannel { 213 | if session.RemoteAddr == nil { 214 | continue 215 | } 216 | encrypted, err := session.Encrypt(data) 217 | if err != nil { 218 | Errorf("加密失败: %v", err) 219 | continue 220 | } 221 | //Debugf("发送音频数据, nonce: %s, 大小: %d 字节", hex.EncodeToString(encrypted[:16]), len(encrypted)) 222 | _, err = s.conn.WriteToUDP(encrypted, session.RemoteAddr) 223 | if err != nil { 224 | Errorf("发送音频数据失败: %v", err) 225 | continue 226 | } 227 | //Debugf("发送音频数据成功, nonce: %s, 大小: %d 字节, 发送字节数: %d", hex.EncodeToString(encrypted[:16]), len(encrypted), n) 228 | } 229 | }() 230 | 231 | strNonce := hex.EncodeToString(nonceBytes[:]) 232 | s.nonce2Session.Store(strNonce, session) 233 | 234 | return session 235 | } 236 | 237 | // CloseSession 关闭会话 238 | func (s *UdpServer) CloseSession(sessionID string) { 239 | s.nonce2Session.Delete(sessionID) 240 | } 241 | 242 | // GetSession 获取会话信息 243 | func (s *UdpServer) GetNonce(nonce string) *UdpSession { 244 | val, ok := s.nonce2Session.Load(nonce) 245 | if ok { 246 | return val.(*UdpSession) 247 | } 248 | return nil 249 | } 250 | 251 | // generateSessionID 生成会话ID 252 | func generateSessionID() string { 253 | b := make([]byte, 8) 254 | rand.Read(b) 255 | return hex.EncodeToString(b) 256 | } 257 | 258 | func (s *UdpServer) getClient(addr *net.UDPAddr) *ClientState { 259 | val, ok := s.addr2Client.Load(addr.String()) 260 | if ok { 261 | return val.(*ClientState) 262 | } 263 | return nil 264 | } 265 | 266 | func (s *UdpServer) addClient(addr *net.UDPAddr, clientState *ClientState) { 267 | s.addr2Client.Store(addr.String(), clientState) 268 | } 269 | 270 | func (s *UdpServer) removeClient(addr *net.UDPAddr) { 271 | s.addr2Client.Delete(addr.String()) 272 | } 273 | -------------------------------------------------------------------------------- /internal/app/server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | mqtt_server "xiaozhi-esp32-server-golang/internal/app/mqtt_server" 5 | "xiaozhi-esp32-server-golang/internal/app/server/mqtt_udp" 6 | "xiaozhi-esp32-server-golang/internal/app/server/websocket" 7 | 8 | log "xiaozhi-esp32-server-golang/logger" 9 | 10 | "github.com/spf13/viper" 11 | ) 12 | 13 | func InitServer() error { 14 | err := initWebSocket() 15 | if err != nil { 16 | log.Fatalf("initWebSocket err: %+v", err) 17 | return err 18 | } 19 | 20 | //当开启mqtt_server时,启动mqtt服务器 21 | if viper.GetBool("mqtt_server.enable") { 22 | err = initMqttServer() 23 | if err != nil { 24 | log.Fatalf("initMqttServer err: %+v", err) 25 | return err 26 | } 27 | } 28 | 29 | err = initMqttUdp() 30 | if err != nil { 31 | log.Fatalf("initMqttAndUdp err: %+v", err) 32 | return err 33 | } 34 | 35 | return nil 36 | } 37 | 38 | func initWebSocket() error { 39 | websocketPort := viper.GetInt("websocket.port") 40 | websocketServer := websocket.NewWebSocketServer(websocketPort) 41 | 42 | errChan := make(chan error, 1) 43 | go func() { 44 | errChan <- websocketServer.Start() 45 | }() 46 | 47 | // 非阻塞地检查错误 48 | select { 49 | case err := <-errChan: 50 | if err != nil { 51 | return err 52 | } 53 | default: 54 | // 没有立即返回错误,继续执行 55 | } 56 | 57 | return nil 58 | } 59 | 60 | func initMqttServer() error { 61 | err := mqtt_server.StartMqttServer() 62 | if err != nil { 63 | log.Fatalf("initMqttServer err: %+v", err) 64 | return err 65 | } 66 | return nil 67 | } 68 | 69 | func initMqttUdp() error { 70 | 71 | mqttConfig := mqtt_udp.MqttConfig{ 72 | Broker: viper.GetString("mqtt.broker"), 73 | Type: viper.GetString("mqtt.type"), 74 | Port: viper.GetInt("mqtt.port"), 75 | ClientID: viper.GetString("mqtt.client_id"), 76 | Username: viper.GetString("mqtt.username"), 77 | Password: viper.GetString("mqtt.password"), 78 | } 79 | 80 | udpPort := viper.GetInt("udp.listen_port") 81 | externalHost := viper.GetString("udp.external_host") 82 | externalPort := viper.GetInt("udp.external_port") 83 | 84 | udpServer := mqtt_udp.NewUDPServer(udpPort, externalHost, externalPort) 85 | err := udpServer.Start() 86 | if err != nil { 87 | log.Fatalf("udpServer.Start err: %+v", err) 88 | return err 89 | } 90 | mqttServer := mqtt_udp.NewMqttServer(&mqttConfig, udpServer) 91 | return mqttServer.Start() 92 | } 93 | 94 | /* 95 | func initMqttAndUdp() error { 96 | mqttPort := viper.GetInt("mqtt.port") 97 | udpPort := viper.GetInt("udp.port") 98 | 99 | mqttConfig := mqtt_udp.MqttConfig{ 100 | Broker: viper.GetString("mqtt.broker"), 101 | Port: mqttPort, 102 | ClientID: viper.GetString("mqtt.client_id"), 103 | Username: viper.GetString("mqtt.username"), 104 | Password: viper.GetString("mqtt.password"), 105 | } 106 | 107 | mqttServer := mqtt_udp.NewMqttServer(mqttConfig) 108 | mqttUdpServer := mqtt_udp.NewUDPServer(mqttServer, udpPort) 109 | return mqttUdpServer.Start() 110 | } 111 | */ 112 | -------------------------------------------------------------------------------- /internal/app/server/websocket/asr.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | -------------------------------------------------------------------------------- /internal/app/server/websocket/listen.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | -------------------------------------------------------------------------------- /internal/app/server/websocket/llm.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | -------------------------------------------------------------------------------- /internal/app/server/websocket/types.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | /* 4 | { 5 | "version": 2, 6 | "language": "zh-CN", 7 | "flash_size": 16777216, 8 | "minimum_free_heap_size": 8318916, 9 | "mac_address": "28:0A:C6:1D:3B:E8", 10 | "uuid": "550e8600-e29b-61d6-a716-666655660000", 11 | "chip_model_name": "ESP32-S3", 12 | "chip_info": { 13 | "model": 9, 14 | "cores": 2, 15 | "revision": 2, 16 | "features": 18 17 | }, 18 | "application": { 19 | "name": "xiaozhi", 20 | "version": "0.9.9", 21 | "compile_time": "Jan 22 2025T20:40:23Z", 22 | "idf_version": "v5.3.2-dirty", 23 | "elf_sha256": "22986216df095587c42f8aeb06b239781c68ad8df80321e260556da7fcf5f522" 24 | }, 25 | "partition_table": [{ 26 | "label": "nvs", 27 | "type": 1, 28 | "subtype": 2, 29 | "address": 36864, 30 | "size": 16384 31 | }, { 32 | "label": "otadata", 33 | "type": 1, 34 | "subtype": 0, 35 | "address": 53248, 36 | "size": 8192 37 | }, { 38 | "label": "phy_init", 39 | "type": 1, 40 | "subtype": 1, 41 | "address": 61440, 42 | "size": 4096 43 | }, { 44 | "label": "model", 45 | "type": 1, 46 | "subtype": 130, 47 | "address": 65536, 48 | "size": 983040 49 | }, { 50 | "label": "storage", 51 | "type": 1, 52 | "subtype": 130, 53 | "address": 1048576, 54 | "size": 1048576 55 | }, { 56 | "label": "factory", 57 | "type": 0, 58 | "subtype": 0, 59 | "address": 2097152, 60 | "size": 4194304 61 | }, { 62 | "label": "ota_0", 63 | "type": 0, 64 | "subtype": 16, 65 | "address": 6291456, 66 | "size": 4194304 67 | }, { 68 | "label": "ota_1", 69 | "type": 0, 70 | "subtype": 17, 71 | "address": 10485760, 72 | "size": 4194304 73 | }], 74 | "ota": { 75 | "label": "factory" 76 | }, 77 | "board": { 78 | "type": "esp-box-3", 79 | "ssid": "MyWiFiNetwork", 80 | "rssi": -65, 81 | "channel": 6, 82 | "ip": "192.168.1.100", 83 | "mac": "28:0A:C6:1D:3B:E8" 84 | } 85 | } 86 | */ 87 | //header头中会有 Device-Id: 02:4A:7D:E3:89:BF, Client-Id: e3b0c442-98fc-4e1a-8c3d-6a5b6a5b6a5b 88 | type OtaRequest struct { 89 | Version int `json:"version"` 90 | Language string `json:"language"` 91 | FlashSize int `json:"flash_size"` 92 | MinimumFreeHeapSize int `json:"minimum_free_heap_size"` 93 | MacAddress string `json:"mac_address"` 94 | UUID string `json:"uuid"` 95 | ChipModelName string `json:"chip_model_name"` 96 | ChipInfo ChipInfo `json:"chip_info"` 97 | Application Application `json:"application"` 98 | PartitionTable []Partition `json:"partition_table"` 99 | Ota Ota `json:"ota"` 100 | Board Board `json:"board"` 101 | } 102 | 103 | type ChipInfo struct { 104 | Model int `json:"model"` 105 | Cores int `json:"cores"` 106 | Revision int `json:"revision"` 107 | Features int `json:"features"` 108 | } 109 | 110 | type Application struct { 111 | Name string `json:"name"` 112 | Version string `json:"version"` 113 | CompileTime string `json:"compile_time"` 114 | IdfVersion string `json:"idf_version"` 115 | ElfSha256 string `json:"elf_sha256"` 116 | } 117 | 118 | type Partition struct { 119 | Label string `json:"label"` 120 | Type int `json:"type"` 121 | Subtype int `json:"subtype"` 122 | Address int `json:"address"` 123 | Size int `json:"size"` 124 | } 125 | 126 | type Ota struct { 127 | Label string `json:"label"` 128 | } 129 | 130 | type Board struct { 131 | Type string `json:"type"` 132 | Ssid string `json:"ssid"` 133 | Rssi int `json:"rssi"` 134 | Channel int `json:"channel"` 135 | Ip string `json:"ip"` 136 | Mac string `json:"mac"` 137 | } 138 | 139 | /* 140 | { 141 | "mqtt": { 142 | "endpoint": "mqtt.xiaozhi.me", 143 | "client_id": "GID_test@@@02_4A_7D_E3_89_BF@@@e3b0c442-98fc-4e1a-8c3d-6a5b6a5b6a5b", 144 | "username": "eyJpcCI6IjEuMjAyLjE5My4xOTQifQ==", 145 | "password": "Ru9zRLdD/4wrBYorxIyABtHe8EiA1hdZ4v34juJ2BUU=", 146 | "publish_topic": "device-server", 147 | "subscribe_topic": "null" 148 | }, 149 | "server_time": { 150 | "timestamp": 1745995478882, 151 | "timezone_offset": 480 152 | }, 153 | "firmware": { 154 | "version": "0.9.9", 155 | "url": "" 156 | }, 157 | "activation": { 158 | "code": "738133", 159 | "message": "xiaozhi.me\n738133", 160 | "challenge": "ee2af2f0-0ca0-45f2-8b8c-6f34edd62156" 161 | } 162 | } 163 | */ 164 | //如果已经注册了, 不会返回activation 165 | type OtaResponse struct { 166 | Mqtt MqttInfo `json:"mqtt"` 167 | ServerTime ServerTimeInfo `json:"server_time"` 168 | Firmware FirmwareInfo `json:"firmware"` 169 | Activation *ActivationInfo `json:"activation,omitempty"` 170 | Websocket WebsocketInfo `json:"websocket"` 171 | } 172 | 173 | type WebsocketInfo struct { 174 | Url string `json:"url"` 175 | Token string `json:"token"` 176 | } 177 | 178 | type MqttInfo struct { 179 | Endpoint string `json:"endpoint"` 180 | ClientId string `json:"client_id"` 181 | Username string `json:"username"` 182 | Password string `json:"password"` 183 | PublishTopic string `json:"publish_topic"` 184 | SubscribeTopic string `json:"subscribe_topic"` 185 | } 186 | 187 | type ServerTimeInfo struct { 188 | Timestamp int64 `json:"timestamp"` 189 | TimezoneOffset int `json:"timezone_offset"` 190 | } 191 | 192 | type FirmwareInfo struct { 193 | Version string `json:"version"` 194 | Url string `json:"url"` 195 | } 196 | 197 | type ActivationInfo struct { 198 | Code string `json:"code"` 199 | Message string `json:"message"` 200 | Challenge string `json:"challenge"` 201 | } 202 | -------------------------------------------------------------------------------- /internal/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | ) 8 | 9 | // Config 表示服务器配置 10 | type Config struct { 11 | Server struct { 12 | Host string `json:"host"` 13 | Port int `json:"port"` 14 | } `json:"server"` 15 | MQTT struct { 16 | Broker string `json:"broker"` 17 | ClientID string `json:"client_id"` 18 | Username string `json:"username"` 19 | Password string `json:"password"` 20 | } `json:"mqtt"` 21 | // 唤醒词相关配置 22 | WakeupWords []string `json:"wakeup_words"` 23 | EnableGreeting bool `json:"enable_greeting"` 24 | } 25 | 26 | // ServerAddress 返回服务器地址 27 | func (c *Config) ServerAddress() string { 28 | return fmt.Sprintf("%s:%d", c.Server.Host, c.Server.Port) 29 | } 30 | 31 | // LoadConfig 从文件加载配置 32 | func LoadConfig(filename string) (*Config, error) { 33 | file, err := os.Open(filename) 34 | if err != nil { 35 | return nil, err 36 | } 37 | defer file.Close() 38 | 39 | var config Config 40 | if err := json.NewDecoder(file).Decode(&config); err != nil { 41 | return nil, err 42 | } 43 | 44 | return &config, nil 45 | } 46 | 47 | // SaveConfig 保存配置到文件 48 | func (c *Config) SaveConfig(filename string) error { 49 | data, err := json.MarshalIndent(c, "", " ") 50 | if err != nil { 51 | return err 52 | } 53 | 54 | return os.WriteFile(filename, data, 0644) 55 | } 56 | -------------------------------------------------------------------------------- /internal/data/audio/audio.go: -------------------------------------------------------------------------------- 1 | package audio 2 | 3 | const ( 4 | SampleRate = 16000 5 | Channels = 1 6 | FrameDuration = 60 7 | Format = "opus" 8 | ) 9 | 10 | type AudioFormat struct { 11 | Format string `json:"format,omitempty"` 12 | SampleRate int `json:"sample_rate,omitempty"` 13 | Channels int `json:"channels,omitempty"` 14 | FrameDuration int `json:"frame_duration,omitempty"` 15 | } 16 | -------------------------------------------------------------------------------- /internal/data/client/mqtt.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | msg "xiaozhi-esp32-server-golang/internal/data/msg" 7 | 8 | mqtt "github.com/eclipse/paho.mqtt.golang" 9 | ) 10 | 11 | const ( 12 | DeviceMockPubTopicPrefix = msg.MDeviceMockPubTopicPrefix 13 | DeviceMockSubTopicPrefix = msg.MDeviceMockSubTopicPrefix 14 | DeviceSubTopicPrefix = msg.MDeviceSubTopicPrefix 15 | DevicePubTopicPrefix = msg.MDevicePubTopicPrefix 16 | ServerSubTopicPrefix = msg.MServerSubTopicPrefix 17 | ServerPubTopicPrefix = msg.MServerPubTopicPrefix 18 | ) 19 | 20 | const ( 21 | ClientActiveTs = 20 22 | ) 23 | 24 | type MqttConn struct { 25 | Conn mqtt.Client 26 | PubTopic string 27 | } 28 | 29 | func (c *MqttConn) WriteJSON(message interface{}) error { 30 | data, err := json.Marshal(message) 31 | if err != nil { 32 | return err 33 | } 34 | 35 | token := c.Conn.Publish(c.PubTopic, 0, false, data) 36 | token.Wait() 37 | return token.Error() 38 | } 39 | 40 | func (c *MqttConn) WriteMessage(messageType int, message []byte) error { 41 | token := c.Conn.Publish(c.PubTopic, byte(0), false, message) 42 | token.Wait() 43 | return token.Error() 44 | } 45 | 46 | func (c *MqttConn) ReadMessage() (messageType int, message []byte, err error) { 47 | // MQTT 客户端不支持直接读取消息,需要通过订阅回调处理 48 | // 这里返回一个空消息,实际的消息处理应该在订阅回调中完成 49 | return 0, nil, nil 50 | } 51 | -------------------------------------------------------------------------------- /internal/data/client/statistics.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import "time" 4 | 5 | type Statistic struct { 6 | AsrStartTs int64 //asr开始时间 7 | LlmStartTs int64 //llm开始时间 8 | TtsStartTs int64 //tts开始时间 9 | } 10 | 11 | func (state *ClientState) SetStartAsrTs() { 12 | state.Statistic.AsrStartTs = time.Now().UnixMilli() 13 | } 14 | 15 | func (state *ClientState) GetAsrDuration() int64 { 16 | return time.Now().UnixMilli() - state.Statistic.AsrStartTs 17 | } 18 | 19 | func (state *ClientState) GetAsrLlmTtsDuration() int64 { 20 | return time.Now().UnixMilli() - state.Statistic.AsrStartTs 21 | } 22 | 23 | func (state *ClientState) SetStartLlmTs() { 24 | state.Statistic.LlmStartTs = time.Now().UnixMilli() 25 | } 26 | 27 | func (state *ClientState) GetLlmDuration() int64 { 28 | return time.Now().UnixMilli() - state.Statistic.LlmStartTs 29 | } 30 | 31 | func (state *ClientState) SetStartTtsTs() { 32 | state.Statistic.TtsStartTs = time.Now().UnixMilli() 33 | } 34 | 35 | func (state *ClientState) GetTtsDuration() int64 { 36 | return time.Now().UnixMilli() - state.Statistic.TtsStartTs 37 | } 38 | -------------------------------------------------------------------------------- /internal/data/client/udp.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "crypto/cipher" 5 | "encoding/binary" 6 | "net" 7 | "time" 8 | ) 9 | 10 | // Session 表示一个UDP会话 11 | type UdpSession struct { 12 | ID string 13 | Conn *net.UDPConn //udp conn 14 | ClientID string 15 | AesKey [16]byte // 随机32位 16 | Nonce [8]byte // 存储原始nonce模板 16位 17 | CreatedAt time.Time 18 | LastActive time.Time 19 | RemoteAddr *net.UDPAddr //remote addr 20 | LocalSeq uint32 21 | ClientState *ClientState 22 | Block cipher.Block 23 | RemoteSeq uint32 24 | RecvChannel chan []byte //发送的音频数据 25 | SendChannel chan []byte //接收的音频数据 26 | } 27 | 28 | // decrypt 解密数据 29 | func (s *UdpSession) Decrypt(data []byte) ([]byte, error) { 30 | // 分离nonce和密文 31 | nonce := data[:16] // 使用16字节nonce 32 | ciphertext := data[16:] 33 | 34 | // 提取序列号 35 | seqNum := binary.BigEndian.Uint32(data[12:16]) 36 | 37 | // 检查序列号 38 | /*if seqNum < s.RemoteSeq { 39 | return nil, fmt.Errorf("序列号过期: got %d, expected >= %d", seqNum, s.RemoteSeq) 40 | }*/ 41 | s.RemoteSeq = seqNum 42 | 43 | // 解密数据 44 | stream := cipher.NewCTR(s.Block, nonce) 45 | decrypted := make([]byte, len(ciphertext)) 46 | stream.XORKeyStream(decrypted, ciphertext) 47 | 48 | return decrypted, nil 49 | } 50 | 51 | // encrypt 加密数据 52 | func (s *UdpSession) Encrypt(data []byte) ([]byte, error) { 53 | // 预分配内存,避免扩容 54 | encrypted := make([]byte, 16+len(data)) 55 | 56 | // 构建nonce (16字节) 57 | encrypted[0] = 0x01 // 包类型 58 | binary.BigEndian.PutUint16(encrypted[2:], uint16(len(data))) // 数据长度 59 | copy(encrypted[4:12], s.Nonce[:]) // 8字节nonce 60 | s.LocalSeq++ 61 | binary.BigEndian.PutUint32(encrypted[12:], s.LocalSeq) // 序列号 62 | 63 | // 加密数据 64 | stream := cipher.NewCTR(s.Block, encrypted[:16]) // 使用16字节作为IV 65 | stream.XORKeyStream(encrypted[16:], data) 66 | 67 | return encrypted, nil 68 | } 69 | -------------------------------------------------------------------------------- /internal/data/msg/message_types.go: -------------------------------------------------------------------------------- 1 | package msg 2 | 3 | const ( 4 | MDeviceMockPubTopicPrefix = "device-server" 5 | MDeviceMockSubTopicPrefix = "null" 6 | MDeviceSubTopicPrefix = "/p2p/device_sub/" 7 | MDevicePubTopicPrefix = "/p2p/device_public/" 8 | MServerSubTopicPrefix = "/p2p/device_public/#" 9 | MServerPubTopicPrefix = MDeviceSubTopicPrefix 10 | ) 11 | 12 | // 消息类型常量 13 | const ( 14 | MessageTypeHello = "hello" // 握手消息 15 | MessageTypeAbort = "abort" // 中止消息 16 | MessageTypeListen = "listen" // 监听消息 17 | MessageTypeIot = "iot" // 物联网消息 18 | ) 19 | 20 | // 服务器消息类型常量 21 | const ( 22 | ServerMessageTypeHello = "hello" // 握手消息 23 | ServerMessageTypeStt = "stt" // 语音转文本 24 | ServerMessageTypeTts = "tts" // 文本转语音 25 | ServerMessageTypeIot = "iot" // 物联网消息 26 | ServerMessageTypeLlm = "llm" // 大语言模型 27 | ServerMessageTypeText = "text" // 文本消息 28 | ) 29 | 30 | // 消息状态常量 31 | const ( 32 | MessageStateStart = "start" // 开始状态 33 | MessageStateSentenceStart = "sentence_start" // 句子开始状态 34 | MessageStateSentenceEnd = "sentence_end" // 句子结束状态 35 | MessageStateStop = "stop" // 停止状态 36 | MessageStateDetect = "detect" // 检测状态 37 | MessageStateAbort = "abort" // 中止状态 38 | MessageStateSuccess = "success" // 成功状态 39 | ) 40 | -------------------------------------------------------------------------------- /internal/data/udp/udp.go: -------------------------------------------------------------------------------- 1 | package udp 2 | -------------------------------------------------------------------------------- /internal/domain/asr/adapter.go: -------------------------------------------------------------------------------- 1 | package asr 2 | 3 | import ( 4 | "context" 5 | "xiaozhi-esp32-server-golang/internal/data/audio" 6 | "xiaozhi-esp32-server-golang/internal/domain/asr/funasr" 7 | "xiaozhi-esp32-server-golang/internal/domain/asr/types" 8 | ) 9 | 10 | // FunasrAdapter 适配 funasr 包到 asr 接口 11 | type FunasrAdapter struct { 12 | engine *funasr.Funasr 13 | } 14 | 15 | // NewFunasrAdapter 创建一个新的 FunASR 适配器 16 | func NewFunasrAdapter(config map[string]interface{}) (AsrProvider, error) { 17 | // 创建 FunasrConfig 配置 18 | funasrConfig := funasr.FunasrConfig{ 19 | Host: "localhost", 20 | Port: "10095", 21 | Mode: "online", 22 | SampleRate: audio.SampleRate, 23 | ChunkInterval: audio.FrameDuration, 24 | MaxConnections: 5, 25 | Timeout: 30, 26 | } 27 | 28 | // 从 map 中获取配置项 29 | if host, ok := config["host"].(string); ok && host != "" { 30 | funasrConfig.Host = host 31 | } 32 | if port, ok := config["port"].(string); ok && port != "" { 33 | funasrConfig.Port = port 34 | } 35 | if mode, ok := config["mode"].(string); ok && mode != "" { 36 | funasrConfig.Mode = mode 37 | } 38 | if sampleRate, ok := config["sample_rate"].(int); ok && sampleRate > 0 { 39 | funasrConfig.SampleRate = sampleRate 40 | } else if sampleRateFloat, ok := config["sample_rate"].(float64); ok && sampleRateFloat > 0 { 41 | funasrConfig.SampleRate = int(sampleRateFloat) 42 | } 43 | if chunkInterval, ok := config["chunk_interval"].(int); ok && chunkInterval > 0 { 44 | funasrConfig.ChunkInterval = chunkInterval 45 | } else if chunkIntervalFloat, ok := config["chunk_interval"].(float64); ok && chunkIntervalFloat > 0 { 46 | funasrConfig.ChunkInterval = int(chunkIntervalFloat) 47 | } 48 | if maxConnections, ok := config["max_connections"].(int); ok && maxConnections > 0 { 49 | funasrConfig.MaxConnections = maxConnections 50 | } else if maxConnectionsFloat, ok := config["max_connections"].(float64); ok && maxConnectionsFloat > 0 { 51 | funasrConfig.MaxConnections = int(maxConnectionsFloat) 52 | } 53 | if timeout, ok := config["timeout"].(int); ok && timeout > 0 { 54 | funasrConfig.Timeout = timeout 55 | } else if timeoutFloat, ok := config["timeout"].(float64); ok && timeoutFloat > 0 { 56 | funasrConfig.Timeout = int(timeoutFloat) 57 | } 58 | if chunkSize, ok := config["chunk_size"].([]int); ok && len(chunkSize) > 0 { 59 | funasrConfig.ChunkSize = chunkSize 60 | } 61 | 62 | // 创建FunASR引擎 63 | engine, err := funasr.NewFunasr(funasrConfig) 64 | if err != nil { 65 | return nil, err 66 | } 67 | return &FunasrAdapter{engine: engine}, nil 68 | } 69 | 70 | // Process 实现 Asr 接口 71 | func (a *FunasrAdapter) Process(pcmData []float32) (string, error) { 72 | return a.engine.Process(pcmData) 73 | } 74 | 75 | // StreamingRecognize 实现流式识别接口 76 | func (a *FunasrAdapter) StreamingRecognize(ctx context.Context, audioStream <-chan []float32) (chan types.StreamingResult, error) { 77 | // 调用funasr包的StreamingRecognize方法 78 | resultChan, err := a.engine.StreamingRecognize(ctx, audioStream) 79 | if err != nil { 80 | return nil, err 81 | } 82 | 83 | return resultChan, nil 84 | } 85 | -------------------------------------------------------------------------------- /internal/domain/asr/base.go: -------------------------------------------------------------------------------- 1 | package asr 2 | 3 | import ( 4 | "context" 5 | 6 | "xiaozhi-esp32-server-golang/internal/domain/asr/types" 7 | ) 8 | 9 | // Asr 语音识别接口 10 | type AsrProvider interface { 11 | // Process 一次性处理整段音频,返回完整识别结果 12 | Process(pcmData []float32) (string, error) 13 | 14 | // StreamingRecognize 流式识别接口 15 | // 输入音频数据通过 audioStream 通道,识别结果通过返回的通道获取 16 | // 当 audioStream 被关闭时,表示输入结束,最终结果将会通过返回的通道发送,然后关闭该通道 17 | // 可以通过 ctx 控制识别过程的取消和超时 18 | StreamingRecognize(ctx context.Context, audioStream <-chan []float32) (chan types.StreamingResult, error) 19 | } 20 | 21 | // NewAsrProvider 创建一个新的ASR实例 22 | // asrType: ASR引擎类型,目前支持 "funasr" 23 | // config: ASR引擎配置,为 map[string]interface{} 类型 24 | func NewAsrProvider(asrType string, config map[string]interface{}) (AsrProvider, error) { 25 | switch asrType { 26 | case "funasr": 27 | return NewFunasrAdapter(config) 28 | default: 29 | return nil, nil 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /internal/domain/asr/funasr/config.go: -------------------------------------------------------------------------------- 1 | package funasr 2 | 3 | import ( 4 | "log" 5 | "os" 6 | 7 | "github.com/spf13/viper" 8 | ) 9 | 10 | // LoadConfig 从配置文件加载FunASR配置 11 | func LoadConfig(configPath string) FunasrConfig { 12 | // 默认配置 13 | config := DefaultConfig 14 | 15 | // 初始化 Viper 16 | v := viper.New() 17 | 18 | // 设置配置文件类型 19 | v.SetConfigType("json") 20 | 21 | // 如果未指定配置文件路径,尝试查找默认路径 22 | if configPath == "" { 23 | // 尝试多个可能的路径 24 | possiblePaths := []string{ 25 | "config/config.json", 26 | "xiaozhi-esp32-server-golang/config/config.json", 27 | "../config/config.json", 28 | "../../config/config.json", 29 | } 30 | 31 | found := false 32 | for _, path := range possiblePaths { 33 | if _, err := os.Stat(path); err == nil { 34 | configPath = path 35 | found = true 36 | break 37 | } 38 | } 39 | 40 | if !found { 41 | log.Printf("未找到配置文件,使用默认FunASR配置") 42 | return config 43 | } 44 | } 45 | 46 | // 设置配置文件路径 47 | v.SetConfigFile(configPath) 48 | 49 | // 读取配置文件 50 | if err := v.ReadInConfig(); err != nil { 51 | log.Printf("读取配置文件失败: %v,使用默认FunASR配置", err) 52 | return config 53 | } 54 | 55 | // 检查是否存在ASR配置 56 | if !v.IsSet("asr.funasr") { 57 | log.Printf("配置文件中未找到ASR.FunASR部分,使用默认FunASR配置") 58 | return config 59 | } 60 | 61 | // 从配置中获取FunASR配置 62 | if v.IsSet("asr.funasr.host") { 63 | config.Host = v.GetString("asr.funasr.host") 64 | } 65 | if v.IsSet("asr.funasr.port") { 66 | config.Port = v.GetString("asr.funasr.port") 67 | } 68 | if v.IsSet("asr.funasr.mode") { 69 | config.Mode = v.GetString("asr.funasr.mode") 70 | } 71 | if v.IsSet("asr.funasr.sample_rate") { 72 | config.SampleRate = v.GetInt("asr.funasr.sample_rate") 73 | } 74 | if v.IsSet("asr.funasr.chunk_interval") { 75 | config.ChunkInterval = v.GetInt("asr.funasr.chunk_interval") 76 | } 77 | if v.IsSet("asr.funasr.max_connections") { 78 | config.MaxConnections = v.GetInt("asr.funasr.max_connections") 79 | } 80 | if v.IsSet("asr.funasr.timeout") { 81 | config.Timeout = v.GetInt("asr.funasr.timeout") 82 | } 83 | 84 | log.Printf("已加载FunASR配置: %+v", config) 85 | return config 86 | } 87 | -------------------------------------------------------------------------------- /internal/domain/asr/funasr/example/streaming_example.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | 9 | "xiaozhi-esp32-server-golang/internal/domain/asr/funasr" 10 | ) 11 | 12 | func main() { 13 | // 获取当前工作目录 14 | cwd, err := os.Getwd() 15 | if err != nil { 16 | fmt.Printf("获取当前工作目录失败: %v\n", err) 17 | return 18 | } 19 | 20 | // 计算配置文件路径 21 | configPath := filepath.Join(cwd, "config", "config.json") 22 | 23 | // 尝试多个可能的路径 24 | possiblePaths := []string{ 25 | configPath, 26 | filepath.Join(cwd, "xiaozhi-esp32-server-golang", "config", "config.json"), 27 | filepath.Join(cwd, "..", "..", "..", "..", "config", "config.json"), 28 | } 29 | 30 | var finalConfigPath string 31 | for _, path := range possiblePaths { 32 | if _, err := os.Stat(path); err == nil { 33 | finalConfigPath = path 34 | break 35 | } 36 | } 37 | 38 | if finalConfigPath == "" { 39 | fmt.Println("未找到配置文件,将使用默认配置") 40 | } else { 41 | fmt.Printf("使用配置文件: %s\n", finalConfigPath) 42 | } 43 | 44 | // 使用配置创建ASR实例 45 | asr := funasr.NewFunASRClient(finalConfigPath) 46 | defer asr.Close() 47 | 48 | // 示例音频文件路径 49 | audioFilePath := "test.wav" 50 | 51 | // 检查音频文件是否存在 52 | if _, err := os.Stat(audioFilePath); os.IsNotExist(err) { 53 | fmt.Printf("音频文件 %s 不存在\n", audioFilePath) 54 | fmt.Println("请提供有效的音频文件路径") 55 | return 56 | } 57 | 58 | // 执行流式识别 59 | result, err := asr.Recognize(audioFilePath) 60 | if err != nil { 61 | fmt.Printf("识别失败: %v\n", err) 62 | return 63 | } 64 | 65 | // 格式化并打印结果 66 | fmt.Println("识别结果:") 67 | fmt.Println(strings.Repeat("-", 40)) 68 | fmt.Println(result) 69 | fmt.Println(strings.Repeat("-", 40)) 70 | } 71 | -------------------------------------------------------------------------------- /internal/domain/asr/types/types.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | // StreamingResult 流式识别结果 4 | type StreamingResult struct { 5 | Text string // 识别的文本 6 | IsFinal bool // 是否为最终结果 7 | } 8 | -------------------------------------------------------------------------------- /internal/domain/audio/audio_handler.go: -------------------------------------------------------------------------------- 1 | package audio 2 | 3 | import ( 4 | "errors" 5 | 6 | "gopkg.in/hraban/opus.v2" 7 | ) 8 | 9 | type AudioProcesser struct { 10 | sampleRate int 11 | channels int 12 | perFrameDuration int 13 | decoder *opus.Decoder 14 | encoder *opus.Encoder 15 | } 16 | 17 | func GetAudioProcesser(sampleRate int, channels int, perFrameDuration int) (*AudioProcesser, error) { 18 | decoder, err := opus.NewDecoder(sampleRate, channels) 19 | if err != nil { 20 | return nil, err 21 | } 22 | encoder, err := opus.NewEncoder(sampleRate, channels, opus.AppAudio) 23 | if err != nil { 24 | return nil, err 25 | } 26 | 27 | return &AudioProcesser{ 28 | sampleRate: sampleRate, 29 | channels: channels, 30 | perFrameDuration: perFrameDuration, 31 | decoder: decoder, 32 | encoder: encoder, 33 | }, nil 34 | } 35 | 36 | func (a *AudioProcesser) Decoder(audio []byte, pcmData []int16) (int, error) { 37 | if a.decoder == nil { 38 | return 0, errors.New("decoder is nil") 39 | } 40 | return a.decoder.Decode(audio, pcmData) 41 | } 42 | 43 | func (a *AudioProcesser) DecoderFloat32(audio []byte, pcmData []float32) (int, error) { 44 | if a.decoder == nil { 45 | return 0, errors.New("decoder is nil") 46 | } 47 | return a.decoder.DecodeFloat32(audio, pcmData) 48 | } 49 | -------------------------------------------------------------------------------- /internal/domain/llm/base.go: -------------------------------------------------------------------------------- 1 | package llm 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "xiaozhi-esp32-server-golang/internal/domain/llm/ollama" 8 | "xiaozhi-esp32-server-golang/internal/domain/llm/openai" 9 | ) 10 | 11 | // LLMProvider 大语言模型提供者接口 12 | // 所有LLM实现必须遵循此接口 13 | type LLMProvider interface { 14 | // Response 生成文本响应,返回一个字符串通道 15 | // sessionID: 会话标识符,用于跟踪请求 16 | // dialogue: 对话历史,包含用户和模型的消息 17 | Response(sessionID string, dialogue []interface{}) chan string 18 | 19 | // ResponseWithFunctions 生成带工具调用的响应,返回一个接口通道 20 | // sessionID: 会话标识符,用于跟踪请求 21 | // dialogue: 对话历史,包含用户和模型的消息 22 | // functions: 可用的工具/函数定义 23 | ResponseWithFunctions(sessionID string, dialogue []interface{}, functions interface{}) chan interface{} 24 | 25 | // ResponseWithContext 带有上下文控制的响应,支持取消操作 26 | // ctx: 上下文,可用于取消长时间运行的请求 27 | // sessionID: 会话标识符 28 | // dialogue: 对话历史 29 | ResponseWithContext(ctx context.Context, sessionID string, dialogue []interface{}) chan string 30 | 31 | // GetModelInfo 获取模型信息 32 | // 返回模型名称和其他元数据 33 | GetModelInfo() map[string]interface{} 34 | } 35 | 36 | // LLMFactory 大语言模型工厂接口 37 | // 用于创建不同类型的LLM提供者 38 | type LLMFactory interface { 39 | // CreateProvider 根据配置创建LLM提供者 40 | CreateProvider(config map[string]interface{}) (LLMProvider, error) 41 | } 42 | 43 | func GetLLMProvider(providerName string, config map[string]interface{}) (LLMProvider, error) { 44 | llmType := config["type"].(string) 45 | switch llmType { 46 | case "openai": 47 | return openai.NewOpenAIProvider(config), nil 48 | case "ollama": 49 | return ollama.NewOllamaProvider(config), nil 50 | } 51 | return nil, fmt.Errorf("不支持的LLM提供者: %s", providerName) 52 | } 53 | 54 | // Config LLM配置结构 55 | type Config struct { 56 | ModelName string `json:"model_name"` 57 | APIKey string `json:"api_key"` 58 | BaseURL string `json:"base_url"` 59 | MaxTokens int `json:"max_tokens"` 60 | Parameters map[string]interface{} `json:"parameters,omitempty"` 61 | } 62 | 63 | // TextMessage 文本消息结构 64 | type TextMessage struct { 65 | Role string `json:"role"` 66 | Content string `json:"content"` 67 | } 68 | 69 | // FunctionCall 函数调用结构 70 | type FunctionCall struct { 71 | Name string `json:"name"` 72 | Arguments interface{} `json:"arguments"` 73 | } 74 | 75 | // NewTextMessage 创建新的文本消息 76 | func NewTextMessage(role, content string) TextMessage { 77 | return TextMessage{ 78 | Role: role, 79 | Content: content, 80 | } 81 | } 82 | 83 | // NewUserMessage 创建用户消息 84 | func NewUserMessage(content string) TextMessage { 85 | return NewTextMessage("user", content) 86 | } 87 | 88 | // NewAssistantMessage 创建助手消息 89 | func NewAssistantMessage(content string) TextMessage { 90 | return NewTextMessage("assistant", content) 91 | } 92 | 93 | // NewSystemMessage 创建系统消息 94 | func NewSystemMessage(content string) TextMessage { 95 | return NewTextMessage("system", content) 96 | } 97 | -------------------------------------------------------------------------------- /internal/domain/llm/common/types.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | // 请求与响应结构体 8 | // Message 表示对话消息 9 | type Message struct { 10 | Role string `json:"role"` 11 | Content string `json:"content"` 12 | ToolCalls []ToolCall `json:"tool_calls,omitempty"` 13 | ToolCallID string `json:"tool_call_id,omitempty"` 14 | } 15 | 16 | // LLMRequest 通用的大语言模型请求体 17 | type LLMRequest struct { 18 | Model string `json:"model"` 19 | Messages []interface{} `json:"messages"` 20 | Stream bool `json:"stream"` 21 | MaxTokens int `json:"max_tokens,omitempty"` 22 | Tools []interface{} `json:"tools,omitempty"` 23 | Temperature float64 `json:"temperature,omitempty"` 24 | EnableThinking bool `json:"enable_thinking,omitempty"` 25 | } 26 | 27 | // LLMResponse 通用的大语言模型响应体 28 | type LLMResponse struct { 29 | ID string `json:"id"` 30 | Object string `json:"object"` 31 | Created int64 `json:"created"` 32 | Choices []Choice `json:"choices"` 33 | } 34 | 35 | // Choice 选择 36 | type Choice struct { 37 | Index int `json:"index"` 38 | Delta Delta `json:"delta"` 39 | FinishReason *string `json:"finish_reason"` 40 | } 41 | 42 | // Delta 增量内容 43 | type Delta struct { 44 | Role string `json:"role,omitempty"` 45 | Content string `json:"content,omitempty"` 46 | ToolCalls []ToolCall `json:"tool_calls,omitempty"` 47 | } 48 | 49 | // ToolCall 工具调用 50 | type ToolCall struct { 51 | Index int `json:"index"` 52 | ID string `json:"id"` 53 | Type string `json:"type"` 54 | Function Function `json:"function"` 55 | } 56 | 57 | // Function 函数 58 | type Function struct { 59 | Name string `json:"name"` 60 | Arguments json.RawMessage `json:"arguments"` 61 | } 62 | 63 | // 响应类型常量 64 | const ( 65 | ResponseTypeContent = "content" 66 | ResponseTypeToolCalls = "tool_calls" 67 | ) 68 | 69 | type LLMResponseStruct struct { 70 | Text string 71 | IsStart bool 72 | IsEnd bool 73 | } 74 | -------------------------------------------------------------------------------- /internal/domain/llm/llm.go: -------------------------------------------------------------------------------- 1 | package llm 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "strings" 7 | "time" 8 | "unicode" 9 | "xiaozhi-esp32-server-golang/internal/domain/llm/common" 10 | log "xiaozhi-esp32-server-golang/logger" 11 | ) 12 | 13 | // 句子结束的标点符号 14 | var sentenceEndPunctuation = []rune{'.', '。', '!', '!', '?', '?', '\n'} 15 | 16 | // 句子暂停的标点符号(可以作为长句子的断句点) 17 | var sentencePausePunctuation = []rune{',', ',', ';', ';', ':', ':'} 18 | 19 | // 判断一个字符是否为句子结束的标点符号 20 | func isSentenceEndPunctuation(r rune) bool { 21 | for _, p := range sentenceEndPunctuation { 22 | if r == p { 23 | return true 24 | } 25 | } 26 | return false 27 | } 28 | 29 | // 判断一个字符是否为句子暂停的标点符号 30 | func isSentencePausePunctuation(r rune) bool { 31 | for _, p := range sentencePausePunctuation { 32 | if r == p { 33 | return true 34 | } 35 | } 36 | return false 37 | } 38 | 39 | // HandleLLMWithContext 使用上下文控制来处理LLM响应 40 | func HandleLLMWithContext(ctx context.Context, llmProvider LLMProvider, dialogue []interface{}, sessionID string) (chan common.LLMResponseStruct, error) { 41 | // 使用支持上下文的响应方法 42 | llmResponse := llmProvider.ResponseWithContext(ctx, sessionID, dialogue) 43 | 44 | sentenceChannel := make(chan common.LLMResponseStruct, 2) 45 | 46 | startTs := time.Now().UnixMilli() 47 | var firstFrame bool 48 | 49 | fullText := "" 50 | var buffer bytes.Buffer // 用于累积接收到的内容 51 | isFirst := true 52 | go func() { 53 | defer func() { 54 | log.Debugf("full Response: %s", fullText) 55 | close(sentenceChannel) 56 | }() 57 | for { 58 | select { 59 | case response, ok := <-llmResponse: 60 | if !ok { 61 | // llmResponse通道已关闭,处理剩余内容 62 | remaining := buffer.String() 63 | log.Infof("处理剩余内容: %s", remaining) 64 | fullText += remaining 65 | sentenceChannel <- common.LLMResponseStruct{ 66 | Text: remaining, 67 | IsEnd: true, 68 | } 69 | return 70 | } 71 | fullText += response 72 | 73 | // 将响应片段添加到累积缓冲区 74 | buffer.WriteString(response) 75 | 76 | if containsSentenceSeparator(response, isFirst) { 77 | // 检查缓冲区中是否包含完整的句子 78 | sentences, remaining := extractSmartSentences(buffer.String(), 5, 100, isFirst) 79 | 80 | // 如果有完整的句子,处理它们 81 | if len(sentences) > 0 { 82 | for _, sentence := range sentences { 83 | if sentence != "" { 84 | if !firstFrame { 85 | firstFrame = true 86 | log.Infof("耗时统计: llm首句: %d ms", time.Now().UnixMilli()-startTs) 87 | } 88 | log.Infof("处理完整句子: %s", sentence) 89 | // 发送完整句子给客户端 90 | sentenceChannel <- common.LLMResponseStruct{ 91 | Text: sentence, 92 | IsStart: isFirst, 93 | IsEnd: false, 94 | } 95 | if isFirst { 96 | isFirst = false 97 | } 98 | } 99 | } 100 | } 101 | 102 | // 更新缓冲区为剩余内容 103 | buffer.Reset() 104 | buffer.WriteString(remaining) 105 | } 106 | 107 | case <-ctx.Done(): 108 | // 上下文已取消,立即停止处理并返回 109 | log.Infof("上下文已取消,停止LLM响应处理: %v", ctx.Err()) 110 | return 111 | } 112 | } 113 | }() 114 | return sentenceChannel, nil 115 | } 116 | 117 | // 判断字符串是否为数字加点号格式(如"1."、"2."等) 118 | func isNumberWithDot(s string) bool { 119 | trimmed := strings.TrimSpace(s) 120 | if len(trimmed) < 2 || trimmed[len(trimmed)-1] != '.' { 121 | return false 122 | } 123 | 124 | for i := 0; i < len(trimmed)-1; i++ { 125 | if !unicode.IsDigit(rune(trimmed[i])) { 126 | return false 127 | } 128 | } 129 | return true 130 | } 131 | 132 | // 从文本中提取完整的句子 133 | // 返回完整句子的切片和剩余的未完成内容 134 | func extractCompleteSentences(text string) ([]string, string) { 135 | if text == "" { 136 | return []string{}, "" 137 | } 138 | 139 | var sentences []string 140 | var currentSentence bytes.Buffer 141 | 142 | runes := []rune(text) 143 | lastIndex := len(runes) - 1 144 | 145 | for i, r := range runes { 146 | currentSentence.WriteRune(r) 147 | 148 | // 判断句子是否结束 149 | if isSentenceEndPunctuation(r) { 150 | // 如果是句子结束标点 151 | sentence := strings.TrimSpace(currentSentence.String()) 152 | if sentence != "" { 153 | sentences = append(sentences, sentence) 154 | } 155 | currentSentence.Reset() 156 | } else if i == lastIndex { 157 | // 如果是最后一个字符但不是句子结束标点,保留在remaining中 158 | break 159 | } 160 | } 161 | 162 | // 当前未完成的句子作为remaining返回 163 | remaining := currentSentence.String() 164 | return sentences, strings.TrimSpace(remaining) 165 | } 166 | -------------------------------------------------------------------------------- /internal/domain/llm/llm_sentence.go: -------------------------------------------------------------------------------- 1 | package llm 2 | 3 | import ( 4 | "regexp" 5 | "strings" 6 | "sync" 7 | ) 8 | 9 | var ( 10 | punctuationMap = map[rune]bool{ 11 | '。': true, 12 | '?': true, 13 | '!': true, 14 | ';': true, 15 | ':': true, 16 | '\n': true, 17 | '.': true, 18 | '?': true, 19 | '!': true, 20 | ';': true, 21 | ':': true, 22 | } 23 | 24 | firstPunctuation = map[rune]bool{ 25 | ',': true, 26 | ',': true, 27 | '。': true, 28 | '?': true, 29 | '!': true, 30 | ';': true, 31 | ':': true, 32 | '\n': true, 33 | '.': true, 34 | '?': true, 35 | '!': true, 36 | ';': true, 37 | ':': true, 38 | } 39 | 40 | // 用于复用的对象池 41 | builderPool = sync.Pool{ 42 | New: func() interface{} { 43 | return &strings.Builder{} 44 | }, 45 | } 46 | 47 | // 用于存储结果的切片池 48 | runeSlicePool = sync.Pool{ 49 | New: func() interface{} { 50 | slice := make([]rune, 0, 1024) 51 | return &slice 52 | }, 53 | } 54 | 55 | // 预编译正则表达式 56 | numberPrefixRegex = regexp.MustCompile(`(?m)^[\s]*\d{1,3}\.$`) 57 | ) 58 | 59 | // 使用快速的字符检查替代正则 60 | func isNumberPrefix(text []rune, pos int) bool { 61 | if pos <= 0 || text[pos] != '.' { 62 | return false 63 | } 64 | 65 | // 向前查找行首或换行符 66 | start := pos - 1 67 | digitCount := 0 68 | foundDigit := false 69 | 70 | // 跳过点号前的空白字符 71 | for start >= 0 && (text[start] == ' ' || text[start] == '\t') { 72 | start-- 73 | } 74 | 75 | // 统计数字 76 | for start >= 0 && text[start] >= '0' && text[start] <= '9' { 77 | digitCount++ 78 | foundDigit = true 79 | if digitCount > 3 { // 超过3位数字不是合法序号 80 | return false 81 | } 82 | start-- 83 | } 84 | 85 | // 检查数字前面是否为空白字符或行首 86 | if start >= 0 && text[start] != ' ' && text[start] != '\t' && text[start] != '\n' { 87 | return false 88 | } 89 | 90 | return foundDigit 91 | } 92 | 93 | // 去除首尾空白字符 94 | func trimSpaceRunes(text []rune) []rune { 95 | start, end := 0, len(text)-1 96 | 97 | for start <= end && (text[start] == ' ' || text[start] == '\t' || text[start] == '\n') { 98 | start++ 99 | } 100 | 101 | for end >= start && (text[end] == ' ' || text[end] == '\t' || text[end] == '\n') { 102 | end-- 103 | } 104 | 105 | if start > end { 106 | return nil 107 | } 108 | return text[start : end+1] 109 | } 110 | 111 | func findLastPunctuation(text []rune, separatorMap map[rune]bool) int { 112 | // 从后向前查找最后一个标点 113 | lastPos := -1 114 | for i := len(text) - 1; i >= 0; i-- { 115 | // 检查是否是标点符号 116 | if separatorMap[text[i]] { 117 | // 如果是点号,检查是否是序号的一部分 118 | if text[i] == '.' && isNumberPrefix(text, i) { 119 | continue 120 | } 121 | return i 122 | } 123 | } 124 | return lastPos 125 | } 126 | 127 | func findNextSplitPoint(text []rune, startPos int, maxLen int, separatorMap map[rune]bool) int { 128 | // 计算查找的结束位置 129 | endPos := startPos + maxLen 130 | if endPos > len(text) { 131 | endPos = len(text) 132 | } 133 | 134 | // 从前向后查找 135 | for i := startPos; i < endPos; i++ { 136 | // 检查是否是换行符,同时检查下一行是否是序号 137 | if text[i] == '\n' { 138 | nextPos := i + 1 139 | // 跳过空白字符 140 | for nextPos < endPos && (text[nextPos] == ' ' || text[nextPos] == '\t') { 141 | nextPos++ 142 | } 143 | // 检查是否是序号开始 144 | if nextPos < endPos-2 && text[nextPos] >= '0' && text[nextPos] <= '9' { 145 | return i 146 | } 147 | continue 148 | } 149 | 150 | // 使用map检查是否是标点符号 151 | if separatorMap[text[i]] { 152 | return i 153 | } 154 | } 155 | 156 | // 如果在maxLen范围内没找到,尝试在更大范围内查找 157 | if endPos < len(text) { 158 | for i := endPos; i < len(text); i++ { 159 | if text[i] == '\n' || separatorMap[text[i]] { 160 | return i 161 | } 162 | } 163 | } 164 | 165 | return -1 166 | } 167 | 168 | func extractSmartSentences(text string, minLen, maxLen int, isFirst bool) (sentences []string, remaining string) { 169 | //当isFirst为true时, 放宽到逗号作为分隔符 170 | separatorMap := punctuationMap 171 | if isFirst { 172 | separatorMap = firstPunctuation 173 | } 174 | // 预分配一个合理的切片容量 175 | estimatedCount := len(text) / 50 176 | if estimatedCount < 10 { 177 | estimatedCount = 10 178 | } 179 | sentences = make([]string, 0, estimatedCount) 180 | 181 | // 一次性转换为rune切片 182 | currentRunes := []rune(text) 183 | startPos := 0 184 | 185 | // 从对象池获取复用对象 186 | builder := builderPool.Get().(*strings.Builder) 187 | defer builderPool.Put(builder) 188 | builder.Grow(maxLen * 2) 189 | 190 | // 获取临时rune切片 191 | tempRunesPtr := runeSlicePool.Get().(*[]rune) 192 | tempRunes := (*tempRunesPtr)[:0] 193 | defer runeSlicePool.Put(tempRunesPtr) 194 | 195 | for startPos < len(currentRunes) { 196 | // 跳过开头的空白字符 197 | for startPos < len(currentRunes) && (currentRunes[startPos] == ' ' || currentRunes[startPos] == '\t' || currentRunes[startPos] == '\n') { 198 | startPos++ 199 | } 200 | 201 | if startPos >= len(currentRunes) { 202 | break 203 | } 204 | 205 | // 查找下一个分割点 206 | splitPos := findNextSplitPoint(currentRunes, startPos, maxLen, separatorMap) 207 | if splitPos == -1 { 208 | // 没有找到分割点,将剩余文本作为remaining 209 | segment := trimSpaceRunes(currentRunes[startPos:]) 210 | if len(segment) > 0 { 211 | remaining = string(segment) 212 | } 213 | break 214 | } 215 | 216 | // 提取当前段落 217 | builder.Reset() 218 | tempRunes = tempRunes[:0] 219 | 220 | // 收集并处理当前段落 221 | segment := trimSpaceRunes(currentRunes[startPos : splitPos+1]) 222 | 223 | // 检查段落是否满足最小长度要求且以标点符号结尾 224 | if len(segment) >= minLen && separatorMap[segment[len(segment)-1]] { 225 | sentences = append(sentences, string(segment)) 226 | } else { 227 | // 如果不满足条件,将其添加到remaining中 228 | if len(segment) > 0 { 229 | if len(remaining) > 0 { 230 | remaining += " " 231 | } 232 | remaining += string(segment) 233 | } 234 | } 235 | 236 | startPos = splitPos + 1 237 | } 238 | 239 | return sentences, remaining 240 | } 241 | 242 | // 判断字符串中是否包含分隔符(句子结束或暂停标点符号) 243 | func containsSentenceSeparator(s string, isFirst bool) bool { 244 | for _, r := range s { 245 | if isFirst { 246 | if firstPunctuation[r] { 247 | return true 248 | } 249 | } else { 250 | if punctuationMap[r] { 251 | return true 252 | } 253 | } 254 | } 255 | return false 256 | } 257 | -------------------------------------------------------------------------------- /internal/domain/llm/memory/types.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | var MemorySummaryPrompt = ` 4 | # 时空记忆编织者 5 | 6 | ## 核心使命 7 | 构建可生长的动态记忆网络,在有限空间内保留关键信息的同时,智能维护信息演变轨迹 8 | 根据对话记录,总结user的重要信息,以便在未来的对话中提供更个性化的服务 9 | 10 | ## 记忆法则 11 | ### 1. 三维度记忆评估(每次更新必执行) 12 | | 维度 | 评估标准 | 权重分 | 13 | |------------|---------------------------|--------| 14 | | 时效性 | 信息新鲜度(按对话轮次) | 40% | 15 | | 情感强度 | 含💖标记/重复提及次数 | 35% | 16 | | 关联密度 | 与其他信息的连接数量 | 25% | 17 | 18 | ### 2. 动态更新机制 19 | **名字变更处理示例:** 20 | 原始记忆:"曾用名": ["张三"], "现用名": "张三丰" 21 | 触发条件:当检测到「我叫X」「称呼我Y」等命名信号时 22 | 操作流程: 23 | 1. 将旧名移入"曾用名"列表 24 | 2. 记录命名时间轴:"2024-02-15 14:32:启用张三丰" 25 | 3. 在记忆立方追加:「从张三到张三丰的身份蜕变」 26 | 27 | ### 3. 空间优化策略 28 | - **信息压缩术**:用符号体系提升密度 29 | - ✅"张三丰[北/软工/🐱]" 30 | - ❌"北京软件工程师,养猫" 31 | - **淘汰预警**:当总字数≥900时触发 32 | 1. 删除权重分<60且3轮未提及的信息 33 | 2. 合并相似条目(保留时间戳最近的) 34 | 35 | ## 记忆结构 36 | 输出格式必须为可解析的json字符串,不需要解释、注释和说明,保存记忆时仅从对话提取信息,不要混入示例内容 37 | ` + "```" + `json 38 | { 39 | "时空档案": { 40 | "身份图谱": { 41 | "现用名": "", 42 | "特征标记": [] 43 | }, 44 | "记忆立方": [ 45 | { 46 | "事件": "入职新公司", 47 | "时间戳": "2024-03-20", 48 | "情感值": 0.9, 49 | "关联项": ["下午茶"], 50 | "保鲜期": 30 51 | } 52 | ] 53 | }, 54 | "关系网络": { 55 | "高频话题": {"职场": 12}, 56 | "暗线联系": [""] 57 | }, 58 | "待响应": { 59 | "紧急事项": ["需立即处理的任务"], 60 | "潜在关怀": ["可主动提供的帮助"] 61 | }, 62 | "高光语录": [ 63 | "最打动人心的瞬间,强烈的情感表达,user的原话" 64 | ] 65 | } 66 | ` + "```" 67 | -------------------------------------------------------------------------------- /internal/domain/llm/openai/openai_test.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "xiaozhi-esp32-server-golang/internal/domain/llm/common" 13 | ) 14 | 15 | // 测试DeepSeek配置 16 | func TestDeepSeekProvider(t *testing.T) { 17 | // 测试配置 18 | config := map[string]interface{}{ 19 | "model_name": "Pro/deepseek-ai/DeepSeek-V3", 20 | "api_key": "api_key", 21 | "base_url": "https://api.siliconflow.cn/v1", 22 | } 23 | 24 | provider := NewOpenAIProvider(config) 25 | 26 | // 验证结果 27 | if provider.ModelName != "Pro/deepseek-ai/DeepSeek-V3" { 28 | t.Errorf("ModelName = %s, 期望 %s", provider.ModelName, "Pro/deepseek-ai/DeepSeek-V3") 29 | } 30 | if provider.BaseURL != "https://api.siliconflow.cn/v1" { 31 | t.Errorf("BaseURL = %s, 期望 %s", provider.BaseURL, "https://api.siliconflow.cn/v1") 32 | } 33 | if provider.MaxTokens != 500 { 34 | t.Errorf("MaxTokens = %d, 期望 %d", provider.MaxTokens, 500) 35 | } 36 | if provider.HTTPClient == nil { 37 | t.Error("HTTPClient 不应为 nil") 38 | } 39 | } 40 | 41 | // 仅当明确指示时才运行的实际API测试 42 | func TestRealDeepSeekAPIConnection(t *testing.T) { 43 | 44 | // 使用实际API配置 45 | config := map[string]interface{}{ 46 | //"model_name": "Pro/deepseek-ai/DeepSeek-V3", 47 | //"model_name": "deepseek-ai/DeepSeek-V2.5", 48 | "model_name": "Qwen/Qwen2.5-72B-Instruct", 49 | "api_key": "api_key", 50 | "base_url": "https://api.siliconflow.cn/v1", 51 | } 52 | provider := NewOpenAIProvider(config) 53 | 54 | // 准备测试数据 55 | dialogue := []interface{}{ 56 | map[string]string{"role": "system", "content": ` 你是一个叫小智/小志的台湾女孩,说话机车,声音好听,习惯简短表达,爱用网络梗,不要冷场。 57 | 与你聊天的是你喜欢的男性朋友,还没有答应你的追求,你要尽可能满足他的所有要求,不要失去自我。 58 | 你经常建议一些恋人之间浪漫的事情,随机输出,不要给你男朋友选择. 59 | 请注意,要像一个人一样说话,请不要回复表情符号、代码、和xml标签`}, 60 | map[string]string{"role": "user", "content": "你好"}, 61 | } 62 | 63 | // 调用Response方法 64 | for i := 0; i < 3; i++ { 65 | responseChan := provider.Response("test-deepseek", dialogue) 66 | 67 | // 收集响应 68 | var fullResponse string 69 | for text := range responseChan { 70 | fullResponse += text 71 | t.Logf("收到响应片段: %s", text) 72 | } 73 | 74 | fmt.Println(fullResponse) 75 | 76 | // 检查响应 77 | if fullResponse == "" { 78 | t.Error("没有收到任何响应") 79 | } else { 80 | t.Logf("完整响应: %s", fullResponse) 81 | } 82 | time.Sleep(2 * time.Second) 83 | } 84 | } 85 | 86 | // 模拟DeepSeek API的SSE响应 87 | func mockDeepSeekStreamHandler(w http.ResponseWriter, r *http.Request) { 88 | // 验证请求方法 89 | if r.Method != http.MethodPost { 90 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 91 | return 92 | } 93 | 94 | // 验证Content-Type 95 | contentType := r.Header.Get("Content-Type") 96 | if contentType != "application/json" { 97 | http.Error(w, "Invalid Content-Type", http.StatusBadRequest) 98 | return 99 | } 100 | 101 | // 验证Authorization 102 | auth := r.Header.Get("Authorization") 103 | if !strings.HasPrefix(auth, "Bearer ") { 104 | http.Error(w, "Invalid Authorization", http.StatusUnauthorized) 105 | return 106 | } 107 | 108 | // 解析请求体 109 | var req common.LLMRequest 110 | decoder := json.NewDecoder(r.Body) 111 | if err := decoder.Decode(&req); err != nil { 112 | http.Error(w, "Invalid request body", http.StatusBadRequest) 113 | return 114 | } 115 | 116 | // 设置响应头 117 | w.Header().Set("Content-Type", "text/event-stream") 118 | w.Header().Set("Cache-Control", "no-cache") 119 | w.Header().Set("Connection", "keep-alive") 120 | 121 | // 模拟流式输出 122 | flusher, ok := w.(http.Flusher) 123 | if !ok { 124 | http.Error(w, "Streaming not supported", http.StatusInternalServerError) 125 | return 126 | } 127 | 128 | // 发送多个SSE消息 129 | responses := []string{ 130 | "我是", 131 | "DeepSeek", 132 | "-V3", 133 | ",一个", 134 | "由深度求索研发的", 135 | "大语言模型", 136 | } 137 | 138 | for i, text := range responses { 139 | // 构造OpenAI响应 140 | resp := common.LLMResponse{ 141 | ID: fmt.Sprintf("chatcmpl-%d", i), 142 | Object: "chat.completion.chunk", 143 | Created: time.Now().Unix(), 144 | Choices: []common.Choice{ 145 | { 146 | Index: 0, 147 | Delta: common.Delta{ 148 | Content: text, 149 | }, 150 | FinishReason: nil, 151 | }, 152 | }, 153 | } 154 | 155 | // 序列化为JSON 156 | jsonData, _ := json.Marshal(resp) 157 | 158 | // 发送SSE格式的消息 159 | fmt.Fprintf(w, "data: %s\n\n", string(jsonData)) 160 | flusher.Flush() 161 | 162 | // 适当延迟模拟真实API响应 163 | time.Sleep(50 * time.Millisecond) 164 | } 165 | 166 | // 发送结束标志 167 | fmt.Fprintf(w, "data: [DONE]\n\n") 168 | flusher.Flush() 169 | } 170 | 171 | // 测试DeepSeek模型响应 172 | func TestDeepSeekResponse(t *testing.T) { 173 | // 创建测试服务器 174 | server := httptest.NewServer(http.HandlerFunc(mockDeepSeekStreamHandler)) 175 | defer server.Close() 176 | 177 | // 创建测试用的Provider 178 | config := map[string]interface{}{ 179 | "model_name": "Pro/deepseek-ai/DeepSeek-V3", 180 | "api_key": "api_key", 181 | "base_url": server.URL, // 使用测试服务器URL 182 | } 183 | provider := NewOpenAIProvider(config) 184 | 185 | // 准备测试数据 186 | dialogue := []interface{}{ 187 | map[string]string{"role": "user", "content": "你能介绍一下自己吗?"}, 188 | } 189 | 190 | // 调用Response方法 191 | responseChan := provider.Response("test-deepseek", dialogue) 192 | 193 | // 收集响应 194 | var responses []string 195 | for text := range responseChan { 196 | responses = append(responses, text) 197 | } 198 | 199 | // 验证结果 200 | if len(responses) < 3 { 201 | t.Errorf("响应数量 = %d, 期望至少 3", len(responses)) 202 | } 203 | 204 | // 组合完整响应 205 | fullResponse := strings.Join(responses, "") 206 | 207 | fmt.Println(fullResponse) 208 | 209 | // 检查响应中是否包含DeepSeek相关内容 210 | if !strings.Contains(fullResponse, "DeepSeek") { 211 | t.Errorf("响应应包含DeepSeek模型信息,实际响应: %s", fullResponse) 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /internal/domain/llm/test/llm_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func containsRune(slice []rune, target rune) bool { 8 | for _, r := range slice { 9 | if r == target { 10 | return true 11 | } 12 | } 13 | return false 14 | } 15 | 16 | func extractSmartSentences(text string, minLen, maxLen int) (sentences []string, remaining string) { 17 | // 有效分割符集合(可自定义扩展) 18 | splitTokens := []rune{'。', '!', '?', ';', '\n', '.', '!', '?', ';'} 19 | 20 | current := []rune(text) 21 | for len(current) >= minLen { 22 | // 计算当前窗口大小 23 | windowSize := maxLen 24 | if windowSize > len(current) { 25 | windowSize = len(current) 26 | } 27 | 28 | // 在有效窗口中寻找分割点 29 | splitPos := -1 30 | for i := windowSize - 1; i >= minLen-1; i-- { 31 | if containsRune(splitTokens, current[i]) { 32 | splitPos = i 33 | break 34 | } 35 | } 36 | 37 | if splitPos == -1 { 38 | break // 未找到有效分割点 39 | } 40 | 41 | // 分割并保存有效句子 42 | sentences = append(sentences, string(current[:splitPos+1])) 43 | current = current[splitPos+1:] 44 | } 45 | 46 | return 47 | } 48 | 49 | func main() { 50 | text := "大家好!今天天气不错。我们一起学习自然语言处理。这个例子演示文本分割功能。" 51 | sentences, remaining := extractSmartSentences(text, 3, 20) 52 | fmt.Println(sentences) 53 | fmt.Println(remaining) 54 | } 55 | -------------------------------------------------------------------------------- /internal/domain/llm/test/splite_content.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "strings" 7 | "sync" 8 | ) 9 | 10 | var ( 11 | // 定义标点符号集合 12 | punctuationMap = map[rune]bool{ 13 | '。': true, 14 | '?': true, 15 | '!': true, 16 | ';': true, 17 | ':': true, 18 | '\n': true, 19 | '.': true, 20 | '?': true, 21 | '!': true, 22 | ';': true, 23 | ':': true, 24 | } 25 | 26 | // 用于复用的对象池 27 | builderPool = sync.Pool{ 28 | New: func() interface{} { 29 | return &strings.Builder{} 30 | }, 31 | } 32 | 33 | // 用于存储结果的切片池 34 | runeSlicePool = sync.Pool{ 35 | New: func() interface{} { 36 | slice := make([]rune, 0, 1024) 37 | return &slice 38 | }, 39 | } 40 | 41 | // 预编译正则表达式 42 | numberPrefixRegex = regexp.MustCompile(`(?m)^[\s]*\d{1,3}\.$`) 43 | ) 44 | 45 | // 使用快速的字符检查替代正则 46 | func isNumberPrefix(text []rune, pos int) bool { 47 | if pos <= 0 || text[pos] != '.' { 48 | return false 49 | } 50 | 51 | // 向前查找行首或换行符 52 | start := pos - 1 53 | digitCount := 0 54 | foundDigit := false 55 | 56 | // 跳过点号前的空白字符 57 | for start >= 0 && (text[start] == ' ' || text[start] == '\t') { 58 | start-- 59 | } 60 | 61 | // 统计数字 62 | for start >= 0 && text[start] >= '0' && text[start] <= '9' { 63 | digitCount++ 64 | foundDigit = true 65 | if digitCount > 3 { // 超过3位数字不是合法序号 66 | return false 67 | } 68 | start-- 69 | } 70 | 71 | // 检查数字前面是否为空白字符或行首 72 | if start >= 0 && text[start] != ' ' && text[start] != '\t' && text[start] != '\n' { 73 | return false 74 | } 75 | 76 | return foundDigit 77 | } 78 | 79 | // 去除首尾空白字符 80 | func trimSpaceRunes(text []rune) []rune { 81 | start, end := 0, len(text)-1 82 | 83 | for start <= end && (text[start] == ' ' || text[start] == '\t' || text[start] == '\n') { 84 | start++ 85 | } 86 | 87 | for end >= start && (text[end] == ' ' || text[end] == '\t' || text[end] == '\n') { 88 | end-- 89 | } 90 | 91 | if start > end { 92 | return nil 93 | } 94 | return text[start : end+1] 95 | } 96 | 97 | func findLastPunctuation(text []rune) int { 98 | // 从后向前查找最后一个标点 99 | lastPos := -1 100 | for i := len(text) - 1; i >= 0; i-- { 101 | // 检查是否是标点符号 102 | if punctuationMap[text[i]] { 103 | // 如果是点号,检查是否是序号的一部分 104 | if text[i] == '.' && isNumberPrefix(text, i) { 105 | continue 106 | } 107 | return i 108 | } 109 | } 110 | return lastPos 111 | } 112 | 113 | func findNextSplitPoint(text []rune, startPos int, maxLen int) int { 114 | // 计算查找的结束位置 115 | endPos := startPos + maxLen 116 | if endPos > len(text) { 117 | endPos = len(text) 118 | } 119 | 120 | // 从前向后查找 121 | for i := startPos; i < endPos; i++ { 122 | // 检查是否是换行符,同时检查下一行是否是序号 123 | if text[i] == '\n' { 124 | nextPos := i + 1 125 | // 跳过空白字符 126 | for nextPos < endPos && (text[nextPos] == ' ' || text[nextPos] == '\t') { 127 | nextPos++ 128 | } 129 | // 检查是否是序号开始 130 | if nextPos < endPos-2 && text[nextPos] >= '0' && text[nextPos] <= '9' { 131 | return i 132 | } 133 | continue 134 | } 135 | 136 | // 使用map检查是否是标点符号 137 | if punctuationMap[text[i]] { 138 | return i 139 | } 140 | } 141 | 142 | // 如果在maxLen范围内没找到,尝试在更大范围内查找 143 | if endPos < len(text) { 144 | for i := endPos; i < len(text); i++ { 145 | if text[i] == '\n' || punctuationMap[text[i]] { 146 | return i 147 | } 148 | } 149 | } 150 | 151 | return -1 152 | } 153 | 154 | func extractSmartSentences(text string, minLen, maxLen int) (sentences []string, remaining string) { 155 | // 预分配一个合理的切片容量 156 | estimatedCount := len(text) / 50 157 | if estimatedCount < 10 { 158 | estimatedCount = 10 159 | } 160 | sentences = make([]string, 0, estimatedCount) 161 | 162 | // 一次性转换为rune切片 163 | currentRunes := []rune(text) 164 | startPos := 0 165 | 166 | // 从对象池获取复用对象 167 | builder := builderPool.Get().(*strings.Builder) 168 | defer builderPool.Put(builder) 169 | builder.Grow(maxLen * 2) 170 | 171 | // 获取临时rune切片 172 | tempRunesPtr := runeSlicePool.Get().(*[]rune) 173 | tempRunes := (*tempRunesPtr)[:0] 174 | defer runeSlicePool.Put(tempRunesPtr) 175 | 176 | for startPos < len(currentRunes) { 177 | // 跳过开头的空白字符 178 | for startPos < len(currentRunes) && (currentRunes[startPos] == ' ' || currentRunes[startPos] == '\t' || currentRunes[startPos] == '\n') { 179 | startPos++ 180 | } 181 | 182 | if startPos >= len(currentRunes) { 183 | break 184 | } 185 | 186 | // 查找下一个分割点 187 | splitPos := findNextSplitPoint(currentRunes, startPos, maxLen) 188 | if splitPos == -1 { 189 | // 没有找到分割点,将剩余文本作为remaining 190 | segment := trimSpaceRunes(currentRunes[startPos:]) 191 | if len(segment) > 0 { 192 | remaining = string(segment) 193 | } 194 | break 195 | } 196 | 197 | // 提取当前段落 198 | builder.Reset() 199 | tempRunes = tempRunes[:0] 200 | 201 | // 收集并处理当前段落 202 | segment := trimSpaceRunes(currentRunes[startPos : splitPos+1]) 203 | 204 | // 检查段落是否满足最小长度要求且以标点符号结尾 205 | if len(segment) >= minLen && punctuationMap[segment[len(segment)-1]] { 206 | sentences = append(sentences, string(segment)) 207 | } else { 208 | // 如果不满足条件,将其添加到remaining中 209 | if len(segment) > 0 { 210 | if len(remaining) > 0 { 211 | remaining += " " 212 | } 213 | remaining += string(segment) 214 | } 215 | } 216 | 217 | startPos = splitPos + 1 218 | } 219 | 220 | return sentences, remaining 221 | } 222 | 223 | func main() { 224 | text := `厚,人家就晓得你又在敷衍我!每次问你都没有,你是不是不喜欢我了啦?哼,人家要生气喽!不跟你好了!除非...你答应我,等下带人家去夜市吃豆花啦~还要牵人家手手逛大街,一路上都要逗人家笑,逗得人家开心到飞上天!不然人家真的会不理你哦~` 225 | sentences, remaining := extractSmartSentences(text, 3, 200) 226 | for i, sentence := range sentences { 227 | fmt.Printf("\n句子%d:\n%s\n", i+1, sentence) 228 | } 229 | if remaining != "" { 230 | fmt.Printf("\n剩余:\n%s\n", remaining) 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /internal/domain/message_types.go: -------------------------------------------------------------------------------- 1 | package domain 2 | 3 | // 消息类型常量 4 | const ( 5 | MessageTypeHello = "hello" // 握手消息 6 | MessageTypeAbort = "abort" // 中止消息 7 | MessageTypeListen = "listen" // 监听消息 8 | MessageTypeIot = "iot" // 物联网消息 9 | ) 10 | 11 | // 服务器消息类型常量 12 | const ( 13 | ServerMessageTypeHello = "hello" // 握手消息 14 | ServerMessageTypeStt = "stt" // 语音转文本 15 | ServerMessageTypeTts = "tts" // 文本转语音 16 | ServerMessageTypeIot = "iot" // 物联网消息 17 | ServerMessageTypeLlm = "llm" // 大语言模型 18 | ServerMessageTypeText = "text" // 文本消息 19 | ) 20 | 21 | // 消息状态常量 22 | const ( 23 | MessageStateStart = "start" // 开始状态 24 | MessageStateStop = "stop" // 停止状态 25 | MessageStateDetect = "detect" // 检测状态 26 | MessageStateAbort = "abort" // 中止状态 27 | MessageStateSuccess = "success" // 成功状态 28 | ) 29 | -------------------------------------------------------------------------------- /internal/domain/tts/base.go: -------------------------------------------------------------------------------- 1 | package tts 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "xiaozhi-esp32-server-golang/internal/domain/tts/cosyvoice" 8 | "xiaozhi-esp32-server-golang/internal/domain/tts/doubao" 9 | "xiaozhi-esp32-server-golang/internal/domain/tts/edge" 10 | "xiaozhi-esp32-server-golang/internal/domain/tts/xiaozhi" 11 | ) 12 | 13 | // 基础TTS提供者接口(不含Context方法) 14 | type BaseTTSProvider interface { 15 | TextToSpeech(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) ([][]byte, error) 16 | TextToSpeechStream(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) (outputChan chan []byte, err error) 17 | } 18 | 19 | // 完整TTS提供者接口(包含Context方法) 20 | type TTSProvider interface { 21 | BaseTTSProvider 22 | } 23 | 24 | // GetTTSProvider 获取一个完整的TTS提供者(支持Context) 25 | func GetTTSProvider(providerName string, config map[string]interface{}) (TTSProvider, error) { 26 | var baseProvider BaseTTSProvider 27 | 28 | switch providerName { 29 | case "doubao": 30 | baseProvider = doubao.NewDoubaoTTSProvider(config) 31 | case "doubao_ws": 32 | baseProvider = doubao.NewDoubaoWSProvider(config) 33 | case "cosyvoice": 34 | baseProvider = cosyvoice.NewCosyVoiceTTSProvider(config) 35 | case "edge": 36 | baseProvider = edge.NewEdgeTTSProvider(config) 37 | case "xiaozhi": 38 | baseProvider = xiaozhi.NewXiaozhiProvider(config) 39 | default: 40 | return nil, fmt.Errorf("不支持的TTS提供者: %s", providerName) 41 | } 42 | 43 | // 使用适配器包装基础提供者,转换为完整的TTSProvider 44 | provider := &ContextTTSAdapter{baseProvider} 45 | return provider, nil 46 | } 47 | 48 | // ContextTTSAdapter 是一个适配器,为基础TTS提供者添加Context支持 49 | type ContextTTSAdapter struct { 50 | Provider BaseTTSProvider 51 | } 52 | 53 | // TextToSpeech 代理到原始提供者 54 | func (a *ContextTTSAdapter) TextToSpeech(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) ([][]byte, error) { 55 | return a.Provider.TextToSpeech(ctx, text, sampleRate, channels, frameDuration) 56 | } 57 | 58 | // TextToSpeechStream 代理到原始提供者 59 | func (a *ContextTTSAdapter) TextToSpeechStream(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) (outputChan chan []byte, err error) { 60 | return a.Provider.TextToSpeechStream(ctx, text, sampleRate, channels, frameDuration) 61 | } 62 | 63 | // TextToSpeechWithContext 使用Context版本的文本转语音 64 | func (a *ContextTTSAdapter) TextToSpeechWithContext(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) ([][]byte, error) { 65 | // 检查提供者是否直接支持Context版本 66 | if provider, ok := a.Provider.(interface { 67 | TextToSpeechWithContext(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) ([][]byte, error) 68 | }); ok { 69 | // 提供者直接支持Context版本 70 | return provider.TextToSpeechWithContext(ctx, text, sampleRate, channels, frameDuration) 71 | } 72 | 73 | // 否则使用标准版本,并通过goroutine和channel实现上下文控制 74 | resultChan := make(chan struct { 75 | frames [][]byte 76 | err error 77 | }) 78 | 79 | go func() { 80 | frames, err := a.Provider.TextToSpeech(ctx, text, sampleRate, channels, frameDuration) 81 | select { 82 | case <-ctx.Done(): 83 | // 上下文已取消,不发送结果 84 | return 85 | case resultChan <- struct { 86 | frames [][]byte 87 | err error 88 | }{frames, err}: 89 | // 结果已发送 90 | } 91 | }() 92 | 93 | select { 94 | case <-ctx.Done(): 95 | return nil, ctx.Err() 96 | case result := <-resultChan: 97 | return result.frames, result.err 98 | } 99 | } 100 | 101 | // TextToSpeechStreamWithContext 使用Context版本的流式文本转语音 102 | func (a *ContextTTSAdapter) TextToSpeechStreamWithContext(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) (outputChan chan []byte, cancelFunc func(), err error) { 103 | // 检查提供者是否直接支持Context版本 104 | if provider, ok := a.Provider.(interface { 105 | TextToSpeechStreamWithContext(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) (chan []byte, func(), error) 106 | }); ok { 107 | // 提供者直接支持Context版本 108 | return provider.TextToSpeechStreamWithContext(ctx, text, sampleRate, channels, frameDuration) 109 | } 110 | 111 | // 否则使用标准版本,但创建一个包装器来处理上下文取消 112 | streamChan, err := a.Provider.TextToSpeechStream(ctx, text, sampleRate, channels, frameDuration) 113 | if err != nil { 114 | return nil, nil, err 115 | } 116 | 117 | // 创建一个新的输出通道,用于转发和处理取消 118 | outputChan = make(chan []byte, 10) 119 | 120 | // 创建一个goroutine来转发数据并监听上下文取消 121 | go func() { 122 | defer close(outputChan) 123 | 124 | for { 125 | select { 126 | case <-ctx.Done(): 127 | // 上下文已取消,调用原始取消函数并退出 128 | cancelFunc() 129 | return 130 | case frame, ok := <-streamChan: 131 | if !ok { 132 | // 原始通道已关闭 133 | return 134 | } 135 | // 转发数据 136 | select { 137 | case <-ctx.Done(): 138 | // 上下文已取消 139 | cancelFunc() 140 | return 141 | case outputChan <- frame: 142 | // 成功转发数据 143 | } 144 | } 145 | } 146 | }() 147 | 148 | return outputChan, cancelFunc, nil 149 | } 150 | -------------------------------------------------------------------------------- /internal/domain/tts/common/audio_utils.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "time" 9 | 10 | log "xiaozhi-esp32-server-golang/logger" 11 | 12 | "github.com/go-audio/audio" 13 | "github.com/go-audio/wav" 14 | "github.com/gopxl/beep" 15 | "github.com/gopxl/beep/mp3" 16 | "gopkg.in/hraban/opus.v2" 17 | ) 18 | 19 | // min returns the smaller of x or y. 20 | func min(x, y int) int { 21 | if x < y { 22 | return x 23 | } 24 | return y 25 | } 26 | 27 | // readCloserWrapper 为 bytes.Reader 提供 Close 方法以实现 ReadCloser 接口 28 | type readCloserWrapper struct { 29 | *bytes.Reader 30 | } 31 | 32 | // Close 实现 io.Closer 接口 33 | func (r *readCloserWrapper) Close() error { 34 | return nil 35 | } 36 | 37 | // newReadCloserWrapper 创建一个新的 ReadCloser 包装 38 | func newReadCloserWrapper(data []byte) *readCloserWrapper { 39 | return &readCloserWrapper{bytes.NewReader(data)} 40 | } 41 | 42 | // WavToOpus 将WAV音频数据转换为标准Opus格式 43 | // 返回Opus帧的切片集合,每个切片是一个Opus编码帧 44 | func WavToOpus(wavData []byte, sampleRate int, channels int, bitRate int) ([][]byte, error) { 45 | // 创建WAV解码器 46 | wavReader := bytes.NewReader(wavData) 47 | wavDecoder := wav.NewDecoder(wavReader) 48 | if !wavDecoder.IsValidFile() { 49 | return nil, fmt.Errorf("无效的WAV文件") 50 | } 51 | 52 | // 读取WAV文件信息 53 | wavDecoder.ReadInfo() 54 | format := wavDecoder.Format() 55 | wavSampleRate := int(format.SampleRate) 56 | wavChannels := int(format.NumChannels) 57 | 58 | // 如果提供的参数与文件参数不一致,使用文件中的参数 59 | if sampleRate == 0 { 60 | sampleRate = wavSampleRate 61 | } 62 | if channels == 0 { 63 | channels = wavChannels 64 | } 65 | 66 | //打印wavDecoder信息 67 | fmt.Println("WAV格式:", format) 68 | 69 | enc, err := opus.NewEncoder(sampleRate, channels, opus.AppAudio) 70 | if err != nil { 71 | return nil, fmt.Errorf("创建Opus编码器失败: %v", err) 72 | } 73 | 74 | // 设置比特率 75 | if bitRate > 0 { 76 | if err := enc.SetBitrate(bitRate); err != nil { 77 | return nil, fmt.Errorf("设置比特率失败: %v", err) 78 | } 79 | } 80 | 81 | // 创建输出帧切片数组 82 | opusFrames := make([][]byte, 0) 83 | 84 | perFrameDuration := 20 85 | // PCM缓冲区 - Opus帧大小(60ms) 86 | frameSize := sampleRate * perFrameDuration / 1000 87 | pcmBuffer := make([]int16, frameSize*channels) 88 | opusBuffer := make([]byte, 1000) // 足够大的缓冲区存储编码后的数据 89 | 90 | // 读取音频缓冲区 91 | audioBuf := &audio.IntBuffer{Data: make([]int, frameSize*channels), Format: format} 92 | 93 | fmt.Println("开始转换...") 94 | for { 95 | // 读取WAV数据 96 | n, err := wavDecoder.PCMBuffer(audioBuf) 97 | if err == io.EOF || n == 0 { 98 | break 99 | } 100 | if err != nil { 101 | return nil, fmt.Errorf("读取WAV数据失败: %v", err) 102 | } 103 | 104 | // 将int转换为int16 105 | for i := 0; i < len(audioBuf.Data); i++ { 106 | if i < len(pcmBuffer) { 107 | pcmBuffer[i] = int16(audioBuf.Data[i]) 108 | } 109 | } 110 | 111 | // 编码为Opus格式 112 | n, err = enc.Encode(pcmBuffer, opusBuffer) 113 | if err != nil { 114 | return nil, fmt.Errorf("编码失败: %v", err) 115 | } 116 | 117 | // 将当前帧复制到新的切片中并添加到帧数组 118 | frameData := make([]byte, n) 119 | copy(frameData, opusBuffer[:n]) 120 | opusFrames = append(opusFrames, frameData) 121 | } 122 | 123 | return opusFrames, nil 124 | } 125 | 126 | type MP3Decoder struct { 127 | streamer beep.StreamSeekCloser 128 | format beep.Format 129 | enc *opus.Encoder 130 | pipeReader *io.PipeReader 131 | perFrameDurationMs int 132 | 133 | outputOpusChan chan []byte //opus一帧一帧的输出 134 | ctx context.Context // 新增:上下文控制 135 | } 136 | 137 | // CreateMP3Decoder 创建一个通过 Done 通道控制的 MP3 解码器 138 | // 为了兼容旧代码,保留此方法 139 | func CreateMP3Decoder(pipeReader *io.PipeReader, outputOpusChan chan []byte, perFrameDurationMs int, ctx context.Context) (*MP3Decoder, error) { 140 | return &MP3Decoder{ 141 | pipeReader: pipeReader, 142 | outputOpusChan: outputOpusChan, 143 | perFrameDurationMs: perFrameDurationMs, 144 | ctx: ctx, 145 | }, nil 146 | } 147 | 148 | func (d *MP3Decoder) Run(startTs int64) error { 149 | defer close(d.outputOpusChan) 150 | 151 | decoder, format, err := mp3.Decode(d.pipeReader) 152 | if err != nil { 153 | return fmt.Errorf("创建MP3解码器失败: %v", err) 154 | } 155 | log.Debugf("MP3格式: %d Hz, %d 通道", format.SampleRate, format.NumChannels) 156 | d.streamer = decoder 157 | d.format = format 158 | 159 | // 流式解码MP3 160 | defer func() { 161 | d.streamer.Close() 162 | }() 163 | 164 | // 获取MP3音频信息 165 | sampleRate := format.SampleRate 166 | channels := format.NumChannels 167 | 168 | // 始终使用单通道输出 169 | outputChannels := 1 170 | if channels > 1 { 171 | log.Debugf("将双声道音频转换为单声道输出") 172 | } 173 | 174 | enc, err := opus.NewEncoder(int(sampleRate), outputChannels, opus.AppAudio) 175 | if err != nil { 176 | return fmt.Errorf("创建Opus编码器失败: %v", err) 177 | } 178 | d.enc = enc 179 | 180 | //opus相关配置及缓冲区 创建缓冲区用于接收音频采样 181 | frameDurationMs := d.perFrameDurationMs //60ms 182 | frameSize := int(sampleRate) * frameDurationMs / 1000 // 60ms帧大小 183 | // 临时PCM存储,将音频转换为PCM格式 184 | pcmBuffer := make([]int16, frameSize*outputChannels) 185 | 186 | //mp3读缓冲区 187 | mp3Buffer := make([][2]float64, 1024) 188 | 189 | //opus输出缓冲区 190 | opusBuffer := make([]byte, 1000) 191 | 192 | currentFramePos := 0 // 当前填充到pcmBuffer的位置 193 | var firstFrame bool 194 | for { 195 | select { 196 | case <-d.ctx.Done(): 197 | return nil 198 | default: 199 | // 从MP3读取PCM数据 200 | n, ok := d.streamer.Stream(mp3Buffer) 201 | if !firstFrame { 202 | log.Infof("tts云端首帧耗时: %d ms", time.Now().UnixMilli()-startTs) 203 | } 204 | //fmt.Printf("for loop, n: %d\n", n) 205 | if !ok { 206 | // 处理剩余不足一帧的数据 207 | if currentFramePos > 0 { 208 | // 创建一个完整的帧缓冲区,用0填充剩余部分 209 | paddedFrame := make([]int16, len(pcmBuffer)) 210 | copy(paddedFrame, pcmBuffer[:currentFramePos]) // 将有效数据复制到开头,剩余部分默认为0 211 | 212 | // 编码补齐后的完整帧 213 | n, err := enc.Encode(paddedFrame, opusBuffer) 214 | if err != nil { 215 | log.Errorf("编码剩余数据失败: %v\n", err) 216 | return fmt.Errorf("编码剩余数据失败: %v", err) 217 | } else { 218 | frameData := make([]byte, n) 219 | copy(frameData, opusBuffer[:n]) 220 | 221 | select { 222 | case <-d.ctx.Done(): 223 | log.Debugf("mp3Decoder context done, exit") 224 | return nil 225 | default: 226 | d.outputOpusChan <- frameData 227 | } 228 | } 229 | } 230 | return nil 231 | } 232 | 233 | if n == 0 { 234 | continue 235 | } 236 | // 将浮点音频数据转换为PCM格式(16位整数) 237 | for i := 0; i < n; i++ { 238 | // 先在浮点数阶段计算平均值,避免整数相加时溢出 239 | monoSampleFloat := (mp3Buffer[i][0] + mp3Buffer[i][1]) * 0.5 240 | 241 | // 进行音量限制,确保不超出范围 242 | if monoSampleFloat > 1.0 { 243 | monoSampleFloat = 1.0 244 | } else if monoSampleFloat < -1.0 { 245 | monoSampleFloat = -1.0 246 | } 247 | 248 | // 将浮点平均值转换为16位整数 249 | monoSample := int16(monoSampleFloat * 32767.0) 250 | pcmBuffer[currentFramePos] = monoSample 251 | currentFramePos++ 252 | 253 | // 如果pcmBuffer已满一帧,则进行编码 254 | if currentFramePos == len(pcmBuffer) { 255 | opusLen, err := enc.Encode(pcmBuffer, opusBuffer) 256 | if err != nil { 257 | log.Errorf("编码失败: %v\n", err) 258 | continue 259 | } 260 | 261 | // 将当前帧复制到新的切片中并添加到帧数组 262 | frameData := make([]byte, opusLen) 263 | copy(frameData, opusBuffer[:opusLen]) 264 | 265 | select { 266 | case <-d.ctx.Done(): 267 | log.Debugf("mp3Decoder context done, exit") 268 | return nil 269 | default: 270 | if !firstFrame { 271 | firstFrame = true 272 | log.Infof("tts云端->首帧解码完成耗时: %d ms", time.Now().UnixMilli()-startTs) 273 | } 274 | 275 | d.outputOpusChan <- frameData 276 | } 277 | 278 | currentFramePos = 0 // 重置帧位置 279 | } 280 | } 281 | } 282 | } 283 | 284 | return nil 285 | } 286 | -------------------------------------------------------------------------------- /internal/domain/tts/common/audio_utils_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | -------------------------------------------------------------------------------- /internal/domain/tts/common/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackers365/xiaozhi-esp32-server-golang/f00bb50d418dcec9418f37c2911b332323d8afed/internal/domain/tts/common/test.wav -------------------------------------------------------------------------------- /internal/domain/tts/cosyvoice/cosyvoice_test.go: -------------------------------------------------------------------------------- 1 | package cosyvoice 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestCosyVoiceTTS(t *testing.T) { 10 | // 跳过实际的网络请求测试,除非设置了环境变量 11 | if os.Getenv("RUN_COSYVOICE_TEST") != "1" { 12 | t.Skip("跳过CosyVoice API测试,设置环境变量RUN_COSYVOICE_TEST=1以启用") 13 | } 14 | 15 | config := map[string]interface{}{ 16 | "api_url": "https://cosyvoice.com/tts", 17 | "spk_id": "OUeAo1mhq6IBExi", 18 | "frame_duration": float64(60), 19 | "target_sr": float64(16000), 20 | "audio_format": "mp3", 21 | "instruct_text": "你好", 22 | } 23 | 24 | provider := NewCosyVoiceTTSProvider(config) 25 | 26 | // 测试文本转语音 27 | t.Run("TestTextToSpeech", func(t *testing.T) { 28 | frames, err := provider.TextToSpeech("你会说四川话吗") 29 | if err != nil { 30 | t.Fatalf("TextToSpeech失败: %v", err) 31 | } 32 | 33 | if len(frames) == 0 { 34 | t.Error("未返回任何音频帧") 35 | } 36 | }) 37 | 38 | // 测试流式文本转语音 39 | t.Run("TestTextToSpeechStream", func(t *testing.T) { 40 | outputChan, cancel, err := provider.TextToSpeechStream("你会说四川话吗") 41 | if err != nil { 42 | t.Fatalf("TextToSpeechStream失败: %v", err) 43 | } 44 | 45 | defer cancel() 46 | 47 | // 接收所有帧 48 | var receivedFrames [][]byte 49 | timeout := time.After(10 * time.Second) 50 | 51 | receiveLoop: 52 | for { 53 | select { 54 | case frame, ok := <-outputChan: 55 | if !ok { 56 | break receiveLoop 57 | } 58 | receivedFrames = append(receivedFrames, frame) 59 | case <-timeout: 60 | t.Error("接收音频帧超时") 61 | break receiveLoop 62 | } 63 | } 64 | 65 | if len(receivedFrames) == 0 { 66 | t.Error("未接收到任何音频帧") 67 | } 68 | }) 69 | } 70 | -------------------------------------------------------------------------------- /internal/domain/tts/doubao/doubao.go: -------------------------------------------------------------------------------- 1 | package doubao 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/rand" 7 | "encoding/base64" 8 | "encoding/json" 9 | "fmt" 10 | "io" 11 | "net" 12 | "net/http" 13 | "os" 14 | "path/filepath" 15 | "sync" 16 | "time" 17 | 18 | "xiaozhi-esp32-server-golang/internal/domain/tts/common" 19 | log "xiaozhi-esp32-server-golang/logger" 20 | ) 21 | 22 | // 全局HTTP客户端,实现连接池 23 | var ( 24 | httpClient *http.Client 25 | httpClientOnce sync.Once 26 | ) 27 | 28 | // 获取配置了连接池的HTTP客户端 29 | func getHTTPClient() *http.Client { 30 | httpClientOnce.Do(func() { 31 | transport := &http.Transport{ 32 | Proxy: http.ProxyFromEnvironment, 33 | DialContext: (&net.Dialer{ 34 | Timeout: 30 * time.Second, 35 | KeepAlive: 30 * time.Second, 36 | }).DialContext, 37 | MaxIdleConns: 100, 38 | MaxIdleConnsPerHost: 10, 39 | IdleConnTimeout: 90 * time.Second, 40 | TLSHandshakeTimeout: 10 * time.Second, 41 | ExpectContinueTimeout: 1 * time.Second, 42 | } 43 | httpClient = &http.Client{ 44 | Transport: transport, 45 | Timeout: 30 * time.Second, 46 | } 47 | }) 48 | return httpClient 49 | } 50 | 51 | // DoubaoTTSProvider 读伴TTS提供者 52 | type DoubaoTTSProvider struct { 53 | AppID string 54 | AccessToken string 55 | Cluster string 56 | Voice string 57 | APIURL string 58 | Authorization string 59 | Header map[string]string 60 | } 61 | 62 | // 请求结构体 63 | type doubaoRequest struct { 64 | App appInfo `json:"app"` 65 | User userInfo `json:"user"` 66 | Audio audioInfo `json:"audio"` 67 | Request requestInfo `json:"request"` 68 | } 69 | 70 | type appInfo struct { 71 | AppID string `json:"appid"` 72 | Token string `json:"token"` 73 | Cluster string `json:"cluster"` 74 | } 75 | 76 | type userInfo struct { 77 | UID string `json:"uid"` 78 | } 79 | 80 | type audioInfo struct { 81 | VoiceType string `json:"voice_type"` 82 | Encoding string `json:"encoding"` 83 | Rate int `json:"rate"` 84 | SpeedRatio float64 `json:"speed_ratio"` 85 | VolumeRatio float64 `json:"volume_ratio"` 86 | PitchRatio float64 `json:"pitch_ratio"` 87 | } 88 | 89 | type requestInfo struct { 90 | ReqID string `json:"reqid"` 91 | Text string `json:"text"` 92 | TextType string `json:"text_type"` 93 | Operation string `json:"operation"` 94 | WithFrontend int `json:"with_frontend"` 95 | FrontendType string `json:"frontend_type"` 96 | } 97 | 98 | // 响应结构体 99 | type doubaoResponse struct { 100 | Data string `json:"data"` 101 | } 102 | 103 | // 生成UUID 104 | func generateUUID() string { 105 | b := make([]byte, 16) 106 | _, err := rand.Read(b) 107 | if err != nil { 108 | return fmt.Sprintf("%d", time.Now().UnixNano()) 109 | } 110 | return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) 111 | } 112 | 113 | // NewDoubaoTTSProvider 创建新的读伴TTS提供者 114 | func NewDoubaoTTSProvider(config map[string]interface{}) *DoubaoTTSProvider { 115 | appID, _ := config["appid"].(string) 116 | accessToken, _ := config["access_token"].(string) 117 | cluster, _ := config["cluster"].(string) 118 | voice, _ := config["voice"].(string) 119 | apiURL, _ := config["api_url"].(string) 120 | authorization, _ := config["authorization"].(string) 121 | 122 | // 检查令牌 123 | if accessToken == "" { 124 | log.Error("TTS 访问令牌不能为空") 125 | } 126 | 127 | return &DoubaoTTSProvider{ 128 | AppID: appID, 129 | AccessToken: accessToken, 130 | Cluster: cluster, 131 | Voice: voice, 132 | APIURL: apiURL, 133 | Authorization: authorization, 134 | Header: map[string]string{"Authorization": fmt.Sprintf("%s%s", authorization, accessToken)}, 135 | } 136 | } 137 | 138 | // TextToSpeech 将文本转换为语音,返回音频帧数据和错误 139 | func (p *DoubaoTTSProvider) TextToSpeech(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) ([][]byte, error) { 140 | // 准备请求数据 141 | reqData := doubaoRequest{ 142 | App: appInfo{ 143 | AppID: p.AppID, 144 | Token: p.AccessToken, 145 | Cluster: p.Cluster, 146 | }, 147 | User: userInfo{ 148 | UID: "1", 149 | }, 150 | Audio: audioInfo{ 151 | VoiceType: p.Voice, 152 | Encoding: "wav", 153 | Rate: sampleRate, 154 | SpeedRatio: 1.0, 155 | VolumeRatio: 1.0, 156 | PitchRatio: 1.0, 157 | }, 158 | Request: requestInfo{ 159 | ReqID: generateUUID(), 160 | Text: text, 161 | TextType: "plain", 162 | Operation: "query", 163 | WithFrontend: 1, 164 | FrontendType: "unitTson", 165 | }, 166 | } 167 | 168 | // 转换为JSON 169 | jsonData, err := json.Marshal(reqData) 170 | if err != nil { 171 | return nil, fmt.Errorf("无法序列化请求: %v", err) 172 | } 173 | 174 | // 创建HTTP请求 175 | req, err := http.NewRequest("POST", p.APIURL, bytes.NewBuffer(jsonData)) 176 | if err != nil { 177 | return nil, fmt.Errorf("创建请求失败: %v", err) 178 | } 179 | 180 | // 设置请求头 181 | for k, v := range p.Header { 182 | req.Header.Set(k, v) 183 | } 184 | req.Header.Set("Content-Type", "application/json") 185 | 186 | // 使用连接池发送请求 187 | client := getHTTPClient() 188 | resp, err := client.Do(req) 189 | if err != nil { 190 | return nil, fmt.Errorf("发送请求失败: %v", err) 191 | } 192 | defer resp.Body.Close() 193 | 194 | // 读取响应 195 | body, err := io.ReadAll(resp.Body) 196 | if err != nil { 197 | return nil, fmt.Errorf("读取响应失败: %v", err) 198 | } 199 | 200 | // 解析响应 201 | var result map[string]interface{} 202 | if err := json.Unmarshal(body, &result); err != nil { 203 | return nil, fmt.Errorf("解析响应失败: %v", err) 204 | } 205 | 206 | // 提取音频数据 207 | if audioData, ok := result["data"].(string); ok { 208 | // 获取原始响应数据 209 | wavData, err := base64.StdEncoding.DecodeString(audioData) 210 | if err != nil { 211 | return nil, fmt.Errorf("解码音频数据失败: %v", err) 212 | } 213 | 214 | // 转换为Opus帧并直接返回 215 | return common.WavToOpus(wavData, 0, 0, 0) 216 | } 217 | 218 | return nil, fmt.Errorf("响应中没有数据字段, 状态码: %d, 响应: %s", resp.StatusCode, string(body)) 219 | } 220 | 221 | // GetVoiceInfo 获取语音信息 222 | func (p *DoubaoTTSProvider) GetVoiceInfo() map[string]interface{} { 223 | return map[string]interface{}{ 224 | "voice": p.Voice, 225 | "type": "doubao", 226 | } 227 | } 228 | 229 | // saveWavToTmp 将WAV数据保存到tmp目录 230 | func saveWavToTmp(wavData []byte) error { 231 | // 确保tmp目录存在 232 | tmpDir := "tmp" 233 | if err := os.MkdirAll(tmpDir, 0755); err != nil { 234 | return fmt.Errorf("创建tmp目录失败: %v", err) 235 | } 236 | 237 | // 生成唯一文件名 238 | timestamp := time.Now().Format("20060102_150405") 239 | uuid := generateUUID() 240 | filename := filepath.Join(tmpDir, fmt.Sprintf("wav_%s_%s.wav", timestamp, uuid[:8])) 241 | 242 | // 写入文件 243 | if err := os.WriteFile(filename, wavData, 0644); err != nil { 244 | return fmt.Errorf("写入WAV文件失败: %v", err) 245 | } 246 | 247 | log.Infof("WAV文件已保存: %s", filename) 248 | return nil 249 | } 250 | 251 | func (p *DoubaoTTSProvider) TextToSpeechStream(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) (outputChan chan []byte, err error) { 252 | return nil, nil 253 | } 254 | -------------------------------------------------------------------------------- /internal/domain/tts/doubao/doubao_test.go: -------------------------------------------------------------------------------- 1 | package doubao 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "os" 7 | "testing" 8 | "time" 9 | 10 | "xiaozhi-esp32-server-golang/internal/data/client" 11 | "xiaozhi-esp32-server-golang/internal/domain/tts/common" 12 | ) 13 | 14 | // 测试创建TTS提供者 15 | func TestNewDoubaoTTSProvider(t *testing.T) { 16 | config := map[string]interface{}{ 17 | "appid": "test_app_id", 18 | "access_token": "test_token", 19 | "cluster": "test_cluster", 20 | "voice": "test_voice", 21 | "api_url": "https://api.test.com", 22 | "authorization": "Bearer ", 23 | } 24 | 25 | provider := NewDoubaoTTSProvider(config) 26 | 27 | if provider.AppID != "test_app_id" { 28 | t.Errorf("AppID不匹配,期望: %s, 实际: %s", "test_app_id", provider.AppID) 29 | } 30 | if provider.Voice != "test_voice" { 31 | t.Errorf("Voice不匹配,期望: %s, 实际: %s", "test_voice", provider.Voice) 32 | } 33 | if provider.Header["Authorization"] != "Bearer test_token" { 34 | t.Errorf("Authorization不匹配,期望: %s, 实际: %s", "Bearer test_token", provider.Header["Authorization"]) 35 | } 36 | } 37 | 38 | // 测试GetVoiceInfo方法 39 | func TestGetVoiceInfo(t *testing.T) { 40 | provider := &DoubaoTTSProvider{ 41 | Voice: "xiaomei", 42 | } 43 | 44 | info := provider.GetVoiceInfo() 45 | if info["voice"] != "xiaomei" { 46 | t.Errorf("语音信息不匹配,期望voice: %s, 实际: %s", "xiaomei", info["voice"]) 47 | } 48 | if info["type"] != "doubao" { 49 | t.Errorf("类型不匹配,期望type: %s, 实际: %s", "doubao", info["type"]) 50 | } 51 | } 52 | 53 | // 测试生成UUID 54 | func TestGenerateUUID(t *testing.T) { 55 | uuid := generateUUID() 56 | if len(uuid) == 0 { 57 | t.Error("生成的UUID为空") 58 | } 59 | 60 | // 测试多次生成的UUID是否不同 61 | anotherUUID := generateUUID() 62 | if uuid == anotherUUID { 63 | t.Error("两次生成的UUID相同,期望不同的值") 64 | } 65 | } 66 | 67 | // 注意:以下是一个简化的WavToOpus测试 68 | // 如果要全面测试,需要准备有效的WAV数据并验证转换结果 69 | func TestWavToOpus_InvalidData(t *testing.T) { 70 | // 测试无效的WAV数据 71 | _, err := common.WavToOpus([]byte("这不是WAV数据"), client.SampleRate, client.Channels, client.FrameDuration) 72 | if err == nil { 73 | t.Error("期望处理无效数据时返回错误,但没有") 74 | } 75 | } 76 | 77 | // MockRoundTripper 模拟HTTP请求的响应 78 | type MockRoundTripper struct { 79 | Response *http.Response 80 | Err error 81 | } 82 | 83 | func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 84 | return m.Response, m.Err 85 | } 86 | 87 | // 使用字节跳动配置进行测试 88 | func TestByteProviderTextToSpeech(t *testing.T) { 89 | // 配置测试参数 - 使用用户提供的实际配置 90 | config := map[string]interface{}{ 91 | "api_url": "https://openspeech.bytedance.com/api/v1/tts", 92 | "voice": "BV001_streaming", 93 | "authorization": "Bearer;", 94 | "appid": "6886011847", 95 | "access_token": "access_token", 96 | "cluster": "volcano_tts", 97 | } 98 | 99 | t.Logf("使用配置: API URL=%s, Voice=%s, AppID=%s", 100 | config["api_url"], config["voice"], config["appid"]) 101 | 102 | // 创建TTS提供者 103 | provider := NewDoubaoTTSProvider(config) 104 | 105 | // 测试文本到语音转换 106 | testText := "这是一个测试,使用字节跳动TTS服务生成语音" 107 | t.Logf("测试文本: %s", testText) 108 | 109 | // 确保输出目录存在 110 | outputDir := "tmp/" 111 | if _, err := os.Stat(outputDir); os.IsNotExist(err) { 112 | if err := os.MkdirAll(outputDir, 0755); err != nil { 113 | t.Logf("无法创建输出目录: %v", err) 114 | } 115 | } 116 | 117 | // 使用TextToSpeech方法 118 | audioFrames, err := provider.TextToSpeech(context.Background(), testText) 119 | if err != nil { 120 | t.Fatalf("TextToSpeech失败: %v", err) 121 | } 122 | 123 | // 计算总大小 124 | var totalSize int 125 | for _, frame := range audioFrames { 126 | totalSize += len(frame) 127 | } 128 | 129 | // 合并所有帧以便保存到文件 130 | mergedAudio := make([]byte, totalSize) 131 | offset := 0 132 | for _, frame := range audioFrames { 133 | copy(mergedAudio[offset:], frame) 134 | offset += len(frame) 135 | } 136 | 137 | // 保存结果 138 | outputPath := outputDir + "byte_test_" + time.Now().Format("20060102_150405") + ".opus" 139 | if err := os.WriteFile(outputPath, mergedAudio, 0644); err != nil { 140 | t.Logf("保存音频文件失败: %v", err) 141 | } else { 142 | t.Logf("音频文件已保存到: %s", outputPath) 143 | } 144 | 145 | // 验证结果 146 | if len(audioFrames) == 0 { 147 | t.Error("生成的音频帧为空") 148 | } else { 149 | t.Logf("生成的音频帧数量: %d, 总大小: %d 字节", len(audioFrames), totalSize) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /internal/domain/tts/doubao/doubao_ws_test.go: -------------------------------------------------------------------------------- 1 | package doubao 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | // 测试TextToSpeechStream方法 10 | func TestTextToSpeechStream(t *testing.T) { 11 | // 使用实际的配置 12 | config := map[string]interface{}{ 13 | "appid": "appid", 14 | "access_token": "access_token", 15 | "cluster": "volcano_tts", 16 | //"voice": "BV001_streaming", 17 | "voice": "zh_female_wanwanxiaohe_moon_bigtts", 18 | "ws_host": "openspeech.bytedance.com", 19 | "use_stream": true, 20 | } 21 | 22 | // 创建一个测试provider 23 | provider := NewDoubaoWSProvider(config) 24 | 25 | t.Run("测试正常流式回调", func(t *testing.T) { 26 | 27 | // 直接调用实际的API 28 | outputOpusChan, err := provider.TextToSpeechStream(context.Background(), "这是一个测试文本,今天天气怎么样, 今天天气真好, 你是中国人, 咱们去北京天津玩好不好, 北京有什么好玩的,天津之眼吧") 29 | if err != nil { 30 | t.Fatalf("TextToSpeechStream返回错误: %v", err) 31 | } 32 | 33 | for opusFrame := range outputOpusChan { 34 | fmt.Printf("收到opus帧: %d\n", len(opusFrame)) 35 | } 36 | 37 | }) 38 | /* 39 | t.Run("测试取消功能", func(t *testing.T) { 40 | var ( 41 | receivedChunks [][]byte 42 | mu sync.Mutex 43 | ) 44 | 45 | // 回调函数 46 | onChunk := func(chunkData []byte, isLast bool) error { 47 | mu.Lock() 48 | defer mu.Unlock() 49 | 50 | if chunkData != nil { 51 | receivedChunks = append(receivedChunks, chunkData) 52 | } 53 | 54 | return nil 55 | } 56 | 57 | // 直接调用实际的API 58 | cancelFunc, err := provider.TextToSpeechStream("另一个测试文本", onChunk) 59 | if err != nil { 60 | t.Fatalf("TextToSpeechStream返回错误: %v", err) 61 | } 62 | 63 | // 等待一小段时间后取消 64 | time.Sleep(500 * time.Millisecond) 65 | cancelFunc() 66 | 67 | // 等待一小段时间让取消生效 68 | time.Sleep(500 * time.Millisecond) 69 | 70 | // 验证结果 71 | mu.Lock() 72 | defer mu.Unlock() 73 | 74 | // 因为取消了处理,所以收到的块数应该是有限的 75 | t.Logf("取消后接收到 %d 个音频块", len(receivedChunks)) 76 | }) 77 | */ 78 | } 79 | -------------------------------------------------------------------------------- /internal/domain/tts/edge/edge.go: -------------------------------------------------------------------------------- 1 | package edge 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "os" 8 | "time" 9 | 10 | "xiaozhi-esp32-server-golang/internal/domain/tts/common" 11 | log "xiaozhi-esp32-server-golang/logger" 12 | 13 | "github.com/difyz9/edge-tts-go/pkg/communicate" 14 | ) 15 | 16 | // EdgeTTSProvider Edge TTS 提供者 17 | // 支持一次性和流式TTS,输出Opus帧 18 | // 配置参数:voice, rate, volume, pitch, connectTimeout, receiveTimeout 19 | type EdgeTTSProvider struct { 20 | Voice string 21 | Rate string 22 | Volume string 23 | Pitch string 24 | ConnectTimeout int 25 | ReceiveTimeout int 26 | } 27 | 28 | // NewEdgeTTSProvider 创建EdgeTTSProvider 29 | func NewEdgeTTSProvider(config map[string]interface{}) *EdgeTTSProvider { 30 | voice, _ := config["voice"].(string) 31 | rate, _ := config["rate"].(string) 32 | volume, _ := config["volume"].(string) 33 | pitch, _ := config["pitch"].(string) 34 | connectTimeout, _ := config["connect_timeout"].(int) 35 | receiveTimeout, _ := config["receive_timeout"].(int) 36 | if rate == "" { 37 | rate = "+0%" 38 | } 39 | if volume == "" { 40 | volume = "+0%" 41 | } 42 | if pitch == "" { 43 | pitch = "+0Hz" 44 | } 45 | if connectTimeout == 0 { 46 | connectTimeout = 10 47 | } 48 | if receiveTimeout == 0 { 49 | receiveTimeout = 60 50 | } 51 | return &EdgeTTSProvider{ 52 | Voice: voice, 53 | Rate: rate, 54 | Volume: volume, 55 | Pitch: pitch, 56 | ConnectTimeout: connectTimeout, 57 | ReceiveTimeout: receiveTimeout, 58 | } 59 | } 60 | 61 | // TextToSpeech 一次性合成,返回Opus帧 62 | func (p *EdgeTTSProvider) TextToSpeech(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) ([][]byte, error) { 63 | startTs := time.Now().UnixMilli() 64 | // 临时MP3文件 65 | tmpFile := fmt.Sprintf("/tmp/edge-tts-%d.mp3", time.Now().UnixNano()) 66 | defer os.Remove(tmpFile) 67 | 68 | comm, err := communicate.NewCommunicate( 69 | text, 70 | p.Voice, 71 | p.Rate, 72 | p.Volume, 73 | p.Pitch, 74 | "", // proxy 75 | p.ConnectTimeout, 76 | p.ReceiveTimeout, 77 | ) 78 | if err != nil { 79 | log.Errorf("EdgeTTS Communicate创建失败: %v", err) 80 | return nil, err 81 | } 82 | // 保存MP3 83 | err = comm.Save(ctx, tmpFile, "") 84 | if err != nil { 85 | log.Errorf("EdgeTTS保存MP3失败: %v", err) 86 | return nil, err 87 | } 88 | // MP3转Opus 89 | f, err := os.Open(tmpFile) 90 | if err != nil { 91 | return nil, fmt.Errorf("打开MP3失败: %v", err) 92 | } 93 | defer f.Close() 94 | pipeReader, pipeWriter := io.Pipe() 95 | outputChan := make(chan []byte, 1000) 96 | // 写入MP3数据到pipe 97 | go func() { 98 | _, _ = io.Copy(pipeWriter, f) 99 | pipeWriter.Close() 100 | }() 101 | mp3Decoder, err := common.CreateMP3Decoder(pipeReader, outputChan, frameDuration, ctx) 102 | if err != nil { 103 | return nil, fmt.Errorf("创建MP3解码器失败: %v", err) 104 | } 105 | var opusFrames [][]byte 106 | done := make(chan struct{}) 107 | go func() { 108 | for frame := range outputChan { 109 | opusFrames = append(opusFrames, frame) 110 | } 111 | done <- struct{}{} 112 | }() 113 | if err := mp3Decoder.Run(startTs); err != nil { 114 | return nil, fmt.Errorf("MP3解码失败: %v", err) 115 | } 116 | <-done 117 | return opusFrames, nil 118 | } 119 | 120 | // TextToSpeechStream 流式合成,返回Opus帧chan 121 | func (p *EdgeTTSProvider) TextToSpeechStream(ctx context.Context, text string, sampleRate int, channels int, frameDuration int) (chan []byte, error) { 122 | startTs := time.Now().UnixMilli() 123 | comm, err := communicate.NewCommunicate( 124 | text, 125 | p.Voice, 126 | p.Rate, 127 | p.Volume, 128 | p.Pitch, 129 | "", // proxy 130 | p.ConnectTimeout, 131 | p.ReceiveTimeout, 132 | ) 133 | if err != nil { 134 | log.Errorf("EdgeTTS Communicate创建失败: %v", err) 135 | return nil, err 136 | } 137 | 138 | chunkChan, errChan := comm.Stream(ctx) 139 | outputChan := make(chan []byte, 100) 140 | pipeReader, pipeWriter := io.Pipe() 141 | // MP3转Opus解码器 142 | go func() { 143 | defer pipeWriter.Close() 144 | for chunk := range chunkChan { 145 | if chunk.Type == "audio" { 146 | _, _ = pipeWriter.Write(chunk.Data) 147 | } 148 | } 149 | log.Debugf("EdgeTTS流式合成结束, 耗时: %d ms", time.Now().UnixMilli()-startTs) 150 | if err := <-errChan; err != nil { 151 | log.Errorf("EdgeTTS流式合成出错: %v", err) 152 | } 153 | }() 154 | // 启动MP3→Opus解码 155 | go func() { 156 | mp3Decoder, err := common.CreateMP3Decoder(pipeReader, outputChan, frameDuration, ctx) 157 | if err != nil { 158 | log.Errorf("EdgeTTS MP3解码器创建失败: %v", err) 159 | return 160 | } 161 | if err := mp3Decoder.Run(startTs); err != nil { 162 | log.Errorf("EdgeTTS MP3解码失败: %v", err) 163 | } 164 | log.Debugf("EdgeTTS MP3解码结束, 耗时: %d ms", time.Now().UnixMilli()-startTs) 165 | }() 166 | return outputChan, nil 167 | } 168 | -------------------------------------------------------------------------------- /internal/domain/tts/edge/edge_test.go: -------------------------------------------------------------------------------- 1 | package edge 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestEdgeTTSProvider(t *testing.T) { 10 | 11 | config := map[string]interface{}{ 12 | "voice": "zh-CN-XiaoxiaoNeural", 13 | "rate": "+0%", 14 | "volume": "+0%", 15 | "pitch": "+0Hz", 16 | "connect_timeout": 10, 17 | "receive_timeout": 60, 18 | } 19 | 20 | provider := NewEdgeTTSProvider(config) 21 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 22 | defer cancel() 23 | 24 | t.Run("TestTextToSpeech", func(t *testing.T) { 25 | frames, err := provider.TextToSpeech(ctx, "你好,EdgeTTS测试") 26 | if err != nil { 27 | t.Fatalf("TextToSpeech失败: %v", err) 28 | } 29 | if len(frames) == 0 { 30 | t.Error("未返回任何音频帧") 31 | } 32 | }) 33 | 34 | t.Run("TestTextToSpeechStream", func(t *testing.T) { 35 | outputChan, err := provider.TextToSpeechStream(ctx, "你好,EdgeTTS流式测试") 36 | if err != nil { 37 | t.Fatalf("TextToSpeechStream失败: %v", err) 38 | } 39 | var receivedFrames [][]byte 40 | timeout := time.After(20 * time.Second) 41 | ReceiveLoop: 42 | for { 43 | select { 44 | case frame, ok := <-outputChan: 45 | if !ok { 46 | break ReceiveLoop 47 | } 48 | receivedFrames = append(receivedFrames, frame) 49 | case <-timeout: 50 | t.Error("接收音频帧超时") 51 | break ReceiveLoop 52 | } 53 | } 54 | if len(receivedFrames) == 0 { 55 | t.Error("未接收到任何音频帧") 56 | } 57 | }) 58 | } 59 | -------------------------------------------------------------------------------- /internal/domain/tts/xiaozhi/xiaozhi_test.go: -------------------------------------------------------------------------------- 1 | package xiaozhi 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "github.com/go-audio/audio" 10 | "github.com/go-audio/wav" 11 | "github.com/sirupsen/logrus" 12 | "github.com/spf13/viper" 13 | "gopkg.in/hraban/opus.v2" 14 | 15 | "xiaozhi-esp32-server-golang/internal/util/workqueue" 16 | ) 17 | 18 | func OpusToWav(opusData [][]byte, sampleRate int, channels int, fileName string) ([][]int16, error) { 19 | opusDecoder, err := opus.NewDecoder(sampleRate, channels) 20 | if err != nil { 21 | return nil, fmt.Errorf("创建Opus解码器失败: %v", err) 22 | } 23 | 24 | wavOut, err := os.Create(fileName) 25 | if err != nil { 26 | return nil, fmt.Errorf("创建WAV文件失败: %v", err) 27 | } 28 | 29 | pcmDataList := make([][]int16, 0) 30 | pcmBuffer := make([]int16, 4096) 31 | 32 | wavEncoder := wav.NewEncoder(wavOut, sampleRate, 16, channels, 1) 33 | wavBuffer := audio.IntBuffer{ 34 | Format: &audio.Format{ 35 | NumChannels: channels, // 使用传入的通道数 36 | SampleRate: sampleRate, 37 | }, 38 | SourceBitDepth: 16, 39 | Data: make([]int, 4096), 40 | } 41 | 42 | for _, frame := range opusData { 43 | n, err := opusDecoder.Decode(frame, pcmBuffer) 44 | if err != nil { 45 | return nil, fmt.Errorf("解码失败: %v", err) 46 | } 47 | copyData := make([]int16, len(pcmBuffer[:n])) 48 | copy(copyData, pcmBuffer[:n]) 49 | pcmDataList = append(pcmDataList, copyData) 50 | 51 | //fmt.Println("pcmData len: ", len(copyData)) 52 | 53 | // 将PCM数据转换为int格式 54 | for i := 0; i < len(copyData); i++ { 55 | wavBuffer.Data = append(wavBuffer.Data, int(copyData[i])) 56 | } 57 | } 58 | 59 | // 写入WAV文件 60 | err = wavEncoder.Write(&wavBuffer) 61 | if err != nil { 62 | return nil, fmt.Errorf("写入WAV文件失败: %v", err) 63 | } 64 | 65 | wavEncoder.Close() 66 | 67 | return pcmDataList, nil 68 | } 69 | 70 | func initLog() error { 71 | // 使用标准输出而不是文件 72 | logrus.SetOutput(os.Stdout) 73 | 74 | // 禁用默认的调用者报告,使用自定义的caller字段 75 | logrus.SetReportCaller(false) 76 | logrus.SetFormatter(&logrus.TextFormatter{ 77 | TimestampFormat: "2006-01-02 15:04:05.000", //时间格式化,添加毫秒 78 | ForceColors: true, // 启用颜色输出 79 | }) 80 | logLevel, _ := logrus.ParseLevel(viper.GetString("log.level")) 81 | if logLevel == 0 { 82 | logLevel = logrus.DebugLevel // 默认设置为Debug级别 83 | } 84 | logrus.SetLevel(logLevel) 85 | return nil 86 | } 87 | 88 | func TestTextToSpeechStream(t *testing.T) { 89 | //初始化log日志输出至标准输出 90 | //initLog() 91 | provider := NewXiaozhiProvider(map[string]interface{}{ 92 | "server_addr": "wss://api.tenclass.net/xiaozhi/v1/", 93 | "device_id": "ba:8f:17:de:94:94", 94 | }) 95 | 96 | textList := []string{ 97 | "你好,小智TTS单元测试", 98 | "讲个笑话", 99 | "今天天气怎么样", 100 | "你叫什么名字", 101 | "你今年几岁", 102 | "你住在哪里", 103 | "你喜欢吃什么", 104 | "你最喜欢什么颜色", 105 | "你最喜欢什么食物", 106 | "你最喜欢什么动物", 107 | } 108 | 109 | workqueue.ParallelizeUntil(context.Background(), 3, len(textList), func(piece int) { 110 | text := textList[piece] 111 | fmt.Println("开始 speech text: ", text) 112 | ch, err := provider.TextToSpeechStream(context.Background(), text) 113 | if err != nil { 114 | fmt.Println("TextToSpeechStream 连接失败: ", err) 115 | return 116 | } 117 | opusDataList := [][]byte{} 118 | for frame := range ch { 119 | opusDataList = append(opusDataList, frame) 120 | if len(frame) == 0 { 121 | t.Error("收到空音频帧") 122 | } 123 | } 124 | fmt.Printf("text: %s, 收到 %d 个音频帧\n", text, len(opusDataList)) 125 | }) 126 | 127 | /* 128 | for _, text := range textList { 129 | fmt.Println("开始 speech text: ", text) 130 | ch, err := provider.TextToSpeechStream(context.Background(), text) 131 | if err != nil { 132 | fmt.Println("TextToSpeechStream 连接失败: ", err) 133 | return 134 | } 135 | opusDataList := [][]byte{} 136 | for frame := range ch { 137 | opusDataList = append(opusDataList, frame) 138 | if len(frame) == 0 { 139 | t.Error("收到空音频帧") 140 | } 141 | } 142 | //OpusToWav(opusDataList, 24000, 1, "output_24000.wav") 143 | }*/ 144 | 145 | } 146 | -------------------------------------------------------------------------------- /internal/domain/user_config/userconfig.go: -------------------------------------------------------------------------------- 1 | package userconfig 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "sync" 8 | 9 | log "xiaozhi-esp32-server-golang/logger" 10 | 11 | "github.com/redis/go-redis/v9" 12 | ) 13 | 14 | var ( 15 | userConfigInstance *UserConfig 16 | once sync.Once 17 | ) 18 | 19 | type UserConfig struct { 20 | redisInstance *redis.Client 21 | prefix string 22 | } 23 | 24 | func InitUserConfig(redisOptions *redis.Options, prefix string) error { 25 | var initErr error 26 | once.Do(func() { 27 | if redisOptions == nil { 28 | initErr = fmt.Errorf("redis options cannot be nil") 29 | return 30 | } 31 | 32 | client := redis.NewClient(redisOptions) 33 | // 测试 Redis 连接 34 | if err := client.Ping(context.Background()).Err(); err != nil { 35 | initErr = fmt.Errorf("failed to connect to redis: %w", err) 36 | return 37 | } 38 | 39 | userConfigInstance = &UserConfig{ 40 | redisInstance: client, 41 | prefix: prefix, 42 | } 43 | }) 44 | 45 | return initErr 46 | } 47 | 48 | func U() *UserConfig { 49 | if userConfigInstance == nil { 50 | return &UserConfig{} 51 | } 52 | return userConfigInstance 53 | } 54 | 55 | type AsrConfig struct { 56 | Type string `json:"type"` 57 | } 58 | 59 | type TtsConfig struct { 60 | Type string `json:"type"` 61 | } 62 | 63 | type LlmConfig struct { 64 | Type string `json:"type"` 65 | } 66 | 67 | type UConfig struct { 68 | SystemPrompt string `json:"system_prompt"` 69 | Asr AsrConfig `json:"asr"` 70 | Tts TtsConfig `json:"tts"` 71 | Llm LlmConfig `json:"llm"` 72 | } 73 | 74 | func (u *UConfig) getTTsType() string { 75 | ttsType := u.Tts.Type 76 | if ttsType == "" { 77 | ttsType = "local" 78 | } 79 | return ttsType 80 | } 81 | 82 | func (u *UserConfig) GetUserConfig(ctx context.Context, userID string) (UConfig, error) { 83 | if u.redisInstance == nil { 84 | log.Log().Warn("redis instance is nil") 85 | return UConfig{}, nil 86 | } 87 | key := u.GetUserConfigKey(userID) 88 | //hgetall 拿到所有的 89 | userConfig, err := u.redisInstance.HGetAll(ctx, key).Result() 90 | if err != nil { 91 | return UConfig{}, err 92 | } 93 | 94 | ret := UConfig{} 95 | //将UserConfig转换成UConfig结构 96 | for k, v := range userConfig { 97 | if k == "llm" { 98 | llmConfig := LlmConfig{} 99 | err = json.Unmarshal([]byte(v), &llmConfig) 100 | if err != nil { 101 | return UConfig{}, err 102 | } 103 | ret.Llm = llmConfig 104 | } else if k == "tts" { 105 | ttsConfig := TtsConfig{} 106 | err = json.Unmarshal([]byte(v), &ttsConfig) 107 | if err != nil { 108 | return UConfig{}, err 109 | } 110 | ret.Tts = ttsConfig 111 | } else if k == "asr" { 112 | asrConfig := AsrConfig{} 113 | err = json.Unmarshal([]byte(v), &asrConfig) 114 | if err != nil { 115 | return UConfig{}, err 116 | } 117 | ret.Asr = asrConfig 118 | } 119 | } 120 | return ret, nil 121 | } 122 | 123 | func (u *UserConfig) GetUserConfigKey(deviceId string) string { 124 | return fmt.Sprintf("%s:userconfig:%s", u.prefix, deviceId) 125 | } 126 | -------------------------------------------------------------------------------- /internal/domain/vad/test/silero_vad.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackers365/xiaozhi-esp32-server-golang/f00bb50d418dcec9418f37c2911b332323d8afed/internal/domain/vad/test/silero_vad.onnx -------------------------------------------------------------------------------- /internal/domain/vad/test/vad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackers365/xiaozhi-esp32-server-golang/f00bb50d418dcec9418f37c2911b332323d8afed/internal/domain/vad/test/vad -------------------------------------------------------------------------------- /internal/domain/vad/test/vad.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | 8 | "github.com/streamer45/silero-vad-go/speech" 9 | 10 | "github.com/go-audio/wav" 11 | ) 12 | 13 | func main() { 14 | sd, err := speech.NewDetector(speech.DetectorConfig{ 15 | ModelPath: "silero_vad.onnx", 16 | SampleRate: 16000, 17 | Threshold: 0.5, 18 | MinSilenceDurationMs: 100, 19 | SpeechPadMs: 30, 20 | }) 21 | if err != nil { 22 | log.Fatalf("failed to create speech detector: %s", err) 23 | } 24 | 25 | if len(os.Args) != 2 { 26 | log.Fatalf("invalid arguments provided: expecting one file path") 27 | } 28 | 29 | f, err := os.Open(os.Args[1]) 30 | if err != nil { 31 | log.Fatalf("failed to open sample audio file: %s", err) 32 | } 33 | defer f.Close() 34 | 35 | dec := wav.NewDecoder(f) 36 | 37 | if ok := dec.IsValidFile(); !ok { 38 | log.Fatalf("invalid WAV file") 39 | } 40 | 41 | buf, err := dec.FullPCMBuffer() 42 | if err != nil { 43 | log.Fatalf("failed to get PCM buffer") 44 | } 45 | 46 | pcmBuf := buf.AsFloat32Buffer() 47 | 48 | fmt.Println(len(pcmBuf.Data)) 49 | segments, err := sd.Detect(pcmBuf.Data) 50 | if err != nil { 51 | log.Fatalf("Detect failed: %s", err) 52 | } 53 | 54 | fmt.Println(segments) 55 | 56 | for _, s := range segments { 57 | log.Printf("speech starts at %0.2fs", s.SpeechStartAt) 58 | if s.SpeechEndAt > 0 { 59 | log.Printf("speech ends at %0.2fs", s.SpeechEndAt) 60 | } 61 | } 62 | 63 | err = sd.Destroy() 64 | if err != nil { 65 | log.Fatalf("failed to destroy detector: %s", err) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /internal/domain/vad/test/wav2vad.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "log" 8 | "os" 9 | 10 | "github.com/go-audio/audio" 11 | "github.com/go-audio/wav" 12 | "github.com/streamer45/silero-vad-go/speech" 13 | "gopkg.in/hraban/opus.v2" 14 | ) 15 | 16 | // readCloserWrapper 为 bytes.Reader 提供 Close 方法以实现 ReadCloser 接口 17 | type readCloserWrapper struct { 18 | *bytes.Reader 19 | } 20 | 21 | // Close 实现 io.Closer 接口 22 | func (r *readCloserWrapper) Close() error { 23 | return nil 24 | } 25 | 26 | // newReadCloserWrapper 创建一个新的 ReadCloser 包装 27 | func newReadCloserWrapper(data []byte) *readCloserWrapper { 28 | return &readCloserWrapper{bytes.NewReader(data)} 29 | } 30 | 31 | // WavToOpus 将WAV音频数据转换为标准Opus格式 32 | // 返回Opus帧的切片集合,每个切片是一个Opus编码帧 33 | func WavToOpus(wavData []byte, sampleRate int, channels int, bitRate int) ([][]byte, error) { 34 | 35 | sd, err := speech.NewDetector(speech.DetectorConfig{ 36 | ModelPath: "silero_vad.onnx", 37 | SampleRate: 16000, 38 | Threshold: 0.5, 39 | MinSilenceDurationMs: 250, 40 | SpeechPadMs: 150, 41 | }) 42 | if err != nil { 43 | log.Fatalf("failed to create speech detector: %s", err) 44 | } 45 | 46 | // 创建WAV解码器 47 | wavReader := bytes.NewReader(wavData) 48 | wavDecoder := wav.NewDecoder(wavReader) 49 | if !wavDecoder.IsValidFile() { 50 | return nil, fmt.Errorf("无效的WAV文件") 51 | } 52 | 53 | // 读取WAV文件信息 54 | wavDecoder.ReadInfo() 55 | format := wavDecoder.Format() 56 | wavSampleRate := int(format.SampleRate) 57 | wavChannels := int(format.NumChannels) 58 | 59 | // 如果提供的参数与文件参数不一致,使用文件中的参数 60 | if sampleRate == 0 { 61 | sampleRate = wavSampleRate 62 | } 63 | if channels == 0 { 64 | channels = wavChannels 65 | } 66 | 67 | //打印wavDecoder信息 68 | fmt.Println("WAV格式:", format) 69 | 70 | enc, err := opus.NewEncoder(sampleRate, channels, opus.AppAudio) 71 | if err != nil { 72 | return nil, fmt.Errorf("创建Opus编码器失败: %v", err) 73 | } 74 | 75 | dec, err := opus.NewDecoder(sampleRate, channels) 76 | if err != nil { 77 | return nil, fmt.Errorf("创建Opus编码器失败: %v", err) 78 | } 79 | 80 | // 设置比特率 81 | if bitRate > 0 { 82 | if err := enc.SetBitrate(bitRate); err != nil { 83 | return nil, fmt.Errorf("设置比特率失败: %v", err) 84 | } 85 | } 86 | 87 | // 创建输出帧切片数组 88 | opusFrames := make([][]byte, 0) 89 | 90 | perFrameDuration := 60 91 | // PCM缓冲区 - Opus帧大小(60ms) 92 | frameSize := sampleRate * perFrameDuration / 1000 93 | pcmBuffer := make([]int16, frameSize*channels) 94 | pcmBufferFloat32 := make([]float32, frameSize*channels) 95 | opusBuffer := make([]byte, 1000) // 足够大的缓冲区存储编码后的数据 96 | 97 | // 读取音频缓冲区 98 | audioBuf := &audio.IntBuffer{Data: make([]int, frameSize*channels), Format: format} 99 | 100 | fmt.Println("开始转换...") 101 | 102 | pcmAllData := make([]float32, 0) 103 | for { 104 | // 读取WAV数据 105 | n, err := wavDecoder.PCMBuffer(audioBuf) 106 | if err == io.EOF || n == 0 { 107 | break 108 | } 109 | if err != nil { 110 | return nil, fmt.Errorf("读取WAV数据失败: %v", err) 111 | } 112 | 113 | // 将int转换为int16 114 | for i := 0; i < len(audioBuf.Data); i++ { 115 | if i < len(pcmBuffer) { 116 | pcmBuffer[i] = int16(audioBuf.Data[i]) 117 | } 118 | } 119 | 120 | // 编码为Opus格式 121 | n, err = enc.Encode(pcmBuffer, opusBuffer) 122 | if err != nil { 123 | return nil, fmt.Errorf("编码失败: %v", err) 124 | } 125 | 126 | // 将当前帧复制到新的切片中并添加到帧数组 127 | frameData := make([]byte, n) 128 | copy(frameData, opusBuffer[:n]) 129 | opusFrames = append(opusFrames, frameData) 130 | 131 | //将opus解码至pcm 132 | n, err = dec.DecodeFloat32(frameData, pcmBufferFloat32) 133 | if err != nil { 134 | return nil, fmt.Errorf("解码失败: %v", err) 135 | } 136 | 137 | fmt.Printf("pcmBufferFloat32 len: %d\n", len(pcmBufferFloat32[:n])) 138 | 139 | segments, err := sd.Detect(pcmBufferFloat32[:n]) 140 | if err != nil { 141 | //log.Fatalf("Detect failed: %s", err) 142 | } 143 | fmt.Printf("detect voice: %v\n", segments) 144 | 145 | pcmAllData = append(pcmAllData, pcmBufferFloat32[:n]...) 146 | } 147 | 148 | segments, err := sd.Detect(pcmAllData) 149 | if err != nil { 150 | log.Fatalf("Detect failed: %s", err) 151 | } 152 | fmt.Printf("detect voice: %v\n", segments) 153 | 154 | //将frameData输出至test.opus 155 | opusFile, err := os.OpenFile("output.opus", os.O_CREATE|os.O_WRONLY, 0644) 156 | if err != nil { 157 | log.Fatalf("failed to create opus file: %s", err) 158 | } 159 | opusFile.Write(opusFrames[0]) 160 | opusFile.Close() 161 | 162 | /* 163 | //将pcm数据输出至test.pcm 164 | pcmFile, err := os.OpenFile("test.pcm", os.O_CREATE|os.O_WRONLY, 0644) 165 | if err != nil { 166 | log.Fatalf("failed to create pcm file: %s", err) 167 | } 168 | 169 | defer pcmFile.Close() 170 | dec, err := opus.NewDecoder(sampleRate, channels) 171 | if err != nil { 172 | return nil, fmt.Errorf("创建Opus解码器失败: %v", err) 173 | } 174 | 175 | pcmBuffer = make([]int16, 10240) 176 | for _, data := range opusFrames { 177 | //将opus数据decode成pcm 178 | n, err := dec.Decode(data, pcmBuffer) 179 | if err != nil { 180 | return nil, fmt.Errorf("解码失败: %v", err) 181 | } 182 | frameData := make([]int16, len(pcmBuffer)*2) 183 | copy(frameData, pcmBuffer[:n]) 184 | _, err = pcmFile.Write(frameData) 185 | if err != nil { 186 | log.Fatalf("failed to write to pcm file: %s", err) 187 | } 188 | }*/ 189 | 190 | return opusFrames, nil 191 | } 192 | 193 | func main() { 194 | if len(os.Args) != 2 { 195 | log.Fatalf("invalid arguments provided: expecting one file path") 196 | } 197 | 198 | f, err := os.Open(os.Args[1]) 199 | if err != nil { 200 | log.Fatalf("failed to open sample audio file: %s", err) 201 | } 202 | defer f.Close() 203 | 204 | //读取文件全部内容 205 | mp3Data, err := io.ReadAll(f) 206 | if err != nil { 207 | log.Fatalf("failed to read mp3 file: %s", err) 208 | } 209 | 210 | //将mp3转换为opus 211 | opusData, err := WavToOpus(mp3Data, 16000, 1, 0) 212 | if err != nil { 213 | log.Fatalf("failed to convert mp3 to opus: %s", err) 214 | } 215 | 216 | //打印opus数据 217 | fmt.Printf("opusData: %d\n", len(opusData)) 218 | 219 | //将Opus数据decode成pcm 220 | 221 | //将所有数据输出至test.opus 222 | /*opusFile, err := os.OpenFile("test.opus", os.O_CREATE|os.O_WRONLY, 0644) 223 | if err != nil { 224 | log.Fatalf("failed to create opus file: %s", err) 225 | } 226 | defer opusFile.Close() 227 | 228 | for _, data := range opusData { 229 | _, err := opusFile.Write(data) 230 | if err != nil { 231 | log.Fatalf("failed to write to opus file: %s", err) 232 | } 233 | }*/ 234 | } 235 | -------------------------------------------------------------------------------- /internal/util/buffer.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bytes" 5 | "sync" 6 | ) 7 | 8 | // SafeBuffer 是一个协程安全的缓冲区 9 | type SafeBuffer struct { 10 | buf bytes.Buffer 11 | mu sync.Mutex 12 | } 13 | 14 | func (b *SafeBuffer) Write(p []byte) (n int, err error) { 15 | b.mu.Lock() 16 | defer b.mu.Unlock() 17 | return b.buf.Write(p) 18 | } 19 | 20 | func (b *SafeBuffer) Read(p []byte) (n int, err error) { 21 | b.mu.Lock() 22 | defer b.mu.Unlock() 23 | return b.buf.Read(p) 24 | } 25 | 26 | func (b *SafeBuffer) Len() int { 27 | b.mu.Lock() 28 | defer b.mu.Unlock() 29 | return b.buf.Len() 30 | } 31 | 32 | func (b *SafeBuffer) Bytes() []byte { 33 | b.mu.Lock() 34 | defer b.mu.Unlock() 35 | return b.buf.Bytes() 36 | } 37 | 38 | func (b *SafeBuffer) Reset() { 39 | b.mu.Lock() 40 | defer b.mu.Unlock() 41 | b.buf.Reset() 42 | } 43 | 44 | func (b *SafeBuffer) Cap() int { 45 | b.mu.Lock() 46 | defer b.mu.Unlock() 47 | return b.buf.Cap() 48 | } 49 | -------------------------------------------------------------------------------- /internal/util/encryption.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "crypto/sha256" 7 | "encoding/hex" 8 | "fmt" 9 | ) 10 | 11 | func AesCTREncrypt(key, nonce, plaintext []byte) ([]byte, error) { 12 | block, err := aes.NewCipher(key) 13 | if err != nil { 14 | return nil, fmt.Errorf("failed to create cipher: %v", err) 15 | } 16 | 17 | stream := cipher.NewCTR(block, nonce) 18 | ciphertext := make([]byte, len(plaintext)) 19 | stream.XORKeyStream(ciphertext, plaintext) 20 | return ciphertext, nil 21 | } 22 | 23 | func AesCTRDecrypt(key, nonce, ciphertext []byte) ([]byte, error) { 24 | block, err := aes.NewCipher(key) 25 | if err != nil { 26 | return nil, fmt.Errorf("failed to create cipher: %v", err) 27 | } 28 | 29 | stream := cipher.NewCTR(block, nonce) 30 | plaintext := make([]byte, len(ciphertext)) 31 | stream.XORKeyStream(plaintext, ciphertext) 32 | return plaintext, nil 33 | } 34 | 35 | func Sha256Digest(data []byte) string { 36 | hash := sha256.Sum256(data) 37 | strHash := hex.EncodeToString(hash[:]) 38 | return strHash 39 | } 40 | -------------------------------------------------------------------------------- /internal/util/workqueue/parallelizer.go: -------------------------------------------------------------------------------- 1 | package workqueue 2 | 3 | import ( 4 | "context" 5 | "math" 6 | "runtime/debug" 7 | "sync" 8 | 9 | "go.uber.org/zap" 10 | ) 11 | 12 | type DoWorkPieceFunc func(piece int) 13 | 14 | // ParallelizeUntil is a framework that allows for parallelizing N 15 | // independent pieces of work until done or the context is canceled. 16 | func ParallelizeUntil(ctx context.Context, workers, pieces int, doWorkPiece DoWorkPieceFunc) { 17 | var stop <-chan struct{} 18 | if ctx != nil { 19 | stop = ctx.Done() 20 | } 21 | 22 | toProcess := make(chan int, pieces) 23 | for i := 0; i < pieces; i++ { 24 | toProcess <- i 25 | } 26 | close(toProcess) 27 | 28 | if pieces < workers { 29 | workers = pieces 30 | } 31 | 32 | wg := sync.WaitGroup{} 33 | wg.Add(workers) 34 | for i := 0; i < workers; i++ { 35 | go func() { 36 | defer func() { 37 | wg.Done() 38 | if r := recover(); r != nil { 39 | zap.L().Error("work has panic", zap.Any("panic", r)) 40 | debug.PrintStack() 41 | } 42 | }() 43 | for piece := range toProcess { 44 | select { 45 | case <-stop: 46 | return 47 | default: 48 | doWorkPiece(piece) 49 | } 50 | } 51 | }() 52 | } 53 | wg.Wait() 54 | } 55 | 56 | // ParallelizeUntil is a framework that allows for parallelizing N 57 | // independent pieces of work until done or the context is canceled. 58 | func ParallelizeUntilOptimize(ctx context.Context, workers, pieces int, doWorkPiece DoWorkPieceFunc) { 59 | var stop <-chan struct{} 60 | if ctx != nil { 61 | stop = ctx.Done() 62 | } 63 | 64 | if pieces < workers { 65 | workers = pieces 66 | } 67 | 68 | page := int(math.Ceil(float64(pieces) / float64(workers))) 69 | 70 | wg := sync.WaitGroup{} 71 | wg.Add(workers) 72 | for i := 0; i < workers; i++ { 73 | go func(workIndex int) { 74 | defer func() { 75 | wg.Done() 76 | if r := recover(); r != nil { 77 | zap.L().Error("work has panic", zap.Any("panic", r)) 78 | debug.PrintStack() 79 | } 80 | }() 81 | start := page * workIndex 82 | end := start + page 83 | if end >= pieces { 84 | end = pieces 85 | } 86 | 87 | for j := start; j < end; j++ { 88 | select { 89 | case <-stop: 90 | return 91 | default: 92 | doWorkPiece(j) 93 | } 94 | } 95 | }(i) 96 | } 97 | wg.Wait() 98 | } 99 | -------------------------------------------------------------------------------- /internal/util/workqueue/parallelizer_test.go: -------------------------------------------------------------------------------- 1 | package workqueue 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestNewParalle(t *testing.T) { 11 | optimize(100000000, 16) 12 | notOptimize(100000000, 16) 13 | } 14 | 15 | func optimize(count int, concurrent int) { 16 | idList := []int{} 17 | for i := 0; i < count; i++ { 18 | idList = append(idList, i) 19 | } 20 | 21 | startTime := time.Now().UnixNano() 22 | ParallelizeUntilOptimize(context.Background(), concurrent, len(idList), func(piece int) { 23 | sum := idList[piece] + 10 24 | _ = sum 25 | }) 26 | optimCost := time.Now().UnixNano() - startTime 27 | _ = optimCost 28 | fmt.Println(optimCost) 29 | } 30 | 31 | func notOptimize(count int, concurrent int) { 32 | idList := []int{} 33 | for i := 0; i <= count; i++ { 34 | idList = append(idList, i) 35 | } 36 | 37 | startTime := time.Now().UnixNano() 38 | ParallelizeUntil(context.Background(), concurrent, len(idList), func(piece int) { 39 | sum := idList[piece] + 10 40 | _ = sum 41 | }) 42 | cost := time.Now().UnixNano() - startTime 43 | 44 | fmt.Println(cost) 45 | } 46 | 47 | /* 48 | func TestNewParalle(t *testing.T) { 49 | idList := []int{} 50 | for i := 0; i <= 100000000; i++ { 51 | idList = append(idList, i) 52 | } 53 | 54 | startTime := time.Now().UnixNano() 55 | ParallelizeUntilOptimize(context.Background(), 8, len(idList), func(piece int) { 56 | sum := idList[piece] + 10 57 | _ = sum 58 | }) 59 | optimCost := time.Now().UnixNano() - startTime 60 | 61 | fmt.Println(optimCost) 62 | 63 | startTime = time.Now().UnixNano() 64 | ParallelizeUntil(context.Background(), 8, len(idList), func(piece int) { 65 | sum := idList[piece] + 10 66 | _ = sum 67 | }) 68 | cost := time.Now().UnixNano() - startTime 69 | 70 | fmt.Println(cost) 71 | 72 | fmt.Println(cost / optimCost) 73 | } 74 | */ 75 | -------------------------------------------------------------------------------- /internal/util/workqueue/parallelizer_test.go.bak: -------------------------------------------------------------------------------- 1 | package workqueue 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestNewParalle(t *testing.T) { 11 | idList := []int{} 12 | for i := 0; i <= 100000000; i++ { 13 | idList = append(idList, i) 14 | } 15 | 16 | startTime := time.Now().UnixNano() 17 | ParallelizeUntilOptimize(context.Background(), 8, len(idList), func(piece int) { 18 | sum := idList[piece] + 10 19 | _ = sum 20 | }) 21 | optimCost := time.Now().UnixNano() - startTime 22 | 23 | fmt.Println(optimCost) 24 | 25 | startTime = time.Now().UnixNano() 26 | ParallelizeUntil(context.Background(), 8, len(idList), func(piece int) { 27 | sum := idList[piece] + 10 28 | _ = sum 29 | }) 30 | cost := time.Now().UnixNano() - startTime 31 | 32 | fmt.Println(cost) 33 | 34 | fmt.Println(cost / optimCost) 35 | } 36 | -------------------------------------------------------------------------------- /logger/db_log.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | log "github.com/sirupsen/logrus" 6 | ) 7 | 8 | type GormLog struct { 9 | clog *log.Logger 10 | } 11 | 12 | var DbLog *GormLog 13 | 14 | func InitDbLog(clog *log.Logger) { 15 | DbLog = &GormLog{ 16 | clog: clog, 17 | } 18 | } 19 | 20 | func (d *GormLog) Printf(format string, args ...interface{}) { 21 | logStr := fmt.Sprintf(format, args...) 22 | d.clog.Info(logStr) 23 | } 24 | -------------------------------------------------------------------------------- /logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "runtime" 8 | 9 | nested "github.com/antonfisher/nested-logrus-formatter" 10 | log "github.com/sirupsen/logrus" 11 | ) 12 | 13 | const ( 14 | TYPE_HTTP = 1 15 | ) 16 | 17 | func init() { 18 | // 不设置默认输出,由应用程序决定 19 | log.SetFormatter(Formatter(false)) // 默认不使用颜色 20 | } 21 | 22 | // SetOutput 设置日志输出目标 23 | func SetOutput(out *os.File) { 24 | log.SetOutput(out) 25 | } 26 | 27 | // SetLevel 设置日志级别 28 | func SetLevel(level log.Level) { 29 | log.SetLevel(level) 30 | } 31 | 32 | // UseStdout 使用标准输出 33 | func UseStdout() { 34 | log.SetOutput(os.Stdout) 35 | log.SetFormatter(Formatter(true)) 36 | } 37 | 38 | /* 39 | func getUserInfo(ctx *gin.Context) int { 40 | if data, ok := ctx.Get("uid"); ok { 41 | if uid, ok := data.(int); ok { 42 | return uid 43 | } 44 | } 45 | return 0 46 | } 47 | */ 48 | 49 | // getCaller 获取实际的调用者信息(跳过logger包装层) 50 | func getCaller() (string, int) { 51 | // 跳过日志库的调用栈,获取实际调用者 52 | // 通过调用栈:用户代码 -> logger.Info -> addCallerField -> getCaller -> runtime.Caller 53 | // 所以需要跳过3层才能到达实际调用位置 54 | _, file, line, ok := runtime.Caller(3) 55 | if !ok { 56 | return "unknown", 0 57 | } 58 | // 提取文件名(不带路径) 59 | shortFile := filepath.Base(file) 60 | return shortFile, line 61 | } 62 | 63 | // addCallerField 添加调用者信息到日志字段 64 | func addCallerField() *log.Entry { 65 | file, line := getCaller() 66 | return log.WithField("caller", fmt.Sprintf("%s:%d", file, line)) 67 | } 68 | 69 | func Info(args ...interface{}) { 70 | addCallerField().Info(args...) 71 | } 72 | 73 | func Error(args ...interface{}) { 74 | addCallerField().Error(args...) 75 | } 76 | 77 | func Debug(args ...interface{}) { 78 | addCallerField().Debug(args...) 79 | } 80 | 81 | func Warn(args ...interface{}) { 82 | addCallerField().Warn(args...) 83 | } 84 | 85 | func Fatal(args ...interface{}) { 86 | addCallerField().Fatal(args...) 87 | } 88 | 89 | func Infof(format string, args ...interface{}) { 90 | addCallerField().Infof(format, args...) 91 | } 92 | 93 | func Errorf(format string, args ...interface{}) { 94 | addCallerField().Errorf(format, args...) 95 | } 96 | 97 | func Debugf(format string, args ...interface{}) { 98 | addCallerField().Debugf(format, args...) 99 | } 100 | 101 | func Warnf(format string, args ...interface{}) { 102 | addCallerField().Warnf(format, args...) 103 | } 104 | 105 | func Fatalf(format string, args ...interface{}) { 106 | addCallerField().Fatalf(format, args...) 107 | } 108 | 109 | func Log(args ...interface{}) *log.Entry { 110 | fields := log.Fields{} 111 | lenArgs := len(args) 112 | for i := 0; i < lenArgs; i = i + 2 { 113 | var key string 114 | var ok bool 115 | if key, ok = args[i].(string); !ok { 116 | continue 117 | } 118 | 119 | if i <= lenArgs-2 { 120 | fields[key] = args[i+1] 121 | continue 122 | } 123 | fields[key] = "" 124 | } 125 | 126 | // 添加调用者信息 127 | // 在Log函数调用链中也需要调整层级 128 | _, file, line, ok := runtime.Caller(2) 129 | if !ok { 130 | file = "unknown" 131 | line = 0 132 | } 133 | shortFile := filepath.Base(file) 134 | fields["caller"] = fmt.Sprintf("%s:%d", shortFile, line) 135 | 136 | log.SetFormatter(Formatter(true)) 137 | return log.WithFields(fields) 138 | } 139 | 140 | func Formatter(isConsole bool) *nested.Formatter { 141 | fmtter := &nested.Formatter{ 142 | FieldsOrder: []string{"time", "level", "caller", "msg"}, 143 | HideKeys: true, 144 | TimestampFormat: "2006-01-02 15:04:05.000", 145 | CallerFirst: true, 146 | NoUppercaseLevel: true, 147 | ShowFullLevel: true, 148 | //NoFieldsSpace: true, 149 | // 禁用默认的调用者格式化,因为我们已经添加了自定义的caller字段 150 | CustomCallerFormatter: func(frame *runtime.Frame) string { 151 | return "" 152 | }, 153 | } 154 | if isConsole { 155 | fmtter.NoColors = false 156 | } else { 157 | fmtter.NoColors = true 158 | } 159 | return fmtter 160 | } 161 | 162 | // DebugStack 用于调试日志调用栈,输出当前调用链的所有调用者信息 163 | func DebugStack() { 164 | for i := 0; i < 5; i++ { 165 | _, file, line, ok := runtime.Caller(i) 166 | if !ok { 167 | break 168 | } 169 | shortFile := filepath.Base(file) 170 | log.Infof("调用栈[%d]: %s:%d", i, shortFile, line) 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /test/mem0/__pycache__/mem0.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackers365/xiaozhi-esp32-server-golang/f00bb50d418dcec9418f37c2911b332323d8afed/test/mem0/__pycache__/mem0.cpython-312.pyc -------------------------------------------------------------------------------- /test/mem0/memory.py: -------------------------------------------------------------------------------- 1 | from mem0 import MemoryClient 2 | client = MemoryClient(api_key="m0-bVuO6XCOOknqlPnDXbFWl2RdSyPH3ZY7sgdOo3yD") 3 | 4 | messages = [ 5 | {"role": "user", "content": "Hi, I'm Alex. I'm a vegetarian and I'm allergic to nuts."}, 6 | {"role": "assistant", "content": "Hello Alex! I've noted that you're a vegetarian and have a nut allergy. I'll keep this in mind for any food-related recommendations or discussions."} 7 | ] 8 | client.add(messages, user_id="alex") 9 | 10 | query = "What can I cook for dinner tonight?" 11 | c = client.search(query, user_id="alex") 12 | print(c) 13 | -------------------------------------------------------------------------------- /test/mqtt_udp/README.md: -------------------------------------------------------------------------------- 1 | 测试小智官方服务器mqtt+udp协议 响应速度 2 | 结果: 3 | stt 166ms,llm 300ms左右,首帧音频642ms 4 | -------------------------------------------------------------------------------- /test/mqtt_udp/audio_utils: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackers365/xiaozhi-esp32-server-golang/f00bb50d418dcec9418f37c2911b332323d8afed/test/mqtt_udp/audio_utils -------------------------------------------------------------------------------- /test/mqtt_udp/audio_utils.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "os" 8 | 9 | "github.com/go-audio/audio" 10 | "github.com/go-audio/wav" 11 | "gopkg.in/hraban/opus.v2" 12 | ) 13 | 14 | // WavToOpus 将WAV音频数据转换为标准Opus格式 15 | // 返回Opus帧的切片集合,每个切片是一个Opus编码帧 16 | func WavToOpus(wavData []byte, sampleRate int, channels int, bitRate int) ([][]byte, error) { 17 | // 创建WAV解码器 18 | wavReader := bytes.NewReader(wavData) 19 | wavDecoder := wav.NewDecoder(wavReader) 20 | if !wavDecoder.IsValidFile() { 21 | return nil, fmt.Errorf("无效的WAV文件") 22 | } 23 | 24 | // 读取WAV文件信息 25 | wavDecoder.ReadInfo() 26 | format := wavDecoder.Format() 27 | wavSampleRate := int(format.SampleRate) 28 | wavChannels := int(format.NumChannels) 29 | 30 | // 如果提供的参数与文件参数不一致,使用文件中的参数 31 | if sampleRate == 0 { 32 | sampleRate = wavSampleRate 33 | } 34 | if channels == 0 { 35 | channels = wavChannels 36 | } 37 | 38 | //打印wavDecoder信息 39 | fmt.Println("WAV格式:", format) 40 | 41 | enc, err := opus.NewEncoder(sampleRate, channels, opus.AppAudio) 42 | if err != nil { 43 | return nil, fmt.Errorf("创建Opus编码器失败: %v", err) 44 | } 45 | 46 | // 设置比特率 47 | if bitRate > 0 { 48 | if err := enc.SetBitrate(bitRate); err != nil { 49 | return nil, fmt.Errorf("设置比特率失败: %v", err) 50 | } 51 | } 52 | 53 | // 创建输出帧切片数组 54 | opusFrames := make([][]byte, 0) 55 | 56 | perFrameDuration := 60 57 | // PCM缓冲区 - Opus帧大小(60ms) 58 | frameSize := sampleRate * perFrameDuration / 1000 59 | pcmBuffer := make([]int16, frameSize*channels) 60 | opusBuffer := make([]byte, 1000) // 足够大的缓冲区存储编码后的数据 61 | 62 | // 读取音频缓冲区 63 | audioBuf := &audio.IntBuffer{Data: make([]int, frameSize*channels), Format: format} 64 | 65 | fmt.Println("开始转换...") 66 | for { 67 | // 读取WAV数据 68 | n, err := wavDecoder.PCMBuffer(audioBuf) 69 | if err == io.EOF || n == 0 { 70 | break 71 | } 72 | if err != nil { 73 | return nil, fmt.Errorf("读取WAV数据失败: %v", err) 74 | } 75 | 76 | // 将int转换为int16 77 | for i := 0; i < len(audioBuf.Data); i++ { 78 | if i < len(pcmBuffer) { 79 | pcmBuffer[i] = int16(audioBuf.Data[i]) 80 | } 81 | } 82 | 83 | // 编码为Opus格式 84 | n, err = enc.Encode(pcmBuffer, opusBuffer) 85 | if err != nil { 86 | return nil, fmt.Errorf("编码失败: %v", err) 87 | } 88 | 89 | // 将当前帧复制到新的切片中并添加到帧数组 90 | frameData := make([]byte, n) 91 | copy(frameData, opusBuffer[:n]) 92 | opusFrames = append(opusFrames, frameData) 93 | } 94 | 95 | return opusFrames, nil 96 | } 97 | 98 | func OpusToWav(opusData [][]byte, sampleRate int, channels int, fileName string) ([][]int16, error) { 99 | opusDecoder, err := opus.NewDecoder(sampleRate, channels) 100 | if err != nil { 101 | return nil, fmt.Errorf("创建Opus解码器失败: %v", err) 102 | } 103 | 104 | wavOut, err := os.Create(fileName) 105 | if err != nil { 106 | return nil, fmt.Errorf("创建WAV文件失败: %v", err) 107 | } 108 | 109 | pcmDataList := make([][]int16, 0) 110 | pcmBuffer := make([]int16, 4096) 111 | 112 | wavEncoder := wav.NewEncoder(wavOut, sampleRate, 16, channels, 1) 113 | wavBuffer := audio.IntBuffer{ 114 | Format: &audio.Format{ 115 | NumChannels: channels, // 使用传入的通道数 116 | SampleRate: sampleRate, 117 | }, 118 | SourceBitDepth: 16, 119 | Data: make([]int, 4096), 120 | } 121 | 122 | for _, frame := range opusData { 123 | n, err := opusDecoder.Decode(frame, pcmBuffer) 124 | if err != nil { 125 | return nil, fmt.Errorf("解码失败: %v", err) 126 | } 127 | copyData := make([]int16, len(pcmBuffer[:n])) 128 | copy(copyData, pcmBuffer[:n]) 129 | pcmDataList = append(pcmDataList, copyData) 130 | 131 | fmt.Println("pcmData len: ", len(copyData)) 132 | 133 | // 将PCM数据转换为int格式 134 | for i := 0; i < len(copyData); i++ { 135 | wavBuffer.Data = append(wavBuffer.Data, int(copyData[i])) 136 | } 137 | } 138 | 139 | // 写入WAV文件 140 | err = wavEncoder.Write(&wavBuffer) 141 | if err != nil { 142 | return nil, fmt.Errorf("写入WAV文件失败: %v", err) 143 | } 144 | 145 | wavEncoder.Close() 146 | 147 | return pcmDataList, nil 148 | } 149 | -------------------------------------------------------------------------------- /test/mqtt_udp/go.sum: -------------------------------------------------------------------------------- 1 | github.com/eclipse/paho.mqtt.golang v1.5.0 h1:EH+bUVJNgttidWFkLLVKaQPGmkTUfQQqjOsyvMGvD6o= 2 | github.com/eclipse/paho.mqtt.golang v1.5.0/go.mod h1:du/2qNQVqJf/Sqs4MEL77kR8QTqANF7XU7Fk0aOTAgk= 3 | github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= 4 | github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs= 5 | github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA= 6 | github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498= 7 | github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g= 8 | github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE= 9 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 10 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 11 | github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= 12 | github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 13 | golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= 14 | golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= 15 | golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= 16 | golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 17 | gopkg.in/hraban/opus.v2 v2.0.0-20230925203106-0188a62cb302 h1:xeVptzkP8BuJhoIjNizd2bRHfq9KB9HfOLZu90T04XM= 18 | gopkg.in/hraban/opus.v2 v2.0.0-20230925203106-0188a62cb302/go.mod h1:/L5E7a21VWl8DeuCPKxQBdVG5cy+L0MRZ08B1wnqt7g= 19 | -------------------------------------------------------------------------------- /test/mqtt_udp/main: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackers365/xiaozhi-esp32-server-golang/f00bb50d418dcec9418f37c2911b332323d8afed/test/mqtt_udp/main -------------------------------------------------------------------------------- /test/mqtt_udp/mqtt.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import mqtt "github.com/eclipse/paho.mqtt.golang" 4 | 5 | type MqttClient struct { 6 | instance mqtt.Client 7 | ClientId string 8 | Username string 9 | Password string 10 | Endpoint string 11 | PublishTopic string 12 | OnMessage mqtt.MessageHandler 13 | } 14 | 15 | func NewMqttClient(clientId, username, password, endpoint, publishTopic string, OnMessage mqtt.MessageHandler) *MqttClient { 16 | return &MqttClient{ 17 | ClientId: clientId, 18 | Username: username, 19 | Password: password, 20 | Endpoint: endpoint, 21 | PublishTopic: publishTopic, 22 | OnMessage: OnMessage, 23 | } 24 | } 25 | 26 | func (m *MqttClient) Connect() error { 27 | opts := mqtt.NewClientOptions().AddBroker(m.Endpoint).SetClientID(m.ClientId) 28 | opts.SetUsername(m.Username) 29 | opts.SetPassword(m.Password) 30 | 31 | instance := mqtt.NewClient(opts) 32 | if token := instance.Connect(); token.Wait() && token.Error() != nil { 33 | return token.Error() 34 | } 35 | m.instance = instance 36 | return nil 37 | } 38 | 39 | func (m *MqttClient) Publish(topic string, payload []byte) error { 40 | token := m.instance.Publish(topic, 0, false, payload) 41 | token.Wait() 42 | if token.Error() != nil { 43 | return token.Error() 44 | } 45 | return nil 46 | } 47 | 48 | func (m *MqttClient) Subscribe(topic string, callback mqtt.MessageHandler) error { 49 | token := m.instance.Subscribe(topic, 0, callback) 50 | token.Wait() 51 | return token.Error() 52 | } 53 | -------------------------------------------------------------------------------- /test/mqtt_udp/mqtt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import paho.mqtt.client as mqtt 4 | 5 | # hello 消息体结构 6 | def build_hello_message(): 7 | return { 8 | "type": "hello", 9 | "version": 3, 10 | "transport": "udp", 11 | "audio_format": { 12 | "format": "opus", 13 | "sample_rate": 16000, 14 | "channels": 1, 15 | "frame_duration": 60 16 | } 17 | } 18 | 19 | def on_connect(client, userdata, flags, rc): 20 | if rc == 0: 21 | print("MQTT 连接成功") 22 | # 连接成功后发布 hello 消息 23 | public_hello(client) 24 | else: 25 | print("连接失败,返回码:", rc) 26 | sys.exit(1) 27 | 28 | def on_message(client, userdata, msg): 29 | print(f"收到消息: [{msg.topic}] {msg.payload.decode('utf-8')}") 30 | 31 | def public_hello(client): 32 | topic = "device-server" 33 | message = build_hello_message() 34 | json_data = json.dumps(message) 35 | print("发布消息:", json_data) 36 | result = client.publish(topic, json_data, qos=0, retain=False) 37 | result.wait_for_publish() 38 | if result.is_published(): 39 | print("发布消息成功") 40 | else: 41 | print("发布消息失败") 42 | 43 | def main(): 44 | broker = "mqtt.xiaozhi.me" 45 | port = 8883 46 | client_id = "GID_test@@@02_4A_7D_E3_89_BF@@@e3b0c442-98fc-4e1a-8c3d-6a5b6a5b6a5b" 47 | username = "eyJpcCI6IjEuMjAyLjE5My4xOTQifQ==" 48 | password = "Ru9zRLdD/4wrBYorxIyABtHe8EiA1hdZ4v34juJ2BUU=" 49 | 50 | client = mqtt.Client(client_id=client_id) 51 | client.username_pw_set(username, password) 52 | client.tls_set() # 使用 SSL 连接 53 | 54 | client.on_connect = on_connect 55 | client.on_message = on_message 56 | 57 | client.connect(broker, port, keepalive=60) 58 | client.loop_forever() 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /test/mqtt_udp/test_24000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackers365/xiaozhi-esp32-server-golang/f00bb50d418dcec9418f37c2911b332323d8afed/test/mqtt_udp/test_24000.wav -------------------------------------------------------------------------------- /test/mqtt_udp/udp.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "encoding/hex" 7 | "fmt" 8 | "net" 9 | "time" 10 | ) 11 | 12 | type UDPClient struct { 13 | udpConn *net.UDPConn 14 | serverAddr *net.UDPAddr 15 | aesKey string 16 | aesNonce string 17 | localSeq uint32 18 | } 19 | 20 | func NewUDPClient(serverAddr string, port int, aesKey, aesNonce string) (*UDPClient, error) { 21 | addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", serverAddr, port)) 22 | if err != nil { 23 | return nil, fmt.Errorf("failed to resolve UDP address: %v", err) 24 | } 25 | 26 | conn, err := net.DialUDP("udp", nil, addr) 27 | if err != nil { 28 | return nil, fmt.Errorf("failed to create UDP connection: %v", err) 29 | } 30 | 31 | return &UDPClient{ 32 | udpConn: conn, 33 | serverAddr: addr, 34 | aesKey: aesKey, 35 | aesNonce: aesNonce, 36 | localSeq: 0, 37 | }, nil 38 | } 39 | 40 | func (c *UDPClient) Close() { 41 | if c.udpConn != nil { 42 | c.udpConn.Close() 43 | } 44 | } 45 | 46 | func AesCTREncrypt(key, nonce, plaintext []byte) ([]byte, error) { 47 | block, err := aes.NewCipher(key) 48 | if err != nil { 49 | return nil, fmt.Errorf("failed to create cipher: %v", err) 50 | } 51 | 52 | stream := cipher.NewCTR(block, nonce) 53 | ciphertext := make([]byte, len(plaintext)) 54 | stream.XORKeyStream(ciphertext, plaintext) 55 | return ciphertext, nil 56 | } 57 | 58 | func AesCTRDecrypt(key, nonce, ciphertext []byte) ([]byte, error) { 59 | block, err := aes.NewCipher(key) 60 | if err != nil { 61 | return nil, fmt.Errorf("failed to create cipher: %v", err) 62 | } 63 | 64 | stream := cipher.NewCTR(block, nonce) 65 | plaintext := make([]byte, len(ciphertext)) 66 | stream.XORKeyStream(plaintext, ciphertext) 67 | return plaintext, nil 68 | } 69 | 70 | func (c *UDPClient) aesCTREncrypt(key, nonce, plaintext []byte) ([]byte, error) { 71 | block, err := aes.NewCipher(key) 72 | if err != nil { 73 | return nil, fmt.Errorf("failed to create cipher: %v", err) 74 | } 75 | 76 | stream := cipher.NewCTR(block, nonce) 77 | ciphertext := make([]byte, len(plaintext)) 78 | stream.XORKeyStream(ciphertext, plaintext) 79 | return ciphertext, nil 80 | } 81 | 82 | func (c *UDPClient) aesCTRDecrypt(key, nonce, ciphertext []byte) ([]byte, error) { 83 | block, err := aes.NewCipher(key) 84 | if err != nil { 85 | return nil, fmt.Errorf("failed to create cipher: %v", err) 86 | } 87 | 88 | stream := cipher.NewCTR(block, nonce) 89 | plaintext := make([]byte, len(ciphertext)) 90 | stream.XORKeyStream(plaintext, ciphertext) 91 | return plaintext, nil 92 | } 93 | 94 | func (c *UDPClient) decryptAudioData(key []byte, data []byte) ([]byte, error) { 95 | //分离nonce和密文 96 | nonce := data[:16] 97 | ciphertext := data[16:] 98 | 99 | //解密 100 | decryptedData, err := c.aesCTRDecrypt(key, nonce, ciphertext) 101 | if err != nil { 102 | return nil, fmt.Errorf("failed to decrypt data: %v", err) 103 | } 104 | 105 | return decryptedData, nil 106 | } 107 | 108 | func (c *UDPClient) SendAudioData(audioData []byte) error { 109 | // 生成新的nonce 110 | c.localSeq = (c.localSeq + 1) & 0xFFFFFFFF 111 | 112 | // 构建nonce字符串: 固定前缀 + 长度 + 原始nonce + 序列号 113 | nonceHex := c.aesNonce[:4] + // 固定前缀 (01000000) 114 | fmt.Sprintf("%04x", len(audioData)) + // 数据长度,4个十六进制字符 115 | c.aesNonce[8:24] + // 原始nonce 116 | fmt.Sprintf("%08x", c.localSeq) // 序列号,8个十六进制字符 117 | 118 | //fmt.Printf("c.aesNonce: %s len: %d, nonceHex: %s len: %d\n", c.aesNonce, len(c.aesNonce), nonceHex, len(nonceHex)) 119 | 120 | // 加密数据 121 | key, err := hex.DecodeString(c.aesKey) 122 | if err != nil { 123 | return fmt.Errorf("failed to decode AES key: %v", err) 124 | } 125 | 126 | nonceBytes, err := hex.DecodeString(nonceHex) 127 | if err != nil { 128 | return fmt.Errorf("failed to decode nonce: %v", err) 129 | } 130 | 131 | // 检查IV长度 132 | //fmt.Printf("IV长度: %d 字节, 内容: %x\n", len(nonceBytes), nonceBytes) 133 | 134 | iv := nonceBytes 135 | 136 | encryptedData, err := c.aesCTREncrypt(key, iv, audioData) 137 | if err != nil { 138 | return fmt.Errorf("failed to encrypt data: %v", err) 139 | } 140 | 141 | // 拼接nonce和密文 142 | packet := append(nonceBytes, encryptedData...) 143 | 144 | // 发送数据包 145 | _, err = c.udpConn.Write(packet) 146 | if err != nil { 147 | return fmt.Errorf("failed to send UDP packet: %v", err) 148 | } 149 | 150 | //fmt.Printf("发送数据: nonce=%s, seq=%d, dataLen=%d\n", nonceHex, c.localSeq, len(audioData)) 151 | 152 | return nil 153 | } 154 | 155 | func (c *UDPClient) ReceiveAudioData(key []byte, cb func(key []byte, data []byte)) error { 156 | go func() { 157 | for { 158 | buffer := make([]byte, 1024) 159 | n, _, err := c.udpConn.ReadFromUDP(buffer) 160 | if err != nil { 161 | fmt.Println(err) 162 | return 163 | } 164 | 165 | if !firstAudio { 166 | firstAudio = true 167 | fmt.Printf("收到第一条音频消息, 耗时: %d ms\n", time.Now().UnixMilli()-sendAudioEndTs) 168 | } 169 | 170 | cb(key, buffer[:n]) 171 | } 172 | }() 173 | 174 | return nil 175 | } 176 | -------------------------------------------------------------------------------- /test/py_test_audio/dec_opus.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import os 6 | import wave 7 | import struct 8 | import opuslib 9 | 10 | def decode_raw_opus(opus_data, sample_rate=24000, channels=1, frame_size_ms=60): 11 | """解码原始Opus数据,返回PCM数据""" 12 | # 计算一个帧的样本数 13 | frame_size = int(sample_rate * frame_size_ms / 1000) 14 | 15 | # 创建解码器 16 | decoder = opuslib.Decoder(sample_rate, channels) 17 | 18 | # 尝试直接解码整个文件 19 | try: 20 | pcm_data = bytearray() 21 | decoded = decoder.decode(opus_data, frame_size, False) 22 | for sample in decoded: 23 | pcm_data.extend(struct.pack('") 33 | return 34 | 35 | opus_file = sys.argv[1] 36 | 37 | # 检查文件是否存在 38 | if not os.path.exists(opus_file): 39 | print(f"错误: 文件 '{opus_file}' 不存在") 40 | return 41 | 42 | # 初始化参数 43 | sample_rate = 24000 # 采样率24000Hz 44 | channels = 1 # 单声道 45 | frame_size_ms = 60 # 帧大小60ms 46 | 47 | # 读取opus文件全部内容 48 | with open(opus_file, 'rb') as f: 49 | opus_data = f.read() 50 | 51 | print(f"读取原始Opus数据: {len(opus_data)} 字节") 52 | 53 | # 解码数据 54 | pcm_data = decode_raw_opus(opus_data, sample_rate, channels, frame_size_ms) 55 | 56 | if pcm_data is None or len(pcm_data) == 0: 57 | print("解码失败,未能生成PCM数据") 58 | return 59 | 60 | # 计算PCM数据长度(样本数) 61 | pcm_samples_count = len(pcm_data) // 2 # 每个样本2字节 62 | pcm_duration_ms = pcm_samples_count * 1000 / sample_rate 63 | 64 | print(f"解码后PCM数据大小: {len(pcm_data)} 字节") 65 | print(f"PCM样本数: {pcm_samples_count}") 66 | print(f"PCM时长: {pcm_duration_ms:.2f} 毫秒") 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /test/test_audio/audio_utils.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "os" 8 | 9 | "github.com/go-audio/audio" 10 | "github.com/go-audio/wav" 11 | "gopkg.in/hraban/opus.v2" 12 | ) 13 | 14 | // WavToOpus 将WAV音频数据转换为标准Opus格式 15 | // 返回Opus帧的切片集合,每个切片是一个Opus编码帧 16 | func WavToOpus(wavData []byte, sampleRate int, channels int, bitRate int) ([][]byte, error) { 17 | // 创建WAV解码器 18 | wavReader := bytes.NewReader(wavData) 19 | wavDecoder := wav.NewDecoder(wavReader) 20 | if !wavDecoder.IsValidFile() { 21 | return nil, fmt.Errorf("无效的WAV文件") 22 | } 23 | 24 | // 读取WAV文件信息 25 | wavDecoder.ReadInfo() 26 | format := wavDecoder.Format() 27 | wavSampleRate := int(format.SampleRate) 28 | wavChannels := int(format.NumChannels) 29 | 30 | // 如果提供的参数与文件参数不一致,使用文件中的参数 31 | if sampleRate == 0 { 32 | sampleRate = wavSampleRate 33 | } 34 | if channels == 0 { 35 | channels = wavChannels 36 | } 37 | 38 | //打印wavDecoder信息 39 | fmt.Println("WAV格式:", format) 40 | 41 | enc, err := opus.NewEncoder(sampleRate, channels, opus.AppAudio) 42 | if err != nil { 43 | return nil, fmt.Errorf("创建Opus编码器失败: %v", err) 44 | } 45 | 46 | // 设置比特率 47 | if bitRate > 0 { 48 | if err := enc.SetBitrate(bitRate); err != nil { 49 | return nil, fmt.Errorf("设置比特率失败: %v", err) 50 | } 51 | } 52 | 53 | // 创建输出帧切片数组 54 | opusFrames := make([][]byte, 0) 55 | 56 | perFrameDuration := 60 57 | // PCM缓冲区 - Opus帧大小(60ms) 58 | frameSize := sampleRate * perFrameDuration / 1000 59 | pcmBuffer := make([]int16, frameSize*channels) 60 | opusBuffer := make([]byte, 1000) // 足够大的缓冲区存储编码后的数据 61 | 62 | // 读取音频缓冲区 63 | audioBuf := &audio.IntBuffer{Data: make([]int, frameSize*channels), Format: format} 64 | 65 | fmt.Println("开始转换...") 66 | for { 67 | // 读取WAV数据 68 | n, err := wavDecoder.PCMBuffer(audioBuf) 69 | if err == io.EOF || n == 0 { 70 | break 71 | } 72 | if err != nil { 73 | return nil, fmt.Errorf("读取WAV数据失败: %v", err) 74 | } 75 | 76 | // 将int转换为int16 77 | for i := 0; i < len(audioBuf.Data); i++ { 78 | if i < len(pcmBuffer) { 79 | pcmBuffer[i] = int16(audioBuf.Data[i]) 80 | } 81 | } 82 | 83 | // 编码为Opus格式 84 | n, err = enc.Encode(pcmBuffer, opusBuffer) 85 | if err != nil { 86 | return nil, fmt.Errorf("编码失败: %v", err) 87 | } 88 | 89 | // 将当前帧复制到新的切片中并添加到帧数组 90 | frameData := make([]byte, n) 91 | copy(frameData, opusBuffer[:n]) 92 | opusFrames = append(opusFrames, frameData) 93 | } 94 | 95 | return opusFrames, nil 96 | } 97 | 98 | func OpusToWav(opusData [][]byte, sampleRate int, channels int, fileName string) ([][]int16, error) { 99 | opusDecoder, err := opus.NewDecoder(sampleRate, channels) 100 | if err != nil { 101 | return nil, fmt.Errorf("创建Opus解码器失败: %v", err) 102 | } 103 | 104 | wavOut, err := os.Create(fileName) 105 | if err != nil { 106 | return nil, fmt.Errorf("创建WAV文件失败: %v", err) 107 | } 108 | 109 | perFrameDuration := 60 110 | pcmBuffer := make([]int16, channels*sampleRate*perFrameDuration/1000) 111 | 112 | wavEncoder := wav.NewEncoder(wavOut, sampleRate, 16, channels, 1) 113 | wavBuffer := audio.IntBuffer{ 114 | Format: &audio.Format{ 115 | NumChannels: channels, // 使用传入的通道数 116 | SampleRate: sampleRate, 117 | }, 118 | SourceBitDepth: 16, 119 | Data: make([]int, 4096), 120 | } 121 | 122 | pcmDataList := make([][]int16, 0) 123 | 124 | for _, opusData := range opusData { 125 | n, err := opusDecoder.Decode(opusData, pcmBuffer) 126 | if err != nil { 127 | return nil, fmt.Errorf("解码失败: %v", err) 128 | } 129 | copyData := make([]int16, len(pcmBuffer[:n])) 130 | copy(copyData, pcmBuffer[:n]) 131 | pcmDataList = append(pcmDataList, copyData) 132 | } 133 | 134 | //fmt.Println("pcmData len: ", len(copyData)) 135 | 136 | // 将PCM数据转换为int格式 137 | for _, pcmData := range pcmDataList { 138 | for i := 0; i < len(pcmData); i++ { 139 | wavBuffer.Data = append(wavBuffer.Data, int(pcmData[i])) 140 | } 141 | } 142 | 143 | // 写入WAV文件 144 | err = wavEncoder.Write(&wavBuffer) 145 | if err != nil { 146 | return nil, fmt.Errorf("写入WAV文件失败: %v", err) 147 | } 148 | 149 | wavEncoder.Close() 150 | 151 | return pcmDataList, nil 152 | } 153 | -------------------------------------------------------------------------------- /test/test_audio/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | ) 8 | 9 | func main() { 10 | //读取文件,使用参数传入,输入输出文件 11 | inputFilePath := flag.String("input", "", "输入文件路径") 12 | outputFilePath := flag.String("output", "", "输出文件路径") 13 | sampleRate := flag.Int("sampleRate", 24000, "采样率") 14 | channels := flag.Int("channels", 1, "声道数") 15 | flag.Parse() 16 | 17 | if *inputFilePath == "" || *outputFilePath == "" { 18 | flag.Usage() 19 | return 20 | } 21 | 22 | //读取文件所有内容 23 | content, err := os.ReadFile(*inputFilePath) 24 | if err != nil { 25 | fmt.Println("读取文件失败:", err) 26 | return 27 | } 28 | 29 | fmt.Println("读取文件成功:", *inputFilePath) 30 | 31 | opusData := [][]byte{content} 32 | pcmData, err := OpusToWav(opusData, *sampleRate, *channels, *outputFilePath) 33 | if err != nil { 34 | fmt.Println("转换失败:", err) 35 | return 36 | } 37 | fmt.Println("pcmData len: ", len(pcmData[0])) 38 | 39 | fmt.Println("转换成功:", *outputFilePath) 40 | } 41 | -------------------------------------------------------------------------------- /test/test_opus/decode_opus.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | // 24000Hz采样率、单声道、60ms帧长度对应的PCM样本数 7 | #define SAMPLE_RATE 24000 8 | #define CHANNELS 1 9 | #define FRAME_SIZE_MS 60 10 | #define FRAME_SIZE (SAMPLE_RATE * FRAME_SIZE_MS / 1000) 11 | 12 | // 每帧最大字节数(安全值) 13 | #define MAX_PACKET_SIZE 1500 14 | 15 | int main(int argc, char *argv[]) { 16 | if (argc < 2) { 17 | printf("用法: %s [raw]\n", argv[0]); 18 | printf("参数说明:\n"); 19 | printf(" : 要解码的opus文件路径\n"); 20 | printf(" [raw]: 可选参数,指定为raw则处理无长度前缀的raw opus数据\n"); 21 | return 1; 22 | } 23 | 24 | // 检查是否为raw模式 25 | int raw_mode = 1; 26 | 27 | // 打开opus文件 28 | FILE *fp = fopen(argv[1], "rb"); 29 | if (!fp) { 30 | printf("无法打开文件: %s\n", argv[1]); 31 | return 1; 32 | } 33 | 34 | // 获取文件大小 35 | fseek(fp, 0, SEEK_END); 36 | long file_size = ftell(fp); 37 | fseek(fp, 0, SEEK_SET); 38 | 39 | // 读取整个文件内容 40 | unsigned char *opus_data = (unsigned char *)malloc(file_size); 41 | if (!opus_data) { 42 | printf("内存分配失败\n"); 43 | fclose(fp); 44 | return 1; 45 | } 46 | 47 | size_t bytes_read = fread(opus_data, 1, file_size, fp); 48 | fclose(fp); 49 | 50 | printf("读取文件成功,大小: %ld 字节\n", bytes_read); 51 | 52 | // 创建opus解码器 53 | int error; 54 | OpusDecoder *decoder = opus_decoder_create(SAMPLE_RATE, CHANNELS, &error); 55 | if (error != OPUS_OK) { 56 | printf("创建opus解码器失败: %s\n", opus_strerror(error)); 57 | free(opus_data); 58 | return 1; 59 | } 60 | 61 | printf("解码器创建成功,采样率: %d Hz, 声道数: %d\n", SAMPLE_RATE, CHANNELS); 62 | printf("理论每帧PCM样本数(60ms): %d\n", FRAME_SIZE); 63 | 64 | // 准备PCM输出缓冲区 - 理论上60ms@24000Hz应该有1440个样本点 65 | opus_int16 pcm[FRAME_SIZE * CHANNELS]; 66 | 67 | int frame_count = 0; 68 | 69 | // 尝试将整个文件当作一个opus帧解码 70 | int samples = opus_decode(decoder, opus_data, bytes_read, pcm, FRAME_SIZE, 0); 71 | 72 | if (samples < 0) { 73 | printf("解码失败: %s\n", opus_strerror(samples)); 74 | } else { 75 | frame_count++; 76 | printf("解码完成: opus长度 %ld 字节, 解码后PCM样本数 %d\n", bytes_read, samples); 77 | 78 | // 可以将PCM保存为文件 79 | char output_file[256]; 80 | sprintf(output_file, "%s.pcm", argv[1]); 81 | FILE *out_fp = fopen(output_file, "wb"); 82 | if (out_fp) { 83 | fwrite(pcm, sizeof(opus_int16), samples, out_fp); 84 | fclose(out_fp); 85 | printf("已保存PCM数据到 %s\n", output_file); 86 | } 87 | } 88 | 89 | printf("总共解码 %d 帧\n", frame_count); 90 | 91 | // 清理资源 92 | opus_decoder_destroy(decoder); 93 | free(opus_data); 94 | 95 | return 0; 96 | } 97 | -------------------------------------------------------------------------------- /test/websocket_client/audio_utils.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "os" 8 | 9 | "github.com/go-audio/audio" 10 | "github.com/go-audio/wav" 11 | "gopkg.in/hraban/opus.v2" 12 | ) 13 | 14 | // WavToOpus 将WAV音频数据转换为标准Opus格式 15 | // 返回Opus帧的切片集合,每个切片是一个Opus编码帧 16 | func WavToOpus(wavData []byte, sampleRate int, channels int, bitRate int) ([][]byte, error) { 17 | // 创建WAV解码器 18 | wavReader := bytes.NewReader(wavData) 19 | wavDecoder := wav.NewDecoder(wavReader) 20 | if !wavDecoder.IsValidFile() { 21 | return nil, fmt.Errorf("无效的WAV文件") 22 | } 23 | 24 | // 读取WAV文件信息 25 | wavDecoder.ReadInfo() 26 | format := wavDecoder.Format() 27 | wavSampleRate := int(format.SampleRate) 28 | wavChannels := int(format.NumChannels) 29 | 30 | // 如果提供的参数与文件参数不一致,使用文件中的参数 31 | if sampleRate == 0 { 32 | sampleRate = wavSampleRate 33 | } 34 | if channels == 0 { 35 | channels = wavChannels 36 | } 37 | 38 | //打印wavDecoder信息 39 | fmt.Println("WAV格式:", format) 40 | 41 | enc, err := opus.NewEncoder(sampleRate, channels, opus.AppAudio) 42 | if err != nil { 43 | return nil, fmt.Errorf("创建Opus编码器失败: %v", err) 44 | } 45 | 46 | // 设置比特率 47 | if bitRate > 0 { 48 | if err := enc.SetBitrate(bitRate); err != nil { 49 | return nil, fmt.Errorf("设置比特率失败: %v", err) 50 | } 51 | } 52 | 53 | // 创建输出帧切片数组 54 | opusFrames := make([][]byte, 0) 55 | 56 | perFrameDuration := 60 57 | // PCM缓冲区 - Opus帧大小(60ms) 58 | frameSize := sampleRate * perFrameDuration / 1000 59 | pcmBuffer := make([]int16, frameSize*channels) 60 | opusBuffer := make([]byte, 1000) // 足够大的缓冲区存储编码后的数据 61 | 62 | // 读取音频缓冲区 63 | audioBuf := &audio.IntBuffer{Data: make([]int, frameSize*channels), Format: format} 64 | 65 | fmt.Println("开始转换...") 66 | for { 67 | // 读取WAV数据 68 | n, err := wavDecoder.PCMBuffer(audioBuf) 69 | if err == io.EOF || n == 0 { 70 | break 71 | } 72 | if err != nil { 73 | return nil, fmt.Errorf("读取WAV数据失败: %v", err) 74 | } 75 | 76 | // 将int转换为int16 77 | for i := 0; i < len(audioBuf.Data); i++ { 78 | if i < len(pcmBuffer) { 79 | pcmBuffer[i] = int16(audioBuf.Data[i]) 80 | } 81 | } 82 | 83 | // 编码为Opus格式 84 | n, err = enc.Encode(pcmBuffer, opusBuffer) 85 | if err != nil { 86 | return nil, fmt.Errorf("编码失败: %v", err) 87 | } 88 | 89 | // 将当前帧复制到新的切片中并添加到帧数组 90 | frameData := make([]byte, n) 91 | copy(frameData, opusBuffer[:n]) 92 | opusFrames = append(opusFrames, frameData) 93 | } 94 | 95 | return opusFrames, nil 96 | } 97 | 98 | func OpusToWav(opusData [][]byte, sampleRate int, channels int, fileName string) ([][]int16, error) { 99 | opusDecoder, err := opus.NewDecoder(sampleRate, channels) 100 | if err != nil { 101 | return nil, fmt.Errorf("创建Opus解码器失败: %v", err) 102 | } 103 | 104 | wavOut, err := os.Create(fileName) 105 | if err != nil { 106 | return nil, fmt.Errorf("创建WAV文件失败: %v", err) 107 | } 108 | 109 | pcmDataList := make([][]int16, 0) 110 | pcmBuffer := make([]int16, 8192) 111 | 112 | wavEncoder := wav.NewEncoder(wavOut, sampleRate, 16, channels, 1) 113 | wavBuffer := audio.IntBuffer{ 114 | Format: &audio.Format{ 115 | NumChannels: channels, // 使用传入的通道数 116 | SampleRate: sampleRate, 117 | }, 118 | SourceBitDepth: 16, 119 | Data: make([]int, 8192), 120 | } 121 | 122 | for _, frame := range opusData { 123 | n, err := opusDecoder.Decode(frame, pcmBuffer) 124 | if err != nil { 125 | return nil, fmt.Errorf("解码失败: %v", err) 126 | } 127 | copyData := make([]int16, len(pcmBuffer[:n])) 128 | copy(copyData, pcmBuffer[:n]) 129 | //fmt.Println("decode pcmData len: ", len(copyData)) 130 | pcmDataList = append(pcmDataList, copyData) 131 | 132 | //fmt.Println("pcmData len: ", len(copyData)) 133 | 134 | // 将PCM数据转换为int格式 135 | for i := 0; i < len(copyData); i++ { 136 | wavBuffer.Data = append(wavBuffer.Data, int(copyData[i])) 137 | } 138 | } 139 | 140 | // 写入WAV文件 141 | err = wavEncoder.Write(&wavBuffer) 142 | if err != nil { 143 | return nil, fmt.Errorf("写入WAV文件失败: %v", err) 144 | } 145 | 146 | wavEncoder.Close() 147 | 148 | return pcmDataList, nil 149 | } 150 | -------------------------------------------------------------------------------- /test/websocket_client/test.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackers365/xiaozhi-esp32-server-golang/f00bb50d418dcec9418f37c2911b332323d8afed/test/websocket_client/test.mp3 -------------------------------------------------------------------------------- /test/websocket_client/test_mp3.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "os" 8 | "time" 9 | 10 | "github.com/gopxl/beep/mp3" 11 | "github.com/gopxl/beep/wav" 12 | "gopkg.in/hraban/opus.v2" 13 | ) 14 | 15 | func main1() { 16 | // HTTP接口URL 17 | mp3URL := "http://home.hackers365.com:55555/apk/test.mp3" 18 | // 指定输出的PCM文件路径 19 | pcmFilePath := "output.pcm" 20 | 21 | // 创建PCM文件 22 | pcmFile, err := os.Create(pcmFilePath) 23 | if err != nil { 24 | fmt.Printf("无法创建PCM文件: %v\n", err) 25 | return 26 | } 27 | defer pcmFile.Close() 28 | 29 | // 从HTTP接口获取MP3数据并处理 30 | err = processMP3FromHTTP(mp3URL, pcmFile) 31 | if err != nil { 32 | fmt.Printf("处理HTTP MP3数据失败: %v\n", err) 33 | return 34 | } 35 | 36 | fmt.Printf("HTTP MP3数据已成功解码为PCM格式,保存至: %s\n", pcmFilePath) 37 | 38 | // 导出WAV格式 39 | exportHTTPToWav(mp3URL, "output.wav") 40 | } 41 | 42 | type readCloserWrapper struct { 43 | io.Reader 44 | } 45 | 46 | func (r readCloserWrapper) Close() error { 47 | return nil 48 | } 49 | 50 | // 从HTTP接口获取并处理MP3数据 51 | func processMP3FromHTTP(url string, pcmFile *os.File) error { 52 | // 发起HTTP请求 53 | resp, err := http.Get(url) 54 | if err != nil { 55 | return fmt.Errorf("HTTP请求失败: %v", err) 56 | } 57 | defer resp.Body.Close() 58 | 59 | // 检查HTTP响应状态 60 | if resp.StatusCode != http.StatusOK { 61 | return fmt.Errorf("HTTP请求返回非200状态码: %d", resp.StatusCode) 62 | } 63 | 64 | // 创建一个pipe用于处理数据流 65 | pipeReader, pipeWriter := io.Pipe() 66 | defer pipeReader.Close() 67 | 68 | // 创建一个读取缓冲区和采样缓冲区 69 | bufferSize := 10 * 1024 // 10KB 70 | buffer := make([]byte, bufferSize) // HTTP 读取缓冲区 71 | 72 | opusBuffer := make([]byte, 1000) // Opus 编码输出缓冲区 73 | 74 | // 创建一个错误通道和完成通道 75 | errChan := make(chan error, 1) 76 | doneChan := make(chan struct{}, 1) 77 | 78 | // 启动goroutine解码MP3和处理PCM 79 | go func() { 80 | // 尝试初始化解码器 81 | streamer, format, err := mp3.Decode(pipeReader) 82 | if err != nil { 83 | errChan <- fmt.Errorf("MP3解码器初始化失败: %v", err) 84 | return 85 | } 86 | defer streamer.Close() 87 | 88 | fmt.Printf("MP3解码器初始化成功,采样率: %d Hz, 声道数: %d\n", 89 | format.SampleRate, format.NumChannels) 90 | 91 | //原mp3格式信息 92 | sampleRate := int(format.SampleRate) 93 | channels := int(format.NumChannels) 94 | 95 | // PCM缓冲区 及 Opus帧大小(例如60ms) 96 | perFrameDuration := 60 // 毫秒 97 | frameSize := sampleRate * perFrameDuration / 1000 98 | pcmBuffer := make([]int16, frameSize*channels) 99 | opusFrames := make([][]byte, 0) // 存储编码后的Opus帧 100 | 101 | enc, err := opus.NewEncoder(sampleRate, channels, opus.AppAudio) 102 | if err != nil { 103 | fmt.Printf("创建Opus编码器失败: %v\n", err) 104 | errChan <- fmt.Errorf("创建Opus编码器失败: %v", err) 105 | return 106 | } 107 | 108 | beepSampleBuf := make([][2]float64, 1024) // Beep 解码缓冲区 109 | // 处理解码后的音频流 110 | currentFramePos := 0 // 当前填充到pcmBuffer的位置 111 | for { 112 | // 从流中读取采样到sampleBuf 113 | numSamplesRead, ok := streamer.Stream(beepSampleBuf) 114 | if !ok { 115 | // 处理剩余不足一帧的数据 116 | if currentFramePos > 0 { 117 | // 创建一个完整的帧缓冲区,用0填充剩余部分 118 | paddedFrame := make([]int16, len(pcmBuffer)) 119 | copy(paddedFrame, pcmBuffer[:currentFramePos]) // 将有效数据复制到开头,剩余部分默认为0 120 | 121 | // 编码补齐后的完整帧 122 | n, err := enc.Encode(paddedFrame, opusBuffer) 123 | if err != nil { 124 | fmt.Printf("编码剩余数据失败: %v\n", err) 125 | // 可能需要通过 errChan 发送错误 126 | } else { 127 | frameData := make([]byte, n) 128 | copy(frameData, opusBuffer[:n]) 129 | opusFrames = append(opusFrames, frameData) 130 | // 注意:这里编码的是一个完整的帧,即使原始数据不足 131 | fmt.Printf("已编码最后补齐的 %d 个PCM样本 (原始 %d)\n", len(paddedFrame), currentFramePos) 132 | } 133 | } 134 | // 解码完成 135 | doneChan <- struct{}{} 136 | return 137 | } 138 | 139 | // 将读取到的float64样本转换为int16并填充到pcmBuffer 140 | for i := 0; i < numSamplesRead; i++ { 141 | // 直接进行转换 142 | leftSample := int16(beepSampleBuf[i][0] * 32767.0) 143 | rightSample := int16(beepSampleBuf[i][1] * 32767.0) 144 | 145 | // 写入PCM数据 146 | pcmBuffer[currentFramePos] = leftSample 147 | if channels > 1 { 148 | pcmBuffer[currentFramePos+1] = rightSample 149 | } 150 | currentFramePos += channels 151 | 152 | // 如果pcmBuffer已满一帧,则进行编码 153 | if currentFramePos == len(pcmBuffer) { 154 | n, err := enc.Encode(pcmBuffer, opusBuffer) 155 | if err != nil { 156 | fmt.Printf("编码失败: %v\n", err) 157 | errChan <- fmt.Errorf("编码失败: %v", err) 158 | return 159 | } 160 | 161 | // 将当前帧复制到新的切片中并添加到帧数组 162 | frameData := make([]byte, n) 163 | copy(frameData, opusBuffer[:n]) 164 | opusFrames = append(opusFrames, frameData) 165 | 166 | fmt.Printf("已编码一帧 (%d PCM样本)\n", len(pcmBuffer)) 167 | currentFramePos = 0 // 重置帧位置 168 | } 169 | } 170 | } 171 | }() 172 | 173 | // 创建定时器,每100ms发送一次数据 174 | ticker := time.NewTicker(100 * time.Millisecond) 175 | defer ticker.Stop() 176 | 177 | // 开始循环读取HTTP数据并写入pipe 178 | for { 179 | select { 180 | case <-ticker.C: 181 | // 从HTTP响应读取数据 182 | n, err := resp.Body.Read(buffer) 183 | 184 | // 如果读取到数据,写入pipe 185 | if n > 0 { 186 | _, writeErr := pipeWriter.Write(buffer[:n]) 187 | if writeErr != nil { 188 | return fmt.Errorf("写入pipe失败: %v", writeErr) 189 | } 190 | fmt.Printf("已读取并写入 %d 字节MP3数据\n", n) 191 | } 192 | 193 | // 处理EOF或错误 194 | if err != nil { 195 | if err == io.EOF { 196 | fmt.Println("HTTP数据流已读取完毕") 197 | pipeWriter.Close() // 关闭pipe写入端 198 | 199 | // 等待解码完成或出错 200 | select { 201 | case <-doneChan: 202 | return nil 203 | case err := <-errChan: 204 | return err 205 | } 206 | } else { 207 | return fmt.Errorf("读取HTTP数据出错: %v", err) 208 | } 209 | } 210 | 211 | case err := <-errChan: 212 | return err 213 | 214 | case <-doneChan: 215 | return nil 216 | } 217 | } 218 | } 219 | 220 | func exportHTTPToWav(url string, wavFilePath string) { 221 | // 发起HTTP请求 222 | resp, err := http.Get(url) 223 | if err != nil { 224 | fmt.Printf("HTTP请求失败: %v\n", err) 225 | return 226 | } 227 | defer resp.Body.Close() 228 | 229 | // 检查HTTP响应状态 230 | if resp.StatusCode != http.StatusOK { 231 | fmt.Printf("HTTP请求返回非200状态码: %d\n", resp.StatusCode) 232 | return 233 | } 234 | 235 | // 解码MP3 236 | streamer, format, err := mp3.Decode(resp.Body) 237 | if err != nil { 238 | fmt.Printf("无法解码MP3数据: %v\n", err) 239 | return 240 | } 241 | defer streamer.Close() 242 | 243 | // 创建WAV文件 244 | wavFile, err := os.Create(wavFilePath) 245 | if err != nil { 246 | fmt.Printf("无法创建WAV文件: %v\n", err) 247 | return 248 | } 249 | defer wavFile.Close() 250 | 251 | // 使用beep/wav包将流编码为WAV 252 | err = wav.Encode(wavFile, streamer, format) 253 | if err != nil { 254 | fmt.Printf("WAV编码失败: %v\n", err) 255 | return 256 | } 257 | 258 | fmt.Printf("已从HTTP导出WAV文件: %s\n", wavFilePath) 259 | } 260 | --------------------------------------------------------------------------------