├── .gitignore
├── .vscode
└── launch.json
├── CHANGELOG.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── Tutorials
├── 01_Creating_trading_environment.md
├── 02_Trading_simulation_backbone.md
├── 03_Trading_with_RL.md
├── 04_Indicators_and_Metrics.md
└── Documents
│ ├── 01_FinRock.jpg
│ ├── 02_FinRock.jpg
│ ├── 02_FinRock_render.png
│ ├── 03_FinRock.jpg
│ ├── 03_FinRock_render.png
│ ├── 04_FinRock.jpg
│ └── 04_FinRock_render.png
├── bin
├── create_sinusoid_data.py
└── plot_data.py
├── experiments
├── playing_random_sinusoid.py
├── testing_ppo_sinusoid_continuous.py
├── testing_ppo_sinusoid_discrete.py
├── training_ppo_sinusoid_continuous.py
└── training_ppo_sinusoid_discrete.py
├── finrock
├── __init__.py
├── data_feeder.py
├── indicators.py
├── metrics.py
├── render.py
├── reward.py
├── scalers.py
├── state.py
└── trading_env.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.egg-info
3 | api.json
4 | venv*
5 | Datasets
6 | runs
7 | Models
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Current File",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal",
13 | "justMyCode": false
14 | }
15 | ]
16 | }
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | ## [0.5.0] - 2024-01-30
2 | ### Added:
3 | - Added `MACD` indicator to `indicators` file.
4 | - Added `reward.AccountValueChangeReward` object to calculate reward based on the change in the account value.
5 | - Added `scalers.ZScoreScaler` that doesn't require min and max to transform data, but uses mean and std instead.
6 | - Added `ActionSpace` object to handle the action space of the agent.
7 | - Added support for continuous actions. (float values between 0 and 1)
8 |
9 | ### Changed:
10 | - Updated all indicators to have `config` parameter, that we can use so we can serialize the indicators. (save/load configurations to/from file)
11 | - Changed `reward.simpleReward` to `reward.SimpleReward` Object.
12 | - Updated `state.State` to have `open`, `high`, `low`, `close` and `volume` attributes.
13 | - Updated `data_feeder.PdDataFeeder` to be serializable by including `save_config` and `load_config` methods.
14 | - Included trading fees into `trading_env.TradingEnv` object.
15 | - Updated `trading_env.TradingEnv` to have `reset` method, which resets the environment to the initial state.
16 | - Included `save_config` and `load_config` methods into `trading_env.TradingEnv` object, so we can save/load the environment configuration.
17 |
18 | ## [0.4.0] - 2024-01-02
19 | ### Added:
20 | - Created `indicators` file, where I added `BolingerBands`, `RSI`, `PSAR`, `SMA` indicators
21 | - Added `SharpeRatio` and `MaxDrawdown` metrics to `metrics`
22 | - Included indicators handling into `data_feeder.PdDataFeeder` object
23 | - Included indicators handling into `state.State` object
24 |
25 | ### Changed:
26 | - Changed `finrock` package dependency from `0.0.4` to `0.4.1`
27 | - Refactored `render.PygameRender` object to handle indicators rendering (getting very messy)
28 | - Updated `scalers.MinMaxScaler` to handle indicators scaling
29 | - Updated `trading_env.TradingEnv` to raise an error with `np.nan` data and skip `None` states
30 |
31 |
32 | ## [0.3.0] - 2023-12-05
33 | ### Added:
34 | - Added `DifferentActions` and `AccountValue` as metrics. Metrics are the main way to evaluate the performance of the agent.
35 | - Now `metrics.Metrics` object can be used to calculate the metrics within trading environment.
36 | - Included `rockrl==0.0.4` as a dependency, which is a reinforcement learning package that I created.
37 | - Added `experiments/training_ppo_sinusoid.py` to train a simple Dense agent using PPO algorithm on the sinusoid data with discrete actions.
38 | - Added `experiments/testing_ppo_sinusoid.py` to test the trained agent on the sinusoid data with discrete actions.
39 |
40 | ### Changed:
41 | - Renamed and moved `playing.py` to `experiments/playing_random_sinusoid.py`
42 | - Upgraded `finrock.render.PygameRender`, now we can stop/resume rendering with spacebar and render account value along with the actions
43 |
44 |
45 | ## [0.2.0] - 2023-11-29
46 | ### Added:
47 | - Created `reward.simpleReward` function to calculate reward based on the action and the difference between the current price and the previous price
48 | - Created `scalers.MinMaxScaler` object to transform the price data to a range between 0 and 1 and prepare it for the neural networks input
49 | - Created `state.Observations` object to hold the observations of the agent with set window size
50 | - Updated `render.PygameRender` object to render the agent's actions
51 | - Updated `state.State` to hold current state `assets`, `balance` and `allocation_percentage` on specific State
52 |
53 |
54 | ## [0.1.0] - 2023-10-17
55 | ### Initial Release:
56 | - Created the project
57 | - Created code to create random sinusoidal price data
58 | - Created `state.State` object, which holds the state of the market
59 | - Created `render.PygameRender` object, which renders the state of the market using `pygame` library
60 | - Created `trading_env.TradingEnv` object, which is the environment for the agent to interact with
61 | - Created `data_feeder.PdDataFeeder` object, which feeds the environment with data from a pandas dataframe
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | global-exclude *.pyc
2 | include requirements.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FinRock
2 | Reinforcement Learning package for Finance
3 |
4 | # Environment Structure:
5 |
6 |
7 |
8 |
9 | ### Install requirements:
10 | ```
11 | pip install -r requirements.txt
12 | pip install pygame
13 | pip install .
14 | ```
15 |
16 | ### Create sinusoid data:
17 | ```
18 | python bin/create_sinusoid_data.py
19 | ```
20 |
21 | ### Train RL (PPO) agent on discrete actions:
22 | ```
23 | experiments/training_ppo_sinusoid.py
24 | ```
25 |
26 | ### Test trained agent (Change path to the saved model):
27 | ```
28 | experiments/testing_ppo_sinusoid.py
29 | ```
30 |
31 | ### Environment Render:
32 |
33 |
34 |
35 |
36 | ## Links to YouTube videos:
37 | - [Introduction to FinRock package](https://youtu.be/xU_YJB7vilA)
38 | - [Complete Trading Simulation Backbone](https://youtu.be/1z5geob8Yho)
39 | - [Training RL agent on Sinusoid data](https://youtu.be/JkA4BuYvWyE)
40 | - [Included metrics and indicators into environment](https://youtu.be/bGpBEnKzIdo)
41 |
42 | # TODO:
43 | - [ ] Train model on `continuous` actions (control allocation percentage)
44 | - [ ] Add more indicators
45 | - [ ] Add more metrics
46 | - [ ] Add more reward functions
47 | - [ ] Add more scalers
48 | - [ ] Train RL agent on real data
49 | - [ ] Add more RL algorithms
50 | - [ ] Refactor rendering, maybe move to browser?
--------------------------------------------------------------------------------
/Tutorials/01_Creating_trading_environment.md:
--------------------------------------------------------------------------------
1 | # Introduction to FinRock package
2 |
3 | ### Environment Structure:
4 |
5 |
6 |
7 |
8 | ### Link to YouTube video:
9 | https://youtu.be/xU_YJB7vilA
10 |
11 | ### Link to tutorial code:
12 | https://github.com/pythonlessons/FinRock/tree/0.1.0
13 |
14 | ### Download tutorial code:
15 | https://github.com/pythonlessons/FinRock/archive/refs/tags/0.1.0.zip
16 |
17 |
18 | ### Install requirements:
19 | ```
20 | pip install -r requirements.txt
21 | pip install pygame
22 | ```
23 |
24 | ### Create sinusoid data:
25 | ```
26 | python bin/create_sinusoid_data.py
27 | ```
28 |
29 | ### Run environment:
30 | ```
31 | python playing.py
32 | ```
--------------------------------------------------------------------------------
/Tutorials/02_Trading_simulation_backbone.md:
--------------------------------------------------------------------------------
1 | # Complete Trading Simulation Backbone
2 |
3 | ### Environment Structure:
4 |
5 |
6 |
7 |
8 | ### Link to YouTube video:
9 | https://youtu.be/1z5geob8Yho
10 |
11 | ### Link to tutorial code:
12 | https://github.com/pythonlessons/FinRock/tree/0.2.0
13 |
14 | ### Download tutorial code:
15 | https://github.com/pythonlessons/FinRock/archive/refs/tags/0.2.0.zip
16 |
17 |
18 | ### Install requirements:
19 | ```
20 | pip install -r requirements.txt
21 | pip install pygame
22 | ```
23 |
24 | ### Create sinusoid data:
25 | ```
26 | python bin/create_sinusoid_data.py
27 | ```
28 |
29 | ### Run environment:
30 | ```
31 | python playing.py
32 | ```
33 |
34 | ### Environment Render:
35 |
36 |
37 |
--------------------------------------------------------------------------------
/Tutorials/03_Trading_with_RL.md:
--------------------------------------------------------------------------------
1 | # Complete Trading Simulation Backbone
2 |
3 | ### Environment Structure:
4 |
5 |
6 |
7 |
8 | ### Link to YouTube video:
9 | https://youtu.be/JkA4BuYvWyE
10 |
11 | ### Link to tutorial code:
12 | https://github.com/pythonlessons/FinRock/tree/0.3.0
13 |
14 | ### Download tutorial code:
15 | https://github.com/pythonlessons/FinRock/archive/refs/tags/0.3.0.zip
16 |
17 |
18 | ### Install requirements:
19 | ```
20 | pip install -r requirements.txt
21 | pip install pygame
22 | pip install .
23 | ```
24 |
25 | ### Create sinusoid data:
26 | ```
27 | python bin/create_sinusoid_data.py
28 | ```
29 |
30 | ### Train RL (PPO) agent on discrete actions:
31 | ```
32 | experiments/training_ppo_sinusoid.py
33 | ```
34 |
35 | ### Test trained agent (Change path to the saved model):
36 | ```
37 | experiments/testing_ppo_sinusoid.py
38 | ```
39 |
40 | ### Environment Render:
41 |
42 |
43 |
--------------------------------------------------------------------------------
/Tutorials/04_Indicators_and_Metrics.md:
--------------------------------------------------------------------------------
1 | # Complete Trading Simulation Backbone
2 |
3 | ### Environment Structure:
4 |
5 |
6 |
7 |
8 | ### Link to YouTube video:
9 | https://youtu.be/bGpBEnKzIdo
10 |
11 | ### Link to tutorial code:
12 | https://github.com/pythonlessons/FinRock/tree/0.4.0
13 |
14 | ### Download tutorial code:
15 | https://github.com/pythonlessons/FinRock/archive/refs/tags/0.4.0.zip
16 |
17 |
18 | ### Install requirements:
19 | ```
20 | pip install -r requirements.txt
21 | pip install pygame
22 | pip install .
23 | ```
24 |
25 | ### Create sinusoid data:
26 | ```
27 | python bin/create_sinusoid_data.py
28 | ```
29 |
30 | ### Train RL (PPO) agent on discrete actions:
31 | ```
32 | experiments/training_ppo_sinusoid.py
33 | ```
34 |
35 | ### Test trained agent (Change path to the saved model):
36 | ```
37 | experiments/testing_ppo_sinusoid.py
38 | ```
39 |
40 | ### Environment Render:
41 |
42 |
43 |
--------------------------------------------------------------------------------
/Tutorials/Documents/01_FinRock.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/01_FinRock.jpg
--------------------------------------------------------------------------------
/Tutorials/Documents/02_FinRock.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/02_FinRock.jpg
--------------------------------------------------------------------------------
/Tutorials/Documents/02_FinRock_render.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/02_FinRock_render.png
--------------------------------------------------------------------------------
/Tutorials/Documents/03_FinRock.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/03_FinRock.jpg
--------------------------------------------------------------------------------
/Tutorials/Documents/03_FinRock_render.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/03_FinRock_render.png
--------------------------------------------------------------------------------
/Tutorials/Documents/04_FinRock.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/04_FinRock.jpg
--------------------------------------------------------------------------------
/Tutorials/Documents/04_FinRock_render.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pythonlessons/FinRock/a8968663cc655a182e858133006b4209fc00a650/Tutorials/Documents/04_FinRock_render.png
--------------------------------------------------------------------------------
/bin/create_sinusoid_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import pandas as pd
5 | from datetime import datetime, timedelta
6 |
7 | def create_sinusoidal_df(
8 | amplitude = 2000.0, # Amplitude of the price variations
9 | frequency = 0.01, # Frequency of the price variations
10 | phase = 0.0, # Phase shift of the price variations
11 | num_samples = 10000, # Number of data samples
12 | data_shift = 20000, # shift the data up
13 | trendline_down = 5000, # shift the data down
14 | plot = False,
15 | ):
16 | """Create a dataframe with sinusoidal data"""
17 |
18 | # Generate the time axis
19 | t = np.linspace(0, 2 * np.pi * frequency * num_samples, num_samples)
20 |
21 | # Get the current datetime
22 | now = datetime.now()
23 |
24 | # Set hours, minutes, and seconds to zero
25 | now = now.replace(hour=0, minute=0, second=0, microsecond=0)
26 |
27 | # Generate timestamps for each day
28 | # timestamps = [now - timedelta(days=i) for i in range(num_samples)]
29 | timestamps = [now - timedelta(hours=i*4) for i in range(num_samples)]
30 |
31 | # Convert datetime objects to strings
32 | timestamps = [timestamps.strftime('%Y-%m-%d %H:%M:%S') for timestamps in timestamps]
33 |
34 | # Invert the order of the timestamps
35 | timestamps = timestamps[::-1]
36 |
37 | # Generate the sinusoidal data for prices
38 | sin_data = amplitude * np.sin(t + phase)
39 | sin_data += data_shift # shift the data up
40 |
41 | # shiwft sin_data up, to create trendline up
42 | sin_data -= np.linspace(0, trendline_down, num_samples)
43 |
44 | # Add random noise
45 | noise = np.random.uniform(0.95, 1.05, len(t)) # generate random noise
46 | noisy_sin_data = sin_data * noise # add noise to the original data
47 |
48 | price_range = np.max(noisy_sin_data) - np.min(noisy_sin_data)
49 |
50 | # Generate random low and close prices
51 | low_prices = noisy_sin_data - np.random.uniform(0, 0.1 * price_range, len(noisy_sin_data))
52 | close_prices = noisy_sin_data + np.random.uniform(-0.05 * price_range, 0.05 * price_range, len(noisy_sin_data))
53 |
54 | # open prices usually are close to the close prices of the previous day
55 | open_prices = np.zeros(len(close_prices))
56 | open_prices[0] = close_prices[0]
57 | open_prices[1:] = close_prices[:-1]
58 |
59 | # high prices are always above open and close prices
60 | high_prices = np.maximum(open_prices, close_prices) + np.random.uniform(0, 0.1 * price_range, len(close_prices))
61 |
62 | # low prices are always below open and close prices
63 | low_prices = np.minimum(open_prices, close_prices) - np.random.uniform(0, 0.1 * price_range, len(close_prices))
64 |
65 | if plot:
66 | # Plot the price data
67 | plt.figure(figsize=(10, 6))
68 | plt.plot(t, noisy_sin_data, label='Noisy Sinusoidal Data')
69 | plt.plot(t, open_prices, label='Open')
70 | plt.plot(t, low_prices, label='Low')
71 | plt.plot(t, close_prices, label='Close')
72 | plt.plot(t, high_prices, label='High')
73 | plt.xlabel('Time')
74 | plt.ylabel('Price')
75 | plt.title('Fake Price Data')
76 | plt.legend()
77 | plt.grid(True)
78 | plt.show()
79 |
80 | # save the data to a CSV file with matplotlib as df[['open', 'high', 'low', 'close']
81 | df = pd.DataFrame({'timestamp': timestamps, 'open': open_prices, 'high': high_prices, 'low': low_prices, 'close': close_prices})
82 |
83 | return df
84 |
85 | if __name__ == '__main__':
86 | # Create a dataframe with sinusoidal data
87 | df = create_sinusoidal_df()
88 |
89 | # Create a directory to store the datasets
90 | os.makedirs('Datasets', exist_ok=True)
91 |
92 | # Save the dataframe to a CSV file
93 | df.to_csv(f'Datasets/random_sinusoid.csv')
--------------------------------------------------------------------------------
/bin/plot_data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import matplotlib.pyplot as plt
3 |
4 | df = pd.read_csv('Datasets/random_sinusoid.csv', index_col='timestamp', parse_dates=True)
5 | df = df[['open', 'high', 'low', 'close']]
6 | # limit to last 1000 data points
7 | df = df[-1000:]
8 |
9 | # plot the data
10 | plt.figure(figsize=(10, 6))
11 | plt.plot(df['close'])
12 | plt.xlabel('Time')
13 | plt.ylabel('Price')
14 | plt.title('random_sinusoid.csv')
15 | plt.grid(True)
16 | plt.show()
--------------------------------------------------------------------------------
/experiments/playing_random_sinusoid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | from finrock.data_feeder import PdDataFeeder
5 | from finrock.trading_env import TradingEnv
6 | from finrock.render import PygameRender
7 | from finrock.scalers import ZScoreScaler
8 | from finrock.reward import AccountValueChangeReward
9 | from finrock.indicators import BolingerBands, SMA, RSI, PSAR, MACD
10 | from finrock.metrics import DifferentActions, AccountValue, MaxDrawdown, SharpeRatio
11 |
12 | df = pd.read_csv('Datasets/random_sinusoid.csv')
13 |
14 | pd_data_feeder = PdDataFeeder(
15 | df = df,
16 | indicators = [
17 | BolingerBands(data=df, period=20, std=2),
18 | RSI(data=df, period=14),
19 | PSAR(data=df),
20 | MACD(data=df),
21 | SMA(data=df, period=7),
22 | ]
23 | )
24 |
25 | env = TradingEnv(
26 | data_feeder = pd_data_feeder,
27 | output_transformer = ZScoreScaler(),
28 | initial_balance = 1000.0,
29 | max_episode_steps = 1000,
30 | window_size = 50,
31 | reward_function = AccountValueChangeReward(),
32 | metrics = [
33 | DifferentActions(),
34 | AccountValue(),
35 | MaxDrawdown(),
36 | SharpeRatio(),
37 | ]
38 | )
39 | action_space = env.action_space
40 | input_shape = env.observation_space.shape
41 |
42 | env.save_config()
43 |
44 | pygameRender = PygameRender(frame_rate=60)
45 |
46 | state, info = env.reset()
47 | pygameRender.render(info)
48 | rewards = 0.0
49 | while True:
50 | # simulate model prediction, now use random action
51 | action = np.random.randint(0, action_space)
52 |
53 | state, reward, terminated, truncated, info = env.step(action)
54 | rewards += reward
55 | pygameRender.render(info)
56 |
57 | if terminated or truncated:
58 | print(info['states'][-1].account_value, rewards)
59 | rewards = 0.0
60 | state, info = env.reset()
61 | pygameRender.reset()
--------------------------------------------------------------------------------
/experiments/testing_ppo_sinusoid_continuous.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import tensorflow as tf
4 | tf.get_logger().setLevel('ERROR')
5 | for gpu in tf.config.experimental.list_physical_devices('GPU'):
6 | tf.config.experimental.set_memory_growth(gpu, True)
7 |
8 | from finrock.data_feeder import PdDataFeeder
9 | from finrock.trading_env import TradingEnv
10 | from finrock.render import PygameRender
11 |
12 |
13 | df = pd.read_csv('Datasets/random_sinusoid.csv')
14 | df = df[-1000:]
15 |
16 | model_path = "runs/1704798174"
17 |
18 | pd_data_feeder = PdDataFeeder.load_config(df, model_path)
19 | env = TradingEnv.load_config(pd_data_feeder, model_path)
20 |
21 | action_space = env.action_space
22 | input_shape = env.observation_space.shape
23 | pygameRender = PygameRender(frame_rate=120)
24 |
25 | agent = tf.keras.models.load_model(f'{model_path}/ppo_sinusoid_actor.h5')
26 |
27 | state, info = env.reset()
28 | pygameRender.render(info)
29 | rewards = 0.0
30 | while True:
31 | # simulate model prediction, now use random action
32 | action = agent.predict(np.expand_dims(state, axis=0), verbose=False)[0][:-1]
33 |
34 | state, reward, terminated, truncated, info = env.step(action)
35 | rewards += reward
36 | pygameRender.render(info)
37 |
38 | if terminated or truncated:
39 | print(rewards)
40 | for metric, value in info['metrics'].items():
41 | print(metric, value)
42 | state, info = env.reset()
43 | rewards = 0.0
44 | pygameRender.reset()
45 | pygameRender.render(info)
--------------------------------------------------------------------------------
/experiments/testing_ppo_sinusoid_discrete.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import tensorflow as tf
4 | tf.get_logger().setLevel('ERROR')
5 | for gpu in tf.config.experimental.list_physical_devices('GPU'):
6 | tf.config.experimental.set_memory_growth(gpu, True)
7 |
8 | from finrock.data_feeder import PdDataFeeder
9 | from finrock.trading_env import TradingEnv
10 | from finrock.render import PygameRender
11 |
12 |
13 | df = pd.read_csv('Datasets/random_sinusoid.csv')
14 | df = df[-1000:]
15 |
16 | model_path = "runs/1704746665"
17 |
18 | pd_data_feeder = PdDataFeeder.load_config(df, model_path)
19 | env = TradingEnv.load_config(pd_data_feeder, model_path)
20 |
21 | action_space = env.action_space
22 | input_shape = env.observation_space.shape
23 | pygameRender = PygameRender(frame_rate=120)
24 |
25 | agent = tf.keras.models.load_model(f'{model_path}/ppo_sinusoid_actor.h5')
26 |
27 | state, info = env.reset()
28 | pygameRender.render(info)
29 | rewards = 0.0
30 | while True:
31 | # simulate model prediction, now use random action
32 | prob = agent.predict(np.expand_dims(state, axis=0), verbose=False)[0]
33 | action = np.argmax(prob)
34 |
35 | state, reward, terminated, truncated, info = env.step(action)
36 | rewards += reward
37 | pygameRender.render(info)
38 |
39 | if terminated or truncated:
40 | print(rewards)
41 | for metric, value in info['metrics'].items():
42 | print(metric, value)
43 | state, info = env.reset()
44 | rewards = 0.0
45 | pygameRender.reset()
46 | pygameRender.render(info)
--------------------------------------------------------------------------------
/experiments/training_ppo_sinusoid_continuous.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import tensorflow as tf
4 | tf.get_logger().setLevel('ERROR')
5 | for gpu in tf.config.experimental.list_physical_devices('GPU'):
6 | tf.config.experimental.set_memory_growth(gpu, True)
7 |
8 | from keras import layers, models
9 |
10 | from finrock.data_feeder import PdDataFeeder
11 | from finrock.trading_env import TradingEnv, ActionSpace
12 | from finrock.scalers import ZScoreScaler
13 | from finrock.reward import AccountValueChangeReward
14 | from finrock.metrics import DifferentActions, AccountValue, MaxDrawdown, SharpeRatio
15 | from finrock.indicators import BolingerBands, RSI, PSAR, SMA, MACD
16 |
17 | from rockrl.utils.misc import MeanAverage
18 | from rockrl.utils.memory import MemoryManager
19 | from rockrl.tensorflow import PPOAgent
20 | from rockrl.utils.vectorizedEnv import VectorizedEnv
21 |
22 | df = pd.read_csv('Datasets/random_sinusoid.csv')
23 | df = df[:-1000] # leave 1000 for testing
24 |
25 | pd_data_feeder = PdDataFeeder(
26 | df,
27 | indicators = [
28 | BolingerBands(data=df, period=20, std=2),
29 | RSI(data=df, period=14),
30 | PSAR(data=df),
31 | MACD(data=df),
32 | SMA(data=df, period=7),
33 | ]
34 | )
35 |
36 | num_envs = 10
37 | env = VectorizedEnv(
38 | env_object = TradingEnv,
39 | num_envs = num_envs,
40 | data_feeder = pd_data_feeder,
41 | output_transformer = ZScoreScaler(),
42 | initial_balance = 1000.0,
43 | max_episode_steps = 1000,
44 | window_size = 50,
45 | reward_function = AccountValueChangeReward(),
46 | action_space = ActionSpace.CONTINUOUS,
47 | metrics = [
48 | DifferentActions(),
49 | AccountValue(),
50 | MaxDrawdown(),
51 | SharpeRatio(),
52 | ]
53 | )
54 |
55 | action_space = env.action_space
56 | input_shape = env.observation_space.shape
57 |
58 |
59 | def actor_model(input_shape, action_space):
60 | input = layers.Input(shape=input_shape, dtype=tf.float32)
61 | x = layers.Flatten()(input)
62 | x = layers.Dense(512, activation='elu')(x)
63 | x = layers.Dense(256, activation='elu')(x)
64 | x = layers.Dense(64, activation='elu')(x)
65 | x = layers.Dropout(0.2)(x)
66 | action = layers.Dense(action_space, activation="tanh")(x)
67 | sigma = layers.Dense(action_space)(x)
68 | sigma = layers.Dense(1, activation='sigmoid')(sigma)
69 | output = layers.concatenate([action, sigma]) # continuous action space
70 | return models.Model(inputs=input, outputs=output)
71 |
72 | def critic_model(input_shape):
73 | input = layers.Input(shape=input_shape, dtype=tf.float32)
74 | x = layers.Flatten()(input)
75 | x = layers.Dense(512, activation='elu')(x)
76 | x = layers.Dense(256, activation='elu')(x)
77 | x = layers.Dense(64, activation='elu')(x)
78 | x = layers.Dropout(0.2)(x)
79 | output = layers.Dense(1, activation=None)(x)
80 | return models.Model(inputs=input, outputs=output)
81 |
82 |
83 | agent = PPOAgent(
84 | actor = actor_model(input_shape, action_space),
85 | critic = critic_model(input_shape),
86 | optimizer=tf.keras.optimizers.Adam(learning_rate=0.00005),
87 | batch_size=128,
88 | lamda=0.95,
89 | kl_coeff=0.5,
90 | c2=0.01,
91 | writer_comment='ppo_sinusoid',
92 | action_space="continuous",
93 | )
94 | pd_data_feeder.save_config(agent.logdir)
95 | env.env.save_config(agent.logdir)
96 |
97 | memory = MemoryManager(num_envs=num_envs)
98 | meanAverage = MeanAverage(best_mean_score_episode=1000)
99 | states, infos = env.reset()
100 | rewards = 0.0
101 | while True:
102 | action, prob = agent.act(states)
103 |
104 | next_states, reward, terminated, truncated, infos = env.step(action)
105 | memory.append(states, action, reward, prob, terminated, truncated, next_states, infos)
106 | states = next_states
107 |
108 | for index in memory.done_indices():
109 | env_memory = memory[index]
110 | history = agent.train(env_memory)
111 | mean_reward = meanAverage(np.sum(env_memory.rewards))
112 |
113 | if meanAverage.is_best(agent.epoch):
114 | agent.save_models('ppo_sinusoid')
115 |
116 | if history['kl_div'] > 0.2 and agent.epoch > 1000:
117 | agent.reduce_learning_rate(0.995, verbose=False)
118 |
119 | info = env_memory.infos[-1]
120 | print(agent.epoch, np.sum(env_memory.rewards), mean_reward, info["metrics"]['account_value'], history['kl_div'])
121 | agent.log_to_writer(info['metrics'])
122 | states[index], infos[index] = env.reset(index=index)
123 |
124 | if agent.epoch >= 20000:
125 | break
126 |
127 | env.close()
128 | exit()
--------------------------------------------------------------------------------
/experiments/training_ppo_sinusoid_discrete.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import tensorflow as tf
4 | tf.get_logger().setLevel('ERROR')
5 | for gpu in tf.config.experimental.list_physical_devices('GPU'):
6 | tf.config.experimental.set_memory_growth(gpu, True)
7 |
8 | from keras import layers, models
9 |
10 | from finrock.data_feeder import PdDataFeeder
11 | from finrock.trading_env import TradingEnv
12 | from finrock.scalers import MinMaxScaler, ZScoreScaler
13 | from finrock.reward import SimpleReward, AccountValueChangeReward
14 | from finrock.metrics import DifferentActions, AccountValue, MaxDrawdown, SharpeRatio
15 | from finrock.indicators import BolingerBands, RSI, PSAR, SMA, MACD
16 |
17 | from rockrl.utils.misc import MeanAverage
18 | from rockrl.utils.memory import MemoryManager
19 | from rockrl.tensorflow import PPOAgent
20 | from rockrl.utils.vectorizedEnv import VectorizedEnv
21 |
22 | df = pd.read_csv('Datasets/random_sinusoid.csv')
23 | df = df[:-1000]
24 |
25 |
26 | pd_data_feeder = PdDataFeeder(
27 | df,
28 | indicators = [
29 | BolingerBands(data=df, period=20, std=2),
30 | RSI(data=df, period=14),
31 | PSAR(data=df),
32 | MACD(data=df),
33 | SMA(data=df, period=7),
34 | ]
35 | )
36 |
37 | num_envs = 10
38 | env = VectorizedEnv(
39 | env_object = TradingEnv,
40 | num_envs = num_envs,
41 | data_feeder = pd_data_feeder,
42 | output_transformer = ZScoreScaler(),
43 | initial_balance = 1000.0,
44 | max_episode_steps = 1000,
45 | window_size = 50,
46 | reward_function = AccountValueChangeReward(),
47 | metrics = [
48 | DifferentActions(),
49 | AccountValue(),
50 | MaxDrawdown(),
51 | SharpeRatio(),
52 | ]
53 | )
54 |
55 | action_space = env.action_space
56 | input_shape = env.observation_space.shape
57 |
58 | def actor_model(input_shape, action_space):
59 | input = layers.Input(shape=input_shape, dtype=tf.float32)
60 | x = layers.Flatten()(input)
61 | x = layers.Dense(512, activation='elu')(x)
62 | x = layers.Dense(256, activation='elu')(x)
63 | x = layers.Dense(64, activation='elu')(x)
64 | x = layers.Dropout(0.2)(x)
65 | output = layers.Dense(action_space, activation='softmax')(x) # discrete action space
66 | return models.Model(inputs=input, outputs=output)
67 |
68 | def critic_model(input_shape):
69 | input = layers.Input(shape=input_shape, dtype=tf.float32)
70 | x = layers.Flatten()(input)
71 | x = layers.Dense(512, activation='elu')(x)
72 | x = layers.Dense(256, activation='elu')(x)
73 | x = layers.Dense(64, activation='elu')(x)
74 | x = layers.Dropout(0.2)(x)
75 | output = layers.Dense(1, activation=None)(x)
76 | return models.Model(inputs=input, outputs=output)
77 |
78 | agent = PPOAgent(
79 | actor = actor_model(input_shape, action_space),
80 | critic = critic_model(input_shape),
81 | optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
82 | batch_size=128,
83 | lamda=0.95,
84 | kl_coeff=0.5,
85 | c2=0.01,
86 | writer_comment='ppo_sinusoid_discrete',
87 | )
88 |
89 | pd_data_feeder.save_config(agent.logdir)
90 | env.env.save_config(agent.logdir)
91 |
92 | memory = MemoryManager(num_envs=num_envs)
93 | meanAverage = MeanAverage(best_mean_score_episode=1000)
94 | states, infos = env.reset()
95 | rewards = 0.0
96 | while True:
97 | action, prob = agent.act(states)
98 |
99 | next_states, reward, terminated, truncated, infos = env.step(action)
100 | memory.append(states, action, reward, prob, terminated, truncated, next_states, infos)
101 | states = next_states
102 |
103 | for index in memory.done_indices():
104 | env_memory = memory[index]
105 | history = agent.train(env_memory)
106 | mean_reward = meanAverage(np.sum(env_memory.rewards))
107 |
108 | if meanAverage.is_best(agent.epoch):
109 | agent.save_models('ppo_sinusoid')
110 |
111 | if history['kl_div'] > 0.05 and agent.epoch > 1000:
112 | agent.reduce_learning_rate(0.995, verbose=False)
113 |
114 | info = env_memory.infos[-1]
115 | print(agent.epoch, np.sum(env_memory.rewards), mean_reward, info["metrics"]['account_value'], history['kl_div'])
116 | agent.log_to_writer(info['metrics'])
117 | states[index], infos[index] = env.reset(index=index)
118 |
119 | if agent.epoch >= 10000:
120 | break
121 |
122 | env.close()
123 | exit()
--------------------------------------------------------------------------------
/finrock/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.5.0"
--------------------------------------------------------------------------------
/finrock/data_feeder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import importlib
4 | import pandas as pd
5 | from finrock.state import State
6 | from finrock.indicators import Indicator
7 |
8 |
9 | class PdDataFeeder:
10 | def __init__(
11 | self,
12 | df: pd.DataFrame,
13 | indicators: list = [],
14 | min: float = None,
15 | max: float = None,
16 | ) -> None:
17 | self._df = df
18 | self._min = min
19 | self._max = max
20 | self._indicators = indicators
21 | self._cache = {}
22 |
23 | assert isinstance(self._df, pd.DataFrame) == True, "df must be a pandas.DataFrame"
24 | assert 'timestamp' in self._df.columns, "df must have 'timestamp' column"
25 | assert 'open' in self._df.columns, "df must have 'open' column"
26 | assert 'high' in self._df.columns, "df must have 'high' column"
27 | assert 'low' in self._df.columns, "df must have 'low' column"
28 | assert 'close' in self._df.columns, "df must have 'close' column"
29 |
30 | assert isinstance(self._indicators, list) == True, "indicators must be an iterable"
31 | assert all(isinstance(indicator, Indicator) for indicator in self._indicators) == True, "indicators must be a list of Indicator objects"
32 |
33 | @property
34 | def __name__(self) -> str:
35 | return self.__class__.__name__
36 |
37 | @property
38 | def name(self) -> str:
39 | return self.__name__
40 |
41 | @property
42 | def min(self) -> float:
43 | return self._min or self._df['low'].min()
44 |
45 | @property
46 | def max(self) -> float:
47 | return self._max or self._df['high'].max()
48 |
49 | def __len__(self) -> int:
50 | return len(self._df)
51 |
52 | def __getitem__(self, idx: int, args=None) -> State:
53 | # Use cache to speed up training
54 | if idx in self._cache:
55 | return self._cache[idx]
56 |
57 | indicators = []
58 | for indicator in self._indicators:
59 | results = indicator(idx)
60 | if results is None:
61 | self._cache[idx] = None
62 | return None
63 |
64 | indicators.append(results)
65 |
66 | data = self._df.iloc[idx]
67 | state = State(
68 | timestamp=data['timestamp'],
69 | open=data['open'],
70 | high=data['high'],
71 | low=data['low'],
72 | close=data['close'],
73 | volume=data.get('volume', 0.0),
74 | indicators=indicators
75 | )
76 | self._cache[idx] = state
77 |
78 | return state
79 |
80 | def __iter__(self) -> State:
81 | """ Create a generator that iterate over the Sequence."""
82 | for index in range(len(self)):
83 | yield self[index]
84 |
85 | def save_config(self, path: str) -> None:
86 | config = {
87 | "indicators": [],
88 | "min": self.min,
89 | "max": self.max
90 | }
91 | for indicator in self._indicators:
92 | config["indicators"].append(indicator.config())
93 |
94 | # save config into json file
95 | with open(os.path.join(path, "PdDataFeeder.json"), 'w') as outfile:
96 | json.dump(config, outfile, indent=4)
97 |
98 | @staticmethod
99 | def load_config(df, path: str) -> None:
100 | # load config from json file
101 | config_path = os.path.join(path, "PdDataFeeder.json")
102 | if not os.path.exists(config_path):
103 | raise Exception(f"PdDataFeeder Config file not found in {path}")
104 |
105 | with open(config_path) as json_file:
106 | config = json.load(json_file)
107 |
108 | _indicators = []
109 | for indicator in config["indicators"]:
110 | indicator_class = getattr(importlib.import_module(".indicators", package=__package__), indicator["name"])
111 | ind = indicator_class(data=df, **indicator)
112 | _indicators.append(ind)
113 |
114 | pdDataFeeder = PdDataFeeder(df=df, indicators=_indicators, min=config["min"], max=config["max"])
115 |
116 | return pdDataFeeder
--------------------------------------------------------------------------------
/finrock/indicators.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from .render import RenderOptions, RenderType, WindowType
4 |
5 | """ Implemented indicators:
6 | - SMA
7 | - Bolinger Bands
8 | - RSI
9 | - PSAR
10 | - MACD (Moving Average Convergence Divergence)
11 |
12 | TODO:
13 | - Commodity Channel Index (CCI), and X is the
14 | - Average Directional Index (ADX)
15 | """
16 |
17 |
18 | class Indicator:
19 | """ Base class for indicators
20 | """
21 | def __init__(
22 | self,
23 | data: pd.DataFrame,
24 | target_column: str='close',
25 | render_options: dict={},
26 | min: float=None,
27 | max: float=None,
28 | **kwargs
29 | ) -> None:
30 | self._data = data.copy()
31 | self._target_column = target_column
32 | self._custom_render_options = render_options
33 | self._render_options = render_options
34 | self._min = min # if min is not None else self._data[target_column].min()
35 | self._max = max # if max is not None else self._data[target_column].max()
36 | self.values = {}
37 |
38 | assert isinstance(self._data, pd.DataFrame) == True, "data must be a pandas.DataFrame"
39 | assert self._target_column in self._data.columns, f"data must have '{self._target_column}' column"
40 |
41 | self.compute()
42 | if not self._custom_render_options:
43 | self._render_options = self.default_render_options()
44 |
45 | @property
46 | def min(self):
47 | return self._min
48 |
49 | @min.setter
50 | def min(self, min: float):
51 | self._min = self._min or min
52 | if not self._custom_render_options:
53 | self._render_options = self.default_render_options()
54 |
55 | @property
56 | def max(self):
57 | return self._max
58 |
59 | @max.setter
60 | def max(self, max: float):
61 | self._max = self._max or max
62 | if not self._custom_render_options:
63 | self._render_options = self.default_render_options()
64 |
65 | @property
66 | def target_column(self):
67 | return self._target_column
68 |
69 | @property
70 | def __name__(self) -> str:
71 | return self.__class__.__name__
72 |
73 | @property
74 | def name(self):
75 | return self.__name__
76 |
77 | @property
78 | def names(self):
79 | return self._names
80 |
81 | def compute(self):
82 | raise NotImplementedError
83 |
84 | def default_render_options(self):
85 | return {}
86 |
87 | def render_options(self):
88 | return {name: option.copy() for name, option in self._render_options.items()}
89 |
90 | def __getitem__(self, index: int):
91 | row = self._data.iloc[index]
92 | for name in self.names:
93 | if pd.isna(row[name]):
94 | return None
95 |
96 | self.values[name] = row[name]
97 | if self._render_options.get(name):
98 | self._render_options[name].value = row[name]
99 |
100 | return self.serialise()
101 |
102 | def __call__(self, index: int):
103 | return self[index]
104 |
105 | def serialise(self):
106 | return {
107 | 'name': self.name,
108 | 'names': self.names,
109 | 'values': self.values.copy(),
110 | 'target_column': self.target_column,
111 | 'render_options': self.render_options(),
112 | 'min': self.min,
113 | 'max': self.max
114 | }
115 |
116 | def config(self):
117 | return {
118 | 'name': self.name,
119 | 'names': self.names,
120 | 'target_column': self.target_column,
121 | 'min': self.min,
122 | 'max': self.max
123 | }
124 |
125 |
126 |
127 | class SMA(Indicator):
128 | """ Trend indicator
129 |
130 | A simple moving average (SMA) calculates the average of a selected range of prices, usually closing prices, by the number
131 | of periods in that range.
132 |
133 | The SMA is a technical indicator for determining if an asset price will continue or reverse a bull or bear trend. It is
134 | calculated by summing up the closing prices of a stock over time and then dividing that total by the number of time periods
135 | being examined. Short-term averages respond quickly to changes in the price of the underlying, while long-term averages are
136 | slow to react.
137 |
138 | https://www.investopedia.com/terms/s/sma.asp
139 | """
140 | def __init__(
141 | self,
142 | data: pd.DataFrame,
143 | period: int=20,
144 | target_column: str='close',
145 | render_options: dict={},
146 | **kwargs
147 | ):
148 | self._period = period
149 | self._names = [f'SMA{period}']
150 | super().__init__(data, target_column, render_options, **kwargs)
151 | self.min = self._data[self._names[0]].min()
152 | self.max = self._data[self._names[0]].max()
153 |
154 | def default_render_options(self):
155 | return {name: RenderOptions(
156 | name=name,
157 | color=(100, 100, 255),
158 | window_type=WindowType.MAIN,
159 | render_type=RenderType.LINE,
160 | min=self.min,
161 | max=self.max
162 | ) for name in self._names}
163 |
164 | def compute(self):
165 | self._data[self.names[0]] = self._data[self.target_column].rolling(self._period).mean()
166 |
167 | def config(self):
168 | config = super().config()
169 | config['period'] = self._period
170 | return config
171 |
172 |
173 | class BolingerBands(Indicator):
174 | """ Volatility indicator
175 |
176 | Bollinger Bands are a type of price envelope developed by John BollingerOpens in a new window. (Price envelopes define
177 | upper and lower price range levels.) Bollinger Bands are envelopes plotted at a standard deviation level above and
178 | below a simple moving average of the price. Because the distance of the bands is based on standard deviation, they
179 | adjust to volatility swings in the underlying price.
180 |
181 | Bollinger Bands use 2 parameters, Period and Standard Deviations, StdDev. The default values are 20 for period, and 2
182 | for standard deviations, although you may customize the combinations.
183 |
184 | Bollinger bands help determine whether prices are high or low on a relative basis. They are used in pairs, both upper
185 | and lower bands and in conjunction with a moving average. Further, the pair of bands is not intended to be used on its own.
186 | Use the pair to confirm signals given with other indicators.
187 | """
188 | def __init__(
189 | self,
190 | data: pd.DataFrame,
191 | period: int=20,
192 | std: int=2,
193 | target_column: str='close',
194 | render_options: dict={},
195 | **kwargs
196 | ):
197 | self._period = period
198 | self._std = std
199 | self._names = ['SMA', 'BB_up', 'BB_dn']
200 | super().__init__(data, target_column, render_options, **kwargs)
201 | self.min = self._data['BB_dn'].min()
202 | self.max = self._data['BB_up'].max()
203 |
204 | def compute(self):
205 | self._data['SMA'] = self._data[self.target_column].rolling(self._period).mean()
206 | self._data['BB_up'] = self._data['SMA'] + self._data[self.target_column].rolling(self._period).std() * self._std
207 | self._data['BB_dn'] = self._data['SMA'] - self._data[self.target_column].rolling(self._period).std() * self._std
208 |
209 | def default_render_options(self):
210 | return {name: RenderOptions(
211 | name=name,
212 | color=(100, 100, 255),
213 | window_type=WindowType.MAIN,
214 | render_type=RenderType.LINE,
215 | min=self.min,
216 | max=self.max
217 | ) for name in self._names}
218 |
219 | def config(self):
220 | config = super().config()
221 | config['period'] = self._period
222 | config['std'] = self._std
223 | return config
224 |
225 | class RSI(Indicator):
226 | """ Momentum indicator
227 |
228 | The Relative Strength Index (RSI), developed by J. Welles Wilder, is a momentum oscillator that measures the speed and
229 | change of price movements. The RSI oscillates between zero and 100. Traditionally the RSI is considered overbought when
230 | above 70 and oversold when below 30. Signals can be generated by looking for divergences and failure swings.
231 | RSI can also be used to identify the general trend.
232 | """
233 | def __init__(
234 | self,
235 | data: pd.DataFrame,
236 | period: int=14,
237 | target_column: str='close',
238 | render_options: dict={},
239 | min: float=0.0,
240 | max: float=100.0,
241 | **kwargs
242 | ):
243 | self._period = period
244 | self._names = ['RSI']
245 | super().__init__(data, target_column, render_options, min=min, max=max, **kwargs)
246 |
247 | def compute(self):
248 | delta = self._data[self.target_column].diff()
249 | up = delta.clip(lower=0)
250 | down = -1 * delta.clip(upper=0)
251 | ema_up = up.ewm(com=self._period-1, adjust=True, min_periods=self._period).mean()
252 | ema_down = down.ewm(com=self._period-1, adjust=True, min_periods=self._period).mean()
253 | rs = ema_up / ema_down
254 | self._data['RSI'] = 100 - (100 / (1 + rs))
255 |
256 | def default_render_options(self):
257 | custom_options = {
258 | "RSI0": 0,
259 | "RSI30": 30,
260 | "RSI70": 70,
261 | "RSI100": 100
262 | }
263 | options = {name: RenderOptions(
264 | name=name,
265 | color=(100, 100, 255),
266 | window_type=WindowType.SEPERATE,
267 | render_type=RenderType.LINE,
268 | min=self.min,
269 | max=self.max
270 | ) for name in self._names}
271 |
272 | for name, value in custom_options.items():
273 | options[name] = RenderOptions(
274 | name=name,
275 | color=(192, 192, 192),
276 | window_type=WindowType.SEPERATE,
277 | render_type=RenderType.LINE,
278 | min=self.min,
279 | max=self.max,
280 | value=value
281 | )
282 | return options
283 |
284 | def config(self):
285 | config = super().config()
286 | config['period'] = self._period
287 | return config
288 |
289 |
290 | class PSAR(Indicator):
291 | """ Parabolic Stop and Reverse (Parabolic SAR)
292 |
293 | The Parabolic Stop and Reverse, more commonly known as the
294 | Parabolic SAR,is a trend-following indicator developed by
295 | J. Welles Wilder. The Parabolic SAR is displayed as a single
296 | parabolic line (or dots) underneath the price bars in an uptrend,
297 | and above the price bars in a downtrend.
298 |
299 | https://school.stockcharts.com/doku.php?id=technical_indicators:parabolic_sar
300 | """
301 | def __init__(
302 | self,
303 | data: pd.DataFrame,
304 | step: float=0.02,
305 | max_step: float=0.2,
306 | target_column: str='close',
307 | render_options: dict={},
308 | **kwargs
309 | ):
310 | self._names = ['PSAR']
311 | self._step = step
312 | self._max_step = max_step
313 | super().__init__(data, target_column, render_options, **kwargs)
314 | self.min = self._data['PSAR'].min()
315 | self.max = self._data['PSAR'].max()
316 |
317 | def default_render_options(self):
318 | return {name: RenderOptions(
319 | name=name,
320 | color=(100, 100, 255),
321 | window_type=WindowType.MAIN,
322 | render_type=RenderType.DOT,
323 | min=self.min,
324 | max=self.max
325 | ) for name in self._names}
326 |
327 | def compute(self):
328 | high = self._data['high']
329 | low = self._data['low']
330 | close = self._data[self.target_column]
331 |
332 | up_trend = True
333 | acceleration_factor = self._step
334 | up_trend_high = high.iloc[0]
335 | down_trend_low = low.iloc[0]
336 |
337 | self._psar = close.copy()
338 | self._psar_up = pd.Series(index=self._psar.index, dtype="float64")
339 | self._psar_down = pd.Series(index=self._psar.index, dtype="float64")
340 |
341 | for i in range(2, len(close)):
342 | reversal = False
343 |
344 | max_high = high.iloc[i]
345 | min_low = low.iloc[i]
346 |
347 | if up_trend:
348 | self._psar.iloc[i] = self._psar.iloc[i - 1] + (
349 | acceleration_factor * (up_trend_high - self._psar.iloc[i - 1])
350 | )
351 |
352 | if min_low < self._psar.iloc[i]:
353 | reversal = True
354 | self._psar.iloc[i] = up_trend_high
355 | down_trend_low = min_low
356 | acceleration_factor = self._step
357 | else:
358 | if max_high > up_trend_high:
359 | up_trend_high = max_high
360 | acceleration_factor = min(
361 | acceleration_factor + self._step, self._max_step
362 | )
363 |
364 | low1 = low.iloc[i - 1]
365 | low2 = low.iloc[i - 2]
366 | if low2 < self._psar.iloc[i]:
367 | self._psar.iloc[i] = low2
368 | elif low1 < self._psar.iloc[i]:
369 | self._psar.iloc[i] = low1
370 | else:
371 | self._psar.iloc[i] = self._psar.iloc[i - 1] - (
372 | acceleration_factor * (self._psar.iloc[i - 1] - down_trend_low)
373 | )
374 |
375 | if max_high > self._psar.iloc[i]:
376 | reversal = True
377 | self._psar.iloc[i] = down_trend_low
378 | up_trend_high = max_high
379 | acceleration_factor = self._step
380 | else:
381 | if min_low < down_trend_low:
382 | down_trend_low = min_low
383 | acceleration_factor = min(
384 | acceleration_factor + self._step, self._max_step
385 | )
386 |
387 | high1 = high.iloc[i - 1]
388 | high2 = high.iloc[i - 2]
389 | if high2 > self._psar.iloc[i]:
390 | self._psar[i] = high2
391 | elif high1 > self._psar.iloc[i]:
392 | self._psar.iloc[i] = high1
393 |
394 | up_trend = up_trend != reversal # XOR
395 |
396 | if up_trend:
397 | self._psar_up.iloc[i] = self._psar.iloc[i]
398 | else:
399 | self._psar_down.iloc[i] = self._psar.iloc[i]
400 |
401 | # calculate psar indicator
402 | self._data['PSAR'] = self._psar
403 |
404 | def config(self):
405 | config = super().config()
406 | config['step'] = self._step
407 | config['max_step'] = self._max_step
408 | return config
409 |
410 |
411 | class MACD(Indicator):
412 | """ Moving Average Convergence Divergence (MACD)
413 | """
414 | def __init__(
415 | self,
416 | data: pd.DataFrame,
417 | fast_ma: int = 12,
418 | slow_ma: int = 26,
419 | histogram: int = 9,
420 | target_column: str='close',
421 | render_options: dict={},
422 | **kwargs
423 | ):
424 | self._fast_ma = fast_ma
425 | self._slow_ma = slow_ma
426 | self._histogram = histogram
427 | self._names = ['MACD', 'MACD_signal']
428 | super().__init__(data, target_column, render_options, **kwargs)
429 | self.min = self._data['MACD_signal'].min()
430 | self.max = self._data['MACD_signal'].max()
431 |
432 | def compute(self):
433 | # Calculate the Short Term Exponential Moving Average (EMA)
434 | short_ema = self._data[self.target_column].ewm(span=self._fast_ma, adjust=False).mean()
435 |
436 | # Calculate the Long Term Exponential Moving Average (EMA)
437 | long_ema = self._data[self.target_column].ewm(span=self._slow_ma, adjust=False).mean()
438 |
439 | # Calculate the Moving Average Convergence/Divergence (MACD)
440 | self._data["MACD"] = short_ema - long_ema
441 |
442 | # Calculate the Signal Line
443 | self._data["MACD_signal"] = self._data["MACD"].ewm(span=9, adjust=False).mean()
444 |
445 | def default_render_options(self):
446 | return {name: RenderOptions(
447 | name=name,
448 | color=(100, 100, 255),
449 | window_type=WindowType.SEPERATE,
450 | render_type=RenderType.LINE,
451 | min=self.min,
452 | max=self.max
453 | ) for name in self._names}
454 |
455 | def config(self):
456 | config = super().config()
457 | config['fast_ma'] = self._fast_ma
458 | config['slow_ma'] = self._slow_ma
459 | config['histogram'] = self._histogram
460 | return config
--------------------------------------------------------------------------------
/finrock/metrics.py:
--------------------------------------------------------------------------------
1 | from .state import State
2 | import numpy as np
3 |
4 | """ Metrics are used to track and log information about the environment.
5 | possible metrics:
6 | + DifferentActions,
7 | + AccountValue,
8 | + MaxDrawdown,
9 | + SharpeRatio,
10 | - AverageProfit,
11 | - AverageLoss,
12 | - AverageTrade,
13 | - WinRate,
14 | - LossRate,
15 | - AverageWin,
16 | - AverageLoss,
17 | - AverageWinLossRatio,
18 | - AverageTradeDuration,
19 | - AverageTradeReturn,
20 | """
21 |
22 | class Metric:
23 | def __init__(self, name: str="metric") -> None:
24 | self.name = name
25 | self.reset()
26 |
27 | @property
28 | def __name__(self) -> str:
29 | return self.__class__.__name__
30 |
31 | def update(self, state: State):
32 | assert isinstance(state, State), f'state must be State, received: {type(state)}'
33 |
34 | return state
35 |
36 | @property
37 | def result(self):
38 | raise NotImplementedError
39 |
40 | def reset(self, prev_state: State=None):
41 | assert prev_state is None or isinstance(prev_state, State), f'prev_state must be None or State, received: {type(prev_state)}'
42 |
43 | return prev_state
44 |
45 |
46 | class DifferentActions(Metric):
47 | def __init__(self, name: str="different_actions") -> None:
48 | super().__init__(name=name)
49 |
50 | def update(self, state: State):
51 | super().update(state)
52 |
53 | if not self.prev_state:
54 | self.prev_state = state
55 | else:
56 | if state.allocation_percentage != self.prev_state.allocation_percentage:
57 | self.different_actions += 1
58 |
59 | self.prev_state = state
60 |
61 | @property
62 | def result(self):
63 | return self.different_actions
64 |
65 | def reset(self, prev_state: State=None):
66 | super().reset(prev_state)
67 |
68 | self.prev_state = prev_state
69 | self.different_actions = 0
70 |
71 |
72 | class AccountValue(Metric):
73 | def __init__(self, name: str="account_value") -> None:
74 | super().__init__(name=name)
75 |
76 | def update(self, state: State):
77 | super().update(state)
78 |
79 | self.account_value = state.account_value
80 |
81 | @property
82 | def result(self):
83 | return self.account_value
84 |
85 | def reset(self, prev_state: State=None):
86 | super().reset(prev_state)
87 |
88 | self.account_value = prev_state.account_value if prev_state else 0.0
89 |
90 |
91 | class MaxDrawdown(Metric):
92 | """ The Maximum Drawdown (MDD) is a measure of the largest peak-to-trough decline in the
93 | value of a portfolio or investment during a specific period
94 |
95 | The Maximum Drawdown Ratio represents the proportion of the peak value that was lost during
96 | the largest decline. It is a measure of the risk associated with a particular investment or
97 | portfolio. Investors and fund managers use the Maximum Drawdown and its ratio to assess the
98 | historical downside risk and potential losses that could be incurred.
99 | """
100 | def __init__(self, name: str="max_drawdown") -> None:
101 | super().__init__(name=name)
102 |
103 | def update(self, state: State):
104 | super().update(state)
105 |
106 | # Use min to find the trough value
107 | self.max_account_value = max(self.max_account_value, state.account_value)
108 |
109 | # Calculate drawdown
110 | drawdown = (state.account_value - self.max_account_value) / self.max_account_value
111 |
112 | # Update max drawdown if the current drawdown is greater
113 | self.max_drawdown = min(self.max_drawdown, drawdown)
114 |
115 | @property
116 | def result(self):
117 | return self.max_drawdown
118 |
119 | def reset(self, prev_state: State=None):
120 | super().reset(prev_state)
121 |
122 | self.max_account_value = prev_state.account_value if prev_state else 0.0
123 | self.max_drawdown = 0.0
124 |
125 |
126 | class SharpeRatio(Metric):
127 | """ The Sharpe Ratio, is a measure of the risk-adjusted performance of an investment or a portfolio.
128 | It helps investors evaluate the return of an investment relative to its risk.
129 |
130 | A higher Sharpe Ratio indicates a better risk-adjusted performance. Investors and portfolio managers
131 | often use the Sharpe Ratio to compare the risk-adjusted returns of different investments or portfolios.
132 | It allows them to assess whether the additional return earned by taking on additional risk is justified.
133 | """
134 | def __init__(self, ratio_days=365.25, name: str='sharpe_ratio'):
135 | self.ratio_days = ratio_days
136 | super().__init__(name=name)
137 |
138 | def update(self, state: State):
139 | super().update(state)
140 | time_difference_days = (state.date - self.prev_state.date).days
141 | if time_difference_days >= 1:
142 | self.daily_returns.append((state.account_value - self.prev_state.account_value) / self.prev_state.account_value)
143 | self.prev_state = state
144 |
145 | @property
146 | def result(self):
147 | if len(self.daily_returns) == 0:
148 | return 0.0
149 |
150 | mean = np.mean(self.daily_returns)
151 | std = np.std(self.daily_returns)
152 | if std == 0:
153 | return 0.0
154 |
155 | sharpe_ratio = mean / std * np.sqrt(self.ratio_days)
156 |
157 | return sharpe_ratio
158 |
159 | def reset(self, prev_state: State=None):
160 | super().reset(prev_state)
161 | self.prev_state = prev_state
162 | self.daily_returns = []
--------------------------------------------------------------------------------
/finrock/render.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from .state import State
3 |
4 | class RenderType(Enum):
5 | LINE = 0
6 | DOT = 1
7 |
8 | class WindowType(Enum):
9 | MAIN = 0
10 | SEPERATE = 1
11 |
12 | class RenderOptions:
13 | def __init__(
14 | self,
15 | name: str,
16 | color: tuple,
17 | window_type: WindowType,
18 | render_type: RenderType,
19 | min: float,
20 | max: float,
21 | value: float = None,
22 | ):
23 | self.name = name
24 | self.color = color
25 | self.window_type = window_type
26 | self.render_type = render_type
27 | self.min = min
28 | self.max = max
29 | self.value = value
30 |
31 | def copy(self):
32 | return RenderOptions(
33 | name=self.name,
34 | color=self.color,
35 | window_type=self.window_type,
36 | render_type=self.render_type,
37 | min=self.min,
38 | max=self.max,
39 | value=self.value
40 | )
41 |
42 | class ColorTheme:
43 | black = (0, 0, 0)
44 | white = (255, 255, 255)
45 | red = (255, 10, 0)
46 | lightblue = (100, 100, 255)
47 | green = (0, 240, 0)
48 |
49 | background = black
50 | up_candle = green
51 | down_candle = red
52 | wick = white
53 | text = white
54 | buy = green
55 | sell = red
56 | font = 'Noto Sans'
57 | font_ratio = 0.02
58 |
59 | class MainWindow:
60 | def __init__(
61 | self,
62 | width: int,
63 | height: int,
64 | top_offset: int,
65 | bottom_offset: int,
66 | window_size: int,
67 | candle_spacing,
68 | font_ratio: float=0.02,
69 | spacing_ratio: float=0.02,
70 | split_offset: int=0
71 | ):
72 | self.width = width
73 | self.height = height
74 | self.top_offset = top_offset
75 | self.bottom_offset = bottom_offset
76 | self.window_size = window_size
77 | self.candle_spacing = candle_spacing
78 | self.font_ratio = font_ratio
79 | self.spacing_ratio = spacing_ratio
80 | self.split_offset = split_offset
81 |
82 | self.seperate_window_ratio = 0.15
83 |
84 | @property
85 | def font_size(self):
86 | return int(self.height * self.font_ratio)
87 |
88 | @property
89 | def candle_width(self):
90 | return self.width // self.window_size - self.candle_spacing
91 |
92 | @property
93 | def chart_height(self):
94 | return self.height - (2 * self.top_offset + self.bottom_offset)
95 |
96 | @property
97 | def spacing(self):
98 | return int(self.height * self.spacing_ratio)
99 |
100 | @property
101 | def screen_shape(self):
102 | return (self.width, self.height)
103 |
104 | @screen_shape.setter
105 | def screen_shape(self, value: tuple):
106 | self.width, self.height = value
107 |
108 | def map_price_to_window(self, price: float, max_low: float, max_high: float):
109 | max_range = max_high - max_low
110 | height = self.chart_height - self.split_offset - self.bottom_offset - self.top_offset * 2
111 | value = int(height - (price - max_low) / max_range * height) + self.top_offset
112 | return value
113 |
114 | def map_to_seperate_window(self, value: float, min: float, max: float):
115 | self.split_offset = int(self.height * self.seperate_window_ratio)
116 | max_range = max - min
117 | new_value = int(self.split_offset - (value - min) / max_range * self.split_offset)
118 | height = self.chart_height - self.split_offset + new_value
119 | return height
120 |
121 |
122 | class PygameRender:
123 | def __init__(
124 | self,
125 | window_size: int=100,
126 | screen_width: int=1440,
127 | screen_height: int=1080,
128 | top_offset: int=25,
129 | bottom_offset: int=25,
130 | candle_spacing: int=1,
131 | color_theme = ColorTheme(),
132 | frame_rate: int=30,
133 | render_balance: bool=True,
134 | ):
135 | # pygame window settings
136 | self.screen_width = screen_width
137 | self.screen_height = screen_height
138 | self.top_offset = top_offset
139 | self.bottom_offset = bottom_offset
140 | self.candle_spacing = candle_spacing
141 | self.window_size = window_size
142 | self.color_theme = color_theme
143 | self.frame_rate = frame_rate
144 | self.render_balance = render_balance
145 |
146 | self.mainWindow = MainWindow(
147 | width=self.screen_width,
148 | height=self.screen_height,
149 | top_offset=self.top_offset,
150 | bottom_offset=self.bottom_offset,
151 | window_size=self.window_size,
152 | candle_spacing=self.candle_spacing,
153 | font_ratio=self.color_theme.font_ratio
154 | )
155 |
156 | self._states = []
157 |
158 | try:
159 | import pygame
160 | self.pygame = pygame
161 | except ImportError:
162 | raise ImportError('Please install pygame (pip install pygame)')
163 |
164 | self.pygame.init()
165 | self.pygame.display.init()
166 | self.window = self.pygame.display.set_mode(self.mainWindow.screen_shape, self.pygame.RESIZABLE)
167 | self.clock = self.pygame.time.Clock()
168 |
169 | def reset(self):
170 | self._states = []
171 |
172 | def _prerender(func):
173 | """ Decorator for input data validation and pygame window rendering"""
174 | def wrapper(self, info: dict, rgb_array: bool=False):
175 | self._states += info.get('states', [])
176 |
177 | if not self._states or not bool(self.window._pixels_address):
178 | return
179 |
180 | for event in self.pygame.event.get():
181 | if event.type == self.pygame.QUIT:
182 | self.pygame.quit()
183 | return
184 |
185 | if event.type == self.pygame.VIDEORESIZE:
186 | self.mainWindow.screen_shape = (event.w, event.h)
187 |
188 | # pause if spacebar is pressed
189 | if event.type == self.pygame.KEYDOWN:
190 | if event.key == self.pygame.K_SPACE:
191 | print('Paused')
192 | while True:
193 | event = self.pygame.event.wait()
194 | if event.type == self.pygame.KEYDOWN:
195 | if event.key == self.pygame.K_SPACE:
196 | print('Unpaused')
197 | break
198 | if event.type == self.pygame.QUIT:
199 | self.pygame.quit()
200 | return
201 |
202 | self.mainWindow.screen_shape = self.pygame.display.get_surface().get_size()
203 |
204 |
205 | canvas = func(self, info)
206 | canvas = self.pygame.transform.scale(canvas, self.mainWindow.screen_shape)
207 | # The following line copies our drawings from `canvas` to the visible window
208 | self.window.blit(canvas, canvas.get_rect())
209 | self.pygame.display.update()
210 | self.clock.tick(self.frame_rate)
211 |
212 | if rgb_array:
213 | return self.pygame.surfarray.array3d(canvas)
214 |
215 | return wrapper
216 |
217 | def render_indicators(self, state: State, canvas: object, candle_offset: int, max_low: float, max_high: float):
218 | # connect last 2 points with a line
219 | for i, indicator in enumerate(state.indicators):
220 | for name, render_option in indicator["render_options"].items():
221 |
222 | index = self._states.index(state)
223 | if not index:
224 | return
225 | last_state = self._states[index - 1]
226 |
227 | if render_option.render_type == RenderType.LINE:
228 | prev_render_option = last_state.indicators[i]["render_options"][name]
229 | if render_option.window_type == WindowType.MAIN:
230 |
231 | cur_value_map = self.mainWindow.map_price_to_window(render_option.value, max_low, max_high)
232 | prev_value_map = self.mainWindow.map_price_to_window(prev_render_option.value, max_low, max_high)
233 |
234 | elif render_option.window_type == WindowType.SEPERATE:
235 |
236 | cur_value_map = self.mainWindow.map_to_seperate_window(render_option.value, render_option.min, render_option.max)
237 | prev_value_map = self.mainWindow.map_to_seperate_window(prev_render_option.value, prev_render_option.min, prev_render_option.max)
238 |
239 | self.pygame.draw.line(canvas, render_option.color,
240 | (candle_offset - self.mainWindow.candle_width / 2, prev_value_map),
241 | (candle_offset + self.mainWindow.candle_width / 2, cur_value_map))
242 |
243 | elif render_option.render_type == RenderType.DOT:
244 | if render_option.window_type == WindowType.MAIN:
245 | self.pygame.draw.circle(canvas, render_option.color,
246 | (candle_offset, self.mainWindow.map_price_to_window(render_option.value, max_low, max_high)), 2)
247 | elif render_option.window == WindowType.SEPERATE:
248 | raise NotImplementedError('Seperate window for indicators is not implemented yet')
249 |
250 | def render_candle(self, state: State, canvas: object, candle_offset: int, max_low: float, max_high: float, font: object):
251 | assert isinstance(state, State) == True # check if state is a State object
252 |
253 | # Calculate candle coordinates
254 | candle_y_open = self.mainWindow.map_price_to_window(state.open, max_low, max_high)
255 | candle_y_close = self.mainWindow.map_price_to_window(state.close, max_low, max_high)
256 | candle_y_high = self.mainWindow.map_price_to_window(state.high, max_low, max_high)
257 | candle_y_low = self.mainWindow.map_price_to_window(state.low, max_low, max_high)
258 |
259 | # Determine candle color
260 | if state.open < state.close:
261 | # up candle
262 | candle_color = self.color_theme.up_candle
263 | candle_body_y = candle_y_close
264 | candle_body_height = candle_y_open - candle_y_close
265 | else:
266 | # down candle
267 | candle_color = self.color_theme.down_candle
268 | candle_body_y = candle_y_open
269 | candle_body_height = candle_y_close - candle_y_open
270 |
271 | # Draw candlestick wicks
272 | self.pygame.draw.line(canvas, self.color_theme.wick,
273 | (candle_offset + self.mainWindow.candle_width // 2, candle_y_high),
274 | (candle_offset + self.mainWindow.candle_width // 2, candle_y_low))
275 |
276 | # Draw candlestick body
277 | self.pygame.draw.rect(canvas, candle_color, (candle_offset, candle_body_y, self.mainWindow.candle_width, candle_body_height))
278 |
279 | # Compare with previous state to determine whether buy or sell action was taken and draw arrow
280 | index = self._states.index(state)
281 | if index > 0:
282 | last_state = self._states[index - 1]
283 |
284 | if last_state.allocation_percentage < state.allocation_percentage:
285 | # buy
286 | candle_y_low = self.mainWindow.map_price_to_window(last_state.low, max_low, max_high)
287 | self.pygame.draw.polygon(canvas, self.color_theme.buy, [
288 | (candle_offset - self.mainWindow.candle_width / 2, candle_y_low + self.mainWindow.spacing / 2),
289 | (candle_offset - self.mainWindow.candle_width * 0.1, candle_y_low + self.mainWindow.spacing),
290 | (candle_offset - self.mainWindow.candle_width * 0.9, candle_y_low + self.mainWindow.spacing)
291 | ])
292 |
293 | # add account_value label bellow candle
294 | if self.render_balance:
295 | text = str(int(last_state.account_value))
296 | buy_label = font.render(text, True, self.color_theme.text)
297 | label_width, label_height = font.size(text)
298 | canvas.blit(buy_label, (candle_offset - (self.mainWindow.candle_width + label_width) / 2, candle_y_low + self.mainWindow.spacing))
299 |
300 | elif last_state.allocation_percentage > state.allocation_percentage:
301 | # sell
302 | candle_y_high = self.mainWindow.map_price_to_window(last_state.high, max_low, max_high)
303 | self.pygame.draw.polygon(canvas, self.color_theme.sell, [
304 | (candle_offset - self.mainWindow.candle_width / 2, candle_y_high - self.mainWindow.spacing / 2),
305 | (candle_offset - self.mainWindow.candle_width * 0.1, candle_y_high - self.mainWindow.spacing),
306 | (candle_offset - self.mainWindow.candle_width * 0.9, candle_y_high - self.mainWindow.spacing)
307 | ])
308 |
309 | # add account_value label above candle
310 | if self.render_balance:
311 | text = str(int(last_state.account_value))
312 | sell_label = font.render(text, True, self.color_theme.text)
313 | label_width, label_height = font.size(text)
314 | canvas.blit(sell_label, (candle_offset - (self.mainWindow.candle_width + label_width) / 2, candle_y_high - self.mainWindow.spacing - label_height))
315 |
316 | @_prerender
317 | def render(self, info: dict):
318 | canvas = self.pygame.Surface(self.mainWindow.screen_shape)
319 | canvas.fill(self.color_theme.background)
320 |
321 | max_high = max([state.high for state in self._states[-self.window_size:]])
322 | max_low = min([state.low for state in self._states[-self.window_size:]])
323 |
324 | candle_offset = self.candle_spacing
325 |
326 | # Set font for labels
327 | font = self.pygame.font.SysFont(self.color_theme.font, self.mainWindow.font_size)
328 |
329 | for state in self._states[-self.window_size:]:
330 |
331 | # draw indicators
332 | self.render_indicators(state, canvas, candle_offset, max_low, max_high)
333 |
334 | # draw candle
335 | self.render_candle(state, canvas, candle_offset, max_low, max_high, font)
336 |
337 | # Move to the next candle
338 | candle_offset += self.mainWindow.candle_width + self.candle_spacing
339 |
340 | # Draw max and min ohlc values on the chart
341 | label_width, label_height = font.size(str(max_low))
342 | label_y_low = font.render(str(max_low), True, self.color_theme.text)
343 | canvas.blit(label_y_low, (self.candle_spacing + 5, self.mainWindow.height - label_height * 2))
344 |
345 | label_width, label_height = font.size(str(max_low))
346 | label_y_high = font.render(str(max_high), True, self.color_theme.text)
347 | canvas.blit(label_y_high, (self.candle_spacing + 5, label_height))
348 |
349 | return canvas
--------------------------------------------------------------------------------
/finrock/reward.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .state import Observations
3 |
4 | class Reward:
5 | def __init__(self) -> None:
6 | pass
7 |
8 | @property
9 | def __name__(self) -> str:
10 | return self.__class__.__name__
11 |
12 | def __call__(self, observations: Observations) -> float:
13 | raise NotImplementedError
14 |
15 | def reset(self, observations: Observations):
16 | pass
17 |
18 |
19 | class SimpleReward(Reward):
20 | def __init__(self) -> None:
21 | super().__init__()
22 |
23 | def __call__(self, observations: Observations) -> float:
24 | assert isinstance(observations, Observations) == True, "observations must be an instance of Observations"
25 |
26 | last_state, next_state = observations[-2:]
27 |
28 | # buy
29 | if next_state.allocation_percentage > last_state.allocation_percentage:
30 | # check whether it was good or bad to buy
31 | order_size = next_state.allocation_percentage - last_state.allocation_percentage
32 | reward = (next_state.close - last_state.close) / last_state.close * order_size
33 | hold_reward = (next_state.close - last_state.close) / last_state.close * last_state.allocation_percentage
34 | reward += hold_reward
35 |
36 | # sell
37 | elif next_state.allocation_percentage < last_state.allocation_percentage:
38 | # check whether it was good or bad to sell
39 | order_size = last_state.allocation_percentage - next_state.allocation_percentage
40 | reward = -1 * (next_state.close - last_state.close) / last_state.close * order_size
41 | hold_reward = (next_state.close - last_state.close) / last_state.close * next_state.allocation_percentage
42 | reward += hold_reward
43 |
44 | # hold
45 | else:
46 | # check whether it was good or bad to hold
47 | ratio = -1 if not last_state.allocation_percentage else last_state.allocation_percentage
48 | reward = (next_state.close - last_state.close) / last_state.close * ratio
49 |
50 | return reward
51 |
52 | class AccountValueChangeReward(Reward):
53 | def __init__(self) -> None:
54 | super().__init__()
55 | self.ratio_days=365.25
56 |
57 | def reset(self, observations: Observations):
58 | super().reset(observations)
59 | self.returns = []
60 |
61 | def __call__(self, observations: Observations) -> float:
62 | assert isinstance(observations, Observations) == True, "observations must be an instance of Observations"
63 |
64 | last_state, next_state = observations[-2:]
65 | reward = (next_state.account_value - last_state.account_value) / last_state.account_value
66 |
67 | return reward
--------------------------------------------------------------------------------
/finrock/scalers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | np.seterr(all="ignore")
3 | import warnings
4 | from .state import Observations
5 |
6 |
7 | class Scaler:
8 | def __init__(self):
9 | pass
10 |
11 | def transform(self, observations: Observations) -> np.ndarray:
12 | raise NotImplementedError
13 |
14 | def __call__(self, observations) -> np.ndarray:
15 | assert isinstance(observations, Observations) == True, "observations must be an instance of Observations"
16 | return self.transform(observations)
17 |
18 | @property
19 | def __name__(self) -> str:
20 | return self.__class__.__name__
21 |
22 | @property
23 | def name(self) -> str:
24 | return self.__name__
25 |
26 |
27 | class MinMaxScaler(Scaler):
28 | def __init__(self, min: float, max: float):
29 | super().__init__()
30 | self._min = min
31 | self._max = max
32 |
33 | def transform(self, observations: Observations) -> np.ndarray:
34 | transformed_data = []
35 | for state in observations:
36 | data = []
37 | for name in ['open', 'high', 'low', 'close']:
38 | value = getattr(state, name)
39 | transformed_value = (value - self._min) / (self._max - self._min)
40 | data.append(transformed_value)
41 |
42 | data.append(state.allocation_percentage)
43 |
44 | # append scaled indicators
45 | for indicator in state.indicators:
46 | for value in indicator["values"].values():
47 | transformed_value = (value - indicator["min"]) / (indicator["max"] - indicator["min"])
48 | data.append(transformed_value)
49 |
50 | transformed_data.append(data)
51 |
52 | results = np.array(transformed_data)
53 |
54 | return results
55 |
56 |
57 | class ZScoreScaler(Scaler):
58 | def __init__(self):
59 | super().__init__()
60 | warnings.filterwarnings("ignore", category=RuntimeWarning, message="overflow encountered in reduce")
61 |
62 | def transform(self, observations: Observations) -> np.ndarray:
63 | full_data = []
64 | for state in observations:
65 | data = [getattr(state, name) for name in ['open', 'high', 'low', 'close', 'allocation_percentage']]
66 | data += [value for indicator in state.indicators for value in indicator["values"].values()]
67 | full_data.append(data)
68 |
69 | results = np.array(full_data)
70 |
71 | # nan to zero, when divided by zero and allocation_percentage is not changed
72 | returns = np.nan_to_num(np.diff(results, axis=0) / results[:-1])
73 |
74 | z_scores = np.nan_to_num((returns - np.mean(returns, axis=0)) / np.std(returns, axis=0))
75 |
76 | return z_scores
--------------------------------------------------------------------------------
/finrock/state.py:
--------------------------------------------------------------------------------
1 | import typing
2 | import numpy as np
3 | from datetime import datetime
4 |
5 | class State:
6 | def __init__(
7 | self,
8 | timestamp: str,
9 | open: float,
10 | high: float,
11 | low: float,
12 | close: float,
13 | volume: float=0.0,
14 | indicators: list=[]
15 | ):
16 | self.timestamp = timestamp
17 | self.open = open
18 | self.high = high
19 | self.low = low
20 | self.close = close
21 | self.volume = volume
22 | self.indicators = indicators
23 |
24 | try:
25 | self.date = datetime.strptime(timestamp, '%Y-%m-%d %H:%M:%S')
26 | except ValueError:
27 | raise ValueError(f'received invalid timestamp date format: {timestamp}, expected: YYYY-MM-DD HH:MM:SS')
28 |
29 | self._balance = 0.0 # balance in cash
30 | self._assets = 0.0 # balance in assets
31 | self._allocation_percentage = 0.0 # percentage of assets allocated to this state
32 |
33 | @property
34 | def balance(self):
35 | return self._balance
36 |
37 | @balance.setter
38 | def balance(self, value: float):
39 | self._balance = value
40 |
41 | @property
42 | def assets(self):
43 | return self._assets
44 |
45 | @assets.setter
46 | def assets(self, value: float):
47 | self._assets = value
48 |
49 | @property
50 | def account_value(self):
51 | return self.balance + self.assets * self.close
52 |
53 | @property
54 | def allocation_percentage(self):
55 | return self._allocation_percentage
56 |
57 | @allocation_percentage.setter
58 | def allocation_percentage(self, value: float):
59 | assert 0.0 <= value <= 1.0, f'allocation_percentage value must be between 0.0 and 1.0, received: {value}'
60 | self._allocation_percentage = value
61 |
62 |
63 | class Observations:
64 | def __init__(
65 | self,
66 | window_size: int,
67 | observations: typing.List[State]=[],
68 | ):
69 | self._observations = observations
70 | self._window_size = window_size
71 |
72 | assert isinstance(self._observations, list) == True, "observations must be a list"
73 | assert len(self._observations) <= self._window_size, f'observations length must be <= window_size, received: {len(self._observations)}'
74 | assert all(isinstance(observation, State) for observation in self._observations) == True, "observations must be a list of State objects"
75 |
76 | def __len__(self) -> int:
77 | return len(self._observations)
78 |
79 | @property
80 | def window_size(self) -> int:
81 | return self._window_size
82 |
83 | @property
84 | def observations(self) -> typing.List[State]:
85 | return self._observations
86 |
87 | @property
88 | def full(self) -> bool:
89 | return len(self._observations) == self._window_size
90 |
91 | def __getitem__(self, idx: int) -> State:
92 | try:
93 | return self._observations[idx]
94 | except IndexError:
95 | raise IndexError(f'index out of range: {idx}, observations length: {len(self._observations)}')
96 |
97 | def __iter__(self) -> State:
98 | """ Create a generator that iterate over the Sequence."""
99 | for index in range(len(self)):
100 | yield self[index]
101 |
102 | def reset(self) -> None:
103 | self._observations = []
104 |
105 | def append(self, state: State) -> None:
106 | # state should be State object or None
107 | assert isinstance(state, State) or state is None, "state must be a State object or None"
108 | self._observations.append(state)
109 |
110 | if len(self._observations) > self._window_size:
111 | self._observations.pop(0)
112 |
113 | @property
114 | def close(self) -> np.ndarray:
115 | return np.array([state.close for state in self._observations])
116 |
117 | @property
118 | def high(self) -> np.ndarray:
119 | return np.array([state.high for state in self._observations])
120 |
121 | @property
122 | def low(self) -> np.ndarray:
123 | return np.array([state.low for state in self._observations])
124 |
125 | @property
126 | def open(self) -> np.ndarray:
127 | return np.array([state.open for state in self._observations])
128 |
129 | @property
130 | def allocation_percentage(self) -> np.ndarray:
131 | return np.array([state.allocation_percentage for state in self._observations])
132 |
133 | @property
134 | def volume(self) -> np.ndarray:
135 | return np.array([state.volume for state in self._observations])
--------------------------------------------------------------------------------
/finrock/trading_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import typing
4 | import importlib
5 | import numpy as np
6 |
7 | from enum import Enum
8 | from .state import State, Observations
9 | from .data_feeder import PdDataFeeder
10 | from .reward import SimpleReward
11 |
12 | class ActionSpace(Enum):
13 | DISCRETE = 3
14 | CONTINUOUS = 2
15 |
16 | class TradingEnv:
17 | def __init__(
18 | self,
19 | data_feeder: PdDataFeeder,
20 | output_transformer: typing.Callable = None,
21 | initial_balance: float = 1000.0,
22 | max_episode_steps: int = None,
23 | window_size: int = 50,
24 | reward_function: typing.Callable = SimpleReward(),
25 | action_space: ActionSpace = ActionSpace.DISCRETE,
26 | metrics: typing.List[typing.Callable] = [],
27 | order_fee_percent: float = 0.001
28 | ) -> None:
29 | self._data_feeder = data_feeder
30 | self._output_transformer = output_transformer
31 | self._initial_balance = initial_balance
32 | self._max_episode_steps = max_episode_steps if max_episode_steps is not None else len(data_feeder)
33 | self._window_size = window_size
34 | self._reward_function = reward_function
35 | self._metrics = metrics
36 | self._order_fee_percent = order_fee_percent
37 |
38 | self._observations = Observations(window_size=window_size)
39 | self._observation_space = np.zeros(self.reset()[0].shape)
40 | self._action_space = action_space
41 | self.fee_ratio = 1 - self._order_fee_percent
42 |
43 | @property
44 | def action_space(self):
45 | return self._action_space.value
46 |
47 | @property
48 | def observation_space(self):
49 | return self._observation_space
50 |
51 | def _get_obs(self, index: int, balance: float=None) -> State:
52 | next_state = self._data_feeder[index]
53 | if next_state is None:
54 | return None
55 |
56 | if balance is not None:
57 | next_state.balance = balance
58 |
59 | return next_state
60 |
61 | def _get_terminated(self):
62 | return False
63 |
64 | def _take_action(self, action_pred: typing.Union[int, np.ndarray]) -> typing.Tuple[int, float]:
65 | """
66 | """
67 | # validate action is in range
68 |
69 | if isinstance(action_pred, np.ndarray):
70 | order_size = np.clip(action_pred[1], 0, 1)
71 | order_size = np.around(order_size, decimals=2)
72 | action = int((np.clip(action_pred[0], -1, 1) + 1) * 1.5) # scale from -1,1 to 0,3
73 | elif action_pred in [0, 1, 2]:
74 | order_size = 1.0
75 | action = action_pred
76 | assert (action in list(range(self._action_space.value))) == True, f'action must be in range {self._action_space.value}, received: {action}'
77 | else:
78 | raise ValueError(f'invalid action type: {type(action)}')
79 |
80 |
81 | # get last state and next state
82 | last_state, next_state = self._observations[-2:]
83 |
84 | # modify action to hold (0) if we are out of balance
85 | if action == 2 and last_state.allocation_percentage == 1.0:
86 | action = 0
87 |
88 | # modify action to hold (0) if we are out of assets
89 | elif action == 1 and last_state.allocation_percentage == 0.0:
90 | action = 0
91 |
92 | if order_size == 0:
93 | action = 0
94 |
95 | if action == 2: # buy
96 | buy_order_size = order_size
97 | next_state.allocation_percentage = last_state.allocation_percentage + (1 - last_state.allocation_percentage) * buy_order_size
98 | next_state.assets = last_state.assets + (last_state.balance * buy_order_size / last_state.close) * self.fee_ratio
99 | next_state.balance = last_state.balance - (last_state.balance * buy_order_size) * self.fee_ratio
100 |
101 | elif action == 1: # sell
102 | sell_order_size = order_size
103 | next_state.allocation_percentage = last_state.allocation_percentage - last_state.allocation_percentage * sell_order_size
104 | next_state.balance = last_state.balance + (last_state.assets * sell_order_size * last_state.close) * self.fee_ratio
105 | next_state.assets = last_state.assets - (last_state.assets * sell_order_size) * self.fee_ratio
106 |
107 | else: # hold
108 | next_state.allocation_percentage = last_state.allocation_percentage
109 | next_state.assets = last_state.assets
110 | next_state.balance = last_state.balance
111 |
112 | if next_state.allocation_percentage > 1.0:
113 | raise ValueError(f'next_state.allocation_percentage > 1.0: {next_state.allocation_percentage}')
114 | elif next_state.allocation_percentage < 0.0:
115 | raise ValueError(f'next_state.allocation_percentage < 0.0: {next_state.allocation_percentage}')
116 |
117 | return action, order_size
118 |
119 | @property
120 | def metrics(self):
121 | return self._metrics
122 |
123 | def _metricsHandler(self, observation: State):
124 | metrics = {}
125 | # Loop through metrics and update
126 | for metric in self._metrics:
127 | metric.update(observation)
128 | metrics[metric.name] = metric.result
129 |
130 | return metrics
131 |
132 | def step(self, action: int) -> typing.Tuple[State, float, bool, bool, dict]:
133 |
134 | index = self._env_step_indexes.pop(0)
135 |
136 | observation = self._get_obs(index)
137 | # update observations object with new observation
138 | self._observations.append(observation)
139 |
140 | action, order_size = self._take_action(action)
141 | reward = self._reward_function(self._observations)
142 | terminated = self._get_terminated()
143 | truncated = False if self._env_step_indexes else True
144 | info = {
145 | "states": [observation],
146 | "metrics": self._metricsHandler(observation)
147 | }
148 |
149 | transformed_obs = self._output_transformer.transform(self._observations)
150 |
151 | if np.isnan(transformed_obs).any():
152 | raise ValueError("transformed_obs contains nan values, check your data")
153 |
154 | return transformed_obs, reward, terminated, truncated, info
155 |
156 | def reset(self) -> typing.Tuple[State, dict]:
157 | """ Reset the environment and return the initial state
158 | """
159 | size = len(self._data_feeder) - self._max_episode_steps
160 | self._env_start_index = np.random.randint(0, size) if size > 0 else 0
161 | self._env_step_indexes = list(range(self._env_start_index, self._env_start_index + self._max_episode_steps))
162 |
163 | # Initial observations are the first states of the window size
164 | self._observations.reset()
165 | while not self._observations.full:
166 | obs = self._get_obs(self._env_step_indexes.pop(0), balance=self._initial_balance)
167 | if obs is None:
168 | continue
169 | # update observations object with new observation
170 | self._observations.append(obs)
171 |
172 | info = {
173 | "states": self._observations.observations,
174 | "metrics": {}
175 | }
176 |
177 | # reset metrics with last state
178 | for metric in self._metrics:
179 | metric.reset(self._observations.observations[-1])
180 |
181 | transformed_obs = self._output_transformer.transform(self._observations)
182 | if np.isnan(transformed_obs).any():
183 | raise ValueError("transformed_obs contains nan values, check your data")
184 |
185 | # return state and info
186 | return transformed_obs, info
187 |
188 | def render(self):
189 | raise NotImplementedError
190 |
191 | def close(self):
192 | """ Close the environment
193 | """
194 | pass
195 |
196 | def config(self):
197 | """ Return the environment configuration
198 | """
199 | return {
200 | "data_feeder": self._data_feeder.__name__,
201 | "output_transformer": self._output_transformer.__name__,
202 | "initial_balance": self._initial_balance,
203 | "max_episode_steps": self._max_episode_steps,
204 | "window_size": self._window_size,
205 | "reward_function": self._reward_function.__name__,
206 | "metrics": [metric.__name__ for metric in self._metrics],
207 | "order_fee_percent": self._order_fee_percent,
208 | "observation_space_shape": tuple(self.observation_space.shape),
209 | "action_space": self._action_space.name,
210 | }
211 |
212 | def save_config(self, path: str = ""):
213 | """ Save the environment configuration
214 | """
215 | output_path = os.path.join(path, "TradingEnv.json")
216 | with open(output_path, "w") as f:
217 | json.dump(self.config(), f, indent=4)
218 |
219 | @staticmethod
220 | def load_config(data_feeder, path: str = "", **kwargs):
221 | """ Load the environment configuration
222 | """
223 |
224 | input_path = os.path.join(path, "TradingEnv.json")
225 | if not os.path.exists(input_path):
226 | raise Exception(f"TradingEnv Config file not found in {path}")
227 | with open(input_path, "r") as f:
228 | config = json.load(f)
229 |
230 | environment = TradingEnv(
231 | data_feeder = data_feeder,
232 | output_transformer = getattr(importlib.import_module(".scalers", package=__package__), config["output_transformer"])(),
233 | initial_balance = kwargs.get("initial_balance") or config["initial_balance"],
234 | max_episode_steps = kwargs.get("max_episode_steps") or config["max_episode_steps"],
235 | window_size = kwargs.get("window_size") or config["window_size"],
236 | reward_function = getattr(importlib.import_module(".reward", package=__package__), config["reward_function"])(),
237 | action_space = ActionSpace[config["action_space"]],
238 | metrics = [getattr(importlib.import_module(".metrics", package=__package__), metric)() for metric in config["metrics"]],
239 | order_fee_percent = kwargs.get("order_fee_percent") or config["order_fee_percent"]
240 | )
241 |
242 | return environment
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pandas
3 | matplotlib
4 | rockrl==0.4.4
5 | tensorflow==2.10.0
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup, find_packages
3 |
4 | DIR = os.path.abspath(os.path.dirname(__file__))
5 |
6 | with open(os.path.join(DIR, 'README.md')) as fh:
7 | long_description = fh.read()
8 |
9 | with open(os.path.join(DIR, 'requirements.txt')) as fh:
10 | requirements = fh.read().splitlines()
11 |
12 | def get_version(initpath: str) -> str:
13 | """ Get from the init of the source code the version string
14 |
15 | Params:
16 | initpath (str): path to the init file of the python package relative to the setup file
17 |
18 | Returns:
19 | str: The version string in the form 0.0.1
20 | """
21 |
22 | path = os.path.join(os.path.dirname(__file__), initpath)
23 |
24 | with open(path, "r") as handle:
25 | for line in handle.read().splitlines():
26 | if line.startswith("__version__"):
27 | return line.split("=")[1].strip().strip("\"'")
28 | else:
29 | raise RuntimeError("Unable to find version string.")
30 |
31 | setup(
32 | name = 'finrock',
33 | version = get_version("finrock/__init__.py"),
34 | long_description = long_description,
35 | long_description_content_type = 'text/markdown',
36 | url='https://pylessons.com/',
37 | author='PyLessons',
38 | author_email='pythonlessons0@gmail.com',
39 | install_requires=requirements,
40 | python_requires='>=3',
41 | packages = find_packages(exclude=['*.pyc']),
42 | include_package_data=True,
43 | project_urls={
44 | 'Source': 'https://github.com/pythonlessons/FinRock/',
45 | 'Tracker': 'https://github.com/pythonlessons/FinRock/issues',
46 | },
47 | description="Reinformcement Learning for Financial Trading",
48 | )
--------------------------------------------------------------------------------