├── config
├── __init__.py
├── trading_mode.py
├── config.json
├── exceptions.py
├── config_manager.py
└── config_validator.py
├── core
├── __init__.py
├── order_handling
│ ├── __init__.py
│ ├── execution_strategy
│ │ ├── __init__.py
│ │ ├── order_execution_strategy_interface.py
│ │ ├── order_execution_strategy_factory.py
│ │ ├── backtest_order_execution_strategy.py
│ │ └── live_order_execution_strategy.py
│ ├── fee_calculator.py
│ ├── exceptions.py
│ ├── order_book.py
│ ├── order.py
│ └── order_status_tracker.py
├── bot_management
│ ├── bot_controller
│ │ ├── __init__.py
│ │ ├── exceptions.py
│ │ └── bot_controller.py
│ ├── notification
│ │ ├── notification_content.py
│ │ └── notification_handler.py
│ └── event_bus.py
├── validation
│ ├── exceptions.py
│ └── order_validator.py
├── services
│ ├── exchange_service_factory.py
│ ├── exceptions.py
│ ├── exchange_interface.py
│ └── backtest_exchange_service.py
└── grid_management
│ └── grid_level.py
├── tests
├── __init__.py
├── conftest.py
├── order_handling
│ ├── test_fee_calculator.py
│ ├── test_order_execution_strategy_factory.py
│ ├── test_backtest_order_execution_strategy.py
│ ├── test_order.py
│ ├── test_order_book.py
│ └── test_live_order_execution_strategy.py
├── utils
│ ├── test_config_name_generator.py
│ ├── test_performance_results_saver.py
│ ├── test_logging_config.py
│ └── test_arg_parser.py
├── services
│ └── test_exchange_service_factory.py
├── bot_management
│ ├── test_event_bus.py
│ ├── test_notification_handler.py
│ └── test_bot_controller.py
├── validation
│ └── test_order_validator.py
├── grid_management
│ └── test_grid_level.py
├── strategies
│ └── test_plotter.py
└── config
│ ├── test_config_validator.py
│ └── test_config_manager.py
├── utils
├── __init__.py
├── constants.py
├── config_name_generator.py
├── logging_config.py
├── performance_results_saver.py
└── arg_parser.py
├── strategies
├── __init__.py
├── spacing_type.py
├── strategy_type.py
├── trading_strategy_interface.py
└── plotter.py
├── .vscode
├── extensions.json
└── settings.json
├── monitoring
├── configs
│ ├── grafana
│ │ └── provisioning
│ │ │ ├── dashboards.yml
│ │ │ └── datasources.yml
│ ├── loki
│ │ ├── rules.yaml
│ │ └── loki.yaml
│ └── promtail
│ │ └── promtail.yaml
└── dashboards
│ └── grid_trading_bot_dashboard.json
├── .github
├── dependabot.yml
├── workflows
│ └── run-tests-on-push-or-merge-pr-master.yml
└── PULL_REQUEST_TEMPLATE.md
├── .gitignore
├── LICENSE.txt
├── .pre-commit-config.yaml
├── docker-compose.yml
├── pyproject.toml
├── CODE_OF_CONDUCT.md
└── main.py
/config/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/strategies/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/core/order_handling/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/core/bot_management/bot_controller/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/core/order_handling/execution_strategy/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": ["ms-python.python", "charliermarsh.ruff"]
3 | }
4 |
--------------------------------------------------------------------------------
/monitoring/configs/grafana/provisioning/dashboards.yml:
--------------------------------------------------------------------------------
1 | apiVersion: 1
2 | providers:
3 | - name: 'GridTradingBot'
4 | orgId: 1
5 | folder: 'TradingBot Monitoring'
6 | type: file
7 | options:
8 | path: /var/lib/grafana/dashboards
9 |
--------------------------------------------------------------------------------
/monitoring/configs/grafana/provisioning/datasources.yml:
--------------------------------------------------------------------------------
1 | apiVersion: 1
2 |
3 | datasources:
4 | - name: Loki
5 | type: loki
6 | access: proxy
7 | url: http://loki:3100
8 | uid: Loki
9 | editable: false
10 | isDefault: true
11 |
--------------------------------------------------------------------------------
/core/bot_management/bot_controller/exceptions.py:
--------------------------------------------------------------------------------
1 | class BotControllerError(Exception):
2 | """Base exception class for BotController errors."""
3 |
4 |
5 | class CommandParsingError(BotControllerError):
6 | """Exception raised when there is an error parsing a command."""
7 |
8 |
9 | class StrategyControlError(BotControllerError):
10 | """Exception raised when starting, stopping, or restarting the strategy fails."""
11 |
--------------------------------------------------------------------------------
/core/validation/exceptions.py:
--------------------------------------------------------------------------------
1 | class InsufficientBalanceError(Exception):
2 | """Raised when balance is insufficient to place a buy or sell order."""
3 |
4 | pass
5 |
6 |
7 | class InsufficientCryptoBalanceError(Exception):
8 | """Raised when crypto balance is insufficient to complete a sell order."""
9 |
10 | pass
11 |
12 |
13 | class InvalidOrderQuantityError(Exception):
14 | """Raised when order quantity (amount) is invalid."""
15 |
16 | pass
17 |
--------------------------------------------------------------------------------
/core/order_handling/fee_calculator.py:
--------------------------------------------------------------------------------
1 | from config.config_manager import ConfigManager
2 |
3 |
4 | class FeeCalculator:
5 | def __init__(
6 | self,
7 | config_manager: ConfigManager,
8 | ):
9 | self.config_manager = config_manager
10 | self.trading_fee: float = self.config_manager.get_trading_fee()
11 |
12 | def calculate_fee(
13 | self,
14 | trade_value: float,
15 | ) -> float:
16 | return trade_value * self.trading_fee
17 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "pip"
4 | directory: "/"
5 | schedule:
6 | interval: "weekly"
7 | ignore:
8 | # Ignoring these dependencies temporarily for compatibility
9 | - dependency-name: "numpy"
10 | - dependency-name: "pandas"
11 | labels:
12 | - "pip-dependencies"
13 |
14 | - package-ecosystem: "github-actions"
15 | directory: "/"
16 | schedule:
17 | interval: "weekly"
18 | labels:
19 | - "actions"
20 |
--------------------------------------------------------------------------------
/core/order_handling/exceptions.py:
--------------------------------------------------------------------------------
1 | from .order import OrderSide, OrderType
2 |
3 |
4 | class OrderExecutionFailedError(Exception):
5 | def __init__(
6 | self,
7 | message: str,
8 | order_side: OrderSide,
9 | order_type: OrderType,
10 | pair: str,
11 | quantity: float,
12 | price: float,
13 | ):
14 | super().__init__(message)
15 | self.order_side = order_side
16 | self.order_type = order_type
17 | self.pair = pair
18 | self.quantity = quantity
19 | self.price = price
20 |
--------------------------------------------------------------------------------
/config/trading_mode.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class TradingMode(Enum):
5 | BACKTEST = "backtest"
6 | PAPER_TRADING = "paper_trading"
7 | LIVE = "live"
8 |
9 | @staticmethod
10 | def from_string(mode_str: str):
11 | try:
12 | return TradingMode(mode_str)
13 | except ValueError:
14 | available_modes = ", ".join([mode.value for mode in TradingMode])
15 | raise ValueError(
16 | f"Invalid trading mode: '{mode_str}'. Available modes are: {available_modes}",
17 | ) from None
18 |
--------------------------------------------------------------------------------
/strategies/spacing_type.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class SpacingType(Enum):
5 | ARITHMETIC = "arithmetic"
6 | GEOMETRIC = "geometric"
7 |
8 | @staticmethod
9 | def from_string(spacing_type_str: str):
10 | try:
11 | return SpacingType(spacing_type_str)
12 | except ValueError:
13 | available_spacings = ", ".join([spacing.value for spacing in SpacingType])
14 | raise ValueError(
15 | f"Invalid spacing type: '{spacing_type_str}'. Available spacings are: {available_spacings}",
16 | ) from None
17 |
--------------------------------------------------------------------------------
/strategies/strategy_type.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class StrategyType(Enum):
5 | SIMPLE_GRID = "simple_grid"
6 | HEDGED_GRID = "hedged_grid"
7 |
8 | @staticmethod
9 | def from_string(strategy_type_str: str):
10 | try:
11 | return StrategyType(strategy_type_str)
12 | except ValueError:
13 | available_strategies = ", ".join([strat.value for strat in StrategyType])
14 | raise ValueError(
15 | f"Invalid strategy type: '{strategy_type_str}'. Available strategies are: {available_strategies}",
16 | ) from None
17 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | *.DS_Store
3 | __pycache__/
4 | *.pyc
5 | *.py[cod]
6 |
7 | # Distribution / packaging
8 | build/
9 | dist/
10 | *.egg-info/
11 | *.lock
12 |
13 | # Installer logs
14 | pip-log.txt
15 | pip-delete-this-directory.txt
16 |
17 | # Unit test / coverage reports
18 | .tox/
19 | .coverage
20 | .cache
21 | coverage.xml
22 |
23 | # Ignore profiling output files
24 | *.prof
25 |
26 | # VS Code
27 | .vscode/*
28 | !.vscode/settings.json
29 | !.vscode/extensions.json
30 |
31 | # Secrets
32 | secrets.json
33 | .env.yml
34 | .env
35 | loki-data/
36 | grafana-data/
37 |
38 | # `uv` Package Manager
39 | .uv/
40 |
41 | # Project specific
42 | /data/
43 | logs/*
44 |
--------------------------------------------------------------------------------
/monitoring/configs/loki/rules.yaml:
--------------------------------------------------------------------------------
1 | groups:
2 | - name: grid_trading_alerts
3 | rules:
4 | - alert: HighCPUUsage
5 | expr: |
6 | avg_over_time({job="grid_trading_bot"} | json | unwrap cpu [5m]) > 80
7 | for: 2m
8 | labels:
9 | severity: warning
10 | annotations:
11 | summary: High CPU usage detected
12 | - alert: OrderExecutionFailure
13 | expr: |
14 | count_over_time({job="grid_trading_bot"} |= "Failed to execute order" [5m]) > 3
15 | for: 1m
16 | labels:
17 | severity: critical
18 | annotations:
19 | summary: Multiple order execution failures detected
20 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.autoTestDiscoverOnSaveEnabled": true,
3 | "python.testing.unittestEnabled": false,
4 | "python.testing.pytestEnabled": true,
5 | "python.testing.pytestArgs": ["tests"],
6 | "python.defaultInterpreterPath": "./.venv/bin/python",
7 |
8 | "python.analysis.autoImportCompletions": true,
9 | "editor.codeActionsOnSave": {
10 | "source.fixAll": "explicit"
11 | },
12 | "editor.formatOnSave": true,
13 | "editor.insertSpaces": true,
14 | "editor.tabSize": 2,
15 | "editor.detectIndentation": false,
16 | "editor.rulers": [120],
17 | "[python]": {
18 | "editor.defaultFormatter": "charliermarsh.ruff",
19 | "editor.tabSize": 4
20 | },
21 | "files.watcherExclude": {
22 | "**/.venv/**": true,
23 | "**/__pycache__/**": true,
24 | "**/.pytest_cache/**": true
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/core/order_handling/execution_strategy/order_execution_strategy_interface.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from ..order import Order, OrderSide
4 |
5 |
6 | class OrderExecutionStrategyInterface(ABC):
7 | @abstractmethod
8 | async def execute_market_order(
9 | self,
10 | order_side: OrderSide,
11 | pair: str,
12 | quantity: float,
13 | price: float,
14 | ) -> Order | None:
15 | pass
16 |
17 | @abstractmethod
18 | async def execute_limit_order(
19 | self,
20 | order_side: OrderSide,
21 | pair: str,
22 | quantity: float,
23 | price: float,
24 | ) -> Order | None:
25 | pass
26 |
27 | @abstractmethod
28 | async def get_order(
29 | self,
30 | order_id: str,
31 | pair: str,
32 | ) -> Order | None:
33 | pass
34 |
--------------------------------------------------------------------------------
/core/services/exchange_service_factory.py:
--------------------------------------------------------------------------------
1 | from config.config_manager import ConfigManager
2 | from config.trading_mode import TradingMode
3 |
4 | from .backtest_exchange_service import BacktestExchangeService
5 | from .live_exchange_service import LiveExchangeService
6 |
7 |
8 | class ExchangeServiceFactory:
9 | @staticmethod
10 | def create_exchange_service(
11 | config_manager: ConfigManager,
12 | trading_mode: TradingMode,
13 | ):
14 | if trading_mode == TradingMode.BACKTEST:
15 | return BacktestExchangeService(config_manager)
16 | elif trading_mode == TradingMode.PAPER_TRADING:
17 | return LiveExchangeService(config_manager, is_paper_trading_activated=True)
18 | elif trading_mode == TradingMode.LIVE:
19 | return LiveExchangeService(config_manager, is_paper_trading_activated=False)
20 | else:
21 | raise ValueError(f"Unsupported trading mode: {trading_mode}")
22 |
--------------------------------------------------------------------------------
/config/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "exchange": {
3 | "name": "binance",
4 | "trading_fee": 0.001,
5 | "trading_mode": "backtest"
6 | },
7 | "pair": {
8 | "base_currency": "SOL",
9 | "quote_currency": "USDT"
10 | },
11 | "trading_settings": {
12 | "timeframe": "1m",
13 | "period": {
14 | "start_date": "2024-12-10T09:00:00Z",
15 | "end_date": "2024-12-15T23:00:00Z"
16 | },
17 | "initial_balance": 150
18 | },
19 | "grid_strategy": {
20 | "type": "hedged_grid",
21 | "spacing": "geometric",
22 | "num_grids": 8,
23 | "range": {
24 | "top": 240,
25 | "bottom": 210
26 | }
27 | },
28 | "risk_management": {
29 | "take_profit": {
30 | "enabled": false,
31 | "threshold": 3700
32 | },
33 | "stop_loss": {
34 | "enabled": false,
35 | "threshold": 2830
36 | }
37 | },
38 | "logging": {
39 | "log_level": "INFO",
40 | "log_to_file": true
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/core/services/exceptions.py:
--------------------------------------------------------------------------------
1 | class UnsupportedExchangeError(Exception):
2 | """Raised when the exchange is not supported."""
3 |
4 | pass
5 |
6 |
7 | class DataFetchError(Exception):
8 | """Raised when data fetching fails after retries."""
9 |
10 | pass
11 |
12 |
13 | class HistoricalMarketDataFileNotFoundError(Exception):
14 | """Raised when historical market data has not been found in repository."""
15 |
16 | pass
17 |
18 |
19 | class UnsupportedTimeframeError(Exception):
20 | """Raised when a timeframe is not supported by a given exchange."""
21 |
22 | pass
23 |
24 |
25 | class UnsupportedPairError(Exception):
26 | """Raised when a crypto pair is not supported by a given exchange."""
27 |
28 | pass
29 |
30 |
31 | class OrderCancellationError(Exception):
32 | """Raised when order cancellation fails."""
33 |
34 | pass
35 |
36 |
37 | class MissingEnvironmentVariableError(Exception):
38 | """Raised when env variable are missing (EXCHANGE_API_KEY and/or EXCHANGE_SECRET_KEY)."""
39 |
40 | pass
41 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Jordan TETE
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/core/order_handling/execution_strategy/order_execution_strategy_factory.py:
--------------------------------------------------------------------------------
1 | from config.config_manager import ConfigManager
2 | from config.trading_mode import TradingMode
3 | from core.services.exchange_interface import ExchangeInterface
4 |
5 | from .backtest_order_execution_strategy import BacktestOrderExecutionStrategy
6 | from .live_order_execution_strategy import LiveOrderExecutionStrategy
7 | from .order_execution_strategy_interface import OrderExecutionStrategyInterface
8 |
9 |
10 | class OrderExecutionStrategyFactory:
11 | @staticmethod
12 | def create(
13 | config_manager: ConfigManager,
14 | exchange_service: ExchangeInterface,
15 | ) -> OrderExecutionStrategyInterface:
16 | trading_mode = config_manager.get_trading_mode()
17 |
18 | if trading_mode == TradingMode.LIVE or trading_mode == TradingMode.PAPER_TRADING:
19 | return LiveOrderExecutionStrategy(exchange_service=exchange_service)
20 | elif trading_mode == TradingMode.BACKTEST:
21 | return BacktestOrderExecutionStrategy()
22 | else:
23 | raise ValueError(f"Unknown trading mode: {trading_mode}")
24 |
--------------------------------------------------------------------------------
/utils/constants.py:
--------------------------------------------------------------------------------
1 | CANDLE_LIMITS = {
2 | "binance": 1000,
3 | "coinbase": 300,
4 | "kraken": 720,
5 | "bitfinex": 5000,
6 | "bitstamp": 1000,
7 | "huobi": 2000,
8 | "okex": 1440,
9 | "bybit": 200,
10 | "bittrex": 500,
11 | "poloniex": 500,
12 | "gateio": 1000,
13 | "kucoin": 1500,
14 | }
15 |
16 | TIMEFRAME_MAPPINGS = {
17 | "1s": 1 * 1000, # 1 second
18 | "1m": 60 * 1000, # 1 minute
19 | "3m": 3 * 60 * 1000, # 3 minutes
20 | "5m": 5 * 60 * 1000, # 5 minutes
21 | "15m": 15 * 60 * 1000, # 15 minutes
22 | "30m": 30 * 60 * 1000, # 30 minutes
23 | "1h": 60 * 60 * 1000, # 1 hour
24 | "2h": 2 * 60 * 60 * 1000, # 2 hours
25 | "6h": 6 * 60 * 60 * 1000, # 6 hours
26 | "12h": 12 * 60 * 60 * 1000, # 12 hours
27 | "1d": 24 * 60 * 60 * 1000, # 1 day
28 | "3d": 3 * 24 * 60 * 60 * 1000, # 3 days
29 | "1w": 7 * 24 * 60 * 60 * 1000, # 1 week
30 | "1M": 30 * 24 * 60 * 60 * 1000, # 1 month (approximated as 30 days)
31 | }
32 |
33 | RESSOURCE_THRESHOLDS = {
34 | "cpu": 90,
35 | "bot_cpu": 80,
36 | "memory": 80,
37 | "bot_memory": 70,
38 | "disk": 90,
39 | }
40 |
--------------------------------------------------------------------------------
/.github/workflows/run-tests-on-push-or-merge-pr-master.yml:
--------------------------------------------------------------------------------
1 | name: on_push_or_merge_pr_master
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 |
11 | jobs:
12 | test:
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - name: Checkout code
17 | uses: actions/checkout@v6
18 |
19 | - name: Install uv
20 | uses: astral-sh/setup-uv@v7
21 |
22 | - name: "Set up Python"
23 | uses: actions/setup-python@v6
24 | with:
25 | python-version-file: "pyproject.toml"
26 |
27 | - name: Install the project
28 | run: uv sync --all-extras --dev
29 |
30 | - name: Set PYTHONPATH
31 | run: echo "PYTHONPATH=$(pwd)" >> $GITHUB_ENV
32 |
33 | - name: Run tests and upload coverage
34 | run: uv run pytest --cov=core --cov=config --cov=strategies --cov=utils --cov-report=xml:coverage.xml --cov-report=term
35 | continue-on-error: true
36 |
37 | - name: Upload coverage reports to Codecov
38 | uses: codecov/codecov-action@v5
39 | with:
40 | fail_ci_if_error: true
41 | files: ./coverage.xml
42 | token: ${{ secrets.CODECOV_TOKEN }}
43 | verbose: true
44 |
--------------------------------------------------------------------------------
/utils/config_name_generator.py:
--------------------------------------------------------------------------------
1 | from datetime import UTC, datetime
2 |
3 | from config.config_manager import ConfigManager
4 |
5 |
6 | def generate_config_name(config_manager: ConfigManager) -> str:
7 | """
8 | Generates a unique and descriptive name for the bot's configuration.
9 |
10 | Args:
11 | config_manager (ConfigManager): Config manager instance to retrieve key parameters.
12 |
13 | Returns:
14 | str: A descriptive configuration name including trading pair, mode, strategy,
15 | grid spacing, grid range, and timestamp.
16 | """
17 | trading_pair = f"{config_manager.get_base_currency()}_{config_manager.get_quote_currency()}"
18 | trading_mode = config_manager.get_trading_mode().name
19 | grid_strategy_type = config_manager.get_strategy_type().name
20 | grid_spacing_type = config_manager.get_spacing_type().name
21 | grid_size = config_manager.get_num_grids()
22 | grid_top = config_manager.get_top_range()
23 | grid_bottom = config_manager.get_bottom_range()
24 | start_time = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M")
25 |
26 | return (
27 | f"bot_{trading_pair}_{trading_mode}_strategy{grid_strategy_type}_"
28 | f"spacing{grid_spacing_type}_size{grid_size}_range{grid_bottom}-{grid_top}_{start_time}"
29 | )
30 |
--------------------------------------------------------------------------------
/monitoring/configs/loki/loki.yaml:
--------------------------------------------------------------------------------
1 | auth_enabled: false
2 |
3 | server:
4 | http_listen_port: 3100
5 | grpc_listen_port: 9096
6 | log_level: debug
7 |
8 | common:
9 | instance_addr: 127.0.0.1
10 | path_prefix: /tmp/loki
11 | storage:
12 | filesystem:
13 | chunks_directory: /tmp/loki/chunks
14 | rules_directory: /tmp/loki/rules
15 | replication_factor: 1
16 | ring:
17 | kvstore:
18 | store: inmemory
19 |
20 | frontend:
21 | max_outstanding_per_tenant: 2048
22 |
23 | limits_config:
24 | max_global_streams_per_user: 1000
25 | ingestion_rate_mb: 500
26 | ingestion_burst_size_mb: 500
27 | volume_enabled: true
28 | reject_old_samples: false
29 | reject_old_samples_max_age: 2160h # Accept logs up to 90 days old (adjust as needed)
30 |
31 | query_range:
32 | results_cache:
33 | cache:
34 | embedded_cache:
35 | enabled: true
36 | max_size_mb: 100
37 |
38 | ruler:
39 | rule_path: /etc/loki/rules
40 | storage:
41 | type: local
42 | local:
43 | directory: /etc/loki/rules
44 | enable_api: true
45 |
46 | schema_config:
47 | configs:
48 | - from: 2020-10-24
49 | store: tsdb
50 | object_store: filesystem
51 | schema: v13
52 | index:
53 | prefix: index_
54 | period: 24h
55 |
56 | analytics:
57 | reporting_enabled: false
58 |
--------------------------------------------------------------------------------
/config/exceptions.py:
--------------------------------------------------------------------------------
1 | class ConfigError(Exception):
2 | """Base class for all configuration-related errors."""
3 |
4 | pass
5 |
6 |
7 | class ConfigFileNotFoundError(ConfigError):
8 | def __init__(self, config_file, message="Configuration file not found"):
9 | self.config_file = config_file
10 | self.message = f"{message}: {config_file}"
11 | super().__init__(self.message)
12 |
13 |
14 | class ConfigValidationError(ConfigError):
15 | def __init__(self, missing_fields=None, invalid_fields=None, message="Configuration validation error"):
16 | self.missing_fields = missing_fields or []
17 | self.invalid_fields = invalid_fields or []
18 | details = []
19 | if self.missing_fields:
20 | details.append(f"Missing required fields - {', '.join(self.missing_fields)}")
21 | if self.invalid_fields:
22 | details.append(f"Invalid fields - {', '.join(self.invalid_fields)}")
23 | self.message = f"{message}: {', '.join(details)}"
24 | super().__init__(self.message)
25 |
26 |
27 | class ConfigParseError(ConfigError):
28 | def __init__(self, config_file, original_exception, message="Error parsing configuration file"):
29 | self.config_file = config_file
30 | self.original_exception = original_exception
31 | self.message = f"{message} ({config_file}): {original_exception!s}"
32 | super().__init__(self.message)
33 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | @pytest.fixture
5 | def valid_config():
6 | """Fixture providing a valid configuration for testing."""
7 | return {
8 | "exchange": {
9 | "name": "binance",
10 | "trading_fee": 0.001,
11 | "trading_mode": "backtest",
12 | },
13 | "pair": {
14 | "base_currency": "ETH",
15 | "quote_currency": "USDT",
16 | },
17 | "trading_settings": {
18 | "initial_balance": 10000,
19 | "timeframe": "1m",
20 | "period": {
21 | "start_date": "2024-07-04T00:00:00Z",
22 | "end_date": "2024-07-11T00:00:00Z",
23 | },
24 | "historical_data_file": "data/SOL_USDT/2024/1m.csv",
25 | },
26 | "grid_strategy": {
27 | "type": "simple_grid",
28 | "spacing": "geometric",
29 | "num_grids": 20,
30 | "range": {
31 | "top": 3100,
32 | "bottom": 2850,
33 | },
34 | },
35 | "risk_management": {
36 | "take_profit": {
37 | "enabled": False,
38 | "threshold": 3700,
39 | },
40 | "stop_loss": {
41 | "enabled": False,
42 | "threshold": 2830,
43 | },
44 | },
45 | "logging": {
46 | "log_level": "INFO",
47 | "log_to_file": True,
48 | },
49 | }
50 |
--------------------------------------------------------------------------------
/core/bot_management/notification/notification_content.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import Enum
3 |
4 |
5 | @dataclass
6 | class NotificationContent:
7 | title: str
8 | message: str
9 |
10 |
11 | class NotificationType(Enum):
12 | ORDER_PLACED = NotificationContent(
13 | title="Order Placement Successful",
14 | message=("Order placed successfully:\n{order_details}"),
15 | )
16 | ORDER_FILLED = NotificationContent(
17 | title="Order Filled",
18 | message=("Order has been filled successfully:\n{order_details}"),
19 | )
20 | ORDER_FAILED = NotificationContent(
21 | title="Order Placement Failed",
22 | message=("Failed to place order:\n{error_details}"),
23 | )
24 | ORDER_CANCELLED = NotificationContent(
25 | title="Order Cancellation",
26 | message=("Order has been cancelled:\n{order_details}"),
27 | )
28 | ERROR_OCCURRED = NotificationContent(
29 | title="Error Occurred",
30 | message="An error occurred in the trading bot:\n{error_details}",
31 | )
32 | TAKE_PROFIT_TRIGGERED = NotificationContent(
33 | title="Take Profit Triggered",
34 | message="Take profit triggered with order details:\n{order_details}",
35 | )
36 | STOP_LOSS_TRIGGERED = NotificationContent(
37 | title="Stop Loss Triggered",
38 | message="Stop loss triggered with order details:\n{order_details}",
39 | )
40 | HEALTH_CHECK_ALERT = NotificationContent(
41 | title="Health Check Alert",
42 | message=("{alert_details}"),
43 | )
44 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | # Ruff for linting and formatting
3 | - repo: https://github.com/astral-sh/ruff-pre-commit
4 | rev: v0.8.4
5 | hooks:
6 | - id: ruff
7 | name: ruff (linting)
8 | args: [--fix]
9 | - id: ruff-format
10 | name: ruff (formatting)
11 |
12 | # Built-in hooks for basic file checks
13 | - repo: https://github.com/pre-commit/pre-commit-hooks
14 | rev: v5.0.0
15 | hooks:
16 | - id: trailing-whitespace
17 | name: trim trailing whitespace
18 | - id: end-of-file-fixer
19 | name: fix end of files
20 | - id: check-yaml
21 | name: check yaml
22 | - id: check-json
23 | name: check json
24 | - id: check-toml
25 | name: check toml
26 | - id: check-merge-conflict
27 | name: check for merge conflicts
28 | - id: check-added-large-files
29 | name: check for added large files
30 | args: ["--maxkb=1000"]
31 | - id: debug-statements
32 | name: debug statements (Python)
33 |
34 | # Python-specific checks
35 | - repo: https://github.com/pre-commit/pygrep-hooks
36 | rev: v1.10.0
37 | hooks:
38 | - id: python-check-blanket-noqa
39 | name: check blanket noqa
40 | - id: python-check-blanket-type-ignore
41 | name: check blanket type ignore
42 | - id: python-no-log-warn
43 | name: check for deprecated log warn
44 | - id: python-use-type-annotations
45 | name: check use type annotations
46 | - id: text-unicode-replacement-char
47 | name: check for unicode replacement chars
48 |
--------------------------------------------------------------------------------
/core/grid_management/grid_level.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from ..order_handling.order import Order
4 |
5 |
6 | class GridCycleState(Enum):
7 | READY_TO_BUY_OR_SELL = "ready_to_buy_or_sell" # Level is ready for both a buy or a sell order
8 | READY_TO_BUY = "ready_to_buy" # Level is ready for a buy order
9 | WAITING_FOR_BUY_FILL = "waiting_for_buy_fill" # Buy order placed, waiting for execution
10 | READY_TO_SELL = "ready_to_sell" # Level is ready for a sell order
11 | WAITING_FOR_SELL_FILL = "waiting_for_sell_fill" # Sell order placed, waiting for execution
12 |
13 |
14 | class GridLevel:
15 | def __init__(self, price: float, state: GridCycleState):
16 | self.price: float = price
17 | self.orders: list[Order] = [] # Track all orders at this level
18 | self.state: GridCycleState = state
19 | self.paired_buy_level: GridLevel | None = None
20 | self.paired_sell_level: GridLevel | None = None
21 |
22 | def add_order(self, order: Order) -> None:
23 | """
24 | Record an order at this level.
25 | """
26 | self.orders.append(order)
27 |
28 | def __str__(self) -> str:
29 | return (
30 | f"GridLevel(price={self.price}, "
31 | f"state={self.state.name}, "
32 | f"num_orders={len(self.orders)}, "
33 | f"paired_buy_level={self.paired_buy_level.price if self.paired_buy_level else None}), "
34 | f"paired_sell_level={self.paired_sell_level.price if self.paired_sell_level else None})"
35 | )
36 |
37 | def __repr__(self) -> str:
38 | return self.__str__()
39 |
--------------------------------------------------------------------------------
/monitoring/configs/promtail/promtail.yaml:
--------------------------------------------------------------------------------
1 | server:
2 | http_listen_port: 9080
3 | grpc_listen_port: 0
4 | log_level: debug
5 |
6 | positions:
7 | filename: /tmp/positions.yaml
8 |
9 | clients:
10 | - url: "http://loki:3100/loki/api/v1/push"
11 | batchsize: 1
12 | batchwait: 1s
13 |
14 | scrape_configs:
15 | - job_name: bot_logs
16 | static_configs:
17 | - targets:
18 | - localhost
19 | labels:
20 | job: grid_trading_bot
21 | __path__: /logs/**/*.log
22 |
23 | relabel_configs:
24 | - source_labels: [__path__]
25 | regex: '.*/bot_(?P[A-Z]+)_(?P[A-Z]+)_(?P[A-Z]+)_strategy(?P[A-Z_]+)_spacing(?P[A-Z]+)_size(?P\d+)_range(?P\d+-\d+)_.*\.log$'
26 | target_label: filename
27 |
28 | - source_labels: [base]
29 | target_label: base_currency
30 | - source_labels: [quote]
31 | target_label: quote_currency
32 | - source_labels: [mode]
33 | target_label: trading_mode
34 | - source_labels: [strategy]
35 | target_label: strategy_type
36 | - source_labels: [spacing]
37 | target_label: spacing_type
38 | - source_labels: [size]
39 | target_label: grid_size
40 | - source_labels: [range]
41 | target_label: grid_range
42 |
43 | pipeline_stages:
44 | - regex:
45 | expression: '^(?P\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}) - (?P[A-Za-z0-9_]+) - (?P[A-Z]+) - (?P.*)$'
46 |
47 | - labels:
48 | service: '{{ .service }}'
49 | level: '{{ .level }}'
50 |
51 | # Convert timestamp to Loki format
52 | - timestamp:
53 | source: timestamp
54 | format: "2006-01-02 15:04:05,000"
55 |
--------------------------------------------------------------------------------
/strategies/trading_strategy_interface.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import logging
3 |
4 |
5 | class TradingStrategyInterface(ABC):
6 | """
7 | Abstract base class for all trading strategies.
8 | Requires implementation of key methods for any concrete strategy.
9 | """
10 |
11 | def __init__(self, config_manager, balance_tracker):
12 | """
13 | Initializes the strategy with the given configuration manager and balance tracker.
14 |
15 | Args:
16 | config_manager: Provides access to the trading configuration (e.g., exchange, fees).
17 | balance_tracker: Tracks the balance and crypto balance for the strategy.
18 | """
19 | self.logger = logging.getLogger(self.__class__.__name__)
20 | self.config_manager = config_manager
21 | self.balance_tracker = balance_tracker
22 |
23 | @abstractmethod
24 | def initialize_strategy(self):
25 | """
26 | Method to initialize the strategy with specific settings (grids, limits, etc.).
27 | Must be implemented by any subclass.
28 | """
29 | pass
30 |
31 | @abstractmethod
32 | async def run(self):
33 | """
34 | Run the strategy with historical or live data.
35 | Must be implemented by any subclass.
36 | """
37 | pass
38 |
39 | @abstractmethod
40 | def plot_results(self):
41 | """
42 | Plots the strategy performance after simulation.
43 | Must be implemented by any subclass.
44 | """
45 | pass
46 |
47 | @abstractmethod
48 | def generate_performance_report(self) -> tuple[dict, list]:
49 | """
50 | Generates a report summarizing the strategy's performance (ROI, max drawdown, etc.).
51 | Must be implemented by any subclass.
52 | """
53 | pass
54 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Pull Request Template
2 |
3 | ## 📋 Description
4 |
5 | > **Provide a brief description of the changes introduced in this pull request.**
6 | > Explain the problem you are solving or the feature you are adding.
7 |
8 | Fixes # (issue number, if applicable)
9 |
10 | ---
11 |
12 | ## 🔍 Type of Change
13 |
14 | > **Select the type of change your PR introduces:**
15 |
16 | - [ ] 🐞 Bug fix (non-breaking change that fixes an issue)
17 | - [ ] ✨ New feature (non-breaking change that adds functionality)
18 | - [ ] 💥 Breaking change (fix or feature that would cause existing functionality to change)
19 | - [ ] 📝 Documentation update
20 | - [ ] 🔧 Code refactor or optimization
21 | - [ ] ✅ Test improvement or addition
22 | - [ ] Other (please specify):
23 |
24 | ---
25 |
26 | ## ✅ Checklist
27 |
28 | > **Ensure your pull request meets the following requirements:**
29 |
30 | - [ ] I have tested my changes locally and ensured they work as expected.
31 | - [ ] I have added necessary documentation (if applicable).
32 | - [ ] I have added tests that prove my fix/feature works as intended.
33 | - [ ] My code follows the project's code style and guidelines.
34 | - [ ] I have updated the README (if applicable).
35 |
36 | ---
37 |
38 | ## 🚦 How Has This Been Tested?
39 |
40 | > **Explain how you tested your changes.**
41 | > Include steps, test cases, or a description of your testing environment.
42 |
43 | ---
44 |
45 | ## 🔗 Related Issues or Discussions
46 |
47 | > **Mention any related issues or discussions (e.g., `Fixes #123`, `Closes #456`, or links to discussions).**
48 |
49 | ---
50 |
51 | ## 📷 Screenshots (if applicable)
52 |
53 | > **Include screenshots, logs, or recordings to illustrate your changes.**
54 |
55 | ---
56 |
57 | ## 🗒️ Additional Notes
58 |
59 | > **Add any additional context, notes, or considerations here.**
60 |
--------------------------------------------------------------------------------
/tests/order_handling/test_fee_calculator.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock
2 |
3 | import pytest
4 |
5 | from core.order_handling.fee_calculator import FeeCalculator
6 |
7 |
8 | class TestFeeCalculator:
9 | @pytest.fixture
10 | def config_manager(self):
11 | mock_config = Mock()
12 | mock_config.get_trading_fee.return_value = 0.001 # 0.1% trading fee
13 | return mock_config
14 |
15 | @pytest.fixture
16 | def fee_calculator(self, config_manager):
17 | return FeeCalculator(config_manager)
18 |
19 | def test_calculate_fee_basic(self, fee_calculator):
20 | trade_value = 1000
21 | expected_fee = 1 # 0.1% of 1000
22 | assert fee_calculator.calculate_fee(trade_value) == pytest.approx(expected_fee)
23 |
24 | def test_calculate_fee_zero(self, fee_calculator):
25 | trade_value = 0
26 | expected_fee = 0
27 | assert fee_calculator.calculate_fee(trade_value) == expected_fee
28 |
29 | def test_calculate_fee_small_value(self, fee_calculator):
30 | trade_value = 0.01 # 1 cent trade
31 | expected_fee = 0.00001 # 0.1% of 0.01
32 | assert fee_calculator.calculate_fee(trade_value) == pytest.approx(expected_fee, rel=1e-5)
33 |
34 | def test_calculate_fee_large_value(self, fee_calculator):
35 | trade_value = 1_000_000 # 1 million trade
36 | expected_fee = 1000 # 0.1% of 1 million
37 | assert fee_calculator.calculate_fee(trade_value) == pytest.approx(expected_fee)
38 |
39 | def test_trading_fee_from_config(self, config_manager, fee_calculator):
40 | assert fee_calculator.trading_fee == config_manager.get_trading_fee()
41 |
42 | def test_calculate_fee_tiny_trade_value_case(self, fee_calculator):
43 | trade_value = 0.0000001
44 | expected_fee = trade_value * 0.001
45 | assert fee_calculator.calculate_fee(trade_value) == pytest.approx(expected_fee, rel=1e-9)
46 |
--------------------------------------------------------------------------------
/core/services/exchange_interface.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any
3 |
4 | import pandas as pd
5 |
6 |
7 | class ExchangeInterface(ABC):
8 | @abstractmethod
9 | async def get_balance(self) -> dict[str, Any]:
10 | """Fetches the account balance, returning a dictionary with fiat and crypto balances."""
11 | pass
12 |
13 | @abstractmethod
14 | async def place_order(
15 | self,
16 | pair: str,
17 | order_side: str,
18 | order_type: str,
19 | amount: float,
20 | price: float | None = None,
21 | ) -> dict[str, str | float]:
22 | """Places an order, returning a dictionary with order details including id and status."""
23 | pass
24 |
25 | @abstractmethod
26 | def fetch_ohlcv(
27 | self,
28 | pair: str,
29 | timeframe: str,
30 | start_date: str,
31 | end_date: str,
32 | ) -> pd.DataFrame:
33 | """
34 | Fetches historical OHLCV data as a list of dictionaries, each containing open, high, low,
35 | close, and volume for the specified time period.
36 | """
37 | pass
38 |
39 | @abstractmethod
40 | async def get_current_price(
41 | self,
42 | pair: str,
43 | ) -> float:
44 | """Fetches the current market price for the specified trading pair."""
45 | pass
46 |
47 | @abstractmethod
48 | async def cancel_order(
49 | self,
50 | order_id: str,
51 | pair: str,
52 | ) -> dict[str, str | float]:
53 | """Attempts to cancel an order by ID, returning the result of the cancellation."""
54 | pass
55 |
56 | @abstractmethod
57 | async def get_exchange_status(self) -> dict:
58 | """Fetches current exchange status."""
59 | pass
60 |
61 | @abstractmethod
62 | async def close_connection(self) -> None:
63 | """Close current exchange connection."""
64 | pass
65 |
--------------------------------------------------------------------------------
/utils/logging_config.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from logging.handlers import RotatingFileHandler
3 | import os
4 |
5 |
6 | def setup_logging(
7 | log_level: int,
8 | log_to_file: bool = False,
9 | config_name: str | None = None,
10 | max_file_size: int = 5_000_000, # 5MB default max file size for rotation
11 | backup_count: int = 5, # Default number of backup files
12 | ) -> None:
13 | """
14 | Sets up logging with options for console, rotating file logging, and log differentiation.
15 |
16 | Args:
17 | log_level (int): The logging level (e.g., logging.INFO, logging.DEBUG).
18 | log_to_file (bool): Whether to log to a file.
19 | config_name (Optional[str]): Name of the bot configuration to differentiate logs.
20 | max_file_size (int): Maximum size of log file in bytes before rotation.
21 | backup_count (int): Number of backup log files to keep.
22 | """
23 | handlers = []
24 |
25 | console_handler = logging.StreamHandler()
26 | console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
27 | handlers.append(console_handler)
28 | log_file_path = ""
29 |
30 | if log_to_file:
31 | log_dir = "logs"
32 | os.makedirs(log_dir, exist_ok=True)
33 |
34 | if config_name:
35 | log_file_path = os.path.join(log_dir, f"{config_name}.log")
36 | else:
37 | log_file_path = os.path.join(log_dir, "grid_trading_bot.log")
38 |
39 | file_handler = RotatingFileHandler(log_file_path, maxBytes=max_file_size, backupCount=backup_count)
40 | file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
41 | handlers.append(file_handler)
42 |
43 | logging.basicConfig(level=log_level, handlers=handlers)
44 | logging.info(f"Logging initialized. Log level: {logging.getLevelName(log_level)}")
45 |
46 | if log_to_file:
47 | logging.info(f"File logging enabled. Logs are stored in: {log_file_path}")
48 |
--------------------------------------------------------------------------------
/tests/utils/test_config_name_generator.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock, patch
2 |
3 | import pytest
4 |
5 | from utils.config_name_generator import generate_config_name
6 |
7 |
8 | @pytest.fixture
9 | def mock_config_manager():
10 | mock_manager = Mock()
11 | mock_manager.get_base_currency.return_value = "BTC"
12 | mock_manager.get_quote_currency.return_value = "USD"
13 | mock_manager.get_trading_mode.return_value.name = "LIVE"
14 | mock_manager.get_strategy_type.return_value.name = "GRID"
15 | mock_manager.get_spacing_type.return_value.name = "PERCENTAGE"
16 | mock_manager.get_num_grids.return_value = 10
17 | mock_manager.get_top_range.return_value = 50000
18 | mock_manager.get_bottom_range.return_value = 30000
19 | return mock_manager
20 |
21 |
22 | @patch("utils.config_name_generator.datetime")
23 | def test_generate_config_name(mock_datetime, mock_config_manager):
24 | mock_datetime.now.return_value.strftime.return_value = "20241220_1200"
25 |
26 | result = generate_config_name(mock_config_manager)
27 |
28 | expected_name = "bot_BTC_USD_LIVE_strategyGRID_spacingPERCENTAGE_size10_range30000-50000_20241220_1200"
29 | assert result == expected_name
30 | mock_config_manager.get_base_currency.assert_called_once()
31 | mock_config_manager.get_quote_currency.assert_called_once()
32 | mock_config_manager.get_trading_mode.assert_called_once()
33 | mock_config_manager.get_strategy_type.assert_called_once()
34 | mock_config_manager.get_spacing_type.assert_called_once()
35 | mock_config_manager.get_num_grids.assert_called_once()
36 | mock_config_manager.get_top_range.assert_called_once()
37 | mock_config_manager.get_bottom_range.assert_called_once()
38 |
39 |
40 | def test_generate_config_name_edge_cases(mock_config_manager):
41 | mock_config_manager.get_base_currency.return_value = "ETH"
42 | mock_config_manager.get_quote_currency.return_value = "EUR"
43 | mock_config_manager.get_num_grids.return_value = 0
44 | mock_config_manager.get_top_range.return_value = 0
45 | mock_config_manager.get_bottom_range.return_value = 0
46 |
47 | result = generate_config_name(mock_config_manager)
48 |
49 | assert "bot_ETH_EUR" in result
50 | assert "_size0_range0-0" in result
51 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | loki:
3 | image: grafana/loki:3.0.0
4 | container_name: loki
5 | restart: unless-stopped
6 | user: root
7 | ports:
8 | - "3100:3100"
9 | command: -config.file=/etc/loki/loki.yaml
10 | volumes:
11 | - ./monitoring/configs/loki/loki.yaml:/etc/loki/loki.yaml
12 | - ./monitoring/configs/loki/rules.yaml:/etc/loki/rules/fake/loki-rules.yml
13 | - loki-data:/tmp/loki/chunks
14 | cpus: 0.5
15 | mem_limit: 512m
16 | networks:
17 | - default
18 |
19 | grafana:
20 | image: grafana/grafana:11.0.0
21 | container_name: grafana
22 | restart: unless-stopped
23 | ports:
24 | - "3000:3000"
25 | environment:
26 | - GF_AUTH_ANONYMOUS_ENABLED=false
27 | - GF_SECURITY_ADMIN_USER=${GRAFANA_ADMIN_USER}
28 | - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD}
29 | - GF_USERS_DEFAULT_THEME=dark
30 | - GF_LOG_MODE=console
31 | - GF_LOG_LEVEL=debug
32 | - GF_PANELS_ENABLE_ALPHA=false
33 | - GF_FEATURE_TOGGLES_ENABLE=lokiLogsDataplane
34 | - GF_INSTALL_PLUGINS=grafana-polystat-panel
35 | volumes:
36 | - ./monitoring/configs/grafana/provisioning/datasources.yml:/etc/grafana/provisioning/datasources/provisioning-datasources.yaml:ro
37 | - ./monitoring/configs/grafana/provisioning/dashboards.yml:/etc/grafana/provisioning/dashboards/dashboards.yml:ro
38 | - ./monitoring/dashboards/grid_trading_bot_dashboard.json:/var/lib/grafana/dashboards/grid_trading_bot_dashboard.json:ro
39 | - grafana-data:/var/lib/grafana
40 | depends_on:
41 | - loki
42 | cpus: 0.5
43 | mem_limit: 512m
44 | networks:
45 | - default
46 |
47 | promtail:
48 | image: grafana/promtail:3.0.0
49 | container_name: promtail
50 | restart: unless-stopped
51 | volumes:
52 | - ./monitoring/configs/promtail/promtail.yaml:/etc/promtail/promtail.yaml
53 | - ./logs:/logs:ro
54 | - promtail-positions:/tmp
55 | command: -config.file=/etc/promtail/promtail.yaml -config.expand-env=true
56 | depends_on:
57 | - loki
58 | cpus: 0.5
59 | mem_limit: 512m
60 | networks:
61 | - default
62 |
63 | volumes:
64 | grafana-data:
65 | driver: local
66 | loki-data:
67 | driver: local
68 | promtail-positions:
69 | driver: local
70 |
--------------------------------------------------------------------------------
/core/order_handling/order_book.py:
--------------------------------------------------------------------------------
1 | from ..grid_management.grid_level import GridLevel
2 | from .order import Order, OrderSide, OrderStatus
3 |
4 |
5 | class OrderBook:
6 | def __init__(self):
7 | self.buy_orders: list[Order] = []
8 | self.sell_orders: list[Order] = []
9 | self.non_grid_orders: list[Order] = [] # Orders that are not linked to any grid level
10 | self.order_to_grid_map: dict[Order, GridLevel] = {} # Mapping of Order -> GridLevel
11 |
12 | def add_order(
13 | self,
14 | order: Order,
15 | grid_level: GridLevel | None = None,
16 | ) -> None:
17 | if order.side == OrderSide.BUY:
18 | self.buy_orders.append(order)
19 | else:
20 | self.sell_orders.append(order)
21 |
22 | if grid_level:
23 | self.order_to_grid_map[order] = grid_level # Store the grid level associated with this order
24 | else:
25 | self.non_grid_orders.append(order) # This is a non-grid order like take profit or stop loss
26 |
27 | def get_buy_orders_with_grid(self) -> list[tuple[Order, GridLevel | None]]:
28 | return [(order, self.order_to_grid_map.get(order, None)) for order in self.buy_orders]
29 |
30 | def get_sell_orders_with_grid(self) -> list[tuple[Order, GridLevel | None]]:
31 | return [(order, self.order_to_grid_map.get(order, None)) for order in self.sell_orders]
32 |
33 | def get_all_buy_orders(self) -> list[Order]:
34 | return self.buy_orders
35 |
36 | def get_all_sell_orders(self) -> list[Order]:
37 | return self.sell_orders
38 |
39 | def get_open_orders(self) -> list[Order]:
40 | return [order for order in self.buy_orders + self.sell_orders if order.is_open()]
41 |
42 | def get_completed_orders(self) -> list[Order]:
43 | return [order for order in self.buy_orders + self.sell_orders if order.is_filled()]
44 |
45 | def get_grid_level_for_order(self, order: Order) -> GridLevel | None:
46 | return self.order_to_grid_map.get(order)
47 |
48 | def update_order_status(
49 | self,
50 | order_id: str,
51 | new_status: OrderStatus,
52 | ) -> None:
53 | for order in self.buy_orders + self.sell_orders:
54 | if order.identifier == order_id:
55 | order.status = new_status
56 | break
57 |
--------------------------------------------------------------------------------
/core/order_handling/execution_strategy/backtest_order_execution_strategy.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from ..order import Order, OrderSide, OrderStatus, OrderType
4 | from .order_execution_strategy_interface import OrderExecutionStrategyInterface
5 |
6 |
7 | class BacktestOrderExecutionStrategy(OrderExecutionStrategyInterface):
8 | async def execute_market_order(
9 | self,
10 | order_side: OrderSide,
11 | pair: str,
12 | quantity: float,
13 | price: float,
14 | ) -> Order | None:
15 | order_id = f"backtest-{int(time.time())}"
16 | timestamp = int(time.time() * 1000)
17 | return Order(
18 | identifier=order_id,
19 | status=OrderStatus.OPEN,
20 | order_type=OrderType.MARKET,
21 | side=order_side,
22 | price=price,
23 | average=price,
24 | amount=quantity,
25 | filled=quantity,
26 | remaining=0,
27 | timestamp=timestamp,
28 | datetime="111",
29 | last_trade_timestamp=1,
30 | symbol=pair,
31 | time_in_force="GTC",
32 | )
33 |
34 | async def execute_limit_order(
35 | self,
36 | order_side: OrderSide,
37 | pair: str,
38 | quantity: float,
39 | price: float,
40 | ) -> Order | None:
41 | order_id = f"backtest-{int(time.time())}"
42 | return Order(
43 | identifier=order_id,
44 | status=OrderStatus.OPEN,
45 | order_type=OrderType.LIMIT,
46 | side=order_side,
47 | price=price,
48 | average=price,
49 | amount=quantity,
50 | filled=0,
51 | remaining=quantity,
52 | timestamp=0,
53 | datetime="",
54 | last_trade_timestamp=1,
55 | symbol=pair,
56 | time_in_force="GTC",
57 | )
58 |
59 | async def get_order(
60 | self,
61 | order_id: str,
62 | pair: str,
63 | ) -> Order | None:
64 | return Order(
65 | identifier=order_id,
66 | status=OrderStatus.OPEN,
67 | order_type=OrderType.LIMIT,
68 | side=OrderSide.BUY,
69 | price=100,
70 | average=100,
71 | amount=1,
72 | filled=1,
73 | remaining=0,
74 | timestamp=0,
75 | datetime="111",
76 | last_trade_timestamp=1,
77 | symbol=pair,
78 | time_in_force="GTC",
79 | )
80 |
--------------------------------------------------------------------------------
/tests/services/test_exchange_service_factory.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock, patch
2 |
3 | import pytest
4 |
5 | from config.config_manager import ConfigManager
6 | from config.trading_mode import TradingMode
7 | from core.services.backtest_exchange_service import BacktestExchangeService
8 | from core.services.exchange_service_factory import ExchangeServiceFactory
9 | from core.services.live_exchange_service import LiveExchangeService
10 |
11 |
12 | class TestExchangeServiceFactory:
13 | @pytest.fixture
14 | def config_manager(self):
15 | config_manager = Mock(spec=ConfigManager)
16 | config_manager.get_trading_mode.return_value = TradingMode.LIVE
17 | config_manager.get_exchange_name.return_value = "binance"
18 | return config_manager
19 |
20 | @patch("core.services.live_exchange_service.ccxtpro")
21 | @patch("core.services.live_exchange_service.getattr")
22 | def test_create_live_exchange_service_with_env_vars(self, mock_getattr, mock_ccxtpro, config_manager, monkeypatch):
23 | monkeypatch.setenv("EXCHANGE_API_KEY", "test_api_key")
24 | monkeypatch.setenv("EXCHANGE_SECRET_KEY", "test_secret_key")
25 |
26 | mock_exchange_instance = Mock()
27 | mock_ccxtpro.binance.return_value = mock_exchange_instance
28 | mock_getattr.return_value = mock_ccxtpro.binance
29 |
30 | service = ExchangeServiceFactory.create_exchange_service(config_manager, TradingMode.LIVE)
31 |
32 | assert isinstance(service, LiveExchangeService), "Expected a LiveExchangeService instance"
33 | mock_getattr.assert_called_once_with(mock_ccxtpro, "binance")
34 | mock_ccxtpro.binance.assert_called_once_with(
35 | {
36 | "apiKey": "test_api_key",
37 | "secret": "test_secret_key",
38 | "enableRateLimit": True,
39 | },
40 | )
41 |
42 | @patch("core.services.live_exchange_service.ccxtpro")
43 | @patch("core.services.live_exchange_service.getattr")
44 | def test_create_backtest_exchange_service(self, mock_getattr, mock_ccxtpro, config_manager):
45 | config_manager.get_trading_mode.return_value = TradingMode.BACKTEST
46 | service = ExchangeServiceFactory.create_exchange_service(config_manager, TradingMode.BACKTEST)
47 | assert isinstance(service, BacktestExchangeService), "Expected a BacktestExchangeService instance"
48 |
49 | def test_invalid_trading_mode(self, config_manager):
50 | config_manager.get_trading_mode.return_value = "invalid_mode"
51 | with pytest.raises(ValueError, match="Unsupported trading mode: invalid_mode"):
52 | ExchangeServiceFactory.create_exchange_service(config_manager, "invalid_mode")
53 |
--------------------------------------------------------------------------------
/utils/performance_results_saver.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timedelta
2 | import json
3 | import logging
4 | import os
5 | from typing import Any
6 |
7 | import pandas as pd
8 |
9 |
10 | def save_or_append_performance_results(
11 | new_results: dict[str, Any],
12 | file_path: str,
13 | ) -> None:
14 | """
15 | Saves or appends performance results to a JSON file.
16 |
17 | Args:
18 | new_results: Dictionary containing performance summary and orders.
19 | file_path: Path to the JSON file.
20 | """
21 | try:
22 | if os.path.exists(file_path):
23 | with open(file_path) as json_file:
24 | try:
25 | all_results = json.load(json_file)
26 | if not isinstance(all_results, list):
27 | logging.error(f"Existing file {file_path} is not a valid JSON list. Overwriting the file.")
28 | all_results = []
29 | except json.JSONDecodeError:
30 | logging.warning(f"Could not decode JSON from {file_path}. Overwriting the file.")
31 | all_results = []
32 | else:
33 | all_results = []
34 |
35 | cleaned_performance_summary = {
36 | key: (
37 | value.isoformat()
38 | if isinstance(value, datetime | pd.Timestamp)
39 | else str(value)
40 | if isinstance(value, timedelta)
41 | else value
42 | )
43 | for key, value in new_results.get("performance_summary").items()
44 | }
45 |
46 | order_keys = ["Order Side", "Type", "Status", "Price", "Quantity", "Timestamp", "Grid Level", "Slippage"]
47 | cleaned_orders = [
48 | {
49 | key: (value.isoformat() if isinstance(value, datetime | pd.Timestamp) else value)
50 | for key, value in zip(order_keys, order, strict=False)
51 | }
52 | for order in new_results.get("orders")
53 | ]
54 |
55 | cleaned_results = {
56 | "config": new_results.get("config"),
57 | "performance_summary": cleaned_performance_summary,
58 | "orders": cleaned_orders,
59 | }
60 | all_results.append(cleaned_results)
61 |
62 | with open(file_path, "w") as json_file:
63 | json.dump(all_results, json_file, indent=4)
64 |
65 | logging.info(f"Performance metrics saved to {file_path}")
66 |
67 | except OSError as e:
68 | logging.error(f"Failed to save performance metrics to {file_path}: {e}")
69 |
70 | except Exception as e:
71 | logging.error(f"An unexpected error occurred while saving performance metrics: {e}")
72 |
--------------------------------------------------------------------------------
/tests/order_handling/test_order_execution_strategy_factory.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock
2 |
3 | import pytest
4 |
5 | from config.config_manager import ConfigManager
6 | from config.trading_mode import TradingMode
7 | from core.order_handling.execution_strategy.backtest_order_execution_strategy import (
8 | BacktestOrderExecutionStrategy,
9 | )
10 | from core.order_handling.execution_strategy.live_order_execution_strategy import (
11 | LiveOrderExecutionStrategy,
12 | )
13 | from core.order_handling.execution_strategy.order_execution_strategy_factory import (
14 | OrderExecutionStrategyFactory,
15 | )
16 | from core.services.exchange_interface import ExchangeInterface
17 |
18 |
19 | class TestOrderExecutionStrategyFactory:
20 | @pytest.fixture
21 | def config_manager(self):
22 | return Mock(spec=ConfigManager)
23 |
24 | @pytest.fixture
25 | def exchange_service(self):
26 | return Mock(spec=ExchangeInterface)
27 |
28 | def test_create_live_strategy(self, config_manager, exchange_service):
29 | config_manager.get_trading_mode.return_value = TradingMode.LIVE
30 | strategy = OrderExecutionStrategyFactory.create(config_manager, exchange_service)
31 |
32 | assert isinstance(
33 | strategy,
34 | LiveOrderExecutionStrategy,
35 | ), "Expected LiveOrderExecutionStrategy instance for live trading mode"
36 | assert (
37 | strategy.exchange_service == exchange_service
38 | ), "Expected exchange_service to be set correctly in LiveOrderExecutionStrategy"
39 |
40 | def test_create_paper_trading_strategy(self, config_manager, exchange_service):
41 | config_manager.get_trading_mode.return_value = TradingMode.PAPER_TRADING
42 | strategy = OrderExecutionStrategyFactory.create(config_manager, exchange_service)
43 |
44 | assert isinstance(
45 | strategy,
46 | LiveOrderExecutionStrategy,
47 | ), "Expected LiveOrderExecutionStrategy instance for paper trading mode"
48 | assert (
49 | strategy.exchange_service == exchange_service
50 | ), "Expected exchange_service to be set correctly in LiveOrderExecutionStrategy"
51 |
52 | def test_create_backtest_strategy(self, config_manager, exchange_service):
53 | config_manager.get_trading_mode.return_value = TradingMode.BACKTEST
54 | strategy = OrderExecutionStrategyFactory.create(config_manager, exchange_service)
55 |
56 | assert isinstance(
57 | strategy,
58 | BacktestOrderExecutionStrategy,
59 | ), "Expected BacktestOrderExecutionStrategy instance for backtesting mode"
60 |
61 | def test_invalid_trading_mode_raises_error(self, config_manager, exchange_service):
62 | config_manager.get_trading_mode.return_value = "UNKNOWN_MODE" # Simulate an invalid mode
63 | with pytest.raises(ValueError, match="Unknown trading mode: UNKNOWN_MODE"):
64 | OrderExecutionStrategyFactory.create(config_manager, exchange_service)
65 |
--------------------------------------------------------------------------------
/tests/utils/test_performance_results_saver.py:
--------------------------------------------------------------------------------
1 | from datetime import UTC, datetime, timedelta
2 | from unittest.mock import mock_open, patch
3 |
4 | import pytest
5 |
6 | from utils.performance_results_saver import save_or_append_performance_results
7 |
8 |
9 | @pytest.fixture
10 | def new_results_fixture():
11 | return {
12 | "config": "config.json",
13 | "performance_summary": {
14 | "start_time": datetime(2024, 12, 20, 10, 0, 0, tzinfo=UTC).isoformat(),
15 | "end_time": datetime(2024, 12, 20, 12, 0, 0, tzinfo=UTC).isoformat(),
16 | "total_profit": 500.0,
17 | "runtime": str(timedelta(hours=2)),
18 | },
19 | "orders": [
20 | [
21 | "BUY",
22 | "LIMIT",
23 | "FILLED",
24 | 1000.0,
25 | 0.5,
26 | datetime(2024, 12, 20, 10, 30, 0, tzinfo=UTC).isoformat(),
27 | "Level 1",
28 | 0.1,
29 | ],
30 | [
31 | "SELL",
32 | "LIMIT",
33 | "FILLED",
34 | 1500.0,
35 | 0.5,
36 | datetime(2024, 12, 20, 11, 30, 0, tzinfo=UTC).isoformat(),
37 | "Level 2",
38 | 0.05,
39 | ],
40 | ],
41 | }
42 |
43 |
44 | def test_save_or_append_performance_results_invalid_json(new_results_fixture):
45 | with (
46 | patch("builtins.open", mock_open(read_data="INVALID_JSON")) as mocked_file,
47 | patch("os.path.exists", return_value=True),
48 | patch("utils.performance_results_saver.logging.warning") as mock_logger_warning,
49 | ):
50 | save_or_append_performance_results(new_results_fixture, "results.json")
51 |
52 | mocked_file.assert_any_call("results.json")
53 | mocked_file.assert_any_call("results.json", "w")
54 | mock_logger_warning.assert_called_once_with("Could not decode JSON from results.json. Overwriting the file.")
55 |
56 |
57 | def test_save_or_append_performance_results_os_error(new_results_fixture):
58 | with (
59 | patch("builtins.open", side_effect=OSError("Test OS Error")),
60 | patch("utils.performance_results_saver.logging.error") as mock_logger_error,
61 | ):
62 | save_or_append_performance_results(new_results_fixture, "results.json")
63 |
64 | mock_logger_error.assert_called_once_with("Failed to save performance metrics to results.json: Test OS Error")
65 |
66 |
67 | def test_save_or_append_performance_results_unexpected_exception(new_results_fixture):
68 | with (
69 | patch("builtins.open", side_effect=Exception("Unexpected Error")),
70 | patch("utils.performance_results_saver.logging.error") as mock_logger_error,
71 | ):
72 | save_or_append_performance_results(new_results_fixture, "results.json")
73 |
74 | mock_logger_error.assert_any_call(
75 | "An unexpected error occurred while saving performance metrics: Unexpected Error",
76 | )
77 |
--------------------------------------------------------------------------------
/tests/utils/test_logging_config.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from unittest.mock import MagicMock, patch
3 |
4 | import pytest
5 |
6 | from utils.logging_config import setup_logging
7 |
8 |
9 | @pytest.fixture
10 | def mock_makedirs():
11 | with patch("os.makedirs") as mocked_makedirs:
12 | yield mocked_makedirs
13 |
14 |
15 | @pytest.fixture
16 | def mock_basic_config():
17 | with patch("logging.basicConfig") as mocked_basic_config:
18 | yield mocked_basic_config
19 |
20 |
21 | @pytest.fixture
22 | def mock_rotating_file_handler():
23 | with patch("logging.handlers.RotatingFileHandler") as mocked_handler:
24 | mocked_handler.return_value = MagicMock()
25 | yield mocked_handler
26 |
27 |
28 | def test_setup_logging_console_only(mock_basic_config):
29 | setup_logging(log_level=logging.INFO, log_to_file=False)
30 |
31 | mock_basic_config.assert_called_once()
32 | handlers = mock_basic_config.call_args[1]["handlers"]
33 | assert len(handlers) == 1
34 | assert isinstance(handlers[0], logging.StreamHandler)
35 |
36 |
37 | @patch("os.makedirs")
38 | @patch("logging.basicConfig")
39 | @patch("logging.handlers.RotatingFileHandler")
40 | def test_setup_logging_file_logging(mock_rotating_file_handler, mock_basic_config, mock_makedirs):
41 | setup_logging(
42 | log_level=logging.DEBUG,
43 | log_to_file=True,
44 | config_name="test_config",
45 | max_file_size=10_000_000,
46 | backup_count=3,
47 | )
48 |
49 | mock_makedirs.assert_called_once_with("logs", exist_ok=True)
50 | mock_basic_config.assert_called_once()
51 | handlers = mock_basic_config.call_args[1]["handlers"]
52 |
53 | assert len(handlers) == 2
54 | assert any(isinstance(handler, logging.StreamHandler) for handler in handlers)
55 |
56 |
57 | @patch("os.makedirs")
58 | @patch("logging.basicConfig")
59 | @patch("logging.handlers.RotatingFileHandler")
60 | def test_setup_logging_default_file_logging(mock_rotating_file_handler, mock_basic_config, mock_makedirs):
61 | setup_logging(log_level=logging.WARNING, log_to_file=True)
62 |
63 | mock_makedirs.assert_called_once_with("logs", exist_ok=True)
64 | mock_basic_config.assert_called_once()
65 | handlers = mock_basic_config.call_args[1]["handlers"]
66 |
67 | assert len(handlers) == 2
68 | assert any(isinstance(handler, logging.StreamHandler) for handler in handlers)
69 |
70 |
71 | def test_setup_logging_logs_info(mock_basic_config, mock_rotating_file_handler, caplog):
72 | with caplog.at_level(logging.INFO):
73 | setup_logging(log_level=logging.INFO, log_to_file=True, config_name="test_config")
74 |
75 | assert "Logging initialized. Log level: INFO" in caplog.text
76 | assert "File logging enabled. Logs are stored in: logs/test_config.log" in caplog.text
77 |
78 |
79 | def test_setup_logging_directory_creation_error(mock_makedirs):
80 | mock_makedirs.side_effect = OSError("Directory creation failed")
81 |
82 | with pytest.raises(OSError, match="Directory creation failed"):
83 | setup_logging(log_level=logging.DEBUG, log_to_file=True)
84 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "grid_trading_bot"
3 | version = "0.1.0"
4 | description = "Open-source cryptocurrency trading bot designed to perform grid trading strategies using historical data for backtesting."
5 | authors = [{ name = "Jordan TETE", email = "tetej171@gmail.com" }]
6 | readme = "README.md"
7 | license = { file = "LICENSE.txt" }
8 | requires-python = ">=3.12"
9 | dependencies = [
10 | "pandas==2.2.3",
11 | "numpy==2.1.3",
12 | "plotly==6.4.0",
13 | "tabulate==0.9.0",
14 | "aiohttp==3.10.11",
15 | "apprise==1.9.5",
16 | "ccxt==4.4.82",
17 | "configparser==7.2.0",
18 | "psutil==7.1.3",
19 | "python-dotenv==1.2.1"
20 | ]
21 |
22 | [project.optional-dependencies]
23 | dev = [
24 | "pytest==8.4.2",
25 | "pytest-asyncio==0.26.0",
26 | "pytest-cov==7.0.0",
27 | "pytest-timeout==2.4.0",
28 | "pre-commit==4.4.0",
29 | ]
30 |
31 | [project.urls]
32 | repository= "https://github.com/jordantete/grid_trading_bot"
33 | issues = "https://github.com/jordantete/grid_trading_bot/issues"
34 | discussions = "https://github.com/jordantete/grid_trading_bot/discussions"
35 |
36 | [tool.pytest.ini_options]
37 | asyncio_mode = "auto"
38 | testpaths = ["tests"]
39 | python_files = ["test_*.py"]
40 | python_classes = ["Test*"]
41 | python_functions = ["test_*"]
42 | log_cli = true
43 | log_cli_level = "INFO"
44 | markers = [
45 | "asyncio: mark a test as an async test",
46 | "timeout: mark a test with a timeout"
47 | ]
48 |
49 | [tool.coverage.run]
50 | omit = [
51 | "*/interface*.py"
52 | ]
53 |
54 | [tool.ruff]
55 | line-length = 120
56 | target-version = "py312"
57 | src = ["core", "strategies", "config", "utils", "tests"]
58 |
59 | [tool.ruff.lint]
60 | select = [
61 | "E", # pycodestyle errors
62 | "W", # pycodestyle warnings
63 | "F", # Pyflakes
64 | "I", # isort
65 | "B", # flake8-bugbear
66 | "C4", # flake8-comprehensions
67 | "UP", # pyupgrade
68 | "N", # pep8-naming
69 | "S", # flake8-bandit (security)
70 | "T20", # flake8-print
71 | "SIM", # flake8-simplify
72 | "RUF", # Ruff-specific rules
73 | "PT", # flake8-pytest-style
74 | "Q", # flake8-quotes
75 | "A", # flake8-builtins
76 | "COM", # flake8-commas
77 | "DTZ", # flake8-datetimez
78 | "TCH", # flake8-type-checking
79 | ]
80 | ignore = [
81 | "S101", # assert used (OK in tests)
82 | "COM812", # missing trailing comma (handled by formatter)
83 | ]
84 |
85 | [tool.ruff.format]
86 | quote-style = "double"
87 | indent-style = "space"
88 | skip-magic-trailing-comma = false
89 | line-ending = "auto"
90 | docstring-code-format = true
91 |
92 | [tool.ruff.lint.isort]
93 | known-first-party = ["core", "strategies", "config", "utils"]
94 | force-sort-within-sections = true
95 |
96 | [build-system]
97 | requires = ["setuptools>=42", "wheel"]
98 | build-backend = "setuptools.build_meta"
99 |
100 | [tool.setuptools.packages.find]
101 | where = ["."]
102 | exclude = ["logs", "data"]
103 |
104 | [tool.setuptools]
105 | include-package-data = true
106 |
--------------------------------------------------------------------------------
/tests/order_handling/test_backtest_order_execution_strategy.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import patch
2 |
3 | import pytest
4 |
5 | from core.order_handling.execution_strategy.backtest_order_execution_strategy import (
6 | BacktestOrderExecutionStrategy,
7 | )
8 | from core.order_handling.order import OrderSide, OrderStatus, OrderType
9 |
10 |
11 | @pytest.mark.asyncio
12 | class TestBacktestOrderExecutionStrategy:
13 | @pytest.fixture
14 | def setup_strategy(self):
15 | return BacktestOrderExecutionStrategy()
16 |
17 | @patch("time.time", return_value=1680000000)
18 | async def test_execute_market_order(self, mock_time, setup_strategy):
19 | strategy = setup_strategy
20 | order_side = OrderSide.BUY
21 | pair = "BTC/USDT"
22 | quantity = 0.5
23 | price = 30000
24 |
25 | order = await strategy.execute_market_order(order_side, pair, quantity, price)
26 |
27 | assert order is not None
28 | assert order.identifier == "backtest-1680000000"
29 | assert order.status == OrderStatus.OPEN
30 | assert order.order_type == OrderType.MARKET
31 | assert order.side == order_side
32 | assert order.price == price
33 | assert order.amount == quantity
34 | assert order.filled == quantity
35 | assert order.remaining == 0
36 | assert order.symbol == pair
37 | assert order.time_in_force == "GTC"
38 | assert order.timestamp == 1680000000000
39 |
40 | @patch("time.time", return_value=1680000000)
41 | async def test_execute_limit_order(self, mock_time, setup_strategy):
42 | strategy = setup_strategy
43 | order_side = OrderSide.SELL
44 | pair = "ETH/USDT"
45 | quantity = 1
46 | price = 2000
47 |
48 | order = await strategy.execute_limit_order(order_side, pair, quantity, price)
49 |
50 | assert order is not None
51 | assert order.identifier == "backtest-1680000000"
52 | assert order.status == OrderStatus.OPEN
53 | assert order.order_type == OrderType.LIMIT
54 | assert order.side == order_side
55 | assert order.price == price
56 | assert order.amount == quantity
57 | assert order.filled == 0
58 | assert order.remaining == quantity
59 | assert order.symbol == pair
60 | assert order.time_in_force == "GTC"
61 | assert order.timestamp == 0
62 |
63 | async def test_get_order(self, setup_strategy):
64 | strategy = setup_strategy
65 | order_id = "test-order-1"
66 | order_symbol = "BTC/USDT"
67 |
68 | order = await strategy.get_order(order_id, order_symbol)
69 |
70 | assert order is not None
71 | assert order.identifier == order_id
72 | assert order.status == OrderStatus.OPEN
73 | assert order.order_type == OrderType.LIMIT
74 | assert order.side == OrderSide.BUY
75 | assert order.price == 100
76 | assert order.average == 100
77 | assert order.amount == 1
78 | assert order.filled == 1
79 | assert order.remaining == 0
80 | assert order.symbol == order_symbol
81 | assert order.time_in_force == "GTC"
82 | assert order.timestamp == 0
83 |
--------------------------------------------------------------------------------
/core/bot_management/notification/notification_handler.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from concurrent.futures import ThreadPoolExecutor
3 | import logging
4 |
5 | import apprise
6 |
7 | from config.trading_mode import TradingMode
8 | from core.bot_management.event_bus import EventBus, Events
9 | from core.order_handling.order import Order
10 |
11 | from .notification_content import NotificationType
12 |
13 |
14 | class NotificationHandler:
15 | """
16 | Handles sending notifications through various channels using the Apprise library.
17 | Supports multiple notification services like Telegram, Discord, Slack, etc.
18 | """
19 |
20 | _executor = ThreadPoolExecutor(max_workers=3)
21 |
22 | def __init__(
23 | self,
24 | event_bus: EventBus,
25 | urls: list[str] | None,
26 | trading_mode: TradingMode,
27 | ):
28 | self.logger = logging.getLogger(self.__class__.__name__)
29 | self.event_bus = event_bus
30 | self.enabled = bool(urls) and trading_mode in {TradingMode.LIVE, TradingMode.PAPER_TRADING}
31 | self.lock = asyncio.Lock()
32 | self.apprise_instance = apprise.Apprise() if self.enabled else None
33 |
34 | if self.enabled and urls is not None:
35 | self.event_bus.subscribe(Events.ORDER_FILLED, self._send_notification_on_order_filled)
36 |
37 | for url in urls:
38 | self.apprise_instance.add(url)
39 |
40 | def send_notification(
41 | self,
42 | content: NotificationType | str,
43 | **kwargs,
44 | ) -> None:
45 | if self.enabled and self.apprise_instance:
46 | if isinstance(content, NotificationType):
47 | title = content.value.title
48 | message_template = content.value.message
49 | required_placeholders = {
50 | key.strip("{}") for key in message_template.split() if "{" in key and "}" in key
51 | }
52 | missing_placeholders = required_placeholders - kwargs.keys()
53 |
54 | if missing_placeholders:
55 | self.logger.warning(
56 | f"Missing placeholders for notification: {missing_placeholders}. "
57 | "Defaulting to 'N/A' for missing values.",
58 | )
59 |
60 | message = message_template.format(**{key: kwargs.get(key, "N/A") for key in required_placeholders})
61 | else:
62 | title = "Notification"
63 | message = content
64 |
65 | self.apprise_instance.notify(title=title, body=message)
66 |
67 | async def async_send_notification(
68 | self,
69 | content: NotificationType | str,
70 | **kwargs,
71 | ) -> None:
72 | async with self.lock:
73 | loop = asyncio.get_running_loop()
74 | try:
75 | await asyncio.wait_for(
76 | loop.run_in_executor(self._executor, lambda: self.send_notification(content, **kwargs)),
77 | timeout=5,
78 | )
79 | except Exception as e:
80 | self.logger.error(f"Failed to send notification: {e!s}")
81 |
82 | async def _send_notification_on_order_filled(self, order: Order) -> None:
83 | await self.async_send_notification(NotificationType.ORDER_FILLED, order_details=str(order))
84 |
--------------------------------------------------------------------------------
/core/order_handling/order.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import pandas as pd
4 |
5 |
6 | class OrderSide(Enum):
7 | BUY = "buy"
8 | SELL = "sell"
9 |
10 |
11 | class OrderType(Enum):
12 | MARKET = "market"
13 | LIMIT = "limit"
14 |
15 |
16 | class OrderStatus(Enum):
17 | OPEN = "open"
18 | CLOSED = "closed"
19 | CANCELED = "canceled"
20 | EXPIRED = "expired"
21 | REJECTED = "rejected"
22 | UNKNOWN = "unknown"
23 |
24 |
25 | class Order:
26 | def __init__(
27 | self,
28 | identifier: str,
29 | status: OrderStatus,
30 | order_type: OrderType,
31 | side: OrderSide,
32 | price: float,
33 | average: float | None,
34 | amount: float,
35 | filled: float,
36 | remaining: float,
37 | timestamp: int,
38 | datetime: str | None,
39 | last_trade_timestamp: int | None,
40 | symbol: str,
41 | time_in_force: str | None,
42 | trades: list[dict[str, str | float]] | None = None,
43 | fee: dict[str, str | float] | None = None,
44 | cost: float | None = None,
45 | info: dict[str, str | float | dict] | None = None,
46 | ):
47 | self.identifier = identifier
48 | self.status = status # 'open', 'closed', 'canceled', 'expired', 'rejected'
49 | self.order_type = order_type # 'market', 'limit'
50 | self.side = side # 'buy', 'sell'
51 | self.price = price # float price in quote currency (may be empty for market orders)
52 | self.average = average # float average filling price
53 | self.amount = amount # ordered amount of base currency
54 | self.filled = filled # filled amount of base currency
55 | self.remaining = remaining # remaining amount to fill
56 | self.timestamp = timestamp # order placing/opening Unix timestamp in milliseconds
57 | self.datetime = datetime # ISO8601 datetime of 'timestamp' with milliseconds
58 | self.last_trade_timestamp = last_trade_timestamp # Unix timestamp of the most recent trade on this order
59 | self.symbol = symbol # symbol
60 | self.time_in_force = time_in_force # 'GTC', 'IOC', 'FOK', 'PO'
61 | self.trades = trades # a list of order trades/executions
62 | self.fee = fee # fee info, if available
63 | self.cost = cost # 'filled' * 'price' (filling price used where available)
64 | self.info = info # Original unparsed structure for debugging or auditing
65 |
66 | def is_filled(self) -> bool:
67 | return self.status == OrderStatus.CLOSED
68 |
69 | def is_canceled(self) -> bool:
70 | return self.status == OrderStatus.CANCELED
71 |
72 | def is_open(self) -> bool:
73 | return self.status == OrderStatus.OPEN
74 |
75 | def format_last_trade_timestamp(self) -> str | None:
76 | if self.last_trade_timestamp is None:
77 | return None
78 | return pd.Timestamp(self.last_trade_timestamp, unit="s").isoformat()
79 |
80 | def __str__(self) -> str:
81 | return (
82 | f"Order(id={self.identifier}, status={self.status}, "
83 | f"type={self.order_type}, side={self.side}, price={self.price}, average={self.average}, "
84 | f"amount={self.amount}, filled={self.filled}, remaining={self.remaining}, "
85 | f"timestamp={self.timestamp}, datetime={self.datetime}, symbol={self.symbol}, "
86 | f"time_in_force={self.time_in_force}, trades={self.trades}, fee={self.fee}, cost={self.cost})"
87 | )
88 |
89 | def __repr__(self) -> str:
90 | return self.__str__()
91 |
--------------------------------------------------------------------------------
/monitoring/dashboards/grid_trading_bot_dashboard.json:
--------------------------------------------------------------------------------
1 | {
2 | "title": "Grid Trading Bot Dashboard",
3 | "variables": [
4 | {
5 | "name": "trading_pair",
6 | "type": "query",
7 | "datasource": "Loki",
8 | "query": "label_values(trading_pair)"
9 | },
10 | {
11 | "name": "trading_mode",
12 | "type": "query",
13 | "datasource": "Loki",
14 | "query": "label_values(trading_mode)"
15 | },
16 | {
17 | "name": "strategy",
18 | "type": "query",
19 | "datasource": "Loki",
20 | "query": "label_values(strategy_type)"
21 | }
22 | ],
23 | "panels": [
24 | {
25 | "title": "Strategy Overview",
26 | "type": "stat",
27 | "datasource": "Loki",
28 | "targets": [
29 | {
30 | "expr": "{job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\"} | json | line_format \"{{.grid_size}} grids, Range: {{.grid_range}}, Spacing: {{.spacing_type}}\""
31 | }
32 | ]
33 | },
34 | {
35 | "title": "ROI Over Time",
36 | "type": "timeseries",
37 | "datasource": "Loki",
38 | "targets": [
39 | {
40 | "expr": "{job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\"} | regexp \"ROI\\s+\\|\\s+(?P[\\-\\d\\.]+)%\" | unwrap roi"
41 | }
42 | ]
43 | },
44 | {
45 | "title": "Grid Level States",
46 | "type": "table",
47 | "datasource": "Loki",
48 | "targets": [
49 | {
50 | "expr": "{job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\"} | json | grid_price != \"\" and grid_state != \"\" | line_format \"{{.grid_price}} - {{.grid_state}}\""
51 | }
52 | ]
53 | },
54 | {
55 | "title": "Order Flow",
56 | "type": "timeseries",
57 | "datasource": "Loki",
58 | "targets": [
59 | {
60 | "expr": "sum(count_over_time({job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\", order_side=\"BUY\"}[5m]))",
61 | "legendFormat": "Buy Orders"
62 | },
63 | {
64 | "expr": "sum(count_over_time({job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\", order_side=\"SELL\"}[5m]))",
65 | "legendFormat": "Sell Orders"
66 | }
67 | ]
68 | },
69 | {
70 | "title": "Balance History",
71 | "type": "timeseries",
72 | "datasource": "Loki",
73 | "targets": [
74 | {
75 | "expr": "{job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\"} | regexp \"Balance: (?P[\\d\\.]+)\" | unwrap balance"
76 | }
77 | ]
78 | },
79 | {
80 | "title": "System Health",
81 | "type": "gauge",
82 | "datasource": "Loki",
83 | "targets": [
84 | {
85 | "expr": "{job=\"grid_trading_bot\", trading_pair=\"$trading_pair\"} | json | unwrap cpu"
86 | }
87 | ],
88 | "fieldConfig": {
89 | "defaults": {
90 | "thresholds": {
91 | "steps": [
92 | { "value": 0, "color": "green" },
93 | { "value": 70, "color": "yellow" },
94 | { "value": 85, "color": "red" }
95 | ]
96 | }
97 | }
98 | }
99 | }
100 | ],
101 | "refresh": "10s",
102 | "schemaVersion": 36
103 | }
104 |
--------------------------------------------------------------------------------
/utils/arg_parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import traceback
5 |
6 |
7 | def validate_args(args):
8 | """
9 | Validates parsed arguments.
10 |
11 | Args:
12 | args: Parsed arguments object.
13 | Raises:
14 | ValueError: If validation fails.
15 | """
16 | # Validate --config
17 | if args.config:
18 | for config_path in args.config:
19 | if not os.path.exists(config_path):
20 | raise ValueError(f"Config file does not exist: {config_path}")
21 |
22 | # Validate --save_performance_results directory
23 | if args.save_performance_results:
24 | save_performance_dir = os.path.dirname(args.save_performance_results)
25 | if save_performance_dir and not os.path.exists(save_performance_dir):
26 | raise ValueError(f"The directory for saving performance results does not exist: {save_performance_dir}")
27 |
28 |
29 | def parse_and_validate_console_args(cli_args=None):
30 | """
31 | Parses and validates console arguments.
32 |
33 | Args:
34 | cli_args: Optional CLI arguments for testing.
35 | Returns:
36 | argparse.Namespace: Parsed and validated arguments.
37 | Raises:
38 | RuntimeError: If argument parsing or validation fails.
39 | """
40 | try:
41 | parser = argparse.ArgumentParser(
42 | description="📈 Spot Grid Trading Bot - Automate your grid trading strategy with confidence\n\n"
43 | "This bot lets you automate your trading by implementing a grid strategy. "
44 | "Set your parameters, watch it execute, and manage your trades more effectively. "
45 | "Ideal for both beginners and experienced traders!",
46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
47 | )
48 |
49 | required_args = parser.add_argument_group("Required Arguments")
50 | required_args.add_argument(
51 | "--config",
52 | type=str,
53 | nargs="+",
54 | required=True,
55 | metavar="CONFIG",
56 | help="Path(s) to the configuration file(s) containing strategy details.",
57 | )
58 |
59 | optional_args = parser.add_argument_group("Optional Arguments")
60 | optional_args.add_argument(
61 | "--save_performance_results",
62 | type=str,
63 | metavar="FILE",
64 | help="Path to save simulation results (e.g., results.json).",
65 | )
66 | optional_args.add_argument(
67 | "--no-plot",
68 | action="store_true",
69 | help="Disable the display of plots at the end of the simulation.",
70 | )
71 | optional_args.add_argument(
72 | "--profile",
73 | action="store_true",
74 | help="Enable profiling for performance analysis.",
75 | )
76 |
77 | args = parser.parse_args(cli_args)
78 | validate_args(args)
79 | return args
80 |
81 | except SystemExit as e:
82 | if e.code == 0: # Exit code 0 indicates a successful --help invocation
83 | raise
84 | logging.error(f"Argument parsing failed: {e}")
85 | raise RuntimeError("Failed to parse arguments. Please check your inputs.") from e
86 |
87 | except ValueError as e:
88 | logging.error(f"Validation failed: {e}")
89 | raise RuntimeError("Argument validation failed.") from e
90 |
91 | except Exception as e:
92 | logging.error(f"An unexpected error occurred while parsing arguments: {e}")
93 | logging.error(traceback.format_exc())
94 | raise RuntimeError("An unexpected error occurred during argument parsing.") from e
95 |
--------------------------------------------------------------------------------
/tests/bot_management/test_event_bus.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | from unittest.mock import AsyncMock, Mock
4 |
5 | import pytest
6 |
7 | from core.bot_management.event_bus import EventBus, Events
8 |
9 |
10 | class TestEventBus:
11 | @pytest.fixture
12 | def event_bus(self):
13 | return EventBus()
14 |
15 | def test_subscribe(self, event_bus):
16 | callback = Mock()
17 | event_bus.subscribe(Events.ORDER_FILLED, callback)
18 | assert Events.ORDER_FILLED in event_bus.subscribers
19 | assert callback in event_bus.subscribers[Events.ORDER_FILLED]
20 |
21 | @pytest.mark.asyncio
22 | async def test_publish_async_single_callback(self, event_bus):
23 | async_callback = AsyncMock()
24 | event_bus.subscribe(Events.ORDER_FILLED, async_callback)
25 | await event_bus.publish(Events.ORDER_FILLED, {"data": "test"})
26 | async_callback.assert_awaited_once_with({"data": "test"})
27 |
28 | @pytest.mark.asyncio
29 | async def test_publish_async_multiple_callbacks(self, event_bus):
30 | async_callback_1 = AsyncMock()
31 | async_callback_2 = AsyncMock()
32 | event_bus.subscribe(Events.ORDER_FILLED, async_callback_1)
33 | event_bus.subscribe(Events.ORDER_FILLED, async_callback_2)
34 | await event_bus.publish(Events.ORDER_FILLED, {"data": "test"})
35 | async_callback_1.assert_awaited_once_with({"data": "test"})
36 | async_callback_2.assert_awaited_once_with({"data": "test"})
37 |
38 | @pytest.mark.asyncio
39 | async def test_publish_async_with_exception(self, event_bus, caplog):
40 | failing_callback = AsyncMock(side_effect=Exception("Test Error"))
41 | event_bus.subscribe(Events.ORDER_FILLED, failing_callback)
42 |
43 | await event_bus.publish(Events.ORDER_FILLED, {"data": "test"})
44 |
45 | # Wait for all tasks in the event bus to complete
46 | await asyncio.gather(*event_bus._tasks, return_exceptions=True)
47 |
48 | assert "Error in async callback 'AsyncMock'" in caplog.text
49 | assert "Test Error" in caplog.text
50 |
51 | def test_publish_sync(self, event_bus):
52 | sync_callback = Mock()
53 | event_bus.subscribe(Events.ORDER_FILLED, sync_callback)
54 | event_bus.publish_sync(Events.ORDER_FILLED, {"data": "test"})
55 | sync_callback.assert_called_once_with({"data": "test"})
56 |
57 | @pytest.mark.asyncio
58 | async def test_safe_invoke_async(self, event_bus, caplog):
59 | async_callback = AsyncMock()
60 |
61 | await event_bus._safe_invoke_async(async_callback, {"data": "test"})
62 |
63 | # Wait for all tasks in the EventBus to complete
64 | await asyncio.gather(*event_bus._tasks, return_exceptions=True)
65 |
66 | async_callback.assert_awaited_once_with({"data": "test"})
67 |
68 | @pytest.mark.asyncio
69 | async def test_safe_invoke_async_with_exception(self, event_bus, caplog):
70 | failing_callback = AsyncMock(side_effect=Exception("Async Error"))
71 | caplog.set_level(logging.DEBUG)
72 |
73 | await event_bus._safe_invoke_async(failing_callback, {"data": "test"})
74 | await asyncio.gather(*event_bus._tasks, return_exceptions=True)
75 |
76 | assert "Error in async callback" in caplog.text
77 | assert "Async Error" in caplog.text
78 | assert "Task created for callback" in caplog.text
79 |
80 | def test_safe_invoke_sync(self, event_bus, caplog):
81 | sync_callback = Mock()
82 | event_bus._safe_invoke_sync(sync_callback, {"data": "test"})
83 | sync_callback.assert_called_once_with({"data": "test"})
84 |
85 | def test_safe_invoke_sync_with_exception(self, event_bus, caplog):
86 | failing_callback = Mock(side_effect=Exception("Sync Error"))
87 | event_bus._safe_invoke_sync(failing_callback, {"data": "test"})
88 | assert "Error in sync subscriber callback" in caplog.text
89 |
--------------------------------------------------------------------------------
/tests/utils/test_arg_parser.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from unittest.mock import patch
3 |
4 | import pytest
5 |
6 | from utils.arg_parser import parse_and_validate_console_args
7 |
8 |
9 | @pytest.mark.parametrize(
10 | ("args", "expected_config"),
11 | [
12 | (["--config", "config1.json"], ["config1.json"]),
13 | (["--config", "config1.json", "config2.json"], ["config1.json", "config2.json"]),
14 | ],
15 | )
16 | @patch("os.path.exists", return_value=True) # Mock os.path.exists to always return True
17 | def test_parse_and_validate_console_args_required(mock_exists, args, expected_config):
18 | with patch.object(sys, "argv", ["program_name", *args]):
19 | result = parse_and_validate_console_args()
20 | assert result.config == expected_config, f"Expected {expected_config}, got {result.config}"
21 |
22 |
23 | @patch("os.path.exists", return_value=True)
24 | def test_parse_and_validate_console_args_save_performance_results_exists(mock_exists):
25 | with patch.object(
26 | sys,
27 | "argv",
28 | ["program_name", "--config", "config.json", "--save_performance_results", "results.json"],
29 | ):
30 | result = parse_and_validate_console_args()
31 | assert result.save_performance_results == "results.json"
32 |
33 |
34 | def test_parse_and_validate_console_args_save_performance_results_dir_does_not_exist():
35 | with (
36 | patch("os.path.exists", side_effect=lambda path: path == "config.json"),
37 | patch.object(
38 | sys,
39 | "argv",
40 | ["program_name", "--config", "config.json", "--save_performance_results", "non_existent_dir/results.json"],
41 | ),
42 | patch("utils.arg_parser.logging.error") as mock_log,
43 | ):
44 | with pytest.raises(RuntimeError, match="Argument validation failed."):
45 | parse_and_validate_console_args()
46 | mock_log.assert_called_once_with(
47 | "Validation failed: The directory for saving performance results does not exist: non_existent_dir",
48 | )
49 |
50 |
51 | @patch("os.path.exists", return_value=True)
52 | def test_parse_and_validate_console_args_no_plot(mock_exists):
53 | with patch.object(sys, "argv", ["program_name", "--config", "config.json", "--no-plot"]):
54 | result = parse_and_validate_console_args()
55 |
56 | assert hasattr(result, "no_plot"), "The `no_plot` attribute is missing from the parsed result."
57 | assert result.no_plot is True, "The `no_plot` flag was not set to True."
58 |
59 |
60 | @patch("os.path.exists", return_value=True)
61 | def test_parse_and_validate_console_args_profile(mock_exists):
62 | with patch.object(sys, "argv", ["program_name", "--config", "config.json", "--profile"]):
63 | result = parse_and_validate_console_args()
64 |
65 | assert hasattr(result, "profile"), "The `profile` attribute is missing from the parsed result."
66 | assert result.profile is True, "The `profile` flag was not set to True."
67 |
68 |
69 | @patch("utils.arg_parser.logging.error")
70 | def test_parse_and_validate_console_args_argument_error(mock_log):
71 | with patch.object(sys, "argv", ["program_name", "--config"]):
72 | with pytest.raises(RuntimeError, match="Failed to parse arguments. Please check your inputs."):
73 | parse_and_validate_console_args()
74 | mock_log.assert_called_once_with("Argument parsing failed: 2")
75 |
76 |
77 | @patch("utils.arg_parser.logging.error")
78 | def test_parse_and_validate_console_args_unexpected_error(mock_log):
79 | with (
80 | patch.object(
81 | sys,
82 | "argv",
83 | ["program_name", "--config", "config.json", "--save_performance_results", "results.json"],
84 | ),
85 | patch("os.path.exists", side_effect=Exception("Unexpected error")),
86 | ):
87 | with pytest.raises(RuntimeError, match="An unexpected error occurred during argument parsing."):
88 | parse_and_validate_console_args()
89 | mock_log.assert_any_call("An unexpected error occurred while parsing arguments: Unexpected error")
90 |
--------------------------------------------------------------------------------
/core/validation/order_validator.py:
--------------------------------------------------------------------------------
1 | from .exceptions import (
2 | InsufficientBalanceError,
3 | InsufficientCryptoBalanceError,
4 | InvalidOrderQuantityError,
5 | )
6 |
7 |
8 | class OrderValidator:
9 | def __init__(self, tolerance: float = 1e-6, threshold_ratio: float = 0.5):
10 | """
11 | Initializes the OrderValidator with a specified tolerance and threshold.
12 |
13 | Args:
14 | tolerance (float): Minimum precision tolerance for validation.
15 | threshold_ratio (float): Threshold below which an insufficient balance/crypto error is triggered early.
16 | """
17 | self.tolerance = tolerance
18 | self.threshold_ratio = threshold_ratio
19 |
20 | def adjust_and_validate_buy_quantity(self, balance: float, order_quantity: float, price: float) -> float:
21 | """
22 | Adjusts and validates the buy order quantity based on the available balance.
23 |
24 | Args:
25 | balance (float): Available fiat balance.
26 | order_quantity (float): Requested buy quantity.
27 | price (float): Price of the asset.
28 |
29 | Returns:
30 | float: Adjusted and validated buy order quantity.
31 |
32 | Raises:
33 | InsufficientBalanceError: If the balance is insufficient to place any valid order.
34 | InvalidOrderQuantityError: If the adjusted quantity is invalid.
35 | """
36 | total_cost = order_quantity * price
37 |
38 | if balance < total_cost * self.threshold_ratio:
39 | raise InsufficientBalanceError(
40 | f"Balance {balance:.2f} is far below the required cost {total_cost:.2f} "
41 | f"(threshold ratio: {self.threshold_ratio}).",
42 | )
43 |
44 | if total_cost > balance:
45 | adjusted_quantity = max((balance - self.tolerance) / price, 0)
46 |
47 | if adjusted_quantity <= 0 or (adjusted_quantity * price) < self.tolerance:
48 | raise InsufficientBalanceError(
49 | f"Insufficient balance: {balance:.2f} to place any buy order at price {price:.2f}.",
50 | )
51 | else:
52 | adjusted_quantity = order_quantity
53 |
54 | self._validate_quantity(adjusted_quantity, is_buy=True)
55 | return adjusted_quantity
56 |
57 | def adjust_and_validate_sell_quantity(self, crypto_balance: float, order_quantity: float) -> float:
58 | """
59 | Adjusts and validates the sell order quantity based on the available crypto balance.
60 |
61 | Args:
62 | crypto_balance (float): Available crypto balance.
63 | order_quantity (float): Requested sell quantity.
64 |
65 | Returns:
66 | float: Adjusted and validated sell order quantity.
67 |
68 | Raises:
69 | InsufficientCryptoBalanceError: If the crypto balance is insufficient to place any valid order.
70 | InvalidOrderQuantityError: If the adjusted quantity is invalid.
71 | """
72 | if crypto_balance < order_quantity * self.threshold_ratio:
73 | raise InsufficientCryptoBalanceError(
74 | f"Crypto balance {crypto_balance:.6f} is far below the required quantity {order_quantity:.6f} "
75 | f"(threshold ratio: {self.threshold_ratio}).",
76 | )
77 |
78 | adjusted_quantity = min(order_quantity, crypto_balance - self.tolerance)
79 | self._validate_quantity(adjusted_quantity, is_buy=False)
80 | return adjusted_quantity
81 |
82 | def _validate_quantity(self, quantity: float, is_buy: bool) -> None:
83 | """
84 | Validates the adjusted order quantity.
85 |
86 | Args:
87 | quantity (float): Adjusted quantity to validate.
88 | is_buy (bool): Whether the order is a buy order.
89 |
90 | Raises:
91 | InvalidOrderQuantityError: If the quantity is invalid.
92 | """
93 | if quantity <= 0:
94 | order_type = "buy" if is_buy else "sell"
95 | raise InvalidOrderQuantityError(f"Invalid {order_type} quantity: {quantity:.6f}")
96 |
--------------------------------------------------------------------------------
/tests/validation/test_order_validator.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from core.validation.exceptions import (
4 | InsufficientBalanceError,
5 | InsufficientCryptoBalanceError,
6 | InvalidOrderQuantityError,
7 | )
8 | from core.validation.order_validator import OrderValidator
9 |
10 |
11 | class TestOrderValidator:
12 | @pytest.fixture
13 | def validator(self):
14 | return OrderValidator()
15 |
16 | def test_adjust_and_validate_buy_quantity_valid(self, validator):
17 | balance = 5000
18 | order_quantity = 1
19 | price = 3000
20 |
21 | adjusted_quantity = validator.adjust_and_validate_buy_quantity(balance, order_quantity, price)
22 | assert adjusted_quantity == order_quantity
23 |
24 | def test_adjust_and_validate_buy_quantity_adjusted(self, validator):
25 | balance = 2000 # Insufficient for full quantity
26 | order_quantity = 1
27 | price = 3000
28 |
29 | adjusted_quantity = validator.adjust_and_validate_buy_quantity(balance, order_quantity, price)
30 | expected_quantity = (balance - validator.tolerance) / price
31 | assert adjusted_quantity == pytest.approx(expected_quantity, rel=1e-6)
32 |
33 | def test_adjust_and_validate_buy_quantity_insufficient_balance(self, validator):
34 | balance = 10 # Far below required cost
35 | order_quantity = 1
36 | price = 3000
37 |
38 | with pytest.raises(InsufficientBalanceError, match="far below the required cost"):
39 | validator.adjust_and_validate_buy_quantity(balance, order_quantity, price)
40 |
41 | def test_adjust_and_validate_buy_quantity_invalid_quantity(self, validator):
42 | balance = 5000
43 | order_quantity = -1 # Invalid quantity
44 | price = 3000
45 |
46 | with pytest.raises(InvalidOrderQuantityError, match="Invalid buy quantity"):
47 | validator.adjust_and_validate_buy_quantity(balance, order_quantity, price)
48 |
49 | def test_adjust_and_validate_sell_quantity_valid(self, validator):
50 | crypto_balance = 5
51 | order_quantity = 3
52 |
53 | adjusted_quantity = validator.adjust_and_validate_sell_quantity(crypto_balance, order_quantity)
54 | assert adjusted_quantity == order_quantity
55 |
56 | def test_adjust_and_validate_sell_quantity_adjusted(self, validator):
57 | crypto_balance = 2 # Insufficient for full quantity
58 | order_quantity = 3
59 |
60 | adjusted_quantity = validator.adjust_and_validate_sell_quantity(crypto_balance, order_quantity)
61 | expected_quantity = crypto_balance - validator.tolerance
62 | assert adjusted_quantity == pytest.approx(expected_quantity, rel=1e-6)
63 |
64 | def test_adjust_and_validate_sell_quantity_insufficient_balance(self, validator):
65 | crypto_balance = 0.1 # Far below required amount
66 | order_quantity = 3
67 |
68 | with pytest.raises(InsufficientCryptoBalanceError, match="far below the required quantity"):
69 | validator.adjust_and_validate_sell_quantity(crypto_balance, order_quantity)
70 |
71 | def test_adjust_and_validate_sell_quantity_invalid_quantity(self, validator):
72 | crypto_balance = 5
73 | order_quantity = -3 # Invalid quantity
74 |
75 | with pytest.raises(InvalidOrderQuantityError, match="Invalid sell quantity"):
76 | validator.adjust_and_validate_sell_quantity(crypto_balance, order_quantity)
77 |
78 | def test_adjust_and_validate_buy_quantity_tolerance_threshold(self, validator):
79 | balance = 1400 # Just above tolerance threshold
80 | order_quantity = 1
81 | price = 3000
82 |
83 | with pytest.raises(InsufficientBalanceError, match="Balance .* is far below the required cost .*"):
84 | validator.adjust_and_validate_buy_quantity(balance, order_quantity, price)
85 |
86 | def test_adjust_and_validate_sell_quantity_tolerance_threshold(self, validator):
87 | crypto_balance = 0.001 # Just above tolerance threshold
88 | order_quantity = 3
89 |
90 | with pytest.raises(
91 | InsufficientCryptoBalanceError,
92 | match="Crypto balance .* is far below the required quantity .*",
93 | ):
94 | validator.adjust_and_validate_sell_quantity(crypto_balance, order_quantity)
95 |
--------------------------------------------------------------------------------
/tests/order_handling/test_order.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from core.order_handling.order import Order, OrderSide, OrderStatus, OrderType
4 |
5 |
6 | class TestOrder:
7 | @pytest.fixture
8 | def sample_order(self):
9 | return Order(
10 | identifier="123",
11 | status=OrderStatus.OPEN,
12 | order_type=OrderType.LIMIT,
13 | side=OrderSide.BUY,
14 | price=1000.0,
15 | average=None,
16 | amount=5.0,
17 | filled=0.0,
18 | remaining=5.0,
19 | timestamp=1695890800,
20 | datetime="2024-01-01T00:00:00Z",
21 | last_trade_timestamp=None,
22 | symbol="BTC/USDT",
23 | time_in_force="GTC",
24 | )
25 |
26 | def test_create_order_with_valid_data(self, sample_order):
27 | assert sample_order.identifier == "123"
28 | assert sample_order.status == OrderStatus.OPEN
29 | assert sample_order.order_type == OrderType.LIMIT
30 | assert sample_order.side == OrderSide.BUY
31 | assert sample_order.price == 1000.0
32 | assert sample_order.amount == 5.0
33 | assert sample_order.filled == 0.0
34 | assert sample_order.remaining == 5.0
35 | assert sample_order.timestamp == 1695890800
36 | assert sample_order.datetime == "2024-01-01T00:00:00Z"
37 | assert sample_order.symbol == "BTC/USDT"
38 | assert sample_order.time_in_force == "GTC"
39 |
40 | def test_is_filled(self, sample_order):
41 | sample_order.status = OrderStatus.CLOSED
42 | assert sample_order.is_filled() is True
43 | sample_order.status = OrderStatus.OPEN
44 | assert sample_order.is_filled() is False
45 |
46 | def test_is_canceled(self, sample_order):
47 | sample_order.status = OrderStatus.CANCELED
48 | assert sample_order.is_canceled() is True
49 | sample_order.status = OrderStatus.OPEN
50 | assert sample_order.is_canceled() is False
51 |
52 | def test_is_open(self, sample_order):
53 | sample_order.status = OrderStatus.OPEN
54 | assert sample_order.is_open() is True
55 | sample_order.status = OrderStatus.CLOSED
56 | assert sample_order.is_open() is False
57 |
58 | def test_format_last_trade_timestamp(self, sample_order):
59 | # Case 1: No last trade timestamp
60 | assert sample_order.format_last_trade_timestamp() is None
61 |
62 | # Case 2: Valid last trade timestamp
63 | sample_order.last_trade_timestamp = 1695890800
64 | assert sample_order.format_last_trade_timestamp() == "2023-09-28T08:46:40"
65 |
66 | def test_order_str_representation(self, sample_order):
67 | order_str = str(sample_order)
68 | assert "Order(id=123, status=OrderStatus.OPEN" in order_str
69 | assert "type=OrderType.LIMIT, side=OrderSide.BUY, price=1000.0" in order_str
70 |
71 | def test_order_repr_representation(self, sample_order):
72 | order_repr = repr(sample_order)
73 | assert order_repr == str(sample_order)
74 |
75 | def test_order_with_trades_and_fee(self):
76 | order = Order(
77 | identifier="456",
78 | status=OrderStatus.CLOSED,
79 | order_type=OrderType.LIMIT,
80 | side=OrderSide.SELL,
81 | price=2000.0,
82 | average=1950.0,
83 | amount=3.0,
84 | filled=3.0,
85 | remaining=0.0,
86 | timestamp=1695890800,
87 | datetime="2024-01-01T00:00:00Z",
88 | last_trade_timestamp=1695890900,
89 | symbol="ETH/USDT",
90 | time_in_force="GTC",
91 | trades=[
92 | {"id": "trade1", "price": 1950.0, "amount": 1.0},
93 | {"id": "trade2", "price": 1950.0, "amount": 2.0},
94 | ],
95 | fee={"currency": "USDT", "cost": 5.0},
96 | cost=5850.0,
97 | )
98 | assert order.is_filled() is True
99 | assert order.fee == {"currency": "USDT", "cost": 5.0}
100 | assert order.trades == [
101 | {"id": "trade1", "price": 1950.0, "amount": 1.0},
102 | {"id": "trade2", "price": 1950.0, "amount": 2.0},
103 | ]
104 | assert order.cost == 5850.0
105 |
--------------------------------------------------------------------------------
/tests/grid_management/test_grid_level.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock
2 |
3 | import pytest
4 |
5 | from core.grid_management.grid_level import GridCycleState, GridLevel
6 | from core.order_handling.order import Order
7 |
8 |
9 | class TestGridLevel:
10 | @pytest.fixture
11 | def grid_level(self):
12 | return GridLevel(price=1000, state=GridCycleState.READY_TO_BUY)
13 |
14 | def test_grid_level_initialization(self):
15 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_BUY)
16 |
17 | assert grid_level.price == 1000
18 | assert grid_level.state == GridCycleState.READY_TO_BUY
19 | assert grid_level.orders == []
20 | assert grid_level.paired_buy_level is None
21 | assert grid_level.paired_sell_level is None
22 |
23 | def test_add_order(self, grid_level):
24 | mock_order = Mock(spec=Order)
25 | grid_level.add_order(mock_order)
26 |
27 | assert len(grid_level.orders) == 1
28 | assert grid_level.orders[0] == mock_order
29 |
30 | def test_str_representation(self):
31 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_BUY)
32 | grid_level.paired_buy_level = GridLevel(price=900, state=GridCycleState.READY_TO_SELL)
33 | grid_level.paired_sell_level = GridLevel(price=1100, state=GridCycleState.READY_TO_BUY)
34 |
35 | repr_str = str(grid_level)
36 | assert "price=1000" in repr_str
37 | assert "state=READY_TO_BUY" in repr_str
38 | assert "paired_buy_level=900" in repr_str
39 | assert "paired_sell_level=1100" in repr_str
40 |
41 | def test_update_paired_levels(self):
42 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_BUY)
43 | paired_buy_level = GridLevel(price=900, state=GridCycleState.READY_TO_SELL)
44 | paired_sell_level = GridLevel(price=1100, state=GridCycleState.READY_TO_BUY)
45 |
46 | grid_level.paired_buy_level = paired_buy_level
47 | grid_level.paired_sell_level = paired_sell_level
48 |
49 | assert grid_level.paired_buy_level.price == 900
50 | assert grid_level.paired_sell_level.price == 1100
51 |
52 | def test_state_transition_to_waiting_for_buy_fill(self, grid_level):
53 | grid_level.state = GridCycleState.READY_TO_BUY
54 | mock_order = Mock(spec=Order)
55 | grid_level.add_order(mock_order)
56 |
57 | grid_level.state = GridCycleState.WAITING_FOR_BUY_FILL
58 | assert grid_level.state == GridCycleState.WAITING_FOR_BUY_FILL
59 |
60 | def test_state_transition_to_ready_to_sell(self, grid_level):
61 | grid_level.state = GridCycleState.READY_TO_BUY
62 | mock_order = Mock(spec=Order)
63 | grid_level.add_order(mock_order)
64 |
65 | grid_level.state = GridCycleState.READY_TO_SELL
66 | assert grid_level.state == GridCycleState.READY_TO_SELL
67 |
68 | def test_state_transition_to_waiting_for_sell_fill(self):
69 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_SELL)
70 | mock_order = Mock(spec=Order)
71 | grid_level.add_order(mock_order)
72 |
73 | grid_level.state = GridCycleState.WAITING_FOR_SELL_FILL
74 | assert grid_level.state == GridCycleState.WAITING_FOR_SELL_FILL
75 |
76 | def test_state_transition_to_ready_to_buy_or_sell(self):
77 | grid_level = GridLevel(price=1000, state=GridCycleState.WAITING_FOR_SELL_FILL)
78 |
79 | grid_level.state = GridCycleState.READY_TO_BUY_OR_SELL
80 | assert grid_level.state == GridCycleState.READY_TO_BUY_OR_SELL
81 |
82 | def test_paired_levels_initialization(self):
83 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_BUY)
84 | paired_buy_level = GridLevel(price=900, state=GridCycleState.READY_TO_SELL)
85 | paired_sell_level = GridLevel(price=1100, state=GridCycleState.READY_TO_BUY)
86 |
87 | grid_level.paired_buy_level = paired_buy_level
88 | grid_level.paired_sell_level = paired_sell_level
89 |
90 | assert grid_level.paired_buy_level == paired_buy_level
91 | assert grid_level.paired_sell_level == paired_sell_level
92 |
93 | def test_orders_list(self):
94 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_BUY)
95 | order1 = Mock(spec=Order)
96 | order2 = Mock(spec=Order)
97 |
98 | grid_level.add_order(order1)
99 | grid_level.add_order(order2)
100 |
101 | assert len(grid_level.orders) == 2
102 | assert grid_level.orders[0] == order1
103 | assert grid_level.orders[1] == order2
104 |
--------------------------------------------------------------------------------
/core/bot_management/bot_controller/bot_controller.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 |
4 | from tabulate import tabulate
5 |
6 | from core.bot_management.event_bus import EventBus, Events
7 | from core.bot_management.grid_trading_bot import GridTradingBot
8 |
9 | from .exceptions import CommandParsingError, StrategyControlError
10 |
11 |
12 | class BotController:
13 | """
14 | Handles user commands and manages the lifecycle of the GridTradingBot.
15 | """
16 |
17 | def __init__(
18 | self,
19 | bot: GridTradingBot,
20 | event_bus: EventBus,
21 | ):
22 | """
23 | Initializes the BotController.
24 |
25 | Args:
26 | bot: The GridTradingBot instance to control.
27 | event_bus: The EventBus instance to publish/subscribe Events.
28 | """
29 | self.logger = logging.getLogger(self.__class__.__name__)
30 | self.bot = bot
31 | self.event_bus = event_bus
32 | self._stop_listening = False
33 | self.event_bus.subscribe(Events.STOP_BOT, self._handle_stop_event)
34 |
35 | async def command_listener(self):
36 | """
37 | Listens for user commands and processes them.
38 | """
39 | self.logger.info("Command listener started. Type 'quit' to exit.")
40 | loop = asyncio.get_event_loop()
41 |
42 | while not self._stop_listening:
43 | try:
44 | command = await loop.run_in_executor(
45 | None,
46 | input,
47 | "Enter command (quit, orders, balance, stop, restart, pause): ",
48 | )
49 | await self._handle_command(command.strip().lower())
50 |
51 | except CommandParsingError as e:
52 | self.logger.warning(f"Command error: {e}")
53 |
54 | except Exception as e:
55 | self.logger.error(f"Unexpected error in command listener: {e}", exc_info=True)
56 |
57 | async def _handle_command(self, command: str):
58 | """
59 | Handles individual commands from the user.
60 |
61 | Args:
62 | command: The command entered by the user.
63 | """
64 | if command == "quit":
65 | self.logger.info("Stop bot command received")
66 | self.event_bus.publish_sync(Events.STOP_BOT, "User requested shutdown")
67 |
68 | elif command == "orders":
69 | await self._display_orders()
70 |
71 | elif command == "balance":
72 | await self._display_balance()
73 |
74 | elif command == "stop":
75 | self.event_bus.publish_sync(Events.STOP_BOT, "User issued stop command")
76 |
77 | elif command == "restart":
78 | self.event_bus.publish_sync(Events.STOP_BOT, "User issued restart command")
79 | self.event_bus.publish_sync(Events.START_BOT, "User issued restart command")
80 |
81 | elif command.startswith("pause"):
82 | await self._pause_bot(command)
83 |
84 | else:
85 | raise CommandParsingError(f"Unknown command: {command}")
86 |
87 | def _stop_listener(self):
88 | """
89 | Stops the command listener loop.
90 | """
91 | self._stop_listening = True
92 | self.logger.info("Command listener stopped.")
93 |
94 | def _handle_stop_event(self, reason: str) -> None:
95 | """
96 | Handles the STOP_BOT event and stops the command listener.
97 |
98 | Args:
99 | reason: The reason for stopping the bot.
100 | """
101 | self.logger.info(f"Received STOP_BOT event: {reason}")
102 | self._stop_listener()
103 |
104 | async def _display_orders(self):
105 | """
106 | Displays formatted orders retrieved from the bot.
107 | """
108 | self.logger.info("Display orders bot command received")
109 | formatted_orders = self.bot.strategy.get_formatted_orders()
110 | orders_table = tabulate(
111 | formatted_orders,
112 | headers=["Order Side", "Type", "Status", "Price", "Quantity", "Timestamp", "Grid Level", "Slippage"],
113 | tablefmt="pipe",
114 | )
115 | self.logger.info("\nFormatted Orders:\n" + orders_table)
116 |
117 | async def _display_balance(self):
118 | """
119 | Displays the current balances retrieved from the bot.
120 | """
121 | self.logger.info("Display balance bot command received")
122 | current_balances = self.bot.get_balances()
123 | self.logger.info(f"Current balances: {current_balances}")
124 |
125 | async def _pause_bot(self, command: str):
126 | """
127 | Pauses the bot for a specified duration.
128 |
129 | Args:
130 | command: The pause command containing the duration.
131 | """
132 | try:
133 | self.logger.info("Pause bot command received")
134 | duration = int(command.split()[1])
135 | await self.event_bus.publish(Events.STOP_BOT, "User issued pause command")
136 | self.logger.info(f"Bot paused for {duration} seconds.")
137 | await asyncio.sleep(duration)
138 | self.logger.info("Resuming bot after pause.")
139 | await self.event_bus.publish(Events.START_BOT, "Resuming bot after pause")
140 |
141 | except ValueError:
142 | raise CommandParsingError("Invalid pause duration. Please specify in seconds.") from None
143 |
144 | except Exception as e:
145 | raise StrategyControlError(f"Error during pause operation: {e}") from e
146 |
--------------------------------------------------------------------------------
/core/bot_management/event_bus.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from collections.abc import Awaitable, Callable
3 | import inspect
4 | import logging
5 | from typing import Any
6 |
7 |
8 | class Events:
9 | """
10 | Defines event types for the EventBus.
11 | """
12 |
13 | ORDER_FILLED = "order_filled"
14 | ORDER_CANCELLED = "order_cancelled"
15 | START_BOT = "start_bot"
16 | STOP_BOT = "stop_bot"
17 |
18 |
19 | class EventBus:
20 | """
21 | A simple event bus for managing pub-sub interactions with support for both sync and async publishing.
22 | """
23 |
24 | def __init__(self):
25 | """
26 | Initializes the EventBus with an empty subscriber list.
27 | """
28 | self.logger = logging.getLogger(self.__class__.__name__)
29 | self.subscribers: dict[str, list[Callable[[Any], None]]] = {}
30 | self._tasks: set[asyncio.Task] = set()
31 |
32 | def subscribe(
33 | self,
34 | event_type: str,
35 | callback: Callable[[Any], None] | Callable[[Any], Awaitable[None]],
36 | ) -> None:
37 | """
38 | Subscribes a callback to a specific event type.
39 |
40 | Args:
41 | event_type: The type of event to subscribe to.
42 | callback: The callback function to invoke when the event is published.
43 | """
44 | if event_type not in self.subscribers:
45 | self.subscribers[event_type] = []
46 |
47 | self.subscribers[event_type].append(callback)
48 | callback_name = getattr(callback, "__name__", str(callback))
49 | caller_frame = inspect.stack()[1]
50 | caller_name = f"{caller_frame.function} (from {caller_frame.filename}:{caller_frame.lineno})"
51 | self.logger.info(f"Callback '{callback_name}' subscribed to event: {event_type} by {caller_name}")
52 |
53 | async def publish(
54 | self,
55 | event_type: str,
56 | data: Any = None,
57 | ) -> None:
58 | """
59 | Publishes an event asynchronously to all subscribers.
60 | """
61 | if event_type not in self.subscribers:
62 | self.logger.warning(f"No subscribers for event: {event_type}")
63 | return
64 |
65 | self.logger.info(f"Publishing async event: {event_type} with data: {data}")
66 | tasks = [
67 | self._safe_invoke_async(callback, data)
68 | if asyncio.iscoroutinefunction(callback)
69 | else asyncio.to_thread(self._safe_invoke_sync, callback, data)
70 | for callback in self.subscribers[event_type]
71 | ]
72 | if tasks:
73 | results = await asyncio.gather(*tasks, return_exceptions=True)
74 | for result in results:
75 | if isinstance(result, Exception):
76 | self.logger.error(f"Exception in async event callback: {result}", exc_info=True)
77 |
78 | def publish_sync(
79 | self,
80 | event_type: str,
81 | data: Any,
82 | ) -> None:
83 | """
84 | Publishes an event synchronously to all subscribers.
85 | """
86 | if event_type in self.subscribers:
87 | self.logger.info(f"Publishing sync event: {event_type} with data: {data}")
88 | loop = asyncio.get_event_loop()
89 | for callback in self.subscribers[event_type]:
90 | if asyncio.iscoroutinefunction(callback):
91 | asyncio.run_coroutine_threadsafe(self._safe_invoke_async(callback, data), loop)
92 | else:
93 | self._safe_invoke_sync(callback, data)
94 |
95 | async def _safe_invoke_async(
96 | self,
97 | callback: Callable[[Any], None],
98 | data: Any,
99 | ) -> None:
100 | """
101 | Safely invokes an async callback, suppressing and logging any exceptions.
102 | """
103 | task = asyncio.create_task(self._invoke_callback(callback, data))
104 | self._tasks.add(task)
105 |
106 | def remove_task(completed_task: asyncio.Task):
107 | if not completed_task.cancelled():
108 | self._tasks.discard(completed_task)
109 |
110 | task.add_done_callback(remove_task)
111 | self.logger.debug(f"Task created for callback '{callback.__name__}' with data: {data}")
112 |
113 | async def _invoke_callback(
114 | self,
115 | callback: Callable[[Any], None],
116 | data: Any,
117 | ) -> None:
118 | try:
119 | self.logger.info(f"Executing async callback '{callback.__name__}' for event with data: {data}")
120 | await callback(data)
121 | except Exception as e:
122 | self.logger.error(f"Error in async callback '{callback.__name__}': {e}", exc_info=True)
123 |
124 | def _safe_invoke_sync(
125 | self,
126 | callback: Callable[[Any], None],
127 | data: Any,
128 | ) -> None:
129 | """
130 | Safely invokes a sync callback, suppressing and logging any exceptions.
131 | """
132 | try:
133 | callback(data)
134 | except Exception as e:
135 | self.logger.error(f"Error in sync subscriber callback: {e}", exc_info=True)
136 |
137 | async def shutdown(self):
138 | """
139 | Cancels all active tasks tracked by the EventBus for graceful shutdown.
140 | """
141 | self.logger.info("Shutting down EventBus...")
142 | for task in self._tasks:
143 | task.cancel()
144 | await asyncio.gather(*self._tasks, return_exceptions=True)
145 | self._tasks.clear()
146 | self.logger.info("EventBus shutdown complete.")
147 |
--------------------------------------------------------------------------------
/tests/order_handling/test_order_book.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock
2 |
3 | import pytest
4 |
5 | from core.grid_management.grid_level import GridLevel
6 | from core.order_handling.order import Order, OrderSide, OrderStatus
7 | from core.order_handling.order_book import OrderBook
8 |
9 |
10 | class TestOrderBook:
11 | @pytest.fixture
12 | def setup_order_book(self):
13 | return OrderBook()
14 |
15 | def test_add_order_with_grid(self, setup_order_book):
16 | order_book = setup_order_book
17 | buy_order = Mock(spec=Order, side=OrderSide.BUY)
18 | sell_order = Mock(spec=Order, side=OrderSide.SELL)
19 | grid_level = Mock(spec=GridLevel)
20 |
21 | order_book.add_order(buy_order, grid_level)
22 | order_book.add_order(sell_order, grid_level)
23 |
24 | assert len(order_book.buy_orders) == 1
25 | assert len(order_book.sell_orders) == 1
26 | assert order_book.order_to_grid_map[buy_order] == grid_level
27 | assert order_book.order_to_grid_map[sell_order] == grid_level
28 |
29 | def test_add_order_without_grid(self, setup_order_book):
30 | order_book = setup_order_book
31 | non_grid_order = Mock(spec=Order, side=OrderSide.SELL)
32 |
33 | order_book.add_order(non_grid_order)
34 |
35 | assert len(order_book.non_grid_orders) == 1
36 | assert order_book.non_grid_orders[0] == non_grid_order
37 |
38 | def test_get_buy_orders_with_grid(self, setup_order_book):
39 | order_book = setup_order_book
40 | buy_order = Mock(spec=Order, side=OrderSide.BUY)
41 | grid_level = Mock(spec=GridLevel)
42 |
43 | order_book.add_order(buy_order, grid_level)
44 | result = order_book.get_buy_orders_with_grid()
45 |
46 | assert len(result) == 1
47 | assert result[0] == (buy_order, grid_level)
48 |
49 | def test_get_sell_orders_with_grid(self, setup_order_book):
50 | order_book = setup_order_book
51 | sell_order = Mock(spec=Order, side=OrderSide.SELL)
52 | grid_level = Mock(spec=GridLevel)
53 |
54 | order_book.add_order(sell_order, grid_level)
55 | result = order_book.get_sell_orders_with_grid()
56 |
57 | assert len(result) == 1
58 | assert result[0] == (sell_order, grid_level)
59 |
60 | def test_get_all_buy_orders(self, setup_order_book):
61 | order_book = setup_order_book
62 | buy_order_1 = Mock(spec=Order, side=OrderSide.BUY)
63 | buy_order_2 = Mock(spec=Order, side=OrderSide.BUY)
64 |
65 | order_book.add_order(buy_order_1)
66 | order_book.add_order(buy_order_2)
67 | result = order_book.get_all_buy_orders()
68 |
69 | assert len(result) == 2
70 | assert buy_order_1 in result
71 | assert buy_order_2 in result
72 |
73 | def test_get_all_sell_orders(self, setup_order_book):
74 | order_book = setup_order_book
75 | sell_order_1 = Mock(spec=Order, side=OrderSide.SELL)
76 | sell_order_2 = Mock(spec=Order, side=OrderSide.SELL)
77 |
78 | order_book.add_order(sell_order_1)
79 | order_book.add_order(sell_order_2)
80 | result = order_book.get_all_sell_orders()
81 |
82 | assert len(result) == 2
83 | assert sell_order_1 in result
84 | assert sell_order_2 in result
85 |
86 | def test_get_open_orders(self, setup_order_book):
87 | order_book = setup_order_book
88 | open_order = Mock(spec=Order, side=OrderSide.BUY, is_open=Mock(return_value=True))
89 | closed_order = Mock(spec=Order, side=OrderSide.SELL, is_open=Mock(return_value=False))
90 |
91 | order_book.add_order(open_order)
92 | order_book.add_order(closed_order)
93 | result = order_book.get_open_orders()
94 |
95 | assert len(result) == 1
96 | assert open_order in result
97 |
98 | def test_get_completed_orders(self, setup_order_book):
99 | order_book = setup_order_book
100 | completed_order = Mock(spec=Order, side=OrderSide.BUY, is_filled=Mock(return_value=True))
101 | pending_order = Mock(spec=Order, side=OrderSide.BUY, is_filled=Mock(return_value=False))
102 |
103 | order_book.add_order(completed_order)
104 | order_book.add_order(pending_order)
105 | result = order_book.get_completed_orders()
106 |
107 | assert len(result) == 1
108 | assert completed_order in result
109 |
110 | def test_get_grid_level_for_order(self, setup_order_book):
111 | order_book = setup_order_book
112 | order = Mock(spec=Order, side=OrderSide.BUY)
113 | grid_level = Mock(spec=GridLevel)
114 |
115 | order_book.add_order(order, grid_level)
116 | result = order_book.get_grid_level_for_order(order)
117 |
118 | assert result == grid_level
119 |
120 | def test_update_order_status(self, setup_order_book):
121 | order_book = setup_order_book
122 | order = Mock(spec=Order, identifier="order_123", side=OrderSide.BUY, status=OrderStatus.OPEN)
123 |
124 | order_book.add_order(order)
125 | order_book.update_order_status("order_123", OrderStatus.CLOSED)
126 |
127 | assert order.status == OrderStatus.CLOSED
128 |
129 | def test_update_order_status_nonexistent_order(self, setup_order_book):
130 | order_book = setup_order_book
131 | order = Mock(spec=Order, identifier="order_123", status=OrderStatus.OPEN)
132 | order.side = OrderSide.BUY
133 |
134 | order_book.add_order(order)
135 | order_book.update_order_status("nonexistent_order", OrderStatus.CLOSED)
136 |
137 | assert order.status == OrderStatus.OPEN # Ensure no changes for non-existent orders
138 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement by emailing
63 | [tetej171@gmail.com].
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import cProfile
3 | import logging
4 | import os
5 | from typing import Any
6 |
7 | from dotenv import load_dotenv
8 |
9 | from config.config_manager import ConfigManager
10 | from config.config_validator import ConfigValidator
11 | from config.exceptions import ConfigError
12 | from config.trading_mode import TradingMode
13 | from core.bot_management.bot_controller.bot_controller import BotController
14 | from core.bot_management.event_bus import EventBus
15 | from core.bot_management.grid_trading_bot import GridTradingBot
16 | from core.bot_management.health_check import HealthCheck
17 | from core.bot_management.notification.notification_handler import NotificationHandler
18 | from utils.arg_parser import parse_and_validate_console_args
19 | from utils.config_name_generator import generate_config_name
20 | from utils.logging_config import setup_logging
21 | from utils.performance_results_saver import save_or_append_performance_results
22 |
23 |
24 | def initialize_config(config_path: str) -> ConfigManager:
25 | load_dotenv()
26 | try:
27 | return ConfigManager(config_path, ConfigValidator())
28 |
29 | except ConfigError as e:
30 | logging.error(f"An error occured during the initialization of ConfigManager {e}")
31 | exit(1)
32 |
33 |
34 | def initialize_notification_handler(config_manager: ConfigManager, event_bus: EventBus) -> NotificationHandler:
35 | notification_urls = os.getenv("APPRISE_NOTIFICATION_URLS", "").split(",")
36 | trading_mode = config_manager.get_trading_mode()
37 | return NotificationHandler(event_bus, notification_urls, trading_mode)
38 |
39 |
40 | async def run_bot(
41 | config_path: str,
42 | profile: bool = False,
43 | save_performance_results_path: str | None = None,
44 | no_plot: bool = False,
45 | ) -> dict[str, Any] | None:
46 | config_manager = initialize_config(config_path)
47 | config_name = generate_config_name(config_manager)
48 | setup_logging(config_manager.get_logging_level(), config_manager.should_log_to_file(), config_name)
49 | event_bus = EventBus()
50 | notification_handler = initialize_notification_handler(config_manager, event_bus)
51 | bot = GridTradingBot(
52 | config_path,
53 | config_manager,
54 | notification_handler,
55 | event_bus,
56 | save_performance_results_path,
57 | no_plot,
58 | )
59 | bot_controller = BotController(bot, event_bus)
60 | health_check = HealthCheck(bot, notification_handler, event_bus)
61 |
62 | if profile:
63 | cProfile.runctx("asyncio.run(bot.run())", globals(), locals(), "profile_results.prof")
64 | return None
65 |
66 | try:
67 | if bot.trading_mode in {TradingMode.LIVE, TradingMode.PAPER_TRADING}:
68 | bot_task = asyncio.create_task(bot.run(), name="BotTask")
69 | bot_controller_task = asyncio.create_task(bot_controller.command_listener(), name="BotControllerTask")
70 | health_check_task = asyncio.create_task(health_check.start(), name="HealthCheckTask")
71 | await asyncio.gather(bot_task, bot_controller_task, health_check_task)
72 | else:
73 | await bot.run()
74 |
75 | except asyncio.CancelledError:
76 | logging.info("Cancellation received. Shutting down gracefully.")
77 |
78 | except Exception as e:
79 | logging.error(f"An unexpected error occurred: {e}", exc_info=True)
80 |
81 | finally:
82 | try:
83 | await event_bus.shutdown()
84 |
85 | except Exception as e:
86 | logging.error(f"Error during EventBus shutdown: {e}", exc_info=True)
87 |
88 |
89 | async def cleanup_tasks():
90 | logging.info("Shutting down bot and cleaning up tasks...")
91 |
92 | current_task = asyncio.current_task()
93 | tasks_to_cancel = {
94 | task for task in asyncio.all_tasks() if task is not current_task and not task.done() and not task.cancelled()
95 | }
96 |
97 | logging.info(f"Tasks to cancel: {len(tasks_to_cancel)}")
98 |
99 | for task in tasks_to_cancel:
100 | logging.info(f"Task to cancel: {task} - Done: {task.done()} - Cancelled: {task.cancelled()}")
101 | task.cancel()
102 |
103 | try:
104 | await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
105 |
106 | except asyncio.CancelledError:
107 | logging.info("Tasks cancelled successfully.")
108 |
109 | except Exception as e:
110 | logging.error(f"Error during task cancellation: {e}", exc_info=True)
111 |
112 |
113 | if __name__ == "__main__":
114 | args = parse_and_validate_console_args()
115 |
116 | async def main():
117 | try:
118 | tasks = [
119 | run_bot(config_path, args.profile, args.save_performance_results, args.no_plot)
120 | for config_path in args.config
121 | ]
122 |
123 | results = await asyncio.gather(*tasks, return_exceptions=True)
124 |
125 | for index, result in enumerate(results):
126 | if isinstance(result, Exception):
127 | logging.error(
128 | f"Error occurred while running bot for config {args.config[index]}: {result}",
129 | exc_info=True,
130 | )
131 | else:
132 | if args.save_performance_results:
133 | save_or_append_performance_results(result, args.save_performance_results)
134 |
135 | except Exception as e:
136 | logging.error(f"Critical error in main: {e}", exc_info=True)
137 |
138 | finally:
139 | await cleanup_tasks()
140 | logging.info("All tasks have completed.")
141 |
142 | asyncio.run(main())
143 |
--------------------------------------------------------------------------------
/strategies/plotter.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import plotly.graph_objects as go
3 | from plotly.subplots import make_subplots
4 |
5 | from core.grid_management.grid_manager import GridManager
6 | from core.order_handling.order import Order, OrderSide
7 | from core.order_handling.order_book import OrderBook
8 |
9 |
10 | class Plotter:
11 | def __init__(
12 | self,
13 | grid_manager: GridManager,
14 | order_book: OrderBook,
15 | ):
16 | self.grid_manager = grid_manager
17 | self.order_book = order_book
18 |
19 | def plot_results(
20 | self,
21 | data: pd.DataFrame,
22 | ) -> None:
23 | fig = make_subplots(rows=3, cols=1, shared_xaxes=True, row_heights=[0.70, 0.15, 0.15], vertical_spacing=0.02)
24 | self._add_candlestick_trace(fig, data)
25 | trigger_price = self.grid_manager.get_trigger_price()
26 | self._add_trigger_price_line(fig, trigger_price)
27 | self._add_grid_lines(fig, self.grid_manager.price_grids, self.grid_manager.central_price)
28 | self._add_trade_markers(fig, self.order_book.get_completed_orders())
29 | self._add_volume_trace(fig, data)
30 | self._add_account_value_trace(fig, data)
31 |
32 | fig.update_layout(
33 | title="Grid Trading Strategy Results",
34 | yaxis_title="Price (USDT)",
35 | yaxis2_title="Volume",
36 | yaxis3_title="Equity",
37 | xaxis={"rangeslider": {"visible": False}},
38 | showlegend=False,
39 | )
40 | fig.show()
41 |
42 | def _add_candlestick_trace(
43 | self,
44 | fig: go.Figure,
45 | data: pd.DataFrame,
46 | ) -> None:
47 | fig.add_trace(
48 | go.Candlestick(
49 | x=data.index,
50 | open=data["open"],
51 | high=data["high"],
52 | low=data["low"],
53 | close=data["close"],
54 | name="",
55 | ),
56 | row=1,
57 | col=1,
58 | )
59 |
60 | def _add_trigger_price_line(
61 | self,
62 | fig: go.Figure,
63 | trigger_price: float,
64 | ):
65 | fig.add_trace(
66 | go.Scatter(
67 | x=[fig.data[0].x[0], fig.data[0].x[-1]],
68 | y=[trigger_price, trigger_price],
69 | mode="lines",
70 | line={"color": "blue", "width": 2, "dash": "dash"},
71 | name="Central Price",
72 | ),
73 | )
74 | fig.add_annotation(
75 | x=fig.data[0].x[-1],
76 | y=trigger_price,
77 | text=f"Trigger Price: {trigger_price:.2f}",
78 | showarrow=True,
79 | arrowhead=2,
80 | arrowsize=1,
81 | arrowcolor="blue",
82 | ax=20,
83 | ay=-20,
84 | font={"size": 10, "color": "blue"},
85 | )
86 |
87 | def _add_grid_lines(
88 | self,
89 | fig: go.Figure,
90 | grids: list[float],
91 | central_price: float,
92 | ) -> None:
93 | for price in grids:
94 | color = "green" if price < central_price else "red"
95 | fig.add_trace(
96 | go.Scatter(
97 | x=[fig.data[0].x[0], fig.data[0].x[-1]],
98 | y=[price, price],
99 | mode="lines",
100 | line={"color": color, "dash": "dash"},
101 | showlegend=False,
102 | ),
103 | )
104 |
105 | def _add_trade_markers(
106 | self,
107 | fig: go.Figure,
108 | orders: list[Order],
109 | ) -> None:
110 | for order in orders:
111 | icon_name = "triangle-up" if order.side == OrderSide.BUY else "triangle-down"
112 | icon_color = "green" if order.side == OrderSide.BUY else "red"
113 | fig.add_trace(
114 | go.Scatter(
115 | x=[order.format_last_trade_timestamp()],
116 | y=[order.price],
117 | mode="markers",
118 | marker={
119 | "symbol": icon_name,
120 | "color": icon_color,
121 | "size": 12,
122 | "line": {"color": "black", "width": 2},
123 | },
124 | name=f"{order.side.name} Order",
125 | text=f"Price: {order.price}\nQty: {order.filled}\nDate: {order.format_last_trade_timestamp()}",
126 | hoverinfo="x+y+text",
127 | ),
128 | row=1,
129 | col=1,
130 | )
131 |
132 | def _add_volume_trace(
133 | self,
134 | fig: go.Figure,
135 | data: pd.DataFrame,
136 | ) -> None:
137 | volume_colors = [
138 | "green" if close >= open_ else "red" for close, open_ in zip(data["close"], data["open"], strict=False)
139 | ]
140 |
141 | fig.add_trace(
142 | go.Bar(
143 | x=data.index,
144 | y=data["volume"],
145 | marker={"color": volume_colors},
146 | name="",
147 | ),
148 | row=2,
149 | col=1,
150 | )
151 |
152 | fig.update_yaxes(
153 | title="Volume",
154 | row=2,
155 | col=1,
156 | )
157 |
158 | def _add_account_value_trace(
159 | self,
160 | fig: go.Figure,
161 | data: pd.DataFrame,
162 | ) -> None:
163 | fig.add_trace(
164 | go.Scatter(
165 | x=data.index,
166 | y=data["account_value"],
167 | mode="lines",
168 | name="",
169 | line={"color": "purple", "width": 2},
170 | ),
171 | row=3,
172 | col=1,
173 | )
174 |
--------------------------------------------------------------------------------
/tests/strategies/test_plotter.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock, patch
2 |
3 | import pandas as pd
4 | import plotly.graph_objects as go
5 | from plotly.subplots import make_subplots
6 | import pytest
7 |
8 | from core.grid_management.grid_manager import GridManager
9 | from core.order_handling.order import Order, OrderSide, OrderType
10 | from core.order_handling.order_book import OrderBook
11 | from strategies.plotter import Plotter
12 |
13 |
14 | class TestPlotter:
15 | @pytest.fixture
16 | def setup_plotter(self):
17 | grid_manager = Mock(spec=GridManager)
18 | order_book = Mock(spec=OrderBook)
19 | plotter = Plotter(grid_manager=grid_manager, order_book=order_book)
20 | return plotter, grid_manager, order_book
21 |
22 | def test_add_grid_lines(self, setup_plotter):
23 | plotter, grid_manager, _ = setup_plotter
24 | fig = go.Figure()
25 |
26 | mock_x_data = [1, 2, 3] # Example x-axis values
27 | fig.add_trace(go.Scatter(x=mock_x_data, y=[100, 105, 110])) # Add a dummy trace
28 |
29 | grid_manager.price_grids = [90, 100, 110]
30 | grid_manager.central_price = 100
31 |
32 | plotter._add_grid_lines(fig, grid_manager.price_grids, grid_manager.central_price)
33 |
34 | assert len(fig.data) == 4 # 1 dummy trace + 3 grid lines
35 | assert fig.data[1].line.color == "green" # Below central price
36 | assert fig.data[2].line.color == "red" # Above central price
37 | assert fig.data[3].line.color == "red" # Above central price
38 |
39 | def test_add_trigger_price_line(self, setup_plotter):
40 | plotter, grid_manager, _ = setup_plotter
41 | fig = go.Figure()
42 | trigger_price = 105
43 |
44 | fig.add_trace(go.Scatter(x=[1, 2, 3], y=[100, 105, 110]))
45 |
46 | plotter._add_trigger_price_line(fig, trigger_price)
47 |
48 | assert len(fig.data) == 2 # 1 dummy trace + 1 trigger price line
49 | assert fig.data[1].line.color == "blue"
50 | assert "Trigger Price" in fig.layout.annotations[0].text
51 |
52 | def test_add_trade_markers(self, setup_plotter):
53 | plotter, _, order_book = setup_plotter
54 | fig = make_subplots(rows=3, cols=1, shared_xaxes=True, row_heights=[0.70, 0.15, 0.15], vertical_spacing=0.02)
55 |
56 | orders = [
57 | Order(
58 | identifier="123",
59 | status=None,
60 | order_type=OrderType.LIMIT,
61 | side=OrderSide.BUY,
62 | price=1000.0,
63 | average=None,
64 | amount=5.0,
65 | filled=5.0,
66 | remaining=0.0,
67 | timestamp=1695890800,
68 | datetime="2024-01-01T00:00:00Z",
69 | last_trade_timestamp=1695890800,
70 | symbol="BTC/USDT",
71 | time_in_force="GTC",
72 | ),
73 | Order(
74 | identifier="124",
75 | status=None,
76 | order_type=OrderType.LIMIT,
77 | side=OrderSide.SELL,
78 | price=1200.0,
79 | average=None,
80 | amount=3.0,
81 | filled=3.0,
82 | remaining=0.0,
83 | timestamp=1695890800,
84 | datetime="2024-01-02T00:00:00Z",
85 | last_trade_timestamp=1695890800,
86 | symbol="BTC/USDT",
87 | time_in_force="GTC",
88 | ),
89 | ]
90 | order_book.get_completed_orders.return_value = orders
91 |
92 | plotter._add_trade_markers(fig, orders)
93 |
94 | assert len(fig.data) == 2
95 | assert fig.data[0].marker.color == "green"
96 | assert fig.data[1].marker.color == "red"
97 |
98 | def test_add_volume_trace(self, setup_plotter):
99 | plotter, _, _ = setup_plotter
100 | fig = make_subplots(rows=3, cols=1, shared_xaxes=True, row_heights=[0.70, 0.15, 0.15], vertical_spacing=0.02)
101 |
102 | data = pd.DataFrame(
103 | {"open": [100, 110], "close": [110, 105], "volume": [500, 700]},
104 | index=pd.date_range("2024-01-01", periods=2),
105 | )
106 |
107 | plotter._add_volume_trace(fig, data)
108 |
109 | assert len(fig.data) == 1
110 | assert fig.data[0].type == "bar"
111 | assert fig.data[0].y.tolist() == [500, 700]
112 | assert list(fig.data[0].marker.color) == ["green", "red"]
113 |
114 | def test_add_account_value_trace(self, setup_plotter):
115 | plotter, _, _ = setup_plotter
116 | fig = make_subplots(rows=3, cols=1, shared_xaxes=True, row_heights=[0.70, 0.15, 0.15], vertical_spacing=0.02)
117 | data = pd.DataFrame({"account_value": [10000, 10500]}, index=pd.date_range("2024-01-01", periods=2))
118 |
119 | plotter._add_account_value_trace(fig, data)
120 |
121 | assert len(fig.data) == 1
122 | assert fig.data[0].type == "scatter"
123 | assert fig.data[0].y.tolist() == [10000, 10500]
124 | assert fig.data[0].line.color == "purple"
125 |
126 | @patch("plotly.graph_objects.Figure.show")
127 | def test_plot_results(self, mock_show, setup_plotter):
128 | plotter, grid_manager, order_book = setup_plotter
129 | data = pd.DataFrame(
130 | {
131 | "open": [100, 105],
132 | "high": [110, 115],
133 | "low": [95, 100],
134 | "close": [105, 110],
135 | "volume": [500, 700],
136 | "account_value": [10000, 10500],
137 | },
138 | index=pd.date_range("2024-01-01", periods=2),
139 | )
140 |
141 | grid_manager.price_grids = [90, 100, 110]
142 | grid_manager.central_price = 100
143 | grid_manager.get_trigger_price.return_value = 100
144 | order_book.get_completed_orders.return_value = []
145 |
146 | plotter.plot_results(data)
147 | mock_show.assert_called_once()
148 |
--------------------------------------------------------------------------------
/config/config_manager.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 |
5 | from strategies.spacing_type import SpacingType
6 | from strategies.strategy_type import StrategyType
7 |
8 | from .exceptions import ConfigFileNotFoundError, ConfigParseError
9 | from .trading_mode import TradingMode
10 |
11 |
12 | class ConfigManager:
13 | def __init__(self, config_file, config_validator):
14 | self.logger = logging.getLogger(self.__class__.__name__)
15 | self.config_file = config_file
16 | self.config_validator = config_validator
17 | self.config = None
18 | self.load_config()
19 |
20 | def load_config(self):
21 | if not os.path.exists(self.config_file):
22 | self.logger.error(f"Config file {self.config_file} does not exist.")
23 | raise ConfigFileNotFoundError(self.config_file)
24 |
25 | with open(self.config_file) as file:
26 | try:
27 | self.config = json.load(file)
28 | self.config_validator.validate(self.config)
29 | except json.JSONDecodeError as e:
30 | self.logger.error(f"Failed to parse config file {self.config_file}: {e}")
31 | raise ConfigParseError(self.config_file, e) from e
32 |
33 | def get(self, key, default=None):
34 | return self.config.get(key, default)
35 |
36 | # --- General Accessor Methods ---
37 | def get_exchange(self):
38 | return self.config.get("exchange", {})
39 |
40 | def get_exchange_name(self):
41 | exchange = self.get_exchange()
42 | return exchange.get("name", None)
43 |
44 | def get_trading_fee(self):
45 | exchange = self.get_exchange()
46 | return exchange.get("trading_fee", 0)
47 |
48 | def get_trading_mode(self) -> TradingMode | None:
49 | exchange = self.get_exchange()
50 | trading_mode = exchange.get("trading_mode", None)
51 |
52 | if trading_mode:
53 | return TradingMode.from_string(trading_mode)
54 |
55 | def get_pair(self):
56 | return self.config.get("pair", {})
57 |
58 | def get_base_currency(self):
59 | pair = self.get_pair()
60 | return pair.get("base_currency", None)
61 |
62 | def get_quote_currency(self):
63 | pair = self.get_pair()
64 | return pair.get("quote_currency", None)
65 |
66 | def get_trading_settings(self):
67 | return self.config.get("trading_settings", {})
68 |
69 | def get_timeframe(self):
70 | trading_settings = self.get_trading_settings()
71 | return trading_settings.get("timeframe", "1h")
72 |
73 | def get_period(self):
74 | trading_settings = self.get_trading_settings()
75 | return trading_settings.get("period", {})
76 |
77 | def get_start_date(self):
78 | period = self.get_period()
79 | return period.get("start_date", None)
80 |
81 | def get_end_date(self):
82 | period = self.get_period()
83 | return period.get("end_date", None)
84 |
85 | def get_initial_balance(self):
86 | trading_settings = self.get_trading_settings()
87 | return trading_settings.get("initial_balance", 10000)
88 |
89 | def get_historical_data_file(self):
90 | trading_settings = self.get_trading_settings()
91 | return trading_settings.get("historical_data_file", None)
92 |
93 | # --- Grid Accessor Methods ---
94 | def get_grid_settings(self):
95 | return self.config.get("grid_strategy", {})
96 |
97 | def get_strategy_type(self) -> StrategyType | None:
98 | grid_settings = self.get_grid_settings()
99 | strategy_type = grid_settings.get("type", None)
100 |
101 | if strategy_type:
102 | return StrategyType.from_string(strategy_type)
103 |
104 | def get_spacing_type(self) -> SpacingType | None:
105 | grid_settings = self.get_grid_settings()
106 | spacing_type = grid_settings.get("spacing", None)
107 |
108 | if spacing_type:
109 | return SpacingType.from_string(spacing_type)
110 |
111 | def get_num_grids(self):
112 | grid_settings = self.get_grid_settings()
113 | return grid_settings.get("num_grids", None)
114 |
115 | def get_grid_range(self):
116 | grid_settings = self.get_grid_settings()
117 | return grid_settings.get("range", {})
118 |
119 | def get_top_range(self):
120 | grid_range = self.get_grid_range()
121 | return grid_range.get("top", None)
122 |
123 | def get_bottom_range(self):
124 | grid_range = self.get_grid_range()
125 | return grid_range.get("bottom", None)
126 |
127 | # --- Risk management (Take Profit / Stop Loss) Accessor Methods ---
128 | def get_risk_management(self):
129 | return self.config.get("risk_management", {})
130 |
131 | def get_take_profit(self):
132 | risk_management = self.get_risk_management()
133 | return risk_management.get("take_profit", {})
134 |
135 | def is_take_profit_enabled(self):
136 | take_profit = self.get_take_profit()
137 | return take_profit.get("enabled", False)
138 |
139 | def get_take_profit_threshold(self):
140 | take_profit = self.get_take_profit()
141 | return take_profit.get("threshold", None)
142 |
143 | def get_stop_loss(self):
144 | risk_management = self.get_risk_management()
145 | return risk_management.get("stop_loss", {})
146 |
147 | def is_stop_loss_enabled(self):
148 | stop_loss = self.get_stop_loss()
149 | return stop_loss.get("enabled", False)
150 |
151 | def get_stop_loss_threshold(self):
152 | stop_loss = self.get_stop_loss()
153 | return stop_loss.get("threshold", None)
154 |
155 | # --- Logging Accessor Methods ---
156 | def get_logging(self):
157 | return self.config.get("logging", {})
158 |
159 | def get_logging_level(self):
160 | logging = self.get_logging()
161 | return logging.get("log_level", {})
162 |
163 | def should_log_to_file(self) -> bool:
164 | logging = self.get_logging()
165 | return logging.get("log_to_file", False)
166 |
--------------------------------------------------------------------------------
/tests/config/test_config_validator.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from config.config_validator import ConfigValidator
4 | from config.exceptions import ConfigValidationError
5 |
6 |
7 | class TestConfigValidator:
8 | @pytest.fixture
9 | def config_validator(self):
10 | return ConfigValidator()
11 |
12 | def test_validate_valid_config(self, config_validator, valid_config):
13 | try:
14 | config_validator.validate(valid_config)
15 | except ConfigValidationError:
16 | pytest.fail("Valid configuration raised ConfigValidationError")
17 |
18 | def test_validate_missing_required_fields(self, config_validator):
19 | invalid_config = {
20 | "exchange": {},
21 | "pair": {},
22 | "trading_settings": {},
23 | "grid_strategy": {},
24 | "risk_management": {},
25 | "logging": {},
26 | }
27 | with pytest.raises(ConfigValidationError) as excinfo:
28 | config_validator.validate(invalid_config)
29 |
30 | missing_fields = excinfo.value.missing_fields
31 | invalid_fields = excinfo.value.invalid_fields
32 |
33 | assert "pair.base_currency" in missing_fields
34 | assert "pair.quote_currency" in missing_fields
35 | assert "trading_settings.initial_balance" in missing_fields
36 | assert "trading_settings.period.start_date" in missing_fields
37 | assert "trading_settings.period.end_date" in missing_fields
38 | assert "grid_strategy.num_grids" in missing_fields
39 | assert "grid_strategy.range.top" in missing_fields
40 | assert "grid_strategy.range.bottom" in missing_fields
41 | assert "logging.log_level" in missing_fields
42 |
43 | assert "exchange.name" in invalid_fields
44 | assert "exchange.trading_fee" in invalid_fields
45 | assert "trading_settings.timeframe" in invalid_fields
46 |
47 | def test_validate_invalid_exchange(self, config_validator, valid_config):
48 | valid_config["exchange"] = {"name": "", "trading_fee": -0.01} # Invalid exchange
49 | with pytest.raises(ConfigValidationError) as excinfo:
50 | config_validator.validate(valid_config)
51 | assert "exchange.name" in excinfo.value.invalid_fields
52 | assert "exchange.trading_fee" in excinfo.value.invalid_fields
53 |
54 | def test_validate_valid_trading_modes(self, config_validator, valid_config):
55 | for mode in ["live", "paper_trading", "backtest"]:
56 | valid_config["exchange"]["trading_mode"] = mode
57 | try:
58 | config_validator.validate(valid_config)
59 | except ConfigValidationError:
60 | pytest.fail(f"Valid trading_mode '{mode}' raised ConfigValidationError")
61 |
62 | def test_validate_invalid_trading_mode(self, config_validator, valid_config):
63 | valid_config["exchange"]["trading_mode"] = "invalid_mode"
64 | with pytest.raises(ConfigValidationError, match="exchange.trading_mode"):
65 | config_validator.validate(valid_config)
66 |
67 | def test_validate_invalid_timeframe(self, config_validator, valid_config):
68 | valid_config["trading_settings"]["timeframe"] = "3h" # Invalid timeframe
69 | with pytest.raises(ConfigValidationError) as excinfo:
70 | config_validator.validate(valid_config)
71 | assert "trading_settings.timeframe" in excinfo.value.invalid_fields
72 |
73 | def test_validate_missing_period_fields(self, config_validator, valid_config):
74 | valid_config["trading_settings"]["period"] = {} # Missing start and end date
75 | with pytest.raises(ConfigValidationError) as excinfo:
76 | config_validator.validate(valid_config)
77 | assert "trading_settings.period.start_date" in excinfo.value.missing_fields
78 | assert "trading_settings.period.end_date" in excinfo.value.missing_fields
79 |
80 | def test_validate_invalid_grid_settings(self, config_validator, valid_config):
81 | # Test invalid grid type
82 | valid_config["grid_strategy"]["type"] = "invalid_type" # Invalid grid type
83 | with pytest.raises(ConfigValidationError) as excinfo:
84 | config_validator.validate(valid_config)
85 | assert "grid_strategy.type" in excinfo.value.invalid_fields
86 |
87 | # Test missing num_grids
88 | valid_config["grid_strategy"]["num_grids"] = None
89 | with pytest.raises(ConfigValidationError) as excinfo:
90 | config_validator.validate(valid_config)
91 | assert "grid_strategy.num_grids" in excinfo.value.missing_fields
92 |
93 | # Test invalid top/bottom range (bottom should be less than top)
94 | valid_config["grid_strategy"]["range"] = {"top": 2800, "bottom": 2850} # Invalid range
95 | with pytest.raises(ConfigValidationError) as excinfo:
96 | config_validator.validate(valid_config)
97 | assert "grid_strategy.range.top" in excinfo.value.invalid_fields
98 | assert "grid_strategy.range.bottom" in excinfo.value.invalid_fields
99 |
100 | def test_validate_limits_invalid_type(self, config_validator, valid_config):
101 | valid_config["risk_management"] = {
102 | "take_profit": {"enabled": "yes"}, # Invalid boolean
103 | "stop_loss": {"enabled": 1}, # Invalid boolean
104 | }
105 | with pytest.raises(ConfigValidationError) as excinfo:
106 | config_validator.validate(valid_config)
107 | assert "risk_management.take_profit.enabled" in excinfo.value.invalid_fields
108 | assert "risk_management.stop_loss.enabled" in excinfo.value.invalid_fields
109 |
110 | def test_validate_logging_invalid_level(self, config_validator, valid_config):
111 | valid_config["logging"] = {
112 | "log_level": "VERBOSE", # Invalid log level
113 | "log_to_file": True,
114 | }
115 | with pytest.raises(ConfigValidationError) as excinfo:
116 | config_validator.validate(valid_config)
117 | assert "logging.log_level" in excinfo.value.invalid_fields
118 |
119 | def test_validate_logging_missing_level(self, config_validator, valid_config):
120 | valid_config["logging"] = {"log_to_file": True}
121 | with pytest.raises(ConfigValidationError) as excinfo:
122 | config_validator.validate(valid_config)
123 | assert "logging.log_level" in excinfo.value.missing_fields
124 |
--------------------------------------------------------------------------------
/tests/bot_management/test_notification_handler.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock, patch
2 |
3 | import pytest
4 |
5 | from config.trading_mode import TradingMode
6 | from core.bot_management.event_bus import EventBus, Events
7 | from core.bot_management.notification.notification_content import NotificationType
8 | from core.bot_management.notification.notification_handler import NotificationHandler
9 | from core.order_handling.order import Order, OrderSide, OrderStatus, OrderType
10 |
11 |
12 | class TestNotificationHandler:
13 | @pytest.fixture
14 | def event_bus(self):
15 | return Mock(spec=EventBus)
16 |
17 | @pytest.fixture
18 | def notification_handler_enabled(self, event_bus):
19 | urls = ["json://localhost:8080/path"]
20 | handler = NotificationHandler(
21 | event_bus=event_bus,
22 | urls=urls,
23 | trading_mode=TradingMode.LIVE,
24 | )
25 | return handler
26 |
27 | @pytest.fixture
28 | def notification_handler_disabled(self, event_bus):
29 | return NotificationHandler(
30 | event_bus=event_bus,
31 | urls=None,
32 | trading_mode=TradingMode.BACKTEST,
33 | )
34 |
35 | @pytest.fixture
36 | def mock_order(self):
37 | return Order(
38 | identifier="test-123",
39 | status=OrderStatus.CLOSED,
40 | order_type=OrderType.LIMIT,
41 | side=OrderSide.BUY,
42 | price=1000.0,
43 | average=1000.0,
44 | amount=1.0,
45 | filled=1.0,
46 | remaining=0.0,
47 | timestamp=1234567890000,
48 | datetime="2024-01-01T00:00:00Z",
49 | last_trade_timestamp="2024-01-01T00:00:00Z",
50 | symbol="BTC/USDT",
51 | time_in_force="GTC",
52 | )
53 |
54 | @patch("apprise.Apprise")
55 | def test_notification_handler_enabled_initialization(self, mock_apprise, event_bus):
56 | handler = NotificationHandler(
57 | event_bus=event_bus,
58 | urls=["mock://example.com"],
59 | trading_mode=TradingMode.LIVE,
60 | )
61 | assert handler.enabled is True
62 | mock_apprise.return_value.add.assert_called_once_with("mock://example.com")
63 | event_bus.subscribe.assert_called_once_with(Events.ORDER_FILLED, handler._send_notification_on_order_filled)
64 |
65 | @patch("apprise.Apprise")
66 | def test_notification_handler_disabled_initialization(self, mock_apprise, event_bus):
67 | handler = NotificationHandler(
68 | event_bus=event_bus,
69 | urls=None,
70 | trading_mode=TradingMode.BACKTEST,
71 | )
72 | assert handler.enabled is False
73 | mock_apprise.assert_not_called()
74 | event_bus.subscribe.assert_not_called()
75 |
76 | @pytest.mark.asyncio
77 | async def test_send_notification_with_predefined_content(self, notification_handler_enabled, mock_order):
78 | handler = notification_handler_enabled
79 | with patch.object(handler.apprise_instance, "notify") as mock_notify:
80 | handler.send_notification(
81 | NotificationType.ORDER_FILLED,
82 | order_details=str(mock_order),
83 | )
84 |
85 | mock_notify.assert_called_once_with(
86 | title="Order Filled",
87 | body=f"Order has been filled successfully:\n{mock_order!s}",
88 | )
89 |
90 | @pytest.mark.asyncio
91 | async def test_send_notification_with_missing_placeholder(self, notification_handler_enabled):
92 | handler = notification_handler_enabled
93 | with (
94 | patch.object(handler.apprise_instance, "notify") as mock_notify,
95 | patch("logging.Logger.warning") as mock_warning,
96 | ):
97 | handler.send_notification(NotificationType.ORDER_FILLED)
98 |
99 | mock_warning.assert_called_once_with(
100 | "Missing placeholders for notification: {'order_details'}. Defaulting to 'N/A' for missing values.",
101 | )
102 | mock_notify.assert_called_once_with(
103 | title="Order Filled",
104 | body="Order has been filled successfully:\nN/A",
105 | )
106 |
107 | @pytest.mark.asyncio
108 | async def test_send_notification_with_order_failed(self, notification_handler_enabled):
109 | handler = notification_handler_enabled
110 | error_details = "Insufficient funds"
111 |
112 | with patch.object(handler.apprise_instance, "notify") as mock_notify:
113 | handler.send_notification(
114 | NotificationType.ORDER_FAILED,
115 | error_details=error_details,
116 | )
117 |
118 | mock_notify.assert_called_once_with(
119 | title="Order Placement Failed",
120 | body=f"Failed to place order:\n{error_details}",
121 | )
122 |
123 | @pytest.mark.asyncio
124 | async def test_async_send_notification_success(self, notification_handler_enabled):
125 | handler = notification_handler_enabled
126 |
127 | # Mock both the executor and send_notification
128 | with (
129 | patch.object(handler, "_executor", create=True) as mock_executor,
130 | patch.object(handler, "send_notification") as mock_send,
131 | ):
132 | # Configure the mock executor to run the function directly
133 | mock_executor.submit = lambda f, *args, **kwargs: f(*args, **kwargs)
134 |
135 | await handler.async_send_notification(
136 | NotificationType.ORDER_FILLED,
137 | order_details="test",
138 | )
139 |
140 | mock_send.assert_called_once_with(
141 | NotificationType.ORDER_FILLED,
142 | order_details="test",
143 | )
144 |
145 | @pytest.mark.asyncio
146 | async def test_event_subscription_and_notification_on_order_filled(
147 | self,
148 | notification_handler_enabled,
149 | mock_order,
150 | ):
151 | handler = notification_handler_enabled
152 | with patch.object(handler, "async_send_notification") as mock_async_send:
153 | await handler._send_notification_on_order_filled(mock_order)
154 |
155 | mock_async_send.assert_called_once_with(NotificationType.ORDER_FILLED, order_details=str(mock_order))
156 |
157 | def test_send_notification_disabled(self, notification_handler_disabled):
158 | handler = notification_handler_disabled
159 | with patch("apprise.Apprise.notify") as mock_notify:
160 | handler.send_notification(NotificationType.ORDER_FILLED, order_details="test")
161 | mock_notify.assert_not_called()
162 |
--------------------------------------------------------------------------------
/tests/order_handling/test_live_order_execution_strategy.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import AsyncMock, Mock, patch
2 |
3 | import pytest
4 |
5 | from core.order_handling.exceptions import OrderExecutionFailedError
6 | from core.order_handling.execution_strategy.live_order_execution_strategy import (
7 | LiveOrderExecutionStrategy,
8 | )
9 | from core.order_handling.order import OrderSide, OrderStatus, OrderType
10 | from core.services.exceptions import DataFetchError
11 |
12 |
13 | @pytest.mark.asyncio
14 | class TestLiveOrderExecutionStrategy:
15 | @pytest.fixture
16 | def setup_strategy(self):
17 | mock_exchange_service = Mock()
18 | strategy = LiveOrderExecutionStrategy(exchange_service=mock_exchange_service)
19 | return strategy, mock_exchange_service
20 |
21 | @patch("time.time", return_value=1680000000) # Mock time for predictable order IDs
22 | async def test_execute_market_order_success(self, mock_time, setup_strategy):
23 | strategy, exchange_service = setup_strategy
24 | pair = "BTC/USDT"
25 | quantity = 0.5
26 | price = 30000
27 | raw_order = {
28 | "id": "test-order-id",
29 | "status": "closed",
30 | "type": "market",
31 | "side": "buy",
32 | "price": price,
33 | "amount": quantity,
34 | "filled": quantity,
35 | "remaining": 0,
36 | "symbol": pair,
37 | "timestamp": 1680000000000,
38 | }
39 |
40 | exchange_service.place_order = AsyncMock(return_value=raw_order)
41 |
42 | order = await strategy.execute_market_order(OrderSide.BUY, pair, quantity, price)
43 |
44 | assert order is not None
45 | assert order.identifier == "test-order-id"
46 | assert order.status == OrderStatus.CLOSED
47 | assert order.order_type == OrderType.MARKET
48 | assert order.side == OrderSide.BUY
49 | assert order.price == price
50 |
51 | async def test_execute_market_order_retries(self, setup_strategy):
52 | strategy, exchange_service = setup_strategy
53 | pair = "BTC/USDT"
54 | quantity = 0.5
55 | price = 30000
56 |
57 | exchange_service.place_order = AsyncMock(side_effect=Exception("Order failed"))
58 |
59 | with pytest.raises(OrderExecutionFailedError):
60 | await strategy.execute_market_order(OrderSide.BUY, pair, quantity, price)
61 |
62 | assert exchange_service.place_order.call_count == strategy.max_retries
63 |
64 | async def test_execute_limit_order_success(self, setup_strategy):
65 | strategy, exchange_service = setup_strategy
66 | pair = "ETH/USDT"
67 | quantity = 1
68 | price = 2000
69 | raw_order = {
70 | "id": "test-limit-order-id",
71 | "status": "open",
72 | "type": "limit",
73 | "side": "sell",
74 | "price": price,
75 | "amount": quantity,
76 | "filled": 0,
77 | "remaining": quantity,
78 | "symbol": pair,
79 | }
80 |
81 | exchange_service.place_order = AsyncMock(return_value=raw_order)
82 |
83 | order = await strategy.execute_limit_order(OrderSide.SELL, pair, quantity, price)
84 |
85 | assert order is not None
86 | assert order.identifier == "test-limit-order-id"
87 | assert order.status == OrderStatus.OPEN
88 | assert order.order_type == OrderType.LIMIT
89 | assert order.side == OrderSide.SELL
90 | assert order.price == price
91 |
92 | async def test_execute_limit_order_data_fetch_error(self, setup_strategy):
93 | strategy, exchange_service = setup_strategy
94 | pair = "ETH/USDT"
95 | quantity = 1
96 | price = 2000
97 |
98 | exchange_service.place_order = AsyncMock(side_effect=DataFetchError("Exchange API error"))
99 |
100 | with pytest.raises(OrderExecutionFailedError):
101 | await strategy.execute_limit_order(OrderSide.SELL, pair, quantity, price)
102 |
103 | async def test_get_order_success(self, setup_strategy):
104 | strategy, exchange_service = setup_strategy
105 | order_id = "test-order-id"
106 | pair = "BTC/USDT"
107 | raw_order = {
108 | "id": order_id,
109 | "status": "open",
110 | "type": "limit",
111 | "side": "buy",
112 | "price": 100,
113 | "amount": 1,
114 | "filled": 0,
115 | "remaining": 1,
116 | "symbol": pair,
117 | }
118 |
119 | exchange_service.fetch_order = AsyncMock(return_value=raw_order)
120 |
121 | order = await strategy.get_order(order_id, pair)
122 |
123 | assert order is not None
124 | assert order.identifier == order_id
125 | assert order.symbol == pair
126 | assert order.status == OrderStatus.OPEN
127 | assert order.order_type == OrderType.LIMIT
128 |
129 | async def test_get_order_data_fetch_error(self, setup_strategy):
130 | strategy, exchange_service = setup_strategy
131 | order_id = "test-order-id"
132 | pair = "BTC/USDT"
133 |
134 | exchange_service.fetch_order = AsyncMock(side_effect=DataFetchError("Order not found"))
135 |
136 | with pytest.raises(DataFetchError):
137 | await strategy.get_order(order_id, pair)
138 |
139 | async def test_handle_partial_fill(self, setup_strategy):
140 | strategy, exchange_service = setup_strategy
141 | partial_order = Mock(identifier="partial-order", filled=0.5)
142 | exchange_service.cancel_order = AsyncMock(return_value={"status": "canceled"})
143 |
144 | await strategy._handle_partial_fill(partial_order, "BTC/USDT")
145 | exchange_service.cancel_order.assert_called_once_with("partial-order", "BTC/USDT")
146 |
147 | async def test_retry_cancel_order(self, setup_strategy):
148 | strategy, exchange_service = setup_strategy
149 | order_id = "test-order-id"
150 | pair = "BTC/USDT"
151 |
152 | exchange_service.cancel_order = AsyncMock(
153 | side_effect=[
154 | {"status": "failed"},
155 | {"status": "canceled"},
156 | ],
157 | )
158 |
159 | result = await strategy._retry_cancel_order(order_id, pair)
160 |
161 | assert result is True
162 | assert exchange_service.cancel_order.call_count == 2
163 |
164 | async def test_adjust_price_buy(self, setup_strategy):
165 | strategy, _ = setup_strategy
166 | price = 30000
167 | adjusted_price = await strategy._adjust_price(OrderSide.BUY, price, 1)
168 |
169 | assert adjusted_price > price
170 |
171 | async def test_adjust_price_sell(self, setup_strategy):
172 | strategy, _ = setup_strategy
173 | price = 30000
174 | adjusted_price = await strategy._adjust_price(OrderSide.SELL, price, 1)
175 |
176 | assert adjusted_price < price
177 |
--------------------------------------------------------------------------------
/core/order_handling/order_status_tracker.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 |
4 | from core.bot_management.event_bus import EventBus, Events
5 | from core.order_handling.order import Order, OrderStatus
6 | from core.order_handling.order_book import OrderBook
7 |
8 |
9 | class OrderStatusTracker:
10 | """
11 | Tracks the status of pending orders and publishes events
12 | when their states change (e.g., open, filled, canceled).
13 | """
14 |
15 | def __init__(
16 | self,
17 | order_book: OrderBook,
18 | order_execution_strategy,
19 | event_bus: EventBus,
20 | polling_interval: float = 15.0,
21 | ):
22 | """
23 | Initializes the OrderStatusTracker.
24 |
25 | Args:
26 | order_book: OrderBook instance to manage and query orders.
27 | order_execution_strategy: Strategy for querying order statuses from the exchange.
28 | event_bus: EventBus instance for publishing state change events.
29 | polling_interval: Time interval (in seconds) between status checks.
30 | """
31 | self.order_book = order_book
32 | self.order_execution_strategy = order_execution_strategy
33 | self.event_bus = event_bus
34 | self.polling_interval = polling_interval
35 | self._monitoring_task = None
36 | self._active_tasks = set()
37 | self.logger = logging.getLogger(self.__class__.__name__)
38 |
39 | async def _track_open_order_statuses(self) -> None:
40 | """
41 | Periodically checks the statuses of open orders and updates their states.
42 | """
43 | try:
44 | while True:
45 | await self._process_open_orders()
46 | await asyncio.sleep(self.polling_interval)
47 |
48 | except asyncio.CancelledError:
49 | self.logger.info("OrderStatusTracker monitoring task was cancelled.")
50 | await self._cancel_active_tasks()
51 |
52 | except Exception as error:
53 | self.logger.error(f"Unexpected error in OrderStatusTracker: {error}")
54 |
55 | async def _process_open_orders(self) -> None:
56 | """
57 | Processes open orders by querying their statuses and handling state changes.
58 | """
59 | open_orders = self.order_book.get_open_orders()
60 | tasks = [self._create_task(self._query_and_handle_order(order)) for order in open_orders]
61 | results = await asyncio.gather(*tasks, return_exceptions=True)
62 |
63 | for result in results:
64 | if isinstance(result, Exception):
65 | self.logger.error(f"Error during order processing: {result}", exc_info=True)
66 |
67 | async def _query_and_handle_order(self, local_order: Order):
68 | """
69 | Query order and handling state changes if needed.
70 | """
71 | try:
72 | remote_order = await self.order_execution_strategy.get_order(local_order.identifier, local_order.symbol)
73 | self._handle_order_status_change(remote_order)
74 |
75 | except Exception as error:
76 | self.logger.error(
77 | f"Failed to query remote order with identifier {local_order.identifier}: {error}",
78 | exc_info=True,
79 | )
80 |
81 | def _handle_order_status_change(
82 | self,
83 | remote_order: Order,
84 | ) -> None:
85 | """
86 | Handles changes in the status of the order data fetched from the exchange.
87 |
88 | Args:
89 | remote_order: The latest `Order` object fetched from the exchange.
90 |
91 | Raises:
92 | ValueError: If critical fields (e.g., status) are missing from the remote order.
93 | """
94 | try:
95 | if remote_order.status == OrderStatus.UNKNOWN:
96 | self.logger.error(f"Missing 'status' in remote order object: {remote_order}", exc_info=True)
97 | raise ValueError("Order data from the exchange is missing the 'status' field.")
98 | elif remote_order.status == OrderStatus.CLOSED:
99 | self.order_book.update_order_status(remote_order.identifier, OrderStatus.CLOSED)
100 | self.event_bus.publish_sync(Events.ORDER_FILLED, remote_order)
101 | self.logger.info(f"Order {remote_order.identifier} filled.")
102 | elif remote_order.status == OrderStatus.CANCELED:
103 | self.order_book.update_order_status(remote_order.identifier, OrderStatus.CANCELED)
104 | self.event_bus.publish_sync(Events.ORDER_CANCELLED, remote_order)
105 | self.logger.warning(f"Order {remote_order.identifier} was canceled.")
106 | elif remote_order.status == OrderStatus.OPEN: # Still open
107 | if remote_order.filled > 0:
108 | self.logger.info(
109 | f"Order {remote_order} partially filled. Filled: {remote_order.filled}, "
110 | f"Remaining: {remote_order.remaining}.",
111 | )
112 | else:
113 | self.logger.info(f"Order {remote_order} is still open. No fills yet.")
114 | else:
115 | self.logger.warning(
116 | f"Unhandled order status '{remote_order.status}' for order {remote_order.identifier}.",
117 | )
118 |
119 | except Exception as e:
120 | self.logger.error(f"Error handling order status change: {e}", exc_info=True)
121 |
122 | def _create_task(self, coro):
123 | """
124 | Creates a managed asyncio task and adds it to the active task set.
125 |
126 | Args:
127 | coro: Coroutine to be scheduled as a task.
128 | """
129 | task = asyncio.create_task(coro)
130 | self._active_tasks.add(task)
131 | task.add_done_callback(self._active_tasks.discard)
132 | return task
133 |
134 | async def _cancel_active_tasks(self):
135 | """
136 | Cancels all active tasks tracked by the tracker.
137 | """
138 | for task in self._active_tasks:
139 | task.cancel()
140 | await asyncio.gather(*self._active_tasks, return_exceptions=True)
141 | self._active_tasks.clear()
142 |
143 | def start_tracking(self) -> None:
144 | """
145 | Starts the order tracking task.
146 | """
147 | if self._monitoring_task and not self._monitoring_task.done():
148 | self.logger.warning("OrderStatusTracker is already running.")
149 | return
150 | self._monitoring_task = asyncio.create_task(self._track_open_order_statuses())
151 | self.logger.info("OrderStatusTracker has started tracking open orders.")
152 |
153 | async def stop_tracking(self) -> None:
154 | """
155 | Stops the order tracking task.
156 | """
157 | if self._monitoring_task:
158 | self._monitoring_task.cancel()
159 | try:
160 | await self._monitoring_task
161 | except asyncio.CancelledError:
162 | self.logger.info("OrderStatusTracker monitoring task was cancelled.")
163 | await self._cancel_active_tasks()
164 | self._monitoring_task = None
165 | self.logger.info("OrderStatusTracker has stopped tracking open orders.")
166 |
--------------------------------------------------------------------------------
/tests/config/test_config_manager.py:
--------------------------------------------------------------------------------
1 | import json
2 | from unittest.mock import Mock, mock_open, patch
3 |
4 | import pytest
5 |
6 | from config.config_manager import ConfigManager
7 | from config.config_validator import ConfigValidator
8 | from config.exceptions import ConfigFileNotFoundError, ConfigParseError
9 | from config.trading_mode import TradingMode
10 | from strategies.spacing_type import SpacingType
11 | from strategies.strategy_type import StrategyType
12 |
13 |
14 | class TestConfigManager:
15 | @pytest.fixture
16 | def mock_validator(self):
17 | return Mock(spec=ConfigValidator)
18 |
19 | @pytest.fixture
20 | def config_manager(self, mock_validator, valid_config):
21 | # Mocking both open and os.path.exists to simulate a valid config file
22 | mocked_open = mock_open(read_data=json.dumps(valid_config))
23 | with patch("builtins.open", mocked_open), patch("os.path.exists", return_value=True):
24 | return ConfigManager("config.json", mock_validator)
25 |
26 | def test_load_config_valid(self, config_manager, valid_config, mock_validator):
27 | mock_validator.validate.assert_called_once_with(valid_config)
28 | assert config_manager.config == valid_config
29 |
30 | def test_load_config_file_not_found(self, mock_validator):
31 | with patch("os.path.exists", return_value=False), pytest.raises(ConfigFileNotFoundError):
32 | ConfigManager("config.json", mock_validator)
33 |
34 | def test_load_config_json_decode_error(self, mock_validator):
35 | invalid_json = '{"invalid_json": ' # Malformed JSON
36 | mocked_open = mock_open(read_data=invalid_json)
37 | with (
38 | patch("builtins.open", mocked_open),
39 | patch("os.path.exists", return_value=True),
40 | pytest.raises(ConfigParseError),
41 | ):
42 | ConfigManager("config.json", mock_validator)
43 |
44 | def test_get_exchange_name(self, config_manager):
45 | assert config_manager.get_exchange_name() == "binance"
46 |
47 | def test_get_trading_fee(self, config_manager):
48 | assert config_manager.get_trading_fee() == 0.001
49 |
50 | def test_get_base_currency(self, config_manager):
51 | assert config_manager.get_base_currency() == "ETH"
52 |
53 | def test_get_quote_currency(self, config_manager):
54 | assert config_manager.get_quote_currency() == "USDT"
55 |
56 | def test_get_initial_balance(self, config_manager):
57 | assert config_manager.get_initial_balance() == 10000
58 |
59 | def test_get_spacing_type(self, config_manager):
60 | assert config_manager.get_spacing_type() == SpacingType.GEOMETRIC
61 |
62 | def test_get_strategy_type(self, config_manager):
63 | assert config_manager.get_strategy_type() == StrategyType.SIMPLE_GRID
64 |
65 | def test_get_trading_mode(self, config_manager):
66 | assert config_manager.get_trading_mode() == TradingMode.BACKTEST
67 |
68 | def test_get_timeframe(self, config_manager):
69 | assert config_manager.get_timeframe() == "1m"
70 |
71 | def test_get_period(self, config_manager):
72 | expected_period = {
73 | "start_date": "2024-07-04T00:00:00Z",
74 | "end_date": "2024-07-11T00:00:00Z",
75 | }
76 | assert config_manager.get_period() == expected_period
77 |
78 | def test_get_start_date(self, config_manager):
79 | assert config_manager.get_start_date() == "2024-07-04T00:00:00Z"
80 |
81 | def test_get_end_date(self, config_manager):
82 | assert config_manager.get_end_date() == "2024-07-11T00:00:00Z"
83 |
84 | def test_get_num_grids(self, config_manager):
85 | assert config_manager.get_num_grids() == 20
86 |
87 | def test_get_grid_range(self, config_manager):
88 | expected_range = {
89 | "top": 3100,
90 | "bottom": 2850,
91 | }
92 | assert config_manager.get_grid_range() == expected_range
93 |
94 | def test_get_top_range(self, config_manager):
95 | assert config_manager.get_top_range() == 3100
96 |
97 | def test_get_bottom_range(self, config_manager):
98 | assert config_manager.get_bottom_range() == 2850
99 |
100 | def test_is_take_profit_enabled(self, config_manager):
101 | assert not config_manager.is_take_profit_enabled()
102 |
103 | def test_get_take_profit_threshold(self, config_manager):
104 | assert config_manager.get_take_profit_threshold() == 3700
105 |
106 | def test_get_stop_loss_threshold(self, config_manager):
107 | assert config_manager.get_stop_loss_threshold() == 2830
108 |
109 | def test_is_stop_loss_enabled(self, config_manager):
110 | assert not config_manager.is_stop_loss_enabled()
111 |
112 | def test_get_log_level(self, config_manager):
113 | assert config_manager.get_logging_level() == "INFO"
114 |
115 | def test_should_log_to_file_true(self, config_manager):
116 | assert config_manager.should_log_to_file() is True
117 |
118 | def test_get_trading_mode_invalid_value(self, config_manager):
119 | config_manager.config["exchange"]["trading_mode"] = "invalid_mode"
120 |
121 | with pytest.raises(
122 | ValueError,
123 | match="Invalid trading mode: 'invalid_mode'. Available modes are: backtest, paper_trading, live",
124 | ):
125 | config_manager.get_trading_mode()
126 |
127 | def test_get_spacing_type_invalid_value(self, config_manager):
128 | config_manager.config["grid_strategy"]["spacing"] = "invalid_spacing"
129 |
130 | with pytest.raises(
131 | ValueError,
132 | match="Invalid spacing type: 'invalid_spacing'. Available spacings are: arithmetic, geometric",
133 | ):
134 | config_manager.get_spacing_type()
135 |
136 | def test_get_strategy_type_invalid_value(self, config_manager):
137 | config_manager.config["grid_strategy"]["type"] = "invalid_strategy"
138 |
139 | with pytest.raises(
140 | ValueError,
141 | match="Invalid strategy type: 'invalid_strategy'. Available strategies are: simple_grid, hedged_grid",
142 | ):
143 | config_manager.get_strategy_type()
144 |
145 | def test_get_timeframe_default(self, config_manager):
146 | del config_manager.config["trading_settings"]["timeframe"]
147 | assert config_manager.get_timeframe() == "1h"
148 |
149 | def test_get_historical_data_file_default(self, config_manager):
150 | del config_manager.config["trading_settings"]["historical_data_file"]
151 | assert config_manager.get_historical_data_file() is None
152 |
153 | def test_is_take_profit_enabled_default(self, config_manager):
154 | del config_manager.config["risk_management"]["take_profit"]
155 | assert config_manager.is_take_profit_enabled() is False
156 |
157 | def test_get_take_profit_threshold_default(self, config_manager):
158 | del config_manager.config["risk_management"]["take_profit"]
159 | assert config_manager.get_take_profit_threshold() is None
160 |
161 | def test_is_stop_loss_enabled_default(self, config_manager):
162 | del config_manager.config["risk_management"]["stop_loss"]
163 | assert config_manager.is_stop_loss_enabled() is False
164 |
165 | def test_get_stop_loss_threshold_default(self, config_manager):
166 | del config_manager.config["risk_management"]["stop_loss"]
167 | assert config_manager.get_stop_loss_threshold() is None
168 |
--------------------------------------------------------------------------------
/core/order_handling/execution_strategy/live_order_execution_strategy.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 |
4 | from core.services.exceptions import DataFetchError
5 | from core.services.exchange_interface import ExchangeInterface
6 |
7 | from ..exceptions import OrderExecutionFailedError
8 | from ..order import Order, OrderSide, OrderStatus, OrderType
9 | from .order_execution_strategy_interface import OrderExecutionStrategyInterface
10 |
11 |
12 | class LiveOrderExecutionStrategy(OrderExecutionStrategyInterface):
13 | def __init__(
14 | self,
15 | exchange_service: ExchangeInterface,
16 | max_retries: int = 3,
17 | retry_delay: int = 1,
18 | max_slippage: float = 0.01,
19 | ) -> None:
20 | self.exchange_service = exchange_service
21 | self.max_retries = max_retries
22 | self.retry_delay = retry_delay
23 | self.max_slippage = max_slippage
24 | self.logger = logging.getLogger(self.__class__.__name__)
25 |
26 | async def execute_market_order(
27 | self,
28 | order_side: OrderSide,
29 | pair: str,
30 | quantity: float,
31 | price: float,
32 | ) -> Order | None:
33 | for attempt in range(self.max_retries):
34 | try:
35 | raw_order = await self.exchange_service.place_order(
36 | pair,
37 | OrderType.MARKET.value.lower(),
38 | order_side.name.lower(),
39 | quantity,
40 | price,
41 | )
42 | order_result = await self._parse_order_result(raw_order)
43 |
44 | if order_result.status == OrderStatus.CLOSED:
45 | return order_result # Order fully filled
46 |
47 | elif order_result.status == OrderStatus.OPEN:
48 | await self._handle_partial_fill(order_result, pair)
49 |
50 | await asyncio.sleep(self.retry_delay)
51 | self.logger.info(f"Retrying order. Attempt {attempt + 1}/{self.max_retries}.")
52 | price = await self._adjust_price(order_side, price, attempt)
53 |
54 | except Exception as e:
55 | self.logger.error(f"Attempt {attempt + 1} failed with error: {e!s}")
56 | await asyncio.sleep(self.retry_delay)
57 |
58 | raise OrderExecutionFailedError(
59 | "Failed to execute Market order after maximum retries.",
60 | order_side,
61 | OrderType.MARKET,
62 | pair,
63 | quantity,
64 | price,
65 | )
66 |
67 | async def execute_limit_order(
68 | self,
69 | order_side: OrderSide,
70 | pair: str,
71 | quantity: float,
72 | price: float,
73 | ) -> Order | None:
74 | try:
75 | raw_order = await self.exchange_service.place_order(
76 | pair,
77 | OrderType.LIMIT.value.lower(),
78 | order_side.name.lower(),
79 | quantity,
80 | price,
81 | )
82 | order_result = await self._parse_order_result(raw_order)
83 | return order_result
84 |
85 | except DataFetchError as e:
86 | self.logger.error(f"DataFetchError during order execution for {pair} - {e}")
87 | raise OrderExecutionFailedError(
88 | f"Failed to execute Limit order on {pair}: {e}",
89 | order_side,
90 | OrderType.LIMIT,
91 | pair,
92 | quantity,
93 | price,
94 | ) from e
95 |
96 | except Exception as e:
97 | self.logger.error(f"Unexpected error in execute_limit_order: {e}")
98 | raise OrderExecutionFailedError(
99 | f"Unexpected error during order execution: {e}",
100 | order_side,
101 | OrderType.LIMIT,
102 | pair,
103 | quantity,
104 | price,
105 | ) from e
106 |
107 | async def get_order(
108 | self,
109 | order_id: str,
110 | pair: str,
111 | ) -> Order | None:
112 | try:
113 | raw_order = await self.exchange_service.fetch_order(order_id, pair)
114 | order_result = await self._parse_order_result(raw_order)
115 | return order_result
116 |
117 | except DataFetchError as e:
118 | raise e
119 |
120 | except Exception as e:
121 | raise DataFetchError(f"Unexpected error during order status retrieval: {e!s}") from e
122 |
123 | async def _parse_order_result(
124 | self,
125 | raw_order_result: dict,
126 | ) -> Order:
127 | """
128 | Parses the raw order response from the exchange into an Order object.
129 |
130 | Args:
131 | raw_order_result: The raw response from the exchange.
132 |
133 | Returns:
134 | An Order object with standardized fields.
135 | """
136 | return Order(
137 | identifier=raw_order_result.get("id", ""),
138 | status=OrderStatus(raw_order_result.get("status", "unknown").lower()),
139 | order_type=OrderType(raw_order_result.get("type", "unknown").lower()),
140 | side=OrderSide(raw_order_result.get("side", "unknown").lower()),
141 | price=raw_order_result.get("price", 0.0),
142 | average=raw_order_result.get("average"),
143 | amount=raw_order_result.get("amount", 0.0),
144 | filled=raw_order_result.get("filled", 0.0),
145 | remaining=raw_order_result.get("remaining", 0.0),
146 | timestamp=raw_order_result.get("timestamp", 0),
147 | datetime=raw_order_result.get("datetime"),
148 | last_trade_timestamp=raw_order_result.get("lastTradeTimestamp"),
149 | symbol=raw_order_result.get("symbol", ""),
150 | time_in_force=raw_order_result.get("timeInForce"),
151 | trades=raw_order_result.get("trades", []),
152 | fee=raw_order_result.get("fee"),
153 | cost=raw_order_result.get("cost"),
154 | info=raw_order_result.get("info", raw_order_result),
155 | )
156 |
157 | async def _adjust_price(
158 | self,
159 | order_side: OrderSide,
160 | price: float,
161 | attempt: int,
162 | ) -> float:
163 | adjustment = self.max_slippage / self.max_retries * attempt
164 | return price * (1 + adjustment) if order_side == OrderSide.BUY else price * (1 - adjustment)
165 |
166 | async def _handle_partial_fill(
167 | self,
168 | order: Order,
169 | pair: str,
170 | ) -> dict | None:
171 | self.logger.info(f"Order partially filled with {order.filled}. Attempting to cancel and retry full quantity.")
172 |
173 | if not await self._retry_cancel_order(order.identifier, pair):
174 | self.logger.error(f"Unable to cancel partially filled order {order.identifier} after retries.")
175 |
176 | async def _retry_cancel_order(
177 | self,
178 | order_id: str,
179 | pair: str,
180 | ) -> bool:
181 | for cancel_attempt in range(self.max_retries):
182 | try:
183 | cancel_result = await self.exchange_service.cancel_order(order_id, pair)
184 |
185 | if cancel_result["status"] == "canceled":
186 | self.logger.info(f"Successfully canceled order {order_id}.")
187 | return True
188 |
189 | self.logger.warning(f"Cancel attempt {cancel_attempt + 1} for order {order_id} failed.")
190 |
191 | except Exception as e:
192 | self.logger.warning(f"Error during cancel attempt {cancel_attempt + 1} for order {order_id}: {e!s}")
193 |
194 | await asyncio.sleep(self.retry_delay)
195 | return False
196 |
--------------------------------------------------------------------------------
/core/services/backtest_exchange_service.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 | from typing import Any
5 |
6 | import ccxt
7 | import pandas as pd
8 |
9 | from config.config_manager import ConfigManager
10 | from utils.constants import CANDLE_LIMITS, TIMEFRAME_MAPPINGS
11 |
12 | from .exceptions import (
13 | DataFetchError,
14 | HistoricalMarketDataFileNotFoundError,
15 | UnsupportedExchangeError,
16 | UnsupportedPairError,
17 | UnsupportedTimeframeError,
18 | )
19 | from .exchange_interface import ExchangeInterface
20 |
21 |
22 | class BacktestExchangeService(ExchangeInterface):
23 | def __init__(self, config_manager: ConfigManager):
24 | self.logger = logging.getLogger(self.__class__.__name__)
25 | self.config_manager = config_manager
26 | self.historical_data_file = self.config_manager.get_historical_data_file()
27 | self.exchange_name = self.config_manager.get_exchange_name()
28 | self.exchange = self._initialize_exchange()
29 |
30 | def _initialize_exchange(self) -> ccxt.Exchange | None:
31 | try:
32 | return getattr(ccxt, self.exchange_name)()
33 | except AttributeError:
34 | raise UnsupportedExchangeError(f"The exchange '{self.exchange_name}' is not supported.") from None
35 |
36 | def _is_timeframe_supported(self, timeframe: str) -> bool:
37 | if timeframe not in self.exchange.timeframes:
38 | self.logger.error(f"Timeframe '{timeframe}' is not supported by {self.exchange_name}.")
39 | return False
40 | return True
41 |
42 | def _is_pair_supported(self, pair: str) -> bool:
43 | markets = self.exchange.load_markets()
44 | return pair in markets
45 |
46 | def fetch_ohlcv(
47 | self,
48 | pair: str,
49 | timeframe: str,
50 | start_date: str,
51 | end_date: str,
52 | ) -> pd.DataFrame:
53 | if self.historical_data_file:
54 | if not os.path.exists(self.historical_data_file):
55 | raise HistoricalMarketDataFileNotFoundError(
56 | f"Failed to load OHLCV data from file: {self.historical_data_file}",
57 | )
58 |
59 | self.logger.info(f"Loading OHLCV data from file: {self.historical_data_file}")
60 | return self._load_ohlcv_from_file(self.historical_data_file, start_date, end_date)
61 |
62 | if not self._is_pair_supported(pair):
63 | raise UnsupportedPairError(f"Pair: {pair} is not supported by {self.exchange_name}")
64 |
65 | if not self._is_timeframe_supported(timeframe):
66 | raise UnsupportedTimeframeError(f"Timeframe '{timeframe}' is not supported by {self.exchange_name}.")
67 |
68 | self.logger.info(f"Fetching OHLCV data for {pair} from {start_date} to {end_date}")
69 | try:
70 | since = self.exchange.parse8601(start_date)
71 | until = self.exchange.parse8601(end_date)
72 | candles_per_request = self._get_candle_limit()
73 | total_candles_needed = (until - since) // self._get_timeframe_in_ms(timeframe)
74 |
75 | if total_candles_needed > candles_per_request:
76 | return self._fetch_ohlcv_in_chunks(pair, timeframe, since, until, candles_per_request)
77 | else:
78 | return self._fetch_ohlcv_single_batch(pair, timeframe, since, until)
79 | except ccxt.NetworkError as e:
80 | raise DataFetchError(f"Network issue occurred while fetching OHLCV data: {e!s}") from e
81 | except ccxt.BaseError as e:
82 | raise DataFetchError(f"Exchange-specific error occurred: {e!s}") from e
83 | except Exception as e:
84 | raise DataFetchError(f"Failed to fetch OHLCV data {e!s}.") from e
85 |
86 | def _load_ohlcv_from_file(
87 | self,
88 | file_path: str,
89 | start_date: str,
90 | end_date: str,
91 | ) -> pd.DataFrame:
92 | try:
93 | df = pd.read_csv(file_path, parse_dates=["timestamp"])
94 | df["timestamp"] = pd.to_datetime(df["timestamp"])
95 | df.set_index("timestamp", inplace=True)
96 | start_timestamp = pd.to_datetime(start_date).tz_localize(None)
97 | end_timestamp = pd.to_datetime(end_date).tz_localize(None)
98 | filtered_df = df.loc[start_timestamp:end_timestamp]
99 | self.logger.debug(f"Loaded {len(filtered_df)} rows of OHLCV data from file.")
100 | return filtered_df
101 |
102 | except Exception as e:
103 | raise DataFetchError(f"Failed to load OHLCV data from file: {e!s}") from e
104 |
105 | def _fetch_ohlcv_single_batch(
106 | self,
107 | pair: str,
108 | timeframe: str,
109 | since: int,
110 | until: int,
111 | ) -> pd.DataFrame:
112 | ohlcv = self._fetch_with_retry(self.exchange.fetch_ohlcv, pair, timeframe, since)
113 | return self._format_ohlcv(ohlcv, until)
114 |
115 | def _fetch_ohlcv_in_chunks(
116 | self,
117 | pair: str,
118 | timeframe: str,
119 | since: int,
120 | until: int,
121 | candles_per_request: int,
122 | ) -> pd.DataFrame:
123 | all_ohlcv = []
124 | while since < until:
125 | ohlcv = self._fetch_with_retry(self.exchange.fetch_ohlcv, pair, timeframe, since, limit=candles_per_request)
126 | if not ohlcv:
127 | break
128 | all_ohlcv.extend(ohlcv)
129 | since = ohlcv[-1][0] + 1
130 | self.logger.info(f"Fetched up to {pd.to_datetime(since, unit='ms')}")
131 | return self._format_ohlcv(all_ohlcv, until)
132 |
133 | def _format_ohlcv(
134 | self,
135 | ohlcv,
136 | until: int,
137 | ) -> pd.DataFrame:
138 | df = pd.DataFrame(ohlcv, columns=["timestamp", "open", "high", "low", "close", "volume"])
139 | df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
140 | df.set_index("timestamp", inplace=True)
141 | until_timestamp = pd.to_datetime(until, unit="ms")
142 | return df[df.index <= until_timestamp]
143 |
144 | def _get_candle_limit(self) -> int:
145 | return CANDLE_LIMITS.get(self.exchange_name, 500) # Default to 500 if not found
146 |
147 | def _get_timeframe_in_ms(self, timeframe: str) -> int:
148 | return TIMEFRAME_MAPPINGS.get(timeframe, 60 * 1000) # Default to 1m if not found
149 |
150 | def _fetch_with_retry(
151 | self,
152 | method,
153 | *args,
154 | retries=3,
155 | delay=5,
156 | **kwargs,
157 | ):
158 | for attempt in range(retries):
159 | try:
160 | return method(*args, **kwargs)
161 | except Exception as e:
162 | if attempt < retries - 1:
163 | self.logger.warning(f"Attempt {attempt + 1} failed. Retrying in {delay} seconds...")
164 | time.sleep(delay)
165 | else:
166 | self.logger.error(f"Failed after {retries} attempts: {e}")
167 | raise DataFetchError(f"Failed to fetch data after {retries} attempts: {e!s}") from e
168 |
169 | async def place_order(
170 | self,
171 | pair: str,
172 | order_side: str,
173 | order_type: str,
174 | amount: float,
175 | price: float | None = None,
176 | ) -> dict[str, str | float]:
177 | raise NotImplementedError("place_order is not used in backtesting")
178 |
179 | async def get_balance(self) -> dict[str, Any]:
180 | raise NotImplementedError("get_balance is not used in backtesting")
181 |
182 | async def get_current_price(
183 | self,
184 | pair: str,
185 | ) -> float:
186 | raise NotImplementedError("get_current_price is not used in backtesting")
187 |
188 | async def cancel_order(
189 | self,
190 | order_id: str,
191 | pair: str,
192 | ) -> dict[str, str | float]:
193 | raise NotImplementedError("cancel_order is not used in backtesting")
194 |
195 | async def get_exchange_status(self) -> dict:
196 | raise NotImplementedError("get_exchange_status is not used in backtesting")
197 |
198 | async def close_connection(self) -> None:
199 | self.logger.info("[BACKTEST] Closing WebSocket connection...")
200 |
--------------------------------------------------------------------------------
/tests/bot_management/test_bot_controller.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import contextlib
3 | from unittest.mock import Mock, call, patch
4 |
5 | import pytest
6 |
7 | from core.bot_management.bot_controller.bot_controller import BotController
8 | from core.bot_management.event_bus import EventBus, Events
9 | from core.bot_management.grid_trading_bot import GridTradingBot
10 |
11 |
12 | @pytest.mark.asyncio
13 | class TestBotController:
14 | @pytest.fixture
15 | def setup_bot_controller(self):
16 | bot = Mock(spec=GridTradingBot)
17 | bot.strategy = Mock()
18 | bot.strategy.get_formatted_orders = Mock(return_value=[])
19 | bot.get_balances = Mock(return_value={})
20 | event_bus = Mock(spec=EventBus)
21 | bot_controller = BotController(bot, event_bus)
22 | return bot_controller, bot, event_bus
23 |
24 | @pytest.fixture(autouse=True)
25 | def setup_logging(self, setup_bot_controller):
26 | bot_controller, _, _ = setup_bot_controller
27 | bot_controller.logger = Mock()
28 |
29 | @pytest.fixture
30 | def mock_input(self):
31 | with patch("builtins.input") as mock_input:
32 | yield mock_input
33 |
34 | async def run_command_test(self, bot_controller, mock_input):
35 | """Helper method to run command listener tests."""
36 | listener_task = asyncio.create_task(bot_controller.command_listener())
37 |
38 | try:
39 | await asyncio.sleep(0.1)
40 | bot_controller._stop_listening = True
41 |
42 | try:
43 | await asyncio.wait_for(listener_task, timeout=1.0)
44 | except TimeoutError:
45 | listener_task.cancel()
46 | with contextlib.suppress(asyncio.CancelledError):
47 | await listener_task
48 |
49 | finally:
50 | if not listener_task.done():
51 | listener_task.cancel()
52 | with contextlib.suppress(asyncio.CancelledError):
53 | await listener_task
54 |
55 | @pytest.mark.asyncio
56 | @pytest.mark.timeout(2)
57 | async def test_command_listener_quit(self, mock_input, setup_bot_controller):
58 | bot_controller, _, event_bus = setup_bot_controller
59 | mock_input.return_value = "quit"
60 | event_bus.publish_sync = Mock()
61 |
62 | mock_input.side_effect = ["quit", StopIteration]
63 |
64 | await self.run_command_test(bot_controller, mock_input)
65 |
66 | event_bus.publish_sync.assert_called_once_with(Events.STOP_BOT, "User requested shutdown")
67 | assert bot_controller._stop_listening
68 |
69 | @pytest.mark.asyncio
70 | @pytest.mark.timeout(2)
71 | async def test_command_listener_orders(self, mock_input, setup_bot_controller):
72 | bot_controller, bot, _ = setup_bot_controller
73 | # Set up mock to return "orders" and then raise StopIteration
74 | mock_input.side_effect = ["orders", StopIteration]
75 | bot.strategy.get_formatted_orders.return_value = [
76 | ["BUY", "LIMIT", "OPEN", "50000", "0.1", "2024-01-01", "1", "0.1%"],
77 | ]
78 |
79 | await self.run_command_test(bot_controller, mock_input)
80 |
81 | bot.strategy.get_formatted_orders.assert_called_once()
82 |
83 | @pytest.mark.asyncio
84 | @pytest.mark.timeout(2)
85 | async def test_command_listener_balance(self, mock_input, setup_bot_controller):
86 | bot_controller, bot, _ = setup_bot_controller
87 | mock_input.side_effect = ["balance", StopIteration]
88 | bot.get_balances.return_value = {"USD": 1000, "BTC": 0.1}
89 |
90 | await self.run_command_test(bot_controller, mock_input)
91 |
92 | bot.get_balances.assert_called_once()
93 |
94 | @pytest.mark.asyncio
95 | @pytest.mark.timeout(2)
96 | async def test_command_listener_stop(self, mock_input, setup_bot_controller):
97 | bot_controller, _, event_bus = setup_bot_controller
98 | mock_input.side_effect = ["stop", StopIteration]
99 | event_bus.publish_sync = Mock()
100 |
101 | await self.run_command_test(bot_controller, mock_input)
102 |
103 | event_bus.publish_sync.assert_called_once_with(Events.STOP_BOT, "User issued stop command")
104 |
105 | @pytest.mark.asyncio
106 | @pytest.mark.timeout(2)
107 | async def test_command_listener_restart(self, mock_input, setup_bot_controller):
108 | bot_controller, _, event_bus = setup_bot_controller
109 | mock_input.side_effect = ["restart", StopIteration]
110 | event_bus.publish_sync = Mock()
111 |
112 | await self.run_command_test(bot_controller, mock_input)
113 |
114 | assert event_bus.publish_sync.call_count == 2
115 | event_bus.publish_sync.assert_any_call(Events.STOP_BOT, "User issued restart command")
116 | event_bus.publish_sync.assert_any_call(Events.START_BOT, "User issued restart command")
117 |
118 | @pytest.mark.asyncio
119 | @pytest.mark.timeout(2)
120 | async def test_command_listener_invalid_command(self, mock_input, setup_bot_controller):
121 | bot_controller, _, _ = setup_bot_controller
122 | mock_input.side_effect = ["invalid", StopIteration]
123 |
124 | with patch.object(bot_controller.logger, "warning") as mock_logger:
125 | await self.run_command_test(bot_controller, mock_input)
126 | mock_logger.assert_called_once()
127 |
128 | @pytest.mark.asyncio
129 | async def test_handle_stop_event(self, setup_bot_controller):
130 | bot_controller, _, _ = setup_bot_controller
131 |
132 | with patch.object(bot_controller.logger, "info") as mock_logger:
133 | bot_controller._handle_stop_event("Test stop reason")
134 |
135 | assert bot_controller._stop_listening is True
136 | mock_logger.assert_has_calls(
137 | [
138 | call("Received STOP_BOT event: Test stop reason"),
139 | call("Command listener stopped."),
140 | ],
141 | )
142 | assert mock_logger.call_count == 2
143 |
144 | @pytest.mark.asyncio
145 | @pytest.mark.timeout(2)
146 | async def test_command_listener_unexpected_error(self, mock_input, setup_bot_controller):
147 | bot_controller, _, _ = setup_bot_controller
148 | mock_input.side_effect = Exception("Unexpected error")
149 |
150 | with patch.object(bot_controller.logger, "error") as mock_logger:
151 | await self.run_command_test(bot_controller, mock_input)
152 |
153 | mock_logger.assert_called_with(
154 | "Unexpected error in command listener: Unexpected error",
155 | exc_info=True,
156 | )
157 |
158 | @pytest.mark.asyncio
159 | @pytest.mark.timeout(2)
160 | async def test_command_listener_invalid_pause_duration(self, mock_input, setup_bot_controller):
161 | bot_controller, _, _ = setup_bot_controller
162 | mock_input.side_effect = ["pause invalid", StopIteration]
163 |
164 | with patch.object(bot_controller.logger, "warning") as mock_logger:
165 | await self.run_command_test(bot_controller, mock_input)
166 | mock_logger.assert_called_once()
167 |
168 | @pytest.mark.asyncio
169 | @pytest.mark.timeout(2)
170 | async def test_display_orders(self, setup_bot_controller):
171 | bot_controller, bot, _ = setup_bot_controller
172 | orders = [["BUY", "LIMIT", "OPEN", "50000", "0.1", "2024-01-01", "1", "0.1%"]]
173 | bot.strategy.get_formatted_orders.return_value = orders
174 |
175 | with patch.object(bot_controller.logger, "info") as mock_logger:
176 | await bot_controller._display_orders()
177 | assert mock_logger.call_count == 2
178 | bot.strategy.get_formatted_orders.assert_called_once()
179 |
180 | @pytest.mark.asyncio
181 | @pytest.mark.timeout(2)
182 | async def test_display_balance(self, setup_bot_controller):
183 | bot_controller, bot, _ = setup_bot_controller
184 | balances = {"USD": 1000, "BTC": 0.1}
185 | bot.get_balances.return_value = balances
186 |
187 | with patch.object(bot_controller.logger, "info") as mock_logger:
188 | await bot_controller._display_balance()
189 | mock_logger.assert_any_call(f"Current balances: {balances}")
190 | bot.get_balances.assert_called_once()
191 |
--------------------------------------------------------------------------------
/config/config_validator.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from strategies.spacing_type import SpacingType
4 | from strategies.strategy_type import StrategyType
5 |
6 | from .exceptions import ConfigValidationError
7 | from .trading_mode import TradingMode
8 |
9 |
10 | class ConfigValidator:
11 | def __init__(self):
12 | self.logger = logging.getLogger(self.__class__.__name__)
13 |
14 | def validate(self, config):
15 | missing_fields = []
16 | invalid_fields = []
17 | missing_fields += self._validate_required_fields(config)
18 | invalid_fields += self._validate_exchange(config)
19 | missing_fields += self._validate_pair(config)
20 | missing_trading_settings, invalid_trading_settings = self._validate_trading_settings(config)
21 | missing_fields += missing_trading_settings
22 | invalid_fields += invalid_trading_settings
23 | missing_grid_settings, invalid_grid_settings = self._validate_grid_strategy(config)
24 | missing_fields += missing_grid_settings
25 | invalid_fields += invalid_grid_settings
26 | invalid_fields += self._validate_limits(config)
27 | missing_logging_settings, invalid_logging_settings = self._validate_logging(config)
28 | missing_fields += missing_logging_settings
29 | invalid_fields += invalid_logging_settings
30 |
31 | if missing_fields or invalid_fields:
32 | raise ConfigValidationError(missing_fields=missing_fields, invalid_fields=invalid_fields)
33 |
34 | def _validate_required_fields(self, config):
35 | required_fields = ["exchange", "pair", "trading_settings", "grid_strategy", "risk_management", "logging"]
36 | missing_fields = [field for field in required_fields if field not in config]
37 | if missing_fields:
38 | self.logger.error(f"Missing required fields: {missing_fields}")
39 | return missing_fields
40 |
41 | def _validate_exchange(self, config):
42 | invalid_fields = []
43 | exchange = config.get("exchange", {})
44 |
45 | if not exchange.get("name"):
46 | self.logger.error("Exchange name is missing.")
47 | invalid_fields.append("exchange.name")
48 |
49 | trading_fee = exchange.get("trading_fee")
50 | if trading_fee is None or not isinstance(trading_fee, float | int) or trading_fee < 0:
51 | self.logger.error("Invalid or missing trading fee.")
52 | invalid_fields.append("exchange.trading_fee")
53 |
54 | trading_mode_str = exchange.get("trading_mode")
55 | if not trading_mode_str:
56 | invalid_fields.append("exchange.trading_mode")
57 | else:
58 | try:
59 | TradingMode.from_string(trading_mode_str)
60 | except ValueError:
61 | invalid_fields.append("exchange.trading_mode")
62 |
63 | return invalid_fields
64 |
65 | def _validate_pair(self, config):
66 | missing_fields = []
67 | pair = config.get("pair", {})
68 |
69 | if not pair.get("base_currency"):
70 | missing_fields.append("pair.base_currency")
71 | if not pair.get("quote_currency"):
72 | missing_fields.append("pair.quote_currency")
73 |
74 | if missing_fields:
75 | self.logger.error(f"Missing pair configuration fields: {missing_fields}")
76 |
77 | return missing_fields
78 |
79 | def _validate_trading_settings(self, config):
80 | missing_fields = []
81 | invalid_fields = []
82 | trading_settings = config.get("trading_settings", {})
83 |
84 | if not trading_settings.get("initial_balance"):
85 | missing_fields.append("trading_settings.initial_balance")
86 |
87 | # Validate timeframe
88 | timeframe = trading_settings.get("timeframe")
89 | valid_timeframes = ["1s", "1m", "3m", "5m", "15m", "30m", "1h", "2h", "6h", "12h", "1d", "1w", "1M"]
90 | if timeframe not in valid_timeframes:
91 | self.logger.error(f"Invalid timeframe: {timeframe}. Must be one of {valid_timeframes}.")
92 | invalid_fields.append("trading_settings.timeframe")
93 |
94 | # Validate period
95 | period = trading_settings.get("period", {})
96 | start_date = period.get("start_date")
97 | end_date = period.get("end_date")
98 |
99 | if not start_date:
100 | missing_fields.append("trading_settings.period.start_date")
101 | if not end_date:
102 | missing_fields.append("trading_settings.period.end_date")
103 |
104 | return missing_fields, invalid_fields
105 |
106 | def _validate_grid_strategy(self, config):
107 | missing_fields = []
108 | invalid_fields = []
109 | grid = config.get("grid_strategy", {})
110 |
111 | grid_type = grid.get("type")
112 | if grid_type is None:
113 | missing_fields.append("grid_strategy.type")
114 | else:
115 | try:
116 | StrategyType.from_string(grid_type)
117 |
118 | except ValueError as e:
119 | self.logger.error(str(e))
120 | invalid_fields.append("grid_strategy.type")
121 |
122 | spacing = grid.get("spacing")
123 | if spacing is None:
124 | missing_fields.append("grid_strategy.spacing")
125 | else:
126 | try:
127 | SpacingType.from_string(spacing)
128 |
129 | except ValueError as e:
130 | self.logger.error(str(e))
131 | invalid_fields.append("grid_strategy.spacing")
132 |
133 | num_grids = grid.get("num_grids")
134 | if num_grids is None:
135 | missing_fields.append("grid_strategy.num_grids")
136 | elif not isinstance(num_grids, int) or num_grids <= 0:
137 | self.logger.error("Grid strategy 'num_grids' must be a positive integer.")
138 | invalid_fields.append("grid_strategy.num_grids")
139 |
140 | range_ = grid.get("range", {})
141 | top = range_.get("top")
142 | bottom = range_.get("bottom")
143 | if top is None:
144 | missing_fields.append("grid_strategy.range.top")
145 | if bottom is None:
146 | missing_fields.append("grid_strategy.range.bottom")
147 |
148 | if top is not None and bottom is not None:
149 | if not isinstance(top, int | float) or not isinstance(bottom, int | float):
150 | self.logger.error("'top' and 'bottom' in 'grid_strategy.range' must be numbers.")
151 | invalid_fields.append("grid_strategy.range.top")
152 | invalid_fields.append("grid_strategy.range.bottom")
153 | elif bottom >= top:
154 | self.logger.error("'grid_strategy.range.bottom' must be less than 'grid_strategy.range.top'.")
155 | invalid_fields.append("grid_strategy.range.top")
156 | invalid_fields.append("grid_strategy.range.bottom")
157 |
158 | return missing_fields, invalid_fields
159 |
160 | def _validate_limits(self, config):
161 | invalid_fields = []
162 | limits = config.get("risk_management", {})
163 | take_profit = limits.get("take_profit", {})
164 | stop_loss = limits.get("stop_loss", {})
165 |
166 | # Validate take profit
167 | if not isinstance(take_profit.get("enabled"), bool):
168 | self.logger.error("Take profit enabled flag must be a boolean.")
169 | invalid_fields.append("risk_management.take_profit.enabled")
170 |
171 | if take_profit.get("threshold") is None or not isinstance(take_profit.get("threshold"), float | int):
172 | self.logger.error("Invalid or missing take profit threshold.")
173 | invalid_fields.append("risk_management.take_profit.threshold")
174 |
175 | # Validate stop loss
176 | if not isinstance(stop_loss.get("enabled"), bool):
177 | self.logger.error("Stop loss enabled flag must be a boolean.")
178 | invalid_fields.append("risk_management.stop_loss.enabled")
179 |
180 | if stop_loss.get("threshold") is None or not isinstance(stop_loss.get("threshold"), float | int):
181 | self.logger.error("Invalid or missing stop loss threshold.")
182 | invalid_fields.append("risk_management.stop_loss.threshold")
183 |
184 | return invalid_fields
185 |
186 | def _validate_logging(self, config):
187 | missing_fields = []
188 | invalid_fields = []
189 | logging_settings = config.get("logging", {})
190 |
191 | # Validate log level
192 | log_level = logging_settings.get("log_level")
193 | valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
194 | if log_level is None:
195 | missing_fields.append("logging.log_level")
196 | elif log_level.upper() not in valid_log_levels:
197 | self.logger.error(f"Invalid log level: {log_level}. Must be one of {valid_log_levels}.")
198 | invalid_fields.append("logging.log_level")
199 |
200 | # Validate log to file
201 | if not isinstance(logging_settings.get("log_to_file"), bool):
202 | self.logger.error("log_to_file must be a boolean.")
203 | invalid_fields.append("logging.log_to_file")
204 |
205 | if missing_fields:
206 | self.logger.error(f"Missing logging fields: {missing_fields}")
207 |
208 | return missing_fields, invalid_fields
209 |
--------------------------------------------------------------------------------