├── .github
├── CODEOWNERS
├── ISSUE_TEMPLATE.md
├── OPEN_SOURCE_LICENSES.md
├── PULL_REQUEST_TEMPLATE.md
└── workflows
│ ├── lint.yml
│ ├── publish.yml
│ └── test.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── assets
└── kaia_llama.webp
├── mle
├── __init__.py
├── agents
│ ├── __init__.py
│ ├── advisor.py
│ ├── chat.py
│ ├── coder.py
│ ├── debugger.py
│ ├── planner.py
│ ├── reporter.py
│ └── summarizer.py
├── cli.py
├── function
│ ├── __init__.py
│ ├── data.py
│ ├── execution.py
│ ├── files.py
│ ├── interaction.py
│ └── search.py
├── integration
│ ├── __init__.py
│ ├── github.py
│ ├── google_calendar.py
│ ├── kaggle.py
│ └── local_git.py
├── model
│ ├── __init__.py
│ ├── anthropic.py
│ ├── common.py
│ ├── deepseek.py
│ ├── gemini.py
│ ├── mistral.py
│ ├── ollama.py
│ ├── openai.py
│ └── vllm.py
├── server
│ ├── __init__.py
│ └── app.py
├── utils
│ ├── __init__.py
│ ├── cache.py
│ ├── chunk.py
│ ├── component_memory.py
│ ├── data.py
│ ├── memory.py
│ ├── parser.py
│ └── system.py
├── version.py
└── workflow
│ ├── __init__.py
│ ├── baseline.py
│ ├── chat.py
│ ├── kaggle.py
│ └── report.py
├── requirements.txt
├── setup.cfg
├── setup.py
├── tests
└── __init__.py
└── web
├── .eslintrc.json
├── .gitignore
├── README.md
├── app
├── fonts
│ ├── GeistMonoVF.woff
│ └── GeistVF.woff
├── globals.css
├── layout.tsx
└── page.tsx
├── next.config.mjs
├── package-lock.json
├── package.json
├── pnpm-lock.yaml
├── postcss.config.mjs
├── public
├── favicon.ico
└── logo.png
├── tailwind.config.ts
└── tsconfig.json
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | @huangyz0918
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
10 |
11 | #### Software execution information
12 |
13 | MLE-agent version:
14 | System OS version:
15 |
16 | #### Problem description
17 |
18 | #### Steps to reproduce the problem
19 |
20 | #### Expected behavior
21 |
22 | #### Other information
23 |
24 | Things you tried, stack traces, related issues, suggestions on how to fix it...
--------------------------------------------------------------------------------
/.github/OPEN_SOURCE_LICENSES.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/.github/OPEN_SOURCE_LICENSES.md
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | Closes #
2 |
3 |
6 |
7 | #### What has been done to verify that this works as intended?
8 |
9 | #### Why is this the best possible solution? Were any other approaches considered?
10 |
11 | #### How does this change affect users? Describe intentional changes to behavior and behavior that could have accidentally been affected by code changes. In other words, what are the regression risks?
12 |
13 | #### Do we need any specific form for testing your changes? If so, please attach one.
14 |
15 | #### Does this change require updates to documentation? If so, please file an issue [here](https://github.com/MLSysOps/MLE-agent/issues/new) and include the link below.
16 |
17 | #### Before submitting this PR, please make sure you have:
18 |
19 | - [ ] confirmed all checks still pass OR confirm CI build passes.
20 | - [ ] verified that any code or assets from external sources are properly credited in comments and/or in
21 | the [credit file](https://github.com/MLSysOps/MLE-agent/blob/main/.github/OPEN_SOURCE_LICENSES.md).
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on:
4 | push:
5 | branches:
6 | - dev
7 | - main
8 | pull_request:
9 | branches:
10 | - dev
11 | - main
12 | jobs:
13 | lint:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: Checkout
17 | uses: actions/checkout@v3
18 | - name: Set up Python
19 | uses: actions/setup-python@v4
20 | with:
21 | python-version: '3.11'
22 | - name: Flake8 Lint
23 | uses: py-actions/flake8@v2
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python Package
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | permissions:
8 | contents: read
9 |
10 | jobs:
11 | publish:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v4
15 | - name: Set up Python
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: "3.x"
19 | cache: pip
20 | cache-dependency-path: setup.py
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Publish
26 | env:
27 | TWINE_USERNAME: __token__
28 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Unit Test
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | testing:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v2
16 | - name: Set up Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: '3.11'
20 | - name: Install dependencies
21 | run: |
22 | pip install .
23 | - name: Run Test
24 | run: python -m unittest discover tests
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | # macOS
163 | *.DS_Store
164 |
165 | # cache
166 | .rich-chat.history
167 |
168 | poc/
169 |
170 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to MLE Agent
2 |
3 | This document is a work in progress. If you notice areas for improvement, please feel free to update this guide and submit a pull request!
4 |
5 | ## Table of Contents
6 |
7 | - [Submitting a Pull Request](#submitting-a-pull-request)
8 | - [Ensuring Your Pull Request Gets Accepted](#ensuring-your-pull-request-gets-accepted)
9 | - [The Review Process](#the-review-process)
10 | - [Work in Progress Pull Requests](#work-in-progress-pull-requests)
11 | - [Triaging Issues](#triaging-issues)
12 |
13 | ## Submitting a Pull Request
14 |
15 | To contribute code to MLE Agent, you need to open a [pull request](https://help.github.com/articles/about-pull-requests/). The pull request will be reviewed by the community before it is merged into the core project. Generally, a pull request should be submitted when a unit of work is complete, but you can also share ideas or get feedback through a work in progress (WIP) pull request ([learn more](#work-in-progress-pull-requests)).
16 |
17 | 1. Familiarize yourself with the project by reading our ["Getting Started Guide"](docs/GETTING_STARTED.md).
18 |
19 | 2. Follow our [coding standards](docs/CODE_GUIDELINES.md) to ensure consistency across the project.
20 |
21 | 3. Review our [testing guidelines](docs/TEST_GUIDELINES.md) to understand the project's automated testing framework.
22 |
23 | 4. [Set up your development environment](docs/DEVELOPMENT_SETUP.md) to make sure you have everything you need to contribute.
24 |
25 | 5. Make sure you have the latest version of the code by syncing your fork with the main repository:
26 |
27 | ```sh
28 | git remote add upstream https://github.com/MLSysOps/MLE-agent.git
29 | git fetch upstream
30 | git merge upstream/main
31 | ```
32 |
33 | 6. Create a branch for the code you will be working on:
34 |
35 | ```sh
36 | git checkout -b my-new-feature
37 | ```
38 |
39 | 7. Write your code, making sure to include tests as needed.
40 |
41 | 8. Commit your changes with a meaningful commit message:
42 |
43 | ```sh
44 | git commit -m "Description of the changes"
45 | ```
46 |
47 | 9. Push your changes to your fork:
48 |
49 | ```sh
50 | git push origin my-new-feature
51 | ```
52 |
53 | 10. Open a pull request on GitHub. Make sure to include a detailed description of the changes you made and any relevant context.
54 |
55 | ## Ensuring Your Pull Request Gets Accepted
56 |
57 | - Make sure your code follows the coding standards outlined in our code guidelines -- we use [flake8](https://flake8.pycqa.org/en/latest/) to enforce these standards.
58 | - Write tests for any new features or significant changes.
59 | - Ensure all tests pass before submitting your pull request.
60 | - Be responsive to feedback from reviewers.
61 |
62 |
63 | ## The Review Process
64 |
65 | Once you submit a pull request, it will be reviewed by the maintainers. They might request changes or provide feedback. The goal is to ensure the code is high quality and aligns with the project's goals.
66 |
67 | ## Work in Progress Pull Requests
68 |
69 | If you want feedback on your work before it's complete, you can open a WIP pull request. This allows you to get input from others on your approach or on specific parts of your code.
70 | When you're ready for a full review, you can mark the pull request as `MRG` for review by removing the `WIP` label.
71 |
72 | ## Triaging Issues
73 | If you're not ready to submit code but still want to contribute, you can help by triaging issues. This involves confirming bugs, providing additional information, or suggesting ways to reproduce issues.
74 |
75 | Thank you for your interest in contributing to MLE Agent!
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Huaizheng Zhang, Yizheng Huang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include web *
2 | recursive-include web/app *
3 |
4 | recursive-exclude web/node_modules *
5 | recursive-exclude web/.pnp
6 | recursive-exclude web/.pnp.js
7 | recursive-exclude web/.yarn/install-state.gz
8 | recursive-exclude web/coverage *
9 | recursive-exclude web/.next *
10 | recursive-exclude web/out *
11 | recursive-exclude web/build *
12 | recursive-exclude web/.DS_Store
13 | recursive-exclude web/*.pem
14 | recursive-exclude web/npm-debug.log*
15 | recursive-exclude web/yarn-debug.log*
16 | recursive-exclude web/yarn-error.log*
17 | recursive-exclude web/.env*.local
18 | recursive-exclude web/.vercel
19 | recursive-exclude web/*.tsbuildinfo
20 | recursive-exclude web/next-env.d.ts
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
MLE-Agent: Your intelligent companion for seamless AI engineering and research.
3 |
4 |
5 |
:love_letter: Fathers' love for Kaia :love_letter:
6 |
7 | 
8 | 
9 | 
10 | [](https://pepy.tech/project/mle-agent)
11 | 
12 |
13 |
14 | [📚 Docs](https://mle-agent-site.vercel.app/) |
15 | [🐞 Report Issues](https://github.com/MLSysOps/MLE-agent/issues/new) |
16 | 👋 Join us on
Discord
17 |
18 |
19 |
20 | ## Overview
21 |
22 | MLE-Agent is designed as a pairing LLM agent for machine learning engineers and researchers. It is featured by:
23 |
24 | - 🤖 Autonomous Baseline: Automatically builds ML/AI baselines and solutions based on your requirements.
25 | - 🏅End-to-end ML Task: Participates in Kaggle competitions and completes tasks independently.
26 | - 🔍 [Arxiv](https://arxiv.org/) and [Papers with Code](https://paperswithcode.com/) Integration: Access best practices
27 | and state-of-the-art methods.
28 | - 🐛 Smart Debugging: Ensures high-quality code through automatic debugger-coder interactions.
29 | - 📂 File System Integration: Organizes your project structure efficiently.
30 | - 🧰 Comprehensive Tools Integration: Includes AI/ML functions and MLOps tools for a seamless workflow.
31 | - ☕ Interactive CLI Chat: Enhances your projects with an easy-to-use chat interface.
32 | - 🧠 Smart Advisor: Provides personalized suggestions and recommendations for your ML/AI project.
33 | - 📊 Weekly Report: Automatically generates detailed summaries of your weekly works.
34 |
35 | https://github.com/user-attachments/assets/dac7be90-c662-4d0d-8d3a-2bc4df9cffb9
36 |
37 | ## Milestones
38 |
39 | - :rocket: 09/24/2024: Release the `0.4.2` with enhanced `Auto-Kaggle` mode to complete an end-to-end competition with minimal effort.
40 | - :rocket: 09/10/2024: Release the `0.4.0` with new CLIs like `MLE report`, `MLE kaggle`, `MLE integration` and many new
41 | models like `Mistral`.
42 | - :rocket: 07/25/2024: Release the `0.3.0` with huge refactoring, many integrations, etc. (v0.3.0)
43 | - :rocket: 07/11/2024: Release the `0.2.0` with multiple agents interaction (v0.2.0)
44 | - 👨🍼 **07/03/2024: Kaia is born**
45 | - :rocket: 06/01/2024: Release the first rule-based version of MLE agent (v0.1.0)
46 |
47 | ## Get started
48 |
49 | ### Installation
50 |
51 | ```bash
52 | pip install mle-agent -U
53 | # or from source
54 | git clone git@github.com:MLSysOps/MLE-agent.git
55 | pip install -e .
56 | ```
57 |
58 | ### Usage
59 |
60 | ```bash
61 | mle new
62 | ```
63 |
64 | And a project directory will be created under the current path, you need to start the project under the project
65 | directory.
66 |
67 | ```bash
68 | cd
69 | mle start
70 | ```
71 |
72 | You can also start an interactive chat in the terminal under the project directory:
73 |
74 | ```bash
75 | mle chat
76 | ```
77 |
78 | ## Use cases
79 |
80 | ### 🧪 Prototype an ML Baseline
81 |
82 | MLE agent can help you prototype an ML baseline with the given requirements, and test the model on the local machine.
83 | The requirements can be vague, such as "I want to predict the stock price based on the historical data".
84 |
85 | ```bash
86 | cd
87 | mle start
88 | ```
89 |
90 | ### :bar_chart: Generate Work Report
91 |
92 | MLE agent can help you summarize your weekly report, including development progress, communication notes, reference, and
93 | to-do lists.
94 |
95 | #### Mode 1: Web Application to Generate Report from GitHub
96 |
97 | ```bash
98 | cd
99 | mle report
100 | ```
101 |
102 | Then, you can visit http://localhost:3000/ to generate your report locally.
103 |
104 | #### Mode 2: CLI Tool to Generate Report from Local Git Repository
105 | ```bash
106 | cd
107 | mle report-local --email= --start-date=YYYY-MM-DD --end-date=YYYY-MM-DD
108 | ```
109 |
110 | - `--start-date` and `--end-date` are optional parameters. If omitted, the command will generate a report for the default date range of the last 7 days.
111 | - Replace `` with your Git email and `` with the path to your local Git repository.
112 |
113 | ### :trophy: Start with Kaggle Competition
114 |
115 | MLE agent can participate in Kaggle competitions and finish coding and debugging from data preparation to model training
116 | independently. Here is the basic command to start a Kaggle competition:
117 |
118 | ```bash
119 | cd
120 | mle kaggle
121 | ```
122 |
123 | Or you can let the agents finish the Kaggle task without human interaction if you have the dataset and submission file
124 | ready:
125 |
126 | ```bash
127 | cd
128 | mle kaggle --auto \
129 | --datasets ",,..." \
130 | --description "" \
131 | --submission "" \
132 | --sub_example "" \
133 | --comp_id ""
134 | ```
135 |
136 | Please make sure you have joined the competition before running the command. For more details, see the [MLE-Agent Tutorials](https://mle-agent-site.vercel.app/tutorial/Start_a_kaggle_task).
137 |
138 | ## Roadmap
139 |
140 | The following is a list of the tasks we plan to do, welcome to propose something new!
141 |
142 |
143 | :hammer: General Features
144 |
145 | - [x] Understand users' requirements to create an end-to-end AI project
146 | - [x] Suggest the SOTA data science solutions by using the web search
147 | - [x] Plan the ML engineering tasks with human interaction
148 | - [x] Execute the code on the local machine/cloud, debug and fix the errors
149 | - [x] Leverage the built-in functions to complete ML engineering tasks
150 | - [x] Interactive chat: A human-in-the-loop mode to help improve the existing ML projects
151 | - [x] Kaggle mode: to finish a Kaggle task without humans
152 | - [x] Summary and reflect the whole ML/AI pipeline
153 | - [ ] Integration with Cloud data and testing and debugging platforms
154 | - [x] Local RAG support to make personal ML/AI coding assistant
155 | - [ ] Function zoo: generate AI/ML functions and save them for future usage
156 |
157 |
158 |
159 |
160 | :star: More LLMs and Serving Tools
161 |
162 | - [x] Ollama LLama3
163 | - [x] OpenAI GPTs
164 | - [x] Anthropic Claude 3.5 Sonnet
165 |
166 |
167 |
168 |
169 | :sparkling_heart: Better user experience
170 |
171 | - [x] CLI Application
172 | - [x] Web UI
173 | - [x] Discord
174 |
175 |
176 |
177 |
178 | :jigsaw: Functions and Integrations
179 |
180 | - [x] Local file system
181 | - [x] Local code exectutor
182 | - [x] Arxiv.org search
183 | - [x] Papers with Code search
184 | - [x] General keyword search
185 | - [ ] Hugging Face
186 | - [ ] SkyPilot cloud deployment
187 | - [ ] Snowflake data
188 | - [ ] AWS S3 data
189 | - [ ] Databricks data catalog
190 | - [ ] Wandb experiment monitoring
191 | - [ ] MLflow management
192 | - [ ] DBT data transform
193 |
194 |
195 |
196 | ## Contributing
197 |
198 | We welcome contributions from the community. We are looking for contributors to help us with the following tasks:
199 |
200 | - Benchmark and Evaluate the agent
201 | - Add more features to the agent
202 | - Improve the documentation
203 | - Write tests
204 |
205 | Please check the [CONTRIBUTING.md](CONTRIBUTING.md) file if you want to contribute.
206 |
207 | ## Support and Community
208 |
209 | - [Discord community](https://discord.gg/SgxBpENGRG). If you have any questions, please ask in the Discord community.
210 |
211 | ## Citation
212 |
213 | ```bibtex
214 | @misc{zhang2024mleagent,
215 | title = {MLE-Agent: Your Intelligent Companion for Seamless AI Engineering and Research},
216 | author = {Huaizheng Zhang*, Yizheng Huang*, Lei Zhang},
217 | year = {2024},
218 | note = {\url{https://github.com/MLSysOps/MLE-agent}},
219 | }
220 | ```
221 |
222 | ## License
223 |
224 | Check [MIT License](LICENSE) file for more information.
225 |
--------------------------------------------------------------------------------
/assets/kaia_llama.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/assets/kaia_llama.webp
--------------------------------------------------------------------------------
/mle/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__
2 |
--------------------------------------------------------------------------------
/mle/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from .advisor import *
2 | from .coder import *
3 | from .debugger import *
4 | from .planner import *
5 | from .summarizer import *
6 | from .reporter import *
7 | from .chat import *
8 |
--------------------------------------------------------------------------------
/mle/agents/chat.py:
--------------------------------------------------------------------------------
1 | from rich.console import Console
2 |
3 | from mle.function import *
4 | from mle.utils import get_config, WorkflowCache
5 | from mle.utils.component_memory import trace_component
6 |
7 | class ChatAgent:
8 |
9 | def __init__(self, model, memory=None, working_dir='.', console=None):
10 | """
11 | ChatAgent assists users with planning and debugging ML projects.
12 |
13 | Args:
14 | model: The machine learning model used for generating responses.
15 | """
16 | config_data = get_config()
17 |
18 | self.model = model
19 | self.memory = memory
20 | self.chat_history = []
21 | if working_dir == '.':
22 | working_dir = os.getcwd()
23 | self.working_dir = working_dir
24 | self.cache = WorkflowCache(working_dir, 'baseline')
25 |
26 | self.console = console
27 | if not self.console:
28 | self.console = Console()
29 |
30 | self.sys_prompt = f"""
31 | You are a programmer working on an Machine Learning task using Python.
32 | You are currently working on: {self.working_dir}.
33 |
34 | Your can leverage your capabilities by using the specific functions listed below:
35 |
36 | 1. Creating project structures based on the user requirement using function `create_directory`.
37 | 2. Writing clean, efficient, and well-documented code using function `create_file` and `write_file`.
38 | 3. Exam the project to re-use the existing code snippets as much as possible, you may need to use
39 | functions like `list_files`, `read_file` and `write_file`.
40 | 4. Writing the code into the file when creating new files, do not create empty files.
41 | 5. Use function `preview_csv_data` to preview the CSV data if the task include CSV data processing.
42 | 6. Decide whether the task requires execution and debugging before moving to the next or not.
43 | 7. Generate the commands to run and test the current task, and the dependencies list for this task.
44 | 8. You only write Python scripts, don't write Jupiter notebooks which require interactive execution.
45 | """
46 | self.search_prompt = """
47 | 9. Performing web searches use function `web_search` to get up-to-date information or additional context.
48 | """
49 |
50 | self.functions = [
51 | schema_read_file,
52 | schema_create_file,
53 | schema_write_file,
54 | schema_list_files,
55 | schema_create_directory,
56 | schema_search_arxiv,
57 | schema_search_papers_with_code,
58 | schema_web_search,
59 | schema_execute_command,
60 | schema_preview_csv_data,
61 | schema_unzip_data,
62 | schema_preview_zip_structure
63 | ]
64 |
65 | if config_data.get('search_key'):
66 | self.functions.append(schema_web_search)
67 | self.sys_prompt += self.search_prompt
68 |
69 | if not self.cache.is_empty():
70 | dataset = self.cache.resume_variable("dataset")
71 | ml_requirement = self.cache.resume_variable("ml_requirement")
72 | advisor_report = self.cache.resume_variable("advisor_report")
73 | self.sys_prompt += f"""
74 | The overall project information: \n
75 | {'Dataset: ' + str(dataset) if dataset else ''} \n
76 | {'Requirement: ' + str(ml_requirement) if ml_requirement else ''} \n
77 | {'Advisor: ' + str(advisor_report) if advisor_report else ''} \n
78 | """
79 |
80 | self.chat_history.append({"role": 'system', "content": self.sys_prompt})
81 |
82 | def greet(self):
83 | """
84 | Generate a greeting message to the user, including inquiries about the project's purpose and
85 | an overview of the support provided. This initializes a collaborative tone with the user.
86 |
87 | Returns:
88 | str: The generated greeting message.
89 | """
90 | greet_prompt = """
91 | Can you provide concise and friendly greetings within 50 words, including:
92 | 1. Infer about the project's purpose or objective.
93 | 2. Summarize the previous conversations if it existed.
94 | 2. Offering a brief overview of the assistance and support you can provide to the user, such as:
95 | - Helping with project planning and management.
96 | - Assisting with debugging and troubleshooting code.
97 | - Offering advice on best practices and optimization techniques.
98 | - Providing resources and references for further learning.
99 | Make sure your greeting is inviting and sets a positive tone for collaboration.
100 | """
101 | self.chat_history.append({"role": "user", "content": greet_prompt})
102 | greets = self.model.query(
103 | self.chat_history,
104 | function_call='auto',
105 | functions=self.functions,
106 | )
107 |
108 | self.chat_history.append({"role": "assistant", "content": greets})
109 | return greets
110 | @trace_component("chat")
111 | def chat(self, user_prompt):
112 | """
113 | Handle the response from the model streaming.
114 | The stream mode is integrative with the model streaming function, we don't
115 | need to set it into the JSON mode.
116 |
117 | Args:
118 | user_prompt: the user prompt.
119 | """
120 | text = ''
121 | if self.memory:
122 | table_name = 'mle_chat_' + self.working_dir.split('/')[-1]
123 | query = self.memory.query([user_prompt], table_name=table_name, n_results=1) # TODO: adjust the n_results.
124 | user_prompt += f"""
125 | \nThese reference files and their snippets may be useful for the question:\n\n
126 | """
127 |
128 | for t in query[0]:
129 | snippet, metadata = t.get('text'), t.get('metadata')
130 | user_prompt += f"**File**: {metadata.get('file')}\n**Snippet**: {snippet}\n"
131 | self.chat_history.append({"role": "user", "content": user_prompt})
132 |
133 | for content in self.model.stream(
134 | self.chat_history,
135 | function_call='auto',
136 | functions=self.functions,
137 | ):
138 | if content:
139 | text += content
140 | yield text
141 |
142 | self.chat_history.append({"role": "assistant", "content": text})
143 |
--------------------------------------------------------------------------------
/mle/agents/coder.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | from rich.console import Console
4 |
5 | from mle.function import *
6 | from mle.utils import get_config, print_in_box, clean_json_string
7 | from mle.utils.component_memory import trace_component
8 |
9 | def process_summary(summary_dict: dict):
10 | """
11 | Process the code summary.
12 | Args:
13 | summary_dict: the code summary in a dictionary format.
14 | """
15 | return textwrap.dedent(f"""
16 | MLE Developer has finished the task: {summary_dict.get('task')}.\n
17 | Task description: {summary_dict.get('task_description')}\n
18 | {summary_dict.get('message')}\n
19 | Dependencies you are required to run the code: {summary_dict.get('dependency')}\n
20 | Command to run the code: {summary_dict.get('command')}\n
21 | Whether the code is required to execute and debug: {summary_dict.get('debug')}""")
22 |
23 |
24 | class CodeAgent:
25 |
26 | def __init__(self, model, working_dir='.', console=None, single_file=False):
27 | """
28 | CodeAgent: the agent to solve the given coding problems, by planning coding tasks, searching websites,
29 | and generating code snippets. It does not execute the code, only make use of built-in functions to provides
30 | the code snippets to solve the problem.
31 |
32 | Args:
33 | model: the model to use.
34 | working_dir: the working directory.
35 | console: the console to use.
36 | single_file: whether the agent is working on a single file or not.
37 | """
38 | config_data = get_config()
39 | self.code_summary = None
40 | self.model = model
41 | self.chat_history = []
42 | self.working_dir = working_dir
43 |
44 | self.console = console
45 | if not self.console:
46 | self.console = Console()
47 |
48 | self.sys_prompt = f"""
49 | You are a programmer working on an Machine Learning task using Python.
50 | You are currently working on: {self.working_dir}.
51 |
52 | Your can leverage your capabilities by using the specific functions listed below:
53 |
54 | - Creating project structures based on the user requirement using function `create_directory`.
55 | - Writing clean, efficient, and well-documented code using function `create_file` and `write_file`.
56 | - Exam the project to re-use the existing code snippets as much as possible, you may need to use
57 | functions like `list_files`, `read_file` and `write_file`.
58 | - Use function `preview_zip_structure` to preview the structure of the file if the task include zip file processing.
59 | - Use function `unzip_data` to extract the compressed file if the task include compressed file processing.
60 | - Writing the code into the file when creating new files, do not create empty files.
61 | - Use function `preview_csv_data` to preview the CSV data if the task include CSV data processing.
62 | - Decide whether the task requires execution and debugging before moving to the next or not.
63 | - Generate the commands to run and test the current task, and the dependencies list for this task.
64 | - You only write Python scripts, don't write Jupiter notebooks which require interactive execution.
65 | """
66 |
67 | if single_file:
68 | self.sys_prompt = f"""
69 | You are an expert programmer working on an Machine Learning task using Python.
70 | You are currently working on: {self.working_dir}.
71 |
72 | Your can leverage your capabilities by using the specific functions listed below:
73 |
74 | - You should create a single script first, with the complete code inside. You can have multiple functions and classes.
75 | - Writing clean, efficient, and well-documented code to a script using functions `create_file`.
76 | - Use function `preview_csv_data` to preview the CSV data if the task include CSV dataset or examples.
77 | - Use function `preview_zip_structure` to preview the structure of the file if the task include zip file processing.
78 | - Use function `unzip_data` to extract the compressed file if the task include compressed file processing.
79 | - Generate the commands to run and test the current script, and the dependencies list required for this script.
80 | - You only write Python scripts, don't write Jupiter notebooks which require interactive execution.
81 | - Make sure the code has met the task description, and the suggested methods.
82 | - Make sure the output format and the output file path is correct.
83 | """
84 |
85 | self.search_prompt = """
86 | - Performing web searches use function `web_search` to get up-to-date information or additional context.
87 | """
88 |
89 | self.json_mode_prompt = """
90 |
91 | The output format should be in JSON format, include:
92 |
93 | 1. The dependency list that the project needs to run.
94 | 2. And the command to run and test the project.
95 | 3. The reason why failed if the status is failed, put it in the "message" field.
96 | 4. Whether the task requires execution and debug or not (it is "false" when create new directories or files).
97 | If the task requires modifying existing code or generating new code, it is "true". If the "command" is empty,
98 | the "debug" should be "false".
99 |
100 | Example JSON output:
101 |
102 | {
103 | "dependency":[
104 | "torch",
105 | "scikit-learn"
106 | ],
107 | "command":"python /path/to/your/project.py",
108 | "message":"the project-related has been generated in the project.py.",
109 | "debug":"true"
110 | }
111 |
112 | """
113 |
114 | if single_file:
115 | self.json_mode_prompt = """
116 |
117 | The output format should be in JSON format, include:
118 |
119 | 1. The dependency list that the project needs to run.
120 | 2. And the command and the parameters to run and test the script.
121 |
122 | Example JSON output:
123 |
124 | {
125 | "dependency":[
126 | "torch",
127 | "scikit-learn"
128 | ],
129 | "command":"python /path/to/your/project.py",
130 | }
131 |
132 | """
133 |
134 | self.functions = [
135 | schema_read_file,
136 | schema_create_file,
137 | schema_write_file,
138 | schema_list_files,
139 | schema_create_directory,
140 | schema_preview_csv_data,
141 | schema_preview_zip_structure,
142 | schema_unzip_data
143 | ]
144 |
145 | if config_data.get('search_key'):
146 | self.functions.append(schema_web_search)
147 | self.sys_prompt += self.search_prompt
148 |
149 | self.sys_prompt += self.json_mode_prompt
150 | self.chat_history.append({"role": 'system', "content": self.sys_prompt})
151 |
152 | @trace_component("coder")
153 | def read_requirement(self, advisor_report: str):
154 | """
155 | Read the user requirement and the advisor report.
156 | :param advisor_report:
157 | :return:
158 | """
159 | self.chat_history.append({"role": "system", "content": advisor_report})
160 |
161 | @trace_component("coder")
162 | def code(self, task_dict: dict):
163 | """
164 | Handle the query from the model query response.
165 | Args:
166 | task_dict: the task dictionary.
167 | """
168 | task_prompt = textwrap.dedent(f"""
169 | You are required to complete task: {task_dict.get('task')}.\n
170 | Task description: {task_dict.get('description')}
171 | """)
172 |
173 | with self.console.status(f"Coder is working on the task: {task_dict.get('task')}..."):
174 | self.chat_history.append({"role": "user", "content": task_prompt})
175 | text = self.model.query(
176 | self.chat_history,
177 | function_call='auto',
178 | functions=self.functions,
179 | response_format={"type": "json_object"}
180 | )
181 |
182 | self.chat_history.append({"role": "assistant", "content": text})
183 | code_summary = clean_json_string(text)
184 | code_summary.update({'task': task_dict.get('task'), 'task_description': task_dict.get('description')})
185 | return code_summary
186 |
187 | @trace_component("coder")
188 | def debug(self, task_dict: dict, debug_report: dict):
189 | """
190 | Handle the query from the model query response.
191 | :param task_dict: the task dictionary.
192 | :param debug_report: the debug report from DebugAgent.
193 | :return:
194 | """
195 | improve_prompt = textwrap.dedent(f"""
196 | You are required improve the existing project.\n
197 | The required changes: {debug_report.get("changes")}\n
198 | The suggestion: {debug_report.get("suggestion")}
199 |
200 | """)
201 |
202 | with self.console.status(f"Coder is improving the code for task {task_dict.get('task')}..."):
203 | self.chat_history.append({"role": "user", "content": improve_prompt})
204 | text = self.model.query(
205 | self.chat_history,
206 | function_call='auto',
207 | functions=self.functions,
208 | response_format={"type": "json_object"}
209 | )
210 |
211 | self.chat_history.append({"role": "assistant", "content": text})
212 | code_summary = clean_json_string(text)
213 | code_summary.update({'task': task_dict.get('task'), 'task_description': task_dict.get('description')})
214 | return code_summary
215 |
216 | def interact(self, task_dict: dict):
217 | """
218 | Interact with the user to code the task.
219 | Args:
220 | task_dict: the task dictionary.
221 | """
222 | self.code_summary = self.code(task_dict)
223 | print_in_box(process_summary(self.code_summary), self.console, title="MLE Developer", color="cyan")
224 | while True:
225 | suggestion = questionary.text(
226 | "Any feedback to MLE developer? (ENTER to move to the next stage, \"exit\" to exit the project)"
227 | ).ask()
228 |
229 | if not suggestion:
230 | break
231 |
232 | if suggestion.lower() in ["exit"]:
233 | sys.exit(0)
234 |
235 | with self.console.status(f"MLE Developer is working on the task: {task_dict.get('task')}..."):
236 | self.chat_history.append({"role": "user", "content": suggestion})
237 | text = self.model.query(
238 | self.chat_history,
239 | function_call='auto',
240 | functions=self.functions,
241 | response_format={"type": "json_object"}
242 | )
243 |
244 | self.chat_history.append({"role": "assistant", "content": text})
245 | self.code_summary = clean_json_string(text)
246 | self.code_summary.update(
247 | {
248 | 'task': task_dict.get('task'),
249 | 'task_description': task_dict.get('description')
250 | }
251 | )
252 | print_in_box(process_summary(self.code_summary), self.console, title="MLE Developer", color="cyan")
253 | return self.code_summary
254 |
--------------------------------------------------------------------------------
/mle/agents/debugger.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from mle.function import *
4 | from mle.utils import get_config, print_in_box
5 |
6 | from rich.console import Console
7 | from mle.utils.component_memory import trace_component
8 |
9 | def process_debug_report(debug_report):
10 | """
11 | Process the debug report.
12 | Args:
13 | debug_report: the debug report.
14 | """
15 | if debug_report.get("status") == "success":
16 | return "The code runs without errors."
17 | else:
18 | changes = debug_report.get("changes")
19 | suggestion = debug_report.get("suggestion")
20 | report_str = "The code is running with errors.\nChanges are required:\n"
21 | for change in changes:
22 | report_str += f"File: {change.get('file')}, Line: {change.get('line')}, Issue: {change.get('issue')}, " \
23 | f"Suggestion: {change.get('suggestion')}\n"
24 | report_str += f"Overall Suggestion: {suggestion}"
25 | return report_str
26 |
27 |
28 | class DebugAgent:
29 |
30 | def __init__(self, model, console=None, analyze_only=False):
31 | """
32 | DebugAgent: the agent to run the generated the code and then debug it. The return of the
33 | agent is an instruction to the user to modify the code based on the logs and web search.
34 |
35 | Args:
36 | model: the model to use.
37 | console: the console to use.
38 | analyze_only: if only analyze the code without execution
39 | """
40 | config_data = get_config()
41 | self.console = console
42 | if not self.console:
43 | self.console = Console()
44 | self.model = model
45 | self.chat_history = []
46 | self.sys_prompt = """
47 | You are a program error debugger working on a Python project.
48 |
49 | Your can leverage your capabilities by using the specific functions listed below:
50 |
51 | - Install the code dependencies using function `execute_command` based on the Developer's dependencies list.
52 | - Execute the code using function `execute_command` to test the code based on the Developer's instructions.
53 | - If the program returns errors, you need to debug the code based on the logs, you may need to first read the
54 | structure of the project using function `list_files`.
55 | - Then you may need to call `read_file` function to read the content of the code files, locate the error line
56 | and the reasons.
57 | - You don't need to care about the best practices and code styles, you only care about the errors in the code.
58 |
59 | """
60 |
61 | self.search_prompt = """
62 | - You need to debug the code based on the error logs, you may need to call `web_search` function to search for
63 | the solutions or reasons for the errors if needed.
64 | """
65 |
66 | self.json_mode_prompt = """
67 |
68 | Example JSON output if a program runs without errors:
69 | {
70 | "status":"success",
71 | "changes":[],
72 | "suggestion":""
73 | }
74 |
75 | Example JSON output if a program returns errors:
76 | {
77 | "status":"error",
78 | "changes":[
79 | {
80 | "file":"xxx.py",
81 | "line":10,
82 | "issue":"xxx",
83 | "suggestion":"xxx"
84 | },
85 | "suggestion":"Failed to find the target file. Please check the file path."
86 | ]
87 | }
88 | """
89 |
90 | self.functions = [
91 | schema_read_file,
92 | schema_list_files,
93 | schema_execute_command
94 | ]
95 |
96 | if analyze_only:
97 | self.sys_prompt = """
98 | You are a program error debugger working on a Python project. You target is to
99 | analyze the running logs and the source code to locate the errors. And give the
100 | suggestions to the developer to fix the errors.
101 |
102 | Your can leverage your capabilities by using the specific functions listed below:
103 |
104 | - Read and understand the running logs and the error messages.
105 | - Read the content of the source code files using function `read_file`, to locate the error line.
106 | - Give the suggestions to the developer to fix the errors.
107 | - Install missing dependencies using function `execute_command` based on the error logs.
108 | - If there is no error based on the exit code, you don't need to do anything but return the success status.
109 | """
110 |
111 | if config_data.get('search_key'):
112 | self.functions.append(schema_web_search)
113 | self.sys_prompt += self.search_prompt
114 |
115 | self.sys_prompt += self.json_mode_prompt
116 | self.chat_history.append({"role": 'system', "content": self.sys_prompt})
117 |
118 | def analyze_with_log(self, commands, logs):
119 | """
120 | Analyze the logs.
121 | :param commands: the commands to execute.
122 | :param logs: the logs to analyze.
123 | :return:
124 | """
125 | analyze_prompt = f"""
126 | The command to execute the code: {commands} \n
127 | The logs to analyze: {logs} \n
128 | """
129 |
130 | self.chat_history.append({"role": "user", "content": analyze_prompt})
131 | try:
132 | text = self.model.query(
133 | self.chat_history,
134 | function_call='auto',
135 | functions=self.functions,
136 | response_format={"type": "json_object"}
137 | )
138 | except Exception as e:
139 | print(f"Error occurred while querying the model: {e}")
140 | return {}
141 |
142 | self.chat_history.append({"role": "assistant", "content": text})
143 | report_dict = json.loads(text)
144 | print_in_box(process_debug_report(report_dict), self.console, title="MLE Debugger", color="yellow")
145 | return report_dict
146 | @trace_component("debugger")
147 | def analyze(self, code_report):
148 | """
149 | Handle the query from the model query response.
150 | Args:
151 | code_report: the code report from the MLE developer.
152 | """
153 | debug_prompt = f"""
154 | Please help me debug the current task: {code_report.get('task')}. {code_report.get('messages')}\n
155 | The task description: {code_report.get('task_description')}
156 | The dependencies required for this task: {code_report.get('dependencies')}
157 | The command to execute the code: {code_report.get('command')}
158 |
159 | """
160 |
161 | error_msg = code_report.get('error_message')
162 | if error_msg:
163 | debug_prompt += f"Error message: {error_msg}\n"
164 |
165 | self.chat_history.append({"role": "user", "content": debug_prompt})
166 | try:
167 | text = self.model.query(
168 | self.chat_history,
169 | function_call='auto',
170 | functions=self.functions,
171 | response_format={"type": "json_object"}
172 | )
173 | except Exception as e:
174 | print(f"Error occurred while querying the model: {e}")
175 | return {}
176 |
177 | self.chat_history.append({"role": "assistant", "content": text})
178 | report_dict = json.loads(text)
179 | print_in_box(process_debug_report(report_dict), self.console, title="MLE Debugger", color="yellow")
180 | return report_dict
--------------------------------------------------------------------------------
/mle/agents/planner.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | import questionary
4 | from rich.console import Console
5 |
6 | from mle.utils import print_in_box, clean_json_string
7 | from mle.utils.component_memory import trace_component
8 |
9 |
10 | def process_plan(plan_dict: dict):
11 | plan_str = ""
12 | for task in plan_dict.get('tasks'):
13 | plan_str += f"[green][Task]:[/green] {task.get('task')}\n[green][Description]:[/green] {task.get('description')}\n\n"
14 |
15 | return plan_str
16 |
17 |
18 | class PlanAgent:
19 |
20 | def __init__(self, model, console=None):
21 | """
22 | PlanAgent: the agent to plan the machine learning project. By receiving the user's requirements, the agent will
23 | first analyze the requirements and ask the user to provide more details if necessary. Then the agent will
24 | generate the project plan based on the requirements and the user's input.
25 |
26 | The project plan will be sent to the advisor agent to provide suggestions on the best machine learning task,
27 | model, dataset, and evaluation metrics to use.
28 |
29 | Args:
30 | model: the model to use.
31 | """
32 | self.console = console
33 | if not self.console:
34 | self.console = Console()
35 |
36 | self.plan_dict = None
37 | self.model = model
38 | self.chat_history = []
39 | self.sys_prompt = """
40 | You are an Machine learning Product Manager, you are going to collaborate with the user to plan the ML
41 | project. A generated plan includes the coding tasks for the developer to complete the project.
42 |
43 | Your capabilities include:
44 |
45 | 1. Understand the user's dataset and the user's requirements, the requirements may include the intention of the
46 | project, the model or algorithm to use, the dataset, the evaluation metrics, and the expected results. The
47 | generated task should always meet the user's requirements.
48 | 2. The generated plan should include several coding tasks, and the task should include specific instructions for
49 | the developer. For example: "Create a directory named 'dataset' under the project root, and write a Python
50 | script called 'data_loader.py' to download the dataset ImageNet from the official website."
51 | 3. The coding task should be clear and easy to understand, but with essential information to complete the task.
52 | For example, if the dataset is a user's local CSV file, you should provide the absolute path to the file in the
53 | task, otherwise, the developer may not be able to complete the task.
54 | 4. Please only provide the coding tasks, do not provide the code snippets, the developer will complete the task.
55 | 5. Do not generate task like "setup environment", "install dependencies", "run the code", etc. The developer
56 | only focus on the coding tasks.
57 |
58 | """
59 | self.json_mode_prompt = """
60 |
61 | Example JSON output:
62 |
63 | {
64 | "tasks": [
65 | {
66 | "task": "download dataset",
67 | "description": "Create a directory named 'dataset' under the project root, and write a Python
68 | script called 'data_loader.py' to download the dataset ImageNet from the official website."
69 | },
70 | {
71 | "task": "process ImageNet",
72 | "description": "Write a Python script called `process_data.py` to process the dataset by
73 | resizing the images to 224x224 pixels and save the data to the 'processed_data' directory."
74 | },
75 | {
76 | "task": "train model",
77 | "description": "Write a Python script called `train_model.py` to train an image classification
78 | model on the processed data and save the trained model to the 'model' directory."
79 | }
80 | ]
81 | }
82 | """
83 | self.sys_prompt += self.json_mode_prompt
84 | self.chat_history.append({"role": 'system', "content": self.sys_prompt})
85 |
86 | @trace_component("planner")
87 | def plan(self, user_prompt):
88 | """
89 | Handle the query from the model query response.
90 | Args:
91 | user_prompt: the user prompt.
92 | """
93 | with self.console.status("MLE Planner is planning the coding tasks..."):
94 | self.chat_history.append({"role": "user", "content": user_prompt})
95 | text = self.model.query(
96 | self.chat_history,
97 | response_format={"type": "json_object"}
98 | )
99 |
100 | self.chat_history.append({"role": "assistant", "content": text})
101 |
102 | try:
103 | return json.loads(text)
104 | except json.JSONDecodeError as e:
105 | return clean_json_string(text)
106 |
107 | @trace_component("planner")
108 | def interact(self, user_prompt):
109 | """
110 | Handle the query from the model query response.
111 | Args:
112 | user_prompt: the user prompt.
113 | """
114 | self.plan_dict = self.plan(user_prompt)
115 | print_in_box(process_plan(self.plan_dict), self.console, title="MLE Planner", color="purple")
116 |
117 | while True:
118 | suggestion = questionary.text(
119 | "Suggestions to improve the plan? (ENTER to move to the next stage, \"exit\" to exit the project)"
120 | ).ask()
121 |
122 | if not suggestion or suggestion.lower() == "no":
123 | break
124 |
125 | if suggestion.lower() == "exit":
126 | sys.exit(0)
127 |
128 | self.plan_dict = self.plan(suggestion)
129 | print_in_box(process_plan(self.plan_dict), self.console, title="MLE Planner", color="purple")
130 |
131 | return self.plan_dict
--------------------------------------------------------------------------------
/mle/agents/reporter.py:
--------------------------------------------------------------------------------
1 | import json
2 | from rich.console import Console
3 | from time import gmtime, strftime
4 | from mle.utils.component_memory import trace_component
5 |
6 | class ReportAgent:
7 |
8 | def __init__(self, model, console=None):
9 | """
10 | ReportAgent: generate the report based on the information provided by the user.
11 |
12 | Args:
13 | model: the model to use.
14 | console: the console to use.
15 | """
16 | self.report = None
17 | self.knowledge = None
18 | self.model = model
19 | self.chat_history = []
20 | self.console = console
21 | if not self.console:
22 | self.console = Console()
23 | self.sys_prompt = """
24 | You are writing a weekly progress report for an engineer working on a project. Your capabilities include:
25 |
26 | 1. Based on the user's input information, you need to organize the information and generate more details from the
27 | user's perspective.
28 | 2. You need to generate a section called "Development Progress" based on the user's Github
29 | summary given by the user, do not use the commit messages directly.
30 | 3. You need to generate a section called "Communication / Design Progress" based on the user's Google Calendar
31 | events (if any). Not all events are related to the project but you need to filter out the related ones.
32 | 4. You need to generate a section called "Development To-do" based on the user's Github information, and the
33 | task priority, with the highest priority first and generate more details.
34 | 5. You need to generate a section called "Communication / Design To-do" based on the user's future
35 | Google Calendar events (if any).
36 | 6. You need to generate a section called "Existing Hard Parts" to summarize/infer the hard parts of the project.
37 | 7. Based on the hard parts and the project information, you need to generate a section called
38 | "Require Manager' / Others’ help", to indicate the parts that may need help.
39 | 8. You can generate as more as possible details to make sure the report is informative and has great progress.
40 |
41 | """
42 | self.json_mode_prompt = """
43 |
44 | JSON Output Format:
45 |
46 | {
47 | "project_okr": "if user provides the ORKs, put there. Otherwise, put an empty string",
48 | "business_goal": ["The project aims to build an image classification model...", ...],
49 | "dev_progress": ["implemented the data collection Python function...", ...],
50 | "communicate_progress": ["Meeting with the design team to discuss the new feature...", ...],
51 | "dev_todo": [{"task": "fix ...", "description": ..., "priority": "high"}, {"task": "support ..."," description": ..., "priority": "medium"}, ...],
52 | "communicate_todo": [{"task": "seek helps from ...", "priority": "high"},
53 | {"task": "meet with ...", "priority": "low"} ...],
54 | "hard_parts": ["The project may face the challenge of ...", ...],
55 | "require_manager_help": ["The project needs help from the design team to ...", ...],
56 | "suggestions_to_user": ["Increase more meeting with design team...", ...],
57 | "reference": [{"title": "xxxx", "link":"https://arxiv.org/abs/xxx.xxxx"}, {"title": "xxx", "link": "https://github.com/xxx"}, ...],
58 | }
59 |
60 | """
61 | self.sys_prompt += self.json_mode_prompt
62 | self.chat_history.append({"role": 'system', "content": self.sys_prompt})
63 |
64 | def process_knowledge(self, github_summary: dict, calendar_events: list = None, okr: str = None):
65 | """
66 | Process the knowledge to generate the report.
67 |
68 | Args:
69 | github_summary: the summary of the GitHub project.
70 | calendar_events: the Google Calendar events.
71 | okr: the OKR of the project.
72 | """
73 | info_prompt = f"""
74 | # Project Overview
75 |
76 | ## The username: {github_summary.get('username')}\n
77 | ## The repository: {github_summary.get('github_repo')}\n
78 | ## Technology stack: {github_summary.get('tech_stack')}\n
79 | ## The project summary: {github_summary.get('summary')}\n
80 | """
81 | if okr:
82 | info_prompt += f"\n## The project's OKR: \n"
83 | info_prompt += f"{okr}\n"
84 |
85 | info_prompt += f"\n## The project's business goal: \n"
86 | for goal in github_summary.get("business_goal", []):
87 | info_prompt += f"- {goal}\n"
88 |
89 | if github_summary.get("dataset"):
90 | info_prompt += f"\n## The project's datasets: \n"
91 | for dataset in github_summary.get("dataset"):
92 | info_prompt += f"- {dataset['name']}: {dataset['description']}\n"
93 |
94 | info_prompt += f"\n## The project's roadmap: \n"
95 | for task in github_summary.get("roadmap", []):
96 | info_prompt += f"- {task['task']} ({task['priority']})\n"
97 |
98 | info_prompt += f"\n## The project's hard parts: \n"
99 | for part in github_summary.get("hard_parts", []):
100 | info_prompt += f"- {part}\n"
101 |
102 | info_prompt += f"\n## The project's related work: \n"
103 | for work in github_summary.get("related_work", []):
104 | info_prompt += f"- {work['title']} ({work['link']})\n"
105 |
106 | activities = github_summary.get("user_activity")
107 | info_prompt += f"""
108 | # User's Activity (from {activities['period']['start']} to {activities['period']['end']})
109 |
110 | """
111 |
112 | info_prompt += f"""
113 | ## Contributions:\n
114 | - Commits: {activities['summary']['total_commits']}
115 | - Pull Requests: {activities['summary']['total_pull_requests']}
116 | - Issues: {activities['summary']['total_issues']}\n
117 | """
118 |
119 | info_prompt += f"## The user's commits: \n"
120 | for commit in activities['commits']['messages']:
121 | info_prompt += f"- {commit}"
122 |
123 | info_prompt += f"\n## The user's pull requests: \n"
124 | for pr in activities['pull_requests']['details']:
125 | info_prompt += f"- {pr['title']} ({pr['status']})"
126 |
127 | info_prompt += f"\n## The user's issues: \n"
128 | for issue in activities['issues']['details']:
129 | info_prompt += f"- {issue['title']}"
130 |
131 | if calendar_events:
132 | info_prompt += f"\n## The user's calendar events:\n"
133 | for event in calendar_events:
134 | info_prompt += (f"- Title: {event['title']}\n"
135 | f" Time: ({event['start_time']} - {event['end_time']})\n"
136 | f" Description: {event['description']}\n"
137 | f" Organizer: {event['organizer']['email']}\n")
138 |
139 | self.knowledge = info_prompt
140 | return info_prompt
141 |
142 | @trace_component("reporter")
143 | def gen_report(self, github_summary: dict, calendar_events: list = None, okr: str = None):
144 | """
145 | Handle the query from the model query response.
146 | Args:
147 | github_summary: the summary of the GitHub project.
148 | calendar_events: the Google Calendar
149 | okr: the OKR of the project.
150 | """
151 | with self.console.status("MLE reporter is writing the progress report..."):
152 | self.chat_history.append(
153 | {
154 | "role": "user",
155 | "content": self.process_knowledge(github_summary, calendar_events, okr)
156 | }
157 | )
158 | text = self.model.query(
159 | self.chat_history,
160 | response_format={"type": "json_object"}
161 | )
162 |
163 | self.chat_history.append({"role": "assistant", "content": text})
164 | # save the dict into a local files
165 | today = strftime("%Y_%m_%d", gmtime())
166 | result_dict = json.loads(text)
167 | with open(f'progress_report_{today}.json', 'w') as f:
168 | json.dump(result_dict, f)
169 | return result_dict
--------------------------------------------------------------------------------
/mle/function/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import py7zr
3 | import gzip
4 | import bz2
5 | import lzma
6 | import shutil
7 | import tarfile
8 | import zipfile
9 | import textwrap
10 | import tempfile
11 | import pandas as pd
12 | from pandas.api.types import is_numeric_dtype
13 |
14 |
15 | def unzip_data(compressed_file_path, extract_path=None):
16 | """
17 | Unzip a compressed file, supporting various formats (.zip, .7z, .tar, .gz, .bz2, .xz).
18 | If no extract_path is provided, it creates a temporary directory.
19 |
20 | :param compressed_file_path: Path to the compressed file
21 | :param extract_path: Path where the contents will be extracted. If None, a temp directory is used.
22 | :return: String with the path to the unzipped contents
23 | """
24 | if not os.path.exists(compressed_file_path):
25 | raise FileNotFoundError(f"The file {compressed_file_path} does not exist.")
26 |
27 | # If no extract_path is provided, create a temporary directory
28 | if extract_path is None:
29 | extract_path = tempfile.mkdtemp()
30 | print(f"No extract path provided. Using temporary directory: {extract_path}")
31 | else:
32 | # Create the extraction directory if it doesn't exist
33 | os.makedirs(extract_path, exist_ok=True)
34 |
35 | file_extension = os.path.splitext(compressed_file_path)[1].lower()
36 | file_name = os.path.splitext(os.path.basename(compressed_file_path))[0]
37 |
38 | # Create a subdirectory with the name of the compressed file
39 | specific_extract_path = os.path.join(extract_path, file_name)
40 | os.makedirs(specific_extract_path, exist_ok=True)
41 |
42 | try:
43 | if file_extension == '.zip':
44 | with zipfile.ZipFile(compressed_file_path, 'r') as zip_ref:
45 | zip_ref.extractall(specific_extract_path)
46 |
47 | elif file_extension == '.7z':
48 | with py7zr.SevenZipFile(compressed_file_path, mode='r') as z:
49 | z.extractall(specific_extract_path)
50 |
51 | elif file_extension in ['.tar', '.gz', '.bz2', '.xz']:
52 | if file_extension == '.gz':
53 | open_func = gzip.open
54 | elif file_extension == '.bz2':
55 | open_func = bz2.open
56 | elif file_extension == '.xz':
57 | open_func = lzma.open
58 | else:
59 | open_func = open
60 |
61 | with open_func(compressed_file_path, 'rb') as f:
62 | if tarfile.is_tarfile(compressed_file_path) or file_extension in ['.gz', '.bz2', '.xz']:
63 | with tarfile.open(fileobj=f) as tar:
64 | tar.extractall(path=specific_extract_path)
65 | else:
66 | # For single file compression (non-tar)
67 | output_filename = os.path.splitext(os.path.basename(compressed_file_path))[0]
68 | output_path = os.path.join(specific_extract_path, output_filename)
69 | with open(output_path, 'wb') as out_f:
70 | shutil.copyfileobj(f, out_f)
71 |
72 | else:
73 | raise ValueError(f"Unsupported file format: {file_extension}")
74 |
75 | print(f"Successfully extracted {compressed_file_path} to {specific_extract_path}")
76 | return specific_extract_path
77 |
78 | except Exception as e:
79 | print(f"Error extracting {compressed_file_path}: {str(e)}")
80 | raise
81 |
82 |
83 | def preview_zip_structure(zip_path, max_files=50, max_dirs=20, max_output_length=1000, show_hidden=False):
84 | """
85 | Preview the structure of a zip file with limits on output and option to show hidden files.
86 | :param zip_path: the path to the zip file.
87 | :param max_files: maximum number of files to display.
88 | :param max_dirs: maximum number of directories to display.
89 | :param max_output_length: maximum length of the output string.
90 | :param show_hidden: if True, show hidden files and directories (starting with a dot).
91 | :return: the limited structure of the zip file as a string.
92 | """
93 | if not os.path.exists(zip_path):
94 | return f"Error: The file '{zip_path}' does not exist."
95 |
96 | if not zipfile.is_zipfile(zip_path):
97 | return f"Error: '{zip_path}' is not a valid zip file."
98 |
99 | structure = []
100 | file_count = 0
101 | dir_count = 0
102 | total_count = 0
103 | hidden_count = 0
104 |
105 | with zipfile.ZipFile(zip_path, 'r') as zip_ref:
106 | for file_info in zip_ref.infolist():
107 | file_path = file_info.filename
108 | is_hidden = os.path.basename(file_path).startswith('.')
109 |
110 | if is_hidden and not show_hidden:
111 | hidden_count += 1
112 | continue
113 |
114 | if file_info.is_dir():
115 | if dir_count < max_dirs:
116 | structure.append(f"Directory: {file_path}")
117 | dir_count += 1
118 | else:
119 | if file_count < max_files:
120 | structure.append(f"File: {file_path}")
121 | file_count += 1
122 |
123 | total_count += 1
124 | if len("\n".join(structure)) >= max_output_length:
125 | structure.append("... (output truncated due to length)")
126 | break
127 |
128 | if file_count >= max_files:
129 | structure.append(f"... (and {total_count - file_count - dir_count} more files)")
130 | if dir_count >= max_dirs:
131 | structure.append(f"... (and {total_count - file_count - dir_count} more directories)")
132 | if not show_hidden and hidden_count > 0:
133 | structure.append(f"... ({hidden_count} hidden items not shown)")
134 |
135 | output = "\n".join(structure)
136 | if len(output) > max_output_length:
137 | output = output[:max_output_length] + "... (output truncated)"
138 |
139 | return output
140 |
141 |
142 | def preview_csv_data(path: str, limit_rows: int = 5, limit_columns: int = None) -> str:
143 | """
144 | Preview the sample dataset from the project data path and include metadata.
145 | :param path: the path to a local CSV file.
146 | :param limit_rows: the number of rows to preview.
147 | :param limit_columns: the number of columns to preview. If None, all columns are previewed.
148 | :return: the sample dataset with metadata as a string.
149 | """
150 | try:
151 | df = pd.read_csv(path)
152 | num_rows, num_cols = df.shape
153 | summary = [f"CSV file in `{path}` has {num_rows} rows and {num_cols} columns."]
154 |
155 | if limit_columns is not None and limit_columns < num_cols:
156 | columns_to_preview = sorted(df.columns)[:limit_columns]
157 | summary.append(f"Previewing {limit_columns} out of {num_cols} columns.")
158 | else:
159 | columns_to_preview = sorted(df.columns)
160 |
161 | summary.append("Here is some information about the columns:")
162 |
163 | for col in columns_to_preview:
164 | dtype = df[col].dtype
165 | name = f"{col} ({dtype})"
166 | nan_count = df[col].isnull().sum()
167 | if dtype == "bool":
168 | true_percentage = df[col].mean() * 100
169 | summary.append(f"{name} is {true_percentage:.2f}% True, {100 - true_percentage:.2f}% False")
170 | elif df[col].nunique() < 10:
171 | unique_values = df[col].unique().tolist()
172 | summary.append(f"{name} has {df[col].nunique()} unique values: {unique_values}")
173 | elif is_numeric_dtype(df[col]):
174 | min_val, max_val = df[col].min(), df[col].max()
175 | summary.append(f"{name} has range: {min_val:.2f} - {max_val:.2f}, {nan_count} NaN values")
176 | elif dtype == "object":
177 | unique_count = df[col].nunique()
178 | example_values = df[col].value_counts().head(limit_rows).index.tolist()
179 | summary.append(f"{name} has {unique_count} unique values. Some example values: {example_values}")
180 |
181 | return textwrap.dedent("\n".join(summary)).strip()
182 | except Exception as e:
183 | return f"Cannot read CSV data: {e}"
184 |
--------------------------------------------------------------------------------
/mle/function/execution.py:
--------------------------------------------------------------------------------
1 | """
2 | Tools to execute functions, acquire runtime logs.
3 | """
4 |
5 | import subprocess
6 | from collections import deque
7 |
8 |
9 | def execute_command(command: str, max_lines: int = 30):
10 | """
11 | Run a command in the shell and return the outputs, errors, and exit status,
12 | limiting the output to a specified number of most recent lines.
13 |
14 | Args:
15 | command (str): The input command to run.
16 | max_lines (int): Maximum number of output lines to keep. Defaults to 100.
17 |
18 | Return: A string of the exit status and the limited output (most recent lines).
19 | """
20 | try:
21 | process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
22 | output_buffer = deque(maxlen=max_lines)
23 |
24 | while True:
25 | line = process.stdout.readline()
26 | if not line and process.poll() is not None:
27 | break
28 | output_buffer.append(line.rstrip())
29 | print(line, end='')
30 |
31 | exit_code = process.wait()
32 |
33 | limited_output = "\n".join(output_buffer)
34 | if len(output_buffer) == max_lines:
35 | return f"Exit code: {exit_code}\nOutput (last {max_lines} lines):\n{limited_output}"
36 | else:
37 | return f"Exit code: {exit_code}\nOutput:\n{limited_output}"
38 |
39 | except Exception as e:
40 | return f"Error running command: {str(e)}"
41 |
--------------------------------------------------------------------------------
/mle/function/files.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def read_file(file_path: str, limit: int = 2000):
5 | """
6 | Reads the contents of a file and returns it as a string.
7 |
8 | Args:
9 | file_path (str): The path to the file that needs to be read.
10 | limit (int, optional): Maximum number of lines to read.
11 |
12 | Returns:
13 | str: The contents of the file as a string.
14 | """
15 | try:
16 | with open(file_path, 'r', encoding='utf-8') as file:
17 | if limit <= 0:
18 | return file.read()
19 | lines = []
20 | for i, line in enumerate(file):
21 | if i >= limit:
22 | break
23 | lines.append(line)
24 | return ''.join(lines)
25 | except FileNotFoundError:
26 | return f"File not found: {file_path}"
27 |
28 |
29 | def create_file(path, content):
30 | """
31 | Create a file with the given path and content.
32 | Args:
33 | path (str): The path to the file to create.
34 | content (str): The initial content to write to the file.
35 | """
36 | try:
37 | with open(path, 'w') as f:
38 | f.write(content)
39 | return f"File created: {path}"
40 | except Exception as e:
41 | return f"Error creating file: {str(e)}"
42 |
43 |
44 | def write_file(path, content):
45 | """
46 | Write content to a file.
47 | Args:
48 | path (str): The path to the file to write to.
49 | content (str): The content to write to the file.
50 | """
51 | try:
52 | with open(path, 'w') as f:
53 | f.write(content)
54 | return f"Content written to file: {path}"
55 | except Exception as e:
56 | return f"Error writing to file: {str(e)}"
57 |
58 |
59 | def list_files(path, limit=50):
60 | """
61 | Lists files and directories under the given path if it is a directory,
62 | up to a specified limit.
63 |
64 | Args:
65 | path (str): The file system path to check and list contents from.
66 | limit (int): Maximum number of items to list. Defaults to 50.
67 |
68 | Returns: A string containing the list of file and directory names under
69 | the given path, or a message if the path is a file or if the
70 | number of items exceeds the limit.
71 | """
72 | if os.path.isfile(path):
73 | return "The given path is a file. Please provide a path of a directory."
74 |
75 | try:
76 | files = os.listdir(path)
77 | except PermissionError:
78 | return "Permission denied to access this directory."
79 | except FileNotFoundError:
80 | return "The specified directory does not exist."
81 | except Exception as e:
82 | return f"An error occurred: {str(e)}"
83 |
84 | total_files = len(files)
85 |
86 | if total_files > limit:
87 | files = files[:limit]
88 | output = "\n".join(files)
89 | output += f"\n\n... and {total_files - limit} more items (total of {total_files} items)"
90 | else:
91 | output = "\n".join(files)
92 | output += f"\n\nTotal items: {total_files}"
93 |
94 | return output
95 |
96 |
97 | def create_directory(path: str):
98 | """
99 | Create a directory if it does not exist.
100 | Args:
101 | path (str): The path to the directory to create.
102 | """
103 | try:
104 | os.makedirs(path, exist_ok=True)
105 | return f"Directory '{path}' created successfully."
106 | except OSError as error:
107 | return f"Creation of the directory '{path}' failed due to: {error}"
108 |
--------------------------------------------------------------------------------
/mle/function/interaction.py:
--------------------------------------------------------------------------------
1 | import questionary
2 |
3 |
4 | def ask_question(question: str):
5 | """
6 | Ask a question to the user.
7 | """
8 | answer = input("[ADVISOR]: " + question)
9 | return f"Question: {question}\nAnswer: {answer}"
10 |
11 |
12 | def ask_yes_no(question: str):
13 | """
14 | Ask a yes/no question to the user.
15 | """
16 | return questionary.confirm("[ADVISOR]: " + question).ask()
17 |
18 |
19 | def ask_choices(question: str, choices: list):
20 | """
21 | Ask a multiple choice question to the user.
22 | """
23 | return "More details: " + questionary.select("[ADVISOR]: " + question, choices=choices).ask()
24 |
--------------------------------------------------------------------------------
/mle/function/search.py:
--------------------------------------------------------------------------------
1 | """
2 | Search API functions based on Tavily.
3 | """
4 | import os
5 | import requests
6 | from tavily import TavilyClient
7 | from xml.etree import ElementTree
8 |
9 |
10 | def search_github_repos(query, limit=5):
11 | """
12 | Search GitHub public repositories based on a keyword.
13 |
14 | :param query: The query to search for in repository names or descriptions.
15 | :param limit: The total number of repositories to return.
16 | :return: A list of dictionaries containing repository details, limited to the specified number.
17 | """
18 | repos = []
19 | per_page = 10
20 | page = 1
21 | while len(repos) < limit:
22 | url = f'https://api.github.com/search/repositories?q={query}&per_page={per_page}&page={page}'
23 |
24 | response = requests.get(url)
25 |
26 | if response.status_code == 200:
27 | items = response.json().get('items', [])
28 | for item in items:
29 | formatted_repo = {
30 | "name": f"{item['owner']['login']}/{item['name']}",
31 | "author": item['owner']['login'],
32 | "description": item['description'],
33 | "link": item['html_url']
34 | }
35 | repos.append(formatted_repo)
36 | if len(repos) >= limit:
37 | break
38 |
39 | if len(items) < per_page: # Stop if there are no more repos to fetch
40 | break
41 | page += 1
42 | else:
43 | raise Exception(f"GitHub API request failed with status code {response.status_code}: {response.text}")
44 |
45 | return_str = """
46 | Here are some of the repositories I found on GitHub:
47 | """
48 |
49 | for repo in repos:
50 | return_str += f"""
51 | Name: {repo['name']}
52 | Description: {repo['description']}
53 | Link: {repo['link']}
54 | """
55 |
56 | return return_str
57 |
58 |
59 | def web_search(query: str):
60 | """
61 | Perform a web search based on the query.
62 | Args:
63 | query: The search query.
64 | """
65 | try:
66 | client = TavilyClient(api_key=os.environ['SEARCH_API_KEY'])
67 | response = client.qna_search(query=query, search_depth="advanced")
68 | return response
69 | except Exception as e:
70 | return f"Error performing web search: {str(e)}"
71 |
72 |
73 | def search_arxiv(query, max_results=8):
74 | url = 'https://export.arxiv.org/api/query'
75 | params = {
76 | 'search_query': query,
77 | 'start': 0,
78 | 'max_results': max_results
79 | }
80 | response = requests.get(url, params=params)
81 | if response.status_code != 200:
82 | return f"Error: Unable to fetch data from arXiv (Status code: {response.status_code})"
83 |
84 | root = ElementTree.fromstring(response.content)
85 | output = ""
86 | for entry in root.findall('{http://www.w3.org/2005/Atom}entry'):
87 | title = entry.find('{http://www.w3.org/2005/Atom}title').text
88 | summary = entry.find('{http://www.w3.org/2005/Atom}summary').text
89 | link = entry.find('{http://www.w3.org/2005/Atom}id').text
90 | published = entry.find('{http://www.w3.org/2005/Atom}published').text
91 | authors = [author.find('{http://www.w3.org/2005/Atom}name').text for author in
92 | entry.findall('{http://www.w3.org/2005/Atom}author')]
93 |
94 | output += f"""
95 | Title: {title.strip()}
96 | Summary: {summary.strip()}
97 | Link: {link.strip()}
98 | Published: {published.strip()}
99 | Authors: {authors}
100 | """
101 |
102 | return output
103 |
104 |
105 | def search_papers_with_code(query: str, k: int = 8) -> str:
106 | url = f"https://paperswithcode.com/api/v1/search/"
107 | response = requests.get(url, params={'page': 1, 'q': query})
108 | if response.status_code != 200:
109 | return "Failed to retrieve data from Papers With Code."
110 |
111 | data = response.json()
112 | if 'results' not in data:
113 | return "No results found for the given query."
114 |
115 | results = data['results'][:int(k)] # Get top-k results
116 | result_strings = []
117 |
118 | for result in results:
119 | paper = result['paper']
120 | paper_title = paper.get('title', 'No title available')
121 | abstract = paper.get('abstract', 'No abstract available')
122 | paper_pdf_url = paper.get('url_pdf', 'No PDF available')
123 | repository = result.get('repository', [])
124 | if repository:
125 | code_url = repository.get('url', 'No official code link available')
126 | else:
127 | code_url = 'No official code link available'
128 |
129 | result_string = f"Title: {paper_title}\nAbstract:{abstract}\nPaper URL: {paper_pdf_url}\nCode URL: {code_url}\n"
130 | result_strings.append(result_string)
131 |
132 | return "\n".join(result_strings)
133 |
--------------------------------------------------------------------------------
/mle/integration/__init__.py:
--------------------------------------------------------------------------------
1 | from .local_git import GitIntegration
2 | from .github import GitHubIntegration, github_login
3 | from .google_calendar import GoogleCalendarIntegration, google_calendar_login
4 | from .kaggle import KaggleIntegration
5 |
--------------------------------------------------------------------------------
/mle/integration/google_calendar.py:
--------------------------------------------------------------------------------
1 | import json
2 | import datetime
3 | from mle.utils import load_file
4 | from google.auth.transport.requests import Request
5 | from google_auth_oauthlib.flow import InstalledAppFlow
6 | from googleapiclient.discovery import build
7 |
8 |
9 | def google_calendar_login(credential=None):
10 | """
11 | Authenticate the user using Google OAuth 2.0 and return the credentials.
12 |
13 | :param credential: The client secrets.
14 | :return: Google OAuth 2.0 credentials or None if authentication fails.
15 | """
16 |
17 | if credential is None:
18 | # FIXME: remove the test app_credential
19 | credential = json.loads(load_file(
20 | "https://raw.githubusercontent.com/leeeizhang/leeeizhang/assets/google_app",
21 | base64_decode=True,
22 | ))
23 |
24 | try:
25 | SCOPES = ["https://www.googleapis.com/auth/calendar.readonly"]
26 | flow = InstalledAppFlow.from_client_config(
27 | credential,
28 | SCOPES,
29 | )
30 | creds = flow.run_local_server(host="127.0.0.1", port=0)
31 | except Exception:
32 | return None
33 | return creds
34 |
35 |
36 | class GoogleCalendarIntegration:
37 | """
38 | Class to interface with Google Calendar API to fetch events.
39 |
40 | :param token: Google OAuth 2.0 credentials.
41 | """
42 |
43 | def __init__(self, token=None):
44 | self.token = token
45 | if self.token.expired and self.token.refresh_token:
46 | self.token.refresh(Request())
47 |
48 | def get_events(self, start_date=None, end_date=None, limit=100, detailed=True):
49 | """
50 | Fetch upcoming calendar events.
51 | :param start_date: Start date for calendar events (inclusive), in 'YYYY-MM-DD' format
52 | :param end_date: End date for calendar events (inclusive), in 'YYYY-MM-DD' format
53 | :param limit: The maximum number of events to return.
54 | :param detailed: Whether to return detailed event information.
55 | :return: A list of events with details or None if an error occurs.
56 | """
57 | try:
58 | # Set default dates if not provided
59 | today = datetime.date.today()
60 | if start_date is None:
61 | start_date = (today - datetime.timedelta(days=7)).isoformat()
62 | if end_date is None:
63 | end_date = (today + datetime.timedelta(days=7)).isoformat()
64 |
65 | # Convert dates to datetime objects with time
66 | start_dt = datetime.datetime.strptime(f"{start_date}T00:00:00Z", "%Y-%m-%dT%H:%M:%S%z")
67 | end_dt = datetime.datetime.strptime(f"{end_date}T23:59:59Z", "%Y-%m-%dT%H:%M:%S%z")
68 |
69 | if start_dt >= end_dt:
70 | raise ValueError("start_date must be less than end_date")
71 |
72 | # Convert back to string format for API call
73 | start_date = start_dt.isoformat()
74 | end_date = end_dt.isoformat()
75 |
76 | # Build the service object for interacting with the Google Calendar API
77 | service = build("calendar", "v3", credentials=self.token)
78 |
79 | # Retrieve the events from the primary calendar
80 | events_result = (
81 | service.events()
82 | .list(
83 | calendarId="primary",
84 | timeMin=start_date,
85 | timeMax=end_date,
86 | maxResults=limit,
87 | singleEvents=True,
88 | orderBy="startTime",
89 | )
90 | .execute()
91 | )
92 | events_result = events_result.get("items", [])
93 |
94 | # Format the events into a specified structure
95 | events = []
96 | for event in events_result:
97 | e = {
98 | "title": event.get("summary"),
99 | "status": event.get("status"),
100 | "description": event.get("description"),
101 | "creator": event.get("creator"),
102 | "organizer": event.get("organizer"),
103 | "start_time": event["start"].get("dateTime", event["start"].get("date")),
104 | "end_time": event["end"].get("dateTime", event["end"].get("date"))
105 | }
106 |
107 | if detailed:
108 | e.update(
109 | {
110 | "htmlLink": event.get("htmlLink"),
111 | "kind": event.get("kind"),
112 | }
113 | )
114 |
115 | events.append(e)
116 | return events
117 |
118 | except Exception as e:
119 | print(f"An error occurred: {e}")
120 | return None
121 |
--------------------------------------------------------------------------------
/mle/integration/kaggle.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import requests
4 | import questionary
5 | from zipfile import ZipFile
6 |
7 |
8 | class KaggleIntegration:
9 |
10 | def __init__(self):
11 | """
12 | Initializes KaggleIntegration with the provided credentials.
13 | """
14 |
15 | kaggle_file = os.path.join(os.path.expanduser("~"), ".kaggle", "kaggle.json")
16 |
17 | if not os.path.exists(kaggle_file):
18 | username = questionary.text("What is your Kaggle username?").ask()
19 | key = questionary.password("What is your Kaggle token?").ask()
20 | if username and key:
21 | os.makedirs(os.path.dirname(kaggle_file), exist_ok=True)
22 | with open(kaggle_file, "w") as f:
23 | json.dump({"username": username, "key": key}, f)
24 |
25 | from kaggle.api.kaggle_api_extended import KaggleApi
26 | self.api = KaggleApi()
27 | self.api.authenticate()
28 |
29 | def list_competition(self):
30 | """
31 | Lists all Kaggle competitions.
32 | :return: A tuple containing references of all competitions.
33 | """
34 | competitions = self.api.competitions_list()
35 | return tuple([comp.ref for comp in competitions])
36 |
37 | def download_competition_dataset(
38 | self, competition: str, download_dir: str = "./data"
39 | ):
40 | """
41 | Downloads and extracts the dataset for a specific competition.
42 | :param competition: The URL or name of the Kaggle competition.
43 | :param download_dir: Directory to save the downloaded files. Defaults to './data'.
44 | :return: The directory where the dataset has been downloaded and extracted.
45 | """
46 | if competition.startswith("https://www.kaggle.com/competitions/"):
47 | competition = competition.split("/")[-1]
48 |
49 | os.makedirs(download_dir, exist_ok=True)
50 | self.api.competition_download_files(competition, path=download_dir)
51 |
52 | # Unzip downloaded files
53 | for file in os.listdir(download_dir):
54 | if file.endswith(".zip"):
55 | with ZipFile(os.path.join(download_dir, file), "r") as zip_ref:
56 | zip_ref.extractall(download_dir)
57 | return download_dir
58 |
59 | def fetch_competition_overview(self, competition: str):
60 | """
61 | Fetches competition overview information using the Kaggle API.
62 | :param competition: The URL or name of the Kaggle competition.
63 | :return: A dictionary containing competition overview information, or None if not found.
64 | """
65 | for _ in range(3): # Retry 3 times if the request fails
66 | try:
67 | reader_url = f"https://r.jina.ai/{competition}/overview/description"
68 | response = requests.get(
69 | reader_url,
70 | timeout=30,
71 | headers={"X-Return-Format": "markdown"},
72 | )
73 | response.raise_for_status()
74 | overview = response.text
75 | break
76 | except requests.exceptions.HTTPError:
77 | continue
78 | return overview.encode('utf-8', 'ignore').decode('utf-8')
79 |
--------------------------------------------------------------------------------
/mle/integration/local_git.py:
--------------------------------------------------------------------------------
1 | from git import Repo, NULL_TREE
2 | from datetime import datetime, timezone, timedelta
3 |
4 | import os
5 | import fnmatch
6 | import subprocess
7 |
8 | class GitIntegration:
9 | def __init__(self, path):
10 | self.repo_path = path
11 | self.repo = Repo(self.repo_path)
12 | if self.repo.bare:
13 | raise Exception("Repository is not valid or is bare.")
14 |
15 | def get_repo_status(self):
16 | """
17 | Get the status of a git repository
18 | :return: List of changed files
19 | """
20 | try:
21 | changed_files = []
22 | for diff in self.repo.index.diff(None):
23 | changed_files.append({
24 | 'file_path': diff.a_path,
25 | 'change_type': diff.change_type,
26 | 'author': self.repo.head.commit.author.name,
27 | 'date': datetime.fromtimestamp(self.repo.head.commit.committed_date).strftime("%Y-%m-%d %H:%M:%S")
28 | })
29 |
30 | return changed_files
31 |
32 | except Exception as e:
33 | return f"An error occurred: {str(e)}"
34 |
35 | def get_commit_history(self, start_date=None, end_date=None, email=None, limit=None):
36 | """
37 | Process commit history within a specified date range and for a specific user (email).
38 | :param start_date: Start date for commit range (inclusive), in 'YYYY-MM-DD' format
39 | :param end_date: End date for commit range (inclusive), in 'YYYY-MM-DD' format
40 | :param username: GitHub username to filter commits (optional)
41 | :param limit: Maximum number of commits to retrieve (default is None, which retrieves all commits in range)
42 | :return: Dictionary of commits
43 | """
44 | end_time = None
45 | if end_date is not None:
46 | end_time = f"{end_date}T23:59:59Z"
47 |
48 | start_time = None
49 | if start_date is not None:
50 | start_time = f"{start_date}T00:00:00Z"
51 |
52 | try:
53 | commit_history = []
54 | for commit in self.repo.iter_commits(max_count=limit):
55 | commit_date = datetime.fromtimestamp(commit.committed_date)
56 | commit_date = commit_date.replace(tzinfo=timezone.utc)
57 | if start_time is not None and commit_date < datetime.fromisoformat(start_time):
58 | continue
59 |
60 | if end_time is not None and commit_date > datetime.fromisoformat(end_time):
61 | continue
62 |
63 | if email is not None and commit.author.email != email:
64 | continue
65 |
66 | commit_history.append({
67 | 'commit_hash': commit.hexsha,
68 | 'author': commit.author.name,
69 | 'email': commit.author.email,
70 | 'message': commit.message.strip(),
71 | 'date': commit_date.strftime("%Y-%m-%d %H:%M:%S")
72 | })
73 |
74 | return commit_history
75 |
76 | except Exception as e:
77 | return f"An error occurred: {str(e)}"
78 |
79 | def get_commit_diff(self, commit_hash, show_content=False):
80 | """
81 | Get the code changes for a specific commit
82 | :param commit_hash: The hash of the commit to get the changes for
83 | :param show_content: Boolean to determine if file content should be included (default False)
84 | :return: A dictionary containing the commit info and the code changes
85 | """
86 | try:
87 | commit = self.repo.commit(commit_hash)
88 | parent = commit.parents[0] if commit.parents else NULL_TREE
89 |
90 | diffs = parent.diff(commit)
91 |
92 | changes = {}
93 | for diff in diffs:
94 | if show_content:
95 | if diff.a_blob and diff.b_blob:
96 | a_content = diff.a_blob.data_stream.read().decode('utf-8', errors='ignore')
97 | b_content = diff.b_blob.data_stream.read().decode('utf-8', errors='ignore')
98 | changes[diff.a_path] = {
99 | 'change_type': 'modified',
100 | 'old_content': a_content,
101 | 'new_content': b_content
102 | }
103 | elif diff.a_blob:
104 | changes[diff.a_path] = {
105 | 'change_type': 'deleted',
106 | 'old_content': diff.a_blob.data_stream.read().decode('utf-8', errors='ignore'),
107 | 'new_content': None
108 | }
109 | elif diff.b_blob:
110 | changes[diff.b_path] = {
111 | 'change_type': 'added',
112 | 'old_content': None,
113 | 'new_content': diff.b_blob.data_stream.read().decode('utf-8', errors='ignore')
114 | }
115 | else:
116 | if diff.a_blob and diff.b_blob:
117 | changes[diff.a_path] = {'change_type': 'modified'}
118 | elif diff.a_blob:
119 | changes[diff.a_path] = {'change_type': 'deleted'}
120 | elif diff.b_blob:
121 | changes[diff.b_path] = {'change_type': 'added'}
122 |
123 | commit_info = {
124 | 'commit_hash': commit.hexsha,
125 | 'author': commit.author.name,
126 | 'email': commit.author.email,
127 | 'message': commit.message.strip(),
128 | 'date': datetime.fromtimestamp(commit.committed_date).strftime("%Y-%m-%d %H:%M:%S"),
129 | 'changes': changes
130 | }
131 |
132 | return commit_info
133 |
134 | except Exception as e:
135 | return f"An error occurred: {str(e)}"
136 |
137 | def get_source_code(self, file_pattern="*"):
138 | """
139 | Process source code files in the repository.
140 | :param file_pattern: Wildcard pattern to filter files (e.g., "*.py" for Python files)
141 | :return: Dictionary with file paths as keys and file contents as values
142 | """
143 |
144 | def get_contents(path="", file_pattern=file_pattern):
145 | for root, _, files in os.walk(os.path.join(self.repo_path, path)):
146 | for filename in fnmatch.filter(files, file_pattern):
147 | file_path = os.path.join(root, filename)
148 | with open(file_path, 'r') as f:
149 | yield {
150 | 'path': os.path.relpath(file_path, self.repo_path),
151 | 'name': filename,
152 | 'content': f.read()
153 | }
154 |
155 | return {file['path']: file['content'] for file in get_contents()}
156 |
157 | def get_readme(self):
158 | """
159 | Get readme content of the repository.
160 | :return: The readme content
161 | """
162 | content = self.get_source_code("README.md")
163 | if len(content):
164 | return list(content.values())[0]
165 | return None
166 |
167 | def get_structure(self, path=''):
168 | """
169 | Scan and return the file structure and file names of the Git repository as a list of paths.
170 | :param path: The path to start scanning from (default is root)
171 | :param branch: The branch to scan (if None, the repository's default branch will be used)
172 | :param include_invisible: Whether to include invisible files/folders (starting with .) (default is False)
173 | :return: A list of file paths in the repository
174 | """
175 | result = subprocess.run(
176 | ["git", "-C", path, "ls-files"],
177 | stdout=subprocess.PIPE,
178 | text=True,
179 | check=True
180 | )
181 |
182 | return result.stdout.splitlines()
183 |
184 | def get_user_activity(self, email, start_date=None, end_date=None):
185 | """
186 | Aggregate information about a user's activity within a specific time period.
187 | :param email: User email to analyze
188 | :param start_date: Start date for the analysis period, in 'YYYY-MM-DD' format
189 | :param end_date: End date for the analysis period, in 'YYYY-MM-DD' format
190 | :return: Dictionary containing aggregated user activity information, if the
191 | start and end dates are not provided, the default period is the last 7 days.
192 | """
193 | if end_date is None:
194 | end_datetime = datetime.now(timezone.utc).replace(hour=23, minute=59, second=59, microsecond=0)
195 | end_date = end_datetime.strftime("%Y-%m-%d")
196 | else:
197 | end_datetime = (datetime.strptime(end_date, "%Y-%m-%d")
198 | .replace(hour=23, minute=59, second=59, tzinfo=timezone.utc))
199 |
200 | if start_date is None:
201 | start_datetime = end_datetime - timedelta(days=6)
202 | start_date = start_datetime.strftime("%Y-%m-%d")
203 |
204 | # Fetch data
205 | commits = self.get_commit_history(start_date, end_date, email)
206 |
207 | # Aggregate commit information
208 | commit_count = len(commits)
209 | commit_messages = [commit['message'] for commit in commits]
210 |
211 | # Compile the report
212 | report = {
213 | 'username': email,
214 | 'period': {
215 | 'start': start_date,
216 | 'end': end_date
217 | },
218 | 'summary': {
219 | 'total_commits': commit_count,
220 | 'total_pull_requests': 0,
221 | 'total_issues': 0
222 | },
223 | 'commits': {
224 | 'count': commit_count,
225 | 'messages': commit_messages
226 | },
227 | 'pull_requests': {
228 | 'count': 0,
229 | 'details': []
230 | },
231 | 'issues': {
232 | 'count': 0,
233 | 'details': []
234 | }
235 | }
236 |
237 | return report
238 |
--------------------------------------------------------------------------------
/mle/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .anthropic import *
2 | from .deepseek import *
3 | from .mistral import *
4 | from .ollama import *
5 | from .openai import *
6 | from .gemini import *
7 | from .vllm import *
8 |
9 | from mle.utils import get_config
10 |
11 |
12 | MODEL_OLLAMA = 'Ollama'
13 | MODEL_OPENAI = 'OpenAI'
14 | MODEL_CLAUDE = 'Claude'
15 | MODEL_MISTRAL = 'MistralAI'
16 | MODEL_DEEPSEEK = 'DeepSeek'
17 | MODEL_GEMINI = 'Gemini'
18 | MODEL_VLLM = 'vLLM'
19 |
20 |
21 | class ObservableModel:
22 | """
23 | A class that wraps a model to make it trackable by the metric platform (e.g., Langfuse).
24 | """
25 |
26 | try:
27 | from mle.utils import get_langfuse_observer
28 | _observe = get_langfuse_observer()
29 | except Exception as e:
30 | # If importing fails, set _observe to a lambda function that does nothing.
31 | _observe = lambda fn: fn
32 |
33 | def __init__(self, model: Model):
34 | """
35 | Initialize the ObservableModel.
36 | Args:
37 | model: The model to be wrapped and made observable.
38 | """
39 | self.model = model
40 |
41 | @_observe
42 | def query(self, *args, **kwargs):
43 | return self.model.query(*args, **kwargs)
44 |
45 | @_observe
46 | def stream(self, *args, **kwargs):
47 | return self.model.query(*args, **kwargs)
48 |
49 |
50 | def load_model(project_dir: str, model_name: str=None, observable=True):
51 | """
52 | load_model: load the model based on the configuration.
53 | Args:
54 | project_dir (str): The project directory.
55 | model_name (str): The model name.
56 | observable (boolean): Whether the model should be tracked.
57 | """
58 | config = get_config(project_dir)
59 | model = None
60 |
61 | if config['platform'] == MODEL_OLLAMA:
62 | model = OllamaModel(model=model_name)
63 | if config['platform'] == MODEL_OPENAI:
64 | model = OpenAIModel(api_key=config['api_key'], model=model_name)
65 | if config['platform'] == MODEL_CLAUDE:
66 | model = ClaudeModel(api_key=config['api_key'], model=model_name)
67 | if config['platform'] == MODEL_MISTRAL:
68 | model = MistralModel(api_key=config['api_key'], model=model_name)
69 | if config['platform'] == MODEL_DEEPSEEK:
70 | model = DeepSeekModel(api_key=config['api_key'], model=model_name)
71 | if config['platform'] == MODEL_GEMINI:
72 | model = GeminiModel(api_key=config['api_key'], model=model_name)
73 | if config['platform'] == MODEL_VLLM:
74 | model = vLLMModel(base_url=config.get('base_url', 'http://localhost:8000/v1'), model=model_name)
75 |
76 | if observable:
77 | return ObservableModel(model)
78 | return model
79 |
--------------------------------------------------------------------------------
/mle/model/anthropic.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 |
3 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
4 | from mle.model.common import Model
5 |
6 |
7 | class ClaudeModel(Model):
8 | def __init__(self, api_key, model, temperature=0.7):
9 | """
10 | Initialize the Claude model.
11 | Args:
12 | api_key (str): The Anthropic API key.
13 | model (str): The model with version.
14 | temperature (float): The temperature value.
15 | """
16 | super().__init__()
17 |
18 | dependency = "anthropic"
19 | spec = importlib.util.find_spec(dependency)
20 | if spec is not None:
21 | self.anthropic = importlib.import_module(dependency).Anthropic
22 | else:
23 | raise ImportError(
24 | "It seems you didn't install anthropic. In order to enable the OpenAI client related features, "
25 | "please make sure openai Python package has been installed. "
26 | "More information, please refer to: https://docs.anthropic.com/en/api/client-sdks"
27 | )
28 |
29 | self.model = model if model else 'claude-3-5-sonnet-20240620'
30 | self.model_type = 'Claude'
31 | self.temperature = temperature
32 | self.client = self.anthropic(api_key=api_key)
33 | self.func_call_history = []
34 |
35 | @staticmethod
36 | def _add_tool_result_into_chat_history(chat_history, func, result):
37 | """
38 | Add the result of tool calls into messages.
39 | """
40 | return chat_history.extend([
41 | {
42 | "role": "assistant",
43 | "content": [
44 | {
45 | "type": "tool_use",
46 | "id": func.id,
47 | "name": func.name,
48 | "input": func.input,
49 | },
50 | ]
51 | },
52 | {
53 | "role": "user",
54 | "content": [
55 | {
56 | "type": "tool_result",
57 | "tool_use_id": func.id,
58 | "content": result,
59 | },
60 | ]
61 | },
62 | ])
63 |
64 | def query(self, chat_history, **kwargs):
65 | """
66 | Query the LLM model.
67 |
68 | Args:
69 | chat_history: The context (chat history).
70 | """
71 | # claude has not system role in chat_history
72 | # https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts
73 | system_prompt = ""
74 | for idx, msg in enumerate(chat_history):
75 | if msg["role"] == "system":
76 | system_prompt += msg["content"]
77 |
78 | # claude does not support mannual `response_format`, so we append it into system prompt
79 | if "response_format" in kwargs.keys():
80 | system_prompt += (
81 | f"\nOutputs only valid {kwargs['response_format']['type']} without any explanatory words"
82 | )
83 |
84 | # mapping the openai function_schema to claude tool_schema
85 | tools = kwargs.get("functions",[])
86 | for tool in tools:
87 | if "parameters" in tool.keys():
88 | tool["input_schema"] = tool["parameters"]
89 | del tool["parameters"]
90 |
91 | completion = self.client.messages.create(
92 | max_tokens=4096,
93 | model=self.model,
94 | system=system_prompt,
95 | messages=[msg for msg in chat_history if msg["role"] != "system"],
96 | temperature=self.temperature,
97 | stream=False,
98 | tools=tools,
99 | )
100 | if completion.stop_reason == "tool_use":
101 | for func in completion.content:
102 | if func.type != "tool_use":
103 | continue
104 | function_name = process_function_name(func.name)
105 | arguments = func.input
106 | print("[MLE FUNC CALL]: ", function_name)
107 | self.func_call_history.append({"name": function_name, "arguments": arguments})
108 | # avoid the multiple search function calls
109 | search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
110 | if len(search_attempts) > 3:
111 | kwargs['functions'] = []
112 | result = get_function(function_name)(**arguments)
113 | self._add_tool_result_into_chat_history(chat_history, func, result)
114 | return self.query(chat_history, **kwargs)
115 | else:
116 | return completion.content[0].text
117 |
118 | def stream(self, chat_history, **kwargs):
119 | """
120 | Stream the output from the LLM model.
121 | Args:
122 | chat_history: The context (chat history).
123 | """
124 | # claude has not system role in chat_history
125 | # https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts
126 | system_prompt = ""
127 | for idx, msg in enumerate(chat_history):
128 | if msg["role"] == "system":
129 | system_prompt += msg["content"]
130 | chat_history = [msg for msg in chat_history if msg["role"] != "system"]
131 |
132 | # claude does not support mannual `response_format`, so we append it into system prompt
133 | if "response_format" in kwargs.keys():
134 | system_prompt += (
135 | f"\nOutputs only valid {kwargs['response_format']['type']} without any explanatory words"
136 | )
137 |
138 | with self.client.messages.stream(
139 | max_tokens=4096,
140 | model=self.model,
141 | messages=chat_history,
142 | ) as stream:
143 | for chunk in stream.text_stream:
144 | yield chunk
145 |
--------------------------------------------------------------------------------
/mle/model/common.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class Model(ABC):
5 |
6 | def __init__(self):
7 | """
8 | Initialize the model.
9 | """
10 | self.model_type = None
11 |
12 | @abstractmethod
13 | def query(self, chat_history, **kwargs):
14 | pass
15 |
16 | @abstractmethod
17 | def stream(self, chat_history, **kwargs):
18 | pass
19 |
--------------------------------------------------------------------------------
/mle/model/deepseek.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 | import json
3 |
4 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
5 | from mle.model.common import Model
6 |
7 |
8 | class DeepSeekModel(Model):
9 | def __init__(self, api_key, model, temperature=0.7):
10 | """
11 | Initialize the DeepSeek model.
12 | Args:
13 | api_key (str): The DeepSeek API key.
14 | model (str): The model with version.
15 | temperature (float): The temperature value.
16 | """
17 | super().__init__()
18 |
19 | dependency = "openai"
20 | spec = importlib.util.find_spec(dependency)
21 | if spec is not None:
22 | self.openai = importlib.import_module(dependency).OpenAI
23 | else:
24 | raise ImportError(
25 | "It seems you didn't install openai. In order to enable the OpenAI client related features, "
26 | "please make sure openai Python package has been installed. "
27 | "More information, please refer to: https://openai.com/product"
28 | )
29 | self.model = model if model else "deepseek-coder"
30 | self.model_type = 'DeepSeek'
31 | self.temperature = temperature
32 | self.client = self.openai(
33 | api_key=api_key, base_url="https://api.deepseek.com/beta"
34 | )
35 | self.func_call_history = []
36 |
37 | def _convert_functions_to_tools(self, functions):
38 | """
39 | Convert OpenAI-style functions to DeepSeek-style tools.
40 | """
41 | tools = []
42 | for func in functions:
43 | tool = {
44 | "type": "function",
45 | "function": {
46 | "name": func["name"],
47 | "description": func.get("description", ""),
48 | "parameters": func["parameters"],
49 | },
50 | }
51 | tools.append(tool)
52 | return tools
53 |
54 | def query(self, chat_history, **kwargs):
55 | """
56 | Query the LLM model.
57 |
58 | Args:
59 | chat_history: The context (chat history).
60 | """
61 | functions = kwargs.get("functions", None)
62 | tools = self._convert_functions_to_tools(functions) if functions else None
63 | parameters = kwargs
64 | completion = self.client.chat.completions.create(
65 | model=self.model,
66 | messages=chat_history,
67 | temperature=self.temperature,
68 | stream=False,
69 | tools=tools,
70 | **parameters,
71 | )
72 |
73 | resp = completion.choices[0].message
74 | if resp.tool_calls:
75 | for tool_call in resp.tool_calls:
76 | chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
77 | function_name = process_function_name(tool_call.function.name)
78 | arguments = json.loads(tool_call.function.arguments)
79 | print("[MLE FUNC CALL]: ", function_name)
80 | self.func_call_history.append({"name": function_name, "arguments": arguments})
81 | # avoid the multiple search function calls
82 | search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
83 | if len(search_attempts) > 3:
84 | parameters['tool_choice'] = "none"
85 | result = get_function(function_name)(**arguments)
86 | chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id})
87 | return self.query(chat_history, **parameters)
88 | else:
89 | return resp.content
90 |
91 | def stream(self, chat_history, **kwargs):
92 | """
93 | Stream the output from the LLM model.
94 | Args:
95 | chat_history: The context (chat history).
96 | """
97 | arguments = ""
98 | function_name = ""
99 | for chunk in self.client.chat.completions.create(
100 | model=self.model,
101 | messages=chat_history,
102 | temperature=self.temperature,
103 | stream=True,
104 | **kwargs,
105 | ):
106 | if chunk.choices[0].delta.tool_calls:
107 | tool_call = chunk.choices[0].delta.tool_calls[0]
108 | if tool_call.function.name:
109 | chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
110 | function_name = process_function_name(tool_call.function.name)
111 | arguments = json.loads(tool_call.function.arguments)
112 | result = get_function(function_name)(**arguments)
113 | chat_history.append({"role": "tool", "content": result, "name": function_name})
114 | yield from self.stream(chat_history, **kwargs)
115 | else:
116 | yield chunk.choices[0].delta.content
117 |
--------------------------------------------------------------------------------
/mle/model/gemini.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib.util
3 | import json
4 |
5 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
6 | from mle.model.common import Model
7 |
8 |
9 | class GeminiModel(Model):
10 |
11 | def __init__(self, api_key, model, temperature=0.7):
12 | """
13 | Initialize the Gemini model.
14 | Args:
15 | api_key (str): The Gemini API key.
16 | model (str): The model with version.
17 | temperature (float): The temperature value.
18 | """
19 | super().__init__()
20 |
21 | dependency = "google.generativeai"
22 | spec = importlib.util.find_spec(dependency)
23 | if spec is not None:
24 | self.gemini = importlib.import_module(dependency)
25 | self.gemini.configure(api_key=api_key)
26 | else:
27 | raise ImportError(
28 | "It seems you didn't install `google-generativeai`. "
29 | "In order to enable the Gemini client related features, "
30 | "please make sure gemini Python package has been installed. "
31 | "More information, please refer to: https://ai.google.dev/gemini-api/docs/quickstart?lang=python"
32 | )
33 |
34 | self.model = model if model else 'gemini-1.5-flash'
35 | self.model_type = 'Gemini'
36 | self.temperature = temperature
37 | self.func_call_history = []
38 |
39 | def _map_chat_history_from_openai(self, chat_history):
40 | _key_map_dict = {
41 | "role": "role",
42 | "content": "parts",
43 | }
44 | _value_map_dict = {
45 | "system": "model",
46 | "user": "user",
47 | "assistant": "model",
48 | "content": "parts",
49 | }
50 | return [
51 | {
52 | _key_map_dict.get(k, k): _value_map_dict.get(v, v)
53 | for k, v in dict(chat).items()
54 | } for chat in chat_history
55 | ]
56 |
57 | def _map_functions_from_openai(self, functions):
58 | def _mapping_type(_type: str):
59 | if _type == "string":
60 | return self.gemini.protos.Type.STRING
61 | if _type == "object":
62 | return self.gemini.protos.Type.OBJECT
63 | if _type == "integer":
64 | return self.gemini.protos.Type.NUMBER
65 | if _type == "boolean":
66 | return self.gemini.protos.Type.BOOLEAN
67 | if _type == "array":
68 | return self.gemini.protos.Type.ARRAY
69 | return self.gemini.protos.Type.TYPE_UNSPECIFIED
70 |
71 | return self.gemini.protos.Tool(function_declarations=[
72 | self.gemini.protos.FunctionDeclaration(
73 | name=func.get("name"),
74 | description=func.get("description"),
75 | parameters=self.gemini.protos.Schema(
76 | type=_mapping_type(func.get("parameters", {}).get("type")),
77 | properties={
78 | param_name: self.gemini.protos.Schema(
79 | type=_mapping_type(properties.get("type")),
80 | description=properties.get("description")
81 | )
82 | for param_name, properties in \
83 | func.get("parameters",{}).get("properties", {}).items()
84 | },
85 | required=[key for key in func.get("parameters",{}).get("properties", {}).keys()],
86 | )
87 | )
88 | for func in functions
89 | ])
90 |
91 | def _mapping_response_format_from_openai(self, response_format):
92 | if response_format.get("type") == "json_object":
93 | return "application/json"
94 | return None
95 |
96 | def query(self, chat_history, **kwargs):
97 | """
98 | Query the LLM model.
99 |
100 | Args:
101 | chat_history: The context (chat history).
102 | """
103 | parameters = kwargs
104 | chat_history = self._map_chat_history_from_openai(chat_history)
105 |
106 | tools = None
107 | if parameters.get("functions") is not None:
108 | tools = self._map_functions_from_openai(parameters["functions"])
109 |
110 | client = self.gemini.GenerativeModel(self.model)
111 | chat_handler = client.start_chat(history=chat_history[:-1])
112 |
113 | completion = chat_handler.send_message(
114 | chat_history[-1]["parts"],
115 | tools=tools,
116 | generation_config=self.gemini.types.GenerationConfig(
117 | max_output_tokens=4096,
118 | temperature=self.temperature,
119 | ),
120 | )
121 |
122 | function_outputs = {}
123 | for part in completion.parts:
124 | fn = part.function_call
125 | if fn:
126 | print("[MLE FUNC CALL]: ", fn.name)
127 | # avoid the multiple search function calls
128 | search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
129 | if len(search_attempts) > 3:
130 | parameters['functions'] = None
131 | result = get_function(fn.name)(**dict(fn.args))
132 | function_outputs[fn.name] = result
133 |
134 | if len(function_outputs):
135 | response_parts = [
136 | self.gemini.protos.Part(
137 | function_response=self.gemini.protos.FunctionResponse(
138 | name=fn, response={"result": val}
139 | )
140 | )
141 | for fn, val in function_outputs.items()
142 | ]
143 |
144 | completion = chat_handler.send_message(
145 | self.gemini.protos.Content(parts=response_parts),
146 | generation_config=self.gemini.types.GenerationConfig(
147 | max_output_tokens=4096,
148 | temperature=self.temperature,
149 | response_mime_type=self._mapping_response_format_from_openai(
150 | parameters.get("response_format", {})),
151 | ),
152 | )
153 |
154 | return completion.text
155 |
156 | def stream(self, chat_history, **kwargs):
157 | """
158 | Stream the output from the LLM model.
159 | Args:
160 | chat_history: The context (chat history).
161 | """
162 | client = self.gemini.GenerativeModel(self.model)
163 | chat_handler = client.start_chat(history=chat_history[:-1])
164 |
165 | completions = chat_handler.send_message(
166 | chat_history[-1]["parts"],
167 | stream=True,
168 | generation_config=self.gemini.types.GenerationConfig(
169 | max_output_tokens=4096,
170 | temperature=self.temperature,
171 | ),
172 | )
173 |
174 | for chunk in completions:
175 | yield chunk.text
176 |
--------------------------------------------------------------------------------
/mle/model/mistral.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 | import json
3 |
4 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
5 | from mle.model.common import Model
6 |
7 |
8 | class MistralModel(Model):
9 | def __init__(self, api_key, model, temperature=0.7):
10 | """
11 | Initialize the Mistral model.
12 | Args:
13 | api_key (str): The Mistral API key.
14 | model (str): The model with version.
15 | temperature (float): The temperature value.
16 | """
17 | super().__init__()
18 |
19 | dependency = "mistralai"
20 | spec = importlib.util.find_spec(dependency)
21 | if spec is not None:
22 | self.mistral = importlib.import_module(dependency).Mistral
23 | else:
24 | raise ImportError(
25 | "It seems you didn't install mistralai. In order to enable the Mistral AI client related features, "
26 | "please make sure mistralai Python package has been installed. "
27 | "More information, please refer to: https://github.com/mistralai/client-python"
28 | )
29 |
30 | self.model = model if model else 'mistral-large-latest'
31 | self.model_type = 'MistralAI'
32 | self.temperature = temperature
33 | self.client = self.mistral(api_key=api_key)
34 | self.func_call_history = []
35 |
36 | def _convert_functions_to_tools(self, functions):
37 | """
38 | Convert OpenAI-style functions to Mistral-style tools.
39 | """
40 | tools = []
41 | for func in functions:
42 | tool = {
43 | "type": "function",
44 | "function": {
45 | "name": func["name"],
46 | "description": func.get("description", ""),
47 | "parameters": func["parameters"]
48 | }
49 | }
50 | tools.append(tool)
51 | return tools
52 |
53 | def query(self, chat_history, **kwargs):
54 | """
55 | Query the LLM model.
56 |
57 | Args:
58 | chat_history: The context (chat history).
59 | """
60 | functions = kwargs.get("functions",[])
61 | tools = self._convert_functions_to_tools(functions)
62 | tool_choice = kwargs.get('tool_choice', 'any')
63 | parameters = kwargs
64 | completion = self.client.chat.complete(
65 | model=self.model,
66 | messages=chat_history,
67 | temperature=self.temperature,
68 | stream=False,
69 | tools=tools,
70 | tool_choice=tool_choice,
71 | )
72 | resp = completion.choices[0].message
73 | if resp.tool_calls:
74 | for tool_call in resp.tool_calls:
75 | chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
76 | function_name = process_function_name(tool_call.function.name)
77 | arguments = json.loads(tool_call.function.arguments)
78 | print("[MLE FUNC CALL]: ", function_name)
79 | self.func_call_history.append({"name": function_name, "arguments": arguments})
80 | # avoid the multiple search function calls
81 | search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
82 | if len(search_attempts) > 3:
83 | parameters['tool_choice'] = "none"
84 | result = get_function(function_name)(**arguments)
85 | chat_history.append({"role": "tool", "content": result, "name": function_name, "tool_call_id":tool_call.id})
86 | return self.query(chat_history, **parameters)
87 | else:
88 | return resp.content
89 |
90 | def stream(self, chat_history, **kwargs):
91 | """
92 | Stream the output from the LLM model.
93 | Args:
94 | chat_history: The context (chat history).
95 | """
96 | functions = kwargs.get("functions",[])
97 | tools = self._convert_functions_to_tools(functions)
98 | tool_choice = kwargs.get('tool_choice', 'any')
99 | for chunk in self.client.chat.complete(
100 | model=self.model,
101 | messages=chat_history,
102 | temperature=self.temperature,
103 | stream=True,
104 | tools=tools,
105 | tool_choice=tool_choice
106 | ):
107 | if chunk.choices[0].delta.tool_calls:
108 | tool_call = chunk.choices[0].delta.tool_calls[0]
109 | if tool_call.function.name:
110 | chat_history.append({"role": "assistant", "content": '', "tool_calls": [tool_call], "prefix":False})
111 | function_name = process_function_name(tool_call.function.name)
112 | arguments = json.loads(tool_call.function.arguments)
113 | result = get_function(function_name)(**arguments)
114 | chat_history.append({"role": "tool", "content": result, "name": function_name})
115 | yield from self.stream(chat_history, **kwargs)
116 | else:
117 | yield chunk.choices[0].delta.content
118 |
--------------------------------------------------------------------------------
/mle/model/ollama.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 | import re
3 |
4 | from .common import Model
5 |
6 |
7 | class OllamaModel(Model):
8 | def __init__(self, model, host_url=None):
9 | """
10 | Initialize the Ollama model.
11 | Args:
12 | host_url (str): The Ollama Host url.
13 | model (str): The model version.
14 | """
15 | super().__init__()
16 |
17 | dependency = "ollama"
18 | spec = importlib.util.find_spec(dependency)
19 | if spec is not None:
20 | self.model = model if model else 'llama3'
21 | self.model_type = 'Ollama'
22 | self.ollama = importlib.import_module(dependency)
23 | self.client = self.ollama.Client(host=host_url)
24 | else:
25 | raise ImportError(
26 | "It seems you didn't install ollama. In order to enable the Ollama client related features, "
27 | "please make sure ollama Python package has been installed. "
28 | "More information, please refer to: https://github.com/ollama/ollama-python"
29 | )
30 |
31 | def _clean_think_tags(self, text):
32 | """
33 | Remove content between tags and empty think tags from the text.
34 | Args:
35 | text (str): The input text to clean.
36 | Returns:
37 | str: The cleaned text with think tags and their content removed.
38 | """
39 | # Remove content between tags
40 | text = re.sub(r'.*? ', '', text, flags=re.DOTALL)
41 | # Remove empty think tags
42 | text = re.sub(r' ', '', text)
43 | return text.strip()
44 |
45 | def _process_message(self, message, **kwargs):
46 | """
47 | Process the message before sending to the model.
48 | Args:
49 | message: The message to process.
50 | **kwargs: Additional arguments.
51 | Returns:
52 | dict: The processed message.
53 | """
54 | if isinstance(message, dict) and 'content' in message:
55 | message['content'] = self._clean_think_tags(message['content'])
56 | return message
57 |
58 | def query(self, chat_history, **kwargs):
59 | """
60 | Query the LLM model.
61 | Args:
62 | chat_history: The context (chat history).
63 | **kwargs: Additional arguments for the model.
64 | Returns:
65 | str: The model's response.
66 | """
67 |
68 | # Check if 'response_format' exists in kwargs
69 | format = None
70 | if 'response_format' in kwargs and kwargs['response_format'].get('type') == 'json_object':
71 | format = 'json'
72 |
73 | response = self.client.chat(model=self.model, messages=chat_history, format=format)
74 | return self._clean_think_tags(response['message']['content'])
75 |
76 | def stream(self, chat_history, **kwargs):
77 | """
78 | Stream the output from the LLM model.
79 | Args:
80 | chat_history: The context (chat history).
81 | **kwargs: Additional arguments for the model.
82 | Yields:
83 | str: Chunks of the model's response.
84 | """
85 |
86 | for chunk in self.client.chat(
87 | model=self.model,
88 | messages=chat_history,
89 | stream=True
90 | ):
91 | yield self._clean_think_tags(chunk['message']['content'])
92 |
--------------------------------------------------------------------------------
/mle/model/openai.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib.util
3 | import json
4 |
5 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
6 | from mle.model.common import Model
7 |
8 |
9 | class OpenAIModel(Model):
10 | def __init__(self, api_key, model, temperature=0.7):
11 | """
12 | Initialize the OpenAI model.
13 | Args:
14 | api_key (str): The OpenAI API key.
15 | model (str): The model with version.
16 | temperature (float): The temperature value.
17 | """
18 | super().__init__()
19 |
20 | dependency = "openai"
21 | spec = importlib.util.find_spec(dependency)
22 | if spec is not None:
23 | self.openai = importlib.import_module(dependency).OpenAI
24 | else:
25 | raise ImportError(
26 | "It seems you didn't install openai. In order to enable the OpenAI client related features, "
27 | "please make sure openai Python package has been installed. "
28 | "More information, please refer to: https://openai.com/product"
29 | )
30 |
31 | self.model = model if model else 'gpt-4o-2024-08-06'
32 | self.model_type = 'OpenAI'
33 | self.temperature = temperature
34 | self.client = self.openai(
35 | api_key=api_key,
36 | base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"),
37 | )
38 | self.func_call_history = []
39 |
40 | def query(self, chat_history, **kwargs):
41 | """
42 | Query the LLM model.
43 |
44 | Args:
45 | chat_history: The context (chat history).
46 | """
47 | parameters = kwargs
48 | completion = self.client.chat.completions.create(
49 | model=self.model,
50 | messages=chat_history,
51 | temperature=self.temperature,
52 | stream=False,
53 | **parameters
54 | )
55 |
56 | resp = completion.choices[0].message
57 | if resp.function_call:
58 | function_name = process_function_name(resp.function_call.name)
59 | arguments = json.loads(resp.function_call.arguments)
60 | print("[MLE FUNC CALL]: ", function_name)
61 | self.func_call_history.append({"name": function_name, "arguments": arguments})
62 | # avoid the multiple search function calls
63 | search_attempts = [item for item in self.func_call_history if item['name'] in SEARCH_FUNCTIONS]
64 | if len(search_attempts) > 3:
65 | parameters['function_call'] = "none"
66 | result = get_function(function_name)(**arguments)
67 | chat_history.append({"role": "assistant", "function_call": dict(resp.function_call)})
68 | chat_history.append({"role": "function", "content": result, "name": function_name})
69 | return self.query(chat_history, **parameters)
70 | else:
71 | return resp.content
72 |
73 | def stream(self, chat_history, **kwargs):
74 | """
75 | Stream the output from the LLM model.
76 | Args:
77 | chat_history: The context (chat history).
78 | """
79 | arguments = ''
80 | function_name = ''
81 | for chunk in self.client.chat.completions.create(
82 | model=self.model,
83 | messages=chat_history,
84 | temperature=self.temperature,
85 | stream=True,
86 | **kwargs
87 | ):
88 | delta = chunk.choices[0].delta
89 | if delta.function_call:
90 | if delta.function_call.name:
91 | function_name = process_function_name(delta.function_call.name)
92 | if delta.function_call.arguments:
93 | arguments += delta.function_call.arguments
94 |
95 | if chunk.choices[0].finish_reason == "function_call":
96 | result = get_function(function_name)(**json.loads(arguments))
97 | chat_history.append({"role": "function", "content": result, "name": function_name})
98 | yield from self.stream(chat_history, **kwargs)
99 | else:
100 | yield delta.content
101 |
--------------------------------------------------------------------------------
/mle/model/vllm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import importlib.util
4 | from typing import List, Dict, Any, Optional
5 |
6 | from mle.model.common import Model
7 | from mle.function import SEARCH_FUNCTIONS, get_function, process_function_name
8 |
9 |
10 | class vLLMModel(Model):
11 | """vLLM model implementation using OpenAI-compatible API."""
12 |
13 | def __init__(self, base_url: Optional[str] = None,
14 | model: Optional[str] = None,
15 | temperature: float = 0.7) -> None:
16 | """Initialize the vLLM model.
17 |
18 | Args:
19 | base_url: The URL of the vLLM server.
20 | model: The model name.
21 | temperature: The sampling temperature.
22 | """
23 | super().__init__()
24 |
25 | dependency = "openai"
26 | spec = importlib.util.find_spec(dependency)
27 | if spec is not None:
28 | self.openai = importlib.import_module(dependency).OpenAI
29 | else:
30 | raise ImportError(
31 | "OpenAI package not found. Please install it using: "
32 | "pip install openai"
33 | )
34 |
35 | self.model = model if model else 'mistralai/Mistral-7B-Instruct-v0.3'
36 | self.model_type = 'vLLM'
37 | self.temperature = temperature
38 | self.client = self.openai(
39 | api_key="EMPTY",
40 | base_url=base_url or os.getenv("vLLM_BASE_URL",
41 | "http://localhost:8000/v1"),
42 | timeout=60.0,
43 | max_retries=2,
44 | )
45 | self.func_call_history = []
46 |
47 | def normalize_chat_history(self, chat_history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
48 | """Normalize chat history to ensure it follows the required format.
49 |
50 | Args:
51 | chat_history: The original chat history.
52 |
53 | Returns:
54 | Normalized chat history list.
55 | """
56 | normalized = []
57 |
58 | # Handle system message first
59 | system_messages = [msg for msg in chat_history if msg["role"] == "system"]
60 | if system_messages:
61 | normalized.append(system_messages[0])
62 |
63 | # Add other messages in order
64 | for msg in chat_history:
65 | if msg["role"] == "system":
66 | continue
67 |
68 | if msg["role"] in ["user", "assistant", "function"]:
69 | normalized.append(msg)
70 |
71 | return normalized
72 |
73 | def query(self, chat_history: List[Dict[str, Any]], **kwargs) -> str:
74 | """Query the LLM model.
75 |
76 | Args:
77 | chat_history: The context (chat history).
78 | **kwargs: Additional parameters for the API call.
79 |
80 | Returns:
81 | Model's response as string.
82 |
83 | Raises:
84 | Exception: If the API call fails.
85 | """
86 | try:
87 | normalized_history = self.normalize_chat_history(chat_history)
88 | parameters = kwargs
89 | completion = self.client.chat.completions.create(
90 | model=self.model,
91 | messages=normalized_history,
92 | temperature=self.temperature,
93 | stream=False,
94 | **parameters
95 | )
96 |
97 | resp = completion.choices[0].message
98 | if resp.function_call:
99 | function_name = process_function_name(resp.function_call.name)
100 | arguments = json.loads(resp.function_call.arguments)
101 | print("[MLE FUNC CALL]: ", function_name)
102 | self.func_call_history.append({
103 | "name": function_name,
104 | "arguments": arguments
105 | })
106 |
107 | # Avoid multiple search function calls
108 | search_attempts = [
109 | item for item in self.func_call_history
110 | if item['name'] in SEARCH_FUNCTIONS
111 | ]
112 | if len(search_attempts) > 3:
113 | parameters['function_call'] = "none"
114 |
115 | result = get_function(function_name)(**arguments)
116 | chat_history.append({
117 | "role": "assistant",
118 | "function_call": dict(resp.function_call)
119 | })
120 | chat_history.append({
121 | "role": "function",
122 | "content": result,
123 | "name": function_name
124 | })
125 | return self.query(chat_history, **parameters)
126 | return resp.content
127 |
128 | except Exception as e:
129 | error_msg = f"vLLM API error: {str(e)}"
130 | print(f"Error during vLLM query: {error_msg}")
131 | if hasattr(e, 'response'):
132 | print(f"Response status: {e.response.status_code}")
133 | print(f"Response body: {e.response.text}")
134 | raise Exception(error_msg)
135 |
136 | def stream(self, chat_history: List[Dict[str, Any]], **kwargs) -> str:
137 | """Stream the output from the LLM model.
138 |
139 | Args:
140 | chat_history: The context (chat history).
141 | **kwargs: Additional parameters for the API call.
142 |
143 | Yields:
144 | Chunks of the model's response.
145 |
146 | Raises:
147 | Exception: If the streaming fails.
148 | """
149 | try:
150 | arguments = ''
151 | function_name = ''
152 | for chunk in self.client.chat.completions.create(
153 | model=self.model,
154 | messages=chat_history,
155 | temperature=self.temperature,
156 | stream=True,
157 | **kwargs
158 | ):
159 | delta = chunk.choices[0].delta
160 | if delta.function_call:
161 | if delta.function_call.name:
162 | function_name = process_function_name(
163 | delta.function_call.name
164 | )
165 | if delta.function_call.arguments:
166 | arguments += delta.function_call.arguments
167 |
168 | if chunk.choices[0].finish_reason == "function_call":
169 | result = get_function(function_name)(**json.loads(arguments))
170 | chat_history.append({
171 | "role": "function",
172 | "content": result,
173 | "name": function_name
174 | })
175 | yield from self.stream(chat_history, **kwargs)
176 | else:
177 | yield delta.content
178 |
179 | except Exception as e:
180 | error_msg = f"vLLM streaming error: {str(e)}"
181 | print(f"Error during vLLM streaming: {error_msg}")
182 | if hasattr(e, 'response'):
183 | print(f"Response status: {e.response.status_code}")
184 | print(f"Response body: {e.response.text}")
185 | raise Exception(error_msg)
186 |
--------------------------------------------------------------------------------
/mle/server/__init__.py:
--------------------------------------------------------------------------------
1 | from .app import app
2 |
--------------------------------------------------------------------------------
/mle/server/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from typing import Optional
4 | from datetime import datetime
5 | from pydantic import BaseModel
6 | from fastapi import FastAPI, HTTPException, BackgroundTasks
7 | from fastapi.responses import JSONResponse
8 | from fastapi.middleware.cors import CORSMiddleware
9 |
10 | from mle.workflow import report
11 | from mle.utils import check_config
12 |
13 | app = FastAPI()
14 |
15 | # Add CORS middleware
16 | app.add_middleware(
17 | CORSMiddleware,
18 | allow_origins=["*"],
19 | allow_credentials=True,
20 | allow_methods=["*"],
21 | allow_headers=["*"],
22 | )
23 |
24 |
25 | class ReportRequest(BaseModel):
26 | """
27 | ReportRequest: the request body for generating a report
28 | """
29 | repo: str
30 | username: str
31 | token: Optional[str] = None
32 | okr: Optional[str] = None
33 |
34 |
35 | @app.get("/")
36 | def root():
37 | """
38 | read_root: read the root.
39 | :return: the root.
40 | """
41 | return {"Welcome to": "MLE-Agent!"}
42 |
43 |
44 | @app.get("/latest_report", response_class=JSONResponse)
45 | def read_latest_report():
46 | """
47 | read_latest_report: read the latest progress report.
48 | :return: the content of the latest progress report as plain text.
49 | """
50 | if not check_config():
51 | raise HTTPException(
52 | status_code=400,
53 | detail="`project.yml` not found. Please start the MLE server under an MLE-Agent project directory."
54 | )
55 |
56 | reports_dir = os.getcwd()
57 | report_files = [f for f in os.listdir(reports_dir) if f.startswith("progress_report_") and f.endswith(".json")]
58 |
59 | if not report_files:
60 | raise HTTPException(status_code=404, detail="No progress reports found.")
61 |
62 | latest_report = max(report_files, key=lambda f: datetime.strptime(f, "progress_report_%Y_%m_%d.json"))
63 | try:
64 | with open(os.path.join(reports_dir, latest_report), 'r') as file:
65 | report_dict = json.load(file)
66 | report_dict.update({"file": latest_report})
67 | return JSONResponse(content=report_dict)
68 | except IOError:
69 | raise HTTPException(status_code=500, detail="Error reading the latest report file.")
70 |
71 |
72 | @app.post("/gen_report")
73 | def gen_report(report_request: ReportRequest):
74 | """
75 | Generate a report synchronously based on the provided GitHub repository and username.
76 | Optionally includes OKR text.
77 |
78 | Example payload:
79 |
80 | curl -X POST http://localhost:8000/gen_report \
81 | -H "Content-Type: application/json" \
82 | -d '{
83 | "token": "***",
84 | "repo": "MLSysOps/MLE-agent",
85 | "username": "huangyz0918",
86 | "okr": "Improve system efficiency by 20% this quarter"
87 | }'
88 | """
89 | try:
90 | # Run report generation synchronously
91 | result = report(
92 | os.getcwd(),
93 | report_request.repo,
94 | report_request.username,
95 | report_request.token,
96 | okr_str=report_request.okr,
97 | model="gpt-4o",
98 | )
99 |
100 | return {
101 | "message": "Report generation completed",
102 | "repo": report_request.repo,
103 | "username": report_request.username,
104 | "okr_provided": report_request.okr is not None,
105 | "result": result # Assuming the report function returns some result
106 | }
107 | except Exception as e:
108 | raise HTTPException(status_code=500, detail=f"Error in report generation process: {e}")
109 |
110 |
111 | @app.post("/gen_report_async")
112 | async def gen_report_async(report_request: ReportRequest, background_tasks: BackgroundTasks):
113 | """
114 | Generate a report (async) based on the provided GitHub repository and username.
115 | Optionally includes OKR text.
116 |
117 | Example payload:
118 |
119 | curl -X POST http://localhost:8000/gen_report_async \
120 | -H "Content-Type: application/json" \
121 | -d '{
122 | "token": "***",
123 | "repo": "MLSysOps/MLE-agent",
124 | "username": "huangyz0918",
125 | "okr": "Improve system efficiency by 20% this quarter"
126 | }'
127 | """
128 | try:
129 | # Trigger report generation in the background
130 | background_tasks.add_task(
131 | report,
132 | os.getcwd(),
133 | report_request.repo,
134 | report_request.username,
135 | report_request.token,
136 | okr_str=report_request.okr,
137 | model="gpt-4o",
138 | )
139 |
140 | return {
141 | "message": "Report generation started",
142 | "repo": report_request.repo,
143 | "username": report_request.username,
144 | "okr_provided": report_request.okr is not None
145 | }
146 | except Exception as e:
147 | raise HTTPException(status_code=500, detail=f"Error in report generation process: {e}")
148 |
--------------------------------------------------------------------------------
/mle/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .system import *
2 | from .cache import *
3 | from .memory import *
4 | from .data import *
5 | from .chunk import *
6 |
--------------------------------------------------------------------------------
/mle/utils/cache.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from typing import Dict, Any, Optional
3 | from datetime import datetime
4 |
5 | from mle.utils.system import get_config, write_config
6 |
7 |
8 | class WorkflowCacheOperator:
9 | """
10 | WorkflowCacheOperator handles the storing and resuming of cache content.
11 | """
12 |
13 | def __init__(self, cache: 'WorkflowCache', cache_content: Dict[str, Any]):
14 | """
15 | Args:
16 | cache: The cache instance to which this operator belongs.
17 | cache_content (Dict[str, object]): A dictionary holding the cached content.
18 | """
19 | self.cache = cache
20 | self.cache_content = cache_content
21 |
22 | def store(self, key: str, value: Any) -> None:
23 | """
24 | Store a value into the cache content.
25 |
26 | Args:
27 | key (str): The key under which the value is stored.
28 | value (object): The value to be stored.
29 | """
30 | self.cache_content[key] = pickle.dumps(value, fix_imports=False)
31 |
32 | def resume(self, key: str) -> Any:
33 | """
34 | Resume a value from the cache content.
35 |
36 | Args:
37 | key (str): The key of the value to be resumed.
38 |
39 | Returns:
40 | object: The resumed value, or None if the key does not exist.
41 | """
42 | if key in self.cache_content:
43 | return pickle.loads(self.cache_content[key])
44 | return None
45 |
46 | def __enter__(self):
47 | """
48 | Enter the runtime context related to this object.
49 |
50 | Returns:
51 | WorkflowCacheOperator: self
52 | """
53 | return self
54 |
55 | def __exit__(self, exc_type, exc_val, exc_tb):
56 | """
57 | Exit the runtime context related to this object.
58 |
59 | Args:
60 | exc_type: The exception type.
61 | exc_val: The exception value.
62 | exc_tb: The traceback object.
63 | """
64 | if exc_type is None:
65 | self.cache._store_cache_buffer()
66 |
67 |
68 | class WorkflowCache:
69 | """
70 | WorkflowCache manages the caching for workflows, providing
71 | methods to load, store, and remove cached steps.
72 | """
73 |
74 | def __init__(self, project_dir: str, workflow: str = 'baseline'):
75 | """
76 | Initialize WorkflowCache with a project directory.
77 |
78 | Args:
79 | project_dir (str): The directory of the project.
80 | workflow (str): The name of the cached workflow.
81 | """
82 | self.project_dir = project_dir
83 | self.workflow = workflow
84 | self.buffer = self._load_cache_buffer(workflow)
85 | self.cache: Dict[int, Dict[str, Any]] = self.buffer["cache"][workflow]
86 |
87 | def is_empty(self) -> bool:
88 | """
89 | Check if the cache is empty.
90 |
91 | Returns:
92 | bool: True if the cache is empty, False otherwise.
93 | """
94 | return len(self.cache) == 0
95 |
96 | def remove(self, step: int) -> None:
97 | """
98 | Remove a step from the cache.
99 |
100 | Args:
101 | step (int): The step index to be removed.
102 | """
103 | self.cache.pop(step, None)
104 | self._store_cache_buffer()
105 |
106 | def current_step(self) -> int:
107 | """
108 | Get the current step from the cache.
109 |
110 | Returns:
111 | int: The current step.
112 | """
113 | return max(self.cache.keys()) if self.cache else 0
114 |
115 | def resume_variable(self, key: str, step: Optional[int] = None):
116 | """
117 | Resume the cached variable.
118 |
119 | Args:
120 | key (str): The key of the value to be resumed.
121 | step (str): The step to be initialized.
122 |
123 | Returns:
124 | object: The resumed value, or None if the key does not exist.
125 | """
126 | if step is not None:
127 | return self.__call__(step).resume(key)
128 | else:
129 | for step in range(self.current_step() + 1):
130 | value = self.resume_variable(key, step)
131 | if value is not None:
132 | return value
133 | return None
134 |
135 | def _load_cache_buffer(self, workflow: str) -> Dict[str, Any]:
136 | """
137 | Load the cache buffer from the configuration.
138 |
139 | Args:
140 | workflow (str): The name of the cached workflow.
141 |
142 | Returns:
143 | dict: The buffer loaded from the configuration.
144 | """
145 | buffer = get_config() or {}
146 | if "cache" not in buffer.keys():
147 | buffer["cache"] = {}
148 | if workflow not in buffer["cache"].keys():
149 | buffer["cache"][workflow] = {}
150 | return buffer
151 |
152 | def _store_cache_buffer(self) -> None:
153 | """
154 | Store the cache buffer to the configuration.
155 | """
156 | write_config(self.buffer)
157 |
158 | def __call__(self, step: int, name: Optional[str] = None) -> WorkflowCacheOperator:
159 | """
160 | Initialize the cache content for a given step and name.
161 |
162 | Args:
163 | step (str): The step to be initialized.
164 | name (str): The name associated with the step.
165 |
166 | Returns:
167 | WorkflowCacheOperator: An instance of WorkflowCacheOperator.
168 | """
169 | if step not in self.cache.keys():
170 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
171 | self.cache[step] = {
172 | "step": step,
173 | "name": name,
174 | "time": timestamp,
175 | "content": {},
176 | }
177 | cache_content = self.cache[step]["content"]
178 | return WorkflowCacheOperator(self, cache_content)
179 |
180 | def __str__(self) -> str:
181 | """
182 | Return a string representation of the step cache list.
183 |
184 | Returns:
185 | str: The string representation of the cache.
186 | """
187 | return "\n".join(f"[{k}] {v['name']} ({v['time']})" for k, v in self.cache.items())
--------------------------------------------------------------------------------
/mle/utils/chunk.py:
--------------------------------------------------------------------------------
1 | # Source modified from https://github.com/CintraAI/code-chunker/blob/main/Chunker.py
2 | import tiktoken
3 | from .parser import CodeParser
4 | from abc import ABC, abstractmethod
5 |
6 |
7 | def count_tokens(string: str, encoding_name: str) -> int:
8 | encoding = tiktoken.encoding_for_model(encoding_name)
9 | num_tokens = len(encoding.encode(string))
10 | return num_tokens
11 |
12 |
13 | class Chunker(ABC):
14 | def __init__(self, encoding_name="gpt-4"):
15 | self.encoding_name = encoding_name
16 |
17 | @abstractmethod
18 | def chunk(self, content, token_limit):
19 | pass
20 |
21 | @abstractmethod
22 | def get_chunk(self, chunked_content, chunk_number):
23 | pass
24 |
25 | @staticmethod
26 | def print_chunks(chunks):
27 | for chunk_number, chunk_code in chunks.items():
28 | print(f"Chunk {chunk_number}:")
29 | print("=" * 40)
30 | print(chunk_code)
31 | print("=" * 40)
32 |
33 | @staticmethod
34 | def consolidate_chunks_into_file(chunks):
35 | return "\n".join(chunks.values())
36 |
37 | @staticmethod
38 | def count_lines(consolidated_chunks):
39 | lines = consolidated_chunks.split("\n")
40 | return len(lines)
41 |
42 |
43 | class CodeChunker(Chunker):
44 | def __init__(self, cache_dir, file_extension, encoding_name="gpt-4o-mini"):
45 | super().__init__(encoding_name)
46 | self.file_extension = file_extension
47 | self.cache_dir = cache_dir
48 |
49 | def chunk(self, code, token_limit) -> dict:
50 | code_parser = CodeParser(self.cache_dir, self.file_extension)
51 | chunks = {}
52 | token_count = 0
53 | lines = code.split("\n")
54 | i = 0
55 | chunk_number = 1
56 | start_line = 0
57 | breakpoints = sorted(code_parser.get_lines_for_points_of_interest(code, self.file_extension))
58 | comments = sorted(code_parser.get_lines_for_comments(code, self.file_extension))
59 | adjusted_breakpoints = []
60 | for bp in breakpoints:
61 | current_line = bp - 1
62 | highest_comment_line = None # Initialize with None to indicate no comment line has been found yet
63 | while current_line in comments:
64 | highest_comment_line = current_line # Update highest comment line found
65 | current_line -= 1 # Move to the previous line
66 |
67 | if highest_comment_line: # If a highest comment line exists, add it
68 | adjusted_breakpoints.append(highest_comment_line)
69 | else:
70 | adjusted_breakpoints.append(
71 | bp) # If no comments were found before the breakpoint, add the original breakpoint
72 |
73 | breakpoints = sorted(set(adjusted_breakpoints)) # Ensure breakpoints are unique and sorted
74 |
75 | while i < len(lines):
76 | line = lines[i]
77 | new_token_count = count_tokens(line, self.encoding_name)
78 | if token_count + new_token_count > token_limit:
79 |
80 | # Set the stop line to the last breakpoint before the current line
81 | if i in breakpoints:
82 | stop_line = i
83 | else:
84 | stop_line = max(max([x for x in breakpoints if x < i], default=start_line), start_line)
85 |
86 | # If the stop line is the same as the start line, it means we haven't reached a breakpoint yet, and we need to move to the next line to find one
87 | if stop_line == start_line and i not in breakpoints:
88 | token_count += new_token_count
89 | i += 1
90 |
91 | # If the stop line is the same as the start line and the current line is a breakpoint, it means we can create a chunk with just the current line
92 | elif stop_line == start_line and i == stop_line:
93 | token_count += new_token_count
94 | i += 1
95 |
96 | # If the stop line is the same as the start line and the current line is a breakpoint, it means we can create a chunk with just the current line
97 | elif stop_line == start_line and i in breakpoints:
98 | current_chunk = "\n".join(lines[start_line:stop_line])
99 | if current_chunk.strip(): # If the current chunk is not just whitespace
100 | chunks[chunk_number] = current_chunk # Using chunk_number as key
101 | chunk_number += 1
102 |
103 | token_count = 0
104 | start_line = i
105 | i += 1
106 |
107 | # If the stop line is different from the start line, it means we're at the end of a block
108 | else:
109 | current_chunk = "\n".join(lines[start_line:stop_line])
110 | if current_chunk.strip():
111 | chunks[chunk_number] = current_chunk # Using chunk_number as key
112 | chunk_number += 1
113 |
114 | i = stop_line
115 | token_count = 0
116 | start_line = stop_line
117 | else:
118 | # If the token count is still within the limit, add the line to the current chunk
119 | token_count += new_token_count
120 | i += 1
121 |
122 | # Append remaining code, if any, ensuring it's not empty or whitespace
123 | current_chunk_code = "\n".join(lines[start_line:])
124 | if current_chunk_code.strip(): # Checks if the chunk is not just whitespace
125 | chunks[chunk_number] = current_chunk_code # Using chunk_number as key
126 |
127 | return chunks
128 |
129 | def get_chunk(self, chunked_codebase, chunk_number):
130 | return chunked_codebase[chunk_number]
131 |
--------------------------------------------------------------------------------
/mle/utils/data.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import json
4 | from typing import Dict, Any
5 |
6 |
7 | def dict_to_markdown(data: Dict[str, Any], file_path: str) -> None:
8 | """
9 | Write a dictionary to a markdown file.
10 | :param data: the dictionary to write.
11 | :param file_path: the file path to write the dictionary to.
12 | :return:
13 | """
14 |
15 | def write_item(k, v, indent_level=0):
16 | if isinstance(v, dict):
17 | md_file.write(f"{'##' * (indent_level + 1)} {k}\n")
18 | for sub_key, sub_value in v.items():
19 | write_item(sub_key, sub_value, indent_level + 1)
20 | elif isinstance(v, list):
21 | md_file.write(f"{'##' * (indent_level + 1)} {k}\n")
22 | for item in v:
23 | md_file.write(f"{' ' * indent_level}- {item}\n")
24 | else:
25 | md_file.write(f"{'##' * (indent_level + 1)} {k}\n")
26 | md_file.write(f"{' ' * indent_level}{v}\n")
27 |
28 | with open(file_path, 'w') as md_file:
29 | for key, value in data.items():
30 | write_item(key, value)
31 | md_file.write("\n")
32 |
33 |
34 | def is_markdown_file(file_path):
35 | """
36 | Check if the file is a Markdown file.
37 | :param file_path: the file path
38 | :return: boolean
39 | """
40 | if not os.path.isfile(file_path):
41 | return False
42 |
43 | valid_extensions = ['.md', '.markdown', '.mdown', '.mkdn', '.mkd', '.mdwn', '.mdtxt', '.mdtext', '.text', '.Rmd']
44 | file_extension = os.path.splitext(file_path)[1].lower()
45 |
46 | return file_extension in valid_extensions
47 |
48 |
49 | def read_markdown(file_path, include_links=False, include_images=False):
50 | """
51 | Read the markdown file and return the content.
52 | :param file_path: the file path to the .md file
53 | :param include_links: the flag to include the links
54 | :param include_images: the flag to include the images
55 | :return: the raw content of the markdown file
56 | """
57 | try:
58 | with open(file_path, 'r', encoding='utf-8') as file:
59 | content = file.read()
60 |
61 | if not include_links:
62 | content = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', content)
63 |
64 | if not include_images:
65 | content = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', '', content)
66 |
67 | return content.strip()
68 | except FileNotFoundError:
69 | return f"Error: File not found at {file_path}"
70 | except Exception as e:
71 | return f"Error: An unexpected error occurred - {str(e)}"
72 |
73 |
74 | def clean_json_string(input_string):
75 | """
76 | clean the json string
77 | :input_string: the input json string
78 | """
79 | cleaned = input_string.strip()
80 | cleaned = re.sub(r'^```\s*json?\s*', '', cleaned)
81 | cleaned = re.sub(r'\s*```\s*$', '', cleaned)
82 | parsed_json = json.loads(cleaned)
83 | return parsed_json
84 |
--------------------------------------------------------------------------------
/mle/utils/memory.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from typing import List, Dict, Optional
3 |
4 | import lancedb
5 | from lancedb.embeddings import get_registry
6 |
7 | from mle.utils import get_config
8 |
9 |
10 | class LanceDBMemory:
11 |
12 | def __init__(self, project_path: str):
13 | """
14 | Memory: A base class for memory and external knowledge management.
15 | Args:
16 | project_path: the path to store the data.
17 | """
18 | self.db_name = '.mle'
19 | self.table_name = 'memory'
20 | self.client = lancedb.connect(uri=self.db_name)
21 |
22 | config = get_config(project_path)
23 | if config["platform"] == "OpenAI":
24 | self.text_embedding = get_registry().get("openai").create(api_key=config["api_key"])
25 | else:
26 | self.text_embedding = get_registry().get("sentence-transformers").create(
27 | name="sentence-transformers/paraphrase-MiniLM-L6-v2"
28 | )
29 |
30 | def _open_table(self, table_name: str = None):
31 | """
32 | Open a LanceDB table by table name. (Return None if not exists)
33 | Args:
34 | table_name (Optional[str]): The name of the table. Defaults to self.table_name.
35 | """
36 | table_name = table_name or self.table_name
37 | try:
38 | table = self.client.open_table(table_name)
39 | except FileNotFoundError:
40 | return None
41 | return table
42 |
43 | def add(
44 | self,
45 | texts: List[str],
46 | metadata: Optional[List[Dict]] = None,
47 | table_name: Optional[str] = None,
48 | ids: Optional[List[str]] = None,
49 | ) -> List[str]:
50 | """
51 | Adds a list of text items to the specified memory table in the database.
52 |
53 | Args:
54 | texts (List[str]): A list of text strings to be added.
55 | metadata (Optional[List[Dict]]): A list of metadata to be added.
56 | table_name (Optional[str]): The name of the table to add data to. Defaults to self.table_name.
57 | ids (Optional[List[str]]): A list of unique IDs for the text items.
58 | If not provided, random UUIDs are generated.
59 |
60 | Returns:
61 | List[str]: A list of IDs associated with the added text items.
62 | """
63 | if isinstance(texts, str):
64 | texts = (texts,)
65 |
66 | if metadata is None:
67 | metadata = [None, ] * len(texts)
68 | elif isinstance(metadata, dict):
69 | metadata = (metadata,)
70 | else:
71 | assert len(texts) == len(metadata)
72 |
73 | embeds = self.text_embedding.compute_source_embeddings(texts)
74 |
75 | table_name = table_name or self.table_name
76 | ids = ids or [str(uuid.uuid4()) for _ in range(len(texts))]
77 |
78 | data = [
79 | {
80 | "vector": embed,
81 | "text": text,
82 | "id": idx,
83 | "metadata": meta,
84 | } for idx, text, embed, meta in zip(ids, texts, embeds, metadata)
85 | ]
86 |
87 | if table_name not in self.client.table_names():
88 | table = self.client.create_table(table_name, data=data)
89 | table.create_fts_index("id")
90 | else:
91 | self._open_table(table_name).add(data=data)
92 |
93 | return ids
94 |
95 | def query(self, query_texts: List[str], table_name: Optional[str] = None, n_results: int = 5) -> List[List[dict]]:
96 | """
97 | Queries the specified memory table for similar text embeddings.
98 |
99 | Args:
100 | query_texts (List[str]): A list of query text strings.
101 | table_name (Optional[str]): The name of the table to query. Defaults to self.table_name.
102 | n_results (int): The maximum number of results to retrieve per query. Default is 5.
103 |
104 | Returns:
105 | List[List[dict]]: A list of results for each query text, each result being a dictionary with
106 | keys such as "vector", "text", and "id".
107 | """
108 | table = self._open_table(table_name)
109 | if table is None:
110 | return []
111 |
112 | query_embeds = self.text_embedding.compute_source_embeddings(query_texts)
113 |
114 | results = [table.search(query).limit(n_results).to_list() for query in query_embeds]
115 | return results
116 |
117 | def list_all_keys(self, table_name: Optional[str] = None):
118 | """
119 | Lists all IDs in the specified memory table.
120 |
121 | Args:
122 | table_name (Optional[str]): The name of the table to list IDs from. Defaults to the instance's table name.
123 |
124 | Returns:
125 | List[str]: A list of all IDs in the table.
126 | """
127 | table = self._open_table(table_name)
128 | if table is None:
129 | return []
130 |
131 | return [item["id"] for item in table.search(query_type="fts").to_list()]
132 |
133 | def get(self, record_id: str, table_name: Optional[str] = None):
134 | """
135 | Retrieves a record by its ID from the specified memory table.
136 |
137 | Args:
138 | record_id (str): The ID of the record to retrieve.
139 | table_name (Optional[str]): The name of the table to query. Defaults to the instance's table name.
140 |
141 | Returns:
142 | List[dict]: A list containing the matching record, or an empty list if not found.
143 | """
144 | table = self._open_table(table_name)
145 | if table is None:
146 | return []
147 |
148 | return table.search(query_type="fts") \
149 | .where(f"id = '{record_id}'") \
150 | .limit(1).to_list()
151 |
152 | def get_by_metadata(self, key: str, value: str, table_name: Optional[str] = None, n_results: int = 5):
153 | """
154 | Retrieves records matching a specific metadata key-value pair.
155 |
156 | Args:
157 | key (str): The metadata key to filter by.
158 | value (str): The value of the metadata key to filter by.
159 | table_name (Optional[str]): The name of the table to query. Defaults to the instance's table name.
160 | n_results (int): The maximum number of results to retrieve. Defaults to 5.
161 |
162 | Returns:
163 | List[dict]: A list of records matching the metadata criteria.
164 | """
165 | table = self._open_table(table_name)
166 | if table is None:
167 | return []
168 |
169 | return table.search(query_type="fts") \
170 | .where(f"metadata.{key} = '{value}'") \
171 | .limit(n_results).to_list()
172 |
173 | def delete(self, record_id: str, table_name: Optional[str] = None) -> bool:
174 | """
175 | Deletes a record from the specified memory table.
176 |
177 | Args:
178 | record_id (str): The ID of the record to delete.
179 | table_name (Optional[str]): The name of the table to delete the record from. Defaults to self.table_name.
180 |
181 | Returns:
182 | bool: True if the deletion was successful, False otherwise.
183 | """
184 | table = self._open_table(table_name)
185 | if table is None:
186 | return True
187 |
188 | return table.delete(f"id = '{record_id}'")
189 |
190 | def delete_by_metadata(self, key: str, value: str, table_name: Optional[str] = None):
191 | """
192 | Deletes records from the specified memory table based on a metadata key-value pair.
193 |
194 | Args:
195 | key (str): The metadata key to filter by.
196 | value (str): The value of the metadata key to filter by.
197 | table_name (Optional[str]): The name of the table to delete records from. Defaults to the instance's table name.
198 |
199 | Returns:
200 | bool: True if deletion was successful, False otherwise.
201 | """
202 | table = self._open_table(table_name)
203 | if table is None:
204 | return True
205 |
206 | return table.delete(f"metadata.{key} = '{value}'")
207 |
208 | def drop(self, table_name: Optional[str] = None) -> bool:
209 | """
210 | Drops (deletes) the specified memory table.
211 |
212 | Args:
213 | table_name (Optional[str]): The name of the table to delete. Defaults to self.table_name.
214 |
215 | Returns:
216 | bool: True if the table was successfully dropped, False otherwise.
217 | """
218 | table_name = table_name or self.table_name
219 | table = self._open_table(table_name)
220 | if table is None:
221 | return True
222 |
223 | return self.client.drop_table(table_name)
224 |
225 | def count(self, table_name: Optional[str] = None) -> int:
226 | """
227 | Counts the number of records in the specified memory table.
228 |
229 | Args:
230 | table_name (Optional[str]): The name of the table to count records in. Defaults to self.table_name.
231 |
232 | Returns:
233 | int: The number of records in the table.
234 | """
235 | table = self._open_table(table_name)
236 | if table is None:
237 | return 0
238 |
239 | return table.count_rows()
240 |
241 | def reset(self) -> None:
242 | """
243 | Resets the memory by dropping the default memory table.
244 | """
245 | self.drop()
246 |
--------------------------------------------------------------------------------
/mle/version.py:
--------------------------------------------------------------------------------
1 | # PEP0440 compatible formatted version, see:
2 | # https://www.python.org/dev/peps/pep-0440/
3 | #
4 | # Generic release markers:
5 | # X.Y
6 | # X.Y.Z # For bug fix releases
7 | #
8 | # Admissible pre-release markers:
9 | # X.YaN # Alpha release
10 | # X.YbN # Beta release
11 | # X.YrcN # Release Candidate
12 | # X.Y # Final release
13 | #
14 | # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
15 | # 'X.Y.dev0' is the canonical version of 'X.Y.dev'
16 | #
17 |
18 | __version__ = '0.4.2'
19 |
--------------------------------------------------------------------------------
/mle/workflow/__init__.py:
--------------------------------------------------------------------------------
1 | from .chat import chat
2 | from .baseline import baseline
3 | from .report import report, report_local
4 | from .kaggle import kaggle, auto_kaggle
5 |
--------------------------------------------------------------------------------
/mle/workflow/baseline.py:
--------------------------------------------------------------------------------
1 | """
2 | Baseline Mode: the mode to quickly generate the AI baseline based on the user's requirements.
3 | """
4 | import os
5 | import questionary
6 | from rich.console import Console
7 | from mle.model import load_model
8 | from mle.utils import print_in_box, ask_text, WorkflowCache
9 | from mle.agents import CodeAgent, DebugAgent, AdviseAgent, PlanAgent
10 |
11 |
12 | def ask_data(data_str: str):
13 | """
14 | Ask the user to provide the data information.
15 | :param data_str: the input data string. Now, it should be the name of the public dataset or
16 | the path to the local CSV file.
17 | :return: the formated data information.
18 | """
19 | if os.path.isfile(data_str) and data_str.lower().endswith('.csv'):
20 | return f"[green]CSV Dataset Location:[/green] {data_str}"
21 | else:
22 | return f"[green]Dataset:[/green] {data_str}"
23 |
24 |
25 | def baseline(work_dir: str, model=None):
26 | """
27 | The workflow of the baseline mode.
28 | :return:
29 | """
30 |
31 | console = Console()
32 | cache = WorkflowCache(work_dir, 'baseline')
33 | model = load_model(work_dir, model)
34 |
35 | if not cache.is_empty():
36 | step = ask_text(f"MLE has finished the following steps: \n{cache}\n"
37 | f"You can pick a step from 1 to {cache.current_step()} to resume\n"
38 | "(or ENTER to continue the workflow)")
39 | if step:
40 | step = int(step)
41 | for i in range(step, cache.current_step() + 1):
42 | cache.remove(i) # remove the stale step caches
43 |
44 | # ask for the data information
45 | with cache(step=1, name="ask for the data information") as ca:
46 | dataset = ca.resume("dataset")
47 | if dataset is None:
48 | advisor = AdviseAgent(model, console)
49 | dataset = ask_text("Please provide your dataset information (a public dataset name or a local absolute filepath)")
50 | if not dataset:
51 | print_in_box("The dataset is empty. Aborted", console, title="Error", color="red")
52 | return
53 | dataset = advisor.clarify_dataset(dataset)
54 | ca.store("dataset", dataset)
55 |
56 | # ask for the user requirement
57 | with cache(step=2, name="ask for the user requirement") as ca:
58 | ml_requirement = ca.resume("ml_requirement")
59 | if ml_requirement is None:
60 | ml_requirement = ask_text("Please provide your requirement")
61 | if not ml_requirement:
62 | print_in_box("The user's requirement is empty. Aborted", console, title="Error", color="red")
63 | return
64 | ca.store("ml_requirement", ml_requirement)
65 |
66 | # advisor agent gives suggestions in a report
67 | with cache(step=3, name="MLE advisor agent provides a high-level report") as ca:
68 | advisor_report = ca.resume("advisor_report")
69 | if advisor_report is None:
70 | advisor = AdviseAgent(model, console)
71 | advisor_report = advisor.interact("[green]User Requirement:[/green] " + ml_requirement + "\n" + ask_data(dataset))
72 | ca.store("advisor_report", advisor_report)
73 |
74 | # plan agent generates the coding plan
75 | with cache(step=4, name="MLE plan agent generates a dev plan") as ca:
76 | coding_plan = ca.resume("coding_plan")
77 | if coding_plan is None:
78 | planner = PlanAgent(model, console)
79 | coding_plan = planner.interact(advisor_report)
80 | ca.store("coding_plan", coding_plan)
81 |
82 | # code agent codes the tasks and debug with the debug agent
83 | with cache(step=5, name="MLE code&debug agents start to work") as ca:
84 | coder = CodeAgent(model, work_dir, console)
85 | coder.read_requirement(advisor_report)
86 | debugger = DebugAgent(model, console)
87 |
88 | is_auto_mode = questionary.confirm(
89 | "MLE developer is about to start to code.\n"
90 | "Choose to debug or not (If no, MLE agent will only focus on coding tasks,"
91 | " and you have to run and debug the code yourself)?"
92 | ).ask()
93 |
94 | for current_task in coding_plan.get('tasks'):
95 | code_report = coder.interact(current_task)
96 | is_debugging = code_report.get('debug')
97 |
98 | if is_auto_mode:
99 | while True:
100 | if is_debugging == 'true' or is_debugging == 'True':
101 | with console.status("MLE Debug Agent is executing and debugging the code..."):
102 | debug_report = debugger.analyze(code_report)
103 | if debug_report.get('status') == 'success':
104 | break
105 | else:
106 | code_report = coder.debug(current_task, debug_report)
107 | else:
108 | break
109 |
--------------------------------------------------------------------------------
/mle/workflow/chat.py:
--------------------------------------------------------------------------------
1 | """
2 | Chat Mode: the mode to have an interactive chat with LLM to work on ML project.
3 | """
4 | import questionary
5 | from rich.live import Live
6 | from rich.panel import Panel
7 | from rich.console import Console
8 | from rich.markdown import Markdown
9 | from mle.model import load_model
10 | from mle.utils import print_in_box, WorkflowCache
11 | from mle.agents import ChatAgent
12 |
13 |
14 | def chat(work_dir: str, memory=None, model=None):
15 | console = Console()
16 | cache = WorkflowCache(work_dir, 'chat')
17 | model = load_model(work_dir, model)
18 | chatbot = ChatAgent(model, memory=memory)
19 |
20 | if not cache.is_empty():
21 | if questionary.confirm(f"Would you like to continue the previous conversation?\n").ask():
22 | chatbot.chat_history = cache.resume_variable("conversation")
23 |
24 | with cache(step=1, name="chat") as ca:
25 | greets = chatbot.greet()
26 | print_in_box(greets, console=console, title="MLE Chatbot", color="magenta")
27 |
28 | while True:
29 | try:
30 | user_pmpt = questionary.text("[Exit/Ctrl+D]: ").ask()
31 | if user_pmpt:
32 | with Live(console=Console()) as live:
33 | for text in chatbot.chat(user_pmpt.strip()):
34 | live.update(
35 | Panel(Markdown(text), title="[bold magenta]MLE-Agent[/]", border_style="magenta"),
36 | refresh=True
37 | )
38 | ca.store("conversation", chatbot.chat_history)
39 | except (KeyboardInterrupt, EOFError):
40 | break
41 |
--------------------------------------------------------------------------------
/mle/workflow/kaggle.py:
--------------------------------------------------------------------------------
1 | """
2 | Kaggle Mode: the mode to generate ML pipeline for kaggle competitions.
3 | """
4 | import os
5 | import questionary
6 | from typing import List
7 | from rich.console import Console
8 |
9 | from mle.model import load_model
10 | from mle.function import execute_command
11 | from mle.integration import KaggleIntegration
12 | from mle.utils import ask_text, read_markdown, is_markdown_file, WorkflowCache, print_in_box
13 | from mle.agents import CodeAgent, DebugAgent, AdviseAgent, PlanAgent, GitHubSummaryAgent
14 |
15 |
16 | def auto_kaggle(
17 | work_dir: str,
18 | datasets: List[str],
19 | description: str,
20 | submission='./submission.csv',
21 | debug_max_attempt=5,
22 | sub_examples=None,
23 | competition_id=None,
24 | model=None
25 | ):
26 | """
27 | The workflow of the kaggle mode.
28 | :param work_dir: the working directory.
29 | :param datasets: the datasets to use.
30 | :param description: the description of the competition, can be a path to a local .md file or a string.
31 | :param submission: the path of the kaggle submission file.
32 | :param debug_max_attempt: the max attempt for debugging.
33 | :param sub_examples: the path to the kaggle submission example file.
34 | :param competition_id: the competition id.
35 | :param model: the model to use.
36 | """
37 | console = Console()
38 | model = load_model(work_dir, model)
39 |
40 | # initialize the agents
41 | advisor = AdviseAgent(model, console, mode="precise")
42 | summarizer = GitHubSummaryAgent(model, console=console)
43 | coder = CodeAgent(model, work_dir, console=console, single_file=True)
44 | debugger = DebugAgent(model, console, analyze_only=True)
45 |
46 | if is_markdown_file(description):
47 | description = read_markdown(description)
48 |
49 | with console.status("MLE Agent is processing the kaggle competition overview..."):
50 | requirements = summarizer.kaggle_request_summarize(description, sub_examples)
51 | requirements += f"\n\nLOCAL DATASET PATH:\n"
52 | for dataset in datasets:
53 | requirements += f" - {dataset}\n"
54 |
55 | requirements += f"\nSUBMISSION FILE PATH: {submission}\n"
56 |
57 | suggestions = advisor.suggest(requirements, return_raw=True)
58 | requirements += f"""
59 | \nIMPLEMENTATION SUGGESTIONS:
60 |
61 | - Suggestion summary: {suggestions.get('suggestion')}
62 | - ML task: {suggestions.get('task')}
63 | - Model or algorithm: {suggestions.get('model_or_algorithm')}
64 | - Training strategy: {suggestions.get('training_method')}
65 | - Tricks to increase performance:
66 | """
67 | for trick in suggestions.get('tricks'):
68 | requirements += f"\n - {trick}"
69 |
70 | coder.read_requirement(requirements)
71 | if competition_id is None:
72 | competition_id = "kaggle competition"
73 |
74 | coding_task = {
75 | "task": competition_id,
76 | "description": requirements
77 | }
78 | print_in_box(requirements, console, title="Kaggle Competition Requirement", color="green")
79 | code_report = coder.code(coding_task)
80 | debug_attempt = 0
81 | while True:
82 | if debug_attempt > debug_max_attempt:
83 | console.log(f"Debug the code failed with max {debug_max_attempt} attempts. Please check the code manually.")
84 | break
85 |
86 | with console.status("MLE Debug Agent is executing and debugging the code..."):
87 | running_cmd = code_report.get('command')
88 | logs = execute_command(running_cmd)
89 | debug_report = debugger.analyze_with_log(running_cmd, logs)
90 | if debug_report.get('status') == 'success':
91 | # check the submission file
92 | if not os.path.exists(submission):
93 | console.log(f"The submission file ({submission}) is not found. Launch the coder to improve...")
94 | code_report = coder.debug(
95 | coding_task,
96 | {
97 | "status": "error",
98 | "changes": [
99 | f"make sure the submission file is generated in {submission}",
100 | f"make sure the submission file is in the correct format. You can refer to the example submission file: {sub_examples}"
101 | ],
102 | "suggestion": f"Please update the code related to generating the submission file."
103 | }
104 | )
105 | else:
106 | break
107 | else:
108 | debug_attempt += 1
109 | code_report = coder.debug(coding_task, debug_report)
110 |
111 |
112 | def kaggle(work_dir: str, model=None):
113 | """
114 | The workflow of the kaggle mode.
115 | :param work_dir: the working directory.
116 | :param model: the model to use.
117 | """
118 | console = Console()
119 | cache = WorkflowCache(work_dir, 'kaggle')
120 | model = load_model(work_dir, model)
121 | integration = KaggleIntegration()
122 |
123 | if not cache.is_empty():
124 | step = ask_text(f"MLE has finished the following steps: \n{cache}\n"
125 | f"You can pick a step from 1 to {cache.current_step()} to resume\n"
126 | "(or ENTER to continue the workflow)")
127 | if step:
128 | step = int(step)
129 | for i in range(step, cache.current_step() + 1):
130 | cache.remove(i) # remove the stale step caches
131 |
132 | # ask for the kaggle competition
133 | with cache(step=1, name="ask for the kaggle competition") as ca:
134 | competition = ca.resume("competition")
135 | dataset = ca.resume("dataset")
136 | if competition is None or dataset is None:
137 | competition = questionary.select(
138 | "Please select a Kaggle competition to join:",
139 | choices=integration.list_competition()
140 | ).ask()
141 | with console.status("MLE Agent is downloading the kaggle competition dataset..."):
142 | dataset = integration.download_competition_dataset(
143 | competition, os.path.join(os.getcwd(), 'data'))
144 | ca.store("competition", competition)
145 | ca.store("dataset", dataset)
146 |
147 | # ask for the user requirement
148 | with cache(step=2, name="get the competition overview from kaggle") as ca:
149 | ml_requirement = ca.resume("ml_requirement")
150 | if ml_requirement is None:
151 | with console.status("MLE Agent is fetching the kaggle competition overview..."):
152 | summary = GitHubSummaryAgent(model, console=console)
153 | ml_requirement = summary.kaggle_request_summarize(integration.fetch_competition_overview(competition))
154 | ca.store("ml_requirement", ml_requirement)
155 |
156 | # advisor agent gives suggestions in a report
157 | with cache(step=3, name="MLE advisor agent provides a high-level report") as ca:
158 | advisor_report = ca.resume("advisor_report")
159 | if advisor_report is None:
160 | advisor = AdviseAgent(model, console)
161 | advisor_report = advisor.interact(
162 | f"[green]Competition Requirement:[/green] {ml_requirement}\n"
163 | f"Dataset is downloaded in path: {dataset}"
164 | )
165 | ca.store("advisor_report", advisor_report)
166 |
167 | # plan agent generates the coding plan
168 | with cache(step=4, name="MLE plan agent generates a dev plan") as ca:
169 | coding_plan = ca.resume("coding_plan")
170 | if coding_plan is None:
171 | planner = PlanAgent(model, console)
172 | coding_plan = planner.interact(advisor_report)
173 | ca.store("coding_plan", coding_plan)
174 |
175 | # code agent codes the tasks and debug with the debug agent
176 | with cache(step=5, name="MLE code&debug agents start to work") as ca:
177 | coder = CodeAgent(model, work_dir, console)
178 | coder.read_requirement(advisor_report)
179 | debugger = DebugAgent(model, console)
180 |
181 | is_auto_mode = questionary.confirm(
182 | "MLE developer is about to start to code.\n"
183 | "Choose to debug or not (If no, MLE agent will only focus on coding tasks,"
184 | " and you have to run and debug the code yourself)?"
185 | ).ask()
186 |
187 | for current_task in coding_plan.get('tasks'):
188 | code_report = coder.interact(current_task)
189 | is_debugging = code_report.get('debug')
190 |
191 | if is_auto_mode:
192 | while True:
193 | if is_debugging == 'true' or is_debugging == 'True':
194 | with console.status("MLE Debug Agent is executing and debugging the code..."):
195 | debug_report = debugger.analyze(code_report)
196 | if debug_report.get('status') == 'success':
197 | break
198 | else:
199 | code_report = coder.debug(current_task, debug_report)
200 | else:
201 | break
202 |
--------------------------------------------------------------------------------
/mle/workflow/report.py:
--------------------------------------------------------------------------------
1 | """
2 | Report Mode: the mode to generate the AI report based on the user's requirements.
3 | """
4 | import os
5 | import pickle
6 | from rich.console import Console
7 | from mle.model import load_model
8 | from mle.utils.system import get_config, write_config, check_config
9 | from mle.integration import GoogleCalendarIntegration, github_login
10 | from mle.agents import GitHubSummaryAgent, ReportAgent, GitSummaryAgent
11 |
12 |
13 | def ask_data(data_str: str):
14 | """
15 | Ask the user to provide the data information.
16 | :param data_str: the input data string. Now, it should be the name of the public dataset or
17 | the path to the local CSV file.
18 | :return: the formated data information.
19 | """
20 | if os.path.isfile(data_str) and data_str.lower().endswith('.csv'):
21 | return f"[green]CSV Dataset Location:[/green] {data_str}"
22 | else:
23 | return f"[green]Dataset:[/green] {data_str}"
24 |
25 |
26 | def report(
27 | work_dir: str,
28 | github_repo: str,
29 | github_username: str,
30 | github_token: str = None,
31 | okr_str: str = None,
32 | model=None
33 | ):
34 | """
35 | The workflow of the baseline mode.
36 | :param work_dir: the working directory.
37 | :param github_repo: the GitHub repository.
38 | :param github_username: the GitHub username.
39 | :param github_token: the GitHub token.
40 | :param okr_str: the OKR string.
41 | :param model: the model to use.
42 | :return:
43 | """
44 | console = Console()
45 | model = load_model(work_dir, model)
46 |
47 | events = None
48 | if check_config(console):
49 | config = get_config()
50 | if github_token is None:
51 | if "github" in config.get("integration", {}).keys():
52 | github_token = config["integration"]["github"].get("token")
53 | else:
54 | github_token = github_login()
55 | config["integration"]["github"] = {"token": github_token}
56 | write_config(config)
57 |
58 | if "google_calendar" in config.get("integration", {}).keys():
59 | google_token = pickle.loads(config["integration"]["google_calendar"].get("token"))
60 | google_calendar = GoogleCalendarIntegration(google_token)
61 | events = google_calendar.get_events()
62 |
63 | summarizer = GitHubSummaryAgent(
64 | model,
65 | github_repo=github_repo,
66 | username=github_username,
67 | github_token=github_token,
68 | )
69 | reporter = ReportAgent(model, console)
70 |
71 | github_summary = summarizer.summarize()
72 | return reporter.gen_report(github_summary, events, okr=okr_str)
73 |
74 |
75 | def report_local(
76 | work_dir: str,
77 | git_path: str,
78 | email: str,
79 | okr_str: str = None,
80 | start_date: str = None,
81 | end_date: str = None,
82 | model=None
83 | ):
84 | """
85 | The workflow of the baseline mode.
86 | :param work_dir: the working directory.
87 | :param git_path: the path to the local Git repository.
88 | :param email: the email address.
89 | :param okr_str: the OKR string.
90 | :param start_date: the start date.
91 | :param end_date: the end date.
92 | :param model: the model to use.
93 | :return:
94 | """
95 |
96 | console = Console()
97 | model = load_model(work_dir, model)
98 |
99 | events = None
100 |
101 | summarizer = GitSummaryAgent(
102 | model,
103 | git_path=git_path,
104 | git_email=email,
105 | )
106 | reporter = ReportAgent(model, console)
107 |
108 | git_summary = summarizer.summarize(start_date=start_date, end_date=end_date)
109 | return reporter.gen_report(git_summary, events, okr=okr_str)
110 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | rich
2 | click
3 | py7zr~=0.22.0
4 | openai
5 | pyyaml
6 | kaggle
7 | fastapi
8 | uvicorn
9 | requests
10 | GitPython
11 | tree-sitter==0.21.3
12 | onnxruntime
13 | questionary
14 | pandas~=2.2.2
15 | tavily-python
16 | instructor
17 | langfuse
18 | setuptools
19 | numexpr~=2.10.1
20 | bottleneck~=1.4.0
21 | google-api-python-client~=2.143.0
22 | google-auth-httplib2~=0.2.0
23 | google-auth-oauthlib~=1.2.1
24 | lancedb~=0.15.0
25 | tantivy~=0.22.0
26 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
3 |
4 | [flake8]
5 | select = W291
6 | max-line-length = 256
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from setuptools import find_packages, setup
3 |
4 | # read the contents of README file
5 | from os import path
6 | from io import open # for Python 2 and 3 compatibility
7 |
8 | # get __version__ from _version.py
9 | ver_file = path.join('mle', 'version.py')
10 | with open(ver_file) as f:
11 | exec(f.read())
12 |
13 | this_directory = path.abspath(path.dirname(__file__))
14 |
15 |
16 | # read the contents of README.md
17 | def readme():
18 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
19 | return f.read()
20 |
21 |
22 | # read the contents of requirements.txt
23 | with open(path.join(this_directory, 'requirements.txt'), encoding='utf-8') as f:
24 | requirements = f.read().splitlines()
25 |
26 | setup(
27 | name='mle-agent',
28 | version=__version__,
29 | description='MLE-agent: An agent to automate your MLE processes',
30 | long_description=readme(),
31 | long_description_content_type='text/markdown',
32 | author='Yizheng Huang, Huaizheng Zhang',
33 | author_email='huangyz0918@gmail.com',
34 | url='https://github.com/MLSysOps/MLE-agent',
35 | download_url='https://github.com/MLSysOps/MLE-agent/archive/refs/heads/main.zip',
36 | keywords=['LLM', 'deep learning', 'MLOps', 'shell', 'neural networks'],
37 | packages=find_packages(),
38 | entry_points={
39 | "console_scripts": [
40 | "mle-agent=mle.cli:cli",
41 | "mle=mle.cli:cli",
42 | ]
43 | },
44 | zip_safe=False,
45 | include_package_data=True,
46 | install_requires=requirements,
47 | setup_requires=['setuptools>=38.6.0'],
48 | classifiers=[
49 | 'Development Status :: 5 - Production/Stable',
50 | 'Intended Audience :: Education',
51 | 'Intended Audience :: Financial and Insurance Industry',
52 | 'Intended Audience :: Science/Research',
53 | 'Intended Audience :: Developers',
54 | 'Intended Audience :: Information Technology',
55 | 'License :: OSI Approved :: Apache Software License',
56 | 'Programming Language :: Python :: 3',
57 | "Operating System :: OS Independent",
58 | ],
59 | )
60 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/tests/__init__.py
--------------------------------------------------------------------------------
/web/.eslintrc.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": ["next/core-web-vitals", "next/typescript"]
3 | }
4 |
--------------------------------------------------------------------------------
/web/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | /node_modules
5 | /.pnp
6 | .pnp.js
7 | .yarn/install-state.gz
8 |
9 | # testing
10 | /coverage
11 |
12 | # next.js
13 | /.next/
14 | /out/
15 |
16 | # production
17 | /build
18 |
19 | # misc
20 | .DS_Store
21 | *.pem
22 |
23 | # debug
24 | npm-debug.log*
25 | yarn-debug.log*
26 | yarn-error.log*
27 |
28 | # local env files
29 | .env*.local
30 |
31 | # vercel
32 | .vercel
33 |
34 | # typescript
35 | *.tsbuildinfo
36 | next-env.d.ts
37 |
--------------------------------------------------------------------------------
/web/README.md:
--------------------------------------------------------------------------------
1 | This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
2 |
3 | ## Getting Started
4 |
5 | First, run the development server:
6 |
7 | ```bash
8 | npm run dev
9 | # or
10 | yarn dev
11 | # or
12 | pnpm dev
13 | # or
14 | bun dev
15 | ```
16 |
17 | Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
18 |
19 | You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
20 |
21 | This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
22 |
23 | ## Learn More
24 |
25 | To learn more about Next.js, take a look at the following resources:
26 |
27 | - [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
28 | - [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
29 |
30 | You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
31 |
32 | ## Deploy on Vercel
33 |
34 | The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
35 |
36 | Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.
37 |
--------------------------------------------------------------------------------
/web/app/fonts/GeistMonoVF.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/web/app/fonts/GeistMonoVF.woff
--------------------------------------------------------------------------------
/web/app/fonts/GeistVF.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/web/app/fonts/GeistVF.woff
--------------------------------------------------------------------------------
/web/app/globals.css:
--------------------------------------------------------------------------------
1 | @tailwind base;
2 | @tailwind components;
3 | @tailwind utilities;
4 |
5 | :root {
6 | --background: #ffffff;
7 | --foreground: #171717;
8 | }
9 |
10 | @media (prefers-color-scheme: dark) {
11 | :root {
12 | --background: #0a0a0a;
13 | --foreground: #ededed;
14 | }
15 | }
16 |
17 | body {
18 | color: var(--foreground);
19 | background: var(--background);
20 | font-family: Arial, Helvetica, sans-serif;
21 | }
22 |
23 | @layer utilities {
24 | .text-balance {
25 | text-wrap: balance;
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/web/app/layout.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 |
3 | import React, { useState, useEffect } from "react";
4 | import localFont from "next/font/local";
5 | import "./globals.css";
6 |
7 | const geistSans = localFont({
8 | src: "./fonts/GeistVF.woff",
9 | variable: "--font-geist-sans",
10 | weight: "100 900",
11 | });
12 | const geistMono = localFont({
13 | src: "./fonts/GeistMonoVF.woff",
14 | variable: "--font-geist-mono",
15 | weight: "100 900",
16 | });
17 |
18 | export default function RootLayout({
19 | children,
20 | }: Readonly<{
21 | children: React.ReactNode;
22 | }>) {
23 | const [isLoading, setIsLoading] = useState(true);
24 |
25 | useEffect(() => {
26 | const handleLoad = () => {
27 | setIsLoading(false);
28 | };
29 |
30 | if (document.readyState === "complete") {
31 | handleLoad();
32 | } else {
33 | window.addEventListener("load", handleLoad);
34 | return () => {
35 | window.removeEventListener("load", handleLoad);
36 | };
37 | }
38 | }, []);
39 |
40 | return (
41 |
42 |
45 | {!isLoading && children}
46 |
47 |
48 | );
49 | }
50 |
--------------------------------------------------------------------------------
/web/app/page.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import React, { useState, useEffect } from 'react';
3 | import { Layout, Card, Input, Button, message, Form, Select, Spin, Flex, Row, Col } from 'antd';
4 | import { HomeOutlined, BookOutlined, MailOutlined, GithubOutlined, XOutlined, DiscordOutlined} from '@ant-design/icons';
5 | import dynamic from "next/dynamic";
6 |
7 | const MDEditor = dynamic(
8 | () => import("@uiw/react-md-editor"),
9 | { ssr: false }
10 | );
11 | const { Content, Header } = Layout;
12 | const { TextArea } = Input;
13 | const { Option } = Select;
14 |
15 | interface ReportData {
16 | project_okr: string;
17 | business_goal: string[];
18 | dev_progress: string[];
19 | communicate_progress: string[];
20 | dev_todo: { task: string; description: string; priority: string }[];
21 | communicate_todo: { task: string; priority: string }[];
22 | hard_parts: string[];
23 | require_manager_help: string[];
24 | suggestions_to_user: string[];
25 | reference: { title: string; link: string;}[];
26 | }
27 |
28 | interface ReportRequest {
29 | repo: string;
30 | username: string;
31 | okr?: string;
32 | dateRange?: string;
33 | recurringReports?: string;
34 | additionalSources?: string[];
35 | }
36 |
37 | export default function Home() {
38 | const [reportContent, setReportContent] = useState("");
39 | const [reportData, setReportData] = useState(null);
40 | const [form] = Form.useForm();
41 | const [loading, setLoading] = useState(false);
42 |
43 | useEffect(() => {
44 | fetchLatestReport();
45 | form.setFieldsValue({ recurringReports: 'weekly' });
46 | }, [form]);
47 |
48 | const fetchLatestReport = async () => {
49 | try {
50 | const response = await fetch('http://localhost:8000/latest_report');
51 | if (response.status === 404) {
52 | setReportContent("No report has been found. Please click 'Generate Report' to create one.");
53 | setReportData(null);
54 | return;
55 | }
56 | if (!response.ok) {
57 | throw new Error('Failed to fetch the latest report');
58 | }
59 | const data: ReportData = await response.json();
60 | setReportData(data);
61 | const markdownContent = convertToMarkdown(data);
62 | setReportContent(markdownContent);
63 | } catch (error) {
64 | console.error('Error fetching latest report:', error);
65 | message.error('Failed to fetch the latest report');
66 | }
67 | };
68 |
69 | const convertToMarkdown = (data: ReportData): string => {
70 | let markdown = '';
71 | markdown += `## Project Report\n\n`;
72 |
73 | if (data.project_okr) {
74 | markdown += `### Project OKR\n${data.project_okr}\n\n`;
75 | }
76 |
77 | markdown += `### Business Goal\n${data.business_goal.map(goal => `- ${goal}`).join('\n')}\n\n`;
78 | markdown += `### Work Finished This Week\n\n`;
79 | markdown += `#### Development Progress\n${data.dev_progress.map(progress => `- ${progress}`).join('\n')}\n\n`;
80 | markdown += `#### Communication/Design Progress\n${data.communicate_progress.map(progress => `- ${progress}`).join('\n')}\n\n`;
81 | markdown += `### Work TODOs in the Next Week\n\n`;
82 | markdown += `#### Development TODOs\n${data.dev_todo.map(todo => `- ${todo.task} **(${todo.priority})**: ${todo.description}`).join('\n')}\n\n`;
83 | markdown += `#### Communication TODOs\n${data.communicate_todo.map(todo => `- ${todo.task} **(${todo.priority})**`).join('\n')}\n\n`;
84 | markdown += `### Hard Problems\n\n`;
85 | markdown += `#### Challenges\n${data.hard_parts.map(part => `- ${part}`).join('\n')}\n\n`;
86 | markdown += `#### Manager Help Required\n${data.require_manager_help.map(help => `- ${help}`).join('\n')}\n\n`;
87 | markdown += `### Other Progress and Thoughts\n\n`;
88 | markdown += `#### Suggestions\n${data.suggestions_to_user.map(suggestion => `- ${suggestion}`).join('\n')}\n\n`;
89 | markdown += `#### References\n${data.reference.map(ref => `- [${ref.title}](${ref.link})`).join('\n')}\n\n`;
90 | return markdown;
91 | };
92 |
93 | const handleGenerateReport = async (values: ReportRequest) => {
94 | setLoading(true);
95 | try {
96 | const response = await fetch('http://localhost:8000/gen_report', {
97 | method: 'POST',
98 | headers: {
99 | 'Content-Type': 'application/json',
100 | },
101 | body: JSON.stringify(values),
102 | });
103 |
104 | if (!response.ok) {
105 | throw new Error('Failed to generate report');
106 | }
107 |
108 | const result = await response.json();
109 |
110 | if (result.result) {
111 | setReportData(result.result);
112 | const markdownContent = convertToMarkdown(result.result);
113 | setReportContent(markdownContent);
114 | message.success('Report generated successfully');
115 | } else {
116 | message.info('Report generation completed, but no data returned');
117 | }
118 | } catch (error) {
119 | console.error('Error generating report:', error);
120 | message.error('Failed to generate report');
121 | } finally {
122 | setLoading(false);
123 | }
124 | };
125 |
126 | const handleSaveReport = () => {
127 | message.info('Save report functionality not implemented yet');
128 | };
129 |
130 | return (
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 | } href="https://repx.app/">
142 | Home
143 |
144 | } href="https://docs.repx.app/">
145 | Docs
146 |
147 | } href="https://discord.gg/xHW3Yz4x">
148 | Feedback
149 |
150 | } href="https://github.com/MLSysOps/MLE-agent"/>
151 | } href="https://x.com/zhzHNN/"/>
152 |
153 |
154 |
155 |
156 |
157 |
251 |
252 |
253 | );
254 | }
--------------------------------------------------------------------------------
/web/next.config.mjs:
--------------------------------------------------------------------------------
1 | /** @type {import('next').NextConfig} */
2 | const nextConfig = {};
3 |
4 | export default nextConfig;
5 |
--------------------------------------------------------------------------------
/web/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "web",
3 | "version": "0.1.0",
4 | "private": true,
5 | "scripts": {
6 | "dev": "next dev",
7 | "build": "next build",
8 | "start": "next start",
9 | "lint": "next lint"
10 | },
11 | "dependencies": {
12 | "@uiw/react-md-editor": "v3.6.0",
13 | "antd": "^5.20.5",
14 | "next": "14.2.10",
15 | "react": "^18",
16 | "react-dom": "^18"
17 | },
18 | "devDependencies": {
19 | "@types/node": "^20",
20 | "@types/react": "^18",
21 | "@types/react-dom": "^18",
22 | "eslint": "^8",
23 | "eslint-config-next": "14.2.8",
24 | "postcss": "^8",
25 | "tailwindcss": "^3.4.1",
26 | "typescript": "^5"
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/web/postcss.config.mjs:
--------------------------------------------------------------------------------
1 | /** @type {import('postcss-load-config').Config} */
2 | const config = {
3 | plugins: {
4 | tailwindcss: {},
5 | },
6 | };
7 |
8 | export default config;
9 |
--------------------------------------------------------------------------------
/web/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/web/public/favicon.ico
--------------------------------------------------------------------------------
/web/public/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLSysOps/MLE-agent/447834c4c0db42b68dd0be4e7a811440d215e37e/web/public/logo.png
--------------------------------------------------------------------------------
/web/tailwind.config.ts:
--------------------------------------------------------------------------------
1 | import type { Config } from "tailwindcss";
2 |
3 | const config: Config = {
4 | content: [
5 | "./pages/**/*.{js,ts,jsx,tsx,mdx}",
6 | "./components/**/*.{js,ts,jsx,tsx,mdx}",
7 | "./app/**/*.{js,ts,jsx,tsx,mdx}",
8 | ],
9 | theme: {
10 | extend: {
11 | colors: {
12 | background: "var(--background)",
13 | foreground: "var(--foreground)",
14 | },
15 | },
16 | },
17 | plugins: [],
18 | };
19 | export default config;
20 |
--------------------------------------------------------------------------------
/web/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "lib": ["dom", "dom.iterable", "esnext"],
4 | "allowJs": true,
5 | "skipLibCheck": true,
6 | "strict": true,
7 | "noEmit": true,
8 | "esModuleInterop": true,
9 | "module": "esnext",
10 | "moduleResolution": "bundler",
11 | "resolveJsonModule": true,
12 | "isolatedModules": true,
13 | "jsx": "preserve",
14 | "incremental": true,
15 | "plugins": [
16 | {
17 | "name": "next"
18 | }
19 | ],
20 | "paths": {
21 | "@/*": ["./*"]
22 | }
23 | },
24 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
25 | "exclude": ["node_modules"]
26 | }
27 |
--------------------------------------------------------------------------------