├── .env.example ├── .flake8 ├── .github └── workflows │ └── codeql.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── docs ├── data.md └── images │ └── framework.png ├── examples ├── agent │ ├── __init__.py │ ├── chat2vis.py │ ├── coml4vis.py │ ├── lida.py │ └── utils.py ├── evaluate.py └── model │ └── langchain_llama.py ├── pyproject.toml ├── tests ├── assets │ ├── bar.svg │ ├── bar_1.svg │ ├── bar_1000_0.svg │ ├── bar_1071_0.svg │ ├── bar_1129_0.svg │ ├── bar_1129_1.svg │ ├── bar_1129_2.svg │ ├── bar_186.svg │ ├── bar_189.svg │ ├── bar_2.svg │ ├── bar_205.svg │ ├── bar_2060_0.svg │ ├── bar_2733_0.svg │ ├── bar_3.svg │ ├── bar_3137_1.svg │ ├── bar_3269.svg │ ├── bar_4.svg │ ├── bar_465.svg │ ├── bar_5.svg │ ├── bar_550_0.svg │ ├── bar_6.svg │ ├── bar_68_0.svg │ ├── bar_9.svg │ ├── bar_error_3227_0.svg │ ├── dual_1499_0.svg │ ├── empty.svg │ ├── empty_1.svg │ ├── empty_2.svg │ ├── empty_3.svg │ ├── empty_4.svg │ ├── grouping_bar.svg │ ├── grouping_line.svg │ ├── grouping_line_1.svg │ ├── grouping_line_2.svg │ ├── grouping_line_2781_0.svg │ ├── grouping_line_3.svg │ ├── grouping_line_4.svg │ ├── grouping_scatter.svg │ ├── grouping_scatter_3272_0.svg │ ├── horizontal_stacked_bar.svg │ ├── horizontal_stacked_bar_1.svg │ ├── line.svg │ ├── line_1.svg │ ├── line_10.svg │ ├── line_11.svg │ ├── line_12.svg │ ├── line_13.svg │ ├── line_14.svg │ ├── line_15.svg │ ├── line_1746_1.svg │ ├── line_2.svg │ ├── line_3.svg │ ├── line_3240_0.svg │ ├── line_4.svg │ ├── line_5.svg │ ├── line_6.svg │ ├── line_7.svg │ ├── line_773_0.svg │ ├── line_8.svg │ ├── line_9.svg │ ├── pie.svg │ ├── pie_1.svg │ ├── pie_2.svg │ ├── pie_3.svg │ ├── pie_4.svg │ ├── pie_4_0.svg │ ├── pie_5.svg │ ├── readability │ │ ├── 1008.svg │ │ ├── 1024.svg │ │ ├── 1071.svg │ │ ├── 1188.svg │ │ ├── 1237@y_name@ASC.svg │ │ ├── 1285.svg │ │ ├── 131.svg │ │ ├── 1314.svg │ │ ├── 134@x_name@DESC.svg │ │ ├── 1363.svg │ │ ├── 1380@y_name@ASC.svg │ │ ├── 1392.svg │ │ ├── 1434.svg │ │ ├── 145@y_name@DESC.svg │ │ ├── 1491@y_name@DESC.svg │ │ ├── 1517.svg │ │ ├── 1530.svg │ │ ├── 1534.svg │ │ ├── 1630.svg │ │ ├── 173.svg │ │ ├── 17@y_name@DESC.svg │ │ ├── 1961.svg │ │ ├── 1974.svg │ │ ├── 1992.svg │ │ ├── 2013.svg │ │ ├── 2019.svg │ │ ├── 2024@y_name@DESC.svg │ │ ├── 2027.svg │ │ ├── 2110.svg │ │ ├── 2174.svg │ │ ├── 219.svg │ │ ├── 2303.svg │ │ ├── 2350.svg │ │ ├── 2416.svg │ │ ├── 246.svg │ │ ├── 2526.svg │ │ ├── 2571.svg │ │ ├── 2575@y_name@ASC.svg │ │ ├── 2615.svg │ │ ├── 2652.svg │ │ ├── 267.svg │ │ ├── 2724.svg │ │ ├── 2756@x_name@DESC.svg │ │ ├── 2760@y_name@DESC.svg │ │ ├── 2765@y_name@ASC.svg │ │ ├── 279@y_name@DESC.svg │ │ ├── 2815@x_name@DESC.svg │ │ ├── 2841@y_name@ASC.svg │ │ ├── 286.svg │ │ ├── 2941.svg │ │ ├── 2943@y_name@ASC.svg │ │ ├── 296.svg │ │ ├── 3014@y_name@ASC.svg │ │ ├── 3069@x_name@ASC.svg │ │ ├── 312.svg │ │ ├── 3134@y_name@ASC.svg │ │ ├── 316.svg │ │ ├── 32.svg │ │ ├── 3207.svg │ │ ├── 33.svg │ │ ├── 338.svg │ │ ├── 342.svg │ │ ├── 368.svg │ │ ├── 387.svg │ │ ├── 399.svg │ │ ├── 4.svg │ │ ├── 403.svg │ │ ├── 432.svg │ │ ├── 461.svg │ │ ├── 464.svg │ │ ├── 470.svg │ │ ├── 487.svg │ │ ├── 513.svg │ │ ├── 517.svg │ │ ├── 533@y_name@DESC.svg │ │ ├── 555@y_name@DESC.svg │ │ ├── 567.svg │ │ ├── 58@y_name@ASC.svg │ │ ├── 611.svg │ │ ├── 616.svg │ │ ├── 622.svg │ │ ├── 65.svg │ │ ├── 659.svg │ │ ├── 676.svg │ │ ├── 69.svg │ │ ├── 693.svg │ │ ├── 708.svg │ │ ├── 718.svg │ │ ├── 72.svg │ │ ├── 744.svg │ │ ├── 75@y_name@ASC.svg │ │ ├── 789.svg │ │ ├── 803.svg │ │ ├── 81.svg │ │ ├── 847@y_name@DESC.svg │ │ ├── 856.svg │ │ ├── 880.svg │ │ ├── 9.svg │ │ ├── 914.svg │ │ ├── 961.svg │ │ └── readability_human_rating.csv │ ├── samples.json │ ├── scatter.svg │ ├── scatter_1.svg │ ├── scatter_2.svg │ ├── scatter_292.svg │ ├── scatter_3.svg │ ├── scatter_4.svg │ ├── scatter_400_0.svg │ ├── scatter_400_1.svg │ ├── scatter_5.svg │ ├── scatter_6.svg │ ├── scatter_62.svg │ ├── scatter_674.svg │ ├── scatter_7.svg │ ├── stacked_bar.svg │ ├── stacked_bar_1.svg │ ├── stacked_bar_2.svg │ ├── stacked_bar_2750_0.svg │ ├── stacked_bar_2750_1.svg │ ├── stacked_bar_2815.svg │ ├── stacked_bar_2815_0.svg │ ├── stacked_bar_3.svg │ ├── stacked_bar_4.svg │ └── stacked_bar_680_0.svg ├── test_chart_check.py ├── test_data_check.py ├── test_deconstruct.py ├── test_layout_check.py ├── test_order_check.py └── test_surface_form_check.py ├── viseval ├── __init__.py ├── agent.py ├── check │ ├── __init__.py │ ├── chart_check.py │ ├── data_check.py │ ├── deconstruct.py │ ├── layout_check.py │ ├── order_check.py │ ├── readability_check.py │ ├── scale_and_ticks_check.py │ ├── surface_form_check.py │ └── time_utils.py ├── dataset.py └── evaluate.py └── viseval_dataset.zip /.env.example: -------------------------------------------------------------------------------- 1 | # Azure OpenAI 2 | AZURE_OPENAI_ENDPOINT=https://xxx.azure.com/ 3 | OPENAI_API_KEY=XXXXXXXXX 4 | OPENAI_API_VERSION=XXXXXXXXX 5 | 6 | # Google API 7 | GOOGLE_API_KEY=XXXXXXXXX 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | select = "E303, W293, W291, W292, E305, E231, E302" 4 | exclude = 5 | .tox, 6 | __pycache__, 7 | *.pyc, 8 | .env 9 | venv*/*, 10 | .venv/*, 11 | reports/*, 12 | dist/*, 13 | node_modules/*, 14 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | branches: [ "main" ] 19 | schedule: 20 | - cron: '22 2 * * 4' 21 | 22 | jobs: 23 | analyze: 24 | name: Analyze (${{ matrix.language }}) 25 | # Runner size impacts CodeQL analysis time. To learn more, please see: 26 | # - https://gh.io/recommended-hardware-resources-for-running-codeql 27 | # - https://gh.io/supported-runners-and-hardware-resources 28 | # - https://gh.io/using-larger-runners (GitHub.com only) 29 | # Consider using larger runners or machines with greater resources for possible analysis time improvements. 30 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} 31 | timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} 32 | permissions: 33 | # required for all workflows 34 | security-events: write 35 | 36 | # required to fetch internal or private CodeQL packs 37 | packages: read 38 | 39 | # only required for workflows in private repositories 40 | actions: read 41 | contents: read 42 | 43 | strategy: 44 | fail-fast: false 45 | matrix: 46 | include: 47 | - language: python 48 | - build-mode: none 49 | # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' 50 | # Use `c-cpp` to analyze code written in C, C++ or both 51 | # Use 'java-kotlin' to analyze code written in Java, Kotlin or both 52 | # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both 53 | # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis, 54 | # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning. 55 | # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how 56 | # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages 57 | steps: 58 | - name: Checkout repository 59 | uses: actions/checkout@v4 60 | 61 | # Initializes the CodeQL tools for scanning. 62 | - name: Initialize CodeQL 63 | uses: github/codeql-action/init@v3 64 | with: 65 | languages: ${{ matrix.language }} 66 | build-mode: ${{ matrix.build-mode }} 67 | # If you wish to specify custom queries, you can do so here or in a config file. 68 | # By default, queries listed here will override any specified in a config file. 69 | # Prefix the list here with "+" to use these queries and those in the config file. 70 | 71 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 72 | # queries: security-extended,security-and-quality 73 | 74 | # If the analyze step fails for one of the languages you are analyzing with 75 | # "We were unable to automatically build your code", modify the matrix above 76 | # to set the build mode to "manual" for that language. Then modify this step 77 | # to build your code. 78 | # ℹ️ Command-line programs to run using the OS shell. 79 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 80 | - if: matrix.build-mode == 'manual' 81 | run: | 82 | echo 'If you are using a "manual" build mode for one or more of the' \ 83 | 'languages you are analyzing, replace this with the commands to build' \ 84 | 'your code, for example:' 85 | echo ' make bootstrap' 86 | echo ' make release' 87 | exit 1 88 | 89 | - name: Perform CodeQL Analysis 90 | uses: github/codeql-action/analyze@v3 91 | with: 92 | category: "/language:${{matrix.language}}" 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Customized 2 | logs 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ 142 | 143 | # security 144 | .secrets/* 145 | *cache* -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-byte-order-marker 6 | - id: check-case-conflict 7 | - id: check-merge-conflict 8 | - id: check-symlinks 9 | - id: debug-statements 10 | 11 | - repo: https://github.com/pycqa/isort 12 | rev: 5.12.0 13 | hooks: 14 | - id: isort 15 | types: [python] 16 | args: ["--profile=black"] 17 | 18 | - repo: https://github.com/psf/black 19 | rev: 23.3.0 20 | hooks: 21 | - id: black 22 | types: [python] 23 | 24 | - repo: local 25 | hooks: 26 | - id: run-pytest 27 | name: Run pytest 28 | entry: pytest 29 | language: system 30 | types: [python] 31 | pass_filenames: false 32 | always_run: true -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | This project welcomes contributions and suggestions. Most contributions require you to 4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to, 5 | and actually do, grant us the rights to use your contribution. For details, visit 6 | https://cla.microsoft.com. 7 | 8 | When you submit a pull request, a CLA-bot will automatically determine whether you need 9 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the 10 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA. 11 | 12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 14 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VisEval: A NL2VIS Benchmark 2 | VisEval is a benchmark designed to evaluate visualization generation methods. 3 | In this repository, we provide both the toolkit to support the benchmarking, as well as the data used for benchmarks. 4 | 5 | ## What Can VisEval Evaluate 6 | 7 | ![The pipeline of VisEval includes three key modules: the validity checker, the legality checker, and the readability evaluator.](docs/images/framework.png) 8 | 9 | VisEval evaluates generated visualizations from three dimensions: 10 | 1. Whether the generated code can produce the visualization. 11 | 2. Whether the generated visualization meets the query. 12 | 3. Whether the generated visualization is easy to read. 13 | 14 | 15 | ## Get Started 16 | ### Install Benchmark Toolkit 17 | 18 | ```bash 19 | pip install --upgrade vis-evaluator 20 | # or `git clone https://github.com/microsoft/VisEval.git && cd VisEval && pip install --upgrade -e .` 21 | ``` 22 | 23 | ### Download Benchmark Dataset 24 | To access the dataset, please follow these steps: 25 | 26 | 1. Download the dataset from [this link](https://github.com/microsoft/VisEval/blob/main/viseval_dataset.zip). 27 | 2. Once the download is complete, unzip the file to extract the dataset contents. 28 | 29 | For additional information about the dataset, please refer to the [dataset documentation](docs/data.md). 30 | 31 | ### Usage & Examples 32 | After installation, you can use VisEval by referring to `examples/evaluate.py` or a follow: 33 | 34 | 35 | 1. **Create your generation method** by inheriting from the `Agent` Class. You can find three examples in the `examples/agent` directory. 36 | ```python 37 | from viseval.agent import Agent, ChartExecutionResult 38 | 39 | class YourAgent(Agent): 40 | def __init__(self, llm): 41 | self.llm = llm 42 | 43 | def generate( 44 | self, nl_query: str, tables: list[str], config: dict 45 | ) -> Tuple[str, dict]: 46 | """Generate code for the given natural language query.""" 47 | pass 48 | 49 | def execute( 50 | self, code: str, context: dict, log_name: str = None 51 | ) -> ChartExecutionResult: 52 | """Execute the given code with context and return the result""" 53 | pass 54 | ``` 55 | 56 | 2. **Configure evaluator**. 57 | ```python 58 | evaluator = Evaluator(webdriver_path, vision_model) 59 | ``` 60 | (You can configure the Evaluator without a webdriver and vision model, in which case the evaluation of the readability of the generated visualizations will be skipped.) 61 | 62 | - Install webdriver. 63 | ```bash 64 | # download 65 | wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb 66 | # install 67 | apt install google-chrome-stable_current_amd64.deb 68 | # verify 69 | google-chrome --version 70 | ``` 71 | 72 | - Load vision model (e.g., GPT4-v). 73 | ```python 74 | from langchain_openai import AzureChatOpenAI 75 | 76 | import dotenv 77 | # Copy .env.example to .env and put your API keys in the file. 78 | dotenv.load_dotenv() 79 | 80 | vision_model = AzureChatOpenAI( 81 | model_name="gpt-4-turbo-v", 82 | max_retries=999, 83 | temperature=0.0, 84 | request_timeout=20, 85 | max_tokens=4096, 86 | ) 87 | ``` 88 | 89 | 3. **Evaluate** 90 | ```python 91 | from viseval import Dataset 92 | 93 | # Configure dataset with the benchmark dataset folder path ( folder), 94 | # specify the number of tables required to generate visualizations (table_type`: all, single, or multiple), 95 | # and indicate whether to include irrelevant tables (`with_irrelevant_tables`). 96 | dataset = Dataset(folder, table_type, with_irrelevant_tables) 97 | 98 | config = {"library": args.library} 99 | result = evaluator.evaluate(agent, dataset, config) 100 | score = result.score() 101 | print(f"Score: {score}") 102 | ``` 103 | 104 | 105 | ## Contributing 106 | 107 | This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 108 | 109 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. 110 | 111 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 112 | 113 | ## Trademarks 114 | 115 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow 116 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 117 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 118 | Any use of third-party trademarks or logos are subject to those third-party's policies. 119 | 120 | ## Privacy Statement 121 | 122 | This project has adopted the [Microsoft Privacy Statement](https://go.microsoft.com/fwlink/?LinkId=521839.). 123 | 124 | ## Citation 125 | 126 | If you find that VisEval helps your research, please consider citing it: 127 | ``` 128 | @misc{chen2024viseval, 129 | title={VisEval: A Benchmark for Data Visualization in the Era of Large Language Models}, 130 | author={Nan Chen and Yuge Zhang and Jiahang Xu and Kan Ren and Yuqing Yang}, 131 | year={2024}, 132 | eprint={2407.00981}, 133 | archivePrefix={arXiv}, 134 | primaryClass={cs.HC}, 135 | } 136 | ``` 137 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing issues before filing new issues to avoid duplicates. For new issues, file your bug or feature request as a new Issue. 6 | 7 | For help and questions about using this project, please consult the project readme or open an issue. 8 | 9 | ## Microsoft Support Policy 10 | 11 | Support for this project is limited to the resources listed above. 12 | -------------------------------------------------------------------------------- /docs/data.md: -------------------------------------------------------------------------------- 1 | # VisEval Dataset 2 | 3 | ## Introduction 4 | 5 | VisEval dataset is a large-scale and high-quality dataset for Nature Language to Visualization (NL2VI) task. It contains 2,524 (NL, VIS) pairs, supports 7 common types of visualizations and covers 146 databases. 6 | 7 | - **VisEval.json** stores the JSON format of (NL, VIS) pairs. The natural language (NL) is a sentence that describes the desired visualization. The visualization (VIS) is represented in JSON format, including chart type, data for the x-axis, y-axis, and z-axis, as well as information such as sorting requirements. 8 | - **VisEval_single.json** and **VisEval_multiple.json** store the visualizations that can be generated from a single data table and those that require processing multiple data tables, respectively. 9 | - **databases** contains 146 databases, and each database has several data tables saved in CSV format. 10 | 11 | 12 | ## Important Notes / Caveats / FAQs 13 | - The primary objective of this dataset is to serve as a benchmark for evaluating LLMs-based methods in natural language to visualization generation. This dataset is intended for research purposes only and should not be relied upon as the sole benchmark for production scenarios. 14 | - How was the data collected? The dataset was constructed based on [nvBench](https://github.com/TsinghuaDatabaseGroup/nvBench). We selected high-quality queries from the original dataset, and we corrected and annotated the dataset ourselves. More details are provided in our paper's subsection 4.1. -------------------------------------------------------------------------------- /docs/images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/VisEval/619155231c476aaa05f1b3a5b6d79082a6bcf782/docs/images/framework.png -------------------------------------------------------------------------------- /examples/agent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .chat2vis import Chat2vis 5 | from .coml4vis import CoML4VIS 6 | from .lida import Lida 7 | -------------------------------------------------------------------------------- /examples/agent/chat2vis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import sys 5 | import warnings 6 | 7 | import pandas as pd 8 | from langchain.chat_models.base import BaseChatModel 9 | from langchain.schema import HumanMessage 10 | 11 | from viseval.agent import Agent, ChartExecutionResult 12 | 13 | from .utils import show_svg 14 | 15 | MAXIMUM_SAMPLES = 10 16 | 17 | TEMPLATE = '''"""Use a dataframe called df from data_file.csv with columns {columns}. 18 | 19 | {columns_description} 20 | 21 | Label the x and y axes appropriately. Add a title. Set the fig suptitle as empty. 22 | 23 | Using Python version 3.9.12 and library {library}, create a script using the dataframe df to graph the following: {nl_query}. 24 | """ 25 | 26 | {pre_code} 27 | ''' 28 | 29 | 30 | class Chat2vis(Agent): 31 | def __init__(self, llm: BaseChatModel): 32 | self.llm = llm 33 | 34 | def table_format(self, data: pd.DataFrame): 35 | # table format 36 | descriptions = [] 37 | for column in data.columns: 38 | dtype = data[column].dtype 39 | description = None 40 | if dtype in [int, float, complex]: 41 | description = f"The column '{column}' is type {dtype} and contains numeric values." 42 | elif dtype == bool: 43 | description = f"The column '{column}' is type {dtype} and contains boolean values." 44 | elif dtype == object: 45 | # Check if the string column can be cast to a valid datetime 46 | try: 47 | with warnings.catch_warnings(): 48 | warnings.simplefilter("ignore") 49 | pd.to_datetime(data[column], errors="raise") 50 | dtype = "date" 51 | except ValueError: 52 | # Check if the string column has a limited number of values 53 | if data[column].nunique() / len(data[column]) < 0.5: 54 | dtype = "category" 55 | else: 56 | dtype = "string" 57 | elif pd.api.types.is_categorical_dtype(data[column]): 58 | dtype = "category" 59 | elif pd.api.types.is_datetime64_any_dtype(data[column]): 60 | dtype = "date" 61 | 62 | if dtype == "date" or dtype == "category": 63 | non_null_values = data[column][data[column].notnull()].unique() 64 | n_samples = min(MAXIMUM_SAMPLES, len(non_null_values)) 65 | samples = ( 66 | pd.Series(non_null_values) 67 | .sample(n_samples, random_state=42) 68 | .tolist() 69 | ) 70 | values = "'" + "', '".join(samples) + "'" 71 | description = f"The column '{column}' has {dtype} values {values}" 72 | 73 | if n_samples < len(non_null_values): 74 | description += " etc." 75 | else: 76 | description += "." 77 | elif description is None: 78 | description = f"The column '{column}' is {dtype} type." 79 | 80 | descriptions.append(description) 81 | 82 | return " ".join(descriptions) 83 | 84 | def generate(self, nl_query: str, tables: list[str], config: dict): 85 | library = config["library"] 86 | 87 | if library == "seaborn": 88 | import_statements = "import seaborn as sns\n" 89 | else: 90 | import_statements = "" 91 | 92 | pre_code = f"""import pandas as pd 93 | import matplotlib.pyplot as plt 94 | {import_statements} 95 | fig,ax = plt.subplots(1,1,figsize=(10,4)) 96 | ax.spines['top'].set_visible(False) 97 | ax.spines['right'].set_visible(False) 98 | 99 | df=df_nvBenchEval.copy() 100 | """ 101 | data = pd.read_csv(tables[0], encoding="utf-8") 102 | columns = "'" + "', '".join(list(data.columns)) + "'" 103 | prompt = TEMPLATE.format( 104 | columns=columns, 105 | columns_description=self.table_format(data), 106 | library=library, 107 | nl_query=nl_query, 108 | pre_code=pre_code, 109 | ) 110 | 111 | try: 112 | messages = [HumanMessage(content=prompt)] 113 | response = self.llm.invoke(messages) 114 | code = response.content 115 | codes = code.split("\n") 116 | codes = list(filter(lambda row: "data_file.csv" not in row, codes)) 117 | code = "\n".join(codes) 118 | # plot.show 119 | if "plt.show()" not in code and ("plt." in code or "fig." in code): 120 | code += "\nplt.show()" 121 | 122 | context = { 123 | "tables": tables, 124 | } 125 | return pre_code + "\n" + code, context 126 | except Exception: 127 | warnings.warn(str(sys.exc_info())) 128 | return None, None 129 | 130 | def execute(self, code: str, context: dict, log_name: str = None): 131 | tables = context["tables"] 132 | df_nvBenchEval = pd.read_csv(tables[0]) 133 | global_env = { 134 | "df_nvBenchEval": df_nvBenchEval, 135 | "svg_string": None, 136 | "show_svg": show_svg, 137 | "svg_name": log_name, 138 | } 139 | code += "\nsvg_string = show_svg(plt, svg_name)" 140 | try: 141 | exec(code, global_env) 142 | svg_string = global_env["svg_string"] 143 | return ChartExecutionResult(status=True, svg_string=svg_string) 144 | except Exception as exception_error: 145 | import traceback 146 | 147 | exception_info = traceback.format_exception_only( 148 | type(exception_error), exception_error 149 | ) 150 | return ChartExecutionResult(status=False, error_msg=exception_info) 151 | -------------------------------------------------------------------------------- /examples/agent/coml4vis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import sys 5 | import warnings 6 | 7 | import pandas as pd 8 | from coml import CoMLAgent, describe_variable 9 | from langchain.chat_models.base import BaseChatModel 10 | 11 | from viseval.agent import Agent, ChartExecutionResult 12 | 13 | from .utils import show_svg 14 | 15 | 16 | def read_table(name, url, format): 17 | code = f"{name}_dataset = pd.read_csv('{url}')" 18 | variable_description = {} 19 | exec(code) 20 | exec( 21 | f"variable_description['{name}_dataset'] = describe_variable({name}_dataset, dataframe_format='{format}', pandas_description_config=dict(max_rows=10))" 22 | ) 23 | return code, variable_description 24 | 25 | 26 | class CoML4VIS(Agent): 27 | def __init__(self, llm: BaseChatModel, config: dict = None): 28 | num_examples = 1 29 | prompt_version = "matplotlib" 30 | if config: 31 | if "num_examples" in config: 32 | num_examples = min(config["num_examples"], 4) 33 | if "library" in config and config["library"] in [ 34 | "matplotlib", 35 | "seaborn", 36 | ]: 37 | prompt_version = config["library"] 38 | 39 | self.coml = CoMLAgent( 40 | llm, num_examples=num_examples, prompt_version=prompt_version 41 | ) 42 | 43 | def pre_code(self, tables: list[dict], chart_lib: str, table_format: str = "coml"): 44 | codes = ["import pandas as pd\nimport matplotlib.pyplot as plt\n"] 45 | variable_descriptions = {} 46 | if chart_lib == "seaborn": 47 | codes[-1] += "import seaborn as sns\n" 48 | 49 | for url in tables: 50 | name = url.split("/")[-1].split(".")[0] 51 | code, variable_description = read_table(name, url, table_format) 52 | codes.append(code) 53 | variable_descriptions.update(variable_description) 54 | return codes, variable_descriptions 55 | 56 | def generate(self, nl_query: str, tables: list[str], config: dict): 57 | library = config["library"] 58 | table_format = config["table_format"] if "table_format" in config else "coml" 59 | 60 | pre_codes, variable_descriptions = self.pre_code(tables, library, table_format) 61 | generating_context = self.coml.generate_code( 62 | nl_query, variable_descriptions, pre_codes 63 | ) 64 | try: 65 | generating_context = self.coml.generate_code( 66 | nl_query, variable_descriptions, pre_codes 67 | ) 68 | generate_code = generating_context["answer"] 69 | 70 | context = {"tables": tables, "library": library} 71 | return "\n".join(pre_codes) + "\n" + generate_code, context 72 | except Exception: 73 | warnings.warn(str(sys.exc_info())) 74 | return None, None 75 | 76 | def execute(self, code: str, context: dict, log_name: str = None): 77 | tables = context["tables"] 78 | library = context["library"] 79 | 80 | global_env = {"svg_string": None, "show_svg": show_svg, "svg_name": log_name} 81 | code += "\nsvg_string = show_svg(plt, svg_name)" 82 | try: 83 | exec(code, global_env) 84 | svg_string = global_env["svg_string"] 85 | return ChartExecutionResult(status=True, svg_string=svg_string) 86 | except Exception as exception_error: 87 | try: 88 | # handle old version 89 | codes, variable_descriptions = self.pre_code(tables, library) 90 | exec("\n".join(codes) + "\n" + code, global_env) 91 | svg_string = global_env["svg_string"] 92 | return ChartExecutionResult(status=True, svg_string=svg_string) 93 | except Exception as exception_error: 94 | error_msg = str(exception_error) 95 | return ChartExecutionResult(status=False, error_msg=error_msg) 96 | -------------------------------------------------------------------------------- /examples/agent/lida.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import time 5 | 6 | import matplotlib.pyplot as plt 7 | import pandas as pd 8 | from lida import Manager 9 | from lida.components import get_globals_dict, preprocess_code 10 | from lida.datamodel import Goal 11 | from llmx import TextGenerator 12 | 13 | from viseval.agent import Agent, ChartExecutionResult 14 | 15 | from .utils import show_svg 16 | 17 | max_retries = 20 18 | retry_seconds = 20 19 | 20 | 21 | class Lida(Agent): 22 | def __init__(self, llm: TextGenerator): 23 | self.lida = Manager(text_gen=llm) 24 | 25 | def generate(self, nl_query: str, tables: list[str], config: dict): 26 | library = config["library"] 27 | summary = self.lida.summarize(tables[0]) 28 | 29 | for attempt in range(max_retries): 30 | try: 31 | charts = self.lida.visualize( 32 | summary=summary, goal=nl_query, library=library, return_error=True 33 | ) 34 | 35 | code = charts[0].code 36 | code += "\nplt.show()" 37 | 38 | context = {"data": self.lida.data, "library": library} 39 | return code, context 40 | except Exception: 41 | if attempt < max_retries - 1: 42 | print(f"Retrying in {retry_seconds} seconds...") 43 | time.sleep(retry_seconds) 44 | 45 | return None, None 46 | 47 | def execute(self, code: str, context: dict, log_name: str = None): 48 | data = context["data"] 49 | library = context["library"] 50 | 51 | code = preprocess_code(code) 52 | if library == "matplotlib" or library == "seaborn": 53 | try: 54 | ex_locals = get_globals_dict(code, data) 55 | exec(code, ex_locals) 56 | 57 | plt.box(False) 58 | plt.grid(color="lightgray", linestyle="dashed", zorder=-10) 59 | 60 | svg_string = show_svg(plt, log_name) 61 | return ChartExecutionResult(status=True, svg_string=svg_string) 62 | except Exception as exception_error: 63 | import traceback 64 | 65 | exception_info = traceback.format_exception_only( 66 | type(exception_error), exception_error 67 | ) 68 | return ChartExecutionResult(status=False, error_msg=exception_info) 69 | else: 70 | pass 71 | 72 | def evaluate(self, code: str, nl_query: str, library: str): 73 | goal = Goal(question=nl_query, visualization=nl_query, rationale="") 74 | 75 | result = self.lida.evaluate(code, goal, library=library) 76 | return result[0] 77 | -------------------------------------------------------------------------------- /examples/agent/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | def show_svg(plt, svg_name: str): 6 | """Show a plot as a SVG inline.""" 7 | from io import StringIO 8 | 9 | f = StringIO() 10 | plt.savefig(f, format="svg") 11 | if svg_name: 12 | plt.savefig(f"{svg_name}") 13 | plt.close() 14 | 15 | return f.getvalue() 16 | -------------------------------------------------------------------------------- /examples/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import argparse 5 | from pathlib import Path 6 | 7 | import dotenv 8 | from agent import Chat2vis, CoML4VIS, Lida 9 | 10 | from viseval import Dataset, Evaluator 11 | 12 | dotenv.load_dotenv() 13 | 14 | 15 | def configure_llm(model: str, agent: str): 16 | if agent == "lida": 17 | if model in ["gpt-35-turbo", "gpt-4"]: 18 | from llmx import llm 19 | 20 | return llm( 21 | provider="openai", 22 | api_type="azure", 23 | model=model, 24 | models={ 25 | "max_tokens": 4096, 26 | "temperature": 0.0, 27 | }, 28 | ) 29 | else: 30 | raise ValueError(f"Unknown model {model}") 31 | else: 32 | if model == "gemini-pro": 33 | from langchain_google_genai import ChatGoogleGenerativeAI 34 | 35 | return ChatGoogleGenerativeAI( 36 | model=model, temperature=0.0, convert_system_message_to_human=True 37 | ) 38 | elif model in ["gpt-35-turbo", "gpt-4"]: 39 | from langchain_openai import AzureChatOpenAI 40 | 41 | return AzureChatOpenAI( 42 | model_name=model, 43 | max_retries=999, 44 | temperature=0.0, 45 | request_timeout=20, 46 | ) 47 | elif model == "codellama-7b": 48 | from model.langchain_llama import ChatLlama 49 | 50 | return ChatLlama("../llama_models/CodeLlama-7b-Instruct") 51 | else: 52 | raise ValueError(f"Unknown model {model}") 53 | 54 | 55 | def config_agent(agent: str, model: str, config: dict): 56 | llm = configure_llm(model, agent) 57 | if agent == "coml4vis": 58 | return CoML4VIS(llm, config) 59 | elif agent == "chat2vis": 60 | return Chat2vis(llm) 61 | elif agent == "lida": 62 | return Lida(llm) 63 | else: 64 | raise ValueError(f"Unknown agent {agent}") 65 | 66 | 67 | def _main(): 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("--benchmark", type=Path) 70 | parser.add_argument( 71 | "--type", type=str, choices=["all", "single", "multiple"], default="all" 72 | ) 73 | parser.add_argument("--irrelevant_tables", type=bool, default=False) 74 | 75 | parser.add_argument( 76 | "--model", 77 | type=str, 78 | default="gpt-35-turbo", 79 | choices=["gpt-4", "gpt-35-turbo", "gemini-pro", "codellama-7b"], 80 | ) 81 | parser.add_argument( 82 | "--agent", 83 | type=str, 84 | default="coml4vis", 85 | choices=["coml4vis", "lida", "chat2vis"], 86 | ) 87 | parser.add_argument("--num_examples", type=int, default=1, choices=range(0, 4)) 88 | parser.add_argument( 89 | "--library", type=str, default="matplotlib", choices=["matplotlib", "seaborn"] 90 | ) 91 | parser.add_argument("--logs", type=Path, default="./logs") 92 | parser.add_argument("--webdriver", type=Path, default="/usr/bin/chromedriver") 93 | 94 | args = parser.parse_args() 95 | 96 | # config dataset 97 | dataset = Dataset(args.benchmark, args.type, args.irrelevant_tables) 98 | 99 | # config agent 100 | agent = config_agent( 101 | args.agent, 102 | args.model, 103 | {"num_examples": args.num_examples, "library": args.library}, 104 | ) 105 | 106 | from langchain_openai import AzureChatOpenAI 107 | 108 | vision_model = AzureChatOpenAI( 109 | model_name="gpt-4-turbo-v", 110 | max_retries=999, 111 | temperature=0.0, 112 | request_timeout=20, 113 | max_tokens=4096, 114 | ) 115 | # config evaluator 116 | evaluator = Evaluator(webdriver_path=args.webdriver, vision_model=vision_model) 117 | 118 | # evaluate agent 119 | config = {"library": args.library, "logs": args.logs} 120 | result = evaluator.evaluate(agent, dataset, config) 121 | score = result.score() 122 | print(f"Score: {score}") 123 | 124 | 125 | if __name__ == "__main__": 126 | _main() 127 | -------------------------------------------------------------------------------- /examples/model/langchain_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import json 5 | from typing import Any, List, Optional 6 | 7 | import torch 8 | from langchain.callbacks.manager import CallbackManagerForLLMRun 9 | from langchain.chat_models.base import SimpleChatModel 10 | from langchain.schema import ( 11 | AIMessage, 12 | BaseMessage, 13 | ChatMessage, 14 | HumanMessage, 15 | SystemMessage, 16 | ) 17 | from transformers import AutoTokenizer, LlamaForCausalLM 18 | 19 | B_ROUND, E_ROUND = "", "" 20 | B_INST, E_INST = "[INST]", "[/INST]" 21 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 22 | 23 | SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"] 24 | UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt." 25 | 26 | 27 | class ChatLlama(SimpleChatModel): 28 | # read from config.json 29 | max_context_length: int = 2048 30 | max_new_tokens: int = 4096 31 | temperature: float = 0.0 32 | top_p: float = 1 33 | top_k: int = 50 34 | tokenizer: Any 35 | model: Any 36 | 37 | def __init__(self, model_path): 38 | super().__init__() 39 | 40 | self.model = LlamaForCausalLM.from_pretrained( 41 | model_path, torch_dtype=torch.float16, device_map="auto" 42 | ) 43 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 44 | self.tokenizer.pad_token = self.tokenizer.eos_token 45 | with open(f"{model_path}/config.json") as f: 46 | config = json.load(f) 47 | self.max_context_length = config["max_position_embeddings"] 48 | 49 | @property 50 | def _llm_type(self) -> str: 51 | return "llama2-chat" 52 | 53 | def _assemble_prompt(self, messages: List[BaseMessage]) -> str: 54 | prompt = "" 55 | temp = [] 56 | for message in messages: 57 | if isinstance(message, ChatMessage): 58 | role = message.role 59 | elif isinstance(message, HumanMessage): 60 | role = "user" 61 | elif isinstance(message, AIMessage): 62 | role = "assistant" 63 | elif isinstance(message, SystemMessage): 64 | role = "system" 65 | else: 66 | raise ValueError(f"Got unknown type {message}") 67 | temp.append({"role": role, "content": message.content}) 68 | messages = temp 69 | if messages[0]["role"] == "system": 70 | messages = [ 71 | { 72 | "role": messages[1]["role"], 73 | "content": B_SYS 74 | + messages[0]["content"] 75 | + E_SYS 76 | + messages[1]["content"], 77 | } 78 | ] + messages[2:] 79 | assert all([msg["role"] == "user" for msg in messages[::2]]) and all( 80 | [msg["role"] == "assistant" for msg in messages[1::2]] 81 | ), ( 82 | "model only supports 'system', 'user' and 'assistant' roles, " 83 | "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" 84 | ) 85 | prompts: List[str] = [ 86 | f"{B_ROUND}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {E_ROUND}" 87 | for prompt, answer in zip( 88 | messages[::2], 89 | messages[1::2], 90 | ) 91 | ] 92 | prompt = "".join(prompts) 93 | assert ( 94 | messages[-1]["role"] == "user" 95 | ), f"Last message must be from user, got {messages[-1]['role']}" 96 | prompt += f"{B_ROUND}{B_INST} {(messages[-1]['content']).strip()} {E_INST}" 97 | return prompt 98 | 99 | def _call( 100 | self, 101 | messages: List[BaseMessage], 102 | stop: Optional[List[str]] = None, 103 | run_manager: Optional[CallbackManagerForLLMRun] = None, 104 | ) -> str: 105 | unsafe = any( 106 | tag in message.content for message in messages for tag in SPECIAL_TAGS 107 | ) 108 | if unsafe: 109 | return UNSAFE_ERROR 110 | prompt = self._assemble_prompt(messages) 111 | 112 | input_ids = self.tokenizer( 113 | prompt, return_tensors="pt", add_special_tokens=False 114 | ).input_ids.to("cuda") 115 | assert len(input_ids[0]) <= self.max_context_length, ( 116 | f"Prompt is too long, got {len(input_ids[0])} tokens, " 117 | f"max is {self.max_context_length}" 118 | ) 119 | generate_input = { 120 | "input_ids": input_ids, 121 | "max_new_tokens": self.max_new_tokens, 122 | "repetition_penalty": 1.0, 123 | "eos_token_id": self.tokenizer.eos_token_id, 124 | "bos_token_id": self.tokenizer.bos_token_id, 125 | "pad_token_id": self.tokenizer.pad_token_id, 126 | } 127 | if self.temperature == 0.0: 128 | generate_input["do_sample"] = False 129 | else: 130 | generate_input["do_sample"] = True 131 | generate_input["temperature"] = self.temperature 132 | generate_input["top_p"] = self.top_p 133 | generate_input["top_k"] = self.top_k 134 | 135 | generate_ids = self.model.generate(**generate_input) 136 | generate_ids = [item[len(input_ids[0]) : -1] for item in generate_ids] 137 | # output 138 | output = self.tokenizer.batch_decode( 139 | generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False 140 | )[0] 141 | return output 142 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "vis-evaluator" 3 | version = "0.0.3" 4 | requires-python = ">=3.10" 5 | dependencies = [ 6 | "langchain", 7 | "python-dotenv", 8 | "numpy", 9 | "pandas", 10 | "matplotlib", 11 | "seaborn", 12 | "CairoSVG", 13 | "selenium", 14 | "llmx", 15 | ] 16 | readme = "README.md" 17 | license = {text = "MIT"} 18 | 19 | [tool.setuptools.packages.find] 20 | include = ["viseval*"] 21 | 22 | [project.urls] 23 | Documentation = "https://github.com/microsoft/VisEval#readme" 24 | Issues = "https://github.com/microsoft/VisEval/issues" 25 | Source = "https://github.com/microsoft/VisEval" 26 | 27 | [project.optional-dependencies] 28 | dev = [ 29 | "pytest", 30 | "flake8", 31 | "pre-commit", 32 | "black", 33 | "isort", 34 | "hatch", 35 | "mlcopilot", 36 | "lida", 37 | ] 38 | 39 | 40 | [build-system] 41 | requires = ["setuptools", "wheel"] 42 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /tests/assets/empty_3.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2023-12-04T11:29:20.935948 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.8.2, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /tests/assets/empty_4.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2024-03-01T03:47:43.000613 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.8.2, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 48 | 61 | 92 | 114 | 147 | 154 | 167 | 192 | 193 | 212 | 242 | 268 | 285 | 299 | 320 | 339 | 360 | 381 | 397 | 418 | 435 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 571 | 597 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 660 | 661 | 662 | 663 | 664 | 665 | -------------------------------------------------------------------------------- /tests/assets/pie_4.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2023-11-22T15:34:03.855731 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.7.2, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 92 | 114 | 127 | 134 | 160 | 161 | 173 | 194 | 211 | 232 | 248 | 267 | 278 | 303 | 334 | 355 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 433 | 452 | 473 | 489 | 523 | 554 | 573 | 586 | 616 | 629 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 670 | 671 | 672 | 673 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | 684 | 685 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 717 | 718 | 719 | 720 | 721 | 722 | 723 | 724 | 725 | 726 | -------------------------------------------------------------------------------- /tests/assets/readability/2571.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2024-02-08T18:26:02.080912 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.8.2, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 41 | 42 | 43 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 93 | 119 | 136 | 149 | 168 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 221 | 253 | 260 | 285 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 355 | 371 | 392 | 418 | 425 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 539 | 540 | 570 | 587 | 612 | 644 | 665 | 666 | 667 | 668 | 669 | 670 | 671 | 672 | 673 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | 684 | 685 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 700 | -------------------------------------------------------------------------------- /tests/assets/readability/readability_human_rating.csv: -------------------------------------------------------------------------------- 1 | visualization_id,query,expert1,expert2,expert3,has_layout_issues,has_scale&ticks_issues 2 | 2756@x_name@DESC,"Draw a bar chart showing the total number of every ship type by categorizing by nationality, and list x-axis in descending order.",1,1,2,TRUE,FALSE 3 | 2815@x_name@DESC,"Compute the total number in each competition type, and then split by country. Plot a bar chart and show in desc by the x-axis.",1,1,1,TRUE,FALSE 4 | 2841@y_name@ASC,"Bar graph to show the average of high temperature from different day of week, and could you show y axis in ascending order?",2,2,3,FALSE,TRUE 5 | 312,Give me the proportion of end stations for the trips with the three oldest id using a pie chart.,4,4,4,TRUE,FALSE 6 | 69,What is the number of booking start dates of the apartments with more than 2 bedrooms for each year? Return a bar chart.,1,2,1,TRUE,TRUE 7 | 4,Show all the faculty ranks and the number of students advised by each rank with a pie chart.,4,5,5,FALSE,FALSE 8 | 9,Show the number of faculty members for each rank in a bar chart.,4,3,1,TRUE,FALSE 9 | 32,List the names of aircrafts and the number of times it won matches by a bar chart.,3,3,2,TRUE,TRUE 10 | 286,What are the date and mean humidity for the top 3 days with the largest max wind speeds? Return me a bar chart.,2,3,1,TRUE,FALSE 11 | 296,"For the top 3 days with the largest max wind speeds, please bin the date into the of the week and then compute the average of mean humidity to visualize a bar chart.",3,3,2,TRUE,FALSE 12 | 33,"What are the descriptions for the aircrafts, and count them by a pie chart",4,4,5,FALSE,FALSE 13 | 65,"Show the number of start dates of all the apartment bookings made by guests with gender code ""Female"" for each year in a bar chart.",1,2,1,FALSE,TRUE 14 | 72,Find the number of booking start date for the apartments that have more than two bedrooms for each weekday in a bar chart.,4,3,4,FALSE,TRUE 15 | 81,A pie chart for showing the number of the facility codes of apartments with more than 4 bedrooms.,4,5,5,FALSE,FALSE 16 | 387,"A grouped scatter chart shows the correlation between Height and Weight , and group by attribute Sex.",5,5,4,FALSE,FALSE 17 | 399,"Plot a scatter chart, to show the correlation between support and consider rates for each candidate.",5,4,3,FALSE,FALSE 18 | 432,plot scatter on what is the maximum accelerate for different number of cylinders?,5,4,4,FALSE,FALSE 19 | 267,Return a line chart about the change of monthly_rental over date_address_from .,4,5,5,FALSE,FALSE 20 | 316,A line chart for giveing me the number of the dates when the max temperature was higher than 85.,4,3,4,FALSE,TRUE 21 | 470,What is the sum of capacity of cinemas open for each year? Return a line chart.,4,4,5,FALSE,FALSE 22 | 173,A pie chart showing the number of results of the battles when the bulgarian commander is not 'Boril'.,4,5,4,FALSE,FALSE 23 | 219,"Show me about the distribution of other_details and the average of monthly_rental , and group by attribute other_details in a bar chart.",5,5,5,FALSE,FALSE 24 | 246,Draw a bar chart about the distribution of date_address_from and the sum of monthly_rental bin date_address_from by year.,1,2,1,FALSE,TRUE 25 | 338,Compute the total number of stations across city as a pie chart.,4,4,5,FALSE,FALSE 26 | 342,"What are the dates in which the mean sea level pressure was between 30.3 and 31, and count them by a line chart",2,1,1,TRUE,TRUE 27 | 368,A bar chart showing the number of accelerators for each browser.,4,3,2,TRUE,TRUE 28 | 403,Count the number of people of each sex who have a weight higher than 85 by a bar chart.,4,4,3,FALSE,TRUE 29 | 1008,Show the number of products with price higher than 1000 or lower than 500 for each product name in a pie chart.,5,5,5,FALSE,FALSE 30 | 461,Give me a pie chart showing sum of price for each cinema.,4,4,3,FALSE,FALSE 31 | 464,Show each location and the number of cinemas there by a bar chart.,4,3,4,TRUE,TRUE 32 | 487,Show the number of climbers for each mountain in a pie chart.,4,3,4,TRUE,FALSE 33 | 513,Find the total credits of all classes offered by each department. Visualize by bar chart.,5,5,5,FALSE,FALSE 34 | 517,Find the number of students whose gpa is lower than the average gpa of all students for different first name in a pie chart.,4,4,5,FALSE,FALSE 35 | 567,What is the gpa of the top 5 students with highest gpa? Show me a bar chart with each student by first name.,5,5,5,FALSE,FALSE 36 | 611,Find the number of courses offered by Psychology department in each year with a line chart.,3,2,4,FALSE,TRUE 37 | 616,"Find dept_name and the sum of salary , and group by attribute dept_name, and visualize them by a bar chart.",1,2,1,TRUE,FALSE 38 | 622,Find the relationship between average and maximum capacity among rooms in each building with a scatter chart.,5,5,5,FALSE,FALSE 39 | 659,"Find the last name of female (sex is F) students in the descending order of age, and count them by a bar chart",4,3,4,TRUE,TRUE 40 | 676,What is the number of each course name that have at least five enrollments? Show me a bar chart.,2,3,3,TRUE,TRUE 41 | 693,Show the number of singers in each country with a bar chart.,4,4,3,FALSE,TRUE 42 | 708,Pie chart. how many counties correspond to each police force?,4,4,5,FALSE,FALSE 43 | 718,Show the number of courses each teacher is arranged to teach in a pie chart.,4,5,5,FALSE,FALSE 44 | 744,Show all template type codes and the number of documents using each type with a bar chart.,4,4,3,FALSE,FALSE 45 | 789,Show all calendar dates and day Numbers in a line chart.,4,3,4,FALSE,TRUE 46 | 803,Show budget type codes and the number of documents in each budget type with a bar chart.,5,5,4,FALSE,FALSE 47 | 856,"Which workshop groups have bookings with status code ""stop""? Give me the names, and count them by a pie chart",4,4,5,FALSE,FALSE 48 | 880,A line chart for what are the number of the actual delivery dates of orders with quantity 1?,4,3,4,FALSE,TRUE 49 | 914,"Return a bar chart showing the proportion of the number of orders that have the status ""Delivered"" for each customer name.",3,2,3,TRUE,TRUE 50 | 961,Show all product names and the total quantity ordered for each product name in a bar chart.,3,3,2,TRUE,FALSE 51 | 1008,Show the number of products with price higher than 1000 or lower than 500 for each product name in a pie chart.,4,4,5,FALSE,FALSE 52 | 1071,"List the venues of debates in ascending order of the number of audience, and count them by a bar chart",2,2,2,TRUE,FALSE 53 | 1188,How many dogs who have gone through a treatment departed in each day? Return a bar chart.,5,5,5,FALSE,FALSE 54 | 1285,"Which tests have ""Pass"" results? Return the dates when the tests were taken, and count them by a line chart",4,3,3,FALSE,TRUE 55 | 1314,"For each county, find the name of the county and the number of delegates from that county. Show the proportion by pie chart.",5,4,5,FALSE,FALSE 56 | 1363,Draw a bar chart for what is the number of employees from each city?,4,4,3,FALSE,TRUE 57 | 1392,Please show the number of films for each type in a bar chart.,3,3,2,TRUE,FALSE 58 | 1434,A pie chart for finding the number of the names of Japanese constructors that have once earned more than 5 points?,4,4,5,FALSE,FALSE 59 | 1517,Show the number of companies in each headquarter with a pie chart.,3,4,3,FALSE,FALSE 60 | 1530,A bar chart for listing the number of the names of patients who have made appointments.,4,4,3,FALSE,TRUE 61 | 1630,Provide the frequency of the last names of employees earning more than the employee with id 163 using a bar chart.,4,3,3,FALSE,TRUE 62 | 1961,"For all employees in the Finance department, show me the proportion of their job id using a pie chart.",5,4,5,FALSE,FALSE 63 | 1974,Find the number of rooms with king bed for each decor type. Plot them as pie chart.,4,4,4,FALSE,FALSE 64 | 1992,"Among all the claims, which claims have a claimed amount larger than the average? Please Bin the date it was settled into weekday interval and count them to show a bar chart.",4,3,3,FALSE,TRUE 65 | 2013,What about the proportion of the total amounts of payments by each method code? You can give me a pie chart.,5,4,4,FALSE,FALSE 66 | 2027,Sum the amount for all the payments processed with Visa by each year using a bar chart.,1,2,1,FALSE,TRUE 67 | 2110,"Group and count the move in date in a bar chart, and I want to bin the X into Year interval.",1,2,1,FALSE,TRUE 68 | 2174,Give me a bar chart to show the names and revenue of the company that earns the highest revenue in each headquarter city.,5,4,5,FALSE,FALSE 69 | 2303,Please show me a bar chart for visualizing the name and revenue of all manufacturers sorted by their revenue in the descending order.,4,4,4,TRUE,FALSE 70 | 2350,Group and count the color scheme for all the photos using a pie chart.,4,4,4,FALSE,FALSE 71 | 2416,Find the name and membership level of the visitors whose membership level is higher than 4. Plot them as pie chart.,5,4,4,FALSE,FALSE 72 | 2526,Show the proportion of all ministers using a pie chart.,5,4,5,FALSE,FALSE 73 | 2571,Give me a pie to show total number of memory in g from different carrier.,4,4,4,FALSE,FALSE 74 | 2615,"What are the payment date of the payment with amount paid higher than 300 or with payment type is 'Check', bin the payment date by month and count them by a bar chart",5,5,5,FALSE,FALSE 75 | 2652,A bar chart for listing the number of the builders of railways in ascending alphabetical order.,4,3,3,TRUE,TRUE 76 | 2724,"For each denomination, return the denomination and the count of schools with that denomination. Visualize by bar chart.",4,4,4,FALSE,TRUE 77 | 2941,Give me a bar chart about the number of countries in the artist table,5,5,5,FALSE,FALSE 78 | 3207,Show the name and age for all male people who don't have a wedding by a bar chart.,4,4,4,TRUE,FALSE 79 | 1491@y_name@DESC,"Show the number of games for each home team in a bar chart, sort by the y-axis in desc please.",5,5,4,FALSE,FALSE 80 | 1380@y_name@ASC,"Find each target user's name and average trust score Visualize by bar chart, rank y axis from low to high order.",5,5,4,FALSE,FALSE 81 | 58@y_name@ASC,"What are the average ages for male and female students Plot them as bar chart, and list Y in asc order.",3,4,4,FALSE,FALSE 82 | 847@y_name@DESC,"Show the number of documents created in each day and bin document date by weekday with a bar chart, list by the y axis in desc.",2,1,1,TRUE,TRUE 83 | 2024@y_name@DESC,"For those payments processed with Visa, bin the payment day into Year interval and count them for a bar chart, order from high to low by the y axis.",2,2,1,FALSE,TRUE 84 | 3014@y_name@ASC,"Find the data about the sale details and dates of transactions with amount smaller than 3000? Bin the date of the transaction into a weekday interval and compute the total number of each day with a bar chart, and display Y in ascending order.",4,4,3,FALSE,TRUE 85 | 17@y_name@DESC,"Bar chart of total number by each rank, sort from high to low by the Y.",5,5,5,FALSE,FALSE 86 | 2575@y_name@ASC,"For each phone, show its names and total number of stocks Visualize by bar chart, rank by the Y-axis in ascending please.",5,5,5,FALSE,FALSE 87 | 75@y_name@ASC,"What is the number of booking start dates of the apartments with type code ""Duplex"" in each year? Return a bar chart, show y axis in ascending order.",1,1,1,FALSE,TRUE 88 | 533@y_name@DESC,"What is the lowest student GPA for every department? Return a bar chart, display in desc by the Y.",4,5,5,FALSE,FALSE 89 | 555@y_name@DESC,"How many classes are held in each department Visualize by bar chart, and list by the total number from high to low.",5,5,5,FALSE,FALSE 90 | 1237@y_name@ASC,"Find the average age for students with different sex in a bar chart, could you list by the Y from low to high?",3,4,4,FALSE,FALSE 91 | 279@y_name@DESC,"A bar chart about the number of end dates for incidents with incident type code ""NOISE"" and bin by month.",1,1,1,FALSE,TRUE 92 | 3134@y_name@ASC,"Give me a bar chart for all_games_percent of each team name, could you list in asc by the y axis please?",5,5,5,FALSE,FALSE 93 | 2943@y_name@ASC,"Bar chart x axis country y axis the average of age, and I want to show by the y-axis in asc please.",5,5,4,FALSE,FALSE 94 | 2760@y_name@DESC,"Bar graph to show how many nationality from different nationality, list in desc by the total number.",5,5,5,FALSE,FALSE 95 | 2765@y_name@ASC,"Give me a bar chart for mean tonnage of each type, show by the Y in asc.",5,5,5,FALSE,FALSE 96 | 134@x_name@DESC,"List the number of enginners in a stacked bar chart The x-axis is last name and group by skill description, rank last_name in descending order.",2,2,2,TRUE,TRUE 97 | 3069@x_name@ASC,"Find the name of each user and number of tweets tweeted by each of them Visualize by bar chart, could you order X-axis in asc order?",4,3,3,TRUE,TRUE 98 | 145@y_name@DESC,"A stacked bar chart showing thfe number of faults for different fault short name and skills required to fix them The x-axis is skill description and group by fault short name, display by the y axis in descending.",3,4,4,TRUE,FALSE 99 | 131,Give me a pie chart about the number of engineers for different skill description.,5,4,5,FALSE,FALSE 100 | 2019,Bin the claim date into the Year interval and count them for visualizing a bar chart.,3,2,1,FALSE,TRUE 101 | 1534,Return a pie on how many patients do each physician take care of? List their names and number of patients they take care of.,4,4,5,FALSE,FALSE 102 | -------------------------------------------------------------------------------- /tests/test_chart_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import json 5 | from pathlib import Path 6 | 7 | import pytest 8 | 9 | from viseval.check import chart_check, deconstruct 10 | 11 | folder = Path(__file__).resolve().parent 12 | 13 | with open(folder / "assets/samples.json") as f: 14 | benchmark = json.load(f) 15 | 16 | 17 | def test_chart_check_pie_4(): 18 | with open(folder / "assets/pie_4_0.svg", "r") as f: 19 | svg_string = f.read() 20 | chart_info, msg = deconstruct(svg_string) 21 | ground_truth = benchmark["4"]["chart"] 22 | query_meta = benchmark["4"]["query_meta"] 23 | result = chart_check( 24 | chart_info, 25 | ground_truth, 26 | query_meta[0]["stacked_bar"] if "stacked_bar" in query_meta[0] else None, 27 | ) 28 | assert result[0] is True 29 | 30 | 31 | def test_chart_check_scatter_400(): 32 | for i in range(2): 33 | with open(folder / f"assets/scatter_400_{i}.svg", "r") as f: 34 | svg_string = f.read() 35 | chart_info, msg = deconstruct(svg_string) 36 | ground_truth = benchmark["400"]["chart"] 37 | query_meta = benchmark["400"]["query_meta"] 38 | result = chart_check( 39 | chart_info, 40 | ground_truth, 41 | ( 42 | query_meta[i]["stacked_bar"] 43 | if "stacked_bar" in query_meta[i] 44 | else None 45 | ), 46 | ) 47 | assert result[0] is True 48 | 49 | 50 | def test_chart_check_bar_1129(): 51 | for i in range(3): 52 | with open(folder / f"assets/bar_1129_{i}.svg", "r") as f: 53 | svg_string = f.read() 54 | chart_info, msg = deconstruct(svg_string) 55 | ground_truth = benchmark["1129"]["chart"] 56 | query_meta = benchmark["1129"]["query_meta"] 57 | result = chart_check( 58 | chart_info, 59 | ground_truth, 60 | ( 61 | query_meta[i]["stacked_bar"] 62 | if "stacked_bar" in query_meta[i] 63 | else None 64 | ), 65 | ) 66 | assert result[0] is True 67 | 68 | 69 | def test_chart_check_bar_2750(): 70 | for i in range(2): 71 | with open(folder / f"assets/stacked_bar_2750_{i}.svg", "r") as f: 72 | svg_string = f.read() 73 | chart_info, msg = deconstruct(svg_string) 74 | ground_truth = benchmark["2750"]["chart"] 75 | query_meta = benchmark["2750"]["query_meta"] 76 | result = chart_check( 77 | chart_info, 78 | ground_truth, 79 | ( 80 | query_meta[i]["stacked_bar"] 81 | if "stacked_bar" in query_meta[i] 82 | else None 83 | ), 84 | ) 85 | assert result[0] is True 86 | 87 | 88 | def test_chart_check_line_3240(): 89 | with open(folder / f"assets/line_3240_0.svg", "r") as f: 90 | svg_string = f.read() 91 | chart_info, msg = deconstruct(svg_string) 92 | ground_truth = benchmark["3240"]["chart"] 93 | query_meta = benchmark["3240"]["query_meta"] 94 | result = chart_check( 95 | chart_info, 96 | ground_truth, 97 | query_meta[0]["stacked_bar"] if "stacked_bar" in query_meta[0] else None, 98 | ) 99 | assert result[0] is True 100 | 101 | 102 | def test_chart_check_line_2781(): 103 | with open(folder / f"assets/grouping_line_2781_0.svg", "r") as f: 104 | svg_string = f.read() 105 | chart_info, msg = deconstruct(svg_string) 106 | ground_truth = benchmark["2781"]["chart"] 107 | query_meta = benchmark["2781"]["query_meta"] 108 | result = chart_check( 109 | chart_info, 110 | ground_truth, 111 | query_meta[0]["stacked_bar"] if "stacked_bar" in query_meta[0] else None, 112 | ) 113 | assert result[0] is True 114 | 115 | 116 | def test_chart_check_bar_1071(): 117 | # horizontal bar 118 | with open(folder / f"assets/bar_1071_0.svg", "r") as f: 119 | svg_string = f.read() 120 | chart_info, msg = deconstruct(svg_string) 121 | ground_truth = benchmark["1071"]["chart"] 122 | query_meta = benchmark["1071"]["query_meta"] 123 | result = chart_check( 124 | chart_info, 125 | ground_truth, 126 | query_meta[0]["stacked_bar"] if "stacked_bar" in query_meta[0] else None, 127 | ) 128 | assert result[0] is True 129 | 130 | 131 | def test_stacked(): 132 | result = chart_check({"chart": "grouping bar"}, "Stacked Bar", True) 133 | assert result[0] is False 134 | 135 | result = chart_check({"chart": "grouping bar"}, "Stacked Bar", False) 136 | assert result[0] is True 137 | -------------------------------------------------------------------------------- /tests/test_data_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import json 5 | from pathlib import Path 6 | 7 | import pytest 8 | 9 | from viseval.check import data_check, deconstruct 10 | 11 | folder = Path(__file__).resolve().parent 12 | 13 | with open(folder / "assets/samples.json") as f: 14 | benchmark = json.load(f) 15 | 16 | 17 | def test_data_check_pie_5(): 18 | with open(folder / "assets/pie_5.svg", "r") as f: 19 | svg_string = f.read() 20 | chart_info, msg = deconstruct(svg_string) 21 | ground_truth = benchmark["3264"]["vis_obj"] 22 | query_meta = benchmark["3264"]["query_meta"] 23 | result = data_check( 24 | chart_info, ground_truth, query_meta[0]["channel_specified"] 25 | ) 26 | assert result[0] is True 27 | 28 | def test_data_check_pie_4(): 29 | with open(folder / "assets/pie_4_0.svg", "r") as f: 30 | svg_string = f.read() 31 | chart_info, msg = deconstruct(svg_string) 32 | ground_truth = benchmark["4"]["vis_obj"] 33 | query_meta = benchmark["4"]["query_meta"] 34 | result = data_check( 35 | chart_info, ground_truth, query_meta[0]["channel_specified"] 36 | ) 37 | assert result[0] is True 38 | 39 | 40 | def test_data_check_scatter_400(): 41 | for i in range(2): 42 | with open(folder / f"assets/scatter_400_{i}.svg", "r") as f: 43 | svg_string = f.read() 44 | chart_info, msg = deconstruct(svg_string) 45 | ground_truth = benchmark["400"]["vis_obj"] 46 | query_meta = benchmark["400"]["query_meta"] 47 | result = data_check( 48 | chart_info, ground_truth, query_meta[i]["channel_specified"] 49 | ) 50 | assert result[0] is True 51 | 52 | 53 | def test_data_check_bar_1129(): 54 | for i in range(3): 55 | with open(folder / f"assets/bar_1129_{i}.svg", "r") as f: 56 | svg_string = f.read() 57 | chart_info, msg = deconstruct(svg_string) 58 | ground_truth = benchmark["1129"]["vis_obj"] 59 | query_meta = benchmark["1129"]["query_meta"] 60 | result = data_check( 61 | chart_info, ground_truth, query_meta[i]["channel_specified"] 62 | ) 63 | assert result[0] is True 64 | 65 | 66 | def test_data_check_bar_2750(): 67 | for i in range(2): 68 | with open(folder / f"assets/stacked_bar_2750_{i}.svg", "r") as f: 69 | svg_string = f.read() 70 | chart_info, msg = deconstruct(svg_string) 71 | ground_truth = benchmark["2750"]["vis_obj"] 72 | query_meta = benchmark["2750"]["query_meta"] 73 | result = data_check( 74 | chart_info, ground_truth, query_meta[i]["channel_specified"] 75 | ) 76 | assert result[0] is True 77 | 78 | 79 | def test_data_check_line_3240(): 80 | with open(folder / "assets/line_3240_0.svg", "r") as f: 81 | svg_string = f.read() 82 | chart_info, msg = deconstruct(svg_string) 83 | ground_truth = benchmark["3240"]["vis_obj"] 84 | query_meta = benchmark["3240"]["query_meta"] 85 | result = data_check( 86 | chart_info, ground_truth, query_meta[0]["channel_specified"] 87 | ) 88 | assert result[0] is True 89 | 90 | 91 | def test_data_check_line_2781(): 92 | with open(folder / "assets/grouping_line_2781_0.svg", "r") as f: 93 | svg_string = f.read() 94 | chart_info, msg = deconstruct(svg_string) 95 | ground_truth = benchmark["2781"]["vis_obj"] 96 | query_meta = benchmark["2781"]["query_meta"] 97 | result = data_check( 98 | chart_info, ground_truth, query_meta[0]["channel_specified"] 99 | ) 100 | assert result[0] is False 101 | 102 | 103 | def test_data_check_line_773(): 104 | # yyyy-mm-dd 105 | with open(folder / "assets/line_773_0.svg", "r") as f: 106 | svg_string = f.read() 107 | chart_info, msg = deconstruct(svg_string) 108 | ground_truth = benchmark["773@x_name@DESC"]["vis_obj"] 109 | query_meta = benchmark["773@x_name@DESC"]["query_meta"] 110 | result = data_check( 111 | chart_info, ground_truth, query_meta[0]["channel_specified"] 112 | ) 113 | assert result[0] is True 114 | 115 | 116 | def test_data_check_bar_68(): 117 | # monday - mon 118 | with open(folder / "assets/bar_68_0.svg", "r") as f: 119 | svg_string = f.read() 120 | chart_info, msg = deconstruct(svg_string) 121 | ground_truth = benchmark["68"]["vis_obj"] 122 | query_meta = benchmark["68"]["query_meta"] 123 | result = data_check( 124 | chart_info, ground_truth, query_meta[0]["channel_specified"] 125 | ) 126 | assert result[0] is True 127 | 128 | 129 | def test_data_check_bar_1071(): 130 | # horizontal bar 131 | with open(folder / "assets/bar_1071_0.svg", "r") as f: 132 | svg_string = f.read() 133 | chart_info, msg = deconstruct(svg_string) 134 | ground_truth = benchmark["1071"]["vis_obj"] 135 | query_meta = benchmark["1071"]["query_meta"] 136 | result = data_check( 137 | chart_info, ground_truth, query_meta[0]["channel_specified"] 138 | ) 139 | assert result[0] is True 140 | 141 | result = data_check(chart_info, ground_truth, ["x", "y"]) 142 | assert result[0] is False 143 | 144 | 145 | def test_data_check_bar_3137(): 146 | # error cases 147 | with open(folder / "assets/bar_3137_1.svg", "r") as f: 148 | svg_string = f.read() 149 | chart_info, msg = deconstruct(svg_string) 150 | ground_truth = benchmark["3137@y_name@DESC"]["vis_obj"] 151 | query_meta = benchmark["3137@y_name@DESC"]["query_meta"] 152 | result = data_check( 153 | chart_info, ground_truth, query_meta[0]["channel_specified"] 154 | ) 155 | assert result[0] is False 156 | -------------------------------------------------------------------------------- /tests/test_deconstruct.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from pathlib import Path 5 | 6 | import pytest 7 | 8 | from viseval.check import deconstruct 9 | 10 | folder = Path(__file__).resolve().parent 11 | 12 | 13 | # error case 14 | def test_deconstruct_empty(): 15 | with open(folder / "assets/empty.svg", "r") as f: 16 | svg_string = f.read() 17 | chart_info, msg = deconstruct(svg_string) 18 | assert len(chart_info["data"]) == 0 19 | 20 | 21 | def test_deconstruct_empty1(): 22 | with open(folder / "assets/empty_1.svg", "r") as f: 23 | svg_string = f.read() 24 | chart_info, msg = deconstruct(svg_string) 25 | assert len(chart_info["data"]) == 0 26 | 27 | 28 | def test_deconstruct_empty2(): 29 | with open(folder / "assets/empty_2.svg", "r") as f: 30 | svg_string = f.read() 31 | chart_info, msg = deconstruct(svg_string) 32 | assert len(chart_info["data"]) == 0 33 | 34 | 35 | def test_deconstruct_empty3(): 36 | with open(folder / "assets/empty_3.svg", "r") as f: 37 | svg_string = f.read() 38 | chart_info, msg = deconstruct(svg_string) 39 | assert chart_info is None 40 | 41 | 42 | def test_deconstruct_empty4(): 43 | with open(folder / "assets/empty_4.svg", "r") as f: 44 | svg_string = f.read() 45 | chart_info, msg = deconstruct(svg_string) 46 | print(chart_info) 47 | assert len(chart_info["data"]) == 0 48 | 49 | 50 | def test_deconstruct_dual_axis(): 51 | with open(folder / "assets/dual_1499_0.svg", "r") as f: 52 | svg_string = f.read() 53 | chart_info, msg = deconstruct(svg_string) 54 | assert chart_info is None 55 | 56 | 57 | def test_deconstruct_bar_error(): 58 | with open(folder / "assets/bar_error_3227_0.svg", "r") as f: 59 | svg_string = f.read() 60 | chart_info, msg = deconstruct(svg_string) 61 | assert chart_info["mark"] == "bar" 62 | assert chart_info["chart"] == "vertical bar(error)" 63 | 64 | 65 | def test_deconstruct_scatter(): 66 | with open(folder / "assets/scatter.svg", "r") as f: 67 | svg_string = f.read() 68 | chart_info, msg = deconstruct(svg_string) 69 | assert chart_info["mark"] == "circle" 70 | assert chart_info["chart"] == "scatter" 71 | assert len(chart_info["encoding"].keys()) == 2 72 | assert ( 73 | chart_info["title"] 74 | == "Correlation between IMDB Rating and Rotten Tomatoes Rating in Action Movies" 75 | ) 76 | assert chart_info["encoding"]["x"]["title"] == "IMDB Rating" 77 | assert chart_info["encoding"]["y"]["title"] == "Rotten Tomatoes Rating" 78 | assert len(chart_info["data"]) == 309 79 | 80 | 81 | def test_deconstruct_scatter1(): 82 | with open(folder / "assets/scatter_1.svg", "r") as f: 83 | svg_string = f.read() 84 | chart_info, msg = deconstruct(svg_string) 85 | assert chart_info["mark"] == "circle" 86 | assert chart_info["chart"] == "scatter" 87 | assert len(chart_info["encoding"].keys()) == 2 88 | assert len(chart_info["data"]) == 5 89 | 90 | 91 | def test_deconstruct_scatter2(): 92 | with open(folder / "assets/scatter_2.svg", "r") as f: 93 | svg_string = f.read() 94 | chart_info, msg = deconstruct(svg_string) 95 | assert chart_info["mark"] == "circle" 96 | assert chart_info["chart"] == "scatter" 97 | assert len(chart_info["encoding"].keys()) == 2 98 | assert len(chart_info["data"]) == 100 99 | 100 | 101 | def test_deconstruct_scatter4(): 102 | with open(folder / "assets/scatter_4.svg", "r") as f: 103 | svg_string = f.read() 104 | chart_info, msg = deconstruct(svg_string) 105 | assert chart_info["mark"] == "circle" 106 | assert chart_info["chart"] == "scatter" 107 | assert len(chart_info["encoding"].keys()) == 2 108 | assert len(chart_info["data"]) == 3 109 | 110 | 111 | def test_deconstruct_scatter6(): 112 | with open(folder / "assets/scatter_6.svg", "r") as f: 113 | svg_string = f.read() 114 | chart_info, msg = deconstruct(svg_string) 115 | assert chart_info["mark"] == "circle" 116 | assert chart_info["chart"] == "scatter" 117 | assert len(chart_info["encoding"].keys()) == 2 118 | assert len(chart_info["data"]) == 69 119 | 120 | 121 | def test_deconstruct_scatter_lida(): 122 | with open(folder / "assets/scatter_62.svg", "r") as f: 123 | svg_string = f.read() 124 | chart_info, msg = deconstruct(svg_string) 125 | assert chart_info["mark"] == "circle" 126 | assert chart_info["chart"] == "scatter" 127 | assert len(chart_info["encoding"].keys()) == 2 128 | assert len(chart_info["data"]) == 9 129 | 130 | 131 | def test_deconstruct_scatter_lida2(): 132 | with open(folder / "assets/scatter_292.svg", "r") as f: 133 | svg_string = f.read() 134 | chart_info, msg = deconstruct(svg_string) 135 | assert chart_info["mark"] == "circle" 136 | assert chart_info["chart"] == "scatter" 137 | assert len(chart_info["encoding"].keys()) == 2 138 | assert len(chart_info["data"]) == 3 139 | 140 | 141 | def test_deconstruct_scatter_lida3(): 142 | with open(folder / "assets/scatter_674.svg", "r") as f: 143 | svg_string = f.read() 144 | chart_info, msg = deconstruct(svg_string) 145 | print(chart_info) 146 | assert chart_info["mark"] == "circle" 147 | assert chart_info["chart"] == "scatter" 148 | assert len(chart_info["encoding"].keys()) == 2 149 | assert len(chart_info["data"]) == 76 150 | 151 | 152 | # black legend 153 | def test_deconstruct_grouping_scatter7(): 154 | with open(folder / "assets/grouping_scatter_3272_0.svg", "r") as f: 155 | svg_string = f.read() 156 | chart_info, msg = deconstruct(svg_string) 157 | assert chart_info["mark"] == "circle" 158 | assert chart_info["chart"] == "grouping scatter" 159 | assert len(chart_info["encoding"].keys()) == 3 160 | assert len(chart_info["data"]) == 6 161 | 162 | 163 | def test_deconstruct_grouping_scatter(): 164 | with open(folder / "assets/grouping_scatter.svg", "r") as f: 165 | svg_string = f.read() 166 | chart_info, msg = deconstruct(svg_string) 167 | assert chart_info["mark"] == "circle" 168 | assert chart_info["chart"] == "grouping scatter" 169 | assert len(chart_info["encoding"].keys()) == 3 170 | assert len(chart_info["data"]) == 5 171 | 172 | 173 | def test_deconstruct_line(): 174 | with open(folder / "assets/line.svg", "r") as f: 175 | svg_string = f.read() 176 | chart_info, msg = deconstruct(svg_string) 177 | assert chart_info["mark"] == "line" 178 | assert chart_info["chart"] == "line" 179 | assert len(chart_info["encoding"].keys()) == 2 180 | assert chart_info["title"] == "Maximum Temperature by Month" 181 | assert chart_info["encoding"]["x"]["title"] == "Month" 182 | assert chart_info["encoding"]["y"]["title"] == "Temperature (°C)" 183 | assert len(chart_info["data"]) == 12 184 | 185 | 186 | def test_deconstruct_line1(): 187 | with open(folder / "assets/line_1.svg", "r") as f: 188 | svg_string = f.read() 189 | chart_info, msg = deconstruct(svg_string) 190 | assert chart_info["mark"] == "line" 191 | assert chart_info["chart"] == "line" 192 | assert len(chart_info["encoding"].keys()) == 2 193 | assert len(chart_info["data"]) == 20 194 | 195 | 196 | def test_deconstruct_line2(): 197 | with open(folder / "assets/line_2.svg", "r") as f: 198 | svg_string = f.read() 199 | chart_info, msg = deconstruct(svg_string) 200 | assert chart_info["mark"] == "line" 201 | assert chart_info["chart"] == "line" 202 | assert len(chart_info["encoding"].keys()) == 2 203 | assert len(chart_info["data"]) == 107 204 | 205 | 206 | def test_deconstruct_line3(): 207 | with open(folder / "assets/line_3.svg", "r") as f: 208 | svg_string = f.read() 209 | chart_info, msg = deconstruct(svg_string) 210 | assert chart_info["mark"] == "line" 211 | assert chart_info["chart"] == "line" 212 | assert len(chart_info["encoding"].keys()) == 2 213 | assert len(chart_info["data"]) == 10 214 | 215 | 216 | def test_deconstruct_line4(): 217 | with open(folder / "assets/line_4.svg", "r") as f: 218 | svg_string = f.read() 219 | chart_info, msg = deconstruct(svg_string) 220 | assert chart_info["mark"] == "line" 221 | assert chart_info["chart"] == "line" 222 | assert len(chart_info["encoding"].keys()) == 2 223 | assert len(chart_info["data"]) == 107 224 | 225 | 226 | # error case 227 | def test_deconstruct_line5(): 228 | with open(folder / "assets/line_5.svg", "r") as f: 229 | svg_string = f.read() 230 | chart_info, msg = deconstruct(svg_string) 231 | assert chart_info["mark"] == "line" 232 | assert chart_info["chart"] == "line" 233 | assert len(chart_info["encoding"].keys()) == 2 234 | assert len(chart_info["data"]) == 0 235 | 236 | 237 | # error case 238 | def test_deconstruct_line11(): 239 | with open(folder / "assets/line_11.svg", "r") as f: 240 | svg_string = f.read() 241 | chart_info, msg = deconstruct(svg_string) 242 | assert "mark" not in chart_info 243 | 244 | 245 | # tick line 246 | def test_deconstruct_line12(): 247 | with open(folder / "assets/line_12.svg", "r") as f: 248 | svg_string = f.read() 249 | chart_info, msg = deconstruct(svg_string) 250 | assert chart_info["mark"] == "line" 251 | assert chart_info["chart"] == "line" 252 | assert len(chart_info["encoding"].keys()) == 2 253 | assert len(chart_info["data"]) == 13 254 | 255 | 256 | def test_deconstruct_line13(): 257 | with open(folder / "assets/line_13.svg", "r") as f: 258 | svg_string = f.read() 259 | chart_info, msg = deconstruct(svg_string) 260 | assert chart_info["mark"] == "line" 261 | assert chart_info["chart"] == "line" 262 | assert len(chart_info["encoding"].keys()) == 2 263 | assert len(chart_info["data"]) == 15 264 | 265 | 266 | def test_deconstruct_line14(): 267 | with open(folder / "assets/line_1746_1.svg", "r") as f: 268 | svg_string = f.read() 269 | chart_info, msg = deconstruct(svg_string) 270 | assert chart_info["mark"] == "circle" 271 | assert chart_info["chart"] == "scatter" 272 | assert len(chart_info["encoding"].keys()) == 2 273 | assert len(chart_info["data"]) == 1 274 | 275 | 276 | def test_deconstruct_grouping_line(): 277 | with open(folder / "assets/grouping_line.svg", "r") as f: 278 | svg_string = f.read() 279 | chart_info, msg = deconstruct(svg_string) 280 | assert chart_info["mark"] == "line" 281 | assert chart_info["chart"] == "grouping line" 282 | assert len(chart_info["encoding"].keys()) == 3 283 | assert len(chart_info["data"]) == 9 284 | 285 | 286 | def test_deconstruct_grouping_line1(): 287 | with open(folder / "assets/grouping_line_1.svg", "r") as f: 288 | svg_string = f.read() 289 | chart_info, msg = deconstruct(svg_string) 290 | assert chart_info["mark"] == "line" 291 | assert chart_info["chart"] == "grouping line" 292 | assert len(chart_info["encoding"].keys()) == 3 293 | assert len(chart_info["data"]) == 20 294 | 295 | 296 | def test_deconstruct_grouping_line4(): 297 | with open(folder / "assets/grouping_line_4.svg", "r") as f: 298 | svg_string = f.read() 299 | chart_info, msg = deconstruct(svg_string) 300 | assert len(chart_info["encoding"].keys()) == 2 301 | 302 | 303 | # error case 304 | def test_deconstruct_grouping_line2(): 305 | with open(folder / "assets/grouping_line_2.svg", "r") as f: 306 | svg_string = f.read() 307 | chart_info, msg = deconstruct(svg_string) 308 | assert chart_info["mark"] == "line" 309 | assert chart_info["chart"] == "grouping line" 310 | assert len(chart_info["encoding"].keys()) == 3 311 | assert len(chart_info["data"]) == 0 312 | 313 | 314 | def test_deconstruct_grouping_line3(): 315 | with open(folder / "assets/grouping_line_3.svg", "r") as f: 316 | svg_string = f.read() 317 | chart_info, msg = deconstruct(svg_string) 318 | assert chart_info["mark"] == "line" 319 | assert chart_info["chart"] == "grouping line" 320 | assert len(chart_info["encoding"].keys()) == 3 321 | assert len(chart_info["data"]) == 0 322 | 323 | 324 | def test_deconstruct_bar(): 325 | with open(folder / "assets/bar.svg", "r") as f: 326 | svg_string = f.read() 327 | chart_info, msg = deconstruct(svg_string) 328 | assert chart_info["mark"] == "bar" 329 | assert chart_info["chart"] == "vertical bar" 330 | assert len(chart_info["encoding"].keys()) == 2 331 | assert chart_info["title"] == "Number of Students Advised by Faculty Rank" 332 | assert chart_info["encoding"]["x"]["title"] == "Faculty Rank" 333 | assert chart_info["encoding"]["y"]["title"] == "Number of Students Advised" 334 | assert len(chart_info["data"]) == 3 335 | 336 | 337 | def test_deconstruct_bar1(): 338 | with open(folder / "assets/bar_1.svg", "r") as f: 339 | svg_string = f.read() 340 | chart_info, msg = deconstruct(svg_string) 341 | assert chart_info["mark"] == "bar" 342 | assert chart_info["chart"] == "horizontal bar" 343 | assert len(chart_info["encoding"].keys()) == 2 344 | assert len(chart_info["data"]) == 7 345 | 346 | 347 | def test_deconstruct_bar6(): 348 | with open(folder / "assets/bar_6.svg", "r") as f: 349 | svg_string = f.read() 350 | chart_info, msg = deconstruct(svg_string) 351 | assert chart_info["mark"] == "bar" 352 | assert len(chart_info["encoding"].keys()) == 2 353 | assert len(chart_info["data"]) == 49 354 | 355 | 356 | # 1 bar 357 | def test_deconstruct_bar7(): 358 | with open(folder / "assets/bar_2733_0.svg", "r") as f: 359 | svg_string = f.read() 360 | chart_info, msg = deconstruct(svg_string) 361 | assert chart_info["mark"] == "bar" 362 | assert chart_info["chart"] == "horizontal bar" 363 | assert len(chart_info["encoding"].keys()) == 2 364 | assert len(chart_info["data"]) == 1 365 | 366 | 367 | # bar legend error 368 | def test_deconstruct_bar8(): 369 | with open(folder / "assets/stacked_bar_2815.svg", "r") as f: 370 | svg_string = f.read() 371 | chart_info, msg = deconstruct(svg_string) 372 | assert chart_info["mark"] == "bar" 373 | assert len(chart_info["encoding"].keys()) == 2 374 | 375 | 376 | def test_deconstruct_bar9(): 377 | with open(folder / "assets/bar_186.svg", "r") as f: 378 | svg_string = f.read() 379 | chart_info, msg = deconstruct(svg_string) 380 | assert chart_info["mark"] == "bar" 381 | assert len(chart_info["encoding"].keys()) == 2 382 | assert len(chart_info["data"]) == 14 383 | 384 | 385 | def test_deconstruct_bar10(): 386 | with open(folder / "assets/bar_3269.svg", "r") as f: 387 | svg_string = f.read() 388 | chart_info, msg = deconstruct(svg_string) 389 | assert chart_info["mark"] == "bar" 390 | assert len(chart_info["encoding"].keys()) == 2 391 | print(chart_info) 392 | assert len(chart_info["data"]) == 6 393 | 394 | 395 | # lida bar 396 | def test_deconstruct_bar_lida1(): 397 | with open(folder / "assets/bar_9.svg", "r") as f: 398 | svg_string = f.read() 399 | chart_info, msg = deconstruct(svg_string) 400 | assert chart_info["mark"] == "bar" 401 | assert chart_info["chart"] == "vertical bar" 402 | assert len(chart_info["encoding"].keys()) == 2 403 | assert len(chart_info["data"]) == 4 404 | 405 | 406 | # lida bar 407 | def test_deconstruct_bar_lida2(): 408 | with open(folder / "assets/bar_205.svg", "r") as f: 409 | svg_string = f.read() 410 | chart_info, msg = deconstruct(svg_string) 411 | assert chart_info["mark"] == "bar" 412 | assert chart_info["chart"] == "vertical bar" 413 | assert len(chart_info["encoding"].keys()) == 2 414 | assert len(chart_info["data"]) == 2 415 | 416 | 417 | def test_deconstruct_stacked_bar(): 418 | with open(folder / "assets/stacked_bar.svg", "r") as f: 419 | svg_string = f.read() 420 | chart_info, msg = deconstruct(svg_string) 421 | assert chart_info["mark"] == "bar" 422 | assert chart_info["chart"] == "vertical stacked bar" 423 | assert len(chart_info["encoding"].keys()) == 3 424 | assert chart_info["title"] == "Faculty by Rank/Sex" 425 | assert chart_info["encoding"]["x"]["title"] == "Rank" 426 | assert chart_info["encoding"]["y"]["title"] == "Number of faculty" 427 | assert len(chart_info["data"]) == 7 428 | 429 | 430 | def test_deconstruct_stacked_bar1(): 431 | with open(folder / "assets/stacked_bar_1.svg", "r") as f: 432 | svg_string = f.read() 433 | chart_info, msg = deconstruct(svg_string) 434 | assert chart_info["mark"] == "bar" 435 | assert chart_info["chart"] == "vertical stacked bar" 436 | assert len(chart_info["encoding"].keys()) == 3 437 | assert len(chart_info["data"]) == 9 438 | 439 | 440 | def test_deconstruct_grouping_bar(): 441 | with open(folder / "assets/bar_189.svg", "r") as f: 442 | svg_string = f.read() 443 | chart_info, msg = deconstruct(svg_string) 444 | assert chart_info["mark"] == "bar" 445 | assert chart_info["chart"] == "vertical grouping bar" 446 | assert len(chart_info["encoding"].keys()) == 3 447 | assert len(chart_info["data"]) == 4 448 | 449 | 450 | def test_deconstruct_pie(): 451 | with open(folder / "assets/pie.svg", "r") as f: 452 | svg_string = f.read() 453 | chart_info, msg = deconstruct(svg_string) 454 | assert chart_info["mark"] == "arc" 455 | assert chart_info["chart"] == "pie" 456 | assert len(chart_info["encoding"].keys()) == 2 457 | assert chart_info["title"] == "Maximum Price of Each Film" 458 | assert len(chart_info["data"]) == 5 459 | 460 | 461 | # 180 degree 462 | def test_deconstruct_pie1(): 463 | with open(folder / "assets/pie_1.svg", "r") as f: 464 | svg_string = f.read() 465 | chart_info, msg = deconstruct(svg_string) 466 | assert chart_info["mark"] == "arc" 467 | assert chart_info["chart"] == "pie" 468 | assert len(chart_info["encoding"].keys()) == 2 469 | assert len(chart_info["data"]) == 3 470 | flag = False 471 | for datum in chart_info["data"]: 472 | # = 50 percent 473 | if datum["field_theta"] == 50: 474 | flag = True 475 | assert flag 476 | 477 | 478 | # shadow and > 180 degree 479 | def test_deconstruct_pie2(): 480 | with open(folder / "assets/pie_2.svg", "r") as f: 481 | svg_string = f.read() 482 | chart_info, msg = deconstruct(svg_string) 483 | assert chart_info["mark"] == "arc" 484 | assert chart_info["chart"] == "pie" 485 | assert len(chart_info["encoding"].keys()) == 2 486 | assert len(chart_info["data"]) == 3 487 | flag = False 488 | for datum in chart_info["data"]: 489 | # > 50 percent 490 | if datum["field_theta"] > 50: 491 | flag = True 492 | assert flag 493 | 494 | 495 | # shadow and > 180 degree 496 | def test_deconstruct_pie3(): 497 | with open(folder / "assets/pie_3.svg", "r") as f: 498 | svg_string = f.read() 499 | chart_info, msg = deconstruct(svg_string) 500 | assert chart_info["mark"] == "arc" 501 | assert chart_info["chart"] == "pie" 502 | assert len(chart_info["encoding"].keys()) == 2 503 | assert len(chart_info["data"]) == 2 504 | flag = False 505 | for datum in chart_info["data"]: 506 | # > 50 percent 507 | if datum["field_theta"] > 50: 508 | flag = True 509 | assert flag 510 | 511 | 512 | # 360 degree 513 | def test_deconstruct_pie3(): 514 | with open(folder / "assets/pie_4.svg", "r") as f: 515 | svg_string = f.read() 516 | chart_info, msg = deconstruct(svg_string) 517 | assert chart_info["mark"] == "arc" 518 | assert chart_info["chart"] == "pie" 519 | assert len(chart_info["encoding"].keys()) == 2 520 | assert len(chart_info["data"]) == 1 521 | flag = False 522 | for datum in chart_info["data"]: 523 | # > 50 percent 524 | if datum["field_theta"] == 100: 525 | flag = True 526 | assert flag 527 | -------------------------------------------------------------------------------- /tests/test_layout_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import csv 5 | from pathlib import Path 6 | 7 | import pytest 8 | 9 | from viseval.check import layout_check 10 | 11 | folder = Path(__file__).resolve().parent 12 | 13 | webdriver_path = "/usr/bin/chromedriver" 14 | 15 | 16 | def test_layout_check(): 17 | with open(folder / "assets/readability/145@y_name@DESC.svg", "r") as svg_file: 18 | svg_string = svg_file.read() 19 | assert layout_check({"svg_string": svg_string}, webdriver_path) == ( 20 | False, 21 | "Overflow detected.", 22 | ) 23 | 24 | 25 | def test_layout_check_2(): 26 | with open(folder / "assets/readability/2350.svg", "r") as svg_file: 27 | svg_string = svg_file.read() 28 | assert layout_check({"svg_string": svg_string}, webdriver_path) == ( 29 | True, 30 | "No overflow or overlap detected.", 31 | ) 32 | 33 | 34 | def test_layout_check_3(): 35 | with open(folder / "assets/readability/2652.svg", "r") as svg_file: 36 | svg_string = svg_file.read() 37 | assert layout_check({"svg_string": svg_string}, webdriver_path) == ( 38 | False, 39 | "Overflow detected.", 40 | ) 41 | 42 | 43 | def test_layout_check_batch(): 44 | with open(folder / "assets/readability/readability_human_rating.csv") as f: 45 | reader = csv.reader(f) 46 | next(reader) 47 | for row in reader: 48 | svg_id = row[0] 49 | 50 | with open(folder / f"assets/readability/{svg_id}.svg", "r") as svg_file: 51 | svg_string = svg_file.read() 52 | print(svg_id) 53 | assert not layout_check({"svg_string": svg_string}, webdriver_path)[0] == ( 54 | False if row[5] == "FALSE" else True 55 | ) 56 | -------------------------------------------------------------------------------- /tests/test_order_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import json 5 | from pathlib import Path 6 | 7 | import pytest 8 | 9 | from viseval.check import data_check, deconstruct, order_check 10 | 11 | folder = Path(__file__).resolve().parent 12 | 13 | with open(folder / "assets/samples.json") as f: 14 | benchmark = json.load(f) 15 | 16 | 17 | def test_order_check_line_773(): 18 | # x ascending 19 | with open(folder / "assets/line_773_0.svg", "r") as f: 20 | svg_string = f.read() 21 | chart_info, msg = deconstruct(svg_string) 22 | # yyyy-mm-dd 23 | ground_truth = benchmark["773@x_name@DESC"]["vis_obj"] 24 | query_meta = benchmark["773@x_name@DESC"]["query_meta"] 25 | answer, rationale = data_check( 26 | chart_info, ground_truth, query_meta[0]["channel_specified"] 27 | ) 28 | 29 | answer, rationale = order_check( 30 | chart_info, ground_truth, query_meta[0]["sort_by"] 31 | ) 32 | assert answer is True 33 | 34 | ground_truth["sort"]["order"] = "descending" 35 | answer, rationale = order_check( 36 | chart_info, ground_truth, query_meta[0]["sort_by"] 37 | ) 38 | assert answer is False 39 | 40 | 41 | def test_order_check_bar_680(): 42 | # y descending 43 | with open(folder / "assets/stacked_bar_680_0.svg", "r") as f: 44 | svg_string = f.read() 45 | chart_info, msg = deconstruct(svg_string) 46 | ground_truth = benchmark["680@y_name@DESC"]["vis_obj"] 47 | query_meta = benchmark["680@y_name@DESC"]["query_meta"] 48 | answer, rationale = data_check( 49 | chart_info, ground_truth, query_meta[0]["channel_specified"] 50 | ) 51 | 52 | answer, rationale = order_check( 53 | chart_info, ground_truth, query_meta[0]["sort_by"] 54 | ) 55 | assert answer is True 56 | 57 | ground_truth["sort"]["order"] = "ascending" 58 | answer, rationale = order_check( 59 | chart_info, ground_truth, query_meta[0]["sort_by"] 60 | ) 61 | assert answer is False 62 | 63 | 64 | def test_order_check_bar_2815(): 65 | # x descending 66 | # swap channel x-z 67 | with open(folder / "assets/stacked_bar_2815_0.svg", "r") as f: 68 | svg_string = f.read() 69 | chart_info, msg = deconstruct(svg_string) 70 | ground_truth = benchmark["2815@x_name@DESC"]["vis_obj"] 71 | query_meta = benchmark["2815@x_name@DESC"]["query_meta"] 72 | answer, rationale = data_check( 73 | chart_info, ground_truth, query_meta[0]["channel_specified"] 74 | ) 75 | 76 | answer, rationale = order_check( 77 | chart_info, ground_truth, query_meta[0]["sort_by"] 78 | ) 79 | assert answer is True 80 | 81 | ground_truth["sort"]["order"] = "ascending" 82 | answer, rationale = order_check( 83 | chart_info, ground_truth, query_meta[0]["sort_by"] 84 | ) 85 | assert answer is False 86 | 87 | 88 | # horizontal bar 89 | def test_order_check_bar_2060(): 90 | # y descending 91 | # swap channel x-y 92 | with open(folder / "assets/bar_2060_0.svg", "r") as f: 93 | svg_string = f.read() 94 | chart_info, msg = deconstruct(svg_string) 95 | ground_truth = benchmark["2060@y_name@DESC"]["vis_obj"] 96 | query_meta = benchmark["2060@y_name@DESC"]["query_meta"] 97 | answer, rationale = data_check(chart_info, ground_truth, []) 98 | 99 | answer, rationale = order_check( 100 | chart_info, ground_truth, query_meta[0]["sort_by"] 101 | ) 102 | assert answer is False 103 | 104 | # x quantitative descending 105 | answer, rationale = order_check(chart_info, ground_truth, "attribute") 106 | assert answer is True 107 | 108 | ground_truth["sort"]["order"] = "ascending" 109 | answer, rationale = order_check(chart_info, ground_truth, "attribute") 110 | assert answer is False 111 | 112 | 113 | # horizontal bar 114 | def test_order_check_bar_1000(): 115 | # y descending 116 | # swap channel x-y 117 | with open(folder / "assets/bar_1000_0.svg", "r") as f: 118 | svg_string = f.read() 119 | chart_info, msg = deconstruct(svg_string) 120 | ground_truth = benchmark["1000@y_name@DESC"]["vis_obj"] 121 | query_meta = benchmark["1000@y_name@DESC"]["query_meta"] 122 | answer, rationale = data_check(chart_info, ground_truth, []) 123 | print(answer, rationale) 124 | answer, rationale = order_check( 125 | chart_info, ground_truth, query_meta[0]["sort_by"] 126 | ) 127 | assert answer is False 128 | 129 | # x quantitative descending 130 | ground_truth["sort"]["channel"] = "x" 131 | answer, rationale = order_check( 132 | chart_info, ground_truth, query_meta[0]["sort_by"] 133 | ) 134 | assert answer is True 135 | 136 | # y nominal descending 137 | ground_truth["sort"]["channel"] = "y" 138 | chart_info["encoding"]["y"]["scale"]["domain"] = [ 139 | "sony", 140 | "jcrew", 141 | "gucci", 142 | "apple", 143 | ] 144 | answer, rationale = order_check( 145 | chart_info, ground_truth, query_meta[0]["sort_by"] 146 | ) 147 | assert answer is True 148 | 149 | # custom order 150 | ground_truth["sort"]["order"] = ["sony", "jcrew", "gucci", "apple"] 151 | answer, rationale = order_check( 152 | chart_info, ground_truth, query_meta[0]["sort_by"] 153 | ) 154 | assert answer is True 155 | 156 | 157 | # invert y axis, zero value 158 | def test_order_check_bar_550(): 159 | with open(folder / "assets/bar_550_0.svg", "r") as f: 160 | svg_string = f.read() 161 | chart_info, msg = deconstruct(svg_string) 162 | ground_truth = benchmark["550@y_name@DESC"]["vis_obj"] 163 | query_meta = benchmark["550@y_name@DESC"]["query_meta"] 164 | answer, rationale = data_check( 165 | chart_info, ground_truth, query_meta[0]["channel_specified"] 166 | ) 167 | 168 | answer, rationale = order_check( 169 | chart_info, ground_truth, query_meta[0]["sort_by"] 170 | ) 171 | assert answer is True 172 | 173 | 174 | def test_order_check_bar_465(): 175 | with open(folder / "assets/bar_465.svg", "r") as f: 176 | svg_string = f.read() 177 | chart_info, msg = deconstruct(svg_string) 178 | ground_truth = benchmark["465@y_name@ASC"]["vis_obj"] 179 | query_meta = benchmark["465@y_name@ASC"]["query_meta"] 180 | answer, rationale = data_check( 181 | chart_info, ground_truth, query_meta[0]["channel_specified"] 182 | ) 183 | 184 | answer, rationale = order_check( 185 | chart_info, ground_truth, query_meta[0]["sort_by"] 186 | ) 187 | assert answer is False 188 | -------------------------------------------------------------------------------- /tests/test_surface_form_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | from pathlib import Path 6 | 7 | import pytest 8 | 9 | from viseval.check import surface_form_check 10 | 11 | folder = Path(__file__).resolve().parent 12 | 13 | def test_valid_empty(): 14 | with open(folder / "assets/empty_3.svg", "r") as f: 15 | svg_string = f.read() 16 | answer, msg = surface_form_check(svg_string) 17 | assert answer == False 18 | 19 | 20 | def test_valid_empty(): 21 | with open(folder / "assets/pie_4_0.svg", "r") as f: 22 | svg_string = f.read() 23 | answer, msg = surface_form_check(svg_string) 24 | assert answer == True -------------------------------------------------------------------------------- /viseval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .dataset import Dataset 5 | from .evaluate import Evaluator 6 | -------------------------------------------------------------------------------- /viseval/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import Optional, Tuple, Union 6 | 7 | from attr import dataclass 8 | from langchain.chat_models.base import BaseChatModel 9 | from llmx import TextGenerator 10 | 11 | 12 | @dataclass 13 | class ChartExecutionResult: 14 | """Response from a visualization execution""" 15 | 16 | # True if successful, False otherwise 17 | status: bool 18 | # Generate svg string if status is True 19 | svg_string: Optional[str] = None 20 | # Error message if status is False 21 | error_msg: Optional[str] = None 22 | 23 | 24 | class Agent(ABC): 25 | @abstractmethod 26 | def __init__( 27 | self, llm: Union[BaseChatModel, TextGenerator], config: dict = None 28 | ) -> None: 29 | pass 30 | 31 | @abstractmethod 32 | def generate( 33 | self, nl_query: str, tables: list[str], config: dict 34 | ) -> Tuple[str, dict]: 35 | """Generate code for the given natural language query. 36 | 37 | Args: 38 | nl_query (str): Natural language query. 39 | tables (list[str]): List of table file paths. 40 | config (dict): Generation configuration. 41 | 42 | Returns: 43 | Tuple[str, dict]: Generated code and execution context. 44 | """ 45 | pass 46 | 47 | @abstractmethod 48 | def execute( 49 | self, code: str, context: dict, log_name: str = None 50 | ) -> ChartExecutionResult: 51 | """Execute the given code with context and return the result 52 | 53 | Args: 54 | code (str): Code to execute,. 55 | context (dict): Context for the code execution. Different agents require different contexts. 56 | log_name (str, optional): SVG file name. Defaults to None. 57 | 58 | Returns: 59 | ChartExecutionResult: _description_ 60 | """ 61 | pass 62 | -------------------------------------------------------------------------------- /viseval/check/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .chart_check import chart_check 5 | from .data_check import data_check 6 | from .deconstruct import deconstruct 7 | from .layout_check import layout_check 8 | from .order_check import order_check 9 | from .readability_check import readability_check 10 | from .scale_and_ticks_check import scale_and_ticks_check 11 | from .surface_form_check import surface_form_check 12 | -------------------------------------------------------------------------------- /viseval/check/chart_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | def chart_check( 6 | chart_info: dict, chart_type_ground_truth: str, stacked_bar: bool = False 7 | ): 8 | # chart_type_ground_truth: pie, bar, line, scatter, stacked bar, grouping line, grouping scatter 9 | if "chart" not in chart_info: 10 | return False, "Cannot recognize the chart type." 11 | 12 | chart_type = chart_info["chart"] 13 | 14 | if chart_type_ground_truth.lower() in chart_type.lower() or ( 15 | chart_type_ground_truth == "Stacked Bar" 16 | and not stacked_bar 17 | and "grouping bar" in chart_type.lower() 18 | ): 19 | return True, "Chart type is consistent with ground truth." 20 | 21 | return False, "Chart type is not consistent with ground truth." 22 | -------------------------------------------------------------------------------- /viseval/check/data_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import copy 5 | import json 6 | 7 | from .time_utils import ( 8 | compare_time_strings, 9 | convert_month_or_weekday_to_int, 10 | is_month_or_weekday, 11 | parse_number_to_time, 12 | parse_time_to_timestamp, 13 | ) 14 | 15 | PRECISION = 0.0005 16 | 17 | 18 | def is_numeric(s): 19 | try: 20 | float(s) 21 | return True 22 | except ValueError: 23 | return False 24 | 25 | def compare_string(string, ground_truth): 26 | return string.strip().lower().startswith(ground_truth.strip().lower()) 27 | 28 | def convert_ground_truth_data(ground_truth): 29 | # ground truth data 30 | x_data = ground_truth["x_data"] 31 | y_data = ground_truth["y_data"] 32 | classify = ground_truth["classify"] 33 | 34 | data_ground_truth = [] 35 | if len(classify) == 0: 36 | x_data = x_data[0] 37 | y_data = y_data[0] 38 | for index in range(len(x_data)): 39 | data_ground_truth.append( 40 | { 41 | "field_x": x_data[index], 42 | "field_y": y_data[index], 43 | } 44 | ) 45 | else: 46 | # scatter 47 | if len(classify) == len(x_data) and len(classify) == len(y_data): 48 | for index in range(len(classify)): 49 | for index2 in range(len(x_data[index])): 50 | data_ground_truth.append( 51 | { 52 | "field_x": x_data[index][index2], 53 | "field_y": y_data[index][index2], 54 | "field_classify": classify[index], 55 | } 56 | ) 57 | # line / bar 58 | elif len(classify) == len(y_data): 59 | for index in range(len(classify)): 60 | for index2 in range(len(x_data[0])): 61 | data_ground_truth.append( 62 | { 63 | "field_x": x_data[0][index2], 64 | "field_y": y_data[index][index2], 65 | "field_classify": classify[index], 66 | } 67 | ) 68 | return data_ground_truth 69 | 70 | 71 | # return answer and rationale for the data check 72 | def compare_data(data_ground_truth, chart_info): 73 | # deep copy: avoid to change origin data 74 | data = copy.deepcopy(chart_info["data"]) 75 | encoding = chart_info["encoding"] 76 | # line chart may have different length of data because line chart might omit some values 77 | if (len(data) != len(data_ground_truth) and chart_info["mark"] != "line") or ( 78 | len(data) > len(data_ground_truth) 79 | ): 80 | return ( 81 | False, 82 | f"visualization data length {len(data)} != ground truth length {len(data_ground_truth)}.", 83 | ) 84 | 85 | if chart_info["mark"] == "arc": 86 | # field_fill -> field_x 87 | field_x = "field_fill" 88 | # field_theta -> field_y relative 89 | field_y = "field_theta" 90 | 91 | scale = None 92 | for datum_ground_truth in data_ground_truth: 93 | datum = [ 94 | x 95 | for x in data 96 | if compare_string(str(x[field_x]), str(datum_ground_truth["field_x"])) 97 | ] 98 | if len(datum) == 1: 99 | datum = datum[0] 100 | if scale is None: 101 | scale = datum[field_y] / datum_ground_truth["field_y"] 102 | else: 103 | if ( 104 | abs(datum[field_y] / datum_ground_truth["field_y"] - scale) 105 | > PRECISION 106 | ): 107 | return False, f"{json.dumps(datum_ground_truth)} not found\n" 108 | elif len(datum) == 0: 109 | return False, f"{json.dumps(datum_ground_truth)} not found." 110 | elif len(datum) > 1: 111 | return False, f"{json.dumps(datum_ground_truth)} found more than one." 112 | else: 113 | if len(encoding.keys()) < len(data_ground_truth[0].keys()): 114 | return False, "too few channels\n" 115 | elif len(encoding.keys()) > 3: 116 | return False, "too many channels\n" 117 | elif len(encoding.keys()) > len(data_ground_truth[0].keys()): 118 | # only keep x and y for comparison 119 | encoding = {key: encoding[key] for key in ["x", "y"]} 120 | 121 | for datum_ground_truth in data_ground_truth: 122 | datum = data 123 | for key in encoding: 124 | if key == "x" or key == "y": 125 | field = f"field_{key}" 126 | else: 127 | field = "field_classify" 128 | if ( 129 | encoding[key]["type"] == "quantitative" 130 | or encoding[key]["type"] == "temporal" 131 | ): 132 | # e.g., Monday -> 1 133 | if is_month_or_weekday(datum_ground_truth[field]): 134 | if encoding[key]["type"] == "temporal": 135 | datum = [ 136 | x 137 | for x in datum 138 | if compare_time_strings( 139 | x["field_" + key].strip(), 140 | str(datum_ground_truth[field]).strip(), 141 | ) 142 | ] 143 | value_ground_truth = None 144 | else: 145 | value_ground_truth = convert_month_or_weekday_to_int( 146 | datum_ground_truth[field] 147 | ) 148 | else: 149 | try: 150 | value_ground_truth = float(datum_ground_truth[field]) 151 | except Exception: 152 | # not a number 153 | # try temporal 154 | try: 155 | value_ground_truth = parse_time_to_timestamp( 156 | datum_ground_truth[field] 157 | ) 158 | if value_ground_truth is None: 159 | return ( 160 | False, 161 | f"The data type of {key}({encoding[key]['type']}) is wrong.", 162 | ) 163 | 164 | if encoding[key]["type"] == "quantitative": 165 | for x in datum: 166 | x["field_" + key] = parse_number_to_time( 167 | x["field_" + key] 168 | ) 169 | except Exception: 170 | return ( 171 | False, 172 | f"The data type of {key}({encoding[key]['type']}) is wrong.", 173 | ) 174 | 175 | field_vis = ( 176 | "field_" + key 177 | if encoding[key]["type"] != "temporal" 178 | else "field_" + key + "_origin" 179 | ) 180 | if value_ground_truth is not None: 181 | # avoid division by zero 182 | datum = [ 183 | x 184 | for x in datum 185 | if abs((x[field_vis] - value_ground_truth)) <= PRECISION 186 | or ( 187 | value_ground_truth != 0 188 | and abs( 189 | (x[field_vis] - value_ground_truth) 190 | / value_ground_truth 191 | ) 192 | <= PRECISION 193 | ) 194 | ] 195 | elif encoding[key]["type"] == "nominal": 196 | # exact match or time match 197 | datum = [ 198 | x 199 | for x in datum 200 | if compare_string(x["field_" + key], str(datum_ground_truth[field])) 201 | or compare_time_strings( 202 | x["field_" + key].strip(), 203 | str(datum_ground_truth[field]).strip(), 204 | ) 205 | ] 206 | 207 | if len(datum) == 0: 208 | if ( 209 | chart_info["mark"] == "line" 210 | and ( 211 | encoding["x"]["type"] == "quantitative" 212 | or encoding["x"]["type"] == "temporal" 213 | ) 214 | and (encoding["y"]["type"] == "quantitative") 215 | ): 216 | # use all data 217 | datum = chart_info["data"] 218 | if len(data_ground_truth[0].keys()) == 3: 219 | datum = [ 220 | x 221 | for x in datum 222 | if x["field_stroke"].strip() 223 | == datum_ground_truth["field_classify"].strip() 224 | ] 225 | # line chart might omit some values 226 | min_larger_index = -1 227 | max_smaller_index = -1 228 | # convert 229 | field_vis = ( 230 | "field_x" 231 | if encoding["x"]["type"] != "temporal" 232 | else "field_x_origin" 233 | ) 234 | value_ground_truth = datum_ground_truth["field_x"] 235 | try: 236 | value_ground_truth = float(value_ground_truth) 237 | except Exception: 238 | value_ground_truth = parse_time_to_timestamp(value_ground_truth) 239 | 240 | for index in range(len(datum)): 241 | if datum[index][field_vis] > value_ground_truth: 242 | if ( 243 | min_larger_index == -1 244 | or datum[index][field_vis] 245 | < datum[min_larger_index][field_vis] 246 | ): 247 | min_larger_index = index 248 | elif datum[index][field_vis] < value_ground_truth: 249 | if ( 250 | max_smaller_index == -1 251 | or datum[index][field_vis] 252 | > datum[max_smaller_index][field_vis] 253 | ): 254 | max_smaller_index = index 255 | if min_larger_index != -1 and max_smaller_index != -1: 256 | if ( 257 | abs( 258 | ( 259 | datum[min_larger_index]["field_y"] 260 | - datum[max_smaller_index]["field_y"] 261 | ) 262 | ) 263 | <= PRECISION 264 | or abs( 265 | ( 266 | datum[min_larger_index]["field_y"] 267 | - datum_ground_truth["field_y"] 268 | ) 269 | ) 270 | <= PRECISION 271 | ): 272 | continue 273 | else: 274 | return ( 275 | False, 276 | f"{json.dumps(datum_ground_truth)} not found\n", 277 | ) 278 | return False, f"{json.dumps(datum_ground_truth)} not found." 279 | else: 280 | data.remove(datum[0]) 281 | 282 | return True, "The data on the charts is consistent with the ground truth." 283 | 284 | 285 | def data_check(chart_info: dict, data: dict, channel_specified: list): 286 | if ("data" not in chart_info) or (len(chart_info["data"]) == 0): 287 | return False, "The data on the charts cannot be understood." 288 | 289 | data_ground_truth = convert_ground_truth_data(data) 290 | # filter zero data 291 | if chart_info["mark"] == "bar": 292 | data_ground_truth = list( 293 | filter(lambda x: x["field_x"] != 0 and x["field_y"] != 0, data_ground_truth) 294 | ) 295 | chart_info["data"] = list( 296 | filter( 297 | lambda x: x["field_x"] != 0 and x["field_y"] != 0, 298 | chart_info["data"], 299 | ) 300 | ) 301 | candidates = [data_ground_truth] 302 | # ground truth channel -> chart channel 303 | channel_maps = [] 304 | if len(data["classify"]) == 0: 305 | # 2 channels 306 | channel_maps.append({"x": "x", "y": "y"}) 307 | if "x" not in channel_specified and "y" not in channel_specified: 308 | # swap x and y 309 | data_ground_truth_copy = copy.deepcopy(data_ground_truth) 310 | for datum in data_ground_truth_copy: 311 | datum["field_x"], datum["field_y"] = ( 312 | datum["field_y"], 313 | datum["field_x"], 314 | ) 315 | candidates.append(data_ground_truth_copy) 316 | channel_maps.append({"x": "y", "y": "x"}) 317 | else: 318 | # 3 channels 319 | channel_maps.append({"x": "x", "y": "y", "classify": "classify"}) 320 | channels = ["x", "y", "classify"] 321 | if len(channel_specified) <= 1: 322 | swap_channels = list(set(channels) - set(channel_specified)) 323 | data_ground_truth_copy = copy.deepcopy(data_ground_truth) 324 | for datum in data_ground_truth_copy: 325 | ( 326 | datum[f"field_{swap_channels[0]}"], 327 | datum[f"field_{swap_channels[1]}"], 328 | ) = ( 329 | datum[f"field_{swap_channels[1]}"], 330 | datum[f"field_{swap_channels[0]}"], 331 | ) 332 | candidates.append(data_ground_truth_copy) 333 | channel_map = {"x": "x", "y": "y", "classify": "classify"} 334 | channel_map[swap_channels[0]], channel_map[swap_channels[1]] = ( 335 | channel_map[swap_channels[1]], 336 | channel_map[swap_channels[0]], 337 | ) 338 | channel_maps.append(channel_map) 339 | if len(channel_specified) == 0: 340 | # swap x and z 341 | data_ground_truth_copy = copy.deepcopy(candidates[0]) 342 | for datum in data_ground_truth_copy: 343 | datum["field_x"], datum["field_classify"] = ( 344 | datum["field_classify"], 345 | datum["field_x"], 346 | ) 347 | candidates.append(data_ground_truth_copy) 348 | channel_maps.append({"x": "classify", "y": "y", "classify": "x"}) 349 | # swap x and y, x and z 350 | data_ground_truth_copy = copy.deepcopy(candidates[1]) 351 | for datum in data_ground_truth_copy: 352 | datum["field_y"], datum["field_classify"] = ( 353 | datum["field_classify"], 354 | datum["field_y"], 355 | ) 356 | candidates.append(data_ground_truth_copy) 357 | channel_maps.append({"x": "y", "y": "classify", "classify": "x"}) 358 | cache = None 359 | for i in range(len(candidates)): 360 | candidate = candidates[i] 361 | channel_map = channel_maps[i] 362 | answer, rationale = compare_data(candidate, chart_info) 363 | if answer: 364 | chart_info["channel_map"] = channel_map 365 | return answer, rationale 366 | if i == 0: 367 | cache = [answer, rationale] 368 | 369 | return cache[0], cache[1] 370 | -------------------------------------------------------------------------------- /viseval/check/layout_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import uuid 6 | 7 | overflowScript = """const isElementOutsideCanvas = (element) => { 8 | // Get the bounding box of the SVG element 9 | let svgRect = element.getBBox(); 10 | 11 | // Get the viewBox attribute of the SVG element 12 | let svgElement = document.querySelector('svg'); 13 | let viewBoxValue = svgElement.getAttribute('viewBox'); 14 | let viewBoxArray = viewBoxValue.split(' ').map(Number); 15 | let minX = viewBoxArray[0]; 16 | let minY = viewBoxArray[1]; 17 | let width = viewBoxArray[2]; 18 | let height = viewBoxArray[3]; 19 | 20 | // Check if any part of the element is outside the canvas 21 | if (svgRect.x < minX || svgRect.y < minY || svgRect.x + svgRect.width > width+minX || svgRect.y + svgRect.height > height+minY) { 22 | return true; 23 | } else { 24 | return false; 25 | } 26 | } 27 | 28 | let svgElement = document.querySelector('#axes_1'); 29 | return isElementOutsideCanvas(svgElement); 30 | """ 31 | 32 | overlapScript = """const isTextOverlap = (parentElement) => { 33 | let index = 1; 34 | let flag = true; 35 | let textArray = []; 36 | while (flag) { 37 | let textElement = parentElement.querySelector('#text_' + index); 38 | if (textElement) { 39 | index++; 40 | let gElement = textElement.querySelector('g'); 41 | let transform = gElement.getAttribute('transform'); 42 | if (transform) { 43 | const rotateMatch = transform.match(/rotate\(([^)]+)\)/); 44 | let notRotate = true; 45 | if (rotateMatch) { 46 | const rotateValue = parseFloat(rotateMatch[1]) % 90; 47 | if (Math.abs(rotateValue) > 10 && Math.abs(rotateValue) < 80) { 48 | notRotate = false; 49 | } 50 | } 51 | 52 | if (notRotate) { 53 | let bbox = textElement.getBBox(); 54 | textArray.push(bbox); 55 | } 56 | } 57 | } 58 | else { 59 | flag = false; 60 | } 61 | } 62 | let textOverlap = false; 63 | let textArrayLength = textArray.length; 64 | for (let i = 0; i < textArrayLength; i++) { 65 | for (let j = i + 1; j < textArrayLength; j++) { 66 | let text1Rect = textArray[i]; 67 | let text2Rect = textArray[j]; 68 | if (text1Rect.x < text2Rect.x + text2Rect.width && 69 | text1Rect.x + text1Rect.width > text2Rect.x && 70 | text1Rect.y < text2Rect.y + text2Rect.height && 71 | text1Rect.y + text1Rect.height > text2Rect.y) { 72 | textOverlap = true; 73 | break; 74 | } 75 | } 76 | } 77 | return textOverlap; 78 | } 79 | 80 | let svgElement = document.querySelector('#axes_1'); 81 | return isTextOverlap(svgElement) 82 | """ 83 | 84 | 85 | def layout_check(context: dict, webdriver_path): 86 | svg_string = context["svg_string"] 87 | 88 | if webdriver_path is not None: 89 | from selenium import webdriver 90 | from selenium.webdriver.chrome.service import Service 91 | 92 | # Chrome headless mode 93 | options = webdriver.ChromeOptions() 94 | options.add_argument("--headless") 95 | options.add_argument("--disable-gpu") 96 | try: 97 | service = Service(webdriver_path) 98 | webdriver = webdriver.Chrome(service=service, options=options) 99 | 100 | current_directory = os.getcwd() 101 | file_path = f"{current_directory}/temp_{uuid.uuid1()}.svg" 102 | 103 | with open(file_path, "w") as svg_file: 104 | svg_file.write(svg_string) 105 | 106 | webdriver.get(f"file://{file_path}") 107 | overflow_result = not webdriver.execute_script(overflowScript) 108 | overlap_result = not webdriver.execute_script(overlapScript) 109 | webdriver.close() 110 | try: 111 | os.remove(file_path) 112 | except Exception as e: 113 | print(f"Remove file error: {e}") 114 | 115 | if overflow_result and overlap_result: 116 | msg = "No overflow or overlap detected." 117 | elif not overflow_result and not overlap_result: 118 | msg = "Overflow and overlap detected." 119 | elif not overflow_result: 120 | msg = "Overflow detected." 121 | else: 122 | msg = "Overlap detected." 123 | return overflow_result and overlap_result, msg 124 | except Exception as e: 125 | print(e) 126 | 127 | return None, "No webdriver path provided." 128 | -------------------------------------------------------------------------------- /viseval/check/order_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | def order_check(chart_info: dict, ground_truth: dict, sort_by: str): 6 | order = ground_truth["sort"] 7 | encoding = chart_info["encoding"] 8 | data = chart_info["data"] 9 | channel_map = chart_info["channel_map"] 10 | 11 | if order is not None: 12 | # bar, line 13 | if sort_by == "axis": 14 | order_channel = order["channel"] 15 | else: 16 | order_channel = channel_map[order["channel"]] 17 | 18 | other_channel = "y" if order_channel == "x" else "x" 19 | 20 | order_channel_scale = encoding[order_channel]["scale"] 21 | other_channel_scale = encoding[other_channel]["scale"] 22 | 23 | # origin channel 24 | if ( 25 | channel_map[order_channel] == "x" 26 | or channel_map[order_channel] == "classify" 27 | ): 28 | arr = [] 29 | if len(order_channel_scale["range"]) == 0: 30 | scale_range = range(1, 1 + len(order_channel_scale["domain"])) 31 | else: 32 | scale_range = order_channel_scale["range"] 33 | 34 | for index in range(len(order_channel_scale["domain"])): 35 | arr.append( 36 | tuple( 37 | [ 38 | order_channel_scale["domain"][index], 39 | scale_range[index], 40 | ] 41 | ) 42 | ) 43 | if order_channel == "x": 44 | reverse = True 45 | else: 46 | reverse = False 47 | if order["order"] == "ascending": 48 | arr.sort(key=lambda x: x[0], reverse=reverse) 49 | elif order["order"] == "descending": 50 | arr.sort(key=lambda x: x[0], reverse=not reverse) 51 | else: # custom order 52 | sort_order = {} 53 | for index in range(len(order["order"])): 54 | sort_order[order["order"][index]] = index 55 | arr.sort(key=lambda x: sort_order[x[0]], reverse=reverse) 56 | 57 | is_sorted = all([arr[i][1] > arr[i + 1][1] for i in range(len(arr) - 1)]) 58 | # 'quantitative' 59 | else: 60 | # sort by other channel 61 | values_other = [] 62 | if ( 63 | "type" not in other_channel_scale 64 | or other_channel_scale["type"] == "ordinal" 65 | ): 66 | for index in range(len(other_channel_scale["domain"])): 67 | values_other.append( 68 | tuple( 69 | [ 70 | other_channel_scale["domain"][index], 71 | other_channel_scale["range"][index], 72 | ] 73 | ) 74 | ) 75 | values_other.sort(key=lambda x: x[1]) 76 | values_other = [item[0] for item in values_other] 77 | else: 78 | values_other = list( 79 | set([datum["field_" + other_channel] for datum in data]) 80 | ) 81 | values_other.sort( 82 | reverse=True 83 | if ( 84 | other_channel_scale["domain"][1] 85 | - other_channel_scale["domain"][0] 86 | ) 87 | / ( 88 | other_channel_scale["range"][1] 89 | - other_channel_scale["range"][0] 90 | ) 91 | < 0 92 | else False 93 | ) 94 | 95 | # cumulative 96 | values_order = [] 97 | for value in values_other: 98 | data_filter = [ 99 | float(d["field_" + order_channel]) 100 | for d in data 101 | if d["field_" + other_channel] == value 102 | ] 103 | values_order.append(sum(data_filter)) 104 | 105 | # filter zero data 106 | if chart_info["mark"] == "bar": 107 | values_order = list(filter(lambda x: x != 0, values_order)) 108 | 109 | is_sorted = True 110 | if order["order"] == "ascending": 111 | is_sorted = all( 112 | [ 113 | values_order[i] <= values_order[i + 1] 114 | for i in range(len(values_order) - 1) 115 | ] 116 | ) 117 | elif order["order"] == "descending": 118 | is_sorted = all( 119 | [ 120 | values_order[i] >= values_order[i + 1] 121 | for i in range(len(values_order) - 1) 122 | ] 123 | ) 124 | 125 | if not is_sorted: 126 | return False, "Doesn't sort." 127 | else: 128 | return True, "Sorted." 129 | else: 130 | return True, "No sort." 131 | -------------------------------------------------------------------------------- /viseval/check/readability_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from __future__ import annotations 5 | 6 | import json 7 | import sys 8 | import warnings 9 | 10 | from langchain.chat_models.base import BaseChatModel 11 | from langchain.schema import HumanMessage, SystemMessage 12 | 13 | INSTRUCTION = """Your task is to evaluate the readability of the visualization on a scale of 1 to 5, where 1 indicates very difficult to read and 5 indicates very easy to read. You will be given a visualization requirement and the corresponding visualization created based on that requirement. Additionally, reviews from others regarding this visualization will be provided for your reference. Please think carefully and provide your reasoning and score. 14 | ``` 15 | { 16 | "Rationale": "a brief reason", 17 | "Score": 1-5 18 | } 19 | ``` 20 | 21 | 22 | Examples: 23 | - If the visualization is clear and information can be easily interpreted, you might return: 24 | ``` 25 | { 26 | "Rationale": "The chart is well-organized, and the use of contrasting colors helps in distinguishing different data sets effectively. The labels are legible, and the key insights can be understood at a glance.", 27 | "Score": 5 28 | } 29 | ``` 30 | - Conversely, if the visualization is cluttered or confusing, you might return: 31 | ``` 32 | { 33 | "Rationale": "While there is no overflow or overlap, the unconventional inverted y-axis and the use of decimal numbers for months on the x-axis deviate from the standard interpretation of bar charts, confusing readers and significantly affecting the chart's readability.", 34 | "Score": 1 35 | } 36 | ``` 37 | """ 38 | 39 | 40 | def readability_check(context: dict, query: str, vision_model: BaseChatModel): 41 | base64 = context["base64"] 42 | 43 | reviews = "" 44 | if "reviews" in context and len(context["reviews"]) > 0: 45 | reviews = "Other Reviews:\n" 46 | reviews += "\n".join( 47 | [ 48 | f"""- {review["aspect"]}: {review["content"]}""" 49 | for review in context["reviews"] 50 | ] 51 | ) 52 | reviews += "\n\n" 53 | 54 | try: 55 | messages = [ 56 | SystemMessage(content=INSTRUCTION), 57 | ] 58 | messages.append( 59 | HumanMessage( 60 | content=[ 61 | { 62 | "type": "text", 63 | "text": f"""Visualization Requirement: {query}\n\n{reviews}Visualization image:""", 64 | }, 65 | { 66 | "type": "image_url", 67 | "image_url": base64, 68 | }, 69 | { 70 | "type": "text", 71 | "text": """Please assess the readability, taking into account factors such as layout, scale and ticks, title and labels, colors, and ease of extracting information. Do not consider the correctness of the data and order in the visualizations, as they have already been verified.""", 72 | }, 73 | ] 74 | ) 75 | ) 76 | 77 | response = vision_model.invoke(messages) 78 | 79 | json_string = ( 80 | response.content.replace("```json\n", "").replace("```", "").strip() 81 | ) 82 | try: 83 | result = json.loads(json_string) 84 | except Exception: 85 | result = eval(json_string) 86 | 87 | return result["Score"], result["Rationale"] 88 | except Exception: 89 | warnings.warn(str(sys.exc_info())) 90 | return None, "Exception occurred." 91 | -------------------------------------------------------------------------------- /viseval/check/scale_and_ticks_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from __future__ import annotations 5 | 6 | import json 7 | import sys 8 | import warnings 9 | 10 | from langchain.chat_models.base import BaseChatModel 11 | from langchain.schema import HumanMessage, SystemMessage 12 | 13 | INSTRUCTION = """You will be provided with a visualization and its specifications. Consider the following aspect: 14 | 15 | - If the scale selected for the visualization is appropriate for accurate interpretation of values, avoid using unconventional scales, such as an inverted y-axis scale. 16 | - When axes are present, ensure that the selected ticks are appropriate for clarity, avoiding unconventional choices, such as representing counts of individual entities with floating-point ticks. 17 | 18 | 19 | Report your findings, focusing solely on scale and tick appropriateness without considering the order. 20 | ``` 21 | { 22 | "Appropriate": true/false, 23 | "Rationale": "reason ..." 24 | } 25 | ``` 26 | """ 27 | 28 | 29 | def scale_and_ticks_check(context: dict, query: str, vision_model: BaseChatModel): 30 | base64 = context["base64"] 31 | encoding = context["encoding"] 32 | chart = context["chart"] 33 | if chart == "pie": 34 | ticks_desc = "" 35 | else: 36 | x_ticks = encoding["x"]["scale"]["ticks"] 37 | y_ticks = encoding["y"]["scale"]["ticks"] 38 | ticks_desc = f"Ticks extracted from the visualization:\n- x axis ticks: {','.join(x_ticks)}\n- y axis ticks: {','.join(y_ticks)}\n\n" 39 | try: 40 | messages = [ 41 | SystemMessage(content=INSTRUCTION), 42 | ] 43 | messages.append( 44 | HumanMessage( 45 | content=[ 46 | { 47 | "type": "text", 48 | "text": f"Visualization specification:{query}\n\n{ticks_desc}Visualization image that has been verified for DATA and ORDER accuracy:", 49 | }, 50 | { 51 | "type": "image_url", 52 | "image_url": base64, 53 | }, 54 | ] 55 | ) 56 | ) 57 | 58 | response = vision_model.invoke(messages) 59 | 60 | json_string = ( 61 | response.content.replace("```json\n", "").replace("```", "").strip() 62 | ) 63 | try: 64 | result = json.loads(json_string) 65 | except Exception: 66 | result = eval(json_string) 67 | return result["Appropriate"], result["Rationale"] 68 | except Exception: 69 | warnings.warn(str(sys.exc_info())) 70 | return None, "Exception occurred." 71 | -------------------------------------------------------------------------------- /viseval/check/surface_form_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | from xml.dom import minidom 4 | 5 | def is_group_children(node): 6 | return node.nodeType == node.ELEMENT_NODE and node.tagName == "g" 7 | 8 | def surface_form_check(svg_string): 9 | """ 10 | Check if the code has plotted visualization. 11 | """ 12 | doc = minidom.parseString(svg_string) 13 | svg = doc.getElementsByTagName("svg")[0] 14 | 15 | children = list(filter(lambda node: is_group_children(node), svg.childNodes)) 16 | if len(children) == 0: 17 | return False, "Did not plot visualization." 18 | 19 | if len(children) == 1: 20 | children = list(filter(lambda node: is_group_children(node), children[0].childNodes)) 21 | if len(children) < 2: 22 | return False, "Did not plot visualization." 23 | 24 | return True, "Plotted visualization." -------------------------------------------------------------------------------- /viseval/check/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from datetime import datetime 5 | 6 | from dateutil import parser 7 | 8 | TIME_MAP = { 9 | "mon": "monday", 10 | "tue": "tuesday", 11 | "wed": "wednesday", 12 | "thu": "thursday", 13 | "fri": "friday", 14 | "sat": "saturday", 15 | "sun": "sunday", 16 | "jan": "january", 17 | "feb": "february", 18 | "mar": "march", 19 | "apr": "april", 20 | "may": "may", 21 | "jun": "june", 22 | "jul": "july", 23 | "aug": "august", 24 | "sep": "september", 25 | "sept": "september", 26 | "oct": "october", 27 | "nov": "november", 28 | "dec": "december", 29 | "mon": "monday", 30 | "tue": "tuesday", 31 | "wed": "wednesday", 32 | "thu": "thursday", 33 | "thur": "thursday", 34 | "fri": "friday", 35 | "sat": "saturday", 36 | "sun": "sunday", 37 | } 38 | WEEKDAYS = ["mon", "tue", "wed", "thu", "fri", "sat", "sun"] 39 | MONTHS = [ 40 | "jan", 41 | "feb", 42 | "mar", 43 | "apr", 44 | "may", 45 | "jun", 46 | "jul", 47 | "aug", 48 | "sep", 49 | "oct", 50 | "nov", 51 | "dec", 52 | ] 53 | 54 | 55 | def is_month_or_weekday(s: str): 56 | if isinstance(s, str): 57 | if s.lower() in TIME_MAP or ( 58 | s[0:3].lower() in TIME_MAP and s.lower() == TIME_MAP[s[0:3].lower()] 59 | ): 60 | return True 61 | return False 62 | 63 | 64 | def convert_month_or_weekday_to_int(s: str) -> int: 65 | if is_month_or_weekday(s): 66 | if s[0:3].lower() in WEEKDAYS: 67 | return WEEKDAYS.index(s[0:3].lower()) + 1 68 | if s[0:3].lower() in MONTHS: 69 | return MONTHS.index(s[0:3].lower()) + 1 70 | return -1 71 | 72 | 73 | def is_datetime(s): 74 | # consider month and weekday as nominal 75 | if is_month_or_weekday(s): 76 | return False 77 | try: 78 | parser.parse(s) 79 | return True 80 | except ValueError: 81 | return False 82 | 83 | 84 | def check_time_format(time_str, time_format): 85 | try: 86 | datetime.strptime(time_str, time_format) 87 | return True 88 | except ValueError: 89 | return False 90 | 91 | 92 | def parse_time_to_timestamp(time_str): 93 | # 0:00 is prone to bias 94 | if check_time_format(time_str, "%Y"): 95 | time_str = time_str + "-01-01 00:00:10" 96 | elif check_time_format(time_str, "%Y-%m"): 97 | time_str = time_str + "-01 00:00:10" 98 | elif check_time_format(time_str, "%Y-%m-%d"): 99 | time_str = time_str + " 00:00:10" 100 | 101 | try: 102 | parsed_time = parser.parse(time_str) 103 | timestamp = parsed_time.timestamp() 104 | return timestamp 105 | except Exception: 106 | return None 107 | 108 | 109 | def parse_timestamp_to_time(timestamp): 110 | # todo: extract date format 111 | date_format = "%Y-%m-%d" 112 | try: 113 | parsed_time = datetime.fromtimestamp(timestamp) 114 | time_str = parsed_time.strftime(date_format) 115 | return time_str 116 | except Exception: 117 | return None 118 | 119 | 120 | # handle case like 2008.436089 121 | def parse_number_to_time(number): 122 | if number > 0 and number < 2999: 123 | timestamp = parse_time_to_timestamp(str(int(number))) 124 | timestamp += (number - int(number)) * 365 * 24 * 60 * 60 125 | return timestamp 126 | return number 127 | 128 | 129 | def compare_time_strings(time_str1: str, time_str2: str): 130 | try: 131 | if parser.parse(time_str1).timestamp() == parser.parse(time_str2).timestamp(): 132 | return True 133 | except Exception: 134 | pass 135 | 136 | try: 137 | str1 = TIME_MAP.get(time_str1.lower(), time_str1) 138 | str2 = TIME_MAP.get(time_str2.lower(), time_str2) 139 | 140 | if str1.lower() == str2.lower(): 141 | return True 142 | 143 | if ( 144 | is_month_or_weekday(str1) 145 | and str(convert_month_or_weekday_to_int(str1)) == str2 146 | ): 147 | return True 148 | if ( 149 | is_month_or_weekday(str2) 150 | and str(convert_month_or_weekday_to_int(str2)) == str1 151 | ): 152 | return True 153 | 154 | if ( 155 | is_month_or_weekday(str1) 156 | and str(convert_month_or_weekday_to_int(str1)) == str2 157 | ): 158 | return True 159 | if ( 160 | is_month_or_weekday(str2) 161 | and str(convert_month_or_weekday_to_int(str2)) == str1 162 | ): 163 | return True 164 | 165 | if is_month_or_weekday(str2) and parse_time_to_timestamp(time_str1) is not None: 166 | # weekday 167 | if ( 168 | str2.lower() 169 | == datetime.fromtimestamp(parse_time_to_timestamp(time_str1)) 170 | .strftime("%A") 171 | .lower() 172 | ): 173 | return True 174 | # month 175 | if ( 176 | str2.lower() 177 | == datetime.fromtimestamp(parse_time_to_timestamp(time_str1)) 178 | .strftime("%B") 179 | .lower() 180 | ): 181 | return True 182 | if is_month_or_weekday(str1) and parse_time_to_timestamp(time_str2) is not None: 183 | # weekday 184 | if ( 185 | str1.lower() 186 | == datetime.fromtimestamp(parse_time_to_timestamp(time_str2)) 187 | .strftime("%A") 188 | .lower() 189 | ): 190 | return True 191 | # month 192 | if ( 193 | str1.lower() 194 | == datetime.fromtimestamp(parse_time_to_timestamp(time_str2)) 195 | .strftime("%B") 196 | .lower() 197 | ): 198 | return True 199 | 200 | return False 201 | except Exception: 202 | return False 203 | -------------------------------------------------------------------------------- /viseval/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import json 5 | from pathlib import Path 6 | 7 | 8 | class Dataset: 9 | def __init__( 10 | self, 11 | folder: Path, 12 | table_type: str = "all", 13 | with_irrelevant_tables: bool = False, 14 | ): 15 | self.folder = folder 16 | dict_name = "visEval" 17 | if table_type in ["single", "multiple"]: 18 | dict_name += "_" + table_type 19 | dict_name += ".json" 20 | with open(folder / dict_name) as f: 21 | self.dict = json.load(f) 22 | 23 | with open(folder / "databases/db_tables.json") as f: 24 | self.db_tables = json.load(f) 25 | 26 | def benchmark(): 27 | for key in list(self.dict.keys()): 28 | self.dict[key]["id"] = key 29 | self.dict[key]["tables"] = self.__get_tables( 30 | key, with_irrelevant_tables 31 | ) 32 | yield self.dict[key] 33 | 34 | self.benchmark = benchmark() 35 | 36 | def __get_tables(self, id: str, with_irrelevant_tables: bool = False): 37 | spec = self.dict[id] 38 | db_id = spec["db_id"] 39 | # table name 40 | all_table_names = self.db_tables[db_id] 41 | table_names = [ 42 | x 43 | for x in all_table_names 44 | if x.lower() in spec["vis_query"]["VQL"].lower().split() 45 | ] 46 | 47 | if with_irrelevant_tables: 48 | irrelevant_tables = spec["irrelevant_tables"] 49 | table_names.extend(irrelevant_tables) 50 | 51 | tables = list( 52 | map( 53 | lambda table_name: f"{self.folder}/databases/{db_id}/{table_name}.csv", 54 | table_names, 55 | ) 56 | ) 57 | 58 | return tables 59 | -------------------------------------------------------------------------------- /viseval/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import base64 5 | import json 6 | import logging 7 | import os 8 | from typing import Union 9 | 10 | import cairosvg 11 | import pandas as pd 12 | from attr import dataclass 13 | 14 | from .check import ( 15 | chart_check, 16 | data_check, 17 | deconstruct, 18 | layout_check, 19 | order_check, 20 | readability_check, 21 | scale_and_ticks_check, 22 | ) 23 | from .dataset import Dataset 24 | 25 | 26 | @dataclass 27 | class CheckResult: 28 | answer: Union[bool, int] 29 | aspect: str 30 | rationale: str 31 | 32 | def get_json(self): 33 | return { 34 | "answer": self.answer, 35 | "aspect": self.aspect, 36 | "rationale": self.rationale, 37 | } 38 | 39 | 40 | @dataclass 41 | class EvaluationDetail: 42 | id: str 43 | results: list[list[CheckResult]] 44 | 45 | 46 | VALID_ASPECTS = ["code execution", "surface-form check"] 47 | LEGAL_ASPECTS = ["deconstruction", "chart type check", "data check", "order check"] 48 | READABILITY_ASPECT = ["layout check", "scale and ticks check", "readability check"] 49 | 50 | FAIL_ASPECTS = VALID_ASPECTS + LEGAL_ASPECTS + ["layout check", "scale and ticks check"] 51 | 52 | 53 | class EvaluationResult: 54 | dataset: Dataset 55 | details: list[EvaluationDetail] 56 | 57 | def __init__(self, dataset: Dataset, details: list[EvaluationDetail]): 58 | self.dataset = dataset 59 | self.details = details 60 | 61 | def detail_records(self) -> pd.DataFrame: 62 | records = [] 63 | for detail in self.details: 64 | id = detail.id 65 | instance_results = detail.results 66 | count = len(instance_results) 67 | record = { 68 | "id": id, 69 | "chart": self.dataset.dict[id]["chart"], 70 | "hardness": self.dataset.dict[id]["hardness"], 71 | } 72 | 73 | # fail rate 74 | for aspect in FAIL_ASPECTS: 75 | evaluate_result = [ 76 | ( 77 | all( 78 | [ 79 | item.answer 80 | for item in query_results 81 | if item.aspect == aspect 82 | ] 83 | ) 84 | ) 85 | for query_results in instance_results 86 | ] 87 | fail_result = [item for item in evaluate_result if not item] 88 | record[f"{aspect}_fail_rate"] = len(fail_result) / count 89 | 90 | high_level_dimensions = [ 91 | ["invalid_rate", VALID_ASPECTS], 92 | ["illegal rate", LEGAL_ASPECTS], 93 | ] 94 | pass_count = count 95 | for dimension in high_level_dimensions: 96 | evaluate_result = [ 97 | ( 98 | all( 99 | [ 100 | item.answer 101 | for item in query_results 102 | if (item.aspect in dimension[1]) 103 | ] 104 | ) 105 | ) 106 | for query_results in instance_results 107 | ] 108 | false_count = len([item for item in evaluate_result if not item]) 109 | record[dimension[0]] = false_count / count 110 | pass_count -= false_count 111 | records.append(record) 112 | 113 | # pass rate 114 | record["pass_rate"] = pass_count / count 115 | records.append(record) 116 | 117 | # readability score 118 | evaluate_result = [ 119 | ( 120 | sum( 121 | [ 122 | item.answer 123 | for item in query_results 124 | if item.aspect == "readability check" 125 | ] 126 | ) 127 | ) 128 | for query_results in instance_results 129 | ] 130 | if pass_count > 0: 131 | record["readability_score"] = sum(evaluate_result) / pass_count 132 | 133 | record["quality_score"] = sum(evaluate_result) / count 134 | 135 | return pd.DataFrame(records) 136 | 137 | def score(self): 138 | records = self.detail_records() 139 | metrics = [ 140 | "invalid_rate", 141 | "illegal rate", 142 | "pass_rate", 143 | "readability_score", 144 | "quality_score", 145 | ] 146 | score = {} 147 | for metric in metrics: 148 | score[metric] = records[metric].mean() 149 | 150 | for key in records.keys(): 151 | if ( 152 | key not in metrics 153 | and key != "id" 154 | and key != "chart" 155 | and key != "hardness" 156 | ): 157 | score[key] = records[key].mean() 158 | 159 | return score 160 | 161 | 162 | def convert_svg_to_base64(svg_string): 163 | png_string = cairosvg.svg2png(bytestring=svg_string) 164 | base64_encoded = base64.b64encode(png_string).decode("utf-8") 165 | return f"data:image/png;base64,{base64_encoded}" 166 | 167 | 168 | class Evaluator: 169 | def __init__(self, webdriver_path=None, vision_model=None): 170 | self.webdriver_path = webdriver_path 171 | self.vision_model = vision_model 172 | 173 | def evaluate(self, agent, dataset, config): 174 | use_logs = False 175 | evaluation_details = [] 176 | if "logs" in config: 177 | log_folder = config["logs"] 178 | isExists = os.path.exists(log_folder) 179 | try: 180 | if not isExists: 181 | os.makedirs(log_folder) 182 | logging.basicConfig( 183 | level=logging.INFO, 184 | filename=log_folder / "evaluation.log", 185 | filemode="a", 186 | format="%(levelname)s: %(message)s", 187 | ) 188 | use_logs = True 189 | except Exception as e: 190 | print(e) 191 | 192 | for instance in dataset.benchmark: 193 | codes = [] 194 | instance_results = [] 195 | nl_queries = instance["nl_queries"] 196 | tables = instance["tables"] 197 | 198 | if use_logs: 199 | instanceFolder = log_folder / instance["id"] 200 | path = instanceFolder / "result.json" 201 | if os.path.exists(path): 202 | with open(path, "r") as f: 203 | data = json.load(f) 204 | if "codes" in data and "evaluations" in data: 205 | instance_results = [] 206 | for query_result in data["evaluations"]: 207 | results = [ 208 | CheckResult( 209 | answer=result["answer"], 210 | aspect=result["aspect"], 211 | rationale=result["rationale"], 212 | ) 213 | for result in query_result 214 | ] 215 | instance_results.append(results) 216 | evaluation_details.append( 217 | EvaluationDetail(instance["id"], instance_results) 218 | ) 219 | continue 220 | else: 221 | logging.info(f"Instance ({instance['id']}) evaluation began.") 222 | isExists = os.path.exists(instanceFolder) 223 | if not isExists: 224 | os.makedirs(instanceFolder) 225 | 226 | for index in range(len(nl_queries)): 227 | nl_query = nl_queries[index] 228 | if index < len(codes): 229 | code = codes[index] 230 | context = {} 231 | context["tables"] = tables 232 | else: 233 | code, context = agent.generate(nl_query, tables, config) 234 | codes.append(code) 235 | if code is None: 236 | results = [ 237 | CheckResult( 238 | answer=False, 239 | aspect="generation", 240 | rationale="Code generation failed.", 241 | ) 242 | ] 243 | else: 244 | context["library"] = config["library"] 245 | if use_logs: 246 | results = self.validity_check( 247 | code, context, agent, instanceFolder / f"{index}.svg" 248 | ) 249 | else: 250 | results = self.validity_check(code, context, agent) 251 | 252 | pass_validity = all([result.answer for result in results]) 253 | if pass_validity: 254 | ground_truth = { 255 | "chart": instance["chart"], 256 | "vis_obj": instance["vis_obj"], 257 | "meta_info": instance["query_meta"][index], 258 | } 259 | results += self.legality_check(context, ground_truth) 260 | 261 | pass_legality = all([result.answer for result in results]) 262 | if pass_legality: 263 | results += self.readability_evaluate(context, nl_query) 264 | 265 | instance_results.append(results) 266 | 267 | evaluation_details.append( 268 | EvaluationDetail(instance["id"], instance_results) 269 | ) 270 | if use_logs: 271 | logging.info(f"Instance ({instance['id']}) evaluation finished.") 272 | # convert CheckResult to json 273 | instance_results = [ 274 | [result.get_json() for result in results] 275 | for results in instance_results 276 | ] 277 | with open(log_folder / (instance["id"] + "/result.json"), "w") as f: 278 | f.write( 279 | json.dumps({"codes": codes, "evaluations": instance_results}) 280 | ) 281 | f.close() 282 | return EvaluationResult(dataset, evaluation_details) 283 | 284 | def execute(self, code, context, agent, log_name=None) -> CheckResult: 285 | result = agent.execute(code, context, log_name) 286 | if result.status is False: 287 | return CheckResult( 288 | answer=False, aspect="code execution", rationale=result.error_msg 289 | ) 290 | 291 | context["svg_string"] = result.svg_string 292 | return CheckResult( 293 | answer=True, 294 | aspect="code execution", 295 | rationale="Code executed successfully.", 296 | ) 297 | 298 | def surface_form_check(self, context) -> CheckResult: 299 | svg_string = context["svg_string"] 300 | answer, rationale = surface_form_check(svg_string) 301 | return CheckResult( 302 | answer=answer, 303 | aspect="surface-form check", 304 | rationale=rationale, 305 | ) 306 | 307 | def validity_check(self, code, context, agent, log_name=None) -> list[CheckResult]: 308 | results = [] 309 | result = self.execute(code, context, agent, log_name) 310 | results.append(result) 311 | if result.answer: 312 | result = self.surface_form_check(context) 313 | results.append(result) 314 | 315 | return results 316 | 317 | def deconstruction(self, context) -> CheckResult: 318 | svg_string = context["svg_string"] 319 | library = context["library"] 320 | if library == "seaborn": 321 | library = "matplotlib" 322 | try: 323 | chart_info, msg = deconstruct(svg_string, library) 324 | if chart_info is None: 325 | return CheckResult( 326 | answer=False, 327 | aspect="deconstruction", 328 | rationale=msg, 329 | ) 330 | context.update(chart_info) 331 | return CheckResult( 332 | answer=True, 333 | aspect="deconstruction", 334 | rationale="Deconstructed the chart successfully.", 335 | ) 336 | except: 337 | return CheckResult( 338 | answer=False, 339 | aspect="deconstruction", 340 | rationale="Cannot parse the visualization.", 341 | ) 342 | 343 | def chart_type_check(self, context, ground_truth) -> CheckResult: 344 | answer, rationale = chart_check( 345 | context, 346 | ground_truth["chart"], 347 | ( 348 | ground_truth["meta_info"]["stacked_bar"] 349 | if "stacked_bar" in ground_truth["meta_info"] 350 | else None 351 | ), 352 | ) 353 | return CheckResult( 354 | answer=answer, 355 | aspect="chart type check", 356 | rationale=rationale, 357 | ) 358 | 359 | def data_check(self, context, ground_truth) -> CheckResult: 360 | answer, rationale = data_check( 361 | context, 362 | ground_truth["vis_obj"], 363 | ground_truth["meta_info"]["channel_specified"], 364 | ) 365 | return CheckResult( 366 | answer=answer, 367 | aspect="data check", 368 | rationale=rationale, 369 | ) 370 | 371 | def order_check(self, context, ground_truth) -> CheckResult: 372 | answer, rationale = order_check( 373 | context, 374 | ground_truth["vis_obj"], 375 | ( 376 | ground_truth["meta_info"]["sort_by"] 377 | if "sort_by" in ground_truth["meta_info"] 378 | else None 379 | ), 380 | ) 381 | return CheckResult( 382 | answer=answer, 383 | aspect="order check", 384 | rationale=rationale, 385 | ) 386 | 387 | def legality_check(self, context, ground_truth) -> list[CheckResult]: 388 | results = [] 389 | result = self.deconstruction(context) 390 | results.append(result) 391 | if result.answer: 392 | chart_type_check_result = self.chart_type_check(context, ground_truth) 393 | data_check_result = self.data_check(context, ground_truth) 394 | results.append(chart_type_check_result) 395 | results.append(data_check_result) 396 | if data_check_result.answer and ground_truth["vis_obj"]["sort"] is not None: 397 | self.order_check(context, ground_truth) 398 | results.append(self.order_check(context, ground_truth)) 399 | 400 | return results 401 | 402 | def layout_check(self, context) -> CheckResult: 403 | assert "svg_string" in context 404 | assert self.webdriver_path is not None 405 | 406 | answer, rationale = layout_check(context, self.webdriver_path) 407 | return CheckResult( 408 | answer=answer, 409 | aspect="layout check", 410 | rationale=rationale, 411 | ) 412 | 413 | def scale_and_ticks_check(self, context, query) -> CheckResult: 414 | assert "base64" in context and "encoding" in context and "chart" in context 415 | assert self.vision_model is not None 416 | 417 | answer, rationale = scale_and_ticks_check(context, query, self.vision_model) 418 | return CheckResult( 419 | answer=answer, 420 | aspect="scale and ticks check", 421 | rationale=rationale, 422 | ) 423 | 424 | def readability_evaluate(self, context, query: str) -> list[CheckResult]: 425 | results = [] 426 | if self.webdriver_path: 427 | layout_result = self.layout_check(context) 428 | if layout_result.answer is not None: 429 | results.append(layout_result) 430 | 431 | if self.vision_model: 432 | context["base64"] = convert_svg_to_base64(context["svg_string"]) 433 | scale_and_ticks_result = self.scale_and_ticks_check(context, query) 434 | if scale_and_ticks_result.answer is not None: 435 | results.append(scale_and_ticks_result) 436 | 437 | aspect_format = { 438 | "layout check": "Overflow/Overlap", 439 | "scale and ticks check": "Scale/Ticks", 440 | } 441 | reviews = [ 442 | { 443 | "aspect": aspect_format[result.aspect], 444 | "content": result.rationale, 445 | } 446 | for result in results 447 | ] 448 | context["reviews"] = reviews 449 | 450 | answer, rationale = readability_check(context, query, self.vision_model) 451 | if answer is not None: 452 | readability_result = CheckResult( 453 | answer=answer, 454 | aspect="readability check", 455 | rationale=rationale, 456 | ) 457 | results.append(readability_result) 458 | 459 | return results 460 | -------------------------------------------------------------------------------- /viseval_dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/VisEval/619155231c476aaa05f1b3a5b6d79082a6bcf782/viseval_dataset.zip --------------------------------------------------------------------------------