├── utils ├── __init__.py ├── image_processing.py ├── dicom_handler.py └── report_formatter.py ├── crew ├── __init__.py ├── agents.py ├── tasks.py └── process.py ├── tools ├── __init__.py ├── knowledge_retrieval.py ├── report_generation.py └── ct_analysis.py ├── tests ├── __init__.py ├── test_agents.py ├── test_integration.py └── test_tools.py ├── langchain_components ├── __init__.py ├── embeddings.py ├── vectorstore.py ├── document_loaders.py └── retriever.py ├── requirements.txt ├── config.py ├── readme.md └── main.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 实用工具模块 3 | """ 4 | # 此文件使模块可导入 5 | -------------------------------------------------------------------------------- /crew/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | CrewAI模块: 提供专门处理医学CT图像的智能体系统 3 | """ 4 | # 此文件使模块可导入 5 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 工具模块: 为CrewAI Agent提供各种任务工具 3 | """ 4 | # 此文件使模块可导入 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 测试模块: 包含针对medical-ct-agent各组件的单元测试和集成测试 3 | """ 4 | # 此文件使模块可导入 5 | -------------------------------------------------------------------------------- /langchain_components/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | LangChain组件模块: 提供LangChain配置和实现 3 | """ 4 | # 此文件使模块可导入 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 核心框架 2 | crewai>=0.28.0 3 | langchain>=0.0.267 4 | langchain-community>=0.0.13 5 | 6 | # LLM相关 7 | openai>=1.3.0 8 | sentence-transformers>=2.2.2 9 | 10 | # 医学影像处理 11 | pydicom>=2.4.0 12 | nibabel>=5.1.0 13 | pillow>=10.0.0 14 | opencv-python>=4.8.0 15 | torchvision>=0.16.0 16 | transformers>=4.35.0 17 | 18 | # 向量数据库 19 | chromadb>=0.4.18 20 | 21 | # 工具和辅助 22 | numpy>=1.24.0 23 | pandas>=2.1.0 24 | matplotlib>=3.8.0 25 | scikit-image>=0.21.0 26 | 27 | # 测试 28 | pytest>=7.4.0 29 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | 配置文件: 包含项目的所有配置参数 3 | """ 4 | import os 5 | from pathlib import Path 6 | 7 | # 基础路径 8 | BASE_DIR = Path(__file__).resolve().parent 9 | DATA_DIR = BASE_DIR / "data" 10 | MEDICAL_DOCS_DIR = DATA_DIR / "medical_docs" 11 | SAMPLE_IMAGES_DIR = DATA_DIR / "sample_images" 12 | VECTOR_DB_DIR = DATA_DIR / "vector_db" 13 | 14 | # 确保所需目录存在 15 | os.makedirs(MEDICAL_DOCS_DIR, exist_ok=True) 16 | os.makedirs(SAMPLE_IMAGES_DIR, exist_ok=True) 17 | os.makedirs(VECTOR_DB_DIR, exist_ok=True) 18 | 19 | # 模型配置 20 | BIOMEDCLIP_MODEL_NAME = "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" 21 | EMBEDDING_MODEL_NAME = "pritamdeka/S-PubMedBert-MS-MARCO" 22 | LLM_MODEL_NAME = "medalpaca/medalpaca-7b" # 示例,实际使用时可能需要调整 23 | 24 | # LangChain配置 25 | CHUNK_SIZE = 1000 26 | CHUNK_OVERLAP = 200 27 | TOP_K_RETRIEVALS = 5 28 | SIMILARITY_THRESHOLD = 0.75 29 | 30 | # DICOM处理配置 31 | DEFAULT_WINDOW_CENTER = 50 32 | DEFAULT_WINDOW_WIDTH = 400 33 | 34 | # 创建日志目录 35 | LOG_DIR = BASE_DIR / "logs" 36 | os.makedirs(LOG_DIR, exist_ok=True) 37 | 38 | # API密钥配置 (请在实际部署中使用环境变量或安全存储) 39 | OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") 40 | 41 | # CrewAI配置 42 | # 如果需要使用OpenAI模型,修改为对应模型名称 43 | CREWAI_LLM_MODEL = "gpt-4" 44 | CREWAI_VERBOSE = True 45 | 46 | # 报告生成配置 47 | REPORT_TEMPLATE = """ 48 | # 医学CT影像诊断报告 49 | 50 | ## 患者信息 51 | - **检查日期**: {examination_date} 52 | - **检查类型**: CT扫描 53 | - **检查部位**: {examination_area} 54 | 55 | ## 影像发现 56 | {image_findings} 57 | 58 | ## 分析与解释 59 | {analysis_and_interpretation} 60 | 61 | ## 诊断意见 62 | {diagnostic_opinion} 63 | 64 | ## 建议 65 | {recommendations} 66 | 67 | ## 报告日期 68 | {report_date} 69 | 70 | """ 71 | -------------------------------------------------------------------------------- /tests/test_agents.py: -------------------------------------------------------------------------------- 1 | """ 2 | 测试CrewAI智能体: 验证医学CT分析智能体的创建和功能 3 | """ 4 | import unittest 5 | from unittest.mock import patch, MagicMock 6 | 7 | from langchain.schema.language_model import BaseLanguageModel 8 | from crewai import Agent 9 | 10 | from crew.agents import ( 11 | get_image_analyst_agent, 12 | get_medical_researcher_agent, 13 | get_radiologist_agent, 14 | create_medical_ct_agents 15 | ) 16 | 17 | 18 | class TestAgents(unittest.TestCase): 19 | """测试CrewAI智能体""" 20 | 21 | def setUp(self): 22 | """设置测试环境""" 23 | # 创建模拟的语言模型 24 | self.mock_llm = MagicMock(spec=BaseLanguageModel) 25 | 26 | def test_get_image_analyst_agent(self): 27 | """测试创建医学影像分析师智能体""" 28 | # 创建智能体 29 | agent = get_image_analyst_agent(llm=self.mock_llm) 30 | 31 | # 验证智能体属性 32 | self.assertIsInstance(agent, Agent) 33 | self.assertEqual(agent.role, "医学影像分析师") 34 | self.assertTrue("分析CT图像" in agent.goal) 35 | self.assertTrue(agent.verbose) 36 | self.assertTrue(agent.allow_delegation) 37 | 38 | def test_get_medical_researcher_agent(self): 39 | """测试创建医学研究员智能体""" 40 | # 创建智能体 41 | agent = get_medical_researcher_agent(llm=self.mock_llm) 42 | 43 | # 验证智能体属性 44 | self.assertIsInstance(agent, Agent) 45 | self.assertEqual(agent.role, "医学研究员") 46 | self.assertTrue("检索" in agent.goal) 47 | self.assertTrue(agent.verbose) 48 | self.assertTrue(agent.allow_delegation) 49 | 50 | def test_get_radiologist_agent(self): 51 | """测试创建放射科医师智能体""" 52 | # 创建智能体 53 | agent = get_radiologist_agent(llm=self.mock_llm) 54 | 55 | # 验证智能体属性 56 | self.assertIsInstance(agent, Agent) 57 | self.assertEqual(agent.role, "放射科医师") 58 | self.assertTrue("诊断报告" in agent.goal) 59 | self.assertTrue(agent.verbose) 60 | self.assertTrue(agent.allow_delegation) 61 | 62 | def test_create_medical_ct_agents(self): 63 | """测试创建所有医学CT分析智能体""" 64 | # 创建所有智能体 65 | agents = create_medical_ct_agents(llm=self.mock_llm) 66 | 67 | # 验证返回的字典 68 | self.assertIsInstance(agents, dict) 69 | self.assertIn("image_analyst", agents) 70 | self.assertIn("medical_researcher", agents) 71 | self.assertIn("radiologist", agents) 72 | 73 | # 验证每个智能体的类型 74 | for name, agent in agents.items(): 75 | self.assertIsInstance(agent, Agent) 76 | 77 | def test_agent_delegation_config(self): 78 | """测试智能体委派配置""" 79 | # 测试不允许委派 80 | agent = get_image_analyst_agent(llm=self.mock_llm, allow_delegation=False) 81 | self.assertFalse(agent.allow_delegation) 82 | 83 | # 测试允许委派 84 | agent = get_image_analyst_agent(llm=self.mock_llm, allow_delegation=True) 85 | self.assertTrue(agent.allow_delegation) 86 | 87 | 88 | if __name__ == "__main__": 89 | unittest.main() 90 | -------------------------------------------------------------------------------- /crew/agents.py: -------------------------------------------------------------------------------- 1 | """ 2 | 定义医学CT分析的智能体: 影像分析师、医学研究员、放射科医师 3 | """ 4 | from typing import Dict, Any, List, Optional 5 | import os 6 | 7 | from crewai import Agent 8 | from langchain_core.language_models import BaseLanguageModel 9 | 10 | from config import CREWAI_LLM_MODEL 11 | 12 | 13 | def get_image_analyst_agent( 14 | llm: Optional[BaseLanguageModel] = None, 15 | allow_delegation: bool = True 16 | ) -> Agent: 17 | """ 18 | 创建医学影像分析师智能体 19 | 20 | Args: 21 | llm: 语言模型 22 | allow_delegation: 是否允许委派任务 23 | 24 | Returns: 25 | 医学影像分析师智能体 26 | """ 27 | return Agent( 28 | role="医学影像分析师", 29 | goal="精确分析CT图像,识别关键特征并提供专业描述", 30 | backstory=""" 31 | 你是一位经验丰富的医学影像分析师,拥有医学影像学博士学位和十年临床经验。 32 | 你擅长使用先进的AI模型分析CT图像,能够发现细微的异常特征。 33 | 你的专长是结合图像特征与解剖学知识,提供准确的初步分析。 34 | 你对各种病理表现的CT影像特征有深入了解,能够识别微妙的变化和异常模式。 35 | """, 36 | verbose=True, 37 | llm=llm, 38 | allow_delegation=allow_delegation 39 | ) 40 | 41 | 42 | def get_medical_researcher_agent( 43 | llm: Optional[BaseLanguageModel] = None, 44 | allow_delegation: bool = True 45 | ) -> Agent: 46 | """ 47 | 创建医学研究员智能体 48 | 49 | Args: 50 | llm: 语言模型 51 | allow_delegation: 是否允许委派任务 52 | 53 | Returns: 54 | 医学研究员智能体 55 | """ 56 | return Agent( 57 | role="医学研究员", 58 | goal="检索和分析与CT发现相关的最新医学知识,提供科学依据", 59 | backstory=""" 60 | 你是一位医学研究员,拥有医学博士学位和流行病学硕士学位。 61 | 你擅长检索和解读最新的医学研究文献,能够将复杂的医学概念转化为清晰的解释。 62 | 你对循证医学有深入理解,能够评估证据的强度和适用性。 63 | 你熟悉各种疾病的最新诊断标准、治疗指南和预后因素,能够提供全面的医学背景知识。 64 | 你善于整合多源信息,为临床决策提供全面的知识支持。 65 | """, 66 | verbose=True, 67 | llm=llm, 68 | allow_delegation=allow_delegation 69 | ) 70 | 71 | 72 | def get_radiologist_agent( 73 | llm: Optional[BaseLanguageModel] = None, 74 | allow_delegation: bool = True 75 | ) -> Agent: 76 | """ 77 | 创建放射科医师智能体 78 | 79 | Args: 80 | llm: 语言模型 81 | allow_delegation: 是否允许委派任务 82 | 83 | Returns: 84 | 放射科医师智能体 85 | """ 86 | return Agent( 87 | role="放射科医师", 88 | goal="整合图像分析和医学知识,提供专业的诊断报告和治疗建议", 89 | backstory=""" 90 | 你是一位资深放射科医师,拥有放射诊断学副教授职称和15年临床经验。 91 | 你在胸部和腹部CT诊断领域有特殊专长,曾在顶级医学期刊发表多篇研究论文。 92 | 你善于整合影像发现与临床信息,提供全面准确的诊断和治疗建议。 93 | 你具有出色的医学判断力,能够权衡不同诊断假设的可能性,并提供合理的鉴别诊断。 94 | 你精通医学报告写作,能够撰写专业、清晰且符合临床需求的诊断报告。 95 | """, 96 | verbose=True, 97 | llm=llm, 98 | allow_delegation=allow_delegation 99 | ) 100 | 101 | 102 | def create_medical_ct_agents( 103 | llm: Optional[BaseLanguageModel] = None 104 | ) -> Dict[str, Agent]: 105 | """ 106 | 创建所有医学CT分析的智能体 107 | 108 | Args: 109 | llm: 语言模型,如果未提供将使用默认模型 110 | 111 | Returns: 112 | 包含所有智能体的字典 113 | """ 114 | agents = { 115 | "image_analyst": get_image_analyst_agent(llm=llm), 116 | "medical_researcher": get_medical_researcher_agent(llm=llm), 117 | "radiologist": get_radiologist_agent(llm=llm) 118 | } 119 | 120 | return agents 121 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Medical CT Agent 2 | 3 | ![Medical CT Analysis](https://img.shields.io/badge/AI-Medical%20Imaging-brightgreen) 4 | ![Python](https://img.shields.io/badge/Python-3.9%2B-blue) 5 | ![CrewAI](https://img.shields.io/badge/CrewAI-0.28.0%2B-orange) 6 | ![LangChain](https://img.shields.io/badge/LangChain-0.0.267%2B-yellow) 7 | ![BiomedCLIP](https://img.shields.io/badge/BiomedCLIP-Vision%20Model-violet) 8 | ![RAG](https://img.shields.io/badge/Architecture-RAG-red) 9 | 10 | ## 项目概述 11 | 12 | CTAnalyticsAgent 是一个先进的医学CT图像分析系统,使用多智能体协作框架和最新的AI技术自动分析CT图像、检索医学知识和生成专业诊断报告。该系统集成了计算机视觉、自然语言处理和知识检索技术,为放射科医生提供智能辅助诊断工具。 13 | 14 | ## 核心功能 15 | 16 | - 🧠 **多智能体协作**:使用CrewAI构建专家智能体团队,包括影像分析师、医学研究员和放射科医师 17 | - 🔍 **CT图像自动分析**:通过BiomedCLIP模型分析CT图像,识别关键特征和异常 18 | - 📚 **医学知识检索**:基于RAG架构的医学知识库检索,提供相关医学依据 19 | - 📝 **专业诊断报告生成**:自动生成结构化的医学诊断报告 20 | - 📊 **历史扫描对比分析**:比较多次CT扫描,自动检测变化和疾病进展 21 | 22 | ## 技术架构 23 | 24 | ![Architecture](https://mermaid.ink/img/pako:eNqNVE1v2zAM_SuETgm6ediBw7IchhbdShTIoUCAYD0MBWKrsYHIUiTJSdqh_32U7I_YyTrsYlB8j48PKoDbkqBES82ZMtBYLKEGQ8XREskQHQX9mWBfMI67unCxYsU0pQVNBvhgM58P9lrVkAQS3U7eo3QvngR62GVpzaISB-v5W0O1wN6SC92G_wlntDT3mXaqzuIjcUUcwtOJtmVpUrJRLQ478VbKNvgixVDNZLzm1WWnNvuX4r2gTbRVlmspWFCKLBuMrpbRLnI8Pi4HWdbGtnGNOtVmogaxSG1fJfR4jus2GGb5YmEqt7XnruDR42ksUTBz_1eS3OKGWe_NihYVmxj0Ok6tZkzjFx9j2tmzNxLJuBtaBjNTU3VGreBv5QdYG7IKiLk1vKFNNQtVp3moRzwk4yg_p6UptKHnxaIdDLznaeK_U8kFWdkjvT4OrmZRj9YLbShBPQ3uo-mmG02lq5FluR9ibG494sGMeZ9G9ZNV7MxVtTfGPYnKRzNj-d7Vw82FwZW5qhqsg5eW7bcweSNNUR5hHQUA5V2xLTGswqBizfMBlt0Y7AxJnKqE1TdUR3gik3idSmOoBnJ-p8s5_tX3VruD_2dXwp6MCg0t6gNzT3XP_n5F_2JI9Md9dpBPwcvTCIlXI1QfISyvRqhOhUl9tXsK9me6ZBhdkfT-32ywJzwqboPyXDy4dSiRKtKEQzRA6z7nc1TOoPwe_ADwShbr) 25 | 26 | ### 前沿AI技术栈 27 | 28 | - **CrewAI**:最新的多智能体框架,允许角色专业化智能体协作 29 | - **LangChain**:组合LLM与应用的顶级框架,实现复杂AI工作流 30 | - **BiomedCLIP**:微软专为医学图像分析优化的多模态AI模型 31 | - **RAG架构**:检索增强生成技术,提供基于医学文献的可靠信息 32 | - **ChromaDB**:高效向量数据库,支持语义搜索 33 | - **DICOM处理**:专业医学影像格式解析与处理 34 | - **OpenAI集成**:与GPT-4等高级模型集成,支持医学推理 35 | 36 | ### 系统组件 37 | 38 | 1. **CT影像分析模块** 39 | - BiomedCLIP视觉模型集成 40 | - 医学图像预处理流水线 41 | - 自适应窗宽窗位调整算法 42 | 43 | 2. **知识检索引擎** 44 | - 医学文档向量化与索引 45 | - 多查询生成策略 46 | - 相关文档重排序与提取 47 | 48 | 3. **多智能体系统** 49 | - 专业化角色智能体 50 | - 任务规划与协调 51 | - 结果整合与推理 52 | 53 | 4. **报告生成系统** 54 | - 结构化医学诊断模板 55 | - 历史对比分析 56 | - 临床建议生成 57 | 58 | 59 | ## 项目结构 60 | 61 | ``` 62 | medical-ct-agent/ 63 | ├── main.py # 项目主入口,初始化和启动系统 64 | ├── config.py # 配置文件(API密钥、模型路径等) 65 | ├── crew/ # CrewAI相关实现 66 | │ ├── agents.py # 定义专业Agent(影像分析师、医学研究员、放射科医师) 67 | │ ├── tasks.py # 定义CrewAI任务 68 | │ └── process.py # 定义CrewAI工作流程 69 | ├── tools/ # CrewAI Agent使用的工具 70 | │ ├── ct_analysis.py # CT影像分析工具(使用BiomedCLIP) 71 | │ ├── knowledge_retrieval.py # 知识检索工具(使用LangChain RAG) 72 | │ └── report_generation.py # 报告生成工具 73 | ├── langchain_components/ # LangChain组件配置 74 | │ ├── document_loaders.py # 配置文档加载器 75 | │ ├── embeddings.py # 配置嵌入模型 76 | │ ├── vectorstore.py # 配置向量存储 77 | │ └── retriever.py # 配置检索器 78 | └── utils/ # 辅助工具 79 | ├── image_processing.py # 图像预处理功能 80 | ├── dicom_handler.py # DICOM格式处理 81 | └── report_formatter.py # 报告格式化 82 | ``` 83 | 84 | ## 算法与技术亮点 85 | 86 | - 🔄 **多模态融合**: 将图像数据与医学文本知识无缝结合 87 | - 🔗 **可解释AI**: 智能体提供诊断推理过程,增强医生信任 88 | - 🧮 **自适应窗宽处理**: 优化CT图像对比度,增强病理特征可见度 89 | - 📈 **多查询生成**: 通过LLM分解复杂医学描述为多个精确查询 90 | - 📌 **相似度重排序**: 使用余弦相似度重新排序检索结果,提高相关性 91 | - 🏥 **专业报告格式化**: 符合医学标准的结构化报告生成 92 | 93 | ## 运行示例 94 | 95 | ```bash 96 | # 基本使用 97 | python main.py --image_path ./data/sample_images/chest_ct_001.dcm --output_dir ./reports 98 | 99 | # 多图像分析 100 | python main.py --image_path ./data/sample_images/ --model openai 101 | 102 | # 使用本地模型 103 | python main.py --image_path ./data/sample_images/ --model local 104 | -------------------------------------------------------------------------------- /langchain_components/embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | 嵌入模型模块: 提供医学文本嵌入功能 3 | """ 4 | import os 5 | from typing import List, Dict, Any, Optional, Union 6 | 7 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings 8 | from langchain.embeddings import OpenAIEmbeddings 9 | from langchain_core.embeddings import Embeddings 10 | 11 | from config import EMBEDDING_MODEL_NAME, OPENAI_API_KEY 12 | 13 | 14 | def get_medical_embedding_model( 15 | model_name: str = EMBEDDING_MODEL_NAME, 16 | use_openai: bool = False, 17 | device: str = "cpu" 18 | ) -> Embeddings: 19 | """ 20 | 获取医学文本嵌入模型 21 | 22 | Args: 23 | model_name: 模型名称,默认使用配置中设置的医学嵌入模型 24 | use_openai: 是否使用OpenAI嵌入模型,需要API密钥 25 | device: 模型运行设备 ("cpu" 或 "cuda") 26 | 27 | Returns: 28 | 嵌入模型实例 29 | """ 30 | if use_openai and OPENAI_API_KEY: 31 | # 如果选择使用OpenAI的嵌入模型 32 | return OpenAIEmbeddings( 33 | model="text-embedding-3-small" if model_name == "default" else model_name, 34 | openai_api_key=OPENAI_API_KEY 35 | ) 36 | else: 37 | # 使用本地Hugging Face模型 38 | model_kwargs = {'device': device} 39 | encode_kwargs = {'normalize_embeddings': True} 40 | 41 | return HuggingFaceEmbeddings( 42 | model_name=model_name, 43 | model_kwargs=model_kwargs, 44 | encode_kwargs=encode_kwargs 45 | ) 46 | 47 | 48 | def get_pubmedbert_embedding_model(device: str = "cpu") -> Embeddings: 49 | """ 50 | 获取专门优化用于医学文本的S-PubMedBert嵌入模型 51 | 52 | Args: 53 | device: 模型运行设备 ("cpu" 或 "cuda") 54 | 55 | Returns: 56 | PubMedBERT嵌入模型实例 57 | """ 58 | return get_medical_embedding_model( 59 | model_name="pritamdeka/S-PubMedBert-MS-MARCO", 60 | device=device 61 | ) 62 | 63 | 64 | def get_biomedclip_embedding_model(device: str = "cpu") -> Embeddings: 65 | """ 66 | 获取BiomedCLIP的文本嵌入模型部分 67 | 68 | Args: 69 | device: 模型运行设备 ("cpu" 或 "cuda") 70 | 71 | Returns: 72 | BiomedCLIP文本嵌入模型实例 73 | """ 74 | return get_medical_embedding_model( 75 | model_name="microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", 76 | device=device 77 | ) 78 | 79 | 80 | def compare_embeddings(text1: str, text2: str, embedding_model: Optional[Embeddings] = None) -> float: 81 | """ 82 | 计算两段文本的嵌入相似度 83 | 84 | Args: 85 | text1: 第一段文本 86 | text2: 第二段文本 87 | embedding_model: 嵌入模型,如果未提供则使用默认医学模型 88 | 89 | Returns: 90 | 余弦相似度 (0-1之间的值,越大表示越相似) 91 | """ 92 | import numpy as np 93 | from sklearn.metrics.pairwise import cosine_similarity 94 | 95 | # 如果未提供嵌入模型,使用默认模型 96 | if embedding_model is None: 97 | embedding_model = get_medical_embedding_model() 98 | 99 | # 获取两段文本的嵌入 100 | embedding1 = embedding_model.embed_query(text1) 101 | embedding2 = embedding_model.embed_query(text2) 102 | 103 | # 计算余弦相似度 104 | similarity = cosine_similarity( 105 | np.array(embedding1).reshape(1, -1), 106 | np.array(embedding2).reshape(1, -1) 107 | )[0][0] 108 | 109 | return float(similarity) 110 | 111 | 112 | def batch_embed_texts( 113 | texts: List[str], 114 | embedding_model: Optional[Embeddings] = None, 115 | chunk_size: int = 100 116 | ) -> List[List[float]]: 117 | """ 118 | 批量嵌入多个文本 119 | 120 | Args: 121 | texts: 要嵌入的文本列表 122 | embedding_model: 嵌入模型,如果未提供则使用默认医学模型 123 | chunk_size: 每批处理的文本数量 124 | 125 | Returns: 126 | 嵌入向量列表 127 | """ 128 | # 如果未提供嵌入模型,使用默认模型 129 | if embedding_model is None: 130 | embedding_model = get_medical_embedding_model() 131 | 132 | # 批量处理文本 133 | embeddings = [] 134 | for i in range(0, len(texts), chunk_size): 135 | batch = texts[i:i+chunk_size] 136 | batch_embeddings = embedding_model.embed_documents(batch) 137 | embeddings.extend(batch_embeddings) 138 | print(f"已处理 {min(i+chunk_size, len(texts))}/{len(texts)} 个文本") 139 | 140 | return embeddings 141 | -------------------------------------------------------------------------------- /utils/image_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | CT图像预处理模块: 提供CT图像的加载、预处理和增强功能 3 | """ 4 | import os 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | from typing import Union, Tuple, List, Optional 10 | 11 | from config import DEFAULT_WINDOW_CENTER, DEFAULT_WINDOW_WIDTH 12 | 13 | 14 | def load_image(image_path: str) -> np.ndarray: 15 | """ 16 | 加载图像文件 17 | 18 | Args: 19 | image_path: 图像文件路径 20 | 21 | Returns: 22 | 加载的图像数组 23 | """ 24 | if not os.path.exists(image_path): 25 | raise FileNotFoundError(f"图像文件不存在: {image_path}") 26 | 27 | # 使用PIL加载图像并转换为numpy数组 28 | image = Image.open(image_path) 29 | image_array = np.array(image) 30 | 31 | # 如果是RGB图像,转换为灰度图 32 | if len(image_array.shape) == 3 and image_array.shape[2] >= 3: 33 | image_array = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) 34 | 35 | return image_array 36 | 37 | 38 | def resize_image(image: np.ndarray, size: Tuple[int, int] = (224, 224)) -> np.ndarray: 39 | """ 40 | 调整图像尺寸 41 | 42 | Args: 43 | image: 输入图像 44 | size: 目标尺寸 (高度, 宽度) 45 | 46 | Returns: 47 | 调整尺寸后的图像 48 | """ 49 | return cv2.resize(image, size) 50 | 51 | 52 | def normalize_image(image: np.ndarray) -> np.ndarray: 53 | """ 54 | 标准化图像,使其值范围在0-1之间 55 | 56 | Args: 57 | image: 输入图像 58 | 59 | Returns: 60 | 标准化后的图像 61 | """ 62 | if image.max() == image.min(): 63 | return np.zeros_like(image, dtype=np.float32) 64 | 65 | normalized = (image - image.min()) / (image.max() - image.min()) 66 | return normalized 67 | 68 | 69 | def apply_window_level( 70 | image: np.ndarray, 71 | window_center: int = DEFAULT_WINDOW_CENTER, 72 | window_width: int = DEFAULT_WINDOW_WIDTH 73 | ) -> np.ndarray: 74 | """ 75 | 应用窗宽窗位调整,用于CT图像的对比度优化 76 | 77 | Args: 78 | image: 输入CT图像 79 | window_center: 窗位(WL) 80 | window_width: 窗宽(WW) 81 | 82 | Returns: 83 | 调整后的图像 84 | """ 85 | min_value = window_center - window_width // 2 86 | max_value = window_center + window_width // 2 87 | 88 | windowed = np.clip(image, min_value, max_value) 89 | windowed = (windowed - min_value) / (max_value - min_value) 90 | 91 | return windowed 92 | 93 | 94 | def enhance_contrast(image: np.ndarray, alpha: float = 1.5, beta: int = 0) -> np.ndarray: 95 | """ 96 | 增强图像对比度 97 | 98 | Args: 99 | image: 输入图像 100 | alpha: 对比度控制参数 101 | beta: 亮度控制参数 102 | 103 | Returns: 104 | 对比度增强后的图像 105 | """ 106 | # 确保图像值在0-1范围内 107 | if image.max() > 1.0: 108 | image = normalize_image(image) 109 | 110 | # 应用对比度增强 111 | enhanced = np.clip(alpha * image + beta, 0, 1) 112 | return enhanced 113 | 114 | 115 | def denoise_image(image: np.ndarray, strength: int = 7) -> np.ndarray: 116 | """ 117 | 降噪处理 118 | 119 | Args: 120 | image: 输入图像 121 | strength: 降噪强度 122 | 123 | Returns: 124 | 降噪后的图像 125 | """ 126 | # 确保图像是正确的类型和范围 127 | if image.max() <= 1.0: 128 | image_8bit = (image * 255).astype(np.uint8) 129 | else: 130 | image_8bit = image.astype(np.uint8) 131 | 132 | # 应用非局部均值去噪 133 | denoised = cv2.fastNlMeansDenoising(image_8bit, None, strength, 7, 21) 134 | 135 | # 如果原图是浮点型,返回浮点型结果 136 | if image.max() <= 1.0: 137 | return denoised / 255.0 138 | 139 | return denoised 140 | 141 | 142 | def preprocess_for_biomedclip(image_path: str) -> np.ndarray: 143 | """ 144 | 预处理图像以适应BiomedCLIP模型输入要求 145 | 146 | Args: 147 | image_path: 图像文件路径 148 | 149 | Returns: 150 | 预处理后的图像数组 151 | """ 152 | # 加载图像 153 | image = load_image(image_path) 154 | 155 | # 应用窗宽窗位调整 156 | image = apply_window_level(image) 157 | 158 | # 调整大小为模型输入尺寸 159 | image = resize_image(image, (224, 224)) 160 | 161 | # 转换为PIL图像以应用torchvision转换 162 | image_pil = Image.fromarray((image * 255).astype(np.uint8)) 163 | 164 | # 应用BiomedCLIP预处理 165 | transform = transforms.Compose([ 166 | transforms.ToTensor(), 167 | transforms.Normalize(mean=[0.5], std=[0.5]) 168 | ]) 169 | 170 | return transform(image_pil).numpy() 171 | 172 | 173 | def process_ct_batch(image_paths: List[str]) -> List[np.ndarray]: 174 | """ 175 | 批量处理多个CT图像 176 | 177 | Args: 178 | image_paths: CT图像路径列表 179 | 180 | Returns: 181 | 处理后的图像列表 182 | """ 183 | processed_images = [] 184 | for image_path in image_paths: 185 | try: 186 | processed_image = preprocess_for_biomedclip(image_path) 187 | processed_images.append(processed_image) 188 | except Exception as e: 189 | print(f"处理图像 {image_path} 时出错: {e}") 190 | 191 | return processed_images 192 | -------------------------------------------------------------------------------- /langchain_components/vectorstore.py: -------------------------------------------------------------------------------- 1 | """ 2 | 向量存储模块: 提供基于Chroma的向量存储实现 3 | """ 4 | import os 5 | from typing import List, Dict, Any, Optional, Union 6 | 7 | from langchain_community.vectorstores import Chroma 8 | from langchain.vectorstores.base import VectorStore 9 | from langchain_core.documents import Document 10 | from langchain_core.embeddings import Embeddings 11 | 12 | from config import VECTOR_DB_DIR 13 | from langchain_components.embeddings import get_medical_embedding_model 14 | 15 | 16 | def create_vectorstore( 17 | documents: List[Document], 18 | embedding_model: Optional[Embeddings] = None, 19 | persist_directory: str = VECTOR_DB_DIR, 20 | collection_name: str = "medical_knowledge" 21 | ) -> Chroma: 22 | """ 23 | 创建向量存储 24 | 25 | Args: 26 | documents: 要存储的文档列表 27 | embedding_model: 嵌入模型,如果未提供则使用默认医学模型 28 | persist_directory: 向量存储持久化目录 29 | collection_name: 集合名称 30 | 31 | Returns: 32 | Chroma向量存储实例 33 | """ 34 | # 如果未提供嵌入模型,使用默认模型 35 | if embedding_model is None: 36 | embedding_model = get_medical_embedding_model() 37 | 38 | # 创建向量存储 39 | vectorstore = Chroma.from_documents( 40 | documents=documents, 41 | embedding=embedding_model, 42 | persist_directory=persist_directory, 43 | collection_name=collection_name 44 | ) 45 | 46 | # 持久化到磁盘 47 | vectorstore.persist() 48 | 49 | return vectorstore 50 | 51 | 52 | def load_vectorstore( 53 | persist_directory: str = VECTOR_DB_DIR, 54 | embedding_model: Optional[Embeddings] = None, 55 | collection_name: str = "medical_knowledge" 56 | ) -> Chroma: 57 | """ 58 | 加载现有向量存储 59 | 60 | Args: 61 | persist_directory: 向量存储持久化目录 62 | embedding_model: 嵌入模型,如果未提供则使用默认医学模型 63 | collection_name: 集合名称 64 | 65 | Returns: 66 | Chroma向量存储实例 67 | """ 68 | # 如果未提供嵌入模型,使用默认模型 69 | if embedding_model is None: 70 | embedding_model = get_medical_embedding_model() 71 | 72 | # 检查向量存储目录是否存在 73 | if not os.path.exists(persist_directory): 74 | raise FileNotFoundError(f"向量存储目录不存在: {persist_directory}") 75 | 76 | # 加载向量存储 77 | return Chroma( 78 | persist_directory=persist_directory, 79 | embedding_function=embedding_model, 80 | collection_name=collection_name 81 | ) 82 | 83 | 84 | def create_or_load_vectorstore( 85 | documents: Optional[List[Document]] = None, 86 | embedding_model: Optional[Embeddings] = None, 87 | persist_directory: str = VECTOR_DB_DIR, 88 | collection_name: str = "medical_knowledge", 89 | recreate: bool = False 90 | ) -> Chroma: 91 | """ 92 | 创建或加载向量存储 93 | 94 | Args: 95 | documents: 要存储的文档列表(如果需要创建新的向量存储) 96 | embedding_model: 嵌入模型,如果未提供则使用默认医学模型 97 | persist_directory: 向量存储持久化目录 98 | collection_name: 集合名称 99 | recreate: 是否强制重新创建向量存储 100 | 101 | Returns: 102 | Chroma向量存储实例 103 | """ 104 | # 如果未提供嵌入模型,使用默认模型 105 | if embedding_model is None: 106 | embedding_model = get_medical_embedding_model() 107 | 108 | # 检查向量存储是否已存在 109 | vector_db_exists = os.path.exists(persist_directory) and len(os.listdir(persist_directory)) > 0 110 | 111 | if recreate or not vector_db_exists: 112 | # 如果需要重新创建或不存在,创建新的向量存储 113 | if documents is None: 114 | raise ValueError("创建新的向量存储需要提供文档") 115 | 116 | # 如果目录已存在但要重新创建,清空目录 117 | if vector_db_exists and recreate: 118 | import shutil 119 | shutil.rmtree(persist_directory) 120 | os.makedirs(persist_directory, exist_ok=True) 121 | 122 | return create_vectorstore( 123 | documents=documents, 124 | embedding_model=embedding_model, 125 | persist_directory=persist_directory, 126 | collection_name=collection_name 127 | ) 128 | else: 129 | # 加载现有向量存储 130 | return load_vectorstore( 131 | persist_directory=persist_directory, 132 | embedding_model=embedding_model, 133 | collection_name=collection_name 134 | ) 135 | 136 | 137 | def add_documents_to_vectorstore( 138 | vectorstore: Chroma, 139 | documents: List[Document] 140 | ) -> None: 141 | """ 142 | 向现有向量存储添加文档 143 | 144 | Args: 145 | vectorstore: 向量存储实例 146 | documents: 要添加的文档列表 147 | """ 148 | vectorstore.add_documents(documents) 149 | vectorstore.persist() 150 | 151 | 152 | def search_similar_documents( 153 | vectorstore: Union[Chroma, VectorStore], 154 | query: str, 155 | k: int = 5, 156 | filter: Optional[Dict[str, Any]] = None 157 | ) -> List[Document]: 158 | """ 159 | 搜索相似文档 160 | 161 | Args: 162 | vectorstore: 向量存储实例 163 | query: 查询文本 164 | k: 返回的结果数量 165 | filter: 过滤条件 166 | 167 | Returns: 168 | 相似文档列表 169 | """ 170 | return vectorstore.similarity_search(query, k=k, filter=filter) 171 | 172 | 173 | def search_similar_with_scores( 174 | vectorstore: Union[Chroma, VectorStore], 175 | query: str, 176 | k: int = 5 177 | ) -> List[tuple[Document, float]]: 178 | """ 179 | 搜索相似文档并返回相似度分数 180 | 181 | Args: 182 | vectorstore: 向量存储实例 183 | query: 查询文本 184 | k: 返回的结果数量 185 | 186 | Returns: 187 | 包含文档和相似度分数的元组列表 188 | """ 189 | return vectorstore.similarity_search_with_score(query, k=k) 190 | -------------------------------------------------------------------------------- /langchain_components/document_loaders.py: -------------------------------------------------------------------------------- 1 | """ 2 | 文档加载器模块: 配置和提供医学文档的加载功能 3 | """ 4 | import os 5 | from typing import List, Dict, Any, Optional, Union 6 | import glob 7 | 8 | from langchain_community.document_loaders import ( 9 | PyPDFLoader, 10 | TextLoader, 11 | UnstructuredMarkdownLoader, 12 | DirectoryLoader, 13 | BSHTMLLoader 14 | ) 15 | from langchain_core.document_loaders import BaseLoader 16 | from langchain_core.documents import Document 17 | from langchain.text_splitter import RecursiveCharacterTextSplitter 18 | 19 | from config import MEDICAL_DOCS_DIR, CHUNK_SIZE, CHUNK_OVERLAP 20 | 21 | 22 | def load_pdf_document(file_path: str) -> List[Document]: 23 | """ 24 | 加载PDF文档 25 | 26 | Args: 27 | file_path: PDF文件路径 28 | 29 | Returns: 30 | 文档对象列表 31 | """ 32 | if not os.path.exists(file_path): 33 | raise FileNotFoundError(f"文件不存在: {file_path}") 34 | 35 | loader = PyPDFLoader(file_path) 36 | return loader.load() 37 | 38 | 39 | def load_text_document(file_path: str) -> List[Document]: 40 | """ 41 | 加载文本文档 42 | 43 | Args: 44 | file_path: 文本文件路径 45 | 46 | Returns: 47 | 文档对象列表 48 | """ 49 | if not os.path.exists(file_path): 50 | raise FileNotFoundError(f"文件不存在: {file_path}") 51 | 52 | loader = TextLoader(file_path, encoding='utf-8') 53 | return loader.load() 54 | 55 | 56 | def load_markdown_document(file_path: str) -> List[Document]: 57 | """ 58 | 加载Markdown文档 59 | 60 | Args: 61 | file_path: Markdown文件路径 62 | 63 | Returns: 64 | 文档对象列表 65 | """ 66 | if not os.path.exists(file_path): 67 | raise FileNotFoundError(f"文件不存在: {file_path}") 68 | 69 | loader = UnstructuredMarkdownLoader(file_path) 70 | return loader.load() 71 | 72 | 73 | def load_html_document(file_path: str) -> List[Document]: 74 | """ 75 | 加载HTML文档 76 | 77 | Args: 78 | file_path: HTML文件路径 79 | 80 | Returns: 81 | 文档对象列表 82 | """ 83 | if not os.path.exists(file_path): 84 | raise FileNotFoundError(f"文件不存在: {file_path}") 85 | 86 | loader = BSHTMLLoader(file_path) 87 | return loader.load() 88 | 89 | 90 | def load_directory( 91 | directory_path: str, 92 | glob_pattern: str = "**/*.*", 93 | show_progress: bool = True 94 | ) -> List[Document]: 95 | """ 96 | 加载目录中的所有文档 97 | 98 | Args: 99 | directory_path: 目录路径 100 | glob_pattern: 文件匹配模式 101 | show_progress: 是否显示进度条 102 | 103 | Returns: 104 | 文档对象列表 105 | """ 106 | if not os.path.exists(directory_path): 107 | raise FileNotFoundError(f"目录不存在: {directory_path}") 108 | 109 | # 创建不同类型文档的加载器 110 | loaders = { 111 | ".pdf": PyPDFLoader, 112 | ".txt": TextLoader, 113 | ".md": UnstructuredMarkdownLoader, 114 | ".html": BSHTMLLoader, 115 | ".htm": BSHTMLLoader, 116 | } 117 | 118 | # 使用DirectoryLoader加载目录中的所有文档 119 | loader = DirectoryLoader( 120 | directory_path, 121 | glob=glob_pattern, 122 | loader_cls=lambda file_path: select_loader(file_path, loaders), 123 | show_progress=show_progress 124 | ) 125 | 126 | return loader.load() 127 | 128 | 129 | def select_loader(file_path: str, loaders: Dict[str, BaseLoader]) -> BaseLoader: 130 | """ 131 | 根据文件扩展名选择合适的加载器 132 | 133 | Args: 134 | file_path: 文件路径 135 | loaders: 加载器字典 136 | 137 | Returns: 138 | 合适的文档加载器 139 | """ 140 | # 获取文件扩展名 141 | ext = os.path.splitext(file_path)[1].lower() 142 | 143 | # 选择合适的加载器 144 | if ext in loaders: 145 | return loaders[ext](file_path) 146 | else: 147 | # 默认使用TextLoader 148 | return TextLoader(file_path, encoding='utf-8') 149 | 150 | 151 | def split_documents( 152 | documents: List[Document], 153 | chunk_size: int = CHUNK_SIZE, 154 | chunk_overlap: int = CHUNK_OVERLAP 155 | ) -> List[Document]: 156 | """ 157 | 分割文档为较小的块 158 | 159 | Args: 160 | documents: 要分割的文档列表 161 | chunk_size: 块大小 162 | chunk_overlap: 块重叠大小 163 | 164 | Returns: 165 | 分割后的文档列表 166 | """ 167 | text_splitter = RecursiveCharacterTextSplitter( 168 | chunk_size=chunk_size, 169 | chunk_overlap=chunk_overlap, 170 | separators=["\n\n", "\n", " ", ""] 171 | ) 172 | 173 | return text_splitter.split_documents(documents) 174 | 175 | 176 | def load_medical_knowledge_base( 177 | directory: str = MEDICAL_DOCS_DIR, 178 | glob_pattern: str = "**/*.*", 179 | use_splitter: bool = True, 180 | chunk_size: int = CHUNK_SIZE, 181 | chunk_overlap: int = CHUNK_OVERLAP 182 | ) -> List[Document]: 183 | """ 184 | 加载医学知识库中的所有文档 185 | 186 | Args: 187 | directory: 医学文档目录 188 | glob_pattern: 文件匹配模式 189 | use_splitter: 是否使用文本分割器 190 | chunk_size: 块大小 191 | chunk_overlap: 块重叠大小 192 | 193 | Returns: 194 | 处理后的文档列表 195 | """ 196 | # 加载目录中的所有文档 197 | try: 198 | documents = load_directory(directory, glob_pattern) 199 | print(f"成功加载 {len(documents)} 个文档") 200 | 201 | # 如果需要分割文档 202 | if use_splitter: 203 | documents = split_documents(documents, chunk_size, chunk_overlap) 204 | print(f"分割后得到 {len(documents)} 个文档块") 205 | 206 | return documents 207 | 208 | except Exception as e: 209 | print(f"加载医学知识库时出错: {e}") 210 | return [] 211 | 212 | 213 | def load_medical_documents_by_type( 214 | doc_type: str, 215 | directory: str = MEDICAL_DOCS_DIR, 216 | use_splitter: bool = True 217 | ) -> List[Document]: 218 | """ 219 | 按类型加载医学文档 220 | 221 | Args: 222 | doc_type: 文档类型 ("radiology", "pathology", "general", etc.) 223 | directory: 基础目录 224 | use_splitter: 是否使用文本分割器 225 | 226 | Returns: 227 | 加载的文档列表 228 | """ 229 | # 确定类型目录 230 | type_dir = os.path.join(directory, doc_type) 231 | 232 | # 如果类型目录不存在,尝试搜索包含类型名称的文件 233 | if not os.path.exists(type_dir): 234 | # 构建文件匹配模式 235 | pattern = f"**/*{doc_type}*.*" 236 | return load_medical_knowledge_base(directory, pattern, use_splitter) 237 | 238 | # 如果类型目录存在,加载该目录下的所有文档 239 | return load_medical_knowledge_base(type_dir, "**/*.*", use_splitter) 240 | -------------------------------------------------------------------------------- /utils/dicom_handler.py: -------------------------------------------------------------------------------- 1 | """ 2 | DICOM处理模块: 提供DICOM医学影像文件的读取和处理功能 3 | """ 4 | import os 5 | import pydicom 6 | import numpy as np 7 | from typing import Dict, Any, Tuple, Optional, List 8 | 9 | from config import DEFAULT_WINDOW_CENTER, DEFAULT_WINDOW_WIDTH 10 | 11 | 12 | def load_dicom(dicom_path: str) -> pydicom.FileDataset: 13 | """ 14 | 加载DICOM文件 15 | 16 | Args: 17 | dicom_path: DICOM文件路径 18 | 19 | Returns: 20 | DICOM数据集对象 21 | """ 22 | if not os.path.exists(dicom_path): 23 | raise FileNotFoundError(f"DICOM文件不存在: {dicom_path}") 24 | 25 | return pydicom.dcmread(dicom_path) 26 | 27 | 28 | def extract_dicom_metadata(dicom_data: pydicom.FileDataset) -> Dict[str, Any]: 29 | """ 30 | 提取DICOM元数据信息 31 | 32 | Args: 33 | dicom_data: DICOM数据集 34 | 35 | Returns: 36 | 包含关键元数据的字典 37 | """ 38 | metadata = {} 39 | 40 | # 尝试提取常见的DICOM标签 41 | try: 42 | if hasattr(dicom_data, 'PatientID'): 43 | metadata['PatientID'] = dicom_data.PatientID 44 | if hasattr(dicom_data, 'PatientName'): 45 | metadata['PatientName'] = str(dicom_data.PatientName) 46 | if hasattr(dicom_data, 'PatientBirthDate'): 47 | metadata['PatientBirthDate'] = dicom_data.PatientBirthDate 48 | if hasattr(dicom_data, 'PatientSex'): 49 | metadata['PatientSex'] = dicom_data.PatientSex 50 | if hasattr(dicom_data, 'StudyDate'): 51 | metadata['StudyDate'] = dicom_data.StudyDate 52 | if hasattr(dicom_data, 'StudyDescription'): 53 | metadata['StudyDescription'] = dicom_data.StudyDescription 54 | if hasattr(dicom_data, 'Modality'): 55 | metadata['Modality'] = dicom_data.Modality 56 | if hasattr(dicom_data, 'BodyPartExamined'): 57 | metadata['BodyPartExamined'] = dicom_data.BodyPartExamined 58 | if hasattr(dicom_data, 'SliceThickness'): 59 | metadata['SliceThickness'] = dicom_data.SliceThickness 60 | if hasattr(dicom_data, 'WindowCenter'): 61 | metadata['WindowCenter'] = dicom_data.WindowCenter 62 | if hasattr(dicom_data, 'WindowWidth'): 63 | metadata['WindowWidth'] = dicom_data.WindowWidth 64 | except Exception as e: 65 | print(f"提取元数据时出错: {e}") 66 | 67 | return metadata 68 | 69 | 70 | def dicom_to_numpy(dicom_data: pydicom.FileDataset) -> np.ndarray: 71 | """ 72 | 将DICOM数据转换为NumPy数组 73 | 74 | Args: 75 | dicom_data: DICOM数据集 76 | 77 | Returns: 78 | 表示图像的NumPy数组 79 | """ 80 | # 提取像素数据 81 | image = dicom_data.pixel_array.astype(np.float32) 82 | 83 | # 应用放射度变换 84 | if hasattr(dicom_data, 'RescaleSlope') and hasattr(dicom_data, 'RescaleIntercept'): 85 | image = image * dicom_data.RescaleSlope + dicom_data.RescaleIntercept 86 | 87 | return image 88 | 89 | 90 | def window_dicom_image( 91 | image: np.ndarray, 92 | dicom_data: Optional[pydicom.FileDataset] = None, 93 | window_center: Optional[int] = None, 94 | window_width: Optional[int] = None 95 | ) -> np.ndarray: 96 | """ 97 | 应用窗宽窗位设置调整DICOM图像 98 | 99 | Args: 100 | image: 图像数组 101 | dicom_data: DICOM数据集,用于获取默认窗宽窗位 102 | window_center: 自定义窗位,优先级高于DICOM标签 103 | window_width: 自定义窗宽,优先级高于DICOM标签 104 | 105 | Returns: 106 | 调整后的图像 107 | """ 108 | # 确定窗宽窗位值 109 | if window_center is None and window_width is None and dicom_data is not None: 110 | if hasattr(dicom_data, 'WindowCenter') and hasattr(dicom_data, 'WindowWidth'): 111 | window_center = dicom_data.WindowCenter 112 | window_width = dicom_data.WindowWidth 113 | 114 | # 处理多值情况 115 | if isinstance(window_center, pydicom.multival.MultiValue): 116 | window_center = window_center[0] 117 | if isinstance(window_width, pydicom.multival.MultiValue): 118 | window_width = window_width[0] 119 | else: 120 | window_center = DEFAULT_WINDOW_CENTER 121 | window_width = DEFAULT_WINDOW_WIDTH 122 | elif window_center is None or window_width is None: 123 | window_center = window_center or DEFAULT_WINDOW_CENTER 124 | window_width = window_width or DEFAULT_WINDOW_WIDTH 125 | 126 | # 计算窗口的最小值和最大值 127 | min_value = window_center - window_width // 2 128 | max_value = window_center + window_width // 2 129 | 130 | # 应用窗宽窗位 131 | windowed = np.clip(image, min_value, max_value) 132 | 133 | # 标准化到0-1范围 134 | if max_value != min_value: # 防止除以零 135 | windowed = (windowed - min_value) / (max_value - min_value) 136 | else: 137 | windowed = np.zeros_like(windowed) 138 | 139 | return windowed 140 | 141 | 142 | def save_dicom_as_png(dicom_path: str, output_path: str, apply_window: bool = True) -> str: 143 | """ 144 | 将DICOM文件转换为PNG图像并保存 145 | 146 | Args: 147 | dicom_path: DICOM文件路径 148 | output_path: 输出PNG文件路径 149 | apply_window: 是否应用默认窗宽窗位 150 | 151 | Returns: 152 | 保存的PNG文件路径 153 | """ 154 | import cv2 155 | 156 | # 加载DICOM 157 | dicom_data = load_dicom(dicom_path) 158 | 159 | # 转换为NumPy数组 160 | image = dicom_to_numpy(dicom_data) 161 | 162 | # 应用窗宽窗位 163 | if apply_window: 164 | image = window_dicom_image(image, dicom_data) 165 | else: 166 | # 如果不应用窗宽窗位,则进行线性标准化 167 | image = (image - image.min()) / (image.max() - image.min() + 1e-8) 168 | 169 | # 转换为8位图像 170 | image_8bit = (image * 255).astype(np.uint8) 171 | 172 | # 保存为PNG 173 | cv2.imwrite(output_path, image_8bit) 174 | 175 | return output_path 176 | 177 | 178 | def process_dicom_directory( 179 | dicom_dir: str, 180 | output_dir: str, 181 | extension: str = ".dcm" 182 | ) -> List[str]: 183 | """ 184 | 处理目录中的所有DICOM文件 185 | 186 | Args: 187 | dicom_dir: DICOM文件目录 188 | output_dir: 输出PNG文件目录 189 | extension: DICOM文件扩展名 190 | 191 | Returns: 192 | 处理后的PNG文件路径列表 193 | """ 194 | # 确保输出目录存在 195 | os.makedirs(output_dir, exist_ok=True) 196 | 197 | # 处理每个DICOM文件 198 | png_paths = [] 199 | for filename in os.listdir(dicom_dir): 200 | if filename.lower().endswith(extension): 201 | dicom_path = os.path.join(dicom_dir, filename) 202 | png_filename = os.path.splitext(filename)[0] + '.png' 203 | png_path = os.path.join(output_dir, png_filename) 204 | 205 | try: 206 | save_dicom_as_png(dicom_path, png_path) 207 | png_paths.append(png_path) 208 | except Exception as e: 209 | print(f"处理文件 {dicom_path} 时出错: {e}") 210 | 211 | return png_paths 212 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | 医学CT智能体系统入口: 初始化和启动系统 3 | """ 4 | import os 5 | import argparse 6 | from typing import Dict, Any, List, Optional, Union 7 | import datetime 8 | 9 | from langchain_openai import ChatOpenAI 10 | from langchain_community.llms import HuggingFacePipeline 11 | 12 | from tools.ct_analysis import create_ct_analysis_tool 13 | from tools.knowledge_retrieval import MedicalKnowledgeRetrievalTool 14 | from tools.report_generation import MedicalReportGenerator 15 | from crew.process import MedicalCTCrew 16 | from config import ( 17 | OPENAI_API_KEY, 18 | CREWAI_LLM_MODEL, 19 | DATA_DIR, 20 | SAMPLE_IMAGES_DIR, 21 | LOG_DIR 22 | ) 23 | 24 | 25 | def setup_llm(model_type: str = "openai") -> Any: 26 | """ 27 | 设置语言模型 28 | 29 | Args: 30 | model_type: 模型类型 ("openai" 或 "local") 31 | 32 | Returns: 33 | 语言模型实例 34 | """ 35 | if model_type == "openai": 36 | # 使用OpenAI模型 37 | if not OPENAI_API_KEY: 38 | raise ValueError("使用OpenAI需要设置OPENAI_API_KEY环境变量") 39 | 40 | return ChatOpenAI( 41 | model=CREWAI_LLM_MODEL, 42 | temperature=0.2, 43 | openai_api_key=OPENAI_API_KEY 44 | ) 45 | else: 46 | # 使用本地模型 47 | try: 48 | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline 49 | import torch 50 | 51 | # 指定本地医学模型 52 | model_path = "medalpaca/medalpaca-7b" # 可以替换为其他适合医学的本地模型路径 53 | 54 | print(f"加载本地模型: {model_path}") 55 | tokenizer = AutoTokenizer.from_pretrained(model_path) 56 | model = AutoModelForCausalLM.from_pretrained( 57 | model_path, 58 | torch_dtype=torch.float16, 59 | device_map="auto" 60 | ) 61 | 62 | # 创建pipeline 63 | pipe = pipeline( 64 | "text-generation", 65 | model=model, 66 | tokenizer=tokenizer, 67 | max_new_tokens=1024, 68 | temperature=0.2, 69 | top_p=0.95, 70 | repetition_penalty=1.15 71 | ) 72 | 73 | # 创建LangChain包装器 74 | return HuggingFacePipeline(pipeline=pipe) 75 | 76 | except ImportError: 77 | print("未安装transformers或torch,回退到使用OpenAI API") 78 | return ChatOpenAI( 79 | model=CREWAI_LLM_MODEL, 80 | temperature=0.2, 81 | openai_api_key=OPENAI_API_KEY 82 | ) 83 | 84 | 85 | def get_sample_images() -> List[str]: 86 | """ 87 | 获取示例图像路径列表 88 | 89 | Returns: 90 | 图像路径列表 91 | """ 92 | if not os.path.exists(SAMPLE_IMAGES_DIR): 93 | os.makedirs(SAMPLE_IMAGES_DIR, exist_ok=True) 94 | print(f"示例图像目录已创建: {SAMPLE_IMAGES_DIR}") 95 | print("请在此目录中添加CT图像文件") 96 | return [] 97 | 98 | # 获取所有图像文件 99 | image_files = [] 100 | for ext in ['.dcm', '.png', '.jpg', '.jpeg']: 101 | image_files.extend( 102 | [os.path.join(SAMPLE_IMAGES_DIR, f) for f in os.listdir(SAMPLE_IMAGES_DIR) if f.lower().endswith(ext)] 103 | ) 104 | 105 | return image_files 106 | 107 | 108 | def setup_tools(llm: Any) -> Dict[str, Any]: 109 | """ 110 | 设置所需工具 111 | 112 | Args: 113 | llm: 语言模型 114 | 115 | Returns: 116 | 工具字典 117 | """ 118 | # 创建CT分析工具 119 | ct_analysis_tool = create_ct_analysis_tool() 120 | 121 | # 创建知识检索工具 122 | knowledge_tool = MedicalKnowledgeRetrievalTool(llm=llm) 123 | 124 | # 创建报告生成工具 125 | report_tool = MedicalReportGenerator(llm=llm) 126 | 127 | return { 128 | "ct_analysis_tool": ct_analysis_tool, 129 | "knowledge_tool": knowledge_tool, 130 | "report_tool": report_tool 131 | } 132 | 133 | 134 | def analyze_ct_images( 135 | image_paths: Union[str, List[str]], 136 | tools: Dict[str, Any], 137 | llm: Any, 138 | output_dir: Optional[str] = None 139 | ) -> Dict[str, Any]: 140 | """ 141 | 分析CT图像并生成诊断报告 142 | 143 | Args: 144 | image_paths: 图像路径或路径列表 145 | tools: 工具字典 146 | llm: 语言模型 147 | output_dir: 输出目录 148 | 149 | Returns: 150 | 分析结果 151 | """ 152 | # 创建医学CT智能体团队 153 | crew = MedicalCTCrew( 154 | llm=llm, 155 | ct_analysis_tool=tools["ct_analysis_tool"], 156 | knowledge_tool=tools["knowledge_tool"], 157 | report_tool=tools["report_tool"] 158 | ) 159 | 160 | # 确定输出路径 161 | if output_dir: 162 | os.makedirs(output_dir, exist_ok=True) 163 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 164 | output_path = os.path.join(output_dir, f"ct_report_{timestamp}.md") 165 | else: 166 | output_path = None 167 | 168 | # 执行分析 169 | result = crew.analyze_ct_images( 170 | image_paths=image_paths, 171 | output_path=output_path 172 | ) 173 | 174 | return result 175 | 176 | 177 | def main(): 178 | """主函数""" 179 | # 解析命令行参数 180 | parser = argparse.ArgumentParser(description="医学CT智能体系统") 181 | parser.add_argument("--image_path", type=str, help="CT图像文件或目录路径") 182 | parser.add_argument("--model", type=str, default="openai", choices=["openai", "local"], help="使用的语言模型类型") 183 | parser.add_argument("--output_dir", type=str, default=os.path.join(LOG_DIR, "reports"), help="输出目录") 184 | args = parser.parse_args() 185 | 186 | # 设置语言模型 187 | print("初始化语言模型...") 188 | llm = setup_llm(args.model) 189 | 190 | # 设置工具 191 | print("初始化工具...") 192 | tools = setup_tools(llm) 193 | 194 | # 确定分析的图像 195 | if args.image_path: 196 | image_paths = args.image_path 197 | if os.path.isdir(args.image_path): 198 | # 如果提供的是目录,获取所有图像文件 199 | image_files = [] 200 | for ext in ['.dcm', '.png', '.jpg', '.jpeg']: 201 | image_files.extend( 202 | [os.path.join(args.image_path, f) for f in os.listdir(args.image_path) if f.lower().endswith(ext)] 203 | ) 204 | image_paths = image_files 205 | else: 206 | # 使用示例图像 207 | image_paths = get_sample_images() 208 | if not image_paths: 209 | print("未找到任何图像文件,请提供有效的图像路径") 210 | return 211 | 212 | # 执行CT图像分析 213 | print(f"开始分析CT图像: {image_paths if isinstance(image_paths, str) else len(image_paths)}张图像") 214 | result = analyze_ct_images( 215 | image_paths=image_paths, 216 | tools=tools, 217 | llm=llm, 218 | output_dir=args.output_dir 219 | ) 220 | 221 | # 输出结果摘要 222 | report = result.get("report", {}) 223 | if isinstance(report, dict) and "content" in report: 224 | print("\n==== 诊断报告摘要 ====") 225 | content = report["content"] 226 | # 打印前500个字符 227 | print(f"{content[:500]}...") 228 | print("\n完整报告已保存到输出目录") 229 | else: 230 | print("\n==== 诊断报告 ====") 231 | print(report) 232 | 233 | 234 | if __name__ == "__main__": 235 | main() 236 | -------------------------------------------------------------------------------- /crew/tasks.py: -------------------------------------------------------------------------------- 1 | """ 2 | 定义CrewAI任务: CT图像分析、知识检索和报告生成任务 3 | """ 4 | from typing import Dict, Any, List, Optional, Union 5 | import os 6 | 7 | from crewai import Task 8 | from crewai.agent import Agent 9 | 10 | from tools.ct_analysis import BiomedCLIPTool 11 | from tools.knowledge_retrieval import MedicalKnowledgeRetrievalTool 12 | from tools.report_generation import MedicalReportGenerator 13 | 14 | 15 | def create_image_analysis_task( 16 | agent: Agent, 17 | ct_analysis_tool: BiomedCLIPTool, 18 | image_paths: Union[str, List[str]], 19 | task_id: str = "image_analysis" 20 | ) -> Task: 21 | """ 22 | 创建CT图像分析任务 23 | 24 | Args: 25 | agent: 分配任务的智能体 26 | ct_analysis_tool: CT分析工具 27 | image_paths: CT图像路径或路径列表 28 | task_id: 任务ID 29 | 30 | Returns: 31 | 图像分析任务 32 | """ 33 | # 确保image_paths是列表格式 34 | if isinstance(image_paths, str): 35 | if os.path.isdir(image_paths): 36 | # 如果是目录,获取所有图像文件 37 | image_files = [] 38 | for ext in ['.dcm', '.png', '.jpg', '.jpeg']: 39 | image_files.extend( 40 | [os.path.join(image_paths, f) for f in os.listdir(image_paths) if f.lower().endswith(ext)] 41 | ) 42 | image_paths = image_files 43 | else: 44 | # 如果是单个文件,转换为列表 45 | image_paths = [image_paths] 46 | 47 | # 构建任务描述 48 | image_count = len(image_paths) 49 | if image_count == 1: 50 | description = f"分析一张CT图像,识别其中的关键医学特征和可能的异常。" 51 | else: 52 | description = f"分析{image_count}张CT图像,识别其中的关键医学特征和可能的异常。" 53 | 54 | # 创建任务 55 | return Task( 56 | description=description, 57 | agent=agent, 58 | expected_output=""" 59 | 详细的CT图像分析报告,包括: 60 | 1. 图像质量评估 61 | 2. 正常解剖结构描述 62 | 3. 异常发现的详细描述,包括位置、大小、形态、密度/强度等特征 63 | 4. 关键医学特征的总结 64 | """, 65 | tools=[ct_analysis_tool.analyze_multiple_images], 66 | context=[ 67 | f"需要分析的CT图像数量: {image_count}", 68 | f"图像路径: {', '.join(image_paths[:3])}{'...' if image_count > 3 else ''}", 69 | ], 70 | id=task_id 71 | ) 72 | 73 | 74 | def create_knowledge_retrieval_task( 75 | agent: Agent, 76 | retrieval_tool: MedicalKnowledgeRetrievalTool, 77 | ct_analysis_result: Dict[str, Any], 78 | task_id: str = "knowledge_retrieval" 79 | ) -> Task: 80 | """ 81 | 创建医学知识检索任务 82 | 83 | Args: 84 | agent: 分配任务的智能体 85 | retrieval_tool: 知识检索工具 86 | ct_analysis_result: CT分析结果 87 | task_id: 任务ID 88 | 89 | Returns: 90 | 知识检索任务 91 | """ 92 | # 提取CT分析描述 93 | if "summary" in ct_analysis_result: 94 | ct_description = ct_analysis_result["summary"].get("combined_description", "") 95 | elif "analysis" in ct_analysis_result: 96 | ct_description = ct_analysis_result["analysis"].get("combined_description", "") 97 | else: 98 | ct_description = str(ct_analysis_result) 99 | 100 | return Task( 101 | description="基于CT分析结果检索相关的医学知识和研究信息", 102 | agent=agent, 103 | expected_output=""" 104 | 全面的医学知识检索报告,包括: 105 | 1. 与CT发现相关的疾病或病理说明 106 | 2. 相关的诊断标准和鉴别诊断 107 | 3. 最新的治疗指南和研究进展 108 | 4. 预后和风险因素分析 109 | 5. 临床建议的科学依据 110 | """, 111 | tools=[retrieval_tool.retrieve_knowledge_from_ct_analysis], 112 | context=[ 113 | f"CT分析结果: {ct_description[:500]}{'...' if len(ct_description) > 500 else ''}", 114 | "检索相关的医学知识,重点关注与CT发现相符的疾病、症状、诊断标准和治疗方法" 115 | ], 116 | id=task_id 117 | ) 118 | 119 | 120 | def create_report_generation_task( 121 | agent: Agent, 122 | report_tool: MedicalReportGenerator, 123 | ct_analysis_result: Dict[str, Any], 124 | medical_knowledge: str, 125 | task_id: str = "report_generation" 126 | ) -> Task: 127 | """ 128 | 创建诊断报告生成任务 129 | 130 | Args: 131 | agent: 分配任务的智能体 132 | report_tool: 报告生成工具 133 | ct_analysis_result: CT分析结果 134 | medical_knowledge: 医学知识内容 135 | task_id: 任务ID 136 | 137 | Returns: 138 | 报告生成任务 139 | """ 140 | # 提取CT分析描述 141 | if "summary" in ct_analysis_result: 142 | ct_description = ct_analysis_result["summary"].get("combined_description", "") 143 | abnormality_detected = ct_analysis_result["summary"].get("abnormality_detected", False) 144 | elif "analysis" in ct_analysis_result: 145 | ct_description = ct_analysis_result["analysis"].get("combined_description", "") 146 | abnormality_detected = ct_analysis_result["analysis"].get("abnormality_detected", False) 147 | else: 148 | ct_description = str(ct_analysis_result) 149 | abnormality_detected = "异常" in ct_description 150 | 151 | # 根据是否发现异常调整任务描述 152 | if abnormality_detected: 153 | task_description = "生成详细的医学CT诊断报告,重点分析发现的异常" 154 | else: 155 | task_description = "生成医学CT诊断报告,确认正常发现并提供适当的建议" 156 | 157 | return Task( 158 | description=task_description, 159 | agent=agent, 160 | expected_output=""" 161 | 专业的CT诊断报告,包括: 162 | 1. 详细的影像发现描述 163 | 2. 专业的分析与解释 164 | 3. 明确的诊断意见或鉴别诊断 165 | 4. 具体的后续建议和处理方案 166 | """, 167 | tools=[report_tool.generate_ct_report], 168 | context=[ 169 | f"CT分析结果: {ct_description[:300]}{'...' if len(ct_description) > 300 else ''}", 170 | f"医学知识参考: {medical_knowledge[:300]}{'...' if len(medical_knowledge) > 300 else ''}", 171 | "生成一份专业、全面且结构清晰的医学诊断报告" 172 | ], 173 | id=task_id 174 | ) 175 | 176 | 177 | def create_comparative_report_task( 178 | agent: Agent, 179 | report_tool: MedicalReportGenerator, 180 | current_ct_result: Dict[str, Any], 181 | previous_ct_result: Dict[str, Any], 182 | medical_knowledge: str, 183 | task_id: str = "comparative_report" 184 | ) -> Task: 185 | """ 186 | 创建对比报告生成任务 187 | 188 | Args: 189 | agent: 分配任务的智能体 190 | report_tool: 报告生成工具 191 | current_ct_result: 当前CT分析结果 192 | previous_ct_result: 之前的CT分析结果 193 | medical_knowledge: 医学知识内容 194 | task_id: 任务ID 195 | 196 | Returns: 197 | 对比报告生成任务 198 | """ 199 | # 提取CT分析描述 200 | if "summary" in current_ct_result: 201 | current_description = current_ct_result["summary"].get("combined_description", "") 202 | elif "analysis" in current_ct_result: 203 | current_description = current_ct_result["analysis"].get("combined_description", "") 204 | else: 205 | current_description = str(current_ct_result) 206 | 207 | if "summary" in previous_ct_result: 208 | previous_description = previous_ct_result["summary"].get("combined_description", "") 209 | elif "analysis" in previous_ct_result: 210 | previous_description = previous_ct_result["analysis"].get("combined_description", "") 211 | else: 212 | previous_description = str(previous_ct_result) 213 | 214 | return Task( 215 | description="比较当前与之前的CT扫描结果,生成对比分析报告", 216 | agent=agent, 217 | expected_output=""" 218 | 详细的CT对比分析报告,包括: 219 | 1. 当前与之前CT的对比发现 220 | 2. 变化的分析与解释 221 | 3. 基于对比结果的诊断意见 222 | 4. 针对疾病进展或改善的建议 223 | """, 224 | tools=[report_tool.generate_comparative_report], 225 | context=[ 226 | f"当前CT分析: {current_description[:200]}...", 227 | f"之前CT分析: {previous_description[:200]}...", 228 | f"医学知识参考: {medical_knowledge[:200]}...", 229 | "重点分析两次CT扫描之间的变化和临床意义" 230 | ], 231 | id=task_id 232 | ) 233 | -------------------------------------------------------------------------------- /tools/knowledge_retrieval.py: -------------------------------------------------------------------------------- 1 | """ 2 | 知识检索工具: 使用LangChain RAG从医学知识库检索相关信息 3 | """ 4 | import os 5 | from typing import List, Dict, Any, Optional, Union 6 | 7 | from langchain_core.documents import Document 8 | from langchain_core.retrievers import BaseRetriever 9 | from langchain_core.language_models import BaseLanguageModel 10 | from langchain.chains import LLMChain 11 | from langchain.prompts import PromptTemplate 12 | 13 | from config import VECTOR_DB_DIR, TOP_K_RETRIEVALS 14 | from langchain_components.document_loaders import load_medical_knowledge_base 15 | from langchain_components.embeddings import get_medical_embedding_model 16 | from langchain_components.vectorstore import create_or_load_vectorstore 17 | from langchain_components.retriever import ( 18 | get_medical_knowledge_retriever, 19 | generate_multiple_queries, 20 | retrieve_medical_knowledge, 21 | build_medical_context 22 | ) 23 | 24 | 25 | class MedicalKnowledgeRetrievalTool: 26 | """医学知识检索工具""" 27 | 28 | def __init__( 29 | self, 30 | llm: Optional[BaseLanguageModel] = None, 31 | retriever: Optional[BaseRetriever] = None, 32 | vector_db_dir: str = VECTOR_DB_DIR, 33 | collection_name: str = "medical_knowledge", 34 | rebuild_vectordb: bool = False 35 | ): 36 | """ 37 | 初始化医学知识检索工具 38 | 39 | Args: 40 | llm: 语言模型,用于查询生成和知识整合 41 | retriever: 检索器实例,如果提供则直接使用 42 | vector_db_dir: 向量数据库目录 43 | collection_name: 集合名称 44 | rebuild_vectordb: 是否重建向量数据库 45 | """ 46 | self.llm = llm 47 | self.vector_db_dir = vector_db_dir 48 | self.collection_name = collection_name 49 | 50 | # 设置检索器 51 | if retriever is not None: 52 | self.retriever = retriever 53 | else: 54 | # 检查向量库是否存在 55 | vector_db_exists = os.path.exists(vector_db_dir) and len(os.listdir(vector_db_dir)) > 0 56 | 57 | if not vector_db_exists or rebuild_vectordb: 58 | print("向量数据库不存在或需要重建,正在创建...") 59 | # 加载医学知识 60 | documents = load_medical_knowledge_base() 61 | if not documents: 62 | raise ValueError("无法加载医学知识文档") 63 | 64 | # 创建嵌入模型 65 | embedding_model = get_medical_embedding_model() 66 | 67 | # 创建向量存储和检索器 68 | self.retriever = get_medical_knowledge_retriever( 69 | documents=documents, 70 | embedding_model=embedding_model 71 | ) 72 | else: 73 | print("加载现有向量数据库...") 74 | # 加载现有向量存储和检索器 75 | embedding_model = get_medical_embedding_model() 76 | vectorstore = create_or_load_vectorstore( 77 | embedding_model=embedding_model, 78 | persist_directory=vector_db_dir, 79 | collection_name=collection_name 80 | ) 81 | self.retriever = get_medical_knowledge_retriever( 82 | vectorstore=vectorstore 83 | ) 84 | 85 | def retrieve_knowledge( 86 | self, 87 | query: str, 88 | k: int = TOP_K_RETRIEVALS, 89 | return_documents: bool = False 90 | ) -> Union[str, List[Document]]: 91 | """ 92 | 检索相关医学知识 93 | 94 | Args: 95 | query: 查询文本 96 | k: 检索结果数量 97 | return_documents: 是否返回文档对象而非文本 98 | 99 | Returns: 100 | 检索到的知识上下文或文档列表 101 | """ 102 | # 检索相关文档 103 | documents = retrieve_medical_knowledge(query, self.retriever) 104 | 105 | # 限制结果数量 106 | documents = documents[:k] 107 | 108 | if return_documents: 109 | return documents 110 | else: 111 | # 构建上下文 112 | context = build_medical_context(documents) 113 | return context 114 | 115 | def retrieve_knowledge_from_ct_analysis( 116 | self, 117 | ct_analysis_result: Dict[str, Any], 118 | num_queries: int = 3, 119 | return_documents: bool = False 120 | ) -> Union[str, Dict[str, List[Document]]]: 121 | """ 122 | 基于CT分析结果检索相关医学知识 123 | 124 | Args: 125 | ct_analysis_result: CT分析结果 126 | num_queries: 生成的查询数量 127 | return_documents: 是否返回文档对象而非文本 128 | 129 | Returns: 130 | 检索到的知识上下文或按查询分组的文档字典 131 | """ 132 | # 提取CT分析描述 133 | if "summary" in ct_analysis_result: 134 | ct_description = ct_analysis_result["summary"].get("combined_description", "") 135 | elif "analysis" in ct_analysis_result: 136 | ct_description = ct_analysis_result["analysis"].get("combined_description", "") 137 | else: 138 | ct_description = str(ct_analysis_result) 139 | 140 | # 如果提供了LLM,生成多个查询 141 | if self.llm and ct_description: 142 | queries = generate_multiple_queries(ct_description, self.llm, num_queries) 143 | else: 144 | # 否则使用单个查询 145 | queries = [ct_description] 146 | 147 | all_documents = {} 148 | all_context_parts = [] 149 | 150 | # 对每个查询执行检索 151 | for i, query in enumerate(queries): 152 | documents = retrieve_medical_knowledge(query, self.retriever) 153 | 154 | if return_documents: 155 | all_documents[f"query_{i}"] = documents 156 | else: 157 | context = build_medical_context(documents) 158 | all_context_parts.append(f"--- 查询 {i+1}: {query} ---\n\n{context}") 159 | 160 | if return_documents: 161 | return all_documents 162 | else: 163 | # 组合所有查询的结果 164 | combined_context = "\n\n".join(all_context_parts) 165 | return combined_context 166 | 167 | def integrate_knowledge_with_llm( 168 | self, 169 | ct_analysis_result: Dict[str, Any], 170 | knowledge_context: str 171 | ) -> str: 172 | """ 173 | 使用LLM整合CT分析结果和检索到的医学知识 174 | 175 | Args: 176 | ct_analysis_result: CT分析结果 177 | knowledge_context: 检索到的知识上下文 178 | 179 | Returns: 180 | 整合后的知识总结 181 | """ 182 | if not self.llm: 183 | raise ValueError("需要提供LLM模型才能整合知识") 184 | 185 | # 提取CT分析描述 186 | if "summary" in ct_analysis_result: 187 | ct_description = ct_analysis_result["summary"].get("combined_description", "") 188 | elif "analysis" in ct_analysis_result: 189 | ct_description = ct_analysis_result["analysis"].get("combined_description", "") 190 | else: 191 | ct_description = str(ct_analysis_result) 192 | 193 | # 创建提示模板 194 | template = """ 195 | 作为医学影像专家,请基于CT图像分析结果和检索到的医学知识,提供综合的医学解释。 196 | 197 | CT图像分析结果: 198 | {ct_description} 199 | 200 | 相关医学知识: 201 | {knowledge_context} 202 | 203 | 请提供综合的医学分析,包括可能的诊断、相关医学解释和临床意义。请使用专业但清晰的语言: 204 | """ 205 | 206 | # 创建提示 207 | prompt = PromptTemplate( 208 | template=template, 209 | input_variables=["ct_description", "knowledge_context"] 210 | ) 211 | 212 | # 创建链 213 | chain = LLMChain(llm=self.llm, prompt=prompt) 214 | 215 | # 执行链 216 | result = chain.run( 217 | ct_description=ct_description, 218 | knowledge_context=knowledge_context 219 | ) 220 | 221 | return result 222 | -------------------------------------------------------------------------------- /utils/report_formatter.py: -------------------------------------------------------------------------------- 1 | """ 2 | 报告格式化模块: 提供医学报告格式化和处理功能 3 | """ 4 | import json 5 | import datetime 6 | from typing import Dict, Any, List, Optional 7 | 8 | from config import REPORT_TEMPLATE 9 | 10 | 11 | def format_ct_report( 12 | image_findings: str, 13 | analysis: str, 14 | diagnostic_opinion: str, 15 | recommendations: str, 16 | examination_area: str = "胸部/腹部", 17 | examination_date: Optional[str] = None, 18 | report_date: Optional[str] = None, 19 | additional_info: Dict[str, Any] = None 20 | ) -> str: 21 | """ 22 | 格式化CT诊断报告 23 | 24 | Args: 25 | image_findings: 影像发现内容 26 | analysis: 分析与解释内容 27 | diagnostic_opinion: 诊断意见内容 28 | recommendations: 建议内容 29 | examination_area: 检查部位 30 | examination_date: 检查日期(可选),默认为当前日期 31 | report_date: 报告日期(可选),默认为当前日期 32 | additional_info: 其他额外信息(可选) 33 | 34 | Returns: 35 | 格式化后的报告文本 36 | """ 37 | # 处理日期 38 | if examination_date is None: 39 | examination_date = datetime.datetime.now().strftime("%Y-%m-%d") 40 | 41 | if report_date is None: 42 | report_date = datetime.datetime.now().strftime("%Y-%m-%d") 43 | 44 | # 准备报告内容 45 | report_content = REPORT_TEMPLATE.format( 46 | examination_date=examination_date, 47 | examination_area=examination_area, 48 | image_findings=image_findings, 49 | analysis_and_interpretation=analysis, 50 | diagnostic_opinion=diagnostic_opinion, 51 | recommendations=recommendations, 52 | report_date=report_date 53 | ) 54 | 55 | # 如果有额外信息,添加到报告末尾 56 | if additional_info: 57 | additional_section = "\n## 附加信息\n" 58 | for key, value in additional_info.items(): 59 | additional_section += f"- **{key}**: {value}\n" 60 | report_content += additional_section 61 | 62 | return report_content 63 | 64 | 65 | def save_report_to_markdown(report_content: str, output_path: str) -> str: 66 | """ 67 | 将报告保存为Markdown文件 68 | 69 | Args: 70 | report_content: 报告内容 71 | output_path: 输出文件路径 72 | 73 | Returns: 74 | 保存的文件路径 75 | """ 76 | with open(output_path, 'w', encoding='utf-8') as file: 77 | file.write(report_content) 78 | 79 | return output_path 80 | 81 | 82 | def save_report_to_json( 83 | report_data: Dict[str, Any], 84 | output_path: str, 85 | indent: int = 2 86 | ) -> str: 87 | """ 88 | 将报告数据保存为JSON文件 89 | 90 | Args: 91 | report_data: 报告数据字典 92 | output_path: 输出文件路径 93 | indent: JSON缩进空格数 94 | 95 | Returns: 96 | 保存的文件路径 97 | """ 98 | with open(output_path, 'w', encoding='utf-8') as file: 99 | json.dump(report_data, file, ensure_ascii=False, indent=indent) 100 | 101 | return output_path 102 | 103 | 104 | def report_data_to_markdown(report_data: Dict[str, Any]) -> str: 105 | """ 106 | 将报告数据字典转换为Markdown格式 107 | 108 | Args: 109 | report_data: 报告数据字典 110 | 111 | Returns: 112 | Markdown格式的报告内容 113 | """ 114 | # 提取关键信息 115 | image_findings = report_data.get('image_findings', '') 116 | analysis = report_data.get('analysis', '') 117 | diagnostic_opinion = report_data.get('diagnostic_opinion', '') 118 | recommendations = report_data.get('recommendations', '') 119 | examination_area = report_data.get('examination_area', '胸部/腹部') 120 | examination_date = report_data.get('examination_date') 121 | report_date = report_data.get('report_date') 122 | 123 | # 提取额外信息 124 | additional_info = {} 125 | for key, value in report_data.items(): 126 | if key not in ['image_findings', 'analysis', 'diagnostic_opinion', 127 | 'recommendations', 'examination_area', 128 | 'examination_date', 'report_date']: 129 | additional_info[key] = value 130 | 131 | # 生成并返回格式化报告 132 | return format_ct_report( 133 | image_findings=image_findings, 134 | analysis=analysis, 135 | diagnostic_opinion=diagnostic_opinion, 136 | recommendations=recommendations, 137 | examination_area=examination_area, 138 | examination_date=examination_date, 139 | report_date=report_date, 140 | additional_info=additional_info 141 | ) 142 | 143 | 144 | def markdown_to_report_data(markdown_content: str) -> Dict[str, Any]: 145 | """ 146 | 从Markdown格式报告中提取数据 147 | 148 | Args: 149 | markdown_content: Markdown格式的报告内容 150 | 151 | Returns: 152 | 报告数据字典 153 | """ 154 | report_data = {} 155 | 156 | # 提取检查日期 157 | import re 158 | examination_date_match = re.search(r'\*\*检查日期\*\*:\s*(.*)', markdown_content) 159 | if examination_date_match: 160 | report_data['examination_date'] = examination_date_match.group(1).strip() 161 | 162 | # 提取检查部位 163 | examination_area_match = re.search(r'\*\*检查部位\*\*:\s*(.*)', markdown_content) 164 | if examination_area_match: 165 | report_data['examination_area'] = examination_area_match.group(1).strip() 166 | 167 | # 提取影像发现 168 | findings_start = markdown_content.find('## 影像发现') 169 | analysis_start = markdown_content.find('## 分析与解释') 170 | if findings_start >= 0 and analysis_start >= 0: 171 | findings_text = markdown_content[findings_start+len('## 影像发现'):analysis_start].strip() 172 | report_data['image_findings'] = findings_text 173 | 174 | # 提取分析与解释 175 | analysis_start = markdown_content.find('## 分析与解释') 176 | diagnosis_start = markdown_content.find('## 诊断意见') 177 | if analysis_start >= 0 and diagnosis_start >= 0: 178 | analysis_text = markdown_content[analysis_start+len('## 分析与解释'):diagnosis_start].strip() 179 | report_data['analysis'] = analysis_text 180 | 181 | # 提取诊断意见 182 | diagnosis_start = markdown_content.find('## 诊断意见') 183 | recommendations_start = markdown_content.find('## 建议') 184 | if diagnosis_start >= 0 and recommendations_start >= 0: 185 | diagnosis_text = markdown_content[diagnosis_start+len('## 诊断意见'):recommendations_start].strip() 186 | report_data['diagnostic_opinion'] = diagnosis_text 187 | 188 | # 提取建议 189 | recommendations_start = markdown_content.find('## 建议') 190 | report_date_start = markdown_content.find('## 报告日期') 191 | if recommendations_start >= 0 and report_date_start >= 0: 192 | recommendations_text = markdown_content[recommendations_start+len('## 建议'):report_date_start].strip() 193 | report_data['recommendations'] = recommendations_text 194 | 195 | # 提取报告日期 196 | report_date_start = markdown_content.find('## 报告日期') 197 | additional_info_start = markdown_content.find('## 附加信息') 198 | if report_date_start >= 0: 199 | if additional_info_start >= 0: 200 | report_date_text = markdown_content[report_date_start+len('## 报告日期'):additional_info_start].strip() 201 | else: 202 | report_date_text = markdown_content[report_date_start+len('## 报告日期'):].strip() 203 | report_data['report_date'] = report_date_text 204 | 205 | return report_data 206 | 207 | 208 | def append_to_report( 209 | original_report: str, 210 | section_title: str, 211 | content: str 212 | ) -> str: 213 | """ 214 | 向报告中追加新内容 215 | 216 | Args: 217 | original_report: 原始报告内容 218 | section_title: 要添加的章节标题 219 | content: 要添加的内容 220 | 221 | Returns: 222 | 更新后的报告内容 223 | """ 224 | # 检查是否已存在该章节 225 | section_header = f"## {section_title}" 226 | 227 | if section_header in original_report: 228 | # 如果章节已存在,更新内容 229 | import re 230 | pattern = f"(## {section_title}.*?)(?=## |$)" 231 | replacement = f"## {section_title}\n{content}\n\n" 232 | updated_report = re.sub(pattern, replacement, original_report, flags=re.DOTALL) 233 | return updated_report 234 | else: 235 | # 如果章节不存在,添加到末尾 236 | return f"{original_report.rstrip()}\n\n## {section_title}\n{content}\n" 237 | -------------------------------------------------------------------------------- /crew/process.py: -------------------------------------------------------------------------------- 1 | """ 2 | 定义CrewAI工作流程: 协调智能体任务流程 3 | """ 4 | from typing import Dict, Any, List, Optional, Union 5 | import os 6 | 7 | from crewai import Crew, Process 8 | from crewai.agent import Agent 9 | from crewai.task import Task 10 | from langchain_core.language_models import BaseLanguageModel 11 | 12 | from config import CREWAI_VERBOSE 13 | from tools.ct_analysis import BiomedCLIPTool, create_ct_analysis_tool 14 | from tools.knowledge_retrieval import MedicalKnowledgeRetrievalTool 15 | from tools.report_generation import MedicalReportGenerator 16 | from crew.agents import create_medical_ct_agents 17 | from crew.tasks import ( 18 | create_image_analysis_task, 19 | create_knowledge_retrieval_task, 20 | create_report_generation_task, 21 | create_comparative_report_task 22 | ) 23 | 24 | 25 | class MedicalCTCrew: 26 | """医学CT分析智能体团队""" 27 | 28 | def __init__( 29 | self, 30 | llm: Optional[BaseLanguageModel] = None, 31 | ct_analysis_tool: Optional[BiomedCLIPTool] = None, 32 | knowledge_tool: Optional[MedicalKnowledgeRetrievalTool] = None, 33 | report_tool: Optional[MedicalReportGenerator] = None 34 | ): 35 | """ 36 | 初始化医学CT分析智能体团队 37 | 38 | Args: 39 | llm: 语言模型 40 | ct_analysis_tool: CT分析工具 41 | knowledge_tool: 知识检索工具 42 | report_tool: 报告生成工具 43 | """ 44 | # 创建智能体 45 | self.agents = create_medical_ct_agents(llm=llm) 46 | 47 | # 保存语言模型 48 | self.llm = llm 49 | 50 | # 保存或创建工具 51 | self.ct_analysis_tool = ct_analysis_tool or create_ct_analysis_tool() 52 | self.knowledge_tool = knowledge_tool 53 | self.report_tool = report_tool 54 | 55 | # 验证所需工具 56 | if not self.knowledge_tool: 57 | raise ValueError("需要提供知识检索工具") 58 | if not self.report_tool: 59 | raise ValueError("需要提供报告生成工具") 60 | 61 | def analyze_ct_images( 62 | self, 63 | image_paths: Union[str, List[str]], 64 | process_type: str = "sequential", 65 | output_path: Optional[str] = None 66 | ) -> Dict[str, Any]: 67 | """ 68 | 分析CT图像并生成诊断报告 69 | 70 | Args: 71 | image_paths: CT图像路径或路径列表 72 | process_type: 处理类型 ("sequential" 或 "hierarchical") 73 | output_path: 报告输出路径 74 | 75 | Returns: 76 | 分析结果字典 77 | """ 78 | # 1. 创建图像分析任务 79 | image_analysis_task = create_image_analysis_task( 80 | agent=self.agents["image_analyst"], 81 | ct_analysis_tool=self.ct_analysis_tool, 82 | image_paths=image_paths 83 | ) 84 | 85 | # 2. 创建知识检索和报告生成任务(为后续步骤准备) 86 | # 这些任务需要在第一个任务完成后才能具体化参数,先定义为None 87 | knowledge_retrieval_task = None 88 | report_generation_task = None 89 | 90 | # 3. 创建智能体团队 91 | crew = Crew( 92 | agents=[ 93 | self.agents["image_analyst"], 94 | self.agents["medical_researcher"], 95 | self.agents["radiologist"] 96 | ], 97 | tasks=[image_analysis_task], # 先只添加第一个任务 98 | verbose=CREWAI_VERBOSE, 99 | process=Process.sequential if process_type == "sequential" else Process.hierarchical 100 | ) 101 | 102 | # 4. 执行图像分析任务 103 | print("执行CT图像分析任务...") 104 | image_analysis_result = crew.kickoff() 105 | 106 | # 5. 解析图像分析结果 107 | try: 108 | ct_analysis_result = eval(image_analysis_result) 109 | except: 110 | # 如果无法解析为字典,则作为字符串处理 111 | ct_analysis_result = {"analysis": {"combined_description": image_analysis_result}} 112 | 113 | # 6. 创建知识检索任务并执行 114 | knowledge_retrieval_task = create_knowledge_retrieval_task( 115 | agent=self.agents["medical_researcher"], 116 | retrieval_tool=self.knowledge_tool, 117 | ct_analysis_result=ct_analysis_result 118 | ) 119 | 120 | # 更新团队任务 121 | crew.tasks = [knowledge_retrieval_task] 122 | 123 | print("执行医学知识检索任务...") 124 | medical_knowledge = crew.kickoff() 125 | 126 | # 7. 创建报告生成任务并执行 127 | report_generation_task = create_report_generation_task( 128 | agent=self.agents["radiologist"], 129 | report_tool=self.report_tool, 130 | ct_analysis_result=ct_analysis_result, 131 | medical_knowledge=medical_knowledge 132 | ) 133 | 134 | # 更新团队任务 135 | crew.tasks = [report_generation_task] 136 | 137 | print("执行诊断报告生成任务...") 138 | report_result = crew.kickoff() 139 | 140 | # 8. 解析报告结果 141 | try: 142 | report_data = eval(report_result) 143 | except: 144 | # 如果无法解析为字典,则作为字符串处理 145 | report_data = {"content": report_result, "format": "markdown"} 146 | 147 | # 9. 保存报告(如果指定了输出路径) 148 | if output_path: 149 | self.report_tool.save_report(report_data, output_path) 150 | print(f"报告已保存至: {output_path}") 151 | 152 | # 10. 返回完整结果 153 | return { 154 | "ct_analysis": ct_analysis_result, 155 | "medical_knowledge": medical_knowledge, 156 | "report": report_data 157 | } 158 | 159 | def compare_ct_scans( 160 | self, 161 | current_image_paths: Union[str, List[str]], 162 | previous_image_paths: Union[str, List[str]], 163 | output_path: Optional[str] = None 164 | ) -> Dict[str, Any]: 165 | """ 166 | 比较当前与之前的CT扫描结果 167 | 168 | Args: 169 | current_image_paths: 当前CT图像路径或路径列表 170 | previous_image_paths: 之前CT图像路径或路径列表 171 | output_path: 报告输出路径 172 | 173 | Returns: 174 | 比较分析结果字典 175 | """ 176 | # 1. 分析当前CT图像 177 | print("分析当前CT图像...") 178 | current_ct_result = self.ct_analysis_tool.analyze_multiple_images( 179 | image_paths=current_image_paths 180 | ) 181 | 182 | # 2. 分析之前的CT图像 183 | print("分析之前CT图像...") 184 | previous_ct_result = self.ct_analysis_tool.analyze_multiple_images( 185 | image_paths=previous_image_paths 186 | ) 187 | 188 | # 3. 基于两组CT分析结果检索医学知识 189 | print("检索相关医学知识...") 190 | # 结合两组CT分析的结果 191 | combined_description = "" 192 | if "summary" in current_ct_result: 193 | combined_description += current_ct_result["summary"].get("combined_description", "") + " " 194 | if "summary" in previous_ct_result: 195 | combined_description += previous_ct_result["summary"].get("combined_description", "") 196 | 197 | medical_knowledge = self.knowledge_tool.retrieve_knowledge(combined_description) 198 | 199 | # 4. 创建对比报告任务 200 | comparative_report_task = create_comparative_report_task( 201 | agent=self.agents["radiologist"], 202 | report_tool=self.report_tool, 203 | current_ct_result=current_ct_result, 204 | previous_ct_result=previous_ct_result, 205 | medical_knowledge=medical_knowledge 206 | ) 207 | 208 | # 5. 创建智能体团队 209 | crew = Crew( 210 | agents=[self.agents["radiologist"]], 211 | tasks=[comparative_report_task], 212 | verbose=CREWAI_VERBOSE, 213 | process=Process.sequential 214 | ) 215 | 216 | # 6. 执行对比报告任务 217 | print("生成对比分析报告...") 218 | comparative_report = crew.kickoff() 219 | 220 | # 7. 解析报告结果 221 | try: 222 | report_data = eval(comparative_report) 223 | except: 224 | # 如果无法解析为字典,则作为字符串处理 225 | report_data = {"content": comparative_report, "format": "markdown"} 226 | 227 | # 8. 保存报告(如果指定了输出路径) 228 | if output_path: 229 | self.report_tool.save_report(report_data, output_path) 230 | print(f"对比分析报告已保存至: {output_path}") 231 | 232 | # 9. 返回完整结果 233 | return { 234 | "current_ct_analysis": current_ct_result, 235 | "previous_ct_analysis": previous_ct_result, 236 | "medical_knowledge": medical_knowledge, 237 | "comparative_report": report_data 238 | } 239 | -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | 集成测试: 验证医学CT智能体系统的端到端功能 3 | """ 4 | import unittest 5 | import os 6 | import tempfile 7 | from unittest.mock import patch, MagicMock 8 | import json 9 | from PIL import Image 10 | 11 | from langchain.schema.language_model import BaseLanguageModel 12 | from crewai import Agent, Task, Crew 13 | 14 | from tools.ct_analysis import BiomedCLIPTool, create_ct_analysis_tool 15 | from tools.knowledge_retrieval import MedicalKnowledgeRetrievalTool 16 | from tools.report_generation import MedicalReportGenerator 17 | from crew.process import MedicalCTCrew 18 | 19 | 20 | class TestSystemIntegration(unittest.TestCase): 21 | """系统集成测试""" 22 | 23 | @classmethod 24 | def setUpClass(cls): 25 | """设置测试环境""" 26 | # 创建临时测试目录和图像 27 | cls.temp_dir = tempfile.TemporaryDirectory() 28 | cls.test_image_path = os.path.join(cls.temp_dir.name, "test_image.png") 29 | cls.output_path = os.path.join(cls.temp_dir.name, "report.md") 30 | 31 | # 创建简单的灰度图像 32 | img = Image.new('L', (224, 224), 128) 33 | img.save(cls.test_image_path) 34 | 35 | @classmethod 36 | def tearDownClass(cls): 37 | """清理测试环境""" 38 | cls.temp_dir.cleanup() 39 | 40 | @patch("crew.process.create_medical_ct_agents") 41 | @patch("tools.ct_analysis.BiomedCLIPTool") 42 | @patch("tools.knowledge_retrieval.MedicalKnowledgeRetrievalTool") 43 | @patch("tools.report_generation.MedicalReportGenerator") 44 | @patch("crew.process.Crew") 45 | def test_medical_ct_crew_analyze( 46 | self, 47 | mock_crew_class, 48 | mock_report_tool_class, 49 | mock_knowledge_tool_class, 50 | mock_ct_tool_class, 51 | mock_create_agents 52 | ): 53 | """测试医学CT智能体团队的分析功能""" 54 | # 配置模拟对象 55 | mock_llm = MagicMock(spec=BaseLanguageModel) 56 | 57 | mock_ct_tool = MagicMock() 58 | mock_ct_tool_class.return_value = mock_ct_tool 59 | 60 | mock_knowledge_tool = MagicMock() 61 | mock_knowledge_tool_class.return_value = mock_knowledge_tool 62 | 63 | mock_report_tool = MagicMock() 64 | mock_report_tool_class.return_value = mock_report_tool 65 | 66 | mock_agents = { 67 | "image_analyst": MagicMock(spec=Agent), 68 | "medical_researcher": MagicMock(spec=Agent), 69 | "radiologist": MagicMock(spec=Agent) 70 | } 71 | mock_create_agents.return_value = mock_agents 72 | 73 | mock_crew = MagicMock(spec=Crew) 74 | mock_crew_class.return_value = mock_crew 75 | 76 | # 模拟智能体团队的执行结果 77 | mock_crew.kickoff.side_effect = [ 78 | # 第一次调用返回CT分析结果 79 | '''{"analysis": {"combined_description": "肺部有磨玻璃样阴影,考虑肺炎可能。", "abnormality_detected": true}}''', 80 | # 第二次调用返回医学知识 81 | "肺炎通常表现为磨玻璃样阴影,可能是由细菌或病毒感染引起。", 82 | # 第三次调用返回报告结果 83 | '''{"content": "# 医学CT影像诊断报告\\n\\n## 影像发现\\n肺部可见磨玻璃样阴影", "format": "markdown"}''' 84 | ] 85 | 86 | # 创建医学CT智能体团队 87 | crew = MedicalCTCrew( 88 | llm=mock_llm, 89 | ct_analysis_tool=mock_ct_tool, 90 | knowledge_tool=mock_knowledge_tool, 91 | report_tool=mock_report_tool 92 | ) 93 | 94 | # 执行分析 95 | result = crew.analyze_ct_images( 96 | image_paths=self.test_image_path, 97 | output_path=self.output_path 98 | ) 99 | 100 | # 验证结果 101 | self.assertIn("ct_analysis", result) 102 | self.assertIn("medical_knowledge", result) 103 | self.assertIn("report", result) 104 | 105 | # 验证Crew被调用了3次 106 | self.assertEqual(mock_crew.kickoff.call_count, 3) 107 | 108 | # 验证报告保存功能被调用 109 | mock_report_tool.save_report.assert_called_once() 110 | 111 | @patch("tools.ct_analysis.BiomedCLIPTool.analyze_multiple_images") 112 | @patch("tools.knowledge_retrieval.MedicalKnowledgeRetrievalTool.retrieve_knowledge") 113 | @patch("crew.process.create_medical_ct_agents") 114 | @patch("crew.process.Crew") 115 | def test_compare_ct_scans( 116 | self, 117 | mock_crew_class, 118 | mock_create_agents, 119 | mock_retrieve_knowledge, 120 | mock_analyze_images 121 | ): 122 | """测试比较CT扫描功能""" 123 | # 配置模拟对象 124 | mock_llm = MagicMock(spec=BaseLanguageModel) 125 | 126 | mock_analyze_images.side_effect = [ 127 | # 第一次调用返回当前CT分析结果 128 | {"summary": {"combined_description": "肺部有磨玻璃样阴影,考虑肺炎可能。"}}, 129 | # 第二次调用返回之前的CT分析结果 130 | {"summary": {"combined_description": "肺部无明显异常。"}} 131 | ] 132 | 133 | mock_retrieve_knowledge.return_value = "肺炎通常表现为磨玻璃样阴影,可能是由细菌或病毒感染引起。" 134 | 135 | mock_agents = { 136 | "radiologist": MagicMock(spec=Agent) 137 | } 138 | mock_create_agents.return_value = mock_agents 139 | 140 | mock_crew = MagicMock(spec=Crew) 141 | mock_crew_class.return_value = mock_crew 142 | mock_crew.kickoff.return_value = '''{"content": "# CT对比分析报告\\n\\n## 影像对比发现\\n当前扫描显示肺部有新出现的磨玻璃样阴影。", "format": "markdown"}''' 143 | 144 | # 创建医学CT智能体团队 145 | ct_tool = BiomedCLIPTool() # 使用真实的类,但方法会被模拟 146 | knowledge_tool = MagicMock(spec=MedicalKnowledgeRetrievalTool) 147 | knowledge_tool.retrieve_knowledge = mock_retrieve_knowledge 148 | report_tool = MagicMock(spec=MedicalReportGenerator) 149 | 150 | crew = MedicalCTCrew( 151 | llm=mock_llm, 152 | ct_analysis_tool=ct_tool, 153 | knowledge_tool=knowledge_tool, 154 | report_tool=report_tool 155 | ) 156 | 157 | # 执行对比分析 158 | result = crew.compare_ct_scans( 159 | current_image_paths=self.test_image_path, 160 | previous_image_paths=self.test_image_path, 161 | output_path=self.output_path 162 | ) 163 | 164 | # 验证结果 165 | self.assertIn("current_ct_analysis", result) 166 | self.assertIn("previous_ct_analysis", result) 167 | self.assertIn("medical_knowledge", result) 168 | self.assertIn("comparative_report", result) 169 | 170 | # 验证方法调用 171 | self.assertEqual(mock_analyze_images.call_count, 2) 172 | mock_retrieve_knowledge.assert_called_once() 173 | mock_crew.kickoff.assert_called_once() 174 | report_tool.save_report.assert_called_once() 175 | 176 | 177 | class TestEndToEnd(unittest.TestCase): 178 | """端到端测试""" 179 | 180 | def setUp(self): 181 | """设置测试环境""" 182 | # 跳过实际执行的测试,除非设置了特定环境变量 183 | if not os.environ.get("RUN_E2E_TESTS"): 184 | self.skipTest("跳过端到端测试。设置 RUN_E2E_TESTS=1 环境变量以启用。") 185 | 186 | # 创建临时测试目录 187 | self.temp_dir = tempfile.TemporaryDirectory() 188 | 189 | # 检查是否提供了测试图像路径 190 | self.test_image_path = os.environ.get("TEST_CT_IMAGE") 191 | if not self.test_image_path or not os.path.exists(self.test_image_path): 192 | self.skipTest("未提供有效的测试CT图像路径。设置 TEST_CT_IMAGE 环境变量。") 193 | 194 | def tearDown(self): 195 | """清理测试环境""" 196 | if hasattr(self, 'temp_dir'): 197 | self.temp_dir.cleanup() 198 | 199 | @unittest.skip("此测试需要完整环境和API密钥,默认跳过") 200 | def test_full_system(self): 201 | """完整系统测试""" 202 | import sys 203 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 204 | 205 | from main import setup_llm, setup_tools, analyze_ct_images 206 | 207 | # 设置语言模型 208 | llm = setup_llm("openai") 209 | 210 | # 设置工具 211 | tools = setup_tools(llm) 212 | 213 | # 执行分析 214 | output_dir = self.temp_dir.name 215 | result = analyze_ct_images( 216 | image_paths=self.test_image_path, 217 | tools=tools, 218 | llm=llm, 219 | output_dir=output_dir 220 | ) 221 | 222 | # 验证结果 223 | self.assertIn("ct_analysis", result) 224 | self.assertIn("medical_knowledge", result) 225 | self.assertIn("report", result) 226 | 227 | # 验证生成了报告文件 228 | report_files = [f for f in os.listdir(output_dir) if f.endswith('.md')] 229 | self.assertGreater(len(report_files), 0) 230 | 231 | # 检查报告内容 232 | with open(os.path.join(output_dir, report_files[0]), 'r', encoding='utf-8') as f: 233 | report_content = f.read() 234 | self.assertIn("影像发现", report_content) 235 | self.assertIn("诊断意见", report_content) 236 | 237 | 238 | if __name__ == "__main__": 239 | unittest.main() 240 | -------------------------------------------------------------------------------- /langchain_components/retriever.py: -------------------------------------------------------------------------------- 1 | """ 2 | 检索器模块: 配置和提供高级文档检索功能 3 | """ 4 | from typing import List, Dict, Any, Optional, Callable, Union 5 | 6 | from langchain_core.documents import Document 7 | from langchain_core.retrievers import BaseRetriever 8 | from langchain_community.vectorstores import Chroma 9 | from langchain.retrievers import ContextualCompressionRetriever 10 | from langchain.retrievers.document_compressors import LLMChainExtractor 11 | from langchain_core.language_models import BaseLanguageModel 12 | from langchain_core.embeddings import Embeddings 13 | from langchain.chains import LLMChain 14 | from langchain.prompts import PromptTemplate 15 | 16 | from config import TOP_K_RETRIEVALS, SIMILARITY_THRESHOLD 17 | from langchain_components.embeddings import get_medical_embedding_model 18 | from langchain_components.vectorstore import create_or_load_vectorstore 19 | 20 | 21 | def get_basic_retriever( 22 | vectorstore: Chroma, 23 | search_kwargs: Optional[Dict[str, Any]] = None 24 | ) -> BaseRetriever: 25 | """ 26 | 获取基本向量检索器 27 | 28 | Args: 29 | vectorstore: 向量存储实例 30 | search_kwargs: 搜索参数 31 | 32 | Returns: 33 | 基本检索器 34 | """ 35 | if search_kwargs is None: 36 | search_kwargs = {"k": TOP_K_RETRIEVALS} 37 | 38 | return vectorstore.as_retriever(search_kwargs=search_kwargs) 39 | 40 | 41 | def rerank_documents( 42 | documents: List[Document], 43 | query: str, 44 | embedding_model: Optional[Embeddings] = None, 45 | top_k: int = TOP_K_RETRIEVALS, 46 | threshold: float = SIMILARITY_THRESHOLD 47 | ) -> List[Document]: 48 | """ 49 | 对检索到的文档进行重排序 50 | 51 | Args: 52 | documents: 检索到的文档列表 53 | query: 查询文本 54 | embedding_model: 嵌入模型 55 | top_k: 保留的结果数量 56 | threshold: 相似度阈值,低于此值的结果将被过滤 57 | 58 | Returns: 59 | 重排序后的文档列表 60 | """ 61 | import numpy as np 62 | from sklearn.metrics.pairwise import cosine_similarity 63 | 64 | # 如果没有文档,返回空列表 65 | if not documents: 66 | return [] 67 | 68 | # 如果未提供嵌入模型,使用默认医学模型 69 | if embedding_model is None: 70 | embedding_model = get_medical_embedding_model() 71 | 72 | # 嵌入查询 73 | query_embedding = embedding_model.embed_query(query) 74 | 75 | # 嵌入每个文档的内容 76 | doc_embeddings = [] 77 | for doc in documents: 78 | doc_embedding = embedding_model.embed_query(doc.page_content) 79 | doc_embeddings.append(doc_embedding) 80 | 81 | # 计算余弦相似度 82 | similarities = [] 83 | for doc_embedding in doc_embeddings: 84 | similarity = cosine_similarity( 85 | np.array(query_embedding).reshape(1, -1), 86 | np.array(doc_embedding).reshape(1, -1) 87 | )[0][0] 88 | similarities.append(similarity) 89 | 90 | # 创建文档-相似度对 91 | doc_similarity_pairs = list(zip(documents, similarities)) 92 | 93 | # 按相似度排序 94 | sorted_pairs = sorted(doc_similarity_pairs, key=lambda x: x[1], reverse=True) 95 | 96 | # 过滤掉相似度低于阈值的文档 97 | filtered_pairs = [(doc, sim) for doc, sim in sorted_pairs if sim >= threshold] 98 | 99 | # 限制结果数量 100 | filtered_pairs = filtered_pairs[:top_k] 101 | 102 | # 提取排序后的文档 103 | reranked_documents = [doc for doc, _ in filtered_pairs] 104 | 105 | return reranked_documents 106 | 107 | 108 | def get_contextual_compression_retriever( 109 | vectorstore: Chroma, 110 | llm: BaseLanguageModel, 111 | search_kwargs: Optional[Dict[str, Any]] = None 112 | ) -> ContextualCompressionRetriever: 113 | """ 114 | 获取上下文压缩检索器,可以提取相关信息片段 115 | 116 | Args: 117 | vectorstore: 向量存储实例 118 | llm: 语言模型 119 | search_kwargs: 搜索参数 120 | 121 | Returns: 122 | 上下文压缩检索器 123 | """ 124 | # 创建基本检索器 125 | base_retriever = get_basic_retriever(vectorstore, search_kwargs) 126 | 127 | # 创建LLM提取器 128 | compressor = LLMChainExtractor.from_llm(llm) 129 | 130 | # 创建上下文压缩检索器 131 | retriever = ContextualCompressionRetriever( 132 | base_compressor=compressor, 133 | base_retriever=base_retriever 134 | ) 135 | 136 | return retriever 137 | 138 | 139 | def get_medical_knowledge_retriever( 140 | documents: Optional[List[Document]] = None, 141 | embedding_model: Optional[Embeddings] = None, 142 | vectorstore: Optional[Chroma] = None, 143 | k: int = TOP_K_RETRIEVALS 144 | ) -> BaseRetriever: 145 | """ 146 | 获取医学知识检索器 147 | 148 | Args: 149 | documents: 文档列表(如果需要创建新的向量存储) 150 | embedding_model: 嵌入模型 151 | vectorstore: 现有向量存储(如果已有) 152 | k: 检索结果数量 153 | 154 | Returns: 155 | 医学知识检索器 156 | """ 157 | # 获取向量存储 158 | if vectorstore is None: 159 | vectorstore = create_or_load_vectorstore( 160 | documents=documents, 161 | embedding_model=embedding_model 162 | ) 163 | 164 | # 创建检索器 165 | return get_basic_retriever( 166 | vectorstore=vectorstore, 167 | search_kwargs={"k": k} 168 | ) 169 | 170 | 171 | def create_medical_query_rewriter( 172 | llm: BaseLanguageModel 173 | ) -> Callable[[str], str]: 174 | """ 175 | 创建医学查询重写函数,用于优化检索查询 176 | 177 | Args: 178 | llm: 语言模型 179 | 180 | Returns: 181 | 查询重写函数 182 | """ 183 | template = """ 184 | 你是一位医学检索专家,请将以下查询重写为更有效的医学检索查询。 185 | 保留所有关键的医学术语,添加可能相关的同义词,移除不必要的词语。 186 | 187 | 原始查询: {query} 188 | 189 | 重写后的查询: 190 | """ 191 | 192 | prompt = PromptTemplate( 193 | input_variables=["query"], 194 | template=template 195 | ) 196 | 197 | chain = LLMChain(llm=llm, prompt=prompt) 198 | 199 | def rewrite_query(query: str) -> str: 200 | """ 201 | 重写查询以优化检索结果 202 | 203 | Args: 204 | query: 原始查询 205 | 206 | Returns: 207 | 重写后的查询 208 | """ 209 | try: 210 | rewritten_query = chain.run(query=query).strip() 211 | return rewritten_query 212 | except Exception as e: 213 | print(f"重写查询时出错: {e}") 214 | return query 215 | 216 | return rewrite_query 217 | 218 | 219 | def generate_multiple_queries( 220 | ct_findings: str, 221 | llm: BaseLanguageModel, 222 | num_queries: int = 3 223 | ) -> List[str]: 224 | """ 225 | 基于CT发现生成多个检索查询 226 | 227 | Args: 228 | ct_findings: CT发现描述 229 | llm: 语言模型 230 | num_queries: 生成的查询数量 231 | 232 | Returns: 233 | 查询列表 234 | """ 235 | template = """ 236 | 基于以下CT图像分析结果,生成{num_queries}个不同的医学检索查询,每个查询关注不同的关键医学发现或可能的诊断。 237 | 查询应当简洁、精确,包含关键的医学术语。 238 | 239 | CT分析结果: 240 | {ct_findings} 241 | 242 | 查询列表(每行一个查询): 243 | """ 244 | 245 | prompt = PromptTemplate( 246 | input_variables=["ct_findings", "num_queries"], 247 | template=template 248 | ) 249 | 250 | chain = LLMChain(llm=llm, prompt=prompt) 251 | 252 | try: 253 | result = chain.run(ct_findings=ct_findings, num_queries=num_queries).strip() 254 | queries = [q.strip() for q in result.split('\n') if q.strip()] 255 | return queries[:num_queries] # 确保不超过请求的查询数量 256 | except Exception as e: 257 | print(f"生成查询时出错: {e}") 258 | # 返回一个基本查询 259 | return [f"CT影像医学分析: {ct_findings[:100]}..."] 260 | 261 | 262 | def retrieve_medical_knowledge( 263 | query: str, 264 | retriever: BaseRetriever, 265 | rerank: bool = True, 266 | embedding_model: Optional[Embeddings] = None 267 | ) -> List[Document]: 268 | """ 269 | 检索医学知识 270 | 271 | Args: 272 | query: 查询文本 273 | retriever: 检索器 274 | rerank: 是否对结果进行重排序 275 | embedding_model: 用于重排序的嵌入模型 276 | 277 | Returns: 278 | 检索到的文档列表 279 | """ 280 | # 执行检索 281 | documents = retriever.get_relevant_documents(query) 282 | 283 | # 如果需要重排序 284 | if rerank and documents: 285 | documents = rerank_documents( 286 | documents=documents, 287 | query=query, 288 | embedding_model=embedding_model 289 | ) 290 | 291 | return documents 292 | 293 | 294 | def build_medical_context( 295 | documents: List[Document], 296 | max_len: int = 4000 297 | ) -> str: 298 | """ 299 | 从检索文档构建医学上下文 300 | 301 | Args: 302 | documents: 检索到的文档列表 303 | max_len: 上下文最大长度 304 | 305 | Returns: 306 | 构建的上下文文本 307 | """ 308 | if not documents: 309 | return "无相关医学知识。" 310 | 311 | context_parts = [] 312 | current_len = 0 313 | 314 | for i, doc in enumerate(documents): 315 | content = doc.page_content.strip() 316 | source = doc.metadata.get('source', f'来源 {i+1}') 317 | 318 | # 格式化为带有来源的段落 319 | formatted_content = f"[{source}]: {content}" 320 | 321 | # 检查是否会超过最大长度 322 | if current_len + len(formatted_content) > max_len: 323 | # 如果还没有添加任何内容,添加第一个文档的截断版本 324 | if not context_parts: 325 | truncated = formatted_content[:max_len] 326 | context_parts.append(truncated) 327 | break 328 | 329 | context_parts.append(formatted_content) 330 | current_len += len(formatted_content) 331 | 332 | return "\n\n".join(context_parts) 333 | -------------------------------------------------------------------------------- /tests/test_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | 测试Agent工具: 验证CT分析、知识检索和报告生成工具 3 | """ 4 | import unittest 5 | import os 6 | import tempfile 7 | from unittest.mock import patch, MagicMock 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from tools.ct_analysis import BiomedCLIPTool, create_ct_analysis_tool 12 | from tools.knowledge_retrieval import MedicalKnowledgeRetrievalTool 13 | from tools.report_generation import MedicalReportGenerator 14 | from utils.image_processing import preprocess_for_biomedclip 15 | 16 | 17 | class TestCTAnalysisTool(unittest.TestCase): 18 | """测试CT分析工具""" 19 | 20 | @classmethod 21 | def setUpClass(cls): 22 | """设置测试环境""" 23 | # 创建临时测试图像 24 | cls.temp_dir = tempfile.TemporaryDirectory() 25 | cls.test_image_path = os.path.join(cls.temp_dir.name, "test_image.png") 26 | 27 | # 创建简单的灰度图像 28 | img = Image.new('L', (224, 224), 128) 29 | img.save(cls.test_image_path) 30 | 31 | @classmethod 32 | def tearDownClass(cls): 33 | """清理测试环境""" 34 | cls.temp_dir.cleanup() 35 | 36 | @patch("tools.ct_analysis.AutoProcessor") 37 | @patch("tools.ct_analysis.AutoModel") 38 | def test_biomedclip_tool_init(self, mock_model, mock_processor): 39 | """测试BiomedCLIPTool初始化""" 40 | # 配置模拟对象 41 | mock_processor.from_pretrained.return_value = MagicMock() 42 | mock_model.from_pretrained.return_value = MagicMock() 43 | mock_model.from_pretrained.return_value.to.return_value = MagicMock() 44 | 45 | # 创建工具实例 46 | tool = BiomedCLIPTool() 47 | 48 | # 验证初始化行为 49 | self.assertIsNotNone(tool.processor) 50 | self.assertIsNotNone(tool.model) 51 | mock_processor.from_pretrained.assert_called_once() 52 | mock_model.from_pretrained.assert_called_once() 53 | 54 | @patch("tools.ct_analysis.BiomedCLIPTool._run_biomedclip_analysis") 55 | def test_analyze_regular_image(self, mock_run_analysis): 56 | """测试分析常规图像""" 57 | # 配置模拟对象 58 | mock_run_analysis.return_value = { 59 | "combined_description": "正常CT图像,无明显异常。", 60 | "abnormality_detected": False, 61 | "confidence": 0.95 62 | } 63 | 64 | # 创建带有模拟方法的工具实例 65 | tool = MagicMock(spec=BiomedCLIPTool) 66 | tool._analyze_regular_image = BiomedCLIPTool._analyze_regular_image 67 | tool._run_biomedclip_analysis = mock_run_analysis 68 | 69 | # 使用真实图像路径调用方法 70 | result = tool._analyze_regular_image(self, self.test_image_path) 71 | 72 | # 验证结果 73 | self.assertIn("metadata", result) 74 | self.assertIn("analysis", result) 75 | self.assertEqual(result["image_type"], "Regular") 76 | mock_run_analysis.assert_called_once() 77 | 78 | def test_create_ct_analysis_tool(self): 79 | """测试创建CT分析工具函数""" 80 | with patch("tools.ct_analysis.BiomedCLIPTool") as mock_tool_class: 81 | # 配置模拟对象 82 | mock_tool_class.return_value = MagicMock() 83 | 84 | # 调用函数 85 | tool = create_ct_analysis_tool() 86 | 87 | # 验证结果 88 | self.assertIsNotNone(tool) 89 | mock_tool_class.assert_called_once() 90 | 91 | 92 | class TestKnowledgeRetrievalTool(unittest.TestCase): 93 | """测试知识检索工具""" 94 | 95 | def setUp(self): 96 | """设置测试环境""" 97 | # 模拟语言模型 98 | self.mock_llm = MagicMock() 99 | 100 | # 模拟检索器 101 | self.mock_retriever = MagicMock() 102 | self.mock_retriever.get_relevant_documents.return_value = [ 103 | MagicMock(page_content="肺炎的CT表现", metadata={"source": "医学教科书"}), 104 | MagicMock(page_content="肺部结节的分类", metadata={"source": "放射学指南"}) 105 | ] 106 | 107 | @patch("tools.knowledge_retrieval.get_medical_knowledge_retriever") 108 | def test_init_with_retriever(self, mock_get_retriever): 109 | """测试使用现有检索器初始化""" 110 | # 创建工具实例 111 | tool = MedicalKnowledgeRetrievalTool( 112 | llm=self.mock_llm, 113 | retriever=self.mock_retriever 114 | ) 115 | 116 | # 验证初始化行为 117 | self.assertEqual(tool.llm, self.mock_llm) 118 | self.assertEqual(tool.retriever, self.mock_retriever) 119 | mock_get_retriever.assert_not_called() 120 | 121 | def test_retrieve_knowledge(self): 122 | """测试检索知识""" 123 | # 创建工具实例 124 | tool = MedicalKnowledgeRetrievalTool( 125 | llm=self.mock_llm, 126 | retriever=self.mock_retriever 127 | ) 128 | 129 | # 调用检索方法 130 | result = tool.retrieve_knowledge("肺炎的CT表现") 131 | 132 | # 验证结果 133 | self.assertTrue(isinstance(result, str)) 134 | self.mock_retriever.get_relevant_documents.assert_called_once() 135 | 136 | def test_retrieve_knowledge_return_documents(self): 137 | """测试检索知识并返回文档""" 138 | # 创建工具实例 139 | tool = MedicalKnowledgeRetrievalTool( 140 | llm=self.mock_llm, 141 | retriever=self.mock_retriever 142 | ) 143 | 144 | # 调用检索方法 145 | result = tool.retrieve_knowledge("肺炎的CT表现", return_documents=True) 146 | 147 | # 验证结果 148 | self.assertTrue(isinstance(result, list)) 149 | self.assertEqual(len(result), 2) 150 | self.mock_retriever.get_relevant_documents.assert_called_once() 151 | 152 | @patch("tools.knowledge_retrieval.generate_multiple_queries") 153 | def test_retrieve_knowledge_from_ct_analysis(self, mock_generate_queries): 154 | """测试基于CT分析结果检索知识""" 155 | # 配置模拟对象 156 | mock_generate_queries.return_value = ["肺炎的CT表现", "肺部结节的特征"] 157 | 158 | # 创建工具实例 159 | tool = MedicalKnowledgeRetrievalTool( 160 | llm=self.mock_llm, 161 | retriever=self.mock_retriever 162 | ) 163 | 164 | # 创建测试CT分析结果 165 | ct_result = { 166 | "summary": { 167 | "combined_description": "肺部有磨玻璃样阴影,考虑肺炎可能。" 168 | } 169 | } 170 | 171 | # 调用检索方法 172 | result = tool.retrieve_knowledge_from_ct_analysis(ct_result) 173 | 174 | # 验证结果 175 | self.assertTrue(isinstance(result, str)) 176 | mock_generate_queries.assert_called_once() 177 | # 检索器应该被调用两次,对应两个生成的查询 178 | self.assertEqual(self.mock_retriever.get_relevant_documents.call_count, 2) 179 | 180 | 181 | class TestReportGenerationTool(unittest.TestCase): 182 | """测试报告生成工具""" 183 | 184 | def setUp(self): 185 | """设置测试环境""" 186 | # 模拟语言模型 187 | self.mock_llm = MagicMock() 188 | self.mock_llm_chain = MagicMock() 189 | self.mock_llm_chain.run.return_value = """ 190 | 1. 影像发现: 肺部可见磨玻璃样阴影,分布于双肺下叶。 191 | 192 | 2. 分析与解释: 磨玻璃样阴影通常提示间质性改变,考虑可能是炎症或早期间质性肺病。 193 | 194 | 3. 诊断意见: 考虑为肺炎,也不排除间质性肺病的可能。 195 | 196 | 4. 建议: 建议短期复查CT,进行血常规、C反应蛋白等炎症指标检查。 197 | """ 198 | 199 | # 创建工具实例 200 | self.tool = MedicalReportGenerator(llm=self.mock_llm) 201 | 202 | @patch("tools.report_generation.LLMChain") 203 | def test_generate_ct_report(self, mock_llm_chain_class): 204 | """测试生成CT报告""" 205 | # 配置模拟对象 206 | mock_llm_chain_class.return_value = self.mock_llm_chain 207 | 208 | # 创建测试CT分析结果和医学知识 209 | ct_result = { 210 | "summary": { 211 | "combined_description": "肺部有磨玻璃样阴影,考虑肺炎可能。", 212 | "abnormality_detected": True, 213 | "scan_region": "胸部" 214 | } 215 | } 216 | medical_knowledge = "肺炎通常表现为磨玻璃样阴影。" 217 | 218 | # 调用生成报告方法 219 | result = self.tool.generate_ct_report(ct_result, medical_knowledge) 220 | 221 | # 验证结果 222 | self.assertIn("content", result) 223 | self.assertEqual(result["format"], "markdown") 224 | mock_llm_chain_class.assert_called_once() 225 | self.mock_llm_chain.run.assert_called_once() 226 | 227 | def test_extract_report_sections(self): 228 | """测试提取报告各个部分""" 229 | # 准备测试报告文本 230 | report_text = """ 231 | 1. 影像发现: 肺部可见磨玻璃样阴影,分布于双肺下叶。 232 | 233 | 2. 分析与解释: 磨玻璃样阴影通常提示间质性改变,考虑可能是炎症或早期间质性肺病。 234 | 235 | 3. 诊断意见: 考虑为肺炎,也不排除间质性肺病的可能。 236 | 237 | 4. 建议: 建议短期复查CT,进行血常规、C反应蛋白等炎症指标检查。 238 | """ 239 | 240 | # 调用提取方法 241 | sections = self.tool._extract_report_sections(report_text) 242 | 243 | # 验证结果 244 | self.assertIn("影像发现", sections) 245 | self.assertIn("分析与解释", sections) 246 | self.assertIn("诊断意见", sections) 247 | self.assertIn("建议", sections) 248 | self.assertTrue("肺部可见磨玻璃样阴影" in sections["影像发现"]) 249 | self.assertTrue("肺炎" in sections["诊断意见"]) 250 | 251 | @patch("tools.report_generation.save_report_to_markdown") 252 | def test_save_report_markdown(self, mock_save): 253 | """测试保存报告为Markdown""" 254 | # 配置模拟对象 255 | mock_save.return_value = "/tmp/report.md" 256 | 257 | # 准备测试报告数据 258 | report_data = { 259 | "content": "# 医学CT影像诊断报告\n\n## 影像发现\n肺部可见磨玻璃样阴影", 260 | "format": "markdown" 261 | } 262 | 263 | # 调用保存方法 264 | result = self.tool.save_report(report_data, "/tmp/report.md") 265 | 266 | # 验证结果 267 | self.assertEqual(result, "/tmp/report.md") 268 | mock_save.assert_called_once() 269 | 270 | 271 | if __name__ == "__main__": 272 | unittest.main() 273 | -------------------------------------------------------------------------------- /tools/report_generation.py: -------------------------------------------------------------------------------- 1 | """ 2 | 报告生成工具: 生成专业医疗诊断报告 3 | """ 4 | from typing import Dict, Any, List, Optional, Union 5 | import datetime 6 | 7 | from langchain_core.language_models import BaseLanguageModel 8 | from langchain.chains import LLMChain 9 | from langchain.prompts import PromptTemplate 10 | 11 | from config import REPORT_TEMPLATE 12 | from utils.report_formatter import ( 13 | format_ct_report, 14 | save_report_to_markdown, 15 | save_report_to_json 16 | ) 17 | 18 | 19 | class MedicalReportGenerator: 20 | """医疗报告生成器""" 21 | 22 | def __init__(self, llm: BaseLanguageModel): 23 | """ 24 | 初始化医疗报告生成器 25 | 26 | Args: 27 | llm: 语言模型,用于生成报告内容 28 | """ 29 | self.llm = llm 30 | 31 | def generate_ct_report( 32 | self, 33 | ct_analysis: Dict[str, Any], 34 | medical_knowledge: str, 35 | output_format: str = "markdown" 36 | ) -> Dict[str, Any]: 37 | """ 38 | 生成CT诊断报告 39 | 40 | Args: 41 | ct_analysis: CT分析结果 42 | medical_knowledge: 医学知识上下文 43 | output_format: 输出格式 ("markdown", "json", "dict") 44 | 45 | Returns: 46 | 包含报告内容的字典 47 | """ 48 | # 提取CT分析描述 49 | if "summary" in ct_analysis: 50 | ct_description = ct_analysis["summary"].get("combined_description", "") 51 | scan_region = ct_analysis["summary"].get("scan_region", "未知区域") 52 | elif "analysis" in ct_analysis: 53 | ct_description = ct_analysis["analysis"].get("combined_description", "") 54 | scan_region = "未知区域" 55 | else: 56 | ct_description = str(ct_analysis) 57 | scan_region = "未知区域" 58 | 59 | # 创建提示模板 60 | template = """ 61 | 作为一名经验丰富的放射科医师,请基于下面的CT图像分析结果和相关医学知识,生成一份专业的CT诊断报告。 62 | 你的报告应当包含影像发现、分析与解释、诊断意见和建议等部分。 63 | 64 | ## CT图像分析结果 65 | {ct_description} 66 | 67 | ## 相关医学知识 68 | {medical_knowledge} 69 | 70 | 请提供以下四个部分的内容,每部分都要详细专业: 71 | 72 | 1. 影像发现: (详细描述CT图像中观察到的客观发现) 73 | 74 | 2. 分析与解释: (对影像发现进行专业分析和解释) 75 | 76 | 3. 诊断意见: (给出可能的诊断,如有多种可能,请列出并标明优先级) 77 | 78 | 4. 建议: (提供针对诊断的下一步建议,如进一步检查、随访或治疗方案) 79 | 80 | 请确保报告内容专业准确,使用医学术语,但同时要清晰易懂。 81 | """ 82 | 83 | # 创建提示 84 | prompt = PromptTemplate( 85 | template=template, 86 | input_variables=["ct_description", "medical_knowledge"] 87 | ) 88 | 89 | # 创建链 90 | chain = LLMChain(llm=self.llm, prompt=prompt) 91 | 92 | # 执行链 93 | result = chain.run( 94 | ct_description=ct_description, 95 | medical_knowledge=medical_knowledge 96 | ) 97 | 98 | # 提取各个部分 99 | report_sections = self._extract_report_sections(result) 100 | 101 | # 生成报告 102 | if output_format == "markdown" or output_format == "md": 103 | report_content = format_ct_report( 104 | image_findings=report_sections["影像发现"], 105 | analysis=report_sections["分析与解释"], 106 | diagnostic_opinion=report_sections["诊断意见"], 107 | recommendations=report_sections["建议"], 108 | examination_area=scan_region 109 | ) 110 | report_data = {"content": report_content, "format": "markdown"} 111 | else: 112 | report_data = { 113 | "image_findings": report_sections["影像发现"], 114 | "analysis": report_sections["分析与解释"], 115 | "diagnostic_opinion": report_sections["诊断意见"], 116 | "recommendations": report_sections["建议"], 117 | "examination_area": scan_region, 118 | "examination_date": datetime.datetime.now().strftime("%Y-%m-%d"), 119 | "report_date": datetime.datetime.now().strftime("%Y-%m-%d"), 120 | "format": output_format 121 | } 122 | 123 | return report_data 124 | 125 | def _extract_report_sections(self, report_text: str) -> Dict[str, str]: 126 | """ 127 | 从生成的文本中提取报告各个部分 128 | 129 | Args: 130 | report_text: 生成的报告文本 131 | 132 | Returns: 133 | 报告各部分内容的字典 134 | """ 135 | sections = { 136 | "影像发现": "", 137 | "分析与解释": "", 138 | "诊断意见": "", 139 | "建议": "" 140 | } 141 | 142 | # 查找各部分的位置 143 | section_markers = [ 144 | ("影像发现", ["1. 影像发现", "影像发现:", "影像发现:"]), 145 | ("分析与解释", ["2. 分析与解释", "分析与解释:", "分析与解释:"]), 146 | ("诊断意见", ["3. 诊断意见", "诊断意见:", "诊断意见:"]), 147 | ("建议", ["4. 建议", "建议:", "建议:"]) 148 | ] 149 | 150 | # 初始化每个部分的起始位置 151 | positions = {} 152 | 153 | # 寻找每个部分的起始位置 154 | for section, markers in section_markers: 155 | for marker in markers: 156 | pos = report_text.find(marker) 157 | if pos >= 0: 158 | # 加上标记的长度,跳过标记本身 159 | positions[section] = pos + len(marker) 160 | break 161 | 162 | # 按位置排序节段 163 | sorted_sections = sorted(positions.items(), key=lambda x: x[1]) 164 | 165 | # 提取每个部分的内容 166 | for i, (section, start) in enumerate(sorted_sections): 167 | # 如果不是最后一个部分,则截取到下一个部分开始位置 168 | if i < len(sorted_sections) - 1: 169 | next_section, next_start = sorted_sections[i+1] 170 | section_content = report_text[start:next_start].strip() 171 | else: 172 | # 最后一个部分,截取到文本末尾 173 | section_content = report_text[start:].strip() 174 | 175 | sections[section] = section_content 176 | 177 | return sections 178 | 179 | def save_report( 180 | self, 181 | report_data: Dict[str, Any], 182 | output_path: str, 183 | format_type: Optional[str] = None 184 | ) -> str: 185 | """ 186 | 保存报告到文件 187 | 188 | Args: 189 | report_data: 报告数据 190 | output_path: 输出路径 191 | format_type: 输出格式类型,如果不指定则从report_data中获取 192 | 193 | Returns: 194 | 保存的文件路径 195 | """ 196 | # 确定输出格式 197 | if format_type is None: 198 | format_type = report_data.get("format", "markdown") 199 | 200 | # 根据格式保存 201 | if format_type == "markdown" or format_type == "md": 202 | if "content" in report_data: 203 | content = report_data["content"] 204 | else: 205 | content = format_ct_report( 206 | image_findings=report_data.get("image_findings", ""), 207 | analysis=report_data.get("analysis", ""), 208 | diagnostic_opinion=report_data.get("diagnostic_opinion", ""), 209 | recommendations=report_data.get("recommendations", ""), 210 | examination_area=report_data.get("examination_area", "未知区域"), 211 | examination_date=report_data.get("examination_date"), 212 | report_date=report_data.get("report_date") 213 | ) 214 | return save_report_to_markdown(content, output_path) 215 | else: 216 | return save_report_to_json(report_data, output_path) 217 | 218 | def generate_comparative_report( 219 | self, 220 | current_ct_analysis: Dict[str, Any], 221 | previous_ct_analysis: Dict[str, Any], 222 | medical_knowledge: str 223 | ) -> Dict[str, Any]: 224 | """ 225 | 生成比较性CT诊断报告,对比当前与之前的扫描结果 226 | 227 | Args: 228 | current_ct_analysis: 当前CT分析结果 229 | previous_ct_analysis: 之前的CT分析结果 230 | medical_knowledge: 医学知识上下文 231 | 232 | Returns: 233 | 包含比较报告内容的字典 234 | """ 235 | # 提取CT分析描述 236 | if "summary" in current_ct_analysis: 237 | current_description = current_ct_analysis["summary"].get("combined_description", "") 238 | elif "analysis" in current_ct_analysis: 239 | current_description = current_ct_analysis["analysis"].get("combined_description", "") 240 | else: 241 | current_description = str(current_ct_analysis) 242 | 243 | if "summary" in previous_ct_analysis: 244 | previous_description = previous_ct_analysis["summary"].get("combined_description", "") 245 | elif "analysis" in previous_ct_analysis: 246 | previous_description = previous_ct_analysis["analysis"].get("combined_description", "") 247 | else: 248 | previous_description = str(previous_ct_analysis) 249 | 250 | # 创建提示模板 251 | template = """ 252 | 作为一名经验丰富的放射科医师,请比较当前和之前的CT扫描结果,并生成一份对比分析报告。 253 | 254 | ## 当前CT结果 255 | {current_description} 256 | 257 | ## 之前CT结果 258 | {previous_description} 259 | 260 | ## 相关医学知识 261 | {medical_knowledge} 262 | 263 | 请提供以下四个部分的内容: 264 | 265 | 1. 影像对比发现: (对比当前和之前CT的差异,包括新发现和变化) 266 | 267 | 2. 分析与解释: (对变化的专业分析和解释) 268 | 269 | 3. 诊断意见: (基于对比结果给出诊断意见,说明疾病进展情况) 270 | 271 | 4. 建议: (提供针对对比结果的建议) 272 | 273 | 请使用专业医学术语,同时保持清晰易懂。 274 | """ 275 | 276 | # 创建提示 277 | prompt = PromptTemplate( 278 | template=template, 279 | input_variables=["current_description", "previous_description", "medical_knowledge"] 280 | ) 281 | 282 | # 创建链 283 | chain = LLMChain(llm=self.llm, prompt=prompt) 284 | 285 | # 执行链 286 | result = chain.run( 287 | current_description=current_description, 288 | previous_description=previous_description, 289 | medical_knowledge=medical_knowledge 290 | ) 291 | 292 | # 提取各个部分 293 | report_sections = self._extract_report_sections(result) 294 | 295 | # 设置扫描区域 296 | scan_region = "未知区域" 297 | if "summary" in current_ct_analysis: 298 | scan_region = current_ct_analysis["summary"].get("scan_region", scan_region) 299 | 300 | # 生成报告 301 | report_content = format_ct_report( 302 | image_findings=report_sections["影像发现"], 303 | analysis=report_sections["分析与解释"], 304 | diagnostic_opinion=report_sections["诊断意见"], 305 | recommendations=report_sections["建议"], 306 | examination_area=scan_region, 307 | additional_info={"报告类型": "对比分析报告"} 308 | ) 309 | 310 | return {"content": report_content, "format": "markdown"} 311 | -------------------------------------------------------------------------------- /tools/ct_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | CT图像分析工具: 使用BiomedCLIP模型分析CT图像 3 | """ 4 | import os 5 | import json 6 | from typing import Dict, Any, List, Optional, Union 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from transformers import AutoProcessor, AutoModel 11 | 12 | from config import BIOMEDCLIP_MODEL_NAME 13 | from utils.image_processing import preprocess_for_biomedclip, load_image, apply_window_level 14 | from utils.dicom_handler import load_dicom, dicom_to_numpy, extract_dicom_metadata 15 | 16 | 17 | class BiomedCLIPTool: 18 | """使用BiomedCLIP模型分析CT图像的工具""" 19 | 20 | def __init__( 21 | self, 22 | model_name: str = BIOMEDCLIP_MODEL_NAME, 23 | device: Optional[str] = None 24 | ): 25 | """ 26 | 初始化BiomedCLIPTool 27 | 28 | Args: 29 | model_name: BiomedCLIP模型名称 30 | device: 计算设备 ('cpu', 'cuda', 'mps') 31 | """ 32 | # 确定设备 33 | if device is None: 34 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 35 | else: 36 | self.device = device 37 | 38 | print(f"使用设备: {self.device}") 39 | 40 | # 加载BiomedCLIP处理器和模型 41 | try: 42 | print(f"加载BiomedCLIP模型: {model_name}") 43 | self.processor = AutoProcessor.from_pretrained(model_name) 44 | self.model = AutoModel.from_pretrained(model_name).to(self.device) 45 | 46 | # 设置为评估模式 47 | self.model.eval() 48 | print("BiomedCLIP模型加载完成") 49 | except Exception as e: 50 | print(f"加载BiomedCLIP模型时出错: {e}") 51 | raise 52 | 53 | def analyze_image(self, image_path: str) -> Dict[str, Any]: 54 | """ 55 | 分析CT图像 56 | 57 | Args: 58 | image_path: 图像文件路径 59 | 60 | Returns: 61 | 包含分析结果的字典 62 | """ 63 | # 检查文件是否存在 64 | if not os.path.exists(image_path): 65 | raise FileNotFoundError(f"图像文件不存在: {image_path}") 66 | 67 | # 确定文件类型 68 | is_dicom = image_path.lower().endswith('.dcm') 69 | 70 | # 处理DICOM文件 71 | if is_dicom: 72 | return self._analyze_dicom_image(image_path) 73 | # 处理常规图像文件 74 | else: 75 | return self._analyze_regular_image(image_path) 76 | 77 | def _analyze_dicom_image(self, dicom_path: str) -> Dict[str, Any]: 78 | """ 79 | 分析DICOM格式CT图像 80 | 81 | Args: 82 | dicom_path: DICOM文件路径 83 | 84 | Returns: 85 | 包含分析结果的字典 86 | """ 87 | try: 88 | # 加载DICOM文件 89 | dicom_data = load_dicom(dicom_path) 90 | 91 | # 提取元数据 92 | metadata = extract_dicom_metadata(dicom_data) 93 | 94 | # 转换为NumPy数组 95 | image_array = dicom_to_numpy(dicom_data) 96 | 97 | # 应用窗宽窗位 98 | if 'WindowCenter' in metadata and 'WindowWidth' in metadata: 99 | window_center = float(metadata['WindowCenter']) 100 | window_width = float(metadata['WindowWidth']) 101 | processed_image = apply_window_level(image_array, window_center, window_width) 102 | else: 103 | processed_image = apply_window_level(image_array) 104 | 105 | # 使用BiomedCLIP分析图像 106 | analysis_result = self._run_biomedclip_analysis(processed_image) 107 | 108 | # 合并结果 109 | result = { 110 | "metadata": metadata, 111 | "analysis": analysis_result, 112 | "image_type": "DICOM" 113 | } 114 | 115 | return result 116 | 117 | except Exception as e: 118 | print(f"分析DICOM图像时出错: {e}") 119 | raise 120 | 121 | def _analyze_regular_image(self, image_path: str) -> Dict[str, Any]: 122 | """ 123 | 分析常规格式CT图像 124 | 125 | Args: 126 | image_path: 图像文件路径 127 | 128 | Returns: 129 | 包含分析结果的字典 130 | """ 131 | try: 132 | # 预处理图像 133 | processed_image = preprocess_for_biomedclip(image_path) 134 | 135 | # 使用BiomedCLIP分析图像 136 | analysis_result = self._run_biomedclip_analysis(processed_image) 137 | 138 | # 构建结果 139 | result = { 140 | "metadata": { 141 | "filename": os.path.basename(image_path), 142 | "file_type": os.path.splitext(image_path)[1][1:] 143 | }, 144 | "analysis": analysis_result, 145 | "image_type": "Regular" 146 | } 147 | 148 | return result 149 | 150 | except Exception as e: 151 | print(f"分析常规图像时出错: {e}") 152 | raise 153 | 154 | def _run_biomedclip_analysis(self, image_data: np.ndarray) -> Dict[str, Any]: 155 | """ 156 | 运行BiomedCLIP模型分析 157 | 158 | Args: 159 | image_data: 预处理后的图像数据 160 | 161 | Returns: 162 | 分析结果字典 163 | """ 164 | # 准备医学描述模板 165 | medical_descriptions = [ 166 | "正常胸部CT图像,无明显异常。", 167 | "肺部有磨玻璃样阴影,考虑肺炎可能。", 168 | "肺部有结节影,需要进一步评估。", 169 | "肺部有实质性浸润影,考虑感染或肿瘤。", 170 | "胸腔积液,肺实质无明显异常。", 171 | "肺气肿表现,肺部透明度增高。", 172 | "肺间质改变,考虑间质性肺病。", 173 | "支气管扩张,有蜂窝状改变。", 174 | "肺内占位性病变,需要进一步评估。", 175 | "纵隔淋巴结肿大,考虑炎症或肿瘤。", 176 | "肺动脉高压表现,肺动脉主干增宽。", 177 | "冠状动脉钙化,考虑冠心病。" 178 | ] 179 | 180 | try: 181 | # 转换图像格式以适应模型输入 182 | if isinstance(image_data, np.ndarray): 183 | if image_data.ndim == 2: # 单通道图像 184 | image_pil = Image.fromarray((image_data * 255).astype(np.uint8)) 185 | else: # 多通道图像 186 | image_pil = Image.fromarray((image_data[0] * 255).astype(np.uint8)) 187 | else: 188 | raise ValueError("图像数据格式不支持") 189 | 190 | # 使用处理器准备输入 191 | inputs = self.processor( 192 | text=medical_descriptions, 193 | images=image_pil, 194 | return_tensors="pt", 195 | padding=True 196 | ).to(self.device) 197 | 198 | # 执行推理 199 | with torch.no_grad(): 200 | outputs = self.model(**inputs) 201 | 202 | # 计算图像文本相似度 203 | image_embeds = outputs.vision_model_output.pooler_output 204 | text_embeds = outputs.text_model_output.pooler_output 205 | 206 | # 归一化嵌入 207 | image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) 208 | text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) 209 | 210 | # 计算相似度分数 211 | logits_per_image = image_embeds @ text_embeds.t() 212 | probs = torch.nn.functional.softmax(logits_per_image, dim=-1) 213 | 214 | # 将结果转换为CPU并提取概率 215 | probs_list = probs.cpu().numpy().tolist()[0] 216 | 217 | # 整理结果 218 | results = [] 219 | for desc, prob in zip(medical_descriptions, probs_list): 220 | results.append({ 221 | "description": desc, 222 | "probability": prob 223 | }) 224 | 225 | # 按概率排序 226 | results = sorted(results, key=lambda x: x["probability"], reverse=True) 227 | 228 | # 生成综合描述 229 | top_descriptions = [r["description"] for r in results[:3]] 230 | combined_description = self._generate_combined_description(top_descriptions) 231 | 232 | # 构建最终结果 233 | analysis_result = { 234 | "top_matches": results[:5], 235 | "combined_description": combined_description, 236 | "abnormality_detected": "正常" not in results[0]["description"], 237 | "confidence": results[0]["probability"] 238 | } 239 | 240 | return analysis_result 241 | 242 | except Exception as e: 243 | print(f"运行BiomedCLIP分析时出错: {e}") 244 | raise 245 | 246 | def _generate_combined_description(self, top_descriptions: List[str]) -> str: 247 | """ 248 | 生成综合描述 249 | 250 | Args: 251 | top_descriptions: 概率最高的几个描述 252 | 253 | Returns: 254 | 综合描述文本 255 | """ 256 | # 如果第一个描述是正常的,而且概率很高,直接返回 257 | if "正常" in top_descriptions[0]: 258 | return top_descriptions[0] 259 | 260 | # 否则,组合前三个描述 261 | combined = "CT图像分析显示:" 262 | 263 | for i, desc in enumerate(top_descriptions): 264 | if i == 0: 265 | combined += desc 266 | else: 267 | # 移除句首,只保留关键信息 268 | cleaned_desc = desc.split(",", 1)[-1] if "," in desc else desc 269 | combined += f";另外可能{cleaned_desc}" 270 | 271 | return combined 272 | 273 | def analyze_multiple_images( 274 | self, 275 | image_paths: List[str] 276 | ) -> Dict[str, Any]: 277 | """ 278 | 分析多张CT图像并汇总结果 279 | 280 | Args: 281 | image_paths: 图像文件路径列表 282 | 283 | Returns: 284 | 汇总的分析结果 285 | """ 286 | # 存储每张图像的分析结果 287 | individual_results = [] 288 | 289 | # 分析每张图像 290 | for image_path in image_paths: 291 | try: 292 | result = self.analyze_image(image_path) 293 | individual_results.append(result) 294 | except Exception as e: 295 | print(f"分析图像 {image_path} 时出错: {e}") 296 | 297 | # 汇总结果 298 | summary = self._summarize_analysis_results(individual_results) 299 | 300 | # 构建完整结果 301 | return { 302 | "summary": summary, 303 | "individual_results": individual_results, 304 | "image_count": len(individual_results) 305 | } 306 | 307 | def _summarize_analysis_results( 308 | self, 309 | results: List[Dict[str, Any]] 310 | ) -> Dict[str, Any]: 311 | """ 312 | 汇总多个分析结果 313 | 314 | Args: 315 | results: 分析结果列表 316 | 317 | Returns: 318 | 汇总结果 319 | """ 320 | if not results: 321 | return {"error": "无有效分析结果"} 322 | 323 | # 收集所有描述 324 | all_descriptions = [] 325 | abnormality_detected = False 326 | 327 | for result in results: 328 | analysis = result.get("analysis", {}) 329 | if analysis: 330 | all_descriptions.append(analysis.get("combined_description", "")) 331 | if analysis.get("abnormality_detected", False): 332 | abnormality_detected = True 333 | 334 | # 简单合并所有描述 335 | combined_description = " ".join(all_descriptions) 336 | 337 | # 构建汇总结果 338 | summary = { 339 | "combined_description": combined_description, 340 | "abnormality_detected": abnormality_detected, 341 | "scan_region": self._determine_scan_region(results) 342 | } 343 | 344 | return summary 345 | 346 | def _determine_scan_region(self, results: List[Dict[str, Any]]) -> str: 347 | """ 348 | 确定扫描区域 349 | 350 | Args: 351 | results: 分析结果列表 352 | 353 | Returns: 354 | 扫描区域描述 355 | """ 356 | # 从DICOM元数据中提取扫描区域信息 357 | for result in results: 358 | metadata = result.get("metadata", {}) 359 | if "BodyPartExamined" in metadata: 360 | return metadata["BodyPartExamined"] 361 | 362 | # 如果没有找到,尝试从描述中猜测 363 | descriptions = [] 364 | for result in results: 365 | if "analysis" in result and "combined_description" in result["analysis"]: 366 | descriptions.append(result["analysis"]["combined_description"]) 367 | 368 | combined_text = " ".join(descriptions).lower() 369 | 370 | # 简单规则匹配 371 | if "胸" in combined_text or "肺" in combined_text: 372 | return "胸部" 373 | elif "腹" in combined_text or "肝" in combined_text or "脾" in combined_text: 374 | return "腹部" 375 | elif "头" in combined_text or "脑" in combined_text: 376 | return "头部" 377 | else: 378 | return "未知区域" 379 | 380 | 381 | def create_ct_analysis_tool() -> BiomedCLIPTool: 382 | """ 383 | 创建CT分析工具实例 384 | 385 | Returns: 386 | BiomedCLIPTool实例 387 | """ 388 | return BiomedCLIPTool() 389 | 390 | 391 | def analyze_ct_images(image_paths: Union[str, List[str]]) -> Dict[str, Any]: 392 | """ 393 | 分析CT图像的便捷函数 394 | 395 | Args: 396 | image_paths: 单个图像路径或图像路径列表 397 | 398 | Returns: 399 | 分析结果 400 | """ 401 | tool = create_ct_analysis_tool() 402 | 403 | if isinstance(image_paths, str): 404 | return tool.analyze_image(image_paths) 405 | else: 406 | return tool.analyze_multiple_images(image_paths) 407 | 408 | 409 | def save_analysis_result(result: Dict[str, Any], output_path: str) -> None: 410 | """ 411 | 保存分析结果到JSON文件 412 | 413 | Args: 414 | result: 分析结果 415 | output_path: 输出文件路径 416 | """ 417 | with open(output_path, 'w', encoding='utf-8') as f: 418 | json.dump(result, f, ensure_ascii=False, indent=2) 419 | --------------------------------------------------------------------------------