├── .dockerignore ├── .env.example ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── docker-compose.yml ├── poetry.lock ├── pyproject.toml ├── run.bat ├── run.sh └── src ├── agents ├── ben_graham.py ├── bill_ackman.py ├── cathie_wood.py ├── charlie_munger.py ├── fundamentals.py ├── michael_burry.py ├── peter_lynch.py ├── phil_fisher.py ├── portfolio_manager.py ├── risk_manager.py ├── sentiment.py ├── stanley_druckenmiller.py ├── technicals.py ├── valuation.py └── warren_buffett.py ├── backtester.py ├── data ├── cache.py └── models.py ├── graph └── state.py ├── llm └── models.py ├── main.py ├── tools └── api.py └── utils ├── __init__.py ├── analysts.py ├── display.py ├── docker.py ├── llm.py ├── ollama.py ├── progress.py └── visualize.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # Git 2 | .git 3 | .gitignore 4 | 5 | # Poetry 6 | .venv 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | .pytest_cache/ 11 | 12 | # Environment 13 | .env 14 | 15 | # IDEs and editors 16 | .idea/ 17 | .vscode/ 18 | *.swp 19 | *.swo 20 | 21 | # Logs and data 22 | logs/ 23 | data/ 24 | *.log 25 | 26 | # OS specific 27 | .DS_Store 28 | Thumbs.db -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # For running LLMs hosted by anthropic (claude-3-5-sonnet, claude-3-opus, claude-3-5-haiku) 2 | # Get your Anthropic API key from https://anthropic.com/ 3 | ANTHROPIC_API_KEY=your-anthropic-api-key 4 | # For running LLMs hosted by deepseek (deepseek-chat, deepseek-reasoner, etc.) 5 | # Get your DeepSeek API key from https://deepseek.com/ 6 | DEEPSEEK_API_KEY=your-deepseek-api-key 7 | 8 | # For running LLMs hosted by groq (deepseek, llama3, etc.) 9 | # Get your Groq API key from https://groq.com/ 10 | GROQ_API_KEY=your-groq-api-key 11 | 12 | # For running LLMs hosted by gemini (gemini-2.0-flash, gemini-2.0-pro) 13 | # Get your Google API key from https://console.cloud.google.com/ 14 | GOOGLE_API_KEY=your-google-api-key 15 | # For getting financial data to power the hedge fund 16 | # Get your Financial Datasets API key from https://financialdatasets.ai/ 17 | FINANCIAL_DATASETS_API_KEY=your-financial-datasets-api-key 18 | # For running LLMs hosted by openai (gpt-4o, gpt-4o-mini, etc.) 19 | # Get your OpenAI API key from https://platform.openai.com/ 20 | OPENAI_API_KEY=your-openai-api-key -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Screenshot** 14 | Add a screenshot of the bug to help explain your problem. 15 | 16 | **Additional context** 17 | Add any other context about the problem here. 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the feature you'd like** 11 | A clear and concise description of what you want to happen. 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | env/ 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | 24 | # Virtual Environment 25 | venv/ 26 | ENV/ 27 | 28 | # Environment Variables 29 | .env 30 | 31 | # IDE 32 | .idea/ 33 | .vscode/ 34 | *.swp 35 | *.swo 36 | .cursorrules 37 | .cursorignore 38 | .cursorindexingignore 39 | 40 | # OS 41 | .DS_Store 42 | Thumbs.db 43 | 44 | # graph 45 | *.png 46 | 47 | # Txt files 48 | *.txt 49 | 50 | # PDF files 51 | *.pdf 52 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim 2 | 3 | WORKDIR /app 4 | 5 | # Install Poetry 6 | RUN pip install poetry==1.7.1 7 | 8 | # Copy only dependency files first for better caching 9 | COPY pyproject.toml poetry.lock* /app/ 10 | 11 | # Configure Poetry to not use a virtual environment 12 | RUN poetry config virtualenvs.create false \ 13 | && poetry install --no-interaction --no-ansi 14 | 15 | # Copy rest of the source code 16 | COPY . /app/ 17 | 18 | # Default command (will be overridden by Docker Compose) 19 | CMD ["python", "src/main.py"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Virat Singh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI Hedge Fund 2 | 3 | This is a proof of concept for an AI-powered hedge fund. The goal of this project is to explore the use of AI to make trading decisions. This project is for **educational** purposes only and is not intended for real trading or investment. 4 | 5 | This system employs several agents working together: 6 | 7 | 1. Ben Graham Agent - The godfather of value investing, only buys hidden gems with a margin of safety 8 | 2. Bill Ackman Agent - An activist investors, takes bold positions and pushes for change 9 | 3. Cathie Wood Agent - The queen of growth investing, believes in the power of innovation and disruption 10 | 4. Charlie Munger Agent - Warren Buffett's partner, only buys wonderful businesses at fair prices 11 | 5. Michael Burry Agent - The Big Short contrarian who hunts for deep value 12 | 6. Peter Lynch Agent - Practical investor who seeks "ten-baggers" in everyday businesses 13 | 7. Phil Fisher Agent - Meticulous growth investor who uses deep "scuttlebutt" research 14 | 8. Stanley Druckenmiller Agent - Macro legend who hunts for asymmetric opportunities with growth potential 15 | 9. Warren Buffett Agent - The oracle of Omaha, seeks wonderful companies at a fair price 16 | 10. Valuation Agent - Calculates the intrinsic value of a stock and generates trading signals 17 | 11. Sentiment Agent - Analyzes market sentiment and generates trading signals 18 | 12. Fundamentals Agent - Analyzes fundamental data and generates trading signals 19 | 13. Technicals Agent - Analyzes technical indicators and generates trading signals 20 | 14. Risk Manager - Calculates risk metrics and sets position limits 21 | 15. Portfolio Manager - Makes final trading decisions and generates orders 22 | 23 | Screenshot 2025-03-22 at 6 19 07 PM 24 | 25 | 26 | **Note**: the system simulates trading decisions, it does not actually trade. 27 | 28 | [![Twitter Follow](https://img.shields.io/twitter/follow/virattt?style=social)](https://twitter.com/virattt) 29 | 30 | ## Disclaimer 31 | 32 | This project is for **educational and research purposes only**. 33 | 34 | - Not intended for real trading or investment 35 | - No warranties or guarantees provided 36 | - Past performance does not indicate future results 37 | - Creator assumes no liability for financial losses 38 | - Consult a financial advisor for investment decisions 39 | 40 | By using this software, you agree to use it solely for learning purposes. 41 | 42 | ## Table of Contents 43 | - [Setup](#setup) 44 | - [Using Poetry](#using-poetry) 45 | - [Using Docker](#using-docker) 46 | - [Usage](#usage) 47 | - [Running the Hedge Fund](#running-the-hedge-fund) 48 | - [Running the Backtester](#running-the-backtester) 49 | - [Project Structure](#project-structure) 50 | - [Contributing](#contributing) 51 | - [Feature Requests](#feature-requests) 52 | - [License](#license) 53 | 54 | ## Setup 55 | 56 | ### Using Poetry 57 | 58 | Clone the repository: 59 | ```bash 60 | git clone https://github.com/virattt/ai-hedge-fund.git 61 | cd ai-hedge-fund 62 | ``` 63 | 64 | 1. Install Poetry (if not already installed): 65 | ```bash 66 | curl -sSL https://install.python-poetry.org | python3 - 67 | ``` 68 | 69 | 2. Install dependencies: 70 | ```bash 71 | poetry install 72 | ``` 73 | 74 | 3. Set up your environment variables: 75 | ```bash 76 | # Create .env file for your API keys 77 | cp .env.example .env 78 | ``` 79 | 80 | 4. Set your API keys: 81 | ```bash 82 | # For running LLMs hosted by openai (gpt-4o, gpt-4o-mini, etc.) 83 | # Get your OpenAI API key from https://platform.openai.com/ 84 | OPENAI_API_KEY=your-openai-api-key 85 | 86 | # For running LLMs hosted by groq (deepseek, llama3, etc.) 87 | # Get your Groq API key from https://groq.com/ 88 | GROQ_API_KEY=your-groq-api-key 89 | 90 | # For getting financial data to power the hedge fund 91 | # Get your Financial Datasets API key from https://financialdatasets.ai/ 92 | FINANCIAL_DATASETS_API_KEY=your-financial-datasets-api-key 93 | ``` 94 | 95 | ### Using Docker 96 | 97 | 1. Make sure you have Docker installed on your system. If not, you can download it from [Docker's official website](https://www.docker.com/get-started). 98 | 99 | 2. Clone the repository: 100 | ```bash 101 | git clone https://github.com/virattt/ai-hedge-fund.git 102 | cd ai-hedge-fund 103 | ``` 104 | 105 | 3. Set up your environment variables: 106 | ```bash 107 | # Create .env file for your API keys 108 | cp .env.example .env 109 | ``` 110 | 111 | 4. Edit the .env file to add your API keys as described above. 112 | 113 | 5. Build the Docker image: 114 | ```bash 115 | # On Linux/Mac: 116 | ./run.sh build 117 | 118 | # On Windows: 119 | run.bat build 120 | ``` 121 | 122 | **Important**: You must set `OPENAI_API_KEY`, `GROQ_API_KEY`, `ANTHROPIC_API_KEY`, or `DEEPSEEK_API_KEY` for the hedge fund to work. If you want to use LLMs from all providers, you will need to set all API keys. 123 | 124 | Financial data for AAPL, GOOGL, MSFT, NVDA, and TSLA is free and does not require an API key. 125 | 126 | For any other ticker, you will need to set the `FINANCIAL_DATASETS_API_KEY` in the .env file. 127 | 128 | ## Usage 129 | 130 | ### Running the Hedge Fund 131 | 132 | #### With Poetry 133 | ```bash 134 | poetry run python src/main.py --ticker AAPL,MSFT,NVDA 135 | ``` 136 | 137 | #### With Docker 138 | ```bash 139 | # On Linux/Mac: 140 | ./run.sh --ticker AAPL,MSFT,NVDA main 141 | 142 | # On Windows: 143 | run.bat --ticker AAPL,MSFT,NVDA main 144 | ``` 145 | 146 | **Example Output:** 147 | Screenshot 2025-01-06 at 5 50 17 PM 148 | 149 | You can also specify a `--ollama` flag to run the AI hedge fund using local LLMs. 150 | 151 | ```bash 152 | # With Poetry: 153 | poetry run python src/main.py --ticker AAPL,MSFT,NVDA --ollama 154 | 155 | # With Docker (on Linux/Mac): 156 | ./run.sh --ticker AAPL,MSFT,NVDA --ollama main 157 | 158 | # With Docker (on Windows): 159 | run.bat --ticker AAPL,MSFT,NVDA --ollama main 160 | ``` 161 | 162 | You can also specify a `--show-reasoning` flag to print the reasoning of each agent to the console. 163 | 164 | ```bash 165 | # With Poetry: 166 | poetry run python src/main.py --ticker AAPL,MSFT,NVDA --show-reasoning 167 | 168 | # With Docker (on Linux/Mac): 169 | ./run.sh --ticker AAPL,MSFT,NVDA --show-reasoning main 170 | 171 | # With Docker (on Windows): 172 | run.bat --ticker AAPL,MSFT,NVDA --show-reasoning main 173 | ``` 174 | 175 | You can optionally specify the start and end dates to make decisions for a specific time period. 176 | 177 | ```bash 178 | # With Poetry: 179 | poetry run python src/main.py --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 180 | 181 | # With Docker (on Linux/Mac): 182 | ./run.sh --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 main 183 | 184 | # With Docker (on Windows): 185 | run.bat --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 main 186 | ``` 187 | 188 | ### Running the Backtester 189 | 190 | #### With Poetry 191 | ```bash 192 | poetry run python src/backtester.py --ticker AAPL,MSFT,NVDA 193 | ``` 194 | 195 | #### With Docker 196 | ```bash 197 | # On Linux/Mac: 198 | ./run.sh --ticker AAPL,MSFT,NVDA backtest 199 | 200 | # On Windows: 201 | run.bat --ticker AAPL,MSFT,NVDA backtest 202 | ``` 203 | 204 | **Example Output:** 205 | Screenshot 2025-01-06 at 5 47 52 PM 206 | 207 | 208 | You can optionally specify the start and end dates to backtest over a specific time period. 209 | 210 | ```bash 211 | # With Poetry: 212 | poetry run python src/backtester.py --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 213 | 214 | # With Docker (on Linux/Mac): 215 | ./run.sh --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 backtest 216 | 217 | # With Docker (on Windows): 218 | run.bat --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 backtest 219 | ``` 220 | 221 | You can also specify a `--ollama` flag to run the backtester using local LLMs. 222 | ```bash 223 | # With Poetry: 224 | poetry run python src/backtester.py --ticker AAPL,MSFT,NVDA --ollama 225 | 226 | # With Docker (on Linux/Mac): 227 | ./run.sh --ticker AAPL,MSFT,NVDA --ollama backtest 228 | 229 | # With Docker (on Windows): 230 | run.bat --ticker AAPL,MSFT,NVDA --ollama backtest 231 | ``` 232 | 233 | 234 | ## Project Structure 235 | ``` 236 | ai-hedge-fund/ 237 | ├── src/ 238 | │ ├── agents/ # Agent definitions and workflow 239 | │ │ ├── bill_ackman.py # Bill Ackman agent 240 | │ │ ├── fundamentals.py # Fundamental analysis agent 241 | │ │ ├── portfolio_manager.py # Portfolio management agent 242 | │ │ ├── risk_manager.py # Risk management agent 243 | │ │ ├── sentiment.py # Sentiment analysis agent 244 | │ │ ├── technicals.py # Technical analysis agent 245 | │ │ ├── valuation.py # Valuation analysis agent 246 | │ │ ├── ... # Other agents 247 | │ │ ├── warren_buffett.py # Warren Buffett agent 248 | │ ├── tools/ # Agent tools 249 | │ │ ├── api.py # API tools 250 | │ ├── backtester.py # Backtesting tools 251 | │ ├── main.py # Main entry point 252 | ├── pyproject.toml 253 | ├── ... 254 | ``` 255 | 256 | ## Contributing 257 | 258 | 1. Fork the repository 259 | 2. Create a feature branch 260 | 3. Commit your changes 261 | 4. Push to the branch 262 | 5. Create a Pull Request 263 | 264 | **Important**: Please keep your pull requests small and focused. This will make it easier to review and merge. 265 | 266 | ## Feature Requests 267 | 268 | If you have a feature request, please open an [issue](https://github.com/virattt/ai-hedge-fund/issues) and make sure it is tagged with `enhancement`. 269 | 270 | ## License 271 | 272 | This project is licensed under the MIT License - see the LICENSE file for details. 273 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | ollama: 3 | image: ollama/ollama:latest 4 | container_name: ollama 5 | environment: 6 | - OLLAMA_HOST=0.0.0.0 7 | # Apple Silicon GPU acceleration 8 | - METAL_DEVICE=on 9 | - METAL_DEVICE_INDEX=0 10 | volumes: 11 | - ollama_data:/root/.ollama 12 | ports: 13 | - "11434:11434" 14 | restart: unless-stopped 15 | 16 | hedge-fund: 17 | build: . 18 | image: ai-hedge-fund 19 | depends_on: 20 | - ollama 21 | volumes: 22 | - ./.env:/app/.env 23 | command: python src/main.py --ticker AAPL,MSFT,NVDA 24 | environment: 25 | - PYTHONUNBUFFERED=1 26 | - OLLAMA_BASE_URL=http://ollama:11434 27 | tty: true 28 | stdin_open: true 29 | 30 | hedge-fund-reasoning: 31 | build: . 32 | image: ai-hedge-fund 33 | depends_on: 34 | - ollama 35 | volumes: 36 | - ./.env:/app/.env 37 | command: python src/main.py --ticker AAPL,MSFT,NVDA --show-reasoning 38 | environment: 39 | - PYTHONUNBUFFERED=1 40 | - OLLAMA_BASE_URL=http://ollama:11434 41 | tty: true 42 | stdin_open: true 43 | 44 | hedge-fund-ollama: 45 | build: . 46 | image: ai-hedge-fund 47 | depends_on: 48 | - ollama 49 | volumes: 50 | - ./.env:/app/.env 51 | command: python src/main.py --ticker AAPL,MSFT,NVDA --ollama 52 | environment: 53 | - PYTHONUNBUFFERED=1 54 | - OLLAMA_BASE_URL=http://ollama:11434 55 | tty: true 56 | stdin_open: true 57 | 58 | backtester: 59 | build: . 60 | image: ai-hedge-fund 61 | depends_on: 62 | - ollama 63 | volumes: 64 | - ./.env:/app/.env 65 | command: python src/backtester.py --ticker AAPL,MSFT,NVDA 66 | environment: 67 | - PYTHONUNBUFFERED=1 68 | - OLLAMA_BASE_URL=http://ollama:11434 69 | tty: true 70 | stdin_open: true 71 | 72 | backtester-ollama: 73 | build: . 74 | image: ai-hedge-fund 75 | depends_on: 76 | - ollama 77 | volumes: 78 | - ./.env:/app/.env 79 | command: python src/backtester.py --ticker AAPL,MSFT,NVDA --ollama 80 | environment: 81 | - PYTHONUNBUFFERED=1 82 | - OLLAMA_BASE_URL=http://ollama:11434 83 | tty: true 84 | stdin_open: true 85 | 86 | volumes: 87 | ollama_data: -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "ai-hedge-fund" 3 | version = "0.1.0" 4 | description = "An AI-powered hedge fund that uses multiple agents to make trading decisions" 5 | authors = ["Your Name "] 6 | readme = "README.md" 7 | packages = [ 8 | { include = "src", from = "." } 9 | ] 10 | [tool.poetry.dependencies] 11 | python = "^3.9" 12 | langchain = "0.3.0" 13 | langchain-anthropic = "0.3.5" 14 | langchain-groq = "0.2.3" 15 | langchain-openai = "^0.3.5" 16 | langchain-deepseek = "^0.1.2" 17 | langchain-ollama = "^0.2.0" 18 | langgraph = "0.2.56" 19 | pandas = "^2.1.0" 20 | numpy = "^1.24.0" 21 | python-dotenv = "1.0.0" 22 | matplotlib = "^3.9.2" 23 | tabulate = "^0.9.0" 24 | colorama = "^0.4.6" 25 | questionary = "^2.1.0" 26 | rich = "^13.9.4" 27 | langchain-google-genai = "^2.0.11" 28 | 29 | [tool.poetry.group.dev.dependencies] 30 | pytest = "^7.4.0" 31 | black = "^23.7.0" 32 | isort = "^5.12.0" 33 | flake8 = "^6.1.0" 34 | 35 | [build-system] 36 | requires = ["poetry-core"] 37 | build-backend = "poetry.core.masonry.api" 38 | 39 | [tool.black] 40 | line-length = 420 41 | target-version = ['py39'] 42 | include = '\.pyi?$' -------------------------------------------------------------------------------- /run.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal enabledelayedexpansion 3 | 4 | :: Default values 5 | set TICKER=AAPL,MSFT,NVDA 6 | set USE_OLLAMA= 7 | set START_DATE= 8 | set END_DATE= 9 | set INITIAL_AMOUNT=100000.0 10 | set MARGIN_REQUIREMENT=0.0 11 | set SHOW_REASONING= 12 | set COMMAND= 13 | set MODEL_NAME= 14 | 15 | :: Help function 16 | :show_help 17 | echo AI Hedge Fund Docker Runner 18 | echo. 19 | echo Usage: run.bat [OPTIONS] COMMAND 20 | echo. 21 | echo Options: 22 | echo --ticker SYMBOLS Comma-separated list of ticker symbols (e.g., AAPL,MSFT,NVDA) 23 | echo --start-date DATE Start date in YYYY-MM-DD format 24 | echo --end-date DATE End date in YYYY-MM-DD format 25 | echo --initial-cash AMT Initial cash position (default: 100000.0) 26 | echo --margin-requirement RATIO Margin requirement ratio (default: 0.0) 27 | echo --ollama Use Ollama for local LLM inference 28 | echo --show-reasoning Show reasoning from each agent 29 | echo. 30 | echo Commands: 31 | echo main Run the main hedge fund application 32 | echo backtest Run the backtester 33 | echo build Build the Docker image 34 | echo compose Run using Docker Compose with integrated Ollama 35 | echo ollama Start only the Ollama container for model management 36 | echo pull MODEL Pull a specific model into the Ollama container 37 | echo help Show this help message 38 | echo. 39 | echo Examples: 40 | echo run.bat --ticker AAPL,MSFT,NVDA main 41 | echo run.bat --ticker AAPL,MSFT,NVDA --ollama main 42 | echo run.bat --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 backtest 43 | echo run.bat compose # Run with Docker Compose (includes Ollama) 44 | echo run.bat ollama # Start only the Ollama container 45 | echo run.bat pull llama3 # Pull the llama3 model to Ollama 46 | echo. 47 | goto :eof 48 | 49 | :: Parse arguments 50 | :parse_args 51 | if "%~1"=="" goto :check_command 52 | if "%~1"=="--ticker" ( 53 | set TICKER=%~2 54 | shift 55 | shift 56 | goto :parse_args 57 | ) 58 | if "%~1"=="--start-date" ( 59 | set START_DATE=--start-date %~2 60 | shift 61 | shift 62 | goto :parse_args 63 | ) 64 | if "%~1"=="--end-date" ( 65 | set END_DATE=--end-date %~2 66 | shift 67 | shift 68 | goto :parse_args 69 | ) 70 | if "%~1"=="--initial-cash" ( 71 | set INITIAL_AMOUNT=%~2 72 | shift 73 | shift 74 | goto :parse_args 75 | ) 76 | if "%~1"=="--margin-requirement" ( 77 | set MARGIN_REQUIREMENT=%~2 78 | shift 79 | shift 80 | goto :parse_args 81 | ) 82 | if "%~1"=="--ollama" ( 83 | set USE_OLLAMA=--ollama 84 | shift 85 | goto :parse_args 86 | ) 87 | if "%~1"=="--show-reasoning" ( 88 | set SHOW_REASONING=--show-reasoning 89 | shift 90 | goto :parse_args 91 | ) 92 | if "%~1"=="main" ( 93 | set COMMAND=main 94 | shift 95 | goto :parse_args 96 | ) 97 | if "%~1"=="backtest" ( 98 | set COMMAND=backtest 99 | shift 100 | goto :parse_args 101 | ) 102 | if "%~1"=="build" ( 103 | set COMMAND=build 104 | shift 105 | goto :parse_args 106 | ) 107 | if "%~1"=="compose" ( 108 | set COMMAND=compose 109 | shift 110 | goto :parse_args 111 | ) 112 | if "%~1"=="ollama" ( 113 | set COMMAND=ollama 114 | shift 115 | goto :parse_args 116 | ) 117 | if "%~1"=="pull" ( 118 | set COMMAND=pull 119 | set MODEL_NAME=%~2 120 | shift 121 | shift 122 | goto :parse_args 123 | ) 124 | if "%~1"=="help" ( 125 | call :show_help 126 | exit /b 0 127 | ) 128 | if "%~1"=="--help" ( 129 | call :show_help 130 | exit /b 0 131 | ) 132 | echo Unknown option: %~1 133 | call :show_help 134 | exit /b 1 135 | 136 | :check_command 137 | if "!COMMAND!"=="" ( 138 | echo Error: No command specified. 139 | call :show_help 140 | exit /b 1 141 | ) 142 | 143 | :: Show help if 'help' command is provided 144 | if "!COMMAND!"=="help" ( 145 | call :show_help 146 | exit /b 0 147 | ) 148 | 149 | :: Check for Docker Compose existence 150 | docker compose version >nul 2>&1 151 | if !ERRORLEVEL! EQU 0 ( 152 | set COMPOSE_CMD=docker compose 153 | ) else ( 154 | docker-compose --version >nul 2>&1 155 | if !ERRORLEVEL! EQU 0 ( 156 | set COMPOSE_CMD=docker-compose 157 | ) else ( 158 | echo Error: Docker Compose is not installed. 159 | exit /b 1 160 | ) 161 | ) 162 | 163 | :: Build the Docker image if 'build' command is provided 164 | if "!COMMAND!"=="build" ( 165 | docker build -t ai-hedge-fund . 166 | exit /b 0 167 | ) 168 | 169 | :: Start Ollama container if 'ollama' command is provided 170 | if "!COMMAND!"=="ollama" ( 171 | echo Starting Ollama container... 172 | !COMPOSE_CMD! up -d ollama 173 | 174 | :: Check if Ollama is running 175 | echo Waiting for Ollama to start... 176 | for /l %%i in (1, 1, 30) do ( 177 | !COMPOSE_CMD! exec ollama curl -s http://localhost:11434/api/version >nul 2>&1 178 | if !ERRORLEVEL! EQU 0 ( 179 | echo Ollama is now running. 180 | :: Show available models 181 | echo Available models: 182 | !COMPOSE_CMD! exec ollama ollama list 183 | 184 | echo. 185 | echo Manage your models using: 186 | echo run.bat pull ^ # Download a model 187 | echo run.bat ollama # Start Ollama and show models 188 | exit /b 0 189 | ) 190 | timeout /t 1 /nobreak >nul 191 | echo. 192 | ) 193 | 194 | echo Failed to start Ollama within the expected time. You may need to check the container logs. 195 | exit /b 1 196 | ) 197 | 198 | :: Pull a model if 'pull' command is provided 199 | if "!COMMAND!"=="pull" ( 200 | if "!MODEL_NAME!"=="" ( 201 | echo Error: No model name specified. 202 | echo Usage: run.bat pull ^ 203 | echo Example: run.bat pull llama3 204 | exit /b 1 205 | ) 206 | 207 | :: Start Ollama if it's not already running 208 | !COMPOSE_CMD! up -d ollama 209 | 210 | :: Wait for Ollama to start 211 | echo Ensuring Ollama is running... 212 | for /l %%i in (1, 1, 30) do ( 213 | !COMPOSE_CMD! exec ollama curl -s http://localhost:11434/api/version >nul 2>&1 214 | if !ERRORLEVEL! EQU 0 ( 215 | echo Ollama is running. 216 | goto :pull_model 217 | ) 218 | timeout /t 1 /nobreak >nul 219 | echo. 220 | ) 221 | 222 | :pull_model 223 | :: Pull the model 224 | echo Pulling model: !MODEL_NAME! 225 | echo This may take some time depending on the model size and your internet connection. 226 | echo You can press Ctrl+C to cancel at any time (the model will continue downloading in the background). 227 | 228 | !COMPOSE_CMD! exec ollama ollama pull "!MODEL_NAME!" 229 | 230 | :: Check if the model was successfully pulled 231 | !COMPOSE_CMD! exec ollama ollama list | findstr /i "!MODEL_NAME!" >nul 232 | if !ERRORLEVEL! EQU 0 ( 233 | echo Model !MODEL_NAME! was successfully downloaded. 234 | ) else ( 235 | echo Warning: Model !MODEL_NAME! may not have been properly downloaded. 236 | echo Check the Ollama container status with: run.bat ollama 237 | ) 238 | 239 | exit /b 0 240 | ) 241 | 242 | :: Run with Docker Compose if 'compose' command is provided 243 | if "!COMMAND!"=="compose" ( 244 | echo Running with Docker Compose (includes Ollama)... 245 | !COMPOSE_CMD! up --build 246 | exit /b 0 247 | ) 248 | 249 | :: Check if .env file exists, if not create from .env.example 250 | if not exist .env ( 251 | if exist .env.example ( 252 | echo No .env file found. Creating from .env.example... 253 | copy .env.example .env 254 | echo Please edit .env file to add your API keys. 255 | ) else ( 256 | echo Error: No .env or .env.example file found. 257 | exit /b 1 258 | ) 259 | ) 260 | 261 | :: Set script path and parameters based on command 262 | if "!COMMAND!"=="main" ( 263 | set SCRIPT_PATH=src/main.py 264 | if "!COMMAND!"=="main" ( 265 | set INITIAL_PARAM=--initial-cash !INITIAL_AMOUNT! 266 | ) 267 | ) else if "!COMMAND!"=="backtest" ( 268 | set SCRIPT_PATH=src/backtester.py 269 | if "!COMMAND!"=="backtest" ( 270 | set INITIAL_PARAM=--initial-capital !INITIAL_AMOUNT! 271 | ) 272 | ) 273 | 274 | :: If using Ollama, make sure the service is started 275 | if not "!USE_OLLAMA!"=="" ( 276 | echo Setting up Ollama container for local LLM inference... 277 | 278 | :: Start Ollama container if not already running 279 | !COMPOSE_CMD! up -d ollama 280 | 281 | :: Wait for Ollama to start 282 | echo Waiting for Ollama to start... 283 | for /l %%i in (1, 1, 30) do ( 284 | !COMPOSE_CMD! exec ollama curl -s http://localhost:11434/api/version >nul 2>&1 285 | if !ERRORLEVEL! EQU 0 ( 286 | echo Ollama is running. 287 | :: Show available models 288 | echo Available models: 289 | !COMPOSE_CMD! exec ollama ollama list 290 | goto :continue_ollama 291 | ) 292 | timeout /t 1 /nobreak >nul 293 | echo. 294 | ) 295 | 296 | :continue_ollama 297 | :: Build the AI Hedge Fund image if needed 298 | docker images -q ai-hedge-fund 2>nul | findstr /r /c:"^..*$" >nul 299 | if !ERRORLEVEL! NEQ 0 ( 300 | echo Building AI Hedge Fund image... 301 | docker build -t ai-hedge-fund . 302 | ) 303 | 304 | :: Create command override for Docker Compose 305 | set COMMAND_OVERRIDE= 306 | 307 | if not "!START_DATE!"=="" ( 308 | set COMMAND_OVERRIDE=!COMMAND_OVERRIDE! !START_DATE! 309 | ) 310 | 311 | if not "!END_DATE!"=="" ( 312 | set COMMAND_OVERRIDE=!COMMAND_OVERRIDE! !END_DATE! 313 | ) 314 | 315 | if not "!INITIAL_PARAM!"=="" ( 316 | set COMMAND_OVERRIDE=!COMMAND_OVERRIDE! !INITIAL_PARAM! 317 | ) 318 | 319 | if not "!MARGIN_REQUIREMENT!"=="" ( 320 | set COMMAND_OVERRIDE=!COMMAND_OVERRIDE! --margin-requirement !MARGIN_REQUIREMENT! 321 | ) 322 | 323 | :: Run the command with Docker Compose 324 | echo Running AI Hedge Fund with Ollama using Docker Compose... 325 | 326 | :: Use the appropriate service based on command and reasoning flag 327 | if "!COMMAND!"=="main" ( 328 | if not "!SHOW_REASONING!"=="" ( 329 | !COMPOSE_CMD! run --rm hedge-fund-reasoning python src/main.py --ticker !TICKER! !COMMAND_OVERRIDE! !SHOW_REASONING! --ollama 330 | ) else ( 331 | !COMPOSE_CMD! run --rm hedge-fund-ollama python src/main.py --ticker !TICKER! !COMMAND_OVERRIDE! --ollama 332 | ) 333 | ) else if "!COMMAND!"=="backtest" ( 334 | !COMPOSE_CMD! run --rm backtester-ollama python src/backtester.py --ticker !TICKER! !COMMAND_OVERRIDE! !SHOW_REASONING! --ollama 335 | ) 336 | 337 | exit /b 0 338 | ) 339 | 340 | :: Standard Docker run (without Ollama) 341 | :: Build the command 342 | set CMD=docker run -it --rm -v %cd%\.env:/app/.env 343 | 344 | :: Add the command 345 | set CMD=!CMD! ai-hedge-fund python !SCRIPT_PATH! --ticker !TICKER! !START_DATE! !END_DATE! !INITIAL_PARAM! --margin-requirement !MARGIN_REQUIREMENT! !SHOW_REASONING! 346 | 347 | :: Run the command 348 | echo Running: !CMD! 349 | !CMD! 350 | 351 | :: Exit 352 | exit /b 0 353 | 354 | :: Start script execution 355 | call :parse_args %* -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Help text to display when --help is provided 4 | show_help() { 5 | echo "AI Hedge Fund Docker Runner" 6 | echo "" 7 | echo "Usage: ./run.sh [OPTIONS] COMMAND" 8 | echo "" 9 | echo "Options:" 10 | echo " --ticker SYMBOLS Comma-separated list of ticker symbols (e.g., AAPL,MSFT,NVDA)" 11 | echo " --start-date DATE Start date in YYYY-MM-DD format" 12 | echo " --end-date DATE End date in YYYY-MM-DD format" 13 | echo " --initial-cash AMT Initial cash position (default: 100000.0)" 14 | echo " --margin-requirement RATIO Margin requirement ratio (default: 0.0)" 15 | echo " --ollama Use Ollama for local LLM inference" 16 | echo " --show-reasoning Show reasoning from each agent" 17 | echo "" 18 | echo "Commands:" 19 | echo " main Run the main hedge fund application" 20 | echo " backtest Run the backtester" 21 | echo " build Build the Docker image" 22 | echo " compose Run using Docker Compose with integrated Ollama" 23 | echo " ollama Start only the Ollama container for model management" 24 | echo " pull MODEL Pull a specific model into the Ollama container" 25 | echo " help Show this help message" 26 | echo "" 27 | echo "Examples:" 28 | echo " ./run.sh --ticker AAPL,MSFT,NVDA main" 29 | echo " ./run.sh --ticker AAPL,MSFT,NVDA --ollama main" 30 | echo " ./run.sh --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 backtest" 31 | echo " ./run.sh compose # Run with Docker Compose (includes Ollama)" 32 | echo " ./run.sh ollama # Start only the Ollama container" 33 | echo " ./run.sh pull llama3 # Pull the llama3 model to Ollama" 34 | echo "" 35 | } 36 | 37 | # Default values 38 | TICKER="AAPL,MSFT,NVDA" 39 | USE_OLLAMA="" 40 | START_DATE="" 41 | END_DATE="" 42 | INITIAL_AMOUNT="100000.0" 43 | MARGIN_REQUIREMENT="0.0" 44 | SHOW_REASONING="" 45 | COMMAND="" 46 | MODEL_NAME="" 47 | 48 | # Parse arguments 49 | while [[ $# -gt 0 ]]; do 50 | case $1 in 51 | --ticker) 52 | TICKER="$2" 53 | shift 2 54 | ;; 55 | --start-date) 56 | START_DATE="--start-date $2" 57 | shift 2 58 | ;; 59 | --end-date) 60 | END_DATE="--end-date $2" 61 | shift 2 62 | ;; 63 | --initial-cash) 64 | INITIAL_AMOUNT="$2" 65 | shift 2 66 | ;; 67 | --margin-requirement) 68 | MARGIN_REQUIREMENT="$2" 69 | shift 2 70 | ;; 71 | --ollama) 72 | USE_OLLAMA="--ollama" 73 | shift 74 | ;; 75 | --show-reasoning) 76 | SHOW_REASONING="--show-reasoning" 77 | shift 78 | ;; 79 | main|backtest|build|help|compose|ollama) 80 | COMMAND="$1" 81 | shift 82 | ;; 83 | pull) 84 | COMMAND="pull" 85 | MODEL_NAME="$2" 86 | shift 2 87 | ;; 88 | --help) 89 | show_help 90 | exit 0 91 | ;; 92 | *) 93 | echo "Unknown option: $1" 94 | show_help 95 | exit 1 96 | ;; 97 | esac 98 | done 99 | 100 | # Check if command is provided 101 | if [ -z "$COMMAND" ]; then 102 | echo "Error: No command specified." 103 | show_help 104 | exit 1 105 | fi 106 | 107 | # Show help if 'help' command is provided 108 | if [ "$COMMAND" = "help" ]; then 109 | show_help 110 | exit 0 111 | fi 112 | 113 | # Check for Docker Compose existence 114 | if ! command -v docker-compose &> /dev/null && ! docker compose version &> /dev/null; then 115 | echo "Error: Docker Compose is not installed." 116 | exit 1 117 | fi 118 | 119 | # Determine which Docker Compose command to use 120 | if command -v docker-compose &> /dev/null; then 121 | COMPOSE_CMD="docker-compose" 122 | else 123 | COMPOSE_CMD="docker compose" 124 | fi 125 | 126 | # Detect system architecture for GPU configuration 127 | ARCH=$(uname -m) 128 | OS=$(uname -s) 129 | GPU_CONFIG="" 130 | 131 | # Set appropriate GPU configuration based on architecture 132 | if [ "$OS" = "Darwin" ] && { [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; }; then 133 | echo "Detected Apple Silicon (M-series) - Metal GPU acceleration should be enabled" 134 | # Metal GPU is handled via environment variables in docker-compose.yml 135 | elif command -v nvidia-smi &> /dev/null; then 136 | echo "NVIDIA GPU detected - Adding NVIDIA GPU configuration" 137 | GPU_CONFIG="-f docker-compose.yml -f docker-compose.nvidia.yml" 138 | fi 139 | 140 | # Build the Docker image if 'build' command is provided 141 | if [ "$COMMAND" = "build" ]; then 142 | docker build -t ai-hedge-fund . 143 | exit 0 144 | fi 145 | 146 | # Start Ollama container if 'ollama' command is provided 147 | if [ "$COMMAND" = "ollama" ]; then 148 | echo "Starting Ollama container..." 149 | $COMPOSE_CMD $GPU_CONFIG up -d ollama 150 | 151 | # Check if Ollama is running 152 | echo "Waiting for Ollama to start..." 153 | for i in {1..30}; do 154 | if docker run --rm --network=host curlimages/curl:latest curl -s http://localhost:11434/api/version &> /dev/null; then 155 | echo "Ollama is now running." 156 | # Show available models 157 | echo "Available models:" 158 | docker exec -t ollama ollama list 159 | 160 | echo -e "\nManage your models using:" 161 | echo " ./run.sh pull # Download a model" 162 | echo " ./run.sh ollama # Start Ollama and show models" 163 | exit 0 164 | fi 165 | echo -n "." 166 | sleep 1 167 | done 168 | 169 | echo "Failed to start Ollama within the expected time. You may need to check the container logs." 170 | exit 1 171 | fi 172 | 173 | # Pull a model if 'pull' command is provided 174 | if [ "$COMMAND" = "pull" ]; then 175 | if [ -z "$MODEL_NAME" ]; then 176 | echo "Error: No model name specified." 177 | echo "Usage: ./run.sh pull " 178 | echo "Example: ./run.sh pull llama3" 179 | exit 1 180 | fi 181 | 182 | # Start Ollama if it's not already running 183 | $COMPOSE_CMD $GPU_CONFIG up -d ollama 184 | 185 | # Wait for Ollama to start 186 | echo "Ensuring Ollama is running..." 187 | for i in {1..30}; do 188 | if docker run --rm --network=host curlimages/curl:latest curl -s http://localhost:11434/api/version &> /dev/null; then 189 | echo "Ollama is running." 190 | break 191 | fi 192 | echo -n "." 193 | sleep 1 194 | done 195 | 196 | # Pull the model 197 | echo "Pulling model: $MODEL_NAME" 198 | echo "This may take some time depending on the model size and your internet connection." 199 | echo "You can press Ctrl+C to cancel at any time (the model will continue downloading in the background)." 200 | 201 | docker exec -t ollama ollama pull "$MODEL_NAME" 202 | 203 | # Check if the model was successfully pulled 204 | if docker exec -t ollama ollama list | grep -q "$MODEL_NAME"; then 205 | echo "Model $MODEL_NAME was successfully downloaded." 206 | else 207 | echo "Warning: Model $MODEL_NAME may not have been properly downloaded." 208 | echo "Check the Ollama container status with: ./run.sh ollama" 209 | fi 210 | 211 | exit 0 212 | fi 213 | 214 | # Run with Docker Compose 215 | if [ "$COMMAND" = "compose" ]; then 216 | echo "Running with Docker Compose (includes Ollama)..." 217 | $COMPOSE_CMD $GPU_CONFIG up --build 218 | exit 0 219 | fi 220 | 221 | # Check if .env file exists, if not create from .env.example 222 | if [ ! -f .env ]; then 223 | if [ -f .env.example ]; then 224 | echo "No .env file found. Creating from .env.example..." 225 | cp .env.example .env 226 | echo "Please edit .env file to add your API keys." 227 | else 228 | echo "Error: No .env or .env.example file found." 229 | exit 1 230 | fi 231 | fi 232 | 233 | # Set script path and parameters based on command 234 | if [ "$COMMAND" = "main" ]; then 235 | SCRIPT_PATH="src/main.py" 236 | if [ "$COMMAND" = "main" ]; then 237 | INITIAL_PARAM="--initial-cash $INITIAL_AMOUNT" 238 | fi 239 | elif [ "$COMMAND" = "backtest" ]; then 240 | SCRIPT_PATH="src/backtester.py" 241 | if [ "$COMMAND" = "backtest" ]; then 242 | INITIAL_PARAM="--initial-capital $INITIAL_AMOUNT" 243 | fi 244 | fi 245 | 246 | # If using Ollama, make sure the service is started 247 | if [ -n "$USE_OLLAMA" ]; then 248 | echo "Setting up Ollama container for local LLM inference..." 249 | 250 | # Start Ollama container if not already running 251 | $COMPOSE_CMD $GPU_CONFIG up -d ollama 252 | 253 | # Wait for Ollama to start 254 | echo "Waiting for Ollama to start..." 255 | for i in {1..30}; do 256 | if docker run --rm --network=host curlimages/curl:latest curl -s http://localhost:11434/api/version &> /dev/null; then 257 | echo "Ollama is running." 258 | # Show available models 259 | echo "Available models:" 260 | docker exec -t ollama ollama list 261 | break 262 | fi 263 | echo -n "." 264 | sleep 1 265 | done 266 | 267 | # Build the AI Hedge Fund image if needed 268 | if [[ "$(docker images -q ai-hedge-fund 2> /dev/null)" == "" ]]; then 269 | echo "Building AI Hedge Fund image..." 270 | docker build -t ai-hedge-fund . 271 | fi 272 | 273 | # Create command override for Docker Compose 274 | COMMAND_OVERRIDE="" 275 | 276 | if [ -n "$START_DATE" ]; then 277 | COMMAND_OVERRIDE="$COMMAND_OVERRIDE $START_DATE" 278 | fi 279 | 280 | if [ -n "$END_DATE" ]; then 281 | COMMAND_OVERRIDE="$COMMAND_OVERRIDE $END_DATE" 282 | fi 283 | 284 | if [ -n "$INITIAL_PARAM" ]; then 285 | COMMAND_OVERRIDE="$COMMAND_OVERRIDE $INITIAL_PARAM" 286 | fi 287 | 288 | if [ -n "$MARGIN_REQUIREMENT" ]; then 289 | COMMAND_OVERRIDE="$COMMAND_OVERRIDE --margin-requirement $MARGIN_REQUIREMENT" 290 | fi 291 | 292 | # Run the command with Docker Compose 293 | echo "Running AI Hedge Fund with Ollama using Docker Compose..." 294 | 295 | # Use the appropriate service based on command and reasoning flag 296 | if [ "$COMMAND" = "main" ]; then 297 | if [ -n "$SHOW_REASONING" ]; then 298 | $COMPOSE_CMD $GPU_CONFIG run --rm hedge-fund-reasoning python src/main.py --ticker $TICKER $COMMAND_OVERRIDE $SHOW_REASONING --ollama 299 | else 300 | $COMPOSE_CMD $GPU_CONFIG run --rm hedge-fund-ollama python src/main.py --ticker $TICKER $COMMAND_OVERRIDE --ollama 301 | fi 302 | elif [ "$COMMAND" = "backtest" ]; then 303 | $COMPOSE_CMD $GPU_CONFIG run --rm backtester-ollama python src/backtester.py --ticker $TICKER $COMMAND_OVERRIDE $SHOW_REASONING --ollama 304 | fi 305 | 306 | exit 0 307 | fi 308 | 309 | # Standard Docker run (without Ollama) 310 | # Build the command 311 | CMD="docker run -it --rm -v $(pwd)/.env:/app/.env" 312 | 313 | # Add the command 314 | CMD="$CMD ai-hedge-fund python $SCRIPT_PATH --ticker $TICKER $START_DATE $END_DATE $INITIAL_PARAM --margin-requirement $MARGIN_REQUIREMENT $SHOW_REASONING" 315 | 316 | # Run the command 317 | echo "Running: $CMD" 318 | $CMD -------------------------------------------------------------------------------- /src/agents/ben_graham.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | from graph.state import AgentState, show_agent_reasoning 3 | from tools.api import get_financial_metrics, get_market_cap, search_line_items 4 | from langchain_core.prompts import ChatPromptTemplate 5 | from langchain_core.messages import HumanMessage 6 | from pydantic import BaseModel 7 | import json 8 | from typing_extensions import Literal 9 | from utils.progress import progress 10 | from utils.llm import call_llm 11 | import math 12 | 13 | 14 | class BenGrahamSignal(BaseModel): 15 | signal: Literal["bullish", "bearish", "neutral"] 16 | confidence: float 17 | reasoning: str 18 | 19 | 20 | def ben_graham_agent(state: AgentState): 21 | """ 22 | Analyzes stocks using Benjamin Graham's classic value-investing principles: 23 | 1. Earnings stability over multiple years. 24 | 2. Solid financial strength (low debt, adequate liquidity). 25 | 3. Discount to intrinsic value (e.g. Graham Number or net-net). 26 | 4. Adequate margin of safety. 27 | """ 28 | data = state["data"] 29 | end_date = data["end_date"] 30 | tickers = data["tickers"] 31 | 32 | analysis_data = {} 33 | graham_analysis = {} 34 | 35 | for ticker in tickers: 36 | progress.update_status("ben_graham_agent", ticker, "Fetching financial metrics") 37 | metrics = get_financial_metrics(ticker, end_date, period="annual", limit=10) 38 | 39 | progress.update_status("ben_graham_agent", ticker, "Gathering financial line items") 40 | financial_line_items = search_line_items(ticker, ["earnings_per_share", "revenue", "net_income", "book_value_per_share", "total_assets", "total_liabilities", "current_assets", "current_liabilities", "dividends_and_other_cash_distributions", "outstanding_shares"], end_date, period="annual", limit=10) 41 | 42 | progress.update_status("ben_graham_agent", ticker, "Getting market cap") 43 | market_cap = get_market_cap(ticker, end_date) 44 | 45 | # Perform sub-analyses 46 | progress.update_status("ben_graham_agent", ticker, "Analyzing earnings stability") 47 | earnings_analysis = analyze_earnings_stability(metrics, financial_line_items) 48 | 49 | progress.update_status("ben_graham_agent", ticker, "Analyzing financial strength") 50 | strength_analysis = analyze_financial_strength(financial_line_items) 51 | 52 | progress.update_status("ben_graham_agent", ticker, "Analyzing Graham valuation") 53 | valuation_analysis = analyze_valuation_graham(financial_line_items, market_cap) 54 | 55 | # Aggregate scoring 56 | total_score = earnings_analysis["score"] + strength_analysis["score"] + valuation_analysis["score"] 57 | max_possible_score = 15 # total possible from the three analysis functions 58 | 59 | # Map total_score to signal 60 | if total_score >= 0.7 * max_possible_score: 61 | signal = "bullish" 62 | elif total_score <= 0.3 * max_possible_score: 63 | signal = "bearish" 64 | else: 65 | signal = "neutral" 66 | 67 | analysis_data[ticker] = {"signal": signal, "score": total_score, "max_score": max_possible_score, "earnings_analysis": earnings_analysis, "strength_analysis": strength_analysis, "valuation_analysis": valuation_analysis} 68 | 69 | progress.update_status("ben_graham_agent", ticker, "Generating Ben Graham analysis") 70 | graham_output = generate_graham_output( 71 | ticker=ticker, 72 | analysis_data=analysis_data, 73 | model_name=state["metadata"]["model_name"], 74 | model_provider=state["metadata"]["model_provider"], 75 | ) 76 | 77 | graham_analysis[ticker] = {"signal": graham_output.signal, "confidence": graham_output.confidence, "reasoning": graham_output.reasoning} 78 | 79 | progress.update_status("ben_graham_agent", ticker, "Done") 80 | 81 | # Wrap results in a single message for the chain 82 | message = HumanMessage(content=json.dumps(graham_analysis), name="ben_graham_agent") 83 | 84 | # Optionally display reasoning 85 | if state["metadata"]["show_reasoning"]: 86 | show_agent_reasoning(graham_analysis, "Ben Graham Agent") 87 | 88 | # Store signals in the overall state 89 | state["data"]["analyst_signals"]["ben_graham_agent"] = graham_analysis 90 | 91 | return {"messages": [message], "data": state["data"]} 92 | 93 | 94 | def analyze_earnings_stability(metrics: list, financial_line_items: list) -> dict: 95 | """ 96 | Graham wants at least several years of consistently positive earnings (ideally 5+). 97 | We'll check: 98 | 1. Number of years with positive EPS. 99 | 2. Growth in EPS from first to last period. 100 | """ 101 | score = 0 102 | details = [] 103 | 104 | if not metrics or not financial_line_items: 105 | return {"score": score, "details": "Insufficient data for earnings stability analysis"} 106 | 107 | eps_vals = [] 108 | for item in financial_line_items: 109 | if item.earnings_per_share is not None: 110 | eps_vals.append(item.earnings_per_share) 111 | 112 | if len(eps_vals) < 2: 113 | details.append("Not enough multi-year EPS data.") 114 | return {"score": score, "details": "; ".join(details)} 115 | 116 | # 1. Consistently positive EPS 117 | positive_eps_years = sum(1 for e in eps_vals if e > 0) 118 | total_eps_years = len(eps_vals) 119 | if positive_eps_years == total_eps_years: 120 | score += 3 121 | details.append("EPS was positive in all available periods.") 122 | elif positive_eps_years >= (total_eps_years * 0.8): 123 | score += 2 124 | details.append("EPS was positive in most periods.") 125 | else: 126 | details.append("EPS was negative in multiple periods.") 127 | 128 | # 2. EPS growth from earliest to latest 129 | if eps_vals[0] > eps_vals[-1]: 130 | score += 1 131 | details.append("EPS grew from earliest to latest period.") 132 | else: 133 | details.append("EPS did not grow from earliest to latest period.") 134 | 135 | return {"score": score, "details": "; ".join(details)} 136 | 137 | 138 | def analyze_financial_strength(financial_line_items: list) -> dict: 139 | """ 140 | Graham checks liquidity (current ratio >= 2), manageable debt, 141 | and dividend record (preferably some history of dividends). 142 | """ 143 | score = 0 144 | details = [] 145 | 146 | if not financial_line_items: 147 | return {"score": score, "details": "No data for financial strength analysis"} 148 | 149 | latest_item = financial_line_items[0] 150 | total_assets = latest_item.total_assets or 0 151 | total_liabilities = latest_item.total_liabilities or 0 152 | current_assets = latest_item.current_assets or 0 153 | current_liabilities = latest_item.current_liabilities or 0 154 | 155 | # 1. Current ratio 156 | if current_liabilities > 0: 157 | current_ratio = current_assets / current_liabilities 158 | if current_ratio >= 2.0: 159 | score += 2 160 | details.append(f"Current ratio = {current_ratio:.2f} (>=2.0: solid).") 161 | elif current_ratio >= 1.5: 162 | score += 1 163 | details.append(f"Current ratio = {current_ratio:.2f} (moderately strong).") 164 | else: 165 | details.append(f"Current ratio = {current_ratio:.2f} (<1.5: weaker liquidity).") 166 | else: 167 | details.append("Cannot compute current ratio (missing or zero current_liabilities).") 168 | 169 | # 2. Debt vs. Assets 170 | if total_assets > 0: 171 | debt_ratio = total_liabilities / total_assets 172 | if debt_ratio < 0.5: 173 | score += 2 174 | details.append(f"Debt ratio = {debt_ratio:.2f}, under 0.50 (conservative).") 175 | elif debt_ratio < 0.8: 176 | score += 1 177 | details.append(f"Debt ratio = {debt_ratio:.2f}, somewhat high but could be acceptable.") 178 | else: 179 | details.append(f"Debt ratio = {debt_ratio:.2f}, quite high by Graham standards.") 180 | else: 181 | details.append("Cannot compute debt ratio (missing total_assets).") 182 | 183 | # 3. Dividend track record 184 | div_periods = [item.dividends_and_other_cash_distributions for item in financial_line_items if item.dividends_and_other_cash_distributions is not None] 185 | if div_periods: 186 | # In many data feeds, dividend outflow is shown as a negative number 187 | # (money going out to shareholders). We'll consider any negative as 'paid a dividend'. 188 | div_paid_years = sum(1 for d in div_periods if d < 0) 189 | if div_paid_years > 0: 190 | # e.g. if at least half the periods had dividends 191 | if div_paid_years >= (len(div_periods) // 2 + 1): 192 | score += 1 193 | details.append("Company paid dividends in the majority of the reported years.") 194 | else: 195 | details.append("Company has some dividend payments, but not most years.") 196 | else: 197 | details.append("Company did not pay dividends in these periods.") 198 | else: 199 | details.append("No dividend data available to assess payout consistency.") 200 | 201 | return {"score": score, "details": "; ".join(details)} 202 | 203 | 204 | def analyze_valuation_graham(financial_line_items: list, market_cap: float) -> dict: 205 | """ 206 | Core Graham approach to valuation: 207 | 1. Net-Net Check: (Current Assets - Total Liabilities) vs. Market Cap 208 | 2. Graham Number: sqrt(22.5 * EPS * Book Value per Share) 209 | 3. Compare per-share price to Graham Number => margin of safety 210 | """ 211 | if not financial_line_items or not market_cap or market_cap <= 0: 212 | return {"score": 0, "details": "Insufficient data to perform valuation"} 213 | 214 | latest = financial_line_items[0] 215 | current_assets = latest.current_assets or 0 216 | total_liabilities = latest.total_liabilities or 0 217 | book_value_ps = latest.book_value_per_share or 0 218 | eps = latest.earnings_per_share or 0 219 | shares_outstanding = latest.outstanding_shares or 0 220 | 221 | details = [] 222 | score = 0 223 | 224 | # 1. Net-Net Check 225 | # NCAV = Current Assets - Total Liabilities 226 | # If NCAV > Market Cap => historically a strong buy signal 227 | net_current_asset_value = current_assets - total_liabilities 228 | if net_current_asset_value > 0 and shares_outstanding > 0: 229 | net_current_asset_value_per_share = net_current_asset_value / shares_outstanding 230 | price_per_share = market_cap / shares_outstanding if shares_outstanding else 0 231 | 232 | details.append(f"Net Current Asset Value = {net_current_asset_value:,.2f}") 233 | details.append(f"NCAV Per Share = {net_current_asset_value_per_share:,.2f}") 234 | details.append(f"Price Per Share = {price_per_share:,.2f}") 235 | 236 | if net_current_asset_value > market_cap: 237 | score += 4 # Very strong Graham signal 238 | details.append("Net-Net: NCAV > Market Cap (classic Graham deep value).") 239 | else: 240 | # For partial net-net discount 241 | if net_current_asset_value_per_share >= (price_per_share * 0.67): 242 | score += 2 243 | details.append("NCAV Per Share >= 2/3 of Price Per Share (moderate net-net discount).") 244 | else: 245 | details.append("NCAV not exceeding market cap or insufficient data for net-net approach.") 246 | 247 | # 2. Graham Number 248 | # GrahamNumber = sqrt(22.5 * EPS * BVPS). 249 | # Compare the result to the current price_per_share 250 | # If GrahamNumber >> price, indicates undervaluation 251 | graham_number = None 252 | if eps > 0 and book_value_ps > 0: 253 | graham_number = math.sqrt(22.5 * eps * book_value_ps) 254 | details.append(f"Graham Number = {graham_number:.2f}") 255 | else: 256 | details.append("Unable to compute Graham Number (EPS or Book Value missing/<=0).") 257 | 258 | # 3. Margin of Safety relative to Graham Number 259 | if graham_number and shares_outstanding > 0: 260 | current_price = market_cap / shares_outstanding 261 | if current_price > 0: 262 | margin_of_safety = (graham_number - current_price) / current_price 263 | details.append(f"Margin of Safety (Graham Number) = {margin_of_safety:.2%}") 264 | if margin_of_safety > 0.5: 265 | score += 3 266 | details.append("Price is well below Graham Number (>=50% margin).") 267 | elif margin_of_safety > 0.2: 268 | score += 1 269 | details.append("Some margin of safety relative to Graham Number.") 270 | else: 271 | details.append("Price close to or above Graham Number, low margin of safety.") 272 | else: 273 | details.append("Current price is zero or invalid; can't compute margin of safety.") 274 | # else: already appended details for missing graham_number 275 | 276 | return {"score": score, "details": "; ".join(details)} 277 | 278 | 279 | def generate_graham_output( 280 | ticker: str, 281 | analysis_data: dict[str, any], 282 | model_name: str, 283 | model_provider: str, 284 | ) -> BenGrahamSignal: 285 | """ 286 | Generates an investment decision in the style of Benjamin Graham: 287 | - Value emphasis, margin of safety, net-nets, conservative balance sheet, stable earnings. 288 | - Return the result in a JSON structure: { signal, confidence, reasoning }. 289 | """ 290 | 291 | template = ChatPromptTemplate.from_messages([ 292 | ( 293 | "system", 294 | """You are a Benjamin Graham AI agent, making investment decisions using his principles: 295 | 1. Insist on a margin of safety by buying below intrinsic value (e.g., using Graham Number, net-net). 296 | 2. Emphasize the company's financial strength (low leverage, ample current assets). 297 | 3. Prefer stable earnings over multiple years. 298 | 4. Consider dividend record for extra safety. 299 | 5. Avoid speculative or high-growth assumptions; focus on proven metrics. 300 | 301 | When providing your reasoning, be thorough and specific by: 302 | 1. Explaining the key valuation metrics that influenced your decision the most (Graham Number, NCAV, P/E, etc.) 303 | 2. Highlighting the specific financial strength indicators (current ratio, debt levels, etc.) 304 | 3. Referencing the stability or instability of earnings over time 305 | 4. Providing quantitative evidence with precise numbers 306 | 5. Comparing current metrics to Graham's specific thresholds (e.g., "Current ratio of 2.5 exceeds Graham's minimum of 2.0") 307 | 6. Using Benjamin Graham's conservative, analytical voice and style in your explanation 308 | 309 | For example, if bullish: "The stock trades at a 35% discount to net current asset value, providing an ample margin of safety. The current ratio of 2.5 and debt-to-equity of 0.3 indicate strong financial position..." 310 | For example, if bearish: "Despite consistent earnings, the current price of $50 exceeds our calculated Graham Number of $35, offering no margin of safety. Additionally, the current ratio of only 1.2 falls below Graham's preferred 2.0 threshold..." 311 | 312 | Return a rational recommendation: bullish, bearish, or neutral, with a confidence level (0-100) and thorough reasoning. 313 | """ 314 | ), 315 | ( 316 | "human", 317 | """Based on the following analysis, create a Graham-style investment signal: 318 | 319 | Analysis Data for {ticker}: 320 | {analysis_data} 321 | 322 | Return JSON exactly in this format: 323 | {{ 324 | "signal": "bullish" or "bearish" or "neutral", 325 | "confidence": float (0-100), 326 | "reasoning": "string" 327 | }} 328 | """ 329 | ) 330 | ]) 331 | 332 | prompt = template.invoke({ 333 | "analysis_data": json.dumps(analysis_data, indent=2), 334 | "ticker": ticker 335 | }) 336 | 337 | def create_default_ben_graham_signal(): 338 | return BenGrahamSignal(signal="neutral", confidence=0.0, reasoning="Error in generating analysis; defaulting to neutral.") 339 | 340 | return call_llm( 341 | prompt=prompt, 342 | model_name=model_name, 343 | model_provider=model_provider, 344 | pydantic_model=BenGrahamSignal, 345 | agent_name="ben_graham_agent", 346 | default_factory=create_default_ben_graham_signal, 347 | ) 348 | -------------------------------------------------------------------------------- /src/agents/fundamentals.py: -------------------------------------------------------------------------------- 1 | from langchain_core.messages import HumanMessage 2 | from graph.state import AgentState, show_agent_reasoning 3 | from utils.progress import progress 4 | import json 5 | 6 | from tools.api import get_financial_metrics 7 | 8 | 9 | ##### Fundamental Agent ##### 10 | def fundamentals_agent(state: AgentState): 11 | """Analyzes fundamental data and generates trading signals for multiple tickers.""" 12 | data = state["data"] 13 | end_date = data["end_date"] 14 | tickers = data["tickers"] 15 | 16 | # Initialize fundamental analysis for each ticker 17 | fundamental_analysis = {} 18 | 19 | for ticker in tickers: 20 | progress.update_status("fundamentals_agent", ticker, "Fetching financial metrics") 21 | 22 | # Get the financial metrics 23 | financial_metrics = get_financial_metrics( 24 | ticker=ticker, 25 | end_date=end_date, 26 | period="ttm", 27 | limit=10, 28 | ) 29 | 30 | if not financial_metrics: 31 | progress.update_status("fundamentals_agent", ticker, "Failed: No financial metrics found") 32 | continue 33 | 34 | # Pull the most recent financial metrics 35 | metrics = financial_metrics[0] 36 | 37 | # Initialize signals list for different fundamental aspects 38 | signals = [] 39 | reasoning = {} 40 | 41 | progress.update_status("fundamentals_agent", ticker, "Analyzing profitability") 42 | # 1. Profitability Analysis 43 | return_on_equity = metrics.return_on_equity 44 | net_margin = metrics.net_margin 45 | operating_margin = metrics.operating_margin 46 | 47 | thresholds = [ 48 | (return_on_equity, 0.15), # Strong ROE above 15% 49 | (net_margin, 0.20), # Healthy profit margins 50 | (operating_margin, 0.15), # Strong operating efficiency 51 | ] 52 | profitability_score = sum(metric is not None and metric > threshold for metric, threshold in thresholds) 53 | 54 | signals.append("bullish" if profitability_score >= 2 else "bearish" if profitability_score == 0 else "neutral") 55 | reasoning["profitability_signal"] = { 56 | "signal": signals[0], 57 | "details": (f"ROE: {return_on_equity:.2%}" if return_on_equity else "ROE: N/A") + ", " + (f"Net Margin: {net_margin:.2%}" if net_margin else "Net Margin: N/A") + ", " + (f"Op Margin: {operating_margin:.2%}" if operating_margin else "Op Margin: N/A"), 58 | } 59 | 60 | progress.update_status("fundamentals_agent", ticker, "Analyzing growth") 61 | # 2. Growth Analysis 62 | revenue_growth = metrics.revenue_growth 63 | earnings_growth = metrics.earnings_growth 64 | book_value_growth = metrics.book_value_growth 65 | 66 | thresholds = [ 67 | (revenue_growth, 0.10), # 10% revenue growth 68 | (earnings_growth, 0.10), # 10% earnings growth 69 | (book_value_growth, 0.10), # 10% book value growth 70 | ] 71 | growth_score = sum(metric is not None and metric > threshold for metric, threshold in thresholds) 72 | 73 | signals.append("bullish" if growth_score >= 2 else "bearish" if growth_score == 0 else "neutral") 74 | reasoning["growth_signal"] = { 75 | "signal": signals[1], 76 | "details": (f"Revenue Growth: {revenue_growth:.2%}" if revenue_growth else "Revenue Growth: N/A") + ", " + (f"Earnings Growth: {earnings_growth:.2%}" if earnings_growth else "Earnings Growth: N/A"), 77 | } 78 | 79 | progress.update_status("fundamentals_agent", ticker, "Analyzing financial health") 80 | # 3. Financial Health 81 | current_ratio = metrics.current_ratio 82 | debt_to_equity = metrics.debt_to_equity 83 | free_cash_flow_per_share = metrics.free_cash_flow_per_share 84 | earnings_per_share = metrics.earnings_per_share 85 | 86 | health_score = 0 87 | if current_ratio and current_ratio > 1.5: # Strong liquidity 88 | health_score += 1 89 | if debt_to_equity and debt_to_equity < 0.5: # Conservative debt levels 90 | health_score += 1 91 | if free_cash_flow_per_share and earnings_per_share and free_cash_flow_per_share > earnings_per_share * 0.8: # Strong FCF conversion 92 | health_score += 1 93 | 94 | signals.append("bullish" if health_score >= 2 else "bearish" if health_score == 0 else "neutral") 95 | reasoning["financial_health_signal"] = { 96 | "signal": signals[2], 97 | "details": (f"Current Ratio: {current_ratio:.2f}" if current_ratio else "Current Ratio: N/A") + ", " + (f"D/E: {debt_to_equity:.2f}" if debt_to_equity else "D/E: N/A"), 98 | } 99 | 100 | progress.update_status("fundamentals_agent", ticker, "Analyzing valuation ratios") 101 | # 4. Price to X ratios 102 | pe_ratio = metrics.price_to_earnings_ratio 103 | pb_ratio = metrics.price_to_book_ratio 104 | ps_ratio = metrics.price_to_sales_ratio 105 | 106 | thresholds = [ 107 | (pe_ratio, 25), # Reasonable P/E ratio 108 | (pb_ratio, 3), # Reasonable P/B ratio 109 | (ps_ratio, 5), # Reasonable P/S ratio 110 | ] 111 | price_ratio_score = sum(metric is not None and metric > threshold for metric, threshold in thresholds) 112 | 113 | signals.append("bearish" if price_ratio_score >= 2 else "bullish" if price_ratio_score == 0 else "neutral") 114 | reasoning["price_ratios_signal"] = { 115 | "signal": signals[3], 116 | "details": (f"P/E: {pe_ratio:.2f}" if pe_ratio else "P/E: N/A") + ", " + (f"P/B: {pb_ratio:.2f}" if pb_ratio else "P/B: N/A") + ", " + (f"P/S: {ps_ratio:.2f}" if ps_ratio else "P/S: N/A"), 117 | } 118 | 119 | progress.update_status("fundamentals_agent", ticker, "Calculating final signal") 120 | # Determine overall signal 121 | bullish_signals = signals.count("bullish") 122 | bearish_signals = signals.count("bearish") 123 | 124 | if bullish_signals > bearish_signals: 125 | overall_signal = "bullish" 126 | elif bearish_signals > bullish_signals: 127 | overall_signal = "bearish" 128 | else: 129 | overall_signal = "neutral" 130 | 131 | # Calculate confidence level 132 | total_signals = len(signals) 133 | confidence = round(max(bullish_signals, bearish_signals) / total_signals, 2) * 100 134 | 135 | fundamental_analysis[ticker] = { 136 | "signal": overall_signal, 137 | "confidence": confidence, 138 | "reasoning": reasoning, 139 | } 140 | 141 | progress.update_status("fundamentals_agent", ticker, "Done") 142 | 143 | # Create the fundamental analysis message 144 | message = HumanMessage( 145 | content=json.dumps(fundamental_analysis), 146 | name="fundamentals_agent", 147 | ) 148 | 149 | # Print the reasoning if the flag is set 150 | if state["metadata"]["show_reasoning"]: 151 | show_agent_reasoning(fundamental_analysis, "Fundamental Analysis Agent") 152 | 153 | # Add the signal to the analyst_signals list 154 | state["data"]["analyst_signals"]["fundamentals_agent"] = fundamental_analysis 155 | 156 | return { 157 | "messages": [message], 158 | "data": data, 159 | } 160 | -------------------------------------------------------------------------------- /src/agents/michael_burry.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime, timedelta 4 | import json 5 | from typing_extensions import Literal 6 | 7 | from graph.state import AgentState, show_agent_reasoning 8 | from langchain_core.messages import HumanMessage 9 | from langchain_core.prompts import ChatPromptTemplate 10 | from pydantic import BaseModel 11 | 12 | from tools.api import ( 13 | get_company_news, 14 | get_financial_metrics, 15 | get_insider_trades, 16 | get_market_cap, 17 | search_line_items, 18 | ) 19 | from utils.llm import call_llm 20 | from utils.progress import progress 21 | 22 | __all__ = [ 23 | "MichaelBurrySignal", 24 | "michael_burry_agent", 25 | ] 26 | 27 | ############################################################################### 28 | # Pydantic output model 29 | ############################################################################### 30 | 31 | 32 | class MichaelBurrySignal(BaseModel): 33 | """Schema returned by the LLM.""" 34 | 35 | signal: Literal["bullish", "bearish", "neutral"] 36 | confidence: float # 0–100 37 | reasoning: str 38 | 39 | 40 | ############################################################################### 41 | # Core agent 42 | ############################################################################### 43 | 44 | 45 | def michael_burry_agent(state: AgentState): # noqa: C901 (complexity is fine here) 46 | """Analyse stocks using Michael Burry's deep‑value, contrarian framework.""" 47 | 48 | data = state["data"] 49 | end_date: str = data["end_date"] # YYYY‑MM‑DD 50 | tickers: list[str] = data["tickers"] 51 | 52 | # We look one year back for insider trades / news flow 53 | start_date = (datetime.fromisoformat(end_date) - timedelta(days=365)).date().isoformat() 54 | 55 | analysis_data: dict[str, dict] = {} 56 | burry_analysis: dict[str, dict] = {} 57 | 58 | for ticker in tickers: 59 | # ------------------------------------------------------------------ 60 | # Fetch raw data 61 | # ------------------------------------------------------------------ 62 | progress.update_status("michael_burry_agent", ticker, "Fetching financial metrics") 63 | metrics = get_financial_metrics(ticker, end_date, period="ttm", limit=5) 64 | 65 | progress.update_status("michael_burry_agent", ticker, "Fetching line items") 66 | line_items = search_line_items( 67 | ticker, 68 | [ 69 | "free_cash_flow", 70 | "net_income", 71 | "total_debt", 72 | "cash_and_equivalents", 73 | "total_assets", 74 | "total_liabilities", 75 | "outstanding_shares", 76 | "issuance_or_purchase_of_equity_shares", 77 | ], 78 | end_date, 79 | ) 80 | 81 | progress.update_status("michael_burry_agent", ticker, "Fetching insider trades") 82 | insider_trades = get_insider_trades(ticker, end_date=end_date, start_date=start_date) 83 | 84 | progress.update_status("michael_burry_agent", ticker, "Fetching company news") 85 | news = get_company_news(ticker, end_date=end_date, start_date=start_date, limit=250) 86 | 87 | progress.update_status("michael_burry_agent", ticker, "Fetching market cap") 88 | market_cap = get_market_cap(ticker, end_date) 89 | 90 | # ------------------------------------------------------------------ 91 | # Run sub‑analyses 92 | # ------------------------------------------------------------------ 93 | progress.update_status("michael_burry_agent", ticker, "Analyzing value") 94 | value_analysis = _analyze_value(metrics, line_items, market_cap) 95 | 96 | progress.update_status("michael_burry_agent", ticker, "Analyzing balance sheet") 97 | balance_sheet_analysis = _analyze_balance_sheet(metrics, line_items) 98 | 99 | progress.update_status("michael_burry_agent", ticker, "Analyzing insider activity") 100 | insider_analysis = _analyze_insider_activity(insider_trades) 101 | 102 | progress.update_status("michael_burry_agent", ticker, "Analyzing contrarian sentiment") 103 | contrarian_analysis = _analyze_contrarian_sentiment(news) 104 | 105 | # ------------------------------------------------------------------ 106 | # Aggregate score & derive preliminary signal 107 | # ------------------------------------------------------------------ 108 | total_score = ( 109 | value_analysis["score"] 110 | + balance_sheet_analysis["score"] 111 | + insider_analysis["score"] 112 | + contrarian_analysis["score"] 113 | ) 114 | max_score = ( 115 | value_analysis["max_score"] 116 | + balance_sheet_analysis["max_score"] 117 | + insider_analysis["max_score"] 118 | + contrarian_analysis["max_score"] 119 | ) 120 | 121 | if total_score >= 0.7 * max_score: 122 | signal = "bullish" 123 | elif total_score <= 0.3 * max_score: 124 | signal = "bearish" 125 | else: 126 | signal = "neutral" 127 | 128 | # ------------------------------------------------------------------ 129 | # Collect data for LLM reasoning & output 130 | # ------------------------------------------------------------------ 131 | analysis_data[ticker] = { 132 | "signal": signal, 133 | "score": total_score, 134 | "max_score": max_score, 135 | "value_analysis": value_analysis, 136 | "balance_sheet_analysis": balance_sheet_analysis, 137 | "insider_analysis": insider_analysis, 138 | "contrarian_analysis": contrarian_analysis, 139 | "market_cap": market_cap, 140 | } 141 | 142 | progress.update_status("michael_burry_agent", ticker, "Generating LLM output") 143 | burry_output = _generate_burry_output( 144 | ticker=ticker, 145 | analysis_data=analysis_data, 146 | model_name=state["metadata"]["model_name"], 147 | model_provider=state["metadata"]["model_provider"], 148 | ) 149 | 150 | burry_analysis[ticker] = { 151 | "signal": burry_output.signal, 152 | "confidence": burry_output.confidence, 153 | "reasoning": burry_output.reasoning, 154 | } 155 | 156 | progress.update_status("michael_burry_agent", ticker, "Done") 157 | 158 | # ---------------------------------------------------------------------- 159 | # Return to the graph 160 | # ---------------------------------------------------------------------- 161 | message = HumanMessage(content=json.dumps(burry_analysis), name="michael_burry_agent") 162 | 163 | if state["metadata"].get("show_reasoning"): 164 | show_agent_reasoning(burry_analysis, "Michael Burry Agent") 165 | 166 | state["data"]["analyst_signals"]["michael_burry_agent"] = burry_analysis 167 | 168 | return {"messages": [message], "data": state["data"]} 169 | 170 | 171 | ############################################################################### 172 | # Sub‑analysis helpers 173 | ############################################################################### 174 | 175 | 176 | def _latest_line_item(line_items: list): 177 | """Return the most recent line‑item object or *None*.""" 178 | return line_items[0] if line_items else None 179 | 180 | 181 | # ----- Value ---------------------------------------------------------------- 182 | 183 | def _analyze_value(metrics, line_items, market_cap): 184 | """Free cash‑flow yield, EV/EBIT, other classic deep‑value metrics.""" 185 | 186 | max_score = 6 # 4 pts for FCF‑yield, 2 pts for EV/EBIT 187 | score = 0 188 | details: list[str] = [] 189 | 190 | # Free‑cash‑flow yield 191 | latest_item = _latest_line_item(line_items) 192 | fcf = getattr(latest_item, "free_cash_flow", None) if latest_item else None 193 | if fcf is not None and market_cap: 194 | fcf_yield = fcf / market_cap 195 | if fcf_yield >= 0.15: 196 | score += 4 197 | details.append(f"Extraordinary FCF yield {fcf_yield:.1%}") 198 | elif fcf_yield >= 0.12: 199 | score += 3 200 | details.append(f"Very high FCF yield {fcf_yield:.1%}") 201 | elif fcf_yield >= 0.08: 202 | score += 2 203 | details.append(f"Respectable FCF yield {fcf_yield:.1%}") 204 | else: 205 | details.append(f"Low FCF yield {fcf_yield:.1%}") 206 | else: 207 | details.append("FCF data unavailable") 208 | 209 | # EV/EBIT (from financial metrics) 210 | if metrics: 211 | ev_ebit = getattr(metrics[0], "ev_to_ebit", None) 212 | if ev_ebit is not None: 213 | if ev_ebit < 6: 214 | score += 2 215 | details.append(f"EV/EBIT {ev_ebit:.1f} (<6)") 216 | elif ev_ebit < 10: 217 | score += 1 218 | details.append(f"EV/EBIT {ev_ebit:.1f} (<10)") 219 | else: 220 | details.append(f"High EV/EBIT {ev_ebit:.1f}") 221 | else: 222 | details.append("EV/EBIT data unavailable") 223 | else: 224 | details.append("Financial metrics unavailable") 225 | 226 | return {"score": score, "max_score": max_score, "details": "; ".join(details)} 227 | 228 | 229 | # ----- Balance sheet -------------------------------------------------------- 230 | 231 | def _analyze_balance_sheet(metrics, line_items): 232 | """Leverage and liquidity checks.""" 233 | 234 | max_score = 3 235 | score = 0 236 | details: list[str] = [] 237 | 238 | latest_metrics = metrics[0] if metrics else None 239 | latest_item = _latest_line_item(line_items) 240 | 241 | debt_to_equity = getattr(latest_metrics, "debt_to_equity", None) if latest_metrics else None 242 | if debt_to_equity is not None: 243 | if debt_to_equity < 0.5: 244 | score += 2 245 | details.append(f"Low D/E {debt_to_equity:.2f}") 246 | elif debt_to_equity < 1: 247 | score += 1 248 | details.append(f"Moderate D/E {debt_to_equity:.2f}") 249 | else: 250 | details.append(f"High leverage D/E {debt_to_equity:.2f}") 251 | else: 252 | details.append("Debt‑to‑equity data unavailable") 253 | 254 | # Quick liquidity sanity check (cash vs total debt) 255 | if latest_item is not None: 256 | cash = getattr(latest_item, "cash_and_equivalents", None) 257 | total_debt = getattr(latest_item, "total_debt", None) 258 | if cash is not None and total_debt is not None: 259 | if cash > total_debt: 260 | score += 1 261 | details.append("Net cash position") 262 | else: 263 | details.append("Net debt position") 264 | else: 265 | details.append("Cash/debt data unavailable") 266 | 267 | return {"score": score, "max_score": max_score, "details": "; ".join(details)} 268 | 269 | 270 | # ----- Insider activity ----------------------------------------------------- 271 | 272 | def _analyze_insider_activity(insider_trades): 273 | """Net insider buying over the last 12 months acts as a hard catalyst.""" 274 | 275 | max_score = 2 276 | score = 0 277 | details: list[str] = [] 278 | 279 | if not insider_trades: 280 | details.append("No insider trade data") 281 | return {"score": score, "max_score": max_score, "details": "; ".join(details)} 282 | 283 | shares_bought = sum(t.transaction_shares or 0 for t in insider_trades if (t.transaction_shares or 0) > 0) 284 | shares_sold = abs(sum(t.transaction_shares or 0 for t in insider_trades if (t.transaction_shares or 0) < 0)) 285 | net = shares_bought - shares_sold 286 | if net > 0: 287 | score += 2 if net / max(shares_sold, 1) > 1 else 1 288 | details.append(f"Net insider buying of {net:,} shares") 289 | else: 290 | details.append("Net insider selling") 291 | 292 | return {"score": score, "max_score": max_score, "details": "; ".join(details)} 293 | 294 | 295 | # ----- Contrarian sentiment ------------------------------------------------- 296 | 297 | def _analyze_contrarian_sentiment(news): 298 | """Very rough gauge: a wall of recent negative headlines can be a *positive* for a contrarian.""" 299 | 300 | max_score = 1 301 | score = 0 302 | details: list[str] = [] 303 | 304 | if not news: 305 | details.append("No recent news") 306 | return {"score": score, "max_score": max_score, "details": "; ".join(details)} 307 | 308 | # Count negative sentiment articles 309 | sentiment_negative_count = sum( 310 | 1 for n in news if n.sentiment and n.sentiment.lower() in ["negative", "bearish"] 311 | ) 312 | 313 | if sentiment_negative_count >= 5: 314 | score += 1 # The more hated, the better (assuming fundamentals hold up) 315 | details.append(f"{sentiment_negative_count} negative headlines (contrarian opportunity)") 316 | else: 317 | details.append("Limited negative press") 318 | 319 | return {"score": score, "max_score": max_score, "details": "; ".join(details)} 320 | 321 | 322 | ############################################################################### 323 | # LLM generation 324 | ############################################################################### 325 | 326 | def _generate_burry_output( 327 | ticker: str, 328 | analysis_data: dict, 329 | *, 330 | model_name: str, 331 | model_provider: str, 332 | ) -> MichaelBurrySignal: 333 | """Call the LLM to craft the final trading signal in Burry's voice.""" 334 | 335 | template = ChatPromptTemplate.from_messages( 336 | [ 337 | ( 338 | "system", 339 | """You are an AI agent emulating Dr. Michael J. Burry. Your mandate: 340 | - Hunt for deep value in US equities using hard numbers (free cash flow, EV/EBIT, balance sheet) 341 | - Be contrarian: hatred in the press can be your friend if fundamentals are solid 342 | - Focus on downside first – avoid leveraged balance sheets 343 | - Look for hard catalysts such as insider buying, buybacks, or asset sales 344 | - Communicate in Burry's terse, data‑driven style 345 | 346 | When providing your reasoning, be thorough and specific by: 347 | 1. Start with the key metric(s) that drove your decision 348 | 2. Cite concrete numbers (e.g. "FCF yield 14.7%", "EV/EBIT 5.3") 349 | 3. Highlight risk factors and why they are acceptable (or not) 350 | 4. Mention relevant insider activity or contrarian opportunities 351 | 5. Use Burry's direct, number-focused communication style with minimal words 352 | 353 | For example, if bullish: "FCF yield 12.8%. EV/EBIT 6.2. Debt-to-equity 0.4. Net insider buying 25k shares. Market missing value due to overreaction to recent litigation. Strong buy." 354 | For example, if bearish: "FCF yield only 2.1%. Debt-to-equity concerning at 2.3. Management diluting shareholders. Pass." 355 | """, 356 | ), 357 | ( 358 | "human", 359 | """Based on the following data, create the investment signal as Michael Burry would: 360 | 361 | Analysis Data for {ticker}: 362 | {analysis_data} 363 | 364 | Return the trading signal in the following JSON format exactly: 365 | {{ 366 | "signal": "bullish" | "bearish" | "neutral", 367 | "confidence": float between 0 and 100, 368 | "reasoning": "string" 369 | }} 370 | """, 371 | ), 372 | ] 373 | ) 374 | 375 | prompt = template.invoke({"analysis_data": json.dumps(analysis_data, indent=2), "ticker": ticker}) 376 | 377 | # Default fallback signal in case parsing fails 378 | def create_default_michael_burry_signal(): 379 | return MichaelBurrySignal(signal="neutral", confidence=0.0, reasoning="Parsing error – defaulting to neutral") 380 | 381 | return call_llm( 382 | prompt=prompt, 383 | model_name=model_name, 384 | model_provider=model_provider, 385 | pydantic_model=MichaelBurrySignal, 386 | agent_name="michael_burry_agent", 387 | default_factory=create_default_michael_burry_signal, 388 | ) 389 | -------------------------------------------------------------------------------- /src/agents/portfolio_manager.py: -------------------------------------------------------------------------------- 1 | import json 2 | from langchain_core.messages import HumanMessage 3 | from langchain_core.prompts import ChatPromptTemplate 4 | 5 | from graph.state import AgentState, show_agent_reasoning 6 | from pydantic import BaseModel, Field 7 | from typing_extensions import Literal 8 | from utils.progress import progress 9 | from utils.llm import call_llm 10 | 11 | 12 | class PortfolioDecision(BaseModel): 13 | action: Literal["buy", "sell", "short", "cover", "hold"] 14 | quantity: int = Field(description="Number of shares to trade") 15 | confidence: float = Field(description="Confidence in the decision, between 0.0 and 100.0") 16 | reasoning: str = Field(description="Reasoning for the decision") 17 | 18 | 19 | class PortfolioManagerOutput(BaseModel): 20 | decisions: dict[str, PortfolioDecision] = Field(description="Dictionary of ticker to trading decisions") 21 | 22 | 23 | ##### Portfolio Management Agent ##### 24 | def portfolio_management_agent(state: AgentState): 25 | """Makes final trading decisions and generates orders for multiple tickers""" 26 | 27 | # Get the portfolio and analyst signals 28 | portfolio = state["data"]["portfolio"] 29 | analyst_signals = state["data"]["analyst_signals"] 30 | tickers = state["data"]["tickers"] 31 | 32 | progress.update_status("portfolio_management_agent", None, "Analyzing signals") 33 | 34 | # Get position limits, current prices, and signals for every ticker 35 | position_limits = {} 36 | current_prices = {} 37 | max_shares = {} 38 | signals_by_ticker = {} 39 | for ticker in tickers: 40 | progress.update_status("portfolio_management_agent", ticker, "Processing analyst signals") 41 | 42 | # Get position limits and current prices for the ticker 43 | risk_data = analyst_signals.get("risk_management_agent", {}).get(ticker, {}) 44 | position_limits[ticker] = risk_data.get("remaining_position_limit", 0) 45 | current_prices[ticker] = risk_data.get("current_price", 0) 46 | 47 | # Calculate maximum shares allowed based on position limit and price 48 | if current_prices[ticker] > 0: 49 | max_shares[ticker] = int(position_limits[ticker] / current_prices[ticker]) 50 | else: 51 | max_shares[ticker] = 0 52 | 53 | # Get signals for the ticker 54 | ticker_signals = {} 55 | for agent, signals in analyst_signals.items(): 56 | if agent != "risk_management_agent" and ticker in signals: 57 | ticker_signals[agent] = {"signal": signals[ticker]["signal"], "confidence": signals[ticker]["confidence"]} 58 | signals_by_ticker[ticker] = ticker_signals 59 | 60 | progress.update_status("portfolio_management_agent", None, "Making trading decisions") 61 | 62 | # Generate the trading decision 63 | result = generate_trading_decision( 64 | tickers=tickers, 65 | signals_by_ticker=signals_by_ticker, 66 | current_prices=current_prices, 67 | max_shares=max_shares, 68 | portfolio=portfolio, 69 | model_name=state["metadata"]["model_name"], 70 | model_provider=state["metadata"]["model_provider"], 71 | ) 72 | 73 | # Create the portfolio management message 74 | message = HumanMessage( 75 | content=json.dumps({ticker: decision.model_dump() for ticker, decision in result.decisions.items()}), 76 | name="portfolio_management", 77 | ) 78 | 79 | # Print the decision if the flag is set 80 | if state["metadata"]["show_reasoning"]: 81 | show_agent_reasoning({ticker: decision.model_dump() for ticker, decision in result.decisions.items()}, "Portfolio Management Agent") 82 | 83 | progress.update_status("portfolio_management_agent", None, "Done") 84 | 85 | return { 86 | "messages": state["messages"] + [message], 87 | "data": state["data"], 88 | } 89 | 90 | 91 | def generate_trading_decision( 92 | tickers: list[str], 93 | signals_by_ticker: dict[str, dict], 94 | current_prices: dict[str, float], 95 | max_shares: dict[str, int], 96 | portfolio: dict[str, float], 97 | model_name: str, 98 | model_provider: str, 99 | ) -> PortfolioManagerOutput: 100 | """Attempts to get a decision from the LLM with retry logic""" 101 | # Create the prompt template 102 | template = ChatPromptTemplate.from_messages( 103 | [ 104 | ( 105 | "system", 106 | """You are a portfolio manager making final trading decisions based on multiple tickers. 107 | 108 | Trading Rules: 109 | - For long positions: 110 | * Only buy if you have available cash 111 | * Only sell if you currently hold long shares of that ticker 112 | * Sell quantity must be ≤ current long position shares 113 | * Buy quantity must be ≤ max_shares for that ticker 114 | 115 | - For short positions: 116 | * Only short if you have available margin (position value × margin requirement) 117 | * Only cover if you currently have short shares of that ticker 118 | * Cover quantity must be ≤ current short position shares 119 | * Short quantity must respect margin requirements 120 | 121 | - The max_shares values are pre-calculated to respect position limits 122 | - Consider both long and short opportunities based on signals 123 | - Maintain appropriate risk management with both long and short exposure 124 | 125 | Available Actions: 126 | - "buy": Open or add to long position 127 | - "sell": Close or reduce long position 128 | - "short": Open or add to short position 129 | - "cover": Close or reduce short position 130 | - "hold": No action 131 | 132 | Inputs: 133 | - signals_by_ticker: dictionary of ticker → signals 134 | - max_shares: maximum shares allowed per ticker 135 | - portfolio_cash: current cash in portfolio 136 | - portfolio_positions: current positions (both long and short) 137 | - current_prices: current prices for each ticker 138 | - margin_requirement: current margin requirement for short positions (e.g., 0.5 means 50%) 139 | - total_margin_used: total margin currently in use 140 | """, 141 | ), 142 | ( 143 | "human", 144 | """Based on the team's analysis, make your trading decisions for each ticker. 145 | 146 | Here are the signals by ticker: 147 | {signals_by_ticker} 148 | 149 | Current Prices: 150 | {current_prices} 151 | 152 | Maximum Shares Allowed For Purchases: 153 | {max_shares} 154 | 155 | Portfolio Cash: {portfolio_cash} 156 | Current Positions: {portfolio_positions} 157 | Current Margin Requirement: {margin_requirement} 158 | Total Margin Used: {total_margin_used} 159 | 160 | Output strictly in JSON with the following structure: 161 | {{ 162 | "decisions": {{ 163 | "TICKER1": {{ 164 | "action": "buy/sell/short/cover/hold", 165 | "quantity": integer, 166 | "confidence": float between 0 and 100, 167 | "reasoning": "string" 168 | }}, 169 | "TICKER2": {{ 170 | ... 171 | }}, 172 | ... 173 | }} 174 | }} 175 | """, 176 | ), 177 | ] 178 | ) 179 | 180 | # Generate the prompt 181 | prompt = template.invoke( 182 | { 183 | "signals_by_ticker": json.dumps(signals_by_ticker, indent=2), 184 | "current_prices": json.dumps(current_prices, indent=2), 185 | "max_shares": json.dumps(max_shares, indent=2), 186 | "portfolio_cash": f"{portfolio.get('cash', 0):.2f}", 187 | "portfolio_positions": json.dumps(portfolio.get('positions', {}), indent=2), 188 | "margin_requirement": f"{portfolio.get('margin_requirement', 0):.2f}", 189 | "total_margin_used": f"{portfolio.get('margin_used', 0):.2f}", 190 | } 191 | ) 192 | 193 | # Create default factory for PortfolioManagerOutput 194 | def create_default_portfolio_output(): 195 | return PortfolioManagerOutput(decisions={ticker: PortfolioDecision(action="hold", quantity=0, confidence=0.0, reasoning="Error in portfolio management, defaulting to hold") for ticker in tickers}) 196 | 197 | return call_llm(prompt=prompt, model_name=model_name, model_provider=model_provider, pydantic_model=PortfolioManagerOutput, agent_name="portfolio_management_agent", default_factory=create_default_portfolio_output) 198 | -------------------------------------------------------------------------------- /src/agents/risk_manager.py: -------------------------------------------------------------------------------- 1 | from langchain_core.messages import HumanMessage 2 | from graph.state import AgentState, show_agent_reasoning 3 | from utils.progress import progress 4 | from tools.api import get_prices, prices_to_df 5 | import json 6 | 7 | 8 | ##### Risk Management Agent ##### 9 | def risk_management_agent(state: AgentState): 10 | """Controls position sizing based on real-world risk factors for multiple tickers.""" 11 | portfolio = state["data"]["portfolio"] 12 | data = state["data"] 13 | tickers = data["tickers"] 14 | 15 | # Initialize risk analysis for each ticker 16 | risk_analysis = {} 17 | current_prices = {} # Store prices here to avoid redundant API calls 18 | 19 | for ticker in tickers: 20 | progress.update_status("risk_management_agent", ticker, "Analyzing price data") 21 | 22 | prices = get_prices( 23 | ticker=ticker, 24 | start_date=data["start_date"], 25 | end_date=data["end_date"], 26 | ) 27 | 28 | if not prices: 29 | progress.update_status("risk_management_agent", ticker, "Failed: No price data found") 30 | continue 31 | 32 | prices_df = prices_to_df(prices) 33 | 34 | progress.update_status("risk_management_agent", ticker, "Calculating position limits") 35 | 36 | # Calculate portfolio value 37 | current_price = prices_df["close"].iloc[-1] 38 | current_prices[ticker] = current_price # Store the current price 39 | 40 | # Calculate current position value for this ticker 41 | current_position_value = portfolio.get("cost_basis", {}).get(ticker, 0) 42 | 43 | # Calculate total portfolio value using stored prices 44 | total_portfolio_value = portfolio.get("cash", 0) + sum(portfolio.get("cost_basis", {}).get(t, 0) for t in portfolio.get("cost_basis", {})) 45 | 46 | # Base limit is 20% of portfolio for any single position 47 | position_limit = total_portfolio_value * 0.20 48 | 49 | # For existing positions, subtract current position value from limit 50 | remaining_position_limit = position_limit - current_position_value 51 | 52 | # Ensure we don't exceed available cash 53 | max_position_size = min(remaining_position_limit, portfolio.get("cash", 0)) 54 | 55 | risk_analysis[ticker] = { 56 | "remaining_position_limit": float(max_position_size), 57 | "current_price": float(current_price), 58 | "reasoning": { 59 | "portfolio_value": float(total_portfolio_value), 60 | "current_position": float(current_position_value), 61 | "position_limit": float(position_limit), 62 | "remaining_limit": float(remaining_position_limit), 63 | "available_cash": float(portfolio.get("cash", 0)), 64 | }, 65 | } 66 | 67 | progress.update_status("risk_management_agent", ticker, "Done") 68 | 69 | message = HumanMessage( 70 | content=json.dumps(risk_analysis), 71 | name="risk_management_agent", 72 | ) 73 | 74 | if state["metadata"]["show_reasoning"]: 75 | show_agent_reasoning(risk_analysis, "Risk Management Agent") 76 | 77 | # Add the signal to the analyst_signals list 78 | state["data"]["analyst_signals"]["risk_management_agent"] = risk_analysis 79 | 80 | return { 81 | "messages": state["messages"] + [message], 82 | "data": data, 83 | } 84 | -------------------------------------------------------------------------------- /src/agents/sentiment.py: -------------------------------------------------------------------------------- 1 | from langchain_core.messages import HumanMessage 2 | from graph.state import AgentState, show_agent_reasoning 3 | from utils.progress import progress 4 | import pandas as pd 5 | import numpy as np 6 | import json 7 | 8 | from tools.api import get_insider_trades, get_company_news 9 | 10 | 11 | ##### Sentiment Agent ##### 12 | def sentiment_agent(state: AgentState): 13 | """Analyzes market sentiment and generates trading signals for multiple tickers.""" 14 | data = state.get("data", {}) 15 | end_date = data.get("end_date") 16 | tickers = data.get("tickers") 17 | 18 | # Initialize sentiment analysis for each ticker 19 | sentiment_analysis = {} 20 | 21 | for ticker in tickers: 22 | progress.update_status("sentiment_agent", ticker, "Fetching insider trades") 23 | 24 | # Get the insider trades 25 | insider_trades = get_insider_trades( 26 | ticker=ticker, 27 | end_date=end_date, 28 | limit=1000, 29 | ) 30 | 31 | progress.update_status("sentiment_agent", ticker, "Analyzing trading patterns") 32 | 33 | # Get the signals from the insider trades 34 | transaction_shares = pd.Series([t.transaction_shares for t in insider_trades]).dropna() 35 | insider_signals = np.where(transaction_shares < 0, "bearish", "bullish").tolist() 36 | 37 | progress.update_status("sentiment_agent", ticker, "Fetching company news") 38 | 39 | # Get the company news 40 | company_news = get_company_news(ticker, end_date, limit=100) 41 | 42 | # Get the sentiment from the company news 43 | sentiment = pd.Series([n.sentiment for n in company_news]).dropna() 44 | news_signals = np.where(sentiment == "negative", "bearish", 45 | np.where(sentiment == "positive", "bullish", "neutral")).tolist() 46 | 47 | progress.update_status("sentiment_agent", ticker, "Combining signals") 48 | # Combine signals from both sources with weights 49 | insider_weight = 0.3 50 | news_weight = 0.7 51 | 52 | # Calculate weighted signal counts 53 | bullish_signals = ( 54 | insider_signals.count("bullish") * insider_weight + 55 | news_signals.count("bullish") * news_weight 56 | ) 57 | bearish_signals = ( 58 | insider_signals.count("bearish") * insider_weight + 59 | news_signals.count("bearish") * news_weight 60 | ) 61 | 62 | if bullish_signals > bearish_signals: 63 | overall_signal = "bullish" 64 | elif bearish_signals > bullish_signals: 65 | overall_signal = "bearish" 66 | else: 67 | overall_signal = "neutral" 68 | 69 | # Calculate confidence level based on the weighted proportion 70 | total_weighted_signals = len(insider_signals) * insider_weight + len(news_signals) * news_weight 71 | confidence = 0 # Default confidence when there are no signals 72 | if total_weighted_signals > 0: 73 | confidence = round(max(bullish_signals, bearish_signals) / total_weighted_signals, 2) * 100 74 | reasoning = f"Weighted Bullish signals: {bullish_signals:.1f}, Weighted Bearish signals: {bearish_signals:.1f}" 75 | 76 | sentiment_analysis[ticker] = { 77 | "signal": overall_signal, 78 | "confidence": confidence, 79 | "reasoning": reasoning, 80 | } 81 | 82 | progress.update_status("sentiment_agent", ticker, "Done") 83 | 84 | # Create the sentiment message 85 | message = HumanMessage( 86 | content=json.dumps(sentiment_analysis), 87 | name="sentiment_agent", 88 | ) 89 | 90 | # Print the reasoning if the flag is set 91 | if state["metadata"]["show_reasoning"]: 92 | show_agent_reasoning(sentiment_analysis, "Sentiment Analysis Agent") 93 | 94 | # Add the signal to the analyst_signals list 95 | state["data"]["analyst_signals"]["sentiment_agent"] = sentiment_analysis 96 | 97 | return { 98 | "messages": [message], 99 | "data": data, 100 | } 101 | -------------------------------------------------------------------------------- /src/agents/technicals.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from langchain_core.messages import HumanMessage 4 | 5 | from graph.state import AgentState, show_agent_reasoning 6 | 7 | import json 8 | import pandas as pd 9 | import numpy as np 10 | 11 | from tools.api import get_prices, prices_to_df 12 | from utils.progress import progress 13 | 14 | 15 | ##### Technical Analyst ##### 16 | def technical_analyst_agent(state: AgentState): 17 | """ 18 | Sophisticated technical analysis system that combines multiple trading strategies for multiple tickers: 19 | 1. Trend Following 20 | 2. Mean Reversion 21 | 3. Momentum 22 | 4. Volatility Analysis 23 | 5. Statistical Arbitrage Signals 24 | """ 25 | data = state["data"] 26 | start_date = data["start_date"] 27 | end_date = data["end_date"] 28 | tickers = data["tickers"] 29 | 30 | # Initialize analysis for each ticker 31 | technical_analysis = {} 32 | 33 | for ticker in tickers: 34 | progress.update_status("technical_analyst_agent", ticker, "Analyzing price data") 35 | 36 | # Get the historical price data 37 | prices = get_prices( 38 | ticker=ticker, 39 | start_date=start_date, 40 | end_date=end_date, 41 | ) 42 | 43 | if not prices: 44 | progress.update_status("technical_analyst_agent", ticker, "Failed: No price data found") 45 | continue 46 | 47 | # Convert prices to a DataFrame 48 | prices_df = prices_to_df(prices) 49 | 50 | progress.update_status("technical_analyst_agent", ticker, "Calculating trend signals") 51 | trend_signals = calculate_trend_signals(prices_df) 52 | 53 | progress.update_status("technical_analyst_agent", ticker, "Calculating mean reversion") 54 | mean_reversion_signals = calculate_mean_reversion_signals(prices_df) 55 | 56 | progress.update_status("technical_analyst_agent", ticker, "Calculating momentum") 57 | momentum_signals = calculate_momentum_signals(prices_df) 58 | 59 | progress.update_status("technical_analyst_agent", ticker, "Analyzing volatility") 60 | volatility_signals = calculate_volatility_signals(prices_df) 61 | 62 | progress.update_status("technical_analyst_agent", ticker, "Statistical analysis") 63 | stat_arb_signals = calculate_stat_arb_signals(prices_df) 64 | 65 | # Combine all signals using a weighted ensemble approach 66 | strategy_weights = { 67 | "trend": 0.25, 68 | "mean_reversion": 0.20, 69 | "momentum": 0.25, 70 | "volatility": 0.15, 71 | "stat_arb": 0.15, 72 | } 73 | 74 | progress.update_status("technical_analyst_agent", ticker, "Combining signals") 75 | combined_signal = weighted_signal_combination( 76 | { 77 | "trend": trend_signals, 78 | "mean_reversion": mean_reversion_signals, 79 | "momentum": momentum_signals, 80 | "volatility": volatility_signals, 81 | "stat_arb": stat_arb_signals, 82 | }, 83 | strategy_weights, 84 | ) 85 | 86 | # Generate detailed analysis report for this ticker 87 | technical_analysis[ticker] = { 88 | "signal": combined_signal["signal"], 89 | "confidence": round(combined_signal["confidence"] * 100), 90 | "strategy_signals": { 91 | "trend_following": { 92 | "signal": trend_signals["signal"], 93 | "confidence": round(trend_signals["confidence"] * 100), 94 | "metrics": normalize_pandas(trend_signals["metrics"]), 95 | }, 96 | "mean_reversion": { 97 | "signal": mean_reversion_signals["signal"], 98 | "confidence": round(mean_reversion_signals["confidence"] * 100), 99 | "metrics": normalize_pandas(mean_reversion_signals["metrics"]), 100 | }, 101 | "momentum": { 102 | "signal": momentum_signals["signal"], 103 | "confidence": round(momentum_signals["confidence"] * 100), 104 | "metrics": normalize_pandas(momentum_signals["metrics"]), 105 | }, 106 | "volatility": { 107 | "signal": volatility_signals["signal"], 108 | "confidence": round(volatility_signals["confidence"] * 100), 109 | "metrics": normalize_pandas(volatility_signals["metrics"]), 110 | }, 111 | "statistical_arbitrage": { 112 | "signal": stat_arb_signals["signal"], 113 | "confidence": round(stat_arb_signals["confidence"] * 100), 114 | "metrics": normalize_pandas(stat_arb_signals["metrics"]), 115 | }, 116 | }, 117 | } 118 | progress.update_status("technical_analyst_agent", ticker, "Done") 119 | 120 | # Create the technical analyst message 121 | message = HumanMessage( 122 | content=json.dumps(technical_analysis), 123 | name="technical_analyst_agent", 124 | ) 125 | 126 | if state["metadata"]["show_reasoning"]: 127 | show_agent_reasoning(technical_analysis, "Technical Analyst") 128 | 129 | # Add the signal to the analyst_signals list 130 | state["data"]["analyst_signals"]["technical_analyst_agent"] = technical_analysis 131 | 132 | return { 133 | "messages": state["messages"] + [message], 134 | "data": data, 135 | } 136 | 137 | 138 | def calculate_trend_signals(prices_df): 139 | """ 140 | Advanced trend following strategy using multiple timeframes and indicators 141 | """ 142 | # Calculate EMAs for multiple timeframes 143 | ema_8 = calculate_ema(prices_df, 8) 144 | ema_21 = calculate_ema(prices_df, 21) 145 | ema_55 = calculate_ema(prices_df, 55) 146 | 147 | # Calculate ADX for trend strength 148 | adx = calculate_adx(prices_df, 14) 149 | 150 | # Determine trend direction and strength 151 | short_trend = ema_8 > ema_21 152 | medium_trend = ema_21 > ema_55 153 | 154 | # Combine signals with confidence weighting 155 | trend_strength = adx["adx"].iloc[-1] / 100.0 156 | 157 | if short_trend.iloc[-1] and medium_trend.iloc[-1]: 158 | signal = "bullish" 159 | confidence = trend_strength 160 | elif not short_trend.iloc[-1] and not medium_trend.iloc[-1]: 161 | signal = "bearish" 162 | confidence = trend_strength 163 | else: 164 | signal = "neutral" 165 | confidence = 0.5 166 | 167 | return { 168 | "signal": signal, 169 | "confidence": confidence, 170 | "metrics": { 171 | "adx": float(adx["adx"].iloc[-1]), 172 | "trend_strength": float(trend_strength), 173 | }, 174 | } 175 | 176 | 177 | def calculate_mean_reversion_signals(prices_df): 178 | """ 179 | Mean reversion strategy using statistical measures and Bollinger Bands 180 | """ 181 | # Calculate z-score of price relative to moving average 182 | ma_50 = prices_df["close"].rolling(window=50).mean() 183 | std_50 = prices_df["close"].rolling(window=50).std() 184 | z_score = (prices_df["close"] - ma_50) / std_50 185 | 186 | # Calculate Bollinger Bands 187 | bb_upper, bb_lower = calculate_bollinger_bands(prices_df) 188 | 189 | # Calculate RSI with multiple timeframes 190 | rsi_14 = calculate_rsi(prices_df, 14) 191 | rsi_28 = calculate_rsi(prices_df, 28) 192 | 193 | # Mean reversion signals 194 | price_vs_bb = (prices_df["close"].iloc[-1] - bb_lower.iloc[-1]) / (bb_upper.iloc[-1] - bb_lower.iloc[-1]) 195 | 196 | # Combine signals 197 | if z_score.iloc[-1] < -2 and price_vs_bb < 0.2: 198 | signal = "bullish" 199 | confidence = min(abs(z_score.iloc[-1]) / 4, 1.0) 200 | elif z_score.iloc[-1] > 2 and price_vs_bb > 0.8: 201 | signal = "bearish" 202 | confidence = min(abs(z_score.iloc[-1]) / 4, 1.0) 203 | else: 204 | signal = "neutral" 205 | confidence = 0.5 206 | 207 | return { 208 | "signal": signal, 209 | "confidence": confidence, 210 | "metrics": { 211 | "z_score": float(z_score.iloc[-1]), 212 | "price_vs_bb": float(price_vs_bb), 213 | "rsi_14": float(rsi_14.iloc[-1]), 214 | "rsi_28": float(rsi_28.iloc[-1]), 215 | }, 216 | } 217 | 218 | 219 | def calculate_momentum_signals(prices_df): 220 | """ 221 | Multi-factor momentum strategy 222 | """ 223 | # Price momentum 224 | returns = prices_df["close"].pct_change() 225 | mom_1m = returns.rolling(21).sum() 226 | mom_3m = returns.rolling(63).sum() 227 | mom_6m = returns.rolling(126).sum() 228 | 229 | # Volume momentum 230 | volume_ma = prices_df["volume"].rolling(21).mean() 231 | volume_momentum = prices_df["volume"] / volume_ma 232 | 233 | # Relative strength 234 | # (would compare to market/sector in real implementation) 235 | 236 | # Calculate momentum score 237 | momentum_score = (0.4 * mom_1m + 0.3 * mom_3m + 0.3 * mom_6m).iloc[-1] 238 | 239 | # Volume confirmation 240 | volume_confirmation = volume_momentum.iloc[-1] > 1.0 241 | 242 | if momentum_score > 0.05 and volume_confirmation: 243 | signal = "bullish" 244 | confidence = min(abs(momentum_score) * 5, 1.0) 245 | elif momentum_score < -0.05 and volume_confirmation: 246 | signal = "bearish" 247 | confidence = min(abs(momentum_score) * 5, 1.0) 248 | else: 249 | signal = "neutral" 250 | confidence = 0.5 251 | 252 | return { 253 | "signal": signal, 254 | "confidence": confidence, 255 | "metrics": { 256 | "momentum_1m": float(mom_1m.iloc[-1]), 257 | "momentum_3m": float(mom_3m.iloc[-1]), 258 | "momentum_6m": float(mom_6m.iloc[-1]), 259 | "volume_momentum": float(volume_momentum.iloc[-1]), 260 | }, 261 | } 262 | 263 | 264 | def calculate_volatility_signals(prices_df): 265 | """ 266 | Volatility-based trading strategy 267 | """ 268 | # Calculate various volatility metrics 269 | returns = prices_df["close"].pct_change() 270 | 271 | # Historical volatility 272 | hist_vol = returns.rolling(21).std() * math.sqrt(252) 273 | 274 | # Volatility regime detection 275 | vol_ma = hist_vol.rolling(63).mean() 276 | vol_regime = hist_vol / vol_ma 277 | 278 | # Volatility mean reversion 279 | vol_z_score = (hist_vol - vol_ma) / hist_vol.rolling(63).std() 280 | 281 | # ATR ratio 282 | atr = calculate_atr(prices_df) 283 | atr_ratio = atr / prices_df["close"] 284 | 285 | # Generate signal based on volatility regime 286 | current_vol_regime = vol_regime.iloc[-1] 287 | vol_z = vol_z_score.iloc[-1] 288 | 289 | if current_vol_regime < 0.8 and vol_z < -1: 290 | signal = "bullish" # Low vol regime, potential for expansion 291 | confidence = min(abs(vol_z) / 3, 1.0) 292 | elif current_vol_regime > 1.2 and vol_z > 1: 293 | signal = "bearish" # High vol regime, potential for contraction 294 | confidence = min(abs(vol_z) / 3, 1.0) 295 | else: 296 | signal = "neutral" 297 | confidence = 0.5 298 | 299 | return { 300 | "signal": signal, 301 | "confidence": confidence, 302 | "metrics": { 303 | "historical_volatility": float(hist_vol.iloc[-1]), 304 | "volatility_regime": float(current_vol_regime), 305 | "volatility_z_score": float(vol_z), 306 | "atr_ratio": float(atr_ratio.iloc[-1]), 307 | }, 308 | } 309 | 310 | 311 | def calculate_stat_arb_signals(prices_df): 312 | """ 313 | Statistical arbitrage signals based on price action analysis 314 | """ 315 | # Calculate price distribution statistics 316 | returns = prices_df["close"].pct_change() 317 | 318 | # Skewness and kurtosis 319 | skew = returns.rolling(63).skew() 320 | kurt = returns.rolling(63).kurt() 321 | 322 | # Test for mean reversion using Hurst exponent 323 | hurst = calculate_hurst_exponent(prices_df["close"]) 324 | 325 | # Correlation analysis 326 | # (would include correlation with related securities in real implementation) 327 | 328 | # Generate signal based on statistical properties 329 | if hurst < 0.4 and skew.iloc[-1] > 1: 330 | signal = "bullish" 331 | confidence = (0.5 - hurst) * 2 332 | elif hurst < 0.4 and skew.iloc[-1] < -1: 333 | signal = "bearish" 334 | confidence = (0.5 - hurst) * 2 335 | else: 336 | signal = "neutral" 337 | confidence = 0.5 338 | 339 | return { 340 | "signal": signal, 341 | "confidence": confidence, 342 | "metrics": { 343 | "hurst_exponent": float(hurst), 344 | "skewness": float(skew.iloc[-1]), 345 | "kurtosis": float(kurt.iloc[-1]), 346 | }, 347 | } 348 | 349 | 350 | def weighted_signal_combination(signals, weights): 351 | """ 352 | Combines multiple trading signals using a weighted approach 353 | """ 354 | # Convert signals to numeric values 355 | signal_values = {"bullish": 1, "neutral": 0, "bearish": -1} 356 | 357 | weighted_sum = 0 358 | total_confidence = 0 359 | 360 | for strategy, signal in signals.items(): 361 | numeric_signal = signal_values[signal["signal"]] 362 | weight = weights[strategy] 363 | confidence = signal["confidence"] 364 | 365 | weighted_sum += numeric_signal * weight * confidence 366 | total_confidence += weight * confidence 367 | 368 | # Normalize the weighted sum 369 | if total_confidence > 0: 370 | final_score = weighted_sum / total_confidence 371 | else: 372 | final_score = 0 373 | 374 | # Convert back to signal 375 | if final_score > 0.2: 376 | signal = "bullish" 377 | elif final_score < -0.2: 378 | signal = "bearish" 379 | else: 380 | signal = "neutral" 381 | 382 | return {"signal": signal, "confidence": abs(final_score)} 383 | 384 | 385 | def normalize_pandas(obj): 386 | """Convert pandas Series/DataFrames to primitive Python types""" 387 | if isinstance(obj, pd.Series): 388 | return obj.tolist() 389 | elif isinstance(obj, pd.DataFrame): 390 | return obj.to_dict("records") 391 | elif isinstance(obj, dict): 392 | return {k: normalize_pandas(v) for k, v in obj.items()} 393 | elif isinstance(obj, (list, tuple)): 394 | return [normalize_pandas(item) for item in obj] 395 | return obj 396 | 397 | 398 | def calculate_rsi(prices_df: pd.DataFrame, period: int = 14) -> pd.Series: 399 | delta = prices_df["close"].diff() 400 | gain = (delta.where(delta > 0, 0)).fillna(0) 401 | loss = (-delta.where(delta < 0, 0)).fillna(0) 402 | avg_gain = gain.rolling(window=period).mean() 403 | avg_loss = loss.rolling(window=period).mean() 404 | rs = avg_gain / avg_loss 405 | rsi = 100 - (100 / (1 + rs)) 406 | return rsi 407 | 408 | 409 | def calculate_bollinger_bands(prices_df: pd.DataFrame, window: int = 20) -> tuple[pd.Series, pd.Series]: 410 | sma = prices_df["close"].rolling(window).mean() 411 | std_dev = prices_df["close"].rolling(window).std() 412 | upper_band = sma + (std_dev * 2) 413 | lower_band = sma - (std_dev * 2) 414 | return upper_band, lower_band 415 | 416 | 417 | def calculate_ema(df: pd.DataFrame, window: int) -> pd.Series: 418 | """ 419 | Calculate Exponential Moving Average 420 | 421 | Args: 422 | df: DataFrame with price data 423 | window: EMA period 424 | 425 | Returns: 426 | pd.Series: EMA values 427 | """ 428 | return df["close"].ewm(span=window, adjust=False).mean() 429 | 430 | 431 | def calculate_adx(df: pd.DataFrame, period: int = 14) -> pd.DataFrame: 432 | """ 433 | Calculate Average Directional Index (ADX) 434 | 435 | Args: 436 | df: DataFrame with OHLC data 437 | period: Period for calculations 438 | 439 | Returns: 440 | DataFrame with ADX values 441 | """ 442 | # Calculate True Range 443 | df["high_low"] = df["high"] - df["low"] 444 | df["high_close"] = abs(df["high"] - df["close"].shift()) 445 | df["low_close"] = abs(df["low"] - df["close"].shift()) 446 | df["tr"] = df[["high_low", "high_close", "low_close"]].max(axis=1) 447 | 448 | # Calculate Directional Movement 449 | df["up_move"] = df["high"] - df["high"].shift() 450 | df["down_move"] = df["low"].shift() - df["low"] 451 | 452 | df["plus_dm"] = np.where((df["up_move"] > df["down_move"]) & (df["up_move"] > 0), df["up_move"], 0) 453 | df["minus_dm"] = np.where((df["down_move"] > df["up_move"]) & (df["down_move"] > 0), df["down_move"], 0) 454 | 455 | # Calculate ADX 456 | df["+di"] = 100 * (df["plus_dm"].ewm(span=period).mean() / df["tr"].ewm(span=period).mean()) 457 | df["-di"] = 100 * (df["minus_dm"].ewm(span=period).mean() / df["tr"].ewm(span=period).mean()) 458 | df["dx"] = 100 * abs(df["+di"] - df["-di"]) / (df["+di"] + df["-di"]) 459 | df["adx"] = df["dx"].ewm(span=period).mean() 460 | 461 | return df[["adx", "+di", "-di"]] 462 | 463 | 464 | def calculate_atr(df: pd.DataFrame, period: int = 14) -> pd.Series: 465 | """ 466 | Calculate Average True Range 467 | 468 | Args: 469 | df: DataFrame with OHLC data 470 | period: Period for ATR calculation 471 | 472 | Returns: 473 | pd.Series: ATR values 474 | """ 475 | high_low = df["high"] - df["low"] 476 | high_close = abs(df["high"] - df["close"].shift()) 477 | low_close = abs(df["low"] - df["close"].shift()) 478 | 479 | ranges = pd.concat([high_low, high_close, low_close], axis=1) 480 | true_range = ranges.max(axis=1) 481 | 482 | return true_range.rolling(period).mean() 483 | 484 | 485 | def calculate_hurst_exponent(price_series: pd.Series, max_lag: int = 20) -> float: 486 | """ 487 | Calculate Hurst Exponent to determine long-term memory of time series 488 | H < 0.5: Mean reverting series 489 | H = 0.5: Random walk 490 | H > 0.5: Trending series 491 | 492 | Args: 493 | price_series: Array-like price data 494 | max_lag: Maximum lag for R/S calculation 495 | 496 | Returns: 497 | float: Hurst exponent 498 | """ 499 | lags = range(2, max_lag) 500 | # Add small epsilon to avoid log(0) 501 | tau = [max(1e-8, np.sqrt(np.std(np.subtract(price_series[lag:], price_series[:-lag])))) for lag in lags] 502 | 503 | # Return the Hurst exponent from linear fit 504 | try: 505 | reg = np.polyfit(np.log(lags), np.log(tau), 1) 506 | return reg[0] # Hurst exponent is the slope 507 | except (ValueError, RuntimeWarning): 508 | # Return 0.5 (random walk) if calculation fails 509 | return 0.5 510 | -------------------------------------------------------------------------------- /src/agents/valuation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | """Valuation Agent 4 | 5 | Implements four complementary valuation methodologies and aggregates them with 6 | configurable weights. 7 | """ 8 | 9 | from statistics import median 10 | import json 11 | from langchain_core.messages import HumanMessage 12 | from graph.state import AgentState, show_agent_reasoning 13 | from utils.progress import progress 14 | 15 | from tools.api import ( 16 | get_financial_metrics, 17 | get_market_cap, 18 | search_line_items, 19 | ) 20 | 21 | def valuation_agent(state: AgentState): 22 | """Run valuation across tickers and write signals back to `state`.""" 23 | 24 | data = state["data"] 25 | end_date = data["end_date"] 26 | tickers = data["tickers"] 27 | 28 | valuation_analysis: dict[str, dict] = {} 29 | 30 | for ticker in tickers: 31 | progress.update_status("valuation_agent", ticker, "Fetching financial data") 32 | 33 | # --- Historical financial metrics (pull 8 latest TTM snapshots for medians) --- 34 | financial_metrics = get_financial_metrics( 35 | ticker=ticker, 36 | end_date=end_date, 37 | period="ttm", 38 | limit=8, 39 | ) 40 | if not financial_metrics: 41 | progress.update_status("valuation_agent", ticker, "Failed: No financial metrics found") 42 | continue 43 | most_recent_metrics = financial_metrics[0] 44 | 45 | # --- Fine‑grained line‑items (need two periods to calc WC change) --- 46 | progress.update_status("valuation_agent", ticker, "Gathering line items") 47 | line_items = search_line_items( 48 | ticker=ticker, 49 | line_items=[ 50 | "free_cash_flow", 51 | "net_income", 52 | "depreciation_and_amortization", 53 | "capital_expenditure", 54 | "working_capital", 55 | ], 56 | end_date=end_date, 57 | period="ttm", 58 | limit=2, 59 | ) 60 | if len(line_items) < 2: 61 | progress.update_status("valuation_agent", ticker, "Failed: Insufficient financial line items") 62 | continue 63 | li_curr, li_prev = line_items[0], line_items[1] 64 | 65 | # ------------------------------------------------------------------ 66 | # Valuation models 67 | # ------------------------------------------------------------------ 68 | wc_change = li_curr.working_capital - li_prev.working_capital 69 | 70 | # Owner Earnings 71 | owner_val = calculate_owner_earnings_value( 72 | net_income=li_curr.net_income, 73 | depreciation=li_curr.depreciation_and_amortization, 74 | capex=li_curr.capital_expenditure, 75 | working_capital_change=wc_change, 76 | growth_rate=most_recent_metrics.earnings_growth or 0.05, 77 | ) 78 | 79 | # Discounted Cash Flow 80 | dcf_val = calculate_intrinsic_value( 81 | free_cash_flow=li_curr.free_cash_flow, 82 | growth_rate=most_recent_metrics.earnings_growth or 0.05, 83 | discount_rate=0.10, 84 | terminal_growth_rate=0.03, 85 | num_years=5, 86 | ) 87 | 88 | # Implied Equity Value 89 | ev_ebitda_val = calculate_ev_ebitda_value(financial_metrics) 90 | 91 | # Residual Income Model 92 | rim_val = calculate_residual_income_value( 93 | market_cap=most_recent_metrics.market_cap, 94 | net_income=li_curr.net_income, 95 | price_to_book_ratio=most_recent_metrics.price_to_book_ratio, 96 | book_value_growth=most_recent_metrics.book_value_growth or 0.03, 97 | ) 98 | 99 | # ------------------------------------------------------------------ 100 | # Aggregate & signal 101 | # ------------------------------------------------------------------ 102 | market_cap = get_market_cap(ticker, end_date) 103 | if not market_cap: 104 | progress.update_status("valuation_agent", ticker, "Failed: Market cap unavailable") 105 | continue 106 | 107 | method_values = { 108 | "dcf": {"value": dcf_val, "weight": 0.35}, 109 | "owner_earnings": {"value": owner_val, "weight": 0.35}, 110 | "ev_ebitda": {"value": ev_ebitda_val, "weight": 0.20}, 111 | "residual_income": {"value": rim_val, "weight": 0.10}, 112 | } 113 | 114 | total_weight = sum(v["weight"] for v in method_values.values() if v["value"] > 0) 115 | if total_weight == 0: 116 | progress.update_status("valuation_agent", ticker, "Failed: All valuation methods zero") 117 | continue 118 | 119 | for v in method_values.values(): 120 | v["gap"] = (v["value"] - market_cap) / market_cap if v["value"] > 0 else None 121 | 122 | weighted_gap = sum( 123 | v["weight"] * v["gap"] for v in method_values.values() if v["gap"] is not None 124 | ) / total_weight 125 | 126 | signal = "bullish" if weighted_gap > 0.15 else "bearish" if weighted_gap < -0.15 else "neutral" 127 | confidence = round(min(abs(weighted_gap) / 0.30 * 100, 100)) 128 | 129 | reasoning = { 130 | f"{m}_analysis": { 131 | "signal": ( 132 | "bullish" if vals["gap"] and vals["gap"] > 0.15 else 133 | "bearish" if vals["gap"] and vals["gap"] < -0.15 else "neutral" 134 | ), 135 | "details": ( 136 | f"Value: ${vals['value']:,.2f}, Market Cap: ${market_cap:,.2f}, " 137 | f"Gap: {vals['gap']:.1%}, Weight: {vals['weight']*100:.0f}%" 138 | ), 139 | } 140 | for m, vals in method_values.items() if vals["value"] > 0 141 | } 142 | 143 | valuation_analysis[ticker] = { 144 | "signal": signal, 145 | "confidence": confidence, 146 | "reasoning": reasoning, 147 | } 148 | progress.update_status("valuation_agent", ticker, "Done") 149 | 150 | # ---- Emit message (for LLM tool chain) ---- 151 | msg = HumanMessage(content=json.dumps(valuation_analysis), name="valuation_agent") 152 | if state["metadata"].get("show_reasoning"): 153 | show_agent_reasoning(valuation_analysis, "Valuation Analysis Agent") 154 | state["data"]["analyst_signals"]["valuation_agent"] = valuation_analysis 155 | return {"messages": [msg], "data": data} 156 | 157 | ############################# 158 | # Helper Valuation Functions 159 | ############################# 160 | 161 | def calculate_owner_earnings_value( 162 | net_income: float | None, 163 | depreciation: float | None, 164 | capex: float | None, 165 | working_capital_change: float | None, 166 | growth_rate: float = 0.05, 167 | required_return: float = 0.15, 168 | margin_of_safety: float = 0.25, 169 | num_years: int = 5, 170 | ) -> float: 171 | """Buffett owner‑earnings valuation with margin‑of‑safety.""" 172 | if not all(isinstance(x, (int, float)) for x in [net_income, depreciation, capex, working_capital_change]): 173 | return 0 174 | 175 | owner_earnings = net_income + depreciation - capex - working_capital_change 176 | if owner_earnings <= 0: 177 | return 0 178 | 179 | pv = 0.0 180 | for yr in range(1, num_years + 1): 181 | future = owner_earnings * (1 + growth_rate) ** yr 182 | pv += future / (1 + required_return) ** yr 183 | 184 | terminal_growth = min(growth_rate, 0.03) 185 | term_val = (owner_earnings * (1 + growth_rate) ** num_years * (1 + terminal_growth)) / ( 186 | required_return - terminal_growth 187 | ) 188 | pv_term = term_val / (1 + required_return) ** num_years 189 | 190 | intrinsic = pv + pv_term 191 | return intrinsic * (1 - margin_of_safety) 192 | 193 | 194 | def calculate_intrinsic_value( 195 | free_cash_flow: float | None, 196 | growth_rate: float = 0.05, 197 | discount_rate: float = 0.10, 198 | terminal_growth_rate: float = 0.02, 199 | num_years: int = 5, 200 | ) -> float: 201 | """Classic DCF on FCF with constant growth and terminal value.""" 202 | if free_cash_flow is None or free_cash_flow <= 0: 203 | return 0 204 | 205 | pv = 0.0 206 | for yr in range(1, num_years + 1): 207 | fcft = free_cash_flow * (1 + growth_rate) ** yr 208 | pv += fcft / (1 + discount_rate) ** yr 209 | 210 | term_val = ( 211 | free_cash_flow * (1 + growth_rate) ** num_years * (1 + terminal_growth_rate) 212 | ) / (discount_rate - terminal_growth_rate) 213 | pv_term = term_val / (1 + discount_rate) ** num_years 214 | 215 | return pv + pv_term 216 | 217 | 218 | def calculate_ev_ebitda_value(financial_metrics: list): 219 | """Implied equity value via median EV/EBITDA multiple.""" 220 | if not financial_metrics: 221 | return 0 222 | m0 = financial_metrics[0] 223 | if not (m0.enterprise_value and m0.enterprise_value_to_ebitda_ratio): 224 | return 0 225 | if m0.enterprise_value_to_ebitda_ratio == 0: 226 | return 0 227 | 228 | ebitda_now = m0.enterprise_value / m0.enterprise_value_to_ebitda_ratio 229 | med_mult = median([ 230 | m.enterprise_value_to_ebitda_ratio for m in financial_metrics if m.enterprise_value_to_ebitda_ratio 231 | ]) 232 | ev_implied = med_mult * ebitda_now 233 | net_debt = (m0.enterprise_value or 0) - (m0.market_cap or 0) 234 | return max(ev_implied - net_debt, 0) 235 | 236 | 237 | def calculate_residual_income_value( 238 | market_cap: float | None, 239 | net_income: float | None, 240 | price_to_book_ratio: float | None, 241 | book_value_growth: float = 0.03, 242 | cost_of_equity: float = 0.10, 243 | terminal_growth_rate: float = 0.03, 244 | num_years: int = 5, 245 | ): 246 | """Residual Income Model (Edwards‑Bell‑Ohlson).""" 247 | if not (market_cap and net_income and price_to_book_ratio and price_to_book_ratio > 0): 248 | return 0 249 | 250 | book_val = market_cap / price_to_book_ratio 251 | ri0 = net_income - cost_of_equity * book_val 252 | if ri0 <= 0: 253 | return 0 254 | 255 | pv_ri = 0.0 256 | for yr in range(1, num_years + 1): 257 | ri_t = ri0 * (1 + book_value_growth) ** yr 258 | pv_ri += ri_t / (1 + cost_of_equity) ** yr 259 | 260 | term_ri = ri0 * (1 + book_value_growth) ** (num_years + 1) / ( 261 | cost_of_equity - terminal_growth_rate 262 | ) 263 | pv_term = term_ri / (1 + cost_of_equity) ** num_years 264 | 265 | intrinsic = book_val + pv_ri + pv_term 266 | return intrinsic * 0.8 # 20% margin of safety 267 | -------------------------------------------------------------------------------- /src/data/cache.py: -------------------------------------------------------------------------------- 1 | class Cache: 2 | """In-memory cache for API responses.""" 3 | 4 | def __init__(self): 5 | self._prices_cache: dict[str, list[dict[str, any]]] = {} 6 | self._financial_metrics_cache: dict[str, list[dict[str, any]]] = {} 7 | self._line_items_cache: dict[str, list[dict[str, any]]] = {} 8 | self._insider_trades_cache: dict[str, list[dict[str, any]]] = {} 9 | self._company_news_cache: dict[str, list[dict[str, any]]] = {} 10 | 11 | def _merge_data(self, existing: list[dict] | None, new_data: list[dict], key_field: str) -> list[dict]: 12 | """Merge existing and new data, avoiding duplicates based on a key field.""" 13 | if not existing: 14 | return new_data 15 | 16 | # Create a set of existing keys for O(1) lookup 17 | existing_keys = {item[key_field] for item in existing} 18 | 19 | # Only add items that don't exist yet 20 | merged = existing.copy() 21 | merged.extend([item for item in new_data if item[key_field] not in existing_keys]) 22 | return merged 23 | 24 | def get_prices(self, ticker: str) -> list[dict[str, any]] | None: 25 | """Get cached price data if available.""" 26 | return self._prices_cache.get(ticker) 27 | 28 | def set_prices(self, ticker: str, data: list[dict[str, any]]): 29 | """Append new price data to cache.""" 30 | self._prices_cache[ticker] = self._merge_data(self._prices_cache.get(ticker), data, key_field="time") 31 | 32 | def get_financial_metrics(self, ticker: str) -> list[dict[str, any]]: 33 | """Get cached financial metrics if available.""" 34 | return self._financial_metrics_cache.get(ticker) 35 | 36 | def set_financial_metrics(self, ticker: str, data: list[dict[str, any]]): 37 | """Append new financial metrics to cache.""" 38 | self._financial_metrics_cache[ticker] = self._merge_data(self._financial_metrics_cache.get(ticker), data, key_field="report_period") 39 | 40 | def get_line_items(self, ticker: str) -> list[dict[str, any]] | None: 41 | """Get cached line items if available.""" 42 | return self._line_items_cache.get(ticker) 43 | 44 | def set_line_items(self, ticker: str, data: list[dict[str, any]]): 45 | """Append new line items to cache.""" 46 | self._line_items_cache[ticker] = self._merge_data(self._line_items_cache.get(ticker), data, key_field="report_period") 47 | 48 | def get_insider_trades(self, ticker: str) -> list[dict[str, any]] | None: 49 | """Get cached insider trades if available.""" 50 | return self._insider_trades_cache.get(ticker) 51 | 52 | def set_insider_trades(self, ticker: str, data: list[dict[str, any]]): 53 | """Append new insider trades to cache.""" 54 | self._insider_trades_cache[ticker] = self._merge_data(self._insider_trades_cache.get(ticker), data, key_field="filing_date") # Could also use transaction_date if preferred 55 | 56 | def get_company_news(self, ticker: str) -> list[dict[str, any]] | None: 57 | """Get cached company news if available.""" 58 | return self._company_news_cache.get(ticker) 59 | 60 | def set_company_news(self, ticker: str, data: list[dict[str, any]]): 61 | """Append new company news to cache.""" 62 | self._company_news_cache[ticker] = self._merge_data(self._company_news_cache.get(ticker), data, key_field="date") 63 | 64 | 65 | # Global cache instance 66 | _cache = Cache() 67 | 68 | 69 | def get_cache() -> Cache: 70 | """Get the global cache instance.""" 71 | return _cache 72 | -------------------------------------------------------------------------------- /src/data/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Price(BaseModel): 5 | open: float 6 | close: float 7 | high: float 8 | low: float 9 | volume: int 10 | time: str 11 | 12 | 13 | class PriceResponse(BaseModel): 14 | ticker: str 15 | prices: list[Price] 16 | 17 | 18 | class FinancialMetrics(BaseModel): 19 | ticker: str 20 | report_period: str 21 | period: str 22 | currency: str 23 | market_cap: float | None 24 | enterprise_value: float | None 25 | price_to_earnings_ratio: float | None 26 | price_to_book_ratio: float | None 27 | price_to_sales_ratio: float | None 28 | enterprise_value_to_ebitda_ratio: float | None 29 | enterprise_value_to_revenue_ratio: float | None 30 | free_cash_flow_yield: float | None 31 | peg_ratio: float | None 32 | gross_margin: float | None 33 | operating_margin: float | None 34 | net_margin: float | None 35 | return_on_equity: float | None 36 | return_on_assets: float | None 37 | return_on_invested_capital: float | None 38 | asset_turnover: float | None 39 | inventory_turnover: float | None 40 | receivables_turnover: float | None 41 | days_sales_outstanding: float | None 42 | operating_cycle: float | None 43 | working_capital_turnover: float | None 44 | current_ratio: float | None 45 | quick_ratio: float | None 46 | cash_ratio: float | None 47 | operating_cash_flow_ratio: float | None 48 | debt_to_equity: float | None 49 | debt_to_assets: float | None 50 | interest_coverage: float | None 51 | revenue_growth: float | None 52 | earnings_growth: float | None 53 | book_value_growth: float | None 54 | earnings_per_share_growth: float | None 55 | free_cash_flow_growth: float | None 56 | operating_income_growth: float | None 57 | ebitda_growth: float | None 58 | payout_ratio: float | None 59 | earnings_per_share: float | None 60 | book_value_per_share: float | None 61 | free_cash_flow_per_share: float | None 62 | 63 | 64 | class FinancialMetricsResponse(BaseModel): 65 | financial_metrics: list[FinancialMetrics] 66 | 67 | 68 | class LineItem(BaseModel): 69 | ticker: str 70 | report_period: str 71 | period: str 72 | currency: str 73 | 74 | # Allow additional fields dynamically 75 | model_config = {"extra": "allow"} 76 | 77 | 78 | class LineItemResponse(BaseModel): 79 | search_results: list[LineItem] 80 | 81 | 82 | class InsiderTrade(BaseModel): 83 | ticker: str 84 | issuer: str | None 85 | name: str | None 86 | title: str | None 87 | is_board_director: bool | None 88 | transaction_date: str | None 89 | transaction_shares: float | None 90 | transaction_price_per_share: float | None 91 | transaction_value: float | None 92 | shares_owned_before_transaction: float | None 93 | shares_owned_after_transaction: float | None 94 | security_title: str | None 95 | filing_date: str 96 | 97 | 98 | class InsiderTradeResponse(BaseModel): 99 | insider_trades: list[InsiderTrade] 100 | 101 | 102 | class CompanyNews(BaseModel): 103 | ticker: str 104 | title: str 105 | author: str 106 | source: str 107 | date: str 108 | url: str 109 | sentiment: str | None = None 110 | 111 | 112 | class CompanyNewsResponse(BaseModel): 113 | news: list[CompanyNews] 114 | 115 | 116 | class CompanyFacts(BaseModel): 117 | ticker: str 118 | name: str 119 | cik: str | None = None 120 | industry: str | None = None 121 | sector: str | None = None 122 | category: str | None = None 123 | exchange: str | None = None 124 | is_active: bool | None = None 125 | listing_date: str | None = None 126 | location: str | None = None 127 | market_cap: float | None = None 128 | number_of_employees: int | None = None 129 | sec_filings_url: str | None = None 130 | sic_code: str | None = None 131 | sic_industry: str | None = None 132 | sic_sector: str | None = None 133 | website_url: str | None = None 134 | weighted_average_shares: int | None = None 135 | 136 | 137 | class CompanyFactsResponse(BaseModel): 138 | company_facts: CompanyFacts 139 | 140 | 141 | class Position(BaseModel): 142 | cash: float = 0.0 143 | shares: int = 0 144 | ticker: str 145 | 146 | 147 | class Portfolio(BaseModel): 148 | positions: dict[str, Position] # ticker -> Position mapping 149 | total_cash: float = 0.0 150 | 151 | 152 | class AnalystSignal(BaseModel): 153 | signal: str | None = None 154 | confidence: float | None = None 155 | reasoning: dict | str | None = None 156 | max_position_size: float | None = None # For risk management signals 157 | 158 | 159 | class TickerAnalysis(BaseModel): 160 | ticker: str 161 | analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping 162 | 163 | 164 | class AgentStateData(BaseModel): 165 | tickers: list[str] 166 | portfolio: Portfolio 167 | start_date: str 168 | end_date: str 169 | ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping 170 | 171 | 172 | class AgentStateMetadata(BaseModel): 173 | show_reasoning: bool = False 174 | model_config = {"extra": "allow"} 175 | -------------------------------------------------------------------------------- /src/graph/state.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import Annotated, Sequence, TypedDict 2 | 3 | import operator 4 | from langchain_core.messages import BaseMessage 5 | 6 | 7 | import json 8 | 9 | 10 | def merge_dicts(a: dict[str, any], b: dict[str, any]) -> dict[str, any]: 11 | return {**a, **b} 12 | 13 | 14 | # Define agent state 15 | class AgentState(TypedDict): 16 | messages: Annotated[Sequence[BaseMessage], operator.add] 17 | data: Annotated[dict[str, any], merge_dicts] 18 | metadata: Annotated[dict[str, any], merge_dicts] 19 | 20 | 21 | def show_agent_reasoning(output, agent_name): 22 | print(f"\n{'=' * 10} {agent_name.center(28)} {'=' * 10}") 23 | 24 | def convert_to_serializable(obj): 25 | if hasattr(obj, "to_dict"): # Handle Pandas Series/DataFrame 26 | return obj.to_dict() 27 | elif hasattr(obj, "__dict__"): # Handle custom objects 28 | return obj.__dict__ 29 | elif isinstance(obj, (int, float, bool, str)): 30 | return obj 31 | elif isinstance(obj, (list, tuple)): 32 | return [convert_to_serializable(item) for item in obj] 33 | elif isinstance(obj, dict): 34 | return {key: convert_to_serializable(value) for key, value in obj.items()} 35 | else: 36 | return str(obj) # Fallback to string representation 37 | 38 | if isinstance(output, (dict, list)): 39 | # Convert the output to JSON-serializable format 40 | serializable_output = convert_to_serializable(output) 41 | print(json.dumps(serializable_output, indent=2)) 42 | else: 43 | try: 44 | # Parse the string as JSON and pretty print it 45 | parsed_output = json.loads(output) 46 | print(json.dumps(parsed_output, indent=2)) 47 | except json.JSONDecodeError: 48 | # Fallback to original string if not valid JSON 49 | print(output) 50 | 51 | print("=" * 48) 52 | -------------------------------------------------------------------------------- /src/llm/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | from langchain_anthropic import ChatAnthropic 3 | from langchain_deepseek import ChatDeepSeek 4 | from langchain_google_genai import ChatGoogleGenerativeAI 5 | from langchain_groq import ChatGroq 6 | from langchain_openai import ChatOpenAI 7 | from langchain_ollama import ChatOllama 8 | from enum import Enum 9 | from pydantic import BaseModel 10 | from typing import Tuple 11 | 12 | 13 | class ModelProvider(str, Enum): 14 | """Enum for supported LLM providers""" 15 | ANTHROPIC = "Anthropic" 16 | DEEPSEEK = "DeepSeek" 17 | GEMINI = "Gemini" 18 | GROQ = "Groq" 19 | OPENAI = "OpenAI" 20 | OLLAMA = "Ollama" 21 | 22 | 23 | 24 | class LLMModel(BaseModel): 25 | """Represents an LLM model configuration""" 26 | display_name: str 27 | model_name: str 28 | provider: ModelProvider 29 | 30 | def to_choice_tuple(self) -> Tuple[str, str, str]: 31 | """Convert to format needed for questionary choices""" 32 | return (self.display_name, self.model_name, self.provider.value) 33 | 34 | def has_json_mode(self) -> bool: 35 | """Check if the model supports JSON mode""" 36 | if self.is_deepseek() or self.is_gemini(): 37 | return False 38 | # Only certain Ollama models support JSON mode 39 | if self.is_ollama(): 40 | return "llama3" in self.model_name or "neural-chat" in self.model_name 41 | return True 42 | 43 | def is_deepseek(self) -> bool: 44 | """Check if the model is a DeepSeek model""" 45 | return self.model_name.startswith("deepseek") 46 | 47 | def is_gemini(self) -> bool: 48 | """Check if the model is a Gemini model""" 49 | return self.model_name.startswith("gemini") 50 | 51 | def is_ollama(self) -> bool: 52 | """Check if the model is an Ollama model""" 53 | return self.provider == ModelProvider.OLLAMA 54 | 55 | 56 | # Define available models 57 | AVAILABLE_MODELS = [ 58 | LLMModel( 59 | display_name="[anthropic] claude-3.5-haiku", 60 | model_name="claude-3-5-haiku-latest", 61 | provider=ModelProvider.ANTHROPIC 62 | ), 63 | LLMModel( 64 | display_name="[anthropic] claude-3.5-sonnet", 65 | model_name="claude-3-5-sonnet-latest", 66 | provider=ModelProvider.ANTHROPIC 67 | ), 68 | LLMModel( 69 | display_name="[anthropic] claude-3.7-sonnet", 70 | model_name="claude-3-7-sonnet-latest", 71 | provider=ModelProvider.ANTHROPIC 72 | ), 73 | LLMModel( 74 | display_name="[deepseek] deepseek-r1", 75 | model_name="deepseek-reasoner", 76 | provider=ModelProvider.DEEPSEEK 77 | ), 78 | LLMModel( 79 | display_name="[deepseek] deepseek-v3", 80 | model_name="deepseek-chat", 81 | provider=ModelProvider.DEEPSEEK 82 | ), 83 | LLMModel( 84 | display_name="[gemini] gemini-2.0-flash", 85 | model_name="gemini-2.0-flash", 86 | provider=ModelProvider.GEMINI 87 | ), 88 | LLMModel( 89 | display_name="[gemini] gemini-2.5-pro", 90 | model_name="gemini-2.5-pro-exp-03-25", 91 | provider=ModelProvider.GEMINI 92 | ), 93 | LLMModel( 94 | display_name="[groq] llama-4-scout-17b", 95 | model_name="meta-llama/llama-4-scout-17b-16e-instruct", 96 | provider=ModelProvider.GROQ 97 | ), 98 | LLMModel( 99 | display_name="[groq] llama-4-maverick-17b", 100 | model_name="meta-llama/llama-4-maverick-17b-128e-instruct", 101 | provider=ModelProvider.GROQ 102 | ), 103 | LLMModel( 104 | display_name="[openai] gpt-4.5", 105 | model_name="gpt-4.5-preview", 106 | provider=ModelProvider.OPENAI 107 | ), 108 | LLMModel( 109 | display_name="[openai] gpt-4o", 110 | model_name="gpt-4o", 111 | provider=ModelProvider.OPENAI 112 | ), 113 | LLMModel( 114 | display_name="[openai] o3", 115 | model_name="o3", 116 | provider=ModelProvider.OPENAI 117 | ), 118 | LLMModel( 119 | display_name="[openai] o4-mini", 120 | model_name="o4-mini", 121 | provider=ModelProvider.OPENAI 122 | ), 123 | ] 124 | 125 | # Define Ollama models separately 126 | OLLAMA_MODELS = [ 127 | LLMModel( 128 | display_name="[ollama] gemma3 (4B)", 129 | model_name="gemma3:4b", 130 | provider=ModelProvider.OLLAMA 131 | ), 132 | LLMModel( 133 | display_name="[ollama] qwen2.5 (7B)", 134 | model_name="qwen2.5", 135 | provider=ModelProvider.OLLAMA 136 | ), 137 | LLMModel( 138 | display_name="[ollama] llama3.1 (8B)", 139 | model_name="llama3.1:latest", 140 | provider=ModelProvider.OLLAMA 141 | ), 142 | LLMModel( 143 | display_name="[ollama] gemma3 (12B)", 144 | model_name="gemma3:12b", 145 | provider=ModelProvider.OLLAMA 146 | ), 147 | LLMModel( 148 | display_name="[ollama] mistral-small3.1 (24B)", 149 | model_name="mistral-small3.1", 150 | provider=ModelProvider.OLLAMA 151 | ), 152 | LLMModel( 153 | display_name="[ollama] gemma3 (27B)", 154 | model_name="gemma3:27b", 155 | provider=ModelProvider.OLLAMA 156 | ), 157 | LLMModel( 158 | display_name="[ollama] qwen2.5 (32B)", 159 | model_name="qwen2.5:32b", 160 | provider=ModelProvider.OLLAMA 161 | ), 162 | LLMModel( 163 | display_name="[ollama] llama-3.3 (70B)", 164 | model_name="llama3.3:70b-instruct-q4_0", 165 | provider=ModelProvider.OLLAMA 166 | ), 167 | ] 168 | 169 | # Create LLM_ORDER in the format expected by the UI 170 | LLM_ORDER = [model.to_choice_tuple() for model in AVAILABLE_MODELS] 171 | 172 | # Create Ollama LLM_ORDER separately 173 | OLLAMA_LLM_ORDER = [model.to_choice_tuple() for model in OLLAMA_MODELS] 174 | 175 | def get_model_info(model_name: str) -> LLMModel | None: 176 | """Get model information by model_name""" 177 | all_models = AVAILABLE_MODELS + OLLAMA_MODELS 178 | return next((model for model in all_models if model.model_name == model_name), None) 179 | 180 | def get_model(model_name: str, model_provider: ModelProvider) -> ChatOpenAI | ChatGroq | ChatOllama | None: 181 | if model_provider == ModelProvider.GROQ: 182 | api_key = os.getenv("GROQ_API_KEY") 183 | if not api_key: 184 | # Print error to console 185 | print(f"API Key Error: Please make sure GROQ_API_KEY is set in your .env file.") 186 | raise ValueError("Groq API key not found. Please make sure GROQ_API_KEY is set in your .env file.") 187 | return ChatGroq(model=model_name, api_key=api_key) 188 | elif model_provider == ModelProvider.OPENAI: 189 | # Get and validate API key 190 | api_key = os.getenv("OPENAI_API_KEY") 191 | if not api_key: 192 | # Print error to console 193 | print(f"API Key Error: Please make sure OPENAI_API_KEY is set in your .env file.") 194 | raise ValueError("OpenAI API key not found. Please make sure OPENAI_API_KEY is set in your .env file.") 195 | return ChatOpenAI(model=model_name, api_key=api_key) 196 | elif model_provider == ModelProvider.ANTHROPIC: 197 | api_key = os.getenv("ANTHROPIC_API_KEY") 198 | if not api_key: 199 | print(f"API Key Error: Please make sure ANTHROPIC_API_KEY is set in your .env file.") 200 | raise ValueError("Anthropic API key not found. Please make sure ANTHROPIC_API_KEY is set in your .env file.") 201 | return ChatAnthropic(model=model_name, api_key=api_key) 202 | elif model_provider == ModelProvider.DEEPSEEK: 203 | api_key = os.getenv("DEEPSEEK_API_KEY") 204 | if not api_key: 205 | print(f"API Key Error: Please make sure DEEPSEEK_API_KEY is set in your .env file.") 206 | raise ValueError("DeepSeek API key not found. Please make sure DEEPSEEK_API_KEY is set in your .env file.") 207 | return ChatDeepSeek(model=model_name, api_key=api_key) 208 | elif model_provider == ModelProvider.GEMINI: 209 | api_key = os.getenv("GOOGLE_API_KEY") 210 | if not api_key: 211 | print(f"API Key Error: Please make sure GOOGLE_API_KEY is set in your .env file.") 212 | raise ValueError("Google API key not found. Please make sure GOOGLE_API_KEY is set in your .env file.") 213 | return ChatGoogleGenerativeAI(model=model_name, api_key=api_key) 214 | elif model_provider == ModelProvider.OLLAMA: 215 | # For Ollama, we use a base URL instead of an API key 216 | # Check if OLLAMA_HOST is set (for Docker on macOS) 217 | ollama_host = os.getenv("OLLAMA_HOST", "localhost") 218 | base_url = os.getenv("OLLAMA_BASE_URL", f"http://{ollama_host}:11434") 219 | return ChatOllama( 220 | model=model_name, 221 | base_url=base_url, 222 | ) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from dotenv import load_dotenv 4 | from langchain_core.messages import HumanMessage 5 | from langgraph.graph import END, StateGraph 6 | from colorama import Fore, Style, init 7 | import questionary 8 | from agents.portfolio_manager import portfolio_management_agent 9 | from agents.risk_manager import risk_management_agent 10 | from graph.state import AgentState 11 | from utils.display import print_trading_output 12 | from utils.analysts import ANALYST_ORDER, get_analyst_nodes 13 | from utils.progress import progress 14 | from llm.models import LLM_ORDER, OLLAMA_LLM_ORDER, get_model_info, ModelProvider 15 | from utils.ollama import ensure_ollama_and_model 16 | 17 | import argparse 18 | from datetime import datetime 19 | from dateutil.relativedelta import relativedelta 20 | from utils.visualize import save_graph_as_png 21 | import json 22 | 23 | # Load environment variables from .env file 24 | load_dotenv() 25 | 26 | init(autoreset=True) 27 | 28 | 29 | def parse_hedge_fund_response(response): 30 | """Parses a JSON string and returns a dictionary.""" 31 | try: 32 | return json.loads(response) 33 | except json.JSONDecodeError as e: 34 | print(f"JSON decoding error: {e}\nResponse: {repr(response)}") 35 | return None 36 | except TypeError as e: 37 | print(f"Invalid response type (expected string, got {type(response).__name__}): {e}") 38 | return None 39 | except Exception as e: 40 | print(f"Unexpected error while parsing response: {e}\nResponse: {repr(response)}") 41 | return None 42 | 43 | 44 | ##### Run the Hedge Fund ##### 45 | def run_hedge_fund( 46 | tickers: list[str], 47 | start_date: str, 48 | end_date: str, 49 | portfolio: dict, 50 | show_reasoning: bool = False, 51 | selected_analysts: list[str] = [], 52 | model_name: str = "gpt-4o", 53 | model_provider: str = "OpenAI", 54 | ): 55 | # Start progress tracking 56 | progress.start() 57 | 58 | try: 59 | # Create a new workflow if analysts are customized 60 | if selected_analysts: 61 | workflow = create_workflow(selected_analysts) 62 | agent = workflow.compile() 63 | else: 64 | agent = app 65 | 66 | final_state = agent.invoke( 67 | { 68 | "messages": [ 69 | HumanMessage( 70 | content="Make trading decisions based on the provided data.", 71 | ) 72 | ], 73 | "data": { 74 | "tickers": tickers, 75 | "portfolio": portfolio, 76 | "start_date": start_date, 77 | "end_date": end_date, 78 | "analyst_signals": {}, 79 | }, 80 | "metadata": { 81 | "show_reasoning": show_reasoning, 82 | "model_name": model_name, 83 | "model_provider": model_provider, 84 | }, 85 | }, 86 | ) 87 | 88 | return { 89 | "decisions": parse_hedge_fund_response(final_state["messages"][-1].content), 90 | "analyst_signals": final_state["data"]["analyst_signals"], 91 | } 92 | finally: 93 | # Stop progress tracking 94 | progress.stop() 95 | 96 | 97 | def start(state: AgentState): 98 | """Initialize the workflow with the input message.""" 99 | return state 100 | 101 | 102 | def create_workflow(selected_analysts=None): 103 | """Create the workflow with selected analysts.""" 104 | workflow = StateGraph(AgentState) 105 | workflow.add_node("start_node", start) 106 | 107 | # Get analyst nodes from the configuration 108 | analyst_nodes = get_analyst_nodes() 109 | 110 | # Default to all analysts if none selected 111 | if selected_analysts is None: 112 | selected_analysts = list(analyst_nodes.keys()) 113 | # Add selected analyst nodes 114 | for analyst_key in selected_analysts: 115 | node_name, node_func = analyst_nodes[analyst_key] 116 | workflow.add_node(node_name, node_func) 117 | workflow.add_edge("start_node", node_name) 118 | 119 | # Always add risk and portfolio management 120 | workflow.add_node("risk_management_agent", risk_management_agent) 121 | workflow.add_node("portfolio_management_agent", portfolio_management_agent) 122 | 123 | # Connect selected analysts to risk management 124 | for analyst_key in selected_analysts: 125 | node_name = analyst_nodes[analyst_key][0] 126 | workflow.add_edge(node_name, "risk_management_agent") 127 | 128 | workflow.add_edge("risk_management_agent", "portfolio_management_agent") 129 | workflow.add_edge("portfolio_management_agent", END) 130 | 131 | workflow.set_entry_point("start_node") 132 | return workflow 133 | 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser(description="Run the hedge fund trading system") 137 | parser.add_argument("--initial-cash", type=float, default=100000.0, help="Initial cash position. Defaults to 100000.0)") 138 | parser.add_argument("--margin-requirement", type=float, default=0.0, help="Initial margin requirement. Defaults to 0.0") 139 | parser.add_argument("--tickers", type=str, required=True, help="Comma-separated list of stock ticker symbols") 140 | parser.add_argument( 141 | "--start-date", 142 | type=str, 143 | help="Start date (YYYY-MM-DD). Defaults to 3 months before end date", 144 | ) 145 | parser.add_argument("--end-date", type=str, help="End date (YYYY-MM-DD). Defaults to today") 146 | parser.add_argument("--show-reasoning", action="store_true", help="Show reasoning from each agent") 147 | parser.add_argument("--show-agent-graph", action="store_true", help="Show the agent graph") 148 | parser.add_argument("--ollama", action="store_true", help="Use Ollama for local LLM inference") 149 | 150 | args = parser.parse_args() 151 | 152 | # Parse tickers from comma-separated string 153 | tickers = [ticker.strip() for ticker in args.tickers.split(",")] 154 | 155 | # Select analysts 156 | selected_analysts = None 157 | choices = questionary.checkbox( 158 | "Select your AI analysts.", 159 | choices=[questionary.Choice(display, value=value) for display, value in ANALYST_ORDER], 160 | instruction="\n\nInstructions: \n1. Press Space to select/unselect analysts.\n2. Press 'a' to select/unselect all.\n3. Press Enter when done to run the hedge fund.\n", 161 | validate=lambda x: len(x) > 0 or "You must select at least one analyst.", 162 | style=questionary.Style( 163 | [ 164 | ("checkbox-selected", "fg:green"), 165 | ("selected", "fg:green noinherit"), 166 | ("highlighted", "noinherit"), 167 | ("pointer", "noinherit"), 168 | ] 169 | ), 170 | ).ask() 171 | 172 | if not choices: 173 | print("\n\nInterrupt received. Exiting...") 174 | sys.exit(0) 175 | else: 176 | selected_analysts = choices 177 | print(f"\nSelected analysts: {', '.join(Fore.GREEN + choice.title().replace('_', ' ') + Style.RESET_ALL for choice in choices)}\n") 178 | 179 | # Select LLM model based on whether Ollama is being used 180 | model_choice = None 181 | model_provider = None 182 | 183 | if args.ollama: 184 | print(f"{Fore.CYAN}Using Ollama for local LLM inference.{Style.RESET_ALL}") 185 | 186 | # Select from Ollama-specific models 187 | model_choice = questionary.select( 188 | "Select your Ollama model:", 189 | choices=[questionary.Choice(display, value=value) for display, value, _ in OLLAMA_LLM_ORDER], 190 | style=questionary.Style( 191 | [ 192 | ("selected", "fg:green bold"), 193 | ("pointer", "fg:green bold"), 194 | ("highlighted", "fg:green"), 195 | ("answer", "fg:green bold"), 196 | ] 197 | ), 198 | ).ask() 199 | 200 | if not model_choice: 201 | print("\n\nInterrupt received. Exiting...") 202 | sys.exit(0) 203 | 204 | # Ensure Ollama is installed, running, and the model is available 205 | if not ensure_ollama_and_model(model_choice): 206 | print(f"{Fore.RED}Cannot proceed without Ollama and the selected model.{Style.RESET_ALL}") 207 | sys.exit(1) 208 | 209 | model_provider = ModelProvider.OLLAMA.value 210 | print(f"\nSelected {Fore.CYAN}Ollama{Style.RESET_ALL} model: {Fore.GREEN + Style.BRIGHT}{model_choice}{Style.RESET_ALL}\n") 211 | else: 212 | # Use the standard cloud-based LLM selection 213 | model_choice = questionary.select( 214 | "Select your LLM model:", 215 | choices=[questionary.Choice(display, value=value) for display, value, _ in LLM_ORDER], 216 | style=questionary.Style( 217 | [ 218 | ("selected", "fg:green bold"), 219 | ("pointer", "fg:green bold"), 220 | ("highlighted", "fg:green"), 221 | ("answer", "fg:green bold"), 222 | ] 223 | ), 224 | ).ask() 225 | 226 | if not model_choice: 227 | print("\n\nInterrupt received. Exiting...") 228 | sys.exit(0) 229 | else: 230 | # Get model info using the helper function 231 | model_info = get_model_info(model_choice) 232 | if model_info: 233 | model_provider = model_info.provider.value 234 | print(f"\nSelected {Fore.CYAN}{model_provider}{Style.RESET_ALL} model: {Fore.GREEN + Style.BRIGHT}{model_choice}{Style.RESET_ALL}\n") 235 | else: 236 | model_provider = "Unknown" 237 | print(f"\nSelected model: {Fore.GREEN + Style.BRIGHT}{model_choice}{Style.RESET_ALL}\n") 238 | 239 | # Create the workflow with selected analysts 240 | workflow = create_workflow(selected_analysts) 241 | app = workflow.compile() 242 | 243 | if args.show_agent_graph: 244 | file_path = "" 245 | if selected_analysts is not None: 246 | for selected_analyst in selected_analysts: 247 | file_path += selected_analyst + "_" 248 | file_path += "graph.png" 249 | save_graph_as_png(app, file_path) 250 | 251 | # Validate dates if provided 252 | if args.start_date: 253 | try: 254 | datetime.strptime(args.start_date, "%Y-%m-%d") 255 | except ValueError: 256 | raise ValueError("Start date must be in YYYY-MM-DD format") 257 | 258 | if args.end_date: 259 | try: 260 | datetime.strptime(args.end_date, "%Y-%m-%d") 261 | except ValueError: 262 | raise ValueError("End date must be in YYYY-MM-DD format") 263 | 264 | # Set the start and end dates 265 | end_date = args.end_date or datetime.now().strftime("%Y-%m-%d") 266 | if not args.start_date: 267 | # Calculate 3 months before end_date 268 | end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") 269 | start_date = (end_date_obj - relativedelta(months=3)).strftime("%Y-%m-%d") 270 | else: 271 | start_date = args.start_date 272 | 273 | # Initialize portfolio with cash amount and stock positions 274 | portfolio = { 275 | "cash": args.initial_cash, # Initial cash amount 276 | "margin_requirement": args.margin_requirement, # Initial margin requirement 277 | "margin_used": 0.0, # total margin usage across all short positions 278 | "positions": { 279 | ticker: { 280 | "long": 0, # Number of shares held long 281 | "short": 0, # Number of shares held short 282 | "long_cost_basis": 0.0, # Average cost basis for long positions 283 | "short_cost_basis": 0.0, # Average price at which shares were sold short 284 | "short_margin_used": 0.0, # Dollars of margin used for this ticker's short 285 | } 286 | for ticker in tickers 287 | }, 288 | "realized_gains": { 289 | ticker: { 290 | "long": 0.0, # Realized gains from long positions 291 | "short": 0.0, # Realized gains from short positions 292 | } 293 | for ticker in tickers 294 | }, 295 | } 296 | 297 | # Run the hedge fund 298 | result = run_hedge_fund( 299 | tickers=tickers, 300 | start_date=start_date, 301 | end_date=end_date, 302 | portfolio=portfolio, 303 | show_reasoning=args.show_reasoning, 304 | selected_analysts=selected_analysts, 305 | model_name=model_choice, 306 | model_provider=model_provider, 307 | ) 308 | print_trading_output(result) 309 | -------------------------------------------------------------------------------- /src/tools/api.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import pandas as pd 4 | import requests 5 | 6 | from data.cache import get_cache 7 | from data.models import ( 8 | CompanyNews, 9 | CompanyNewsResponse, 10 | FinancialMetrics, 11 | FinancialMetricsResponse, 12 | Price, 13 | PriceResponse, 14 | LineItem, 15 | LineItemResponse, 16 | InsiderTrade, 17 | InsiderTradeResponse, 18 | CompanyFactsResponse, 19 | ) 20 | 21 | # Global cache instance 22 | _cache = get_cache() 23 | 24 | 25 | def get_prices(ticker: str, start_date: str, end_date: str) -> list[Price]: 26 | """Fetch price data from cache or API.""" 27 | # Check cache first 28 | if cached_data := _cache.get_prices(ticker): 29 | # Filter cached data by date range and convert to Price objects 30 | filtered_data = [Price(**price) for price in cached_data if start_date <= price["time"] <= end_date] 31 | if filtered_data: 32 | return filtered_data 33 | 34 | # If not in cache or no data in range, fetch from API 35 | headers = {} 36 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): 37 | headers["X-API-KEY"] = api_key 38 | 39 | url = f"https://api.financialdatasets.ai/prices/?ticker={ticker}&interval=day&interval_multiplier=1&start_date={start_date}&end_date={end_date}" 40 | response = requests.get(url, headers=headers) 41 | if response.status_code != 200: 42 | raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}") 43 | 44 | # Parse response with Pydantic model 45 | price_response = PriceResponse(**response.json()) 46 | prices = price_response.prices 47 | 48 | if not prices: 49 | return [] 50 | 51 | # Cache the results as dicts 52 | _cache.set_prices(ticker, [p.model_dump() for p in prices]) 53 | return prices 54 | 55 | 56 | def get_financial_metrics( 57 | ticker: str, 58 | end_date: str, 59 | period: str = "ttm", 60 | limit: int = 10, 61 | ) -> list[FinancialMetrics]: 62 | """Fetch financial metrics from cache or API.""" 63 | # Check cache first 64 | if cached_data := _cache.get_financial_metrics(ticker): 65 | # Filter cached data by date and limit 66 | filtered_data = [FinancialMetrics(**metric) for metric in cached_data if metric["report_period"] <= end_date] 67 | filtered_data.sort(key=lambda x: x.report_period, reverse=True) 68 | if filtered_data: 69 | return filtered_data[:limit] 70 | 71 | # If not in cache or insufficient data, fetch from API 72 | headers = {} 73 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): 74 | headers["X-API-KEY"] = api_key 75 | 76 | url = f"https://api.financialdatasets.ai/financial-metrics/?ticker={ticker}&report_period_lte={end_date}&limit={limit}&period={period}" 77 | response = requests.get(url, headers=headers) 78 | if response.status_code != 200: 79 | raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}") 80 | 81 | # Parse response with Pydantic model 82 | metrics_response = FinancialMetricsResponse(**response.json()) 83 | # Return the FinancialMetrics objects directly instead of converting to dict 84 | financial_metrics = metrics_response.financial_metrics 85 | 86 | if not financial_metrics: 87 | return [] 88 | 89 | # Cache the results as dicts 90 | _cache.set_financial_metrics(ticker, [m.model_dump() for m in financial_metrics]) 91 | return financial_metrics 92 | 93 | 94 | def search_line_items( 95 | ticker: str, 96 | line_items: list[str], 97 | end_date: str, 98 | period: str = "ttm", 99 | limit: int = 10, 100 | ) -> list[LineItem]: 101 | """Fetch line items from API.""" 102 | # If not in cache or insufficient data, fetch from API 103 | headers = {} 104 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): 105 | headers["X-API-KEY"] = api_key 106 | 107 | url = "https://api.financialdatasets.ai/financials/search/line-items" 108 | 109 | body = { 110 | "tickers": [ticker], 111 | "line_items": line_items, 112 | "end_date": end_date, 113 | "period": period, 114 | "limit": limit, 115 | } 116 | response = requests.post(url, headers=headers, json=body) 117 | if response.status_code != 200: 118 | raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}") 119 | data = response.json() 120 | response_model = LineItemResponse(**data) 121 | search_results = response_model.search_results 122 | if not search_results: 123 | return [] 124 | 125 | # Cache the results 126 | return search_results[:limit] 127 | 128 | 129 | def get_insider_trades( 130 | ticker: str, 131 | end_date: str, 132 | start_date: str | None = None, 133 | limit: int = 1000, 134 | ) -> list[InsiderTrade]: 135 | """Fetch insider trades from cache or API.""" 136 | # Check cache first 137 | if cached_data := _cache.get_insider_trades(ticker): 138 | # Filter cached data by date range 139 | filtered_data = [InsiderTrade(**trade) for trade in cached_data 140 | if (start_date is None or (trade.get("transaction_date") or trade["filing_date"]) >= start_date) 141 | and (trade.get("transaction_date") or trade["filing_date"]) <= end_date] 142 | filtered_data.sort(key=lambda x: x.transaction_date or x.filing_date, reverse=True) 143 | if filtered_data: 144 | return filtered_data 145 | 146 | # If not in cache or insufficient data, fetch from API 147 | headers = {} 148 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): 149 | headers["X-API-KEY"] = api_key 150 | 151 | all_trades = [] 152 | current_end_date = end_date 153 | 154 | while True: 155 | url = f"https://api.financialdatasets.ai/insider-trades/?ticker={ticker}&filing_date_lte={current_end_date}" 156 | if start_date: 157 | url += f"&filing_date_gte={start_date}" 158 | url += f"&limit={limit}" 159 | 160 | response = requests.get(url, headers=headers) 161 | if response.status_code != 200: 162 | raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}") 163 | 164 | data = response.json() 165 | response_model = InsiderTradeResponse(**data) 166 | insider_trades = response_model.insider_trades 167 | 168 | if not insider_trades: 169 | break 170 | 171 | all_trades.extend(insider_trades) 172 | 173 | # Only continue pagination if we have a start_date and got a full page 174 | if not start_date or len(insider_trades) < limit: 175 | break 176 | 177 | # Update end_date to the oldest filing date from current batch for next iteration 178 | current_end_date = min(trade.filing_date for trade in insider_trades).split('T')[0] 179 | 180 | # If we've reached or passed the start_date, we can stop 181 | if current_end_date <= start_date: 182 | break 183 | 184 | if not all_trades: 185 | return [] 186 | 187 | # Cache the results 188 | _cache.set_insider_trades(ticker, [trade.model_dump() for trade in all_trades]) 189 | return all_trades 190 | 191 | 192 | def get_company_news( 193 | ticker: str, 194 | end_date: str, 195 | start_date: str | None = None, 196 | limit: int = 1000, 197 | ) -> list[CompanyNews]: 198 | """Fetch company news from cache or API.""" 199 | # Check cache first 200 | if cached_data := _cache.get_company_news(ticker): 201 | # Filter cached data by date range 202 | filtered_data = [CompanyNews(**news) for news in cached_data 203 | if (start_date is None or news["date"] >= start_date) 204 | and news["date"] <= end_date] 205 | filtered_data.sort(key=lambda x: x.date, reverse=True) 206 | if filtered_data: 207 | return filtered_data 208 | 209 | # If not in cache or insufficient data, fetch from API 210 | headers = {} 211 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): 212 | headers["X-API-KEY"] = api_key 213 | 214 | all_news = [] 215 | current_end_date = end_date 216 | 217 | while True: 218 | url = f"https://api.financialdatasets.ai/news/?ticker={ticker}&end_date={current_end_date}" 219 | if start_date: 220 | url += f"&start_date={start_date}" 221 | url += f"&limit={limit}" 222 | 223 | response = requests.get(url, headers=headers) 224 | if response.status_code != 200: 225 | raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}") 226 | 227 | data = response.json() 228 | response_model = CompanyNewsResponse(**data) 229 | company_news = response_model.news 230 | 231 | if not company_news: 232 | break 233 | 234 | all_news.extend(company_news) 235 | 236 | # Only continue pagination if we have a start_date and got a full page 237 | if not start_date or len(company_news) < limit: 238 | break 239 | 240 | # Update end_date to the oldest date from current batch for next iteration 241 | current_end_date = min(news.date for news in company_news).split('T')[0] 242 | 243 | # If we've reached or passed the start_date, we can stop 244 | if current_end_date <= start_date: 245 | break 246 | 247 | if not all_news: 248 | return [] 249 | 250 | # Cache the results 251 | _cache.set_company_news(ticker, [news.model_dump() for news in all_news]) 252 | return all_news 253 | 254 | 255 | def get_market_cap( 256 | ticker: str, 257 | end_date: str, 258 | ) -> float | None: 259 | """Fetch market cap from the API.""" 260 | # Check if end_date is today 261 | if end_date == datetime.datetime.now().strftime("%Y-%m-%d"): 262 | # Get the market cap from company facts API 263 | headers = {} 264 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): 265 | headers["X-API-KEY"] = api_key 266 | 267 | url = f"https://api.financialdatasets.ai/company/facts/?ticker={ticker}" 268 | response = requests.get(url, headers=headers) 269 | if response.status_code != 200: 270 | print(f"Error fetching company facts: {ticker} - {response.status_code}") 271 | return None 272 | 273 | data = response.json() 274 | response_model = CompanyFactsResponse(**data) 275 | return response_model.company_facts.market_cap 276 | 277 | financial_metrics = get_financial_metrics(ticker, end_date) 278 | if not financial_metrics: 279 | return None 280 | 281 | market_cap = financial_metrics[0].market_cap 282 | 283 | if not market_cap: 284 | return None 285 | 286 | return market_cap 287 | 288 | 289 | def prices_to_df(prices: list[Price]) -> pd.DataFrame: 290 | """Convert prices to a DataFrame.""" 291 | df = pd.DataFrame([p.model_dump() for p in prices]) 292 | df["Date"] = pd.to_datetime(df["time"]) 293 | df.set_index("Date", inplace=True) 294 | numeric_cols = ["open", "close", "high", "low", "volume"] 295 | for col in numeric_cols: 296 | df[col] = pd.to_numeric(df[col], errors="coerce") 297 | df.sort_index(inplace=True) 298 | return df 299 | 300 | 301 | # Update the get_price_data function to use the new functions 302 | def get_price_data(ticker: str, start_date: str, end_date: str) -> pd.DataFrame: 303 | prices = get_prices(ticker, start_date, end_date) 304 | return prices_to_df(prices) 305 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # This file can be empty 2 | 3 | """Utility modules for the application.""" 4 | -------------------------------------------------------------------------------- /src/utils/analysts.py: -------------------------------------------------------------------------------- 1 | """Constants and utilities related to analysts configuration.""" 2 | 3 | from agents.ben_graham import ben_graham_agent 4 | from agents.bill_ackman import bill_ackman_agent 5 | from agents.cathie_wood import cathie_wood_agent 6 | from agents.charlie_munger import charlie_munger_agent 7 | from agents.fundamentals import fundamentals_agent 8 | from agents.michael_burry import michael_burry_agent 9 | from agents.phil_fisher import phil_fisher_agent 10 | from agents.peter_lynch import peter_lynch_agent 11 | from agents.sentiment import sentiment_agent 12 | from agents.stanley_druckenmiller import stanley_druckenmiller_agent 13 | from agents.technicals import technical_analyst_agent 14 | from agents.valuation import valuation_agent 15 | from agents.warren_buffett import warren_buffett_agent 16 | 17 | # Define analyst configuration - single source of truth 18 | ANALYST_CONFIG = { 19 | "ben_graham": { 20 | "display_name": "Ben Graham", 21 | "agent_func": ben_graham_agent, 22 | "order": 0, 23 | }, 24 | "bill_ackman": { 25 | "display_name": "Bill Ackman", 26 | "agent_func": bill_ackman_agent, 27 | "order": 1, 28 | }, 29 | "cathie_wood": { 30 | "display_name": "Cathie Wood", 31 | "agent_func": cathie_wood_agent, 32 | "order": 2, 33 | }, 34 | "charlie_munger": { 35 | "display_name": "Charlie Munger", 36 | "agent_func": charlie_munger_agent, 37 | "order": 3, 38 | }, 39 | "michael_burry": { 40 | "display_name": "Michael Burry", 41 | "agent_func": michael_burry_agent, 42 | "order": 4, 43 | }, 44 | "peter_lynch": { 45 | "display_name": "Peter Lynch", 46 | "agent_func": peter_lynch_agent, 47 | "order": 5, 48 | }, 49 | "phil_fisher": { 50 | "display_name": "Phil Fisher", 51 | "agent_func": phil_fisher_agent, 52 | "order": 6, 53 | }, 54 | "stanley_druckenmiller": { 55 | "display_name": "Stanley Druckenmiller", 56 | "agent_func": stanley_druckenmiller_agent, 57 | "order": 7, 58 | }, 59 | "warren_buffett": { 60 | "display_name": "Warren Buffett", 61 | "agent_func": warren_buffett_agent, 62 | "order": 8, 63 | }, 64 | "technical_analyst": { 65 | "display_name": "Technical Analyst", 66 | "agent_func": technical_analyst_agent, 67 | "order": 9, 68 | }, 69 | "fundamentals_analyst": { 70 | "display_name": "Fundamentals Analyst", 71 | "agent_func": fundamentals_agent, 72 | "order": 10, 73 | }, 74 | "sentiment_analyst": { 75 | "display_name": "Sentiment Analyst", 76 | "agent_func": sentiment_agent, 77 | "order": 11, 78 | }, 79 | "valuation_analyst": { 80 | "display_name": "Valuation Analyst", 81 | "agent_func": valuation_agent, 82 | "order": 12, 83 | }, 84 | } 85 | 86 | # Derive ANALYST_ORDER from ANALYST_CONFIG for backwards compatibility 87 | ANALYST_ORDER = [(config["display_name"], key) for key, config in sorted(ANALYST_CONFIG.items(), key=lambda x: x[1]["order"])] 88 | 89 | 90 | def get_analyst_nodes(): 91 | """Get the mapping of analyst keys to their (node_name, agent_func) tuples.""" 92 | return {key: (f"{key}_agent", config["agent_func"]) for key, config in ANALYST_CONFIG.items()} 93 | -------------------------------------------------------------------------------- /src/utils/display.py: -------------------------------------------------------------------------------- 1 | from colorama import Fore, Style 2 | from tabulate import tabulate 3 | from .analysts import ANALYST_ORDER 4 | import os 5 | import json 6 | 7 | 8 | def sort_agent_signals(signals): 9 | """Sort agent signals in a consistent order.""" 10 | # Create order mapping from ANALYST_ORDER 11 | analyst_order = {display: idx for idx, (display, _) in enumerate(ANALYST_ORDER)} 12 | analyst_order["Risk Management"] = len(ANALYST_ORDER) # Add Risk Management at the end 13 | 14 | return sorted(signals, key=lambda x: analyst_order.get(x[0], 999)) 15 | 16 | 17 | def print_trading_output(result: dict) -> None: 18 | """ 19 | Print formatted trading results with colored tables for multiple tickers. 20 | 21 | Args: 22 | result (dict): Dictionary containing decisions and analyst signals for multiple tickers 23 | """ 24 | decisions = result.get("decisions") 25 | if not decisions: 26 | print(f"{Fore.RED}No trading decisions available{Style.RESET_ALL}") 27 | return 28 | 29 | # Print decisions for each ticker 30 | for ticker, decision in decisions.items(): 31 | print(f"\n{Fore.WHITE}{Style.BRIGHT}Analysis for {Fore.CYAN}{ticker}{Style.RESET_ALL}") 32 | print(f"{Fore.WHITE}{Style.BRIGHT}{'=' * 50}{Style.RESET_ALL}") 33 | 34 | # Prepare analyst signals table for this ticker 35 | table_data = [] 36 | for agent, signals in result.get("analyst_signals", {}).items(): 37 | if ticker not in signals: 38 | continue 39 | 40 | # Skip Risk Management agent in the signals section 41 | if agent == "risk_management_agent": 42 | continue 43 | 44 | signal = signals[ticker] 45 | agent_name = agent.replace("_agent", "").replace("_", " ").title() 46 | signal_type = signal.get("signal", "").upper() 47 | confidence = signal.get("confidence", 0) 48 | 49 | signal_color = { 50 | "BULLISH": Fore.GREEN, 51 | "BEARISH": Fore.RED, 52 | "NEUTRAL": Fore.YELLOW, 53 | }.get(signal_type, Fore.WHITE) 54 | 55 | # Get reasoning if available 56 | reasoning_str = "" 57 | if "reasoning" in signal and signal["reasoning"]: 58 | reasoning = signal["reasoning"] 59 | 60 | # Handle different types of reasoning (string, dict, etc.) 61 | if isinstance(reasoning, str): 62 | reasoning_str = reasoning 63 | elif isinstance(reasoning, dict): 64 | # Convert dict to string representation 65 | reasoning_str = json.dumps(reasoning, indent=2) 66 | else: 67 | # Convert any other type to string 68 | reasoning_str = str(reasoning) 69 | 70 | # Wrap long reasoning text to make it more readable 71 | wrapped_reasoning = "" 72 | current_line = "" 73 | # Use a fixed width of 60 characters to match the table column width 74 | max_line_length = 60 75 | for word in reasoning_str.split(): 76 | if len(current_line) + len(word) + 1 > max_line_length: 77 | wrapped_reasoning += current_line + "\n" 78 | current_line = word 79 | else: 80 | if current_line: 81 | current_line += " " + word 82 | else: 83 | current_line = word 84 | if current_line: 85 | wrapped_reasoning += current_line 86 | 87 | reasoning_str = wrapped_reasoning 88 | 89 | table_data.append( 90 | [ 91 | f"{Fore.CYAN}{agent_name}{Style.RESET_ALL}", 92 | f"{signal_color}{signal_type}{Style.RESET_ALL}", 93 | f"{Fore.WHITE}{confidence}%{Style.RESET_ALL}", 94 | f"{Fore.WHITE}{reasoning_str}{Style.RESET_ALL}", 95 | ] 96 | ) 97 | 98 | # Sort the signals according to the predefined order 99 | table_data = sort_agent_signals(table_data) 100 | 101 | print(f"\n{Fore.WHITE}{Style.BRIGHT}AGENT ANALYSIS:{Style.RESET_ALL} [{Fore.CYAN}{ticker}{Style.RESET_ALL}]") 102 | print( 103 | tabulate( 104 | table_data, 105 | headers=[f"{Fore.WHITE}Agent", "Signal", "Confidence", "Reasoning"], 106 | tablefmt="grid", 107 | colalign=("left", "center", "right", "left"), 108 | ) 109 | ) 110 | 111 | # Print Trading Decision Table 112 | action = decision.get("action", "").upper() 113 | action_color = { 114 | "BUY": Fore.GREEN, 115 | "SELL": Fore.RED, 116 | "HOLD": Fore.YELLOW, 117 | "COVER": Fore.GREEN, 118 | "SHORT": Fore.RED, 119 | }.get(action, Fore.WHITE) 120 | 121 | # Get reasoning and format it 122 | reasoning = decision.get("reasoning", "") 123 | # Wrap long reasoning text to make it more readable 124 | wrapped_reasoning = "" 125 | if reasoning: 126 | current_line = "" 127 | # Use a fixed width of 60 characters to match the table column width 128 | max_line_length = 60 129 | for word in reasoning.split(): 130 | if len(current_line) + len(word) + 1 > max_line_length: 131 | wrapped_reasoning += current_line + "\n" 132 | current_line = word 133 | else: 134 | if current_line: 135 | current_line += " " + word 136 | else: 137 | current_line = word 138 | if current_line: 139 | wrapped_reasoning += current_line 140 | 141 | decision_data = [ 142 | ["Action", f"{action_color}{action}{Style.RESET_ALL}"], 143 | ["Quantity", f"{action_color}{decision.get('quantity')}{Style.RESET_ALL}"], 144 | [ 145 | "Confidence", 146 | f"{Fore.WHITE}{decision.get('confidence'):.1f}%{Style.RESET_ALL}", 147 | ], 148 | ["Reasoning", f"{Fore.WHITE}{wrapped_reasoning}{Style.RESET_ALL}"], 149 | ] 150 | 151 | print(f"\n{Fore.WHITE}{Style.BRIGHT}TRADING DECISION:{Style.RESET_ALL} [{Fore.CYAN}{ticker}{Style.RESET_ALL}]") 152 | print(tabulate(decision_data, tablefmt="grid", colalign=("left", "left"))) 153 | 154 | # Print Portfolio Summary 155 | print(f"\n{Fore.WHITE}{Style.BRIGHT}PORTFOLIO SUMMARY:{Style.RESET_ALL}") 156 | portfolio_data = [] 157 | 158 | # Extract portfolio manager reasoning (common for all tickers) 159 | portfolio_manager_reasoning = None 160 | for ticker, decision in decisions.items(): 161 | if decision.get("reasoning"): 162 | portfolio_manager_reasoning = decision.get("reasoning") 163 | break 164 | 165 | for ticker, decision in decisions.items(): 166 | action = decision.get("action", "").upper() 167 | action_color = { 168 | "BUY": Fore.GREEN, 169 | "SELL": Fore.RED, 170 | "HOLD": Fore.YELLOW, 171 | "COVER": Fore.GREEN, 172 | "SHORT": Fore.RED, 173 | }.get(action, Fore.WHITE) 174 | portfolio_data.append( 175 | [ 176 | f"{Fore.CYAN}{ticker}{Style.RESET_ALL}", 177 | f"{action_color}{action}{Style.RESET_ALL}", 178 | f"{action_color}{decision.get('quantity')}{Style.RESET_ALL}", 179 | f"{Fore.WHITE}{decision.get('confidence'):.1f}%{Style.RESET_ALL}", 180 | ] 181 | ) 182 | 183 | headers = [f"{Fore.WHITE}Ticker", "Action", "Quantity", "Confidence"] 184 | 185 | # Print the portfolio summary table 186 | print( 187 | tabulate( 188 | portfolio_data, 189 | headers=headers, 190 | tablefmt="grid", 191 | colalign=("left", "center", "right", "right"), 192 | ) 193 | ) 194 | 195 | # Print Portfolio Manager's reasoning if available 196 | if portfolio_manager_reasoning: 197 | # Handle different types of reasoning (string, dict, etc.) 198 | reasoning_str = "" 199 | if isinstance(portfolio_manager_reasoning, str): 200 | reasoning_str = portfolio_manager_reasoning 201 | elif isinstance(portfolio_manager_reasoning, dict): 202 | # Convert dict to string representation 203 | reasoning_str = json.dumps(portfolio_manager_reasoning, indent=2) 204 | else: 205 | # Convert any other type to string 206 | reasoning_str = str(portfolio_manager_reasoning) 207 | 208 | # Wrap long reasoning text to make it more readable 209 | wrapped_reasoning = "" 210 | current_line = "" 211 | # Use a fixed width of 60 characters to match the table column width 212 | max_line_length = 60 213 | for word in reasoning_str.split(): 214 | if len(current_line) + len(word) + 1 > max_line_length: 215 | wrapped_reasoning += current_line + "\n" 216 | current_line = word 217 | else: 218 | if current_line: 219 | current_line += " " + word 220 | else: 221 | current_line = word 222 | if current_line: 223 | wrapped_reasoning += current_line 224 | 225 | print(f"\n{Fore.WHITE}{Style.BRIGHT}Portfolio Strategy:{Style.RESET_ALL}") 226 | print(f"{Fore.CYAN}{wrapped_reasoning}{Style.RESET_ALL}") 227 | 228 | 229 | def print_backtest_results(table_rows: list) -> None: 230 | """Print the backtest results in a nicely formatted table""" 231 | # Clear the screen 232 | os.system("cls" if os.name == "nt" else "clear") 233 | 234 | # Split rows into ticker rows and summary rows 235 | ticker_rows = [] 236 | summary_rows = [] 237 | 238 | for row in table_rows: 239 | if isinstance(row[1], str) and "PORTFOLIO SUMMARY" in row[1]: 240 | summary_rows.append(row) 241 | else: 242 | ticker_rows.append(row) 243 | 244 | 245 | # Display latest portfolio summary 246 | if summary_rows: 247 | latest_summary = summary_rows[-1] 248 | print(f"\n{Fore.WHITE}{Style.BRIGHT}PORTFOLIO SUMMARY:{Style.RESET_ALL}") 249 | 250 | # Extract values and remove commas before converting to float 251 | cash_str = latest_summary[7].split("$")[1].split(Style.RESET_ALL)[0].replace(",", "") 252 | position_str = latest_summary[6].split("$")[1].split(Style.RESET_ALL)[0].replace(",", "") 253 | total_str = latest_summary[8].split("$")[1].split(Style.RESET_ALL)[0].replace(",", "") 254 | 255 | print(f"Cash Balance: {Fore.CYAN}${float(cash_str):,.2f}{Style.RESET_ALL}") 256 | print(f"Total Position Value: {Fore.YELLOW}${float(position_str):,.2f}{Style.RESET_ALL}") 257 | print(f"Total Value: {Fore.WHITE}${float(total_str):,.2f}{Style.RESET_ALL}") 258 | print(f"Return: {latest_summary[9]}") 259 | 260 | # Display performance metrics if available 261 | if latest_summary[10]: # Sharpe ratio 262 | print(f"Sharpe Ratio: {latest_summary[10]}") 263 | if latest_summary[11]: # Sortino ratio 264 | print(f"Sortino Ratio: {latest_summary[11]}") 265 | if latest_summary[12]: # Max drawdown 266 | print(f"Max Drawdown: {latest_summary[12]}") 267 | 268 | # Add vertical spacing 269 | print("\n" * 2) 270 | 271 | # Print the table with just ticker rows 272 | print( 273 | tabulate( 274 | ticker_rows, 275 | headers=[ 276 | "Date", 277 | "Ticker", 278 | "Action", 279 | "Quantity", 280 | "Price", 281 | "Shares", 282 | "Position Value", 283 | "Bullish", 284 | "Bearish", 285 | "Neutral", 286 | ], 287 | tablefmt="grid", 288 | colalign=( 289 | "left", # Date 290 | "left", # Ticker 291 | "center", # Action 292 | "right", # Quantity 293 | "right", # Price 294 | "right", # Shares 295 | "right", # Position Value 296 | "right", # Bullish 297 | "right", # Bearish 298 | "right", # Neutral 299 | ), 300 | ) 301 | ) 302 | 303 | # Add vertical spacing 304 | print("\n" * 4) 305 | 306 | 307 | def format_backtest_row( 308 | date: str, 309 | ticker: str, 310 | action: str, 311 | quantity: float, 312 | price: float, 313 | shares_owned: float, 314 | position_value: float, 315 | bullish_count: int, 316 | bearish_count: int, 317 | neutral_count: int, 318 | is_summary: bool = False, 319 | total_value: float = None, 320 | return_pct: float = None, 321 | cash_balance: float = None, 322 | total_position_value: float = None, 323 | sharpe_ratio: float = None, 324 | sortino_ratio: float = None, 325 | max_drawdown: float = None, 326 | ) -> list[any]: 327 | """Format a row for the backtest results table""" 328 | # Color the action 329 | action_color = { 330 | "BUY": Fore.GREEN, 331 | "COVER": Fore.GREEN, 332 | "SELL": Fore.RED, 333 | "SHORT": Fore.RED, 334 | "HOLD": Fore.WHITE, 335 | }.get(action.upper(), Fore.WHITE) 336 | 337 | if is_summary: 338 | return_color = Fore.GREEN if return_pct >= 0 else Fore.RED 339 | return [ 340 | date, 341 | f"{Fore.WHITE}{Style.BRIGHT}PORTFOLIO SUMMARY{Style.RESET_ALL}", 342 | "", # Action 343 | "", # Quantity 344 | "", # Price 345 | "", # Shares 346 | f"{Fore.YELLOW}${total_position_value:,.2f}{Style.RESET_ALL}", # Total Position Value 347 | f"{Fore.CYAN}${cash_balance:,.2f}{Style.RESET_ALL}", # Cash Balance 348 | f"{Fore.WHITE}${total_value:,.2f}{Style.RESET_ALL}", # Total Value 349 | f"{return_color}{return_pct:+.2f}%{Style.RESET_ALL}", # Return 350 | f"{Fore.YELLOW}{sharpe_ratio:.2f}{Style.RESET_ALL}" if sharpe_ratio is not None else "", # Sharpe Ratio 351 | f"{Fore.YELLOW}{sortino_ratio:.2f}{Style.RESET_ALL}" if sortino_ratio is not None else "", # Sortino Ratio 352 | f"{Fore.RED}{abs(max_drawdown):.2f}%{Style.RESET_ALL}" if max_drawdown is not None else "", # Max Drawdown 353 | ] 354 | else: 355 | return [ 356 | date, 357 | f"{Fore.CYAN}{ticker}{Style.RESET_ALL}", 358 | f"{action_color}{action.upper()}{Style.RESET_ALL}", 359 | f"{action_color}{quantity:,.0f}{Style.RESET_ALL}", 360 | f"{Fore.WHITE}{price:,.2f}{Style.RESET_ALL}", 361 | f"{Fore.WHITE}{shares_owned:,.0f}{Style.RESET_ALL}", 362 | f"{Fore.YELLOW}{position_value:,.2f}{Style.RESET_ALL}", 363 | f"{Fore.GREEN}{bullish_count}{Style.RESET_ALL}", 364 | f"{Fore.RED}{bearish_count}{Style.RESET_ALL}", 365 | f"{Fore.BLUE}{neutral_count}{Style.RESET_ALL}", 366 | ] 367 | -------------------------------------------------------------------------------- /src/utils/docker.py: -------------------------------------------------------------------------------- 1 | """Utilities for working with Ollama models in Docker environments""" 2 | 3 | import requests 4 | import time 5 | from colorama import Fore, Style 6 | import questionary 7 | 8 | def ensure_ollama_and_model(model_name: str, ollama_url: str) -> bool: 9 | """Ensure the Ollama model is available in a Docker environment.""" 10 | print(f"{Fore.CYAN}Docker environment detected.{Style.RESET_ALL}") 11 | 12 | # Step 1: Check if Ollama service is available 13 | if not is_ollama_available(ollama_url): 14 | return False 15 | 16 | # Step 2: Check if model is already available 17 | available_models = get_available_models(ollama_url) 18 | if model_name in available_models: 19 | print(f"{Fore.GREEN}Model {model_name} is available in the Docker Ollama container.{Style.RESET_ALL}") 20 | return True 21 | 22 | # Step 3: Model not available - ask if user wants to download 23 | print(f"{Fore.YELLOW}Model {model_name} is not available in the Docker Ollama container.{Style.RESET_ALL}") 24 | 25 | if not questionary.confirm(f"Do you want to download {model_name}?").ask(): 26 | print(f"{Fore.RED}Cannot proceed without the model.{Style.RESET_ALL}") 27 | return False 28 | 29 | # Step 4: Download the model 30 | return download_model(model_name, ollama_url) 31 | 32 | 33 | def is_ollama_available(ollama_url: str) -> bool: 34 | """Check if Ollama service is available in Docker environment.""" 35 | try: 36 | response = requests.get(f"{ollama_url}/api/version", timeout=5) 37 | if response.status_code == 200: 38 | return True 39 | 40 | print(f"{Fore.RED}Cannot connect to Ollama service at {ollama_url}.{Style.RESET_ALL}") 41 | print(f"{Fore.YELLOW}Make sure the Ollama service is running in your Docker environment.{Style.RESET_ALL}") 42 | return False 43 | except requests.RequestException as e: 44 | print(f"{Fore.RED}Error connecting to Ollama service: {e}{Style.RESET_ALL}") 45 | return False 46 | 47 | 48 | def get_available_models(ollama_url: str) -> list: 49 | """Get list of available models in Docker environment.""" 50 | try: 51 | response = requests.get(f"{ollama_url}/api/tags", timeout=5) 52 | if response.status_code == 200: 53 | models = response.json().get("models", []) 54 | return [m["name"] for m in models] 55 | 56 | print(f"{Fore.RED}Failed to get available models from Ollama service. Status code: {response.status_code}{Style.RESET_ALL}") 57 | return [] 58 | except requests.RequestException as e: 59 | print(f"{Fore.RED}Error getting available models: {e}{Style.RESET_ALL}") 60 | return [] 61 | 62 | 63 | def download_model(model_name: str, ollama_url: str) -> bool: 64 | """Download a model in Docker environment.""" 65 | print(f"{Fore.YELLOW}Downloading model {model_name} to the Docker Ollama container...{Style.RESET_ALL}") 66 | print(f"{Fore.CYAN}This may take some time. Please be patient.{Style.RESET_ALL}") 67 | 68 | # Step 1: Initiate the download 69 | try: 70 | response = requests.post(f"{ollama_url}/api/pull", json={"name": model_name}, timeout=10) 71 | if response.status_code != 200: 72 | print(f"{Fore.RED}Failed to initiate model download. Status code: {response.status_code}{Style.RESET_ALL}") 73 | if response.text: 74 | print(f"{Fore.RED}Error: {response.text}{Style.RESET_ALL}") 75 | return False 76 | except requests.RequestException as e: 77 | print(f"{Fore.RED}Error initiating download request: {e}{Style.RESET_ALL}") 78 | return False 79 | 80 | # Step 2: Monitor the download progress 81 | print(f"{Fore.CYAN}Download initiated. Checking periodically for completion...{Style.RESET_ALL}") 82 | 83 | total_wait_time = 0 84 | max_wait_time = 1800 # 30 minutes max wait 85 | check_interval = 10 # Check every 10 seconds 86 | 87 | while total_wait_time < max_wait_time: 88 | # Check if the model has been downloaded 89 | available_models = get_available_models(ollama_url) 90 | if model_name in available_models: 91 | print(f"{Fore.GREEN}Model {model_name} downloaded successfully.{Style.RESET_ALL}") 92 | return True 93 | 94 | # Wait before checking again 95 | time.sleep(check_interval) 96 | total_wait_time += check_interval 97 | 98 | # Print a status message every minute 99 | if total_wait_time % 60 == 0: 100 | minutes = total_wait_time // 60 101 | print(f"{Fore.CYAN}Download in progress... ({minutes} minute{'s' if minutes != 1 else ''} elapsed){Style.RESET_ALL}") 102 | 103 | # If we get here, we've timed out 104 | print(f"{Fore.RED}Timed out waiting for model download to complete after {max_wait_time // 60} minutes.{Style.RESET_ALL}") 105 | return False 106 | 107 | 108 | def delete_model(model_name: str, ollama_url: str) -> bool: 109 | """Delete a model in Docker environment.""" 110 | print(f"{Fore.YELLOW}Deleting model {model_name} from Docker container...{Style.RESET_ALL}") 111 | 112 | try: 113 | response = requests.delete(f"{ollama_url}/api/delete", json={"name": model_name}, timeout=10) 114 | if response.status_code == 200: 115 | print(f"{Fore.GREEN}Model {model_name} deleted successfully.{Style.RESET_ALL}") 116 | return True 117 | else: 118 | print(f"{Fore.RED}Failed to delete model. Status code: {response.status_code}{Style.RESET_ALL}") 119 | if response.text: 120 | print(f"{Fore.RED}Error: {response.text}{Style.RESET_ALL}") 121 | return False 122 | except requests.RequestException as e: 123 | print(f"{Fore.RED}Error deleting model: {e}{Style.RESET_ALL}") 124 | return False -------------------------------------------------------------------------------- /src/utils/llm.py: -------------------------------------------------------------------------------- 1 | """Helper functions for LLM""" 2 | 3 | import json 4 | from typing import TypeVar, Type, Optional, Any 5 | from pydantic import BaseModel 6 | from utils.progress import progress 7 | 8 | T = TypeVar('T', bound=BaseModel) 9 | 10 | def call_llm( 11 | prompt: Any, 12 | model_name: str, 13 | model_provider: str, 14 | pydantic_model: Type[T], 15 | agent_name: Optional[str] = None, 16 | max_retries: int = 3, 17 | default_factory = None 18 | ) -> T: 19 | """ 20 | Makes an LLM call with retry logic, handling both JSON supported and non-JSON supported models. 21 | 22 | Args: 23 | prompt: The prompt to send to the LLM 24 | model_name: Name of the model to use 25 | model_provider: Provider of the model 26 | pydantic_model: The Pydantic model class to structure the output 27 | agent_name: Optional name of the agent for progress updates 28 | max_retries: Maximum number of retries (default: 3) 29 | default_factory: Optional factory function to create default response on failure 30 | 31 | Returns: 32 | An instance of the specified Pydantic model 33 | """ 34 | from llm.models import get_model, get_model_info 35 | 36 | model_info = get_model_info(model_name) 37 | llm = get_model(model_name, model_provider) 38 | 39 | # For non-JSON support models, we can use structured output 40 | if not (model_info and not model_info.has_json_mode()): 41 | llm = llm.with_structured_output( 42 | pydantic_model, 43 | method="json_mode", 44 | ) 45 | 46 | # Call the LLM with retries 47 | for attempt in range(max_retries): 48 | try: 49 | # Call the LLM 50 | result = llm.invoke(prompt) 51 | 52 | # For non-JSON support models, we need to extract and parse the JSON manually 53 | if model_info and not model_info.has_json_mode(): 54 | parsed_result = extract_json_from_response(result.content) 55 | if parsed_result: 56 | return pydantic_model(**parsed_result) 57 | else: 58 | return result 59 | 60 | except Exception as e: 61 | if agent_name: 62 | progress.update_status(agent_name, None, f"Error - retry {attempt + 1}/{max_retries}") 63 | 64 | if attempt == max_retries - 1: 65 | print(f"Error in LLM call after {max_retries} attempts: {e}") 66 | # Use default_factory if provided, otherwise create a basic default 67 | if default_factory: 68 | return default_factory() 69 | return create_default_response(pydantic_model) 70 | 71 | # This should never be reached due to the retry logic above 72 | return create_default_response(pydantic_model) 73 | 74 | def create_default_response(model_class: Type[T]) -> T: 75 | """Creates a safe default response based on the model's fields.""" 76 | default_values = {} 77 | for field_name, field in model_class.model_fields.items(): 78 | if field.annotation == str: 79 | default_values[field_name] = "Error in analysis, using default" 80 | elif field.annotation == float: 81 | default_values[field_name] = 0.0 82 | elif field.annotation == int: 83 | default_values[field_name] = 0 84 | elif hasattr(field.annotation, "__origin__") and field.annotation.__origin__ == dict: 85 | default_values[field_name] = {} 86 | else: 87 | # For other types (like Literal), try to use the first allowed value 88 | if hasattr(field.annotation, "__args__"): 89 | default_values[field_name] = field.annotation.__args__[0] 90 | else: 91 | default_values[field_name] = None 92 | 93 | return model_class(**default_values) 94 | 95 | def extract_json_from_response(content: str) -> Optional[dict]: 96 | """Extracts JSON from markdown-formatted response.""" 97 | try: 98 | json_start = content.find("```json") 99 | if json_start != -1: 100 | json_text = content[json_start + 7:] # Skip past ```json 101 | json_end = json_text.find("```") 102 | if json_end != -1: 103 | json_text = json_text[:json_end].strip() 104 | return json.loads(json_text) 105 | except Exception as e: 106 | print(f"Error extracting JSON from response: {e}") 107 | return None 108 | -------------------------------------------------------------------------------- /src/utils/ollama.py: -------------------------------------------------------------------------------- 1 | """Utilities for working with Ollama models""" 2 | 3 | import platform 4 | import subprocess 5 | import requests 6 | import time 7 | from typing import List 8 | import questionary 9 | from colorama import Fore, Style 10 | import os 11 | from . import docker 12 | 13 | # Constants 14 | OLLAMA_SERVER_URL = "http://localhost:11434" 15 | OLLAMA_API_MODELS_ENDPOINT = f"{OLLAMA_SERVER_URL}/api/tags" 16 | OLLAMA_DOWNLOAD_URL = {"darwin": "https://ollama.com/download/darwin", "windows": "https://ollama.com/download/windows", "linux": "https://ollama.com/download/linux"} # macOS # Windows # Linux 17 | INSTALLATION_INSTRUCTIONS = {"darwin": "curl -fsSL https://ollama.com/install.sh | sh", "windows": "# Download from https://ollama.com/download/windows and run the installer", "linux": "curl -fsSL https://ollama.com/install.sh | sh"} 18 | 19 | 20 | def is_ollama_installed() -> bool: 21 | """Check if Ollama is installed on the system.""" 22 | system = platform.system().lower() 23 | 24 | if system == "darwin" or system == "linux": # macOS or Linux 25 | try: 26 | result = subprocess.run(["which", "ollama"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 27 | return result.returncode == 0 28 | except Exception: 29 | return False 30 | elif system == "windows": # Windows 31 | try: 32 | result = subprocess.run(["where", "ollama"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True) 33 | return result.returncode == 0 34 | except Exception: 35 | return False 36 | else: 37 | return False # Unsupported OS 38 | 39 | 40 | def is_ollama_server_running() -> bool: 41 | """Check if the Ollama server is running.""" 42 | try: 43 | response = requests.get(OLLAMA_API_MODELS_ENDPOINT, timeout=2) 44 | return response.status_code == 200 45 | except requests.RequestException: 46 | return False 47 | 48 | 49 | def get_locally_available_models() -> List[str]: 50 | """Get a list of models that are already downloaded locally.""" 51 | if not is_ollama_server_running(): 52 | return [] 53 | 54 | try: 55 | response = requests.get(OLLAMA_API_MODELS_ENDPOINT, timeout=5) 56 | if response.status_code == 200: 57 | data = response.json() 58 | return [model["name"] for model in data["models"]] if "models" in data else [] 59 | return [] 60 | except requests.RequestException: 61 | return [] 62 | 63 | 64 | def start_ollama_server() -> bool: 65 | """Start the Ollama server if it's not already running.""" 66 | if is_ollama_server_running(): 67 | print(f"{Fore.GREEN}Ollama server is already running.{Style.RESET_ALL}") 68 | return True 69 | 70 | system = platform.system().lower() 71 | 72 | try: 73 | if system == "darwin" or system == "linux": # macOS or Linux 74 | subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 75 | elif system == "windows": # Windows 76 | subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) 77 | else: 78 | print(f"{Fore.RED}Unsupported operating system: {system}{Style.RESET_ALL}") 79 | return False 80 | 81 | # Wait for server to start 82 | for _ in range(10): # Try for 10 seconds 83 | if is_ollama_server_running(): 84 | print(f"{Fore.GREEN}Ollama server started successfully.{Style.RESET_ALL}") 85 | return True 86 | time.sleep(1) 87 | 88 | print(f"{Fore.RED}Failed to start Ollama server. Timed out waiting for server to become available.{Style.RESET_ALL}") 89 | return False 90 | except Exception as e: 91 | print(f"{Fore.RED}Error starting Ollama server: {e}{Style.RESET_ALL}") 92 | return False 93 | 94 | 95 | def install_ollama() -> bool: 96 | """Install Ollama on the system.""" 97 | system = platform.system().lower() 98 | if system not in OLLAMA_DOWNLOAD_URL: 99 | print(f"{Fore.RED}Unsupported operating system for automatic installation: {system}{Style.RESET_ALL}") 100 | print(f"Please visit https://ollama.com/download to install Ollama manually.") 101 | return False 102 | 103 | if system == "darwin": # macOS 104 | print(f"{Fore.YELLOW}Ollama for Mac is available as an application download.{Style.RESET_ALL}") 105 | 106 | # Default to offering the app download first for macOS users 107 | if questionary.confirm("Would you like to download the Ollama application?", default=True).ask(): 108 | try: 109 | import webbrowser 110 | 111 | webbrowser.open(OLLAMA_DOWNLOAD_URL["darwin"]) 112 | print(f"{Fore.YELLOW}Please download and install the application, then restart this program.{Style.RESET_ALL}") 113 | print(f"{Fore.CYAN}After installation, you may need to open the Ollama app once before continuing.{Style.RESET_ALL}") 114 | 115 | # Ask if they want to try continuing after installation 116 | if questionary.confirm("Have you installed the Ollama app and opened it at least once?", default=False).ask(): 117 | # Check if it's now installed 118 | if is_ollama_installed() and start_ollama_server(): 119 | print(f"{Fore.GREEN}Ollama is now properly installed and running!{Style.RESET_ALL}") 120 | return True 121 | else: 122 | print(f"{Fore.RED}Ollama installation not detected. Please restart this application after installing Ollama.{Style.RESET_ALL}") 123 | return False 124 | return False 125 | except Exception as e: 126 | print(f"{Fore.RED}Failed to open browser: {e}{Style.RESET_ALL}") 127 | return False 128 | else: 129 | # Only offer command-line installation as a fallback for advanced users 130 | if questionary.confirm("Would you like to try the command-line installation instead? (For advanced users)", default=False).ask(): 131 | print(f"{Fore.YELLOW}Attempting command-line installation...{Style.RESET_ALL}") 132 | try: 133 | install_process = subprocess.run(["bash", "-c", "curl -fsSL https://ollama.com/install.sh | sh"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 134 | 135 | if install_process.returncode == 0: 136 | print(f"{Fore.GREEN}Ollama installed successfully via command line.{Style.RESET_ALL}") 137 | return True 138 | else: 139 | print(f"{Fore.RED}Command-line installation failed. Please use the app download method instead.{Style.RESET_ALL}") 140 | return False 141 | except Exception as e: 142 | print(f"{Fore.RED}Error during command-line installation: {e}{Style.RESET_ALL}") 143 | return False 144 | return False 145 | elif system == "linux": # Linux 146 | print(f"{Fore.YELLOW}Installing Ollama...{Style.RESET_ALL}") 147 | try: 148 | # Run the installation command as a single command 149 | install_process = subprocess.run(["bash", "-c", "curl -fsSL https://ollama.com/install.sh | sh"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 150 | 151 | if install_process.returncode == 0: 152 | print(f"{Fore.GREEN}Ollama installed successfully.{Style.RESET_ALL}") 153 | return True 154 | else: 155 | print(f"{Fore.RED}Failed to install Ollama. Error: {install_process.stderr}{Style.RESET_ALL}") 156 | return False 157 | except Exception as e: 158 | print(f"{Fore.RED}Error during Ollama installation: {e}{Style.RESET_ALL}") 159 | return False 160 | elif system == "windows": # Windows 161 | print(f"{Fore.YELLOW}Automatic installation on Windows is not supported.{Style.RESET_ALL}") 162 | print(f"Please download and install Ollama from: {OLLAMA_DOWNLOAD_URL['windows']}") 163 | 164 | # Ask if they want to open the download page 165 | if questionary.confirm("Do you want to open the Ollama download page in your browser?").ask(): 166 | try: 167 | import webbrowser 168 | 169 | webbrowser.open(OLLAMA_DOWNLOAD_URL["windows"]) 170 | print(f"{Fore.YELLOW}After installation, please restart this application.{Style.RESET_ALL}") 171 | 172 | # Ask if they want to try continuing after installation 173 | if questionary.confirm("Have you installed Ollama?", default=False).ask(): 174 | # Check if it's now installed 175 | if is_ollama_installed() and start_ollama_server(): 176 | print(f"{Fore.GREEN}Ollama is now properly installed and running!{Style.RESET_ALL}") 177 | return True 178 | else: 179 | print(f"{Fore.RED}Ollama installation not detected. Please restart this application after installing Ollama.{Style.RESET_ALL}") 180 | return False 181 | except Exception as e: 182 | print(f"{Fore.RED}Failed to open browser: {e}{Style.RESET_ALL}") 183 | return False 184 | 185 | return False 186 | 187 | 188 | def download_model(model_name: str) -> bool: 189 | """Download an Ollama model.""" 190 | if not is_ollama_server_running(): 191 | if not start_ollama_server(): 192 | return False 193 | 194 | print(f"{Fore.YELLOW}Downloading model {model_name}...{Style.RESET_ALL}") 195 | print(f"{Fore.CYAN}This may take a while depending on your internet speed and the model size.{Style.RESET_ALL}") 196 | print(f"{Fore.CYAN}The download is happening in the background. Please be patient...{Style.RESET_ALL}") 197 | 198 | try: 199 | # Use the Ollama CLI to download the model 200 | process = subprocess.Popen(["ollama", "pull", model_name], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True) # Redirect stderr to stdout to capture all output # Line buffered 201 | 202 | # Show some progress to the user 203 | print(f"{Fore.CYAN}Download progress:{Style.RESET_ALL}") 204 | 205 | # For tracking progress 206 | last_percentage = 0 207 | last_phase = "" 208 | bar_length = 40 209 | 210 | while True: 211 | output = process.stdout.readline() 212 | if output == "" and process.poll() is not None: 213 | break 214 | if output: 215 | output = output.strip() 216 | # Try to extract percentage information using a more lenient approach 217 | percentage = None 218 | current_phase = None 219 | 220 | # Example patterns in Ollama output: 221 | # "downloading: 23.45 MB / 42.19 MB [================>-------------] 55.59%" 222 | # "downloading model: 76%" 223 | # "pulling manifest: 100%" 224 | 225 | # Check for percentage in the output 226 | import re 227 | 228 | percentage_match = re.search(r"(\d+(\.\d+)?)%", output) 229 | if percentage_match: 230 | try: 231 | percentage = float(percentage_match.group(1)) 232 | except ValueError: 233 | percentage = None 234 | 235 | # Try to determine the current phase (downloading, extracting, etc.) 236 | phase_match = re.search(r"^([a-zA-Z\s]+):", output) 237 | if phase_match: 238 | current_phase = phase_match.group(1).strip() 239 | 240 | # If we found a percentage, display a progress bar 241 | if percentage is not None: 242 | # Only update if there's a significant change (avoid flickering) 243 | if abs(percentage - last_percentage) >= 1 or (current_phase and current_phase != last_phase): 244 | last_percentage = percentage 245 | if current_phase: 246 | last_phase = current_phase 247 | 248 | # Create a progress bar 249 | filled_length = int(bar_length * percentage / 100) 250 | bar = "█" * filled_length + "░" * (bar_length - filled_length) 251 | 252 | # Build the status line with the phase if available 253 | phase_display = f"{Fore.CYAN}{last_phase.capitalize()}{Style.RESET_ALL}: " if last_phase else "" 254 | status_line = f"\r{phase_display}{Fore.GREEN}{bar}{Style.RESET_ALL} {Fore.YELLOW}{percentage:.1f}%{Style.RESET_ALL}" 255 | 256 | # Print the status line without a newline to update in place 257 | print(status_line, end="", flush=True) 258 | else: 259 | # If we couldn't extract a percentage but have identifiable output 260 | if "download" in output.lower() or "extract" in output.lower() or "pulling" in output.lower(): 261 | # Don't print a newline for percentage updates 262 | if "%" in output: 263 | print(f"\r{Fore.GREEN}{output}{Style.RESET_ALL}", end="", flush=True) 264 | else: 265 | print(f"{Fore.GREEN}{output}{Style.RESET_ALL}") 266 | 267 | # Wait for the process to finish 268 | return_code = process.wait() 269 | 270 | # Ensure we print a newline after the progress bar 271 | print() 272 | 273 | if return_code == 0: 274 | print(f"{Fore.GREEN}Model {model_name} downloaded successfully!{Style.RESET_ALL}") 275 | return True 276 | else: 277 | print(f"{Fore.RED}Failed to download model {model_name}. Check your internet connection and try again.{Style.RESET_ALL}") 278 | return False 279 | except Exception as e: 280 | print(f"\n{Fore.RED}Error downloading model {model_name}: {e}{Style.RESET_ALL}") 281 | return False 282 | 283 | 284 | def ensure_ollama_and_model(model_name: str) -> bool: 285 | """Ensure Ollama is installed, running, and the requested model is available.""" 286 | # Check if we're running in Docker 287 | in_docker = os.environ.get("OLLAMA_BASE_URL", "").startswith("http://ollama:") or os.environ.get("OLLAMA_BASE_URL", "").startswith("http://host.docker.internal:") 288 | 289 | # In Docker environment, we need a different approach 290 | if in_docker: 291 | ollama_url = os.environ.get("OLLAMA_BASE_URL", "http://ollama:11434") 292 | return docker.ensure_ollama_and_model(model_name, ollama_url) 293 | 294 | # Regular flow for non-Docker environments 295 | # Check if Ollama is installed 296 | if not is_ollama_installed(): 297 | print(f"{Fore.YELLOW}Ollama is not installed on your system.{Style.RESET_ALL}") 298 | 299 | # Ask if they want to install it 300 | if questionary.confirm("Do you want to install Ollama?").ask(): 301 | if not install_ollama(): 302 | return False 303 | else: 304 | print(f"{Fore.RED}Ollama is required to use local models.{Style.RESET_ALL}") 305 | return False 306 | 307 | # Make sure the server is running 308 | if not is_ollama_server_running(): 309 | print(f"{Fore.YELLOW}Starting Ollama server...{Style.RESET_ALL}") 310 | if not start_ollama_server(): 311 | return False 312 | 313 | # Check if the model is already downloaded 314 | available_models = get_locally_available_models() 315 | if model_name not in available_models: 316 | print(f"{Fore.YELLOW}Model {model_name} is not available locally.{Style.RESET_ALL}") 317 | 318 | # Ask if they want to download it 319 | model_size_info = "" 320 | if "70b" in model_name: 321 | model_size_info = " This is a large model (up to several GB) and may take a while to download." 322 | elif "34b" in model_name or "8x7b" in model_name: 323 | model_size_info = " This is a medium-sized model (1-2 GB) and may take a few minutes to download." 324 | 325 | if questionary.confirm(f"Do you want to download the {model_name} model?{model_size_info} The download will happen in the background.").ask(): 326 | return download_model(model_name) 327 | else: 328 | print(f"{Fore.RED}The model is required to proceed.{Style.RESET_ALL}") 329 | return False 330 | 331 | return True 332 | 333 | 334 | def delete_model(model_name: str) -> bool: 335 | """Delete a locally downloaded Ollama model.""" 336 | # Check if we're running in Docker 337 | in_docker = os.environ.get("OLLAMA_BASE_URL", "").startswith("http://ollama:") or os.environ.get("OLLAMA_BASE_URL", "").startswith("http://host.docker.internal:") 338 | 339 | # In Docker environment, delegate to docker module 340 | if in_docker: 341 | ollama_url = os.environ.get("OLLAMA_BASE_URL", "http://ollama:11434") 342 | return docker.delete_model(model_name, ollama_url) 343 | 344 | # Non-Docker environment 345 | if not is_ollama_server_running(): 346 | if not start_ollama_server(): 347 | return False 348 | 349 | print(f"{Fore.YELLOW}Deleting model {model_name}...{Style.RESET_ALL}") 350 | 351 | try: 352 | # Use the Ollama CLI to delete the model 353 | process = subprocess.run(["ollama", "rm", model_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 354 | 355 | if process.returncode == 0: 356 | print(f"{Fore.GREEN}Model {model_name} deleted successfully.{Style.RESET_ALL}") 357 | return True 358 | else: 359 | print(f"{Fore.RED}Failed to delete model {model_name}. Error: {process.stderr}{Style.RESET_ALL}") 360 | return False 361 | except Exception as e: 362 | print(f"{Fore.RED}Error deleting model {model_name}: {e}{Style.RESET_ALL}") 363 | return False 364 | 365 | 366 | # Add this at the end of the file for command-line usage 367 | if __name__ == "__main__": 368 | import sys 369 | import argparse 370 | 371 | parser = argparse.ArgumentParser(description="Ollama model manager") 372 | parser.add_argument("--check-model", help="Check if model exists and download if needed") 373 | args = parser.parse_args() 374 | 375 | if args.check_model: 376 | print(f"Ensuring Ollama is installed and model {args.check_model} is available...") 377 | result = ensure_ollama_and_model(args.check_model) 378 | sys.exit(0 if result else 1) 379 | else: 380 | print("No action specified. Use --check-model to check if a model exists.") 381 | sys.exit(1) 382 | -------------------------------------------------------------------------------- /src/utils/progress.py: -------------------------------------------------------------------------------- 1 | from rich.console import Console 2 | from rich.live import Live 3 | from rich.table import Table 4 | from rich.style import Style 5 | from rich.text import Text 6 | from typing import Dict, Optional 7 | 8 | console = Console() 9 | 10 | 11 | class AgentProgress: 12 | """Manages progress tracking for multiple agents.""" 13 | 14 | def __init__(self): 15 | self.agent_status: Dict[str, Dict[str, str]] = {} 16 | self.table = Table(show_header=False, box=None, padding=(0, 1)) 17 | self.live = Live(self.table, console=console, refresh_per_second=4) 18 | self.started = False 19 | 20 | def start(self): 21 | """Start the progress display.""" 22 | if not self.started: 23 | self.live.start() 24 | self.started = True 25 | 26 | def stop(self): 27 | """Stop the progress display.""" 28 | if self.started: 29 | self.live.stop() 30 | self.started = False 31 | 32 | def update_status(self, agent_name: str, ticker: Optional[str] = None, status: str = ""): 33 | """Update the status of an agent.""" 34 | if agent_name not in self.agent_status: 35 | self.agent_status[agent_name] = {"status": "", "ticker": None} 36 | 37 | if ticker: 38 | self.agent_status[agent_name]["ticker"] = ticker 39 | if status: 40 | self.agent_status[agent_name]["status"] = status 41 | 42 | self._refresh_display() 43 | 44 | def _refresh_display(self): 45 | """Refresh the progress display.""" 46 | self.table.columns.clear() 47 | self.table.add_column(width=100) 48 | 49 | # Sort agents with Risk Management and Portfolio Management at the bottom 50 | def sort_key(item): 51 | agent_name = item[0] 52 | if "risk_management" in agent_name: 53 | return (2, agent_name) 54 | elif "portfolio_management" in agent_name: 55 | return (3, agent_name) 56 | else: 57 | return (1, agent_name) 58 | 59 | for agent_name, info in sorted(self.agent_status.items(), key=sort_key): 60 | status = info["status"] 61 | ticker = info["ticker"] 62 | 63 | # Create the status text with appropriate styling 64 | if status.lower() == "done": 65 | style = Style(color="green", bold=True) 66 | symbol = "✓" 67 | elif status.lower() == "error": 68 | style = Style(color="red", bold=True) 69 | symbol = "✗" 70 | else: 71 | style = Style(color="yellow") 72 | symbol = "⋯" 73 | 74 | agent_display = agent_name.replace("_agent", "").replace("_", " ").title() 75 | status_text = Text() 76 | status_text.append(f"{symbol} ", style=style) 77 | status_text.append(f"{agent_display:<20}", style=Style(bold=True)) 78 | 79 | if ticker: 80 | status_text.append(f"[{ticker}] ", style=Style(color="cyan")) 81 | status_text.append(status, style=style) 82 | 83 | self.table.add_row(status_text) 84 | 85 | 86 | # Create a global instance 87 | progress = AgentProgress() 88 | -------------------------------------------------------------------------------- /src/utils/visualize.py: -------------------------------------------------------------------------------- 1 | from langgraph.graph.state import CompiledGraph 2 | from langchain_core.runnables.graph import MermaidDrawMethod 3 | 4 | 5 | def save_graph_as_png(app: CompiledGraph, output_file_path) -> None: 6 | png_image = app.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API) 7 | file_path = output_file_path if len(output_file_path) > 0 else "graph.png" 8 | with open(file_path, "wb") as f: 9 | f.write(png_image) --------------------------------------------------------------------------------