├── tests ├── __init__.py ├── resources │ └── __init__.py ├── conftest.py ├── test_grobid_processors.py └── test_document_qa_engine.py ├── pytest.ini ├── docs └── images │ ├── screenshot1.png │ └── screenshot2.png ├── .streamlit └── config.toml ├── .gitignore ├── Dockerfile ├── requirements.txt ├── pyproject.toml ├── .devcontainer └── devcontainer.json ├── .gitattributes ├── document_qa ├── custom_embeddings.py ├── deployment │ ├── modal_inference_qwen.py │ ├── modal_inference_phi.py │ └── modal_embeddings.py ├── langchain.py ├── ner_client_generic.py ├── document_qa_engine.py └── grobid_processors.py ├── .github └── workflows │ ├── ci-build.yml │ └── ci-release.yml ├── CHANGELOG.md ├── README.md ├── LICENSE └── streamlit_app.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests -------------------------------------------------------------------------------- /tests/resources/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | TEST_DATA_PATH = os.path.dirname(__file__) 4 | -------------------------------------------------------------------------------- /docs/images/screenshot1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfoppiano/document-qa/HEAD/docs/images/screenshot1.png -------------------------------------------------------------------------------- /docs/images/screenshot2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lfoppiano/document-qa/HEAD/docs/images/screenshot2.png -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [logger] 2 | level = "info" 3 | 4 | [browser] 5 | gatherUsageStats = true 6 | 7 | [ui] 8 | hideTopBar = true -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .env 3 | .env.docker 4 | **/**/.chroma 5 | resources/db 6 | build 7 | dist 8 | __pycache__ 9 | document_qa/__pycache__ 10 | document_qa_engine.egg-info/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12-slim 2 | 3 | WORKDIR /app 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | build-essential \ 7 | curl \ 8 | git \ 9 | && rm -rf /var/lib/apt/lists/* 10 | 11 | COPY requirements.txt . 12 | 13 | RUN pip3 install -r requirements.txt 14 | 15 | COPY .streamlit ./.streamlit 16 | COPY document_qa ./document_qa 17 | COPY streamlit_app.py . 18 | 19 | # extract version 20 | COPY .git ./.git 21 | RUN git rev-parse --short HEAD > revision.txt 22 | RUN rm -rf ./.git 23 | 24 | EXPOSE 8501 25 | 26 | HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health 27 | 28 | ENV PYTHONPATH "${PYTHONPATH}:." 29 | 30 | ENTRYPOINT ["streamlit", "run", "streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"] 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Grobid 2 | grobid-quantities-client==0.4.0 3 | grobid-client-python==0.0.9 4 | grobid-tei-xml==0.1.3 5 | 6 | # Utils 7 | tqdm==4.66.3 8 | pyyaml==6.0.1 9 | pytest==8.1.1 10 | streamlit==1.45.1 11 | lxml==5.2.1 12 | beautifulsoup4==4.12.3 13 | python-dotenv==1.0.1 14 | watchdog==4.0.0 15 | dateparser==1.2.0 16 | requests>=2.31.0 17 | numpy==1.26.4 18 | 19 | # LLM 20 | chromadb==0.4.24 21 | tiktoken==0.9.0 22 | openai==1.82.0 23 | langchain==0.3.25 24 | langchain-core==0.3.61 25 | langchain-openai==0.3.18 26 | langchain-huggingface==0.2.0 27 | langchain-community==0.3.21 28 | typing-inspect==0.9.0 29 | typing_extensions==4.12.2 30 | pydantic==2.10.6 31 | sentence-transformers==2.6.1 32 | streamlit-pdf-viewer==0.0.25 33 | umap-learn==0.5.6 34 | plotly==5.20.0 35 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from unittest.mock import MagicMock 4 | 5 | import pytest 6 | from _pytest._py.path import LocalPath 7 | 8 | # derived from https://github.com/elifesciences/sciencebeam-trainer-delft/tree/develop/tests 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | 13 | @pytest.fixture(scope='session', autouse=True) 14 | def setup_logging(): 15 | logging.root.handlers = [] 16 | logging.basicConfig(level='INFO') 17 | logging.getLogger('tests').setLevel('DEBUG') 18 | # logging.getLogger('sciencebeam_trainer_delft').setLevel('DEBUG') 19 | 20 | 21 | def _backport_assert_called(mock: MagicMock): 22 | assert mock.called 23 | 24 | 25 | @pytest.fixture(scope='session', autouse=True) 26 | def patch_magicmock(): 27 | try: 28 | MagicMock.assert_called 29 | except AttributeError: 30 | MagicMock.assert_called = _backport_assert_called 31 | 32 | 33 | @pytest.fixture 34 | def temp_dir(tmpdir: LocalPath): 35 | # convert to standard Path 36 | return Path(str(tmpdir)) 37 | 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.bumpversion] 6 | current_version = "0.5.1" 7 | commit = "true" 8 | tag = "true" 9 | tag_name = "v{new_version}" 10 | 11 | #[[tool.bumpversion.files]] 12 | #filename = "version.txt" 13 | #search = "{current_version}" 14 | #replace = "{new_version}" 15 | 16 | [project] 17 | name = "document-qa-engine" 18 | license = { file = "LICENSE" } 19 | authors = [ 20 | { name = "Luca Foppiano", email = "lucanoro@duck.com" }, 21 | ] 22 | maintainers = [ 23 | { name = "Luca Foppiano", email = "lucanoro@duck.com" } 24 | ] 25 | description = "Scientific Document Insight Q/A" 26 | readme = "README.md" 27 | 28 | dynamic = ['version', "dependencies"] 29 | 30 | [tool.setuptools] 31 | license-files = [] 32 | 33 | [tool.setuptools.dynamic] 34 | dependencies = {file = ["requirements.txt"]} 35 | 36 | [tool.setuptools_scm] 37 | 38 | [project.urls] 39 | Homepage = "https://document-insights.streamlit.app" 40 | Repository = "https://github.com/lfoppiano/document-qa" 41 | Changelog = "https://github.com/lfoppiano/document-qa/blob/main/CHANGELOG.md" -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Python 3", 3 | // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile 4 | "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye", 5 | "customizations": { 6 | "codespaces": { 7 | "openFiles": [ 8 | "README.md", 9 | "streamlit_app.py" 10 | ] 11 | }, 12 | "vscode": { 13 | "settings": {}, 14 | "extensions": [ 15 | "ms-python.python", 16 | "ms-python.vscode-pylance" 17 | ] 18 | } 19 | }, 20 | "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y List[List[str]]: 13 | # We remove newlines from the text to avoid issues with the embedding model. 14 | cleaned_text = [t.replace("\n", " ") for t in text] 15 | 16 | payload = {'text': "\n".join(cleaned_text)} 17 | 18 | headers = {} 19 | if self.api_key: 20 | headers = {'x-api-key': self.api_key} 21 | 22 | response = requests.post( 23 | self.url, 24 | data=payload, 25 | files=[], 26 | headers=headers 27 | ) 28 | response.raise_for_status() 29 | 30 | # print(response.text) 31 | return response.json() 32 | 33 | def embed_documents(self, text: List[str]) -> List[List[str]]: 34 | """ 35 | Embed a list of documents using the embedding model. 36 | """ 37 | return self.embed(text) 38 | 39 | def embed_query(self, text: str) -> List[str]: 40 | """ 41 | Embed a query 42 | """ 43 | return self.embed([text])[0] 44 | 45 | def get_model_name(self) -> str: 46 | return self.model_name 47 | 48 | 49 | if __name__ == "__main__": 50 | embeds = ModalEmbeddings( 51 | url="https://lfoppiano--intfloat-multilingual-e5-large-instruct-embed-5da184.modal.run/", 52 | model_name="intfloat/multilingual-e5-large-instruct" 53 | ) 54 | 55 | print(embeds.embed( 56 | ["We are surrounded by stupid kids", 57 | "We are interested in the future of AI"] 58 | )) 59 | -------------------------------------------------------------------------------- /document_qa/deployment/modal_inference_qwen.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import modal 4 | 5 | vllm_image = ( 6 | modal.Image.debian_slim(python_version="3.11") 7 | .pip_install( 8 | "vllm", 9 | "transformers>=4.51.0", 10 | "huggingface_hub[hf_transfer]>=0.26.2", 11 | "flashinfer-python==0.2.0.post2", # pinning, very unstable 12 | extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5", 13 | ) 14 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers 15 | ) 16 | 17 | MODELS_DIR = "/llamas" 18 | MODEL_NAME = "Qwen/Qwen3-0.6B" 19 | MODEL_REVISION = "e6de91484c29aa9480d55605af694f39b081c455" 20 | 21 | hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True) 22 | vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) 23 | 24 | 25 | app = modal.App("gwen-0.6b-qa-vllm") 26 | 27 | N_GPU = 1 28 | MINUTES = 60 # seconds 29 | VLLM_PORT = 8000 30 | 31 | 32 | @app.function( 33 | image=vllm_image, 34 | # gpu=f"L40S:{N_GPU}", 35 | gpu=f"A10G:{N_GPU}", 36 | # how long should we stay up with no requests? 37 | scaledown_window=5 * MINUTES, 38 | volumes={ 39 | "/root/.cache/huggingface": hf_cache_vol, 40 | "/root/.cache/vllm": vllm_cache_vol, 41 | }, 42 | secrets=[modal.Secret.from_name("document-qa-api-key")] 43 | ) 44 | @modal.concurrent( 45 | max_inputs=5 46 | ) # how many requests can one replica handle? tune carefully! 47 | @modal.web_server(port=VLLM_PORT, startup_timeout=5 * MINUTES) 48 | def serve(): 49 | import subprocess 50 | 51 | cmd = [ 52 | "vllm", 53 | "serve", 54 | "--uvicorn-log-level=info", 55 | MODEL_NAME, 56 | "--revision", 57 | MODEL_REVISION, 58 | "--enable-reasoning", 59 | "--reasoning-parser", 60 | "deepseek_r1", 61 | "--max-model-len", 62 | "32768", 63 | "--host", 64 | "0.0.0.0", 65 | "--port", 66 | str(VLLM_PORT), 67 | "--api-key", 68 | os.environ["API_KEY"], 69 | ] 70 | 71 | subprocess.Popen(" ".join(cmd), shell=True) -------------------------------------------------------------------------------- /.github/workflows/ci-build.yml: -------------------------------------------------------------------------------- 1 | name: Build unstable 2 | 3 | on: [push] 4 | 5 | concurrency: 6 | group: unstable 7 | # cancel-in-progress: true 8 | 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: [ 3.8, 3.9, '3.10', '3.11' ] 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | cache: 'pip' 24 | - name: Cleanup more disk space 25 | run: sudo rm -rf /usr/share/dotnet && sudo rm -rf /opt/ghc && sudo rm -rf "/usr/local/share/boost" && sudo rm -rf "$AGENT_TOOLSDIRECTORY" 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install --upgrade flake8 pytest pycodestyle pytest-cov 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | 41 | docker-build: 42 | needs: [build] 43 | runs-on: ubuntu-latest 44 | 45 | steps: 46 | - name: Create more disk space 47 | run: sudo rm -rf /usr/share/dotnet && sudo rm -rf /opt/ghc && sudo rm -rf "/usr/local/share/boost" && sudo rm -rf "$AGENT_TOOLSDIRECTORY" 48 | - uses: actions/checkout@v2 49 | - name: Build and push 50 | id: docker_build 51 | uses: mr-smithers-excellent/docker-build-push@v5 52 | with: 53 | username: ${{ secrets.DOCKERHUB_USERNAME }} 54 | password: ${{ secrets.DOCKERHUB_TOKEN }} 55 | image: lfoppiano/document-insights-qa 56 | registry: docker.io 57 | pushImage: ${{ github.event_name != 'pull_request' }} 58 | tags: latest-develop 59 | - name: Image digest 60 | run: echo ${{ steps.docker_build.outputs.digest }} -------------------------------------------------------------------------------- /document_qa/deployment/modal_inference_phi.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import modal 4 | 5 | vllm_image = ( 6 | modal.Image.debian_slim(python_version="3.10") 7 | .pip_install( 8 | "vllm", 9 | "huggingface_hub[hf_transfer]==0.26.2", 10 | "flashinfer-python==0.2.0.post2", # pinning, very unstable 11 | extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5", 12 | ) 13 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers 14 | ) 15 | 16 | MODELS_DIR = "/llamas" 17 | MODEL_NAME = "microsoft/Phi-4-mini-instruct" 18 | MODEL_REVISION = "c0fb9e74abda11b496b7907a9c6c9009a7a0488f" 19 | 20 | FAST_BOOT = True 21 | 22 | hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True) 23 | vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) 24 | 25 | 26 | app = modal.App("phi-4-mini-instruct-qa-vllm") 27 | 28 | N_GPU = 1 29 | MINUTES = 60 # seconds 30 | VLLM_PORT = 8000 31 | 32 | 33 | @app.function( 34 | image=vllm_image, 35 | # gpu=f"L40S:{N_GPU}", 36 | gpu=f"A10G:{N_GPU}", 37 | # how long should we stay up with no requests? 38 | scaledown_window=5 * MINUTES, 39 | volumes={ 40 | "/root/.cache/huggingface": hf_cache_vol, 41 | "/root/.cache/vllm": vllm_cache_vol, 42 | }, 43 | secrets=[modal.Secret.from_name("document-qa-api-key")] 44 | ) 45 | @modal.concurrent( 46 | max_inputs=5 47 | ) # how many requests can one replica handle? tune carefully! 48 | @modal.web_server(port=VLLM_PORT, startup_timeout=5 * MINUTES) 49 | def serve(): 50 | import subprocess 51 | 52 | cmd = [ 53 | "vllm", 54 | "serve", 55 | "--uvicorn-log-level=info", 56 | MODEL_NAME, 57 | "--revision", 58 | MODEL_REVISION, 59 | "--max-model-len", 60 | "32768", 61 | "--host", 62 | "0.0.0.0", 63 | "--port", 64 | str(VLLM_PORT), 65 | "--api-key", 66 | os.environ["API_KEY"], 67 | ] 68 | 69 | # enforce-eager disables both Torch compilation and CUDA graph capture 70 | # default is no-enforce-eager. see the --compilation-config flag for tighter control 71 | cmd += ["--enforce-eager" if FAST_BOOT else "--no-enforce-eager"] 72 | 73 | # assume multiple GPUs are for splitting up large matrix multiplications 74 | cmd += ["--tensor-parallel-size", str(N_GPU)] 75 | 76 | subprocess.Popen(" ".join(cmd), shell=True) -------------------------------------------------------------------------------- /tests/test_document_qa_engine.py: -------------------------------------------------------------------------------- 1 | from document_qa.document_qa_engine import TextMerger 2 | 3 | 4 | def test_merge_passages_small_chunk(): 5 | merger = TextMerger() 6 | 7 | passages = [ 8 | { 9 | 'text': "The quick brown fox jumps over the tree", 10 | 'coordinates': '1' 11 | }, 12 | { 13 | 'text': "and went straight into the mouth of a bear.", 14 | 'coordinates': '2' 15 | }, 16 | { 17 | 'text': "The color of the colors is a color with colors", 18 | 'coordinates': '3' 19 | }, 20 | { 21 | 'text': "the main colors are not the colorw we show", 22 | 'coordinates': '4' 23 | } 24 | ] 25 | new_passages = merger.merge_passages(passages, chunk_size=10, tolerance=0) 26 | 27 | assert len(new_passages) == 4 28 | assert new_passages[0]['coordinates'] == "1" 29 | assert new_passages[0]['text'] == "The quick brown fox jumps over the tree" 30 | 31 | assert new_passages[1]['coordinates'] == "2" 32 | assert new_passages[1]['text'] == "and went straight into the mouth of a bear." 33 | 34 | assert new_passages[2]['coordinates'] == "3" 35 | assert new_passages[2]['text'] == "The color of the colors is a color with colors" 36 | 37 | assert new_passages[3]['coordinates'] == "4" 38 | assert new_passages[3]['text'] == "the main colors are not the colorw we show" 39 | 40 | 41 | def test_merge_passages_big_chunk(): 42 | merger = TextMerger() 43 | 44 | passages = [ 45 | { 46 | 'text': "The quick brown fox jumps over the tree", 47 | 'coordinates': '1' 48 | }, 49 | { 50 | 'text': "and went straight into the mouth of a bear.", 51 | 'coordinates': '2' 52 | }, 53 | { 54 | 'text': "The color of the colors is a color with colors", 55 | 'coordinates': '3' 56 | }, 57 | { 58 | 'text': "the main colors are not the colorw we show", 59 | 'coordinates': '4' 60 | } 61 | ] 62 | new_passages = merger.merge_passages(passages, chunk_size=20, tolerance=0) 63 | 64 | assert len(new_passages) == 2 65 | assert new_passages[0]['coordinates'] == "1;2" 66 | assert new_passages[0][ 67 | 'text'] == "The quick brown fox jumps over the tree and went straight into the mouth of a bear." 68 | 69 | assert new_passages[1]['coordinates'] == "3;4" 70 | assert new_passages[1][ 71 | 'text'] == "The color of the colors is a color with colors the main colors are not the colorw we show" 72 | -------------------------------------------------------------------------------- /.github/workflows/ci-release.yml: -------------------------------------------------------------------------------- 1 | name: Build release 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | tags: 7 | - 'v*' 8 | 9 | concurrency: 10 | group: docker 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 3.9 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.9" 22 | cache: 'pip' 23 | - name: Cleanup more disk space 24 | run: sudo rm -rf /usr/share/dotnet && sudo rm -rf /opt/ghc && sudo rm -rf "/usr/local/share/boost" && sudo rm -rf "$AGENT_TOOLSDIRECTORY" 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install --upgrade flake8 pytest pycodestyle 29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 30 | - name: Lint with flake8 31 | run: | 32 | # stop the build if there are Python syntax errors or undefined names 33 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 34 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 35 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 36 | # - name: Test with pytest 37 | # run: | 38 | # pytest 39 | 40 | - name: Build and Publish to PyPI 41 | uses: conchylicultor/pypi-build-publish@v1 42 | with: 43 | pypi-token: ${{ secrets.PYPI_API_TOKEN }} 44 | 45 | 46 | docker-build: 47 | needs: [build] 48 | runs-on: ubuntu-latest 49 | 50 | steps: 51 | - name: Set tags 52 | id: set_tags 53 | run: | 54 | DOCKER_IMAGE=lfoppiano/document-insights-qa 55 | VERSION="" 56 | if [[ $GITHUB_REF == refs/tags/v* ]]; then 57 | VERSION=${GITHUB_REF#refs/tags/v} 58 | fi 59 | if [[ $VERSION =~ ^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$ ]]; then 60 | TAGS="${VERSION}" 61 | else 62 | TAGS="latest" 63 | fi 64 | echo "TAGS=${TAGS}" 65 | echo ::set-output name=tags::${TAGS} 66 | - name: Create more disk space 67 | run: sudo rm -rf /usr/share/dotnet && sudo rm -rf /opt/ghc && sudo rm -rf "/usr/local/share/boost" && sudo rm -rf "$AGENT_TOOLSDIRECTORY" 68 | - uses: actions/checkout@v2 69 | - name: Build and push 70 | id: docker_build 71 | uses: mr-smithers-excellent/docker-build-push@v5 72 | with: 73 | username: ${{ secrets.DOCKERHUB_USERNAME }} 74 | password: ${{ secrets.DOCKERHUB_TOKEN }} 75 | image: lfoppiano/document-insights-qa 76 | registry: docker.io 77 | pushImage: ${{ github.event_name != 'pull_request' }} 78 | tags: ${{ steps.set_tags.outputs.tags }} 79 | - name: Image digest 80 | run: echo ${{ steps.docker_build.outputs.digest }} -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 6 | 7 | ## [0.4.2] - 2024-08-23 8 | 9 | ### Fixed 10 | + Correct invalid dependency of promptlayer slipped in the build 11 | 12 | ## [0.4.1] - 2024-08-23 13 | 14 | ### Added 15 | + Scroll to the first relevant context passage, if the most relevant context passage is at the end, it will scroll to the end of the document 16 | + Added Mistral NEMO as default model 17 | 18 | ### Changed 19 | + Rearranged the interface to get more space 20 | + Updated libraries to the latest versions 21 | 22 | ### Fixed 23 | + Fixed the chat messages sequence that were buggy 24 | + Updated the PDF viewer to the latest version 25 | 26 | ## [0.4.0] - 2024-06-24 27 | 28 | ### Added 29 | + Add selection of embedding functions 30 | + Add selection of text from the pdf viewer (provided by https://github.com/lfoppiano/streamlit-pdf-viewer) 31 | + Added an experimental feature for calculating the coefficient that relate the question and the embedding database 32 | + Added the data availability statement in the searchable text 33 | 34 | ### Changed 35 | + Removed obsolete and non-working models zephyr and mistral v0.1 36 | + The underlying library was refactored to make it easier to maintain 37 | + Removed the native PDF viewer 38 | + Updated langchain and streamlit to the latest versions 39 | + Removed conversational memory which was causing more problems than bringing benefits 40 | + Rearranged the interface to get more space 41 | 42 | ### Fixed 43 | + Updated and removed models that were not working 44 | + Fixed problems with langchain and other libraries 45 | 46 | ## [0.3.4] - 2023-12-26 47 | 48 | ### Added 49 | 50 | + Add gpt4 and gpt4-turbo 51 | 52 | ### Changed 53 | 54 | + improved UI: replace combo boxes with dropdown box 55 | 56 | ### Fixed 57 | 58 | + Fixed dependencies when installing as library 59 | 60 | ## [0.3.3] - 2023-12-14 61 | 62 | ### Added 63 | 64 | + Add experimental PDF rendering in the page 65 | 66 | ### Fixed 67 | 68 | + Fix GrobidProcessors API implementation 69 | 70 | ## [0.3.2] - 2023-12-01 71 | 72 | ### Fixed 73 | 74 | + Remove memory when using Zephyr-7b-beta, that easily hallucinate 75 | 76 | ## [0.3.1] - 2023-11-22 77 | 78 | ### Added 79 | 80 | + Include biblio in embeddings by @lfoppiano in #21 81 | 82 | ### Fixed 83 | 84 | + Fix conversational memory by @lfoppiano in #20 85 | 86 | ## [0.3.0] - 2023-11-18 87 | 88 | ### Added 89 | 90 | + add zephyr-7b by @lfoppiano in #15 91 | + add conversational memory in #18 92 | 93 | ## [0.2.1] - 2023-11-01 94 | 95 | ### Fixed 96 | 97 | + fix env variables by @lfoppiano in #9 98 | 99 | ## [0.2.0] – 2023-10-31 100 | 101 | ### Added 102 | 103 | + Selection of chunk size on which embeddings are created upon 104 | + Mistral model to be used freely via the Huggingface free API 105 | 106 | ### Changed 107 | 108 | + Improved documentation, adding privacy statement 109 | + Moved settings on the sidebar 110 | + Disable NER extraction by default, and allow user to activate it 111 | + Read API KEY from the environment variables and if present, avoid asking the user 112 | + Avoid changing model after update 113 | 114 | ## [0.1.3] – 2023-10-30 115 | 116 | ### Fixed 117 | 118 | + ChromaDb accumulating information even when new papers were uploaded 119 | 120 | ## [0.1.2] – 2023-10-26 121 | 122 | ### Fixed 123 | 124 | + docker build 125 | 126 | ## [0.1.1] – 2023-10-26 127 | 128 | ### Fixed 129 | 130 | + Github action build 131 | + dependencies of langchain and chromadb 132 | 133 | ## [0.1.0] – 2023-10-26 134 | 135 | ### Added 136 | 137 | + pypi package 138 | + docker package release 139 | 140 | ## [0.0.1] – 2023-10-26 141 | 142 | ### Added 143 | 144 | + Kick off application 145 | + Support for GPT-3.5 146 | + Support for Mistral + SentenceTransformer 147 | + Streamlit application 148 | + Docker image 149 | + pypi package 150 | 151 | 152 | -------------------------------------------------------------------------------- /document_qa/deployment/modal_embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Annotated, List 3 | from fastapi import Request, HTTPException, Form 4 | 5 | import modal 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from transformers import AutoTokenizer, AutoModel 10 | 11 | image = ( 12 | modal.Image.debian_slim(python_version="3.11") 13 | .pip_install( 14 | "transformers", 15 | "huggingface_hub[hf_transfer]==0.26.2", 16 | "flashinfer-python==0.2.0.post2", # pinning, very unstable 17 | "fastapi[standard]", 18 | extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5", 19 | ) 20 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers 21 | ) 22 | 23 | MODELS_DIR = "/llamas" 24 | MODEL_NAME = "intfloat/multilingual-e5-large-instruct" 25 | MODEL_REVISION = "84344a23ee1820ac951bc365f1e91d094a911763" 26 | 27 | hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True) 28 | vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True) 29 | 30 | app = modal.App("intfloat-multilingual-e5-large-instruct-embeddings") 31 | 32 | 33 | def get_device(): 34 | return torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | 36 | def load_model(): 37 | print("Loading model...") 38 | device = get_device() 39 | print(f"Using device: {device}") 40 | 41 | tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large-instruct') 42 | model = AutoModel.from_pretrained('intfloat/multilingual-e5-large-instruct').to(device) 43 | print("Model loaded successfully.") 44 | 45 | return tokenizer, model, device 46 | 47 | 48 | N_GPU = 1 49 | MINUTES = 60 # seconds 50 | VLLM_PORT = 8000 51 | 52 | 53 | def average_pool(last_hidden_states: Tensor, 54 | attention_mask: Tensor) -> Tensor: 55 | last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) 56 | return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 57 | 58 | 59 | @app.function( 60 | image=image, 61 | gpu=f"L40S:{N_GPU}", 62 | # gpu=f"A10G:{N_GPU}", 63 | # how long should we stay up with no requests? 64 | scaledown_window=3 * MINUTES, 65 | volumes={ 66 | "/root/.cache/huggingface": hf_cache_vol, 67 | "/root/.cache/vllm": vllm_cache_vol, 68 | }, 69 | secrets=[modal.Secret.from_name("document-qa-embedding-key")] 70 | ) 71 | @modal.concurrent( 72 | max_inputs=5 73 | ) # how many requests can one replica handle? tune carefully! 74 | @modal.fastapi_endpoint(method="POST") 75 | def embed(request: Request, text: Annotated[str, Form()]): 76 | api_key = request.headers.get("x-api-key") 77 | expected_key = os.environ["API_KEY"] 78 | 79 | if api_key != expected_key: 80 | raise HTTPException(status_code=401, detail="Unauthorized") 81 | 82 | 83 | texts = [t for t in text.split("\n") if t.strip()] 84 | if not texts: 85 | return [] 86 | 87 | tokenizer, model, device = load_model() 88 | model.eval() 89 | 90 | print(f"Start embedding {len(texts)} texts") 91 | try: 92 | with torch.no_grad(): 93 | # Move inputs to the same device as model 94 | batch_dict = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') 95 | batch_dict = {k: v.to(device) for k, v in batch_dict.items()} 96 | 97 | # Forward pass 98 | outputs = model(**batch_dict) 99 | 100 | # Process embeddings 101 | embeddings = average_pool( 102 | outputs.last_hidden_state, 103 | batch_dict['attention_mask'] 104 | ) 105 | embeddings = F.normalize(embeddings, p=2, dim=1) 106 | 107 | # Move to CPU and convert to list for serialization 108 | embeddings = embeddings.cpu().numpy().tolist() 109 | 110 | print("Finished embedding texts.") 111 | return embeddings 112 | 113 | except RuntimeError as e: 114 | print(f"Error during embedding: {str(e)}") 115 | if "CUDA out of memory" in str(e): 116 | print("CUDA out of memory error. Try reducing batch size or using a smaller model.") 117 | raise 118 | -------------------------------------------------------------------------------- /document_qa/langchain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection 2 | 3 | from langchain.schema import Document 4 | from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K 5 | from langchain_core.callbacks import CallbackManagerForRetrieverRun 6 | from langchain_core.utils import xor_args 7 | from langchain_core.vectorstores import VectorStore, VectorStoreRetriever 8 | 9 | 10 | class AdvancedVectorStoreRetriever(VectorStoreRetriever): 11 | allowed_search_types: ClassVar[Collection[str]] = ( 12 | "similarity", 13 | "similarity_score_threshold", 14 | "mmr", 15 | "similarity_with_embeddings" 16 | ) 17 | 18 | def _get_relevant_documents( 19 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun 20 | ) -> List[Document]: 21 | 22 | if self.search_type == "similarity_with_embeddings": 23 | docs_scores_and_embeddings = ( 24 | self.vectorstore.advanced_similarity_search( 25 | query, **self.search_kwargs 26 | ) 27 | ) 28 | 29 | for doc, score, embeddings in docs_scores_and_embeddings: 30 | if '__embeddings' not in doc.metadata.keys(): 31 | doc.metadata['__embeddings'] = embeddings 32 | if '__similarity' not in doc.metadata.keys(): 33 | doc.metadata['__similarity'] = score 34 | 35 | docs = [doc for doc, _, _ in docs_scores_and_embeddings] 36 | elif self.search_type == "similarity_score_threshold": 37 | docs_and_similarities = ( 38 | self.vectorstore.similarity_search_with_relevance_scores( 39 | query, **self.search_kwargs 40 | ) 41 | ) 42 | for doc, similarity in docs_and_similarities: 43 | if '__similarity' not in doc.metadata.keys(): 44 | doc.metadata['__similarity'] = similarity 45 | 46 | docs = [doc for doc, _ in docs_and_similarities] 47 | else: 48 | docs = super()._get_relevant_documents(query, run_manager=run_manager) 49 | 50 | return docs 51 | 52 | 53 | class AdvancedVectorStore(VectorStore): 54 | def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever: 55 | tags = kwargs.pop("tags", None) or [] 56 | tags.extend(self._get_retriever_tags()) 57 | return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags) 58 | 59 | 60 | class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore): 61 | def __init__(self, **kwargs): 62 | super().__init__(**kwargs) 63 | 64 | @xor_args(("query_texts", "query_embeddings")) 65 | def __query_collection( 66 | self, 67 | query_texts: Optional[List[str]] = None, 68 | query_embeddings: Optional[List[List[float]]] = None, 69 | n_results: int = 4, 70 | where: Optional[Dict[str, str]] = None, 71 | where_document: Optional[Dict[str, str]] = None, 72 | **kwargs: Any, 73 | ) -> List[Document]: 74 | """Query the chroma collection.""" 75 | try: 76 | import chromadb # noqa: F401 77 | except ImportError: 78 | raise ValueError( 79 | "Could not import chromadb python package. " 80 | "Please install it with `pip install chromadb`." 81 | ) 82 | return self._collection.query( 83 | query_texts=query_texts, 84 | query_embeddings=query_embeddings, 85 | n_results=n_results, 86 | where=where, 87 | where_document=where_document, 88 | **kwargs, 89 | ) 90 | 91 | def advanced_similarity_search( 92 | self, 93 | query: str, 94 | k: int = DEFAULT_K, 95 | filter: Optional[Dict[str, str]] = None, 96 | **kwargs: Any, 97 | ) -> [List[Document], float, List[float]]: 98 | docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter) 99 | return docs_scores_and_embeddings 100 | 101 | def similarity_search_with_scores_and_embeddings( 102 | self, 103 | query: str, 104 | k: int = DEFAULT_K, 105 | filter: Optional[Dict[str, str]] = None, 106 | where_document: Optional[Dict[str, str]] = None, 107 | **kwargs: Any, 108 | ) -> List[Tuple[Document, float, List[float]]]: 109 | 110 | if self._embedding_function is None: 111 | results = self.__query_collection( 112 | query_texts=[query], 113 | n_results=k, 114 | where=filter, 115 | where_document=where_document, 116 | include=['metadatas', 'documents', 'embeddings', 'distances'] 117 | ) 118 | else: 119 | query_embedding = self._embedding_function.embed_query(query) 120 | results = self.__query_collection( 121 | query_embeddings=[query_embedding], 122 | n_results=k, 123 | where=filter, 124 | where_document=where_document, 125 | include=['metadatas', 'documents', 'embeddings', 'distances'] 126 | ) 127 | 128 | return _results_to_docs_scores_and_embeddings(results) 129 | 130 | 131 | def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]: 132 | return [ 133 | (Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3]) 134 | for result in zip( 135 | results["documents"][0], 136 | results["metadatas"][0], 137 | results["distances"][0], 138 | results["embeddings"][0], 139 | ) 140 | ] 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Scientific Document Insights Q/A 3 | emoji: 📝 4 | colorFrom: yellow 5 | colorTo: pink 6 | sdk: streamlit 7 | sdk_version: 1.37.1 8 | app_file: streamlit_app.py 9 | pinned: false 10 | license: apache-2.0 11 | app_port: 8501 12 | --- 13 | 14 | # DocumentIQA: Scientific Document Insights Q/A 15 | 16 | **Work in progress** :construction_worker: 17 | 18 | 19 | 20 | https://lfoppiano-document-qa.hf.space/ 21 | 22 | **NOTE**: The LLM API is kindly provided by [Modal.com](https://www.modal.com) which offers 30$/month for computing. When these are done, the app will stop answering. 😅 23 | 24 | ## Introduction 25 | 26 | Question/Answering on scientific documents using LLMs. The tool can be customized to use different types of LLM APIs. 27 | The streamlit application demonstrates the implementation of a RAG (Retrieval Augmented Generation) on scientific documents. 28 | **Different from most of the projects**, we focus on scientific articles and extract text from a structured document. 29 | We target only the full text using [Grobid](https://github.com/kermitt2/grobid) which provides cleaner results than the raw PDF2Text converter (which is comparable with most of the other solutions). 30 | 31 | Additionally, this frontend provides the visualisation of named entities on LLM responses to extract physical quantities, measurements (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and materials mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)). 32 | 33 | (The image on the right was generated with https://huggingface.co/spaces/stabilityai/stable-diffusion) 34 | 35 | [](https://www.youtube.com/embed/M4UaYs5WKGs) 37 | 38 | ## Getting started 39 | 40 | - Upload a scientific article as a PDF document. You will see a spinner or loading indicator while the processing is in progress. 41 | - Once the spinner disappears, you can proceed to ask your questions 42 | 43 | ![screenshot2.png](docs%2Fimages%2Fscreenshot2.png) 44 | 45 | ## Documentation 46 | 47 | ### Embedding selection 48 | In the latest version, there is the possibility to select both embedding functions and LLMs. There are some limitations, OpenAI embeddings cannot be used with open source models, and vice-versa. 49 | 50 | ### Context size 51 | Allow to change the number of blocks from the original document that are considered for responding. 52 | The default size of each block is 250 tokens (which can be changed before uploading the first document). 53 | With default settings, each question uses around 1000 tokens. 54 | 55 | **NOTE**: if the chat answers something like "the information is not provided in the given context", **changing the context size will likely help**. 56 | 57 | ### Chunks size 58 | When uploaded, each document is split into blocks of a determined size (250 tokens by default). 59 | This setting allows users to modify the size of such blocks. 60 | Smaller blocks will result in a smaller context, yielding more precise sections of the document. 61 | Larger blocks will result in a larger context less constrained around the question. 62 | 63 | ### Query mode 64 | Indicates whether sending a question to the LLM (Language Model) or the vector storage. 65 | - **LLM** (default) enables question/answering related to the document content. 66 | - **Embeddings**: the response will consist of the raw text from the document related to the question (based on the embeddings). This mode helps to test why sometimes the answers are not satisfying or incomplete. 67 | - **Question coefficient** (experimental): provide a coefficient that indicates how the question has been far or closed to the retrieved context 68 | 69 | ### NER (Named Entities Recognition) 70 | This feature is specifically crafted for people working with scientific documents in materials science. 71 | It enables to run NER on the response from the LLM, to identify materials mentions and properties (quantities, measurements). 72 | This feature leverages both [grobid-quantities](https://github.com/kermitt2/grobid-quanities) and [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors) external services. 73 | 74 | ### Troubleshooting 75 | Error: `streamlit: Your system has an unsupported version of sqlite3. Chroma requires sqlite3 >= 3.35.0`. 76 | Here is the [solution on Linux](https://stackoverflow.com/questions/76958817/streamlit-your-system-has-an-unsupported-version-of-sqlite3-chroma-requires-sq). 77 | For more information, see the [details](https://docs.trychroma.com/troubleshooting#sqlite) on the Chroma website. 78 | 79 | ## Disclaimer on Data, Security, and Privacy ⚠️ 80 | 81 | Please read carefully: 82 | 83 | - Avoid uploading sensitive data. We temporarily store text from the uploaded PDF documents only for processing your request, and we disclaim any responsibility for subsequent use or handling of the submitted data by third-party LLMs. 84 | - Mistral and Zephyr are FREE to use and do not require any API, but as we leverage the free API entrypoint, there is no guarantee that all requests will go through. Use at your own risk. 85 | - We do not assume responsibility for how the data is utilized by the LLM end-points API. 86 | 87 | ## Development notes 88 | 89 | To release a new version: 90 | 91 | - `bump-my-version bump patch` 92 | - `git push --tags` 93 | 94 | To use docker: 95 | 96 | - docker run `lfoppiano/document-insights-qa:{latest_version)` 97 | 98 | - docker run `lfoppiano/document-insights-qa:latest-develop` for the latest development version 99 | 100 | To install the library with Pypi: 101 | 102 | - `pip install document-qa-engine` 103 | 104 | 105 | ## Acknowledgement 106 | 107 | The project was initiated at the [National Institute for Materials Science](https://www.nims.go.jp) (NIMS) in Japan. 108 | Currently, the development is possible thanks to [ScienciLAB](https://www.sciencialab.com). 109 | This project was contributed by [Guillaume Lambard](https://github.com/GLambard) and the [Lambard-ML-Team](https://github.com/Lambard-ML-Team), [Pedro Ortiz Suarez](https://github.com/pjox), and [Tomoya Mato](https://github.com/t29mato). 110 | Thanks also to [Patrice Lopez](https://www.science-miner.com), the author of [Grobid](https://github.com/kermitt2/grobid). 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /document_qa/ner_client_generic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import yaml 5 | 6 | ''' 7 | This client is a generic client for any Grobid application and sub-modules. 8 | At the moment, it supports only single document processing. 9 | 10 | Source: https://github.com/kermitt2/grobid-client-python 11 | ''' 12 | 13 | """ Generic API Client """ 14 | from copy import deepcopy 15 | import json 16 | import requests 17 | 18 | try: 19 | from urlparse import urljoin 20 | except ImportError: 21 | from urllib.parse import urljoin 22 | 23 | 24 | class ApiClient(object): 25 | """ Client to interact with a generic Rest API. 26 | 27 | Subclasses should implement functionality accordingly with the provided 28 | service methods, i.e. ``get``, ``post``, ``put`` and ``delete``. 29 | """ 30 | 31 | accept_type = 'application/xml' 32 | api_base = None 33 | 34 | def __init__( 35 | self, 36 | base_url, 37 | username=None, 38 | api_key=None, 39 | status_endpoint=None, 40 | timeout=60 41 | ): 42 | """ Initialise client. 43 | 44 | Args: 45 | base_url (str): The base URL to the service being used. 46 | username (str): The username to authenticate with. 47 | api_key (str): The API key to authenticate with. 48 | timeout (int): Maximum time before timing out. 49 | """ 50 | self.base_url = base_url 51 | self.username = username 52 | self.api_key = api_key 53 | self.status_endpoint = urljoin(self.base_url, status_endpoint) 54 | self.timeout = timeout 55 | 56 | @staticmethod 57 | def encode(request, data): 58 | """ Add request content data to request body, set Content-type header. 59 | 60 | Should be overridden by subclasses if not using JSON encoding. 61 | 62 | Args: 63 | request (HTTPRequest): The request object. 64 | data (dict, None): Data to be encoded. 65 | 66 | Returns: 67 | HTTPRequest: The request object. 68 | """ 69 | if data is None: 70 | return request 71 | 72 | request.add_header('Content-Type', 'application/json') 73 | request.extracted_data = json.dumps(data) 74 | 75 | return request 76 | 77 | @staticmethod 78 | def decode(response): 79 | """ Decode the returned data in the response. 80 | 81 | Should be overridden by subclasses if something else than JSON is 82 | expected. 83 | 84 | Args: 85 | response (HTTPResponse): The response object. 86 | 87 | Returns: 88 | dict or None. 89 | """ 90 | try: 91 | return response.json() 92 | except ValueError as e: 93 | return e.message 94 | 95 | def get_credentials(self): 96 | """ Returns parameters to be added to authenticate the request. 97 | 98 | This lives on its own to make it easier to re-implement it if needed. 99 | 100 | Returns: 101 | dict: A dictionary containing the credentials. 102 | """ 103 | return {"username": self.username, "api_key": self.api_key} 104 | 105 | def call_api( 106 | self, 107 | method, 108 | url, 109 | headers=None, 110 | params=None, 111 | data=None, 112 | files=None, 113 | timeout=None, 114 | ): 115 | """ Call API. 116 | 117 | This returns object containing data, with error details if applicable. 118 | 119 | Args: 120 | method (str): The HTTP method to use. 121 | url (str): Resource location relative to the base URL. 122 | headers (dict or None): Extra request headers to set. 123 | params (dict or None): Query-string parameters. 124 | data (dict or None): Request body contents for POST or PUT requests. 125 | files (dict or None: Files to be passed to the request. 126 | timeout (int): Maximum time before timing out. 127 | 128 | Returns: 129 | ResultParser or ErrorParser. 130 | """ 131 | headers = deepcopy(headers) or {} 132 | headers['Accept'] = self.accept_type if 'Accept' not in headers else headers['Accept'] 133 | params = deepcopy(params) or {} 134 | data = data or {} 135 | files = files or {} 136 | # if self.username is not None and self.api_key is not None: 137 | # params.update(self.get_credentials()) 138 | r = requests.request( 139 | method, 140 | url, 141 | headers=headers, 142 | params=params, 143 | files=files, 144 | data=data, 145 | timeout=timeout, 146 | ) 147 | 148 | return r, r.status_code 149 | 150 | def get(self, url, params=None, **kwargs): 151 | """ Call the API with a GET request. 152 | 153 | Args: 154 | url (str): Resource location relative to the base URL. 155 | params (dict or None): Query-string parameters. 156 | 157 | Returns: 158 | ResultParser or ErrorParser. 159 | """ 160 | return self.call_api( 161 | "GET", 162 | url, 163 | params=params, 164 | **kwargs 165 | ) 166 | 167 | def delete(self, url, params=None, **kwargs): 168 | """ Call the API with a DELETE request. 169 | 170 | Args: 171 | url (str): Resource location relative to the base URL. 172 | params (dict or None): Query-string parameters. 173 | 174 | Returns: 175 | ResultParser or ErrorParser. 176 | """ 177 | return self.call_api( 178 | "DELETE", 179 | url, 180 | params=params, 181 | **kwargs 182 | ) 183 | 184 | def put(self, url, params=None, data=None, files=None, **kwargs): 185 | """ Call the API with a PUT request. 186 | 187 | Args: 188 | url (str): Resource location relative to the base URL. 189 | params (dict or None): Query-string parameters. 190 | data (dict or None): Request body contents. 191 | files (dict or None: Files to be passed to the request. 192 | 193 | Returns: 194 | An instance of ResultParser or ErrorParser. 195 | """ 196 | return self.call_api( 197 | "PUT", 198 | url, 199 | params=params, 200 | data=data, 201 | files=files, 202 | **kwargs 203 | ) 204 | 205 | def post(self, url, params=None, data=None, files=None, **kwargs): 206 | """ Call the API with a POST request. 207 | 208 | Args: 209 | url (str): Resource location relative to the base URL. 210 | params (dict or None): Query-string parameters. 211 | data (dict or None): Request body contents. 212 | files (dict or None: Files to be passed to the request. 213 | 214 | Returns: 215 | An instance of ResultParser or ErrorParser. 216 | """ 217 | return self.call_api( 218 | method="POST", 219 | url=url, 220 | params=params, 221 | data=data, 222 | files=files, 223 | **kwargs 224 | ) 225 | 226 | def service_status(self, **kwargs): 227 | """ Call the API to get the status of the service. 228 | 229 | Returns: 230 | An instance of ResultParser or ErrorParser. 231 | """ 232 | return self.call_api( 233 | 'GET', 234 | self.status_endpoint, 235 | params={'format': 'json'}, 236 | **kwargs 237 | ) 238 | 239 | 240 | class NERClientGeneric(ApiClient): 241 | 242 | def __init__(self, config_path=None, ping=False): 243 | self.config = None 244 | if config_path is not None: 245 | self.config = self._load_yaml_config_from_file(path=config_path) 246 | super().__init__(self.config['grobid']['server']) 247 | 248 | if ping: 249 | result = self.ping_service() 250 | if not result: 251 | raise Exception("Grobid is down.") 252 | 253 | os.environ['NO_PROXY'] = "nims.go.jp" 254 | 255 | @staticmethod 256 | def _load_json_config_from_file(path='./config.json'): 257 | """ 258 | Load the json configuration 259 | """ 260 | config = {} 261 | with open(path, 'r') as fp: 262 | config = json.load(fp) 263 | 264 | return config 265 | 266 | @staticmethod 267 | def _load_yaml_config_from_file(path='./config.yaml'): 268 | """ 269 | Load the YAML configuration 270 | """ 271 | config = {} 272 | try: 273 | with open(path, 'r') as the_file: 274 | raw_configuration = the_file.read() 275 | 276 | config = yaml.safe_load(raw_configuration) 277 | except Exception as e: 278 | print("Configuration could not be loaded: ", str(e)) 279 | exit(1) 280 | 281 | return config 282 | 283 | def set_config(self, config, ping=False): 284 | self.config = config 285 | if ping: 286 | try: 287 | result = self.ping_service() 288 | if not result: 289 | raise Exception("Grobid is down.") 290 | except Exception as e: 291 | raise Exception("Grobid is down or other problems were encountered. ", e) 292 | 293 | def ping_service(self): 294 | # test if the server is up and running... 295 | ping_url = self.get_url("ping") 296 | 297 | r = requests.get(ping_url) 298 | status = r.status_code 299 | 300 | if status != 200: 301 | print('GROBID server does not appear up and running ' + str(status)) 302 | return False 303 | else: 304 | print("GROBID server is up and running") 305 | return True 306 | 307 | def get_url(self, action): 308 | grobid_config = self.config['grobid'] 309 | base_url = grobid_config['server'] 310 | action_url = base_url + grobid_config['url_mapping'][action] 311 | 312 | return action_url 313 | 314 | def process_texts(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}): 315 | 316 | files = { 317 | 'texts': input 318 | } 319 | 320 | the_url = self.get_url(method_name) 321 | params, the_url = self.get_params_from_url(the_url) 322 | 323 | res, status = self.post( 324 | url=the_url, 325 | files=files, 326 | data=params, 327 | headers=headers 328 | ) 329 | 330 | if status == 503: 331 | time.sleep(self.config['sleep_time']) 332 | return self.process_texts(input, method_name, params, headers) 333 | elif status != 200: 334 | print('Processing failed with error ' + str(status)) 335 | return status, None 336 | else: 337 | return status, json.loads(res.text) 338 | 339 | def process_text(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}): 340 | 341 | files = { 342 | 'text': input 343 | } 344 | 345 | the_url = self.get_url(method_name) 346 | params, the_url = self.get_params_from_url(the_url) 347 | 348 | res, status = self.post( 349 | url=the_url, 350 | files=files, 351 | data=params, 352 | headers=headers 353 | ) 354 | 355 | if status == 503: 356 | time.sleep(self.config['sleep_time']) 357 | return self.process_text(input, method_name, params, headers) 358 | elif status != 200: 359 | print('Processing failed with error ' + str(status)) 360 | return status, None 361 | else: 362 | return status, json.loads(res.text) 363 | 364 | def process_pdf(self, 365 | form_data: dict, 366 | method_name='superconductors', 367 | params={}, 368 | headers={"Accept": "application/json"} 369 | ): 370 | 371 | the_url = self.get_url(method_name) 372 | params, the_url = self.get_params_from_url(the_url) 373 | 374 | res, status = self.post( 375 | url=the_url, 376 | files=form_data, 377 | data=params, 378 | headers=headers 379 | ) 380 | 381 | if status == 503: 382 | time.sleep(self.config['sleep_time']) 383 | return self.process_text(input, method_name, params, headers) 384 | elif status != 200: 385 | print('Processing failed with error ' + str(status)) 386 | else: 387 | return res.text 388 | 389 | def process_pdfs(self, pdf_files, params={}): 390 | pass 391 | 392 | def process_pdf( 393 | self, 394 | pdf_file, 395 | method_name, 396 | params={}, 397 | headers={"Accept": "application/json"}, 398 | verbose=False, 399 | retry=None 400 | ): 401 | 402 | files = { 403 | 'input': ( 404 | pdf_file, 405 | open(pdf_file, 'rb'), 406 | 'application/pdf', 407 | {'Expires': '0'} 408 | ) 409 | } 410 | 411 | the_url = self.get_url(method_name) 412 | 413 | params, the_url = self.get_params_from_url(the_url) 414 | 415 | res, status = self.post( 416 | url=the_url, 417 | files=files, 418 | data=params, 419 | headers=headers 420 | ) 421 | 422 | if status == 503 or status == 429: 423 | if retry is None: 424 | retry = self.config['max_retry'] - 1 425 | else: 426 | if retry - 1 == 0: 427 | if verbose: 428 | print("re-try exhausted. Aborting request") 429 | return None, status 430 | else: 431 | retry -= 1 432 | 433 | sleep_time = self.config['sleep_time'] 434 | if verbose: 435 | print("Server is saturated, waiting", sleep_time, "seconds and trying again. ") 436 | time.sleep(sleep_time) 437 | return self.process_pdf(pdf_file, method_name, params, headers, verbose=verbose, retry=retry) 438 | elif status != 200: 439 | desc = None 440 | if res.content: 441 | c = json.loads(res.text) 442 | desc = c['description'] if 'description' in c else None 443 | return desc, status 444 | elif status == 204: 445 | # print('No content returned. Moving on. ') 446 | return None, status 447 | else: 448 | return res.text, status 449 | 450 | def get_params_from_url(self, the_url): 451 | """ 452 | This method is used to pass to the URL predefined parameters, which are added in the URL format 453 | """ 454 | params = {} 455 | if "?" in the_url: 456 | split = the_url.split("?") 457 | the_url = split[0] 458 | params = split[1] 459 | 460 | params = {param.split("=")[0]: param.split("=")[1] for param in params.split("&")} 461 | return params, the_url 462 | -------------------------------------------------------------------------------- /streamlit_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from hashlib import blake2b 4 | from tempfile import NamedTemporaryFile 5 | 6 | import dotenv 7 | from grobid_quantities.quantities import QuantitiesAPI 8 | from langchain.memory import ConversationBufferMemory 9 | from langchain_openai import ChatOpenAI 10 | from streamlit_pdf_viewer import pdf_viewer 11 | 12 | from document_qa.custom_embeddings import ModalEmbeddings 13 | from document_qa.ner_client_generic import NERClientGeneric 14 | 15 | dotenv.load_dotenv(override=True) 16 | 17 | import streamlit as st 18 | from document_qa.document_qa_engine import DocumentQAEngine, DataStorage 19 | from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations 20 | 21 | API_MODELS = { 22 | "microsoft/Phi-4-mini-instruct": os.environ["PHI_URL"], 23 | "Qwen/Qwen3-0.6B": os.environ["QWEN_URL"] 24 | } 25 | 26 | API_EMBEDDINGS = { 27 | 'intfloat/multilingual-e5-large-instruct-modal': os.environ['EMBEDS_URL'] 28 | } 29 | 30 | if 'rqa' not in st.session_state: 31 | st.session_state['rqa'] = {} 32 | 33 | if 'model' not in st.session_state: 34 | st.session_state['model'] = None 35 | 36 | if 'api_keys' not in st.session_state: 37 | st.session_state['api_keys'] = {} 38 | 39 | if 'doc_id' not in st.session_state: 40 | st.session_state['doc_id'] = None 41 | 42 | if 'loaded_embeddings' not in st.session_state: 43 | st.session_state['loaded_embeddings'] = None 44 | 45 | if 'hash' not in st.session_state: 46 | st.session_state['hash'] = None 47 | 48 | if 'git_rev' not in st.session_state: 49 | st.session_state['git_rev'] = "unknown" 50 | if os.path.exists("revision.txt"): 51 | with open("revision.txt", 'r') as fr: 52 | from_file = fr.read() 53 | st.session_state['git_rev'] = from_file if len(from_file) > 0 else "unknown" 54 | 55 | if "messages" not in st.session_state: 56 | st.session_state.messages = [] 57 | 58 | if 'ner_processing' not in st.session_state: 59 | st.session_state['ner_processing'] = False 60 | 61 | if 'uploaded' not in st.session_state: 62 | st.session_state['uploaded'] = False 63 | 64 | if 'memory' not in st.session_state: 65 | st.session_state['memory'] = None 66 | 67 | if 'binary' not in st.session_state: 68 | st.session_state['binary'] = None 69 | 70 | if 'annotations' not in st.session_state: 71 | st.session_state['annotations'] = None 72 | 73 | if 'should_show_annotations' not in st.session_state: 74 | st.session_state['should_show_annotations'] = True 75 | 76 | if 'pdf' not in st.session_state: 77 | st.session_state['pdf'] = None 78 | 79 | if 'embeddings' not in st.session_state: 80 | st.session_state['embeddings'] = None 81 | 82 | if 'scroll_to_first_annotation' not in st.session_state: 83 | st.session_state['scroll_to_first_annotation'] = False 84 | 85 | st.set_page_config( 86 | page_title="Scientific Document Insights Q/A", 87 | page_icon="📝", 88 | initial_sidebar_state="expanded", 89 | layout="wide", 90 | menu_items={ 91 | 'Get Help': 'https://github.com/lfoppiano/document-qa', 92 | 'Report a bug': "https://github.com/lfoppiano/document-qa/issues", 93 | 'About': "Upload a scientific article in PDF, ask questions, get insights." 94 | } 95 | ) 96 | 97 | st.markdown( 98 | """ 99 | 107 | """, 108 | unsafe_allow_html=True 109 | ) 110 | 111 | 112 | def new_file(): 113 | st.session_state['loaded_embeddings'] = None 114 | st.session_state['doc_id'] = None 115 | st.session_state['uploaded'] = True 116 | st.session_state['annotations'] = [] 117 | if st.session_state['memory']: 118 | st.session_state['memory'].clear() 119 | 120 | 121 | def clear_memory(): 122 | st.session_state['memory'].clear() 123 | 124 | 125 | # @st.cache_resource 126 | def init_qa(model_name, embeddings_name): 127 | st.session_state['memory'] = ConversationBufferMemory( 128 | memory_key="chat_history", 129 | return_messages=True 130 | ) 131 | chat = ChatOpenAI( 132 | model=model_name, 133 | temperature=0.0, 134 | base_url=API_MODELS[model_name], 135 | api_key=os.environ.get('API_KEY') 136 | ) 137 | 138 | embeddings = ModalEmbeddings( 139 | url=API_EMBEDDINGS[embeddings_name], 140 | model_name=embeddings_name, 141 | api_key=os.environ.get('EMBEDS_API_KEY') 142 | ) 143 | 144 | storage = DataStorage(embeddings) 145 | return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory']) 146 | 147 | 148 | @st.cache_resource 149 | def init_ner(): 150 | quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True) 151 | 152 | materials_client = NERClientGeneric(ping=True) 153 | config_materials = { 154 | 'grobid': { 155 | "server": os.environ['GROBID_MATERIALS_URL'], 156 | 'sleep_time': 5, 157 | 'timeout': 60, 158 | 'url_mapping': { 159 | 'processText_disable_linking': "/service/process/text?disableLinking=True", 160 | # 'processText_disable_linking': "/service/process/text" 161 | } 162 | } 163 | } 164 | 165 | materials_client.set_config(config_materials) 166 | 167 | gqa = GrobidAggregationProcessor(grobid_quantities_client=quantities_client, 168 | grobid_superconductors_client=materials_client) 169 | return gqa 170 | 171 | 172 | gqa = init_ner() 173 | 174 | 175 | def get_file_hash(fname): 176 | hash_md5 = blake2b() 177 | with open(fname, "rb") as f: 178 | for chunk in iter(lambda: f.read(4096), b""): 179 | hash_md5.update(chunk) 180 | return hash_md5.hexdigest() 181 | 182 | 183 | def play_old_messages(container): 184 | if st.session_state['messages']: 185 | for message in st.session_state['messages']: 186 | if message['role'] == 'user': 187 | container.chat_message("user").markdown(message['content']) 188 | elif message['role'] == 'assistant': 189 | if mode == "LLM": 190 | container.chat_message("assistant").markdown(message['content'], unsafe_allow_html=True) 191 | else: 192 | container.chat_message("assistant").write(message['content']) 193 | 194 | 195 | # is_api_key_provided = st.session_state['api_key'] 196 | 197 | with st.sidebar: 198 | st.title("📝 Document Q/A") 199 | st.markdown("Upload a scientific article in PDF, ask questions, get insights.") 200 | st.markdown( 201 | ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ") 202 | st.markdown("LM and Embeddings are powered by [Modal.com](https://modal.com/)") 203 | 204 | st.divider() 205 | st.session_state['model'] = model = st.selectbox( 206 | "Model:", 207 | options=API_MODELS.keys(), 208 | index=(list(API_MODELS.keys())).index( 209 | os.environ["DEFAULT_MODEL"]) if "DEFAULT_MODEL" in os.environ and os.environ["DEFAULT_MODEL"] else 0, 210 | placeholder="Select model", 211 | help="Select the LLM model:", 212 | disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'] 213 | ) 214 | 215 | st.session_state['embeddings'] = embedding_name = st.selectbox( 216 | "Embeddings:", 217 | options=API_EMBEDDINGS.keys(), 218 | index=(list(API_EMBEDDINGS.keys())).index( 219 | os.environ["DEFAULT_EMBEDDING"]) if "DEFAULT_EMBEDDING" in os.environ and os.environ[ 220 | "DEFAULT_EMBEDDING"] else 0, 221 | placeholder="Select embedding", 222 | help="Select the Embedding function:", 223 | disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'] 224 | ) 225 | 226 | api_key = os.environ['API_KEY'] 227 | 228 | if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']: 229 | with st.spinner("Preparing environment"): 230 | st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings']) 231 | st.session_state['api_keys'][model] = api_key 232 | 233 | left_column, right_column = st.columns([5, 4]) 234 | right_column = right_column.container(border=True) 235 | left_column = left_column.container(border=True) 236 | 237 | with right_column: 238 | uploaded_file = st.file_uploader( 239 | "Upload a scientific article", 240 | type=("pdf"), 241 | on_change=new_file, 242 | disabled=st.session_state['model'] is not None and st.session_state['model'] not in 243 | st.session_state['api_keys'], 244 | help="The full-text is extracted using Grobid." 245 | ) 246 | 247 | placeholder = st.empty() 248 | messages = st.container(height=300) 249 | 250 | question = st.chat_input( 251 | "Ask something about the article", 252 | # placeholder="Can you give me a short summary?", 253 | disabled=not uploaded_file 254 | ) 255 | 256 | query_modes = { 257 | "llm": "LLM Q/A", 258 | "embeddings": "Embeddings", 259 | "question_coefficient": "Question coefficient" 260 | } 261 | 262 | with st.sidebar: 263 | st.header("Settings") 264 | mode = st.radio( 265 | "Query mode", 266 | ("llm", "embeddings", "question_coefficient"), 267 | disabled=not uploaded_file, 268 | index=0, 269 | horizontal=True, 270 | format_func=lambda x: query_modes[x], 271 | help="LLM will respond the question, Embedding will show the " 272 | "relevant paragraphs to the question in the paper. " 273 | "Question coefficient attempt to estimate how effective the question will be answered." 274 | ) 275 | st.session_state['scroll_to_first_annotation'] = st.checkbox( 276 | "Scroll to context", 277 | help='The PDF viewer will automatically scroll to the first relevant passage in the document.' 278 | ) 279 | st.session_state['ner_processing'] = st.checkbox( 280 | "Identify materials and properties.", 281 | help='The LLM responses undergo post-processing to extract physical quantities, measurements, and materials mentions.' 282 | ) 283 | 284 | # Add a checkbox for showing annotations 285 | # st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True) 286 | # st.session_state['should_show_annotations'] = st.checkbox("Show annotations", value=True) 287 | 288 | chunk_size = st.slider("Text chunks size", -1, 2000, value=-1, 289 | help="Size of chunks in which split the document. -1: use paragraphs, > 0 paragraphs are aggregated.", 290 | disabled=uploaded_file is not None) 291 | if chunk_size == -1: 292 | context_size = st.slider("Context size (paragraphs)", 3, 20, value=10, 293 | help="Number of paragraphs to consider when answering a question", 294 | disabled=not uploaded_file) 295 | else: 296 | context_size = st.slider("Context size (chunks)", 3, 10, value=4, 297 | help="Number of chunks to consider when answering a question", 298 | disabled=not uploaded_file) 299 | 300 | st.divider() 301 | 302 | st.header("Documentation") 303 | st.markdown("https://github.com/lfoppiano/document-qa") 304 | st.markdown( 305 | """Upload a scientific article as PDF document. Once the spinner stops, you can proceed to ask your questions.""") 306 | 307 | if st.session_state['git_rev'] != "unknown": 308 | st.markdown("**Revision number**: [" + st.session_state[ 309 | 'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")") 310 | 311 | if uploaded_file and not st.session_state.loaded_embeddings: 312 | if model not in st.session_state['api_keys']: 313 | st.error("Before uploading a document, you must enter the API key. ") 314 | st.stop() 315 | 316 | with left_column: 317 | with st.spinner('Reading file, calling Grobid, and creating in-memory embeddings...'): 318 | binary = uploaded_file.getvalue() 319 | tmp_file = NamedTemporaryFile() 320 | tmp_file.write(bytearray(binary)) 321 | st.session_state['binary'] = binary 322 | 323 | st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings( 324 | tmp_file.name, 325 | chunk_size=chunk_size, 326 | perc_overlap=0.1 327 | ) 328 | st.session_state['loaded_embeddings'] = True 329 | st.session_state.messages = [] 330 | 331 | 332 | def rgb_to_hex(rgb): 333 | return "#{:02x}{:02x}{:02x}".format(*rgb) 334 | 335 | 336 | def generate_color_gradient(num_elements): 337 | # Define warm and cold colors in RGB format 338 | warm_color = (255, 165, 0) # Orange 339 | cold_color = (0, 0, 255) # Blue 340 | 341 | # Generate a linear gradient of colors 342 | color_gradient = [ 343 | rgb_to_hex(tuple(int(warm * (1 - i / num_elements) + cold * (i / num_elements)) for warm, cold in 344 | zip(warm_color, cold_color))) 345 | for i in range(num_elements) 346 | ] 347 | 348 | return color_gradient 349 | 350 | 351 | with right_column: 352 | if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id: 353 | st.session_state.messages.append({"role": "user", "mode": mode, "content": question}) 354 | 355 | for message in st.session_state.messages: 356 | # with messages.chat_message(message["role"]): 357 | if message['mode'] == "llm": 358 | messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) 359 | elif message['mode'] == "embeddings": 360 | messages.chat_message(message["role"]).write(message["content"]) 361 | elif message['mode'] == "question_coefficient": 362 | messages.chat_message(message["role"]).markdown(message["content"], unsafe_allow_html=True) 363 | if model not in st.session_state['rqa']: 364 | st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `") 365 | st.stop() 366 | 367 | text_response = None 368 | if mode == "embeddings": 369 | with placeholder: 370 | with st.spinner("Fetching the relevant context..."): 371 | text_response, coordinates = st.session_state['rqa'][model].query_storage( 372 | question, 373 | st.session_state.doc_id, 374 | context_size=context_size 375 | ) 376 | elif mode == "llm": 377 | with placeholder: 378 | with st.spinner("Generating LLM response..."): 379 | _, text_response, coordinates = st.session_state['rqa'][model].query_document( 380 | question, 381 | st.session_state.doc_id, 382 | context_size=context_size 383 | ) 384 | 385 | elif mode == "question_coefficient": 386 | with st.spinner("Estimate question/context relevancy..."): 387 | text_response, coordinates = st.session_state['rqa'][model].analyse_query( 388 | question, 389 | st.session_state.doc_id, 390 | context_size=context_size 391 | ) 392 | 393 | annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc] 394 | for coord_doc in coordinates] 395 | gradients = generate_color_gradient(len(annotations)) 396 | for i, color in enumerate(gradients): 397 | for annotation in annotations[i]: 398 | annotation['color'] = color 399 | if i == 0: 400 | annotation['border'] = "dotted" 401 | 402 | st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in 403 | annotation_doc] 404 | 405 | if not text_response: 406 | st.error("Something went wrong. Contact info AT sciencialab.com to report the issue through GitHub.") 407 | 408 | if mode == "llm": 409 | if st.session_state['ner_processing']: 410 | with st.spinner("Processing NER on LLM response..."): 411 | entities = gqa.process_single_text(text_response) 412 | decorated_text = decorate_text_with_annotations(text_response.strip(), entities) 413 | decorated_text = decorated_text.replace('class="label material"', 'style="color:green"') 414 | decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text) 415 | text_response = decorated_text 416 | messages.chat_message("assistant").markdown(text_response, unsafe_allow_html=True) 417 | else: 418 | messages.chat_message("assistant").write(text_response) 419 | st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response}) 420 | 421 | elif st.session_state.loaded_embeddings and st.session_state.doc_id: 422 | play_old_messages(messages) 423 | 424 | with left_column: 425 | if st.session_state['binary']: 426 | with st.container(height=600): 427 | pdf_viewer( 428 | input=st.session_state['binary'], 429 | annotation_outline_size=2, 430 | annotations=st.session_state['annotations'] if st.session_state['annotations'] else [], 431 | render_text=True, 432 | scroll_to_annotation=1 if (st.session_state['annotations'] and st.session_state[ 433 | 'scroll_to_first_annotation']) else None 434 | ) 435 | -------------------------------------------------------------------------------- /document_qa/document_qa_engine.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from pathlib import Path 4 | from typing import Union, Any, List 5 | 6 | import tiktoken 7 | from langchain.chains import create_extraction_chain 8 | from langchain.chains.combine_documents import create_stuff_documents_chain 9 | from langchain.chains.question_answering import stuff_prompt, refine_prompts, map_reduce_prompt, \ 10 | map_rerank_prompt 11 | from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate 12 | from langchain.retrievers import MultiQueryRetriever 13 | from langchain.schema import Document 14 | from langchain_community.vectorstores.chroma import Chroma 15 | from langchain_core.vectorstores import VectorStore 16 | from tqdm import tqdm 17 | 18 | from document_qa.grobid_processors import GrobidProcessor 19 | from document_qa.langchain import ChromaAdvancedRetrieval 20 | 21 | 22 | class TextMerger: 23 | """ 24 | This class tries to replicate the RecursiveTextSplitter from LangChain, to preserve and merge the 25 | coordinate information from the PDF document. 26 | """ 27 | 28 | def __init__(self, model_name=None, encoding_name="gpt2"): 29 | if model_name is not None: 30 | self.enc = tiktoken.encoding_for_model(model_name) 31 | else: 32 | self.enc = tiktoken.get_encoding(encoding_name) 33 | 34 | def encode(self, text, allowed_special=set(), disallowed_special="all"): 35 | return self.enc.encode( 36 | text, 37 | allowed_special=allowed_special, 38 | disallowed_special=disallowed_special, 39 | ) 40 | 41 | def merge_passages(self, passages, chunk_size, tolerance=0.2): 42 | new_passages = [] 43 | new_coordinates = [] 44 | current_texts = [] 45 | current_coordinates = [] 46 | for idx, passage in enumerate(passages): 47 | text = passage['text'] 48 | coordinates = passage['coordinates'] 49 | current_texts.append(text) 50 | current_coordinates.append(coordinates) 51 | 52 | accumulated_text = " ".join(current_texts) 53 | 54 | encoded_accumulated_text = self.encode(accumulated_text) 55 | 56 | if len(encoded_accumulated_text) > chunk_size + chunk_size * tolerance: 57 | if len(current_texts) > 1: 58 | new_passages.append(current_texts[:-1]) 59 | new_coordinates.append(current_coordinates[:-1]) 60 | current_texts = [current_texts[-1]] 61 | current_coordinates = [current_coordinates[-1]] 62 | else: 63 | new_passages.append(current_texts) 64 | new_coordinates.append(current_coordinates) 65 | current_texts = [] 66 | current_coordinates = [] 67 | 68 | elif chunk_size <= len(encoded_accumulated_text) < chunk_size + chunk_size * tolerance: 69 | new_passages.append(current_texts) 70 | new_coordinates.append(current_coordinates) 71 | current_texts = [] 72 | current_coordinates = [] 73 | 74 | if len(current_texts) > 0: 75 | new_passages.append(current_texts) 76 | new_coordinates.append(current_coordinates) 77 | 78 | new_passages_struct = [] 79 | for i, passages in enumerate(new_passages): 80 | text = " ".join(passages) 81 | coordinates = ";".join(new_coordinates[i]) 82 | 83 | new_passages_struct.append( 84 | { 85 | "text": text, 86 | "coordinates": coordinates, 87 | "type": "aggregated chunks", 88 | "section": "mixed", 89 | "subSection": "mixed" 90 | } 91 | ) 92 | 93 | return new_passages_struct 94 | 95 | 96 | class BaseRetrieval: 97 | 98 | def __init__( 99 | self, 100 | persist_directory: Path, 101 | embedding_function 102 | ): 103 | self.embedding_function = embedding_function 104 | self.persist_directory = persist_directory 105 | 106 | 107 | class NER_Retrival(VectorStore): 108 | """ 109 | This class implement a retrieval based on NER models. 110 | This is an alternative retrieval to embeddings that relies on extracted entities. 111 | """ 112 | pass 113 | 114 | 115 | engines = { 116 | 'chroma': ChromaAdvancedRetrieval, 117 | 'ner': NER_Retrival 118 | } 119 | 120 | 121 | class DataStorage: 122 | embeddings_dict = {} 123 | embeddings_map_from_md5 = {} 124 | embeddings_map_to_md5 = {} 125 | 126 | def __init__( 127 | self, 128 | embedding_function, 129 | root_path: Path = None, 130 | engine=ChromaAdvancedRetrieval, 131 | ) -> None: 132 | self.root_path = root_path 133 | self.engine = engine 134 | self.embedding_function = embedding_function 135 | 136 | if root_path is not None: 137 | self.embeddings_root_path = root_path 138 | if not os.path.exists(root_path): 139 | os.makedirs(root_path) 140 | else: 141 | self.load_embeddings(self.embeddings_root_path) 142 | 143 | def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None: 144 | """ 145 | Load the vector storage assuming they are all persisted and stored in a single directory. 146 | The root path of the embeddings containing one data store for each document in each subdirectory 147 | """ 148 | 149 | embeddings_directories = [f for f in os.scandir(embeddings_root_path) if f.is_dir()] 150 | 151 | if len(embeddings_directories) == 0: 152 | print("No available embeddings") 153 | return 154 | 155 | for embedding_document_dir in embeddings_directories: 156 | self.embeddings_dict[embedding_document_dir.name] = self.engine( 157 | persist_directory=embedding_document_dir.path, 158 | embedding_function=self.embedding_function 159 | ) 160 | 161 | filename_list = list(Path(embedding_document_dir).glob('*.storage_filename')) 162 | if filename_list: 163 | filenam = filename_list[0].name.replace(".storage_filename", "") 164 | self.embeddings_map_from_md5[embedding_document_dir.name] = filenam 165 | self.embeddings_map_to_md5[filenam] = embedding_document_dir.name 166 | 167 | print("Embedding loaded: ", len(self.embeddings_dict.keys())) 168 | 169 | def get_loaded_embeddings_ids(self): 170 | return list(self.embeddings_dict.keys()) 171 | 172 | def get_md5_from_filename(self, filename): 173 | return self.embeddings_map_to_md5[filename] 174 | 175 | def get_filename_from_md5(self, md5): 176 | return self.embeddings_map_from_md5[md5] 177 | 178 | def embed_document(self, doc_id, texts, metadatas): 179 | if doc_id not in self.embeddings_dict.keys(): 180 | self.embeddings_dict[doc_id] = self.engine.from_texts( 181 | texts, 182 | embedding=self.embedding_function, 183 | metadatas=metadatas, 184 | collection_name=doc_id) 185 | else: 186 | # Workaround Chroma (?) breaking change 187 | self.embeddings_dict[doc_id].delete_collection() 188 | self.embeddings_dict[doc_id] = self.engine.from_texts( 189 | texts, 190 | embedding=self.embedding_function, 191 | metadatas=metadatas, 192 | collection_name=doc_id) 193 | 194 | self.embeddings_root_path = None 195 | 196 | 197 | class DocumentQAEngine: 198 | llm = None 199 | qa_chain_type = None 200 | 201 | default_prompts = { 202 | 'stuff': stuff_prompt, 203 | 'refine': refine_prompts, 204 | "map_reduce": map_reduce_prompt, 205 | "map_rerank": map_rerank_prompt 206 | } 207 | 208 | def __init__(self, 209 | llm, 210 | data_storage: DataStorage, 211 | grobid_url=None, 212 | memory=None 213 | ): 214 | 215 | self.llm = llm 216 | self.memory = memory 217 | self.chain = create_stuff_documents_chain(llm, self.default_prompts['stuff'].PROMPT) 218 | self.text_merger = TextMerger() 219 | self.data_storage = data_storage 220 | 221 | if grobid_url: 222 | self.grobid_processor = GrobidProcessor(grobid_url) 223 | 224 | def query_document( 225 | self, 226 | query: str, 227 | doc_id, 228 | output_parser=None, 229 | context_size=4, 230 | extraction_schema=None, 231 | verbose=False 232 | ) -> (Any, str): 233 | # self.load_embeddings(self.embeddings_root_path) 234 | 235 | if verbose: 236 | print(query) 237 | 238 | response, coordinates = self._run_query(doc_id, query, context_size=context_size) 239 | response = response['output_text'] if 'output_text' in response else response 240 | 241 | if verbose: 242 | print(doc_id, "->", response) 243 | 244 | if output_parser: 245 | try: 246 | return self._parse_json(response, output_parser), response 247 | except Exception as oe: 248 | print("Failing to parse the response", oe) 249 | return None, response, coordinates 250 | elif extraction_schema: 251 | try: 252 | chain = create_extraction_chain(extraction_schema, self.llm) 253 | parsed = chain.run(response) 254 | return parsed, response, coordinates 255 | except Exception as oe: 256 | print("Failing to parse the response", oe) 257 | return None, response, coordinates 258 | else: 259 | return None, response, coordinates 260 | 261 | def query_storage(self, query: str, doc_id, context_size=4) -> (List[Document], list): 262 | """ 263 | Returns the context related to a given query 264 | """ 265 | documents, coordinates = self._get_context(doc_id, query, context_size) 266 | 267 | context_as_text = [doc.page_content for doc in documents] 268 | return context_as_text, coordinates 269 | 270 | def query_storage_and_embeddings(self, query: str, doc_id, context_size=4) -> List[Document]: 271 | """ 272 | Returns both the context and the embedding information from a given query 273 | """ 274 | db = self.data_storage.embeddings_dict[doc_id] 275 | retriever = db.as_retriever( 276 | search_kwargs={"k": context_size}, 277 | search_type="similarity_with_embeddings" 278 | ) 279 | relevant_documents = retriever.invoke(query) 280 | 281 | return relevant_documents 282 | 283 | def analyse_query(self, query, doc_id, context_size=4): 284 | db = self.data_storage.embeddings_dict[doc_id] 285 | # retriever = db.as_retriever( 286 | # search_kwargs={"k": context_size, 'score_threshold': 0.0}, 287 | # search_type="similarity_score_threshold" 288 | # ) 289 | retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings") 290 | relevant_documents = retriever.invoke(query) 291 | relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] 292 | for doc in 293 | relevant_documents] 294 | all_documents = db.get(include=['documents', 'metadatas', 'embeddings']) 295 | # all_documents_embeddings = all_documents["embeddings"] 296 | # query_embedding = db._embedding_function.embed_query(query) 297 | 298 | # distance_evaluator = load_evaluator("pairwise_embedding_distance", 299 | # embeddings=db._embedding_function, 300 | # distance_metric=EmbeddingDistance.EUCLIDEAN) 301 | 302 | # distance_evaluator.evaluate_string_pairs(query=query_embedding, documents="") 303 | 304 | similarities = [doc.metadata['__similarity'] for doc in relevant_documents] 305 | min_similarity = min(similarities) 306 | mean_similarity = sum(similarities) / len(similarities) 307 | coefficient = min_similarity - mean_similarity 308 | 309 | return f"Coefficient: {coefficient}, (Min similarity {min_similarity}, Mean similarity: {mean_similarity})", relevant_document_coordinates 310 | 311 | def _parse_json(self, response, output_parser): 312 | system_message = "You are an useful assistant expert in materials science, physics, and chemistry " \ 313 | "that can process text and transform it to JSON." 314 | human_message = """Transform the text between three double quotes in JSON.\n\n\n\n 315 | {format_instructions}\n\nText: \"\"\"{text}\"\"\"""" 316 | 317 | system_message_prompt = SystemMessagePromptTemplate.from_template(system_message) 318 | human_message_prompt = HumanMessagePromptTemplate.from_template(human_message) 319 | 320 | prompt_template = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) 321 | 322 | results = self.llm( 323 | prompt_template.format_prompt( 324 | text=response, 325 | format_instructions=output_parser.get_format_instructions() 326 | ).to_messages() 327 | ) 328 | parsed_output = output_parser.parse(results.content) 329 | 330 | return parsed_output 331 | 332 | def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list): 333 | relevant_documents, relevant_document_coordinates = self._get_context(doc_id, query, context_size) 334 | response = self.chain.invoke({"context": relevant_documents, "question": query}) 335 | return response, relevant_document_coordinates 336 | 337 | def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list): 338 | db = self.data_storage.embeddings_dict[doc_id] 339 | retriever = db.as_retriever(search_kwargs={"k": context_size}) 340 | relevant_documents = retriever.invoke(query) 341 | relevant_document_coordinates = [ 342 | doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] 343 | for doc in 344 | relevant_documents 345 | ] 346 | if self.memory and len(self.memory.buffer_as_messages) > 0: 347 | relevant_documents.append( 348 | Document( 349 | page_content="""Following, the previous question and answers. Use these information only when in the question there are unspecified references:\n{}\n\n""".format( 350 | self.memory.buffer_as_str)) 351 | ) 352 | return relevant_documents, relevant_document_coordinates 353 | 354 | def get_full_context_by_document(self, doc_id): 355 | """ 356 | Return the full context from the document 357 | """ 358 | db = self.data_storage.embeddings_dict[doc_id] 359 | docs = db.get() 360 | return docs['documents'] 361 | 362 | def _get_context_multiquery(self, doc_id, query, context_size=4): 363 | db = self.data_storage.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size}) 364 | multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm) 365 | relevant_documents = multi_query_retriever.invoke(query) 366 | return relevant_documents 367 | 368 | def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False): 369 | """ 370 | Extract text from documents using Grobid. 371 | - if chunk_size is < 0, keeps each paragraph separately 372 | - if chunk_size > 0, aggregate all paragraphs and split them again using an approximate chunk size 373 | """ 374 | if verbose: 375 | print("File", pdf_file_path) 376 | filename = Path(pdf_file_path).stem 377 | coordinates = True # if chunk_size == -1 else False 378 | structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates) 379 | 380 | biblio = structure['biblio'] 381 | biblio['filename'] = filename.replace(" ", "_") 382 | 383 | if verbose: 384 | print("Generating embeddings for:", hash, ", filename: ", filename) 385 | 386 | texts = [] 387 | metadatas = [] 388 | ids = [] 389 | 390 | if chunk_size > 0: 391 | new_passages = self.text_merger.merge_passages(structure['passages'], chunk_size=chunk_size) 392 | else: 393 | new_passages = structure['passages'] 394 | 395 | for passage in new_passages: 396 | biblio_copy = copy.copy(biblio) 397 | if len(str.strip(passage['text'])) > 0: 398 | texts.append(passage['text']) 399 | 400 | biblio_copy['type'] = passage['type'] 401 | biblio_copy['section'] = passage['section'] 402 | biblio_copy['subSection'] = passage['subSection'] 403 | biblio_copy['coordinates'] = passage['coordinates'] 404 | metadatas.append(biblio_copy) 405 | 406 | # ids.append(passage['passage_id']) 407 | 408 | ids = [id for id, t in enumerate(new_passages)] 409 | 410 | return texts, metadatas, ids 411 | 412 | def create_memory_embeddings( 413 | self, 414 | pdf_path, 415 | doc_id=None, 416 | chunk_size=500, 417 | perc_overlap=0.1 418 | ): 419 | texts, metadata, ids = self.get_text_from_document( 420 | pdf_path, 421 | chunk_size=chunk_size, 422 | perc_overlap=perc_overlap) 423 | if doc_id: 424 | hash = doc_id 425 | else: 426 | hash = metadata[0]['hash'] if len(metadata) > 0 and 'hash' in metadata[0] else "" 427 | 428 | self.data_storage.embed_document(hash, texts, metadata) 429 | 430 | return hash 431 | 432 | def create_embeddings( 433 | self, 434 | pdfs_dir_path: Path, 435 | chunk_size=500, 436 | perc_overlap=0.1, 437 | include_biblio=False 438 | ): 439 | input_files = [] 440 | for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False): 441 | for file_ in files: 442 | if not (file_.lower().endswith(".pdf")): 443 | continue 444 | input_files.append(os.path.join(root, file_)) 445 | 446 | for input_file in tqdm(input_files, total=len(input_files), unit='document', 447 | desc="Grobid + embeddings processing"): 448 | 449 | md5 = self.calculate_md5(input_file) 450 | data_path = os.path.join(self.data_storage.embeddings_root_path, md5) 451 | 452 | if os.path.exists(data_path): 453 | print(data_path, "exists. Skipping it ") 454 | continue 455 | # include = ["biblio"] if include_biblio else [] 456 | texts, metadata, ids = self.get_text_from_document( 457 | input_file, 458 | chunk_size=chunk_size, 459 | perc_overlap=perc_overlap) 460 | filename = metadata[0]['filename'] 461 | 462 | vector_db_document = Chroma.from_texts(texts, 463 | metadatas=metadata, 464 | embedding=self.embedding_function, 465 | persist_directory=data_path) 466 | vector_db_document.persist() 467 | 468 | with open(os.path.join(data_path, filename + ".storage_filename"), 'w') as fo: 469 | fo.write("") 470 | 471 | @staticmethod 472 | def calculate_md5(input_file: Union[Path, str]): 473 | import hashlib 474 | md5_hash = hashlib.md5() 475 | with open(input_file, 'rb') as fi: 476 | md5_hash.update(fi.read()) 477 | return md5_hash.hexdigest().upper() 478 | -------------------------------------------------------------------------------- /document_qa/grobid_processors.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import OrderedDict 3 | from html import escape 4 | from pathlib import Path 5 | 6 | import dateparser 7 | import grobid_tei_xml 8 | from bs4 import BeautifulSoup 9 | from grobid_client.grobid_client import GrobidClient 10 | 11 | 12 | def get_span_start(type, title=None): 13 | title_ = ' title="' + title + '"' if title is not None else "" 14 | return '' 15 | 16 | 17 | def get_span_end(): 18 | return '' 19 | 20 | 21 | def get_rs_start(type): 22 | return '' 23 | 24 | 25 | def get_rs_end(): 26 | return '' 27 | 28 | 29 | def has_space_between_value_and_unit(quantity): 30 | return quantity['offsetEnd'] < quantity['rawUnit']['offsetStart'] 31 | 32 | 33 | def decorate_text_with_annotations(text, spans, tag="span"): 34 | """ 35 | Decorate a text using spans, using two style defined by the tag: 36 | - "span" generated HTML like annotated text 37 | - "rs" generate XML like annotated text (format SuperMat) 38 | """ 39 | sorted_spans = list(sorted(spans, key=lambda item: item['offset_start'])) 40 | annotated_text = "" 41 | start = 0 42 | for span in sorted_spans: 43 | type = span['type'].replace("<", "").replace(">", "") 44 | if 'unit_type' in span and span['unit_type'] is not None: 45 | type = span['unit_type'].replace(" ", "_") 46 | annotated_text += escape(text[start: span['offset_start']]) 47 | title = span['quantified'] if 'quantified' in span else None 48 | annotated_text += get_span_start(type, title) if tag == "span" else get_rs_start(type) 49 | annotated_text += escape(text[span['offset_start']: span['offset_end']]) 50 | annotated_text += get_span_end() if tag == "span" else get_rs_end() 51 | 52 | start = span['offset_end'] 53 | annotated_text += escape(text[start: len(text)]) 54 | return annotated_text 55 | 56 | 57 | def get_parsed_value_type(quantity): 58 | if 'parsedValue' in quantity and 'structure' in quantity['parsedValue']: 59 | return quantity['parsedValue']['structure']['type'] 60 | 61 | 62 | class BaseProcessor(object): 63 | # def __init__(self, grobid_superconductors_client=None, grobid_quantities_client=None): 64 | # self.grobid_superconductors_client = grobid_superconductors_client 65 | # self.grobid_quantities_client = grobid_quantities_client 66 | 67 | patterns = [ 68 | r'\d+e\d+' 69 | ] 70 | 71 | def post_process(self, text): 72 | output = text.replace('À', '-') 73 | output = output.replace('¼', '=') 74 | output = output.replace('þ', '+') 75 | output = output.replace('Â', 'x') 76 | output = output.replace('$', '~') 77 | output = output.replace('−', '-') 78 | output = output.replace('–', '-') 79 | 80 | for pattern in self.patterns: 81 | output = re.sub(pattern, lambda match: match.group().replace('e', '-'), output) 82 | 83 | return output 84 | 85 | 86 | class GrobidProcessor(BaseProcessor): 87 | def __init__(self, grobid_url, ping_server=True): 88 | # super().__init__() 89 | grobid_client = GrobidClient( 90 | grobid_server=grobid_url, 91 | batch_size=5, 92 | coordinates=["p", "title", "persName"], 93 | sleep_time=5, 94 | timeout=60, 95 | check_server=ping_server 96 | ) 97 | self.grobid_client = grobid_client 98 | 99 | def process_structure(self, input_path, coordinates=False): 100 | pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument", 101 | input_path, 102 | consolidate_header=True, 103 | consolidate_citations=False, 104 | segment_sentences=False, 105 | tei_coordinates=coordinates, 106 | include_raw_citations=False, 107 | include_raw_affiliations=False, 108 | generateIDs=True) 109 | 110 | if status != 200: 111 | return 112 | 113 | document_object = self.parse_grobid_xml(text, coordinates=coordinates) 114 | document_object['filename'] = Path(pdf_file).stem.replace(".tei", "") 115 | 116 | return document_object 117 | 118 | def process_single(self, input_file): 119 | doc = self.process_structure(input_file) 120 | 121 | for paragraph in doc['passages']: 122 | entities = self.process_single_text(paragraph['text']) 123 | paragraph['spans'] = entities 124 | 125 | return doc 126 | 127 | def parse_grobid_xml(self, text, coordinates=False): 128 | output_data = OrderedDict() 129 | 130 | doc_biblio = grobid_tei_xml.parse_document_xml(text) 131 | biblio = { 132 | "doi": doc_biblio.header.doi if doc_biblio.header.doi is not None else "", 133 | "authors": ", ".join([author.full_name for author in doc_biblio.header.authors]), 134 | "title": doc_biblio.header.title, 135 | "hash": doc_biblio.pdf_md5 136 | } 137 | try: 138 | year = dateparser.parse(doc_biblio.header.date).year 139 | biblio["publication_year"] = year 140 | except: 141 | pass 142 | 143 | output_data['biblio'] = biblio 144 | passages = [] 145 | output_data['passages'] = passages 146 | passage_type = "paragraph" 147 | 148 | soup = BeautifulSoup(text, 'xml') 149 | blocks_header = get_xml_nodes_header(soup, use_paragraphs=True) 150 | 151 | # passages.append({ 152 | # "text": f"authors: {biblio['authors']}", 153 | # "type": passage_type, 154 | # "section": "
", 155 | # "subSection": "", 156 | # "passage_id": "hauthors", 157 | # "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in 158 | # blocks_header['authors']]) 159 | # }) 160 | 161 | passages.append({ 162 | "text": self.post_process(" ".join([node.text for node in blocks_header['title']])), 163 | "type": passage_type, 164 | "section": "
", 165 | "subSection": "", 166 | "passage_id": "htitle", 167 | "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in 168 | blocks_header['title']]) 169 | }) 170 | 171 | passages.append({ 172 | "text": self.post_process( 173 | ''.join(node.text for node in blocks_header['abstract'] for text in node.find_all(text=True) if 174 | text.parent.name != "ref" or ( 175 | text.parent.name == "ref" and text.parent.attrs[ 176 | 'type'] != 'bibr'))), 177 | "type": passage_type, 178 | "section": "<header>", 179 | "subSection": "<abstract>", 180 | "passage_id": "habstract", 181 | "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in 182 | blocks_header['abstract']]) 183 | }) 184 | 185 | text_blocks_body = get_xml_nodes_body(soup, verbose=False, use_paragraphs=True) 186 | text_blocks_body.extend(get_xml_nodes_back(soup, verbose=False, use_paragraphs=True)) 187 | 188 | use_paragraphs = True 189 | if not use_paragraphs: 190 | passages.extend([ 191 | { 192 | "text": self.post_process(''.join(text for text in sentence.find_all(text=True) if 193 | text.parent.name != "ref" or ( 194 | text.parent.name == "ref" and text.parent.attrs[ 195 | 'type'] != 'bibr'))), 196 | "type": passage_type, 197 | "section": "<body>", 198 | "subSection": "<paragraph>", 199 | "passage_id": str(paragraph_id), 200 | "coordinates": paragraph['coords'] if coordinates and sentence.has_attr('coords') else "" 201 | } 202 | for paragraph_id, paragraph in enumerate(text_blocks_body) for 203 | sentence_id, sentence in enumerate(paragraph) 204 | ]) 205 | else: 206 | passages.extend([ 207 | { 208 | "text": self.post_process(''.join(text for text in paragraph.find_all(text=True) if 209 | text.parent.name != "ref" or ( 210 | text.parent.name == "ref" and text.parent.attrs[ 211 | 'type'] != 'bibr'))), 212 | "type": passage_type, 213 | "section": "<body>", 214 | "subSection": "<paragraph>", 215 | "passage_id": str(paragraph_id), 216 | "coordinates": paragraph['coords'] if coordinates and paragraph.has_attr('coords') else "" 217 | } 218 | for paragraph_id, paragraph in enumerate(text_blocks_body) 219 | ]) 220 | 221 | text_blocks_figures = get_xml_nodes_figures(soup, verbose=False) 222 | 223 | if not use_paragraphs: 224 | passages.extend([ 225 | { 226 | "text": self.post_process(''.join(text for text in sentence.find_all(text=True) if 227 | text.parent.name != "ref" or ( 228 | text.parent.name == "ref" and text.parent.attrs[ 229 | 'type'] != 'bibr'))), 230 | "type": passage_type, 231 | "section": "<body>", 232 | "subSection": "<figure>", 233 | "passage_id": str(paragraph_id) + str(sentence_id), 234 | "coordinates": sentence['coords'] if coordinates and 'coords' in sentence else "" 235 | } 236 | for paragraph_id, paragraph in enumerate(text_blocks_figures) for 237 | sentence_id, sentence in enumerate(paragraph) 238 | ]) 239 | else: 240 | passages.extend([ 241 | { 242 | "text": self.post_process(''.join(text for text in paragraph.find_all(text=True) if 243 | text.parent.name != "ref" or ( 244 | text.parent.name == "ref" and text.parent.attrs[ 245 | 'type'] != 'bibr'))), 246 | "type": passage_type, 247 | "section": "<body>", 248 | "subSection": "<figure>", 249 | "passage_id": str(paragraph_id), 250 | "coordinates": paragraph['coords'] if coordinates and paragraph.has_attr('coords') else "" 251 | } 252 | for paragraph_id, paragraph in enumerate(text_blocks_figures) 253 | ]) 254 | 255 | return output_data 256 | 257 | 258 | class GrobidQuantitiesProcessor(BaseProcessor): 259 | def __init__(self, grobid_quantities_client): 260 | self.grobid_quantities_client = grobid_quantities_client 261 | 262 | def process(self, text) -> list: 263 | status, result = self.grobid_quantities_client.process_text(text.strip()) 264 | 265 | if status != 200: 266 | result = {} 267 | 268 | spans = [] 269 | 270 | if 'measurements' in result: 271 | found_measurements = self.parse_measurements_output(result) 272 | 273 | for m in found_measurements: 274 | item = { 275 | "text": text[m['offset_start']:m['offset_end']], 276 | 'offset_start': m['offset_start'], 277 | 'offset_end': m['offset_end'] 278 | } 279 | 280 | if 'raw' in m and m['raw'] != item['text']: 281 | item['text'] = m['raw'] 282 | 283 | if 'quantified_substance' in m: 284 | item['quantified'] = m['quantified_substance'] 285 | 286 | if 'type' in m: 287 | item["unit_type"] = m['type'] 288 | 289 | item['type'] = 'property' 290 | # if 'raw_value' in m: 291 | # item['raw_value'] = m['raw_value'] 292 | 293 | spans.append(item) 294 | 295 | return spans 296 | 297 | @staticmethod 298 | def parse_measurements_output(result): 299 | measurements_output = [] 300 | 301 | for measurement in result['measurements']: 302 | type = measurement['type'] 303 | measurement_output_object = {} 304 | quantity_type = None 305 | has_unit = False 306 | parsed_value_type = None 307 | 308 | if 'quantified' in measurement: 309 | if 'normalizedName' in measurement['quantified']: 310 | quantified_substance = measurement['quantified']['normalizedName'] 311 | measurement_output_object["quantified_substance"] = quantified_substance 312 | 313 | if 'measurementOffsets' in measurement: 314 | measurement_output_object["offset_start"] = measurement["measurementOffsets"]['start'] 315 | measurement_output_object["offset_end"] = measurement["measurementOffsets"]['end'] 316 | else: 317 | # If there are no offsets we skip the measurement 318 | continue 319 | 320 | # if 'measurementRaw' in measurement: 321 | # measurement_output_object['raw_value'] = measurement['measurementRaw'] 322 | 323 | if type == 'value': 324 | quantity = measurement['quantity'] 325 | 326 | parsed_value = GrobidQuantitiesProcessor.get_parsed(quantity) 327 | if parsed_value: 328 | measurement_output_object['parsed'] = parsed_value 329 | 330 | normalized_value = GrobidQuantitiesProcessor.get_normalized(quantity) 331 | if normalized_value: 332 | measurement_output_object['normalized'] = normalized_value 333 | 334 | raw_value = GrobidQuantitiesProcessor.get_raw(quantity) 335 | if raw_value: 336 | measurement_output_object['raw'] = raw_value 337 | 338 | if 'type' in quantity: 339 | quantity_type = quantity['type'] 340 | 341 | if 'rawUnit' in quantity: 342 | has_unit = True 343 | 344 | parsed_value_type = get_parsed_value_type(quantity) 345 | 346 | elif type == 'interval': 347 | if 'quantityMost' in measurement: 348 | quantityMost = measurement['quantityMost'] 349 | if 'type' in quantityMost: 350 | quantity_type = quantityMost['type'] 351 | 352 | if 'rawUnit' in quantityMost: 353 | has_unit = True 354 | 355 | parsed_value_type = get_parsed_value_type(quantityMost) 356 | 357 | if 'quantityLeast' in measurement: 358 | quantityLeast = measurement['quantityLeast'] 359 | 360 | if 'type' in quantityLeast: 361 | quantity_type = quantityLeast['type'] 362 | 363 | if 'rawUnit' in quantityLeast: 364 | has_unit = True 365 | 366 | parsed_value_type = get_parsed_value_type(quantityLeast) 367 | 368 | elif type == 'listc': 369 | quantities = measurement['quantities'] 370 | 371 | if 'type' in quantities[0]: 372 | quantity_type = quantities[0]['type'] 373 | 374 | if 'rawUnit' in quantities[0]: 375 | has_unit = True 376 | 377 | parsed_value_type = get_parsed_value_type(quantities[0]) 378 | 379 | if quantity_type is not None or has_unit: 380 | measurement_output_object['type'] = quantity_type 381 | 382 | if parsed_value_type is None or parsed_value_type not in ['ALPHABETIC', 'TIME']: 383 | measurements_output.append(measurement_output_object) 384 | 385 | return measurements_output 386 | 387 | @staticmethod 388 | def get_parsed(quantity): 389 | parsed_value = parsed_unit = None 390 | if 'parsedValue' in quantity and 'parsed' in quantity['parsedValue']: 391 | parsed_value = quantity['parsedValue']['parsed'] 392 | if 'parsedUnit' in quantity and 'name' in quantity['parsedUnit']: 393 | parsed_unit = quantity['parsedUnit']['name'] 394 | 395 | if parsed_value and parsed_unit: 396 | if has_space_between_value_and_unit(quantity): 397 | return str(parsed_value) + str(parsed_unit) 398 | else: 399 | return str(parsed_value) + " " + str(parsed_unit) 400 | 401 | @staticmethod 402 | def get_normalized(quantity): 403 | normalized_value = normalized_unit = None 404 | if 'normalizedQuantity' in quantity: 405 | normalized_value = quantity['normalizedQuantity'] 406 | if 'normalizedUnit' in quantity and 'name' in quantity['normalizedUnit']: 407 | normalized_unit = quantity['normalizedUnit']['name'] 408 | 409 | if normalized_value and normalized_unit: 410 | if has_space_between_value_and_unit(quantity): 411 | return str(normalized_value) + " " + str(normalized_unit) 412 | else: 413 | return str(normalized_value) + str(normalized_unit) 414 | 415 | @staticmethod 416 | def get_raw(quantity): 417 | raw_value = raw_unit = None 418 | if 'rawValue' in quantity: 419 | raw_value = quantity['rawValue'] 420 | if 'rawUnit' in quantity and 'name' in quantity['rawUnit']: 421 | raw_unit = quantity['rawUnit']['name'] 422 | 423 | if raw_value and raw_unit: 424 | if has_space_between_value_and_unit(quantity): 425 | return str(raw_value) + " " + str(raw_unit) 426 | else: 427 | return str(raw_value) + str(raw_unit) 428 | 429 | 430 | class GrobidMaterialsProcessor(BaseProcessor): 431 | def __init__(self, grobid_superconductors_client): 432 | self.grobid_superconductors_client = grobid_superconductors_client 433 | 434 | def process(self, text): 435 | preprocessed_text = text.strip() 436 | status, result = self.grobid_superconductors_client.process_text(preprocessed_text, 437 | "processText_disable_linking") 438 | 439 | if status != 200: 440 | result = {} 441 | 442 | spans = [] 443 | 444 | if 'passages' in result: 445 | materials = self.parse_superconductors_output(result, preprocessed_text) 446 | 447 | for m in materials: 448 | item = {"text": preprocessed_text[m['offset_start']:m['offset_end']]} 449 | 450 | item['offset_start'] = m['offset_start'] 451 | item['offset_end'] = m['offset_end'] 452 | 453 | if 'formula' in m: 454 | item["formula"] = m['formula'] 455 | 456 | item['type'] = 'material' 457 | item['raw_value'] = m['text'] 458 | 459 | spans.append(item) 460 | 461 | return spans 462 | 463 | def parse_materials(self, text): 464 | status, result = self.grobid_superconductors_client.process_texts(text.strip(), "parseMaterials") 465 | 466 | if status != 200: 467 | result = [] 468 | 469 | results = [] 470 | for position_material in result: 471 | compositions = [] 472 | for material in position_material: 473 | if 'resolvedFormulas' in material: 474 | for resolved_formula in material['resolvedFormulas']: 475 | if 'formulaComposition' in resolved_formula: 476 | compositions.append(resolved_formula['formulaComposition']) 477 | elif 'formula' in material: 478 | if 'formulaComposition' in material['formula']: 479 | compositions.append(material['formula']['formulaComposition']) 480 | results.append(compositions) 481 | 482 | return results 483 | 484 | def parse_material(self, text): 485 | status, result = self.grobid_superconductors_client.process_text(text.strip(), "parseMaterial") 486 | 487 | if status != 200: 488 | result = [] 489 | 490 | compositions = self.output_info(result) 491 | 492 | return compositions 493 | 494 | def output_info(self, result): 495 | compositions = [] 496 | for material in result: 497 | if 'resolvedFormulas' in material: 498 | for resolved_formula in material['resolvedFormulas']: 499 | if 'formulaComposition' in resolved_formula: 500 | compositions.append(resolved_formula['formulaComposition']) 501 | elif 'formula' in material: 502 | if 'formulaComposition' in material['formula']: 503 | compositions.append(material['formula']['formulaComposition']) 504 | if 'name' in material: 505 | compositions.append(material['name']) 506 | return compositions 507 | 508 | @staticmethod 509 | def parse_superconductors_output(result, original_text): 510 | materials = [] 511 | 512 | for passage in result['passages']: 513 | sentence_offset = original_text.index(passage['text']) 514 | if 'spans' in passage: 515 | spans = passage['spans'] 516 | for material_span in filter(lambda s: s['type'] == '<material>', spans): 517 | text_ = material_span['text'] 518 | 519 | base_material_information = { 520 | "text": text_, 521 | "offset_start": sentence_offset + material_span['offset_start'], 522 | 'offset_end': sentence_offset + material_span['offset_end'] 523 | } 524 | 525 | materials.append(base_material_information) 526 | 527 | return materials 528 | 529 | 530 | class GrobidAggregationProcessor(GrobidQuantitiesProcessor, GrobidMaterialsProcessor): 531 | def __init__(self, grobid_quantities_client=None, grobid_superconductors_client=None): 532 | if grobid_quantities_client: 533 | self.gqp = GrobidQuantitiesProcessor(grobid_quantities_client) 534 | if grobid_superconductors_client: 535 | self.gmp = GrobidMaterialsProcessor(grobid_superconductors_client) 536 | 537 | def process_single_text(self, text): 538 | extracted_quantities_spans = self.process_properties(text) 539 | extracted_materials_spans = self.process_materials(text) 540 | all_entities = extracted_quantities_spans + extracted_materials_spans 541 | entities = self.prune_overlapping_annotations(all_entities) 542 | return entities 543 | 544 | def process_properties(self, text): 545 | if self.gqp: 546 | return self.gqp.process(text) 547 | else: 548 | return [] 549 | 550 | def process_materials(self, text): 551 | if self.gmp: 552 | return self.gmp.process(text) 553 | else: 554 | return [] 555 | 556 | @staticmethod 557 | def box_to_dict(box, color=None, type=None, border=None): 558 | 559 | if box is None or box == "" or len(box) < 5: 560 | return {} 561 | 562 | item = {"page": box[0], "x": box[1], "y": box[2], "width": box[3], "height": box[4]} 563 | if color: 564 | item['color'] = color 565 | 566 | if type: 567 | item['type'] = type 568 | 569 | if border: 570 | item['border'] = border 571 | 572 | return item 573 | 574 | @staticmethod 575 | def prune_overlapping_annotations(entities: list) -> list: 576 | # Sorting by offsets 577 | sorted_entities = sorted(entities, key=lambda d: d['offset_start']) 578 | 579 | if len(entities) <= 1: 580 | return sorted_entities 581 | 582 | to_be_removed = [] 583 | 584 | previous = None 585 | first = True 586 | 587 | for current in sorted_entities: 588 | if first: 589 | first = False 590 | previous = current 591 | continue 592 | 593 | if previous['offset_start'] < current['offset_start'] \ 594 | and previous['offset_end'] < current['offset_end'] \ 595 | and (previous['offset_end'] < current['offset_start'] \ 596 | and not (previous['text'] == "-" and current['text'][0].isdigit())): 597 | previous = current 598 | continue 599 | 600 | if previous['offset_end'] < current['offset_end']: 601 | if current['type'] == previous['type']: 602 | # Type is the same 603 | if current['offset_start'] == previous['offset_end']: 604 | if current['type'] == 'property': 605 | if current['text'].startswith("."): 606 | print( 607 | f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") 608 | # current entity starts with a ".", suspiciously look like a truncated value 609 | to_be_removed.append(previous) 610 | current['text'] = previous['text'] + current['text'] 611 | current['raw_value'] = current['text'] 612 | current['offset_start'] = previous['offset_start'] 613 | elif previous['text'].endswith(".") and current['text'][0].isdigit(): 614 | print( 615 | f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") 616 | # previous entity ends with ".", current entity starts with a number 617 | to_be_removed.append(previous) 618 | current['text'] = previous['text'] + current['text'] 619 | current['raw_value'] = current['text'] 620 | current['offset_start'] = previous['offset_start'] 621 | elif previous['text'].startswith("-"): 622 | print( 623 | f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") 624 | # previous starts with a `-`, sherlock this is another truncated value 625 | current['text'] = previous['text'] + current['text'] 626 | current['raw_value'] = current['text'] 627 | current['offset_start'] = previous['offset_start'] 628 | to_be_removed.append(previous) 629 | else: 630 | print("Other cases to be considered: ", previous, current) 631 | else: 632 | if current['text'].startswith("-"): 633 | print( 634 | f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") 635 | # previous starts with a `-`, sherlock this is another truncated value 636 | current['text'] = previous['text'] + current['text'] 637 | current['raw_value'] = current['text'] 638 | current['offset_start'] = previous['offset_start'] 639 | to_be_removed.append(previous) 640 | else: 641 | print("Other cases to be considered: ", previous, current) 642 | 643 | elif previous['text'] == "-" and current['text'][0].isdigit(): 644 | print( 645 | f"Merging. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") 646 | # previous starts with a `-`, sherlock this is another truncated value 647 | current['text'] = previous['text'] + " " * (current['offset_start'] - previous['offset_end']) + \ 648 | current['text'] 649 | current['raw_value'] = current['text'] 650 | current['offset_start'] = previous['offset_start'] 651 | to_be_removed.append(previous) 652 | else: 653 | print( 654 | f"Overlapping. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") 655 | 656 | # take the largest one 657 | if len(previous['text']) > len(current['text']): 658 | to_be_removed.append(current) 659 | elif len(previous['text']) < len(current['text']): 660 | to_be_removed.append(previous) 661 | else: 662 | to_be_removed.append(previous) 663 | elif current['type'] != previous['type']: 664 | print( 665 | f"Overlapping. {current['text']} <{current['type']}> with {previous['text']} <{previous['type']}>") 666 | 667 | if len(previous['text']) > len(current['text']): 668 | to_be_removed.append(current) 669 | elif len(previous['text']) < len(current['text']): 670 | to_be_removed.append(previous) 671 | else: 672 | if current['type'] == "material": 673 | to_be_removed.append(previous) 674 | else: 675 | to_be_removed.append(current) 676 | previous = current 677 | 678 | elif previous['offset_end'] > current['offset_end']: 679 | to_be_removed.append(current) 680 | # the previous goes after the current, so we keep the previous and we discard the current 681 | else: 682 | if current['type'] == "material": 683 | to_be_removed.append(previous) 684 | else: 685 | to_be_removed.append(current) 686 | previous = current 687 | 688 | new_sorted_entities = [e for e in sorted_entities if e not in to_be_removed] 689 | 690 | return new_sorted_entities 691 | 692 | 693 | class XmlProcessor(BaseProcessor): 694 | def __init__(self): 695 | super().__init__() 696 | 697 | def process_structure(self, input_file): 698 | text = "" 699 | with open(input_file, encoding='utf-8') as fi: 700 | text = fi.read() 701 | 702 | output_data = self.parse_xml(text) 703 | output_data['filename'] = Path(input_file).stem.replace(".tei", "") 704 | 705 | return output_data 706 | 707 | # def process_single(self, input_file): 708 | # doc = self.process_structure(input_file) 709 | # 710 | # for paragraph in doc['passages']: 711 | # entities = self.process_single_text(paragraph['text']) 712 | # paragraph['spans'] = entities 713 | # 714 | # return doc 715 | 716 | def process(self, text): 717 | output_data = OrderedDict() 718 | soup = BeautifulSoup(text, 'xml') 719 | text_blocks_children = get_children_list_supermat(soup, verbose=False) 720 | 721 | passages = [] 722 | output_data['passages'] = passages 723 | passages.extend([ 724 | { 725 | "text": self.post_process(''.join(text for text in sentence.find_all(text=True) if 726 | text.parent.name != "ref" or ( 727 | text.parent.name == "ref" and text.parent.attrs[ 728 | 'type'] != 'bibr'))), 729 | "type": "paragraph", 730 | "section": "<body>", 731 | "subSection": "<paragraph>", 732 | "passage_id": str(paragraph_id) + str(sentence_id) 733 | } 734 | for paragraph_id, paragraph in enumerate(text_blocks_children) for 735 | sentence_id, sentence in enumerate(paragraph) 736 | ]) 737 | 738 | return output_data 739 | 740 | 741 | def get_children_list_supermat(soup, use_paragraphs=False, verbose=False): 742 | children = [] 743 | 744 | child_name = "p" if use_paragraphs else "s" 745 | for child in soup.tei.children: 746 | if child.name == 'teiHeader': 747 | pass 748 | children.append(child.find_all("title")) 749 | children.extend([subchild.find_all(child_name) for subchild in child.find_all("abstract")]) 750 | children.extend([subchild.find_all(child_name) for subchild in child.find_all("ab", {"type": "keywords"})]) 751 | elif child.name == 'text': 752 | children.extend([subchild.find_all(child_name) for subchild in child.find_all("body")]) 753 | 754 | if verbose: 755 | print(str(children)) 756 | 757 | return children 758 | 759 | 760 | def get_children_list_grobid(soup: object, use_paragraphs: object = True, verbose: object = False) -> object: 761 | children = [] 762 | 763 | child_name = "p" if use_paragraphs else "s" 764 | for child in soup.TEI.children: 765 | if child.name == 'teiHeader': 766 | pass 767 | # children.extend(child.find_all("title", attrs={"level": "a"}, limit=1)) 768 | # children.extend([subchild.find_all(child_name) for subchild in child.find_all("abstract")]) 769 | elif child.name == 'text': 770 | children.extend([subchild.find_all(child_name) for subchild in child.find_all("body")]) 771 | children.extend([subchild.find_all("figDesc") for subchild in child.find_all("body")]) 772 | 773 | if verbose: 774 | print(str(children)) 775 | 776 | return children 777 | 778 | 779 | def get_xml_nodes_header(soup: object, use_paragraphs: bool = True) -> list: 780 | sub_tag = "p" if use_paragraphs else "s" 781 | 782 | header_elements = { 783 | "authors": [persNameNode for persNameNode in soup.teiHeader.find_all("persName")], 784 | "abstract": [p_in_abstract for abstractNodes in soup.teiHeader.find_all("abstract") for p_in_abstract in 785 | abstractNodes.find_all(sub_tag)], 786 | "title": [soup.teiHeader.fileDesc.title] 787 | } 788 | 789 | return header_elements 790 | 791 | 792 | def get_xml_nodes_body(soup: object, use_paragraphs: bool = True, verbose: bool = False) -> list: 793 | nodes = [] 794 | tag_name = "p" if use_paragraphs else "s" 795 | for child in soup.TEI.children: 796 | if child.name == 'text': 797 | # nodes.extend([subchild.find_all(tag_name) for subchild in child.find_all("body")]) 798 | nodes.extend( 799 | [subsubchild for subchild in child.find_all("body") for subsubchild in subchild.find_all(tag_name)]) 800 | 801 | if verbose: 802 | print(str(nodes)) 803 | 804 | return nodes 805 | 806 | 807 | def get_xml_nodes_back(soup: object, use_paragraphs: bool = True, verbose: bool = False) -> list: 808 | nodes = [] 809 | tag_name = "p" if use_paragraphs else "s" 810 | for child in soup.TEI.children: 811 | if child.name == 'text': 812 | nodes.extend( 813 | [subsubchild for subchild in child.find_all("back") for subsubchild in subchild.find_all(tag_name)]) 814 | 815 | if verbose: 816 | print(str(nodes)) 817 | 818 | return nodes 819 | 820 | 821 | def get_xml_nodes_figures(soup: object, verbose: bool = False) -> list: 822 | children = [] 823 | for child in soup.TEI.children: 824 | if child.name == 'text': 825 | children.extend( 826 | [subchild for subchilds in child.find_all("body") for subchild in subchilds.find_all("figDesc")]) 827 | 828 | if verbose: 829 | print(str(children)) 830 | 831 | return children 832 | --------------------------------------------------------------------------------