├── .gitignore ├── README.md ├── SierraCharts_Chartbook.Cht ├── live_trading ├── __init__.py ├── dataframe_pipeline │ ├── __init__.py │ ├── data_event.py │ ├── footprint_dataframe.py │ └── main_dataframe.py ├── events │ ├── __init__.py │ ├── account_event.py │ └── position_event.py └── trade_management │ ├── __init__.py │ ├── account_manager.py │ ├── model_handling.py │ ├── run_live_trading.py │ ├── trade_data_processor.py │ └── trading_logic.py ├── live_trading_app.py ├── model_training ├── __init__.py ├── agents │ ├── __init__.py │ └── ppo_train_test_split.py ├── gym_envs │ ├── __init__.py │ └── trading_env.py └── preprocessing │ ├── __init__.py │ ├── dataframe_processor.py │ └── feature_engineering.py ├── model_training_app.py ├── requirements.txt ├── sierracharts_data_downloader └── sc_data_downloader.py └── trade29_scpy-0.0.14.tar.gz /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | *.pyc 3 | *.pyo 4 | *.pyd 5 | __pycache__/ 6 | 7 | # Environments 8 | .env 9 | .venv 10 | env/ 11 | venv/ 12 | ENV/ 13 | 14 | # Distutils 15 | *.egg 16 | *.egg-info/ 17 | dist/ 18 | build/ 19 | eggs/ 20 | parts/ 21 | var/ 22 | sdist/ 23 | develop-eggs/ 24 | 25 | # IDEs 26 | .idea/ 27 | .vscode/ 28 | 29 | # OS generated files 30 | .DS_Store 31 | Thumbs.db 32 | 33 | # Data Files 34 | model_training/data/*.csv 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Project Overview 2 | 3 | This project is a Python-based application that uses deep reinforcement learning to perform live trading and model training on ES emini futures data through Sierra Charts _(New to Sierra Charts and how to connect it to Python? Check out my articles here https://medium.com/@investinatech)_. Sierra Charts is an extremely powerful trading software that offers pinpoint data. The project is divided into several modules, each with a specific role in the overall functionality of the application. 4 | 5 | ## Project Structure 6 | 7 | The project is organized into the following main directories: 8 | 9 | - [`model_training`]: This directory contains the code for training machine learning models. It includes the `agents` module for training, `gym_envs` for the trading environment, and `preprocessing` for data preprocessing and feature engineering. 10 | 11 | - [`live_trading`]: This directory contains the code for live trading. It includes the `trade_management` module for managing trades, `dataframe_pipeline` for processing data events into dataframes, and `events` for handling account and position events. 12 | 13 | - [`sierracharts_data_downloader`]: This directory contains the `sc_data_downloader.py` script for downloading data from Sierra Charts. 14 | 15 | The project also includes two main application files: 16 | 17 | - [`model_training_app.py`]: This is the main entry point for the model training application. 18 | 19 | - [`live_trading_app.py`]: This is the main entry point for the live trading application. 20 | 21 | ## Key Components 22 | 23 | ### Model Training 24 | 25 | The model training process is handled by the [`run_training`] function in [`model_training/agents/training.py`]. This function uses the Proximal Policy Optimization (PPO) algorithm from the Stable Baselines3 library to train the model. The training data is preprocessed and feature engineered using the `DataFrameProcessor` and [`FeatureEngineering`] classes in [`model_training/preprocessing/dataframe_processor.py`] and [`model_training/preprocessing/feature_engineering.py`] respectively. 26 | 27 | ### Live Trading 28 | 29 | The live trading process is handled by the `run_live_trading` function in [`live_trading/trade_management/run_live_trading.py`]. This function uses the trained model to make trading decisions. The trading data is processed into a dataframe using the [`DataEvent`], [`FootprintDataframe`], and [`MainDataframe`] classes in [`live_trading/dataframe_pipeline/data_event.py`], [`live_trading/dataframe_pipeline/footprint_dataframe.py`], and [`live_trading/dataframe_pipeline/main_dataframe.py`] respectively. The `AccountEvent` and `PositionEvent` classes in [`live_trading/events/account_event.py`] and [`live_trading/events/position_event.py`] are used to handle account and position events. 30 | 31 | ## Running the Applications 32 | 33 | Before running the applications, ensure the following steps are completed: 34 | 35 | 1. **Download Sierra Chartbook**: I provided the necessary chartbook file [`SiarraCharts_Chartbook.Cht`] for Sierra Charts. 36 | 37 | 2. **Set Up Virtual Environment**: Create a virtual environment (venv) to manage Python dependencies. 38 | - Use `python -m venv .venv` to create a new venv. 39 | - Use `.venv\Scripts\activate` to activate the venv. 40 | 41 | 3. **Install Dependencies**: Install the project's dependencies using `pip install -r requirements.txt`. 42 | 43 | 4. **Environment Configuration**: 44 | - Create a `.env` file to store your configuration and sensitive data securely. 45 | 5. **Data Preparation**: 46 | - Run the Sierra Chart (SC) data downloader: Execute the `sc_data_downloader.py` script from the [`sierracharts_data_downloader`] directory to fetch necessary trading data. 47 | 6. **Model Training**: 48 | - Run the model training application: Execute the [`model_training_app.py`] script located in the root of the project directory. 49 | 7. **Live Trading**: 50 | - Run the live trading application: Execute the [`live_trading_app.py`] script for live trading operations. 51 | 52 | Ensure each step is successfully completed before proceeding to the next. For detailed instructions on each step, refer to the respective sections of this documentation. 53 | 54 | 55 | ## Dependencies 56 | 57 | The project's dependencies are listed in the [`requirements.txt`] file. To install these dependencies, run `pip install -r requirements.txt`. 58 | 59 | ## Environment Variables 60 | 61 | The project uses environment variables to store sensitive information such as API keys. These variables are loaded from a [`.env`] file at runtime using the python-dotenv library. 62 | 63 | ## Data 64 | 65 | The project uses data from Sierra Charts, which is downloaded using the `sc_data_downloader.py` script in the [`sierracharts_data_downloader`] directory. 66 | 67 | ## Ignored Files 68 | 69 | The [`.gitignore`] file lists the files and directories that are ignored by Git. This includes Python cache files, environment variable files, log files, data files, Jupyter Notebook checkpoints, and IDE-specific files. 70 | 71 | ## Conclusion 72 | 73 | This project is a comprehensive application that uses deep reinforcement learning for live trading. It demonstrates the use of various Python libraries and techniques, including data preprocessing, feature engineering, deep reinforcement learning, and live trading. -------------------------------------------------------------------------------- /SierraCharts_Chartbook.Cht: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Invest-In-a-Tech/Deep-Reinforcement-Learning-DayTradingAgent/71c28cadc2e3fdaaab6a37a00f9cdc9b8d5abc0d/SierraCharts_Chartbook.Cht -------------------------------------------------------------------------------- /live_trading/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Invest-In-a-Tech/Deep-Reinforcement-Learning-DayTradingAgent/71c28cadc2e3fdaaab6a37a00f9cdc9b8d5abc0d/live_trading/__init__.py -------------------------------------------------------------------------------- /live_trading/dataframe_pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Invest-In-a-Tech/Deep-Reinforcement-Learning-DayTradingAgent/71c28cadc2e3fdaaab6a37a00f9cdc9b8d5abc0d/live_trading/dataframe_pipeline/__init__.py -------------------------------------------------------------------------------- /live_trading/dataframe_pipeline/data_event.py: -------------------------------------------------------------------------------- 1 | # In this file, we are subscribing to tuples data and processing it into a DataFrame 2 | 3 | import pandas as pd 4 | import os 5 | from dotenv import find_dotenv, load_dotenv 6 | 7 | 8 | # Load environment variables 9 | load_dotenv(find_dotenv()) 10 | sc_api_key = os.environ.get("SC_API_KEY") 11 | 12 | 13 | class DataEvent: 14 | def __init__(self, sc): 15 | self.sc = sc 16 | self.subscribe() 17 | 18 | 19 | def subscribe(self): 20 | self.data_reqid = self.sc.graph_data_request( 21 | key=sc_api_key, 22 | historical_init_bars=30, 23 | realtime_update_bars=50, 24 | include_vbp=True, 25 | #update_frequency=1, 26 | on_bar_close=True, 27 | base_data='1;2;3;4;5', 28 | sg_data="ID2.[SG1;SG10];ID4.[SG1];ID5.[SG2-SG3];ID6.[SG1]", 29 | ) 30 | 31 | 32 | def process_dataframe(self, df): 33 | df = df.sort_values(by=['Date', 'Price']) 34 | df['Date'] = pd.to_datetime(df['Date'], format='%Y-%m-%d %H:%M:%S') 35 | df.set_index('Date', inplace=True) 36 | df = df.dropna() # Remove rows with missing values 37 | return df 38 | 39 | 40 | def process_data_event(self, msg): 41 | msg.df.columns = [ 42 | 'Date', 43 | 'BarNumber', 44 | 'Open', 45 | 'High', 46 | 'Low', 47 | 'Close', 48 | 'Volume', 49 | 'Delta', 50 | 'CVD', 51 | 'TodayOpen', 52 | 'PrevHigh', 53 | 'PrevLow', 54 | 'VWAP', 55 | 'Tuples' 56 | ] 57 | 58 | 59 | exploded_df = msg.df.explode('Tuples') 60 | tuple_cols = ['Price', 'Bid', 'Ask', 'TotalVolume', 'NumberOfTrades'] 61 | exploded_df[tuple_cols] = pd.DataFrame(exploded_df['Tuples'].tolist(), index=exploded_df.index) 62 | exploded_df.drop(columns=['BarNumber', 'Tuples'], inplace=True) 63 | print(f"After processing data, DataFrame shape: {exploded_df.shape}") 64 | return self.process_dataframe(exploded_df) 65 | 66 | -------------------------------------------------------------------------------- /live_trading/dataframe_pipeline/footprint_dataframe.py: -------------------------------------------------------------------------------- 1 | # footprint_dataframe.py 2 | 3 | import pandas as pd 4 | 5 | class FootprintDataframe: 6 | def __init__(self, df): 7 | # Copying DataFrame if necessary 8 | self.df = df if df is not None else pd.DataFrame() 9 | 10 | 11 | def process_footprint_dataframe(self): 12 | """ 13 | Enhancements for Performance: 14 | 15 | - Used vectorized group operations with pandas groupby and apply methods. 16 | This avoids looping over groups manually and leverages pandas' optimized 17 | internal computations for group operations. 18 | 19 | - Custom function 'calculate_levels' is applied to each group to perform 20 | necessary calculations in a vectorized way, increasing efficiency. 21 | 22 | - Efficient merging with the original DataFrame and in-place handling of 23 | missing values further enhances performance. 24 | """ 25 | # Ensure DataFrame has necessary columns 26 | required_columns = {'TotalVolume', 'Price'} 27 | if not required_columns.issubset(self.df.columns): 28 | raise ValueError(f"DataFrame must contain columns: {required_columns}") 29 | 30 | # Group data by Date 31 | if not self.df.index.is_unique: 32 | grouped = self.df.groupby(self.df.index) 33 | else: 34 | # If each row has a unique index, grouping is not necessary 35 | grouped = [(self.df.index[i], self.df.iloc[[i]]) for i in range(len(self.df))] 36 | 37 | 38 | def calculate_levels(group): 39 | sorted_volumes = group.sort_values('TotalVolume') 40 | poc = sorted_volumes.iloc[-1] 41 | hvn = sorted_volumes.iloc[-2] if len(sorted_volumes) > 1 else poc 42 | lvn = sorted_volumes.iloc[0] 43 | 44 | return pd.Series({ 45 | 'POC_Price': poc['Price'], 46 | 'POC_Volume': poc['TotalVolume'], 47 | 'HVN_Price': hvn['Price'], 48 | 'HVN_Volume': hvn['TotalVolume'], 49 | 'LVN_Price': lvn['Price'], 50 | 'LVN_Volume': lvn['TotalVolume'] 51 | }) 52 | 53 | combined_df = grouped.apply(calculate_levels) 54 | 55 | # Merge with Original DataFrame 56 | self.df = self.df.merge(combined_df, left_index=True, right_index=True, how='left') 57 | 58 | 59 | # Handle Missing Values 60 | self.df.ffill(inplace=True) 61 | 62 | print(self.df.tail(50)) 63 | print(f"After processing footprint dataframe, DataFrame shape: {self.df.shape}") 64 | return self.df.dropna() 65 | 66 | -------------------------------------------------------------------------------- /live_trading/dataframe_pipeline/main_dataframe.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from live_trading.dataframe_pipeline.footprint_dataframe import FootprintDataframe 3 | 4 | class MainDataframe: 5 | def __init__(self, df): 6 | # Copying DataFrame if necessary 7 | self.df = df if df is not None else pd.DataFrame() 8 | 9 | 10 | def process_main_dataframe(self): 11 | # Create an instance of FootprintDataframe 12 | footprintdataframe = FootprintDataframe(self.df) 13 | 14 | # Perform operations defined in FootprintDataframe 15 | fp_df = footprintdataframe.process_footprint_dataframe() 16 | 17 | # Grouping the data by timestamp 18 | grouped_data = fp_df.groupby(fp_df.index) 19 | 20 | # Preparing a DataFrame to hold the POC, HVN, and LVN for each timestamp 21 | poc_hvn_lvn_data = [] 22 | 23 | for date, group in grouped_data: 24 | # Extracting POC, HVN, and LVN data for each timestamp 25 | # POC, HVN, and LVN data 26 | poc_price = group['POC_Price'].iloc[0] # Assuming the first occurrence is the most relevant 27 | poc_volume = group['POC_Volume'].iloc[0] 28 | hvn_price = group['HVN_Price'].iloc[0] 29 | hvn_volume = group['HVN_Volume'].iloc[0] 30 | lvn_price = group['LVN_Price'].iloc[0] 31 | lvn_volume = group['LVN_Volume'].iloc[0] 32 | 33 | # Additional columns data 34 | open_price = group['Open'].iloc[0] 35 | high_price = group['High'].iloc[0] 36 | low_price = group['Low'].iloc[0] 37 | close_price = group['Close'].iloc[0] 38 | volume = group['Volume'].iloc[0] 39 | delta = group['Delta'].iloc[0] 40 | cvd = group['CVD'].iloc[0] 41 | 42 | poc_hvn_lvn_data.append((date, poc_price, poc_volume, hvn_price, hvn_volume, lvn_price, lvn_volume, 43 | open_price, high_price, low_price, close_price, volume, delta, cvd)) 44 | 45 | # Creating a DataFrame from the aggregated data 46 | columns = ['Date', 'POC_Price', 'POC_Volume', 'HVN_Price', 'HVN_Volume', 'LVN_Price', 'LVN_Volume', 47 | 'Open', 'High', 'Low', 'Close', 'Volume', 'Delta', 'CVD'] 48 | 49 | new_dataframe = pd.DataFrame(poc_hvn_lvn_data, columns=columns) 50 | new_dataframe.set_index('Date', inplace=True) 51 | 52 | print(f"After processing main dataframe, DataFrame shape: {new_dataframe.shape}") 53 | print(new_dataframe.tail(10)) 54 | # Return the enhanced DataFrame 55 | return new_dataframe 56 | 57 | -------------------------------------------------------------------------------- /live_trading/events/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Invest-In-a-Tech/Deep-Reinforcement-Learning-DayTradingAgent/71c28cadc2e3fdaaab6a37a00f9cdc9b8d5abc0d/live_trading/events/__init__.py -------------------------------------------------------------------------------- /live_trading/events/account_event.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import find_dotenv, load_dotenv 3 | 4 | # Load environment variables 5 | load_dotenv(find_dotenv()) 6 | sc_api_key = os.environ.get("SC_API_KEY") 7 | 8 | 9 | class AccountEvent: 10 | def __init__(self, sc): 11 | self.sc = sc 12 | self.acct_reqid = None 13 | self.subscribe() 14 | 15 | self.available_funds = None 16 | self.trade_account = None 17 | self.open_positions_pnl = None 18 | self.daily_pnl = None 19 | self.cash_balance = None 20 | 21 | 22 | def subscribe(self): 23 | self.acct_reqid = self.sc.get_account_status(key=sc_api_key, subscribe=True) 24 | #print(self.msg.dict) 25 | 26 | 27 | def process_account_event(self, msg): 28 | #print(msg) 29 | data = { 30 | 'available_funds': msg.available_funds, 31 | 'trade_account': msg.trade_account, 32 | 'cash_balance': msg.cash_balance, 33 | 'open_positions_pnl': msg.open_positions_pnl, 34 | 'daily_pnl': msg.daily_pnl 35 | } 36 | self.available_funds = data['available_funds'] 37 | self.trade_account = data['trade_account'] 38 | self.cash_balance = data['cash_balance'] 39 | self.open_positions_pnl = data['open_positions_pnl'] 40 | self.daily_pnl = data['daily_pnl'] 41 | return data 42 | -------------------------------------------------------------------------------- /live_trading/events/position_event.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import find_dotenv, load_dotenv 3 | 4 | # Load environment variables 5 | load_dotenv(find_dotenv()) 6 | sc_api_key = os.environ.get("SC_API_KEY") 7 | 8 | 9 | class PositionEvent: 10 | def __init__(self, sc): 11 | self.sc = sc 12 | self.pos_reqid = None 13 | self.subscribe() 14 | 15 | self.position = None 16 | self.avg_price = None 17 | self.open_pnl = 0 18 | 19 | self.cumulative_pnl = 0.0 20 | self.previous_open_pnl = 0.0 21 | 22 | 23 | def subscribe(self): 24 | self.pos_reqid = self.sc.get_position_status(key=sc_api_key, subscribe=True) 25 | 26 | 27 | def process_position_event(self, msg): 28 | #trade position 29 | position = 1.0 if msg.qty > 0 else -1.0 if msg.qty < 0 else 0.0 30 | 31 | # If a position is still open 32 | if position != 0: 33 | change_in_pnl = msg.open_pnl - self.previous_open_pnl 34 | self.cumulative_pnl += change_in_pnl 35 | self.previous_open_pnl = msg.open_pnl 36 | # If position is closed 37 | else: 38 | previous_open_pnl = 0.0 39 | 40 | data = { 41 | 'position': position, 42 | 'avg_price': msg.avg_price, 43 | 'open_pnl': msg.open_pnl, 44 | 'cumulative_pnl': self.cumulative_pnl, 45 | 'previous_open_pnl': self.previous_open_pnl, 46 | } 47 | self.position = data['position'] 48 | self.avg_price = data['avg_price'] 49 | self.open_pnl = data['open_pnl'] 50 | self.cumulative_pnl = data['cumulative_pnl'] 51 | self.previous_open_pnl = data['previous_open_pnl'] 52 | return data 53 | -------------------------------------------------------------------------------- /live_trading/trade_management/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Invest-In-a-Tech/Deep-Reinforcement-Learning-DayTradingAgent/71c28cadc2e3fdaaab6a37a00f9cdc9b8d5abc0d/live_trading/trade_management/__init__.py -------------------------------------------------------------------------------- /live_trading/trade_management/account_manager.py: -------------------------------------------------------------------------------- 1 | # account_manager.py 2 | 3 | class AccountManager: 4 | DEFAULT_BEGINNING_BALANCE = 50000 5 | DEFAULT_HIGH_WATER_MARK = 0.0 6 | STOP_LOSS_THRESHOLD = 1000 7 | PROFIT_TARGET_THRESHOLD = 1000 8 | 9 | 10 | def __init__(self): 11 | self.beginning_balance = self.DEFAULT_BEGINNING_BALANCE 12 | self.high_water_mark = self.DEFAULT_HIGH_WATER_MARK 13 | self.current_balance = None 14 | self.stop_trading = False 15 | self.stop_trading_after_profit_target = False 16 | self.current_position = None 17 | self.entry_balance = None 18 | 19 | 20 | def _initialize_balance(self, position_event): 21 | """Initializes the balance based on the position event.""" 22 | if position_event.cumulative_pnl is None: 23 | print("Warning: cumulative_pnl is not initialized.") 24 | position_event.cumulative_pnl = 0.0 25 | print("cumulative_pnl:", position_event.cumulative_pnl) 26 | self.current_balance = self.beginning_balance + position_event.cumulative_pnl 27 | print(f"Account Value: {self.current_balance}") 28 | 29 | 30 | def _update_high_water_mark(self): 31 | """Updates the high water mark if the current balance exceeds it.""" 32 | if self.current_balance > self.high_water_mark: 33 | self.high_water_mark = self.current_balance 34 | print(f"High water mark: {self.high_water_mark}") 35 | 36 | 37 | def _check_stop_conditions(self, position_event): 38 | """Checks and updates trading stop conditions.""" 39 | stop_loss_flag = 0 40 | if self.current_balance <= (self.high_water_mark - self.STOP_LOSS_THRESHOLD): 41 | print("Agent Message: Stop loss reached. I'm no longer allowed to trade!") 42 | stop_loss_flag = 1 43 | self.stop_trading = True 44 | 45 | profit_target_flag = 0 46 | if position_event.cumulative_pnl >= self.PROFIT_TARGET_THRESHOLD: 47 | profit_target_flag = 1 48 | self.stop_trading_after_profit_target = True 49 | 50 | return stop_loss_flag, profit_target_flag 51 | 52 | 53 | def _update_position_status(self, position_event): 54 | """Updates the position status.""" 55 | self.current_position = position_event.position 56 | long_position = int(self.current_position == 1.0) 57 | short_position = int(self.current_position == -1.0) 58 | print(f"Long position: {long_position}") 59 | print(f"Short position: {short_position}") 60 | 61 | if self.current_position in [1.0, -1.0]: 62 | self.entry_balance = self.current_balance 63 | else: 64 | self.entry_balance = None 65 | print(f"Entry balance: {self.entry_balance}") 66 | 67 | drawdown = min(position_event.open_pnl, 0) if position_event.position is not None else 0 68 | print(f"Drawdown: {drawdown}") 69 | print(f"Open PnL: {position_event.open_pnl}") 70 | 71 | return long_position, short_position, drawdown 72 | 73 | 74 | def reset_account(self, position_event): 75 | """Resets the account state for a new trading session.""" 76 | #print(f"Resetting: Current Cumulative PnL: {position_event.cumulative_pnl}, High Water Mark: {self.high_water_mark}") 77 | position_event.cumulative_pnl = 0.0 78 | self.high_water_mark = self.DEFAULT_HIGH_WATER_MARK 79 | #print(f"Reset Complete: Cumulative PnL: {position_event.cumulative_pnl}, High Water Mark: {self.high_water_mark}") 80 | 81 | 82 | def manage_account(self, position_event): 83 | """Manages the account based on the given position event.""" 84 | print(f"Beginning balance: {self.beginning_balance}") 85 | 86 | self._initialize_balance(position_event) 87 | self._update_high_water_mark() 88 | stop_loss_flag, profit_target_flag = self._check_stop_conditions(position_event) 89 | long_position, short_position, drawdown = self._update_position_status(position_event) 90 | 91 | return { 92 | 'stop_trading': self.stop_trading, 93 | 'stop_trading_after_profit_target': self.stop_trading_after_profit_target, 94 | 'current_position': self.current_position, 95 | 'entry_balance': self.entry_balance, 96 | 'drawdown': drawdown, 97 | 'stop_loss_flag': stop_loss_flag, 98 | 'profit_target_flag': profit_target_flag, 99 | 'long_position': long_position, 100 | 'short_position': short_position, 101 | 'high_water_mark': self.high_water_mark, 102 | 'cumulative_pnl': position_event.cumulative_pnl, 103 | 'beginning_balance': self.beginning_balance, 104 | 'stop_loss': self.STOP_LOSS_THRESHOLD, 105 | 'current_balance': self.current_balance, 106 | } 107 | -------------------------------------------------------------------------------- /live_trading/trade_management/model_handling.py: -------------------------------------------------------------------------------- 1 | # model_handling.py 2 | 3 | import joblib 4 | import numpy as np 5 | from stable_baselines3 import PPO 6 | 7 | class ModelHandler: 8 | def __init__(self, model_path, scaler_path): 9 | self.model = PPO.load(model_path) 10 | self.scaler = joblib.load(scaler_path) 11 | 12 | 13 | def prepare_observation(self, df_enhanced, long_position, short_position, current_balance, stop_loss_flag, drawdown, open_pnl): 14 | live_features_scaled = self.scaler.transform(df_enhanced) 15 | obs = np.hstack((live_features_scaled[0], [long_position, short_position, current_balance, stop_loss_flag, drawdown, open_pnl])) 16 | return obs 17 | 18 | 19 | def predict_action(self, observation): 20 | action, _states = self.model.predict(observation) 21 | return action 22 | 23 | -------------------------------------------------------------------------------- /live_trading/trade_management/run_live_trading.py: -------------------------------------------------------------------------------- 1 | from live_trading.trade_management.trade_data_processor import TradeDataProcessor 2 | 3 | def run_live_trading(): 4 | processor = TradeDataProcessor() 5 | processor.process_data() -------------------------------------------------------------------------------- /live_trading/trade_management/trade_data_processor.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from trade29.sc.bridge import SCBridge 4 | 5 | from live_trading.trade_management.model_handling import ModelHandler 6 | from live_trading.trade_management.account_manager import AccountManager 7 | from live_trading.trade_management.trading_logic import execute_trading_logic 8 | 9 | from live_trading.events.account_event import AccountEvent 10 | from live_trading.events.position_event import PositionEvent 11 | 12 | from live_trading.dataframe_pipeline.data_event import DataEvent 13 | from live_trading.dataframe_pipeline.footprint_dataframe import FootprintDataframe 14 | 15 | import os 16 | from dotenv import find_dotenv, load_dotenv 17 | 18 | 19 | # Load environment variables 20 | load_dotenv(find_dotenv()) 21 | sc_api_key = os.environ.get("SC_API_KEY") 22 | 23 | 24 | class TradeDataProcessor: 25 | def __init__(self, start_time="08:30:00", end_time="14:30:00"): 26 | 27 | # Create an instance of SCBridge 28 | self.sc = SCBridge() 29 | 30 | # Pass the SCBridge instance to our DataEvent, AccountEvent, PositionEvent 31 | self.data_event = DataEvent(self.sc) 32 | self.account_event = AccountEvent(self.sc) 33 | self.position_event = PositionEvent(self.sc) 34 | 35 | self.start_time = datetime.strptime(start_time, "%H:%M:%S").time() 36 | self.end_time = datetime.strptime(end_time, "%H:%M:%S").time() 37 | 38 | # Initialize ModelHandler 39 | self.model_handler = ModelHandler( 40 | os.path.join("model_training", "training", "saved_models", "PPO_fold_1.zip"), 41 | os.path.join("model_training", "training", "saved_scalers", "scaler_fold_1") 42 | ) 43 | 44 | # Initialize account manager 45 | self.account_manager = AccountManager() 46 | 47 | 48 | def prepare_data(self, data_message): 49 | # Extract and process the data from the message 50 | self.df = self.data_event.process_data_event(data_message) 51 | 52 | # Process the data frame for the model 53 | footprint_dataframe = FootprintDataframe(self.df) 54 | df_enhanced = footprint_dataframe.process_footprint_dataframe() 55 | return df_enhanced 56 | 57 | 58 | def manage_account(self): 59 | # Logic for managing account 60 | self.account_info = self.account_manager.manage_account(self.position_event) 61 | return self.account_info 62 | 63 | 64 | def process_data(self): 65 | # Start an infinite loop 66 | while True: 67 | # Get the data event 68 | self.msg = self.sc.get_response_queue().get() 69 | 70 | # Check if the request ID is for account data 71 | if self.msg.request_id == self.account_event.acct_reqid: 72 | self.account_event.process_account_event(self.msg) 73 | 74 | # Check if the request ID is for position data 75 | elif self.msg.request_id == self.position_event.pos_reqid: 76 | self.position_event.process_position_event(self.msg) 77 | 78 | # Check if the request ID is for data 79 | elif self.msg.request_id == self.data_event.data_reqid: 80 | df_enhanced = self.prepare_data(self.msg) 81 | #print(df_enhanced) 82 | 83 | # Call the manage_account method and get the updated account info 84 | self.account_info = self.manage_account() 85 | 86 | # Model handling 87 | obs = self.model_handler.prepare_observation( 88 | df_enhanced, 89 | self.account_info['long_position'], 90 | self.account_info['short_position'], 91 | self.account_info['current_balance'], 92 | self.account_info['stop_loss_flag'], 93 | self.account_info['drawdown'], 94 | self.position_event.open_pnl, 95 | ) 96 | action = self.model_handler.predict_action(obs) 97 | 98 | # Trading logic 99 | current_time = self.df.index[-1].time() 100 | 101 | # Call the execute_trading_logic function 102 | execute_trading_logic( 103 | self.sc, 104 | self.account_info, 105 | self.df.index[-1].time(), # current time 106 | action, 107 | self.start_time, 108 | self.end_time 109 | ) 110 | 111 | # Exit time - Check if the current time is past the exit time 112 | exit_time = datetime.strptime("14:31:00", "%H:%M:%S").time() 113 | if current_time >= exit_time: 114 | 115 | # Reset values for the next trading session 116 | self.account_manager.reset_account(position_event=self.position_event) 117 | 118 | # Check if there are any open positions 119 | if self.account_info['current_position'] != 0: 120 | self.account_info['current_position'] = 0 121 | self.sc.flatten_and_cancel(key=sc_api_key) # Cancel all open orders 122 | print("End of Day: Exiting all positions") 123 | 124 | -------------------------------------------------------------------------------- /live_trading/trade_management/trading_logic.py: -------------------------------------------------------------------------------- 1 | # trading_logic.py 2 | 3 | import os 4 | from dotenv import find_dotenv, load_dotenv 5 | 6 | # Load environment variables 7 | load_dotenv(find_dotenv()) 8 | sc_api_key = os.environ.get("SC_API_KEY") 9 | 10 | 11 | def execute_trading_logic(sc, account_info, current_time, action, start_time, end_time): 12 | # Trading logic 13 | 14 | if account_info['stop_trading'] and account_info['current_position'] != 0: 15 | sc.flatten_and_cancel(key=sc_api_key) # Cancel all open orders 16 | print("Agent Message: Exiting due to stop trading condition being True") 17 | account_info["current_position"] = 0 18 | 19 | if account_info['stop_trading_after_profit_target'] and account_info['current_position'] != 0: 20 | sc.flatten_and_cancel(key=sc_api_key) 21 | print("Agent Message: Exiting due to stop trading after profit target condition being True") 22 | account_info['current_position'] = 0 23 | 24 | if start_time <= current_time <= end_time and not account_info['stop_trading'] \ 25 | and not account_info['stop_trading_after_profit_target']: 26 | if account_info['current_position'] == 0: 27 | 28 | # Buy entry 29 | if action == 0: 30 | print("Long condition") 31 | sc.submit_order(key=sc_api_key, qty=1, is_buy=True, target_enabled=True, target_offset=10, stop_enabled=True, stop_offset=10) 32 | account_info['current_position'] = 1 33 | print("Entry balance:", account_info["entry_balance"]) 34 | 35 | # Short entry 36 | elif action == 2: 37 | print("Short condition") 38 | sc.submit_order(key=sc_api_key, qty=1, is_buy=False, target_enabled=True, target_offset=10, stop_enabled=True, stop_offset=10) 39 | account_info['current_position'] = -1 40 | print("Entry balance:", account_info['entry_balance']) 41 | 42 | # No action 43 | elif action == 4: 44 | print("No action") 45 | 46 | # Long Position Open 47 | elif account_info['current_position'] == 1: 48 | 49 | # Buy exit 50 | if action == 1: 51 | sc.flatten_and_cancel(key=sc_api_key) # Cancel all open orders 52 | print("Long exit") 53 | account_info['current_position'] = 0 54 | 55 | # Short Position Open 56 | elif account_info['current_position'] == -1: 57 | 58 | # Short exit 59 | if action == 3: 60 | sc.flatten_and_cancel(key=sc_api_key) # Cancel all open orders 61 | print("Short exit") 62 | account_info['current_position'] = 0 63 | -------------------------------------------------------------------------------- /live_trading_app.py: -------------------------------------------------------------------------------- 1 | from live_trading.trade_management.run_live_trading import run_live_trading 2 | 3 | def main(): 4 | 5 | run_live_trading() 6 | 7 | if __name__ == "__main__": 8 | main() -------------------------------------------------------------------------------- /model_training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Invest-In-a-Tech/Deep-Reinforcement-Learning-DayTradingAgent/71c28cadc2e3fdaaab6a37a00f9cdc9b8d5abc0d/model_training/__init__.py -------------------------------------------------------------------------------- /model_training/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Invest-In-a-Tech/Deep-Reinforcement-Learning-DayTradingAgent/71c28cadc2e3fdaaab6a37a00f9cdc9b8d5abc0d/model_training/agents/__init__.py -------------------------------------------------------------------------------- /model_training/agents/ppo_train_test_split.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import train_test_split 2 | from sklearn.model_selection import TimeSeriesSplit 3 | from sklearn.preprocessing import MinMaxScaler 4 | import pandas as pd 5 | import numpy as np 6 | import joblib 7 | from stable_baselines3 import PPO 8 | from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize 9 | from sklearn.model_selection import KFold 10 | from model_training.preprocessing.dataframe_processor import DataFrameProcessor 11 | from model_training.preprocessing.feature_engineering import FeatureEngineering 12 | from model_training.gym_envs.trading_env import TradingEnv 13 | import os 14 | 15 | 16 | def run_training(): 17 | processor = DataFrameProcessor(os.path.join('model_training', 'data', 'example_esDataset.csv')) 18 | df = processor.process_data() 19 | 20 | feature_engineer = FeatureEngineering(df) 21 | df_enhanced = feature_engineer.perform_feature_engineering() 22 | 23 | # Number of splits for Time Series Cross-Validation 24 | n_splits = 2 25 | 26 | # TimeSeriesSplit 27 | tscv = TimeSeriesSplit(n_splits=n_splits) 28 | 29 | # Store evaluation metrics for each fold 30 | evaluation_results = [] 31 | 32 | 33 | for fold_num, (train_index, test_index) in enumerate(tscv.split(df_enhanced), start=1): 34 | print(f"Fold {fold_num}:") 35 | 36 | 37 | # Define the log path for this fold 38 | log_path = os.path.join('model_training', 'training', 'logs', f'fold_{fold_num}') 39 | os.makedirs(log_path, exist_ok=True) 40 | print(f"Log path for fold {fold_num}: {log_path}") 41 | 42 | # Splitting the data 43 | features_train, features_test = df_enhanced.iloc[train_index], df_enhanced.iloc[test_index] 44 | 45 | # Scaling the features 46 | scaler = MinMaxScaler() 47 | features_train_scaled = scaler.fit_transform(features_train) 48 | features_test_scaled = scaler.transform(features_test) 49 | 50 | # Convert the scaled data into a DataFrame 51 | features_train_scaled = pd.DataFrame(features_train_scaled, columns=features_train.columns, index=features_train.index) 52 | features_test_scaled = pd.DataFrame(features_test_scaled, columns=features_test.columns, index=features_test.index) 53 | print(f"After scaling, training data shape: {features_train_scaled.shape}, test data shape: {features_test_scaled.shape}") 54 | 55 | # Initialize and train your model 56 | #env_train = TradingEnv(features_train_scaled, start_time="08:30:00", end_time="14:30:00") 57 | env_train = TradingEnv(features_train, start_time="08:30:00", end_time="14:30:00") 58 | model = PPO("MlpPolicy", env_train, verbose=1, tensorboard_log=log_path) 59 | model.learn(total_timesteps=500) 60 | 61 | 62 | # Evaluate your model on the test set 63 | #env_test = TradingEnv(features_test_scaled, start_time="08:30:00", end_time="14:30:00") 64 | env_test = TradingEnv(features_test, start_time="08:30:00", end_time="14:30:00") 65 | obs = env_test.reset() 66 | total_rewards = 0 67 | done = False 68 | while not done: 69 | action, _states = model.predict(obs) 70 | obs, reward, done, info = env_test.step(action) 71 | total_rewards += reward 72 | env_test.render(action=action, reward=reward) 73 | 74 | 75 | # Store evaluation results 76 | evaluation_results.append(total_rewards) 77 | 78 | # Define the path to save the model 79 | PPO_path = os.path.join('model_training', 'training', 'saved_models', f'PPO_fold_{fold_num}') 80 | 81 | # Save the mode 82 | model.save(PPO_path) 83 | 84 | # Create the directory for the scaler if it doesn't exist 85 | scaler_directory = os.path.join('model_training', 'training', 'saved_scalers') 86 | os.makedirs(scaler_directory, exist_ok=True) 87 | 88 | # Saving the scaler 89 | scaler_filename = os.path.join(scaler_directory, f'scaler_fold_{fold_num}') 90 | joblib.dump(scaler, scaler_filename) 91 | 92 | # Analysis of evaluation results across all folds 93 | print("Evaluation Results:", evaluation_results) -------------------------------------------------------------------------------- /model_training/gym_envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Invest-In-a-Tech/Deep-Reinforcement-Learning-DayTradingAgent/71c28cadc2e3fdaaab6a37a00f9cdc9b8d5abc0d/model_training/gym_envs/__init__.py -------------------------------------------------------------------------------- /model_training/gym_envs/trading_env.py: -------------------------------------------------------------------------------- 1 | ############################################## 2 | ################# IMPORTS #################### 3 | import gym 4 | from gym import spaces 5 | import numpy as np 6 | from datetime import datetime 7 | 8 | 9 | ####################################################### 10 | ###### TRADING ENVIRONMENT FOR FINANCIAL MARKETS ###### 11 | class TradingEnv(gym.Env): 12 | """ 13 | A custom trading environment class compatible with OpenAI's gym interface. 14 | It simulates a trading scenario allowing actions like buying, selling, and holding. 15 | 16 | Parameters: 17 | - features (DataFrame): Historical market data features. 18 | - initial_balance (float): The starting balance for the trading account. 19 | - tick_value (float): The value of each market tick. 20 | - stop_loss (float): The stop loss amount. 21 | - profit_target (float): The profit target. 22 | - start_time (str): The opening time of the trading session. 23 | - end_time (str): The closing time of the trading session. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | features, 29 | initial_balance=50000, 30 | tick_value=12.50, 31 | stop_loss=1000, 32 | profit_target=1000, 33 | start_time="08:30:00", 34 | end_time="14:30:00", 35 | ): 36 | super(TradingEnv, self).__init__() 37 | 38 | # Process and set the market data features 39 | self.features = features.values 40 | self.feature_columns = features.columns 41 | self.dates = features.index 42 | 43 | # Setting up trading session timings 44 | self.start_time = datetime.strptime(start_time, "%H:%M:%S").time() 45 | self.end_time = datetime.strptime(end_time, "%H:%M:%S").time() 46 | 47 | # Validate the trading session timings 48 | if self.start_time >= self.end_time: 49 | raise ValueError("Start time must be earlier than end time.") 50 | 51 | # Initialize variables for account management 52 | self.initial_balance = initial_balance 53 | self.balance = initial_balance 54 | self.tick_value = tick_value 55 | self.high_water_mark = initial_balance 56 | self.stop_loss = stop_loss 57 | self.profit_target = profit_target 58 | 59 | # Position management variables 60 | self.position = None 61 | self.entry_balance = None 62 | self.entry_price = None 63 | self.last_trade_reward = 0 64 | self.stop_loss_flag = 0 65 | 66 | # Organize market data 67 | self.features = features.sort_index() 68 | self.unique_timestamps = self.features.index.unique() 69 | 70 | # Define the action space (5 actions: buy entry/exit, short entry/exit, no action) 71 | self.action_space = spaces.Discrete(5) 72 | 73 | # Define the observation space (market data + account information) 74 | market_data_shape = features.shape[1] 75 | additional_info_shape = 6 # Balance, position flags, stop loss flag, open PnL, drawdown 76 | total_observation_shape = market_data_shape + additional_info_shape 77 | self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(total_observation_shape,), dtype=np.float32) 78 | 79 | # Initialize the environment to its start state 80 | self.reset() 81 | 82 | 83 | ##################################################### 84 | ####### RESET METHOD FOR TRADING ENVIRONMENT ######## 85 | def reset(self): 86 | """ 87 | Resets the trading environment to its initial state. 88 | This method is typically called at the beginning of each new episode. 89 | 90 | Returns: 91 | - The initial observation of the environment. 92 | """ 93 | 94 | # Reset the index for the current timestamp to the start of the dataset 95 | self.current_timestamp_index = 0 96 | self.current_timestamp = self.unique_timestamps[self.current_timestamp_index] 97 | 98 | # Reset the row counter for the current state in the dataset 99 | self.current_row = 0 100 | 101 | # Reinitialize the account balance to the initial balance 102 | self.balance = self.initial_balance 103 | 104 | # Reset position-related variables to None, indicating no current position 105 | self.position = None 106 | self.entry_price = None 107 | self.high_water_mark = self.initial_balance # Reset the high water mark to the initial balance 108 | 109 | # Reset the step at which the current position was entered (if any) 110 | self.entry_step = None 111 | 112 | # Reset the stop loss flag, used to indicate if stop loss was triggered in the last step 113 | self.stop_loss_flag = 0 114 | 115 | # Generate the next observation after the reset 116 | return self._next_observation() 117 | 118 | 119 | ####################################################### 120 | ############# NEXT OBSERVATION METHOD ################# 121 | def _next_observation(self): 122 | """ 123 | Generates the next observation of the market and the agent's state. 124 | 125 | This method is called to update the state of the environment after each action. 126 | 127 | Returns: 128 | - The next observation, which includes both market data and account information. 129 | """ 130 | 131 | # Extract the market data for the current timestamp from the features DataFrame 132 | filtered_df = self.features.loc[[self.current_timestamp]] 133 | 134 | # Ensure that the current row is within the bounds of the filtered data 135 | if self.current_row < len(filtered_df): 136 | row = filtered_df.iloc[self.current_row] 137 | market_data = row.values 138 | 139 | # Update the current market price from the 'Price' column 140 | self.current_market_price = row['Price'] 141 | 142 | # Calculate the open profit and loss (PnL) based on the current position 143 | if self.position == 'long': 144 | self.open_pnl = (self.current_market_price - self.entry_price) * self.tick_value 145 | elif self.position == 'short': 146 | self.open_pnl = (self.entry_price - self.current_market_price) * self.tick_value 147 | else: 148 | self.open_pnl = 0 149 | 150 | # Calculate the drawdown from the open PnL 151 | self.drawdown = min(self.open_pnl, 0) if self.position is not None else 0 152 | 153 | # Update the stop loss flag based on account balance and high water mark 154 | self.stop_loss_flag = 1 if self.balance <= (self.high_water_mark - self.stop_loss) else 0 155 | 156 | # Combine market data with additional account information for the observation 157 | position_flags = np.array([int(self.position == 'long'), int(self.position == 'short')]) 158 | additional_info = np.array([self.balance, *position_flags, self.stop_loss_flag, self.open_pnl, self.drawdown]) 159 | observation = np.concatenate((market_data, additional_info)) 160 | return observation 161 | else: 162 | # Handle the case when the current row exceeds the length of filtered data 163 | self.current_timestamp_index += 1 164 | if self.current_timestamp_index >= len(self.unique_timestamps): 165 | # Reset the environment if the end of the dataset is reached 166 | return self.reset() 167 | else: 168 | # Move to the next timestamp and reset the current row 169 | self.current_timestamp = self.unique_timestamps[self.current_timestamp_index] 170 | self.current_row = 0 171 | return self._next_observation() 172 | 173 | 174 | ############################################## 175 | ############# STEP METHOD #################### 176 | 177 | def step(self, action): 178 | reward = 0 179 | done = False 180 | info = {} 181 | 182 | # Define transaction cost 183 | transaction_cost = 10 # Example fixed cost per trade 184 | 185 | # Extract the current time 186 | current_time = self.current_timestamp.time() 187 | 188 | # End of Day Exit (outside of trading hours) 189 | if current_time > self.end_time: 190 | if self.position == 'long' or self.position == 'short': 191 | if self.position == 'long': 192 | action = 1 # Force Buy Exit 193 | elif self.position == 'short': 194 | action = 3 # Force Short Exit 195 | self.position = None 196 | self.entry_price = None 197 | self.entry_balance = None 198 | self.drawdown = 0 199 | #self.open_pnl = 0 200 | #print("End of Day Exiting all positions") 201 | 202 | # Check if the current time is within the trading hours 203 | if self.start_time <= current_time <= self.end_time: 204 | 205 | # Check if the current time is within the trading hours 206 | if self.position is None: 207 | # Buy Entry 208 | if action == 0: 209 | self.position = 'long' 210 | self.entry_price = self.current_market_price 211 | #reward -= transaction_cost # Applying transaction cost 212 | 213 | # Short Entry 214 | elif action == 2: 215 | self.position = 'short' 216 | self.entry_price = self.current_market_price 217 | #reward -= transaction_cost # Applying transaction cost 218 | 219 | # Do Nothing 220 | elif action == 4: 221 | pass 222 | 223 | elif self.position == 'long': 224 | # Buy Exit 225 | if action == 1: 226 | pnl = (self.current_market_price - self.entry_price) * self.tick_value 227 | self.balance += pnl 228 | reward += pnl * 10 229 | self.position = None 230 | #reward -= transaction_cost # Applying transaction cost 231 | 232 | elif self.position == 'short': 233 | # Short Exit 234 | if action == 3: 235 | pnl = (self.entry_price - self.current_market_price) * self.tick_value 236 | self.balance += pnl 237 | reward += pnl * 10 238 | self.position = None 239 | #reward -= transaction_cost # Applying transaction cost 240 | 241 | # Risk management penalty 242 | # Example risk penalty (can be customized) 243 | 244 | risk_penalty = 0 245 | 246 | if self.position is not None and (self.current_market_price - self.entry_price) * self.tick_value < -500: 247 | risk_penalty = 10 248 | 249 | reward -= risk_penalty 250 | 251 | 252 | # Sequential event processing 253 | # Increment the row index within the current timestamp 254 | self.current_row += 1 255 | 256 | # Check if all rows for the current timestamp are processed 257 | current_timestamp_rows = len(self.features.loc[[self.current_timestamp]]) 258 | if self.current_row >= current_timestamp_rows or self.current_timestamp_index >= len(self.unique_timestamps) - 1: 259 | # Move to the next timestamp 260 | self.current_timestamp_index += 1 261 | if self.current_timestamp_index >= len(self.unique_timestamps): 262 | done = True 263 | else: 264 | self.current_timestamp = self.unique_timestamps[self.current_timestamp_index] 265 | self.current_row = 0 266 | else: 267 | # Update current market price if still within the current timestamp 268 | self.current_market_price = self.features.loc[self.current_timestamp].iloc[self.current_row]['Price'] 269 | 270 | observation = self._next_observation() 271 | return observation, reward, done, info 272 | 273 | 274 | ############################################## 275 | ############# RENDER METHOD ################## 276 | def render(self, mode='human', action=None, reward=None): 277 | if mode == 'human': 278 | # Human-readable printout of the current state in a single line 279 | state_info = ( 280 | f"Timestamp: {self.current_timestamp}, " 281 | f"Current Row: {self.current_row}, " 282 | f"Current Position: {self.position}, " 283 | f"Action: {action}," 284 | f"Balance: {self.balance}, " 285 | f"Entry Price: {self.entry_price}, " 286 | f"Current Market Price: {self.current_market_price}, " 287 | f"Open PnL: {self.open_pnl}," 288 | f"Reward: {reward}, " 289 | f"Drawdown: {self.drawdown}, " 290 | ) 291 | print(state_info) 292 | else: 293 | raise NotImplementedError("Only 'human' mode is supported for rendering.") 294 | 295 | -------------------------------------------------------------------------------- /model_training/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Invest-In-a-Tech/Deep-Reinforcement-Learning-DayTradingAgent/71c28cadc2e3fdaaab6a37a00f9cdc9b8d5abc0d/model_training/preprocessing/__init__.py -------------------------------------------------------------------------------- /model_training/preprocessing/dataframe_processor.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | class DataFrameProcessor: 4 | def __init__(self, file_path): 5 | self.file_path = file_path 6 | self.df = None 7 | 8 | 9 | def process_data(self): 10 | self.df = pd.read_csv(self.file_path) 11 | self.df['Date'] = pd.to_datetime(self.df['Date'], format='%Y-%m-%d %H:%M:%S') 12 | self.df = self.df.set_index('Date') 13 | self.df = self.df.dropna() 14 | #print(f"Heads of DataFrame: {self.df.head()}") 15 | #print(f"Tails of DataFrame: {self.df.tail()}") 16 | print(f"After processing data, DataFrame shape: {self.df.shape}") 17 | return self.df 18 | 19 | 20 | # Example usage 21 | #processor = DataFrameProcessor(file_path=os.path.join("model_training", "data", "es_tuples_dataset.csv")) 22 | #df = processor.process_data() -------------------------------------------------------------------------------- /model_training/preprocessing/feature_engineering.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | class FeatureEngineering: 5 | def __init__(self, df): 6 | # Copying DataFrame if necessary 7 | self.df = df if df is not None else pd.DataFrame() 8 | 9 | 10 | def perform_feature_engineering(self): 11 | """ 12 | Enhancements for Performance: 13 | 14 | - Used vectorized group operations with pandas groupby and apply methods. 15 | This avoids looping over groups manually and leverages pandas' optimized 16 | internal computations for group operations. 17 | 18 | - Custom function 'calculate_levels' is applied to each group to perform 19 | necessary calculations in a vectorized way, increasing efficiency. 20 | 21 | - Efficient merging with the original DataFrame and in-place handling of 22 | missing values further enhances performance. 23 | """ 24 | # Ensure DataFrame has necessary columns 25 | required_columns = {'TotalVolume', 'Price'} 26 | if not required_columns.issubset(self.df.columns): 27 | raise ValueError(f"DataFrame must contain columns: {required_columns}") 28 | 29 | # Group data by Date 30 | if not self.df.index.is_unique: 31 | grouped = self.df.groupby(self.df.index) 32 | else: 33 | # If each row has a unique index, grouping is not necessary 34 | grouped = [(self.df.index[i], self.df.iloc[[i]]) for i in range(len(self.df))] 35 | 36 | 37 | def calculate_levels(group): 38 | sorted_volumes = group.sort_values('TotalVolume') 39 | poc = sorted_volumes.iloc[-1] 40 | hvn = sorted_volumes.iloc[-2] if len(sorted_volumes) > 1 else poc 41 | lvn = sorted_volumes.iloc[0] 42 | 43 | return pd.Series({ 44 | 'POC_Price': poc['Price'], 45 | 'POC_Volume': poc['TotalVolume'], 46 | 'HVN_Price': hvn['Price'], 47 | 'HVN_Volume': hvn['TotalVolume'], 48 | 'LVN_Price': lvn['Price'], 49 | 'LVN_Volume': lvn['TotalVolume'] 50 | }) 51 | 52 | combined_df = grouped.apply(calculate_levels) 53 | 54 | # Merge with Original DataFrame 55 | self.df = self.df.merge(combined_df, left_index=True, right_index=True, how='left') 56 | 57 | 58 | # Handle Missing Values 59 | self.df.ffill(inplace=True) 60 | 61 | #print(self.df.tail(50)) 62 | print(f"After processing footprint dataframe, DataFrame shape: {self.df.shape}") 63 | return self.df.dropna() -------------------------------------------------------------------------------- /model_training_app.py: -------------------------------------------------------------------------------- 1 | # main.py 2 | 3 | 4 | from model_training.agents.ppo_train_test_split import run_training 5 | #from model_training.agents.test_code import run_training 6 | 7 | def main(): 8 | 9 | # Call the training function or class method from training.py here 10 | run_training() 11 | #run_testing() 12 | 13 | if __name__ == "__main__": 14 | main() 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #flask 2 | stable-baselines3 3 | shimmy>=0.2.1 4 | scikit-learn 5 | python-dateutil 6 | python-dotenv 7 | gym 8 | #requests 9 | #gunicorn 10 | trade29_scpy-0.0.14.tar.gz 11 | tensorboard -------------------------------------------------------------------------------- /sierracharts_data_downloader/sc_data_downloader.py: -------------------------------------------------------------------------------- 1 | 2 | # Import libraries 3 | import pandas as pd 4 | from trade29.sc.bridge import SCBridge 5 | import os 6 | from dotenv import find_dotenv, load_dotenv 7 | 8 | 9 | # Load environment variables 10 | load_dotenv(find_dotenv()) 11 | sc_api_key = os.environ.get("SC_API_KEY") 12 | 13 | 14 | # Process the DataFrame 15 | def process_dataframe(df): 16 | df = df.sort_values(by=['Date', 'Price']) 17 | df['Date'] = pd.to_datetime(df['Date'], format='%Y-%m-%d %H:%M:%S') 18 | df.set_index('Date', inplace=True) 19 | return df 20 | 21 | 22 | # Create an instance of the SCBridge class 23 | sc = SCBridge() 24 | 25 | 26 | # Request the data 27 | sc.graph_data_request( 28 | sc_api_key, 29 | historical_init_bars=5, # Max historical bars is 600,000 30 | include_vbp=True, 31 | on_bar_close=True, 32 | base_data='1;2;3;4;5', 33 | sg_data="ID2.[SG1;SG10];ID4.[SG1];ID5.[SG2-SG3];ID6.[SG1]" 34 | ) 35 | 36 | 37 | # Wait for the response 38 | msg = sc.get_response_queue().get() 39 | 40 | 41 | # Rename columns for the entire msg.df 42 | msg.df.columns = [ 43 | 'Date', 44 | 'BarNumber', 45 | 'Open', 46 | 'High', 47 | 'Low', 48 | 'Close', 49 | 'Volume', 50 | 'Delta', 51 | 'CVD', 52 | 'TodayOpen', 53 | 'PrevHigh', 54 | 'PrevLow', 55 | 'VWAP', 56 | 'Tuples' 57 | ] 58 | 59 | 60 | # Process tuple data using vectorized operations 61 | exploded_df = msg.df.explode('Tuples') 62 | tuple_cols = ['Price', 'Bid', 'Ask', 'TotalVolume', 'NumberOfTrades'] 63 | exploded_df[tuple_cols] = pd.DataFrame(exploded_df['Tuples'].tolist(), index=exploded_df.index) 64 | exploded_df.drop(columns=['BarNumber', 'Tuples'], inplace=True) 65 | vbp_df = process_dataframe(exploded_df) 66 | print(vbp_df.shape) 67 | #print(vbp_df.tail(10)) 68 | 69 | 70 | # Save the DataFrame to a CSV file 71 | file_path = os.path.join('model_training', 'data', 'ES_tuples.csv') 72 | vbp_df.to_csv(file_path, index=True) 73 | print(f"DataFrame saved to {file_path}") -------------------------------------------------------------------------------- /trade29_scpy-0.0.14.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Invest-In-a-Tech/Deep-Reinforcement-Learning-DayTradingAgent/71c28cadc2e3fdaaab6a37a00f9cdc9b8d5abc0d/trade29_scpy-0.0.14.tar.gz --------------------------------------------------------------------------------