├── config ├── __init__.py ├── trading_mode.py ├── config.json ├── exceptions.py ├── config_manager.py └── config_validator.py ├── core ├── __init__.py ├── order_handling │ ├── __init__.py │ ├── execution_strategy │ │ ├── __init__.py │ │ ├── order_execution_strategy_interface.py │ │ ├── order_execution_strategy_factory.py │ │ ├── backtest_order_execution_strategy.py │ │ └── live_order_execution_strategy.py │ ├── fee_calculator.py │ ├── exceptions.py │ ├── order_book.py │ ├── order.py │ └── order_status_tracker.py ├── bot_management │ ├── bot_controller │ │ ├── __init__.py │ │ ├── exceptions.py │ │ └── bot_controller.py │ ├── notification │ │ ├── notification_content.py │ │ └── notification_handler.py │ └── event_bus.py ├── validation │ ├── exceptions.py │ └── order_validator.py ├── services │ ├── exchange_service_factory.py │ ├── exceptions.py │ ├── exchange_interface.py │ └── backtest_exchange_service.py └── grid_management │ └── grid_level.py ├── tests ├── __init__.py ├── conftest.py ├── order_handling │ ├── test_fee_calculator.py │ ├── test_order_execution_strategy_factory.py │ ├── test_backtest_order_execution_strategy.py │ ├── test_order.py │ ├── test_order_book.py │ └── test_live_order_execution_strategy.py ├── utils │ ├── test_config_name_generator.py │ ├── test_performance_results_saver.py │ ├── test_logging_config.py │ └── test_arg_parser.py ├── services │ └── test_exchange_service_factory.py ├── bot_management │ ├── test_event_bus.py │ ├── test_notification_handler.py │ └── test_bot_controller.py ├── validation │ └── test_order_validator.py ├── grid_management │ └── test_grid_level.py ├── strategies │ └── test_plotter.py └── config │ ├── test_config_validator.py │ └── test_config_manager.py ├── utils ├── __init__.py ├── constants.py ├── config_name_generator.py ├── logging_config.py ├── performance_results_saver.py └── arg_parser.py ├── strategies ├── __init__.py ├── spacing_type.py ├── strategy_type.py ├── trading_strategy_interface.py └── plotter.py ├── .vscode ├── extensions.json └── settings.json ├── monitoring ├── configs │ ├── grafana │ │ └── provisioning │ │ │ ├── dashboards.yml │ │ │ └── datasources.yml │ ├── loki │ │ ├── rules.yaml │ │ └── loki.yaml │ └── promtail │ │ └── promtail.yaml └── dashboards │ └── grid_trading_bot_dashboard.json ├── .github ├── dependabot.yml ├── workflows │ └── run-tests-on-push-or-merge-pr-master.yml └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── LICENSE.txt ├── .pre-commit-config.yaml ├── docker-compose.yml ├── pyproject.toml ├── CODE_OF_CONDUCT.md └── main.py /config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /strategies/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/order_handling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/bot_management/bot_controller/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/order_handling/execution_strategy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": ["ms-python.python", "charliermarsh.ruff"] 3 | } 4 | -------------------------------------------------------------------------------- /monitoring/configs/grafana/provisioning/dashboards.yml: -------------------------------------------------------------------------------- 1 | apiVersion: 1 2 | providers: 3 | - name: 'GridTradingBot' 4 | orgId: 1 5 | folder: 'TradingBot Monitoring' 6 | type: file 7 | options: 8 | path: /var/lib/grafana/dashboards 9 | -------------------------------------------------------------------------------- /monitoring/configs/grafana/provisioning/datasources.yml: -------------------------------------------------------------------------------- 1 | apiVersion: 1 2 | 3 | datasources: 4 | - name: Loki 5 | type: loki 6 | access: proxy 7 | url: http://loki:3100 8 | uid: Loki 9 | editable: false 10 | isDefault: true 11 | -------------------------------------------------------------------------------- /core/bot_management/bot_controller/exceptions.py: -------------------------------------------------------------------------------- 1 | class BotControllerError(Exception): 2 | """Base exception class for BotController errors.""" 3 | 4 | 5 | class CommandParsingError(BotControllerError): 6 | """Exception raised when there is an error parsing a command.""" 7 | 8 | 9 | class StrategyControlError(BotControllerError): 10 | """Exception raised when starting, stopping, or restarting the strategy fails.""" 11 | -------------------------------------------------------------------------------- /core/validation/exceptions.py: -------------------------------------------------------------------------------- 1 | class InsufficientBalanceError(Exception): 2 | """Raised when balance is insufficient to place a buy or sell order.""" 3 | 4 | pass 5 | 6 | 7 | class InsufficientCryptoBalanceError(Exception): 8 | """Raised when crypto balance is insufficient to complete a sell order.""" 9 | 10 | pass 11 | 12 | 13 | class InvalidOrderQuantityError(Exception): 14 | """Raised when order quantity (amount) is invalid.""" 15 | 16 | pass 17 | -------------------------------------------------------------------------------- /core/order_handling/fee_calculator.py: -------------------------------------------------------------------------------- 1 | from config.config_manager import ConfigManager 2 | 3 | 4 | class FeeCalculator: 5 | def __init__( 6 | self, 7 | config_manager: ConfigManager, 8 | ): 9 | self.config_manager = config_manager 10 | self.trading_fee: float = self.config_manager.get_trading_fee() 11 | 12 | def calculate_fee( 13 | self, 14 | trade_value: float, 15 | ) -> float: 16 | return trade_value * self.trading_fee 17 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | ignore: 8 | # Ignoring these dependencies temporarily for compatibility 9 | - dependency-name: "numpy" 10 | - dependency-name: "pandas" 11 | labels: 12 | - "pip-dependencies" 13 | 14 | - package-ecosystem: "github-actions" 15 | directory: "/" 16 | schedule: 17 | interval: "weekly" 18 | labels: 19 | - "actions" 20 | -------------------------------------------------------------------------------- /core/order_handling/exceptions.py: -------------------------------------------------------------------------------- 1 | from .order import OrderSide, OrderType 2 | 3 | 4 | class OrderExecutionFailedError(Exception): 5 | def __init__( 6 | self, 7 | message: str, 8 | order_side: OrderSide, 9 | order_type: OrderType, 10 | pair: str, 11 | quantity: float, 12 | price: float, 13 | ): 14 | super().__init__(message) 15 | self.order_side = order_side 16 | self.order_type = order_type 17 | self.pair = pair 18 | self.quantity = quantity 19 | self.price = price 20 | -------------------------------------------------------------------------------- /config/trading_mode.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TradingMode(Enum): 5 | BACKTEST = "backtest" 6 | PAPER_TRADING = "paper_trading" 7 | LIVE = "live" 8 | 9 | @staticmethod 10 | def from_string(mode_str: str): 11 | try: 12 | return TradingMode(mode_str) 13 | except ValueError: 14 | available_modes = ", ".join([mode.value for mode in TradingMode]) 15 | raise ValueError( 16 | f"Invalid trading mode: '{mode_str}'. Available modes are: {available_modes}", 17 | ) from None 18 | -------------------------------------------------------------------------------- /strategies/spacing_type.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class SpacingType(Enum): 5 | ARITHMETIC = "arithmetic" 6 | GEOMETRIC = "geometric" 7 | 8 | @staticmethod 9 | def from_string(spacing_type_str: str): 10 | try: 11 | return SpacingType(spacing_type_str) 12 | except ValueError: 13 | available_spacings = ", ".join([spacing.value for spacing in SpacingType]) 14 | raise ValueError( 15 | f"Invalid spacing type: '{spacing_type_str}'. Available spacings are: {available_spacings}", 16 | ) from None 17 | -------------------------------------------------------------------------------- /strategies/strategy_type.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class StrategyType(Enum): 5 | SIMPLE_GRID = "simple_grid" 6 | HEDGED_GRID = "hedged_grid" 7 | 8 | @staticmethod 9 | def from_string(strategy_type_str: str): 10 | try: 11 | return StrategyType(strategy_type_str) 12 | except ValueError: 13 | available_strategies = ", ".join([strat.value for strat in StrategyType]) 14 | raise ValueError( 15 | f"Invalid strategy type: '{strategy_type_str}'. Available strategies are: {available_strategies}", 16 | ) from None 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *.DS_Store 3 | __pycache__/ 4 | *.pyc 5 | *.py[cod] 6 | 7 | # Distribution / packaging 8 | build/ 9 | dist/ 10 | *.egg-info/ 11 | *.lock 12 | 13 | # Installer logs 14 | pip-log.txt 15 | pip-delete-this-directory.txt 16 | 17 | # Unit test / coverage reports 18 | .tox/ 19 | .coverage 20 | .cache 21 | coverage.xml 22 | 23 | # Ignore profiling output files 24 | *.prof 25 | 26 | # VS Code 27 | .vscode/* 28 | !.vscode/settings.json 29 | !.vscode/extensions.json 30 | 31 | # Secrets 32 | secrets.json 33 | .env.yml 34 | .env 35 | loki-data/ 36 | grafana-data/ 37 | 38 | # `uv` Package Manager 39 | .uv/ 40 | 41 | # Project specific 42 | /data/ 43 | logs/* 44 | -------------------------------------------------------------------------------- /monitoring/configs/loki/rules.yaml: -------------------------------------------------------------------------------- 1 | groups: 2 | - name: grid_trading_alerts 3 | rules: 4 | - alert: HighCPUUsage 5 | expr: | 6 | avg_over_time({job="grid_trading_bot"} | json | unwrap cpu [5m]) > 80 7 | for: 2m 8 | labels: 9 | severity: warning 10 | annotations: 11 | summary: High CPU usage detected 12 | - alert: OrderExecutionFailure 13 | expr: | 14 | count_over_time({job="grid_trading_bot"} |= "Failed to execute order" [5m]) > 3 15 | for: 1m 16 | labels: 17 | severity: critical 18 | annotations: 19 | summary: Multiple order execution failures detected 20 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.autoTestDiscoverOnSaveEnabled": true, 3 | "python.testing.unittestEnabled": false, 4 | "python.testing.pytestEnabled": true, 5 | "python.testing.pytestArgs": ["tests"], 6 | "python.defaultInterpreterPath": "./.venv/bin/python", 7 | 8 | "python.analysis.autoImportCompletions": true, 9 | "editor.codeActionsOnSave": { 10 | "source.fixAll": "explicit" 11 | }, 12 | "editor.formatOnSave": true, 13 | "editor.insertSpaces": true, 14 | "editor.tabSize": 2, 15 | "editor.detectIndentation": false, 16 | "editor.rulers": [120], 17 | "[python]": { 18 | "editor.defaultFormatter": "charliermarsh.ruff", 19 | "editor.tabSize": 4 20 | }, 21 | "files.watcherExclude": { 22 | "**/.venv/**": true, 23 | "**/__pycache__/**": true, 24 | "**/.pytest_cache/**": true 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /core/order_handling/execution_strategy/order_execution_strategy_interface.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from ..order import Order, OrderSide 4 | 5 | 6 | class OrderExecutionStrategyInterface(ABC): 7 | @abstractmethod 8 | async def execute_market_order( 9 | self, 10 | order_side: OrderSide, 11 | pair: str, 12 | quantity: float, 13 | price: float, 14 | ) -> Order | None: 15 | pass 16 | 17 | @abstractmethod 18 | async def execute_limit_order( 19 | self, 20 | order_side: OrderSide, 21 | pair: str, 22 | quantity: float, 23 | price: float, 24 | ) -> Order | None: 25 | pass 26 | 27 | @abstractmethod 28 | async def get_order( 29 | self, 30 | order_id: str, 31 | pair: str, 32 | ) -> Order | None: 33 | pass 34 | -------------------------------------------------------------------------------- /core/services/exchange_service_factory.py: -------------------------------------------------------------------------------- 1 | from config.config_manager import ConfigManager 2 | from config.trading_mode import TradingMode 3 | 4 | from .backtest_exchange_service import BacktestExchangeService 5 | from .live_exchange_service import LiveExchangeService 6 | 7 | 8 | class ExchangeServiceFactory: 9 | @staticmethod 10 | def create_exchange_service( 11 | config_manager: ConfigManager, 12 | trading_mode: TradingMode, 13 | ): 14 | if trading_mode == TradingMode.BACKTEST: 15 | return BacktestExchangeService(config_manager) 16 | elif trading_mode == TradingMode.PAPER_TRADING: 17 | return LiveExchangeService(config_manager, is_paper_trading_activated=True) 18 | elif trading_mode == TradingMode.LIVE: 19 | return LiveExchangeService(config_manager, is_paper_trading_activated=False) 20 | else: 21 | raise ValueError(f"Unsupported trading mode: {trading_mode}") 22 | -------------------------------------------------------------------------------- /config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "exchange": { 3 | "name": "binance", 4 | "trading_fee": 0.001, 5 | "trading_mode": "backtest" 6 | }, 7 | "pair": { 8 | "base_currency": "SOL", 9 | "quote_currency": "USDT" 10 | }, 11 | "trading_settings": { 12 | "timeframe": "1m", 13 | "period": { 14 | "start_date": "2024-12-10T09:00:00Z", 15 | "end_date": "2024-12-15T23:00:00Z" 16 | }, 17 | "initial_balance": 150 18 | }, 19 | "grid_strategy": { 20 | "type": "hedged_grid", 21 | "spacing": "geometric", 22 | "num_grids": 8, 23 | "range": { 24 | "top": 240, 25 | "bottom": 210 26 | } 27 | }, 28 | "risk_management": { 29 | "take_profit": { 30 | "enabled": false, 31 | "threshold": 3700 32 | }, 33 | "stop_loss": { 34 | "enabled": false, 35 | "threshold": 2830 36 | } 37 | }, 38 | "logging": { 39 | "log_level": "INFO", 40 | "log_to_file": true 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /core/services/exceptions.py: -------------------------------------------------------------------------------- 1 | class UnsupportedExchangeError(Exception): 2 | """Raised when the exchange is not supported.""" 3 | 4 | pass 5 | 6 | 7 | class DataFetchError(Exception): 8 | """Raised when data fetching fails after retries.""" 9 | 10 | pass 11 | 12 | 13 | class HistoricalMarketDataFileNotFoundError(Exception): 14 | """Raised when historical market data has not been found in repository.""" 15 | 16 | pass 17 | 18 | 19 | class UnsupportedTimeframeError(Exception): 20 | """Raised when a timeframe is not supported by a given exchange.""" 21 | 22 | pass 23 | 24 | 25 | class UnsupportedPairError(Exception): 26 | """Raised when a crypto pair is not supported by a given exchange.""" 27 | 28 | pass 29 | 30 | 31 | class OrderCancellationError(Exception): 32 | """Raised when order cancellation fails.""" 33 | 34 | pass 35 | 36 | 37 | class MissingEnvironmentVariableError(Exception): 38 | """Raised when env variable are missing (EXCHANGE_API_KEY and/or EXCHANGE_SECRET_KEY).""" 39 | 40 | pass 41 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jordan TETE 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /core/order_handling/execution_strategy/order_execution_strategy_factory.py: -------------------------------------------------------------------------------- 1 | from config.config_manager import ConfigManager 2 | from config.trading_mode import TradingMode 3 | from core.services.exchange_interface import ExchangeInterface 4 | 5 | from .backtest_order_execution_strategy import BacktestOrderExecutionStrategy 6 | from .live_order_execution_strategy import LiveOrderExecutionStrategy 7 | from .order_execution_strategy_interface import OrderExecutionStrategyInterface 8 | 9 | 10 | class OrderExecutionStrategyFactory: 11 | @staticmethod 12 | def create( 13 | config_manager: ConfigManager, 14 | exchange_service: ExchangeInterface, 15 | ) -> OrderExecutionStrategyInterface: 16 | trading_mode = config_manager.get_trading_mode() 17 | 18 | if trading_mode == TradingMode.LIVE or trading_mode == TradingMode.PAPER_TRADING: 19 | return LiveOrderExecutionStrategy(exchange_service=exchange_service) 20 | elif trading_mode == TradingMode.BACKTEST: 21 | return BacktestOrderExecutionStrategy() 22 | else: 23 | raise ValueError(f"Unknown trading mode: {trading_mode}") 24 | -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | CANDLE_LIMITS = { 2 | "binance": 1000, 3 | "coinbase": 300, 4 | "kraken": 720, 5 | "bitfinex": 5000, 6 | "bitstamp": 1000, 7 | "huobi": 2000, 8 | "okex": 1440, 9 | "bybit": 200, 10 | "bittrex": 500, 11 | "poloniex": 500, 12 | "gateio": 1000, 13 | "kucoin": 1500, 14 | } 15 | 16 | TIMEFRAME_MAPPINGS = { 17 | "1s": 1 * 1000, # 1 second 18 | "1m": 60 * 1000, # 1 minute 19 | "3m": 3 * 60 * 1000, # 3 minutes 20 | "5m": 5 * 60 * 1000, # 5 minutes 21 | "15m": 15 * 60 * 1000, # 15 minutes 22 | "30m": 30 * 60 * 1000, # 30 minutes 23 | "1h": 60 * 60 * 1000, # 1 hour 24 | "2h": 2 * 60 * 60 * 1000, # 2 hours 25 | "6h": 6 * 60 * 60 * 1000, # 6 hours 26 | "12h": 12 * 60 * 60 * 1000, # 12 hours 27 | "1d": 24 * 60 * 60 * 1000, # 1 day 28 | "3d": 3 * 24 * 60 * 60 * 1000, # 3 days 29 | "1w": 7 * 24 * 60 * 60 * 1000, # 1 week 30 | "1M": 30 * 24 * 60 * 60 * 1000, # 1 month (approximated as 30 days) 31 | } 32 | 33 | RESSOURCE_THRESHOLDS = { 34 | "cpu": 90, 35 | "bot_cpu": 80, 36 | "memory": 80, 37 | "bot_memory": 70, 38 | "disk": 90, 39 | } 40 | -------------------------------------------------------------------------------- /.github/workflows/run-tests-on-push-or-merge-pr-master.yml: -------------------------------------------------------------------------------- 1 | name: on_push_or_merge_pr_master 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | test: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout code 17 | uses: actions/checkout@v6 18 | 19 | - name: Install uv 20 | uses: astral-sh/setup-uv@v7 21 | 22 | - name: "Set up Python" 23 | uses: actions/setup-python@v6 24 | with: 25 | python-version-file: "pyproject.toml" 26 | 27 | - name: Install the project 28 | run: uv sync --all-extras --dev 29 | 30 | - name: Set PYTHONPATH 31 | run: echo "PYTHONPATH=$(pwd)" >> $GITHUB_ENV 32 | 33 | - name: Run tests and upload coverage 34 | run: uv run pytest --cov=core --cov=config --cov=strategies --cov=utils --cov-report=xml:coverage.xml --cov-report=term 35 | continue-on-error: true 36 | 37 | - name: Upload coverage reports to Codecov 38 | uses: codecov/codecov-action@v5 39 | with: 40 | fail_ci_if_error: true 41 | files: ./coverage.xml 42 | token: ${{ secrets.CODECOV_TOKEN }} 43 | verbose: true 44 | -------------------------------------------------------------------------------- /utils/config_name_generator.py: -------------------------------------------------------------------------------- 1 | from datetime import UTC, datetime 2 | 3 | from config.config_manager import ConfigManager 4 | 5 | 6 | def generate_config_name(config_manager: ConfigManager) -> str: 7 | """ 8 | Generates a unique and descriptive name for the bot's configuration. 9 | 10 | Args: 11 | config_manager (ConfigManager): Config manager instance to retrieve key parameters. 12 | 13 | Returns: 14 | str: A descriptive configuration name including trading pair, mode, strategy, 15 | grid spacing, grid range, and timestamp. 16 | """ 17 | trading_pair = f"{config_manager.get_base_currency()}_{config_manager.get_quote_currency()}" 18 | trading_mode = config_manager.get_trading_mode().name 19 | grid_strategy_type = config_manager.get_strategy_type().name 20 | grid_spacing_type = config_manager.get_spacing_type().name 21 | grid_size = config_manager.get_num_grids() 22 | grid_top = config_manager.get_top_range() 23 | grid_bottom = config_manager.get_bottom_range() 24 | start_time = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M") 25 | 26 | return ( 27 | f"bot_{trading_pair}_{trading_mode}_strategy{grid_strategy_type}_" 28 | f"spacing{grid_spacing_type}_size{grid_size}_range{grid_bottom}-{grid_top}_{start_time}" 29 | ) 30 | -------------------------------------------------------------------------------- /monitoring/configs/loki/loki.yaml: -------------------------------------------------------------------------------- 1 | auth_enabled: false 2 | 3 | server: 4 | http_listen_port: 3100 5 | grpc_listen_port: 9096 6 | log_level: debug 7 | 8 | common: 9 | instance_addr: 127.0.0.1 10 | path_prefix: /tmp/loki 11 | storage: 12 | filesystem: 13 | chunks_directory: /tmp/loki/chunks 14 | rules_directory: /tmp/loki/rules 15 | replication_factor: 1 16 | ring: 17 | kvstore: 18 | store: inmemory 19 | 20 | frontend: 21 | max_outstanding_per_tenant: 2048 22 | 23 | limits_config: 24 | max_global_streams_per_user: 1000 25 | ingestion_rate_mb: 500 26 | ingestion_burst_size_mb: 500 27 | volume_enabled: true 28 | reject_old_samples: false 29 | reject_old_samples_max_age: 2160h # Accept logs up to 90 days old (adjust as needed) 30 | 31 | query_range: 32 | results_cache: 33 | cache: 34 | embedded_cache: 35 | enabled: true 36 | max_size_mb: 100 37 | 38 | ruler: 39 | rule_path: /etc/loki/rules 40 | storage: 41 | type: local 42 | local: 43 | directory: /etc/loki/rules 44 | enable_api: true 45 | 46 | schema_config: 47 | configs: 48 | - from: 2020-10-24 49 | store: tsdb 50 | object_store: filesystem 51 | schema: v13 52 | index: 53 | prefix: index_ 54 | period: 24h 55 | 56 | analytics: 57 | reporting_enabled: false 58 | -------------------------------------------------------------------------------- /config/exceptions.py: -------------------------------------------------------------------------------- 1 | class ConfigError(Exception): 2 | """Base class for all configuration-related errors.""" 3 | 4 | pass 5 | 6 | 7 | class ConfigFileNotFoundError(ConfigError): 8 | def __init__(self, config_file, message="Configuration file not found"): 9 | self.config_file = config_file 10 | self.message = f"{message}: {config_file}" 11 | super().__init__(self.message) 12 | 13 | 14 | class ConfigValidationError(ConfigError): 15 | def __init__(self, missing_fields=None, invalid_fields=None, message="Configuration validation error"): 16 | self.missing_fields = missing_fields or [] 17 | self.invalid_fields = invalid_fields or [] 18 | details = [] 19 | if self.missing_fields: 20 | details.append(f"Missing required fields - {', '.join(self.missing_fields)}") 21 | if self.invalid_fields: 22 | details.append(f"Invalid fields - {', '.join(self.invalid_fields)}") 23 | self.message = f"{message}: {', '.join(details)}" 24 | super().__init__(self.message) 25 | 26 | 27 | class ConfigParseError(ConfigError): 28 | def __init__(self, config_file, original_exception, message="Error parsing configuration file"): 29 | self.config_file = config_file 30 | self.original_exception = original_exception 31 | self.message = f"{message} ({config_file}): {original_exception!s}" 32 | super().__init__(self.message) 33 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture 5 | def valid_config(): 6 | """Fixture providing a valid configuration for testing.""" 7 | return { 8 | "exchange": { 9 | "name": "binance", 10 | "trading_fee": 0.001, 11 | "trading_mode": "backtest", 12 | }, 13 | "pair": { 14 | "base_currency": "ETH", 15 | "quote_currency": "USDT", 16 | }, 17 | "trading_settings": { 18 | "initial_balance": 10000, 19 | "timeframe": "1m", 20 | "period": { 21 | "start_date": "2024-07-04T00:00:00Z", 22 | "end_date": "2024-07-11T00:00:00Z", 23 | }, 24 | "historical_data_file": "data/SOL_USDT/2024/1m.csv", 25 | }, 26 | "grid_strategy": { 27 | "type": "simple_grid", 28 | "spacing": "geometric", 29 | "num_grids": 20, 30 | "range": { 31 | "top": 3100, 32 | "bottom": 2850, 33 | }, 34 | }, 35 | "risk_management": { 36 | "take_profit": { 37 | "enabled": False, 38 | "threshold": 3700, 39 | }, 40 | "stop_loss": { 41 | "enabled": False, 42 | "threshold": 2830, 43 | }, 44 | }, 45 | "logging": { 46 | "log_level": "INFO", 47 | "log_to_file": True, 48 | }, 49 | } 50 | -------------------------------------------------------------------------------- /core/bot_management/notification/notification_content.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | 4 | 5 | @dataclass 6 | class NotificationContent: 7 | title: str 8 | message: str 9 | 10 | 11 | class NotificationType(Enum): 12 | ORDER_PLACED = NotificationContent( 13 | title="Order Placement Successful", 14 | message=("Order placed successfully:\n{order_details}"), 15 | ) 16 | ORDER_FILLED = NotificationContent( 17 | title="Order Filled", 18 | message=("Order has been filled successfully:\n{order_details}"), 19 | ) 20 | ORDER_FAILED = NotificationContent( 21 | title="Order Placement Failed", 22 | message=("Failed to place order:\n{error_details}"), 23 | ) 24 | ORDER_CANCELLED = NotificationContent( 25 | title="Order Cancellation", 26 | message=("Order has been cancelled:\n{order_details}"), 27 | ) 28 | ERROR_OCCURRED = NotificationContent( 29 | title="Error Occurred", 30 | message="An error occurred in the trading bot:\n{error_details}", 31 | ) 32 | TAKE_PROFIT_TRIGGERED = NotificationContent( 33 | title="Take Profit Triggered", 34 | message="Take profit triggered with order details:\n{order_details}", 35 | ) 36 | STOP_LOSS_TRIGGERED = NotificationContent( 37 | title="Stop Loss Triggered", 38 | message="Stop loss triggered with order details:\n{order_details}", 39 | ) 40 | HEALTH_CHECK_ALERT = NotificationContent( 41 | title="Health Check Alert", 42 | message=("{alert_details}"), 43 | ) 44 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Ruff for linting and formatting 3 | - repo: https://github.com/astral-sh/ruff-pre-commit 4 | rev: v0.8.4 5 | hooks: 6 | - id: ruff 7 | name: ruff (linting) 8 | args: [--fix] 9 | - id: ruff-format 10 | name: ruff (formatting) 11 | 12 | # Built-in hooks for basic file checks 13 | - repo: https://github.com/pre-commit/pre-commit-hooks 14 | rev: v5.0.0 15 | hooks: 16 | - id: trailing-whitespace 17 | name: trim trailing whitespace 18 | - id: end-of-file-fixer 19 | name: fix end of files 20 | - id: check-yaml 21 | name: check yaml 22 | - id: check-json 23 | name: check json 24 | - id: check-toml 25 | name: check toml 26 | - id: check-merge-conflict 27 | name: check for merge conflicts 28 | - id: check-added-large-files 29 | name: check for added large files 30 | args: ["--maxkb=1000"] 31 | - id: debug-statements 32 | name: debug statements (Python) 33 | 34 | # Python-specific checks 35 | - repo: https://github.com/pre-commit/pygrep-hooks 36 | rev: v1.10.0 37 | hooks: 38 | - id: python-check-blanket-noqa 39 | name: check blanket noqa 40 | - id: python-check-blanket-type-ignore 41 | name: check blanket type ignore 42 | - id: python-no-log-warn 43 | name: check for deprecated log warn 44 | - id: python-use-type-annotations 45 | name: check use type annotations 46 | - id: text-unicode-replacement-char 47 | name: check for unicode replacement chars 48 | -------------------------------------------------------------------------------- /core/grid_management/grid_level.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from ..order_handling.order import Order 4 | 5 | 6 | class GridCycleState(Enum): 7 | READY_TO_BUY_OR_SELL = "ready_to_buy_or_sell" # Level is ready for both a buy or a sell order 8 | READY_TO_BUY = "ready_to_buy" # Level is ready for a buy order 9 | WAITING_FOR_BUY_FILL = "waiting_for_buy_fill" # Buy order placed, waiting for execution 10 | READY_TO_SELL = "ready_to_sell" # Level is ready for a sell order 11 | WAITING_FOR_SELL_FILL = "waiting_for_sell_fill" # Sell order placed, waiting for execution 12 | 13 | 14 | class GridLevel: 15 | def __init__(self, price: float, state: GridCycleState): 16 | self.price: float = price 17 | self.orders: list[Order] = [] # Track all orders at this level 18 | self.state: GridCycleState = state 19 | self.paired_buy_level: GridLevel | None = None 20 | self.paired_sell_level: GridLevel | None = None 21 | 22 | def add_order(self, order: Order) -> None: 23 | """ 24 | Record an order at this level. 25 | """ 26 | self.orders.append(order) 27 | 28 | def __str__(self) -> str: 29 | return ( 30 | f"GridLevel(price={self.price}, " 31 | f"state={self.state.name}, " 32 | f"num_orders={len(self.orders)}, " 33 | f"paired_buy_level={self.paired_buy_level.price if self.paired_buy_level else None}), " 34 | f"paired_sell_level={self.paired_sell_level.price if self.paired_sell_level else None})" 35 | ) 36 | 37 | def __repr__(self) -> str: 38 | return self.__str__() 39 | -------------------------------------------------------------------------------- /monitoring/configs/promtail/promtail.yaml: -------------------------------------------------------------------------------- 1 | server: 2 | http_listen_port: 9080 3 | grpc_listen_port: 0 4 | log_level: debug 5 | 6 | positions: 7 | filename: /tmp/positions.yaml 8 | 9 | clients: 10 | - url: "http://loki:3100/loki/api/v1/push" 11 | batchsize: 1 12 | batchwait: 1s 13 | 14 | scrape_configs: 15 | - job_name: bot_logs 16 | static_configs: 17 | - targets: 18 | - localhost 19 | labels: 20 | job: grid_trading_bot 21 | __path__: /logs/**/*.log 22 | 23 | relabel_configs: 24 | - source_labels: [__path__] 25 | regex: '.*/bot_(?P[A-Z]+)_(?P[A-Z]+)_(?P[A-Z]+)_strategy(?P[A-Z_]+)_spacing(?P[A-Z]+)_size(?P\d+)_range(?P\d+-\d+)_.*\.log$' 26 | target_label: filename 27 | 28 | - source_labels: [base] 29 | target_label: base_currency 30 | - source_labels: [quote] 31 | target_label: quote_currency 32 | - source_labels: [mode] 33 | target_label: trading_mode 34 | - source_labels: [strategy] 35 | target_label: strategy_type 36 | - source_labels: [spacing] 37 | target_label: spacing_type 38 | - source_labels: [size] 39 | target_label: grid_size 40 | - source_labels: [range] 41 | target_label: grid_range 42 | 43 | pipeline_stages: 44 | - regex: 45 | expression: '^(?P\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}) - (?P[A-Za-z0-9_]+) - (?P[A-Z]+) - (?P.*)$' 46 | 47 | - labels: 48 | service: '{{ .service }}' 49 | level: '{{ .level }}' 50 | 51 | # Convert timestamp to Loki format 52 | - timestamp: 53 | source: timestamp 54 | format: "2006-01-02 15:04:05,000" 55 | -------------------------------------------------------------------------------- /strategies/trading_strategy_interface.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import logging 3 | 4 | 5 | class TradingStrategyInterface(ABC): 6 | """ 7 | Abstract base class for all trading strategies. 8 | Requires implementation of key methods for any concrete strategy. 9 | """ 10 | 11 | def __init__(self, config_manager, balance_tracker): 12 | """ 13 | Initializes the strategy with the given configuration manager and balance tracker. 14 | 15 | Args: 16 | config_manager: Provides access to the trading configuration (e.g., exchange, fees). 17 | balance_tracker: Tracks the balance and crypto balance for the strategy. 18 | """ 19 | self.logger = logging.getLogger(self.__class__.__name__) 20 | self.config_manager = config_manager 21 | self.balance_tracker = balance_tracker 22 | 23 | @abstractmethod 24 | def initialize_strategy(self): 25 | """ 26 | Method to initialize the strategy with specific settings (grids, limits, etc.). 27 | Must be implemented by any subclass. 28 | """ 29 | pass 30 | 31 | @abstractmethod 32 | async def run(self): 33 | """ 34 | Run the strategy with historical or live data. 35 | Must be implemented by any subclass. 36 | """ 37 | pass 38 | 39 | @abstractmethod 40 | def plot_results(self): 41 | """ 42 | Plots the strategy performance after simulation. 43 | Must be implemented by any subclass. 44 | """ 45 | pass 46 | 47 | @abstractmethod 48 | def generate_performance_report(self) -> tuple[dict, list]: 49 | """ 50 | Generates a report summarizing the strategy's performance (ROI, max drawdown, etc.). 51 | Must be implemented by any subclass. 52 | """ 53 | pass 54 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Pull Request Template 2 | 3 | ## 📋 Description 4 | 5 | > **Provide a brief description of the changes introduced in this pull request.** 6 | > Explain the problem you are solving or the feature you are adding. 7 | 8 | Fixes # (issue number, if applicable) 9 | 10 | --- 11 | 12 | ## 🔍 Type of Change 13 | 14 | > **Select the type of change your PR introduces:** 15 | 16 | - [ ] 🐞 Bug fix (non-breaking change that fixes an issue) 17 | - [ ] ✨ New feature (non-breaking change that adds functionality) 18 | - [ ] 💥 Breaking change (fix or feature that would cause existing functionality to change) 19 | - [ ] 📝 Documentation update 20 | - [ ] 🔧 Code refactor or optimization 21 | - [ ] ✅ Test improvement or addition 22 | - [ ] Other (please specify): 23 | 24 | --- 25 | 26 | ## ✅ Checklist 27 | 28 | > **Ensure your pull request meets the following requirements:** 29 | 30 | - [ ] I have tested my changes locally and ensured they work as expected. 31 | - [ ] I have added necessary documentation (if applicable). 32 | - [ ] I have added tests that prove my fix/feature works as intended. 33 | - [ ] My code follows the project's code style and guidelines. 34 | - [ ] I have updated the README (if applicable). 35 | 36 | --- 37 | 38 | ## 🚦 How Has This Been Tested? 39 | 40 | > **Explain how you tested your changes.** 41 | > Include steps, test cases, or a description of your testing environment. 42 | 43 | --- 44 | 45 | ## 🔗 Related Issues or Discussions 46 | 47 | > **Mention any related issues or discussions (e.g., `Fixes #123`, `Closes #456`, or links to discussions).** 48 | 49 | --- 50 | 51 | ## 📷 Screenshots (if applicable) 52 | 53 | > **Include screenshots, logs, or recordings to illustrate your changes.** 54 | 55 | --- 56 | 57 | ## 🗒️ Additional Notes 58 | 59 | > **Add any additional context, notes, or considerations here.** 60 | -------------------------------------------------------------------------------- /tests/order_handling/test_fee_calculator.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import pytest 4 | 5 | from core.order_handling.fee_calculator import FeeCalculator 6 | 7 | 8 | class TestFeeCalculator: 9 | @pytest.fixture 10 | def config_manager(self): 11 | mock_config = Mock() 12 | mock_config.get_trading_fee.return_value = 0.001 # 0.1% trading fee 13 | return mock_config 14 | 15 | @pytest.fixture 16 | def fee_calculator(self, config_manager): 17 | return FeeCalculator(config_manager) 18 | 19 | def test_calculate_fee_basic(self, fee_calculator): 20 | trade_value = 1000 21 | expected_fee = 1 # 0.1% of 1000 22 | assert fee_calculator.calculate_fee(trade_value) == pytest.approx(expected_fee) 23 | 24 | def test_calculate_fee_zero(self, fee_calculator): 25 | trade_value = 0 26 | expected_fee = 0 27 | assert fee_calculator.calculate_fee(trade_value) == expected_fee 28 | 29 | def test_calculate_fee_small_value(self, fee_calculator): 30 | trade_value = 0.01 # 1 cent trade 31 | expected_fee = 0.00001 # 0.1% of 0.01 32 | assert fee_calculator.calculate_fee(trade_value) == pytest.approx(expected_fee, rel=1e-5) 33 | 34 | def test_calculate_fee_large_value(self, fee_calculator): 35 | trade_value = 1_000_000 # 1 million trade 36 | expected_fee = 1000 # 0.1% of 1 million 37 | assert fee_calculator.calculate_fee(trade_value) == pytest.approx(expected_fee) 38 | 39 | def test_trading_fee_from_config(self, config_manager, fee_calculator): 40 | assert fee_calculator.trading_fee == config_manager.get_trading_fee() 41 | 42 | def test_calculate_fee_tiny_trade_value_case(self, fee_calculator): 43 | trade_value = 0.0000001 44 | expected_fee = trade_value * 0.001 45 | assert fee_calculator.calculate_fee(trade_value) == pytest.approx(expected_fee, rel=1e-9) 46 | -------------------------------------------------------------------------------- /core/services/exchange_interface.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | import pandas as pd 5 | 6 | 7 | class ExchangeInterface(ABC): 8 | @abstractmethod 9 | async def get_balance(self) -> dict[str, Any]: 10 | """Fetches the account balance, returning a dictionary with fiat and crypto balances.""" 11 | pass 12 | 13 | @abstractmethod 14 | async def place_order( 15 | self, 16 | pair: str, 17 | order_side: str, 18 | order_type: str, 19 | amount: float, 20 | price: float | None = None, 21 | ) -> dict[str, str | float]: 22 | """Places an order, returning a dictionary with order details including id and status.""" 23 | pass 24 | 25 | @abstractmethod 26 | def fetch_ohlcv( 27 | self, 28 | pair: str, 29 | timeframe: str, 30 | start_date: str, 31 | end_date: str, 32 | ) -> pd.DataFrame: 33 | """ 34 | Fetches historical OHLCV data as a list of dictionaries, each containing open, high, low, 35 | close, and volume for the specified time period. 36 | """ 37 | pass 38 | 39 | @abstractmethod 40 | async def get_current_price( 41 | self, 42 | pair: str, 43 | ) -> float: 44 | """Fetches the current market price for the specified trading pair.""" 45 | pass 46 | 47 | @abstractmethod 48 | async def cancel_order( 49 | self, 50 | order_id: str, 51 | pair: str, 52 | ) -> dict[str, str | float]: 53 | """Attempts to cancel an order by ID, returning the result of the cancellation.""" 54 | pass 55 | 56 | @abstractmethod 57 | async def get_exchange_status(self) -> dict: 58 | """Fetches current exchange status.""" 59 | pass 60 | 61 | @abstractmethod 62 | async def close_connection(self) -> None: 63 | """Close current exchange connection.""" 64 | pass 65 | -------------------------------------------------------------------------------- /utils/logging_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging.handlers import RotatingFileHandler 3 | import os 4 | 5 | 6 | def setup_logging( 7 | log_level: int, 8 | log_to_file: bool = False, 9 | config_name: str | None = None, 10 | max_file_size: int = 5_000_000, # 5MB default max file size for rotation 11 | backup_count: int = 5, # Default number of backup files 12 | ) -> None: 13 | """ 14 | Sets up logging with options for console, rotating file logging, and log differentiation. 15 | 16 | Args: 17 | log_level (int): The logging level (e.g., logging.INFO, logging.DEBUG). 18 | log_to_file (bool): Whether to log to a file. 19 | config_name (Optional[str]): Name of the bot configuration to differentiate logs. 20 | max_file_size (int): Maximum size of log file in bytes before rotation. 21 | backup_count (int): Number of backup log files to keep. 22 | """ 23 | handlers = [] 24 | 25 | console_handler = logging.StreamHandler() 26 | console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) 27 | handlers.append(console_handler) 28 | log_file_path = "" 29 | 30 | if log_to_file: 31 | log_dir = "logs" 32 | os.makedirs(log_dir, exist_ok=True) 33 | 34 | if config_name: 35 | log_file_path = os.path.join(log_dir, f"{config_name}.log") 36 | else: 37 | log_file_path = os.path.join(log_dir, "grid_trading_bot.log") 38 | 39 | file_handler = RotatingFileHandler(log_file_path, maxBytes=max_file_size, backupCount=backup_count) 40 | file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) 41 | handlers.append(file_handler) 42 | 43 | logging.basicConfig(level=log_level, handlers=handlers) 44 | logging.info(f"Logging initialized. Log level: {logging.getLevelName(log_level)}") 45 | 46 | if log_to_file: 47 | logging.info(f"File logging enabled. Logs are stored in: {log_file_path}") 48 | -------------------------------------------------------------------------------- /tests/utils/test_config_name_generator.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | 3 | import pytest 4 | 5 | from utils.config_name_generator import generate_config_name 6 | 7 | 8 | @pytest.fixture 9 | def mock_config_manager(): 10 | mock_manager = Mock() 11 | mock_manager.get_base_currency.return_value = "BTC" 12 | mock_manager.get_quote_currency.return_value = "USD" 13 | mock_manager.get_trading_mode.return_value.name = "LIVE" 14 | mock_manager.get_strategy_type.return_value.name = "GRID" 15 | mock_manager.get_spacing_type.return_value.name = "PERCENTAGE" 16 | mock_manager.get_num_grids.return_value = 10 17 | mock_manager.get_top_range.return_value = 50000 18 | mock_manager.get_bottom_range.return_value = 30000 19 | return mock_manager 20 | 21 | 22 | @patch("utils.config_name_generator.datetime") 23 | def test_generate_config_name(mock_datetime, mock_config_manager): 24 | mock_datetime.now.return_value.strftime.return_value = "20241220_1200" 25 | 26 | result = generate_config_name(mock_config_manager) 27 | 28 | expected_name = "bot_BTC_USD_LIVE_strategyGRID_spacingPERCENTAGE_size10_range30000-50000_20241220_1200" 29 | assert result == expected_name 30 | mock_config_manager.get_base_currency.assert_called_once() 31 | mock_config_manager.get_quote_currency.assert_called_once() 32 | mock_config_manager.get_trading_mode.assert_called_once() 33 | mock_config_manager.get_strategy_type.assert_called_once() 34 | mock_config_manager.get_spacing_type.assert_called_once() 35 | mock_config_manager.get_num_grids.assert_called_once() 36 | mock_config_manager.get_top_range.assert_called_once() 37 | mock_config_manager.get_bottom_range.assert_called_once() 38 | 39 | 40 | def test_generate_config_name_edge_cases(mock_config_manager): 41 | mock_config_manager.get_base_currency.return_value = "ETH" 42 | mock_config_manager.get_quote_currency.return_value = "EUR" 43 | mock_config_manager.get_num_grids.return_value = 0 44 | mock_config_manager.get_top_range.return_value = 0 45 | mock_config_manager.get_bottom_range.return_value = 0 46 | 47 | result = generate_config_name(mock_config_manager) 48 | 49 | assert "bot_ETH_EUR" in result 50 | assert "_size0_range0-0" in result 51 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | loki: 3 | image: grafana/loki:3.0.0 4 | container_name: loki 5 | restart: unless-stopped 6 | user: root 7 | ports: 8 | - "3100:3100" 9 | command: -config.file=/etc/loki/loki.yaml 10 | volumes: 11 | - ./monitoring/configs/loki/loki.yaml:/etc/loki/loki.yaml 12 | - ./monitoring/configs/loki/rules.yaml:/etc/loki/rules/fake/loki-rules.yml 13 | - loki-data:/tmp/loki/chunks 14 | cpus: 0.5 15 | mem_limit: 512m 16 | networks: 17 | - default 18 | 19 | grafana: 20 | image: grafana/grafana:11.0.0 21 | container_name: grafana 22 | restart: unless-stopped 23 | ports: 24 | - "3000:3000" 25 | environment: 26 | - GF_AUTH_ANONYMOUS_ENABLED=false 27 | - GF_SECURITY_ADMIN_USER=${GRAFANA_ADMIN_USER} 28 | - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD} 29 | - GF_USERS_DEFAULT_THEME=dark 30 | - GF_LOG_MODE=console 31 | - GF_LOG_LEVEL=debug 32 | - GF_PANELS_ENABLE_ALPHA=false 33 | - GF_FEATURE_TOGGLES_ENABLE=lokiLogsDataplane 34 | - GF_INSTALL_PLUGINS=grafana-polystat-panel 35 | volumes: 36 | - ./monitoring/configs/grafana/provisioning/datasources.yml:/etc/grafana/provisioning/datasources/provisioning-datasources.yaml:ro 37 | - ./monitoring/configs/grafana/provisioning/dashboards.yml:/etc/grafana/provisioning/dashboards/dashboards.yml:ro 38 | - ./monitoring/dashboards/grid_trading_bot_dashboard.json:/var/lib/grafana/dashboards/grid_trading_bot_dashboard.json:ro 39 | - grafana-data:/var/lib/grafana 40 | depends_on: 41 | - loki 42 | cpus: 0.5 43 | mem_limit: 512m 44 | networks: 45 | - default 46 | 47 | promtail: 48 | image: grafana/promtail:3.0.0 49 | container_name: promtail 50 | restart: unless-stopped 51 | volumes: 52 | - ./monitoring/configs/promtail/promtail.yaml:/etc/promtail/promtail.yaml 53 | - ./logs:/logs:ro 54 | - promtail-positions:/tmp 55 | command: -config.file=/etc/promtail/promtail.yaml -config.expand-env=true 56 | depends_on: 57 | - loki 58 | cpus: 0.5 59 | mem_limit: 512m 60 | networks: 61 | - default 62 | 63 | volumes: 64 | grafana-data: 65 | driver: local 66 | loki-data: 67 | driver: local 68 | promtail-positions: 69 | driver: local 70 | -------------------------------------------------------------------------------- /core/order_handling/order_book.py: -------------------------------------------------------------------------------- 1 | from ..grid_management.grid_level import GridLevel 2 | from .order import Order, OrderSide, OrderStatus 3 | 4 | 5 | class OrderBook: 6 | def __init__(self): 7 | self.buy_orders: list[Order] = [] 8 | self.sell_orders: list[Order] = [] 9 | self.non_grid_orders: list[Order] = [] # Orders that are not linked to any grid level 10 | self.order_to_grid_map: dict[Order, GridLevel] = {} # Mapping of Order -> GridLevel 11 | 12 | def add_order( 13 | self, 14 | order: Order, 15 | grid_level: GridLevel | None = None, 16 | ) -> None: 17 | if order.side == OrderSide.BUY: 18 | self.buy_orders.append(order) 19 | else: 20 | self.sell_orders.append(order) 21 | 22 | if grid_level: 23 | self.order_to_grid_map[order] = grid_level # Store the grid level associated with this order 24 | else: 25 | self.non_grid_orders.append(order) # This is a non-grid order like take profit or stop loss 26 | 27 | def get_buy_orders_with_grid(self) -> list[tuple[Order, GridLevel | None]]: 28 | return [(order, self.order_to_grid_map.get(order, None)) for order in self.buy_orders] 29 | 30 | def get_sell_orders_with_grid(self) -> list[tuple[Order, GridLevel | None]]: 31 | return [(order, self.order_to_grid_map.get(order, None)) for order in self.sell_orders] 32 | 33 | def get_all_buy_orders(self) -> list[Order]: 34 | return self.buy_orders 35 | 36 | def get_all_sell_orders(self) -> list[Order]: 37 | return self.sell_orders 38 | 39 | def get_open_orders(self) -> list[Order]: 40 | return [order for order in self.buy_orders + self.sell_orders if order.is_open()] 41 | 42 | def get_completed_orders(self) -> list[Order]: 43 | return [order for order in self.buy_orders + self.sell_orders if order.is_filled()] 44 | 45 | def get_grid_level_for_order(self, order: Order) -> GridLevel | None: 46 | return self.order_to_grid_map.get(order) 47 | 48 | def update_order_status( 49 | self, 50 | order_id: str, 51 | new_status: OrderStatus, 52 | ) -> None: 53 | for order in self.buy_orders + self.sell_orders: 54 | if order.identifier == order_id: 55 | order.status = new_status 56 | break 57 | -------------------------------------------------------------------------------- /core/order_handling/execution_strategy/backtest_order_execution_strategy.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from ..order import Order, OrderSide, OrderStatus, OrderType 4 | from .order_execution_strategy_interface import OrderExecutionStrategyInterface 5 | 6 | 7 | class BacktestOrderExecutionStrategy(OrderExecutionStrategyInterface): 8 | async def execute_market_order( 9 | self, 10 | order_side: OrderSide, 11 | pair: str, 12 | quantity: float, 13 | price: float, 14 | ) -> Order | None: 15 | order_id = f"backtest-{int(time.time())}" 16 | timestamp = int(time.time() * 1000) 17 | return Order( 18 | identifier=order_id, 19 | status=OrderStatus.OPEN, 20 | order_type=OrderType.MARKET, 21 | side=order_side, 22 | price=price, 23 | average=price, 24 | amount=quantity, 25 | filled=quantity, 26 | remaining=0, 27 | timestamp=timestamp, 28 | datetime="111", 29 | last_trade_timestamp=1, 30 | symbol=pair, 31 | time_in_force="GTC", 32 | ) 33 | 34 | async def execute_limit_order( 35 | self, 36 | order_side: OrderSide, 37 | pair: str, 38 | quantity: float, 39 | price: float, 40 | ) -> Order | None: 41 | order_id = f"backtest-{int(time.time())}" 42 | return Order( 43 | identifier=order_id, 44 | status=OrderStatus.OPEN, 45 | order_type=OrderType.LIMIT, 46 | side=order_side, 47 | price=price, 48 | average=price, 49 | amount=quantity, 50 | filled=0, 51 | remaining=quantity, 52 | timestamp=0, 53 | datetime="", 54 | last_trade_timestamp=1, 55 | symbol=pair, 56 | time_in_force="GTC", 57 | ) 58 | 59 | async def get_order( 60 | self, 61 | order_id: str, 62 | pair: str, 63 | ) -> Order | None: 64 | return Order( 65 | identifier=order_id, 66 | status=OrderStatus.OPEN, 67 | order_type=OrderType.LIMIT, 68 | side=OrderSide.BUY, 69 | price=100, 70 | average=100, 71 | amount=1, 72 | filled=1, 73 | remaining=0, 74 | timestamp=0, 75 | datetime="111", 76 | last_trade_timestamp=1, 77 | symbol=pair, 78 | time_in_force="GTC", 79 | ) 80 | -------------------------------------------------------------------------------- /tests/services/test_exchange_service_factory.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | 3 | import pytest 4 | 5 | from config.config_manager import ConfigManager 6 | from config.trading_mode import TradingMode 7 | from core.services.backtest_exchange_service import BacktestExchangeService 8 | from core.services.exchange_service_factory import ExchangeServiceFactory 9 | from core.services.live_exchange_service import LiveExchangeService 10 | 11 | 12 | class TestExchangeServiceFactory: 13 | @pytest.fixture 14 | def config_manager(self): 15 | config_manager = Mock(spec=ConfigManager) 16 | config_manager.get_trading_mode.return_value = TradingMode.LIVE 17 | config_manager.get_exchange_name.return_value = "binance" 18 | return config_manager 19 | 20 | @patch("core.services.live_exchange_service.ccxtpro") 21 | @patch("core.services.live_exchange_service.getattr") 22 | def test_create_live_exchange_service_with_env_vars(self, mock_getattr, mock_ccxtpro, config_manager, monkeypatch): 23 | monkeypatch.setenv("EXCHANGE_API_KEY", "test_api_key") 24 | monkeypatch.setenv("EXCHANGE_SECRET_KEY", "test_secret_key") 25 | 26 | mock_exchange_instance = Mock() 27 | mock_ccxtpro.binance.return_value = mock_exchange_instance 28 | mock_getattr.return_value = mock_ccxtpro.binance 29 | 30 | service = ExchangeServiceFactory.create_exchange_service(config_manager, TradingMode.LIVE) 31 | 32 | assert isinstance(service, LiveExchangeService), "Expected a LiveExchangeService instance" 33 | mock_getattr.assert_called_once_with(mock_ccxtpro, "binance") 34 | mock_ccxtpro.binance.assert_called_once_with( 35 | { 36 | "apiKey": "test_api_key", 37 | "secret": "test_secret_key", 38 | "enableRateLimit": True, 39 | }, 40 | ) 41 | 42 | @patch("core.services.live_exchange_service.ccxtpro") 43 | @patch("core.services.live_exchange_service.getattr") 44 | def test_create_backtest_exchange_service(self, mock_getattr, mock_ccxtpro, config_manager): 45 | config_manager.get_trading_mode.return_value = TradingMode.BACKTEST 46 | service = ExchangeServiceFactory.create_exchange_service(config_manager, TradingMode.BACKTEST) 47 | assert isinstance(service, BacktestExchangeService), "Expected a BacktestExchangeService instance" 48 | 49 | def test_invalid_trading_mode(self, config_manager): 50 | config_manager.get_trading_mode.return_value = "invalid_mode" 51 | with pytest.raises(ValueError, match="Unsupported trading mode: invalid_mode"): 52 | ExchangeServiceFactory.create_exchange_service(config_manager, "invalid_mode") 53 | -------------------------------------------------------------------------------- /utils/performance_results_saver.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | import json 3 | import logging 4 | import os 5 | from typing import Any 6 | 7 | import pandas as pd 8 | 9 | 10 | def save_or_append_performance_results( 11 | new_results: dict[str, Any], 12 | file_path: str, 13 | ) -> None: 14 | """ 15 | Saves or appends performance results to a JSON file. 16 | 17 | Args: 18 | new_results: Dictionary containing performance summary and orders. 19 | file_path: Path to the JSON file. 20 | """ 21 | try: 22 | if os.path.exists(file_path): 23 | with open(file_path) as json_file: 24 | try: 25 | all_results = json.load(json_file) 26 | if not isinstance(all_results, list): 27 | logging.error(f"Existing file {file_path} is not a valid JSON list. Overwriting the file.") 28 | all_results = [] 29 | except json.JSONDecodeError: 30 | logging.warning(f"Could not decode JSON from {file_path}. Overwriting the file.") 31 | all_results = [] 32 | else: 33 | all_results = [] 34 | 35 | cleaned_performance_summary = { 36 | key: ( 37 | value.isoformat() 38 | if isinstance(value, datetime | pd.Timestamp) 39 | else str(value) 40 | if isinstance(value, timedelta) 41 | else value 42 | ) 43 | for key, value in new_results.get("performance_summary").items() 44 | } 45 | 46 | order_keys = ["Order Side", "Type", "Status", "Price", "Quantity", "Timestamp", "Grid Level", "Slippage"] 47 | cleaned_orders = [ 48 | { 49 | key: (value.isoformat() if isinstance(value, datetime | pd.Timestamp) else value) 50 | for key, value in zip(order_keys, order, strict=False) 51 | } 52 | for order in new_results.get("orders") 53 | ] 54 | 55 | cleaned_results = { 56 | "config": new_results.get("config"), 57 | "performance_summary": cleaned_performance_summary, 58 | "orders": cleaned_orders, 59 | } 60 | all_results.append(cleaned_results) 61 | 62 | with open(file_path, "w") as json_file: 63 | json.dump(all_results, json_file, indent=4) 64 | 65 | logging.info(f"Performance metrics saved to {file_path}") 66 | 67 | except OSError as e: 68 | logging.error(f"Failed to save performance metrics to {file_path}: {e}") 69 | 70 | except Exception as e: 71 | logging.error(f"An unexpected error occurred while saving performance metrics: {e}") 72 | -------------------------------------------------------------------------------- /tests/order_handling/test_order_execution_strategy_factory.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import pytest 4 | 5 | from config.config_manager import ConfigManager 6 | from config.trading_mode import TradingMode 7 | from core.order_handling.execution_strategy.backtest_order_execution_strategy import ( 8 | BacktestOrderExecutionStrategy, 9 | ) 10 | from core.order_handling.execution_strategy.live_order_execution_strategy import ( 11 | LiveOrderExecutionStrategy, 12 | ) 13 | from core.order_handling.execution_strategy.order_execution_strategy_factory import ( 14 | OrderExecutionStrategyFactory, 15 | ) 16 | from core.services.exchange_interface import ExchangeInterface 17 | 18 | 19 | class TestOrderExecutionStrategyFactory: 20 | @pytest.fixture 21 | def config_manager(self): 22 | return Mock(spec=ConfigManager) 23 | 24 | @pytest.fixture 25 | def exchange_service(self): 26 | return Mock(spec=ExchangeInterface) 27 | 28 | def test_create_live_strategy(self, config_manager, exchange_service): 29 | config_manager.get_trading_mode.return_value = TradingMode.LIVE 30 | strategy = OrderExecutionStrategyFactory.create(config_manager, exchange_service) 31 | 32 | assert isinstance( 33 | strategy, 34 | LiveOrderExecutionStrategy, 35 | ), "Expected LiveOrderExecutionStrategy instance for live trading mode" 36 | assert ( 37 | strategy.exchange_service == exchange_service 38 | ), "Expected exchange_service to be set correctly in LiveOrderExecutionStrategy" 39 | 40 | def test_create_paper_trading_strategy(self, config_manager, exchange_service): 41 | config_manager.get_trading_mode.return_value = TradingMode.PAPER_TRADING 42 | strategy = OrderExecutionStrategyFactory.create(config_manager, exchange_service) 43 | 44 | assert isinstance( 45 | strategy, 46 | LiveOrderExecutionStrategy, 47 | ), "Expected LiveOrderExecutionStrategy instance for paper trading mode" 48 | assert ( 49 | strategy.exchange_service == exchange_service 50 | ), "Expected exchange_service to be set correctly in LiveOrderExecutionStrategy" 51 | 52 | def test_create_backtest_strategy(self, config_manager, exchange_service): 53 | config_manager.get_trading_mode.return_value = TradingMode.BACKTEST 54 | strategy = OrderExecutionStrategyFactory.create(config_manager, exchange_service) 55 | 56 | assert isinstance( 57 | strategy, 58 | BacktestOrderExecutionStrategy, 59 | ), "Expected BacktestOrderExecutionStrategy instance for backtesting mode" 60 | 61 | def test_invalid_trading_mode_raises_error(self, config_manager, exchange_service): 62 | config_manager.get_trading_mode.return_value = "UNKNOWN_MODE" # Simulate an invalid mode 63 | with pytest.raises(ValueError, match="Unknown trading mode: UNKNOWN_MODE"): 64 | OrderExecutionStrategyFactory.create(config_manager, exchange_service) 65 | -------------------------------------------------------------------------------- /tests/utils/test_performance_results_saver.py: -------------------------------------------------------------------------------- 1 | from datetime import UTC, datetime, timedelta 2 | from unittest.mock import mock_open, patch 3 | 4 | import pytest 5 | 6 | from utils.performance_results_saver import save_or_append_performance_results 7 | 8 | 9 | @pytest.fixture 10 | def new_results_fixture(): 11 | return { 12 | "config": "config.json", 13 | "performance_summary": { 14 | "start_time": datetime(2024, 12, 20, 10, 0, 0, tzinfo=UTC).isoformat(), 15 | "end_time": datetime(2024, 12, 20, 12, 0, 0, tzinfo=UTC).isoformat(), 16 | "total_profit": 500.0, 17 | "runtime": str(timedelta(hours=2)), 18 | }, 19 | "orders": [ 20 | [ 21 | "BUY", 22 | "LIMIT", 23 | "FILLED", 24 | 1000.0, 25 | 0.5, 26 | datetime(2024, 12, 20, 10, 30, 0, tzinfo=UTC).isoformat(), 27 | "Level 1", 28 | 0.1, 29 | ], 30 | [ 31 | "SELL", 32 | "LIMIT", 33 | "FILLED", 34 | 1500.0, 35 | 0.5, 36 | datetime(2024, 12, 20, 11, 30, 0, tzinfo=UTC).isoformat(), 37 | "Level 2", 38 | 0.05, 39 | ], 40 | ], 41 | } 42 | 43 | 44 | def test_save_or_append_performance_results_invalid_json(new_results_fixture): 45 | with ( 46 | patch("builtins.open", mock_open(read_data="INVALID_JSON")) as mocked_file, 47 | patch("os.path.exists", return_value=True), 48 | patch("utils.performance_results_saver.logging.warning") as mock_logger_warning, 49 | ): 50 | save_or_append_performance_results(new_results_fixture, "results.json") 51 | 52 | mocked_file.assert_any_call("results.json") 53 | mocked_file.assert_any_call("results.json", "w") 54 | mock_logger_warning.assert_called_once_with("Could not decode JSON from results.json. Overwriting the file.") 55 | 56 | 57 | def test_save_or_append_performance_results_os_error(new_results_fixture): 58 | with ( 59 | patch("builtins.open", side_effect=OSError("Test OS Error")), 60 | patch("utils.performance_results_saver.logging.error") as mock_logger_error, 61 | ): 62 | save_or_append_performance_results(new_results_fixture, "results.json") 63 | 64 | mock_logger_error.assert_called_once_with("Failed to save performance metrics to results.json: Test OS Error") 65 | 66 | 67 | def test_save_or_append_performance_results_unexpected_exception(new_results_fixture): 68 | with ( 69 | patch("builtins.open", side_effect=Exception("Unexpected Error")), 70 | patch("utils.performance_results_saver.logging.error") as mock_logger_error, 71 | ): 72 | save_or_append_performance_results(new_results_fixture, "results.json") 73 | 74 | mock_logger_error.assert_any_call( 75 | "An unexpected error occurred while saving performance metrics: Unexpected Error", 76 | ) 77 | -------------------------------------------------------------------------------- /tests/utils/test_logging_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from unittest.mock import MagicMock, patch 3 | 4 | import pytest 5 | 6 | from utils.logging_config import setup_logging 7 | 8 | 9 | @pytest.fixture 10 | def mock_makedirs(): 11 | with patch("os.makedirs") as mocked_makedirs: 12 | yield mocked_makedirs 13 | 14 | 15 | @pytest.fixture 16 | def mock_basic_config(): 17 | with patch("logging.basicConfig") as mocked_basic_config: 18 | yield mocked_basic_config 19 | 20 | 21 | @pytest.fixture 22 | def mock_rotating_file_handler(): 23 | with patch("logging.handlers.RotatingFileHandler") as mocked_handler: 24 | mocked_handler.return_value = MagicMock() 25 | yield mocked_handler 26 | 27 | 28 | def test_setup_logging_console_only(mock_basic_config): 29 | setup_logging(log_level=logging.INFO, log_to_file=False) 30 | 31 | mock_basic_config.assert_called_once() 32 | handlers = mock_basic_config.call_args[1]["handlers"] 33 | assert len(handlers) == 1 34 | assert isinstance(handlers[0], logging.StreamHandler) 35 | 36 | 37 | @patch("os.makedirs") 38 | @patch("logging.basicConfig") 39 | @patch("logging.handlers.RotatingFileHandler") 40 | def test_setup_logging_file_logging(mock_rotating_file_handler, mock_basic_config, mock_makedirs): 41 | setup_logging( 42 | log_level=logging.DEBUG, 43 | log_to_file=True, 44 | config_name="test_config", 45 | max_file_size=10_000_000, 46 | backup_count=3, 47 | ) 48 | 49 | mock_makedirs.assert_called_once_with("logs", exist_ok=True) 50 | mock_basic_config.assert_called_once() 51 | handlers = mock_basic_config.call_args[1]["handlers"] 52 | 53 | assert len(handlers) == 2 54 | assert any(isinstance(handler, logging.StreamHandler) for handler in handlers) 55 | 56 | 57 | @patch("os.makedirs") 58 | @patch("logging.basicConfig") 59 | @patch("logging.handlers.RotatingFileHandler") 60 | def test_setup_logging_default_file_logging(mock_rotating_file_handler, mock_basic_config, mock_makedirs): 61 | setup_logging(log_level=logging.WARNING, log_to_file=True) 62 | 63 | mock_makedirs.assert_called_once_with("logs", exist_ok=True) 64 | mock_basic_config.assert_called_once() 65 | handlers = mock_basic_config.call_args[1]["handlers"] 66 | 67 | assert len(handlers) == 2 68 | assert any(isinstance(handler, logging.StreamHandler) for handler in handlers) 69 | 70 | 71 | def test_setup_logging_logs_info(mock_basic_config, mock_rotating_file_handler, caplog): 72 | with caplog.at_level(logging.INFO): 73 | setup_logging(log_level=logging.INFO, log_to_file=True, config_name="test_config") 74 | 75 | assert "Logging initialized. Log level: INFO" in caplog.text 76 | assert "File logging enabled. Logs are stored in: logs/test_config.log" in caplog.text 77 | 78 | 79 | def test_setup_logging_directory_creation_error(mock_makedirs): 80 | mock_makedirs.side_effect = OSError("Directory creation failed") 81 | 82 | with pytest.raises(OSError, match="Directory creation failed"): 83 | setup_logging(log_level=logging.DEBUG, log_to_file=True) 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "grid_trading_bot" 3 | version = "0.1.0" 4 | description = "Open-source cryptocurrency trading bot designed to perform grid trading strategies using historical data for backtesting." 5 | authors = [{ name = "Jordan TETE", email = "tetej171@gmail.com" }] 6 | readme = "README.md" 7 | license = { file = "LICENSE.txt" } 8 | requires-python = ">=3.12" 9 | dependencies = [ 10 | "pandas==2.2.3", 11 | "numpy==2.1.3", 12 | "plotly==6.4.0", 13 | "tabulate==0.9.0", 14 | "aiohttp==3.10.11", 15 | "apprise==1.9.5", 16 | "ccxt==4.4.82", 17 | "configparser==7.2.0", 18 | "psutil==7.1.3", 19 | "python-dotenv==1.2.1" 20 | ] 21 | 22 | [project.optional-dependencies] 23 | dev = [ 24 | "pytest==8.4.2", 25 | "pytest-asyncio==0.26.0", 26 | "pytest-cov==7.0.0", 27 | "pytest-timeout==2.4.0", 28 | "pre-commit==4.4.0", 29 | ] 30 | 31 | [project.urls] 32 | repository= "https://github.com/jordantete/grid_trading_bot" 33 | issues = "https://github.com/jordantete/grid_trading_bot/issues" 34 | discussions = "https://github.com/jordantete/grid_trading_bot/discussions" 35 | 36 | [tool.pytest.ini_options] 37 | asyncio_mode = "auto" 38 | testpaths = ["tests"] 39 | python_files = ["test_*.py"] 40 | python_classes = ["Test*"] 41 | python_functions = ["test_*"] 42 | log_cli = true 43 | log_cli_level = "INFO" 44 | markers = [ 45 | "asyncio: mark a test as an async test", 46 | "timeout: mark a test with a timeout" 47 | ] 48 | 49 | [tool.coverage.run] 50 | omit = [ 51 | "*/interface*.py" 52 | ] 53 | 54 | [tool.ruff] 55 | line-length = 120 56 | target-version = "py312" 57 | src = ["core", "strategies", "config", "utils", "tests"] 58 | 59 | [tool.ruff.lint] 60 | select = [ 61 | "E", # pycodestyle errors 62 | "W", # pycodestyle warnings 63 | "F", # Pyflakes 64 | "I", # isort 65 | "B", # flake8-bugbear 66 | "C4", # flake8-comprehensions 67 | "UP", # pyupgrade 68 | "N", # pep8-naming 69 | "S", # flake8-bandit (security) 70 | "T20", # flake8-print 71 | "SIM", # flake8-simplify 72 | "RUF", # Ruff-specific rules 73 | "PT", # flake8-pytest-style 74 | "Q", # flake8-quotes 75 | "A", # flake8-builtins 76 | "COM", # flake8-commas 77 | "DTZ", # flake8-datetimez 78 | "TCH", # flake8-type-checking 79 | ] 80 | ignore = [ 81 | "S101", # assert used (OK in tests) 82 | "COM812", # missing trailing comma (handled by formatter) 83 | ] 84 | 85 | [tool.ruff.format] 86 | quote-style = "double" 87 | indent-style = "space" 88 | skip-magic-trailing-comma = false 89 | line-ending = "auto" 90 | docstring-code-format = true 91 | 92 | [tool.ruff.lint.isort] 93 | known-first-party = ["core", "strategies", "config", "utils"] 94 | force-sort-within-sections = true 95 | 96 | [build-system] 97 | requires = ["setuptools>=42", "wheel"] 98 | build-backend = "setuptools.build_meta" 99 | 100 | [tool.setuptools.packages.find] 101 | where = ["."] 102 | exclude = ["logs", "data"] 103 | 104 | [tool.setuptools] 105 | include-package-data = true 106 | -------------------------------------------------------------------------------- /tests/order_handling/test_backtest_order_execution_strategy.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | from core.order_handling.execution_strategy.backtest_order_execution_strategy import ( 6 | BacktestOrderExecutionStrategy, 7 | ) 8 | from core.order_handling.order import OrderSide, OrderStatus, OrderType 9 | 10 | 11 | @pytest.mark.asyncio 12 | class TestBacktestOrderExecutionStrategy: 13 | @pytest.fixture 14 | def setup_strategy(self): 15 | return BacktestOrderExecutionStrategy() 16 | 17 | @patch("time.time", return_value=1680000000) 18 | async def test_execute_market_order(self, mock_time, setup_strategy): 19 | strategy = setup_strategy 20 | order_side = OrderSide.BUY 21 | pair = "BTC/USDT" 22 | quantity = 0.5 23 | price = 30000 24 | 25 | order = await strategy.execute_market_order(order_side, pair, quantity, price) 26 | 27 | assert order is not None 28 | assert order.identifier == "backtest-1680000000" 29 | assert order.status == OrderStatus.OPEN 30 | assert order.order_type == OrderType.MARKET 31 | assert order.side == order_side 32 | assert order.price == price 33 | assert order.amount == quantity 34 | assert order.filled == quantity 35 | assert order.remaining == 0 36 | assert order.symbol == pair 37 | assert order.time_in_force == "GTC" 38 | assert order.timestamp == 1680000000000 39 | 40 | @patch("time.time", return_value=1680000000) 41 | async def test_execute_limit_order(self, mock_time, setup_strategy): 42 | strategy = setup_strategy 43 | order_side = OrderSide.SELL 44 | pair = "ETH/USDT" 45 | quantity = 1 46 | price = 2000 47 | 48 | order = await strategy.execute_limit_order(order_side, pair, quantity, price) 49 | 50 | assert order is not None 51 | assert order.identifier == "backtest-1680000000" 52 | assert order.status == OrderStatus.OPEN 53 | assert order.order_type == OrderType.LIMIT 54 | assert order.side == order_side 55 | assert order.price == price 56 | assert order.amount == quantity 57 | assert order.filled == 0 58 | assert order.remaining == quantity 59 | assert order.symbol == pair 60 | assert order.time_in_force == "GTC" 61 | assert order.timestamp == 0 62 | 63 | async def test_get_order(self, setup_strategy): 64 | strategy = setup_strategy 65 | order_id = "test-order-1" 66 | order_symbol = "BTC/USDT" 67 | 68 | order = await strategy.get_order(order_id, order_symbol) 69 | 70 | assert order is not None 71 | assert order.identifier == order_id 72 | assert order.status == OrderStatus.OPEN 73 | assert order.order_type == OrderType.LIMIT 74 | assert order.side == OrderSide.BUY 75 | assert order.price == 100 76 | assert order.average == 100 77 | assert order.amount == 1 78 | assert order.filled == 1 79 | assert order.remaining == 0 80 | assert order.symbol == order_symbol 81 | assert order.time_in_force == "GTC" 82 | assert order.timestamp == 0 83 | -------------------------------------------------------------------------------- /core/bot_management/notification/notification_handler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from concurrent.futures import ThreadPoolExecutor 3 | import logging 4 | 5 | import apprise 6 | 7 | from config.trading_mode import TradingMode 8 | from core.bot_management.event_bus import EventBus, Events 9 | from core.order_handling.order import Order 10 | 11 | from .notification_content import NotificationType 12 | 13 | 14 | class NotificationHandler: 15 | """ 16 | Handles sending notifications through various channels using the Apprise library. 17 | Supports multiple notification services like Telegram, Discord, Slack, etc. 18 | """ 19 | 20 | _executor = ThreadPoolExecutor(max_workers=3) 21 | 22 | def __init__( 23 | self, 24 | event_bus: EventBus, 25 | urls: list[str] | None, 26 | trading_mode: TradingMode, 27 | ): 28 | self.logger = logging.getLogger(self.__class__.__name__) 29 | self.event_bus = event_bus 30 | self.enabled = bool(urls) and trading_mode in {TradingMode.LIVE, TradingMode.PAPER_TRADING} 31 | self.lock = asyncio.Lock() 32 | self.apprise_instance = apprise.Apprise() if self.enabled else None 33 | 34 | if self.enabled and urls is not None: 35 | self.event_bus.subscribe(Events.ORDER_FILLED, self._send_notification_on_order_filled) 36 | 37 | for url in urls: 38 | self.apprise_instance.add(url) 39 | 40 | def send_notification( 41 | self, 42 | content: NotificationType | str, 43 | **kwargs, 44 | ) -> None: 45 | if self.enabled and self.apprise_instance: 46 | if isinstance(content, NotificationType): 47 | title = content.value.title 48 | message_template = content.value.message 49 | required_placeholders = { 50 | key.strip("{}") for key in message_template.split() if "{" in key and "}" in key 51 | } 52 | missing_placeholders = required_placeholders - kwargs.keys() 53 | 54 | if missing_placeholders: 55 | self.logger.warning( 56 | f"Missing placeholders for notification: {missing_placeholders}. " 57 | "Defaulting to 'N/A' for missing values.", 58 | ) 59 | 60 | message = message_template.format(**{key: kwargs.get(key, "N/A") for key in required_placeholders}) 61 | else: 62 | title = "Notification" 63 | message = content 64 | 65 | self.apprise_instance.notify(title=title, body=message) 66 | 67 | async def async_send_notification( 68 | self, 69 | content: NotificationType | str, 70 | **kwargs, 71 | ) -> None: 72 | async with self.lock: 73 | loop = asyncio.get_running_loop() 74 | try: 75 | await asyncio.wait_for( 76 | loop.run_in_executor(self._executor, lambda: self.send_notification(content, **kwargs)), 77 | timeout=5, 78 | ) 79 | except Exception as e: 80 | self.logger.error(f"Failed to send notification: {e!s}") 81 | 82 | async def _send_notification_on_order_filled(self, order: Order) -> None: 83 | await self.async_send_notification(NotificationType.ORDER_FILLED, order_details=str(order)) 84 | -------------------------------------------------------------------------------- /core/order_handling/order.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import pandas as pd 4 | 5 | 6 | class OrderSide(Enum): 7 | BUY = "buy" 8 | SELL = "sell" 9 | 10 | 11 | class OrderType(Enum): 12 | MARKET = "market" 13 | LIMIT = "limit" 14 | 15 | 16 | class OrderStatus(Enum): 17 | OPEN = "open" 18 | CLOSED = "closed" 19 | CANCELED = "canceled" 20 | EXPIRED = "expired" 21 | REJECTED = "rejected" 22 | UNKNOWN = "unknown" 23 | 24 | 25 | class Order: 26 | def __init__( 27 | self, 28 | identifier: str, 29 | status: OrderStatus, 30 | order_type: OrderType, 31 | side: OrderSide, 32 | price: float, 33 | average: float | None, 34 | amount: float, 35 | filled: float, 36 | remaining: float, 37 | timestamp: int, 38 | datetime: str | None, 39 | last_trade_timestamp: int | None, 40 | symbol: str, 41 | time_in_force: str | None, 42 | trades: list[dict[str, str | float]] | None = None, 43 | fee: dict[str, str | float] | None = None, 44 | cost: float | None = None, 45 | info: dict[str, str | float | dict] | None = None, 46 | ): 47 | self.identifier = identifier 48 | self.status = status # 'open', 'closed', 'canceled', 'expired', 'rejected' 49 | self.order_type = order_type # 'market', 'limit' 50 | self.side = side # 'buy', 'sell' 51 | self.price = price # float price in quote currency (may be empty for market orders) 52 | self.average = average # float average filling price 53 | self.amount = amount # ordered amount of base currency 54 | self.filled = filled # filled amount of base currency 55 | self.remaining = remaining # remaining amount to fill 56 | self.timestamp = timestamp # order placing/opening Unix timestamp in milliseconds 57 | self.datetime = datetime # ISO8601 datetime of 'timestamp' with milliseconds 58 | self.last_trade_timestamp = last_trade_timestamp # Unix timestamp of the most recent trade on this order 59 | self.symbol = symbol # symbol 60 | self.time_in_force = time_in_force # 'GTC', 'IOC', 'FOK', 'PO' 61 | self.trades = trades # a list of order trades/executions 62 | self.fee = fee # fee info, if available 63 | self.cost = cost # 'filled' * 'price' (filling price used where available) 64 | self.info = info # Original unparsed structure for debugging or auditing 65 | 66 | def is_filled(self) -> bool: 67 | return self.status == OrderStatus.CLOSED 68 | 69 | def is_canceled(self) -> bool: 70 | return self.status == OrderStatus.CANCELED 71 | 72 | def is_open(self) -> bool: 73 | return self.status == OrderStatus.OPEN 74 | 75 | def format_last_trade_timestamp(self) -> str | None: 76 | if self.last_trade_timestamp is None: 77 | return None 78 | return pd.Timestamp(self.last_trade_timestamp, unit="s").isoformat() 79 | 80 | def __str__(self) -> str: 81 | return ( 82 | f"Order(id={self.identifier}, status={self.status}, " 83 | f"type={self.order_type}, side={self.side}, price={self.price}, average={self.average}, " 84 | f"amount={self.amount}, filled={self.filled}, remaining={self.remaining}, " 85 | f"timestamp={self.timestamp}, datetime={self.datetime}, symbol={self.symbol}, " 86 | f"time_in_force={self.time_in_force}, trades={self.trades}, fee={self.fee}, cost={self.cost})" 87 | ) 88 | 89 | def __repr__(self) -> str: 90 | return self.__str__() 91 | -------------------------------------------------------------------------------- /monitoring/dashboards/grid_trading_bot_dashboard.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "Grid Trading Bot Dashboard", 3 | "variables": [ 4 | { 5 | "name": "trading_pair", 6 | "type": "query", 7 | "datasource": "Loki", 8 | "query": "label_values(trading_pair)" 9 | }, 10 | { 11 | "name": "trading_mode", 12 | "type": "query", 13 | "datasource": "Loki", 14 | "query": "label_values(trading_mode)" 15 | }, 16 | { 17 | "name": "strategy", 18 | "type": "query", 19 | "datasource": "Loki", 20 | "query": "label_values(strategy_type)" 21 | } 22 | ], 23 | "panels": [ 24 | { 25 | "title": "Strategy Overview", 26 | "type": "stat", 27 | "datasource": "Loki", 28 | "targets": [ 29 | { 30 | "expr": "{job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\"} | json | line_format \"{{.grid_size}} grids, Range: {{.grid_range}}, Spacing: {{.spacing_type}}\"" 31 | } 32 | ] 33 | }, 34 | { 35 | "title": "ROI Over Time", 36 | "type": "timeseries", 37 | "datasource": "Loki", 38 | "targets": [ 39 | { 40 | "expr": "{job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\"} | regexp \"ROI\\s+\\|\\s+(?P[\\-\\d\\.]+)%\" | unwrap roi" 41 | } 42 | ] 43 | }, 44 | { 45 | "title": "Grid Level States", 46 | "type": "table", 47 | "datasource": "Loki", 48 | "targets": [ 49 | { 50 | "expr": "{job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\"} | json | grid_price != \"\" and grid_state != \"\" | line_format \"{{.grid_price}} - {{.grid_state}}\"" 51 | } 52 | ] 53 | }, 54 | { 55 | "title": "Order Flow", 56 | "type": "timeseries", 57 | "datasource": "Loki", 58 | "targets": [ 59 | { 60 | "expr": "sum(count_over_time({job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\", order_side=\"BUY\"}[5m]))", 61 | "legendFormat": "Buy Orders" 62 | }, 63 | { 64 | "expr": "sum(count_over_time({job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\", order_side=\"SELL\"}[5m]))", 65 | "legendFormat": "Sell Orders" 66 | } 67 | ] 68 | }, 69 | { 70 | "title": "Balance History", 71 | "type": "timeseries", 72 | "datasource": "Loki", 73 | "targets": [ 74 | { 75 | "expr": "{job=\"grid_trading_bot\", trading_pair=\"$trading_pair\", trading_mode=\"$trading_mode\", strategy_type=\"$strategy\"} | regexp \"Balance: (?P[\\d\\.]+)\" | unwrap balance" 76 | } 77 | ] 78 | }, 79 | { 80 | "title": "System Health", 81 | "type": "gauge", 82 | "datasource": "Loki", 83 | "targets": [ 84 | { 85 | "expr": "{job=\"grid_trading_bot\", trading_pair=\"$trading_pair\"} | json | unwrap cpu" 86 | } 87 | ], 88 | "fieldConfig": { 89 | "defaults": { 90 | "thresholds": { 91 | "steps": [ 92 | { "value": 0, "color": "green" }, 93 | { "value": 70, "color": "yellow" }, 94 | { "value": 85, "color": "red" } 95 | ] 96 | } 97 | } 98 | } 99 | } 100 | ], 101 | "refresh": "10s", 102 | "schemaVersion": 36 103 | } 104 | -------------------------------------------------------------------------------- /utils/arg_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import traceback 5 | 6 | 7 | def validate_args(args): 8 | """ 9 | Validates parsed arguments. 10 | 11 | Args: 12 | args: Parsed arguments object. 13 | Raises: 14 | ValueError: If validation fails. 15 | """ 16 | # Validate --config 17 | if args.config: 18 | for config_path in args.config: 19 | if not os.path.exists(config_path): 20 | raise ValueError(f"Config file does not exist: {config_path}") 21 | 22 | # Validate --save_performance_results directory 23 | if args.save_performance_results: 24 | save_performance_dir = os.path.dirname(args.save_performance_results) 25 | if save_performance_dir and not os.path.exists(save_performance_dir): 26 | raise ValueError(f"The directory for saving performance results does not exist: {save_performance_dir}") 27 | 28 | 29 | def parse_and_validate_console_args(cli_args=None): 30 | """ 31 | Parses and validates console arguments. 32 | 33 | Args: 34 | cli_args: Optional CLI arguments for testing. 35 | Returns: 36 | argparse.Namespace: Parsed and validated arguments. 37 | Raises: 38 | RuntimeError: If argument parsing or validation fails. 39 | """ 40 | try: 41 | parser = argparse.ArgumentParser( 42 | description="📈 Spot Grid Trading Bot - Automate your grid trading strategy with confidence\n\n" 43 | "This bot lets you automate your trading by implementing a grid strategy. " 44 | "Set your parameters, watch it execute, and manage your trades more effectively. " 45 | "Ideal for both beginners and experienced traders!", 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 47 | ) 48 | 49 | required_args = parser.add_argument_group("Required Arguments") 50 | required_args.add_argument( 51 | "--config", 52 | type=str, 53 | nargs="+", 54 | required=True, 55 | metavar="CONFIG", 56 | help="Path(s) to the configuration file(s) containing strategy details.", 57 | ) 58 | 59 | optional_args = parser.add_argument_group("Optional Arguments") 60 | optional_args.add_argument( 61 | "--save_performance_results", 62 | type=str, 63 | metavar="FILE", 64 | help="Path to save simulation results (e.g., results.json).", 65 | ) 66 | optional_args.add_argument( 67 | "--no-plot", 68 | action="store_true", 69 | help="Disable the display of plots at the end of the simulation.", 70 | ) 71 | optional_args.add_argument( 72 | "--profile", 73 | action="store_true", 74 | help="Enable profiling for performance analysis.", 75 | ) 76 | 77 | args = parser.parse_args(cli_args) 78 | validate_args(args) 79 | return args 80 | 81 | except SystemExit as e: 82 | if e.code == 0: # Exit code 0 indicates a successful --help invocation 83 | raise 84 | logging.error(f"Argument parsing failed: {e}") 85 | raise RuntimeError("Failed to parse arguments. Please check your inputs.") from e 86 | 87 | except ValueError as e: 88 | logging.error(f"Validation failed: {e}") 89 | raise RuntimeError("Argument validation failed.") from e 90 | 91 | except Exception as e: 92 | logging.error(f"An unexpected error occurred while parsing arguments: {e}") 93 | logging.error(traceback.format_exc()) 94 | raise RuntimeError("An unexpected error occurred during argument parsing.") from e 95 | -------------------------------------------------------------------------------- /tests/bot_management/test_event_bus.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from unittest.mock import AsyncMock, Mock 4 | 5 | import pytest 6 | 7 | from core.bot_management.event_bus import EventBus, Events 8 | 9 | 10 | class TestEventBus: 11 | @pytest.fixture 12 | def event_bus(self): 13 | return EventBus() 14 | 15 | def test_subscribe(self, event_bus): 16 | callback = Mock() 17 | event_bus.subscribe(Events.ORDER_FILLED, callback) 18 | assert Events.ORDER_FILLED in event_bus.subscribers 19 | assert callback in event_bus.subscribers[Events.ORDER_FILLED] 20 | 21 | @pytest.mark.asyncio 22 | async def test_publish_async_single_callback(self, event_bus): 23 | async_callback = AsyncMock() 24 | event_bus.subscribe(Events.ORDER_FILLED, async_callback) 25 | await event_bus.publish(Events.ORDER_FILLED, {"data": "test"}) 26 | async_callback.assert_awaited_once_with({"data": "test"}) 27 | 28 | @pytest.mark.asyncio 29 | async def test_publish_async_multiple_callbacks(self, event_bus): 30 | async_callback_1 = AsyncMock() 31 | async_callback_2 = AsyncMock() 32 | event_bus.subscribe(Events.ORDER_FILLED, async_callback_1) 33 | event_bus.subscribe(Events.ORDER_FILLED, async_callback_2) 34 | await event_bus.publish(Events.ORDER_FILLED, {"data": "test"}) 35 | async_callback_1.assert_awaited_once_with({"data": "test"}) 36 | async_callback_2.assert_awaited_once_with({"data": "test"}) 37 | 38 | @pytest.mark.asyncio 39 | async def test_publish_async_with_exception(self, event_bus, caplog): 40 | failing_callback = AsyncMock(side_effect=Exception("Test Error")) 41 | event_bus.subscribe(Events.ORDER_FILLED, failing_callback) 42 | 43 | await event_bus.publish(Events.ORDER_FILLED, {"data": "test"}) 44 | 45 | # Wait for all tasks in the event bus to complete 46 | await asyncio.gather(*event_bus._tasks, return_exceptions=True) 47 | 48 | assert "Error in async callback 'AsyncMock'" in caplog.text 49 | assert "Test Error" in caplog.text 50 | 51 | def test_publish_sync(self, event_bus): 52 | sync_callback = Mock() 53 | event_bus.subscribe(Events.ORDER_FILLED, sync_callback) 54 | event_bus.publish_sync(Events.ORDER_FILLED, {"data": "test"}) 55 | sync_callback.assert_called_once_with({"data": "test"}) 56 | 57 | @pytest.mark.asyncio 58 | async def test_safe_invoke_async(self, event_bus, caplog): 59 | async_callback = AsyncMock() 60 | 61 | await event_bus._safe_invoke_async(async_callback, {"data": "test"}) 62 | 63 | # Wait for all tasks in the EventBus to complete 64 | await asyncio.gather(*event_bus._tasks, return_exceptions=True) 65 | 66 | async_callback.assert_awaited_once_with({"data": "test"}) 67 | 68 | @pytest.mark.asyncio 69 | async def test_safe_invoke_async_with_exception(self, event_bus, caplog): 70 | failing_callback = AsyncMock(side_effect=Exception("Async Error")) 71 | caplog.set_level(logging.DEBUG) 72 | 73 | await event_bus._safe_invoke_async(failing_callback, {"data": "test"}) 74 | await asyncio.gather(*event_bus._tasks, return_exceptions=True) 75 | 76 | assert "Error in async callback" in caplog.text 77 | assert "Async Error" in caplog.text 78 | assert "Task created for callback" in caplog.text 79 | 80 | def test_safe_invoke_sync(self, event_bus, caplog): 81 | sync_callback = Mock() 82 | event_bus._safe_invoke_sync(sync_callback, {"data": "test"}) 83 | sync_callback.assert_called_once_with({"data": "test"}) 84 | 85 | def test_safe_invoke_sync_with_exception(self, event_bus, caplog): 86 | failing_callback = Mock(side_effect=Exception("Sync Error")) 87 | event_bus._safe_invoke_sync(failing_callback, {"data": "test"}) 88 | assert "Error in sync subscriber callback" in caplog.text 89 | -------------------------------------------------------------------------------- /tests/utils/test_arg_parser.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from utils.arg_parser import parse_and_validate_console_args 7 | 8 | 9 | @pytest.mark.parametrize( 10 | ("args", "expected_config"), 11 | [ 12 | (["--config", "config1.json"], ["config1.json"]), 13 | (["--config", "config1.json", "config2.json"], ["config1.json", "config2.json"]), 14 | ], 15 | ) 16 | @patch("os.path.exists", return_value=True) # Mock os.path.exists to always return True 17 | def test_parse_and_validate_console_args_required(mock_exists, args, expected_config): 18 | with patch.object(sys, "argv", ["program_name", *args]): 19 | result = parse_and_validate_console_args() 20 | assert result.config == expected_config, f"Expected {expected_config}, got {result.config}" 21 | 22 | 23 | @patch("os.path.exists", return_value=True) 24 | def test_parse_and_validate_console_args_save_performance_results_exists(mock_exists): 25 | with patch.object( 26 | sys, 27 | "argv", 28 | ["program_name", "--config", "config.json", "--save_performance_results", "results.json"], 29 | ): 30 | result = parse_and_validate_console_args() 31 | assert result.save_performance_results == "results.json" 32 | 33 | 34 | def test_parse_and_validate_console_args_save_performance_results_dir_does_not_exist(): 35 | with ( 36 | patch("os.path.exists", side_effect=lambda path: path == "config.json"), 37 | patch.object( 38 | sys, 39 | "argv", 40 | ["program_name", "--config", "config.json", "--save_performance_results", "non_existent_dir/results.json"], 41 | ), 42 | patch("utils.arg_parser.logging.error") as mock_log, 43 | ): 44 | with pytest.raises(RuntimeError, match="Argument validation failed."): 45 | parse_and_validate_console_args() 46 | mock_log.assert_called_once_with( 47 | "Validation failed: The directory for saving performance results does not exist: non_existent_dir", 48 | ) 49 | 50 | 51 | @patch("os.path.exists", return_value=True) 52 | def test_parse_and_validate_console_args_no_plot(mock_exists): 53 | with patch.object(sys, "argv", ["program_name", "--config", "config.json", "--no-plot"]): 54 | result = parse_and_validate_console_args() 55 | 56 | assert hasattr(result, "no_plot"), "The `no_plot` attribute is missing from the parsed result." 57 | assert result.no_plot is True, "The `no_plot` flag was not set to True." 58 | 59 | 60 | @patch("os.path.exists", return_value=True) 61 | def test_parse_and_validate_console_args_profile(mock_exists): 62 | with patch.object(sys, "argv", ["program_name", "--config", "config.json", "--profile"]): 63 | result = parse_and_validate_console_args() 64 | 65 | assert hasattr(result, "profile"), "The `profile` attribute is missing from the parsed result." 66 | assert result.profile is True, "The `profile` flag was not set to True." 67 | 68 | 69 | @patch("utils.arg_parser.logging.error") 70 | def test_parse_and_validate_console_args_argument_error(mock_log): 71 | with patch.object(sys, "argv", ["program_name", "--config"]): 72 | with pytest.raises(RuntimeError, match="Failed to parse arguments. Please check your inputs."): 73 | parse_and_validate_console_args() 74 | mock_log.assert_called_once_with("Argument parsing failed: 2") 75 | 76 | 77 | @patch("utils.arg_parser.logging.error") 78 | def test_parse_and_validate_console_args_unexpected_error(mock_log): 79 | with ( 80 | patch.object( 81 | sys, 82 | "argv", 83 | ["program_name", "--config", "config.json", "--save_performance_results", "results.json"], 84 | ), 85 | patch("os.path.exists", side_effect=Exception("Unexpected error")), 86 | ): 87 | with pytest.raises(RuntimeError, match="An unexpected error occurred during argument parsing."): 88 | parse_and_validate_console_args() 89 | mock_log.assert_any_call("An unexpected error occurred while parsing arguments: Unexpected error") 90 | -------------------------------------------------------------------------------- /core/validation/order_validator.py: -------------------------------------------------------------------------------- 1 | from .exceptions import ( 2 | InsufficientBalanceError, 3 | InsufficientCryptoBalanceError, 4 | InvalidOrderQuantityError, 5 | ) 6 | 7 | 8 | class OrderValidator: 9 | def __init__(self, tolerance: float = 1e-6, threshold_ratio: float = 0.5): 10 | """ 11 | Initializes the OrderValidator with a specified tolerance and threshold. 12 | 13 | Args: 14 | tolerance (float): Minimum precision tolerance for validation. 15 | threshold_ratio (float): Threshold below which an insufficient balance/crypto error is triggered early. 16 | """ 17 | self.tolerance = tolerance 18 | self.threshold_ratio = threshold_ratio 19 | 20 | def adjust_and_validate_buy_quantity(self, balance: float, order_quantity: float, price: float) -> float: 21 | """ 22 | Adjusts and validates the buy order quantity based on the available balance. 23 | 24 | Args: 25 | balance (float): Available fiat balance. 26 | order_quantity (float): Requested buy quantity. 27 | price (float): Price of the asset. 28 | 29 | Returns: 30 | float: Adjusted and validated buy order quantity. 31 | 32 | Raises: 33 | InsufficientBalanceError: If the balance is insufficient to place any valid order. 34 | InvalidOrderQuantityError: If the adjusted quantity is invalid. 35 | """ 36 | total_cost = order_quantity * price 37 | 38 | if balance < total_cost * self.threshold_ratio: 39 | raise InsufficientBalanceError( 40 | f"Balance {balance:.2f} is far below the required cost {total_cost:.2f} " 41 | f"(threshold ratio: {self.threshold_ratio}).", 42 | ) 43 | 44 | if total_cost > balance: 45 | adjusted_quantity = max((balance - self.tolerance) / price, 0) 46 | 47 | if adjusted_quantity <= 0 or (adjusted_quantity * price) < self.tolerance: 48 | raise InsufficientBalanceError( 49 | f"Insufficient balance: {balance:.2f} to place any buy order at price {price:.2f}.", 50 | ) 51 | else: 52 | adjusted_quantity = order_quantity 53 | 54 | self._validate_quantity(adjusted_quantity, is_buy=True) 55 | return adjusted_quantity 56 | 57 | def adjust_and_validate_sell_quantity(self, crypto_balance: float, order_quantity: float) -> float: 58 | """ 59 | Adjusts and validates the sell order quantity based on the available crypto balance. 60 | 61 | Args: 62 | crypto_balance (float): Available crypto balance. 63 | order_quantity (float): Requested sell quantity. 64 | 65 | Returns: 66 | float: Adjusted and validated sell order quantity. 67 | 68 | Raises: 69 | InsufficientCryptoBalanceError: If the crypto balance is insufficient to place any valid order. 70 | InvalidOrderQuantityError: If the adjusted quantity is invalid. 71 | """ 72 | if crypto_balance < order_quantity * self.threshold_ratio: 73 | raise InsufficientCryptoBalanceError( 74 | f"Crypto balance {crypto_balance:.6f} is far below the required quantity {order_quantity:.6f} " 75 | f"(threshold ratio: {self.threshold_ratio}).", 76 | ) 77 | 78 | adjusted_quantity = min(order_quantity, crypto_balance - self.tolerance) 79 | self._validate_quantity(adjusted_quantity, is_buy=False) 80 | return adjusted_quantity 81 | 82 | def _validate_quantity(self, quantity: float, is_buy: bool) -> None: 83 | """ 84 | Validates the adjusted order quantity. 85 | 86 | Args: 87 | quantity (float): Adjusted quantity to validate. 88 | is_buy (bool): Whether the order is a buy order. 89 | 90 | Raises: 91 | InvalidOrderQuantityError: If the quantity is invalid. 92 | """ 93 | if quantity <= 0: 94 | order_type = "buy" if is_buy else "sell" 95 | raise InvalidOrderQuantityError(f"Invalid {order_type} quantity: {quantity:.6f}") 96 | -------------------------------------------------------------------------------- /tests/validation/test_order_validator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from core.validation.exceptions import ( 4 | InsufficientBalanceError, 5 | InsufficientCryptoBalanceError, 6 | InvalidOrderQuantityError, 7 | ) 8 | from core.validation.order_validator import OrderValidator 9 | 10 | 11 | class TestOrderValidator: 12 | @pytest.fixture 13 | def validator(self): 14 | return OrderValidator() 15 | 16 | def test_adjust_and_validate_buy_quantity_valid(self, validator): 17 | balance = 5000 18 | order_quantity = 1 19 | price = 3000 20 | 21 | adjusted_quantity = validator.adjust_and_validate_buy_quantity(balance, order_quantity, price) 22 | assert adjusted_quantity == order_quantity 23 | 24 | def test_adjust_and_validate_buy_quantity_adjusted(self, validator): 25 | balance = 2000 # Insufficient for full quantity 26 | order_quantity = 1 27 | price = 3000 28 | 29 | adjusted_quantity = validator.adjust_and_validate_buy_quantity(balance, order_quantity, price) 30 | expected_quantity = (balance - validator.tolerance) / price 31 | assert adjusted_quantity == pytest.approx(expected_quantity, rel=1e-6) 32 | 33 | def test_adjust_and_validate_buy_quantity_insufficient_balance(self, validator): 34 | balance = 10 # Far below required cost 35 | order_quantity = 1 36 | price = 3000 37 | 38 | with pytest.raises(InsufficientBalanceError, match="far below the required cost"): 39 | validator.adjust_and_validate_buy_quantity(balance, order_quantity, price) 40 | 41 | def test_adjust_and_validate_buy_quantity_invalid_quantity(self, validator): 42 | balance = 5000 43 | order_quantity = -1 # Invalid quantity 44 | price = 3000 45 | 46 | with pytest.raises(InvalidOrderQuantityError, match="Invalid buy quantity"): 47 | validator.adjust_and_validate_buy_quantity(balance, order_quantity, price) 48 | 49 | def test_adjust_and_validate_sell_quantity_valid(self, validator): 50 | crypto_balance = 5 51 | order_quantity = 3 52 | 53 | adjusted_quantity = validator.adjust_and_validate_sell_quantity(crypto_balance, order_quantity) 54 | assert adjusted_quantity == order_quantity 55 | 56 | def test_adjust_and_validate_sell_quantity_adjusted(self, validator): 57 | crypto_balance = 2 # Insufficient for full quantity 58 | order_quantity = 3 59 | 60 | adjusted_quantity = validator.adjust_and_validate_sell_quantity(crypto_balance, order_quantity) 61 | expected_quantity = crypto_balance - validator.tolerance 62 | assert adjusted_quantity == pytest.approx(expected_quantity, rel=1e-6) 63 | 64 | def test_adjust_and_validate_sell_quantity_insufficient_balance(self, validator): 65 | crypto_balance = 0.1 # Far below required amount 66 | order_quantity = 3 67 | 68 | with pytest.raises(InsufficientCryptoBalanceError, match="far below the required quantity"): 69 | validator.adjust_and_validate_sell_quantity(crypto_balance, order_quantity) 70 | 71 | def test_adjust_and_validate_sell_quantity_invalid_quantity(self, validator): 72 | crypto_balance = 5 73 | order_quantity = -3 # Invalid quantity 74 | 75 | with pytest.raises(InvalidOrderQuantityError, match="Invalid sell quantity"): 76 | validator.adjust_and_validate_sell_quantity(crypto_balance, order_quantity) 77 | 78 | def test_adjust_and_validate_buy_quantity_tolerance_threshold(self, validator): 79 | balance = 1400 # Just above tolerance threshold 80 | order_quantity = 1 81 | price = 3000 82 | 83 | with pytest.raises(InsufficientBalanceError, match="Balance .* is far below the required cost .*"): 84 | validator.adjust_and_validate_buy_quantity(balance, order_quantity, price) 85 | 86 | def test_adjust_and_validate_sell_quantity_tolerance_threshold(self, validator): 87 | crypto_balance = 0.001 # Just above tolerance threshold 88 | order_quantity = 3 89 | 90 | with pytest.raises( 91 | InsufficientCryptoBalanceError, 92 | match="Crypto balance .* is far below the required quantity .*", 93 | ): 94 | validator.adjust_and_validate_sell_quantity(crypto_balance, order_quantity) 95 | -------------------------------------------------------------------------------- /tests/order_handling/test_order.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from core.order_handling.order import Order, OrderSide, OrderStatus, OrderType 4 | 5 | 6 | class TestOrder: 7 | @pytest.fixture 8 | def sample_order(self): 9 | return Order( 10 | identifier="123", 11 | status=OrderStatus.OPEN, 12 | order_type=OrderType.LIMIT, 13 | side=OrderSide.BUY, 14 | price=1000.0, 15 | average=None, 16 | amount=5.0, 17 | filled=0.0, 18 | remaining=5.0, 19 | timestamp=1695890800, 20 | datetime="2024-01-01T00:00:00Z", 21 | last_trade_timestamp=None, 22 | symbol="BTC/USDT", 23 | time_in_force="GTC", 24 | ) 25 | 26 | def test_create_order_with_valid_data(self, sample_order): 27 | assert sample_order.identifier == "123" 28 | assert sample_order.status == OrderStatus.OPEN 29 | assert sample_order.order_type == OrderType.LIMIT 30 | assert sample_order.side == OrderSide.BUY 31 | assert sample_order.price == 1000.0 32 | assert sample_order.amount == 5.0 33 | assert sample_order.filled == 0.0 34 | assert sample_order.remaining == 5.0 35 | assert sample_order.timestamp == 1695890800 36 | assert sample_order.datetime == "2024-01-01T00:00:00Z" 37 | assert sample_order.symbol == "BTC/USDT" 38 | assert sample_order.time_in_force == "GTC" 39 | 40 | def test_is_filled(self, sample_order): 41 | sample_order.status = OrderStatus.CLOSED 42 | assert sample_order.is_filled() is True 43 | sample_order.status = OrderStatus.OPEN 44 | assert sample_order.is_filled() is False 45 | 46 | def test_is_canceled(self, sample_order): 47 | sample_order.status = OrderStatus.CANCELED 48 | assert sample_order.is_canceled() is True 49 | sample_order.status = OrderStatus.OPEN 50 | assert sample_order.is_canceled() is False 51 | 52 | def test_is_open(self, sample_order): 53 | sample_order.status = OrderStatus.OPEN 54 | assert sample_order.is_open() is True 55 | sample_order.status = OrderStatus.CLOSED 56 | assert sample_order.is_open() is False 57 | 58 | def test_format_last_trade_timestamp(self, sample_order): 59 | # Case 1: No last trade timestamp 60 | assert sample_order.format_last_trade_timestamp() is None 61 | 62 | # Case 2: Valid last trade timestamp 63 | sample_order.last_trade_timestamp = 1695890800 64 | assert sample_order.format_last_trade_timestamp() == "2023-09-28T08:46:40" 65 | 66 | def test_order_str_representation(self, sample_order): 67 | order_str = str(sample_order) 68 | assert "Order(id=123, status=OrderStatus.OPEN" in order_str 69 | assert "type=OrderType.LIMIT, side=OrderSide.BUY, price=1000.0" in order_str 70 | 71 | def test_order_repr_representation(self, sample_order): 72 | order_repr = repr(sample_order) 73 | assert order_repr == str(sample_order) 74 | 75 | def test_order_with_trades_and_fee(self): 76 | order = Order( 77 | identifier="456", 78 | status=OrderStatus.CLOSED, 79 | order_type=OrderType.LIMIT, 80 | side=OrderSide.SELL, 81 | price=2000.0, 82 | average=1950.0, 83 | amount=3.0, 84 | filled=3.0, 85 | remaining=0.0, 86 | timestamp=1695890800, 87 | datetime="2024-01-01T00:00:00Z", 88 | last_trade_timestamp=1695890900, 89 | symbol="ETH/USDT", 90 | time_in_force="GTC", 91 | trades=[ 92 | {"id": "trade1", "price": 1950.0, "amount": 1.0}, 93 | {"id": "trade2", "price": 1950.0, "amount": 2.0}, 94 | ], 95 | fee={"currency": "USDT", "cost": 5.0}, 96 | cost=5850.0, 97 | ) 98 | assert order.is_filled() is True 99 | assert order.fee == {"currency": "USDT", "cost": 5.0} 100 | assert order.trades == [ 101 | {"id": "trade1", "price": 1950.0, "amount": 1.0}, 102 | {"id": "trade2", "price": 1950.0, "amount": 2.0}, 103 | ] 104 | assert order.cost == 5850.0 105 | -------------------------------------------------------------------------------- /tests/grid_management/test_grid_level.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import pytest 4 | 5 | from core.grid_management.grid_level import GridCycleState, GridLevel 6 | from core.order_handling.order import Order 7 | 8 | 9 | class TestGridLevel: 10 | @pytest.fixture 11 | def grid_level(self): 12 | return GridLevel(price=1000, state=GridCycleState.READY_TO_BUY) 13 | 14 | def test_grid_level_initialization(self): 15 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_BUY) 16 | 17 | assert grid_level.price == 1000 18 | assert grid_level.state == GridCycleState.READY_TO_BUY 19 | assert grid_level.orders == [] 20 | assert grid_level.paired_buy_level is None 21 | assert grid_level.paired_sell_level is None 22 | 23 | def test_add_order(self, grid_level): 24 | mock_order = Mock(spec=Order) 25 | grid_level.add_order(mock_order) 26 | 27 | assert len(grid_level.orders) == 1 28 | assert grid_level.orders[0] == mock_order 29 | 30 | def test_str_representation(self): 31 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_BUY) 32 | grid_level.paired_buy_level = GridLevel(price=900, state=GridCycleState.READY_TO_SELL) 33 | grid_level.paired_sell_level = GridLevel(price=1100, state=GridCycleState.READY_TO_BUY) 34 | 35 | repr_str = str(grid_level) 36 | assert "price=1000" in repr_str 37 | assert "state=READY_TO_BUY" in repr_str 38 | assert "paired_buy_level=900" in repr_str 39 | assert "paired_sell_level=1100" in repr_str 40 | 41 | def test_update_paired_levels(self): 42 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_BUY) 43 | paired_buy_level = GridLevel(price=900, state=GridCycleState.READY_TO_SELL) 44 | paired_sell_level = GridLevel(price=1100, state=GridCycleState.READY_TO_BUY) 45 | 46 | grid_level.paired_buy_level = paired_buy_level 47 | grid_level.paired_sell_level = paired_sell_level 48 | 49 | assert grid_level.paired_buy_level.price == 900 50 | assert grid_level.paired_sell_level.price == 1100 51 | 52 | def test_state_transition_to_waiting_for_buy_fill(self, grid_level): 53 | grid_level.state = GridCycleState.READY_TO_BUY 54 | mock_order = Mock(spec=Order) 55 | grid_level.add_order(mock_order) 56 | 57 | grid_level.state = GridCycleState.WAITING_FOR_BUY_FILL 58 | assert grid_level.state == GridCycleState.WAITING_FOR_BUY_FILL 59 | 60 | def test_state_transition_to_ready_to_sell(self, grid_level): 61 | grid_level.state = GridCycleState.READY_TO_BUY 62 | mock_order = Mock(spec=Order) 63 | grid_level.add_order(mock_order) 64 | 65 | grid_level.state = GridCycleState.READY_TO_SELL 66 | assert grid_level.state == GridCycleState.READY_TO_SELL 67 | 68 | def test_state_transition_to_waiting_for_sell_fill(self): 69 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_SELL) 70 | mock_order = Mock(spec=Order) 71 | grid_level.add_order(mock_order) 72 | 73 | grid_level.state = GridCycleState.WAITING_FOR_SELL_FILL 74 | assert grid_level.state == GridCycleState.WAITING_FOR_SELL_FILL 75 | 76 | def test_state_transition_to_ready_to_buy_or_sell(self): 77 | grid_level = GridLevel(price=1000, state=GridCycleState.WAITING_FOR_SELL_FILL) 78 | 79 | grid_level.state = GridCycleState.READY_TO_BUY_OR_SELL 80 | assert grid_level.state == GridCycleState.READY_TO_BUY_OR_SELL 81 | 82 | def test_paired_levels_initialization(self): 83 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_BUY) 84 | paired_buy_level = GridLevel(price=900, state=GridCycleState.READY_TO_SELL) 85 | paired_sell_level = GridLevel(price=1100, state=GridCycleState.READY_TO_BUY) 86 | 87 | grid_level.paired_buy_level = paired_buy_level 88 | grid_level.paired_sell_level = paired_sell_level 89 | 90 | assert grid_level.paired_buy_level == paired_buy_level 91 | assert grid_level.paired_sell_level == paired_sell_level 92 | 93 | def test_orders_list(self): 94 | grid_level = GridLevel(price=1000, state=GridCycleState.READY_TO_BUY) 95 | order1 = Mock(spec=Order) 96 | order2 = Mock(spec=Order) 97 | 98 | grid_level.add_order(order1) 99 | grid_level.add_order(order2) 100 | 101 | assert len(grid_level.orders) == 2 102 | assert grid_level.orders[0] == order1 103 | assert grid_level.orders[1] == order2 104 | -------------------------------------------------------------------------------- /core/bot_management/bot_controller/bot_controller.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from tabulate import tabulate 5 | 6 | from core.bot_management.event_bus import EventBus, Events 7 | from core.bot_management.grid_trading_bot import GridTradingBot 8 | 9 | from .exceptions import CommandParsingError, StrategyControlError 10 | 11 | 12 | class BotController: 13 | """ 14 | Handles user commands and manages the lifecycle of the GridTradingBot. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | bot: GridTradingBot, 20 | event_bus: EventBus, 21 | ): 22 | """ 23 | Initializes the BotController. 24 | 25 | Args: 26 | bot: The GridTradingBot instance to control. 27 | event_bus: The EventBus instance to publish/subscribe Events. 28 | """ 29 | self.logger = logging.getLogger(self.__class__.__name__) 30 | self.bot = bot 31 | self.event_bus = event_bus 32 | self._stop_listening = False 33 | self.event_bus.subscribe(Events.STOP_BOT, self._handle_stop_event) 34 | 35 | async def command_listener(self): 36 | """ 37 | Listens for user commands and processes them. 38 | """ 39 | self.logger.info("Command listener started. Type 'quit' to exit.") 40 | loop = asyncio.get_event_loop() 41 | 42 | while not self._stop_listening: 43 | try: 44 | command = await loop.run_in_executor( 45 | None, 46 | input, 47 | "Enter command (quit, orders, balance, stop, restart, pause): ", 48 | ) 49 | await self._handle_command(command.strip().lower()) 50 | 51 | except CommandParsingError as e: 52 | self.logger.warning(f"Command error: {e}") 53 | 54 | except Exception as e: 55 | self.logger.error(f"Unexpected error in command listener: {e}", exc_info=True) 56 | 57 | async def _handle_command(self, command: str): 58 | """ 59 | Handles individual commands from the user. 60 | 61 | Args: 62 | command: The command entered by the user. 63 | """ 64 | if command == "quit": 65 | self.logger.info("Stop bot command received") 66 | self.event_bus.publish_sync(Events.STOP_BOT, "User requested shutdown") 67 | 68 | elif command == "orders": 69 | await self._display_orders() 70 | 71 | elif command == "balance": 72 | await self._display_balance() 73 | 74 | elif command == "stop": 75 | self.event_bus.publish_sync(Events.STOP_BOT, "User issued stop command") 76 | 77 | elif command == "restart": 78 | self.event_bus.publish_sync(Events.STOP_BOT, "User issued restart command") 79 | self.event_bus.publish_sync(Events.START_BOT, "User issued restart command") 80 | 81 | elif command.startswith("pause"): 82 | await self._pause_bot(command) 83 | 84 | else: 85 | raise CommandParsingError(f"Unknown command: {command}") 86 | 87 | def _stop_listener(self): 88 | """ 89 | Stops the command listener loop. 90 | """ 91 | self._stop_listening = True 92 | self.logger.info("Command listener stopped.") 93 | 94 | def _handle_stop_event(self, reason: str) -> None: 95 | """ 96 | Handles the STOP_BOT event and stops the command listener. 97 | 98 | Args: 99 | reason: The reason for stopping the bot. 100 | """ 101 | self.logger.info(f"Received STOP_BOT event: {reason}") 102 | self._stop_listener() 103 | 104 | async def _display_orders(self): 105 | """ 106 | Displays formatted orders retrieved from the bot. 107 | """ 108 | self.logger.info("Display orders bot command received") 109 | formatted_orders = self.bot.strategy.get_formatted_orders() 110 | orders_table = tabulate( 111 | formatted_orders, 112 | headers=["Order Side", "Type", "Status", "Price", "Quantity", "Timestamp", "Grid Level", "Slippage"], 113 | tablefmt="pipe", 114 | ) 115 | self.logger.info("\nFormatted Orders:\n" + orders_table) 116 | 117 | async def _display_balance(self): 118 | """ 119 | Displays the current balances retrieved from the bot. 120 | """ 121 | self.logger.info("Display balance bot command received") 122 | current_balances = self.bot.get_balances() 123 | self.logger.info(f"Current balances: {current_balances}") 124 | 125 | async def _pause_bot(self, command: str): 126 | """ 127 | Pauses the bot for a specified duration. 128 | 129 | Args: 130 | command: The pause command containing the duration. 131 | """ 132 | try: 133 | self.logger.info("Pause bot command received") 134 | duration = int(command.split()[1]) 135 | await self.event_bus.publish(Events.STOP_BOT, "User issued pause command") 136 | self.logger.info(f"Bot paused for {duration} seconds.") 137 | await asyncio.sleep(duration) 138 | self.logger.info("Resuming bot after pause.") 139 | await self.event_bus.publish(Events.START_BOT, "Resuming bot after pause") 140 | 141 | except ValueError: 142 | raise CommandParsingError("Invalid pause duration. Please specify in seconds.") from None 143 | 144 | except Exception as e: 145 | raise StrategyControlError(f"Error during pause operation: {e}") from e 146 | -------------------------------------------------------------------------------- /core/bot_management/event_bus.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections.abc import Awaitable, Callable 3 | import inspect 4 | import logging 5 | from typing import Any 6 | 7 | 8 | class Events: 9 | """ 10 | Defines event types for the EventBus. 11 | """ 12 | 13 | ORDER_FILLED = "order_filled" 14 | ORDER_CANCELLED = "order_cancelled" 15 | START_BOT = "start_bot" 16 | STOP_BOT = "stop_bot" 17 | 18 | 19 | class EventBus: 20 | """ 21 | A simple event bus for managing pub-sub interactions with support for both sync and async publishing. 22 | """ 23 | 24 | def __init__(self): 25 | """ 26 | Initializes the EventBus with an empty subscriber list. 27 | """ 28 | self.logger = logging.getLogger(self.__class__.__name__) 29 | self.subscribers: dict[str, list[Callable[[Any], None]]] = {} 30 | self._tasks: set[asyncio.Task] = set() 31 | 32 | def subscribe( 33 | self, 34 | event_type: str, 35 | callback: Callable[[Any], None] | Callable[[Any], Awaitable[None]], 36 | ) -> None: 37 | """ 38 | Subscribes a callback to a specific event type. 39 | 40 | Args: 41 | event_type: The type of event to subscribe to. 42 | callback: The callback function to invoke when the event is published. 43 | """ 44 | if event_type not in self.subscribers: 45 | self.subscribers[event_type] = [] 46 | 47 | self.subscribers[event_type].append(callback) 48 | callback_name = getattr(callback, "__name__", str(callback)) 49 | caller_frame = inspect.stack()[1] 50 | caller_name = f"{caller_frame.function} (from {caller_frame.filename}:{caller_frame.lineno})" 51 | self.logger.info(f"Callback '{callback_name}' subscribed to event: {event_type} by {caller_name}") 52 | 53 | async def publish( 54 | self, 55 | event_type: str, 56 | data: Any = None, 57 | ) -> None: 58 | """ 59 | Publishes an event asynchronously to all subscribers. 60 | """ 61 | if event_type not in self.subscribers: 62 | self.logger.warning(f"No subscribers for event: {event_type}") 63 | return 64 | 65 | self.logger.info(f"Publishing async event: {event_type} with data: {data}") 66 | tasks = [ 67 | self._safe_invoke_async(callback, data) 68 | if asyncio.iscoroutinefunction(callback) 69 | else asyncio.to_thread(self._safe_invoke_sync, callback, data) 70 | for callback in self.subscribers[event_type] 71 | ] 72 | if tasks: 73 | results = await asyncio.gather(*tasks, return_exceptions=True) 74 | for result in results: 75 | if isinstance(result, Exception): 76 | self.logger.error(f"Exception in async event callback: {result}", exc_info=True) 77 | 78 | def publish_sync( 79 | self, 80 | event_type: str, 81 | data: Any, 82 | ) -> None: 83 | """ 84 | Publishes an event synchronously to all subscribers. 85 | """ 86 | if event_type in self.subscribers: 87 | self.logger.info(f"Publishing sync event: {event_type} with data: {data}") 88 | loop = asyncio.get_event_loop() 89 | for callback in self.subscribers[event_type]: 90 | if asyncio.iscoroutinefunction(callback): 91 | asyncio.run_coroutine_threadsafe(self._safe_invoke_async(callback, data), loop) 92 | else: 93 | self._safe_invoke_sync(callback, data) 94 | 95 | async def _safe_invoke_async( 96 | self, 97 | callback: Callable[[Any], None], 98 | data: Any, 99 | ) -> None: 100 | """ 101 | Safely invokes an async callback, suppressing and logging any exceptions. 102 | """ 103 | task = asyncio.create_task(self._invoke_callback(callback, data)) 104 | self._tasks.add(task) 105 | 106 | def remove_task(completed_task: asyncio.Task): 107 | if not completed_task.cancelled(): 108 | self._tasks.discard(completed_task) 109 | 110 | task.add_done_callback(remove_task) 111 | self.logger.debug(f"Task created for callback '{callback.__name__}' with data: {data}") 112 | 113 | async def _invoke_callback( 114 | self, 115 | callback: Callable[[Any], None], 116 | data: Any, 117 | ) -> None: 118 | try: 119 | self.logger.info(f"Executing async callback '{callback.__name__}' for event with data: {data}") 120 | await callback(data) 121 | except Exception as e: 122 | self.logger.error(f"Error in async callback '{callback.__name__}': {e}", exc_info=True) 123 | 124 | def _safe_invoke_sync( 125 | self, 126 | callback: Callable[[Any], None], 127 | data: Any, 128 | ) -> None: 129 | """ 130 | Safely invokes a sync callback, suppressing and logging any exceptions. 131 | """ 132 | try: 133 | callback(data) 134 | except Exception as e: 135 | self.logger.error(f"Error in sync subscriber callback: {e}", exc_info=True) 136 | 137 | async def shutdown(self): 138 | """ 139 | Cancels all active tasks tracked by the EventBus for graceful shutdown. 140 | """ 141 | self.logger.info("Shutting down EventBus...") 142 | for task in self._tasks: 143 | task.cancel() 144 | await asyncio.gather(*self._tasks, return_exceptions=True) 145 | self._tasks.clear() 146 | self.logger.info("EventBus shutdown complete.") 147 | -------------------------------------------------------------------------------- /tests/order_handling/test_order_book.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import pytest 4 | 5 | from core.grid_management.grid_level import GridLevel 6 | from core.order_handling.order import Order, OrderSide, OrderStatus 7 | from core.order_handling.order_book import OrderBook 8 | 9 | 10 | class TestOrderBook: 11 | @pytest.fixture 12 | def setup_order_book(self): 13 | return OrderBook() 14 | 15 | def test_add_order_with_grid(self, setup_order_book): 16 | order_book = setup_order_book 17 | buy_order = Mock(spec=Order, side=OrderSide.BUY) 18 | sell_order = Mock(spec=Order, side=OrderSide.SELL) 19 | grid_level = Mock(spec=GridLevel) 20 | 21 | order_book.add_order(buy_order, grid_level) 22 | order_book.add_order(sell_order, grid_level) 23 | 24 | assert len(order_book.buy_orders) == 1 25 | assert len(order_book.sell_orders) == 1 26 | assert order_book.order_to_grid_map[buy_order] == grid_level 27 | assert order_book.order_to_grid_map[sell_order] == grid_level 28 | 29 | def test_add_order_without_grid(self, setup_order_book): 30 | order_book = setup_order_book 31 | non_grid_order = Mock(spec=Order, side=OrderSide.SELL) 32 | 33 | order_book.add_order(non_grid_order) 34 | 35 | assert len(order_book.non_grid_orders) == 1 36 | assert order_book.non_grid_orders[0] == non_grid_order 37 | 38 | def test_get_buy_orders_with_grid(self, setup_order_book): 39 | order_book = setup_order_book 40 | buy_order = Mock(spec=Order, side=OrderSide.BUY) 41 | grid_level = Mock(spec=GridLevel) 42 | 43 | order_book.add_order(buy_order, grid_level) 44 | result = order_book.get_buy_orders_with_grid() 45 | 46 | assert len(result) == 1 47 | assert result[0] == (buy_order, grid_level) 48 | 49 | def test_get_sell_orders_with_grid(self, setup_order_book): 50 | order_book = setup_order_book 51 | sell_order = Mock(spec=Order, side=OrderSide.SELL) 52 | grid_level = Mock(spec=GridLevel) 53 | 54 | order_book.add_order(sell_order, grid_level) 55 | result = order_book.get_sell_orders_with_grid() 56 | 57 | assert len(result) == 1 58 | assert result[0] == (sell_order, grid_level) 59 | 60 | def test_get_all_buy_orders(self, setup_order_book): 61 | order_book = setup_order_book 62 | buy_order_1 = Mock(spec=Order, side=OrderSide.BUY) 63 | buy_order_2 = Mock(spec=Order, side=OrderSide.BUY) 64 | 65 | order_book.add_order(buy_order_1) 66 | order_book.add_order(buy_order_2) 67 | result = order_book.get_all_buy_orders() 68 | 69 | assert len(result) == 2 70 | assert buy_order_1 in result 71 | assert buy_order_2 in result 72 | 73 | def test_get_all_sell_orders(self, setup_order_book): 74 | order_book = setup_order_book 75 | sell_order_1 = Mock(spec=Order, side=OrderSide.SELL) 76 | sell_order_2 = Mock(spec=Order, side=OrderSide.SELL) 77 | 78 | order_book.add_order(sell_order_1) 79 | order_book.add_order(sell_order_2) 80 | result = order_book.get_all_sell_orders() 81 | 82 | assert len(result) == 2 83 | assert sell_order_1 in result 84 | assert sell_order_2 in result 85 | 86 | def test_get_open_orders(self, setup_order_book): 87 | order_book = setup_order_book 88 | open_order = Mock(spec=Order, side=OrderSide.BUY, is_open=Mock(return_value=True)) 89 | closed_order = Mock(spec=Order, side=OrderSide.SELL, is_open=Mock(return_value=False)) 90 | 91 | order_book.add_order(open_order) 92 | order_book.add_order(closed_order) 93 | result = order_book.get_open_orders() 94 | 95 | assert len(result) == 1 96 | assert open_order in result 97 | 98 | def test_get_completed_orders(self, setup_order_book): 99 | order_book = setup_order_book 100 | completed_order = Mock(spec=Order, side=OrderSide.BUY, is_filled=Mock(return_value=True)) 101 | pending_order = Mock(spec=Order, side=OrderSide.BUY, is_filled=Mock(return_value=False)) 102 | 103 | order_book.add_order(completed_order) 104 | order_book.add_order(pending_order) 105 | result = order_book.get_completed_orders() 106 | 107 | assert len(result) == 1 108 | assert completed_order in result 109 | 110 | def test_get_grid_level_for_order(self, setup_order_book): 111 | order_book = setup_order_book 112 | order = Mock(spec=Order, side=OrderSide.BUY) 113 | grid_level = Mock(spec=GridLevel) 114 | 115 | order_book.add_order(order, grid_level) 116 | result = order_book.get_grid_level_for_order(order) 117 | 118 | assert result == grid_level 119 | 120 | def test_update_order_status(self, setup_order_book): 121 | order_book = setup_order_book 122 | order = Mock(spec=Order, identifier="order_123", side=OrderSide.BUY, status=OrderStatus.OPEN) 123 | 124 | order_book.add_order(order) 125 | order_book.update_order_status("order_123", OrderStatus.CLOSED) 126 | 127 | assert order.status == OrderStatus.CLOSED 128 | 129 | def test_update_order_status_nonexistent_order(self, setup_order_book): 130 | order_book = setup_order_book 131 | order = Mock(spec=Order, identifier="order_123", status=OrderStatus.OPEN) 132 | order.side = OrderSide.BUY 133 | 134 | order_book.add_order(order) 135 | order_book.update_order_status("nonexistent_order", OrderStatus.CLOSED) 136 | 137 | assert order.status == OrderStatus.OPEN # Ensure no changes for non-existent orders 138 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement by emailing 63 | [tetej171@gmail.com]. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import cProfile 3 | import logging 4 | import os 5 | from typing import Any 6 | 7 | from dotenv import load_dotenv 8 | 9 | from config.config_manager import ConfigManager 10 | from config.config_validator import ConfigValidator 11 | from config.exceptions import ConfigError 12 | from config.trading_mode import TradingMode 13 | from core.bot_management.bot_controller.bot_controller import BotController 14 | from core.bot_management.event_bus import EventBus 15 | from core.bot_management.grid_trading_bot import GridTradingBot 16 | from core.bot_management.health_check import HealthCheck 17 | from core.bot_management.notification.notification_handler import NotificationHandler 18 | from utils.arg_parser import parse_and_validate_console_args 19 | from utils.config_name_generator import generate_config_name 20 | from utils.logging_config import setup_logging 21 | from utils.performance_results_saver import save_or_append_performance_results 22 | 23 | 24 | def initialize_config(config_path: str) -> ConfigManager: 25 | load_dotenv() 26 | try: 27 | return ConfigManager(config_path, ConfigValidator()) 28 | 29 | except ConfigError as e: 30 | logging.error(f"An error occured during the initialization of ConfigManager {e}") 31 | exit(1) 32 | 33 | 34 | def initialize_notification_handler(config_manager: ConfigManager, event_bus: EventBus) -> NotificationHandler: 35 | notification_urls = os.getenv("APPRISE_NOTIFICATION_URLS", "").split(",") 36 | trading_mode = config_manager.get_trading_mode() 37 | return NotificationHandler(event_bus, notification_urls, trading_mode) 38 | 39 | 40 | async def run_bot( 41 | config_path: str, 42 | profile: bool = False, 43 | save_performance_results_path: str | None = None, 44 | no_plot: bool = False, 45 | ) -> dict[str, Any] | None: 46 | config_manager = initialize_config(config_path) 47 | config_name = generate_config_name(config_manager) 48 | setup_logging(config_manager.get_logging_level(), config_manager.should_log_to_file(), config_name) 49 | event_bus = EventBus() 50 | notification_handler = initialize_notification_handler(config_manager, event_bus) 51 | bot = GridTradingBot( 52 | config_path, 53 | config_manager, 54 | notification_handler, 55 | event_bus, 56 | save_performance_results_path, 57 | no_plot, 58 | ) 59 | bot_controller = BotController(bot, event_bus) 60 | health_check = HealthCheck(bot, notification_handler, event_bus) 61 | 62 | if profile: 63 | cProfile.runctx("asyncio.run(bot.run())", globals(), locals(), "profile_results.prof") 64 | return None 65 | 66 | try: 67 | if bot.trading_mode in {TradingMode.LIVE, TradingMode.PAPER_TRADING}: 68 | bot_task = asyncio.create_task(bot.run(), name="BotTask") 69 | bot_controller_task = asyncio.create_task(bot_controller.command_listener(), name="BotControllerTask") 70 | health_check_task = asyncio.create_task(health_check.start(), name="HealthCheckTask") 71 | await asyncio.gather(bot_task, bot_controller_task, health_check_task) 72 | else: 73 | await bot.run() 74 | 75 | except asyncio.CancelledError: 76 | logging.info("Cancellation received. Shutting down gracefully.") 77 | 78 | except Exception as e: 79 | logging.error(f"An unexpected error occurred: {e}", exc_info=True) 80 | 81 | finally: 82 | try: 83 | await event_bus.shutdown() 84 | 85 | except Exception as e: 86 | logging.error(f"Error during EventBus shutdown: {e}", exc_info=True) 87 | 88 | 89 | async def cleanup_tasks(): 90 | logging.info("Shutting down bot and cleaning up tasks...") 91 | 92 | current_task = asyncio.current_task() 93 | tasks_to_cancel = { 94 | task for task in asyncio.all_tasks() if task is not current_task and not task.done() and not task.cancelled() 95 | } 96 | 97 | logging.info(f"Tasks to cancel: {len(tasks_to_cancel)}") 98 | 99 | for task in tasks_to_cancel: 100 | logging.info(f"Task to cancel: {task} - Done: {task.done()} - Cancelled: {task.cancelled()}") 101 | task.cancel() 102 | 103 | try: 104 | await asyncio.gather(*tasks_to_cancel, return_exceptions=True) 105 | 106 | except asyncio.CancelledError: 107 | logging.info("Tasks cancelled successfully.") 108 | 109 | except Exception as e: 110 | logging.error(f"Error during task cancellation: {e}", exc_info=True) 111 | 112 | 113 | if __name__ == "__main__": 114 | args = parse_and_validate_console_args() 115 | 116 | async def main(): 117 | try: 118 | tasks = [ 119 | run_bot(config_path, args.profile, args.save_performance_results, args.no_plot) 120 | for config_path in args.config 121 | ] 122 | 123 | results = await asyncio.gather(*tasks, return_exceptions=True) 124 | 125 | for index, result in enumerate(results): 126 | if isinstance(result, Exception): 127 | logging.error( 128 | f"Error occurred while running bot for config {args.config[index]}: {result}", 129 | exc_info=True, 130 | ) 131 | else: 132 | if args.save_performance_results: 133 | save_or_append_performance_results(result, args.save_performance_results) 134 | 135 | except Exception as e: 136 | logging.error(f"Critical error in main: {e}", exc_info=True) 137 | 138 | finally: 139 | await cleanup_tasks() 140 | logging.info("All tasks have completed.") 141 | 142 | asyncio.run(main()) 143 | -------------------------------------------------------------------------------- /strategies/plotter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import plotly.graph_objects as go 3 | from plotly.subplots import make_subplots 4 | 5 | from core.grid_management.grid_manager import GridManager 6 | from core.order_handling.order import Order, OrderSide 7 | from core.order_handling.order_book import OrderBook 8 | 9 | 10 | class Plotter: 11 | def __init__( 12 | self, 13 | grid_manager: GridManager, 14 | order_book: OrderBook, 15 | ): 16 | self.grid_manager = grid_manager 17 | self.order_book = order_book 18 | 19 | def plot_results( 20 | self, 21 | data: pd.DataFrame, 22 | ) -> None: 23 | fig = make_subplots(rows=3, cols=1, shared_xaxes=True, row_heights=[0.70, 0.15, 0.15], vertical_spacing=0.02) 24 | self._add_candlestick_trace(fig, data) 25 | trigger_price = self.grid_manager.get_trigger_price() 26 | self._add_trigger_price_line(fig, trigger_price) 27 | self._add_grid_lines(fig, self.grid_manager.price_grids, self.grid_manager.central_price) 28 | self._add_trade_markers(fig, self.order_book.get_completed_orders()) 29 | self._add_volume_trace(fig, data) 30 | self._add_account_value_trace(fig, data) 31 | 32 | fig.update_layout( 33 | title="Grid Trading Strategy Results", 34 | yaxis_title="Price (USDT)", 35 | yaxis2_title="Volume", 36 | yaxis3_title="Equity", 37 | xaxis={"rangeslider": {"visible": False}}, 38 | showlegend=False, 39 | ) 40 | fig.show() 41 | 42 | def _add_candlestick_trace( 43 | self, 44 | fig: go.Figure, 45 | data: pd.DataFrame, 46 | ) -> None: 47 | fig.add_trace( 48 | go.Candlestick( 49 | x=data.index, 50 | open=data["open"], 51 | high=data["high"], 52 | low=data["low"], 53 | close=data["close"], 54 | name="", 55 | ), 56 | row=1, 57 | col=1, 58 | ) 59 | 60 | def _add_trigger_price_line( 61 | self, 62 | fig: go.Figure, 63 | trigger_price: float, 64 | ): 65 | fig.add_trace( 66 | go.Scatter( 67 | x=[fig.data[0].x[0], fig.data[0].x[-1]], 68 | y=[trigger_price, trigger_price], 69 | mode="lines", 70 | line={"color": "blue", "width": 2, "dash": "dash"}, 71 | name="Central Price", 72 | ), 73 | ) 74 | fig.add_annotation( 75 | x=fig.data[0].x[-1], 76 | y=trigger_price, 77 | text=f"Trigger Price: {trigger_price:.2f}", 78 | showarrow=True, 79 | arrowhead=2, 80 | arrowsize=1, 81 | arrowcolor="blue", 82 | ax=20, 83 | ay=-20, 84 | font={"size": 10, "color": "blue"}, 85 | ) 86 | 87 | def _add_grid_lines( 88 | self, 89 | fig: go.Figure, 90 | grids: list[float], 91 | central_price: float, 92 | ) -> None: 93 | for price in grids: 94 | color = "green" if price < central_price else "red" 95 | fig.add_trace( 96 | go.Scatter( 97 | x=[fig.data[0].x[0], fig.data[0].x[-1]], 98 | y=[price, price], 99 | mode="lines", 100 | line={"color": color, "dash": "dash"}, 101 | showlegend=False, 102 | ), 103 | ) 104 | 105 | def _add_trade_markers( 106 | self, 107 | fig: go.Figure, 108 | orders: list[Order], 109 | ) -> None: 110 | for order in orders: 111 | icon_name = "triangle-up" if order.side == OrderSide.BUY else "triangle-down" 112 | icon_color = "green" if order.side == OrderSide.BUY else "red" 113 | fig.add_trace( 114 | go.Scatter( 115 | x=[order.format_last_trade_timestamp()], 116 | y=[order.price], 117 | mode="markers", 118 | marker={ 119 | "symbol": icon_name, 120 | "color": icon_color, 121 | "size": 12, 122 | "line": {"color": "black", "width": 2}, 123 | }, 124 | name=f"{order.side.name} Order", 125 | text=f"Price: {order.price}\nQty: {order.filled}\nDate: {order.format_last_trade_timestamp()}", 126 | hoverinfo="x+y+text", 127 | ), 128 | row=1, 129 | col=1, 130 | ) 131 | 132 | def _add_volume_trace( 133 | self, 134 | fig: go.Figure, 135 | data: pd.DataFrame, 136 | ) -> None: 137 | volume_colors = [ 138 | "green" if close >= open_ else "red" for close, open_ in zip(data["close"], data["open"], strict=False) 139 | ] 140 | 141 | fig.add_trace( 142 | go.Bar( 143 | x=data.index, 144 | y=data["volume"], 145 | marker={"color": volume_colors}, 146 | name="", 147 | ), 148 | row=2, 149 | col=1, 150 | ) 151 | 152 | fig.update_yaxes( 153 | title="Volume", 154 | row=2, 155 | col=1, 156 | ) 157 | 158 | def _add_account_value_trace( 159 | self, 160 | fig: go.Figure, 161 | data: pd.DataFrame, 162 | ) -> None: 163 | fig.add_trace( 164 | go.Scatter( 165 | x=data.index, 166 | y=data["account_value"], 167 | mode="lines", 168 | name="", 169 | line={"color": "purple", "width": 2}, 170 | ), 171 | row=3, 172 | col=1, 173 | ) 174 | -------------------------------------------------------------------------------- /tests/strategies/test_plotter.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | 3 | import pandas as pd 4 | import plotly.graph_objects as go 5 | from plotly.subplots import make_subplots 6 | import pytest 7 | 8 | from core.grid_management.grid_manager import GridManager 9 | from core.order_handling.order import Order, OrderSide, OrderType 10 | from core.order_handling.order_book import OrderBook 11 | from strategies.plotter import Plotter 12 | 13 | 14 | class TestPlotter: 15 | @pytest.fixture 16 | def setup_plotter(self): 17 | grid_manager = Mock(spec=GridManager) 18 | order_book = Mock(spec=OrderBook) 19 | plotter = Plotter(grid_manager=grid_manager, order_book=order_book) 20 | return plotter, grid_manager, order_book 21 | 22 | def test_add_grid_lines(self, setup_plotter): 23 | plotter, grid_manager, _ = setup_plotter 24 | fig = go.Figure() 25 | 26 | mock_x_data = [1, 2, 3] # Example x-axis values 27 | fig.add_trace(go.Scatter(x=mock_x_data, y=[100, 105, 110])) # Add a dummy trace 28 | 29 | grid_manager.price_grids = [90, 100, 110] 30 | grid_manager.central_price = 100 31 | 32 | plotter._add_grid_lines(fig, grid_manager.price_grids, grid_manager.central_price) 33 | 34 | assert len(fig.data) == 4 # 1 dummy trace + 3 grid lines 35 | assert fig.data[1].line.color == "green" # Below central price 36 | assert fig.data[2].line.color == "red" # Above central price 37 | assert fig.data[3].line.color == "red" # Above central price 38 | 39 | def test_add_trigger_price_line(self, setup_plotter): 40 | plotter, grid_manager, _ = setup_plotter 41 | fig = go.Figure() 42 | trigger_price = 105 43 | 44 | fig.add_trace(go.Scatter(x=[1, 2, 3], y=[100, 105, 110])) 45 | 46 | plotter._add_trigger_price_line(fig, trigger_price) 47 | 48 | assert len(fig.data) == 2 # 1 dummy trace + 1 trigger price line 49 | assert fig.data[1].line.color == "blue" 50 | assert "Trigger Price" in fig.layout.annotations[0].text 51 | 52 | def test_add_trade_markers(self, setup_plotter): 53 | plotter, _, order_book = setup_plotter 54 | fig = make_subplots(rows=3, cols=1, shared_xaxes=True, row_heights=[0.70, 0.15, 0.15], vertical_spacing=0.02) 55 | 56 | orders = [ 57 | Order( 58 | identifier="123", 59 | status=None, 60 | order_type=OrderType.LIMIT, 61 | side=OrderSide.BUY, 62 | price=1000.0, 63 | average=None, 64 | amount=5.0, 65 | filled=5.0, 66 | remaining=0.0, 67 | timestamp=1695890800, 68 | datetime="2024-01-01T00:00:00Z", 69 | last_trade_timestamp=1695890800, 70 | symbol="BTC/USDT", 71 | time_in_force="GTC", 72 | ), 73 | Order( 74 | identifier="124", 75 | status=None, 76 | order_type=OrderType.LIMIT, 77 | side=OrderSide.SELL, 78 | price=1200.0, 79 | average=None, 80 | amount=3.0, 81 | filled=3.0, 82 | remaining=0.0, 83 | timestamp=1695890800, 84 | datetime="2024-01-02T00:00:00Z", 85 | last_trade_timestamp=1695890800, 86 | symbol="BTC/USDT", 87 | time_in_force="GTC", 88 | ), 89 | ] 90 | order_book.get_completed_orders.return_value = orders 91 | 92 | plotter._add_trade_markers(fig, orders) 93 | 94 | assert len(fig.data) == 2 95 | assert fig.data[0].marker.color == "green" 96 | assert fig.data[1].marker.color == "red" 97 | 98 | def test_add_volume_trace(self, setup_plotter): 99 | plotter, _, _ = setup_plotter 100 | fig = make_subplots(rows=3, cols=1, shared_xaxes=True, row_heights=[0.70, 0.15, 0.15], vertical_spacing=0.02) 101 | 102 | data = pd.DataFrame( 103 | {"open": [100, 110], "close": [110, 105], "volume": [500, 700]}, 104 | index=pd.date_range("2024-01-01", periods=2), 105 | ) 106 | 107 | plotter._add_volume_trace(fig, data) 108 | 109 | assert len(fig.data) == 1 110 | assert fig.data[0].type == "bar" 111 | assert fig.data[0].y.tolist() == [500, 700] 112 | assert list(fig.data[0].marker.color) == ["green", "red"] 113 | 114 | def test_add_account_value_trace(self, setup_plotter): 115 | plotter, _, _ = setup_plotter 116 | fig = make_subplots(rows=3, cols=1, shared_xaxes=True, row_heights=[0.70, 0.15, 0.15], vertical_spacing=0.02) 117 | data = pd.DataFrame({"account_value": [10000, 10500]}, index=pd.date_range("2024-01-01", periods=2)) 118 | 119 | plotter._add_account_value_trace(fig, data) 120 | 121 | assert len(fig.data) == 1 122 | assert fig.data[0].type == "scatter" 123 | assert fig.data[0].y.tolist() == [10000, 10500] 124 | assert fig.data[0].line.color == "purple" 125 | 126 | @patch("plotly.graph_objects.Figure.show") 127 | def test_plot_results(self, mock_show, setup_plotter): 128 | plotter, grid_manager, order_book = setup_plotter 129 | data = pd.DataFrame( 130 | { 131 | "open": [100, 105], 132 | "high": [110, 115], 133 | "low": [95, 100], 134 | "close": [105, 110], 135 | "volume": [500, 700], 136 | "account_value": [10000, 10500], 137 | }, 138 | index=pd.date_range("2024-01-01", periods=2), 139 | ) 140 | 141 | grid_manager.price_grids = [90, 100, 110] 142 | grid_manager.central_price = 100 143 | grid_manager.get_trigger_price.return_value = 100 144 | order_book.get_completed_orders.return_value = [] 145 | 146 | plotter.plot_results(data) 147 | mock_show.assert_called_once() 148 | -------------------------------------------------------------------------------- /config/config_manager.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | from strategies.spacing_type import SpacingType 6 | from strategies.strategy_type import StrategyType 7 | 8 | from .exceptions import ConfigFileNotFoundError, ConfigParseError 9 | from .trading_mode import TradingMode 10 | 11 | 12 | class ConfigManager: 13 | def __init__(self, config_file, config_validator): 14 | self.logger = logging.getLogger(self.__class__.__name__) 15 | self.config_file = config_file 16 | self.config_validator = config_validator 17 | self.config = None 18 | self.load_config() 19 | 20 | def load_config(self): 21 | if not os.path.exists(self.config_file): 22 | self.logger.error(f"Config file {self.config_file} does not exist.") 23 | raise ConfigFileNotFoundError(self.config_file) 24 | 25 | with open(self.config_file) as file: 26 | try: 27 | self.config = json.load(file) 28 | self.config_validator.validate(self.config) 29 | except json.JSONDecodeError as e: 30 | self.logger.error(f"Failed to parse config file {self.config_file}: {e}") 31 | raise ConfigParseError(self.config_file, e) from e 32 | 33 | def get(self, key, default=None): 34 | return self.config.get(key, default) 35 | 36 | # --- General Accessor Methods --- 37 | def get_exchange(self): 38 | return self.config.get("exchange", {}) 39 | 40 | def get_exchange_name(self): 41 | exchange = self.get_exchange() 42 | return exchange.get("name", None) 43 | 44 | def get_trading_fee(self): 45 | exchange = self.get_exchange() 46 | return exchange.get("trading_fee", 0) 47 | 48 | def get_trading_mode(self) -> TradingMode | None: 49 | exchange = self.get_exchange() 50 | trading_mode = exchange.get("trading_mode", None) 51 | 52 | if trading_mode: 53 | return TradingMode.from_string(trading_mode) 54 | 55 | def get_pair(self): 56 | return self.config.get("pair", {}) 57 | 58 | def get_base_currency(self): 59 | pair = self.get_pair() 60 | return pair.get("base_currency", None) 61 | 62 | def get_quote_currency(self): 63 | pair = self.get_pair() 64 | return pair.get("quote_currency", None) 65 | 66 | def get_trading_settings(self): 67 | return self.config.get("trading_settings", {}) 68 | 69 | def get_timeframe(self): 70 | trading_settings = self.get_trading_settings() 71 | return trading_settings.get("timeframe", "1h") 72 | 73 | def get_period(self): 74 | trading_settings = self.get_trading_settings() 75 | return trading_settings.get("period", {}) 76 | 77 | def get_start_date(self): 78 | period = self.get_period() 79 | return period.get("start_date", None) 80 | 81 | def get_end_date(self): 82 | period = self.get_period() 83 | return period.get("end_date", None) 84 | 85 | def get_initial_balance(self): 86 | trading_settings = self.get_trading_settings() 87 | return trading_settings.get("initial_balance", 10000) 88 | 89 | def get_historical_data_file(self): 90 | trading_settings = self.get_trading_settings() 91 | return trading_settings.get("historical_data_file", None) 92 | 93 | # --- Grid Accessor Methods --- 94 | def get_grid_settings(self): 95 | return self.config.get("grid_strategy", {}) 96 | 97 | def get_strategy_type(self) -> StrategyType | None: 98 | grid_settings = self.get_grid_settings() 99 | strategy_type = grid_settings.get("type", None) 100 | 101 | if strategy_type: 102 | return StrategyType.from_string(strategy_type) 103 | 104 | def get_spacing_type(self) -> SpacingType | None: 105 | grid_settings = self.get_grid_settings() 106 | spacing_type = grid_settings.get("spacing", None) 107 | 108 | if spacing_type: 109 | return SpacingType.from_string(spacing_type) 110 | 111 | def get_num_grids(self): 112 | grid_settings = self.get_grid_settings() 113 | return grid_settings.get("num_grids", None) 114 | 115 | def get_grid_range(self): 116 | grid_settings = self.get_grid_settings() 117 | return grid_settings.get("range", {}) 118 | 119 | def get_top_range(self): 120 | grid_range = self.get_grid_range() 121 | return grid_range.get("top", None) 122 | 123 | def get_bottom_range(self): 124 | grid_range = self.get_grid_range() 125 | return grid_range.get("bottom", None) 126 | 127 | # --- Risk management (Take Profit / Stop Loss) Accessor Methods --- 128 | def get_risk_management(self): 129 | return self.config.get("risk_management", {}) 130 | 131 | def get_take_profit(self): 132 | risk_management = self.get_risk_management() 133 | return risk_management.get("take_profit", {}) 134 | 135 | def is_take_profit_enabled(self): 136 | take_profit = self.get_take_profit() 137 | return take_profit.get("enabled", False) 138 | 139 | def get_take_profit_threshold(self): 140 | take_profit = self.get_take_profit() 141 | return take_profit.get("threshold", None) 142 | 143 | def get_stop_loss(self): 144 | risk_management = self.get_risk_management() 145 | return risk_management.get("stop_loss", {}) 146 | 147 | def is_stop_loss_enabled(self): 148 | stop_loss = self.get_stop_loss() 149 | return stop_loss.get("enabled", False) 150 | 151 | def get_stop_loss_threshold(self): 152 | stop_loss = self.get_stop_loss() 153 | return stop_loss.get("threshold", None) 154 | 155 | # --- Logging Accessor Methods --- 156 | def get_logging(self): 157 | return self.config.get("logging", {}) 158 | 159 | def get_logging_level(self): 160 | logging = self.get_logging() 161 | return logging.get("log_level", {}) 162 | 163 | def should_log_to_file(self) -> bool: 164 | logging = self.get_logging() 165 | return logging.get("log_to_file", False) 166 | -------------------------------------------------------------------------------- /tests/config/test_config_validator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from config.config_validator import ConfigValidator 4 | from config.exceptions import ConfigValidationError 5 | 6 | 7 | class TestConfigValidator: 8 | @pytest.fixture 9 | def config_validator(self): 10 | return ConfigValidator() 11 | 12 | def test_validate_valid_config(self, config_validator, valid_config): 13 | try: 14 | config_validator.validate(valid_config) 15 | except ConfigValidationError: 16 | pytest.fail("Valid configuration raised ConfigValidationError") 17 | 18 | def test_validate_missing_required_fields(self, config_validator): 19 | invalid_config = { 20 | "exchange": {}, 21 | "pair": {}, 22 | "trading_settings": {}, 23 | "grid_strategy": {}, 24 | "risk_management": {}, 25 | "logging": {}, 26 | } 27 | with pytest.raises(ConfigValidationError) as excinfo: 28 | config_validator.validate(invalid_config) 29 | 30 | missing_fields = excinfo.value.missing_fields 31 | invalid_fields = excinfo.value.invalid_fields 32 | 33 | assert "pair.base_currency" in missing_fields 34 | assert "pair.quote_currency" in missing_fields 35 | assert "trading_settings.initial_balance" in missing_fields 36 | assert "trading_settings.period.start_date" in missing_fields 37 | assert "trading_settings.period.end_date" in missing_fields 38 | assert "grid_strategy.num_grids" in missing_fields 39 | assert "grid_strategy.range.top" in missing_fields 40 | assert "grid_strategy.range.bottom" in missing_fields 41 | assert "logging.log_level" in missing_fields 42 | 43 | assert "exchange.name" in invalid_fields 44 | assert "exchange.trading_fee" in invalid_fields 45 | assert "trading_settings.timeframe" in invalid_fields 46 | 47 | def test_validate_invalid_exchange(self, config_validator, valid_config): 48 | valid_config["exchange"] = {"name": "", "trading_fee": -0.01} # Invalid exchange 49 | with pytest.raises(ConfigValidationError) as excinfo: 50 | config_validator.validate(valid_config) 51 | assert "exchange.name" in excinfo.value.invalid_fields 52 | assert "exchange.trading_fee" in excinfo.value.invalid_fields 53 | 54 | def test_validate_valid_trading_modes(self, config_validator, valid_config): 55 | for mode in ["live", "paper_trading", "backtest"]: 56 | valid_config["exchange"]["trading_mode"] = mode 57 | try: 58 | config_validator.validate(valid_config) 59 | except ConfigValidationError: 60 | pytest.fail(f"Valid trading_mode '{mode}' raised ConfigValidationError") 61 | 62 | def test_validate_invalid_trading_mode(self, config_validator, valid_config): 63 | valid_config["exchange"]["trading_mode"] = "invalid_mode" 64 | with pytest.raises(ConfigValidationError, match="exchange.trading_mode"): 65 | config_validator.validate(valid_config) 66 | 67 | def test_validate_invalid_timeframe(self, config_validator, valid_config): 68 | valid_config["trading_settings"]["timeframe"] = "3h" # Invalid timeframe 69 | with pytest.raises(ConfigValidationError) as excinfo: 70 | config_validator.validate(valid_config) 71 | assert "trading_settings.timeframe" in excinfo.value.invalid_fields 72 | 73 | def test_validate_missing_period_fields(self, config_validator, valid_config): 74 | valid_config["trading_settings"]["period"] = {} # Missing start and end date 75 | with pytest.raises(ConfigValidationError) as excinfo: 76 | config_validator.validate(valid_config) 77 | assert "trading_settings.period.start_date" in excinfo.value.missing_fields 78 | assert "trading_settings.period.end_date" in excinfo.value.missing_fields 79 | 80 | def test_validate_invalid_grid_settings(self, config_validator, valid_config): 81 | # Test invalid grid type 82 | valid_config["grid_strategy"]["type"] = "invalid_type" # Invalid grid type 83 | with pytest.raises(ConfigValidationError) as excinfo: 84 | config_validator.validate(valid_config) 85 | assert "grid_strategy.type" in excinfo.value.invalid_fields 86 | 87 | # Test missing num_grids 88 | valid_config["grid_strategy"]["num_grids"] = None 89 | with pytest.raises(ConfigValidationError) as excinfo: 90 | config_validator.validate(valid_config) 91 | assert "grid_strategy.num_grids" in excinfo.value.missing_fields 92 | 93 | # Test invalid top/bottom range (bottom should be less than top) 94 | valid_config["grid_strategy"]["range"] = {"top": 2800, "bottom": 2850} # Invalid range 95 | with pytest.raises(ConfigValidationError) as excinfo: 96 | config_validator.validate(valid_config) 97 | assert "grid_strategy.range.top" in excinfo.value.invalid_fields 98 | assert "grid_strategy.range.bottom" in excinfo.value.invalid_fields 99 | 100 | def test_validate_limits_invalid_type(self, config_validator, valid_config): 101 | valid_config["risk_management"] = { 102 | "take_profit": {"enabled": "yes"}, # Invalid boolean 103 | "stop_loss": {"enabled": 1}, # Invalid boolean 104 | } 105 | with pytest.raises(ConfigValidationError) as excinfo: 106 | config_validator.validate(valid_config) 107 | assert "risk_management.take_profit.enabled" in excinfo.value.invalid_fields 108 | assert "risk_management.stop_loss.enabled" in excinfo.value.invalid_fields 109 | 110 | def test_validate_logging_invalid_level(self, config_validator, valid_config): 111 | valid_config["logging"] = { 112 | "log_level": "VERBOSE", # Invalid log level 113 | "log_to_file": True, 114 | } 115 | with pytest.raises(ConfigValidationError) as excinfo: 116 | config_validator.validate(valid_config) 117 | assert "logging.log_level" in excinfo.value.invalid_fields 118 | 119 | def test_validate_logging_missing_level(self, config_validator, valid_config): 120 | valid_config["logging"] = {"log_to_file": True} 121 | with pytest.raises(ConfigValidationError) as excinfo: 122 | config_validator.validate(valid_config) 123 | assert "logging.log_level" in excinfo.value.missing_fields 124 | -------------------------------------------------------------------------------- /tests/bot_management/test_notification_handler.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | 3 | import pytest 4 | 5 | from config.trading_mode import TradingMode 6 | from core.bot_management.event_bus import EventBus, Events 7 | from core.bot_management.notification.notification_content import NotificationType 8 | from core.bot_management.notification.notification_handler import NotificationHandler 9 | from core.order_handling.order import Order, OrderSide, OrderStatus, OrderType 10 | 11 | 12 | class TestNotificationHandler: 13 | @pytest.fixture 14 | def event_bus(self): 15 | return Mock(spec=EventBus) 16 | 17 | @pytest.fixture 18 | def notification_handler_enabled(self, event_bus): 19 | urls = ["json://localhost:8080/path"] 20 | handler = NotificationHandler( 21 | event_bus=event_bus, 22 | urls=urls, 23 | trading_mode=TradingMode.LIVE, 24 | ) 25 | return handler 26 | 27 | @pytest.fixture 28 | def notification_handler_disabled(self, event_bus): 29 | return NotificationHandler( 30 | event_bus=event_bus, 31 | urls=None, 32 | trading_mode=TradingMode.BACKTEST, 33 | ) 34 | 35 | @pytest.fixture 36 | def mock_order(self): 37 | return Order( 38 | identifier="test-123", 39 | status=OrderStatus.CLOSED, 40 | order_type=OrderType.LIMIT, 41 | side=OrderSide.BUY, 42 | price=1000.0, 43 | average=1000.0, 44 | amount=1.0, 45 | filled=1.0, 46 | remaining=0.0, 47 | timestamp=1234567890000, 48 | datetime="2024-01-01T00:00:00Z", 49 | last_trade_timestamp="2024-01-01T00:00:00Z", 50 | symbol="BTC/USDT", 51 | time_in_force="GTC", 52 | ) 53 | 54 | @patch("apprise.Apprise") 55 | def test_notification_handler_enabled_initialization(self, mock_apprise, event_bus): 56 | handler = NotificationHandler( 57 | event_bus=event_bus, 58 | urls=["mock://example.com"], 59 | trading_mode=TradingMode.LIVE, 60 | ) 61 | assert handler.enabled is True 62 | mock_apprise.return_value.add.assert_called_once_with("mock://example.com") 63 | event_bus.subscribe.assert_called_once_with(Events.ORDER_FILLED, handler._send_notification_on_order_filled) 64 | 65 | @patch("apprise.Apprise") 66 | def test_notification_handler_disabled_initialization(self, mock_apprise, event_bus): 67 | handler = NotificationHandler( 68 | event_bus=event_bus, 69 | urls=None, 70 | trading_mode=TradingMode.BACKTEST, 71 | ) 72 | assert handler.enabled is False 73 | mock_apprise.assert_not_called() 74 | event_bus.subscribe.assert_not_called() 75 | 76 | @pytest.mark.asyncio 77 | async def test_send_notification_with_predefined_content(self, notification_handler_enabled, mock_order): 78 | handler = notification_handler_enabled 79 | with patch.object(handler.apprise_instance, "notify") as mock_notify: 80 | handler.send_notification( 81 | NotificationType.ORDER_FILLED, 82 | order_details=str(mock_order), 83 | ) 84 | 85 | mock_notify.assert_called_once_with( 86 | title="Order Filled", 87 | body=f"Order has been filled successfully:\n{mock_order!s}", 88 | ) 89 | 90 | @pytest.mark.asyncio 91 | async def test_send_notification_with_missing_placeholder(self, notification_handler_enabled): 92 | handler = notification_handler_enabled 93 | with ( 94 | patch.object(handler.apprise_instance, "notify") as mock_notify, 95 | patch("logging.Logger.warning") as mock_warning, 96 | ): 97 | handler.send_notification(NotificationType.ORDER_FILLED) 98 | 99 | mock_warning.assert_called_once_with( 100 | "Missing placeholders for notification: {'order_details'}. Defaulting to 'N/A' for missing values.", 101 | ) 102 | mock_notify.assert_called_once_with( 103 | title="Order Filled", 104 | body="Order has been filled successfully:\nN/A", 105 | ) 106 | 107 | @pytest.mark.asyncio 108 | async def test_send_notification_with_order_failed(self, notification_handler_enabled): 109 | handler = notification_handler_enabled 110 | error_details = "Insufficient funds" 111 | 112 | with patch.object(handler.apprise_instance, "notify") as mock_notify: 113 | handler.send_notification( 114 | NotificationType.ORDER_FAILED, 115 | error_details=error_details, 116 | ) 117 | 118 | mock_notify.assert_called_once_with( 119 | title="Order Placement Failed", 120 | body=f"Failed to place order:\n{error_details}", 121 | ) 122 | 123 | @pytest.mark.asyncio 124 | async def test_async_send_notification_success(self, notification_handler_enabled): 125 | handler = notification_handler_enabled 126 | 127 | # Mock both the executor and send_notification 128 | with ( 129 | patch.object(handler, "_executor", create=True) as mock_executor, 130 | patch.object(handler, "send_notification") as mock_send, 131 | ): 132 | # Configure the mock executor to run the function directly 133 | mock_executor.submit = lambda f, *args, **kwargs: f(*args, **kwargs) 134 | 135 | await handler.async_send_notification( 136 | NotificationType.ORDER_FILLED, 137 | order_details="test", 138 | ) 139 | 140 | mock_send.assert_called_once_with( 141 | NotificationType.ORDER_FILLED, 142 | order_details="test", 143 | ) 144 | 145 | @pytest.mark.asyncio 146 | async def test_event_subscription_and_notification_on_order_filled( 147 | self, 148 | notification_handler_enabled, 149 | mock_order, 150 | ): 151 | handler = notification_handler_enabled 152 | with patch.object(handler, "async_send_notification") as mock_async_send: 153 | await handler._send_notification_on_order_filled(mock_order) 154 | 155 | mock_async_send.assert_called_once_with(NotificationType.ORDER_FILLED, order_details=str(mock_order)) 156 | 157 | def test_send_notification_disabled(self, notification_handler_disabled): 158 | handler = notification_handler_disabled 159 | with patch("apprise.Apprise.notify") as mock_notify: 160 | handler.send_notification(NotificationType.ORDER_FILLED, order_details="test") 161 | mock_notify.assert_not_called() 162 | -------------------------------------------------------------------------------- /tests/order_handling/test_live_order_execution_strategy.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, Mock, patch 2 | 3 | import pytest 4 | 5 | from core.order_handling.exceptions import OrderExecutionFailedError 6 | from core.order_handling.execution_strategy.live_order_execution_strategy import ( 7 | LiveOrderExecutionStrategy, 8 | ) 9 | from core.order_handling.order import OrderSide, OrderStatus, OrderType 10 | from core.services.exceptions import DataFetchError 11 | 12 | 13 | @pytest.mark.asyncio 14 | class TestLiveOrderExecutionStrategy: 15 | @pytest.fixture 16 | def setup_strategy(self): 17 | mock_exchange_service = Mock() 18 | strategy = LiveOrderExecutionStrategy(exchange_service=mock_exchange_service) 19 | return strategy, mock_exchange_service 20 | 21 | @patch("time.time", return_value=1680000000) # Mock time for predictable order IDs 22 | async def test_execute_market_order_success(self, mock_time, setup_strategy): 23 | strategy, exchange_service = setup_strategy 24 | pair = "BTC/USDT" 25 | quantity = 0.5 26 | price = 30000 27 | raw_order = { 28 | "id": "test-order-id", 29 | "status": "closed", 30 | "type": "market", 31 | "side": "buy", 32 | "price": price, 33 | "amount": quantity, 34 | "filled": quantity, 35 | "remaining": 0, 36 | "symbol": pair, 37 | "timestamp": 1680000000000, 38 | } 39 | 40 | exchange_service.place_order = AsyncMock(return_value=raw_order) 41 | 42 | order = await strategy.execute_market_order(OrderSide.BUY, pair, quantity, price) 43 | 44 | assert order is not None 45 | assert order.identifier == "test-order-id" 46 | assert order.status == OrderStatus.CLOSED 47 | assert order.order_type == OrderType.MARKET 48 | assert order.side == OrderSide.BUY 49 | assert order.price == price 50 | 51 | async def test_execute_market_order_retries(self, setup_strategy): 52 | strategy, exchange_service = setup_strategy 53 | pair = "BTC/USDT" 54 | quantity = 0.5 55 | price = 30000 56 | 57 | exchange_service.place_order = AsyncMock(side_effect=Exception("Order failed")) 58 | 59 | with pytest.raises(OrderExecutionFailedError): 60 | await strategy.execute_market_order(OrderSide.BUY, pair, quantity, price) 61 | 62 | assert exchange_service.place_order.call_count == strategy.max_retries 63 | 64 | async def test_execute_limit_order_success(self, setup_strategy): 65 | strategy, exchange_service = setup_strategy 66 | pair = "ETH/USDT" 67 | quantity = 1 68 | price = 2000 69 | raw_order = { 70 | "id": "test-limit-order-id", 71 | "status": "open", 72 | "type": "limit", 73 | "side": "sell", 74 | "price": price, 75 | "amount": quantity, 76 | "filled": 0, 77 | "remaining": quantity, 78 | "symbol": pair, 79 | } 80 | 81 | exchange_service.place_order = AsyncMock(return_value=raw_order) 82 | 83 | order = await strategy.execute_limit_order(OrderSide.SELL, pair, quantity, price) 84 | 85 | assert order is not None 86 | assert order.identifier == "test-limit-order-id" 87 | assert order.status == OrderStatus.OPEN 88 | assert order.order_type == OrderType.LIMIT 89 | assert order.side == OrderSide.SELL 90 | assert order.price == price 91 | 92 | async def test_execute_limit_order_data_fetch_error(self, setup_strategy): 93 | strategy, exchange_service = setup_strategy 94 | pair = "ETH/USDT" 95 | quantity = 1 96 | price = 2000 97 | 98 | exchange_service.place_order = AsyncMock(side_effect=DataFetchError("Exchange API error")) 99 | 100 | with pytest.raises(OrderExecutionFailedError): 101 | await strategy.execute_limit_order(OrderSide.SELL, pair, quantity, price) 102 | 103 | async def test_get_order_success(self, setup_strategy): 104 | strategy, exchange_service = setup_strategy 105 | order_id = "test-order-id" 106 | pair = "BTC/USDT" 107 | raw_order = { 108 | "id": order_id, 109 | "status": "open", 110 | "type": "limit", 111 | "side": "buy", 112 | "price": 100, 113 | "amount": 1, 114 | "filled": 0, 115 | "remaining": 1, 116 | "symbol": pair, 117 | } 118 | 119 | exchange_service.fetch_order = AsyncMock(return_value=raw_order) 120 | 121 | order = await strategy.get_order(order_id, pair) 122 | 123 | assert order is not None 124 | assert order.identifier == order_id 125 | assert order.symbol == pair 126 | assert order.status == OrderStatus.OPEN 127 | assert order.order_type == OrderType.LIMIT 128 | 129 | async def test_get_order_data_fetch_error(self, setup_strategy): 130 | strategy, exchange_service = setup_strategy 131 | order_id = "test-order-id" 132 | pair = "BTC/USDT" 133 | 134 | exchange_service.fetch_order = AsyncMock(side_effect=DataFetchError("Order not found")) 135 | 136 | with pytest.raises(DataFetchError): 137 | await strategy.get_order(order_id, pair) 138 | 139 | async def test_handle_partial_fill(self, setup_strategy): 140 | strategy, exchange_service = setup_strategy 141 | partial_order = Mock(identifier="partial-order", filled=0.5) 142 | exchange_service.cancel_order = AsyncMock(return_value={"status": "canceled"}) 143 | 144 | await strategy._handle_partial_fill(partial_order, "BTC/USDT") 145 | exchange_service.cancel_order.assert_called_once_with("partial-order", "BTC/USDT") 146 | 147 | async def test_retry_cancel_order(self, setup_strategy): 148 | strategy, exchange_service = setup_strategy 149 | order_id = "test-order-id" 150 | pair = "BTC/USDT" 151 | 152 | exchange_service.cancel_order = AsyncMock( 153 | side_effect=[ 154 | {"status": "failed"}, 155 | {"status": "canceled"}, 156 | ], 157 | ) 158 | 159 | result = await strategy._retry_cancel_order(order_id, pair) 160 | 161 | assert result is True 162 | assert exchange_service.cancel_order.call_count == 2 163 | 164 | async def test_adjust_price_buy(self, setup_strategy): 165 | strategy, _ = setup_strategy 166 | price = 30000 167 | adjusted_price = await strategy._adjust_price(OrderSide.BUY, price, 1) 168 | 169 | assert adjusted_price > price 170 | 171 | async def test_adjust_price_sell(self, setup_strategy): 172 | strategy, _ = setup_strategy 173 | price = 30000 174 | adjusted_price = await strategy._adjust_price(OrderSide.SELL, price, 1) 175 | 176 | assert adjusted_price < price 177 | -------------------------------------------------------------------------------- /core/order_handling/order_status_tracker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from core.bot_management.event_bus import EventBus, Events 5 | from core.order_handling.order import Order, OrderStatus 6 | from core.order_handling.order_book import OrderBook 7 | 8 | 9 | class OrderStatusTracker: 10 | """ 11 | Tracks the status of pending orders and publishes events 12 | when their states change (e.g., open, filled, canceled). 13 | """ 14 | 15 | def __init__( 16 | self, 17 | order_book: OrderBook, 18 | order_execution_strategy, 19 | event_bus: EventBus, 20 | polling_interval: float = 15.0, 21 | ): 22 | """ 23 | Initializes the OrderStatusTracker. 24 | 25 | Args: 26 | order_book: OrderBook instance to manage and query orders. 27 | order_execution_strategy: Strategy for querying order statuses from the exchange. 28 | event_bus: EventBus instance for publishing state change events. 29 | polling_interval: Time interval (in seconds) between status checks. 30 | """ 31 | self.order_book = order_book 32 | self.order_execution_strategy = order_execution_strategy 33 | self.event_bus = event_bus 34 | self.polling_interval = polling_interval 35 | self._monitoring_task = None 36 | self._active_tasks = set() 37 | self.logger = logging.getLogger(self.__class__.__name__) 38 | 39 | async def _track_open_order_statuses(self) -> None: 40 | """ 41 | Periodically checks the statuses of open orders and updates their states. 42 | """ 43 | try: 44 | while True: 45 | await self._process_open_orders() 46 | await asyncio.sleep(self.polling_interval) 47 | 48 | except asyncio.CancelledError: 49 | self.logger.info("OrderStatusTracker monitoring task was cancelled.") 50 | await self._cancel_active_tasks() 51 | 52 | except Exception as error: 53 | self.logger.error(f"Unexpected error in OrderStatusTracker: {error}") 54 | 55 | async def _process_open_orders(self) -> None: 56 | """ 57 | Processes open orders by querying their statuses and handling state changes. 58 | """ 59 | open_orders = self.order_book.get_open_orders() 60 | tasks = [self._create_task(self._query_and_handle_order(order)) for order in open_orders] 61 | results = await asyncio.gather(*tasks, return_exceptions=True) 62 | 63 | for result in results: 64 | if isinstance(result, Exception): 65 | self.logger.error(f"Error during order processing: {result}", exc_info=True) 66 | 67 | async def _query_and_handle_order(self, local_order: Order): 68 | """ 69 | Query order and handling state changes if needed. 70 | """ 71 | try: 72 | remote_order = await self.order_execution_strategy.get_order(local_order.identifier, local_order.symbol) 73 | self._handle_order_status_change(remote_order) 74 | 75 | except Exception as error: 76 | self.logger.error( 77 | f"Failed to query remote order with identifier {local_order.identifier}: {error}", 78 | exc_info=True, 79 | ) 80 | 81 | def _handle_order_status_change( 82 | self, 83 | remote_order: Order, 84 | ) -> None: 85 | """ 86 | Handles changes in the status of the order data fetched from the exchange. 87 | 88 | Args: 89 | remote_order: The latest `Order` object fetched from the exchange. 90 | 91 | Raises: 92 | ValueError: If critical fields (e.g., status) are missing from the remote order. 93 | """ 94 | try: 95 | if remote_order.status == OrderStatus.UNKNOWN: 96 | self.logger.error(f"Missing 'status' in remote order object: {remote_order}", exc_info=True) 97 | raise ValueError("Order data from the exchange is missing the 'status' field.") 98 | elif remote_order.status == OrderStatus.CLOSED: 99 | self.order_book.update_order_status(remote_order.identifier, OrderStatus.CLOSED) 100 | self.event_bus.publish_sync(Events.ORDER_FILLED, remote_order) 101 | self.logger.info(f"Order {remote_order.identifier} filled.") 102 | elif remote_order.status == OrderStatus.CANCELED: 103 | self.order_book.update_order_status(remote_order.identifier, OrderStatus.CANCELED) 104 | self.event_bus.publish_sync(Events.ORDER_CANCELLED, remote_order) 105 | self.logger.warning(f"Order {remote_order.identifier} was canceled.") 106 | elif remote_order.status == OrderStatus.OPEN: # Still open 107 | if remote_order.filled > 0: 108 | self.logger.info( 109 | f"Order {remote_order} partially filled. Filled: {remote_order.filled}, " 110 | f"Remaining: {remote_order.remaining}.", 111 | ) 112 | else: 113 | self.logger.info(f"Order {remote_order} is still open. No fills yet.") 114 | else: 115 | self.logger.warning( 116 | f"Unhandled order status '{remote_order.status}' for order {remote_order.identifier}.", 117 | ) 118 | 119 | except Exception as e: 120 | self.logger.error(f"Error handling order status change: {e}", exc_info=True) 121 | 122 | def _create_task(self, coro): 123 | """ 124 | Creates a managed asyncio task and adds it to the active task set. 125 | 126 | Args: 127 | coro: Coroutine to be scheduled as a task. 128 | """ 129 | task = asyncio.create_task(coro) 130 | self._active_tasks.add(task) 131 | task.add_done_callback(self._active_tasks.discard) 132 | return task 133 | 134 | async def _cancel_active_tasks(self): 135 | """ 136 | Cancels all active tasks tracked by the tracker. 137 | """ 138 | for task in self._active_tasks: 139 | task.cancel() 140 | await asyncio.gather(*self._active_tasks, return_exceptions=True) 141 | self._active_tasks.clear() 142 | 143 | def start_tracking(self) -> None: 144 | """ 145 | Starts the order tracking task. 146 | """ 147 | if self._monitoring_task and not self._monitoring_task.done(): 148 | self.logger.warning("OrderStatusTracker is already running.") 149 | return 150 | self._monitoring_task = asyncio.create_task(self._track_open_order_statuses()) 151 | self.logger.info("OrderStatusTracker has started tracking open orders.") 152 | 153 | async def stop_tracking(self) -> None: 154 | """ 155 | Stops the order tracking task. 156 | """ 157 | if self._monitoring_task: 158 | self._monitoring_task.cancel() 159 | try: 160 | await self._monitoring_task 161 | except asyncio.CancelledError: 162 | self.logger.info("OrderStatusTracker monitoring task was cancelled.") 163 | await self._cancel_active_tasks() 164 | self._monitoring_task = None 165 | self.logger.info("OrderStatusTracker has stopped tracking open orders.") 166 | -------------------------------------------------------------------------------- /tests/config/test_config_manager.py: -------------------------------------------------------------------------------- 1 | import json 2 | from unittest.mock import Mock, mock_open, patch 3 | 4 | import pytest 5 | 6 | from config.config_manager import ConfigManager 7 | from config.config_validator import ConfigValidator 8 | from config.exceptions import ConfigFileNotFoundError, ConfigParseError 9 | from config.trading_mode import TradingMode 10 | from strategies.spacing_type import SpacingType 11 | from strategies.strategy_type import StrategyType 12 | 13 | 14 | class TestConfigManager: 15 | @pytest.fixture 16 | def mock_validator(self): 17 | return Mock(spec=ConfigValidator) 18 | 19 | @pytest.fixture 20 | def config_manager(self, mock_validator, valid_config): 21 | # Mocking both open and os.path.exists to simulate a valid config file 22 | mocked_open = mock_open(read_data=json.dumps(valid_config)) 23 | with patch("builtins.open", mocked_open), patch("os.path.exists", return_value=True): 24 | return ConfigManager("config.json", mock_validator) 25 | 26 | def test_load_config_valid(self, config_manager, valid_config, mock_validator): 27 | mock_validator.validate.assert_called_once_with(valid_config) 28 | assert config_manager.config == valid_config 29 | 30 | def test_load_config_file_not_found(self, mock_validator): 31 | with patch("os.path.exists", return_value=False), pytest.raises(ConfigFileNotFoundError): 32 | ConfigManager("config.json", mock_validator) 33 | 34 | def test_load_config_json_decode_error(self, mock_validator): 35 | invalid_json = '{"invalid_json": ' # Malformed JSON 36 | mocked_open = mock_open(read_data=invalid_json) 37 | with ( 38 | patch("builtins.open", mocked_open), 39 | patch("os.path.exists", return_value=True), 40 | pytest.raises(ConfigParseError), 41 | ): 42 | ConfigManager("config.json", mock_validator) 43 | 44 | def test_get_exchange_name(self, config_manager): 45 | assert config_manager.get_exchange_name() == "binance" 46 | 47 | def test_get_trading_fee(self, config_manager): 48 | assert config_manager.get_trading_fee() == 0.001 49 | 50 | def test_get_base_currency(self, config_manager): 51 | assert config_manager.get_base_currency() == "ETH" 52 | 53 | def test_get_quote_currency(self, config_manager): 54 | assert config_manager.get_quote_currency() == "USDT" 55 | 56 | def test_get_initial_balance(self, config_manager): 57 | assert config_manager.get_initial_balance() == 10000 58 | 59 | def test_get_spacing_type(self, config_manager): 60 | assert config_manager.get_spacing_type() == SpacingType.GEOMETRIC 61 | 62 | def test_get_strategy_type(self, config_manager): 63 | assert config_manager.get_strategy_type() == StrategyType.SIMPLE_GRID 64 | 65 | def test_get_trading_mode(self, config_manager): 66 | assert config_manager.get_trading_mode() == TradingMode.BACKTEST 67 | 68 | def test_get_timeframe(self, config_manager): 69 | assert config_manager.get_timeframe() == "1m" 70 | 71 | def test_get_period(self, config_manager): 72 | expected_period = { 73 | "start_date": "2024-07-04T00:00:00Z", 74 | "end_date": "2024-07-11T00:00:00Z", 75 | } 76 | assert config_manager.get_period() == expected_period 77 | 78 | def test_get_start_date(self, config_manager): 79 | assert config_manager.get_start_date() == "2024-07-04T00:00:00Z" 80 | 81 | def test_get_end_date(self, config_manager): 82 | assert config_manager.get_end_date() == "2024-07-11T00:00:00Z" 83 | 84 | def test_get_num_grids(self, config_manager): 85 | assert config_manager.get_num_grids() == 20 86 | 87 | def test_get_grid_range(self, config_manager): 88 | expected_range = { 89 | "top": 3100, 90 | "bottom": 2850, 91 | } 92 | assert config_manager.get_grid_range() == expected_range 93 | 94 | def test_get_top_range(self, config_manager): 95 | assert config_manager.get_top_range() == 3100 96 | 97 | def test_get_bottom_range(self, config_manager): 98 | assert config_manager.get_bottom_range() == 2850 99 | 100 | def test_is_take_profit_enabled(self, config_manager): 101 | assert not config_manager.is_take_profit_enabled() 102 | 103 | def test_get_take_profit_threshold(self, config_manager): 104 | assert config_manager.get_take_profit_threshold() == 3700 105 | 106 | def test_get_stop_loss_threshold(self, config_manager): 107 | assert config_manager.get_stop_loss_threshold() == 2830 108 | 109 | def test_is_stop_loss_enabled(self, config_manager): 110 | assert not config_manager.is_stop_loss_enabled() 111 | 112 | def test_get_log_level(self, config_manager): 113 | assert config_manager.get_logging_level() == "INFO" 114 | 115 | def test_should_log_to_file_true(self, config_manager): 116 | assert config_manager.should_log_to_file() is True 117 | 118 | def test_get_trading_mode_invalid_value(self, config_manager): 119 | config_manager.config["exchange"]["trading_mode"] = "invalid_mode" 120 | 121 | with pytest.raises( 122 | ValueError, 123 | match="Invalid trading mode: 'invalid_mode'. Available modes are: backtest, paper_trading, live", 124 | ): 125 | config_manager.get_trading_mode() 126 | 127 | def test_get_spacing_type_invalid_value(self, config_manager): 128 | config_manager.config["grid_strategy"]["spacing"] = "invalid_spacing" 129 | 130 | with pytest.raises( 131 | ValueError, 132 | match="Invalid spacing type: 'invalid_spacing'. Available spacings are: arithmetic, geometric", 133 | ): 134 | config_manager.get_spacing_type() 135 | 136 | def test_get_strategy_type_invalid_value(self, config_manager): 137 | config_manager.config["grid_strategy"]["type"] = "invalid_strategy" 138 | 139 | with pytest.raises( 140 | ValueError, 141 | match="Invalid strategy type: 'invalid_strategy'. Available strategies are: simple_grid, hedged_grid", 142 | ): 143 | config_manager.get_strategy_type() 144 | 145 | def test_get_timeframe_default(self, config_manager): 146 | del config_manager.config["trading_settings"]["timeframe"] 147 | assert config_manager.get_timeframe() == "1h" 148 | 149 | def test_get_historical_data_file_default(self, config_manager): 150 | del config_manager.config["trading_settings"]["historical_data_file"] 151 | assert config_manager.get_historical_data_file() is None 152 | 153 | def test_is_take_profit_enabled_default(self, config_manager): 154 | del config_manager.config["risk_management"]["take_profit"] 155 | assert config_manager.is_take_profit_enabled() is False 156 | 157 | def test_get_take_profit_threshold_default(self, config_manager): 158 | del config_manager.config["risk_management"]["take_profit"] 159 | assert config_manager.get_take_profit_threshold() is None 160 | 161 | def test_is_stop_loss_enabled_default(self, config_manager): 162 | del config_manager.config["risk_management"]["stop_loss"] 163 | assert config_manager.is_stop_loss_enabled() is False 164 | 165 | def test_get_stop_loss_threshold_default(self, config_manager): 166 | del config_manager.config["risk_management"]["stop_loss"] 167 | assert config_manager.get_stop_loss_threshold() is None 168 | -------------------------------------------------------------------------------- /core/order_handling/execution_strategy/live_order_execution_strategy.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | from core.services.exceptions import DataFetchError 5 | from core.services.exchange_interface import ExchangeInterface 6 | 7 | from ..exceptions import OrderExecutionFailedError 8 | from ..order import Order, OrderSide, OrderStatus, OrderType 9 | from .order_execution_strategy_interface import OrderExecutionStrategyInterface 10 | 11 | 12 | class LiveOrderExecutionStrategy(OrderExecutionStrategyInterface): 13 | def __init__( 14 | self, 15 | exchange_service: ExchangeInterface, 16 | max_retries: int = 3, 17 | retry_delay: int = 1, 18 | max_slippage: float = 0.01, 19 | ) -> None: 20 | self.exchange_service = exchange_service 21 | self.max_retries = max_retries 22 | self.retry_delay = retry_delay 23 | self.max_slippage = max_slippage 24 | self.logger = logging.getLogger(self.__class__.__name__) 25 | 26 | async def execute_market_order( 27 | self, 28 | order_side: OrderSide, 29 | pair: str, 30 | quantity: float, 31 | price: float, 32 | ) -> Order | None: 33 | for attempt in range(self.max_retries): 34 | try: 35 | raw_order = await self.exchange_service.place_order( 36 | pair, 37 | OrderType.MARKET.value.lower(), 38 | order_side.name.lower(), 39 | quantity, 40 | price, 41 | ) 42 | order_result = await self._parse_order_result(raw_order) 43 | 44 | if order_result.status == OrderStatus.CLOSED: 45 | return order_result # Order fully filled 46 | 47 | elif order_result.status == OrderStatus.OPEN: 48 | await self._handle_partial_fill(order_result, pair) 49 | 50 | await asyncio.sleep(self.retry_delay) 51 | self.logger.info(f"Retrying order. Attempt {attempt + 1}/{self.max_retries}.") 52 | price = await self._adjust_price(order_side, price, attempt) 53 | 54 | except Exception as e: 55 | self.logger.error(f"Attempt {attempt + 1} failed with error: {e!s}") 56 | await asyncio.sleep(self.retry_delay) 57 | 58 | raise OrderExecutionFailedError( 59 | "Failed to execute Market order after maximum retries.", 60 | order_side, 61 | OrderType.MARKET, 62 | pair, 63 | quantity, 64 | price, 65 | ) 66 | 67 | async def execute_limit_order( 68 | self, 69 | order_side: OrderSide, 70 | pair: str, 71 | quantity: float, 72 | price: float, 73 | ) -> Order | None: 74 | try: 75 | raw_order = await self.exchange_service.place_order( 76 | pair, 77 | OrderType.LIMIT.value.lower(), 78 | order_side.name.lower(), 79 | quantity, 80 | price, 81 | ) 82 | order_result = await self._parse_order_result(raw_order) 83 | return order_result 84 | 85 | except DataFetchError as e: 86 | self.logger.error(f"DataFetchError during order execution for {pair} - {e}") 87 | raise OrderExecutionFailedError( 88 | f"Failed to execute Limit order on {pair}: {e}", 89 | order_side, 90 | OrderType.LIMIT, 91 | pair, 92 | quantity, 93 | price, 94 | ) from e 95 | 96 | except Exception as e: 97 | self.logger.error(f"Unexpected error in execute_limit_order: {e}") 98 | raise OrderExecutionFailedError( 99 | f"Unexpected error during order execution: {e}", 100 | order_side, 101 | OrderType.LIMIT, 102 | pair, 103 | quantity, 104 | price, 105 | ) from e 106 | 107 | async def get_order( 108 | self, 109 | order_id: str, 110 | pair: str, 111 | ) -> Order | None: 112 | try: 113 | raw_order = await self.exchange_service.fetch_order(order_id, pair) 114 | order_result = await self._parse_order_result(raw_order) 115 | return order_result 116 | 117 | except DataFetchError as e: 118 | raise e 119 | 120 | except Exception as e: 121 | raise DataFetchError(f"Unexpected error during order status retrieval: {e!s}") from e 122 | 123 | async def _parse_order_result( 124 | self, 125 | raw_order_result: dict, 126 | ) -> Order: 127 | """ 128 | Parses the raw order response from the exchange into an Order object. 129 | 130 | Args: 131 | raw_order_result: The raw response from the exchange. 132 | 133 | Returns: 134 | An Order object with standardized fields. 135 | """ 136 | return Order( 137 | identifier=raw_order_result.get("id", ""), 138 | status=OrderStatus(raw_order_result.get("status", "unknown").lower()), 139 | order_type=OrderType(raw_order_result.get("type", "unknown").lower()), 140 | side=OrderSide(raw_order_result.get("side", "unknown").lower()), 141 | price=raw_order_result.get("price", 0.0), 142 | average=raw_order_result.get("average"), 143 | amount=raw_order_result.get("amount", 0.0), 144 | filled=raw_order_result.get("filled", 0.0), 145 | remaining=raw_order_result.get("remaining", 0.0), 146 | timestamp=raw_order_result.get("timestamp", 0), 147 | datetime=raw_order_result.get("datetime"), 148 | last_trade_timestamp=raw_order_result.get("lastTradeTimestamp"), 149 | symbol=raw_order_result.get("symbol", ""), 150 | time_in_force=raw_order_result.get("timeInForce"), 151 | trades=raw_order_result.get("trades", []), 152 | fee=raw_order_result.get("fee"), 153 | cost=raw_order_result.get("cost"), 154 | info=raw_order_result.get("info", raw_order_result), 155 | ) 156 | 157 | async def _adjust_price( 158 | self, 159 | order_side: OrderSide, 160 | price: float, 161 | attempt: int, 162 | ) -> float: 163 | adjustment = self.max_slippage / self.max_retries * attempt 164 | return price * (1 + adjustment) if order_side == OrderSide.BUY else price * (1 - adjustment) 165 | 166 | async def _handle_partial_fill( 167 | self, 168 | order: Order, 169 | pair: str, 170 | ) -> dict | None: 171 | self.logger.info(f"Order partially filled with {order.filled}. Attempting to cancel and retry full quantity.") 172 | 173 | if not await self._retry_cancel_order(order.identifier, pair): 174 | self.logger.error(f"Unable to cancel partially filled order {order.identifier} after retries.") 175 | 176 | async def _retry_cancel_order( 177 | self, 178 | order_id: str, 179 | pair: str, 180 | ) -> bool: 181 | for cancel_attempt in range(self.max_retries): 182 | try: 183 | cancel_result = await self.exchange_service.cancel_order(order_id, pair) 184 | 185 | if cancel_result["status"] == "canceled": 186 | self.logger.info(f"Successfully canceled order {order_id}.") 187 | return True 188 | 189 | self.logger.warning(f"Cancel attempt {cancel_attempt + 1} for order {order_id} failed.") 190 | 191 | except Exception as e: 192 | self.logger.warning(f"Error during cancel attempt {cancel_attempt + 1} for order {order_id}: {e!s}") 193 | 194 | await asyncio.sleep(self.retry_delay) 195 | return False 196 | -------------------------------------------------------------------------------- /core/services/backtest_exchange_service.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import Any 5 | 6 | import ccxt 7 | import pandas as pd 8 | 9 | from config.config_manager import ConfigManager 10 | from utils.constants import CANDLE_LIMITS, TIMEFRAME_MAPPINGS 11 | 12 | from .exceptions import ( 13 | DataFetchError, 14 | HistoricalMarketDataFileNotFoundError, 15 | UnsupportedExchangeError, 16 | UnsupportedPairError, 17 | UnsupportedTimeframeError, 18 | ) 19 | from .exchange_interface import ExchangeInterface 20 | 21 | 22 | class BacktestExchangeService(ExchangeInterface): 23 | def __init__(self, config_manager: ConfigManager): 24 | self.logger = logging.getLogger(self.__class__.__name__) 25 | self.config_manager = config_manager 26 | self.historical_data_file = self.config_manager.get_historical_data_file() 27 | self.exchange_name = self.config_manager.get_exchange_name() 28 | self.exchange = self._initialize_exchange() 29 | 30 | def _initialize_exchange(self) -> ccxt.Exchange | None: 31 | try: 32 | return getattr(ccxt, self.exchange_name)() 33 | except AttributeError: 34 | raise UnsupportedExchangeError(f"The exchange '{self.exchange_name}' is not supported.") from None 35 | 36 | def _is_timeframe_supported(self, timeframe: str) -> bool: 37 | if timeframe not in self.exchange.timeframes: 38 | self.logger.error(f"Timeframe '{timeframe}' is not supported by {self.exchange_name}.") 39 | return False 40 | return True 41 | 42 | def _is_pair_supported(self, pair: str) -> bool: 43 | markets = self.exchange.load_markets() 44 | return pair in markets 45 | 46 | def fetch_ohlcv( 47 | self, 48 | pair: str, 49 | timeframe: str, 50 | start_date: str, 51 | end_date: str, 52 | ) -> pd.DataFrame: 53 | if self.historical_data_file: 54 | if not os.path.exists(self.historical_data_file): 55 | raise HistoricalMarketDataFileNotFoundError( 56 | f"Failed to load OHLCV data from file: {self.historical_data_file}", 57 | ) 58 | 59 | self.logger.info(f"Loading OHLCV data from file: {self.historical_data_file}") 60 | return self._load_ohlcv_from_file(self.historical_data_file, start_date, end_date) 61 | 62 | if not self._is_pair_supported(pair): 63 | raise UnsupportedPairError(f"Pair: {pair} is not supported by {self.exchange_name}") 64 | 65 | if not self._is_timeframe_supported(timeframe): 66 | raise UnsupportedTimeframeError(f"Timeframe '{timeframe}' is not supported by {self.exchange_name}.") 67 | 68 | self.logger.info(f"Fetching OHLCV data for {pair} from {start_date} to {end_date}") 69 | try: 70 | since = self.exchange.parse8601(start_date) 71 | until = self.exchange.parse8601(end_date) 72 | candles_per_request = self._get_candle_limit() 73 | total_candles_needed = (until - since) // self._get_timeframe_in_ms(timeframe) 74 | 75 | if total_candles_needed > candles_per_request: 76 | return self._fetch_ohlcv_in_chunks(pair, timeframe, since, until, candles_per_request) 77 | else: 78 | return self._fetch_ohlcv_single_batch(pair, timeframe, since, until) 79 | except ccxt.NetworkError as e: 80 | raise DataFetchError(f"Network issue occurred while fetching OHLCV data: {e!s}") from e 81 | except ccxt.BaseError as e: 82 | raise DataFetchError(f"Exchange-specific error occurred: {e!s}") from e 83 | except Exception as e: 84 | raise DataFetchError(f"Failed to fetch OHLCV data {e!s}.") from e 85 | 86 | def _load_ohlcv_from_file( 87 | self, 88 | file_path: str, 89 | start_date: str, 90 | end_date: str, 91 | ) -> pd.DataFrame: 92 | try: 93 | df = pd.read_csv(file_path, parse_dates=["timestamp"]) 94 | df["timestamp"] = pd.to_datetime(df["timestamp"]) 95 | df.set_index("timestamp", inplace=True) 96 | start_timestamp = pd.to_datetime(start_date).tz_localize(None) 97 | end_timestamp = pd.to_datetime(end_date).tz_localize(None) 98 | filtered_df = df.loc[start_timestamp:end_timestamp] 99 | self.logger.debug(f"Loaded {len(filtered_df)} rows of OHLCV data from file.") 100 | return filtered_df 101 | 102 | except Exception as e: 103 | raise DataFetchError(f"Failed to load OHLCV data from file: {e!s}") from e 104 | 105 | def _fetch_ohlcv_single_batch( 106 | self, 107 | pair: str, 108 | timeframe: str, 109 | since: int, 110 | until: int, 111 | ) -> pd.DataFrame: 112 | ohlcv = self._fetch_with_retry(self.exchange.fetch_ohlcv, pair, timeframe, since) 113 | return self._format_ohlcv(ohlcv, until) 114 | 115 | def _fetch_ohlcv_in_chunks( 116 | self, 117 | pair: str, 118 | timeframe: str, 119 | since: int, 120 | until: int, 121 | candles_per_request: int, 122 | ) -> pd.DataFrame: 123 | all_ohlcv = [] 124 | while since < until: 125 | ohlcv = self._fetch_with_retry(self.exchange.fetch_ohlcv, pair, timeframe, since, limit=candles_per_request) 126 | if not ohlcv: 127 | break 128 | all_ohlcv.extend(ohlcv) 129 | since = ohlcv[-1][0] + 1 130 | self.logger.info(f"Fetched up to {pd.to_datetime(since, unit='ms')}") 131 | return self._format_ohlcv(all_ohlcv, until) 132 | 133 | def _format_ohlcv( 134 | self, 135 | ohlcv, 136 | until: int, 137 | ) -> pd.DataFrame: 138 | df = pd.DataFrame(ohlcv, columns=["timestamp", "open", "high", "low", "close", "volume"]) 139 | df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") 140 | df.set_index("timestamp", inplace=True) 141 | until_timestamp = pd.to_datetime(until, unit="ms") 142 | return df[df.index <= until_timestamp] 143 | 144 | def _get_candle_limit(self) -> int: 145 | return CANDLE_LIMITS.get(self.exchange_name, 500) # Default to 500 if not found 146 | 147 | def _get_timeframe_in_ms(self, timeframe: str) -> int: 148 | return TIMEFRAME_MAPPINGS.get(timeframe, 60 * 1000) # Default to 1m if not found 149 | 150 | def _fetch_with_retry( 151 | self, 152 | method, 153 | *args, 154 | retries=3, 155 | delay=5, 156 | **kwargs, 157 | ): 158 | for attempt in range(retries): 159 | try: 160 | return method(*args, **kwargs) 161 | except Exception as e: 162 | if attempt < retries - 1: 163 | self.logger.warning(f"Attempt {attempt + 1} failed. Retrying in {delay} seconds...") 164 | time.sleep(delay) 165 | else: 166 | self.logger.error(f"Failed after {retries} attempts: {e}") 167 | raise DataFetchError(f"Failed to fetch data after {retries} attempts: {e!s}") from e 168 | 169 | async def place_order( 170 | self, 171 | pair: str, 172 | order_side: str, 173 | order_type: str, 174 | amount: float, 175 | price: float | None = None, 176 | ) -> dict[str, str | float]: 177 | raise NotImplementedError("place_order is not used in backtesting") 178 | 179 | async def get_balance(self) -> dict[str, Any]: 180 | raise NotImplementedError("get_balance is not used in backtesting") 181 | 182 | async def get_current_price( 183 | self, 184 | pair: str, 185 | ) -> float: 186 | raise NotImplementedError("get_current_price is not used in backtesting") 187 | 188 | async def cancel_order( 189 | self, 190 | order_id: str, 191 | pair: str, 192 | ) -> dict[str, str | float]: 193 | raise NotImplementedError("cancel_order is not used in backtesting") 194 | 195 | async def get_exchange_status(self) -> dict: 196 | raise NotImplementedError("get_exchange_status is not used in backtesting") 197 | 198 | async def close_connection(self) -> None: 199 | self.logger.info("[BACKTEST] Closing WebSocket connection...") 200 | -------------------------------------------------------------------------------- /tests/bot_management/test_bot_controller.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import contextlib 3 | from unittest.mock import Mock, call, patch 4 | 5 | import pytest 6 | 7 | from core.bot_management.bot_controller.bot_controller import BotController 8 | from core.bot_management.event_bus import EventBus, Events 9 | from core.bot_management.grid_trading_bot import GridTradingBot 10 | 11 | 12 | @pytest.mark.asyncio 13 | class TestBotController: 14 | @pytest.fixture 15 | def setup_bot_controller(self): 16 | bot = Mock(spec=GridTradingBot) 17 | bot.strategy = Mock() 18 | bot.strategy.get_formatted_orders = Mock(return_value=[]) 19 | bot.get_balances = Mock(return_value={}) 20 | event_bus = Mock(spec=EventBus) 21 | bot_controller = BotController(bot, event_bus) 22 | return bot_controller, bot, event_bus 23 | 24 | @pytest.fixture(autouse=True) 25 | def setup_logging(self, setup_bot_controller): 26 | bot_controller, _, _ = setup_bot_controller 27 | bot_controller.logger = Mock() 28 | 29 | @pytest.fixture 30 | def mock_input(self): 31 | with patch("builtins.input") as mock_input: 32 | yield mock_input 33 | 34 | async def run_command_test(self, bot_controller, mock_input): 35 | """Helper method to run command listener tests.""" 36 | listener_task = asyncio.create_task(bot_controller.command_listener()) 37 | 38 | try: 39 | await asyncio.sleep(0.1) 40 | bot_controller._stop_listening = True 41 | 42 | try: 43 | await asyncio.wait_for(listener_task, timeout=1.0) 44 | except TimeoutError: 45 | listener_task.cancel() 46 | with contextlib.suppress(asyncio.CancelledError): 47 | await listener_task 48 | 49 | finally: 50 | if not listener_task.done(): 51 | listener_task.cancel() 52 | with contextlib.suppress(asyncio.CancelledError): 53 | await listener_task 54 | 55 | @pytest.mark.asyncio 56 | @pytest.mark.timeout(2) 57 | async def test_command_listener_quit(self, mock_input, setup_bot_controller): 58 | bot_controller, _, event_bus = setup_bot_controller 59 | mock_input.return_value = "quit" 60 | event_bus.publish_sync = Mock() 61 | 62 | mock_input.side_effect = ["quit", StopIteration] 63 | 64 | await self.run_command_test(bot_controller, mock_input) 65 | 66 | event_bus.publish_sync.assert_called_once_with(Events.STOP_BOT, "User requested shutdown") 67 | assert bot_controller._stop_listening 68 | 69 | @pytest.mark.asyncio 70 | @pytest.mark.timeout(2) 71 | async def test_command_listener_orders(self, mock_input, setup_bot_controller): 72 | bot_controller, bot, _ = setup_bot_controller 73 | # Set up mock to return "orders" and then raise StopIteration 74 | mock_input.side_effect = ["orders", StopIteration] 75 | bot.strategy.get_formatted_orders.return_value = [ 76 | ["BUY", "LIMIT", "OPEN", "50000", "0.1", "2024-01-01", "1", "0.1%"], 77 | ] 78 | 79 | await self.run_command_test(bot_controller, mock_input) 80 | 81 | bot.strategy.get_formatted_orders.assert_called_once() 82 | 83 | @pytest.mark.asyncio 84 | @pytest.mark.timeout(2) 85 | async def test_command_listener_balance(self, mock_input, setup_bot_controller): 86 | bot_controller, bot, _ = setup_bot_controller 87 | mock_input.side_effect = ["balance", StopIteration] 88 | bot.get_balances.return_value = {"USD": 1000, "BTC": 0.1} 89 | 90 | await self.run_command_test(bot_controller, mock_input) 91 | 92 | bot.get_balances.assert_called_once() 93 | 94 | @pytest.mark.asyncio 95 | @pytest.mark.timeout(2) 96 | async def test_command_listener_stop(self, mock_input, setup_bot_controller): 97 | bot_controller, _, event_bus = setup_bot_controller 98 | mock_input.side_effect = ["stop", StopIteration] 99 | event_bus.publish_sync = Mock() 100 | 101 | await self.run_command_test(bot_controller, mock_input) 102 | 103 | event_bus.publish_sync.assert_called_once_with(Events.STOP_BOT, "User issued stop command") 104 | 105 | @pytest.mark.asyncio 106 | @pytest.mark.timeout(2) 107 | async def test_command_listener_restart(self, mock_input, setup_bot_controller): 108 | bot_controller, _, event_bus = setup_bot_controller 109 | mock_input.side_effect = ["restart", StopIteration] 110 | event_bus.publish_sync = Mock() 111 | 112 | await self.run_command_test(bot_controller, mock_input) 113 | 114 | assert event_bus.publish_sync.call_count == 2 115 | event_bus.publish_sync.assert_any_call(Events.STOP_BOT, "User issued restart command") 116 | event_bus.publish_sync.assert_any_call(Events.START_BOT, "User issued restart command") 117 | 118 | @pytest.mark.asyncio 119 | @pytest.mark.timeout(2) 120 | async def test_command_listener_invalid_command(self, mock_input, setup_bot_controller): 121 | bot_controller, _, _ = setup_bot_controller 122 | mock_input.side_effect = ["invalid", StopIteration] 123 | 124 | with patch.object(bot_controller.logger, "warning") as mock_logger: 125 | await self.run_command_test(bot_controller, mock_input) 126 | mock_logger.assert_called_once() 127 | 128 | @pytest.mark.asyncio 129 | async def test_handle_stop_event(self, setup_bot_controller): 130 | bot_controller, _, _ = setup_bot_controller 131 | 132 | with patch.object(bot_controller.logger, "info") as mock_logger: 133 | bot_controller._handle_stop_event("Test stop reason") 134 | 135 | assert bot_controller._stop_listening is True 136 | mock_logger.assert_has_calls( 137 | [ 138 | call("Received STOP_BOT event: Test stop reason"), 139 | call("Command listener stopped."), 140 | ], 141 | ) 142 | assert mock_logger.call_count == 2 143 | 144 | @pytest.mark.asyncio 145 | @pytest.mark.timeout(2) 146 | async def test_command_listener_unexpected_error(self, mock_input, setup_bot_controller): 147 | bot_controller, _, _ = setup_bot_controller 148 | mock_input.side_effect = Exception("Unexpected error") 149 | 150 | with patch.object(bot_controller.logger, "error") as mock_logger: 151 | await self.run_command_test(bot_controller, mock_input) 152 | 153 | mock_logger.assert_called_with( 154 | "Unexpected error in command listener: Unexpected error", 155 | exc_info=True, 156 | ) 157 | 158 | @pytest.mark.asyncio 159 | @pytest.mark.timeout(2) 160 | async def test_command_listener_invalid_pause_duration(self, mock_input, setup_bot_controller): 161 | bot_controller, _, _ = setup_bot_controller 162 | mock_input.side_effect = ["pause invalid", StopIteration] 163 | 164 | with patch.object(bot_controller.logger, "warning") as mock_logger: 165 | await self.run_command_test(bot_controller, mock_input) 166 | mock_logger.assert_called_once() 167 | 168 | @pytest.mark.asyncio 169 | @pytest.mark.timeout(2) 170 | async def test_display_orders(self, setup_bot_controller): 171 | bot_controller, bot, _ = setup_bot_controller 172 | orders = [["BUY", "LIMIT", "OPEN", "50000", "0.1", "2024-01-01", "1", "0.1%"]] 173 | bot.strategy.get_formatted_orders.return_value = orders 174 | 175 | with patch.object(bot_controller.logger, "info") as mock_logger: 176 | await bot_controller._display_orders() 177 | assert mock_logger.call_count == 2 178 | bot.strategy.get_formatted_orders.assert_called_once() 179 | 180 | @pytest.mark.asyncio 181 | @pytest.mark.timeout(2) 182 | async def test_display_balance(self, setup_bot_controller): 183 | bot_controller, bot, _ = setup_bot_controller 184 | balances = {"USD": 1000, "BTC": 0.1} 185 | bot.get_balances.return_value = balances 186 | 187 | with patch.object(bot_controller.logger, "info") as mock_logger: 188 | await bot_controller._display_balance() 189 | mock_logger.assert_any_call(f"Current balances: {balances}") 190 | bot.get_balances.assert_called_once() 191 | -------------------------------------------------------------------------------- /config/config_validator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from strategies.spacing_type import SpacingType 4 | from strategies.strategy_type import StrategyType 5 | 6 | from .exceptions import ConfigValidationError 7 | from .trading_mode import TradingMode 8 | 9 | 10 | class ConfigValidator: 11 | def __init__(self): 12 | self.logger = logging.getLogger(self.__class__.__name__) 13 | 14 | def validate(self, config): 15 | missing_fields = [] 16 | invalid_fields = [] 17 | missing_fields += self._validate_required_fields(config) 18 | invalid_fields += self._validate_exchange(config) 19 | missing_fields += self._validate_pair(config) 20 | missing_trading_settings, invalid_trading_settings = self._validate_trading_settings(config) 21 | missing_fields += missing_trading_settings 22 | invalid_fields += invalid_trading_settings 23 | missing_grid_settings, invalid_grid_settings = self._validate_grid_strategy(config) 24 | missing_fields += missing_grid_settings 25 | invalid_fields += invalid_grid_settings 26 | invalid_fields += self._validate_limits(config) 27 | missing_logging_settings, invalid_logging_settings = self._validate_logging(config) 28 | missing_fields += missing_logging_settings 29 | invalid_fields += invalid_logging_settings 30 | 31 | if missing_fields or invalid_fields: 32 | raise ConfigValidationError(missing_fields=missing_fields, invalid_fields=invalid_fields) 33 | 34 | def _validate_required_fields(self, config): 35 | required_fields = ["exchange", "pair", "trading_settings", "grid_strategy", "risk_management", "logging"] 36 | missing_fields = [field for field in required_fields if field not in config] 37 | if missing_fields: 38 | self.logger.error(f"Missing required fields: {missing_fields}") 39 | return missing_fields 40 | 41 | def _validate_exchange(self, config): 42 | invalid_fields = [] 43 | exchange = config.get("exchange", {}) 44 | 45 | if not exchange.get("name"): 46 | self.logger.error("Exchange name is missing.") 47 | invalid_fields.append("exchange.name") 48 | 49 | trading_fee = exchange.get("trading_fee") 50 | if trading_fee is None or not isinstance(trading_fee, float | int) or trading_fee < 0: 51 | self.logger.error("Invalid or missing trading fee.") 52 | invalid_fields.append("exchange.trading_fee") 53 | 54 | trading_mode_str = exchange.get("trading_mode") 55 | if not trading_mode_str: 56 | invalid_fields.append("exchange.trading_mode") 57 | else: 58 | try: 59 | TradingMode.from_string(trading_mode_str) 60 | except ValueError: 61 | invalid_fields.append("exchange.trading_mode") 62 | 63 | return invalid_fields 64 | 65 | def _validate_pair(self, config): 66 | missing_fields = [] 67 | pair = config.get("pair", {}) 68 | 69 | if not pair.get("base_currency"): 70 | missing_fields.append("pair.base_currency") 71 | if not pair.get("quote_currency"): 72 | missing_fields.append("pair.quote_currency") 73 | 74 | if missing_fields: 75 | self.logger.error(f"Missing pair configuration fields: {missing_fields}") 76 | 77 | return missing_fields 78 | 79 | def _validate_trading_settings(self, config): 80 | missing_fields = [] 81 | invalid_fields = [] 82 | trading_settings = config.get("trading_settings", {}) 83 | 84 | if not trading_settings.get("initial_balance"): 85 | missing_fields.append("trading_settings.initial_balance") 86 | 87 | # Validate timeframe 88 | timeframe = trading_settings.get("timeframe") 89 | valid_timeframes = ["1s", "1m", "3m", "5m", "15m", "30m", "1h", "2h", "6h", "12h", "1d", "1w", "1M"] 90 | if timeframe not in valid_timeframes: 91 | self.logger.error(f"Invalid timeframe: {timeframe}. Must be one of {valid_timeframes}.") 92 | invalid_fields.append("trading_settings.timeframe") 93 | 94 | # Validate period 95 | period = trading_settings.get("period", {}) 96 | start_date = period.get("start_date") 97 | end_date = period.get("end_date") 98 | 99 | if not start_date: 100 | missing_fields.append("trading_settings.period.start_date") 101 | if not end_date: 102 | missing_fields.append("trading_settings.period.end_date") 103 | 104 | return missing_fields, invalid_fields 105 | 106 | def _validate_grid_strategy(self, config): 107 | missing_fields = [] 108 | invalid_fields = [] 109 | grid = config.get("grid_strategy", {}) 110 | 111 | grid_type = grid.get("type") 112 | if grid_type is None: 113 | missing_fields.append("grid_strategy.type") 114 | else: 115 | try: 116 | StrategyType.from_string(grid_type) 117 | 118 | except ValueError as e: 119 | self.logger.error(str(e)) 120 | invalid_fields.append("grid_strategy.type") 121 | 122 | spacing = grid.get("spacing") 123 | if spacing is None: 124 | missing_fields.append("grid_strategy.spacing") 125 | else: 126 | try: 127 | SpacingType.from_string(spacing) 128 | 129 | except ValueError as e: 130 | self.logger.error(str(e)) 131 | invalid_fields.append("grid_strategy.spacing") 132 | 133 | num_grids = grid.get("num_grids") 134 | if num_grids is None: 135 | missing_fields.append("grid_strategy.num_grids") 136 | elif not isinstance(num_grids, int) or num_grids <= 0: 137 | self.logger.error("Grid strategy 'num_grids' must be a positive integer.") 138 | invalid_fields.append("grid_strategy.num_grids") 139 | 140 | range_ = grid.get("range", {}) 141 | top = range_.get("top") 142 | bottom = range_.get("bottom") 143 | if top is None: 144 | missing_fields.append("grid_strategy.range.top") 145 | if bottom is None: 146 | missing_fields.append("grid_strategy.range.bottom") 147 | 148 | if top is not None and bottom is not None: 149 | if not isinstance(top, int | float) or not isinstance(bottom, int | float): 150 | self.logger.error("'top' and 'bottom' in 'grid_strategy.range' must be numbers.") 151 | invalid_fields.append("grid_strategy.range.top") 152 | invalid_fields.append("grid_strategy.range.bottom") 153 | elif bottom >= top: 154 | self.logger.error("'grid_strategy.range.bottom' must be less than 'grid_strategy.range.top'.") 155 | invalid_fields.append("grid_strategy.range.top") 156 | invalid_fields.append("grid_strategy.range.bottom") 157 | 158 | return missing_fields, invalid_fields 159 | 160 | def _validate_limits(self, config): 161 | invalid_fields = [] 162 | limits = config.get("risk_management", {}) 163 | take_profit = limits.get("take_profit", {}) 164 | stop_loss = limits.get("stop_loss", {}) 165 | 166 | # Validate take profit 167 | if not isinstance(take_profit.get("enabled"), bool): 168 | self.logger.error("Take profit enabled flag must be a boolean.") 169 | invalid_fields.append("risk_management.take_profit.enabled") 170 | 171 | if take_profit.get("threshold") is None or not isinstance(take_profit.get("threshold"), float | int): 172 | self.logger.error("Invalid or missing take profit threshold.") 173 | invalid_fields.append("risk_management.take_profit.threshold") 174 | 175 | # Validate stop loss 176 | if not isinstance(stop_loss.get("enabled"), bool): 177 | self.logger.error("Stop loss enabled flag must be a boolean.") 178 | invalid_fields.append("risk_management.stop_loss.enabled") 179 | 180 | if stop_loss.get("threshold") is None or not isinstance(stop_loss.get("threshold"), float | int): 181 | self.logger.error("Invalid or missing stop loss threshold.") 182 | invalid_fields.append("risk_management.stop_loss.threshold") 183 | 184 | return invalid_fields 185 | 186 | def _validate_logging(self, config): 187 | missing_fields = [] 188 | invalid_fields = [] 189 | logging_settings = config.get("logging", {}) 190 | 191 | # Validate log level 192 | log_level = logging_settings.get("log_level") 193 | valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] 194 | if log_level is None: 195 | missing_fields.append("logging.log_level") 196 | elif log_level.upper() not in valid_log_levels: 197 | self.logger.error(f"Invalid log level: {log_level}. Must be one of {valid_log_levels}.") 198 | invalid_fields.append("logging.log_level") 199 | 200 | # Validate log to file 201 | if not isinstance(logging_settings.get("log_to_file"), bool): 202 | self.logger.error("log_to_file must be a boolean.") 203 | invalid_fields.append("logging.log_to_file") 204 | 205 | if missing_fields: 206 | self.logger.error(f"Missing logging fields: {missing_fields}") 207 | 208 | return missing_fields, invalid_fields 209 | --------------------------------------------------------------------------------