├── .env.example
├── .flake8
├── .github
└── workflows
│ └── codeql.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── docs
├── data.md
└── images
│ └── framework.png
├── examples
├── agent
│ ├── __init__.py
│ ├── chat2vis.py
│ ├── coml4vis.py
│ ├── lida.py
│ └── utils.py
├── evaluate.py
└── model
│ └── langchain_llama.py
├── pyproject.toml
├── tests
├── assets
│ ├── bar.svg
│ ├── bar_1.svg
│ ├── bar_1000_0.svg
│ ├── bar_1071_0.svg
│ ├── bar_1129_0.svg
│ ├── bar_1129_1.svg
│ ├── bar_1129_2.svg
│ ├── bar_186.svg
│ ├── bar_189.svg
│ ├── bar_2.svg
│ ├── bar_205.svg
│ ├── bar_2060_0.svg
│ ├── bar_2733_0.svg
│ ├── bar_3.svg
│ ├── bar_3137_1.svg
│ ├── bar_3269.svg
│ ├── bar_4.svg
│ ├── bar_465.svg
│ ├── bar_5.svg
│ ├── bar_550_0.svg
│ ├── bar_6.svg
│ ├── bar_68_0.svg
│ ├── bar_9.svg
│ ├── bar_error_3227_0.svg
│ ├── dual_1499_0.svg
│ ├── empty.svg
│ ├── empty_1.svg
│ ├── empty_2.svg
│ ├── empty_3.svg
│ ├── empty_4.svg
│ ├── grouping_bar.svg
│ ├── grouping_line.svg
│ ├── grouping_line_1.svg
│ ├── grouping_line_2.svg
│ ├── grouping_line_2781_0.svg
│ ├── grouping_line_3.svg
│ ├── grouping_line_4.svg
│ ├── grouping_scatter.svg
│ ├── grouping_scatter_3272_0.svg
│ ├── horizontal_stacked_bar.svg
│ ├── horizontal_stacked_bar_1.svg
│ ├── line.svg
│ ├── line_1.svg
│ ├── line_10.svg
│ ├── line_11.svg
│ ├── line_12.svg
│ ├── line_13.svg
│ ├── line_14.svg
│ ├── line_15.svg
│ ├── line_1746_1.svg
│ ├── line_2.svg
│ ├── line_3.svg
│ ├── line_3240_0.svg
│ ├── line_4.svg
│ ├── line_5.svg
│ ├── line_6.svg
│ ├── line_7.svg
│ ├── line_773_0.svg
│ ├── line_8.svg
│ ├── line_9.svg
│ ├── pie.svg
│ ├── pie_1.svg
│ ├── pie_2.svg
│ ├── pie_3.svg
│ ├── pie_4.svg
│ ├── pie_4_0.svg
│ ├── pie_5.svg
│ ├── readability
│ │ ├── 1008.svg
│ │ ├── 1024.svg
│ │ ├── 1071.svg
│ │ ├── 1188.svg
│ │ ├── 1237@y_name@ASC.svg
│ │ ├── 1285.svg
│ │ ├── 131.svg
│ │ ├── 1314.svg
│ │ ├── 134@x_name@DESC.svg
│ │ ├── 1363.svg
│ │ ├── 1380@y_name@ASC.svg
│ │ ├── 1392.svg
│ │ ├── 1434.svg
│ │ ├── 145@y_name@DESC.svg
│ │ ├── 1491@y_name@DESC.svg
│ │ ├── 1517.svg
│ │ ├── 1530.svg
│ │ ├── 1534.svg
│ │ ├── 1630.svg
│ │ ├── 173.svg
│ │ ├── 17@y_name@DESC.svg
│ │ ├── 1961.svg
│ │ ├── 1974.svg
│ │ ├── 1992.svg
│ │ ├── 2013.svg
│ │ ├── 2019.svg
│ │ ├── 2024@y_name@DESC.svg
│ │ ├── 2027.svg
│ │ ├── 2110.svg
│ │ ├── 2174.svg
│ │ ├── 219.svg
│ │ ├── 2303.svg
│ │ ├── 2350.svg
│ │ ├── 2416.svg
│ │ ├── 246.svg
│ │ ├── 2526.svg
│ │ ├── 2571.svg
│ │ ├── 2575@y_name@ASC.svg
│ │ ├── 2615.svg
│ │ ├── 2652.svg
│ │ ├── 267.svg
│ │ ├── 2724.svg
│ │ ├── 2756@x_name@DESC.svg
│ │ ├── 2760@y_name@DESC.svg
│ │ ├── 2765@y_name@ASC.svg
│ │ ├── 279@y_name@DESC.svg
│ │ ├── 2815@x_name@DESC.svg
│ │ ├── 2841@y_name@ASC.svg
│ │ ├── 286.svg
│ │ ├── 2941.svg
│ │ ├── 2943@y_name@ASC.svg
│ │ ├── 296.svg
│ │ ├── 3014@y_name@ASC.svg
│ │ ├── 3069@x_name@ASC.svg
│ │ ├── 312.svg
│ │ ├── 3134@y_name@ASC.svg
│ │ ├── 316.svg
│ │ ├── 32.svg
│ │ ├── 3207.svg
│ │ ├── 33.svg
│ │ ├── 338.svg
│ │ ├── 342.svg
│ │ ├── 368.svg
│ │ ├── 387.svg
│ │ ├── 399.svg
│ │ ├── 4.svg
│ │ ├── 403.svg
│ │ ├── 432.svg
│ │ ├── 461.svg
│ │ ├── 464.svg
│ │ ├── 470.svg
│ │ ├── 487.svg
│ │ ├── 513.svg
│ │ ├── 517.svg
│ │ ├── 533@y_name@DESC.svg
│ │ ├── 555@y_name@DESC.svg
│ │ ├── 567.svg
│ │ ├── 58@y_name@ASC.svg
│ │ ├── 611.svg
│ │ ├── 616.svg
│ │ ├── 622.svg
│ │ ├── 65.svg
│ │ ├── 659.svg
│ │ ├── 676.svg
│ │ ├── 69.svg
│ │ ├── 693.svg
│ │ ├── 708.svg
│ │ ├── 718.svg
│ │ ├── 72.svg
│ │ ├── 744.svg
│ │ ├── 75@y_name@ASC.svg
│ │ ├── 789.svg
│ │ ├── 803.svg
│ │ ├── 81.svg
│ │ ├── 847@y_name@DESC.svg
│ │ ├── 856.svg
│ │ ├── 880.svg
│ │ ├── 9.svg
│ │ ├── 914.svg
│ │ ├── 961.svg
│ │ └── readability_human_rating.csv
│ ├── samples.json
│ ├── scatter.svg
│ ├── scatter_1.svg
│ ├── scatter_2.svg
│ ├── scatter_292.svg
│ ├── scatter_3.svg
│ ├── scatter_4.svg
│ ├── scatter_400_0.svg
│ ├── scatter_400_1.svg
│ ├── scatter_5.svg
│ ├── scatter_6.svg
│ ├── scatter_62.svg
│ ├── scatter_674.svg
│ ├── scatter_7.svg
│ ├── stacked_bar.svg
│ ├── stacked_bar_1.svg
│ ├── stacked_bar_2.svg
│ ├── stacked_bar_2750_0.svg
│ ├── stacked_bar_2750_1.svg
│ ├── stacked_bar_2815.svg
│ ├── stacked_bar_2815_0.svg
│ ├── stacked_bar_3.svg
│ ├── stacked_bar_4.svg
│ └── stacked_bar_680_0.svg
├── test_chart_check.py
├── test_data_check.py
├── test_deconstruct.py
├── test_layout_check.py
├── test_order_check.py
└── test_surface_form_check.py
├── viseval
├── __init__.py
├── agent.py
├── check
│ ├── __init__.py
│ ├── chart_check.py
│ ├── data_check.py
│ ├── deconstruct.py
│ ├── layout_check.py
│ ├── order_check.py
│ ├── readability_check.py
│ ├── scale_and_ticks_check.py
│ ├── surface_form_check.py
│ └── time_utils.py
├── dataset.py
└── evaluate.py
└── viseval_dataset.zip
/.env.example:
--------------------------------------------------------------------------------
1 | # Azure OpenAI
2 | AZURE_OPENAI_ENDPOINT=https://xxx.azure.com/
3 | OPENAI_API_KEY=XXXXXXXXX
4 | OPENAI_API_VERSION=XXXXXXXXX
5 |
6 | # Google API
7 | GOOGLE_API_KEY=XXXXXXXXX
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | select = "E303, W293, W291, W292, E305, E231, E302"
4 | exclude =
5 | .tox,
6 | __pycache__,
7 | *.pyc,
8 | .env
9 | venv*/*,
10 | .venv/*,
11 | reports/*,
12 | dist/*,
13 | node_modules/*,
14 |
--------------------------------------------------------------------------------
/.github/workflows/codeql.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ "main" ]
17 | pull_request:
18 | branches: [ "main" ]
19 | schedule:
20 | - cron: '22 2 * * 4'
21 |
22 | jobs:
23 | analyze:
24 | name: Analyze (${{ matrix.language }})
25 | # Runner size impacts CodeQL analysis time. To learn more, please see:
26 | # - https://gh.io/recommended-hardware-resources-for-running-codeql
27 | # - https://gh.io/supported-runners-and-hardware-resources
28 | # - https://gh.io/using-larger-runners (GitHub.com only)
29 | # Consider using larger runners or machines with greater resources for possible analysis time improvements.
30 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
31 | timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }}
32 | permissions:
33 | # required for all workflows
34 | security-events: write
35 |
36 | # required to fetch internal or private CodeQL packs
37 | packages: read
38 |
39 | # only required for workflows in private repositories
40 | actions: read
41 | contents: read
42 |
43 | strategy:
44 | fail-fast: false
45 | matrix:
46 | include:
47 | - language: python
48 | - build-mode: none
49 | # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift'
50 | # Use `c-cpp` to analyze code written in C, C++ or both
51 | # Use 'java-kotlin' to analyze code written in Java, Kotlin or both
52 | # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
53 | # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
54 | # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
55 | # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
56 | # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
57 | steps:
58 | - name: Checkout repository
59 | uses: actions/checkout@v4
60 |
61 | # Initializes the CodeQL tools for scanning.
62 | - name: Initialize CodeQL
63 | uses: github/codeql-action/init@v3
64 | with:
65 | languages: ${{ matrix.language }}
66 | build-mode: ${{ matrix.build-mode }}
67 | # If you wish to specify custom queries, you can do so here or in a config file.
68 | # By default, queries listed here will override any specified in a config file.
69 | # Prefix the list here with "+" to use these queries and those in the config file.
70 |
71 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
72 | # queries: security-extended,security-and-quality
73 |
74 | # If the analyze step fails for one of the languages you are analyzing with
75 | # "We were unable to automatically build your code", modify the matrix above
76 | # to set the build mode to "manual" for that language. Then modify this step
77 | # to build your code.
78 | # ℹ️ Command-line programs to run using the OS shell.
79 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
80 | - if: matrix.build-mode == 'manual'
81 | run: |
82 | echo 'If you are using a "manual" build mode for one or more of the' \
83 | 'languages you are analyzing, replace this with the commands to build' \
84 | 'your code, for example:'
85 | echo ' make bootstrap'
86 | echo ' make release'
87 | exit 1
88 |
89 | - name: Perform CodeQL Analysis
90 | uses: github/codeql-action/analyze@v3
91 | with:
92 | category: "/language:${{matrix.language}}"
93 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Customized
2 | logs
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
137 | # pytype static type analyzer
138 | .pytype/
139 |
140 | # Cython debug symbols
141 | cython_debug/
142 |
143 | # security
144 | .secrets/*
145 | *cache*
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.4.0
4 | hooks:
5 | - id: check-byte-order-marker
6 | - id: check-case-conflict
7 | - id: check-merge-conflict
8 | - id: check-symlinks
9 | - id: debug-statements
10 |
11 | - repo: https://github.com/pycqa/isort
12 | rev: 5.12.0
13 | hooks:
14 | - id: isort
15 | types: [python]
16 | args: ["--profile=black"]
17 |
18 | - repo: https://github.com/psf/black
19 | rev: 23.3.0
20 | hooks:
21 | - id: black
22 | types: [python]
23 |
24 | - repo: local
25 | hooks:
26 | - id: run-pytest
27 | name: Run pytest
28 | entry: pytest
29 | language: system
30 | types: [python]
31 | pass_filenames: false
32 | always_run: true
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | This project welcomes contributions and suggestions. Most contributions require you to
4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to,
5 | and actually do, grant us the rights to use your contribution. For details, visit
6 | https://cla.microsoft.com.
7 |
8 | When you submit a pull request, a CLA-bot will automatically determine whether you need
9 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
10 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
11 |
12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
14 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VisEval: A NL2VIS Benchmark
2 | VisEval is a benchmark designed to evaluate visualization generation methods.
3 | In this repository, we provide both the toolkit to support the benchmarking, as well as the data used for benchmarks.
4 |
5 | ## What Can VisEval Evaluate
6 |
7 | 
8 |
9 | VisEval evaluates generated visualizations from three dimensions:
10 | 1. Whether the generated code can produce the visualization.
11 | 2. Whether the generated visualization meets the query.
12 | 3. Whether the generated visualization is easy to read.
13 |
14 |
15 | ## Get Started
16 | ### Install Benchmark Toolkit
17 |
18 | ```bash
19 | pip install --upgrade vis-evaluator
20 | # or `git clone https://github.com/microsoft/VisEval.git && cd VisEval && pip install --upgrade -e .`
21 | ```
22 |
23 | ### Download Benchmark Dataset
24 | To access the dataset, please follow these steps:
25 |
26 | 1. Download the dataset from [this link](https://github.com/microsoft/VisEval/blob/main/viseval_dataset.zip).
27 | 2. Once the download is complete, unzip the file to extract the dataset contents.
28 |
29 | For additional information about the dataset, please refer to the [dataset documentation](docs/data.md).
30 |
31 | ### Usage & Examples
32 | After installation, you can use VisEval by referring to `examples/evaluate.py` or a follow:
33 |
34 |
35 | 1. **Create your generation method** by inheriting from the `Agent` Class. You can find three examples in the `examples/agent` directory.
36 | ```python
37 | from viseval.agent import Agent, ChartExecutionResult
38 |
39 | class YourAgent(Agent):
40 | def __init__(self, llm):
41 | self.llm = llm
42 |
43 | def generate(
44 | self, nl_query: str, tables: list[str], config: dict
45 | ) -> Tuple[str, dict]:
46 | """Generate code for the given natural language query."""
47 | pass
48 |
49 | def execute(
50 | self, code: str, context: dict, log_name: str = None
51 | ) -> ChartExecutionResult:
52 | """Execute the given code with context and return the result"""
53 | pass
54 | ```
55 |
56 | 2. **Configure evaluator**.
57 | ```python
58 | evaluator = Evaluator(webdriver_path, vision_model)
59 | ```
60 | (You can configure the Evaluator without a webdriver and vision model, in which case the evaluation of the readability of the generated visualizations will be skipped.)
61 |
62 | - Install webdriver.
63 | ```bash
64 | # download
65 | wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb
66 | # install
67 | apt install google-chrome-stable_current_amd64.deb
68 | # verify
69 | google-chrome --version
70 | ```
71 |
72 | - Load vision model (e.g., GPT4-v).
73 | ```python
74 | from langchain_openai import AzureChatOpenAI
75 |
76 | import dotenv
77 | # Copy .env.example to .env and put your API keys in the file.
78 | dotenv.load_dotenv()
79 |
80 | vision_model = AzureChatOpenAI(
81 | model_name="gpt-4-turbo-v",
82 | max_retries=999,
83 | temperature=0.0,
84 | request_timeout=20,
85 | max_tokens=4096,
86 | )
87 | ```
88 |
89 | 3. **Evaluate**
90 | ```python
91 | from viseval import Dataset
92 |
93 | # Configure dataset with the benchmark dataset folder path ( folder),
94 | # specify the number of tables required to generate visualizations (table_type`: all, single, or multiple),
95 | # and indicate whether to include irrelevant tables (`with_irrelevant_tables`).
96 | dataset = Dataset(folder, table_type, with_irrelevant_tables)
97 |
98 | config = {"library": args.library}
99 | result = evaluator.evaluate(agent, dataset, config)
100 | score = result.score()
101 | print(f"Score: {score}")
102 | ```
103 |
104 |
105 | ## Contributing
106 |
107 | This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
108 |
109 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
110 |
111 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
112 |
113 | ## Trademarks
114 |
115 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow
116 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
117 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
118 | Any use of third-party trademarks or logos are subject to those third-party's policies.
119 |
120 | ## Privacy Statement
121 |
122 | This project has adopted the [Microsoft Privacy Statement](https://go.microsoft.com/fwlink/?LinkId=521839.).
123 |
124 | ## Citation
125 |
126 | If you find that VisEval helps your research, please consider citing it:
127 | ```
128 | @misc{chen2024viseval,
129 | title={VisEval: A Benchmark for Data Visualization in the Era of Large Language Models},
130 | author={Nan Chen and Yuge Zhang and Jiahang Xu and Kan Ren and Yuqing Yang},
131 | year={2024},
132 | eprint={2407.00981},
133 | archivePrefix={arXiv},
134 | primaryClass={cs.HC},
135 | }
136 | ```
137 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # Support
2 |
3 | ## How to file issues and get help
4 |
5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing issues before filing new issues to avoid duplicates. For new issues, file your bug or feature request as a new Issue.
6 |
7 | For help and questions about using this project, please consult the project readme or open an issue.
8 |
9 | ## Microsoft Support Policy
10 |
11 | Support for this project is limited to the resources listed above.
12 |
--------------------------------------------------------------------------------
/docs/data.md:
--------------------------------------------------------------------------------
1 | # VisEval Dataset
2 |
3 | ## Introduction
4 |
5 | VisEval dataset is a large-scale and high-quality dataset for Nature Language to Visualization (NL2VI) task. It contains 2,524 (NL, VIS) pairs, supports 7 common types of visualizations and covers 146 databases.
6 |
7 | - **VisEval.json** stores the JSON format of (NL, VIS) pairs. The natural language (NL) is a sentence that describes the desired visualization. The visualization (VIS) is represented in JSON format, including chart type, data for the x-axis, y-axis, and z-axis, as well as information such as sorting requirements.
8 | - **VisEval_single.json** and **VisEval_multiple.json** store the visualizations that can be generated from a single data table and those that require processing multiple data tables, respectively.
9 | - **databases** contains 146 databases, and each database has several data tables saved in CSV format.
10 |
11 |
12 | ## Important Notes / Caveats / FAQs
13 | - The primary objective of this dataset is to serve as a benchmark for evaluating LLMs-based methods in natural language to visualization generation. This dataset is intended for research purposes only and should not be relied upon as the sole benchmark for production scenarios.
14 | - How was the data collected? The dataset was constructed based on [nvBench](https://github.com/TsinghuaDatabaseGroup/nvBench). We selected high-quality queries from the original dataset, and we corrected and annotated the dataset ourselves. More details are provided in our paper's subsection 4.1.
--------------------------------------------------------------------------------
/docs/images/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/VisEval/619155231c476aaa05f1b3a5b6d79082a6bcf782/docs/images/framework.png
--------------------------------------------------------------------------------
/examples/agent/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .chat2vis import Chat2vis
5 | from .coml4vis import CoML4VIS
6 | from .lida import Lida
7 |
--------------------------------------------------------------------------------
/examples/agent/chat2vis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import sys
5 | import warnings
6 |
7 | import pandas as pd
8 | from langchain.chat_models.base import BaseChatModel
9 | from langchain.schema import HumanMessage
10 |
11 | from viseval.agent import Agent, ChartExecutionResult
12 |
13 | from .utils import show_svg
14 |
15 | MAXIMUM_SAMPLES = 10
16 |
17 | TEMPLATE = '''"""Use a dataframe called df from data_file.csv with columns {columns}.
18 |
19 | {columns_description}
20 |
21 | Label the x and y axes appropriately. Add a title. Set the fig suptitle as empty.
22 |
23 | Using Python version 3.9.12 and library {library}, create a script using the dataframe df to graph the following: {nl_query}.
24 | """
25 |
26 | {pre_code}
27 | '''
28 |
29 |
30 | class Chat2vis(Agent):
31 | def __init__(self, llm: BaseChatModel):
32 | self.llm = llm
33 |
34 | def table_format(self, data: pd.DataFrame):
35 | # table format
36 | descriptions = []
37 | for column in data.columns:
38 | dtype = data[column].dtype
39 | description = None
40 | if dtype in [int, float, complex]:
41 | description = f"The column '{column}' is type {dtype} and contains numeric values."
42 | elif dtype == bool:
43 | description = f"The column '{column}' is type {dtype} and contains boolean values."
44 | elif dtype == object:
45 | # Check if the string column can be cast to a valid datetime
46 | try:
47 | with warnings.catch_warnings():
48 | warnings.simplefilter("ignore")
49 | pd.to_datetime(data[column], errors="raise")
50 | dtype = "date"
51 | except ValueError:
52 | # Check if the string column has a limited number of values
53 | if data[column].nunique() / len(data[column]) < 0.5:
54 | dtype = "category"
55 | else:
56 | dtype = "string"
57 | elif pd.api.types.is_categorical_dtype(data[column]):
58 | dtype = "category"
59 | elif pd.api.types.is_datetime64_any_dtype(data[column]):
60 | dtype = "date"
61 |
62 | if dtype == "date" or dtype == "category":
63 | non_null_values = data[column][data[column].notnull()].unique()
64 | n_samples = min(MAXIMUM_SAMPLES, len(non_null_values))
65 | samples = (
66 | pd.Series(non_null_values)
67 | .sample(n_samples, random_state=42)
68 | .tolist()
69 | )
70 | values = "'" + "', '".join(samples) + "'"
71 | description = f"The column '{column}' has {dtype} values {values}"
72 |
73 | if n_samples < len(non_null_values):
74 | description += " etc."
75 | else:
76 | description += "."
77 | elif description is None:
78 | description = f"The column '{column}' is {dtype} type."
79 |
80 | descriptions.append(description)
81 |
82 | return " ".join(descriptions)
83 |
84 | def generate(self, nl_query: str, tables: list[str], config: dict):
85 | library = config["library"]
86 |
87 | if library == "seaborn":
88 | import_statements = "import seaborn as sns\n"
89 | else:
90 | import_statements = ""
91 |
92 | pre_code = f"""import pandas as pd
93 | import matplotlib.pyplot as plt
94 | {import_statements}
95 | fig,ax = plt.subplots(1,1,figsize=(10,4))
96 | ax.spines['top'].set_visible(False)
97 | ax.spines['right'].set_visible(False)
98 |
99 | df=df_nvBenchEval.copy()
100 | """
101 | data = pd.read_csv(tables[0], encoding="utf-8")
102 | columns = "'" + "', '".join(list(data.columns)) + "'"
103 | prompt = TEMPLATE.format(
104 | columns=columns,
105 | columns_description=self.table_format(data),
106 | library=library,
107 | nl_query=nl_query,
108 | pre_code=pre_code,
109 | )
110 |
111 | try:
112 | messages = [HumanMessage(content=prompt)]
113 | response = self.llm.invoke(messages)
114 | code = response.content
115 | codes = code.split("\n")
116 | codes = list(filter(lambda row: "data_file.csv" not in row, codes))
117 | code = "\n".join(codes)
118 | # plot.show
119 | if "plt.show()" not in code and ("plt." in code or "fig." in code):
120 | code += "\nplt.show()"
121 |
122 | context = {
123 | "tables": tables,
124 | }
125 | return pre_code + "\n" + code, context
126 | except Exception:
127 | warnings.warn(str(sys.exc_info()))
128 | return None, None
129 |
130 | def execute(self, code: str, context: dict, log_name: str = None):
131 | tables = context["tables"]
132 | df_nvBenchEval = pd.read_csv(tables[0])
133 | global_env = {
134 | "df_nvBenchEval": df_nvBenchEval,
135 | "svg_string": None,
136 | "show_svg": show_svg,
137 | "svg_name": log_name,
138 | }
139 | code += "\nsvg_string = show_svg(plt, svg_name)"
140 | try:
141 | exec(code, global_env)
142 | svg_string = global_env["svg_string"]
143 | return ChartExecutionResult(status=True, svg_string=svg_string)
144 | except Exception as exception_error:
145 | import traceback
146 |
147 | exception_info = traceback.format_exception_only(
148 | type(exception_error), exception_error
149 | )
150 | return ChartExecutionResult(status=False, error_msg=exception_info)
151 |
--------------------------------------------------------------------------------
/examples/agent/coml4vis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import sys
5 | import warnings
6 |
7 | import pandas as pd
8 | from coml import CoMLAgent, describe_variable
9 | from langchain.chat_models.base import BaseChatModel
10 |
11 | from viseval.agent import Agent, ChartExecutionResult
12 |
13 | from .utils import show_svg
14 |
15 |
16 | def read_table(name, url, format):
17 | code = f"{name}_dataset = pd.read_csv('{url}')"
18 | variable_description = {}
19 | exec(code)
20 | exec(
21 | f"variable_description['{name}_dataset'] = describe_variable({name}_dataset, dataframe_format='{format}', pandas_description_config=dict(max_rows=10))"
22 | )
23 | return code, variable_description
24 |
25 |
26 | class CoML4VIS(Agent):
27 | def __init__(self, llm: BaseChatModel, config: dict = None):
28 | num_examples = 1
29 | prompt_version = "matplotlib"
30 | if config:
31 | if "num_examples" in config:
32 | num_examples = min(config["num_examples"], 4)
33 | if "library" in config and config["library"] in [
34 | "matplotlib",
35 | "seaborn",
36 | ]:
37 | prompt_version = config["library"]
38 |
39 | self.coml = CoMLAgent(
40 | llm, num_examples=num_examples, prompt_version=prompt_version
41 | )
42 |
43 | def pre_code(self, tables: list[dict], chart_lib: str, table_format: str = "coml"):
44 | codes = ["import pandas as pd\nimport matplotlib.pyplot as plt\n"]
45 | variable_descriptions = {}
46 | if chart_lib == "seaborn":
47 | codes[-1] += "import seaborn as sns\n"
48 |
49 | for url in tables:
50 | name = url.split("/")[-1].split(".")[0]
51 | code, variable_description = read_table(name, url, table_format)
52 | codes.append(code)
53 | variable_descriptions.update(variable_description)
54 | return codes, variable_descriptions
55 |
56 | def generate(self, nl_query: str, tables: list[str], config: dict):
57 | library = config["library"]
58 | table_format = config["table_format"] if "table_format" in config else "coml"
59 |
60 | pre_codes, variable_descriptions = self.pre_code(tables, library, table_format)
61 | generating_context = self.coml.generate_code(
62 | nl_query, variable_descriptions, pre_codes
63 | )
64 | try:
65 | generating_context = self.coml.generate_code(
66 | nl_query, variable_descriptions, pre_codes
67 | )
68 | generate_code = generating_context["answer"]
69 |
70 | context = {"tables": tables, "library": library}
71 | return "\n".join(pre_codes) + "\n" + generate_code, context
72 | except Exception:
73 | warnings.warn(str(sys.exc_info()))
74 | return None, None
75 |
76 | def execute(self, code: str, context: dict, log_name: str = None):
77 | tables = context["tables"]
78 | library = context["library"]
79 |
80 | global_env = {"svg_string": None, "show_svg": show_svg, "svg_name": log_name}
81 | code += "\nsvg_string = show_svg(plt, svg_name)"
82 | try:
83 | exec(code, global_env)
84 | svg_string = global_env["svg_string"]
85 | return ChartExecutionResult(status=True, svg_string=svg_string)
86 | except Exception as exception_error:
87 | try:
88 | # handle old version
89 | codes, variable_descriptions = self.pre_code(tables, library)
90 | exec("\n".join(codes) + "\n" + code, global_env)
91 | svg_string = global_env["svg_string"]
92 | return ChartExecutionResult(status=True, svg_string=svg_string)
93 | except Exception as exception_error:
94 | error_msg = str(exception_error)
95 | return ChartExecutionResult(status=False, error_msg=error_msg)
96 |
--------------------------------------------------------------------------------
/examples/agent/lida.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import time
5 |
6 | import matplotlib.pyplot as plt
7 | import pandas as pd
8 | from lida import Manager
9 | from lida.components import get_globals_dict, preprocess_code
10 | from lida.datamodel import Goal
11 | from llmx import TextGenerator
12 |
13 | from viseval.agent import Agent, ChartExecutionResult
14 |
15 | from .utils import show_svg
16 |
17 | max_retries = 20
18 | retry_seconds = 20
19 |
20 |
21 | class Lida(Agent):
22 | def __init__(self, llm: TextGenerator):
23 | self.lida = Manager(text_gen=llm)
24 |
25 | def generate(self, nl_query: str, tables: list[str], config: dict):
26 | library = config["library"]
27 | summary = self.lida.summarize(tables[0])
28 |
29 | for attempt in range(max_retries):
30 | try:
31 | charts = self.lida.visualize(
32 | summary=summary, goal=nl_query, library=library, return_error=True
33 | )
34 |
35 | code = charts[0].code
36 | code += "\nplt.show()"
37 |
38 | context = {"data": self.lida.data, "library": library}
39 | return code, context
40 | except Exception:
41 | if attempt < max_retries - 1:
42 | print(f"Retrying in {retry_seconds} seconds...")
43 | time.sleep(retry_seconds)
44 |
45 | return None, None
46 |
47 | def execute(self, code: str, context: dict, log_name: str = None):
48 | data = context["data"]
49 | library = context["library"]
50 |
51 | code = preprocess_code(code)
52 | if library == "matplotlib" or library == "seaborn":
53 | try:
54 | ex_locals = get_globals_dict(code, data)
55 | exec(code, ex_locals)
56 |
57 | plt.box(False)
58 | plt.grid(color="lightgray", linestyle="dashed", zorder=-10)
59 |
60 | svg_string = show_svg(plt, log_name)
61 | return ChartExecutionResult(status=True, svg_string=svg_string)
62 | except Exception as exception_error:
63 | import traceback
64 |
65 | exception_info = traceback.format_exception_only(
66 | type(exception_error), exception_error
67 | )
68 | return ChartExecutionResult(status=False, error_msg=exception_info)
69 | else:
70 | pass
71 |
72 | def evaluate(self, code: str, nl_query: str, library: str):
73 | goal = Goal(question=nl_query, visualization=nl_query, rationale="")
74 |
75 | result = self.lida.evaluate(code, goal, library=library)
76 | return result[0]
77 |
--------------------------------------------------------------------------------
/examples/agent/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
5 | def show_svg(plt, svg_name: str):
6 | """Show a plot as a SVG inline."""
7 | from io import StringIO
8 |
9 | f = StringIO()
10 | plt.savefig(f, format="svg")
11 | if svg_name:
12 | plt.savefig(f"{svg_name}")
13 | plt.close()
14 |
15 | return f.getvalue()
16 |
--------------------------------------------------------------------------------
/examples/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import argparse
5 | from pathlib import Path
6 |
7 | import dotenv
8 | from agent import Chat2vis, CoML4VIS, Lida
9 |
10 | from viseval import Dataset, Evaluator
11 |
12 | dotenv.load_dotenv()
13 |
14 |
15 | def configure_llm(model: str, agent: str):
16 | if agent == "lida":
17 | if model in ["gpt-35-turbo", "gpt-4"]:
18 | from llmx import llm
19 |
20 | return llm(
21 | provider="openai",
22 | api_type="azure",
23 | model=model,
24 | models={
25 | "max_tokens": 4096,
26 | "temperature": 0.0,
27 | },
28 | )
29 | else:
30 | raise ValueError(f"Unknown model {model}")
31 | else:
32 | if model == "gemini-pro":
33 | from langchain_google_genai import ChatGoogleGenerativeAI
34 |
35 | return ChatGoogleGenerativeAI(
36 | model=model, temperature=0.0, convert_system_message_to_human=True
37 | )
38 | elif model in ["gpt-35-turbo", "gpt-4"]:
39 | from langchain_openai import AzureChatOpenAI
40 |
41 | return AzureChatOpenAI(
42 | model_name=model,
43 | max_retries=999,
44 | temperature=0.0,
45 | request_timeout=20,
46 | )
47 | elif model == "codellama-7b":
48 | from model.langchain_llama import ChatLlama
49 |
50 | return ChatLlama("../llama_models/CodeLlama-7b-Instruct")
51 | else:
52 | raise ValueError(f"Unknown model {model}")
53 |
54 |
55 | def config_agent(agent: str, model: str, config: dict):
56 | llm = configure_llm(model, agent)
57 | if agent == "coml4vis":
58 | return CoML4VIS(llm, config)
59 | elif agent == "chat2vis":
60 | return Chat2vis(llm)
61 | elif agent == "lida":
62 | return Lida(llm)
63 | else:
64 | raise ValueError(f"Unknown agent {agent}")
65 |
66 |
67 | def _main():
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument("--benchmark", type=Path)
70 | parser.add_argument(
71 | "--type", type=str, choices=["all", "single", "multiple"], default="all"
72 | )
73 | parser.add_argument("--irrelevant_tables", type=bool, default=False)
74 |
75 | parser.add_argument(
76 | "--model",
77 | type=str,
78 | default="gpt-35-turbo",
79 | choices=["gpt-4", "gpt-35-turbo", "gemini-pro", "codellama-7b"],
80 | )
81 | parser.add_argument(
82 | "--agent",
83 | type=str,
84 | default="coml4vis",
85 | choices=["coml4vis", "lida", "chat2vis"],
86 | )
87 | parser.add_argument("--num_examples", type=int, default=1, choices=range(0, 4))
88 | parser.add_argument(
89 | "--library", type=str, default="matplotlib", choices=["matplotlib", "seaborn"]
90 | )
91 | parser.add_argument("--logs", type=Path, default="./logs")
92 | parser.add_argument("--webdriver", type=Path, default="/usr/bin/chromedriver")
93 |
94 | args = parser.parse_args()
95 |
96 | # config dataset
97 | dataset = Dataset(args.benchmark, args.type, args.irrelevant_tables)
98 |
99 | # config agent
100 | agent = config_agent(
101 | args.agent,
102 | args.model,
103 | {"num_examples": args.num_examples, "library": args.library},
104 | )
105 |
106 | from langchain_openai import AzureChatOpenAI
107 |
108 | vision_model = AzureChatOpenAI(
109 | model_name="gpt-4-turbo-v",
110 | max_retries=999,
111 | temperature=0.0,
112 | request_timeout=20,
113 | max_tokens=4096,
114 | )
115 | # config evaluator
116 | evaluator = Evaluator(webdriver_path=args.webdriver, vision_model=vision_model)
117 |
118 | # evaluate agent
119 | config = {"library": args.library, "logs": args.logs}
120 | result = evaluator.evaluate(agent, dataset, config)
121 | score = result.score()
122 | print(f"Score: {score}")
123 |
124 |
125 | if __name__ == "__main__":
126 | _main()
127 |
--------------------------------------------------------------------------------
/examples/model/langchain_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import json
5 | from typing import Any, List, Optional
6 |
7 | import torch
8 | from langchain.callbacks.manager import CallbackManagerForLLMRun
9 | from langchain.chat_models.base import SimpleChatModel
10 | from langchain.schema import (
11 | AIMessage,
12 | BaseMessage,
13 | ChatMessage,
14 | HumanMessage,
15 | SystemMessage,
16 | )
17 | from transformers import AutoTokenizer, LlamaForCausalLM
18 |
19 | B_ROUND, E_ROUND = "", ""
20 | B_INST, E_INST = "[INST]", "[/INST]"
21 | B_SYS, E_SYS = "<>\n", "\n<>\n\n"
22 |
23 | SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"]
24 | UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
25 |
26 |
27 | class ChatLlama(SimpleChatModel):
28 | # read from config.json
29 | max_context_length: int = 2048
30 | max_new_tokens: int = 4096
31 | temperature: float = 0.0
32 | top_p: float = 1
33 | top_k: int = 50
34 | tokenizer: Any
35 | model: Any
36 |
37 | def __init__(self, model_path):
38 | super().__init__()
39 |
40 | self.model = LlamaForCausalLM.from_pretrained(
41 | model_path, torch_dtype=torch.float16, device_map="auto"
42 | )
43 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
44 | self.tokenizer.pad_token = self.tokenizer.eos_token
45 | with open(f"{model_path}/config.json") as f:
46 | config = json.load(f)
47 | self.max_context_length = config["max_position_embeddings"]
48 |
49 | @property
50 | def _llm_type(self) -> str:
51 | return "llama2-chat"
52 |
53 | def _assemble_prompt(self, messages: List[BaseMessage]) -> str:
54 | prompt = ""
55 | temp = []
56 | for message in messages:
57 | if isinstance(message, ChatMessage):
58 | role = message.role
59 | elif isinstance(message, HumanMessage):
60 | role = "user"
61 | elif isinstance(message, AIMessage):
62 | role = "assistant"
63 | elif isinstance(message, SystemMessage):
64 | role = "system"
65 | else:
66 | raise ValueError(f"Got unknown type {message}")
67 | temp.append({"role": role, "content": message.content})
68 | messages = temp
69 | if messages[0]["role"] == "system":
70 | messages = [
71 | {
72 | "role": messages[1]["role"],
73 | "content": B_SYS
74 | + messages[0]["content"]
75 | + E_SYS
76 | + messages[1]["content"],
77 | }
78 | ] + messages[2:]
79 | assert all([msg["role"] == "user" for msg in messages[::2]]) and all(
80 | [msg["role"] == "assistant" for msg in messages[1::2]]
81 | ), (
82 | "model only supports 'system', 'user' and 'assistant' roles, "
83 | "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
84 | )
85 | prompts: List[str] = [
86 | f"{B_ROUND}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {E_ROUND}"
87 | for prompt, answer in zip(
88 | messages[::2],
89 | messages[1::2],
90 | )
91 | ]
92 | prompt = "".join(prompts)
93 | assert (
94 | messages[-1]["role"] == "user"
95 | ), f"Last message must be from user, got {messages[-1]['role']}"
96 | prompt += f"{B_ROUND}{B_INST} {(messages[-1]['content']).strip()} {E_INST}"
97 | return prompt
98 |
99 | def _call(
100 | self,
101 | messages: List[BaseMessage],
102 | stop: Optional[List[str]] = None,
103 | run_manager: Optional[CallbackManagerForLLMRun] = None,
104 | ) -> str:
105 | unsafe = any(
106 | tag in message.content for message in messages for tag in SPECIAL_TAGS
107 | )
108 | if unsafe:
109 | return UNSAFE_ERROR
110 | prompt = self._assemble_prompt(messages)
111 |
112 | input_ids = self.tokenizer(
113 | prompt, return_tensors="pt", add_special_tokens=False
114 | ).input_ids.to("cuda")
115 | assert len(input_ids[0]) <= self.max_context_length, (
116 | f"Prompt is too long, got {len(input_ids[0])} tokens, "
117 | f"max is {self.max_context_length}"
118 | )
119 | generate_input = {
120 | "input_ids": input_ids,
121 | "max_new_tokens": self.max_new_tokens,
122 | "repetition_penalty": 1.0,
123 | "eos_token_id": self.tokenizer.eos_token_id,
124 | "bos_token_id": self.tokenizer.bos_token_id,
125 | "pad_token_id": self.tokenizer.pad_token_id,
126 | }
127 | if self.temperature == 0.0:
128 | generate_input["do_sample"] = False
129 | else:
130 | generate_input["do_sample"] = True
131 | generate_input["temperature"] = self.temperature
132 | generate_input["top_p"] = self.top_p
133 | generate_input["top_k"] = self.top_k
134 |
135 | generate_ids = self.model.generate(**generate_input)
136 | generate_ids = [item[len(input_ids[0]) : -1] for item in generate_ids]
137 | # output
138 | output = self.tokenizer.batch_decode(
139 | generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
140 | )[0]
141 | return output
142 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "vis-evaluator"
3 | version = "0.0.3"
4 | requires-python = ">=3.10"
5 | dependencies = [
6 | "langchain",
7 | "python-dotenv",
8 | "numpy",
9 | "pandas",
10 | "matplotlib",
11 | "seaborn",
12 | "CairoSVG",
13 | "selenium",
14 | "llmx",
15 | ]
16 | readme = "README.md"
17 | license = {text = "MIT"}
18 |
19 | [tool.setuptools.packages.find]
20 | include = ["viseval*"]
21 |
22 | [project.urls]
23 | Documentation = "https://github.com/microsoft/VisEval#readme"
24 | Issues = "https://github.com/microsoft/VisEval/issues"
25 | Source = "https://github.com/microsoft/VisEval"
26 |
27 | [project.optional-dependencies]
28 | dev = [
29 | "pytest",
30 | "flake8",
31 | "pre-commit",
32 | "black",
33 | "isort",
34 | "hatch",
35 | "mlcopilot",
36 | "lida",
37 | ]
38 |
39 |
40 | [build-system]
41 | requires = ["setuptools", "wheel"]
42 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/tests/assets/empty_3.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
33 |
--------------------------------------------------------------------------------
/tests/assets/empty_4.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
665 |
--------------------------------------------------------------------------------
/tests/assets/pie_4.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
726 |
--------------------------------------------------------------------------------
/tests/assets/readability/2571.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
700 |
--------------------------------------------------------------------------------
/tests/assets/readability/readability_human_rating.csv:
--------------------------------------------------------------------------------
1 | visualization_id,query,expert1,expert2,expert3,has_layout_issues,has_scale&ticks_issues
2 | 2756@x_name@DESC,"Draw a bar chart showing the total number of every ship type by categorizing by nationality, and list x-axis in descending order.",1,1,2,TRUE,FALSE
3 | 2815@x_name@DESC,"Compute the total number in each competition type, and then split by country. Plot a bar chart and show in desc by the x-axis.",1,1,1,TRUE,FALSE
4 | 2841@y_name@ASC,"Bar graph to show the average of high temperature from different day of week, and could you show y axis in ascending order?",2,2,3,FALSE,TRUE
5 | 312,Give me the proportion of end stations for the trips with the three oldest id using a pie chart.,4,4,4,TRUE,FALSE
6 | 69,What is the number of booking start dates of the apartments with more than 2 bedrooms for each year? Return a bar chart.,1,2,1,TRUE,TRUE
7 | 4,Show all the faculty ranks and the number of students advised by each rank with a pie chart.,4,5,5,FALSE,FALSE
8 | 9,Show the number of faculty members for each rank in a bar chart.,4,3,1,TRUE,FALSE
9 | 32,List the names of aircrafts and the number of times it won matches by a bar chart.,3,3,2,TRUE,TRUE
10 | 286,What are the date and mean humidity for the top 3 days with the largest max wind speeds? Return me a bar chart.,2,3,1,TRUE,FALSE
11 | 296,"For the top 3 days with the largest max wind speeds, please bin the date into the of the week and then compute the average of mean humidity to visualize a bar chart.",3,3,2,TRUE,FALSE
12 | 33,"What are the descriptions for the aircrafts, and count them by a pie chart",4,4,5,FALSE,FALSE
13 | 65,"Show the number of start dates of all the apartment bookings made by guests with gender code ""Female"" for each year in a bar chart.",1,2,1,FALSE,TRUE
14 | 72,Find the number of booking start date for the apartments that have more than two bedrooms for each weekday in a bar chart.,4,3,4,FALSE,TRUE
15 | 81,A pie chart for showing the number of the facility codes of apartments with more than 4 bedrooms.,4,5,5,FALSE,FALSE
16 | 387,"A grouped scatter chart shows the correlation between Height and Weight , and group by attribute Sex.",5,5,4,FALSE,FALSE
17 | 399,"Plot a scatter chart, to show the correlation between support and consider rates for each candidate.",5,4,3,FALSE,FALSE
18 | 432,plot scatter on what is the maximum accelerate for different number of cylinders?,5,4,4,FALSE,FALSE
19 | 267,Return a line chart about the change of monthly_rental over date_address_from .,4,5,5,FALSE,FALSE
20 | 316,A line chart for giveing me the number of the dates when the max temperature was higher than 85.,4,3,4,FALSE,TRUE
21 | 470,What is the sum of capacity of cinemas open for each year? Return a line chart.,4,4,5,FALSE,FALSE
22 | 173,A pie chart showing the number of results of the battles when the bulgarian commander is not 'Boril'.,4,5,4,FALSE,FALSE
23 | 219,"Show me about the distribution of other_details and the average of monthly_rental , and group by attribute other_details in a bar chart.",5,5,5,FALSE,FALSE
24 | 246,Draw a bar chart about the distribution of date_address_from and the sum of monthly_rental bin date_address_from by year.,1,2,1,FALSE,TRUE
25 | 338,Compute the total number of stations across city as a pie chart.,4,4,5,FALSE,FALSE
26 | 342,"What are the dates in which the mean sea level pressure was between 30.3 and 31, and count them by a line chart",2,1,1,TRUE,TRUE
27 | 368,A bar chart showing the number of accelerators for each browser.,4,3,2,TRUE,TRUE
28 | 403,Count the number of people of each sex who have a weight higher than 85 by a bar chart.,4,4,3,FALSE,TRUE
29 | 1008,Show the number of products with price higher than 1000 or lower than 500 for each product name in a pie chart.,5,5,5,FALSE,FALSE
30 | 461,Give me a pie chart showing sum of price for each cinema.,4,4,3,FALSE,FALSE
31 | 464,Show each location and the number of cinemas there by a bar chart.,4,3,4,TRUE,TRUE
32 | 487,Show the number of climbers for each mountain in a pie chart.,4,3,4,TRUE,FALSE
33 | 513,Find the total credits of all classes offered by each department. Visualize by bar chart.,5,5,5,FALSE,FALSE
34 | 517,Find the number of students whose gpa is lower than the average gpa of all students for different first name in a pie chart.,4,4,5,FALSE,FALSE
35 | 567,What is the gpa of the top 5 students with highest gpa? Show me a bar chart with each student by first name.,5,5,5,FALSE,FALSE
36 | 611,Find the number of courses offered by Psychology department in each year with a line chart.,3,2,4,FALSE,TRUE
37 | 616,"Find dept_name and the sum of salary , and group by attribute dept_name, and visualize them by a bar chart.",1,2,1,TRUE,FALSE
38 | 622,Find the relationship between average and maximum capacity among rooms in each building with a scatter chart.,5,5,5,FALSE,FALSE
39 | 659,"Find the last name of female (sex is F) students in the descending order of age, and count them by a bar chart",4,3,4,TRUE,TRUE
40 | 676,What is the number of each course name that have at least five enrollments? Show me a bar chart.,2,3,3,TRUE,TRUE
41 | 693,Show the number of singers in each country with a bar chart.,4,4,3,FALSE,TRUE
42 | 708,Pie chart. how many counties correspond to each police force?,4,4,5,FALSE,FALSE
43 | 718,Show the number of courses each teacher is arranged to teach in a pie chart.,4,5,5,FALSE,FALSE
44 | 744,Show all template type codes and the number of documents using each type with a bar chart.,4,4,3,FALSE,FALSE
45 | 789,Show all calendar dates and day Numbers in a line chart.,4,3,4,FALSE,TRUE
46 | 803,Show budget type codes and the number of documents in each budget type with a bar chart.,5,5,4,FALSE,FALSE
47 | 856,"Which workshop groups have bookings with status code ""stop""? Give me the names, and count them by a pie chart",4,4,5,FALSE,FALSE
48 | 880,A line chart for what are the number of the actual delivery dates of orders with quantity 1?,4,3,4,FALSE,TRUE
49 | 914,"Return a bar chart showing the proportion of the number of orders that have the status ""Delivered"" for each customer name.",3,2,3,TRUE,TRUE
50 | 961,Show all product names and the total quantity ordered for each product name in a bar chart.,3,3,2,TRUE,FALSE
51 | 1008,Show the number of products with price higher than 1000 or lower than 500 for each product name in a pie chart.,4,4,5,FALSE,FALSE
52 | 1071,"List the venues of debates in ascending order of the number of audience, and count them by a bar chart",2,2,2,TRUE,FALSE
53 | 1188,How many dogs who have gone through a treatment departed in each day? Return a bar chart.,5,5,5,FALSE,FALSE
54 | 1285,"Which tests have ""Pass"" results? Return the dates when the tests were taken, and count them by a line chart",4,3,3,FALSE,TRUE
55 | 1314,"For each county, find the name of the county and the number of delegates from that county. Show the proportion by pie chart.",5,4,5,FALSE,FALSE
56 | 1363,Draw a bar chart for what is the number of employees from each city?,4,4,3,FALSE,TRUE
57 | 1392,Please show the number of films for each type in a bar chart.,3,3,2,TRUE,FALSE
58 | 1434,A pie chart for finding the number of the names of Japanese constructors that have once earned more than 5 points?,4,4,5,FALSE,FALSE
59 | 1517,Show the number of companies in each headquarter with a pie chart.,3,4,3,FALSE,FALSE
60 | 1530,A bar chart for listing the number of the names of patients who have made appointments.,4,4,3,FALSE,TRUE
61 | 1630,Provide the frequency of the last names of employees earning more than the employee with id 163 using a bar chart.,4,3,3,FALSE,TRUE
62 | 1961,"For all employees in the Finance department, show me the proportion of their job id using a pie chart.",5,4,5,FALSE,FALSE
63 | 1974,Find the number of rooms with king bed for each decor type. Plot them as pie chart.,4,4,4,FALSE,FALSE
64 | 1992,"Among all the claims, which claims have a claimed amount larger than the average? Please Bin the date it was settled into weekday interval and count them to show a bar chart.",4,3,3,FALSE,TRUE
65 | 2013,What about the proportion of the total amounts of payments by each method code? You can give me a pie chart.,5,4,4,FALSE,FALSE
66 | 2027,Sum the amount for all the payments processed with Visa by each year using a bar chart.,1,2,1,FALSE,TRUE
67 | 2110,"Group and count the move in date in a bar chart, and I want to bin the X into Year interval.",1,2,1,FALSE,TRUE
68 | 2174,Give me a bar chart to show the names and revenue of the company that earns the highest revenue in each headquarter city.,5,4,5,FALSE,FALSE
69 | 2303,Please show me a bar chart for visualizing the name and revenue of all manufacturers sorted by their revenue in the descending order.,4,4,4,TRUE,FALSE
70 | 2350,Group and count the color scheme for all the photos using a pie chart.,4,4,4,FALSE,FALSE
71 | 2416,Find the name and membership level of the visitors whose membership level is higher than 4. Plot them as pie chart.,5,4,4,FALSE,FALSE
72 | 2526,Show the proportion of all ministers using a pie chart.,5,4,5,FALSE,FALSE
73 | 2571,Give me a pie to show total number of memory in g from different carrier.,4,4,4,FALSE,FALSE
74 | 2615,"What are the payment date of the payment with amount paid higher than 300 or with payment type is 'Check', bin the payment date by month and count them by a bar chart",5,5,5,FALSE,FALSE
75 | 2652,A bar chart for listing the number of the builders of railways in ascending alphabetical order.,4,3,3,TRUE,TRUE
76 | 2724,"For each denomination, return the denomination and the count of schools with that denomination. Visualize by bar chart.",4,4,4,FALSE,TRUE
77 | 2941,Give me a bar chart about the number of countries in the artist table,5,5,5,FALSE,FALSE
78 | 3207,Show the name and age for all male people who don't have a wedding by a bar chart.,4,4,4,TRUE,FALSE
79 | 1491@y_name@DESC,"Show the number of games for each home team in a bar chart, sort by the y-axis in desc please.",5,5,4,FALSE,FALSE
80 | 1380@y_name@ASC,"Find each target user's name and average trust score Visualize by bar chart, rank y axis from low to high order.",5,5,4,FALSE,FALSE
81 | 58@y_name@ASC,"What are the average ages for male and female students Plot them as bar chart, and list Y in asc order.",3,4,4,FALSE,FALSE
82 | 847@y_name@DESC,"Show the number of documents created in each day and bin document date by weekday with a bar chart, list by the y axis in desc.",2,1,1,TRUE,TRUE
83 | 2024@y_name@DESC,"For those payments processed with Visa, bin the payment day into Year interval and count them for a bar chart, order from high to low by the y axis.",2,2,1,FALSE,TRUE
84 | 3014@y_name@ASC,"Find the data about the sale details and dates of transactions with amount smaller than 3000? Bin the date of the transaction into a weekday interval and compute the total number of each day with a bar chart, and display Y in ascending order.",4,4,3,FALSE,TRUE
85 | 17@y_name@DESC,"Bar chart of total number by each rank, sort from high to low by the Y.",5,5,5,FALSE,FALSE
86 | 2575@y_name@ASC,"For each phone, show its names and total number of stocks Visualize by bar chart, rank by the Y-axis in ascending please.",5,5,5,FALSE,FALSE
87 | 75@y_name@ASC,"What is the number of booking start dates of the apartments with type code ""Duplex"" in each year? Return a bar chart, show y axis in ascending order.",1,1,1,FALSE,TRUE
88 | 533@y_name@DESC,"What is the lowest student GPA for every department? Return a bar chart, display in desc by the Y.",4,5,5,FALSE,FALSE
89 | 555@y_name@DESC,"How many classes are held in each department Visualize by bar chart, and list by the total number from high to low.",5,5,5,FALSE,FALSE
90 | 1237@y_name@ASC,"Find the average age for students with different sex in a bar chart, could you list by the Y from low to high?",3,4,4,FALSE,FALSE
91 | 279@y_name@DESC,"A bar chart about the number of end dates for incidents with incident type code ""NOISE"" and bin by month.",1,1,1,FALSE,TRUE
92 | 3134@y_name@ASC,"Give me a bar chart for all_games_percent of each team name, could you list in asc by the y axis please?",5,5,5,FALSE,FALSE
93 | 2943@y_name@ASC,"Bar chart x axis country y axis the average of age, and I want to show by the y-axis in asc please.",5,5,4,FALSE,FALSE
94 | 2760@y_name@DESC,"Bar graph to show how many nationality from different nationality, list in desc by the total number.",5,5,5,FALSE,FALSE
95 | 2765@y_name@ASC,"Give me a bar chart for mean tonnage of each type, show by the Y in asc.",5,5,5,FALSE,FALSE
96 | 134@x_name@DESC,"List the number of enginners in a stacked bar chart The x-axis is last name and group by skill description, rank last_name in descending order.",2,2,2,TRUE,TRUE
97 | 3069@x_name@ASC,"Find the name of each user and number of tweets tweeted by each of them Visualize by bar chart, could you order X-axis in asc order?",4,3,3,TRUE,TRUE
98 | 145@y_name@DESC,"A stacked bar chart showing thfe number of faults for different fault short name and skills required to fix them The x-axis is skill description and group by fault short name, display by the y axis in descending.",3,4,4,TRUE,FALSE
99 | 131,Give me a pie chart about the number of engineers for different skill description.,5,4,5,FALSE,FALSE
100 | 2019,Bin the claim date into the Year interval and count them for visualizing a bar chart.,3,2,1,FALSE,TRUE
101 | 1534,Return a pie on how many patients do each physician take care of? List their names and number of patients they take care of.,4,4,5,FALSE,FALSE
102 |
--------------------------------------------------------------------------------
/tests/test_chart_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import json
5 | from pathlib import Path
6 |
7 | import pytest
8 |
9 | from viseval.check import chart_check, deconstruct
10 |
11 | folder = Path(__file__).resolve().parent
12 |
13 | with open(folder / "assets/samples.json") as f:
14 | benchmark = json.load(f)
15 |
16 |
17 | def test_chart_check_pie_4():
18 | with open(folder / "assets/pie_4_0.svg", "r") as f:
19 | svg_string = f.read()
20 | chart_info, msg = deconstruct(svg_string)
21 | ground_truth = benchmark["4"]["chart"]
22 | query_meta = benchmark["4"]["query_meta"]
23 | result = chart_check(
24 | chart_info,
25 | ground_truth,
26 | query_meta[0]["stacked_bar"] if "stacked_bar" in query_meta[0] else None,
27 | )
28 | assert result[0] is True
29 |
30 |
31 | def test_chart_check_scatter_400():
32 | for i in range(2):
33 | with open(folder / f"assets/scatter_400_{i}.svg", "r") as f:
34 | svg_string = f.read()
35 | chart_info, msg = deconstruct(svg_string)
36 | ground_truth = benchmark["400"]["chart"]
37 | query_meta = benchmark["400"]["query_meta"]
38 | result = chart_check(
39 | chart_info,
40 | ground_truth,
41 | (
42 | query_meta[i]["stacked_bar"]
43 | if "stacked_bar" in query_meta[i]
44 | else None
45 | ),
46 | )
47 | assert result[0] is True
48 |
49 |
50 | def test_chart_check_bar_1129():
51 | for i in range(3):
52 | with open(folder / f"assets/bar_1129_{i}.svg", "r") as f:
53 | svg_string = f.read()
54 | chart_info, msg = deconstruct(svg_string)
55 | ground_truth = benchmark["1129"]["chart"]
56 | query_meta = benchmark["1129"]["query_meta"]
57 | result = chart_check(
58 | chart_info,
59 | ground_truth,
60 | (
61 | query_meta[i]["stacked_bar"]
62 | if "stacked_bar" in query_meta[i]
63 | else None
64 | ),
65 | )
66 | assert result[0] is True
67 |
68 |
69 | def test_chart_check_bar_2750():
70 | for i in range(2):
71 | with open(folder / f"assets/stacked_bar_2750_{i}.svg", "r") as f:
72 | svg_string = f.read()
73 | chart_info, msg = deconstruct(svg_string)
74 | ground_truth = benchmark["2750"]["chart"]
75 | query_meta = benchmark["2750"]["query_meta"]
76 | result = chart_check(
77 | chart_info,
78 | ground_truth,
79 | (
80 | query_meta[i]["stacked_bar"]
81 | if "stacked_bar" in query_meta[i]
82 | else None
83 | ),
84 | )
85 | assert result[0] is True
86 |
87 |
88 | def test_chart_check_line_3240():
89 | with open(folder / f"assets/line_3240_0.svg", "r") as f:
90 | svg_string = f.read()
91 | chart_info, msg = deconstruct(svg_string)
92 | ground_truth = benchmark["3240"]["chart"]
93 | query_meta = benchmark["3240"]["query_meta"]
94 | result = chart_check(
95 | chart_info,
96 | ground_truth,
97 | query_meta[0]["stacked_bar"] if "stacked_bar" in query_meta[0] else None,
98 | )
99 | assert result[0] is True
100 |
101 |
102 | def test_chart_check_line_2781():
103 | with open(folder / f"assets/grouping_line_2781_0.svg", "r") as f:
104 | svg_string = f.read()
105 | chart_info, msg = deconstruct(svg_string)
106 | ground_truth = benchmark["2781"]["chart"]
107 | query_meta = benchmark["2781"]["query_meta"]
108 | result = chart_check(
109 | chart_info,
110 | ground_truth,
111 | query_meta[0]["stacked_bar"] if "stacked_bar" in query_meta[0] else None,
112 | )
113 | assert result[0] is True
114 |
115 |
116 | def test_chart_check_bar_1071():
117 | # horizontal bar
118 | with open(folder / f"assets/bar_1071_0.svg", "r") as f:
119 | svg_string = f.read()
120 | chart_info, msg = deconstruct(svg_string)
121 | ground_truth = benchmark["1071"]["chart"]
122 | query_meta = benchmark["1071"]["query_meta"]
123 | result = chart_check(
124 | chart_info,
125 | ground_truth,
126 | query_meta[0]["stacked_bar"] if "stacked_bar" in query_meta[0] else None,
127 | )
128 | assert result[0] is True
129 |
130 |
131 | def test_stacked():
132 | result = chart_check({"chart": "grouping bar"}, "Stacked Bar", True)
133 | assert result[0] is False
134 |
135 | result = chart_check({"chart": "grouping bar"}, "Stacked Bar", False)
136 | assert result[0] is True
137 |
--------------------------------------------------------------------------------
/tests/test_data_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import json
5 | from pathlib import Path
6 |
7 | import pytest
8 |
9 | from viseval.check import data_check, deconstruct
10 |
11 | folder = Path(__file__).resolve().parent
12 |
13 | with open(folder / "assets/samples.json") as f:
14 | benchmark = json.load(f)
15 |
16 |
17 | def test_data_check_pie_5():
18 | with open(folder / "assets/pie_5.svg", "r") as f:
19 | svg_string = f.read()
20 | chart_info, msg = deconstruct(svg_string)
21 | ground_truth = benchmark["3264"]["vis_obj"]
22 | query_meta = benchmark["3264"]["query_meta"]
23 | result = data_check(
24 | chart_info, ground_truth, query_meta[0]["channel_specified"]
25 | )
26 | assert result[0] is True
27 |
28 | def test_data_check_pie_4():
29 | with open(folder / "assets/pie_4_0.svg", "r") as f:
30 | svg_string = f.read()
31 | chart_info, msg = deconstruct(svg_string)
32 | ground_truth = benchmark["4"]["vis_obj"]
33 | query_meta = benchmark["4"]["query_meta"]
34 | result = data_check(
35 | chart_info, ground_truth, query_meta[0]["channel_specified"]
36 | )
37 | assert result[0] is True
38 |
39 |
40 | def test_data_check_scatter_400():
41 | for i in range(2):
42 | with open(folder / f"assets/scatter_400_{i}.svg", "r") as f:
43 | svg_string = f.read()
44 | chart_info, msg = deconstruct(svg_string)
45 | ground_truth = benchmark["400"]["vis_obj"]
46 | query_meta = benchmark["400"]["query_meta"]
47 | result = data_check(
48 | chart_info, ground_truth, query_meta[i]["channel_specified"]
49 | )
50 | assert result[0] is True
51 |
52 |
53 | def test_data_check_bar_1129():
54 | for i in range(3):
55 | with open(folder / f"assets/bar_1129_{i}.svg", "r") as f:
56 | svg_string = f.read()
57 | chart_info, msg = deconstruct(svg_string)
58 | ground_truth = benchmark["1129"]["vis_obj"]
59 | query_meta = benchmark["1129"]["query_meta"]
60 | result = data_check(
61 | chart_info, ground_truth, query_meta[i]["channel_specified"]
62 | )
63 | assert result[0] is True
64 |
65 |
66 | def test_data_check_bar_2750():
67 | for i in range(2):
68 | with open(folder / f"assets/stacked_bar_2750_{i}.svg", "r") as f:
69 | svg_string = f.read()
70 | chart_info, msg = deconstruct(svg_string)
71 | ground_truth = benchmark["2750"]["vis_obj"]
72 | query_meta = benchmark["2750"]["query_meta"]
73 | result = data_check(
74 | chart_info, ground_truth, query_meta[i]["channel_specified"]
75 | )
76 | assert result[0] is True
77 |
78 |
79 | def test_data_check_line_3240():
80 | with open(folder / "assets/line_3240_0.svg", "r") as f:
81 | svg_string = f.read()
82 | chart_info, msg = deconstruct(svg_string)
83 | ground_truth = benchmark["3240"]["vis_obj"]
84 | query_meta = benchmark["3240"]["query_meta"]
85 | result = data_check(
86 | chart_info, ground_truth, query_meta[0]["channel_specified"]
87 | )
88 | assert result[0] is True
89 |
90 |
91 | def test_data_check_line_2781():
92 | with open(folder / "assets/grouping_line_2781_0.svg", "r") as f:
93 | svg_string = f.read()
94 | chart_info, msg = deconstruct(svg_string)
95 | ground_truth = benchmark["2781"]["vis_obj"]
96 | query_meta = benchmark["2781"]["query_meta"]
97 | result = data_check(
98 | chart_info, ground_truth, query_meta[0]["channel_specified"]
99 | )
100 | assert result[0] is False
101 |
102 |
103 | def test_data_check_line_773():
104 | # yyyy-mm-dd
105 | with open(folder / "assets/line_773_0.svg", "r") as f:
106 | svg_string = f.read()
107 | chart_info, msg = deconstruct(svg_string)
108 | ground_truth = benchmark["773@x_name@DESC"]["vis_obj"]
109 | query_meta = benchmark["773@x_name@DESC"]["query_meta"]
110 | result = data_check(
111 | chart_info, ground_truth, query_meta[0]["channel_specified"]
112 | )
113 | assert result[0] is True
114 |
115 |
116 | def test_data_check_bar_68():
117 | # monday - mon
118 | with open(folder / "assets/bar_68_0.svg", "r") as f:
119 | svg_string = f.read()
120 | chart_info, msg = deconstruct(svg_string)
121 | ground_truth = benchmark["68"]["vis_obj"]
122 | query_meta = benchmark["68"]["query_meta"]
123 | result = data_check(
124 | chart_info, ground_truth, query_meta[0]["channel_specified"]
125 | )
126 | assert result[0] is True
127 |
128 |
129 | def test_data_check_bar_1071():
130 | # horizontal bar
131 | with open(folder / "assets/bar_1071_0.svg", "r") as f:
132 | svg_string = f.read()
133 | chart_info, msg = deconstruct(svg_string)
134 | ground_truth = benchmark["1071"]["vis_obj"]
135 | query_meta = benchmark["1071"]["query_meta"]
136 | result = data_check(
137 | chart_info, ground_truth, query_meta[0]["channel_specified"]
138 | )
139 | assert result[0] is True
140 |
141 | result = data_check(chart_info, ground_truth, ["x", "y"])
142 | assert result[0] is False
143 |
144 |
145 | def test_data_check_bar_3137():
146 | # error cases
147 | with open(folder / "assets/bar_3137_1.svg", "r") as f:
148 | svg_string = f.read()
149 | chart_info, msg = deconstruct(svg_string)
150 | ground_truth = benchmark["3137@y_name@DESC"]["vis_obj"]
151 | query_meta = benchmark["3137@y_name@DESC"]["query_meta"]
152 | result = data_check(
153 | chart_info, ground_truth, query_meta[0]["channel_specified"]
154 | )
155 | assert result[0] is False
156 |
--------------------------------------------------------------------------------
/tests/test_deconstruct.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from pathlib import Path
5 |
6 | import pytest
7 |
8 | from viseval.check import deconstruct
9 |
10 | folder = Path(__file__).resolve().parent
11 |
12 |
13 | # error case
14 | def test_deconstruct_empty():
15 | with open(folder / "assets/empty.svg", "r") as f:
16 | svg_string = f.read()
17 | chart_info, msg = deconstruct(svg_string)
18 | assert len(chart_info["data"]) == 0
19 |
20 |
21 | def test_deconstruct_empty1():
22 | with open(folder / "assets/empty_1.svg", "r") as f:
23 | svg_string = f.read()
24 | chart_info, msg = deconstruct(svg_string)
25 | assert len(chart_info["data"]) == 0
26 |
27 |
28 | def test_deconstruct_empty2():
29 | with open(folder / "assets/empty_2.svg", "r") as f:
30 | svg_string = f.read()
31 | chart_info, msg = deconstruct(svg_string)
32 | assert len(chart_info["data"]) == 0
33 |
34 |
35 | def test_deconstruct_empty3():
36 | with open(folder / "assets/empty_3.svg", "r") as f:
37 | svg_string = f.read()
38 | chart_info, msg = deconstruct(svg_string)
39 | assert chart_info is None
40 |
41 |
42 | def test_deconstruct_empty4():
43 | with open(folder / "assets/empty_4.svg", "r") as f:
44 | svg_string = f.read()
45 | chart_info, msg = deconstruct(svg_string)
46 | print(chart_info)
47 | assert len(chart_info["data"]) == 0
48 |
49 |
50 | def test_deconstruct_dual_axis():
51 | with open(folder / "assets/dual_1499_0.svg", "r") as f:
52 | svg_string = f.read()
53 | chart_info, msg = deconstruct(svg_string)
54 | assert chart_info is None
55 |
56 |
57 | def test_deconstruct_bar_error():
58 | with open(folder / "assets/bar_error_3227_0.svg", "r") as f:
59 | svg_string = f.read()
60 | chart_info, msg = deconstruct(svg_string)
61 | assert chart_info["mark"] == "bar"
62 | assert chart_info["chart"] == "vertical bar(error)"
63 |
64 |
65 | def test_deconstruct_scatter():
66 | with open(folder / "assets/scatter.svg", "r") as f:
67 | svg_string = f.read()
68 | chart_info, msg = deconstruct(svg_string)
69 | assert chart_info["mark"] == "circle"
70 | assert chart_info["chart"] == "scatter"
71 | assert len(chart_info["encoding"].keys()) == 2
72 | assert (
73 | chart_info["title"]
74 | == "Correlation between IMDB Rating and Rotten Tomatoes Rating in Action Movies"
75 | )
76 | assert chart_info["encoding"]["x"]["title"] == "IMDB Rating"
77 | assert chart_info["encoding"]["y"]["title"] == "Rotten Tomatoes Rating"
78 | assert len(chart_info["data"]) == 309
79 |
80 |
81 | def test_deconstruct_scatter1():
82 | with open(folder / "assets/scatter_1.svg", "r") as f:
83 | svg_string = f.read()
84 | chart_info, msg = deconstruct(svg_string)
85 | assert chart_info["mark"] == "circle"
86 | assert chart_info["chart"] == "scatter"
87 | assert len(chart_info["encoding"].keys()) == 2
88 | assert len(chart_info["data"]) == 5
89 |
90 |
91 | def test_deconstruct_scatter2():
92 | with open(folder / "assets/scatter_2.svg", "r") as f:
93 | svg_string = f.read()
94 | chart_info, msg = deconstruct(svg_string)
95 | assert chart_info["mark"] == "circle"
96 | assert chart_info["chart"] == "scatter"
97 | assert len(chart_info["encoding"].keys()) == 2
98 | assert len(chart_info["data"]) == 100
99 |
100 |
101 | def test_deconstruct_scatter4():
102 | with open(folder / "assets/scatter_4.svg", "r") as f:
103 | svg_string = f.read()
104 | chart_info, msg = deconstruct(svg_string)
105 | assert chart_info["mark"] == "circle"
106 | assert chart_info["chart"] == "scatter"
107 | assert len(chart_info["encoding"].keys()) == 2
108 | assert len(chart_info["data"]) == 3
109 |
110 |
111 | def test_deconstruct_scatter6():
112 | with open(folder / "assets/scatter_6.svg", "r") as f:
113 | svg_string = f.read()
114 | chart_info, msg = deconstruct(svg_string)
115 | assert chart_info["mark"] == "circle"
116 | assert chart_info["chart"] == "scatter"
117 | assert len(chart_info["encoding"].keys()) == 2
118 | assert len(chart_info["data"]) == 69
119 |
120 |
121 | def test_deconstruct_scatter_lida():
122 | with open(folder / "assets/scatter_62.svg", "r") as f:
123 | svg_string = f.read()
124 | chart_info, msg = deconstruct(svg_string)
125 | assert chart_info["mark"] == "circle"
126 | assert chart_info["chart"] == "scatter"
127 | assert len(chart_info["encoding"].keys()) == 2
128 | assert len(chart_info["data"]) == 9
129 |
130 |
131 | def test_deconstruct_scatter_lida2():
132 | with open(folder / "assets/scatter_292.svg", "r") as f:
133 | svg_string = f.read()
134 | chart_info, msg = deconstruct(svg_string)
135 | assert chart_info["mark"] == "circle"
136 | assert chart_info["chart"] == "scatter"
137 | assert len(chart_info["encoding"].keys()) == 2
138 | assert len(chart_info["data"]) == 3
139 |
140 |
141 | def test_deconstruct_scatter_lida3():
142 | with open(folder / "assets/scatter_674.svg", "r") as f:
143 | svg_string = f.read()
144 | chart_info, msg = deconstruct(svg_string)
145 | print(chart_info)
146 | assert chart_info["mark"] == "circle"
147 | assert chart_info["chart"] == "scatter"
148 | assert len(chart_info["encoding"].keys()) == 2
149 | assert len(chart_info["data"]) == 76
150 |
151 |
152 | # black legend
153 | def test_deconstruct_grouping_scatter7():
154 | with open(folder / "assets/grouping_scatter_3272_0.svg", "r") as f:
155 | svg_string = f.read()
156 | chart_info, msg = deconstruct(svg_string)
157 | assert chart_info["mark"] == "circle"
158 | assert chart_info["chart"] == "grouping scatter"
159 | assert len(chart_info["encoding"].keys()) == 3
160 | assert len(chart_info["data"]) == 6
161 |
162 |
163 | def test_deconstruct_grouping_scatter():
164 | with open(folder / "assets/grouping_scatter.svg", "r") as f:
165 | svg_string = f.read()
166 | chart_info, msg = deconstruct(svg_string)
167 | assert chart_info["mark"] == "circle"
168 | assert chart_info["chart"] == "grouping scatter"
169 | assert len(chart_info["encoding"].keys()) == 3
170 | assert len(chart_info["data"]) == 5
171 |
172 |
173 | def test_deconstruct_line():
174 | with open(folder / "assets/line.svg", "r") as f:
175 | svg_string = f.read()
176 | chart_info, msg = deconstruct(svg_string)
177 | assert chart_info["mark"] == "line"
178 | assert chart_info["chart"] == "line"
179 | assert len(chart_info["encoding"].keys()) == 2
180 | assert chart_info["title"] == "Maximum Temperature by Month"
181 | assert chart_info["encoding"]["x"]["title"] == "Month"
182 | assert chart_info["encoding"]["y"]["title"] == "Temperature (°C)"
183 | assert len(chart_info["data"]) == 12
184 |
185 |
186 | def test_deconstruct_line1():
187 | with open(folder / "assets/line_1.svg", "r") as f:
188 | svg_string = f.read()
189 | chart_info, msg = deconstruct(svg_string)
190 | assert chart_info["mark"] == "line"
191 | assert chart_info["chart"] == "line"
192 | assert len(chart_info["encoding"].keys()) == 2
193 | assert len(chart_info["data"]) == 20
194 |
195 |
196 | def test_deconstruct_line2():
197 | with open(folder / "assets/line_2.svg", "r") as f:
198 | svg_string = f.read()
199 | chart_info, msg = deconstruct(svg_string)
200 | assert chart_info["mark"] == "line"
201 | assert chart_info["chart"] == "line"
202 | assert len(chart_info["encoding"].keys()) == 2
203 | assert len(chart_info["data"]) == 107
204 |
205 |
206 | def test_deconstruct_line3():
207 | with open(folder / "assets/line_3.svg", "r") as f:
208 | svg_string = f.read()
209 | chart_info, msg = deconstruct(svg_string)
210 | assert chart_info["mark"] == "line"
211 | assert chart_info["chart"] == "line"
212 | assert len(chart_info["encoding"].keys()) == 2
213 | assert len(chart_info["data"]) == 10
214 |
215 |
216 | def test_deconstruct_line4():
217 | with open(folder / "assets/line_4.svg", "r") as f:
218 | svg_string = f.read()
219 | chart_info, msg = deconstruct(svg_string)
220 | assert chart_info["mark"] == "line"
221 | assert chart_info["chart"] == "line"
222 | assert len(chart_info["encoding"].keys()) == 2
223 | assert len(chart_info["data"]) == 107
224 |
225 |
226 | # error case
227 | def test_deconstruct_line5():
228 | with open(folder / "assets/line_5.svg", "r") as f:
229 | svg_string = f.read()
230 | chart_info, msg = deconstruct(svg_string)
231 | assert chart_info["mark"] == "line"
232 | assert chart_info["chart"] == "line"
233 | assert len(chart_info["encoding"].keys()) == 2
234 | assert len(chart_info["data"]) == 0
235 |
236 |
237 | # error case
238 | def test_deconstruct_line11():
239 | with open(folder / "assets/line_11.svg", "r") as f:
240 | svg_string = f.read()
241 | chart_info, msg = deconstruct(svg_string)
242 | assert "mark" not in chart_info
243 |
244 |
245 | # tick line
246 | def test_deconstruct_line12():
247 | with open(folder / "assets/line_12.svg", "r") as f:
248 | svg_string = f.read()
249 | chart_info, msg = deconstruct(svg_string)
250 | assert chart_info["mark"] == "line"
251 | assert chart_info["chart"] == "line"
252 | assert len(chart_info["encoding"].keys()) == 2
253 | assert len(chart_info["data"]) == 13
254 |
255 |
256 | def test_deconstruct_line13():
257 | with open(folder / "assets/line_13.svg", "r") as f:
258 | svg_string = f.read()
259 | chart_info, msg = deconstruct(svg_string)
260 | assert chart_info["mark"] == "line"
261 | assert chart_info["chart"] == "line"
262 | assert len(chart_info["encoding"].keys()) == 2
263 | assert len(chart_info["data"]) == 15
264 |
265 |
266 | def test_deconstruct_line14():
267 | with open(folder / "assets/line_1746_1.svg", "r") as f:
268 | svg_string = f.read()
269 | chart_info, msg = deconstruct(svg_string)
270 | assert chart_info["mark"] == "circle"
271 | assert chart_info["chart"] == "scatter"
272 | assert len(chart_info["encoding"].keys()) == 2
273 | assert len(chart_info["data"]) == 1
274 |
275 |
276 | def test_deconstruct_grouping_line():
277 | with open(folder / "assets/grouping_line.svg", "r") as f:
278 | svg_string = f.read()
279 | chart_info, msg = deconstruct(svg_string)
280 | assert chart_info["mark"] == "line"
281 | assert chart_info["chart"] == "grouping line"
282 | assert len(chart_info["encoding"].keys()) == 3
283 | assert len(chart_info["data"]) == 9
284 |
285 |
286 | def test_deconstruct_grouping_line1():
287 | with open(folder / "assets/grouping_line_1.svg", "r") as f:
288 | svg_string = f.read()
289 | chart_info, msg = deconstruct(svg_string)
290 | assert chart_info["mark"] == "line"
291 | assert chart_info["chart"] == "grouping line"
292 | assert len(chart_info["encoding"].keys()) == 3
293 | assert len(chart_info["data"]) == 20
294 |
295 |
296 | def test_deconstruct_grouping_line4():
297 | with open(folder / "assets/grouping_line_4.svg", "r") as f:
298 | svg_string = f.read()
299 | chart_info, msg = deconstruct(svg_string)
300 | assert len(chart_info["encoding"].keys()) == 2
301 |
302 |
303 | # error case
304 | def test_deconstruct_grouping_line2():
305 | with open(folder / "assets/grouping_line_2.svg", "r") as f:
306 | svg_string = f.read()
307 | chart_info, msg = deconstruct(svg_string)
308 | assert chart_info["mark"] == "line"
309 | assert chart_info["chart"] == "grouping line"
310 | assert len(chart_info["encoding"].keys()) == 3
311 | assert len(chart_info["data"]) == 0
312 |
313 |
314 | def test_deconstruct_grouping_line3():
315 | with open(folder / "assets/grouping_line_3.svg", "r") as f:
316 | svg_string = f.read()
317 | chart_info, msg = deconstruct(svg_string)
318 | assert chart_info["mark"] == "line"
319 | assert chart_info["chart"] == "grouping line"
320 | assert len(chart_info["encoding"].keys()) == 3
321 | assert len(chart_info["data"]) == 0
322 |
323 |
324 | def test_deconstruct_bar():
325 | with open(folder / "assets/bar.svg", "r") as f:
326 | svg_string = f.read()
327 | chart_info, msg = deconstruct(svg_string)
328 | assert chart_info["mark"] == "bar"
329 | assert chart_info["chart"] == "vertical bar"
330 | assert len(chart_info["encoding"].keys()) == 2
331 | assert chart_info["title"] == "Number of Students Advised by Faculty Rank"
332 | assert chart_info["encoding"]["x"]["title"] == "Faculty Rank"
333 | assert chart_info["encoding"]["y"]["title"] == "Number of Students Advised"
334 | assert len(chart_info["data"]) == 3
335 |
336 |
337 | def test_deconstruct_bar1():
338 | with open(folder / "assets/bar_1.svg", "r") as f:
339 | svg_string = f.read()
340 | chart_info, msg = deconstruct(svg_string)
341 | assert chart_info["mark"] == "bar"
342 | assert chart_info["chart"] == "horizontal bar"
343 | assert len(chart_info["encoding"].keys()) == 2
344 | assert len(chart_info["data"]) == 7
345 |
346 |
347 | def test_deconstruct_bar6():
348 | with open(folder / "assets/bar_6.svg", "r") as f:
349 | svg_string = f.read()
350 | chart_info, msg = deconstruct(svg_string)
351 | assert chart_info["mark"] == "bar"
352 | assert len(chart_info["encoding"].keys()) == 2
353 | assert len(chart_info["data"]) == 49
354 |
355 |
356 | # 1 bar
357 | def test_deconstruct_bar7():
358 | with open(folder / "assets/bar_2733_0.svg", "r") as f:
359 | svg_string = f.read()
360 | chart_info, msg = deconstruct(svg_string)
361 | assert chart_info["mark"] == "bar"
362 | assert chart_info["chart"] == "horizontal bar"
363 | assert len(chart_info["encoding"].keys()) == 2
364 | assert len(chart_info["data"]) == 1
365 |
366 |
367 | # bar legend error
368 | def test_deconstruct_bar8():
369 | with open(folder / "assets/stacked_bar_2815.svg", "r") as f:
370 | svg_string = f.read()
371 | chart_info, msg = deconstruct(svg_string)
372 | assert chart_info["mark"] == "bar"
373 | assert len(chart_info["encoding"].keys()) == 2
374 |
375 |
376 | def test_deconstruct_bar9():
377 | with open(folder / "assets/bar_186.svg", "r") as f:
378 | svg_string = f.read()
379 | chart_info, msg = deconstruct(svg_string)
380 | assert chart_info["mark"] == "bar"
381 | assert len(chart_info["encoding"].keys()) == 2
382 | assert len(chart_info["data"]) == 14
383 |
384 |
385 | def test_deconstruct_bar10():
386 | with open(folder / "assets/bar_3269.svg", "r") as f:
387 | svg_string = f.read()
388 | chart_info, msg = deconstruct(svg_string)
389 | assert chart_info["mark"] == "bar"
390 | assert len(chart_info["encoding"].keys()) == 2
391 | print(chart_info)
392 | assert len(chart_info["data"]) == 6
393 |
394 |
395 | # lida bar
396 | def test_deconstruct_bar_lida1():
397 | with open(folder / "assets/bar_9.svg", "r") as f:
398 | svg_string = f.read()
399 | chart_info, msg = deconstruct(svg_string)
400 | assert chart_info["mark"] == "bar"
401 | assert chart_info["chart"] == "vertical bar"
402 | assert len(chart_info["encoding"].keys()) == 2
403 | assert len(chart_info["data"]) == 4
404 |
405 |
406 | # lida bar
407 | def test_deconstruct_bar_lida2():
408 | with open(folder / "assets/bar_205.svg", "r") as f:
409 | svg_string = f.read()
410 | chart_info, msg = deconstruct(svg_string)
411 | assert chart_info["mark"] == "bar"
412 | assert chart_info["chart"] == "vertical bar"
413 | assert len(chart_info["encoding"].keys()) == 2
414 | assert len(chart_info["data"]) == 2
415 |
416 |
417 | def test_deconstruct_stacked_bar():
418 | with open(folder / "assets/stacked_bar.svg", "r") as f:
419 | svg_string = f.read()
420 | chart_info, msg = deconstruct(svg_string)
421 | assert chart_info["mark"] == "bar"
422 | assert chart_info["chart"] == "vertical stacked bar"
423 | assert len(chart_info["encoding"].keys()) == 3
424 | assert chart_info["title"] == "Faculty by Rank/Sex"
425 | assert chart_info["encoding"]["x"]["title"] == "Rank"
426 | assert chart_info["encoding"]["y"]["title"] == "Number of faculty"
427 | assert len(chart_info["data"]) == 7
428 |
429 |
430 | def test_deconstruct_stacked_bar1():
431 | with open(folder / "assets/stacked_bar_1.svg", "r") as f:
432 | svg_string = f.read()
433 | chart_info, msg = deconstruct(svg_string)
434 | assert chart_info["mark"] == "bar"
435 | assert chart_info["chart"] == "vertical stacked bar"
436 | assert len(chart_info["encoding"].keys()) == 3
437 | assert len(chart_info["data"]) == 9
438 |
439 |
440 | def test_deconstruct_grouping_bar():
441 | with open(folder / "assets/bar_189.svg", "r") as f:
442 | svg_string = f.read()
443 | chart_info, msg = deconstruct(svg_string)
444 | assert chart_info["mark"] == "bar"
445 | assert chart_info["chart"] == "vertical grouping bar"
446 | assert len(chart_info["encoding"].keys()) == 3
447 | assert len(chart_info["data"]) == 4
448 |
449 |
450 | def test_deconstruct_pie():
451 | with open(folder / "assets/pie.svg", "r") as f:
452 | svg_string = f.read()
453 | chart_info, msg = deconstruct(svg_string)
454 | assert chart_info["mark"] == "arc"
455 | assert chart_info["chart"] == "pie"
456 | assert len(chart_info["encoding"].keys()) == 2
457 | assert chart_info["title"] == "Maximum Price of Each Film"
458 | assert len(chart_info["data"]) == 5
459 |
460 |
461 | # 180 degree
462 | def test_deconstruct_pie1():
463 | with open(folder / "assets/pie_1.svg", "r") as f:
464 | svg_string = f.read()
465 | chart_info, msg = deconstruct(svg_string)
466 | assert chart_info["mark"] == "arc"
467 | assert chart_info["chart"] == "pie"
468 | assert len(chart_info["encoding"].keys()) == 2
469 | assert len(chart_info["data"]) == 3
470 | flag = False
471 | for datum in chart_info["data"]:
472 | # = 50 percent
473 | if datum["field_theta"] == 50:
474 | flag = True
475 | assert flag
476 |
477 |
478 | # shadow and > 180 degree
479 | def test_deconstruct_pie2():
480 | with open(folder / "assets/pie_2.svg", "r") as f:
481 | svg_string = f.read()
482 | chart_info, msg = deconstruct(svg_string)
483 | assert chart_info["mark"] == "arc"
484 | assert chart_info["chart"] == "pie"
485 | assert len(chart_info["encoding"].keys()) == 2
486 | assert len(chart_info["data"]) == 3
487 | flag = False
488 | for datum in chart_info["data"]:
489 | # > 50 percent
490 | if datum["field_theta"] > 50:
491 | flag = True
492 | assert flag
493 |
494 |
495 | # shadow and > 180 degree
496 | def test_deconstruct_pie3():
497 | with open(folder / "assets/pie_3.svg", "r") as f:
498 | svg_string = f.read()
499 | chart_info, msg = deconstruct(svg_string)
500 | assert chart_info["mark"] == "arc"
501 | assert chart_info["chart"] == "pie"
502 | assert len(chart_info["encoding"].keys()) == 2
503 | assert len(chart_info["data"]) == 2
504 | flag = False
505 | for datum in chart_info["data"]:
506 | # > 50 percent
507 | if datum["field_theta"] > 50:
508 | flag = True
509 | assert flag
510 |
511 |
512 | # 360 degree
513 | def test_deconstruct_pie3():
514 | with open(folder / "assets/pie_4.svg", "r") as f:
515 | svg_string = f.read()
516 | chart_info, msg = deconstruct(svg_string)
517 | assert chart_info["mark"] == "arc"
518 | assert chart_info["chart"] == "pie"
519 | assert len(chart_info["encoding"].keys()) == 2
520 | assert len(chart_info["data"]) == 1
521 | flag = False
522 | for datum in chart_info["data"]:
523 | # > 50 percent
524 | if datum["field_theta"] == 100:
525 | flag = True
526 | assert flag
527 |
--------------------------------------------------------------------------------
/tests/test_layout_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import csv
5 | from pathlib import Path
6 |
7 | import pytest
8 |
9 | from viseval.check import layout_check
10 |
11 | folder = Path(__file__).resolve().parent
12 |
13 | webdriver_path = "/usr/bin/chromedriver"
14 |
15 |
16 | def test_layout_check():
17 | with open(folder / "assets/readability/145@y_name@DESC.svg", "r") as svg_file:
18 | svg_string = svg_file.read()
19 | assert layout_check({"svg_string": svg_string}, webdriver_path) == (
20 | False,
21 | "Overflow detected.",
22 | )
23 |
24 |
25 | def test_layout_check_2():
26 | with open(folder / "assets/readability/2350.svg", "r") as svg_file:
27 | svg_string = svg_file.read()
28 | assert layout_check({"svg_string": svg_string}, webdriver_path) == (
29 | True,
30 | "No overflow or overlap detected.",
31 | )
32 |
33 |
34 | def test_layout_check_3():
35 | with open(folder / "assets/readability/2652.svg", "r") as svg_file:
36 | svg_string = svg_file.read()
37 | assert layout_check({"svg_string": svg_string}, webdriver_path) == (
38 | False,
39 | "Overflow detected.",
40 | )
41 |
42 |
43 | def test_layout_check_batch():
44 | with open(folder / "assets/readability/readability_human_rating.csv") as f:
45 | reader = csv.reader(f)
46 | next(reader)
47 | for row in reader:
48 | svg_id = row[0]
49 |
50 | with open(folder / f"assets/readability/{svg_id}.svg", "r") as svg_file:
51 | svg_string = svg_file.read()
52 | print(svg_id)
53 | assert not layout_check({"svg_string": svg_string}, webdriver_path)[0] == (
54 | False if row[5] == "FALSE" else True
55 | )
56 |
--------------------------------------------------------------------------------
/tests/test_order_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import json
5 | from pathlib import Path
6 |
7 | import pytest
8 |
9 | from viseval.check import data_check, deconstruct, order_check
10 |
11 | folder = Path(__file__).resolve().parent
12 |
13 | with open(folder / "assets/samples.json") as f:
14 | benchmark = json.load(f)
15 |
16 |
17 | def test_order_check_line_773():
18 | # x ascending
19 | with open(folder / "assets/line_773_0.svg", "r") as f:
20 | svg_string = f.read()
21 | chart_info, msg = deconstruct(svg_string)
22 | # yyyy-mm-dd
23 | ground_truth = benchmark["773@x_name@DESC"]["vis_obj"]
24 | query_meta = benchmark["773@x_name@DESC"]["query_meta"]
25 | answer, rationale = data_check(
26 | chart_info, ground_truth, query_meta[0]["channel_specified"]
27 | )
28 |
29 | answer, rationale = order_check(
30 | chart_info, ground_truth, query_meta[0]["sort_by"]
31 | )
32 | assert answer is True
33 |
34 | ground_truth["sort"]["order"] = "descending"
35 | answer, rationale = order_check(
36 | chart_info, ground_truth, query_meta[0]["sort_by"]
37 | )
38 | assert answer is False
39 |
40 |
41 | def test_order_check_bar_680():
42 | # y descending
43 | with open(folder / "assets/stacked_bar_680_0.svg", "r") as f:
44 | svg_string = f.read()
45 | chart_info, msg = deconstruct(svg_string)
46 | ground_truth = benchmark["680@y_name@DESC"]["vis_obj"]
47 | query_meta = benchmark["680@y_name@DESC"]["query_meta"]
48 | answer, rationale = data_check(
49 | chart_info, ground_truth, query_meta[0]["channel_specified"]
50 | )
51 |
52 | answer, rationale = order_check(
53 | chart_info, ground_truth, query_meta[0]["sort_by"]
54 | )
55 | assert answer is True
56 |
57 | ground_truth["sort"]["order"] = "ascending"
58 | answer, rationale = order_check(
59 | chart_info, ground_truth, query_meta[0]["sort_by"]
60 | )
61 | assert answer is False
62 |
63 |
64 | def test_order_check_bar_2815():
65 | # x descending
66 | # swap channel x-z
67 | with open(folder / "assets/stacked_bar_2815_0.svg", "r") as f:
68 | svg_string = f.read()
69 | chart_info, msg = deconstruct(svg_string)
70 | ground_truth = benchmark["2815@x_name@DESC"]["vis_obj"]
71 | query_meta = benchmark["2815@x_name@DESC"]["query_meta"]
72 | answer, rationale = data_check(
73 | chart_info, ground_truth, query_meta[0]["channel_specified"]
74 | )
75 |
76 | answer, rationale = order_check(
77 | chart_info, ground_truth, query_meta[0]["sort_by"]
78 | )
79 | assert answer is True
80 |
81 | ground_truth["sort"]["order"] = "ascending"
82 | answer, rationale = order_check(
83 | chart_info, ground_truth, query_meta[0]["sort_by"]
84 | )
85 | assert answer is False
86 |
87 |
88 | # horizontal bar
89 | def test_order_check_bar_2060():
90 | # y descending
91 | # swap channel x-y
92 | with open(folder / "assets/bar_2060_0.svg", "r") as f:
93 | svg_string = f.read()
94 | chart_info, msg = deconstruct(svg_string)
95 | ground_truth = benchmark["2060@y_name@DESC"]["vis_obj"]
96 | query_meta = benchmark["2060@y_name@DESC"]["query_meta"]
97 | answer, rationale = data_check(chart_info, ground_truth, [])
98 |
99 | answer, rationale = order_check(
100 | chart_info, ground_truth, query_meta[0]["sort_by"]
101 | )
102 | assert answer is False
103 |
104 | # x quantitative descending
105 | answer, rationale = order_check(chart_info, ground_truth, "attribute")
106 | assert answer is True
107 |
108 | ground_truth["sort"]["order"] = "ascending"
109 | answer, rationale = order_check(chart_info, ground_truth, "attribute")
110 | assert answer is False
111 |
112 |
113 | # horizontal bar
114 | def test_order_check_bar_1000():
115 | # y descending
116 | # swap channel x-y
117 | with open(folder / "assets/bar_1000_0.svg", "r") as f:
118 | svg_string = f.read()
119 | chart_info, msg = deconstruct(svg_string)
120 | ground_truth = benchmark["1000@y_name@DESC"]["vis_obj"]
121 | query_meta = benchmark["1000@y_name@DESC"]["query_meta"]
122 | answer, rationale = data_check(chart_info, ground_truth, [])
123 | print(answer, rationale)
124 | answer, rationale = order_check(
125 | chart_info, ground_truth, query_meta[0]["sort_by"]
126 | )
127 | assert answer is False
128 |
129 | # x quantitative descending
130 | ground_truth["sort"]["channel"] = "x"
131 | answer, rationale = order_check(
132 | chart_info, ground_truth, query_meta[0]["sort_by"]
133 | )
134 | assert answer is True
135 |
136 | # y nominal descending
137 | ground_truth["sort"]["channel"] = "y"
138 | chart_info["encoding"]["y"]["scale"]["domain"] = [
139 | "sony",
140 | "jcrew",
141 | "gucci",
142 | "apple",
143 | ]
144 | answer, rationale = order_check(
145 | chart_info, ground_truth, query_meta[0]["sort_by"]
146 | )
147 | assert answer is True
148 |
149 | # custom order
150 | ground_truth["sort"]["order"] = ["sony", "jcrew", "gucci", "apple"]
151 | answer, rationale = order_check(
152 | chart_info, ground_truth, query_meta[0]["sort_by"]
153 | )
154 | assert answer is True
155 |
156 |
157 | # invert y axis, zero value
158 | def test_order_check_bar_550():
159 | with open(folder / "assets/bar_550_0.svg", "r") as f:
160 | svg_string = f.read()
161 | chart_info, msg = deconstruct(svg_string)
162 | ground_truth = benchmark["550@y_name@DESC"]["vis_obj"]
163 | query_meta = benchmark["550@y_name@DESC"]["query_meta"]
164 | answer, rationale = data_check(
165 | chart_info, ground_truth, query_meta[0]["channel_specified"]
166 | )
167 |
168 | answer, rationale = order_check(
169 | chart_info, ground_truth, query_meta[0]["sort_by"]
170 | )
171 | assert answer is True
172 |
173 |
174 | def test_order_check_bar_465():
175 | with open(folder / "assets/bar_465.svg", "r") as f:
176 | svg_string = f.read()
177 | chart_info, msg = deconstruct(svg_string)
178 | ground_truth = benchmark["465@y_name@ASC"]["vis_obj"]
179 | query_meta = benchmark["465@y_name@ASC"]["query_meta"]
180 | answer, rationale = data_check(
181 | chart_info, ground_truth, query_meta[0]["channel_specified"]
182 | )
183 |
184 | answer, rationale = order_check(
185 | chart_info, ground_truth, query_meta[0]["sort_by"]
186 | )
187 | assert answer is False
188 |
--------------------------------------------------------------------------------
/tests/test_surface_form_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
5 | from pathlib import Path
6 |
7 | import pytest
8 |
9 | from viseval.check import surface_form_check
10 |
11 | folder = Path(__file__).resolve().parent
12 |
13 | def test_valid_empty():
14 | with open(folder / "assets/empty_3.svg", "r") as f:
15 | svg_string = f.read()
16 | answer, msg = surface_form_check(svg_string)
17 | assert answer == False
18 |
19 |
20 | def test_valid_empty():
21 | with open(folder / "assets/pie_4_0.svg", "r") as f:
22 | svg_string = f.read()
23 | answer, msg = surface_form_check(svg_string)
24 | assert answer == True
--------------------------------------------------------------------------------
/viseval/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .dataset import Dataset
5 | from .evaluate import Evaluator
6 |
--------------------------------------------------------------------------------
/viseval/agent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from abc import ABC, abstractmethod
5 | from typing import Optional, Tuple, Union
6 |
7 | from attr import dataclass
8 | from langchain.chat_models.base import BaseChatModel
9 | from llmx import TextGenerator
10 |
11 |
12 | @dataclass
13 | class ChartExecutionResult:
14 | """Response from a visualization execution"""
15 |
16 | # True if successful, False otherwise
17 | status: bool
18 | # Generate svg string if status is True
19 | svg_string: Optional[str] = None
20 | # Error message if status is False
21 | error_msg: Optional[str] = None
22 |
23 |
24 | class Agent(ABC):
25 | @abstractmethod
26 | def __init__(
27 | self, llm: Union[BaseChatModel, TextGenerator], config: dict = None
28 | ) -> None:
29 | pass
30 |
31 | @abstractmethod
32 | def generate(
33 | self, nl_query: str, tables: list[str], config: dict
34 | ) -> Tuple[str, dict]:
35 | """Generate code for the given natural language query.
36 |
37 | Args:
38 | nl_query (str): Natural language query.
39 | tables (list[str]): List of table file paths.
40 | config (dict): Generation configuration.
41 |
42 | Returns:
43 | Tuple[str, dict]: Generated code and execution context.
44 | """
45 | pass
46 |
47 | @abstractmethod
48 | def execute(
49 | self, code: str, context: dict, log_name: str = None
50 | ) -> ChartExecutionResult:
51 | """Execute the given code with context and return the result
52 |
53 | Args:
54 | code (str): Code to execute,.
55 | context (dict): Context for the code execution. Different agents require different contexts.
56 | log_name (str, optional): SVG file name. Defaults to None.
57 |
58 | Returns:
59 | ChartExecutionResult: _description_
60 | """
61 | pass
62 |
--------------------------------------------------------------------------------
/viseval/check/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .chart_check import chart_check
5 | from .data_check import data_check
6 | from .deconstruct import deconstruct
7 | from .layout_check import layout_check
8 | from .order_check import order_check
9 | from .readability_check import readability_check
10 | from .scale_and_ticks_check import scale_and_ticks_check
11 | from .surface_form_check import surface_form_check
12 |
--------------------------------------------------------------------------------
/viseval/check/chart_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
5 | def chart_check(
6 | chart_info: dict, chart_type_ground_truth: str, stacked_bar: bool = False
7 | ):
8 | # chart_type_ground_truth: pie, bar, line, scatter, stacked bar, grouping line, grouping scatter
9 | if "chart" not in chart_info:
10 | return False, "Cannot recognize the chart type."
11 |
12 | chart_type = chart_info["chart"]
13 |
14 | if chart_type_ground_truth.lower() in chart_type.lower() or (
15 | chart_type_ground_truth == "Stacked Bar"
16 | and not stacked_bar
17 | and "grouping bar" in chart_type.lower()
18 | ):
19 | return True, "Chart type is consistent with ground truth."
20 |
21 | return False, "Chart type is not consistent with ground truth."
22 |
--------------------------------------------------------------------------------
/viseval/check/data_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import copy
5 | import json
6 |
7 | from .time_utils import (
8 | compare_time_strings,
9 | convert_month_or_weekday_to_int,
10 | is_month_or_weekday,
11 | parse_number_to_time,
12 | parse_time_to_timestamp,
13 | )
14 |
15 | PRECISION = 0.0005
16 |
17 |
18 | def is_numeric(s):
19 | try:
20 | float(s)
21 | return True
22 | except ValueError:
23 | return False
24 |
25 | def compare_string(string, ground_truth):
26 | return string.strip().lower().startswith(ground_truth.strip().lower())
27 |
28 | def convert_ground_truth_data(ground_truth):
29 | # ground truth data
30 | x_data = ground_truth["x_data"]
31 | y_data = ground_truth["y_data"]
32 | classify = ground_truth["classify"]
33 |
34 | data_ground_truth = []
35 | if len(classify) == 0:
36 | x_data = x_data[0]
37 | y_data = y_data[0]
38 | for index in range(len(x_data)):
39 | data_ground_truth.append(
40 | {
41 | "field_x": x_data[index],
42 | "field_y": y_data[index],
43 | }
44 | )
45 | else:
46 | # scatter
47 | if len(classify) == len(x_data) and len(classify) == len(y_data):
48 | for index in range(len(classify)):
49 | for index2 in range(len(x_data[index])):
50 | data_ground_truth.append(
51 | {
52 | "field_x": x_data[index][index2],
53 | "field_y": y_data[index][index2],
54 | "field_classify": classify[index],
55 | }
56 | )
57 | # line / bar
58 | elif len(classify) == len(y_data):
59 | for index in range(len(classify)):
60 | for index2 in range(len(x_data[0])):
61 | data_ground_truth.append(
62 | {
63 | "field_x": x_data[0][index2],
64 | "field_y": y_data[index][index2],
65 | "field_classify": classify[index],
66 | }
67 | )
68 | return data_ground_truth
69 |
70 |
71 | # return answer and rationale for the data check
72 | def compare_data(data_ground_truth, chart_info):
73 | # deep copy: avoid to change origin data
74 | data = copy.deepcopy(chart_info["data"])
75 | encoding = chart_info["encoding"]
76 | # line chart may have different length of data because line chart might omit some values
77 | if (len(data) != len(data_ground_truth) and chart_info["mark"] != "line") or (
78 | len(data) > len(data_ground_truth)
79 | ):
80 | return (
81 | False,
82 | f"visualization data length {len(data)} != ground truth length {len(data_ground_truth)}.",
83 | )
84 |
85 | if chart_info["mark"] == "arc":
86 | # field_fill -> field_x
87 | field_x = "field_fill"
88 | # field_theta -> field_y relative
89 | field_y = "field_theta"
90 |
91 | scale = None
92 | for datum_ground_truth in data_ground_truth:
93 | datum = [
94 | x
95 | for x in data
96 | if compare_string(str(x[field_x]), str(datum_ground_truth["field_x"]))
97 | ]
98 | if len(datum) == 1:
99 | datum = datum[0]
100 | if scale is None:
101 | scale = datum[field_y] / datum_ground_truth["field_y"]
102 | else:
103 | if (
104 | abs(datum[field_y] / datum_ground_truth["field_y"] - scale)
105 | > PRECISION
106 | ):
107 | return False, f"{json.dumps(datum_ground_truth)} not found\n"
108 | elif len(datum) == 0:
109 | return False, f"{json.dumps(datum_ground_truth)} not found."
110 | elif len(datum) > 1:
111 | return False, f"{json.dumps(datum_ground_truth)} found more than one."
112 | else:
113 | if len(encoding.keys()) < len(data_ground_truth[0].keys()):
114 | return False, "too few channels\n"
115 | elif len(encoding.keys()) > 3:
116 | return False, "too many channels\n"
117 | elif len(encoding.keys()) > len(data_ground_truth[0].keys()):
118 | # only keep x and y for comparison
119 | encoding = {key: encoding[key] for key in ["x", "y"]}
120 |
121 | for datum_ground_truth in data_ground_truth:
122 | datum = data
123 | for key in encoding:
124 | if key == "x" or key == "y":
125 | field = f"field_{key}"
126 | else:
127 | field = "field_classify"
128 | if (
129 | encoding[key]["type"] == "quantitative"
130 | or encoding[key]["type"] == "temporal"
131 | ):
132 | # e.g., Monday -> 1
133 | if is_month_or_weekday(datum_ground_truth[field]):
134 | if encoding[key]["type"] == "temporal":
135 | datum = [
136 | x
137 | for x in datum
138 | if compare_time_strings(
139 | x["field_" + key].strip(),
140 | str(datum_ground_truth[field]).strip(),
141 | )
142 | ]
143 | value_ground_truth = None
144 | else:
145 | value_ground_truth = convert_month_or_weekday_to_int(
146 | datum_ground_truth[field]
147 | )
148 | else:
149 | try:
150 | value_ground_truth = float(datum_ground_truth[field])
151 | except Exception:
152 | # not a number
153 | # try temporal
154 | try:
155 | value_ground_truth = parse_time_to_timestamp(
156 | datum_ground_truth[field]
157 | )
158 | if value_ground_truth is None:
159 | return (
160 | False,
161 | f"The data type of {key}({encoding[key]['type']}) is wrong.",
162 | )
163 |
164 | if encoding[key]["type"] == "quantitative":
165 | for x in datum:
166 | x["field_" + key] = parse_number_to_time(
167 | x["field_" + key]
168 | )
169 | except Exception:
170 | return (
171 | False,
172 | f"The data type of {key}({encoding[key]['type']}) is wrong.",
173 | )
174 |
175 | field_vis = (
176 | "field_" + key
177 | if encoding[key]["type"] != "temporal"
178 | else "field_" + key + "_origin"
179 | )
180 | if value_ground_truth is not None:
181 | # avoid division by zero
182 | datum = [
183 | x
184 | for x in datum
185 | if abs((x[field_vis] - value_ground_truth)) <= PRECISION
186 | or (
187 | value_ground_truth != 0
188 | and abs(
189 | (x[field_vis] - value_ground_truth)
190 | / value_ground_truth
191 | )
192 | <= PRECISION
193 | )
194 | ]
195 | elif encoding[key]["type"] == "nominal":
196 | # exact match or time match
197 | datum = [
198 | x
199 | for x in datum
200 | if compare_string(x["field_" + key], str(datum_ground_truth[field]))
201 | or compare_time_strings(
202 | x["field_" + key].strip(),
203 | str(datum_ground_truth[field]).strip(),
204 | )
205 | ]
206 |
207 | if len(datum) == 0:
208 | if (
209 | chart_info["mark"] == "line"
210 | and (
211 | encoding["x"]["type"] == "quantitative"
212 | or encoding["x"]["type"] == "temporal"
213 | )
214 | and (encoding["y"]["type"] == "quantitative")
215 | ):
216 | # use all data
217 | datum = chart_info["data"]
218 | if len(data_ground_truth[0].keys()) == 3:
219 | datum = [
220 | x
221 | for x in datum
222 | if x["field_stroke"].strip()
223 | == datum_ground_truth["field_classify"].strip()
224 | ]
225 | # line chart might omit some values
226 | min_larger_index = -1
227 | max_smaller_index = -1
228 | # convert
229 | field_vis = (
230 | "field_x"
231 | if encoding["x"]["type"] != "temporal"
232 | else "field_x_origin"
233 | )
234 | value_ground_truth = datum_ground_truth["field_x"]
235 | try:
236 | value_ground_truth = float(value_ground_truth)
237 | except Exception:
238 | value_ground_truth = parse_time_to_timestamp(value_ground_truth)
239 |
240 | for index in range(len(datum)):
241 | if datum[index][field_vis] > value_ground_truth:
242 | if (
243 | min_larger_index == -1
244 | or datum[index][field_vis]
245 | < datum[min_larger_index][field_vis]
246 | ):
247 | min_larger_index = index
248 | elif datum[index][field_vis] < value_ground_truth:
249 | if (
250 | max_smaller_index == -1
251 | or datum[index][field_vis]
252 | > datum[max_smaller_index][field_vis]
253 | ):
254 | max_smaller_index = index
255 | if min_larger_index != -1 and max_smaller_index != -1:
256 | if (
257 | abs(
258 | (
259 | datum[min_larger_index]["field_y"]
260 | - datum[max_smaller_index]["field_y"]
261 | )
262 | )
263 | <= PRECISION
264 | or abs(
265 | (
266 | datum[min_larger_index]["field_y"]
267 | - datum_ground_truth["field_y"]
268 | )
269 | )
270 | <= PRECISION
271 | ):
272 | continue
273 | else:
274 | return (
275 | False,
276 | f"{json.dumps(datum_ground_truth)} not found\n",
277 | )
278 | return False, f"{json.dumps(datum_ground_truth)} not found."
279 | else:
280 | data.remove(datum[0])
281 |
282 | return True, "The data on the charts is consistent with the ground truth."
283 |
284 |
285 | def data_check(chart_info: dict, data: dict, channel_specified: list):
286 | if ("data" not in chart_info) or (len(chart_info["data"]) == 0):
287 | return False, "The data on the charts cannot be understood."
288 |
289 | data_ground_truth = convert_ground_truth_data(data)
290 | # filter zero data
291 | if chart_info["mark"] == "bar":
292 | data_ground_truth = list(
293 | filter(lambda x: x["field_x"] != 0 and x["field_y"] != 0, data_ground_truth)
294 | )
295 | chart_info["data"] = list(
296 | filter(
297 | lambda x: x["field_x"] != 0 and x["field_y"] != 0,
298 | chart_info["data"],
299 | )
300 | )
301 | candidates = [data_ground_truth]
302 | # ground truth channel -> chart channel
303 | channel_maps = []
304 | if len(data["classify"]) == 0:
305 | # 2 channels
306 | channel_maps.append({"x": "x", "y": "y"})
307 | if "x" not in channel_specified and "y" not in channel_specified:
308 | # swap x and y
309 | data_ground_truth_copy = copy.deepcopy(data_ground_truth)
310 | for datum in data_ground_truth_copy:
311 | datum["field_x"], datum["field_y"] = (
312 | datum["field_y"],
313 | datum["field_x"],
314 | )
315 | candidates.append(data_ground_truth_copy)
316 | channel_maps.append({"x": "y", "y": "x"})
317 | else:
318 | # 3 channels
319 | channel_maps.append({"x": "x", "y": "y", "classify": "classify"})
320 | channels = ["x", "y", "classify"]
321 | if len(channel_specified) <= 1:
322 | swap_channels = list(set(channels) - set(channel_specified))
323 | data_ground_truth_copy = copy.deepcopy(data_ground_truth)
324 | for datum in data_ground_truth_copy:
325 | (
326 | datum[f"field_{swap_channels[0]}"],
327 | datum[f"field_{swap_channels[1]}"],
328 | ) = (
329 | datum[f"field_{swap_channels[1]}"],
330 | datum[f"field_{swap_channels[0]}"],
331 | )
332 | candidates.append(data_ground_truth_copy)
333 | channel_map = {"x": "x", "y": "y", "classify": "classify"}
334 | channel_map[swap_channels[0]], channel_map[swap_channels[1]] = (
335 | channel_map[swap_channels[1]],
336 | channel_map[swap_channels[0]],
337 | )
338 | channel_maps.append(channel_map)
339 | if len(channel_specified) == 0:
340 | # swap x and z
341 | data_ground_truth_copy = copy.deepcopy(candidates[0])
342 | for datum in data_ground_truth_copy:
343 | datum["field_x"], datum["field_classify"] = (
344 | datum["field_classify"],
345 | datum["field_x"],
346 | )
347 | candidates.append(data_ground_truth_copy)
348 | channel_maps.append({"x": "classify", "y": "y", "classify": "x"})
349 | # swap x and y, x and z
350 | data_ground_truth_copy = copy.deepcopy(candidates[1])
351 | for datum in data_ground_truth_copy:
352 | datum["field_y"], datum["field_classify"] = (
353 | datum["field_classify"],
354 | datum["field_y"],
355 | )
356 | candidates.append(data_ground_truth_copy)
357 | channel_maps.append({"x": "y", "y": "classify", "classify": "x"})
358 | cache = None
359 | for i in range(len(candidates)):
360 | candidate = candidates[i]
361 | channel_map = channel_maps[i]
362 | answer, rationale = compare_data(candidate, chart_info)
363 | if answer:
364 | chart_info["channel_map"] = channel_map
365 | return answer, rationale
366 | if i == 0:
367 | cache = [answer, rationale]
368 |
369 | return cache[0], cache[1]
370 |
--------------------------------------------------------------------------------
/viseval/check/layout_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import os
5 | import uuid
6 |
7 | overflowScript = """const isElementOutsideCanvas = (element) => {
8 | // Get the bounding box of the SVG element
9 | let svgRect = element.getBBox();
10 |
11 | // Get the viewBox attribute of the SVG element
12 | let svgElement = document.querySelector('svg');
13 | let viewBoxValue = svgElement.getAttribute('viewBox');
14 | let viewBoxArray = viewBoxValue.split(' ').map(Number);
15 | let minX = viewBoxArray[0];
16 | let minY = viewBoxArray[1];
17 | let width = viewBoxArray[2];
18 | let height = viewBoxArray[3];
19 |
20 | // Check if any part of the element is outside the canvas
21 | if (svgRect.x < minX || svgRect.y < minY || svgRect.x + svgRect.width > width+minX || svgRect.y + svgRect.height > height+minY) {
22 | return true;
23 | } else {
24 | return false;
25 | }
26 | }
27 |
28 | let svgElement = document.querySelector('#axes_1');
29 | return isElementOutsideCanvas(svgElement);
30 | """
31 |
32 | overlapScript = """const isTextOverlap = (parentElement) => {
33 | let index = 1;
34 | let flag = true;
35 | let textArray = [];
36 | while (flag) {
37 | let textElement = parentElement.querySelector('#text_' + index);
38 | if (textElement) {
39 | index++;
40 | let gElement = textElement.querySelector('g');
41 | let transform = gElement.getAttribute('transform');
42 | if (transform) {
43 | const rotateMatch = transform.match(/rotate\(([^)]+)\)/);
44 | let notRotate = true;
45 | if (rotateMatch) {
46 | const rotateValue = parseFloat(rotateMatch[1]) % 90;
47 | if (Math.abs(rotateValue) > 10 && Math.abs(rotateValue) < 80) {
48 | notRotate = false;
49 | }
50 | }
51 |
52 | if (notRotate) {
53 | let bbox = textElement.getBBox();
54 | textArray.push(bbox);
55 | }
56 | }
57 | }
58 | else {
59 | flag = false;
60 | }
61 | }
62 | let textOverlap = false;
63 | let textArrayLength = textArray.length;
64 | for (let i = 0; i < textArrayLength; i++) {
65 | for (let j = i + 1; j < textArrayLength; j++) {
66 | let text1Rect = textArray[i];
67 | let text2Rect = textArray[j];
68 | if (text1Rect.x < text2Rect.x + text2Rect.width &&
69 | text1Rect.x + text1Rect.width > text2Rect.x &&
70 | text1Rect.y < text2Rect.y + text2Rect.height &&
71 | text1Rect.y + text1Rect.height > text2Rect.y) {
72 | textOverlap = true;
73 | break;
74 | }
75 | }
76 | }
77 | return textOverlap;
78 | }
79 |
80 | let svgElement = document.querySelector('#axes_1');
81 | return isTextOverlap(svgElement)
82 | """
83 |
84 |
85 | def layout_check(context: dict, webdriver_path):
86 | svg_string = context["svg_string"]
87 |
88 | if webdriver_path is not None:
89 | from selenium import webdriver
90 | from selenium.webdriver.chrome.service import Service
91 |
92 | # Chrome headless mode
93 | options = webdriver.ChromeOptions()
94 | options.add_argument("--headless")
95 | options.add_argument("--disable-gpu")
96 | try:
97 | service = Service(webdriver_path)
98 | webdriver = webdriver.Chrome(service=service, options=options)
99 |
100 | current_directory = os.getcwd()
101 | file_path = f"{current_directory}/temp_{uuid.uuid1()}.svg"
102 |
103 | with open(file_path, "w") as svg_file:
104 | svg_file.write(svg_string)
105 |
106 | webdriver.get(f"file://{file_path}")
107 | overflow_result = not webdriver.execute_script(overflowScript)
108 | overlap_result = not webdriver.execute_script(overlapScript)
109 | webdriver.close()
110 | try:
111 | os.remove(file_path)
112 | except Exception as e:
113 | print(f"Remove file error: {e}")
114 |
115 | if overflow_result and overlap_result:
116 | msg = "No overflow or overlap detected."
117 | elif not overflow_result and not overlap_result:
118 | msg = "Overflow and overlap detected."
119 | elif not overflow_result:
120 | msg = "Overflow detected."
121 | else:
122 | msg = "Overlap detected."
123 | return overflow_result and overlap_result, msg
124 | except Exception as e:
125 | print(e)
126 |
127 | return None, "No webdriver path provided."
128 |
--------------------------------------------------------------------------------
/viseval/check/order_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
5 | def order_check(chart_info: dict, ground_truth: dict, sort_by: str):
6 | order = ground_truth["sort"]
7 | encoding = chart_info["encoding"]
8 | data = chart_info["data"]
9 | channel_map = chart_info["channel_map"]
10 |
11 | if order is not None:
12 | # bar, line
13 | if sort_by == "axis":
14 | order_channel = order["channel"]
15 | else:
16 | order_channel = channel_map[order["channel"]]
17 |
18 | other_channel = "y" if order_channel == "x" else "x"
19 |
20 | order_channel_scale = encoding[order_channel]["scale"]
21 | other_channel_scale = encoding[other_channel]["scale"]
22 |
23 | # origin channel
24 | if (
25 | channel_map[order_channel] == "x"
26 | or channel_map[order_channel] == "classify"
27 | ):
28 | arr = []
29 | if len(order_channel_scale["range"]) == 0:
30 | scale_range = range(1, 1 + len(order_channel_scale["domain"]))
31 | else:
32 | scale_range = order_channel_scale["range"]
33 |
34 | for index in range(len(order_channel_scale["domain"])):
35 | arr.append(
36 | tuple(
37 | [
38 | order_channel_scale["domain"][index],
39 | scale_range[index],
40 | ]
41 | )
42 | )
43 | if order_channel == "x":
44 | reverse = True
45 | else:
46 | reverse = False
47 | if order["order"] == "ascending":
48 | arr.sort(key=lambda x: x[0], reverse=reverse)
49 | elif order["order"] == "descending":
50 | arr.sort(key=lambda x: x[0], reverse=not reverse)
51 | else: # custom order
52 | sort_order = {}
53 | for index in range(len(order["order"])):
54 | sort_order[order["order"][index]] = index
55 | arr.sort(key=lambda x: sort_order[x[0]], reverse=reverse)
56 |
57 | is_sorted = all([arr[i][1] > arr[i + 1][1] for i in range(len(arr) - 1)])
58 | # 'quantitative'
59 | else:
60 | # sort by other channel
61 | values_other = []
62 | if (
63 | "type" not in other_channel_scale
64 | or other_channel_scale["type"] == "ordinal"
65 | ):
66 | for index in range(len(other_channel_scale["domain"])):
67 | values_other.append(
68 | tuple(
69 | [
70 | other_channel_scale["domain"][index],
71 | other_channel_scale["range"][index],
72 | ]
73 | )
74 | )
75 | values_other.sort(key=lambda x: x[1])
76 | values_other = [item[0] for item in values_other]
77 | else:
78 | values_other = list(
79 | set([datum["field_" + other_channel] for datum in data])
80 | )
81 | values_other.sort(
82 | reverse=True
83 | if (
84 | other_channel_scale["domain"][1]
85 | - other_channel_scale["domain"][0]
86 | )
87 | / (
88 | other_channel_scale["range"][1]
89 | - other_channel_scale["range"][0]
90 | )
91 | < 0
92 | else False
93 | )
94 |
95 | # cumulative
96 | values_order = []
97 | for value in values_other:
98 | data_filter = [
99 | float(d["field_" + order_channel])
100 | for d in data
101 | if d["field_" + other_channel] == value
102 | ]
103 | values_order.append(sum(data_filter))
104 |
105 | # filter zero data
106 | if chart_info["mark"] == "bar":
107 | values_order = list(filter(lambda x: x != 0, values_order))
108 |
109 | is_sorted = True
110 | if order["order"] == "ascending":
111 | is_sorted = all(
112 | [
113 | values_order[i] <= values_order[i + 1]
114 | for i in range(len(values_order) - 1)
115 | ]
116 | )
117 | elif order["order"] == "descending":
118 | is_sorted = all(
119 | [
120 | values_order[i] >= values_order[i + 1]
121 | for i in range(len(values_order) - 1)
122 | ]
123 | )
124 |
125 | if not is_sorted:
126 | return False, "Doesn't sort."
127 | else:
128 | return True, "Sorted."
129 | else:
130 | return True, "No sort."
131 |
--------------------------------------------------------------------------------
/viseval/check/readability_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from __future__ import annotations
5 |
6 | import json
7 | import sys
8 | import warnings
9 |
10 | from langchain.chat_models.base import BaseChatModel
11 | from langchain.schema import HumanMessage, SystemMessage
12 |
13 | INSTRUCTION = """Your task is to evaluate the readability of the visualization on a scale of 1 to 5, where 1 indicates very difficult to read and 5 indicates very easy to read. You will be given a visualization requirement and the corresponding visualization created based on that requirement. Additionally, reviews from others regarding this visualization will be provided for your reference. Please think carefully and provide your reasoning and score.
14 | ```
15 | {
16 | "Rationale": "a brief reason",
17 | "Score": 1-5
18 | }
19 | ```
20 |
21 |
22 | Examples:
23 | - If the visualization is clear and information can be easily interpreted, you might return:
24 | ```
25 | {
26 | "Rationale": "The chart is well-organized, and the use of contrasting colors helps in distinguishing different data sets effectively. The labels are legible, and the key insights can be understood at a glance.",
27 | "Score": 5
28 | }
29 | ```
30 | - Conversely, if the visualization is cluttered or confusing, you might return:
31 | ```
32 | {
33 | "Rationale": "While there is no overflow or overlap, the unconventional inverted y-axis and the use of decimal numbers for months on the x-axis deviate from the standard interpretation of bar charts, confusing readers and significantly affecting the chart's readability.",
34 | "Score": 1
35 | }
36 | ```
37 | """
38 |
39 |
40 | def readability_check(context: dict, query: str, vision_model: BaseChatModel):
41 | base64 = context["base64"]
42 |
43 | reviews = ""
44 | if "reviews" in context and len(context["reviews"]) > 0:
45 | reviews = "Other Reviews:\n"
46 | reviews += "\n".join(
47 | [
48 | f"""- {review["aspect"]}: {review["content"]}"""
49 | for review in context["reviews"]
50 | ]
51 | )
52 | reviews += "\n\n"
53 |
54 | try:
55 | messages = [
56 | SystemMessage(content=INSTRUCTION),
57 | ]
58 | messages.append(
59 | HumanMessage(
60 | content=[
61 | {
62 | "type": "text",
63 | "text": f"""Visualization Requirement: {query}\n\n{reviews}Visualization image:""",
64 | },
65 | {
66 | "type": "image_url",
67 | "image_url": base64,
68 | },
69 | {
70 | "type": "text",
71 | "text": """Please assess the readability, taking into account factors such as layout, scale and ticks, title and labels, colors, and ease of extracting information. Do not consider the correctness of the data and order in the visualizations, as they have already been verified.""",
72 | },
73 | ]
74 | )
75 | )
76 |
77 | response = vision_model.invoke(messages)
78 |
79 | json_string = (
80 | response.content.replace("```json\n", "").replace("```", "").strip()
81 | )
82 | try:
83 | result = json.loads(json_string)
84 | except Exception:
85 | result = eval(json_string)
86 |
87 | return result["Score"], result["Rationale"]
88 | except Exception:
89 | warnings.warn(str(sys.exc_info()))
90 | return None, "Exception occurred."
91 |
--------------------------------------------------------------------------------
/viseval/check/scale_and_ticks_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from __future__ import annotations
5 |
6 | import json
7 | import sys
8 | import warnings
9 |
10 | from langchain.chat_models.base import BaseChatModel
11 | from langchain.schema import HumanMessage, SystemMessage
12 |
13 | INSTRUCTION = """You will be provided with a visualization and its specifications. Consider the following aspect:
14 |
15 | - If the scale selected for the visualization is appropriate for accurate interpretation of values, avoid using unconventional scales, such as an inverted y-axis scale.
16 | - When axes are present, ensure that the selected ticks are appropriate for clarity, avoiding unconventional choices, such as representing counts of individual entities with floating-point ticks.
17 |
18 |
19 | Report your findings, focusing solely on scale and tick appropriateness without considering the order.
20 | ```
21 | {
22 | "Appropriate": true/false,
23 | "Rationale": "reason ..."
24 | }
25 | ```
26 | """
27 |
28 |
29 | def scale_and_ticks_check(context: dict, query: str, vision_model: BaseChatModel):
30 | base64 = context["base64"]
31 | encoding = context["encoding"]
32 | chart = context["chart"]
33 | if chart == "pie":
34 | ticks_desc = ""
35 | else:
36 | x_ticks = encoding["x"]["scale"]["ticks"]
37 | y_ticks = encoding["y"]["scale"]["ticks"]
38 | ticks_desc = f"Ticks extracted from the visualization:\n- x axis ticks: {','.join(x_ticks)}\n- y axis ticks: {','.join(y_ticks)}\n\n"
39 | try:
40 | messages = [
41 | SystemMessage(content=INSTRUCTION),
42 | ]
43 | messages.append(
44 | HumanMessage(
45 | content=[
46 | {
47 | "type": "text",
48 | "text": f"Visualization specification:{query}\n\n{ticks_desc}Visualization image that has been verified for DATA and ORDER accuracy:",
49 | },
50 | {
51 | "type": "image_url",
52 | "image_url": base64,
53 | },
54 | ]
55 | )
56 | )
57 |
58 | response = vision_model.invoke(messages)
59 |
60 | json_string = (
61 | response.content.replace("```json\n", "").replace("```", "").strip()
62 | )
63 | try:
64 | result = json.loads(json_string)
65 | except Exception:
66 | result = eval(json_string)
67 | return result["Appropriate"], result["Rationale"]
68 | except Exception:
69 | warnings.warn(str(sys.exc_info()))
70 | return None, "Exception occurred."
71 |
--------------------------------------------------------------------------------
/viseval/check/surface_form_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 | from xml.dom import minidom
4 |
5 | def is_group_children(node):
6 | return node.nodeType == node.ELEMENT_NODE and node.tagName == "g"
7 |
8 | def surface_form_check(svg_string):
9 | """
10 | Check if the code has plotted visualization.
11 | """
12 | doc = minidom.parseString(svg_string)
13 | svg = doc.getElementsByTagName("svg")[0]
14 |
15 | children = list(filter(lambda node: is_group_children(node), svg.childNodes))
16 | if len(children) == 0:
17 | return False, "Did not plot visualization."
18 |
19 | if len(children) == 1:
20 | children = list(filter(lambda node: is_group_children(node), children[0].childNodes))
21 | if len(children) < 2:
22 | return False, "Did not plot visualization."
23 |
24 | return True, "Plotted visualization."
--------------------------------------------------------------------------------
/viseval/check/time_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from datetime import datetime
5 |
6 | from dateutil import parser
7 |
8 | TIME_MAP = {
9 | "mon": "monday",
10 | "tue": "tuesday",
11 | "wed": "wednesday",
12 | "thu": "thursday",
13 | "fri": "friday",
14 | "sat": "saturday",
15 | "sun": "sunday",
16 | "jan": "january",
17 | "feb": "february",
18 | "mar": "march",
19 | "apr": "april",
20 | "may": "may",
21 | "jun": "june",
22 | "jul": "july",
23 | "aug": "august",
24 | "sep": "september",
25 | "sept": "september",
26 | "oct": "october",
27 | "nov": "november",
28 | "dec": "december",
29 | "mon": "monday",
30 | "tue": "tuesday",
31 | "wed": "wednesday",
32 | "thu": "thursday",
33 | "thur": "thursday",
34 | "fri": "friday",
35 | "sat": "saturday",
36 | "sun": "sunday",
37 | }
38 | WEEKDAYS = ["mon", "tue", "wed", "thu", "fri", "sat", "sun"]
39 | MONTHS = [
40 | "jan",
41 | "feb",
42 | "mar",
43 | "apr",
44 | "may",
45 | "jun",
46 | "jul",
47 | "aug",
48 | "sep",
49 | "oct",
50 | "nov",
51 | "dec",
52 | ]
53 |
54 |
55 | def is_month_or_weekday(s: str):
56 | if isinstance(s, str):
57 | if s.lower() in TIME_MAP or (
58 | s[0:3].lower() in TIME_MAP and s.lower() == TIME_MAP[s[0:3].lower()]
59 | ):
60 | return True
61 | return False
62 |
63 |
64 | def convert_month_or_weekday_to_int(s: str) -> int:
65 | if is_month_or_weekday(s):
66 | if s[0:3].lower() in WEEKDAYS:
67 | return WEEKDAYS.index(s[0:3].lower()) + 1
68 | if s[0:3].lower() in MONTHS:
69 | return MONTHS.index(s[0:3].lower()) + 1
70 | return -1
71 |
72 |
73 | def is_datetime(s):
74 | # consider month and weekday as nominal
75 | if is_month_or_weekday(s):
76 | return False
77 | try:
78 | parser.parse(s)
79 | return True
80 | except ValueError:
81 | return False
82 |
83 |
84 | def check_time_format(time_str, time_format):
85 | try:
86 | datetime.strptime(time_str, time_format)
87 | return True
88 | except ValueError:
89 | return False
90 |
91 |
92 | def parse_time_to_timestamp(time_str):
93 | # 0:00 is prone to bias
94 | if check_time_format(time_str, "%Y"):
95 | time_str = time_str + "-01-01 00:00:10"
96 | elif check_time_format(time_str, "%Y-%m"):
97 | time_str = time_str + "-01 00:00:10"
98 | elif check_time_format(time_str, "%Y-%m-%d"):
99 | time_str = time_str + " 00:00:10"
100 |
101 | try:
102 | parsed_time = parser.parse(time_str)
103 | timestamp = parsed_time.timestamp()
104 | return timestamp
105 | except Exception:
106 | return None
107 |
108 |
109 | def parse_timestamp_to_time(timestamp):
110 | # todo: extract date format
111 | date_format = "%Y-%m-%d"
112 | try:
113 | parsed_time = datetime.fromtimestamp(timestamp)
114 | time_str = parsed_time.strftime(date_format)
115 | return time_str
116 | except Exception:
117 | return None
118 |
119 |
120 | # handle case like 2008.436089
121 | def parse_number_to_time(number):
122 | if number > 0 and number < 2999:
123 | timestamp = parse_time_to_timestamp(str(int(number)))
124 | timestamp += (number - int(number)) * 365 * 24 * 60 * 60
125 | return timestamp
126 | return number
127 |
128 |
129 | def compare_time_strings(time_str1: str, time_str2: str):
130 | try:
131 | if parser.parse(time_str1).timestamp() == parser.parse(time_str2).timestamp():
132 | return True
133 | except Exception:
134 | pass
135 |
136 | try:
137 | str1 = TIME_MAP.get(time_str1.lower(), time_str1)
138 | str2 = TIME_MAP.get(time_str2.lower(), time_str2)
139 |
140 | if str1.lower() == str2.lower():
141 | return True
142 |
143 | if (
144 | is_month_or_weekday(str1)
145 | and str(convert_month_or_weekday_to_int(str1)) == str2
146 | ):
147 | return True
148 | if (
149 | is_month_or_weekday(str2)
150 | and str(convert_month_or_weekday_to_int(str2)) == str1
151 | ):
152 | return True
153 |
154 | if (
155 | is_month_or_weekday(str1)
156 | and str(convert_month_or_weekday_to_int(str1)) == str2
157 | ):
158 | return True
159 | if (
160 | is_month_or_weekday(str2)
161 | and str(convert_month_or_weekday_to_int(str2)) == str1
162 | ):
163 | return True
164 |
165 | if is_month_or_weekday(str2) and parse_time_to_timestamp(time_str1) is not None:
166 | # weekday
167 | if (
168 | str2.lower()
169 | == datetime.fromtimestamp(parse_time_to_timestamp(time_str1))
170 | .strftime("%A")
171 | .lower()
172 | ):
173 | return True
174 | # month
175 | if (
176 | str2.lower()
177 | == datetime.fromtimestamp(parse_time_to_timestamp(time_str1))
178 | .strftime("%B")
179 | .lower()
180 | ):
181 | return True
182 | if is_month_or_weekday(str1) and parse_time_to_timestamp(time_str2) is not None:
183 | # weekday
184 | if (
185 | str1.lower()
186 | == datetime.fromtimestamp(parse_time_to_timestamp(time_str2))
187 | .strftime("%A")
188 | .lower()
189 | ):
190 | return True
191 | # month
192 | if (
193 | str1.lower()
194 | == datetime.fromtimestamp(parse_time_to_timestamp(time_str2))
195 | .strftime("%B")
196 | .lower()
197 | ):
198 | return True
199 |
200 | return False
201 | except Exception:
202 | return False
203 |
--------------------------------------------------------------------------------
/viseval/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import json
5 | from pathlib import Path
6 |
7 |
8 | class Dataset:
9 | def __init__(
10 | self,
11 | folder: Path,
12 | table_type: str = "all",
13 | with_irrelevant_tables: bool = False,
14 | ):
15 | self.folder = folder
16 | dict_name = "visEval"
17 | if table_type in ["single", "multiple"]:
18 | dict_name += "_" + table_type
19 | dict_name += ".json"
20 | with open(folder / dict_name) as f:
21 | self.dict = json.load(f)
22 |
23 | with open(folder / "databases/db_tables.json") as f:
24 | self.db_tables = json.load(f)
25 |
26 | def benchmark():
27 | for key in list(self.dict.keys()):
28 | self.dict[key]["id"] = key
29 | self.dict[key]["tables"] = self.__get_tables(
30 | key, with_irrelevant_tables
31 | )
32 | yield self.dict[key]
33 |
34 | self.benchmark = benchmark()
35 |
36 | def __get_tables(self, id: str, with_irrelevant_tables: bool = False):
37 | spec = self.dict[id]
38 | db_id = spec["db_id"]
39 | # table name
40 | all_table_names = self.db_tables[db_id]
41 | table_names = [
42 | x
43 | for x in all_table_names
44 | if x.lower() in spec["vis_query"]["VQL"].lower().split()
45 | ]
46 |
47 | if with_irrelevant_tables:
48 | irrelevant_tables = spec["irrelevant_tables"]
49 | table_names.extend(irrelevant_tables)
50 |
51 | tables = list(
52 | map(
53 | lambda table_name: f"{self.folder}/databases/{db_id}/{table_name}.csv",
54 | table_names,
55 | )
56 | )
57 |
58 | return tables
59 |
--------------------------------------------------------------------------------
/viseval/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import base64
5 | import json
6 | import logging
7 | import os
8 | from typing import Union
9 |
10 | import cairosvg
11 | import pandas as pd
12 | from attr import dataclass
13 |
14 | from .check import (
15 | chart_check,
16 | data_check,
17 | deconstruct,
18 | layout_check,
19 | order_check,
20 | readability_check,
21 | scale_and_ticks_check,
22 | )
23 | from .dataset import Dataset
24 |
25 |
26 | @dataclass
27 | class CheckResult:
28 | answer: Union[bool, int]
29 | aspect: str
30 | rationale: str
31 |
32 | def get_json(self):
33 | return {
34 | "answer": self.answer,
35 | "aspect": self.aspect,
36 | "rationale": self.rationale,
37 | }
38 |
39 |
40 | @dataclass
41 | class EvaluationDetail:
42 | id: str
43 | results: list[list[CheckResult]]
44 |
45 |
46 | VALID_ASPECTS = ["code execution", "surface-form check"]
47 | LEGAL_ASPECTS = ["deconstruction", "chart type check", "data check", "order check"]
48 | READABILITY_ASPECT = ["layout check", "scale and ticks check", "readability check"]
49 |
50 | FAIL_ASPECTS = VALID_ASPECTS + LEGAL_ASPECTS + ["layout check", "scale and ticks check"]
51 |
52 |
53 | class EvaluationResult:
54 | dataset: Dataset
55 | details: list[EvaluationDetail]
56 |
57 | def __init__(self, dataset: Dataset, details: list[EvaluationDetail]):
58 | self.dataset = dataset
59 | self.details = details
60 |
61 | def detail_records(self) -> pd.DataFrame:
62 | records = []
63 | for detail in self.details:
64 | id = detail.id
65 | instance_results = detail.results
66 | count = len(instance_results)
67 | record = {
68 | "id": id,
69 | "chart": self.dataset.dict[id]["chart"],
70 | "hardness": self.dataset.dict[id]["hardness"],
71 | }
72 |
73 | # fail rate
74 | for aspect in FAIL_ASPECTS:
75 | evaluate_result = [
76 | (
77 | all(
78 | [
79 | item.answer
80 | for item in query_results
81 | if item.aspect == aspect
82 | ]
83 | )
84 | )
85 | for query_results in instance_results
86 | ]
87 | fail_result = [item for item in evaluate_result if not item]
88 | record[f"{aspect}_fail_rate"] = len(fail_result) / count
89 |
90 | high_level_dimensions = [
91 | ["invalid_rate", VALID_ASPECTS],
92 | ["illegal rate", LEGAL_ASPECTS],
93 | ]
94 | pass_count = count
95 | for dimension in high_level_dimensions:
96 | evaluate_result = [
97 | (
98 | all(
99 | [
100 | item.answer
101 | for item in query_results
102 | if (item.aspect in dimension[1])
103 | ]
104 | )
105 | )
106 | for query_results in instance_results
107 | ]
108 | false_count = len([item for item in evaluate_result if not item])
109 | record[dimension[0]] = false_count / count
110 | pass_count -= false_count
111 | records.append(record)
112 |
113 | # pass rate
114 | record["pass_rate"] = pass_count / count
115 | records.append(record)
116 |
117 | # readability score
118 | evaluate_result = [
119 | (
120 | sum(
121 | [
122 | item.answer
123 | for item in query_results
124 | if item.aspect == "readability check"
125 | ]
126 | )
127 | )
128 | for query_results in instance_results
129 | ]
130 | if pass_count > 0:
131 | record["readability_score"] = sum(evaluate_result) / pass_count
132 |
133 | record["quality_score"] = sum(evaluate_result) / count
134 |
135 | return pd.DataFrame(records)
136 |
137 | def score(self):
138 | records = self.detail_records()
139 | metrics = [
140 | "invalid_rate",
141 | "illegal rate",
142 | "pass_rate",
143 | "readability_score",
144 | "quality_score",
145 | ]
146 | score = {}
147 | for metric in metrics:
148 | score[metric] = records[metric].mean()
149 |
150 | for key in records.keys():
151 | if (
152 | key not in metrics
153 | and key != "id"
154 | and key != "chart"
155 | and key != "hardness"
156 | ):
157 | score[key] = records[key].mean()
158 |
159 | return score
160 |
161 |
162 | def convert_svg_to_base64(svg_string):
163 | png_string = cairosvg.svg2png(bytestring=svg_string)
164 | base64_encoded = base64.b64encode(png_string).decode("utf-8")
165 | return f"data:image/png;base64,{base64_encoded}"
166 |
167 |
168 | class Evaluator:
169 | def __init__(self, webdriver_path=None, vision_model=None):
170 | self.webdriver_path = webdriver_path
171 | self.vision_model = vision_model
172 |
173 | def evaluate(self, agent, dataset, config):
174 | use_logs = False
175 | evaluation_details = []
176 | if "logs" in config:
177 | log_folder = config["logs"]
178 | isExists = os.path.exists(log_folder)
179 | try:
180 | if not isExists:
181 | os.makedirs(log_folder)
182 | logging.basicConfig(
183 | level=logging.INFO,
184 | filename=log_folder / "evaluation.log",
185 | filemode="a",
186 | format="%(levelname)s: %(message)s",
187 | )
188 | use_logs = True
189 | except Exception as e:
190 | print(e)
191 |
192 | for instance in dataset.benchmark:
193 | codes = []
194 | instance_results = []
195 | nl_queries = instance["nl_queries"]
196 | tables = instance["tables"]
197 |
198 | if use_logs:
199 | instanceFolder = log_folder / instance["id"]
200 | path = instanceFolder / "result.json"
201 | if os.path.exists(path):
202 | with open(path, "r") as f:
203 | data = json.load(f)
204 | if "codes" in data and "evaluations" in data:
205 | instance_results = []
206 | for query_result in data["evaluations"]:
207 | results = [
208 | CheckResult(
209 | answer=result["answer"],
210 | aspect=result["aspect"],
211 | rationale=result["rationale"],
212 | )
213 | for result in query_result
214 | ]
215 | instance_results.append(results)
216 | evaluation_details.append(
217 | EvaluationDetail(instance["id"], instance_results)
218 | )
219 | continue
220 | else:
221 | logging.info(f"Instance ({instance['id']}) evaluation began.")
222 | isExists = os.path.exists(instanceFolder)
223 | if not isExists:
224 | os.makedirs(instanceFolder)
225 |
226 | for index in range(len(nl_queries)):
227 | nl_query = nl_queries[index]
228 | if index < len(codes):
229 | code = codes[index]
230 | context = {}
231 | context["tables"] = tables
232 | else:
233 | code, context = agent.generate(nl_query, tables, config)
234 | codes.append(code)
235 | if code is None:
236 | results = [
237 | CheckResult(
238 | answer=False,
239 | aspect="generation",
240 | rationale="Code generation failed.",
241 | )
242 | ]
243 | else:
244 | context["library"] = config["library"]
245 | if use_logs:
246 | results = self.validity_check(
247 | code, context, agent, instanceFolder / f"{index}.svg"
248 | )
249 | else:
250 | results = self.validity_check(code, context, agent)
251 |
252 | pass_validity = all([result.answer for result in results])
253 | if pass_validity:
254 | ground_truth = {
255 | "chart": instance["chart"],
256 | "vis_obj": instance["vis_obj"],
257 | "meta_info": instance["query_meta"][index],
258 | }
259 | results += self.legality_check(context, ground_truth)
260 |
261 | pass_legality = all([result.answer for result in results])
262 | if pass_legality:
263 | results += self.readability_evaluate(context, nl_query)
264 |
265 | instance_results.append(results)
266 |
267 | evaluation_details.append(
268 | EvaluationDetail(instance["id"], instance_results)
269 | )
270 | if use_logs:
271 | logging.info(f"Instance ({instance['id']}) evaluation finished.")
272 | # convert CheckResult to json
273 | instance_results = [
274 | [result.get_json() for result in results]
275 | for results in instance_results
276 | ]
277 | with open(log_folder / (instance["id"] + "/result.json"), "w") as f:
278 | f.write(
279 | json.dumps({"codes": codes, "evaluations": instance_results})
280 | )
281 | f.close()
282 | return EvaluationResult(dataset, evaluation_details)
283 |
284 | def execute(self, code, context, agent, log_name=None) -> CheckResult:
285 | result = agent.execute(code, context, log_name)
286 | if result.status is False:
287 | return CheckResult(
288 | answer=False, aspect="code execution", rationale=result.error_msg
289 | )
290 |
291 | context["svg_string"] = result.svg_string
292 | return CheckResult(
293 | answer=True,
294 | aspect="code execution",
295 | rationale="Code executed successfully.",
296 | )
297 |
298 | def surface_form_check(self, context) -> CheckResult:
299 | svg_string = context["svg_string"]
300 | answer, rationale = surface_form_check(svg_string)
301 | return CheckResult(
302 | answer=answer,
303 | aspect="surface-form check",
304 | rationale=rationale,
305 | )
306 |
307 | def validity_check(self, code, context, agent, log_name=None) -> list[CheckResult]:
308 | results = []
309 | result = self.execute(code, context, agent, log_name)
310 | results.append(result)
311 | if result.answer:
312 | result = self.surface_form_check(context)
313 | results.append(result)
314 |
315 | return results
316 |
317 | def deconstruction(self, context) -> CheckResult:
318 | svg_string = context["svg_string"]
319 | library = context["library"]
320 | if library == "seaborn":
321 | library = "matplotlib"
322 | try:
323 | chart_info, msg = deconstruct(svg_string, library)
324 | if chart_info is None:
325 | return CheckResult(
326 | answer=False,
327 | aspect="deconstruction",
328 | rationale=msg,
329 | )
330 | context.update(chart_info)
331 | return CheckResult(
332 | answer=True,
333 | aspect="deconstruction",
334 | rationale="Deconstructed the chart successfully.",
335 | )
336 | except:
337 | return CheckResult(
338 | answer=False,
339 | aspect="deconstruction",
340 | rationale="Cannot parse the visualization.",
341 | )
342 |
343 | def chart_type_check(self, context, ground_truth) -> CheckResult:
344 | answer, rationale = chart_check(
345 | context,
346 | ground_truth["chart"],
347 | (
348 | ground_truth["meta_info"]["stacked_bar"]
349 | if "stacked_bar" in ground_truth["meta_info"]
350 | else None
351 | ),
352 | )
353 | return CheckResult(
354 | answer=answer,
355 | aspect="chart type check",
356 | rationale=rationale,
357 | )
358 |
359 | def data_check(self, context, ground_truth) -> CheckResult:
360 | answer, rationale = data_check(
361 | context,
362 | ground_truth["vis_obj"],
363 | ground_truth["meta_info"]["channel_specified"],
364 | )
365 | return CheckResult(
366 | answer=answer,
367 | aspect="data check",
368 | rationale=rationale,
369 | )
370 |
371 | def order_check(self, context, ground_truth) -> CheckResult:
372 | answer, rationale = order_check(
373 | context,
374 | ground_truth["vis_obj"],
375 | (
376 | ground_truth["meta_info"]["sort_by"]
377 | if "sort_by" in ground_truth["meta_info"]
378 | else None
379 | ),
380 | )
381 | return CheckResult(
382 | answer=answer,
383 | aspect="order check",
384 | rationale=rationale,
385 | )
386 |
387 | def legality_check(self, context, ground_truth) -> list[CheckResult]:
388 | results = []
389 | result = self.deconstruction(context)
390 | results.append(result)
391 | if result.answer:
392 | chart_type_check_result = self.chart_type_check(context, ground_truth)
393 | data_check_result = self.data_check(context, ground_truth)
394 | results.append(chart_type_check_result)
395 | results.append(data_check_result)
396 | if data_check_result.answer and ground_truth["vis_obj"]["sort"] is not None:
397 | self.order_check(context, ground_truth)
398 | results.append(self.order_check(context, ground_truth))
399 |
400 | return results
401 |
402 | def layout_check(self, context) -> CheckResult:
403 | assert "svg_string" in context
404 | assert self.webdriver_path is not None
405 |
406 | answer, rationale = layout_check(context, self.webdriver_path)
407 | return CheckResult(
408 | answer=answer,
409 | aspect="layout check",
410 | rationale=rationale,
411 | )
412 |
413 | def scale_and_ticks_check(self, context, query) -> CheckResult:
414 | assert "base64" in context and "encoding" in context and "chart" in context
415 | assert self.vision_model is not None
416 |
417 | answer, rationale = scale_and_ticks_check(context, query, self.vision_model)
418 | return CheckResult(
419 | answer=answer,
420 | aspect="scale and ticks check",
421 | rationale=rationale,
422 | )
423 |
424 | def readability_evaluate(self, context, query: str) -> list[CheckResult]:
425 | results = []
426 | if self.webdriver_path:
427 | layout_result = self.layout_check(context)
428 | if layout_result.answer is not None:
429 | results.append(layout_result)
430 |
431 | if self.vision_model:
432 | context["base64"] = convert_svg_to_base64(context["svg_string"])
433 | scale_and_ticks_result = self.scale_and_ticks_check(context, query)
434 | if scale_and_ticks_result.answer is not None:
435 | results.append(scale_and_ticks_result)
436 |
437 | aspect_format = {
438 | "layout check": "Overflow/Overlap",
439 | "scale and ticks check": "Scale/Ticks",
440 | }
441 | reviews = [
442 | {
443 | "aspect": aspect_format[result.aspect],
444 | "content": result.rationale,
445 | }
446 | for result in results
447 | ]
448 | context["reviews"] = reviews
449 |
450 | answer, rationale = readability_check(context, query, self.vision_model)
451 | if answer is not None:
452 | readability_result = CheckResult(
453 | answer=answer,
454 | aspect="readability check",
455 | rationale=rationale,
456 | )
457 | results.append(readability_result)
458 |
459 | return results
460 |
--------------------------------------------------------------------------------
/viseval_dataset.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/VisEval/619155231c476aaa05f1b3a5b6d79082a6bcf782/viseval_dataset.zip
--------------------------------------------------------------------------------