├── .coveragerc ├── .dockerignore ├── .env.example ├── .flake8 ├── .github └── workflows │ ├── pypi-publish.yml │ └── pytest.yml ├── .gitignore ├── .gitleaks-report.json ├── .isort.cfg ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── README-USECASE.md ├── README.md ├── README_v2.md ├── WIP-Custom-BERT-Model.ipynb ├── assets ├── brainstorm-examples.png └── displacy-sentiment-example.png ├── examples ├── Entity Recognition Evaluation.ipynb ├── Exploring-DS-Graphs.ipynb ├── Extract structured data from text.ipynb ├── SRL-Coref.ipynb ├── WIP-QA-pipeline.ipynb ├── brainstorming_examples.ipynb ├── data_graph.py ├── de_chatintents.py ├── de_nested_objects.py ├── ee │ └── sports │ │ └── README.md ├── er_travel.py ├── homographs.ipynb └── sentiment_analysis_of_reviews.ipynb ├── poetry.lock ├── promptedgraphs ├── __init__.py ├── __main__.py ├── cli.py ├── code_execution │ ├── __init__.py │ └── safer_python_exec.py ├── config.py ├── data_extraction.py ├── data_modeling │ ├── __init__.py │ ├── conceptual.py │ ├── logical.py │ └── physical.py ├── entity_linking │ ├── README.md │ ├── __init__.py │ ├── link.py │ └── upsertion.py ├── entity_recognition.py ├── entity_resolution │ ├── __init__.py │ └── resolve.py ├── extraction │ ├── __init__.py │ ├── data_from_text.py │ └── entities_from_text.py ├── generation │ ├── __init__.py │ ├── data_from_model.py │ ├── schema_from_data.py │ └── schema_from_model.py ├── helpers.py ├── llms │ ├── __init__.py │ ├── anthropic_chat.py │ ├── chat.py │ ├── coding.py │ ├── helpers.py │ ├── openai_chat.py │ ├── openai_streaming.py │ └── usage.py ├── models.py ├── normalization │ ├── __init__.py │ ├── object_to_data.py │ ├── schema_to_schema.py │ └── vis_graphs.py ├── parsers.py ├── sources │ ├── __init__.py │ ├── datagraph_from_class.py │ └── datagraph_from_pydantic.py ├── statistical │ ├── __init__.py │ └── data_analysis.py ├── utils │ └── __init__.py ├── validation │ ├── __init__.py │ ├── validate_data.py │ └── validate_schema.py └── vis.py ├── pyproject.toml ├── run_security_check.sh ├── run_tests_with_coverage.sh └── tests ├── __init__.py ├── all.py ├── generation ├── __init__.py └── test_schema_from_data.py ├── test_cli.py ├── test_config.py └── test_install.py /.coveragerc: -------------------------------------------------------------------------------- 1 | # https://coveralls-python.readthedocs.io/en/latest/tips/coveragerc.html 2 | 3 | [run] 4 | source = promptedgraphs 5 | branch = True 6 | 7 | [report] 8 | exclude_lines = 9 | pragma: no cover 10 | def __repr__ 11 | raise AssertionError 12 | raise NotImplementedError 13 | if __name__ == .__main__.: 14 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | env 3 | venv 4 | htmlcov 5 | *.egg-info 6 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # An example configuration file for the application. 2 | 3 | # OPENAI_API_KEY="sk-.............." 4 | 5 | # For scraping api documentation 6 | OGTAGS_API_KEY=".............." # https://ogtags.com 7 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | max-complexity = 18 4 | select = B,C,E,F,W,T4,B9 5 | ignore = E203, E266, E501, W503, F403, E743, E741 6 | exclude = 7 | # list of ignores 8 | .git, 9 | __pycache__, 10 | .pytest_cache, 11 | .vscode, 12 | prompt_to_code.egg-info, 13 | env, 14 | dist 15 | data_models 16 | -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | #---------------------------------------------- 13 | # check-out repo and set-up python 14 | #---------------------------------------------- 15 | - name: Check out repository 16 | uses: actions/checkout@v3 17 | - name: Set up python 18 | id: setup-python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: '3.10' 22 | #---------------------------------------------- 23 | # ----- install & configure poetry ----- 24 | #---------------------------------------------- 25 | - name: Install Poetry 26 | uses: snok/install-poetry@v1 27 | with: 28 | virtualenvs-create: true 29 | virtualenvs-in-project: true 30 | installer-parallel: true 31 | #---------------------------------------------- 32 | # install dependencies if cache does not exist 33 | #---------------------------------------------- 34 | - name: Install dependencies 35 | run: poetry install --no-interaction --no-root 36 | #---------------------------------------------- 37 | # install your root project, if required 38 | #---------------------------------------------- 39 | - name: Install project 40 | run: poetry install --no-interaction 41 | #---------------------------------------------- 42 | # run build 43 | #---------------------------------------------- 44 | - name: Build Package 45 | run: | 46 | poetry build 47 | poetry add twine 48 | - name: Build package 49 | run: poetry build 50 | - name: package info 51 | run: head pyproject.toml 52 | - name: Publish package 53 | uses: pypa/gh-action-pypi-publish@release/v1 54 | with: 55 | user: __token__ 56 | password: ${{ secrets.PYPI_API_TOKEN }} 57 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: test 5 | 6 | on: ["push", "pull_request"] 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | test: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | #---------------------------------------------- 17 | # check-out repo and set-up python 18 | #---------------------------------------------- 19 | - name: Check out repository 20 | uses: actions/checkout@v3 21 | - name: Set up python 22 | id: setup-python 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: '3.10' 26 | #---------------------------------------------- 27 | # ----- install & configure poetry ----- 28 | #---------------------------------------------- 29 | - name: Install Poetry 30 | uses: snok/install-poetry@v1 31 | with: 32 | virtualenvs-create: true 33 | virtualenvs-in-project: true 34 | installer-parallel: true 35 | 36 | #---------------------------------------------- 37 | # load cached venv if cache exists 38 | #---------------------------------------------- 39 | - name: Load cached venv 40 | id: cached-poetry-dependencies 41 | uses: actions/cache@v3 42 | with: 43 | path: .venv 44 | key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 45 | #---------------------------------------------- 46 | # install dependencies if cache does not exist 47 | #---------------------------------------------- 48 | - name: Install dependencies 49 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 50 | run: poetry install --no-interaction --no-root 51 | #---------------------------------------------- 52 | # install your root project, if required 53 | #---------------------------------------------- 54 | - name: Install project 55 | run: poetry install --no-interaction 56 | #---------------------------------------------- 57 | # run test suite 58 | #---------------------------------------------- 59 | - name: Run tests 60 | run: | 61 | source .venv/bin/activate 62 | pytest tests/ --cov=promptedgraphs 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | .vscode 162 | .gitleaks-report.json 163 | 164 | 165 | data_models 166 | examples/ee -------------------------------------------------------------------------------- /.gitleaks-report.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | line_length = 88 3 | multi_line_output = 3 4 | include_trailing_comma = True 5 | force_grid_wrap = 0 6 | use_parentheses = True 7 | ensure_newline_before_comments = True 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | minimum_pre_commit_version: '2.9.0' 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.4.0 5 | hooks: 6 | - id: trailing-whitespace 7 | - id: end-of-file-fixer 8 | - id: mixed-line-ending 9 | args: ['--fix=lf'] 10 | - id: check-added-large-files 11 | args: ['--maxkb=30000'] 12 | # - id: no-commit-to-branch 13 | - repo: https://github.com/asottile/pyupgrade 14 | rev: v3.3.1 15 | hooks: 16 | - id: pyupgrade 17 | args: [--py310-plus] 18 | files: \.py$ 19 | - repo: https://github.com/PyCQA/isort 20 | rev: 5.12.0 21 | hooks: 22 | - id: isort 23 | - repo: https://github.com/ambv/black 24 | rev: 23.3.0 25 | hooks: 26 | - id: black 27 | - repo: https://github.com/myint/eradicate 28 | rev: v2.1.0 29 | hooks: 30 | - id: eradicate 31 | - repo: https://github.com/PyCQA/flake8 32 | rev: 5.0.4 33 | hooks: 34 | - id: flake8 35 | # - repo: local 36 | # hooks: 37 | # - id: vulture 38 | # name: vulture 39 | # description: Find dead Python code 40 | # entry: vulture 41 | # args: ["--min-confidence", "80", "--exclude", "env,venv,examples,data_models", "."] 42 | # language: system 43 | # types: [python] 44 | - repo: https://github.com/PyCQA/autoflake 45 | rev: v2.0.1 46 | hooks: 47 | - id: autoflake 48 | args: [--in-place, --remove-all-unused-imports, --remove-unused-variables] 49 | types_or: [python, pyi] 50 | - repo: https://github.com/adamchainz/blacken-docs 51 | rev: 1.13.0 52 | hooks: 53 | - id: blacken-docs 54 | additional_dependencies: 55 | - black 56 | # - repo: https://github.com/gitleaks/gitleaks 57 | # rev: v8.16.1 58 | # hooks: 59 | # - id: gitleaks 60 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG APP_NAME=promptedgraphs 2 | 3 | # Base image 4 | FROM python:3.10-slim-buster as staging 5 | 6 | # Install necessary system packages 7 | RUN apt-get update && apt-get install -y --no-install-recommends \ 8 | build-essential \ 9 | curl \ 10 | && rm -rf /var/lib/apt/lists/* 11 | 12 | # Set environment variables 13 | ENV PYTHONDONTWRITEBYTECODE 1 14 | ENV PYTHONUNBUFFERED 1 15 | ENV POETRY_HOME="/opt/poetry" 16 | ENV PATH="$POETRY_HOME/bin:$PATH" 17 | 18 | # Install Poetry 19 | RUN curl -sSL https://install.python-poetry.org | python - && chmod +x $POETRY_HOME/bin/poetry 20 | 21 | # Set work directory 22 | WORKDIR /app 23 | 24 | # Copy only requirements to cache them in docker layer 25 | COPY poetry.lock pyproject.toml /app/ 26 | 27 | # # Project initialization: 28 | RUN poetry install --no-interaction --no-root 29 | 30 | # # Copying the project files into the container 31 | COPY . /app/ 32 | 33 | # # Install the project 34 | RUN poetry install --no-interaction 35 | RUN pip install https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl 36 | 37 | # # Command to run tests 38 | RUN poetry run pytest --cov promptedgraphs && poetry run coverage report 39 | 40 | ENTRYPOINT [ "poetry", "run" ] 41 | CMD [ "python", "-m", "promptedgraphs", "info" ] 42 | 43 | 44 | FROM staging as build 45 | ARG APP_NAME 46 | 47 | WORKDIR /app 48 | RUN poetry build --format wheel 49 | RUN poetry export --format requirements.txt --output constraints.txt --without-hashes 50 | 51 | 52 | FROM python:3.10-slim-buster as production 53 | ARG APP_NAME 54 | 55 | # Set environment variables 56 | ENV \ 57 | PYTHONDONTWRITEBYTECODE=1 \ 58 | PYTHONUNBUFFERED=1 \ 59 | PYTHONFAULTHANDLER=1 60 | 61 | ENV \ 62 | PIP_NO_CACHE_DIR=off \ 63 | PIP_DISABLE_PIP_VERSION_CHECK=on \ 64 | PIP_DEFAULT_TIMEOUT=100 65 | 66 | # Get build artifact wheel and install it respecting dependency versions 67 | WORKDIR /app 68 | COPY --from=build /app/dist/*.whl ./ 69 | COPY --from=build /app/constraints.txt ./ 70 | RUN pip install ./$APP_NAME*.whl --constraint constraints.txt 71 | ENTRYPOINT [ "python"] 72 | CMD [ "-m", "promptedgraphs", "info" ] 73 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ClosedLoop Technologies 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README-USECASE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Steps 4 | 5 | 1. Given data from an API endpoint 6 | * Description of the endpoint (url, method, parameters, etc.) 7 | * Example **Raw Data** from the endpoint 8 | 2. Generate a Pydantic **DataModel** from the example data 9 | 3. Repeat for two other endpoints 10 | 4. Construct a **DataGraph** from the **DataModels** to represent the relationships between the data 11 | 5. Generate a **PropertyGraph-Schema** from the **DataGraph** and represent as an ER-Diagram. 12 | 6. Create a schema alignment between the **PropertyGraph-Schema** and the properties of the **DataGraph**. 13 | a. Indicate how the data models should be transformed to fit the schema. 14 | b. TODO handle cases where keys and values need to be 'pivoted' to fit the schema. 15 | 7. Generate a **Database-Schema** from the **PropertyGraph-Schema**. The data should be third-form-normal as a default. 16 | 8. Create a **Database** from the **Database-Schema** 17 | 9. Implement ETL tasks to transform, and load data from the API endpoints into the database. 18 | 1. Transform should get the example data and convert it to an in-memory tables reflecting the database schema. These are the **staging** tables. 19 | 2. Load should first `resolve` any existing data and update primary keys with the existing keys. 20 | 3. Human in the loop for any ambiguous keys. 21 | 4. Insert the data into the database. 22 | 5. Optionally Update any existing data. 23 | 24 | In this library, we should be able to manually chain together these functions and generate the necessary code. 25 | 26 | The library **AutoETL** should be able to automatically chain these functions together and run them in a pipeline. The library should also be able to generate the code for the pipeline. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PromptedGraphs 2 | 3 | **From Dataset Labeling to Deployment: The Power of NLP and LLMs Combined.** 4 | 5 | ## Description 6 | 7 | PromptedGraphs is a Python library that aims to seamlessly integrate traditional NLP methods with the capabilities of modern Large Language Models (LLMs) in the realm of knowledge graphs. Our library offers tools tailored for dataset labeling, model training, and smooth deployment to production environments. We leverage the strengths of [spacy](https://github.com/explosion/spaCy) for core NLP tasks, [snorkel](https://github.com/closedloop-technologies/snorkel) for effective data labeling, and `async` to ensure enhanced performance. Our mission is to provide a harmonized solution to knowledge graph development when you have to merge traditional and LLM-driven approaches, squarely addressing the challenges associated with accuracy, efficiency, and affordability. 8 | 9 | ## ✨ Features 10 | 11 | - **Named Entity Recognition (NER)**: Customize ER labels based on your domain. 12 | - **Structured Data Extraction**: Extract structured data from unstructured text. 13 | - **Entity Resolution**: Deduplication and normalization 14 | - **Relationship Extraction**: Either open ended labels or constrain to your domain 15 | - **Entity Linking**: Link references in text to entities in a graph 16 | - **Graph Construction**: Create or update knowledge graphs 17 | 18 | ## Core Functions 19 | 20 | - **Dataset Labeling**: Efficient tools for labeling datasets, powered by `haystack`. 21 | - **Model Training**: Combine the reliability of NLP and the prowess of LLMs. 22 | - **Deployment**: Streamlined processes to ensure smooth transition to production. 23 | 24 | ## Requirements 25 | 26 | - Python 3.10 or newer. 27 | 28 | ## 📦 Installation 29 | 30 | To install `PromptedGraphs` via pip: 31 | 32 | ```bash 33 | pip install promptedgraphs 34 | # or 35 | poetry add promptedgraphs 36 | ``` 37 | 38 | ## Usage 39 | ### Entity Recognition 40 | 41 | from [examples/er_reviews.ipynb](https://github.com/closedloop-technologies/PromptedGraphs/blob/main/examples/er_reviews.ipynb) 42 | 43 | ```python 44 | from spacy import displacy 45 | from promptedgraphs.config import Config 46 | from promptedgraphs.extraction.entities_from_text import entities_from_text 47 | 48 | labels = { 49 | "POSITIVE": "A postive review of a product or service.", 50 | "NEGATIVE": "A negative review of a product or service.", 51 | "NEUTRAL": "A neutral review of a product or service.", 52 | } 53 | 54 | text_of_reviews = """ 55 | 1. "I absolutely love this product. It's been a game changer!" 56 | 2. "The service was quite poor and the staff was rude." 57 | 3. "The item is okay. Nothing special, but it gets the job done." 58 | """.strip() 59 | 60 | 61 | # Label Sentiment 62 | ents = [] 63 | async for msg in entities_from_text( 64 | name="sentiment", 65 | description="Sentiment Analysis of Customer Reviews", 66 | text=text_of_reviews, 67 | labels=labels, 68 | config=Config(), # Reads `OPENAI_API_KEY` from .env file or environment 69 | ): 70 | ents.append(msg) 71 | 72 | # Show Results using spacy.displacy 73 | render_entities( 74 | text=text_of_reviews, 75 | entities=ents, 76 | labels=labels, 77 | colors = {"POSITIVE": "#7aecec", "NEGATIVE": "#f44336", "NEUTRAL": "#f4f442"} 78 | ) 79 | ``` 80 | ![displacy-sentiment-example](./assets/displacy-sentiment-example.png?raw=true) 81 | 82 | ### Brainstorming Data 83 | Generate a list of data that fits a given data model. 84 | 85 | from [examples/er_reviews.ipynb](https://github.com/closedloop-technologies/PromptedGraphs/blob/main/examples/brainstorming_examples.ipynb) 86 | 87 | ```python 88 | from pydantic import BaseModel, Field 89 | 90 | from promptedgraphs.config import Config 91 | from promptedgraphs.ideation import brainstorm 92 | from promptedgraphs.vis import render_entities 93 | 94 | 95 | class BusinessIdea(BaseModel): 96 | """A business idea generated using the Jobs-to-be-done framework 97 | For example "We help [adj] [target_audience] [action] so they can [benefit or do something else]" 98 | """ 99 | 100 | target_audience: str = Field(title="Target Audience") 101 | action: str = Field(title="Action") 102 | benefit: str = Field(title="Benefit or next action") 103 | adj: str | None = Field( 104 | title="Adjective", 105 | description="Optional adjective describing the target audience's condition", 106 | ) 107 | 108 | 109 | ideas = [] 110 | async for idea in brainstorm( 111 | text=BusinessIdea.__doc__, 112 | output_type=list[BusinessIdea], 113 | config=Config(), 114 | n=10, 115 | max_workers=2, 116 | ): 117 | ideas.append(idea) 118 | render_entities( 119 | f"We help {idea.adj} {idea.target_audience} {idea.action} so they can {idea.benefit}", 120 | idea, 121 | ) 122 | ``` 123 | ![brainstorm-examples](./assets/brainstorm-examples.png?raw=true) 124 | 125 | ### Structured Data Extraction 126 | 127 | from [examples/de_chatintents.ipynb](https://github.com/closedloop-technologies/PromptedGraphs/blob/main/examples/de_chatintents.ipynb) 128 | 129 | ```python 130 | from pydantic import BaseModel, Field 131 | 132 | from promptedgraphs.config import Config 133 | 134 | 135 | class UserIntent(BaseModel): 136 | """The UserIntent entity, representing the canonical description of what a user desires to achieve in a given conversation.""" 137 | 138 | intent_name: str = Field( 139 | title="Intent Name", 140 | description="Canonical name of the user's intent", 141 | examples=[ 142 | "question", 143 | "command", 144 | "clarification", 145 | "chit_chat", 146 | "greeting", 147 | "feedback", 148 | "nonsensical", 149 | "closing", 150 | "harrassment", 151 | "unknown", 152 | ], 153 | ) 154 | description: str | None = Field( 155 | title="Intent Description", 156 | description="A detailed explanation of the user's intent", 157 | ) 158 | 159 | 160 | msg = """It's a busy day, I need to send an email and to buy groceries""" 161 | 162 | async for intent in data_from_text( 163 | text=msg, output_type=UserIntent, config=Config() 164 | ): 165 | print(intent) 166 | ``` 167 | ```bash 168 | intent_name='task' description='User wants to complete a task' 169 | intent_name='communication' description='User wants to send an email' 170 | intent_name='shopping' description='User wants to buy groceries' 171 | ``` 172 | 173 | ## 📚 Resources 174 | 175 | * [Awesome-LLM-KG](https://github.com/RManLuo/Awesome-LLM-KG) 176 | * [KG-LLM-Papers](https://github.com/zjukg/KG-LLM-Papers) 177 | 178 | ### Related Libraries 179 | * [instructor](https://jxnl.github.io/instructor/) 180 | * [marvin](https://github.com/PrefectHQ/marvin) 181 | 182 | ## Contributing 183 | 184 | We welcome contributions! Please DM me [@seankruzel](https://twitter.com/seankruzel) or create issues or pull requests. 185 | 186 | ## 📝 License 187 | 188 | This project is licensed under the terms of the [MIT license](/LICENSE). 189 | 190 | Built using [quantready](https://github.com/closedloop-technologies/quantready) using template [https://github.com/closedloop-technologies/quantready-api](https://github.com/closedloop-technologies/quantready-api) 191 | -------------------------------------------------------------------------------- /README_v2.md: -------------------------------------------------------------------------------- 1 | # PromptedGraphs 2 | 3 | Do you have data coming in from multiple sources? Do you need to extract structured data from unstructured text? Do you need to update a database or knowledge graph with this data? If so, PromptedGraphs is for you! 4 | 5 | ## Key Concepts 6 | 7 | ### Data Modeling Concepts 8 | * **Raw Data**: Any data coming from a source, such as text or arbitrary objects. 9 | * **DataModel**: A Pydantic DataModel is a schema for structured data. It is used to validate and serialize data. This can be represented as a JSON Schema or a Pydantic DataModel. 10 | * **DataGraph**: This is a taxonomy-like structure used to organize multiple DataModels. It connects properties from different DataModels via `same_as`, `part_of` and `is_in` relationships. In this case we are using the term `graph` to refer to a network of nodes and edges where each node is a property of a DataModel and each edge is a relationship between two properties. 11 | 12 | ### Storage Concepts 13 | Data can be stored in a Knowledge Graph (which follows an PropertyGraph-Schema) or in a Database (which follows a Database-Schema). 14 | 15 | * **PropertyGraph-Schama**: This is a more formalized version of a DataGraph. It is used to define the relationships between DataModels and their properties. Much like an Ontology, it typically includes a hierarchy of classes and properties and can merge or split DataModels as well as define new properties. For Ontology-purists, this might feel a bit like a lightweight version of [OWL](https://www.w3.org/TR/owl-guide/). 16 | * **Database-Schema**: This represents a relational database schema. It is used to define tables and columns and their relationships such as foreign keys and indexes. 17 | 18 | 19 | ## ETL-Related Tasks 20 | 21 | * **Extract** The purpose of this is to take raw data (such as text or arbitrary objects) and convert it into structured data. 22 | 23 | * **Transform** We then transform this structured data into a format that can be used to update a database. 24 | 25 | * **Load** The transformed data is then used to update the database or knowledge graph. 26 | 27 | ## NLP-Related Tasks 28 | Implementing the above ETL tasks for arbitrary data is easier said than done. This is where NLP comes in. We can use NLP to extract structured data from unstructured text. This can be done in a variety of ways: 29 | 30 | * **Extraction** 31 | * `data_from_text`: Extract structured data from unstructured text. 32 | * `entities_from_text`: Extract entities from unstructured text. 33 | * **Relationship Extraction**: Either open ended labels or constrain to your domain 34 | 35 | * **Transformation** 36 | * `schema_from_data`: Generate schemas from data samples. 37 | * We represent schemas as Pydantic DataModels and as JSON Schema. 38 | * `taxonomy`: 39 | * `data_from_schema`: Generate data samples from schemas. 40 | * `normalize`: Convert data to fit schema specifications with light reformats. 41 | 42 | * **Loading** 43 | * `update_database`: Update a database with structured data. 44 | * `update_graph`: Update a knowledge graph with structured data. 45 | * **Entity Linking**: Link references in text to entities in a graph 46 | * **Graph Construction**: Create or update knowledge graphs 47 | * **Entity Resolution**: Deduplication and normalization 48 | 49 | * **Testing** 50 | * `generate`: Generate samples data from Pydantic DataModels. This is the reverse direction from `schema_from_data` 51 | * **Validation**: Validate schemas and data against specifications or models. 52 | 53 | 54 | ## Installation 55 | 56 | Install PromptedGraphs using pip: 57 | 58 | ```bash 59 | pip install promptedgraphs 60 | ``` 61 | 62 | ## Usage 63 | 64 | Here's a quick example to get you started: 65 | 66 | ```python 67 | from promptedgraphs.extraction import data_from_text 68 | from myproject.models import MyDataModel 69 | 70 | # Example text 71 | text = "Your example text goes here." 72 | 73 | # Extract data from text 74 | data = data_from_text(text, model=MyDataModel) 75 | 76 | print(data) 77 | ``` 78 | 79 | Replace `MyDataModel` with your Pydantic `DataModel` tailored to your specific needs. 80 | 81 | ## Documentation 82 | 83 | For detailed documentation, visit [Link to Documentation]. 84 | 85 | ## Contributing 86 | 87 | Contributions are welcome! If you have ideas for new features or improvements, feel free to open an issue or submit a pull request. 88 | 89 | ## Support 90 | 91 | If you need help or have any questions, please open an issue in the GitHub repository. 92 | 93 | ## License 94 | 95 | PromptedGraphs is released under the [MIT License](LICENSE). 96 | -------------------------------------------------------------------------------- /WIP-Custom-BERT-Model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "4ac1a5db-0cc8-40b5-bbe6-8b23f06ddcee", 6 | "metadata": {}, 7 | "source": [ 8 | "# BERTModel " 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 83, 14 | "id": "5c8b5697-3488-426c-8982-b1b4034a333b", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n", 22 | "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 23 | "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "from transformers import AutoModelForTokenClassification, AutoTokenizer\n", 29 | "import torch\n", 30 | "from promptedgraphs.models import EntityReference\n", 31 | "from typing import Dict, List\n", 32 | "from collections import defaultdict\n", 33 | "import re\n", 34 | "import re\n", 35 | "from promptedgraphs.vis import render_entities\n", 36 | "\n", 37 | "# Load the BERT model and tokenizer\n", 38 | "model_name = \"dbmdz/bert-large-cased-finetuned-conll03-english\"\n", 39 | "model = AutoModelForTokenClassification.from_pretrained(model_name)\n", 40 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 41 | "\n", 42 | "\n", 43 | "async def extract_entities_bert(\n", 44 | " text: str, labels: Dict[str, str]\n", 45 | ") -> List[EntityReference]:\n", 46 | " # Tokenize the text and convert to tensor\n", 47 | " inputs = tokenizer(text, return_tensors=\"pt\", padding=True, truncation=True)\n", 48 | "\n", 49 | " # Predict entities using BERT\n", 50 | " with torch.no_grad():\n", 51 | " outputs = model(**inputs).logits\n", 52 | " predictions = torch.argmax(outputs, dim=2)\n", 53 | "\n", 54 | " # Map predictions to entity labels\n", 55 | " tokenized_text = tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"][0])\n", 56 | " entities = defaultdict(list)\n", 57 | "\n", 58 | " prev_token_label = \"0\"\n", 59 | " start_char = 0\n", 60 | " for token, prediction in zip(tokenized_text, predictions[0].numpy()):\n", 61 | " label = model.config.id2label[prediction]\n", 62 | " text_span = tokenizer.convert_tokens_to_string([token])\n", 63 | " if label != \"O\": # O means no entity\n", 64 | " if prev_token_label == label: # combine continuous labels\n", 65 | " entities[label][-1][\"text\"] += \" \" + token\n", 66 | " entities[label][-1][\"end\"] += len(text_span) + 1\n", 67 | " else:\n", 68 | " entities[label].append(\n", 69 | " {\n", 70 | " \"text\": token,\n", 71 | " \"start\": start_char,\n", 72 | " \"end\": start_char + len(text_span),\n", 73 | " }\n", 74 | " )\n", 75 | " prev_token_label = label\n", 76 | " if token not in {\"[CLS]\", \"[SEP]\"}:\n", 77 | " start_char += len(text_span) + 1\n", 78 | "\n", 79 | " # convert to EntityReference\n", 80 | " processed_entities = []\n", 81 | " for label, tokens in entities.items():\n", 82 | " for m in tokens:\n", 83 | " for match in re.finditer(m[\"text\"], text):\n", 84 | " entity = EntityReference(\n", 85 | " start=match.start(),\n", 86 | " end=match.end(),\n", 87 | " text=m[\"text\"],\n", 88 | " label=label,\n", 89 | " )\n", 90 | " processed_entities.append(entity)\n", 91 | "\n", 92 | " return processed_entities" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 84, 98 | "id": "3263b6e1-913c-4425-bf85-8108ce04a76e", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "text = \"I am in Kansas, my favorite actor is Matt Damon and I live in North Dakota\"\n", 103 | "entities = await extract_entities_bert(text, labels=labels)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 85, 109 | "id": "c8104cf9-3d15-41d5-b603-0bcab34b3188", 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "data": { 114 | "text/html": [ 115 | "
I am in \n", 116 | "\n", 117 | " Kansas\n", 118 | " I-LOC\n", 119 | "\n", 120 | ", my favorite actor is \n", 121 | "\n", 122 | " Matt Damon\n", 123 | " I-PER\n", 124 | "\n", 125 | " and I live in \n", 126 | "\n", 127 | " North Dakota\n", 128 | " I-LOC\n", 129 | "\n", 130 | "
" 131 | ], 132 | "text/plain": [ 133 | "" 134 | ] 135 | }, 136 | "metadata": {}, 137 | "output_type": "display_data" 138 | } 139 | ], 140 | "source": [ 141 | "render_entities(text, entities)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 86, 147 | "id": "4cb7f178-bd63-4fca-9348-22ff6e2d4264", 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "[EntityReference(start=8, end=14, label='I-LOC', text='Kansas', reason=None),\n", 154 | " EntityReference(start=62, end=74, label='I-LOC', text='North Dakota', reason=None),\n", 155 | " EntityReference(start=37, end=47, label='I-PER', text='Matt Damon', reason=None)]" 156 | ] 157 | }, 158 | "execution_count": 86, 159 | "metadata": {}, 160 | "output_type": "execute_result" 161 | } 162 | ], 163 | "source": [ 164 | "entities" 165 | ] 166 | } 167 | ], 168 | "metadata": { 169 | "kernelspec": { 170 | "display_name": "Python 3 (ipykernel)", 171 | "language": "python", 172 | "name": "python3" 173 | }, 174 | "language_info": { 175 | "codemirror_mode": { 176 | "name": "ipython", 177 | "version": 3 178 | }, 179 | "file_extension": ".py", 180 | "mimetype": "text/x-python", 181 | "name": "python", 182 | "nbconvert_exporter": "python", 183 | "pygments_lexer": "ipython3", 184 | "version": "3.10.12" 185 | } 186 | }, 187 | "nbformat": 4, 188 | "nbformat_minor": 5 189 | } 190 | -------------------------------------------------------------------------------- /assets/brainstorm-examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/assets/brainstorm-examples.png -------------------------------------------------------------------------------- /assets/displacy-sentiment-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/assets/displacy-sentiment-example.png -------------------------------------------------------------------------------- /examples/Extract structured data from text.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Extract structured data from text" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from pydantic import BaseModel, Field\n", 17 | "\n", 18 | "from promptedgraphs.config import Config, load_config\n", 19 | "from promptedgraphs.extraction.data_from_text import data_from_text\n", 20 | "\n", 21 | "\n", 22 | "class UserIntent(BaseModel):\n", 23 | " \"\"\"The UserIntent entity, representing the canonical description of what a user desires to achieve in a given conversation.\"\"\"\n", 24 | "\n", 25 | " intent_name: str = Field(\n", 26 | " title=\"Intent Name\",\n", 27 | " description=\"Canonical name of the user's intent\",\n", 28 | " examples=[\n", 29 | " \"question\",\n", 30 | " \"command\",\n", 31 | " \"clarification\",\n", 32 | " \"chit_chat\",\n", 33 | " \"greeting\",\n", 34 | " \"feedback\",\n", 35 | " \"nonsensical\",\n", 36 | " \"closing\",\n", 37 | " \"harrassment\",\n", 38 | " \"unknown\",\n", 39 | " ],\n", 40 | " )\n", 41 | " description: str | None = Field(\n", 42 | " title=\"Intent Description\",\n", 43 | " description=\"A detailed explanation of the user's intent\",\n", 44 | " )" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "intent_name='question' description='How can I learn more about your product?'\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "load_config()\n", 62 | "\n", 63 | "msg = \"\"\"How can I learn more about your product?\"\"\"\n", 64 | "async for intent in data_from_text(text=msg, output_type=UserIntent, config=Config()):\n", 65 | " print(intent)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 6, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "class UserTask(BaseModel):\n", 75 | " \"\"\"A specific TODO item\"\"\"\n", 76 | "\n", 77 | " task_name: str = Field(\n", 78 | " title=\"Task Name\",\n", 79 | " description=\"Canonical name of the user's task, usually a verb\",\n", 80 | " )" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 7, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "name": "stdout", 90 | "output_type": "stream", 91 | "text": [ 92 | "task_name='send an email'\n", 93 | "task_name='buy groceries'\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "msg = \"\"\"It's a busy day, I need to send an email and to buy groceries\"\"\"\n", 99 | "async for intent in data_from_text(text=msg, output_type=UserTask, config=Config()):\n", 100 | " print(intent)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [] 109 | } 110 | ], 111 | "metadata": { 112 | "kernelspec": { 113 | "display_name": "Python 3 (ipykernel)", 114 | "language": "python", 115 | "name": "python3" 116 | }, 117 | "language_info": { 118 | "codemirror_mode": { 119 | "name": "ipython", 120 | "version": 3 121 | }, 122 | "file_extension": ".py", 123 | "mimetype": "text/x-python", 124 | "name": "python", 125 | "nbconvert_exporter": "python", 126 | "pygments_lexer": "ipython3", 127 | "version": "3.10.12" 128 | } 129 | }, 130 | "nbformat": 4, 131 | "nbformat_minor": 4 132 | } 133 | -------------------------------------------------------------------------------- /examples/data_graph.py: -------------------------------------------------------------------------------- 1 | """The goal of this is to construct a data graph from an openapi spec. 2 | 3 | 1. Ensure all outputs of API call are typed using pydantic 4 | 2. Ensure all inputs of API call are typed using pydantic 5 | 3. Ensure all endpoints are typed using pydantic and contain all inputs and outputs 6 | 4. Map common fields between inputs and outputs to create a SourceGraph of (object -> property) pairs 7 | 5. Label property nodes of SourceGraph into as {'auth', 'query', 'data', 'is_unique_id', ...} 8 | 6. Create DataGraph by normalizing the objecst and properties of the SourceGraph 9 | 1. what are unique ids? 10 | 2. core properties 11 | 7. Create a DataGraph from the SourceGraph 12 | """ 13 | -------------------------------------------------------------------------------- /examples/de_chatintents.py: -------------------------------------------------------------------------------- 1 | """ 2 | Here are some messages a chatbot might receive and their category 3 | 4 | 1. **Questions and Requests:** 5 | - "How can I help you?" 6 | - "Tell me a joke." 7 | - "What's the weather like tomorrow?" 8 | - "How does this work?" 9 | - "Where can I find more information?" 10 | 11 | 2. **Greetings and Salutations:** 12 | - "Hello!" 13 | - "Hi there!" 14 | - "Good morning." 15 | - "Hey, Chatbot!" 16 | 17 | 3. **Feedback and Complaints:** 18 | - "You're not very helpful." 19 | - "Thanks for the assistance." 20 | - "That wasn't the answer I was looking for." 21 | 22 | 4. **Technical Queries:** 23 | - "Can you integrate with my CRM?" 24 | - "Do you support multi-language queries?" 25 | - "How do I reset my password?" 26 | 27 | 5. **Personal Interaction:** 28 | - "How are you today?" 29 | - "Tell me something about yourself." 30 | - "Do you like pizza?" 31 | 32 | 6. **Random Messages and Tests:** 33 | - "asdfghjkl" 34 | - "Testing... 1, 2, 3." 35 | - "Is anyone there?" 36 | 37 | 7. **Commands:** 38 | - "Show me my account balance." 39 | - "Start a new session." 40 | - "Cancel my order." 41 | 42 | 8. **Seeking Recommendations:** 43 | - "Can you suggest a good book?" 44 | - "Where should I go for dinner?" 45 | 46 | 9. **Closing the Conversation:** 47 | - "Goodbye." 48 | - "Thanks for your help!" 49 | - "Talk to you later." 50 | 51 | 10. **Emotional or Philosophical Queries:** 52 | - "Do you have feelings?" 53 | - "What's the meaning of life?" 54 | - "Can you fall in love?" 55 | 56 | 11. **Contextual or Continued Conversations:** 57 | - "Based on that, what should I do next?" 58 | - "Tell me more." 59 | 60 | 12. **Interactive Actions:** 61 | - "[User sends an image]" 62 | - "Can you play music?" 63 | - "Translate this for me." 64 | 65 | 13. **Clarifications:** 66 | - "I didn't understand that." 67 | - "Can you explain that in simpler terms?" 68 | - "Did you mean...?" 69 | 70 | 14. **Challenging the Chatbot:** 71 | - "Are you smarter than a human?" 72 | - "I bet you can't answer this!" 73 | 74 | 15. **Feedback on Answers:** 75 | - "That was helpful, thanks!" 76 | - "That's not right. Try again." 77 | 78 | """ 79 | import asyncio 80 | from pprint import pprint 81 | 82 | import tqdm 83 | from pydantic import BaseModel, Field 84 | 85 | from promptedgraphs.config import Config, load_config 86 | from promptedgraphs.extraction.data_from_text import data_from_text 87 | 88 | 89 | async def main(): 90 | load_config() 91 | 92 | ### Information Extraction of Chat intent 93 | messages = [ 94 | "How can I help you?", # Questions and Requests 95 | "Hello!", # Greetings and Salutations 96 | "You're not very helpful.", # Feedback and Complaints 97 | "Can you integrate with my CRM?", # Technical Queries 98 | "How are you today?", # Personal Interaction 99 | "asdfghjkl", # Random Messages and Tests 100 | "Show me my account balance.", # Commands 101 | "Can you suggest a good book?", # Seeking Recommendations 102 | "Goodbye.", # Closing the Conversation 103 | "Do you have feelings?", # Emotional or Philosophical Queries 104 | "Based on that, what should I do next?", # Contextual or Continued Conversations 105 | "[User sends an image]", # Interactive Actions 106 | "I didn't understand that.", # Clarifications 107 | "Are you smarter than a human?", # Challenging the Chatbot 108 | "That was helpful, thanks!", # Feedback on Answers 109 | ] 110 | 111 | class UserIntent(BaseModel): 112 | """The UserIntent entity, representing the canonical description of what a user desires to achieve in a given conversation.""" 113 | 114 | intent_name: str = Field( 115 | title="Intent Name", 116 | description="Canonical name of the user's intent", 117 | examples=[ 118 | "question", 119 | "command", 120 | "clarification", 121 | "chit_chat", 122 | "greeting", 123 | "feedback", 124 | "nonsensical", 125 | "closing", 126 | "harrassment", 127 | "unknown", 128 | ], 129 | ) 130 | description: str | None = Field( 131 | title="Intent Description", 132 | description="A detailed explanation of the user's intent", 133 | ) 134 | 135 | intents = [] 136 | 137 | # TODO move to parrellel processing across messages 138 | # Make it an async generator 139 | async def parallel_data_from_text(text): 140 | results = [] 141 | async for result in data_from_text( 142 | text=text, output_type=UserIntent, config=Config() 143 | ): 144 | results.append((result, text)) 145 | return results 146 | 147 | # Create a list of coroutine objects 148 | tasks = [parallel_data_from_text(msg) for msg in messages] 149 | 150 | # Use as_completed to yield from tasks as they complete 151 | for future in asyncio.as_completed(tasks): 152 | results = await future 153 | intents.extend(results) 154 | 155 | pprint(intents) 156 | 157 | 158 | if __name__ == "__main__": 159 | asyncio.run(main()) 160 | -------------------------------------------------------------------------------- /examples/de_nested_objects.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from promptedgraphs.config import Config, load_config 6 | from promptedgraphs.extraction.data_from_text import data_from_text 7 | from promptedgraphs.llms.openai_chat import LanguageModel 8 | 9 | 10 | class Traveler(BaseModel): 11 | name: str = Field(title="Traveler Name") 12 | age: int | None = Field(None, title="Age of Traveler") 13 | relationship: str | None = Field(None, title="Relationship with other travelers") 14 | 15 | 16 | class WorkflowState(BaseModel): 17 | """Defines the workflow state types for the AI Agent.""" 18 | 19 | travelers: list[Traveler] = Field( 20 | title="TRAVELERS", 21 | description="A list of who is traveling, their ages, and relationships between each other", 22 | ) 23 | departure_date: str | None = Field(title="Departure Date") 24 | return_date: str | None = Field(title="Return Date") 25 | expected_length_of_trip: str | None = Field(title="Expected Length of Trip") 26 | departure_location: str | None = Field( 27 | title="DEPARTURE LOCATION", description="Where the travelers are departing from" 28 | ) 29 | budget: str | None = Field( 30 | title="BUDGET", 31 | description="Any information related to the expected costs or budget limitations for the trip", 32 | ) 33 | trip_reason: str | None = Field( 34 | title="TRIP REASON", 35 | description="The purpose of the trip, e.g., fun, work, etc.", 36 | ) 37 | locations: list[str] | None = Field( 38 | title="LOCATIONS", 39 | description="A list of all locations mentioned in the request", 40 | ) 41 | accomodations: list[str] | None = Field( 42 | title="ACCOMODATIONS", 43 | description="A list of all hotels and other accommodations mentioned in the request", 44 | ) 45 | activities: list[str] | None = Field( 46 | title="ACTIVITIES", 47 | description="A list of all activities and travel interests mentioned in the request", 48 | ) 49 | needs: str | None = Field( 50 | title="NEEDS", 51 | description="Accessibility, dietary restrictions, or medical considerations", 52 | ) 53 | accomodation_preferences: str | None = Field( 54 | title="ACCOMODATION PREFERENCES", 55 | description="Preference for types of lodging such as hotels, Airbnbs, or opinions on places to stay", 56 | ) 57 | transportation_preferences: str | None = Field( 58 | title="TRANSPORTATION PREFERENCES", 59 | description="Any preference related to how they like to travel between destinations and generally on the trip", 60 | ) 61 | interests: list[str] | None = Field( 62 | title="INTERESTS", 63 | description="Other interests mentioned that the travelers would likely enjoy doing", 64 | ) 65 | 66 | 67 | async def main(): 68 | load_config() 69 | 70 | msg = """Message from Jane, Hello fellow travelers! We're venturing to New Zealand from March 5-18th as a couple. We'll primarily be in Auckland visiting my sister, but we're hoping to explore more of the North Island. Since we're big fans of adventure sports and nature, we're thinking of places like the Tongariro Alpine Crossing or maybe Waitomo Caves. However, we're unsure about the best routes or if there are any hidden gems nearby. Any tips or suggestions? Has anyone been around those areas in March? Recommendations for cozy accommodations, local eateries, or any must-visit spots around Auckland would be greatly appreciated. Cheers!""" 71 | 72 | async for state in data_from_text( 73 | text=msg, output_type=WorkflowState, config=Config(), model=LanguageModel.GPT4 74 | ): 75 | print(state) 76 | 77 | 78 | if __name__ == "__main__": 79 | asyncio.run(main()) 80 | -------------------------------------------------------------------------------- /examples/ee/sports/README.md: -------------------------------------------------------------------------------- 1 | # Working full example for the Sports example 2 | 3 | Data Sources 4 | 1. OddsAPI 5 | 2. Q4 Sports 6 | 3. Genius Sports 7 | 8 | 9 | ## Ingestion Workflow 10 | 11 | 1. Define each DataSource as a 'Tool' 12 | - properties such as name, url, openscheume_url, docs, pipy, npm repos ,etc.. 13 | 2. Build DataSource DAG 14 | - endpoints, models 15 | 3. Generate OpenAPI schema 16 | - Similar to `python -m openapi_python_client.generate -i ./openapi.yaml -o ./generated` 17 | - Generate api.py and models.py for each endpoint and datasource so that each endpoint has typeing and validation 18 | - Might require building `class BaseAPI` to handle auth, retries, monitoring, etc.. 19 | 4. Generator 20 | - Build query plan - a dry run that iterates through each endpoint and respects dependencies and internal links 21 | - build partial prisma.schema to save results with api_endpoint consistent linking 22 | - update prisma.schema with new models and deploy to server 23 | 6. Deploy 24 | - Build prisma client 25 | - Build prisma schema 26 | - Deploy prisma schema 27 | 7. Runner 28 | - Run query plan and save results to prisma 29 | - On inserting send new ids to queue for downstream processing 30 | 31 | ## Tool Linking 32 | 33 | 1. For a given set of tools and a target domain (and potentially other materials that the define the uses of the data). 34 | 2. Propose an ontology of core entities, properties, and relationships that look to unify the data, and provide a common language for the tools to communicate. 35 | 3. Map the Tool -> DB (Table,Column) -> DSO (Entity,Property) 36 | 1. Can also have a mapping of DB:Table -> DSO:Entity 37 | 2. Can also have a mapping of DB:Column -> DSO:Property or DSO:Relationship 38 | 39 | What if the column's value specifies the type of relationship? or if the simple co-occurrence of entity references columns in a query specifies a relationship? 40 | 41 | 42 | other issues: 43 | * two columns could represent (key, value) pairs such that the key is a property of the entity and the value is the value of the property. 44 | -------------------------------------------------------------------------------- /examples/er_travel.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from promptedgraphs.config import Config, load_config 4 | from promptedgraphs.extraction.entities_from_text import entities_from_text 5 | 6 | 7 | async def main(): 8 | load_config() 9 | 10 | travel_text = """Hello fellow travelers! We're venturing to New Zealand from March 5-18th as a couple. We'll primarily be in Auckland visiting my sister, but we're hoping to explore more of the North Island. Since we're big fans of adventure sports and nature, we're thinking of places like the Tongariro Alpine Crossing or maybe Waitomo Caves. However, we're unsure about the best routes or if there are any hidden gems nearby. Any tips or suggestions? Has anyone been around those areas in March? Recommendations for cozy accommodations, local eateries, or any must-visit spots around Auckland would be greatly appreciated. Cheers!""" 11 | itinerary_entities = { 12 | "TRAVELER": "List of travelers with their ages and relationship details.", 13 | "DATE": "Absolute or relative dates or periods", 14 | "DEPARTURE LOCATION": "Location from where the travelers will start their journey.", 15 | "LOCATIONS": "All destinations and places mentioned in the travel request.", 16 | "ACCOMODATIONS": "Details of hotels and other lodging options mentioned.", 17 | "ACTIVITIES": "List of activities and attractions highlighted in the travel plan.", 18 | "NEEDS": "Special requirements such as accessibility, dietary needs, or medical considerations.", 19 | "INTERESTS": "Additional activities or attractions the travelers might enjoy.", 20 | } 21 | 22 | async for msg in entities_from_text( 23 | travel_text, labels=itinerary_entities, config=Config(), include_reason=True 24 | ): 25 | print(msg) 26 | 27 | 28 | if __name__ == "__main__": 29 | asyncio.run(main()) 30 | -------------------------------------------------------------------------------- /examples/sentiment_analysis_of_reviews.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Text Labeling Task: Sentiment Analysis of Customer Reviews" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 10, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "data": { 17 | "text/plain": [ 18 | "Config(name=Prompted Graphs, description=From Dataset Labeling to Deployment: The Power of NLP and LLMs Combined., version=0.3.1, openai_api_key=***************************************************)" 19 | ] 20 | }, 21 | "execution_count": 10, 22 | "metadata": {}, 23 | "output_type": "execute_result" 24 | } 25 | ], 26 | "source": [ 27 | "from promptedgraphs.vis import render_entities\n", 28 | "from promptedgraphs.config import Config, load_config\n", 29 | "\n", 30 | "import asyncio\n", 31 | "import datetime\n", 32 | "from typing import Dict, List, NamedTuple\n", 33 | "\n", 34 | "import spacy\n", 35 | "from nltk.sentiment import SentimentIntensityAnalyzer\n", 36 | "from pydantic import BaseModel, Field\n", 37 | "\n", 38 | "from promptedgraphs.config import Config, load_config\n", 39 | "from promptedgraphs.extraction.entities_from_text import entities_from_text\n", 40 | "from promptedgraphs.generation.data_from_model import generate\n", 41 | "from promptedgraphs.llms.openai_chat import LanguageModel\n", 42 | "from promptedgraphs.llms.usage import Usage\n", 43 | "from promptedgraphs.models import EntityReference\n", 44 | "from promptedgraphs.vis import ensure_entities, render_entities\n", 45 | "\n", 46 | "# Load Config to read OPENAI_API_KEY from .env file or environment variable\n", 47 | "config = load_config()\n", 48 | "config" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 12, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "async def label_sentiment(text_of_reviews):\n", 58 | " labels = {\n", 59 | " \"POSITIVE\": \"A postive review of a product or service.\",\n", 60 | " \"NEGATIVE\": \"A negative review of a product or service.\",\n", 61 | " \"NEUTRAL\": \"A neutral review of a product or service.\",\n", 62 | " }\n", 63 | "\n", 64 | " ents = []\n", 65 | " async for msg in entities_from_text(\n", 66 | " name=\"sentiment\",\n", 67 | " description=\"Sentiment Analysis of Customer Reviews\",\n", 68 | " text=text_of_reviews,\n", 69 | " labels=labels,\n", 70 | " config=config,\n", 71 | " include_reason=True,\n", 72 | " ):\n", 73 | " ents.append(msg)\n", 74 | " return ents\n", 75 | "\n", 76 | "\n", 77 | "text_of_reviews = \"\"\"\n", 78 | "1. \"I absolutely love this product. It's been a game changer!\"\n", 79 | "2. \"The service was quite poor and the staff was rude.\"\n", 80 | "3. \"The item is okay. Nothing special, but it gets the job done.\"\n", 81 | "\"\"\".strip()\n", 82 | "\n", 83 | "ents = await label_sentiment(text_of_reviews)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 13, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/html": [ 94 | "
1. "\n", 95 | "\n", 96 | " I absolutely love this product. It's been a game changer!\n", 97 | " POSITIVE\n", 98 | "\n", 99 | ""
2. "\n", 100 | "\n", 101 | " The service was quite poor and the staff was rude.\n", 102 | " NEGATIVE\n", 103 | "\n", 104 | ""
3. "\n", 105 | "\n", 106 | " The item is okay. Nothing special, but it gets the job done.\n", 107 | " NEUTRAL\n", 108 | "\n", 109 | ""
" 110 | ], 111 | "text/plain": [ 112 | "" 113 | ] 114 | }, 115 | "metadata": {}, 116 | "output_type": "display_data" 117 | } 118 | ], 119 | "source": [ 120 | "render_entities(\n", 121 | " text_of_reviews,\n", 122 | " ents,\n", 123 | " colors={\"POSITIVE\": \"#7aecec\", \"NEGATIVE\": \"#f44336\", \"NEUTRAL\": \"#f4f442\"},\n", 124 | ")" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 14, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/plain": [ 135 | "[EntityReference(start=4, end=61, label='POSITIVE', text=\"I absolutely love this product. It's been a game changer!\", reason='The text expresses love for the product and describes it as a game changer.'),\n", 136 | " EntityReference(start=67, end=117, label='NEGATIVE', text='The service was quite poor and the staff was rude.', reason='The text indicates poor service and rude staff.'),\n", 137 | " EntityReference(start=123, end=183, label='NEUTRAL', text='The item is okay. Nothing special, but it gets the job done.', reason='The text describes the item as okay and mentions it gets the job done.')]" 138 | ] 139 | }, 140 | "execution_count": 14, 141 | "metadata": {}, 142 | "output_type": "execute_result" 143 | } 144 | ], 145 | "source": [ 146 | "# Here you can see the reasons\n", 147 | "ents" 148 | ] 149 | } 150 | ], 151 | "metadata": { 152 | "kernelspec": { 153 | "display_name": "Python 3 (ipykernel)", 154 | "language": "python", 155 | "name": "python3" 156 | }, 157 | "language_info": { 158 | "codemirror_mode": { 159 | "name": "ipython", 160 | "version": 3 161 | }, 162 | "file_extension": ".py", 163 | "mimetype": "text/x-python", 164 | "name": "python", 165 | "nbconvert_exporter": "python", 166 | "pygments_lexer": "ipython3", 167 | "version": "3.10.12" 168 | } 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 4 172 | } 173 | -------------------------------------------------------------------------------- /promptedgraphs/__init__.py: -------------------------------------------------------------------------------- 1 | __title__ = "Prompted Graphs" 2 | __version__ = "0.4.3" 3 | __description__ = ( 4 | "From Dataset Labeling to Deployment: The Power of NLP and LLMs Combined." 5 | ) 6 | -------------------------------------------------------------------------------- /promptedgraphs/__main__.py: -------------------------------------------------------------------------------- 1 | """This is the main entrypoint for the cli""" 2 | from promptedgraphs.cli import app 3 | 4 | if __name__ == "__main__": 5 | app() 6 | -------------------------------------------------------------------------------- /promptedgraphs/cli.py: -------------------------------------------------------------------------------- 1 | import pyfiglet 2 | from rich import print 3 | from typer import Typer 4 | 5 | from promptedgraphs import __description__ as DESCRIPTION 6 | from promptedgraphs import __title__ as NAME 7 | from promptedgraphs import __version__ as VERSION 8 | from promptedgraphs.config import load_config 9 | 10 | 11 | def banner(): 12 | return pyfiglet.figlet_format(NAME.replace("_", " ").title(), font="slant").rstrip() 13 | 14 | 15 | app = Typer(help=f"{(NAME or '').replace('_', ' ').title()} CLI") 16 | 17 | 18 | @app.command() 19 | def info(): 20 | """Prints info about the package""" 21 | print(f"{banner()}\n") 22 | print(f"{NAME}: {DESCRIPTION}") 23 | print(f"Version: {VERSION}\n") 24 | print(load_config()) 25 | 26 | 27 | @app.command() 28 | def main(): 29 | """Main Function""" 30 | print(f"{banner()}\n") 31 | print( 32 | "This is your default command-line interface. Feel free to customize it as you see fit.\n" 33 | ) 34 | -------------------------------------------------------------------------------- /promptedgraphs/code_execution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/code_execution/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/code_execution/safer_python_exec.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from logging import getLogger 3 | 4 | import black 5 | 6 | logger = getLogger(__name__) 7 | 8 | 9 | def is_potentially_unsafe(node): # sourcery skip: low-code-quality 10 | if isinstance(node, ast.Call): 11 | # Check for dangerous built-in functions 12 | unsafe_functions = {"eval", "exec", "open", "os.system"} 13 | if isinstance(node.func, ast.Name) and node.func.id in unsafe_functions: 14 | return True 15 | elif isinstance(node, (ast.Import, ast.ImportFrom)): 16 | # Disallow dynamic imports 17 | allowed_imports = {"json", "math", "__future__", "pydantic", "promptedgraphs"} 18 | if isinstance(node, ast.Import): 19 | for alias in node.names: 20 | if alias.name not in allowed_imports: 21 | return True 22 | elif isinstance(node, ast.ImportFrom): 23 | return node.module not in allowed_imports 24 | elif isinstance(node, ast.FunctionDef): 25 | # Disallow functions that use exec 26 | for stmt in node.body: 27 | if ( 28 | isinstance(stmt, ast.Expr) 29 | and isinstance(stmt.value, ast.Call) 30 | and ( 31 | isinstance(stmt.value.func, ast.Name) 32 | and stmt.value.func.id == "exec" 33 | ) 34 | ): 35 | return True 36 | elif isinstance(node, ast.Assign): 37 | # Disallow assignments of potentially unsafe values 38 | for target in node.targets: 39 | if isinstance(target, ast.Name) and target.id in { 40 | "__builtins__", 41 | "eval", 42 | "exec", 43 | "open", 44 | }: 45 | return True 46 | elif isinstance(node, ast.Attribute): 47 | # Disallow dynamic attribute access 48 | if isinstance(node.value, ast.Name) and node.value.id == "__import__": 49 | return True 50 | elif isinstance(node, ast.Expr): 51 | # Disallow potentially unsafe expressions 52 | if isinstance(node.value, ast.Call) and ( 53 | isinstance(node.value.func, ast.Name) and node.value.func.id == "eval" 54 | ): 55 | return True 56 | return False 57 | 58 | 59 | def format_code(code): 60 | return black.format_str(code, mode=black.FileMode()) 61 | 62 | 63 | def safer_exec(model_code): 64 | logger.warning("Executing generated data model code from datamodel_code_generator") 65 | exec_variables = {} 66 | 67 | # Parse the code and generate the AST 68 | parsed_code = ast.parse(model_code, filename="", mode="exec") 69 | 70 | # Filter out potentially unsafe constructs 71 | safe_code_body = [ 72 | node for node in parsed_code.body if not is_potentially_unsafe(node) 73 | ] 74 | safe_code_ast = ast.Module(body=safe_code_body, type_ignores=[]) 75 | 76 | # Compile the parsed code into a safe code object 77 | safe_code = compile(safe_code_ast, filename="data_model.py", mode="exec") 78 | 79 | # Execute the safe code within a restricted environment 80 | exec(safe_code, exec_variables) 81 | 82 | return exec_variables 83 | -------------------------------------------------------------------------------- /promptedgraphs/config.py: -------------------------------------------------------------------------------- 1 | """Loads the configuration file for the QuantReady package.""" 2 | # Load the configuration file 3 | import os 4 | import re 5 | from dataclasses import dataclass, field 6 | from pathlib import Path 7 | 8 | from dotenv import load_dotenv 9 | 10 | from promptedgraphs import __description__ as description 11 | from promptedgraphs import __title__ as name 12 | from promptedgraphs import __version__ as version 13 | 14 | 15 | @dataclass 16 | class Config: 17 | """Configuration class for PromptedGraphs""" 18 | 19 | name: str = name 20 | description: str = description 21 | version: str = version 22 | openai_api_key: str | None = field( 23 | default_factory=lambda: os.getenv("OPENAI_API_KEY") 24 | ) 25 | ogtags_api_key: str | None = field( 26 | default_factory=lambda: os.getenv("OGTAGS_API_KEY") 27 | ) 28 | 29 | def __repr__(self): 30 | # Mask the value of openai_api_key 31 | secret_keys = {"openai_api_key", "ogtags_api_key"} 32 | 33 | # Create a dictionary of all attributes to display 34 | attributes = { 35 | "name": self.name, 36 | "description": self.description, 37 | "version": self.version, 38 | } 39 | for key in secret_keys: 40 | value = getattr(self, key) 41 | attributes[key] = None if value is None else re.sub(r".", "*", value) 42 | # Generate string representation of the object 43 | attribute_strings = [ 44 | f"{key}={value}" for key, value in attributes.items() if value is not None 45 | ] 46 | return f"Config({', '.join(attribute_strings)})" 47 | 48 | 49 | def load_config() -> Config: 50 | load_dotenv( 51 | verbose=True, 52 | dotenv_path=Path(__file__).parent.joinpath(".env"), 53 | override=False, 54 | ) 55 | return Config( 56 | name=name, 57 | description=description, 58 | version=version, 59 | ) 60 | -------------------------------------------------------------------------------- /promptedgraphs/data_extraction.py: -------------------------------------------------------------------------------- 1 | """The purpose of this class is to extract data from text 2 | and return it as a structured object or a list of structured objects. 3 | 4 | Base case: 5 | * Provide text and specify the return type as a Pydantic model and it returns an instance of that model 6 | 7 | """ 8 | 9 | 10 | import contextlib 11 | import json 12 | import re 13 | 14 | from pydantic import BaseModel 15 | 16 | from promptedgraphs.config import Config 17 | from promptedgraphs.llms.openai_streaming import ( 18 | GPT_MODEL, 19 | streaming_chat_completion_request, 20 | ) 21 | from promptedgraphs.models import ChatMessage 22 | from promptedgraphs.parsers import extract_partial_list 23 | 24 | SYSTEM_MESSAGE = """ 25 | You are a Qualitative User Researcher and Linguist. Your task is to extract structured data from text to be passed into python's `{name}(BaseModel)` pydantic class. 26 | Maintain as much verbatim text as possible, light edits are allowed, feel free to remove any text that is not relevant to the label. 27 | If there is not information applicable for a particular field, use the value `==NA==`. 28 | 29 | In particular look for the following fields: {label_list}. 30 | 31 | {label_definitions} 32 | """ 33 | 34 | 35 | def add_space_before_capital(text): 36 | return re.sub(r"(?<=[a-z])(?=[A-Z])", " ", text) 37 | 38 | 39 | def camelcase_to_words(text): 40 | if len(text) <= 1: 41 | return text.replace("_", " ") 42 | if text.lower() == text: 43 | return text.replace("_", " ") 44 | return add_space_before_capital(text[0].upper() + text[1:]).replace("_", " ") 45 | 46 | 47 | def format_fieldinfo(key, v: dict): 48 | l = f"`{key}`: {v.get('title','')} - {v.get('description','')}".strip() 49 | if v.get("annotation"): 50 | annotation = v.get("annotation") 51 | if annotation == "str": 52 | annotation = "string" 53 | l += f"\n\ttype: {annotation}" 54 | if v.get("examples"): 55 | l += f"\n\texamples: {v.get('examples')}" 56 | return l.rstrip() 57 | 58 | 59 | def create_messages(text, name, labels, custom_system_message=None): 60 | label_list = list(labels.keys()) 61 | 62 | label_definitions = """Below are definitions of each field to help aid you in what kinds of structured data to extract for each label. 63 | Assume these definitions are written by an expert and follow them closely.\n\n""" + "\n".join( 64 | [f" * {format_fieldinfo(k, dict(labels[k]))}" for k in label_list] 65 | ) 66 | 67 | custom_system_message = custom_system_message or SYSTEM_MESSAGE.format( 68 | name=name, label_list=label_list, label_definitions=label_definitions 69 | ) 70 | 71 | messages = [ 72 | ChatMessage( 73 | role="system", 74 | content=custom_system_message, 75 | ) 76 | ] 77 | if text: 78 | messages.append(ChatMessage(role="user", content=text)) 79 | 80 | return messages 81 | 82 | 83 | def remove_nas(data: dict[str, any], nulls: list[str] = None): 84 | nulls = nulls or ["==NA==", "NA", "N/A", "n/a", "#N/A", "None", "none"] 85 | for k, v in data.items(): 86 | if isinstance(v, dict): 87 | remove_nas(v) 88 | elif isinstance(v, list): 89 | for i in range(len(v)): 90 | if isinstance(v[i], dict): 91 | remove_nas(v[i]) 92 | elif v[i] in nulls: 93 | v[i] = None 94 | # remove empty list items 95 | data[k] = [i for i in v if i is not None] 96 | elif v in nulls: 97 | data[k] = None 98 | return data 99 | 100 | 101 | def format_properties(properties, refs=None): 102 | refs = refs or {} 103 | d = {} 104 | for k, v in properties.items(): 105 | dtypes = None # reset per loop 106 | dtype = v.get("type") 107 | if dtype is None and v.get("anyOf"): 108 | dtypes = [t for t in v["anyOf"] if t["type"] != "null"] 109 | if len(dtypes): 110 | dtype = dtypes[0]["type"] 111 | d[k] = {"type": dtype or "string"} 112 | if d[k]["type"] == "array": 113 | d[k][ 114 | "description" 115 | ] = f"{v.get('title','')} - {v.get('description', '')}".strip() 116 | 117 | if reference := v.get("items", {}).get("$ref", "/").split("/")[-1]: 118 | v2 = refs[reference] 119 | elif dtypes: 120 | v2 = dtypes[0].get("items", v) 121 | else: 122 | v2 = v 123 | 124 | item_type = v2.get("type") 125 | if item_type == "object": 126 | d[k]["items"] = { 127 | "type": v2.get("type", "object"), 128 | "description": f"{v2.get('title','')} - {v2.get('description', '')}".strip(), 129 | "properties": format_properties(v2.get("properties", {})), 130 | "required": v2.get("required", []), 131 | } 132 | else: 133 | d[k]["items"] = {"type": item_type} 134 | return d 135 | 136 | 137 | def create_functions(is_parent_list, schema, fn_name: str = None): 138 | """Note the top level cannot be an array, only nested objects can be arrays""" 139 | name, description = ( 140 | schema["title"], 141 | schema["description"], 142 | ) 143 | fn_name = fn_name or camelcase_to_words(name).replace(" ", "_").lower() 144 | if is_parent_list: 145 | return [ 146 | { 147 | "name": name, 148 | "type": "function", 149 | "description": "Extract a non-empty list of structured data from text", 150 | "parameters": { 151 | "type": "object", 152 | "properties": { 153 | fn_name: { 154 | "type": "array", 155 | "items": { 156 | "type": "object", 157 | "description": description, 158 | "properties": format_properties( 159 | schema["properties"], refs=schema.get("$defs", {}) 160 | ), 161 | "required": schema.get("required", []), 162 | }, 163 | } 164 | }, 165 | }, 166 | } 167 | ] 168 | 169 | return [ 170 | { 171 | "name": name, 172 | "type": "function", 173 | "description": description, 174 | "parameters": { 175 | "type": "object", 176 | "properties": format_properties( 177 | schema["properties"], refs=schema.get("$defs", {}) 178 | ), 179 | "required": schema.get("required", []), 180 | }, 181 | } 182 | ] 183 | 184 | 185 | async def extract_data( 186 | text: str, 187 | output_type: list[BaseModel] | BaseModel, 188 | config: Config, 189 | model=GPT_MODEL, 190 | temperature=0.0, 191 | custom_system_message: str | None = None, 192 | force_type: bool = True, 193 | ): 194 | if is_parent_list := str(output_type).lower().startswith("list"): 195 | assert ( 196 | len(output_type.__args__) == 1 197 | ), "Only one type argument is allowed for list types" 198 | output_type = output_type.__args__[0] 199 | assert issubclass(output_type, BaseModel), "output_type must be a Pydantic model" 200 | 201 | schema = output_type.model_json_schema() 202 | name = schema["title"] 203 | fn_name = camelcase_to_words(schema["title"]).replace(" ", "_").lower() 204 | 205 | messages = create_messages( 206 | text, 207 | schema["title"], 208 | schema["properties"], 209 | custom_system_message=custom_system_message, 210 | ) 211 | functions = create_functions(is_parent_list, schema, fn_name) 212 | 213 | count = 0 214 | payload = "" 215 | async for msg in streaming_chat_completion_request( 216 | messages=messages, 217 | functions=functions, 218 | model=model, 219 | config=config, 220 | temperature=temperature, 221 | ): 222 | if msg.data is None or msg.data == "": 223 | continue 224 | 225 | if msg.data == "[DONE]": 226 | with contextlib.suppress(json.decoder.JSONDecodeError): 227 | if is_parent_list: 228 | s = json.loads(payload).get(name, []) 229 | for data in s[count:]: 230 | try: 231 | yield output_type(**remove_nas(data)) 232 | except Exception as e: 233 | if force_type: 234 | raise e 235 | else: 236 | yield remove_nas(data) 237 | else: 238 | data = json.loads(payload) 239 | try: 240 | yield output_type(**remove_nas(data)) 241 | except Exception as e: 242 | if force_type: 243 | raise e 244 | else: 245 | yield remove_nas(data) 246 | 247 | break 248 | 249 | # TODO try catch for malformed json 250 | if msg.event == "error": 251 | print("ERROR", msg.data) 252 | break 253 | 254 | data = json.loads(msg.data) 255 | 256 | choices = data.get("choices") 257 | if choices is None: 258 | continue 259 | 260 | delta = choices[0].get("delta") 261 | 262 | # TODO rewrite this to be more robust streaming json parser 263 | payload += delta.get("function_call", {}).get("arguments", "") 264 | 265 | if is_parent_list: 266 | s = extract_partial_list(payload, key=fn_name) 267 | if s is None or len(s) == 0 or len(s) <= count: 268 | continue 269 | 270 | for data in s[count:]: 271 | yield output_type(**remove_nas(data)) 272 | 273 | count = len(s) 274 | -------------------------------------------------------------------------------- /promptedgraphs/data_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/data_modeling/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/data_modeling/logical.py: -------------------------------------------------------------------------------- 1 | def logical(data_objects: list[dict], entity_relationship_graph: dict) -> dict: 2 | """Transforms a conceptual model into a logical schema with types and constraints. 3 | 4 | Args: 5 | data_objects (List[dict]): A list of data objects. 6 | entity_relationship_graph (dict): The entity-relationship graph derived from the conceptual model. 7 | 8 | Returns: 9 | Dict: A logical schema representing the data model with types and constraints. 10 | """ 11 | -------------------------------------------------------------------------------- /promptedgraphs/data_modeling/physical.py: -------------------------------------------------------------------------------- 1 | def physical(entity_relationship_graph: dict, schema: dict) -> str: 2 | """Generates a physical database schema in third-normal form from an ER graph and logical schema. 3 | 4 | Args: 5 | entity_relationship_graph (dict): The entity-relationship graph. 6 | schema (dict): The logical schema. 7 | 8 | Returns: 9 | str: The physical database schema in third-normal form. 10 | """ 11 | -------------------------------------------------------------------------------- /promptedgraphs/entity_linking/README.md: -------------------------------------------------------------------------------- 1 | # Entity Linking 2 | 3 | Coreference resolution, entity linking, and entity extraction 4 | 5 | ``` 6 | pip install -U spacy-experimental 7 | ``` -------------------------------------------------------------------------------- /promptedgraphs/entity_linking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/entity_linking/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/entity_linking/link.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | def link( 5 | er_graph: dict, data_to_schema_map: dict, db_connections: dict[str, Any] 6 | ) -> dict: 7 | """Generates likely links from an ER graph to specific data in the database. 8 | 9 | Args: 10 | er_graph (Dict): The Entity-Relationship Property graph. 11 | data_to_schema_map (Dict): The map created from the schema_to_schema function. 12 | db_connections (Dict[str, Any]): Database connections for querying candidates. 13 | 14 | Returns: 15 | Dict: A mapping of entities in the ER graph to database records. 16 | """ 17 | -------------------------------------------------------------------------------- /promptedgraphs/entity_linking/upsertion.py: -------------------------------------------------------------------------------- 1 | def upsertion(*args, **kwargs) -> None: 2 | """Performs an efficient upsertion step into the database using the links from the `link` step. 3 | 4 | This function's arguments and return type will be highly dependent on the specifics of your database schema and the way you've structured your entity linking outputs. 5 | """ 6 | -------------------------------------------------------------------------------- /promptedgraphs/entity_recognition.py: -------------------------------------------------------------------------------- 1 | """This module contains the entity recognition pipeines 2 | We can rely on the spacy library for this task as well 3 | as their spacy-llm library for the language model. 4 | 5 | We will be providing a custom implementation of spacy-llm 6 | in order to support the streaming responses from the API 7 | and will still make the results play well with spacy to enable 8 | custom training and other features. 9 | 10 | Ideally ER should be done locally with a transformer-style 11 | model trained on a given domain. Until we get there, 12 | we will use LLMs to label the data. 13 | 14 | Base case: 15 | * Run entity recognition on the text, stream back spacy-compatible entity spans 16 | 17 | Nice to have 18 | * Run ER multiple times with noise and only choose a subset of labels 19 | * Aggregate the highest confidence common labels 20 | 21 | Ideal case: 22 | * For labels with high disagreement, have a human label the data 23 | * Create a gold-star dataset for training a custom model 24 | * Train a custom model on the data 25 | * Evaluate the performance of the model and compare to the LLM in performance, cost and time 26 | 27 | """ 28 | 29 | 30 | import contextlib 31 | import json 32 | import re 33 | 34 | import tiktoken 35 | 36 | from promptedgraphs.config import Config 37 | from promptedgraphs.llms.openai_streaming import ( 38 | GPT_MODEL, 39 | streaming_chat_completion_request, 40 | ) 41 | from promptedgraphs.llms.usage import Usage 42 | from promptedgraphs.models import ChatMessage, EntityReference 43 | from promptedgraphs.parsers import extract_partial_list 44 | 45 | # Name and description of entity types 46 | 47 | 48 | SYSTEM_MESSAGE = """ 49 | You are an expert Named Entity Recognition (NER) system. 50 | Your task is to accept Text as input and extract named entities to be passed to the `{name}` function. 51 | 52 | Entities must have one of the following labels: {label_list}. 53 | If a span is not an entity label it: `==NONE==`. 54 | 55 | {label_definitions} 56 | """ 57 | 58 | 59 | def build_message_list( 60 | text: str, 61 | label_list: list[str], 62 | labels: dict[str, str], 63 | name="extract_entities", 64 | ): 65 | label_definitions = """Below are definitions of each label to help aid you in what kinds of named entities to extract for each label. 66 | Assume these definitions are written by an expert and follow them closely.\n\n""" + "\n".join( 67 | [f" * {label}: {labels[label]}" for label in label_list] 68 | ) 69 | 70 | return [ 71 | ChatMessage( 72 | role="system", 73 | content=SYSTEM_MESSAGE.format( 74 | name=name, label_list=label_list, label_definitions=label_definitions 75 | ), 76 | ), 77 | ChatMessage(role="user", content=text), 78 | ] 79 | 80 | 81 | def build_function_spec( 82 | label_list: list[str], 83 | function_name="extract_entities", 84 | key="entities", 85 | description="", 86 | include_reason=True, 87 | ): 88 | spec = [ 89 | { 90 | "name": function_name, 91 | "type": "function", 92 | "description": description, 93 | "parameters": { 94 | "type": "object", 95 | "properties": { 96 | key: { 97 | "type": "array", 98 | "items": { 99 | "type": "object", 100 | "description": "The raw text and label for each entity occurrence", 101 | "properties": { 102 | "text_span": { 103 | "type": "string", 104 | "description": "The exact text of referenced entity.", 105 | }, 106 | "is_entity": { 107 | "type": "boolean", 108 | "description": "A boolean indicating if the span is an entity label.", 109 | }, 110 | "label": { 111 | "type": "string", 112 | "enum": label_list + ["==NONE=="], 113 | "description": "The label of the entity.", 114 | }, 115 | "reason": { 116 | "type": "string", 117 | "description": "A short description of why that label was selected.", 118 | }, 119 | }, 120 | "required": ["text_span", "is_entity", "label", "reason"], 121 | }, 122 | } 123 | }, 124 | }, 125 | } 126 | ] 127 | if not include_reason: 128 | spec[0]["parameters"]["properties"][key]["items"]["properties"].pop("reason") 129 | spec[0]["parameters"]["properties"][key]["items"]["required"] = [ 130 | "text_span", 131 | "is_entity", 132 | "label", 133 | ] 134 | return spec 135 | 136 | 137 | def _format_entities(s, text): 138 | for entity in s: 139 | if entity["is_entity"]: 140 | for match in re.finditer(entity["text_span"], text): 141 | yield EntityReference( 142 | start=match.start(), 143 | end=match.end(), 144 | text=entity["text_span"], 145 | label=entity["label"], 146 | reason=entity.get("reason"), 147 | ) 148 | 149 | 150 | async def extract_entities( 151 | text: str, 152 | labels: dict[str, str], 153 | config: Config, 154 | name="entities", 155 | description: str | None = "Extract Entities from text", 156 | include_reason=True, 157 | model=GPT_MODEL, 158 | temperature=0.2, 159 | ): 160 | label_list = sorted(labels.keys()) 161 | messages = build_message_list( 162 | text, 163 | label_list, 164 | labels=labels, 165 | name=f"extract_{name}".lower().replace(" ", "_"), 166 | ) 167 | 168 | functions = build_function_spec( 169 | label_list, 170 | function_name=f"extract_{name}".lower().replace(" ", "_"), 171 | key=name, 172 | description=description, 173 | include_reason=include_reason, 174 | ) 175 | 176 | count = 0 177 | payload = "" 178 | usage = Usage(model=model) 179 | usage.start() 180 | async for msg in streaming_chat_completion_request( 181 | messages=messages, 182 | functions=functions, 183 | model=model, 184 | config=config, 185 | temperature=temperature, 186 | ): 187 | if msg.data is None or msg.data == "": 188 | continue 189 | 190 | if msg.data == "[DONE]": 191 | try: 192 | encoding = tiktoken.encoding_for_model(model) 193 | except KeyError: 194 | encoding = tiktoken.get_encoding("cl100k_base") 195 | usage.completion_tokens += len(encoding.encode(payload)) 196 | with contextlib.suppress(json.decoder.JSONDecodeError): 197 | s = json.loads(payload).get(name, []) 198 | for entity in _format_entities(s[count:], text): 199 | yield entity 200 | 201 | break 202 | 203 | # TODO try catch for malformed json 204 | data = json.loads(msg.data) 205 | 206 | msg_usage = data.get("usage") 207 | if msg_usage is not None: 208 | usage.prompt_tokens += msg_usage.get("prompt_tokens", 0) 209 | usage.completion_tokens += msg_usage.get("completion_tokens", 0) 210 | 211 | choices = data.get("choices") 212 | if choices is None: 213 | continue 214 | 215 | delta = choices[0].get("delta") 216 | 217 | # TODO rewrite this to be more robust streaming json parser 218 | payload += delta.get("function_call", {}).get("arguments", "") 219 | s = extract_partial_list(payload, key=name) 220 | if s is None or len(s) == 0 or len(s) <= count: 221 | continue 222 | 223 | for entity in _format_entities(s[count:], text): 224 | yield entity 225 | count = len(s) 226 | 227 | usage.end() # calculates cost and time 228 | yield usage 229 | -------------------------------------------------------------------------------- /promptedgraphs/entity_resolution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/entity_resolution/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/entity_resolution/resolve.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | def resolve( 5 | text: str, data_models: set[BaseModel], entity_relationship_schema: dict 6 | ) -> dict: 7 | """Resolves entities and relationships from text based on a set of DataModels and an entity-relationship schema. 8 | 9 | Args: 10 | text (str): The text containing the entities to be resolved. 11 | data_models (Set[BaseModel]): The set of DataModels to use for resolving entities. 12 | entity_relationship_schema (Dict): The schema defining the relationships and properties of entities. 13 | 14 | Returns: 15 | Dict: An Entity-Relationship Property graph representing the resolved entities and their relationships. 16 | """ 17 | -------------------------------------------------------------------------------- /promptedgraphs/extraction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/extraction/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/extraction/data_from_text.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from logging import getLogger 4 | from typing import AsyncGenerator 5 | 6 | from pydantic import BaseModel 7 | 8 | from promptedgraphs.config import Config, load_config 9 | from promptedgraphs.generation.schema_from_model import schema_from_model 10 | from promptedgraphs.llms.chat import Chat 11 | from promptedgraphs.llms.openai_chat import LanguageModel 12 | from promptedgraphs.llms.usage import Usage 13 | from promptedgraphs.models import ChatMessage 14 | 15 | logger = getLogger(__name__) 16 | 17 | SYSTEM_MESSAGE = """ 18 | You are an information extraction expert. 19 | Your task is to extract structured information from the provided text. 20 | The extracted data should fit the provided schema to structure the data into a Pydantic BaseModel class. 21 | Do not create data and lightly edit and reformat data to match the schema. 22 | 23 | Assume this schema was written by an expert and follow them closely when extracting the data from the text. 24 | Return a list of extracted data that fits the schema 25 | Always prefer a list of similar items over a single complex item. 26 | 27 | ## {name}(BaseModel) Schema: 28 | {description} 29 | 30 | {label_definitions} 31 | 32 | ## Schema of the list of data to extract 33 | ```json 34 | {schema} 35 | ``` 36 | 37 | Always just return a json list of the extracted data with no explanation. 38 | """ 39 | 40 | MESSAGE_TEMPLATE = """ 41 | ## Extract data from this text: 42 | 43 | {text} 44 | 45 | ## JSON list of extracted data 46 | """ 47 | 48 | 49 | async def extraction_chat( 50 | text: str, 51 | chat: Chat = None, 52 | system_message: str = SYSTEM_MESSAGE, 53 | message_template: str = MESSAGE_TEMPLATE, 54 | usage: Usage = None, 55 | message_history: list[ChatMessage] = None, 56 | force_json: bool = True, 57 | **chat_kwargs, 58 | ) -> list[BaseModel] | list[str]: 59 | usage = usage or Usage() 60 | 61 | messages = [{"role": "system", "content": system_message.strip()}] 62 | messages.extend(message_history or []) 63 | messages.append( 64 | {"role": "user", "content": message_template.format(text=text or "").strip()} 65 | ) 66 | # TODO replace with a tiktoken model and pad the message by 2x 67 | default_chat_args = { 68 | "max_tokens": 4_096, 69 | "temperature": 0.0, 70 | } 71 | if force_json: 72 | default_chat_args["response_format"] = {"type": "json_object"} 73 | 74 | response = await chat.chat_completion( 75 | messages=messages, **default_chat_args | chat_kwargs 76 | ) 77 | if hasattr(response, "usage"): 78 | usage.completion_tokens = getattr(response.usage, "completion_tokens") or 0 79 | usage.prompt_tokens = getattr(response.usage, "prompt_tokens") or 0 80 | 81 | # TODO check if the results stopped early, in that case, the list might be truncated 82 | if force_json: 83 | results = json.loads(response.choices[0].message.content) 84 | else: 85 | content = response.choices[0].message.content 86 | # TODO this will break if multiple json objects are in the response 87 | results = json.loads(content.split("```json")[1].split("```")[0]) 88 | 89 | if "items" in results: 90 | results = results["items"] 91 | return results if isinstance(results, list) else [results] 92 | 93 | 94 | async def _extract_data_from_text( 95 | text: str, 96 | output_type: BaseModel | None = None, 97 | temperature: float = 0.0, 98 | model: str = LanguageModel.GPT35_turbo, 99 | config: Config = None, 100 | usage: Usage = None, 101 | ) -> AsyncGenerator[BaseModel, BaseModel] | AsyncGenerator[str, str]: 102 | """Generate ideas using a text prompt""" 103 | 104 | if output_type is None: 105 | raise ValueError("output_type must be provided") 106 | 107 | usage = usage or Usage(model) 108 | 109 | # Make a list out of the output type 110 | class StructuredData(BaseModel): 111 | items: list[output_type] 112 | 113 | schema = schema_from_model(StructuredData) 114 | 115 | chat = Chat( 116 | config=config or Config(), 117 | model=model, 118 | ) 119 | 120 | # Format System Message 121 | item_schema = output_type.model_json_schema() 122 | labels = item_schema.get("properties", {}) 123 | label_list = sorted(labels.keys()) 124 | system_message = SYSTEM_MESSAGE.format( 125 | name=item_schema.get("title", "DataModel"), 126 | description=item_schema.get("description", ""), 127 | label_definitions="\n".join( 128 | [f" * {label}: {labels[label]}" for label in label_list] 129 | ), 130 | schema=json.dumps(schema, indent=4), 131 | ) 132 | 133 | # async def call_brainstorming_chat(): 134 | results = await extraction_chat( 135 | text=text, 136 | chat=chat, 137 | system_message=system_message, 138 | temperature=temperature, 139 | usage=usage, 140 | ) 141 | for result in results: 142 | yield result 143 | 144 | 145 | async def data_from_text( 146 | text: str, 147 | output_type: type[BaseModel] | BaseModel | str | None = None, 148 | temperature: float = 0.0, 149 | model: str = LanguageModel.GPT35_turbo, 150 | config: Config = None, 151 | usage: Usage = None, 152 | ) -> AsyncGenerator[BaseModel, BaseModel] | AsyncGenerator[str, str]: 153 | usage = usage or Usage(model) 154 | async for result in _extract_data_from_text( 155 | text, 156 | temperature=temperature, 157 | output_type=output_type, 158 | model=model, 159 | config=config, 160 | usage=usage, 161 | ): 162 | # TODO heal the data if it cannot be parsed 163 | if not result: 164 | continue 165 | yield output_type(**result) 166 | 167 | 168 | async def example(): 169 | from pydantic import BaseModel, Field 170 | 171 | class UserIntent(BaseModel): 172 | """The UserIntent entity, representing the canonical description of what a user desires to achieve in a given conversation.""" 173 | 174 | intent_name: str = Field( 175 | title="Intent Name", 176 | description="Canonical name of the user's intent", 177 | examples=[ 178 | "question", 179 | "command", 180 | "clarification", 181 | "chit_chat", 182 | "greeting", 183 | "feedback", 184 | "nonsensical", 185 | "closing", 186 | "harrassment", 187 | "unknown", 188 | ], 189 | ) 190 | description: str | None = Field( 191 | title="Intent Description", 192 | description="A detailed explanation of the user's intent", 193 | ) 194 | 195 | load_config() 196 | 197 | msg = """How can I learn more about your product?""" 198 | async for intent in data_from_text( 199 | text=msg, output_type=UserIntent, config=Config() 200 | ): 201 | print(intent) 202 | 203 | 204 | if __name__ == "__main__": 205 | asyncio.run(example()) 206 | -------------------------------------------------------------------------------- /promptedgraphs/extraction/entities_from_text.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import re 3 | from typing import AsyncGenerator, Iterator 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | from promptedgraphs.config import Config, load_config 8 | from promptedgraphs.extraction.data_from_text import data_from_text 9 | from promptedgraphs.llms.openai_chat import LanguageModel 10 | from promptedgraphs.llms.usage import Usage 11 | from promptedgraphs.models import EntityReference 12 | 13 | 14 | def _format_entities( 15 | entity: BaseModel, text: str, include_reason=False 16 | ) -> Iterator[EntityReference]: 17 | if isinstance(entity, Usage): 18 | yield entity 19 | return 20 | if entity.is_entity: 21 | for m in re.finditer(entity.text_span, text): 22 | yield EntityReference( 23 | start=m.start(), 24 | end=m.end(), 25 | text=entity.text_span, 26 | label=entity.label, 27 | reason=entity.reason if include_reason else None, 28 | ) 29 | 30 | 31 | async def entities_from_text( 32 | text: str, 33 | labels: dict[str, str], 34 | config: Config, 35 | name="entities", 36 | description: str | None = "", 37 | include_reason=True, 38 | model=LanguageModel.GPT35_turbo, 39 | temperature=0.2, 40 | usage: Usage = None, 41 | ) -> AsyncGenerator[EntityReference | Usage, None]: 42 | usage = usage or Usage(model=model) 43 | label_list = sorted(labels.keys()) 44 | 45 | class EntityMention(BaseModel): 46 | """The raw text and label for each entity occurrence""" 47 | 48 | text_span: str = Field( 49 | title="Text Span", 50 | description="The exact text of referenced entity.", 51 | ) 52 | is_entity: bool = Field( 53 | title="Is Entity", 54 | description="A boolean indicating if the span is an entity label.", 55 | ) 56 | label: str = Field( 57 | title="Label", 58 | description="The label of the entity.", 59 | examples=label_list + ["==NONE=="], 60 | ) 61 | if include_reason: 62 | reason: str = Field( 63 | title="Reason", 64 | description="A short description of why that label was selected.", 65 | ) 66 | 67 | EntityMention.__doc__ += f"""\nEntity Type: {name}""" if name else "" 68 | EntityMention.__doc__ += description or "" 69 | EntityMention.__doc__ += """\n 70 | Labels must be one of the following: {label_list}""".format( 71 | label_list=label_list + ["==NONE=="] 72 | ) 73 | EntityMention.__doc__ += """\n 74 | Label Definitions:\n""" + "\n".join( 75 | [f" * {label}: {labels[label]}" for label in label_list] 76 | ) 77 | usage.start() 78 | async for er in data_from_text( 79 | text=text, 80 | output_type=EntityMention, 81 | config=config or Config(), 82 | model=model, 83 | temperature=temperature, 84 | usage=usage, 85 | ): 86 | for ent in _format_entities(er, text, include_reason=include_reason): 87 | yield ent 88 | 89 | usage.end() 90 | 91 | 92 | async def example(): 93 | load_config() 94 | 95 | ### Canonical Text Labeling Task: Sentiment Analysis of Customer Reviews 96 | 97 | text_of_reviews = """ 98 | 1. "I absolutely love this product. It's been a game changer!" 99 | 2. "The service was quite poor and the staff was rude." 100 | 3. "The item is okay. Nothing special, but it gets the job done." 101 | """ 102 | 103 | labels = { 104 | "POSITIVE": "A postive review of a product or service.", 105 | "NEGATIVE": "A negative review of a product or service.", 106 | "NEUTRAL": "A neutral review of a product or service.", 107 | } 108 | 109 | ents = [] 110 | usage = Usage(model=LanguageModel.GPT35_turbo) 111 | async for msg in entities_from_text( 112 | name="sentiment", 113 | description="Sentiment Analysis of Customer Reviews", 114 | text=text_of_reviews, 115 | labels=labels, 116 | config=Config(), 117 | include_reason=False, 118 | usage=usage, 119 | ): 120 | ents.append(msg) 121 | 122 | print(ents) 123 | print("Usage:", usage) 124 | 125 | # displacy.render( 126 | # { 127 | # "text": text_of_reviews, 128 | # "ents": ents, 129 | # }, 130 | # style="ent", 131 | # jupyter=True, 132 | # manual=True, 133 | # ) 134 | 135 | 136 | if __name__ == "__main__": 137 | asyncio.run(example()) 138 | -------------------------------------------------------------------------------- /promptedgraphs/generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/generation/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/generation/data_from_model.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from logging import getLogger 4 | from typing import AsyncGenerator 5 | 6 | import tqdm 7 | from pydantic import BaseModel, Field 8 | 9 | from promptedgraphs.config import Config 10 | from promptedgraphs.generation.schema_from_model import schema_from_model 11 | from promptedgraphs.llms.chat import Chat 12 | from promptedgraphs.llms.openai_chat import LanguageModel 13 | from promptedgraphs.llms.usage import Usage 14 | 15 | logger = getLogger(__name__) 16 | 17 | SYSTEM_MESSAGE = """ 18 | {role} 19 | Your task is brainstorm examples of objects to be passed into python's `{name}(BaseModel)` pydantic class. 20 | Generate a list of diverse and creative examples of objects with the following fields: {label_list} 21 | 22 | Below are definitions of each label to help aid you in creating correct examples to generate. 23 | Assume these definitions are written by an expert and follow them closely when generating the list of examples 24 | 25 | {label_definitions} 26 | 27 | ## Schema of data models to generate 28 | ```json 29 | {schema} 30 | ``` 31 | """ 32 | 33 | MESSAGE_TEMPLATE = """ 34 | {text}{positive_examples}{negative_examples} 35 | Generate a list of {n} examples. 36 | 37 | YOU MUST RETURN A JSON LIST OF EXAMPLES! Not a single example! 38 | """ 39 | 40 | SINGLE_MESSAGE_TEMPLATE = """ 41 | {text}{positive_examples}{negative_examples} 42 | Generate a single example matching the schema. 43 | 44 | YOU MUST RETURN A JSON OBJECT! 45 | """ 46 | 47 | 48 | async def brainstorming_chat( 49 | text: str, 50 | chat: Chat = None, 51 | positive_examples: list[str | BaseModel] = None, 52 | negative_examples: list[str | BaseModel] = None, 53 | batch_size: int = 10, 54 | system_message: str = SYSTEM_MESSAGE, 55 | **chat_kwargs, 56 | ) -> list[BaseModel] | list[str]: 57 | if batch_size == 1: 58 | msg = SINGLE_MESSAGE_TEMPLATE.format( 59 | text=text or "", 60 | positive_examples=f"\n## Good Examples:\n{positive_examples}\n" 61 | if positive_examples 62 | else "", 63 | negative_examples=f"\n## Bad Examples:\n{negative_examples}\n" 64 | if negative_examples 65 | else "", 66 | ) 67 | else: 68 | msg = MESSAGE_TEMPLATE.format( 69 | text=text or "", 70 | n=batch_size, 71 | positive_examples=f"\n## Good Examples:\n{positive_examples}\n" 72 | if positive_examples 73 | else "", 74 | negative_examples=f"\n## Bad Examples:\n{negative_examples}\n" 75 | if negative_examples 76 | else "", 77 | ) 78 | 79 | # TODO replace with a tiktoken model and pad the message by 2x 80 | response = await chat.chat_completion( 81 | messages=[ 82 | {"role": "system", "content": system_message.strip()}, 83 | {"role": "user", "content": msg.strip()}, 84 | ], 85 | **{ 86 | **{ 87 | "max_tokens": 4_096, 88 | "temperature": 0.0, 89 | "response_format": {"type": "json_object"}, 90 | }, 91 | **chat_kwargs, 92 | }, 93 | ) 94 | # TODO check if the results stopped early, in that case, the list might be truncated 95 | results = json.loads(response.choices[0].message.content) 96 | if "items" in results: 97 | results = results["items"] 98 | return results if isinstance(results, list) else [results] 99 | 100 | 101 | async def _brainstorm( 102 | text: str, 103 | n: int = 10, 104 | output_type: BaseModel | None = None, 105 | positive_examples: list[str | BaseModel] = None, 106 | negative_examples: list[str | BaseModel] = None, 107 | batch_size: int = 10, 108 | max_workers: int = 5, 109 | temperature: float = 0.6, 110 | model: str = LanguageModel.GPT35_turbo, 111 | role: str = "You are a Creative Director and Ideation Specialist.", 112 | config: Config = None, 113 | usage: Usage = None, 114 | ) -> AsyncGenerator[BaseModel, BaseModel] | AsyncGenerator[str, str]: 115 | """Generate ideas using a text prompt""" 116 | config = config or Config() 117 | positive_examples = positive_examples or [] 118 | negative_examples = negative_examples or [] 119 | 120 | # Make a list out of the output type 121 | if n > 1: 122 | 123 | class Examples(BaseModel): 124 | items: list[output_type] 125 | 126 | schema = schema_from_model(Examples) 127 | else: 128 | schema = schema_from_model(output_type) 129 | 130 | batch_size = min(batch_size, n) 131 | 132 | chat = Chat( 133 | config=config or Config(), 134 | model=model, 135 | ) 136 | 137 | # Format System Message 138 | item_schema = output_type.model_json_schema() 139 | labels = item_schema.get("properties", {}) 140 | label_list = sorted(labels.keys()) 141 | system_message = SYSTEM_MESSAGE.format( 142 | role=role, 143 | name=item_schema.get("title", "DataModel"), 144 | label_list=label_list, 145 | label_definitions="\n".join( 146 | [f" * {label}: {labels[label]}" for label in label_list] 147 | ), 148 | schema=json.dumps(schema, indent=4), 149 | ) 150 | 151 | if n % batch_size != 0: 152 | max_workers = max(1, min(max_workers, 1 + n // batch_size)) 153 | else: 154 | max_workers = max(1, min(max_workers, n // batch_size)) 155 | 156 | if max_workers == 1: 157 | results = await brainstorming_chat( 158 | text=text, 159 | chat=chat, 160 | positive_examples=positive_examples, 161 | negative_examples=negative_examples, 162 | batch_size=n, 163 | system_message=system_message, 164 | temperature=temperature, 165 | ) 166 | for result in results: 167 | yield result 168 | return 169 | 170 | async def call_brainstorming_chat(): 171 | results = await brainstorming_chat( 172 | text=text, 173 | chat=chat, 174 | positive_examples=positive_examples, 175 | negative_examples=negative_examples, 176 | batch_size=batch_size, 177 | system_message=system_message, 178 | temperature=temperature, 179 | ) 180 | return results 181 | 182 | # Create a list of coroutine objects 183 | tasks = [call_brainstorming_chat() for _ in range(max_workers)] 184 | 185 | # Use as_completed to yield from tasks as they complete 186 | for future in asyncio.as_completed(tasks): 187 | results = await future 188 | for result in results: 189 | yield result 190 | 191 | 192 | async def generate( 193 | text: str, 194 | n: int = 10, 195 | output_type: type[BaseModel] | BaseModel | str | None = None, 196 | positive_examples: list[str | BaseModel] = None, 197 | negative_examples: list[str | BaseModel] = None, 198 | batch_size: int = 10, 199 | max_workers: int = 5, 200 | temperature: float = 0.6, 201 | model: str = LanguageModel.GPT35_turbo, 202 | role: str = "You are a Creative Director and Ideation Specialist.", 203 | config: Config = None, 204 | usage: Usage = None, 205 | ) -> AsyncGenerator[BaseModel, BaseModel] | AsyncGenerator[str, str]: 206 | yield_count = 0 207 | usage = usage or Usage(model=model) 208 | while yield_count < n: 209 | new_yield_count = 0 210 | async for result in _brainstorm( 211 | text, 212 | n=n - yield_count, 213 | temperature=temperature, 214 | output_type=output_type, 215 | positive_examples=positive_examples, 216 | negative_examples=negative_examples, 217 | batch_size=batch_size, 218 | max_workers=max_workers, 219 | model=model, 220 | role=role, 221 | config=config, 222 | usage=usage, 223 | ): 224 | yield output_type(**result) 225 | new_yield_count += 1 226 | yield_count += 1 227 | if yield_count >= n: 228 | break 229 | if new_yield_count == 0: 230 | print("No more results") 231 | break 232 | 233 | 234 | async def example(): 235 | class BusinessIdea(BaseModel): 236 | """A business idea generated using the Jobs-to-be-done framework 237 | For example "We help [adj] [target_audience] do [action] so they can [benefit or do something else]" 238 | """ 239 | 240 | target_audience: str = Field(title="Target Audience") 241 | action: str = Field(title="Action") 242 | benefit: str = Field(title="Benefit or next action") 243 | adj: str | None = Field( 244 | title="Adjective", 245 | description="Optional adjective describing the target audience's condition", 246 | ) 247 | 248 | ideas = [] 249 | ittr = tqdm.tqdm(total=100, desc="Brainstorming Ideas") 250 | async for idea in generate( 251 | text="Generate unique business ideas for a new startup in scuba diving sector", 252 | n=100, 253 | output_type=BusinessIdea, 254 | max_workers=20, 255 | batch_size=5, 256 | ): 257 | ideas.append(idea) 258 | ittr.update(1) 259 | print(f"{len(ideas)} {BusinessIdea.__name__} ideas generated") 260 | return ideas 261 | 262 | 263 | if __name__ == "__main__": 264 | asyncio.run(example()) 265 | -------------------------------------------------------------------------------- /promptedgraphs/generation/schema_from_data.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import json 3 | from typing import Any, Dict, List 4 | 5 | import tqdm 6 | from pydantic import BaseModel, Field 7 | 8 | from promptedgraphs.generation.data_from_model import generate 9 | from promptedgraphs.llms.chat import Chat 10 | from promptedgraphs.statistical.data_analysis import ( 11 | can_cast_to_ints_without_losing_precision_np_updated, 12 | ) 13 | 14 | 15 | class JSONSchemaTitleDescription(BaseModel): 16 | title: str = Field( 17 | title="Title", 18 | description="A short description of the provided schema. PascalCase is recommended.", 19 | ) 20 | description: str = Field( 21 | title="Description", 22 | description="The concise description of the schema.", 23 | ) 24 | 25 | 26 | SYSTEM_MESSAGE = """You are an ontologist and JSON Schema expert. 27 | Your task is to return a partial JSON schema object with only the fields 'title' and 'description' that are best used to describe the data provided by the user. 28 | 29 | Only return a JSON object in the form 30 | { 31 | "title": "a short name describing the user's object" 32 | "description": "a concise description of the object" 33 | } 34 | """ 35 | 36 | MESSAGE_TEMPLATE = """ 37 | ## Schema of the data object 38 | ```json 39 | {schema} 40 | ``` 41 | 42 | ## Example of the data object 43 | ```json 44 | {example} 45 | ``` 46 | This object is nested within a parent object and is accessed using the following path: `{path}` 47 | 48 | This object has the following sibling properties: {sibling_properties} 49 | 50 | Use the path and sibling properties to help determine the best title and description for the object. 51 | You do not need to include the path or sibling properties in the description. 52 | """ 53 | 54 | 55 | async def add_schema_titles_and_descriptions( 56 | schema: dict, 57 | parent_keys: list[str] = None, 58 | chat: Chat | None = None, 59 | ittr=None, 60 | sibling_properties: list[str] = None, 61 | ): 62 | """Adds names and descriptions to a schema based a language model. 63 | Recursively traverse the schema and add title and descriptions 64 | starting with the leaf nodes and working up to the root 65 | each time adding a name and description to the schema 66 | """ 67 | if ittr is None: 68 | ittr = tqdm.tqdm(desc="Adding Titles and Descriptions") 69 | 70 | chat = chat or Chat() 71 | parent_keys = parent_keys or [] 72 | if schema.get("type") == "array": 73 | await add_schema_titles_and_descriptions( 74 | schema.get("items", {}), 75 | parent_keys, 76 | chat=chat, 77 | ittr=ittr, 78 | sibling_properties=[], 79 | ) 80 | elif schema.get("type") != "object": 81 | # At the leaf node no need to add a title or description 82 | return 83 | 84 | properties = set(schema.get("properties", {}).keys()) 85 | for key, value in schema.get("properties", {}).items(): 86 | await add_schema_titles_and_descriptions( 87 | value, 88 | parent_keys + [key], 89 | chat=chat, 90 | ittr=ittr, 91 | sibling_properties=[parent_keys + [p] for p in properties - {key}], 92 | ) 93 | 94 | # Make sure the object has a title and description 95 | if schema.get("title") and schema.get("description"): 96 | return 97 | 98 | example = schema.get("example") or {} 99 | path = ".".join(parent_keys) or "base object" 100 | sibling_properties = ( 101 | "\n * " + "\n * ".join([".".join(sp) for sp in sibling_properties]) 102 | if sibling_properties 103 | else "none" 104 | ) 105 | 106 | response = await chat.chat_completion( 107 | [ 108 | { 109 | "role": "system", 110 | "content": SYSTEM_MESSAGE.strip(), 111 | }, 112 | { 113 | "role": "user", 114 | "content": MESSAGE_TEMPLATE.format( 115 | schema=json.dumps( 116 | { 117 | "type": schema.get("type"), 118 | "properties": schema.get("properties", {}), 119 | }, 120 | indent=4, 121 | ), 122 | example=json.dumps(example, indent=4) 123 | if example 124 | else "none available", 125 | path=path, 126 | sibling_properties="\n" + json.dumps(sibling_properties, indent=4), 127 | ), 128 | }, 129 | ], 130 | **{ 131 | "max_tokens": 4_096, 132 | "temperature": 0.0, 133 | "response_format": {"type": "json_object"}, 134 | }, 135 | ) 136 | meta_data = json.loads(response.choices[0].message.content) 137 | schema["title"] = schema.get("title") or meta_data.get("title") 138 | if schema["title"]: # to PascalCase 139 | schema["title"] = schema["title"].strip().replace(" ", "") 140 | schema["title"] = schema["title"][0].upper() + schema["title"][1:] 141 | schema["description"] = schema.get("description") or meta_data.get("description") 142 | ittr.update(1) 143 | 144 | 145 | def schema_from_data(data_samples: List[Dict[str, Any]]) -> Dict[str, Any]: 146 | """Generates a minimal schema based on provided data samples. 147 | 148 | Args: 149 | data_samples (List[dict]): A list of data samples to generate the schema from. 150 | 151 | Returns: 152 | dict: The generated minimal schema as a dictionary. 153 | """ 154 | if not data_samples: 155 | return {} 156 | 157 | schema: Dict[str, Any] = {"type": "object", "properties": {}} 158 | required_keys = set(data_samples[0].keys()) 159 | 160 | for sample in data_samples: 161 | required_keys &= set(sample.keys()) 162 | for key, value in sample.items(): 163 | if key in schema["properties"]: 164 | schema["properties"][key] = merge_types( 165 | schema["properties"][key], infer_type(value) 166 | ) 167 | else: 168 | schema["properties"][key] = infer_type(value) 169 | 170 | schema["required"] = sorted(required_keys) 171 | 172 | return schema 173 | 174 | 175 | def infer_type(value: Any) -> Dict[str, Any]: 176 | """Infers the type of a value and returns the corresponding schema. 177 | 178 | Args: 179 | value: The value to infer the type from. 180 | 181 | Returns: 182 | dict: The inferred schema for the value. 183 | """ 184 | if isinstance(value, bool): 185 | return {"type": "boolean", "example": value} 186 | elif isinstance(value, int): 187 | return {"type": "integer", "example": value} 188 | elif isinstance(value, float): 189 | if can_cast_to_ints_without_losing_precision_np_updated([value]): 190 | return {"type": "integer", "example": value} 191 | return {"type": "number", "example": value} 192 | elif isinstance(value, str): 193 | with contextlib.suppress(ValueError): 194 | value_float = float(value) 195 | if can_cast_to_ints_without_losing_precision_np_updated([value]): 196 | return {"type": "integer", "example": int(value_float)} 197 | return {"type": "number", "example": value_float} 198 | return {"type": "string", "example": value} 199 | elif isinstance(value, list): 200 | if not value: 201 | return {"type": "array", "items": {}, "example": value} 202 | items_type = infer_type(value[0]) 203 | return {"type": "array", "items": items_type, "example": value} 204 | elif isinstance(value, dict): 205 | properties = {} 206 | example = {} 207 | for key, val in value.items(): 208 | properties[key] = infer_type(val) 209 | example[key] = val 210 | return {"type": "object", "properties": properties, "example": example} 211 | else: 212 | return {} 213 | 214 | 215 | def merge_types(type1: Dict[str, Any], type2: Dict[str, Any]) -> Dict[str, Any]: 216 | """Merges two types to ensure the resulting type is the minimal union of the types found. 217 | 218 | Args: 219 | type1 (dict): The first type to merge. 220 | type2 (dict): The second type to merge. 221 | 222 | Returns: 223 | dict: The merged type. 224 | """ 225 | type1_types = type1.get("anyOf", [type1.get("type")] if type1.get("type") else []) 226 | type2_types = type2.get("anyOf", [type2.get("type")] if type2.get("type") else []) 227 | 228 | type1_types = { 229 | t if isinstance(t, str) else t.get("type") 230 | for t in type1_types 231 | if isinstance(t, str) or t.get("type") 232 | } 233 | type2_types = { 234 | t if isinstance(t, str) else t.get("type") 235 | for t in type2_types 236 | if isinstance(t, str) or t.get("type") 237 | } 238 | 239 | # if type1 is a subtype of type2, return type2 240 | if type1_types.issubset(type2_types): 241 | type1_types = type2_types 242 | elif type2_types.issubset(type1_types): 243 | type2_types = type1_types 244 | 245 | if type1_types != type2_types: 246 | # type_diff = (type1_types - type2_types) | (type2_types - type1_types) 247 | # type_intersection = type1_types & type2_types 248 | type_union = type1_types | type2_types 249 | merged_example = type1.get("example", type2.get("example")) 250 | return { 251 | "anyOf": [{"type": t} for t in sorted(type_union)], 252 | "example": merged_example, 253 | } 254 | 255 | if "object" in type1_types: 256 | merged_properties = { 257 | **type1.get("properties", {}), 258 | **type2.get("properties", {}), 259 | } 260 | merged_example = type1.get("example", type2.get("example", {})) 261 | return { 262 | "type": "object", 263 | "properties": merged_properties, 264 | "example": merged_example, 265 | } 266 | elif "array" in type1_types: 267 | merged_items = merge_types(type1["items"], type2["items"]) 268 | merged_example = type1.get("example", type2.get("example", [])) 269 | return {"type": "array", "items": merged_items, "example": merged_example} 270 | return type1 271 | 272 | 273 | def example(): 274 | data_samples = [ 275 | {"name": "Alice", "age": 30, "is_student": False}, 276 | {"name": "Bob", "is_student": True}, 277 | ] 278 | schema = schema_from_data(data_samples) 279 | print(schema) 280 | 281 | 282 | if __name__ == "__main__": 283 | example() 284 | -------------------------------------------------------------------------------- /promptedgraphs/generation/schema_from_model.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, RootModel 2 | 3 | def extract_references(schema: dict, references: set = None, max_depth=4) -> set: 4 | """Extracts all $ref references from a JSON schema.""" 5 | if max_depth <= 0: 6 | return references if references is not None else set() 7 | 8 | if references is None: 9 | references = set() 10 | 11 | if "$ref" in schema: 12 | references.add(schema["$ref"]) 13 | 14 | if "properties" in schema: 15 | for prop in schema["properties"].values(): 16 | extract_references(prop, references, max_depth-1) 17 | 18 | if "items" in schema: 19 | extract_references(schema["items"], references, max_depth-1) 20 | 21 | return references 22 | 23 | def resolve_references(schema: dict, schema_defs:dict=None, max_depth=4) -> dict: 24 | """Resolves $ref references from $def definitions in a JSONschema.""" 25 | if max_depth <= 0: 26 | return schema 27 | if schema_defs is None: 28 | schema_defs = schema.get('$defs') 29 | if "$ref" in schema: 30 | ref = schema["$ref"] 31 | if resolved_ref := find_ref(ref, schema_defs): 32 | schema = resolved_ref 33 | else: 34 | raise ValueError(f"Reference not found: {ref}") 35 | if "properties" in schema: 36 | schema["properties"] = { 37 | k: resolve_references(v, schema_defs=schema_defs, max_depth=max_depth-1) for k, v in schema["properties"].items() 38 | } 39 | if "items" in schema: 40 | schema["items"] = resolve_references(schema["items"], schema_defs=schema_defs, max_depth=max_depth-1) 41 | return schema 42 | 43 | def find_ref(ref_path: str, schema_defs:dict) -> dict: 44 | ref_path = ref_path.split('/') 45 | definition_name = ref_path[-1] # Extract the last element as definition name 46 | if definition_name in schema_defs: 47 | return schema_defs[definition_name] 48 | raise ValueError(f"Definition not found: {definition_name}") 49 | 50 | 51 | def schema_from_model(data_model: list[BaseModel] | BaseModel, resolve_refs=False) -> dict: 52 | """Generates a schema from a Pydantic DataModel. 53 | 54 | Args: 55 | model (BaseModel): The Pydantic BaseModel instance to generate a schema for. 56 | 57 | Returns: 58 | dict: The generated schema as a dictionary. 59 | """ 60 | if isinstance(data_model, list) or str(data_model).startswith("list["): 61 | x = data_model.__args__[0] 62 | schema = RootModel[list[x]].model_json_schema() 63 | else: 64 | schema = data_model.model_json_schema() 65 | return resolve_references(schema) if resolve_refs else schema 66 | -------------------------------------------------------------------------------- /promptedgraphs/helpers.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def deep_merge(d1: dict[str, any], d2: dict[str, any]): 5 | for k, v in d2.items(): 6 | if k not in d1 or d1[k] is None: 7 | d1[k] = v 8 | elif isinstance(d1[k], dict) and isinstance(v, dict): 9 | deep_merge(d1[k], v) 10 | else: 11 | d1[k] += str(v) 12 | 13 | 14 | def add_space_before_capital(text): 15 | return re.sub(r"(?<=[a-z])(?=[A-Z])", " ", text) 16 | 17 | 18 | def camelcase_to_words(text): 19 | if len(text) <= 1: 20 | return text.replace("_", " ") 21 | if text.lower() == text: 22 | return text.replace("_", " ") 23 | return add_space_before_capital(text[0].upper() + text[1:]).replace("_", " ") 24 | 25 | 26 | def format_fieldinfo(key, v: dict): 27 | l = f"`{key}`: {v.get('title','')} - {v.get('description','')}".strip() 28 | if v.get("annotation"): 29 | annotation = v.get("annotation") 30 | if annotation == "str": 31 | annotation = "string" 32 | l += f"\n\ttype: {annotation}" 33 | if v.get("examples"): 34 | l += f"\n\texamples: {v.get('examples')}" 35 | return l.rstrip() 36 | 37 | 38 | def remove_nas(data: dict[str, any], nulls: list[str] = None): 39 | nulls = nulls or ["==NA==", "NA", "N/A", "n/a", "#N/A", "None", "none"] 40 | for k, v in data.items(): 41 | if isinstance(v, dict): 42 | remove_nas(v) 43 | elif isinstance(v, list): 44 | for i in range(len(v)): 45 | if isinstance(v[i], dict): 46 | remove_nas(v[i]) 47 | elif v[i] in nulls: 48 | v[i] = None 49 | # remove empty list items 50 | data[k] = [i for i in v if i is not None] 51 | elif v in nulls: 52 | data[k] = None 53 | return data 54 | -------------------------------------------------------------------------------- /promptedgraphs/llms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/llms/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/llms/anthropic_chat.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/llms/anthropic_chat.py -------------------------------------------------------------------------------- /promptedgraphs/llms/chat.py: -------------------------------------------------------------------------------- 1 | # Create an enum called LanguageModel with the following values: GPT2, GPT3 2 | 3 | import asyncio 4 | 5 | from promptedgraphs.config import Config 6 | from promptedgraphs.llms.openai_chat import LanguageModel as OpenAILanguageModel 7 | from promptedgraphs.llms.openai_chat import OpenAIChat 8 | 9 | LanguageModels = OpenAILanguageModel 10 | 11 | 12 | class Chat: 13 | def __init__( 14 | self, 15 | model: LanguageModels = LanguageModels.GPT35_turbo, 16 | config: Config = None, 17 | max_retries: int = 3, 18 | timeout=60, 19 | **kwargs, 20 | ): 21 | config = config or Config() 22 | if model in OpenAILanguageModel: 23 | self.chat = OpenAIChat( 24 | api_key=config.openai_api_key, 25 | model=model, 26 | max_retries=max_retries, 27 | timeout=timeout, 28 | **kwargs, 29 | ) 30 | else: 31 | raise NotImplementedError("Anthropic Model not implemented") 32 | 33 | async def chat_completion(self, messages: list[any] = None, **kwargs): 34 | return await self.chat.chat_completion(messages=messages, **kwargs) 35 | 36 | 37 | async def usage_example(): 38 | chat = Chat() 39 | messages = [ 40 | { 41 | "role": "system", 42 | "content": "You are an expert programmer working with a user to debug and fix code. You will write the corrected code and the user provides you with traceback information after running your code on their system. A history of errors and code changes is shown in the conversation. Please fix the code and respond with the corrected code in a code block", 43 | }, 44 | { 45 | "role": "assistant", 46 | "content": "```python\nprint('Hello, World!')\n```", 47 | }, 48 | { 49 | "role": "user", 50 | "content": "# ERROR and STACKTRACE\n```bash\nNameError: name 'print' is not defined\n```", 51 | }, 52 | ] 53 | print(await chat.chat_completion(messages=messages)) 54 | 55 | 56 | if __name__ == "__main__": 57 | asyncio.run(usage_example()) 58 | -------------------------------------------------------------------------------- /promptedgraphs/llms/coding.py: -------------------------------------------------------------------------------- 1 | # get traceback 2 | 3 | from promptedgraphs.llms.helpers import _sync_wrapper, extract_code_blocks 4 | from promptedgraphs.models import ChatMessage 5 | 6 | 7 | async def fix_code( 8 | code: str, error: str, tb: str, history: list[tuple[str, str]] 9 | ) -> tuple[str, list[tuple[str, str]]]: 10 | """return new code plus history of code, tb pairs""" 11 | 12 | system_message = """You are an expert programmer working with a user to debug and fix code 13 | you will write the corrected code and the user provides you with traceback information 14 | after running your code on their system. 15 | 16 | A history of errors and code changes is shown in the conversation. Please fix the code 17 | and respond with the corrected code in a code block""" 18 | 19 | history = history or [] 20 | history.append((code, f"{error}\n\n{tb}" if error else tb or "")) 21 | 22 | messages = [ 23 | ChatMessage( 24 | role="system", 25 | content=system_message, 26 | ) 27 | ] 28 | for c, err in history: 29 | messages.append( 30 | ChatMessage( 31 | role="assistant", 32 | content=f"```python\n{c}\n```", 33 | ) 34 | ) 35 | messages.append( 36 | ChatMessage( 37 | role="user", 38 | content=f"# ERROR and STACKTRACE\n```bash\n{err}\n```", 39 | ) 40 | ) 41 | data = await _sync_wrapper(messages, model="gpt-4-0613") 42 | 43 | # pull out the json and pydantic code blocks 44 | content = data.get("choices")[0]["message"]["content"] 45 | results = extract_code_blocks(content) 46 | for r in results: 47 | if r["block_type"] == "python": 48 | return r["content"], history 49 | -------------------------------------------------------------------------------- /promptedgraphs/llms/helpers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | from promptedgraphs.config import Config, load_config 5 | from promptedgraphs.llms.openai_streaming import ( 6 | GPT_MODEL, 7 | streaming_chat_completion_request, 8 | ) 9 | 10 | 11 | async def _sync_wrapper(messages, config: Config | None = None, model=GPT_MODEL): 12 | payload = "" 13 | async for event in streaming_chat_completion_request( 14 | messages=messages, 15 | functions=None, 16 | config=config or load_config(), 17 | stream=False, 18 | model=model, 19 | ): 20 | if event.data: 21 | payload += event.data 22 | elif event.retry: 23 | print(f"Retry: {event.retry}") 24 | data = json.loads(payload) 25 | return data 26 | 27 | 28 | def extract_code_blocks(text: str) -> list[dict[str, str]]: 29 | pattern = re.compile(r"```(.*?)\n(.*?)```", re.DOTALL) 30 | matches = pattern.findall(text) 31 | 32 | blocks = [] 33 | for match in matches: 34 | block_type, content = match 35 | blocks.append({"block_type": block_type.strip(), "content": content.strip()}) 36 | return blocks 37 | -------------------------------------------------------------------------------- /promptedgraphs/llms/openai_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | from enum import Enum 3 | from logging import getLogger 4 | 5 | import openai 6 | from openai import AsyncOpenAI 7 | 8 | 9 | # Language models that support json chat completions 10 | class LanguageModel(Enum): 11 | GPT35_turbo = "gpt-3.5-turbo" 12 | GPT4 = "gpt-4-0125-preview" 13 | 14 | 15 | class OpenAIChat: 16 | def __init__(self, api_key: str, model: LanguageModel, **kwargs): 17 | self.client = AsyncOpenAI( 18 | api_key=api_key or os.environ.get("OPENAI_API_KEY"), 19 | **kwargs, 20 | ) 21 | self.model = model.value 22 | self.logger = getLogger("openai_chat") 23 | 24 | async def chat_completion(self, messages: list[any] = None, **kwargs) -> None: 25 | try: 26 | return await self.client.chat.completions.create( 27 | messages=messages or [], 28 | model=self.model, 29 | **kwargs, 30 | ) 31 | except openai.APIConnectionError as e: 32 | self.logger.error(f"The server could not be reached: {e.__cause__}") 33 | raise 34 | except openai.AuthenticationError as e: 35 | self.logger.error(f"Authentication with the OpenAI API failed: {e}") 36 | raise e 37 | except openai.RateLimitError as e: 38 | self.logger.error(f"Rate limit exceeded: {e}") 39 | raise e 40 | except openai.APIStatusError as e: 41 | self.logger.error(f"An API error occurred: {e.status_code} {e.response}") 42 | raise e 43 | -------------------------------------------------------------------------------- /promptedgraphs/llms/openai_streaming.py: -------------------------------------------------------------------------------- 1 | # https://github.com/openai/openai-cookbook/blob/60b12dfad1b6e7b32c4a6f1edff3b94c946b467d/examples/How_to_call_functions_with_chat_models.ipynb 2 | import json 3 | from collections.abc import AsyncGenerator 4 | 5 | from httpx import AsyncClient, ReadTimeout 6 | from sse_starlette import ServerSentEvent 7 | 8 | from promptedgraphs.config import Config 9 | from promptedgraphs.llms.openai_chat import LanguageModel 10 | from promptedgraphs.llms.usage import estimate_tokens 11 | from promptedgraphs.models import ChatFunction, ChatMessage 12 | 13 | GPT_MODEL = LanguageModel.GPT35_turbo.value 14 | GPT_MODEL_BIG_CONTEXT = LanguageModel.GPT35_turbo.value 15 | 16 | 17 | # @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3)) 18 | async def streaming_chat_completion_request( 19 | messages: list[ChatMessage] | None, 20 | functions: list[ChatFunction] | None = None, 21 | model=GPT_MODEL, 22 | config: Config | None = None, 23 | temperature=0.2, 24 | max_tokens=4000, 25 | stream=True, 26 | timeout=None, 27 | ) -> AsyncGenerator[bytes, None]: 28 | assert config and config.openai_api_key is not None, "OpenAI API Key not found" 29 | 30 | url = "https://api.openai.com/v1/chat/completions" 31 | headers = { 32 | "Content-Type": "application/json", 33 | "Authorization": f"Bearer {config.openai_api_key}", 34 | } 35 | 36 | json_data = { 37 | "model": model, 38 | "messages": [m.model_dump(exclude_none=True) for m in messages or []], 39 | "temperature": temperature, 40 | "max_tokens": max_tokens, 41 | "stream": stream, 42 | } 43 | if functions is not None and len(functions) > 0: 44 | json_data["functions"] = [ 45 | f if isinstance(f, dict) else f.model_dump(exclude_none=True) 46 | for f in functions 47 | ] 48 | 49 | token_count_approx = estimate_tokens(json_data, model=model) 50 | 51 | if token_count_approx >= 4_096: 52 | model = GPT_MODEL_BIG_CONTEXT 53 | json_data["max_tokens"] = 16_384 54 | json_data["model"] = model 55 | json_data["messages"][-1]["content"] = json_data["messages"][-1][ 56 | "content" 57 | ].strip()[:40_000] 58 | 59 | json_data["max_tokens"] = min( 60 | max(json_data["max_tokens"] - int(token_count_approx), 200), 16_384 61 | ) 62 | 63 | async with AsyncClient(timeout=timeout) as client: 64 | try: 65 | response = await client.post(url, headers=headers, json=json_data) 66 | if response.status_code != 200: 67 | raise ValueError(f"Failed to post to {url}. Response: {response.text}") 68 | yield ServerSentEvent( 69 | data=json.dumps( 70 | { 71 | "usage": { 72 | "prompt_tokens": token_count_approx, 73 | "completion_tokens": 0, 74 | } 75 | } 76 | ) 77 | ) 78 | async for chunk in response.aiter_lines(): 79 | if chunk.startswith("data:"): 80 | yield ServerSentEvent(data=chunk[5:].strip()) 81 | elif chunk.startswith("event:"): 82 | yield ServerSentEvent(event=chunk[6:].strip()) 83 | elif chunk.startswith("id:"): 84 | yield ServerSentEvent(id=chunk[3:].strip()) 85 | elif chunk.startswith("retry:"): 86 | yield ServerSentEvent(retry=chunk[6:].strip()) 87 | else: 88 | yield ServerSentEvent(data=chunk.strip()) 89 | except GeneratorExit: 90 | pass # Handle the generator being closed, if necessary 91 | except ReadTimeout: 92 | yield ServerSentEvent( 93 | data="timeout: Timeout reading the response", event="error" 94 | ) 95 | except Exception as e: 96 | yield ServerSentEvent(data=str(e), event="error") 97 | -------------------------------------------------------------------------------- /promptedgraphs/llms/usage.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from logging import getLogger 4 | 5 | import tiktoken 6 | 7 | from promptedgraphs.llms.openai_chat import LanguageModel 8 | 9 | logger = getLogger(__name__) 10 | 11 | 12 | class Usage: 13 | prompt_tokens: int = 0 14 | completion_tokens: int = 0 15 | 16 | def __init__(self, model: str, computer: str = "unknown") -> None: 17 | self.model = model 18 | self.computer = computer 19 | self.duration = 0 20 | self.prompt_tokens = 0 21 | self.completion_tokens = 0 22 | self.start_time = time.time() 23 | 24 | def start(self): 25 | self.prompt_tokens = 0 26 | self.completion_tokens = 0 27 | self.start_time = time.time() 28 | 29 | def end(self): 30 | self.end_time = time.time() 31 | self.duration = self.end_time - self.start_time 32 | 33 | @property 34 | def cost(self): 35 | return self.llm_cost + self.compute_cost 36 | 37 | @property 38 | def llm_cost(self): 39 | return calculate_langage_model_costs(self, model=self.model) 40 | 41 | @property 42 | def compute_cost(self): 43 | return calculate_compute_costs(self, computer=self.computer) 44 | 45 | def dict(self): 46 | return { 47 | "model": self.model, 48 | "prompt_tokens": self.prompt_tokens, 49 | "completion_tokens": self.completion_tokens, 50 | "duration": self.duration, 51 | "cost": self.cost, 52 | "llm_cost": self.llm_cost, 53 | "compute_cost": self.compute_cost, 54 | } 55 | 56 | def __repr__(self) -> str: 57 | return f"Usage(model={self.model}, prompt_tokens={self.prompt_tokens}, completion_tokens={self.completion_tokens}, duration={self.duration:.4f}, cost={self.cost:.6f}), compute_cost={self.compute_cost:.6f}), llm_cost={self.llm_cost:.6f})" 58 | 59 | 60 | def num_tokens_from_messages(messages, model="gpt-3.5-turbo-1106"): 61 | """Returns the number of tokens used by a list of messages. 62 | See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens. 63 | """ 64 | try: 65 | encoding = tiktoken.encoding_for_model(model) 66 | except KeyError: 67 | encoding = tiktoken.get_encoding("cl100k_base") 68 | if model == "gpt-3.5-turbo-1106": 69 | num_tokens = 0 70 | for message in messages: 71 | num_tokens += ( 72 | 4 # every message follows {role/name}\n{content}\n 73 | ) 74 | for key, value in message.items(): 75 | num_tokens += len(encoding.encode(value)) 76 | if key == "name": # if there's a name, the role is omitted 77 | num_tokens += -1 # role is always required and always 1 token 78 | num_tokens += 2 # every reply is primed with assistant 79 | return num_tokens 80 | else: 81 | encoding = tiktoken.encoding_for_model(model) 82 | return len( 83 | encoding.encode( 84 | "<|endoftext|>".join( 85 | [l for m in messages for l in m.values() if l is not None] 86 | ), 87 | allowed_special=set(encoding._special_tokens.keys()), 88 | ) 89 | ) 90 | 91 | 92 | def estimate_tokens(json_data, model="gpt-3.5-turbo-1106"): 93 | # Adjust max_tokens based on message length and function length 94 | message_tokens = num_tokens_from_messages( 95 | json_data.get("messages", []), model=model 96 | ) 97 | 98 | encoding = tiktoken.encoding_for_model(model) 99 | function_tokens = len( 100 | encoding.encode( 101 | "<|endoftext|>".join( 102 | [ 103 | json.dumps(l) 104 | for m in json_data.get("functions", []) 105 | for l in m.values() 106 | if l is not None 107 | ] 108 | ), 109 | allowed_special=set(encoding._special_tokens.keys()), 110 | ) 111 | ) 112 | return function_tokens + message_tokens 113 | 114 | 115 | def calculate_compute_costs(usage: Usage, computer: str): 116 | # Pricing in dollars per 1000 tokens 117 | # Compute time calculated assuming AWS ondemand prices 118 | # t4g.medium $0.0336 2 4 GiB EBS Only Up to 5 Gigabit 119 | default_compute_cost_second = 0.0336 / 60 / 60 120 | pricing = {"t4g.medium": 0.0336 / 60 / 60} 121 | try: 122 | compute_pricing = pricing[computer] 123 | except KeyError as e: 124 | logger.debug( 125 | f"Server {computer} not found in pricing table, using default pricing of {default_compute_cost_second:.6f}" 126 | ) 127 | compute_pricing = default_compute_cost_second 128 | 129 | return round(usage.duration * compute_pricing, 6) 130 | 131 | 132 | def calculate_langage_model_costs(usage: Usage, model: LanguageModel): 133 | # Pricing in dollars per 1000 tokenspricing = { 134 | pricing = { 135 | LanguageModel.GPT4: { 136 | "prompt": 0.03, 137 | "completion": 0.06, 138 | }, 139 | LanguageModel.GPT35_turbo: { 140 | "prompt": 0.0010, 141 | "completion": 0.0020, 142 | }, 143 | "default": { 144 | "prompt": 0.0, 145 | "completion": 0.0, 146 | }, 147 | } 148 | 149 | try: 150 | model_pricing = pricing[model] 151 | except KeyError as e: 152 | if isinstance(model, LanguageModel): 153 | raise ValueError(f"Model {model} not found in pricing table") from e 154 | else: 155 | logger.warning( 156 | f"Model {model} not found in pricing table, using default pricing of 0" 157 | ) 158 | model_pricing = pricing["default"] 159 | 160 | prompt_cost = usage.prompt_tokens * model_pricing["prompt"] / 1000 161 | completion_cost = usage.completion_tokens * model_pricing["completion"] / 1000 162 | 163 | total_cost = prompt_cost + completion_cost 164 | # round to 6 decimals 165 | return round(total_cost, 6) 166 | -------------------------------------------------------------------------------- /promptedgraphs/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | from dataclasses_json import dataclass_json 5 | from pydantic import BaseModel 6 | 7 | 8 | @dataclass_json 9 | @dataclass 10 | class EntityLabel: 11 | label: str 12 | name: str 13 | description: str | None 14 | 15 | 16 | @dataclass_json 17 | @dataclass 18 | class EntityReference: 19 | start: int 20 | end: int 21 | label: str 22 | text: str | None = None 23 | reason: str | None = None 24 | 25 | 26 | @dataclass_json 27 | @dataclass 28 | class TextwithEntities: 29 | text: str 30 | ents: list[EntityReference] 31 | 32 | 33 | class ChatMessage(BaseModel): 34 | role: str 35 | content: str 36 | name: str | None = None 37 | 38 | 39 | class FunctionParameter(BaseModel): 40 | type: str 41 | description: str | None 42 | 43 | 44 | class FunctionArrayParameter(BaseModel): 45 | type: str 46 | description: str | None 47 | items: Any | None 48 | 49 | 50 | class FunctionParameters(BaseModel): 51 | type: str 52 | properties: dict[str, FunctionParameter] 53 | required: list[str] 54 | 55 | 56 | class ChatFunction(BaseModel): 57 | name: str 58 | description: str 59 | dependencies: list[str] | None 60 | parameters: FunctionParameters 61 | -------------------------------------------------------------------------------- /promptedgraphs/normalization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/normalization/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/normalization/object_to_data.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import tempfile 4 | from logging import getLogger 5 | from pathlib import Path 6 | from string import Template 7 | from typing import Any 8 | 9 | import datamodel_code_generator as dcg 10 | from pydantic import BaseModel, EmailStr, Field, RootModel 11 | 12 | from promptedgraphs import __version__ as version 13 | from promptedgraphs.code_execution.safer_python_exec import safer_exec 14 | from promptedgraphs.generation.schema_from_model import ( 15 | extract_references, 16 | schema_from_model, 17 | ) 18 | from promptedgraphs.llms.chat import Chat 19 | 20 | logger = getLogger(__name__) 21 | 22 | SYSTEM_MESSAGE = """ 23 | We are a data entry system tasked with properly formatting data based on a provided schema. You will be given data that does not conform to the required schema and you are tasked with lightly editing the object to conform with the provided schema. 24 | 25 | If it is a 'value_error' only return the corrected value. 26 | If the corrected value is an object, return it in json format otherwise return just the value with no explanation. 27 | """ 28 | 29 | MESSAGE_TEMPLATE = Template( 30 | """ 31 | # Validation Error: $error_type 32 | $error_msg 33 | 34 | ## Required schema 35 | ``` 36 | $schema 37 | ``` 38 | 39 | ## Object 40 | ``` 41 | $obj 42 | ``` 43 | """ 44 | ) 45 | 46 | 47 | async def correct_value_error( 48 | obj: dict, schema: dict, error_type: str, error_msg: str 49 | ) -> Any: 50 | """Corrects a value error in a data object.""" 51 | obj_str = json.dumps(obj, indent=4) 52 | msg = MESSAGE_TEMPLATE.substitute( 53 | obj=obj_str, 54 | schema=json.dumps(schema, indent=4), 55 | error_type=error_type, 56 | error_msg=error_msg, 57 | ).strip() 58 | chat = Chat() 59 | 60 | # TODO replace with a tiktoken model and pad the message by 2x 61 | max_tokens = len(obj_str) 62 | 63 | response = await chat.chat_completion( 64 | messages=[ 65 | {"role": "system", "content": SYSTEM_MESSAGE.strip()}, 66 | {"role": "system", "content": msg}, 67 | ], 68 | **{ 69 | "max_tokens": min(4_096, max_tokens), 70 | "temperature": 0.0, 71 | "response_format": {"type": "json_object"}, 72 | }, 73 | ) 74 | return json.loads(response.choices[0].message.content) 75 | 76 | 77 | def get_sub_object(data_object: dict, loc: list[str]) -> dict: 78 | """Gets a sub-object from a data object.""" 79 | if not loc: 80 | return data_object 81 | for key in loc: 82 | data_object = data_object[key] 83 | sub_obj = {loc[-1]: data_object} 84 | if len(loc) > 1: 85 | for key in loc[-2::-1]: # Iterate backwards building up nested objects 86 | sub_obj = {key: sub_obj} 87 | 88 | if defs := extract_references(sub_obj): 89 | raise NotImplementedError(f"References are not yet supported: {defs}") 90 | return sub_obj 91 | 92 | 93 | def get_subschema(schema_spec: dict, loc: list[str]) -> dict: 94 | """Gets a sub-schema from a schema specification.""" 95 | # TODO handle array items 96 | for key in loc: 97 | if "properties" in schema_spec: 98 | schema_spec = schema_spec["properties"][key] 99 | else: 100 | schema_spec = schema_spec[key] 101 | return schema_spec 102 | 103 | 104 | def set_data_object_value( 105 | data_object: dict, new_value: Any, old_value: Any, loc: list[str] 106 | ): 107 | """Sets a value in a data object.""" 108 | if not loc: 109 | raise ValueError("Cannot set a value at the root level.") 110 | for key in loc[:-1]: 111 | data_object = data_object[key] 112 | if loc[-1] in new_value: 113 | data_object[loc[-1]] = new_value[loc[-1]] 114 | else: 115 | data_object[loc[-1]] = new_value 116 | 117 | 118 | def data_model_to_schema(data_model: list[BaseModel] | BaseModel) -> dict: 119 | """Converts a Pydantic model to a JSON schema.""" 120 | if isinstance(data_model, list) or str(data_model).startswith("list["): 121 | x = data_model.__args__[0] 122 | return RootModel[list[x]].model_json_schema() 123 | return data_model.model_json_schema() 124 | 125 | 126 | def schema_to_data_model(schema_spec: dict) -> tuple[BaseModel, str]: 127 | """Converts a JSON schema to a Pydantic model. 128 | WARNING: This function uses the output of the datamodel-codegen and schema_spec 129 | and runs `exec` to execute the result. This is a potential security risk and should be used with caution. 130 | 131 | Returns the compiled datamodel and the generated code as a string. 132 | """ 133 | input_text = json.dumps(schema_spec, indent=4) 134 | 135 | output_file = Path(tempfile.mkstemp(prefix="promptedgraphs_")[1]) 136 | 137 | # Get the class name from the schema specification 138 | class_name = schema_spec.get("title", "DataModel") 139 | 140 | dcg.generate( 141 | input_=input_text, 142 | class_name=class_name, 143 | target_python_version=dcg.PythonVersion.PY_310, 144 | output=output_file, 145 | custom_file_header=f"# PromptedGraphs {version}\n# generated by datamodel-codegen", 146 | ) 147 | model_code = Path(output_file).read_text() 148 | output_file.unlink() # Delete the temporary file 149 | 150 | # Get the constucted object from the model code in the exec environment 151 | exec_variable_scope = safer_exec(model_code) 152 | return exec_variable_scope, model_code, class_name 153 | 154 | 155 | async def update_data_object( 156 | data_object: dict, schema_spec: dict, errors: list[str] = None 157 | ): 158 | """Updates the data object with error information.""" 159 | logger.debug(f"Updating data object with error: {errors}") 160 | corrections = [] 161 | for error in errors: 162 | if error["type"] == "missing": 163 | loc: tuple = error["loc"][:-1] if len(error["loc"]) else () 164 | error_msg = f"{error['msg']}: '{error['loc']}' is missing. If possible rename a key to match the schema or add the missing key to the object." 165 | else: 166 | loc = error["loc"] 167 | error_msg = error["msg"] 168 | 169 | old_value = get_sub_object(data_object, loc=loc) 170 | subschema = get_subschema(schema_spec, loc=loc) 171 | new_value = await correct_value_error( 172 | old_value, 173 | subschema, 174 | error_type=error["type"], 175 | error_msg=error_msg, 176 | ) 177 | if len(loc): 178 | set_data_object_value( 179 | data_object, 180 | new_value=new_value, 181 | old_value=old_value, 182 | loc=loc, 183 | ) 184 | else: 185 | data_object = new_value 186 | corrections.append( 187 | (loc, error["type"], error_msg, old_value, new_value) 188 | ) 189 | return data_object, corrections 190 | 191 | 192 | async def object_to_data( 193 | data_object: dict | list, 194 | schema_spec: dict | None = None, 195 | data_model: BaseModel | None = None, 196 | coerce: bool = True, 197 | retry_count: int = 10, 198 | ) -> BaseModel | list[BaseModel]: 199 | """Converts data to fit a given schema, applying light reformatting like type casting and field renaming. 200 | 201 | Args: 202 | data_object (Union[dict, list]): The data to reformat. 203 | schema_spec (Optional[Dict], optional): The schema specification for reformatting. Defaults to None. 204 | data_model (Optional[BaseModel], optional): The Pydantic model for reformatting. Defaults to None. 205 | coerce (bool, optional): Whether to coerce data types. Defaults to True. 206 | 207 | Returns: 208 | Union[dict, list]: The reformatted data. 209 | """ 210 | if isinstance(data_object, list): 211 | return [ 212 | await object_to_data(obj, schema_spec, data_model, coerce) 213 | for obj in data_object 214 | ] 215 | if schema_spec and not data_model: 216 | data_models, model_code, class_name = schema_to_data_model(schema_spec) 217 | # TODO load all of the data_models into local scope 218 | data_model = data_models[class_name] 219 | if not coerce: 220 | return data_model(**data_object) 221 | 222 | corrections = [] 223 | while retry_count > 0: 224 | try: 225 | if len(corrections): 226 | logger.info(f"Coercing data with {len(corrections)}corrections") 227 | for c in corrections: 228 | logger.info( 229 | f"Correcting {c[0]} with {c[1]}: {c[2]} - {c[3]} -> {c[4]}" 230 | ) 231 | return data_model(**data_object) 232 | except Exception as e: 233 | errors = e.errors 234 | if not isinstance(errors, list): 235 | errors = errors() 236 | 237 | # Ensure schema_spec is defined. This is needed for the update_data_object function. 238 | schema_spec = schema_spec or schema_from_model( 239 | data_model, resolve_refs=True 240 | ) 241 | 242 | # Update the data object with error information 243 | data_object, new_corrections = await update_data_object( 244 | data_object, schema_spec, errors=errors 245 | ) 246 | if new_corrections: 247 | corrections.extend(new_corrections) 248 | retry_count -= 1 249 | 250 | raise ValueError("Failed to coerce data to schema.") 251 | 252 | 253 | async def example(): 254 | data = { 255 | "name": "John Doe", 256 | "age": "10", 257 | "email": "john.doe@gmail", 258 | } 259 | 260 | class UserBioData(BaseModel): 261 | name: str 262 | age: int = Field(..., gt=0) 263 | email: EmailStr 264 | 265 | schema = schema_from_model(UserBioData) 266 | print(schema) 267 | 268 | data_model = await object_to_data(data, schema) 269 | print(data_model) 270 | 271 | 272 | if __name__ == "__main__": 273 | asyncio.run(example()) 274 | -------------------------------------------------------------------------------- /promptedgraphs/normalization/schema_to_schema.py: -------------------------------------------------------------------------------- 1 | def schema_to_schema(schema: dict, db_schema: dict) -> dict: 2 | """Generates a mapping from a high-level schema to a database schema. 3 | 4 | Args: 5 | schema (dict): The high-level schema. 6 | db_schema (dict): The database schema. 7 | 8 | Returns: 9 | dict: A Directed Acyclic Graph (DAG) mapping schema elements to database elements. 10 | """ 11 | -------------------------------------------------------------------------------- /promptedgraphs/normalization/vis_graphs.py: -------------------------------------------------------------------------------- 1 | from pprint import pformat 2 | 3 | import matplotlib.pyplot as plt 4 | import networkx as nx 5 | import numpy as np 6 | from matplotlib.patches import FancyBboxPatch 7 | 8 | 9 | def graph_to_mermaid(g: nx.DiGraph | nx.MultiDiGraph, kind="graph"): 10 | """Converts a NetworkX graph to a Mermaid markdown string.""" 11 | if kind == "graph": 12 | md = "graph TD\n" 13 | elif kind == "entity_relationship": 14 | md = "erDiagram\n" 15 | else: 16 | raise ValueError(f"kind must be 'graph' or 'entity_relationship', not {kind}") 17 | 18 | for node in g.nodes: 19 | node_data = g.nodes[node] 20 | node_id = str(node) 21 | node_label = node_data.get("name", node_id) 22 | md += f" {node_id}[{node_label}]\n" 23 | 24 | for edge in g.edges: 25 | edge_data = g.edges[edge] 26 | source, target = edge 27 | md += f" {source} --> {target}\n" 28 | return md 29 | 30 | 31 | def visualize_data_graph(g): 32 | # Prepare a color map only for nodes where 'kind' == 'object' 33 | color_map = [] 34 | unique_types = {g.nodes[n]["kind"] for n in g.nodes if g.nodes[n].get("kind")} 35 | unique_types.add("_unknown") 36 | # matplotlib.colormaps.get_cmap(obj) 37 | cm = plt.colormaps.get_cmap("viridis") 38 | type_colors = { 39 | node_type: cm(float(i) / len(unique_types)) 40 | for i, node_type in enumerate(unique_types) 41 | } 42 | 43 | # Apply colors to nodes 44 | for node in g.nodes: 45 | node_data = g.nodes[node] 46 | color_map.append(type_colors[node_data.get("kind") or "_unknown"]) 47 | 48 | # Draw the graph 49 | fig, ax = plt.subplots() 50 | pos = nx.nx_agraph.graphviz_layout(g, prog="dot") 51 | nx.draw(g, pos, node_color=color_map, ax=ax) 52 | 53 | # Define base node size and scaling factor for width 54 | base_width = 75 55 | height = 25 56 | 57 | for node in g.nodes: 58 | node_data = g.nodes[node] 59 | x, y = pos[node] 60 | if node_data.get("kind") == "object": 61 | num_children = len(list(g.successors(node))) 62 | width = base_width * num_children # Width scaled by number of children 63 | color = type_colors[node_data["kind"]] 64 | box = FancyBboxPatch( 65 | (x - width / 2, y - height / 2), 66 | width, 67 | height, 68 | boxstyle="round,pad=0.1", 69 | color=color, 70 | ec="black", 71 | zorder=100, 72 | ) 73 | ax.add_patch(box) 74 | plt.text( 75 | x, 76 | y, 77 | str(node).split("_")[0], 78 | ha="center", 79 | va="center", 80 | color="black", 81 | zorder=101, 82 | fontsize=10, 83 | ) 84 | # Set limits and draw edges as normal 85 | ax.set_xlim( 86 | np.min([x for x, y in pos.values()]) - 50, 87 | np.max([x for x, y in pos.values()]) + 50, 88 | ) 89 | ax.set_ylim( 90 | np.min([y for x, y in pos.values()]) - 50, 91 | np.max([y for x, y in pos.values()]) + 50, 92 | ) 93 | 94 | plt.axis("off") 95 | return fig, ax 96 | 97 | # Now visualize the graph in markdown 98 | 99 | 100 | def _tree_as_markdown(g, node, level=0): 101 | node_data = g.nodes[node] 102 | node_id = str(node) 103 | data_type = node_data.get("kind", "object") 104 | label = f"ID[{node_id}]: {node_data.get('name', '')}({data_type})".strip() 105 | if description := node_data.get("description"): 106 | label += f" - {description}" 107 | md = [f"{' ' * level} * {label}"] 108 | 109 | if other_data := { 110 | k: v 111 | for k, v in dict(node_data).items() 112 | if k not in ("kind", "description", "schema_id", "parents", "name") 113 | }: 114 | prefix = f"{' ' * (level+1)} * properties: " 115 | md.append(prefix + pformat(other_data, indent=len(prefix))) 116 | md.extend( 117 | _tree_as_markdown(g, child, level=level + 1) for child in g.successors(node) 118 | ) 119 | return "\n".join(md) 120 | 121 | 122 | def data_graph_as_markdown(g): 123 | root_candidates = [node for node, indegree in g.in_degree() if indegree == 0] 124 | md = ["## Data Graph"] 125 | md.extend(_tree_as_markdown(g, root) for root in root_candidates) 126 | return "\n".join(md) 127 | -------------------------------------------------------------------------------- /promptedgraphs/parsers.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def extract_partial_list(s, key="entities"): 5 | if not isinstance(s, str) or s == "": 6 | return [] 7 | 8 | # remove everything before the first occurrence of the key 9 | field = f'"{key}":' 10 | s = field.join(s.split(field)[1:]).strip() 11 | if len(s) == 0: 12 | return [] 13 | 14 | if not s.startswith("["): 15 | return [] 16 | 17 | if s.endswith("},"): 18 | s = s[:-1] 19 | 20 | if not s.rstrip().endswith("}"): 21 | return [] 22 | 23 | s += "]" 24 | try: 25 | return json.loads(s) 26 | except json.decoder.JSONDecodeError: 27 | return [] 28 | -------------------------------------------------------------------------------- /promptedgraphs/sources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/sources/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/sources/datagraph_from_class.py: -------------------------------------------------------------------------------- 1 | """Bootstrap a datagraph from APIs and functions""" 2 | import asyncio 3 | import datetime 4 | import inspect 5 | import json 6 | import logging 7 | import os 8 | import re 9 | from pathlib import Path 10 | 11 | import googlemaps 12 | import tqdm 13 | from bs4 import BeautifulSoup 14 | from dotenv import load_dotenv 15 | 16 | from promptedgraphs.config import load_config 17 | from promptedgraphs.llms.helpers import _sync_wrapper, extract_code_blocks 18 | from promptedgraphs.llms.openai_streaming import streaming_chat_completion_request 19 | from promptedgraphs.models import ChatMessage 20 | from promptedgraphs.sources.rtfm import fetch_from_ogtags 21 | 22 | load_dotenv() 23 | load_config() 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | SYSTEM_MESSAGE_FN_TO_SCHEMAS = """You are a python developer tasked with building pydantic data models to enforce types for various function calls. 28 | Below is a function's call signature and doc string with optionally additional documentation 29 | 30 | # Steps: 31 | 1. Write the function as an openAPI spec in JSON assuming the endpoint is GET `/{fn_name}`. Make sure to enclose the JSON data in a code block 32 | 33 | 2. Then convert the openAPI json spec above into pydantic data models. Use pydantic.Field to keep types, title and description information. 34 | 35 | Make sure to enclose the pydantic data models within a code block like 36 | ```python 37 | from typing import Dict, List 38 | from pydantic import BaseModel, Field 39 | 40 | 41 | class ExampleRequest(BaseModel): 42 | pass 43 | 44 | 45 | class ExampleResponse(BaseModel): 46 | pass 47 | ``` 48 | 49 | The response should only contain two code blocks (one with JSON data and one with python code) 50 | """ 51 | 52 | 53 | def get_functions_from_object(fn: classmethod): 54 | """Get functions from an object""" 55 | return [ 56 | getattr(fn, a) 57 | for a in dir(fn) 58 | if callable(getattr(fn, a)) and not a.startswith("_") 59 | ] 60 | 61 | 62 | def build_function_signature(fn: classmethod): 63 | """Build a function from a signature""" 64 | sig = inspect.signature(fn) 65 | doc = (fn.__doc__ or "").strip() 66 | return f'def {fn.__name__}{str(sig)}:\n """{doc}' + '\n """\n ....' 67 | 68 | 69 | def extract_urls_with_anchors(text): 70 | if not text: 71 | return [] 72 | url_pattern = re.compile( 73 | r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+(?:#[a-zA-Z0-9-_]*)?" 74 | ) 75 | return sorted(set(url_pattern.findall(text))) 76 | 77 | 78 | def slim_down_html(html: str, anchor: str = None) -> tuple[str, bool]: 79 | html = BeautifulSoup(html, features="html.parser") 80 | # remove all script and style and visual elements 81 | for script in html(["script", "style", "svg", "img"]): 82 | script.extract() 83 | 84 | if anchor: 85 | if anchor := html.find(id=anchor): 86 | return str(anchor), True 87 | 88 | if article := html.find("article"): 89 | return str(article), False 90 | elif main := html.find("main"): 91 | return str(main), False 92 | elif body := html.find("body"): 93 | return str(body), False 94 | return str(html), False 95 | 96 | 97 | async def html_to_markdown(html, url=None): 98 | """Convert html to markdown""" 99 | if url and "#" in url: 100 | anchor = url.split("#")[1] if url else None 101 | else: 102 | anchor = None 103 | html, used_anchor = slim_down_html(html, anchor=anchor) 104 | if anchor and not used_anchor: 105 | url = url.split("#")[0] 106 | 107 | SYSTEM_MESSAGE = """Format webpage as markdown""" 108 | payload = "" 109 | async for event in streaming_chat_completion_request( 110 | messages=[ 111 | ChatMessage( 112 | role="system", 113 | content=SYSTEM_MESSAGE, 114 | ), 115 | ChatMessage(role="user", content=html), 116 | ], 117 | functions=None, 118 | config=load_config(), 119 | stream=False, 120 | ): 121 | if event.data: 122 | payload += event.data 123 | elif event.retry: 124 | print(f"Retry: {event.retry}") 125 | 126 | data = json.loads(payload) 127 | return data["choices"][0]["message"]["content"], url 128 | 129 | 130 | async def build_function_schemas( 131 | fn_name: str, fn_signature: str, external_references: dict[str, str] = None 132 | ) -> tuple[dict]: 133 | """Convert function signatures into OpenAPI specs and pydantic datamodels""" 134 | 135 | external_references = external_references or {} 136 | 137 | messages = [ 138 | ChatMessage( 139 | role="system", 140 | content=SYSTEM_MESSAGE_FN_TO_SCHEMAS.format(fn_name=fn_name), 141 | ) 142 | ] 143 | messages.append(ChatMessage(role="user", content=fn_signature)) 144 | messages.extend( 145 | ChatMessage( 146 | role="assistant", 147 | content=f"Additional documentation loaded from {url}\n\n{html}", 148 | ) 149 | for url, html in external_references.items() 150 | ) 151 | data = await _sync_wrapper(messages) 152 | 153 | # pull out the json and pydantic code blocks 154 | content = data.get("choices")[0]["message"]["content"] 155 | results = extract_code_blocks(content) 156 | if len(results) != 2: # Try again 157 | missing_types = {"json", "python"} - {r["block_type"] for r in results} 158 | logger.warning(f"Missing code blocks of type: {missing_types}") 159 | messages.append( 160 | ChatMessage( 161 | role="assistant", 162 | content=content, 163 | ) 164 | ) 165 | messages.append( 166 | ChatMessage( 167 | role="user", 168 | content=f"Please provide the missing `{','.join(missing_types)}` code block", 169 | ), 170 | ) 171 | data = await _sync_wrapper(messages) 172 | content = data.get("choices")[0]["message"]["content"] 173 | results.extend(extract_code_blocks(content)) 174 | 175 | # Append the function signature documenation to the python code block 176 | docs = f"""# {fn_name}\n\n```python\n{fn_signature}\n\n```""" 177 | for url, html in external_references.items(): 178 | docs += f"\n\n## Additional documentation loaded from {url}\n\n{html}" 179 | results.append({"block_type": "md", "content": docs}) 180 | 181 | return results 182 | 183 | 184 | async def register_function_as_datasource( 185 | obj: object, 186 | name: str = None, 187 | datasource_registry: str = None, 188 | ): 189 | name = ( 190 | name or getattr(obj, "__name__", None) or str(obj).split()[0].replace("<", "") 191 | ) 192 | assert name, "You must provide a name for the datasource" 193 | 194 | datasource_registry = Path(datasource_registry or "./data_models/") 195 | output_dir = datasource_registry / name 196 | output_dir.mkdir(exist_ok=True, parents=True) 197 | 198 | meta = { 199 | "name": name, 200 | "description": obj.__doc__ or "", 201 | "as_of": datetime.datetime.now(tz=datetime.timezone.utc).isoformat(), 202 | } 203 | 204 | fns = get_functions_from_object(obj) 205 | ittr = tqdm.tqdm(fns, desc=f"Building schemas for {name}") 206 | for fn in ittr: 207 | ittr.set_description(f"Building schema for {name}::{fn.__name__}") 208 | links = extract_urls_with_anchors(fn.__doc__) 209 | sig = build_function_signature(fn) 210 | 211 | external_references = {} 212 | for link in tqdm.tqdm(links, desc="Fetching external references"): 213 | if link in external_references: 214 | continue 215 | if html := fetch_from_ogtags(link): 216 | html, new_link = await html_to_markdown(html, url=link) 217 | external_references[new_link] = html 218 | 219 | fn_schemas = await build_function_schemas( 220 | fn.__name__, sig, external_references=external_references 221 | ) 222 | 223 | for fn_schema in fn_schemas: 224 | suffix = ( 225 | "py" if fn_schema["block_type"] == "python" else fn_schema["block_type"] 226 | ) 227 | with open(output_dir / f"{fn.__name__}.{suffix}", "w") as f: 228 | f.write(fn_schema["content"]) 229 | 230 | with open(datasource_registry / "meta.jsonl", "a") as f: 231 | f.write(json.dumps(meta) + "\n") 232 | 233 | # Aggregate all the schemas into one file 234 | 235 | 236 | if __name__ == "__main__": 237 | load_dotenv() 238 | gmaps = googlemaps.Client(key=os.environ["GOOGLEMAPS_API_KEY"]) 239 | asyncio.run(register_function_as_datasource(gmaps)) 240 | -------------------------------------------------------------------------------- /promptedgraphs/sources/datagraph_from_pydantic.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import re 5 | import traceback 6 | from pathlib import Path 7 | from typing import Any 8 | 9 | import black 10 | import isort 11 | import networkx as nx 12 | import tqdm 13 | from pydantic import BaseModel, Field 14 | 15 | from promptedgraphs.code_execution.safer_python_exec import format_code 16 | from promptedgraphs.llms.coding import fix_code 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | # Restrictive environment 21 | allowed_globals = { 22 | "__import__": __import__, 23 | "BaseModel": BaseModel, 24 | "Field": Field, 25 | "bytes": bytes, 26 | } 27 | 28 | 29 | def kindofsafe_exec(code, dependencies: dict = None): 30 | variables = {} 31 | dependencies = dependencies or allowed_globals 32 | 33 | for snippet in code.split( 34 | "\n\n" 35 | ): # Loading in code by chunks helps with dependency resolution 36 | dependencies.update(variables) 37 | exec(snippet, dependencies, variables) 38 | DeprecationWarning("This is a deprecated function: use safer_exec") 39 | return variables 40 | 41 | 42 | def is_pydantic_base_model(cls: Any) -> bool: 43 | return isinstance(cls, type) and issubclass(cls, BaseModel) 44 | 45 | 46 | def extract_fields_and_type(cls: type[BaseModel]) -> list[tuple[str, type, str]]: 47 | results = [] 48 | if type(cls) == type(BaseModel): 49 | return [] 50 | schema = cls.model_json_schema() 51 | properties = schema.get("properties", {}) 52 | for field_name, field_info in properties.items(): 53 | if field_info["type"] == "array": 54 | results.append((field_name, cls, "item")) 55 | else: 56 | results.append((field_name, cls, "object")) 57 | return results 58 | 59 | 60 | def process_file(file_path): 61 | with open(file_path) as file: 62 | file_content = file.read() 63 | 64 | # Things can throw weird errors if multiple definitions of the same name are in the same file 65 | 66 | import networkx as nx 67 | 68 | g = nx.MultiDiGraph() 69 | local_vars = kindofsafe_exec(file_content) 70 | for name, cls in local_vars.items(): 71 | if ( 72 | isinstance(cls, type) 73 | and name != "BaseModel" 74 | and issubclass(cls, BaseModel) 75 | and hasattr(cls, "model_json_schema") 76 | ): 77 | schema = cls.model_json_schema() 78 | if schema.get("type") == "object": 79 | target = schema.get("title", cls.__name__) 80 | target_type = "entity" 81 | if target.endswith("Request"): 82 | target = f"fn({target[:-7]})" 83 | target_type = "fn" 84 | elif target.endswith("Response"): 85 | target = f"fn({target[:-8]})" 86 | target_type = "fn" 87 | if target not in g.nodes: 88 | g.add_node(target, kind=target_type) 89 | for p_name, p_field in schema.get("properties", {}).items(): 90 | src = p_name 91 | if src not in g.nodes: 92 | g.add_node(src, kind="field") 93 | g.add_edge(src, target, kind="property", data=json.dumps(p_field)) 94 | 95 | for n, d in g.nodes(data=True): 96 | if d["kind"] == "fn" and n.startswith("fn("): 97 | target = n[3:-1] 98 | if target in g.nodes: 99 | g.add_edge(target, n, kind="reference") 100 | 101 | updates = add_entity_is_subset_edges(g) 102 | print("Added edges:", updates) 103 | 104 | nx.write_graphml(g, file_path.parent / "schema.graphml") 105 | 106 | sorted(g.degree, key=lambda x: x[1], reverse=True) 107 | entities = [ 108 | n 109 | for n in sorted(g.degree, key=lambda x: x[1], reverse=True) 110 | if g.nodes[n[0]]["kind"] == "entity" 111 | ] 112 | 113 | # Give me shortest path count between entities and fields (those are their properties) 114 | entities = [ 115 | n 116 | for n in sorted(g.degree, key=lambda x: x[1], reverse=True) 117 | if g.nodes[n[0]]["kind"] == "entity" 118 | ] 119 | print(schema) 120 | print(entities) 121 | 122 | 123 | def add_entity_is_subset_edges(g: nx.DiGraph): 124 | # Extract node attributes for 'kind' 125 | node_kinds = nx.get_node_attributes(g, "kind") 126 | 127 | # Filter out nodes with kind 'fn' and 'field' 128 | field_nodes = [node for node, kind in node_kinds.items() if kind == "field"] 129 | entity_nodes = [node for node, kind in node_kinds.items() if kind == "entity"] 130 | 131 | # Create the adjacency matrix using pandas 132 | adjacency_matrix = nx.to_pandas_adjacency(g, nodelist=field_nodes + entity_nodes) 133 | # Extract relevant columns 134 | adjacency_matrix = adjacency_matrix.loc[field_nodes, entity_nodes] 135 | 136 | # Check for column subsets and record parent columns 137 | transpose_matrix = adjacency_matrix.transpose() 138 | found = { 139 | "is_equal": 0, 140 | "is_subset": 0, 141 | } 142 | # There is a vectorized version of this but not necessary for now 143 | for i, col1 in transpose_matrix.iterrows(): 144 | if col1.sum() == 0: # Skip empty columns 145 | continue 146 | for j, col2 in transpose_matrix.iterrows(): 147 | if i != j and col2.sum() > 0: 148 | if all(col1 == col2): 149 | g.add_edge(j, i, kind="is_equal") 150 | found["is_equal"] += 1 151 | elif all(col1 <= col2) and any(col1 < col2): 152 | g.add_edge(j, i, kind="is_subset") 153 | found["is_subset"] += 1 154 | return found 155 | 156 | 157 | def get_mappings_of_functions_to_objects(fname, code): 158 | requests = [] 159 | responses = [] 160 | other = [] 161 | fn_name = fname.name.split(".")[0] 162 | 163 | obj = kindofsafe_exec(code) 164 | 165 | for k, v in obj.items(): 166 | if k.endswith("Response"): 167 | responses.append(k) 168 | elif k.endswith("Request"): 169 | requests.append(k) 170 | elif is_pydantic_base_model(v) and k != "BaseModel": 171 | other.append(k) 172 | return { 173 | "name": fn_name, 174 | "requests": requests, 175 | "responses": responses, 176 | "other": other, 177 | } 178 | 179 | 180 | def aggregate_fn_to_object_mappings(fdir): 181 | g = nx.MultiDiGraph() 182 | fdir = Path(fdir) 183 | for fname in fdir.glob("*.py"): 184 | if fname.name.startswith("_"): 185 | continue 186 | with open(fname) as f: 187 | m = get_mappings_of_functions_to_objects(fname, f.read()) 188 | g.add_node(m["name"], kind="fn") 189 | for r in m["requests"]: 190 | g.add_node(r, kind="request") 191 | g.add_edge(m["name"], r, kind="request") 192 | for r in m["responses"]: 193 | g.add_node(r, kind="response") 194 | g.add_edge(m["name"], r, kind="response") 195 | for r in m["other"]: 196 | g.add_node(r, kind="other") 197 | g.add_edge(m["name"], r, kind="other") 198 | return g 199 | 200 | 201 | def aggregate_python_files(fdir): 202 | fdir = Path(fdir) 203 | code = [] 204 | for fname in fdir.glob("*.py"): 205 | with open(fname) as f: 206 | c = format_code(f.read()) 207 | code.append(c) 208 | try: 209 | format_code("\n\n".join(code)) 210 | except Exception: 211 | logger.error(f"Static code error in: {fname}") 212 | return 213 | 214 | # move all imports to the top 215 | code = move_imports_to_top("\n\n".join(code)) 216 | config = isort.Config(profile="black") 217 | code = isort.code(code, config=config) 218 | 219 | # reformat one more time 220 | code = format_code(code) 221 | 222 | with open(fdir / "_all.py", "w") as f: 223 | f.write(code) 224 | 225 | return fdir / "_all.py" 226 | 227 | 228 | def move_imports_to_top(code): 229 | lines = code.split("\n") 230 | 231 | # Separating import lines and other lines 232 | import_lines = [ 233 | line for line in lines if re.match(r"^(import |from .+ import)", line) 234 | ] 235 | other_lines = [ 236 | line for line in lines if not re.match(r"^(import |from .+ import)", line) 237 | ] 238 | 239 | # Combining the lines and writing back to the file 240 | return "\n".join(import_lines + other_lines) 241 | 242 | 243 | async def validate_python_files(fdir): 244 | fnames = sorted(fdir.glob("*.py")) 245 | for fname in tqdm.tqdm(fnames): 246 | if fname.name == "_all.py": 247 | continue 248 | code = Path(fname).read_text() 249 | 250 | history = [] 251 | i = 0 252 | while i < 4: 253 | code = format_code(code) 254 | kindofsafe_exec(code) 255 | # break 256 | try: 257 | code = format_code(code) 258 | kindofsafe_exec(code) 259 | break 260 | except Exception as e: 261 | i += 1 262 | if i >= 4: 263 | logger.error(f"Code error in: {fname}") 264 | return 265 | 266 | logger.warning(f"fixing code error in: {fname} - take {i} - {e}") 267 | code, history = await fix_code( 268 | code, error=e, tb=traceback.format_exc(), history=history 269 | ) 270 | if i > 0: 271 | logger.warning(f"Fixed code error in: {fname}") 272 | with open(fname, "w") as f: 273 | f.write(code) 274 | 275 | 276 | async def python_files_pipeline(fdir): 277 | await validate_python_files(fdir) 278 | graph_of_functions_and_objects = aggregate_fn_to_object_mappings(fdir) 279 | with open(fdir / "_function_graph.json", "w") as f: 280 | f.write(json.dumps(nx.node_link_data(graph_of_functions_and_objects))) 281 | nx.write_graphml(graph_of_functions_and_objects, fdir / "_function_graph.graphml") 282 | 283 | file_path = aggregate_python_files(fdir) 284 | results = process_file(file_path) 285 | print(results) 286 | return results 287 | 288 | 289 | if __name__ == "__main__": 290 | fdir = Path("data_models/googlemaps.client.Client") 291 | asyncio.run(python_files_pipeline(fdir)) 292 | 293 | # for r in results: 294 | -------------------------------------------------------------------------------- /promptedgraphs/statistical/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/statistical/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/statistical/data_analysis.py: -------------------------------------------------------------------------------- 1 | """This document helps understand 2 | the data types and distributions of various data sets.""" 3 | 4 | import concurrent.futures 5 | import os 6 | import time 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import pandas as pd 11 | import scipy.stats 12 | import seaborn as sns 13 | import tqdm 14 | from scipy.stats import kstest 15 | 16 | from promptedgraphs.vis import get_colors 17 | 18 | # from scipy.stats 19 | DISCRETE_DISTRIBUTIONS = [ 20 | "bernoulli", # -- Bernoulli 21 | "betabinom", # -- Beta-Binomial 22 | "binom", # -- Binomial 23 | "boltzmann", # -- Boltzmann (Truncated Discrete Exponential) 24 | "dlaplace", # -- Discrete Laplacian 25 | "geom", # -- Geometric 26 | "hypergeom", # -- Hypergeometric 27 | "logser", # -- Logarithmic (Log-Series, Series) 28 | "nbinom", # -- Negative Binomial 29 | "nchypergeom_fisher", # Fisher's Noncentral Hypergeometric 30 | "nchypergeom_wallenius", # Wallenius's Noncentral Hypergeometric 31 | "nhypergeom", # -- Negative Hypergeometric 32 | "planck", # -- Planck (Discrete Exponential) 33 | "poisson", # -- Poisson 34 | "randint", # -- Discrete Uniform 35 | "skellam", # -- Skellam 36 | "yulesimon", # -- Yule-Simon 37 | "zipf", # -- Zipf (Zeta) 38 | "zipfian", # -- Zipfian 39 | ] 40 | CONTINUOUS_DISTRIBUTIONS = [ 41 | "alpha", # -- Alpha 42 | "anglit", # -- Anglit 43 | "arcsine", # -- Arcsine 44 | "argus", # -- Argus 45 | "beta", # -- Beta 46 | "betaprime", # -- Beta Prime 47 | "bradford", # -- Bradford 48 | "burr", # -- Burr (Type III) 49 | "burr12", # -- Burr (Type XII) 50 | "cauchy", # -- Cauchy 51 | "chi", # -- Chi 52 | "chi2", # -- Chi-squared 53 | "cosine", # -- Cosine 54 | "crystalball", # Crystalball 55 | "dgamma", # -- Double Gamma 56 | "dweibull", # -- Double Weibull 57 | "erlang", # -- Erlang 58 | "expon", # -- Exponential 59 | "exponnorm", # -- Exponentially Modified Normal 60 | "exponweib", # -- Exponentiated Weibull 61 | "exponpow", # -- Exponential Power 62 | "f", # -- F (Snecdor F) 63 | "fatiguelife", # Fatigue Life (Birnbaum-Saunders) 64 | "fisk", # -- Fisk 65 | "foldcauchy", # -- Folded Cauchy 66 | "foldnorm", # -- Folded Normal 67 | "genlogistic", # Generalized Logistic 68 | "gennorm", # -- Generalized normal 69 | "genpareto", # -- Generalized Pareto 70 | "genexpon", # -- Generalized Exponential 71 | "genextreme", # -- Generalized Extreme Value 72 | "gausshyper", # -- Gauss Hypergeometric 73 | "gamma", # -- Gamma 74 | "gengamma", # -- Generalized gamma 75 | "genhalflogistic", # Generalized Half Logistic 76 | # "genhyperbolic", # Generalized Hyperbolic 77 | # "geninvgauss", # Generalized Inverse Gaussian 78 | "gibrat", # -- Gibrat 79 | "gompertz", # -- Gompertz (Truncated Gumbel) 80 | "gumbel_r", # -- Right Sided Gumbel, Log-Weibull, Fisher-Tippett, Extreme Value Type I 81 | "gumbel_l", # -- Left Sided Gumbel, etc. 82 | "halfcauchy", # -- Half Cauchy 83 | "halflogistic", # Half Logistic 84 | "halfnorm", # -- Half Normal 85 | "halfgennorm", # Generalized Half Normal 86 | "hypsecant", # -- Hyperbolic Secant 87 | "invgamma", # -- Inverse Gamma 88 | "invgauss", # -- Inverse Gaussian 89 | "invweibull", # -- Inverse Weibull 90 | "johnsonsb", # -- Johnson SB 91 | "johnsonsu", # -- Johnson SU 92 | "kappa4", # -- Kappa 4 parameter 93 | "kappa3", # -- Kappa 3 parameter 94 | "ksone", # -- Distribution of Kolmogorov-Smirnov one-sided test statistic 95 | # "kstwo", # -- Distribution of Kolmogorov-Smirnov two-sided test statistic 96 | "kstwobign", # -- Limiting Distribution of scaled Kolmogorov-Smirnov two-sided test statistic. 97 | "laplace", # -- Laplace 98 | "laplace_asymmetric", # Asymmetric Laplace 99 | "levy", # -- Levy 100 | "levy_l", # 101 | # "levy_stable", # 102 | "logistic", # -- Logistic 103 | "loggamma", # -- Log-Gamma 104 | "loglaplace", # -- Log-Laplace (Log Double Exponential) 105 | "lognorm", # -- Log-Normal 106 | "loguniform", # -- Log-Uniform 107 | "lomax", # -- Lomax (Pareto of the second kind) 108 | "maxwell", # -- Maxwell 109 | "mielke", # -- Mielke's Beta-Kappa 110 | "moyal", # -- Moyal 111 | "nakagami", # -- Nakagami 112 | # "ncx2", # -- Non-central chi-squared 113 | # "ncf", # -- Non-central F 114 | # "nct", # -- Non-central Student's T 115 | "norm", # -- Normal (Gaussian) 116 | # "norminvgauss", # Normal Inverse Gaussian 117 | "pareto", # -- Pareto 118 | "pearson3", # -- Pearson type III 119 | "powerlaw", # -- Power-function 120 | "powerlognorm", # Power log normal 121 | "powernorm", # -- Power normal 122 | "rdist", # -- R-distribution 123 | "rayleigh", # -- Rayleigh 124 | "rel_breitwigner", # Relativistic Breit-Wigner 125 | "rice", # -- Rice 126 | "recipinvgauss", # Reciprocal Inverse Gaussian 127 | "semicircular", # Semicircular 128 | "skewcauchy", # -- Skew Cauchy 129 | "skewnorm", # -- Skew normal 130 | # "studentized_range", # Studentized Range 131 | "t", # -- Student's T 132 | "trapezoid", # -- Trapezoidal 133 | "triang", # -- Triangular 134 | "truncexpon", # -- Truncated Exponential 135 | "truncnorm", # -- Truncated Normal 136 | "truncpareto", # Truncated Pareto 137 | "truncweibull_min", # Truncated minimum Weibull distribution 138 | # "tukeylambda", # Tukey-Lambda 139 | "uniform", # -- Uniform 140 | "vonmises", # -- Von-Mises (Circular) 141 | "vonmises_line", # Von-Mises (Line) 142 | "wald", # -- Wald 143 | "weibull_min", # Minimum Weibull (see Frechet) 144 | "weibull_max", # Maximum Weibull (see Frechet) 145 | "wrapcauchy", # -- Wrapped Cauchy 146 | ] 147 | 148 | 149 | def can_cast_to_ints_without_losing_precision_np_updated( 150 | values: list[float | int], epsilon=1e-9 151 | ): 152 | """ 153 | Check if all values in the list can be safely cast to integers without losing precision using NumPy, 154 | allowing for a small epsilon deviation from integer numbers to handle floating point errors. 155 | 156 | Args: 157 | - values (list or numpy.ndarray): A list or numpy array of values to check. 158 | - epsilon (float): The tolerance for deviation from an exact integer, to handle floating point errors. 159 | 160 | Returns: 161 | - bool: True if all values can be safely cast to integers without significant loss of precision, False otherwise. 162 | """ 163 | np_values = np.array( 164 | values 165 | ) # Convert the list to a NumPy array if it's not already one 166 | try: 167 | fractional_parts = np.abs(np_values - np.round(np_values)) 168 | except TypeError: 169 | return False 170 | return np.all(fractional_parts <= epsilon) 171 | 172 | 173 | # Function to fit distribution and perform KS test 174 | def _fit_and_test(data, dist): 175 | try: 176 | t = time.time() 177 | dist_params = getattr(scipy.stats, dist).fit(data) 178 | fit_time = time.time() - t 179 | ks_statistic, p_value = kstest(data, dist, args=dist_params) 180 | result = { 181 | "Parameters": dist_params, 182 | "KS-Test": ks_statistic, 183 | "P-Value": p_value, 184 | "fit_time": fit_time, 185 | } 186 | return (dist, result) 187 | except Exception as e: 188 | return (dist, f"Error fitting {dist}: {e}") 189 | 190 | 191 | def fit_distribution( 192 | data: np.array, 193 | discrete_or_continuous: str | None = None, 194 | max_workers: int | None = None, 195 | ): 196 | # max_workers is num cores by default 197 | if max_workers is None or max_workers < 1: 198 | max_workers = max_workers or int(os.cpu_count() / 2) 199 | assert discrete_or_continuous in ["discrete", "continuous", None], ( # noqa 200 | "discrete_or_continuous must be one of the following: " 201 | "['discrete', 'continuous', None]" 202 | ) 203 | if discrete_or_continuous is None: 204 | discrete_or_continuous = ( 205 | "discrete" 206 | if can_cast_to_ints_without_losing_precision_np_updated(data) 207 | else "continuous" 208 | ) 209 | if discrete_or_continuous == "discrete": 210 | dists = DISCRETE_DISTRIBUTIONS 211 | else: 212 | dists = CONTINUOUS_DISTRIBUTIONS 213 | 214 | np.random.shuffle(dists) 215 | fit_results = {} 216 | 217 | # Use ProcessPoolExecutor to parallelize fitting 218 | fit_results = {} 219 | with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: 220 | # Map the function over the distributions 221 | futures = {executor.submit(_fit_and_test, data, dist): dist for dist in dists} 222 | ittr = tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(dists)) 223 | for future in ittr: 224 | dist = futures[future] 225 | try: 226 | result = future.result(timeout=10) # 30-second timeout for each task 227 | if isinstance(result[1], dict): 228 | fit_results[result[0]] = result[1] 229 | # Optionally print current best fits 230 | print("Remaining:", sorted(set(dists) - set(fit_results.keys()))) 231 | # print( 232 | # .T.sort_values("KS-Test", ascending=True) 233 | # .head(10) 234 | else: 235 | print(result[1]) # Print error message 236 | except concurrent.futures.TimeoutError: 237 | print(f"TimeoutError: Fitting {dist} exceeded 30 seconds.") 238 | 239 | # for dist in ittr: 240 | # # Fit the data to various distributions and estimate their parameters 241 | # # Kolmogorov-Smirnov test for goodness-of-fit 242 | # # Results 243 | # fit_results[dist] = { 244 | # "KS-Test": ks_statistic, 245 | # "P-Value": p_value, 246 | # print( 247 | # .T.sort_values("KS-Test", ascending=True) 248 | # .head(10) 249 | 250 | return fit_results 251 | 252 | 253 | def plot_fitted(data, results: pd.DataFrame, top_n=5): 254 | # Re-import necessary libraries and re-define variables after reset 255 | # Re-fit parameters for selected distributions 256 | results = results.sort_values("KS-Test", ascending=True).head(top_n) 257 | 258 | # Create a range of values for plotting fitted distributions 259 | x_values = np.linspace(min(data), max(data), 1000) 260 | 261 | # Plotting 262 | plt.figure(figsize=(14, 8)) 263 | sns.histplot( 264 | data, 265 | kde=True, 266 | bins=30, 267 | color="gray", 268 | stat="density", 269 | label="Empirical", 270 | alpha=0.5, 271 | ) 272 | 273 | colors = get_colors(results.index.tolist()) 274 | for dist_name, row in results.iterrows(): 275 | print("plotting", dist_name) 276 | dist_params = row["Parameters"] 277 | pdf = getattr(scipy.stats, dist_name).pdf(x_values, *dist_params) 278 | plt.plot( 279 | x_values, pdf, label=dist_name, color=colors[dist_name][:7], linestyle="--" 280 | ) 281 | 282 | plt.title("Empirical Distribution with Fitted Distributions") 283 | plt.xlabel("Value") 284 | plt.ylabel("Density") 285 | plt.legend() 286 | plt.show() 287 | 288 | 289 | def get_posterior_weights( 290 | data: np.ndarray, results: pd.DataFrame, priors: dict[str, float] = None 291 | ): 292 | uniform_prior = 1 / len(results) 293 | default_priors = {dist: uniform_prior for dist in results.index} 294 | priors = priors or default_priors 295 | len(data) 296 | log_pdfs = {} 297 | for dist_name, row in results.iterrows(): 298 | dist_params = row["Parameters"] 299 | pdf = getattr(scipy.stats, dist_name).pdf(data, *dist_params) 300 | log_pdfs[dist_name] = np.log(pdf).sum() / len(data) 301 | if np.isinf(log_pdfs[dist_name]) or np.isnan(log_pdfs[dist_name]): 302 | log_pdfs[dist_name] = -np.inf 303 | 304 | Z = sum(priors[dist] * np.exp(log_pdf) for dist, log_pdf in log_pdfs.items()) 305 | return { 306 | dist: priors[dist] * np.exp(log_pdf) / Z for dist, log_pdf in log_pdfs.items() 307 | } 308 | 309 | 310 | if __name__ == "__main__": 311 | # Load the data 312 | import json 313 | 314 | with open( 315 | "/usr/local/repos/thecrowdsline/thecrowdsline-data/nfl_roster.jsonl" 316 | ) as f: 317 | data = pd.DataFrame(json.loads(line) for line in f.readlines()) 318 | 319 | # Fit the data to various distributions 320 | c = "age" 321 | c = "entry_year" 322 | c = "draft_number" 323 | x = data[c].dropna().values 324 | dists = fit_distribution(x, discrete_or_continuous="continuous", max_workers=4) 325 | 326 | df = pd.DataFrame(dists).T 327 | top_dists = df.loc[df["KS-Test"] < 0.05] 328 | if len(top_dists) > 0: 329 | print(top_dists) 330 | plot_fitted(x, top_dists, top_n=len(top_dists)) 331 | weights = get_posterior_weights(x, top_dists) 332 | print(top_dists.join(pd.Series(weights, name="Posterior Weight"))) 333 | print(dists) 334 | -------------------------------------------------------------------------------- /promptedgraphs/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/utils/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/promptedgraphs/validation/__init__.py -------------------------------------------------------------------------------- /promptedgraphs/validation/validate_data.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | from promptedgraphs.normalization.object_to_data import object_to_data 4 | 5 | 6 | def validate_data( 7 | data_object: dict, 8 | schema_spec: dict | None = None, 9 | data_model: BaseModel | None = None, 10 | ) -> bool: 11 | """Validates a data object against a schema specification or a Pydantic DataModel. 12 | 13 | Args: 14 | data_object (dict): The data object to validate. 15 | schema_spec (Optional[dict], optional): The schema specification to validate against. Defaults to None. 16 | data_model (Optional[BaseModel], optional): The Pydantic BaseModel to validate against. Defaults to None. 17 | 18 | Returns: 19 | bool: True if the data object is valid, False otherwise. 20 | """ 21 | try: 22 | object_to_data( 23 | data_object=data_object, 24 | schema_spec=schema_spec, 25 | data_model=data_model, 26 | coerce=False, 27 | ) 28 | return True 29 | except ValueError: 30 | return False 31 | -------------------------------------------------------------------------------- /promptedgraphs/validation/validate_schema.py: -------------------------------------------------------------------------------- 1 | def validate_schema(schema_object: dict, schema_spec: dict) -> bool: 2 | """Validates a schema object against a schema specification. 3 | 4 | Args: 5 | schema_object (dict): The schema object to validate. 6 | schema_spec (dict): The schema specification to validate against. 7 | 8 | Returns: 9 | bool: True if the schema object adheres to the schema specification, False otherwise. 10 | """ 11 | -------------------------------------------------------------------------------- /promptedgraphs/vis.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | from pydantic import BaseModel 3 | from spacy import displacy 4 | 5 | from promptedgraphs.models import EntityReference 6 | 7 | 8 | def rgb_to_hex(rgb): 9 | return "#{:02x}{:02x}{:02x}".format( 10 | int(rgb[0] * 256), int(rgb[1] * 256), int(rgb[2] * 256) 11 | ) 12 | 13 | 14 | def get_fields(ents: list[EntityReference] | dict | BaseModel): 15 | if hasattr(ents, "model_dump"): 16 | data = ents.model_dump() 17 | return sorted(data.keys()) 18 | elif isinstance(ents, dict): 19 | return sorted(ents.keys()) 20 | elif isinstance(ents[0], EntityReference): 21 | return sorted({e.label for e in ents}) 22 | else: 23 | raise ValueError("ents must be a list of EntityReference, BaseModel or dict") 24 | 25 | 26 | def get_colors(fields: list[str], color_palette: list[float] = None): 27 | palette = color_palette or sns.color_palette("Set2", len(fields)) 28 | return {f: rgb_to_hex(color)[:7] for f, color in zip(list(fields), palette)} 29 | 30 | 31 | def ensure_entities( 32 | ents: list[EntityReference] | dict | BaseModel, text: str 33 | ) -> list[EntityReference]: 34 | if ents is None: 35 | return None 36 | elif hasattr(ents, "model_dump"): 37 | data = ents.model_dump() 38 | return [ 39 | EntityReference( 40 | start=text.find(v), end=text.find(v) + len(v), label=k, text=v 41 | ) 42 | for k, v in data.items() 43 | if v in text 44 | ] 45 | elif isinstance(ents, dict): 46 | data = ents 47 | return [ 48 | EntityReference( 49 | start=text.find(v), end=text.find(v) + len(v), label=k, text=v 50 | ) 51 | for k, v in data.items() 52 | if v in text 53 | ] 54 | return [e for e in ents if isinstance(e, EntityReference)] 55 | 56 | 57 | def render_entities( 58 | text: str, 59 | ents: BaseModel | list[EntityReference] | dict = None, 60 | jupyter=True, 61 | color_dict: dict = None, 62 | color_palette: list[float] = None, 63 | **options 64 | ): 65 | """Renders entities using the displacy.render function""" 66 | 67 | if ents is None: 68 | return displacy.render( 69 | {"text": text}, style="ent", jupyter=jupyter, manual=True 70 | ) 71 | 72 | ents = ensure_entities(ents, text) 73 | fields = get_fields(ents) 74 | color_dict = color_dict or get_colors(fields, color_palette) 75 | 76 | # Build colors 77 | if color_dict is None: 78 | palette = color_palette or sns.color_palette("Set2", 8) 79 | color_dict = {f: rgb_to_hex(color) for f, color in zip(list(fields), palette)} 80 | 81 | return displacy.render( 82 | { 83 | "text": text, 84 | "ents": [e.to_dict() for e in ents], 85 | }, 86 | style="ent", 87 | jupyter=jupyter, 88 | manual=True, 89 | options={"colors": color_dict, **options}, 90 | ) 91 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "promptedgraphs" 3 | version = "0.4.3" 4 | description = "From Dataset Labeling to Deployment: The Power of NLP and LLMs Combined." 5 | authors = ["Sean Kruzel "] 6 | license = "MIT" 7 | readme = "README.md" 8 | packages = [{ include = "promptedgraphs" }] 9 | classifiers = [ 10 | "Topic :: Software Development :: Libraries :: Python Modules", 11 | "Programming Language :: Python :: 3.10", 12 | "Operating System :: OS Independent", 13 | ] 14 | 15 | [tool.poetry.urls] 16 | "Repository" = "https://github.com/closedLoop-technologies/promptedgraphs" 17 | "Bug Tracker" = "https://github.com/closedLoop-technologies/promptedgraphs/issues" 18 | 19 | [tool.poetry.dependencies] 20 | python = "<4.0,>=3.10" 21 | pre-commit = "^3.3.1" 22 | gcloud = "^0.18.3" 23 | typer = "^0.9.0" 24 | pyfiglet = "^1.0.2" 25 | termcolor = "^2.4.0" 26 | tabulate = "^0.9.0" 27 | colorama = "^0.4.6" 28 | pydantic = "^2.6.3" 29 | tiktoken = "^0.6.0" 30 | httpx = "^0.27.0" 31 | spacy = "^3.7.4" 32 | googlemaps = "^4.10.0" 33 | beautifulsoup4 = "^4.12.3" 34 | networkx = "^3.2.1" 35 | pandas = "^2.2.1" 36 | seaborn = "^0.13.2" 37 | nltk = "^3.8.1" 38 | transformers = "^4.37.2" 39 | matplotlib = "^3.8.0" 40 | jupyter = "^1.0.0" 41 | openai = "^1.12.0" 42 | scikit-learn = "^1.4.0" 43 | cohere = "^4.47" 44 | flair = "^0.13.1" 45 | spacy-experimental = "^0.6.4" 46 | ipykernel = "^6.29.2" 47 | amrlib = "^0.8.0" 48 | unidecode = "^1.3.8" 49 | penman = "^1.3.0" 50 | datamodel-code-generator = "^0.25.5" 51 | pyarrow = "^15.0.2" 52 | torch = "^2.2.2" 53 | pydot = "^2.0.0" 54 | jsonpath-ng = "^1.6.1" 55 | 56 | [tool.poetry.group.dev.dependencies] 57 | black = "^23.3.0" 58 | coverage = "^7.2.5" 59 | eradicate = "^2.2.0" 60 | isort = "^5.12.0" 61 | mypy = "^1.3.0" 62 | pre-commit-hooks = "^4.4.0" 63 | pre-commit = "^3.6.2" 64 | pycodestyle = "^2.10.0" 65 | pyflakes = "^3.0.1" 66 | pytest-cov = "^4.0.0" 67 | pytest = "^7.3.1" 68 | radon = "^6.0.1" 69 | vulture = "^2.7" 70 | pip-upgrader = "^1.4.15" 71 | pyupgrade = "^3.4.0" 72 | ipykernel = "^6.25.2" 73 | twine = "^5.0.0" 74 | 75 | [tool.poetry.scripts] 76 | promptedgraphs = "promptedgraphs.cli:app" 77 | 78 | [build-system] 79 | requires = ["poetry-core"] 80 | build-backend = "poetry.core.masonry.api" 81 | -------------------------------------------------------------------------------- /run_security_check.sh: -------------------------------------------------------------------------------- 1 | sudo docker run -v .:/path zricethezav/gitleaks:latest detect --source="/path" --report-path /path/.gitleaks-report.json 2 | -------------------------------------------------------------------------------- /run_tests_with_coverage.sh: -------------------------------------------------------------------------------- 1 | python -m coverage erase 2 | python -m coverage run -m unittest discover 3 | python -m coverage report 4 | python -m coverage html 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/tests/__init__.py -------------------------------------------------------------------------------- /tests/all.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | if __name__ == "__main__": 4 | import pytest 5 | 6 | # run tests with code coverage 7 | pytest.main(["-s", "--tb=native", "--cov-config=.coveragerc", "--maxfail=2"]) 8 | -------------------------------------------------------------------------------- /tests/generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closedloop-technologies/PromptedGraphs/8a2b9c1c77cadcc89eb22c74500f114069c2ef1f/tests/generation/__init__.py -------------------------------------------------------------------------------- /tests/generation/test_schema_from_data.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from typing import Any, Dict, List 3 | 4 | from promptedgraphs.generation.schema_from_data import schema_from_data 5 | 6 | 7 | class TestSchemaFromData(unittest.TestCase): 8 | def test_empty_data_samples(self): 9 | data_samples: List[Dict[str, Any]] = [] 10 | expected_schema = {} 11 | self.assertEqual(schema_from_data(data_samples), expected_schema) 12 | 13 | def test_single_data_sample(self): 14 | data_samples = [{"name": "John", "age": 30, "city": "New York"}] 15 | expected_schema = { 16 | "type": "object", 17 | "properties": { 18 | "name": {"type": "string", "example": "John"}, 19 | "age": {"type": "integer", "example": 30}, 20 | "city": {"type": "string", "example": "New York"}, 21 | }, 22 | "required": sorted(["name", "age", "city"]), 23 | } 24 | self.assertEqual(schema_from_data(data_samples), expected_schema) 25 | 26 | def test_multiple_data_samples(self): 27 | data_samples = [ 28 | {"name": "John", "age": 30, "city": "New York"}, 29 | {"name": "Alice", "age": 25, "city": "London"}, 30 | {"name": "Bob", "age": 35, "country": "USA"}, 31 | ] 32 | expected_schema = { 33 | "type": "object", 34 | "properties": { 35 | "name": {"type": "string", "example": "John"}, 36 | "age": {"type": "integer", "example": 30}, 37 | "city": {"type": "string", "example": "New York"}, 38 | "country": {"type": "string", "example": "USA"}, 39 | }, 40 | "required": ["age", "name"], 41 | } 42 | self.assertEqual(schema_from_data(data_samples), expected_schema) 43 | 44 | def test_nested_objects(self): 45 | data_samples = [ 46 | {"person": {"name": "Alice", "age": 25}, "city": "London"}, 47 | {"person": {"name": "John", "age": 30}, "city": "New York"}, 48 | ] 49 | expected_schema = { 50 | "type": "object", 51 | "properties": { 52 | "person": { 53 | "type": "object", 54 | "properties": { 55 | "name": {"type": "string", "example": "John"}, 56 | "age": {"type": "integer", "example": 30}, 57 | }, 58 | # "required": ["age","name"], 59 | "example": {"name": "Alice", "age": 25}, 60 | }, 61 | "city": {"type": "string", "example": "London"}, 62 | }, 63 | "required": ["city", "person"], 64 | # "example": {"person": {"name": "John", "age": 30}, "city": "New York"} 65 | } 66 | self.assertEqual(schema_from_data(data_samples), expected_schema) 67 | 68 | def test_arrays(self): 69 | data_samples = [ 70 | {"name": "John", "hobbies": ["reading", "swimming"]}, 71 | {"name": "Alice", "hobbies": ["painting", "dancing"]}, 72 | ] 73 | expected_schema = { 74 | "type": "object", 75 | "properties": { 76 | "name": {"type": "string", "example": "John"}, 77 | "hobbies": { 78 | "type": "array", 79 | "items": {"type": "string", "example": "reading"}, 80 | "example": ["reading", "swimming"], 81 | }, 82 | }, 83 | "required": sorted(["name", "hobbies"]), 84 | } 85 | self.assertEqual(schema_from_data(data_samples), expected_schema) 86 | 87 | def test_type_merging(self): 88 | data_samples = [{"value": 10}, {"value": "20"}, {"value": 30.5}] 89 | expected_schema = { 90 | "type": "object", 91 | "properties": { 92 | "value": { 93 | "anyOf": [{"type": "integer"}, {"type": "number"}], 94 | "example": 10, 95 | } 96 | }, 97 | "required": ["value"], 98 | } 99 | self.assertEqual(schema_from_data(data_samples), expected_schema) 100 | 101 | 102 | if __name__ == "__main__": 103 | # unittest.main() 104 | t = TestSchemaFromData() 105 | t.test_type_merging() 106 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | """Tests the CRUD actions in the CLI for taskforce""" 2 | import unittest 3 | 4 | from typer.testing import CliRunner 5 | 6 | from promptedgraphs import cli 7 | 8 | 9 | class TestCLI(unittest.TestCase): 10 | def setUp(self): 11 | self.runner = CliRunner() 12 | 13 | def test_info(self): 14 | result = self.runner.invoke(cli.app, ["info"]) 15 | self.assertEqual(result.exit_code, 0) 16 | self.assertGreater(len(result.stdout), 0) 17 | self.assertIn("version", result.stdout.lower().strip(), "version not in output") 18 | 19 | def test_main(self): 20 | result = self.runner.invoke(cli.app, ["main"]) 21 | self.assertEqual(result.exit_code, 0) 22 | 23 | def test_help(self): 24 | result = self.runner.invoke(cli.app, ["--help"]) 25 | self.assertEqual(result.exit_code, 0) 26 | 27 | 28 | if __name__ == "__main__": 29 | unittest.main() 30 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from promptedgraphs import __title__ as name 4 | from promptedgraphs.config import Config 5 | 6 | 7 | class TestConfig(unittest.TestCase): 8 | def test_configclass(self): 9 | config = Config() 10 | self.assertIsInstance(config, Config, msg="config is not a Config") 11 | self.assertEqual(config.name, name) 12 | 13 | 14 | if __name__ == "__main__": 15 | unittest.main() 16 | -------------------------------------------------------------------------------- /tests/test_install.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from subprocess import run 3 | 4 | 5 | class TestInstall(unittest.TestCase): 6 | def test_library_installed(self): 7 | import promptedgraphs 8 | 9 | self.assertIsNotNone(promptedgraphs) 10 | 11 | def test_module(self): 12 | run(["python3", "-m", "promptedgraphs", "--help"]) 13 | 14 | # def test_consolescript(self): 15 | # run(["promptedgraphs", "--help"]) 16 | --------------------------------------------------------------------------------