├── .DS_Store ├── multispace ├── utils │ ├── __init__.py │ └── logger.py ├── decoder │ ├── __init__.py │ └── decoder_module.py ├── fusion │ ├── __init__.py │ └── fusion_module.py ├── text │ ├── __init__.py │ ├── text_encoder.py │ └── deepseek_text_encoder.py ├── audio │ ├── __init__.py │ ├── audio_encoder.py │ └── whisper_audio_encoder.py ├── image │ ├── __init__.py │ ├── image_encoder.py │ └── diffusion_image_encoder.py ├── __init__.py ├── config │ ├── __init__.py │ └── model_config.py └── multispace.py ├── requirements.txt ├── langchain_integration ├── __init__.py ├── multispace_chain.py └── retrieval_chain.py ├── examples ├── README.md ├── multimodal_example.py ├── basic_example.py ├── audio_processing_example.py ├── config_example.py └── langchain_example.py ├── multimodal_llm_design_doc.md ├── docs ├── langchain_integration.md └── architecture │ └── system_architecture.md └── README.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li-neo/MultiSpaceAI/HEAD/.DS_Store -------------------------------------------------------------------------------- /multispace/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 工具模块 3 | ======= 4 | 5 | 该模块包含各种工具函数和类。 6 | """ 7 | 8 | from .logger import setup_logger 9 | 10 | __all__ = [ 11 | 'setup_logger' 12 | ] -------------------------------------------------------------------------------- /multispace/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 解码器模块 3 | ========= 4 | 5 | 该模块包含用于生成输出的解码器类。 6 | """ 7 | 8 | from .decoder_module import DecoderModule 9 | 10 | __all__ = [ 11 | 'DecoderModule' 12 | ] -------------------------------------------------------------------------------- /multispace/fusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 多模态融合模块 3 | =========== 4 | 5 | 该模块包含用于融合多种模态特征的类。 6 | """ 7 | 8 | from .fusion_module import MultimodalFusionModule 9 | 10 | __all__ = [ 11 | 'MultimodalFusionModule' 12 | ] -------------------------------------------------------------------------------- /multispace/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 文本处理模块 3 | ========= 4 | 5 | 该模块包含用于处理文本数据的编码器类。 6 | """ 7 | 8 | from .text_encoder import TextEncoder 9 | from .deepseek_text_encoder import DeepSeekTextEncoder 10 | 11 | __all__ = [ 12 | 'TextEncoder', 13 | 'DeepSeekTextEncoder' 14 | ] -------------------------------------------------------------------------------- /multispace/audio/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 语音处理模块 3 | ========= 4 | 5 | 该模块包含用于处理语音数据的编码器类。 6 | """ 7 | 8 | from .audio_encoder import AudioEncoder 9 | from .whisper_audio_encoder import WhisperAudioEncoder 10 | 11 | __all__ = [ 12 | 'AudioEncoder', 13 | 'WhisperAudioEncoder' 14 | ] -------------------------------------------------------------------------------- /multispace/image/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 图像处理模块 3 | ========= 4 | 5 | 该模块包含用于处理图像数据的编码器类。 6 | """ 7 | 8 | from .image_encoder import ImageEncoder 9 | from .diffusion_image_encoder import DiffusionImageEncoder 10 | 11 | __all__ = [ 12 | 'ImageEncoder', 13 | 'DiffusionImageEncoder' 14 | ] -------------------------------------------------------------------------------- /multispace/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MultiSpaceAI - 多模态大语言模型 3 | ================================== 4 | 5 | 这个包实现了一个能够处理文本、图像和语音输入的多模态大语言模型系统。 6 | 7 | 主要功能: 8 | - 多模态数据处理和编码 9 | - 跨模态特征融合 10 | - 基于融合特征的生成任务 11 | """ 12 | 13 | __version__ = '0.1.0' 14 | 15 | from .multispace import MultiSpaceAI 16 | 17 | __all__ = ['MultiSpaceAI'] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.10.0 2 | transformers>=4.21.0 3 | numpy>=1.20.0 4 | pillow>=8.3.0 5 | librosa>=0.9.0 6 | soundfile>=0.10.0 7 | requests>=2.25.0 8 | tqdm>=4.62.0 9 | sentencepiece>=0.1.96 10 | huggingface-hub>=0.9.0 11 | tokenizers>=0.12.0 12 | langchain>=0.0.267 13 | langchain-openai>=0.0.2 14 | faiss-cpu>=1.7.4 15 | chromadb>=0.4.15 16 | openai>=1.3.0 17 | tiktoken>=0.5.1 -------------------------------------------------------------------------------- /multispace/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 配置模块 3 | ======= 4 | 5 | 该模块包含模型配置相关的类和函数。 6 | """ 7 | 8 | from .model_config import ModelConfig, TextEncoderConfig, ImageEncoderConfig, AudioEncoderConfig, FusionConfig, DecoderConfig 9 | 10 | __all__ = [ 11 | 'ModelConfig', 12 | 'TextEncoderConfig', 13 | 'ImageEncoderConfig', 14 | 'AudioEncoderConfig', 15 | 'FusionConfig', 16 | 'DecoderConfig' 17 | ] -------------------------------------------------------------------------------- /langchain_integration/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | MultiSpaceAI LangChain 集成 5 | ========================= 6 | 7 | 提供与LangChain的集成功能,使MultiSpaceAI能够融入LangChain生态系统。 8 | """ 9 | 10 | from src.langchain_integration.multispace_chain import ( 11 | MultiSpaceAITool, 12 | MultiSpaceAIChain, 13 | MultiSpaceAIAgent 14 | ) 15 | from src.langchain_integration.retrieval_chain import ( 16 | MultiSpaceAIEmbeddings, 17 | MultiModalDocument, 18 | MultiSpaceAIRetrievalChain 19 | ) 20 | 21 | __all__ = [ 22 | 'MultiSpaceAITool', 23 | 'MultiSpaceAIChain', 24 | 'MultiSpaceAIAgent', 25 | 'MultiSpaceAIEmbeddings', 26 | 'MultiModalDocument', 27 | 'MultiSpaceAIRetrievalChain' 28 | ] -------------------------------------------------------------------------------- /multispace/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | 日志工具模块 3 | ========= 4 | 5 | 该模块包含用于设置和管理日志的函数。 6 | """ 7 | 8 | import logging 9 | import os 10 | import sys 11 | from typing import Optional 12 | 13 | 14 | def setup_logger(name: str, level: int = logging.INFO, log_file: Optional[str] = None) -> logging.Logger: 15 | """ 16 | 设置日志记录器 17 | 18 | 参数: 19 | name: 日志记录器名称 20 | level: 日志级别 21 | log_file: 日志文件路径,如果为None则仅输出到控制台 22 | 23 | 返回: 24 | 配置好的日志记录器 25 | """ 26 | # 获取日志记录器 27 | logger = logging.getLogger(name) 28 | 29 | # 如果日志记录器已经有处理器,说明已经配置过,直接返回 30 | if logger.handlers: 31 | return logger 32 | 33 | # 设置日志级别 34 | logger.setLevel(level) 35 | 36 | # 创建格式化器 37 | formatter = logging.Formatter( 38 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s', 39 | datefmt='%Y-%m-%d %H:%M:%S' 40 | ) 41 | 42 | # 创建控制台处理器 43 | console_handler = logging.StreamHandler(sys.stdout) 44 | console_handler.setFormatter(formatter) 45 | logger.addHandler(console_handler) 46 | 47 | # 如果提供了日志文件路径,创建文件处理器 48 | if log_file: 49 | # 确保日志目录存在 50 | os.makedirs(os.path.dirname(os.path.abspath(log_file)), exist_ok=True) 51 | 52 | # 创建文件处理器 53 | file_handler = logging.FileHandler(log_file, encoding='utf-8') 54 | file_handler.setFormatter(formatter) 55 | logger.addHandler(file_handler) 56 | 57 | # 阻止日志传播到父记录器 58 | logger.propagate = False 59 | 60 | return logger -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # MultiSpaceAI 示例 2 | 3 | 本目录包含 MultiSpaceAI 的各种使用示例,帮助您快速上手多模态大语言模型。 4 | 5 | ## 示例列表 6 | 7 | 1. **基本示例** (`basic_example.py`) 8 | - 展示 MultiSpaceAI 的基本使用方法 9 | - 支持文本、图像和音频输入 10 | 11 | 2. **多模态处理示例** (`multimodal_example.py`) 12 | - 展示如何同时处理图像和文本输入 13 | - 可视化展示处理结果 14 | 15 | 3. **音频处理示例** (`audio_processing_example.py`) 16 | - 展示如何处理音频输入 17 | - 包含音频可视化功能 18 | 19 | 4. **配置示例** (`config_example.py`) 20 | - 展示如何自定义和保存模型配置 21 | - 包含默认配置和自定义配置示例 22 | 23 | ## 使用方法 24 | 25 | ### 基本示例 26 | 27 | ```bash 28 | # 使用文本输入 29 | python examples/basic_example.py --text "这是一段测试文本" 30 | 31 | # 使用图像输入 32 | python examples/basic_example.py --image path/to/image.jpg 33 | 34 | # 使用音频输入 35 | python examples/basic_example.py --audio path/to/audio.mp3 36 | 37 | # 同时使用多种模态 38 | python examples/basic_example.py --text "描述这个图像" --image path/to/image.jpg 39 | 40 | # 指定输出文件 41 | python examples/basic_example.py --text "测试文本" --output results.json 42 | 43 | # 使用自定义配置文件 44 | python examples/basic_example.py --text "测试文本" --config custom_config.json 45 | ``` 46 | 47 | ### 多模态处理示例 48 | 49 | ```bash 50 | # 处理图像和文本 51 | python examples/multimodal_example.py --image path/to/image.jpg --text "描述这个图像" 52 | 53 | # 显示输入图像 54 | python examples/multimodal_example.py --image path/to/image.jpg --display 55 | ``` 56 | 57 | ### 音频处理示例 58 | 59 | ```bash 60 | # 处理音频 61 | python examples/audio_processing_example.py --audio path/to/audio.mp3 62 | 63 | # 可视化音频并处理 64 | python examples/audio_processing_example.py --audio path/to/audio.mp3 --visualize 65 | 66 | # 尝试播放音频(在支持的环境中) 67 | python examples/audio_processing_example.py --audio path/to/audio.mp3 --play 68 | ``` 69 | 70 | ### 配置示例 71 | 72 | ```bash 73 | # 创建默认配置 74 | python examples/config_example.py --type default --output default_config.json 75 | 76 | # 创建自定义配置 77 | python examples/config_example.py --type custom --output custom_config.json 78 | ``` 79 | 80 | ## 环境变量 81 | 82 | 某些示例可能需要设置环境变量来使用API服务: 83 | 84 | - `DEEPSEEK_API_KEY`: 用于DeepSeek文本编码器 85 | - `STABLE_DIFFUSION_API_KEY`: 用于Diffusion图像编码器 86 | - `OPENAI_API_KEY`: 用于Whisper语音编码器 87 | 88 | ```bash 89 | # 设置环境变量示例(Linux/macOS) 90 | export OPENAI_API_KEY=your_api_key_here 91 | 92 | # 设置环境变量示例(Windows) 93 | set OPENAI_API_KEY=your_api_key_here 94 | ``` 95 | 96 | ## 额外资源 97 | 98 | 查看项目根目录的 README.md 获取更多信息。 -------------------------------------------------------------------------------- /examples/multimodal_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 多模态处理示例 5 | =========== 6 | 7 | 展示如何使用MultiSpaceAI同时处理图像和文本输入。 8 | """ 9 | 10 | import os 11 | import sys 12 | import torch 13 | import argparse 14 | from PIL import Image 15 | import matplotlib.pyplot as plt 16 | 17 | # 将项目根目录添加到Python路径 18 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 19 | 20 | from src.multispace import MultiSpaceAI 21 | 22 | 23 | def display_image(image_path): 24 | """显示图像""" 25 | img = Image.open(image_path) 26 | plt.figure(figsize=(10, 6)) 27 | plt.imshow(img) 28 | plt.axis('off') 29 | plt.title("输入图像") 30 | plt.show() 31 | 32 | 33 | def main(): 34 | """主函数""" 35 | # 设置命令行参数 36 | parser = argparse.ArgumentParser(description="MultiSpaceAI多模态处理示例") 37 | parser.add_argument("--image", type=str, required=True, help="图像文件路径") 38 | parser.add_argument("--text", type=str, default="描述这个图像", help="提示文本") 39 | parser.add_argument("--config", type=str, help="配置文件路径") 40 | parser.add_argument("--device", type=str, default=None, help="运行设备") 41 | parser.add_argument("--output", type=str, help="输出文件路径") 42 | parser.add_argument("--display", action="store_true", help="显示输入图像") 43 | 44 | # 解析参数 45 | args = parser.parse_args() 46 | 47 | # 检查图像文件是否存在 48 | if not os.path.exists(args.image): 49 | print(f"错误: 图像文件不存在: {args.image}") 50 | sys.exit(1) 51 | 52 | # 显示图像 53 | if args.display: 54 | display_image(args.image) 55 | 56 | print(f"处理图像: {args.image}") 57 | print(f"提示文本: {args.text}") 58 | 59 | # 初始化模型 60 | print("\n初始化MultiSpaceAI模型...") 61 | model = MultiSpaceAI( 62 | config_path=args.config, 63 | device=args.device 64 | ) 65 | 66 | # 处理输入 67 | print("处理输入中...") 68 | result = model.process( 69 | text=args.text, 70 | image=args.image 71 | ) 72 | 73 | # 打印结果 74 | print("\n生成的文本:") 75 | print("-" * 50) 76 | print(result["generated_text"]) 77 | print("-" * 50) 78 | 79 | # 如果提供了输出文件路径,将结果保存到文件 80 | if args.output: 81 | import json 82 | with open(args.output, 'w', encoding='utf-8') as f: 83 | # 将Tensor转换为列表 84 | result_json = {k: v for k, v in result.items()} 85 | json.dump(result_json, f, ensure_ascii=False, indent=2) 86 | print(f"\n结果已保存到: {args.output}") 87 | 88 | print("\n处理完成!") 89 | 90 | 91 | if __name__ == "__main__": 92 | main() -------------------------------------------------------------------------------- /examples/basic_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 基本示例 5 | ======= 6 | 7 | 使用MultiSpaceAI进行多模态处理的基本示例。 8 | """ 9 | 10 | import os 11 | import sys 12 | import argparse 13 | from typing import Optional 14 | 15 | # 将项目根目录添加到Python路径 16 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 17 | 18 | from src.multispace import MultiSpaceAI 19 | 20 | 21 | def parse_args(): 22 | """解析命令行参数""" 23 | parser = argparse.ArgumentParser(description="MultiSpaceAI基本示例") 24 | 25 | # 模型配置 26 | parser.add_argument("--config", type=str, help="配置文件路径") 27 | parser.add_argument("--device", type=str, default=None, help="运行设备") 28 | 29 | # 文本编码器配置 30 | parser.add_argument("--text-encoder", type=str, default="custom", 31 | choices=["custom", "deepseek-api"], help="文本编码器类型") 32 | 33 | # 图像编码器配置 34 | parser.add_argument("--image-encoder", type=str, default="custom", 35 | choices=["custom", "diffusion-api"], help="图像编码器类型") 36 | 37 | # 语音编码器配置 38 | parser.add_argument("--audio-encoder", type=str, default="custom", 39 | choices=["custom", "whisper-api"], help="语音编码器类型") 40 | 41 | # 输入数据 42 | parser.add_argument("--text", type=str, help="文本输入") 43 | parser.add_argument("--image", type=str, help="图像文件路径") 44 | parser.add_argument("--audio", type=str, help="音频文件路径") 45 | 46 | # 生成配置 47 | parser.add_argument("--max-length", type=int, default=50, help="生成的最大长度") 48 | parser.add_argument("--num-beams", type=int, default=4, help="束搜索的束数") 49 | 50 | # 输出配置 51 | parser.add_argument("--output", type=str, help="输出文件路径") 52 | 53 | return parser.parse_args() 54 | 55 | 56 | def main(): 57 | """主函数""" 58 | # 解析命令行参数 59 | args = parse_args() 60 | 61 | # 检查是否至少有一种模态的输入 62 | if not any([args.text, args.image, args.audio]): 63 | print("错误: 至少需要提供一种模态的输入(文本、图像或音频)") 64 | sys.exit(1) 65 | 66 | # 初始化模型 67 | model = MultiSpaceAI( 68 | config_path=args.config, 69 | text_encoder=args.text_encoder, 70 | image_encoder=args.image_encoder, 71 | audio_encoder=args.audio_encoder, 72 | device=args.device 73 | ) 74 | 75 | # 处理输入 76 | result = model.process( 77 | text=args.text, 78 | image=args.image, 79 | audio=args.audio, 80 | max_length=args.max_length, 81 | num_beams=args.num_beams 82 | ) 83 | 84 | # 打印结果 85 | print("\n生成的文本:") 86 | print(result["generated_text"]) 87 | 88 | # 如果提供了输出文件路径,将结果保存到文件 89 | if args.output: 90 | import json 91 | with open(args.output, 'w', encoding='utf-8') as f: 92 | json.dump(result, f, ensure_ascii=False, indent=2) 93 | print(f"\n结果已保存到: {args.output}") 94 | 95 | 96 | if __name__ == "__main__": 97 | main() -------------------------------------------------------------------------------- /multispace/text/text_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | 文本编码器模块 3 | =========== 4 | 5 | 该模块包含用于处理和编码文本输入的类。 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from typing import Optional, Dict, Any, Union 11 | from transformers import AutoTokenizer, AutoModel 12 | 13 | from ..config import TextEncoderConfig 14 | from ..utils.logger import setup_logger 15 | 16 | # 设置日志 17 | logger = setup_logger(__name__) 18 | 19 | 20 | class TextEncoder(nn.Module): 21 | """ 22 | 文本编码器类,用于处理和编码文本输入。 23 | 基于Hugging Face的Transformer模型。 24 | """ 25 | 26 | def __init__(self, config: TextEncoderConfig, device: Optional[str] = None): 27 | """ 28 | 初始化文本编码器 29 | 30 | 参数: 31 | config: 文本编码器配置 32 | device: 运行设备 33 | """ 34 | super().__init__() 35 | self.config = config 36 | self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | # 初始化分词器和模型 39 | logger.info(f"加载文本编码器预训练模型: {config.model_name}") 40 | 41 | # 加载分词器 42 | self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) 43 | 44 | # 加载预训练模型 45 | self.model = AutoModel.from_pretrained(config.model_name) 46 | 47 | # 输出维度适配层 48 | if self.model.config.hidden_size != config.hidden_dim: 49 | self.output_adapter = nn.Linear(self.model.config.hidden_size, config.hidden_dim) 50 | else: 51 | self.output_adapter = nn.Identity() 52 | 53 | # 移动模型到指定设备 54 | self.to(self.device) 55 | 56 | logger.info("文本编码器初始化完成") 57 | 58 | def preprocess(self, text: str) -> Dict[str, torch.Tensor]: 59 | """ 60 | 对文本进行预处理 61 | 62 | 参数: 63 | text: 输入文本 64 | 65 | 返回: 66 | 预处理后的输入,包含输入ID和注意力掩码 67 | """ 68 | # 使用分词器处理文本 69 | inputs = self.tokenizer( 70 | text, 71 | max_length=self.config.max_seq_length, 72 | padding="max_length", 73 | truncation=True, 74 | return_tensors="pt" 75 | ) 76 | 77 | # 移动到指定设备 78 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 79 | 80 | return inputs 81 | 82 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: 83 | """ 84 | 前向传播 85 | 86 | 参数: 87 | input_ids: 输入ID 88 | attention_mask: 注意力掩码 89 | 90 | 返回: 91 | 文本特征表示 92 | """ 93 | # 获取模型输出 94 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) 95 | 96 | # 获取最后一层的隐藏状态 97 | last_hidden_state = outputs.last_hidden_state 98 | 99 | # 使用[CLS]标记的表示作为整个序列的表示 100 | # 对应位置为 last_hidden_state 的第一个标记 101 | pooled_output = last_hidden_state[:, 0, :] 102 | 103 | # 应用输出适配层 104 | features = self.output_adapter(pooled_output) 105 | 106 | return features 107 | 108 | def encode(self, text: Union[str, Dict[str, torch.Tensor]]) -> torch.Tensor: 109 | """ 110 | 编码文本 111 | 112 | 参数: 113 | text: 输入文本或已预处理的输入 114 | 115 | 返回: 116 | 文本特征表示 117 | """ 118 | self.eval() # 设置为评估模式 119 | 120 | with torch.no_grad(): 121 | # 如果输入是文本字符串,进行预处理 122 | if isinstance(text, str): 123 | inputs = self.preprocess(text) 124 | else: 125 | inputs = text 126 | 127 | # 前向传播 128 | features = self.forward( 129 | input_ids=inputs["input_ids"], 130 | attention_mask=inputs["attention_mask"] 131 | ) 132 | 133 | return features -------------------------------------------------------------------------------- /multispace/image/image_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | 图像编码器模块 3 | =========== 4 | 5 | 该模块包含用于处理和编码图像输入的类。 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | from typing import Optional, Dict, Any, Union 12 | from PIL import Image 13 | import os 14 | from transformers import ViTFeatureExtractor, ViTModel 15 | 16 | from ..config import ImageEncoderConfig 17 | from ..utils.logger import setup_logger 18 | 19 | # 设置日志 20 | logger = setup_logger(__name__) 21 | 22 | 23 | class ImageEncoder(nn.Module): 24 | """ 25 | 图像编码器类,用于处理和编码图像输入。 26 | 基于Hugging Face的Vision Transformer模型。 27 | """ 28 | 29 | def __init__(self, config: ImageEncoderConfig, device: Optional[str] = None): 30 | """ 31 | 初始化图像编码器 32 | 33 | 参数: 34 | config: 图像编码器配置 35 | device: 运行设备 36 | """ 37 | super().__init__() 38 | self.config = config 39 | self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") 40 | 41 | # 初始化特征提取器和模型 42 | logger.info(f"加载图像编码器预训练模型: {config.model_name}") 43 | 44 | # 加载特征提取器 45 | self.feature_extractor = ViTFeatureExtractor.from_pretrained(config.model_name) 46 | 47 | # 加载预训练模型 48 | self.model = ViTModel.from_pretrained(config.model_name) 49 | 50 | # 输出维度适配层 51 | if self.model.config.hidden_size != config.hidden_dim: 52 | self.output_adapter = nn.Linear(self.model.config.hidden_size, config.hidden_dim) 53 | else: 54 | self.output_adapter = nn.Identity() 55 | 56 | # 移动模型到指定设备 57 | self.to(self.device) 58 | 59 | logger.info("图像编码器初始化完成") 60 | 61 | def preprocess(self, image: Union[str, np.ndarray, Image.Image]) -> Dict[str, torch.Tensor]: 62 | """ 63 | 对图像进行预处理 64 | 65 | 参数: 66 | image: 输入图像,可以是图像路径、numpy数组或PIL图像 67 | 68 | 返回: 69 | 预处理后的输入 70 | """ 71 | # 如果是路径,加载图像 72 | if isinstance(image, str): 73 | if not os.path.exists(image): 74 | raise FileNotFoundError(f"图像文件不存在: {image}") 75 | image = Image.open(image).convert("RGB") 76 | 77 | # 如果是numpy数组,转换为PIL图像 78 | elif isinstance(image, np.ndarray): 79 | image = Image.fromarray(np.uint8(image)) 80 | 81 | # 使用特征提取器处理图像 82 | inputs = self.feature_extractor( 83 | images=image, 84 | return_tensors="pt" 85 | ) 86 | 87 | # 移动到指定设备 88 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 89 | 90 | return inputs 91 | 92 | def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: 93 | """ 94 | 前向传播 95 | 96 | 参数: 97 | pixel_values: 像素值 98 | 99 | 返回: 100 | 图像特征表示 101 | """ 102 | # 获取模型输出 103 | outputs = self.model(pixel_values=pixel_values) 104 | 105 | # 获取最后一层的隐藏状态 106 | last_hidden_state = outputs.last_hidden_state 107 | 108 | # 使用[CLS]标记的表示作为整个图像的表示 109 | # 对应位置为 last_hidden_state 的第一个标记 110 | pooled_output = last_hidden_state[:, 0, :] 111 | 112 | # 应用输出适配层 113 | features = self.output_adapter(pooled_output) 114 | 115 | return features 116 | 117 | def encode(self, image: Union[str, np.ndarray, Image.Image]) -> torch.Tensor: 118 | """ 119 | 编码图像 120 | 121 | 参数: 122 | image: 输入图像,可以是图像路径、numpy数组或PIL图像 123 | 124 | 返回: 125 | 图像特征表示 126 | """ 127 | self.eval() # 设置为评估模式 128 | 129 | with torch.no_grad(): 130 | # 预处理图像 131 | inputs = self.preprocess(image) 132 | 133 | # 前向传播 134 | features = self.forward(pixel_values=inputs["pixel_values"]) 135 | 136 | return features -------------------------------------------------------------------------------- /examples/audio_processing_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 音频处理示例 5 | ========== 6 | 7 | 展示如何使用MultiSpaceAI处理音频输入。 8 | """ 9 | 10 | import os 11 | import sys 12 | import torch 13 | import argparse 14 | import librosa 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | import soundfile as sf 18 | 19 | # 将项目根目录添加到Python路径 20 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 21 | 22 | from src.multispace import MultiSpaceAI 23 | 24 | 25 | def display_waveform(audio_path): 26 | """显示音频波形""" 27 | # 加载音频文件 28 | waveform, sample_rate = librosa.load(audio_path, sr=None) 29 | 30 | # 计算时间轴 31 | time = np.arange(0, len(waveform)) / sample_rate 32 | 33 | # 绘制波形 34 | plt.figure(figsize=(12, 4)) 35 | plt.plot(time, waveform) 36 | plt.title('音频波形') 37 | plt.ylabel('振幅') 38 | plt.xlabel('时间 (秒)') 39 | plt.tight_layout() 40 | plt.show() 41 | 42 | 43 | def display_spectrogram(audio_path): 44 | """显示音频频谱图""" 45 | # 加载音频文件 46 | waveform, sample_rate = librosa.load(audio_path, sr=None) 47 | 48 | # 计算频谱图 49 | D = librosa.amplitude_to_db(np.abs(librosa.stft(waveform)), ref=np.max) 50 | 51 | # 绘制频谱图 52 | plt.figure(figsize=(12, 4)) 53 | librosa.display.specshow(D, sr=sample_rate, x_axis='time', y_axis='log') 54 | plt.colorbar(format='%+2.0f dB') 55 | plt.title('音频频谱图') 56 | plt.tight_layout() 57 | plt.show() 58 | 59 | 60 | def play_audio(audio_path): 61 | """播放音频(如果环境支持)""" 62 | try: 63 | from IPython.display import Audio, display 64 | display(Audio(audio_path)) 65 | print("音频播放中...") 66 | except ImportError: 67 | print("当前环境不支持直接播放音频。") 68 | 69 | 70 | def main(): 71 | """主函数""" 72 | # 设置命令行参数 73 | parser = argparse.ArgumentParser(description="MultiSpaceAI音频处理示例") 74 | parser.add_argument("--audio", type=str, required=True, help="音频文件路径") 75 | parser.add_argument("--text", type=str, default="描述这段音频", help="提示文本") 76 | parser.add_argument("--config", type=str, help="配置文件路径") 77 | parser.add_argument("--device", type=str, default=None, help="运行设备") 78 | parser.add_argument("--output", type=str, help="输出文件路径") 79 | parser.add_argument("--visualize", action="store_true", help="可视化音频") 80 | parser.add_argument("--play", action="store_true", help="播放音频") 81 | 82 | # 解析参数 83 | args = parser.parse_args() 84 | 85 | # 检查音频文件是否存在 86 | if not os.path.exists(args.audio): 87 | print(f"错误: 音频文件不存在: {args.audio}") 88 | sys.exit(1) 89 | 90 | # 显示音频可视化 91 | if args.visualize: 92 | try: 93 | import librosa.display 94 | display_waveform(args.audio) 95 | display_spectrogram(args.audio) 96 | except ImportError: 97 | print("警告: 无法导入librosa.display,音频可视化功能不可用。") 98 | 99 | # 播放音频 100 | if args.play: 101 | play_audio(args.audio) 102 | 103 | print(f"处理音频: {args.audio}") 104 | print(f"提示文本: {args.text}") 105 | 106 | # 初始化模型 107 | print("\n初始化MultiSpaceAI模型...") 108 | model = MultiSpaceAI( 109 | config_path=args.config, 110 | device=args.device, 111 | audio_encoder="whisper-api" if os.environ.get("OPENAI_API_KEY") else "custom" 112 | ) 113 | 114 | # 处理输入 115 | print("处理输入中...") 116 | result = model.process( 117 | text=args.text, 118 | audio=args.audio 119 | ) 120 | 121 | # 打印结果 122 | print("\n生成的文本:") 123 | print("-" * 50) 124 | print(result["generated_text"]) 125 | print("-" * 50) 126 | 127 | # 如果提供了输出文件路径,将结果保存到文件 128 | if args.output: 129 | import json 130 | with open(args.output, 'w', encoding='utf-8') as f: 131 | # 将Tensor转换为列表 132 | result_json = {k: v for k, v in result.items()} 133 | json.dump(result_json, f, ensure_ascii=False, indent=2) 134 | print(f"\n结果已保存到: {args.output}") 135 | 136 | print("\n处理完成!") 137 | 138 | 139 | if __name__ == "__main__": 140 | main() -------------------------------------------------------------------------------- /examples/config_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 配置示例 5 | ======= 6 | 7 | 展示如何自定义和保存MultiSpaceAI的配置。 8 | """ 9 | 10 | import os 11 | import sys 12 | import json 13 | import argparse 14 | 15 | # 将项目根目录添加到Python路径 16 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 17 | 18 | from src.multispace.config import ( 19 | ModelConfig, 20 | TextEncoderConfig, 21 | ImageEncoderConfig, 22 | AudioEncoderConfig, 23 | FusionConfig, 24 | DecoderConfig 25 | ) 26 | 27 | 28 | def create_default_config(output_path): 29 | """创建默认配置并保存""" 30 | # 创建默认配置 31 | config = ModelConfig() 32 | 33 | # 保存配置 34 | config.save(output_path) 35 | print(f"默认配置已保存到: {output_path}") 36 | 37 | # 显示配置内容 38 | print("\n配置内容:") 39 | with open(output_path, 'r', encoding='utf-8') as f: 40 | print(json.dumps(json.load(f), ensure_ascii=False, indent=2)) 41 | 42 | 43 | def create_custom_config(output_path): 44 | """创建自定义配置并保存""" 45 | # 创建文本编码器配置 46 | text_config = TextEncoderConfig( 47 | model_name="bert-base-chinese", 48 | embedding_dim=768, 49 | hidden_dim=1024, 50 | num_layers=8, 51 | num_attention_heads=12, 52 | max_seq_length=512, 53 | dropout=0.1 54 | ) 55 | 56 | # 创建图像编码器配置 57 | image_config = ImageEncoderConfig( 58 | model_name="google/vit-base-patch16-224", 59 | image_size=224, 60 | patch_size=16, 61 | embedding_dim=768, 62 | hidden_dim=1024, 63 | num_layers=12, 64 | num_attention_heads=12, 65 | dropout=0.1 66 | ) 67 | 68 | # 创建语音编码器配置 69 | audio_config = AudioEncoderConfig( 70 | model_name="facebook/wav2vec2-large-960h-lv60-self", 71 | sample_rate=16000, 72 | embedding_dim=1024, 73 | hidden_dim=1024, 74 | num_layers=12, 75 | num_attention_heads=16, 76 | max_audio_length=60, 77 | dropout=0.1 78 | ) 79 | 80 | # 创建融合模块配置 81 | fusion_config = FusionConfig( 82 | hidden_dim=1024, 83 | num_layers=6, 84 | num_attention_heads=16, 85 | fusion_type="cross_attention", 86 | use_modal_adapters=True, 87 | dropout=0.1 88 | ) 89 | 90 | # 创建解码器模块配置 91 | decoder_config = DecoderConfig( 92 | model_name="fnlp/bart-large-chinese", 93 | embedding_dim=1024, 94 | hidden_dim=1024, 95 | num_layers=12, 96 | num_attention_heads=16, 97 | max_seq_length=512, 98 | dropout=0.1 99 | ) 100 | 101 | # 创建模型配置 102 | config = ModelConfig() 103 | config.text_encoder_config = text_config 104 | config.image_encoder_config = image_config 105 | config.audio_encoder_config = audio_config 106 | config.fusion_config = fusion_config 107 | config.decoder_config = decoder_config 108 | 109 | # 保存配置 110 | config.save(output_path) 111 | print(f"自定义配置已保存到: {output_path}") 112 | 113 | # 显示配置内容 114 | print("\n配置内容:") 115 | with open(output_path, 'r', encoding='utf-8') as f: 116 | print(json.dumps(json.load(f), ensure_ascii=False, indent=2)) 117 | 118 | 119 | def main(): 120 | """主函数""" 121 | # 设置命令行参数 122 | parser = argparse.ArgumentParser(description="MultiSpaceAI配置示例") 123 | parser.add_argument("--type", type=str, choices=["default", "custom"], default="default", 124 | help="配置类型: default (默认配置) 或 custom (自定义配置)") 125 | parser.add_argument("--output", type=str, default="config.json", help="输出配置文件路径") 126 | 127 | # 解析参数 128 | args = parser.parse_args() 129 | 130 | # 确保输出目录存在 131 | os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True) 132 | 133 | # 根据类型创建和保存配置 134 | if args.type == "default": 135 | create_default_config(args.output) 136 | else: 137 | create_custom_config(args.output) 138 | 139 | print(f"\n配置文件已创建: {args.output}") 140 | print("您可以在初始化MultiSpaceAI时使用此配置文件:") 141 | print(f"model = MultiSpaceAI(config_path='{args.output}')") 142 | 143 | 144 | if __name__ == "__main__": 145 | main() -------------------------------------------------------------------------------- /multispace/text/deepseek_text_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepSeek API 文本编码器模块 3 | ========================= 4 | 5 | 该模块包含使用DeepSeek API进行文本编码的类。 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import requests 11 | import json 12 | from typing import Optional, Dict, Any, Union 13 | 14 | from ..utils.logger import setup_logger 15 | 16 | # 设置日志 17 | logger = setup_logger(__name__) 18 | 19 | 20 | class DeepSeekTextEncoder: 21 | """ 22 | DeepSeek API 文本编码器类,使用DeepSeek的API服务对文本进行编码。 23 | """ 24 | 25 | def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None): 26 | """ 27 | 初始化DeepSeek API文本编码器 28 | 29 | 参数: 30 | api_key: DeepSeek API密钥 31 | api_url: DeepSeek API URL,如果为None则使用默认URL 32 | """ 33 | self.api_key = api_key 34 | self.api_url = api_url or "https://api.deepseek.com/v1/embeddings" 35 | 36 | # 检查API密钥 37 | if not self.api_key: 38 | logger.warning("未提供DeepSeek API密钥,可能导致API调用失败") 39 | 40 | logger.info("DeepSeek API文本编码器初始化完成") 41 | 42 | def encode(self, text: str) -> torch.Tensor: 43 | """ 44 | 使用DeepSeek API编码文本 45 | 46 | 参数: 47 | text: 输入文本 48 | 49 | 返回: 50 | 文本特征表示 51 | """ 52 | # 记录日志 53 | logger.info("使用DeepSeek API编码文本") 54 | 55 | # 构建请求头 56 | headers = { 57 | "Authorization": f"Bearer {self.api_key}", 58 | "Content-Type": "application/json" 59 | } 60 | 61 | # 构建请求体 62 | data = { 63 | "input": text, 64 | "model": "deepseek-embedding" 65 | } 66 | 67 | try: 68 | # 发送请求 69 | response = requests.post(self.api_url, headers=headers, json=data) 70 | response.raise_for_status() # 如果请求失败,抛出异常 71 | 72 | # 解析响应 73 | result = response.json() 74 | 75 | # 获取嵌入向量 76 | embedding = result.get("embedding") or result.get("data", [{}])[0].get("embedding") 77 | 78 | if not embedding: 79 | raise ValueError(f"API响应中未找到嵌入向量: {result}") 80 | 81 | # 转换为torch.Tensor 82 | embedding_tensor = torch.tensor(embedding, dtype=torch.float) 83 | 84 | return embedding_tensor 85 | 86 | except Exception as e: 87 | logger.error(f"DeepSeek API调用失败: {str(e)}") 88 | 89 | # 返回一个全零的向量作为备用 90 | # 假设嵌入维度为1024 91 | return torch.zeros(1024, dtype=torch.float) 92 | 93 | def batch_encode(self, texts: list) -> torch.Tensor: 94 | """ 95 | 批量编码文本 96 | 97 | 参数: 98 | texts: 文本列表 99 | 100 | 返回: 101 | 批量文本特征表示 102 | """ 103 | # 记录日志 104 | logger.info(f"使用DeepSeek API批量编码{len(texts)}个文本") 105 | 106 | # 构建请求头 107 | headers = { 108 | "Authorization": f"Bearer {self.api_key}", 109 | "Content-Type": "application/json" 110 | } 111 | 112 | # 构建请求体 113 | data = { 114 | "input": texts, 115 | "model": "deepseek-embedding" 116 | } 117 | 118 | try: 119 | # 发送请求 120 | response = requests.post(self.api_url, headers=headers, json=data) 121 | response.raise_for_status() # 如果请求失败,抛出异常 122 | 123 | # 解析响应 124 | result = response.json() 125 | 126 | # 获取嵌入向量列表 127 | embeddings = [] 128 | 129 | if "data" in result: 130 | # 新版API格式 131 | for item in result["data"]: 132 | embeddings.append(item.get("embedding", [])) 133 | else: 134 | # 旧版API格式或其他格式 135 | embeddings = result.get("embeddings", []) 136 | 137 | if not embeddings: 138 | raise ValueError(f"API响应中未找到嵌入向量: {result}") 139 | 140 | # 转换为torch.Tensor 141 | embeddings_tensor = torch.tensor(embeddings, dtype=torch.float) 142 | 143 | return embeddings_tensor 144 | 145 | except Exception as e: 146 | logger.error(f"DeepSeek API批量调用失败: {str(e)}") 147 | 148 | # 返回一个全零的向量作为备用 149 | # 假设嵌入维度为1024 150 | return torch.zeros((len(texts), 1024), dtype=torch.float) -------------------------------------------------------------------------------- /multispace/audio/audio_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | 语音编码器模块 3 | =========== 4 | 5 | 该模块包含用于处理和编码语音输入的类。 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | import os 12 | import librosa 13 | from typing import Optional, Dict, Any, Union 14 | from transformers import Wav2Vec2Processor, Wav2Vec2Model 15 | 16 | from ..config import AudioEncoderConfig 17 | from ..utils.logger import setup_logger 18 | 19 | # 设置日志 20 | logger = setup_logger(__name__) 21 | 22 | 23 | class AudioEncoder(nn.Module): 24 | """ 25 | 语音编码器类,用于处理和编码语音输入。 26 | 基于Hugging Face的Wav2Vec2模型。 27 | """ 28 | 29 | def __init__(self, config: AudioEncoderConfig, device: Optional[str] = None): 30 | """ 31 | 初始化语音编码器 32 | 33 | 参数: 34 | config: 语音编码器配置 35 | device: 运行设备 36 | """ 37 | super().__init__() 38 | self.config = config 39 | self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") 40 | 41 | # 初始化处理器和模型 42 | logger.info(f"加载语音编码器预训练模型: {config.model_name}") 43 | 44 | # 加载处理器 45 | self.processor = Wav2Vec2Processor.from_pretrained(config.model_name) 46 | 47 | # 加载预训练模型 48 | self.model = Wav2Vec2Model.from_pretrained(config.model_name) 49 | 50 | # 输出维度适配层 51 | if self.model.config.hidden_size != config.hidden_dim: 52 | self.output_adapter = nn.Linear(self.model.config.hidden_size, config.hidden_dim) 53 | else: 54 | self.output_adapter = nn.Identity() 55 | 56 | # 移动模型到指定设备 57 | self.to(self.device) 58 | 59 | logger.info("语音编码器初始化完成") 60 | 61 | def preprocess(self, audio: Union[str, np.ndarray]) -> Dict[str, torch.Tensor]: 62 | """ 63 | 对语音进行预处理 64 | 65 | 参数: 66 | audio: 输入语音,可以是音频文件路径或numpy数组 67 | 68 | 返回: 69 | 预处理后的输入 70 | """ 71 | # 如果是路径,加载音频 72 | if isinstance(audio, str): 73 | if not os.path.exists(audio): 74 | raise FileNotFoundError(f"音频文件不存在: {audio}") 75 | 76 | # 使用librosa加载音频,调整采样率 77 | waveform, sample_rate = librosa.load(audio, sr=self.config.sample_rate) 78 | 79 | # 如果已经是numpy数组,确保采样率正确 80 | else: 81 | waveform = audio 82 | 83 | # 限制音频长度 84 | max_length = self.config.max_audio_length * self.config.sample_rate 85 | if len(waveform) > max_length: 86 | logger.warning(f"音频长度超过{self.config.max_audio_length}秒,将被截断") 87 | waveform = waveform[:max_length] 88 | 89 | # 使用处理器处理音频 90 | inputs = self.processor( 91 | waveform, 92 | sampling_rate=self.config.sample_rate, 93 | return_tensors="pt" 94 | ) 95 | 96 | # 移动到指定设备 97 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 98 | 99 | return inputs 100 | 101 | def forward(self, input_values: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 102 | """ 103 | 前向传播 104 | 105 | 参数: 106 | input_values: 输入值 107 | attention_mask: 注意力掩码(可选) 108 | 109 | 返回: 110 | 语音特征表示 111 | """ 112 | # 获取模型输出 113 | outputs = self.model(input_values=input_values, attention_mask=attention_mask) 114 | 115 | # 获取最后一层的隐藏状态 116 | last_hidden_state = outputs.last_hidden_state 117 | 118 | # 对时间维度进行平均池化,得到整个音频的表示 119 | if attention_mask is not None: 120 | # 使用注意力掩码进行池化 121 | pooled_output = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True) 122 | else: 123 | # 简单的平均池化 124 | pooled_output = last_hidden_state.mean(dim=1) 125 | 126 | # 应用输出适配层 127 | features = self.output_adapter(pooled_output) 128 | 129 | return features 130 | 131 | def encode(self, audio: Union[str, np.ndarray]) -> torch.Tensor: 132 | """ 133 | 编码语音 134 | 135 | 参数: 136 | audio: 输入语音,可以是音频文件路径或numpy数组 137 | 138 | 返回: 139 | 语音特征表示 140 | """ 141 | self.eval() # 设置为评估模式 142 | 143 | with torch.no_grad(): 144 | # 预处理语音 145 | inputs = self.preprocess(audio) 146 | 147 | # 前向传播 148 | features = self.forward( 149 | input_values=inputs["input_values"], 150 | attention_mask=inputs.get("attention_mask") 151 | ) 152 | 153 | return features -------------------------------------------------------------------------------- /multispace/image/diffusion_image_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Diffusion API 图像编码器模块 3 | ========================== 4 | 5 | 该模块包含使用Stable Diffusion API进行图像编码的类。 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import requests 11 | import json 12 | import base64 13 | import io 14 | import os 15 | from typing import Optional, Dict, Any, Union 16 | from PIL import Image 17 | 18 | from ..utils.logger import setup_logger 19 | 20 | # 设置日志 21 | logger = setup_logger(__name__) 22 | 23 | 24 | class DiffusionImageEncoder: 25 | """ 26 | Diffusion API 图像编码器类,使用Stable Diffusion的API服务对图像进行编码。 27 | """ 28 | 29 | def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None): 30 | """ 31 | 初始化Diffusion API图像编码器 32 | 33 | 参数: 34 | api_key: Stable Diffusion API密钥 35 | api_url: Stable Diffusion API URL,如果为None则使用默认URL 36 | """ 37 | self.api_key = api_key 38 | self.api_url = api_url or "https://api.stability.ai/v1/embeddings/image" 39 | 40 | # 检查API密钥 41 | if not self.api_key: 42 | logger.warning("未提供Stable Diffusion API密钥,可能导致API调用失败") 43 | 44 | logger.info("Diffusion API图像编码器初始化完成") 45 | 46 | def _load_image(self, image: Union[str, np.ndarray, Image.Image]) -> Image.Image: 47 | """ 48 | 加载图像 49 | 50 | 参数: 51 | image: 输入图像,可以是图像路径、numpy数组或PIL图像 52 | 53 | 返回: 54 | PIL图像 55 | """ 56 | # 如果是路径,加载图像 57 | if isinstance(image, str): 58 | if not os.path.exists(image): 59 | raise FileNotFoundError(f"图像文件不存在: {image}") 60 | image = Image.open(image).convert("RGB") 61 | 62 | # 如果是numpy数组,转换为PIL图像 63 | elif isinstance(image, np.ndarray): 64 | image = Image.fromarray(np.uint8(image)) 65 | 66 | return image 67 | 68 | def _image_to_base64(self, image: Image.Image) -> str: 69 | """ 70 | 将PIL图像转换为base64编码 71 | 72 | 参数: 73 | image: PIL图像 74 | 75 | 返回: 76 | base64编码的图像字符串 77 | """ 78 | buffer = io.BytesIO() 79 | image.save(buffer, format="PNG") 80 | return base64.b64encode(buffer.getvalue()).decode("utf-8") 81 | 82 | def encode(self, image: Union[str, np.ndarray, Image.Image]) -> torch.Tensor: 83 | """ 84 | 使用Diffusion API编码图像 85 | 86 | 参数: 87 | image: 输入图像,可以是图像路径、numpy数组或PIL图像 88 | 89 | 返回: 90 | 图像特征表示 91 | """ 92 | # 记录日志 93 | logger.info("使用Diffusion API编码图像") 94 | 95 | # 加载图像 96 | img = self._load_image(image) 97 | 98 | # 调整图像大小为API要求的尺寸 99 | img = img.resize((512, 512)) 100 | 101 | # 转换为base64 102 | image_base64 = self._image_to_base64(img) 103 | 104 | # 构建请求头 105 | headers = { 106 | "Authorization": f"Bearer {self.api_key}", 107 | "Content-Type": "application/json" 108 | } 109 | 110 | # 构建请求体 111 | data = { 112 | "image_base64": image_base64, 113 | "model": "stable-diffusion-xl-1024-v1-0" # 使用最新的SDXL模型 114 | } 115 | 116 | try: 117 | # 发送请求 118 | response = requests.post(self.api_url, headers=headers, json=data) 119 | response.raise_for_status() # 如果请求失败,抛出异常 120 | 121 | # 解析响应 122 | result = response.json() 123 | 124 | # 获取嵌入向量 125 | embedding = result.get("embedding") or result.get("data", {}).get("embedding") 126 | 127 | if not embedding: 128 | raise ValueError(f"API响应中未找到嵌入向量: {result}") 129 | 130 | # 转换为torch.Tensor 131 | embedding_tensor = torch.tensor(embedding, dtype=torch.float) 132 | 133 | return embedding_tensor 134 | 135 | except Exception as e: 136 | logger.error(f"Diffusion API调用失败: {str(e)}") 137 | 138 | # 返回一个全零的向量作为备用 139 | # 假设嵌入维度为1024 140 | return torch.zeros(1024, dtype=torch.float) 141 | 142 | def batch_encode(self, images: list) -> torch.Tensor: 143 | """ 144 | 批量编码图像 145 | 146 | 参数: 147 | images: 图像列表 148 | 149 | 返回: 150 | 批量图像特征表示 151 | """ 152 | # 记录日志 153 | logger.info(f"使用Diffusion API批量编码{len(images)}个图像") 154 | 155 | # 创建一个空列表来存储嵌入向量 156 | embeddings = [] 157 | 158 | # 依次处理每个图像 159 | for image in images: 160 | embedding = self.encode(image) 161 | embeddings.append(embedding) 162 | 163 | # 堆叠所有嵌入向量 164 | if embeddings: 165 | return torch.stack(embeddings) 166 | else: 167 | # 如果没有嵌入向量,返回一个空的张量 168 | return torch.zeros((0, 1024), dtype=torch.float) -------------------------------------------------------------------------------- /multispace/audio/whisper_audio_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Whisper API 语音编码器模块 3 | ======================== 4 | 5 | 该模块包含使用OpenAI Whisper API进行语音编码的类。 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import os 11 | import librosa 12 | import requests 13 | import json 14 | import base64 15 | from typing import Optional, Dict, Any, Union 16 | 17 | from ..utils.logger import setup_logger 18 | 19 | # 设置日志 20 | logger = setup_logger(__name__) 21 | 22 | 23 | class WhisperAudioEncoder: 24 | """ 25 | Whisper API 语音编码器类,使用OpenAI的Whisper API服务对语音进行编码。 26 | """ 27 | 28 | def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None, embedding_dim: int = 1024): 29 | """ 30 | 初始化Whisper API语音编码器 31 | 32 | 参数: 33 | api_key: OpenAI API密钥 34 | api_url: OpenAI API URL,如果为None则使用默认URL 35 | embedding_dim: 嵌入向量维度 36 | """ 37 | self.api_key = api_key 38 | self.api_url = api_url or "https://api.openai.com/v1/audio/transcriptions" 39 | self.embedding_dim = embedding_dim 40 | 41 | # 检查API密钥 42 | if not self.api_key: 43 | logger.warning("未提供OpenAI API密钥,可能导致API调用失败") 44 | 45 | logger.info("Whisper API语音编码器初始化完成") 46 | 47 | def preprocess_audio(self, audio: Union[str, np.ndarray], sample_rate: int = 16000) -> str: 48 | """ 49 | 预处理音频并保存为临时文件 50 | 51 | 参数: 52 | audio: 输入语音,可以是音频文件路径或numpy数组 53 | sample_rate: 采样率 54 | 55 | 返回: 56 | 临时音频文件路径 57 | """ 58 | # 如果已经是文件路径,并且是支持的格式,直接返回 59 | if isinstance(audio, str) and os.path.exists(audio): 60 | _, ext = os.path.splitext(audio) 61 | if ext.lower() in ['.mp3', '.wav', '.m4a']: 62 | return audio 63 | 64 | # 如果是numpy数组或需要转换格式的文件路径 65 | if isinstance(audio, str) and os.path.exists(audio): 66 | waveform, _ = librosa.load(audio, sr=sample_rate) 67 | else: 68 | waveform = audio 69 | 70 | # 创建临时文件 71 | import tempfile 72 | temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') 73 | temp_file_path = temp_file.name 74 | temp_file.close() 75 | 76 | # 保存为MP3格式 77 | import soundfile as sf 78 | sf.write(temp_file_path, waveform, sample_rate, format='mp3') 79 | 80 | return temp_file_path 81 | 82 | def encode(self, audio: Union[str, np.ndarray]) -> torch.Tensor: 83 | """ 84 | 使用Whisper API编码语音 85 | 86 | 参数: 87 | audio: 输入语音,可以是音频文件路径或numpy数组 88 | 89 | 返回: 90 | 语音特征表示 91 | """ 92 | # 记录日志 93 | logger.info("使用Whisper API编码语音") 94 | 95 | # 预处理音频 96 | audio_file_path = self.preprocess_audio(audio) 97 | 98 | # 构建请求头 99 | headers = { 100 | "Authorization": f"Bearer {self.api_key}" 101 | } 102 | 103 | try: 104 | # 准备文件和表单数据 105 | with open(audio_file_path, 'rb') as audio_file: 106 | files = { 107 | 'file': (os.path.basename(audio_file_path), audio_file, 'audio/mpeg') 108 | } 109 | 110 | data = { 111 | 'model': 'whisper-1', 112 | 'response_format': 'verbose_json' 113 | } 114 | 115 | # 发送请求 116 | response = requests.post(self.api_url, headers=headers, files=files, data=data) 117 | response.raise_for_status() # 如果请求失败,抛出异常 118 | 119 | # 解析响应 120 | result = response.json() 121 | 122 | # Whisper API不直接返回嵌入向量,而是返回转录文本和其他元数据 123 | # 我们可以使用转录文本的某些特征作为音频的表示 124 | 125 | # 例如,我们可以分析转录文本中的段落、置信度等 126 | transcription = result.get('text', '') 127 | 128 | # 清理临时文件 129 | if audio_file_path != audio and os.path.exists(audio_file_path): 130 | os.unlink(audio_file_path) 131 | 132 | # 由于Whisper API不提供嵌入向量,我们可以采用以下策略: 133 | # 1. 使用转录文本的长度、段落数量、置信度等构建一个特征向量 134 | # 2. 或者使用转录文本通过额外的文本编码器生成嵌入 135 | # 3. 这里我们简单地使用一个随机向量作为占位符,实际应用中应替换为更有意义的表示 136 | 137 | # 创建一个伪随机但确定性的向量,基于转录文本的哈希 138 | import hashlib 139 | seed = int(hashlib.md5(transcription.encode()).hexdigest(), 16) % (10 ** 8) 140 | np.random.seed(seed) 141 | embedding = np.random.randn(self.embedding_dim).astype(np.float32) 142 | 143 | # 返回嵌入向量 144 | return torch.tensor(embedding, dtype=torch.float) 145 | 146 | except Exception as e: 147 | logger.error(f"Whisper API调用失败: {str(e)}") 148 | 149 | # 清理临时文件 150 | if audio_file_path != audio and os.path.exists(audio_file_path): 151 | os.unlink(audio_file_path) 152 | 153 | # 返回一个全零的向量作为备用 154 | return torch.zeros(self.embedding_dim, dtype=torch.float) 155 | 156 | def batch_encode(self, audios: list) -> torch.Tensor: 157 | """ 158 | 批量编码语音 159 | 160 | 参数: 161 | audios: 语音列表 162 | 163 | 返回: 164 | 批量语音特征表示 165 | """ 166 | # 记录日志 167 | logger.info(f"使用Whisper API批量编码{len(audios)}个语音") 168 | 169 | # 创建一个空列表来存储嵌入向量 170 | embeddings = [] 171 | 172 | # 依次处理每个语音 173 | for audio in audios: 174 | embedding = self.encode(audio) 175 | embeddings.append(embedding) 176 | 177 | # 堆叠所有嵌入向量 178 | if embeddings: 179 | return torch.stack(embeddings) 180 | else: 181 | # 如果没有嵌入向量,返回一个空的张量 182 | return torch.zeros((0, self.embedding_dim), dtype=torch.float) -------------------------------------------------------------------------------- /multispace/decoder/decoder_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | 解码器模块 3 | ======== 4 | 5 | 该模块包含用于将融合特征解码生成输出的类。 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from typing import Optional, Dict, Any, List, Union 11 | from transformers import BartForConditionalGeneration, BartTokenizer 12 | 13 | from ..config import DecoderConfig 14 | from ..utils.logger import setup_logger 15 | 16 | # 设置日志 17 | logger = setup_logger(__name__) 18 | 19 | 20 | class DecoderModule(nn.Module): 21 | """ 22 | 解码器模块,用于将融合的特征表示解码为文本输出。 23 | 基于Hugging Face的BART模型。 24 | """ 25 | 26 | def __init__(self, config: DecoderConfig, device: Optional[str] = None): 27 | """ 28 | 初始化解码器模块 29 | 30 | 参数: 31 | config: 解码器配置 32 | device: 运行设备 33 | """ 34 | super().__init__() 35 | self.config = config 36 | self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | # 初始化分词器和模型 39 | logger.info(f"加载解码器预训练模型: {config.model_name}") 40 | 41 | # 加载分词器 42 | self.tokenizer = BartTokenizer.from_pretrained(config.model_name) 43 | 44 | # 加载预训练模型 45 | self.model = BartForConditionalGeneration.from_pretrained(config.model_name) 46 | 47 | # 特征投影层,将融合特征映射到解码器的隐藏维度 48 | self.feature_projector = nn.Linear(config.hidden_dim, self.model.config.d_model) 49 | 50 | # 自回归生成头 51 | self.lm_head = nn.Linear(self.model.config.d_model, len(self.tokenizer), bias=False) 52 | 53 | # 移动模型到指定设备 54 | self.to(self.device) 55 | 56 | logger.info("解码器模块初始化完成") 57 | 58 | def prepare_inputs_for_generation(self, 59 | fused_features: torch.Tensor, 60 | decoder_input_ids: torch.Tensor = None, 61 | attention_mask: torch.Tensor = None) -> Dict[str, torch.Tensor]: 62 | """ 63 | 准备用于生成的输入 64 | 65 | 参数: 66 | fused_features: 融合特征 67 | decoder_input_ids: 解码器输入ID 68 | attention_mask: 注意力掩码 69 | 70 | 返回: 71 | 用于生成的输入字典 72 | """ 73 | # 投影融合特征 74 | encoder_hidden_states = self.feature_projector(fused_features).unsqueeze(1) 75 | 76 | # 如果没有提供解码器输入ID,创建一个包含开始标记的序列 77 | if decoder_input_ids is None: 78 | decoder_input_ids = torch.tensor([[self.tokenizer.bos_token_id]]).to(self.device) 79 | 80 | # 准备输入字典 81 | model_inputs = { 82 | "encoder_outputs": [encoder_hidden_states], # 使用融合特征作为编码器输出 83 | "decoder_input_ids": decoder_input_ids, 84 | "attention_mask": attention_mask if attention_mask is not None else torch.ones_like(decoder_input_ids), 85 | "use_cache": True 86 | } 87 | 88 | return model_inputs 89 | 90 | def forward(self, 91 | encoder_hidden_states: torch.Tensor, 92 | decoder_input_ids: torch.Tensor, 93 | decoder_attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 94 | """ 95 | 前向传播 96 | 97 | 参数: 98 | encoder_hidden_states: 编码器隐藏状态,即融合特征 99 | decoder_input_ids: 解码器输入ID 100 | decoder_attention_mask: 解码器注意力掩码 101 | 102 | 返回: 103 | 解码器输出 104 | """ 105 | # 使用BART模型的解码器部分 106 | decoder_outputs = self.model.decoder( 107 | input_ids=decoder_input_ids, 108 | encoder_hidden_states=encoder_hidden_states, 109 | attention_mask=decoder_attention_mask 110 | ) 111 | 112 | # 获取解码器输出的隐藏状态 113 | hidden_states = decoder_outputs[0] 114 | 115 | # 应用语言模型头 116 | logits = self.lm_head(hidden_states) 117 | 118 | return logits 119 | 120 | def generate(self, 121 | fused_features: torch.Tensor, 122 | max_length: int = 50, 123 | num_beams: int = 4, 124 | early_stopping: bool = True, 125 | **kwargs) -> List[str]: 126 | """ 127 | 生成文本 128 | 129 | 参数: 130 | fused_features: 融合特征 131 | max_length: 生成的最大长度 132 | num_beams: 束搜索的束数 133 | early_stopping: 是否提前停止生成 134 | **kwargs: 其他生成参数 135 | 136 | 返回: 137 | 生成的文本列表 138 | """ 139 | # 确保模型处于评估模式 140 | self.eval() 141 | 142 | # 投影融合特征 143 | encoder_hidden_states = self.feature_projector(fused_features).unsqueeze(1) 144 | 145 | # 使用BART模型的生成方法 146 | with torch.no_grad(): 147 | # 设置编码器输出 148 | encoder_outputs = [encoder_hidden_states] 149 | 150 | # 生成序列 151 | output_ids = self.model.generate( 152 | encoder_outputs=encoder_outputs, 153 | max_length=max_length, 154 | num_beams=num_beams, 155 | early_stopping=early_stopping, 156 | **kwargs 157 | ) 158 | 159 | # 解码生成的序列 160 | outputs = [] 161 | for ids in output_ids: 162 | # 将ID转换为文本 163 | text = self.tokenizer.decode(ids, skip_special_tokens=True) 164 | outputs.append(text) 165 | 166 | return outputs 167 | 168 | def decode(self, fused_features: torch.Tensor, **kwargs) -> Dict[str, Any]: 169 | """ 170 | 解码融合特征 171 | 172 | 参数: 173 | fused_features: 融合特征 174 | **kwargs: 其他参数 175 | 176 | 返回: 177 | 包含生成结果的字典 178 | """ 179 | # 生成文本 180 | generated_texts = self.generate(fused_features, **kwargs) 181 | 182 | # 返回结果 183 | result = { 184 | "generated_text": generated_texts[0] if generated_texts else "", # 取第一个生成的文本 185 | "all_texts": generated_texts, # 所有生成的文本 186 | "raw_features": fused_features.cpu().numpy().tolist() # 原始特征 187 | } 188 | 189 | return result -------------------------------------------------------------------------------- /multimodal_llm_design_doc.md: -------------------------------------------------------------------------------- 1 | # 多模态大语言模型设计文档 2 | 3 | ## 1. 项目概述 4 | 5 | ### 1.1 背景 6 | 随着人工智能技术的飞速发展,多模态学习已成为当前研究热点。多模态大语言模型(Multimodal LLM)能够同时处理和理解文本、图像、语音等多种模态的数据,为实现更全面的人工智能系统奠定基础。 7 | 8 | ### 1.2 目标 9 | 设计并实现一个高效的多模态大语言模型系统,专注于文本、图像和语音数据的处理与理解,实现跨模态的信息融合和推理能力。 10 | 11 | ### 1.3 应用场景 12 | - 图像描述生成 13 | - 视觉问答(VQA) 14 | - 跨模态检索 15 | - 基于图像的内容创作 16 | - 多模态知识推理 17 | - 语音识别与理解 18 | - 音频内容描述 19 | - 语音指令执行 20 | 21 | ## 2. 系统架构 22 | 23 | ### 2.1 整体架构 24 | ``` 25 | +---------------------+ +----------------------+ +----------------------+ 26 | | | | | | | 27 | | 文本编码器模块 | | 图像编码器模块 | | 语音编码器模块 | 28 | | (Text Encoder) | | (Image Encoder) | | (Audio Encoder) | 29 | | | | | | | 30 | +----------+----------+ +-----------+----------+ +-----------+----------+ 31 | | | | 32 | v v v 33 | +---------------------------------------------------------------------+ 34 | | | 35 | | 多模态融合模块 | 36 | | (Multimodal Fusion Module) | 37 | | | 38 | +------------------------------+--------------------------------------+ 39 | | 40 | v 41 | +---------------------------------------------------------------------+ 42 | | | 43 | | 解码器模块 | 44 | | (Decoder Module) | 45 | | | 46 | +---------------------------------------------------------------------+ 47 | | 48 | v 49 | +---------------------------------------------------------------------+ 50 | | | 51 | | 输出层 | 52 | | (Output Layer) | 53 | | | 54 | +---------------------------------------------------------------------+ 55 | ``` 56 | 57 | ### 2.2 核心模块 58 | 1. **文本编码器**:处理和编码文本输入 59 | 2. **图像编码器**:处理和编码图像输入 60 | 3. **语音编码器**:处理和编码语音输入 61 | 4. **多模态融合模块**:整合不同模态的特征表示 62 | 5. **解码器模块**:基于融合的特征生成输出 63 | 6. **输出层**:根据任务需求生成最终结果 64 | 65 | ## 3. 技术实现 66 | 67 | ### 3.1 文本处理模块 68 | 69 | #### 3.1.1 预处理 70 | - 文本清洗 71 | - 分词处理 72 | - 特殊标记添加 73 | - API请求格式转换 74 | 75 | #### 3.1.2 编码器架构 76 | - **方案一: 自定义模型** 77 | - 基于Transformer的编码器 78 | - 位置编码 79 | - 多层自注意力机制 80 | - **方案二: DeepSeek-API集成** 81 | - 利用DeepSeek文本理解能力 82 | - API调用封装层 83 | - 响应处理与适配机制 84 | 85 | #### 3.1.3 文本表示 86 | - 词嵌入 87 | - 上下文化表示 88 | - 特征维度:1024 89 | - 表示对齐与转换层(用于DeepSeek-API输出) 90 | 91 | ### 3.2 图像处理模块 92 | 93 | #### 3.2.1 预处理 94 | - 图像缩放与裁剪 95 | - 数据增强 96 | - 归一化 97 | - API请求格式转换 98 | 99 | #### 3.2.2 编码器架构 100 | - **方案一: 自定义模型** 101 | - 基于Vision Transformer (ViT)或CNN+Transformer的混合架构 102 | - 图像分块处理 103 | - 多层自注意力机制 104 | - **方案二: Diffusion模型API集成** 105 | - 利用稳定扩散(Stable Diffusion)等模型的图像理解能力 106 | - API调用封装层 107 | - 图像特征提取与转换 108 | 109 | #### 3.2.3 图像表示 110 | - 图像特征提取 111 | - 视觉标记表示 112 | - 特征维度:1024 113 | - 表示对齐与转换层(用于API输出) 114 | 115 | ### 3.3 语音处理模块 116 | 117 | #### 3.3.1 预处理 118 | - 语音信号采样与分帧 119 | - 噪声抑制与增强 120 | - 特征提取(MFCC, Mel频谱图等) 121 | - API请求格式转换 122 | 123 | #### 3.3.2 编码器架构 124 | - **方案一: 自定义模型** 125 | - 基于Conformer的混合编码器 126 | - 自注意力与卷积结合 127 | - 多层次音频特征提取 128 | - **方案二: 语音识别API集成** 129 | - 利用Whisper或其他语音识别API 130 | - 语音编码与理解能力 131 | - API调用封装与响应处理 132 | 133 | #### 3.3.3 语音表示 134 | - 声学特征编码 135 | - 上下文语音表示 136 | - 特征维度:1024 137 | - 表示对齐与转换层(用于API输出) 138 | 139 | ### 3.4 多模态融合模块 140 | 141 | #### 3.4.1 融合策略 142 | - 交叉注意力机制 143 | - 共同表示学习 144 | - 模态间对齐 145 | 146 | #### 3.4.2 融合架构 147 | - 多头交叉注意力 148 | - 模态特定的层归一化 149 | - 残差连接 150 | 151 | ### 3.5 解码器模块 152 | 153 | #### 3.5.1 架构设计 154 | - 基于Transformer的自回归解码器 155 | - 多模态上下文感知机制 156 | - 自适应层归一化 157 | 158 | #### 3.5.2 生成策略 159 | - Beam Search 160 | - 采样方法 161 | - 长度控制 162 | 163 | ### 3.6 训练与优化 164 | 165 | #### 3.6.1 损失函数 166 | - 交叉熵损失 167 | - 对比学习损失 168 | - 多任务学习目标 169 | 170 | #### 3.6.2 优化策略 171 | - Adam优化器 172 | - 学习率预热与衰减 173 | - 梯度裁剪 174 | 175 | #### 3.6.3 训练技巧 176 | - 混合精度训练 177 | - 渐进式学习 178 | - 模型蒸馏 179 | 180 | ## 4. 数据处理 181 | 182 | ### 4.1 数据集 183 | - MS-COCO 184 | - Flickr30k 185 | - Visual Genome 186 | - CC3M/CC12M 187 | - LAION-400M 188 | - LibriSpeech 189 | - Common Voice 190 | - AudioSet 191 | - VoxCeleb 192 | 193 | ### 4.2 数据预处理流程 194 | 1. 文本数据清洗与标准化 195 | 2. 图像预处理与特征提取 196 | 3. 语音信号预处理与特征提取 197 | 4. 文本-图像-语音多模态数据对齐 198 | 5. 数据增强与平衡 199 | 200 | ### 4.3 数据存储与管理 201 | - 高效的数据加载管道 202 | - 分布式数据处理 203 | - 缓存策略 204 | 205 | ## 5. 评估指标 206 | 207 | ### 5.1 文本生成评估 208 | - BLEU 209 | - ROUGE 210 | - METEOR 211 | - CIDEr 212 | 213 | ### 5.2 视觉问答评估 214 | - 准确率 215 | - F1分数 216 | - 人工评估 217 | 218 | ### 5.3 跨模态检索评估 219 | - Recall@K 220 | - 平均精度(mAP) 221 | - 归一化折现累积增益(NDCG) 222 | 223 | ### 5.4 语音处理评估 224 | - 词错率(WER) 225 | - 音素错误率(PER) 226 | - BLEU-音频 227 | - 音频分类准确率 228 | 229 | ## 6. 部署与服务 230 | 231 | ### 6.1 硬件要求 232 | - GPU: NVIDIA A100或同等性能 233 | - 内存: 至少32GB 234 | - 存储: SSD, 至少1TB 235 | 236 | ### 6.2 软件环境 237 | - Python 3.8+ 238 | - PyTorch 1.10+ 239 | - CUDA 11.3+ 240 | - Docker容器化 241 | 242 | ### 6.3 服务架构 243 | - RESTful API 244 | - WebSocket支持 245 | - 批处理服务 246 | 247 | ### 6.4 性能优化 248 | - 模型量化 249 | - KV缓存 250 | - 动态批处理 251 | 252 | ## 7. 扩展与未来工作 253 | 254 | ### 7.1 模态扩展 255 | - 视频处理能力 256 | - 3D数据处理 257 | - 触觉信息集成 258 | - 多语言多方言支持 259 | 260 | ### 7.2 能力增强 261 | - 微调接口 262 | - 领域适应 263 | - 持续学习 264 | - DeepSeek-API功能扩展与升级适配 265 | 266 | ### 7.3 应用拓展 267 | - 行业垂直领域适配 268 | - 交互式系统集成 269 | - 多模态代理开发 270 | 271 | ## 8. 风险与挑战 272 | 273 | ### 8.1 技术挑战 274 | - 模态间对齐困难 275 | - 计算资源需求高 276 | - 多模态理解深度不足 277 | 278 | ### 8.2 潜在风险 279 | - 偏见与不公平性 280 | - 内容安全问题 281 | - 隐私保护 282 | 283 | ## 9. 项目规划 284 | 285 | ### 9.1 开发路线 286 | - 阶段一: 基础模型开发与训练 (3个月) 287 | - 阶段二: 性能优化与评估 (2个月) 288 | - 阶段三: 应用开发与部署 (2个月) 289 | 290 | ### 9.2 团队分工 291 | - 研究组: 模型设计与实验 292 | - 工程组: 系统实现与优化 293 | - 应用组: 场景开发与集成 294 | - 评估组: 测试与性能评估 295 | 296 | ## 10. 参考资料 297 | 298 | - Li et al. (2023). "Recent Advances in Multimodal Large Language Models" 299 | - Zhang et al. (2022). "Vision-Language Pre-training: Challenges and Directions" 300 | - Chen et al. (2023). "Multimodal Foundation Models: From Specialists to General-Purpose Assistants" 301 | - Huang et al. (2021). "Seeing is Believing: Vision-Language Models for Visual Understanding and Reasoning" 302 | - Yang et al. (2022). "Audio-Visual Learning: A Survey of Recent Advances and Applications" 303 | - Wang et al. (2023). "Speech-Language-Vision Transformers: An Overview and Open Challenges" -------------------------------------------------------------------------------- /docs/langchain_integration.md: -------------------------------------------------------------------------------- 1 | # MultiSpaceAI LangChain 集成 2 | 3 | 本文档介绍如何使用 MultiSpaceAI 与 LangChain 集成,实现更强大的多模态处理能力。 4 | 5 | ## 概述 6 | 7 | [LangChain](https://www.langchain.com/) 是一个用于构建基于大语言模型应用的框架,提供了一系列组件和工具,使开发人员能够创建复杂的基于 LLM 的应用程序。 8 | 9 | MultiSpaceAI 与 LangChain 的集成允许您: 10 | 11 | 1. 将 MultiSpaceAI 作为 LangChain 工具使用 12 | 2. 创建包含多模态处理能力的 LangChain 链 13 | 3. 构建可以使用 MultiSpaceAI 处理多模态输入的代理 14 | 4. 实现基于向量数据库的多模态内容检索 15 | 16 | ## 安装依赖 17 | 18 | 要使用 MultiSpaceAI 的 LangChain 集成,需要安装以下依赖: 19 | 20 | ```bash 21 | pip install langchain langchain-openai faiss-cpu chromadb openai tiktoken 22 | ``` 23 | 24 | 或者直接使用项目的 requirements.txt: 25 | 26 | ```bash 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## 基本用法 31 | 32 | ### 将 MultiSpaceAI 作为工具使用 33 | 34 | ```python 35 | from langchain.chat_models import ChatOpenAI 36 | from src.langchain_integration import MultiSpaceAITool 37 | 38 | # 初始化 LLM 39 | llm = ChatOpenAI(model_name="gpt-3.5-turbo") 40 | 41 | # 初始化 MultiSpaceAI 工具 42 | multispace_tool = MultiSpaceAITool( 43 | config_path="path/to/config.json", 44 | text_encoder="custom", 45 | image_encoder="custom", 46 | audio_encoder="custom" 47 | ) 48 | 49 | # 使用工具处理多模态输入 50 | result = multispace_tool.run( 51 | query="描述这个图像", 52 | image="path/to/image.jpg" 53 | ) 54 | 55 | print(result) 56 | ``` 57 | 58 | ### 使用 MultiSpaceAI 链 59 | 60 | ```python 61 | from langchain.chat_models import ChatOpenAI 62 | from langchain.prompts import PromptTemplate 63 | from src.langchain_integration import MultiSpaceAIChain 64 | 65 | # 初始化 LLM 66 | llm = ChatOpenAI(model_name="gpt-3.5-turbo") 67 | 68 | # 创建提示模板 69 | template = """你是一个多模态助手,能够处理文本、图像和音频输入。 70 | 71 | {multimodal_context} 72 | 73 | 用户问题: {query} 74 | 75 | 请提供详细且有用的回答:""" 76 | 77 | prompt = PromptTemplate( 78 | template=template, 79 | input_variables=["multimodal_context", "query"] 80 | ) 81 | 82 | # 初始化 MultiSpaceAI 链 83 | multispace_chain = MultiSpaceAIChain( 84 | llm=llm, 85 | prompt=prompt, 86 | config_path="path/to/config.json" 87 | ) 88 | 89 | # 运行链 90 | result = multispace_chain({ 91 | "query": "这个图像中有什么?", 92 | "image": "path/to/image.jpg" 93 | }) 94 | 95 | print(result["text"]) 96 | ``` 97 | 98 | ### 使用 MultiSpaceAI 代理 99 | 100 | ```python 101 | from langchain.chat_models import ChatOpenAI 102 | from src.langchain_integration import MultiSpaceAIAgent 103 | 104 | # 初始化 LLM 105 | llm = ChatOpenAI(model_name="gpt-3.5-turbo") 106 | 107 | # 初始化 MultiSpaceAI 代理 108 | agent = MultiSpaceAIAgent( 109 | llm=llm, 110 | config_path="path/to/config.json", 111 | verbose=True 112 | ) 113 | 114 | # 运行代理 115 | result = agent.run("分析这个图像并告诉我其中的主要内容 图像路径: path/to/image.jpg") 116 | 117 | print(result) 118 | ``` 119 | 120 | ### 使用 MultiSpaceAI 检索链 121 | 122 | ```python 123 | from langchain.chat_models import ChatOpenAI 124 | from src.langchain_integration import MultiSpaceAIRetrievalChain 125 | 126 | # 初始化 LLM 127 | llm = ChatOpenAI(model_name="gpt-3.5-turbo") 128 | 129 | # 初始化检索链 130 | retrieval_chain = MultiSpaceAIRetrievalChain( 131 | llm=llm, 132 | config_path="path/to/config.json", 133 | return_source_documents=True 134 | ) 135 | 136 | # 导入文档 137 | retrieval_chain.ingest_from_directory( 138 | directory="path/to/data", 139 | chunk_size=1000, 140 | chunk_overlap=0 141 | ) 142 | 143 | # 运行检索链 144 | result = retrieval_chain.run( 145 | query="关于这个图像,我们之前的数据库中有什么相关信息?", 146 | image="path/to/image.jpg" 147 | ) 148 | 149 | print(result["result"]) 150 | 151 | # 打印源文档 152 | for doc in result["source_documents"]: 153 | print(f"来源: {doc.metadata.get('source')}") 154 | print(f"内容: {doc.page_content[:100]}...") 155 | ``` 156 | 157 | ## 组件详解 158 | 159 | ### MultiSpaceAITool 160 | 161 | `MultiSpaceAITool` 类将 MultiSpaceAI 封装为 LangChain 工具,可以在代理中使用。 162 | 163 | - 主要功能: 处理多模态输入并返回生成文本 164 | - 支持异步调用 (`arun` 方法) 165 | - 可作为代理的工具使用 166 | 167 | ### MultiSpaceAIChain 168 | 169 | `MultiSpaceAIChain` 类扩展了 LangChain 的 `LLMChain`,添加了多模态处理能力。 170 | 171 | - 主要功能: 处理多模态输入并将处理结果添加到链的上下文中 172 | - 支持内存组件,可以保存对话历史 173 | - 可以自定义提示模板 174 | 175 | ### MultiSpaceAIAgent 176 | 177 | `MultiSpaceAIAgent` 类创建一个可以使用 MultiSpaceAI 处理多模态输入的代理。 178 | 179 | - 主要功能: 创建一个可以规划和执行多模态任务的代理 180 | - 基于 LangChain 的结构化代理实现 181 | - 支持对话内存和回调 182 | 183 | ### MultiSpaceAIEmbeddings 184 | 185 | `MultiSpaceAIEmbeddings` 类实现了 LangChain 的 `Embeddings` 接口,用于生成多模态内容的嵌入向量。 186 | 187 | - 主要功能: 生成文本、图像和音频的嵌入向量 188 | - 可以与 LangChain 的向量存储集成 189 | - 支持文档和查询的嵌入生成 190 | 191 | ### MultiModalDocument 192 | 193 | `MultiModalDocument` 类扩展了 LangChain 的 `Document` 类,添加了对图像和音频内容的支持。 194 | 195 | - 主要功能: 表示包含多模态内容的文档 196 | - 存储图像和音频文件路径到元数据中 197 | - 与标准 LangChain 文档兼容 198 | 199 | ### MultiSpaceAIRetrievalChain 200 | 201 | `MultiSpaceAIRetrievalChain` 类实现了一个基于向量数据库的多模态内容检索和问答系统。 202 | 203 | - 主要功能: 根据多模态查询检索相关内容并回答问题 204 | - 支持从目录导入文档 205 | - 使用 LLM 进行上下文压缩和回答生成 206 | 207 | ## 示例 208 | 209 | 查看 `examples/langchain_example.py` 获取完整示例。 210 | 211 | 运行示例: 212 | 213 | ```bash 214 | # 工具示例 215 | python examples/langchain_example.py --example tool --image path/to/image.jpg --query "描述这个图像" 216 | 217 | # 链示例 218 | python examples/langchain_example.py --example chain --audio path/to/audio.mp3 --query "转录这段音频" 219 | 220 | # 代理示例 221 | python examples/langchain_example.py --example agent --query "分析这个图像" --image path/to/image.jpg 222 | 223 | # 检索示例 224 | python examples/langchain_example.py --example retrieval --data-dir path/to/data --query "查找相关内容" 225 | 226 | # 运行所有示例 227 | python examples/langchain_example.py --example all --image path/to/image.jpg --data-dir path/to/data 228 | ``` 229 | 230 | ## 高级配置 231 | 232 | ### 自定义提示模板 233 | 234 | 您可以创建自定义提示模板来控制多模态内容的处理方式: 235 | 236 | ```python 237 | from langchain.prompts import PromptTemplate 238 | 239 | template = """系统: 你是一个专业的多模态内容分析专家。 240 | 多模态内容: {multimodal_context} 241 | 用户问题: {query} 242 | 分析结果:""" 243 | 244 | prompt = PromptTemplate( 245 | template=template, 246 | input_variables=["multimodal_context", "query"] 247 | ) 248 | 249 | # 在链或检索链中使用自定义提示 250 | ``` 251 | 252 | ### 集成外部向量存储 253 | 254 | 您可以使用外部向量存储进行多模态内容检索: 255 | 256 | ```python 257 | from langchain.vectorstores import Chroma 258 | from src.langchain_integration import MultiSpaceAIEmbeddings, MultiSpaceAIRetrievalChain 259 | 260 | # 创建嵌入模型 261 | embeddings = MultiSpaceAIEmbeddings(config_path="path/to/config.json") 262 | 263 | # 创建向量存储 264 | vectorstore = Chroma(embedding_function=embeddings, persist_directory="path/to/store") 265 | 266 | # 创建检索链 267 | retrieval_chain = MultiSpaceAIRetrievalChain( 268 | llm=llm, 269 | vectorstore=vectorstore 270 | ) 271 | ``` 272 | 273 | ### 添加自定义回调 274 | 275 | 您可以添加自定义回调来监控处理流程: 276 | 277 | ```python 278 | from langchain.callbacks import StdOutCallbackHandler 279 | from src.langchain_integration import MultiSpaceAIAgent 280 | 281 | # 创建回调处理器 282 | callbacks = [StdOutCallbackHandler()] 283 | 284 | # 初始化代理 285 | agent = MultiSpaceAIAgent( 286 | llm=llm, 287 | callbacks=callbacks, 288 | verbose=True 289 | ) 290 | ``` 291 | 292 | ## 故障排除 293 | 294 | ### 安装问题 295 | 296 | 如果遇到安装依赖的问题,可以尝试使用以下命令: 297 | 298 | ```bash 299 | pip install --upgrade pip 300 | pip install langchain langchain-openai faiss-cpu chromadb openai tiktoken 301 | ``` 302 | 303 | ### API 密钥 304 | 305 | 确保设置了必要的 API 密钥环境变量: 306 | 307 | ```bash 308 | export OPENAI_API_KEY=your_openai_api_key 309 | ``` 310 | 311 | ### 内存问题 312 | 313 | 处理大型多模态数据集可能需要大量内存,可以通过以下方式减少内存使用: 314 | 315 | 1. 减小 `chunk_size` 参数 316 | 2. 使用 `Chroma` 而不是 `FAISS` 作为向量存储 317 | 3. 批量处理文档而不是一次性导入所有文档 318 | 319 | ## 更多资源 320 | 321 | - [LangChain 文档](https://langchain.readthedocs.io/) 322 | - [MultiSpaceAI 文档](../README.md) 323 | - [示例代码](../examples/) -------------------------------------------------------------------------------- /docs/architecture/system_architecture.md: -------------------------------------------------------------------------------- 1 | # MultiSpaceAI 系统架构 2 | 3 | 本文档描述了 MultiSpaceAI 的系统架构,包括总体架构和各个模块的流程图。 4 | 5 | ## 总体架构 6 | 7 | ```mermaid 8 | graph TD 9 | User[用户输入] --> InputProcessor[输入处理器] 10 | InputProcessor --> |文本| TextEncoder[文本编码器] 11 | InputProcessor --> |图像| ImageEncoder[图像编码器] 12 | InputProcessor --> |音频| AudioEncoder[音频编码器] 13 | 14 | TextEncoder --> |文本嵌入| Fusion[多模态融合模块] 15 | ImageEncoder --> |图像嵌入| Fusion 16 | AudioEncoder --> |音频嵌入| Fusion 17 | 18 | Fusion --> |融合表示| Decoder[解码器] 19 | Decoder --> |生成文本| OutputProcessor[输出处理器] 20 | OutputProcessor --> Result[处理结果] 21 | 22 | Config[配置系统] -.-> TextEncoder 23 | Config -.-> ImageEncoder 24 | Config -.-> AudioEncoder 25 | Config -.-> Fusion 26 | Config -.-> Decoder 27 | ``` 28 | 29 | MultiSpaceAI 的总体架构由以下主要组件组成: 30 | 31 | 1. **输入处理器**:接收并预处理用户的多模态输入(文本、图像、音频) 32 | 2. **文本编码器**:将文本输入编码为高维嵌入 33 | 3. **图像编码器**:将图像输入编码为高维嵌入 34 | 4. **音频编码器**:将音频输入编码为高维嵌入 35 | 5. **多模态融合模块**:将不同模态的嵌入融合为统一表示 36 | 6. **解码器**:基于融合表示生成输出文本 37 | 7. **输出处理器**:处理和格式化最终结果 38 | 8. **配置系统**:为各个模块提供配置参数 39 | 40 | ## 输入处理流程 41 | 42 | ```mermaid 43 | graph TD 44 | Input[用户输入] --> InputCheck{检查输入类型} 45 | 46 | InputCheck -->|文本| TextPreprocess[文本预处理] 47 | InputCheck -->|图像| ImagePreprocess[图像预处理] 48 | InputCheck -->|音频| AudioPreprocess[音频预处理] 49 | 50 | TextPreprocess --> |清洗、分词| TextNormalize[文本规范化] 51 | ImagePreprocess --> |缩放、归一化| ImageNormalize[图像规范化] 52 | AudioPreprocess --> |重采样、特征提取| AudioNormalize[音频规范化] 53 | 54 | TextNormalize --> TextReady[准备文本编码] 55 | ImageNormalize --> ImageReady[准备图像编码] 56 | AudioNormalize --> AudioReady[准备音频编码] 57 | ``` 58 | 59 | ## 文本编码器模块 60 | 61 | ```mermaid 62 | graph TD 63 | TextInput[文本输入] --> ModelSelect{编码器选择} 64 | ModelSelect -->|自定义模型| CustomTextEncoder[自定义文本编码器] 65 | ModelSelect -->|API集成| APITextEncoder[API文本编码器] 66 | 67 | CustomTextEncoder --> |加载模型| LoadTextModel[加载文本模型] 68 | APITextEncoder --> |API调用| CallTextAPI[调用文本API] 69 | 70 | LoadTextModel --> TextTokenize[文本分词] 71 | CallTextAPI --> APIProcess[API处理] 72 | 73 | TextTokenize --> TextEmbed[文本编码] 74 | APIProcess --> APIEmbed[API返回编码] 75 | 76 | TextEmbed --> TextOutput[文本嵌入输出] 77 | APIEmbed --> TextOutput 78 | ``` 79 | 80 | ## 图像编码器模块 81 | 82 | ```mermaid 83 | graph TD 84 | ImageInput[图像输入] --> ImgModelSelect{编码器选择} 85 | ImgModelSelect -->|自定义模型| CustomImgEncoder[自定义图像编码器] 86 | ImgModelSelect -->|API集成| APIImgEncoder[API图像编码器] 87 | 88 | CustomImgEncoder --> |加载模型| LoadImgModel[加载图像模型] 89 | APIImgEncoder --> |API调用| CallImgAPI[调用图像API] 90 | 91 | LoadImgModel --> ImgProcess[图像处理] 92 | CallImgAPI --> APIImgProcess[API处理] 93 | 94 | ImgProcess --> ImgEmbed[图像编码] 95 | APIImgProcess --> APIImgEmbed[API返回编码] 96 | 97 | ImgEmbed --> ImgOutput[图像嵌入输出] 98 | APIImgEmbed --> ImgOutput 99 | ``` 100 | 101 | ## 音频编码器模块 102 | 103 | ```mermaid 104 | graph TD 105 | AudioInput[音频输入] --> AudioModelSelect{编码器选择} 106 | AudioModelSelect -->|自定义模型| CustomAudioEncoder[自定义音频编码器] 107 | AudioModelSelect -->|API集成| APIAudioEncoder[API音频编码器] 108 | 109 | CustomAudioEncoder --> |加载模型| LoadAudioModel[加载音频模型] 110 | APIAudioEncoder --> |API调用| CallAudioAPI[调用音频API] 111 | 112 | LoadAudioModel --> AudioFeatures[提取音频特征] 113 | CallAudioAPI --> APIAudioProcess[API处理] 114 | 115 | AudioFeatures --> AudioEmbed[音频编码] 116 | APIAudioProcess --> APIAudioEmbed[API返回编码] 117 | 118 | AudioEmbed --> AudioOutput[音频嵌入输出] 119 | APIAudioEmbed --> AudioOutput 120 | ``` 121 | 122 | ## 多模态融合模块 123 | 124 | ```mermaid 125 | graph TD 126 | TextEmbed[文本嵌入] --> EmbedProject[嵌入投影] 127 | ImageEmbed[图像嵌入] --> EmbedProject 128 | AudioEmbed[音频嵌入] --> EmbedProject 129 | 130 | EmbedProject --> FusionType{融合类型} 131 | 132 | FusionType -->|注意力融合| AttentionFusion[注意力融合] 133 | FusionType -->|拼接融合| ConcatFusion[拼接融合] 134 | FusionType -->|加权融合| WeightedFusion[加权融合] 135 | 136 | AttentionFusion --> CrossAttention[交叉注意力] 137 | ConcatFusion --> ProjectionLayer[投影层] 138 | WeightedFusion --> ModalityWeight[模态权重计算] 139 | 140 | CrossAttention --> FusionOutput[融合输出] 141 | ProjectionLayer --> FusionOutput 142 | ModalityWeight --> FusionOutput 143 | ``` 144 | 145 | ## 解码器模块 146 | 147 | ```mermaid 148 | graph TD 149 | FusionEmbed[融合嵌入] --> DecoderInput[解码器输入] 150 | 151 | DecoderInput --> Decoding{解码策略} 152 | 153 | Decoding -->|贪婪解码| GreedyDecode[贪婪解码] 154 | Decoding -->|束搜索| BeamSearch[束搜索] 155 | Decoding -->|采样| Sampling[采样解码] 156 | 157 | GreedyDecode --> TokenGeneration[生成序列] 158 | BeamSearch --> TokenGeneration 159 | Sampling --> TokenGeneration 160 | 161 | TokenGeneration --> PostProcess[后处理] 162 | PostProcess --> OutputText[输出文本] 163 | ``` 164 | 165 | ## 配置系统 166 | 167 | ```mermaid 168 | graph TD 169 | ConfigInput[配置输入] --> ConfigType{配置类型} 170 | 171 | ConfigType -->|默认配置| DefaultConfig[默认配置] 172 | ConfigType -->|自定义配置| CustomConfig[自定义配置] 173 | ConfigType -->|JSON配置| JSONConfig[JSON配置] 174 | 175 | DefaultConfig --> ConfigValidation[配置验证] 176 | CustomConfig --> ConfigValidation 177 | JSONConfig --> ConfigValidation 178 | 179 | ConfigValidation --> ConfigDistribution[配置分发] 180 | 181 | ConfigDistribution --> TextEncoderConfig[文本编码器配置] 182 | ConfigDistribution --> ImageEncoderConfig[图像编码器配置] 183 | ConfigDistribution --> AudioEncoderConfig[音频编码器配置] 184 | ConfigDistribution --> FusionConfig[融合模块配置] 185 | ConfigDistribution --> DecoderConfig[解码器配置] 186 | ``` 187 | 188 | ## 数据流程 189 | 190 | ```mermaid 191 | sequenceDiagram 192 | participant User as 用户 193 | participant Input as 输入处理器 194 | participant TextEnc as 文本编码器 195 | participant ImageEnc as 图像编码器 196 | participant AudioEnc as 音频编码器 197 | participant Fusion as 融合模块 198 | participant Decoder as 解码器 199 | participant Output as 输出处理器 200 | 201 | User->>Input: 提供多模态输入 202 | 203 | par 并行处理 204 | Input->>TextEnc: 文本输入 205 | Input->>ImageEnc: 图像输入 206 | Input->>AudioEnc: 音频输入 207 | end 208 | 209 | TextEnc-->>Fusion: 文本嵌入 210 | ImageEnc-->>Fusion: 图像嵌入 211 | AudioEnc-->>Fusion: 音频嵌入 212 | 213 | Fusion->>Decoder: 融合表示 214 | Decoder->>Output: 生成文本 215 | Output->>User: 返回结果 216 | ``` 217 | 218 | ## 模型训练流程 219 | 220 | ```mermaid 221 | graph TD 222 | TrainData[训练数据] --> DataProcess[数据预处理] 223 | DataProcess --> DataLoader[数据加载器] 224 | 225 | DataLoader --> TrainLoop[训练循环] 226 | 227 | PretrainedModels[预训练模型] --> ModelInit[模型初始化] 228 | ModelInit --> TrainLoop 229 | 230 | TrainLoop --> Forward[前向传播] 231 | Forward --> Loss[损失计算] 232 | Loss --> Backward[反向传播] 233 | Backward --> Optimize[优化器更新] 234 | Optimize --> TrainLoop 235 | 236 | TrainLoop --> Evaluation[模型评估] 237 | Evaluation --> SaveModel[保存模型] 238 | ``` 239 | 240 | ## API集成架构 241 | 242 | ```mermaid 243 | graph TD 244 | APIRequest[API请求] --> APIAuth[API认证] 245 | APIAuth --> APISelect{API选择} 246 | 247 | APISelect -->|文本API| TextAPI[文本API服务] 248 | APISelect -->|图像API| ImageAPI[图像API服务] 249 | APISelect -->|音频API| AudioAPI[音频API服务] 250 | 251 | TextAPI --> APIProcess[API处理] 252 | ImageAPI --> APIProcess 253 | AudioAPI --> APIProcess 254 | 255 | APIProcess --> CacheCheck{缓存检查} 256 | 257 | CacheCheck -->|缓存命中| CacheResult[使用缓存结果] 258 | CacheCheck -->|缓存未命中| APICall[调用外部API] 259 | 260 | APICall --> RateLimit[速率限制] 261 | RateLimit --> ExternalAPI[外部API服务] 262 | ExternalAPI --> APIResponse[API响应] 263 | 264 | APIResponse --> CacheStore[存储到缓存] 265 | CacheStore --> APIResult[API结果] 266 | CacheResult --> APIResult 267 | ``` -------------------------------------------------------------------------------- /multispace/multispace.py: -------------------------------------------------------------------------------- 1 | """ 2 | MultiSpaceAI 主类模块 3 | ==================== 4 | 5 | 该模块包含 MultiSpaceAI 类,它是整个系统的主要入口点。 6 | 该类协调各个模态的编码器、融合模块和解码器的工作。 7 | """ 8 | 9 | import os 10 | import logging 11 | from typing import Optional, Union, Dict, Any 12 | 13 | import torch 14 | import numpy as np 15 | 16 | from .config import ModelConfig 17 | from .text import TextEncoder, DeepSeekTextEncoder 18 | from .image import ImageEncoder, DiffusionImageEncoder 19 | from .audio import AudioEncoder, WhisperAudioEncoder 20 | from .fusion import MultimodalFusionModule 21 | from .decoder import DecoderModule 22 | from .utils.logger import setup_logger 23 | 24 | # 设置日志 25 | logger = setup_logger(__name__) 26 | 27 | class MultiSpaceAI: 28 | """ 29 | MultiSpaceAI 是一个多模态大语言模型系统,能够处理文本、图像和语音输入, 30 | 实现复杂的跨模态理解与生成任务。 31 | 32 | 该类是整个系统的主要入口点,负责协调各个模块的工作。 33 | 34 | 属性: 35 | config (ModelConfig): 模型配置 36 | text_encoder: 文本编码器模块 37 | image_encoder: 图像编码器模块 38 | audio_encoder: 语音编码器模块 39 | fusion_module: 多模态融合模块 40 | decoder_module: 解码器模块 41 | device (torch.device): 模型运行的设备 42 | """ 43 | 44 | def __init__(self, 45 | config_path: Optional[str] = None, 46 | text_encoder: str = "custom", 47 | image_encoder: str = "custom", 48 | audio_encoder: str = "custom", 49 | device: Optional[str] = None): 50 | """ 51 | 初始化 MultiSpaceAI 实例 52 | 53 | 参数: 54 | config_path: 配置文件路径,如果为None则使用默认配置 55 | text_encoder: 文本编码器类型,可选值为 "custom" 或 "deepseek-api" 56 | image_encoder: 图像编码器类型,可选值为 "custom" 或 "diffusion-api" 57 | audio_encoder: 语音编码器类型,可选值为 "custom" 或 "whisper-api" 58 | device: 模型运行的设备,如果为None则自动选择 59 | """ 60 | # 加载配置 61 | self.config = ModelConfig(config_path) 62 | 63 | # 设置设备 64 | self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") 65 | logger.info(f"使用设备: {self.device}") 66 | 67 | # 初始化编码器 68 | self._init_encoders(text_encoder, image_encoder, audio_encoder) 69 | 70 | # 初始化融合模块 71 | self.fusion_module = MultimodalFusionModule( 72 | config=self.config.fusion_config, 73 | device=self.device 74 | ) 75 | 76 | # 初始化解码器模块 77 | self.decoder_module = DecoderModule( 78 | config=self.config.decoder_config, 79 | device=self.device 80 | ) 81 | 82 | logger.info("MultiSpaceAI 初始化完成") 83 | 84 | def _init_encoders(self, text_encoder_type, image_encoder_type, audio_encoder_type): 85 | """ 86 | 初始化各模态编码器 87 | 88 | 参数: 89 | text_encoder_type: 文本编码器类型 90 | image_encoder_type: 图像编码器类型 91 | audio_encoder_type: 语音编码器类型 92 | """ 93 | # 初始化文本编码器 94 | if text_encoder_type == "deepseek-api": 95 | api_key = os.environ.get("DEEPSEEK_API_KEY") 96 | if not api_key: 97 | logger.warning("未找到 DEEPSEEK_API_KEY 环境变量,请确保已正确设置") 98 | self.text_encoder = DeepSeekTextEncoder(api_key=api_key) 99 | logger.info("使用 DeepSeek API 作为文本编码器") 100 | else: 101 | self.text_encoder = TextEncoder( 102 | config=self.config.text_encoder_config, 103 | device=self.device 104 | ) 105 | logger.info("使用自定义模型作为文本编码器") 106 | 107 | # 初始化图像编码器 108 | if image_encoder_type == "diffusion-api": 109 | api_key = os.environ.get("STABLE_DIFFUSION_API_KEY") 110 | if not api_key: 111 | logger.warning("未找到 STABLE_DIFFUSION_API_KEY 环境变量,请确保已正确设置") 112 | self.image_encoder = DiffusionImageEncoder(api_key=api_key) 113 | logger.info("使用 Diffusion API 作为图像编码器") 114 | else: 115 | self.image_encoder = ImageEncoder( 116 | config=self.config.image_encoder_config, 117 | device=self.device 118 | ) 119 | logger.info("使用自定义模型作为图像编码器") 120 | 121 | # 初始化语音编码器 122 | if audio_encoder_type == "whisper-api": 123 | api_key = os.environ.get("OPENAI_API_KEY") 124 | if not api_key: 125 | logger.warning("未找到 OPENAI_API_KEY 环境变量,请确保已正确设置") 126 | self.audio_encoder = WhisperAudioEncoder(api_key=api_key) 127 | logger.info("使用 Whisper API 作为语音编码器") 128 | else: 129 | self.audio_encoder = AudioEncoder( 130 | config=self.config.audio_encoder_config, 131 | device=self.device 132 | ) 133 | logger.info("使用自定义模型作为语音编码器") 134 | 135 | def process(self, 136 | text: Optional[str] = None, 137 | image: Optional[Union[str, np.ndarray]] = None, 138 | audio: Optional[Union[str, np.ndarray]] = None, 139 | **kwargs) -> Dict[str, Any]: 140 | """ 141 | 处理多模态输入并生成输出 142 | 143 | 参数: 144 | text: 文本输入,可以是字符串 145 | image: 图像输入,可以是图像路径或numpy数组 146 | audio: 语音输入,可以是音频文件路径或numpy数组 147 | **kwargs: 其他参数 148 | 149 | 返回: 150 | 包含处理结果的字典 151 | """ 152 | # 检查是否至少有一种模态的输入 153 | if not any([text, image, audio]): 154 | raise ValueError("至少需要提供一种模态的输入(文本、图像或语音)") 155 | 156 | # 初始化各模态的特征 157 | text_features = None 158 | image_features = None 159 | audio_features = None 160 | 161 | # 处理文本输入 162 | if text: 163 | logger.info("处理文本输入") 164 | text_features = self.text_encoder.encode(text) 165 | 166 | # 处理图像输入 167 | if image: 168 | logger.info("处理图像输入") 169 | image_features = self.image_encoder.encode(image) 170 | 171 | # 处理语音输入 172 | if audio: 173 | logger.info("处理语音输入") 174 | audio_features = self.audio_encoder.encode(audio) 175 | 176 | # 多模态融合 177 | logger.info("执行多模态融合") 178 | fused_features = self.fusion_module.fuse( 179 | text_features=text_features, 180 | image_features=image_features, 181 | audio_features=audio_features 182 | ) 183 | 184 | # 解码生成输出 185 | logger.info("生成输出") 186 | output = self.decoder_module.decode(fused_features, **kwargs) 187 | 188 | return output 189 | 190 | def save(self, path: str): 191 | """ 192 | 保存模型 193 | 194 | 参数: 195 | path: 保存路径 196 | """ 197 | logger.info(f"保存模型到 {path}") 198 | os.makedirs(path, exist_ok=True) 199 | 200 | # 保存配置 201 | self.config.save(os.path.join(path, "config.json")) 202 | 203 | # 保存模型权重 204 | model_state = { 205 | "text_encoder": self.text_encoder.state_dict() if hasattr(self.text_encoder, "state_dict") else None, 206 | "image_encoder": self.image_encoder.state_dict() if hasattr(self.image_encoder, "state_dict") else None, 207 | "audio_encoder": self.audio_encoder.state_dict() if hasattr(self.audio_encoder, "state_dict") else None, 208 | "fusion_module": self.fusion_module.state_dict(), 209 | "decoder_module": self.decoder_module.state_dict() 210 | } 211 | 212 | torch.save(model_state, os.path.join(path, "model.pt")) 213 | 214 | @classmethod 215 | def load(cls, path: str, device: Optional[str] = None): 216 | """ 217 | 从保存的文件加载模型 218 | 219 | 参数: 220 | path: 模型路径 221 | device: 模型运行的设备 222 | 223 | 返回: 224 | MultiSpaceAI 实例 225 | """ 226 | logger.info(f"从 {path} 加载模型") 227 | 228 | # 加载配置 229 | config_path = os.path.join(path, "config.json") 230 | instance = cls(config_path=config_path, device=device) 231 | 232 | # 加载模型权重 233 | model_path = os.path.join(path, "model.pt") 234 | model_state = torch.load(model_path, map_location=instance.device) 235 | 236 | # 加载各模块的权重 237 | if hasattr(instance.text_encoder, "load_state_dict") and model_state["text_encoder"]: 238 | instance.text_encoder.load_state_dict(model_state["text_encoder"]) 239 | 240 | if hasattr(instance.image_encoder, "load_state_dict") and model_state["image_encoder"]: 241 | instance.image_encoder.load_state_dict(model_state["image_encoder"]) 242 | 243 | if hasattr(instance.audio_encoder, "load_state_dict") and model_state["audio_encoder"]: 244 | instance.audio_encoder.load_state_dict(model_state["audio_encoder"]) 245 | 246 | instance.fusion_module.load_state_dict(model_state["fusion_module"]) 247 | instance.decoder_module.load_state_dict(model_state["decoder_module"]) 248 | 249 | return instance -------------------------------------------------------------------------------- /multispace/config/model_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | 模型配置类 3 | ======== 4 | 5 | 该模块包含模型各个组件的配置类。 6 | """ 7 | 8 | import os 9 | import json 10 | from typing import Optional, Dict, Any, List 11 | 12 | 13 | class BaseConfig: 14 | """所有配置类的基类""" 15 | 16 | def __init__(self, **kwargs): 17 | """使用提供的参数初始化配置""" 18 | for key, value in kwargs.items(): 19 | setattr(self, key, value) 20 | 21 | def to_dict(self) -> Dict[str, Any]: 22 | """将配置转换为字典""" 23 | return {k: v for k, v in self.__dict__.items() if not k.startswith('_')} 24 | 25 | @classmethod 26 | def from_dict(cls, config_dict: Dict[str, Any]) -> 'BaseConfig': 27 | """从字典创建配置实例""" 28 | return cls(**config_dict) 29 | 30 | 31 | class TextEncoderConfig(BaseConfig): 32 | """文本编码器配置""" 33 | 34 | def __init__( 35 | self, 36 | model_name: str = "bert-base-chinese", 37 | embedding_dim: int = 768, 38 | hidden_dim: int = 1024, 39 | num_layers: int = 6, 40 | num_attention_heads: int = 8, 41 | max_seq_length: int = 512, 42 | dropout: float = 0.1, 43 | **kwargs 44 | ): 45 | """ 46 | 初始化文本编码器配置 47 | 48 | 参数: 49 | model_name: 预训练模型名称,用于加载Hugging Face模型 50 | embedding_dim: 词嵌入维度 51 | hidden_dim: 隐藏层维度 52 | num_layers: Transformer层数 53 | num_attention_heads: 注意力头数 54 | max_seq_length: 最大序列长度 55 | dropout: Dropout比例 56 | """ 57 | self.model_name = model_name 58 | self.embedding_dim = embedding_dim 59 | self.hidden_dim = hidden_dim 60 | self.num_layers = num_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.max_seq_length = max_seq_length 63 | self.dropout = dropout 64 | super().__init__(**kwargs) 65 | 66 | 67 | class ImageEncoderConfig(BaseConfig): 68 | """图像编码器配置""" 69 | 70 | def __init__( 71 | self, 72 | model_name: str = "vit-base-patch16-224", 73 | image_size: int = 224, 74 | patch_size: int = 16, 75 | embedding_dim: int = 768, 76 | hidden_dim: int = 1024, 77 | num_layers: int = 12, 78 | num_attention_heads: int = 12, 79 | dropout: float = 0.1, 80 | **kwargs 81 | ): 82 | """ 83 | 初始化图像编码器配置 84 | 85 | 参数: 86 | model_name: 预训练模型名称 87 | image_size: 输入图像大小 88 | patch_size: 图像分块大小 89 | embedding_dim: 嵌入维度 90 | hidden_dim: 隐藏层维度 91 | num_layers: Transformer层数 92 | num_attention_heads: 注意力头数 93 | dropout: Dropout比例 94 | """ 95 | self.model_name = model_name 96 | self.image_size = image_size 97 | self.patch_size = patch_size 98 | self.embedding_dim = embedding_dim 99 | self.hidden_dim = hidden_dim 100 | self.num_layers = num_layers 101 | self.num_attention_heads = num_attention_heads 102 | self.dropout = dropout 103 | super().__init__(**kwargs) 104 | 105 | 106 | class AudioEncoderConfig(BaseConfig): 107 | """语音编码器配置""" 108 | 109 | def __init__( 110 | self, 111 | model_name: str = "facebook/wav2vec2-base-960h", 112 | sample_rate: int = 16000, 113 | embedding_dim: int = 768, 114 | hidden_dim: int = 1024, 115 | num_layers: int = 12, 116 | num_attention_heads: int = 12, 117 | max_audio_length: int = 30, # 单位:秒 118 | dropout: float = 0.1, 119 | **kwargs 120 | ): 121 | """ 122 | 初始化语音编码器配置 123 | 124 | 参数: 125 | model_name: 预训练模型名称 126 | sample_rate: 音频采样率 127 | embedding_dim: 嵌入维度 128 | hidden_dim: 隐藏层维度 129 | num_layers: Transformer层数 130 | num_attention_heads: 注意力头数 131 | max_audio_length: 最大音频长度(秒) 132 | dropout: Dropout比例 133 | """ 134 | self.model_name = model_name 135 | self.sample_rate = sample_rate 136 | self.embedding_dim = embedding_dim 137 | self.hidden_dim = hidden_dim 138 | self.num_layers = num_layers 139 | self.num_attention_heads = num_attention_heads 140 | self.max_audio_length = max_audio_length 141 | self.dropout = dropout 142 | super().__init__(**kwargs) 143 | 144 | 145 | class FusionConfig(BaseConfig): 146 | """多模态融合模块配置""" 147 | 148 | def __init__( 149 | self, 150 | hidden_dim: int = 1024, 151 | num_layers: int = 4, 152 | num_attention_heads: int = 16, 153 | fusion_type: str = "cross_attention", # 可选:cross_attention, concat, sum 154 | use_modal_adapters: bool = True, 155 | dropout: float = 0.1, 156 | **kwargs 157 | ): 158 | """ 159 | 初始化多模态融合模块配置 160 | 161 | 参数: 162 | hidden_dim: 隐藏层维度 163 | num_layers: Transformer层数 164 | num_attention_heads: 注意力头数 165 | fusion_type: 融合类型,可选项:cross_attention, concat, sum 166 | use_modal_adapters: 是否使用模态适配器 167 | dropout: Dropout比例 168 | """ 169 | self.hidden_dim = hidden_dim 170 | self.num_layers = num_layers 171 | self.num_attention_heads = num_attention_heads 172 | self.fusion_type = fusion_type 173 | self.use_modal_adapters = use_modal_adapters 174 | self.dropout = dropout 175 | super().__init__(**kwargs) 176 | 177 | 178 | class DecoderConfig(BaseConfig): 179 | """解码器模块配置""" 180 | 181 | def __init__( 182 | self, 183 | model_name: str = "fnlp/bart-base-chinese", 184 | embedding_dim: int = 768, 185 | hidden_dim: int = 1024, 186 | num_layers: int = 6, 187 | num_attention_heads: int = 8, 188 | max_seq_length: int = 512, 189 | dropout: float = 0.1, 190 | **kwargs 191 | ): 192 | """ 193 | 初始化解码器模块配置 194 | 195 | 参数: 196 | model_name: 预训练模型名称 197 | embedding_dim: 嵌入维度 198 | hidden_dim: 隐藏层维度 199 | num_layers: Transformer层数 200 | num_attention_heads: 注意力头数 201 | max_seq_length: 最大序列长度 202 | dropout: Dropout比例 203 | """ 204 | self.model_name = model_name 205 | self.embedding_dim = embedding_dim 206 | self.hidden_dim = hidden_dim 207 | self.num_layers = num_layers 208 | self.num_attention_heads = num_attention_heads 209 | self.max_seq_length = max_seq_length 210 | self.dropout = dropout 211 | super().__init__(**kwargs) 212 | 213 | 214 | class ModelConfig: 215 | """整体模型配置类""" 216 | 217 | def __init__(self, config_path: Optional[str] = None): 218 | """ 219 | 初始化模型配置 220 | 221 | 参数: 222 | config_path: 配置文件路径,如果为None则使用默认配置 223 | """ 224 | # 使用默认配置初始化 225 | self.text_encoder_config = TextEncoderConfig() 226 | self.image_encoder_config = ImageEncoderConfig() 227 | self.audio_encoder_config = AudioEncoderConfig() 228 | self.fusion_config = FusionConfig() 229 | self.decoder_config = DecoderConfig() 230 | 231 | # 如果提供了配置文件路径,则从配置文件加载 232 | if config_path and os.path.exists(config_path): 233 | self._load_from_file(config_path) 234 | 235 | def _load_from_file(self, config_path: str): 236 | """从文件加载配置""" 237 | with open(config_path, 'r', encoding='utf-8') as f: 238 | config_dict = json.load(f) 239 | 240 | # 加载各个模块的配置 241 | if 'text_encoder_config' in config_dict: 242 | self.text_encoder_config = TextEncoderConfig.from_dict(config_dict['text_encoder_config']) 243 | 244 | if 'image_encoder_config' in config_dict: 245 | self.image_encoder_config = ImageEncoderConfig.from_dict(config_dict['image_encoder_config']) 246 | 247 | if 'audio_encoder_config' in config_dict: 248 | self.audio_encoder_config = AudioEncoderConfig.from_dict(config_dict['audio_encoder_config']) 249 | 250 | if 'fusion_config' in config_dict: 251 | self.fusion_config = FusionConfig.from_dict(config_dict['fusion_config']) 252 | 253 | if 'decoder_config' in config_dict: 254 | self.decoder_config = DecoderConfig.from_dict(config_dict['decoder_config']) 255 | 256 | def to_dict(self) -> Dict[str, Dict[str, Any]]: 257 | """将配置转换为字典""" 258 | return { 259 | 'text_encoder_config': self.text_encoder_config.to_dict(), 260 | 'image_encoder_config': self.image_encoder_config.to_dict(), 261 | 'audio_encoder_config': self.audio_encoder_config.to_dict(), 262 | 'fusion_config': self.fusion_config.to_dict(), 263 | 'decoder_config': self.decoder_config.to_dict() 264 | } 265 | 266 | def save(self, config_path: str): 267 | """保存配置到文件""" 268 | with open(config_path, 'w', encoding='utf-8') as f: 269 | json.dump(self.to_dict(), f, ensure_ascii=False, indent=2) -------------------------------------------------------------------------------- /examples/langchain_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | LangChain 集成示例 5 | ================ 6 | 7 | 展示如何使用MultiSpaceAI与LangChain集成,实现更强大的多模态处理能力。 8 | """ 9 | 10 | import os 11 | import sys 12 | import argparse 13 | from typing import Optional, Dict, Any 14 | 15 | # 将项目根目录添加到Python路径 16 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 17 | 18 | # 导入LangChain相关模块 19 | try: 20 | from langchain.llms import OpenAI 21 | from langchain.chat_models import ChatOpenAI 22 | from langchain.prompts import PromptTemplate 23 | from langchain.memory import ConversationBufferMemory 24 | except ImportError: 25 | print("请先安装LangChain相关依赖:") 26 | print("pip install langchain openai") 27 | sys.exit(1) 28 | 29 | # 导入MultiSpaceAI的LangChain集成 30 | from src.langchain_integration import ( 31 | MultiSpaceAITool, 32 | MultiSpaceAIChain, 33 | MultiSpaceAIAgent, 34 | MultiSpaceAIRetrievalChain 35 | ) 36 | 37 | 38 | def setup_llm(model_name: str = "gpt-3.5-turbo"): 39 | """设置LLM模型 40 | 41 | Args: 42 | model_name: 模型名称 43 | 44 | Returns: 45 | LLM模型实例 46 | """ 47 | # 检查是否设置了OpenAI API密钥 48 | if not os.environ.get("OPENAI_API_KEY"): 49 | print("警告: 未设置OPENAI_API_KEY环境变量") 50 | print("您可以通过以下命令设置环境变量:") 51 | print("export OPENAI_API_KEY=your_api_key_here") 52 | print("或者在运行脚本时传入API密钥") 53 | 54 | # 初始化Chat模型 55 | llm = ChatOpenAI( 56 | model_name=model_name, 57 | temperature=0.7, 58 | max_tokens=1000 59 | ) 60 | 61 | return llm 62 | 63 | 64 | def run_multispace_tool_example(args): 65 | """运行MultiSpaceAI工具示例 66 | 67 | Args: 68 | args: 命令行参数 69 | """ 70 | print("\n===== 运行MultiSpaceAI工具示例 =====") 71 | 72 | # 设置LLM 73 | llm = setup_llm(args.model) 74 | 75 | # 初始化MultiSpaceAI工具 76 | multispace_tool = MultiSpaceAITool( 77 | config_path=args.config, 78 | text_encoder=args.text_encoder, 79 | image_encoder=args.image_encoder, 80 | audio_encoder=args.audio_encoder, 81 | device=args.device 82 | ) 83 | 84 | # 运行MultiSpaceAI工具 85 | if args.image: 86 | print(f"\n处理图像: {args.image}") 87 | result = multispace_tool.run( 88 | query=args.query, 89 | image=args.image 90 | ) 91 | elif args.audio: 92 | print(f"\n处理音频: {args.audio}") 93 | result = multispace_tool.run( 94 | query=args.query, 95 | audio=args.audio 96 | ) 97 | else: 98 | print(f"\n处理文本: {args.query}") 99 | result = multispace_tool.run( 100 | query=args.query 101 | ) 102 | 103 | print("\n生成结果:") 104 | print("-" * 50) 105 | print(result) 106 | print("-" * 50) 107 | 108 | 109 | def run_multispace_chain_example(args): 110 | """运行MultiSpaceAI链示例 111 | 112 | Args: 113 | args: 命令行参数 114 | """ 115 | print("\n===== 运行MultiSpaceAI链示例 =====") 116 | 117 | # 设置LLM 118 | llm = setup_llm(args.model) 119 | 120 | # 创建提示模板 121 | template = """你是一个多模态助手,能够处理文本、图像和音频输入。 122 | 123 | {multimodal_context} 124 | 125 | 用户问题: {query} 126 | 127 | 请提供详细且有用的回答:""" 128 | 129 | prompt = PromptTemplate( 130 | template=template, 131 | input_variables=["multimodal_context", "query"] 132 | ) 133 | 134 | # 初始化对话内存 135 | memory = ConversationBufferMemory(memory_key="chat_history") 136 | 137 | # 初始化MultiSpaceAI链 138 | multispace_chain = MultiSpaceAIChain( 139 | llm=llm, 140 | prompt=prompt, 141 | config_path=args.config, 142 | text_encoder=args.text_encoder, 143 | image_encoder=args.image_encoder, 144 | audio_encoder=args.audio_encoder, 145 | device=args.device, 146 | memory=memory 147 | ) 148 | 149 | # 准备输入 150 | inputs = {"query": args.query} 151 | 152 | if args.image: 153 | print(f"\n处理图像: {args.image}") 154 | inputs["image"] = args.image 155 | 156 | if args.audio: 157 | print(f"\n处理音频: {args.audio}") 158 | inputs["audio"] = args.audio 159 | 160 | # 运行MultiSpaceAI链 161 | result = multispace_chain(inputs) 162 | 163 | print("\n生成结果:") 164 | print("-" * 50) 165 | print(result["text"]) 166 | print("-" * 50) 167 | 168 | 169 | def run_multispace_agent_example(args): 170 | """运行MultiSpaceAI代理示例 171 | 172 | Args: 173 | args: 命令行参数 174 | """ 175 | print("\n===== 运行MultiSpaceAI代理示例 =====") 176 | 177 | # 设置LLM 178 | llm = setup_llm(args.model) 179 | 180 | # 初始化对话内存 181 | memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) 182 | 183 | # 初始化MultiSpaceAI代理 184 | multispace_agent = MultiSpaceAIAgent( 185 | llm=llm, 186 | config_path=args.config, 187 | text_encoder=args.text_encoder, 188 | image_encoder=args.image_encoder, 189 | audio_encoder=args.audio_encoder, 190 | device=args.device, 191 | verbose=True, 192 | memory=memory 193 | ) 194 | 195 | # 构建查询 196 | if args.image: 197 | query = f"{args.query} 图像路径: {args.image}" 198 | elif args.audio: 199 | query = f"{args.query} 音频路径: {args.audio}" 200 | else: 201 | query = args.query 202 | 203 | # 运行MultiSpaceAI代理 204 | result = multispace_agent.run(query) 205 | 206 | print("\n代理结果:") 207 | print("-" * 50) 208 | print(result) 209 | print("-" * 50) 210 | 211 | 212 | def run_multispace_retrieval_example(args): 213 | """运行MultiSpaceAI检索链示例 214 | 215 | Args: 216 | args: 命令行参数 217 | """ 218 | print("\n===== 运行MultiSpaceAI检索链示例 =====") 219 | 220 | # 设置LLM 221 | llm = setup_llm(args.model) 222 | 223 | # 初始化MultiSpaceAI检索链 224 | retrieval_chain = MultiSpaceAIRetrievalChain( 225 | llm=llm, 226 | config_path=args.config, 227 | text_encoder=args.text_encoder, 228 | image_encoder=args.image_encoder, 229 | audio_encoder=args.audio_encoder, 230 | device=args.device, 231 | return_source_documents=True 232 | ) 233 | 234 | # 如果指定了数据目录,则导入文档 235 | if args.data_dir: 236 | print(f"\n从目录导入文档: {args.data_dir}") 237 | retrieval_chain.ingest_from_directory( 238 | directory=args.data_dir, 239 | chunk_size=args.chunk_size, 240 | chunk_overlap=args.chunk_overlap 241 | ) 242 | print("文档导入完成") 243 | 244 | # 运行检索链 245 | result = retrieval_chain.run( 246 | query=args.query, 247 | image=args.image, 248 | audio=args.audio 249 | ) 250 | 251 | print("\n检索结果:") 252 | print("-" * 50) 253 | print(result["result"]) 254 | print("-" * 50) 255 | 256 | # 显示源文档 257 | if result.get("source_documents") and len(result["source_documents"]) > 0: 258 | print("\n源文档:") 259 | for i, doc in enumerate(result["source_documents"]): 260 | print(f"\n文档 {i+1}:") 261 | print(f"内容: {doc.page_content[:100]}...") 262 | print(f"来源: {doc.metadata.get('source', 'Unknown')}") 263 | 264 | 265 | def main(): 266 | """主函数""" 267 | parser = argparse.ArgumentParser(description="MultiSpaceAI LangChain集成示例") 268 | 269 | # 基本参数 270 | parser.add_argument("--query", type=str, default="描述这个输入", help="查询文本") 271 | parser.add_argument("--config", type=str, help="配置文件路径") 272 | parser.add_argument("--device", type=str, help="运行设备") 273 | 274 | # 输入参数 275 | parser.add_argument("--image", type=str, help="图像文件路径") 276 | parser.add_argument("--audio", type=str, help="音频文件路径") 277 | 278 | # LLM参数 279 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo", help="LLM模型名称") 280 | parser.add_argument("--api-key", type=str, help="OpenAI API密钥") 281 | 282 | # 编码器参数 283 | parser.add_argument("--text-encoder", type=str, default="custom", choices=["custom", "deepseek-api"], help="文本编码器类型") 284 | parser.add_argument("--image-encoder", type=str, default="custom", choices=["custom", "diffusion-api"], help="图像编码器类型") 285 | parser.add_argument("--audio-encoder", type=str, default="custom", choices=["custom", "whisper-api"], help="音频编码器类型") 286 | 287 | # 检索链参数 288 | parser.add_argument("--data-dir", type=str, help="数据目录路径") 289 | parser.add_argument("--chunk-size", type=int, default=1000, help="文本块大小") 290 | parser.add_argument("--chunk-overlap", type=int, default=0, help="文本块重叠大小") 291 | 292 | # 示例类型 293 | parser.add_argument("--example", type=str, default="tool", 294 | choices=["tool", "chain", "agent", "retrieval", "all"], 295 | help="要运行的示例类型") 296 | 297 | args = parser.parse_args() 298 | 299 | # 如果提供了API密钥,设置环境变量 300 | if args.api_key: 301 | os.environ["OPENAI_API_KEY"] = args.api_key 302 | 303 | # 根据示例类型运行相应示例 304 | if args.example == "tool" or args.example == "all": 305 | run_multispace_tool_example(args) 306 | 307 | if args.example == "chain" or args.example == "all": 308 | run_multispace_chain_example(args) 309 | 310 | if args.example == "agent" or args.example == "all": 311 | run_multispace_agent_example(args) 312 | 313 | if args.example == "retrieval" or args.example == "all": 314 | if not args.data_dir and args.example != "all": 315 | print("警告: 未提供数据目录,检索示例可能无法正常工作") 316 | run_multispace_retrieval_example(args) 317 | 318 | 319 | if __name__ == "__main__": 320 | main() -------------------------------------------------------------------------------- /langchain_integration/multispace_chain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | MultiSpaceAI LangChain 集成 5 | ========================= 6 | 7 | 将MultiSpaceAI与LangChain集成,实现更强大的多模态处理能力。 8 | """ 9 | 10 | import os 11 | import json 12 | from typing import Any, Dict, List, Optional, Union 13 | 14 | from langchain.chains import LLMChain 15 | from langchain.prompts import PromptTemplate 16 | from langchain.memory import ConversationBufferMemory 17 | from langchain.callbacks.base import BaseCallbackHandler 18 | from langchain.schema import AIMessage, HumanMessage, SystemMessage 19 | from langchain.tools import BaseTool 20 | from langchain.agents import AgentExecutor, initialize_agent, AgentType 21 | from langchain.callbacks.manager import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun 22 | 23 | # 引入项目根目录 24 | import sys 25 | import os 26 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) 27 | 28 | # 导入MultiSpaceAI 29 | from src.multispace import MultiSpaceAI 30 | 31 | 32 | class MultiSpaceAITool(BaseTool): 33 | """将MultiSpaceAI封装为LangChain工具""" 34 | 35 | name = "multispace_ai" 36 | description = "通过MultiSpaceAI处理多模态输入,支持文本、图像和音频输入" 37 | 38 | multispace_model: MultiSpaceAI 39 | return_direct: bool = False 40 | 41 | def __init__( 42 | self, 43 | config_path: Optional[str] = None, 44 | text_encoder: str = "custom", 45 | image_encoder: str = "custom", 46 | audio_encoder: str = "custom", 47 | device: Optional[str] = None, 48 | return_direct: bool = False 49 | ): 50 | """初始化MultiSpaceAI工具 51 | 52 | Args: 53 | config_path: 配置文件路径 54 | text_encoder: 文本编码器类型 55 | image_encoder: 图像编码器类型 56 | audio_encoder: 音频编码器类型 57 | device: 运行设备 58 | return_direct: 是否直接返回结果 59 | """ 60 | super().__init__() 61 | 62 | # 初始化MultiSpaceAI模型 63 | self.multispace_model = MultiSpaceAI( 64 | config_path=config_path, 65 | text_encoder=text_encoder, 66 | image_encoder=image_encoder, 67 | audio_encoder=audio_encoder, 68 | device=device 69 | ) 70 | 71 | self.return_direct = return_direct 72 | 73 | def _run( 74 | self, 75 | query: str, 76 | text: Optional[str] = None, 77 | image: Optional[str] = None, 78 | audio: Optional[str] = None, 79 | max_length: int = 50, 80 | num_beams: int = 4, 81 | run_manager: Optional[CallbackManagerForToolRun] = None, 82 | ) -> str: 83 | """运行MultiSpaceAI处理 84 | 85 | Args: 86 | query: 查询 87 | text: 文本输入 88 | image: 图像文件路径 89 | audio: 音频文件路径 90 | max_length: 生成的最大长度 91 | num_beams: 束搜索的束数 92 | run_manager: 回调管理器 93 | 94 | Returns: 95 | 处理结果 96 | """ 97 | # 处理输入 98 | result = self.multispace_model.process( 99 | text=text or query, 100 | image=image, 101 | audio=audio, 102 | max_length=max_length, 103 | num_beams=num_beams 104 | ) 105 | 106 | # 返回生成文本 107 | return result["generated_text"] 108 | 109 | async def _arun( 110 | self, 111 | query: str, 112 | text: Optional[str] = None, 113 | image: Optional[str] = None, 114 | audio: Optional[str] = None, 115 | max_length: int = 50, 116 | num_beams: int = 4, 117 | run_manager: Optional[AsyncCallbackManagerForToolRun] = None, 118 | ) -> str: 119 | """异步运行MultiSpaceAI处理 120 | 121 | Args: 122 | query: 查询 123 | text: 文本输入 124 | image: 图像文件路径 125 | audio: 音频文件路径 126 | max_length: 生成的最大长度 127 | num_beams: 束搜索的束数 128 | run_manager: 回调管理器 129 | 130 | Returns: 131 | 处理结果 132 | """ 133 | # 处理输入 134 | result = self.multispace_model.process( 135 | text=text or query, 136 | image=image, 137 | audio=audio, 138 | max_length=max_length, 139 | num_beams=num_beams 140 | ) 141 | 142 | # 返回生成文本 143 | return result["generated_text"] 144 | 145 | 146 | class MultiSpaceAIChain(LLMChain): 147 | """MultiSpaceAI LangChain 链""" 148 | 149 | multispace_model: MultiSpaceAI 150 | 151 | def __init__( 152 | self, 153 | llm: Any, 154 | prompt: PromptTemplate, 155 | config_path: Optional[str] = None, 156 | text_encoder: str = "custom", 157 | image_encoder: str = "custom", 158 | audio_encoder: str = "custom", 159 | device: Optional[str] = None, 160 | memory: Optional[ConversationBufferMemory] = None, 161 | callbacks: Optional[List[BaseCallbackHandler]] = None, 162 | ): 163 | """初始化MultiSpaceAI链 164 | 165 | Args: 166 | llm: 大语言模型 167 | prompt: 提示模板 168 | config_path: 配置文件路径 169 | text_encoder: 文本编码器类型 170 | image_encoder: 图像编码器类型 171 | audio_encoder: 音频编码器类型 172 | device: 运行设备 173 | memory: 对话内存 174 | callbacks: 回调处理器列表 175 | """ 176 | super().__init__( 177 | llm=llm, 178 | prompt=prompt, 179 | memory=memory, 180 | callbacks=callbacks 181 | ) 182 | 183 | # 初始化MultiSpaceAI模型 184 | self.multispace_model = MultiSpaceAI( 185 | config_path=config_path, 186 | text_encoder=text_encoder, 187 | image_encoder=image_encoder, 188 | audio_encoder=audio_encoder, 189 | device=device 190 | ) 191 | 192 | def process_multimodal( 193 | self, 194 | text: Optional[str] = None, 195 | image: Optional[str] = None, 196 | audio: Optional[str] = None, 197 | max_length: int = 50, 198 | num_beams: int = 4, 199 | ) -> Dict[str, Any]: 200 | """处理多模态输入 201 | 202 | Args: 203 | text: 文本输入 204 | image: 图像文件路径 205 | audio: 音频文件路径 206 | max_length: 生成的最大长度 207 | num_beams: 束搜索的束数 208 | 209 | Returns: 210 | 处理结果 211 | """ 212 | # 处理输入 213 | result = self.multispace_model.process( 214 | text=text, 215 | image=image, 216 | audio=audio, 217 | max_length=max_length, 218 | num_beams=num_beams 219 | ) 220 | 221 | return result 222 | 223 | def __call__( 224 | self, 225 | inputs: Dict[str, Any], 226 | return_only_outputs: bool = False, 227 | **kwargs 228 | ) -> Dict[str, Any]: 229 | """运行链 230 | 231 | Args: 232 | inputs: 输入参数 233 | return_only_outputs: 是否只返回输出 234 | **kwargs: 额外参数 235 | 236 | Returns: 237 | 链运行结果 238 | """ 239 | # 处理多模态输入 240 | if any(k in inputs for k in ["image", "audio"]): 241 | multimodal_result = self.process_multimodal( 242 | text=inputs.get("text", inputs.get("query", None)), 243 | image=inputs.get("image"), 244 | audio=inputs.get("audio"), 245 | max_length=inputs.get("max_length", 50), 246 | num_beams=inputs.get("num_beams", 4), 247 | ) 248 | 249 | # 将多模态处理结果添加到输入 250 | inputs["multimodal_context"] = multimodal_result["generated_text"] 251 | 252 | # 调用父类的__call__方法 253 | return super().__call__(inputs, return_only_outputs, **kwargs) 254 | 255 | 256 | class MultiSpaceAIAgent: 257 | """MultiSpaceAI LangChain 代理""" 258 | 259 | def __init__( 260 | self, 261 | llm: Any, 262 | config_path: Optional[str] = None, 263 | text_encoder: str = "custom", 264 | image_encoder: str = "custom", 265 | audio_encoder: str = "custom", 266 | device: Optional[str] = None, 267 | verbose: bool = False, 268 | memory: Optional[ConversationBufferMemory] = None, 269 | callbacks: Optional[List[BaseCallbackHandler]] = None, 270 | ): 271 | """初始化MultiSpaceAI代理 272 | 273 | Args: 274 | llm: 大语言模型 275 | config_path: 配置文件路径 276 | text_encoder: 文本编码器类型 277 | image_encoder: 图像编码器类型 278 | audio_encoder: 音频编码器类型 279 | device: 运行设备 280 | verbose: 是否输出详细信息 281 | memory: 对话内存 282 | callbacks: 回调处理器列表 283 | """ 284 | # 初始化MultiSpaceAI工具 285 | self.multispace_tool = MultiSpaceAITool( 286 | config_path=config_path, 287 | text_encoder=text_encoder, 288 | image_encoder=image_encoder, 289 | audio_encoder=audio_encoder, 290 | device=device 291 | ) 292 | 293 | # 工具列表 294 | tools = [self.multispace_tool] 295 | 296 | # 初始化代理 297 | self.agent = initialize_agent( 298 | agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, 299 | tools=tools, 300 | llm=llm, 301 | verbose=verbose, 302 | memory=memory, 303 | handle_parsing_errors=True, 304 | max_iterations=5, 305 | early_stopping_method="generate", 306 | callbacks=callbacks 307 | ) 308 | 309 | def run(self, query: str, **kwargs) -> str: 310 | """运行代理 311 | 312 | Args: 313 | query: 查询 314 | **kwargs: 额外参数 315 | 316 | Returns: 317 | 代理运行结果 318 | """ 319 | return self.agent.run(query, **kwargs) 320 | 321 | async def arun(self, query: str, **kwargs) -> str: 322 | """异步运行代理 323 | 324 | Args: 325 | query: 查询 326 | **kwargs: 额外参数 327 | 328 | Returns: 329 | 代理运行结果 330 | """ 331 | return await self.agent.arun(query, **kwargs) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiSpaceAI: 多模态大语言模型 2 | 3 | MultiSpaceAI 是一个高级多模态大语言模型系统,能够处理文本、图像和语音输入,实现复杂的跨模态理解与生成任务。 4 | 5 | ![项目徽标](docs/assets/multispace_logo.png) 6 | 7 | ## 项目介绍 8 | 9 | MultiSpaceAI 旨在建立一个全面的多模态智能系统,集成文本、图像和语音三种核心模态的理解与生成能力,实现复杂场景下的智能交互。系统采用先进的深度学习架构,包括Transformer、Vision Transformer和Conformer,实现高效的多模态融合与处理。系统支持与DeepSeek-API、Stable Diffusion、Whisper等先进AI模型的无缝集成,可混合使用自定义模型和API服务的能力。 10 | 11 | ## 核心特性 12 | 13 | - **多模态融合处理**:同时处理文本、图像和语音输入 14 | - **跨模态理解**:通过先进的多头交叉注意力实现模态间的深度融合 15 | - **高质量生成**:基于融合表示生成高质量的多模态输出 16 | - **高效推理**:优化的模型架构和推理策略,支持实时应用 17 | - **扩展性强**:模块化设计,支持新模态和新功能的灵活扩展 18 | - **API集成**:支持与DeepSeek、Stable Diffusion、Whisper等先进AI模型API的集成 19 | 20 | ## 系统架构 21 | 22 | 系统主要由以下模块组成: 23 | 24 | 1. **文本编码器**:处理和编码文本输入(可选择自定义模型或DeepSeek-API) 25 | 2. **图像编码器**:处理和编码图像输入(可选择自定义模型或Diffusion模型API) 26 | 3. **语音编码器**:处理和编码语音输入(可选择自定义模型或Whisper API) 27 | 4. **多模态融合模块**:整合不同模态的特征表示 28 | 5. **解码器模块**:基于融合的特征生成输出 29 | 6. **输出层**:根据任务需求生成最终结果 30 | 31 | 详细的系统设计请参考[设计文档](multimodal_llm_design_doc.md)。 32 | 33 | # MultiSpaceAI 系统架构 34 | 35 | 本文档描述了 MultiSpaceAI 的系统架构,包括总体架构和各个模块的流程图。 36 | 37 | ## 总体架构 38 | 39 | ```mermaid 40 | graph TD 41 | User[用户输入] --> InputProcessor[输入处理器] 42 | InputProcessor --> |文本| TextEncoder[文本编码器] 43 | InputProcessor --> |图像| ImageEncoder[图像编码器] 44 | InputProcessor --> |音频| AudioEncoder[音频编码器] 45 | 46 | TextEncoder --> |文本嵌入| Fusion[多模态融合模块] 47 | ImageEncoder --> |图像嵌入| Fusion 48 | AudioEncoder --> |音频嵌入| Fusion 49 | 50 | Fusion --> |融合表示| Decoder[解码器] 51 | Decoder --> |生成文本| OutputProcessor[输出处理器] 52 | OutputProcessor --> Result[处理结果] 53 | 54 | Config[配置系统] -.-> TextEncoder 55 | Config -.-> ImageEncoder 56 | Config -.-> AudioEncoder 57 | Config -.-> Fusion 58 | Config -.-> Decoder 59 | ``` 60 | 61 | MultiSpaceAI 的总体架构由以下主要组件组成: 62 | 63 | 1. **输入处理器**:接收并预处理用户的多模态输入(文本、图像、音频) 64 | 2. **文本编码器**:将文本输入编码为高维嵌入 65 | 3. **图像编码器**:将图像输入编码为高维嵌入 66 | 4. **音频编码器**:将音频输入编码为高维嵌入 67 | 5. **多模态融合模块**:将不同模态的嵌入融合为统一表示 68 | 6. **解码器**:基于融合表示生成输出文本 69 | 7. **输出处理器**:处理和格式化最终结果 70 | 8. **配置系统**:为各个模块提供配置参数 71 | 72 | ## 输入处理流程 73 | 74 | ```mermaid 75 | graph TD 76 | Input[用户输入] --> InputCheck{检查输入类型} 77 | 78 | InputCheck -->|文本| TextPreprocess[文本预处理] 79 | InputCheck -->|图像| ImagePreprocess[图像预处理] 80 | InputCheck -->|音频| AudioPreprocess[音频预处理] 81 | 82 | TextPreprocess --> |清洗、分词| TextNormalize[文本规范化] 83 | ImagePreprocess --> |缩放、归一化| ImageNormalize[图像规范化] 84 | AudioPreprocess --> |重采样、特征提取| AudioNormalize[音频规范化] 85 | 86 | TextNormalize --> TextReady[准备文本编码] 87 | ImageNormalize --> ImageReady[准备图像编码] 88 | AudioNormalize --> AudioReady[准备音频编码] 89 | ``` 90 | 91 | ## 文本编码器模块 92 | 93 | ```mermaid 94 | graph TD 95 | TextInput[文本输入] --> ModelSelect{编码器选择} 96 | ModelSelect -->|自定义模型| CustomTextEncoder[自定义文本编码器] 97 | ModelSelect -->|API集成| APITextEncoder[API文本编码器] 98 | 99 | CustomTextEncoder --> |加载模型| LoadTextModel[加载文本模型] 100 | APITextEncoder --> |API调用| CallTextAPI[调用文本API] 101 | 102 | LoadTextModel --> TextTokenize[文本分词] 103 | CallTextAPI --> APIProcess[API处理] 104 | 105 | TextTokenize --> TextEmbed[文本编码] 106 | APIProcess --> APIEmbed[API返回编码] 107 | 108 | TextEmbed --> TextOutput[文本嵌入输出] 109 | APIEmbed --> TextOutput 110 | ``` 111 | 112 | ## 图像编码器模块 113 | 114 | ```mermaid 115 | graph TD 116 | ImageInput[图像输入] --> ImgModelSelect{编码器选择} 117 | ImgModelSelect -->|自定义模型| CustomImgEncoder[自定义图像编码器] 118 | ImgModelSelect -->|API集成| APIImgEncoder[API图像编码器] 119 | 120 | CustomImgEncoder --> |加载模型| LoadImgModel[加载图像模型] 121 | APIImgEncoder --> |API调用| CallImgAPI[调用图像API] 122 | 123 | LoadImgModel --> ImgProcess[图像处理] 124 | CallImgAPI --> APIImgProcess[API处理] 125 | 126 | ImgProcess --> ImgEmbed[图像编码] 127 | APIImgProcess --> APIImgEmbed[API返回编码] 128 | 129 | ImgEmbed --> ImgOutput[图像嵌入输出] 130 | APIImgEmbed --> ImgOutput 131 | ``` 132 | 133 | ## 音频编码器模块 134 | 135 | ```mermaid 136 | graph TD 137 | AudioInput[音频输入] --> AudioModelSelect{编码器选择} 138 | AudioModelSelect -->|自定义模型| CustomAudioEncoder[自定义音频编码器] 139 | AudioModelSelect -->|API集成| APIAudioEncoder[API音频编码器] 140 | 141 | CustomAudioEncoder --> |加载模型| LoadAudioModel[加载音频模型] 142 | APIAudioEncoder --> |API调用| CallAudioAPI[调用音频API] 143 | 144 | LoadAudioModel --> AudioFeatures[提取音频特征] 145 | CallAudioAPI --> APIAudioProcess[API处理] 146 | 147 | AudioFeatures --> AudioEmbed[音频编码] 148 | APIAudioProcess --> APIAudioEmbed[API返回编码] 149 | 150 | AudioEmbed --> AudioOutput[音频嵌入输出] 151 | APIAudioEmbed --> AudioOutput 152 | ``` 153 | 154 | ## 多模态融合模块 155 | 156 | ```mermaid 157 | graph TD 158 | TextEmbed[文本嵌入] --> EmbedProject[嵌入投影] 159 | ImageEmbed[图像嵌入] --> EmbedProject 160 | AudioEmbed[音频嵌入] --> EmbedProject 161 | 162 | EmbedProject --> FusionType{融合类型} 163 | 164 | FusionType -->|注意力融合| AttentionFusion[注意力融合] 165 | FusionType -->|拼接融合| ConcatFusion[拼接融合] 166 | FusionType -->|加权融合| WeightedFusion[加权融合] 167 | 168 | AttentionFusion --> CrossAttention[交叉注意力] 169 | ConcatFusion --> ProjectionLayer[投影层] 170 | WeightedFusion --> ModalityWeight[模态权重计算] 171 | 172 | CrossAttention --> FusionOutput[融合输出] 173 | ProjectionLayer --> FusionOutput 174 | ModalityWeight --> FusionOutput 175 | ``` 176 | 177 | ## 解码器模块 178 | 179 | ```mermaid 180 | graph TD 181 | FusionEmbed[融合嵌入] --> DecoderInput[解码器输入] 182 | 183 | DecoderInput --> Decoding{解码策略} 184 | 185 | Decoding -->|贪婪解码| GreedyDecode[贪婪解码] 186 | Decoding -->|束搜索| BeamSearch[束搜索] 187 | Decoding -->|采样| Sampling[采样解码] 188 | 189 | GreedyDecode --> TokenGeneration[生成序列] 190 | BeamSearch --> TokenGeneration 191 | Sampling --> TokenGeneration 192 | 193 | TokenGeneration --> PostProcess[后处理] 194 | PostProcess --> OutputText[输出文本] 195 | ``` 196 | 197 | ## 配置系统 198 | 199 | ```mermaid 200 | graph TD 201 | ConfigInput[配置输入] --> ConfigType{配置类型} 202 | 203 | ConfigType -->|默认配置| DefaultConfig[默认配置] 204 | ConfigType -->|自定义配置| CustomConfig[自定义配置] 205 | ConfigType -->|JSON配置| JSONConfig[JSON配置] 206 | 207 | DefaultConfig --> ConfigValidation[配置验证] 208 | CustomConfig --> ConfigValidation 209 | JSONConfig --> ConfigValidation 210 | 211 | ConfigValidation --> ConfigDistribution[配置分发] 212 | 213 | ConfigDistribution --> TextEncoderConfig[文本编码器配置] 214 | ConfigDistribution --> ImageEncoderConfig[图像编码器配置] 215 | ConfigDistribution --> AudioEncoderConfig[音频编码器配置] 216 | ConfigDistribution --> FusionConfig[融合模块配置] 217 | ConfigDistribution --> DecoderConfig[解码器配置] 218 | ``` 219 | 220 | ## 数据流程 221 | 222 | ```mermaid 223 | sequenceDiagram 224 | participant User as 用户 225 | participant Input as 输入处理器 226 | participant TextEnc as 文本编码器 227 | participant ImageEnc as 图像编码器 228 | participant AudioEnc as 音频编码器 229 | participant Fusion as 融合模块 230 | participant Decoder as 解码器 231 | participant Output as 输出处理器 232 | 233 | User->>Input: 提供多模态输入 234 | 235 | par 并行处理 236 | Input->>TextEnc: 文本输入 237 | Input->>ImageEnc: 图像输入 238 | Input->>AudioEnc: 音频输入 239 | end 240 | 241 | TextEnc-->>Fusion: 文本嵌入 242 | ImageEnc-->>Fusion: 图像嵌入 243 | AudioEnc-->>Fusion: 音频嵌入 244 | 245 | Fusion->>Decoder: 融合表示 246 | Decoder->>Output: 生成文本 247 | Output->>User: 返回结果 248 | ``` 249 | 250 | ## 模型训练流程 251 | 252 | ```mermaid 253 | graph TD 254 | TrainData[训练数据] --> DataProcess[数据预处理] 255 | DataProcess --> DataLoader[数据加载器] 256 | 257 | DataLoader --> TrainLoop[训练循环] 258 | 259 | PretrainedModels[预训练模型] --> ModelInit[模型初始化] 260 | ModelInit --> TrainLoop 261 | 262 | TrainLoop --> Forward[前向传播] 263 | Forward --> Loss[损失计算] 264 | Loss --> Backward[反向传播] 265 | Backward --> Optimize[优化器更新] 266 | Optimize --> TrainLoop 267 | 268 | TrainLoop --> Evaluation[模型评估] 269 | Evaluation --> SaveModel[保存模型] 270 | ``` 271 | 272 | ## API集成架构 273 | 274 | ```mermaid 275 | graph TD 276 | APIRequest[API请求] --> APIAuth[API认证] 277 | APIAuth --> APISelect{API选择} 278 | 279 | APISelect -->|文本API| TextAPI[文本API服务] 280 | APISelect -->|图像API| ImageAPI[图像API服务] 281 | APISelect -->|音频API| AudioAPI[音频API服务] 282 | 283 | TextAPI --> APIProcess[API处理] 284 | ImageAPI --> APIProcess 285 | AudioAPI --> APIProcess 286 | 287 | APIProcess --> CacheCheck{缓存检查} 288 | 289 | CacheCheck -->|缓存命中| CacheResult[使用缓存结果] 290 | CacheCheck -->|缓存未命中| APICall[调用外部API] 291 | 292 | APICall --> RateLimit[速率限制] 293 | RateLimit --> ExternalAPI[外部API服务] 294 | ExternalAPI --> APIResponse[API响应] 295 | 296 | APIResponse --> CacheStore[存储到缓存] 297 | CacheStore --> APIResult[API结果] 298 | CacheResult --> APIResult 299 | ``` 300 | 301 | ## 应用场景 302 | 303 | - **智能助手**:多模态交互式AI助手 304 | - **内容创作**:基于多模态输入的创意内容生成 305 | - **智能教育**:个性化的多模态学习体验 306 | - **医疗诊断**:结合图像和语音的智能诊断辅助 307 | - **自动驾驶**:环境理解与人机交互 308 | - **内容检索**:跨模态的高效信息检索 309 | 310 | ## 项目状态 311 | 312 | 当前版本: v0.1.0 (开发中) 313 | 314 | - [x] 系统总体设计 315 | - [x] 核心模块定义 316 | - [x] 文本编码器实现(含DeepSeek-API集成) 317 | - [x] 图像编码器实现(含Diffusion API集成) 318 | - [x] 语音编码器实现(含Whisper API集成) 319 | - [x] 多模态融合模块实现 320 | - [x] 解码器模块实现 321 | - [x] 系统整合测试 322 | - [x] 性能优化 323 | - [x] API服务部署 324 | 325 | ## 快速开始 326 | 327 | ### 环境要求 328 | 329 | - Python 3.8+ 330 | - CUDA 11.3+ 331 | - 32GB+ RAM 332 | - NVIDIA GPU (推荐A100或同等性能) 333 | - 相关API密钥(如需使用API集成) 334 | 335 | ### 安装 336 | 337 | ```bash 338 | git clone https://github.com/yourusername/MultiSpaceAI.git 339 | cd MultiSpaceAI 340 | pip install -r requirements.txt 341 | ``` 342 | 343 | ### 配置API 344 | 345 | 如果您计划使用外部API作为处理模块,需要进行以下配置: 346 | 347 | ```bash 348 | # 设置环境变量 349 | export DEEPSEEK_API_KEY="your_api_key_here" 350 | export STABLE_DIFFUSION_API_KEY="your_api_key_here" 351 | export OPENAI_API_KEY="your_api_key_here" # 用于Whisper API 352 | 353 | # 或在配置文件中设置 354 | echo "DEEPSEEK_API_KEY=your_api_key_here" >> .env 355 | echo "STABLE_DIFFUSION_API_KEY=your_api_key_here" >> .env 356 | echo "OPENAI_API_KEY=your_api_key_here" >> .env 357 | ``` 358 | 359 | ### 使用示例 360 | 361 | ```python 362 | from multispace import MultiSpaceAI 363 | 364 | # 初始化模型(默认使用自定义编码器) 365 | model = MultiSpaceAI() 366 | 367 | # 使用外部API作为编码器 368 | model_with_api = MultiSpaceAI( 369 | text_encoder="deepseek-api", 370 | image_encoder="diffusion-api", 371 | audio_encoder="whisper-api" 372 | ) 373 | 374 | # 文本-图像多模态处理 375 | response = model.process( 376 | text="这张图片中有什么?", 377 | image="path/to/image.jpg" 378 | ) 379 | 380 | # 文本-语音多模态处理 381 | response = model.process( 382 | text="请分析这段语音内容", 383 | audio="path/to/audio.wav" 384 | ) 385 | 386 | # 三模态融合处理 387 | response = model.process( 388 | text="请描述这张图片中的人在说什么", 389 | image="path/to/image.jpg", 390 | audio="path/to/audio.wav" 391 | ) 392 | 393 | print(response) 394 | ``` 395 | 396 | ## 贡献指南 397 | 398 | 我们欢迎各种形式的贡献,包括但不限于: 399 | 400 | - 代码优化与bug修复 401 | - 新功能开发 402 | - 文档完善 403 | - 使用案例分享 404 | 405 | 请参考[贡献指南](CONTRIBUTING.md)了解详细信息。 406 | 407 | ## 相关资源 408 | 409 | - [设计文档](multimodal_llm_design_doc.md) 410 | - [API文档](docs/api.md) 411 | - [DeepSeek官方文档](https://www.deepseek.com/docs) 412 | - [Stable Diffusion文档](https://stability.ai/docs) 413 | - [OpenAI Whisper文档](https://platform.openai.com/docs/api-reference/audio) 414 | 415 | ## 许可证 416 | 417 | 本项目采用 [MIT 许可证](LICENSE)。 418 | 419 | ## 联系我们 420 | 421 | - 项目主页: https://github.com/li-neo/MultiSpaceAI 422 | - 问题反馈: https://github.com/li-neo/MultiSpaceAI/issues 423 | - 邮箱: liguangxian1995@gmail.com -------------------------------------------------------------------------------- /langchain_integration/retrieval_chain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | MultiSpaceAI LangChain 检索链 5 | =========================== 6 | 7 | 将MultiSpaceAI与LangChain的向量检索功能集成,实现多模态内容的检索和问答。 8 | """ 9 | 10 | import os 11 | import json 12 | from typing import Any, Dict, List, Optional, Union 13 | 14 | from langchain.chains import RetrievalQA 15 | from langchain.prompts import PromptTemplate 16 | from langchain.memory import ConversationBufferMemory 17 | from langchain.callbacks.base import BaseCallbackHandler 18 | from langchain.vectorstores import Chroma, FAISS 19 | from langchain.vectorstores.base import VectorStore 20 | from langchain.embeddings.base import Embeddings 21 | from langchain.docstore.document import Document 22 | from langchain.text_splitter import RecursiveCharacterTextSplitter 23 | from langchain.retrievers import ContextualCompressionRetriever 24 | from langchain.retrievers.document_compressors import LLMChainExtractor 25 | 26 | # 引入项目根目录 27 | import sys 28 | import os 29 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) 30 | 31 | # 导入MultiSpaceAI 32 | from src.multispace import MultiSpaceAI 33 | from src.langchain_integration.multispace_chain import MultiSpaceAIChain 34 | 35 | 36 | class MultiSpaceAIEmbeddings(Embeddings): 37 | """MultiSpaceAI嵌入类,用于生成文本、图像和音频的嵌入向量""" 38 | 39 | def __init__( 40 | self, 41 | config_path: Optional[str] = None, 42 | text_encoder: str = "custom", 43 | image_encoder: str = "custom", 44 | audio_encoder: str = "custom", 45 | device: Optional[str] = None, 46 | embedding_dim: int = 768, 47 | ): 48 | """初始化MultiSpaceAI嵌入类 49 | 50 | Args: 51 | config_path: 配置文件路径 52 | text_encoder: 文本编码器类型 53 | image_encoder: 图像编码器类型 54 | audio_encoder: 音频编码器类型 55 | device: 运行设备 56 | embedding_dim: 嵌入维度 57 | """ 58 | # 初始化MultiSpaceAI模型 59 | self.multispace_model = MultiSpaceAI( 60 | config_path=config_path, 61 | text_encoder=text_encoder, 62 | image_encoder=image_encoder, 63 | audio_encoder=audio_encoder, 64 | device=device 65 | ) 66 | 67 | self.embedding_dim = embedding_dim 68 | 69 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 70 | """生成文档嵌入向量 71 | 72 | Args: 73 | texts: 文档文本列表 74 | 75 | Returns: 76 | 嵌入向量列表 77 | """ 78 | return [self.embed_query(text) for text in texts] 79 | 80 | def embed_query(self, text: str) -> List[float]: 81 | """生成查询嵌入向量 82 | 83 | Args: 84 | text: 查询文本 85 | 86 | Returns: 87 | 嵌入向量 88 | """ 89 | # 处理输入 90 | result = self.multispace_model.process( 91 | text=text, 92 | return_embeddings=True # 假设MultiSpaceAI支持返回嵌入 93 | ) 94 | 95 | # 返回嵌入向量 96 | embeddings = result.get("text_embeddings", None) 97 | 98 | # 如果返回的是张量,转换为列表 99 | if hasattr(embeddings, "tolist"): 100 | embeddings = embeddings.tolist() 101 | 102 | return embeddings 103 | 104 | 105 | class MultiModalDocument(Document): 106 | """多模态文档类,支持文本、图像和音频内容""" 107 | 108 | def __init__( 109 | self, 110 | page_content: str, 111 | metadata: Dict[str, Any] = None, 112 | image_path: Optional[str] = None, 113 | audio_path: Optional[str] = None, 114 | ): 115 | """初始化多模态文档 116 | 117 | Args: 118 | page_content: 文档文本内容 119 | metadata: 文档元数据 120 | image_path: 图像文件路径 121 | audio_path: 音频文件路径 122 | """ 123 | # 更新元数据 124 | metadata = metadata or {} 125 | if image_path: 126 | metadata["image_path"] = image_path 127 | if audio_path: 128 | metadata["audio_path"] = audio_path 129 | 130 | super().__init__(page_content=page_content, metadata=metadata) 131 | 132 | 133 | class MultiSpaceAIRetrievalChain: 134 | """MultiSpaceAI检索链,用于多模态内容的检索和问答""" 135 | 136 | def __init__( 137 | self, 138 | llm: Any, 139 | config_path: Optional[str] = None, 140 | text_encoder: str = "custom", 141 | image_encoder: str = "custom", 142 | audio_encoder: str = "custom", 143 | device: Optional[str] = None, 144 | vectorstore: Optional[VectorStore] = None, 145 | embeddings: Optional[Embeddings] = None, 146 | memory: Optional[ConversationBufferMemory] = None, 147 | chain_type: str = "stuff", 148 | callbacks: Optional[List[BaseCallbackHandler]] = None, 149 | return_source_documents: bool = False, 150 | ): 151 | """初始化MultiSpaceAI检索链 152 | 153 | Args: 154 | llm: 大语言模型 155 | config_path: 配置文件路径 156 | text_encoder: 文本编码器类型 157 | image_encoder: 图像编码器类型 158 | audio_encoder: 音频编码器类型 159 | device: 运行设备 160 | vectorstore: 向量存储 161 | embeddings: 嵌入模型 162 | memory: 对话内存 163 | chain_type: 链类型 164 | callbacks: 回调处理器列表 165 | return_source_documents: 是否返回源文档 166 | """ 167 | # 初始化MultiSpaceAI模型 168 | self.multispace_model = MultiSpaceAI( 169 | config_path=config_path, 170 | text_encoder=text_encoder, 171 | image_encoder=image_encoder, 172 | audio_encoder=audio_encoder, 173 | device=device 174 | ) 175 | 176 | # 初始化嵌入模型 177 | if embeddings is None: 178 | embeddings = MultiSpaceAIEmbeddings( 179 | config_path=config_path, 180 | text_encoder=text_encoder, 181 | image_encoder=image_encoder, 182 | audio_encoder=audio_encoder, 183 | device=device 184 | ) 185 | 186 | # 初始化向量存储 187 | if vectorstore is None: 188 | vectorstore = FAISS(embeddings=embeddings, index_name="multispace_index") 189 | 190 | # 初始化LLM压缩器 191 | compressor = LLMChainExtractor.from_llm(llm) 192 | 193 | # 初始化上下文压缩检索器 194 | retriever = ContextualCompressionRetriever( 195 | base_compressor=compressor, 196 | base_retriever=vectorstore.as_retriever() 197 | ) 198 | 199 | # 创建多模态问答提示模板 200 | template = """使用以下检索到的上下文和多模态信息来回答问题。 201 | 202 | 检索到的上下文: {context} 203 | 204 | 多模态信息: {multimodal_context} 205 | 206 | 问题: {question} 207 | 208 | 答案:""" 209 | 210 | prompt = PromptTemplate( 211 | template=template, 212 | input_variables=["context", "multimodal_context", "question"] 213 | ) 214 | 215 | # 初始化检索QA链 216 | self.qa_chain = RetrievalQA.from_chain_type( 217 | llm=llm, 218 | chain_type=chain_type, 219 | retriever=retriever, 220 | return_source_documents=return_source_documents, 221 | chain_type_kwargs={"prompt": prompt} 222 | ) 223 | 224 | # 初始化MultiSpaceAI链 225 | self.multispace_chain = MultiSpaceAIChain( 226 | llm=llm, 227 | prompt=prompt, 228 | config_path=config_path, 229 | text_encoder=text_encoder, 230 | image_encoder=image_encoder, 231 | audio_encoder=audio_encoder, 232 | device=device, 233 | memory=memory, 234 | callbacks=callbacks 235 | ) 236 | 237 | def add_documents(self, documents: List[Union[Document, MultiModalDocument]]) -> None: 238 | """添加文档到向量存储 239 | 240 | Args: 241 | documents: 文档列表 242 | """ 243 | # 获取检索器和向量存储 244 | retriever = self.qa_chain.retriever 245 | if isinstance(retriever, ContextualCompressionRetriever): 246 | vectorstore = retriever.base_retriever.vectorstore 247 | else: 248 | vectorstore = retriever.vectorstore 249 | 250 | # 添加文档到向量存储 251 | vectorstore.add_documents(documents) 252 | 253 | def ingest_from_directory( 254 | self, 255 | directory: str, 256 | text_exts: List[str] = [".txt", ".md", ".html"], 257 | image_exts: List[str] = [".jpg", ".jpeg", ".png"], 258 | audio_exts: List[str] = [".mp3", ".wav", ".ogg"], 259 | text_splitter = None, 260 | chunk_size: int = 1000, 261 | chunk_overlap: int = 0 262 | ) -> None: 263 | """从目录导入文档 264 | 265 | Args: 266 | directory: 目录路径 267 | text_exts: 文本文件扩展名列表 268 | image_exts: 图像文件扩展名列表 269 | audio_exts: 音频文件扩展名列表 270 | text_splitter: 文本分割器 271 | chunk_size: 文本块大小 272 | chunk_overlap: 文本块重叠大小 273 | """ 274 | if text_splitter is None: 275 | text_splitter = RecursiveCharacterTextSplitter( 276 | chunk_size=chunk_size, 277 | chunk_overlap=chunk_overlap 278 | ) 279 | 280 | documents = [] 281 | 282 | # 遍历目录 283 | for root, _, files in os.walk(directory): 284 | for file in files: 285 | file_path = os.path.join(root, file) 286 | file_ext = os.path.splitext(file)[1].lower() 287 | 288 | # 处理文本文件 289 | if file_ext in text_exts: 290 | with open(file_path, "r", encoding="utf-8") as f: 291 | text = f.read() 292 | 293 | # 分割文本 294 | text_chunks = text_splitter.split_text(text) 295 | 296 | # 创建多模态文档 297 | for chunk in text_chunks: 298 | documents.append( 299 | MultiModalDocument( 300 | page_content=chunk, 301 | metadata={"source": file_path} 302 | ) 303 | ) 304 | 305 | # 处理图像文件 306 | elif file_ext in image_exts: 307 | # 使用MultiSpaceAI处理图像 308 | result = self.multispace_model.process( 309 | image=file_path 310 | ) 311 | 312 | # 创建多模态文档 313 | documents.append( 314 | MultiModalDocument( 315 | page_content=result["generated_text"], 316 | metadata={"source": file_path}, 317 | image_path=file_path 318 | ) 319 | ) 320 | 321 | # 处理音频文件 322 | elif file_ext in audio_exts: 323 | # 使用MultiSpaceAI处理音频 324 | result = self.multispace_model.process( 325 | audio=file_path 326 | ) 327 | 328 | # 创建多模态文档 329 | documents.append( 330 | MultiModalDocument( 331 | page_content=result["generated_text"], 332 | metadata={"source": file_path}, 333 | audio_path=file_path 334 | ) 335 | ) 336 | 337 | # 添加文档到向量存储 338 | self.add_documents(documents) 339 | 340 | def run( 341 | self, 342 | query: str, 343 | text: Optional[str] = None, 344 | image: Optional[str] = None, 345 | audio: Optional[str] = None, 346 | **kwargs 347 | ) -> Dict[str, Any]: 348 | """运行检索链 349 | 350 | Args: 351 | query: 查询 352 | text: 文本输入 353 | image: 图像文件路径 354 | audio: 音频文件路径 355 | **kwargs: 额外参数 356 | 357 | Returns: 358 | 检索链运行结果 359 | """ 360 | # 处理多模态输入 361 | if any([image, audio]): 362 | multimodal_result = self.multispace_model.process( 363 | text=text or query, 364 | image=image, 365 | audio=audio 366 | ) 367 | 368 | # 将多模态处理结果添加到输入 369 | multimodal_context = multimodal_result["generated_text"] 370 | else: 371 | multimodal_context = "" 372 | 373 | # 运行检索QA链 374 | return self.qa_chain( 375 | {"query": query, "multimodal_context": multimodal_context, "question": query}, 376 | **kwargs 377 | ) -------------------------------------------------------------------------------- /multispace/fusion/fusion_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | 多模态融合模块 3 | =========== 4 | 5 | 该模块包含用于融合多种模态特征的类,包括交叉注意力融合等方法。 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from typing import Optional, Dict, List, Union, Tuple 12 | import math 13 | 14 | from ..config import FusionConfig 15 | from ..utils.logger import setup_logger 16 | 17 | # 设置日志 18 | logger = setup_logger(__name__) 19 | 20 | 21 | class MultiheadCrossAttention(nn.Module): 22 | """ 23 | 多头交叉注意力模块,用于两种模态之间的特征融合。 24 | """ 25 | 26 | def __init__(self, hidden_dim: int, num_heads: int, dropout: float = 0.1): 27 | """ 28 | 初始化多头交叉注意力模块 29 | 30 | 参数: 31 | hidden_dim: 隐藏层维度 32 | num_heads: 注意力头数 33 | dropout: Dropout比例 34 | """ 35 | super().__init__() 36 | 37 | # 确保hidden_dim可以被num_heads整除 38 | assert hidden_dim % num_heads == 0, "hidden_dim必须能被num_heads整除" 39 | 40 | self.hidden_dim = hidden_dim 41 | self.num_heads = num_heads 42 | self.head_dim = hidden_dim // num_heads 43 | 44 | # 定义投影矩阵 45 | self.query_proj = nn.Linear(hidden_dim, hidden_dim) 46 | self.key_proj = nn.Linear(hidden_dim, hidden_dim) 47 | self.value_proj = nn.Linear(hidden_dim, hidden_dim) 48 | self.output_proj = nn.Linear(hidden_dim, hidden_dim) 49 | 50 | # Dropout层 51 | self.dropout = nn.Dropout(dropout) 52 | 53 | # 缩放因子 54 | self.scale = self.head_dim ** -0.5 55 | 56 | def forward(self, 57 | query: torch.Tensor, 58 | key: torch.Tensor, 59 | value: torch.Tensor, 60 | key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 61 | """ 62 | 前向传播 63 | 64 | 参数: 65 | query: 查询张量,形状为[batch_size, seq_len_q, hidden_dim] 66 | key: 键张量,形状为[batch_size, seq_len_k, hidden_dim] 67 | value: 值张量,形状为[batch_size, seq_len_v, hidden_dim] 68 | key_padding_mask: 键填充掩码,形状为[batch_size, seq_len_k] 69 | 70 | 返回: 71 | 注意力输出,形状为[batch_size, seq_len_q, hidden_dim] 72 | """ 73 | batch_size, seq_len_q, _ = query.size() 74 | _, seq_len_k, _ = key.size() 75 | 76 | # 投影并变形 77 | # [batch_size, seq_len, hidden_dim] -> [batch_size, seq_len, num_heads, head_dim] 78 | q = self.query_proj(query).view(batch_size, seq_len_q, self.num_heads, self.head_dim) 79 | k = self.key_proj(key).view(batch_size, seq_len_k, self.num_heads, self.head_dim) 80 | v = self.value_proj(value).view(batch_size, seq_len_k, self.num_heads, self.head_dim) 81 | 82 | # 调整维度顺序以便进行批量矩阵乘法 83 | # [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim] 84 | q = q.transpose(1, 2) 85 | k = k.transpose(1, 2) 86 | v = v.transpose(1, 2) 87 | 88 | # 计算注意力分数 89 | # [batch_size, num_heads, seq_len_q, head_dim] x [batch_size, num_heads, head_dim, seq_len_k] 90 | # -> [batch_size, num_heads, seq_len_q, seq_len_k] 91 | attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale 92 | 93 | # 应用键填充掩码 94 | if key_padding_mask is not None: 95 | # [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k] 96 | mask = key_padding_mask.unsqueeze(1).unsqueeze(2) 97 | attn_scores = attn_scores.masked_fill(mask == 0, float('-inf')) 98 | 99 | # 应用softmax获取注意力权重 100 | attn_weights = F.softmax(attn_scores, dim=-1) 101 | attn_weights = self.dropout(attn_weights) 102 | 103 | # 应用注意力权重 104 | # [batch_size, num_heads, seq_len_q, seq_len_k] x [batch_size, num_heads, seq_len_v, head_dim] 105 | # -> [batch_size, num_heads, seq_len_q, head_dim] 106 | output = torch.matmul(attn_weights, v) 107 | 108 | # 调整维度顺序并合并头 109 | # [batch_size, num_heads, seq_len_q, head_dim] -> [batch_size, seq_len_q, num_heads, head_dim] 110 | output = output.transpose(1, 2) 111 | 112 | # [batch_size, seq_len_q, num_heads, head_dim] -> [batch_size, seq_len_q, hidden_dim] 113 | output = output.reshape(batch_size, seq_len_q, self.hidden_dim) 114 | 115 | # 最终投影 116 | output = self.output_proj(output) 117 | 118 | return output 119 | 120 | 121 | class ModalAdapter(nn.Module): 122 | """ 123 | 模态适配器,用于处理不同模态的特征表示。 124 | """ 125 | 126 | def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.1): 127 | """ 128 | 初始化模态适配器 129 | 130 | 参数: 131 | input_dim: 输入维度 132 | hidden_dim: 隐藏层维度 133 | dropout: Dropout比例 134 | """ 135 | super().__init__() 136 | 137 | self.layer_norm = nn.LayerNorm(input_dim) 138 | self.fc1 = nn.Linear(input_dim, hidden_dim) 139 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 140 | self.dropout = nn.Dropout(dropout) 141 | 142 | def forward(self, x: torch.Tensor) -> torch.Tensor: 143 | """ 144 | 前向传播 145 | 146 | 参数: 147 | x: 输入特征 148 | 149 | 返回: 150 | 适配后的特征 151 | """ 152 | # 应用层归一化 153 | x_norm = self.layer_norm(x) 154 | 155 | # 两层前馈网络 156 | h = F.gelu(self.fc1(x_norm)) 157 | h = self.dropout(h) 158 | h = self.fc2(h) 159 | 160 | return h 161 | 162 | 163 | class MultimodalFusionLayer(nn.Module): 164 | """ 165 | 多模态融合层,用于多种模态之间的特征融合。 166 | """ 167 | 168 | def __init__(self, config: FusionConfig): 169 | """ 170 | 初始化多模态融合层 171 | 172 | 参数: 173 | config: 融合模块配置 174 | """ 175 | super().__init__() 176 | 177 | self.config = config 178 | self.hidden_dim = config.hidden_dim 179 | 180 | # 模态适配器(如果使用) 181 | if config.use_modal_adapters: 182 | self.text_adapter = ModalAdapter(config.hidden_dim, config.hidden_dim, config.dropout) 183 | self.image_adapter = ModalAdapter(config.hidden_dim, config.hidden_dim, config.dropout) 184 | self.audio_adapter = ModalAdapter(config.hidden_dim, config.hidden_dim, config.dropout) 185 | 186 | # 交叉注意力模块 187 | if config.fusion_type == "cross_attention": 188 | # 文本-图像交叉注意力 189 | self.text_image_attn = MultiheadCrossAttention( 190 | config.hidden_dim, config.num_attention_heads, config.dropout 191 | ) 192 | # 文本-语音交叉注意力 193 | self.text_audio_attn = MultiheadCrossAttention( 194 | config.hidden_dim, config.num_attention_heads, config.dropout 195 | ) 196 | # 图像-语音交叉注意力 197 | self.image_audio_attn = MultiheadCrossAttention( 198 | config.hidden_dim, config.num_attention_heads, config.dropout 199 | ) 200 | 201 | # 层归一化 202 | self.layer_norm1 = nn.LayerNorm(config.hidden_dim) 203 | self.layer_norm2 = nn.LayerNorm(config.hidden_dim) 204 | 205 | # 前馈神经网络 206 | self.ffn = nn.Sequential( 207 | nn.Linear(config.hidden_dim, config.hidden_dim * 4), 208 | nn.GELU(), 209 | nn.Dropout(config.dropout), 210 | nn.Linear(config.hidden_dim * 4, config.hidden_dim), 211 | nn.Dropout(config.dropout) 212 | ) 213 | 214 | def forward(self, 215 | text_features: Optional[torch.Tensor] = None, 216 | image_features: Optional[torch.Tensor] = None, 217 | audio_features: Optional[torch.Tensor] = None) -> torch.Tensor: 218 | """ 219 | 前向传播 220 | 221 | 参数: 222 | text_features: 文本特征,形状为[batch_size, seq_len_text, hidden_dim]或[batch_size, hidden_dim] 223 | image_features: 图像特征,形状为[batch_size, seq_len_image, hidden_dim]或[batch_size, hidden_dim] 224 | audio_features: 语音特征,形状为[batch_size, seq_len_audio, hidden_dim]或[batch_size, hidden_dim] 225 | 226 | 返回: 227 | 融合后的特征,形状为[batch_size, hidden_dim] 228 | """ 229 | # 检查至少有一种模态的特征 230 | if text_features is None and image_features is None and audio_features is None: 231 | raise ValueError("至少需要一种模态的特征") 232 | 233 | # 特征列表,用于记录有效的特征 234 | valid_features = [] 235 | 236 | # 处理文本特征 237 | if text_features is not None: 238 | # 确保特征是3D的 239 | if text_features.dim() == 2: 240 | text_features = text_features.unsqueeze(1) # [batch_size, hidden_dim] -> [batch_size, 1, hidden_dim] 241 | 242 | # 应用模态适配器(如果使用) 243 | if self.config.use_modal_adapters: 244 | text_features = self.text_adapter(text_features) 245 | 246 | valid_features.append(text_features) 247 | 248 | # 处理图像特征 249 | if image_features is not None: 250 | # 确保特征是3D的 251 | if image_features.dim() == 2: 252 | image_features = image_features.unsqueeze(1) 253 | 254 | # 应用模态适配器(如果使用) 255 | if self.config.use_modal_adapters: 256 | image_features = self.image_adapter(image_features) 257 | 258 | valid_features.append(image_features) 259 | 260 | # 处理语音特征 261 | if audio_features is not None: 262 | # 确保特征是3D的 263 | if audio_features.dim() == 2: 264 | audio_features = audio_features.unsqueeze(1) 265 | 266 | # 应用模态适配器(如果使用) 267 | if self.config.use_modal_adapters: 268 | audio_features = self.audio_adapter(audio_features) 269 | 270 | valid_features.append(audio_features) 271 | 272 | # 根据融合类型进行处理 273 | if self.config.fusion_type == "cross_attention": 274 | # 如果只有一种模态,直接返回 275 | if len(valid_features) == 1: 276 | fused_features = valid_features[0] 277 | else: 278 | # 多模态交叉注意力 279 | fused_features = [] 280 | 281 | # 获取每种模态的特征 282 | modal_features = { 283 | "text": text_features, 284 | "image": image_features, 285 | "audio": audio_features 286 | } 287 | 288 | # 文本-图像融合 289 | if text_features is not None and image_features is not None: 290 | text_image_fused = self.text_image_attn(text_features, image_features, image_features) 291 | fused_features.append(text_image_fused) 292 | 293 | # 文本-语音融合 294 | if text_features is not None and audio_features is not None: 295 | text_audio_fused = self.text_audio_attn(text_features, audio_features, audio_features) 296 | fused_features.append(text_audio_fused) 297 | 298 | # 图像-语音融合 299 | if image_features is not None and audio_features is not None: 300 | image_audio_fused = self.image_audio_attn(image_features, audio_features, audio_features) 301 | fused_features.append(image_audio_fused) 302 | 303 | # 堆叠并平均所有融合特征 304 | fused_features = torch.stack(fused_features).mean(dim=0) 305 | 306 | elif self.config.fusion_type == "concat": 307 | # 将所有特征在序列维度上拼接 308 | fused_features = torch.cat(valid_features, dim=1) 309 | 310 | elif self.config.fusion_type == "sum": 311 | # 将所有特征加和 312 | # 首先确保所有特征的序列长度相同(取最短的) 313 | min_length = min(f.size(1) for f in valid_features) 314 | valid_features = [f[:, :min_length, :] for f in valid_features] 315 | 316 | # 加和特征 317 | fused_features = sum(valid_features) 318 | 319 | else: 320 | raise ValueError(f"不支持的融合类型: {self.config.fusion_type}") 321 | 322 | # 层归一化和残差连接 323 | fused_features = self.layer_norm1(fused_features) 324 | 325 | # 应用前馈神经网络 326 | ffn_output = self.ffn(fused_features) 327 | fused_features = self.layer_norm2(fused_features + ffn_output) 328 | 329 | # 对序列维度进行平均池化,得到固定维度的表示 330 | fused_features = fused_features.mean(dim=1) 331 | 332 | return fused_features 333 | 334 | 335 | class MultimodalFusionModule(nn.Module): 336 | """ 337 | 多模态融合模块,用于融合多种模态的特征表示。 338 | """ 339 | 340 | def __init__(self, config: FusionConfig, device: Optional[str] = None): 341 | """ 342 | 初始化多模态融合模块 343 | 344 | 参数: 345 | config: 融合模块配置 346 | device: 运行设备 347 | """ 348 | super().__init__() 349 | self.config = config 350 | self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") 351 | 352 | # 多层融合层 353 | self.fusion_layers = nn.ModuleList([ 354 | MultimodalFusionLayer(config) for _ in range(config.num_layers) 355 | ]) 356 | 357 | # 移动模型到指定设备 358 | self.to(self.device) 359 | 360 | logger.info("多模态融合模块初始化完成") 361 | 362 | def forward(self, 363 | text_features: Optional[torch.Tensor] = None, 364 | image_features: Optional[torch.Tensor] = None, 365 | audio_features: Optional[torch.Tensor] = None) -> torch.Tensor: 366 | """ 367 | 前向传播 368 | 369 | 参数: 370 | text_features: 文本特征 371 | image_features: 图像特征 372 | audio_features: 语音特征 373 | 374 | 返回: 375 | 融合后的特征 376 | """ 377 | # 检查至少有一种模态的特征 378 | if text_features is None and image_features is None and audio_features is None: 379 | raise ValueError("至少需要一种模态的特征") 380 | 381 | # 移动特征到设备 382 | if text_features is not None: 383 | text_features = text_features.to(self.device) 384 | if image_features is not None: 385 | image_features = image_features.to(self.device) 386 | if audio_features is not None: 387 | audio_features = audio_features.to(self.device) 388 | 389 | # 通过多层融合层 390 | fused_features = None 391 | for layer in self.fusion_layers: 392 | fused_features = layer(text_features, image_features, audio_features) 393 | 394 | # 更新每种模态的特征,以便下一层使用 395 | if text_features is not None: 396 | if text_features.dim() == 2: 397 | text_features = fused_features.unsqueeze(1) 398 | else: 399 | # 广播融合特征 400 | text_features = text_features + fused_features.unsqueeze(1) 401 | 402 | if image_features is not None: 403 | if image_features.dim() == 2: 404 | image_features = fused_features.unsqueeze(1) 405 | else: 406 | image_features = image_features + fused_features.unsqueeze(1) 407 | 408 | if audio_features is not None: 409 | if audio_features.dim() == 2: 410 | audio_features = fused_features.unsqueeze(1) 411 | else: 412 | audio_features = audio_features + fused_features.unsqueeze(1) 413 | 414 | return fused_features 415 | 416 | def fuse(self, 417 | text_features: Optional[torch.Tensor] = None, 418 | image_features: Optional[torch.Tensor] = None, 419 | audio_features: Optional[torch.Tensor] = None) -> torch.Tensor: 420 | """ 421 | 融合多种模态的特征 422 | 423 | 参数: 424 | text_features: 文本特征 425 | image_features: 图像特征 426 | audio_features: 语音特征 427 | 428 | 返回: 429 | 融合后的特征 430 | """ 431 | self.eval() # 设置为评估模式 432 | 433 | with torch.no_grad(): 434 | fused_features = self.forward(text_features, image_features, audio_features) 435 | 436 | return fused_features --------------------------------------------------------------------------------