├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE.md ├── OPEN_SOURCE_LICENSES.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── lint.yml │ ├── publish.yml │ └── test.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── assets └── kaia_llama.webp ├── mle ├── __init__.py ├── agents │ ├── __init__.py │ ├── advisor.py │ ├── chat.py │ ├── coder.py │ ├── debugger.py │ ├── planner.py │ ├── reporter.py │ └── summarizer.py ├── cli.py ├── function │ ├── __init__.py │ ├── data.py │ ├── execution.py │ ├── files.py │ ├── interaction.py │ └── search.py ├── integration │ ├── __init__.py │ ├── github.py │ ├── google_calendar.py │ ├── kaggle.py │ └── local_git.py ├── model │ ├── __init__.py │ ├── anthropic.py │ ├── common.py │ ├── deepseek.py │ ├── gemini.py │ ├── mistral.py │ ├── ollama.py │ ├── openai.py │ └── vllm.py ├── server │ ├── __init__.py │ └── app.py ├── utils │ ├── __init__.py │ ├── cache.py │ ├── chunk.py │ ├── component_memory.py │ ├── data.py │ ├── memory.py │ ├── parser.py │ └── system.py ├── version.py └── workflow │ ├── __init__.py │ ├── baseline.py │ ├── chat.py │ ├── kaggle.py │ └── report.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests └── __init__.py └── web ├── .eslintrc.json ├── .gitignore ├── README.md ├── app ├── fonts │ ├── GeistMonoVF.woff │ └── GeistVF.woff ├── globals.css ├── layout.tsx └── page.tsx ├── next.config.mjs ├── package-lock.json ├── package.json ├── pnpm-lock.yaml ├── postcss.config.mjs ├── public ├── favicon.ico └── logo.png ├── tailwind.config.ts └── tsconfig.json /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | @huangyz0918 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 10 | 11 | #### Software execution information 12 | 13 | MLE-agent version: 14 | System OS version: 15 | 16 | #### Problem description 17 | 18 | #### Steps to reproduce the problem 19 | 20 | #### Expected behavior 21 | 22 | #### Other information 23 | 24 | Things you tried, stack traces, related issues, suggestions on how to fix it... -------------------------------------------------------------------------------- /.github/OPEN_SOURCE_LICENSES.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/.github/OPEN_SOURCE_LICENSES.md -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Closes # 2 | 3 | 6 | 7 | #### What has been done to verify that this works as intended? 8 | 9 | #### Why is this the best possible solution? Were any other approaches considered? 10 | 11 | #### How does this change affect users? Describe intentional changes to behavior and behavior that could have accidentally been affected by code changes. In other words, what are the regression risks? 12 | 13 | #### Do we need any specific form for testing your changes? If so, please attach one. 14 | 15 | #### Does this change require updates to documentation? If so, please file an issue [here](https://github.com/MLSysOps/MLE-agent/issues/new) and include the link below. 16 | 17 | #### Before submitting this PR, please make sure you have: 18 | 19 | - [ ] confirmed all checks still pass OR confirm CI build passes. 20 | - [ ] verified that any code or assets from external sources are properly credited in comments and/or in 21 | the [credit file](https://github.com/MLSysOps/MLE-agent/blob/main/.github/OPEN_SOURCE_LICENSES.md). -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - dev 7 | - main 8 | pull_request: 9 | branches: 10 | - dev 11 | - main 12 | jobs: 13 | lint: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v3 18 | - name: Set up Python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: '3.11' 22 | - name: Flake8 Lint 23 | uses: py-actions/flake8@v2 -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | publish: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: "3.x" 19 | cache: pip 20 | cache-dependency-path: setup.py 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Publish 26 | env: 27 | TWINE_USERNAME: __token__ 28 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Unit Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | testing: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: '3.11' 20 | - name: Install dependencies 21 | run: | 22 | pip install . 23 | - name: Run Test 24 | run: python -m unittest discover tests -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # macOS 163 | *.DS_Store 164 | 165 | # cache 166 | .rich-chat.history 167 | 168 | poc/ 169 | 170 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MLE Agent 2 | 3 | This document is a work in progress. If you notice areas for improvement, please feel free to update this guide and submit a pull request! 4 | 5 | ## Table of Contents 6 | 7 | - [Submitting a Pull Request](#submitting-a-pull-request) 8 | - [Ensuring Your Pull Request Gets Accepted](#ensuring-your-pull-request-gets-accepted) 9 | - [The Review Process](#the-review-process) 10 | - [Work in Progress Pull Requests](#work-in-progress-pull-requests) 11 | - [Triaging Issues](#triaging-issues) 12 | 13 | ## Submitting a Pull Request 14 | 15 | To contribute code to MLE Agent, you need to open a [pull request](https://help.github.com/articles/about-pull-requests/). The pull request will be reviewed by the community before it is merged into the core project. Generally, a pull request should be submitted when a unit of work is complete, but you can also share ideas or get feedback through a work in progress (WIP) pull request ([learn more](#work-in-progress-pull-requests)). 16 | 17 | 1. Familiarize yourself with the project by reading our ["Getting Started Guide"](docs/GETTING_STARTED.md). 18 | 19 | 2. Follow our [coding standards](docs/CODE_GUIDELINES.md) to ensure consistency across the project. 20 | 21 | 3. Review our [testing guidelines](docs/TEST_GUIDELINES.md) to understand the project's automated testing framework. 22 | 23 | 4. [Set up your development environment](docs/DEVELOPMENT_SETUP.md) to make sure you have everything you need to contribute. 24 | 25 | 5. Make sure you have the latest version of the code by syncing your fork with the main repository: 26 | 27 | ```sh 28 | git remote add upstream https://github.com/MLSysOps/MLE-agent.git 29 | git fetch upstream 30 | git merge upstream/main 31 | ``` 32 | 33 | 6. Create a branch for the code you will be working on: 34 | 35 | ```sh 36 | git checkout -b my-new-feature 37 | ``` 38 | 39 | 7. Write your code, making sure to include tests as needed. 40 | 41 | 8. Commit your changes with a meaningful commit message: 42 | 43 | ```sh 44 | git commit -m "Description of the changes" 45 | ``` 46 | 47 | 9. Push your changes to your fork: 48 | 49 | ```sh 50 | git push origin my-new-feature 51 | ``` 52 | 53 | 10. Open a pull request on GitHub. Make sure to include a detailed description of the changes you made and any relevant context. 54 | 55 | ## Ensuring Your Pull Request Gets Accepted 56 | 57 | - Make sure your code follows the coding standards outlined in our code guidelines -- we use [flake8](https://flake8.pycqa.org/en/latest/) to enforce these standards. 58 | - Write tests for any new features or significant changes. 59 | - Ensure all tests pass before submitting your pull request. 60 | - Be responsive to feedback from reviewers. 61 | 62 | 63 | ## The Review Process 64 | 65 | Once you submit a pull request, it will be reviewed by the maintainers. They might request changes or provide feedback. The goal is to ensure the code is high quality and aligns with the project's goals. 66 | 67 | ## Work in Progress Pull Requests 68 | 69 | If you want feedback on your work before it's complete, you can open a WIP pull request. This allows you to get input from others on your approach or on specific parts of your code. 70 | When you're ready for a full review, you can mark the pull request as `MRG` for review by removing the `WIP` label. 71 | 72 | ## Triaging Issues 73 | If you're not ready to submit code but still want to contribute, you can help by triaging issues. This involves confirming bugs, providing additional information, or suggesting ways to reproduce issues. 74 | 75 | Thank you for your interest in contributing to MLE Agent! -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Huaizheng Zhang, Yizheng Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include web * 2 | recursive-include web/app * 3 | 4 | recursive-exclude web/node_modules * 5 | recursive-exclude web/.pnp 6 | recursive-exclude web/.pnp.js 7 | recursive-exclude web/.yarn/install-state.gz 8 | recursive-exclude web/coverage * 9 | recursive-exclude web/.next * 10 | recursive-exclude web/out * 11 | recursive-exclude web/build * 12 | recursive-exclude web/.DS_Store 13 | recursive-exclude web/*.pem 14 | recursive-exclude web/npm-debug.log* 15 | recursive-exclude web/yarn-debug.log* 16 | recursive-exclude web/yarn-error.log* 17 | recursive-exclude web/.env*.local 18 | recursive-exclude web/.vercel 19 | recursive-exclude web/*.tsbuildinfo 20 | recursive-exclude web/next-env.d.ts 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

MLE-Agent: Your intelligent companion for seamless AI engineering and research.

3 | kaia-llama 4 | MLSysOps%2FMLE-agent | Trendshift 5 |

:love_letter: Fathers' love for Kaia :love_letter:

6 | 7 | ![](https://github.com/MLSysOps/MLE-agent/actions/workflows/lint.yml/badge.svg) 8 | ![](https://github.com/MLSysOps/MLE-agent/actions/workflows/test.yml/badge.svg) 9 | ![PyPI - Version](https://img.shields.io/pypi/v/mle-agent) 10 | [![Downloads](https://static.pepy.tech/badge/mle-agent)](https://pepy.tech/project/mle-agent) 11 | ![GitHub License](https://img.shields.io/github/license/MLSysOps/MLE-agent) 12 | Join our Discord community 13 | 14 | [📚 Docs](https://mle-agent-site.vercel.app/) | 15 | [🐞 Report Issues](https://github.com/MLSysOps/MLE-agent/issues/new) | 16 | 👋 Join us on Discord 17 | 18 |
19 | 20 | ## Overview 21 | 22 | MLE-Agent is designed as a pairing LLM agent for machine learning engineers and researchers. It is featured by: 23 | 24 | - 🤖 Autonomous Baseline: Automatically builds ML/AI baselines and solutions based on your requirements. 25 | - 🏅End-to-end ML Task: Participates in Kaggle competitions and completes tasks independently. 26 | - 🔍 [Arxiv](https://arxiv.org/) and [Papers with Code](https://paperswithcode.com/) Integration: Access best practices 27 | and state-of-the-art methods. 28 | - 🐛 Smart Debugging: Ensures high-quality code through automatic debugger-coder interactions. 29 | - 📂 File System Integration: Organizes your project structure efficiently. 30 | - 🧰 Comprehensive Tools Integration: Includes AI/ML functions and MLOps tools for a seamless workflow. 31 | - ☕ Interactive CLI Chat: Enhances your projects with an easy-to-use chat interface. 32 | - 🧠 Smart Advisor: Provides personalized suggestions and recommendations for your ML/AI project. 33 | - 📊 Weekly Report: Automatically generates detailed summaries of your weekly works. 34 | 35 | https://github.com/user-attachments/assets/dac7be90-c662-4d0d-8d3a-2bc4df9cffb9 36 | 37 | ## Milestones 38 | 39 | - :rocket: 09/24/2024: Release the `0.4.2` with enhanced `Auto-Kaggle` mode to complete an end-to-end competition with minimal effort. 40 | - :rocket: 09/10/2024: Release the `0.4.0` with new CLIs like `MLE report`, `MLE kaggle`, `MLE integration` and many new 41 | models like `Mistral`. 42 | - :rocket: 07/25/2024: Release the `0.3.0` with huge refactoring, many integrations, etc. (v0.3.0) 43 | - :rocket: 07/11/2024: Release the `0.2.0` with multiple agents interaction (v0.2.0) 44 | - 👨‍🍼 **07/03/2024: Kaia is born** 45 | - :rocket: 06/01/2024: Release the first rule-based version of MLE agent (v0.1.0) 46 | 47 | ## Get started 48 | 49 | ### Installation 50 | 51 | ```bash 52 | pip install mle-agent -U 53 | # or from source 54 | git clone git@github.com:MLSysOps/MLE-agent.git 55 | pip install -e . 56 | ``` 57 | 58 | ### Usage 59 | 60 | ```bash 61 | mle new 62 | ``` 63 | 64 | And a project directory will be created under the current path, you need to start the project under the project 65 | directory. 66 | 67 | ```bash 68 | cd 69 | mle start 70 | ``` 71 | 72 | You can also start an interactive chat in the terminal under the project directory: 73 | 74 | ```bash 75 | mle chat 76 | ``` 77 | 78 | ## Use cases 79 | 80 | ### 🧪 Prototype an ML Baseline 81 | 82 | MLE agent can help you prototype an ML baseline with the given requirements, and test the model on the local machine. 83 | The requirements can be vague, such as "I want to predict the stock price based on the historical data". 84 | 85 | ```bash 86 | cd 87 | mle start 88 | ``` 89 | 90 | ### :bar_chart: Generate Work Report 91 | 92 | MLE agent can help you summarize your weekly report, including development progress, communication notes, reference, and 93 | to-do lists. 94 | 95 | #### Mode 1: Web Application to Generate Report from GitHub 96 | 97 | ```bash 98 | cd 99 | mle report 100 | ``` 101 | 102 | Then, you can visit http://localhost:3000/ to generate your report locally. 103 | 104 | #### Mode 2: CLI Tool to Generate Report from Local Git Repository 105 | ```bash 106 | cd 107 | mle report-local --email= --start-date=YYYY-MM-DD --end-date=YYYY-MM-DD 108 | ``` 109 | 110 | - `--start-date` and `--end-date` are optional parameters. If omitted, the command will generate a report for the default date range of the last 7 days. 111 | - Replace `` with your Git email and `` with the path to your local Git repository. 112 | 113 | ### :trophy: Start with Kaggle Competition 114 | 115 | MLE agent can participate in Kaggle competitions and finish coding and debugging from data preparation to model training 116 | independently. Here is the basic command to start a Kaggle competition: 117 | 118 | ```bash 119 | cd 120 | mle kaggle 121 | ``` 122 | 123 | Or you can let the agents finish the Kaggle task without human interaction if you have the dataset and submission file 124 | ready: 125 | 126 | ```bash 127 | cd 128 | mle kaggle --auto \ 129 | --datasets ",,..." \ 130 | --description "" \ 131 | --submission "" \ 132 | --sub_example "" \ 133 | --comp_id "" 134 | ``` 135 | 136 | Please make sure you have joined the competition before running the command. For more details, see the [MLE-Agent Tutorials](https://mle-agent-site.vercel.app/tutorial/Start_a_kaggle_task). 137 | 138 | ## Roadmap 139 | 140 | The following is a list of the tasks we plan to do, welcome to propose something new! 141 | 142 |
143 | :hammer: General Features 144 | 145 | - [x] Understand users' requirements to create an end-to-end AI project 146 | - [x] Suggest the SOTA data science solutions by using the web search 147 | - [x] Plan the ML engineering tasks with human interaction 148 | - [x] Execute the code on the local machine/cloud, debug and fix the errors 149 | - [x] Leverage the built-in functions to complete ML engineering tasks 150 | - [x] Interactive chat: A human-in-the-loop mode to help improve the existing ML projects 151 | - [x] Kaggle mode: to finish a Kaggle task without humans 152 | - [x] Summary and reflect the whole ML/AI pipeline 153 | - [ ] Integration with Cloud data and testing and debugging platforms 154 | - [x] Local RAG support to make personal ML/AI coding assistant 155 | - [ ] Function zoo: generate AI/ML functions and save them for future usage 156 | 157 |
158 | 159 |
160 | :star: More LLMs and Serving Tools 161 | 162 | - [x] Ollama LLama3 163 | - [x] OpenAI GPTs 164 | - [x] Anthropic Claude 3.5 Sonnet 165 | 166 |
167 | 168 |
169 | :sparkling_heart: Better user experience 170 | 171 | - [x] CLI Application 172 | - [x] Web UI 173 | - [x] Discord 174 | 175 |
176 | 177 |
178 | :jigsaw: Functions and Integrations 179 | 180 | - [x] Local file system 181 | - [x] Local code exectutor 182 | - [x] Arxiv.org search 183 | - [x] Papers with Code search 184 | - [x] General keyword search 185 | - [ ] Hugging Face 186 | - [ ] SkyPilot cloud deployment 187 | - [ ] Snowflake data 188 | - [ ] AWS S3 data 189 | - [ ] Databricks data catalog 190 | - [ ] Wandb experiment monitoring 191 | - [ ] MLflow management 192 | - [ ] DBT data transform 193 | 194 |
195 | 196 | ## Contributing 197 | 198 | We welcome contributions from the community. We are looking for contributors to help us with the following tasks: 199 | 200 | - Benchmark and Evaluate the agent 201 | - Add more features to the agent 202 | - Improve the documentation 203 | - Write tests 204 | 205 | Please check the [CONTRIBUTING.md](CONTRIBUTING.md) file if you want to contribute. 206 | 207 | ## Support and Community 208 | 209 | - [Discord community](https://discord.gg/SgxBpENGRG). If you have any questions, please ask in the Discord community. 210 | 211 | ## Citation 212 | 213 | ```bibtex 214 | @misc{zhang2024mleagent, 215 | title = {MLE-Agent: Your Intelligent Companion for Seamless AI Engineering and Research}, 216 | author = {Huaizheng Zhang*, Yizheng Huang*, Lei Zhang}, 217 | year = {2024}, 218 | note = {\url{https://github.com/MLSysOps/MLE-agent}}, 219 | } 220 | ``` 221 | 222 | ## License 223 | 224 | Check [MIT License](LICENSE) file for more information. 225 | -------------------------------------------------------------------------------- /assets/kaia_llama.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/assets/kaia_llama.webp -------------------------------------------------------------------------------- /mle/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | -------------------------------------------------------------------------------- /mle/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .advisor import * 2 | from .coder import * 3 | from .debugger import * 4 | from .planner import * 5 | from .summarizer import * 6 | from .reporter import * 7 | from .chat import * 8 | -------------------------------------------------------------------------------- /mle/agents/chat.py: -------------------------------------------------------------------------------- 1 | from rich.console import Console 2 | 3 | from mle.function import * 4 | from mle.utils import get_config, WorkflowCache 5 | from mle.utils.component_memory import trace_component 6 | 7 | class ChatAgent: 8 | 9 | def __init__(self, model, memory=None, working_dir='.', console=None): 10 | """ 11 | ChatAgent assists users with planning and debugging ML projects. 12 | 13 | Args: 14 | model: The machine learning model used for generating responses. 15 | """ 16 | config_data = get_config() 17 | 18 | self.model = model 19 | self.memory = memory 20 | self.chat_history = [] 21 | if working_dir == '.': 22 | working_dir = os.getcwd() 23 | self.working_dir = working_dir 24 | self.cache = WorkflowCache(working_dir, 'baseline') 25 | 26 | self.console = console 27 | if not self.console: 28 | self.console = Console() 29 | 30 | self.sys_prompt = f""" 31 | You are a programmer working on an Machine Learning task using Python. 32 | You are currently working on: {self.working_dir}. 33 | 34 | Your can leverage your capabilities by using the specific functions listed below: 35 | 36 | 1. Creating project structures based on the user requirement using function `create_directory`. 37 | 2. Writing clean, efficient, and well-documented code using function `create_file` and `write_file`. 38 | 3. Exam the project to re-use the existing code snippets as much as possible, you may need to use 39 | functions like `list_files`, `read_file` and `write_file`. 40 | 4. Writing the code into the file when creating new files, do not create empty files. 41 | 5. Use function `preview_csv_data` to preview the CSV data if the task include CSV data processing. 42 | 6. Decide whether the task requires execution and debugging before moving to the next or not. 43 | 7. Generate the commands to run and test the current task, and the dependencies list for this task. 44 | 8. You only write Python scripts, don't write Jupiter notebooks which require interactive execution. 45 | """ 46 | self.search_prompt = """ 47 | 9. Performing web searches use function `web_search` to get up-to-date information or additional context. 48 | """ 49 | 50 | self.functions = [ 51 | schema_read_file, 52 | schema_create_file, 53 | schema_write_file, 54 | schema_list_files, 55 | schema_create_directory, 56 | schema_search_arxiv, 57 | schema_search_papers_with_code, 58 | schema_web_search, 59 | schema_execute_command, 60 | schema_preview_csv_data, 61 | schema_unzip_data, 62 | schema_preview_zip_structure 63 | ] 64 | 65 | if config_data.get('search_key'): 66 | self.functions.append(schema_web_search) 67 | self.sys_prompt += self.search_prompt 68 | 69 | if not self.cache.is_empty(): 70 | dataset = self.cache.resume_variable("dataset") 71 | ml_requirement = self.cache.resume_variable("ml_requirement") 72 | advisor_report = self.cache.resume_variable("advisor_report") 73 | self.sys_prompt += f""" 74 | The overall project information: \n 75 | {'Dataset: ' + str(dataset) if dataset else ''} \n 76 | {'Requirement: ' + str(ml_requirement) if ml_requirement else ''} \n 77 | {'Advisor: ' + str(advisor_report) if advisor_report else ''} \n 78 | """ 79 | 80 | self.chat_history.append({"role": 'system', "content": self.sys_prompt}) 81 | 82 | def greet(self): 83 | """ 84 | Generate a greeting message to the user, including inquiries about the project's purpose and 85 | an overview of the support provided. This initializes a collaborative tone with the user. 86 | 87 | Returns: 88 | str: The generated greeting message. 89 | """ 90 | greet_prompt = """ 91 | Can you provide concise and friendly greetings within 50 words, including: 92 | 1. Infer about the project's purpose or objective. 93 | 2. Summarize the previous conversations if it existed. 94 | 2. Offering a brief overview of the assistance and support you can provide to the user, such as: 95 | - Helping with project planning and management. 96 | - Assisting with debugging and troubleshooting code. 97 | - Offering advice on best practices and optimization techniques. 98 | - Providing resources and references for further learning. 99 | Make sure your greeting is inviting and sets a positive tone for collaboration. 100 | """ 101 | self.chat_history.append({"role": "user", "content": greet_prompt}) 102 | greets = self.model.query( 103 | self.chat_history, 104 | function_call='auto', 105 | functions=self.functions, 106 | ) 107 | 108 | self.chat_history.append({"role": "assistant", "content": greets}) 109 | return greets 110 | @trace_component("chat") 111 | def chat(self, user_prompt): 112 | """ 113 | Handle the response from the model streaming. 114 | The stream mode is integrative with the model streaming function, we don't 115 | need to set it into the JSON mode. 116 | 117 | Args: 118 | user_prompt: the user prompt. 119 | """ 120 | text = '' 121 | if self.memory: 122 | table_name = 'mle_chat_' + self.working_dir.split('/')[-1] 123 | query = self.memory.query([user_prompt], table_name=table_name, n_results=1) # TODO: adjust the n_results. 124 | user_prompt += f""" 125 | \nThese reference files and their snippets may be useful for the question:\n\n 126 | """ 127 | 128 | for t in query[0]: 129 | snippet, metadata = t.get('text'), t.get('metadata') 130 | user_prompt += f"**File**: {metadata.get('file')}\n**Snippet**: {snippet}\n" 131 | self.chat_history.append({"role": "user", "content": user_prompt}) 132 | 133 | for content in self.model.stream( 134 | self.chat_history, 135 | function_call='auto', 136 | functions=self.functions, 137 | ): 138 | if content: 139 | text += content 140 | yield text 141 | 142 | self.chat_history.append({"role": "assistant", "content": text}) 143 | -------------------------------------------------------------------------------- /mle/agents/coder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | from rich.console import Console 4 | 5 | from mle.function import * 6 | from mle.utils import get_config, print_in_box, clean_json_string 7 | from mle.utils.component_memory import trace_component 8 | 9 | def process_summary(summary_dict: dict): 10 | """ 11 | Process the code summary. 12 | Args: 13 | summary_dict: the code summary in a dictionary format. 14 | """ 15 | return textwrap.dedent(f""" 16 | MLE Developer has finished the task: {summary_dict.get('task')}.\n 17 | Task description: {summary_dict.get('task_description')}\n 18 | {summary_dict.get('message')}\n 19 | Dependencies you are required to run the code: {summary_dict.get('dependency')}\n 20 | Command to run the code: {summary_dict.get('command')}\n 21 | Whether the code is required to execute and debug: {summary_dict.get('debug')}""") 22 | 23 | 24 | class CodeAgent: 25 | 26 | def __init__(self, model, working_dir='.', console=None, single_file=False): 27 | """ 28 | CodeAgent: the agent to solve the given coding problems, by planning coding tasks, searching websites, 29 | and generating code snippets. It does not execute the code, only make use of built-in functions to provides 30 | the code snippets to solve the problem. 31 | 32 | Args: 33 | model: the model to use. 34 | working_dir: the working directory. 35 | console: the console to use. 36 | single_file: whether the agent is working on a single file or not. 37 | """ 38 | config_data = get_config() 39 | self.code_summary = None 40 | self.model = model 41 | self.chat_history = [] 42 | self.working_dir = working_dir 43 | 44 | self.console = console 45 | if not self.console: 46 | self.console = Console() 47 | 48 | self.sys_prompt = f""" 49 | You are a programmer working on an Machine Learning task using Python. 50 | You are currently working on: {self.working_dir}. 51 | 52 | Your can leverage your capabilities by using the specific functions listed below: 53 | 54 | - Creating project structures based on the user requirement using function `create_directory`. 55 | - Writing clean, efficient, and well-documented code using function `create_file` and `write_file`. 56 | - Exam the project to re-use the existing code snippets as much as possible, you may need to use 57 | functions like `list_files`, `read_file` and `write_file`. 58 | - Use function `preview_zip_structure` to preview the structure of the file if the task include zip file processing. 59 | - Use function `unzip_data` to extract the compressed file if the task include compressed file processing. 60 | - Writing the code into the file when creating new files, do not create empty files. 61 | - Use function `preview_csv_data` to preview the CSV data if the task include CSV data processing. 62 | - Decide whether the task requires execution and debugging before moving to the next or not. 63 | - Generate the commands to run and test the current task, and the dependencies list for this task. 64 | - You only write Python scripts, don't write Jupiter notebooks which require interactive execution. 65 | """ 66 | 67 | if single_file: 68 | self.sys_prompt = f""" 69 | You are an expert programmer working on an Machine Learning task using Python. 70 | You are currently working on: {self.working_dir}. 71 | 72 | Your can leverage your capabilities by using the specific functions listed below: 73 | 74 | - You should create a single script first, with the complete code inside. You can have multiple functions and classes. 75 | - Writing clean, efficient, and well-documented code to a script using functions `create_file`. 76 | - Use function `preview_csv_data` to preview the CSV data if the task include CSV dataset or examples. 77 | - Use function `preview_zip_structure` to preview the structure of the file if the task include zip file processing. 78 | - Use function `unzip_data` to extract the compressed file if the task include compressed file processing. 79 | - Generate the commands to run and test the current script, and the dependencies list required for this script. 80 | - You only write Python scripts, don't write Jupiter notebooks which require interactive execution. 81 | - Make sure the code has met the task description, and the suggested methods. 82 | - Make sure the output format and the output file path is correct. 83 | """ 84 | 85 | self.search_prompt = """ 86 | - Performing web searches use function `web_search` to get up-to-date information or additional context. 87 | """ 88 | 89 | self.json_mode_prompt = """ 90 | 91 | The output format should be in JSON format, include: 92 | 93 | 1. The dependency list that the project needs to run. 94 | 2. And the command to run and test the project. 95 | 3. The reason why failed if the status is failed, put it in the "message" field. 96 | 4. Whether the task requires execution and debug or not (it is "false" when create new directories or files). 97 | If the task requires modifying existing code or generating new code, it is "true". If the "command" is empty, 98 | the "debug" should be "false". 99 | 100 | Example JSON output: 101 | 102 | { 103 | "dependency":[ 104 | "torch", 105 | "scikit-learn" 106 | ], 107 | "command":"python /path/to/your/project.py", 108 | "message":"the project-related has been generated in the project.py.", 109 | "debug":"true" 110 | } 111 | 112 | """ 113 | 114 | if single_file: 115 | self.json_mode_prompt = """ 116 | 117 | The output format should be in JSON format, include: 118 | 119 | 1. The dependency list that the project needs to run. 120 | 2. And the command and the parameters to run and test the script. 121 | 122 | Example JSON output: 123 | 124 | { 125 | "dependency":[ 126 | "torch", 127 | "scikit-learn" 128 | ], 129 | "command":"python /path/to/your/project.py", 130 | } 131 | 132 | """ 133 | 134 | self.functions = [ 135 | schema_read_file, 136 | schema_create_file, 137 | schema_write_file, 138 | schema_list_files, 139 | schema_create_directory, 140 | schema_preview_csv_data, 141 | schema_preview_zip_structure, 142 | schema_unzip_data 143 | ] 144 | 145 | if config_data.get('search_key'): 146 | self.functions.append(schema_web_search) 147 | self.sys_prompt += self.search_prompt 148 | 149 | self.sys_prompt += self.json_mode_prompt 150 | self.chat_history.append({"role": 'system', "content": self.sys_prompt}) 151 | 152 | @trace_component("coder") 153 | def read_requirement(self, advisor_report: str): 154 | """ 155 | Read the user requirement and the advisor report. 156 | :param advisor_report: 157 | :return: 158 | """ 159 | self.chat_history.append({"role": "system", "content": advisor_report}) 160 | 161 | @trace_component("coder") 162 | def code(self, task_dict: dict): 163 | """ 164 | Handle the query from the model query response. 165 | Args: 166 | task_dict: the task dictionary. 167 | """ 168 | task_prompt = textwrap.dedent(f""" 169 | You are required to complete task: {task_dict.get('task')}.\n 170 | Task description: {task_dict.get('description')} 171 | """) 172 | 173 | with self.console.status(f"Coder is working on the task: {task_dict.get('task')}..."): 174 | self.chat_history.append({"role": "user", "content": task_prompt}) 175 | text = self.model.query( 176 | self.chat_history, 177 | function_call='auto', 178 | functions=self.functions, 179 | response_format={"type": "json_object"} 180 | ) 181 | 182 | self.chat_history.append({"role": "assistant", "content": text}) 183 | code_summary = clean_json_string(text) 184 | code_summary.update({'task': task_dict.get('task'), 'task_description': task_dict.get('description')}) 185 | return code_summary 186 | 187 | @trace_component("coder") 188 | def debug(self, task_dict: dict, debug_report: dict): 189 | """ 190 | Handle the query from the model query response. 191 | :param task_dict: the task dictionary. 192 | :param debug_report: the debug report from DebugAgent. 193 | :return: 194 | """ 195 | improve_prompt = textwrap.dedent(f""" 196 | You are required improve the existing project.\n 197 | The required changes: {debug_report.get("changes")}\n 198 | The suggestion: {debug_report.get("suggestion")} 199 | 200 | """) 201 | 202 | with self.console.status(f"Coder is improving the code for task {task_dict.get('task')}..."): 203 | self.chat_history.append({"role": "user", "content": improve_prompt}) 204 | text = self.model.query( 205 | self.chat_history, 206 | function_call='auto', 207 | functions=self.functions, 208 | response_format={"type": "json_object"} 209 | ) 210 | 211 | self.chat_history.append({"role": "assistant", "content": text}) 212 | code_summary = clean_json_string(text) 213 | code_summary.update({'task': task_dict.get('task'), 'task_description': task_dict.get('description')}) 214 | return code_summary 215 | 216 | def interact(self, task_dict: dict): 217 | """ 218 | Interact with the user to code the task. 219 | Args: 220 | task_dict: the task dictionary. 221 | """ 222 | self.code_summary = self.code(task_dict) 223 | print_in_box(process_summary(self.code_summary), self.console, title="MLE Developer", color="cyan") 224 | while True: 225 | suggestion = questionary.text( 226 | "Any feedback to MLE developer? (ENTER to move to the next stage, \"exit\" to exit the project)" 227 | ).ask() 228 | 229 | if not suggestion: 230 | break 231 | 232 | if suggestion.lower() in ["exit"]: 233 | sys.exit(0) 234 | 235 | with self.console.status(f"MLE Developer is working on the task: {task_dict.get('task')}..."): 236 | self.chat_history.append({"role": "user", "content": suggestion}) 237 | text = self.model.query( 238 | self.chat_history, 239 | function_call='auto', 240 | functions=self.functions, 241 | response_format={"type": "json_object"} 242 | ) 243 | 244 | self.chat_history.append({"role": "assistant", "content": text}) 245 | self.code_summary = clean_json_string(text) 246 | self.code_summary.update( 247 | { 248 | 'task': task_dict.get('task'), 249 | 'task_description': task_dict.get('description') 250 | } 251 | ) 252 | print_in_box(process_summary(self.code_summary), self.console, title="MLE Developer", color="cyan") 253 | return self.code_summary 254 | -------------------------------------------------------------------------------- /mle/agents/debugger.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from mle.function import * 4 | from mle.utils import get_config, print_in_box 5 | 6 | from rich.console import Console 7 | from mle.utils.component_memory import trace_component 8 | 9 | def process_debug_report(debug_report): 10 | """ 11 | Process the debug report. 12 | Args: 13 | debug_report: the debug report. 14 | """ 15 | if debug_report.get("status") == "success": 16 | return "The code runs without errors." 17 | else: 18 | changes = debug_report.get("changes") 19 | suggestion = debug_report.get("suggestion") 20 | report_str = "The code is running with errors.\nChanges are required:\n" 21 | for change in changes: 22 | report_str += f"File: {change.get('file')}, Line: {change.get('line')}, Issue: {change.get('issue')}, " \ 23 | f"Suggestion: {change.get('suggestion')}\n" 24 | report_str += f"Overall Suggestion: {suggestion}" 25 | return report_str 26 | 27 | 28 | class DebugAgent: 29 | 30 | def __init__(self, model, console=None, analyze_only=False): 31 | """ 32 | DebugAgent: the agent to run the generated the code and then debug it. The return of the 33 | agent is an instruction to the user to modify the code based on the logs and web search. 34 | 35 | Args: 36 | model: the model to use. 37 | console: the console to use. 38 | analyze_only: if only analyze the code without execution 39 | """ 40 | config_data = get_config() 41 | self.console = console 42 | if not self.console: 43 | self.console = Console() 44 | self.model = model 45 | self.chat_history = [] 46 | self.sys_prompt = """ 47 | You are a program error debugger working on a Python project. 48 | 49 | Your can leverage your capabilities by using the specific functions listed below: 50 | 51 | - Install the code dependencies using function `execute_command` based on the Developer's dependencies list. 52 | - Execute the code using function `execute_command` to test the code based on the Developer's instructions. 53 | - If the program returns errors, you need to debug the code based on the logs, you may need to first read the 54 | structure of the project using function `list_files`. 55 | - Then you may need to call `read_file` function to read the content of the code files, locate the error line 56 | and the reasons. 57 | - You don't need to care about the best practices and code styles, you only care about the errors in the code. 58 | 59 | """ 60 | 61 | self.search_prompt = """ 62 | - You need to debug the code based on the error logs, you may need to call `web_search` function to search for 63 | the solutions or reasons for the errors if needed. 64 | """ 65 | 66 | self.json_mode_prompt = """ 67 | 68 | Example JSON output if a program runs without errors: 69 | { 70 | "status":"success", 71 | "changes":[], 72 | "suggestion":"" 73 | } 74 | 75 | Example JSON output if a program returns errors: 76 | { 77 | "status":"error", 78 | "changes":[ 79 | { 80 | "file":"xxx.py", 81 | "line":10, 82 | "issue":"xxx", 83 | "suggestion":"xxx" 84 | }, 85 | "suggestion":"Failed to find the target file. Please check the file path." 86 | ] 87 | } 88 | """ 89 | 90 | self.functions = [ 91 | schema_read_file, 92 | schema_list_files, 93 | schema_execute_command 94 | ] 95 | 96 | if analyze_only: 97 | self.sys_prompt = """ 98 | You are a program error debugger working on a Python project. You target is to 99 | analyze the running logs and the source code to locate the errors. And give the 100 | suggestions to the developer to fix the errors. 101 | 102 | Your can leverage your capabilities by using the specific functions listed below: 103 | 104 | - Read and understand the running logs and the error messages. 105 | - Read the content of the source code files using function `read_file`, to locate the error line. 106 | - Give the suggestions to the developer to fix the errors. 107 | - Install missing dependencies using function `execute_command` based on the error logs. 108 | - If there is no error based on the exit code, you don't need to do anything but return the success status. 109 | """ 110 | 111 | if config_data.get('search_key'): 112 | self.functions.append(schema_web_search) 113 | self.sys_prompt += self.search_prompt 114 | 115 | self.sys_prompt += self.json_mode_prompt 116 | self.chat_history.append({"role": 'system', "content": self.sys_prompt}) 117 | 118 | def analyze_with_log(self, commands, logs): 119 | """ 120 | Analyze the logs. 121 | :param commands: the commands to execute. 122 | :param logs: the logs to analyze. 123 | :return: 124 | """ 125 | analyze_prompt = f""" 126 | The command to execute the code: {commands} \n 127 | The logs to analyze: {logs} \n 128 | """ 129 | 130 | self.chat_history.append({"role": "user", "content": analyze_prompt}) 131 | try: 132 | text = self.model.query( 133 | self.chat_history, 134 | function_call='auto', 135 | functions=self.functions, 136 | response_format={"type": "json_object"} 137 | ) 138 | except Exception as e: 139 | print(f"Error occurred while querying the model: {e}") 140 | return {} 141 | 142 | self.chat_history.append({"role": "assistant", "content": text}) 143 | report_dict = json.loads(text) 144 | print_in_box(process_debug_report(report_dict), self.console, title="MLE Debugger", color="yellow") 145 | return report_dict 146 | @trace_component("debugger") 147 | def analyze(self, code_report): 148 | """ 149 | Handle the query from the model query response. 150 | Args: 151 | code_report: the code report from the MLE developer. 152 | """ 153 | debug_prompt = f""" 154 | Please help me debug the current task: {code_report.get('task')}. {code_report.get('messages')}\n 155 | The task description: {code_report.get('task_description')} 156 | The dependencies required for this task: {code_report.get('dependencies')} 157 | The command to execute the code: {code_report.get('command')} 158 | 159 | """ 160 | 161 | error_msg = code_report.get('error_message') 162 | if error_msg: 163 | debug_prompt += f"Error message: {error_msg}\n" 164 | 165 | self.chat_history.append({"role": "user", "content": debug_prompt}) 166 | try: 167 | text = self.model.query( 168 | self.chat_history, 169 | function_call='auto', 170 | functions=self.functions, 171 | response_format={"type": "json_object"} 172 | ) 173 | except Exception as e: 174 | print(f"Error occurred while querying the model: {e}") 175 | return {} 176 | 177 | self.chat_history.append({"role": "assistant", "content": text}) 178 | report_dict = json.loads(text) 179 | print_in_box(process_debug_report(report_dict), self.console, title="MLE Debugger", color="yellow") 180 | return report_dict -------------------------------------------------------------------------------- /mle/agents/planner.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import questionary 4 | from rich.console import Console 5 | 6 | from mle.utils import print_in_box, clean_json_string 7 | from mle.utils.component_memory import trace_component 8 | 9 | 10 | def process_plan(plan_dict: dict): 11 | plan_str = "" 12 | for task in plan_dict.get('tasks'): 13 | plan_str += f"[green][Task]:[/green] {task.get('task')}\n[green][Description]:[/green] {task.get('description')}\n\n" 14 | 15 | return plan_str 16 | 17 | 18 | class PlanAgent: 19 | 20 | def __init__(self, model, console=None): 21 | """ 22 | PlanAgent: the agent to plan the machine learning project. By receiving the user's requirements, the agent will 23 | first analyze the requirements and ask the user to provide more details if necessary. Then the agent will 24 | generate the project plan based on the requirements and the user's input. 25 | 26 | The project plan will be sent to the advisor agent to provide suggestions on the best machine learning task, 27 | model, dataset, and evaluation metrics to use. 28 | 29 | Args: 30 | model: the model to use. 31 | """ 32 | self.console = console 33 | if not self.console: 34 | self.console = Console() 35 | 36 | self.plan_dict = None 37 | self.model = model 38 | self.chat_history = [] 39 | self.sys_prompt = """ 40 | You are an Machine learning Product Manager, you are going to collaborate with the user to plan the ML 41 | project. A generated plan includes the coding tasks for the developer to complete the project. 42 | 43 | Your capabilities include: 44 | 45 | 1. Understand the user's dataset and the user's requirements, the requirements may include the intention of the 46 | project, the model or algorithm to use, the dataset, the evaluation metrics, and the expected results. The 47 | generated task should always meet the user's requirements. 48 | 2. The generated plan should include several coding tasks, and the task should include specific instructions for 49 | the developer. For example: "Create a directory named 'dataset' under the project root, and write a Python 50 | script called 'data_loader.py' to download the dataset ImageNet from the official website." 51 | 3. The coding task should be clear and easy to understand, but with essential information to complete the task. 52 | For example, if the dataset is a user's local CSV file, you should provide the absolute path to the file in the 53 | task, otherwise, the developer may not be able to complete the task. 54 | 4. Please only provide the coding tasks, do not provide the code snippets, the developer will complete the task. 55 | 5. Do not generate task like "setup environment", "install dependencies", "run the code", etc. The developer 56 | only focus on the coding tasks. 57 | 58 | """ 59 | self.json_mode_prompt = """ 60 | 61 | Example JSON output: 62 | 63 | { 64 | "tasks": [ 65 | { 66 | "task": "download dataset", 67 | "description": "Create a directory named 'dataset' under the project root, and write a Python 68 | script called 'data_loader.py' to download the dataset ImageNet from the official website." 69 | }, 70 | { 71 | "task": "process ImageNet", 72 | "description": "Write a Python script called `process_data.py` to process the dataset by 73 | resizing the images to 224x224 pixels and save the data to the 'processed_data' directory." 74 | }, 75 | { 76 | "task": "train model", 77 | "description": "Write a Python script called `train_model.py` to train an image classification 78 | model on the processed data and save the trained model to the 'model' directory." 79 | } 80 | ] 81 | } 82 | """ 83 | self.sys_prompt += self.json_mode_prompt 84 | self.chat_history.append({"role": 'system', "content": self.sys_prompt}) 85 | 86 | @trace_component("planner") 87 | def plan(self, user_prompt): 88 | """ 89 | Handle the query from the model query response. 90 | Args: 91 | user_prompt: the user prompt. 92 | """ 93 | with self.console.status("MLE Planner is planning the coding tasks..."): 94 | self.chat_history.append({"role": "user", "content": user_prompt}) 95 | text = self.model.query( 96 | self.chat_history, 97 | response_format={"type": "json_object"} 98 | ) 99 | 100 | self.chat_history.append({"role": "assistant", "content": text}) 101 | 102 | try: 103 | return json.loads(text) 104 | except json.JSONDecodeError as e: 105 | return clean_json_string(text) 106 | 107 | @trace_component("planner") 108 | def interact(self, user_prompt): 109 | """ 110 | Handle the query from the model query response. 111 | Args: 112 | user_prompt: the user prompt. 113 | """ 114 | self.plan_dict = self.plan(user_prompt) 115 | print_in_box(process_plan(self.plan_dict), self.console, title="MLE Planner", color="purple") 116 | 117 | while True: 118 | suggestion = questionary.text( 119 | "Suggestions to improve the plan? (ENTER to move to the next stage, \"exit\" to exit the project)" 120 | ).ask() 121 | 122 | if not suggestion or suggestion.lower() == "no": 123 | break 124 | 125 | if suggestion.lower() == "exit": 126 | sys.exit(0) 127 | 128 | self.plan_dict = self.plan(suggestion) 129 | print_in_box(process_plan(self.plan_dict), self.console, title="MLE Planner", color="purple") 130 | 131 | return self.plan_dict -------------------------------------------------------------------------------- /mle/agents/reporter.py: -------------------------------------------------------------------------------- 1 | import json 2 | from rich.console import Console 3 | from time import gmtime, strftime 4 | from mle.utils.component_memory import trace_component 5 | 6 | class ReportAgent: 7 | 8 | def __init__(self, model, console=None): 9 | """ 10 | ReportAgent: generate the report based on the information provided by the user. 11 | 12 | Args: 13 | model: the model to use. 14 | console: the console to use. 15 | """ 16 | self.report = None 17 | self.knowledge = None 18 | self.model = model 19 | self.chat_history = [] 20 | self.console = console 21 | if not self.console: 22 | self.console = Console() 23 | self.sys_prompt = """ 24 | You are writing a weekly progress report for an engineer working on a project. Your capabilities include: 25 | 26 | 1. Based on the user's input information, you need to organize the information and generate more details from the 27 | user's perspective. 28 | 2. You need to generate a section called "Development Progress" based on the user's Github 29 | summary given by the user, do not use the commit messages directly. 30 | 3. You need to generate a section called "Communication / Design Progress" based on the user's Google Calendar 31 | events (if any). Not all events are related to the project but you need to filter out the related ones. 32 | 4. You need to generate a section called "Development To-do" based on the user's Github information, and the 33 | task priority, with the highest priority first and generate more details. 34 | 5. You need to generate a section called "Communication / Design To-do" based on the user's future 35 | Google Calendar events (if any). 36 | 6. You need to generate a section called "Existing Hard Parts" to summarize/infer the hard parts of the project. 37 | 7. Based on the hard parts and the project information, you need to generate a section called 38 | "Require Manager' / Others’ help", to indicate the parts that may need help. 39 | 8. You can generate as more as possible details to make sure the report is informative and has great progress. 40 | 41 | """ 42 | self.json_mode_prompt = """ 43 | 44 | JSON Output Format: 45 | 46 | { 47 | "project_okr": "if user provides the ORKs, put there. Otherwise, put an empty string", 48 | "business_goal": ["The project aims to build an image classification model...", ...], 49 | "dev_progress": ["implemented the data collection Python function...", ...], 50 | "communicate_progress": ["Meeting with the design team to discuss the new feature...", ...], 51 | "dev_todo": [{"task": "fix ...", "description": ..., "priority": "high"}, {"task": "support ..."," description": ..., "priority": "medium"}, ...], 52 | "communicate_todo": [{"task": "seek helps from ...", "priority": "high"}, 53 | {"task": "meet with ...", "priority": "low"} ...], 54 | "hard_parts": ["The project may face the challenge of ...", ...], 55 | "require_manager_help": ["The project needs help from the design team to ...", ...], 56 | "suggestions_to_user": ["Increase more meeting with design team...", ...], 57 | "reference": [{"title": "xxxx", "link":"https://arxiv.org/abs/xxx.xxxx"}, {"title": "xxx", "link": "https://github.com/xxx"}, ...], 58 | } 59 | 60 | """ 61 | self.sys_prompt += self.json_mode_prompt 62 | self.chat_history.append({"role": 'system', "content": self.sys_prompt}) 63 | 64 | def process_knowledge(self, github_summary: dict, calendar_events: list = None, okr: str = None): 65 | """ 66 | Process the knowledge to generate the report. 67 | 68 | Args: 69 | github_summary: the summary of the GitHub project. 70 | calendar_events: the Google Calendar events. 71 | okr: the OKR of the project. 72 | """ 73 | info_prompt = f""" 74 | # Project Overview 75 | 76 | ## The username: {github_summary.get('username')}\n 77 | ## The repository: {github_summary.get('github_repo')}\n 78 | ## Technology stack: {github_summary.get('tech_stack')}\n 79 | ## The project summary: {github_summary.get('summary')}\n 80 | """ 81 | if okr: 82 | info_prompt += f"\n## The project's OKR: \n" 83 | info_prompt += f"{okr}\n" 84 | 85 | info_prompt += f"\n## The project's business goal: \n" 86 | for goal in github_summary.get("business_goal", []): 87 | info_prompt += f"- {goal}\n" 88 | 89 | if github_summary.get("dataset"): 90 | info_prompt += f"\n## The project's datasets: \n" 91 | for dataset in github_summary.get("dataset"): 92 | info_prompt += f"- {dataset['name']}: {dataset['description']}\n" 93 | 94 | info_prompt += f"\n## The project's roadmap: \n" 95 | for task in github_summary.get("roadmap", []): 96 | info_prompt += f"- {task['task']} ({task['priority']})\n" 97 | 98 | info_prompt += f"\n## The project's hard parts: \n" 99 | for part in github_summary.get("hard_parts", []): 100 | info_prompt += f"- {part}\n" 101 | 102 | info_prompt += f"\n## The project's related work: \n" 103 | for work in github_summary.get("related_work", []): 104 | info_prompt += f"- {work['title']} ({work['link']})\n" 105 | 106 | activities = github_summary.get("user_activity") 107 | info_prompt += f""" 108 | # User's Activity (from {activities['period']['start']} to {activities['period']['end']}) 109 | 110 | """ 111 | 112 | info_prompt += f""" 113 | ## Contributions:\n 114 | - Commits: {activities['summary']['total_commits']} 115 | - Pull Requests: {activities['summary']['total_pull_requests']} 116 | - Issues: {activities['summary']['total_issues']}\n 117 | """ 118 | 119 | info_prompt += f"## The user's commits: \n" 120 | for commit in activities['commits']['messages']: 121 | info_prompt += f"- {commit}" 122 | 123 | info_prompt += f"\n## The user's pull requests: \n" 124 | for pr in activities['pull_requests']['details']: 125 | info_prompt += f"- {pr['title']} ({pr['status']})" 126 | 127 | info_prompt += f"\n## The user's issues: \n" 128 | for issue in activities['issues']['details']: 129 | info_prompt += f"- {issue['title']}" 130 | 131 | if calendar_events: 132 | info_prompt += f"\n## The user's calendar events:\n" 133 | for event in calendar_events: 134 | info_prompt += (f"- Title: {event['title']}\n" 135 | f" Time: ({event['start_time']} - {event['end_time']})\n" 136 | f" Description: {event['description']}\n" 137 | f" Organizer: {event['organizer']['email']}\n") 138 | 139 | self.knowledge = info_prompt 140 | return info_prompt 141 | 142 | @trace_component("reporter") 143 | def gen_report(self, github_summary: dict, calendar_events: list = None, okr: str = None): 144 | """ 145 | Handle the query from the model query response. 146 | Args: 147 | github_summary: the summary of the GitHub project. 148 | calendar_events: the Google Calendar 149 | okr: the OKR of the project. 150 | """ 151 | with self.console.status("MLE reporter is writing the progress report..."): 152 | self.chat_history.append( 153 | { 154 | "role": "user", 155 | "content": self.process_knowledge(github_summary, calendar_events, okr) 156 | } 157 | ) 158 | text = self.model.query( 159 | self.chat_history, 160 | response_format={"type": "json_object"} 161 | ) 162 | 163 | self.chat_history.append({"role": "assistant", "content": text}) 164 | # save the dict into a local files 165 | today = strftime("%Y_%m_%d", gmtime()) 166 | result_dict = json.loads(text) 167 | with open(f'progress_report_{today}.json', 'w') as f: 168 | json.dump(result_dict, f) 169 | return result_dict -------------------------------------------------------------------------------- /mle/function/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import py7zr 3 | import gzip 4 | import bz2 5 | import lzma 6 | import shutil 7 | import tarfile 8 | import zipfile 9 | import textwrap 10 | import tempfile 11 | import pandas as pd 12 | from pandas.api.types import is_numeric_dtype 13 | 14 | 15 | def unzip_data(compressed_file_path, extract_path=None): 16 | """ 17 | Unzip a compressed file, supporting various formats (.zip, .7z, .tar, .gz, .bz2, .xz). 18 | If no extract_path is provided, it creates a temporary directory. 19 | 20 | :param compressed_file_path: Path to the compressed file 21 | :param extract_path: Path where the contents will be extracted. If None, a temp directory is used. 22 | :return: String with the path to the unzipped contents 23 | """ 24 | if not os.path.exists(compressed_file_path): 25 | raise FileNotFoundError(f"The file {compressed_file_path} does not exist.") 26 | 27 | # If no extract_path is provided, create a temporary directory 28 | if extract_path is None: 29 | extract_path = tempfile.mkdtemp() 30 | print(f"No extract path provided. Using temporary directory: {extract_path}") 31 | else: 32 | # Create the extraction directory if it doesn't exist 33 | os.makedirs(extract_path, exist_ok=True) 34 | 35 | file_extension = os.path.splitext(compressed_file_path)[1].lower() 36 | file_name = os.path.splitext(os.path.basename(compressed_file_path))[0] 37 | 38 | # Create a subdirectory with the name of the compressed file 39 | specific_extract_path = os.path.join(extract_path, file_name) 40 | os.makedirs(specific_extract_path, exist_ok=True) 41 | 42 | try: 43 | if file_extension == '.zip': 44 | with zipfile.ZipFile(compressed_file_path, 'r') as zip_ref: 45 | zip_ref.extractall(specific_extract_path) 46 | 47 | elif file_extension == '.7z': 48 | with py7zr.SevenZipFile(compressed_file_path, mode='r') as z: 49 | z.extractall(specific_extract_path) 50 | 51 | elif file_extension in ['.tar', '.gz', '.bz2', '.xz']: 52 | if file_extension == '.gz': 53 | open_func = gzip.open 54 | elif file_extension == '.bz2': 55 | open_func = bz2.open 56 | elif file_extension == '.xz': 57 | open_func = lzma.open 58 | else: 59 | open_func = open 60 | 61 | with open_func(compressed_file_path, 'rb') as f: 62 | if tarfile.is_tarfile(compressed_file_path) or file_extension in ['.gz', '.bz2', '.xz']: 63 | with tarfile.open(fileobj=f) as tar: 64 | tar.extractall(path=specific_extract_path) 65 | else: 66 | # For single file compression (non-tar) 67 | output_filename = os.path.splitext(os.path.basename(compressed_file_path))[0] 68 | output_path = os.path.join(specific_extract_path, output_filename) 69 | with open(output_path, 'wb') as out_f: 70 | shutil.copyfileobj(f, out_f) 71 | 72 | else: 73 | raise ValueError(f"Unsupported file format: {file_extension}") 74 | 75 | print(f"Successfully extracted {compressed_file_path} to {specific_extract_path}") 76 | return specific_extract_path 77 | 78 | except Exception as e: 79 | print(f"Error extracting {compressed_file_path}: {str(e)}") 80 | raise 81 | 82 | 83 | def preview_zip_structure(zip_path, max_files=50, max_dirs=20, max_output_length=1000, show_hidden=False): 84 | """ 85 | Preview the structure of a zip file with limits on output and option to show hidden files. 86 | :param zip_path: the path to the zip file. 87 | :param max_files: maximum number of files to display. 88 | :param max_dirs: maximum number of directories to display. 89 | :param max_output_length: maximum length of the output string. 90 | :param show_hidden: if True, show hidden files and directories (starting with a dot). 91 | :return: the limited structure of the zip file as a string. 92 | """ 93 | if not os.path.exists(zip_path): 94 | return f"Error: The file '{zip_path}' does not exist." 95 | 96 | if not zipfile.is_zipfile(zip_path): 97 | return f"Error: '{zip_path}' is not a valid zip file." 98 | 99 | structure = [] 100 | file_count = 0 101 | dir_count = 0 102 | total_count = 0 103 | hidden_count = 0 104 | 105 | with zipfile.ZipFile(zip_path, 'r') as zip_ref: 106 | for file_info in zip_ref.infolist(): 107 | file_path = file_info.filename 108 | is_hidden = os.path.basename(file_path).startswith('.') 109 | 110 | if is_hidden and not show_hidden: 111 | hidden_count += 1 112 | continue 113 | 114 | if file_info.is_dir(): 115 | if dir_count < max_dirs: 116 | structure.append(f"Directory: {file_path}") 117 | dir_count += 1 118 | else: 119 | if file_count < max_files: 120 | structure.append(f"File: {file_path}") 121 | file_count += 1 122 | 123 | total_count += 1 124 | if len("\n".join(structure)) >= max_output_length: 125 | structure.append("... (output truncated due to length)") 126 | break 127 | 128 | if file_count >= max_files: 129 | structure.append(f"... (and {total_count - file_count - dir_count} more files)") 130 | if dir_count >= max_dirs: 131 | structure.append(f"... (and {total_count - file_count - dir_count} more directories)") 132 | if not show_hidden and hidden_count > 0: 133 | structure.append(f"... ({hidden_count} hidden items not shown)") 134 | 135 | output = "\n".join(structure) 136 | if len(output) > max_output_length: 137 | output = output[:max_output_length] + "... (output truncated)" 138 | 139 | return output 140 | 141 | 142 | def preview_csv_data(path: str, limit_rows: int = 5, limit_columns: int = None) -> str: 143 | """ 144 | Preview the sample dataset from the project data path and include metadata. 145 | :param path: the path to a local CSV file. 146 | :param limit_rows: the number of rows to preview. 147 | :param limit_columns: the number of columns to preview. If None, all columns are previewed. 148 | :return: the sample dataset with metadata as a string. 149 | """ 150 | try: 151 | df = pd.read_csv(path) 152 | num_rows, num_cols = df.shape 153 | summary = [f"CSV file in `{path}` has {num_rows} rows and {num_cols} columns."] 154 | 155 | if limit_columns is not None and limit_columns < num_cols: 156 | columns_to_preview = sorted(df.columns)[:limit_columns] 157 | summary.append(f"Previewing {limit_columns} out of {num_cols} columns.") 158 | else: 159 | columns_to_preview = sorted(df.columns) 160 | 161 | summary.append("Here is some information about the columns:") 162 | 163 | for col in columns_to_preview: 164 | dtype = df[col].dtype 165 | name = f"{col} ({dtype})" 166 | nan_count = df[col].isnull().sum() 167 | if dtype == "bool": 168 | true_percentage = df[col].mean() * 100 169 | summary.append(f"{name} is {true_percentage:.2f}% True, {100 - true_percentage:.2f}% False") 170 | elif df[col].nunique() < 10: 171 | unique_values = df[col].unique().tolist() 172 | summary.append(f"{name} has {df[col].nunique()} unique values: {unique_values}") 173 | elif is_numeric_dtype(df[col]): 174 | min_val, max_val = df[col].min(), df[col].max() 175 | summary.append(f"{name} has range: {min_val:.2f} - {max_val:.2f}, {nan_count} NaN values") 176 | elif dtype == "object": 177 | unique_count = df[col].nunique() 178 | example_values = df[col].value_counts().head(limit_rows).index.tolist() 179 | summary.append(f"{name} has {unique_count} unique values. Some example values: {example_values}") 180 | 181 | return textwrap.dedent("\n".join(summary)).strip() 182 | except Exception as e: 183 | return f"Cannot read CSV data: {e}" 184 | -------------------------------------------------------------------------------- /mle/function/execution.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools to execute functions, acquire runtime logs. 3 | """ 4 | 5 | import subprocess 6 | from collections import deque 7 | 8 | 9 | def execute_command(command: str, max_lines: int = 30): 10 | """ 11 | Run a command in the shell and return the outputs, errors, and exit status, 12 | limiting the output to a specified number of most recent lines. 13 | 14 | Args: 15 | command (str): The input command to run. 16 | max_lines (int): Maximum number of output lines to keep. Defaults to 100. 17 | 18 | Return: A string of the exit status and the limited output (most recent lines). 19 | """ 20 | try: 21 | process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) 22 | output_buffer = deque(maxlen=max_lines) 23 | 24 | while True: 25 | line = process.stdout.readline() 26 | if not line and process.poll() is not None: 27 | break 28 | output_buffer.append(line.rstrip()) 29 | print(line, end='') 30 | 31 | exit_code = process.wait() 32 | 33 | limited_output = "\n".join(output_buffer) 34 | if len(output_buffer) == max_lines: 35 | return f"Exit code: {exit_code}\nOutput (last {max_lines} lines):\n{limited_output}" 36 | else: 37 | return f"Exit code: {exit_code}\nOutput:\n{limited_output}" 38 | 39 | except Exception as e: 40 | return f"Error running command: {str(e)}" 41 | -------------------------------------------------------------------------------- /mle/function/files.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def read_file(file_path: str, limit: int = 2000): 5 | """ 6 | Reads the contents of a file and returns it as a string. 7 | 8 | Args: 9 | file_path (str): The path to the file that needs to be read. 10 | limit (int, optional): Maximum number of lines to read. 11 | 12 | Returns: 13 | str: The contents of the file as a string. 14 | """ 15 | try: 16 | with open(file_path, 'r', encoding='utf-8') as file: 17 | if limit <= 0: 18 | return file.read() 19 | lines = [] 20 | for i, line in enumerate(file): 21 | if i >= limit: 22 | break 23 | lines.append(line) 24 | return ''.join(lines) 25 | except FileNotFoundError: 26 | return f"File not found: {file_path}" 27 | 28 | 29 | def create_file(path, content): 30 | """ 31 | Create a file with the given path and content. 32 | Args: 33 | path (str): The path to the file to create. 34 | content (str): The initial content to write to the file. 35 | """ 36 | try: 37 | with open(path, 'w') as f: 38 | f.write(content) 39 | return f"File created: {path}" 40 | except Exception as e: 41 | return f"Error creating file: {str(e)}" 42 | 43 | 44 | def write_file(path, content): 45 | """ 46 | Write content to a file. 47 | Args: 48 | path (str): The path to the file to write to. 49 | content (str): The content to write to the file. 50 | """ 51 | try: 52 | with open(path, 'w') as f: 53 | f.write(content) 54 | return f"Content written to file: {path}" 55 | except Exception as e: 56 | return f"Error writing to file: {str(e)}" 57 | 58 | 59 | def list_files(path, limit=50): 60 | """ 61 | Lists files and directories under the given path if it is a directory, 62 | up to a specified limit. 63 | 64 | Args: 65 | path (str): The file system path to check and list contents from. 66 | limit (int): Maximum number of items to list. Defaults to 50. 67 | 68 | Returns: A string containing the list of file and directory names under 69 | the given path, or a message if the path is a file or if the 70 | number of items exceeds the limit. 71 | """ 72 | if os.path.isfile(path): 73 | return "The given path is a file. Please provide a path of a directory." 74 | 75 | try: 76 | files = os.listdir(path) 77 | except PermissionError: 78 | return "Permission denied to access this directory." 79 | except FileNotFoundError: 80 | return "The specified directory does not exist." 81 | except Exception as e: 82 | return f"An error occurred: {str(e)}" 83 | 84 | total_files = len(files) 85 | 86 | if total_files > limit: 87 | files = files[:limit] 88 | output = "\n".join(files) 89 | output += f"\n\n... and {total_files - limit} more items (total of {total_files} items)" 90 | else: 91 | output = "\n".join(files) 92 | output += f"\n\nTotal items: {total_files}" 93 | 94 | return output 95 | 96 | 97 | def create_directory(path: str): 98 | """ 99 | Create a directory if it does not exist. 100 | Args: 101 | path (str): The path to the directory to create. 102 | """ 103 | try: 104 | os.makedirs(path, exist_ok=True) 105 | return f"Directory '{path}' created successfully." 106 | except OSError as error: 107 | return f"Creation of the directory '{path}' failed due to: {error}" 108 | -------------------------------------------------------------------------------- /mle/function/interaction.py: -------------------------------------------------------------------------------- 1 | import questionary 2 | 3 | 4 | def ask_question(question: str): 5 | """ 6 | Ask a question to the user. 7 | """ 8 | answer = input("[ADVISOR]: " + question) 9 | return f"Question: {question}\nAnswer: {answer}" 10 | 11 | 12 | def ask_yes_no(question: str): 13 | """ 14 | Ask a yes/no question to the user. 15 | """ 16 | return questionary.confirm("[ADVISOR]: " + question).ask() 17 | 18 | 19 | def ask_choices(question: str, choices: list): 20 | """ 21 | Ask a multiple choice question to the user. 22 | """ 23 | return "More details: " + questionary.select("[ADVISOR]: " + question, choices=choices).ask() 24 | -------------------------------------------------------------------------------- /mle/function/search.py: -------------------------------------------------------------------------------- 1 | """ 2 | Search API functions based on Tavily. 3 | """ 4 | import os 5 | import requests 6 | from tavily import TavilyClient 7 | from xml.etree import ElementTree 8 | 9 | 10 | def search_github_repos(query, limit=5): 11 | """ 12 | Search GitHub public repositories based on a keyword. 13 | 14 | :param query: The query to search for in repository names or descriptions. 15 | :param limit: The total number of repositories to return. 16 | :return: A list of dictionaries containing repository details, limited to the specified number. 17 | """ 18 | repos = [] 19 | per_page = 10 20 | page = 1 21 | while len(repos) < limit: 22 | url = f'https://api.github.com/search/repositories?q={query}&per_page={per_page}&page={page}' 23 | 24 | response = requests.get(url) 25 | 26 | if response.status_code == 200: 27 | items = response.json().get('items', []) 28 | for item in items: 29 | formatted_repo = { 30 | "name": f"{item['owner']['login']}/{item['name']}", 31 | "author": item['owner']['login'], 32 | "description": item['description'], 33 | "link": item['html_url'] 34 | } 35 | repos.append(formatted_repo) 36 | if len(repos) >= limit: 37 | break 38 | 39 | if len(items) < per_page: # Stop if there are no more repos to fetch 40 | break 41 | page += 1 42 | else: 43 | raise Exception(f"GitHub API request failed with status code {response.status_code}: {response.text}") 44 | 45 | return_str = """ 46 | Here are some of the repositories I found on GitHub: 47 | """ 48 | 49 | for repo in repos: 50 | return_str += f""" 51 | Name: {repo['name']} 52 | Description: {repo['description']} 53 | Link: {repo['link']} 54 | """ 55 | 56 | return return_str 57 | 58 | 59 | def web_search(query: str): 60 | """ 61 | Perform a web search based on the query. 62 | Args: 63 | query: The search query. 64 | """ 65 | try: 66 | client = TavilyClient(api_key=os.environ['SEARCH_API_KEY']) 67 | response = client.qna_search(query=query, search_depth="advanced") 68 | return response 69 | except Exception as e: 70 | return f"Error performing web search: {str(e)}" 71 | 72 | 73 | def search_arxiv(query, max_results=8): 74 | url = 'https://export.arxiv.org/api/query' 75 | params = { 76 | 'search_query': query, 77 | 'start': 0, 78 | 'max_results': max_results 79 | } 80 | response = requests.get(url, params=params) 81 | if response.status_code != 200: 82 | return f"Error: Unable to fetch data from arXiv (Status code: {response.status_code})" 83 | 84 | root = ElementTree.fromstring(response.content) 85 | output = "" 86 | for entry in root.findall('{http://www.w3.org/2005/Atom}entry'): 87 | title = entry.find('{http://www.w3.org/2005/Atom}title').text 88 | summary = entry.find('{http://www.w3.org/2005/Atom}summary').text 89 | link = entry.find('{http://www.w3.org/2005/Atom}id').text 90 | published = entry.find('{http://www.w3.org/2005/Atom}published').text 91 | authors = [author.find('{http://www.w3.org/2005/Atom}name').text for author in 92 | entry.findall('{http://www.w3.org/2005/Atom}author')] 93 | 94 | output += f""" 95 | Title: {title.strip()} 96 | Summary: {summary.strip()} 97 | Link: {link.strip()} 98 | Published: {published.strip()} 99 | Authors: {authors} 100 | """ 101 | 102 | return output 103 | 104 | 105 | def search_papers_with_code(query: str, k: int = 8) -> str: 106 | url = f"https://paperswithcode.com/api/v1/search/" 107 | response = requests.get(url, params={'page': 1, 'q': query}) 108 | if response.status_code != 200: 109 | return "Failed to retrieve data from Papers With Code." 110 | 111 | data = response.json() 112 | if 'results' not in data: 113 | return "No results found for the given query." 114 | 115 | results = data['results'][:int(k)] # Get top-k results 116 | result_strings = [] 117 | 118 | for result in results: 119 | paper = result['paper'] 120 | paper_title = paper.get('title', 'No title available') 121 | abstract = paper.get('abstract', 'No abstract available') 122 | paper_pdf_url = paper.get('url_pdf', 'No PDF available') 123 | repository = result.get('repository', []) 124 | if repository: 125 | code_url = repository.get('url', 'No official code link available') 126 | else: 127 | code_url = 'No official code link available' 128 | 129 | result_string = f"Title: {paper_title}\nAbstract:{abstract}\nPaper URL: {paper_pdf_url}\nCode URL: {code_url}\n" 130 | result_strings.append(result_string) 131 | 132 | return "\n".join(result_strings) 133 | -------------------------------------------------------------------------------- /mle/integration/__init__.py: -------------------------------------------------------------------------------- 1 | from .local_git import GitIntegration 2 | from .github import GitHubIntegration, github_login 3 | from .google_calendar import GoogleCalendarIntegration, google_calendar_login 4 | from .kaggle import KaggleIntegration 5 | -------------------------------------------------------------------------------- /mle/integration/google_calendar.py: -------------------------------------------------------------------------------- 1 | import json 2 | import datetime 3 | from mle.utils import load_file 4 | from google.auth.transport.requests import Request 5 | from google_auth_oauthlib.flow import InstalledAppFlow 6 | from googleapiclient.discovery import build 7 | 8 | 9 | def google_calendar_login(credential=None): 10 | """ 11 | Authenticate the user using Google OAuth 2.0 and return the credentials. 12 | 13 | :param credential: The client secrets. 14 | :return: Google OAuth 2.0 credentials or None if authentication fails. 15 | """ 16 | 17 | if credential is None: 18 | # FIXME: remove the test app_credential 19 | credential = json.loads(load_file( 20 | "https://raw.githubusercontent.com/leeeizhang/leeeizhang/assets/google_app", 21 | base64_decode=True, 22 | )) 23 | 24 | try: 25 | SCOPES = ["https://www.googleapis.com/auth/calendar.readonly"] 26 | flow = InstalledAppFlow.from_client_config( 27 | credential, 28 | SCOPES, 29 | ) 30 | creds = flow.run_local_server(host="127.0.0.1", port=0) 31 | except Exception: 32 | return None 33 | return creds 34 | 35 | 36 | class GoogleCalendarIntegration: 37 | """ 38 | Class to interface with Google Calendar API to fetch events. 39 | 40 | :param token: Google OAuth 2.0 credentials. 41 | """ 42 | 43 | def __init__(self, token=None): 44 | self.token = token 45 | if self.token.expired and self.token.refresh_token: 46 | self.token.refresh(Request()) 47 | 48 | def get_events(self, start_date=None, end_date=None, limit=100, detailed=True): 49 | """ 50 | Fetch upcoming calendar events. 51 | :param start_date: Start date for calendar events (inclusive), in 'YYYY-MM-DD' format 52 | :param end_date: End date for calendar events (inclusive), in 'YYYY-MM-DD' format 53 | :param limit: The maximum number of events to return. 54 | :param detailed: Whether to return detailed event information. 55 | :return: A list of events with details or None if an error occurs. 56 | """ 57 | try: 58 | # Set default dates if not provided 59 | today = datetime.date.today() 60 | if start_date is None: 61 | start_date = (today - datetime.timedelta(days=7)).isoformat() 62 | if end_date is None: 63 | end_date = (today + datetime.timedelta(days=7)).isoformat() 64 | 65 | # Convert dates to datetime objects with time 66 | start_dt = datetime.datetime.strptime(f"{start_date}T00:00:00Z", "%Y-%m-%dT%H:%M:%S%z") 67 | end_dt = datetime.datetime.strptime(f"{end_date}T23:59:59Z", "%Y-%m-%dT%H:%M:%S%z") 68 | 69 | if start_dt >= end_dt: 70 | raise ValueError("start_date must be less than end_date") 71 | 72 | # Convert back to string format for API call 73 | start_date = start_dt.isoformat() 74 | end_date = end_dt.isoformat() 75 | 76 | # Build the service object for interacting with the Google Calendar API 77 | service = build("calendar", "v3", credentials=self.token) 78 | 79 | # Retrieve the events from the primary calendar 80 | events_result = ( 81 | service.events() 82 | .list( 83 | calendarId="primary", 84 | timeMin=start_date, 85 | timeMax=end_date, 86 | maxResults=limit, 87 | singleEvents=True, 88 | orderBy="startTime", 89 | ) 90 | .execute() 91 | ) 92 | events_result = events_result.get("items", []) 93 | 94 | # Format the events into a specified structure 95 | events = [] 96 | for event in events_result: 97 | e = { 98 | "title": event.get("summary"), 99 | "status": event.get("status"), 100 | "description": event.get("description"), 101 | "creator": event.get("creator"), 102 | "organizer": event.get("organizer"), 103 | "start_time": event["start"].get("dateTime", event["start"].get("date")), 104 | "end_time": event["end"].get("dateTime", event["end"].get("date")) 105 | } 106 | 107 | if detailed: 108 | e.update( 109 | { 110 | "htmlLink": event.get("htmlLink"), 111 | "kind": event.get("kind"), 112 | } 113 | ) 114 | 115 | events.append(e) 116 | return events 117 | 118 | except Exception as e: 119 | print(f"An error occurred: {e}") 120 | return None 121 | -------------------------------------------------------------------------------- /mle/integration/kaggle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import requests 4 | import questionary 5 | from zipfile import ZipFile 6 | 7 | 8 | class KaggleIntegration: 9 | 10 | def __init__(self): 11 | """ 12 | Initializes KaggleIntegration with the provided credentials. 13 | """ 14 | 15 | kaggle_file = os.path.join(os.path.expanduser("~"), ".kaggle", "kaggle.json") 16 | 17 | if not os.path.exists(kaggle_file): 18 | username = questionary.text("What is your Kaggle username?").ask() 19 | key = questionary.password("What is your Kaggle token?").ask() 20 | if username and key: 21 | os.makedirs(os.path.dirname(kaggle_file), exist_ok=True) 22 | with open(kaggle_file, "w") as f: 23 | json.dump({"username": username, "key": key}, f) 24 | 25 | from kaggle.api.kaggle_api_extended import KaggleApi 26 | self.api = KaggleApi() 27 | self.api.authenticate() 28 | 29 | def list_competition(self): 30 | """ 31 | Lists all Kaggle competitions. 32 | :return: A tuple containing references of all competitions. 33 | """ 34 | competitions = self.api.competitions_list() 35 | return tuple([comp.ref for comp in competitions]) 36 | 37 | def download_competition_dataset( 38 | self, competition: str, download_dir: str = "./data" 39 | ): 40 | """ 41 | Downloads and extracts the dataset for a specific competition. 42 | :param competition: The URL or name of the Kaggle competition. 43 | :param download_dir: Directory to save the downloaded files. Defaults to './data'. 44 | :return: The directory where the dataset has been downloaded and extracted. 45 | """ 46 | if competition.startswith("https://www.kaggle.com/competitions/"): 47 | competition = competition.split("/")[-1] 48 | 49 | os.makedirs(download_dir, exist_ok=True) 50 | self.api.competition_download_files(competition, path=download_dir) 51 | 52 | # Unzip downloaded files 53 | for file in os.listdir(download_dir): 54 | if file.endswith(".zip"): 55 | with ZipFile(os.path.join(download_dir, file), "r") as zip_ref: 56 | zip_ref.extractall(download_dir) 57 | return download_dir 58 | 59 | def fetch_competition_overview(self, competition: str): 60 | """ 61 | Fetches competition overview information using the Kaggle API. 62 | :param competition: The URL or name of the Kaggle competition. 63 | :return: A dictionary containing competition overview information, or None if not found. 64 | """ 65 | for _ in range(3): # Retry 3 times if the request fails 66 | try: 67 | reader_url = f"https://r.jina.ai/{competition}/overview/description" 68 | response = requests.get( 69 | reader_url, 70 | timeout=30, 71 | headers={"X-Return-Format": "markdown"}, 72 | ) 73 | response.raise_for_status() 74 | overview = response.text 75 | break 76 | except requests.exceptions.HTTPError: 77 | continue 78 | return overview.encode('utf-8', 'ignore').decode('utf-8') 79 | -------------------------------------------------------------------------------- /mle/integration/local_git.py: -------------------------------------------------------------------------------- 1 | from git import Repo, NULL_TREE 2 | from datetime import datetime, timezone, timedelta 3 | 4 | import os 5 | import fnmatch 6 | import subprocess 7 | 8 | class GitIntegration: 9 | def __init__(self, path): 10 | self.repo_path = path 11 | self.repo = Repo(self.repo_path) 12 | if self.repo.bare: 13 | raise Exception("Repository is not valid or is bare.") 14 | 15 | def get_repo_status(self): 16 | """ 17 | Get the status of a git repository 18 | :return: List of changed files 19 | """ 20 | try: 21 | changed_files = [] 22 | for diff in self.repo.index.diff(None): 23 | changed_files.append({ 24 | 'file_path': diff.a_path, 25 | 'change_type': diff.change_type, 26 | 'author': self.repo.head.commit.author.name, 27 | 'date': datetime.fromtimestamp(self.repo.head.commit.committed_date).strftime("%Y-%m-%d %H:%M:%S") 28 | }) 29 | 30 | return changed_files 31 | 32 | except Exception as e: 33 | return f"An error occurred: {str(e)}" 34 | 35 | def get_commit_history(self, start_date=None, end_date=None, email=None, limit=None): 36 | """ 37 | Process commit history within a specified date range and for a specific user (email). 38 | :param start_date: Start date for commit range (inclusive), in 'YYYY-MM-DD' format 39 | :param end_date: End date for commit range (inclusive), in 'YYYY-MM-DD' format 40 | :param username: GitHub username to filter commits (optional) 41 | :param limit: Maximum number of commits to retrieve (default is None, which retrieves all commits in range) 42 | :return: Dictionary of commits 43 | """ 44 | end_time = None 45 | if end_date is not None: 46 | end_time = f"{end_date}T23:59:59Z" 47 | 48 | start_time = None 49 | if start_date is not None: 50 | start_time = f"{start_date}T00:00:00Z" 51 | 52 | try: 53 | commit_history = [] 54 | for commit in self.repo.iter_commits(max_count=limit): 55 | commit_date = datetime.fromtimestamp(commit.committed_date) 56 | commit_date = commit_date.replace(tzinfo=timezone.utc) 57 | if start_time is not None and commit_date < datetime.fromisoformat(start_time): 58 | continue 59 | 60 | if end_time is not None and commit_date > datetime.fromisoformat(end_time): 61 | continue 62 | 63 | if email is not None and commit.author.email != email: 64 | continue 65 | 66 | commit_history.append({ 67 | 'commit_hash': commit.hexsha, 68 | 'author': commit.author.name, 69 | 'email': commit.author.email, 70 | 'message': commit.message.strip(), 71 | 'date': commit_date.strftime("%Y-%m-%d %H:%M:%S") 72 | }) 73 | 74 | return commit_history 75 | 76 | except Exception as e: 77 | return f"An error occurred: {str(e)}" 78 | 79 | def get_commit_diff(self, commit_hash, show_content=False): 80 | """ 81 | Get the code changes for a specific commit 82 | :param commit_hash: The hash of the commit to get the changes for 83 | :param show_content: Boolean to determine if file content should be included (default False) 84 | :return: A dictionary containing the commit info and the code changes 85 | """ 86 | try: 87 | commit = self.repo.commit(commit_hash) 88 | parent = commit.parents[0] if commit.parents else NULL_TREE 89 | 90 | diffs = parent.diff(commit) 91 | 92 | changes = {} 93 | for diff in diffs: 94 | if show_content: 95 | if diff.a_blob and diff.b_blob: 96 | a_content = diff.a_blob.data_stream.read().decode('utf-8', errors='ignore') 97 | b_content = diff.b_blob.data_stream.read().decode('utf-8', errors='ignore') 98 | changes[diff.a_path] = { 99 | 'change_type': 'modified', 100 | 'old_content': a_content, 101 | 'new_content': b_content 102 | } 103 | elif diff.a_blob: 104 | changes[diff.a_path] = { 105 | 'change_type': 'deleted', 106 | 'old_content': diff.a_blob.data_stream.read().decode('utf-8', errors='ignore'), 107 | 'new_content': None 108 | } 109 | elif diff.b_blob: 110 | changes[diff.b_path] = { 111 | 'change_type': 'added', 112 | 'old_content': None, 113 | 'new_content': diff.b_blob.data_stream.read().decode('utf-8', errors='ignore') 114 | } 115 | else: 116 | if diff.a_blob and diff.b_blob: 117 | changes[diff.a_path] = {'change_type': 'modified'} 118 | elif diff.a_blob: 119 | changes[diff.a_path] = {'change_type': 'deleted'} 120 | elif diff.b_blob: 121 | changes[diff.b_path] = {'change_type': 'added'} 122 | 123 | commit_info = { 124 | 'commit_hash': commit.hexsha, 125 | 'author': commit.author.name, 126 | 'email': commit.author.email, 127 | 'message': commit.message.strip(), 128 | 'date': datetime.fromtimestamp(commit.committed_date).strftime("%Y-%m-%d %H:%M:%S"), 129 | 'changes': changes 130 | } 131 | 132 | return commit_info 133 | 134 | except Exception as e: 135 | return f"An error occurred: {str(e)}" 136 | 137 | def get_source_code(self, file_pattern="*"): 138 | """ 139 | Process source code files in the repository. 140 | :param file_pattern: Wildcard pattern to filter files (e.g., "*.py" for Python files) 141 | :return: Dictionary with file paths as keys and file contents as values 142 | """ 143 | 144 | def get_contents(path="", file_pattern=file_pattern): 145 | for root, _, files in os.walk(os.path.join(self.repo_path, path)): 146 | for filename in fnmatch.filter(files, file_pattern): 147 | file_path = os.path.join(root, filename) 148 | with open(file_path, 'r') as f: 149 | yield { 150 | 'path': os.path.relpath(file_path, self.repo_path), 151 | 'name': filename, 152 | 'content': f.read() 153 | } 154 | 155 | return {file['path']: file['content'] for file in get_contents()} 156 | 157 | def get_readme(self): 158 | """ 159 | Get readme content of the repository. 160 | :return: The readme content 161 | """ 162 | content = self.get_source_code("README.md") 163 | if len(content): 164 | return list(content.values())[0] 165 | return None 166 | 167 | def get_structure(self, path=''): 168 | """ 169 | Scan and return the file structure and file names of the Git repository as a list of paths. 170 | :param path: The path to start scanning from (default is root) 171 | :param branch: The branch to scan (if None, the repository's default branch will be used) 172 | :param include_invisible: Whether to include invisible files/folders (starting with .) (default is False) 173 | :return: A list of file paths in the repository 174 | """ 175 | result = subprocess.run( 176 | ["git", "-C", path, "ls-files"], 177 | stdout=subprocess.PIPE, 178 | text=True, 179 | check=True 180 | ) 181 | 182 | return result.stdout.splitlines() 183 | 184 | def get_user_activity(self, email, start_date=None, end_date=None): 185 | """ 186 | Aggregate information about a user's activity within a specific time period. 187 | :param email: User email to analyze 188 | :param start_date: Start date for the analysis period, in 'YYYY-MM-DD' format 189 | :param end_date: End date for the analysis period, in 'YYYY-MM-DD' format 190 | :return: Dictionary containing aggregated user activity information, if the 191 | start and end dates are not provided, the default period is the last 7 days. 192 | """ 193 | if end_date is None: 194 | end_datetime = datetime.now(timezone.utc).replace(hour=23, minute=59, second=59, microsecond=0) 195 | end_date = end_datetime.strftime("%Y-%m-%d") 196 | else: 197 | end_datetime = (datetime.strptime(end_date, "%Y-%m-%d") 198 | .replace(hour=23, minute=59, second=59, tzinfo=timezone.utc)) 199 | 200 | if start_date is None: 201 | start_datetime = end_datetime - timedelta(days=6) 202 | start_date = start_datetime.strftime("%Y-%m-%d") 203 | 204 | # Fetch data 205 | commits = self.get_commit_history(start_date, end_date, email) 206 | 207 | # Aggregate commit information 208 | commit_count = len(commits) 209 | commit_messages = [commit['message'] for commit in commits] 210 | 211 | # Compile the report 212 | report = { 213 | 'username': email, 214 | 'period': { 215 | 'start': start_date, 216 | 'end': end_date 217 | }, 218 | 'summary': { 219 | 'total_commits': commit_count, 220 | 'total_pull_requests': 0, 221 | 'total_issues': 0 222 | }, 223 | 'commits': { 224 | 'count': commit_count, 225 | 'messages': commit_messages 226 | }, 227 | 'pull_requests': { 228 | 'count': 0, 229 | 'details': [] 230 | }, 231 | 'issues': { 232 | 'count': 0, 233 | 'details': [] 234 | } 235 | } 236 | 237 | return report 238 | -------------------------------------------------------------------------------- /mle/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .anthropic import * 2 | from .deepseek import * 3 | from .mistral import * 4 | from .ollama import * 5 | from .openai import * 6 | from .gemini import * 7 | from .vllm import * 8 | 9 | from mle.utils import get_config 10 | 11 | 12 | MODEL_OLLAMA = 'Ollama' 13 | MODEL_OPENAI = 'OpenAI' 14 | MODEL_CLAUDE = 'Claude' 15 | MODEL_MISTRAL = 'MistralAI' 16 | MODEL_DEEPSEEK = 'DeepSeek' 17 | MODEL_GEMINI = 'Gemini' 18 | MODEL_VLLM = 'vLLM' 19 | 20 | 21 | class ObservableModel: 22 | """ 23 | A class that wraps a model to make it trackable by the metric platform (e.g., Langfuse). 24 | """ 25 | 26 | try: 27 | from mle.utils import get_langfuse_observer 28 | _observe = get_langfuse_observer() 29 | except Exception as e: 30 | # If importing fails, set _observe to a lambda function that does nothing. 31 | _observe = lambda fn: fn 32 | 33 | def __init__(self, model: Model): 34 | """ 35 | Initialize the ObservableModel. 36 | Args: 37 | model: The model to be wrapped and made observable. 38 | """ 39 | self.model = model 40 | 41 | @_observe 42 | def query(self, *args, **kwargs): 43 | return self.model.query(*args, **kwargs) 44 | 45 | @_observe 46 | def stream(self, *args, **kwargs): 47 | return self.model.query(*args, **kwargs) 48 | 49 | 50 | def load_model(project_dir: str, model_name: str=None, observable=True): 51 | """ 52 | load_model: load the model based on the configuration. 53 | Args: 54 | project_dir (str): The project directory. 55 | model_name (str): The model name. 56 | observable (boolean): Whether the model should be tracked. 57 | """ 58 | config = get_config(project_dir) 59 | model = None 60 | 61 | if config['platform'] == MODEL_OLLAMA: 62 | model = OllamaModel(model=model_name) 63 | if config['platform'] == MODEL_OPENAI: 64 | model = OpenAIModel(api_key=config['api_key'], model=model_name) 65 | if config['platform'] == MODEL_CLAUDE: 66 | model = ClaudeModel(api_key=config['api_key'], model=model_name) 67 | if config['platform'] == MODEL_MISTRAL: 68 | model = MistralModel(api_key=config['api_key'], model=model_name) 69 | if config['platform'] == MODEL_DEEPSEEK: 70 | model = DeepSeekModel(api_key=config['api_key'], model=model_name) 71 | if config['platform'] == MODEL_GEMINI: 72 | model = GeminiModel(api_key=config['api_key'], model=model_name) 73 | if config['platform'] == MODEL_VLLM: 74 | model = vLLMModel(base_url=config.get('base_url', 'http://localhost:8000/v1'), model=model_name) 75 | 76 | if observable: 77 | return ObservableModel(model) 78 | return model 79 | -------------------------------------------------------------------------------- /mle/model/anthropic.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | 3 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name 4 | from mle.model.common import Model 5 | 6 | 7 | class ClaudeModel(Model): 8 | def __init__(self, api_key, model, temperature=0.7): 9 | """ 10 | Initialize the Claude model. 11 | Args: 12 | api_key (str): The Anthropic API key. 13 | model (str): The model with version. 14 | temperature (float): The temperature value. 15 | """ 16 | super().__init__() 17 | 18 | dependency = "anthropic" 19 | spec = importlib.util.find_spec(dependency) 20 | if spec is not None: 21 | self.anthropic = importlib.import_module(dependency).Anthropic 22 | else: 23 | raise ImportError( 24 | "It seems you didn't install anthropic. In order to enable the OpenAI client related features, " 25 | "please make sure openai Python package has been installed. " 26 | "More information, please refer to: https://docs.anthropic.com/en/api/client-sdks" 27 | ) 28 | 29 | self.model = model if model else 'claude-3-5-sonnet-20240620' 30 | self.model_type = 'Claude' 31 | self.temperature = temperature 32 | self.client = self.anthropic(api_key=api_key) 33 | self.func_call_history = [] 34 | 35 | @staticmethod 36 | def _add_tool_result_into_chat_history(chat_history, func, result): 37 | """ 38 | Add the result of tool calls into messages. 39 | """ 40 | return chat_history.extend([ 41 | { 42 | "role": "assistant", 43 | "content": [ 44 | { 45 | "type": "tool_use", 46 | "id": func.id, 47 | "name": func.name, 48 | "input": func.input, 49 | }, 50 | ] 51 | }, 52 | { 53 | "role": "user", 54 | "content": [ 55 | { 56 | "type": "tool_result", 57 | "tool_use_id": func.id, 58 | "content": result, 59 | }, 60 | ] 61 | }, 62 | ]) 63 | 64 | def query(self, chat_history, **kwargs): 65 | """ 66 | Query the LLM model. 67 | 68 | Args: 69 | chat_history: The context (chat history). 70 | """ 71 | # claude has not system role in chat_history 72 | # https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts 73 | system_prompt = "" 74 | for idx, msg in enumerate(chat_history): 75 | if msg["role"] == "system": 76 | system_prompt += msg["content"] 77 | 78 | # claude does not support mannual `response_format`, so we append it into system prompt 79 | if "response_format" in kwargs.keys(): 80 | system_prompt += ( 81 | f"\nOutputs only valid {kwargs['response_format']['type']} without any explanatory words" 82 | ) 83 | 84 | # mapping the openai function_schema to claude tool_schema 85 | tools = kwargs.get("functions",[]) 86 | for tool in tools: 87 | if "parameters" in tool.keys(): 88 | tool["input_schema"] = tool["parameters"] 89 | del tool["parameters"] 90 | 91 | completion = self.client.messages.create( 92 | max_tokens=4096, 93 | model=self.model, 94 | system=system_prompt, 95 | messages=[msg for msg in chat_history if msg["role"] != "system"], 96 | temperature=self.temperature, 97 | stream=False, 98 | tools=tools, 99 | ) 100 | if completion.stop_reason == "tool_use": 101 | for func in completion.content: 102 | if func.type != "tool_use": 103 | continue 104 | function_name = process_function_name(func.name) 105 | arguments = func.input 106 | print("[MLE FUNC CALL]: ", function_name) 107 | self.func_call_history.append({"name": function_name, "arguments": arguments}) 108 | # avoid the multiple search function calls 109 | search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS] 110 | if len(search_attempts) > 3: 111 | kwargs['functions'] = [] 112 | result = get_function(function_name)(**arguments) 113 | self._add_tool_result_into_chat_history(chat_history, func, result) 114 | return self.query(chat_history, **kwargs) 115 | else: 116 | return completion.content[0].text 117 | 118 | def stream(self, chat_history, **kwargs): 119 | """ 120 | Stream the output from the LLM model. 121 | Args: 122 | chat_history: The context (chat history). 123 | """ 124 | # claude has not system role in chat_history 125 | # https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts 126 | system_prompt = "" 127 | for idx, msg in enumerate(chat_history): 128 | if msg["role"] == "system": 129 | system_prompt += msg["content"] 130 | chat_history = [msg for msg in chat_history if msg["role"] != "system"] 131 | 132 | # claude does not support mannual `response_format`, so we append it into system prompt 133 | if "response_format" in kwargs.keys(): 134 | system_prompt += ( 135 | f"\nOutputs only valid {kwargs['response_format']['type']} without any explanatory words" 136 | ) 137 | 138 | with self.client.messages.stream( 139 | max_tokens=4096, 140 | model=self.model, 141 | messages=chat_history, 142 | ) as stream: 143 | for chunk in stream.text_stream: 144 | yield chunk 145 | -------------------------------------------------------------------------------- /mle/model/common.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Model(ABC): 5 | 6 | def __init__(self): 7 | """ 8 | Initialize the model. 9 | """ 10 | self.model_type = None 11 | 12 | @abstractmethod 13 | def query(self, chat_history, **kwargs): 14 | pass 15 | 16 | @abstractmethod 17 | def stream(self, chat_history, **kwargs): 18 | pass 19 | -------------------------------------------------------------------------------- /mle/model/deepseek.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | 4 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name 5 | from mle.model.common import Model 6 | 7 | 8 | class DeepSeekModel(Model): 9 | def __init__(self, api_key, model, temperature=0.7): 10 | """ 11 | Initialize the DeepSeek model. 12 | Args: 13 | api_key (str): The DeepSeek API key. 14 | model (str): The model with version. 15 | temperature (float): The temperature value. 16 | """ 17 | super().__init__() 18 | 19 | dependency = "openai" 20 | spec = importlib.util.find_spec(dependency) 21 | if spec is not None: 22 | self.openai = importlib.import_module(dependency).OpenAI 23 | else: 24 | raise ImportError( 25 | "It seems you didn't install openai. In order to enable the OpenAI client related features, " 26 | "please make sure openai Python package has been installed. " 27 | "More information, please refer to: https://openai.com/product" 28 | ) 29 | self.model = model if model else "deepseek-coder" 30 | self.model_type = 'DeepSeek' 31 | self.temperature = temperature 32 | self.client = self.openai( 33 | api_key=api_key, base_url="https://api.deepseek.com/beta" 34 | ) 35 | self.func_call_history = [] 36 | 37 | def _convert_functions_to_tools(self, functions): 38 | """ 39 | Convert OpenAI-style functions to DeepSeek-style tools. 40 | """ 41 | tools = [] 42 | for func in functions: 43 | tool = { 44 | "type": "function", 45 | "function": { 46 | "name": func["name"], 47 | "description": func.get("description", ""), 48 | "parameters": func["parameters"], 49 | }, 50 | } 51 | tools.append(tool) 52 | return tools 53 | 54 | def query(self, chat_history, **kwargs): 55 | """ 56 | Query the LLM model. 57 | 58 | Args: 59 | chat_history: The context (chat history). 60 | """ 61 | functions = kwargs.get("functions", None) 62 | tools = self._convert_functions_to_tools(functions) if functions else None 63 | parameters = kwargs 64 | completion = self.client.chat.completions.create( 65 | model=self.model, 66 | messages=chat_history, 67 | temperature=self.temperature, 68 | stream=False, 69 | tools=tools, 70 | **parameters, 71 | ) 72 | 73 | resp = completion.choices[0].message 74 | if resp.tool_calls: 75 | for tool_call in resp.tool_calls: 76 | chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False}) 77 | function_name = process_function_name(tool_call.function.name) 78 | arguments = json.loads(tool_call.function.arguments) 79 | print("[MLE FUNC CALL]: ", function_name) 80 | self.func_call_history.append({"name": function_name, "arguments": arguments}) 81 | # avoid the multiple search function calls 82 | search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS] 83 | if len(search_attempts) > 3: 84 | parameters['tool_choice'] = "none" 85 | result = get_function(function_name)(**arguments) 86 | chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id}) 87 | return self.query(chat_history, **parameters) 88 | else: 89 | return resp.content 90 | 91 | def stream(self, chat_history, **kwargs): 92 | """ 93 | Stream the output from the LLM model. 94 | Args: 95 | chat_history: The context (chat history). 96 | """ 97 | arguments = "" 98 | function_name = "" 99 | for chunk in self.client.chat.completions.create( 100 | model=self.model, 101 | messages=chat_history, 102 | temperature=self.temperature, 103 | stream=True, 104 | **kwargs, 105 | ): 106 | if chunk.choices[0].delta.tool_calls: 107 | tool_call = chunk.choices[0].delta.tool_calls[0] 108 | if tool_call.function.name: 109 | chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False}) 110 | function_name = process_function_name(tool_call.function.name) 111 | arguments = json.loads(tool_call.function.arguments) 112 | result = get_function(function_name)(**arguments) 113 | chat_history.append({"role": "tool", "content": result, "name": function_name}) 114 | yield from self.stream(chat_history, **kwargs) 115 | else: 116 | yield chunk.choices[0].delta.content 117 | -------------------------------------------------------------------------------- /mle/model/gemini.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib.util 3 | import json 4 | 5 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name 6 | from mle.model.common import Model 7 | 8 | 9 | class GeminiModel(Model): 10 | 11 | def __init__(self, api_key, model, temperature=0.7): 12 | """ 13 | Initialize the Gemini model. 14 | Args: 15 | api_key (str): The Gemini API key. 16 | model (str): The model with version. 17 | temperature (float): The temperature value. 18 | """ 19 | super().__init__() 20 | 21 | dependency = "google.generativeai" 22 | spec = importlib.util.find_spec(dependency) 23 | if spec is not None: 24 | self.gemini = importlib.import_module(dependency) 25 | self.gemini.configure(api_key=api_key) 26 | else: 27 | raise ImportError( 28 | "It seems you didn't install `google-generativeai`. " 29 | "In order to enable the Gemini client related features, " 30 | "please make sure gemini Python package has been installed. " 31 | "More information, please refer to: https://ai.google.dev/gemini-api/docs/quickstart?lang=python" 32 | ) 33 | 34 | self.model = model if model else 'gemini-1.5-flash' 35 | self.model_type = 'Gemini' 36 | self.temperature = temperature 37 | self.func_call_history = [] 38 | 39 | def _map_chat_history_from_openai(self, chat_history): 40 | _key_map_dict = { 41 | "role": "role", 42 | "content": "parts", 43 | } 44 | _value_map_dict = { 45 | "system": "model", 46 | "user": "user", 47 | "assistant": "model", 48 | "content": "parts", 49 | } 50 | return [ 51 | { 52 | _key_map_dict.get(k, k): _value_map_dict.get(v, v) 53 | for k, v in dict(chat).items() 54 | } for chat in chat_history 55 | ] 56 | 57 | def _map_functions_from_openai(self, functions): 58 | def _mapping_type(_type: str): 59 | if _type == "string": 60 | return self.gemini.protos.Type.STRING 61 | if _type == "object": 62 | return self.gemini.protos.Type.OBJECT 63 | if _type == "integer": 64 | return self.gemini.protos.Type.NUMBER 65 | if _type == "boolean": 66 | return self.gemini.protos.Type.BOOLEAN 67 | if _type == "array": 68 | return self.gemini.protos.Type.ARRAY 69 | return self.gemini.protos.Type.TYPE_UNSPECIFIED 70 | 71 | return self.gemini.protos.Tool(function_declarations=[ 72 | self.gemini.protos.FunctionDeclaration( 73 | name=func.get("name"), 74 | description=func.get("description"), 75 | parameters=self.gemini.protos.Schema( 76 | type=_mapping_type(func.get("parameters", {}).get("type")), 77 | properties={ 78 | param_name: self.gemini.protos.Schema( 79 | type=_mapping_type(properties.get("type")), 80 | description=properties.get("description") 81 | ) 82 | for param_name, properties in \ 83 | func.get("parameters",{}).get("properties", {}).items() 84 | }, 85 | required=[key for key in func.get("parameters",{}).get("properties", {}).keys()], 86 | ) 87 | ) 88 | for func in functions 89 | ]) 90 | 91 | def _mapping_response_format_from_openai(self, response_format): 92 | if response_format.get("type") == "json_object": 93 | return "application/json" 94 | return None 95 | 96 | def query(self, chat_history, **kwargs): 97 | """ 98 | Query the LLM model. 99 | 100 | Args: 101 | chat_history: The context (chat history). 102 | """ 103 | parameters = kwargs 104 | chat_history = self._map_chat_history_from_openai(chat_history) 105 | 106 | tools = None 107 | if parameters.get("functions") is not None: 108 | tools = self._map_functions_from_openai(parameters["functions"]) 109 | 110 | client = self.gemini.GenerativeModel(self.model) 111 | chat_handler = client.start_chat(history=chat_history[:-1]) 112 | 113 | completion = chat_handler.send_message( 114 | chat_history[-1]["parts"], 115 | tools=tools, 116 | generation_config=self.gemini.types.GenerationConfig( 117 | max_output_tokens=4096, 118 | temperature=self.temperature, 119 | ), 120 | ) 121 | 122 | function_outputs = {} 123 | for part in completion.parts: 124 | fn = part.function_call 125 | if fn: 126 | print("[MLE FUNC CALL]: ", fn.name) 127 | # avoid the multiple search function calls 128 | search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS] 129 | if len(search_attempts) > 3: 130 | parameters['functions'] = None 131 | result = get_function(fn.name)(**dict(fn.args)) 132 | function_outputs[fn.name] = result 133 | 134 | if len(function_outputs): 135 | response_parts = [ 136 | self.gemini.protos.Part( 137 | function_response=self.gemini.protos.FunctionResponse( 138 | name=fn, response={"result": val} 139 | ) 140 | ) 141 | for fn, val in function_outputs.items() 142 | ] 143 | 144 | completion = chat_handler.send_message( 145 | self.gemini.protos.Content(parts=response_parts), 146 | generation_config=self.gemini.types.GenerationConfig( 147 | max_output_tokens=4096, 148 | temperature=self.temperature, 149 | response_mime_type=self._mapping_response_format_from_openai( 150 | parameters.get("response_format", {})), 151 | ), 152 | ) 153 | 154 | return completion.text 155 | 156 | def stream(self, chat_history, **kwargs): 157 | """ 158 | Stream the output from the LLM model. 159 | Args: 160 | chat_history: The context (chat history). 161 | """ 162 | client = self.gemini.GenerativeModel(self.model) 163 | chat_handler = client.start_chat(history=chat_history[:-1]) 164 | 165 | completions = chat_handler.send_message( 166 | chat_history[-1]["parts"], 167 | stream=True, 168 | generation_config=self.gemini.types.GenerationConfig( 169 | max_output_tokens=4096, 170 | temperature=self.temperature, 171 | ), 172 | ) 173 | 174 | for chunk in completions: 175 | yield chunk.text 176 | -------------------------------------------------------------------------------- /mle/model/mistral.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | 4 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name 5 | from mle.model.common import Model 6 | 7 | 8 | class MistralModel(Model): 9 | def __init__(self, api_key, model, temperature=0.7): 10 | """ 11 | Initialize the Mistral model. 12 | Args: 13 | api_key (str): The Mistral API key. 14 | model (str): The model with version. 15 | temperature (float): The temperature value. 16 | """ 17 | super().__init__() 18 | 19 | dependency = "mistralai" 20 | spec = importlib.util.find_spec(dependency) 21 | if spec is not None: 22 | self.mistral = importlib.import_module(dependency).Mistral 23 | else: 24 | raise ImportError( 25 | "It seems you didn't install mistralai. In order to enable the Mistral AI client related features, " 26 | "please make sure mistralai Python package has been installed. " 27 | "More information, please refer to: https://github.com/mistralai/client-python" 28 | ) 29 | 30 | self.model = model if model else 'mistral-large-latest' 31 | self.model_type = 'MistralAI' 32 | self.temperature = temperature 33 | self.client = self.mistral(api_key=api_key) 34 | self.func_call_history = [] 35 | 36 | def _convert_functions_to_tools(self, functions): 37 | """ 38 | Convert OpenAI-style functions to Mistral-style tools. 39 | """ 40 | tools = [] 41 | for func in functions: 42 | tool = { 43 | "type": "function", 44 | "function": { 45 | "name": func["name"], 46 | "description": func.get("description", ""), 47 | "parameters": func["parameters"] 48 | } 49 | } 50 | tools.append(tool) 51 | return tools 52 | 53 | def query(self, chat_history, **kwargs): 54 | """ 55 | Query the LLM model. 56 | 57 | Args: 58 | chat_history: The context (chat history). 59 | """ 60 | functions = kwargs.get("functions",[]) 61 | tools = self._convert_functions_to_tools(functions) 62 | tool_choice = kwargs.get('tool_choice', 'any') 63 | parameters = kwargs 64 | completion = self.client.chat.complete( 65 | model=self.model, 66 | messages=chat_history, 67 | temperature=self.temperature, 68 | stream=False, 69 | tools=tools, 70 | tool_choice=tool_choice, 71 | ) 72 | resp = completion.choices[0].message 73 | if resp.tool_calls: 74 | for tool_call in resp.tool_calls: 75 | chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False}) 76 | function_name = process_function_name(tool_call.function.name) 77 | arguments = json.loads(tool_call.function.arguments) 78 | print("[MLE FUNC CALL]: ", function_name) 79 | self.func_call_history.append({"name": function_name, "arguments": arguments}) 80 | # avoid the multiple search function calls 81 | search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS] 82 | if len(search_attempts) > 3: 83 | parameters['tool_choice'] = "none" 84 | result = get_function(function_name)(**arguments) 85 | chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id}) 86 | return self.query(chat_history, **parameters) 87 | else: 88 | return resp.content 89 | 90 | def stream(self, chat_history, **kwargs): 91 | """ 92 | Stream the output from the LLM model. 93 | Args: 94 | chat_history: The context (chat history). 95 | """ 96 | functions = kwargs.get("functions",[]) 97 | tools = self._convert_functions_to_tools(functions) 98 | tool_choice = kwargs.get('tool_choice', 'any') 99 | for chunk in self.client.chat.complete( 100 | model=self.model, 101 | messages=chat_history, 102 | temperature=self.temperature, 103 | stream=True, 104 | tools=tools, 105 | tool_choice=tool_choice 106 | ): 107 | if chunk.choices[0].delta.tool_calls: 108 | tool_call = chunk.choices[0].delta.tool_calls[0] 109 | if tool_call.function.name: 110 | chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False}) 111 | function_name = process_function_name(tool_call.function.name) 112 | arguments = json.loads(tool_call.function.arguments) 113 | result = get_function(function_name)(**arguments) 114 | chat_history.append({"role": "tool", "content": result, "name": function_name}) 115 | yield from self.stream(chat_history, **kwargs) 116 | else: 117 | yield chunk.choices[0].delta.content 118 | -------------------------------------------------------------------------------- /mle/model/ollama.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import re 3 | 4 | from .common import Model 5 | 6 | 7 | class OllamaModel(Model): 8 | def __init__(self, model, host_url=None): 9 | """ 10 | Initialize the Ollama model. 11 | Args: 12 | host_url (str): The Ollama Host url. 13 | model (str): The model version. 14 | """ 15 | super().__init__() 16 | 17 | dependency = "ollama" 18 | spec = importlib.util.find_spec(dependency) 19 | if spec is not None: 20 | self.model = model if model else 'llama3' 21 | self.model_type = 'Ollama' 22 | self.ollama = importlib.import_module(dependency) 23 | self.client = self.ollama.Client(host=host_url) 24 | else: 25 | raise ImportError( 26 | "It seems you didn't install ollama. In order to enable the Ollama client related features, " 27 | "please make sure ollama Python package has been installed. " 28 | "More information, please refer to: https://github.com/ollama/ollama-python" 29 | ) 30 | 31 | def _clean_think_tags(self, text): 32 | """ 33 | Remove content between tags and empty think tags from the text. 34 | Args: 35 | text (str): The input text to clean. 36 | Returns: 37 | str: The cleaned text with think tags and their content removed. 38 | """ 39 | # Remove content between tags 40 | text = re.sub(r'.*?', '', text, flags=re.DOTALL) 41 | # Remove empty think tags 42 | text = re.sub(r'', '', text) 43 | return text.strip() 44 | 45 | def _process_message(self, message, **kwargs): 46 | """ 47 | Process the message before sending to the model. 48 | Args: 49 | message: The message to process. 50 | **kwargs: Additional arguments. 51 | Returns: 52 | dict: The processed message. 53 | """ 54 | if isinstance(message, dict) and 'content' in message: 55 | message['content'] = self._clean_think_tags(message['content']) 56 | return message 57 | 58 | def query(self, chat_history, **kwargs): 59 | """ 60 | Query the LLM model. 61 | Args: 62 | chat_history: The context (chat history). 63 | **kwargs: Additional arguments for the model. 64 | Returns: 65 | str: The model's response. 66 | """ 67 | 68 | # Check if 'response_format' exists in kwargs 69 | format = None 70 | if 'response_format' in kwargs and kwargs['response_format'].get('type') == 'json_object': 71 | format = 'json' 72 | 73 | response = self.client.chat(model=self.model, messages=chat_history, format=format) 74 | return self._clean_think_tags(response['message']['content']) 75 | 76 | def stream(self, chat_history, **kwargs): 77 | """ 78 | Stream the output from the LLM model. 79 | Args: 80 | chat_history: The context (chat history). 81 | **kwargs: Additional arguments for the model. 82 | Yields: 83 | str: Chunks of the model's response. 84 | """ 85 | 86 | for chunk in self.client.chat( 87 | model=self.model, 88 | messages=chat_history, 89 | stream=True 90 | ): 91 | yield self._clean_think_tags(chunk['message']['content']) 92 | -------------------------------------------------------------------------------- /mle/model/openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib.util 3 | import json 4 | 5 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name 6 | from mle.model.common import Model 7 | 8 | 9 | class OpenAIModel(Model): 10 | def __init__(self, api_key, model, temperature=0.7): 11 | """ 12 | Initialize the OpenAI model. 13 | Args: 14 | api_key (str): The OpenAI API key. 15 | model (str): The model with version. 16 | temperature (float): The temperature value. 17 | """ 18 | super().__init__() 19 | 20 | dependency = "openai" 21 | spec = importlib.util.find_spec(dependency) 22 | if spec is not None: 23 | self.openai = importlib.import_module(dependency).OpenAI 24 | else: 25 | raise ImportError( 26 | "It seems you didn't install openai. In order to enable the OpenAI client related features, " 27 | "please make sure openai Python package has been installed. " 28 | "More information, please refer to: https://openai.com/product" 29 | ) 30 | 31 | self.model = model if model else 'gpt-4o-2024-08-06' 32 | self.model_type = 'OpenAI' 33 | self.temperature = temperature 34 | self.client = self.openai( 35 | api_key=api_key, 36 | base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"), 37 | ) 38 | self.func_call_history = [] 39 | 40 | def query(self, chat_history, **kwargs): 41 | """ 42 | Query the LLM model. 43 | 44 | Args: 45 | chat_history: The context (chat history). 46 | """ 47 | parameters = kwargs 48 | completion = self.client.chat.completions.create( 49 | model=self.model, 50 | messages=chat_history, 51 | temperature=self.temperature, 52 | stream=False, 53 | **parameters 54 | ) 55 | 56 | resp = completion.choices[0].message 57 | if resp.function_call: 58 | function_name = process_function_name(resp.function_call.name) 59 | arguments = json.loads(resp.function_call.arguments) 60 | print("[MLE FUNC CALL]: ", function_name) 61 | self.func_call_history.append({"name": function_name, "arguments": arguments}) 62 | # avoid the multiple search function calls 63 | search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS] 64 | if len(search_attempts) > 3: 65 | parameters['function_call'] = "none" 66 | result = get_function(function_name)(**arguments) 67 | chat_history.append({"role": "assistant", "function_call": dict(resp.function_call)}) 68 | chat_history.append({"role": "function", "content": result, "name": function_name}) 69 | return self.query(chat_history, **parameters) 70 | else: 71 | return resp.content 72 | 73 | def stream(self, chat_history, **kwargs): 74 | """ 75 | Stream the output from the LLM model. 76 | Args: 77 | chat_history: The context (chat history). 78 | """ 79 | arguments = '' 80 | function_name = '' 81 | for chunk in self.client.chat.completions.create( 82 | model=self.model, 83 | messages=chat_history, 84 | temperature=self.temperature, 85 | stream=True, 86 | **kwargs 87 | ): 88 | delta = chunk.choices[0].delta 89 | if delta.function_call: 90 | if delta.function_call.name: 91 | function_name = process_function_name(delta.function_call.name) 92 | if delta.function_call.arguments: 93 | arguments += delta.function_call.arguments 94 | 95 | if chunk.choices[0].finish_reason == "function_call": 96 | result = get_function(function_name)(**json.loads(arguments)) 97 | chat_history.append({"role": "function", "content": result, "name": function_name}) 98 | yield from self.stream(chat_history, **kwargs) 99 | else: 100 | yield delta.content 101 | -------------------------------------------------------------------------------- /mle/model/vllm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import importlib.util 4 | from typing import List, Dict, Any, Optional 5 | 6 | from mle.model.common import Model 7 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name 8 | 9 | 10 | class vLLMModel(Model): 11 | """vLLM model implementation using OpenAI-compatible API.""" 12 | 13 | def __init__(self, base_url: Optional[str] = None, 14 | model: Optional[str] = None, 15 | temperature: float = 0.7) -> None: 16 | """Initialize the vLLM model. 17 | 18 | Args: 19 | base_url: The URL of the vLLM server. 20 | model: The model name. 21 | temperature: The sampling temperature. 22 | """ 23 | super().__init__() 24 | 25 | dependency = "openai" 26 | spec = importlib.util.find_spec(dependency) 27 | if spec is not None: 28 | self.openai = importlib.import_module(dependency).OpenAI 29 | else: 30 | raise ImportError( 31 | "OpenAI package not found. Please install it using: " 32 | "pip install openai" 33 | ) 34 | 35 | self.model = model if model else 'mistralai/Mistral-7B-Instruct-v0.3' 36 | self.model_type = 'vLLM' 37 | self.temperature = temperature 38 | self.client = self.openai( 39 | api_key="EMPTY", 40 | base_url=base_url or os.getenv("vLLM_BASE_URL", 41 | "http://localhost:8000/v1"), 42 | timeout=60.0, 43 | max_retries=2, 44 | ) 45 | self.func_call_history = [] 46 | 47 | def normalize_chat_history(self, chat_history: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 48 | """Normalize chat history to ensure it follows the required format. 49 | 50 | Args: 51 | chat_history: The original chat history. 52 | 53 | Returns: 54 | Normalized chat history list. 55 | """ 56 | normalized = [] 57 | 58 | # Handle system message first 59 | system_messages = [msg for msg in chat_history if msg["role"] == "system"] 60 | if system_messages: 61 | normalized.append(system_messages[0]) 62 | 63 | # Add other messages in order 64 | for msg in chat_history: 65 | if msg["role"] == "system": 66 | continue 67 | 68 | if msg["role"] in ["user", "assistant", "function"]: 69 | normalized.append(msg) 70 | 71 | return normalized 72 | 73 | def query(self, chat_history: List[Dict[str, Any]], **kwargs) -> str: 74 | """Query the LLM model. 75 | 76 | Args: 77 | chat_history: The context (chat history). 78 | **kwargs: Additional parameters for the API call. 79 | 80 | Returns: 81 | Model's response as string. 82 | 83 | Raises: 84 | Exception: If the API call fails. 85 | """ 86 | try: 87 | normalized_history = self.normalize_chat_history(chat_history) 88 | parameters = kwargs 89 | completion = self.client.chat.completions.create( 90 | model=self.model, 91 | messages=normalized_history, 92 | temperature=self.temperature, 93 | stream=False, 94 | **parameters 95 | ) 96 | 97 | resp = completion.choices[0].message 98 | if resp.function_call: 99 | function_name = process_function_name(resp.function_call.name) 100 | arguments = json.loads(resp.function_call.arguments) 101 | print("[MLE FUNC CALL]: ", function_name) 102 | self.func_call_history.append({ 103 | "name": function_name, 104 | "arguments": arguments 105 | }) 106 | 107 | # Avoid multiple search function calls 108 | search_attempts = [ 109 | item for item in self.func_call_history 110 | if item['name'] in SEARCH_FUNCTIONS 111 | ] 112 | if len(search_attempts) > 3: 113 | parameters['function_call'] = "none" 114 | 115 | result = get_function(function_name)(**arguments) 116 | chat_history.append({ 117 | "role": "assistant", 118 | "function_call": dict(resp.function_call) 119 | }) 120 | chat_history.append({ 121 | "role": "function", 122 | "content": result, 123 | "name": function_name 124 | }) 125 | return self.query(chat_history, **parameters) 126 | return resp.content 127 | 128 | except Exception as e: 129 | error_msg = f"vLLM API error: {str(e)}" 130 | print(f"Error during vLLM query: {error_msg}") 131 | if hasattr(e, 'response'): 132 | print(f"Response status: {e.response.status_code}") 133 | print(f"Response body: {e.response.text}") 134 | raise Exception(error_msg) 135 | 136 | def stream(self, chat_history: List[Dict[str, Any]], **kwargs) -> str: 137 | """Stream the output from the LLM model. 138 | 139 | Args: 140 | chat_history: The context (chat history). 141 | **kwargs: Additional parameters for the API call. 142 | 143 | Yields: 144 | Chunks of the model's response. 145 | 146 | Raises: 147 | Exception: If the streaming fails. 148 | """ 149 | try: 150 | arguments = '' 151 | function_name = '' 152 | for chunk in self.client.chat.completions.create( 153 | model=self.model, 154 | messages=chat_history, 155 | temperature=self.temperature, 156 | stream=True, 157 | **kwargs 158 | ): 159 | delta = chunk.choices[0].delta 160 | if delta.function_call: 161 | if delta.function_call.name: 162 | function_name = process_function_name( 163 | delta.function_call.name 164 | ) 165 | if delta.function_call.arguments: 166 | arguments += delta.function_call.arguments 167 | 168 | if chunk.choices[0].finish_reason == "function_call": 169 | result = get_function(function_name)(**json.loads(arguments)) 170 | chat_history.append({ 171 | "role": "function", 172 | "content": result, 173 | "name": function_name 174 | }) 175 | yield from self.stream(chat_history, **kwargs) 176 | else: 177 | yield delta.content 178 | 179 | except Exception as e: 180 | error_msg = f"vLLM streaming error: {str(e)}" 181 | print(f"Error during vLLM streaming: {error_msg}") 182 | if hasattr(e, 'response'): 183 | print(f"Response status: {e.response.status_code}") 184 | print(f"Response body: {e.response.text}") 185 | raise Exception(error_msg) 186 | -------------------------------------------------------------------------------- /mle/server/__init__.py: -------------------------------------------------------------------------------- 1 | from .app import app 2 | -------------------------------------------------------------------------------- /mle/server/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Optional 4 | from datetime import datetime 5 | from pydantic import BaseModel 6 | from fastapi import FastAPI, HTTPException, BackgroundTasks 7 | from fastapi.responses import JSONResponse 8 | from fastapi.middleware.cors import CORSMiddleware 9 | 10 | from mle.workflow import report 11 | from mle.utils import check_config 12 | 13 | app = FastAPI() 14 | 15 | # Add CORS middleware 16 | app.add_middleware( 17 | CORSMiddleware, 18 | allow_origins=["*"], 19 | allow_credentials=True, 20 | allow_methods=["*"], 21 | allow_headers=["*"], 22 | ) 23 | 24 | 25 | class ReportRequest(BaseModel): 26 | """ 27 | ReportRequest: the request body for generating a report 28 | """ 29 | repo: str 30 | username: str 31 | token: Optional[str] = None 32 | okr: Optional[str] = None 33 | 34 | 35 | @app.get("/") 36 | def root(): 37 | """ 38 | read_root: read the root. 39 | :return: the root. 40 | """ 41 | return {"Welcome to": "MLE-Agent!"} 42 | 43 | 44 | @app.get("/latest_report", response_class=JSONResponse) 45 | def read_latest_report(): 46 | """ 47 | read_latest_report: read the latest progress report. 48 | :return: the content of the latest progress report as plain text. 49 | """ 50 | if not check_config(): 51 | raise HTTPException( 52 | status_code=400, 53 | detail="`project.yml` not found. Please start the MLE server under an MLE-Agent project directory." 54 | ) 55 | 56 | reports_dir = os.getcwd() 57 | report_files = [f for f in os.listdir(reports_dir) if f.startswith("progress_report_") and f.endswith(".json")] 58 | 59 | if not report_files: 60 | raise HTTPException(status_code=404, detail="No progress reports found.") 61 | 62 | latest_report = max(report_files, key=lambda f: datetime.strptime(f, "progress_report_%Y_%m_%d.json")) 63 | try: 64 | with open(os.path.join(reports_dir, latest_report), 'r') as file: 65 | report_dict = json.load(file) 66 | report_dict.update({"file": latest_report}) 67 | return JSONResponse(content=report_dict) 68 | except IOError: 69 | raise HTTPException(status_code=500, detail="Error reading the latest report file.") 70 | 71 | 72 | @app.post("/gen_report") 73 | def gen_report(report_request: ReportRequest): 74 | """ 75 | Generate a report synchronously based on the provided GitHub repository and username. 76 | Optionally includes OKR text. 77 | 78 | Example payload: 79 | 80 | curl -X POST http://localhost:8000/gen_report \ 81 | -H "Content-Type: application/json" \ 82 | -d '{ 83 | "token": "***", 84 | "repo": "MLSysOps/MLE-agent", 85 | "username": "huangyz0918", 86 | "okr": "Improve system efficiency by 20% this quarter" 87 | }' 88 | """ 89 | try: 90 | # Run report generation synchronously 91 | result = report( 92 | os.getcwd(), 93 | report_request.repo, 94 | report_request.username, 95 | report_request.token, 96 | okr_str=report_request.okr, 97 | model="gpt-4o", 98 | ) 99 | 100 | return { 101 | "message": "Report generation completed", 102 | "repo": report_request.repo, 103 | "username": report_request.username, 104 | "okr_provided": report_request.okr is not None, 105 | "result": result # Assuming the report function returns some result 106 | } 107 | except Exception as e: 108 | raise HTTPException(status_code=500, detail=f"Error in report generation process: {e}") 109 | 110 | 111 | @app.post("/gen_report_async") 112 | async def gen_report_async(report_request: ReportRequest, background_tasks: BackgroundTasks): 113 | """ 114 | Generate a report (async) based on the provided GitHub repository and username. 115 | Optionally includes OKR text. 116 | 117 | Example payload: 118 | 119 | curl -X POST http://localhost:8000/gen_report_async \ 120 | -H "Content-Type: application/json" \ 121 | -d '{ 122 | "token": "***", 123 | "repo": "MLSysOps/MLE-agent", 124 | "username": "huangyz0918", 125 | "okr": "Improve system efficiency by 20% this quarter" 126 | }' 127 | """ 128 | try: 129 | # Trigger report generation in the background 130 | background_tasks.add_task( 131 | report, 132 | os.getcwd(), 133 | report_request.repo, 134 | report_request.username, 135 | report_request.token, 136 | okr_str=report_request.okr, 137 | model="gpt-4o", 138 | ) 139 | 140 | return { 141 | "message": "Report generation started", 142 | "repo": report_request.repo, 143 | "username": report_request.username, 144 | "okr_provided": report_request.okr is not None 145 | } 146 | except Exception as e: 147 | raise HTTPException(status_code=500, detail=f"Error in report generation process: {e}") 148 | -------------------------------------------------------------------------------- /mle/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .system import * 2 | from .cache import * 3 | from .memory import * 4 | from .data import * 5 | from .chunk import * 6 | -------------------------------------------------------------------------------- /mle/utils/cache.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import Dict, Any, Optional 3 | from datetime import datetime 4 | 5 | from mle.utils.system import get_config, write_config 6 | 7 | 8 | class WorkflowCacheOperator: 9 | """ 10 | WorkflowCacheOperator handles the storing and resuming of cache content. 11 | """ 12 | 13 | def __init__(self, cache: 'WorkflowCache', cache_content: Dict[str, Any]): 14 | """ 15 | Args: 16 | cache: The cache instance to which this operator belongs. 17 | cache_content (Dict[str, object]): A dictionary holding the cached content. 18 | """ 19 | self.cache = cache 20 | self.cache_content = cache_content 21 | 22 | def store(self, key: str, value: Any) -> None: 23 | """ 24 | Store a value into the cache content. 25 | 26 | Args: 27 | key (str): The key under which the value is stored. 28 | value (object): The value to be stored. 29 | """ 30 | self.cache_content[key] = pickle.dumps(value, fix_imports=False) 31 | 32 | def resume(self, key: str) -> Any: 33 | """ 34 | Resume a value from the cache content. 35 | 36 | Args: 37 | key (str): The key of the value to be resumed. 38 | 39 | Returns: 40 | object: The resumed value, or None if the key does not exist. 41 | """ 42 | if key in self.cache_content: 43 | return pickle.loads(self.cache_content[key]) 44 | return None 45 | 46 | def __enter__(self): 47 | """ 48 | Enter the runtime context related to this object. 49 | 50 | Returns: 51 | WorkflowCacheOperator: self 52 | """ 53 | return self 54 | 55 | def __exit__(self, exc_type, exc_val, exc_tb): 56 | """ 57 | Exit the runtime context related to this object. 58 | 59 | Args: 60 | exc_type: The exception type. 61 | exc_val: The exception value. 62 | exc_tb: The traceback object. 63 | """ 64 | if exc_type is None: 65 | self.cache._store_cache_buffer() 66 | 67 | 68 | class WorkflowCache: 69 | """ 70 | WorkflowCache manages the caching for workflows, providing 71 | methods to load, store, and remove cached steps. 72 | """ 73 | 74 | def __init__(self, project_dir: str, workflow: str = 'baseline'): 75 | """ 76 | Initialize WorkflowCache with a project directory. 77 | 78 | Args: 79 | project_dir (str): The directory of the project. 80 | workflow (str): The name of the cached workflow. 81 | """ 82 | self.project_dir = project_dir 83 | self.workflow = workflow 84 | self.buffer = self._load_cache_buffer(workflow) 85 | self.cache: Dict[int, Dict[str, Any]] = self.buffer["cache"][workflow] 86 | 87 | def is_empty(self) -> bool: 88 | """ 89 | Check if the cache is empty. 90 | 91 | Returns: 92 | bool: True if the cache is empty, False otherwise. 93 | """ 94 | return len(self.cache) == 0 95 | 96 | def remove(self, step: int) -> None: 97 | """ 98 | Remove a step from the cache. 99 | 100 | Args: 101 | step (int): The step index to be removed. 102 | """ 103 | self.cache.pop(step, None) 104 | self._store_cache_buffer() 105 | 106 | def current_step(self) -> int: 107 | """ 108 | Get the current step from the cache. 109 | 110 | Returns: 111 | int: The current step. 112 | """ 113 | return max(self.cache.keys()) if self.cache else 0 114 | 115 | def resume_variable(self, key: str, step: Optional[int] = None): 116 | """ 117 | Resume the cached variable. 118 | 119 | Args: 120 | key (str): The key of the value to be resumed. 121 | step (str): The step to be initialized. 122 | 123 | Returns: 124 | object: The resumed value, or None if the key does not exist. 125 | """ 126 | if step is not None: 127 | return self.__call__(step).resume(key) 128 | else: 129 | for step in range(self.current_step() + 1): 130 | value = self.resume_variable(key, step) 131 | if value is not None: 132 | return value 133 | return None 134 | 135 | def _load_cache_buffer(self, workflow: str) -> Dict[str, Any]: 136 | """ 137 | Load the cache buffer from the configuration. 138 | 139 | Args: 140 | workflow (str): The name of the cached workflow. 141 | 142 | Returns: 143 | dict: The buffer loaded from the configuration. 144 | """ 145 | buffer = get_config() or {} 146 | if "cache" not in buffer.keys(): 147 | buffer["cache"] = {} 148 | if workflow not in buffer["cache"].keys(): 149 | buffer["cache"][workflow] = {} 150 | return buffer 151 | 152 | def _store_cache_buffer(self) -> None: 153 | """ 154 | Store the cache buffer to the configuration. 155 | """ 156 | write_config(self.buffer) 157 | 158 | def __call__(self, step: int, name: Optional[str] = None) -> WorkflowCacheOperator: 159 | """ 160 | Initialize the cache content for a given step and name. 161 | 162 | Args: 163 | step (str): The step to be initialized. 164 | name (str): The name associated with the step. 165 | 166 | Returns: 167 | WorkflowCacheOperator: An instance of WorkflowCacheOperator. 168 | """ 169 | if step not in self.cache.keys(): 170 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 171 | self.cache[step] = { 172 | "step": step, 173 | "name": name, 174 | "time": timestamp, 175 | "content": {}, 176 | } 177 | cache_content = self.cache[step]["content"] 178 | return WorkflowCacheOperator(self, cache_content) 179 | 180 | def __str__(self) -> str: 181 | """ 182 | Return a string representation of the step cache list. 183 | 184 | Returns: 185 | str: The string representation of the cache. 186 | """ 187 | return "\n".join(f"[{k}] {v['name']} ({v['time']})" for k, v in self.cache.items()) -------------------------------------------------------------------------------- /mle/utils/chunk.py: -------------------------------------------------------------------------------- 1 | # Source modified from https://github.com/CintraAI/code-chunker/blob/main/Chunker.py 2 | import tiktoken 3 | from .parser import CodeParser 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | def count_tokens(string: str, encoding_name: str) -> int: 8 | encoding = tiktoken.encoding_for_model(encoding_name) 9 | num_tokens = len(encoding.encode(string)) 10 | return num_tokens 11 | 12 | 13 | class Chunker(ABC): 14 | def __init__(self, encoding_name="gpt-4"): 15 | self.encoding_name = encoding_name 16 | 17 | @abstractmethod 18 | def chunk(self, content, token_limit): 19 | pass 20 | 21 | @abstractmethod 22 | def get_chunk(self, chunked_content, chunk_number): 23 | pass 24 | 25 | @staticmethod 26 | def print_chunks(chunks): 27 | for chunk_number, chunk_code in chunks.items(): 28 | print(f"Chunk {chunk_number}:") 29 | print("=" * 40) 30 | print(chunk_code) 31 | print("=" * 40) 32 | 33 | @staticmethod 34 | def consolidate_chunks_into_file(chunks): 35 | return "\n".join(chunks.values()) 36 | 37 | @staticmethod 38 | def count_lines(consolidated_chunks): 39 | lines = consolidated_chunks.split("\n") 40 | return len(lines) 41 | 42 | 43 | class CodeChunker(Chunker): 44 | def __init__(self, cache_dir, file_extension, encoding_name="gpt-4o-mini"): 45 | super().__init__(encoding_name) 46 | self.file_extension = file_extension 47 | self.cache_dir = cache_dir 48 | 49 | def chunk(self, code, token_limit) -> dict: 50 | code_parser = CodeParser(self.cache_dir, self.file_extension) 51 | chunks = {} 52 | token_count = 0 53 | lines = code.split("\n") 54 | i = 0 55 | chunk_number = 1 56 | start_line = 0 57 | breakpoints = sorted(code_parser.get_lines_for_points_of_interest(code, self.file_extension)) 58 | comments = sorted(code_parser.get_lines_for_comments(code, self.file_extension)) 59 | adjusted_breakpoints = [] 60 | for bp in breakpoints: 61 | current_line = bp - 1 62 | highest_comment_line = None # Initialize with None to indicate no comment line has been found yet 63 | while current_line in comments: 64 | highest_comment_line = current_line # Update highest comment line found 65 | current_line -= 1 # Move to the previous line 66 | 67 | if highest_comment_line: # If a highest comment line exists, add it 68 | adjusted_breakpoints.append(highest_comment_line) 69 | else: 70 | adjusted_breakpoints.append( 71 | bp) # If no comments were found before the breakpoint, add the original breakpoint 72 | 73 | breakpoints = sorted(set(adjusted_breakpoints)) # Ensure breakpoints are unique and sorted 74 | 75 | while i < len(lines): 76 | line = lines[i] 77 | new_token_count = count_tokens(line, self.encoding_name) 78 | if token_count + new_token_count > token_limit: 79 | 80 | # Set the stop line to the last breakpoint before the current line 81 | if i in breakpoints: 82 | stop_line = i 83 | else: 84 | stop_line = max(max([x for x in breakpoints if x < i], default=start_line), start_line) 85 | 86 | # If the stop line is the same as the start line, it means we haven't reached a breakpoint yet, and we need to move to the next line to find one 87 | if stop_line == start_line and i not in breakpoints: 88 | token_count += new_token_count 89 | i += 1 90 | 91 | # If the stop line is the same as the start line and the current line is a breakpoint, it means we can create a chunk with just the current line 92 | elif stop_line == start_line and i == stop_line: 93 | token_count += new_token_count 94 | i += 1 95 | 96 | # If the stop line is the same as the start line and the current line is a breakpoint, it means we can create a chunk with just the current line 97 | elif stop_line == start_line and i in breakpoints: 98 | current_chunk = "\n".join(lines[start_line:stop_line]) 99 | if current_chunk.strip(): # If the current chunk is not just whitespace 100 | chunks[chunk_number] = current_chunk # Using chunk_number as key 101 | chunk_number += 1 102 | 103 | token_count = 0 104 | start_line = i 105 | i += 1 106 | 107 | # If the stop line is different from the start line, it means we're at the end of a block 108 | else: 109 | current_chunk = "\n".join(lines[start_line:stop_line]) 110 | if current_chunk.strip(): 111 | chunks[chunk_number] = current_chunk # Using chunk_number as key 112 | chunk_number += 1 113 | 114 | i = stop_line 115 | token_count = 0 116 | start_line = stop_line 117 | else: 118 | # If the token count is still within the limit, add the line to the current chunk 119 | token_count += new_token_count 120 | i += 1 121 | 122 | # Append remaining code, if any, ensuring it's not empty or whitespace 123 | current_chunk_code = "\n".join(lines[start_line:]) 124 | if current_chunk_code.strip(): # Checks if the chunk is not just whitespace 125 | chunks[chunk_number] = current_chunk_code # Using chunk_number as key 126 | 127 | return chunks 128 | 129 | def get_chunk(self, chunked_codebase, chunk_number): 130 | return chunked_codebase[chunk_number] 131 | -------------------------------------------------------------------------------- /mle/utils/data.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import json 4 | from typing import Dict, Any 5 | 6 | 7 | def dict_to_markdown(data: Dict[str, Any], file_path: str) -> None: 8 | """ 9 | Write a dictionary to a markdown file. 10 | :param data: the dictionary to write. 11 | :param file_path: the file path to write the dictionary to. 12 | :return: 13 | """ 14 | 15 | def write_item(k, v, indent_level=0): 16 | if isinstance(v, dict): 17 | md_file.write(f"{'##' * (indent_level + 1)} {k}\n") 18 | for sub_key, sub_value in v.items(): 19 | write_item(sub_key, sub_value, indent_level + 1) 20 | elif isinstance(v, list): 21 | md_file.write(f"{'##' * (indent_level + 1)} {k}\n") 22 | for item in v: 23 | md_file.write(f"{' ' * indent_level}- {item}\n") 24 | else: 25 | md_file.write(f"{'##' * (indent_level + 1)} {k}\n") 26 | md_file.write(f"{' ' * indent_level}{v}\n") 27 | 28 | with open(file_path, 'w') as md_file: 29 | for key, value in data.items(): 30 | write_item(key, value) 31 | md_file.write("\n") 32 | 33 | 34 | def is_markdown_file(file_path): 35 | """ 36 | Check if the file is a Markdown file. 37 | :param file_path: the file path 38 | :return: boolean 39 | """ 40 | if not os.path.isfile(file_path): 41 | return False 42 | 43 | valid_extensions = ['.md', '.markdown', '.mdown', '.mkdn', '.mkd', '.mdwn', '.mdtxt', '.mdtext', '.text', '.Rmd'] 44 | file_extension = os.path.splitext(file_path)[1].lower() 45 | 46 | return file_extension in valid_extensions 47 | 48 | 49 | def read_markdown(file_path, include_links=False, include_images=False): 50 | """ 51 | Read the markdown file and return the content. 52 | :param file_path: the file path to the .md file 53 | :param include_links: the flag to include the links 54 | :param include_images: the flag to include the images 55 | :return: the raw content of the markdown file 56 | """ 57 | try: 58 | with open(file_path, 'r', encoding='utf-8') as file: 59 | content = file.read() 60 | 61 | if not include_links: 62 | content = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', content) 63 | 64 | if not include_images: 65 | content = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', '', content) 66 | 67 | return content.strip() 68 | except FileNotFoundError: 69 | return f"Error: File not found at {file_path}" 70 | except Exception as e: 71 | return f"Error: An unexpected error occurred - {str(e)}" 72 | 73 | 74 | def clean_json_string(input_string): 75 | """ 76 | clean the json string 77 | :input_string: the input json string 78 | """ 79 | cleaned = input_string.strip() 80 | cleaned = re.sub(r'^```\s*json?\s*', '', cleaned) 81 | cleaned = re.sub(r'\s*```\s*$', '', cleaned) 82 | parsed_json = json.loads(cleaned) 83 | return parsed_json 84 | -------------------------------------------------------------------------------- /mle/utils/memory.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import List, Dict, Optional 3 | 4 | import lancedb 5 | from lancedb.embeddings import get_registry 6 | 7 | from mle.utils import get_config 8 | 9 | 10 | class LanceDBMemory: 11 | 12 | def __init__(self, project_path: str): 13 | """ 14 | Memory: A base class for memory and external knowledge management. 15 | Args: 16 | project_path: the path to store the data. 17 | """ 18 | self.db_name = '.mle' 19 | self.table_name = 'memory' 20 | self.client = lancedb.connect(uri=self.db_name) 21 | 22 | config = get_config(project_path) 23 | if config["platform"] == "OpenAI": 24 | self.text_embedding = get_registry().get("openai").create(api_key=config["api_key"]) 25 | else: 26 | self.text_embedding = get_registry().get("sentence-transformers").create( 27 | name="sentence-transformers/paraphrase-MiniLM-L6-v2" 28 | ) 29 | 30 | def _open_table(self, table_name: str = None): 31 | """ 32 | Open a LanceDB table by table name. (Return None if not exists) 33 | Args: 34 | table_name (Optional[str]): The name of the table. Defaults to self.table_name. 35 | """ 36 | table_name = table_name or self.table_name 37 | try: 38 | table = self.client.open_table(table_name) 39 | except FileNotFoundError: 40 | return None 41 | return table 42 | 43 | def add( 44 | self, 45 | texts: List[str], 46 | metadata: Optional[List[Dict]] = None, 47 | table_name: Optional[str] = None, 48 | ids: Optional[List[str]] = None, 49 | ) -> List[str]: 50 | """ 51 | Adds a list of text items to the specified memory table in the database. 52 | 53 | Args: 54 | texts (List[str]): A list of text strings to be added. 55 | metadata (Optional[List[Dict]]): A list of metadata to be added. 56 | table_name (Optional[str]): The name of the table to add data to. Defaults to self.table_name. 57 | ids (Optional[List[str]]): A list of unique IDs for the text items. 58 | If not provided, random UUIDs are generated. 59 | 60 | Returns: 61 | List[str]: A list of IDs associated with the added text items. 62 | """ 63 | if isinstance(texts, str): 64 | texts = (texts,) 65 | 66 | if metadata is None: 67 | metadata = [None, ] * len(texts) 68 | elif isinstance(metadata, dict): 69 | metadata = (metadata,) 70 | else: 71 | assert len(texts) == len(metadata) 72 | 73 | embeds = self.text_embedding.compute_source_embeddings(texts) 74 | 75 | table_name = table_name or self.table_name 76 | ids = ids or [str(uuid.uuid4()) for _ in range(len(texts))] 77 | 78 | data = [ 79 | { 80 | "vector": embed, 81 | "text": text, 82 | "id": idx, 83 | "metadata": meta, 84 | } for idx, text, embed, meta in zip(ids, texts, embeds, metadata) 85 | ] 86 | 87 | if table_name not in self.client.table_names(): 88 | table = self.client.create_table(table_name, data=data) 89 | table.create_fts_index("id") 90 | else: 91 | self._open_table(table_name).add(data=data) 92 | 93 | return ids 94 | 95 | def query(self, query_texts: List[str], table_name: Optional[str] = None, n_results: int = 5) -> List[List[dict]]: 96 | """ 97 | Queries the specified memory table for similar text embeddings. 98 | 99 | Args: 100 | query_texts (List[str]): A list of query text strings. 101 | table_name (Optional[str]): The name of the table to query. Defaults to self.table_name. 102 | n_results (int): The maximum number of results to retrieve per query. Default is 5. 103 | 104 | Returns: 105 | List[List[dict]]: A list of results for each query text, each result being a dictionary with 106 | keys such as "vector", "text", and "id". 107 | """ 108 | table = self._open_table(table_name) 109 | if table is None: 110 | return [] 111 | 112 | query_embeds = self.text_embedding.compute_source_embeddings(query_texts) 113 | 114 | results = [table.search(query).limit(n_results).to_list() for query in query_embeds] 115 | return results 116 | 117 | def list_all_keys(self, table_name: Optional[str] = None): 118 | """ 119 | Lists all IDs in the specified memory table. 120 | 121 | Args: 122 | table_name (Optional[str]): The name of the table to list IDs from. Defaults to the instance's table name. 123 | 124 | Returns: 125 | List[str]: A list of all IDs in the table. 126 | """ 127 | table = self._open_table(table_name) 128 | if table is None: 129 | return [] 130 | 131 | return [item["id"] for item in table.search(query_type="fts").to_list()] 132 | 133 | def get(self, record_id: str, table_name: Optional[str] = None): 134 | """ 135 | Retrieves a record by its ID from the specified memory table. 136 | 137 | Args: 138 | record_id (str): The ID of the record to retrieve. 139 | table_name (Optional[str]): The name of the table to query. Defaults to the instance's table name. 140 | 141 | Returns: 142 | List[dict]: A list containing the matching record, or an empty list if not found. 143 | """ 144 | table = self._open_table(table_name) 145 | if table is None: 146 | return [] 147 | 148 | return table.search(query_type="fts") \ 149 | .where(f"id = '{record_id}'") \ 150 | .limit(1).to_list() 151 | 152 | def get_by_metadata(self, key: str, value: str, table_name: Optional[str] = None, n_results: int = 5): 153 | """ 154 | Retrieves records matching a specific metadata key-value pair. 155 | 156 | Args: 157 | key (str): The metadata key to filter by. 158 | value (str): The value of the metadata key to filter by. 159 | table_name (Optional[str]): The name of the table to query. Defaults to the instance's table name. 160 | n_results (int): The maximum number of results to retrieve. Defaults to 5. 161 | 162 | Returns: 163 | List[dict]: A list of records matching the metadata criteria. 164 | """ 165 | table = self._open_table(table_name) 166 | if table is None: 167 | return [] 168 | 169 | return table.search(query_type="fts") \ 170 | .where(f"metadata.{key} = '{value}'") \ 171 | .limit(n_results).to_list() 172 | 173 | def delete(self, record_id: str, table_name: Optional[str] = None) -> bool: 174 | """ 175 | Deletes a record from the specified memory table. 176 | 177 | Args: 178 | record_id (str): The ID of the record to delete. 179 | table_name (Optional[str]): The name of the table to delete the record from. Defaults to self.table_name. 180 | 181 | Returns: 182 | bool: True if the deletion was successful, False otherwise. 183 | """ 184 | table = self._open_table(table_name) 185 | if table is None: 186 | return True 187 | 188 | return table.delete(f"id = '{record_id}'") 189 | 190 | def delete_by_metadata(self, key: str, value: str, table_name: Optional[str] = None): 191 | """ 192 | Deletes records from the specified memory table based on a metadata key-value pair. 193 | 194 | Args: 195 | key (str): The metadata key to filter by. 196 | value (str): The value of the metadata key to filter by. 197 | table_name (Optional[str]): The name of the table to delete records from. Defaults to the instance's table name. 198 | 199 | Returns: 200 | bool: True if deletion was successful, False otherwise. 201 | """ 202 | table = self._open_table(table_name) 203 | if table is None: 204 | return True 205 | 206 | return table.delete(f"metadata.{key} = '{value}'") 207 | 208 | def drop(self, table_name: Optional[str] = None) -> bool: 209 | """ 210 | Drops (deletes) the specified memory table. 211 | 212 | Args: 213 | table_name (Optional[str]): The name of the table to delete. Defaults to self.table_name. 214 | 215 | Returns: 216 | bool: True if the table was successfully dropped, False otherwise. 217 | """ 218 | table_name = table_name or self.table_name 219 | table = self._open_table(table_name) 220 | if table is None: 221 | return True 222 | 223 | return self.client.drop_table(table_name) 224 | 225 | def count(self, table_name: Optional[str] = None) -> int: 226 | """ 227 | Counts the number of records in the specified memory table. 228 | 229 | Args: 230 | table_name (Optional[str]): The name of the table to count records in. Defaults to self.table_name. 231 | 232 | Returns: 233 | int: The number of records in the table. 234 | """ 235 | table = self._open_table(table_name) 236 | if table is None: 237 | return 0 238 | 239 | return table.count_rows() 240 | 241 | def reset(self) -> None: 242 | """ 243 | Resets the memory by dropping the default memory table. 244 | """ 245 | self.drop() 246 | -------------------------------------------------------------------------------- /mle/version.py: -------------------------------------------------------------------------------- 1 | # PEP0440 compatible formatted version, see: 2 | # https://www.python.org/dev/peps/pep-0440/ 3 | # 4 | # Generic release markers: 5 | # X.Y 6 | # X.Y.Z # For bug fix releases 7 | # 8 | # Admissible pre-release markers: 9 | # X.YaN # Alpha release 10 | # X.YbN # Beta release 11 | # X.YrcN # Release Candidate 12 | # X.Y # Final release 13 | # 14 | # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. 15 | # 'X.Y.dev0' is the canonical version of 'X.Y.dev' 16 | # 17 | 18 | __version__ = '0.4.2' 19 | -------------------------------------------------------------------------------- /mle/workflow/__init__.py: -------------------------------------------------------------------------------- 1 | from .chat import chat 2 | from .baseline import baseline 3 | from .report import report, report_local 4 | from .kaggle import kaggle, auto_kaggle 5 | -------------------------------------------------------------------------------- /mle/workflow/baseline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Baseline Mode: the mode to quickly generate the AI baseline based on the user's requirements. 3 | """ 4 | import os 5 | import questionary 6 | from rich.console import Console 7 | from mle.model import load_model 8 | from mle.utils import print_in_box, ask_text, WorkflowCache 9 | from mle.agents import CodeAgent, DebugAgent, AdviseAgent, PlanAgent 10 | 11 | 12 | def ask_data(data_str: str): 13 | """ 14 | Ask the user to provide the data information. 15 | :param data_str: the input data string. Now, it should be the name of the public dataset or 16 | the path to the local CSV file. 17 | :return: the formated data information. 18 | """ 19 | if os.path.isfile(data_str) and data_str.lower().endswith('.csv'): 20 | return f"[green]CSV Dataset Location:[/green] {data_str}" 21 | else: 22 | return f"[green]Dataset:[/green] {data_str}" 23 | 24 | 25 | def baseline(work_dir: str, model=None): 26 | """ 27 | The workflow of the baseline mode. 28 | :return: 29 | """ 30 | 31 | console = Console() 32 | cache = WorkflowCache(work_dir, 'baseline') 33 | model = load_model(work_dir, model) 34 | 35 | if not cache.is_empty(): 36 | step = ask_text(f"MLE has finished the following steps: \n{cache}\n" 37 | f"You can pick a step from 1 to {cache.current_step()} to resume\n" 38 | "(or ENTER to continue the workflow)") 39 | if step: 40 | step = int(step) 41 | for i in range(step, cache.current_step() + 1): 42 | cache.remove(i) # remove the stale step caches 43 | 44 | # ask for the data information 45 | with cache(step=1, name="ask for the data information") as ca: 46 | dataset = ca.resume("dataset") 47 | if dataset is None: 48 | advisor = AdviseAgent(model, console) 49 | dataset = ask_text("Please provide your dataset information (a public dataset name or a local absolute filepath)") 50 | if not dataset: 51 | print_in_box("The dataset is empty. Aborted", console, title="Error", color="red") 52 | return 53 | dataset = advisor.clarify_dataset(dataset) 54 | ca.store("dataset", dataset) 55 | 56 | # ask for the user requirement 57 | with cache(step=2, name="ask for the user requirement") as ca: 58 | ml_requirement = ca.resume("ml_requirement") 59 | if ml_requirement is None: 60 | ml_requirement = ask_text("Please provide your requirement") 61 | if not ml_requirement: 62 | print_in_box("The user's requirement is empty. Aborted", console, title="Error", color="red") 63 | return 64 | ca.store("ml_requirement", ml_requirement) 65 | 66 | # advisor agent gives suggestions in a report 67 | with cache(step=3, name="MLE advisor agent provides a high-level report") as ca: 68 | advisor_report = ca.resume("advisor_report") 69 | if advisor_report is None: 70 | advisor = AdviseAgent(model, console) 71 | advisor_report = advisor.interact("[green]User Requirement:[/green] " + ml_requirement + "\n" + ask_data(dataset)) 72 | ca.store("advisor_report", advisor_report) 73 | 74 | # plan agent generates the coding plan 75 | with cache(step=4, name="MLE plan agent generates a dev plan") as ca: 76 | coding_plan = ca.resume("coding_plan") 77 | if coding_plan is None: 78 | planner = PlanAgent(model, console) 79 | coding_plan = planner.interact(advisor_report) 80 | ca.store("coding_plan", coding_plan) 81 | 82 | # code agent codes the tasks and debug with the debug agent 83 | with cache(step=5, name="MLE code&debug agents start to work") as ca: 84 | coder = CodeAgent(model, work_dir, console) 85 | coder.read_requirement(advisor_report) 86 | debugger = DebugAgent(model, console) 87 | 88 | is_auto_mode = questionary.confirm( 89 | "MLE developer is about to start to code.\n" 90 | "Choose to debug or not (If no, MLE agent will only focus on coding tasks," 91 | " and you have to run and debug the code yourself)?" 92 | ).ask() 93 | 94 | for current_task in coding_plan.get('tasks'): 95 | code_report = coder.interact(current_task) 96 | is_debugging = code_report.get('debug') 97 | 98 | if is_auto_mode: 99 | while True: 100 | if is_debugging == 'true' or is_debugging == 'True': 101 | with console.status("MLE Debug Agent is executing and debugging the code..."): 102 | debug_report = debugger.analyze(code_report) 103 | if debug_report.get('status') == 'success': 104 | break 105 | else: 106 | code_report = coder.debug(current_task, debug_report) 107 | else: 108 | break 109 | -------------------------------------------------------------------------------- /mle/workflow/chat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Chat Mode: the mode to have an interactive chat with LLM to work on ML project. 3 | """ 4 | import questionary 5 | from rich.live import Live 6 | from rich.panel import Panel 7 | from rich.console import Console 8 | from rich.markdown import Markdown 9 | from mle.model import load_model 10 | from mle.utils import print_in_box, WorkflowCache 11 | from mle.agents import ChatAgent 12 | 13 | 14 | def chat(work_dir: str, memory=None, model=None): 15 | console = Console() 16 | cache = WorkflowCache(work_dir, 'chat') 17 | model = load_model(work_dir, model) 18 | chatbot = ChatAgent(model, memory=memory) 19 | 20 | if not cache.is_empty(): 21 | if questionary.confirm(f"Would you like to continue the previous conversation?\n").ask(): 22 | chatbot.chat_history = cache.resume_variable("conversation") 23 | 24 | with cache(step=1, name="chat") as ca: 25 | greets = chatbot.greet() 26 | print_in_box(greets, console=console, title="MLE Chatbot", color="magenta") 27 | 28 | while True: 29 | try: 30 | user_pmpt = questionary.text("[Exit/Ctrl+D]: ").ask() 31 | if user_pmpt: 32 | with Live(console=Console()) as live: 33 | for text in chatbot.chat(user_pmpt.strip()): 34 | live.update( 35 | Panel(Markdown(text), title="[bold magenta]MLE-Agent[/]", border_style="magenta"), 36 | refresh=True 37 | ) 38 | ca.store("conversation", chatbot.chat_history) 39 | except (KeyboardInterrupt, EOFError): 40 | break 41 | -------------------------------------------------------------------------------- /mle/workflow/kaggle.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kaggle Mode: the mode to generate ML pipeline for kaggle competitions. 3 | """ 4 | import os 5 | import questionary 6 | from typing import List 7 | from rich.console import Console 8 | 9 | from mle.model import load_model 10 | from mle.function import execute_command 11 | from mle.integration import KaggleIntegration 12 | from mle.utils import ask_text, read_markdown, is_markdown_file, WorkflowCache, print_in_box 13 | from mle.agents import CodeAgent, DebugAgent, AdviseAgent, PlanAgent, GitHubSummaryAgent 14 | 15 | 16 | def auto_kaggle( 17 | work_dir: str, 18 | datasets: List[str], 19 | description: str, 20 | submission='./submission.csv', 21 | debug_max_attempt=5, 22 | sub_examples=None, 23 | competition_id=None, 24 | model=None 25 | ): 26 | """ 27 | The workflow of the kaggle mode. 28 | :param work_dir: the working directory. 29 | :param datasets: the datasets to use. 30 | :param description: the description of the competition, can be a path to a local .md file or a string. 31 | :param submission: the path of the kaggle submission file. 32 | :param debug_max_attempt: the max attempt for debugging. 33 | :param sub_examples: the path to the kaggle submission example file. 34 | :param competition_id: the competition id. 35 | :param model: the model to use. 36 | """ 37 | console = Console() 38 | model = load_model(work_dir, model) 39 | 40 | # initialize the agents 41 | advisor = AdviseAgent(model, console, mode="precise") 42 | summarizer = GitHubSummaryAgent(model, console=console) 43 | coder = CodeAgent(model, work_dir, console=console, single_file=True) 44 | debugger = DebugAgent(model, console, analyze_only=True) 45 | 46 | if is_markdown_file(description): 47 | description = read_markdown(description) 48 | 49 | with console.status("MLE Agent is processing the kaggle competition overview..."): 50 | requirements = summarizer.kaggle_request_summarize(description, sub_examples) 51 | requirements += f"\n\nLOCAL DATASET PATH:\n" 52 | for dataset in datasets: 53 | requirements += f" - {dataset}\n" 54 | 55 | requirements += f"\nSUBMISSION FILE PATH: {submission}\n" 56 | 57 | suggestions = advisor.suggest(requirements, return_raw=True) 58 | requirements += f""" 59 | \nIMPLEMENTATION SUGGESTIONS: 60 | 61 | - Suggestion summary: {suggestions.get('suggestion')} 62 | - ML task: {suggestions.get('task')} 63 | - Model or algorithm: {suggestions.get('model_or_algorithm')} 64 | - Training strategy: {suggestions.get('training_method')} 65 | - Tricks to increase performance: 66 | """ 67 | for trick in suggestions.get('tricks'): 68 | requirements += f"\n - {trick}" 69 | 70 | coder.read_requirement(requirements) 71 | if competition_id is None: 72 | competition_id = "kaggle competition" 73 | 74 | coding_task = { 75 | "task": competition_id, 76 | "description": requirements 77 | } 78 | print_in_box(requirements, console, title="Kaggle Competition Requirement", color="green") 79 | code_report = coder.code(coding_task) 80 | debug_attempt = 0 81 | while True: 82 | if debug_attempt > debug_max_attempt: 83 | console.log(f"Debug the code failed with max {debug_max_attempt} attempts. Please check the code manually.") 84 | break 85 | 86 | with console.status("MLE Debug Agent is executing and debugging the code..."): 87 | running_cmd = code_report.get('command') 88 | logs = execute_command(running_cmd) 89 | debug_report = debugger.analyze_with_log(running_cmd, logs) 90 | if debug_report.get('status') == 'success': 91 | # check the submission file 92 | if not os.path.exists(submission): 93 | console.log(f"The submission file ({submission}) is not found. Launch the coder to improve...") 94 | code_report = coder.debug( 95 | coding_task, 96 | { 97 | "status": "error", 98 | "changes": [ 99 | f"make sure the submission file is generated in {submission}", 100 | f"make sure the submission file is in the correct format. You can refer to the example submission file: {sub_examples}" 101 | ], 102 | "suggestion": f"Please update the code related to generating the submission file." 103 | } 104 | ) 105 | else: 106 | break 107 | else: 108 | debug_attempt += 1 109 | code_report = coder.debug(coding_task, debug_report) 110 | 111 | 112 | def kaggle(work_dir: str, model=None): 113 | """ 114 | The workflow of the kaggle mode. 115 | :param work_dir: the working directory. 116 | :param model: the model to use. 117 | """ 118 | console = Console() 119 | cache = WorkflowCache(work_dir, 'kaggle') 120 | model = load_model(work_dir, model) 121 | integration = KaggleIntegration() 122 | 123 | if not cache.is_empty(): 124 | step = ask_text(f"MLE has finished the following steps: \n{cache}\n" 125 | f"You can pick a step from 1 to {cache.current_step()} to resume\n" 126 | "(or ENTER to continue the workflow)") 127 | if step: 128 | step = int(step) 129 | for i in range(step, cache.current_step() + 1): 130 | cache.remove(i) # remove the stale step caches 131 | 132 | # ask for the kaggle competition 133 | with cache(step=1, name="ask for the kaggle competition") as ca: 134 | competition = ca.resume("competition") 135 | dataset = ca.resume("dataset") 136 | if competition is None or dataset is None: 137 | competition = questionary.select( 138 | "Please select a Kaggle competition to join:", 139 | choices=integration.list_competition() 140 | ).ask() 141 | with console.status("MLE Agent is downloading the kaggle competition dataset..."): 142 | dataset = integration.download_competition_dataset( 143 | competition, os.path.join(os.getcwd(), 'data')) 144 | ca.store("competition", competition) 145 | ca.store("dataset", dataset) 146 | 147 | # ask for the user requirement 148 | with cache(step=2, name="get the competition overview from kaggle") as ca: 149 | ml_requirement = ca.resume("ml_requirement") 150 | if ml_requirement is None: 151 | with console.status("MLE Agent is fetching the kaggle competition overview..."): 152 | summary = GitHubSummaryAgent(model, console=console) 153 | ml_requirement = summary.kaggle_request_summarize(integration.fetch_competition_overview(competition)) 154 | ca.store("ml_requirement", ml_requirement) 155 | 156 | # advisor agent gives suggestions in a report 157 | with cache(step=3, name="MLE advisor agent provides a high-level report") as ca: 158 | advisor_report = ca.resume("advisor_report") 159 | if advisor_report is None: 160 | advisor = AdviseAgent(model, console) 161 | advisor_report = advisor.interact( 162 | f"[green]Competition Requirement:[/green] {ml_requirement}\n" 163 | f"Dataset is downloaded in path: {dataset}" 164 | ) 165 | ca.store("advisor_report", advisor_report) 166 | 167 | # plan agent generates the coding plan 168 | with cache(step=4, name="MLE plan agent generates a dev plan") as ca: 169 | coding_plan = ca.resume("coding_plan") 170 | if coding_plan is None: 171 | planner = PlanAgent(model, console) 172 | coding_plan = planner.interact(advisor_report) 173 | ca.store("coding_plan", coding_plan) 174 | 175 | # code agent codes the tasks and debug with the debug agent 176 | with cache(step=5, name="MLE code&debug agents start to work") as ca: 177 | coder = CodeAgent(model, work_dir, console) 178 | coder.read_requirement(advisor_report) 179 | debugger = DebugAgent(model, console) 180 | 181 | is_auto_mode = questionary.confirm( 182 | "MLE developer is about to start to code.\n" 183 | "Choose to debug or not (If no, MLE agent will only focus on coding tasks," 184 | " and you have to run and debug the code yourself)?" 185 | ).ask() 186 | 187 | for current_task in coding_plan.get('tasks'): 188 | code_report = coder.interact(current_task) 189 | is_debugging = code_report.get('debug') 190 | 191 | if is_auto_mode: 192 | while True: 193 | if is_debugging == 'true' or is_debugging == 'True': 194 | with console.status("MLE Debug Agent is executing and debugging the code..."): 195 | debug_report = debugger.analyze(code_report) 196 | if debug_report.get('status') == 'success': 197 | break 198 | else: 199 | code_report = coder.debug(current_task, debug_report) 200 | else: 201 | break 202 | -------------------------------------------------------------------------------- /mle/workflow/report.py: -------------------------------------------------------------------------------- 1 | """ 2 | Report Mode: the mode to generate the AI report based on the user's requirements. 3 | """ 4 | import os 5 | import pickle 6 | from rich.console import Console 7 | from mle.model import load_model 8 | from mle.utils.system import get_config, write_config, check_config 9 | from mle.integration import GoogleCalendarIntegration, github_login 10 | from mle.agents import GitHubSummaryAgent, ReportAgent, GitSummaryAgent 11 | 12 | 13 | def ask_data(data_str: str): 14 | """ 15 | Ask the user to provide the data information. 16 | :param data_str: the input data string. Now, it should be the name of the public dataset or 17 | the path to the local CSV file. 18 | :return: the formated data information. 19 | """ 20 | if os.path.isfile(data_str) and data_str.lower().endswith('.csv'): 21 | return f"[green]CSV Dataset Location:[/green] {data_str}" 22 | else: 23 | return f"[green]Dataset:[/green] {data_str}" 24 | 25 | 26 | def report( 27 | work_dir: str, 28 | github_repo: str, 29 | github_username: str, 30 | github_token: str = None, 31 | okr_str: str = None, 32 | model=None 33 | ): 34 | """ 35 | The workflow of the baseline mode. 36 | :param work_dir: the working directory. 37 | :param github_repo: the GitHub repository. 38 | :param github_username: the GitHub username. 39 | :param github_token: the GitHub token. 40 | :param okr_str: the OKR string. 41 | :param model: the model to use. 42 | :return: 43 | """ 44 | console = Console() 45 | model = load_model(work_dir, model) 46 | 47 | events = None 48 | if check_config(console): 49 | config = get_config() 50 | if github_token is None: 51 | if "github" in config.get("integration", {}).keys(): 52 | github_token = config["integration"]["github"].get("token") 53 | else: 54 | github_token = github_login() 55 | config["integration"]["github"] = {"token": github_token} 56 | write_config(config) 57 | 58 | if "google_calendar" in config.get("integration", {}).keys(): 59 | google_token = pickle.loads(config["integration"]["google_calendar"].get("token")) 60 | google_calendar = GoogleCalendarIntegration(google_token) 61 | events = google_calendar.get_events() 62 | 63 | summarizer = GitHubSummaryAgent( 64 | model, 65 | github_repo=github_repo, 66 | username=github_username, 67 | github_token=github_token, 68 | ) 69 | reporter = ReportAgent(model, console) 70 | 71 | github_summary = summarizer.summarize() 72 | return reporter.gen_report(github_summary, events, okr=okr_str) 73 | 74 | 75 | def report_local( 76 | work_dir: str, 77 | git_path: str, 78 | email: str, 79 | okr_str: str = None, 80 | start_date: str = None, 81 | end_date: str = None, 82 | model=None 83 | ): 84 | """ 85 | The workflow of the baseline mode. 86 | :param work_dir: the working directory. 87 | :param git_path: the path to the local Git repository. 88 | :param email: the email address. 89 | :param okr_str: the OKR string. 90 | :param start_date: the start date. 91 | :param end_date: the end date. 92 | :param model: the model to use. 93 | :return: 94 | """ 95 | 96 | console = Console() 97 | model = load_model(work_dir, model) 98 | 99 | events = None 100 | 101 | summarizer = GitSummaryAgent( 102 | model, 103 | git_path=git_path, 104 | git_email=email, 105 | ) 106 | reporter = ReportAgent(model, console) 107 | 108 | git_summary = summarizer.summarize(start_date=start_date, end_date=end_date) 109 | return reporter.gen_report(git_summary, events, okr=okr_str) 110 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rich 2 | click 3 | py7zr~=0.22.0 4 | openai 5 | pyyaml 6 | kaggle 7 | fastapi 8 | uvicorn 9 | requests 10 | GitPython 11 | tree-sitter==0.21.3 12 | onnxruntime 13 | questionary 14 | pandas~=2.2.2 15 | tavily-python 16 | instructor 17 | langfuse 18 | setuptools 19 | numexpr~=2.10.1 20 | bottleneck~=1.4.0 21 | google-api-python-client~=2.143.0 22 | google-auth-httplib2~=0.2.0 23 | google-auth-oauthlib~=1.2.1 24 | lancedb~=0.15.0 25 | tantivy~=0.22.0 26 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | 4 | [flake8] 5 | select = W291 6 | max-line-length = 256 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import find_packages, setup 3 | 4 | # read the contents of README file 5 | from os import path 6 | from io import open # for Python 2 and 3 compatibility 7 | 8 | # get __version__ from _version.py 9 | ver_file = path.join('mle', 'version.py') 10 | with open(ver_file) as f: 11 | exec(f.read()) 12 | 13 | this_directory = path.abspath(path.dirname(__file__)) 14 | 15 | 16 | # read the contents of README.md 17 | def readme(): 18 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 19 | return f.read() 20 | 21 | 22 | # read the contents of requirements.txt 23 | with open(path.join(this_directory, 'requirements.txt'), encoding='utf-8') as f: 24 | requirements = f.read().splitlines() 25 | 26 | setup( 27 | name='mle-agent', 28 | version=__version__, 29 | description='MLE-agent: An agent to automate your MLE processes', 30 | long_description=readme(), 31 | long_description_content_type='text/markdown', 32 | author='Yizheng Huang, Huaizheng Zhang', 33 | author_email='huangyz0918@gmail.com', 34 | url='https://github.com/MLSysOps/MLE-agent', 35 | download_url='https://github.com/MLSysOps/MLE-agent/archive/refs/heads/main.zip', 36 | keywords=['LLM', 'deep learning', 'MLOps', 'shell', 'neural networks'], 37 | packages=find_packages(), 38 | entry_points={ 39 | "console_scripts": [ 40 | "mle-agent=mle.cli:cli", 41 | "mle=mle.cli:cli", 42 | ] 43 | }, 44 | zip_safe=False, 45 | include_package_data=True, 46 | install_requires=requirements, 47 | setup_requires=['setuptools>=38.6.0'], 48 | classifiers=[ 49 | 'Development Status :: 5 - Production/Stable', 50 | 'Intended Audience :: Education', 51 | 'Intended Audience :: Financial and Insurance Industry', 52 | 'Intended Audience :: Science/Research', 53 | 'Intended Audience :: Developers', 54 | 'Intended Audience :: Information Technology', 55 | 'License :: OSI Approved :: Apache Software License', 56 | 'Programming Language :: Python :: 3', 57 | "Operating System :: OS Independent", 58 | ], 59 | ) 60 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/tests/__init__.py -------------------------------------------------------------------------------- /web/.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["next/core-web-vitals", "next/typescript"] 3 | } 4 | -------------------------------------------------------------------------------- /web/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | .yarn/install-state.gz 8 | 9 | # testing 10 | /coverage 11 | 12 | # next.js 13 | /.next/ 14 | /out/ 15 | 16 | # production 17 | /build 18 | 19 | # misc 20 | .DS_Store 21 | *.pem 22 | 23 | # debug 24 | npm-debug.log* 25 | yarn-debug.log* 26 | yarn-error.log* 27 | 28 | # local env files 29 | .env*.local 30 | 31 | # vercel 32 | .vercel 33 | 34 | # typescript 35 | *.tsbuildinfo 36 | next-env.d.ts 37 | -------------------------------------------------------------------------------- /web/README.md: -------------------------------------------------------------------------------- 1 | This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app). 2 | 3 | ## Getting Started 4 | 5 | First, run the development server: 6 | 7 | ```bash 8 | npm run dev 9 | # or 10 | yarn dev 11 | # or 12 | pnpm dev 13 | # or 14 | bun dev 15 | ``` 16 | 17 | Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. 18 | 19 | You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file. 20 | 21 | This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel. 22 | 23 | ## Learn More 24 | 25 | To learn more about Next.js, take a look at the following resources: 26 | 27 | - [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API. 28 | - [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial. 29 | 30 | You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome! 31 | 32 | ## Deploy on Vercel 33 | 34 | The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js. 35 | 36 | Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details. 37 | -------------------------------------------------------------------------------- /web/app/fonts/GeistMonoVF.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/web/app/fonts/GeistMonoVF.woff -------------------------------------------------------------------------------- /web/app/fonts/GeistVF.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/web/app/fonts/GeistVF.woff -------------------------------------------------------------------------------- /web/app/globals.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | 5 | :root { 6 | --background: #ffffff; 7 | --foreground: #171717; 8 | } 9 | 10 | @media (prefers-color-scheme: dark) { 11 | :root { 12 | --background: #0a0a0a; 13 | --foreground: #ededed; 14 | } 15 | } 16 | 17 | body { 18 | color: var(--foreground); 19 | background: var(--background); 20 | font-family: Arial, Helvetica, sans-serif; 21 | } 22 | 23 | @layer utilities { 24 | .text-balance { 25 | text-wrap: balance; 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /web/app/layout.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import React, { useState, useEffect } from "react"; 4 | import localFont from "next/font/local"; 5 | import "./globals.css"; 6 | 7 | const geistSans = localFont({ 8 | src: "./fonts/GeistVF.woff", 9 | variable: "--font-geist-sans", 10 | weight: "100 900", 11 | }); 12 | const geistMono = localFont({ 13 | src: "./fonts/GeistMonoVF.woff", 14 | variable: "--font-geist-mono", 15 | weight: "100 900", 16 | }); 17 | 18 | export default function RootLayout({ 19 | children, 20 | }: Readonly<{ 21 | children: React.ReactNode; 22 | }>) { 23 | const [isLoading, setIsLoading] = useState(true); 24 | 25 | useEffect(() => { 26 | const handleLoad = () => { 27 | setIsLoading(false); 28 | }; 29 | 30 | if (document.readyState === "complete") { 31 | handleLoad(); 32 | } else { 33 | window.addEventListener("load", handleLoad); 34 | return () => { 35 | window.removeEventListener("load", handleLoad); 36 | }; 37 | } 38 | }, []); 39 | 40 | return ( 41 | 42 | 45 | {!isLoading && children} 46 | 47 | 48 | ); 49 | } 50 | -------------------------------------------------------------------------------- /web/app/page.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import React, { useState, useEffect } from 'react'; 3 | import { Layout, Card, Input, Button, message, Form, Select, Spin, Flex, Row, Col } from 'antd'; 4 | import { HomeOutlined, BookOutlined, MailOutlined, GithubOutlined, XOutlined, DiscordOutlined} from '@ant-design/icons'; 5 | import dynamic from "next/dynamic"; 6 | 7 | const MDEditor = dynamic( 8 | () => import("@uiw/react-md-editor"), 9 | { ssr: false } 10 | ); 11 | const { Content, Header } = Layout; 12 | const { TextArea } = Input; 13 | const { Option } = Select; 14 | 15 | interface ReportData { 16 | project_okr: string; 17 | business_goal: string[]; 18 | dev_progress: string[]; 19 | communicate_progress: string[]; 20 | dev_todo: { task: string; description: string; priority: string }[]; 21 | communicate_todo: { task: string; priority: string }[]; 22 | hard_parts: string[]; 23 | require_manager_help: string[]; 24 | suggestions_to_user: string[]; 25 | reference: { title: string; link: string;}[]; 26 | } 27 | 28 | interface ReportRequest { 29 | repo: string; 30 | username: string; 31 | okr?: string; 32 | dateRange?: string; 33 | recurringReports?: string; 34 | additionalSources?: string[]; 35 | } 36 | 37 | export default function Home() { 38 | const [reportContent, setReportContent] = useState(""); 39 | const [reportData, setReportData] = useState(null); 40 | const [form] = Form.useForm(); 41 | const [loading, setLoading] = useState(false); 42 | 43 | useEffect(() => { 44 | fetchLatestReport(); 45 | form.setFieldsValue({ recurringReports: 'weekly' }); 46 | }, [form]); 47 | 48 | const fetchLatestReport = async () => { 49 | try { 50 | const response = await fetch('http://localhost:8000/latest_report'); 51 | if (response.status === 404) { 52 | setReportContent("No report has been found. Please click 'Generate Report' to create one."); 53 | setReportData(null); 54 | return; 55 | } 56 | if (!response.ok) { 57 | throw new Error('Failed to fetch the latest report'); 58 | } 59 | const data: ReportData = await response.json(); 60 | setReportData(data); 61 | const markdownContent = convertToMarkdown(data); 62 | setReportContent(markdownContent); 63 | } catch (error) { 64 | console.error('Error fetching latest report:', error); 65 | message.error('Failed to fetch the latest report'); 66 | } 67 | }; 68 | 69 | const convertToMarkdown = (data: ReportData): string => { 70 | let markdown = ''; 71 | markdown += `## Project Report\n\n`; 72 | 73 | if (data.project_okr) { 74 | markdown += `### Project OKR\n${data.project_okr}\n\n`; 75 | } 76 | 77 | markdown += `### Business Goal\n${data.business_goal.map(goal => `- ${goal}`).join('\n')}\n\n`; 78 | markdown += `### Work Finished This Week\n\n`; 79 | markdown += `#### Development Progress\n${data.dev_progress.map(progress => `- ${progress}`).join('\n')}\n\n`; 80 | markdown += `#### Communication/Design Progress\n${data.communicate_progress.map(progress => `- ${progress}`).join('\n')}\n\n`; 81 | markdown += `### Work TODOs in the Next Week\n\n`; 82 | markdown += `#### Development TODOs\n${data.dev_todo.map(todo => `- ${todo.task} **(${todo.priority})**: ${todo.description}`).join('\n')}\n\n`; 83 | markdown += `#### Communication TODOs\n${data.communicate_todo.map(todo => `- ${todo.task} **(${todo.priority})**`).join('\n')}\n\n`; 84 | markdown += `### Hard Problems\n\n`; 85 | markdown += `#### Challenges\n${data.hard_parts.map(part => `- ${part}`).join('\n')}\n\n`; 86 | markdown += `#### Manager Help Required\n${data.require_manager_help.map(help => `- ${help}`).join('\n')}\n\n`; 87 | markdown += `### Other Progress and Thoughts\n\n`; 88 | markdown += `#### Suggestions\n${data.suggestions_to_user.map(suggestion => `- ${suggestion}`).join('\n')}\n\n`; 89 | markdown += `#### References\n${data.reference.map(ref => `- [${ref.title}](${ref.link})`).join('\n')}\n\n`; 90 | return markdown; 91 | }; 92 | 93 | const handleGenerateReport = async (values: ReportRequest) => { 94 | setLoading(true); 95 | try { 96 | const response = await fetch('http://localhost:8000/gen_report', { 97 | method: 'POST', 98 | headers: { 99 | 'Content-Type': 'application/json', 100 | }, 101 | body: JSON.stringify(values), 102 | }); 103 | 104 | if (!response.ok) { 105 | throw new Error('Failed to generate report'); 106 | } 107 | 108 | const result = await response.json(); 109 | 110 | if (result.result) { 111 | setReportData(result.result); 112 | const markdownContent = convertToMarkdown(result.result); 113 | setReportContent(markdownContent); 114 | message.success('Report generated successfully'); 115 | } else { 116 | message.info('Report generation completed, but no data returned'); 117 | } 118 | } catch (error) { 119 | console.error('Error generating report:', error); 120 | message.error('Failed to generate report'); 121 | } finally { 122 | setLoading(false); 123 | } 124 | }; 125 | 126 | const handleSaveReport = () => { 127 | message.info('Save report functionality not implemented yet'); 128 | }; 129 | 130 | return ( 131 | 132 |
133 | 134 | 135 | 136 | Repx.app 137 | 138 | 139 | 140 | 141 | 144 | 147 | 150 |
156 | 157 |
158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 |
170 | 175 | 176 | 177 | 182 | 183 | 184 | 189 | 190 | 191 | 196 | 201 | 202 | 206 |