├── .gitignore ├── .vscode └── launch.json ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── Tutorials ├── 01_Creating_trading_environment.md ├── 02_Trading_simulation_backbone.md ├── 03_Trading_with_RL.md ├── 04_Indicators_and_Metrics.md └── Documents │ ├── 01_FinRock.jpg │ ├── 02_FinRock.jpg │ ├── 02_FinRock_render.png │ ├── 03_FinRock.jpg │ ├── 03_FinRock_render.png │ ├── 04_FinRock.jpg │ └── 04_FinRock_render.png ├── bin ├── create_sinusoid_data.py └── plot_data.py ├── experiments ├── playing_random_sinusoid.py ├── testing_ppo_sinusoid_continuous.py ├── testing_ppo_sinusoid_discrete.py ├── training_ppo_sinusoid_continuous.py └── training_ppo_sinusoid_discrete.py ├── finrock ├── __init__.py ├── data_feeder.py ├── indicators.py ├── metrics.py ├── render.py ├── reward.py ├── scalers.py ├── state.py └── trading_env.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.egg-info 3 | api.json 4 | venv* 5 | Datasets 6 | runs 7 | Models -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": false 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## [0.5.0] - 2024-01-30 2 | ### Added: 3 | - Added `MACD` indicator to `indicators` file. 4 | - Added `reward.AccountValueChangeReward` object to calculate reward based on the change in the account value. 5 | - Added `scalers.ZScoreScaler` that doesn't require min and max to transform data, but uses mean and std instead. 6 | - Added `ActionSpace` object to handle the action space of the agent. 7 | - Added support for continuous actions. (float values between 0 and 1) 8 | 9 | ### Changed: 10 | - Updated all indicators to have `config` parameter, that we can use so we can serialize the indicators. (save/load configurations to/from file) 11 | - Changed `reward.simpleReward` to `reward.SimpleReward` Object. 12 | - Updated `state.State` to have `open`, `high`, `low`, `close` and `volume` attributes. 13 | - Updated `data_feeder.PdDataFeeder` to be serializable by including `save_config` and `load_config` methods. 14 | - Included trading fees into `trading_env.TradingEnv` object. 15 | - Updated `trading_env.TradingEnv` to have `reset` method, which resets the environment to the initial state. 16 | - Included `save_config` and `load_config` methods into `trading_env.TradingEnv` object, so we can save/load the environment configuration. 17 | 18 | ## [0.4.0] - 2024-01-02 19 | ### Added: 20 | - Created `indicators` file, where I added `BolingerBands`, `RSI`, `PSAR`, `SMA` indicators 21 | - Added `SharpeRatio` and `MaxDrawdown` metrics to `metrics` 22 | - Included indicators handling into `data_feeder.PdDataFeeder` object 23 | - Included indicators handling into `state.State` object 24 | 25 | ### Changed: 26 | - Changed `finrock` package dependency from `0.0.4` to `0.4.1` 27 | - Refactored `render.PygameRender` object to handle indicators rendering (getting very messy) 28 | - Updated `scalers.MinMaxScaler` to handle indicators scaling 29 | - Updated `trading_env.TradingEnv` to raise an error with `np.nan` data and skip `None` states 30 | 31 | 32 | ## [0.3.0] - 2023-12-05 33 | ### Added: 34 | - Added `DifferentActions` and `AccountValue` as metrics. Metrics are the main way to evaluate the performance of the agent. 35 | - Now `metrics.Metrics` object can be used to calculate the metrics within trading environment. 36 | - Included `rockrl==0.0.4` as a dependency, which is a reinforcement learning package that I created. 37 | - Added `experiments/training_ppo_sinusoid.py` to train a simple Dense agent using PPO algorithm on the sinusoid data with discrete actions. 38 | - Added `experiments/testing_ppo_sinusoid.py` to test the trained agent on the sinusoid data with discrete actions. 39 | 40 | ### Changed: 41 | - Renamed and moved `playing.py` to `experiments/playing_random_sinusoid.py` 42 | - Upgraded `finrock.render.PygameRender`, now we can stop/resume rendering with spacebar and render account value along with the actions 43 | 44 | 45 | ## [0.2.0] - 2023-11-29 46 | ### Added: 47 | - Created `reward.simpleReward` function to calculate reward based on the action and the difference between the current price and the previous price 48 | - Created `scalers.MinMaxScaler` object to transform the price data to a range between 0 and 1 and prepare it for the neural networks input 49 | - Created `state.Observations` object to hold the observations of the agent with set window size 50 | - Updated `render.PygameRender` object to render the agent's actions 51 | - Updated `state.State` to hold current state `assets`, `balance` and `allocation_percentage` on specific State 52 | 53 | 54 | ## [0.1.0] - 2023-10-17 55 | ### Initial Release: 56 | - Created the project 57 | - Created code to create random sinusoidal price data 58 | - Created `state.State` object, which holds the state of the market 59 | - Created `render.PygameRender` object, which renders the state of the market using `pygame` library 60 | - Created `trading_env.TradingEnv` object, which is the environment for the agent to interact with 61 | - Created `data_feeder.PdDataFeeder` object, which feeds the environment with data from a pandas dataframe -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-exclude *.pyc 2 | include requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FinRock 2 | Reinforcement Learning package for Finance 3 | 4 | # Environment Structure: 5 |

6 | 7 |

8 | 9 | ### Install requirements: 10 | ``` 11 | pip install -r requirements.txt 12 | pip install pygame 13 | pip install . 14 | ``` 15 | 16 | ### Create sinusoid data: 17 | ``` 18 | python bin/create_sinusoid_data.py 19 | ``` 20 | 21 | ### Train RL (PPO) agent on discrete actions: 22 | ``` 23 | experiments/training_ppo_sinusoid.py 24 | ``` 25 | 26 | ### Test trained agent (Change path to the saved model): 27 | ``` 28 | experiments/testing_ppo_sinusoid.py 29 | ``` 30 | 31 | ### Environment Render: 32 |

33 | 34 |

35 | 36 | ## Links to YouTube videos: 37 | - [Introduction to FinRock package](https://youtu.be/xU_YJB7vilA) 38 | - [Complete Trading Simulation Backbone](https://youtu.be/1z5geob8Yho) 39 | - [Training RL agent on Sinusoid data](https://youtu.be/JkA4BuYvWyE) 40 | - [Included metrics and indicators into environment](https://youtu.be/bGpBEnKzIdo) 41 | 42 | # TODO: 43 | - [ ] Train model on `continuous` actions (control allocation percentage) 44 | - [ ] Add more indicators 45 | - [ ] Add more metrics 46 | - [ ] Add more reward functions 47 | - [ ] Add more scalers 48 | - [ ] Train RL agent on real data 49 | - [ ] Add more RL algorithms 50 | - [ ] Refactor rendering, maybe move to browser? -------------------------------------------------------------------------------- /Tutorials/01_Creating_trading_environment.md: -------------------------------------------------------------------------------- 1 | # Introduction to FinRock package 2 | 3 | ### Environment Structure: 4 |

5 | 6 |

7 | 8 | ### Link to YouTube video: 9 | https://youtu.be/xU_YJB7vilA 10 | 11 | ### Link to tutorial code: 12 | https://github.com/pythonlessons/FinRock/tree/0.1.0 13 | 14 | ### Download tutorial code: 15 | https://github.com/pythonlessons/FinRock/archive/refs/tags/0.1.0.zip 16 | 17 | 18 | ### Install requirements: 19 | ``` 20 | pip install -r requirements.txt 21 | pip install pygame 22 | ``` 23 | 24 | ### Create sinusoid data: 25 | ``` 26 | python bin/create_sinusoid_data.py 27 | ``` 28 | 29 | ### Run environment: 30 | ``` 31 | python playing.py 32 | ``` -------------------------------------------------------------------------------- /Tutorials/02_Trading_simulation_backbone.md: -------------------------------------------------------------------------------- 1 | # Complete Trading Simulation Backbone 2 | 3 | ### Environment Structure: 4 |

5 | 6 |

7 | 8 | ### Link to YouTube video: 9 | https://youtu.be/1z5geob8Yho 10 | 11 | ### Link to tutorial code: 12 | https://github.com/pythonlessons/FinRock/tree/0.2.0 13 | 14 | ### Download tutorial code: 15 | https://github.com/pythonlessons/FinRock/archive/refs/tags/0.2.0.zip 16 | 17 | 18 | ### Install requirements: 19 | ``` 20 | pip install -r requirements.txt 21 | pip install pygame 22 | ``` 23 | 24 | ### Create sinusoid data: 25 | ``` 26 | python bin/create_sinusoid_data.py 27 | ``` 28 | 29 | ### Run environment: 30 | ``` 31 | python playing.py 32 | ``` 33 | 34 | ### Environment Render: 35 |

36 | 37 |

-------------------------------------------------------------------------------- /Tutorials/03_Trading_with_RL.md: -------------------------------------------------------------------------------- 1 | # Complete Trading Simulation Backbone 2 | 3 | ### Environment Structure: 4 |

5 | 6 |

7 | 8 | ### Link to YouTube video: 9 | https://youtu.be/JkA4BuYvWyE 10 | 11 | ### Link to tutorial code: 12 | https://github.com/pythonlessons/FinRock/tree/0.3.0 13 | 14 | ### Download tutorial code: 15 | https://github.com/pythonlessons/FinRock/archive/refs/tags/0.3.0.zip 16 | 17 | 18 | ### Install requirements: 19 | ``` 20 | pip install -r requirements.txt 21 | pip install pygame 22 | pip install . 23 | ``` 24 | 25 | ### Create sinusoid data: 26 | ``` 27 | python bin/create_sinusoid_data.py 28 | ``` 29 | 30 | ### Train RL (PPO) agent on discrete actions: 31 | ``` 32 | experiments/training_ppo_sinusoid.py 33 | ``` 34 | 35 | ### Test trained agent (Change path to the saved model): 36 | ``` 37 | experiments/testing_ppo_sinusoid.py 38 | ``` 39 | 40 | ### Environment Render: 41 |

42 | 43 |

-------------------------------------------------------------------------------- /Tutorials/04_Indicators_and_Metrics.md: -------------------------------------------------------------------------------- 1 | # Complete Trading Simulation Backbone 2 | 3 | ### Environment Structure: 4 |

5 | 6 |

7 | 8 | ### Link to YouTube video: 9 | https://youtu.be/bGpBEnKzIdo 10 | 11 | ### Link to tutorial code: 12 | https://github.com/pythonlessons/FinRock/tree/0.4.0 13 | 14 | ### Download tutorial code: 15 | https://github.com/pythonlessons/FinRock/archive/refs/tags/0.4.0.zip 16 | 17 | 18 | ### Install requirements: 19 | ``` 20 | pip install -r requirements.txt 21 | pip install pygame 22 | pip install . 23 | ``` 24 | 25 | ### Create sinusoid data: 26 | ``` 27 | python bin/create_sinusoid_data.py 28 | ``` 29 | 30 | ### Train RL (PPO) agent on discrete actions: 31 | ``` 32 | experiments/training_ppo_sinusoid.py 33 | ``` 34 | 35 | ### Test trained agent (Change path to the saved model): 36 | ``` 37 | experiments/testing_ppo_sinusoid.py 38 | ``` 39 | 40 | ### Environment Render: 41 |

42 | 43 |

-------------------------------------------------------------------------------- /Tutorials/Documents/01_FinRock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/01_FinRock.jpg -------------------------------------------------------------------------------- /Tutorials/Documents/02_FinRock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/02_FinRock.jpg -------------------------------------------------------------------------------- /Tutorials/Documents/02_FinRock_render.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/02_FinRock_render.png -------------------------------------------------------------------------------- /Tutorials/Documents/03_FinRock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/03_FinRock.jpg -------------------------------------------------------------------------------- /Tutorials/Documents/03_FinRock_render.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/03_FinRock_render.png -------------------------------------------------------------------------------- /Tutorials/Documents/04_FinRock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/04_FinRock.jpg -------------------------------------------------------------------------------- /Tutorials/Documents/04_FinRock_render.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/04_FinRock_render.png -------------------------------------------------------------------------------- /bin/create_sinusoid_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | from datetime import datetime, timedelta 6 | 7 | def create_sinusoidal_df( 8 | amplitude = 2000.0, # Amplitude of the price variations 9 | frequency = 0.01, # Frequency of the price variations 10 | phase = 0.0, # Phase shift of the price variations 11 | num_samples = 10000, # Number of data samples 12 | data_shift = 20000, # shift the data up 13 | trendline_down = 5000, # shift the data down 14 | plot = False, 15 | ): 16 | """Create a dataframe with sinusoidal data""" 17 | 18 | # Generate the time axis 19 | t = np.linspace(0, 2 * np.pi * frequency * num_samples, num_samples) 20 | 21 | # Get the current datetime 22 | now = datetime.now() 23 | 24 | # Set hours, minutes, and seconds to zero 25 | now = now.replace(hour=0, minute=0, second=0, microsecond=0) 26 | 27 | # Generate timestamps for each day 28 | # timestamps = [now - timedelta(days=i) for i in range(num_samples)] 29 | timestamps = [now - timedelta(hours=i*4) for i in range(num_samples)] 30 | 31 | # Convert datetime objects to strings 32 | timestamps = [timestamps.strftime('%Y-%m-%d %H:%M:%S') for timestamps in timestamps] 33 | 34 | # Invert the order of the timestamps 35 | timestamps = timestamps[::-1] 36 | 37 | # Generate the sinusoidal data for prices 38 | sin_data = amplitude * np.sin(t + phase) 39 | sin_data += data_shift # shift the data up 40 | 41 | # shiwft sin_data up, to create trendline up 42 | sin_data -= np.linspace(0, trendline_down, num_samples) 43 | 44 | # Add random noise 45 | noise = np.random.uniform(0.95, 1.05, len(t)) # generate random noise 46 | noisy_sin_data = sin_data * noise # add noise to the original data 47 | 48 | price_range = np.max(noisy_sin_data) - np.min(noisy_sin_data) 49 | 50 | # Generate random low and close prices 51 | low_prices = noisy_sin_data - np.random.uniform(0, 0.1 * price_range, len(noisy_sin_data)) 52 | close_prices = noisy_sin_data + np.random.uniform(-0.05 * price_range, 0.05 * price_range, len(noisy_sin_data)) 53 | 54 | # open prices usually are close to the close prices of the previous day 55 | open_prices = np.zeros(len(close_prices)) 56 | open_prices[0] = close_prices[0] 57 | open_prices[1:] = close_prices[:-1] 58 | 59 | # high prices are always above open and close prices 60 | high_prices = np.maximum(open_prices, close_prices) + np.random.uniform(0, 0.1 * price_range, len(close_prices)) 61 | 62 | # low prices are always below open and close prices 63 | low_prices = np.minimum(open_prices, close_prices) - np.random.uniform(0, 0.1 * price_range, len(close_prices)) 64 | 65 | if plot: 66 | # Plot the price data 67 | plt.figure(figsize=(10, 6)) 68 | plt.plot(t, noisy_sin_data, label='Noisy Sinusoidal Data') 69 | plt.plot(t, open_prices, label='Open') 70 | plt.plot(t, low_prices, label='Low') 71 | plt.plot(t, close_prices, label='Close') 72 | plt.plot(t, high_prices, label='High') 73 | plt.xlabel('Time') 74 | plt.ylabel('Price') 75 | plt.title('Fake Price Data') 76 | plt.legend() 77 | plt.grid(True) 78 | plt.show() 79 | 80 | # save the data to a CSV file with matplotlib as df[['open', 'high', 'low', 'close'] 81 | df = pd.DataFrame({'timestamp': timestamps, 'open': open_prices, 'high': high_prices, 'low': low_prices, 'close': close_prices}) 82 | 83 | return df 84 | 85 | if __name__ == '__main__': 86 | # Create a dataframe with sinusoidal data 87 | df = create_sinusoidal_df() 88 | 89 | # Create a directory to store the datasets 90 | os.makedirs('Datasets', exist_ok=True) 91 | 92 | # Save the dataframe to a CSV file 93 | df.to_csv(f'Datasets/random_sinusoid.csv') -------------------------------------------------------------------------------- /bin/plot_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | 4 | df = pd.read_csv('Datasets/random_sinusoid.csv', index_col='timestamp', parse_dates=True) 5 | df = df[['open', 'high', 'low', 'close']] 6 | # limit to last 1000 data points 7 | df = df[-1000:] 8 | 9 | # plot the data 10 | plt.figure(figsize=(10, 6)) 11 | plt.plot(df['close']) 12 | plt.xlabel('Time') 13 | plt.ylabel('Price') 14 | plt.title('random_sinusoid.csv') 15 | plt.grid(True) 16 | plt.show() -------------------------------------------------------------------------------- /experiments/playing_random_sinusoid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from finrock.data_feeder import PdDataFeeder 5 | from finrock.trading_env import TradingEnv 6 | from finrock.render import PygameRender 7 | from finrock.scalers import ZScoreScaler 8 | from finrock.reward import AccountValueChangeReward 9 | from finrock.indicators import BolingerBands, SMA, RSI, PSAR, MACD 10 | from finrock.metrics import DifferentActions, AccountValue, MaxDrawdown, SharpeRatio 11 | 12 | df = pd.read_csv('Datasets/random_sinusoid.csv') 13 | 14 | pd_data_feeder = PdDataFeeder( 15 | df = df, 16 | indicators = [ 17 | BolingerBands(data=df, period=20, std=2), 18 | RSI(data=df, period=14), 19 | PSAR(data=df), 20 | MACD(data=df), 21 | SMA(data=df, period=7), 22 | ] 23 | ) 24 | 25 | env = TradingEnv( 26 | data_feeder = pd_data_feeder, 27 | output_transformer = ZScoreScaler(), 28 | initial_balance = 1000.0, 29 | max_episode_steps = 1000, 30 | window_size = 50, 31 | reward_function = AccountValueChangeReward(), 32 | metrics = [ 33 | DifferentActions(), 34 | AccountValue(), 35 | MaxDrawdown(), 36 | SharpeRatio(), 37 | ] 38 | ) 39 | action_space = env.action_space 40 | input_shape = env.observation_space.shape 41 | 42 | env.save_config() 43 | 44 | pygameRender = PygameRender(frame_rate=60) 45 | 46 | state, info = env.reset() 47 | pygameRender.render(info) 48 | rewards = 0.0 49 | while True: 50 | # simulate model prediction, now use random action 51 | action = np.random.randint(0, action_space) 52 | 53 | state, reward, terminated, truncated, info = env.step(action) 54 | rewards += reward 55 | pygameRender.render(info) 56 | 57 | if terminated or truncated: 58 | print(info['states'][-1].account_value, rewards) 59 | rewards = 0.0 60 | state, info = env.reset() 61 | pygameRender.reset() -------------------------------------------------------------------------------- /experiments/testing_ppo_sinusoid_continuous.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import tensorflow as tf 4 | tf.get_logger().setLevel('ERROR') 5 | for gpu in tf.config.experimental.list_physical_devices('GPU'): 6 | tf.config.experimental.set_memory_growth(gpu, True) 7 | 8 | from finrock.data_feeder import PdDataFeeder 9 | from finrock.trading_env import TradingEnv 10 | from finrock.render import PygameRender 11 | 12 | 13 | df = pd.read_csv('Datasets/random_sinusoid.csv') 14 | df = df[-1000:] 15 | 16 | model_path = "runs/1704798174" 17 | 18 | pd_data_feeder = PdDataFeeder.load_config(df, model_path) 19 | env = TradingEnv.load_config(pd_data_feeder, model_path) 20 | 21 | action_space = env.action_space 22 | input_shape = env.observation_space.shape 23 | pygameRender = PygameRender(frame_rate=120) 24 | 25 | agent = tf.keras.models.load_model(f'{model_path}/ppo_sinusoid_actor.h5') 26 | 27 | state, info = env.reset() 28 | pygameRender.render(info) 29 | rewards = 0.0 30 | while True: 31 | # simulate model prediction, now use random action 32 | action = agent.predict(np.expand_dims(state, axis=0), verbose=False)[0][:-1] 33 | 34 | state, reward, terminated, truncated, info = env.step(action) 35 | rewards += reward 36 | pygameRender.render(info) 37 | 38 | if terminated or truncated: 39 | print(rewards) 40 | for metric, value in info['metrics'].items(): 41 | print(metric, value) 42 | state, info = env.reset() 43 | rewards = 0.0 44 | pygameRender.reset() 45 | pygameRender.render(info) -------------------------------------------------------------------------------- /experiments/testing_ppo_sinusoid_discrete.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import tensorflow as tf 4 | tf.get_logger().setLevel('ERROR') 5 | for gpu in tf.config.experimental.list_physical_devices('GPU'): 6 | tf.config.experimental.set_memory_growth(gpu, True) 7 | 8 | from finrock.data_feeder import PdDataFeeder 9 | from finrock.trading_env import TradingEnv 10 | from finrock.render import PygameRender 11 | 12 | 13 | df = pd.read_csv('Datasets/random_sinusoid.csv') 14 | df = df[-1000:] 15 | 16 | model_path = "runs/1704746665" 17 | 18 | pd_data_feeder = PdDataFeeder.load_config(df, model_path) 19 | env = TradingEnv.load_config(pd_data_feeder, model_path) 20 | 21 | action_space = env.action_space 22 | input_shape = env.observation_space.shape 23 | pygameRender = PygameRender(frame_rate=120) 24 | 25 | agent = tf.keras.models.load_model(f'{model_path}/ppo_sinusoid_actor.h5') 26 | 27 | state, info = env.reset() 28 | pygameRender.render(info) 29 | rewards = 0.0 30 | while True: 31 | # simulate model prediction, now use random action 32 | prob = agent.predict(np.expand_dims(state, axis=0), verbose=False)[0] 33 | action = np.argmax(prob) 34 | 35 | state, reward, terminated, truncated, info = env.step(action) 36 | rewards += reward 37 | pygameRender.render(info) 38 | 39 | if terminated or truncated: 40 | print(rewards) 41 | for metric, value in info['metrics'].items(): 42 | print(metric, value) 43 | state, info = env.reset() 44 | rewards = 0.0 45 | pygameRender.reset() 46 | pygameRender.render(info) -------------------------------------------------------------------------------- /experiments/training_ppo_sinusoid_continuous.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import tensorflow as tf 4 | tf.get_logger().setLevel('ERROR') 5 | for gpu in tf.config.experimental.list_physical_devices('GPU'): 6 | tf.config.experimental.set_memory_growth(gpu, True) 7 | 8 | from keras import layers, models 9 | 10 | from finrock.data_feeder import PdDataFeeder 11 | from finrock.trading_env import TradingEnv, ActionSpace 12 | from finrock.scalers import ZScoreScaler 13 | from finrock.reward import AccountValueChangeReward 14 | from finrock.metrics import DifferentActions, AccountValue, MaxDrawdown, SharpeRatio 15 | from finrock.indicators import BolingerBands, RSI, PSAR, SMA, MACD 16 | 17 | from rockrl.utils.misc import MeanAverage 18 | from rockrl.utils.memory import MemoryManager 19 | from rockrl.tensorflow import PPOAgent 20 | from rockrl.utils.vectorizedEnv import VectorizedEnv 21 | 22 | df = pd.read_csv('Datasets/random_sinusoid.csv') 23 | df = df[:-1000] # leave 1000 for testing 24 | 25 | pd_data_feeder = PdDataFeeder( 26 | df, 27 | indicators = [ 28 | BolingerBands(data=df, period=20, std=2), 29 | RSI(data=df, period=14), 30 | PSAR(data=df), 31 | MACD(data=df), 32 | SMA(data=df, period=7), 33 | ] 34 | ) 35 | 36 | num_envs = 10 37 | env = VectorizedEnv( 38 | env_object = TradingEnv, 39 | num_envs = num_envs, 40 | data_feeder = pd_data_feeder, 41 | output_transformer = ZScoreScaler(), 42 | initial_balance = 1000.0, 43 | max_episode_steps = 1000, 44 | window_size = 50, 45 | reward_function = AccountValueChangeReward(), 46 | action_space = ActionSpace.CONTINUOUS, 47 | metrics = [ 48 | DifferentActions(), 49 | AccountValue(), 50 | MaxDrawdown(), 51 | SharpeRatio(), 52 | ] 53 | ) 54 | 55 | action_space = env.action_space 56 | input_shape = env.observation_space.shape 57 | 58 | 59 | def actor_model(input_shape, action_space): 60 | input = layers.Input(shape=input_shape, dtype=tf.float32) 61 | x = layers.Flatten()(input) 62 | x = layers.Dense(512, activation='elu')(x) 63 | x = layers.Dense(256, activation='elu')(x) 64 | x = layers.Dense(64, activation='elu')(x) 65 | x = layers.Dropout(0.2)(x) 66 | action = layers.Dense(action_space, activation="tanh")(x) 67 | sigma = layers.Dense(action_space)(x) 68 | sigma = layers.Dense(1, activation='sigmoid')(sigma) 69 | output = layers.concatenate([action, sigma]) # continuous action space 70 | return models.Model(inputs=input, outputs=output) 71 | 72 | def critic_model(input_shape): 73 | input = layers.Input(shape=input_shape, dtype=tf.float32) 74 | x = layers.Flatten()(input) 75 | x = layers.Dense(512, activation='elu')(x) 76 | x = layers.Dense(256, activation='elu')(x) 77 | x = layers.Dense(64, activation='elu')(x) 78 | x = layers.Dropout(0.2)(x) 79 | output = layers.Dense(1, activation=None)(x) 80 | return models.Model(inputs=input, outputs=output) 81 | 82 | 83 | agent = PPOAgent( 84 | actor = actor_model(input_shape, action_space), 85 | critic = critic_model(input_shape), 86 | optimizer=tf.keras.optimizers.Adam(learning_rate=0.00005), 87 | batch_size=128, 88 | lamda=0.95, 89 | kl_coeff=0.5, 90 | c2=0.01, 91 | writer_comment='ppo_sinusoid', 92 | action_space="continuous", 93 | ) 94 | pd_data_feeder.save_config(agent.logdir) 95 | env.env.save_config(agent.logdir) 96 | 97 | memory = MemoryManager(num_envs=num_envs) 98 | meanAverage = MeanAverage(best_mean_score_episode=1000) 99 | states, infos = env.reset() 100 | rewards = 0.0 101 | while True: 102 | action, prob = agent.act(states) 103 | 104 | next_states, reward, terminated, truncated, infos = env.step(action) 105 | memory.append(states, action, reward, prob, terminated, truncated, next_states, infos) 106 | states = next_states 107 | 108 | for index in memory.done_indices(): 109 | env_memory = memory[index] 110 | history = agent.train(env_memory) 111 | mean_reward = meanAverage(np.sum(env_memory.rewards)) 112 | 113 | if meanAverage.is_best(agent.epoch): 114 | agent.save_models('ppo_sinusoid') 115 | 116 | if history['kl_div'] > 0.2 and agent.epoch > 1000: 117 | agent.reduce_learning_rate(0.995, verbose=False) 118 | 119 | info = env_memory.infos[-1] 120 | print(agent.epoch, np.sum(env_memory.rewards), mean_reward, info["metrics"]['account_value'], history['kl_div']) 121 | agent.log_to_writer(info['metrics']) 122 | states[index], infos[index] = env.reset(index=index) 123 | 124 | if agent.epoch >= 20000: 125 | break 126 | 127 | env.close() 128 | exit() -------------------------------------------------------------------------------- /experiments/training_ppo_sinusoid_discrete.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import tensorflow as tf 4 | tf.get_logger().setLevel('ERROR') 5 | for gpu in tf.config.experimental.list_physical_devices('GPU'): 6 | tf.config.experimental.set_memory_growth(gpu, True) 7 | 8 | from keras import layers, models 9 | 10 | from finrock.data_feeder import PdDataFeeder 11 | from finrock.trading_env import TradingEnv 12 | from finrock.scalers import MinMaxScaler, ZScoreScaler 13 | from finrock.reward import SimpleReward, AccountValueChangeReward 14 | from finrock.metrics import DifferentActions, AccountValue, MaxDrawdown, SharpeRatio 15 | from finrock.indicators import BolingerBands, RSI, PSAR, SMA, MACD 16 | 17 | from rockrl.utils.misc import MeanAverage 18 | from rockrl.utils.memory import MemoryManager 19 | from rockrl.tensorflow import PPOAgent 20 | from rockrl.utils.vectorizedEnv import VectorizedEnv 21 | 22 | df = pd.read_csv('Datasets/random_sinusoid.csv') 23 | df = df[:-1000] 24 | 25 | 26 | pd_data_feeder = PdDataFeeder( 27 | df, 28 | indicators = [ 29 | BolingerBands(data=df, period=20, std=2), 30 | RSI(data=df, period=14), 31 | PSAR(data=df), 32 | MACD(data=df), 33 | SMA(data=df, period=7), 34 | ] 35 | ) 36 | 37 | num_envs = 10 38 | env = VectorizedEnv( 39 | env_object = TradingEnv, 40 | num_envs = num_envs, 41 | data_feeder = pd_data_feeder, 42 | output_transformer = ZScoreScaler(), 43 | initial_balance = 1000.0, 44 | max_episode_steps = 1000, 45 | window_size = 50, 46 | reward_function = AccountValueChangeReward(), 47 | metrics = [ 48 | DifferentActions(), 49 | AccountValue(), 50 | MaxDrawdown(), 51 | SharpeRatio(), 52 | ] 53 | ) 54 | 55 | action_space = env.action_space 56 | input_shape = env.observation_space.shape 57 | 58 | def actor_model(input_shape, action_space): 59 | input = layers.Input(shape=input_shape, dtype=tf.float32) 60 | x = layers.Flatten()(input) 61 | x = layers.Dense(512, activation='elu')(x) 62 | x = layers.Dense(256, activation='elu')(x) 63 | x = layers.Dense(64, activation='elu')(x) 64 | x = layers.Dropout(0.2)(x) 65 | output = layers.Dense(action_space, activation='softmax')(x) # discrete action space 66 | return models.Model(inputs=input, outputs=output) 67 | 68 | def critic_model(input_shape): 69 | input = layers.Input(shape=input_shape, dtype=tf.float32) 70 | x = layers.Flatten()(input) 71 | x = layers.Dense(512, activation='elu')(x) 72 | x = layers.Dense(256, activation='elu')(x) 73 | x = layers.Dense(64, activation='elu')(x) 74 | x = layers.Dropout(0.2)(x) 75 | output = layers.Dense(1, activation=None)(x) 76 | return models.Model(inputs=input, outputs=output) 77 | 78 | agent = PPOAgent( 79 | actor = actor_model(input_shape, action_space), 80 | critic = critic_model(input_shape), 81 | optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), 82 | batch_size=128, 83 | lamda=0.95, 84 | kl_coeff=0.5, 85 | c2=0.01, 86 | writer_comment='ppo_sinusoid_discrete', 87 | ) 88 | 89 | pd_data_feeder.save_config(agent.logdir) 90 | env.env.save_config(agent.logdir) 91 | 92 | memory = MemoryManager(num_envs=num_envs) 93 | meanAverage = MeanAverage(best_mean_score_episode=1000) 94 | states, infos = env.reset() 95 | rewards = 0.0 96 | while True: 97 | action, prob = agent.act(states) 98 | 99 | next_states, reward, terminated, truncated, infos = env.step(action) 100 | memory.append(states, action, reward, prob, terminated, truncated, next_states, infos) 101 | states = next_states 102 | 103 | for index in memory.done_indices(): 104 | env_memory = memory[index] 105 | history = agent.train(env_memory) 106 | mean_reward = meanAverage(np.sum(env_memory.rewards)) 107 | 108 | if meanAverage.is_best(agent.epoch): 109 | agent.save_models('ppo_sinusoid') 110 | 111 | if history['kl_div'] > 0.05 and agent.epoch > 1000: 112 | agent.reduce_learning_rate(0.995, verbose=False) 113 | 114 | info = env_memory.infos[-1] 115 | print(agent.epoch, np.sum(env_memory.rewards), mean_reward, info["metrics"]['account_value'], history['kl_div']) 116 | agent.log_to_writer(info['metrics']) 117 | states[index], infos[index] = env.reset(index=index) 118 | 119 | if agent.epoch >= 10000: 120 | break 121 | 122 | env.close() 123 | exit() -------------------------------------------------------------------------------- /finrock/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5.0" -------------------------------------------------------------------------------- /finrock/data_feeder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import importlib 4 | import pandas as pd 5 | from finrock.state import State 6 | from finrock.indicators import Indicator 7 | 8 | 9 | class PdDataFeeder: 10 | def __init__( 11 | self, 12 | df: pd.DataFrame, 13 | indicators: list = [], 14 | min: float = None, 15 | max: float = None, 16 | ) -> None: 17 | self._df = df 18 | self._min = min 19 | self._max = max 20 | self._indicators = indicators 21 | self._cache = {} 22 | 23 | assert isinstance(self._df, pd.DataFrame) == True, "df must be a pandas.DataFrame" 24 | assert 'timestamp' in self._df.columns, "df must have 'timestamp' column" 25 | assert 'open' in self._df.columns, "df must have 'open' column" 26 | assert 'high' in self._df.columns, "df must have 'high' column" 27 | assert 'low' in self._df.columns, "df must have 'low' column" 28 | assert 'close' in self._df.columns, "df must have 'close' column" 29 | 30 | assert isinstance(self._indicators, list) == True, "indicators must be an iterable" 31 | assert all(isinstance(indicator, Indicator) for indicator in self._indicators) == True, "indicators must be a list of Indicator objects" 32 | 33 | @property 34 | def __name__(self) -> str: 35 | return self.__class__.__name__ 36 | 37 | @property 38 | def name(self) -> str: 39 | return self.__name__ 40 | 41 | @property 42 | def min(self) -> float: 43 | return self._min or self._df['low'].min() 44 | 45 | @property 46 | def max(self) -> float: 47 | return self._max or self._df['high'].max() 48 | 49 | def __len__(self) -> int: 50 | return len(self._df) 51 | 52 | def __getitem__(self, idx: int, args=None) -> State: 53 | # Use cache to speed up training 54 | if idx in self._cache: 55 | return self._cache[idx] 56 | 57 | indicators = [] 58 | for indicator in self._indicators: 59 | results = indicator(idx) 60 | if results is None: 61 | self._cache[idx] = None 62 | return None 63 | 64 | indicators.append(results) 65 | 66 | data = self._df.iloc[idx] 67 | state = State( 68 | timestamp=data['timestamp'], 69 | open=data['open'], 70 | high=data['high'], 71 | low=data['low'], 72 | close=data['close'], 73 | volume=data.get('volume', 0.0), 74 | indicators=indicators 75 | ) 76 | self._cache[idx] = state 77 | 78 | return state 79 | 80 | def __iter__(self) -> State: 81 | """ Create a generator that iterate over the Sequence.""" 82 | for index in range(len(self)): 83 | yield self[index] 84 | 85 | def save_config(self, path: str) -> None: 86 | config = { 87 | "indicators": [], 88 | "min": self.min, 89 | "max": self.max 90 | } 91 | for indicator in self._indicators: 92 | config["indicators"].append(indicator.config()) 93 | 94 | # save config into json file 95 | with open(os.path.join(path, "PdDataFeeder.json"), 'w') as outfile: 96 | json.dump(config, outfile, indent=4) 97 | 98 | @staticmethod 99 | def load_config(df, path: str) -> None: 100 | # load config from json file 101 | config_path = os.path.join(path, "PdDataFeeder.json") 102 | if not os.path.exists(config_path): 103 | raise Exception(f"PdDataFeeder Config file not found in {path}") 104 | 105 | with open(config_path) as json_file: 106 | config = json.load(json_file) 107 | 108 | _indicators = [] 109 | for indicator in config["indicators"]: 110 | indicator_class = getattr(importlib.import_module(".indicators", package=__package__), indicator["name"]) 111 | ind = indicator_class(data=df, **indicator) 112 | _indicators.append(ind) 113 | 114 | pdDataFeeder = PdDataFeeder(df=df, indicators=_indicators, min=config["min"], max=config["max"]) 115 | 116 | return pdDataFeeder -------------------------------------------------------------------------------- /finrock/indicators.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from .render import RenderOptions, RenderType, WindowType 4 | 5 | """ Implemented indicators: 6 | - SMA 7 | - Bolinger Bands 8 | - RSI 9 | - PSAR 10 | - MACD (Moving Average Convergence Divergence) 11 | 12 | TODO: 13 | - Commodity Channel Index (CCI), and X is the 14 | - Average Directional Index (ADX) 15 | """ 16 | 17 | 18 | class Indicator: 19 | """ Base class for indicators 20 | """ 21 | def __init__( 22 | self, 23 | data: pd.DataFrame, 24 | target_column: str='close', 25 | render_options: dict={}, 26 | min: float=None, 27 | max: float=None, 28 | **kwargs 29 | ) -> None: 30 | self._data = data.copy() 31 | self._target_column = target_column 32 | self._custom_render_options = render_options 33 | self._render_options = render_options 34 | self._min = min # if min is not None else self._data[target_column].min() 35 | self._max = max # if max is not None else self._data[target_column].max() 36 | self.values = {} 37 | 38 | assert isinstance(self._data, pd.DataFrame) == True, "data must be a pandas.DataFrame" 39 | assert self._target_column in self._data.columns, f"data must have '{self._target_column}' column" 40 | 41 | self.compute() 42 | if not self._custom_render_options: 43 | self._render_options = self.default_render_options() 44 | 45 | @property 46 | def min(self): 47 | return self._min 48 | 49 | @min.setter 50 | def min(self, min: float): 51 | self._min = self._min or min 52 | if not self._custom_render_options: 53 | self._render_options = self.default_render_options() 54 | 55 | @property 56 | def max(self): 57 | return self._max 58 | 59 | @max.setter 60 | def max(self, max: float): 61 | self._max = self._max or max 62 | if not self._custom_render_options: 63 | self._render_options = self.default_render_options() 64 | 65 | @property 66 | def target_column(self): 67 | return self._target_column 68 | 69 | @property 70 | def __name__(self) -> str: 71 | return self.__class__.__name__ 72 | 73 | @property 74 | def name(self): 75 | return self.__name__ 76 | 77 | @property 78 | def names(self): 79 | return self._names 80 | 81 | def compute(self): 82 | raise NotImplementedError 83 | 84 | def default_render_options(self): 85 | return {} 86 | 87 | def render_options(self): 88 | return {name: option.copy() for name, option in self._render_options.items()} 89 | 90 | def __getitem__(self, index: int): 91 | row = self._data.iloc[index] 92 | for name in self.names: 93 | if pd.isna(row[name]): 94 | return None 95 | 96 | self.values[name] = row[name] 97 | if self._render_options.get(name): 98 | self._render_options[name].value = row[name] 99 | 100 | return self.serialise() 101 | 102 | def __call__(self, index: int): 103 | return self[index] 104 | 105 | def serialise(self): 106 | return { 107 | 'name': self.name, 108 | 'names': self.names, 109 | 'values': self.values.copy(), 110 | 'target_column': self.target_column, 111 | 'render_options': self.render_options(), 112 | 'min': self.min, 113 | 'max': self.max 114 | } 115 | 116 | def config(self): 117 | return { 118 | 'name': self.name, 119 | 'names': self.names, 120 | 'target_column': self.target_column, 121 | 'min': self.min, 122 | 'max': self.max 123 | } 124 | 125 | 126 | 127 | class SMA(Indicator): 128 | """ Trend indicator 129 | 130 | A simple moving average (SMA) calculates the average of a selected range of prices, usually closing prices, by the number 131 | of periods in that range. 132 | 133 | The SMA is a technical indicator for determining if an asset price will continue or reverse a bull or bear trend. It is 134 | calculated by summing up the closing prices of a stock over time and then dividing that total by the number of time periods 135 | being examined. Short-term averages respond quickly to changes in the price of the underlying, while long-term averages are 136 | slow to react. 137 | 138 | https://www.investopedia.com/terms/s/sma.asp 139 | """ 140 | def __init__( 141 | self, 142 | data: pd.DataFrame, 143 | period: int=20, 144 | target_column: str='close', 145 | render_options: dict={}, 146 | **kwargs 147 | ): 148 | self._period = period 149 | self._names = [f'SMA{period}'] 150 | super().__init__(data, target_column, render_options, **kwargs) 151 | self.min = self._data[self._names[0]].min() 152 | self.max = self._data[self._names[0]].max() 153 | 154 | def default_render_options(self): 155 | return {name: RenderOptions( 156 | name=name, 157 | color=(100, 100, 255), 158 | window_type=WindowType.MAIN, 159 | render_type=RenderType.LINE, 160 | min=self.min, 161 | max=self.max 162 | ) for name in self._names} 163 | 164 | def compute(self): 165 | self._data[self.names[0]] = self._data[self.target_column].rolling(self._period).mean() 166 | 167 | def config(self): 168 | config = super().config() 169 | config['period'] = self._period 170 | return config 171 | 172 | 173 | class BolingerBands(Indicator): 174 | """ Volatility indicator 175 | 176 | Bollinger Bands are a type of price envelope developed by John BollingerOpens in a new window. (Price envelopes define 177 | upper and lower price range levels.) Bollinger Bands are envelopes plotted at a standard deviation level above and 178 | below a simple moving average of the price. Because the distance of the bands is based on standard deviation, they 179 | adjust to volatility swings in the underlying price. 180 | 181 | Bollinger Bands use 2 parameters, Period and Standard Deviations, StdDev. The default values are 20 for period, and 2 182 | for standard deviations, although you may customize the combinations. 183 | 184 | Bollinger bands help determine whether prices are high or low on a relative basis. They are used in pairs, both upper 185 | and lower bands and in conjunction with a moving average. Further, the pair of bands is not intended to be used on its own. 186 | Use the pair to confirm signals given with other indicators. 187 | """ 188 | def __init__( 189 | self, 190 | data: pd.DataFrame, 191 | period: int=20, 192 | std: int=2, 193 | target_column: str='close', 194 | render_options: dict={}, 195 | **kwargs 196 | ): 197 | self._period = period 198 | self._std = std 199 | self._names = ['SMA', 'BB_up', 'BB_dn'] 200 | super().__init__(data, target_column, render_options, **kwargs) 201 | self.min = self._data['BB_dn'].min() 202 | self.max = self._data['BB_up'].max() 203 | 204 | def compute(self): 205 | self._data['SMA'] = self._data[self.target_column].rolling(self._period).mean() 206 | self._data['BB_up'] = self._data['SMA'] + self._data[self.target_column].rolling(self._period).std() * self._std 207 | self._data['BB_dn'] = self._data['SMA'] - self._data[self.target_column].rolling(self._period).std() * self._std 208 | 209 | def default_render_options(self): 210 | return {name: RenderOptions( 211 | name=name, 212 | color=(100, 100, 255), 213 | window_type=WindowType.MAIN, 214 | render_type=RenderType.LINE, 215 | min=self.min, 216 | max=self.max 217 | ) for name in self._names} 218 | 219 | def config(self): 220 | config = super().config() 221 | config['period'] = self._period 222 | config['std'] = self._std 223 | return config 224 | 225 | class RSI(Indicator): 226 | """ Momentum indicator 227 | 228 | The Relative Strength Index (RSI), developed by J. Welles Wilder, is a momentum oscillator that measures the speed and 229 | change of price movements. The RSI oscillates between zero and 100. Traditionally the RSI is considered overbought when 230 | above 70 and oversold when below 30. Signals can be generated by looking for divergences and failure swings. 231 | RSI can also be used to identify the general trend. 232 | """ 233 | def __init__( 234 | self, 235 | data: pd.DataFrame, 236 | period: int=14, 237 | target_column: str='close', 238 | render_options: dict={}, 239 | min: float=0.0, 240 | max: float=100.0, 241 | **kwargs 242 | ): 243 | self._period = period 244 | self._names = ['RSI'] 245 | super().__init__(data, target_column, render_options, min=min, max=max, **kwargs) 246 | 247 | def compute(self): 248 | delta = self._data[self.target_column].diff() 249 | up = delta.clip(lower=0) 250 | down = -1 * delta.clip(upper=0) 251 | ema_up = up.ewm(com=self._period-1, adjust=True, min_periods=self._period).mean() 252 | ema_down = down.ewm(com=self._period-1, adjust=True, min_periods=self._period).mean() 253 | rs = ema_up / ema_down 254 | self._data['RSI'] = 100 - (100 / (1 + rs)) 255 | 256 | def default_render_options(self): 257 | custom_options = { 258 | "RSI0": 0, 259 | "RSI30": 30, 260 | "RSI70": 70, 261 | "RSI100": 100 262 | } 263 | options = {name: RenderOptions( 264 | name=name, 265 | color=(100, 100, 255), 266 | window_type=WindowType.SEPERATE, 267 | render_type=RenderType.LINE, 268 | min=self.min, 269 | max=self.max 270 | ) for name in self._names} 271 | 272 | for name, value in custom_options.items(): 273 | options[name] = RenderOptions( 274 | name=name, 275 | color=(192, 192, 192), 276 | window_type=WindowType.SEPERATE, 277 | render_type=RenderType.LINE, 278 | min=self.min, 279 | max=self.max, 280 | value=value 281 | ) 282 | return options 283 | 284 | def config(self): 285 | config = super().config() 286 | config['period'] = self._period 287 | return config 288 | 289 | 290 | class PSAR(Indicator): 291 | """ Parabolic Stop and Reverse (Parabolic SAR) 292 | 293 | The Parabolic Stop and Reverse, more commonly known as the 294 | Parabolic SAR,is a trend-following indicator developed by 295 | J. Welles Wilder. The Parabolic SAR is displayed as a single 296 | parabolic line (or dots) underneath the price bars in an uptrend, 297 | and above the price bars in a downtrend. 298 | 299 | https://school.stockcharts.com/doku.php?id=technical_indicators:parabolic_sar 300 | """ 301 | def __init__( 302 | self, 303 | data: pd.DataFrame, 304 | step: float=0.02, 305 | max_step: float=0.2, 306 | target_column: str='close', 307 | render_options: dict={}, 308 | **kwargs 309 | ): 310 | self._names = ['PSAR'] 311 | self._step = step 312 | self._max_step = max_step 313 | super().__init__(data, target_column, render_options, **kwargs) 314 | self.min = self._data['PSAR'].min() 315 | self.max = self._data['PSAR'].max() 316 | 317 | def default_render_options(self): 318 | return {name: RenderOptions( 319 | name=name, 320 | color=(100, 100, 255), 321 | window_type=WindowType.MAIN, 322 | render_type=RenderType.DOT, 323 | min=self.min, 324 | max=self.max 325 | ) for name in self._names} 326 | 327 | def compute(self): 328 | high = self._data['high'] 329 | low = self._data['low'] 330 | close = self._data[self.target_column] 331 | 332 | up_trend = True 333 | acceleration_factor = self._step 334 | up_trend_high = high.iloc[0] 335 | down_trend_low = low.iloc[0] 336 | 337 | self._psar = close.copy() 338 | self._psar_up = pd.Series(index=self._psar.index, dtype="float64") 339 | self._psar_down = pd.Series(index=self._psar.index, dtype="float64") 340 | 341 | for i in range(2, len(close)): 342 | reversal = False 343 | 344 | max_high = high.iloc[i] 345 | min_low = low.iloc[i] 346 | 347 | if up_trend: 348 | self._psar.iloc[i] = self._psar.iloc[i - 1] + ( 349 | acceleration_factor * (up_trend_high - self._psar.iloc[i - 1]) 350 | ) 351 | 352 | if min_low < self._psar.iloc[i]: 353 | reversal = True 354 | self._psar.iloc[i] = up_trend_high 355 | down_trend_low = min_low 356 | acceleration_factor = self._step 357 | else: 358 | if max_high > up_trend_high: 359 | up_trend_high = max_high 360 | acceleration_factor = min( 361 | acceleration_factor + self._step, self._max_step 362 | ) 363 | 364 | low1 = low.iloc[i - 1] 365 | low2 = low.iloc[i - 2] 366 | if low2 < self._psar.iloc[i]: 367 | self._psar.iloc[i] = low2 368 | elif low1 < self._psar.iloc[i]: 369 | self._psar.iloc[i] = low1 370 | else: 371 | self._psar.iloc[i] = self._psar.iloc[i - 1] - ( 372 | acceleration_factor * (self._psar.iloc[i - 1] - down_trend_low) 373 | ) 374 | 375 | if max_high > self._psar.iloc[i]: 376 | reversal = True 377 | self._psar.iloc[i] = down_trend_low 378 | up_trend_high = max_high 379 | acceleration_factor = self._step 380 | else: 381 | if min_low < down_trend_low: 382 | down_trend_low = min_low 383 | acceleration_factor = min( 384 | acceleration_factor + self._step, self._max_step 385 | ) 386 | 387 | high1 = high.iloc[i - 1] 388 | high2 = high.iloc[i - 2] 389 | if high2 > self._psar.iloc[i]: 390 | self._psar[i] = high2 391 | elif high1 > self._psar.iloc[i]: 392 | self._psar.iloc[i] = high1 393 | 394 | up_trend = up_trend != reversal # XOR 395 | 396 | if up_trend: 397 | self._psar_up.iloc[i] = self._psar.iloc[i] 398 | else: 399 | self._psar_down.iloc[i] = self._psar.iloc[i] 400 | 401 | # calculate psar indicator 402 | self._data['PSAR'] = self._psar 403 | 404 | def config(self): 405 | config = super().config() 406 | config['step'] = self._step 407 | config['max_step'] = self._max_step 408 | return config 409 | 410 | 411 | class MACD(Indicator): 412 | """ Moving Average Convergence Divergence (MACD) 413 | """ 414 | def __init__( 415 | self, 416 | data: pd.DataFrame, 417 | fast_ma: int = 12, 418 | slow_ma: int = 26, 419 | histogram: int = 9, 420 | target_column: str='close', 421 | render_options: dict={}, 422 | **kwargs 423 | ): 424 | self._fast_ma = fast_ma 425 | self._slow_ma = slow_ma 426 | self._histogram = histogram 427 | self._names = ['MACD', 'MACD_signal'] 428 | super().__init__(data, target_column, render_options, **kwargs) 429 | self.min = self._data['MACD_signal'].min() 430 | self.max = self._data['MACD_signal'].max() 431 | 432 | def compute(self): 433 | # Calculate the Short Term Exponential Moving Average (EMA) 434 | short_ema = self._data[self.target_column].ewm(span=self._fast_ma, adjust=False).mean() 435 | 436 | # Calculate the Long Term Exponential Moving Average (EMA) 437 | long_ema = self._data[self.target_column].ewm(span=self._slow_ma, adjust=False).mean() 438 | 439 | # Calculate the Moving Average Convergence/Divergence (MACD) 440 | self._data["MACD"] = short_ema - long_ema 441 | 442 | # Calculate the Signal Line 443 | self._data["MACD_signal"] = self._data["MACD"].ewm(span=9, adjust=False).mean() 444 | 445 | def default_render_options(self): 446 | return {name: RenderOptions( 447 | name=name, 448 | color=(100, 100, 255), 449 | window_type=WindowType.SEPERATE, 450 | render_type=RenderType.LINE, 451 | min=self.min, 452 | max=self.max 453 | ) for name in self._names} 454 | 455 | def config(self): 456 | config = super().config() 457 | config['fast_ma'] = self._fast_ma 458 | config['slow_ma'] = self._slow_ma 459 | config['histogram'] = self._histogram 460 | return config -------------------------------------------------------------------------------- /finrock/metrics.py: -------------------------------------------------------------------------------- 1 | from .state import State 2 | import numpy as np 3 | 4 | """ Metrics are used to track and log information about the environment. 5 | possible metrics: 6 | + DifferentActions, 7 | + AccountValue, 8 | + MaxDrawdown, 9 | + SharpeRatio, 10 | - AverageProfit, 11 | - AverageLoss, 12 | - AverageTrade, 13 | - WinRate, 14 | - LossRate, 15 | - AverageWin, 16 | - AverageLoss, 17 | - AverageWinLossRatio, 18 | - AverageTradeDuration, 19 | - AverageTradeReturn, 20 | """ 21 | 22 | class Metric: 23 | def __init__(self, name: str="metric") -> None: 24 | self.name = name 25 | self.reset() 26 | 27 | @property 28 | def __name__(self) -> str: 29 | return self.__class__.__name__ 30 | 31 | def update(self, state: State): 32 | assert isinstance(state, State), f'state must be State, received: {type(state)}' 33 | 34 | return state 35 | 36 | @property 37 | def result(self): 38 | raise NotImplementedError 39 | 40 | def reset(self, prev_state: State=None): 41 | assert prev_state is None or isinstance(prev_state, State), f'prev_state must be None or State, received: {type(prev_state)}' 42 | 43 | return prev_state 44 | 45 | 46 | class DifferentActions(Metric): 47 | def __init__(self, name: str="different_actions") -> None: 48 | super().__init__(name=name) 49 | 50 | def update(self, state: State): 51 | super().update(state) 52 | 53 | if not self.prev_state: 54 | self.prev_state = state 55 | else: 56 | if state.allocation_percentage != self.prev_state.allocation_percentage: 57 | self.different_actions += 1 58 | 59 | self.prev_state = state 60 | 61 | @property 62 | def result(self): 63 | return self.different_actions 64 | 65 | def reset(self, prev_state: State=None): 66 | super().reset(prev_state) 67 | 68 | self.prev_state = prev_state 69 | self.different_actions = 0 70 | 71 | 72 | class AccountValue(Metric): 73 | def __init__(self, name: str="account_value") -> None: 74 | super().__init__(name=name) 75 | 76 | def update(self, state: State): 77 | super().update(state) 78 | 79 | self.account_value = state.account_value 80 | 81 | @property 82 | def result(self): 83 | return self.account_value 84 | 85 | def reset(self, prev_state: State=None): 86 | super().reset(prev_state) 87 | 88 | self.account_value = prev_state.account_value if prev_state else 0.0 89 | 90 | 91 | class MaxDrawdown(Metric): 92 | """ The Maximum Drawdown (MDD) is a measure of the largest peak-to-trough decline in the 93 | value of a portfolio or investment during a specific period 94 | 95 | The Maximum Drawdown Ratio represents the proportion of the peak value that was lost during 96 | the largest decline. It is a measure of the risk associated with a particular investment or 97 | portfolio. Investors and fund managers use the Maximum Drawdown and its ratio to assess the 98 | historical downside risk and potential losses that could be incurred. 99 | """ 100 | def __init__(self, name: str="max_drawdown") -> None: 101 | super().__init__(name=name) 102 | 103 | def update(self, state: State): 104 | super().update(state) 105 | 106 | # Use min to find the trough value 107 | self.max_account_value = max(self.max_account_value, state.account_value) 108 | 109 | # Calculate drawdown 110 | drawdown = (state.account_value - self.max_account_value) / self.max_account_value 111 | 112 | # Update max drawdown if the current drawdown is greater 113 | self.max_drawdown = min(self.max_drawdown, drawdown) 114 | 115 | @property 116 | def result(self): 117 | return self.max_drawdown 118 | 119 | def reset(self, prev_state: State=None): 120 | super().reset(prev_state) 121 | 122 | self.max_account_value = prev_state.account_value if prev_state else 0.0 123 | self.max_drawdown = 0.0 124 | 125 | 126 | class SharpeRatio(Metric): 127 | """ The Sharpe Ratio, is a measure of the risk-adjusted performance of an investment or a portfolio. 128 | It helps investors evaluate the return of an investment relative to its risk. 129 | 130 | A higher Sharpe Ratio indicates a better risk-adjusted performance. Investors and portfolio managers 131 | often use the Sharpe Ratio to compare the risk-adjusted returns of different investments or portfolios. 132 | It allows them to assess whether the additional return earned by taking on additional risk is justified. 133 | """ 134 | def __init__(self, ratio_days=365.25, name: str='sharpe_ratio'): 135 | self.ratio_days = ratio_days 136 | super().__init__(name=name) 137 | 138 | def update(self, state: State): 139 | super().update(state) 140 | time_difference_days = (state.date - self.prev_state.date).days 141 | if time_difference_days >= 1: 142 | self.daily_returns.append((state.account_value - self.prev_state.account_value) / self.prev_state.account_value) 143 | self.prev_state = state 144 | 145 | @property 146 | def result(self): 147 | if len(self.daily_returns) == 0: 148 | return 0.0 149 | 150 | mean = np.mean(self.daily_returns) 151 | std = np.std(self.daily_returns) 152 | if std == 0: 153 | return 0.0 154 | 155 | sharpe_ratio = mean / std * np.sqrt(self.ratio_days) 156 | 157 | return sharpe_ratio 158 | 159 | def reset(self, prev_state: State=None): 160 | super().reset(prev_state) 161 | self.prev_state = prev_state 162 | self.daily_returns = [] -------------------------------------------------------------------------------- /finrock/render.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from .state import State 3 | 4 | class RenderType(Enum): 5 | LINE = 0 6 | DOT = 1 7 | 8 | class WindowType(Enum): 9 | MAIN = 0 10 | SEPERATE = 1 11 | 12 | class RenderOptions: 13 | def __init__( 14 | self, 15 | name: str, 16 | color: tuple, 17 | window_type: WindowType, 18 | render_type: RenderType, 19 | min: float, 20 | max: float, 21 | value: float = None, 22 | ): 23 | self.name = name 24 | self.color = color 25 | self.window_type = window_type 26 | self.render_type = render_type 27 | self.min = min 28 | self.max = max 29 | self.value = value 30 | 31 | def copy(self): 32 | return RenderOptions( 33 | name=self.name, 34 | color=self.color, 35 | window_type=self.window_type, 36 | render_type=self.render_type, 37 | min=self.min, 38 | max=self.max, 39 | value=self.value 40 | ) 41 | 42 | class ColorTheme: 43 | black = (0, 0, 0) 44 | white = (255, 255, 255) 45 | red = (255, 10, 0) 46 | lightblue = (100, 100, 255) 47 | green = (0, 240, 0) 48 | 49 | background = black 50 | up_candle = green 51 | down_candle = red 52 | wick = white 53 | text = white 54 | buy = green 55 | sell = red 56 | font = 'Noto Sans' 57 | font_ratio = 0.02 58 | 59 | class MainWindow: 60 | def __init__( 61 | self, 62 | width: int, 63 | height: int, 64 | top_offset: int, 65 | bottom_offset: int, 66 | window_size: int, 67 | candle_spacing, 68 | font_ratio: float=0.02, 69 | spacing_ratio: float=0.02, 70 | split_offset: int=0 71 | ): 72 | self.width = width 73 | self.height = height 74 | self.top_offset = top_offset 75 | self.bottom_offset = bottom_offset 76 | self.window_size = window_size 77 | self.candle_spacing = candle_spacing 78 | self.font_ratio = font_ratio 79 | self.spacing_ratio = spacing_ratio 80 | self.split_offset = split_offset 81 | 82 | self.seperate_window_ratio = 0.15 83 | 84 | @property 85 | def font_size(self): 86 | return int(self.height * self.font_ratio) 87 | 88 | @property 89 | def candle_width(self): 90 | return self.width // self.window_size - self.candle_spacing 91 | 92 | @property 93 | def chart_height(self): 94 | return self.height - (2 * self.top_offset + self.bottom_offset) 95 | 96 | @property 97 | def spacing(self): 98 | return int(self.height * self.spacing_ratio) 99 | 100 | @property 101 | def screen_shape(self): 102 | return (self.width, self.height) 103 | 104 | @screen_shape.setter 105 | def screen_shape(self, value: tuple): 106 | self.width, self.height = value 107 | 108 | def map_price_to_window(self, price: float, max_low: float, max_high: float): 109 | max_range = max_high - max_low 110 | height = self.chart_height - self.split_offset - self.bottom_offset - self.top_offset * 2 111 | value = int(height - (price - max_low) / max_range * height) + self.top_offset 112 | return value 113 | 114 | def map_to_seperate_window(self, value: float, min: float, max: float): 115 | self.split_offset = int(self.height * self.seperate_window_ratio) 116 | max_range = max - min 117 | new_value = int(self.split_offset - (value - min) / max_range * self.split_offset) 118 | height = self.chart_height - self.split_offset + new_value 119 | return height 120 | 121 | 122 | class PygameRender: 123 | def __init__( 124 | self, 125 | window_size: int=100, 126 | screen_width: int=1440, 127 | screen_height: int=1080, 128 | top_offset: int=25, 129 | bottom_offset: int=25, 130 | candle_spacing: int=1, 131 | color_theme = ColorTheme(), 132 | frame_rate: int=30, 133 | render_balance: bool=True, 134 | ): 135 | # pygame window settings 136 | self.screen_width = screen_width 137 | self.screen_height = screen_height 138 | self.top_offset = top_offset 139 | self.bottom_offset = bottom_offset 140 | self.candle_spacing = candle_spacing 141 | self.window_size = window_size 142 | self.color_theme = color_theme 143 | self.frame_rate = frame_rate 144 | self.render_balance = render_balance 145 | 146 | self.mainWindow = MainWindow( 147 | width=self.screen_width, 148 | height=self.screen_height, 149 | top_offset=self.top_offset, 150 | bottom_offset=self.bottom_offset, 151 | window_size=self.window_size, 152 | candle_spacing=self.candle_spacing, 153 | font_ratio=self.color_theme.font_ratio 154 | ) 155 | 156 | self._states = [] 157 | 158 | try: 159 | import pygame 160 | self.pygame = pygame 161 | except ImportError: 162 | raise ImportError('Please install pygame (pip install pygame)') 163 | 164 | self.pygame.init() 165 | self.pygame.display.init() 166 | self.window = self.pygame.display.set_mode(self.mainWindow.screen_shape, self.pygame.RESIZABLE) 167 | self.clock = self.pygame.time.Clock() 168 | 169 | def reset(self): 170 | self._states = [] 171 | 172 | def _prerender(func): 173 | """ Decorator for input data validation and pygame window rendering""" 174 | def wrapper(self, info: dict, rgb_array: bool=False): 175 | self._states += info.get('states', []) 176 | 177 | if not self._states or not bool(self.window._pixels_address): 178 | return 179 | 180 | for event in self.pygame.event.get(): 181 | if event.type == self.pygame.QUIT: 182 | self.pygame.quit() 183 | return 184 | 185 | if event.type == self.pygame.VIDEORESIZE: 186 | self.mainWindow.screen_shape = (event.w, event.h) 187 | 188 | # pause if spacebar is pressed 189 | if event.type == self.pygame.KEYDOWN: 190 | if event.key == self.pygame.K_SPACE: 191 | print('Paused') 192 | while True: 193 | event = self.pygame.event.wait() 194 | if event.type == self.pygame.KEYDOWN: 195 | if event.key == self.pygame.K_SPACE: 196 | print('Unpaused') 197 | break 198 | if event.type == self.pygame.QUIT: 199 | self.pygame.quit() 200 | return 201 | 202 | self.mainWindow.screen_shape = self.pygame.display.get_surface().get_size() 203 | 204 | 205 | canvas = func(self, info) 206 | canvas = self.pygame.transform.scale(canvas, self.mainWindow.screen_shape) 207 | # The following line copies our drawings from `canvas` to the visible window 208 | self.window.blit(canvas, canvas.get_rect()) 209 | self.pygame.display.update() 210 | self.clock.tick(self.frame_rate) 211 | 212 | if rgb_array: 213 | return self.pygame.surfarray.array3d(canvas) 214 | 215 | return wrapper 216 | 217 | def render_indicators(self, state: State, canvas: object, candle_offset: int, max_low: float, max_high: float): 218 | # connect last 2 points with a line 219 | for i, indicator in enumerate(state.indicators): 220 | for name, render_option in indicator["render_options"].items(): 221 | 222 | index = self._states.index(state) 223 | if not index: 224 | return 225 | last_state = self._states[index - 1] 226 | 227 | if render_option.render_type == RenderType.LINE: 228 | prev_render_option = last_state.indicators[i]["render_options"][name] 229 | if render_option.window_type == WindowType.MAIN: 230 | 231 | cur_value_map = self.mainWindow.map_price_to_window(render_option.value, max_low, max_high) 232 | prev_value_map = self.mainWindow.map_price_to_window(prev_render_option.value, max_low, max_high) 233 | 234 | elif render_option.window_type == WindowType.SEPERATE: 235 | 236 | cur_value_map = self.mainWindow.map_to_seperate_window(render_option.value, render_option.min, render_option.max) 237 | prev_value_map = self.mainWindow.map_to_seperate_window(prev_render_option.value, prev_render_option.min, prev_render_option.max) 238 | 239 | self.pygame.draw.line(canvas, render_option.color, 240 | (candle_offset - self.mainWindow.candle_width / 2, prev_value_map), 241 | (candle_offset + self.mainWindow.candle_width / 2, cur_value_map)) 242 | 243 | elif render_option.render_type == RenderType.DOT: 244 | if render_option.window_type == WindowType.MAIN: 245 | self.pygame.draw.circle(canvas, render_option.color, 246 | (candle_offset, self.mainWindow.map_price_to_window(render_option.value, max_low, max_high)), 2) 247 | elif render_option.window == WindowType.SEPERATE: 248 | raise NotImplementedError('Seperate window for indicators is not implemented yet') 249 | 250 | def render_candle(self, state: State, canvas: object, candle_offset: int, max_low: float, max_high: float, font: object): 251 | assert isinstance(state, State) == True # check if state is a State object 252 | 253 | # Calculate candle coordinates 254 | candle_y_open = self.mainWindow.map_price_to_window(state.open, max_low, max_high) 255 | candle_y_close = self.mainWindow.map_price_to_window(state.close, max_low, max_high) 256 | candle_y_high = self.mainWindow.map_price_to_window(state.high, max_low, max_high) 257 | candle_y_low = self.mainWindow.map_price_to_window(state.low, max_low, max_high) 258 | 259 | # Determine candle color 260 | if state.open < state.close: 261 | # up candle 262 | candle_color = self.color_theme.up_candle 263 | candle_body_y = candle_y_close 264 | candle_body_height = candle_y_open - candle_y_close 265 | else: 266 | # down candle 267 | candle_color = self.color_theme.down_candle 268 | candle_body_y = candle_y_open 269 | candle_body_height = candle_y_close - candle_y_open 270 | 271 | # Draw candlestick wicks 272 | self.pygame.draw.line(canvas, self.color_theme.wick, 273 | (candle_offset + self.mainWindow.candle_width // 2, candle_y_high), 274 | (candle_offset + self.mainWindow.candle_width // 2, candle_y_low)) 275 | 276 | # Draw candlestick body 277 | self.pygame.draw.rect(canvas, candle_color, (candle_offset, candle_body_y, self.mainWindow.candle_width, candle_body_height)) 278 | 279 | # Compare with previous state to determine whether buy or sell action was taken and draw arrow 280 | index = self._states.index(state) 281 | if index > 0: 282 | last_state = self._states[index - 1] 283 | 284 | if last_state.allocation_percentage < state.allocation_percentage: 285 | # buy 286 | candle_y_low = self.mainWindow.map_price_to_window(last_state.low, max_low, max_high) 287 | self.pygame.draw.polygon(canvas, self.color_theme.buy, [ 288 | (candle_offset - self.mainWindow.candle_width / 2, candle_y_low + self.mainWindow.spacing / 2), 289 | (candle_offset - self.mainWindow.candle_width * 0.1, candle_y_low + self.mainWindow.spacing), 290 | (candle_offset - self.mainWindow.candle_width * 0.9, candle_y_low + self.mainWindow.spacing) 291 | ]) 292 | 293 | # add account_value label bellow candle 294 | if self.render_balance: 295 | text = str(int(last_state.account_value)) 296 | buy_label = font.render(text, True, self.color_theme.text) 297 | label_width, label_height = font.size(text) 298 | canvas.blit(buy_label, (candle_offset - (self.mainWindow.candle_width + label_width) / 2, candle_y_low + self.mainWindow.spacing)) 299 | 300 | elif last_state.allocation_percentage > state.allocation_percentage: 301 | # sell 302 | candle_y_high = self.mainWindow.map_price_to_window(last_state.high, max_low, max_high) 303 | self.pygame.draw.polygon(canvas, self.color_theme.sell, [ 304 | (candle_offset - self.mainWindow.candle_width / 2, candle_y_high - self.mainWindow.spacing / 2), 305 | (candle_offset - self.mainWindow.candle_width * 0.1, candle_y_high - self.mainWindow.spacing), 306 | (candle_offset - self.mainWindow.candle_width * 0.9, candle_y_high - self.mainWindow.spacing) 307 | ]) 308 | 309 | # add account_value label above candle 310 | if self.render_balance: 311 | text = str(int(last_state.account_value)) 312 | sell_label = font.render(text, True, self.color_theme.text) 313 | label_width, label_height = font.size(text) 314 | canvas.blit(sell_label, (candle_offset - (self.mainWindow.candle_width + label_width) / 2, candle_y_high - self.mainWindow.spacing - label_height)) 315 | 316 | @_prerender 317 | def render(self, info: dict): 318 | canvas = self.pygame.Surface(self.mainWindow.screen_shape) 319 | canvas.fill(self.color_theme.background) 320 | 321 | max_high = max([state.high for state in self._states[-self.window_size:]]) 322 | max_low = min([state.low for state in self._states[-self.window_size:]]) 323 | 324 | candle_offset = self.candle_spacing 325 | 326 | # Set font for labels 327 | font = self.pygame.font.SysFont(self.color_theme.font, self.mainWindow.font_size) 328 | 329 | for state in self._states[-self.window_size:]: 330 | 331 | # draw indicators 332 | self.render_indicators(state, canvas, candle_offset, max_low, max_high) 333 | 334 | # draw candle 335 | self.render_candle(state, canvas, candle_offset, max_low, max_high, font) 336 | 337 | # Move to the next candle 338 | candle_offset += self.mainWindow.candle_width + self.candle_spacing 339 | 340 | # Draw max and min ohlc values on the chart 341 | label_width, label_height = font.size(str(max_low)) 342 | label_y_low = font.render(str(max_low), True, self.color_theme.text) 343 | canvas.blit(label_y_low, (self.candle_spacing + 5, self.mainWindow.height - label_height * 2)) 344 | 345 | label_width, label_height = font.size(str(max_low)) 346 | label_y_high = font.render(str(max_high), True, self.color_theme.text) 347 | canvas.blit(label_y_high, (self.candle_spacing + 5, label_height)) 348 | 349 | return canvas -------------------------------------------------------------------------------- /finrock/reward.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .state import Observations 3 | 4 | class Reward: 5 | def __init__(self) -> None: 6 | pass 7 | 8 | @property 9 | def __name__(self) -> str: 10 | return self.__class__.__name__ 11 | 12 | def __call__(self, observations: Observations) -> float: 13 | raise NotImplementedError 14 | 15 | def reset(self, observations: Observations): 16 | pass 17 | 18 | 19 | class SimpleReward(Reward): 20 | def __init__(self) -> None: 21 | super().__init__() 22 | 23 | def __call__(self, observations: Observations) -> float: 24 | assert isinstance(observations, Observations) == True, "observations must be an instance of Observations" 25 | 26 | last_state, next_state = observations[-2:] 27 | 28 | # buy 29 | if next_state.allocation_percentage > last_state.allocation_percentage: 30 | # check whether it was good or bad to buy 31 | order_size = next_state.allocation_percentage - last_state.allocation_percentage 32 | reward = (next_state.close - last_state.close) / last_state.close * order_size 33 | hold_reward = (next_state.close - last_state.close) / last_state.close * last_state.allocation_percentage 34 | reward += hold_reward 35 | 36 | # sell 37 | elif next_state.allocation_percentage < last_state.allocation_percentage: 38 | # check whether it was good or bad to sell 39 | order_size = last_state.allocation_percentage - next_state.allocation_percentage 40 | reward = -1 * (next_state.close - last_state.close) / last_state.close * order_size 41 | hold_reward = (next_state.close - last_state.close) / last_state.close * next_state.allocation_percentage 42 | reward += hold_reward 43 | 44 | # hold 45 | else: 46 | # check whether it was good or bad to hold 47 | ratio = -1 if not last_state.allocation_percentage else last_state.allocation_percentage 48 | reward = (next_state.close - last_state.close) / last_state.close * ratio 49 | 50 | return reward 51 | 52 | class AccountValueChangeReward(Reward): 53 | def __init__(self) -> None: 54 | super().__init__() 55 | self.ratio_days=365.25 56 | 57 | def reset(self, observations: Observations): 58 | super().reset(observations) 59 | self.returns = [] 60 | 61 | def __call__(self, observations: Observations) -> float: 62 | assert isinstance(observations, Observations) == True, "observations must be an instance of Observations" 63 | 64 | last_state, next_state = observations[-2:] 65 | reward = (next_state.account_value - last_state.account_value) / last_state.account_value 66 | 67 | return reward -------------------------------------------------------------------------------- /finrock/scalers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.seterr(all="ignore") 3 | import warnings 4 | from .state import Observations 5 | 6 | 7 | class Scaler: 8 | def __init__(self): 9 | pass 10 | 11 | def transform(self, observations: Observations) -> np.ndarray: 12 | raise NotImplementedError 13 | 14 | def __call__(self, observations) -> np.ndarray: 15 | assert isinstance(observations, Observations) == True, "observations must be an instance of Observations" 16 | return self.transform(observations) 17 | 18 | @property 19 | def __name__(self) -> str: 20 | return self.__class__.__name__ 21 | 22 | @property 23 | def name(self) -> str: 24 | return self.__name__ 25 | 26 | 27 | class MinMaxScaler(Scaler): 28 | def __init__(self, min: float, max: float): 29 | super().__init__() 30 | self._min = min 31 | self._max = max 32 | 33 | def transform(self, observations: Observations) -> np.ndarray: 34 | transformed_data = [] 35 | for state in observations: 36 | data = [] 37 | for name in ['open', 'high', 'low', 'close']: 38 | value = getattr(state, name) 39 | transformed_value = (value - self._min) / (self._max - self._min) 40 | data.append(transformed_value) 41 | 42 | data.append(state.allocation_percentage) 43 | 44 | # append scaled indicators 45 | for indicator in state.indicators: 46 | for value in indicator["values"].values(): 47 | transformed_value = (value - indicator["min"]) / (indicator["max"] - indicator["min"]) 48 | data.append(transformed_value) 49 | 50 | transformed_data.append(data) 51 | 52 | results = np.array(transformed_data) 53 | 54 | return results 55 | 56 | 57 | class ZScoreScaler(Scaler): 58 | def __init__(self): 59 | super().__init__() 60 | warnings.filterwarnings("ignore", category=RuntimeWarning, message="overflow encountered in reduce") 61 | 62 | def transform(self, observations: Observations) -> np.ndarray: 63 | full_data = [] 64 | for state in observations: 65 | data = [getattr(state, name) for name in ['open', 'high', 'low', 'close', 'allocation_percentage']] 66 | data += [value for indicator in state.indicators for value in indicator["values"].values()] 67 | full_data.append(data) 68 | 69 | results = np.array(full_data) 70 | 71 | # nan to zero, when divided by zero and allocation_percentage is not changed 72 | returns = np.nan_to_num(np.diff(results, axis=0) / results[:-1]) 73 | 74 | z_scores = np.nan_to_num((returns - np.mean(returns, axis=0)) / np.std(returns, axis=0)) 75 | 76 | return z_scores -------------------------------------------------------------------------------- /finrock/state.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import numpy as np 3 | from datetime import datetime 4 | 5 | class State: 6 | def __init__( 7 | self, 8 | timestamp: str, 9 | open: float, 10 | high: float, 11 | low: float, 12 | close: float, 13 | volume: float=0.0, 14 | indicators: list=[] 15 | ): 16 | self.timestamp = timestamp 17 | self.open = open 18 | self.high = high 19 | self.low = low 20 | self.close = close 21 | self.volume = volume 22 | self.indicators = indicators 23 | 24 | try: 25 | self.date = datetime.strptime(timestamp, '%Y-%m-%d %H:%M:%S') 26 | except ValueError: 27 | raise ValueError(f'received invalid timestamp date format: {timestamp}, expected: YYYY-MM-DD HH:MM:SS') 28 | 29 | self._balance = 0.0 # balance in cash 30 | self._assets = 0.0 # balance in assets 31 | self._allocation_percentage = 0.0 # percentage of assets allocated to this state 32 | 33 | @property 34 | def balance(self): 35 | return self._balance 36 | 37 | @balance.setter 38 | def balance(self, value: float): 39 | self._balance = value 40 | 41 | @property 42 | def assets(self): 43 | return self._assets 44 | 45 | @assets.setter 46 | def assets(self, value: float): 47 | self._assets = value 48 | 49 | @property 50 | def account_value(self): 51 | return self.balance + self.assets * self.close 52 | 53 | @property 54 | def allocation_percentage(self): 55 | return self._allocation_percentage 56 | 57 | @allocation_percentage.setter 58 | def allocation_percentage(self, value: float): 59 | assert 0.0 <= value <= 1.0, f'allocation_percentage value must be between 0.0 and 1.0, received: {value}' 60 | self._allocation_percentage = value 61 | 62 | 63 | class Observations: 64 | def __init__( 65 | self, 66 | window_size: int, 67 | observations: typing.List[State]=[], 68 | ): 69 | self._observations = observations 70 | self._window_size = window_size 71 | 72 | assert isinstance(self._observations, list) == True, "observations must be a list" 73 | assert len(self._observations) <= self._window_size, f'observations length must be <= window_size, received: {len(self._observations)}' 74 | assert all(isinstance(observation, State) for observation in self._observations) == True, "observations must be a list of State objects" 75 | 76 | def __len__(self) -> int: 77 | return len(self._observations) 78 | 79 | @property 80 | def window_size(self) -> int: 81 | return self._window_size 82 | 83 | @property 84 | def observations(self) -> typing.List[State]: 85 | return self._observations 86 | 87 | @property 88 | def full(self) -> bool: 89 | return len(self._observations) == self._window_size 90 | 91 | def __getitem__(self, idx: int) -> State: 92 | try: 93 | return self._observations[idx] 94 | except IndexError: 95 | raise IndexError(f'index out of range: {idx}, observations length: {len(self._observations)}') 96 | 97 | def __iter__(self) -> State: 98 | """ Create a generator that iterate over the Sequence.""" 99 | for index in range(len(self)): 100 | yield self[index] 101 | 102 | def reset(self) -> None: 103 | self._observations = [] 104 | 105 | def append(self, state: State) -> None: 106 | # state should be State object or None 107 | assert isinstance(state, State) or state is None, "state must be a State object or None" 108 | self._observations.append(state) 109 | 110 | if len(self._observations) > self._window_size: 111 | self._observations.pop(0) 112 | 113 | @property 114 | def close(self) -> np.ndarray: 115 | return np.array([state.close for state in self._observations]) 116 | 117 | @property 118 | def high(self) -> np.ndarray: 119 | return np.array([state.high for state in self._observations]) 120 | 121 | @property 122 | def low(self) -> np.ndarray: 123 | return np.array([state.low for state in self._observations]) 124 | 125 | @property 126 | def open(self) -> np.ndarray: 127 | return np.array([state.open for state in self._observations]) 128 | 129 | @property 130 | def allocation_percentage(self) -> np.ndarray: 131 | return np.array([state.allocation_percentage for state in self._observations]) 132 | 133 | @property 134 | def volume(self) -> np.ndarray: 135 | return np.array([state.volume for state in self._observations]) -------------------------------------------------------------------------------- /finrock/trading_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import typing 4 | import importlib 5 | import numpy as np 6 | 7 | from enum import Enum 8 | from .state import State, Observations 9 | from .data_feeder import PdDataFeeder 10 | from .reward import SimpleReward 11 | 12 | class ActionSpace(Enum): 13 | DISCRETE = 3 14 | CONTINUOUS = 2 15 | 16 | class TradingEnv: 17 | def __init__( 18 | self, 19 | data_feeder: PdDataFeeder, 20 | output_transformer: typing.Callable = None, 21 | initial_balance: float = 1000.0, 22 | max_episode_steps: int = None, 23 | window_size: int = 50, 24 | reward_function: typing.Callable = SimpleReward(), 25 | action_space: ActionSpace = ActionSpace.DISCRETE, 26 | metrics: typing.List[typing.Callable] = [], 27 | order_fee_percent: float = 0.001 28 | ) -> None: 29 | self._data_feeder = data_feeder 30 | self._output_transformer = output_transformer 31 | self._initial_balance = initial_balance 32 | self._max_episode_steps = max_episode_steps if max_episode_steps is not None else len(data_feeder) 33 | self._window_size = window_size 34 | self._reward_function = reward_function 35 | self._metrics = metrics 36 | self._order_fee_percent = order_fee_percent 37 | 38 | self._observations = Observations(window_size=window_size) 39 | self._observation_space = np.zeros(self.reset()[0].shape) 40 | self._action_space = action_space 41 | self.fee_ratio = 1 - self._order_fee_percent 42 | 43 | @property 44 | def action_space(self): 45 | return self._action_space.value 46 | 47 | @property 48 | def observation_space(self): 49 | return self._observation_space 50 | 51 | def _get_obs(self, index: int, balance: float=None) -> State: 52 | next_state = self._data_feeder[index] 53 | if next_state is None: 54 | return None 55 | 56 | if balance is not None: 57 | next_state.balance = balance 58 | 59 | return next_state 60 | 61 | def _get_terminated(self): 62 | return False 63 | 64 | def _take_action(self, action_pred: typing.Union[int, np.ndarray]) -> typing.Tuple[int, float]: 65 | """ 66 | """ 67 | # validate action is in range 68 | 69 | if isinstance(action_pred, np.ndarray): 70 | order_size = np.clip(action_pred[1], 0, 1) 71 | order_size = np.around(order_size, decimals=2) 72 | action = int((np.clip(action_pred[0], -1, 1) + 1) * 1.5) # scale from -1,1 to 0,3 73 | elif action_pred in [0, 1, 2]: 74 | order_size = 1.0 75 | action = action_pred 76 | assert (action in list(range(self._action_space.value))) == True, f'action must be in range {self._action_space.value}, received: {action}' 77 | else: 78 | raise ValueError(f'invalid action type: {type(action)}') 79 | 80 | 81 | # get last state and next state 82 | last_state, next_state = self._observations[-2:] 83 | 84 | # modify action to hold (0) if we are out of balance 85 | if action == 2 and last_state.allocation_percentage == 1.0: 86 | action = 0 87 | 88 | # modify action to hold (0) if we are out of assets 89 | elif action == 1 and last_state.allocation_percentage == 0.0: 90 | action = 0 91 | 92 | if order_size == 0: 93 | action = 0 94 | 95 | if action == 2: # buy 96 | buy_order_size = order_size 97 | next_state.allocation_percentage = last_state.allocation_percentage + (1 - last_state.allocation_percentage) * buy_order_size 98 | next_state.assets = last_state.assets + (last_state.balance * buy_order_size / last_state.close) * self.fee_ratio 99 | next_state.balance = last_state.balance - (last_state.balance * buy_order_size) * self.fee_ratio 100 | 101 | elif action == 1: # sell 102 | sell_order_size = order_size 103 | next_state.allocation_percentage = last_state.allocation_percentage - last_state.allocation_percentage * sell_order_size 104 | next_state.balance = last_state.balance + (last_state.assets * sell_order_size * last_state.close) * self.fee_ratio 105 | next_state.assets = last_state.assets - (last_state.assets * sell_order_size) * self.fee_ratio 106 | 107 | else: # hold 108 | next_state.allocation_percentage = last_state.allocation_percentage 109 | next_state.assets = last_state.assets 110 | next_state.balance = last_state.balance 111 | 112 | if next_state.allocation_percentage > 1.0: 113 | raise ValueError(f'next_state.allocation_percentage > 1.0: {next_state.allocation_percentage}') 114 | elif next_state.allocation_percentage < 0.0: 115 | raise ValueError(f'next_state.allocation_percentage < 0.0: {next_state.allocation_percentage}') 116 | 117 | return action, order_size 118 | 119 | @property 120 | def metrics(self): 121 | return self._metrics 122 | 123 | def _metricsHandler(self, observation: State): 124 | metrics = {} 125 | # Loop through metrics and update 126 | for metric in self._metrics: 127 | metric.update(observation) 128 | metrics[metric.name] = metric.result 129 | 130 | return metrics 131 | 132 | def step(self, action: int) -> typing.Tuple[State, float, bool, bool, dict]: 133 | 134 | index = self._env_step_indexes.pop(0) 135 | 136 | observation = self._get_obs(index) 137 | # update observations object with new observation 138 | self._observations.append(observation) 139 | 140 | action, order_size = self._take_action(action) 141 | reward = self._reward_function(self._observations) 142 | terminated = self._get_terminated() 143 | truncated = False if self._env_step_indexes else True 144 | info = { 145 | "states": [observation], 146 | "metrics": self._metricsHandler(observation) 147 | } 148 | 149 | transformed_obs = self._output_transformer.transform(self._observations) 150 | 151 | if np.isnan(transformed_obs).any(): 152 | raise ValueError("transformed_obs contains nan values, check your data") 153 | 154 | return transformed_obs, reward, terminated, truncated, info 155 | 156 | def reset(self) -> typing.Tuple[State, dict]: 157 | """ Reset the environment and return the initial state 158 | """ 159 | size = len(self._data_feeder) - self._max_episode_steps 160 | self._env_start_index = np.random.randint(0, size) if size > 0 else 0 161 | self._env_step_indexes = list(range(self._env_start_index, self._env_start_index + self._max_episode_steps)) 162 | 163 | # Initial observations are the first states of the window size 164 | self._observations.reset() 165 | while not self._observations.full: 166 | obs = self._get_obs(self._env_step_indexes.pop(0), balance=self._initial_balance) 167 | if obs is None: 168 | continue 169 | # update observations object with new observation 170 | self._observations.append(obs) 171 | 172 | info = { 173 | "states": self._observations.observations, 174 | "metrics": {} 175 | } 176 | 177 | # reset metrics with last state 178 | for metric in self._metrics: 179 | metric.reset(self._observations.observations[-1]) 180 | 181 | transformed_obs = self._output_transformer.transform(self._observations) 182 | if np.isnan(transformed_obs).any(): 183 | raise ValueError("transformed_obs contains nan values, check your data") 184 | 185 | # return state and info 186 | return transformed_obs, info 187 | 188 | def render(self): 189 | raise NotImplementedError 190 | 191 | def close(self): 192 | """ Close the environment 193 | """ 194 | pass 195 | 196 | def config(self): 197 | """ Return the environment configuration 198 | """ 199 | return { 200 | "data_feeder": self._data_feeder.__name__, 201 | "output_transformer": self._output_transformer.__name__, 202 | "initial_balance": self._initial_balance, 203 | "max_episode_steps": self._max_episode_steps, 204 | "window_size": self._window_size, 205 | "reward_function": self._reward_function.__name__, 206 | "metrics": [metric.__name__ for metric in self._metrics], 207 | "order_fee_percent": self._order_fee_percent, 208 | "observation_space_shape": tuple(self.observation_space.shape), 209 | "action_space": self._action_space.name, 210 | } 211 | 212 | def save_config(self, path: str = ""): 213 | """ Save the environment configuration 214 | """ 215 | output_path = os.path.join(path, "TradingEnv.json") 216 | with open(output_path, "w") as f: 217 | json.dump(self.config(), f, indent=4) 218 | 219 | @staticmethod 220 | def load_config(data_feeder, path: str = "", **kwargs): 221 | """ Load the environment configuration 222 | """ 223 | 224 | input_path = os.path.join(path, "TradingEnv.json") 225 | if not os.path.exists(input_path): 226 | raise Exception(f"TradingEnv Config file not found in {path}") 227 | with open(input_path, "r") as f: 228 | config = json.load(f) 229 | 230 | environment = TradingEnv( 231 | data_feeder = data_feeder, 232 | output_transformer = getattr(importlib.import_module(".scalers", package=__package__), config["output_transformer"])(), 233 | initial_balance = kwargs.get("initial_balance") or config["initial_balance"], 234 | max_episode_steps = kwargs.get("max_episode_steps") or config["max_episode_steps"], 235 | window_size = kwargs.get("window_size") or config["window_size"], 236 | reward_function = getattr(importlib.import_module(".reward", package=__package__), config["reward_function"])(), 237 | action_space = ActionSpace[config["action_space"]], 238 | metrics = [getattr(importlib.import_module(".metrics", package=__package__), metric)() for metric in config["metrics"]], 239 | order_fee_percent = kwargs.get("order_fee_percent") or config["order_fee_percent"] 240 | ) 241 | 242 | return environment -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | matplotlib 4 | rockrl==0.4.4 5 | tensorflow==2.10.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | 4 | DIR = os.path.abspath(os.path.dirname(__file__)) 5 | 6 | with open(os.path.join(DIR, 'README.md')) as fh: 7 | long_description = fh.read() 8 | 9 | with open(os.path.join(DIR, 'requirements.txt')) as fh: 10 | requirements = fh.read().splitlines() 11 | 12 | def get_version(initpath: str) -> str: 13 | """ Get from the init of the source code the version string 14 | 15 | Params: 16 | initpath (str): path to the init file of the python package relative to the setup file 17 | 18 | Returns: 19 | str: The version string in the form 0.0.1 20 | """ 21 | 22 | path = os.path.join(os.path.dirname(__file__), initpath) 23 | 24 | with open(path, "r") as handle: 25 | for line in handle.read().splitlines(): 26 | if line.startswith("__version__"): 27 | return line.split("=")[1].strip().strip("\"'") 28 | else: 29 | raise RuntimeError("Unable to find version string.") 30 | 31 | setup( 32 | name = 'finrock', 33 | version = get_version("finrock/__init__.py"), 34 | long_description = long_description, 35 | long_description_content_type = 'text/markdown', 36 | url='https://pylessons.com/', 37 | author='PyLessons', 38 | author_email='pythonlessons0@gmail.com', 39 | install_requires=requirements, 40 | python_requires='>=3', 41 | packages = find_packages(exclude=['*.pyc']), 42 | include_package_data=True, 43 | project_urls={ 44 | 'Source': 'https://github.com/pythonlessons/FinRock/', 45 | 'Tracker': 'https://github.com/pythonlessons/FinRock/issues', 46 | }, 47 | description="Reinformcement Learning for Financial Trading", 48 | ) --------------------------------------------------------------------------------