├── .gitignore ├── LICENSE ├── README.md ├── img ├── Flowchart.png └── screenshot.png ├── neo4j_primekg ├── Dockerfile ├── README.md ├── graph_preprocessing_scripts │ ├── __init__.py │ ├── allowed_preferred_name_mapping_categories.py │ ├── apply_preferred_term_mapping.py │ ├── drugbank_id_mapper.py │ ├── id_based_mapper.py │ ├── mapper.py │ ├── mondo_id_mapper.py │ ├── preferred_term_category_finder.py │ └── preferred_term_mapper.py ├── import_primekg.sh └── primekg_to_neo4j_csv.py ├── pyproject.toml ├── src ├── fact_finder │ ├── __init__.py │ ├── app.py │ ├── chains │ │ ├── __init__.py │ │ ├── answer_generation_chain.py │ │ ├── combined_graph_rag_qa_chain.py │ │ ├── cypher_query_generation_chain.py │ │ ├── cypher_query_preprocessors_chain.py │ │ ├── entity_detection_question_preprocessing_chain.py │ │ ├── filtered_primekg_question_preprocessing_chain.py │ │ ├── graph_chain.py │ │ ├── graph_qa_chain │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── early_stopping_chain.py │ │ │ ├── graph_qa_chain.py │ │ │ ├── output.py │ │ │ └── output_chain.py │ │ ├── graph_summary_chain.py │ │ ├── rag │ │ │ ├── __init__.py │ │ │ ├── semantic_scholar_chain.py │ │ │ └── text_search_qa_chain.py │ │ └── subgraph_extractor_chain.py │ ├── config │ │ ├── __init__.py │ │ ├── primekg_config.py │ │ ├── primekg_predicate_descriptions.py │ │ └── simple_config.py │ ├── evaluator │ │ ├── __init__.py │ │ ├── adv_result_gpt4_o.json │ │ ├── adv_result_gpt4_turbo.json │ │ ├── adversarial_attack_evaluation.py │ │ ├── evaluation.py │ │ ├── evaluation_sample.py │ │ ├── evaluation_samples.py │ │ ├── llm_judge │ │ │ └── llm_judge_evaluator.py │ │ ├── score │ │ │ ├── __init__.py │ │ │ ├── bleu_score.py │ │ │ ├── difflib_score.py │ │ │ ├── embedding_score.py │ │ │ ├── levenshtein_score.py │ │ │ └── score.py │ │ ├── set_evaluator │ │ │ ├── __init__.py │ │ │ └── set_evaluator.py │ │ └── util.py │ ├── prompt_templates.py │ ├── py.typed │ ├── tools │ │ ├── __init__.py │ │ ├── cypher_preprocessors │ │ │ ├── __init__.py │ │ │ ├── always_distinct_preprocessor.py │ │ │ ├── child_to_parent_preprocessor.py │ │ │ ├── cypher_query_preprocessor.py │ │ │ ├── format_preprocessor.py │ │ │ ├── lower_case_properties_cypher_query_preprocessor.py │ │ │ ├── property_string_preprocessor.py │ │ │ ├── size_to_count_preprocessor.py │ │ │ └── synonym_cypher_query_preprocessor.py │ │ ├── entity_detector.py │ │ ├── semantic_scholar_search_api_wrapper.py │ │ ├── sub_graph_extractor.py │ │ ├── subgraph_extension.py │ │ └── synonym_finder │ │ │ ├── __init__.py │ │ │ ├── aggregate_state_synonym_finder.py │ │ │ ├── preferred_term_finder.py │ │ │ ├── synonym_finder.py │ │ │ ├── wiki_data_synonym_finder.py │ │ │ └── word_net_synonym_finder.py │ ├── ui │ │ ├── __init__.py │ │ ├── graph_conversion.py │ │ └── util.py │ └── utils.py └── img │ └── logo.png └── tests ├── __init__.py ├── chains ├── __init__.py ├── graph_qa_chain │ ├── __init__.py │ ├── test_graph_qa_chain.py │ └── test_graph_qa_chain_e2e.py ├── helpers.py ├── test_answer_generation_chain.py ├── test_cypher_query_generation_chain.py ├── test_entity_detection_question_preprocessing_chain.py ├── test_filtered_primekg_question_preprocessing_chain.py ├── test_graph_chain.py ├── test_graph_summary_chain.py ├── test_preprocessors_chain.py ├── test_subgraph_extractor_chain.py └── test_text_search_qa_chain.py ├── evaluator ├── __init__.py ├── score │ ├── __init__.py │ ├── test_bleu_score.py │ ├── test_difflib_score.py │ └── test_embedding_score.py └── set_evaluator │ └── test_set_evaluator.py └── tools ├── __init__.py ├── cypher_preprocessors ├── test_always_distinct_preprocessor.py ├── test_child_to_parent_preprocessor.py ├── test_format_preprocessor.py ├── test_lower_case_property_names.py ├── test_size_to_count_preprocessor.py └── test_synonym_query_preprocessor.py ├── semantic_scholar_search_api_wrapper_test.py ├── synonym_finder └── test_word_net_synonym_finder.py ├── test_entity_detector.py ├── test_llm_subgraph_extractor.py ├── test_regex_subgraph_extractor.py └── test_subgraph_expansion.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | /.idea/inspectionProfiles/profiles_settings.xml 162 | /.idea/inspectionProfiles/Project_Default.xml 163 | /.idea/.gitignore 164 | /.idea/fact-finder.iml 165 | /.idea/jupyter-settings.xml 166 | /.idea/misc.xml 167 | /.idea/modules.xml 168 | /.idea/vcs.xml 169 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Christopher Schymura 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fact Finder 2 | 3 |

4 | 5 | ## Getting Started 6 | 7 | Set up PrimeKG Neo4j Instance, see [here](neo4j_primekg/README.md) 8 | 9 | Install Dependencies: 10 | 11 | ``` 12 | pip install -e . 13 | ``` 14 | 15 | Some features of FactFinder are based on external APIs. While an openai api key is required to run FactFinder, semantic scholar as well as Bayer's linnaeusannotate entity detection are optional. 16 | Set environment variables: 17 | 18 | ``` 19 | export LLM="gpt-4o" # "gpt-4-turbo" as an alternative 20 | export SEMANTIC_SCHOLAR_KEY="" # fill API key for semantic scholar 21 | export OPENAI_API_KEY="" # fill opanAI api key 22 | export SYNONYM_API_KEY="" # Bayer internal linnaeusannotate synonym API key 23 | export SYNONYM_API_URL="" # Bayer internal linnaeusannotate synonym API url 24 | ``` 25 | 26 | Run UI: 27 | 28 | ``` 29 | streamlit run src/fact_finder/app.py --browser.serverAddress localhost 30 | ``` 31 | 32 | Running with additional arguments (e.g. activating the normalized graph synonyms): 33 | 34 | ``` 35 | streamlit run src/fact_finder/app.py --browser.serverAddress localhost -- [args] 36 | streamlit run src/fact_finder/app.py --browser.serverAddress localhost -- --normalized_graph --use_entity_detection_preprocessing 37 | ``` 38 | 39 | The following flags are available: 40 | ``` 41 | --normalized_graph = Apply synonym replacement based on the normalized graph to the cypher queries before applying them to the graph. 42 | --use_entity_detection_preprocessing = Apply entity detection to the user question before generating the cypher query. The found entities will be replaced by their preferred terms and a string describing their category (e.g. "Psoriasis is a disease.") will be added to the query. This requires the corresponding api key ($SYNONYM_API_KEY) to be set. Also, the normalized graph should be used. 43 | --use_subgraph_expansion = The evidence graph gets expanded through the surrounding neighborhoods. 44 | ``` 45 | 46 | ## Process description 47 | 48 | The following steps are undertaken to get from the user question to the natural language answer and the provided evidence: 49 | 50 | 1. In the first step, a language model call is used to generate a cypher query to the knowledge graph. To achieve this, the prompt template contains the schema of the graph, i.e. information about all nodes and their properties. 51 | Additionally, the prompt template can be enriched with natural language descriptions for (some of) the relations in the graph allowing better understanding of their meaning for the language model. 52 | In case the model decides that the user question cannot be answered by a graph with the given schema, the model is instructed to return an error message starting with the marker string "SCHEMA_ERROR". This is then detected and the error message is directly forwarded to the user. 53 | 54 | 2. In the second step, the generated cypher query is preprocessed using regular expressions. 55 | - First, a formatting is applied in order to make subsequent regular expressions easier to design. This includes for example removal of unnecessary whitespaces and using double quotes for all strings. 56 | - Next, all property values are turned to lower case. This assumes that a similar preprocessing has been done for the property values in the graph and makes the query resistant to capitalization mismatches. 57 | - Finally, for some node types, any names used in the query, are replaced with a synonym that is actually used in the graph. This is (for example) done by looking up synonyms for the name and checking which one actually exists in the graph. 58 | 59 | 3. In the third step, the graph is queried with the final result of the cypher preprocessing. The graph answer together with the cypher query are part of the evidence presented in the interface, allowing transparency for the user. 60 | 61 | 4. With another language model call, the final natural language answer is generated from the result of querying the graph. 62 | 63 | 5. Additionally, a subgraph is generated from the graph query and result. This serves as visual evidence for the user. The subgraph can either be generated via a rule based approach or also with help of the language model. 64 | 65 | 66 | ## User Interface 67 | The following image shows the user interface of the application for the question *"Which drugs are used to treat ocular hypertension?"*. The answers of the standalone LLM and our graph-based hybrid system are compared as output. In addition, the relevant subgraph is displayed as evidence together with the generated Cypher query, the answer from the graph and the prompts used. 68 | 69 |

-------------------------------------------------------------------------------- /img/Flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/img/Flowchart.png -------------------------------------------------------------------------------- /img/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/img/screenshot.png -------------------------------------------------------------------------------- /neo4j_primekg/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM neo4j:5.15.0 2 | 3 | ENV IMPORT_DIR=/primekg_data 4 | ENV PRIMEKG_CSV=$IMPORT_DIR/kg.csv 5 | ENV PRIMEKG_DRUG_NODE_FEATURES_TSV=$IMPORT_DIR/drug_features.tab 6 | ENV PRIMEKG_DISEASE_NODE_FEATURES_TSV=$IMPORT_DIR/disease_features.tab 7 | ENV PRIMEKG_CSVS_FOR_NEO4J=$IMPORT_DIR/preprocessed_for_neo4j 8 | ENV NEO4J_PLUGINS='["apoc"]' 9 | 10 | COPY import_primekg.sh /startup/import_primekg.sh 11 | COPY primekg_to_neo4j_csv.py $IMPORT_DIR/primekg_to_neo4j_csv.py 12 | 13 | RUN apt-get update && \ 14 | apt-get install -y python3-pip && \ 15 | pip install pandas pyarrow && \ 16 | rm -rf /var/lib/apt/lists/* && \ 17 | mkdir -p $PRIMEKG_CSVS_FOR_NEO4J && \ 18 | chmod 777 /startup/import_primekg.sh 19 | 20 | ENTRYPOINT ["tini", "-g", "--"] 21 | CMD ["/startup/import_primekg.sh"] 22 | -------------------------------------------------------------------------------- /neo4j_primekg/README.md: -------------------------------------------------------------------------------- 1 | # Neo4j Server with PrimeKG 2 | 3 | This docker image downloads the PrimeKG data, imports it into a Neo4j database and runs the Neo4j service. 4 | 5 | ## Setup 6 | 7 | Build docker image: 8 | ``` 9 | docker build --pull --rm -f "neo4j_primekg/Dockerfile" -t neo4j_primekg:latest "neo4j_primekg" 10 | ``` 11 | 12 | Start up server: 13 | Note that this will take a while as the PrimeKG files are downloaded here. 14 | ``` 15 | docker run -d --restart=always \ 16 | --publish=7474:7474 --publish=7687:7687 \ 17 | --env NEO4J_AUTH=neo4j/opensesame \ 18 | --env NEO4J_server_databases_default__to__read__only=true \ 19 | --env NEO4J_apoc_export_file_enabled=true \ 20 | --env NEO4J_apoc_import_file_enabled=true \ 21 | --env NEO4J_apoc_import_file_use__neo4j__config=true \ 22 | --env NEO4JLABS_PLUGINS=\[\"apoc\"\] \ 23 | --name neo4j_primekg_service \ 24 | neo4j_primekg:latest 25 | ``` 26 | 27 | Test via Cypher shell: 28 | ``` 29 | docker exec -it cypher-shell -u neo4j -p opensesame 30 | ``` 31 | ``` 32 | MATCH (disease1:disease {name: "psoriasis"})-[:parent_child]->(disease2:disease {name: "scalp disease"}) 33 | RETURN disease1, disease2; 34 | ``` 35 | 36 | Alternative start-up using an already downloaded files: 37 | ``` 38 | docker run -d --restart=always \ 39 | --publish=7474:7474 --publish=7687:7687 \ 40 | --volume :/primekg_data/kg.csv:ro \ 41 | --volume :/primekg_data/drug_features.tab:ro \ 42 | --volume :/primekg_data/disease_features.tab:ro \ 43 | --env NEO4J_AUTH=neo4j/opensesame \ 44 | --env NEO4J_server_databases_default__to__read__only=true \ 45 | --env NEO4J_apoc_export_file_enabled=true \ 46 | --env NEO4J_apoc_import_file_enabled=true \ 47 | --env NEO4J_apoc_import_file_use__neo4j__config=true \ 48 | --env NEO4JLABS_PLUGINS=\[\"apoc\"\] \ 49 | --name neo4j_primekg_service \ 50 | neo4j_primekg:latest 51 | ``` 52 | 53 | ## Import and cleaning process for PrimeKG data 54 | 55 | These are the steps taken when starting a new Docker container for the image build from the given Dockerfile. These steps are executed by the scripts "import_primekg.sh" and "primekg_to_neo4j_csv.py". 56 | 57 | First, the data is downloaded from Harvard Dataverse unless it was linked into the container via volume. The data in this case is the kg.csv containing the actual graph (in form of all its edges) and the additional features for drug and disease nodes (drug_features.tab and disease_features.tab). 58 | 59 | In the next step, the data gets loaded via pandas and some general clean up steps are performed: 60 | - All string entries are converted to lower case in order to make queries to the graph more robust. 61 | - In the columns that encode node or relation types ("relation", "display_relation", "x_type", "y_type") spaces and - get replaced with _ since these symbols do not work in Cypher queries. 62 | - Similarly, / gets replaced by _ or _ because this symbol may cause problems in Cypher queries. 63 | - For the drug and disease data, the following replacements are executed: 64 | - "\r" -> "" 65 | - "\n" -> " " 66 | - ",," -> "," 67 | - '"' -> "" 68 | 69 | In the third step, the nodes are extracted from the graph data. Nodes get separated by type. For each node in each type the properties index, id, name and source are extracted. 70 | In this step, the additional features for the drug and disease nodes also get extracted. For the drug nodes, these are certain properties, like the aggregate state or the molecular weight. 71 | For the disease node, several textual descriptions get added to the graph. This includes a description of the disease, its symptoms or when to see a doctor. Note that for the disease description up to four possible candidates are available and they get prioritized as follows: 72 | 1. orphanet_clinical_description 73 | 2. mondo_definition 74 | 3. umls_description 75 | 4. orphanet_definition 76 | 77 | In the fourth step, the relation data from kg.csv gets extracted. They can either be extracted based on the display_relation column (default) or the relation column in the data. No properties are added. 78 | 79 | Subsequently, additional nodes and relations are extracted from the drug features. These nodes are category nodes and approval status nodes to which the drugs can be linked. 80 | 81 | Finally, CSV files for the Neo4j import are built. For each node type a CSV file is generated where the index column is named as "index:ID(< type >)". Similarly, for each relation a CSV file is generated where start and end column are marked via ":START_ID(< node_type >)" and ":END_ID(< node_type >)". Note, that the bidirectional relations are treated as two different relations. 82 | 83 | The script generates the Neo4j import command based on the generated files and the import gets executed. 84 | 85 | Note that the original files (kg.csv etc.) will be deleted if and only if they were downloaded in the beginning. 86 | 87 | ## Citation 88 | 89 | The PrimeKG data was made available by Chandak, Payal and Huang, Kexin and Zitnik, Marinka in Nature Scientific Data, 2023: [Building a knowledge graph to enable precision medicine](https://www.nature.com/articles/s41597-023-01960-3) 90 | 91 | The data is available in [Harvard Dataverse](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/IXA7BM). 92 | -------------------------------------------------------------------------------- /neo4j_primekg/graph_preprocessing_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/neo4j_primekg/graph_preprocessing_scripts/__init__.py -------------------------------------------------------------------------------- /neo4j_primekg/graph_preprocessing_scripts/allowed_preferred_name_mapping_categories.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Set 2 | 3 | _allowed_categories_for_each_type: Dict[str, List[str]] = { 4 | "gene/protein": [ 5 | "Gene", 6 | # "Disease", 7 | # "CellLine", 8 | # "Drug", 9 | # "Organs" 10 | ], 11 | "drug": [ 12 | # "Gene", 13 | # "Antibody", 14 | # "Disease", 15 | # "CellLine", 16 | # "Cells", 17 | "Drug", 18 | # "Organs" 19 | ], 20 | "effect/phenotype": [ 21 | # "Gene", 22 | # "Antibody", 23 | # "Disease", 24 | # "CellLine", 25 | # "Cells", 26 | # "Drug", 27 | # "Organs" 28 | ], 29 | "disease": [ 30 | # "Gene", 31 | # "Antibody", 32 | "Disease", 33 | # "CellLine", 34 | # "Cells", 35 | # "Drug", 36 | # "Organs" 37 | ], 38 | "biological_process": [ 39 | # "Gene", 40 | # "Antibody", 41 | # "Disease", 42 | # "CellLine", 43 | # "Cells", 44 | # "Drug", 45 | # "Organs" 46 | ], 47 | "molecular_function": [ 48 | # "Gene", 49 | # "Antibody", 50 | # "CellLine", 51 | # "Disease", 52 | # "Cells", 53 | # "Drug", 54 | # "Organs" 55 | ], 56 | "cellular_component": [ 57 | # "Gene", 58 | # "Antibody", 59 | # "Disease", 60 | # "CellLine", 61 | # "Cells", 62 | # "Drug", 63 | # "Organs" 64 | ], 65 | "exposure": [ 66 | # "Gene", 67 | # "Disease", 68 | # "CellLine", 69 | # "Drug", 70 | # "Organs" 71 | ], 72 | "pathway": [ 73 | # "Gene", 74 | # "Antibody", 75 | # "CellLine", 76 | # "Disease", 77 | # "Cells", 78 | # "Drug", 79 | # "Organs" 80 | ], 81 | "anatomy": [ 82 | # "Gene", 83 | # "Disease", 84 | # "CellLine", 85 | # "Cells", 86 | # "Drug", 87 | "Organs" 88 | ], 89 | } 90 | 91 | allowed_categories_for_each_type: Dict[str, Set[str]] = { 92 | k: set(v) for k, v in _allowed_categories_for_each_type.items() if v 93 | } 94 | -------------------------------------------------------------------------------- /neo4j_primekg/graph_preprocessing_scripts/apply_preferred_term_mapping.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import List 3 | 4 | try: 5 | import polars as pd 6 | except: 7 | import pandas as pd 8 | 9 | from drugbank_id_mapper import DrugbankIdMapper 10 | from id_based_mapper import IdBasedMapper 11 | from mapper import _POLARS_AVAILABLE 12 | from mondo_id_mapper import MondoIdMapper 13 | from preferred_term_mapper import PreferredTermMapper 14 | 15 | 16 | def primekg_main(args: argparse.Namespace): 17 | graph_file = args.graph_file 18 | output_file = args.output_file 19 | mapping_dir = args.mapping_dir 20 | mondo_id_mapping = args.mondo_id_mapping 21 | drugbank_id_mapping = args.drugbank_id_mapping 22 | 23 | print(f'Loading graph from "{graph_file}...') 24 | graph = pd.read_csv(graph_file, low_memory=False, infer_schema_length=0) 25 | id_based_mappers: List[IdBasedMapper] = [] 26 | if mondo_id_mapping is not None: 27 | print(f'Loading Mondo id mapping from "{mondo_id_mapping}...') 28 | id_based_mappers.append(MondoIdMapper(graph, mondo_id_mapping)) 29 | if drugbank_id_mapping is not None: 30 | print(f'Loading DrugBank id mapping from "{drugbank_id_mapping}...') 31 | id_based_mappers.append(DrugbankIdMapper(graph, drugbank_id_mapping)) 32 | print(f"Applying id based mappings...") 33 | for mapping in id_based_mappers: 34 | graph = mapping.apply_to_graph(graph) 35 | print(f'Loading mappings from "{mapping_dir}...') 36 | mapper_fct = PreferredTermMapper(mapping_dir, graph) 37 | print(f"Applying mappings to graph...") 38 | graph = mapper_fct.apply_to_graph(graph) 39 | print(f'Writing transformed graph to "{output_file}"...') 40 | if _POLARS_AVAILABLE: 41 | graph.write_csv(output_file) 42 | else: 43 | graph.to_csv(output_file) 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | 49 | parser.add_argument("--graph_file", type=str, required=True) 50 | parser.add_argument("--output_file", type=str, required=True) 51 | parser.add_argument("--mapping_dir", type=str, required=True) 52 | parser.add_argument("--mondo_id_mapping", type=str, default=None) 53 | parser.add_argument("--drugbank_id_mapping", type=str, default=None) 54 | 55 | primekg_main(parser.parse_args()) 56 | -------------------------------------------------------------------------------- /neo4j_primekg/graph_preprocessing_scripts/drugbank_id_mapper.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Set, Tuple 2 | 3 | import pandas as pd 4 | from id_based_mapper import IdBasedMapper 5 | from mapper import _POLARS_AVAILABLE 6 | 7 | 8 | class DrugbankIdMapper(IdBasedMapper): 9 | def __init__( 10 | self, 11 | graph: pd.DataFrame, 12 | drugbank_refs_file: str, 13 | drugbank_id_key: str = "DrugBank ID", 14 | new_id_key1: str = "SCI_DRUG ID", 15 | new_id_key2: str = "ID", 16 | name_key: str = "Preferred_Name", 17 | id_to_id_map_sheet: str = "concordance", 18 | id_to_pref_names_sheet: str = "csvdata", 19 | use_polars: bool = _POLARS_AVAILABLE, 20 | ) -> None: 21 | super().__init__(graph, use_polars=use_polars) 22 | self._id_to_id_map = pd.read_excel(drugbank_refs_file, sheet_name=id_to_id_map_sheet) 23 | self._id_to_pref_names = pd.read_excel(drugbank_refs_file, sheet_name=id_to_pref_names_sheet) 24 | self._drugbank_id_key = drugbank_id_key 25 | self._new_id_key1 = new_id_key1 26 | self._new_id_key2 = new_id_key2 27 | self._name_key = name_key 28 | 29 | def _get_relevant_node_types(self) -> Set[str]: 30 | return set(["drug"]) 31 | 32 | def _get_mapping_for_id(self, graph_id: str) -> Tuple[str, Optional[str]]: 33 | id_to_id_res = self._id_to_id_map[self._id_to_id_map[self._drugbank_id_key] == graph_id] 34 | if len(id_to_id_res) == 0: 35 | return graph_id, None 36 | new_id = id_to_id_res[self._new_id_key1].iloc[0] 37 | new_names = self._id_to_pref_names[self._id_to_pref_names[self._new_id_key2] == new_id][self._name_key] 38 | if len(new_names) == 0: 39 | return graph_id, None 40 | new_name = new_names.iloc[0] 41 | return new_id, new_name 42 | -------------------------------------------------------------------------------- /neo4j_primekg/graph_preprocessing_scripts/id_based_mapper.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import List, Optional, Set, Tuple, Union 3 | 4 | try: 5 | import polars as pd 6 | except: 7 | import pandas as pd 8 | 9 | from mapper import Mapper 10 | 11 | 12 | class IdBasedMapper(Mapper): 13 | @abstractmethod 14 | def _get_relevant_node_types(self) -> Set[str]: 15 | """Create a list of node types to which this mapping is applied. 16 | 17 | :return: List of ids. 18 | :rtype: List[str] 19 | """ 20 | 21 | @abstractmethod 22 | def _get_mapping_for_id(self, graph_id: str) -> Tuple[str, Optional[str]]: 23 | """Generates the corresponding id and preferred term/name for a given id. 24 | 25 | :param graph_id: Id in the base graph. 26 | :type graph_id: str 27 | :return: The new id and the new name. If no mapping exists, (graph_id, None) is returned for. 28 | :rtype: Tuple[str, Optional[str]] 29 | """ 30 | ... 31 | 32 | def _call_polars(self, row: Tuple[str, ...]) -> Tuple[str, ...]: 33 | row_lst = list(row) 34 | row_lst = self._apply_mapping(row_lst) 35 | return tuple(row_lst) 36 | 37 | def _call_pandas(self, row: "pd.Series") -> "pd.Series": 38 | return self._apply_mapping(row) 39 | 40 | def _apply_mapping(self, row: Union["pd.Series", List[str]]) -> Union["pd.Series", List[str]]: 41 | node_type, id_in_graph = row[self._x_type_idx], row[self._x_id_idx] 42 | if node_type in self._get_relevant_node_types(): 43 | new_id, new_name = self._get_mapping_for_id(id_in_graph) 44 | if new_name is not None: 45 | row[self._x_name_idx] = new_name 46 | row[self._x_id_idx] = new_id 47 | node_type, id_in_graph = row[self._y_type_idx], row[self._y_id_idx] 48 | if node_type in self._get_relevant_node_types(): 49 | new_id, new_name = self._get_mapping_for_id(id_in_graph) 50 | if new_name is not None: 51 | row[self._y_name_idx] = new_name 52 | row[self._y_id_idx] = new_id 53 | return row 54 | -------------------------------------------------------------------------------- /neo4j_primekg/graph_preprocessing_scripts/mapper.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple, Union 3 | 4 | from tqdm import tqdm 5 | 6 | _POLARS_AVAILABLE = False 7 | try: 8 | import polars as pd 9 | 10 | _POLARS_AVAILABLE = True 11 | except: 12 | import logging 13 | 14 | import pandas as pd 15 | 16 | logging.warning("Polars not available for CSV processing. Using pandas which might be slow.") 17 | 18 | 19 | class Mapper(ABC): 20 | 21 | def __init__(self, graph: pd.DataFrame, use_polars: bool = _POLARS_AVAILABLE) -> None: 22 | self._use_polars = use_polars 23 | self._set_column_accessors(graph) 24 | 25 | def apply_to_graph(self, graph: pd.DataFrame) -> pd.DataFrame: 26 | if self._use_polars: 27 | col_names = graph.columns 28 | graph = graph.map_rows(self) 29 | graph.columns = col_names 30 | else: 31 | tqdm.pandas() 32 | graph = graph.progress_apply(self, axis="columns", result_type="broadcast") 33 | return graph 34 | 35 | def __call__(self, row: Union["pd.Series", Tuple[str, ...]]) -> Union["pd.Series", Tuple[str]]: 36 | if self._use_polars: 37 | return self._call_polars(row) 38 | return self._call_pandas(row) 39 | 40 | @abstractmethod 41 | def _call_polars(self, row: Tuple[str, ...]) -> Tuple[str, ...]: ... 42 | 43 | @abstractmethod 44 | def _call_pandas(self, row: "pd.Series") -> "pd.Series": ... 45 | 46 | def _set_column_accessors(self, df: pd.DataFrame): 47 | if self._use_polars: 48 | self._x_type_idx = df.get_column_index("x_type") 49 | self._x_name_idx = df.get_column_index("x_name") 50 | self._x_id_idx = df.get_column_index("x_id") 51 | self._y_type_idx = df.get_column_index("y_type") 52 | self._y_name_idx = df.get_column_index("y_name") 53 | self._y_id_idx = df.get_column_index("y_id") 54 | else: 55 | self._x_type_idx = "x_type" 56 | self._x_name_idx = "x_name" 57 | self._x_id_idx = "x_id" 58 | self._y_type_idx = "y_type" 59 | self._y_name_idx = "y_name" 60 | self._y_id_idx = "y_id" 61 | -------------------------------------------------------------------------------- /neo4j_primekg/graph_preprocessing_scripts/mondo_id_mapper.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | from typing import List, Optional, Set, Tuple, Union 3 | 4 | import pandas as pd 5 | from id_based_mapper import IdBasedMapper 6 | 7 | 8 | class MondoColumn(StrEnum): 9 | DIRECT_MATCHING = "MONDO-ID\nMATCHING ID(s)" 10 | CROSSREFERENCES_MATCHING = "MONDO-ID\nVIA CROSSREFS AND PREFERRED LABEL MATCH" 11 | LEVENSHTEIN_MATCHING = "cell2Vlookup_left_all[MONDO: AU→B]_A\nSYNONYMS MATCH (>75%Levenshtein)" 12 | 13 | 14 | class MondoIdMapper(IdBasedMapper): 15 | def __init__( 16 | self, 17 | graph: pd.DataFrame, 18 | mondo_refs_file: str, 19 | mondo_id_key: Union[MondoColumn, str] = MondoColumn.LEVENSHTEIN_MATCHING, 20 | new_id_key: str = "BAYER ID", 21 | name_key: str = "name", 22 | id_to_id_map_sheet: str = "concordance", 23 | id_to_pref_names_sheet: str = "SCI_DISEASE", 24 | ) -> None: 25 | super().__init__(graph) 26 | self._id_to_id_map = pd.read_excel(mondo_refs_file, sheet_name=id_to_id_map_sheet) 27 | self._id_to_pref_names = pd.read_excel(mondo_refs_file, sheet_name=id_to_pref_names_sheet) 28 | self._mondo_id_key = str(mondo_id_key) 29 | self._new_id_key = new_id_key 30 | self._name_key = name_key 31 | 32 | def _get_relevant_node_types(self) -> Set[str]: 33 | return set(["disease"]) 34 | 35 | def _get_mapping_for_id(self, graph_id: str) -> Tuple[str, Optional[str]]: 36 | graph_id = f"MONDO:{int(graph_id):07d}" 37 | id_to_id_res = self._id_to_id_map[self._id_to_id_map[self._mondo_id_key] == graph_id] 38 | if len(id_to_id_res) == 0: 39 | return graph_id, None 40 | new_id = id_to_id_res[self._new_id_key].iloc[0] 41 | new_name = self._id_to_pref_names[self._id_to_pref_names[self._new_id_key] == new_id][self._name_key].iloc[0] 42 | return new_id, new_name 43 | -------------------------------------------------------------------------------- /neo4j_primekg/graph_preprocessing_scripts/preferred_term_category_finder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import sys 5 | from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union 6 | 7 | import pandas as pd 8 | from tqdm import tqdm 9 | 10 | 11 | def primekg_main(args: List[str]): 12 | from fact_finder.tools.entity_detector import EntityDetector 13 | 14 | if len(args) != 2: 15 | print(f"Usage: python {args[0]} ") 16 | exit(1) 17 | 18 | graph = pd.read_csv(args[1], low_memory=False) 19 | detector = EntityDetector() 20 | extractor = PreferredTermDataExtractor(graph=graph, api_tool=detector) 21 | 22 | graph_types = [ 23 | "gene/protein", 24 | "drug", 25 | "effect/phenotype", 26 | "disease", 27 | "biological_process", 28 | "molecular_function", 29 | "cellular_component", 30 | "exposure", 31 | "pathway", 32 | "anatomy", 33 | ] 34 | 35 | for graph_type in graph_types: 36 | print("---" * 20) 37 | print(f"PROCESSING {graph_type}") 38 | extractor.process(graph_type) 39 | print("---" * 20) 40 | 41 | 42 | class PreferredTermMapping(dict): 43 | def __init__(self, data: Iterable[Tuple[str, List[Tuple[str, str, str]]]]): 44 | super().__init__(data) 45 | 46 | @classmethod 47 | def from_dataframe( 48 | cls, 49 | graph_data: pd.DataFrame, 50 | graph_type: str, 51 | api_tool: Callable[[str], List[Dict[str, Any]]], 52 | type_key: str = "x_type", 53 | name_key: str = "x_name", 54 | ) -> PreferredTermMapping: 55 | return cls( 56 | cls._map_entity_to_prefered_terms_and_categories(graph_data, graph_type, api_tool, type_key, name_key) 57 | ) 58 | 59 | @staticmethod 60 | def _map_entity_to_prefered_terms_and_categories( 61 | graph_data: pd.DataFrame, 62 | graph_type: str, 63 | api_tool: Callable[[str], List[Dict[str, Any]]], 64 | type_key: str = "x_type", 65 | name_key: str = "x_name", 66 | ) -> Iterable[Tuple[str, List[Tuple[str, str, str]]]]: 67 | relevant_entries = graph_data[graph_data[type_key] == graph_type][name_key].unique() 68 | for entry in tqdm(relevant_entries): 69 | pref_term_results = PreferredTermMapping._extract_preferred_name_and_id_and_category(entry, api_tool) 70 | yield entry, pref_term_results 71 | 72 | @staticmethod 73 | def _extract_preferred_name_and_id_and_category( 74 | name: str, api_tool: Callable[[str], List[Dict[str, Any]]] 75 | ) -> List[Tuple[str, str, str]]: 76 | return [(r["pref_term"], r["id"], r["sem_type"]) for r in api_tool(name)] 77 | 78 | 79 | class PreferredTermDataExtractor: 80 | def __init__( 81 | self, 82 | graph: Union[str, pd.DataFrame], 83 | api_tool: Callable[[str], List[Dict[str, Any]]], 84 | type_key: str = "x_type", 85 | name_key: str = "x_name", 86 | ) -> None: 87 | self._api_tool = api_tool 88 | self._graph = pd.read_csv(graph, low_memory=False) if isinstance(graph, str) else graph 89 | self._type_key = type_key 90 | self._name_key = name_key 91 | 92 | def process(self, graph_type: str): 93 | print(f">> Parsing preferred terms for category {graph_type}...") 94 | data = PreferredTermMapping.from_dataframe( 95 | graph_data=self._graph, 96 | graph_type=graph_type, 97 | api_tool=self._api_tool, 98 | type_key=self._type_key, 99 | name_key=self._name_key, 100 | ) 101 | print("Data:", json.dumps(data)) 102 | self.store(data, graph_type) 103 | self.collect_categories(data) 104 | self.print_errors(data) 105 | 106 | @staticmethod 107 | def store(data: PreferredTermMapping, graph_type: str): 108 | fn = f"{graph_type}_data.json".replace("/", "_").replace(" ", "_") 109 | print(f">> Storing results as {fn}...") 110 | with open(fn, "w") as file: 111 | json.dump(data, file) 112 | 113 | @staticmethod 114 | def collect_categories(data: PreferredTermMapping) -> Set[str]: 115 | print(f">> Extracting unique categories...") 116 | categories = set() 117 | for v in data.values(): 118 | for entries in v: 119 | categories.add(entries[2]) 120 | print(f">> Categories found: {list(categories)}") 121 | return categories 122 | 123 | @staticmethod 124 | def print_errors(data: PreferredTermMapping) -> None: 125 | print(">> The following entries have duplicate category errors:") 126 | for k, v in data.items(): 127 | if len(v) > 1: 128 | unique_categories = set(entries[2] for entries in v) 129 | if len(v) != len(unique_categories): 130 | print(f">>>> Error for drug '{k}': Entries are {v}") 131 | 132 | 133 | if __name__ == "__main__": 134 | primekg_main(sys.argv) 135 | -------------------------------------------------------------------------------- /neo4j_primekg/graph_preprocessing_scripts/preferred_term_mapper.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Dict, Iterable, List, Set, Tuple 4 | 5 | from tqdm import tqdm 6 | 7 | try: 8 | import polars as pd 9 | except: 10 | import pandas as pd 11 | 12 | from allowed_preferred_name_mapping_categories import allowed_categories_for_each_type 13 | from mapper import Mapper 14 | 15 | 16 | class PreferredTermMapper(Mapper): 17 | def __init__(self, directory: str, graph: pd.DataFrame) -> None: 18 | super().__init__(graph) 19 | self._mapping: Dict[Tuple[str, str], Tuple[str, str]] = dict(self._prepare_all_mappings(directory)) 20 | self._filter_substr = "BAY" 21 | 22 | def _call_polars(self, row: Tuple[str, ...]) -> Tuple[str, ...]: 23 | row_lst = list(row) 24 | node_type, val_in_graph = row_lst[self._x_type_idx], row_lst[self._x_name_idx] 25 | if pref_name_and_id := self._mapping.get((node_type, val_in_graph)): 26 | if self._filter_substr not in pref_name_and_id[1]: 27 | row_lst[self._x_name_idx] = pref_name_and_id[0] 28 | row_lst[self._x_id_idx] = pref_name_and_id[1] 29 | node_type, val_in_graph = row_lst[self._y_type_idx], row_lst[self._y_name_idx] 30 | if pref_name_and_id := self._mapping.get((node_type, val_in_graph)): 31 | if self._filter_substr not in pref_name_and_id[1]: 32 | row_lst[self._y_name_idx] = pref_name_and_id[0] 33 | row_lst[self._y_id_idx] = pref_name_and_id[1] 34 | return tuple(row_lst) 35 | 36 | def _call_pandas(self, row: "pd.Series") -> "pd.Series": 37 | node_type, val_in_graph = row[[self._x_type_idx, self._x_name_idx]] 38 | if pref_name_and_id := self._mapping.get((node_type, val_in_graph)): 39 | row[[self._x_name_idx, self._x_id_idx]] = pref_name_and_id 40 | node_type, val_in_graph = row[[self._y_type_idx, self._y_name_idx]] 41 | if pref_name_and_id := self._mapping.get((node_type, val_in_graph)): 42 | row[[self._y_name_idx, self._y_id_idx]] = pref_name_and_id 43 | return row 44 | 45 | def _prepare_all_mappings(self, directory: str) -> Iterable[Tuple[Tuple[str, str], Tuple[str, str]]]: 46 | desc = "Loading mappings for categories" 47 | for node_type, allowed_categories in tqdm(allowed_categories_for_each_type.items(), desc=desc): 48 | mappings = self._load_mappings(directory, node_type) 49 | yield from self._create_allowed_mapping(node_type, allowed_categories, mappings) 50 | 51 | def _load_mappings(self, directory: str, node_type: str) -> Dict[str, List[Tuple[str, str, str]]]: 52 | fn = os.path.join(directory, f"{node_type}_data.json".replace("/", "_").replace(" ", "_")) 53 | with open(fn, "r") as f: 54 | return json.load(f) 55 | 56 | def _create_allowed_mapping( 57 | self, 58 | node_type: str, 59 | ordered_allowed_categories: Set[str], 60 | mappings: Dict[str, List[Tuple[str, str, str]]], 61 | ) -> Iterable[Tuple[Tuple[str, str], Tuple[str, str]]]: 62 | for val_in_graph, pref_names in mappings.items(): 63 | pref_names = list(filter(lambda pn: pn[2] in ordered_allowed_categories, pref_names)) 64 | if len(pref_names) == 0: 65 | continue 66 | entry_per_category = {category: (pref_name, id) for pref_name, id, category in pref_names} 67 | if len(entry_per_category) != len(pref_names): 68 | continue # Two synonyms in the same category are considered an error. 69 | most_relevant_category = next( 70 | category for category in ordered_allowed_categories if category in entry_per_category 71 | ) 72 | pref_name, id = entry_per_category[most_relevant_category] 73 | yield (node_type, val_in_graph), (pref_name, id) 74 | -------------------------------------------------------------------------------- /neo4j_primekg/import_primekg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function download_data 4 | { 5 | if [ -f "$PRIMEKG_CSV" ]; then 6 | declare -g PRIMEKG_CSV_EXISTED=true 7 | else 8 | echo ">>> Downloading PrimeKG csv file..." 9 | declare -g PRIMEKG_CSV_EXISTED=false 10 | wget --no-clobber -O $PRIMEKG_CSV https://dataverse.harvard.edu/api/access/datafile/6180620 11 | fi 12 | 13 | if [ -f "$PRIMEKG_DRUG_NODE_FEATURES_TSV" ]; then 14 | declare -g PRIMEKG_DRUG_FEATURES_EXISTED=true 15 | else 16 | echo ">>> Downloading PrimeKG drug features file..." 17 | declare -g PRIMEKG_DRUG_FEATURES_EXISTED=false 18 | wget --no-clobber -O $PRIMEKG_DRUG_NODE_FEATURES_TSV https://dataverse.harvard.edu/api/access/datafile/6180619 19 | fi 20 | 21 | if [ -f "$PRIMEKG_DISEASE_NODE_FEATURES_TSV" ]; then 22 | declare -g PRIMEKG_DISEASE_FEATURES_EXISTED=true 23 | else 24 | echo ">>> Downloading PrimeKG disease features file..." 25 | declare -g PRIMEKG_DISEASE_FEATURES_EXISTED=false 26 | wget --no-clobber -O $PRIMEKG_DISEASE_NODE_FEATURES_TSV https://dataverse.harvard.edu/api/access/datafile/6180618 27 | fi 28 | } 29 | 30 | function cleanup_data 31 | { 32 | rm -r $PRIMEKG_CSVS_FOR_NEO4J 33 | 34 | if [ "${PRIMEKG_CSV_EXISTED}" == "false" ]; then 35 | echo ">>> Cleaning up PrimeKG csv file..." 36 | rm $PRIMEKG_CSV 37 | fi 38 | if [ "${PRIMEKG_DRUG_FEATURES_EXISTED}" == "false" ]; then 39 | echo ">>> Cleaning up PrimeKG csv file..." 40 | rm $PRIMEKG_DRUG_NODE_FEATURES_TSV 41 | fi 42 | if [ "${PRIMEKG_DISEASE_FEATURES_EXISTED}" == "false" ]; then 43 | echo ">>> Cleaning up PrimeKG csv file..." 44 | rm $PRIMEKG_DISEASE_NODE_FEATURES_TSV 45 | fi 46 | } 47 | 48 | function import_primekg 49 | { 50 | download_data 51 | 52 | echo ">>> Processing PrimeKG csv file..." 53 | IMPORT_CMD=$(python3 $IMPORT_DIR/primekg_to_neo4j_csv.py) 54 | 55 | if [ $? -ne 0 ]; then 56 | echo "Error in python script. Output:" 57 | echo $IMPORT_CMD 58 | exit 1 59 | fi 60 | 61 | echo ">>> Importing PrimeKG..." 62 | eval "$IMPORT_CMD" 63 | 64 | if [ $? -ne 0 ]; then 65 | echo "Error while importing!" 66 | exit 1 67 | fi 68 | 69 | cleanup_data 70 | } 71 | 72 | if [ -e $DELETE_PRIMEKG_CSV ]; then 73 | echo "WARNING: DELETE_PRIMEKG_CSV is deprecated. This now gets handled automatically." 74 | fi 75 | 76 | SETUP_DONE_MARKER="/data/prime_kg_is_imported_to_neo4j" 77 | if [ ! -e $SETUP_DONE_MARKER ]; then 78 | import_primekg 79 | touch $SETUP_DONE_MARKER 80 | fi 81 | 82 | echo ">>> Starting Neo4j..." 83 | bash /startup/docker-entrypoint.sh neo4j 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fact_finder" 3 | version = "0.1.0" 4 | requires-python = ">=3.8,<3.12" 5 | description = "FactFinder" 6 | dependencies = [ 7 | "chainlit", 8 | "langchain", 9 | "langchain-openai", 10 | "pandas", 11 | "pyvis", 12 | "streamlit", 13 | "nltk", 14 | "SPARQLWrapper", 15 | "neo4j", 16 | "regex" 17 | ] 18 | 19 | [project.optional-dependencies] 20 | linting = ["pre-commit"] 21 | tests = ["pytest"] 22 | evaluation = ["sentence-transformers"] 23 | 24 | [project.scripts] 25 | fact-finder = "fact_finder.__main__:main" 26 | 27 | [build-system] 28 | requires = ["setuptools >= 61.0.0"] 29 | build-backend = "setuptools.build_meta" 30 | 31 | [tool.black] 32 | target-version = ["py310"] 33 | line-length = 120 34 | 35 | [tool.isort] 36 | profile = "black" 37 | line_length = 120 -------------------------------------------------------------------------------- /src/fact_finder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/src/fact_finder/__init__.py -------------------------------------------------------------------------------- /src/fact_finder/chains/__init__.py: -------------------------------------------------------------------------------- 1 | from .answer_generation_chain import AnswerGenerationChain 2 | from .cypher_query_generation_chain import CypherQueryGenerationChain 3 | from .cypher_query_preprocessors_chain import CypherQueryPreprocessorsChain 4 | from .graph_chain import GraphChain 5 | from .graph_qa_chain import GraphQAChain, GraphQAChainOutput 6 | from .subgraph_extractor_chain import SubgraphExtractorChain 7 | -------------------------------------------------------------------------------- /src/fact_finder/chains/answer_generation_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from langchain.chains import LLMChain 4 | from langchain.chains.base import Chain 5 | from langchain_core.callbacks import CallbackManagerForChainRun 6 | from langchain_core.language_models import BaseLanguageModel 7 | from langchain_core.prompts import BasePromptTemplate 8 | 9 | from fact_finder.utils import fill_prompt_template 10 | 11 | 12 | class AnswerGenerationChain(Chain): 13 | llm_chain: LLMChain 14 | return_intermediate_steps: bool 15 | question_key: str = "question" #: :meta private: 16 | graph_result_key: str = "graph_result" #: :meta private: 17 | output_key: str = "answer" #: :meta private: 18 | intermediate_steps_key: str = "intermediate_steps" 19 | 20 | def __init__( 21 | self, llm: BaseLanguageModel, prompt_template: BasePromptTemplate, return_intermediate_steps: bool = True 22 | ): 23 | llm_chain = LLMChain(llm=llm, prompt=prompt_template) 24 | super().__init__(llm_chain=llm_chain, return_intermediate_steps=return_intermediate_steps) 25 | 26 | @property 27 | def input_keys(self) -> List[str]: 28 | """Return the input keys.""" 29 | return [self.question_key, self.graph_result_key] 30 | 31 | @property 32 | def output_keys(self) -> List[str]: 33 | """Return the output keys.""" 34 | return [self.output_key] 35 | 36 | def _call(self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None) -> Dict[str, Any]: 37 | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() 38 | answer = self._run_qa_chain(inputs, _run_manager) 39 | return self._prepare_chain_result(inputs, answer) 40 | 41 | def _run_qa_chain(self, inputs: Dict[str, Any], run_manager: CallbackManagerForChainRun) -> str: 42 | inputs = self._prepare_chain_input(inputs) 43 | final_result = self.llm_chain( 44 | inputs=inputs, 45 | callbacks=run_manager.get_child(), 46 | )[self.llm_chain.output_key] 47 | self._log_it(run_manager, final_result) 48 | return final_result 49 | 50 | def _log_it(self, run_manager, graph_result): 51 | run_manager.on_text("QA Chain Result:", end="\n", verbose=self.verbose) 52 | run_manager.on_text(str(graph_result), color="green", end="\n", verbose=self.verbose) 53 | 54 | def _prepare_chain_result(self, inputs: Dict[str, Any], answer: str) -> Dict[str, Any]: 55 | chain_result = { 56 | self.output_key: answer, 57 | } 58 | if self.return_intermediate_steps: 59 | intermediate_steps = inputs.get(self.intermediate_steps_key, []) + [{self.output_key: answer}] 60 | filled_prompt = fill_prompt_template( 61 | inputs=self._prepare_chain_input(inputs), 62 | llm_chain=self.llm_chain, 63 | ) 64 | intermediate_steps.append({f"{self.__class__.__name__}_filled_prompt": filled_prompt}) 65 | chain_result[self.intermediate_steps_key] = intermediate_steps 66 | return chain_result 67 | 68 | def _prepare_chain_input(self, inputs: Dict[str, Any]): 69 | return {"question": inputs[self.question_key], "context": inputs[self.graph_result_key]} 70 | -------------------------------------------------------------------------------- /src/fact_finder/chains/combined_graph_rag_qa_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from fact_finder.chains.rag.text_search_qa_chain import TextSearchQAChain 4 | from langchain_core.callbacks import CallbackManagerForChainRun 5 | from langchain_core.language_models import BaseLanguageModel 6 | from langchain_core.prompts import BasePromptTemplate 7 | 8 | from fact_finder.utils import fill_prompt_template 9 | 10 | 11 | class CombinedQAChain(TextSearchQAChain): 12 | cypher_query_key: str = "preprocessed_cypher_query" #: :meta private: 13 | graph_result_key: str = "graph_result" #: :meta private: 14 | output_key: str = "answer" #: :meta private: 15 | 16 | @property 17 | def input_keys(self) -> List[str]: 18 | return [self.question_key, self.graph_result_key, self.cypher_query_key] 19 | 20 | @property 21 | def output_keys(self) -> List[str]: 22 | return [self.output_key] 23 | 24 | def __init__( 25 | self, 26 | llm: BaseLanguageModel, 27 | combined_answer_generation_template: BasePromptTemplate, 28 | rag_output_key: str, 29 | return_intermediate_steps: bool = True, 30 | ): 31 | super().__init__( 32 | llm=llm, 33 | rag_answer_generation_template=combined_answer_generation_template, 34 | rag_output_key=rag_output_key, 35 | return_intermediate_steps=return_intermediate_steps, 36 | ) 37 | 38 | def _generate_answer(self, inputs: Dict[str, Any], run_manager: CallbackManagerForChainRun) -> str: 39 | inputs = self._prepare_chain_input(inputs) 40 | result = self.rag_answer_generation_llm_chain( 41 | inputs=inputs, 42 | callbacks=run_manager.get_child(), 43 | )[self.rag_answer_generation_llm_chain.output_key] 44 | return result 45 | 46 | def _prepare_chain_input(self, inputs: Dict[str, Any]): 47 | return { 48 | "abstracts": inputs["semantic_scholar_result"], 49 | "cypher_query": inputs["cypher_query"], 50 | "graph_answer": inputs["graph_result"], 51 | "question": inputs["question"], 52 | } 53 | -------------------------------------------------------------------------------- /src/fact_finder/chains/cypher_query_generation_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from fact_finder.utils import fill_prompt_template 4 | from langchain.chains import LLMChain 5 | from langchain.chains.base import Chain 6 | from langchain.chains.graph_qa.cypher import construct_schema, extract_cypher 7 | from langchain_core.callbacks import CallbackManagerForChainRun 8 | from langchain_core.language_models import BaseLanguageModel 9 | from langchain_core.prompts import BasePromptTemplate 10 | 11 | 12 | class CypherQueryGenerationChain(Chain): 13 | cypher_generation_chain: LLMChain 14 | graph_schema: str 15 | predicate_descriptions_text: str 16 | return_intermediate_steps: bool 17 | input_key: str = "question" #: :meta private: 18 | output_key: str = "cypher_query" #: :meta private: 19 | intermediate_steps_key: str = "intermediate_steps" 20 | 21 | def __init__( 22 | self, 23 | llm: BaseLanguageModel, 24 | prompt_template: BasePromptTemplate, 25 | graph_structured_schema: Dict[str, Any], 26 | predicate_descriptions: List[Dict[str, str]] = [], 27 | return_intermediate_steps: bool = True, 28 | exclude_types: List[str] = [], 29 | include_types: List[str] = [], 30 | ): 31 | cypher_generation_chain = LLMChain(llm=llm, prompt=prompt_template) 32 | if exclude_types and include_types: 33 | raise ValueError("Either `exclude_types` or `include_types` " "can be provided, but not both") 34 | graph_schema = construct_schema(graph_structured_schema, include_types, exclude_types) 35 | predicate_descriptions_text = self._construct_predicate_descriptions_text(predicate_descriptions) 36 | super().__init__( 37 | cypher_generation_chain=cypher_generation_chain, 38 | graph_schema=graph_schema, 39 | predicate_descriptions_text=predicate_descriptions_text, 40 | return_intermediate_steps=return_intermediate_steps, 41 | ) 42 | 43 | @property 44 | def input_keys(self) -> List[str]: 45 | """Return the input keys.""" 46 | return [self.input_key] 47 | 48 | @property 49 | def output_keys(self) -> List[str]: 50 | """Return the output keys.""" 51 | return [self.output_key] 52 | 53 | def _call( 54 | self, 55 | inputs: Dict[str, Any], 56 | run_manager: Optional[CallbackManagerForChainRun] = None, 57 | ) -> Dict[str, Any]: 58 | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() 59 | generated_cypher = self._generate_cypher(inputs, _run_manager) 60 | return self._prepare_chain_result(inputs, generated_cypher) 61 | 62 | def _construct_predicate_descriptions_text(self, predicate_descriptions: List[Dict[str, str]]) -> str: 63 | if len(predicate_descriptions) == 0: 64 | return "" 65 | result = ["Here are some descriptions to the most common relationships:"] 66 | for item in predicate_descriptions: 67 | item_as_text = f"({item['subject']})-[{item['predicate']}]->({item['object']}): {item['definition']}" 68 | result.append(item_as_text) 69 | return "\n".join(result) 70 | 71 | def _generate_cypher(self, inputs: Dict[str, Any], run_manager: CallbackManagerForChainRun) -> str: 72 | prepared_inputs = self._prepare_chain_input(inputs) 73 | generated_cypher = self.cypher_generation_chain( 74 | inputs=prepared_inputs, 75 | callbacks=run_manager.get_child(), 76 | )[self.cypher_generation_chain.output_key] 77 | generated_cypher = extract_cypher(generated_cypher) 78 | if generated_cypher.startswith("cypher"): 79 | generated_cypher = generated_cypher[6:].strip() 80 | self._log_it(generated_cypher, run_manager) 81 | return generated_cypher 82 | 83 | def _log_it(self, generated_cypher: str, run_manager: CallbackManagerForChainRun): 84 | run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) 85 | run_manager.on_text(generated_cypher, color="green", end="\n", verbose=self.verbose) 86 | 87 | def _prepare_chain_result(self, inputs: Dict[str, Any], generated_cypher: str) -> Dict[str, Any]: 88 | chain_result = {self.output_key: generated_cypher} 89 | if self.return_intermediate_steps: 90 | intermediate_steps = inputs.get(self.intermediate_steps_key, []) 91 | filled_prompt = fill_prompt_template( 92 | llm_chain=self.cypher_generation_chain, 93 | inputs=self._prepare_chain_input(inputs), 94 | ) 95 | intermediate_steps += [ 96 | {"question": inputs[self.input_key]}, 97 | {self.output_key: generated_cypher}, 98 | {f"{self.__class__.__name__}_filled_prompt": filled_prompt}, 99 | ] 100 | chain_result[self.intermediate_steps_key] = intermediate_steps 101 | return chain_result 102 | 103 | def _prepare_chain_input(self, inputs: Dict[str, Any]): 104 | return { 105 | "question": inputs[self.input_key], 106 | "schema": self.graph_schema, 107 | "predicate_descriptions": self.predicate_descriptions_text, 108 | } 109 | -------------------------------------------------------------------------------- /src/fact_finder/chains/cypher_query_preprocessors_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | from fact_finder.tools.cypher_preprocessors.cypher_query_preprocessor import ( 4 | CypherQueryPreprocessor, 5 | ) 6 | from langchain.chains.base import Chain 7 | from langchain_core.callbacks import CallbackManagerForChainRun 8 | 9 | 10 | class CypherQueryPreprocessorsChain(Chain): 11 | cypher_query_preprocessors: List[CypherQueryPreprocessor] 12 | return_intermediate_steps: bool = True 13 | input_key: str = "cypher_query" #: :meta private: 14 | output_key: str = "preprocessed_cypher_query" #: :meta private: 15 | intermediate_steps_key: str = "intermediate_steps" #: :meta private: 16 | 17 | @property 18 | def input_keys(self) -> List[str]: 19 | """Return the input keys.""" 20 | return [self.input_key] 21 | 22 | @property 23 | def output_keys(self) -> List[str]: 24 | """Return the output keys.""" 25 | return [self.output_key] 26 | 27 | def _call(self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None) -> Dict[str, Any]: 28 | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() 29 | generated_cypher = inputs[self.input_key] 30 | preprocessed_cypher, intermediate_steps = self._run_preprocessors(_run_manager, generated_cypher) 31 | return self._prepare_chain_result(inputs, preprocessed_cypher, intermediate_steps) 32 | 33 | def _run_preprocessors( 34 | self, _run_manager: CallbackManagerForChainRun, generated_cypher: str 35 | ) -> Tuple[str, List[Dict[str, str]]]: 36 | intermediate_steps = [] 37 | for processor in self.cypher_query_preprocessors: 38 | generated_cypher = processor(generated_cypher) 39 | intermediate_steps.append({type(processor).__name__: generated_cypher}) 40 | self._log_it(_run_manager, generated_cypher) 41 | return generated_cypher, intermediate_steps 42 | 43 | def _log_it(self, _run_manager: CallbackManagerForChainRun, generated_cypher: str): 44 | _run_manager.on_text("Preprocessed Cypher:", end="\n", verbose=self.verbose) 45 | _run_manager.on_text(generated_cypher, color="green", end="\n", verbose=self.verbose) 46 | 47 | def _prepare_chain_result( 48 | self, inputs: Dict[str, Any], preprocessed_cypher: str, intermediate_steps: List[Dict[str, str]] 49 | ) -> Dict[str, Any]: 50 | chain_result: Dict[str, Any] = { 51 | self.output_key: preprocessed_cypher, 52 | } 53 | if self.return_intermediate_steps: 54 | intermediate_steps = inputs.get(self.intermediate_steps_key, []) + intermediate_steps 55 | chain_result[self.intermediate_steps_key] = intermediate_steps 56 | return chain_result 57 | -------------------------------------------------------------------------------- /src/fact_finder/chains/filtered_primekg_question_preprocessing_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Collection, Dict, List, Set, Tuple 2 | 3 | from fact_finder.chains.entity_detection_question_preprocessing_chain import ( 4 | EntityDetectionQuestionPreprocessingChain, 5 | ) 6 | from fact_finder.tools.entity_detector import EntityDetector 7 | from langchain_community.graphs import Neo4jGraph 8 | 9 | 10 | class FilteredPrimeKGQuestionPreprocessingChain(EntityDetectionQuestionPreprocessingChain): 11 | graph: Neo4jGraph 12 | excluded_entities: Set[str] 13 | _side_effect_exists_cypher_query: str = ( 14 | 'MATCH(node:effect_or_phenotype {name: "{entity_name}"}) RETURN COUNT(node) > 0 AS exists' 15 | ) 16 | _general_exists_cypher_query: str = 'MATCH(node {name: "{entity_name}"}) RETURN COUNT(node) > 0 AS exists' 17 | 18 | def __init__( 19 | self, 20 | *, 21 | entity_detector: EntityDetector, 22 | allowed_types_and_description_templates: Dict[str, str], 23 | graph: Neo4jGraph, 24 | excluded_entities: Collection[str] = [], 25 | return_intermediate_steps: bool = True, 26 | ): 27 | allowed_types_and_description_templates["side_effect"] = "{entity} is a disease or a effect_or_phenotype." 28 | super().__init__( 29 | entity_detector=entity_detector, 30 | allowed_types_and_description_templates=allowed_types_and_description_templates, 31 | return_intermediate_steps=return_intermediate_steps, 32 | excluded_entities=set(e.strip().lower() for e in excluded_entities), 33 | graph=graph, 34 | ) 35 | 36 | def _extract_entity_data( 37 | self, question: str, entity_results: List[Tuple[int, int, str, str]] 38 | ) -> Tuple[str, List[str]]: 39 | entity_type_hints = [] 40 | new_question = "" 41 | last_index = 0 42 | for start, end, pref_name, type in entity_results: 43 | original_name = question[start:end] 44 | if original_name.strip().lower() in self.excluded_entities: 45 | continue 46 | pref_name, type = self._filter_pref_name_and_type(original_name, pref_name, type) 47 | new_question += question[last_index:start] + pref_name 48 | last_index = end 49 | entity_type_hints.append(self._create_type_hint(pref_name, type)) 50 | new_question += question[last_index:] 51 | return new_question, entity_type_hints 52 | 53 | def _filter_pref_name_and_type(self, original_name: str, pref_name: str, type: str) -> Tuple[str, str]: 54 | type = type.lower() 55 | if type == "disease": 56 | pref_name, type = self._filter_diseases_that_are_also_side_effects(original_name, pref_name, type) 57 | elif self._exists_as_side_effect_node(original_name): 58 | pref_name = original_name 59 | return pref_name, type 60 | 61 | def _filter_diseases_that_are_also_side_effects( 62 | self, original_name: str, pref_name: str, type: str 63 | ) -> Tuple[str, str]: 64 | if self._exists_as_side_effect_node(original_name): 65 | type = "side_effect" 66 | pref_name = original_name 67 | elif self._exists_as_side_effect_node(pref_name): 68 | type = "side_effect" 69 | return pref_name, type 70 | 71 | def _exists_as_side_effect_node(self, name: str) -> bool: 72 | cypher_query = self._side_effect_exists_cypher_query.replace("{entity_name}", name) 73 | nodes = self.graph.query(cypher_query) 74 | return nodes[0]["exists"] 75 | 76 | def _exists_in_graph(self, name: str) -> bool: 77 | cypher_query = self._general_exists_cypher_query.replace("{entity_name}", name) 78 | nodes = self.graph.query(cypher_query) 79 | return nodes[0]["exists"] 80 | -------------------------------------------------------------------------------- /src/fact_finder/chains/graph_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from langchain.chains.base import Chain 4 | from langchain_community.graphs import Neo4jGraph 5 | from langchain_core.callbacks import CallbackManagerForChainRun 6 | 7 | 8 | class GraphChain(Chain): 9 | graph: Neo4jGraph 10 | top_k: int = 50 11 | return_intermediate_steps: bool = True 12 | input_key: str = "preprocessed_cypher_query" #: :meta private: 13 | output_key: str = "graph_result" #: :meta private: 14 | intermediate_steps_key: str = "intermediate_steps" 15 | 16 | @property 17 | def input_keys(self) -> List[str]: 18 | """Return the input keys.""" 19 | return [self.input_key] 20 | 21 | @property 22 | def output_keys(self) -> List[str]: 23 | """Return the output keys.""" 24 | return [self.output_key] 25 | 26 | def _call(self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None) -> Dict[str, Any]: 27 | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() 28 | generated_cypher = inputs[self.input_key] 29 | graph_result = self._query_graph(generated_cypher) 30 | self._log_it(_run_manager, graph_result) 31 | return self._prepare_chain_result(inputs, graph_result) 32 | 33 | def _query_graph(self, generated_cypher: str) -> List[Dict[str, Any]]: 34 | if generated_cypher: 35 | return self.graph.query(generated_cypher)[: self.top_k] 36 | return [] 37 | 38 | def _log_it(self, run_manager, graph_result): 39 | run_manager.on_text("Graph Result:", end="\n", verbose=self.verbose) 40 | run_manager.on_text(str(graph_result), color="green", end="\n", verbose=self.verbose) 41 | 42 | def _prepare_chain_result(self, inputs: Dict[str, Any], graph_result: List[Dict[str, Any]]) -> Dict[str, Any]: 43 | chain_result = { 44 | self.output_key: graph_result, 45 | } 46 | if self.return_intermediate_steps: 47 | intermediate_steps = inputs.get(self.intermediate_steps_key, []) + [{self.output_key: graph_result}] 48 | chain_result[self.intermediate_steps_key] = intermediate_steps 49 | return chain_result 50 | -------------------------------------------------------------------------------- /src/fact_finder/chains/graph_qa_chain/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import GraphQAChainConfig 2 | from .early_stopping_chain import GraphQAChainEarlyStopping 3 | from .graph_qa_chain import GraphQAChain 4 | from .output import GraphQAChainOutput 5 | from .output_chain import GraphQAChainOutputPreparation 6 | -------------------------------------------------------------------------------- /src/fact_finder/chains/graph_qa_chain/config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from fact_finder.chains.entity_detection_question_preprocessing_chain import ( 4 | EntityDetectionQuestionPreprocessingProtocol, 5 | ) 6 | from fact_finder.tools.cypher_preprocessors.cypher_query_preprocessor import ( 7 | CypherQueryPreprocessor, 8 | ) 9 | from fact_finder.tools.entity_detector import EntityDetector 10 | from langchain_community.graphs import Neo4jGraph 11 | from langchain_core.language_models import BaseLanguageModel 12 | from langchain_core.prompts import BasePromptTemplate 13 | from pydantic.v1 import BaseModel, root_validator 14 | 15 | 16 | class GraphQAChainConfig(BaseModel): 17 | 18 | class Config: 19 | arbitrary_types_allowed: bool = True 20 | 21 | llm: BaseLanguageModel 22 | graph: Neo4jGraph 23 | cypher_prompt: BasePromptTemplate 24 | answer_generation_prompt: BasePromptTemplate 25 | cypher_query_preprocessors: List[CypherQueryPreprocessor] 26 | predicate_descriptions: List[Dict[str, str]] = [] 27 | schema_error_string: str = "SCHEMA_ERROR" 28 | return_intermediate_steps: bool = True 29 | 30 | use_entity_detection_preprocessing: bool = False 31 | entity_detection_preprocessor_type: Optional[ 32 | EntityDetectionQuestionPreprocessingProtocol 33 | ] # = EntityDetectionQuestionPreprocessingChain 34 | entity_detector: Optional[EntityDetector] = None 35 | # The keys of this dict contain the (lower case) type names for which entities can be replaced. 36 | # They map to a template explaining the type of an entity (marked via {entity}) 37 | # Example: "chemical_compounds", "{entity} is a chemical compound." 38 | allowed_types_and_description_templates: Dict[str, str] = {} 39 | 40 | skip_subgraph_generation: bool = False 41 | use_subgraph_expansion: bool = False 42 | 43 | combine_output_with_sematic_scholar: bool = False 44 | semantic_scholar_keyword_prompt: Optional[BasePromptTemplate] = None 45 | combined_answer_generation_prompt: Optional[BasePromptTemplate] = None 46 | top_k: int = 2000 # todo change back when not evaluating 47 | 48 | @root_validator(allow_reuse=True) 49 | def check_entity_detection_preprocessing_settings(cls, values: Dict[str, Any]) -> Dict[str, Any]: 50 | if values["use_entity_detection_preprocessing"] and values.get("entity_detector") is None: 51 | raise ValueError("When setting use_entity_detection_preprocessing, an entity_detector has to be provided.") 52 | if values["use_entity_detection_preprocessing"] and values.get("entity_detection_preprocessor_type") is None: 53 | raise ValueError( 54 | "When setting use_entity_detection_preprocessing, " 55 | "an entity_detection_preprocessor_type has to be provided." 56 | ) 57 | if ( 58 | values["use_entity_detection_preprocessing"] 59 | and len(values.get("allowed_types_and_description_templates", {})) == 0 60 | ): 61 | raise ValueError( 62 | "When setting use_entity_detection_preprocessing, " 63 | "allowed_types_and_description_templates has to be provided." 64 | ) 65 | return values 66 | 67 | @root_validator(allow_reuse=True) 68 | def check_semantic_scholar_configured_correctly(cls, values: Dict[str, Any]) -> Dict[str, Any]: 69 | if values["combine_output_with_sematic_scholar"] and values.get("semantic_scholar_keyword_prompt") is None: 70 | raise ValueError( 71 | "When setting combine_output_with_sematic_scholar, " 72 | "a semantic_scholar_keyword_prompt has to be provided." 73 | ) 74 | if values["combine_output_with_sematic_scholar"] and values.get("combined_answer_generation_prompt") is None: 75 | raise ValueError( 76 | "When setting combine_output_with_sematic_scholar, " 77 | "combined_answer_generation_prompt has to be provided." 78 | ) 79 | return values 80 | -------------------------------------------------------------------------------- /src/fact_finder/chains/graph_qa_chain/early_stopping_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from fact_finder.chains.graph_qa_chain.output import GraphQAChainOutput 4 | from langchain.chains.base import Chain 5 | from langchain_core.callbacks import CallbackManagerForChainRun 6 | 7 | 8 | class GraphQAChainEarlyStopping(Chain): 9 | schema_error_string: str 10 | return_intermediate_steps: bool = True 11 | question_key: str = "question" #: :meta private: 12 | query_key: str = "cypher_query" #: :meta private: 13 | output_key: str = "graph_qa_output" #: :meta private: 14 | intermediate_steps_key: str = "intermediate_steps" #: :meta private: 15 | 16 | @property 17 | def input_keys(self) -> List[str]: 18 | """Return the input keys.""" 19 | return [self.question_key, self.query_key] 20 | 21 | @property 22 | def output_keys(self) -> List[str]: 23 | """Return the output keys.""" 24 | return [self.output_key] 25 | 26 | def _call( 27 | self, 28 | inputs: Dict[str, Any], 29 | run_manager: Optional[CallbackManagerForChainRun] = None, 30 | ) -> Dict[str, Any]: 31 | answer = inputs[self.query_key][len(self.schema_error_string) :].lstrip(": ") 32 | result: Dict[str, Any] = { 33 | self.output_key: GraphQAChainOutput( 34 | question=inputs[self.question_key], 35 | cypher_query="", 36 | graph_response=[], 37 | answer=answer, 38 | evidence_sub_graph=[], 39 | expanded_evidence_subgraph=[], 40 | ) 41 | } 42 | if self.return_intermediate_steps and self.intermediate_steps_key in inputs: 43 | result[self.intermediate_steps_key] = inputs[self.intermediate_steps_key] 44 | return result 45 | -------------------------------------------------------------------------------- /src/fact_finder/chains/graph_qa_chain/output.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from pydantic.v1 import BaseModel 4 | 5 | 6 | class GraphQAChainOutput(BaseModel): 7 | question: str 8 | cypher_query: str 9 | graph_response: List[Dict[str, Any]] 10 | answer: str 11 | evidence_sub_graph: List[Dict[str, Any]] 12 | expanded_evidence_subgraph: List[Dict[str, Any]] 13 | -------------------------------------------------------------------------------- /src/fact_finder/chains/graph_qa_chain/output_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterable, List, Optional 2 | 3 | from fact_finder.chains.graph_qa_chain.output import GraphQAChainOutput 4 | from langchain.chains.base import Chain 5 | from langchain_core.callbacks import CallbackManagerForChainRun 6 | 7 | 8 | class GraphQAChainOutputPreparation(Chain): 9 | return_intermediate_steps: bool = True 10 | answer_key: str = "answer" #: :meta private: 11 | question_key: str = "question" #: :meta private: 12 | query_key: str = "preprocessed_cypher_query" #: :meta private: 13 | graph_key: str = "graph_result" #: :meta private: 14 | evidence_key: str = "extracted_nodes" #: :meta private: 15 | expanded_evidence_key: str = "expanded_nodes" #: :meta private: 16 | output_key: str = "graph_qa_output" #: :meta private: 17 | intermediate_steps_key: str = "intermediate_steps" #: :meta private: 18 | 19 | @property 20 | def input_keys(self) -> List[str]: 21 | """Return the input keys.""" 22 | return [self.answer_key, self.evidence_key] 23 | 24 | @property 25 | def output_keys(self) -> List[str]: 26 | """Return the output keys.""" 27 | return [self.output_key] 28 | 29 | def _call( 30 | self, 31 | inputs: Dict[str, Any], 32 | run_manager: Optional[CallbackManagerForChainRun] = None, 33 | ) -> Dict[str, Any]: 34 | result: Dict[str, Any] = { 35 | self.output_key: GraphQAChainOutput( 36 | question=inputs[self.answer_key][self.question_key], 37 | cypher_query=inputs[self.answer_key][self.query_key], 38 | graph_response=inputs[self.answer_key][self.graph_key], 39 | answer=inputs[self.answer_key][self.answer_key], 40 | evidence_sub_graph=inputs[self.evidence_key][self.evidence_key][self.evidence_key], 41 | expanded_evidence_subgraph=inputs[self.evidence_key][self.evidence_key][self.expanded_evidence_key], 42 | ) 43 | } 44 | if self.return_intermediate_steps: 45 | intermediate_steps = list(self._merge_intermediate_steps(inputs)) 46 | result[self.intermediate_steps_key] = intermediate_steps 47 | return result 48 | 49 | def _merge_intermediate_steps(self, inputs: Dict[str, Any]) -> Iterable[Dict[str, Any]]: 50 | i1 = iter(inputs[self.answer_key][self.intermediate_steps_key]) 51 | i2 = iter(inputs[self.evidence_key][self.intermediate_steps_key]) 52 | i2_store = None 53 | try: 54 | while True: 55 | i1_entry = next(i1) 56 | yield i1_entry 57 | i2_entry = next(i2) 58 | if i1_entry != i2_entry: 59 | i2_store = i2_entry 60 | break 61 | except StopIteration: 62 | pass 63 | yield from i1 64 | if i2_store: 65 | yield i2_store 66 | yield from i2 67 | -------------------------------------------------------------------------------- /src/fact_finder/chains/graph_summary_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from langchain.chains import LLMChain 4 | from langchain.chains.base import Chain 5 | from langchain_core.callbacks import CallbackManagerForChainRun 6 | from langchain_core.language_models import BaseLanguageModel 7 | from langchain_core.prompts import PromptTemplate 8 | 9 | from fact_finder.prompt_templates import SUBGRAPH_SUMMARY_PROMPT 10 | from fact_finder.utils import fill_prompt_template 11 | 12 | 13 | class GraphSummaryChain(Chain): 14 | graph_summary_llm_chain: LLMChain 15 | return_intermediate_steps: bool = True 16 | input_key: str = "sub_graph" #: :meta private: 17 | output_key: str = "summary" #: :meta private: 18 | intermediate_steps_key: str = "intermediate_steps" #: :meta private: 19 | 20 | @property 21 | def input_keys(self) -> List[str]: 22 | return [self.input_key] 23 | 24 | @property 25 | def output_keys(self) -> List[str]: 26 | return [self.output_key] 27 | 28 | def __init__( 29 | self, 30 | llm: BaseLanguageModel, 31 | graph_summary_template: PromptTemplate, 32 | return_intermediate_steps: bool = True, 33 | ): 34 | graph_summary_llm_chain = LLMChain(llm=llm, prompt=graph_summary_template) 35 | super().__init__( 36 | llm=llm, 37 | graph_summary_llm_chain=graph_summary_llm_chain, 38 | return_intermediate_steps=return_intermediate_steps, 39 | ) 40 | 41 | def _call(self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None) -> Dict[str, Any]: 42 | answer = self.graph_summary_llm_chain( 43 | inputs=self._prepare_chain_input(inputs), 44 | callbacks=run_manager.get_child(), 45 | )[self.graph_summary_llm_chain.output_key] 46 | return self._prepare_chain_result(inputs, answer) 47 | 48 | def _prepare_chain_result(self, inputs, answer): 49 | chain_result: Dict[str, Any] = {self.output_key: answer} 50 | filled_prompt = fill_prompt_template( 51 | inputs=self._prepare_chain_input(inputs), llm_chain=self.graph_summary_llm_chain 52 | ) 53 | if self.return_intermediate_steps: 54 | intermediate_steps = inputs.get(self.intermediate_steps_key, []) 55 | intermediate_steps += [ 56 | {"question": inputs[self.input_key]}, 57 | {self.output_key: answer}, 58 | {f"{self.__class__.__name__}_filled_prompt": filled_prompt}, 59 | ] 60 | chain_result[self.intermediate_steps_key] = intermediate_steps 61 | return chain_result 62 | 63 | def _prepare_chain_input(self, inputs: Dict[str, Any]): 64 | return {"sub_graph": inputs[self.input_key]} 65 | -------------------------------------------------------------------------------- /src/fact_finder/chains/rag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/src/fact_finder/chains/rag/__init__.py -------------------------------------------------------------------------------- /src/fact_finder/chains/rag/semantic_scholar_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from fact_finder.tools.semantic_scholar_search_api_wrapper import ( 4 | SemanticScholarSearchApiWrapper, 5 | ) 6 | from langchain.chains import LLMChain 7 | from langchain.chains.base import Chain 8 | from langchain_core.callbacks import CallbackManagerForChainRun 9 | from langchain_core.language_models import BaseLanguageModel 10 | from langchain_core.prompts import BasePromptTemplate 11 | 12 | 13 | class SemanticScholarChain(Chain): 14 | semantic_scholar_search: SemanticScholarSearchApiWrapper 15 | keyword_generation_llm_chain: LLMChain 16 | return_intermediate_steps: bool = True 17 | question_key: str = "question" #: :meta private: 18 | output_key: str = "semantic_scholar_result" #: :meta private: 19 | intermediate_steps_key: str = "intermediate_steps" #: :meta private: 20 | 21 | @property 22 | def input_keys(self) -> List[str]: 23 | return [self.question_key] 24 | 25 | @property 26 | def output_keys(self) -> List[str]: 27 | return [self.output_key] 28 | 29 | def __init__( 30 | self, 31 | semantic_scholar_search: SemanticScholarSearchApiWrapper, 32 | llm: BaseLanguageModel, 33 | keyword_prompt_template: BasePromptTemplate, 34 | return_intermediate_steps: bool = True, 35 | ): 36 | keyword_generation_llm_chain = LLMChain(llm=llm, prompt=keyword_prompt_template) 37 | super().__init__( 38 | semantic_scholar_search=semantic_scholar_search, 39 | keyword_generation_llm_chain=keyword_generation_llm_chain, 40 | return_intermediate_steps=return_intermediate_steps, 41 | ) 42 | 43 | def _call(self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None) -> Dict[str, Any]: 44 | run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() 45 | keywords = self._generate_search_keywords_for_question(inputs, run_manager) 46 | semantic_scholar_result = self._search_semantic_scholar(keywords) 47 | return self._build_result(inputs, keywords, semantic_scholar_result) 48 | 49 | def _generate_search_keywords_for_question( 50 | self, inputs: Dict[str, Any], run_manager: CallbackManagerForChainRun 51 | ) -> str: 52 | return self.keyword_generation_llm_chain( 53 | {"question": inputs[self.question_key]}, 54 | callbacks=run_manager.get_child(), 55 | )[self.keyword_generation_llm_chain.output_key] 56 | 57 | def _search_semantic_scholar(self, keywords: str) -> str: 58 | search_result = self.semantic_scholar_search.search_by_abstracts(keywords=keywords) 59 | return "\n\n".join(search_result) 60 | 61 | def _build_result(self, inputs: Dict[str, Any], keywords: str, semantic_scholar_result: str) -> Dict[str, Any]: 62 | result: Dict[str, Any] = {self.output_key: semantic_scholar_result} 63 | if self.return_intermediate_steps: 64 | intermediate_steps = inputs.get(self.intermediate_steps_key, []) 65 | intermediate_steps.append(("search_keywords", keywords)) 66 | intermediate_steps.append(("semantic_scholar_result", semantic_scholar_result)) 67 | result[self.intermediate_steps_key] = intermediate_steps 68 | return result 69 | -------------------------------------------------------------------------------- /src/fact_finder/chains/rag/text_search_qa_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from langchain.chains import LLMChain 4 | from langchain.chains.base import Chain 5 | from langchain_core.callbacks import CallbackManagerForChainRun 6 | from langchain_core.language_models import BaseLanguageModel 7 | from langchain_core.prompts import BasePromptTemplate 8 | 9 | from fact_finder.utils import fill_prompt_template 10 | 11 | 12 | class TextSearchQAChain(Chain): 13 | rag_answer_generation_llm_chain: LLMChain 14 | return_intermediate_steps: bool = True 15 | rag_output_key: str 16 | question_key: str = "question" #: :meta private: 17 | output_key: str = "rag_output" #: :meta private: 18 | intermediate_steps_key: str = "intermediate_steps" #: :meta private: 19 | 20 | @property 21 | def input_keys(self) -> List[str]: 22 | return [self.question_key] 23 | 24 | @property 25 | def output_keys(self) -> List[str]: 26 | return [self.output_key] 27 | 28 | def __init__( 29 | self, 30 | llm: BaseLanguageModel, 31 | rag_answer_generation_template: BasePromptTemplate, 32 | rag_output_key: str, 33 | return_intermediate_steps: bool = True, 34 | ): 35 | rag_answer_generation_llm_chain = LLMChain(llm=llm, prompt=rag_answer_generation_template) 36 | super().__init__( 37 | rag_answer_generation_llm_chain=rag_answer_generation_llm_chain, 38 | rag_output_key=rag_output_key, 39 | return_intermediate_steps=return_intermediate_steps, 40 | ) 41 | 42 | def _call(self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None) -> Dict[str, Any]: 43 | run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() 44 | answer = self._generate_answer(inputs, run_manager) 45 | return self._prepare_chain_result(inputs, answer) 46 | 47 | def _generate_answer(self, inputs: Dict[str, Any], run_manager: CallbackManagerForChainRun) -> str: 48 | result = self.rag_answer_generation_llm_chain( 49 | inputs=self._prepare_chain_input(inputs), 50 | callbacks=run_manager.get_child(), 51 | )[self.rag_answer_generation_llm_chain.output_key] 52 | return result 53 | 54 | def _prepare_chain_result(self, inputs: Dict[str, Any], answer: str) -> Dict[str, Any]: 55 | result: Dict[str, Any] = {self.output_key: answer} 56 | if self.return_intermediate_steps: 57 | intermediate_steps = inputs.get(self.intermediate_steps_key, []) 58 | intermediate_steps.append(("rag_answer", answer)) 59 | filled_prompt = fill_prompt_template( 60 | inputs=self._prepare_chain_input(inputs), 61 | llm_chain=self.rag_answer_generation_llm_chain, 62 | ) 63 | intermediate_steps.append({f"{self.__class__.__name__}_filled_prompt": filled_prompt}) 64 | result[self.intermediate_steps_key] = intermediate_steps 65 | return result 66 | 67 | def _prepare_chain_input(self, inputs: Dict[str, Any]): 68 | return {"context": inputs["semantic_scholar_result"], "question": inputs["question"]} 69 | -------------------------------------------------------------------------------- /src/fact_finder/chains/subgraph_extractor_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | from fact_finder.tools.sub_graph_extractor import LLMSubGraphExtractor 4 | from fact_finder.tools.subgraph_extension import SubgraphExpansion 5 | from langchain.chains.base import Chain 6 | from langchain_community.graphs import Neo4jGraph 7 | from langchain_core.callbacks import CallbackManagerForChainRun 8 | from langchain_core.language_models import BaseLanguageModel 9 | 10 | 11 | class SubgraphExtractorChain(Chain): 12 | graph: Neo4jGraph 13 | subgraph_extractor: LLMSubGraphExtractor 14 | subgraph_expansion: SubgraphExpansion 15 | use_subgraph_expansion: bool 16 | return_intermediate_steps: bool 17 | input_key: str = "preprocessed_cypher_query" #: :meta private: 18 | output_key: str = "extracted_nodes" #: :meta private: 19 | intermediate_steps_key: str = "intermediate_steps" 20 | 21 | def __init__( 22 | self, 23 | llm: BaseLanguageModel, 24 | graph: Neo4jGraph, 25 | subgraph_expansion: SubgraphExpansion, 26 | use_subgraph_expansion: bool, 27 | return_intermediate_steps: bool = True, 28 | ): 29 | subgraph_extractor = LLMSubGraphExtractor(llm) 30 | super().__init__( 31 | subgraph_extractor=subgraph_extractor, 32 | graph=graph, 33 | subgraph_expansion=subgraph_expansion, 34 | use_subgraph_expansion=use_subgraph_expansion, 35 | return_intermediate_steps=return_intermediate_steps, 36 | ) 37 | 38 | @property 39 | def input_keys(self) -> List[str]: 40 | """Return the input keys.""" 41 | return [self.input_key] 42 | 43 | @property 44 | def output_keys(self) -> List[str]: 45 | """Return the output keys.""" 46 | return [self.output_key] 47 | 48 | def _call(self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None) -> Dict[str, Any]: 49 | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() 50 | try: 51 | subgraph_cypher, extracted_nodes, expanded_nodes = self._try_generating_subgraph( 52 | inputs[self.input_key], _run_manager 53 | ) 54 | except Exception as e: 55 | self._log_it(f"Error when creating subgraph cypher!", _run_manager, e) 56 | subgraph_cypher = subgraph_cypher if "subgraph_cypher" in locals() else "" 57 | extracted_nodes = extracted_nodes if "extracted_nodes" in locals() else [] 58 | expanded_nodes = [] 59 | return self._prepare_chain_result(inputs, subgraph_cypher, extracted_nodes, expanded_nodes) 60 | 61 | def _try_generating_subgraph( 62 | self, cypher_query: str, _run_manager: CallbackManagerForChainRun 63 | ) -> Tuple[str, List[Dict[str, Any]], List[Dict[str, Any]]]: 64 | subgraph_cypher = self.subgraph_extractor(cypher_query) 65 | self._log_it("Subgraph Cypher:", _run_manager, subgraph_cypher) 66 | 67 | extracted_nodes = self.graph.query(subgraph_cypher) 68 | expanded_nodes = self.subgraph_expansion.expand(nodes=extracted_nodes) if self.use_subgraph_expansion else [] 69 | self._log_it("Extracted Nodes:", _run_manager, extracted_nodes) 70 | 71 | return subgraph_cypher, extracted_nodes, expanded_nodes 72 | 73 | def _log_it(self, text: str, _run_manager: CallbackManagerForChainRun, entity: Any): 74 | _run_manager.on_text(text, end="\n", verbose=self.verbose) 75 | _run_manager.on_text(entity, color="green", end="\n", verbose=self.verbose) 76 | 77 | def _prepare_chain_result( 78 | self, 79 | inputs: Dict[str, Any], 80 | subgraph_cypher: str, 81 | extracted_nodes: List[Dict[str, Any]], 82 | expanded_nodes: List[Dict[str, Any]], 83 | ) -> Dict[str, Any]: 84 | chain_result = { 85 | self.output_key: {"extracted_nodes": extracted_nodes, "expanded_nodes": expanded_nodes}, 86 | } 87 | if self.return_intermediate_steps: 88 | intermediate_steps = inputs.get(self.intermediate_steps_key, []) + [{"subgraph_cypher": subgraph_cypher}] 89 | chain_result[self.intermediate_steps_key] = intermediate_steps 90 | return chain_result 91 | -------------------------------------------------------------------------------- /src/fact_finder/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/src/fact_finder/config/__init__.py -------------------------------------------------------------------------------- /src/fact_finder/config/simple_config.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fact_finder.chains.rag.semantic_scholar_chain import SemanticScholarChain 4 | from fact_finder.chains.rag.text_search_qa_chain import TextSearchQAChain 5 | from fact_finder.prompt_templates import KEYWORD_PROMPT, LLM_PROMPT, RAG_PROMPT 6 | from fact_finder.tools.semantic_scholar_search_api_wrapper import ( 7 | SemanticScholarSearchApiWrapper, 8 | ) 9 | from langchain.chains import LLMChain 10 | from langchain.chains.base import Chain 11 | from langchain_core.language_models import BaseLanguageModel 12 | from langchain_core.prompts.prompt import PromptTemplate 13 | 14 | 15 | def build_chain(model: BaseLanguageModel, args: List[str] = []) -> Chain: 16 | prompt_template = _get_llm_prompt_template() 17 | return LLMChain(llm=model, prompt=prompt_template, verbose=True) 18 | 19 | 20 | def build_rag_chain(model: BaseLanguageModel, args: List[str] = []) -> Chain: 21 | sematic_scholar_chain = SemanticScholarChain( 22 | semantic_scholar_search=SemanticScholarSearchApiWrapper(), 23 | llm=model, 24 | keyword_prompt_template=KEYWORD_PROMPT, 25 | ) 26 | rag_chain = TextSearchQAChain( 27 | llm=model, 28 | rag_answer_generation_template=RAG_PROMPT, 29 | rag_output_key=sematic_scholar_chain.output_keys[0], 30 | ) 31 | return sematic_scholar_chain | rag_chain 32 | 33 | 34 | def _get_llm_prompt_template() -> PromptTemplate: 35 | return LLM_PROMPT 36 | -------------------------------------------------------------------------------- /src/fact_finder/evaluator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/src/fact_finder/evaluator/__init__.py -------------------------------------------------------------------------------- /src/fact_finder/evaluator/evaluation.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Any, Dict, List 3 | 4 | import pandas as pd 5 | from langchain_openai import ChatOpenAI 6 | from tqdm import tqdm 7 | 8 | import fact_finder.config.primekg_config as graph_config 9 | from fact_finder.evaluator.evaluation_sample import EvaluationSample 10 | from fact_finder.evaluator.evaluation_samples import manual_samples 11 | from fact_finder.evaluator.set_evaluator.set_evaluator import SetEvaluator 12 | from fact_finder.evaluator.util import load_pickle, save_pickle 13 | 14 | 15 | class Evaluation: 16 | 17 | def __init__( 18 | self, 19 | run_name: str, 20 | model_name: str, 21 | chain_args: List[str] = None, 22 | limit_of_samples: int = None, 23 | idx_list_of_samples: List[int] = None, 24 | run_without_preprocessors: bool = False, 25 | ): 26 | if chain_args is None: 27 | chain_args = [ 28 | "--skip_subgraph_generation", 29 | "--normalized_graph", 30 | "--use_entity_detection_preprocessing", 31 | ] 32 | self.model_name = model_name 33 | self.run_name = run_name 34 | self.idx_list_of_samples = idx_list_of_samples 35 | if run_without_preprocessors: 36 | build_function = graph_config.build_chain_without_preprocessings_etc 37 | else: 38 | build_function = graph_config.build_chain 39 | self.chain = build_function(model=self.load_model(), combine_output_with_sematic_scholar=False, args=chain_args) 40 | self.eval_samples = self.eval_samples(limit_of_samples=limit_of_samples) 41 | 42 | def run(self, save_as_excel: bool = False): 43 | cache_path = "cached_results/" + self.run_name + ".pickle" 44 | chain_results = self.run_chain(cache_path) 45 | results = self.evaluate(chain_results) 46 | if save_as_excel: 47 | self.save_as_excel(results) 48 | return results 49 | 50 | def evaluate(self, chain_results: List) -> Dict[str, Any]: 51 | print("Evaluating...") 52 | evaluator = SetEvaluator() 53 | evaluation = evaluator.evaluate(evaluation_samples=self.eval_samples, chain_results=chain_results) 54 | return {"set_evaluator": evaluation} 55 | 56 | def run_chain(self, cache_path: str): 57 | results = [] 58 | print("Running Chains...") 59 | if os.path.exists(cache_path): 60 | return load_pickle(cache_path) 61 | eval_samples = self.eval_samples 62 | for eval_sample in tqdm(eval_samples): 63 | inputs = {"question": eval_sample.question} 64 | try: 65 | result = self.chain.invoke(inputs) 66 | except Exception as e: 67 | print(e) 68 | result = {} 69 | results.append(result) 70 | save_pickle(results, cache_path) 71 | return results 72 | 73 | def eval_samples(self, limit_of_samples: int = None): 74 | eval_samples = [] 75 | samples = manual_samples 76 | if self.idx_list_of_samples: 77 | samples = [samples[i] for i in self.idx_list_of_samples] 78 | for sample in samples[:limit_of_samples]: 79 | eval_sample = EvaluationSample( 80 | question=sample["question"], 81 | cypher_query=sample["expected_cypher"], 82 | expected_answer=sample["expected_answer"], 83 | nodes=sample["nodes"], 84 | ) 85 | eval_samples.append(eval_sample) 86 | return eval_samples 87 | 88 | def save_as_excel(self, results: Dict[str, list]): 89 | concat_results = [] 90 | for i in results.values(): 91 | concat_results += i 92 | df = pd.DataFrame(concat_results) 93 | path = self.run_name + ".xlsx" 94 | df.to_excel(path) 95 | 96 | def load_model(self): 97 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 98 | return ChatOpenAI(model=self.model_name, streaming=False, temperature=0, api_key=OPENAI_API_KEY) 99 | 100 | 101 | if __name__ == "__main__": 102 | models = ["gpt-4o", "gpt-4-turbo"] 103 | flags = [ 104 | [ 105 | "--skip_subgraph_generation", 106 | "--normalized_graph", 107 | ], 108 | ] 109 | for model in models: 110 | print(model) 111 | for flag in flags: 112 | print(flag) 113 | run_name = model + "_".join(flag) 114 | evaluation = Evaluation( 115 | run_name=run_name, chain_args=flag, model_name=model, run_without_preprocessors=False 116 | ) 117 | results = evaluation.run(save_as_excel=True) 118 | print(run_name) 119 | -------------------------------------------------------------------------------- /src/fact_finder/evaluator/evaluation_sample.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class EvaluationSample(BaseModel): 7 | question: str 8 | cypher_query: str 9 | expected_answer: str = "" 10 | nodes: List[Dict[str, Any]] = [] 11 | -------------------------------------------------------------------------------- /src/fact_finder/evaluator/score/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/src/fact_finder/evaluator/score/__init__.py -------------------------------------------------------------------------------- /src/fact_finder/evaluator/score/bleu_score.py: -------------------------------------------------------------------------------- 1 | from fact_finder.evaluator.score.score import Score 2 | from nltk import download, word_tokenize 3 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 4 | 5 | 6 | class BleuScore(Score): 7 | """ 8 | 1. punish length differences 9 | 2. compute n-gram-overlap for n=1,...,4 --> how many bi-grams / tri-grams / ... match in the sequences 10 | e.g. a bi-gram in a text sequence is two adjacent tokens 11 | If any count for the different n-grams is 0, the BLEU score is also 0. The smoothing function avoids that. 12 | """ 13 | 14 | def __init__(self) -> None: 15 | download("punkt") 16 | 17 | def compare(self, text_a: str, text_b: str) -> float: 18 | tokens_a = word_tokenize(text_a) 19 | tokens_b = word_tokenize(text_b) 20 | smoothing_function = SmoothingFunction().method1 21 | bleu_score = sentence_bleu([tokens_a], tokens_b, smoothing_function=smoothing_function) 22 | return bleu_score 23 | -------------------------------------------------------------------------------- /src/fact_finder/evaluator/score/difflib_score.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | 3 | from fact_finder.evaluator.score.score import Score 4 | 5 | 6 | class DifflibScore(Score): 7 | """ 8 | Finds the longest matching subsequence. 9 | Then, to the left and right of the longest matching subsequence, it again finds the longest matching subsequence. 10 | Omits "junk" elements such as white spaces. 11 | """ 12 | 13 | def compare(self, text_a: str, text_b: str) -> float: 14 | # set autojunk=False if you don't want junk to be detected automatically. 15 | sequence_matcher = difflib.SequenceMatcher(None, text_a, text_b, autojunk=True) 16 | similarity_score = sequence_matcher.ratio() 17 | return similarity_score 18 | -------------------------------------------------------------------------------- /src/fact_finder/evaluator/score/embedding_score.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer, util 2 | 3 | from fact_finder.evaluator.score.score import Score 4 | 5 | 6 | class EmbeddingScore(Score): 7 | """ 8 | Use Sentence Transformers to 9 | 1. embed the text 10 | 2. compare the vectors (cosine similarity) 11 | """ 12 | 13 | def __init__(self): 14 | self._model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") 15 | 16 | def compare(self, text_a: str, text_b: str) -> float: 17 | embedding_1 = self._model.encode(text_a, convert_to_tensor=True) 18 | embedding_2 = self._model.encode(text_b, convert_to_tensor=True) 19 | similarity_tensor = util.pytorch_cos_sim(embedding_1, embedding_2)[0] 20 | return float(similarity_tensor) 21 | -------------------------------------------------------------------------------- /src/fact_finder/evaluator/score/levenshtein_score.py: -------------------------------------------------------------------------------- 1 | import textdistance as td 2 | 3 | from fact_finder.evaluator.score.score import Score 4 | 5 | 6 | class LevenshteinScore(Score): 7 | """ 8 | The Levenshtein distance between two words is the minimum number of single-character edits (insertions, deletions or substitutions) required to change one word into the other. 9 | """ 10 | 11 | def compare(self, text_a: str, text_b: str) -> float: 12 | return float(td.levenshtein.normalized_similarity(text_a, text_b)) 13 | -------------------------------------------------------------------------------- /src/fact_finder/evaluator/score/score.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Score(ABC): 5 | @abstractmethod 6 | def compare(self, text_a: str, text_b: str) -> float: ... 7 | -------------------------------------------------------------------------------- /src/fact_finder/evaluator/set_evaluator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/src/fact_finder/evaluator/set_evaluator/__init__.py -------------------------------------------------------------------------------- /src/fact_finder/evaluator/set_evaluator/set_evaluator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple 3 | 4 | from fact_finder.chains.graph_qa_chain.output import GraphQAChainOutput 5 | from fact_finder.evaluator.evaluation_sample import EvaluationSample 6 | from fact_finder.utils import build_neo4j_graph 7 | from langchain_community.graphs import Neo4jGraph 8 | from tqdm import tqdm 9 | 10 | 11 | class SetEvaluator: 12 | CYPHER_QUERY_TEMPLATE: str = "MATCH (n) WHERE n.index ={idx} RETURN n.name" 13 | 14 | def __init__(self, graph: Optional[Neo4jGraph] = None): 15 | self.graph = graph or build_neo4j_graph() 16 | 17 | def evaluate( 18 | self, evaluation_samples: List[EvaluationSample], chain_results: List[Dict[str, Any]], **kwargs 19 | ) -> List[Dict[str, Any]]: 20 | eval_results = [] 21 | for sample, result in tqdm(zip(evaluation_samples, chain_results)): 22 | score, precision, recall = self.evaluate_sample(sample=sample, chain_result=result) 23 | eval_result = { 24 | "question": sample.question, 25 | "expected_cypher": sample.cypher_query, 26 | "expected_graph_response": sample.nodes, 27 | "expected_answer": sample.expected_answer, 28 | "score": score, 29 | "precision": precision, 30 | "recall": recall, 31 | } 32 | try: 33 | eval_result["actual_cypher"] = result["graph_qa_output"].cypher_query 34 | eval_result["actual_graph_response"] = result["graph_qa_output"].graph_response 35 | eval_result["actual_answer"] = result["graph_qa_output"].answer 36 | except KeyError: 37 | pass 38 | eval_results.append(eval_result) 39 | return eval_results 40 | 41 | def evaluate_sample( 42 | self, sample: EvaluationSample, chain_result: Dict[str, GraphQAChainOutput | Any] 43 | ) -> Tuple[float, float, float]: 44 | try: 45 | if not chain_result or not chain_result["graph_qa_output"].graph_response: 46 | print("No chain result or no nodes.") 47 | return 0.0, 0.0, 0.0 48 | if "index" in sample.nodes[0]: 49 | return self.evaluate_sample_with_single_index(sample=sample, result=chain_result) 50 | if "value" in sample.nodes[0]: 51 | return self.evaluate_sample_with_value_result(sample=sample, result=chain_result) 52 | return self.evaluate_sample_with_tuple_in_label(sample=sample, result=chain_result) 53 | except: 54 | print("Error") 55 | return 0.0, 0.0, 0.0 56 | 57 | def evaluate_sample_with_single_index( 58 | self, sample: EvaluationSample, result: Dict[str, GraphQAChainOutput | Any] 59 | ) -> Tuple[float, float, float]: 60 | ids = [node["index"] for node in sample.nodes] 61 | names = set(self.query_node_names(ids)) 62 | all_scores = self._get_scores_per_key(names, result) 63 | return max(all_scores.values(), key=lambda s: s[0]) 64 | 65 | def evaluate_sample_with_value_result( 66 | self, sample: EvaluationSample, result: Dict[str, GraphQAChainOutput | Any] 67 | ) -> Tuple[float, float, float]: 68 | value = [node["value"] for node in sample.nodes] 69 | all_scores = self._get_scores_per_key(value, result) 70 | return max(all_scores.values(), key=lambda s: s[0]) 71 | 72 | def evaluate_sample_with_tuple_in_label( 73 | self, sample: EvaluationSample, result: Dict[str, GraphQAChainOutput | Any] 74 | ) -> Tuple[float, float, float]: 75 | tuple_size = max(int(key[len("index") :]) for key in sample.nodes[0] if key.startswith("index")) + 1 76 | indices = [f"index{i}" for i in range(tuple_size)] 77 | ids_per_index = {i: tuple(self.query_node_names(node[i] for node in sample.nodes)) for i in indices} 78 | best_graph_result_keys = { 79 | k: max(spk := self._get_scores_per_key(v, result), key=lambda k: spk[k][0]) 80 | for k, v in ids_per_index.items() 81 | } 82 | expected_tuples = set(zip(*(ids_per_index[i] for i in indices))) 83 | result_graph_output = result["graph_qa_output"].graph_response 84 | result_tuples = set(tuple(res[best_graph_result_keys[i]] for i in indices) for res in result_graph_output) 85 | return ( 86 | intersection_over_union(result_tuples, expected_tuples), 87 | precision(result_tuples, expected_tuples), 88 | recall(result_tuples, expected_tuples), 89 | ) 90 | 91 | def query_node_names(self, ids: Iterable[int]) -> Iterable[str]: 92 | for number in ids: 93 | graph_return = self.graph.query(self.CYPHER_QUERY_TEMPLATE.replace("{idx}", f"{number}")) 94 | yield graph_return[0]["n.name"] 95 | 96 | def _get_scores_per_key( 97 | self, expected_values: Collection[str], result: Dict[str, GraphQAChainOutput | Any] 98 | ) -> Dict[str, Tuple[float, float, float]]: 99 | result_graph_output: List[Dict[str, Any]] = result["graph_qa_output"].graph_response 100 | if len(result_graph_output) == 0: 101 | return {"INVALID_KEY": 0.0} 102 | result_graph_output_keys = list(result_graph_output[0].keys()) 103 | return { 104 | key: self._compute_score_for_key(expected_values, result_graph_output, key) 105 | for key in result_graph_output_keys 106 | } 107 | 108 | def _compute_score_for_key( 109 | self, expected_values: Collection[str], result_graph_output: List[Dict[str, Any]], key: str 110 | ) -> Tuple[float, float, float]: 111 | if len(result_graph_output) == 0: 112 | return 0.0, 0.0, 0.0 113 | if isinstance(result_graph_output[0][key], dict) and "name" in result_graph_output[0][key]: 114 | # Handle the query returning whole nodes instead of specific properties. 115 | # TODO Other properties than "name"? 116 | entries = {i[key]["name"] for i in result_graph_output} 117 | else: 118 | entries = {i[key] for i in result_graph_output} 119 | return ( 120 | intersection_over_union(entries, expected_values), 121 | precision(entries, expected_values), 122 | recall(entries, expected_values), 123 | ) 124 | 125 | 126 | def intersection_over_union(result: Set[Any], expected: Collection[Any]) -> float: 127 | return len(result.intersection(expected)) / len(result.union(expected)) 128 | 129 | 130 | def precision(result: Set[Any], expected: Collection[Any]) -> float: 131 | return len(result.intersection(expected)) / len(result) 132 | 133 | 134 | def recall(result: Set[Any], expected: Collection[Any]) -> float: 135 | return len(result.intersection(expected)) / len(expected) 136 | -------------------------------------------------------------------------------- /src/fact_finder/evaluator/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional, List, Dict, Any 4 | 5 | import tqdm 6 | from dotenv import load_dotenv 7 | from langchain_community.graphs import Neo4jGraph 8 | 9 | from fact_finder.evaluator.evaluation_sample import EvaluationSample 10 | from fact_finder.evaluator.evaluation_samples import manual_samples 11 | from fact_finder.tools.sub_graph_extractor import LLMSubGraphExtractor 12 | from fact_finder.utils import build_neo4j_graph, load_chat_model 13 | 14 | import pickle 15 | 16 | 17 | def save_pickle(object: Any, path: str = "filename.pickle"): 18 | with open(path, "wb") as handle: 19 | pickle.dump(object, handle, protocol=pickle.HIGHEST_PROTOCOL) 20 | 21 | 22 | def load_pickle(path: str = "filename.pickle"): 23 | with open(path, "rb") as handle: 24 | return pickle.load(handle) 25 | 26 | 27 | class EvalSampleAddition: 28 | """ 29 | Class to add manually selected evaluation samples in our EvaluationSample format to a json file. 30 | TODO: the expected answers are rather random as of now. They may not correlate with whatever a LLM would verbalize 31 | """ 32 | 33 | def __init__(self, graph: Neo4jGraph, subgraph_extractor: LLMSubGraphExtractor, path_to_json: Path): 34 | self._subgraph_extractor = subgraph_extractor 35 | self._graph = graph 36 | self._path_to_json = path_to_json 37 | 38 | def add_to_evaluation_sample_json( 39 | self, 40 | question: str, 41 | cypher_query: str, 42 | source: str, 43 | expected_answer: str, 44 | nodes: List[Dict[str, Any]], 45 | is_answerable: Optional[bool] = True, 46 | subgraph_query: Optional[ 47 | str 48 | ] = None, # sometimes the subgraph query generation fails; we can set it manually then 49 | ): 50 | 51 | # TODO Note that running the subgraph query on my laptop crashes due to not enough RAM 52 | # subgraph_query = subgraph_query if subgraph_query else self._subgraph_extractor(expected_cypher) 53 | # try: 54 | # sub_graph = self._graph.query(query=subgraph_query) 55 | # except: 56 | # print(question) 57 | # print(expected_cypher) 58 | # print(subgraph_query) 59 | # raise 60 | sub_graph = [] 61 | sample = EvaluationSample( 62 | question=question, 63 | cypher_query=cypher_query, 64 | sub_graph=sub_graph, 65 | question_is_answerable=is_answerable, 66 | source=source, 67 | expected_answer=expected_answer, 68 | nodes=nodes, 69 | ) 70 | self._persist(sample=sample) 71 | 72 | def _persist(self, sample: EvaluationSample): 73 | with open(self._path_to_json, "r", encoding="utf8") as r: 74 | json_content = json.load(r) 75 | evaluation_samples = [EvaluationSample.model_validate(r) for r in json_content] 76 | # if an example for that question is already existing, it will be overwritten 77 | evaluation_samples = [r for r in evaluation_samples if not r.question == sample.question] 78 | evaluation_samples.append(sample) 79 | with open(self._path_to_json, "w", encoding="utf8") as w: 80 | json.dump([r.model_dump() for r in evaluation_samples], w, indent=4) 81 | 82 | 83 | if __name__ == "__main__": 84 | load_dotenv() 85 | 86 | eval_sample_addition = EvalSampleAddition( 87 | graph=build_neo4j_graph(), 88 | subgraph_extractor=LLMSubGraphExtractor( 89 | model=load_chat_model(), 90 | ), 91 | path_to_json=Path("src/fact_finder/evaluator/evaluation_samples.json"), 92 | ) 93 | 94 | for sample in tqdm.tqdm(manual_samples): 95 | ... 96 | eval_sample_addition.add_to_evaluation_sample_json( 97 | question=sample["question"], 98 | cypher_query=sample["expected_cypher"] if "expected_cypher" in sample.keys() else sample["cypher_query"], 99 | source="manual", 100 | expected_answer=sample["expected_answer"], 101 | is_answerable=True, 102 | nodes=sample["nodes"], 103 | subgraph_query=sample.get("subgraph_query"), 104 | ) 105 | -------------------------------------------------------------------------------- /src/fact_finder/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/src/fact_finder/py.typed -------------------------------------------------------------------------------- /src/fact_finder/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/src/fact_finder/tools/__init__.py -------------------------------------------------------------------------------- /src/fact_finder/tools/cypher_preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | from .cypher_query_preprocessor import CypherQueryPreprocessor 2 | from .format_preprocessor import FormatPreprocessor 3 | from .lower_case_properties_cypher_query_preprocessor import LowerCasePropertiesCypherQueryPreprocessor 4 | from .synonym_cypher_query_preprocessor import SynonymCypherQueryPreprocessor 5 | -------------------------------------------------------------------------------- /src/fact_finder/tools/cypher_preprocessors/always_distinct_preprocessor.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from fact_finder.tools.cypher_preprocessors.cypher_query_preprocessor import ( 4 | CypherQueryPreprocessor, 5 | ) 6 | 7 | 8 | class AlwaysDistinctCypherQueryPreprocessor(CypherQueryPreprocessor): 9 | def __call__(self, cypher_query: str) -> str: 10 | return re.sub(r"RETURN\s+(?!DISTINCT).*", self._replace_match, cypher_query) 11 | 12 | def _replace_match(self, matches: re.Match[str]) -> str: 13 | return matches.group(0).replace("RETURN", "RETURN DISTINCT") 14 | -------------------------------------------------------------------------------- /src/fact_finder/tools/cypher_preprocessors/child_to_parent_preprocessor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import regex as re 4 | from fact_finder.tools.cypher_preprocessors.property_string_preprocessor import ( 5 | PropertyStringCypherQueryPreprocessor, 6 | ) 7 | from langchain_community.graphs import Neo4jGraph 8 | 9 | 10 | class ChildToParentPreprocessor(PropertyStringCypherQueryPreprocessor): 11 | _mapping_query = ( 12 | "MATCH (child)-[:{child_to_parent_relation}]->(parent)\n" 13 | "WITH child, parent, COUNT{(child)-->()} AS outdegree\n" 14 | "WHERE outdegree = 1\n" 15 | "RETURN DISTINCT child.{name}, parent.{name}" 16 | ) 17 | 18 | def __init__( 19 | self, 20 | graph: Neo4jGraph, 21 | child_to_parent_relation: str, 22 | name_property: str = "name", 23 | ) -> None: 24 | super().__init__(property_names=[name_property]) 25 | self._name_property = name_property 26 | graph_result = self._run_graph_query(graph, child_to_parent_relation) 27 | self._child_to_parent: Dict[str, str] = self._to_parent_child_dict(graph_result) 28 | 29 | def _run_graph_query(self, graph: Neo4jGraph, child_to_parent_relation: str) -> List[Dict[str, str]]: 30 | query = self._mapping_query.replace("{child_to_parent_relation}", child_to_parent_relation) 31 | query = query.replace("{name}", self._name_property) 32 | return graph.query(query) 33 | 34 | def _to_parent_child_dict(self, graph_result: List[Dict[str, str]]) -> Dict[str, str]: 35 | res: Dict[str, str] = dict() 36 | for r in graph_result: 37 | child_name: str = r[f"child.{self._name_property}"] 38 | parent_name: str = r[f"parent.{self._name_property}"] 39 | if child_name != parent_name: 40 | res[child_name] = parent_name 41 | return res 42 | 43 | def _replace_match(self, matches: re.Match[str]) -> str: 44 | assert matches.groups() 45 | block = matches.group(0) 46 | block_start, _ = matches.spans(0)[0] 47 | prev_end = 0 48 | new_block = "" 49 | for node_name, (start, end) in sorted(zip(matches.captures(1), matches.spans(1)), key=lambda x: x[1][0]): 50 | if parent_name := self._child_to_parent.get(node_name): 51 | new_block += block[prev_end : start - block_start] + parent_name 52 | prev_end = end - block_start 53 | new_block += block[prev_end:] 54 | return new_block 55 | -------------------------------------------------------------------------------- /src/fact_finder/tools/cypher_preprocessors/cypher_query_preprocessor.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | 4 | @runtime_checkable 5 | class CypherQueryPreprocessor(Protocol): 6 | def __call__(self, cypher_query: str) -> str: ... 7 | -------------------------------------------------------------------------------- /src/fact_finder/tools/cypher_preprocessors/format_preprocessor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | 4 | from fact_finder.tools.cypher_preprocessors.cypher_query_preprocessor import ( 5 | CypherQueryPreprocessor, 6 | ) 7 | 8 | 9 | class FormatPreprocessor(CypherQueryPreprocessor): 10 | """A preprocessor applying some regex based formating to cypher queries. 11 | This allows simpler regex expressions in subsequent preprocessors. 12 | Partially based on https://github.com/TristanPerry/cypher-query-formatter 13 | (see MIT License: https://github.com/TristanPerry/cypher-query-formatter/blob/master/LICENSE). 14 | 15 | :param CypherQueryPreprocessor: _description_ 16 | :type CypherQueryPreprocessor: _type_ 17 | """ 18 | 19 | def __call__(self, cypher_query: str) -> str: 20 | try: 21 | return self._try_formatting(cypher_query) 22 | except Exception as e: 23 | print(f"Cypher query formatting failed: {e}", file=sys.stderr) 24 | return cypher_query 25 | 26 | def _try_formatting(self, cypher_query: str) -> str: 27 | cypher_query = cypher_query.strip().lower() 28 | cypher_query = self._only_use_double_quotes(cypher_query) 29 | cypher_query = self._keywords_to_upper_case(cypher_query) 30 | cypher_query = self._null_and_boolean_literals_to_lower_case(cypher_query) 31 | cypher_query = self._ensure_main_keywords_on_newline(cypher_query) 32 | cypher_query = self._unix_style_newlines(cypher_query) 33 | cypher_query = self._remove_whitespace_from_start_of_lines(cypher_query) 34 | cypher_query = self._remove_whitespace_from_end_of_lines(cypher_query) 35 | cypher_query = self._add_spaces_after_comma(cypher_query) 36 | cypher_query = self._multiple_spaces_to_single_space(cypher_query) 37 | cypher_query = self._indent_on_create_and_on_match(cypher_query) 38 | cypher_query = self._remove_multiple_empty_newlines(cypher_query) 39 | cypher_query = self._remove_unnecessary_spaces(cypher_query) 40 | 41 | return cypher_query.strip() 42 | 43 | def _only_use_double_quotes(self, cypher_query: str) -> str: 44 | # Escape all single quotes in double quote strings. 45 | cypher_query = re.sub( 46 | r'"([^{}\(\)\[\]=]*?)"', lambda matches: matches.group(0).replace("'", r"\'"), cypher_query 47 | ) 48 | # Replace all not escaped single quotes. 49 | cypher_query = re.sub(r"(? str: 53 | return re.sub( 54 | r"\b(WHEN|CASE|AND|OR|XOR|DISTINCT|AS|IN|STARTS WITH|ENDS WITH|CONTAINS|NOT|SET|ORDER BY)\b", 55 | _keywords_to_upper_case, 56 | cypher_query, 57 | flags=re.IGNORECASE, 58 | ) 59 | 60 | def _null_and_boolean_literals_to_lower_case(self, cypher_query: str) -> str: 61 | return re.sub( 62 | r"\b(NULL|TRUE|FALSE)\b", 63 | _null_and_booleans_to_lower_case, 64 | cypher_query, 65 | flags=re.IGNORECASE, 66 | ) 67 | 68 | def _ensure_main_keywords_on_newline(self, cypher_query: str) -> str: 69 | return re.sub( 70 | r"\b(CASE|DETACH DELETE|DELETE|MATCH|MERGE|LIMIT|OPTIONAL MATCH|RETURN|UNWIND|UNION|WHERE|WITH|GROUP BY)\b", 71 | _main_keywords_on_newline, 72 | cypher_query, 73 | flags=re.IGNORECASE, 74 | ) 75 | 76 | def _unix_style_newlines(self, cypher_query: str) -> str: 77 | return re.sub(r"(\r\n|\r)", "\n", cypher_query) 78 | 79 | def _remove_whitespace_from_start_of_lines(self, cypher_query: str) -> str: 80 | return re.sub(r"^\s+", "", cypher_query, flags=re.MULTILINE) 81 | 82 | def _remove_whitespace_from_end_of_lines(self, cypher_query: str) -> str: 83 | return re.sub(r"\s+$", "", cypher_query, flags=re.MULTILINE) 84 | 85 | def _add_spaces_after_comma(self, cypher_query: str) -> str: 86 | return re.sub(r",([^\s])", lambda matches: ", " + matches.group(1), cypher_query) 87 | 88 | def _multiple_spaces_to_single_space(self, cypher_query: str) -> str: 89 | return re.sub(r"((?![\n])\s)+", " ", cypher_query) 90 | 91 | def _indent_on_create_and_on_match(self, cypher_query: str) -> str: 92 | return re.sub(r"\b(ON CREATE|ON MATCH)\b", _indent_on_create_and_on_match, cypher_query, flags=re.IGNORECASE) 93 | 94 | def _remove_multiple_empty_newlines(self, cypher_query: str) -> str: 95 | return re.sub(r"\n\s*?\n", "\n", cypher_query) 96 | 97 | def _remove_unnecessary_spaces(self, cypher_query: str) -> str: 98 | cypher_query = re.sub(r"(\(|{|\[])\s+", lambda matches: matches.group(1), cypher_query) 99 | cypher_query = re.sub(r"\s+(\)|}|\])", lambda matches: matches.group(1), cypher_query) 100 | cypher_query = re.sub(r"\s*(:|-|>|<)\s*", lambda matches: matches.group(1), cypher_query) 101 | # Retain spaces before property names 102 | cypher_query = re.sub(r':\s*"', ': "', cypher_query) 103 | # Also around equation signs 104 | cypher_query = re.sub(r"\s+=\s*", " = ", cypher_query) 105 | cypher_query = re.sub(r"\s*<=\s*", " <= ", cypher_query) 106 | cypher_query = re.sub(r"\s*>=\s*", " >= ", cypher_query) 107 | return cypher_query 108 | 109 | 110 | def _keywords_to_upper_case(matches: re.Match[str]) -> str: 111 | assert len(matches.groups()) 112 | return matches.group(0).replace(matches.group(1), " " + matches.group(1).upper().strip() + " ") 113 | 114 | 115 | def _null_and_booleans_to_lower_case(matches: re.Match[str]) -> str: 116 | assert len(matches.groups()) 117 | return matches.group(0).replace(matches.group(1), " " + matches.group(1).lower().strip() + " ") 118 | 119 | 120 | def _main_keywords_on_newline(matches: re.Match[str]) -> str: 121 | assert len(matches.groups()) 122 | return matches.group(0).replace(matches.group(1), "\n" + matches.group(1).upper().lstrip() + " ") 123 | 124 | 125 | def _indent_on_create_and_on_match(matches: re.Match[str]) -> str: 126 | assert len(matches.groups()) 127 | return matches.group(0).replace(matches.group(1), "\n " + matches.group(1).upper().lstrip() + " ") 128 | -------------------------------------------------------------------------------- /src/fact_finder/tools/cypher_preprocessors/lower_case_properties_cypher_query_preprocessor.py: -------------------------------------------------------------------------------- 1 | import regex as re 2 | from fact_finder.tools.cypher_preprocessors.property_string_preprocessor import ( 3 | PropertyStringCypherQueryPreprocessor, 4 | ) 5 | 6 | 7 | class LowerCasePropertiesCypherQueryPreprocessor(PropertyStringCypherQueryPreprocessor): 8 | 9 | def _replace_match(self, matches: re.Match[str]) -> str: 10 | assert len(matches.groups()) 11 | text = matches.group(0) 12 | for property in matches.captures(1): 13 | text = text.replace(property, property.lower()) 14 | return text 15 | -------------------------------------------------------------------------------- /src/fact_finder/tools/cypher_preprocessors/property_string_preprocessor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List 3 | 4 | import regex as re 5 | from fact_finder.tools.cypher_preprocessors.cypher_query_preprocessor import ( 6 | CypherQueryPreprocessor, 7 | ) 8 | 9 | 10 | class PropertyStringCypherQueryPreprocessor(CypherQueryPreprocessor, ABC): 11 | def __init__(self, property_names: List[str] = [r"[^{:\s]+"]) -> None: 12 | self._property_names = property_names 13 | 14 | def __call__(self, cypher_query: str) -> str: 15 | for name in self._property_names: 16 | cypher_query = self._replace_property_value_in_node(cypher_query, name) 17 | cypher_query = self._replace_property_value_in_comparison(cypher_query, name) 18 | cypher_query = self._replace_property_value_in_list(cypher_query, name) 19 | return cypher_query 20 | 21 | def _replace_property_value_in_node(self, cypher_query: str, name: str): 22 | return re.sub(r"\{" + name + r': "([^"}]+)"\}', self._replace_match, cypher_query) 23 | 24 | def _replace_property_value_in_comparison(self, cypher_query: str, name: str): 25 | return re.sub(r"[^\s=]+\." + name + r'\s*=\s*"([^"]+)"(\s|$)', self._replace_match, cypher_query) 26 | 27 | def _replace_property_value_in_list(self, cypher_query: str, name: str): 28 | return re.sub(r"\." + name + r'\s*IN\s*\[(?:(?:"([^"]+)"(?:,\s*)*)+?)\]', self._replace_match, cypher_query) 29 | 30 | @abstractmethod 31 | def _replace_match(self, matches: re.Match[str]) -> str: ... 32 | -------------------------------------------------------------------------------- /src/fact_finder/tools/cypher_preprocessors/size_to_count_preprocessor.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from fact_finder.tools.cypher_preprocessors.cypher_query_preprocessor import ( 4 | CypherQueryPreprocessor, 5 | ) 6 | 7 | 8 | class SizeToCountPreprocessor(CypherQueryPreprocessor): 9 | def __call__(self, cypher_query: str) -> str: 10 | cypher_query = re.sub(r"\b(SIZE\b\s*)\(", _size_to_upper_case, cypher_query, flags=re.IGNORECASE) 11 | cypher_query = self._replace_search_word_with_count(cypher_query, "SIZE(") 12 | return cypher_query 13 | 14 | def _replace_search_word_with_count(self, cypher_query, search_term): 15 | len_term = len(search_term) 16 | while (start := cypher_query.find(search_term)) >= 0: 17 | end = self._find_closing_bracket(cypher_query, start + len_term) 18 | cypher_query = ( 19 | cypher_query[:start] + "COUNT{" + cypher_query[start + len_term : end] + "}" + cypher_query[end + 1 :] 20 | ) 21 | return cypher_query 22 | 23 | def _find_closing_bracket(self, cypher_query: str, search_start: int, bracket_values={"(": 1, ")": -1}) -> int: 24 | count = 1 25 | for i, c in enumerate(cypher_query[search_start:]): 26 | count += bracket_values.get(c, 0) 27 | if count == 0: 28 | return i + search_start 29 | raise ValueError("SIZE keyword in Cypher query without closing bracket.") 30 | 31 | 32 | def _size_to_upper_case(matches: re.Match[str]) -> str: 33 | assert len(matches.groups()) 34 | return matches.group(0).replace(matches.group(1), matches.group(1).upper().strip()) 35 | -------------------------------------------------------------------------------- /src/fact_finder/tools/cypher_preprocessors/synonym_cypher_query_preprocessor.py: -------------------------------------------------------------------------------- 1 | import re 2 | from functools import partial 3 | from typing import Iterable, List, Set, Tuple 4 | 5 | from langchain_community.graphs import Neo4jGraph 6 | 7 | from fact_finder.tools.cypher_preprocessors.cypher_query_preprocessor import ( 8 | CypherQueryPreprocessor, 9 | ) 10 | from fact_finder.tools.synonym_finder.synonym_finder import SynonymFinder 11 | 12 | 13 | class SynonymCypherQueryPreprocessor(CypherQueryPreprocessor): 14 | 15 | def __init__( 16 | self, 17 | graph: Neo4jGraph, 18 | synonym_finder: SynonymFinder, 19 | node_types: str | List[str], 20 | search_property_name: str = "name", 21 | replacement_property_name: str | None = None, 22 | ): 23 | self.__graph = graph 24 | self.__synonym_finder = synonym_finder 25 | self.__node_types = {node_types} if isinstance(node_types, str) else set(node_types) 26 | self.__search_property_name = search_property_name 27 | self.__replacement_property_name = ( 28 | replacement_property_name if replacement_property_name else self.__search_property_name 29 | ) 30 | self.__existing_node_properties = dict(self.__get_all_nodes()) 31 | 32 | def __call__(self, cypher_query: str) -> str: 33 | for node_type in self.__node_types: 34 | regex = self._build_match_clause_regex(node_type) 35 | cypher_query = re.sub(regex, partial(self._replace_match, node_type=node_type), cypher_query) 36 | for node_type in self.__node_types: 37 | regex = self._build_where_clause_regex(node_type) 38 | cypher_query = re.sub( 39 | regex, 40 | partial(self._replace_match, node_type=node_type, group_offset=1), 41 | cypher_query, 42 | flags=re.MULTILINE | re.DOTALL, 43 | ) 44 | return cypher_query 45 | 46 | def __get_all_nodes(self) -> Iterable[Tuple[str, Set[str]]]: 47 | for n_type in self.__node_types: 48 | cypher_query = f"MATCH(n:{n_type}) RETURN n" 49 | nodes = self.__graph.query(cypher_query) 50 | yield n_type, set(node["n"][self.__replacement_property_name].lower() for node in nodes) 51 | 52 | def _build_match_clause_regex(self, node_type) -> str: 53 | return r"\([^\s{:]+:" + node_type + r"\s*{(" + self.__search_property_name + r"): ['\"]([^}]+)['\"]}\)" 54 | 55 | def _build_where_clause_regex(self, node_type) -> str: 56 | return ( 57 | r"MATCH[^$]*\(([^\s{:]+):" + node_type + r"\).*?" 58 | r"WHERE[^$]*?\1\.(" + self.__search_property_name + r")\s*=\s*\"([^\"]+)\"" 59 | ) 60 | 61 | def _replace_match(self, match: re.Match[str], node_type: str, group_offset: int = 0) -> str: 62 | assert len(match.groups()) == 2 + group_offset 63 | text = match.group(0) 64 | extracted_property = match.group(2 + group_offset) 65 | replacement_property = self.__find_synonym(node_type, extracted_property) 66 | if replacement_property is None: 67 | return text 68 | return self._build_new_text(text, match, group_offset, replacement_property) 69 | 70 | def _build_new_text(self, text: str, match: re.Match[str], group_offset: int, replacement_property: str) -> str: 71 | hit_offset = match.start(0) 72 | start1 = match.start(1 + group_offset) - hit_offset 73 | end1 = match.end(1 + group_offset) - hit_offset 74 | start2 = match.start(2 + group_offset) - hit_offset 75 | end2 = match.end(2 + group_offset) - hit_offset 76 | new_text = ( 77 | text[:start1] + self.__replacement_property_name + text[end1:start2] + replacement_property + text[end2:] 78 | ) 79 | return new_text 80 | 81 | def __find_synonym(self, node_type: str, node_property: str) -> str | None: 82 | synonyms = self.__synonym_finder(node_property) 83 | for synonym in synonyms: 84 | if self.__exists_in_graph(node_type, synonym): 85 | return synonym 86 | return None 87 | 88 | def __exists_in_graph(self, node_type: str, property_value: str) -> bool: 89 | return property_value.lower() in self.__existing_node_properties[node_type] 90 | -------------------------------------------------------------------------------- /src/fact_finder/tools/entity_detector.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Dict, List 4 | 5 | import requests 6 | from dotenv import load_dotenv 7 | 8 | load_dotenv() 9 | 10 | 11 | _DEFAULT_ENTITY_FILENAMES = [ 12 | "sci_gene_human_dictionary.txt", # Human genes 13 | "sci_ab.dictionary.txt", # Antibodies 14 | "sci_disease_dictionary.txt", # Diseases 15 | "hubble_cells.dictionary.txt", # Cells 16 | "sci_drug_dictionary.txt", # Drugs 17 | "organ_thesarus.txt", # Organs 18 | "celline.dictionary.txt", # Celllines 19 | ] 20 | 21 | 22 | class EntityDetector: 23 | 24 | def __init__(self, filenames: List[str] = _DEFAULT_ENTITY_FILENAMES): 25 | self.__possible_filenames = filenames 26 | self.__url = os.getenv("SYNONYM_API_URL") 27 | self.__api_key = os.getenv("SYNONYM_API_KEY") 28 | if self.__api_key is None or self.__url is None: 29 | raise ValueError("For using EntityDetector, the env variable SYNONYM_API_KEY as well as SYNONYM_API_URL must be set.") 30 | 31 | def __call__(self, search_text: str) -> List[Dict[str, Any]]: 32 | filenames_as_single_string = ", ".join(self.__possible_filenames) 33 | payload = json.dumps({"public_dic": filenames_as_single_string, "text": search_text}) 34 | headers = {"x-api-key": self.__api_key, "Accept": "application/json", "Content-Type": "application/json"} 35 | response = requests.request("POST", self.__url, headers=headers, data=payload) 36 | if response.status_code == 200: 37 | return json.loads(response.text)["annotations"] 38 | response.raise_for_status() 39 | return [] 40 | -------------------------------------------------------------------------------- /src/fact_finder/tools/semantic_scholar_search_api_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import requests 5 | 6 | 7 | class SemanticScholarSearchApiWrapper: 8 | 9 | def __init__(self): 10 | SEMANTIC_SCHOLAR_KEY = os.getenv("SEMANTIC_SCHOLAR_KEY") 11 | self._session = requests.Session() 12 | self._semantic_scholar_endpoint = "https://api.semanticscholar.org/graph/v1/paper/search" 13 | self._header = {"x-api-key": SEMANTIC_SCHOLAR_KEY} 14 | 15 | def search_by_abstracts(self, keywords: str) -> List[str]: 16 | query_params = {"query": keywords, "limit": 5, "fields": "title,abstract"} 17 | response = self._session.get(self._semantic_scholar_endpoint, params=query_params, headers=self._header) 18 | if response.status_code != 200: 19 | raise ValueError(f"Semantic scholar API returned an error:\n{response}") 20 | papers = response.json()["data"] 21 | results = ["\n".join([paper["title"], paper["abstract"] or ""]) for paper in papers] 22 | return results 23 | 24 | def search_by_paper_content(self, keywords: str) -> List[str]: 25 | # Note that semantic scholar does only retrieve abstracts, and even they may be missing. 26 | # we could get the PDF links and try to access papers dynamically, run pdf extraction, etc. 27 | # Or: we decide on papers to download and preprocess and put in a local vector db. 28 | 29 | # Otherwise the idea here would be to get 1 or 2 top papers, dynamically embedd small chunks, run a retriever on that and the plug this 30 | # into a qa prompt. 31 | raise NotImplementedError 32 | 33 | 34 | if __name__ == "__main__": 35 | sem = SemanticScholarSearchApiWrapper() 36 | print(sem.search_by_abstracts("Alternative causes, fever, malaria infections")) 37 | -------------------------------------------------------------------------------- /src/fact_finder/tools/sub_graph_extractor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | from abc import ABC, abstractmethod 4 | 5 | from langchain.chains import LLMChain 6 | from langchain_core.language_models import BaseLanguageModel 7 | 8 | from fact_finder.prompt_templates import SUBGRAPH_EXTRACTOR_PROMPT 9 | 10 | 11 | class SubGraphExtractor(ABC): 12 | @abstractmethod 13 | def __call__(self, cypher_query: str) -> str: ... 14 | 15 | 16 | class LLMSubGraphExtractor(SubGraphExtractor): 17 | 18 | def __init__(self, model: BaseLanguageModel): 19 | self.llm_chain = LLMChain(llm=model, prompt=SUBGRAPH_EXTRACTOR_PROMPT) 20 | 21 | def __call__(self, cypher_query: str) -> str: 22 | result = self.llm_chain(cypher_query) 23 | return result["text"] 24 | 25 | 26 | class RegexSubGraphExtractor(SubGraphExtractor): 27 | 28 | def __init__(self): 29 | self.__extract_group_regex = r"(\(.*\))[-|<|>]+(\[.*\])[-|<|>]+(\(.*\))" 30 | self.__extract_subject_regex = self.__extract_object_regex = r"(?<=\()([a-zA-Z0-9]*)" 31 | self.__extract_predicate_regex = r"(?<=\[)([a-zA-Z0-9]*)" 32 | 33 | def __call__(self, cypher_query: str) -> str: 34 | result = re.search(self.__extract_group_regex, cypher_query) 35 | subject, predicate, object = result.groups() 36 | subject_variable = re.search(self.__extract_subject_regex, subject).group(0) 37 | predicate_variable = re.search(self.__extract_predicate_regex, predicate).group(0) 38 | object_variable = re.search(self.__extract_object_regex, object).group(0) 39 | all_letters = list(string.ascii_lowercase) 40 | all_possible_letters = list(set(all_letters) - {object_variable, subject_variable, predicate_variable}) 41 | if len(subject_variable) == 0: 42 | subject_variable = all_possible_letters.pop() 43 | new_subject = self.__replace_with_variable(subject, subject_variable) 44 | cypher_query = cypher_query.replace(subject, new_subject) 45 | if len(predicate_variable) == 0: 46 | predicate_variable = all_possible_letters.pop() 47 | new_predicate = self.__replace_with_variable(predicate, predicate_variable) 48 | cypher_query = cypher_query.replace(predicate, new_predicate) 49 | if len(object_variable) == 0: 50 | object_variable = all_possible_letters.pop() 51 | new_object = self.__replace_with_variable(object, object_variable) 52 | cypher_query = cypher_query.replace(object, new_object) 53 | cypher_query = re.sub( 54 | r"return.*", 55 | f"return {subject_variable},{predicate_variable},{object_variable}", 56 | cypher_query, 57 | flags=re.IGNORECASE, 58 | ) 59 | return cypher_query 60 | 61 | def __replace_with_variable(self, text_without_variable: str, variable: str) -> str: 62 | result = text_without_variable[0] + variable + text_without_variable[1:] 63 | return result 64 | -------------------------------------------------------------------------------- /src/fact_finder/tools/subgraph_extension.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List, Dict, Any 3 | 4 | from langchain_community.graphs import Neo4jGraph 5 | 6 | from fact_finder.utils import graph_result_contains_triple, get_triples_from_graph_result 7 | 8 | 9 | class SubgraphExpansion: 10 | 11 | def __init__(self, graph: Neo4jGraph): 12 | self._graph: Neo4jGraph = graph 13 | self._extension_query_template = """ 14 | MATCH (a {{index: {head_index} }})-[r1]-(c)-[r2]-(b {{index: {tail_index} }}) 15 | RETURN a, r1, c, r2, b 16 | """ 17 | 18 | def expand(self, nodes: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 19 | nodes = copy.deepcopy(nodes) 20 | result = [] 21 | for entry in nodes: 22 | if graph_result_contains_triple(graph_result_entry=entry): 23 | triples = get_triples_from_graph_result(graph_result_entry=entry) 24 | for triple in triples: 25 | result += self._enrich(triple=triple) 26 | return nodes + result 27 | 28 | def _enrich(self, triple) -> List[Dict[str, Any]]: 29 | head_index = triple[0]["index"] 30 | tail_index = triple[2]["index"] 31 | extension_query = self._extension_query_template.format(head_index=head_index, tail_index=tail_index) 32 | return self._query_graph(cypher=extension_query) 33 | 34 | def _query_graph(self, cypher) -> List[Dict[str, Any]]: 35 | try: 36 | return self._graph.query(cypher) 37 | except Exception as e: 38 | print(f"Sub Graph for {cypher} could not be extracted due to {e}") 39 | return [] 40 | -------------------------------------------------------------------------------- /src/fact_finder/tools/synonym_finder/__init__.py: -------------------------------------------------------------------------------- 1 | from .preferred_term_finder import PreferredTermFinder 2 | from .synonym_finder import SynonymFinder 3 | from .wiki_data_synonym_finder import WikiDataSynonymFinder 4 | from .word_net_synonym_finder import WordNetSynonymFinder 5 | -------------------------------------------------------------------------------- /src/fact_finder/tools/synonym_finder/aggregate_state_synonym_finder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | from fact_finder.tools.synonym_finder.synonym_finder import SynonymFinder 5 | 6 | 7 | class AggregateStateSynonymFinder(SynonymFinder): 8 | 9 | def __call__(self, name: str) -> List[str]: 10 | name = name.strip().lower() 11 | if name in ["gas", "gases", "gaseous", "gassy", "gasiform", "aerially", "aeriform", "vapor", "vapour"]: 12 | return ["gas"] 13 | if name in ["liquid", "liquids", "fluid", "fluids"]: 14 | return ["liquid"] 15 | if name in ["solid", "solids"]: 16 | return ["solid"] 17 | logging.error(f"Unknown aggregate state: {name}") 18 | return [name] 19 | -------------------------------------------------------------------------------- /src/fact_finder/tools/synonym_finder/preferred_term_finder.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List 2 | 3 | from fact_finder.tools.entity_detector import EntityDetector 4 | from fact_finder.tools.synonym_finder.synonym_finder import SynonymFinder 5 | 6 | 7 | class PreferredTermFinder(SynonymFinder): 8 | """Searches and returns preferred names using the EntityDetector. 9 | Returns only preferred names of entities from the allowed categories (semantic types). 10 | """ 11 | 12 | def __init__(self, allowed_categories: Iterable[str]) -> None: 13 | self._detector = EntityDetector() 14 | self._allowed_categories = set(map(str.lower, allowed_categories)) 15 | 16 | def __call__(self, name: str) -> Iterable[str]: 17 | yield name 18 | for r in self._detector(name): 19 | if r["sem_type"].lower() in self._allowed_categories: 20 | yield r["pref_term"] 21 | 22 | 23 | class PreferredTermIdFinder(SynonymFinder): 24 | """Searches preferred names using the EntityDetector. 25 | Returns the corresponding ids. 26 | Returns only ids of entities from the allowed categories (semantic types). 27 | """ 28 | 29 | def __init__(self, allowed_categories: Iterable[str]) -> None: 30 | self._detector = EntityDetector() 31 | self._allowed_categories = set(map(str.lower, allowed_categories)) 32 | 33 | def __call__(self, name: str) -> List[str]: 34 | return [r["id"] for r in self._detector(name) if r["sem_type"].lower() in self._allowed_categories] 35 | -------------------------------------------------------------------------------- /src/fact_finder/tools/synonym_finder/synonym_finder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Iterable, List 3 | 4 | 5 | class SynonymFinder(ABC): 6 | 7 | @abstractmethod 8 | def __call__(self, name: str) -> Iterable[str]: 9 | pass 10 | 11 | 12 | class SimilaritySynonymFinder(SynonymFinder): 13 | 14 | def __call__(self, name: str) -> Iterable[str]: 15 | # TODO find most similar node from all nodes with vector cosine similarity search 16 | raise NotImplementedError 17 | 18 | 19 | class SubWordSynonymFinder(SynonymFinder): 20 | 21 | def __call__(self, name: str) -> Iterable[str]: 22 | # TODO 2010.11784.pdf (arxiv.org) 23 | raise NotImplementedError 24 | -------------------------------------------------------------------------------- /src/fact_finder/tools/synonym_finder/wiki_data_synonym_finder.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | from typing import List 3 | 4 | from fact_finder.tools.synonym_finder.synonym_finder import SynonymFinder 5 | from SPARQLWrapper import JSON, SPARQLWrapper 6 | 7 | 8 | class WikiDataSynonymFinder(SynonymFinder): 9 | 10 | def __init__( 11 | self, 12 | endpoint_url: str = "https://query.wikidata.org/sparql", 13 | user_agent: str = "factfinder/1.0", 14 | ): 15 | self.__endpoint_url = endpoint_url 16 | self.__user_agent = user_agent 17 | 18 | def __call__(self, name: str) -> List[str]: 19 | query = self.__generate_sparql_forwards_query(name) 20 | results = self.__get_sparql_results(query) 21 | if not results: 22 | query = self.__generate_sparql_backwards_query(name) 23 | results = self.__get_sparql_results(query) 24 | if not results: 25 | return [name] 26 | return results 27 | 28 | def __get_sparql_results(self, query): 29 | ssl._create_default_https_context = ssl._create_unverified_context 30 | sparql = SPARQLWrapper(self.__endpoint_url, agent=self.__user_agent) 31 | sparql.setQuery(query) 32 | sparql.setReturnFormat(JSON) 33 | results = sparql.query().convert() 34 | results = [result["label"]["value"] for result in results["results"]["bindings"]] 35 | return results 36 | 37 | def __generate_sparql_forwards_query(self, name: str) -> str: 38 | query = """SELECT DISTINCT ?label WHERE { 39 | ?s skos:altLabel "%s"@en. 40 | ?s rdfs:label ?label . 41 | FILTER(str(lang(?label)) = "en") 42 | }""" % ( 43 | name 44 | ) 45 | return query 46 | 47 | def __generate_sparql_backwards_query(self, name: str) -> str: 48 | query = """SELECT DISTINCT ?alt_label WHERE { 49 | ?s skos:altLabel ?alt_label . 50 | ?s rdfs:label "%s"@en . 51 | FILTER(str(lang(?label)) = "en") 52 | }""" % ( 53 | name 54 | ) 55 | return query 56 | -------------------------------------------------------------------------------- /src/fact_finder/tools/synonym_finder/word_net_synonym_finder.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import List 3 | 4 | import nltk 5 | from nltk.corpus import wordnet as wn 6 | 7 | from fact_finder.tools.synonym_finder.synonym_finder import SynonymFinder 8 | 9 | 10 | class WordNetSynonymFinder(SynonymFinder): 11 | 12 | def __init__(self) -> None: 13 | nltk.download("wordnet") 14 | 15 | def __call__(self, name: str) -> List[str]: 16 | result = list(itertools.chain(*wn.synonyms(word=name))) 17 | if name not in result: 18 | result.append(name) 19 | return result 20 | -------------------------------------------------------------------------------- /src/fact_finder/ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/src/fact_finder/ui/__init__.py -------------------------------------------------------------------------------- /src/fact_finder/ui/graph_conversion.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | from typing import Any, Dict, List 4 | 5 | from pydantic import BaseModel 6 | 7 | from fact_finder.utils import get_triples_from_graph_result, graph_result_contains_triple 8 | 9 | 10 | class Node(BaseModel): 11 | id: int 12 | type: str 13 | name: str 14 | in_query: bool 15 | in_answer: bool 16 | 17 | 18 | class Edge(BaseModel): 19 | id: int 20 | type: str 21 | name: str 22 | source: int 23 | target: int 24 | in_query: bool 25 | in_answer: bool 26 | 27 | 28 | class Subgraph(BaseModel): 29 | nodes: List[Node] 30 | edges: List[Edge] 31 | 32 | 33 | def convert_subgraph(graph: List[Dict[str, Any]], result: List[Dict[str, Any]]) -> (Subgraph, str): 34 | graph_converted = Subgraph(nodes=[], edges=[]) 35 | graph_triplets = "" 36 | 37 | try: 38 | result_ents = [] 39 | for res in result: 40 | result_ents += res.values() 41 | 42 | idx_rel = 0 43 | for entry in graph: 44 | if graph_result_contains_triple(graph_result_entry=entry): 45 | graph_triplet, idx_rel = _process_triples(entry, graph_converted, result_ents, idx_rel) 46 | else: 47 | graph_triplet = _process_nodes_only(entry, graph_converted, result_ents) 48 | 49 | graph_triplets += graph_triplet 50 | 51 | except Exception as e: 52 | print(e) 53 | 54 | return (graph_converted, graph_triplets) 55 | 56 | 57 | def _process_triples(entry, graph_converted: Subgraph, result_ents: list, idx_rel: int) -> int: 58 | triples_as_string = "" 59 | triples = get_triples_from_graph_result(graph_result_entry=entry) 60 | for triple in triples: 61 | triples_as_string += _process_triple( 62 | entry=entry, graph_converted=graph_converted, result_ents=result_ents, idx_rel=idx_rel, triple=triple 63 | ) 64 | idx_rel += 1 65 | return triples_as_string, idx_rel 66 | 67 | 68 | def _process_triple(entry, graph_converted: Subgraph, result_ents: list, idx_rel: int, triple: dict) -> str: 69 | graph_triplet = "" 70 | 71 | head_type = [key for key, value in entry.items() if value == triple[0]] 72 | tail_type = [key for key, value in entry.items() if value == triple[2]] 73 | head_type = head_type[0] if len(head_type) > 0 else "" 74 | tail_type = tail_type[0] if len(tail_type) > 0 else "" 75 | node_head = triple[0] if "index" in triple[0] else list(entry.values())[0] 76 | node_tail = triple[2] if "index" in triple[2] else list(entry.values())[2] 77 | 78 | if "index" in node_head and node_head["index"] not in [node.id for node in graph_converted.nodes]: 79 | graph_converted.nodes.append( 80 | Node( 81 | id=node_head["index"], 82 | type=head_type, 83 | name=node_head["name"], 84 | in_query=False, 85 | in_answer=node_head["name"] in result_ents, 86 | ) 87 | ) 88 | if "index" in node_tail and node_tail["index"] not in [node.id for node in graph_converted.nodes]: 89 | graph_converted.nodes.append( 90 | Node( 91 | id=node_tail["index"], 92 | type=tail_type, 93 | name=node_tail["name"], 94 | in_query=False, 95 | in_answer=node_tail["name"] in result_ents, 96 | ) 97 | ) 98 | if "index" in node_head and "index" in node_tail: 99 | graph_converted.edges.append( 100 | Edge( 101 | id=idx_rel, 102 | type=triple[1], 103 | name=triple[1], 104 | source=node_head["index"], 105 | target=node_tail["index"], 106 | in_query=False, 107 | in_answer=node_tail["name"] in result_ents, 108 | ) 109 | ) 110 | 111 | try: 112 | graph_triplet = f'("{node_head["name"]}", "{triple[1]}", "{node_tail["name"]}"), ' 113 | except Exception as e: 114 | print(e) 115 | 116 | return graph_triplet 117 | 118 | 119 | def _process_nodes_only(entry, graph_converted: Subgraph, result_ents: list) -> None: 120 | graph_triplet = "" 121 | for variable_binding, possible_node in entry.items(): 122 | if not isinstance(possible_node, dict): 123 | continue 124 | if "index" in possible_node and possible_node["index"] not in [node.id for node in graph_converted.nodes]: 125 | graph_converted.nodes.append( 126 | Node( 127 | id=possible_node["index"], 128 | type=variable_binding, 129 | name=possible_node["name"], 130 | in_query=False, 131 | in_answer=possible_node["name"] in result_ents, 132 | ) 133 | ) 134 | try: 135 | graph_triplet += f'("{possible_node["name"]}"), ' 136 | except Exception as e: 137 | print(e) 138 | return graph_triplet 139 | -------------------------------------------------------------------------------- /src/fact_finder/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Any 3 | 4 | from langchain.chains import LLMChain 5 | from langchain_community.graphs import Neo4jGraph 6 | from langchain_core.language_models import BaseChatModel 7 | from langchain_openai import AzureChatOpenAI, ChatOpenAI 8 | 9 | 10 | def concatenate_with_headers(answers: List[Dict[str, str]]) -> str: 11 | result = "" 12 | for answer in answers: 13 | for header, text in answer.items(): 14 | result += header + "\n" + text + "\n\n" 15 | return result 16 | 17 | 18 | def build_neo4j_graph() -> Neo4jGraph: 19 | """ 20 | 21 | :rtype: object 22 | """ 23 | NEO4J_URL = os.getenv("NEO4J_URL", "bolt://localhost:7687") 24 | NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") 25 | NEO4J_PW = os.getenv("NEO4J_PW", "opensesame") 26 | return Neo4jGraph(url=NEO4J_URL, username=NEO4J_USER, password=NEO4J_PW) 27 | 28 | 29 | def get_model_from_env(): 30 | return os.getenv("LLM", "gpt-4o") 31 | 32 | 33 | def load_chat_model() -> BaseChatModel: 34 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 35 | assert OPENAI_API_KEY is not None, "An OpenAI API key has to be set as environment variable OPENAI_API_KEY." 36 | if os.getenv("AZURE_OPENAI_ENDPOINT") is not None: 37 | endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") 38 | assert endpoint is not None 39 | deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME") 40 | api_version = "2023-05-15" 41 | os.environ["AZURE_OPENAI_API_KEY"] = OPENAI_API_KEY 42 | os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint 43 | return AzureChatOpenAI(openai_api_version=api_version, azure_deployment=deployment_name) 44 | model = get_model_from_env() 45 | return ChatOpenAI(model=model, streaming=False, temperature=0, api_key=OPENAI_API_KEY) 46 | 47 | 48 | def graph_result_contains_triple(graph_result_entry): 49 | return len(get_triples_from_graph_result(graph_result_entry)) > 0 50 | 51 | 52 | def get_triples_from_graph_result(graph_result_entry) -> List[dict]: 53 | return [value for key, value in graph_result_entry.items() if type(value) is tuple] 54 | 55 | 56 | def fill_prompt_template(llm_chain: LLMChain, inputs: Dict[str, Any]) -> str: 57 | return llm_chain.prep_prompts([inputs])[0][0].text 58 | -------------------------------------------------------------------------------- /src/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/src/img/logo.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/tests/__init__.py -------------------------------------------------------------------------------- /tests/chains/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/tests/chains/__init__.py -------------------------------------------------------------------------------- /tests/chains/graph_qa_chain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/tests/chains/graph_qa_chain/__init__.py -------------------------------------------------------------------------------- /tests/chains/graph_qa_chain/test_graph_qa_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Literal 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | from fact_finder.chains.graph_qa_chain.config import GraphQAChainConfig 6 | from fact_finder.chains.graph_qa_chain.graph_qa_chain import GraphQAChain 7 | from langchain.chains.graph_qa.cypher import construct_schema 8 | from langchain_community.graphs import Neo4jGraph 9 | from langchain_core.language_models import BaseLanguageModel 10 | from langchain_core.outputs import Generation, LLMResult 11 | from langchain_core.prompt_values import PromptValue 12 | from langchain_core.prompts.prompt import PromptTemplate 13 | 14 | 15 | def test_cypher_generation_is_called_with_expected_arguments( 16 | question: str, chain: GraphQAChain, llm: BaseLanguageModel, cypher_prompt: str 17 | ): 18 | chain.invoke({"question": question}) 19 | all_params = [p.args[0][0] for p in llm.generate_prompt.mock_calls] 20 | assert cypher_prompt in all_params 21 | 22 | 23 | def test_graph_is_called_with_expected_cypher_query( 24 | question: str, chain: GraphQAChain, graph: Neo4jGraph, cypher_query: str 25 | ): 26 | chain.invoke({"question": question}) 27 | all_params = [p.args[0] for p in graph.query.mock_calls] 28 | assert cypher_query in all_params 29 | 30 | 31 | def test_qa_is_called_with_expected_arguments( 32 | question: str, chain: GraphQAChain, llm: BaseLanguageModel, answer_prompt: str 33 | ): 34 | chain.invoke({"question": question}) 35 | all_params = [p.args[0][0] for p in llm.generate_prompt.mock_calls] 36 | assert answer_prompt in all_params 37 | 38 | 39 | def test_returned_result_matches_model_output(question: str, chain: GraphQAChain, system_answer: str): 40 | assert chain.invoke({"question": question})["graph_qa_output"].answer == system_answer 41 | 42 | 43 | @pytest.mark.parametrize("cypher_query", ["SCHEMA_ERROR: This is not a cypher query!"], indirect=True) 44 | def test_invalid_cypher_query_is_returned_directly( 45 | question: str, chain: GraphQAChain, cypher_query: Literal["SCHEMA_ERROR: This is not a cypher query!"] 46 | ): 47 | assert chain.invoke({"question": question})["graph_qa_output"].answer == cypher_query[len("SCHEMA_ERROR: ") :] 48 | 49 | 50 | @pytest.fixture 51 | def chain( 52 | cypher_prompt_template: PromptTemplate, 53 | answer_generation_prompt_template: PromptTemplate, 54 | llm: BaseLanguageModel, 55 | graph: Neo4jGraph, 56 | ): 57 | config = GraphQAChainConfig.construct( 58 | llm=llm, 59 | graph=graph, 60 | cypher_prompt=cypher_prompt_template, 61 | answer_generation_prompt=answer_generation_prompt_template, 62 | cypher_query_preprocessors=[], 63 | return_intermediate_steps=True, 64 | ) 65 | return GraphQAChain(config) 66 | 67 | 68 | @pytest.fixture 69 | def cypher_prompt(cypher_prompt_template: PromptTemplate, schema: str, question: str) -> str: 70 | return cypher_prompt_template.format_prompt(schema=schema, question=question) 71 | 72 | 73 | @pytest.fixture 74 | def cypher_prompt_template() -> PromptTemplate: 75 | return PromptTemplate( 76 | input_variables=["schema", "question"], template="Generate some cypher with schema:\n{schema}\nFor {question}:" 77 | ) 78 | 79 | 80 | @pytest.fixture 81 | def answer_prompt( 82 | answer_generation_prompt_template: PromptTemplate, query_response: List[Dict[str, str]], question: str 83 | ) -> str: 84 | return answer_generation_prompt_template.format_prompt(context=query_response, question=question) 85 | 86 | 87 | @pytest.fixture 88 | def answer_generation_prompt_template() -> PromptTemplate: 89 | return PromptTemplate( 90 | input_variables=["context", "question"], template="Generate an answer to {question} using {context}:" 91 | ) 92 | 93 | 94 | @pytest.fixture 95 | def question() -> str: 96 | return "" 97 | 98 | 99 | @pytest.fixture 100 | def llm(cypher_query: str, system_answer: str) -> BaseLanguageModel: 101 | def llm_side_effect(prompts: List[PromptValue], *args, **kwargs) -> LLMResult: 102 | if "cypher" in prompts[0].to_string().lower(): 103 | text = cypher_query 104 | else: 105 | text = system_answer 106 | return LLMResult(generations=[[Generation(text=text)]]) 107 | 108 | llm = MagicMock(spec=BaseLanguageModel) 109 | llm.generate_prompt = MagicMock() 110 | llm.generate_prompt.side_effect = llm_side_effect 111 | return llm 112 | 113 | 114 | @pytest.fixture(params=[""]) 115 | def cypher_query(request: pytest.FixtureRequest) -> str: 116 | return request.param 117 | 118 | 119 | @pytest.fixture 120 | def system_answer() -> str: 121 | return "" 122 | 123 | 124 | @pytest.fixture 125 | def graph(structured_schema: Dict[str, Any], query_response: List[str]) -> Neo4jGraph: 126 | graph = MagicMock(spec=Neo4jGraph) 127 | graph.get_structured_schema = structured_schema 128 | graph.query = MagicMock() 129 | graph.query.return_value = query_response 130 | return graph 131 | 132 | 133 | @pytest.fixture 134 | def schema(structured_schema: Dict[str, Any]) -> str: 135 | return construct_schema(structured_schema, [], []) 136 | 137 | 138 | @pytest.fixture 139 | def structured_schema() -> Dict[str, Any]: 140 | return { 141 | "node_props": { 142 | "disease": [ 143 | {"property": "id", "type": "STRING"}, 144 | {"property": "name", "type": "STRING"}, 145 | {"property": "source", "type": "STRING"}, 146 | {"property": "index", "type": "INTEGER"}, 147 | ], 148 | "anatomy": [ 149 | {"property": "id", "type": "STRING"}, 150 | {"property": "name", "type": "STRING"}, 151 | {"property": "source", "type": "STRING"}, 152 | {"property": "index", "type": "INTEGER"}, 153 | ], 154 | }, 155 | "rel_props": {}, 156 | "relationships": [ 157 | {"start": "disease", "type": "associated_with", "end": "gene_or_protein"}, 158 | {"start": "disease", "type": "phenotype_present", "end": "effect_or_phenotype"}, 159 | {"start": "disease", "type": "phenotype_absent", "end": "effect_or_phenotype"}, 160 | {"start": "disease", "type": "parent_child", "end": "disease"}, 161 | ], 162 | } 163 | 164 | 165 | @pytest.fixture 166 | def query_response() -> List[Dict[str, str]]: 167 | return [{"node": ""}] 168 | -------------------------------------------------------------------------------- /tests/chains/graph_qa_chain/test_graph_qa_chain_e2e.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | from dotenv import dotenv_values 7 | from fact_finder.chains.graph_qa_chain.graph_qa_chain import GraphQAChain 8 | from fact_finder.prompt_templates import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT 9 | from fact_finder.tools.cypher_preprocessors.cypher_query_preprocessor import ( 10 | CypherQueryPreprocessor, 11 | ) 12 | from fact_finder.tools.cypher_preprocessors.format_preprocessor import ( 13 | FormatPreprocessor, 14 | ) 15 | from fact_finder.tools.cypher_preprocessors.lower_case_properties_cypher_query_preprocessor import ( 16 | LowerCasePropertiesCypherQueryPreprocessor, 17 | ) 18 | from langchain_community.graphs import Neo4jGraph 19 | from langchain_openai import ChatOpenAI 20 | 21 | 22 | @pytest.mark.skip(reason="end to end") 23 | @patch.dict(os.environ, {**dotenv_values(), **os.environ}) 24 | def test_e2e(e2e_chain: GraphQAChain): 25 | questions = [ 26 | "Which drugs are associated with epilepsy?", 27 | "Which drugs are associated with schizophrenia?", 28 | "Which medication has the most indications?", 29 | "What are the phenotypes associated with cardioacrofacial dysplasia?", 30 | ] 31 | for question in questions: 32 | result = run_e2e_chain(e2e_chain=e2e_chain, question=question) 33 | assert len(result) == 3 34 | assert e2e_chain.output_key in result.keys() and e2e_chain.intermediate_steps_key in result.keys() 35 | 36 | 37 | def run_e2e_chain(e2e_chain: GraphQAChain, question: str) -> Dict[str, Any]: 38 | message = {e2e_chain.input_key: question} 39 | return e2e_chain.invoke(input=message) 40 | 41 | 42 | @pytest.fixture(scope="module") 43 | def e2e_chain(model_e2e: ChatOpenAI, graph_e2e: Neo4jGraph, preprocessors_e2e: list) -> GraphQAChain: 44 | return GraphQAChain( 45 | llm=model_e2e, 46 | graph=graph_e2e, 47 | cypher_prompt=CYPHER_GENERATION_PROMPT, 48 | answer_generation_prompt=CYPHER_QA_PROMPT, 49 | cypher_query_preprocessors=preprocessors_e2e, 50 | return_intermediate_steps=True, 51 | ) 52 | 53 | 54 | @pytest.fixture(scope="module") 55 | def preprocessors_e2e() -> List[CypherQueryPreprocessor]: 56 | cypher_query_formatting_preprocessor = FormatPreprocessor() 57 | lower_case_preprocessor = LowerCasePropertiesCypherQueryPreprocessor() 58 | return [cypher_query_formatting_preprocessor, lower_case_preprocessor] 59 | 60 | 61 | @pytest.fixture(scope="module") 62 | def graph_e2e(neo4j_url: str, neo4j_user: str, neo4j_pw: str) -> Neo4jGraph: 63 | return Neo4jGraph(url=neo4j_url, username=neo4j_user, password=neo4j_pw) 64 | 65 | 66 | @pytest.fixture(scope="module") 67 | def model_e2e(open_ai_key: str) -> ChatOpenAI: 68 | return ChatOpenAI(model="gpt-4", streaming=False, temperature=0, api_key=open_ai_key) 69 | 70 | 71 | @pytest.fixture(scope="module") 72 | def neo4j_url() -> str: 73 | return os.getenv("NEO4J_URL", "bolt://localhost:7687") 74 | 75 | 76 | @pytest.fixture(scope="module") 77 | def neo4j_user() -> str: 78 | return os.getenv("NEO4J_USER", "neo4j") 79 | 80 | 81 | @pytest.fixture(scope="module") 82 | def neo4j_pw() -> str: 83 | return os.getenv("NEO4J_PW", "opensesame") 84 | 85 | 86 | @pytest.fixture(scope="module") 87 | def open_ai_key() -> str: 88 | return os.getenv("OPENAI_API_KEY") 89 | -------------------------------------------------------------------------------- /tests/chains/helpers.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | from langchain_core.language_models import BaseLanguageModel 4 | from langchain_core.outputs import Generation, LLMResult 5 | 6 | 7 | def build_llm_mock(output: str) -> BaseLanguageModel: 8 | llm = MagicMock(spec=BaseLanguageModel) 9 | llm.generate_prompt = MagicMock() 10 | llm.generate_prompt.return_value = LLMResult(generations=[[Generation(text=output)]]) 11 | return llm 12 | -------------------------------------------------------------------------------- /tests/chains/test_answer_generation_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | from unittest.mock import ANY, MagicMock 3 | 4 | import pytest 5 | from fact_finder.chains.answer_generation_chain import AnswerGenerationChain 6 | from langchain_core.language_models import BaseChatModel, BaseLanguageModel 7 | from langchain_core.prompts import PromptTemplate 8 | from langchain_core.prompts.string import StringPromptValue 9 | from tests.chains.helpers import build_llm_mock 10 | 11 | 12 | def test_output_key_in_result( 13 | answer_generation_chain: AnswerGenerationChain, 14 | output_from_graph_chain: Dict[str, Any], 15 | ): 16 | result = answer_generation_chain.invoke(input=output_from_graph_chain) 17 | assert answer_generation_chain.output_key in result.keys() 18 | 19 | 20 | def test_returned_result_matches_model_output( 21 | answer_generation_chain: AnswerGenerationChain, output_from_graph_chain: Dict[str, Any], llm_answer: str 22 | ): 23 | result = answer_generation_chain.invoke(input=output_from_graph_chain) 24 | assert result[answer_generation_chain.output_key] == llm_answer 25 | 26 | 27 | def test_answer_generation_llm_is_called_with_expected_arguments( 28 | answer_generation_chain: AnswerGenerationChain, 29 | llm: BaseLanguageModel, 30 | generation_prompt: str, 31 | output_from_graph_chain: Dict[str, Any], 32 | ): 33 | answer_generation_chain.invoke(input=output_from_graph_chain) 34 | llm.generate_prompt.assert_called_once_with([generation_prompt], None, callbacks=ANY) 35 | 36 | 37 | @pytest.fixture 38 | def generation_prompt( 39 | prompt_template: PromptTemplate, graph_result: List[Dict[str, str]], question: str 40 | ) -> StringPromptValue: 41 | return prompt_template.format_prompt(context=graph_result, question=question) 42 | 43 | 44 | @pytest.fixture 45 | def answer_generation_chain(llm: BaseChatModel, prompt_template: PromptTemplate) -> AnswerGenerationChain: 46 | answer_generation_chain = AnswerGenerationChain(llm=llm, prompt_template=prompt_template) 47 | return answer_generation_chain 48 | 49 | 50 | @pytest.fixture 51 | def prompt_template() -> PromptTemplate: 52 | return PromptTemplate( 53 | input_variables=["context", "question"], template="Generate an answer to {question} using {context}:" 54 | ) 55 | 56 | 57 | @pytest.fixture 58 | def llm(llm_answer) -> BaseLanguageModel: 59 | return build_llm_mock(llm_answer) 60 | 61 | 62 | @pytest.fixture 63 | def llm_answer() -> str: 64 | return ( 65 | "The drugs associated with epilepsy are phenytoin, valproic acid, " 66 | "lamotrigine, diazepam, clonazepam, fosphenytoin, mephenytoin, " 67 | "neocitrullamon, carbamazepine, phenobarbital, secobarbital, " 68 | "primidone, and lorazepam." 69 | ) 70 | 71 | 72 | @pytest.fixture 73 | def prompt() -> PromptTemplate: 74 | return MagicMock(spec=PromptTemplate) 75 | 76 | 77 | @pytest.fixture 78 | def output_from_graph_chain(question: str, graph_result: List[Dict[str, str]]) -> Dict[str, Any]: 79 | return { 80 | "question": question, 81 | "cypher_query": "MATCH (d:drug)-[:indication]->(dis:disease) WHERE dis.name = 'epilepsy' RETURN d.name", 82 | "intermediate_steps": [ 83 | {"question": "Which drugs are associated with epilepsy?"}, 84 | { 85 | "FormatPreprocessor": 'MATCH (d:drug)-[:indication]->(dis:disease)\nWHERE dis.name = "epilepsy"\nRETURN d.name', 86 | "LowerCasePropertiesCypherQueryPreprocessor": 'MATCH (d:drug)-[:indication]->(dis:disease)\nWHERE dis.name = "epilepsy"\nRETURN d.name', 87 | }, 88 | {"graph_result": graph_result}, 89 | ], 90 | "preprocessed_cypher_query": 'MATCH (d:drug)-[:indication]->(dis:disease)\nWHERE dis.name = "epilepsy"\nRETURN d.name', 91 | "graph_result": graph_result, 92 | } 93 | 94 | 95 | @pytest.fixture 96 | def question() -> str: 97 | return "Which drugs are associated with epilepsy?" 98 | 99 | 100 | @pytest.fixture 101 | def graph_result() -> List[Dict[str, str]]: 102 | return [ 103 | {"d.name": "phenytoin"}, 104 | {"d.name": "phenytoin"}, 105 | {"d.name": "phenytoin"}, 106 | {"d.name": "valproic acid"}, 107 | {"d.name": "lamotrigine"}, 108 | {"d.name": "lamotrigine"}, 109 | {"d.name": "diazepam"}, 110 | {"d.name": "clonazepam"}, 111 | {"d.name": "fosphenytoin"}, 112 | {"d.name": "mephenytoin"}, 113 | {"d.name": "mephenytoin"}, 114 | {"d.name": "neocitrullamon"}, 115 | {"d.name": "carbamazepine"}, 116 | {"d.name": "carbamazepine"}, 117 | {"d.name": "phenobarbital"}, 118 | {"d.name": "phenobarbital"}, 119 | {"d.name": "secobarbital"}, 120 | {"d.name": "primidone"}, 121 | {"d.name": "primidone"}, 122 | {"d.name": "lorazepam"}, 123 | ] 124 | -------------------------------------------------------------------------------- /tests/chains/test_entity_detection_question_preprocessing_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | from fact_finder.chains.entity_detection_question_preprocessing_chain import ( 6 | EntityDetectionQuestionPreprocessingChain, 7 | ) 8 | from fact_finder.tools.entity_detector import EntityDetector 9 | 10 | _NON_OVERLAPPING_ENTITIES: List[Dict[str, int | str]] = [ 11 | {"start_span": 0, "end_span": 2, "pref_term": "pref_name1", "sem_type": "category1"}, 12 | {"start_span": 4, "end_span": 6, "pref_term": "pref_name2", "sem_type": "category2"}, 13 | {"start_span": 6, "end_span": 8, "pref_term": "pref_name3", "sem_type": "category1"}, 14 | ] 15 | 16 | _ILLEGAL_CATEGORIES_ENTITIES: List[Dict[str, int | str]] = [ 17 | {"start_span": 0, "end_span": 2, "pref_term": "pref_name1", "sem_type": "category1"}, 18 | {"start_span": 4, "end_span": 6, "pref_term": "pref_name2", "sem_type": "category4"}, 19 | {"start_span": 6, "end_span": 8, "pref_term": "pref_name3", "sem_type": "category5"}, 20 | ] 21 | 22 | _PARTIAL_OVERLAPPING_ENTITIES: List[Dict[str, int | str]] = [ 23 | {"start_span": 0, "end_span": 2, "pref_term": "pref_name1", "sem_type": "category1"}, 24 | {"start_span": 4, "end_span": 6, "pref_term": "pref_name2", "sem_type": "category1"}, 25 | {"start_span": 5, "end_span": 8, "pref_term": "pref_name3", "sem_type": "category1"}, 26 | ] 27 | 28 | _SUBSET_OVERLAP_ENTITIES: Tuple[List[Dict[str, int | str]], List[Dict[str, int | str]]] = ( 29 | [ 30 | {"start_span": 0, "end_span": 2, "pref_term": "pref_name1", "sem_type": "category1"}, 31 | {"start_span": 4, "end_span": 8, "pref_term": "pref_name2", "sem_type": "category1"}, 32 | {"start_span": 5, "end_span": 8, "pref_term": "pref_name3", "sem_type": "category1"}, 33 | ], 34 | [ 35 | {"start_span": 0, "end_span": 2, "pref_term": "pref_name1", "sem_type": "category1"}, 36 | {"start_span": 4, "end_span": 7, "pref_term": "pref_name3", "sem_type": "category1"}, 37 | {"start_span": 4, "end_span": 8, "pref_term": "pref_name2", "sem_type": "category1"}, 38 | ], 39 | ) 40 | 41 | 42 | def test_produces_expected_output_key(inputs: Dict[str, str], chain: EntityDetectionQuestionPreprocessingChain): 43 | res = chain.invoke(inputs) 44 | assert all(k in res for k in chain.output_keys) 45 | 46 | 47 | def test_result_contains_entity_replacements( 48 | inputs: Dict[str, str], chain: EntityDetectionQuestionPreprocessingChain, entities: List[Dict[str, int | str]] 49 | ): 50 | res = chain.invoke(inputs) 51 | assert all(e["pref_term"] in res[chain.output_keys[0]] for e in entities) 52 | 53 | 54 | def test_result_contains_entity_hints( 55 | inputs: Dict[str, str], 56 | chain: EntityDetectionQuestionPreprocessingChain, 57 | entities: List[Dict[str, int | str]], 58 | allowed_categories: Dict[str, str], 59 | ): 60 | res = chain.invoke(inputs) 61 | hints = [allowed_categories[e["sem_type"]].replace("{entity}", e["pref_term"]).capitalize() for e in entities] 62 | assert all(h in res[chain.output_keys[0]] for h in hints) 63 | 64 | 65 | @pytest.mark.parametrize("entities", (_ILLEGAL_CATEGORIES_ENTITIES,), indirect=True) 66 | def test_result_only_contains_entities_from_allowed_categories( 67 | inputs: Dict[str, str], 68 | chain: EntityDetectionQuestionPreprocessingChain, 69 | entities: List[Dict[str, int | str]], 70 | allowed_categories: Dict[str, str], 71 | ): 72 | res = chain.invoke(inputs) 73 | assert all((e["pref_term"] in res[chain.output_keys[0]]) == (e["sem_type"] in allowed_categories) for e in entities) 74 | 75 | 76 | @pytest.mark.parametrize("entities", (_PARTIAL_OVERLAPPING_ENTITIES,), indirect=True) 77 | def test_result_does_not_contain_replacements_for_overlapping_entities( 78 | inputs: Dict[str, str], chain: EntityDetectionQuestionPreprocessingChain 79 | ): 80 | res = chain.invoke(inputs) 81 | assert "pref_name2" not in res[chain.output_keys[0]] and "pref_name3" not in res[chain.output_keys[0]] 82 | 83 | 84 | @pytest.mark.parametrize("entities", _SUBSET_OVERLAP_ENTITIES, indirect=True) 85 | def test_result_does_not_contain_entities_that_are_contained_in_a_larger_entity( 86 | inputs: Dict[str, str], chain: EntityDetectionQuestionPreprocessingChain 87 | ): 88 | res = chain.invoke(inputs) 89 | assert "pref_name3" not in res[chain.output_keys[0]] 90 | 91 | 92 | @pytest.mark.parametrize("entities", _SUBSET_OVERLAP_ENTITIES, indirect=True) 93 | def test_result_contains_entity_that_contains_smaller_entity( 94 | inputs: Dict[str, str], chain: EntityDetectionQuestionPreprocessingChain 95 | ): 96 | res = chain.invoke(inputs) 97 | assert "pref_name2" in res[chain.output_keys[0]] 98 | 99 | 100 | @pytest.fixture 101 | def inputs(chain: EntityDetectionQuestionPreprocessingChain) -> Dict[str, str]: 102 | return {chain.input_keys[0]: "e1 e2e3 bla?"} 103 | 104 | 105 | @pytest.fixture 106 | def chain( 107 | entity_detector: EntityDetector, allowed_categories: Dict[str, str] 108 | ) -> EntityDetectionQuestionPreprocessingChain: 109 | return EntityDetectionQuestionPreprocessingChain( 110 | entity_detector=entity_detector, allowed_types_and_description_templates=allowed_categories 111 | ) 112 | 113 | 114 | @pytest.fixture 115 | def allowed_categories() -> Dict[str, str]: 116 | return { 117 | "category1": "{entity} is in category1.", 118 | "category2": "{entity} is in category2.", 119 | "category3": "{entity} is in category3.", 120 | } 121 | 122 | 123 | @pytest.fixture 124 | def entity_detector(entities: Dict[str, int | str]) -> EntityDetector: 125 | det = MagicMock(spec=EntityDetector) 126 | det.return_value = entities 127 | return det 128 | 129 | 130 | @pytest.fixture 131 | def entities(request: pytest.FixtureRequest) -> List[Dict[str, int | str]]: 132 | if not hasattr(request, "param") or request.param is None: 133 | return _NON_OVERLAPPING_ENTITIES 134 | return request.param 135 | -------------------------------------------------------------------------------- /tests/chains/test_filtered_primekg_question_preprocessing_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | from fact_finder.chains.filtered_primekg_question_preprocessing_chain import ( 6 | FilteredPrimeKGQuestionPreprocessingChain, 7 | ) 8 | from fact_finder.tools.entity_detector import EntityDetector 9 | from langchain_community.graphs import Neo4jGraph 10 | from tests.chains.test_entity_detection_question_preprocessing_chain import ( 11 | _ILLEGAL_CATEGORIES_ENTITIES, 12 | _NON_OVERLAPPING_ENTITIES, 13 | _PARTIAL_OVERLAPPING_ENTITIES, 14 | _SUBSET_OVERLAP_ENTITIES, 15 | ) 16 | 17 | _DISEASE_ENTITY: List[Dict[str, int | str]] = [ 18 | {"start_span": 0, "end_span": 2, "pref_term": "pref_name1", "sem_type": "category1"}, 19 | {"start_span": 4, "end_span": 6, "pref_term": "disease_pref_name", "sem_type": "Disease"}, 20 | {"start_span": 6, "end_span": 8, "pref_term": "pref_name3", "sem_type": "category1"}, 21 | ] 22 | 23 | 24 | def test_produces_expected_output_key(inputs: Dict[str, str], chain: FilteredPrimeKGQuestionPreprocessingChain): 25 | res = chain.invoke(inputs) 26 | assert all(k in res for k in chain.output_keys) 27 | 28 | 29 | def test_result_contains_entity_replacements(chain_result: str, entities: List[Dict[str, int | str]]): 30 | assert all(e["pref_term"] in chain_result for e in entities) 31 | 32 | 33 | def test_result_contains_entity_hints( 34 | chain_result: str, entities: List[Dict[str, int | str]], allowed_categories: Dict[str, str] 35 | ): 36 | hints = [allowed_categories[e["sem_type"]].replace("{entity}", e["pref_term"]).capitalize() for e in entities] 37 | assert all(h in chain_result for h in hints) 38 | 39 | 40 | @pytest.mark.parametrize("entities", (_ILLEGAL_CATEGORIES_ENTITIES,), indirect=True) 41 | def test_result_only_contains_entities_from_allowed_categories( 42 | chain_result: str, entities: List[Dict[str, int | str]], allowed_categories: Dict[str, str] 43 | ): 44 | assert all((e["pref_term"] in chain_result) == (e["sem_type"] in allowed_categories) for e in entities) 45 | 46 | 47 | @pytest.mark.parametrize("entities", (_PARTIAL_OVERLAPPING_ENTITIES,), indirect=True) 48 | def test_result_does_not_contain_replacements_for_overlapping_entities(chain_result: str): 49 | assert "pref_name2" not in chain_result and "pref_name3" not in chain_result 50 | 51 | 52 | @pytest.mark.parametrize("entities", _SUBSET_OVERLAP_ENTITIES, indirect=True) 53 | def test_result_does_not_contain_entities_that_are_contained_in_a_larger_entity(chain_result: str): 54 | assert "pref_name3" not in chain_result 55 | 56 | 57 | @pytest.mark.parametrize("entities", _SUBSET_OVERLAP_ENTITIES, indirect=True) 58 | def test_result_contains_entity_that_contains_smaller_entity(chain_result: str): 59 | assert "pref_name2" in chain_result 60 | 61 | 62 | @pytest.mark.parametrize("entities", (_DISEASE_ENTITY,), indirect=True) 63 | def test_disease_replaced_normally_if_no_matching_side_effect_found(chain_result: str): 64 | assert "disease_pref_name" in chain_result 65 | 66 | 67 | @pytest.mark.parametrize("entities", (_DISEASE_ENTITY,), indirect=True) 68 | def test_normal_disease_entity_hint_if_no_matching_side_effect_found(chain_result: str): 69 | assert "Disease_pref_name is a disease." in chain_result 70 | 71 | 72 | @pytest.mark.parametrize("entities,graph_result", [(_DISEASE_ENTITY, True)], indirect=True) 73 | def test_disease_not_replaced_if_matching_side_effect_found_for_original_name(chain_result: str): 74 | assert "disease_pref_name" not in chain_result and "e2" in chain_result 75 | 76 | 77 | @pytest.mark.parametrize("entities,graph_result", [(_DISEASE_ENTITY, True)], indirect=True) 78 | def test_correct_entity_hint_if_matching_side_effect_found_for_original_name(chain_result: str): 79 | assert "E2 is a disease or a effect_or_phenotype." in chain_result 80 | 81 | 82 | @pytest.mark.parametrize("entities,graph_result", [(_DISEASE_ENTITY, (False, False, True, False))], indirect=True) 83 | def test_disease_replaced_if_matching_side_effect_found_for_preferred_name(chain_result: str): 84 | assert "disease_pref_name" in chain_result and "e2" not in chain_result 85 | 86 | 87 | @pytest.mark.parametrize("entities,graph_result", [(_DISEASE_ENTITY, (False, False, True, False))], indirect=True) 88 | def test_correct_entity_hint_if_matching_side_effect_found_for_preferred_name(chain_result: str): 89 | assert "Disease_pref_name is a disease or a effect_or_phenotype." in chain_result 90 | 91 | 92 | @pytest.fixture 93 | def chain_result(inputs: Dict[str, str], chain: FilteredPrimeKGQuestionPreprocessingChain) -> str: 94 | return chain.invoke(inputs)[chain.output_keys[0]] 95 | 96 | 97 | @pytest.fixture 98 | def inputs(chain: FilteredPrimeKGQuestionPreprocessingChain) -> Dict[str, str]: 99 | return {chain.input_keys[0]: "e1 e2e3 bla?"} 100 | 101 | 102 | @pytest.fixture 103 | def chain( 104 | entity_detector: EntityDetector, allowed_categories: Dict[str, str], graph 105 | ) -> FilteredPrimeKGQuestionPreprocessingChain: 106 | return FilteredPrimeKGQuestionPreprocessingChain( 107 | entity_detector=entity_detector, allowed_types_and_description_templates=allowed_categories, graph=graph 108 | ) 109 | 110 | 111 | @pytest.fixture 112 | def allowed_categories() -> Dict[str, str]: 113 | return { 114 | "category1": "{entity} is in category1.", 115 | "category2": "{entity} is in category2.", 116 | "category3": "{entity} is in category3.", 117 | "disease": "{entity} is a disease.", 118 | } 119 | 120 | 121 | @pytest.fixture 122 | def graph(graph_result: List[Dict[str, bool]] | List[List[Dict[str, bool]]]) -> Neo4jGraph: 123 | graph = MagicMock(spec=Neo4jGraph) 124 | graph.query = MagicMock() 125 | if isinstance(graph_result[0], dict): 126 | graph.query.return_value = graph_result 127 | else: 128 | graph.query.side_effect = graph_result 129 | return graph 130 | 131 | 132 | @pytest.fixture 133 | def graph_result(request: pytest.FixtureRequest) -> List[Dict[str, bool]] | List[List[Dict[str, bool]]]: 134 | if not hasattr(request, "param") or request.param is None: 135 | return [{"exists": False}] 136 | if isinstance(request.param, bool): 137 | return [{"exists": request.param}] 138 | return [[{"exists": p}] for p in request.param] 139 | 140 | 141 | @pytest.fixture 142 | def entity_detector(entities: Dict[str, int | str]) -> EntityDetector: 143 | det = MagicMock(spec=EntityDetector) 144 | det.return_value = entities 145 | return det 146 | 147 | 148 | @pytest.fixture 149 | def entities(request: pytest.FixtureRequest) -> List[Dict[str, int | str]]: 150 | if not hasattr(request, "param") or request.param is None: 151 | return _NON_OVERLAPPING_ENTITIES 152 | return request.param 153 | -------------------------------------------------------------------------------- /tests/chains/test_graph_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | from langchain_community.graphs import Neo4jGraph 6 | 7 | from fact_finder.chains.graph_chain import GraphChain 8 | 9 | 10 | def test_graph_chain_returns_graph_result( 11 | graph_chain: GraphChain, preprocessors_chain_result: Dict[str, Any], graph_result: List[Dict[str, Any]] 12 | ): 13 | result = graph_chain(inputs=preprocessors_chain_result) 14 | assert result["graph_result"] == graph_result 15 | 16 | 17 | @pytest.mark.parametrize("top_k", (0, 5, 20, 50), indirect=True) 18 | def test_graph_result_length_is_at_most_k( 19 | graph_chain: GraphChain, top_k: int, preprocessors_chain_result: Dict[str, Any], graph_result: List[Dict[str, Any]] 20 | ): 21 | result = graph_chain(inputs=preprocessors_chain_result) 22 | if top_k <= len(graph_result): 23 | assert len(result["graph_result"]) == top_k 24 | else: 25 | assert len(result["graph_result"]) <= top_k 26 | 27 | 28 | def test_graph_result_is_added_to_intermediate_steps( 29 | graph_chain: GraphChain, preprocessors_chain_result: Dict[str, Any] 30 | ): 31 | result = graph_chain(inputs=preprocessors_chain_result) 32 | assert "graph_result" in result["intermediate_steps"][-1].keys() 33 | 34 | 35 | def test_graph_is_called_with_expected_cypher_query(graph_chain: GraphChain, graph: MagicMock): 36 | cypher_query = "" 37 | graph_chain.invoke(input={graph_chain.input_key: cypher_query}) 38 | graph.query.assert_called_once_with(cypher_query) 39 | 40 | 41 | @pytest.fixture 42 | def graph_chain(graph: Neo4jGraph, top_k: int) -> GraphChain: 43 | return GraphChain(graph=graph, top_k=top_k, return_intermediate_steps=True) 44 | 45 | 46 | @pytest.fixture 47 | def top_k(request) -> int: 48 | if hasattr(request, "param") and request.param: 49 | return request.param 50 | return 20 51 | 52 | 53 | @pytest.fixture 54 | def graph(graph_result) -> Neo4jGraph: 55 | graph = MagicMock(spec=Neo4jGraph) 56 | graph.query = MagicMock() 57 | graph.query.return_value = graph_result 58 | return graph 59 | 60 | 61 | @pytest.fixture() 62 | def graph_result() -> List[Dict[str, Any]]: 63 | return [ 64 | {"d.name": "phenytoin"}, 65 | {"d.name": "phenytoin"}, 66 | {"d.name": "phenytoin"}, 67 | {"d.name": "valproic acid"}, 68 | {"d.name": "lamotrigine"}, 69 | {"d.name": "lamotrigine"}, 70 | {"d.name": "diazepam"}, 71 | {"d.name": "clonazepam"}, 72 | {"d.name": "fosphenytoin"}, 73 | {"d.name": "mephenytoin"}, 74 | {"d.name": "mephenytoin"}, 75 | {"d.name": "neocitrullamon"}, 76 | {"d.name": "carbamazepine"}, 77 | {"d.name": "carbamazepine"}, 78 | {"d.name": "phenobarbital"}, 79 | {"d.name": "phenobarbital"}, 80 | {"d.name": "secobarbital"}, 81 | {"d.name": "primidone"}, 82 | {"d.name": "primidone"}, 83 | {"d.name": "lorazepam"}, 84 | ] 85 | 86 | 87 | @pytest.fixture 88 | def preprocessors_chain_result() -> Dict[str, Any]: 89 | return { 90 | "cypher_query": "MATCH (d:drug)-[:indication]->(dis:disease) WHERE dis.name = 'epilepsy' RETURN d.name", 91 | "intermediate_steps": [ 92 | {"question": "Which drugs are associated with epilepsy?"}, 93 | { 94 | "FormatPreprocessor": 'MATCH (d:drug)-[:indication]->(dis:disease)\nWHERE dis.name = "epilepsy"\nRETURN d.name', 95 | "LowerCasePropertiesCypherQueryPreprocessor": 'MATCH (d:drug)-[:indication]->(dis:disease)\nWHERE dis.name = "epilepsy"\nRETURN d.name', 96 | }, 97 | ], 98 | "preprocessed_cypher_query": 'MATCH (d:drug)-[:indication]->(dis:disease)\nWHERE dis.name = "epilepsy"\nRETURN d.name', 99 | } 100 | -------------------------------------------------------------------------------- /tests/chains/test_graph_summary_chain.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fact_finder.chains.graph_summary_chain import GraphSummaryChain 3 | from langchain_core.prompts import PromptTemplate 4 | from tests.chains.helpers import build_llm_mock 5 | 6 | 7 | def test_simple_question(graph_summary_chain: GraphSummaryChain): 8 | answer = graph_summary_chain({"sub_graph": "(psoriasis, is a, disease)"}) 9 | assert answer["summary"].startswith("Psoriasis is a disease") 10 | 11 | 12 | @pytest.fixture 13 | def graph_summary_chain(graph_summary_template: PromptTemplate) -> GraphSummaryChain: 14 | return GraphSummaryChain( 15 | llm=build_llm_mock("Psoriasis is a disease."), graph_summary_template=graph_summary_template 16 | ) 17 | 18 | 19 | @pytest.fixture 20 | def graph_summary_template() -> PromptTemplate: 21 | return PromptTemplate( 22 | input_variables=["sub_graph"], 23 | template=""" 24 | Verbalize the given triplets of a subgraph to natural text. Use all triplets for the verbalization. 25 | 26 | Triplets of the subgraph: 27 | {sub_graph} 28 | """, 29 | ) 30 | -------------------------------------------------------------------------------- /tests/chains/test_preprocessors_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import pytest 4 | from fact_finder.chains.cypher_query_preprocessors_chain import ( 5 | CypherQueryPreprocessorsChain, 6 | ) 7 | from fact_finder.tools.cypher_preprocessors.cypher_query_preprocessor import ( 8 | CypherQueryPreprocessor, 9 | ) 10 | from fact_finder.tools.cypher_preprocessors.format_preprocessor import ( 11 | FormatPreprocessor, 12 | ) 13 | from fact_finder.tools.cypher_preprocessors.lower_case_properties_cypher_query_preprocessor import ( 14 | LowerCasePropertiesCypherQueryPreprocessor, 15 | ) 16 | 17 | 18 | def test_produces_expected_output_key( 19 | cypher_query_generation_chain_result: Dict[str, Any], preprocessors: List[CypherQueryPreprocessor] 20 | ): 21 | chain = CypherQueryPreprocessorsChain(cypher_query_preprocessors=preprocessors) 22 | result = chain(cypher_query_generation_chain_result) 23 | assert chain.output_key in result.keys() 24 | 25 | 26 | def test_produces_expected_number_of_intermediate_steps( 27 | cypher_query_generation_chain_result: Dict[str, Any], preprocessors: List[CypherQueryPreprocessor] 28 | ): 29 | chain = CypherQueryPreprocessorsChain(cypher_query_preprocessors=preprocessors, return_intermediate_steps=True) 30 | result = chain(cypher_query_generation_chain_result) 31 | assert len(result["intermediate_steps"]) == len(preprocessors) + len( 32 | cypher_query_generation_chain_result["intermediate_steps"] 33 | ) 34 | 35 | 36 | def test_applies_expected_preprocessings( 37 | cypher_query_generation_chain_result: Dict[str, Any], preprocessors: List[CypherQueryPreprocessor] 38 | ): 39 | chain = CypherQueryPreprocessorsChain(cypher_query_preprocessors=preprocessors) 40 | result = chain(cypher_query_generation_chain_result) 41 | assert ( 42 | result["preprocessed_cypher_query"] 43 | == 'MATCH (d:drug)-[:indication]->(dis:disease)\nWHERE dis.name = "epilepsy"\nRETURN d.name' 44 | ) 45 | 46 | 47 | def test_preprocessors_are_called_in_order(cypher_query_generation_chain_result: Dict[str, Any]): 48 | 49 | class CypherQueryPreprocessorMock(CypherQueryPreprocessor): 50 | call_idx = 0 51 | 52 | def __init__(self, expected_idx: int) -> None: 53 | self._expected_idx = expected_idx 54 | 55 | def __call__(self, cypher_query: str) -> str: 56 | assert CypherQueryPreprocessorMock.call_idx == self._expected_idx 57 | CypherQueryPreprocessorMock.call_idx += 1 58 | return cypher_query 59 | 60 | num_preprocs = 3 61 | preprocs = [CypherQueryPreprocessorMock(expected_idx=i) for i in range(num_preprocs)] 62 | chain = CypherQueryPreprocessorsChain(cypher_query_preprocessors=preprocs) 63 | chain(cypher_query_generation_chain_result) 64 | assert CypherQueryPreprocessorMock.call_idx == num_preprocs 65 | 66 | 67 | @pytest.fixture 68 | def cypher_query_generation_chain_result() -> Dict[str, Any]: 69 | return { 70 | "cypher_query": "MATCH (d:drug)-[:indication]->(dis:disease) WHERE dis.name = 'epilepsy' RETURN d.name", 71 | "intermediate_steps": [{"question": "Which drugs are associated with epilepsy?"}], 72 | } 73 | 74 | 75 | @pytest.fixture 76 | def preprocessors() -> List[CypherQueryPreprocessor]: 77 | return [FormatPreprocessor(), LowerCasePropertiesCypherQueryPreprocessor()] 78 | -------------------------------------------------------------------------------- /tests/evaluator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/tests/evaluator/__init__.py -------------------------------------------------------------------------------- /tests/evaluator/score/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/tests/evaluator/score/__init__.py -------------------------------------------------------------------------------- /tests/evaluator/score/test_bleu_score.py: -------------------------------------------------------------------------------- 1 | from nltk import word_tokenize 2 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 3 | 4 | from fact_finder.evaluator.score.bleu_score import BleuScore 5 | 6 | 7 | def test_compare_different_sentences(): 8 | score = BleuScore() 9 | text_a = "This is a test sentence." 10 | text_b = "This is another test sentence." 11 | smoothing_function = SmoothingFunction().method1 12 | expected_score = sentence_bleu( 13 | [word_tokenize(text_a)], word_tokenize(text_b), smoothing_function=smoothing_function 14 | ) 15 | assert score.compare(text_a, text_b) == expected_score 16 | 17 | 18 | def test_compare_same_sentences(): 19 | score = BleuScore() 20 | text_a = "This is a test sentence." 21 | text_b = "This is a test sentence." 22 | smoothing_function = SmoothingFunction().method1 23 | expected_score = sentence_bleu( 24 | [word_tokenize(text_a)], word_tokenize(text_b), smoothing_function=smoothing_function 25 | ) 26 | score = score.compare(text_a, text_b) 27 | assert score == expected_score 28 | assert score == 1.0 29 | -------------------------------------------------------------------------------- /tests/evaluator/score/test_difflib_score.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from fact_finder.evaluator.score.difflib_score import DifflibScore 4 | 5 | 6 | @pytest.fixture 7 | def scorer(): 8 | return DifflibScore() 9 | 10 | 11 | def test_exact_match(scorer): 12 | text_a = "example" 13 | text_b = "example" 14 | score = scorer.compare(text_a, text_b) 15 | assert score == 1.0 16 | 17 | 18 | def test_no_match(scorer): 19 | text_a = "example" 20 | text_b = "different" 21 | score = scorer.compare(text_a, text_b) 22 | assert score == 0.25 23 | 24 | 25 | def test_partial_match(scorer): 26 | text_a = "example" 27 | text_b = "sample" 28 | score = scorer.compare(text_a, text_b) 29 | assert score == 0.7692307692307693 30 | 31 | 32 | def test_case_sensitivity(scorer): 33 | text_a = "Example" 34 | text_b = "example" 35 | score = scorer.compare(text_a, text_b) 36 | assert score == 0.8571428571428571 37 | -------------------------------------------------------------------------------- /tests/evaluator/score/test_embedding_score.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fact_finder.evaluator.score.embedding_score import EmbeddingScore 3 | 4 | # todo mock the model behaviour 5 | 6 | 7 | @pytest.fixture 8 | def scorer(): 9 | return EmbeddingScore() 10 | 11 | 12 | def test_exact_match(scorer): 13 | text_a = "This is a test." 14 | text_b = "This is a test." 15 | score = scorer.compare(text_a, text_b) 16 | assert score == pytest.approx(1.0) 17 | 18 | 19 | def test_no_match(scorer): 20 | text_a = "This is a test." 21 | text_b = "Hello" 22 | score = scorer.compare(text_a, text_b) 23 | assert score < 0.2 24 | 25 | 26 | def test_partial_match(scorer): 27 | text_a = "This is a test." 28 | text_b = "This is a different test." 29 | score = scorer.compare(text_a, text_b) 30 | assert score > 0.0 31 | assert score < 1.0 32 | 33 | 34 | def test_similarity(scorer): 35 | text_a = "I had a great vacation." 36 | text_b = "My holiday was fantastic." 37 | score = scorer.compare(text_a, text_b) 38 | assert score > 0.5 39 | -------------------------------------------------------------------------------- /tests/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrschy/fact-finder/ca57d1b9e3683d026ae52c3662b016781d26cc97/tests/tools/__init__.py -------------------------------------------------------------------------------- /tests/tools/cypher_preprocessors/test_always_distinct_preprocessor.py: -------------------------------------------------------------------------------- 1 | from fact_finder.tools.cypher_preprocessors.always_distinct_preprocessor import ( 2 | AlwaysDistinctCypherQueryPreprocessor, 3 | ) 4 | 5 | 6 | def test_adds_distinct_keyword(): 7 | preproc = AlwaysDistinctCypherQueryPreprocessor() 8 | query1 = 'MATCH (e:exposure {name: "ethanol"})-[:linked_to]->(d:disease) RETURN d.name' 9 | query2 = 'MATCH (e:exposure {name: "ethanol"})-[:linked_to]->(d:disease) RETURN DISTINCT d.name' 10 | processed_query = preproc(query1) 11 | assert processed_query == query2 12 | 13 | 14 | def test_does_nothing_if_distinct_keyword_already_present(): 15 | preproc = AlwaysDistinctCypherQueryPreprocessor() 16 | query1 = 'MATCH (e:exposure {name: "ethanol"})-[:linked_to]->(d:disease) RETURN DISTINCT d.name' 17 | processed_query = preproc(query1) 18 | assert processed_query == query1 19 | -------------------------------------------------------------------------------- /tests/tools/cypher_preprocessors/test_child_to_parent_preprocessor.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | import pytest 4 | from fact_finder.tools.cypher_preprocessors.child_to_parent_preprocessor import ( 5 | ChildToParentPreprocessor, 6 | ) 7 | from langchain_community.graphs import Neo4jGraph 8 | 9 | 10 | def test_replaces_property_name_in_node(preprocessor: ChildToParentPreprocessor): 11 | query1 = 'MATCH (e:exposure {name: "child1"})-[:linked_to]->(d:disease) RETURN d.name' 12 | query2 = 'MATCH (e:exposure {name: "parent1"})-[:linked_to]->(d:disease) RETURN d.name' 13 | processed_query = preprocessor(query1) 14 | assert processed_query == query2 15 | 16 | 17 | def test_replaces_property_name_in_where_clause(preprocessor: ChildToParentPreprocessor): 18 | query1 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name = "child2" RETURN d.name' 19 | query2 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name = "parent2" RETURN d.name' 20 | processed_query = preprocessor(query1) 21 | assert processed_query == query2 22 | 23 | 24 | def test_replaces_property_name_in_multi_element_list_in_where_clause(preprocessor: ChildToParentPreprocessor): 25 | query1 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name IN ["child3", "child1", "child1"] RETURN d' 26 | query2 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name IN ["parent3", "parent1", "parent1"] RETURN d' 27 | processed_query = preprocessor(query1) 28 | assert processed_query == query2 29 | 30 | 31 | @pytest.fixture 32 | def preprocessor() -> ChildToParentPreprocessor: 33 | graph = MagicMock(spec=Neo4jGraph) 34 | graph.query = MagicMock() 35 | graph.query.return_value = [ 36 | {"child.name": "child1", "parent.name": "parent1"}, 37 | {"child.name": "child2", "parent.name": "parent2"}, 38 | {"child.name": "child3", "parent.name": "parent3"}, 39 | ] 40 | return ChildToParentPreprocessor(graph, "parent_child", name_property="name") 41 | -------------------------------------------------------------------------------- /tests/tools/cypher_preprocessors/test_format_preprocessor.py: -------------------------------------------------------------------------------- 1 | from fact_finder.tools.cypher_preprocessors.format_preprocessor import ( 2 | FormatPreprocessor, 3 | ) 4 | 5 | 6 | def test_on_empty_string(): 7 | preprocessor = FormatPreprocessor() 8 | assert preprocessor("") == "" 9 | 10 | 11 | def test_well_formed_query_remains_unchanged(): 12 | query = ( 13 | 'MATCH (disease1:disease {name: "psoriasis"})-[:linked_to]->(exposure:exposure)-[:linked_to]->(disease2:disease {name: "scalp disease"})\n' 14 | 'WHERE p.name = "phenotype"\n' 15 | "RETURN disease1, exposure, disease2" 16 | ) 17 | preprocessor = FormatPreprocessor() 18 | assert preprocessor(query) == query 19 | 20 | 21 | def test_newlines_get_added(): 22 | query = ( 23 | 'MATCH (disease1:disease {name: "psoriasis"})-[:linked_to]->(exposure:exposure)-[:linked_to]->(disease2:disease {name: "scalp disease"})' 24 | 'WHERE p.name = "phenotype"' 25 | "RETURN disease1, exposure, disease2" 26 | ) 27 | formated_query = ( 28 | 'MATCH (disease1:disease {name: "psoriasis"})-[:linked_to]->(exposure:exposure)-[:linked_to]->(disease2:disease {name: "scalp disease"})\n' 29 | 'WHERE p.name = "phenotype"\n' 30 | "RETURN disease1, exposure, disease2" 31 | ) 32 | preprocessor = FormatPreprocessor() 33 | assert preprocessor(query) == formated_query 34 | 35 | 36 | def test_spaces_in_node_get_removed(): 37 | query = ( 38 | 'MATCH ( disease1 : disease { name : "psoriasis" } )-[:linked_to]->( exposure : exposure )-[:linked_to]->(disease2:disease {name: "scalp disease"})\n' 39 | 'WHERE p.name = "phenotype"\n' 40 | "RETURN disease1, exposure, disease2" 41 | ) 42 | formated_query = ( 43 | 'MATCH (disease1:disease {name: "psoriasis"})-[:linked_to]->(exposure:exposure)-[:linked_to]->(disease2:disease {name: "scalp disease"})\n' 44 | 'WHERE p.name = "phenotype"\n' 45 | "RETURN disease1, exposure, disease2" 46 | ) 47 | preprocessor = FormatPreprocessor() 48 | assert preprocessor(query) == formated_query 49 | 50 | 51 | def test_spaces_in_edges_get_removed(): 52 | query = ( 53 | 'MATCH (disease1:disease {name: "psoriasis"}) - [ : linked_to ] - > (exposure:exposure) - [ : linked_to ] - > (disease2:disease {name: "scalp disease"})\n' 54 | 'WHERE p.name = "phenotype"\n' 55 | "RETURN disease1, exposure, disease2" 56 | ) 57 | formated_query = ( 58 | 'MATCH (disease1:disease {name: "psoriasis"})-[:linked_to]->(exposure:exposure)-[:linked_to]->(disease2:disease {name: "scalp disease"})\n' 59 | 'WHERE p.name = "phenotype"\n' 60 | "RETURN disease1, exposure, disease2" 61 | ) 62 | preprocessor = FormatPreprocessor() 63 | assert preprocessor(query) == formated_query 64 | 65 | 66 | def test_single_quotes_to_double_quotes(): 67 | query = ( 68 | "MATCH (disease1:disease {name: 'psoriasi\\'s'})-[:linked_to]->(exposure:exposure)-[:linked_to]->(disease2:disease {name: 'scalp disease'})\n" 69 | "WHERE p.name = 'phenotype'\n" 70 | "RETURN disease1, exposure, disease2" 71 | ) 72 | formated_query = ( 73 | 'MATCH (disease1:disease {name: "psoriasi\\\'s"})-[:linked_to]->(exposure:exposure)-[:linked_to]->(disease2:disease {name: "scalp disease"})\n' 74 | 'WHERE p.name = "phenotype"\n' 75 | "RETURN disease1, exposure, disease2" 76 | ) 77 | preprocessor = FormatPreprocessor() 78 | assert preprocessor(query) == formated_query 79 | 80 | 81 | def test_match_formatings_also_work_in_exists_block(): 82 | query = ( 83 | 'MATCH (d:disease {name: "psoriasis"})-[:indication]->(drug:drug)\n' 84 | "WHERE EXISTS( ( : disease { name :'psoriatic arthriti\\'s' } ) - [ : indication ] - > ( drug ) )" 85 | "RETURN drug.name" 86 | ) 87 | formated_query = ( 88 | 'MATCH (d:disease {name: "psoriasis"})-[:indication]->(drug:drug)\n' 89 | 'WHERE exists((:disease {name: "psoriatic arthriti\\\'s"})-[:indication]->(drug))\n' 90 | "RETURN drug.name" 91 | ) 92 | preprocessor = FormatPreprocessor() 93 | assert preprocessor(query) == formated_query 94 | 95 | 96 | def test_single_quotes_in_double_quoted_string_get_escaped(): 97 | query = ( 98 | "MATCH (disease1:disease {name: \"pso'riasi's\"})-[:linked_to]->(exposure:exposure)-[:linked_to]->(disease2:disease {name: 'scalp disease'})\n" 99 | "WHERE p.name = \"phe'no'type\"\n" 100 | "RETURN disease1, exposure, disease2" 101 | ) 102 | formated_query = ( 103 | 'MATCH (disease1:disease {name: "pso\\\'riasi\\\'s"})-[:linked_to]->(exposure:exposure)-[:linked_to]->(disease2:disease {name: "scalp disease"})\n' 104 | "WHERE p.name = \"phe\\'no\\'type\"\n" 105 | "RETURN disease1, exposure, disease2" 106 | ) 107 | preprocessor = FormatPreprocessor() 108 | assert preprocessor(query) == formated_query 109 | -------------------------------------------------------------------------------- /tests/tools/cypher_preprocessors/test_lower_case_property_names.py: -------------------------------------------------------------------------------- 1 | from fact_finder.tools.cypher_preprocessors.lower_case_properties_cypher_query_preprocessor import ( 2 | LowerCasePropertiesCypherQueryPreprocessor, 3 | ) 4 | 5 | 6 | def test_producing_lower_case_for_given_property(): 7 | preproc = LowerCasePropertiesCypherQueryPreprocessor(property_names=["name"]) 8 | query1 = 'MATCH (e:exposure {name: "Ethanol"})-[:linked_to]->(d:disease) RETURN d.name' 9 | query2 = 'MATCH (e:exposure {name: "ethanol"})-[:linked_to]->(d:disease) RETURN d.name' 10 | processed_query = preproc(query1) 11 | assert processed_query == query2 12 | 13 | 14 | def test_producing_lower_case_with_any_property(): 15 | preproc = LowerCasePropertiesCypherQueryPreprocessor() 16 | query1 = 'MATCH (e:exposure {any_property: "Ethanol"})-[:linked_to]->(d:disease) RETURN d.name' 17 | query2 = 'MATCH (e:exposure {any_property: "ethanol"})-[:linked_to]->(d:disease) RETURN d.name' 18 | processed_query = preproc(query1) 19 | assert processed_query == query2 20 | 21 | 22 | def test_producing_multiple_lower_case_with_any_property(): 23 | preproc = LowerCasePropertiesCypherQueryPreprocessor() 24 | query1 = 'MATCH (e:exposure {any_property: "Ethanol"})-[:linked_to]->(d:disease {another_property: "HickUp"}) RETURN d.name' 25 | query2 = 'MATCH (e:exposure {any_property: "ethanol"})-[:linked_to]->(d:disease {another_property: "hickup"}) RETURN d.name' 26 | processed_query = preproc(query1) 27 | assert processed_query == query2 28 | 29 | 30 | def test_producing_multiple_lower_case_for_multiple_given_properties(): 31 | preproc = LowerCasePropertiesCypherQueryPreprocessor(property_names=["name", "disease_name"]) 32 | query1 = 'MATCH (e:exposure {name: "Ethanol"})-[:linked_to]->(d:disease {disease_name: "HickUp"}) RETURN d.name' 33 | query2 = 'MATCH (e:exposure {name: "ethanol"})-[:linked_to]->(d:disease {disease_name: "hickup"}) RETURN d.name' 34 | processed_query = preproc(query1) 35 | assert processed_query == query2 36 | 37 | 38 | def test_producing_lower_case_with_spaces_present(): 39 | preproc = LowerCasePropertiesCypherQueryPreprocessor() 40 | query1 = 'MATCH (e:exposure)-[:linked_to]->(d:disease {disease_name: "Hick Up"}) RETURN d.name' 41 | query2 = 'MATCH (e:exposure)-[:linked_to]->(d:disease {disease_name: "hick up"}) RETURN d.name' 42 | processed_query = preproc(query1) 43 | assert processed_query == query2 44 | 45 | 46 | def test_producing_lower_case_with_special_characters_present(): 47 | preproc = LowerCasePropertiesCypherQueryPreprocessor() 48 | query1 = 'MATCH (e:exposure)-[:linked_to]->(d:disease {disease_name: "Hick-_!?=/\\+#Up"}) RETURN d.name' 49 | query2 = 'MATCH (e:exposure)-[:linked_to]->(d:disease {disease_name: "hick-_!?=/\\+#up"}) RETURN d.name' 50 | processed_query = preproc(query1) 51 | assert processed_query == query2 52 | 53 | 54 | def test_producing_lower_case_for_assignment_in_where_clause(): 55 | preproc = LowerCasePropertiesCypherQueryPreprocessor() 56 | query1 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name = "Ethanol" RETURN d.name' 57 | query2 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name = "ethanol" RETURN d.name' 58 | processed_query = preproc(query1) 59 | assert processed_query == query2 60 | 61 | 62 | def test_producing_lower_case_for_multiple_assignments_in_where_clause(): 63 | preproc = LowerCasePropertiesCypherQueryPreprocessor() 64 | query1 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name = "Ethanol" AND d.name = "HickUp" RETURN d.name' 65 | query2 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name = "ethanol" AND d.name = "hickup" RETURN d.name' 66 | processed_query = preproc(query1) 67 | assert processed_query == query2 68 | 69 | 70 | def test_producing_lower_case_for_one_element_list_in_where_clause(): 71 | preproc = LowerCasePropertiesCypherQueryPreprocessor() 72 | query1 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name IN ["Ethanol"] RETURN d.name' 73 | query2 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name IN ["ethanol"] RETURN d.name' 74 | processed_query = preproc(query1) 75 | assert processed_query == query2 76 | 77 | 78 | def test_producing_lower_case_for_multi_element_list_in_where_clause(): 79 | preproc = LowerCasePropertiesCypherQueryPreprocessor() 80 | query1 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name IN ["Ethanol", "Alcohol", "Wine"] RETURN d.name' 81 | query2 = 'MATCH (d:disease)-[:linked_to]-(e:exposure) WHERE e.name IN ["ethanol", "alcohol", "wine"] RETURN d.name' 82 | processed_query = preproc(query1) 83 | assert processed_query == query2 84 | -------------------------------------------------------------------------------- /tests/tools/cypher_preprocessors/test_size_to_count_preprocessor.py: -------------------------------------------------------------------------------- 1 | from fact_finder.tools.cypher_preprocessors.size_to_count_preprocessor import ( 2 | SizeToCountPreprocessor, 3 | ) 4 | 5 | 6 | def test_size_to_count(): 7 | preproc = SizeToCountPreprocessor() 8 | cypher_query = "MATCH (n)\nWITH n, SIZE([(n)--()|1]) AS num_edges\nWHERE num_edges = 1\nRETURN n" 9 | processed_query = "MATCH (n)\nWITH n, COUNT{[(n)--()|1]} AS num_edges\nWHERE num_edges = 1\nRETURN n" 10 | assert preproc(cypher_query) == processed_query 11 | 12 | 13 | def test_size_to_count_with_whitespace(): 14 | preproc = SizeToCountPreprocessor() 15 | cypher_query = "MATCH (n)\nWITH n, SIZE ([(n)--()|1]) AS num_edges\nWHERE num_edges = 1\nRETURN n" 16 | processed_query = "MATCH (n)\nWITH n, COUNT{[(n)--()|1]} AS num_edges\nWHERE num_edges = 1\nRETURN n" 17 | assert preproc(cypher_query) == processed_query 18 | 19 | 20 | def test_size_to_count_with_multiple_levels_of_inner_brackets(): 21 | preproc = SizeToCountPreprocessor() 22 | cypher_query = "MATCH (n)\nWHERE SIZE ((())(foo(bar))) = 1\nRETURN n" 23 | processed_query = "MATCH (n)\nWHERE COUNT{(())(foo(bar))} = 1\nRETURN n" 24 | assert preproc(cypher_query) == processed_query 25 | -------------------------------------------------------------------------------- /tests/tools/cypher_preprocessors/test_synonym_query_preprocessor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | from langchain_community.graphs import Neo4jGraph 6 | 7 | from fact_finder.tools.cypher_preprocessors.synonym_cypher_query_preprocessor import SynonymCypherQueryPreprocessor 8 | from fact_finder.tools.synonym_finder.synonym_finder import SynonymFinder 9 | 10 | 11 | _synonyms_with_match = ["ethanol", "alcohols", "alcohol by volume", "heat", "fever"] 12 | _synonyms_without_match = ["alcoholic_beverage", "alcoholic_drink", "inebriant", "intoxicant"] 13 | 14 | 15 | @pytest.fixture(scope="session") 16 | def graph() -> Neo4jGraph: 17 | 18 | def get_graph_nodes(query: str) -> List[Dict[str, Dict[str, str]]]: 19 | if "exposure" in query: 20 | return [{"n": {"name": "ethanol", "id": "1"}}] 21 | if "disease" in query: 22 | return [{"n": {"name": "fever", "id": "2"}}] 23 | return [] 24 | 25 | graph = MagicMock(spec=Neo4jGraph) 26 | graph.query = MagicMock() 27 | graph.query.side_effect = get_graph_nodes 28 | return graph 29 | 30 | 31 | @pytest.fixture() 32 | def synonyms(request) -> List[str]: 33 | if hasattr(request, "param") and request.param: 34 | return request.param 35 | return [] 36 | 37 | 38 | @pytest.fixture() 39 | def synonym_preprocessor(synonyms: str, graph: Neo4jGraph): 40 | synonym_finder = MagicMock(spec=SynonymFinder) 41 | synonym_finder = MagicMock() 42 | synonym_finder.return_value = synonyms 43 | return SynonymCypherQueryPreprocessor( 44 | graph=graph, synonym_finder=synonym_finder, node_types=["exposure", "disease"] 45 | ) 46 | 47 | 48 | @pytest.mark.parametrize("synonyms", (_synonyms_with_match,), indirect=True) 49 | def test_replaces_match_that_exists_in_graph(synonym_preprocessor): 50 | query = 'MATCH (e:exposure {name: "alcohol"})-[:linked_to]->(d:disease) RETURN d.name' 51 | selector_result = synonym_preprocessor(query) 52 | expected = 'MATCH (e:exposure {name: "ethanol"})-[:linked_to]->(d:disease) RETURN d.name' 53 | assert selector_result == expected 54 | 55 | 56 | @pytest.mark.parametrize("synonyms", (_synonyms_with_match,), indirect=True) 57 | def test_no_change_if_synonym_from_graph_is_already_used(synonym_preprocessor): 58 | query = 'MATCH (e:exposure {name: "ethanol"})-[:linked_to]->(d:disease) RETURN d.name' 59 | selector_result = synonym_preprocessor(query) 60 | assert selector_result == query 61 | 62 | 63 | @pytest.mark.parametrize("synonyms", (_synonyms_without_match,), indirect=True) 64 | def test_no_change_if_no_match_was_found(synonym_preprocessor): 65 | query = 'MATCH (e:exposure {name: "alcohol"})-[:linked_to]->(d:disease) RETURN d.name' 66 | selector_result = synonym_preprocessor(query) 67 | assert selector_result == query 68 | 69 | 70 | @pytest.mark.parametrize("synonyms", (_synonyms_with_match,), indirect=True) 71 | def test_replaces_multiple_matches_that_exists_in_graph(synonym_preprocessor): 72 | query = 'MATCH (e:exposure {name: "alcohol"})-[r]->(d:disease {name: "heat"}) RETURN r.name' 73 | selector_result = synonym_preprocessor(query) 74 | expected = 'MATCH (e:exposure {name: "ethanol"})-[r]->(d:disease {name: "fever"}) RETURN r.name' 75 | assert selector_result == expected 76 | 77 | 78 | @pytest.mark.parametrize("synonyms", (_synonyms_with_match,), indirect=True) 79 | def test_replaces_match_in_where_clause(synonym_preprocessor): 80 | query = 'MATCH (e:exposure)-[:linked_to]->(d:disease)\nWHERE d.name = "heat"\nRETURN e.name' 81 | selector_result = synonym_preprocessor(query) 82 | expected = 'MATCH (e:exposure)-[:linked_to]->(d:disease)\nWHERE d.name = "fever"\nRETURN e.name' 83 | assert selector_result == expected 84 | 85 | 86 | @pytest.mark.parametrize("synonyms", (_synonyms_with_match,), indirect=True) 87 | def test_replaces_match_in_where_clause_with_clutter(synonym_preprocessor): 88 | query = 'MATCH (d:disease)-[:linked_to]->(dr:drug)\nWHERE dr.name = "lollipop" and d.name = "heat" bla bla\nRETURN e.name' 89 | selector_result = synonym_preprocessor(query) 90 | expected = 'MATCH (d:disease)-[:linked_to]->(dr:drug)\nWHERE dr.name = "lollipop" and d.name = "fever" bla bla\nRETURN e.name' 91 | assert selector_result == expected 92 | 93 | 94 | def test_does_not_change_cypher_query_if_no_matches_found(graph: Neo4jGraph): 95 | preproc = SynonymCypherQueryPreprocessor(graph=graph, synonym_finder=lambda x: [], node_types=[r"[^\s\"'{]+"]) 96 | query = "MATCH (e:exposure)-[:linked_to]->(d:disease) RETURN e.name, d.name" 97 | selector_result = preproc(query) 98 | assert selector_result == query 99 | -------------------------------------------------------------------------------- /tests/tools/semantic_scholar_search_api_wrapper_test.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch, MagicMock 2 | 3 | import pytest 4 | 5 | from fact_finder.tools.semantic_scholar_search_api_wrapper import SemanticScholarSearchApiWrapper 6 | 7 | 8 | @pytest.fixture 9 | def semantic_scholar_search_api_wrapper(): 10 | with patch("requests.Session") as mock_session: 11 | mock_session.return_value.get.side_effect = _mock_get 12 | return SemanticScholarSearchApiWrapper() 13 | 14 | 15 | def test_search_by_abstract(semantic_scholar_search_api_wrapper): 16 | result = semantic_scholar_search_api_wrapper.search_by_abstracts(keywords="psoriasis, symptoms") 17 | assert 5 == len(result) 18 | assert result[0].startswith("Improve psoriasis") 19 | assert result[4].startswith("Prevalence and Odds") 20 | 21 | 22 | def _mock_get(url: str, params: dict, headers: dict): 23 | assert "https://api.semanticscholar.org/graph/v1/paper/search" == url 24 | assert {"fields": "title,abstract", "limit": 5, "query": "psoriasis, symptoms"} == params 25 | assert "x-api-key" in headers.keys() 26 | response = MagicMock() 27 | response.status_code = 200 28 | response.json = lambda: { 29 | "data": [ 30 | { 31 | "paperId": "160c6875586b9bc085ef4533daa50004d85665f8", 32 | "title": "Improve psoriasis symptoms with strategies to manage nutrition", 33 | "abstract": "Psoriasis is an inflammatory skin disease that has been linked to both genetic and environmental factors. From 0.09 to 11.43% of the world's population has this dermatosis; in industrialized nations, the prevalence is between 1.5% and 5%. Psoriasis is believed to be caused by a combination of adaptive and innate immune responses. The PASI scale measures the clinical severity of psoriasis on a scale from 0 to 100. This analysis was conducted to determine if existing nutrition interventions are effective in alleviating psoriasis symptoms. Science Direct, Google Scholar, Scopus, PubMed, and ClinicalTrials.gov were used to compile the data for this review. We used the following search terms to narrow our results: psoriasis, nutrition, diet treatment, vitamin, RCTs, and clinical trials. Ten studies were selected from the 63 articles for this review. Research designs are evaluated using the Risk of Bias 2 (RoB2), the Risk of Bias in Non-Randomized Studies of Interventions (ROBINS-I), and the Newcastle-Ottawa Scale (NOS). Studies concluded that a Mediterranean diet, vitamin D3 supplementation, the elimination of cadmium (Cd), lead (Pb), and mercury from the diet, as well as intermittent fasting and low-energy diets for weight loss in obese patients, can alleviate the symptoms of inflammatory diseases. Psoriasis patients undergoing treatment should adhere to dietary recommendations.", 34 | }, 35 | { 36 | "paperId": "e63da5fcf63a7086a4f9041dcc7d2ffde91046ea", 37 | "title": "Diurnal and seasonal variation in psoriasis symptoms", 38 | "abstract": "The frequency of itch in people with psoriasis varies between 64-97% and may exhibit time of day differences as well as interfere with sleep. Furthermore, psoriasis flares appear to exhibit seasonal variation. Whilst a North American study, using a physician-rating scale, showed a trend for winter flaring and summer clearing; a Japanese study, using proxy measures, found no difference between hot and cold months.", 39 | }, 40 | { 41 | "paperId": "53d1d11a83740eebe36bf660f085f1c4bb3bcaea", 42 | "title": "Development and Content Validation of the Psoriasis Symptoms and Impacts Measure (P-SIM) for Assessment of Plaque Psoriasis", 43 | "abstract": None, 44 | }, 45 | { 46 | "paperId": "3eee2d46fc672006703f3165340549e4706c2e48", 47 | "title": "Improvement in Patient-Reported Outcomes (Dermatology Life Quality Index and the Psoriasis Symptoms and Signs Diary) with Guselkumab in Moderate-to-Severe Plaque Psoriasis: Results from the Phase III VOYAGE 1 and VOYAGE 2 Studies", 48 | "abstract": None, 49 | }, 50 | { 51 | "paperId": "5232ae50c6bb4dddddc00c334e5da0202ba25e52", 52 | "title": "Prevalence and Odds of Anxiety Disorders and Anxiety Symptoms in Children and Adults with Psoriasis: Systematic Review and Meta-analysis", 53 | "abstract": "The magnitude of the association between psoriasis and depression has been evaluated, but not that between psoriasis and anxiety. The aim of this systematic review and meta-analysis was to examine the prevalence and odds of anxiety disorders and symptoms in patients with psoriasis. Five medical databases (Cochrane Database, EMBASE, PubMed, PsychINFO, ScienceDirect) were searched for relevant literature. A total of 101 eligible articles were identified. Meta-analysis revealed different prevalence rates depending on the type of anxiety disorder: 15% [95% confidence interval [CI] 9–21] for social anxiety disorder, 11% [9–14] for generalized anxiety disorder, and 9% [95% CI 8–10] for unspecified anxiety disorder. There were insufficient studies assessing other anxiety disorders to be able to draw any conclusions on their true prevalence. Meta-analysis also showed a high prevalence of anxiety symptoms (34% [95% CI 32–37]). Case-control studies showed a positive association between psoriasis and unspecified anxiety disorder (odds ratio 1.48 [1.18; 1.85]) and between psoriasis and anxiety symptoms (odds ratio 2.51 [2.02; 3.12]). All meta-analyses revealed an important heterogeneity, which could be explained in each case by methodological factors. The results of this study raise the necessity of screening for the presence of anxiety disorders, as previously recommended for depressive disorders, in patients with psoriasis and, if necessary, to refer such patients for evaluation by a mental health professional and appropriate treatment.", 54 | }, 55 | ] 56 | } 57 | return response 58 | 59 | 60 | @pytest.mark.skip("Not Implemented") 61 | def test_search_by_paper_content(semantic_scholar_search_api_wrapper): 62 | semantic_scholar_search_api_wrapper.search_by_paper_content(keywords="psoriasis, symptoms") 63 | -------------------------------------------------------------------------------- /tests/tools/synonym_finder/test_word_net_synonym_finder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from fact_finder.tools.synonym_finder.word_net_synonym_finder import WordNetSynonymFinder 4 | 5 | 6 | @pytest.mark.parametrize("query", ("table", "alcohol")) 7 | def test_query_is_returned_as_potential_synonym(query: str): 8 | finder = WordNetSynonymFinder() 9 | res = finder(query) 10 | assert query in res 11 | 12 | 13 | def test_finds_synonyms_for_all_meanings(): 14 | finder = WordNetSynonymFinder() 15 | result = finder("table") 16 | meanings = ["tabular_array", "postpone", "board"] 17 | assert all(m in result for m in meanings) 18 | -------------------------------------------------------------------------------- /tests/tools/test_entity_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import skipIf 3 | from unittest.mock import Mock, patch 4 | import pytest 5 | from dotenv import load_dotenv 6 | 7 | from fact_finder.tools.entity_detector import EntityDetector 8 | 9 | load_dotenv() 10 | 11 | 12 | @pytest.fixture() 13 | def entity_detector(): 14 | return EntityDetector() 15 | 16 | 17 | @patch.dict(os.environ, {"SYNONYM_API_KEY": "dummy_key"}) 18 | @patch("requests.request") 19 | def test_entity_detector(mock_request, entity_detector): 20 | mock_response = Mock() 21 | mock_response.status_code = 200 22 | mock_response.text = '{"annotations" : [1, 2]}' 23 | mock_request.return_value = mock_response 24 | 25 | result = entity_detector("What is pink1? Does it help with epilepsy?") 26 | assert len(result) == 2 27 | 28 | 29 | @skipIf(os.getenv("SYNONYM_API_KEY") is None, "Requires SYNONYM_API_KEY to be set.") 30 | def test_entity_detector_with_api(entity_detector): 31 | result = entity_detector("What is pink1? Does it help with epilepsy?") 32 | assert len(result) == 2 33 | result = entity_detector("What is pink1?") 34 | assert len(result) == 1 35 | result = entity_detector("atopic dermatitis") 36 | assert len(result) == 1 37 | assert result[0]["pref_term"] == "dermatitis, atopic" 38 | -------------------------------------------------------------------------------- /tests/tools/test_llm_subgraph_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | from dotenv import dotenv_values 6 | from fact_finder.tools.sub_graph_extractor import LLMSubGraphExtractor 7 | from langchain_openai import ChatOpenAI 8 | 9 | 10 | @pytest.mark.skip(reason="end to end") 11 | @patch.dict(os.environ, {**dotenv_values(), **os.environ}) 12 | def test_subgraph_extractor_e2e(llm_subgraph_extractor): 13 | assert_subgraph_extractor( 14 | llm_subgraph_extractor, 15 | "MATCH (d:disease {name: 'schizophrenia'})-[:indication]->(g:drug) RETURN g", 16 | [ 17 | "MATCH (d:disease {name: 'schizophrenia'})-[r:indication]->(g:drug) RETURN d, r, g", 18 | "MATCH (d:disease {name: 'schizophrenia'})-[r:indication]->(g:drug) RETURN r, d, g", 19 | ], 20 | ) 21 | assert_subgraph_extractor( 22 | llm_subgraph_extractor, 23 | "MATCH (d:disease {name: 'epilepsy'})-[:indication]->(g:drug) RETURN g", 24 | [ 25 | "MATCH (d:disease {name: 'epilepsy'})-[r:indication]->(g:drug) RETURN r, d, g", 26 | "MATCH (d:disease {name: 'epilepsy'})-[r:indication]->(g:drug) RETURN d, r, g", 27 | ], 28 | ) 29 | assert_subgraph_extractor( 30 | llm_subgraph_extractor, 31 | "MATCH (d:disease {name: 'epilepsy'})-[:indication]->(g:drug) RETURN g.name", 32 | [ 33 | "MATCH (d:disease {name: 'epilepsy'})-[r:indication]->(g:drug) RETURN r, d, g", 34 | "MATCH (d:disease {name: 'epilepsy'})-[r:indication]->(g:drug) RETURN d, r, g", 35 | ], 36 | ) 37 | 38 | 39 | @pytest.mark.skip(reason="end to end") 40 | @patch.dict(os.environ, {**dotenv_values(), **os.environ}) 41 | def test_more_complicated_query_subgraph_extractor_e2e(llm_subgraph_extractor): 42 | complicated_query = """MATCH (d:disease)-[:phenotype_present]-({name:"eczema"}) MATCH (d)-[:phenotype_present]-({ 43 | name:"neutropenia"}) MATCH (d)-[:phenotype_present]-({name:"high forehead"}) RETURN DISTINCT d.name""" 44 | result = llm_subgraph_extractor(complicated_query) 45 | assert len(result) > len(complicated_query) 46 | 47 | 48 | def assert_subgraph_extractor( 49 | llm_subgraph_extractor: LLMSubGraphExtractor, cypher_query: str, expected_results: list[str] 50 | ): 51 | result = llm_subgraph_extractor(cypher_query) 52 | assert result in expected_results 53 | 54 | 55 | @pytest.fixture 56 | def llm_subgraph_extractor(llm_e2e): 57 | return LLMSubGraphExtractor(model=llm_e2e) 58 | 59 | 60 | @pytest.fixture 61 | def llm_e2e(open_ai_key): 62 | return ChatOpenAI(model="gpt-4", streaming=False, temperature=0, api_key=open_ai_key) 63 | 64 | 65 | @pytest.fixture 66 | def open_ai_key(): 67 | return os.getenv("OPENAI_API_KEY") 68 | -------------------------------------------------------------------------------- /tests/tools/test_regex_subgraph_extractor.py: -------------------------------------------------------------------------------- 1 | from fact_finder.tools.sub_graph_extractor import RegexSubGraphExtractor 2 | 3 | 4 | def test_extract_subgraph_preprocessor(): 5 | extract_subgraph_preprocessor = RegexSubGraphExtractor() 6 | cypher_query = "MATCH (d:disease {name: 'schizophrenia'})-[:indication]->(g:drug) RETURN g" 7 | edited_cypher_query = extract_subgraph_preprocessor(cypher_query) 8 | expected_cypher_query = "MATCH (d:disease {name: 'schizophrenia'})-[l:indication]->(g:drug) RETURN d,l,g" 9 | assert isinstance(edited_cypher_query, str) 10 | assert len(cypher_query) + 5 == len(edited_cypher_query) 11 | assert len(expected_cypher_query) == len(edited_cypher_query) 12 | -------------------------------------------------------------------------------- /tests/tools/test_subgraph_expansion.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any 2 | from unittest.mock import MagicMock 3 | 4 | import pytest 5 | 6 | from fact_finder.tools.subgraph_extension import SubgraphExpansion 7 | 8 | 9 | @pytest.fixture 10 | def neo4j_graph(): 11 | return MagicMock() 12 | 13 | 14 | @pytest.fixture 15 | def subgraph_expansion(neo4j_graph): 16 | return SubgraphExpansion(graph=neo4j_graph) 17 | 18 | 19 | @pytest.fixture 20 | def extracted_nodes() -> List[Dict[str, Any]]: 21 | return [ 22 | { 23 | "d": { 24 | "index": 27937, 25 | "source": "mondo_grouped", 26 | "causes": "nutrition", 27 | "name": "cardioacrofacial dysplasia", 28 | "id": "30876_30877_31386", 29 | }, 30 | "pp": ( 31 | { 32 | "index": 27937, 33 | "source": "mondo_grouped", 34 | "causes": "nutrition", 35 | "name": "cardioacrofacial dysplasia", 36 | "id": "30876_30877_31386", 37 | "management_and_treatment": "no information available", 38 | "prevention": "no information available", 39 | }, 40 | "phenotype_present", 41 | { 42 | "name": "mandibular prognathia", 43 | "index": 84579, 44 | "source": "hpo", 45 | }, 46 | ), 47 | "p": { 48 | "name": "mandibular prognathia", 49 | "index": 84579, 50 | "source": "hpo", 51 | }, 52 | "properties(pp)": {}, 53 | }, 54 | ] 55 | 56 | 57 | @pytest.fixture 58 | def graph_output() -> List[Dict[str, Any]]: 59 | return [ 60 | { 61 | "a": { 62 | "index": 27937, 63 | "source": "mondo_grouped", 64 | "causes": "nutrition", 65 | "name": "cardioacrofacial dysplasia", 66 | "id": "30876_30877_31386", 67 | }, 68 | "pp": ( 69 | { 70 | "index": 27937, 71 | "source": "mondo_grouped", 72 | "causes": "nutrition", 73 | "name": "cardioacrofacial dysplasia", 74 | "id": "30876_30877_31386", 75 | "management_and_treatment": "no information available", 76 | "prevention": "no information available", 77 | }, 78 | "a_new_relation", 79 | { 80 | "name": "intermediate node", 81 | "index": 4711, 82 | "source": "hpo", 83 | }, 84 | ), 85 | "c": { 86 | "name": "intermediate node", 87 | "index": 4711, 88 | "source": "hpo", 89 | }, 90 | "properties(pp)": {}, 91 | }, 92 | { 93 | "b": { 94 | "name": "mandibular prognathia", 95 | "index": 84579, 96 | "source": "hpo", 97 | }, 98 | "pp": ( 99 | { 100 | "name": "mandibular prognathia", 101 | "index": 84579, 102 | "source": "hpo", 103 | }, 104 | "phenotype_present", 105 | { 106 | "name": "intermediate node", 107 | "index": 4711, 108 | "source": "hpo", 109 | }, 110 | ), 111 | "c": { 112 | "name": "intermediate node", 113 | "index": 4711, 114 | "source": "hpo", 115 | }, 116 | "properties(pp)": {}, 117 | }, 118 | ] 119 | 120 | 121 | def test_expansion(subgraph_expansion, neo4j_graph, extracted_nodes, graph_output): 122 | neo4j_graph.query.return_value = graph_output 123 | enriched = subgraph_expansion.expand(nodes=extracted_nodes) 124 | assert len(enriched) == 3 125 | assert extracted_nodes[0] in enriched 126 | assert graph_output[0] in enriched 127 | assert graph_output[1] in enriched 128 | --------------------------------------------------------------------------------