├── .dockerignore
├── .env.example
├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── docker-compose.yml
├── poetry.lock
├── pyproject.toml
├── run.bat
├── run.sh
└── src
├── agents
├── ben_graham.py
├── bill_ackman.py
├── cathie_wood.py
├── charlie_munger.py
├── fundamentals.py
├── michael_burry.py
├── peter_lynch.py
├── phil_fisher.py
├── portfolio_manager.py
├── risk_manager.py
├── sentiment.py
├── stanley_druckenmiller.py
├── technicals.py
├── valuation.py
└── warren_buffett.py
├── backtester.py
├── data
├── cache.py
└── models.py
├── graph
└── state.py
├── llm
└── models.py
├── main.py
├── tools
└── api.py
└── utils
├── __init__.py
├── analysts.py
├── display.py
├── docker.py
├── llm.py
├── ollama.py
├── progress.py
└── visualize.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Git
2 | .git
3 | .gitignore
4 |
5 | # Poetry
6 | .venv
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 | .pytest_cache/
11 |
12 | # Environment
13 | .env
14 |
15 | # IDEs and editors
16 | .idea/
17 | .vscode/
18 | *.swp
19 | *.swo
20 |
21 | # Logs and data
22 | logs/
23 | data/
24 | *.log
25 |
26 | # OS specific
27 | .DS_Store
28 | Thumbs.db
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | # For running LLMs hosted by anthropic (claude-3-5-sonnet, claude-3-opus, claude-3-5-haiku)
2 | # Get your Anthropic API key from https://anthropic.com/
3 | ANTHROPIC_API_KEY=your-anthropic-api-key
4 | # For running LLMs hosted by deepseek (deepseek-chat, deepseek-reasoner, etc.)
5 | # Get your DeepSeek API key from https://deepseek.com/
6 | DEEPSEEK_API_KEY=your-deepseek-api-key
7 |
8 | # For running LLMs hosted by groq (deepseek, llama3, etc.)
9 | # Get your Groq API key from https://groq.com/
10 | GROQ_API_KEY=your-groq-api-key
11 |
12 | # For running LLMs hosted by gemini (gemini-2.0-flash, gemini-2.0-pro)
13 | # Get your Google API key from https://console.cloud.google.com/
14 | GOOGLE_API_KEY=your-google-api-key
15 | # For getting financial data to power the hedge fund
16 | # Get your Financial Datasets API key from https://financialdatasets.ai/
17 | FINANCIAL_DATASETS_API_KEY=your-financial-datasets-api-key
18 | # For running LLMs hosted by openai (gpt-4o, gpt-4o-mini, etc.)
19 | # Get your OpenAI API key from https://platform.openai.com/
20 | OPENAI_API_KEY=your-openai-api-key
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **Screenshot**
14 | Add a screenshot of the bug to help explain your problem.
15 |
16 | **Additional context**
17 | Add any other context about the problem here.
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the feature you'd like**
11 | A clear and concise description of what you want to happen.
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | env/
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | *.egg-info/
21 | .installed.cfg
22 | *.egg
23 |
24 | # Virtual Environment
25 | venv/
26 | ENV/
27 |
28 | # Environment Variables
29 | .env
30 |
31 | # IDE
32 | .idea/
33 | .vscode/
34 | *.swp
35 | *.swo
36 | .cursorrules
37 | .cursorignore
38 | .cursorindexingignore
39 |
40 | # OS
41 | .DS_Store
42 | Thumbs.db
43 |
44 | # graph
45 | *.png
46 |
47 | # Txt files
48 | *.txt
49 |
50 | # PDF files
51 | *.pdf
52 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11-slim
2 |
3 | WORKDIR /app
4 |
5 | # Install Poetry
6 | RUN pip install poetry==1.7.1
7 |
8 | # Copy only dependency files first for better caching
9 | COPY pyproject.toml poetry.lock* /app/
10 |
11 | # Configure Poetry to not use a virtual environment
12 | RUN poetry config virtualenvs.create false \
13 | && poetry install --no-interaction --no-ansi
14 |
15 | # Copy rest of the source code
16 | COPY . /app/
17 |
18 | # Default command (will be overridden by Docker Compose)
19 | CMD ["python", "src/main.py"]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Virat Singh
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AI Hedge Fund
2 |
3 | This is a proof of concept for an AI-powered hedge fund. The goal of this project is to explore the use of AI to make trading decisions. This project is for **educational** purposes only and is not intended for real trading or investment.
4 |
5 | This system employs several agents working together:
6 |
7 | 1. Ben Graham Agent - The godfather of value investing, only buys hidden gems with a margin of safety
8 | 2. Bill Ackman Agent - An activist investors, takes bold positions and pushes for change
9 | 3. Cathie Wood Agent - The queen of growth investing, believes in the power of innovation and disruption
10 | 4. Charlie Munger Agent - Warren Buffett's partner, only buys wonderful businesses at fair prices
11 | 5. Michael Burry Agent - The Big Short contrarian who hunts for deep value
12 | 6. Peter Lynch Agent - Practical investor who seeks "ten-baggers" in everyday businesses
13 | 7. Phil Fisher Agent - Meticulous growth investor who uses deep "scuttlebutt" research
14 | 8. Stanley Druckenmiller Agent - Macro legend who hunts for asymmetric opportunities with growth potential
15 | 9. Warren Buffett Agent - The oracle of Omaha, seeks wonderful companies at a fair price
16 | 10. Valuation Agent - Calculates the intrinsic value of a stock and generates trading signals
17 | 11. Sentiment Agent - Analyzes market sentiment and generates trading signals
18 | 12. Fundamentals Agent - Analyzes fundamental data and generates trading signals
19 | 13. Technicals Agent - Analyzes technical indicators and generates trading signals
20 | 14. Risk Manager - Calculates risk metrics and sets position limits
21 | 15. Portfolio Manager - Makes final trading decisions and generates orders
22 |
23 |
24 |
25 |
26 | **Note**: the system simulates trading decisions, it does not actually trade.
27 |
28 | [](https://twitter.com/virattt)
29 |
30 | ## Disclaimer
31 |
32 | This project is for **educational and research purposes only**.
33 |
34 | - Not intended for real trading or investment
35 | - No warranties or guarantees provided
36 | - Past performance does not indicate future results
37 | - Creator assumes no liability for financial losses
38 | - Consult a financial advisor for investment decisions
39 |
40 | By using this software, you agree to use it solely for learning purposes.
41 |
42 | ## Table of Contents
43 | - [Setup](#setup)
44 | - [Using Poetry](#using-poetry)
45 | - [Using Docker](#using-docker)
46 | - [Usage](#usage)
47 | - [Running the Hedge Fund](#running-the-hedge-fund)
48 | - [Running the Backtester](#running-the-backtester)
49 | - [Project Structure](#project-structure)
50 | - [Contributing](#contributing)
51 | - [Feature Requests](#feature-requests)
52 | - [License](#license)
53 |
54 | ## Setup
55 |
56 | ### Using Poetry
57 |
58 | Clone the repository:
59 | ```bash
60 | git clone https://github.com/virattt/ai-hedge-fund.git
61 | cd ai-hedge-fund
62 | ```
63 |
64 | 1. Install Poetry (if not already installed):
65 | ```bash
66 | curl -sSL https://install.python-poetry.org | python3 -
67 | ```
68 |
69 | 2. Install dependencies:
70 | ```bash
71 | poetry install
72 | ```
73 |
74 | 3. Set up your environment variables:
75 | ```bash
76 | # Create .env file for your API keys
77 | cp .env.example .env
78 | ```
79 |
80 | 4. Set your API keys:
81 | ```bash
82 | # For running LLMs hosted by openai (gpt-4o, gpt-4o-mini, etc.)
83 | # Get your OpenAI API key from https://platform.openai.com/
84 | OPENAI_API_KEY=your-openai-api-key
85 |
86 | # For running LLMs hosted by groq (deepseek, llama3, etc.)
87 | # Get your Groq API key from https://groq.com/
88 | GROQ_API_KEY=your-groq-api-key
89 |
90 | # For getting financial data to power the hedge fund
91 | # Get your Financial Datasets API key from https://financialdatasets.ai/
92 | FINANCIAL_DATASETS_API_KEY=your-financial-datasets-api-key
93 | ```
94 |
95 | ### Using Docker
96 |
97 | 1. Make sure you have Docker installed on your system. If not, you can download it from [Docker's official website](https://www.docker.com/get-started).
98 |
99 | 2. Clone the repository:
100 | ```bash
101 | git clone https://github.com/virattt/ai-hedge-fund.git
102 | cd ai-hedge-fund
103 | ```
104 |
105 | 3. Set up your environment variables:
106 | ```bash
107 | # Create .env file for your API keys
108 | cp .env.example .env
109 | ```
110 |
111 | 4. Edit the .env file to add your API keys as described above.
112 |
113 | 5. Build the Docker image:
114 | ```bash
115 | # On Linux/Mac:
116 | ./run.sh build
117 |
118 | # On Windows:
119 | run.bat build
120 | ```
121 |
122 | **Important**: You must set `OPENAI_API_KEY`, `GROQ_API_KEY`, `ANTHROPIC_API_KEY`, or `DEEPSEEK_API_KEY` for the hedge fund to work. If you want to use LLMs from all providers, you will need to set all API keys.
123 |
124 | Financial data for AAPL, GOOGL, MSFT, NVDA, and TSLA is free and does not require an API key.
125 |
126 | For any other ticker, you will need to set the `FINANCIAL_DATASETS_API_KEY` in the .env file.
127 |
128 | ## Usage
129 |
130 | ### Running the Hedge Fund
131 |
132 | #### With Poetry
133 | ```bash
134 | poetry run python src/main.py --ticker AAPL,MSFT,NVDA
135 | ```
136 |
137 | #### With Docker
138 | ```bash
139 | # On Linux/Mac:
140 | ./run.sh --ticker AAPL,MSFT,NVDA main
141 |
142 | # On Windows:
143 | run.bat --ticker AAPL,MSFT,NVDA main
144 | ```
145 |
146 | **Example Output:**
147 |
148 |
149 | You can also specify a `--ollama` flag to run the AI hedge fund using local LLMs.
150 |
151 | ```bash
152 | # With Poetry:
153 | poetry run python src/main.py --ticker AAPL,MSFT,NVDA --ollama
154 |
155 | # With Docker (on Linux/Mac):
156 | ./run.sh --ticker AAPL,MSFT,NVDA --ollama main
157 |
158 | # With Docker (on Windows):
159 | run.bat --ticker AAPL,MSFT,NVDA --ollama main
160 | ```
161 |
162 | You can also specify a `--show-reasoning` flag to print the reasoning of each agent to the console.
163 |
164 | ```bash
165 | # With Poetry:
166 | poetry run python src/main.py --ticker AAPL,MSFT,NVDA --show-reasoning
167 |
168 | # With Docker (on Linux/Mac):
169 | ./run.sh --ticker AAPL,MSFT,NVDA --show-reasoning main
170 |
171 | # With Docker (on Windows):
172 | run.bat --ticker AAPL,MSFT,NVDA --show-reasoning main
173 | ```
174 |
175 | You can optionally specify the start and end dates to make decisions for a specific time period.
176 |
177 | ```bash
178 | # With Poetry:
179 | poetry run python src/main.py --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01
180 |
181 | # With Docker (on Linux/Mac):
182 | ./run.sh --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 main
183 |
184 | # With Docker (on Windows):
185 | run.bat --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 main
186 | ```
187 |
188 | ### Running the Backtester
189 |
190 | #### With Poetry
191 | ```bash
192 | poetry run python src/backtester.py --ticker AAPL,MSFT,NVDA
193 | ```
194 |
195 | #### With Docker
196 | ```bash
197 | # On Linux/Mac:
198 | ./run.sh --ticker AAPL,MSFT,NVDA backtest
199 |
200 | # On Windows:
201 | run.bat --ticker AAPL,MSFT,NVDA backtest
202 | ```
203 |
204 | **Example Output:**
205 |
206 |
207 |
208 | You can optionally specify the start and end dates to backtest over a specific time period.
209 |
210 | ```bash
211 | # With Poetry:
212 | poetry run python src/backtester.py --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01
213 |
214 | # With Docker (on Linux/Mac):
215 | ./run.sh --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 backtest
216 |
217 | # With Docker (on Windows):
218 | run.bat --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 backtest
219 | ```
220 |
221 | You can also specify a `--ollama` flag to run the backtester using local LLMs.
222 | ```bash
223 | # With Poetry:
224 | poetry run python src/backtester.py --ticker AAPL,MSFT,NVDA --ollama
225 |
226 | # With Docker (on Linux/Mac):
227 | ./run.sh --ticker AAPL,MSFT,NVDA --ollama backtest
228 |
229 | # With Docker (on Windows):
230 | run.bat --ticker AAPL,MSFT,NVDA --ollama backtest
231 | ```
232 |
233 |
234 | ## Project Structure
235 | ```
236 | ai-hedge-fund/
237 | ├── src/
238 | │ ├── agents/ # Agent definitions and workflow
239 | │ │ ├── bill_ackman.py # Bill Ackman agent
240 | │ │ ├── fundamentals.py # Fundamental analysis agent
241 | │ │ ├── portfolio_manager.py # Portfolio management agent
242 | │ │ ├── risk_manager.py # Risk management agent
243 | │ │ ├── sentiment.py # Sentiment analysis agent
244 | │ │ ├── technicals.py # Technical analysis agent
245 | │ │ ├── valuation.py # Valuation analysis agent
246 | │ │ ├── ... # Other agents
247 | │ │ ├── warren_buffett.py # Warren Buffett agent
248 | │ ├── tools/ # Agent tools
249 | │ │ ├── api.py # API tools
250 | │ ├── backtester.py # Backtesting tools
251 | │ ├── main.py # Main entry point
252 | ├── pyproject.toml
253 | ├── ...
254 | ```
255 |
256 | ## Contributing
257 |
258 | 1. Fork the repository
259 | 2. Create a feature branch
260 | 3. Commit your changes
261 | 4. Push to the branch
262 | 5. Create a Pull Request
263 |
264 | **Important**: Please keep your pull requests small and focused. This will make it easier to review and merge.
265 |
266 | ## Feature Requests
267 |
268 | If you have a feature request, please open an [issue](https://github.com/virattt/ai-hedge-fund/issues) and make sure it is tagged with `enhancement`.
269 |
270 | ## License
271 |
272 | This project is licensed under the MIT License - see the LICENSE file for details.
273 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | ollama:
3 | image: ollama/ollama:latest
4 | container_name: ollama
5 | environment:
6 | - OLLAMA_HOST=0.0.0.0
7 | # Apple Silicon GPU acceleration
8 | - METAL_DEVICE=on
9 | - METAL_DEVICE_INDEX=0
10 | volumes:
11 | - ollama_data:/root/.ollama
12 | ports:
13 | - "11434:11434"
14 | restart: unless-stopped
15 |
16 | hedge-fund:
17 | build: .
18 | image: ai-hedge-fund
19 | depends_on:
20 | - ollama
21 | volumes:
22 | - ./.env:/app/.env
23 | command: python src/main.py --ticker AAPL,MSFT,NVDA
24 | environment:
25 | - PYTHONUNBUFFERED=1
26 | - OLLAMA_BASE_URL=http://ollama:11434
27 | tty: true
28 | stdin_open: true
29 |
30 | hedge-fund-reasoning:
31 | build: .
32 | image: ai-hedge-fund
33 | depends_on:
34 | - ollama
35 | volumes:
36 | - ./.env:/app/.env
37 | command: python src/main.py --ticker AAPL,MSFT,NVDA --show-reasoning
38 | environment:
39 | - PYTHONUNBUFFERED=1
40 | - OLLAMA_BASE_URL=http://ollama:11434
41 | tty: true
42 | stdin_open: true
43 |
44 | hedge-fund-ollama:
45 | build: .
46 | image: ai-hedge-fund
47 | depends_on:
48 | - ollama
49 | volumes:
50 | - ./.env:/app/.env
51 | command: python src/main.py --ticker AAPL,MSFT,NVDA --ollama
52 | environment:
53 | - PYTHONUNBUFFERED=1
54 | - OLLAMA_BASE_URL=http://ollama:11434
55 | tty: true
56 | stdin_open: true
57 |
58 | backtester:
59 | build: .
60 | image: ai-hedge-fund
61 | depends_on:
62 | - ollama
63 | volumes:
64 | - ./.env:/app/.env
65 | command: python src/backtester.py --ticker AAPL,MSFT,NVDA
66 | environment:
67 | - PYTHONUNBUFFERED=1
68 | - OLLAMA_BASE_URL=http://ollama:11434
69 | tty: true
70 | stdin_open: true
71 |
72 | backtester-ollama:
73 | build: .
74 | image: ai-hedge-fund
75 | depends_on:
76 | - ollama
77 | volumes:
78 | - ./.env:/app/.env
79 | command: python src/backtester.py --ticker AAPL,MSFT,NVDA --ollama
80 | environment:
81 | - PYTHONUNBUFFERED=1
82 | - OLLAMA_BASE_URL=http://ollama:11434
83 | tty: true
84 | stdin_open: true
85 |
86 | volumes:
87 | ollama_data:
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "ai-hedge-fund"
3 | version = "0.1.0"
4 | description = "An AI-powered hedge fund that uses multiple agents to make trading decisions"
5 | authors = ["Your Name "]
6 | readme = "README.md"
7 | packages = [
8 | { include = "src", from = "." }
9 | ]
10 | [tool.poetry.dependencies]
11 | python = "^3.9"
12 | langchain = "0.3.0"
13 | langchain-anthropic = "0.3.5"
14 | langchain-groq = "0.2.3"
15 | langchain-openai = "^0.3.5"
16 | langchain-deepseek = "^0.1.2"
17 | langchain-ollama = "^0.2.0"
18 | langgraph = "0.2.56"
19 | pandas = "^2.1.0"
20 | numpy = "^1.24.0"
21 | python-dotenv = "1.0.0"
22 | matplotlib = "^3.9.2"
23 | tabulate = "^0.9.0"
24 | colorama = "^0.4.6"
25 | questionary = "^2.1.0"
26 | rich = "^13.9.4"
27 | langchain-google-genai = "^2.0.11"
28 |
29 | [tool.poetry.group.dev.dependencies]
30 | pytest = "^7.4.0"
31 | black = "^23.7.0"
32 | isort = "^5.12.0"
33 | flake8 = "^6.1.0"
34 |
35 | [build-system]
36 | requires = ["poetry-core"]
37 | build-backend = "poetry.core.masonry.api"
38 |
39 | [tool.black]
40 | line-length = 420
41 | target-version = ['py39']
42 | include = '\.pyi?$'
--------------------------------------------------------------------------------
/run.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | setlocal enabledelayedexpansion
3 |
4 | :: Default values
5 | set TICKER=AAPL,MSFT,NVDA
6 | set USE_OLLAMA=
7 | set START_DATE=
8 | set END_DATE=
9 | set INITIAL_AMOUNT=100000.0
10 | set MARGIN_REQUIREMENT=0.0
11 | set SHOW_REASONING=
12 | set COMMAND=
13 | set MODEL_NAME=
14 |
15 | :: Help function
16 | :show_help
17 | echo AI Hedge Fund Docker Runner
18 | echo.
19 | echo Usage: run.bat [OPTIONS] COMMAND
20 | echo.
21 | echo Options:
22 | echo --ticker SYMBOLS Comma-separated list of ticker symbols (e.g., AAPL,MSFT,NVDA)
23 | echo --start-date DATE Start date in YYYY-MM-DD format
24 | echo --end-date DATE End date in YYYY-MM-DD format
25 | echo --initial-cash AMT Initial cash position (default: 100000.0)
26 | echo --margin-requirement RATIO Margin requirement ratio (default: 0.0)
27 | echo --ollama Use Ollama for local LLM inference
28 | echo --show-reasoning Show reasoning from each agent
29 | echo.
30 | echo Commands:
31 | echo main Run the main hedge fund application
32 | echo backtest Run the backtester
33 | echo build Build the Docker image
34 | echo compose Run using Docker Compose with integrated Ollama
35 | echo ollama Start only the Ollama container for model management
36 | echo pull MODEL Pull a specific model into the Ollama container
37 | echo help Show this help message
38 | echo.
39 | echo Examples:
40 | echo run.bat --ticker AAPL,MSFT,NVDA main
41 | echo run.bat --ticker AAPL,MSFT,NVDA --ollama main
42 | echo run.bat --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 backtest
43 | echo run.bat compose # Run with Docker Compose (includes Ollama)
44 | echo run.bat ollama # Start only the Ollama container
45 | echo run.bat pull llama3 # Pull the llama3 model to Ollama
46 | echo.
47 | goto :eof
48 |
49 | :: Parse arguments
50 | :parse_args
51 | if "%~1"=="" goto :check_command
52 | if "%~1"=="--ticker" (
53 | set TICKER=%~2
54 | shift
55 | shift
56 | goto :parse_args
57 | )
58 | if "%~1"=="--start-date" (
59 | set START_DATE=--start-date %~2
60 | shift
61 | shift
62 | goto :parse_args
63 | )
64 | if "%~1"=="--end-date" (
65 | set END_DATE=--end-date %~2
66 | shift
67 | shift
68 | goto :parse_args
69 | )
70 | if "%~1"=="--initial-cash" (
71 | set INITIAL_AMOUNT=%~2
72 | shift
73 | shift
74 | goto :parse_args
75 | )
76 | if "%~1"=="--margin-requirement" (
77 | set MARGIN_REQUIREMENT=%~2
78 | shift
79 | shift
80 | goto :parse_args
81 | )
82 | if "%~1"=="--ollama" (
83 | set USE_OLLAMA=--ollama
84 | shift
85 | goto :parse_args
86 | )
87 | if "%~1"=="--show-reasoning" (
88 | set SHOW_REASONING=--show-reasoning
89 | shift
90 | goto :parse_args
91 | )
92 | if "%~1"=="main" (
93 | set COMMAND=main
94 | shift
95 | goto :parse_args
96 | )
97 | if "%~1"=="backtest" (
98 | set COMMAND=backtest
99 | shift
100 | goto :parse_args
101 | )
102 | if "%~1"=="build" (
103 | set COMMAND=build
104 | shift
105 | goto :parse_args
106 | )
107 | if "%~1"=="compose" (
108 | set COMMAND=compose
109 | shift
110 | goto :parse_args
111 | )
112 | if "%~1"=="ollama" (
113 | set COMMAND=ollama
114 | shift
115 | goto :parse_args
116 | )
117 | if "%~1"=="pull" (
118 | set COMMAND=pull
119 | set MODEL_NAME=%~2
120 | shift
121 | shift
122 | goto :parse_args
123 | )
124 | if "%~1"=="help" (
125 | call :show_help
126 | exit /b 0
127 | )
128 | if "%~1"=="--help" (
129 | call :show_help
130 | exit /b 0
131 | )
132 | echo Unknown option: %~1
133 | call :show_help
134 | exit /b 1
135 |
136 | :check_command
137 | if "!COMMAND!"=="" (
138 | echo Error: No command specified.
139 | call :show_help
140 | exit /b 1
141 | )
142 |
143 | :: Show help if 'help' command is provided
144 | if "!COMMAND!"=="help" (
145 | call :show_help
146 | exit /b 0
147 | )
148 |
149 | :: Check for Docker Compose existence
150 | docker compose version >nul 2>&1
151 | if !ERRORLEVEL! EQU 0 (
152 | set COMPOSE_CMD=docker compose
153 | ) else (
154 | docker-compose --version >nul 2>&1
155 | if !ERRORLEVEL! EQU 0 (
156 | set COMPOSE_CMD=docker-compose
157 | ) else (
158 | echo Error: Docker Compose is not installed.
159 | exit /b 1
160 | )
161 | )
162 |
163 | :: Build the Docker image if 'build' command is provided
164 | if "!COMMAND!"=="build" (
165 | docker build -t ai-hedge-fund .
166 | exit /b 0
167 | )
168 |
169 | :: Start Ollama container if 'ollama' command is provided
170 | if "!COMMAND!"=="ollama" (
171 | echo Starting Ollama container...
172 | !COMPOSE_CMD! up -d ollama
173 |
174 | :: Check if Ollama is running
175 | echo Waiting for Ollama to start...
176 | for /l %%i in (1, 1, 30) do (
177 | !COMPOSE_CMD! exec ollama curl -s http://localhost:11434/api/version >nul 2>&1
178 | if !ERRORLEVEL! EQU 0 (
179 | echo Ollama is now running.
180 | :: Show available models
181 | echo Available models:
182 | !COMPOSE_CMD! exec ollama ollama list
183 |
184 | echo.
185 | echo Manage your models using:
186 | echo run.bat pull ^ # Download a model
187 | echo run.bat ollama # Start Ollama and show models
188 | exit /b 0
189 | )
190 | timeout /t 1 /nobreak >nul
191 | echo.
192 | )
193 |
194 | echo Failed to start Ollama within the expected time. You may need to check the container logs.
195 | exit /b 1
196 | )
197 |
198 | :: Pull a model if 'pull' command is provided
199 | if "!COMMAND!"=="pull" (
200 | if "!MODEL_NAME!"=="" (
201 | echo Error: No model name specified.
202 | echo Usage: run.bat pull ^
203 | echo Example: run.bat pull llama3
204 | exit /b 1
205 | )
206 |
207 | :: Start Ollama if it's not already running
208 | !COMPOSE_CMD! up -d ollama
209 |
210 | :: Wait for Ollama to start
211 | echo Ensuring Ollama is running...
212 | for /l %%i in (1, 1, 30) do (
213 | !COMPOSE_CMD! exec ollama curl -s http://localhost:11434/api/version >nul 2>&1
214 | if !ERRORLEVEL! EQU 0 (
215 | echo Ollama is running.
216 | goto :pull_model
217 | )
218 | timeout /t 1 /nobreak >nul
219 | echo.
220 | )
221 |
222 | :pull_model
223 | :: Pull the model
224 | echo Pulling model: !MODEL_NAME!
225 | echo This may take some time depending on the model size and your internet connection.
226 | echo You can press Ctrl+C to cancel at any time (the model will continue downloading in the background).
227 |
228 | !COMPOSE_CMD! exec ollama ollama pull "!MODEL_NAME!"
229 |
230 | :: Check if the model was successfully pulled
231 | !COMPOSE_CMD! exec ollama ollama list | findstr /i "!MODEL_NAME!" >nul
232 | if !ERRORLEVEL! EQU 0 (
233 | echo Model !MODEL_NAME! was successfully downloaded.
234 | ) else (
235 | echo Warning: Model !MODEL_NAME! may not have been properly downloaded.
236 | echo Check the Ollama container status with: run.bat ollama
237 | )
238 |
239 | exit /b 0
240 | )
241 |
242 | :: Run with Docker Compose if 'compose' command is provided
243 | if "!COMMAND!"=="compose" (
244 | echo Running with Docker Compose (includes Ollama)...
245 | !COMPOSE_CMD! up --build
246 | exit /b 0
247 | )
248 |
249 | :: Check if .env file exists, if not create from .env.example
250 | if not exist .env (
251 | if exist .env.example (
252 | echo No .env file found. Creating from .env.example...
253 | copy .env.example .env
254 | echo Please edit .env file to add your API keys.
255 | ) else (
256 | echo Error: No .env or .env.example file found.
257 | exit /b 1
258 | )
259 | )
260 |
261 | :: Set script path and parameters based on command
262 | if "!COMMAND!"=="main" (
263 | set SCRIPT_PATH=src/main.py
264 | if "!COMMAND!"=="main" (
265 | set INITIAL_PARAM=--initial-cash !INITIAL_AMOUNT!
266 | )
267 | ) else if "!COMMAND!"=="backtest" (
268 | set SCRIPT_PATH=src/backtester.py
269 | if "!COMMAND!"=="backtest" (
270 | set INITIAL_PARAM=--initial-capital !INITIAL_AMOUNT!
271 | )
272 | )
273 |
274 | :: If using Ollama, make sure the service is started
275 | if not "!USE_OLLAMA!"=="" (
276 | echo Setting up Ollama container for local LLM inference...
277 |
278 | :: Start Ollama container if not already running
279 | !COMPOSE_CMD! up -d ollama
280 |
281 | :: Wait for Ollama to start
282 | echo Waiting for Ollama to start...
283 | for /l %%i in (1, 1, 30) do (
284 | !COMPOSE_CMD! exec ollama curl -s http://localhost:11434/api/version >nul 2>&1
285 | if !ERRORLEVEL! EQU 0 (
286 | echo Ollama is running.
287 | :: Show available models
288 | echo Available models:
289 | !COMPOSE_CMD! exec ollama ollama list
290 | goto :continue_ollama
291 | )
292 | timeout /t 1 /nobreak >nul
293 | echo.
294 | )
295 |
296 | :continue_ollama
297 | :: Build the AI Hedge Fund image if needed
298 | docker images -q ai-hedge-fund 2>nul | findstr /r /c:"^..*$" >nul
299 | if !ERRORLEVEL! NEQ 0 (
300 | echo Building AI Hedge Fund image...
301 | docker build -t ai-hedge-fund .
302 | )
303 |
304 | :: Create command override for Docker Compose
305 | set COMMAND_OVERRIDE=
306 |
307 | if not "!START_DATE!"=="" (
308 | set COMMAND_OVERRIDE=!COMMAND_OVERRIDE! !START_DATE!
309 | )
310 |
311 | if not "!END_DATE!"=="" (
312 | set COMMAND_OVERRIDE=!COMMAND_OVERRIDE! !END_DATE!
313 | )
314 |
315 | if not "!INITIAL_PARAM!"=="" (
316 | set COMMAND_OVERRIDE=!COMMAND_OVERRIDE! !INITIAL_PARAM!
317 | )
318 |
319 | if not "!MARGIN_REQUIREMENT!"=="" (
320 | set COMMAND_OVERRIDE=!COMMAND_OVERRIDE! --margin-requirement !MARGIN_REQUIREMENT!
321 | )
322 |
323 | :: Run the command with Docker Compose
324 | echo Running AI Hedge Fund with Ollama using Docker Compose...
325 |
326 | :: Use the appropriate service based on command and reasoning flag
327 | if "!COMMAND!"=="main" (
328 | if not "!SHOW_REASONING!"=="" (
329 | !COMPOSE_CMD! run --rm hedge-fund-reasoning python src/main.py --ticker !TICKER! !COMMAND_OVERRIDE! !SHOW_REASONING! --ollama
330 | ) else (
331 | !COMPOSE_CMD! run --rm hedge-fund-ollama python src/main.py --ticker !TICKER! !COMMAND_OVERRIDE! --ollama
332 | )
333 | ) else if "!COMMAND!"=="backtest" (
334 | !COMPOSE_CMD! run --rm backtester-ollama python src/backtester.py --ticker !TICKER! !COMMAND_OVERRIDE! !SHOW_REASONING! --ollama
335 | )
336 |
337 | exit /b 0
338 | )
339 |
340 | :: Standard Docker run (without Ollama)
341 | :: Build the command
342 | set CMD=docker run -it --rm -v %cd%\.env:/app/.env
343 |
344 | :: Add the command
345 | set CMD=!CMD! ai-hedge-fund python !SCRIPT_PATH! --ticker !TICKER! !START_DATE! !END_DATE! !INITIAL_PARAM! --margin-requirement !MARGIN_REQUIREMENT! !SHOW_REASONING!
346 |
347 | :: Run the command
348 | echo Running: !CMD!
349 | !CMD!
350 |
351 | :: Exit
352 | exit /b 0
353 |
354 | :: Start script execution
355 | call :parse_args %*
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Help text to display when --help is provided
4 | show_help() {
5 | echo "AI Hedge Fund Docker Runner"
6 | echo ""
7 | echo "Usage: ./run.sh [OPTIONS] COMMAND"
8 | echo ""
9 | echo "Options:"
10 | echo " --ticker SYMBOLS Comma-separated list of ticker symbols (e.g., AAPL,MSFT,NVDA)"
11 | echo " --start-date DATE Start date in YYYY-MM-DD format"
12 | echo " --end-date DATE End date in YYYY-MM-DD format"
13 | echo " --initial-cash AMT Initial cash position (default: 100000.0)"
14 | echo " --margin-requirement RATIO Margin requirement ratio (default: 0.0)"
15 | echo " --ollama Use Ollama for local LLM inference"
16 | echo " --show-reasoning Show reasoning from each agent"
17 | echo ""
18 | echo "Commands:"
19 | echo " main Run the main hedge fund application"
20 | echo " backtest Run the backtester"
21 | echo " build Build the Docker image"
22 | echo " compose Run using Docker Compose with integrated Ollama"
23 | echo " ollama Start only the Ollama container for model management"
24 | echo " pull MODEL Pull a specific model into the Ollama container"
25 | echo " help Show this help message"
26 | echo ""
27 | echo "Examples:"
28 | echo " ./run.sh --ticker AAPL,MSFT,NVDA main"
29 | echo " ./run.sh --ticker AAPL,MSFT,NVDA --ollama main"
30 | echo " ./run.sh --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 backtest"
31 | echo " ./run.sh compose # Run with Docker Compose (includes Ollama)"
32 | echo " ./run.sh ollama # Start only the Ollama container"
33 | echo " ./run.sh pull llama3 # Pull the llama3 model to Ollama"
34 | echo ""
35 | }
36 |
37 | # Default values
38 | TICKER="AAPL,MSFT,NVDA"
39 | USE_OLLAMA=""
40 | START_DATE=""
41 | END_DATE=""
42 | INITIAL_AMOUNT="100000.0"
43 | MARGIN_REQUIREMENT="0.0"
44 | SHOW_REASONING=""
45 | COMMAND=""
46 | MODEL_NAME=""
47 |
48 | # Parse arguments
49 | while [[ $# -gt 0 ]]; do
50 | case $1 in
51 | --ticker)
52 | TICKER="$2"
53 | shift 2
54 | ;;
55 | --start-date)
56 | START_DATE="--start-date $2"
57 | shift 2
58 | ;;
59 | --end-date)
60 | END_DATE="--end-date $2"
61 | shift 2
62 | ;;
63 | --initial-cash)
64 | INITIAL_AMOUNT="$2"
65 | shift 2
66 | ;;
67 | --margin-requirement)
68 | MARGIN_REQUIREMENT="$2"
69 | shift 2
70 | ;;
71 | --ollama)
72 | USE_OLLAMA="--ollama"
73 | shift
74 | ;;
75 | --show-reasoning)
76 | SHOW_REASONING="--show-reasoning"
77 | shift
78 | ;;
79 | main|backtest|build|help|compose|ollama)
80 | COMMAND="$1"
81 | shift
82 | ;;
83 | pull)
84 | COMMAND="pull"
85 | MODEL_NAME="$2"
86 | shift 2
87 | ;;
88 | --help)
89 | show_help
90 | exit 0
91 | ;;
92 | *)
93 | echo "Unknown option: $1"
94 | show_help
95 | exit 1
96 | ;;
97 | esac
98 | done
99 |
100 | # Check if command is provided
101 | if [ -z "$COMMAND" ]; then
102 | echo "Error: No command specified."
103 | show_help
104 | exit 1
105 | fi
106 |
107 | # Show help if 'help' command is provided
108 | if [ "$COMMAND" = "help" ]; then
109 | show_help
110 | exit 0
111 | fi
112 |
113 | # Check for Docker Compose existence
114 | if ! command -v docker-compose &> /dev/null && ! docker compose version &> /dev/null; then
115 | echo "Error: Docker Compose is not installed."
116 | exit 1
117 | fi
118 |
119 | # Determine which Docker Compose command to use
120 | if command -v docker-compose &> /dev/null; then
121 | COMPOSE_CMD="docker-compose"
122 | else
123 | COMPOSE_CMD="docker compose"
124 | fi
125 |
126 | # Detect system architecture for GPU configuration
127 | ARCH=$(uname -m)
128 | OS=$(uname -s)
129 | GPU_CONFIG=""
130 |
131 | # Set appropriate GPU configuration based on architecture
132 | if [ "$OS" = "Darwin" ] && { [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; }; then
133 | echo "Detected Apple Silicon (M-series) - Metal GPU acceleration should be enabled"
134 | # Metal GPU is handled via environment variables in docker-compose.yml
135 | elif command -v nvidia-smi &> /dev/null; then
136 | echo "NVIDIA GPU detected - Adding NVIDIA GPU configuration"
137 | GPU_CONFIG="-f docker-compose.yml -f docker-compose.nvidia.yml"
138 | fi
139 |
140 | # Build the Docker image if 'build' command is provided
141 | if [ "$COMMAND" = "build" ]; then
142 | docker build -t ai-hedge-fund .
143 | exit 0
144 | fi
145 |
146 | # Start Ollama container if 'ollama' command is provided
147 | if [ "$COMMAND" = "ollama" ]; then
148 | echo "Starting Ollama container..."
149 | $COMPOSE_CMD $GPU_CONFIG up -d ollama
150 |
151 | # Check if Ollama is running
152 | echo "Waiting for Ollama to start..."
153 | for i in {1..30}; do
154 | if docker run --rm --network=host curlimages/curl:latest curl -s http://localhost:11434/api/version &> /dev/null; then
155 | echo "Ollama is now running."
156 | # Show available models
157 | echo "Available models:"
158 | docker exec -t ollama ollama list
159 |
160 | echo -e "\nManage your models using:"
161 | echo " ./run.sh pull # Download a model"
162 | echo " ./run.sh ollama # Start Ollama and show models"
163 | exit 0
164 | fi
165 | echo -n "."
166 | sleep 1
167 | done
168 |
169 | echo "Failed to start Ollama within the expected time. You may need to check the container logs."
170 | exit 1
171 | fi
172 |
173 | # Pull a model if 'pull' command is provided
174 | if [ "$COMMAND" = "pull" ]; then
175 | if [ -z "$MODEL_NAME" ]; then
176 | echo "Error: No model name specified."
177 | echo "Usage: ./run.sh pull "
178 | echo "Example: ./run.sh pull llama3"
179 | exit 1
180 | fi
181 |
182 | # Start Ollama if it's not already running
183 | $COMPOSE_CMD $GPU_CONFIG up -d ollama
184 |
185 | # Wait for Ollama to start
186 | echo "Ensuring Ollama is running..."
187 | for i in {1..30}; do
188 | if docker run --rm --network=host curlimages/curl:latest curl -s http://localhost:11434/api/version &> /dev/null; then
189 | echo "Ollama is running."
190 | break
191 | fi
192 | echo -n "."
193 | sleep 1
194 | done
195 |
196 | # Pull the model
197 | echo "Pulling model: $MODEL_NAME"
198 | echo "This may take some time depending on the model size and your internet connection."
199 | echo "You can press Ctrl+C to cancel at any time (the model will continue downloading in the background)."
200 |
201 | docker exec -t ollama ollama pull "$MODEL_NAME"
202 |
203 | # Check if the model was successfully pulled
204 | if docker exec -t ollama ollama list | grep -q "$MODEL_NAME"; then
205 | echo "Model $MODEL_NAME was successfully downloaded."
206 | else
207 | echo "Warning: Model $MODEL_NAME may not have been properly downloaded."
208 | echo "Check the Ollama container status with: ./run.sh ollama"
209 | fi
210 |
211 | exit 0
212 | fi
213 |
214 | # Run with Docker Compose
215 | if [ "$COMMAND" = "compose" ]; then
216 | echo "Running with Docker Compose (includes Ollama)..."
217 | $COMPOSE_CMD $GPU_CONFIG up --build
218 | exit 0
219 | fi
220 |
221 | # Check if .env file exists, if not create from .env.example
222 | if [ ! -f .env ]; then
223 | if [ -f .env.example ]; then
224 | echo "No .env file found. Creating from .env.example..."
225 | cp .env.example .env
226 | echo "Please edit .env file to add your API keys."
227 | else
228 | echo "Error: No .env or .env.example file found."
229 | exit 1
230 | fi
231 | fi
232 |
233 | # Set script path and parameters based on command
234 | if [ "$COMMAND" = "main" ]; then
235 | SCRIPT_PATH="src/main.py"
236 | if [ "$COMMAND" = "main" ]; then
237 | INITIAL_PARAM="--initial-cash $INITIAL_AMOUNT"
238 | fi
239 | elif [ "$COMMAND" = "backtest" ]; then
240 | SCRIPT_PATH="src/backtester.py"
241 | if [ "$COMMAND" = "backtest" ]; then
242 | INITIAL_PARAM="--initial-capital $INITIAL_AMOUNT"
243 | fi
244 | fi
245 |
246 | # If using Ollama, make sure the service is started
247 | if [ -n "$USE_OLLAMA" ]; then
248 | echo "Setting up Ollama container for local LLM inference..."
249 |
250 | # Start Ollama container if not already running
251 | $COMPOSE_CMD $GPU_CONFIG up -d ollama
252 |
253 | # Wait for Ollama to start
254 | echo "Waiting for Ollama to start..."
255 | for i in {1..30}; do
256 | if docker run --rm --network=host curlimages/curl:latest curl -s http://localhost:11434/api/version &> /dev/null; then
257 | echo "Ollama is running."
258 | # Show available models
259 | echo "Available models:"
260 | docker exec -t ollama ollama list
261 | break
262 | fi
263 | echo -n "."
264 | sleep 1
265 | done
266 |
267 | # Build the AI Hedge Fund image if needed
268 | if [[ "$(docker images -q ai-hedge-fund 2> /dev/null)" == "" ]]; then
269 | echo "Building AI Hedge Fund image..."
270 | docker build -t ai-hedge-fund .
271 | fi
272 |
273 | # Create command override for Docker Compose
274 | COMMAND_OVERRIDE=""
275 |
276 | if [ -n "$START_DATE" ]; then
277 | COMMAND_OVERRIDE="$COMMAND_OVERRIDE $START_DATE"
278 | fi
279 |
280 | if [ -n "$END_DATE" ]; then
281 | COMMAND_OVERRIDE="$COMMAND_OVERRIDE $END_DATE"
282 | fi
283 |
284 | if [ -n "$INITIAL_PARAM" ]; then
285 | COMMAND_OVERRIDE="$COMMAND_OVERRIDE $INITIAL_PARAM"
286 | fi
287 |
288 | if [ -n "$MARGIN_REQUIREMENT" ]; then
289 | COMMAND_OVERRIDE="$COMMAND_OVERRIDE --margin-requirement $MARGIN_REQUIREMENT"
290 | fi
291 |
292 | # Run the command with Docker Compose
293 | echo "Running AI Hedge Fund with Ollama using Docker Compose..."
294 |
295 | # Use the appropriate service based on command and reasoning flag
296 | if [ "$COMMAND" = "main" ]; then
297 | if [ -n "$SHOW_REASONING" ]; then
298 | $COMPOSE_CMD $GPU_CONFIG run --rm hedge-fund-reasoning python src/main.py --ticker $TICKER $COMMAND_OVERRIDE $SHOW_REASONING --ollama
299 | else
300 | $COMPOSE_CMD $GPU_CONFIG run --rm hedge-fund-ollama python src/main.py --ticker $TICKER $COMMAND_OVERRIDE --ollama
301 | fi
302 | elif [ "$COMMAND" = "backtest" ]; then
303 | $COMPOSE_CMD $GPU_CONFIG run --rm backtester-ollama python src/backtester.py --ticker $TICKER $COMMAND_OVERRIDE $SHOW_REASONING --ollama
304 | fi
305 |
306 | exit 0
307 | fi
308 |
309 | # Standard Docker run (without Ollama)
310 | # Build the command
311 | CMD="docker run -it --rm -v $(pwd)/.env:/app/.env"
312 |
313 | # Add the command
314 | CMD="$CMD ai-hedge-fund python $SCRIPT_PATH --ticker $TICKER $START_DATE $END_DATE $INITIAL_PARAM --margin-requirement $MARGIN_REQUIREMENT $SHOW_REASONING"
315 |
316 | # Run the command
317 | echo "Running: $CMD"
318 | $CMD
--------------------------------------------------------------------------------
/src/agents/ben_graham.py:
--------------------------------------------------------------------------------
1 | from langchain_openai import ChatOpenAI
2 | from graph.state import AgentState, show_agent_reasoning
3 | from tools.api import get_financial_metrics, get_market_cap, search_line_items
4 | from langchain_core.prompts import ChatPromptTemplate
5 | from langchain_core.messages import HumanMessage
6 | from pydantic import BaseModel
7 | import json
8 | from typing_extensions import Literal
9 | from utils.progress import progress
10 | from utils.llm import call_llm
11 | import math
12 |
13 |
14 | class BenGrahamSignal(BaseModel):
15 | signal: Literal["bullish", "bearish", "neutral"]
16 | confidence: float
17 | reasoning: str
18 |
19 |
20 | def ben_graham_agent(state: AgentState):
21 | """
22 | Analyzes stocks using Benjamin Graham's classic value-investing principles:
23 | 1. Earnings stability over multiple years.
24 | 2. Solid financial strength (low debt, adequate liquidity).
25 | 3. Discount to intrinsic value (e.g. Graham Number or net-net).
26 | 4. Adequate margin of safety.
27 | """
28 | data = state["data"]
29 | end_date = data["end_date"]
30 | tickers = data["tickers"]
31 |
32 | analysis_data = {}
33 | graham_analysis = {}
34 |
35 | for ticker in tickers:
36 | progress.update_status("ben_graham_agent", ticker, "Fetching financial metrics")
37 | metrics = get_financial_metrics(ticker, end_date, period="annual", limit=10)
38 |
39 | progress.update_status("ben_graham_agent", ticker, "Gathering financial line items")
40 | financial_line_items = search_line_items(ticker, ["earnings_per_share", "revenue", "net_income", "book_value_per_share", "total_assets", "total_liabilities", "current_assets", "current_liabilities", "dividends_and_other_cash_distributions", "outstanding_shares"], end_date, period="annual", limit=10)
41 |
42 | progress.update_status("ben_graham_agent", ticker, "Getting market cap")
43 | market_cap = get_market_cap(ticker, end_date)
44 |
45 | # Perform sub-analyses
46 | progress.update_status("ben_graham_agent", ticker, "Analyzing earnings stability")
47 | earnings_analysis = analyze_earnings_stability(metrics, financial_line_items)
48 |
49 | progress.update_status("ben_graham_agent", ticker, "Analyzing financial strength")
50 | strength_analysis = analyze_financial_strength(financial_line_items)
51 |
52 | progress.update_status("ben_graham_agent", ticker, "Analyzing Graham valuation")
53 | valuation_analysis = analyze_valuation_graham(financial_line_items, market_cap)
54 |
55 | # Aggregate scoring
56 | total_score = earnings_analysis["score"] + strength_analysis["score"] + valuation_analysis["score"]
57 | max_possible_score = 15 # total possible from the three analysis functions
58 |
59 | # Map total_score to signal
60 | if total_score >= 0.7 * max_possible_score:
61 | signal = "bullish"
62 | elif total_score <= 0.3 * max_possible_score:
63 | signal = "bearish"
64 | else:
65 | signal = "neutral"
66 |
67 | analysis_data[ticker] = {"signal": signal, "score": total_score, "max_score": max_possible_score, "earnings_analysis": earnings_analysis, "strength_analysis": strength_analysis, "valuation_analysis": valuation_analysis}
68 |
69 | progress.update_status("ben_graham_agent", ticker, "Generating Ben Graham analysis")
70 | graham_output = generate_graham_output(
71 | ticker=ticker,
72 | analysis_data=analysis_data,
73 | model_name=state["metadata"]["model_name"],
74 | model_provider=state["metadata"]["model_provider"],
75 | )
76 |
77 | graham_analysis[ticker] = {"signal": graham_output.signal, "confidence": graham_output.confidence, "reasoning": graham_output.reasoning}
78 |
79 | progress.update_status("ben_graham_agent", ticker, "Done")
80 |
81 | # Wrap results in a single message for the chain
82 | message = HumanMessage(content=json.dumps(graham_analysis), name="ben_graham_agent")
83 |
84 | # Optionally display reasoning
85 | if state["metadata"]["show_reasoning"]:
86 | show_agent_reasoning(graham_analysis, "Ben Graham Agent")
87 |
88 | # Store signals in the overall state
89 | state["data"]["analyst_signals"]["ben_graham_agent"] = graham_analysis
90 |
91 | return {"messages": [message], "data": state["data"]}
92 |
93 |
94 | def analyze_earnings_stability(metrics: list, financial_line_items: list) -> dict:
95 | """
96 | Graham wants at least several years of consistently positive earnings (ideally 5+).
97 | We'll check:
98 | 1. Number of years with positive EPS.
99 | 2. Growth in EPS from first to last period.
100 | """
101 | score = 0
102 | details = []
103 |
104 | if not metrics or not financial_line_items:
105 | return {"score": score, "details": "Insufficient data for earnings stability analysis"}
106 |
107 | eps_vals = []
108 | for item in financial_line_items:
109 | if item.earnings_per_share is not None:
110 | eps_vals.append(item.earnings_per_share)
111 |
112 | if len(eps_vals) < 2:
113 | details.append("Not enough multi-year EPS data.")
114 | return {"score": score, "details": "; ".join(details)}
115 |
116 | # 1. Consistently positive EPS
117 | positive_eps_years = sum(1 for e in eps_vals if e > 0)
118 | total_eps_years = len(eps_vals)
119 | if positive_eps_years == total_eps_years:
120 | score += 3
121 | details.append("EPS was positive in all available periods.")
122 | elif positive_eps_years >= (total_eps_years * 0.8):
123 | score += 2
124 | details.append("EPS was positive in most periods.")
125 | else:
126 | details.append("EPS was negative in multiple periods.")
127 |
128 | # 2. EPS growth from earliest to latest
129 | if eps_vals[0] > eps_vals[-1]:
130 | score += 1
131 | details.append("EPS grew from earliest to latest period.")
132 | else:
133 | details.append("EPS did not grow from earliest to latest period.")
134 |
135 | return {"score": score, "details": "; ".join(details)}
136 |
137 |
138 | def analyze_financial_strength(financial_line_items: list) -> dict:
139 | """
140 | Graham checks liquidity (current ratio >= 2), manageable debt,
141 | and dividend record (preferably some history of dividends).
142 | """
143 | score = 0
144 | details = []
145 |
146 | if not financial_line_items:
147 | return {"score": score, "details": "No data for financial strength analysis"}
148 |
149 | latest_item = financial_line_items[0]
150 | total_assets = latest_item.total_assets or 0
151 | total_liabilities = latest_item.total_liabilities or 0
152 | current_assets = latest_item.current_assets or 0
153 | current_liabilities = latest_item.current_liabilities or 0
154 |
155 | # 1. Current ratio
156 | if current_liabilities > 0:
157 | current_ratio = current_assets / current_liabilities
158 | if current_ratio >= 2.0:
159 | score += 2
160 | details.append(f"Current ratio = {current_ratio:.2f} (>=2.0: solid).")
161 | elif current_ratio >= 1.5:
162 | score += 1
163 | details.append(f"Current ratio = {current_ratio:.2f} (moderately strong).")
164 | else:
165 | details.append(f"Current ratio = {current_ratio:.2f} (<1.5: weaker liquidity).")
166 | else:
167 | details.append("Cannot compute current ratio (missing or zero current_liabilities).")
168 |
169 | # 2. Debt vs. Assets
170 | if total_assets > 0:
171 | debt_ratio = total_liabilities / total_assets
172 | if debt_ratio < 0.5:
173 | score += 2
174 | details.append(f"Debt ratio = {debt_ratio:.2f}, under 0.50 (conservative).")
175 | elif debt_ratio < 0.8:
176 | score += 1
177 | details.append(f"Debt ratio = {debt_ratio:.2f}, somewhat high but could be acceptable.")
178 | else:
179 | details.append(f"Debt ratio = {debt_ratio:.2f}, quite high by Graham standards.")
180 | else:
181 | details.append("Cannot compute debt ratio (missing total_assets).")
182 |
183 | # 3. Dividend track record
184 | div_periods = [item.dividends_and_other_cash_distributions for item in financial_line_items if item.dividends_and_other_cash_distributions is not None]
185 | if div_periods:
186 | # In many data feeds, dividend outflow is shown as a negative number
187 | # (money going out to shareholders). We'll consider any negative as 'paid a dividend'.
188 | div_paid_years = sum(1 for d in div_periods if d < 0)
189 | if div_paid_years > 0:
190 | # e.g. if at least half the periods had dividends
191 | if div_paid_years >= (len(div_periods) // 2 + 1):
192 | score += 1
193 | details.append("Company paid dividends in the majority of the reported years.")
194 | else:
195 | details.append("Company has some dividend payments, but not most years.")
196 | else:
197 | details.append("Company did not pay dividends in these periods.")
198 | else:
199 | details.append("No dividend data available to assess payout consistency.")
200 |
201 | return {"score": score, "details": "; ".join(details)}
202 |
203 |
204 | def analyze_valuation_graham(financial_line_items: list, market_cap: float) -> dict:
205 | """
206 | Core Graham approach to valuation:
207 | 1. Net-Net Check: (Current Assets - Total Liabilities) vs. Market Cap
208 | 2. Graham Number: sqrt(22.5 * EPS * Book Value per Share)
209 | 3. Compare per-share price to Graham Number => margin of safety
210 | """
211 | if not financial_line_items or not market_cap or market_cap <= 0:
212 | return {"score": 0, "details": "Insufficient data to perform valuation"}
213 |
214 | latest = financial_line_items[0]
215 | current_assets = latest.current_assets or 0
216 | total_liabilities = latest.total_liabilities or 0
217 | book_value_ps = latest.book_value_per_share or 0
218 | eps = latest.earnings_per_share or 0
219 | shares_outstanding = latest.outstanding_shares or 0
220 |
221 | details = []
222 | score = 0
223 |
224 | # 1. Net-Net Check
225 | # NCAV = Current Assets - Total Liabilities
226 | # If NCAV > Market Cap => historically a strong buy signal
227 | net_current_asset_value = current_assets - total_liabilities
228 | if net_current_asset_value > 0 and shares_outstanding > 0:
229 | net_current_asset_value_per_share = net_current_asset_value / shares_outstanding
230 | price_per_share = market_cap / shares_outstanding if shares_outstanding else 0
231 |
232 | details.append(f"Net Current Asset Value = {net_current_asset_value:,.2f}")
233 | details.append(f"NCAV Per Share = {net_current_asset_value_per_share:,.2f}")
234 | details.append(f"Price Per Share = {price_per_share:,.2f}")
235 |
236 | if net_current_asset_value > market_cap:
237 | score += 4 # Very strong Graham signal
238 | details.append("Net-Net: NCAV > Market Cap (classic Graham deep value).")
239 | else:
240 | # For partial net-net discount
241 | if net_current_asset_value_per_share >= (price_per_share * 0.67):
242 | score += 2
243 | details.append("NCAV Per Share >= 2/3 of Price Per Share (moderate net-net discount).")
244 | else:
245 | details.append("NCAV not exceeding market cap or insufficient data for net-net approach.")
246 |
247 | # 2. Graham Number
248 | # GrahamNumber = sqrt(22.5 * EPS * BVPS).
249 | # Compare the result to the current price_per_share
250 | # If GrahamNumber >> price, indicates undervaluation
251 | graham_number = None
252 | if eps > 0 and book_value_ps > 0:
253 | graham_number = math.sqrt(22.5 * eps * book_value_ps)
254 | details.append(f"Graham Number = {graham_number:.2f}")
255 | else:
256 | details.append("Unable to compute Graham Number (EPS or Book Value missing/<=0).")
257 |
258 | # 3. Margin of Safety relative to Graham Number
259 | if graham_number and shares_outstanding > 0:
260 | current_price = market_cap / shares_outstanding
261 | if current_price > 0:
262 | margin_of_safety = (graham_number - current_price) / current_price
263 | details.append(f"Margin of Safety (Graham Number) = {margin_of_safety:.2%}")
264 | if margin_of_safety > 0.5:
265 | score += 3
266 | details.append("Price is well below Graham Number (>=50% margin).")
267 | elif margin_of_safety > 0.2:
268 | score += 1
269 | details.append("Some margin of safety relative to Graham Number.")
270 | else:
271 | details.append("Price close to or above Graham Number, low margin of safety.")
272 | else:
273 | details.append("Current price is zero or invalid; can't compute margin of safety.")
274 | # else: already appended details for missing graham_number
275 |
276 | return {"score": score, "details": "; ".join(details)}
277 |
278 |
279 | def generate_graham_output(
280 | ticker: str,
281 | analysis_data: dict[str, any],
282 | model_name: str,
283 | model_provider: str,
284 | ) -> BenGrahamSignal:
285 | """
286 | Generates an investment decision in the style of Benjamin Graham:
287 | - Value emphasis, margin of safety, net-nets, conservative balance sheet, stable earnings.
288 | - Return the result in a JSON structure: { signal, confidence, reasoning }.
289 | """
290 |
291 | template = ChatPromptTemplate.from_messages([
292 | (
293 | "system",
294 | """You are a Benjamin Graham AI agent, making investment decisions using his principles:
295 | 1. Insist on a margin of safety by buying below intrinsic value (e.g., using Graham Number, net-net).
296 | 2. Emphasize the company's financial strength (low leverage, ample current assets).
297 | 3. Prefer stable earnings over multiple years.
298 | 4. Consider dividend record for extra safety.
299 | 5. Avoid speculative or high-growth assumptions; focus on proven metrics.
300 |
301 | When providing your reasoning, be thorough and specific by:
302 | 1. Explaining the key valuation metrics that influenced your decision the most (Graham Number, NCAV, P/E, etc.)
303 | 2. Highlighting the specific financial strength indicators (current ratio, debt levels, etc.)
304 | 3. Referencing the stability or instability of earnings over time
305 | 4. Providing quantitative evidence with precise numbers
306 | 5. Comparing current metrics to Graham's specific thresholds (e.g., "Current ratio of 2.5 exceeds Graham's minimum of 2.0")
307 | 6. Using Benjamin Graham's conservative, analytical voice and style in your explanation
308 |
309 | For example, if bullish: "The stock trades at a 35% discount to net current asset value, providing an ample margin of safety. The current ratio of 2.5 and debt-to-equity of 0.3 indicate strong financial position..."
310 | For example, if bearish: "Despite consistent earnings, the current price of $50 exceeds our calculated Graham Number of $35, offering no margin of safety. Additionally, the current ratio of only 1.2 falls below Graham's preferred 2.0 threshold..."
311 |
312 | Return a rational recommendation: bullish, bearish, or neutral, with a confidence level (0-100) and thorough reasoning.
313 | """
314 | ),
315 | (
316 | "human",
317 | """Based on the following analysis, create a Graham-style investment signal:
318 |
319 | Analysis Data for {ticker}:
320 | {analysis_data}
321 |
322 | Return JSON exactly in this format:
323 | {{
324 | "signal": "bullish" or "bearish" or "neutral",
325 | "confidence": float (0-100),
326 | "reasoning": "string"
327 | }}
328 | """
329 | )
330 | ])
331 |
332 | prompt = template.invoke({
333 | "analysis_data": json.dumps(analysis_data, indent=2),
334 | "ticker": ticker
335 | })
336 |
337 | def create_default_ben_graham_signal():
338 | return BenGrahamSignal(signal="neutral", confidence=0.0, reasoning="Error in generating analysis; defaulting to neutral.")
339 |
340 | return call_llm(
341 | prompt=prompt,
342 | model_name=model_name,
343 | model_provider=model_provider,
344 | pydantic_model=BenGrahamSignal,
345 | agent_name="ben_graham_agent",
346 | default_factory=create_default_ben_graham_signal,
347 | )
348 |
--------------------------------------------------------------------------------
/src/agents/fundamentals.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import HumanMessage
2 | from graph.state import AgentState, show_agent_reasoning
3 | from utils.progress import progress
4 | import json
5 |
6 | from tools.api import get_financial_metrics
7 |
8 |
9 | ##### Fundamental Agent #####
10 | def fundamentals_agent(state: AgentState):
11 | """Analyzes fundamental data and generates trading signals for multiple tickers."""
12 | data = state["data"]
13 | end_date = data["end_date"]
14 | tickers = data["tickers"]
15 |
16 | # Initialize fundamental analysis for each ticker
17 | fundamental_analysis = {}
18 |
19 | for ticker in tickers:
20 | progress.update_status("fundamentals_agent", ticker, "Fetching financial metrics")
21 |
22 | # Get the financial metrics
23 | financial_metrics = get_financial_metrics(
24 | ticker=ticker,
25 | end_date=end_date,
26 | period="ttm",
27 | limit=10,
28 | )
29 |
30 | if not financial_metrics:
31 | progress.update_status("fundamentals_agent", ticker, "Failed: No financial metrics found")
32 | continue
33 |
34 | # Pull the most recent financial metrics
35 | metrics = financial_metrics[0]
36 |
37 | # Initialize signals list for different fundamental aspects
38 | signals = []
39 | reasoning = {}
40 |
41 | progress.update_status("fundamentals_agent", ticker, "Analyzing profitability")
42 | # 1. Profitability Analysis
43 | return_on_equity = metrics.return_on_equity
44 | net_margin = metrics.net_margin
45 | operating_margin = metrics.operating_margin
46 |
47 | thresholds = [
48 | (return_on_equity, 0.15), # Strong ROE above 15%
49 | (net_margin, 0.20), # Healthy profit margins
50 | (operating_margin, 0.15), # Strong operating efficiency
51 | ]
52 | profitability_score = sum(metric is not None and metric > threshold for metric, threshold in thresholds)
53 |
54 | signals.append("bullish" if profitability_score >= 2 else "bearish" if profitability_score == 0 else "neutral")
55 | reasoning["profitability_signal"] = {
56 | "signal": signals[0],
57 | "details": (f"ROE: {return_on_equity:.2%}" if return_on_equity else "ROE: N/A") + ", " + (f"Net Margin: {net_margin:.2%}" if net_margin else "Net Margin: N/A") + ", " + (f"Op Margin: {operating_margin:.2%}" if operating_margin else "Op Margin: N/A"),
58 | }
59 |
60 | progress.update_status("fundamentals_agent", ticker, "Analyzing growth")
61 | # 2. Growth Analysis
62 | revenue_growth = metrics.revenue_growth
63 | earnings_growth = metrics.earnings_growth
64 | book_value_growth = metrics.book_value_growth
65 |
66 | thresholds = [
67 | (revenue_growth, 0.10), # 10% revenue growth
68 | (earnings_growth, 0.10), # 10% earnings growth
69 | (book_value_growth, 0.10), # 10% book value growth
70 | ]
71 | growth_score = sum(metric is not None and metric > threshold for metric, threshold in thresholds)
72 |
73 | signals.append("bullish" if growth_score >= 2 else "bearish" if growth_score == 0 else "neutral")
74 | reasoning["growth_signal"] = {
75 | "signal": signals[1],
76 | "details": (f"Revenue Growth: {revenue_growth:.2%}" if revenue_growth else "Revenue Growth: N/A") + ", " + (f"Earnings Growth: {earnings_growth:.2%}" if earnings_growth else "Earnings Growth: N/A"),
77 | }
78 |
79 | progress.update_status("fundamentals_agent", ticker, "Analyzing financial health")
80 | # 3. Financial Health
81 | current_ratio = metrics.current_ratio
82 | debt_to_equity = metrics.debt_to_equity
83 | free_cash_flow_per_share = metrics.free_cash_flow_per_share
84 | earnings_per_share = metrics.earnings_per_share
85 |
86 | health_score = 0
87 | if current_ratio and current_ratio > 1.5: # Strong liquidity
88 | health_score += 1
89 | if debt_to_equity and debt_to_equity < 0.5: # Conservative debt levels
90 | health_score += 1
91 | if free_cash_flow_per_share and earnings_per_share and free_cash_flow_per_share > earnings_per_share * 0.8: # Strong FCF conversion
92 | health_score += 1
93 |
94 | signals.append("bullish" if health_score >= 2 else "bearish" if health_score == 0 else "neutral")
95 | reasoning["financial_health_signal"] = {
96 | "signal": signals[2],
97 | "details": (f"Current Ratio: {current_ratio:.2f}" if current_ratio else "Current Ratio: N/A") + ", " + (f"D/E: {debt_to_equity:.2f}" if debt_to_equity else "D/E: N/A"),
98 | }
99 |
100 | progress.update_status("fundamentals_agent", ticker, "Analyzing valuation ratios")
101 | # 4. Price to X ratios
102 | pe_ratio = metrics.price_to_earnings_ratio
103 | pb_ratio = metrics.price_to_book_ratio
104 | ps_ratio = metrics.price_to_sales_ratio
105 |
106 | thresholds = [
107 | (pe_ratio, 25), # Reasonable P/E ratio
108 | (pb_ratio, 3), # Reasonable P/B ratio
109 | (ps_ratio, 5), # Reasonable P/S ratio
110 | ]
111 | price_ratio_score = sum(metric is not None and metric > threshold for metric, threshold in thresholds)
112 |
113 | signals.append("bearish" if price_ratio_score >= 2 else "bullish" if price_ratio_score == 0 else "neutral")
114 | reasoning["price_ratios_signal"] = {
115 | "signal": signals[3],
116 | "details": (f"P/E: {pe_ratio:.2f}" if pe_ratio else "P/E: N/A") + ", " + (f"P/B: {pb_ratio:.2f}" if pb_ratio else "P/B: N/A") + ", " + (f"P/S: {ps_ratio:.2f}" if ps_ratio else "P/S: N/A"),
117 | }
118 |
119 | progress.update_status("fundamentals_agent", ticker, "Calculating final signal")
120 | # Determine overall signal
121 | bullish_signals = signals.count("bullish")
122 | bearish_signals = signals.count("bearish")
123 |
124 | if bullish_signals > bearish_signals:
125 | overall_signal = "bullish"
126 | elif bearish_signals > bullish_signals:
127 | overall_signal = "bearish"
128 | else:
129 | overall_signal = "neutral"
130 |
131 | # Calculate confidence level
132 | total_signals = len(signals)
133 | confidence = round(max(bullish_signals, bearish_signals) / total_signals, 2) * 100
134 |
135 | fundamental_analysis[ticker] = {
136 | "signal": overall_signal,
137 | "confidence": confidence,
138 | "reasoning": reasoning,
139 | }
140 |
141 | progress.update_status("fundamentals_agent", ticker, "Done")
142 |
143 | # Create the fundamental analysis message
144 | message = HumanMessage(
145 | content=json.dumps(fundamental_analysis),
146 | name="fundamentals_agent",
147 | )
148 |
149 | # Print the reasoning if the flag is set
150 | if state["metadata"]["show_reasoning"]:
151 | show_agent_reasoning(fundamental_analysis, "Fundamental Analysis Agent")
152 |
153 | # Add the signal to the analyst_signals list
154 | state["data"]["analyst_signals"]["fundamentals_agent"] = fundamental_analysis
155 |
156 | return {
157 | "messages": [message],
158 | "data": data,
159 | }
160 |
--------------------------------------------------------------------------------
/src/agents/michael_burry.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from datetime import datetime, timedelta
4 | import json
5 | from typing_extensions import Literal
6 |
7 | from graph.state import AgentState, show_agent_reasoning
8 | from langchain_core.messages import HumanMessage
9 | from langchain_core.prompts import ChatPromptTemplate
10 | from pydantic import BaseModel
11 |
12 | from tools.api import (
13 | get_company_news,
14 | get_financial_metrics,
15 | get_insider_trades,
16 | get_market_cap,
17 | search_line_items,
18 | )
19 | from utils.llm import call_llm
20 | from utils.progress import progress
21 |
22 | __all__ = [
23 | "MichaelBurrySignal",
24 | "michael_burry_agent",
25 | ]
26 |
27 | ###############################################################################
28 | # Pydantic output model
29 | ###############################################################################
30 |
31 |
32 | class MichaelBurrySignal(BaseModel):
33 | """Schema returned by the LLM."""
34 |
35 | signal: Literal["bullish", "bearish", "neutral"]
36 | confidence: float # 0–100
37 | reasoning: str
38 |
39 |
40 | ###############################################################################
41 | # Core agent
42 | ###############################################################################
43 |
44 |
45 | def michael_burry_agent(state: AgentState): # noqa: C901 (complexity is fine here)
46 | """Analyse stocks using Michael Burry's deep‑value, contrarian framework."""
47 |
48 | data = state["data"]
49 | end_date: str = data["end_date"] # YYYY‑MM‑DD
50 | tickers: list[str] = data["tickers"]
51 |
52 | # We look one year back for insider trades / news flow
53 | start_date = (datetime.fromisoformat(end_date) - timedelta(days=365)).date().isoformat()
54 |
55 | analysis_data: dict[str, dict] = {}
56 | burry_analysis: dict[str, dict] = {}
57 |
58 | for ticker in tickers:
59 | # ------------------------------------------------------------------
60 | # Fetch raw data
61 | # ------------------------------------------------------------------
62 | progress.update_status("michael_burry_agent", ticker, "Fetching financial metrics")
63 | metrics = get_financial_metrics(ticker, end_date, period="ttm", limit=5)
64 |
65 | progress.update_status("michael_burry_agent", ticker, "Fetching line items")
66 | line_items = search_line_items(
67 | ticker,
68 | [
69 | "free_cash_flow",
70 | "net_income",
71 | "total_debt",
72 | "cash_and_equivalents",
73 | "total_assets",
74 | "total_liabilities",
75 | "outstanding_shares",
76 | "issuance_or_purchase_of_equity_shares",
77 | ],
78 | end_date,
79 | )
80 |
81 | progress.update_status("michael_burry_agent", ticker, "Fetching insider trades")
82 | insider_trades = get_insider_trades(ticker, end_date=end_date, start_date=start_date)
83 |
84 | progress.update_status("michael_burry_agent", ticker, "Fetching company news")
85 | news = get_company_news(ticker, end_date=end_date, start_date=start_date, limit=250)
86 |
87 | progress.update_status("michael_burry_agent", ticker, "Fetching market cap")
88 | market_cap = get_market_cap(ticker, end_date)
89 |
90 | # ------------------------------------------------------------------
91 | # Run sub‑analyses
92 | # ------------------------------------------------------------------
93 | progress.update_status("michael_burry_agent", ticker, "Analyzing value")
94 | value_analysis = _analyze_value(metrics, line_items, market_cap)
95 |
96 | progress.update_status("michael_burry_agent", ticker, "Analyzing balance sheet")
97 | balance_sheet_analysis = _analyze_balance_sheet(metrics, line_items)
98 |
99 | progress.update_status("michael_burry_agent", ticker, "Analyzing insider activity")
100 | insider_analysis = _analyze_insider_activity(insider_trades)
101 |
102 | progress.update_status("michael_burry_agent", ticker, "Analyzing contrarian sentiment")
103 | contrarian_analysis = _analyze_contrarian_sentiment(news)
104 |
105 | # ------------------------------------------------------------------
106 | # Aggregate score & derive preliminary signal
107 | # ------------------------------------------------------------------
108 | total_score = (
109 | value_analysis["score"]
110 | + balance_sheet_analysis["score"]
111 | + insider_analysis["score"]
112 | + contrarian_analysis["score"]
113 | )
114 | max_score = (
115 | value_analysis["max_score"]
116 | + balance_sheet_analysis["max_score"]
117 | + insider_analysis["max_score"]
118 | + contrarian_analysis["max_score"]
119 | )
120 |
121 | if total_score >= 0.7 * max_score:
122 | signal = "bullish"
123 | elif total_score <= 0.3 * max_score:
124 | signal = "bearish"
125 | else:
126 | signal = "neutral"
127 |
128 | # ------------------------------------------------------------------
129 | # Collect data for LLM reasoning & output
130 | # ------------------------------------------------------------------
131 | analysis_data[ticker] = {
132 | "signal": signal,
133 | "score": total_score,
134 | "max_score": max_score,
135 | "value_analysis": value_analysis,
136 | "balance_sheet_analysis": balance_sheet_analysis,
137 | "insider_analysis": insider_analysis,
138 | "contrarian_analysis": contrarian_analysis,
139 | "market_cap": market_cap,
140 | }
141 |
142 | progress.update_status("michael_burry_agent", ticker, "Generating LLM output")
143 | burry_output = _generate_burry_output(
144 | ticker=ticker,
145 | analysis_data=analysis_data,
146 | model_name=state["metadata"]["model_name"],
147 | model_provider=state["metadata"]["model_provider"],
148 | )
149 |
150 | burry_analysis[ticker] = {
151 | "signal": burry_output.signal,
152 | "confidence": burry_output.confidence,
153 | "reasoning": burry_output.reasoning,
154 | }
155 |
156 | progress.update_status("michael_burry_agent", ticker, "Done")
157 |
158 | # ----------------------------------------------------------------------
159 | # Return to the graph
160 | # ----------------------------------------------------------------------
161 | message = HumanMessage(content=json.dumps(burry_analysis), name="michael_burry_agent")
162 |
163 | if state["metadata"].get("show_reasoning"):
164 | show_agent_reasoning(burry_analysis, "Michael Burry Agent")
165 |
166 | state["data"]["analyst_signals"]["michael_burry_agent"] = burry_analysis
167 |
168 | return {"messages": [message], "data": state["data"]}
169 |
170 |
171 | ###############################################################################
172 | # Sub‑analysis helpers
173 | ###############################################################################
174 |
175 |
176 | def _latest_line_item(line_items: list):
177 | """Return the most recent line‑item object or *None*."""
178 | return line_items[0] if line_items else None
179 |
180 |
181 | # ----- Value ----------------------------------------------------------------
182 |
183 | def _analyze_value(metrics, line_items, market_cap):
184 | """Free cash‑flow yield, EV/EBIT, other classic deep‑value metrics."""
185 |
186 | max_score = 6 # 4 pts for FCF‑yield, 2 pts for EV/EBIT
187 | score = 0
188 | details: list[str] = []
189 |
190 | # Free‑cash‑flow yield
191 | latest_item = _latest_line_item(line_items)
192 | fcf = getattr(latest_item, "free_cash_flow", None) if latest_item else None
193 | if fcf is not None and market_cap:
194 | fcf_yield = fcf / market_cap
195 | if fcf_yield >= 0.15:
196 | score += 4
197 | details.append(f"Extraordinary FCF yield {fcf_yield:.1%}")
198 | elif fcf_yield >= 0.12:
199 | score += 3
200 | details.append(f"Very high FCF yield {fcf_yield:.1%}")
201 | elif fcf_yield >= 0.08:
202 | score += 2
203 | details.append(f"Respectable FCF yield {fcf_yield:.1%}")
204 | else:
205 | details.append(f"Low FCF yield {fcf_yield:.1%}")
206 | else:
207 | details.append("FCF data unavailable")
208 |
209 | # EV/EBIT (from financial metrics)
210 | if metrics:
211 | ev_ebit = getattr(metrics[0], "ev_to_ebit", None)
212 | if ev_ebit is not None:
213 | if ev_ebit < 6:
214 | score += 2
215 | details.append(f"EV/EBIT {ev_ebit:.1f} (<6)")
216 | elif ev_ebit < 10:
217 | score += 1
218 | details.append(f"EV/EBIT {ev_ebit:.1f} (<10)")
219 | else:
220 | details.append(f"High EV/EBIT {ev_ebit:.1f}")
221 | else:
222 | details.append("EV/EBIT data unavailable")
223 | else:
224 | details.append("Financial metrics unavailable")
225 |
226 | return {"score": score, "max_score": max_score, "details": "; ".join(details)}
227 |
228 |
229 | # ----- Balance sheet --------------------------------------------------------
230 |
231 | def _analyze_balance_sheet(metrics, line_items):
232 | """Leverage and liquidity checks."""
233 |
234 | max_score = 3
235 | score = 0
236 | details: list[str] = []
237 |
238 | latest_metrics = metrics[0] if metrics else None
239 | latest_item = _latest_line_item(line_items)
240 |
241 | debt_to_equity = getattr(latest_metrics, "debt_to_equity", None) if latest_metrics else None
242 | if debt_to_equity is not None:
243 | if debt_to_equity < 0.5:
244 | score += 2
245 | details.append(f"Low D/E {debt_to_equity:.2f}")
246 | elif debt_to_equity < 1:
247 | score += 1
248 | details.append(f"Moderate D/E {debt_to_equity:.2f}")
249 | else:
250 | details.append(f"High leverage D/E {debt_to_equity:.2f}")
251 | else:
252 | details.append("Debt‑to‑equity data unavailable")
253 |
254 | # Quick liquidity sanity check (cash vs total debt)
255 | if latest_item is not None:
256 | cash = getattr(latest_item, "cash_and_equivalents", None)
257 | total_debt = getattr(latest_item, "total_debt", None)
258 | if cash is not None and total_debt is not None:
259 | if cash > total_debt:
260 | score += 1
261 | details.append("Net cash position")
262 | else:
263 | details.append("Net debt position")
264 | else:
265 | details.append("Cash/debt data unavailable")
266 |
267 | return {"score": score, "max_score": max_score, "details": "; ".join(details)}
268 |
269 |
270 | # ----- Insider activity -----------------------------------------------------
271 |
272 | def _analyze_insider_activity(insider_trades):
273 | """Net insider buying over the last 12 months acts as a hard catalyst."""
274 |
275 | max_score = 2
276 | score = 0
277 | details: list[str] = []
278 |
279 | if not insider_trades:
280 | details.append("No insider trade data")
281 | return {"score": score, "max_score": max_score, "details": "; ".join(details)}
282 |
283 | shares_bought = sum(t.transaction_shares or 0 for t in insider_trades if (t.transaction_shares or 0) > 0)
284 | shares_sold = abs(sum(t.transaction_shares or 0 for t in insider_trades if (t.transaction_shares or 0) < 0))
285 | net = shares_bought - shares_sold
286 | if net > 0:
287 | score += 2 if net / max(shares_sold, 1) > 1 else 1
288 | details.append(f"Net insider buying of {net:,} shares")
289 | else:
290 | details.append("Net insider selling")
291 |
292 | return {"score": score, "max_score": max_score, "details": "; ".join(details)}
293 |
294 |
295 | # ----- Contrarian sentiment -------------------------------------------------
296 |
297 | def _analyze_contrarian_sentiment(news):
298 | """Very rough gauge: a wall of recent negative headlines can be a *positive* for a contrarian."""
299 |
300 | max_score = 1
301 | score = 0
302 | details: list[str] = []
303 |
304 | if not news:
305 | details.append("No recent news")
306 | return {"score": score, "max_score": max_score, "details": "; ".join(details)}
307 |
308 | # Count negative sentiment articles
309 | sentiment_negative_count = sum(
310 | 1 for n in news if n.sentiment and n.sentiment.lower() in ["negative", "bearish"]
311 | )
312 |
313 | if sentiment_negative_count >= 5:
314 | score += 1 # The more hated, the better (assuming fundamentals hold up)
315 | details.append(f"{sentiment_negative_count} negative headlines (contrarian opportunity)")
316 | else:
317 | details.append("Limited negative press")
318 |
319 | return {"score": score, "max_score": max_score, "details": "; ".join(details)}
320 |
321 |
322 | ###############################################################################
323 | # LLM generation
324 | ###############################################################################
325 |
326 | def _generate_burry_output(
327 | ticker: str,
328 | analysis_data: dict,
329 | *,
330 | model_name: str,
331 | model_provider: str,
332 | ) -> MichaelBurrySignal:
333 | """Call the LLM to craft the final trading signal in Burry's voice."""
334 |
335 | template = ChatPromptTemplate.from_messages(
336 | [
337 | (
338 | "system",
339 | """You are an AI agent emulating Dr. Michael J. Burry. Your mandate:
340 | - Hunt for deep value in US equities using hard numbers (free cash flow, EV/EBIT, balance sheet)
341 | - Be contrarian: hatred in the press can be your friend if fundamentals are solid
342 | - Focus on downside first – avoid leveraged balance sheets
343 | - Look for hard catalysts such as insider buying, buybacks, or asset sales
344 | - Communicate in Burry's terse, data‑driven style
345 |
346 | When providing your reasoning, be thorough and specific by:
347 | 1. Start with the key metric(s) that drove your decision
348 | 2. Cite concrete numbers (e.g. "FCF yield 14.7%", "EV/EBIT 5.3")
349 | 3. Highlight risk factors and why they are acceptable (or not)
350 | 4. Mention relevant insider activity or contrarian opportunities
351 | 5. Use Burry's direct, number-focused communication style with minimal words
352 |
353 | For example, if bullish: "FCF yield 12.8%. EV/EBIT 6.2. Debt-to-equity 0.4. Net insider buying 25k shares. Market missing value due to overreaction to recent litigation. Strong buy."
354 | For example, if bearish: "FCF yield only 2.1%. Debt-to-equity concerning at 2.3. Management diluting shareholders. Pass."
355 | """,
356 | ),
357 | (
358 | "human",
359 | """Based on the following data, create the investment signal as Michael Burry would:
360 |
361 | Analysis Data for {ticker}:
362 | {analysis_data}
363 |
364 | Return the trading signal in the following JSON format exactly:
365 | {{
366 | "signal": "bullish" | "bearish" | "neutral",
367 | "confidence": float between 0 and 100,
368 | "reasoning": "string"
369 | }}
370 | """,
371 | ),
372 | ]
373 | )
374 |
375 | prompt = template.invoke({"analysis_data": json.dumps(analysis_data, indent=2), "ticker": ticker})
376 |
377 | # Default fallback signal in case parsing fails
378 | def create_default_michael_burry_signal():
379 | return MichaelBurrySignal(signal="neutral", confidence=0.0, reasoning="Parsing error – defaulting to neutral")
380 |
381 | return call_llm(
382 | prompt=prompt,
383 | model_name=model_name,
384 | model_provider=model_provider,
385 | pydantic_model=MichaelBurrySignal,
386 | agent_name="michael_burry_agent",
387 | default_factory=create_default_michael_burry_signal,
388 | )
389 |
--------------------------------------------------------------------------------
/src/agents/portfolio_manager.py:
--------------------------------------------------------------------------------
1 | import json
2 | from langchain_core.messages import HumanMessage
3 | from langchain_core.prompts import ChatPromptTemplate
4 |
5 | from graph.state import AgentState, show_agent_reasoning
6 | from pydantic import BaseModel, Field
7 | from typing_extensions import Literal
8 | from utils.progress import progress
9 | from utils.llm import call_llm
10 |
11 |
12 | class PortfolioDecision(BaseModel):
13 | action: Literal["buy", "sell", "short", "cover", "hold"]
14 | quantity: int = Field(description="Number of shares to trade")
15 | confidence: float = Field(description="Confidence in the decision, between 0.0 and 100.0")
16 | reasoning: str = Field(description="Reasoning for the decision")
17 |
18 |
19 | class PortfolioManagerOutput(BaseModel):
20 | decisions: dict[str, PortfolioDecision] = Field(description="Dictionary of ticker to trading decisions")
21 |
22 |
23 | ##### Portfolio Management Agent #####
24 | def portfolio_management_agent(state: AgentState):
25 | """Makes final trading decisions and generates orders for multiple tickers"""
26 |
27 | # Get the portfolio and analyst signals
28 | portfolio = state["data"]["portfolio"]
29 | analyst_signals = state["data"]["analyst_signals"]
30 | tickers = state["data"]["tickers"]
31 |
32 | progress.update_status("portfolio_management_agent", None, "Analyzing signals")
33 |
34 | # Get position limits, current prices, and signals for every ticker
35 | position_limits = {}
36 | current_prices = {}
37 | max_shares = {}
38 | signals_by_ticker = {}
39 | for ticker in tickers:
40 | progress.update_status("portfolio_management_agent", ticker, "Processing analyst signals")
41 |
42 | # Get position limits and current prices for the ticker
43 | risk_data = analyst_signals.get("risk_management_agent", {}).get(ticker, {})
44 | position_limits[ticker] = risk_data.get("remaining_position_limit", 0)
45 | current_prices[ticker] = risk_data.get("current_price", 0)
46 |
47 | # Calculate maximum shares allowed based on position limit and price
48 | if current_prices[ticker] > 0:
49 | max_shares[ticker] = int(position_limits[ticker] / current_prices[ticker])
50 | else:
51 | max_shares[ticker] = 0
52 |
53 | # Get signals for the ticker
54 | ticker_signals = {}
55 | for agent, signals in analyst_signals.items():
56 | if agent != "risk_management_agent" and ticker in signals:
57 | ticker_signals[agent] = {"signal": signals[ticker]["signal"], "confidence": signals[ticker]["confidence"]}
58 | signals_by_ticker[ticker] = ticker_signals
59 |
60 | progress.update_status("portfolio_management_agent", None, "Making trading decisions")
61 |
62 | # Generate the trading decision
63 | result = generate_trading_decision(
64 | tickers=tickers,
65 | signals_by_ticker=signals_by_ticker,
66 | current_prices=current_prices,
67 | max_shares=max_shares,
68 | portfolio=portfolio,
69 | model_name=state["metadata"]["model_name"],
70 | model_provider=state["metadata"]["model_provider"],
71 | )
72 |
73 | # Create the portfolio management message
74 | message = HumanMessage(
75 | content=json.dumps({ticker: decision.model_dump() for ticker, decision in result.decisions.items()}),
76 | name="portfolio_management",
77 | )
78 |
79 | # Print the decision if the flag is set
80 | if state["metadata"]["show_reasoning"]:
81 | show_agent_reasoning({ticker: decision.model_dump() for ticker, decision in result.decisions.items()}, "Portfolio Management Agent")
82 |
83 | progress.update_status("portfolio_management_agent", None, "Done")
84 |
85 | return {
86 | "messages": state["messages"] + [message],
87 | "data": state["data"],
88 | }
89 |
90 |
91 | def generate_trading_decision(
92 | tickers: list[str],
93 | signals_by_ticker: dict[str, dict],
94 | current_prices: dict[str, float],
95 | max_shares: dict[str, int],
96 | portfolio: dict[str, float],
97 | model_name: str,
98 | model_provider: str,
99 | ) -> PortfolioManagerOutput:
100 | """Attempts to get a decision from the LLM with retry logic"""
101 | # Create the prompt template
102 | template = ChatPromptTemplate.from_messages(
103 | [
104 | (
105 | "system",
106 | """You are a portfolio manager making final trading decisions based on multiple tickers.
107 |
108 | Trading Rules:
109 | - For long positions:
110 | * Only buy if you have available cash
111 | * Only sell if you currently hold long shares of that ticker
112 | * Sell quantity must be ≤ current long position shares
113 | * Buy quantity must be ≤ max_shares for that ticker
114 |
115 | - For short positions:
116 | * Only short if you have available margin (position value × margin requirement)
117 | * Only cover if you currently have short shares of that ticker
118 | * Cover quantity must be ≤ current short position shares
119 | * Short quantity must respect margin requirements
120 |
121 | - The max_shares values are pre-calculated to respect position limits
122 | - Consider both long and short opportunities based on signals
123 | - Maintain appropriate risk management with both long and short exposure
124 |
125 | Available Actions:
126 | - "buy": Open or add to long position
127 | - "sell": Close or reduce long position
128 | - "short": Open or add to short position
129 | - "cover": Close or reduce short position
130 | - "hold": No action
131 |
132 | Inputs:
133 | - signals_by_ticker: dictionary of ticker → signals
134 | - max_shares: maximum shares allowed per ticker
135 | - portfolio_cash: current cash in portfolio
136 | - portfolio_positions: current positions (both long and short)
137 | - current_prices: current prices for each ticker
138 | - margin_requirement: current margin requirement for short positions (e.g., 0.5 means 50%)
139 | - total_margin_used: total margin currently in use
140 | """,
141 | ),
142 | (
143 | "human",
144 | """Based on the team's analysis, make your trading decisions for each ticker.
145 |
146 | Here are the signals by ticker:
147 | {signals_by_ticker}
148 |
149 | Current Prices:
150 | {current_prices}
151 |
152 | Maximum Shares Allowed For Purchases:
153 | {max_shares}
154 |
155 | Portfolio Cash: {portfolio_cash}
156 | Current Positions: {portfolio_positions}
157 | Current Margin Requirement: {margin_requirement}
158 | Total Margin Used: {total_margin_used}
159 |
160 | Output strictly in JSON with the following structure:
161 | {{
162 | "decisions": {{
163 | "TICKER1": {{
164 | "action": "buy/sell/short/cover/hold",
165 | "quantity": integer,
166 | "confidence": float between 0 and 100,
167 | "reasoning": "string"
168 | }},
169 | "TICKER2": {{
170 | ...
171 | }},
172 | ...
173 | }}
174 | }}
175 | """,
176 | ),
177 | ]
178 | )
179 |
180 | # Generate the prompt
181 | prompt = template.invoke(
182 | {
183 | "signals_by_ticker": json.dumps(signals_by_ticker, indent=2),
184 | "current_prices": json.dumps(current_prices, indent=2),
185 | "max_shares": json.dumps(max_shares, indent=2),
186 | "portfolio_cash": f"{portfolio.get('cash', 0):.2f}",
187 | "portfolio_positions": json.dumps(portfolio.get('positions', {}), indent=2),
188 | "margin_requirement": f"{portfolio.get('margin_requirement', 0):.2f}",
189 | "total_margin_used": f"{portfolio.get('margin_used', 0):.2f}",
190 | }
191 | )
192 |
193 | # Create default factory for PortfolioManagerOutput
194 | def create_default_portfolio_output():
195 | return PortfolioManagerOutput(decisions={ticker: PortfolioDecision(action="hold", quantity=0, confidence=0.0, reasoning="Error in portfolio management, defaulting to hold") for ticker in tickers})
196 |
197 | return call_llm(prompt=prompt, model_name=model_name, model_provider=model_provider, pydantic_model=PortfolioManagerOutput, agent_name="portfolio_management_agent", default_factory=create_default_portfolio_output)
198 |
--------------------------------------------------------------------------------
/src/agents/risk_manager.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import HumanMessage
2 | from graph.state import AgentState, show_agent_reasoning
3 | from utils.progress import progress
4 | from tools.api import get_prices, prices_to_df
5 | import json
6 |
7 |
8 | ##### Risk Management Agent #####
9 | def risk_management_agent(state: AgentState):
10 | """Controls position sizing based on real-world risk factors for multiple tickers."""
11 | portfolio = state["data"]["portfolio"]
12 | data = state["data"]
13 | tickers = data["tickers"]
14 |
15 | # Initialize risk analysis for each ticker
16 | risk_analysis = {}
17 | current_prices = {} # Store prices here to avoid redundant API calls
18 |
19 | for ticker in tickers:
20 | progress.update_status("risk_management_agent", ticker, "Analyzing price data")
21 |
22 | prices = get_prices(
23 | ticker=ticker,
24 | start_date=data["start_date"],
25 | end_date=data["end_date"],
26 | )
27 |
28 | if not prices:
29 | progress.update_status("risk_management_agent", ticker, "Failed: No price data found")
30 | continue
31 |
32 | prices_df = prices_to_df(prices)
33 |
34 | progress.update_status("risk_management_agent", ticker, "Calculating position limits")
35 |
36 | # Calculate portfolio value
37 | current_price = prices_df["close"].iloc[-1]
38 | current_prices[ticker] = current_price # Store the current price
39 |
40 | # Calculate current position value for this ticker
41 | current_position_value = portfolio.get("cost_basis", {}).get(ticker, 0)
42 |
43 | # Calculate total portfolio value using stored prices
44 | total_portfolio_value = portfolio.get("cash", 0) + sum(portfolio.get("cost_basis", {}).get(t, 0) for t in portfolio.get("cost_basis", {}))
45 |
46 | # Base limit is 20% of portfolio for any single position
47 | position_limit = total_portfolio_value * 0.20
48 |
49 | # For existing positions, subtract current position value from limit
50 | remaining_position_limit = position_limit - current_position_value
51 |
52 | # Ensure we don't exceed available cash
53 | max_position_size = min(remaining_position_limit, portfolio.get("cash", 0))
54 |
55 | risk_analysis[ticker] = {
56 | "remaining_position_limit": float(max_position_size),
57 | "current_price": float(current_price),
58 | "reasoning": {
59 | "portfolio_value": float(total_portfolio_value),
60 | "current_position": float(current_position_value),
61 | "position_limit": float(position_limit),
62 | "remaining_limit": float(remaining_position_limit),
63 | "available_cash": float(portfolio.get("cash", 0)),
64 | },
65 | }
66 |
67 | progress.update_status("risk_management_agent", ticker, "Done")
68 |
69 | message = HumanMessage(
70 | content=json.dumps(risk_analysis),
71 | name="risk_management_agent",
72 | )
73 |
74 | if state["metadata"]["show_reasoning"]:
75 | show_agent_reasoning(risk_analysis, "Risk Management Agent")
76 |
77 | # Add the signal to the analyst_signals list
78 | state["data"]["analyst_signals"]["risk_management_agent"] = risk_analysis
79 |
80 | return {
81 | "messages": state["messages"] + [message],
82 | "data": data,
83 | }
84 |
--------------------------------------------------------------------------------
/src/agents/sentiment.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import HumanMessage
2 | from graph.state import AgentState, show_agent_reasoning
3 | from utils.progress import progress
4 | import pandas as pd
5 | import numpy as np
6 | import json
7 |
8 | from tools.api import get_insider_trades, get_company_news
9 |
10 |
11 | ##### Sentiment Agent #####
12 | def sentiment_agent(state: AgentState):
13 | """Analyzes market sentiment and generates trading signals for multiple tickers."""
14 | data = state.get("data", {})
15 | end_date = data.get("end_date")
16 | tickers = data.get("tickers")
17 |
18 | # Initialize sentiment analysis for each ticker
19 | sentiment_analysis = {}
20 |
21 | for ticker in tickers:
22 | progress.update_status("sentiment_agent", ticker, "Fetching insider trades")
23 |
24 | # Get the insider trades
25 | insider_trades = get_insider_trades(
26 | ticker=ticker,
27 | end_date=end_date,
28 | limit=1000,
29 | )
30 |
31 | progress.update_status("sentiment_agent", ticker, "Analyzing trading patterns")
32 |
33 | # Get the signals from the insider trades
34 | transaction_shares = pd.Series([t.transaction_shares for t in insider_trades]).dropna()
35 | insider_signals = np.where(transaction_shares < 0, "bearish", "bullish").tolist()
36 |
37 | progress.update_status("sentiment_agent", ticker, "Fetching company news")
38 |
39 | # Get the company news
40 | company_news = get_company_news(ticker, end_date, limit=100)
41 |
42 | # Get the sentiment from the company news
43 | sentiment = pd.Series([n.sentiment for n in company_news]).dropna()
44 | news_signals = np.where(sentiment == "negative", "bearish",
45 | np.where(sentiment == "positive", "bullish", "neutral")).tolist()
46 |
47 | progress.update_status("sentiment_agent", ticker, "Combining signals")
48 | # Combine signals from both sources with weights
49 | insider_weight = 0.3
50 | news_weight = 0.7
51 |
52 | # Calculate weighted signal counts
53 | bullish_signals = (
54 | insider_signals.count("bullish") * insider_weight +
55 | news_signals.count("bullish") * news_weight
56 | )
57 | bearish_signals = (
58 | insider_signals.count("bearish") * insider_weight +
59 | news_signals.count("bearish") * news_weight
60 | )
61 |
62 | if bullish_signals > bearish_signals:
63 | overall_signal = "bullish"
64 | elif bearish_signals > bullish_signals:
65 | overall_signal = "bearish"
66 | else:
67 | overall_signal = "neutral"
68 |
69 | # Calculate confidence level based on the weighted proportion
70 | total_weighted_signals = len(insider_signals) * insider_weight + len(news_signals) * news_weight
71 | confidence = 0 # Default confidence when there are no signals
72 | if total_weighted_signals > 0:
73 | confidence = round(max(bullish_signals, bearish_signals) / total_weighted_signals, 2) * 100
74 | reasoning = f"Weighted Bullish signals: {bullish_signals:.1f}, Weighted Bearish signals: {bearish_signals:.1f}"
75 |
76 | sentiment_analysis[ticker] = {
77 | "signal": overall_signal,
78 | "confidence": confidence,
79 | "reasoning": reasoning,
80 | }
81 |
82 | progress.update_status("sentiment_agent", ticker, "Done")
83 |
84 | # Create the sentiment message
85 | message = HumanMessage(
86 | content=json.dumps(sentiment_analysis),
87 | name="sentiment_agent",
88 | )
89 |
90 | # Print the reasoning if the flag is set
91 | if state["metadata"]["show_reasoning"]:
92 | show_agent_reasoning(sentiment_analysis, "Sentiment Analysis Agent")
93 |
94 | # Add the signal to the analyst_signals list
95 | state["data"]["analyst_signals"]["sentiment_agent"] = sentiment_analysis
96 |
97 | return {
98 | "messages": [message],
99 | "data": data,
100 | }
101 |
--------------------------------------------------------------------------------
/src/agents/technicals.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from langchain_core.messages import HumanMessage
4 |
5 | from graph.state import AgentState, show_agent_reasoning
6 |
7 | import json
8 | import pandas as pd
9 | import numpy as np
10 |
11 | from tools.api import get_prices, prices_to_df
12 | from utils.progress import progress
13 |
14 |
15 | ##### Technical Analyst #####
16 | def technical_analyst_agent(state: AgentState):
17 | """
18 | Sophisticated technical analysis system that combines multiple trading strategies for multiple tickers:
19 | 1. Trend Following
20 | 2. Mean Reversion
21 | 3. Momentum
22 | 4. Volatility Analysis
23 | 5. Statistical Arbitrage Signals
24 | """
25 | data = state["data"]
26 | start_date = data["start_date"]
27 | end_date = data["end_date"]
28 | tickers = data["tickers"]
29 |
30 | # Initialize analysis for each ticker
31 | technical_analysis = {}
32 |
33 | for ticker in tickers:
34 | progress.update_status("technical_analyst_agent", ticker, "Analyzing price data")
35 |
36 | # Get the historical price data
37 | prices = get_prices(
38 | ticker=ticker,
39 | start_date=start_date,
40 | end_date=end_date,
41 | )
42 |
43 | if not prices:
44 | progress.update_status("technical_analyst_agent", ticker, "Failed: No price data found")
45 | continue
46 |
47 | # Convert prices to a DataFrame
48 | prices_df = prices_to_df(prices)
49 |
50 | progress.update_status("technical_analyst_agent", ticker, "Calculating trend signals")
51 | trend_signals = calculate_trend_signals(prices_df)
52 |
53 | progress.update_status("technical_analyst_agent", ticker, "Calculating mean reversion")
54 | mean_reversion_signals = calculate_mean_reversion_signals(prices_df)
55 |
56 | progress.update_status("technical_analyst_agent", ticker, "Calculating momentum")
57 | momentum_signals = calculate_momentum_signals(prices_df)
58 |
59 | progress.update_status("technical_analyst_agent", ticker, "Analyzing volatility")
60 | volatility_signals = calculate_volatility_signals(prices_df)
61 |
62 | progress.update_status("technical_analyst_agent", ticker, "Statistical analysis")
63 | stat_arb_signals = calculate_stat_arb_signals(prices_df)
64 |
65 | # Combine all signals using a weighted ensemble approach
66 | strategy_weights = {
67 | "trend": 0.25,
68 | "mean_reversion": 0.20,
69 | "momentum": 0.25,
70 | "volatility": 0.15,
71 | "stat_arb": 0.15,
72 | }
73 |
74 | progress.update_status("technical_analyst_agent", ticker, "Combining signals")
75 | combined_signal = weighted_signal_combination(
76 | {
77 | "trend": trend_signals,
78 | "mean_reversion": mean_reversion_signals,
79 | "momentum": momentum_signals,
80 | "volatility": volatility_signals,
81 | "stat_arb": stat_arb_signals,
82 | },
83 | strategy_weights,
84 | )
85 |
86 | # Generate detailed analysis report for this ticker
87 | technical_analysis[ticker] = {
88 | "signal": combined_signal["signal"],
89 | "confidence": round(combined_signal["confidence"] * 100),
90 | "strategy_signals": {
91 | "trend_following": {
92 | "signal": trend_signals["signal"],
93 | "confidence": round(trend_signals["confidence"] * 100),
94 | "metrics": normalize_pandas(trend_signals["metrics"]),
95 | },
96 | "mean_reversion": {
97 | "signal": mean_reversion_signals["signal"],
98 | "confidence": round(mean_reversion_signals["confidence"] * 100),
99 | "metrics": normalize_pandas(mean_reversion_signals["metrics"]),
100 | },
101 | "momentum": {
102 | "signal": momentum_signals["signal"],
103 | "confidence": round(momentum_signals["confidence"] * 100),
104 | "metrics": normalize_pandas(momentum_signals["metrics"]),
105 | },
106 | "volatility": {
107 | "signal": volatility_signals["signal"],
108 | "confidence": round(volatility_signals["confidence"] * 100),
109 | "metrics": normalize_pandas(volatility_signals["metrics"]),
110 | },
111 | "statistical_arbitrage": {
112 | "signal": stat_arb_signals["signal"],
113 | "confidence": round(stat_arb_signals["confidence"] * 100),
114 | "metrics": normalize_pandas(stat_arb_signals["metrics"]),
115 | },
116 | },
117 | }
118 | progress.update_status("technical_analyst_agent", ticker, "Done")
119 |
120 | # Create the technical analyst message
121 | message = HumanMessage(
122 | content=json.dumps(technical_analysis),
123 | name="technical_analyst_agent",
124 | )
125 |
126 | if state["metadata"]["show_reasoning"]:
127 | show_agent_reasoning(technical_analysis, "Technical Analyst")
128 |
129 | # Add the signal to the analyst_signals list
130 | state["data"]["analyst_signals"]["technical_analyst_agent"] = technical_analysis
131 |
132 | return {
133 | "messages": state["messages"] + [message],
134 | "data": data,
135 | }
136 |
137 |
138 | def calculate_trend_signals(prices_df):
139 | """
140 | Advanced trend following strategy using multiple timeframes and indicators
141 | """
142 | # Calculate EMAs for multiple timeframes
143 | ema_8 = calculate_ema(prices_df, 8)
144 | ema_21 = calculate_ema(prices_df, 21)
145 | ema_55 = calculate_ema(prices_df, 55)
146 |
147 | # Calculate ADX for trend strength
148 | adx = calculate_adx(prices_df, 14)
149 |
150 | # Determine trend direction and strength
151 | short_trend = ema_8 > ema_21
152 | medium_trend = ema_21 > ema_55
153 |
154 | # Combine signals with confidence weighting
155 | trend_strength = adx["adx"].iloc[-1] / 100.0
156 |
157 | if short_trend.iloc[-1] and medium_trend.iloc[-1]:
158 | signal = "bullish"
159 | confidence = trend_strength
160 | elif not short_trend.iloc[-1] and not medium_trend.iloc[-1]:
161 | signal = "bearish"
162 | confidence = trend_strength
163 | else:
164 | signal = "neutral"
165 | confidence = 0.5
166 |
167 | return {
168 | "signal": signal,
169 | "confidence": confidence,
170 | "metrics": {
171 | "adx": float(adx["adx"].iloc[-1]),
172 | "trend_strength": float(trend_strength),
173 | },
174 | }
175 |
176 |
177 | def calculate_mean_reversion_signals(prices_df):
178 | """
179 | Mean reversion strategy using statistical measures and Bollinger Bands
180 | """
181 | # Calculate z-score of price relative to moving average
182 | ma_50 = prices_df["close"].rolling(window=50).mean()
183 | std_50 = prices_df["close"].rolling(window=50).std()
184 | z_score = (prices_df["close"] - ma_50) / std_50
185 |
186 | # Calculate Bollinger Bands
187 | bb_upper, bb_lower = calculate_bollinger_bands(prices_df)
188 |
189 | # Calculate RSI with multiple timeframes
190 | rsi_14 = calculate_rsi(prices_df, 14)
191 | rsi_28 = calculate_rsi(prices_df, 28)
192 |
193 | # Mean reversion signals
194 | price_vs_bb = (prices_df["close"].iloc[-1] - bb_lower.iloc[-1]) / (bb_upper.iloc[-1] - bb_lower.iloc[-1])
195 |
196 | # Combine signals
197 | if z_score.iloc[-1] < -2 and price_vs_bb < 0.2:
198 | signal = "bullish"
199 | confidence = min(abs(z_score.iloc[-1]) / 4, 1.0)
200 | elif z_score.iloc[-1] > 2 and price_vs_bb > 0.8:
201 | signal = "bearish"
202 | confidence = min(abs(z_score.iloc[-1]) / 4, 1.0)
203 | else:
204 | signal = "neutral"
205 | confidence = 0.5
206 |
207 | return {
208 | "signal": signal,
209 | "confidence": confidence,
210 | "metrics": {
211 | "z_score": float(z_score.iloc[-1]),
212 | "price_vs_bb": float(price_vs_bb),
213 | "rsi_14": float(rsi_14.iloc[-1]),
214 | "rsi_28": float(rsi_28.iloc[-1]),
215 | },
216 | }
217 |
218 |
219 | def calculate_momentum_signals(prices_df):
220 | """
221 | Multi-factor momentum strategy
222 | """
223 | # Price momentum
224 | returns = prices_df["close"].pct_change()
225 | mom_1m = returns.rolling(21).sum()
226 | mom_3m = returns.rolling(63).sum()
227 | mom_6m = returns.rolling(126).sum()
228 |
229 | # Volume momentum
230 | volume_ma = prices_df["volume"].rolling(21).mean()
231 | volume_momentum = prices_df["volume"] / volume_ma
232 |
233 | # Relative strength
234 | # (would compare to market/sector in real implementation)
235 |
236 | # Calculate momentum score
237 | momentum_score = (0.4 * mom_1m + 0.3 * mom_3m + 0.3 * mom_6m).iloc[-1]
238 |
239 | # Volume confirmation
240 | volume_confirmation = volume_momentum.iloc[-1] > 1.0
241 |
242 | if momentum_score > 0.05 and volume_confirmation:
243 | signal = "bullish"
244 | confidence = min(abs(momentum_score) * 5, 1.0)
245 | elif momentum_score < -0.05 and volume_confirmation:
246 | signal = "bearish"
247 | confidence = min(abs(momentum_score) * 5, 1.0)
248 | else:
249 | signal = "neutral"
250 | confidence = 0.5
251 |
252 | return {
253 | "signal": signal,
254 | "confidence": confidence,
255 | "metrics": {
256 | "momentum_1m": float(mom_1m.iloc[-1]),
257 | "momentum_3m": float(mom_3m.iloc[-1]),
258 | "momentum_6m": float(mom_6m.iloc[-1]),
259 | "volume_momentum": float(volume_momentum.iloc[-1]),
260 | },
261 | }
262 |
263 |
264 | def calculate_volatility_signals(prices_df):
265 | """
266 | Volatility-based trading strategy
267 | """
268 | # Calculate various volatility metrics
269 | returns = prices_df["close"].pct_change()
270 |
271 | # Historical volatility
272 | hist_vol = returns.rolling(21).std() * math.sqrt(252)
273 |
274 | # Volatility regime detection
275 | vol_ma = hist_vol.rolling(63).mean()
276 | vol_regime = hist_vol / vol_ma
277 |
278 | # Volatility mean reversion
279 | vol_z_score = (hist_vol - vol_ma) / hist_vol.rolling(63).std()
280 |
281 | # ATR ratio
282 | atr = calculate_atr(prices_df)
283 | atr_ratio = atr / prices_df["close"]
284 |
285 | # Generate signal based on volatility regime
286 | current_vol_regime = vol_regime.iloc[-1]
287 | vol_z = vol_z_score.iloc[-1]
288 |
289 | if current_vol_regime < 0.8 and vol_z < -1:
290 | signal = "bullish" # Low vol regime, potential for expansion
291 | confidence = min(abs(vol_z) / 3, 1.0)
292 | elif current_vol_regime > 1.2 and vol_z > 1:
293 | signal = "bearish" # High vol regime, potential for contraction
294 | confidence = min(abs(vol_z) / 3, 1.0)
295 | else:
296 | signal = "neutral"
297 | confidence = 0.5
298 |
299 | return {
300 | "signal": signal,
301 | "confidence": confidence,
302 | "metrics": {
303 | "historical_volatility": float(hist_vol.iloc[-1]),
304 | "volatility_regime": float(current_vol_regime),
305 | "volatility_z_score": float(vol_z),
306 | "atr_ratio": float(atr_ratio.iloc[-1]),
307 | },
308 | }
309 |
310 |
311 | def calculate_stat_arb_signals(prices_df):
312 | """
313 | Statistical arbitrage signals based on price action analysis
314 | """
315 | # Calculate price distribution statistics
316 | returns = prices_df["close"].pct_change()
317 |
318 | # Skewness and kurtosis
319 | skew = returns.rolling(63).skew()
320 | kurt = returns.rolling(63).kurt()
321 |
322 | # Test for mean reversion using Hurst exponent
323 | hurst = calculate_hurst_exponent(prices_df["close"])
324 |
325 | # Correlation analysis
326 | # (would include correlation with related securities in real implementation)
327 |
328 | # Generate signal based on statistical properties
329 | if hurst < 0.4 and skew.iloc[-1] > 1:
330 | signal = "bullish"
331 | confidence = (0.5 - hurst) * 2
332 | elif hurst < 0.4 and skew.iloc[-1] < -1:
333 | signal = "bearish"
334 | confidence = (0.5 - hurst) * 2
335 | else:
336 | signal = "neutral"
337 | confidence = 0.5
338 |
339 | return {
340 | "signal": signal,
341 | "confidence": confidence,
342 | "metrics": {
343 | "hurst_exponent": float(hurst),
344 | "skewness": float(skew.iloc[-1]),
345 | "kurtosis": float(kurt.iloc[-1]),
346 | },
347 | }
348 |
349 |
350 | def weighted_signal_combination(signals, weights):
351 | """
352 | Combines multiple trading signals using a weighted approach
353 | """
354 | # Convert signals to numeric values
355 | signal_values = {"bullish": 1, "neutral": 0, "bearish": -1}
356 |
357 | weighted_sum = 0
358 | total_confidence = 0
359 |
360 | for strategy, signal in signals.items():
361 | numeric_signal = signal_values[signal["signal"]]
362 | weight = weights[strategy]
363 | confidence = signal["confidence"]
364 |
365 | weighted_sum += numeric_signal * weight * confidence
366 | total_confidence += weight * confidence
367 |
368 | # Normalize the weighted sum
369 | if total_confidence > 0:
370 | final_score = weighted_sum / total_confidence
371 | else:
372 | final_score = 0
373 |
374 | # Convert back to signal
375 | if final_score > 0.2:
376 | signal = "bullish"
377 | elif final_score < -0.2:
378 | signal = "bearish"
379 | else:
380 | signal = "neutral"
381 |
382 | return {"signal": signal, "confidence": abs(final_score)}
383 |
384 |
385 | def normalize_pandas(obj):
386 | """Convert pandas Series/DataFrames to primitive Python types"""
387 | if isinstance(obj, pd.Series):
388 | return obj.tolist()
389 | elif isinstance(obj, pd.DataFrame):
390 | return obj.to_dict("records")
391 | elif isinstance(obj, dict):
392 | return {k: normalize_pandas(v) for k, v in obj.items()}
393 | elif isinstance(obj, (list, tuple)):
394 | return [normalize_pandas(item) for item in obj]
395 | return obj
396 |
397 |
398 | def calculate_rsi(prices_df: pd.DataFrame, period: int = 14) -> pd.Series:
399 | delta = prices_df["close"].diff()
400 | gain = (delta.where(delta > 0, 0)).fillna(0)
401 | loss = (-delta.where(delta < 0, 0)).fillna(0)
402 | avg_gain = gain.rolling(window=period).mean()
403 | avg_loss = loss.rolling(window=period).mean()
404 | rs = avg_gain / avg_loss
405 | rsi = 100 - (100 / (1 + rs))
406 | return rsi
407 |
408 |
409 | def calculate_bollinger_bands(prices_df: pd.DataFrame, window: int = 20) -> tuple[pd.Series, pd.Series]:
410 | sma = prices_df["close"].rolling(window).mean()
411 | std_dev = prices_df["close"].rolling(window).std()
412 | upper_band = sma + (std_dev * 2)
413 | lower_band = sma - (std_dev * 2)
414 | return upper_band, lower_band
415 |
416 |
417 | def calculate_ema(df: pd.DataFrame, window: int) -> pd.Series:
418 | """
419 | Calculate Exponential Moving Average
420 |
421 | Args:
422 | df: DataFrame with price data
423 | window: EMA period
424 |
425 | Returns:
426 | pd.Series: EMA values
427 | """
428 | return df["close"].ewm(span=window, adjust=False).mean()
429 |
430 |
431 | def calculate_adx(df: pd.DataFrame, period: int = 14) -> pd.DataFrame:
432 | """
433 | Calculate Average Directional Index (ADX)
434 |
435 | Args:
436 | df: DataFrame with OHLC data
437 | period: Period for calculations
438 |
439 | Returns:
440 | DataFrame with ADX values
441 | """
442 | # Calculate True Range
443 | df["high_low"] = df["high"] - df["low"]
444 | df["high_close"] = abs(df["high"] - df["close"].shift())
445 | df["low_close"] = abs(df["low"] - df["close"].shift())
446 | df["tr"] = df[["high_low", "high_close", "low_close"]].max(axis=1)
447 |
448 | # Calculate Directional Movement
449 | df["up_move"] = df["high"] - df["high"].shift()
450 | df["down_move"] = df["low"].shift() - df["low"]
451 |
452 | df["plus_dm"] = np.where((df["up_move"] > df["down_move"]) & (df["up_move"] > 0), df["up_move"], 0)
453 | df["minus_dm"] = np.where((df["down_move"] > df["up_move"]) & (df["down_move"] > 0), df["down_move"], 0)
454 |
455 | # Calculate ADX
456 | df["+di"] = 100 * (df["plus_dm"].ewm(span=period).mean() / df["tr"].ewm(span=period).mean())
457 | df["-di"] = 100 * (df["minus_dm"].ewm(span=period).mean() / df["tr"].ewm(span=period).mean())
458 | df["dx"] = 100 * abs(df["+di"] - df["-di"]) / (df["+di"] + df["-di"])
459 | df["adx"] = df["dx"].ewm(span=period).mean()
460 |
461 | return df[["adx", "+di", "-di"]]
462 |
463 |
464 | def calculate_atr(df: pd.DataFrame, period: int = 14) -> pd.Series:
465 | """
466 | Calculate Average True Range
467 |
468 | Args:
469 | df: DataFrame with OHLC data
470 | period: Period for ATR calculation
471 |
472 | Returns:
473 | pd.Series: ATR values
474 | """
475 | high_low = df["high"] - df["low"]
476 | high_close = abs(df["high"] - df["close"].shift())
477 | low_close = abs(df["low"] - df["close"].shift())
478 |
479 | ranges = pd.concat([high_low, high_close, low_close], axis=1)
480 | true_range = ranges.max(axis=1)
481 |
482 | return true_range.rolling(period).mean()
483 |
484 |
485 | def calculate_hurst_exponent(price_series: pd.Series, max_lag: int = 20) -> float:
486 | """
487 | Calculate Hurst Exponent to determine long-term memory of time series
488 | H < 0.5: Mean reverting series
489 | H = 0.5: Random walk
490 | H > 0.5: Trending series
491 |
492 | Args:
493 | price_series: Array-like price data
494 | max_lag: Maximum lag for R/S calculation
495 |
496 | Returns:
497 | float: Hurst exponent
498 | """
499 | lags = range(2, max_lag)
500 | # Add small epsilon to avoid log(0)
501 | tau = [max(1e-8, np.sqrt(np.std(np.subtract(price_series[lag:], price_series[:-lag])))) for lag in lags]
502 |
503 | # Return the Hurst exponent from linear fit
504 | try:
505 | reg = np.polyfit(np.log(lags), np.log(tau), 1)
506 | return reg[0] # Hurst exponent is the slope
507 | except (ValueError, RuntimeWarning):
508 | # Return 0.5 (random walk) if calculation fails
509 | return 0.5
510 |
--------------------------------------------------------------------------------
/src/agents/valuation.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | """Valuation Agent
4 |
5 | Implements four complementary valuation methodologies and aggregates them with
6 | configurable weights.
7 | """
8 |
9 | from statistics import median
10 | import json
11 | from langchain_core.messages import HumanMessage
12 | from graph.state import AgentState, show_agent_reasoning
13 | from utils.progress import progress
14 |
15 | from tools.api import (
16 | get_financial_metrics,
17 | get_market_cap,
18 | search_line_items,
19 | )
20 |
21 | def valuation_agent(state: AgentState):
22 | """Run valuation across tickers and write signals back to `state`."""
23 |
24 | data = state["data"]
25 | end_date = data["end_date"]
26 | tickers = data["tickers"]
27 |
28 | valuation_analysis: dict[str, dict] = {}
29 |
30 | for ticker in tickers:
31 | progress.update_status("valuation_agent", ticker, "Fetching financial data")
32 |
33 | # --- Historical financial metrics (pull 8 latest TTM snapshots for medians) ---
34 | financial_metrics = get_financial_metrics(
35 | ticker=ticker,
36 | end_date=end_date,
37 | period="ttm",
38 | limit=8,
39 | )
40 | if not financial_metrics:
41 | progress.update_status("valuation_agent", ticker, "Failed: No financial metrics found")
42 | continue
43 | most_recent_metrics = financial_metrics[0]
44 |
45 | # --- Fine‑grained line‑items (need two periods to calc WC change) ---
46 | progress.update_status("valuation_agent", ticker, "Gathering line items")
47 | line_items = search_line_items(
48 | ticker=ticker,
49 | line_items=[
50 | "free_cash_flow",
51 | "net_income",
52 | "depreciation_and_amortization",
53 | "capital_expenditure",
54 | "working_capital",
55 | ],
56 | end_date=end_date,
57 | period="ttm",
58 | limit=2,
59 | )
60 | if len(line_items) < 2:
61 | progress.update_status("valuation_agent", ticker, "Failed: Insufficient financial line items")
62 | continue
63 | li_curr, li_prev = line_items[0], line_items[1]
64 |
65 | # ------------------------------------------------------------------
66 | # Valuation models
67 | # ------------------------------------------------------------------
68 | wc_change = li_curr.working_capital - li_prev.working_capital
69 |
70 | # Owner Earnings
71 | owner_val = calculate_owner_earnings_value(
72 | net_income=li_curr.net_income,
73 | depreciation=li_curr.depreciation_and_amortization,
74 | capex=li_curr.capital_expenditure,
75 | working_capital_change=wc_change,
76 | growth_rate=most_recent_metrics.earnings_growth or 0.05,
77 | )
78 |
79 | # Discounted Cash Flow
80 | dcf_val = calculate_intrinsic_value(
81 | free_cash_flow=li_curr.free_cash_flow,
82 | growth_rate=most_recent_metrics.earnings_growth or 0.05,
83 | discount_rate=0.10,
84 | terminal_growth_rate=0.03,
85 | num_years=5,
86 | )
87 |
88 | # Implied Equity Value
89 | ev_ebitda_val = calculate_ev_ebitda_value(financial_metrics)
90 |
91 | # Residual Income Model
92 | rim_val = calculate_residual_income_value(
93 | market_cap=most_recent_metrics.market_cap,
94 | net_income=li_curr.net_income,
95 | price_to_book_ratio=most_recent_metrics.price_to_book_ratio,
96 | book_value_growth=most_recent_metrics.book_value_growth or 0.03,
97 | )
98 |
99 | # ------------------------------------------------------------------
100 | # Aggregate & signal
101 | # ------------------------------------------------------------------
102 | market_cap = get_market_cap(ticker, end_date)
103 | if not market_cap:
104 | progress.update_status("valuation_agent", ticker, "Failed: Market cap unavailable")
105 | continue
106 |
107 | method_values = {
108 | "dcf": {"value": dcf_val, "weight": 0.35},
109 | "owner_earnings": {"value": owner_val, "weight": 0.35},
110 | "ev_ebitda": {"value": ev_ebitda_val, "weight": 0.20},
111 | "residual_income": {"value": rim_val, "weight": 0.10},
112 | }
113 |
114 | total_weight = sum(v["weight"] for v in method_values.values() if v["value"] > 0)
115 | if total_weight == 0:
116 | progress.update_status("valuation_agent", ticker, "Failed: All valuation methods zero")
117 | continue
118 |
119 | for v in method_values.values():
120 | v["gap"] = (v["value"] - market_cap) / market_cap if v["value"] > 0 else None
121 |
122 | weighted_gap = sum(
123 | v["weight"] * v["gap"] for v in method_values.values() if v["gap"] is not None
124 | ) / total_weight
125 |
126 | signal = "bullish" if weighted_gap > 0.15 else "bearish" if weighted_gap < -0.15 else "neutral"
127 | confidence = round(min(abs(weighted_gap) / 0.30 * 100, 100))
128 |
129 | reasoning = {
130 | f"{m}_analysis": {
131 | "signal": (
132 | "bullish" if vals["gap"] and vals["gap"] > 0.15 else
133 | "bearish" if vals["gap"] and vals["gap"] < -0.15 else "neutral"
134 | ),
135 | "details": (
136 | f"Value: ${vals['value']:,.2f}, Market Cap: ${market_cap:,.2f}, "
137 | f"Gap: {vals['gap']:.1%}, Weight: {vals['weight']*100:.0f}%"
138 | ),
139 | }
140 | for m, vals in method_values.items() if vals["value"] > 0
141 | }
142 |
143 | valuation_analysis[ticker] = {
144 | "signal": signal,
145 | "confidence": confidence,
146 | "reasoning": reasoning,
147 | }
148 | progress.update_status("valuation_agent", ticker, "Done")
149 |
150 | # ---- Emit message (for LLM tool chain) ----
151 | msg = HumanMessage(content=json.dumps(valuation_analysis), name="valuation_agent")
152 | if state["metadata"].get("show_reasoning"):
153 | show_agent_reasoning(valuation_analysis, "Valuation Analysis Agent")
154 | state["data"]["analyst_signals"]["valuation_agent"] = valuation_analysis
155 | return {"messages": [msg], "data": data}
156 |
157 | #############################
158 | # Helper Valuation Functions
159 | #############################
160 |
161 | def calculate_owner_earnings_value(
162 | net_income: float | None,
163 | depreciation: float | None,
164 | capex: float | None,
165 | working_capital_change: float | None,
166 | growth_rate: float = 0.05,
167 | required_return: float = 0.15,
168 | margin_of_safety: float = 0.25,
169 | num_years: int = 5,
170 | ) -> float:
171 | """Buffett owner‑earnings valuation with margin‑of‑safety."""
172 | if not all(isinstance(x, (int, float)) for x in [net_income, depreciation, capex, working_capital_change]):
173 | return 0
174 |
175 | owner_earnings = net_income + depreciation - capex - working_capital_change
176 | if owner_earnings <= 0:
177 | return 0
178 |
179 | pv = 0.0
180 | for yr in range(1, num_years + 1):
181 | future = owner_earnings * (1 + growth_rate) ** yr
182 | pv += future / (1 + required_return) ** yr
183 |
184 | terminal_growth = min(growth_rate, 0.03)
185 | term_val = (owner_earnings * (1 + growth_rate) ** num_years * (1 + terminal_growth)) / (
186 | required_return - terminal_growth
187 | )
188 | pv_term = term_val / (1 + required_return) ** num_years
189 |
190 | intrinsic = pv + pv_term
191 | return intrinsic * (1 - margin_of_safety)
192 |
193 |
194 | def calculate_intrinsic_value(
195 | free_cash_flow: float | None,
196 | growth_rate: float = 0.05,
197 | discount_rate: float = 0.10,
198 | terminal_growth_rate: float = 0.02,
199 | num_years: int = 5,
200 | ) -> float:
201 | """Classic DCF on FCF with constant growth and terminal value."""
202 | if free_cash_flow is None or free_cash_flow <= 0:
203 | return 0
204 |
205 | pv = 0.0
206 | for yr in range(1, num_years + 1):
207 | fcft = free_cash_flow * (1 + growth_rate) ** yr
208 | pv += fcft / (1 + discount_rate) ** yr
209 |
210 | term_val = (
211 | free_cash_flow * (1 + growth_rate) ** num_years * (1 + terminal_growth_rate)
212 | ) / (discount_rate - terminal_growth_rate)
213 | pv_term = term_val / (1 + discount_rate) ** num_years
214 |
215 | return pv + pv_term
216 |
217 |
218 | def calculate_ev_ebitda_value(financial_metrics: list):
219 | """Implied equity value via median EV/EBITDA multiple."""
220 | if not financial_metrics:
221 | return 0
222 | m0 = financial_metrics[0]
223 | if not (m0.enterprise_value and m0.enterprise_value_to_ebitda_ratio):
224 | return 0
225 | if m0.enterprise_value_to_ebitda_ratio == 0:
226 | return 0
227 |
228 | ebitda_now = m0.enterprise_value / m0.enterprise_value_to_ebitda_ratio
229 | med_mult = median([
230 | m.enterprise_value_to_ebitda_ratio for m in financial_metrics if m.enterprise_value_to_ebitda_ratio
231 | ])
232 | ev_implied = med_mult * ebitda_now
233 | net_debt = (m0.enterprise_value or 0) - (m0.market_cap or 0)
234 | return max(ev_implied - net_debt, 0)
235 |
236 |
237 | def calculate_residual_income_value(
238 | market_cap: float | None,
239 | net_income: float | None,
240 | price_to_book_ratio: float | None,
241 | book_value_growth: float = 0.03,
242 | cost_of_equity: float = 0.10,
243 | terminal_growth_rate: float = 0.03,
244 | num_years: int = 5,
245 | ):
246 | """Residual Income Model (Edwards‑Bell‑Ohlson)."""
247 | if not (market_cap and net_income and price_to_book_ratio and price_to_book_ratio > 0):
248 | return 0
249 |
250 | book_val = market_cap / price_to_book_ratio
251 | ri0 = net_income - cost_of_equity * book_val
252 | if ri0 <= 0:
253 | return 0
254 |
255 | pv_ri = 0.0
256 | for yr in range(1, num_years + 1):
257 | ri_t = ri0 * (1 + book_value_growth) ** yr
258 | pv_ri += ri_t / (1 + cost_of_equity) ** yr
259 |
260 | term_ri = ri0 * (1 + book_value_growth) ** (num_years + 1) / (
261 | cost_of_equity - terminal_growth_rate
262 | )
263 | pv_term = term_ri / (1 + cost_of_equity) ** num_years
264 |
265 | intrinsic = book_val + pv_ri + pv_term
266 | return intrinsic * 0.8 # 20% margin of safety
267 |
--------------------------------------------------------------------------------
/src/data/cache.py:
--------------------------------------------------------------------------------
1 | class Cache:
2 | """In-memory cache for API responses."""
3 |
4 | def __init__(self):
5 | self._prices_cache: dict[str, list[dict[str, any]]] = {}
6 | self._financial_metrics_cache: dict[str, list[dict[str, any]]] = {}
7 | self._line_items_cache: dict[str, list[dict[str, any]]] = {}
8 | self._insider_trades_cache: dict[str, list[dict[str, any]]] = {}
9 | self._company_news_cache: dict[str, list[dict[str, any]]] = {}
10 |
11 | def _merge_data(self, existing: list[dict] | None, new_data: list[dict], key_field: str) -> list[dict]:
12 | """Merge existing and new data, avoiding duplicates based on a key field."""
13 | if not existing:
14 | return new_data
15 |
16 | # Create a set of existing keys for O(1) lookup
17 | existing_keys = {item[key_field] for item in existing}
18 |
19 | # Only add items that don't exist yet
20 | merged = existing.copy()
21 | merged.extend([item for item in new_data if item[key_field] not in existing_keys])
22 | return merged
23 |
24 | def get_prices(self, ticker: str) -> list[dict[str, any]] | None:
25 | """Get cached price data if available."""
26 | return self._prices_cache.get(ticker)
27 |
28 | def set_prices(self, ticker: str, data: list[dict[str, any]]):
29 | """Append new price data to cache."""
30 | self._prices_cache[ticker] = self._merge_data(self._prices_cache.get(ticker), data, key_field="time")
31 |
32 | def get_financial_metrics(self, ticker: str) -> list[dict[str, any]]:
33 | """Get cached financial metrics if available."""
34 | return self._financial_metrics_cache.get(ticker)
35 |
36 | def set_financial_metrics(self, ticker: str, data: list[dict[str, any]]):
37 | """Append new financial metrics to cache."""
38 | self._financial_metrics_cache[ticker] = self._merge_data(self._financial_metrics_cache.get(ticker), data, key_field="report_period")
39 |
40 | def get_line_items(self, ticker: str) -> list[dict[str, any]] | None:
41 | """Get cached line items if available."""
42 | return self._line_items_cache.get(ticker)
43 |
44 | def set_line_items(self, ticker: str, data: list[dict[str, any]]):
45 | """Append new line items to cache."""
46 | self._line_items_cache[ticker] = self._merge_data(self._line_items_cache.get(ticker), data, key_field="report_period")
47 |
48 | def get_insider_trades(self, ticker: str) -> list[dict[str, any]] | None:
49 | """Get cached insider trades if available."""
50 | return self._insider_trades_cache.get(ticker)
51 |
52 | def set_insider_trades(self, ticker: str, data: list[dict[str, any]]):
53 | """Append new insider trades to cache."""
54 | self._insider_trades_cache[ticker] = self._merge_data(self._insider_trades_cache.get(ticker), data, key_field="filing_date") # Could also use transaction_date if preferred
55 |
56 | def get_company_news(self, ticker: str) -> list[dict[str, any]] | None:
57 | """Get cached company news if available."""
58 | return self._company_news_cache.get(ticker)
59 |
60 | def set_company_news(self, ticker: str, data: list[dict[str, any]]):
61 | """Append new company news to cache."""
62 | self._company_news_cache[ticker] = self._merge_data(self._company_news_cache.get(ticker), data, key_field="date")
63 |
64 |
65 | # Global cache instance
66 | _cache = Cache()
67 |
68 |
69 | def get_cache() -> Cache:
70 | """Get the global cache instance."""
71 | return _cache
72 |
--------------------------------------------------------------------------------
/src/data/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class Price(BaseModel):
5 | open: float
6 | close: float
7 | high: float
8 | low: float
9 | volume: int
10 | time: str
11 |
12 |
13 | class PriceResponse(BaseModel):
14 | ticker: str
15 | prices: list[Price]
16 |
17 |
18 | class FinancialMetrics(BaseModel):
19 | ticker: str
20 | report_period: str
21 | period: str
22 | currency: str
23 | market_cap: float | None
24 | enterprise_value: float | None
25 | price_to_earnings_ratio: float | None
26 | price_to_book_ratio: float | None
27 | price_to_sales_ratio: float | None
28 | enterprise_value_to_ebitda_ratio: float | None
29 | enterprise_value_to_revenue_ratio: float | None
30 | free_cash_flow_yield: float | None
31 | peg_ratio: float | None
32 | gross_margin: float | None
33 | operating_margin: float | None
34 | net_margin: float | None
35 | return_on_equity: float | None
36 | return_on_assets: float | None
37 | return_on_invested_capital: float | None
38 | asset_turnover: float | None
39 | inventory_turnover: float | None
40 | receivables_turnover: float | None
41 | days_sales_outstanding: float | None
42 | operating_cycle: float | None
43 | working_capital_turnover: float | None
44 | current_ratio: float | None
45 | quick_ratio: float | None
46 | cash_ratio: float | None
47 | operating_cash_flow_ratio: float | None
48 | debt_to_equity: float | None
49 | debt_to_assets: float | None
50 | interest_coverage: float | None
51 | revenue_growth: float | None
52 | earnings_growth: float | None
53 | book_value_growth: float | None
54 | earnings_per_share_growth: float | None
55 | free_cash_flow_growth: float | None
56 | operating_income_growth: float | None
57 | ebitda_growth: float | None
58 | payout_ratio: float | None
59 | earnings_per_share: float | None
60 | book_value_per_share: float | None
61 | free_cash_flow_per_share: float | None
62 |
63 |
64 | class FinancialMetricsResponse(BaseModel):
65 | financial_metrics: list[FinancialMetrics]
66 |
67 |
68 | class LineItem(BaseModel):
69 | ticker: str
70 | report_period: str
71 | period: str
72 | currency: str
73 |
74 | # Allow additional fields dynamically
75 | model_config = {"extra": "allow"}
76 |
77 |
78 | class LineItemResponse(BaseModel):
79 | search_results: list[LineItem]
80 |
81 |
82 | class InsiderTrade(BaseModel):
83 | ticker: str
84 | issuer: str | None
85 | name: str | None
86 | title: str | None
87 | is_board_director: bool | None
88 | transaction_date: str | None
89 | transaction_shares: float | None
90 | transaction_price_per_share: float | None
91 | transaction_value: float | None
92 | shares_owned_before_transaction: float | None
93 | shares_owned_after_transaction: float | None
94 | security_title: str | None
95 | filing_date: str
96 |
97 |
98 | class InsiderTradeResponse(BaseModel):
99 | insider_trades: list[InsiderTrade]
100 |
101 |
102 | class CompanyNews(BaseModel):
103 | ticker: str
104 | title: str
105 | author: str
106 | source: str
107 | date: str
108 | url: str
109 | sentiment: str | None = None
110 |
111 |
112 | class CompanyNewsResponse(BaseModel):
113 | news: list[CompanyNews]
114 |
115 |
116 | class CompanyFacts(BaseModel):
117 | ticker: str
118 | name: str
119 | cik: str | None = None
120 | industry: str | None = None
121 | sector: str | None = None
122 | category: str | None = None
123 | exchange: str | None = None
124 | is_active: bool | None = None
125 | listing_date: str | None = None
126 | location: str | None = None
127 | market_cap: float | None = None
128 | number_of_employees: int | None = None
129 | sec_filings_url: str | None = None
130 | sic_code: str | None = None
131 | sic_industry: str | None = None
132 | sic_sector: str | None = None
133 | website_url: str | None = None
134 | weighted_average_shares: int | None = None
135 |
136 |
137 | class CompanyFactsResponse(BaseModel):
138 | company_facts: CompanyFacts
139 |
140 |
141 | class Position(BaseModel):
142 | cash: float = 0.0
143 | shares: int = 0
144 | ticker: str
145 |
146 |
147 | class Portfolio(BaseModel):
148 | positions: dict[str, Position] # ticker -> Position mapping
149 | total_cash: float = 0.0
150 |
151 |
152 | class AnalystSignal(BaseModel):
153 | signal: str | None = None
154 | confidence: float | None = None
155 | reasoning: dict | str | None = None
156 | max_position_size: float | None = None # For risk management signals
157 |
158 |
159 | class TickerAnalysis(BaseModel):
160 | ticker: str
161 | analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping
162 |
163 |
164 | class AgentStateData(BaseModel):
165 | tickers: list[str]
166 | portfolio: Portfolio
167 | start_date: str
168 | end_date: str
169 | ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping
170 |
171 |
172 | class AgentStateMetadata(BaseModel):
173 | show_reasoning: bool = False
174 | model_config = {"extra": "allow"}
175 |
--------------------------------------------------------------------------------
/src/graph/state.py:
--------------------------------------------------------------------------------
1 | from typing_extensions import Annotated, Sequence, TypedDict
2 |
3 | import operator
4 | from langchain_core.messages import BaseMessage
5 |
6 |
7 | import json
8 |
9 |
10 | def merge_dicts(a: dict[str, any], b: dict[str, any]) -> dict[str, any]:
11 | return {**a, **b}
12 |
13 |
14 | # Define agent state
15 | class AgentState(TypedDict):
16 | messages: Annotated[Sequence[BaseMessage], operator.add]
17 | data: Annotated[dict[str, any], merge_dicts]
18 | metadata: Annotated[dict[str, any], merge_dicts]
19 |
20 |
21 | def show_agent_reasoning(output, agent_name):
22 | print(f"\n{'=' * 10} {agent_name.center(28)} {'=' * 10}")
23 |
24 | def convert_to_serializable(obj):
25 | if hasattr(obj, "to_dict"): # Handle Pandas Series/DataFrame
26 | return obj.to_dict()
27 | elif hasattr(obj, "__dict__"): # Handle custom objects
28 | return obj.__dict__
29 | elif isinstance(obj, (int, float, bool, str)):
30 | return obj
31 | elif isinstance(obj, (list, tuple)):
32 | return [convert_to_serializable(item) for item in obj]
33 | elif isinstance(obj, dict):
34 | return {key: convert_to_serializable(value) for key, value in obj.items()}
35 | else:
36 | return str(obj) # Fallback to string representation
37 |
38 | if isinstance(output, (dict, list)):
39 | # Convert the output to JSON-serializable format
40 | serializable_output = convert_to_serializable(output)
41 | print(json.dumps(serializable_output, indent=2))
42 | else:
43 | try:
44 | # Parse the string as JSON and pretty print it
45 | parsed_output = json.loads(output)
46 | print(json.dumps(parsed_output, indent=2))
47 | except json.JSONDecodeError:
48 | # Fallback to original string if not valid JSON
49 | print(output)
50 |
51 | print("=" * 48)
52 |
--------------------------------------------------------------------------------
/src/llm/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | from langchain_anthropic import ChatAnthropic
3 | from langchain_deepseek import ChatDeepSeek
4 | from langchain_google_genai import ChatGoogleGenerativeAI
5 | from langchain_groq import ChatGroq
6 | from langchain_openai import ChatOpenAI
7 | from langchain_ollama import ChatOllama
8 | from enum import Enum
9 | from pydantic import BaseModel
10 | from typing import Tuple
11 |
12 |
13 | class ModelProvider(str, Enum):
14 | """Enum for supported LLM providers"""
15 | ANTHROPIC = "Anthropic"
16 | DEEPSEEK = "DeepSeek"
17 | GEMINI = "Gemini"
18 | GROQ = "Groq"
19 | OPENAI = "OpenAI"
20 | OLLAMA = "Ollama"
21 |
22 |
23 |
24 | class LLMModel(BaseModel):
25 | """Represents an LLM model configuration"""
26 | display_name: str
27 | model_name: str
28 | provider: ModelProvider
29 |
30 | def to_choice_tuple(self) -> Tuple[str, str, str]:
31 | """Convert to format needed for questionary choices"""
32 | return (self.display_name, self.model_name, self.provider.value)
33 |
34 | def has_json_mode(self) -> bool:
35 | """Check if the model supports JSON mode"""
36 | if self.is_deepseek() or self.is_gemini():
37 | return False
38 | # Only certain Ollama models support JSON mode
39 | if self.is_ollama():
40 | return "llama3" in self.model_name or "neural-chat" in self.model_name
41 | return True
42 |
43 | def is_deepseek(self) -> bool:
44 | """Check if the model is a DeepSeek model"""
45 | return self.model_name.startswith("deepseek")
46 |
47 | def is_gemini(self) -> bool:
48 | """Check if the model is a Gemini model"""
49 | return self.model_name.startswith("gemini")
50 |
51 | def is_ollama(self) -> bool:
52 | """Check if the model is an Ollama model"""
53 | return self.provider == ModelProvider.OLLAMA
54 |
55 |
56 | # Define available models
57 | AVAILABLE_MODELS = [
58 | LLMModel(
59 | display_name="[anthropic] claude-3.5-haiku",
60 | model_name="claude-3-5-haiku-latest",
61 | provider=ModelProvider.ANTHROPIC
62 | ),
63 | LLMModel(
64 | display_name="[anthropic] claude-3.5-sonnet",
65 | model_name="claude-3-5-sonnet-latest",
66 | provider=ModelProvider.ANTHROPIC
67 | ),
68 | LLMModel(
69 | display_name="[anthropic] claude-3.7-sonnet",
70 | model_name="claude-3-7-sonnet-latest",
71 | provider=ModelProvider.ANTHROPIC
72 | ),
73 | LLMModel(
74 | display_name="[deepseek] deepseek-r1",
75 | model_name="deepseek-reasoner",
76 | provider=ModelProvider.DEEPSEEK
77 | ),
78 | LLMModel(
79 | display_name="[deepseek] deepseek-v3",
80 | model_name="deepseek-chat",
81 | provider=ModelProvider.DEEPSEEK
82 | ),
83 | LLMModel(
84 | display_name="[gemini] gemini-2.0-flash",
85 | model_name="gemini-2.0-flash",
86 | provider=ModelProvider.GEMINI
87 | ),
88 | LLMModel(
89 | display_name="[gemini] gemini-2.5-pro",
90 | model_name="gemini-2.5-pro-exp-03-25",
91 | provider=ModelProvider.GEMINI
92 | ),
93 | LLMModel(
94 | display_name="[groq] llama-4-scout-17b",
95 | model_name="meta-llama/llama-4-scout-17b-16e-instruct",
96 | provider=ModelProvider.GROQ
97 | ),
98 | LLMModel(
99 | display_name="[groq] llama-4-maverick-17b",
100 | model_name="meta-llama/llama-4-maverick-17b-128e-instruct",
101 | provider=ModelProvider.GROQ
102 | ),
103 | LLMModel(
104 | display_name="[openai] gpt-4.5",
105 | model_name="gpt-4.5-preview",
106 | provider=ModelProvider.OPENAI
107 | ),
108 | LLMModel(
109 | display_name="[openai] gpt-4o",
110 | model_name="gpt-4o",
111 | provider=ModelProvider.OPENAI
112 | ),
113 | LLMModel(
114 | display_name="[openai] o3",
115 | model_name="o3",
116 | provider=ModelProvider.OPENAI
117 | ),
118 | LLMModel(
119 | display_name="[openai] o4-mini",
120 | model_name="o4-mini",
121 | provider=ModelProvider.OPENAI
122 | ),
123 | ]
124 |
125 | # Define Ollama models separately
126 | OLLAMA_MODELS = [
127 | LLMModel(
128 | display_name="[ollama] gemma3 (4B)",
129 | model_name="gemma3:4b",
130 | provider=ModelProvider.OLLAMA
131 | ),
132 | LLMModel(
133 | display_name="[ollama] qwen2.5 (7B)",
134 | model_name="qwen2.5",
135 | provider=ModelProvider.OLLAMA
136 | ),
137 | LLMModel(
138 | display_name="[ollama] llama3.1 (8B)",
139 | model_name="llama3.1:latest",
140 | provider=ModelProvider.OLLAMA
141 | ),
142 | LLMModel(
143 | display_name="[ollama] gemma3 (12B)",
144 | model_name="gemma3:12b",
145 | provider=ModelProvider.OLLAMA
146 | ),
147 | LLMModel(
148 | display_name="[ollama] mistral-small3.1 (24B)",
149 | model_name="mistral-small3.1",
150 | provider=ModelProvider.OLLAMA
151 | ),
152 | LLMModel(
153 | display_name="[ollama] gemma3 (27B)",
154 | model_name="gemma3:27b",
155 | provider=ModelProvider.OLLAMA
156 | ),
157 | LLMModel(
158 | display_name="[ollama] qwen2.5 (32B)",
159 | model_name="qwen2.5:32b",
160 | provider=ModelProvider.OLLAMA
161 | ),
162 | LLMModel(
163 | display_name="[ollama] llama-3.3 (70B)",
164 | model_name="llama3.3:70b-instruct-q4_0",
165 | provider=ModelProvider.OLLAMA
166 | ),
167 | ]
168 |
169 | # Create LLM_ORDER in the format expected by the UI
170 | LLM_ORDER = [model.to_choice_tuple() for model in AVAILABLE_MODELS]
171 |
172 | # Create Ollama LLM_ORDER separately
173 | OLLAMA_LLM_ORDER = [model.to_choice_tuple() for model in OLLAMA_MODELS]
174 |
175 | def get_model_info(model_name: str) -> LLMModel | None:
176 | """Get model information by model_name"""
177 | all_models = AVAILABLE_MODELS + OLLAMA_MODELS
178 | return next((model for model in all_models if model.model_name == model_name), None)
179 |
180 | def get_model(model_name: str, model_provider: ModelProvider) -> ChatOpenAI | ChatGroq | ChatOllama | None:
181 | if model_provider == ModelProvider.GROQ:
182 | api_key = os.getenv("GROQ_API_KEY")
183 | if not api_key:
184 | # Print error to console
185 | print(f"API Key Error: Please make sure GROQ_API_KEY is set in your .env file.")
186 | raise ValueError("Groq API key not found. Please make sure GROQ_API_KEY is set in your .env file.")
187 | return ChatGroq(model=model_name, api_key=api_key)
188 | elif model_provider == ModelProvider.OPENAI:
189 | # Get and validate API key
190 | api_key = os.getenv("OPENAI_API_KEY")
191 | if not api_key:
192 | # Print error to console
193 | print(f"API Key Error: Please make sure OPENAI_API_KEY is set in your .env file.")
194 | raise ValueError("OpenAI API key not found. Please make sure OPENAI_API_KEY is set in your .env file.")
195 | return ChatOpenAI(model=model_name, api_key=api_key)
196 | elif model_provider == ModelProvider.ANTHROPIC:
197 | api_key = os.getenv("ANTHROPIC_API_KEY")
198 | if not api_key:
199 | print(f"API Key Error: Please make sure ANTHROPIC_API_KEY is set in your .env file.")
200 | raise ValueError("Anthropic API key not found. Please make sure ANTHROPIC_API_KEY is set in your .env file.")
201 | return ChatAnthropic(model=model_name, api_key=api_key)
202 | elif model_provider == ModelProvider.DEEPSEEK:
203 | api_key = os.getenv("DEEPSEEK_API_KEY")
204 | if not api_key:
205 | print(f"API Key Error: Please make sure DEEPSEEK_API_KEY is set in your .env file.")
206 | raise ValueError("DeepSeek API key not found. Please make sure DEEPSEEK_API_KEY is set in your .env file.")
207 | return ChatDeepSeek(model=model_name, api_key=api_key)
208 | elif model_provider == ModelProvider.GEMINI:
209 | api_key = os.getenv("GOOGLE_API_KEY")
210 | if not api_key:
211 | print(f"API Key Error: Please make sure GOOGLE_API_KEY is set in your .env file.")
212 | raise ValueError("Google API key not found. Please make sure GOOGLE_API_KEY is set in your .env file.")
213 | return ChatGoogleGenerativeAI(model=model_name, api_key=api_key)
214 | elif model_provider == ModelProvider.OLLAMA:
215 | # For Ollama, we use a base URL instead of an API key
216 | # Check if OLLAMA_HOST is set (for Docker on macOS)
217 | ollama_host = os.getenv("OLLAMA_HOST", "localhost")
218 | base_url = os.getenv("OLLAMA_BASE_URL", f"http://{ollama_host}:11434")
219 | return ChatOllama(
220 | model=model_name,
221 | base_url=base_url,
222 | )
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from dotenv import load_dotenv
4 | from langchain_core.messages import HumanMessage
5 | from langgraph.graph import END, StateGraph
6 | from colorama import Fore, Style, init
7 | import questionary
8 | from agents.portfolio_manager import portfolio_management_agent
9 | from agents.risk_manager import risk_management_agent
10 | from graph.state import AgentState
11 | from utils.display import print_trading_output
12 | from utils.analysts import ANALYST_ORDER, get_analyst_nodes
13 | from utils.progress import progress
14 | from llm.models import LLM_ORDER, OLLAMA_LLM_ORDER, get_model_info, ModelProvider
15 | from utils.ollama import ensure_ollama_and_model
16 |
17 | import argparse
18 | from datetime import datetime
19 | from dateutil.relativedelta import relativedelta
20 | from utils.visualize import save_graph_as_png
21 | import json
22 |
23 | # Load environment variables from .env file
24 | load_dotenv()
25 |
26 | init(autoreset=True)
27 |
28 |
29 | def parse_hedge_fund_response(response):
30 | """Parses a JSON string and returns a dictionary."""
31 | try:
32 | return json.loads(response)
33 | except json.JSONDecodeError as e:
34 | print(f"JSON decoding error: {e}\nResponse: {repr(response)}")
35 | return None
36 | except TypeError as e:
37 | print(f"Invalid response type (expected string, got {type(response).__name__}): {e}")
38 | return None
39 | except Exception as e:
40 | print(f"Unexpected error while parsing response: {e}\nResponse: {repr(response)}")
41 | return None
42 |
43 |
44 | ##### Run the Hedge Fund #####
45 | def run_hedge_fund(
46 | tickers: list[str],
47 | start_date: str,
48 | end_date: str,
49 | portfolio: dict,
50 | show_reasoning: bool = False,
51 | selected_analysts: list[str] = [],
52 | model_name: str = "gpt-4o",
53 | model_provider: str = "OpenAI",
54 | ):
55 | # Start progress tracking
56 | progress.start()
57 |
58 | try:
59 | # Create a new workflow if analysts are customized
60 | if selected_analysts:
61 | workflow = create_workflow(selected_analysts)
62 | agent = workflow.compile()
63 | else:
64 | agent = app
65 |
66 | final_state = agent.invoke(
67 | {
68 | "messages": [
69 | HumanMessage(
70 | content="Make trading decisions based on the provided data.",
71 | )
72 | ],
73 | "data": {
74 | "tickers": tickers,
75 | "portfolio": portfolio,
76 | "start_date": start_date,
77 | "end_date": end_date,
78 | "analyst_signals": {},
79 | },
80 | "metadata": {
81 | "show_reasoning": show_reasoning,
82 | "model_name": model_name,
83 | "model_provider": model_provider,
84 | },
85 | },
86 | )
87 |
88 | return {
89 | "decisions": parse_hedge_fund_response(final_state["messages"][-1].content),
90 | "analyst_signals": final_state["data"]["analyst_signals"],
91 | }
92 | finally:
93 | # Stop progress tracking
94 | progress.stop()
95 |
96 |
97 | def start(state: AgentState):
98 | """Initialize the workflow with the input message."""
99 | return state
100 |
101 |
102 | def create_workflow(selected_analysts=None):
103 | """Create the workflow with selected analysts."""
104 | workflow = StateGraph(AgentState)
105 | workflow.add_node("start_node", start)
106 |
107 | # Get analyst nodes from the configuration
108 | analyst_nodes = get_analyst_nodes()
109 |
110 | # Default to all analysts if none selected
111 | if selected_analysts is None:
112 | selected_analysts = list(analyst_nodes.keys())
113 | # Add selected analyst nodes
114 | for analyst_key in selected_analysts:
115 | node_name, node_func = analyst_nodes[analyst_key]
116 | workflow.add_node(node_name, node_func)
117 | workflow.add_edge("start_node", node_name)
118 |
119 | # Always add risk and portfolio management
120 | workflow.add_node("risk_management_agent", risk_management_agent)
121 | workflow.add_node("portfolio_management_agent", portfolio_management_agent)
122 |
123 | # Connect selected analysts to risk management
124 | for analyst_key in selected_analysts:
125 | node_name = analyst_nodes[analyst_key][0]
126 | workflow.add_edge(node_name, "risk_management_agent")
127 |
128 | workflow.add_edge("risk_management_agent", "portfolio_management_agent")
129 | workflow.add_edge("portfolio_management_agent", END)
130 |
131 | workflow.set_entry_point("start_node")
132 | return workflow
133 |
134 |
135 | if __name__ == "__main__":
136 | parser = argparse.ArgumentParser(description="Run the hedge fund trading system")
137 | parser.add_argument("--initial-cash", type=float, default=100000.0, help="Initial cash position. Defaults to 100000.0)")
138 | parser.add_argument("--margin-requirement", type=float, default=0.0, help="Initial margin requirement. Defaults to 0.0")
139 | parser.add_argument("--tickers", type=str, required=True, help="Comma-separated list of stock ticker symbols")
140 | parser.add_argument(
141 | "--start-date",
142 | type=str,
143 | help="Start date (YYYY-MM-DD). Defaults to 3 months before end date",
144 | )
145 | parser.add_argument("--end-date", type=str, help="End date (YYYY-MM-DD). Defaults to today")
146 | parser.add_argument("--show-reasoning", action="store_true", help="Show reasoning from each agent")
147 | parser.add_argument("--show-agent-graph", action="store_true", help="Show the agent graph")
148 | parser.add_argument("--ollama", action="store_true", help="Use Ollama for local LLM inference")
149 |
150 | args = parser.parse_args()
151 |
152 | # Parse tickers from comma-separated string
153 | tickers = [ticker.strip() for ticker in args.tickers.split(",")]
154 |
155 | # Select analysts
156 | selected_analysts = None
157 | choices = questionary.checkbox(
158 | "Select your AI analysts.",
159 | choices=[questionary.Choice(display, value=value) for display, value in ANALYST_ORDER],
160 | instruction="\n\nInstructions: \n1. Press Space to select/unselect analysts.\n2. Press 'a' to select/unselect all.\n3. Press Enter when done to run the hedge fund.\n",
161 | validate=lambda x: len(x) > 0 or "You must select at least one analyst.",
162 | style=questionary.Style(
163 | [
164 | ("checkbox-selected", "fg:green"),
165 | ("selected", "fg:green noinherit"),
166 | ("highlighted", "noinherit"),
167 | ("pointer", "noinherit"),
168 | ]
169 | ),
170 | ).ask()
171 |
172 | if not choices:
173 | print("\n\nInterrupt received. Exiting...")
174 | sys.exit(0)
175 | else:
176 | selected_analysts = choices
177 | print(f"\nSelected analysts: {', '.join(Fore.GREEN + choice.title().replace('_', ' ') + Style.RESET_ALL for choice in choices)}\n")
178 |
179 | # Select LLM model based on whether Ollama is being used
180 | model_choice = None
181 | model_provider = None
182 |
183 | if args.ollama:
184 | print(f"{Fore.CYAN}Using Ollama for local LLM inference.{Style.RESET_ALL}")
185 |
186 | # Select from Ollama-specific models
187 | model_choice = questionary.select(
188 | "Select your Ollama model:",
189 | choices=[questionary.Choice(display, value=value) for display, value, _ in OLLAMA_LLM_ORDER],
190 | style=questionary.Style(
191 | [
192 | ("selected", "fg:green bold"),
193 | ("pointer", "fg:green bold"),
194 | ("highlighted", "fg:green"),
195 | ("answer", "fg:green bold"),
196 | ]
197 | ),
198 | ).ask()
199 |
200 | if not model_choice:
201 | print("\n\nInterrupt received. Exiting...")
202 | sys.exit(0)
203 |
204 | # Ensure Ollama is installed, running, and the model is available
205 | if not ensure_ollama_and_model(model_choice):
206 | print(f"{Fore.RED}Cannot proceed without Ollama and the selected model.{Style.RESET_ALL}")
207 | sys.exit(1)
208 |
209 | model_provider = ModelProvider.OLLAMA.value
210 | print(f"\nSelected {Fore.CYAN}Ollama{Style.RESET_ALL} model: {Fore.GREEN + Style.BRIGHT}{model_choice}{Style.RESET_ALL}\n")
211 | else:
212 | # Use the standard cloud-based LLM selection
213 | model_choice = questionary.select(
214 | "Select your LLM model:",
215 | choices=[questionary.Choice(display, value=value) for display, value, _ in LLM_ORDER],
216 | style=questionary.Style(
217 | [
218 | ("selected", "fg:green bold"),
219 | ("pointer", "fg:green bold"),
220 | ("highlighted", "fg:green"),
221 | ("answer", "fg:green bold"),
222 | ]
223 | ),
224 | ).ask()
225 |
226 | if not model_choice:
227 | print("\n\nInterrupt received. Exiting...")
228 | sys.exit(0)
229 | else:
230 | # Get model info using the helper function
231 | model_info = get_model_info(model_choice)
232 | if model_info:
233 | model_provider = model_info.provider.value
234 | print(f"\nSelected {Fore.CYAN}{model_provider}{Style.RESET_ALL} model: {Fore.GREEN + Style.BRIGHT}{model_choice}{Style.RESET_ALL}\n")
235 | else:
236 | model_provider = "Unknown"
237 | print(f"\nSelected model: {Fore.GREEN + Style.BRIGHT}{model_choice}{Style.RESET_ALL}\n")
238 |
239 | # Create the workflow with selected analysts
240 | workflow = create_workflow(selected_analysts)
241 | app = workflow.compile()
242 |
243 | if args.show_agent_graph:
244 | file_path = ""
245 | if selected_analysts is not None:
246 | for selected_analyst in selected_analysts:
247 | file_path += selected_analyst + "_"
248 | file_path += "graph.png"
249 | save_graph_as_png(app, file_path)
250 |
251 | # Validate dates if provided
252 | if args.start_date:
253 | try:
254 | datetime.strptime(args.start_date, "%Y-%m-%d")
255 | except ValueError:
256 | raise ValueError("Start date must be in YYYY-MM-DD format")
257 |
258 | if args.end_date:
259 | try:
260 | datetime.strptime(args.end_date, "%Y-%m-%d")
261 | except ValueError:
262 | raise ValueError("End date must be in YYYY-MM-DD format")
263 |
264 | # Set the start and end dates
265 | end_date = args.end_date or datetime.now().strftime("%Y-%m-%d")
266 | if not args.start_date:
267 | # Calculate 3 months before end_date
268 | end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
269 | start_date = (end_date_obj - relativedelta(months=3)).strftime("%Y-%m-%d")
270 | else:
271 | start_date = args.start_date
272 |
273 | # Initialize portfolio with cash amount and stock positions
274 | portfolio = {
275 | "cash": args.initial_cash, # Initial cash amount
276 | "margin_requirement": args.margin_requirement, # Initial margin requirement
277 | "margin_used": 0.0, # total margin usage across all short positions
278 | "positions": {
279 | ticker: {
280 | "long": 0, # Number of shares held long
281 | "short": 0, # Number of shares held short
282 | "long_cost_basis": 0.0, # Average cost basis for long positions
283 | "short_cost_basis": 0.0, # Average price at which shares were sold short
284 | "short_margin_used": 0.0, # Dollars of margin used for this ticker's short
285 | }
286 | for ticker in tickers
287 | },
288 | "realized_gains": {
289 | ticker: {
290 | "long": 0.0, # Realized gains from long positions
291 | "short": 0.0, # Realized gains from short positions
292 | }
293 | for ticker in tickers
294 | },
295 | }
296 |
297 | # Run the hedge fund
298 | result = run_hedge_fund(
299 | tickers=tickers,
300 | start_date=start_date,
301 | end_date=end_date,
302 | portfolio=portfolio,
303 | show_reasoning=args.show_reasoning,
304 | selected_analysts=selected_analysts,
305 | model_name=model_choice,
306 | model_provider=model_provider,
307 | )
308 | print_trading_output(result)
309 |
--------------------------------------------------------------------------------
/src/tools/api.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import pandas as pd
4 | import requests
5 |
6 | from data.cache import get_cache
7 | from data.models import (
8 | CompanyNews,
9 | CompanyNewsResponse,
10 | FinancialMetrics,
11 | FinancialMetricsResponse,
12 | Price,
13 | PriceResponse,
14 | LineItem,
15 | LineItemResponse,
16 | InsiderTrade,
17 | InsiderTradeResponse,
18 | CompanyFactsResponse,
19 | )
20 |
21 | # Global cache instance
22 | _cache = get_cache()
23 |
24 |
25 | def get_prices(ticker: str, start_date: str, end_date: str) -> list[Price]:
26 | """Fetch price data from cache or API."""
27 | # Check cache first
28 | if cached_data := _cache.get_prices(ticker):
29 | # Filter cached data by date range and convert to Price objects
30 | filtered_data = [Price(**price) for price in cached_data if start_date <= price["time"] <= end_date]
31 | if filtered_data:
32 | return filtered_data
33 |
34 | # If not in cache or no data in range, fetch from API
35 | headers = {}
36 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"):
37 | headers["X-API-KEY"] = api_key
38 |
39 | url = f"https://api.financialdatasets.ai/prices/?ticker={ticker}&interval=day&interval_multiplier=1&start_date={start_date}&end_date={end_date}"
40 | response = requests.get(url, headers=headers)
41 | if response.status_code != 200:
42 | raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}")
43 |
44 | # Parse response with Pydantic model
45 | price_response = PriceResponse(**response.json())
46 | prices = price_response.prices
47 |
48 | if not prices:
49 | return []
50 |
51 | # Cache the results as dicts
52 | _cache.set_prices(ticker, [p.model_dump() for p in prices])
53 | return prices
54 |
55 |
56 | def get_financial_metrics(
57 | ticker: str,
58 | end_date: str,
59 | period: str = "ttm",
60 | limit: int = 10,
61 | ) -> list[FinancialMetrics]:
62 | """Fetch financial metrics from cache or API."""
63 | # Check cache first
64 | if cached_data := _cache.get_financial_metrics(ticker):
65 | # Filter cached data by date and limit
66 | filtered_data = [FinancialMetrics(**metric) for metric in cached_data if metric["report_period"] <= end_date]
67 | filtered_data.sort(key=lambda x: x.report_period, reverse=True)
68 | if filtered_data:
69 | return filtered_data[:limit]
70 |
71 | # If not in cache or insufficient data, fetch from API
72 | headers = {}
73 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"):
74 | headers["X-API-KEY"] = api_key
75 |
76 | url = f"https://api.financialdatasets.ai/financial-metrics/?ticker={ticker}&report_period_lte={end_date}&limit={limit}&period={period}"
77 | response = requests.get(url, headers=headers)
78 | if response.status_code != 200:
79 | raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}")
80 |
81 | # Parse response with Pydantic model
82 | metrics_response = FinancialMetricsResponse(**response.json())
83 | # Return the FinancialMetrics objects directly instead of converting to dict
84 | financial_metrics = metrics_response.financial_metrics
85 |
86 | if not financial_metrics:
87 | return []
88 |
89 | # Cache the results as dicts
90 | _cache.set_financial_metrics(ticker, [m.model_dump() for m in financial_metrics])
91 | return financial_metrics
92 |
93 |
94 | def search_line_items(
95 | ticker: str,
96 | line_items: list[str],
97 | end_date: str,
98 | period: str = "ttm",
99 | limit: int = 10,
100 | ) -> list[LineItem]:
101 | """Fetch line items from API."""
102 | # If not in cache or insufficient data, fetch from API
103 | headers = {}
104 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"):
105 | headers["X-API-KEY"] = api_key
106 |
107 | url = "https://api.financialdatasets.ai/financials/search/line-items"
108 |
109 | body = {
110 | "tickers": [ticker],
111 | "line_items": line_items,
112 | "end_date": end_date,
113 | "period": period,
114 | "limit": limit,
115 | }
116 | response = requests.post(url, headers=headers, json=body)
117 | if response.status_code != 200:
118 | raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}")
119 | data = response.json()
120 | response_model = LineItemResponse(**data)
121 | search_results = response_model.search_results
122 | if not search_results:
123 | return []
124 |
125 | # Cache the results
126 | return search_results[:limit]
127 |
128 |
129 | def get_insider_trades(
130 | ticker: str,
131 | end_date: str,
132 | start_date: str | None = None,
133 | limit: int = 1000,
134 | ) -> list[InsiderTrade]:
135 | """Fetch insider trades from cache or API."""
136 | # Check cache first
137 | if cached_data := _cache.get_insider_trades(ticker):
138 | # Filter cached data by date range
139 | filtered_data = [InsiderTrade(**trade) for trade in cached_data
140 | if (start_date is None or (trade.get("transaction_date") or trade["filing_date"]) >= start_date)
141 | and (trade.get("transaction_date") or trade["filing_date"]) <= end_date]
142 | filtered_data.sort(key=lambda x: x.transaction_date or x.filing_date, reverse=True)
143 | if filtered_data:
144 | return filtered_data
145 |
146 | # If not in cache or insufficient data, fetch from API
147 | headers = {}
148 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"):
149 | headers["X-API-KEY"] = api_key
150 |
151 | all_trades = []
152 | current_end_date = end_date
153 |
154 | while True:
155 | url = f"https://api.financialdatasets.ai/insider-trades/?ticker={ticker}&filing_date_lte={current_end_date}"
156 | if start_date:
157 | url += f"&filing_date_gte={start_date}"
158 | url += f"&limit={limit}"
159 |
160 | response = requests.get(url, headers=headers)
161 | if response.status_code != 200:
162 | raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}")
163 |
164 | data = response.json()
165 | response_model = InsiderTradeResponse(**data)
166 | insider_trades = response_model.insider_trades
167 |
168 | if not insider_trades:
169 | break
170 |
171 | all_trades.extend(insider_trades)
172 |
173 | # Only continue pagination if we have a start_date and got a full page
174 | if not start_date or len(insider_trades) < limit:
175 | break
176 |
177 | # Update end_date to the oldest filing date from current batch for next iteration
178 | current_end_date = min(trade.filing_date for trade in insider_trades).split('T')[0]
179 |
180 | # If we've reached or passed the start_date, we can stop
181 | if current_end_date <= start_date:
182 | break
183 |
184 | if not all_trades:
185 | return []
186 |
187 | # Cache the results
188 | _cache.set_insider_trades(ticker, [trade.model_dump() for trade in all_trades])
189 | return all_trades
190 |
191 |
192 | def get_company_news(
193 | ticker: str,
194 | end_date: str,
195 | start_date: str | None = None,
196 | limit: int = 1000,
197 | ) -> list[CompanyNews]:
198 | """Fetch company news from cache or API."""
199 | # Check cache first
200 | if cached_data := _cache.get_company_news(ticker):
201 | # Filter cached data by date range
202 | filtered_data = [CompanyNews(**news) for news in cached_data
203 | if (start_date is None or news["date"] >= start_date)
204 | and news["date"] <= end_date]
205 | filtered_data.sort(key=lambda x: x.date, reverse=True)
206 | if filtered_data:
207 | return filtered_data
208 |
209 | # If not in cache or insufficient data, fetch from API
210 | headers = {}
211 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"):
212 | headers["X-API-KEY"] = api_key
213 |
214 | all_news = []
215 | current_end_date = end_date
216 |
217 | while True:
218 | url = f"https://api.financialdatasets.ai/news/?ticker={ticker}&end_date={current_end_date}"
219 | if start_date:
220 | url += f"&start_date={start_date}"
221 | url += f"&limit={limit}"
222 |
223 | response = requests.get(url, headers=headers)
224 | if response.status_code != 200:
225 | raise Exception(f"Error fetching data: {ticker} - {response.status_code} - {response.text}")
226 |
227 | data = response.json()
228 | response_model = CompanyNewsResponse(**data)
229 | company_news = response_model.news
230 |
231 | if not company_news:
232 | break
233 |
234 | all_news.extend(company_news)
235 |
236 | # Only continue pagination if we have a start_date and got a full page
237 | if not start_date or len(company_news) < limit:
238 | break
239 |
240 | # Update end_date to the oldest date from current batch for next iteration
241 | current_end_date = min(news.date for news in company_news).split('T')[0]
242 |
243 | # If we've reached or passed the start_date, we can stop
244 | if current_end_date <= start_date:
245 | break
246 |
247 | if not all_news:
248 | return []
249 |
250 | # Cache the results
251 | _cache.set_company_news(ticker, [news.model_dump() for news in all_news])
252 | return all_news
253 |
254 |
255 | def get_market_cap(
256 | ticker: str,
257 | end_date: str,
258 | ) -> float | None:
259 | """Fetch market cap from the API."""
260 | # Check if end_date is today
261 | if end_date == datetime.datetime.now().strftime("%Y-%m-%d"):
262 | # Get the market cap from company facts API
263 | headers = {}
264 | if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"):
265 | headers["X-API-KEY"] = api_key
266 |
267 | url = f"https://api.financialdatasets.ai/company/facts/?ticker={ticker}"
268 | response = requests.get(url, headers=headers)
269 | if response.status_code != 200:
270 | print(f"Error fetching company facts: {ticker} - {response.status_code}")
271 | return None
272 |
273 | data = response.json()
274 | response_model = CompanyFactsResponse(**data)
275 | return response_model.company_facts.market_cap
276 |
277 | financial_metrics = get_financial_metrics(ticker, end_date)
278 | if not financial_metrics:
279 | return None
280 |
281 | market_cap = financial_metrics[0].market_cap
282 |
283 | if not market_cap:
284 | return None
285 |
286 | return market_cap
287 |
288 |
289 | def prices_to_df(prices: list[Price]) -> pd.DataFrame:
290 | """Convert prices to a DataFrame."""
291 | df = pd.DataFrame([p.model_dump() for p in prices])
292 | df["Date"] = pd.to_datetime(df["time"])
293 | df.set_index("Date", inplace=True)
294 | numeric_cols = ["open", "close", "high", "low", "volume"]
295 | for col in numeric_cols:
296 | df[col] = pd.to_numeric(df[col], errors="coerce")
297 | df.sort_index(inplace=True)
298 | return df
299 |
300 |
301 | # Update the get_price_data function to use the new functions
302 | def get_price_data(ticker: str, start_date: str, end_date: str) -> pd.DataFrame:
303 | prices = get_prices(ticker, start_date, end_date)
304 | return prices_to_df(prices)
305 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # This file can be empty
2 |
3 | """Utility modules for the application."""
4 |
--------------------------------------------------------------------------------
/src/utils/analysts.py:
--------------------------------------------------------------------------------
1 | """Constants and utilities related to analysts configuration."""
2 |
3 | from agents.ben_graham import ben_graham_agent
4 | from agents.bill_ackman import bill_ackman_agent
5 | from agents.cathie_wood import cathie_wood_agent
6 | from agents.charlie_munger import charlie_munger_agent
7 | from agents.fundamentals import fundamentals_agent
8 | from agents.michael_burry import michael_burry_agent
9 | from agents.phil_fisher import phil_fisher_agent
10 | from agents.peter_lynch import peter_lynch_agent
11 | from agents.sentiment import sentiment_agent
12 | from agents.stanley_druckenmiller import stanley_druckenmiller_agent
13 | from agents.technicals import technical_analyst_agent
14 | from agents.valuation import valuation_agent
15 | from agents.warren_buffett import warren_buffett_agent
16 |
17 | # Define analyst configuration - single source of truth
18 | ANALYST_CONFIG = {
19 | "ben_graham": {
20 | "display_name": "Ben Graham",
21 | "agent_func": ben_graham_agent,
22 | "order": 0,
23 | },
24 | "bill_ackman": {
25 | "display_name": "Bill Ackman",
26 | "agent_func": bill_ackman_agent,
27 | "order": 1,
28 | },
29 | "cathie_wood": {
30 | "display_name": "Cathie Wood",
31 | "agent_func": cathie_wood_agent,
32 | "order": 2,
33 | },
34 | "charlie_munger": {
35 | "display_name": "Charlie Munger",
36 | "agent_func": charlie_munger_agent,
37 | "order": 3,
38 | },
39 | "michael_burry": {
40 | "display_name": "Michael Burry",
41 | "agent_func": michael_burry_agent,
42 | "order": 4,
43 | },
44 | "peter_lynch": {
45 | "display_name": "Peter Lynch",
46 | "agent_func": peter_lynch_agent,
47 | "order": 5,
48 | },
49 | "phil_fisher": {
50 | "display_name": "Phil Fisher",
51 | "agent_func": phil_fisher_agent,
52 | "order": 6,
53 | },
54 | "stanley_druckenmiller": {
55 | "display_name": "Stanley Druckenmiller",
56 | "agent_func": stanley_druckenmiller_agent,
57 | "order": 7,
58 | },
59 | "warren_buffett": {
60 | "display_name": "Warren Buffett",
61 | "agent_func": warren_buffett_agent,
62 | "order": 8,
63 | },
64 | "technical_analyst": {
65 | "display_name": "Technical Analyst",
66 | "agent_func": technical_analyst_agent,
67 | "order": 9,
68 | },
69 | "fundamentals_analyst": {
70 | "display_name": "Fundamentals Analyst",
71 | "agent_func": fundamentals_agent,
72 | "order": 10,
73 | },
74 | "sentiment_analyst": {
75 | "display_name": "Sentiment Analyst",
76 | "agent_func": sentiment_agent,
77 | "order": 11,
78 | },
79 | "valuation_analyst": {
80 | "display_name": "Valuation Analyst",
81 | "agent_func": valuation_agent,
82 | "order": 12,
83 | },
84 | }
85 |
86 | # Derive ANALYST_ORDER from ANALYST_CONFIG for backwards compatibility
87 | ANALYST_ORDER = [(config["display_name"], key) for key, config in sorted(ANALYST_CONFIG.items(), key=lambda x: x[1]["order"])]
88 |
89 |
90 | def get_analyst_nodes():
91 | """Get the mapping of analyst keys to their (node_name, agent_func) tuples."""
92 | return {key: (f"{key}_agent", config["agent_func"]) for key, config in ANALYST_CONFIG.items()}
93 |
--------------------------------------------------------------------------------
/src/utils/display.py:
--------------------------------------------------------------------------------
1 | from colorama import Fore, Style
2 | from tabulate import tabulate
3 | from .analysts import ANALYST_ORDER
4 | import os
5 | import json
6 |
7 |
8 | def sort_agent_signals(signals):
9 | """Sort agent signals in a consistent order."""
10 | # Create order mapping from ANALYST_ORDER
11 | analyst_order = {display: idx for idx, (display, _) in enumerate(ANALYST_ORDER)}
12 | analyst_order["Risk Management"] = len(ANALYST_ORDER) # Add Risk Management at the end
13 |
14 | return sorted(signals, key=lambda x: analyst_order.get(x[0], 999))
15 |
16 |
17 | def print_trading_output(result: dict) -> None:
18 | """
19 | Print formatted trading results with colored tables for multiple tickers.
20 |
21 | Args:
22 | result (dict): Dictionary containing decisions and analyst signals for multiple tickers
23 | """
24 | decisions = result.get("decisions")
25 | if not decisions:
26 | print(f"{Fore.RED}No trading decisions available{Style.RESET_ALL}")
27 | return
28 |
29 | # Print decisions for each ticker
30 | for ticker, decision in decisions.items():
31 | print(f"\n{Fore.WHITE}{Style.BRIGHT}Analysis for {Fore.CYAN}{ticker}{Style.RESET_ALL}")
32 | print(f"{Fore.WHITE}{Style.BRIGHT}{'=' * 50}{Style.RESET_ALL}")
33 |
34 | # Prepare analyst signals table for this ticker
35 | table_data = []
36 | for agent, signals in result.get("analyst_signals", {}).items():
37 | if ticker not in signals:
38 | continue
39 |
40 | # Skip Risk Management agent in the signals section
41 | if agent == "risk_management_agent":
42 | continue
43 |
44 | signal = signals[ticker]
45 | agent_name = agent.replace("_agent", "").replace("_", " ").title()
46 | signal_type = signal.get("signal", "").upper()
47 | confidence = signal.get("confidence", 0)
48 |
49 | signal_color = {
50 | "BULLISH": Fore.GREEN,
51 | "BEARISH": Fore.RED,
52 | "NEUTRAL": Fore.YELLOW,
53 | }.get(signal_type, Fore.WHITE)
54 |
55 | # Get reasoning if available
56 | reasoning_str = ""
57 | if "reasoning" in signal and signal["reasoning"]:
58 | reasoning = signal["reasoning"]
59 |
60 | # Handle different types of reasoning (string, dict, etc.)
61 | if isinstance(reasoning, str):
62 | reasoning_str = reasoning
63 | elif isinstance(reasoning, dict):
64 | # Convert dict to string representation
65 | reasoning_str = json.dumps(reasoning, indent=2)
66 | else:
67 | # Convert any other type to string
68 | reasoning_str = str(reasoning)
69 |
70 | # Wrap long reasoning text to make it more readable
71 | wrapped_reasoning = ""
72 | current_line = ""
73 | # Use a fixed width of 60 characters to match the table column width
74 | max_line_length = 60
75 | for word in reasoning_str.split():
76 | if len(current_line) + len(word) + 1 > max_line_length:
77 | wrapped_reasoning += current_line + "\n"
78 | current_line = word
79 | else:
80 | if current_line:
81 | current_line += " " + word
82 | else:
83 | current_line = word
84 | if current_line:
85 | wrapped_reasoning += current_line
86 |
87 | reasoning_str = wrapped_reasoning
88 |
89 | table_data.append(
90 | [
91 | f"{Fore.CYAN}{agent_name}{Style.RESET_ALL}",
92 | f"{signal_color}{signal_type}{Style.RESET_ALL}",
93 | f"{Fore.WHITE}{confidence}%{Style.RESET_ALL}",
94 | f"{Fore.WHITE}{reasoning_str}{Style.RESET_ALL}",
95 | ]
96 | )
97 |
98 | # Sort the signals according to the predefined order
99 | table_data = sort_agent_signals(table_data)
100 |
101 | print(f"\n{Fore.WHITE}{Style.BRIGHT}AGENT ANALYSIS:{Style.RESET_ALL} [{Fore.CYAN}{ticker}{Style.RESET_ALL}]")
102 | print(
103 | tabulate(
104 | table_data,
105 | headers=[f"{Fore.WHITE}Agent", "Signal", "Confidence", "Reasoning"],
106 | tablefmt="grid",
107 | colalign=("left", "center", "right", "left"),
108 | )
109 | )
110 |
111 | # Print Trading Decision Table
112 | action = decision.get("action", "").upper()
113 | action_color = {
114 | "BUY": Fore.GREEN,
115 | "SELL": Fore.RED,
116 | "HOLD": Fore.YELLOW,
117 | "COVER": Fore.GREEN,
118 | "SHORT": Fore.RED,
119 | }.get(action, Fore.WHITE)
120 |
121 | # Get reasoning and format it
122 | reasoning = decision.get("reasoning", "")
123 | # Wrap long reasoning text to make it more readable
124 | wrapped_reasoning = ""
125 | if reasoning:
126 | current_line = ""
127 | # Use a fixed width of 60 characters to match the table column width
128 | max_line_length = 60
129 | for word in reasoning.split():
130 | if len(current_line) + len(word) + 1 > max_line_length:
131 | wrapped_reasoning += current_line + "\n"
132 | current_line = word
133 | else:
134 | if current_line:
135 | current_line += " " + word
136 | else:
137 | current_line = word
138 | if current_line:
139 | wrapped_reasoning += current_line
140 |
141 | decision_data = [
142 | ["Action", f"{action_color}{action}{Style.RESET_ALL}"],
143 | ["Quantity", f"{action_color}{decision.get('quantity')}{Style.RESET_ALL}"],
144 | [
145 | "Confidence",
146 | f"{Fore.WHITE}{decision.get('confidence'):.1f}%{Style.RESET_ALL}",
147 | ],
148 | ["Reasoning", f"{Fore.WHITE}{wrapped_reasoning}{Style.RESET_ALL}"],
149 | ]
150 |
151 | print(f"\n{Fore.WHITE}{Style.BRIGHT}TRADING DECISION:{Style.RESET_ALL} [{Fore.CYAN}{ticker}{Style.RESET_ALL}]")
152 | print(tabulate(decision_data, tablefmt="grid", colalign=("left", "left")))
153 |
154 | # Print Portfolio Summary
155 | print(f"\n{Fore.WHITE}{Style.BRIGHT}PORTFOLIO SUMMARY:{Style.RESET_ALL}")
156 | portfolio_data = []
157 |
158 | # Extract portfolio manager reasoning (common for all tickers)
159 | portfolio_manager_reasoning = None
160 | for ticker, decision in decisions.items():
161 | if decision.get("reasoning"):
162 | portfolio_manager_reasoning = decision.get("reasoning")
163 | break
164 |
165 | for ticker, decision in decisions.items():
166 | action = decision.get("action", "").upper()
167 | action_color = {
168 | "BUY": Fore.GREEN,
169 | "SELL": Fore.RED,
170 | "HOLD": Fore.YELLOW,
171 | "COVER": Fore.GREEN,
172 | "SHORT": Fore.RED,
173 | }.get(action, Fore.WHITE)
174 | portfolio_data.append(
175 | [
176 | f"{Fore.CYAN}{ticker}{Style.RESET_ALL}",
177 | f"{action_color}{action}{Style.RESET_ALL}",
178 | f"{action_color}{decision.get('quantity')}{Style.RESET_ALL}",
179 | f"{Fore.WHITE}{decision.get('confidence'):.1f}%{Style.RESET_ALL}",
180 | ]
181 | )
182 |
183 | headers = [f"{Fore.WHITE}Ticker", "Action", "Quantity", "Confidence"]
184 |
185 | # Print the portfolio summary table
186 | print(
187 | tabulate(
188 | portfolio_data,
189 | headers=headers,
190 | tablefmt="grid",
191 | colalign=("left", "center", "right", "right"),
192 | )
193 | )
194 |
195 | # Print Portfolio Manager's reasoning if available
196 | if portfolio_manager_reasoning:
197 | # Handle different types of reasoning (string, dict, etc.)
198 | reasoning_str = ""
199 | if isinstance(portfolio_manager_reasoning, str):
200 | reasoning_str = portfolio_manager_reasoning
201 | elif isinstance(portfolio_manager_reasoning, dict):
202 | # Convert dict to string representation
203 | reasoning_str = json.dumps(portfolio_manager_reasoning, indent=2)
204 | else:
205 | # Convert any other type to string
206 | reasoning_str = str(portfolio_manager_reasoning)
207 |
208 | # Wrap long reasoning text to make it more readable
209 | wrapped_reasoning = ""
210 | current_line = ""
211 | # Use a fixed width of 60 characters to match the table column width
212 | max_line_length = 60
213 | for word in reasoning_str.split():
214 | if len(current_line) + len(word) + 1 > max_line_length:
215 | wrapped_reasoning += current_line + "\n"
216 | current_line = word
217 | else:
218 | if current_line:
219 | current_line += " " + word
220 | else:
221 | current_line = word
222 | if current_line:
223 | wrapped_reasoning += current_line
224 |
225 | print(f"\n{Fore.WHITE}{Style.BRIGHT}Portfolio Strategy:{Style.RESET_ALL}")
226 | print(f"{Fore.CYAN}{wrapped_reasoning}{Style.RESET_ALL}")
227 |
228 |
229 | def print_backtest_results(table_rows: list) -> None:
230 | """Print the backtest results in a nicely formatted table"""
231 | # Clear the screen
232 | os.system("cls" if os.name == "nt" else "clear")
233 |
234 | # Split rows into ticker rows and summary rows
235 | ticker_rows = []
236 | summary_rows = []
237 |
238 | for row in table_rows:
239 | if isinstance(row[1], str) and "PORTFOLIO SUMMARY" in row[1]:
240 | summary_rows.append(row)
241 | else:
242 | ticker_rows.append(row)
243 |
244 |
245 | # Display latest portfolio summary
246 | if summary_rows:
247 | latest_summary = summary_rows[-1]
248 | print(f"\n{Fore.WHITE}{Style.BRIGHT}PORTFOLIO SUMMARY:{Style.RESET_ALL}")
249 |
250 | # Extract values and remove commas before converting to float
251 | cash_str = latest_summary[7].split("$")[1].split(Style.RESET_ALL)[0].replace(",", "")
252 | position_str = latest_summary[6].split("$")[1].split(Style.RESET_ALL)[0].replace(",", "")
253 | total_str = latest_summary[8].split("$")[1].split(Style.RESET_ALL)[0].replace(",", "")
254 |
255 | print(f"Cash Balance: {Fore.CYAN}${float(cash_str):,.2f}{Style.RESET_ALL}")
256 | print(f"Total Position Value: {Fore.YELLOW}${float(position_str):,.2f}{Style.RESET_ALL}")
257 | print(f"Total Value: {Fore.WHITE}${float(total_str):,.2f}{Style.RESET_ALL}")
258 | print(f"Return: {latest_summary[9]}")
259 |
260 | # Display performance metrics if available
261 | if latest_summary[10]: # Sharpe ratio
262 | print(f"Sharpe Ratio: {latest_summary[10]}")
263 | if latest_summary[11]: # Sortino ratio
264 | print(f"Sortino Ratio: {latest_summary[11]}")
265 | if latest_summary[12]: # Max drawdown
266 | print(f"Max Drawdown: {latest_summary[12]}")
267 |
268 | # Add vertical spacing
269 | print("\n" * 2)
270 |
271 | # Print the table with just ticker rows
272 | print(
273 | tabulate(
274 | ticker_rows,
275 | headers=[
276 | "Date",
277 | "Ticker",
278 | "Action",
279 | "Quantity",
280 | "Price",
281 | "Shares",
282 | "Position Value",
283 | "Bullish",
284 | "Bearish",
285 | "Neutral",
286 | ],
287 | tablefmt="grid",
288 | colalign=(
289 | "left", # Date
290 | "left", # Ticker
291 | "center", # Action
292 | "right", # Quantity
293 | "right", # Price
294 | "right", # Shares
295 | "right", # Position Value
296 | "right", # Bullish
297 | "right", # Bearish
298 | "right", # Neutral
299 | ),
300 | )
301 | )
302 |
303 | # Add vertical spacing
304 | print("\n" * 4)
305 |
306 |
307 | def format_backtest_row(
308 | date: str,
309 | ticker: str,
310 | action: str,
311 | quantity: float,
312 | price: float,
313 | shares_owned: float,
314 | position_value: float,
315 | bullish_count: int,
316 | bearish_count: int,
317 | neutral_count: int,
318 | is_summary: bool = False,
319 | total_value: float = None,
320 | return_pct: float = None,
321 | cash_balance: float = None,
322 | total_position_value: float = None,
323 | sharpe_ratio: float = None,
324 | sortino_ratio: float = None,
325 | max_drawdown: float = None,
326 | ) -> list[any]:
327 | """Format a row for the backtest results table"""
328 | # Color the action
329 | action_color = {
330 | "BUY": Fore.GREEN,
331 | "COVER": Fore.GREEN,
332 | "SELL": Fore.RED,
333 | "SHORT": Fore.RED,
334 | "HOLD": Fore.WHITE,
335 | }.get(action.upper(), Fore.WHITE)
336 |
337 | if is_summary:
338 | return_color = Fore.GREEN if return_pct >= 0 else Fore.RED
339 | return [
340 | date,
341 | f"{Fore.WHITE}{Style.BRIGHT}PORTFOLIO SUMMARY{Style.RESET_ALL}",
342 | "", # Action
343 | "", # Quantity
344 | "", # Price
345 | "", # Shares
346 | f"{Fore.YELLOW}${total_position_value:,.2f}{Style.RESET_ALL}", # Total Position Value
347 | f"{Fore.CYAN}${cash_balance:,.2f}{Style.RESET_ALL}", # Cash Balance
348 | f"{Fore.WHITE}${total_value:,.2f}{Style.RESET_ALL}", # Total Value
349 | f"{return_color}{return_pct:+.2f}%{Style.RESET_ALL}", # Return
350 | f"{Fore.YELLOW}{sharpe_ratio:.2f}{Style.RESET_ALL}" if sharpe_ratio is not None else "", # Sharpe Ratio
351 | f"{Fore.YELLOW}{sortino_ratio:.2f}{Style.RESET_ALL}" if sortino_ratio is not None else "", # Sortino Ratio
352 | f"{Fore.RED}{abs(max_drawdown):.2f}%{Style.RESET_ALL}" if max_drawdown is not None else "", # Max Drawdown
353 | ]
354 | else:
355 | return [
356 | date,
357 | f"{Fore.CYAN}{ticker}{Style.RESET_ALL}",
358 | f"{action_color}{action.upper()}{Style.RESET_ALL}",
359 | f"{action_color}{quantity:,.0f}{Style.RESET_ALL}",
360 | f"{Fore.WHITE}{price:,.2f}{Style.RESET_ALL}",
361 | f"{Fore.WHITE}{shares_owned:,.0f}{Style.RESET_ALL}",
362 | f"{Fore.YELLOW}{position_value:,.2f}{Style.RESET_ALL}",
363 | f"{Fore.GREEN}{bullish_count}{Style.RESET_ALL}",
364 | f"{Fore.RED}{bearish_count}{Style.RESET_ALL}",
365 | f"{Fore.BLUE}{neutral_count}{Style.RESET_ALL}",
366 | ]
367 |
--------------------------------------------------------------------------------
/src/utils/docker.py:
--------------------------------------------------------------------------------
1 | """Utilities for working with Ollama models in Docker environments"""
2 |
3 | import requests
4 | import time
5 | from colorama import Fore, Style
6 | import questionary
7 |
8 | def ensure_ollama_and_model(model_name: str, ollama_url: str) -> bool:
9 | """Ensure the Ollama model is available in a Docker environment."""
10 | print(f"{Fore.CYAN}Docker environment detected.{Style.RESET_ALL}")
11 |
12 | # Step 1: Check if Ollama service is available
13 | if not is_ollama_available(ollama_url):
14 | return False
15 |
16 | # Step 2: Check if model is already available
17 | available_models = get_available_models(ollama_url)
18 | if model_name in available_models:
19 | print(f"{Fore.GREEN}Model {model_name} is available in the Docker Ollama container.{Style.RESET_ALL}")
20 | return True
21 |
22 | # Step 3: Model not available - ask if user wants to download
23 | print(f"{Fore.YELLOW}Model {model_name} is not available in the Docker Ollama container.{Style.RESET_ALL}")
24 |
25 | if not questionary.confirm(f"Do you want to download {model_name}?").ask():
26 | print(f"{Fore.RED}Cannot proceed without the model.{Style.RESET_ALL}")
27 | return False
28 |
29 | # Step 4: Download the model
30 | return download_model(model_name, ollama_url)
31 |
32 |
33 | def is_ollama_available(ollama_url: str) -> bool:
34 | """Check if Ollama service is available in Docker environment."""
35 | try:
36 | response = requests.get(f"{ollama_url}/api/version", timeout=5)
37 | if response.status_code == 200:
38 | return True
39 |
40 | print(f"{Fore.RED}Cannot connect to Ollama service at {ollama_url}.{Style.RESET_ALL}")
41 | print(f"{Fore.YELLOW}Make sure the Ollama service is running in your Docker environment.{Style.RESET_ALL}")
42 | return False
43 | except requests.RequestException as e:
44 | print(f"{Fore.RED}Error connecting to Ollama service: {e}{Style.RESET_ALL}")
45 | return False
46 |
47 |
48 | def get_available_models(ollama_url: str) -> list:
49 | """Get list of available models in Docker environment."""
50 | try:
51 | response = requests.get(f"{ollama_url}/api/tags", timeout=5)
52 | if response.status_code == 200:
53 | models = response.json().get("models", [])
54 | return [m["name"] for m in models]
55 |
56 | print(f"{Fore.RED}Failed to get available models from Ollama service. Status code: {response.status_code}{Style.RESET_ALL}")
57 | return []
58 | except requests.RequestException as e:
59 | print(f"{Fore.RED}Error getting available models: {e}{Style.RESET_ALL}")
60 | return []
61 |
62 |
63 | def download_model(model_name: str, ollama_url: str) -> bool:
64 | """Download a model in Docker environment."""
65 | print(f"{Fore.YELLOW}Downloading model {model_name} to the Docker Ollama container...{Style.RESET_ALL}")
66 | print(f"{Fore.CYAN}This may take some time. Please be patient.{Style.RESET_ALL}")
67 |
68 | # Step 1: Initiate the download
69 | try:
70 | response = requests.post(f"{ollama_url}/api/pull", json={"name": model_name}, timeout=10)
71 | if response.status_code != 200:
72 | print(f"{Fore.RED}Failed to initiate model download. Status code: {response.status_code}{Style.RESET_ALL}")
73 | if response.text:
74 | print(f"{Fore.RED}Error: {response.text}{Style.RESET_ALL}")
75 | return False
76 | except requests.RequestException as e:
77 | print(f"{Fore.RED}Error initiating download request: {e}{Style.RESET_ALL}")
78 | return False
79 |
80 | # Step 2: Monitor the download progress
81 | print(f"{Fore.CYAN}Download initiated. Checking periodically for completion...{Style.RESET_ALL}")
82 |
83 | total_wait_time = 0
84 | max_wait_time = 1800 # 30 minutes max wait
85 | check_interval = 10 # Check every 10 seconds
86 |
87 | while total_wait_time < max_wait_time:
88 | # Check if the model has been downloaded
89 | available_models = get_available_models(ollama_url)
90 | if model_name in available_models:
91 | print(f"{Fore.GREEN}Model {model_name} downloaded successfully.{Style.RESET_ALL}")
92 | return True
93 |
94 | # Wait before checking again
95 | time.sleep(check_interval)
96 | total_wait_time += check_interval
97 |
98 | # Print a status message every minute
99 | if total_wait_time % 60 == 0:
100 | minutes = total_wait_time // 60
101 | print(f"{Fore.CYAN}Download in progress... ({minutes} minute{'s' if minutes != 1 else ''} elapsed){Style.RESET_ALL}")
102 |
103 | # If we get here, we've timed out
104 | print(f"{Fore.RED}Timed out waiting for model download to complete after {max_wait_time // 60} minutes.{Style.RESET_ALL}")
105 | return False
106 |
107 |
108 | def delete_model(model_name: str, ollama_url: str) -> bool:
109 | """Delete a model in Docker environment."""
110 | print(f"{Fore.YELLOW}Deleting model {model_name} from Docker container...{Style.RESET_ALL}")
111 |
112 | try:
113 | response = requests.delete(f"{ollama_url}/api/delete", json={"name": model_name}, timeout=10)
114 | if response.status_code == 200:
115 | print(f"{Fore.GREEN}Model {model_name} deleted successfully.{Style.RESET_ALL}")
116 | return True
117 | else:
118 | print(f"{Fore.RED}Failed to delete model. Status code: {response.status_code}{Style.RESET_ALL}")
119 | if response.text:
120 | print(f"{Fore.RED}Error: {response.text}{Style.RESET_ALL}")
121 | return False
122 | except requests.RequestException as e:
123 | print(f"{Fore.RED}Error deleting model: {e}{Style.RESET_ALL}")
124 | return False
--------------------------------------------------------------------------------
/src/utils/llm.py:
--------------------------------------------------------------------------------
1 | """Helper functions for LLM"""
2 |
3 | import json
4 | from typing import TypeVar, Type, Optional, Any
5 | from pydantic import BaseModel
6 | from utils.progress import progress
7 |
8 | T = TypeVar('T', bound=BaseModel)
9 |
10 | def call_llm(
11 | prompt: Any,
12 | model_name: str,
13 | model_provider: str,
14 | pydantic_model: Type[T],
15 | agent_name: Optional[str] = None,
16 | max_retries: int = 3,
17 | default_factory = None
18 | ) -> T:
19 | """
20 | Makes an LLM call with retry logic, handling both JSON supported and non-JSON supported models.
21 |
22 | Args:
23 | prompt: The prompt to send to the LLM
24 | model_name: Name of the model to use
25 | model_provider: Provider of the model
26 | pydantic_model: The Pydantic model class to structure the output
27 | agent_name: Optional name of the agent for progress updates
28 | max_retries: Maximum number of retries (default: 3)
29 | default_factory: Optional factory function to create default response on failure
30 |
31 | Returns:
32 | An instance of the specified Pydantic model
33 | """
34 | from llm.models import get_model, get_model_info
35 |
36 | model_info = get_model_info(model_name)
37 | llm = get_model(model_name, model_provider)
38 |
39 | # For non-JSON support models, we can use structured output
40 | if not (model_info and not model_info.has_json_mode()):
41 | llm = llm.with_structured_output(
42 | pydantic_model,
43 | method="json_mode",
44 | )
45 |
46 | # Call the LLM with retries
47 | for attempt in range(max_retries):
48 | try:
49 | # Call the LLM
50 | result = llm.invoke(prompt)
51 |
52 | # For non-JSON support models, we need to extract and parse the JSON manually
53 | if model_info and not model_info.has_json_mode():
54 | parsed_result = extract_json_from_response(result.content)
55 | if parsed_result:
56 | return pydantic_model(**parsed_result)
57 | else:
58 | return result
59 |
60 | except Exception as e:
61 | if agent_name:
62 | progress.update_status(agent_name, None, f"Error - retry {attempt + 1}/{max_retries}")
63 |
64 | if attempt == max_retries - 1:
65 | print(f"Error in LLM call after {max_retries} attempts: {e}")
66 | # Use default_factory if provided, otherwise create a basic default
67 | if default_factory:
68 | return default_factory()
69 | return create_default_response(pydantic_model)
70 |
71 | # This should never be reached due to the retry logic above
72 | return create_default_response(pydantic_model)
73 |
74 | def create_default_response(model_class: Type[T]) -> T:
75 | """Creates a safe default response based on the model's fields."""
76 | default_values = {}
77 | for field_name, field in model_class.model_fields.items():
78 | if field.annotation == str:
79 | default_values[field_name] = "Error in analysis, using default"
80 | elif field.annotation == float:
81 | default_values[field_name] = 0.0
82 | elif field.annotation == int:
83 | default_values[field_name] = 0
84 | elif hasattr(field.annotation, "__origin__") and field.annotation.__origin__ == dict:
85 | default_values[field_name] = {}
86 | else:
87 | # For other types (like Literal), try to use the first allowed value
88 | if hasattr(field.annotation, "__args__"):
89 | default_values[field_name] = field.annotation.__args__[0]
90 | else:
91 | default_values[field_name] = None
92 |
93 | return model_class(**default_values)
94 |
95 | def extract_json_from_response(content: str) -> Optional[dict]:
96 | """Extracts JSON from markdown-formatted response."""
97 | try:
98 | json_start = content.find("```json")
99 | if json_start != -1:
100 | json_text = content[json_start + 7:] # Skip past ```json
101 | json_end = json_text.find("```")
102 | if json_end != -1:
103 | json_text = json_text[:json_end].strip()
104 | return json.loads(json_text)
105 | except Exception as e:
106 | print(f"Error extracting JSON from response: {e}")
107 | return None
108 |
--------------------------------------------------------------------------------
/src/utils/ollama.py:
--------------------------------------------------------------------------------
1 | """Utilities for working with Ollama models"""
2 |
3 | import platform
4 | import subprocess
5 | import requests
6 | import time
7 | from typing import List
8 | import questionary
9 | from colorama import Fore, Style
10 | import os
11 | from . import docker
12 |
13 | # Constants
14 | OLLAMA_SERVER_URL = "http://localhost:11434"
15 | OLLAMA_API_MODELS_ENDPOINT = f"{OLLAMA_SERVER_URL}/api/tags"
16 | OLLAMA_DOWNLOAD_URL = {"darwin": "https://ollama.com/download/darwin", "windows": "https://ollama.com/download/windows", "linux": "https://ollama.com/download/linux"} # macOS # Windows # Linux
17 | INSTALLATION_INSTRUCTIONS = {"darwin": "curl -fsSL https://ollama.com/install.sh | sh", "windows": "# Download from https://ollama.com/download/windows and run the installer", "linux": "curl -fsSL https://ollama.com/install.sh | sh"}
18 |
19 |
20 | def is_ollama_installed() -> bool:
21 | """Check if Ollama is installed on the system."""
22 | system = platform.system().lower()
23 |
24 | if system == "darwin" or system == "linux": # macOS or Linux
25 | try:
26 | result = subprocess.run(["which", "ollama"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
27 | return result.returncode == 0
28 | except Exception:
29 | return False
30 | elif system == "windows": # Windows
31 | try:
32 | result = subprocess.run(["where", "ollama"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True)
33 | return result.returncode == 0
34 | except Exception:
35 | return False
36 | else:
37 | return False # Unsupported OS
38 |
39 |
40 | def is_ollama_server_running() -> bool:
41 | """Check if the Ollama server is running."""
42 | try:
43 | response = requests.get(OLLAMA_API_MODELS_ENDPOINT, timeout=2)
44 | return response.status_code == 200
45 | except requests.RequestException:
46 | return False
47 |
48 |
49 | def get_locally_available_models() -> List[str]:
50 | """Get a list of models that are already downloaded locally."""
51 | if not is_ollama_server_running():
52 | return []
53 |
54 | try:
55 | response = requests.get(OLLAMA_API_MODELS_ENDPOINT, timeout=5)
56 | if response.status_code == 200:
57 | data = response.json()
58 | return [model["name"] for model in data["models"]] if "models" in data else []
59 | return []
60 | except requests.RequestException:
61 | return []
62 |
63 |
64 | def start_ollama_server() -> bool:
65 | """Start the Ollama server if it's not already running."""
66 | if is_ollama_server_running():
67 | print(f"{Fore.GREEN}Ollama server is already running.{Style.RESET_ALL}")
68 | return True
69 |
70 | system = platform.system().lower()
71 |
72 | try:
73 | if system == "darwin" or system == "linux": # macOS or Linux
74 | subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
75 | elif system == "windows": # Windows
76 | subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
77 | else:
78 | print(f"{Fore.RED}Unsupported operating system: {system}{Style.RESET_ALL}")
79 | return False
80 |
81 | # Wait for server to start
82 | for _ in range(10): # Try for 10 seconds
83 | if is_ollama_server_running():
84 | print(f"{Fore.GREEN}Ollama server started successfully.{Style.RESET_ALL}")
85 | return True
86 | time.sleep(1)
87 |
88 | print(f"{Fore.RED}Failed to start Ollama server. Timed out waiting for server to become available.{Style.RESET_ALL}")
89 | return False
90 | except Exception as e:
91 | print(f"{Fore.RED}Error starting Ollama server: {e}{Style.RESET_ALL}")
92 | return False
93 |
94 |
95 | def install_ollama() -> bool:
96 | """Install Ollama on the system."""
97 | system = platform.system().lower()
98 | if system not in OLLAMA_DOWNLOAD_URL:
99 | print(f"{Fore.RED}Unsupported operating system for automatic installation: {system}{Style.RESET_ALL}")
100 | print(f"Please visit https://ollama.com/download to install Ollama manually.")
101 | return False
102 |
103 | if system == "darwin": # macOS
104 | print(f"{Fore.YELLOW}Ollama for Mac is available as an application download.{Style.RESET_ALL}")
105 |
106 | # Default to offering the app download first for macOS users
107 | if questionary.confirm("Would you like to download the Ollama application?", default=True).ask():
108 | try:
109 | import webbrowser
110 |
111 | webbrowser.open(OLLAMA_DOWNLOAD_URL["darwin"])
112 | print(f"{Fore.YELLOW}Please download and install the application, then restart this program.{Style.RESET_ALL}")
113 | print(f"{Fore.CYAN}After installation, you may need to open the Ollama app once before continuing.{Style.RESET_ALL}")
114 |
115 | # Ask if they want to try continuing after installation
116 | if questionary.confirm("Have you installed the Ollama app and opened it at least once?", default=False).ask():
117 | # Check if it's now installed
118 | if is_ollama_installed() and start_ollama_server():
119 | print(f"{Fore.GREEN}Ollama is now properly installed and running!{Style.RESET_ALL}")
120 | return True
121 | else:
122 | print(f"{Fore.RED}Ollama installation not detected. Please restart this application after installing Ollama.{Style.RESET_ALL}")
123 | return False
124 | return False
125 | except Exception as e:
126 | print(f"{Fore.RED}Failed to open browser: {e}{Style.RESET_ALL}")
127 | return False
128 | else:
129 | # Only offer command-line installation as a fallback for advanced users
130 | if questionary.confirm("Would you like to try the command-line installation instead? (For advanced users)", default=False).ask():
131 | print(f"{Fore.YELLOW}Attempting command-line installation...{Style.RESET_ALL}")
132 | try:
133 | install_process = subprocess.run(["bash", "-c", "curl -fsSL https://ollama.com/install.sh | sh"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
134 |
135 | if install_process.returncode == 0:
136 | print(f"{Fore.GREEN}Ollama installed successfully via command line.{Style.RESET_ALL}")
137 | return True
138 | else:
139 | print(f"{Fore.RED}Command-line installation failed. Please use the app download method instead.{Style.RESET_ALL}")
140 | return False
141 | except Exception as e:
142 | print(f"{Fore.RED}Error during command-line installation: {e}{Style.RESET_ALL}")
143 | return False
144 | return False
145 | elif system == "linux": # Linux
146 | print(f"{Fore.YELLOW}Installing Ollama...{Style.RESET_ALL}")
147 | try:
148 | # Run the installation command as a single command
149 | install_process = subprocess.run(["bash", "-c", "curl -fsSL https://ollama.com/install.sh | sh"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
150 |
151 | if install_process.returncode == 0:
152 | print(f"{Fore.GREEN}Ollama installed successfully.{Style.RESET_ALL}")
153 | return True
154 | else:
155 | print(f"{Fore.RED}Failed to install Ollama. Error: {install_process.stderr}{Style.RESET_ALL}")
156 | return False
157 | except Exception as e:
158 | print(f"{Fore.RED}Error during Ollama installation: {e}{Style.RESET_ALL}")
159 | return False
160 | elif system == "windows": # Windows
161 | print(f"{Fore.YELLOW}Automatic installation on Windows is not supported.{Style.RESET_ALL}")
162 | print(f"Please download and install Ollama from: {OLLAMA_DOWNLOAD_URL['windows']}")
163 |
164 | # Ask if they want to open the download page
165 | if questionary.confirm("Do you want to open the Ollama download page in your browser?").ask():
166 | try:
167 | import webbrowser
168 |
169 | webbrowser.open(OLLAMA_DOWNLOAD_URL["windows"])
170 | print(f"{Fore.YELLOW}After installation, please restart this application.{Style.RESET_ALL}")
171 |
172 | # Ask if they want to try continuing after installation
173 | if questionary.confirm("Have you installed Ollama?", default=False).ask():
174 | # Check if it's now installed
175 | if is_ollama_installed() and start_ollama_server():
176 | print(f"{Fore.GREEN}Ollama is now properly installed and running!{Style.RESET_ALL}")
177 | return True
178 | else:
179 | print(f"{Fore.RED}Ollama installation not detected. Please restart this application after installing Ollama.{Style.RESET_ALL}")
180 | return False
181 | except Exception as e:
182 | print(f"{Fore.RED}Failed to open browser: {e}{Style.RESET_ALL}")
183 | return False
184 |
185 | return False
186 |
187 |
188 | def download_model(model_name: str) -> bool:
189 | """Download an Ollama model."""
190 | if not is_ollama_server_running():
191 | if not start_ollama_server():
192 | return False
193 |
194 | print(f"{Fore.YELLOW}Downloading model {model_name}...{Style.RESET_ALL}")
195 | print(f"{Fore.CYAN}This may take a while depending on your internet speed and the model size.{Style.RESET_ALL}")
196 | print(f"{Fore.CYAN}The download is happening in the background. Please be patient...{Style.RESET_ALL}")
197 |
198 | try:
199 | # Use the Ollama CLI to download the model
200 | process = subprocess.Popen(["ollama", "pull", model_name], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True) # Redirect stderr to stdout to capture all output # Line buffered
201 |
202 | # Show some progress to the user
203 | print(f"{Fore.CYAN}Download progress:{Style.RESET_ALL}")
204 |
205 | # For tracking progress
206 | last_percentage = 0
207 | last_phase = ""
208 | bar_length = 40
209 |
210 | while True:
211 | output = process.stdout.readline()
212 | if output == "" and process.poll() is not None:
213 | break
214 | if output:
215 | output = output.strip()
216 | # Try to extract percentage information using a more lenient approach
217 | percentage = None
218 | current_phase = None
219 |
220 | # Example patterns in Ollama output:
221 | # "downloading: 23.45 MB / 42.19 MB [================>-------------] 55.59%"
222 | # "downloading model: 76%"
223 | # "pulling manifest: 100%"
224 |
225 | # Check for percentage in the output
226 | import re
227 |
228 | percentage_match = re.search(r"(\d+(\.\d+)?)%", output)
229 | if percentage_match:
230 | try:
231 | percentage = float(percentage_match.group(1))
232 | except ValueError:
233 | percentage = None
234 |
235 | # Try to determine the current phase (downloading, extracting, etc.)
236 | phase_match = re.search(r"^([a-zA-Z\s]+):", output)
237 | if phase_match:
238 | current_phase = phase_match.group(1).strip()
239 |
240 | # If we found a percentage, display a progress bar
241 | if percentage is not None:
242 | # Only update if there's a significant change (avoid flickering)
243 | if abs(percentage - last_percentage) >= 1 or (current_phase and current_phase != last_phase):
244 | last_percentage = percentage
245 | if current_phase:
246 | last_phase = current_phase
247 |
248 | # Create a progress bar
249 | filled_length = int(bar_length * percentage / 100)
250 | bar = "█" * filled_length + "░" * (bar_length - filled_length)
251 |
252 | # Build the status line with the phase if available
253 | phase_display = f"{Fore.CYAN}{last_phase.capitalize()}{Style.RESET_ALL}: " if last_phase else ""
254 | status_line = f"\r{phase_display}{Fore.GREEN}{bar}{Style.RESET_ALL} {Fore.YELLOW}{percentage:.1f}%{Style.RESET_ALL}"
255 |
256 | # Print the status line without a newline to update in place
257 | print(status_line, end="", flush=True)
258 | else:
259 | # If we couldn't extract a percentage but have identifiable output
260 | if "download" in output.lower() or "extract" in output.lower() or "pulling" in output.lower():
261 | # Don't print a newline for percentage updates
262 | if "%" in output:
263 | print(f"\r{Fore.GREEN}{output}{Style.RESET_ALL}", end="", flush=True)
264 | else:
265 | print(f"{Fore.GREEN}{output}{Style.RESET_ALL}")
266 |
267 | # Wait for the process to finish
268 | return_code = process.wait()
269 |
270 | # Ensure we print a newline after the progress bar
271 | print()
272 |
273 | if return_code == 0:
274 | print(f"{Fore.GREEN}Model {model_name} downloaded successfully!{Style.RESET_ALL}")
275 | return True
276 | else:
277 | print(f"{Fore.RED}Failed to download model {model_name}. Check your internet connection and try again.{Style.RESET_ALL}")
278 | return False
279 | except Exception as e:
280 | print(f"\n{Fore.RED}Error downloading model {model_name}: {e}{Style.RESET_ALL}")
281 | return False
282 |
283 |
284 | def ensure_ollama_and_model(model_name: str) -> bool:
285 | """Ensure Ollama is installed, running, and the requested model is available."""
286 | # Check if we're running in Docker
287 | in_docker = os.environ.get("OLLAMA_BASE_URL", "").startswith("http://ollama:") or os.environ.get("OLLAMA_BASE_URL", "").startswith("http://host.docker.internal:")
288 |
289 | # In Docker environment, we need a different approach
290 | if in_docker:
291 | ollama_url = os.environ.get("OLLAMA_BASE_URL", "http://ollama:11434")
292 | return docker.ensure_ollama_and_model(model_name, ollama_url)
293 |
294 | # Regular flow for non-Docker environments
295 | # Check if Ollama is installed
296 | if not is_ollama_installed():
297 | print(f"{Fore.YELLOW}Ollama is not installed on your system.{Style.RESET_ALL}")
298 |
299 | # Ask if they want to install it
300 | if questionary.confirm("Do you want to install Ollama?").ask():
301 | if not install_ollama():
302 | return False
303 | else:
304 | print(f"{Fore.RED}Ollama is required to use local models.{Style.RESET_ALL}")
305 | return False
306 |
307 | # Make sure the server is running
308 | if not is_ollama_server_running():
309 | print(f"{Fore.YELLOW}Starting Ollama server...{Style.RESET_ALL}")
310 | if not start_ollama_server():
311 | return False
312 |
313 | # Check if the model is already downloaded
314 | available_models = get_locally_available_models()
315 | if model_name not in available_models:
316 | print(f"{Fore.YELLOW}Model {model_name} is not available locally.{Style.RESET_ALL}")
317 |
318 | # Ask if they want to download it
319 | model_size_info = ""
320 | if "70b" in model_name:
321 | model_size_info = " This is a large model (up to several GB) and may take a while to download."
322 | elif "34b" in model_name or "8x7b" in model_name:
323 | model_size_info = " This is a medium-sized model (1-2 GB) and may take a few minutes to download."
324 |
325 | if questionary.confirm(f"Do you want to download the {model_name} model?{model_size_info} The download will happen in the background.").ask():
326 | return download_model(model_name)
327 | else:
328 | print(f"{Fore.RED}The model is required to proceed.{Style.RESET_ALL}")
329 | return False
330 |
331 | return True
332 |
333 |
334 | def delete_model(model_name: str) -> bool:
335 | """Delete a locally downloaded Ollama model."""
336 | # Check if we're running in Docker
337 | in_docker = os.environ.get("OLLAMA_BASE_URL", "").startswith("http://ollama:") or os.environ.get("OLLAMA_BASE_URL", "").startswith("http://host.docker.internal:")
338 |
339 | # In Docker environment, delegate to docker module
340 | if in_docker:
341 | ollama_url = os.environ.get("OLLAMA_BASE_URL", "http://ollama:11434")
342 | return docker.delete_model(model_name, ollama_url)
343 |
344 | # Non-Docker environment
345 | if not is_ollama_server_running():
346 | if not start_ollama_server():
347 | return False
348 |
349 | print(f"{Fore.YELLOW}Deleting model {model_name}...{Style.RESET_ALL}")
350 |
351 | try:
352 | # Use the Ollama CLI to delete the model
353 | process = subprocess.run(["ollama", "rm", model_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
354 |
355 | if process.returncode == 0:
356 | print(f"{Fore.GREEN}Model {model_name} deleted successfully.{Style.RESET_ALL}")
357 | return True
358 | else:
359 | print(f"{Fore.RED}Failed to delete model {model_name}. Error: {process.stderr}{Style.RESET_ALL}")
360 | return False
361 | except Exception as e:
362 | print(f"{Fore.RED}Error deleting model {model_name}: {e}{Style.RESET_ALL}")
363 | return False
364 |
365 |
366 | # Add this at the end of the file for command-line usage
367 | if __name__ == "__main__":
368 | import sys
369 | import argparse
370 |
371 | parser = argparse.ArgumentParser(description="Ollama model manager")
372 | parser.add_argument("--check-model", help="Check if model exists and download if needed")
373 | args = parser.parse_args()
374 |
375 | if args.check_model:
376 | print(f"Ensuring Ollama is installed and model {args.check_model} is available...")
377 | result = ensure_ollama_and_model(args.check_model)
378 | sys.exit(0 if result else 1)
379 | else:
380 | print("No action specified. Use --check-model to check if a model exists.")
381 | sys.exit(1)
382 |
--------------------------------------------------------------------------------
/src/utils/progress.py:
--------------------------------------------------------------------------------
1 | from rich.console import Console
2 | from rich.live import Live
3 | from rich.table import Table
4 | from rich.style import Style
5 | from rich.text import Text
6 | from typing import Dict, Optional
7 |
8 | console = Console()
9 |
10 |
11 | class AgentProgress:
12 | """Manages progress tracking for multiple agents."""
13 |
14 | def __init__(self):
15 | self.agent_status: Dict[str, Dict[str, str]] = {}
16 | self.table = Table(show_header=False, box=None, padding=(0, 1))
17 | self.live = Live(self.table, console=console, refresh_per_second=4)
18 | self.started = False
19 |
20 | def start(self):
21 | """Start the progress display."""
22 | if not self.started:
23 | self.live.start()
24 | self.started = True
25 |
26 | def stop(self):
27 | """Stop the progress display."""
28 | if self.started:
29 | self.live.stop()
30 | self.started = False
31 |
32 | def update_status(self, agent_name: str, ticker: Optional[str] = None, status: str = ""):
33 | """Update the status of an agent."""
34 | if agent_name not in self.agent_status:
35 | self.agent_status[agent_name] = {"status": "", "ticker": None}
36 |
37 | if ticker:
38 | self.agent_status[agent_name]["ticker"] = ticker
39 | if status:
40 | self.agent_status[agent_name]["status"] = status
41 |
42 | self._refresh_display()
43 |
44 | def _refresh_display(self):
45 | """Refresh the progress display."""
46 | self.table.columns.clear()
47 | self.table.add_column(width=100)
48 |
49 | # Sort agents with Risk Management and Portfolio Management at the bottom
50 | def sort_key(item):
51 | agent_name = item[0]
52 | if "risk_management" in agent_name:
53 | return (2, agent_name)
54 | elif "portfolio_management" in agent_name:
55 | return (3, agent_name)
56 | else:
57 | return (1, agent_name)
58 |
59 | for agent_name, info in sorted(self.agent_status.items(), key=sort_key):
60 | status = info["status"]
61 | ticker = info["ticker"]
62 |
63 | # Create the status text with appropriate styling
64 | if status.lower() == "done":
65 | style = Style(color="green", bold=True)
66 | symbol = "✓"
67 | elif status.lower() == "error":
68 | style = Style(color="red", bold=True)
69 | symbol = "✗"
70 | else:
71 | style = Style(color="yellow")
72 | symbol = "⋯"
73 |
74 | agent_display = agent_name.replace("_agent", "").replace("_", " ").title()
75 | status_text = Text()
76 | status_text.append(f"{symbol} ", style=style)
77 | status_text.append(f"{agent_display:<20}", style=Style(bold=True))
78 |
79 | if ticker:
80 | status_text.append(f"[{ticker}] ", style=Style(color="cyan"))
81 | status_text.append(status, style=style)
82 |
83 | self.table.add_row(status_text)
84 |
85 |
86 | # Create a global instance
87 | progress = AgentProgress()
88 |
--------------------------------------------------------------------------------
/src/utils/visualize.py:
--------------------------------------------------------------------------------
1 | from langgraph.graph.state import CompiledGraph
2 | from langchain_core.runnables.graph import MermaidDrawMethod
3 |
4 |
5 | def save_graph_as_png(app: CompiledGraph, output_file_path) -> None:
6 | png_image = app.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API)
7 | file_path = output_file_path if len(output_file_path) > 0 else "graph.png"
8 | with open(file_path, "wb") as f:
9 | f.write(png_image)
--------------------------------------------------------------------------------