├── .flake8 ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .prettierignore ├── .yarnrc.yml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── assets ├── .gitignore ├── coml.db ├── demo.gif ├── demo_coml.gif ├── demo_comlfix.gif ├── demo_comlinspire.gif ├── example_descriptions.json └── example_history.csv ├── coml ├── __init__.py ├── configagent │ ├── .env.template │ ├── __init__.py │ ├── __main__.py │ ├── cli.py │ ├── constants.py │ ├── experience.py │ ├── knowledge.py │ ├── orm.py │ ├── space.py │ ├── suggest.py │ ├── surrogate_utils.py │ └── utils.py ├── core.py ├── ipython_utils.py ├── js │ └── nbclassic_init.js ├── linter.py ├── magics.py ├── prompt_utils.py ├── prompts │ ├── dataset │ │ ├── coffee_shop │ │ │ └── member.csv │ │ ├── game_injury │ │ │ └── game.csv │ │ ├── pilot_record │ │ │ └── pilot.csv │ │ └── scientist_1 │ │ │ ├── AssignedTo.csv │ │ │ └── Projects.csv │ ├── fix_fewshots.json │ ├── fix_fewshots_v2.json │ ├── generate_fewshots.json │ ├── generate_fewshots_kaggle.json │ ├── generate_fewshots_leetcode.json │ ├── generate_fewshots_matplotlib.json │ ├── generate_fewshots_seaborn.json │ └── generate_fewshots_v2.json └── vis_utils │ ├── __init__.py │ ├── deconstruct.py │ └── verifier.py ├── install.json ├── package.json ├── pyproject.toml ├── src └── index.ts ├── test └── configagent │ ├── __init__.py │ ├── helper.py │ ├── test_experience.py │ ├── test_knowledge.py │ ├── test_space.py │ ├── test_suggest.py │ ├── test_surrogate_utils.py │ └── test_utils.py └── tsconfig.json /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | select = "E303, W293, W291, W292, E305, E231, E302" 4 | exclude = 5 | .tox, 6 | __pycache__, 7 | *.pyc, 8 | .env 9 | venv*/*, 10 | .venv/*, 11 | reports/*, 12 | dist/*, 13 | node_modules/*, 14 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Python CI 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | env: 13 | min-python-version: "3.10" 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Set up Python ${{ env.min-python-version }} 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ env.min-python-version }} 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install -e .[dev] 27 | 28 | - name: Lint with flake8 29 | run: flake8 30 | 31 | - name: Check black formatting 32 | run: black . --check 33 | if: success() || failure() 34 | 35 | - name: Check isort formatting 36 | run: isort . --check 37 | if: success() || failure() 38 | 39 | build: 40 | 41 | runs-on: ubuntu-latest 42 | strategy: 43 | matrix: 44 | python-version: ["3.8", "3.9", "3.10", "3.11"] 45 | 46 | steps: 47 | - uses: actions/checkout@v4 48 | 49 | - name: Set up Python ${{ matrix.python-version }} 50 | uses: actions/setup-python@v4 51 | with: 52 | python-version: ${{ matrix.python-version }} 53 | 54 | - name: List files 55 | run: ls -al coml 56 | 57 | - name: Install dependencies 58 | run: | 59 | python -m pip install --upgrade pip 60 | pip install -e .[dev] 61 | 62 | - name: Test with pytest 63 | run: | 64 | pytest 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Version file is handled by hatchling 36 | coml/_version.py 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # Jupyter Lab 85 | *.bundle.* 86 | node_modules/ 87 | .yarn/ 88 | yarn.lock 89 | *.log 90 | .eslintcache 91 | .stylelintcache 92 | *.tsbuildinfo 93 | labextension 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | # For a library or package, you might want to ignore these files since the code is 101 | # intended to run in multiple environments; otherwise, check them in: 102 | # .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # poetry 112 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 113 | # This is especially recommended for binary packages to ensure reproducibility, and is more 114 | # commonly ignored for libraries. 115 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 116 | #poetry.lock 117 | 118 | # pdm 119 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 120 | #pdm.lock 121 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 122 | # in version control. 123 | # https://pdm.fming.dev/#use-with-ide 124 | .pdm.toml 125 | 126 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 127 | __pypackages__/ 128 | 129 | # Celery stuff 130 | celerybeat-schedule 131 | celerybeat.pid 132 | 133 | # SageMath parsed files 134 | *.sage.py 135 | 136 | # Environments 137 | .env 138 | .venv 139 | env/ 140 | venv/ 141 | ENV/ 142 | env.bak/ 143 | venv.bak/ 144 | 145 | # Spyder project settings 146 | .spyderproject 147 | .spyproject 148 | 149 | # Rope project settings 150 | .ropeproject 151 | 152 | # mkdocs documentation 153 | /site 154 | 155 | # mypy 156 | .mypy_cache/ 157 | .dmypy.json 158 | dmypy.json 159 | 160 | # Pyre type checker 161 | .pyre/ 162 | 163 | # pytype static type analyzer 164 | .pytype/ 165 | 166 | # Cython debug symbols 167 | cython_debug/ 168 | 169 | # Code styling tools 170 | .isort.cfg 171 | 172 | # PyCharm 173 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 174 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 175 | # and can be added to the global gitignore or merged into this file. For a more nuclear 176 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 177 | #.idea/ 178 | .vscode/ 179 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | # - id: check-added-large-files 6 | # args: ['--maxkb=500'] 7 | - id: check-byte-order-marker 8 | - id: check-case-conflict 9 | - id: check-merge-conflict 10 | - id: check-symlinks 11 | - id: debug-statements 12 | 13 | - repo: https://github.com/pycqa/isort 14 | rev: 5.12.0 15 | hooks: 16 | - id: isort 17 | language_version: python3.8 18 | types: [python] 19 | 20 | - repo: https://github.com/psf/black 21 | rev: 23.3.0 22 | hooks: 23 | - id: black 24 | language_version: python3.8 25 | types: [python] -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | **/node_modules 3 | **/lib 4 | **/package.json 5 | !/package.json 6 | coml 7 | -------------------------------------------------------------------------------- /.yarnrc.yml: -------------------------------------------------------------------------------- 1 | enableImmutableInstalls: false 2 | 3 | nodeLinker: node-modules 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoML 2 | 3 | CoML (formerly MLCopilot) is an interactive coding assistant for data scientists and machine learning developers, empowered on large language models. 4 | 5 | Highlight features: 6 | 7 | * Out-of-the-box interactive natural language programming interface for data mining and machine learning tasks. 8 | * Integration with Jupyter lab and Jupyter notebook. 9 | * Built-in large knowledge base of machine learning, enhancing the ability of solving complex tasks. 10 | 11 | ## Installation 12 | 13 | ```bash 14 | pip install mlcopilot 15 | ``` 16 | 17 | (We can't have the name `coml` on PyPI, so we use `mlcopilot` instead.) 18 | 19 | ## CoML in Jupyter Lab 20 | 21 | We recommend trying CoML in a Jupyter Lab environment. Before using CoML, please make sure that: 22 | 23 | 1. You have exported `OPENAI_API_KEY=sk-xxxx` in your environment. Alternatively, you can also use a `.env` file. 24 | 2. Use `%load_ext coml` in your notebook to active CoML extension. 25 | 26 | Then we have provided several commands to assist your journey with interactive coding in Jupyter Lab. 27 | 28 | 1. `%coml ` to prompt CoML to write a cell for your task. 29 | 30 | ![](assets/demo_coml.gif) 31 | 32 | 2. `%comlfix` to fix the cell just above the current cell. You can also use `%comlfix ` to provide details for what's wrong. 33 | 34 | ![](assets/demo_comlfix.gif) 35 | 36 | 3. `%comlinspire` to inspire you with a cell describing what to do next. 37 | 38 | ![](assets/demo_comlinspire.gif) 39 | 40 | **Limitations:** 41 | 42 | * Currently, CoML only supports Jupyter Lab and classical Jupyter notebook (nbclassic, and only on Linux platforms). We are still working on supports of newer Jupyter notebook, Jupyter-vscode and Google Colab. 43 | * CoML uses gpt-3.5-turbo-16k model in its implementation. There is no way to change the model for now. The cost of using this model is around $0.04 per request. Please be aware of this cost. 44 | 45 | ## CoML Config Agent 46 | 47 | CoML config agent is the implementation of [MLCopilot]((https://arxiv.org/abs/2304.14979)), which can suggest a ML configuration within a specific task, for a specific task. Currently, it is an independent component residing in `coml.configagent`. In the future, we will integrate it into the CoML system. 48 | 49 | ![](assets/demo.gif) 50 | 51 | (TODO: The demo needs an update.) 52 | 53 | #### Extra preparation steps 54 | 55 | Please follow the steps to use CoML config agent: 56 | 57 | 1. Clone this repo: `git clone REPO_URL; cd coml` 58 | 2. Put assets/coml.db in your home directory: `cp assets/coml.db ~/.coml/coml.db` 59 | 3. Copy `coml/.env.template` to `~/.coml/.env` and put your API keys in the file. 60 | 61 | #### Command line utility 62 | 63 | Currently, it can only be invoked independently. You can use the following command line: 64 | 65 | ``` 66 | coml-configagent --space --task 67 | ``` 68 | 69 | If you feel uncertain about what to put into `` or ``, see the demo above, or try the interactive usage below: 70 | 71 | ``` 72 | coml-configagent --interactive 73 | ``` 74 | 75 | #### API Usage 76 | 77 | ```python 78 | from coml.configagent.suggest import suggest 79 | 80 | space = import_space("YOUR_SPACE_ID") 81 | task_desc = "YOUR_TASK_DESCRIPTION_FOR_NEW_TASK" 82 | suggest_configs, knowledge = suggest(space, task_desc) 83 | ``` 84 | 85 | ## Development 86 | 87 | Development documentation stays here for now. It shall be moved to a separate document later. 88 | 89 | ### Project structure 90 | 91 | Important files and folders: 92 | 93 | ``` 94 | CoML 95 | ├── assets # data, examples, demos 96 | ├── coml # Python package 97 | ├── examples # example scripts 98 | ├── install.json # Jupyter lab extension installation file 99 | ├── package.json # Jupyter lab extension package file 100 | ├── pyproject.toml # Python package configuration 101 | ├── src # Jupyter lab extension source code 102 | ├── test # Python package tests 103 | └── tsconfig.json # Jupyter lab extension TypeScript configuration 104 | ``` 105 | 106 | ### Installation and uninstallation 107 | 108 | You can use the following command for development installation: 109 | 110 | ``` 111 | pip install -e .[dev] 112 | ``` 113 | 114 | If you are to develop Jupyter lab extension, you will also need to install NodeJS and npm, and run the following command: 115 | 116 | ``` 117 | # Link your development version of the extension with JupyterLab 118 | jupyter labextension develop . --overwrite 119 | # Rebuild extension Typescript source after making changes 120 | jlpm run build 121 | ``` 122 | 123 | To uninstall, you can run the following commands: 124 | 125 | ```bash 126 | # Server extension must be manually disabled in develop mode 127 | jupyter server extension disable coml 128 | 129 | # Uninstall the Python package 130 | pip uninstall mlcopilot 131 | ``` 132 | 133 | In development mode, you will also need to remove the symlink created by `jupyter labextension develop` command. 134 | To find its location, you can run `jupyter labextension list` to figure out where the `labextensions` folder is located. 135 | Then you can remove the symlink named `coml` within that folder. 136 | 137 | ### Packaging 138 | 139 | ```bash 140 | hatch build 141 | ``` 142 | 143 | ## Citation 144 | 145 | If you find this work useful, please cite the paper as below: 146 | 147 | @article{zhang2023mlcopilot, 148 | title={MLCopilot: Unleashing the Power of Large Language Models in Solving Machine Learning Tasks}, 149 | author={Zhang, Lei and Zhang, Yuge and Ren, Kan and Li, Dongsheng and Yang, Yuqing}, 150 | journal={arXiv preprint arXiv:2304.14979}, 151 | year={2023} 152 | } 153 | 154 | ## License 155 | 156 | The entire codebase is under [MIT license](LICENSE). 157 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /assets/.gitignore: -------------------------------------------------------------------------------- 1 | /private 2 | -------------------------------------------------------------------------------- /assets/coml.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/CoML/46f7452d70021be7702d05343096cdc1adec09ef/assets/coml.db -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/CoML/46f7452d70021be7702d05343096cdc1adec09ef/assets/demo.gif -------------------------------------------------------------------------------- /assets/demo_coml.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/CoML/46f7452d70021be7702d05343096cdc1adec09ef/assets/demo_coml.gif -------------------------------------------------------------------------------- /assets/demo_comlfix.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/CoML/46f7452d70021be7702d05343096cdc1adec09ef/assets/demo_comlfix.gif -------------------------------------------------------------------------------- /assets/demo_comlinspire.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/CoML/46f7452d70021be7702d05343096cdc1adec09ef/assets/demo_comlinspire.gif -------------------------------------------------------------------------------- /assets/example_descriptions.json: -------------------------------------------------------------------------------- 1 | { 2 | "3950": "The dataset name is \"musk\". It contains 2 classes, 6598 instances, 168 features, 166 numeric features, 2 categorical features. The majority class size is 5581 and the minority class size is 1017.", 3 | "3492": "The dataset name is \"monks-problems-1\". It contains 2 classes, 556 instances, 7 features, 0 numeric features, 7 categorical features. The majority class size is 278 and the minority class size is 278.", 4 | "145834": "The dataset name is \"banknote-authentication\". It contains 2 classes, 1372 instances, 5 features, 4 numeric features, 1 categorical features. The majority class size is 762 and the minority class size is 610.", 5 | "9967": "The dataset name is \"steel-plates-fault\". It contains 2 classes, 1941 instances, 34 features, 33 numeric features, 1 categorical features. The majority class size is 1268 and the minority class size is 673.", 6 | "49": "The dataset name is \"tic-tac-toe\". It contains 2 classes, 958 instances, 10 features, 0 numeric features, 10 categorical features. The majority class size is 626 and the minority class size is 332.", 7 | "145953": "The dataset name is \"kr-vs-kp\". It contains 2 classes, 3196 instances, 37 features, 0 numeric features, 37 categorical features. The majority class size is 1669 and the minority class size is 1527.", 8 | "3889": "The dataset name is \"sylva_agnostic\". It contains 2 classes, 14395 instances, 217 features, 216 numeric features, 1 categorical features. The majority class size is 13509 and the minority class size is 886.", 9 | "3485": "The dataset name is \"scene\". It contains 2 classes, 2407 instances, 300 features, 294 numeric features, 6 categorical features. The majority class size is 1976 and the minority class size is 431.", 10 | "146066": "The dataset name is \"monks-problems-3\". It contains 2 classes, 554 instances, 7 features, 0 numeric features, 7 categorical features. The majority class size is 288 and the minority class size is 266.", 11 | "9889": "The dataset name is \"wilt\". It contains 2 classes, 4839 instances, 6 features, 5 numeric features, 1 categorical features. The majority class size is 4578 and the minority class size is 261.", 12 | "145878": "The dataset name is \"wdbc\". It contains 2 classes, 569 instances, 31 features, 30 numeric features, 1 categorical features. The majority class size is 357 and the minority class size is 212.", 13 | "145854": "The dataset name is \"nomao\". It contains 2 classes, 34465 instances, 119 features, 89 numeric features, 30 categorical features. The majority class size is 24621 and the minority class size is 9844.", 14 | "145855": "The dataset name is \"ozone-level-8hr\". It contains 2 classes, 2534 instances, 73 features, 72 numeric features, 1 categorical features. The majority class size is 2374 and the minority class size is 160.", 15 | "3918": "The dataset name is \"pc1\". It contains 2 classes, 1109 instances, 22 features, 21 numeric features, 1 categorical features. The majority class size is 1032 and the minority class size is 77.", 16 | "43": "The dataset name is \"spambase\". It contains 2 classes, 4601 instances, 58 features, 57 numeric features, 1 categorical features. The majority class size is 2788 and the minority class size is 1813.", 17 | "3891": "The dataset name is \"gina_agnostic\". It contains 2 classes, 3468 instances, 971 features, 970 numeric features, 1 categorical features. The majority class size is 1763 and the minority class size is 1705.", 18 | "9980": "The dataset name is \"climate-model-simulation-crashes\". It contains 2 classes, 540 instances, 21 features, 20 numeric features, 1 categorical features. The majority class size is 494 and the minority class size is 46.", 19 | "3902": "The dataset name is \"pc4\". It contains 2 classes, 1458 instances, 38 features, 37 numeric features, 1 categorical features. The majority class size is 1280 and the minority class size is 178.", 20 | "3899": "The dataset name is \"mozilla4\". It contains 2 classes, 15545 instances, 6 features, 5 numeric features, 1 categorical features. The majority class size is 10437 and the minority class size is 5108.", 21 | "145833": "The dataset name is \"bank-marketing\". It contains 2 classes, 45211 instances, 17 features, 7 numeric features, 10 categorical features. The majority class size is 39922 and the minority class size is 5289.", 22 | "3903": "The dataset name is \"pc3\". It contains 2 classes, 1563 instances, 38 features, 37 numeric features, 1 categorical features. The majority class size is 1403 and the minority class size is 160.", 23 | "9952": "The dataset name is \"phoneme\". It contains 2 classes, 5404 instances, 6 features, 5 numeric features, 1 categorical features. The majority class size is 3818 and the minority class size is 1586.", 24 | "14951": "The dataset name is \"eeg-eye-state\". It contains 2 classes, 14980 instances, 15 features, 14 numeric features, 1 categorical features. The majority class size is 8257 and the minority class size is 6723.", 25 | "9957": "The dataset name is \"qsar-biodeg\". It contains 2 classes, 1055 instances, 42 features, 41 numeric features, 1 categorical features. The majority class size is 699 and the minority class size is 356.", 26 | "3954": "The dataset name is \"MagicTelescope\". It contains 2 classes, 19020 instances, 12 features, 11 numeric features, 1 categorical features. The majority class size is 12332 and the minority class size is 6688.", 27 | "3917": "The dataset name is \"kc1\". It contains 2 classes, 2109 instances, 22 features, 21 numeric features, 1 categorical features. The majority class size is 1783 and the minority class size is 326.", 28 | "219": "The dataset name is \"electricity\". It contains 2 classes, 45312 instances, 9 features, 7 numeric features, 2 categorical features. The majority class size is 26075 and the minority class size is 19237.", 29 | "3896": "The dataset name is \"ada_agnostic\". It contains 2 classes, 4562 instances, 49 features, 48 numeric features, 1 categorical features. The majority class size is 3430 and the minority class size is 1132.", 30 | "7295": "The dataset name is \"Click_prediction_small\". It contains 2 classes, 39948 instances, 10 features, 9 numeric features, 1 categorical features. The majority class size is 33220 and the minority class size is 6728.", 31 | "145836": "The dataset name is \"blood-transfusion-service-center\". It contains 2 classes, 748 instances, 5 features, 4 numeric features, 1 categorical features. The majority class size is 570 and the minority class size is 178.", 32 | "145976": "The dataset name is \"diabetes\". It contains 2 classes, 768 instances, 9 features, 8 numeric features, 1 categorical features. The majority class size is 500 and the minority class size is 268.", 33 | "145972": "The dataset name is \"credit-g\". It contains 2 classes, 1000 instances, 21 features, 7 numeric features, 14 categorical features. The majority class size is 700 and the minority class size is 300.", 34 | "9970": "The dataset name is \"hill-valley\". It contains 2 classes, 1212 instances, 101 features, 100 numeric features, 1 categorical features. The majority class size is 606 and the minority class size is 606.", 35 | "9976": "The dataset name is \"madelon\". It contains 2 classes, 2600 instances, 501 features, 500 numeric features, 1 categorical features. The majority class size is 1300 and the minority class size is 1300." 36 | } -------------------------------------------------------------------------------- /coml/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from ._version import __version__ 4 | from .core import CoMLAgent 5 | from .prompt_utils import describe_variable, filter_variables 6 | 7 | 8 | def load_ipython_extension(ipython): 9 | from IPython.core.display import Javascript 10 | from IPython.display import display 11 | 12 | from .magics import CoMLMagics 13 | 14 | display( 15 | Javascript((Path(__file__).parent / "js" / "nbclassic_init.js").read_text()) 16 | ) 17 | 18 | ipython.register_magics(CoMLMagics) 19 | 20 | print(f"CoML {__version__} loaded.") 21 | 22 | 23 | def _jupyter_labextension_paths(): 24 | return [{"src": "labextension", "dest": "coml"}] 25 | -------------------------------------------------------------------------------- /coml/configagent/.env.template: -------------------------------------------------------------------------------- 1 | ### OPENAI 2 | ## OPENAI_API_KEY - OpenAI API Key (Example: my-openai-api-key) 3 | OPENAI_API_KEY=your-openai-api-key 4 | 5 | ### DB 6 | ## COML_DB_BACKEND - Database backend (Example: sqlite) 7 | COML_DB_BACKEND=sqlite 8 | ## COML_DB_PATH - Path to database file (Example: ~/.coml/coml.db) - Only for sqlite 9 | COML_DB_PATH=~/.coml/coml.db 10 | ## COML_DB_NAME - Database name (Example: coml) 11 | COML_DB_NAME=coml 12 | ## COML_DB_HOST - Database host (Example: localhost) 13 | COML_DB_HOST=localhost 14 | ## COML_DB_PORT - Database port (Example: 5432) 15 | COML_DB_PORT=5432 16 | ## COML_DB_USER - Database user (Example: postgres) 17 | COML_DB_USER=postgres 18 | ## COML_DB_PASSWORD - Database password (Example: '') 19 | COML_DB_PASSWORD='' -------------------------------------------------------------------------------- /coml/configagent/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | from dotenv import load_dotenv 6 | 7 | dotenv_dir = Path.home() / ".coml" 8 | dotenv_path = (dotenv_dir / ".env").resolve() 9 | 10 | if not os.path.exists(dotenv_dir): 11 | os.makedirs(dotenv_dir, exist_ok=True) 12 | if not os.path.exists(dotenv_path): 13 | # copy the default .env file 14 | shutil.copyfile(Path(__file__).parent / ".env.template", dotenv_path) 15 | 16 | # Load the users .env file into environment variables 17 | load_dotenv(dotenv_path, verbose=True, override=False) 18 | 19 | del load_dotenv 20 | -------------------------------------------------------------------------------- /coml/configagent/__main__.py: -------------------------------------------------------------------------------- 1 | from . import cli 2 | 3 | if __name__ == "__main__": 4 | cli.main() 5 | -------------------------------------------------------------------------------- /coml/configagent/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import click 4 | 5 | from .orm import database_proxy 6 | 7 | 8 | @click.group(invoke_without_command=True) 9 | @click.option("--space", help="Space ID.") 10 | @click.option("--task", help="Task description.") 11 | @click.option("--interactive", help="Interactive mode.", is_flag=True) 12 | @click.pass_context 13 | def main( 14 | ctx: click.Context, 15 | space: Optional[str] = None, 16 | task: Optional[str] = None, 17 | interactive: bool = False, 18 | ) -> None: 19 | if ctx.invoked_subcommand is None: 20 | if ctx.params["interactive"]: 21 | from .suggest import suggest_interactive 22 | 23 | suggest_interactive() 24 | database_proxy.close() 25 | else: 26 | if ctx.params["space"] is None or ctx.params["task"] is None: 27 | print("Please specify space ID and a task description.") 28 | return 29 | from .space import import_space 30 | from .suggest import print_suggested_configs, suggest 31 | 32 | results = suggest(import_space(ctx.params["space"]), ctx.params["task"]) 33 | print_suggested_configs(*results) 34 | database_proxy.close() 35 | 36 | 37 | @main.command() 38 | @click.argument("space", nargs=1) 39 | @click.argument("history", nargs=1) 40 | @click.argument("task_desc", nargs=1) 41 | @click.option("--space-desc", help="Space description path (optional).") 42 | @click.option("--no-knowledge", help="Do not generate knowledge.", is_flag=True) 43 | def create( 44 | space: str, 45 | history: str, 46 | task_desc: str, 47 | space_desc: str = None, 48 | no_knowledge: bool = False, 49 | ) -> None: 50 | """ 51 | Create a space from history csv file and task description json file. 52 | 53 | Parameters 54 | ---------- 55 | space: str 56 | The ID of the space to identify the space. 57 | history: str 58 | The path to the history of configurations. A csv file, format see `coml.experience.ingest_experience`. 59 | task_desc: str 60 | The JSON path to the task description. A json file, format see `coml.experience.ingest_experience`. 61 | space_desc: str 62 | The text path to the space description. Optional. 63 | no_knowledge: bool 64 | Whether to generate knowledge from history. 65 | 66 | Returns 67 | ------- 68 | None 69 | """ 70 | from .space import create_space 71 | 72 | create_space(space, history, task_desc, space_desc, no_knowledge) 73 | database_proxy.close() 74 | 75 | 76 | @main.command() 77 | def list() -> None: 78 | from .space import print_space 79 | 80 | print_space() 81 | database_proxy.close() 82 | 83 | 84 | @main.command() 85 | @click.argument("space", nargs=1) 86 | def delete(space: str) -> None: 87 | from .space import delete_space 88 | 89 | delete_space(space) 90 | database_proxy.close() 91 | -------------------------------------------------------------------------------- /coml/configagent/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | __all__ = [ 5 | "COML_DB_PATH", 6 | "TOP_K", 7 | "EMBED_DIM", 8 | "bin_map", 9 | "inverse_bin_map", 10 | "q_num", 11 | "COML_DB_BACKEND", 12 | "COML_DB_NAME", 13 | "COML_DB_HOST", 14 | "COML_DB_PORT", 15 | "COML_DB_USER", 16 | "COML_DB_PASSWORD", 17 | "PROMPT_FORMATS", 18 | "DEFAULT_PROMPT_PREFIX", 19 | "DEFAULT_PROMPT_SUFFIX", 20 | "TOKEN_LIMIT", 21 | "TOKEN_COMPLETION_LIMIT", 22 | "RELAX_TOKEN", 23 | ] 24 | 25 | TOP_K = 3 26 | EMBED_DIM = 1536 27 | TOKEN_LIMIT = 4096 28 | TOKEN_COMPLETION_LIMIT = 800 29 | RELAX_TOKEN = 500 # RELAX_TOKEN is the number of tokens to void token limit 30 | 31 | COML_DB_BACKEND = os.environ.get("COML_DB_BACKEND", "sqlite") 32 | 33 | COML_DB_PATH = Path( 34 | os.environ.get("COML_DB_PATH", Path.home() / ".coml" / "coml.db") 35 | ).expanduser() 36 | 37 | COML_DB_NAME = os.environ.get("COML_DB_NAME", "coml") 38 | COML_DB_HOST = os.environ.get("COML_DB_HOST", "localhost") 39 | COML_DB_PORT = os.environ.get("COML_DB_PORT", 5432) 40 | COML_DB_USER = os.environ.get("COML_DB_USER", "postgres") 41 | COML_DB_PASSWORD = os.environ.get("COML_DB_PASSWORD", "") 42 | 43 | bin_map = { 44 | 0.1: "very small", 45 | 0.3: "small", 46 | 0.5: "medium", 47 | 0.7: "large", 48 | 0.9: "very large", 49 | } 50 | 51 | inverse_bin_map = {v: k for k, v in bin_map.items()} 52 | inverse_bin_map.update( 53 | { 54 | "very low": 0.1, 55 | "low": 0.3, 56 | "high": 0.7, 57 | "very high": 0.9, 58 | "extremely large": 0.9, 59 | "extremely small": 0.1, 60 | "extra large": 0.9, 61 | "extra small": 0.1, 62 | "very medium": 0.5, 63 | "very small": 0.1, 64 | "small": 0.3, 65 | "large": 0.7, 66 | "very large": 0.9, 67 | } 68 | ) 69 | 70 | q_num = sorted(list(bin_map.keys())) 71 | 72 | PROMPT_FORMATS = { 73 | "TOP_K", 74 | "knowledge", 75 | "space_desc", 76 | "new_task_desc", 77 | } 78 | 79 | DEFAULT_PROMPT_PREFIX = """{space_desc}\nRecommend best configurations to train a model for a new task. Format strictly follows this template: ```Configuration 1: {{parameter_1_name}} is {{parameter_1_value}}. {{parameter_2_name}} is {{parameter_2_value}}...{{parameter_n_name}} is {{parameter_n_value}}. 80 | Configuration 2: {{parameter_1_name}} is {{parameter_1_value}}. {{parameter_2_name}} is {{parameter_2_value}}...{{parameter_n_name}} is {{parameter_n_value}}. 81 | Configuration 3: {{parameter_1_name}} is {{parameter_1_value}}. {{parameter_2_name}} is {{parameter_2_value}}...{{parameter_n_name}} is {{parameter_n_value}}. 82 | ```\nHere are some tasks along with best hyper-parameter configurations to train a model on them.\n""" 83 | 84 | DEFAULT_PROMPT_SUFFIX = """\nGuidelines:{knowledge}\n\n\nBased on the examples(if provided) and guidelines(if provided) above, recommend {TOP_K} hyper-parameter configurations for a new classification dataset.\n\n{new_task_desc}""" 85 | -------------------------------------------------------------------------------- /coml/configagent/experience.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import OrderedDict 4 | from typing import Any, Dict, List, Optional, Tuple 5 | 6 | import langchain 7 | import numpy as np 8 | import orjson 9 | import pandas as pd 10 | from langchain.cache import InMemoryCache 11 | from langchain.globals import set_llm_cache 12 | from peewee import ModelSelect, fn 13 | 14 | from .constants import * 15 | from .orm import Knowledge, Solution, Space, Task, database_proxy 16 | from .utils import format_config, get_llm 17 | 18 | SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS 19 | 20 | 21 | set_llm_cache(InMemoryCache()) 22 | 23 | 24 | def ingest_experience( 25 | history_df: pd.DataFrame, 26 | task_desc: Optional[Dict[str, str]], 27 | space_desc: str, 28 | space_id: str, 29 | ) -> Space: 30 | """Ingest experience from history dataframe. 31 | 32 | Parameters 33 | ---------- 34 | history_df: pandas.DataFrame 35 | The history of configurations. 36 | task_desc 37 | The task descriptions. 38 | space_desc 39 | The space description. 40 | space_id 41 | The space id. 42 | 43 | Returns 44 | ------- 45 | None 46 | 47 | Notes 48 | ----- 49 | The history_df should be a dataframe with the following columns: 50 | - CONFIG_0 51 | - CONFIG_1 52 | - ... 53 | - CONFIG_N 54 | - METRIC 55 | The task_desc should be a dict with the following format: 56 | { 57 | "task_id_0": "task_description_0", 58 | "task_id_1": "task_description_1", 59 | ... 60 | "task_id_N": "task_description_N" 61 | } 62 | """ 63 | history_df = history_df.drop_duplicates() 64 | history_df["TASK_ID"] = history_df["TASK_ID"].astype(str) 65 | if task_desc is None: 66 | task_desc = {k: "." for k in history_df["TASK_ID"].unique()} 67 | experience_df = history_df[ 68 | history_df[["TASK_ID", "SCORE"]] 69 | .groupby("TASK_ID") 70 | .rank(method="first", ascending=False)["SCORE"] 71 | <= TOP_K 72 | ] 73 | quantile_info = get_quantile_stat(experience_df) 74 | 75 | space = _ingest_space(space_id, space_desc, quantile_info) 76 | 77 | _ingest_task(history_df, task_desc) 78 | 79 | _ingest_solution(history_df, space) 80 | 81 | print("Ingested experience into database.") 82 | # save db 83 | database_proxy.commit() 84 | return space 85 | 86 | 87 | def _ingest_space( 88 | space_id: str, space_desc: str, quantile_info: Dict[str, List[float]] 89 | ) -> Space: 90 | with database_proxy.atomic(): 91 | try: 92 | space = Space.get(Space.space_id == space_id) 93 | print(f"Space {space_id} already exists, skip ingestion.") 94 | return space 95 | except: 96 | space = Space.create( 97 | space_id=space_id, 98 | desc=space_desc, 99 | quantile_info=orjson.dumps(quantile_info, option=SAVE_OPTIONS), 100 | ) 101 | print("Ingested space into database.") 102 | # save db 103 | database_proxy.commit() 104 | return space 105 | 106 | 107 | def _ingest_task(history_df: pd.DataFrame, row_task_desc: Dict[str, str]) -> None: 108 | embeddings = get_llm("embedding")() 109 | with database_proxy.atomic(): 110 | for task_id in history_df["TASK_ID"].unique(): 111 | try: 112 | Task.get(Task.task_id == task_id) 113 | print(f"Task {task_id} already exists, skip ingestion.") 114 | continue 115 | except Task.DoesNotExist: 116 | task_desc = canonicalize_task(row_task_desc[task_id]) 117 | embedding = np.asarray( 118 | embeddings.embed_query(task_desc), dtype=np.float32 119 | ).tobytes() 120 | Task.create( 121 | task_id=task_id, 122 | embedding=embedding, 123 | desc=task_desc, 124 | row_desc=row_task_desc[task_id], 125 | ) 126 | print("Ingested task into database.") 127 | # save db 128 | database_proxy.commit() 129 | 130 | 131 | def _ingest_solution(history_df: pd.DataFrame, space: Space) -> None: 132 | with database_proxy.atomic(): 133 | solutions = [] 134 | 135 | for _, row in ( 136 | history_df.groupby("TASK_ID") 137 | .apply(lambda x: x.sort_values("SCORE", ascending=False)) 138 | .reset_index(drop=True) 139 | .iterrows() 140 | ): 141 | row_config_dict = { 142 | k[7:]: v for k, v in row.to_dict().items() if k.startswith("CONFIG_") 143 | } 144 | row_config = orjson.dumps((row_config_dict), option=SAVE_OPTIONS) 145 | try: 146 | Solution.get( 147 | Solution.task == row["TASK_ID"], 148 | Solution.space == space.space_id, 149 | Solution.row_config == row_config, 150 | ) 151 | print( 152 | f"Solution '{row_config}' for space {space.space_id} and task {row['TASK_ID']} already exists, skip ingestion." 153 | ) 154 | continue 155 | except Solution.DoesNotExist: 156 | solutions.append( 157 | { 158 | "task": row["TASK_ID"], 159 | "space": space, 160 | "metric": row["SCORE"], 161 | "row_config": row_config, 162 | "extra_metric": row.get("EXTRA_METRIC", ""), 163 | "demo": canonicalize_config( 164 | row_config_dict, orjson.loads(space.quantile_info) 165 | ), 166 | } 167 | ) 168 | Solution.insert_many(solutions).execute() 169 | print("Ingested solution into database.") 170 | # save db 171 | database_proxy.commit() 172 | 173 | 174 | def canonicalize_task(row_task_desc: str) -> str: 175 | """ 176 | Canonicalize task to a string. 177 | 178 | Parameters 179 | ---------- 180 | row_task_desc: str 181 | The task description. 182 | 183 | Returns 184 | ------- 185 | str 186 | The canonicalized task description. 187 | """ 188 | task_desc = f"Task: {row_task_desc}" 189 | return task_desc 190 | 191 | 192 | def canonicalize_config( 193 | config_term: Dict[str, Any], quantile_info: Dict[str, List[float]] 194 | ) -> str: 195 | """ 196 | Canonicalize configuration to a string. 197 | 198 | Parameters 199 | ---------- 200 | config_term 201 | The configuration term. 202 | quantile_info 203 | The meta train info for stat. Record quantile information for each hyperparameter. 204 | 205 | Returns 206 | ------- 207 | str 208 | The canonicalized configuration. 209 | """ 210 | demo = format_config(config_term, quantile_info, bin_map=bin_map) 211 | return demo 212 | 213 | 214 | def get_quantile_stat(experience_df: pd.DataFrame) -> Dict[str, List[float]]: 215 | """ 216 | Get quantile stat from experience dataframe. 217 | 218 | Parameters 219 | ---------- 220 | experience_df: pandas.DataFrame 221 | The experience dataframe. 222 | 223 | Returns 224 | ------- 225 | Dict[str, List[float]] 226 | The quantile stat. 227 | """ 228 | # map column name 'CONFIG_{}' to '{}', rest of columns are deleted 229 | meta_train_df_for_stat = experience_df.rename( 230 | columns=lambda x: x[7:] if x.startswith("CONFIG_") else None 231 | ) 232 | del meta_train_df_for_stat[None] 233 | 234 | try: 235 | quantile_info = { 236 | col: np.quantile( 237 | meta_train_df_for_stat[col][~meta_train_df_for_stat[col].isna()], 238 | list(bin_map.keys()), 239 | method="nearest", 240 | ) 241 | .astype(np.float32) 242 | .tolist() 243 | for col in meta_train_df_for_stat.select_dtypes(include=[np.number]) 244 | } 245 | except: 246 | quantile_info = { 247 | col: np.quantile( 248 | meta_train_df_for_stat[col][~meta_train_df_for_stat[col].isna()], 249 | list(bin_map.keys()), 250 | interpolation="nearest", 251 | ) 252 | .astype(np.float32) 253 | .tolist() 254 | for col in meta_train_df_for_stat.select_dtypes(include=[np.number]) 255 | } 256 | 257 | return quantile_info 258 | 259 | 260 | def gen_experience_per_task(space: Space, task: Task) -> str: 261 | """ 262 | Generate experience content from space and task. 263 | 264 | Parameters 265 | ---------- 266 | task_desc: str 267 | The task description. 268 | demos: str 269 | The demos. 270 | example: 271 | Configurations 1: ... 272 | Configurations 2: ... 273 | Configurations 3: ... 274 | 275 | Returns 276 | ------- 277 | str 278 | The experience content. 279 | """ 280 | demos = _gen_experience_demos(space, task) 281 | content = f"{task.desc}\n{demos}" 282 | return content 283 | 284 | 285 | def _gen_experience_demos(space: Space, task: Task) -> str: 286 | """ 287 | Generate experience demonstrations 288 | 289 | Parameters 290 | ---------- 291 | space: Space 292 | The space. 293 | task: Task 294 | The task. 295 | 296 | Returns 297 | ------- 298 | str 299 | The experience demonstrations per task. 300 | """ 301 | solutions = ( 302 | Solution.select() 303 | .where(Solution.task_id == task.task_id, Solution.space_id == space.space_id) 304 | .order_by(Solution.metric.desc()) 305 | .limit(TOP_K) 306 | ) 307 | demos = "\n".join( 308 | [ 309 | f"Configuration {i+1}: {solution.demo}" 310 | for i, solution in enumerate(solutions) 311 | ] 312 | ) 313 | return demos 314 | 315 | 316 | def _get_best_relevant_solutions(space: Space, task_desc: str) -> ModelSelect: 317 | """ 318 | Get the best relevant solution for a task. 319 | The relevance is measured by cosine similarity between task description embeddings, which affects the order of results. 320 | 321 | Parameters 322 | ---------- 323 | space: Space 324 | The space. 325 | task_desc: str 326 | The task description. 327 | 328 | Returns 329 | ------- 330 | ModelSelect 331 | The best relevant solution. 332 | """ 333 | SolutionAlias = Solution.alias() 334 | order_key = Task.embedding.cosine_distance(task_desc) 335 | subquery = ( 336 | SolutionAlias.select( 337 | SolutionAlias.demo, 338 | Task.task_id, 339 | Task.desc, 340 | Task.embedding, 341 | fn.ROW_NUMBER() 342 | .over( 343 | partition_by=[SolutionAlias.space, SolutionAlias.task], 344 | order_by=[SolutionAlias.metric.desc()], 345 | ) 346 | .alias("rnk"), 347 | ) 348 | .where(SolutionAlias.space == space) 349 | .join(Task, on=(SolutionAlias.task == Task.task_id)) 350 | .order_by(order_key) 351 | .alias("subq") 352 | ) 353 | query = ( 354 | Solution.select(subquery.c.task_id, subquery.c.demo, subquery.c.desc) 355 | .from_(subquery) 356 | .where(subquery.c.rnk <= TOP_K) 357 | ) 358 | return query 359 | 360 | 361 | def _get_best_solutions(space: Space) -> ModelSelect: 362 | """ 363 | Get the best solution for each task. 364 | 365 | Parameters 366 | ---------- 367 | space: Space 368 | The space. 369 | 370 | Returns 371 | ------- 372 | ModelSelect 373 | The best solution for each task. 374 | """ 375 | SolutionAlias = Solution.alias() 376 | subquery = ( 377 | SolutionAlias.select( 378 | SolutionAlias.demo, 379 | Task.task_id, 380 | Task.desc, 381 | Task.embedding, 382 | fn.ROW_NUMBER() 383 | .over( 384 | partition_by=[SolutionAlias.space, SolutionAlias.task], 385 | order_by=[SolutionAlias.metric.desc()], 386 | ) 387 | .alias("rnk"), 388 | ) 389 | .where(SolutionAlias.space == space) 390 | .join(Task, on=(SolutionAlias.task == Task.task_id)) 391 | .alias("subq") 392 | ) 393 | query = ( 394 | Solution.select(subquery.c.task_id, subquery.c.demo, subquery.c.desc) 395 | .from_(subquery) 396 | .where(subquery.c.rnk <= TOP_K) 397 | ) 398 | return query 399 | 400 | 401 | def gen_experience( 402 | space: Space, task_desc: Optional[str] = None 403 | ) -> Tuple[List[str], List[str]]: 404 | """ 405 | Generate experience content from space and optional task description. 406 | 407 | Parameters 408 | ---------- 409 | space: Space 410 | The space. 411 | task_desc 412 | The task description. 413 | 414 | Returns 415 | ------- 416 | List[str] 417 | The experience content. 418 | """ 419 | if task_desc is None: 420 | query = _get_best_solutions(space) 421 | else: 422 | query = _get_best_relevant_solutions(space, task_desc) 423 | examples = OrderedDict() 424 | 425 | for solution in query: 426 | if solution.task_id not in examples: 427 | examples[solution.task_id] = [solution.desc] 428 | examples[solution.task_id].append( 429 | f"Configuration {len(examples[solution.task_id])}: {solution.demo}" 430 | ) 431 | return list(examples.keys()), ["\n".join(e) for e in examples.values()] 432 | -------------------------------------------------------------------------------- /coml/configagent/knowledge.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | from typing import Any, Callable, Dict, List, Optional 4 | 5 | import orjson 6 | from langchain.prompts import FewShotPromptTemplate, PromptTemplate 7 | from langchain.prompts.example_selector import LengthBasedExampleSelector 8 | 9 | from .constants import * 10 | from .constants import TOKEN_COMPLETION_LIMIT, TOKEN_LIMIT 11 | from .experience import gen_experience 12 | from .orm import Knowledge, Solution, Space, Task, database_proxy 13 | from .surrogate_utils import evaluate_configs 14 | from .utils import get_llm, get_token_count_func, parse_configs 15 | 16 | prefix_sep = "__DUMM_SEP__" 17 | 18 | 19 | def gen_knowledge_candidate(examples: List[str]) -> str: 20 | """ 21 | Generate knowledge candidate from examples. 22 | 23 | Parameters 24 | ---------- 25 | examples: list of strings 26 | The list of examples. 27 | 28 | Returns 29 | ------- 30 | str 31 | The generated knowledge candidate. 32 | """ 33 | prefix_token = get_token_count_func()( 34 | "Here are some tasks along with best hyper-parameter configurations to train a model on them.\n" 35 | ) 36 | suffix_token = get_token_count_func()( 37 | "\nQ: From the examples above, what patterns can we observe about the relationship between dataset characteristics and the best hyper-parameter configurations? (Answer MUST be concise, critical, point-by-point, line-by-line, and brief. Only include relevant observations without unnecessary elaboration.)\n\nA: 1." 38 | ) 39 | example_prompt = PromptTemplate( 40 | input_variables=["input"], 41 | template="{input}", 42 | ) 43 | 44 | example_selector = LengthBasedExampleSelector( 45 | examples=[{"input": example} for example in examples], 46 | example_prompt=example_prompt, 47 | max_length=TOKEN_LIMIT 48 | - prefix_token 49 | - suffix_token 50 | - TOKEN_COMPLETION_LIMIT 51 | - RELAX_TOKEN, 52 | get_text_length=get_token_count_func(), 53 | ) 54 | 55 | dynamic_prompt = FewShotPromptTemplate( 56 | # We provide an ExampleSelector instead of examples. 57 | example_selector=example_selector, 58 | example_prompt=example_prompt, 59 | prefix="Here are some tasks along with best hyper-parameter configurations to train a model on them.\n", 60 | suffix="\nQ: From the examples above, what patterns can we observe about the relationship between dataset characteristics and the best hyper-parameter configurations? (Answer MUST be concise, critical, point-by-point, line-by-line, and brief. Only include relevant observations without unnecessary elaboration.)\n\nA: 1.", 61 | input_variables=[], 62 | ) 63 | llm = get_llm("knowledge")() 64 | knowledge = "\n1." + llm(dynamic_prompt.format()) 65 | return knowledge 66 | 67 | 68 | def suggest_with_knowledge( 69 | examples: List[str], 70 | knowledge: str, 71 | valid_example: str, 72 | quantile_infos: Dict[str, List[float]], 73 | ) -> List[Dict[str, Any]]: 74 | """ 75 | Suggest configurations with knowledge. 76 | 77 | Parameters 78 | ---------- 79 | examples 80 | The list of examples. 81 | knowledge: str 82 | The knowledge. 83 | valid_example: str 84 | The valid example. 85 | quantile_infos 86 | The meta train info for stats. Used to convert the text to config value. 87 | 88 | Returns 89 | ------- 90 | List[Dict[str, Any]] 91 | The list of suggested configurations. 92 | """ 93 | prefix_token = get_token_count_func()( 94 | "Here are some tasks along with best hyper-parameter configurations to train a model on them.\n" 95 | ) 96 | suffix_token = get_token_count_func()( 97 | "\nGuidelines:{knowledge}\n\n\nBased on the examples and guidelines above, recommend {TOP_K} hyper-parameter configurations for a new classification dataset.\n\n{output}".format( 98 | knowledge=knowledge, 99 | TOP_K=str(TOP_K), 100 | output=( 101 | valid_example[: valid_example.index("\nConfiguration 1:")] + "\n\n" 102 | ), 103 | ) 104 | ) 105 | example_prompt = PromptTemplate( 106 | input_variables=["input"], 107 | template="{input}", 108 | ) 109 | 110 | example_selector = LengthBasedExampleSelector( 111 | examples=[{"input": example} for example in examples], 112 | example_prompt=example_prompt, 113 | max_length=TOKEN_LIMIT 114 | - prefix_token 115 | - suffix_token 116 | - TOKEN_COMPLETION_LIMIT 117 | - RELAX_TOKEN, 118 | get_text_length=get_token_count_func(), 119 | ) 120 | 121 | dynamic_prompt = FewShotPromptTemplate( 122 | # We provide an ExampleSelector instead of examples. 123 | example_selector=example_selector, 124 | example_prompt=example_prompt, 125 | prefix="Here are some tasks along with best hyper-parameter configurations to train a model on them.\n", 126 | suffix="\nGuidelines:{knowledge}\n\n\nBased on the examples and guidelines above, recommend {TOP_K} hyper-parameter configurations for a new classification dataset.\n\n{output}", 127 | input_variables=["knowledge", "TOP_K", "output"], 128 | ) 129 | 130 | llm = get_llm("suggest")() 131 | 132 | response = llm( 133 | dynamic_prompt.format( 134 | knowledge=knowledge, 135 | TOP_K=str(TOP_K), 136 | output=( 137 | valid_example[: valid_example.index("\nConfiguration 1:")] + "\n\n" 138 | ), 139 | ) 140 | ) 141 | suggest_configs = parse_configs( 142 | response, 143 | TOP_K, 144 | inverse_bin_map, 145 | quantile_infos, 146 | ) 147 | return suggest_configs 148 | 149 | 150 | def post_validation( 151 | space: Space, surrogate_fn: Callable, config_names: List[str] 152 | ) -> List[str]: 153 | """ 154 | Post validation to generate knowledge. 155 | 156 | Parameters 157 | ---------- 158 | space: Space 159 | The space. 160 | surrogate_fn: Callable 161 | The surrogate function. 162 | config_names: list of str 163 | The list of configuration names. 164 | 165 | Returns 166 | ------- 167 | List[str] 168 | The list of generated knowledge. 169 | """ 170 | knowledges = get_knowledge(space) 171 | if knowledges != "": 172 | print("Knowledge already exists.") 173 | return knowledges 174 | quantile_infos = orjson.loads(space.quantile_info) 175 | retrieved_tasks, examples = gen_experience(space) 176 | best_score = float("-inf") 177 | knowledges = None 178 | for _ in range(3): 179 | random.shuffle(examples) 180 | knowledge_candidate = gen_knowledge_candidate(examples) 181 | score = 0 182 | for val_step in range(3): 183 | random.shuffle(examples) 184 | assert len(examples) > 1, "Not enough examples in the memory." 185 | valid_example = examples[-1] 186 | train_example = examples[:-1] 187 | 188 | try: 189 | suggest_configs = suggest_with_knowledge( 190 | train_example, 191 | knowledge_candidate, 192 | valid_example, 193 | quantile_infos, 194 | ) 195 | _score = evaluate_configs( 196 | surrogate_fn, suggest_configs, config_names 197 | ).mean() 198 | except: 199 | _score = float("-inf") 200 | score += _score 201 | if best_score < score: 202 | best_score = score 203 | knowledges = knowledge_candidate 204 | assert knowledges is not None, "Knowledge is not generated." 205 | 206 | knowledges = split_knowledge(knowledges) 207 | return knowledges 208 | 209 | 210 | def get_knowledge(space: Space, task=None): 211 | try: 212 | knowledges = Knowledge.select().where( 213 | (Knowledge.space_id == space.space_id) 214 | & ((Knowledge.task == task) | (Knowledge.task == None)) 215 | ) 216 | knowledge_str = "" 217 | for i, knowledge in enumerate(knowledges): 218 | knowledge_str += f"{i+1}. {knowledge.knowledge}\n\n" 219 | return knowledge_str 220 | except: 221 | return "" 222 | 223 | 224 | def split_knowledge(knowledges: str) -> List[str]: 225 | """ 226 | Split the knowledge into a list of knowledge. 227 | 228 | Parameters 229 | ---------- 230 | knowledges: str 231 | The knowledge. 232 | 233 | Returns 234 | ------- 235 | List[str] 236 | The list of knowledge. 237 | 238 | Examples 239 | -------- 240 | >>> split_knowledge("1. A\n2. B\n3. C\n") 241 | ["A", "B", "C"] 242 | """ 243 | return [ 244 | k.strip() 245 | for k in re.findall( 246 | r"\n\d+\.([\s\S]+?)(?=\n+\d+\.)", "\n" + knowledges + "\n999." 247 | ) 248 | ] 249 | -------------------------------------------------------------------------------- /coml/configagent/orm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import numpy as np 4 | from peewee import ( 5 | AutoField, 6 | BlobField, 7 | CompositeKey, 8 | DatabaseProxy, 9 | Expression, 10 | Field, 11 | FloatField, 12 | ForeignKeyField, 13 | Model, 14 | ModelBase, 15 | PrimaryKeyField, 16 | TextField, 17 | Value, 18 | fn, 19 | ) 20 | 21 | try: 22 | from pgvector.psycopg2 import VectorAdapter, register_vector 23 | from pgvector.utils import from_db, to_db 24 | except ImportError: 25 | from_db = to_db = None 26 | 27 | from .constants import * 28 | from .utils import get_llm 29 | 30 | 31 | class ArrayField(BlobField): 32 | field_type = "BLOB" 33 | 34 | def db_value(self, value): 35 | if isinstance(value, np.ndarray): 36 | value = value.tobytes() 37 | return super().db_value(value) 38 | 39 | def python_value(self, value): 40 | return np.frombuffer(value, dtype=np.float32) 41 | 42 | def cosine_distance(self, text: str): 43 | return fn.cosine_similarity(self, text).desc() 44 | 45 | 46 | class VectorField(Field): 47 | field_type = "VECTOR" 48 | 49 | def __init__(self, dim=None, *args, **kwargs): 50 | self.dim = dim 51 | super().__init__(*args, **kwargs) 52 | 53 | def get_modifiers(self) -> None: 54 | return [self.dim] 55 | 56 | def db_value(self, value): 57 | if isinstance(value, str): 58 | return value 59 | elif isinstance(value, bytes): 60 | value = np.frombuffer(value, dtype=np.float32) 61 | return to_db(value, self.dim) 62 | 63 | def python_value(self, value): 64 | return from_db(value) 65 | 66 | def cosine_distance(self, text: str): 67 | text_emb = np.asarray( 68 | get_llm("embedding")().embed_query(text), dtype=np.float32 69 | ) 70 | return Expression(self, "<=>", Value(to_db(text_emb, self.dim), unpack=False)) 71 | 72 | 73 | database_proxy = DatabaseProxy() 74 | 75 | if COML_DB_BACKEND == "sqlite": 76 | from peewee import SqliteDatabase 77 | 78 | init_db_func = lambda: SqliteDatabase(COML_DB_PATH) 79 | elif COML_DB_BACKEND == "postgres": 80 | from peewee import PostgresqlDatabase 81 | 82 | init_db_func = lambda: PostgresqlDatabase( 83 | COML_DB_NAME, 84 | host=COML_DB_HOST, 85 | port=COML_DB_PORT, 86 | user=COML_DB_USER, 87 | password=COML_DB_PASSWORD, 88 | ) 89 | else: 90 | raise NotImplementedError(f"COML_DB_BACKEND {COML_DB_BACKEND} not supported.") 91 | 92 | 93 | def init_db(): 94 | database_proxy.initialize(init_db_func()) 95 | conn = database_proxy.connection() 96 | if COML_DB_BACKEND == "postgres": 97 | register_vector(conn) 98 | database_proxy.create_tables([Space, Task, Solution, Knowledge]) 99 | 100 | if COML_DB_BACKEND == "sqlite": 101 | _cache = {} 102 | 103 | @database_proxy.func() 104 | def cosine_similarity(task_emb: BlobField, text: str) -> float: 105 | emb = np.frombuffer(task_emb, dtype=np.float32) 106 | if text not in _cache: 107 | _cache[text] = np.asarray( 108 | get_llm("embedding")().embed_query(text), dtype=np.float32 109 | ) 110 | text_emb = _cache[text] 111 | return np.dot(emb, text_emb).item() 112 | 113 | 114 | class BaseModel(Model): 115 | class Meta: 116 | database = database_proxy 117 | 118 | 119 | class Space(BaseModel): 120 | space_id: str = TextField(primary_key=True) 121 | desc = TextField() 122 | quantile_info = BlobField(null=True) 123 | prefix = TextField(default=DEFAULT_PROMPT_PREFIX) 124 | suffix = TextField(default=DEFAULT_PROMPT_SUFFIX) 125 | 126 | 127 | class Task(BaseModel): 128 | task_id: str = TextField(primary_key=True) 129 | embedding = ArrayField() if COML_DB_BACKEND == "sqlite" else VectorField(EMBED_DIM) 130 | desc = TextField() 131 | row_desc = TextField() 132 | 133 | 134 | class Solution(BaseModel): 135 | demo = TextField() 136 | task = ForeignKeyField(Task, backref="solutions") 137 | space = ForeignKeyField(Space, backref="solutions") 138 | metric = FloatField() 139 | row_config = TextField() 140 | extra_metric = TextField() 141 | 142 | class Meta: 143 | primary_key = CompositeKey("space", "task", "row_config") 144 | database = database_proxy 145 | 146 | 147 | class Knowledge(BaseModel): 148 | knowledge = TextField() 149 | space = ForeignKeyField(Space, backref="knowledge") 150 | task = ForeignKeyField(Task, backref="knowledge", null=True) 151 | 152 | 153 | def import_db(tables: Dict[ModelBase, List[Dict[str, Any]]]) -> None: 154 | """ 155 | Imports the contents of the database from a dictionary. 156 | 157 | Parameters: 158 | ----------- 159 | tables: Dict[BaseModel, List[Dict[str, Any]]] 160 | A dictionary with the tables as keys and a list of records as values. 161 | 162 | Returns: 163 | -------- 164 | None 165 | """ 166 | for table, records in tables.items(): 167 | # skip duplicate records 168 | with database_proxy.atomic(): 169 | for record in records: 170 | try: 171 | eval(table).insert(record).execute() 172 | except: 173 | pass 174 | 175 | 176 | init_db() 177 | -------------------------------------------------------------------------------- /coml/configagent/space.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List, Optional 4 | 5 | import pandas as pd 6 | 7 | from .experience import ingest_experience 8 | from .knowledge import get_knowledge 9 | from .orm import Knowledge, Solution, Space, Task, database_proxy 10 | 11 | 12 | def gen_space_description( 13 | history_df: pd.DataFrame, space_desc: Optional[str] = None 14 | ) -> str: 15 | """ 16 | Generate space description from history_df and space_desc. 17 | 18 | Parameters 19 | ---------- 20 | history_df: pandas.DataFrame 21 | The history of configurations. 22 | space_desc: str | None 23 | The path to the space description. 24 | 25 | Returns 26 | ------- 27 | str 28 | The generated space description. 29 | """ 30 | history_df = history_df.rename( 31 | columns=lambda x: x[7:] if x.startswith("CONFIG_") else None 32 | ) 33 | del history_df[None] 34 | history_df.columns 35 | descriptions = """Space has {} configurable hyper-parameters, i.e., {}.\n""".format( 36 | len(history_df.columns), 37 | ", ".join(["'{}'".format(x) for x in history_df.columns]), 38 | ) 39 | return descriptions + (space_desc if space_desc is not None else "") 40 | 41 | 42 | def create_tables(): 43 | with database_proxy: 44 | database_proxy.create_tables([Task, Solution, Space, Knowledge]) 45 | 46 | 47 | def drop_tables(): 48 | with database_proxy: 49 | database_proxy.drop_tables([Task, Solution, Space, Knowledge]) 50 | 51 | 52 | def create_space( 53 | space_id: str, 54 | history: str, 55 | task_desc: Optional[str] = None, 56 | space_desc: Optional[str] = None, 57 | no_knowledge: bool = False, 58 | ) -> Space: 59 | """ 60 | Create a space from history csv file and task description json file. 61 | 62 | Parameters 63 | ---------- 64 | space_id: str 65 | The ID of the space to identify the space. 66 | history: str 67 | The path to the history of configurations. A csv file, format see `coml.experience.ingest_experience`. 68 | task_desc: str 69 | The JSON path to the task description. A json file, format see `coml.experience.ingest_experience`. 70 | space_desc: str 71 | The text path to the space description. Optional. 72 | no_knowledge: bool 73 | Whether to generate knowledge from history. 74 | 75 | Returns 76 | ------- 77 | None 78 | 79 | Examples 80 | -------- 81 | >>> create_space( 82 | ... space_id="example_space", 83 | ... history="assets/example_history.csv", 84 | ... task_desc="assets/example_task_description.json", 85 | ... space_desc="assets/example_space_desc.txt", 86 | ... ) 87 | """ 88 | history_df = pd.read_csv(history) 89 | if task_desc is not None: 90 | task_desc = json.loads(Path(task_desc).read_text()) 91 | space_desc = Path(space_desc).read_text() if space_desc is not None else None 92 | space_desc = gen_space_description(history_df, space_desc) 93 | space = ingest_experience(history_df, task_desc, space_desc, space_id) 94 | 95 | if not no_knowledge and get_knowledge(space) == "": 96 | from .knowledge import post_validation 97 | from .surrogate_utils import process_history_df, train_surrogate 98 | 99 | history_df_processed, config_names = process_history_df(history_df) 100 | surrogate_fn = train_surrogate(history_df_processed) 101 | knowledges = post_validation(space, surrogate_fn, config_names) 102 | for knowledge in knowledges: 103 | Knowledge.create(space_id=space.space_id, knowledge=knowledge) 104 | database_proxy.commit() 105 | return space 106 | 107 | 108 | def list_available_spaces() -> List[Space]: 109 | """ 110 | List all available spaces. 111 | 112 | Returns 113 | ------- 114 | list 115 | A list of available spaces. 116 | """ 117 | return list(Space.select()) 118 | 119 | 120 | def print_space() -> list: 121 | """ 122 | Print all available spaces to the console. 123 | 124 | Returns 125 | ------- 126 | list 127 | A list of available spaces. 128 | """ 129 | print("Current space available: ") 130 | available_spaces = list_available_spaces() 131 | for i, space in enumerate(available_spaces): 132 | print(f"{i+1}. Design space: (Space ID = {space.space_id}){space.desc}" "\n\n") 133 | return available_spaces 134 | 135 | 136 | def delete_space(space_id: str): 137 | """ 138 | Delete a space. 139 | 140 | Parameters 141 | ---------- 142 | space_id: str 143 | The ID of the space to delete. 144 | 145 | Returns 146 | ------- 147 | None 148 | """ 149 | space = import_space(space_id) 150 | if space is None: 151 | print(f"Space {space_id} does not exist.") 152 | return 153 | else: 154 | print(f"Deleting space {space_id}...") 155 | 156 | Solution.delete().where(Solution.space == space.space_id).execute() 157 | Knowledge.delete().where(Knowledge.space == space.space_id).execute() 158 | # delete task if no other space is using it 159 | for task in Task.select(): 160 | if Solution.select().where(Solution.task == task.task_id).count() == 0: 161 | task.delete_instance() 162 | 163 | space.delete_instance() 164 | print(f"Space {space_id} deleted.") 165 | 166 | 167 | def import_space(space_id: str) -> Space: 168 | """ 169 | Import a space. 170 | 171 | Parameters 172 | ---------- 173 | space_id: str 174 | The ID of the space to import. 175 | 176 | Returns 177 | ------- 178 | Space 179 | The imported space. 180 | """ 181 | try: 182 | space = Space.get(Space.space_id == space_id) 183 | except: 184 | space = None 185 | return space 186 | 187 | 188 | create_tables() 189 | -------------------------------------------------------------------------------- /coml/configagent/suggest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 5 | 6 | import orjson 7 | from langchain import FewShotPromptTemplate, PromptTemplate 8 | from langchain.prompts.example_selector import LengthBasedExampleSelector 9 | from peewee import fn 10 | 11 | from .constants import * 12 | from .experience import gen_experience 13 | from .knowledge import get_knowledge 14 | from .orm import Knowledge, Solution, Space, Task, database_proxy 15 | from .space import import_space, print_space 16 | from .utils import ( 17 | clean_input, 18 | escape, 19 | get_llm, 20 | get_token_count_func, 21 | parse_configs, 22 | set_llms, 23 | ) 24 | 25 | 26 | def print_suggested_configs(configurations: Any, knowledge: str | None) -> None: 27 | if knowledge: 28 | print(f"\n* Rationale: ") 29 | print(knowledge) 30 | print(f"\n* Recommended configurations: ") 31 | if isinstance(configurations, str): 32 | print(configurations) 33 | else: 34 | for i, suggest_config in enumerate(configurations): 35 | print(f"Suggested configuration {i+1}. {suggest_config}") 36 | 37 | 38 | def suggest_interactive() -> None: 39 | """ 40 | Suggest configurations interactively. 41 | 42 | Returns 43 | ------- 44 | None 45 | """ 46 | while True: 47 | print_space() 48 | 49 | select_space_id = clean_input(f"Please select a space(input space ID): ") 50 | space = import_space(select_space_id) 51 | if space is None: 52 | print(f"Space '{select_space_id}' not exist.") 53 | continue 54 | 55 | print( 56 | f"You selected space:\n (Space ID = {select_space_id}){space.desc}" "\n\n" 57 | ) 58 | task_desc = clean_input("Your description for new task: ").strip(".") + "." 59 | suggest_configs, knowledge = suggest(space, task_desc) 60 | print_suggested_configs(suggest_configs, knowledge) 61 | # press any key to continue, press 'q' to quit 62 | if clean_input('Press any key to continue, press "q" to quit: ') == "q": 63 | break 64 | 65 | 66 | def suggest(space: Space, task_desc: str) -> Tuple[Any, Union[str, None]]: 67 | """ 68 | Suggest configurations for a new task. 69 | 70 | Parameters 71 | ---------- 72 | space: Space 73 | The space to suggest configurations. 74 | task_desc: str 75 | The description of the new task. 76 | 77 | Returns 78 | ------- 79 | Tuple[Any, Union[str, None]] 80 | A tuple of suggested configurations and rationale. 81 | """ 82 | task_desc = f"""Task: {task_desc}""" 83 | 84 | retrieved_tasks, examples = gen_experience(space, task_desc) 85 | knowledge = ( 86 | get_knowledge(space, retrieved_tasks[0]) 87 | if len(retrieved_tasks) 88 | else get_knowledge(space) 89 | ) 90 | 91 | llm = get_llm("suggest")() 92 | quantile_infos = orjson.loads(space.quantile_info) if space.quantile_info else None 93 | 94 | prompt_addition_info = { 95 | "new_task_desc": task_desc, 96 | "knowledge": knowledge, 97 | "TOP_K": str(TOP_K), 98 | "space_desc": space.desc, 99 | } 100 | 101 | prefix = format_prompt(space.prefix, prompt_addition_info) 102 | suffix = format_prompt(space.suffix, prompt_addition_info) 103 | 104 | token_count_func = get_token_count_func() 105 | prefix_token = token_count_func(prefix) 106 | suffix_token = token_count_func(suffix) 107 | 108 | example_prompt = PromptTemplate( 109 | input_variables=["input"], 110 | template="{input}", 111 | ) 112 | 113 | example_selector = LengthBasedExampleSelector( 114 | examples=[{"input": escape(example)} for example in examples], 115 | example_prompt=example_prompt, 116 | max_length=TOKEN_LIMIT 117 | - prefix_token 118 | - suffix_token 119 | - TOKEN_COMPLETION_LIMIT 120 | - RELAX_TOKEN, 121 | get_text_length=token_count_func, 122 | ) 123 | 124 | dynamic_prompt = FewShotPromptTemplate( 125 | example_selector=example_selector, 126 | example_prompt=example_prompt, 127 | prefix=prefix, 128 | suffix=suffix, 129 | input_variables=[], 130 | ) 131 | prompt = dynamic_prompt.format() 132 | 133 | response = llm(prompt) 134 | if quantile_infos: 135 | suggest_configs = parse_configs( 136 | response, 137 | TOP_K, 138 | inverse_bin_map, 139 | quantile_infos, 140 | ) 141 | else: 142 | suggest_configs = response 143 | return suggest_configs, knowledge 144 | 145 | 146 | def format_prompt(prompt: str, prompt_optional_info: Dict[str, str]): 147 | fs = [f for f in PROMPT_FORMATS if re.search("(? Callable: 43 | """ 44 | Train a surrogate model based on the history of the user. 45 | 46 | Parameters 47 | ---------- 48 | history_df_processed: pd.DataFrame 49 | The processed history dataframe. 50 | 51 | Returns 52 | ------- 53 | surrogate: Callable 54 | The surrogate model. 55 | """ 56 | train_df, test_df = train_test_split( 57 | history_df_processed, test_size=0.2, random_state=42 58 | ) 59 | # normalize the score 60 | train_df["SCORE"] = (train_df["SCORE"] - train_df["SCORE"].min()) / ( 61 | train_df["SCORE"].max() - train_df["SCORE"].min() + 1.0e-4 62 | ) 63 | train_X = train_df.drop(columns=["SCORE"]) 64 | train_y = train_df["SCORE"] 65 | test_X = test_df.drop(columns=["SCORE"]) 66 | test_y = test_df["SCORE"] 67 | surrogate = Pipeline( 68 | [ 69 | ("scaler", RobustScaler()), 70 | ("xgb", xgb.XGBRegressor(objective="reg:squarederror")), 71 | ] 72 | ) 73 | surrogate.fit(train_X, train_y) 74 | preds = surrogate.predict(test_X) 75 | return surrogate 76 | 77 | 78 | def evaluate_configs( 79 | surrogate_fn: Callable, configs: List[Dict[str, Any]], config_names: List[str] 80 | ) -> np.ndarray: 81 | """ 82 | Evaluate a list of configs using the surrogate model. 83 | 84 | Parameters 85 | ---------- 86 | surrogate_fn: Callable 87 | The surrogate model. 88 | configs 89 | The list of configs to be evaluated. 90 | config_names: List[str] 91 | The list of config names. 92 | 93 | Returns 94 | ------- 95 | scores: np.ndarray 96 | The scores of the configs. It should be 1-dimensional. 97 | """ 98 | df = pd.DataFrame(columns=config_names) 99 | dumm_cols = set([n[: n.index(prefix_sep)] for n in config_names if prefix_sep in n]) 100 | for config in configs: 101 | # add one NaN row 102 | idx = len(df) 103 | df.loc[idx] = np.nan 104 | for k, v in config.items(): 105 | if k in dumm_cols: 106 | new_k = k + prefix_sep + v 107 | assert new_k in config_names, "Unknown config name: {}".format(new_k) 108 | df.loc[idx, new_k] = 1 109 | for _k in config_names: 110 | if _k != new_k and _k.startswith(k + prefix_sep): 111 | df.loc[idx, _k] = 0 112 | else: 113 | assert k in config_names, "Unknown config name: {}".format(k) 114 | df.loc[idx, k] = v 115 | return surrogate_fn.predict(df).ravel() 116 | -------------------------------------------------------------------------------- /coml/configagent/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from functools import lru_cache 4 | from pathlib import Path 5 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 6 | 7 | from langchain.embeddings import OpenAIEmbeddings 8 | from langchain_openai import OpenAI 9 | 10 | from .constants import * 11 | 12 | LLM_MODELS = { 13 | "suggest": lambda: OpenAI(model_name="text-davinci-003", temperature=0), 14 | "knowledge": lambda: OpenAI(model_name="text-davinci-003"), 15 | "embedding": lambda: OpenAIEmbeddings(), 16 | } 17 | _TOKEN_COUNT_FUNC = None 18 | 19 | 20 | def clean_input(prompt: str = ""): 21 | try: 22 | return input(prompt) 23 | except KeyboardInterrupt: 24 | print("You interrupted CoML") 25 | print("Quitting...") 26 | exit(0) 27 | 28 | 29 | pattern_0 = re.compile("Configuration(?: \d)*: (.*)\.\n") 30 | 31 | 32 | def parse_configs( 33 | response: str, 34 | TOP_k: int, 35 | inverse_bin_map: Dict = {}, 36 | quantile_info: Optional[Dict] = None, 37 | ) -> List[Dict]: 38 | """ 39 | Parse the response from the LLM API and return the suggested configurations. 40 | 41 | Parameters 42 | ---------- 43 | response : str 44 | The response from the LLM API. 45 | TOP_k : int 46 | The number of suggested configurations to return. 47 | inverse_bin_map : Dict 48 | The inverse bin map from the Space object. 49 | quantile_info : Optional[Dict] 50 | The meta train info for statistics from the Space object. 51 | 52 | Returns 53 | ------- 54 | suggest_configs : List[Dict] 55 | The suggested configurations. 56 | """ 57 | suggest_configs = [] 58 | groups = re.findall(pattern_0, response + "\n") 59 | for t in groups[:TOP_k]: 60 | kvs = t.split(". ") 61 | config = {} 62 | for kv in kvs: 63 | _k, v = kv.strip().split(" is ") 64 | if v in inverse_bin_map: 65 | config_col = list(quantile_info[_k]) 66 | value = config_col[q_num.index(inverse_bin_map[v])] 67 | elif v in ("True", "False"): 68 | value = eval(v) 69 | else: 70 | value = v 71 | config[_k] = value 72 | 73 | suggest_configs.append(config) 74 | return suggest_configs 75 | 76 | 77 | def format_config( 78 | config: Dict[str, Any], 79 | quantile_info: Optional[Dict[str, List[float]]] = None, 80 | bin_map: Dict[float, str] = {}, 81 | ) -> str: 82 | """ 83 | Format the configuration to a string which can be input to the LLM API. 84 | 85 | Parameters 86 | ---------- 87 | config 88 | The configuration to be formatted. 89 | quantile_info 90 | The meta train info for statistics from the Space object. 91 | bin_map 92 | The bin map is to map the bin value to the string. 93 | 94 | Returns 95 | ------- 96 | config_str : str 97 | The formatted configuration. 98 | """ 99 | result = [] 100 | for k, v in config.items(): 101 | _k = k 102 | if v is None: 103 | continue 104 | elif v in ["TRUE", "FALSE"]: 105 | result.append(f"{_k} is {v.lower().capitalize()}") 106 | elif v is True or v is False: 107 | v = str(v) 108 | result.append(f"{_k} is {v.lower().capitalize()}") 109 | elif isinstance(v, str): 110 | result.append(f"{_k} is {v}") 111 | elif isinstance(v, (float, int)): 112 | assert quantile_info is not None, "quantile_info is None" 113 | config_col = list(quantile_info[k]) 114 | anchor = min(config_col, key=lambda x: abs(x - v)) 115 | value = bin_map[q_num[config_col.index(anchor)]] 116 | result.append(f"{_k} is {value}") 117 | 118 | else: 119 | assert False, f"{v}" 120 | return ". ".join(result) + "." 121 | 122 | 123 | @lru_cache(maxsize=1000) 124 | def _token_count(text): 125 | import tiktoken 126 | 127 | try: 128 | encoding = tiktoken.encoding_for_model("gpt-4") 129 | except KeyError: 130 | print("Warning: model not found. Using gpt-2 encoding.") 131 | encoding = tiktoken.get_encoding("gpt2") 132 | return len(encoding.encode(text)) 133 | 134 | 135 | def token_count(texts): 136 | if isinstance(texts, str): 137 | return _token_count(texts) 138 | l = 0 139 | for text in texts: 140 | l += _token_count(text) 141 | return l 142 | 143 | 144 | def set_token_count_func(func): 145 | global _TOKEN_COUNT_FUNC 146 | _TOKEN_COUNT_FUNC = func 147 | 148 | 149 | def get_token_count_func(): 150 | global _TOKEN_COUNT_FUNC 151 | return _TOKEN_COUNT_FUNC 152 | 153 | 154 | def get_llm(model_type: str): 155 | return LLM_MODELS[model_type] 156 | 157 | 158 | def set_llms( 159 | suggest_model: Optional[Callable] = None, 160 | knowledge_model: Optional[Callable] = None, 161 | embedding_model: Optional[Callable] = None, 162 | ): 163 | global LLM_MODELS 164 | if suggest_model is not None: 165 | LLM_MODELS["suggest"] = suggest_model 166 | if knowledge_model is not None: 167 | LLM_MODELS["knowledge"] = knowledge_model 168 | if embedding_model is not None: 169 | LLM_MODELS["embedding"] = embedding_model 170 | 171 | 172 | def escape(text: str) -> str: 173 | return re.sub("(? bool: 15 | if sys.platform == "win32": 16 | # FIXME: Assuming jupyter-lab on Windows 17 | return True 18 | else: 19 | # https://stackoverflow.com/q/57173235/6837658 20 | import psutil 21 | 22 | parent = psutil.Process().parent() 23 | return "jupyter-lab" in parent.name() 24 | 25 | 26 | def insert_cell_below(code: str, metadata: Any = None) -> None: 27 | if is_jupyter_lab_environ(): 28 | try: 29 | input( 30 | json.dumps( 31 | {"command": "insert_cell_below", "code": code, "metadata": metadata} 32 | ) 33 | ) 34 | except EOFError: 35 | # e.g., invoked from a widget callback. It will run in the log console. 36 | from ipylab import JupyterFrontEnd 37 | 38 | app = JupyterFrontEnd() 39 | app.commands.execute("coml:insert_cell_below", {"code": code, "metadata": metadata}) # type: ignore 40 | else: 41 | encoded_code = base64.b64encode(code.encode()).decode() 42 | encoded_metadata = base64.b64encode(json.dumps(metadata).encode()).decode() 43 | display( 44 | Javascript( 45 | f""" 46 | const cell = IPython.notebook.insert_cell_below('code'); 47 | cell.set_text(atob("{encoded_code}")); 48 | cell.metadata.coml = JSON.parse(atob("{encoded_metadata}")); 49 | cell.focus_cell(); 50 | cell.focus_editor(); 51 | """ 52 | ) 53 | ) 54 | 55 | 56 | def run_code_in_next_cell(python_code: str, metadata: Any = None) -> None: 57 | if is_jupyter_lab_environ(): 58 | try: 59 | input( 60 | json.dumps( 61 | { 62 | "command": "insert_and_execute_cell_below", 63 | "code": python_code, 64 | "metadata": metadata, 65 | } 66 | ) 67 | ) 68 | except EOFError: 69 | # e.g., invoked from a widget callback 70 | from ipylab import JupyterFrontEnd 71 | 72 | app = JupyterFrontEnd() 73 | app.commands.execute("coml:insert_and_execute_cell_below", {"code": python_code, "metadata": metadata}) # type: ignore 74 | else: 75 | encoded_code = base64.b64encode(python_code.encode()).decode() 76 | encoded_metadata = base64.b64encode(json.dumps(metadata).encode()).decode() 77 | display( 78 | Javascript( 79 | f""" 80 | const cell = IPython.notebook.insert_cell_below('code'); 81 | cell.set_text(atob("{encoded_code}")); 82 | cell.metadata.coml = JSON.parse(atob("{encoded_metadata}")); 83 | cell.focus_cell(); 84 | cell.execute(); 85 | """ 86 | ) 87 | ) 88 | 89 | 90 | def update_running_cell_metadata(metadata: Any) -> None: 91 | if is_jupyter_lab_environ(): 92 | input( 93 | json.dumps( 94 | {"command": "update_running_cell_metadata", "metadata": metadata} 95 | ) 96 | ) 97 | else: 98 | encoded_metadata = base64.b64encode(json.dumps(metadata).encode()).decode() 99 | display( 100 | Javascript( 101 | """ 102 | const cell = comlGetCurrentCell(); 103 | cell.metadata.coml = Object.assign(cell.metadata.coml || {}, JSON.parse(atob(\"""" 104 | + encoded_metadata 105 | + """\"))); 106 | """ 107 | ) 108 | ) 109 | 110 | 111 | def get_ipython_history(ipython: InteractiveShell) -> list[str]: 112 | codes = [] 113 | for code in ipython.user_ns["In"]: 114 | if not code: 115 | continue 116 | if code.startswith("get_ipython().run_cell_magic('comlrun',"): 117 | # Whitelist 118 | code_match = re.match( 119 | r"get_ipython\(\).run_cell_magic\('comlrun', '', (.*)\)", code 120 | ) 121 | if code_match is not None: 122 | code = eval(code_match.group(1)) 123 | if code.startswith("get_ipython().run"): 124 | continue 125 | codes.append(code) 126 | return codes 127 | 128 | 129 | def get_running_cell() -> dict[str, Any] | None: 130 | """See `get_last_cell` for the output format.""" 131 | return json.loads(input(json.dumps({"command": "running_cell"}))) 132 | 133 | 134 | def get_last_cell() -> dict[str, Any] | None: 135 | """The implementation is in nbclassic_init.js. This is a *hacked* RPC channel. 136 | 137 | Example output: 138 | 139 | { 140 | "metadata":{ 141 | "coml":{ 142 | "variables":{}, 143 | "codes":[], 144 | "request":"Import the dataset from this [address](https://raw.githubusercontent.com/datasets/investor-flow-of-funds-us/master/data/weekly.csv).\nAssign it to a variable called flow", 145 | "answer":"```python\nimport pandas as pd\n\nurl = 'https://raw.githubusercontent.com/datasets/investor-flow-of-funds-us/master/data/weekly.csv'\nflow = pd.read_csv(url)\n```" 146 | }, 147 | "trusted":true 148 | }, 149 | "cell_type":"code", 150 | "source":"import pandas as pd\n\nurl = 'https://raw.githubusercontent.com/datasets/investor-flow-of-funds-us/master/data/weekly.csv'\nflow = pd.read_csv(url)", 151 | "execution_count":3, 152 | "outputs":[ 153 | { 154 | "output_type":"stream", 155 | "text":"123\n", 156 | "name":"stdout" 157 | }, 158 | { 159 | "output_type":"execute_result", 160 | "execution_count":4, 161 | "data":{ 162 | "text/plain":"456" 163 | } 164 | }, 165 | { 166 | "output_type":"error", 167 | "ename":"NameError", 168 | "evalue":"123", 169 | "traceback":[ 170 | "-------", 171 | "NameError Traceback (most recent call last)" 172 | ] 173 | } 174 | ] 175 | } 176 | """ 177 | return json.loads(input(json.dumps({"command": "last_cell"}))) 178 | 179 | 180 | def parse_cell_outputs(outputs: list[dict]) -> tuple[str | None, str | None]: 181 | error = output = None 182 | for cell_out in outputs: 183 | if cell_out["output_type"] == "error": 184 | if "traceback" in cell_out: 185 | traceback = "\n".join(cell_out["traceback"]) 186 | # Remove color characters 187 | error = re.sub(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))", "", traceback) 188 | elif "ename" in cell_out: 189 | error = "Error name: " + cell_out["ename"] + "\n" 190 | if "evalue" in cell_out: 191 | error += "Error value: " + cell_out["evalue"] + "\n" 192 | break 193 | elif cell_out["output_type"] == "stream": 194 | if output is None: 195 | output = "" 196 | output += cell_out["text"] 197 | elif cell_out["output_type"] == "execute_result" and cell_out["data"]: 198 | if output is None: 199 | output = "" 200 | if "text/plain" in cell_out["data"]: 201 | output += cell_out["data"]["text/plain"] + "\n" 202 | else: 203 | output += list(cell_out["data"].values())[0] + "\n" 204 | elif cell_out["output_type"] == "display_data" and cell_out["data"]: 205 | if output is None: 206 | output = "" 207 | if "image/svg+xml" in cell_out["data"]: 208 | output = "" 209 | output += cell_out["data"]["image/svg+xml"] 210 | return error, output 211 | -------------------------------------------------------------------------------- /coml/js/nbclassic_init.js: -------------------------------------------------------------------------------- 1 | window.comlGetRunningCellIndex = function () { 2 | const runningCells = $(".running"); 3 | if (runningCells.length === 0) { 4 | console.warn("No running cell"); 5 | return null; 6 | } 7 | const cellIndex = Jupyter.notebook.get_cell_elements().index(runningCells[0]); 8 | if (cellIndex < 0) { 9 | console.error("Running cell not found in cell list."); 10 | return null; 11 | } 12 | return cellIndex; 13 | } 14 | 15 | window.comlGetCurrentCell = function () { 16 | const cell = comlGetRunningCellIndex(); 17 | if (cell === null) { 18 | return null; 19 | } 20 | return IPython.notebook.get_cell(cell); 21 | } 22 | 23 | window.comlGetLastCell = function () { 24 | const cellIndex = comlGetRunningCellIndex(); 25 | if (cellIndex === null) { 26 | return null; 27 | } 28 | return IPython.notebook.get_cell(comlGetRunningCellIndex() - 1); 29 | } 30 | 31 | if (window.IPython && IPython.CodeCell) { 32 | window.IPythonAvailable = true; 33 | IPython.CodeCell.prototype.native_handle_input_request = IPython.CodeCell.prototype.native_handle_input_request || IPython.CodeCell.prototype._handle_input_request; 34 | IPython.CodeCell.prototype._handle_input_request = function (msg) { 35 | try { 36 | // only apply the hack if the command is valid JSON 37 | const command = JSON.parse(msg.content.prompt); 38 | const kernel = IPython.notebook.kernel; 39 | if (command["command"] === "last_cell") { 40 | kernel.send_input_reply(JSON.stringify(comlGetLastCell().toJSON())); 41 | } else if (command["command"] === "running_cell") { 42 | kernel.send_input_reply(JSON.stringify(comlGetCurrentCell().toJSON())); 43 | } else { 44 | console.log("Not a command", msg); 45 | this.native_handle_input_request(msg); 46 | } 47 | } catch(err) { 48 | console.log("Not a command", msg, err); 49 | this.native_handle_input_request(msg); 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /coml/linter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | from io import StringIO 4 | from typing import Literal, Tuple 5 | 6 | import pylint 7 | from pylint.lint import Run as PylintRun 8 | from pylint.reporters import JSONReporter 9 | 10 | LinterResult = Literal["error", "warning", "info", "ok"] 11 | 12 | 13 | def lint(previous_code: str, new_code: str) -> Tuple[LinterResult, str]: 14 | # https://stackoverflow.com/q/75507725/6837658 15 | pylint_options = [ 16 | "--disable=C0103", # Invalid name 17 | "--disable=C0114", # Missing module docstring 18 | "--disable=C0304", # Final new line missing 19 | ] 20 | previous_lines = previous_code.count("\n") + 1 21 | with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: 22 | f.write(previous_code + "\n" + new_code) 23 | f.flush() 24 | f.seek(0) 25 | 26 | reporter_buffer = StringIO() 27 | 28 | if pylint.__version__ < "3": 29 | results = PylintRun( 30 | [f.name] + pylint_options, 31 | reporter=JSONReporter(reporter_buffer), 32 | do_exit=False, 33 | ) 34 | else: 35 | results = PylintRun( 36 | [f.name] + pylint_options, 37 | reporter=JSONReporter(reporter_buffer), 38 | exit=False, 39 | ) 40 | # Score is here. 41 | # score = results.linter.stats.global_note 42 | file_results = json.loads(reporter_buffer.getvalue()) 43 | file_results = [e for e in file_results if e["line"] > previous_lines] 44 | 45 | details = [] 46 | for error in file_results: 47 | line = f"{error['line'] - previous_lines}:{error['column']}: {error['message-id']}: {error['message']}" 48 | details.append(line) 49 | details_joined = "\n".join(details) 50 | 51 | if any(e["type"] in ("fatal", "error") for e in file_results): 52 | return "error", details_joined 53 | elif any(e["type"] == "warning" for e in file_results): 54 | return "warning", details_joined 55 | elif file_results: 56 | return "info", details_joined 57 | else: 58 | return "ok", "No issues found." 59 | -------------------------------------------------------------------------------- /coml/magics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from typing import Any 5 | 6 | import ipywidgets as widgets 7 | import markdown 8 | from IPython.core.magic import ( 9 | Magics, 10 | cell_magic, 11 | line_cell_magic, 12 | line_magic, 13 | magics_class, 14 | no_var_expand, 15 | ) 16 | from IPython.display import HTML, Code, clear_output, display 17 | from langchain_openai import ChatOpenAI 18 | 19 | from .core import CoMLAgent 20 | from .ipython_utils import ( 21 | get_ipython_history, 22 | get_last_cell, 23 | insert_cell_below, 24 | parse_cell_outputs, 25 | run_code_in_next_cell, 26 | update_running_cell_metadata, 27 | ) 28 | from .linter import lint 29 | from .prompt_utils import ( 30 | FixContext, 31 | GenerateContext, 32 | InteractionIncomplete, 33 | describe_variable, 34 | filter_variables, 35 | ) 36 | 37 | VERIFY_STYLE = """ 38 | 67 | """ 68 | 69 | VERIFY_STATUS_ICON = { 70 | "error": "❌", 71 | "warning": "⚠️", 72 | "info": "ℹ️", 73 | "ok": "✅", 74 | None: "❔", 75 | True: "✅", 76 | False: "❌", 77 | } 78 | 79 | 80 | @magics_class 81 | class CoMLMagics(Magics): 82 | def __init__(self, *args, **kwargs): 83 | super().__init__(*args, **kwargs) 84 | 85 | import dotenv 86 | 87 | dotenv.load_dotenv() 88 | llm = ChatOpenAI(temperature=0.0, model="gpt-3.5-turbo-16k") 89 | self.agent = CoMLAgent(llm) 90 | 91 | def _get_variable_context(self) -> dict[str, Any]: 92 | assert self.shell is not None 93 | return { 94 | key: describe_variable(value) 95 | for key, value in filter_variables(self.shell.user_ns).items() 96 | } 97 | 98 | def _get_code_context(self) -> list[str]: 99 | assert self.shell is not None 100 | return get_ipython_history(self.shell) 101 | 102 | def _post_generation( 103 | self, code: str, context: GenerateContext | FixContext 104 | ) -> None: 105 | def run_button_on_click(b): 106 | run_code_in_next_cell("%%comlrun\n" + code, {"action": "run", **context}) 107 | 108 | def edit_button_on_click(b): 109 | insert_cell_below(code, context) 110 | 111 | def explain_button_on_click(b): 112 | run_code_in_next_cell("%%comlexplain\n" + code) 113 | 114 | def verify_button_on_click(b): 115 | run_code_in_next_cell("%comlverify") 116 | 117 | run_button = widgets.Button( 118 | description="👍 Run it!", layout=widgets.Layout(width="24.5%") 119 | ) 120 | edit_button = widgets.Button( 121 | description="🤔 Let me edit.", layout=widgets.Layout(width="24.5%") 122 | ) 123 | explain_button = widgets.Button( 124 | description="🧐 Explain it.", layout=widgets.Layout(width="24.5%") 125 | ) 126 | verify_button = widgets.Button( 127 | description="🔍 Check yourself.", layout=widgets.Layout(width="24.5%") 128 | ) 129 | run_button.on_click(run_button_on_click) 130 | edit_button.on_click(edit_button_on_click) 131 | explain_button.on_click(explain_button_on_click) 132 | verify_button.on_click(verify_button_on_click) 133 | 134 | update_running_cell_metadata({"action": "generate", **context}) 135 | 136 | combined = widgets.HBox( 137 | [run_button, edit_button, explain_button, verify_button] 138 | ) 139 | display(Code(code, language="python")) 140 | display(combined) 141 | 142 | def _fix_context_from_cell(self, source: str, **kwargs: Any) -> FixContext: 143 | return FixContext( 144 | variables=self._get_variable_context(), 145 | codes=self._get_code_context(), 146 | request=None, 147 | first_attempt=source, 148 | interactions=[InteractionIncomplete(**kwargs)], 149 | ) 150 | 151 | @no_var_expand 152 | @line_cell_magic 153 | def coml(self, line, cell=None): 154 | request: str = line 155 | if cell is not None: 156 | request += "\n" + cell 157 | generate_context = self.agent.generate_code( 158 | request.strip(), self._get_variable_context(), self._get_code_context() 159 | ) 160 | return self._post_generation(generate_context["answer"], generate_context) 161 | 162 | @no_var_expand 163 | @line_magic 164 | def comlset(self, line): 165 | key, value = line.lower().strip().split() 166 | if key == "debug": 167 | from . import core 168 | 169 | if value == "on": 170 | core._debug_mode = True 171 | elif value == "off": 172 | core._debug_mode = False 173 | else: 174 | raise ValueError("Debug mode must be either on or off.") 175 | else: 176 | raise ValueError("Unknown setting.") 177 | 178 | @no_var_expand 179 | @line_magic 180 | def comlinspire(self, line): 181 | if line: 182 | warnings.warn(r"The argument of %comlinspire is ignored.") 183 | suggestions = self.agent.suggest(self._get_code_context()) 184 | 185 | def run_button_on_click(b): 186 | run_code_in_next_cell(r"%coml " + b.description) 187 | 188 | buttons = [ 189 | widgets.Button(description=s, layout=widgets.Layout(width="100%")) 190 | for s in suggestions 191 | ] 192 | for button in buttons: 193 | button.on_click(run_button_on_click) 194 | display(widgets.VBox(buttons)) 195 | 196 | @no_var_expand 197 | @line_magic 198 | def comlfix(self, line): 199 | hint: str | None = line.strip() 200 | if not hint: 201 | hint = None 202 | 203 | target_cell = get_last_cell() 204 | if target_cell is None: 205 | warnings.warn("No cell to fix!") 206 | return 207 | if target_cell["cell_type"] != "code": 208 | warnings.warn("Only code cells can be fixed.") 209 | return 210 | 211 | error, output = parse_cell_outputs(target_cell["outputs"]) 212 | if "coml" in target_cell["metadata"]: 213 | context = target_cell["metadata"]["coml"] 214 | else: 215 | # Last cell is created by user. 216 | print("This cell is not created by coml. Still trying to fix it though.") 217 | context = FixContext( 218 | variables=self._get_variable_context(), 219 | codes=self._get_code_context(), 220 | request=None, 221 | first_attempt=target_cell["source"], 222 | interactions=[ 223 | InteractionIncomplete(error=error, output=output, hint=hint) 224 | ], 225 | ) 226 | 227 | fix_context = self.agent.fix_code(error, output, hint, context) 228 | if fix_context is None: 229 | return 230 | assert "code" in fix_context["interactions"][-1] 231 | return self._post_generation( 232 | fix_context["interactions"][-1]["code"], fix_context 233 | ) 234 | 235 | @no_var_expand 236 | @cell_magic 237 | def comlexplain(self, line, cell): 238 | if line: 239 | warnings.warn(r"The argument of %%comlexplain is ignored.") 240 | explanation = self.agent.explain(cell) 241 | display(Code(explanation, language="markdown")) 242 | 243 | @no_var_expand 244 | @line_magic 245 | def comlverify(self, line): 246 | target_cell = get_last_cell() 247 | if target_cell is None: 248 | raise RuntimeError("No cell to verify!") 249 | if target_cell["cell_type"] != "code": 250 | raise RuntimeError("Only code cells can be verified.") 251 | if "coml" not in target_cell["metadata"]: 252 | raise RuntimeError("This cell is not created by coml.") 253 | 254 | context = target_cell["metadata"]["coml"] 255 | if context.get("interactions"): 256 | code = context["interactions"][-1]["code"] 257 | else: 258 | code = context["answer"] 259 | 260 | error = output = None 261 | generated_vis = False 262 | if context.get("action") == "run": 263 | error, output = parse_cell_outputs(target_cell["outputs"]) 264 | generated_vis = output and "" in output 265 | 266 | def display_statuses(statuses): 267 | clear_output(wait=True) 268 | html = VERIFY_STYLE + "\n" 269 | display_names = { 270 | "lint": "PyLint", 271 | "rubberduck": "Rubberduck", 272 | } 273 | if generated_vis: 274 | display_names["vis"] = "Visualization check" 275 | elif error or output: 276 | display_names["sanity"] = "Output sanity check" 277 | 278 | loading = "" 279 | message_template = "
{}: {}\n{}
" 280 | for name in display_names: 281 | detail_message = "Still loading..." 282 | if name in statuses: 283 | detail_message = markdown.markdown( 284 | statuses[name]["details"], extensions=["nl2br"] 285 | ) 286 | html += message_template.format( 287 | display_names[name], 288 | ( 289 | loading 290 | if name not in statuses 291 | else VERIFY_STATUS_ICON[statuses[name]["result"]] 292 | ), 293 | detail_message, 294 | ) 295 | 296 | display(HTML(html)) 297 | 298 | result = {} 299 | display_statuses(result) 300 | 301 | lint_result, lint_details = lint("\n".join(self._get_code_context()), code) 302 | result["lint"] = { 303 | "result": lint_result, 304 | "details": lint_details, 305 | } 306 | display_statuses(result) 307 | 308 | rubberduck_result, rubberduck_details = self.agent.static_check(code, context) 309 | result["rubberduck"] = { 310 | "result": rubberduck_result, 311 | "details": rubberduck_details, 312 | } 313 | display_statuses(result) 314 | 315 | if generated_vis: 316 | # verify generated visualization 317 | # Roughly judge the source of the visualization 318 | if "plt.show()" in code: 319 | vis_framework = "matplotlib" 320 | ( 321 | visualization_check_result, 322 | visualization_check_details, 323 | ) = self.agent.visualization_check( 324 | context["request"], 325 | "\n".join(context["codes"]), 326 | output.replace("", ""), 327 | context["variables"], 328 | vis_framework, 329 | ) 330 | details = "" 331 | for detail in visualization_check_details: 332 | details += VERIFY_STATUS_ICON[detail[0]] + " " + detail[1] + "\n" 333 | result["vis"] = { 334 | "result": visualization_check_result, 335 | "details": details, 336 | } 337 | display_statuses(result) 338 | 339 | elif error or output: 340 | sanity_result, sanity_details = self.agent.output_sanity_check( 341 | code, context, error, output 342 | ) 343 | result["sanity"] = { 344 | "result": sanity_result, 345 | "details": sanity_details, 346 | } 347 | display_statuses(result) 348 | 349 | @no_var_expand 350 | @cell_magic 351 | def comlrun(self, line, cell): 352 | if line: 353 | warnings.warn(r"The argument of %comlrun is ignored.") 354 | assert self.shell is not None 355 | output = None 356 | try: 357 | # show as svg 358 | if "plt.show()" in cell: 359 | cell = cell.replace( 360 | "plt.show()", 361 | "show_svg(plt)", 362 | ) 363 | import inspect 364 | 365 | from .vis_utils import show_svg 366 | 367 | self.shell.run_cell(inspect.getsource(show_svg)) 368 | output = self.shell.run_cell(cell) 369 | return output.result 370 | finally: 371 | 372 | def like_button_on_click(b): 373 | print("Thanks for your feedback! 🤗") 374 | 375 | def fix_button_on_click(b): 376 | run_code_in_next_cell(r"%comlfix") 377 | 378 | def fix_with_comment_button_on_click(b): 379 | insert_cell_below(r"%comlfix ") 380 | 381 | def verify_button_on_click(b): 382 | run_code_in_next_cell("%comlverify") 383 | 384 | like_button = widgets.Button( 385 | description="🤗 Looks good!", layout=widgets.Layout(width="24.5%") 386 | ) 387 | retry_button = widgets.Button( 388 | description="🤬 Try again!", layout=widgets.Layout(width="24.5%") 389 | ) 390 | comment_button = widgets.Button( 391 | description="🤯 I'll show you what's wrong.", 392 | layout=widgets.Layout(width="24.5%"), 393 | ) 394 | verify_button = widgets.Button( 395 | description="🔍 Check yourself.", layout=widgets.Layout(width="24.5%") 396 | ) 397 | like_button.on_click(like_button_on_click) 398 | retry_button.on_click(fix_button_on_click) 399 | comment_button.on_click(fix_with_comment_button_on_click) 400 | verify_button.on_click(verify_button_on_click) 401 | 402 | combined = widgets.HBox( 403 | [like_button, retry_button, comment_button, verify_button] 404 | ) 405 | display(combined) 406 | -------------------------------------------------------------------------------- /coml/prompts/dataset/coffee_shop/member.csv: -------------------------------------------------------------------------------- 1 | Member_ID,Name,Membership_card,Age,Time_of_purchase,Level_of_membership,Address 2 | 1,"Ashby, Lazale",Black,29,18,5,Hartford 3 | 2,"Breton, Robert",White,67,41,4,Waterbury 4 | 3,"Campbell, Jessie",Black,34,20,6,Hartford 5 | 4,"Cobb, Sedrick",Black,51,27,2,Waterbury 6 | 5,"Hayes, Steven",White,50,44,3,Cheshire 7 | 6,"Komisarjevsky, Joshua",White,33,26,2,Cheshire 8 | 7,"Peeler, Russell",Black,42,26,6,Bridgeport 9 | 8,"Reynolds, Richard",Black,45,24,1,Waterbury 10 | 9,"Rizzo, Todd",White,35,18,4,Waterbury 11 | 10,"Webb, Daniel",Black,51,27,22,Hartford 12 | -------------------------------------------------------------------------------- /coml/prompts/dataset/game_injury/game.csv: -------------------------------------------------------------------------------- 1 | stadium_id,id,Season,Date,Home_team,Away_team,Score,Competition 2 | 1,1,2007,18 May 2007,Quruvchi,Pakhtakor,1–1,League 3 | 2,2,2007,22 September 2007,Pakhtakor,Quruvchi,0–0,League 4 | 3,3,2007,9 December 2007,Pakhtakor,Quruvchi,0–0 (7:6),Cup 5 | 4,4,2008,10 July 2008,Pakhtakor,Quruvchi,1–1,League 6 | 5,5,2008,16 August 2008,Bunyodkor,Pakhtakor,1–1,League 7 | 6,6,2008,31 October 2008,Bunyodkor,Pakhtakor,3–1,Cup 8 | 7,7,2009,12 July 2009,Bunyodkor,Pakhtakor,2–1,League 9 | 8,8,2009,14 October 2009,Pakhtakor,Bunyodkor,0–0,League 10 | 9,9,2009,8 August 2009,Pakhtakor,Bunyodkor,1–0,Cup 11 | 10,10,2010,14 March 2010,Bunyodkor,Pakhtakor,2–1,League 12 | 10,11,2010,31 October 2010,Pakhtakor,Bunyodkor,0–0,League 13 | 10,12,2011,7 July 2011,Pakhtakor,Bunyodkor,0–0,League 14 | 1,13,2011,21 August 2011,Bunyodkor,Pakhtakor,2–1,League 15 | 2,14,2012,11 March 2012,Bunyodkor,Pakhtakor,–,Supercup 16 | 3,15,2012,26 June 2012,Bunyodkor,Pakhtakor,2–0,League 17 | 4,16,2012,9 August 2012,Pakhtakor,Bunyodkor,1–1,League 18 | 5,17,2012,22 August 2012,Bunyodkor,Pakhtakor,1–1,Cup 19 | 11,18,2012,25 November 2012,Pakhtakor,Bunyodkor,1–3,Cup 20 | 12,19,2013,30 June 2013,Pakhtakor,Bunyodkor,0–2,League 21 | 7,20,2013,8 August 2013,Bunyodkor,Pakhtakor,1–2,League 22 | -------------------------------------------------------------------------------- /coml/prompts/dataset/pilot_record/pilot.csv: -------------------------------------------------------------------------------- 1 | Pilot_ID,Pilot_name,Rank,Age,Nationality,Position,Join_Year,Team 2 | 1,Patrick O'Bryant,13,33,United States,Center Team,2009,Bradley 3 | 2,Jermaine O'Neal,6,40,United States,Forward-Center Team,2008,Eau Claire High School 4 | 3,Dan O'Sullivan,45,37,United States,Center Team,1999,Fordham 5 | 4,Charles Oakley,34,22,United Kindom,Forward Team,2001,Virginia Union 6 | 5,Hakeem Olajuwon,34,32,Nigeria,Center Team,2010,Houston 7 | -------------------------------------------------------------------------------- /coml/prompts/dataset/scientist_1/AssignedTo.csv: -------------------------------------------------------------------------------- 1 | Scientist,Project 2 | 123234877,AeH1 3 | 152934485,AeH3 4 | 222364883,Ast3 5 | 326587417,Ast3 6 | 332154719,Bte1 7 | 546523478,Che1 8 | 631231482,Ast3 9 | 654873219,Che1 10 | 745685214,AeH3 11 | 845657245,Ast1 12 | 845657246,Ast2 13 | 332569843,AeH4 14 | -------------------------------------------------------------------------------- /coml/prompts/dataset/scientist_1/Projects.csv: -------------------------------------------------------------------------------- 1 | Code,Name,Hours 2 | AeH1,Winds: Studying Bernoullis Principle,156 3 | AeH2,Aerodynamics and Bridge Design,189 4 | AeH3,Aerodynamics and Gas Mileage,256 5 | AeH4,Aerodynamics and Ice Hockey,789 6 | AeH5,Aerodynamics of a Football,98 7 | AeH6,Aerodynamics of Air Hockey,89 8 | Ast1,A Matter of Time,112 9 | Ast2,A Puzzling Parallax,299 10 | Ast3,Build Your Own Telescope,6546 11 | Bte1,Juicy: Extracting Apple Juice with Pectinase,321 12 | Bte2,A Magnetic Primer Designer,9684 13 | Bte3,Bacterial Transformation Efficiency,321 14 | Che1,A Silver-Cleaning Battery,545 15 | Che2,A Soluble Separation Solution,778 16 | -------------------------------------------------------------------------------- /coml/prompts/fix_fewshots.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "variables": { 4 | "url": "\"https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv\"", 5 | "chipo": "pandas.DataFrame(shape=(4622, 6), columns=[\"order_id\", \"quantity\", \"item_name\", \"choice_description\", \"item_price\", ...])\n order_id quantity item_name choice_description item_price price_per_item\n 0 1 1 Chips an... NaN 2.39 2.39 \n 1 1 1 Izze [Clement... 3.39 3.39 \n ... ... ... ... ... ... ... \n 4620 1834 1 Chicken ... [Fresh T... 8.75 8.75 \n 4621 1834 1 Chicken ... [Fresh T... 8.75 8.75", 6 | "prices": "[2.39, 3.39, 3.39, 2.39, 16.98, ...]", 7 | "chipo_filtered": "pandas.DataFrame(shape=(1949, 5), columns=[\"order_id\", \"quantity\", \"item_name\", \"choice_description\", \"item_price\"])\n order_id quantity item_name choice_description item_price\n 0 1 1 Chips an... NaN 2.39\n 1 1 1 Izze [Clement... 3.39\n ... ... ... ... ... ...\n 4611 1830 1 Veggie B... [Tomatil... 11.25\n 4612 1831 1 Carnitas... [Fresh T... 9.25" 8 | }, 9 | "codes": [ 10 | "import pandas as pd", 11 | "# Import the dataset from this [address](https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv).\n# \n# Assign it to a variable called chipo.\nurl = 'https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv'\n\nchipo = pd.read_csv(url, sep = '\\t')", 12 | "# How many products cost more than $10.00?\n# clean the item_price column and transform it in a float\nprices = [float(value[1 : -1]) for value in chipo.item_price]\n\n# reassign the column with the cleaned prices\nchipo.item_price = prices\n\n# delete the duplicates in item_name and quantity\nchipo_filtered = chipo.drop_duplicates(['item_name','quantity','choice_description'])\n\n# calculate the price per item and query the ones that are greater than $10.00\nchipo['price_per_item'] = chipo.item_price / chipo.quantity\nchipo.query('price_per_item > 10').item_name.nunique()" 13 | ], 14 | "request": "What is the price of each item? \nReturn a data frame with only two columns item_name and item_price, sorted from the most to less expensive", 15 | "first_attempt": "item_prices = chipo.groupby('item_name')['item_price'].max()\nitem_prices_df = pd.DataFrame({'item_name': item_prices.index, 'item_price': item_prices.values})\nitem_prices_df.sort_values(by='item_price', ascending=False)", 16 | "interactions": [ 17 | { 18 | "error": null, 19 | "output": " item_name item_price\n24 Chips and Fresh Tomato Salsa 44.25\n11 Carnitas Bowl 35.25\n.. ... ...\n26 Chips and Mild Fresh Tomato Salsa 3.00\n28 Chips and Roasted Chili-Corn Salsa 2.39", 20 | "hint": "The item_price column in chipo refers to the total price of order.", 21 | "explanation": "1. `item_prices = chipo.groupby('item_name')['item_price'].max()`: This calculates the highest price for each unique item using the 'item_name' column as groups.\n2. `item_prices_df = pd.DataFrame(...)`: This creates a DataFrame with 'item_name' and 'item_price' columns, using the calculated maximum prices.\n3. `item_prices_df.sort_values(by='item_price', ascending=False)`: This sorts the DataFrame by 'item_price' in descending order, showing items with the highest prices at the top.", 22 | "observation": "The code didn't filter out orders with quantity > 1. So the prices are not correct.", 23 | "code": "# Filter out orders with quantity > 1\nchipo_single_item = chipo[chipo['quantity'] == 1]\n# Calculate the total price for each item\nitem_prices = chipo_single_item.groupby('item_name')['item_price'].mean()\n# Create a DataFrame with item_name and item_price columns\nitem_prices_df = pd.DataFrame({'item_name': item_prices.index, 'item_price': item_prices.values})\n# Sort the DataFrame by item_price in descending order\nitem_prices_df.sort_values(by='item_price', ascending=False)" 24 | }, 25 | { 26 | "error": null, 27 | "output": " item_name item_price\n15 Carnitas Salad Bowl 11.056667\n42 Steak Salad Bowl 10.964074\n.. ... ...\n10 Canned Soft Drink 1.250000\n9 Canned Soda 1.090000", 28 | "hint": "Do not take the mean of item_price. Use the first price for each item instead.", 29 | "explanation": "1. `chipo_single_item = chipo[chipo['quantity'] == 1]`: Filters out orders with a quantity greater than 1, creating a new DataFrame.\n2. `item_prices = chipo_single_item.groupby('item_name')['item_price'].mean()`: Calculates the average price per unique item.\n3. `item_prices_df = pd.DataFrame(...)`: Creates a DataFrame with item names and their average prices.\n4. `item_prices_df.sort_values(by='item_price', ascending=False)`: Sorts the DataFrame by average item price in descending order, presenting the highest priced items first.", 30 | "observation": "The code computes the average of item_price, but the user wants the first price.", 31 | "code": "# delete the duplicates in item_name and quantity\nchipo_filtered = chipo.drop_duplicates(['item_name','quantity'])\n\n# select only the products with quantity equals to 1\nchipo_one_prod = chipo_filtered[chipo_filtered.quantity == 1]\n\n# select only the item_name and item_price columns\nchipo_one_prod[['item_name', 'item_price']].sort_values(by = \"item_price\", ascending = False)" 32 | } 33 | ] 34 | }, 35 | { 36 | "variables": { 37 | "users": "pandas.DataFrame(shape=(943, 4), columns=[\"age\", \"gender\", \"occupation\", \"zip_code\"])\n age gender occupation zip_code\n user_id \n 1 24 M technician 85711\n 2 53 F other 94043\n ... ... ... ... ...\n 942 48 F librarian 78209\n 943 22 M student 77841" 38 | }, 39 | "codes": [ 40 | "import pandas as pd", 41 | "# Import the dataset from this [address](https://raw.githubusercontent.com/justmarkham/DAT8/master/data/u.user).\n# \n# Assign it to a variable called users.\nusers = pd.read_table('https://raw.githubusercontent.com/justmarkham/DAT8/master/data/u.user', \n sep='|', index_col='user_id')", 42 | "# Discover what is the mean age per occupation\nusers.groupby('occupation').age.mean()" 43 | ], 44 | "request": "Discover the Male ratio per occupation and sort it from the most to the least", 45 | "answer": "# create a function\ndef gender_to_numeric(x):\n if x == 'M':\n return 1\n if x == 'F':\n return 0\n\n# apply the function to the gender column and create a new column\nusers['gender_n'] = users['gender'].apply(gender_to_numeric)\n\n\na = users.groupby('occupation').gender_n.sum() / users.occupation.value_counts() * 100 \n\n# sort to the most male \na.sort_values(ascending = False)", 46 | "first_attempt": "(users[users['gender'] == 'M'].groupby('occupation').gender.count() / users.groupby('occupation').gender.count()).sort_values(ascending=False)", 47 | "interactions": [ 48 | { 49 | "error": null, 50 | "output": "occupation\ndoctor 1.000000\nengineer 0.970149\n... ...\nhealthcare 0.312500\nhomemaker 0.142857\nName: gender, dtype: float64", 51 | "hint": "Please show the ratio in percentage.", 52 | "explanation": "This code calculates the proportion of male users for each occupation in the 'users' DataFrame, then sorts and presents these proportions in descending order.", 53 | "observation": "The code is almost correct, but the user wants the ratio in percentage.", 54 | "code": "male_ratio_per_occupation = (users[users['gender'] == 'M'].groupby('occupation').gender.count() / users.groupby('occupation').gender.count() * 100).sort_values(ascending=False)" 55 | } 56 | ] 57 | }, 58 | { 59 | "variables": { 60 | "flow": "pandas.DataFrame(shape=(44, 8), columns=[\"Total Equity\", \"Domestic Equity\", \"World Equity\", \"Hybrid\", \"Total Bond\", ...])\n Total Equity Domestic Equity World Equity ... Taxable Bond Municipal Bond Total\n Date ... \n 2012-12-05 -7426 -6060 -1367 ... 4210 1107 -2183\n 2012-12-12 -8783 -7520 -1263 ... 1598 219 -6842\n ... ... ... ... ... ... ... ...\n 2015-04-01 -1527 -3307 1780 ... -1392 96 -2103\n 2015-04-08 1906 -1321 3227 ... 1906 -187 3875" 61 | }, 62 | "codes": [ 63 | "import numpy as np\nimport pandas as pd", 64 | "# Import the dataset from this [address](https://raw.githubusercontent.com/datasets/investor-flow-of-funds-us/master/data/weekly.csv).\n# \n# Assign it to a variable called flow\nflow = pd.read_csv(\n \"https://raw.githubusercontent.com/datasets/investor-flow-of-funds-us/master/data/weekly.csv\",\n)", 65 | "# Set the column Date as the index and set the index to a DatetimeIndex type. Modify the flow dataframe in place.\nflow = flow.set_index('Date')\nflow.index = pd.to_datetime(flow.index)" 66 | ], 67 | "request": "What is the frequency of the dataset?", 68 | "first_attempt": "flow.index.freq", 69 | "interactions": [ 70 | { 71 | "error": null, 72 | "output": null, 73 | "hint": null, 74 | "explanation": "In time series data, the .index.freq in Pandas refers to the frequency of the DateTimeIndex. It helps identify how often data points occur, like daily, weekly, etc. You can use .index.freq to access or set the frequency for the DateTimeIndex in your code. If the data has a regular pattern, this attribute can be quite useful for analysis and manipulation.", 75 | "observation": "The returned value is None, which means the frequency is not set.", 76 | "code": "pd.infer_freq(flow.index)" 77 | }, 78 | { 79 | "error": null, 80 | "output": null, 81 | "hint": null, 82 | "explanation": "The pd.infer_freq() function is used to automatically determine the frequency of a DateTimeIndex based on the time intervals between consecutive dates in the index. It analyzes the time gaps between the dates and tries to identify the most likely frequency that fits the pattern of the data. This function is applied to the DateTimeIndex of the `flow` DataFrame. It attempts to figure out the most appropriate frequency for the time series data based on the pattern of dates.", 83 | "observation": "The returned value is still None. Notice that the gaps between dates are not regular. Some dates are skipped. Therefore the frequency needs to be inferred on a partial index.", 84 | "code": "pd.infer_freq(flow.index[-3:])" 85 | } 86 | ] 87 | }, 88 | { 89 | "variables": { 90 | "url": "\"https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv\"", 91 | "chipo": "pandas.DataFrame(shape=(4622, 5), columns=[\"order_id\", \"quantity\", \"item_name\", \"choice_description\", \"item_price\"])\n order_id quantity item_name choice_description item_price\n 0 1 1 Chips an... NaN $2.39 \n 1 1 1 Izze [Clement... $3.39 \n ... ... ... ... ... ...\n 4620 1834 1 Chicken ... [Fresh T... $8.75 \n 4621 1834 1 Chicken ... [Fresh T... $8.75" 92 | }, 93 | "codes": [ 94 | "import pandas as pd\nimport numpy as np", 95 | "# Import the dataset from this [address](https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv).\n# \n# Assign it to a variable called chipo.\nurl = 'https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv'\n \nchipo = pd.read_csv(url, sep = '\\t')", 96 | "# See the first 10 entries\nchipo.head(10)" 97 | ], 98 | "request": "What is the number of observations in the dataset?", 99 | "first_attempt": "import pandas as pd\nurl = \"https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv\"\nchipo = pd.DataFrame(shape=(4622, 5), columns=pd.Index(['order_id', 'quantity', 'item_name', 'choice_description', 'item_price'], dtype='object'))", 100 | "interactions": [ 101 | { 102 | "error": "Traceback (most recent call last):\n File \"\", line 3, in \nTypeError: DataFrame.__init__() got an unexpected keyword argument 'shape'\n", 103 | "output": null, 104 | "hint": "You should use existing DataFrames.", 105 | "explanation": "1. `import pandas as pd`: Imports the Pandas library and assigns it the alias `pd`.\n2. `url = \"https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv\"`: Defines a URL string pointing to a dataset (in TSV format) related to Chipotle orders.\n3. `chipo = pd.DataFrame(shape=(4622, 5), columns=pd.Index(['order_id', 'quantity', 'item_name', 'choice_description', 'item_price'], dtype='object'))`: Creates a new DataFrame named `chipo` using the `pd.DataFrame()` constructor. Defines the column names using `pd.Index()`, which are 'order_id', 'quantity', 'item_name', 'choice_description', and 'item_price'. Sets the data type of the columns as 'object'. The error occurred in the original code because the shape parameter is not a valid parameter for initializing a DataFrame using the pd.DataFrame() constructor.", 106 | "observation": "The code creates a new DataFrame, but the user wants to use the existing DataFrame `chipo`. Also, the code uses the `shape` parameter, which is not a valid parameter.", 107 | "code": "chipo.shape[0] # entries <= 4622 observations" 108 | } 109 | ] 110 | } 111 | ] -------------------------------------------------------------------------------- /coml/prompts/fix_fewshots_v2.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "variables": { 4 | "flowers": "pandas.DataFrame(shape=(5, 2), columns=[\"name\", \"count\"])\n name count\n 0 rose 10\n 1 tulip 20\n 2 daisy 30\n 3 lily 40\n 4 orchid 50" 5 | }, 6 | "codes": [ 7 | "flowers['count'].sum()", 8 | "flowers = flowers.sort_values('count', ascending=True)" 9 | ], 10 | "request": "Plot a bar chart of the number of flowers for each type.", 11 | "first_attempt": "import seaborn as sns\nsns.barplot(x='name', y='count', data=flowers)", 12 | "interactions": [ 13 | { 14 | "error": "Traceback (most recent call last):\n File \"/home/coml/coml.py\", line 36, in seaborn_example\n sns.barplot(x='name', y='count', data=flowers)", 15 | "output": null, 16 | "hint": null, 17 | "explanation": "`sns.barplot(x='name', y='count', data=flowers)`: This line of code is using the seaborn library's barplot function to create a bar chart. 'name' and 'count' are the column names from the DataFrame 'flowers' that will be used for the x and y axes, respectively.", 18 | "observation": "The seaborn library, represented as `sns`, hasn't been imported, which is why the error is occurring.", 19 | "code": "import seaborn as sns\nsns.barplot(x='name', y='count', data=flowers)" 20 | } 21 | ] 22 | }, 23 | { 24 | "variables": { 25 | "X_train": "numpy.ndarray(shape=(2, 3), dtype=float32)", 26 | "y_train": "numpy.ndarray(shape=(2,), dtype=float32)" 27 | }, 28 | "codes": [ 29 | "from sklearn.linear_model import LinearRegression" 30 | ], 31 | "request": "Fit a linear regression model to the training data.", 32 | "first_attempt": "model = LinearRegression()\nmodel.fit(X_train, y_train)", 33 | "interactions": [ 34 | { 35 | "error": null, 36 | "output": "LinearRegression()", 37 | "hint": null, 38 | "explanation": "1. `model = LinearRegression()`: This line of code creates a new linear regression model.\n2. `model.fit(X_train, y_train)`: This line of code fits the model to the training data.", 39 | "observation": "THE CODE IS CORRECT.", 40 | "code": "" 41 | } 42 | ] 43 | }, 44 | { 45 | "variables": { 46 | "a": "numpy.ndarray(shape=(5,), dtype=int64)", 47 | "b": "numpy.ndarray(shape=(5,), dtype=int64)" 48 | }, 49 | "codes": [], 50 | "request": "Run a t-test to test whether the mean of a is significantly different from the mean of b. Output the p-value.", 51 | "first_attempt": "from scipy.stats import ttest_ind\nresult = ttest_ind(a, b)\nresult.pvalue", 52 | "interactions": [ 53 | { 54 | "error": null, 55 | "output": "0.5191867712416913", 56 | "hint": null, 57 | "explanation": "1. `from scipy.stats import ttest_ind`: This imports the function `ttest_ind` from the `scipy.stats` module. The function `ttest_ind` performs a two-sample t-test, which tests whether the means of two independent samples are significantly different.\n2. `result = ttest_ind(a, b)`: This line performs the two-sample t-test on the arrays `a` and `b`. The result is an object that contains the t-statistic and the p-value of the test.\n3. `print(result.pvalue)`: This prints the p-value of the test result.", 58 | "observation": "The code prints the result directly to the console. It should instead put the result directly in the last line of code.", 59 | "code": "from scipy.stats import ttest_ind\nresult = ttest_ind(a, b)\nresult.pvalue" 60 | } 61 | ] 62 | }, 63 | { 64 | "variables": { 65 | "fruits": "pandas.DataFrame(shape=(5, 3), columns=[\"name\", \"price\", \"color\"])\n name price color\n 0 apple $1.20 red\n 1 banana $0.50 yellow\n 2 cherry $0.75 red\n 3 durian $2.50 green\n 4 elderberry $1.00 purple" 66 | }, 67 | "codes": [ 68 | "fruits[\"name\"].nunique()", 69 | "fruits[\"color\"].value_counts()" 70 | ], 71 | "request": "What is the average price of red fruits?", 72 | "first_attempt": "fruits[fruits['color'] == 'red']['price'].apply(lambda x: float(x.replace('$', ''))).mean()", 73 | "interactions": [ 74 | { 75 | "error": null, 76 | "output": "0.975", 77 | "hint": null, 78 | "explanation": "1. `fruits[\"price_float\"] = fruits[\"price\"].apply(lambda x: float(x.replace(\"$\", \"\")))`: This line of code creates a new column \"price_float\" in the \"fruits\" DataFrame. This column is created by applying a function that removes the dollar sign from each value in the \"price\" column and converts the resulting string to a float.\n2. `fruits[fruits[\"color\"] == \"red\"][\"price_float\"].mean()`: This line of code first filters the \"fruits\" DataFrame to include only rows where the \"color\" column is \"red\". It then selects the \"price_float\" column from the filtered DataFrame and calculates the mean (average) of this column's values.", 79 | "observation": "The code creates a new column in the \"fruits\" DataFrame, which is actually not necessary and violating the principle that no overwriting or modification should be done unless specified.", 80 | "code": "fruits[fruits['color'] == 'red']['price'].apply(lambda x: float(x.replace('$', ''))).mean()" 81 | } 82 | ] 83 | }, 84 | { 85 | "variables": { 86 | "people": "pandas.DataFrame(shape=(6, 3), columns=[\"name\", \"gender\", \"age\"])\n name gender age\n 0 Alice Female 25\n 1 Bob Male 30\n 2 Charlie Male 20\n 3 David Male 35\n 4 Eve Female 40\n 5 Frank Other 45" 87 | }, 88 | "codes": [], 89 | "request": "What is the average age of people for each gender?", 90 | "first_attempt": "people.groupby('gender').age.mean().reset_index()", 91 | "interactions": [ 92 | { 93 | "error": null, 94 | "output": " gender age\n0 Male 28.333333\n1 Female 32.500000", 95 | "hint": null, 96 | "explanation": "1. The code first creates a dictionary with two keys: \"gender\" and \"age\".\n2. For each key, it assigns a list as value. The list for \"gender\" includes two strings: \"Male\" and \"Female\". The list for \"age\" includes two function calls which compute the mean age for male and female people, respectively.\n3. The code then creates a DataFrame using this dictionary.", 97 | "observation": "The code only computes the average ages for \"Male\" and \"Female\" genders. However, there is another gender \"Other\" in the dataset which is not included in the computation.", 98 | "code": "people.groupby('gender').age.mean().reset_index()" 99 | } 100 | ] 101 | }, 102 | { 103 | "variables": { 104 | "admissions": "pandas.DataFrame(shape=(5, 2), columns=[\"Date\", \"Applicants\"])\n Date Applicants\n 0 2016-9-30 1,536\n 1 2016-12-31 2,204\n 2 2017-3-15 11,385\n 3 2017-6-9 250\n 4 2017-6-15 7,359" 105 | }, 106 | "codes": [ 107 | "import pandas as pd", 108 | "admissions = pd.read_csv('admissions.csv')" 109 | ], 110 | "request": "Count the number of applicants between 2016-10-1 and 2017-6-10 per day.", 111 | "first_attempt": "admissions_cvt = pd.DataFrame({\n \"Date\": pd.to_datetime(admissions.Date),\n \"Applicants\": admissions.Applicants.str.replace(',', '').astype(int),\n})\ndays = pd.Timestamp(\"2017-6-10\") - pd.Timestamp(\"2016-10-1\") + pd.Timedelta(days=1)\nadmissions_cvt[(admissions_cvt.Date >= \"2016-10-1\") & (admissions_cvt.Date <= \"2017-6-10\")].Applicants.sum() / days.days", 112 | "interactions": [ 113 | { 114 | "error": "Traceback (most recent call last):\n File \"/home/coml/coml.py\", line 175, in days_example\n admissions[(admissions.Date >= \"2016-10-1\") & (admissions.Date <= \"2017-6-10\")].Applicants.sum() / days.days\nTypeError: unsupported operand type(s) for /: 'str' and 'int'", 115 | "output": null, 116 | "hint": null, 117 | "explanation": "1. `days = pd.Timestamp(\"2017-6-10\") - pd.Timestamp(\"2016-10-1\") + pd.Timedelta(days=1)`: calculates the total number of days between 2016-10-1 and 2017-6-10.\n2. `admissions[(admissions.Date >= \"2016-10-1\") & (admissions.Date <= \"2017-6-10\")].Applicants.sum() / days.days`: filters the admissions dataframe for the specified date range, sums the applicants, and divides it by the total number of days.", 118 | "observation": "The error message indicates that there's a type mismatch when performing an operation. Looking at the code, it seems that the 'Applicants' column is being treated as a string, not a number. This is probably due to the comma in the numbers, which needs to be removed. The date column should also be converted to a datetime object for comparison.", 119 | "code": "admissions.Date = pd.to_datetime(admissions.Date)\nadmissions.Applicants = admissions.Applicants.str.replace(',', '').astype(int)\ndays = pd.Timestamp(\"2017-6-10\") - pd.Timestamp(\"2016-10-1\") + pd.Timedelta(days=1)\nadmissions[(admissions.Date >= \"2016-10-1\") & (admissions.Date <= \"2017-6-10\")].Applicants.sum() / days.days" 120 | }, 121 | { 122 | "error": null, 123 | "output": "54.6996", 124 | "hint": null, 125 | "explanation": "1. `admissions.Date = pd.to_datetime(admissions.Date)`: This line converts the 'Date' column of the DataFrame to a datetime object, and writes it back to the DataFrame.\n2. `admissions.Applicants = admissions.Applicants.str.replace(',', '').astype(int)`: This line removes the comma from the 'Applicants' column values and converts the column to integer type.\n3. `days = pd.Timestamp(\"2017-6-10\") - pd.Timestamp(\"2016-10-1\") + pd.Timedelta(days=1)`: This line calculates the total number of days between the two given dates.\n4. `admissions[(admissions.Date >= \"2016-10-1\") & (admissions.Date <= \"2017-6-10\")].Applicants.sum() / days.days`: This line filters the DataFrame to include only the rows where the date is between the two given dates. It then sums up the 'Applicants' column values and divides it by the total number of days calculated in the previous step.", 126 | "observation": "The code contains inplace changes to the original DataFrame, but the request did not specify inplace modifications to the DataFrame. The request can be fulfilled by creating a new DataFrame instead of overwrting the original one.", 127 | "code": "admissions_cvt = pd.DataFrame({\n \"Date\": pd.to_datetime(admissions.Date),\n \"Applicants\": admissions.Applicants.str.replace(',', '').astype(int),\n})\ndays = pd.Timestamp(\"2017-6-10\") - pd.Timestamp(\"2016-10-1\") + pd.Timedelta(days=1)\nadmissions_cvt[(admissions_cvt.Date >= \"2016-10-1\") & (admissions_cvt.Date <= \"2017-6-10\")].Applicants.sum() / days.days" 128 | } 129 | ] 130 | } 131 | ] -------------------------------------------------------------------------------- /coml/prompts/generate_fewshots.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "variables": { 4 | "url": "\"https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv\"", 5 | "chipo": "pandas.DataFrame(shape=(4622, 5), columns=[\"order_id\", \"quantity\", \"item_name\", \"choice_description\", \"item_price\"])\n order_id quantity item_name choice_description item_price\n 0 1 1 Chips an... NaN $2.39 \n 1 1 1 Izze [Clement... $3.39 \n ... ... ... ... ... ...\n 4620 1834 1 Chicken ... [Fresh T... $8.75 \n 4621 1834 1 Chicken ... [Fresh T... $8.75" 6 | }, 7 | "codes": [ 8 | "import pandas as pd", 9 | "# Import the dataset from this [address](https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv).\n# \n# Assign it to a variable called chipo.\nurl = 'https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv'\n\nchipo = pd.read_csv(url, sep = '\\t')" 10 | ], 11 | "request": "How many products cost more than $10.00?", 12 | "answer": "# clean the item_price column and transform it in a float\nprices = [float(value[1 : -1]) for value in chipo.item_price]\n\n# reassign the column with the cleaned prices\nchipo.item_price = prices\n\n# delete the duplicates in item_name and quantity\nchipo_filtered = chipo.drop_duplicates(['item_name','quantity','choice_description'])\n\n# calculate the price per item and query the ones that are greater than $10.00\nchipo['price_per_item'] = chipo.item_price / chipo.quantity\nchipo.query('price_per_item > 10').item_name.nunique()" 13 | }, 14 | { 15 | "variables": { 16 | "drinks": "pandas.DataFrame(shape=(193, 6), columns=[\"country\", \"beer_servings\", \"spirit_servings\", \"wine_servings\", \"total_litres_of_pure_alcohol\", ...])\n country beer_servings spirit_servings wine_servings total_litres_of_pure_alcohol continent\n 0 Afghanistan 0 0 0 0.0 AS\n 1 Albania 89 132 54 4.9 EU\n .. ... ... ... ... ... ...\n 191 Zambia 32 19 4 2.5 AF\n 192 Zimbabwe 64 18 4 4.7 AF", 17 | "beer_avg_by_continent": "pandas.Series(shape=(5,))" 18 | }, 19 | "codes": [ 20 | "import pandas as pd", 21 | "# Import the dataset from this [address](https://raw.githubusercontent.com/justmarkham/DAT8/master/data/drinks.csv).\n# \n# Assign it to a variable called drinks.\ndrinks = pd.read_csv('https://raw.githubusercontent.com/justmarkham/DAT8/master/data/drinks.csv')", 22 | "# Which continent drinks more beer on average?\n# Group by continent and calculate the mean of beer_servings\nbeer_avg_by_continent = drinks.groupby('continent')['beer_servings'].mean()\n\n# Find the continent with the highest average beer consumption\nbeer_avg_by_continent.idxmax()" 23 | ], 24 | "request": "For each continent show the statistics for wine consumption.", 25 | "answer": "drinks.groupby('continent').wine_servings.describe()" 26 | }, 27 | { 28 | "variables": { 29 | "cars1": "pandas.DataFrame(shape=(198, 14), columns=[\"mpg\", \"cylinders\", \"displacement\", \"horsepower\", \"weight\", ...])\n mpg cylinders displacement ... Unnamed: 11 Unnamed: 12 Unnamed: 13\n 0 18.0 8 307 ... NaN NaN NaN\n 1 15.0 8 350 ... NaN NaN NaN\n .. ... ... ... ... ... ... ...\n 196 24.5 4 98 ... NaN NaN NaN\n 197 29.0 4 90 ... NaN NaN NaN", 30 | "cars2": "pandas.DataFrame(shape=(200, 9), columns=[\"mpg\", \"cylinders\", \"displacement\", \"horsepower\", \"weight\", ...])\n mpg cylinders displacement ... model origin car\n 0 33.0 4 91 ... 76 3 honda civic\n 1 20.0 6 225 ... 76 1 dodge as...\n .. ... ... ... ... ... ... ...\n 198 28.0 4 120 ... 82 1 ford ranger\n 199 31.0 4 119 ... 82 1 chevy s-10" 31 | }, 32 | "codes": [ 33 | "import pandas as pd\nimport numpy as np", 34 | "# Import the first dataset [cars1](https://raw.githubusercontent.com/guipsamora/pandas_exercises/master/05_Merge/Auto_MPG/cars1.csv) and [cars2](https://raw.githubusercontent.com/guipsamora/pandas_exercises/master/05_Merge/Auto_MPG/cars2.csv).\n# \n# Assign each to a to a variable called cars1 and cars2\ncars1 = pd.read_csv(\"https://raw.githubusercontent.com/guipsamora/pandas_exercises/master/05_Merge/Auto_MPG/cars1.csv\")\ncars2 = pd.read_csv(\"https://raw.githubusercontent.com/guipsamora/pandas_exercises/master/05_Merge/Auto_MPG/cars2.csv\")" 35 | ], 36 | "request": "Oops, it seems our first dataset has some unnamed blank columns, fix cars1", 37 | "answer": "cars1 = cars1.loc[:, \"mpg\":\"car\"]" 38 | }, 39 | { 40 | "variables": { 41 | "apple": "pandas.DataFrame(shape=(8465, 6), columns=[\"Open\", \"High\", \"Low\", \"Close\", \"Volume\", ...])\n Open High Low Close Volume Adj Close\n Date \n 1980-12-12 28.75 28.87 28.75 28.75 117258400 0.45\n 1980-12-15 27.38 27.38 27.25 27.25 43971200 0.42\n ... ... ... ... ... ... ...\n 2014-07-07 94.14 95.99 94.10 95.97 56305400 95.97\n 2014-07-08 96.27 96.80 93.92 95.35 65130000 95.35" 42 | }, 43 | "codes": [ 44 | "import pandas as pd\nimport numpy as np", 45 | "# Import the dataset from this [address](https://raw.githubusercontent.com/guipsamora/pandas_exercises/master/09_Time_Series/Apple_Stock/appl_1980_2014.csv)\n# \n# Assign it to a variable apple\napple = pd.read_csv('https://raw.githubusercontent.com/guipsamora/pandas_exercises/master/09_Time_Series/Apple_Stock/appl_1980_2014.csv')", 46 | "# Check out the type of the columns\napple.dtypes", 47 | "# Transform the Date column as a datetime type\napple['Date'] = pd.to_datetime(apple['Date'])", 48 | "# Set the date as the index\napple = apple.set_index('Date')", 49 | "# Is there any duplicate dates?\napple.index.is_unique", 50 | "# Ops...it seems the index is from the most recent date. Make the first entry the oldest date.\napple = apple.sort_index(ascending = True)", 51 | "# Get the last business day of each month. Return a list of pandas timestamps.\napple.resample('BM').last().index.to_list()" 52 | ], 53 | "request": "What is the difference in days between the first day and the oldest", 54 | "answer": "(apple.index.max() - apple.index.min()).days" 55 | }, 56 | { 57 | "variables": { 58 | "Counter": "", 59 | "url": "\"https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv\"", 60 | "chipo": "pandas.DataFrame(shape=(4622, 5), columns=[\"order_id\", \"quantity\", \"item_name\", \"choice_description\", \"item_price\"])\n order_id quantity item_name choice_description item_price\n 0 1 1 Chips an... NaN $2.39 \n 1 1 1 Izze [Clement... $3.39 \n ... ... ... ... ... ...\n 4620 1834 1 Chicken ... [Fresh T... $8.75 \n 4621 1834 1 Chicken ... [Fresh T... $8.75" 61 | }, 62 | "codes": [ 63 | "import pandas as pd\nimport matplotlib.pyplot as plt\nfrom collections import Counter", 64 | "# Import the dataset from this [address](https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv).\n# \n# Assign it to a variable called chipo.\nurl = 'https://raw.githubusercontent.com/justmarkham/DAT8/master/data/chipotle.tsv'\n \nchipo = pd.read_csv(url, sep = '\\t')", 65 | "# See the first 10 entries\nchipo.head(10)" 66 | ], 67 | "request": "Create a histogram of the top 5 items bought", 68 | "answer": "# get the Series of the names\nx = chipo.item_name\n\n# use the Counter class from collections to create a dictionary with keys(text) and frequency\nletter_counts = Counter(x)\n\n# convert the dictionary to a DataFrame\ndf = pd.DataFrame.from_dict(letter_counts, orient='index')\n\n# sort the values from the top to the least value and slice the first 5 items\ndf = df[0].sort_values(ascending = True)[45:50]\n\n# create the plot\ndf.plot(kind='bar')\n\n# Set the title and labels\nplt.xlabel('Items')\nplt.ylabel('Number of Times Ordered')\nplt.title('Most ordered Chipotle\\'s Items')\n\n# show the plot\nplt.show()" 69 | } 70 | ] -------------------------------------------------------------------------------- /coml/prompts/generate_fewshots_kaggle.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "variables": { 4 | "dataset": "pandas.DataFrame(shape=(5043, 28), columns=[\"color\", \"director_name\", \"num_critic_for_reviews\", \"duration\", \"director_facebook_likes\", \"actor_3_facebook_likes\", \"actor_2_name\", \"actor_1_facebook_likes\", \"gross\", \"genres\", \"actor_1_name\", \"movie_title\", \"num_voted_users\", \"cast_total_facebook_likes\", \"actor_3_name\", \"facenumber_in_poster\", \"plot_keywords\", \"movie_imdb_link\", \"num_user_for_reviews\", \"language\", \"country\", \"content_rating\", \"budget\", \"title_year\", \"actor_2_facebook_likes\", \"imdb_score\", \"aspect_ratio\", \"movie_facebook_likes\"])\n color director_name num_critic_for_reviews duration director_facebook_likes actor_3_facebook_likes actor_2_name ... content_rating budget title_year actor_2_facebook_likes imdb_score aspect_ratio movie_facebook_likes\n 0 Color James Cameron 723.0 178.0 0.0 855.0 Joel David ... ... PG-13 237000000.0 2009.0 936.0 7.9 1.78 33000 \n 1 Color Gore Verbinski 302.0 169.0 563.0 1000.0 Orlando Bloom ... PG-13 300000000.0 2007.0 5000.0 7.1 2.35 0 \n ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... \n 5041 Color Daniel Hsia 14.0 100.0 0.0 489.0 Daniel Henney ... PG-13 NaN 2012.0 719.0 6.3 2.35 660 \n 5042 Color Jon Gunn 43.0 90.0 16.0 16.0 Brian Herzl... ... PG 1100.0 2004.0 23.0 6.6 1.85 456" 5 | }, 6 | "codes": [ 7 | "import pandas as pd\n\ndataset = pd.read_csv(\"data/movie_metadata.csv\")", 8 | "dataset.isnull().sum()" 9 | ], 10 | "request": "List the major genres in the dataset. Show the result in a DataFrame with columns `genre` and `count` (number of movies of that genre).\n", 11 | "answer": "# Create a new variable to avoid unintentional modifications to raw DataFrame.\ngenres = pd.DataFrame({\n \"movie\": dataset[\"movie_title\"],\n \"genre\": dataset[\"genres\"].apply(lambda x: x.split(\"|\")[0]),\n})\n\n# Count the number of movies in each genre.\ngenres.groupby(\"genre\").count()[[\"movie\"]].rename(columns={\"movie\": \"count\"}).reset_index()" 12 | }, 13 | { 14 | "variables": { 15 | "iris": "pandas.DataFrame(shape=(150, 5), columns=[\"SepalLengthCm\", \"SepalWidthCm\", \"PetalLengthCm\", \"PetalWidthCm\", \"Species\"])\n SepalLengthCm SepalWidthCm PetalLengthCm PetalWidthCm Species\n Id \n 1 5.1 3.5 1.4 0.2 Iris-setosa\n 2 4.9 3.0 1.4 0.2 Iris-setosa\n .. ... ... ... ... ...\n 149 6.2 3.4 5.4 2.3 Iris-virginica\n 150 5.9 3.0 5.1 1.8 Iris-virginica", 16 | "z": "5.735584236235733", 17 | "pval": "0.05682424941067306" 18 | }, 19 | "codes": [ 20 | "import pandas as pd\nimport numpy as np", 21 | "iris = pd.read_csv(\"data/iris.csv\", index_col=\"Id\")", 22 | "from scipy.stats import normaltest\nz, pval = normaltest(iris[\"SepalLengthCm\"])\npval" 23 | ], 24 | "request": "Analyze the correlations of the dataset. Eliminate upper triangle for readability.\n", 25 | "answer": "corr = iris.select_dtypes(\"number\").corr(method=\"pearson\")\ncorr.where(np.tril(np.ones(corr.shape)).astype(bool))" 26 | }, 27 | { 28 | "variables": { 29 | "train_data": "pandas.DataFrame(shape=(891, 12), columns=[\"PassengerId\", \"Survived\", \"Pclass\", \"Name\", \"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Fare\", \"Cabin\", \"Embarked\"])\n PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked\n 0 1 0 3 Braund, Mr.... male 22.0 1 0 A/5 21171 7.2500 NaN S\n 1 2 1 1 Cumings, Mr... female 38.0 1 0 PC 17599 71.2833 C85 C\n .. ... ... ... ... ... ... ... ... ... ... ... ...\n 889 890 1 1 Behr, Mr. K... male 26.0 0 0 111369 30.0000 C148 C\n 890 891 0 3 Dooley, Mr.... male 32.0 0 0 370376 7.7500 NaN Q", 30 | "test_data": "pandas.DataFrame(shape=(418, 11), columns=[\"PassengerId\", \"Pclass\", \"Name\", \"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Fare\", \"Cabin\", \"Embarked\"])\n PassengerId Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked\n 0 892 3 Kelly, Mr. ... male 34.5 0 0 330911 7.8292 NaN Q\n 1 893 3 Wilkes, Mrs... female 47.0 1 0 363272 7.0000 NaN S\n .. ... ... ... ... ... ... ... ... ... ... ...\n 416 1308 3 Ware, Mr. F... male NaN 0 0 359309 8.0500 NaN S\n 417 1309 3 Peter, Mast... male NaN 1 1 2668 22.3583 NaN C", 31 | "women": "pandas.Series(shape=(314,))" 32 | }, 33 | "codes": [ 34 | "import numpy as np\nimport pandas as pd", 35 | "train_data = pd.read_csv(\"data/titanic_train.csv\")\ntrain_data.shape", 36 | "test_data = pd.read_csv(\"data/titanic_test.csv\")\ntest_data.shape", 37 | "women = train_data.loc[train_data.Sex == \"female\"][\"Survived\"]\nsum(women) / len(women) * 100" 38 | ], 39 | "request": "List out the names with age 42.\n", 40 | "answer": "train_data.loc[train_data.Age == 42, \"Name\"].tolist()" 41 | }, 42 | { 43 | "variables": { 44 | "twc": "pandas.DataFrame(shape=(93, 8), columns=[\"tweet_id\", \"author_id\", \"inbound\", \"created_at\", \"text\", \"response_tweet_id\", \"in_response_to_tweet_id\", \"text_wo_punct\"])\n tweet_id author_id inbound created_at text response_tweet_id in_response_to_tweet_id text_wo_punct\n 0 119237 105834 True Wed Oct 11 ... @AppleSuppo... 119236 NaN applesuppor...\n 1 119238 ChaseSupport False Wed Oct 11 ... @105835 You... NaN 119239.0 105835 your...\n .. ... ... ... ... ... ... ... ...\n 91 119333 105861 True Wed Oct 11 ... @Tesco bit ... 119335,119336 119332.0 tesco bit o...\n 92 119335 Tesco False Wed Oct 11 ... @105861 If ... NaN 119333.0 105861 if t..." 45 | }, 46 | "codes": [ 47 | "import numpy as np\nimport pandas as pd", 48 | "twc = pd.read_csv(\"data/twc-sample.csv\")", 49 | "import string\n\ndef remove_punctuation(text: str):\n return text.lower().translate(str.maketrans(\"\", \"\", string.punctuation))\n\ntwc[\"text_wo_punct\"] = twc[\"text\"].apply(remove_punctuation)" 50 | ], 51 | "request": "Remove English stopwords from the `text_wo_punct` column. Save the result in a new column `text_wo_stopwords`.\n", 52 | "answer": "from nltk.corpus import stopwords\n\nSTOPWORDS = set(stopwords.words(\"english\"))\n\ntwc[\"text_wo_stopwords\"] = twc[\"text_wo_punct\"].apply(\n lambda text: \" \".join([word for word in str(text).split() if word not in STOPWORDS])\n)" 53 | }, 54 | { 55 | "variables": { 56 | "df": "pandas.DataFrame(shape=(1599, 12), columns=[\"fixed acidity\", \"volatile acidity\", \"citric acid\", \"residual sugar\", \"chlorides\", \"free sulfur dioxide\", \"total sulfur dioxide\", \"density\", \"pH\", \"sulphates\", \"alcohol\", \"quality\"])\n fixed acidity volatile acidity citric acid residual sugar chlorides free sulfur dioxide total sulfur dioxide density pH sulphates alcohol quality\n 0 7.4 0.700 0.00 1.9 0.076 11.0 34.0 0.99780 3.51 0.56 9.4 5\n 1 7.8 0.880 0.00 2.6 0.098 25.0 67.0 0.99680 3.20 0.68 9.8 5\n ... ... ... ... ... ... ... ... ... ... ... ... ...\n 1597 5.9 0.645 0.12 2.0 0.075 32.0 44.0 0.99547 3.57 0.71 10.2 5\n 1598 6.0 0.310 0.47 3.6 0.067 18.0 42.0 0.99549 3.39 0.66 11.0 6", 57 | "z": "pandas.DataFrame(shape=(1599, 12), columns=[\"fixed acidity\", \"volatile acidity\", \"citric acid\", \"residual sugar\", \"chlorides\", \"free sulfur dioxide\", \"total sulfur dioxide\", \"density\", \"pH\", \"sulphates\", \"alcohol\", \"quality\"])\n fixed acidity volatile acidity citric acid residual sugar chlorides free sulfur dioxide total sulfur dioxide density pH sulphates alcohol quality\n 0 0.528360 0.961877 1.391472 0.453218 0.243707 0.466193 0.379133 0.558274 1.288643 0.579207 0.960246 0.787823\n 1 0.298547 1.967442 1.391472 0.043416 0.223875 0.872638 0.624363 0.028261 0.719933 0.128950 0.584777 0.787823\n ... ... ... ... ... ... ... ... ... ... ... ... ...\n 1597 1.390155 0.654620 0.775267 0.382271 0.264960 1.542054 0.075043 0.676657 1.677400 0.305990 0.209308 0.787823\n 1598 1.332702 1.216849 1.021999 0.752894 0.434990 0.203223 0.135861 0.666057 0.511130 0.010924 0.541630 0.450848", 58 | "X": "pandas.DataFrame(shape=(1599, 11), columns=[\"fixed acidity\", \"volatile acidity\", \"citric acid\", \"residual sugar\", \"chlorides\", \"free sulfur dioxide\", \"total sulfur dioxide\", \"density\", \"pH\", \"sulphates\", \"alcohol\"])\n fixed acidity volatile acidity citric acid residual sugar chlorides free sulfur dioxide total sulfur dioxide density pH sulphates alcohol\n 0 7.4 0.700 0.00 1.9 0.076 11.0 34.0 0.99780 3.51 0.56 9.4\n 1 7.8 0.880 0.00 2.6 0.098 25.0 67.0 0.99680 3.20 0.68 9.8\n ... ... ... ... ... ... ... ... ... ... ... ...\n 1597 5.9 0.645 0.12 2.0 0.075 32.0 44.0 0.99547 3.57 0.71 10.2\n 1598 6.0 0.310 0.47 3.6 0.067 18.0 42.0 0.99549 3.39 0.66 11.0", 59 | "y": "pandas.Series(shape=(1599,))" 60 | }, 61 | "codes": [ 62 | "import pandas as pd", 63 | "df = pd.read_csv(\"data/winequality-red.csv\")", 64 | "import numpy as np\nfrom scipy import stats\n\nz = np.abs(stats.zscore(df))\ndf[(z < 3).all(axis=1)]", 65 | "from sklearn.model_selection import train_test_split\n\nX = df.drop(columns=\"quality\")\ny = df[\"quality\"]" 66 | ], 67 | "request": "Split X and y into train and test sets. Use 20% of the data for testing.\nSet random state to 42.\nFit a RandomForestClassifier `clf` with 100 estimators on the training data, also using random state 42.\n", 68 | "answer": "from sklearn.ensemble import RandomForestClassifier\n\nX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\nclf = RandomForestClassifier(n_estimators=100, random_state=42)\nclf.fit(X_train, y_train)" 69 | } 70 | ] -------------------------------------------------------------------------------- /coml/prompts/generate_fewshots_leetcode.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "variables": {}, 4 | "codes": [ 5 | "import pandas as pd" 6 | ], 7 | "request": "Write a function `def capital_gainloss(stocks: pd.DataFrame) -> pd.DataFrame`.\n\n`stocks` is a DataFrame with the following columns:\n- stock_name: str\n- operation: str\n- operation_day: int\n- price: int\nEach record indicates that the stock which has `stock_name` had an operation on the day `operation_day` with the price. The `operation` must be one of `Sell`, `Buy`. It is guaranteed that each 'Sell' operation for a stock has a corresponding 'Buy' operation in a previous day. It is also guaranteed that each 'Buy' operation for a stock has a corresponding 'Sell' operation in an upcoming day.\n\nThe function should report the **Capital gain/loss** for each stock. The **Capital gain/loss** of a stock is the total gain or loss after buying and selling the stock one or many times. Return the result table in **any order**.\n\nThe result format is in the following example.\n\nExample input:\n```\nstocks:\n+---------------+-----------+---------------+--------+\n| stock_name | operation | operation_day | price |\n+---------------+-----------+---------------+--------+\n| Leetcode | Buy | 1 | 1000 |\n| Corona Masks | Buy | 2 | 10 |\n| Leetcode | Sell | 5 | 9000 |\n| Handbags | Buy | 17 | 30000 |\n| Corona Masks | Sell | 3 | 1010 |\n| Corona Masks | Buy | 4 | 1000 |\n| Corona Masks | Sell | 5 | 500 |\n| Corona Masks | Buy | 6 | 1000 |\n| Handbags | Sell | 29 | 7000 |\n| Corona Masks | Sell | 10 | 10000 |\n+---------------+-----------+---------------+--------+\n```\n\nExample output:\n```\n+---------------+-------------------+\n| stock_name | capital_gain_loss |\n+---------------+-------------------+\n| Corona Masks | 9500 |\n| Leetcode | 8000 |\n| Handbags | -23000 |\n+---------------+-------------------+\n```\n\nExample explanation:\n- Leetcode stock was bought at day 1 for 1000$ and was sold at day 5 for 9000$. Capital gain = 9000 - 1000 = 8000$.\n- Handbags stock was bought at day 17 for 30000$ and was sold at day 29 for 7000$. Capital loss = 7000 - 30000 = -23000$.\n- Corona Masks stock was bought at day 1 for 10$ and was sold at day 3 for 1010$. It was bought again at day 4 for 1000$ and was sold at day 5 for 500$. At last, it was bought at day 6 for 1000$ and was sold at day 10 for 10000$. Capital gain/loss is the sum of capital gains/losses for each ('Buy' --> 'Sell') operation = (1010 - 10) + (500 - 1000) + (10000 - 1000) = 1000 - 500 + 9000 = 9500$.\n", 8 | "answer": "def capital_gainloss(stocks: pd.DataFrame) -> pd.DataFrame:\n # Create a copy of the dataframe to avoid modifying the original one\n df = stocks.copy()\n\n # Calculate the capital gain/loss for each stock\n df[\"price\"] = df.apply(\n lambda x: -x[\"price\"] if x[\"operation\"] == \"Buy\" else x[\"price\"], axis=1\n )\n result = df.groupby(\"stock_name\")[\"price\"].sum().reset_index()\n result.columns = [\"stock_name\", \"capital_gain_loss\"]\n\n return result" 9 | }, 10 | { 11 | "variables": {}, 12 | "codes": [ 13 | "import pandas as pd" 14 | ], 15 | "request": "Write a function `def consecutive_numbers(logs: pd.DataFrame) -> pd.DataFrame`.\n\n`logs` is a DataFrame with the following columns:\n- id: int\n- num: int\n`id` is an autoincrement column and serves as the primary key.\n\nThe function should find all numbers that appear at least three times consecutively and return the result in a DataFrame with a single column `ConsecutiveNums`. The result should be returned in **any order**.\n\nExample input:\n```\nlogs:\n+----+-----+\n| id | num |\n+----+-----+\n| 1 | 1 |\n| 2 | 1 |\n| 3 | 1 |\n| 4 | 2 |\n| 5 | 1 |\n| 6 | 2 |\n| 7 | 2 |\n+----+-----+\n```\n\nExample output:\n```\n+-----------------+\n| ConsecutiveNums |\n+-----------------+\n| 1 |\n+-----------------+\n```\n\nExample explanation:\n- 1 is the only number that appears consecutively for at least three times.\n", 16 | "answer": "def consecutive_numbers(logs: pd.DataFrame) -> pd.DataFrame:\n # Create a new column to identify consecutive numbers\n logs[\"shifted_num\"] = logs[\"num\"].shift(2)\n logs[\"is_consecutive\"] = logs[\"num\"].eq(logs[\"num\"].shift()) & logs[\"num\"].eq(logs[\"shifted_num\"])\n\n # Filter out the rows where numbers are not consecutive\n consecutive = logs[logs[\"is_consecutive\"]]\n\n # Get the unique numbers that are consecutive\n result = pd.DataFrame(consecutive[\"num\"].unique(), columns=[\"ConsecutiveNums\"])\n\n return result" 17 | }, 18 | { 19 | "variables": {}, 20 | "codes": [ 21 | "import pandas as pd" 22 | ], 23 | "request": "Write a function `def employee_bonus(employee: pd.DataFrame, bonus: pd.DataFrame) -> pd.DataFrame`.\n\n`employee` is a DataFrame with the following columns:\n- empId: int\n- name: str\n- supervisor: int\n- salary: int\nEach record indicates the name and the ID of an employee in addition to their salary and the id of their manager.\n\n`bonus` is a DataFrame with the following columns:\n- empId: int\n- bonus: int\nEach record contains the id of an employee and their respective bonus.\n\nThe function should report the name and bonus amount of each employee with a bonus **less than** `1000`.\n\nReturn the result table in **any order**.\n\nThe result format is in the following example.\n\nExample input:\n```\nemployee:\n+-------+--------+------------+--------+\n| empId | name | supervisor | salary |\n+-------+--------+------------+--------+\n| 3 | Brad | null | 4000 |\n| 1 | John | 3 | 1000 |\n| 2 | Dan | 3 | 2000 |\n| 4 | Thomas | 3 | 4000 |\n+-------+--------+------------+--------+\nbonus:\n+-------+-------+\n| empId | bonus |\n+-------+-------+\n| 2 | 500 |\n| 4 | 2000 |\n+-------+-------+\n```\n\nExample output:\n```\n+------+-------+\n| name | bonus |\n+------+-------+\n| Brad | null |\n| John | null |\n| Dan | 500 |\n+------+-------+\n```\n", 24 | "answer": "def employee_bonus(employee: pd.DataFrame, bonus: pd.DataFrame) -> pd.DataFrame:\n # Merge the employee and bonus dataframes on empId.\n merged_df = pd.merge(employee[[\"empId\", \"name\"]], bonus, on=\"empId\", how=\"left\")\n\n # Filter out the employees with bonus >= 1000 or without a bonus.\n result_df = merged_df[merged_df[\"bonus\"].isna() | (merged_df[\"bonus\"] < 1000)]\n\n # Select only the name and bonus columns for the final output.\n return result_df[[\"name\", \"bonus\"]]" 25 | }, 26 | { 27 | "variables": {}, 28 | "codes": [ 29 | "import pandas as pd" 30 | ], 31 | "request": "Write a function `def price_at_given_date(products: pd.DataFrame) -> pd.DataFrame`.\n\n`products` is a DataFrame with the following columns:\n- product_id: int\n- new_price: int\n- change_date: datetime\nEach record indicates that the price of some product was changed to a new price at some date. Assume the price of all products before any change is `10`.\n\nThe function should return the prices of all products on `2019-08-16`. Return the result table in **any order**.\n\nThe result format is in the following example.\n\nExample input:\n```\nproducts:\n+------------+-----------+-------------+\n| product_id | new_price | change_date |\n+------------+-----------+-------------+\n| 1 | 20 | 2019-08-14 |\n| 2 | 50 | 2019-08-14 |\n| 1 | 30 | 2019-08-15 |\n| 1 | 35 | 2019-08-16 |\n| 2 | 65 | 2019-08-17 |\n| 3 | 20 | 2019-08-18 |\n+------------+-----------+-------------+\n```\n\nExample output:\n```\n+------------+-------+\n| product_id | price |\n+------------+-------+\n| 2 | 50 |\n| 1 | 35 |\n| 3 | 10 |\n+------------+-------+\n```\n", 32 | "answer": "def price_at_given_date(products: pd.DataFrame) -> pd.DataFrame:\n # Filter out the prices that are set after the given date\n filtered_products = products[products[\"change_date\"] <= \"2019-08-16\"]\n\n # Sort by product_id and change_date to get the latest price before or on the given date\n sorted_products = filtered_products.sort_values(\n by=[\"product_id\", \"change_date\"], ascending=[True, False]\n )\n\n # Drop duplicates to keep only the latest price for each product\n latest_prices = sorted_products.drop_duplicates(subset=[\"product_id\"])\n\n # If a product has no price change before the given date, its price is assumed to be 10\n all_products = pd.DataFrame({\"product_id\": products[\"product_id\"].unique()})\n result = pd.merge(\n all_products,\n latest_prices[[\"product_id\", \"new_price\"]],\n on=\"product_id\",\n how=\"left\",\n )\n result[\"price\"] = result[\"new_price\"].fillna(10).astype(int)\n\n # Select the required columns and return the result\n return result[[\"product_id\", \"price\"]]" 33 | }, 34 | { 35 | "variables": {}, 36 | "codes": [ 37 | "import pandas as pd" 38 | ], 39 | "request": "Write a function `def latest_login(logins: pd.DataFrame) -> pd.DataFrame`.\n\n`logins` is a DataFrame with the following columns:\n- user_id: int\n- time_stamp: datetime\nEach record contains information about the login time for the user with ID `user_id`. The combination of `user_id` and `time_stamp` is unique for this table.\n\nThe function should report the **latest** login for all users in the year `2020`. Do **not** include the users who did not login in `2020`. Return the result table **in any order**.\n\nThe result format is in the following example.\n\nExample input:\n```\nlogins:\n+---------+---------------------+\n| user_id | time_stamp |\n+---------+---------------------+\n| 6 | 2020-06-30 15:06:07 |\n| 6 | 2021-04-21 14:06:06 |\n| 6 | 2019-03-07 00:18:15 |\n| 8 | 2020-02-01 05:10:53 |\n| 8 | 2020-12-30 00:46:50 |\n| 2 | 2020-01-16 02:49:50 |\n| 2 | 2019-08-25 07:59:08 |\n| 14 | 2019-07-14 09:00:00 |\n| 14 | 2021-01-06 11:59:59 |\n+---------+---------------------+\n```\n\nExample output:\n```\n+---------+---------------------+\n| user_id | last_stamp |\n+---------+---------------------+\n| 6 | 2020-06-30 15:06:07 |\n| 8 | 2020-12-30 00:46:50 |\n| 2 | 2020-01-16 02:49:50 |\n+---------+---------------------+\n```\n\nExample explanation:\n- User 6 logged into their account 3 times but only once in 2020, so we include this login in the result table.\n- User 8 logged into their account 2 times in 2020, once in February and once in December. We include only the latest one (December) in the result table.\n- User 2 logged into their account 2 times but only once in 2020, so we include this login in the result table.\n- User 14 did not login in 2020, so we do not include them in the result table.\n", 40 | "answer": "def latest_login(logins: pd.DataFrame) -> pd.DataFrame:\n # Filter logins for the year 2020\n logins_2020 = logins[logins[\"time_stamp\"].dt.year == 2020]\n\n # Find the latest login for each user in 2020\n latest_logins = logins_2020.groupby(\"user_id\")[\"time_stamp\"].max().reset_index()\n\n # Rename the column for the output\n latest_logins = latest_logins.rename(columns={\"time_stamp\": \"last_stamp\"})\n\n return latest_logins" 41 | } 42 | ] -------------------------------------------------------------------------------- /coml/prompts/generate_fewshots_matplotlib.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "datasets": [ 4 | "scientist_1/Projects", 5 | "scientist_1/AssignedTo" 6 | ], 7 | "codes": [ 8 | "import pandas as pd\nimport matplotlib.pyplot as plt", 9 | "Projects_dataset = pd.read_csv('../dataset/scientist_1/Projects.csv')\nAssignedTo_dataset = pd.read_csv('../dataset/scientist_1/AssignedTo.csv')" 10 | ], 11 | "request": "What are the names of projects that require more than 300 hours, and how many scientists are assigned to each? Plot the result in a bar chart.", 12 | "answer": "# Merge the two datasets on the project code\nmerged_data = pd.merge(Projects_dataset, AssignedTo_dataset, left_on='Code', right_on='Project')\n\n# Filter projects that require more than 300 hours\nfiltered_data = merged_data[merged_data['Hours'] > 300]\n\n# Count the number of scientists assigned to each project\nproject_counts = filtered_data['Name'].value_counts()\n\n# Plot a bar chart\nplt.figure(figsize=(10, 8))\nplt.bar(project_counts.index, project_counts.values)\nplt.xticks(rotation=5)\nplt.xlabel('Project Name')\nplt.ylabel('Number of Scientists')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\nplt.title('Number of Scientists Assigned to Each Project')\nplt.show()" 13 | }, 14 | { 15 | "datasets": [ 16 | "coffee_shop/member" 17 | ], 18 | "codes": [ 19 | "import pandas as pd\nimport matplotlib.pyplot as plt", 20 | "member = pd.read_csv('../dataset/coffee_shop/member.csv')" 21 | ], 22 | "request": "A scatter chart showing the correlation between the age of the customer and the time of purchase colored by membership level.", 23 | "answer": "# Group the dataset by Membership_card\ngroups = member_dataset.groupby('Membership_card')\n\n# Create a scatter chart for each Membership_card\nfor membership_card, group in groups:\n plt.scatter(group['Age'], group['Time_of_purchase'], label=membership_card)\n\n# Set the title and labels\nplt.title('Correlation between Age and Time of Purchase')\nplt.xlabel('Age')\nplt.ylabel('Time of Purchase')\nplt.legend(loc='upper left')\n\n# Show the plot\nplt.show()" 24 | }, 25 | { 26 | "datasets": [ 27 | "game_injury/game" 28 | ], 29 | "codes": [ 30 | "import pandas as pd\nimport matplotlib.pyplot as plt", 31 | "game_dataset = pd.read_csv('../dataset/game_injury/game.csv')" 32 | ], 33 | "request": "Show the number of games in each season and group by away team in a group line chart. The x-axis is season.", 34 | "answer": "# group the dataset by season and away team\ngrouped = game_dataset.groupby(['Season', 'Away_team']).size().reset_index(name='counts')\n\n# create a pivot table with season as index and away team as columns\npivot_table = pd.pivot_table(grouped, values='counts', index=['Season'], columns=['Away_team'], fill_value=0)\n\n# create the line chart\npivot_table.plot(kind='line')\n\n# set the title and labels\nplt.title('Number of Games in Each Season by Away Team')\nplt.xlabel('Season')\nplt.ylabel('Number of Games')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\n\n# show the plot\nplt.show()" 35 | }, 36 | { 37 | "datasets": [ 38 | "pilot_record/pilot" 39 | ], 40 | "codes": [ 41 | "import pandas as pd\nimport matplotlib.pyplot as plt", 42 | "pilot_dataset = pd.read_csv('../dataset/pilot_record/pilot.csv')" 43 | ], 44 | "request": "What is the proportion of positions of pilots? Show the result in a pie chart.", 45 | "answer": "# Count the number of each position\nposition_counts = pilot_dataset['Position'].value_counts()\n\n# Create a pie chart\nplt.figure(figsize=(8,6))\nplt.pie(position_counts, labels = position_counts.index, autopct='%1.1f%%')\n\n# Set the title\nplt.title('Proportion of Positions')\n\n# Show the plot\nplt.show()" 46 | } 47 | ] -------------------------------------------------------------------------------- /coml/prompts/generate_fewshots_seaborn.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "datasets": [ 4 | "scientist_1/Projects", 5 | "scientist_1/AssignedTo" 6 | ], 7 | "codes": [ 8 | "import pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns", 9 | "Projects_dataset = pd.read_csv('../dataset/scientist_1/Projects.csv')\nAssignedTo_dataset = pd.read_csv('../dataset/scientist_1/AssignedTo.csv')" 10 | ], 11 | "request": "What are the names of projects that require more than 300 hours, and how many scientists are assigned to each? Plot the result in a bar chart.", 12 | "answer": "# Merge the two datasets on the project code\nmerged_data = pd.merge(Projects_dataset, AssignedTo_dataset, left_on='Code', right_on='Project')\n\n# Filter projects that require more than 300 hours\nfiltered_data = merged_data[merged_data['Hours'] > 300]\n\n# Count the number of scientists assigned to each project\nproject_counts = filtered_data['Name'].value_counts()\n\n# Plot a bar chart\nplt.figure(figsize=(10, 8))\nsns.barplot(project_counts)\nplt.xticks(rotation=5)\nplt.xlabel('Project Name')\nplt.ylabel('Number of Scientists')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\nplt.title('Number of Scientists Assigned to Each Project')\nplt.show()" 13 | } 14 | ] -------------------------------------------------------------------------------- /coml/prompts/generate_fewshots_v2.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "variables": { 4 | "df": "pandas.DataFrame(shape=(5, 2), columns=[\"width\", \"height\"])\n width height\n 0 35 20\n 1 42 12\n 2 38 30\n 3 45 25\n 4 50 15" 5 | }, 6 | "codes": [ 7 | "# Compute the area\ndf[\"area\"] = df[\"width\"] * df[\"height\"]" 8 | ], 9 | "request": "Output the correlation between width and height.", 10 | "answer": "df[\"width\"].corr(df[\"height\"])", 11 | "rationale": "To find the correlation between the 'width' and 'height' columns in the DataFrame `df`, we can use the `corr()` method provided by pandas, which computes pairwise correlations of columns, excluding NA/null values. Since we are only interested in the correlation between two specific columns, we can isolate those two columns and then call the `corr()` method.\n\nThe steps to do this are as follows:\n\n1. Select the 'width' and 'height' columns from the DataFrame `df`.\n2. Call the `corr()` method on 'width' column, with 'height' column as its parameter.\n\nHere is the code to do this:" 12 | }, 13 | { 14 | "variables": { 15 | "df": "pandas.DataFrame(shape=(10, 4), columns=[\"students\", \"score\", \"grade\", \"pass\"])\n students score grade pass\n 0 Alice 53 F False\n 1 Bob 68 D False\n .. ... ... ... ...\n 8 Ivan 45 F False\n 9 Judy 80 B True" 16 | }, 17 | "codes": [ 18 | "# Add a column \"grade\" based on the score\ndf[\"grade\"] = df[\"score\"].apply(lambda x: \"A\" if x >= 90 else \"B\" if x >= 80 else \"C\" if x >= 70 else \"D\" if x >= 60 else \"F\")", 19 | "# Add a column \"pass\" based on the grade\ndf[\"pass\"] = df[\"grade\"].apply(lambda x: x in [\"A\", \"B\", \"C\"])" 20 | ], 21 | "request": "Who are the top 3 students in the class?", 22 | "answer": "df.sort_values(by=\"score\", ascending=False).head(3)[\"students\"].tolist()", 23 | "rationale": "To identify the top 3 students in the class, we need to follow these steps:\n\n1. **Sort the DataFrame:** We need to sort the DataFrame based on the 'score' column in descending order because the students with the highest scores are considered the top students.\n2. **Select the Top 3:** After sorting, we select the first three rows of the DataFrame, which will represent the top 3 students.\n3. **Extract Student Names:** Finally, we extract the 'students' column from these top 3 rows to get the names of the top students, and convert it to a list.\n\nGiven that the DataFrame is named `df`, here is how you can perform these steps in code:" 24 | }, 25 | { 26 | "variables": { 27 | "df": "pandas.DataFrame(shape=(6, 4), columns=[\"product\", \"price\", \"weight\", \"screen\"])\n product price weight screen\n 0 Macbook Pro $1,299 3.0 lbs 13.3\"\n 1 Macbook Air $999 2.8 lbs 13.3\"\n 2 Surface Pro $799 1.7 lbs 12.3\"\n 3 Surface Laptop $999 2.8 lbs 13.5\"\n 4 Surface Book $1,499 3.3 lbs 13.5\"\n 5 Surface Studio $3,499 21 lbs 28\"" 28 | }, 29 | "codes": [ 30 | "# Convert the price to float inplace\ndf[\"price\"] = df[\"price\"].apply(lambda x: float(x.replace(\"$\", \"\").replace(\",\", \"\")))" 31 | ], 32 | "request": "Compare the average weight of surface products and macbook products?", 33 | "answer": "\n# Import pandas as it seems unimported.\nimport pandas as pd\n\n# Create a new variable to avoid unintentional modifications to raw DataFrame.\ndf_product_weight = pd.DataFrame({\n \"product\": df[\"product\"].apply(lambda x: \"Macbook\" if \"Macbook\" in x else \"Surface\"),\n \"weight\": df[\"weight\"].apply(lambda x: float(x.replace(\" lbs\", \"\"))),\n})\ndf_product_weight.groupby(\"product\").mean()", 34 | "rationale": "To compare the average weight of Surface products and MacBook products, we'll need to:\n\n1. Extract the weight as a numerical value from the \"weight\" column.\n2. Create a new column to classify each product as either 'Surface' or 'Macbook'.\n3. Use the `groupby` function to calculate the average weight for each product group.\n\nWe need to avoid unintentional modifications to the raw DataFrame, as the request didn't ask for in-place changes to the data.", 35 | "answer_wo_intact": "# Import pandas as it seems unimported.\nimport pandas as pd\n\ndf[\"product\"] = df[\"product\"].apply(lambda x: \"Macbook\" if \"Macbook\" in x else \"Surface\")\ndf[\"weight\"] = df[\"weight\"].apply(lambda x: float(x.replace(\" lbs\", \"\")))\ndf.groupby(\"product\").mean()", 36 | "rationale_wo_intact": "To compare the average weight of Surface products and MacBook products, we'll need to:\n\n1. Extract the weight as a numerical value from the \"weight\" column.\n2. Classify each product as either 'Surface' or 'Macbook'.\n3. Use the `groupby` function to calculate the average weight for each product group." 37 | }, 38 | { 39 | "variables": { 40 | "cars": "pandas.DataFrame(shape=(5, 3), columns=[\"brand\", \"model\", \"price\"])\n brand model price\n 0 BMW X5 60000\n 1 Toyota Camry 25000\n 2 Honda Civic 20000\n 3 Ford Mustang 30000\n 4 Tesla Model 3 40000", 41 | "owners": "pandas.DataFrame(shape=(6, 2), columns=[\"model\", \"owner\"])\n model owner\n 0 Civic Tom\n 1 X5 Tom\n 2 X5 Bob\n 3 Mustang Alice\n 4 Model 3 Alice\n 5 Model 3 Bob" 42 | }, 43 | "codes": [ 44 | "# Compute the average price of each brand\ncars.groupby(\"brand\").mean()" 45 | ], 46 | "request": "Compute the total price of each owner's cars. Return a series with \"Owner\" as index and \"Price\" as value.", 47 | "answer": "cars.merge(owners, on=\"model\").groupby(\"owner\").sum()[\"price\"].rename(\"Price\").rename_axis(\"Owner\")", 48 | "rationale": "To compute the total price of each owner's cars, we need to follow these steps:\n\n1. Merge the `cars` DataFrame with the `owners` DataFrame on the \"model\" column to link each car's price with its owner.\n2. Group the merged DataFrame by the \"owner\" column.\n3. Sum the \"price\" column within each group to get the total price per owner.\n4. Return the resulting series with \"owner\" as index and the total \"price\" as value. The `price` column should be renamed to `Price` and the index should be renamed to `Owner` as required.\n\nLet's write the code to perform these steps:" 49 | }, 50 | { 51 | "variables": { 52 | "df": "pandas.DataFrame(shape=(6, 2), columns=[\"year\", \"sales\"])\n year sales\n 0 2015 100\n 1 2016 120\n 2 2017 150\n 3 2018 180\n 4 2019 200\n 5 2020 250" 53 | }, 54 | "codes": [], 55 | "request": "Visualize the sales trend over years.", 56 | "answer": "import matplotlib.pyplot as plt\nplt.plot(df[\"year\"], df[\"sales\"])", 57 | "rationale": "To visualize the sales trend over the years, we will create a line plot using `matplotlib.pyplot`. This type of plot is ideal for showing how sales figures have changed over time. Here's how we can do it:\n\n1. Import `matplotlib.pyplot` for plotting.\n2. Use the `plot` function to create a line plot with 'year' on the x-axis and 'sales' on the y-axis.\n\nLet me write the code for you." 58 | } 59 | ] -------------------------------------------------------------------------------- /coml/vis_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .verifier import VisVerifier 2 | 3 | 4 | def show_svg(plt): 5 | """Show a plot as a SVG inline.""" 6 | from io import StringIO 7 | 8 | from IPython.display import SVG, display 9 | 10 | f = StringIO() 11 | plt.savefig(f, format="svg") 12 | plt.close() 13 | display(SVG(f.getvalue())) 14 | -------------------------------------------------------------------------------- /install.json: -------------------------------------------------------------------------------- 1 | { 2 | "packageManager": "python", 3 | "packageName": "mlcopilot", 4 | "uninstallInstructions": "Use your Python package manager (pip, conda, etc.) to uninstall the package mlcopilot" 5 | } 6 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "coml", 3 | "version": "0.0.8", 4 | "description": "JupyterLab extension for CoML.", 5 | "keywords": [ 6 | "jupyter", 7 | "jupyterlab", 8 | "jupyterlab-extension" 9 | ], 10 | "homepage": "https://github.com/microsoft/CoML", 11 | "bugs": { 12 | "url": "https://github.com/microsoft/CoML/issues" 13 | }, 14 | "license": "MIT", 15 | "author": { 16 | "name": "CoML team", 17 | "email": "coml@microsoft.com" 18 | }, 19 | "files": [ 20 | "lib/**/*.{d.ts,eot,gif,html,jpg,js,js.map,json,png,svg,woff2,ttf}", 21 | "style/**/*.{css,js,eot,gif,html,jpg,json,png,svg,woff2,ttf}" 22 | ], 23 | "main": "lib/index.js", 24 | "types": "lib/index.d.ts", 25 | "repository": { 26 | "type": "git", 27 | "url": "https://github.com/microsoft/CoML.git" 28 | }, 29 | "scripts": { 30 | "build": "jlpm build:lib && jlpm build:labextension:dev", 31 | "build:prod": "jlpm clean && jlpm build:lib:prod && jlpm build:labextension", 32 | "build:labextension": "jupyter labextension build .", 33 | "build:labextension:dev": "jupyter labextension build --development True .", 34 | "build:lib": "tsc --sourceMap", 35 | "build:lib:prod": "tsc", 36 | "clean": "jlpm clean:lib", 37 | "clean:lib": "rimraf lib tsconfig.tsbuildinfo", 38 | "clean:lintcache": "rimraf .eslintcache .stylelintcache", 39 | "clean:labextension": "rimraf jupyterlab_examples_hello_world/labextension jupyterlab_examples_hello_world/_version.py", 40 | "clean:all": "jlpm clean:lib && jlpm clean:labextension && jlpm clean:lintcache", 41 | "eslint": "jlpm eslint:check --fix", 42 | "eslint:check": "eslint . --cache --ext .ts,.tsx", 43 | "install:extension": "jlpm build", 44 | "lint": "jlpm prettier && jlpm eslint", 45 | "lint:check": "jlpm prettier:check && jlpm eslint:check", 46 | "prettier": "jlpm prettier:base --write --list-different", 47 | "prettier:base": "prettier \"**/*{.ts,.tsx,.js,.jsx,.css,.json,.md}\"", 48 | "prettier:check": "jlpm prettier:base --check", 49 | "test": "jest --coverage", 50 | "watch": "run-p watch:src watch:labextension", 51 | "watch:src": "tsc -w --sourceMap", 52 | "watch:labextension": "jupyter labextension watch ." 53 | }, 54 | "dependencies": { 55 | "@jupyterlab/application": "^4.0.0", 56 | "@jupyterlab/coreutils": "^6.0.0", 57 | "@jupyterlab/services": "^7.0.0" 58 | }, 59 | "devDependencies": { 60 | "@jupyterlab/builder": "^4.0.0", 61 | "@jupyterlab/testutils": "^4.0.0", 62 | "@types/jest": "^29.2.0", 63 | "@types/json-schema": "^7.0.11", 64 | "@types/react": "^18.0.26", 65 | "@typescript-eslint/eslint-plugin": "^5.55.0", 66 | "@typescript-eslint/parser": "^5.55.0", 67 | "css-loader": "^6.7.1", 68 | "eslint": "^8.36.0", 69 | "eslint-config-prettier": "^8.7.0", 70 | "eslint-plugin-prettier": "^4.2.1", 71 | "jest": "^29.2.0", 72 | "npm-run-all": "^4.1.5", 73 | "prettier": "^2.8.7", 74 | "rimraf": "^4.4.1", 75 | "source-map-loader": "^1.0.2", 76 | "typescript": "~5.0.2", 77 | "yjs": "^13.5.0" 78 | }, 79 | "publishConfig": { 80 | "access": "public" 81 | }, 82 | "jupyterlab": { 83 | "extension": true, 84 | "outputDir": "coml/labextension" 85 | }, 86 | "eslintIgnore": [ 87 | "node_modules", 88 | "dist", 89 | "coverage", 90 | "**/*.d.ts", 91 | "tests", 92 | "**/__tests__", 93 | "ui-tests" 94 | ], 95 | "eslintConfig": { 96 | "extends": [ 97 | "eslint:recommended", 98 | "plugin:@typescript-eslint/eslint-recommended", 99 | "plugin:@typescript-eslint/recommended", 100 | "plugin:prettier/recommended" 101 | ], 102 | "parser": "@typescript-eslint/parser", 103 | "parserOptions": { 104 | "project": "tsconfig.json", 105 | "sourceType": "module" 106 | }, 107 | "plugins": [ 108 | "@typescript-eslint" 109 | ], 110 | "rules": { 111 | "@typescript-eslint/naming-convention": [ 112 | "error", 113 | { 114 | "selector": "interface", 115 | "format": [ 116 | "PascalCase" 117 | ], 118 | "custom": { 119 | "regex": "^I[A-Z]", 120 | "match": true 121 | } 122 | } 123 | ], 124 | "@typescript-eslint/no-unused-vars": [ 125 | "warn", 126 | { 127 | "args": "none" 128 | } 129 | ], 130 | "@typescript-eslint/no-explicit-any": "off", 131 | "@typescript-eslint/no-namespace": "off", 132 | "@typescript-eslint/no-use-before-define": "off", 133 | "@typescript-eslint/quotes": [ 134 | "error", 135 | "single", 136 | { 137 | "avoidEscape": true, 138 | "allowTemplateLiterals": false 139 | } 140 | ], 141 | "curly": [ 142 | "error", 143 | "all" 144 | ], 145 | "eqeqeq": "error", 146 | "prefer-arrow-callback": "error" 147 | } 148 | }, 149 | "prettier": { 150 | "singleQuote": true, 151 | "trailingComma": "none", 152 | "arrowParens": "avoid", 153 | "endOfLine": "auto" 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling>=1.5.0", "jupyterlab>=4.0.0,<5", "hatch-nodejs-version"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "mlcopilot" 7 | version = "0.0.8" 8 | dependencies = [ 9 | "click", 10 | "colorama", 11 | "langchain", 12 | "langchain-community", 13 | "langchain-openai", 14 | "numpy", 15 | "orjson", 16 | "pandas", 17 | "peewee", 18 | "python-dotenv", 19 | "psycopg2-binary", 20 | "scikit_learn", 21 | "tiktoken", 22 | "xgboost", 23 | "typing-inspect", 24 | "typing_extensions", 25 | "ipython", 26 | "ipywidgets", 27 | "ipylab", 28 | "pylint", 29 | "markdown", 30 | ] 31 | description = 'Assistant for data scientists and machine learning developers.' 32 | readme = "README.md" 33 | license = { file = "LICENSE" } 34 | requires-python = ">= 3.8" 35 | authors = [ 36 | { name = "CoML team", email = "coml@microsoft.com" }, 37 | ] 38 | classifiers = [ 39 | "Development Status :: 4 - Beta", 40 | "Framework :: Jupyter", 41 | "Framework :: Jupyter :: JupyterLab", 42 | "Framework :: Jupyter :: JupyterLab :: 4", 43 | "Framework :: Jupyter :: JupyterLab :: Extensions", 44 | "Framework :: Jupyter :: JupyterLab :: Extensions :: Prebuilt", 45 | "Programming Language :: Python", 46 | "Programming Language :: Python :: 3.8", 47 | "Programming Language :: Python :: 3.9", 48 | "Programming Language :: Python :: 3.10", 49 | "Programming Language :: Python :: 3.11", 50 | ] 51 | 52 | [project.urls] 53 | Documentation = "https://github.com/microsoft/coml#readme" 54 | Issues = "https://github.com/microsoft/coml/issues" 55 | Source = "https://github.com/microsoft/coml" 56 | 57 | [project.scripts] 58 | coml-configagent = "coml.configagent.cli:main" 59 | 60 | [tool.hatch.version] 61 | source = "nodejs" 62 | 63 | [tool.hatch.build] 64 | include = [ 65 | "/coml" 66 | ] 67 | 68 | [tool.hatch.metadata.hooks.nodejs] 69 | fields = ["description"] 70 | 71 | [tool.hatch.build.targets.sdist] 72 | artifacts = ["coml/labextension"] 73 | exclude = [".github", "binder"] 74 | 75 | [tool.hatch.build.targets.wheel.shared-data] 76 | "coml/labextension" = "share/jupyter/labextensions/coml" 77 | "install.json" = "share/jupyter/labextensions/coml/install.json" 78 | 79 | [tool.hatch.build.hooks.version] 80 | path = "coml/_version.py" 81 | 82 | [tool.hatch.build.hooks.jupyter-builder] 83 | dependencies = ["hatch-jupyter-builder>=0.5"] 84 | build-function = "hatch_jupyter_builder.npm_builder" 85 | ensured-targets = [ 86 | "coml/labextension/package.json", 87 | ] 88 | 89 | [tool.hatch.build.hooks.jupyter-builder.build-kwargs] 90 | build_cmd = "build:prod" 91 | npm = ["jlpm"] 92 | 93 | [tool.hatch.build.hooks.jupyter-builder.editable-build-kwargs] 94 | build_cmd = "install:extension" 95 | npm = ["jlpm"] 96 | source_dir = "src" 97 | build_dir = "coml/labextension" 98 | 99 | [tool.jupyter-releaser.options] 100 | version_cmd = "hatch version" 101 | 102 | [tool.jupyter-releaser.hooks] 103 | before-build-npm = [ 104 | "python -m pip install 'jupyterlab>=4.0.0,<5'", 105 | "jlpm", 106 | "jlpm build:prod" 107 | ] 108 | before-build-python = ["jlpm clean:all"] 109 | 110 | [tool.black] 111 | line-length = 88 112 | target-version = ['py310'] 113 | include = '\.pyi?$' 114 | extend-exclude = '.+/(dist|.venv|venv|build)/.+' 115 | 116 | [tool.isort] 117 | profile = "black" 118 | multi_line_output = 3 119 | include_trailing_comma = true 120 | force_grid_wrap = 0 121 | use_parentheses = true 122 | ensure_newline_before_comments = true 123 | line_length = 88 124 | sections = [ 125 | "FUTURE", 126 | "STDLIB", 127 | "THIRDPARTY", 128 | "FIRSTPARTY", 129 | "LOCALFOLDER" 130 | ] 131 | skip = ''' 132 | .tox 133 | __pycache__ 134 | *.pyc 135 | .env 136 | venv*/* 137 | .venv/* 138 | reports/* 139 | dist/* 140 | node_modules 141 | ''' 142 | 143 | [project.optional-dependencies] 144 | dev = [ 145 | "pytest", 146 | "flake8", 147 | "pre-commit", 148 | "black", 149 | "isort", 150 | ] 151 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | import { 2 | JupyterFrontEnd, 3 | JupyterFrontEndPlugin 4 | } from '@jupyterlab/application'; 5 | 6 | import { ISessionContext } from '@jupyterlab/apputils'; 7 | import { 8 | INotebookTracker, 9 | Notebook, 10 | NotebookActions 11 | } from '@jupyterlab/notebook'; 12 | import { OutputArea } from '@jupyterlab/outputarea'; 13 | import { Kernel, KernelMessage } from '@jupyterlab/services'; 14 | import { CodeCell } from '@jupyterlab/cells'; 15 | 16 | function findCellByOutputArea( 17 | notebook: Notebook, 18 | outputArea: OutputArea 19 | ): number { 20 | // We handle the situation when Shift+Enter is pressed and active cell is not the last one. 21 | const cells = notebook.widgets; 22 | 23 | const matches: number[] = []; 24 | for (let i = 0; i < cells.length; i++) { 25 | const cell = cells[i]; 26 | if ( 27 | cell.model.type === 'code' && 28 | (cell as CodeCell).outputArea === outputArea 29 | ) { 30 | matches.push(i); 31 | } 32 | } 33 | 34 | if (matches.length === 1) { 35 | return matches[0]; 36 | } else if (matches.length > 1) { 37 | // Multiple match, we take the one closest to (but not exceeding) to the active cell. 38 | let bestIndex = -1; 39 | for (const index of matches) { 40 | if (bestIndex === -1 || index <= notebook.activeCellIndex) { 41 | bestIndex = index; 42 | } else { 43 | break; 44 | } 45 | } 46 | return bestIndex; 47 | } else { 48 | // When something is wrong, we fallback to the active cell. 49 | return notebook.activeCell ? notebook.activeCellIndex : -1; 50 | } 51 | } 52 | 53 | function getNotebookContext( 54 | app: JupyterFrontEnd, 55 | notebookTracker: INotebookTracker | null 56 | ): ISessionContext | undefined { 57 | if (!notebookTracker?.currentWidget) { 58 | return undefined; 59 | } 60 | if (notebookTracker.currentWidget !== app.shell.currentWidget) { 61 | return undefined; 62 | } 63 | return notebookTracker.currentWidget.context.sessionContext; 64 | } 65 | 66 | function getNotebook( 67 | app: JupyterFrontEnd, 68 | notebookTracker: INotebookTracker | null 69 | ): Notebook | undefined { 70 | if (!notebookTracker?.currentWidget) { 71 | return undefined; 72 | } 73 | if (notebookTracker.currentWidget !== app.shell.currentWidget) { 74 | return undefined; 75 | } 76 | return notebookTracker.currentWidget.content; 77 | } 78 | 79 | function insertCellBelow( 80 | notebook: Notebook, 81 | activeCellIndex: number, // Active cell index from notebook is not necessarily reliable 82 | code: string, 83 | chatMetadata: any = {}, 84 | editorFocus = true 85 | ): void { 86 | notebook.model?.sharedModel.insertCell(activeCellIndex + 1, { 87 | cell_type: 'code', 88 | metadata: { 89 | coml: chatMetadata 90 | }, 91 | source: code 92 | }); 93 | notebook.activeCellIndex = activeCellIndex + 1; 94 | if (editorFocus) { 95 | notebook.activeCell?.editor?.focus(); 96 | } 97 | } 98 | 99 | function getLastCell(notebook: Notebook, currentCellIndex: number) { 100 | if (currentCellIndex <= 0) { 101 | console.warn( 102 | `Current cell index is ${currentCellIndex}. No last cell found.` 103 | ); 104 | return null; 105 | } else { 106 | const lastCell = notebook.widgets[currentCellIndex - 1]; 107 | return lastCell.model; 108 | } 109 | } 110 | 111 | function getCurrentCell(notebook: Notebook, currentCellIndex: number) { 112 | if (currentCellIndex < 0) { 113 | console.warn(`Invalid current cell index: ${currentCellIndex}.`); 114 | return null; 115 | } else { 116 | const lastCell = notebook.widgets[currentCellIndex]; 117 | return lastCell.model; 118 | } 119 | } 120 | 121 | const plugin: JupyterFrontEndPlugin = { 122 | id: 'coml:plugin', 123 | description: 'JupyterLab extension for CoML.', 124 | autoStart: true, 125 | optional: [INotebookTracker], 126 | activate: ( 127 | app: JupyterFrontEnd, 128 | notebookTracker: INotebookTracker | null 129 | ) => { 130 | function handleCommand( 131 | outputArea: OutputArea, 132 | command: any, 133 | sendCallback: (msg: string) => void 134 | ): any { 135 | if (command['command'] === 'insert_cell_below') { 136 | // Command format: { "command": "insert_cell_below", "code": "print('hello')", "metadata": { "request": ... } } 137 | const notebook = getNotebook(app, notebookTracker); 138 | if (notebook) { 139 | const insertIndex = findCellByOutputArea(notebook, outputArea); 140 | insertCellBelow( 141 | notebook, 142 | insertIndex, 143 | command['code'], 144 | command['metadata'] 145 | ); 146 | } else { 147 | console.warn('No notebook found'); 148 | } 149 | // Reply with empty string to indicate that the command is handled. 150 | return ''; 151 | } else if (command['command'] === 'insert_and_execute_cell_below') { 152 | // Command format: { "command": "insert_and_execute_cell_below", "code": "print('hello')", "metadata": { "request": ... } } 153 | const notebook = getNotebook(app, notebookTracker); 154 | const sessionContext = getNotebookContext(app, notebookTracker); 155 | if (notebook && sessionContext) { 156 | const insertIndex = findCellByOutputArea(notebook, outputArea); 157 | insertCellBelow( 158 | notebook, 159 | insertIndex, 160 | command['code'], 161 | command['metadata'], 162 | false 163 | ); 164 | // Reply must be sent before running the next cell to avoid the deadlock warning. 165 | sendCallback(''); 166 | // Run active cell 167 | NotebookActions.run(notebook, sessionContext); 168 | } else { 169 | console.warn('No notebook or session context found'); 170 | // Reply with empty string to indicate that the command is handled. 171 | return ''; 172 | } 173 | } else if (command['command'] === 'last_cell') { 174 | // Command format: { "command": "last_cell" } 175 | const notebook = getNotebook(app, notebookTracker); 176 | if (notebook) { 177 | const currentCellIndex = findCellByOutputArea(notebook, outputArea); 178 | const lastCell = getLastCell(notebook, currentCellIndex); 179 | if (lastCell) { 180 | return JSON.stringify(lastCell.toJSON()); 181 | } else { 182 | console.warn('No last cell found'); 183 | } 184 | } else { 185 | console.warn('No notebook found'); 186 | } 187 | } else if (command['command'] === 'running_cell') { 188 | // Command format: { "command": "running_cell" } 189 | const notebook = getNotebook(app, notebookTracker); 190 | if (notebook) { 191 | const currentCellIndex = findCellByOutputArea(notebook, outputArea); 192 | const cell = getCurrentCell(notebook, currentCellIndex); 193 | if (cell) { 194 | return JSON.stringify(cell.toJSON()); 195 | } else { 196 | console.warn('No running cell is found'); 197 | } 198 | } else { 199 | console.warn('No notebook found'); 200 | } 201 | } else if (command['command'] === 'update_running_cell_metadata') { 202 | // Command format: { "command": "update_running_cell_metadata", "metadata": ... } 203 | const notebook = getNotebook(app, notebookTracker); 204 | if (notebook) { 205 | const currentCellIndex = findCellByOutputArea(notebook, outputArea); 206 | const cell = getCurrentCell(notebook, currentCellIndex); 207 | if (cell) { 208 | cell.setMetadata('coml', command['metadata']); 209 | return ''; 210 | } else { 211 | console.warn('No running cell is found'); 212 | } 213 | } else { 214 | console.warn('No notebook found'); 215 | } 216 | } else { 217 | console.warn('Invalid command:', command); 218 | return undefined; 219 | } 220 | return ''; 221 | } 222 | 223 | function hackedOnInputRequest( 224 | this: OutputArea, 225 | msg: KernelMessage.IInputRequestMsg, 226 | future: Kernel.IShellFuture 227 | ): void { 228 | // This is the hacked version of handler of `input()` (at kernel side). 229 | // Everything needs to be done at JS side is firstly sent here and routed within this method. 230 | 231 | let sent = false; 232 | function sendCallback(reply: string) { 233 | if (sent) { 234 | console.warn('Reply already sent.'); 235 | } else { 236 | future.sendInputReply({ status: 'ok', value: reply }, msg.header); 237 | sent = true; 238 | } 239 | } 240 | 241 | try { 242 | // only apply the hack if the command is valid JSON 243 | const command = JSON.parse(msg.content.prompt); 244 | const result = handleCommand(this, command, sendCallback); 245 | if (result !== undefined) { 246 | if (!sent) { 247 | sendCallback(result); 248 | } 249 | } else { 250 | return (this as any).nativeOnInputRequest(msg, future); 251 | } 252 | } catch (err) { 253 | console.log('Not a JSON command', msg, err); 254 | return (this as any).nativeOnInputRequest(msg, future); 255 | } 256 | } 257 | 258 | (OutputArea.prototype as any).nativeOnInputRequest = ( 259 | OutputArea.prototype as any 260 | ).onInputRequest; 261 | (OutputArea.prototype as any).onInputRequest = hackedOnInputRequest; 262 | 263 | app.commands.addCommand('coml:insert_cell_below', { 264 | label: 'Execute coml:insert_cell_below Command', 265 | caption: 'Execute coml:insert_cell_below Command', 266 | execute: (args: any) => { 267 | const notebook = getNotebook(app, notebookTracker); 268 | if (!notebook) { 269 | console.warn('No notebook found'); 270 | return; 271 | } 272 | insertCellBelow( 273 | notebook, 274 | notebook.activeCellIndex, 275 | args['code'], 276 | args['metadata'] 277 | ); 278 | } 279 | }); 280 | 281 | app.commands.addCommand('coml:insert_and_execute_cell_below', { 282 | label: 'Execute coml:insert_and_execute_cell_below Command', 283 | caption: 'Execute coml:insert_and_execute_cell_below Command', 284 | execute: (args: any) => { 285 | const notebook = getNotebook(app, notebookTracker); 286 | const sessionContext = getNotebookContext(app, notebookTracker); 287 | if (!notebook || !sessionContext) { 288 | console.warn('Notebook or session context not found'); 289 | return; 290 | } 291 | insertCellBelow( 292 | notebook, 293 | notebook.activeCellIndex, 294 | args['code'], 295 | args['metadata'] 296 | ); 297 | NotebookActions.run(notebook, sessionContext); 298 | } 299 | }); 300 | } 301 | }; 302 | 303 | export default plugin; 304 | -------------------------------------------------------------------------------- /test/configagent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/CoML/46f7452d70021be7702d05343096cdc1adec09ef/test/configagent/__init__.py -------------------------------------------------------------------------------- /test/configagent/helper.py: -------------------------------------------------------------------------------- 1 | from hashlib import md5 2 | from typing import Any, List, Optional 3 | 4 | from langchain.callbacks.manager import CallbackManagerForLLMRun 5 | from langchain.embeddings import FakeEmbeddings 6 | from langchain.llms.fake import FakeListLLM 7 | 8 | from coml.configagent.constants import EMBED_DIM 9 | 10 | 11 | class MockKnowledgeLLM(FakeListLLM): 12 | responses: List[str] = ["This is a mock knowledge."] 13 | 14 | def _call( 15 | self, 16 | prompt: str, 17 | stop: Optional[List[str]] = None, 18 | run_manager: Optional[CallbackManagerForLLMRun] = None, 19 | ) -> str: 20 | response = self.responses[0] 21 | return response 22 | 23 | 24 | class MockSuggestLLM(FakeListLLM): 25 | responses: List[str] = [ 26 | ( 27 | "Configuration 1: cost is very small. gamma is very small. kernel is linear. degree is very small.\n" 28 | "Configuration 2: cost is very small. gamma is very small. kernel is linear. degree is very small.\n" 29 | "Configuration 3: cost is very small. gamma is very small. kernel is linear. degree is very small." 30 | ) 31 | ] 32 | 33 | def _call( 34 | self, 35 | prompt: str, 36 | stop: Optional[List[str]] = None, 37 | run_manager: Optional[CallbackManagerForLLMRun] = None, 38 | ) -> str: 39 | response = self.responses[0] 40 | return response 41 | 42 | 43 | class MockEmbeddingModel(FakeEmbeddings): 44 | size: int = 1536 45 | 46 | def _get_embedding(self, text) -> List[float]: 47 | md5_10 = int(md5(text.encode("utf-8")).hexdigest(), 16) 48 | return [md5_10 // 10**i % 10 for i in range(10)] + [0.0] * (EMBED_DIM - 10) 49 | 50 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 51 | return [self._get_embedding(text) for text in texts] 52 | 53 | def embed_query(self, text: str) -> List[float]: 54 | return self._get_embedding(text) 55 | -------------------------------------------------------------------------------- /test/configagent/test_experience.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from peewee import JOIN, fn 5 | 6 | from coml.configagent.constants import COML_DB_BACKEND, TOP_K 7 | from coml.configagent.experience import ( 8 | _ingest_solution, 9 | _ingest_space, 10 | _ingest_task, 11 | canonicalize_config, 12 | gen_experience, 13 | gen_experience_per_task, 14 | get_quantile_stat, 15 | ingest_experience, 16 | ) 17 | from coml.configagent.orm import Solution, Task 18 | from coml.configagent.space import create_tables, drop_tables 19 | from coml.configagent.utils import set_llms 20 | 21 | from .helper import MockEmbeddingModel, MockKnowledgeLLM 22 | 23 | history_df = pd.read_csv("assets/example_history.csv") 24 | 25 | 26 | def test_get_quantile_stat(): 27 | experience_df = history_df[ 28 | history_df[["TASK_ID", "SCORE"]] 29 | .groupby("TASK_ID") 30 | .rank(method="first", ascending=False)["SCORE"] 31 | <= TOP_K 32 | ] 33 | quantile_info = get_quantile_stat(experience_df) 34 | assert quantile_info == { 35 | "cost": [ 36 | 0.04495582729578018, 37 | 2.966309070587158, 38 | 8.799302101135254, 39 | 53.32119369506836, 40 | 223.55349731445312, 41 | ], 42 | "gamma": [ 43 | 0.006136692129075527, 44 | 0.03009503148496151, 45 | 0.1399584263563156, 46 | 1.139962911605835, 47 | 2.557131290435791, 48 | ], 49 | "degree": [2.0, 2.0, 3.0, 3.0, 3.0], 50 | } 51 | return quantile_info 52 | 53 | 54 | def test_ingest_space(): 55 | space_id = "__test_space__" 56 | space_desc = "This is space description" 57 | quantile_info = test_get_quantile_stat() 58 | space = _ingest_space(space_id, space_desc, quantile_info) 59 | assert space is not None 60 | return space, quantile_info 61 | 62 | 63 | def test_ingest_task(): 64 | set_llms(embedding_model=MockEmbeddingModel) 65 | task_desc = {k: "." for k in history_df["TASK_ID"].unique()} 66 | _ingest_task(history_df, task_desc) 67 | 68 | 69 | def test_canonicalize_config(): 70 | config_term = { 71 | "cost": 62.72359795339274, 72 | "gamma": 0.0051534831825209, 73 | "kernel": "radial", 74 | "degree": np.nan, 75 | } 76 | _, quantile_info = test_ingest_space() 77 | demo = canonicalize_config(config_term, quantile_info) 78 | assert ( 79 | demo 80 | == "cost is large. gamma is very small. kernel is radial. degree is very small." 81 | ) 82 | return demo 83 | 84 | 85 | def test_ingest_solution(): 86 | space, _ = test_ingest_space() 87 | _ingest_solution(history_df, space) 88 | 89 | 90 | def test_ingest_experience(): 91 | space_desc = "This is space description" 92 | space_id = "__test_space__" 93 | space = ingest_experience( 94 | history_df, 95 | {task_id: task_id for task_id in history_df.TASK_ID.unique().astype(str)}, 96 | space_desc, 97 | space_id, 98 | ) 99 | assert space is not None 100 | return space 101 | 102 | 103 | def test_gen_experience_per_task(): 104 | space, _ = test_ingest_space() 105 | test_ingest_task() 106 | task = Task.get(Task.task_id == "43") 107 | experience_per_task = gen_experience_per_task(space, task) 108 | assert isinstance(experience_per_task, str) 109 | return experience_per_task 110 | 111 | 112 | def test_gen_experience(): 113 | set_llms(embedding_model=MockEmbeddingModel) 114 | drop_tables() 115 | create_tables() 116 | task_desc = "test task description" 117 | space = test_ingest_experience() 118 | order_key = Task.embedding.cosine_distance(task_desc) 119 | subquery = ( 120 | Task.select(Task.task_id) 121 | .join(Solution) 122 | .where(Solution.space == space) 123 | .distinct() 124 | ) 125 | 126 | tasks_select = ( 127 | Task.select() 128 | .join(subquery, JOIN.LEFT_OUTER, on=(Task.task_id == subquery.c.task_id)) 129 | .order_by(order_key) 130 | ) 131 | examples_gt = [gen_experience_per_task(space, task) for task in tasks_select] 132 | 133 | retrieved_tasks, examples = gen_experience(space, task_desc) 134 | 135 | assert all(examples[i][:10] == examples_gt[i][:10] for i in range(len(examples))) 136 | -------------------------------------------------------------------------------- /test/configagent/test_knowledge.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pandas as pd 4 | import pytest 5 | 6 | from coml.configagent.knowledge import ( 7 | gen_knowledge_candidate, 8 | post_validation, 9 | split_knowledge, 10 | suggest_with_knowledge, 11 | ) 12 | from coml.configagent.space import create_space, delete_space 13 | from coml.configagent.surrogate_utils import process_history_df, train_surrogate 14 | from coml.configagent.utils import set_llms 15 | 16 | from .helper import MockEmbeddingModel, MockKnowledgeLLM, MockSuggestLLM 17 | 18 | examples = [ 19 | "Task: .\n" 20 | "Configuration 1: cost is large. gamma is small. kernel is radial. degree is very small.\n" 21 | "Configuration 2: cost is very large. gamma is small. kernel is radial. degree is very small.\n" 22 | "Configuration 3: cost is medium. gamma is medium. kernel is radial. degree is very small." 23 | ] 24 | 25 | 26 | def test_gen_knowledge_candidate(): 27 | set_llms(knowledge_model=MockKnowledgeLLM) 28 | knowledge = gen_knowledge_candidate(examples) 29 | assert knowledge == "\n1.This is a mock knowledge." 30 | return knowledge 31 | 32 | 33 | def test_eval_knowledge(): 34 | knowledge_candidate = test_gen_knowledge_candidate() 35 | quantile_info = { 36 | "cost": [ 37 | 0.04495582729578018, 38 | 2.966309070587158, 39 | 8.799302101135254, 40 | 53.32119369506836, 41 | 223.55349731445312, 42 | ], 43 | "gamma": [ 44 | 0.006136692129075527, 45 | 0.03009503148496151, 46 | 0.1399584263563156, 47 | 1.139962911605835, 48 | 2.557131290435791, 49 | ], 50 | "degree": [2.0, 2.0, 3.0, 3.0, 3.0], 51 | } 52 | set_llms(suggest_model=MockSuggestLLM) 53 | suggest_configs = suggest_with_knowledge( 54 | examples, knowledge_candidate, examples[0], quantile_info 55 | ) 56 | assert len(suggest_configs) == 3 57 | return suggest_configs 58 | 59 | 60 | def test_post_validation(): 61 | history_path = "assets/example_history.csv" 62 | space_id = "__test_space__" 63 | delete_space(space_id) 64 | space = create_space( 65 | space_id=space_id, 66 | history=history_path, 67 | task_desc="assets/example_descriptions.json", 68 | no_knowledge=True, 69 | ) 70 | history_df = pd.read_csv(history_path) 71 | history_df_processed, config_names = process_history_df(history_df) 72 | surrogate = train_surrogate(history_df_processed) 73 | set_llms(suggest_model=MockSuggestLLM, knowledge_model=MockKnowledgeLLM) 74 | space_id = "__test_space__" 75 | knowledges = post_validation(space, surrogate, config_names) 76 | delete_space(space_id) 77 | assert set(knowledges) == set(split_knowledge(f"1. {MockKnowledgeLLM()('')}")) 78 | return knowledges 79 | -------------------------------------------------------------------------------- /test/configagent/test_space.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | from coml.configagent.space import ( 5 | create_space, 6 | database_proxy, 7 | delete_space, 8 | drop_tables, 9 | gen_space_description, 10 | list_available_spaces, 11 | print_space, 12 | ) 13 | from coml.configagent.utils import set_llms 14 | 15 | from .helper import MockEmbeddingModel, MockKnowledgeLLM, MockSuggestLLM 16 | 17 | 18 | def test_gen_space_description(): 19 | space_desc = gen_space_description( 20 | pd.read_csv("assets/example_history.csv"), 21 | ) 22 | assert space_desc == ( 23 | "Space has 4 configurable hyper-parameters, i.e., 'cost', 'gamma', 'kernel', 'degree'.\n" 24 | ) 25 | return space_desc 26 | 27 | 28 | def test_delete_space(): 29 | space_id = "__test_space__" 30 | delete_space(space_id) 31 | available_spaces = list_available_spaces() 32 | assert space_id not in (space.space_id for space in available_spaces) 33 | 34 | 35 | def test_create_space(): 36 | space_id = "__test_space__" 37 | delete_space(space_id) 38 | available_spaces = list_available_spaces() 39 | set_llms(knowledge_model=MockKnowledgeLLM) 40 | create_space( 41 | space_id=space_id, 42 | history="assets/example_history.csv", 43 | task_desc="assets/example_descriptions.json", 44 | no_knowledge=True, 45 | ) 46 | assert len(list_available_spaces()) == 1 + len(available_spaces) 47 | assert space_id in (space.space_id for space in list_available_spaces()) 48 | delete_space(space_id) 49 | -------------------------------------------------------------------------------- /test/configagent/test_suggest.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import orjson 4 | import pytest 5 | 6 | from coml.configagent.constants import TOP_K, inverse_bin_map 7 | from coml.configagent.knowledge import split_knowledge 8 | from coml.configagent.orm import Knowledge 9 | from coml.configagent.space import create_space, delete_space 10 | from coml.configagent.suggest import suggest 11 | from coml.configagent.utils import parse_configs, set_llms 12 | 13 | from .helper import MockEmbeddingModel, MockKnowledgeLLM, MockSuggestLLM 14 | 15 | 16 | def _create_space(): 17 | space_id = "__test_space__" 18 | delete_space(space_id) 19 | set_llms(knowledge_model=MockKnowledgeLLM, embedding_model=MockEmbeddingModel) 20 | space = create_space( 21 | space_id=space_id, 22 | history="assets/example_history.csv", 23 | task_desc="assets/example_descriptions.json", 24 | no_knowledge=True, 25 | ) 26 | return space 27 | 28 | 29 | def test_suggest_with_few_shot_no_knowledge(): 30 | space = _create_space() 31 | set_llms(suggest_model=MockSuggestLLM, embedding_model=MockEmbeddingModel) 32 | suggest_configs, knowledge = suggest(space, "") 33 | delete_space(space.space_id) 34 | assert len(suggest_configs) == 3 35 | assert knowledge == "" 36 | 37 | 38 | def test_suggest_with_few_shot_with_knowledge(): 39 | space = _create_space() 40 | knowledges_ = split_knowledge(f"1. {MockKnowledgeLLM()('')}") 41 | try: 42 | for knowledge_ in knowledges_: 43 | Knowledge.create(space_id=space.space_id, knowledge=knowledge_) 44 | except: 45 | pass 46 | set_llms(suggest_model=MockSuggestLLM, embedding_model=MockEmbeddingModel) 47 | suggest_configs, knowledge = suggest(space, "") 48 | delete_space(space.space_id) 49 | assert len(suggest_configs) == 3 50 | knowledge_str = "" 51 | for i, knowledge_ in enumerate(knowledges_): 52 | knowledge_str += f"{i+1}. {knowledge_}\n\n" 53 | assert knowledge == knowledge_str 54 | -------------------------------------------------------------------------------- /test/configagent/test_surrogate_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from coml.configagent.surrogate_utils import ( 5 | evaluate_configs, 6 | process_history_df, 7 | train_surrogate, 8 | ) 9 | 10 | 11 | def test_process_history_df(): 12 | history_df = pd.read_csv("assets/example_history.csv") 13 | history_df_processed, config_names = process_history_df(history_df) 14 | assert history_df_processed.shape == (340, 7) 15 | assert set(config_names) == { 16 | "cost", 17 | "gamma", 18 | "degree", 19 | "kernel__DUMM_SEP__linear", 20 | "kernel__DUMM_SEP__polynomial", 21 | "kernel__DUMM_SEP__radial", 22 | } 23 | return history_df_processed, config_names 24 | 25 | 26 | def test_train_surrogate(): 27 | history_df_processed, config_names = test_process_history_df() 28 | surrogate = train_surrogate(history_df_processed) 29 | assert surrogate.predict(np.random.rand(10, 6)).shape == (10,) 30 | return surrogate, config_names 31 | 32 | 33 | def test_evaluate_configs(): 34 | surrogate, config_names = test_train_surrogate() 35 | configs = [ 36 | {"cost": 0.1, "gamma": 10, "kernel": "linear"}, 37 | {"cost": 10, "gamma": 0.1, "degree": 5, "kernel": "polynomial"}, 38 | {"cost": 5, "gamma": 2, "kernel": "radial"}, 39 | ] 40 | scores = evaluate_configs(surrogate, configs, config_names) 41 | assert scores.shape == (3,) 42 | return scores 43 | -------------------------------------------------------------------------------- /test/configagent/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from coml.configagent.utils import format_config, parse_configs 4 | 5 | 6 | def test_parser_configs(): 7 | configs = parse_configs( 8 | "Configuration 1: learning rate is 0.01. batch size is 32.\n" 9 | "Configuration 2: learning rate is 0.02. batch size is 64.\n" 10 | "Configuration 3: learning rate is 0.03. batch size is 128.\n", 11 | 3, 12 | ) 13 | assert configs == [ 14 | {"learning rate": "0.01", "batch size": "32"}, 15 | {"learning rate": "0.02", "batch size": "64"}, 16 | {"learning rate": "0.03", "batch size": "128"}, 17 | ] 18 | 19 | 20 | def test_format_config(): 21 | config = format_config( 22 | {"learning rate": "0.01", "batch size": "32"}, 23 | ) 24 | assert config == "learning rate is 0.01. batch size is 32." 25 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "allowSyntheticDefaultImports": true, 4 | "composite": true, 5 | "declaration": true, 6 | "esModuleInterop": true, 7 | "incremental": true, 8 | "jsx": "react", 9 | "module": "esnext", 10 | "moduleResolution": "node", 11 | "noEmitOnError": true, 12 | "noImplicitAny": true, 13 | "noUnusedLocals": true, 14 | "preserveWatchOutput": true, 15 | "resolveJsonModule": true, 16 | "outDir": "lib", 17 | "rootDir": "src", 18 | "strict": true, 19 | "strictNullChecks": true, 20 | "target": "ES2018" 21 | }, 22 | "include": ["src/*"] 23 | } 24 | --------------------------------------------------------------------------------