├── .gitignore ├── LICENSE ├── README.md ├── assets └── chatpdf.png ├── chatpdf.py ├── complex_ui.py ├── logic.py ├── requirements.txt ├── sample.pdf └── simple_ui.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .mindnlp/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 nate.river 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChatPDF 2 | 3 | The ChatPDF(PDF Chatbot) is an application that allows users to upload PDF files and interact with pdf using a chatbot. Users can ask questions or provide input, and the chatbot will generate responses based on the provided information. 4 | 5 | ## Technologies Used 6 | 7 | - MindSpore 8 | - MindNLP 9 | - ms2vec 10 | - msimilarities 11 | 12 | 13 | ## Demo Video 14 | 15 | [Demo Video]() 16 | 17 | [![ChatPDF](./assets/chatpdf.png)]() 18 | 19 | ## Installation 20 | 21 | 1. Clone the repository: 22 | 23 | ```bash 24 | git clone https://github.com/lvyufeng/ChatPDF.git 25 | ``` 26 | 27 | 2. Install the required dependencies: 28 | 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ## Usage 34 | 35 | 1. Run the application: 36 | 37 | ```bash 38 | # complex version 39 | python simple_ui.py 40 | # complex version 41 | python complex_ui.py 42 | ``` 43 | 44 | 2. Access the application in your web browser as specified in the console. 45 | 46 | 3. To preview a PDF file, click the "Upload PDF" button and select the PDF file from your local machine. The application will display a preview of the PDF file. 47 | 48 | 5. Use the chatbox to ask questions or have a conversation with the chatbot. The chatbot will generate responses based on the input. 49 | -------------------------------------------------------------------------------- /assets/chatpdf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvyufeng/ChatPDF/fe0561d6b3a6245534d57ec487c8e0b7408a6b3e/assets/chatpdf.png -------------------------------------------------------------------------------- /chatpdf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import argparse 7 | import hashlib 8 | import os 9 | import re 10 | from threading import Thread 11 | from typing import Union, List 12 | 13 | import jieba 14 | from loguru import logger 15 | from mindnlp.peft import PeftModel 16 | from msimilarities import ( 17 | EnsembleSimilarity, 18 | BertSimilarity, 19 | BM25Similarity, 20 | ) 21 | from msimilarities.similarity import SimilarityABC 22 | from mindnlp.transformers import ( 23 | AutoModel, 24 | AutoModelForCausalLM, 25 | AutoTokenizer, 26 | BloomForCausalLM, 27 | BloomTokenizerFast, 28 | LlamaTokenizer, 29 | LlamaForCausalLM, 30 | TextIteratorStreamer, 31 | GenerationConfig, 32 | AutoModelForSequenceClassification, 33 | ) 34 | 35 | jieba.setLogLevel("ERROR") 36 | 37 | MODEL_CLASSES = { 38 | "bloom": (BloomForCausalLM, BloomTokenizerFast), 39 | "chatglm": (AutoModel, AutoTokenizer), 40 | "llama": (LlamaForCausalLM, LlamaTokenizer), 41 | "baichuan": (AutoModelForCausalLM, AutoTokenizer), 42 | "auto": (AutoModelForCausalLM, AutoTokenizer), 43 | } 44 | 45 | PROMPT_TEMPLATE = """基于以下已知信息,简洁和专业的来回答用户的问题。 46 | 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 47 | 48 | 已知内容: 49 | {context_str} 50 | 51 | 问题: 52 | {query_str} 53 | """ 54 | 55 | 56 | class SentenceSplitter: 57 | def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50): 58 | self.chunk_size = chunk_size 59 | self.chunk_overlap = chunk_overlap 60 | 61 | def split_text(self, text: str) -> List[str]: 62 | if self._is_has_chinese(text): 63 | return self._split_chinese_text(text) 64 | else: 65 | return self._split_english_text(text) 66 | 67 | def _split_chinese_text(self, text: str) -> List[str]: 68 | sentence_endings = {'\n', '。', '!', '?', ';', '…'} # 句末标点符号 69 | chunks, current_chunk = [], '' 70 | for word in jieba.cut(text): 71 | if len(current_chunk) + len(word) > self.chunk_size: 72 | chunks.append(current_chunk.strip()) 73 | current_chunk = word 74 | else: 75 | current_chunk += word 76 | if word[-1] in sentence_endings and len(current_chunk) > self.chunk_size - self.chunk_overlap: 77 | chunks.append(current_chunk.strip()) 78 | current_chunk = '' 79 | if current_chunk: 80 | chunks.append(current_chunk.strip()) 81 | if self.chunk_overlap > 0 and len(chunks) > 1: 82 | chunks = self._handle_overlap(chunks) 83 | return chunks 84 | 85 | def _split_english_text(self, text: str) -> List[str]: 86 | # 使用正则表达式按句子分割英文文本 87 | sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' ')) 88 | chunks, current_chunk = [], '' 89 | for sentence in sentences: 90 | if len(current_chunk) + len(sentence) <= self.chunk_size or not current_chunk: 91 | current_chunk += (' ' if current_chunk else '') + sentence 92 | else: 93 | chunks.append(current_chunk) 94 | current_chunk = sentence 95 | if current_chunk: # Add the last chunk 96 | chunks.append(current_chunk) 97 | 98 | if self.chunk_overlap > 0 and len(chunks) > 1: 99 | chunks = self._handle_overlap(chunks) 100 | 101 | return chunks 102 | 103 | def _is_has_chinese(self, text: str) -> bool: 104 | # check if contains chinese characters 105 | if any("\u4e00" <= ch <= "\u9fff" for ch in text): 106 | return True 107 | else: 108 | return False 109 | 110 | def _handle_overlap(self, chunks: List[str]) -> List[str]: 111 | # 处理块间重叠 112 | overlapped_chunks = [] 113 | for i in range(len(chunks) - 1): 114 | chunk = chunks[i] + ' ' + chunks[i + 1][:self.chunk_overlap] 115 | overlapped_chunks.append(chunk.strip()) 116 | overlapped_chunks.append(chunks[-1]) 117 | return overlapped_chunks 118 | 119 | 120 | class ChatPDF: 121 | def __init__( 122 | self, 123 | similarity_model: SimilarityABC = None, 124 | generate_model_type: str = "auto", 125 | generate_model_name_or_path: str = "01ai/Yi-6B-Chat", 126 | lora_model_name_or_path: str = None, 127 | corpus_files: Union[str, List[str]] = None, 128 | save_corpus_emb_dir: str = "./corpus_embs/", 129 | int8: bool = False, 130 | int4: bool = False, 131 | chunk_size: int = 250, 132 | chunk_overlap: int = 0, 133 | rerank_model_name_or_path: str = None, 134 | enable_history: bool = False, 135 | num_expand_context_chunk: int = 2, 136 | similarity_top_k: int = 10, 137 | rerank_top_k: int = 3, 138 | ): 139 | """ 140 | Init RAG model. 141 | :param similarity_model: similarity model, default None, if set, will use it instead of EnsembleSimilarity 142 | :param generate_model_type: generate model type 143 | :param generate_model_name_or_path: generate model name or path 144 | :param lora_model_name_or_path: lora model name or path 145 | :param corpus_files: corpus files 146 | :param save_corpus_emb_dir: save corpus embeddings dir, default ./corpus_embs/ 147 | :param int8: use int8 quantization, default False 148 | :param int4: use int4 quantization, default False 149 | :param chunk_size: chunk size, default 250 150 | :param chunk_overlap: chunk overlap, default 0, can not set to > 0 if num_expand_context_chunk > 0 151 | :param rerank_model_name_or_path: rerank model name or path, default 'BAAI/bge-reranker-base' 152 | :param enable_history: enable history, default False 153 | :param num_expand_context_chunk: num expand context chunk, default 2, if set to 0, will not expand context chunk 154 | :param similarity_top_k: similarity_top_k, default 5, similarity model search k corpus chunks 155 | :param rerank_top_k: rerank_top_k, default 3, rerank model search k corpus chunks 156 | """ 157 | if num_expand_context_chunk > 0 and chunk_overlap > 0: 158 | logger.warning(f" 'num_expand_context_chunk' and 'chunk_overlap' cannot both be greater than zero. " 159 | f" 'chunk_overlap' has been set to zero by default.") 160 | chunk_overlap = 0 161 | self.text_splitter = SentenceSplitter(chunk_size, chunk_overlap) 162 | if similarity_model is not None: 163 | self.sim_model = similarity_model 164 | else: 165 | m1 = BertSimilarity(model_name_or_path="shibing624/text2vec-base-multilingual") 166 | m2 = BM25Similarity() 167 | default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2) 168 | self.sim_model = default_sim_model 169 | self.gen_model, self.tokenizer = self._init_gen_model( 170 | generate_model_type, 171 | generate_model_name_or_path, 172 | peft_name=lora_model_name_or_path, 173 | int8=int8, 174 | int4=int4, 175 | ) 176 | self.history = [] 177 | self.corpus_files = corpus_files 178 | if corpus_files: 179 | self.add_corpus(corpus_files) 180 | self.save_corpus_emb_dir = save_corpus_emb_dir 181 | if rerank_model_name_or_path is None: 182 | rerank_model_name_or_path = "Xorbits/bge-reranker-base" 183 | if rerank_model_name_or_path: 184 | self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path, mirror='modelscope') 185 | self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path, mirror='modelscope') 186 | self.rerank_model.set_train(False) 187 | else: 188 | self.rerank_model = None 189 | self.rerank_tokenizer = None 190 | self.enable_history = enable_history 191 | self.similarity_top_k = similarity_top_k 192 | self.num_expand_context_chunk = num_expand_context_chunk 193 | self.rerank_top_k = rerank_top_k 194 | 195 | def __str__(self): 196 | return f"Similarity model: {self.sim_model}, Generate model: {self.gen_model}" 197 | 198 | def _init_gen_model( 199 | self, 200 | gen_model_type: str, 201 | gen_model_name_or_path: str, 202 | peft_name: str = None, 203 | int8: bool = False, 204 | int4: bool = False, 205 | ): 206 | """Init generate model.""" 207 | model_class, tokenizer_class = MODEL_CLASSES[gen_model_type] 208 | tokenizer = tokenizer_class.from_pretrained(gen_model_name_or_path, mirror='modelscope') 209 | model = model_class.from_pretrained( 210 | gen_model_name_or_path, mirror='modelscope' 211 | ) 212 | try: 213 | model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, mirror='modelscope') 214 | except Exception as e: 215 | logger.warning(f"Failed to load generation config from {gen_model_name_or_path}, {e}") 216 | if peft_name: 217 | model = PeftModel.from_pretrained( 218 | model, 219 | peft_name, 220 | ) 221 | logger.info(f"Loaded peft model from {peft_name}") 222 | model.set_train(False) 223 | return model, tokenizer 224 | 225 | def _get_chat_input(self): 226 | messages = [] 227 | for conv in self.history: 228 | if conv and len(conv) > 0 and conv[0]: 229 | messages.append({'role': 'user', 'content': conv[0]}) 230 | if conv and len(conv) > 1 and conv[1]: 231 | messages.append({'role': 'assistant', 'content': conv[1]}) 232 | input_ids = self.tokenizer.apply_chat_template( 233 | conversation=messages, 234 | tokenize=True, 235 | add_generation_prompt=True, 236 | return_tensors='ms' 237 | ) 238 | return input_ids 239 | 240 | def stream_generate_answer( 241 | self, 242 | max_new_tokens=512, 243 | temperature=0.7, 244 | repetition_penalty=1.0, 245 | context_len=2048 246 | ): 247 | streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) 248 | input_ids = self._get_chat_input() 249 | max_src_len = context_len - max_new_tokens - 8 250 | input_ids = input_ids[-max_src_len:] 251 | generation_kwargs = dict( 252 | input_ids=input_ids, 253 | max_new_tokens=max_new_tokens, 254 | temperature=temperature, 255 | do_sample=True, 256 | repetition_penalty=repetition_penalty, 257 | streamer=streamer, 258 | ) 259 | thread = Thread(target=self.gen_model.generate, kwargs=generation_kwargs) 260 | thread.start() 261 | 262 | yield from streamer 263 | 264 | def add_corpus(self, files: Union[str, List[str]]): 265 | """Load document files.""" 266 | if isinstance(files, str): 267 | files = [files] 268 | for doc_file in files: 269 | if doc_file.endswith('.pdf'): 270 | corpus = self.extract_text_from_pdf(doc_file) 271 | elif doc_file.endswith('.docx'): 272 | corpus = self.extract_text_from_docx(doc_file) 273 | elif doc_file.endswith('.md'): 274 | corpus = self.extract_text_from_markdown(doc_file) 275 | else: 276 | corpus = self.extract_text_from_txt(doc_file) 277 | full_text = '\n'.join(corpus) 278 | chunks = self.text_splitter.split_text(full_text) 279 | self.sim_model.add_corpus(chunks) 280 | self.corpus_files = files 281 | logger.debug(f"files: {files}, corpus size: {len(self.sim_model.corpus)}, top3: " 282 | f"{list(self.sim_model.corpus.values())[:3]}") 283 | 284 | def reset_corpus(self, files: Union[str, List[str]]): 285 | """Load document files.""" 286 | if isinstance(files, str): 287 | files = [files] 288 | for doc_file in files: 289 | if doc_file.endswith('.pdf'): 290 | corpus = self.extract_text_from_pdf(doc_file) 291 | elif doc_file.endswith('.docx'): 292 | corpus = self.extract_text_from_docx(doc_file) 293 | elif doc_file.endswith('.md'): 294 | corpus = self.extract_text_from_markdown(doc_file) 295 | else: 296 | corpus = self.extract_text_from_txt(doc_file) 297 | full_text = '\n'.join(corpus) 298 | chunks = self.text_splitter.split_text(full_text) 299 | self.sim_model.reset_corpus(chunks) 300 | self.corpus_files = files 301 | logger.debug(f"files: {files}, corpus size: {len(self.sim_model.corpus)}, top3: " 302 | f"{list(self.sim_model.corpus.values())[:3]}") 303 | 304 | 305 | @staticmethod 306 | def get_file_hash(fpaths): 307 | hasher = hashlib.md5() 308 | target_file_data = bytes() 309 | if isinstance(fpaths, str): 310 | fpaths = [fpaths] 311 | for fpath in fpaths: 312 | with open(fpath, 'rb') as file: 313 | chunk = file.read(1024 * 1024) # read only first 1MB 314 | hasher.update(chunk) 315 | target_file_data += chunk 316 | 317 | hash_name = hasher.hexdigest()[:32] 318 | return hash_name 319 | 320 | @staticmethod 321 | def extract_text_from_pdf(file_path: str): 322 | """Extract text content from a PDF file.""" 323 | import PyPDF2 324 | contents = [] 325 | with open(file_path, 'rb') as f: 326 | pdf_reader = PyPDF2.PdfReader(f) 327 | for page in pdf_reader.pages: 328 | page_text = page.extract_text().strip() 329 | raw_text = [text.strip() for text in page_text.splitlines() if text.strip()] 330 | new_text = '' 331 | for text in raw_text: 332 | new_text += text 333 | if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」', 334 | '』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']: 335 | contents.append(new_text) 336 | new_text = '' 337 | if new_text: 338 | contents.append(new_text) 339 | return contents 340 | 341 | @staticmethod 342 | def extract_text_from_txt(file_path: str): 343 | """Extract text content from a TXT file.""" 344 | with open(file_path, 'r', encoding='utf-8') as f: 345 | contents = [text.strip() for text in f.readlines() if text.strip()] 346 | return contents 347 | 348 | @staticmethod 349 | def extract_text_from_docx(file_path: str): 350 | """Extract text content from a DOCX file.""" 351 | import docx 352 | document = docx.Document(file_path) 353 | contents = [paragraph.text.strip() for paragraph in document.paragraphs if paragraph.text.strip()] 354 | return contents 355 | 356 | @staticmethod 357 | def extract_text_from_markdown(file_path: str): 358 | """Extract text content from a Markdown file.""" 359 | import markdown 360 | from bs4 import BeautifulSoup 361 | with open(file_path, 'r', encoding='utf-8') as f: 362 | markdown_text = f.read() 363 | html = markdown.markdown(markdown_text) 364 | soup = BeautifulSoup(html, 'html.parser') 365 | contents = [text.strip() for text in soup.get_text().splitlines() if text.strip()] 366 | return contents 367 | 368 | @staticmethod 369 | def _add_source_numbers(lst): 370 | """Add source numbers to a list of strings.""" 371 | return [f'[{idx + 1}]\t "{item}"' for idx, item in enumerate(lst)] 372 | 373 | def _get_reranker_score(self, query: str, reference_results: List[str]): 374 | """Get reranker score.""" 375 | pairs = [] 376 | for reference in reference_results: 377 | pairs.append([query, reference]) 378 | inputs = self.rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors='ms', max_length=512) 379 | scores = self.rerank_model(**inputs, return_dict=True).logits.view(-1, ).float() 380 | 381 | return scores 382 | 383 | def get_reference_results(self, query: str): 384 | """ 385 | Get reference results. 386 | 1. Similarity model get similar chunks 387 | 2. Rerank similar chunks 388 | 3. Expand reference context chunk 389 | :param query: 390 | :return: 391 | """ 392 | reference_results = [] 393 | sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k) 394 | # Get reference results from corpus 395 | hit_chunk_dict = dict() 396 | for query_id, id_score_dict in sim_contents.items(): 397 | for corpus_id, s in id_score_dict.items(): 398 | hit_chunk = self.sim_model.corpus[corpus_id] 399 | reference_results.append(hit_chunk) 400 | hit_chunk_dict[corpus_id] = hit_chunk 401 | 402 | if reference_results: 403 | if self.rerank_model is not None: 404 | # Rerank reference results 405 | rerank_scores = self._get_reranker_score(query, reference_results) 406 | logger.debug(f"rerank_scores: {rerank_scores}") 407 | # Get rerank top k chunks 408 | reference_results = [reference for reference, score in sorted( 409 | zip(reference_results, rerank_scores), key=lambda x: x[1], reverse=True)][:self.rerank_top_k] 410 | hit_chunk_dict = {corpus_id: hit_chunk for corpus_id, hit_chunk in hit_chunk_dict.items() if 411 | hit_chunk in reference_results} 412 | # Expand reference context chunk 413 | if self.num_expand_context_chunk > 0: 414 | new_reference_results = [] 415 | for corpus_id, hit_chunk in hit_chunk_dict.items(): 416 | expanded_reference = self.sim_model.corpus.get(corpus_id - 1, '') + hit_chunk 417 | for i in range(self.num_expand_context_chunk): 418 | expanded_reference += self.sim_model.corpus.get(corpus_id + i + 1, '') 419 | new_reference_results.append(expanded_reference) 420 | reference_results = new_reference_results 421 | return reference_results 422 | 423 | def predict_stream( 424 | self, 425 | query: str, 426 | max_length: int = 512, 427 | context_len: int = 2048, 428 | temperature: float = 0.7, 429 | ): 430 | """Generate predictions stream.""" 431 | stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "" 432 | if not self.enable_history: 433 | self.history = [] 434 | if self.sim_model.corpus: 435 | reference_results = self.get_reference_results(query) 436 | if not reference_results: 437 | yield '没有提供足够的相关信息', reference_results 438 | reference_results = self._add_source_numbers(reference_results) 439 | context_str = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))] 440 | prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query) 441 | logger.debug(f"prompt: {prompt}") 442 | else: 443 | prompt = query 444 | logger.debug(prompt) 445 | self.history.append([prompt, '']) 446 | response = "" 447 | for new_text in self.stream_generate_answer( 448 | max_new_tokens=max_length, 449 | temperature=temperature, 450 | context_len=context_len, 451 | ): 452 | if new_text != stop_str: 453 | response += new_text 454 | yield response 455 | 456 | def predict( 457 | self, 458 | query: str, 459 | max_length: int = 512, 460 | context_len: int = 2048, 461 | temperature: float = 0.7, 462 | ): 463 | """Query from corpus.""" 464 | reference_results = [] 465 | if not self.enable_history: 466 | self.history = [] 467 | if self.sim_model.corpus: 468 | reference_results = self.get_reference_results(query) 469 | 470 | if not reference_results: 471 | return '没有提供足够的相关信息', reference_results 472 | reference_results = self._add_source_numbers(reference_results) 473 | context_str = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))] 474 | prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query) 475 | logger.debug(f"prompt: {prompt}") 476 | else: 477 | prompt = query 478 | self.history.append([prompt, '']) 479 | response = "" 480 | for new_text in self.stream_generate_answer( 481 | max_new_tokens=max_length, 482 | temperature=temperature, 483 | context_len=context_len, 484 | ): 485 | response += new_text 486 | response = response.strip() 487 | self.history[-1][1] = response 488 | return response, reference_results 489 | 490 | def save_corpus_emb(self): 491 | dir_name = self.get_file_hash(self.corpus_files) 492 | save_dir = os.path.join(self.save_corpus_emb_dir, dir_name) 493 | if hasattr(self.sim_model, 'save_corpus_embeddings'): 494 | self.sim_model.save_corpus_embeddings(save_dir) 495 | logger.debug(f"Saving corpus embeddings to {save_dir}") 496 | return save_dir 497 | 498 | def load_corpus_emb(self, emb_dir: str): 499 | if hasattr(self.sim_model, 'load_corpus_embeddings'): 500 | logger.debug(f"Loading corpus embeddings from {emb_dir}") 501 | self.sim_model.load_corpus_embeddings(emb_dir) 502 | 503 | 504 | if __name__ == "__main__": 505 | parser = argparse.ArgumentParser() 506 | parser.add_argument("--sim_model_name", type=str, default="shibing624/text2vec-base-multilingual") 507 | parser.add_argument("--gen_model_type", type=str, default="auto") 508 | parser.add_argument("--gen_model_name", type=str, default="01-ai/Yi-6B-Chat") 509 | parser.add_argument("--lora_model", type=str, default=None) 510 | parser.add_argument("--rerank_model_name", type=str, default="") 511 | parser.add_argument("--corpus_files", type=str, default="sample.pdf") 512 | parser.add_argument("--chunk_size", type=int, default=220) 513 | parser.add_argument("--chunk_overlap", type=int, default=0) 514 | parser.add_argument("--num_expand_context_chunk", type=int, default=1) 515 | args = parser.parse_args() 516 | print(args) 517 | sim_model = BertSimilarity(model_name_or_path=args.sim_model_name) 518 | m = ChatPDF( 519 | similarity_model=sim_model, 520 | generate_model_type=args.gen_model_type, 521 | generate_model_name_or_path=args.gen_model_name, 522 | lora_model_name_or_path=args.lora_model, 523 | chunk_size=args.chunk_size, 524 | chunk_overlap=args.chunk_overlap, 525 | corpus_files=args.corpus_files.split(','), 526 | num_expand_context_chunk=args.num_expand_context_chunk, 527 | rerank_model_name_or_path=args.rerank_model_name, 528 | ) 529 | r, refs = m.predict('自然语言中的非平行迁移是指什么?') 530 | print(r) 531 | print(refs) 532 | -------------------------------------------------------------------------------- /complex_ui.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logic import add_text, generate_response, render_file, clear_chatbot 3 | 4 | import gradio as gr 5 | 6 | # Gradio application setup 7 | def create_demo(): 8 | with gr.Blocks(title= " PDF Chatbot", 9 | theme = "Soft" # Change the theme here 10 | ) as demo: 11 | 12 | # Create a Gradio block 13 | 14 | with gr.Column(): 15 | with gr.Row(): 16 | chatbot = gr.Chatbot(value=[], elem_id='chatbot', height=600) 17 | show_img = gr.Image(label='PDF Preview', height=600) 18 | 19 | with gr.Row(): 20 | text_input = gr.Textbox( 21 | show_label=False, 22 | placeholder="Ask your pdf?", 23 | container=False, 24 | render=False) 25 | gr.Examples(["这篇论文试图解决什么问题?", "有哪些相关研究?", 26 | "论文如何解决这个问题?", "论文做了哪些实验?", 27 | "有什么可以进一步探索的点?", "总结一下本文的主要内容"], text_input) 28 | 29 | with gr.Row(): 30 | with gr.Column(scale=0.60): 31 | text_input.render() 32 | 33 | with gr.Column(scale=0.20): 34 | submit_btn = gr.Button('Send') 35 | 36 | with gr.Column(scale=0.20): 37 | upload_btn = gr.UploadButton("📁 Upload PDF", file_types=[".pdf"]) 38 | 39 | 40 | return demo, chatbot, show_img, text_input, submit_btn, upload_btn 41 | 42 | demo, chatbot, show_img, txt, submit_btn, btn = create_demo() 43 | 44 | # Set up event handlers 45 | with demo: 46 | # Event handler for uploading a PDF 47 | btn.upload(render_file, inputs=[btn], outputs=[show_img]).success(clear_chatbot, outputs=[chatbot]) 48 | 49 | # Event handler for submitting text and generating response 50 | submit_btn.click(add_text, inputs=[chatbot, txt], outputs=[chatbot], queue=False).\ 51 | success(generate_response, inputs=[chatbot, txt, btn], outputs=[chatbot, txt]) 52 | if __name__ == "__main__": 53 | demo.launch() 54 | -------------------------------------------------------------------------------- /logic.py: -------------------------------------------------------------------------------- 1 | import fitz 2 | from PIL import Image 3 | import gradio as gr 4 | from chatpdf import ChatPDF 5 | 6 | model = ChatPDF() 7 | # Function to add text to the chat history 8 | def add_text(history, text): 9 | """ 10 | Adds the user's input text to the chat history. 11 | 12 | Args: 13 | history (list): List of tuples representing the chat history. 14 | text (str): The user's input text. 15 | 16 | Returns: 17 | list: Updated chat history with the new user input. 18 | """ 19 | if not text: 20 | raise gr.Error('Enter text') 21 | history.append((text, '')) 22 | return history 23 | 24 | 25 | def predict_stream(message, history): 26 | history_format = [] 27 | for human, assistant in history: 28 | history_format.append([human, assistant]) 29 | model.history = history_format 30 | for chunk in model.predict_stream(message): 31 | yield chunk 32 | 33 | # Function to generate a response based on the chat history and query 34 | def generate_response(history, query, btn): 35 | """ 36 | Generates a response based on the chat history and user's query. 37 | 38 | Args: 39 | history (list): List of tuples representing the chat history. 40 | query (str): The user's query. 41 | btn (FileStorage): The uploaded PDF file. 42 | 43 | Returns: 44 | tuple: Updated chat history with the generated response and the next page number. 45 | """ 46 | if not btn: 47 | raise gr.Error(message='Upload a PDF') 48 | 49 | history_format = [] 50 | for human, assistant in history: 51 | history_format.append([human, assistant]) 52 | model.history = history_format 53 | for chunk in model.predict_stream(query): 54 | history[-1][-1] = chunk 55 | yield history, " " 56 | 57 | # Function to render a specific page of a PDF file as an image 58 | def render_file(file): 59 | """ 60 | Renders a specific page of a PDF file as an image. 61 | 62 | Args: 63 | file (FileStorage): The PDF file. 64 | 65 | Returns: 66 | PIL.Image.Image: The rendered page as an image. 67 | """ 68 | # global n 69 | model.reset_corpus(file) 70 | doc = fitz.open(file.name) 71 | page = doc[0] 72 | # Render the page as a PNG image with a resolution of 300 DPI 73 | pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) 74 | image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples) 75 | return image 76 | 77 | def clear_chatbot(): 78 | return [] 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mindspore 2 | mindnlp 3 | PyMuPDF 4 | ms2vec 5 | msimilarities 6 | loguru 7 | jieba 8 | gradio 9 | PyPDF2 -------------------------------------------------------------------------------- /sample.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvyufeng/ChatPDF/fe0561d6b3a6245534d57ec487c8e0b7408a6b3e/sample.pdf -------------------------------------------------------------------------------- /simple_ui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author:XuMing(xuming624@qq.com) 4 | @description: 5 | """ 6 | import argparse 7 | import os 8 | 9 | import gradio as gr 10 | from loguru import logger 11 | 12 | from chatpdf import ChatPDF 13 | 14 | pwd_path = os.path.abspath(os.path.dirname(__file__)) 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--gen_model_type", type=str, default="auto") 19 | parser.add_argument("--gen_model_name", type=str, default="01-ai/Yi-6B-Chat") 20 | parser.add_argument("--lora_model", type=str, default=None) 21 | parser.add_argument("--rerank_model_name", type=str, default=None) 22 | parser.add_argument("--corpus_files", type=str, default="sample.pdf") 23 | parser.add_argument("--int4", action='store_true', help="use int4 quantization") 24 | parser.add_argument("--int8", action='store_true', help="use int8 quantization") 25 | parser.add_argument("--chunk_size", type=int, default=220) 26 | parser.add_argument("--chunk_overlap", type=int, default=0) 27 | parser.add_argument("--num_expand_context_chunk", type=int, default=1) 28 | parser.add_argument("--server_name", type=str, default="0.0.0.0") 29 | parser.add_argument("--server_port", type=int, default=8082) 30 | parser.add_argument("--share", action='store_true', help="share model") 31 | args = parser.parse_args() 32 | logger.info(args) 33 | 34 | model = ChatPDF( 35 | generate_model_type=args.gen_model_type, 36 | generate_model_name_or_path=args.gen_model_name, 37 | lora_model_name_or_path=args.lora_model, 38 | corpus_files=args.corpus_files.split(','), 39 | int4=args.int4, 40 | int8=args.int8, 41 | chunk_size=args.chunk_size, 42 | chunk_overlap=args.chunk_overlap, 43 | num_expand_context_chunk=args.num_expand_context_chunk, 44 | rerank_model_name_or_path=args.rerank_model_name, 45 | ) 46 | logger.info(f"chatpdf model: {model}") 47 | 48 | 49 | def predict_stream(message, history): 50 | history_format = [] 51 | for human, assistant in history: 52 | history_format.append([human, assistant]) 53 | model.history = history_format 54 | for chunk in model.predict_stream(message): 55 | yield chunk 56 | 57 | 58 | def predict(message, history): 59 | logger.debug(message) 60 | response, reference_results = model.predict(message) 61 | r = response + "\n\n" + '\n'.join(reference_results) 62 | logger.debug(r) 63 | return r 64 | 65 | 66 | chatbot_stream = gr.Chatbot( 67 | height=600, 68 | avatar_images=( 69 | os.path.join(pwd_path, "assets/user.png"), 70 | os.path.join(pwd_path, "assets/llama.png"), 71 | ), bubble_full_width=False) 72 | title = " 🎉ChatPDF WebUI🎉 " 73 | description = "Link in Github: [lvyufeng/ChatPDF](https://github.com/lvyufeng/ChatPDF)" 74 | css = """.toast-wrap { display: none !important } """ 75 | examples = ['Can you tell me about the NLP?', '介绍下NLP'] 76 | chat_interface_stream = gr.ChatInterface( 77 | predict_stream, 78 | textbox=gr.Textbox(lines=4, placeholder="Ask me question", scale=7), 79 | title=title, 80 | description=description, 81 | chatbot=chatbot_stream, 82 | css=css, 83 | examples=examples, 84 | theme='soft', 85 | ) 86 | 87 | with gr.Blocks() as demo: 88 | chat_interface_stream.render() 89 | demo.queue().launch( 90 | server_name=args.server_name, server_port=args.server_port, share=args.share 91 | ) 92 | --------------------------------------------------------------------------------