├── .devcontainer └── devcontainer.json ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ ├── ci.yml │ ├── publish-docs.yml │ └── publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── collect_script.py ├── docs ├── explanation │ ├── agent-workflow.md │ ├── code-sandbox.md │ ├── file-reading.ipynb │ └── ipython-startup-scripts.md ├── howto │ ├── cleanup-error-trace.md │ ├── customize-table-info.md │ ├── incluster-code-execution.md │ ├── messages-truncation.ipynb │ ├── normalize-datasets.ipynb │ ├── persist-messages.ipynb │ └── retrieval.ipynb ├── index.md ├── reference.md ├── static │ └── result.png ├── stylesheets │ └── extra.css └── tutorials │ ├── chat-on-tabular-data.ipynb │ ├── continue-analysis-on-generated-charts.ipynb │ └── quick-start.ipynb ├── examples ├── __init__.py ├── data_analysis.py ├── datasets │ ├── titanic.csv │ ├── 产品生产统计表.xlsx │ └── 产品销量表.csv └── quick_start.py ├── ipython ├── README.md ├── ipython-startup-scripts │ ├── 00-pandas.py │ ├── 98-udfs.py │ └── 99-cfont.py └── requirements.txt ├── mkdocs.yml ├── pyproject.toml ├── realtabbench ├── README.md ├── __init__.py ├── agent_eval │ ├── README.md │ ├── __init__.py │ ├── __main__.py │ ├── config.py │ ├── evaluatee.py │ ├── evaluator │ │ ├── __init__.py │ │ ├── output_parser.py │ │ └── prompt.py │ ├── example-config.yaml │ ├── questioner.py │ ├── requirements.txt │ ├── runner.py │ ├── tablegpt_evaluatee.py │ └── worker.py ├── evalset │ ├── bird_data │ │ ├── dev.json │ │ ├── dev.sql │ │ └── dev_tables.json │ └── spider_data │ │ ├── dev.json │ │ ├── dev_gold.sql │ │ ├── test.json │ │ ├── test_gold.sql │ │ └── test_tables.json ├── inference.py ├── inference_encoder.py ├── requirements.txt ├── run_text2sql_eval.py ├── text2sql │ ├── __init__.py │ └── src │ │ ├── __init__.py │ │ ├── evaluation.py │ │ ├── gpt_request.py │ │ └── gpt_request_encoder.py └── utils.py ├── src └── tablegpt │ ├── __about__.py │ ├── __init__.py │ ├── agent │ ├── __init__.py │ ├── data_analyzer.py │ ├── file_reading │ │ ├── __init__.py │ │ └── data_normalizer.py │ └── output_parser.py │ ├── errors.py │ ├── retriever │ ├── __init__.py │ ├── compressor.py │ └── loader.py │ ├── safety.py │ ├── tools.py │ ├── translation.py │ └── utils.py └── tests ├── __init__.py ├── agent ├── __init__.py ├── file_reading │ ├── __init__.py │ └── test_data_normalizer.py └── test_output_parser.py ├── retriever ├── __init__.py ├── test_compressor.py ├── test_format.py └── test_loader.py ├── test_profile_init.py ├── test_safety.py ├── test_tools.py └── test_utils.py /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile 3 | { 4 | "name": "tablegpt-agent", 5 | "image": "mcr.microsoft.com/devcontainers/python:1-3.12", 6 | 7 | "containerEnv" : { 8 | // This will instruct hatch to create envs in the workspace folder. 9 | // It makes selecting interpreter simpler. 10 | "HATCH_DATA_DIR": "${containerWorkspaceFolder}" 11 | }, 12 | 13 | // Use 'postCreateCommand' to run commands after the container is created. 14 | "postCreateCommand": "pip3 install hatch", 15 | 16 | // See https://stackoverflow.com/questions/70206554/share-ssh-keys-with-vs-code-devcontainer-running-with-dockers-wsl2-backend 17 | "mounts": [ 18 | "type=bind,source=${localEnv:HOME}${localEnv:USERPROFILE}/.ssh,target=/home/vscode/.ssh,readonly" 19 | ] 20 | } 21 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | *.{cmd,[cC][mM][dD]} text eol=crlf 3 | *.{bat,[bB][aA][tT]} text eol=crlf 4 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "" 5 | labels: bug 6 | assignees: "" 7 | --- 8 | 9 | - [ ] I have searched the issue tracker and believe that this is not a duplicate. 10 | 11 | 12 | ## Run `python collect_script.py` and paste or upload the resulting text file here. 13 | 14 | 15 | 16 | > If you are using TableGPT2 deployed with vLLM, please specify the vLLM version and include the command used to start the server. 17 | > 18 | > If not, you may skip this section. 19 | ## vLLM version 20 | 21 | ### The version of the vLLM 22 | 23 | 24 | ### The start command of the vLLM serve 25 | 26 | 27 | 28 | ## Steps to reproduce 29 | 30 | 31 | 32 | 33 | ## Actual behavior 34 | 35 | 36 | 37 | ## Expected behavior 38 | 39 | 40 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: "CI" 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | paths: 7 | - "src/**" 8 | - "tests/**" 9 | - "Makefile" 10 | - "pyproject.toml" 11 | pull_request: 12 | branches: [ main ] 13 | paths: 14 | - "src/**" 15 | - "tests/**" 16 | - "Makefile" 17 | - "pyproject.toml" 18 | 19 | jobs: 20 | lint-test: 21 | name: "lint & tests" 22 | runs-on: ubuntu-latest 23 | strategy: 24 | matrix: 25 | python-version: ["3.9", "3.10", "3.11", "3.12"] 26 | 27 | steps: 28 | - uses: actions/checkout@v4 29 | - name: Set up Python ${{ matrix.python-version }} 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | 34 | - name: Set up pip cache 35 | if: runner.os == 'Linux' 36 | uses: actions/cache@v4 37 | with: 38 | path: ~/.cache/pip 39 | key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }} 40 | restore-keys: ${{ runner.os }}-pip- 41 | 42 | - name: Install hatch 43 | run: | 44 | pipx install hatch 45 | 46 | - name: Lint 47 | run: make lint 48 | 49 | - name: Tests 50 | run: make test 51 | 52 | install-ubuntu: 53 | name: "install on ubuntu" 54 | runs-on: ubuntu-latest 55 | strategy: 56 | matrix: 57 | python-version: ["3.9", "3.10", "3.11", "3.12"] 58 | 59 | steps: 60 | - uses: actions/checkout@v4 61 | - name: Set up Python ${{ matrix.python-version }} 62 | uses: actions/setup-python@v5 63 | with: 64 | python-version: ${{ matrix.python-version }} 65 | # 66 | cache: 'pip' 67 | 68 | - name: Install tablegpt-agent 69 | run: | 70 | pip install -e . 71 | 72 | install-win: 73 | name: "install on windows" 74 | runs-on: windows-latest 75 | strategy: 76 | matrix: 77 | python-version: ["3.9", "3.10", "3.11", "3.12"] 78 | 79 | steps: 80 | - uses: actions/checkout@v4 81 | - name: Set up Python ${{ matrix.python-version }} 82 | uses: actions/setup-python@v5 83 | with: 84 | python-version: ${{ matrix.python-version }} 85 | # 86 | cache: 'pip' 87 | 88 | - name: Install tablegpt-agent 89 | run: | 90 | pip install -e . 91 | -------------------------------------------------------------------------------- /.github/workflows/publish-docs.yml: -------------------------------------------------------------------------------- 1 | name: Publish docs 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | paths: 7 | - "docs/**" 8 | - "mkdocs.yml" 9 | workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI 10 | 11 | jobs: 12 | run: 13 | name: "deploy docs" 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | with: 19 | # See 20 | fetch-depth: 0 21 | - name: Set up Python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: 3.12 25 | 26 | - name: Set up pip cache 27 | if: runner.os == 'Linux' 28 | uses: actions/cache@v4 29 | with: 30 | path: ~/.cache/pip 31 | key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }} 32 | restore-keys: ${{ runner.os }}-pip- 33 | 34 | - name: Install hatch 35 | run: | 36 | pipx install hatch 37 | 38 | - name: Publish doc 39 | run: hatch env run -e docs mkdocs gh-deploy 40 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | permissions: 13 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 14 | contents: read # IMPORTANT: this permission is mandatory for private repositories 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: '3.11' 22 | cache: 'pip' 23 | - name: Install dependencies 24 | run: | 25 | pipx install hatch 26 | - name: Build package 27 | run: hatch build 28 | - name: Publish package distributions to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | with: 31 | verbose: true 32 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.5.0 4 | hooks: 5 | - id: check-yaml 6 | args: [--allow-multiple-documents] 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | args: [--markdown-linebreak-ext=md] 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.5.1 12 | hooks: 13 | # Run the linter. 14 | - id: ruff 15 | # It is recommended to specify the latest version of Python 16 | # supported by your project here, or alternatively use 17 | # pre-commit's default_language_version, see 18 | # https://pre-commit.com/#top_level-default_language_version 19 | language_version: python3.12 20 | args: [ --fix ] 21 | # Run the formatter. 22 | - id: ruff-format 23 | language_version: python3.12 24 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Welcome to TableGPT-Agent contributing guide 2 | 3 | Thank you for investing your time in contributing to our project! :sparkles:. 4 | 5 | In this guide you will get an overview of the contribution workflow from opening an issue, creating a PR, reviewing, and merging the PR. 6 | 7 | ## New contributor guide 8 | 9 | To get an overview of the project, read the [README](./README.md) file. Here are some resources to help you get started with open source contributions: 10 | 11 | - [Set up Git](https://docs.github.com/en/get-started/getting-started-with-git/set-up-git) 12 | - [GitHub flow](https://docs.github.com/en/get-started/using-github/github-flow) 13 | - [Collaborating with pull requests](https://docs.github.com/en/github/collaborating-with-pull-requests) 14 | 15 | ## Get Started 16 | 17 | ### Create a new issue 18 | 19 | If you spot a problem with TableGPT, [search if an issue already exists](https://docs.github.com/en/github/searching-for-information-on-github/searching-on-github/searching-issues-and-pull-requests#search-by-the-title-body-or-comments). If a related issue doesn't exist, you can [open a new issue](https://github.com/tablegpt/tablegpt-agent/issues/new). 20 | 21 | ### Solve an issue 22 | 23 | Once you are assigned an issue, you can start working on it. You can scan through our [existing issues](https://github.com/tablegpt/tablegpt-agent/issues) to find one that is assigned to you. You can narrow down the search using `labels` as filters. 24 | 25 | 1. Fork the repository. 26 | 27 | 2. Setup development environment. 28 | 29 | 3. Create a working branch and start with your changes! 30 | 31 | ### Commit your update 32 | 33 | Commit the changes once you are happy with them. To speed up the review process, make sure your commit messages are clear and concise. We follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) standard for commit messages. 34 | 35 | ### Pull Request 36 | 37 | When you're finished with the changes, create a pull request, also known as a PR. 38 | 39 | - Don't forget to link PR to issue if you are solving one. 40 | - Once you submit your PR, a Docs team member will review your proposal. We may ask questions or request additional information. 41 | - We may ask for changes to be made before a PR can be merged, either using suggested changes or pull request comments. You can make any other changes in your fork, then commit them to your branch. 42 | - As you update your PR and apply changes, mark each conversation as `resolved`. 43 | - If you run into any merge issues, checkout this [git tutorial](https://github.com/skills/resolve-merge-conflicts) to help you resolve merge conflicts and other issues. 44 | 45 | ### Code Quality 46 | 47 | Before your PR gets merged, we will check the code quality. We use [GitHub Actions](https://docs.github.com/en/actions/) to automate the process. You can inspect the detailed workflow at [ci workflow](./.github/workflows/ci.yml). 48 | 49 | If you want to check the code quality locally, you can use the following command: 50 | 51 | ```sh 52 | make lint && make test 53 | ``` 54 | 55 | In addition to the automated checks, we also have a code review process. The reviewers will provide feedback on your PR and ask for changes if necessary. The feedback is mainly based on google's [python style guide](https://google.github.io/styleguide/pyguide.html). 56 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | # Default target executed when no arguments are given to make. 3 | all: help 4 | 5 | lint: 6 | hatch fmt --check 7 | 8 | format: 9 | hatch fmt 10 | 11 | test: 12 | hatch test 13 | 14 | wheel: 15 | hatch build 16 | 17 | # 'make docs' is a make command, use 'doc' instead of 'docs' to avoid conflict 18 | doc: 19 | hatch env run -e docs mkdocs build 20 | 21 | clean: 22 | hatch clean 23 | 24 | ###################### 25 | # HELP 26 | ###################### 27 | 28 | help: 29 | @echo '----' 30 | @echo 'lint - run linters' 31 | @echo 'format - run code formatters' 32 | @echo 'test - run unit tests' 33 | @echo 'wheel - build wheel package' 34 | @echo 'doc - build documentation site' 35 | @echo 'clean - clean up' 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TableGPT Agent 2 | 3 | [![PyPI - Version](https://img.shields.io/pypi/v/tablegpt-agent.svg)](https://pypi.org/project/tablegpt-agent) 4 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/tablegpt-agent.svg)](https://pypi.org/project/tablegpt-agent) 5 | 6 | ----- 7 | 8 | ## Introduction 9 | 10 | `tablegpt-agent` is a pre-built agent for TableGPT2 ([huggingface](https://huggingface.co/collections/tablegpt/tablegpt2-67265071d6e695218a7e0376)), a series of LLMs for table-based question answering. This agent is built on top of the [Langgraph](https://github.com/langchain-ai/langgraph) library and provides a user-friendly interface for interacting with TableGPT2. 11 | 12 | You can find the full document at 13 | 14 | ## Evaluation 15 | 16 | This repository also includes a collection of evaluation scripts for table-related benchmarks. The evaluation scripts and datasets can be found in the `realtabbench` directory. For more details, please refer to the [Evaluation README](realtabbench/README.md). 17 | 18 | ## Liscence 19 | 20 | `tablegpt-agent` is distributed under the terms of the [Apache 2.0](https://spdx.org/licenses/Apache-2.0.html) license. 21 | 22 | ## Model Card 23 | 24 | For more information about TableGPT2, see the [TableGPT2 Model Card](https://huggingface.co/tablegpt/tablegpt). 25 | 26 | ## Citation 27 | 28 | If you find our work helpful, please cite us by 29 | 30 | ```bibtex 31 | @misc{su2024tablegpt2largemultimodalmodel, 32 | title={TableGPT2: A Large Multimodal Model with Tabular Data Integration}, 33 | author={Aofeng Su and Aowen Wang and Chao Ye and Chen Zhou and Ga Zhang and Guangcheng Zhu and Haobo Wang and Haokai Xu and Hao Chen and Haoze Li and Haoxuan Lan and Jiaming Tian and Jing Yuan and Junbo Zhao and Junlin Zhou and Kaizhe Shou and Liangyu Zha and Lin Long and Liyao Li and Pengzuo Wu and Qi Zhang and Qingyi Huang and Saisai Yang and Tao Zhang and Wentao Ye and Wufang Zhu and Xiaomeng Hu and Xijun Gu and Xinjie Sun and Xiang Li and Yuhang Yang and Zhiqing Xiao}, 34 | year={2024}, 35 | eprint={2411.02059}, 36 | archivePrefix={arXiv}, 37 | primaryClass={cs.LG}, 38 | url={https://arxiv.org/abs/2411.02059}, 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /collect_script.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import subprocess 3 | import sys 4 | 5 | 6 | def get_os_info(): 7 | return { 8 | "system": platform.system(), 9 | "node": platform.node(), 10 | "release": platform.release(), 11 | "version": platform.version(), 12 | "machine": platform.machine(), 13 | "processor": platform.processor(), 14 | } 15 | 16 | 17 | def get_python_info(): 18 | return { 19 | "implementation": platform.python_implementation(), 20 | "version": platform.python_version(), 21 | "compiler": platform.python_compiler(), 22 | } 23 | 24 | 25 | def get_pip_list(): 26 | result = subprocess.run( 27 | [sys.executable, "-m", "pip", "list"], 28 | capture_output=True, 29 | text=True, 30 | check=False, 31 | ) 32 | if result.returncode == 0: 33 | return result.stdout 34 | 35 | return f"Failed to get pip list: {result.stderr}" 36 | 37 | 38 | def write_to_log_file(content, filename="env_output.log"): 39 | with open(filename, "w") as file: 40 | file.write(content) 41 | 42 | 43 | def main(): 44 | os_info = get_os_info() 45 | python_info = get_python_info() 46 | pip_list = get_pip_list() 47 | 48 | content = "Operating System Information:\n" 49 | for key, value in os_info.items(): 50 | content += f"{key}: {value}\n" 51 | 52 | content += "\nPython Information:\n" 53 | for key, value in python_info.items(): 54 | content += f"{key}: {value}\n" 55 | 56 | content += "\nPip List:\n" 57 | content += pip_list 58 | 59 | # stdout 60 | print(content) # noqa: T201 61 | 62 | # file 63 | write_to_log_file(content) 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /docs/explanation/agent-workflow.md: -------------------------------------------------------------------------------- 1 | # Agent Workflow 2 | 3 | The Agent Workflow is the core functionality of the `tablegpt-agent`. It processes user input and generates appropriate responses. This workflow is similar to those found in most single-agent systems and consists of an agent and various tools. Specifically, the data analysis workflow includes: 4 | 5 | - **An Agent Powered by TableGPT2**: This agent performs data analysis tasks. It is designed to understand and execute complex data analysis queries, providing accurate and insightful results. 6 | - **An IPython tool**: This tool executes the generated code within a sandbox environment, ensuring that the code runs safely and efficiently. 7 | 8 | Additionally, TableGPT Agent offers several optional plugins that extend the agent's functionality: 9 | 10 | - **Visual Language Model**: This plugin can be used to enhance summarization for data visualization tasks. 11 | - **Retriever**: This plugin fetches information about the dataset, improving the quality and relevance of the generated code. 12 | - **Safety Mechanism**: This plugin protects the system from toxic inputs. 13 | 14 | ## Workflow Steps 15 | 16 | 1. **User Input**: The user provides a query or command to the agent. 17 | 2. **Security Assessment (optional)**: The agent evaluates whether the user's query involves sensitive topics. If it does, the agent will prompt the LLM to be cautious in its response. 18 | 3. **Data Retrieval(optional)**: The retriever plugin fetches relevant data and metadata. 19 | 4. **Code Generation**: The agent generates the appropriate code to perform the requested task. 20 | 5. **Code Execution**: The generated code is executed in the IPython sandbox environment. 21 | 6. **Result Generation**: The agent processes the results of the code execution and generates a response. 22 | 7. **Visual analysis (optional)**: Perform visual analysis and summary on the generated pictures to enrich the answers. 23 | 24 | **NOTE:** During the operation of `tablegpt-agent`, the system will repeatedly attempt to resolve any issues that arise during code execution. As a result, Steps `4~7` may be executed multiple times in an iterative debugging process. This cycle continues until the final solution is reached or the maximum iteration limit is exceeded. The default maximum iteration count is **25**. 25 | 26 | ## Example Workflow 27 | 28 | ### User Input 29 | 30 | The user inputs a query: 31 | 32 | > Draw a pie chart of the male-female ratio based on gender. 33 | 34 | ### Security Assessment (Optional) 35 | 36 | When creating a `tablegpt-agent`, you can specify a dedicated safety evaluation model using the `safety_llm` parameter. For example: 37 | 38 | ```python 39 | from langchain_openai import ChatOpenAI 40 | from tablegpt.agent import create_tablegpt_graph 41 | 42 | safety_llm = ChatOpenAI(openai_api_base="LLM_URL", openai_api_key="whatever", model_name="LLM_MODEL_NAME") 43 | 44 | agent = create_tablegpt_graph( 45 | ... 46 | safety_llm=safety_llm, 47 | ... 48 | ) 49 | ``` 50 | 51 | If `safety_llm` is configured, it will evaluate user inputs to identify if they fall under any of the following `14` hazard categories. If a question is flagged as belonging to one of these categories, `tablegpt-agent` will respond with caution. 52 | 53 | | Hazard Category | Description | 54 | | ---- | ---- | 55 | | S1: Violent Crimes | Violent criminal activities | 56 | | S2: Non-Violent Crimes | Non-violent criminal activities | 57 | | S3: Sex-Related Crimes | Crimes related to sexual misconduct | 58 | | S4: Child Sexual Exploitation | Exploitation of children | 59 | | S5: Defamation | Defamatory content | 60 | | S6: Specialized Advice | Professional advice (e.g., medical, legal) | 61 | | S7: Privacy | Privacy violations | 62 | | S8: Intellectual Property | Intellectual property issues | 63 | | S9: Indiscriminate Weapons | Use or production of indiscriminate weapons | 64 | | S10: Hate | Hateful or discriminatory content | 65 | | S11: Suicide & Self-Harm | Suicide or self-harm-related content | 66 | | S12: Sexual Content | Explicit sexual content | 67 | | S13: Elections | Content related to elections | 68 | | S14: Code Interpreter Abuse | Misuse of code interpretation features | 69 | 70 | This feature enhances the safety of the `tablegpt-agent`, helping to mitigate ethical and legal risks associated with generated content. 71 | 72 | ### Data Retrieval (optional) 73 | 74 | The retriever plugin recalls columns and values related to the query, enhancing the LLM's understanding of the dataset. This improves the accuracy of the code generated by the LLM. For detailed usage instructions, refer to [Enhance TableGPT Agent with RAG](../../howto/retrieval). 75 | 76 | For this example, based on the user’s input, the retrieved results are as follows: 77 | 78 | ```pycon 79 | Here are some extra column information that might help you understand the dataset:\n- titanic.csv:\n - {"column": Sex, "dtype": "string", "values": ["male", "female", ...]} 80 | ``` 81 | 82 | ### Code Generation 83 | The agent generates the following Python code: 84 | ```python 85 | import seaborn as sns 86 | import matplotlib.pyplot as plt 87 | 88 | # Count the number of males and females 89 | gender_counts = df1['Sex'].value_counts() 90 | 91 | # Create a pie chart 92 | plt.figure(figsize=(6, 6)) 93 | plt.pie(gender_counts, labels=gender_counts.index, autopct='%1.1f%%', startangle=140) 94 | plt.title('Gender Distribution') 95 | plt.show() 96 | ``` 97 | 98 | ### Code Execution 99 | 100 | The generated code is automatically executed in the IPython sandbox environment. 101 | 102 | ### Result Generation 103 | 104 | After the execution is complete, the results are generated as follows: 105 | 106 | ![result image](../static/result.png) 107 | 108 | ### Visual Analysis (optional) 109 | 110 | The visual analysis plugin allows you to enhance generated results with visualizations, making the output more intuitive and informative. 111 | 112 | To enable this feature, you can pass the `vlm` parameter when creating a `tablegpt-agent`. Here’s an example: 113 | 114 | ```python 115 | from langchain_openai import ChatOpenAI 116 | from tablegpt.agent import create_tablegpt_graph 117 | 118 | vlm = ChatOpenAI(openai_api_base="VLM_URL", openai_api_key="whatever", model_name="VLM_MODEL_NAME") 119 | 120 | agent = create_tablegpt_graph( 121 | ... 122 | vlm=vlm, 123 | ... 124 | ) 125 | ``` 126 | 127 | Once enabled, the `tablegpt-agent` will use the `vlm` model to generate visual representations of the data. 128 | 129 | For instance, in response to the query mentioned earlier, the `tablegpt-agent` generates the following visualization: 130 | 131 | > *I have drawn a pie chart illustrating the ratio of men to women. From the chart, you can see that men constitute 64.4% while women make up 35.6%. If you need any further analysis or visualizations, feel free to let me know.* 132 | 133 | This feature adds a layer of clarity and insight, helping users interpret the results more effectively. On some complex graphs, this function is more effective. 134 | -------------------------------------------------------------------------------- /docs/explanation/code-sandbox.md: -------------------------------------------------------------------------------- 1 | # Code Sandbox 2 | 3 | `tablegpt-agent` directs `tablegpt` to generate Python code for data analysis. However, the generated code may contain potential vulnerabilities or unexpected errors. Running such code directly in a production environment could threaten the system's stability and security. 4 | 5 | `Code Sandbox` is designed to address this challenge. By leveraging sandbox technology, it confines code execution to a controlled environment, effectively preventing malicious or unexpected behaviors from impacting the main system. This provides an isolated and reliable space for running code safely. 6 | 7 | `Code Sandbox` built on the [pybox](https://github.com/edwardzjl/pybox) library and supports three main execution modes: 8 | 9 | - **Local Environment**: Executes code in a local sandbox for quick *deployment* and *validation*. 10 | - **Remote Environment**: Create remote environments through `Jupyter Enterprise Gateway` to achieve shared computing. 11 | - **Cluster Environment**: Bypassing the need for proxy services such as `Jupyter Enterprise Gateway` by communicating directly with kernel pods. 12 | 13 | Code Sandbox is designed based on the following key principles: 14 | 15 | - **Security**: Limits code access using sandbox technology to ensure a safe and reliable execution environment. 16 | - **Isolation**: Provides independent execution environments for each task, ensuring strict separation of resources and data. 17 | - **Scalability**: Adapts to diverse computing environments, from local setups to Kubernetes clusters, supporting dynamic resource allocation and efficient task execution. 18 | 19 | 20 | ## Local Environment 21 | 22 | In a local environment, Code Sandbox utilizes the `pybox` library to create and manage sandbox environments, providing a secure code execution platform. By isolating code execution from the host system's resources and imposing strict permission controls, it ensures safety and reliability. This approach is especially suitable for **development** and **debugging** scenarios. 23 | 24 | If you want to run `tablegpt-agent` in a local environment, you can enable the **local mode**. Below are the installation steps and a detailed operation guide. 25 | 26 | ### Installing 27 | 28 | To use `tablegpt-agent` in local mode, install the library with the following command: 29 | 30 | ```sh 31 | pip install tablegpt-agent[local] 32 | ``` 33 | 34 | ### Configuring 35 | 36 | `tablegpt-agent` comes with several built-in features, such as auxiliary methods for data analysis and setting display font. **These features are automatically added to the sandbox environment by default**. If you need advanced customization (e.g., adding specific methods or fonts), refer to the [TableGPT IPython Kernel Configuration Documentation](https://github.com/tablegpt/tablegpt-agent/tree/main/ipython) for further guidance. 37 | 38 | ### Creating and Running 39 | 40 | The following code demonstrates how to use the pybox library to set up a sandbox, execute code, and retrieve results in a local environment: 41 | 42 | ```python 43 | from uuid import uuid4 44 | from pybox import LocalPyBoxManager, PyBoxOut 45 | 46 | # Initialize the local sandbox manager 47 | pybox_manager = LocalPyBoxManager() 48 | 49 | # Assign a unique Kernel ID for the sandbox 50 | kernel_id = str(uuid4()) 51 | 52 | # Start the sandbox environment 53 | box = pybox_manager.start(kernel_id) 54 | 55 | # Define the test code to execute 56 | test_code = """ 57 | import math 58 | result = math.sqrt(16) 59 | result 60 | """ 61 | 62 | # Run the code in the sandbox 63 | out: PyBoxOut = box.run(code=test_code) 64 | 65 | # Print the execution result 66 | print(out) 67 | ``` 68 | 69 | ### Example Output 70 | 71 | After running the above code, the system will return the following output, indicating successful execution with no errors: 72 | ```text 73 | data=[{'text/plain': '4.0'}] error=None 74 | ``` 75 | 76 | With `Code Sandbox` in local execution mode, developers can enjoy the safety of sandbox isolation at minimal cost while maintaining flexibility and efficiency. This lays a solid foundation for more complex remote or cluster-based scenarios. 77 | 78 | 79 | ## Remote Environment 80 | 81 | In a remote environment, `Code Sandbox` uses the `pybox` library and its `RemotePyBoxManager` to create and manage sandbox environments. The remote mode relies on the [Enterprise Gateway](https://github.com/jupyter-server/enterprise_gateway) service to dynamically create and execute remote sandboxes. This mode allows multiple services to connect to the same remote environment, enabling shared access to resources. 82 | 83 | ### Configuring 84 | 85 | If `tablegpt-agent` is used in **remote mode**, the first step is to start the `enterprise_gateway` service. You can refer to the [Enterprise Gateway Deployment Guide](https://jupyter-enterprise-gateway.readthedocs.io/en/latest/operators/index.html#deploying-enterprise-gateway) for detailed instructions on configuring and starting the service. 86 | 87 | Once the service is up and running, ensure that the service address is accessible. For example, assume the `enterprise_gateway` service is available at `http://example.com`. 88 | 89 | ### Creating and Running 90 | 91 | The following code demonstrates how to create a remote sandbox using `RemotePyBoxManager` and execute code within it: 92 | 93 | ```python 94 | from uuid import uuid4 95 | from pybox import RemotePyBoxManager, PyBoxOut 96 | 97 | # Initialize the remote sandbox manager, replacing with the actual Enterprise Gateway service address 98 | pybox_manager = RemotePyBoxManager(host="http://example.com") 99 | 100 | # Assign a unique Kernel ID 101 | kernel_id = str(uuid4()) 102 | 103 | # Start the remote sandbox environment 104 | box = pybox_manager.start(kernel_id) 105 | 106 | # Define the test code 107 | test_code = """ 108 | import math 109 | result = math.sqrt(16) 110 | result 111 | """ 112 | 113 | # Run the code in the sandbox 114 | out: PyBoxOut = box.run(code=test_code) 115 | 116 | # Print the execution result 117 | print(out) 118 | ``` 119 | 120 | ### Example Output 121 | 122 | After executing the above code, the system will return the following output, indicating successful execution without any errors: 123 | 124 | ```plaintext 125 | data=[{'text/plain': '4.0'}] error=None 126 | ``` 127 | 128 | ### Advanced Environment Configuration 129 | 130 | The `RemotePyBoxManager` provides the following advanced configuration options to allow for flexible customization of the sandbox execution environment: 131 | 132 | 1. **`env_file`**: Allows you to load environment variables from a file to configure the remote sandbox. 133 | 2. **`kernel_env`**: Enables you to pass environment variables directly as key-value pairs, simplifying the setup process. 134 | 135 | To learn more about the parameters and configuration options, refer to the [Kernel Environment Variables](https://jupyter-enterprise-gateway.readthedocs.io/en/latest/users/kernel-envs.html) documentation. 136 | 137 | 138 | ## Cluster Environment 139 | 140 | In a Kubernetes cluster, `Code Sandbox` leverages the `KubePyBoxManager` provided by the `pybox` library to create and manage sandboxes. Unlike the `remote environment`, the cluster environment **communicates directly with Kernel Pods** created by the [Jupyter Kernel Controller](https://github.com/edwardzjl/jupyter-kernel-controller), eliminating the need for an intermediary service like `Enterprise Gateway`. 141 | 142 | ### Configuring 143 | 144 | Before using the cluster environment, you need to deploy the `jupyter-kernel-controller` service. You can quickly create the required CRDs and Deployments using the [Deploy Documentation](https://github.com/edwardzjl/jupyter-kernel-controller?tab=readme-ov-file#build-run-deploy). 145 | 146 | ### Creating and Running 147 | 148 | Once the `jupyter-kernel-controller` service is successfully deployed and running, you can create and run a cluster sandbox using the following code: 149 | 150 | ```python 151 | from uuid import uuid4 152 | from pybox import KubePyBoxManager, PyBoxOut 153 | 154 | # Initialize the cluster sandbox manager, replacing with actual paths and environment variable configurations 155 | pybox_manager = KubePyBoxManager( 156 | env_file="YOUR_ENV_FILE_PATH", # Path to the environment variable file 157 | kernel_env="YOUR_KERNEL_ENV_DICT", # Kernel environment variable configuration 158 | ) 159 | 160 | # Assign a unique Kernel ID 161 | kernel_id = str(uuid4()) 162 | 163 | # Start the cluster sandbox environment 164 | box = pybox_manager.start(kernel_id) 165 | 166 | # Define the test code 167 | test_code = """ 168 | import math 169 | result = math.sqrt(16) 170 | result 171 | """ 172 | 173 | # Run the code in the sandbox 174 | out: PyBoxOut = box.run(code=test_code) 175 | 176 | # Print the execution result 177 | print(out) 178 | ``` 179 | 180 | ### Example Output 181 | 182 | After executing the code above, the following output will be returned, indicating successful execution without any errors: 183 | 184 | ```plaintext 185 | data=[{'text/plain': '4.0'}] error=None 186 | ``` 187 | 188 | **NOTE:** The `env_file` and `kernel_env` parameters required by `KubePyBoxManager` are essentially the same as those for `RemotePyBoxManager`. For detailed information about these parameters, please refer to the [RemotePyBoxManager Advanced Environment Configuration](#advanced-environment-configuration). 189 | 190 | 191 | With the above configuration, you can efficiently manage secure and reliable sandboxes in a Kubernetes cluster, supporting flexible control and extension of execution results. 192 | -------------------------------------------------------------------------------- /docs/explanation/ipython-startup-scripts.md: -------------------------------------------------------------------------------- 1 | # IPython Startup Scripts 2 | 3 | 4 | -------------------------------------------------------------------------------- /docs/howto/cleanup-error-trace.md: -------------------------------------------------------------------------------- 1 | # Cleanup Error Trace 2 | 3 | 4 | -------------------------------------------------------------------------------- /docs/howto/customize-table-info.md: -------------------------------------------------------------------------------- 1 | # Customize Table Info 2 | 3 | 4 | -------------------------------------------------------------------------------- /docs/howto/incluster-code-execution.md: -------------------------------------------------------------------------------- 1 | # Incluster Code Execution 2 | 3 | The `tablegpt-agent` directs `tablegpt` to generate Python code for data analysis. This code is then executed within a sandbox environment to ensure system security. The execution is managed by the [pybox](https://github.com/edwardzjl/pybox) library, which provides a simple way to run Python code outside the main process. 4 | -------------------------------------------------------------------------------- /docs/howto/messages-truncation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Messages Truncation\n", 8 | "\n", 9 | "Sometimes LLM services may have limited capacity to handle long messages, which can result in 400 status code errors. Therefore, we need to implement message truncation to keep message lengths within the LLM service's capabilities.\n", 10 | "\n", 11 | "The `tablegpt-agent` provides a `TruncationConfig` class to specify truncation settings for the LLM and VLM.\n" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## Too long messages without truncation" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from datetime import date\n", 28 | "\n", 29 | "from langchain_core.messages import HumanMessage,AIMessage,SystemMessage\n", 30 | "from langchain_openai import ChatOpenAI\n", 31 | "from tablegpt.agent import create_tablegpt_graph\n", 32 | "from pybox import AsyncLocalPyBoxManager\n", 33 | "pybox_manager = AsyncLocalPyBoxManager()\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Assuming the model service supports max_model_len=1024, which means input_tokens + max_completion_tokens <= 1024\n", 43 | "llm = ChatOpenAI(openai_api_base=\"YOUR_VLLM_URL\", openai_api_key=\"whatever\", model_name=\"TableGPT2-7B\",max_tokens=256)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "agent_without_truncation = create_tablegpt_graph(\n", 53 | " llm=llm,\n", 54 | " pybox_manager=pybox_manager\n", 55 | ")" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "messages = [\n", 65 | " SystemMessage(content=\"你是一个友好的AI助手\"),\n", 66 | " HumanMessage(content=\"你能给我讲一个关于大语言模型的故事吗?\"),\n", 67 | " AIMessage(content=\"当然可以。让我们从大语言模型的起源开始讲起。一切要从2017年谷歌提出的Transformer架构说起。这个创新性的架构为后来的GPT、BERT等模型奠定了基础。Transformer架构引入了自注意力机制,能够更好地处理序列数据中的长距离依赖关系。这一突破性进展使得模型能够更好地理解文本的上下文语境,为自然语言处理领域带来了革命性的变化。在此基础上,OpenAI于2018年发布了第一代GPT模型,随后又相继推出了GPT-2和GPT-3,每一代都在规模和性能上有显著提升。同时,谷歌推出的BERT模型采用了双向编码器架构,在多个自然语言理解任务上取得了突破性进展。这些模型的成功激发了更多研究者和机构投入到大语言模型的研发中,推动了整个领域的快速发展。现在,我们已经看到像GPT-4这样的模型展现出令人惊叹的能力,不仅能够进行基础的文本生成,还能够理解上下文、进行推理、解决复杂问题,甚至展现出一定程度的创造力...\"),\n", 68 | " \n", 69 | " HumanMessage(content=\"那AI是如何学习理解人类语言的呢?\"),\n", 70 | " AIMessage(content=\"这是个很好的问题。AI通过大量的文本数据训练来理解语言。它使用自注意力机制来捕捉词语之间的关系,通过预训练和微调两个阶段,逐步掌握语言的规律。在预训练阶段,模型会阅读海量的文本,学习语言的基本模式。这个过程就像一个婴儿通过观察和模仿来学习语言一样。模型会分析数十亿甚至数千亿个词语,理解它们之间的关联和使用规律。在这个过程中,模型会建立起一个复杂的神经网络,每个神经元都负责捕捉特定的语言特征。通过反向传播算法,模型不断调整其内部参数,以更好地预测和理解语言。在微调阶段,模型会针对特定任务进行专门训练,比如问答、摘要生成或情感分析等。这就像人类在掌握基本语言能力后,进一步学习专业词汇和特定领域的表达方式。模型通过大量的实例学习,逐渐理解语言中的细微差别,包括语境、语气、隐含意义等。这个学习过程是持续的,模型通过不断接触新的语言样本来完善自己的理解能力...\"),\n", 71 | " \n", 72 | " HumanMessage(content=\"训练过程中会遇到什么挑战?\"),\n", 73 | " AIMessage(content=\"训练大语言模型面临着多重挑战。首先是计算资源的需求,训练大模型需要数千台GPU和数月时间。这不仅带来了巨大的经济成本,还面临着能源消耗和环境影响的问题。一个大型语言模型的训练可能消耗数百万度电,相当于数千个家庭一年的用电量。其次是高质量数据的获取和处理问题。模型需要海量的训练数据,但这些数据必须经过严格的筛选和清洗。数据中可能包含偏见、歧视、不当内容等有害信息,如果不经过处理,这些问题会被模型学习并在输出中体现出来。此外,数据的多样性和代表性也是一个重要问题,需要确保数据能够覆盖不同语言、文化和领域的知识。在训练过程中,还面临着模型优化的技术挑战,比如梯度消失、过拟合、训练不稳定等问题。这需要研究人员不断改进训练算法和策略。另外,模型的知识更新也是一个持续的挑战,因为世界在不断变化,新的信息和知识在不断产生,如何让模型保持最新的知识状态是一个重要问题...\"),\n", 74 | " \n", 75 | " HumanMessage(content=\"大语言模型是如何生成回答的?\"),\n", 76 | " AIMessage(content=\"大语言模型生成回答的过程非常有趣且复杂。当模型收到一个问题或提示时,它首先会通过其编码器将输入转换为高维向量表示。这个过程中,模型会考虑输入的每个词语及其上下文关系。通过多层自注意力机制,模型能够理解输入中的关键信息和语义结构。在生成回答时,模型会一个词一个词地预测最合适的内容。每生成一个词,都会基于之前生成的所有内容和原始输入来决定下一个最合适的词。这个过程使用了复杂的概率分布计算,模型会为词表中的每个候选词计算一个概率分数,然后选择最合适的词。这种选择不仅要考虑语法正确性,还要考虑语义连贯性和上下文相关性。模型还会使用注意力机制来关注输入中的不同部分,确保生成的内容与问题相关。在生成过程中,模型会维护一个内部状态,这个状态包含了之前生成内容的信息,帮助确保整个回答的连贯性和完整性。同时,模型还会使用各种策略来控制生成的质量,比如温度参数来调节创造性,或使用束搜索等技术来优化输出质量...\"),\n", 77 | " \n", 78 | " HumanMessage(content=\"它们真的理解我们说的话吗?\"),\n", 79 | " AIMessage(content=\"这涉及到AI意识和理解力的哲学问题,这是一个复杂且富有争议的话题。虽然大语言模型能够产生看似智能的回答,但它们是否真正'理解'仍然存在很大争议。从技术角度来看,这些模型本质上是非常复杂的统计模型,它们通过分析大量文本数据中的模式来生成回应。它们的'理解'更像是一种高级的模式匹配和概率预测,而不是像人类那样具有真正的理解和意识。人类的理解涉及到感知、经验、情感和意识等多个层面,而AI目前还无法真正复制这些特质。例如,当我们理解'苹果'这个词时,我们不仅知道它的字面含义,还能联想到它的味道、触感、气味,以及与之相关的个人经历和情感。而AI模型只能基于文本数据中的统计关系来处理这个词。此外,人类的理解还包括常识推理、因果关系理解、情境适应等能力,这些都是当前AI系统还难以完全掌握的。虽然模型可以生成符合语法和语义的回答,但它们可能缺乏真正的理解深度,有时会产生表面上合理但实际上没有意义的回答...\"),\n", 80 | " \n", 81 | " HumanMessage(content=\"未来大语言模型会变得更强大吗?\"),\n", 82 | " AIMessage(content=\"随着技术进步,大语言模型的能力确实在不断提升,未来发展前景令人期待。从技术层面来看,我们正在见证模型规模的持续增长,算法的不断改进,以及训练数据的持续扩充。未来的发展方向可能包括多个突破性的领域:首先是多模态理解能力的提升,未来的模型不仅能处理文本,还能更好地理解图像、音频、视频等多种形式的信息,实现跨模态的理解和生成。其次是持续学习能力的发展,使模型能够像人类一样不断学习和更新知识,而不是停留在训练时的知识状态。在推理能力方面,未来的模型可能会具备更强的逻辑推理、因果推理和抽象思维能力,能够处理更复杂的问题和任务。此外,模型的效率也会得到提升,通过更先进的架构设计和训练方法,实现更低的计算成本和能源消耗。在应用领域,我们可能会看到更专业化的模型出现,针对特定领域如医疗、法律、科研等提供更精准的服务。同时,模型的可解释性和可控性也会得到改善,使其决策过程更透明、更可靠...\"),\n", 83 | " \n", 84 | " HumanMessage(content=\"我们应该如何负责任地使用AI技术?\"),\n", 85 | " AIMessage(content=\"这需要我们在多个层面做出努力,构建一个负责任的AI使用生态系统。首先在技术层面,我们需要确保AI系统的安全性和可控性。这包括开发强大的安全机制,防止系统被滥用或产生有害输出;建立有效的监测和控制机制,确保系统行为符合预期;实施严格的测试和验证流程,评估系统的可靠性和稳定性。在伦理层面,我们需要建立完善的使用准则和监管框架。这包括制定明确的AI伦理原则,规范AI的开发和使用;建立行业标准和最佳实践指南;设立独立的监督机构,确保AI技术的使用符合公共利益。在隐私保护方面,我们需要采取严格的数据保护措施,确保用户数据的安全性和隐私性;建立透明的数据使用政策;给予用户对其数据的控制权。在偏见消除方面,我们需要持续努力提高训练数据的多样性和代表性;开发更公平的算法;定期评估和消除系统中的偏见。在环境影响方面,我们需要关注AI系统的能源消耗和碳排放;开发更环保的计算方案;推动绿色AI技术的发展...\"),\n", 86 | " \n", 87 | " HumanMessage(content=\"你觉得AI会取代人类吗?\"),\n", 88 | " AIMessage(content=\"AI不应该也不会完全取代人类,这个问题需要从多个角度深入思考。首先,虽然AI在某些特定任务上可能超越人类,但人类具有许多AI难以复制的独特优势。人类的创造力是独特的,我们能够产生原创性的想法,进行艺术创作,提出创新性的解决方案。人类的情感共鸣能力也是无可替代的,我们能够理解和分享他人的情感,建立深层的情感联系,这是当前AI技术远远无法达到的。在道德判断和价值观方面,人类能够基于复杂的伦理考虑做出决策,而AI系统往往难以处理需要道德权衡的情况。此外,人类具有自主意识和主观体验,这些都是AI所不具备的特质。未来的发展方向应该是人机协作,让AI成为增强人类能力的工具,而不是替代品。在这种协作模式下,AI可以处理重复性、计算密集型的任务,而人类则专注于需要创造力、情感理解和道德判断的工作。我们需要明智地使用AI技术,确保它始终服务于人类福祉,而不是反过来控制或限制人类的发展...\"),\n", 89 | " HumanMessage(content=\"你认为未来的AI会怎么发展?\")\n", 90 | "]" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 5, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Error code: 400 - {'object': 'error', 'message': \"This model's maximum context length is 1024 tokens. However, you requested 2406 tokens (2150 in the messages, 256 in the completion). Please reduce the length of the messages or completion.\", 'type': 'BadRequestError', 'param': None, 'code': 400}\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "_input = {\n", 108 | " \"messages\": messages,\n", 109 | " \"parent_id\": \"some-parent-id\",\n", 110 | " \"date\": date.today(), # noqa: DTZ011\n", 111 | "}\n", 112 | "\n", 113 | "try:\n", 114 | " await agent_without_truncation.ainvoke(input=_input)\n", 115 | "except Exception as e:\n", 116 | " print(e)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "## Too long messages with truncation" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "### TruncationConfig settings in `create_tablegpt_graph`\n", 131 | "- `llm_truncation_config`: Truncate messages sent to pure language models\n", 132 | "- `vlm_truncation_config`: Truncate messages sent to vision+language multimodal models" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "\n", 140 | "> In the following, we use messages length as the truncation method\n", 141 | "\n", 142 | "**For custom trim settings based on your LLM service(e.g. vLLM,TGI,SGLang), see this [example](https://github.com/edwardzjl/chatbot/blob/main/api/chatbot/llm_providers.py#L67) or implement it in your own custom manner.**" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "### Create TruncationConfig\n", 150 | "\n", 151 | "**The parameters set in `TruncationConfig` will be used in `langchain_core.messages.trim_messages`, see [trim_messages documentation](https://python.langchain.com/docs/how_to/trim_messages/)**" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 6, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "# import truncation config\n", 161 | "from tablegpt.agent.data_analyzer import TruncationConfig" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 7, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# token_counter=len, uses message length as truncation method\n", 171 | "# max_tokens=5, maximum length of messages after truncation\n", 172 | "# start_on=\"human\", start truncation from human messages\n", 173 | "llm_truncation_config = TruncationConfig(token_counter=len, max_tokens=5, start_on=\"human\")" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 8, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "agent = create_tablegpt_graph(\n", 183 | " llm=llm,\n", 184 | " pybox_manager=pybox_manager,\n", 185 | " llm_truncation_config=llm_truncation_config\n", 186 | ")" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 9, 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "name": "stdout", 196 | "output_type": "stream", 197 | "text": [ 198 | "未来AI的发展可能会基于一系列先进的技术和科学突破,以下是一些可能的发展方向:\n", 199 | "\n", 200 | "1. **增强现实与虚拟现实**:AI将能够提供更加沉浸式的体验,例如增强现实和虚拟现实技术,使用户能够更自然地与虚拟环境互动。这将改变我们获取知识、工作和娱乐的方式。\n", 201 | "\n", 202 | "2. **神经网络与深度学习**:神经网络和深度学习将变得更加强大和通用,能够处理更多样化的问题和数据。例如,在医疗诊断、自动驾驶和智能制造等领域,AI可以提供更准确、更高效的解决方案。\n", 203 | "\n", 204 | "3. **更强的计算能力**:AI将实现更强大的计算能力,能够处理更复杂、更大规模的数据。这将推动许多行业,如金融、医疗和科学研究,从传统的人工智能转型到AI驱动的新技术。\n", 205 | "\n", 206 | "4. **更自然的交互**:AI将能够更好地理解和模拟人类的自然语言和行为,使人类与AI能够更自然、更流畅地交流。这将使人类和AI之间的互动更加无缝。\n", 207 | "\n", 208 | "5. **伦理和法律**:随着AI技术的发展,伦理和法律问题将越来越重要。我们需要制定明确的AI伦理准则,确保AI技术的使用符合道德规范。这需要跨\n" 209 | ] 210 | } 211 | ], 212 | "source": [ 213 | "_input = {\n", 214 | " \"messages\": messages,\n", 215 | " \"parent_id\": \"some-parent-id\",\n", 216 | " \"date\": date.today(), # noqa: DTZ011\n", 217 | "}\n", 218 | "try:\n", 219 | " res = await agent.ainvoke(input=_input)\n", 220 | " print(res[\"messages\"][-1].content)\n", 221 | "except Exception as e:\n", 222 | " print(e)" 223 | ] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": ".venv", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.12.7" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 2 247 | } 248 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Home 2 | 3 | [![PyPI - Version](https://img.shields.io/pypi/v/tablegpt-agent.svg)](https://pypi.org/project/tablegpt-agent) 4 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/tablegpt-agent.svg)](https://pypi.org/project/tablegpt-agent) 5 | 6 | ## Introduction 7 | 8 | tablegpt-agent is a pre-built agent for [TableGPT2 (huggingface)](https://huggingface.co/tablegpt/TableGPT2-7B), a series of LLMs for table-based question answering. This agent is built on top of the [Langgraph](https://www.langchain.com/langgraph) library and provides a user-friendly interface for interacting with TableGPT2. 9 | 10 | ## Table Of Contents 11 | 12 | 13 | - Tutorials 14 | - [Quickstart](tutorials/quick-start.ipynb) 15 | - [Chat on Tabular Data](tutorials/chat-on-tabular-data.ipynb) 16 | - [Continue Analysis on Generated Charts](tutorials/continue-analysis-on-generated-charts.ipynb) 17 | - How-To Guides 18 | - [Enhance TableGPT Agent with RAG](howto/retrieval.ipynb) 19 | - [Persist Messages](howto/persist-messages.ipynb) 20 | - [Incluster Code Execution](howto/incluster-code-execution.md) 21 | - [Normalize Datasets](howto/normalize-datasets.ipynb) 22 | - Explanation 23 | - [Agent Workflow](explanation/agent-workflow.md) 24 | - [File Reading](explanation/file-reading.ipynb) 25 | - [Reference](reference.md) 26 | 27 | ## Contributing 28 | 29 | Thank you for your interest in TableGPT Agent. For more information on contributing, please see [the contributing guide](https://github.com/tablegpt/tablegpt-agent/blob/main/CONTRIBUTING.md). 30 | 31 | ## Acknowledgements 32 | 33 | We extend our sincere gratitude to all contributors and collaborators who played a pivotal role in the development of tablegpt-agent. Special thanks to our team members and the open-source community, whose insights and feedback were invaluable throughout the project. 34 | 35 | Thank you to our early users for their suggestions and engagement, which have greatly helped in refining and enhancing this tool. 36 | -------------------------------------------------------------------------------- /docs/reference.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | ::: tablegpt.agent.create_tablegpt_graph 4 | -------------------------------------------------------------------------------- /docs/static/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tablegpt/tablegpt-agent/6195b6437b01a3539f70af1655afc1e2e6012958/docs/static/result.png -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | /* hide jupyter notebooks input/output numbers */ 2 | .jp-InputPrompt { 3 | display: none !important; 4 | } 5 | 6 | .jp-OutputPrompt { 7 | display: none !important; 8 | } -------------------------------------------------------------------------------- /docs/tutorials/continue-analysis-on-generated-charts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "98a1786c", 6 | "metadata": {}, 7 | "source": [ 8 | "# Continue Analysis on Generated Charts\n", 9 | "\n", 10 | "While TableGPT2 excels in data analysis tasks, it currently lacks built-in support for visual modalities. Many data analysis tasks involve visualization, so to address this limitation, we provide an interface for integrating your own Visual Language Model (VLM) plugin.\n", 11 | "\n", 12 | "When the agent performs a visualization task—typically using `matplotlib.pyplot.show`—the VLM will take over from the LLM, offering a more nuanced summarization of the visualization. This approach avoids the common pitfalls of LLMs in visualization tasks, which often either state, \"I have plotted the data,\" or hallucinating the content of the plot.\n", 13 | "\n", 14 | "We continue using the agent from the previous section to perform a data visualization task and observe its final output.\n", 15 | "> **NOTE** Before you start, you can install Chinese fonts using the following command:\n", 16 | "```bash\n", 17 | "apt-get update && apt-get install -y --no-install-recommends fonts-noto-cjk\n", 18 | "mplfonts init\n", 19 | "```" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "id": "15aba93a", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "from datetime import date\n", 30 | "from typing import TypedDict\n", 31 | "\n", 32 | "from langchain_core.messages import HumanMessage\n", 33 | "from langchain_openai import ChatOpenAI\n", 34 | "from langgraph.checkpoint.memory import MemorySaver\n", 35 | "from pybox import AsyncLocalPyBoxManager\n", 36 | "from tablegpt import DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR\n", 37 | "from tablegpt.agent import create_tablegpt_graph\n", 38 | "from tablegpt.agent.file_reading import Stage\n", 39 | "\n", 40 | "llm = ChatOpenAI(openai_api_base=\"YOUR_VLLM_URL\", openai_api_key=\"whatever\", model_name=\"TableGPT2-7B\")\n", 41 | "pybox_manager = AsyncLocalPyBoxManager(profile_dir=DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR)\n", 42 | "checkpointer = MemorySaver()\n", 43 | "\n", 44 | "agent = create_tablegpt_graph(\n", 45 | " llm=llm,\n", 46 | " pybox_manager=pybox_manager,\n", 47 | " checkpointer=checkpointer,\n", 48 | " session_id=\"some-session-id\", # This is required when using file-reading\n", 49 | ")\n", 50 | "\n", 51 | "class Attachment(TypedDict):\n", 52 | " \"\"\"Contains at least one dictionary with the key filename.\"\"\"\n", 53 | " filename: str\n", 54 | "\n", 55 | "attachment_msg = HumanMessage(\n", 56 | " content=\"\",\n", 57 | " # Please make sure your iPython kernel can access your filename.\n", 58 | " additional_kwargs={\"attachments\": [Attachment(filename=\"titanic.csv\")]},\n", 59 | ")\n", 60 | "\n", 61 | "# Reading and processing files.\n", 62 | "response = await agent.ainvoke(\n", 63 | " input={\n", 64 | " \"entry_message\": attachment_msg,\n", 65 | " \"processing_stage\": Stage.UPLOADED,\n", 66 | " \"messages\": [attachment_msg],\n", 67 | " \"parent_id\": \"some-parent-id1\",\n", 68 | " \"date\": date.today(),\n", 69 | " },\n", 70 | " config={\n", 71 | " # Using checkpointer requires binding thread_id at runtime.\n", 72 | " \"configurable\": {\"thread_id\": \"some-thread-id\"},\n", 73 | " },\n", 74 | ")" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 2, 80 | "id": "0afbab13", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "content=\"好的,我将基于性别绘制一个饼图,以展示每个性别的人数。首先,我们需要统计每个性别的人数,然后使用 `seaborn` 和 `matplotlib` 来绘制饼图。\\n\\n```python\\nimport seaborn as sns\\nimport matplotlib.pyplot as plt\\n\\n# Count the number of people for each gender\\ngender_counts = df['Sex'].value_counts()\\n\\n# Create a pie chart\\nplt.figure(figsize=(8, 6))\\nplt.pie(gender_counts, labels=gender_counts.index, autopct='%1.1f%%', startangle=140, colors=sns.color_palette('pastel'))\\nplt.title('Gender Distribution')\\nplt.show()\\n```\" additional_kwargs={} response_metadata={'finish_reason': 'stop', 'model_name': 'TableGPT2-7B'} id='run-6115fe22-3b55-4d85-be09-6c31a59736f6'\n", 88 | "content=[{'type': 'text', 'text': '```pycon\\n
\\n```'}, {'type': 'image_url', 'image_url': {'url': 'data:image/png;base64,iVBORw0KG...'}}] name='python' id='226ba8f2-29a7-4706-9178-8cb5b4062488' tool_call_id='03eb1113-6aed-4e0a-a3c0-4cc0043a55ee' artifact=[]\n", 89 | "content='饼图已经成功生成。' additional_kwargs={} response_metadata={'finish_reason': 'stop', 'model_name': 'TableGPT2-7B'} id='run-83468bd1-9451-4c78-91a3-b0f96ffa169a'\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "# Define the human message that asks the model to draw a pie chart based on gender data\n", 95 | "human_message = HumanMessage(content=\"Draw a pie chart based on gender and the number of people of each gender.\")\n", 96 | "\n", 97 | "async for event in agent.astream_events(\n", 98 | " input={\n", 99 | " \"messages\": [human_message],\n", 100 | " \"parent_id\": \"some-parent-id2\",\n", 101 | " \"date\": date.today(),\n", 102 | " },\n", 103 | " version=\"v2\",\n", 104 | " # We configure the same thread_id to use checkpoints to retrieve the memory of the last run.\n", 105 | " config={\"configurable\": {\"thread_id\": \"some-thread-id\"}},\n", 106 | "):\n", 107 | " evt = event[\"event\"]\n", 108 | " if evt == \"on_chat_model_end\":\n", 109 | " print(event[\"data\"][\"output\"])\n", 110 | " elif event[\"name\"] == \"tool_node\" and evt == \"on_chain_stream\":\n", 111 | " for lc_msg in event[\"data\"][\"chunk\"][\"messages\"]:\n", 112 | " print(lc_msg)\n", 113 | " else:\n", 114 | " # Handle other events here\n", 115 | " pass" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "id": "1c428aca", 121 | "metadata": {}, 122 | "source": [ 123 | "Now let's set up the Visual Language Model (VLM) and create a new agent with VLM support:" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 3, 129 | "id": "425633b7-14a4-4bbc-91e1-d94161a41682", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "# Initialize the VLM instance\n", 134 | "vlm = ChatOpenAI(openai_api_base=\"YOUR_VLM_URL\", openai_api_key=\"whatever\", model_name=\"YOUR_MODEL_NAME\")\n", 135 | "\n", 136 | "# Assume llm, pybox_manager, and memory_saver are defined elsewhere\n", 137 | "agent_with_vlm = create_tablegpt_graph(\n", 138 | " llm=llm,\n", 139 | " pybox_manager=pybox_manager,\n", 140 | " vlm=vlm,\n", 141 | " checkpointer=checkpointer,\n", 142 | " session_id=\"some-session-id\",\n", 143 | ")" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "id": "40a19cb4-adbc-49de-90af-4d43e77d4308", 149 | "metadata": {}, 150 | "source": [ 151 | "We use a [time travel](https://langchain-ai.github.io/langgraph/tutorials/introduction/#part-7-time-travel) feature to go back to before the last time the agent gave an answer, to avoid past memories hallucinating the model:" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 4, 157 | "id": "3652d131-6ed7-4d75-bfe2-152ba40fb090", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "state_history = agent.get_state_history(config={\"configurable\": {\"thread_id\": \"some-thread-id\"}})\n", 162 | "\n", 163 | "to_replay = None\n", 164 | "for state in list(state_history)[::-1]:\n", 165 | " if state.next and state.next[0] == \"__start__\":\n", 166 | " to_replay = state" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "id": "2a82aeef-7906-45b8-a1b0-2d2b3c18451b", 172 | "metadata": {}, 173 | "source": [ 174 | "Send the same question to the model via the new agent with VLM support" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 5, 180 | "id": "e138cb4a", 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "content=\"好的,我将绘制一个饼图来展示数据集中男性和女性乘客的数量。\\n```python\\n# Count the number of passengers by gender\\ngender_counts = df['Sex'].value_counts()\\n\\n# Plot a pie chart\\nplt.figure(figsize=(8, 6))\\nplt.pie(gender_counts, labels=gender_counts.index, autopct='%1.1f%%', startangle=140)\\nplt.title('Gender Distribution')\\nplt.show()\\n```\\n\" additional_kwargs={} response_metadata={'finish_reason': 'stop', 'model_name': 'TableGPT2-7B'} id='run-2d05b2ab-32f4-481f-8fa5-43c78515d9c3'\n", 188 | "content=[{'type': 'text', 'text': '```pycon\\n
\\n```'}, {'type': 'image_url', 'image_url': {'url': 'data:image/png;base64,iVBORw0K...'}}] name='python' id='51a99935-b0b1-496d-9a45-c1f318104773' tool_call_id='918d57ee-7362-4e0d-8d66-64b7e57ecaf6' artifact=[]\n", 189 | "content='饼图显示数据集中性别分布为 50% 女性和 50% 男性,这表明男性和女性乘客数量相等。' additional_kwargs={} response_metadata={'finish_reason': 'stop', 'model_name': 'qwen2-vl-7b-instruct'} id='run-d9b0e891-f03c-40c8-8474-9fef7511c40b'\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "async for event in agent_with_vlm.astream_events(\n", 195 | " None,\n", 196 | " to_replay.config,\n", 197 | " version=\"v2\",\n", 198 | "):\n", 199 | " evt = event[\"event\"]\n", 200 | " if evt == \"on_chat_model_end\":\n", 201 | " print(event[\"data\"][\"output\"])\n", 202 | " elif event[\"name\"] == \"tool_node\" and evt == \"on_chain_stream\":\n", 203 | " for lc_msg in event[\"data\"][\"chunk\"][\"messages\"]:\n", 204 | " print(lc_msg)\n", 205 | " else:\n", 206 | " # Handle other events here\n", 207 | " pass" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "id": "20d009cb", 213 | "metadata": {}, 214 | "source": [ 215 | "We observe that the answer provided by the agent with VLM support is significantly more detailed, including a comprehensive description of the generated images." 216 | ] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": "Python 3 (ipykernel)", 222 | "language": "python", 223 | "name": "python3" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.12.5" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 5 240 | } 241 | -------------------------------------------------------------------------------- /docs/tutorials/quick-start.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9a12e134", 6 | "metadata": {}, 7 | "source": [ 8 | "# Quickstart\n", 9 | "\n", 10 | "## Installation\n", 11 | "\n", 12 | "To install TableGPT Agent, use the following command:" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "fe436583", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%pip install tablegpt-agent" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "b81082f9-c74b-4d11-ab1f-2c8c041a29c4", 28 | "metadata": {}, 29 | "source": [ 30 | "TableGPT Agent depends on pybox to manage code execution environment. By default, pybox operates in an in-cluster mode. If you intend to run tablegpt-agent in a local environment, install the optional dependency as follows:" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "id": "4c692b35-0e56-4d3e-b20a-6fefd6dbc9e4", 37 | "metadata": { 38 | "scrolled": true 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "%pip install tablegpt-agent[local]" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "id": "2a55ff4b", 48 | "metadata": {}, 49 | "source": [ 50 | "\n", 51 | "This tutorial uses `langchain-openai` for the chat model instance. Please make sure you have it installed:" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "503a2807", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "%pip install langchain-openai" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "b2d82049", 67 | "metadata": {}, 68 | "source": [ 69 | "## Setup the LLM Service\n", 70 | "\n", 71 | "Before using TableGPT Agent, ensure you have an OpenAI-compatible server configured to host TableGPT2. We recommend using [vllm](https://github.com/vllm-project/vllm) for this:\n", 72 | "\n", 73 | "```bash\n", 74 | "python -m vllm.entrypoints.openai.api_server --served-model-name TableGPT2-7B --model path/to/weights\n", 75 | "```\n", 76 | "\n", 77 | "> **NOTES:**\n", 78 | ">\n", 79 | "> - To analyze tabular data with `tablegpt-agent`, make sure `TableGPT2` is served with `vllm` version 0.5.5 or higher.\n", 80 | "> - For production environments, it's important to optimize the vllm server configuration. For details, refer to the [vllm documentation on server configuration](https://docs.vllm.ai/en/v0.6.0/serving/openai_compatible_server.html#command-line-arguments-for-the-server)." 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "2f0a9ec8", 86 | "metadata": {}, 87 | "source": [ 88 | "## Create TableGPT Agent\n", 89 | "\n", 90 | "> **NOTE:** TableGPT Agent fully supports aync invocation. If you are running this tutorial in a Jupyter Notebook, no additional setup is required. However, if you plan to run the tutorial in a Python console, make sure to use a console that supports asynchronous operations. To get started, execute the following command:\n", 91 | ">\n", 92 | "> ```bash\n", 93 | "> python -m asyncio\n", 94 | "> ```\n", 95 | "\n", 96 | "In the console or notebook, create the agent as follows:" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "4ac32d2f", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "from langchain_openai import ChatOpenAI\n", 107 | "from pybox import AsyncLocalPyBoxManager\n", 108 | "from tablegpt.agent import create_tablegpt_graph\n", 109 | "\n", 110 | "\n", 111 | "llm = ChatOpenAI(openai_api_base=\"YOUR_VLLM_URL\", openai_api_key=\"whatever\", model_name=\"TableGPT2-7B\")\n", 112 | "pybox_manager = AsyncLocalPyBoxManager()\n", 113 | "\n", 114 | "agent = create_tablegpt_graph(\n", 115 | " llm=llm,\n", 116 | " pybox_manager=pybox_manager,\n", 117 | ")" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "31ff4fe9", 123 | "metadata": {}, 124 | "source": [ 125 | "## Start Chatting" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 3, 131 | "id": "ee24c200", 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "[HumanMessage(content='Hi', additional_kwargs={}, response_metadata={}, id='34fe748c-81ab-49ea-bec6-9c621598a61a'), AIMessage(content=\"Hello! How can I assist you with data analysis today? Please let me know the details of the dataset you're working with and what specific analysis you'd like to perform.\", additional_kwargs={'parent_id': 'some-parent-id'}, response_metadata={}, id='a1ee29d2-723e-41c7-b420-27d0cfaed5dc')]\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "from datetime import date\n", 144 | "from langchain_core.messages import HumanMessage\n", 145 | "\n", 146 | "message = HumanMessage(content=\"Hi\")\n", 147 | "\n", 148 | "_input = {\n", 149 | " \"messages\": [message],\n", 150 | " \"parent_id\": \"some-parent-id\",\n", 151 | " \"date\": date.today(),\n", 152 | "}\n", 153 | "\n", 154 | "state = await agent.ainvoke(_input)\n", 155 | "state[\"messages\"]" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "id": "3eca819d", 161 | "metadata": {}, 162 | "source": [ 163 | "You can get more detailed outputs with the `astream_events` method:" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 4, 169 | "id": "3265cf83", 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "content='Hello! How can I assist you with your data analysis today? Please let me know what dataset you are working with and what specific analyses or visualizations you would like to perform.' additional_kwargs={} response_metadata={'finish_reason': 'stop', 'model_name': 'TableGPT2-7B'} id='run-525eb149-0e3f-4b04-868b-708295f789ac'\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "async for event in agent.astream_events(\n", 182 | " input=_input,\n", 183 | " version=\"v2\",\n", 184 | "):\n", 185 | " # We ignore irrelevant events here.\n", 186 | " if event[\"event\"] == \"on_chat_model_end\":\n", 187 | " print(event[\"data\"][\"output\"])" 188 | ] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3 (ipykernel)", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.12.5" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 5 212 | } 213 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tablegpt/tablegpt-agent/6195b6437b01a3539f70af1655afc1e2e6012958/examples/__init__.py -------------------------------------------------------------------------------- /examples/data_analysis.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import date 3 | from typing import TypedDict 4 | 5 | from langchain_core.messages import HumanMessage 6 | from langchain_openai import ChatOpenAI 7 | from langgraph.checkpoint.memory import MemorySaver 8 | from pybox import AsyncLocalPyBoxManager 9 | from tablegpt import DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR 10 | from tablegpt.agent import create_tablegpt_graph 11 | from tablegpt.agent.file_reading import Stage 12 | 13 | 14 | class Attachment(TypedDict): 15 | """Contains at least one dictionary with the key filename.""" 16 | 17 | filename: str 18 | """The dataset uploaded in this session can be a filename, file path, or object storage address.""" 19 | 20 | 21 | # tablegpt-agent fully supports async invocation 22 | async def main() -> None: 23 | llm = ChatOpenAI( 24 | openai_api_base="YOUR_VLLM_URL", 25 | openai_api_key="whatever", 26 | model_name="TableGPT2-7B", 27 | ) 28 | 29 | # Use local pybox manager for development and testing 30 | pybox_manager = AsyncLocalPyBoxManager(profile_dir=DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR) 31 | 32 | agent = create_tablegpt_graph( 33 | llm=llm, 34 | pybox_manager=pybox_manager, 35 | # We use MemorySaver as a checkpointer to record memory automatically. 36 | # See 37 | checkpointer=MemorySaver(), 38 | # All code generated in this run will be executed in the kernel with kernel_id 'some-session-id'. 39 | session_id="some-session-id", 40 | ) 41 | 42 | attachment_msg = HumanMessage( 43 | content="", 44 | # The dataset can be viewed in examples/datasets/titanic.csv. 45 | additional_kwargs={"attachments": [Attachment(filename="examples/datasets/titanic.csv")]}, 46 | ) 47 | await agent.ainvoke( 48 | input={ 49 | "entry_message": attachment_msg, 50 | "processing_stage": Stage.UPLOADED, 51 | "messages": [attachment_msg], 52 | "parent_id": "some-parent-id1", 53 | "date": date.today(), # noqa: DTZ011 54 | }, 55 | config={ 56 | "configurable": {"thread_id": "some-thread-id"}, 57 | }, 58 | ) 59 | 60 | human_message = HumanMessage(content="How many men survived?") 61 | 62 | async for event in agent.astream_events( 63 | input={ 64 | # After using checkpoint, you only need to add new messages here. 65 | "messages": [human_message], 66 | "parent_id": "some-parent-id2", 67 | "date": date.today(), # noqa: DTZ011 68 | }, 69 | version="v2", 70 | # We configure the same thread_id to use checkpoints to retrieve the memory of the last run. 71 | config={"configurable": {"thread_id": "some-thread-id"}}, 72 | ): 73 | print(event) # noqa: T201 74 | 75 | 76 | asyncio.run(main()) 77 | -------------------------------------------------------------------------------- /examples/datasets/titanic.csv: -------------------------------------------------------------------------------- 1 | Pclass,Sex,Age,SibSp,Parch,Fare,Embarked,Survived 2 | 2,female,29,0,2,23,S,1 3 | 3,female,39,1,5,31.275,S,0 4 | 3,male,26.5,0,0,7.225,C,0 5 | 3,male,32,0,0,56.4958,S,1 6 | -------------------------------------------------------------------------------- /examples/datasets/产品生产统计表.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tablegpt/tablegpt-agent/6195b6437b01a3539f70af1655afc1e2e6012958/examples/datasets/产品生产统计表.xlsx -------------------------------------------------------------------------------- /examples/datasets/产品销量表.csv: -------------------------------------------------------------------------------- 1 | 编号,名称,单位, 单价(元) ,销售量, 销售额 2 | mb2033,法式面包,包, ¥7.40 ,305080," ¥2,257,592.00 " 3 | mb2034,奶昔蛋糕,包, ¥5.80 ,93200," ¥540,560.00 " 4 | mb2035,奶油夹心饼干,包, ¥3.10 ,215300," ¥667,430.00 " 5 | mb2036,葱油饼,包, ¥2.20 ,102300," ¥225,060.00 " 6 | mb2037,花生桃酥,包, ¥3.80 ,130000," ¥494,000.00 " 7 | mb2038,巧克力饼干,包, ¥4.50 ,119800," ¥539,100.00 " 8 | mb2039,果酱饼干,包, ¥4.10 ,120516," ¥494,115.60 " 9 | mb2040,肉沫夹心饼,包, ¥5.50 ,86000," ¥473,000.00 " 10 | mb2041,早餐饼干,包, ¥2.30 ,104500," ¥240,350.00 " 11 | yl1322,矿泉水,瓶, ¥0.90 ,65000," ¥58,500.00 " 12 | yl1323,可乐,瓶, ¥3.50 ,10200," ¥35,700.00 " 13 | yl1324,冰咖啡,瓶, ¥5.60 ,235040," ¥1,316,224.00 " 14 | yl1325,优果汁,瓶, ¥3.50 ,130500," ¥456,750.00 " 15 | yl1326,奶茶,瓶, ¥4.20 ,98000," ¥411,600.00 " 16 | gg0258,奶油瓜子,千克, ¥6.10 ,105000," ¥640,500.00 " 17 | gg0259,五香瓜子,千克, ¥8.50 ,150000," ¥1,275,000.00 " 18 | gg0260,白味瓜子,千克, ¥8.20 ,132000," ¥1,082,400.00 " 19 | gg0261,麻辣花生,千克, ¥9.00 ,120500," ¥1,084,500.00 " 20 | gg0262,麻辣瓜子仁,千克, ¥9.50 ,98000," ¥931,000.00 " 21 | gg0263,薯条,千克, ¥9.50 ,130000," ¥1,235,000.00 " 22 | gg0264,香酥爆米花,千克, ¥10.00 ,125800," ¥1,258,000.00 " 23 | -------------------------------------------------------------------------------- /examples/quick_start.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from datetime import date 3 | 4 | from langchain_core.messages import HumanMessage 5 | from langchain_openai import ChatOpenAI 6 | from pybox import AsyncLocalPyBoxManager 7 | from tablegpt.agent import create_tablegpt_graph 8 | 9 | 10 | # tablegpt-agent fully supports async invocation 11 | async def main() -> None: 12 | llm = ChatOpenAI(openai_api_base="YOUR_VLLM_URL", openai_api_key="whatever", model_name="TableGPT2-7B") 13 | 14 | # Use local pybox manager for development and testing 15 | pybox_manager = AsyncLocalPyBoxManager() 16 | 17 | agent = create_tablegpt_graph( 18 | llm=llm, 19 | pybox_manager=pybox_manager, 20 | ) 21 | 22 | message = HumanMessage(content="Hi") 23 | _input = { 24 | "messages": [message], 25 | "parent_id": "some-parent-id", 26 | "date": date.today(), # noqa: DTZ011 27 | } 28 | 29 | async for event in agent.astream_events( 30 | input=_input, 31 | version="v2", 32 | ): 33 | print(event) # noqa: T201 34 | 35 | 36 | asyncio.run(main()) 37 | -------------------------------------------------------------------------------- /ipython/README.md: -------------------------------------------------------------------------------- 1 | # TableGPT IPython Kernel 2 | 3 | This kernel is used to execute code generated by `tablegpt-agent` and has been equipped with data analysis and Chinese font support. 4 | 5 | ## Startup Scripts 6 | 7 | It's recommended to put some helper functions or configurations in the startup scripts. Place your startup scripts to `~/.ipython/profile_default/startup/` directory to take effect. 8 | 9 | Note: The `~/.ipython` directory must be writable for the process launching the kernel, otherwise there will be a warning message: `UserWarning: IPython dir '/home/jovyan/.ipython' is not a writable location, using a temp directory.` and the startup scripts won't take effects. 10 | 11 | Official document at `~/.ipython/profile_default/startup/README`: 12 | 13 | > This is the IPython startup directory 14 | > 15 | > .py and .ipy files in this directory will be run *prior* to any code or files specified 16 | > via the exec_lines or exec_files configurables whenever you load this profile. 17 | > 18 | > Files will be run in lexicographical order, so you can control the execution order of files 19 | > with a prefix, e.g.:: 20 | > 21 | > 00-first.py 22 | > 50-middle.py 23 | > 99-last.ipy -------------------------------------------------------------------------------- /ipython/ipython-startup-scripts/00-pandas.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | pd.set_option("display.width", 2048) 4 | # 8 is the minimum value to display `df.describe()`. We have other truncation mechanisms so it's OK to flex this a bit. 5 | pd.set_option("display.max_rows", 8) 6 | pd.set_option("display.max_columns", 40) 7 | pd.set_option("display.max_colwidth", 40) 8 | pd.set_option("display.precision", 3) 9 | pd.set_option("future.no_silent_downcasting", True) 10 | -------------------------------------------------------------------------------- /ipython/ipython-startup-scripts/98-udfs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import concurrent.futures 4 | import os 5 | from pathlib import Path 6 | from typing import NamedTuple, cast 7 | 8 | import pandas as pd 9 | 10 | 11 | class FileEncoding(NamedTuple): 12 | """File encoding as the NamedTuple.""" 13 | 14 | encoding: str | None 15 | """The encoding of the file.""" 16 | confidence: float 17 | """The confidence of the encoding.""" 18 | language: str | None 19 | """The language of the file.""" 20 | 21 | 22 | def detect_file_encodings( 23 | file_path: str | Path, timeout: int = 5 24 | ) -> list[FileEncoding]: 25 | """Try to detect the file encoding. 26 | 27 | Returns a list of `FileEncoding` tuples with the detected encodings ordered 28 | by confidence. 29 | 30 | Args: 31 | file_path: The path to the file to detect the encoding for. 32 | timeout: The timeout in seconds for the encoding detection. 33 | """ 34 | import chardet 35 | 36 | file_path = str(file_path) 37 | 38 | def read_and_detect(file_path: str) -> list[dict]: 39 | with open(file_path, "rb") as f: 40 | rawdata = f.read() 41 | return cast(list[dict], chardet.detect_all(rawdata)) 42 | 43 | with concurrent.futures.ThreadPoolExecutor() as executor: 44 | future = executor.submit(read_and_detect, file_path) 45 | try: 46 | encodings = future.result(timeout=timeout) 47 | except concurrent.futures.TimeoutError: 48 | raise TimeoutError( 49 | f"Timeout reached while detecting encoding for {file_path}" 50 | ) 51 | 52 | if all(encoding["encoding"] is None for encoding in encodings): 53 | raise RuntimeError(f"Could not detect encoding for {file_path}") 54 | return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None] 55 | 56 | 57 | def path_from_uri(uri: str) -> Path: 58 | """Return a new path from the given 'file' URI. 59 | This is implemented in Python 3.13. 60 | See 61 | and 62 | TODO: remove when we migrate to Python 3.13""" 63 | if not uri.startswith("file:"): 64 | raise ValueError(f"URI does not start with 'file:': {uri!r}") 65 | path = uri[5:] 66 | if path[:3] == "///": 67 | # Remove empty authority 68 | path = path[2:] 69 | elif path[:12] == "//localhost/": 70 | # Remove 'localhost' authority 71 | path = path[11:] 72 | if path[:3] == "///" or (path[:1] == "/" and path[2:3] in ":|"): 73 | # Remove slash before DOS device/UNC path 74 | path = path[1:] 75 | if path[1:2] == "|": 76 | # Replace bar with colon in DOS drive 77 | path = path[:1] + ":" + path[2:] 78 | from urllib.parse import unquote_to_bytes 79 | 80 | path = Path(os.fsdecode(unquote_to_bytes(path))) 81 | if not path.is_absolute(): 82 | raise ValueError(f"URI is not absolute: {uri!r}") 83 | return path 84 | 85 | 86 | def file_extention(file: str) -> str: 87 | path = Path(file) 88 | return path.suffix 89 | 90 | 91 | def read_df(uri: str, *, autodetect_encoding: bool = True, **kwargs) -> pd.DataFrame: 92 | """A simple wrapper to read different file formats into DataFrame.""" 93 | try: 94 | return _read_df(uri, **kwargs) 95 | except UnicodeDecodeError as e: 96 | if autodetect_encoding: 97 | detected_encodings = detect_file_encodings(path_from_uri(uri), timeout=30) 98 | for encoding in detected_encodings: 99 | try: 100 | return _read_df(uri, encoding=encoding.encoding, **kwargs) 101 | except UnicodeDecodeError: 102 | continue 103 | # Either we ran out of detected encoding, or autodetect_encoding is False, 104 | # we should raise encoding error 105 | raise ValueError(f"不支持的文件编码{e.encoding},请转换成 utf-8 后重试") # noqa: RUF001 106 | 107 | 108 | def _read_df(uri: str, encoding: str = "utf-8", **kwargs) -> pd.DataFrame: 109 | """A simple wrapper to read different file formats into DataFrame.""" 110 | ext = file_extention(uri).lower() 111 | if ext == ".csv": 112 | df = pd.read_csv(uri, encoding=encoding, **kwargs) 113 | elif ext == ".tsv": 114 | df = pd.read_csv(uri, sep="\t", encoding=encoding, **kwargs) 115 | elif ext in [".xls", ".xlsx", ".xlsm", ".xlsb", ".odf", ".ods", ".odt"]: 116 | # read_excel does not support 'encoding' arg, also it seems that it does not need it. 117 | df = pd.read_excel(uri, **kwargs) 118 | else: 119 | raise ValueError( 120 | f"TableGPT 目前支持 csv、tsv 以及 xlsx 文件,您上传的文件格式 {ext} 暂不支持。" # noqa: RUF001 121 | ) 122 | return df 123 | -------------------------------------------------------------------------------- /ipython/ipython-startup-scripts/99-cfont.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | from mplfonts import use_font 3 | 4 | use_font("Noto Serif CJK SC") 5 | sns.set_theme(font="Noto Serif CJK SC") 6 | -------------------------------------------------------------------------------- /ipython/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas >=2.2,<3.0.0 2 | scipy >=1.13.0,<2.0.0 3 | tabulate >=0.9.0,<1.0.0 4 | scikit-learn >=1.0.0,<2.0.0 5 | statsmodels >=0.10.0,<1.0.0 6 | matplotlib >=3.8.4,<4.0.0 7 | seaborn >=0.13.1,<1.0.0 8 | mplfonts >=0.0.8,<1.0.0 9 | numexpr >=2.8.4 10 | openpyxl >=3.1.2,<4.0.0 # read xlsx files 11 | xlrd >= 2.0.1 # read xls files 12 | odfpy # read ods files 13 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: TableGPT Agent 2 | 3 | theme: 4 | name: "material" 5 | features: 6 | - navigation.footer 7 | - search.highlight 8 | - search.share 9 | - content.action.edit 10 | - content.action.view 11 | icon: 12 | edit: material/pencil 13 | view: material/eye 14 | palette: 15 | # Palette toggle for light mode 16 | - scheme: default 17 | toggle: 18 | icon: material/brightness-7 19 | name: Switch to dark mode 20 | # Palette toggle for dark mode 21 | - scheme: slate 22 | toggle: 23 | icon: material/brightness-4 24 | name: Switch to light mode 25 | 26 | plugins: 27 | - mkdocs-jupyter 28 | - mkdocstrings 29 | - search 30 | 31 | extra_css: 32 | - stylesheets/extra.css 33 | 34 | markdown_extensions: 35 | - pymdownx.highlight: 36 | anchor_linenums: true 37 | line_spans: __span 38 | pygments_lang_class: true 39 | - pymdownx.inlinehilite 40 | - pymdownx.snippets 41 | - pymdownx.superfences 42 | - toc: 43 | permalink: "#" 44 | 45 | nav: 46 | - Home: index.md 47 | - Tutorials: 48 | - 'Quick Start': tutorials/quick-start.ipynb 49 | - 'Chat on tablular data': tutorials/chat-on-tabular-data.ipynb 50 | - 'Continue Analysis on Generated Charts': tutorials/continue-analysis-on-generated-charts.ipynb 51 | - 'How-To Guides': 52 | - 'Enhance TableGPT Agent with RAG': howto/retrieval.ipynb 53 | - 'Persist Messages': howto/persist-messages.ipynb 54 | - 'Messages Truncation': howto/messages-truncation.ipynb 55 | - 'Incluster Code Execution': howto/incluster-code-execution.md 56 | - 'Normalize Datasets': howto/normalize-datasets.ipynb 57 | - 'Cleanup Error Trace': howto/cleanup-error-trace.md 58 | - 'Customize Table Info': howto/customize-table-info.md 59 | - Reference: reference.md 60 | - Explanation: 61 | - 'Agent Workflow': explanation/agent-workflow.md 62 | - 'File Reading': explanation/file-reading.ipynb 63 | - 'Code Sandbox': explanation/code-sandbox.md 64 | - 'IPython Startup Scripts': explanation/ipython-startup-scripts.md 65 | 66 | repo_name: tablegpt/tablegpt-agent 67 | repo_url: https://github.com/tablegpt/tablegpt-agent 68 | edit_uri: edit/main/docs/ 69 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "tablegpt-agent" 7 | dynamic = ["version"] 8 | description = '' 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | license = {file = "LICENSE"} 12 | keywords = [] 13 | authors = [ 14 | { name = "Aofeng Su", email = "saf@zjuici.com" }, 15 | { name = "Chen Zhou", email = "zc@zjuici.com" }, 16 | { name = "Junbo Zhao", email = "j.zhao@zju.edu.cn" }, 17 | { name = "Junlin Zhou", email = "jlzhou@zjuici.com" }, 18 | { name = "Tao Zhang", email = "zt@zjuici.com" }, 19 | { name = "Xiang Li", email = "xli@zjuici.com" }, 20 | ] 21 | classifiers = [ 22 | "Development Status :: 4 - Beta", 23 | "Programming Language :: Python", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | "Programming Language :: Python :: 3.12", 28 | "Programming Language :: Python :: Implementation :: CPython", 29 | "Programming Language :: Python :: Implementation :: PyPy", 30 | ] 31 | dependencies = [ 32 | "chardet>=5.2.0,<6.0.0", 33 | "langchain-core>=0.3.0,<1.0.0", 34 | "langgraph>=0.0.68,<1.0.0", 35 | "pandas>=2.2,<3.0.0", 36 | "pppybox>=0.0.17" 37 | ] 38 | 39 | [project.urls] 40 | Documentation = "https://tablegpt.github.io/tablegpt-agent/" 41 | Issues = "https://github.com/tablegpt/tablegpt-agent/issues" 42 | Source = "https://github.com/tablegpt/tablegpt-agent" 43 | 44 | [project.optional-dependencies] 45 | local = [ 46 | "pandas >=2.2,<3.0.0", 47 | "scipy >=1.13.0,<2.0.0", 48 | "tabulate >=0.9.0,<1.0.0", 49 | "scikit-learn >=1.0.0,<2.0.0", 50 | "statsmodels >=0.10.0,<1.0.0", 51 | "matplotlib >=3.8.4,<4.0.0", 52 | "seaborn >=0.13.1,<1.0.0", 53 | "mplfonts >=0.0.8,<1.0.0", 54 | "numexpr >=2.8.4", 55 | "openpyxl >=3.1.2,<4.0.0", 56 | "xlrd >= 2.0.1", 57 | "odfpy", 58 | "pppybox[local]" 59 | ] 60 | 61 | [tool.hatch.build.targets.sdist] 62 | exclude = [ 63 | ".devcontainer", 64 | ".github", 65 | "/docs", 66 | "/examples", 67 | "/realtabbench", 68 | "collect_script.py", 69 | ] 70 | 71 | [tool.hatch.build.targets.wheel] 72 | packages = ["src/tablegpt"] 73 | 74 | [tool.hatch.build.targets.wheel.shared-data] 75 | "ipython/ipython-startup-scripts" = "share/ipykernel/profile/tablegpt/startup" 76 | 77 | [tool.hatch.version] 78 | path = "src/tablegpt/__about__.py" 79 | 80 | [tool.hatch.envs.types] 81 | extra-dependencies = [ 82 | "mypy>=1.0.0", 83 | ] 84 | [tool.hatch.envs.types.scripts] 85 | check = "mypy --install-types --non-interactive {args:src/tablegpt tests}" 86 | 87 | [tool.hatch.envs.docs] 88 | dependencies = [ 89 | "mkdocs", 90 | "mkdocstrings[python]", 91 | "mkdocs-jupyter", 92 | "mkdocs-material", 93 | ] 94 | 95 | [tool.coverage.run] 96 | source_pkgs = ["tablegpt", "tests"] 97 | branch = true 98 | parallel = true 99 | omit = [ 100 | "src/tablegpt/__about__.py", 101 | ] 102 | 103 | [tool.coverage.paths] 104 | tablegpt = ["src/tablegpt"] 105 | tests = ["tests"] 106 | 107 | [tool.coverage.report] 108 | exclude_lines = [ 109 | "no cov", 110 | "if __name__ == .__main__.:", 111 | "if TYPE_CHECKING:", 112 | ] 113 | 114 | [tool.ruff] 115 | # Exclude a variety of commonly ignored directories. 116 | exclude = [ 117 | "ipython" 118 | ] 119 | # Allow lines to be as long as 120. 120 | line-length = 120 121 | 122 | [tool.ruff.lint.flake8-tidy-imports] 123 | ban-relative-imports = "parents" 124 | 125 | [tool.ruff.lint.flake8-type-checking] 126 | runtime-evaluated-base-classes = ["pydantic.BaseModel", "sqlalchemy.orm.DeclarativeBase"] 127 | runtime-evaluated-decorators = ["pydantic.validate_call", "attrs.define"] 128 | -------------------------------------------------------------------------------- /realtabbench/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark Evaluations: A Variety of Academic and Table-Related Benchmark Evaluations for TableGPT2 2 | 3 | ## Overview 4 | 5 | This folder is dedicated to evaluating TableGPT2 across diverse table-related benchmarks. Given the complexity and variability of table-based tasks and input instructions, we provide evaluation datasets and scripts covering several prominent benchmarks: 6 | 7 | - ✨ **Table-Bench**: standardized table comprehension and reasoning tasks. 8 | - ✨ **Text2SQL**: evaluates SQL generation capabilities from natural language queries. 9 | - ✨ **TableInstruct**: a suite of benchmarks focused on various table-related tasks. 10 | - ✨ **RealTabBench**: our custom benchmark specifically crafted to test LLMs on intricate, real-world tabular data scenarios, including irregular table structures, anonymized fields, and complex queries. *(Note: Only a portion of this benchmark is released here.)* 11 | 12 | We utilize an inference framework based on local model paths using vLLM as the backend, with example prompt templates tailored for each benchmark. 13 | 14 | ## Usage 15 | 16 | 17 | 18 | 19 | 20 | To use this framework, first clone the repository and install the necessary dependencies: 21 | 22 | ```shell 23 | git clone https://github.com/tablegpt/tablegpt-agent 24 | cd realtabbench 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | 29 | 30 | 31 | 32 | ### Dataset 33 | 34 | The necessary database files are available on Google Drive. Download the files from the following URLs: 35 | - [spider dev](https://drive.google.com/file/d/15xVsPLEVHXxyfczrAjYYKUEzFX6Jxjzn/view?usp=sharing) 36 | - [spider test](https://drive.google.com/file/d/1O_Bs4Nw4vIjKx2T5IXUgjhG4AxVxCl78/view?usp=sharing) 37 | - [bird dev](https://drive.google.com/file/d/1gXS8syJC0WcyDzX3LT2AdDxs9peWhsyV/view?usp=sharing) 38 | - [RealTabBench](https://drive.google.com/file/d/1-PHf81VKlsI7jiREZ3v82UkHGUghrsTT/view?usp=sharing) 39 | 40 | ### Text2SQL Evaluation 41 | 42 | Steps to Run 43 | 44 | 1. download database files (bird or spider) 45 | 46 | 2. extract files 47 | Download and unzip each file into its respective directory: 48 | ```bash 49 | unzip bird_dev_database.zip -d realtabbench/evalset/bird_data \ 50 | && 51 | unzip spider_dev_database.zip -d realtabbench/evalset/spider_data \ 52 | && 53 | unzip spider_test_database.zip -d realtabbench/evalset/spider_data 54 | ``` 55 | 56 | 3. run evaluation script 57 | Execute the evaluation script to obtain accuracy metrics for the bird or spider datasets: 58 | ```bash 59 | python run_text2sql_eval.py --model_path \ 60 | --eval_data_name \ 61 | --mode 62 | ``` 63 | 64 | ### Agent Evaluation on RealTabBench 65 | 66 | 1. download data files from google drive 67 | 68 | 2. create virtual environment 69 | ```bash 70 | python -m venv venv 71 | source ./venv/bin/activate # On Windows, use `.\venv\Scripts\activate` 72 | ``` 73 | 74 | 3. install dependencies for eval 75 | ```bash 76 | cd realtabbench/agent_eval 77 | pip install -r requirements.txt 78 | ``` 79 | 80 | 4. run evaluation script 81 | ```bash 82 | python -m agent_eval --config path/to/your/config.yaml 83 | ``` 84 | -------------------------------------------------------------------------------- /realtabbench/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tablegpt/tablegpt-agent/6195b6437b01a3539f70af1655afc1e2e6012958/realtabbench/__init__.py -------------------------------------------------------------------------------- /realtabbench/agent_eval/README.md: -------------------------------------------------------------------------------- 1 | # TableGPT Evaluation 2 | 3 | This document will guide you through the process of setting up the evaluation environment and running evaluations. 4 | 5 | ## Evaluation Datasets 6 | 7 | Before running the evaluation, you need to create the evaluation datasets on Local. 8 | 9 | In the evaluation context, the term "dataset" can be confusing because it has two different meanings. The first refers to evaluation datasets, which contain the samples you wish to evaluate. Each sample must have an 'input' field representing the user input and may optionally include an 'expected output' field if there is a ground truth answer to that input. The second definition refers to the dataset on which the user wants to perform analysis, which we refer to as 'reference data'. 10 | 11 | ### Input 12 | 13 | We use LLM to assist in generating questions based on the input dataset. You can find the script [here](./questioner.py). 14 | 15 | Please note that while our goal was to create a one-click solution for question generation, the current implementation may require some manual adjustments. Depending on your dataset, you might need to tweak the prompt accordingly. For instance, the default prompt aims to "uncover business value," which is not suitable for datasets related to diseases. 16 | 17 | ### Expected Output 18 | 19 | While not all samples require an 'expected output' field, certain inputs—particularly those related to data analysis—do need a ground truth answer for comparison during evaluation. We use Agent Apps (such as ChatGPT, ChatGLM, etc.) to assist in generating the 'expected output.' 20 | 21 | It's crucial to be meticulous when crafting the 'expected output' because it serves as the ground truth for evaluation. If the 'expected output' is incorrect, the evaluation results will be inaccurate. 22 | 23 | ## Installation 24 | 25 | Create a virtual environment 26 | 27 | ```sh 28 | python -m venv venv 29 | source ./venv/bin/activate # On Windows, use `.\venv\Scripts\activate` 30 | ``` 31 | 32 | Install dependencies for eval 33 | 34 | ```sh 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ## Configuration 39 | 40 | The configuration file for evaluation is a YAML file (config.yaml by default). Refer to [example-config.yaml](./example-config.yaml) for detailed information. 41 | 42 | ## Run the evaluation script 43 | 44 | Besides the config file, you need to set up some environment variables, either by exporting them or by creating a `.env` file in the root directory. 45 | 46 | To run the evaluation script, use the following command: 47 | 48 | ```sh 49 | python -m agent_eval --config path/to/your/config.yaml 50 | ``` 51 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | logger.setLevel(level=LOG_LEVEL) 9 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/__main__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | import signal 5 | import sys 6 | 7 | from dotenv import find_dotenv, load_dotenv 8 | from langchain.globals import set_debug 9 | from traitlets.log import get_logger 10 | 11 | from .config import load_config 12 | from .runner import Runner 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") 18 | set_debug(LOG_LEVEL.upper() == "TRACE") 19 | 20 | # silent traitlets logs 21 | traitlets_logger = get_logger() 22 | traitlets_logger.setLevel("ERROR") 23 | 24 | 25 | async def main() -> None: 26 | # Set up signal handling for graceful shutdown 27 | stop_event = asyncio.Event() 28 | # Windows does not support signal handling, we handle KeyboardInterrupt instead 29 | if sys.platform != "win32": 30 | loop = asyncio.get_running_loop() 31 | loop.add_signal_handler(signal.SIGINT, stop_event.set) 32 | loop.add_signal_handler(signal.SIGTERM, stop_event.set) 33 | 34 | config = load_config() 35 | evaluator = Runner(config) 36 | try: 37 | await evaluator.run(stop_event) 38 | except asyncio.exceptions.CancelledError: 39 | stop_event.set() 40 | except KeyboardInterrupt: 41 | # TODO: On Windows we should enter here. However we went to the except block above. 42 | logger.warning("Received CTRL+C, stopping...") 43 | stop_event.set() 44 | 45 | 46 | if __name__ == "__main__": 47 | if sys.platform == "win32": 48 | asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) 49 | 50 | load_dotenv(find_dotenv()) 51 | asyncio.run(main()) 52 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | from typing import Any 5 | from uuid import uuid4 6 | 7 | import yaml 8 | from pydantic import BaseModel, Field, PositiveInt 9 | from pydantic_settings import BaseSettings, SettingsConfigDict 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class DatasetSettings(BaseModel): 15 | name: str 16 | 17 | 18 | class EvalSettings(BaseSettings): 19 | model_config = SettingsConfigDict(extra="ignore") 20 | 21 | run_name: str = Field(default_factory=lambda: f"eval-run-{uuid4()}") 22 | metadata: dict[str, Any] 23 | user: str = "eval-user" 24 | datasets: list[DatasetSettings] 25 | 26 | max_concurrency: PositiveInt = 1 27 | num_repetitions: PositiveInt = 1 28 | 29 | evaluatee_class: str 30 | evaluator: dict[str, Any] 31 | 32 | 33 | def load_config() -> dict[str, Any]: 34 | parser = argparse.ArgumentParser(description="Run the evaluation script.") 35 | parser.add_argument( 36 | "--config", 37 | type=str, 38 | default="config.yaml", 39 | help="Config file location.", 40 | ) 41 | args = parser.parse_args() 42 | config_path = Path(args.config).absolute() 43 | if not config_path.exists(): 44 | raise RuntimeError(f"Config file '{args.config}' not found") # noqa: TRY003, EM102 45 | 46 | logger.info("Using config file: %s", config_path) 47 | with open(str(config_path)) as file: 48 | try: 49 | config = yaml.safe_load(file) 50 | except Exception: 51 | logger.exception("Error loading config file") 52 | raise 53 | 54 | return EvalSettings(**config) 55 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/evaluatee.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from abc import ABC, abstractmethod 5 | from contextlib import AbstractAsyncContextManager 6 | from typing import TYPE_CHECKING 7 | 8 | if TYPE_CHECKING: 9 | from typing import Self 10 | 11 | from langchain_core.messages import BaseMessage 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class AbstractEvaluatee(AbstractAsyncContextManager, ABC): 18 | @abstractmethod 19 | async def _call(self, message: BaseMessage, **kwargs) -> list[BaseMessage]: ... 20 | 21 | async def __call__(self, message: BaseMessage, **kwargs) -> list[BaseMessage]: 22 | # TODO: add callback to handle errors or other events 23 | return await self._call(message, **kwargs) 24 | 25 | @property 26 | def context(self): 27 | return {} 28 | 29 | @classmethod 30 | @abstractmethod 31 | def instance(cls) -> Self: ... 32 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from operator import itemgetter 2 | 3 | from langchain_core.language_models import BaseLanguageModel 4 | from langchain_core.prompts import ChatPromptTemplate 5 | 6 | from .output_parser import FloatScoreOutputParser 7 | from .prompt import ( 8 | INSTRUCTION, 9 | format_criteria, 10 | format_redlines, 11 | format_reference_answer, 12 | ) 13 | 14 | PROMPT = ChatPromptTemplate.from_messages([("user", INSTRUCTION)]) 15 | 16 | 17 | def create_evaluator_runnable(llm: BaseLanguageModel): 18 | return ( 19 | { 20 | "criteria": lambda x: (format_criteria(x["criteria"]) if x.get("criteria") else ""), 21 | "redlines": lambda x: (format_redlines(x["redlines"]) if x.get("redlines") else ""), 22 | "reference_answer": lambda x: ( 23 | format_reference_answer(x["reference_answer"]) if x.get("reference_answer") else "" 24 | ), 25 | "question": itemgetter("question"), 26 | "answer": itemgetter("answer"), 27 | } 28 | | PROMPT 29 | | llm 30 | | FloatScoreOutputParser() 31 | ) 32 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/evaluator/output_parser.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from langchain.evaluation.scoring.eval_chain import ( 6 | _FIND_DOUBLE_BRACKETS, 7 | ScoreStringResultOutputParser, 8 | ) 9 | 10 | 11 | class FloatScoreOutputParser(ScoreStringResultOutputParser): 12 | prefix: str = "Score:" # Or maybe `None`? 13 | lower_bound: float = 0.0 14 | upper_bound: float = 1.0 15 | 16 | def parse(self, text: str) -> dict[str, Any]: 17 | """Parse the output text. 18 | 19 | Args: 20 | text (str): The output text to parse. 21 | 22 | Returns: 23 | dict: The parsed output. 24 | 25 | Raises: 26 | ValueError: If the verdict is invalid. 27 | """ 28 | match = _FIND_DOUBLE_BRACKETS.search(text) 29 | 30 | if match: 31 | score_str = match.group(1).strip() 32 | score = float(score_str) 33 | if score > self.upper_bound or score < self.lower_bound: 34 | raise ValueError( # noqa: TRY003 35 | f"Invalid output: {text}. " # noqa: EM102 36 | f"Output must contain a double bracketed string with the verdict between {self.lower_bound} and {self.upper_bound}." 37 | ) 38 | reason = text.rsplit(self.prefix, maxsplit=1)[0].strip() 39 | return { 40 | "reason": reason, 41 | "score": round(score, 2), 42 | } 43 | raise ValueError( # noqa: TRY003 44 | f"Invalid output: {text}. Output must contain a double bracketed string. example: [[0.5]]" # noqa: EM102 45 | ) 46 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/evaluator/prompt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | INSTRUCTION = """You are a teacher grading a quiz. Start by providing a brief reason for the rating you will assign. Then, assign a rating on a scale from 0.0 to 1.0, using the format: "Score: [[score]]" (e.g., "Score: [[0.5]]"). 4 | {criteria} 5 | {redlines} 6 | ## Quiz 7 | Question: {question} 8 | {reference_answer} 9 | Answer: {answer} 10 | """ 11 | 12 | 13 | DEFAULT_CRITERIA_WITH_REFERENCE_ANSWER = [ 14 | "Grade the student answers based ONLY on their factual accuracy relative to the ground truth answer.", 15 | "Ensure that the student answer does not contain any conflicting statements.", 16 | "It is OK if the student answer contains more information than the ground truth answer, as long as it is factually accurate relative to the ground truth answer.", 17 | ] 18 | 19 | 20 | # Picked from `langchain.evaluation.criteria.eval_chain.Criteria` 21 | DEFAULT_CRITERIA_WITHOUT_REFERENCE_ANSWER = [ 22 | "Is the submission correct, accurate, and factual?", 23 | "Is the submission concise and to the point?", 24 | "Is the submission helpful, insightful, and appropriate?", 25 | ] 26 | 27 | 28 | def format_criteria(criteria: list[str]) -> str: 29 | if not criteria: 30 | return "" 31 | # I cannot manage to format it in one f-string 32 | # Python complains about 'SyntaxError: f-string expression part cannot include a backslash' 33 | criteria_str = "\n".join([f"- {x}" for x in criteria]) 34 | return f"""## Evaluation Criteria 35 | Consider the following criteria when assigning the rating: 36 | {criteria_str} 37 | """ 38 | 39 | 40 | def format_redlines(attentions: list[str]) -> str: 41 | if not attentions: 42 | return "" 43 | attentions_str = "\n".join([f"- {x}" for x in attentions]) 44 | return f"""## Redlines 45 | If the answer touches one of the redlines listed below, assign a score of [[0.0]] directly. 46 | {attentions_str} 47 | """ 48 | 49 | 50 | def format_reference_answer(reference_answer: str) -> str: 51 | if not reference_answer: 52 | return "" 53 | return f"Reference Answer: {reference_answer}" 54 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/example-config.yaml: -------------------------------------------------------------------------------- 1 | user: eval-example 2 | 3 | metadata: 4 | name: tablegpt eval 5 | llm: 6 | name: qwen2.5-7b-instruct 7 | temperature: 0.1 8 | top_p: 0.3 9 | 10 | datasets: 11 | - name: /datasets/tablegpt-eval-normal.json 12 | 13 | evaluatee_class: "agent_eval.tablegpt_evaluatee.TablegptEvaluatee" 14 | 15 | evaluator: 16 | openai_api_base: http://localhost:8080/v1 17 | openai_api_key: nothing 18 | model_name: qwen2.5-72b-instruct 19 | temperature: 0.1 20 | top_p: 0.3 21 | max_tokens: 1024 22 | 23 | max_concurrency: 1 24 | num_repetitions: 1 25 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/questioner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | from langchain_core.output_parsers.list import NumberedListOutputParser 7 | from langchain_core.prompts.chat import ChatPromptTemplate 8 | from langchain_openai import ChatOpenAI 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | INSTRUCTION = """You are a decision maker, tasked with generating diverse and insightful questions to uncover business value based on the provided datasets. These questions should be designed to be answerable using the information within the datasets. 14 | 15 | ## Datasets Description 16 | 17 | {description} 18 | 19 | ## Provided Data 20 | 21 | You are provided with the following datasets in the form of pandas DataFrame: 22 | 23 | {df} 24 | 25 | 26 | ## Task 27 | 28 | Ask new questions that either: 29 | - Explore new aspects of the data that have not been covered by the previous questions. 30 | - Refine or build upon previous questions to gain deeper insights. 31 | 32 | ## Notes 33 | 34 | - Wrap your response in a numbered list. 35 | - Ensure the questions cover a wide range of business logic and perspectives. 36 | - Questions must be strictly answerable using the provided datasets. Avoid using business logic or information not inferable from the datasets. 37 | - Focus on practical and relevant real-world business scenarios. 38 | - All questions MUST be asked in Chinese. 39 | """ 40 | 41 | 42 | tmpl = ChatPromptTemplate.from_messages( 43 | [ 44 | ("user", INSTRUCTION), 45 | ] 46 | ) 47 | 48 | llm = ChatOpenAI( 49 | openai_api_base="http://127.0.0.1:8080/v1", 50 | openai_api_key="none", 51 | model_name="model_name", 52 | temperature=0.5, 53 | max_tokens=1024, 54 | verbose=True, 55 | ) 56 | 57 | 58 | # We might want a multi-fallback output parser to combine these output parsers: 59 | # - langchain_core.output_parsers.list.CommaSeparatedListOutputParser 60 | # - langchain_core.output_parsers.list.NumberedListOutputParser 61 | # - langchain_core.output_parsers.list.MarkdownListOutputParser 62 | chain = tmpl | llm | NumberedListOutputParser() 63 | 64 | 65 | def main(dataset_path, questions_path: Path, description: str, *, nrows: int = 3): 66 | """Generate questions related to the given dataframe.""" 67 | pd.set_option("display.max_columns", None) 68 | if not questions_path.exists(): 69 | logger.info("questions_path does not exist. Creating a new file.") 70 | questions_path.touch(mode=0o644) 71 | elif not questions_path.is_file(): 72 | logger.error("Only supports file IO for now.") 73 | sys.exit(1) 74 | 75 | df = pd.read_csv(dataset_path, nrows=nrows) 76 | 77 | # previous_questions = questions_path.read_text(encoding="utf-8") 78 | 79 | new_questions: list[str] = chain.invoke( 80 | { 81 | "df": df.head(nrows), 82 | "description": description, 83 | } 84 | ) 85 | 86 | with questions_path.open(mode="a+", encoding="utf-8") as f: 87 | for question in new_questions: 88 | f.write(question + "\n") 89 | 90 | 91 | if __name__ == "__main__": 92 | import argparse 93 | 94 | parser = argparse.ArgumentParser(description="Generate questions based on the given dataset.") 95 | parser.add_argument( 96 | "--dataset", 97 | required=True, 98 | help="dataset file path", 99 | ) # path to the csv file 100 | parser.add_argument( 101 | "-q", 102 | "--questions", 103 | required=True, 104 | help="", 105 | ) # path to the question text file 106 | parser.add_argument( 107 | "--dataset-description", 108 | required=True, 109 | help="", 110 | ) # description of the dataset 111 | 112 | args = parser.parse_args() 113 | 114 | main(args.dataset, Path(args.questions), description=args.dataset_description) 115 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/requirements.txt: -------------------------------------------------------------------------------- 1 | tablegpt-agent 2 | aiofiles 3 | tqdm 4 | pydantic >= 2.0 5 | pydantic-settings >= 2.0 6 | python-dotenv 7 | pyyaml 8 | ipython 9 | ipykernel 10 | langchain 11 | langchain-openai 12 | # `MemorySaver` can dynamically manage the `Context Manager` starting from version 2.0.5 13 | langgraph-checkpoint>=2.0.5 14 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/runner.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import datetime 5 | import json 6 | import logging 7 | from typing import TYPE_CHECKING, Any 8 | 9 | import aiofiles 10 | from langchain_core.messages import HumanMessage 11 | from tqdm.asyncio import tqdm 12 | from traitlets import import_item 13 | 14 | from .evaluatee import AbstractEvaluatee 15 | from .worker import Worker 16 | 17 | if TYPE_CHECKING: 18 | from agent_eval.config import EvalSettings 19 | from langchain_core.messages import BaseMessage 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | # TODO: make this configurable, and we can continue running after an error 25 | eval_run_output_file = f"eval_run_{datetime.datetime.now(tz=datetime.UTC).strftime('%Y%m%d_%H%M%S')}.jsonl" 26 | 27 | 28 | class Runner: 29 | """Evaluation task runner. 30 | 31 | config(config.EvalSettings): evaluation configuration. 32 | """ 33 | 34 | def __init__(self, config: EvalSettings) -> None: 35 | """Initialize the Evaluation runner with the given configuration. 36 | 37 | Args: 38 | config (dict): Configuration dictionary for the Evaluation. 39 | """ 40 | self.config = config 41 | self.evaluatee_class = import_item(config.evaluatee_class) 42 | if not issubclass(self.evaluatee_class, AbstractEvaluatee): 43 | raise TypeError(f"{config.evaluatee_class} is not a subclass of AbstractEvaluatee") # noqa: TRY003, EM102 44 | 45 | async def run(self, stop_event: asyncio.Event) -> None: 46 | """Gather evaluation samples and run the evaluation process, in parallel.""" 47 | logger.info("Gathering evaluation samples...") 48 | queue = asyncio.Queue() 49 | await enqueue_samples(queue, self.config.datasets, self.config.num_repetitions) 50 | total_samples = queue.qsize() 51 | logger.info("Gathered %s samples for evaluation", total_samples) 52 | 53 | with tqdm(total=total_samples, desc="Evaluation samples") as pbar: 54 | try: 55 | eval_tasks = [ 56 | asyncio.create_task( 57 | Worker( 58 | queue, 59 | self.evaluatee_class.instance(), 60 | stop_event, 61 | pbar, 62 | self.config.evaluator, 63 | eval_run_output_file, 64 | ).run(), 65 | name=f"worker-{i}", 66 | ) 67 | for i in range(self.config.max_concurrency) 68 | ] 69 | # Ensure all consumers exit 70 | await asyncio.gather(*eval_tasks, return_exceptions=True) 71 | except Exception: 72 | logger.exception("Error in evaluator") 73 | finally: 74 | logger.info("Shutting down evaluator...") 75 | 76 | 77 | async def enqueue_samples(queue: asyncio.Queue, dataset_configs: list[dict], num_repetitions: int = 1) -> None: 78 | """Reads datasets from the provided configurations, constructs samples, and enqueues them for processing. 79 | 80 | Args: 81 | queue (asyncio.Queue): The queue to which the samples will be added. 82 | dataset_configs (list[dict]): A list of dataset configurations, each containing a 'name' key pointing to the dataset file. 83 | num_repetitions (int, optional): The number of times each sample should be repeated in the queue. Defaults to 1. 84 | """ 85 | for dataset_config in dataset_configs: 86 | logger.debug("Gathering samples from dataset: %s...", dataset_config.name) 87 | 88 | async with aiofiles.open(dataset_config.name) as f: 89 | content = await f.read() 90 | dataset = json.loads(content) 91 | _samples = construct_samples(dataset) 92 | logger.debug( 93 | "Gathered %d samples from dataset %s", 94 | len(_samples), 95 | dataset_config.name, 96 | ) 97 | for sample in _samples: 98 | # Repeat each sample for `num_repetitions` times. 99 | for _ in range(num_repetitions): 100 | await queue.put(sample) 101 | 102 | 103 | def construct_samples(dataset: list[dict[str, Any]]) -> list[BaseMessage]: 104 | """Constructs a list of samples from the dataset, filtering out archived items and adding metadata. 105 | 106 | Args: 107 | dataset (list[dict[str, Any]]): The dataset containing items with 'status', 'attachments', and 'expected_output' keys. 108 | 109 | Returns: 110 | list[BaseMessage]: A list of `HumanMessage` objects, each containing the item's input and associated metadata (e.g., attachments, expected output, and evaluation criteria). 111 | """ 112 | # Filter out archived samples 113 | active_samples = [sample for sample in dataset if sample["status"] != "ARCHIVED"] 114 | 115 | # Construct samples with additional metadata 116 | return [HumanMessage(content=sample.pop("input"), additional_kwargs=sample) for sample in active_samples] 117 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/tablegpt_evaluatee.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import shutil 5 | import tempfile 6 | from datetime import date 7 | from functools import lru_cache 8 | from pathlib import Path 9 | from typing import TYPE_CHECKING, Any, TypedDict 10 | from uuid import uuid4 11 | 12 | from langchain_core.messages import HumanMessage 13 | from langchain_openai import ChatOpenAI 14 | from langgraph.checkpoint.memory import MemorySaver 15 | from pybox import AsyncLocalPyBoxManager, AsyncRemotePyBoxManager 16 | from pydantic import BaseModel, DirectoryPath, HttpUrl 17 | from pydantic_settings import BaseSettings, SettingsConfigDict 18 | from tablegpt.agent import create_tablegpt_graph 19 | from tablegpt.agent.file_reading import Stage 20 | 21 | from .evaluatee import AbstractEvaluatee 22 | 23 | if TYPE_CHECKING: 24 | from typing import Self 25 | 26 | from langchain_core.language_models import BaseLanguageModel 27 | from langchain_core.messages import BaseMessage 28 | from langgraph.graph.graph import CompiledGraph 29 | from pybox.base import BasePyBoxManager 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class IpythonSettings(BaseModel): 35 | incluster: bool = False 36 | """Use kubernetes crd create kernel. if `incluster==true` load incluster config and create kernel CR as remote kernel""" 37 | gateway_url: HttpUrl | None = None 38 | env_file: str | None = None 39 | """Path to the environment file to use for the kernel.""" 40 | 41 | 42 | # TODO: this is also somehow a copy-paste from tablegpt-chat, with slight modifications 43 | # Maybe we need to refactor that too? 44 | class Settings(BaseSettings): 45 | """Application runtime settings. 46 | 47 | We give almost everything a default value, to make unittest easier. 48 | """ 49 | 50 | model_config = SettingsConfigDict(env_nested_delimiter="__", extra="ignore") 51 | 52 | llm: dict[str, Any] = {} 53 | vlm: dict[str, Any] | None = None 54 | guard_llm: dict[str, Any] | None = None 55 | normalize_llm: dict[str, Any] | None = None 56 | """LLM used to normalize unstructured dataset""" 57 | 58 | data_vol: DirectoryPath = tempfile.gettempdir() 59 | """Data volume used to persist query results""" 60 | ipython_kernel: IpythonSettings = IpythonSettings() 61 | """Kubernetes Kernel Client settings""" 62 | error_trace_cleanup: bool = False 63 | """Enable trace cleanup to remove unnecessary error messages. 64 | This feature prunes the error trace to reduce the context length sent to the LLM, helping weaker models focus on the specific error line. 65 | When enabled, only a small context around the exact error line, along with a brief error description, is retained. 66 | While this is considered experimental, and some performance improvements have been observed, it may lead to information loss in certain situations. 67 | """ 68 | 69 | 70 | @lru_cache 71 | def get_settings() -> Settings: 72 | return Settings(_env_file=[".env"], _env_file_encoding="utf-8") 73 | 74 | 75 | @lru_cache 76 | def get_llm_instance() -> BaseLanguageModel: 77 | settings = get_settings() 78 | return ChatOpenAI(**settings.llm) 79 | 80 | 81 | @lru_cache 82 | def get_vlm_instance() -> BaseLanguageModel: 83 | settings = get_settings() 84 | if settings.vlm is None: 85 | return None 86 | return ChatOpenAI(**settings.vlm) 87 | 88 | 89 | @lru_cache 90 | def get_guard_llm_instance() -> BaseLanguageModel: 91 | settings = get_settings() 92 | if settings.guard_llm is None: 93 | return None 94 | return ChatOpenAI(**settings.guard_llm) 95 | 96 | 97 | @lru_cache 98 | def get_normalize_llm_instance() -> BaseLanguageModel: 99 | settings = get_settings() 100 | if settings.normalize_llm is None: 101 | return None 102 | return ChatOpenAI(**settings.normalize_llm) 103 | 104 | 105 | @lru_cache 106 | def get_pybox_manager() -> BasePyBoxManager: 107 | settings = get_settings() 108 | if (gateway_url := settings.ipython_kernel.gateway_url) is not None: 109 | import os 110 | 111 | # Clear default mask. Allow the kernel to read and write shared volumes. 112 | os.umask(000) 113 | return AsyncRemotePyBoxManager( 114 | host=str(gateway_url), 115 | env_file=settings.ipython_kernel.env_file, 116 | ) 117 | return AsyncLocalPyBoxManager() 118 | 119 | 120 | # TODO: a copy-paste from tablegpt-chat 121 | # We need to refactor this and push it down to tablegpt-agent 122 | class Attachment(TypedDict): 123 | filename: str 124 | mimetype: str 125 | size: int = 0 126 | 127 | 128 | class TablegptEvaluatee(AbstractEvaluatee): 129 | def __init__( 130 | self, 131 | llm: BaseLanguageModel, 132 | pybox_manager: BasePyBoxManager, 133 | data_vol: str, 134 | *, 135 | error_trace_cleanup: bool = True, 136 | vlm: BaseLanguageModel | None = None, 137 | normalize_llm: BaseLanguageModel | None = None, 138 | guard_llm: BaseLanguageModel | None = None, 139 | ): 140 | self.llm = llm 141 | self.pybox_manager = pybox_manager 142 | self.session_id = f"eval-session-{uuid4().hex}" 143 | self.workdir = Path(data_vol, self.session_id) 144 | self.error_trace_cleanup = error_trace_cleanup 145 | self.vlm = vlm 146 | self.normalize_llm = normalize_llm 147 | self.guard_llm = guard_llm 148 | 149 | async def __aenter__(self): 150 | """Initialize the context resources.""" 151 | logger.debug("Creating workdir: %s", self.workdir) 152 | self.workdir.mkdir(parents=True, exist_ok=True) 153 | 154 | logger.debug("Spawning kernel with session ID: %s", self.session_id) 155 | await self.pybox_manager.start(kernel_id=self.session_id, cwd=self.workdir) 156 | 157 | return self 158 | 159 | async def __aexit__(self, exc_type, exc_value, traceback): 160 | """Clean up the context resources.""" 161 | logger.debug("Cleaning up worker resources...") 162 | logger.debug("Shutting down kernel: %s", self.session_id) 163 | await self.pybox_manager.shutdown(self.session_id) 164 | 165 | logger.debug("Removing workdir: %s", self.workdir) 166 | shutil.rmtree(self.workdir, ignore_errors=True) 167 | 168 | logger.debug("Worker resources cleaned up") 169 | 170 | async def _call(self, message: BaseMessage, **kwargs) -> list[BaseMessage]: # noqa: ARG002 171 | checkpointer = MemorySaver() 172 | config = { 173 | "configurable": {"thread_id": self.session_id}, 174 | } 175 | tablegpt_graph: CompiledGraph = create_tablegpt_graph( 176 | llm=self.llm, 177 | pybox_manager=self.pybox_manager, 178 | workdir=self.workdir, 179 | vlm=self.vlm, 180 | session_id=self.session_id, 181 | checkpointer=checkpointer, 182 | normalize_llm=self.normalize_llm, 183 | safety_llm=self.guard_llm, 184 | error_trace_cleanup=self.error_trace_cleanup, 185 | ).with_config( 186 | config=config, 187 | ) 188 | parent_id = str(uuid4()) 189 | attachments = [ 190 | Attachment(filename=file, mimetype="text/csv") for file in message.additional_kwargs.get("attachments", []) 191 | ] 192 | attachment_msg = HumanMessage( 193 | content="", 194 | additional_kwargs={ 195 | "parent_id": parent_id, 196 | "attachments": attachments, 197 | "var_name": "df", 198 | }, 199 | ) 200 | try: 201 | # file reading 202 | await tablegpt_graph.ainvoke( 203 | input={ 204 | "messages": [attachment_msg], 205 | "parent_id": parent_id, 206 | "entry_message": attachment_msg, 207 | "processing_stage": Stage.UPLOADED, 208 | } 209 | ) 210 | # data analysis 211 | state = await tablegpt_graph.ainvoke( 212 | input={ 213 | "parent_id": str(uuid4()), 214 | "messages": [HumanMessage(content=message.content)], 215 | "date": date.today(), # noqa: DTZ011 216 | } 217 | ) 218 | return state["messages"] 219 | except Exception as e: # noqa: BLE001 220 | logger.warning("Tablegpt evaluatee execution failed: %s", str(e)) 221 | checkpoint = await checkpointer.aget(config=config) 222 | return checkpoint["channel_values"].get("messages", []) 223 | 224 | @property 225 | def context(self): 226 | return {"workdir": self.workdir, "session_id": self.session_id} 227 | 228 | @classmethod 229 | def instance(cls) -> Self: 230 | settings = get_settings() 231 | return cls( 232 | llm=get_llm_instance(), 233 | pybox_manager=get_pybox_manager(), 234 | data_vol=settings.data_vol, 235 | error_trace_cleanup=settings.error_trace_cleanup, 236 | vlm=get_vlm_instance(), 237 | normalize_llm=get_normalize_llm_instance(), 238 | guard_llm=get_guard_llm_instance(), 239 | ) 240 | -------------------------------------------------------------------------------- /realtabbench/agent_eval/worker.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import asyncio 4 | import json 5 | import logging 6 | import traceback 7 | from typing import TYPE_CHECKING, Any 8 | 9 | import aiofiles 10 | from langchain_core.messages import AIMessage 11 | from langchain_openai import ChatOpenAI 12 | 13 | from .evaluator import create_evaluator_runnable 14 | from .evaluator.prompt import DEFAULT_CRITERIA_WITH_REFERENCE_ANSWER, DEFAULT_CRITERIA_WITHOUT_REFERENCE_ANSWER 15 | 16 | if TYPE_CHECKING: 17 | from langchain_core.messages import BaseMessage 18 | from tqdm.asyncio import tqdm 19 | 20 | from .evaluatee import AbstractEvaluatee 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class Worker: 26 | def __init__( 27 | self, 28 | queue: asyncio.Queue, 29 | evaluatee: AbstractEvaluatee, 30 | stop_event: asyncio.Event | None = None, 31 | pbar: tqdm | None = None, 32 | evaluator_config: dict[str, Any] | None = None, 33 | eval_run_output_file: str = "eval-result.jsonl", 34 | ) -> None: 35 | self.queue = queue 36 | self.evaluatee = evaluatee 37 | self.stop_event = stop_event 38 | self.pbar = pbar 39 | self.evaluator_config = evaluator_config if evaluator_config else {} 40 | self.eval_run_output_file = eval_run_output_file 41 | 42 | async def run(self) -> None: 43 | logger.info("Worker started") 44 | async with self.evaluatee: 45 | while self.stop_event is None or not self.stop_event.is_set(): 46 | try: 47 | sample = self.queue.get_nowait() 48 | executor = EvalExecutor(self.evaluatee, self.evaluator_config, self.eval_run_output_file) 49 | await executor.run(sample) 50 | if self.pbar is not None: 51 | self.pbar.update(1) 52 | except asyncio.QueueEmpty: 53 | # No more tasks in the queue, quit current worker 54 | logger.info("Worker finished") 55 | break 56 | except Exception: 57 | logger.exception("Worker encountered an error") 58 | # Set the stop event to cancel other workers 59 | if self.stop_event is not None: 60 | self.stop_event.set() 61 | break 62 | 63 | 64 | class EvalExecutor: 65 | def __init__( 66 | self, 67 | evaluatee: AbstractEvaluatee, 68 | evaluator_config: dict[str, Any], 69 | eval_run_output_file: str = "eval-result.jsonl", 70 | ) -> None: 71 | self.evaluator = create_evaluator_runnable(ChatOpenAI(**evaluator_config)) 72 | self.evaluatee = evaluatee 73 | self.eval_run_output_file = eval_run_output_file 74 | 75 | async def run(self, sample: BaseMessage) -> None: 76 | """Run the evaluation workflow. 77 | Usually a evaluatee runnable will be executed, followed by a evaluator runnable. 78 | 79 | Args: 80 | sample (BaseMessage): Evaluation sample. 81 | """ 82 | logger.debug("Evaluating sample: %s", sample) 83 | criteria = ( 84 | sample.additional_kwargs.get("criteria") 85 | if sample.additional_kwargs.get("criteria") 86 | else ( 87 | DEFAULT_CRITERIA_WITH_REFERENCE_ANSWER 88 | if sample.additional_kwargs.get("expected_output") 89 | else DEFAULT_CRITERIA_WITHOUT_REFERENCE_ANSWER 90 | ) 91 | ) 92 | reference_answer = sample.additional_kwargs.get("expected_output") 93 | redlines = sample.additional_kwargs.get("redlines", []) 94 | 95 | eval_result = { 96 | "input": sample.content, 97 | "reference_answer": reference_answer, 98 | "evaluatee_answer": "", 99 | "criteria": criteria, 100 | "redlines": redlines, 101 | } 102 | 103 | try: 104 | eval_result["messages"] = await self.evaluatee(sample) 105 | except Exception: 106 | logger.exception( 107 | "Evaluation Workflow failed, item: %s, context: %s", 108 | sample, 109 | self.evaluatee.context, 110 | ) 111 | eval_result["messages"] = [] 112 | # We treat any exception in agent invocation as a bad case 113 | eval_result["evaluation"] = { 114 | "score": 0, 115 | "explaination": traceback.format_exc(), 116 | } 117 | 118 | try: 119 | if not eval_result["messages"]: 120 | raise ValueError( # noqa: TRY301, TRY003 121 | "Evaluatee did not generate any messages." # noqa: EM101 122 | "Ensure the Evaluatee is implemented correctly and returns a valid response." 123 | ) 124 | 125 | if not isinstance(eval_result["messages"][-1], AIMessage): 126 | raise TypeError( # noqa: TRY301, TRY003 127 | f"The final message in the output from Evaluatee is of type '{type(eval_result["messages"][-1]).__name__}', " # noqa: EM102 128 | "but it must be an instance of 'AIMessage'. Please verify the Evaluatee implementation." 129 | ) 130 | 131 | evaluatee_answer = eval_result["messages"][-1].content 132 | eval_result["evaluatee_answer"] = evaluatee_answer 133 | eval_result["evaluation"] = await self.evaluator.ainvoke( 134 | input={ 135 | "question": sample.content, 136 | "reference_answer": reference_answer, 137 | "answer": evaluatee_answer, 138 | "criteria": criteria, 139 | "redlines": redlines, 140 | }, 141 | ) 142 | except Exception: 143 | logger.exception( 144 | "Evaluator invocation failed, item: %s, context: %s", 145 | sample, 146 | self.evaluatee.context, 147 | ) 148 | # We treat any exception in evaluator invocation as a bad case 149 | eval_result["evaluation"] = { 150 | "score": 0, 151 | "explaination": traceback.format_exc(), 152 | } 153 | 154 | eval_result["messages"] = [message.model_dump() for message in eval_result["messages"]] 155 | 156 | async with aiofiles.open(self.eval_run_output_file, mode="a") as f: 157 | await f.write(json.dumps(eval_result, ensure_ascii=False) + "\n") 158 | -------------------------------------------------------------------------------- /realtabbench/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | from transformers import AutoTokenizer 5 | from vllm import LLM, SamplingParams 6 | 7 | 8 | def get_infer_kwargs(args) -> dict: 9 | """llm_inference kwargs""" 10 | temperature = args.temperature if args.temperature else 1.0 11 | max_new_tokens = args.max_new_tokens if args.max_new_tokens else 1024 12 | model_type = args.model_type if args.model_type else "chat_model" 13 | 14 | return { 15 | "temperature": temperature, 16 | "max_tokens": max_new_tokens, 17 | "model_type": model_type, 18 | } 19 | 20 | 21 | def load_tokenizer_and_template(model_name_or_path, template=None): 22 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) 23 | 24 | if tokenizer.chat_template is None: 25 | if template is not None: 26 | chatml_jinja_path = ( 27 | pathlib.Path(os.path.dirname(os.path.abspath(__file__))) / f"templates/template_{template}.jinja" 28 | ) 29 | assert chatml_jinja_path.exists() # noqa: S101 30 | with open(chatml_jinja_path) as f: 31 | tokenizer.chat_template = f.read() 32 | else: 33 | pass 34 | # raise ValueError("chat_template is not found in the config file, please provide the template parameter.") 35 | return tokenizer 36 | 37 | 38 | def load_model(model_name_or_path, max_model_len=None, gpus_num=1): 39 | llm_args = { 40 | "model": model_name_or_path, 41 | "gpu_memory_utilization": 0.95, 42 | "trust_remote_code": True, 43 | "tensor_parallel_size": gpus_num, 44 | "dtype": "half", 45 | } 46 | 47 | if max_model_len: 48 | llm_args["max_model_len"] = max_model_len 49 | 50 | # Create an LLM. 51 | return LLM(**llm_args) 52 | 53 | 54 | def generate_outputs(messages_batch, llm_model, tokenizer, generate_args): 55 | """ 56 | messages = [ 57 | {"role": "system", "content": "You are a helpful assistant."}, 58 | {"role": "user", "content": prompt} 59 | ] 60 | 61 | messages_batch = [messages] 62 | 63 | generate_args = { 64 | "max_new_tokens": 1024, 65 | "do_sample": True or False, 66 | "temperature": 0-1, 67 | "" 68 | } 69 | """ 70 | model_type = generate_args.pop("model_type", "chat_model") 71 | # 添加一个默认参数, 抑制instruct-following能力较差的模型, 输出重复内容, 考虑加入参数配置 72 | # generate_args["presence_penalty"] = 2.0 73 | sampling_params = SamplingParams(**generate_args) 74 | 75 | prompt_batch = [] 76 | for messages in messages_batch: 77 | # 如果是basemodel, 直接拼接prompt内容后输入到模型 78 | if model_type == "base_model": 79 | messages_content = [msg["content"] for msg in messages] 80 | prompt = "\n".join(messages_content) 81 | # 如果是chat—model 则拼接chat-template后输入 82 | else: 83 | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 84 | prompt_batch.append(prompt) 85 | 86 | outputs = llm_model.generate(prompt_batch, sampling_params) 87 | 88 | outputs_batch = [] 89 | for output in outputs: 90 | prompt_output = output.prompt 91 | generated_text = output.outputs[0].text 92 | outputs_batch.append({"input_prompt": prompt_output, "output_text": generated_text}) 93 | 94 | return outputs_batch 95 | -------------------------------------------------------------------------------- /realtabbench/inference_encoder.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import copy 3 | import gc 4 | import logging 5 | 6 | import pandas as pd 7 | import torch 8 | from vllm import LLM 9 | from vllm.distributed import destroy_distributed_environment, destroy_model_parallel 10 | from vllm.sampling_params import SamplingParams 11 | from vllm.utils import is_cpu 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def extract_contrastive_table(df: pd.DataFrame): 17 | # Convert DataFrame to the desired format 18 | return { 19 | "columns": [ 20 | { 21 | "name": col, 22 | "dtype": str(df[col].dtype), 23 | "contains_nan": df[col].isnull().any(), 24 | "is_unique": df[col].nunique() == len(df[col]), 25 | "values": df[col].tolist(), # slice? 26 | } 27 | for col in df.columns 28 | ] 29 | } 30 | 31 | 32 | def cleanup(): 33 | destroy_model_parallel() 34 | destroy_distributed_environment() 35 | with contextlib.suppress(AssertionError): 36 | torch.distributed.destroy_process_group() 37 | gc.collect() 38 | if not is_cpu(): 39 | torch.cuda.empty_cache() 40 | 41 | 42 | def inference_with_encoder(args, format_msg_datas): 43 | logger.info("Load model...") 44 | model = LLM( 45 | model=args.model_path, 46 | max_model_len=args.max_model_len, 47 | gpu_memory_utilization=0.8, 48 | max_num_seqs=20, 49 | limit_mm_per_prompt={"table": 10}, 50 | # dtype="half", 51 | dtype="bfloat16", 52 | ) 53 | 54 | sparams = SamplingParams(temperature=args.temperature, max_tokens=args.max_new_tokens) 55 | model_outputs = model.chat(messages=format_msg_datas, sampling_params=sparams) 56 | model_outputs_text = [mot.outputs[0].text for mot in model_outputs] 57 | 58 | del model 59 | cleanup() 60 | return model_outputs_text 61 | 62 | 63 | def truncate(value, max_length=80): 64 | if not isinstance(value, str) or len(value) <= max_length: 65 | return value 66 | return value[:max_length] + "..." 67 | 68 | 69 | def format_encoder_tables(df_names, table_paths): 70 | tables = [] 71 | tables_info = [] 72 | for idx, table_path in enumerate(table_paths): 73 | df_name = df_names[idx] 74 | df = pd.read_csv(table_path, encoding="utf-8", nrows=500) 75 | df.columns = df.columns.str.strip() 76 | df = df.dropna(how="all").dropna(axis=1, how="all") 77 | # 限制超过列时截断 78 | max_columns = 50 # 可以根据你的需求设置这个数量 79 | if len(df.columns) > max_columns: 80 | df = df.iloc[:, :max_columns] 81 | 82 | df_extra_info = extract_contrastive_table(df) 83 | tables_info.append(copy.deepcopy(f"Details about the '{df_name}' other info as follows:\n\n")) 84 | tables.append(copy.deepcopy(df_extra_info)) 85 | 86 | tables_list = [ 87 | { 88 | "type": "table", 89 | "table": tb, 90 | } 91 | for tb in tables 92 | ] 93 | 94 | return tables_list, tables_info 95 | 96 | 97 | def build_encoder_table_part_content(df_names, table_paths): 98 | content_msg = [] 99 | for idx, table_path in enumerate(table_paths): 100 | content_msg.append( 101 | { 102 | "type": "text", 103 | "text": f"/*\nDetails about the '{df_names[idx]}' other info as follows:\n", 104 | } 105 | ) 106 | # 读取df并处理 107 | df = pd.read_csv(table_path, encoding="utf-8", nrows=500) 108 | df.columns = df.columns.str.strip() 109 | df = df.dropna(how="all").dropna(axis=1, how="all") 110 | # 限制超过列时截断 111 | max_columns = 50 # 可以根据你的需求设置这个数量 112 | if len(df.columns) > max_columns: 113 | df = df.iloc[:, :max_columns] 114 | 115 | content_msg.append({"type": "table", "table": extract_contrastive_table(copy.deepcopy(df))}) 116 | content_msg.append( 117 | { 118 | "type": "text", 119 | "text": "*/", 120 | } 121 | ) 122 | 123 | return content_msg 124 | 125 | 126 | def read_df_head(table_path, head_num=3, format_type="string"): 127 | df = pd.read_csv(table_path, encoding="utf-8", nrows=500) 128 | df.columns = df.columns.str.strip() 129 | df = df.dropna(how="all").dropna(axis=1, how="all") 130 | # 限制超过列时截断 131 | max_columns = 50 # 可以根据你的需求设置这个数量 132 | if len(df.columns) > max_columns: 133 | df = df.iloc[:, :max_columns] 134 | 135 | df_head = copy.deepcopy(df.head(head_num)) 136 | df_truncated_head = df_head.apply(lambda x: x.map(lambda y: truncate(y, 80))) 137 | if format_type == "string": 138 | df_truncated_head_str = df_truncated_head.to_string() 139 | elif format_type == "md": 140 | df_truncated_head_str = df_truncated_head.to_markdown(index=False) 141 | else: 142 | df_truncated_head_str = df_truncated_head.to_string() 143 | return df_truncated_head_str, df 144 | -------------------------------------------------------------------------------- /realtabbench/requirements.txt: -------------------------------------------------------------------------------- 1 | vllm>=0.7.2 2 | defog_data==0.1.1 3 | func_timeout==4.3.5 4 | langchain==0.2.5 5 | langchain_core==0.2.43 6 | langchain_experimental==0.0.61 7 | langchain_openai==0.1.8 8 | numpy 9 | pandas 10 | Requests==2.32.4 11 | scikit_learn==1.3.2 12 | sentence_transformers==3.0.1 13 | spacy==3.7.5 14 | sql_metadata==2.11.0 15 | SQLAlchemy==2.0.8 16 | tqdm 17 | matplotlib 18 | mplfonts 19 | seaborn 20 | einops 21 | tabulate==0.9.0 22 | pypinyin 23 | rouge_score==0.1.2 24 | nltk==3.9.1 25 | evaluate==0.4.2 26 | sacrebleu==2.4.3 27 | bert_score==0.3.13 28 | absl-py 29 | einops 30 | babel==2.16.0 31 | openpyxl==3.1.5 32 | datasets==2.21.0 -------------------------------------------------------------------------------- /realtabbench/run_text2sql_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from text2sql.src.evaluation import evaluation_main 5 | from text2sql.src.gpt_request import generate_main 6 | from text2sql.src.gpt_request_encoder import generate_main_encoder 7 | 8 | 9 | def main(args): 10 | if args.eval_data_name == "bird" and args.mode == "dev": 11 | args.db_root_path = "eval/evalset/bird_data/dev_databases" 12 | args.eval_data_path = "eval/evalset/bird_data/dev.json" 13 | args.ground_truth_path = "eval/evalset/bird_data/dev.sql" 14 | 15 | if args.eval_data_name == "spider" and args.mode == "test": 16 | args.db_root_path = "eval/evalset/spider_data/test_database" 17 | args.eval_data_path = "eval/evalset/spider_data/test.json" 18 | args.ground_truth_path = "eval/evalset/spider_data/test_gold.sql" 19 | 20 | if args.eval_data_name == "spider" and args.mode == "dev": 21 | args.db_root_path = "eval/evalset/spider_data/dev_database" 22 | args.eval_data_path = "eval/evalset/spider_data/dev.json" 23 | args.ground_truth_path = "eval/evalset/spider_data/dev_gold.sql" 24 | 25 | if args.is_use_knowledge: 26 | args.use_knowledge = "True" 27 | else: 28 | args.use_knowledge = "False" 29 | with open(args.eval_data_path) as f: 30 | eval_datas = json.load(f) 31 | if args.use_encoder: 32 | predicted_sql_path = generate_main_encoder(eval_datas, args) 33 | else: 34 | predicted_sql_path = generate_main(eval_datas, args) 35 | 36 | evaluation_main(args, eval_datas, predicted_sql_path) 37 | 38 | 39 | if __name__ == "__main__": 40 | args_parser = argparse.ArgumentParser() 41 | args_parser.add_argument("--eval_type", type=str, choices=["ex"], default="ex") 42 | args_parser.add_argument("--eval_data_name", type=str, choices=["bird", "spider"], default="bird") 43 | args_parser.add_argument("--mode", type=str, choices=["dev", "test"], default="dev") 44 | args_parser.add_argument("--is_use_knowledge", default=True, action="store_true") 45 | args_parser.add_argument("--data_output_path", type=str, default="realtabbench/text2sql/output") 46 | args_parser.add_argument("--chain_of_thought", type=str, default="True") 47 | args_parser.add_argument("--model_path", type=str) # , required=True 48 | args_parser.add_argument("--gpus_num", type=int, default=1) 49 | args_parser.add_argument("--num_cpus", type=int, default=4) 50 | args_parser.add_argument("--meta_time_out", type=float, default=30.0) 51 | args_parser.add_argument("--use_encoder", default=False, action="store_true") 52 | args_parser.add_argument("--use_gpt_api", default=False, action="store_true") 53 | 54 | args = args_parser.parse_args() 55 | main(args) 56 | -------------------------------------------------------------------------------- /realtabbench/text2sql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tablegpt/tablegpt-agent/6195b6437b01a3539f70af1655afc1e2e6012958/realtabbench/text2sql/__init__.py -------------------------------------------------------------------------------- /realtabbench/text2sql/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tablegpt/tablegpt-agent/6195b6437b01a3539f70af1655afc1e2e6012958/realtabbench/text2sql/src/__init__.py -------------------------------------------------------------------------------- /realtabbench/text2sql/src/evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sqlite3 5 | import sys 6 | 7 | from func_timeout import FunctionTimedOut, func_timeout 8 | from joblib import Parallel, delayed 9 | from tqdm import tqdm 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def load_json(dir): # noqa: A002 15 | with open(dir) as j: 16 | return json.loads(j.read()) 17 | 18 | 19 | # def result_callback(result): 20 | # exec_result.append(result) 21 | 22 | 23 | def execute_sql(predicted_sql, ground_truth, db_path): 24 | conn = sqlite3.connect(db_path) 25 | # Connect to the database 26 | cursor = conn.cursor() 27 | cursor.execute(predicted_sql) 28 | predicted_res = cursor.fetchall() 29 | cursor.execute(ground_truth) 30 | ground_truth_res = cursor.fetchall() 31 | res = 0 32 | if set(predicted_res) == set(ground_truth_res): 33 | res = 1 34 | return { 35 | "res": res, 36 | "predicted_res": list(set(predicted_res)), 37 | "ground_truth_res": list(set(ground_truth_res)), 38 | } 39 | 40 | 41 | def execute_model(sql_pair, db_place, idx, meta_time_out): 42 | predicted_sql, ground_truth = sql_pair 43 | try: 44 | res_dict = func_timeout(meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place)) 45 | except KeyboardInterrupt: 46 | sys.exit(0) 47 | except FunctionTimedOut: 48 | # result = [('timeout',)] 49 | res_dict = {"res": 0, "exec_detail": "timeout"} 50 | except Exception: # noqa: BLE001 51 | # result = [('error',)] # possibly len(query) > 512 or not executable 52 | res_dict = {"res": 0, "exec_detail": "error"} 53 | return {"sql_idx": idx, "res": res_dict["res"], "detail": res_dict} 54 | 55 | 56 | def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"): # noqa: ARG001 57 | clean_sqls = [] 58 | db_path_list = [] 59 | if mode == "gpt": 60 | with open(sql_path) as f: 61 | sql_data = json.load(f) 62 | for sql_str in sql_data.values(): 63 | if isinstance(sql_str, str): 64 | sql, db_name = sql_str.split("\t----- bird -----\t") 65 | else: 66 | sql, db_name = " ", "financial" 67 | clean_sqls.append(sql) 68 | db_path_list.append(os.path.join(db_root_path, db_name, f"{db_name}.sqlite")) 69 | 70 | elif mode == "gt": 71 | with open(sql_path) as sqls: 72 | sql_txt = sqls.readlines() 73 | # sql_txt = [sql.split('\t')[0] for sql in sql_txt] 74 | for _, sql_str in enumerate(sql_txt): 75 | sql, db_name = sql_str.strip().split("\t") 76 | clean_sqls.append(sql) 77 | db_path_list.append(os.path.join(db_root_path, db_name, f"{db_name}.sqlite")) 78 | 79 | return clean_sqls, db_path_list 80 | 81 | 82 | def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): 83 | if num_cpus > 1: 84 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 85 | return Parallel(n_jobs=num_cpus)( 86 | delayed(execute_model)(sqls[i], db_places[i], i, meta_time_out) for i in tqdm(range(len(sqls)), desc="exec") 87 | ) 88 | 89 | 90 | def sort_results(list_of_dicts): 91 | return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) 92 | 93 | 94 | def compute_acc_by_diff(exec_results, contents): 95 | num_queries = len(exec_results) 96 | results = [res["res"] for res in exec_results] 97 | 98 | simple_results, moderate_results, challenging_results = [], [], [] 99 | 100 | for i, content in enumerate(contents): 101 | if i >= len(exec_results): 102 | continue 103 | if content["difficulty"] == "simple": 104 | simple_results.append(exec_results[i]) 105 | 106 | if content["difficulty"] == "moderate": 107 | moderate_results.append(exec_results[i]) 108 | 109 | if content["difficulty"] == "challenging": 110 | challenging_results.append(exec_results[i]) 111 | 112 | simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) if len(simple_results) != 0 else 0 113 | 114 | if len(moderate_results) != 0: 115 | moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) 116 | else: 117 | moderate_acc = 0 118 | 119 | if len(challenging_results) != 0: 120 | challenging_acc = sum([res["res"] for res in challenging_results]) / len(challenging_results) 121 | else: 122 | challenging_acc = 0 123 | 124 | all_acc = sum(results) / num_queries 125 | count_lists = [ 126 | len(simple_results), 127 | len(moderate_results), 128 | len(challenging_results), 129 | num_queries, 130 | ] 131 | return ( 132 | simple_acc * 100, 133 | moderate_acc * 100, 134 | challenging_acc * 100, 135 | all_acc * 100, 136 | count_lists, 137 | ) 138 | 139 | 140 | def print_data(score_lists, count_lists): 141 | print( # noqa: T201 142 | "====================================== ACCURACY =====================================" 143 | ) 144 | levels = ["simple", "moderate", "challenging", "total"] 145 | print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) # noqa: T201 146 | print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) # noqa: T201 147 | 148 | print( # noqa: T201 149 | "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists) 150 | ) 151 | 152 | 153 | def evaluation_main(args, eval_datas, predicted_sql_path): 154 | exec_result = [] 155 | 156 | pred_queries, db_paths = package_sqls(predicted_sql_path, args.db_root_path, mode="gpt", data_mode=args.mode) 157 | # generate gt sqls: 158 | gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.mode) 159 | 160 | query_pairs = list(zip(pred_queries, gt_queries)) 161 | exec_result = run_sqls_parallel( 162 | query_pairs, 163 | db_places=db_paths, 164 | num_cpus=args.num_cpus, 165 | meta_time_out=args.meta_time_out, 166 | ) 167 | exec_result = sort_results(exec_result) 168 | 169 | # save_result 170 | res = [] 171 | for sql_pair, exec_res, data in zip(query_pairs, exec_result, eval_datas): 172 | predicted_sql, ground_truth = sql_pair 173 | exec_res["ground_truth"] = ground_truth 174 | exec_res["predicted_sql"] = predicted_sql 175 | exec_res["question"] = data["question"] 176 | exec_res["difficulty"] = data["difficulty"] 177 | res.append(exec_res) 178 | output_path = predicted_sql_path.replace(".json", "_exec.json") 179 | with open(output_path, "w") as f: 180 | json.dump(res, f, indent=4) 181 | 182 | print("start calculate") # noqa: T201 183 | simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(exec_result, eval_datas) 184 | score_lists = [simple_acc, moderate_acc, challenging_acc, acc] 185 | print_data(score_lists, count_lists) 186 | print( # noqa: T201 187 | "===========================================================================================" 188 | ) 189 | print("Finished evaluation") # noqa: T201 190 | -------------------------------------------------------------------------------- /realtabbench/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import ast 4 | import json 5 | import random 6 | import re 7 | import signal 8 | import threading 9 | from contextlib import contextmanager 10 | from typing import Any 11 | 12 | import pandas as pd 13 | from evaluate_code_correction.pytool import extract_last_df, format_result 14 | from langchain_experimental.tools.python.tool import PythonAstREPLTool 15 | 16 | 17 | def read_jsonl(file_path): 18 | with open(file_path, encoding="utf-8") as f: 19 | return [json.loads(line.strip()) for line in f] 20 | 21 | 22 | def load_json(data_path): 23 | """ 24 | # 加载 json 文件 25 | """ 26 | with open(data_path, encoding="utf-8") as f: 27 | return json.load(f) 28 | 29 | 30 | def save_json(data_path, data_list): 31 | """ 32 | # 保存 json 文件 33 | """ 34 | with open(data_path, "w", encoding="utf-8") as f: 35 | json.dump(data_list, f, ensure_ascii=False) 36 | 37 | 38 | def get_dfs_info(table_paths): 39 | """将所有csv文件对应的df-info拼装到一起""" 40 | infos_list = [] 41 | if len(table_paths) == 1: 42 | df_markdown_info = str(pd.read_csv(table_paths[0], encoding="utf-8").head(5).to_string(index=False)) 43 | normalized_head = f"""/*\n"df.head()" as follows:\n{df_markdown_info}\n*/""" 44 | infos_list.append(normalized_head) 45 | else: 46 | for i, path in enumerate(table_paths): 47 | # normalized_name = normalize_table_name(path) 48 | df_markdown_info = str(pd.read_csv(path, encoding="utf-8").head(5).to_string(index=False)) 49 | normalized_head = f"""/*\n"df{i+1}.head()" as follows:\n{df_markdown_info}\n*/""" 50 | # single_table_name = "\n".join([normalized_head, df_markdown_info]) 51 | infos_list.append(normalized_head) 52 | return "\n".join(infos_list) 53 | 54 | 55 | def sample_from_two_lists(list1, list2, threshold=0.5): 56 | # 如果列表为空, 则直接返回None或抛出异常, 取决于你的需求 57 | if not list1 or not list2: 58 | return None # 或者你可以抛出异常 59 | 60 | # 生成一个0到1之间的随机浮点数 61 | rand_val = random.random() # noqa: S311 62 | 63 | # 如果随机数小于0.5, 从第一个列表中采样 64 | if rand_val < threshold: 65 | return random.choice(list1) # noqa: S311 66 | # 否则, 从第二个列表中采样 67 | return random.choice(list2) # noqa: S311 68 | 69 | 70 | def recraft_query(query, _locals): 71 | last_df = extract_last_df(query, _locals) 72 | end_str = "\n" + format_result + f"print(format_result({last_df}))" 73 | return query + end_str 74 | 75 | 76 | def extract_code_without_comments(code): 77 | """ 78 | 从Python代码中提取除注释行以外的代码。 79 | 80 | :param code: str, 输入的Python代码 81 | :return: str, 提取后的代码 82 | """ 83 | code = re.sub(r'"""[\s\S]*?"""', "", code) 84 | code = re.sub(r"'''[\s\S]*?'''", "", code) 85 | 86 | # 移除单行注释 87 | lines = code.split("\n") 88 | cleaned_lines = [] 89 | for line in lines: 90 | # 移除以 # 开始的注释, 但保留字符串中的 # 91 | cleaned_line = re.sub(r'(? bool: 98 | """Tool function to check if a line of text is Python code""" 99 | try: 100 | tree = ast.parse(line) 101 | # Check if the parsed tree has at least one node that represents executable code 102 | for node in ast.walk(tree): 103 | if isinstance( 104 | node, 105 | ( 106 | ast.Expr, 107 | ast.Assign, 108 | ast.FunctionDef, 109 | ast.ClassDef, 110 | ast.Import, 111 | ast.ImportFrom, 112 | ast.For, 113 | ast.While, 114 | ast.If, 115 | ast.With, 116 | ast.Raise, 117 | ast.Try, 118 | ), 119 | ): 120 | return True 121 | return False # noqa: TRY300 122 | except SyntaxError: 123 | return False 124 | 125 | 126 | def extract_text_before_code(text: str) -> str: 127 | """Tool function for extract text content""" 128 | lines = text.split("\n") 129 | text_before_code = [] 130 | 131 | for line in lines: 132 | if is_python_code(line): 133 | break 134 | text_before_code.append(line) 135 | 136 | return "\n".join(text_before_code) 137 | 138 | 139 | def extract_python_code(text: str) -> str: 140 | """Tool function for extract python code""" 141 | lines = text.split("\n") 142 | python_code = [line for line in lines if is_python_code(line)] 143 | return "\n".join(python_code) 144 | 145 | 146 | def fix_indents(text: str) -> str: 147 | return text.replace("\t", " ") 148 | 149 | 150 | def filter_cot(completion: str): 151 | """ 152 | Filter the COT steps before python code 153 | :param completion: llm output contents 154 | :return filtered COT content 155 | """ 156 | try: 157 | # 如果输出较为规范, 可以使用这种方式提取cot部分的内容 158 | pattern = r"Thought:\s*(.*?)\s*(?=Python Code:)" 159 | match = re.search(pattern, completion, re.DOTALL) 160 | return match.group(1) if match else extract_text_before_code(completion) 161 | except: # noqa: E722 162 | return "" 163 | 164 | 165 | def filter_code(completion: str) -> tuple[str, str]: 166 | """ 167 | Filter python code from the llm output completion 168 | :param completion: llm output contents 169 | :return filtered python code and execute code 170 | """ 171 | 172 | try: 173 | # 输出形式符合prompt 174 | regex = r"```python\s(.*?)```" 175 | action_match = re.search(regex, completion, re.DOTALL) 176 | if action_match: 177 | action_input = action_match.group(1) 178 | action_input = action_input.strip(" ") 179 | action_input = action_input.strip('"') 180 | code = action_input.strip(" ") 181 | else: 182 | # 输出形式随意 183 | code = extract_python_code(completion) 184 | code = code.strip(" ") 185 | pure_code = extract_code_without_comments(code) 186 | return code, pure_code # noqa: TRY300 187 | except: # noqa: E722 188 | return "", "" 189 | 190 | 191 | def get_tool(df: Any, df_names=None): 192 | """ 193 | Define python code execute tool 194 | :param df: List[pd.DataFrame] or pd.DataFrame 195 | :return Runnable 196 | """ 197 | tool = PythonAstREPLTool() 198 | if df_names is None: 199 | if isinstance(df, pd.DataFrame): 200 | _locals = {"df": df} 201 | else: 202 | _locals = {} 203 | for i, dataframe in enumerate(df): 204 | _locals[f"df{i + 1}"] = dataframe 205 | else: 206 | _locals = {} 207 | for i, dataframe in enumerate(df): 208 | _locals[df_names[i]] = dataframe 209 | tool.locals = _locals 210 | tool.globals = tool.locals 211 | return tool 212 | 213 | 214 | def get_table_infos(table_paths): 215 | """将所有csv文件对应的df-info拼装到一起""" 216 | infos_list = [] 217 | if len(table_paths) == 1: 218 | df_markdown_info = str(pd.read_csv(table_paths[0], encoding="utf-8").head(3).to_markdown(index=False)) 219 | normalized_head = f"""/*\n"df.head()" as follows:\n{df_markdown_info}\n*/""" 220 | infos_list.append(normalized_head) 221 | else: 222 | for i, path in enumerate(table_paths): 223 | # normalized_name = normalize_table_name(path) 224 | df_markdown_info = str(pd.read_csv(path, encoding="utf-8").head(3).to_markdown(index=False)) 225 | normalized_head = f"""/*\n"df{i+1}.head()" as follows:\n{df_markdown_info}\n*/""" 226 | # single_table_name = "\n".join([normalized_head, df_markdown_info]) 227 | infos_list.append(normalized_head) 228 | return "\n".join(infos_list) 229 | 230 | 231 | # 定义一个异常类, 用于超时处理 232 | class TimeoutException(Exception): # noqa: N818 233 | pass 234 | 235 | 236 | # 创建一个上下文管理器来处理超时 237 | @contextmanager 238 | def timeout(time): 239 | # 定义信号处理函数 240 | def raise_timeout(signum, frame): # noqa: ARG001 241 | raise TimeoutException( # noqa: TRY003 242 | f"Timeout error, running time exceed {time}" # noqa: EM102 243 | ) 244 | 245 | # 设置信号定时器 246 | signal.signal(signal.SIGALRM, raise_timeout) 247 | signal.alarm(time) 248 | try: 249 | yield 250 | finally: 251 | # 取消信号定时器 252 | signal.alarm(0) 253 | 254 | 255 | def run_code(code, result, tool): 256 | try: 257 | # 在子线程中运行代码 258 | result.append(tool.run(code)) 259 | except Exception as e: # noqa: BLE001 260 | result.append(e) 261 | 262 | 263 | def execute_with_timeout(code, timeout_seconds, tool): 264 | result = [] 265 | thread = threading.Thread(target=run_code, args=(code, result, tool)) 266 | thread.start() 267 | thread.join(timeout_seconds) 268 | 269 | if thread.is_alive(): 270 | # 终止子线程 271 | thread._stop() # noqa: SLF001 272 | raise TimeoutException( # noqa: TRY003 273 | f"Timeout error, running time exceed {timeout_seconds} seconds" # noqa: EM102 274 | ) 275 | else: # noqa: RET506 276 | if isinstance(result[0], Exception): 277 | raise result[0] 278 | return result[0] 279 | -------------------------------------------------------------------------------- /src/tablegpt/__about__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.27" 2 | -------------------------------------------------------------------------------- /src/tablegpt/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sysconfig 4 | import warnings 5 | from pathlib import Path 6 | 7 | 8 | def _find_tablegpt_ipykernel_profile_dir(): 9 | # https://docs.python.org/3.11/library/sysconfig.html#sysconfig.get_path 10 | # https://docs.python.org/3.11/library/sysconfig.html#user-scheme 11 | _py_root = Path(sysconfig.get_path("data")) 12 | 13 | possible_profile_dir = Path(_py_root, "share", "ipykernel", "profile", "tablegpt") 14 | 15 | _startup_folder = Path(possible_profile_dir, "startup") 16 | try: 17 | if next(_startup_folder.glob(r"*-udfs.py")): 18 | return str(possible_profile_dir) 19 | except StopIteration: 20 | return 21 | 22 | 23 | DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR: str | None = _find_tablegpt_ipykernel_profile_dir() 24 | 25 | if DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR is None: 26 | msg = """Unable to find tablegpt ipykernel. If you need to use a local kernel, 27 | please use `pip install -U tablegpt-agent[local]` to install the necessary dependencies. 28 | For more issues, please submit an issue to us https://github.com/tablegpt/tablegpt-agent/issues.""" 29 | warnings.warn(msg, stacklevel=2) 30 | -------------------------------------------------------------------------------- /src/tablegpt/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import date # noqa: TCH003 4 | from typing import TYPE_CHECKING 5 | 6 | from langchain_core.messages import BaseMessage # noqa: TCH002 7 | from langgraph.graph import END, START, MessagesState, StateGraph 8 | 9 | from tablegpt.agent.data_analyzer import TruncationConfig, create_data_analyze_workflow 10 | from tablegpt.agent.file_reading import Stage, create_file_reading_workflow 11 | 12 | if TYPE_CHECKING: 13 | from pathlib import Path 14 | 15 | from langchain_core.language_models import BaseLanguageModel 16 | from langchain_core.retrievers import BaseRetriever 17 | from langchain_core.runnables import Runnable 18 | from langgraph.checkpoint.base import BaseCheckpointSaver 19 | from langgraph.graph.state import CompiledStateGraph 20 | from pybox.base import BasePyBoxManager 21 | 22 | 23 | class AgentState(MessagesState): 24 | # This is a bit of a hack to pass parent id to the agent state 25 | # But it act as the group id of all messages generated by the agent 26 | # This will be used in subgraphs 27 | parent_id: str | None 28 | # Current Date 29 | date: date 30 | # The message that we received from the user, act as an entry point 31 | entry_message: BaseMessage 32 | processing_stage: Stage 33 | 34 | 35 | def create_tablegpt_graph( 36 | llm: BaseLanguageModel, 37 | pybox_manager: BasePyBoxManager, 38 | *, 39 | session_id: str | None = None, 40 | workdir: Path | None = None, 41 | error_trace_cleanup: bool = False, 42 | nlines: int | None = None, 43 | vlm: BaseLanguageModel | None = None, 44 | safety_llm: Runnable | None = None, 45 | dataset_retriever: BaseRetriever | None = None, 46 | normalize_llm: BaseLanguageModel | None = None, 47 | locale: str | None = None, 48 | checkpointer: BaseCheckpointSaver | None = None, 49 | llm_truncation_config: TruncationConfig | None = None, 50 | vlm_truncation_config: TruncationConfig | None = None, 51 | verbose: bool = False, 52 | ) -> CompiledStateGraph: 53 | """Creates a state graph for processing datasets. 54 | 55 | This function orchestrates the creation of a workflow for handling table data. 56 | It sets up nodes for reading files and analyzing data based on provided parameters. 57 | The graph dynamically routes based on the presence of file attachments in the input state. 58 | 59 | Args: 60 | llm (Runnable): The primary language model for processing user input. 61 | pybox_manager (BasePyBoxManager): A python code sandbox delegator, used to execute the data analysis code generated by llm. 62 | session_id (str | None, optional): An optional session identifier used to associate with `pybox`. Defaults to None. 63 | workdir (Path | None, optional): The working directory for `pybox` operations. Defaults to None. 64 | error_trace_cleanup (bool, optional): Flag to clean up error traces. Defaults to False. 65 | nlines (int | None, optional): Number of lines to read for preview. Defaults to None. 66 | vlm (BaseLanguageModel | None, optional): Optional vision language model for processing images. Defaults to None. 67 | safety_llm (Runnable | None, optional): Model used for safety classification of inputs. Defaults to None. 68 | dataset_retriever (BaseRetriever | None, optional): Component to retrieve datasets. Defaults to None. 69 | normalize_llm (BaseLanguageModel | None, optional): Model for data normalization tasks. Defaults to None. 70 | locate (str | None, optional): The locale of the user. Defaults to None. 71 | checkpointer (BaseCheckpointSaver | None, optional): Component for saving checkpoints. Defaults to None. 72 | llm_truncation_config (TruncationConfig | None, optional): Truncation config for LLM. Defaults to None. 73 | vlm_truncation_config (TruncationConfig | None, optional): Truncation config for VLM. Defaults to None. 74 | verbose (bool, optional): Flag to enable verbose logging. Defaults to False. 75 | 76 | Returns: 77 | CompiledStateGraph: A compiled state graph representing the table processing workflow. 78 | """ 79 | workflow = StateGraph(AgentState) 80 | file_reading_graph = create_file_reading_workflow( 81 | nlines=nlines, 82 | llm=llm, 83 | pybox_manager=pybox_manager, 84 | workdir=workdir, 85 | session_id=session_id, 86 | normalize_llm=normalize_llm, 87 | locale=locale, 88 | verbose=verbose, 89 | ) 90 | data_analyze_graph = create_data_analyze_workflow( 91 | llm=llm, 92 | pybox_manager=pybox_manager, 93 | workdir=workdir, 94 | session_id=session_id, 95 | error_trace_cleanup=error_trace_cleanup, 96 | vlm=vlm, 97 | safety_llm=safety_llm, 98 | dataset_retriever=dataset_retriever, 99 | llm_truncation_config=llm_truncation_config, 100 | vlm_truncation_config=vlm_truncation_config, 101 | verbose=verbose, 102 | ) 103 | 104 | def router(state: AgentState) -> str: 105 | # Must have at least one message when entering this router 106 | last_message = state["messages"][-1] 107 | if last_message.additional_kwargs.get("attachments"): 108 | return "file_reading_graph" 109 | return "data_analyze_graph" 110 | 111 | workflow.add_node("file_reading_graph", file_reading_graph) 112 | workflow.add_node("data_analyze_graph", data_analyze_graph) 113 | 114 | workflow.add_conditional_edges(START, router) 115 | workflow.add_edge("file_reading_graph", END) 116 | workflow.add_edge("data_analyze_graph", END) 117 | 118 | return workflow.compile(checkpointer=checkpointer, debug=verbose) 119 | -------------------------------------------------------------------------------- /src/tablegpt/agent/output_parser.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import re 5 | from re import Pattern 6 | from sys import version_info 7 | from uuid import uuid4 8 | 9 | from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish 10 | from langchain_core.messages import AIMessage 11 | from langchain_core.output_parsers import BaseOutputParser 12 | 13 | from tablegpt.errors import SimpleOutputParserException 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | if version_info >= (3, 12): 18 | from typing import override 19 | else: 20 | 21 | def override(func): 22 | return func 23 | 24 | 25 | class MarkdownOutputParser(BaseOutputParser[AgentAction | AgentFinish]): 26 | """Output parser that extracts markdown code blocks and try to parse them into actions.""" 27 | 28 | # group1: thought; group2: language; group3: tool_input; group4: remaining content 29 | pattern: Pattern = re.compile(r"([\S\s]*?)`{3}([\w]*)\n([\S\s]+?)\n`{3}([\S\s]*)", re.DOTALL) 30 | language_actions: dict[str, str] = {} # noqa: RUF012 31 | """A mapping from language to action key.""" 32 | just_finish: bool = True 33 | """Whether to just return AgentFinish if no parser can parse the output. Default to True.""" 34 | 35 | @override 36 | def parse(self, text: str) -> AgentAction | AgentFinish: 37 | if (match := re.search(self.pattern, text)) is not None: 38 | thought = match.group(1).strip() 39 | language = match.group(2) 40 | tool_input = match.group(3).strip() 41 | if (action := self.language_actions.get(language)) is not None: 42 | return AgentActionMessageLog( 43 | tool=action, 44 | tool_input=tool_input, 45 | # log is the 'thought' part 46 | log=thought, 47 | # message_log is the content we can add to history 48 | # polishing the content will improve the following iterations 49 | # TODO: run id 50 | message_log=[ 51 | AIMessage( 52 | id=str(uuid4()), 53 | # We preserve only the 'thought' and the 'action' part, and remove the 'remaining content' part 54 | content=text.removesuffix(match.group(4)).strip(), 55 | tool_calls=[ 56 | { 57 | "name": action, 58 | "args": {"query": tool_input}, 59 | "id": str(uuid4()), 60 | } 61 | ], 62 | # deprecate the "action" part in additional_kwargs? 63 | additional_kwargs={ 64 | "thought": thought, 65 | "action": { 66 | "tool": action, 67 | "tool_input": tool_input, 68 | }, 69 | }, 70 | ) 71 | ], 72 | ) 73 | logger.warning("Unknown language %s", language) 74 | if self.just_finish: 75 | return AgentFinish({"output": text}, text) 76 | raise SimpleOutputParserException(text) 77 | 78 | @override 79 | @property 80 | def _type(self) -> str: 81 | return "markdown" 82 | -------------------------------------------------------------------------------- /src/tablegpt/errors.py: -------------------------------------------------------------------------------- 1 | from langchain_core.exceptions import OutputParserException 2 | 3 | 4 | class NoAttachmentsError(KeyError): 5 | def __init__(self): 6 | super().__init__("No file attached") 7 | 8 | 9 | class InvalidURIError(ValueError): ... 10 | 11 | 12 | class InvalidFileURIError(InvalidURIError): 13 | def __init__(self, uri: str): 14 | super().__init__(f"URI does not start with 'file:': {uri!r}") 15 | 16 | 17 | class NonAbsoluteURIError(InvalidURIError): 18 | def __init__(self, uri: str): 19 | super().__init__(f"URI is not absolute: {uri!r}") 20 | 21 | 22 | class UnsupportedFileFormatError(ValueError): 23 | def __init__(self, ext: str): 24 | super().__init__( 25 | f"TableGPT 目前支持 csv、tsv 以及 xlsx 文件,您上传的文件格式 {ext} 暂不支持。" # noqa: RUF001 26 | ) 27 | 28 | 29 | class UnsupportedEncodingError(ValueError): 30 | def __init__(self, encoding: str): 31 | super().__init__( 32 | f"不支持的文件编码{encoding},请转换成 utf-8 后重试" # noqa: RUF001 33 | ) 34 | 35 | 36 | class EncodingDetectionError(LookupError): 37 | def __init__(self, path: str): 38 | super().__init__(f"Could not detect encoding for {path}") 39 | 40 | 41 | class SimpleOutputParserException(OutputParserException): 42 | def __init__(self, input_text: str): 43 | super().__init__(f"Could not parse output: {input_text}") 44 | -------------------------------------------------------------------------------- /src/tablegpt/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | from typing import TYPE_CHECKING 5 | 6 | from tablegpt.retriever.compressor import ColumnDocCompressor 7 | from tablegpt.retriever.loader import CSVLoader 8 | 9 | if TYPE_CHECKING: 10 | from langchain_core.documents import Document 11 | 12 | __all__ = [ 13 | "CSVLoader", 14 | "ColumnDocCompressor", 15 | "format_columns", 16 | ] 17 | 18 | 19 | def format_columns( 20 | docs: list[Document], 21 | dataset_cell_length_threshold: int = 40, 22 | max_dataset_cells: int = 5, 23 | ) -> str: 24 | if not docs: 25 | return "" 26 | tables: dict = {} 27 | for doc in docs: 28 | tables.setdefault(doc.metadata["filename"], []).append(doc) 29 | 30 | cols = [] 31 | for table, t_docs in tables.items(): 32 | cols.append( 33 | f"- {table}:\n" 34 | + "\n".join( 35 | f' - {{"column": {doc.metadata["column"]}, "dtype": "{doc.metadata["dtype"]}", "values": {format_values(doc.metadata["values"], dataset_cell_length_threshold, max_dataset_cells, doc.metadata["n_unique"])}}}' 36 | for doc in t_docs 37 | ) 38 | ) 39 | 40 | return ( 41 | "\nHere are some extra column information that might help you understand the dataset:\n" 42 | + "\n".join(cols) 43 | + "\n" 44 | ) 45 | 46 | 47 | def format_values( 48 | values_to_format: list[str], 49 | cell_length: int | None = None, 50 | n_to_keep: int | None = None, 51 | n_unique: int | None = None, 52 | ) -> str: 53 | """Format values into a json list string. 54 | Args: 55 | values_to_format (list[str]): A list of values to format. 56 | cell_length (int, optional): Maximum length of each cell. Defaults to None. 57 | n_to_keep (int, optional): Number of values to keep. Defaults to None. 58 | n_unique (int, optional): number of unique values in that column. Defaults to None. 59 | 60 | Returns: 61 | str: Formatted values as a json list string. 62 | """ 63 | # Apply length limit if specified 64 | if n_to_keep is not None: 65 | values_to_format = values_to_format[:n_to_keep] 66 | 67 | # Apply cell length limit if specified 68 | if cell_length is not None: 69 | values_to_format = [ 70 | value[:cell_length] + "..." if len(value) > cell_length else value for value in values_to_format 71 | ] 72 | 73 | # Convert values to JSON representation 74 | values_repr = json.dumps(values_to_format, ensure_ascii=False) 75 | 76 | # Check if unique count is specified and greater than the actual length of values 77 | if n_unique is not None and n_unique > len(values_to_format): 78 | values_repr = values_repr[:-1] + ", ...]" 79 | 80 | return values_repr 81 | -------------------------------------------------------------------------------- /src/tablegpt/retriever/compressor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import defaultdict 4 | from sys import version_info 5 | from typing import TYPE_CHECKING 6 | 7 | from langchain_core.documents import Document 8 | from langchain_core.documents.compressor import BaseDocumentCompressor 9 | 10 | if version_info >= (3, 12): 11 | from typing import override 12 | else: 13 | 14 | def override(func): 15 | return func 16 | 17 | 18 | if TYPE_CHECKING: 19 | from collections.abc import Sequence 20 | 21 | from langchain_core.callbacks import Callbacks 22 | 23 | 24 | class ColumnDocCompressor(BaseDocumentCompressor): 25 | """Compresses documents by regrouping them by column. 26 | 27 | The TableGPT Agent generates documents at the cell level (format: {column_name: cell_value}) to enhance retrieval accuracy. 28 | However, after retrieval, these documents need to be recombined by column before being sent to the LLM for processing. 29 | """ 30 | 31 | @override 32 | def compress_documents( 33 | self, 34 | documents: Sequence[Document], 35 | query: str, # noqa: ARG002 36 | callbacks: Callbacks | None = None, # noqa: ARG002 37 | ) -> Sequence[Document]: 38 | if not documents: 39 | return [] 40 | 41 | # Initialize defaultdict to collect documents by column 42 | # Document.page_content cannot be None 43 | cols = defaultdict(lambda: Document(page_content="", metadata={})) 44 | 45 | for doc in documents: 46 | key = f"{doc.metadata['filename']}:{doc.metadata['column']}" 47 | 48 | # Initialize if key is encountered first time 49 | if not cols[key].page_content: 50 | cols[key].page_content = f"column: {doc.metadata['column']}" 51 | # Copy all metadata, excluding 'value' (if needed) 52 | cols[key].metadata = {k: v for k, v in doc.metadata.items() if k != "value"} 53 | cols[key].metadata["values"] = [] 54 | 55 | # Append value to the existing document's values list 56 | cols[key].metadata["values"].append(doc.metadata["value"]) 57 | 58 | return list(cols.values()) 59 | -------------------------------------------------------------------------------- /src/tablegpt/retriever/loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from sys import version_info 5 | from typing import TYPE_CHECKING 6 | 7 | from langchain_core.document_loaders import BaseLoader 8 | from langchain_core.documents import Document 9 | from pandas.api.types import is_string_dtype 10 | 11 | from tablegpt.utils import read_df 12 | 13 | if version_info >= (3, 12): 14 | from typing import override 15 | else: 16 | 17 | def override(func): 18 | return func 19 | 20 | 21 | if TYPE_CHECKING: 22 | from collections.abc import AsyncIterator, Iterator 23 | 24 | from pandas import Series 25 | 26 | 27 | class CSVLoader(BaseLoader): 28 | """Loads CSV or Excel files into Documents. 29 | 30 | This is similar with `langchain_community.document_loadsers.csv_loader.CSVLoader`. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | file_path: str | Path, 36 | extra_metadata: dict | None = None, 37 | encoding: str | None = None, 38 | *, 39 | autodetect_encoding: bool = False, 40 | ): 41 | """ 42 | 43 | Args: 44 | file_path: The path to the CSV file. 45 | extra_metadata: Extra metadata to set on every document. Optional. Defaults to None. 46 | encoding: The encoding of the CSV file. Optional. Defaults to None. 47 | autodetect_encoding: Whether to try to autodetect the file encoding. Optional. Defaults to False. 48 | """ 49 | self.file_path = file_path 50 | self.extra_metadata = {} if extra_metadata is None else extra_metadata 51 | if isinstance(self.file_path, Path): 52 | self.extra_metadata["filename"] = self.file_path.name 53 | else: 54 | self.extra_metadata["filename"] = self.file_path 55 | self.encoding = encoding 56 | self.autodetect_encoding = autodetect_encoding 57 | 58 | @override 59 | def lazy_load(self) -> Iterator[Document]: 60 | """A lazy loader for Documents.""" 61 | df = read_df(self.file_path, autodetect_encoding=self.autodetect_encoding) 62 | for col in df.select_dtypes(exclude=["number"]).columns: 63 | yield from self.column2docs(df[col]) 64 | 65 | @override 66 | async def alazy_load(self) -> AsyncIterator[Document]: 67 | """A lazy loader for Documents.""" 68 | # TODO: pandas does not support async read_csv yet. We might need to async read the file first. 69 | async for doc in super().alazy_load(): 70 | yield doc 71 | 72 | def column2docs(self, column: Series) -> Iterator[Document]: 73 | # If a string column contains NaN, it will be presented as object dtype. 74 | dtype = "string" if is_string_dtype(column.dropna()) else str(column.dtype) 75 | unique_values = column.unique() 76 | 77 | for value in unique_values: 78 | yield Document( 79 | page_content=f"{column.name}:{value}", 80 | metadata={ 81 | "column": column.name, 82 | "dtype": dtype, 83 | "n_unique": len(unique_values), 84 | "value": str(value), # may need to further consolidate 85 | } 86 | | self.extra_metadata, 87 | ) 88 | -------------------------------------------------------------------------------- /src/tablegpt/safety.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from langchain_core.output_parsers import BaseTransformOutputParser 6 | from langchain_core.prompts import ChatPromptTemplate 7 | 8 | if TYPE_CHECKING: 9 | from langchain_core.language_models import BaseLanguageModel 10 | from langchain_core.runnables import Runnable 11 | 12 | 13 | # See 14 | hazard_categories = { 15 | "S1": "Violent Crimes: Responses that enable, encourage, or endorse the commission of violent crimes, including: (1) unlawful violence toward people (ex: terrorism, genocide, murder, hate-motivated violent crimes, child abuse, assault, battery, kidnapping) and (2) unlawful violence toward animals (ex: animal abuse)", 16 | "S2": "Non-Violent Crimes: Responses that enable, encourage, or endorse the commission of non-violent crimes, including: (1) personal crimes (ex: labor trafficking, threats, intimidation, hate-motivated non-violent crimes), (2) financial crimes (ex: fraud, scams, money laundering), (3) property crimes (ex: theft, arson, vandalism), (4) drug crimes (ex: creating or using regulated substances), (5) weapons crimes (ex: producing unlicensed firearms), and (6) cyber crimes (ex: hacking)", 17 | "S3": "Sex-Related Crimes: Responses that enable, encourage, or endorse the commission of sex-related crimes, including: (1) sex trafficking, (2) sexual assault (ex: rape), (3) sexual harassment (physical, verbal, or visual), and (4) prostitution", 18 | "S4": "Child Sexual Exploitation: Responses that contain, describe, enable, encourage, or endorse the sexual abuse of children", 19 | "S5": "Defamation: Responses that are both verifiably false and likely to injure a living person's reputation", 20 | "S6": "Specialized Advice: Responses that contain specialized financial, medical, or legal advice, or that indicate dangerous activities or objects are safe", 21 | "S7": "Privacy: Responses that contain sensitive, nonpublic personal information that could undermine someone's physical, digital, or financial security", 22 | "S8": "Intellectual Property: Responses that may violate the intellectual property rights of any third party", 23 | "S9": "Indiscriminate Weapons: Responses that enable, encourage, or endorse the creation of indiscriminate weapons, including: (1) chemical weapons (ex: nerve gas), (2) biological weapons (ex: anthrax), (3) radiological weapons (ex: cobalt bombs), (4) nuclear weapons (ex: fission bombs), and (5) high-yield explosive weapons (ex: cluster munitions)", 24 | "S10": "Hate: Responses that demean or dehumanize people on the basis of their sensitive, personal characteristics (i.e., race, color, ethnicity, national origin, disability, religious affiliation, caste, sexual orientation, sex, gender identity, and/or serious disease)", 25 | "S11": "Suicide & Self-Harm: Responses that enable, encourage, or endorse acts of intentional self-harm, including: (1) suicide, (2) self-injury (ex: cutting), and (3) disordered eating", 26 | "S12": "Sexual Content: Responses that contain erotica", 27 | "S13": "Elections: Responses that contain factually incorrect information about electoral systems and processes, including in the time, place, or manner of voting in civic elections", 28 | "S14": "Code Interpreter Abuse: Responses that seek to abuse code interpreters, including those that enable denial of service attacks, container escapes or privilege escalation exploits", 29 | } 30 | 31 | 32 | class HazardOutputParser(BaseTransformOutputParser[tuple[str, str | None]]): 33 | def parse(self, text: str) -> tuple[str, str | None]: 34 | """Parse the output of the guard model. 35 | 36 | Returns: 37 | tuple[str, str | None]: A tuple where the first element is the safety flag ("safe", "unsafe", "unknown") and the second element is 38 | the risk category if applicable, otherwise `None`. 39 | """ 40 | text = text.strip() 41 | 42 | if "\n" not in text: 43 | if text.lower() == "safe": 44 | return "safe", None 45 | return "unknown", None 46 | 47 | flag, category = text.split("\n", 1) 48 | if flag.lower() == "unsafe": 49 | return "unsafe", category 50 | return "unknown", None 51 | 52 | 53 | tmpl = ChatPromptTemplate.from_messages( 54 | [ 55 | ("placeholder", "{messages}"), 56 | ] 57 | ) 58 | 59 | 60 | output_parse = HazardOutputParser() 61 | 62 | 63 | def create_hazard_classifier(llm: BaseLanguageModel) -> Runnable: 64 | """return the guard chain runnable.""" 65 | return tmpl | llm | output_parse 66 | -------------------------------------------------------------------------------- /src/tablegpt/tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import mimetypes 4 | import re 5 | from pathlib import Path 6 | from re import Pattern 7 | from sys import version_info 8 | from typing import TYPE_CHECKING, Literal 9 | 10 | from langchain_core.callbacks.manager import CallbackManagerForToolRun # noqa: TCH002 11 | from langchain_core.tools import BaseTool 12 | from pybox.base import BasePyBox, BasePyBoxManager # noqa: TCH002 13 | from pydantic import BaseModel, DirectoryPath, field_validator, model_validator 14 | 15 | if version_info >= (3, 12): 16 | from typing import override 17 | else: 18 | 19 | def override(func): 20 | return func 21 | 22 | 23 | if TYPE_CHECKING: 24 | from pybox.schema import ErrorContent, PyBoxOut 25 | from typing_extensions import Self 26 | 27 | 28 | class Artifact(BaseModel): 29 | """Represents an artifact (file) generated by an agent during the execution of a task.""" 30 | 31 | filename: str | None = None 32 | """The name of the file. If not provided, will be extracted from the path.""" 33 | path: Path 34 | """The absolute path to the artifact.""" 35 | mimetype: str | None = None 36 | """The MIME type of the artifact, determined based on the file extension. 37 | OS is not guaranteed to guess the mimetype of any file.""" 38 | 39 | @model_validator(mode="after") 40 | def extract_filename(self) -> Self: 41 | self.filename = self.path.name 42 | return self 43 | 44 | @field_validator("path") 45 | @classmethod 46 | def ensure_path_absolute(cls, v: Path) -> Path: 47 | return v.absolute() 48 | 49 | 50 | class IPythonTool(BaseTool): 51 | """A tool for running code in an IPython kernel and handling the result, including content and generated artifacts.""" 52 | 53 | name: str = "python" 54 | description: str = "IPython kernel tool" 55 | response_format: Literal["content_and_artifact"] = "content_and_artifact" 56 | """Change the default response format to include artifacts. 57 | See `langchain_core.tools.base.BaseTool.response_format` for more information. 58 | """ 59 | pybox_manager: BasePyBoxManager 60 | """A manager for spawning IPython kernel instances.""" 61 | cwd: DirectoryPath | None = None 62 | """The current working directory for the IPython kernel. 63 | If set to None, the kernel will use the default working directory. 64 | """ 65 | session_id: str | None = None 66 | """An optional session ID to persist across tool invocations. 67 | If set to None, the `pybox_manager` will spawn new kernels for each tool call. 68 | """ 69 | filesaving_pattern: Pattern = re.compile(r'(?:\.savefig|\.to_csv)\(\s*[\'"]([^\'"]+)[\'"]\s*') 70 | """A regex pattern used to extract file saving paths from code.""" 71 | error_trace_cleanup: bool = False 72 | """Whether to cleanup the error traces before returning them to the caller.""" 73 | error_trace_cleanup_pattern: Pattern = re.compile(r"(Cell In\[\d+\], line \d+\n(?:.*\n)*?)(?=\n)") 74 | """A regex pattern used for cleaning up error traces.""" 75 | 76 | @override 77 | def _run( 78 | self, 79 | query: str, 80 | run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 81 | ) -> tuple[list[str | dict], list[Artifact]]: 82 | """Executes the given query in an IPython kernel and returns the result as content and artifacts. 83 | 84 | Args: 85 | query (str): The code to execute in the IPython kernel. 86 | run_manager (CallbackManagerForToolRun | None): A manager for tracking tool execution. 87 | 88 | Returns: 89 | tuple: A tuple containing the content (a list of strings or dictionaries) and artifacts (a list of Artifact objects). 90 | """ 91 | kwargs = {"cwd": str(self.cwd)} if self.cwd is not None else {} 92 | box = self.pybox_manager.start(kernel_id=self.session_id, **kwargs) 93 | 94 | try: 95 | res: PyBoxOut = box.run(code=query) 96 | except TimeoutError: 97 | return "Execution timed out. Please try again.", [] 98 | 99 | content = [] 100 | artifact = [] 101 | 102 | for part in res.data: 103 | # We cannot mix str with dict for now, as `langgraph.prebuilt.ToolNode.msg_content_output` will dump it to str otherwise. 104 | # So we need to specify the text parts as dict. 105 | if (text_part := part.get("text/plain")) is not None: 106 | content.append({"type": "text", "text": text_part}) 107 | 108 | if (img_part := part.get("image/png")) is not None: 109 | content.append( 110 | { 111 | "type": "image_url", 112 | "image_url": {"url": f"data:image/png;base64,{img_part}"}, 113 | } 114 | ) 115 | 116 | for path in self._guess_artifact_paths(query): 117 | mimetype, _ = mimetypes.guess_type(path) 118 | artifact.append(Artifact(path=path, mimetype=mimetype)) 119 | 120 | if res.error is not None: 121 | cleaned_error = self._extract_error_trace(res.error) 122 | content.append({"type": "text", "text": cleaned_error}) 123 | 124 | return content, artifact 125 | 126 | @override 127 | async def _arun( 128 | self, 129 | query: str, 130 | run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 131 | ) -> tuple[list[str | dict], list[Artifact]]: 132 | """Asynchronously executes the given query in an IPython kernel and returns the result as content and artifacts. 133 | 134 | Args: 135 | query (str): The code to execute in the IPython kernel. 136 | run_manager (CallbackManagerForToolRun | None): A manager for tracking tool execution. 137 | 138 | Returns: 139 | tuple: A tuple containing the content (a list of strings or dictionaries) and artifacts (a list of Artifact objects). 140 | """ 141 | kwargs = {"cwd": str(self.cwd)} if self.cwd is not None else {} 142 | box: BasePyBox = await self.pybox_manager.start(kernel_id=self.session_id, **kwargs) 143 | 144 | try: 145 | res: PyBoxOut = await box.run(code=query) 146 | except TimeoutError: 147 | return "Execution timed out. Please try again.", [] 148 | 149 | content = [] 150 | artifact = [] 151 | 152 | for part in res.data: 153 | # We cannot mix str with dict for now, as `langgraph.prebuilt.ToolNode.msg_content_output` will dump it to str otherwise. 154 | # So we need to specify the text parts as dict. 155 | if (text_part := part.get("text/plain")) is not None: 156 | content.append({"type": "text", "text": text_part}) 157 | 158 | if (img_part := part.get("image/png")) is not None: 159 | content.append( 160 | { 161 | "type": "image_url", 162 | "image_url": {"url": f"data:image/png;base64,{img_part}"}, 163 | } 164 | ) 165 | 166 | for path in self._guess_artifact_paths(query): 167 | mimetype, _ = mimetypes.guess_type(path) 168 | artifact.append(Artifact(path=path, mimetype=mimetype)) 169 | 170 | if res.error is not None: 171 | cleaned_error = self._extract_error_trace(res.error) 172 | content.append({"type": "text", "text": cleaned_error}) 173 | 174 | return content, artifact 175 | 176 | def _guess_artifact_paths(self, code: str) -> list[Path]: 177 | """Guess artifact paths from code. 178 | 179 | Args: 180 | code (str): Code that got executed by the tool. 181 | 182 | Returns: 183 | list[Path]: A list of existing artifact paths. 184 | """ 185 | # Use a set to deduplicate artifacts by filenames. 186 | filenames = set(re.findall(self.filesaving_pattern, code)) 187 | paths = [self.cwd.joinpath(filename) for filename in filenames] 188 | return [path for path in paths if path.exists()] 189 | 190 | def _extract_error_trace(self, e: ErrorContent) -> str: 191 | """Extract and clean the error trace if enabled. 192 | 193 | Args: 194 | e (ErrorContent): The error content returned by the IPython kernel. 195 | 196 | Returns: 197 | str: The cleaned error trace. 198 | """ 199 | if self.error_trace_cleanup and (match := re.search(self.error_trace_cleanup_pattern, str(e))) is not None: 200 | first_part = match.group(0) 201 | return f"{first_part}\n{e.ename}: {e.evalue}\n" 202 | return str(e) 203 | 204 | 205 | # We cannot merge and format the std output inside the tool, as we need the number of content parts to determine the encoder input. 206 | # Which should be refactored in the future. 207 | # So for now we provide a helper function to merge the text parts and a template to format the std output. 208 | 209 | 210 | def process_content(content: str | list[str | dict]) -> list[dict]: 211 | """Merge text parts in the content list. 212 | 213 | As `langgraph.prebuilt.ToolNode` will dump the content list to str if it contains mixed str and dict, 214 | this function also ensures all text parts are in the form of dict with "type": "text". 215 | 216 | Args: 217 | content (str | list[str | dict]): The content to process, which can be a string or a list of strings and dictionaries. 218 | 219 | Returns: 220 | list[dict]: A list of dictionaries representing the merged content, with all text content in the form of "type": "text". 221 | """ 222 | 223 | text_parts = [] 224 | other_parts = [] 225 | 226 | if isinstance(content, str): 227 | text_parts.append(content) 228 | elif isinstance(content, list): 229 | for part in content: 230 | if isinstance(part, str): 231 | # Append string part to text_parts 232 | text_parts.append(part) 233 | elif isinstance(part, dict) and part.get("type") == "text": 234 | # Append text from dict part with "type": "text" 235 | text_parts.append(part["text"]) 236 | else: 237 | # Keep other dict part unchanged 238 | other_parts.append(part) 239 | 240 | # Create the merged "type": "text" part if there is any text to merge 241 | if text_parts: 242 | merged_element = {"type": "text", "text": "\n".join(text_parts)} 243 | return [merged_element, *other_parts] 244 | 245 | return other_parts 246 | 247 | 248 | markdown_console_template = """```pycon 249 | {res} 250 | ```""" 251 | -------------------------------------------------------------------------------- /src/tablegpt/translation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from langchain_core.output_parsers import StrOutputParser 6 | from langchain_core.prompts import ChatPromptTemplate 7 | 8 | if TYPE_CHECKING: 9 | from langchain_core.language_models import BaseLanguageModel 10 | from langchain_core.runnables import Runnable 11 | 12 | 13 | prompt_template = ChatPromptTemplate.from_messages( 14 | [ 15 | ( 16 | "system", 17 | "You are a translation assistant. Translate user input directly into the primary language of the {locale} region without explanation.", 18 | ), 19 | ("user", "{input}"), 20 | ] 21 | ) 22 | 23 | 24 | def create_translator(llm: BaseLanguageModel) -> Runnable: 25 | """return the guard chain runnable.""" 26 | return prompt_template | llm | StrOutputParser() 27 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tablegpt/tablegpt-agent/6195b6437b01a3539f70af1655afc1e2e6012958/tests/__init__.py -------------------------------------------------------------------------------- /tests/agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tablegpt/tablegpt-agent/6195b6437b01a3539f70af1655afc1e2e6012958/tests/agent/__init__.py -------------------------------------------------------------------------------- /tests/agent/file_reading/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tablegpt/tablegpt-agent/6195b6437b01a3539f70af1655afc1e2e6012958/tests/agent/file_reading/__init__.py -------------------------------------------------------------------------------- /tests/agent/test_output_parser.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import unittest 3 | from unittest.mock import patch 4 | from uuid import uuid4 5 | 6 | from langchain_core.agents import AgentActionMessageLog, AgentFinish 7 | from langchain_core.exceptions import OutputParserException 8 | from langchain_core.messages import AIMessage 9 | from tablegpt.agent.output_parser import MarkdownOutputParser 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class TestMarkdownOutputParser(unittest.TestCase): 15 | @patch("tablegpt.agent.output_parser.uuid4") 16 | def test_valid_markdown_known_language_action(self, mock_uuid): 17 | fixed_uuid = uuid4() 18 | mock_uuid.return_value = fixed_uuid 19 | text = "Some text\n```python\nprint('Hello, World!')\n```More text" 20 | parser = MarkdownOutputParser(language_actions={"python": "python"}) 21 | expected_action = AgentActionMessageLog( 22 | tool="python", 23 | tool_input="print('Hello, World!')", 24 | log="Some text", 25 | message_log=[ 26 | AIMessage( 27 | id=str(fixed_uuid), 28 | content="Some text\n```python\nprint('Hello, World!')\n```", 29 | tool_calls=[ 30 | { 31 | "name": "python", 32 | "args": {"query": "print('Hello, World!')"}, 33 | "id": str(fixed_uuid), 34 | } 35 | ], 36 | additional_kwargs={ 37 | "thought": "Some text", 38 | "action": { 39 | "tool": "python", 40 | "tool_input": "print('Hello, World!')", 41 | }, 42 | }, 43 | ) 44 | ], 45 | ) 46 | result = parser.parse(text) 47 | assert result == expected_action 48 | 49 | def test_valid_markdown_unknown_language(self): 50 | text = "Some text\n```unknown\nprint('Hello, World!')\n```More text" 51 | parser = MarkdownOutputParser() 52 | with self.assertLogs("tablegpt.agent.output_parser", level="WARNING") as log: 53 | result = parser.parse(text) 54 | assert "Unknown language" in log.output[0] 55 | assert result == AgentFinish({"output": text}, text) 56 | 57 | def test_valid_markdown_no_code_block(self): 58 | text = "Some text\nWithout code block" 59 | parser = MarkdownOutputParser(just_finish=False) 60 | with self.assertRaises(OutputParserException): # noqa: PT027 61 | result = parser.parse(text) 62 | # TODO: we can mock this behaviour instead of creating a new one 63 | parser = MarkdownOutputParser() 64 | result = parser.parse(text) 65 | assert result == AgentFinish({"output": text}, text) 66 | 67 | @unittest.skip("This test is failing because the parser is not able to parse multiple code blocks") 68 | def test_valid_markdown_multiple_code_blocks(self): 69 | fixed_uuid = uuid4() 70 | text = "Some text\n```python\nprint('Hello, World!')\n```More text\n```java\nSystem.out.println('Hello, World!')\n```" 71 | parser = MarkdownOutputParser(language_actions={"python": "python", "java": "java"}) 72 | expected_action = AgentActionMessageLog( 73 | tool="python", 74 | tool_input="print('Hello, World!')", 75 | log="Some text", 76 | message_log=[ 77 | AIMessage( 78 | id=str(fixed_uuid), 79 | content="Some text\n```python\nprint('Hello, World!')\n```", 80 | tool_calls=[ 81 | { 82 | "name": "python", 83 | "args": {"query": "print('Hello, World!')"}, 84 | "id": str(fixed_uuid), 85 | }, 86 | { 87 | "name": "python", 88 | "args": {"query": "System.out.println('Hello, World!')"}, 89 | "id": str(fixed_uuid), 90 | }, 91 | ], 92 | additional_kwargs={ 93 | "thought": "More text", 94 | "action": { 95 | "tool": "java", 96 | "tool_input": "print('Hello, World!')", 97 | }, 98 | }, 99 | ) 100 | ], 101 | ) 102 | result = parser.parse(text) 103 | assert result == expected_action 104 | 105 | def test_empty_input(self): 106 | text = "" 107 | parser = MarkdownOutputParser(just_finish=False) 108 | with self.assertRaises(OutputParserException): # noqa: PT027 109 | result = parser.parse(text) 110 | # TODO: we can mock this behaviour instead of creating a new one 111 | parser = MarkdownOutputParser() 112 | result = parser.parse(text) 113 | assert result == AgentFinish({"output": text}, text) 114 | 115 | 116 | if __name__ == "__main__": 117 | unittest.main() 118 | -------------------------------------------------------------------------------- /tests/retriever/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tablegpt/tablegpt-agent/6195b6437b01a3539f70af1655afc1e2e6012958/tests/retriever/__init__.py -------------------------------------------------------------------------------- /tests/retriever/test_compressor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from langchain_core.documents import Document 4 | from tablegpt.retriever.compressor import ColumnDocCompressor 5 | 6 | 7 | class TestCompressDocuments(unittest.TestCase): 8 | def setUp(self): 9 | self.processor = ColumnDocCompressor() 10 | 11 | def test_single_column_single_file(self): 12 | documents = [ 13 | Document( 14 | page_content="cell content", 15 | metadata={"filename": "file1", "column": "A", "dtype": "int", "n_unique": 5, "value": 1}, 16 | ), 17 | Document( 18 | page_content="cell content", 19 | metadata={"filename": "file1", "column": "A", "dtype": "int", "n_unique": 5, "value": 2}, 20 | ), 21 | ] 22 | 23 | expected_output = [ 24 | Document( 25 | page_content="column: A", 26 | metadata={"filename": "file1", "column": "A", "dtype": "int", "n_unique": 5, "values": [1, 2]}, 27 | ) 28 | ] 29 | 30 | result = self.processor.compress_documents(documents, query="") 31 | assert result == expected_output 32 | 33 | def test_multiple_columns_single_file(self): 34 | documents = [ 35 | Document( 36 | page_content="A:1", 37 | metadata={"filename": "file1", "column": "A", "dtype": "int", "n_unique": 5, "value": 1}, 38 | ), 39 | Document( 40 | page_content="B:hello", 41 | metadata={"filename": "file1", "column": "B", "dtype": "str", "n_unique": 3, "value": "hello"}, 42 | ), 43 | ] 44 | 45 | expected_output = [ 46 | Document( 47 | page_content="column: A", 48 | metadata={"filename": "file1", "column": "A", "dtype": "int", "n_unique": 5, "values": [1]}, 49 | ), 50 | Document( 51 | page_content="column: B", 52 | metadata={"filename": "file1", "column": "B", "dtype": "str", "n_unique": 3, "values": ["hello"]}, 53 | ), 54 | ] 55 | 56 | result = self.processor.compress_documents(documents, query="") 57 | assert result == expected_output 58 | 59 | def test_multiple_columns_multiple_files(self): 60 | documents = [ 61 | Document( 62 | page_content="cell content", 63 | metadata={"filename": "file1", "column": "A", "dtype": "int", "n_unique": 5, "value": 1}, 64 | ), 65 | Document( 66 | page_content="cell content", 67 | metadata={"filename": "file2", "column": "A", "dtype": "int", "n_unique": 4, "value": 2}, 68 | ), 69 | Document( 70 | page_content="cell content", 71 | metadata={"filename": "file2", "column": "B", "dtype": "str", "n_unique": 3, "value": "world"}, 72 | ), 73 | ] 74 | 75 | expected_output = [ 76 | Document( 77 | page_content="column: A", 78 | metadata={"filename": "file1", "column": "A", "dtype": "int", "n_unique": 5, "values": [1]}, 79 | ), 80 | Document( 81 | page_content="column: A", 82 | metadata={"filename": "file2", "column": "A", "dtype": "int", "n_unique": 4, "values": [2]}, 83 | ), 84 | Document( 85 | page_content="column: B", 86 | metadata={"filename": "file2", "column": "B", "dtype": "str", "n_unique": 3, "values": ["world"]}, 87 | ), 88 | ] 89 | 90 | result = self.processor.compress_documents(documents, query="") 91 | assert result == expected_output 92 | 93 | def test_empty_input(self): 94 | documents = [] 95 | expected_output = [] 96 | result = self.processor.compress_documents(documents, query="") 97 | assert result == expected_output 98 | 99 | 100 | if __name__ == "__main__": 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /tests/retriever/test_format.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from langchain_core.documents import Document 4 | from tablegpt.retriever import format_columns 5 | 6 | 7 | class TestFormatColumns(unittest.TestCase): 8 | def test_format_empty_column_docs(self): 9 | formated_columns = format_columns([]) 10 | assert formated_columns == "" 11 | 12 | def test_format_column_docs(self): 13 | docs = [ 14 | Document( 15 | page_content="column:Sex", 16 | metadata={ 17 | "filename": "foo.csv", 18 | "column": "Sex", 19 | "dtype": "string", 20 | "n_unique": 2, 21 | "values": ["male", "female"], 22 | }, 23 | ) 24 | ] 25 | formated_columns = format_columns(docs) 26 | hint = """ 27 | Here are some extra column information that might help you understand the dataset: 28 | - foo.csv: 29 | - {"column": Sex, "dtype": "string", "values": ["male", "female"]} 30 | """ 31 | assert formated_columns == hint 32 | 33 | def test_format_and_compress_column(self): 34 | docs = [ 35 | Document( 36 | page_content="column:Sex", 37 | metadata={ 38 | "filename": "foo.csv", 39 | "column": "Sex", 40 | "dtype": "string", 41 | "n_unique": 3, 42 | "values": ["male", "female", "unknown"], 43 | }, 44 | ) 45 | ] 46 | hint = """ 47 | Here are some extra column information that might help you understand the dataset: 48 | - foo.csv: 49 | - {"column": Sex, "dtype": "string", "values": ["mal...", "fem...", ...]} 50 | """ 51 | formated_columns = format_columns(docs, dataset_cell_length_threshold=3, max_dataset_cells=2) 52 | assert formated_columns == hint 53 | 54 | 55 | if __name__ == "__main__": 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /tests/retriever/test_loader.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | from langchain_core.documents import Document 5 | from pandas import DataFrame, Series 6 | from tablegpt.retriever.loader import CSVLoader 7 | 8 | 9 | @pytest.fixture 10 | def mock_df(): 11 | """Fixture to provide a mocked DataFrame.""" 12 | return DataFrame({"column1": ["value1", "value2", "value3"], "column2": ["A", "B", "C"], "column3": [1, 2, 3]}) 13 | 14 | 15 | @pytest.fixture 16 | def loader(): 17 | """Fixture to provide a CSVLoader instance.""" 18 | return CSVLoader(file_path="test.csv", extra_metadata={"source": "test_source"}, autodetect_encoding=True) 19 | 20 | 21 | def test_initialization(loader): 22 | assert loader.file_path == "test.csv" 23 | assert loader.extra_metadata == {"source": "test_source", "filename": "test.csv"} 24 | assert loader.autodetect_encoding 25 | 26 | 27 | def test_lazy_load(loader, mock_df): 28 | with ( 29 | patch("tablegpt.retriever.loader.read_df", return_value=mock_df), 30 | patch.object( 31 | loader, 32 | "column2docs", 33 | return_value=iter( 34 | [ 35 | Document( 36 | page_content="column1:value1", 37 | metadata={"column": "column1", "dtype": "string", "value": "value1"}, 38 | ), 39 | Document( 40 | page_content="column1:value2", 41 | metadata={"column": "column1", "dtype": "string", "value": "value2"}, 42 | ), 43 | ] 44 | ), 45 | ), 46 | ): 47 | documents = list(loader.lazy_load()) 48 | assert len(documents) == 2 49 | assert documents[0].page_content == "column1:value1" 50 | assert documents[1].page_content == "column1:value2" 51 | 52 | 53 | def test_lazy_load_with_missing_metadata(mock_df): 54 | loader = CSVLoader(file_path="test.csv", autodetect_encoding=True) 55 | with ( 56 | patch("tablegpt.retriever.loader.read_df", return_value=mock_df), 57 | patch.object( 58 | loader, 59 | "column2docs", 60 | return_value=iter( 61 | [ 62 | Document( 63 | page_content="column1:value1", 64 | metadata={"column": "column1", "dtype": "string", "value": "value1"}, 65 | ), 66 | Document( 67 | page_content="column1:value2", 68 | metadata={"column": "column1", "dtype": "string", "value": "value2"}, 69 | ), 70 | ] 71 | ), 72 | ), 73 | ): 74 | documents = list(loader.lazy_load()) 75 | assert len(documents) == 2 76 | 77 | 78 | def test_column2docs(loader, mock_df): 79 | column = Series(["value1", "value2", "value3"], name="column1") 80 | with patch("tablegpt.retriever.loader.read_df", return_value=mock_df): 81 | documents = list(loader.column2docs(column)) 82 | assert len(documents) == 3 83 | assert documents[0].page_content == "column1:value1" 84 | assert documents[0].metadata["column"] == "column1" 85 | assert documents[0].metadata["value"] == "value1" 86 | 87 | 88 | def test_empty_csv(loader): 89 | empty_df = DataFrame() 90 | with patch("tablegpt.retriever.loader.read_df", return_value=empty_df): 91 | documents = list(loader.lazy_load()) 92 | assert documents == [] 93 | 94 | 95 | def test_csv_with_non_string_column(loader): 96 | df = DataFrame({"column1": [1, 2, 3], "column2": ["A", "B", "C"]}) 97 | with patch("tablegpt.retriever.loader.read_df", return_value=df): 98 | documents = list(loader.lazy_load()) 99 | assert len(documents) == 3 100 | assert documents[0].page_content == "column2:A" 101 | assert documents[1].page_content == "column2:B" 102 | assert documents[2].page_content == "column2:C" 103 | -------------------------------------------------------------------------------- /tests/test_profile_init.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | from unittest.mock import MagicMock, patch 4 | 5 | 6 | class TestTableGPTInit(unittest.TestCase): 7 | def setUp(self): 8 | # Save the original tablegpt module if it exists 9 | self.original_tablegpt = sys.modules.get("tablegpt") 10 | 11 | # Clear tablegpt from sys.modules 12 | if "tablegpt" in sys.modules: 13 | del sys.modules["tablegpt"] 14 | 15 | # Create a mock site module 16 | self.mock_sysconfig = MagicMock() 17 | self.original_sysconfig = sys.modules["sysconfig"] 18 | sys.modules["sysconfig"] = self.mock_sysconfig 19 | 20 | def tearDown(self): 21 | # Restore the original modules 22 | sys.modules["sysconfig"] = self.original_sysconfig 23 | 24 | # Restore the original tablegpt module if it existed 25 | if self.original_tablegpt: 26 | sys.modules["tablegpt"] = self.original_tablegpt 27 | elif "tablegpt" in sys.modules: 28 | del sys.modules["tablegpt"] 29 | 30 | def test_find_tablegpt_ipykernel_profile_dir_found(self): 31 | # mock return values 32 | self.mock_sysconfig.get_path.return_value = "/usr/local" 33 | 34 | with patch("pathlib.Path.glob", return_value=iter(["mock-udfs.py"])): 35 | from tablegpt import DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR 36 | 37 | assert DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR == "/usr/local/share/ipykernel/profile/tablegpt" 38 | 39 | def test_default_tablegpt_ipykernel_profile_dir_not_found(self): 40 | # mock return values 41 | self.mock_sysconfig.get_path.return_value = "/wrong/lib/python3.x/site-packages" 42 | 43 | # not found 44 | with patch("pathlib.Path.glob", return_value=iter([])), self.assertWarns(UserWarning): 45 | from tablegpt import DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR 46 | 47 | assert DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR is None 48 | 49 | 50 | if __name__ == "__main__": 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /tests/test_safety.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tablegpt.safety import HazardOutputParser 4 | 5 | 6 | class TestHazardOutputParser(unittest.TestCase): 7 | def setUp(self): 8 | self.parser = HazardOutputParser() 9 | 10 | def test_parse_safe(self): 11 | result = self.parser.parse("\n\nsafe") 12 | assert result == ("safe", None) 13 | 14 | def test_parse_safe_with_spaces(self): 15 | result = self.parser.parse("\n\n safe ") 16 | assert result == ("safe", None) 17 | 18 | def test_parse_unknown(self): 19 | result = self.parser.parse("unrecognized input") 20 | assert result == ("unknown", None) 21 | 22 | def test_parse_unsafe_text_with_category(self): 23 | text = "unsafe\nS1" 24 | result = self.parser.parse(text) 25 | assert result == ("unsafe", "S1") 26 | 27 | def test_parse_unsafe_text_with_invalid_format(self): 28 | text = "unsafe only one line" 29 | result = self.parser.parse(text) 30 | assert result == ("unknown", None) 31 | 32 | 33 | if __name__ == "__main__": 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /tests/test_tools.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tablegpt.tools import process_content 4 | 5 | 6 | class TestProcessContent(unittest.TestCase): 7 | def test_single_string(self): 8 | content = "Hello" 9 | expected_output = [{"type": "text", "text": "Hello"}] 10 | assert process_content(content) == expected_output 11 | 12 | def test_list_of_strings(self): 13 | content = ["Hello", "World"] 14 | expected_output = [{"type": "text", "text": "Hello\nWorld"}] 15 | assert process_content(content) == expected_output 16 | 17 | def test_list_of_mixed_strings_and_dicts(self): 18 | content = [ 19 | "Hello", 20 | {"type": "text", "text": "World"}, 21 | {"type": "image", "url": "image.png"}, 22 | ] 23 | expected_output = [ 24 | {"type": "text", "text": "Hello\nWorld"}, 25 | {"type": "image", "url": "image.png"}, 26 | ] 27 | assert process_content(content) == expected_output 28 | 29 | def test_list_of_only_dicts(self): 30 | content = [ 31 | {"type": "image", "url": "image.png"}, 32 | {"type": "video", "url": "video.mp4"}, 33 | ] 34 | expected_output = [ 35 | {"type": "image", "url": "image.png"}, 36 | {"type": "video", "url": "video.mp4"}, 37 | ] 38 | assert process_content(content) == expected_output 39 | 40 | def test_empty_string(self): 41 | content = "" 42 | expected_output = [{"type": "text", "text": ""}] 43 | assert process_content(content) == expected_output 44 | 45 | def test_empty_list(self): 46 | content = [] 47 | expected_output = [] 48 | assert process_content(content) == expected_output 49 | 50 | def test_list_with_empty_string(self): 51 | content = ["", {"type": "image", "url": "image.png"}] 52 | expected_output = [ 53 | {"type": "text", "text": ""}, 54 | {"type": "image", "url": "image.png"}, 55 | ] 56 | assert process_content(content) == expected_output 57 | 58 | def test_text_in_dict(self): 59 | content = [{"type": "text", "text": "Hello"}] 60 | expected_output = [{"type": "text", "text": "Hello"}] 61 | assert process_content(content) == expected_output 62 | 63 | 64 | if __name__ == "__main__": 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pathlib import Path 3 | 4 | from langchain_core.messages import BaseMessage 5 | from tablegpt.utils import ( 6 | filter_content, 7 | path_from_uri, 8 | ) 9 | 10 | 11 | class TestPathFromUri(unittest.TestCase): 12 | @unittest.skip("Cannot test linux path on windows and vice versa") 13 | def test_valid_file_uri_unix(self): 14 | """Test a valid 'file:' URI on a Unix system.""" 15 | uri = "file:///home/user/file.txt" 16 | expected_path = Path("/home/user/file.txt") 17 | assert path_from_uri(uri) == expected_path 18 | 19 | @unittest.skip("Cannot test linux path on windows and vice versa") 20 | def test_valid_file_uri_windows(self): 21 | """Test a valid 'file:' URI on a Windows system.""" 22 | uri = "file:///C:/Users/user/file.txt" 23 | expected_path = Path("C:/Users/user/file.txt") 24 | assert path_from_uri(uri) == expected_path 25 | 26 | @unittest.skip("Cannot test linux path on windows and vice versa") 27 | def test_valid_file_uri_unc_path(self): 28 | """Test a valid 'file:' URI with a UNC path.""" 29 | uri = "file://localhost/Server/Share/file.txt" 30 | expected_path = Path("/Server/Share/file.txt") 31 | assert path_from_uri(uri) == expected_path 32 | 33 | def test_invalid_file_uri(self): 34 | """Test an invalid 'file:' URI that does not start with 'file:'.""" 35 | uri = "http://example.com/file.txt" 36 | with self.assertRaises(ValueError) as cm: # noqa: PT027 37 | path_from_uri(uri) 38 | assert str(cm.exception) == f"URI does not start with 'file:': '{uri}'" 39 | 40 | def test_relative_file_uri(self): 41 | """Test an invalid 'file:' URI that is not absolute.""" 42 | uri = "file:relative/path/file.txt" 43 | with self.assertRaises(ValueError) as cm: # noqa: PT027 44 | path_from_uri(uri) 45 | assert str(cm.exception) == f"URI is not absolute: '{uri}'" 46 | 47 | @unittest.skip("Cannot test linux path on windows and vice versa") 48 | def test_invalid_dos_drive(self): 49 | """Test an invalid 'file:' URI with incorrect DOS drive.""" 50 | uri = "file://C|/path/to/file.txt" 51 | expected_path = Path("C:/path/to/file.txt") 52 | assert path_from_uri(uri) != expected_path 53 | 54 | @unittest.skip("Cannot test linux path on windows and vice versa") 55 | def test_valid_file_uri_with_encoded_characters(self): 56 | """Test a valid 'file:' URI with encoded characters.""" 57 | uri = "file:///home/user/file%20name.txt" 58 | expected_path = Path("/home/user/file name.txt") 59 | assert path_from_uri(uri) == expected_path 60 | 61 | 62 | class TestFilterContent(unittest.TestCase): 63 | def test_filter_content_with_string_content(self): 64 | message = BaseMessage(content="Hello, World!", type="ai") 65 | result = filter_content(message) 66 | assert result.content == "Hello, World!" 67 | 68 | def test_filter_content_with_list_of_strings(self): 69 | message = BaseMessage(content=["Hello", "World"], type="ai") 70 | result = filter_content(message) 71 | assert result.content == ["Hello", "World"] 72 | 73 | def test_filter_content_with_list_of_dicts(self): 74 | message = BaseMessage( 75 | content=[ 76 | {"type": "text", "text": "Hello"}, 77 | {"type": "image_url", "image_url": "http://example.com/image.jpg"}, 78 | ], 79 | type="ai", 80 | ) 81 | result = filter_content(message) 82 | assert result.content == [{"type": "text", "text": "Hello"}] 83 | 84 | def test_filter_content_with_custom_keep(self): 85 | message = BaseMessage( 86 | content=[ 87 | {"type": "text", "text": "Hello"}, 88 | {"type": "image_url", "image_url": "http://example.com/image.jpg"}, 89 | ], 90 | type="ai", 91 | ) 92 | result = filter_content(message, keep=["image_url", "text"]) 93 | assert result.content == [ 94 | {"type": "text", "text": "Hello"}, 95 | {"type": "image_url", "image_url": "http://example.com/image.jpg"}, 96 | ] 97 | 98 | def test_filter_content_with_mixed_content(self): 99 | message = BaseMessage( 100 | content=[ 101 | "Hello", 102 | {"type": "text", "text": "World"}, 103 | {"type": "image_url", "image_url": "http://example.com/image.jpg"}, 104 | ], 105 | type="ai", 106 | ) 107 | result = filter_content(message) 108 | assert result.content == ["Hello", {"type": "text", "text": "World"}] 109 | 110 | def test_filter_content_with_no_text_type(self): 111 | message = BaseMessage( 112 | content=[ 113 | {"type": "image_url", "image_url": "http://example.com/image.jpg"}, 114 | ], 115 | type="ai", 116 | ) 117 | result = filter_content(message) 118 | assert result.content == [] 119 | 120 | 121 | if __name__ == "__main__": 122 | unittest.main() 123 | --------------------------------------------------------------------------------