├── cover.png
├── requirements.txt
├── .gitmodules
├── .idea
├── other.xml
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── vcs.xml
├── modules.xml
└── a_share_pipeline.iml
├── utils
├── log.py
├── sqlite.py
├── psqldb.py
└── date_time.py
├── reports
├── template
│ ├── index.html.template
│ └── card_predict_result.template
├── index.html
└── sqlite_to_html.py
├── config.py
├── filter_stock.py
├── README.md
├── .gitignore
├── train_single.py
├── train_batch_bak.py
├── train_batch.py
├── train_helper.py
├── predict_single_psql.py
├── predict_batch_psql.py
├── run_single.py
├── env_single.py
└── run_batch.py
/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dyh/a_share_pipeline/HEAD/cover.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | baostock
2 | requests
3 | psycopg2-binary
4 | gym
5 | stockstats
6 | matplotlib
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "ElegantRL_master"]
2 | path = ElegantRL_master
3 | url = https://github.com/AI4Finance-LLC/ElegantRL.git
4 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/utils/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def get_logger(log_file_path, log_level=logging.INFO):
5 | """
6 | 返回 logger 实例
7 | :param log_file_path: 日志文件路径
8 | :param log_level: 日志记录级别 logging.INFO、logging.ERROR、logging.DEBUG 等等
9 | :return: logger 实例
10 | """
11 | logger = logging.getLogger()
12 | logger.setLevel(log_level)
13 |
14 | logfile = log_file_path
15 | fh = logging.FileHandler(logfile, mode='a')
16 | formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
17 | fh.setFormatter(formatter)
18 | fh.setLevel(log_level)
19 | logger.addHandler(fh)
20 |
21 | sh = logging.StreamHandler()
22 | sh.setLevel(log_level)
23 | logger.addHandler(sh)
24 |
25 | return logger
26 |
--------------------------------------------------------------------------------
/.idea/a_share_pipeline.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/reports/template/index.html.template:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | <%page_title%>
8 |
9 |
10 |
11 |
12 |
13 |
23 |
24 | <%page_content%>
25 |
26 | 页面生成时间 <%page_time_point%>
27 |
28 |
--------------------------------------------------------------------------------
/reports/template/card_predict_result.template:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | 预测 <%date%>
5 |
6 |
7 |
8 | <%tic%> 持股 <%most_hold%> %
9 | 买卖 <%most_action%> %
10 |
11 |
12 |
15 |
16 |
17 |
18 |
<%tic%> <%date%>
19 |
20 |
21 |
22 |
23 | | T-1回报率/低买高卖价差 |
24 | 买卖 |
25 | 持股 |
26 | 算法 |
27 | 验证周期 |
28 | 预测周期 |
29 |
30 | <%predict_result_table_tr_td%>
31 |
32 |
33 |
34 |

35 |
36 |
37 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 |
2 | # 交易的最小股数,可设置最低100股,即1手
3 | MINIMUM_TRADE_SHARES = 1
4 |
5 | # 调试 state 比例
6 | IF_DEBUG_STATE_SCALE = False
7 |
8 | # 调试 reward 比例
9 | IF_DEBUG_REWARD_SCALE = True
10 |
11 | # 多只股票代码List
12 | BATCH_A_STOCK_CODE = []
13 |
14 | # update reward 阈值
15 | REWARD_THRESHOLD = 256 * 1.5
16 |
17 | # 超参ID
18 | MODEL_HYPER_PARAMETERS = -1
19 |
20 | # 代理名称
21 | AGENT_NAME = ''
22 |
23 | # 代理预测周期
24 | AGENT_WORK_DAY = 0
25 |
26 | # 预测周期
27 | PREDICT_PERIOD = ''
28 |
29 | # 单支股票代码List
30 | SINGLE_A_STOCK_CODE = []
31 |
32 | # 真实的预测
33 | IF_ACTUAL_PREDICT = False
34 |
35 | # 工作日标记,用于加载对应的weights
36 | # weights的vali周期
37 | VALI_DAYS_FLAG = ''
38 |
39 | ## time_fmt = '%Y-%m-%d'
40 | START_DATE = ''
41 | START_EVAL_DATE = ''
42 | END_DATE = ''
43 | # 要输出的日期
44 | OUTPUT_DATE = ''
45 |
46 | DATA_SAVE_DIR = f'stock_db'
47 |
48 | # pth路径
49 | WEIGHTS_PATH = 'weights'
50 |
51 | # batch股票数据库地址
52 | STOCK_DB_PATH = './' + DATA_SAVE_DIR + '/stock.db'
53 |
54 | # ----
55 | # PostgreSQL
56 | PSQL_HOST = '192.168.192.1'
57 | PSQL_PORT = '5432'
58 | PSQL_DATABASE = 'a_share'
59 | PSQL_USER = 'dyh'
60 | PSQL_PASSWORD = '9898BdlLsdfsHsbgp'
61 | # ----
62 |
63 | ## stockstats technical indicator column names
64 | ## check https://pypi.org/project/stockstats/ for different names
65 | TECHNICAL_INDICATORS_LIST = ['macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30', 'close_30_sma', 'close_60_sma']
66 |
--------------------------------------------------------------------------------
/filter_stock.py:
--------------------------------------------------------------------------------
1 | import baostock as bs
2 | import pandas as pd
3 |
4 | from stock_data import StockData
5 |
6 | if __name__ == '__main__':
7 |
8 | list1 = StockData.get_batch_a_share_code_list_string(date_filter='2004-05-01')
9 |
10 | # 登陆系统
11 | lg = bs.login()
12 | # 显示登陆返回信息
13 | print('login respond error_code:' + lg.error_code)
14 | print('login respond error_msg:' + lg.error_msg)
15 |
16 | # 获取行业分类数据
17 | rs = bs.query_stock_industry()
18 | # rs = bs.query_stock_basic(code_name="浦发银行")
19 | print('query_stock_industry error_code:' + rs.error_code)
20 | print('query_stock_industry respond error_msg:' + rs.error_msg)
21 |
22 | # 打印结果集
23 | industry_list = []
24 | # 查询 季频盈利能力
25 | profit_list = []
26 |
27 | fields1 = ['code', 'code_name', 'industry', 'pubDate', 'statDate', 'roeAvg', 'npMargin', 'gpMargin',
28 | 'netProfit', 'epsTTM', 'MBRevenue', 'totalShare', 'liqaShare']
29 |
30 | index1 = 0
31 | count1 = len(list1)
32 |
33 | while (rs.error_code == '0') & rs.next():
34 | # 获取一条记录,将记录合并在一起
35 | industry_item = rs.get_row_data()
36 |
37 | updateDate, code, code_name, industry, industryClassification = industry_item
38 |
39 | if code in list1:
40 |
41 | rs_profit = bs.query_profit_data(code=code, year=2021, quarter=1)
42 |
43 | while (rs_profit.error_code == '0') & rs_profit.next():
44 | code, pubDate, statDate, roeAvg, npMargin, gpMargin, netProfit, epsTTM, MBRevenue, totalShare, liqaShare = rs_profit.get_row_data()
45 |
46 | # profit_list.append(rs_profit.get_row_data())
47 | profit_list.append(
48 | [code, code_name, industry, pubDate, statDate, roeAvg, npMargin, gpMargin, netProfit, epsTTM,
49 | MBRevenue, totalShare, liqaShare])
50 |
51 | pass
52 | pass
53 |
54 | index1 += 1
55 | print(index1, '/', count1)
56 |
57 | pass
58 |
59 | result_profit = pd.DataFrame(profit_list, columns=fields1)
60 |
61 | # result = pd.DataFrame(industry_list, columns=rs.fields)
62 |
63 | # 结果集输出到csv文件
64 | result_profit.to_csv("./stock_industry_profit.csv", index=False)
65 |
66 | print(result_profit)
67 |
68 | # 登出系统
69 | bs.logout()
70 | pass
71 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A Share Pipeline
2 |
3 | - A股决策 Pipeline
4 |
5 | ### 视频
6 |
7 | bilibili
8 |
9 | [](https://www.bilibili.com/video/BV1ph411Y78q/ "bilibili")
10 |
11 |
12 | ### 使用框架:
13 | - https://github.com/AI4Finance-LLC/ElegantRL
14 |
15 | ### 数据来源:
16 | - baostock.com
17 | - sinajs.cn
18 |
19 | ## 运行环境
20 |
21 | - ubuntu 18.04.5
22 | - python 3.6+,pip 20+
23 | - pytorch 1.7+
24 | - pip install -r requirements.txt
25 |
26 |
27 | ## 如何运行
28 |
29 | 1. 下载代码
30 |
31 | ```
32 | $ git clone https://github.com/dyh/a_share_pipeline.git
33 | ```
34 |
35 | 2. 进入目录
36 |
37 | ```
38 | $ cd a_share_pipeline
39 | ```
40 |
41 | 3. 更新submodule ElegantRL
42 |
43 | ```
44 | $ cd ElegantRL_master/
45 | $ git submodule update --init --recursive
46 | $ git pull
47 | $ cd ..
48 | ```
49 |
50 | > 如果git pull 提示没在任何分支上,可以检出ElegantRL的master分支:
51 |
52 | ```
53 | $ git checkout master
54 | ```
55 |
56 | 4. 创建 python 虚拟环境
57 |
58 | ```
59 | $ python3 -m venv venv
60 | ```
61 |
62 | > 如果提示找不到 venv,则安装venv:
63 |
64 | ```
65 | $ sudo apt install python3-pip
66 | $ sudo apt-get install python3-venv
67 | ```
68 |
69 | 5. 激活虚拟环境
70 |
71 | ```
72 | $ source venv/bin/activate
73 | ```
74 |
75 | 6. 升级pip
76 |
77 | ```
78 | $ python -m pip install --upgrade pip
79 | ```
80 |
81 | 7. 安装pytorch
82 |
83 | > 根据你的操作系统,运行环境工具和CUDA版本,在 https://pytorch.org/get-started/locally/ 找到对应的安装命令,复制粘贴运行。为了方便演示,以下使用CPU运行,安装命令如下:
84 |
85 | ```
86 | $ pip3 install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
87 | ```
88 |
89 | 8. 安装其他软件包
90 |
91 | ```
92 | $ pip install -r requirements.txt
93 | ```
94 |
95 | > 如果上述安装出错,可以考虑安装如下:
96 |
97 | ```
98 | $ sudo apt-get update && sudo apt-get install cmake libopenmpi-dev python3-dev zlib1g-dev libgl1-mesa-glx
99 | ```
100 |
101 | 9. 训练模型 sh.600036
102 |
103 | ```
104 | $ python train_single.py
105 | ```
106 |
107 | 10. 预测数据 sh.600036
108 |
109 | ```
110 | $ python predict_single_sqlite.py
111 | ```
112 |
113 | 11. 生成报告 sh.600036
114 |
115 | ```
116 | $ cd reports
117 | $ python sqlite_to_html.py
118 | ```
119 |
--------------------------------------------------------------------------------
/utils/sqlite.py:
--------------------------------------------------------------------------------
1 | # coding = utf-8
2 |
3 | import logging
4 | import sqlite3
5 |
6 |
7 | class SQLite:
8 | def __init__(self, dbname):
9 | self.name = dbname
10 | self.conn = sqlite3.connect(self.name, check_same_thread=False)
11 |
12 | def commit(self):
13 | self.conn.commit()
14 |
15 | def close(self):
16 | self.conn.close()
17 |
18 | # 执行sql但不查询
19 | def execute_non_query(self, sql, values=()):
20 | value = False
21 | try:
22 | cursor = self.conn.cursor()
23 | cursor.execute(sql, values)
24 | # self.commit()
25 | value = True
26 | except sqlite3.OperationalError as e:
27 | logging.exception(e)
28 | pass
29 | except Exception as e:
30 | logging.exception(e)
31 | finally:
32 | return value
33 | # self.close()
34 |
35 | # 查询,返回所有结果
36 | def fetchall(self, sql, values=()):
37 | try:
38 | cursor = self.conn.cursor()
39 | cursor.execute(sql, values)
40 | values = cursor.fetchall()
41 | return values
42 | except sqlite3.OperationalError as e:
43 | logging.exception(e)
44 | pass
45 | except Exception as e:
46 | logging.exception(e)
47 | finally:
48 | pass
49 |
50 | # 查询,返回一条结果
51 | def fetchone(self, sql, values=()):
52 | try:
53 | cursor = self.conn.cursor()
54 | cursor.execute(sql, values)
55 | value = cursor.fetchone()
56 | return value
57 | except sqlite3.OperationalError as e:
58 | logging.exception(e)
59 | pass
60 | except Exception as e:
61 | logging.exception(e)
62 | finally:
63 | pass
64 |
65 | # 查询,返回一条结果
66 | def table_exists(self, table_name):
67 | try:
68 | cursor = self.conn.cursor()
69 | cursor.execute(f"select name from sqlite_master where type = 'table' and name = '{table_name}'")
70 | value = cursor.fetchone()
71 | return value
72 | except sqlite3.OperationalError as e:
73 | logging.exception(e)
74 | pass
75 | except Exception as e:
76 | logging.exception(e)
77 | finally:
78 | pass
79 |
80 | # 删除一个表
81 | def drop_table(self, table_name):
82 | value = False
83 | try:
84 | cursor = self.conn.cursor()
85 | cursor.execute(f"DROP TABLE IF EXISTS '{table_name}'")
86 | value = True
87 | except sqlite3.OperationalError as e:
88 | logging.exception(e)
89 | pass
90 | except Exception as e:
91 | logging.exception(e)
92 | finally:
93 | return value
94 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | *.pth
132 | *.zip
133 | *.npy
134 | *.npz
135 | *.df
136 | *.csv
137 | *.db
138 |
139 | stock.db
140 |
141 | /pipeline/informer/checkpoints/
142 | /pipeline/informer/temp_dataset/
143 | /pipeline/finrl/datasets_temp/
144 | /pipeline/finrl/results/
145 | /pipeline/finrl/tensorboard_log/
146 | /pipeline/finrl/trained_models/
147 | /pipeline/informer/results/
148 | /pipeline/elegant/AgentPPO/
149 | /pipeline/elegant/datasets/
150 | /pipeline/elegant/bak/
151 | /pipeline/elegant/AgentSAC/
152 | /pipeline/elegant/AgentDDPG/
153 | /pipeline/elegant/AgentDuelingDQN/
154 | /pipeline/elegant/AgentTD3/
155 | /pipeline/elegant/AgentModSAC/
156 | /pipeline/elegant/AgentDoubleDQN/
157 | /pipeline/elegant/AgentSharedSAC/
158 | /weights/
159 | /stock_db/
160 | /temp/
161 |
--------------------------------------------------------------------------------
/reports/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | A股预测
8 |
9 |
10 |
11 |
12 |
13 |
23 |
24 |
25 |
26 |
27 | 预测 2021-06-24 周四
28 |
29 |
30 |
31 | sh.600036 持股 0.0 %
32 | 买卖 0.0 %
33 |
34 |
35 |
38 |
39 |
40 |
41 |
sh.600036 2021-06-24 周四
42 |
43 |
44 |
45 |
46 | | T-1回报率/低买高卖价差 |
47 | 买卖 |
48 | 持股 |
49 | 算法 |
50 | 验证周期 |
51 | 预测周期 |
52 |
53 | | 5.22% / 9.91% | -97.0% | 0.0% | PPO | 90天 | 第36/40天 |
| 3.46% / 9.91% | 0.0% | 0.0% | SAC | 20天 | 第36/40天 |
| 3.37% / 9.91% | 0.0% | 0.0% | DDPG | 60天 | 第36/40天 |
| 2.6% / 9.91% | 0.0% | 97.0% | PPO | 50天 | 第36/40天 |
| 2.6% / 9.91% | 0.0% | 97.0% | SAC | 30天 | 第36/40天 |
| 2.6% / 9.91% | 0.0% | 97.0% | SAC | 90天 | 第36/40天 |
| 2.3% / 9.91% | 0.0% | 97.0% | PPO | 20天 | 第36/40天 |
| 2.22% / 9.91% | 0.0% | 97.0% | DDPG | 30天 | 第36/40天 |
| 1.81% / 9.91% | -97.0% | 0.0% | TD3 | 50天 | 第36/40天 |
| 1.6% / 9.91% | 28.0% | 69.0% | SAC | 60天 | 第36/40天 |
| 1.44% / 9.91% | 0.0% | 93.0% | PPO | 40天 | 第36/40天 |
| 1.4% / 9.91% | 0.0% | 93.0% | TD3 | 90天 | 第36/40天 |
| 1.19% / 9.91% | 0.0% | 93.0% | PPO | 30天 | 第36/40天 |
| -0.15% / 9.91% | 0.0% | 90.0% | SAC | 72天 | 第36/40天 |
| -0.33% / 9.91% | 0.0% | 93.0% | TD3 | 30天 | 第36/40天 |
| -1.06% / 9.91% | 0.0% | 93.0% | DDPG | 90天 | 第36/40天 |
| -1.21% / 9.91% | 0.0% | 0.0% | TD3 | 72天 | 第36/40天 |
| -1.24% / 9.91% | 0.0% | 0.0% | DDPG | 50天 | 第36/40天 |
| -1.39% / 9.91% | -76.0% | 0.0% | TD3 | 20天 | 第36/40天 |
| -1.98% / 9.91% | 0.0% | 0.0% | DDPG | 20天 | 第36/40天 |
| -2.29% / 9.91% | 17.0% | 41.0% | SAC | 50天 | 第36/40天 |
| -2.33% / 9.91% | 0.0% | 0.0% | TD3 | 60天 | 第36/40天 |
| -2.44% / 9.91% | 0.0% | 0.0% | SAC | 40天 | 第36/40天 |
| -3.08% / 9.91% | -83.0% | 7.0% | PPO | 60天 | 第36/40天 |
| -3.47% / 9.91% | 0.0% | 0.0% | DDPG | 72天 | 第36/40天 |
| -3.94% / 9.91% | -90.0% | 0.0% | DDPG | 40天 | 第36/40天 |
| -4.27% / 9.91% | 90.0% | 90.0% | TD3 | 40天 | 第36/40天 |
| -12.21% / 9.91% | 83.0% | 83.0% | PPO | 72天 | 第36/40天 |
54 |
55 |
56 |
57 |

58 |
59 |
60 |
61 |
62 |
63 | 页面生成时间 2021-06-24 12:49:45
64 |
65 |
--------------------------------------------------------------------------------
/utils/psqldb.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import time
3 |
4 | import logging
5 | import psycopg2
6 |
7 |
8 | # 判断是否数据库超时或者数据库连接关闭
9 | def is_timeout_or_closed_error(str_exception):
10 | if "Connection timed out" in str_exception or "connection already closed" in str_exception:
11 | return True
12 | else:
13 | return False
14 |
15 |
16 | class Psqldb:
17 | def __init__(self, database, user, password, host, port):
18 | self.database = database
19 | self.user = user
20 | self.password = password
21 | self.host = host
22 | self.port = port
23 | self.conn = None
24 | # 调用连接函数
25 | self.re_connect()
26 |
27 | # 再次连接数据库
28 | def re_connect(self):
29 | # 如果连接数据库超时,则循环再次连接
30 | while True:
31 | try:
32 | print(__name__, "re_connect")
33 |
34 | self.conn = psycopg2.connect(database=self.database, user=self.user,
35 | password=self.password, host=self.host, port=self.port)
36 | # 如果连接正常,则退出循环
37 | break
38 | except psycopg2.OperationalError as e:
39 | logging.exception(e)
40 | except Exception as e:
41 | # 其他错误,记录,退出
42 | logging.exception(e)
43 | break
44 |
45 | # sleep 5秒后重新连接
46 | print(__name__, "re-connecting PostgreSQL in 5 seconds")
47 | time.sleep(5)
48 |
49 | # 提交数据
50 | def commit(self):
51 | try:
52 | self.conn.commit()
53 | except psycopg2.OperationalError as e:
54 | logging.exception(e)
55 | except Exception as e:
56 | logging.exception(e)
57 |
58 | # 关闭数据库
59 | def close(self):
60 | try:
61 | print(__name__, "close")
62 |
63 | self.conn.close()
64 | except psycopg2.OperationalError as e:
65 | logging.exception(e)
66 | except Exception as e:
67 | logging.exception(e)
68 |
69 | # 执行sql但不查询
70 | def execute_non_query(self, sql, values=()):
71 | result = False
72 | while True:
73 | try:
74 | # console(__name__, "execute_non_query")
75 |
76 | cursor = self.conn.cursor()
77 | cursor.execute(sql, values)
78 | # self.commit()
79 | result = True
80 | break
81 | except psycopg2.OperationalError as e:
82 | logging.exception(e)
83 | str_exception = str(e)
84 | if is_timeout_or_closed_error(str_exception):
85 | # 重新连接数据库
86 | self.re_connect()
87 | else:
88 | break
89 | except Exception as e:
90 | # 其他错误,记录,退出
91 | logging.exception(e)
92 | str_exception = str(e)
93 | if is_timeout_or_closed_error(str_exception):
94 | # console(__name__, "re-connecting PostgreSQL immediately")
95 | # 重新连接数据库
96 | self.re_connect()
97 | else:
98 | break
99 | # sleep 5秒后重新连接
100 | print(__name__, "re-connecting PostgreSQL in 5 seconds")
101 | time.sleep(5)
102 |
103 | return result
104 |
105 | # 查询,返回所有结果
106 | def fetchall(self, sql, values=()):
107 | results = None
108 | while True:
109 | try:
110 | # console(__name__, "fetchall")
111 |
112 | cursor = self.conn.cursor()
113 | cursor.execute(sql, values)
114 | results = cursor.fetchall()
115 | break
116 | except psycopg2.OperationalError as e:
117 | logging.exception(e)
118 | str_exception = str(e)
119 | if is_timeout_or_closed_error(str_exception):
120 | # 重新连接数据库
121 | self.re_connect()
122 | else:
123 | break
124 | except Exception as e:
125 | # 其他错误,记录忽略
126 | logging.exception(e)
127 | str_exception = str(e)
128 | if is_timeout_or_closed_error(str_exception):
129 | # 重新连接数据库
130 | self.re_connect()
131 | else:
132 | break
133 | return results
134 |
135 | # 查询,返回一条结果
136 | def fetchone(self, sql, values=()):
137 | result = None
138 | while True:
139 | try:
140 | # console(__name__, "fetchone")
141 |
142 | cursor = self.conn.cursor()
143 | cursor.execute(sql, values)
144 | result = cursor.fetchone()
145 | break
146 | except psycopg2.OperationalError as e:
147 | logging.exception(e)
148 | str_exception = str(e)
149 | if is_timeout_or_closed_error(str_exception):
150 | # 重新连接数据库
151 | self.re_connect()
152 | else:
153 | break
154 | except Exception as e:
155 | # 其他错误,记录,退出
156 | logging.exception(e)
157 | str_exception = str(e)
158 | if is_timeout_or_closed_error(str_exception):
159 | # 重新连接数据库
160 | self.re_connect()
161 | else:
162 | break
163 | return result
164 |
165 | # 查询,返回一条结果
166 | def table_exists(self, table_name):
167 | result = None
168 | try:
169 | cursor = self.conn.cursor()
170 | cursor.execute(f"select count(*) from pg_class where relname = \'{table_name}\';")
171 | value = cursor.fetchone()[0]
172 | if value == 0:
173 | result = None
174 | else:
175 | result = value
176 | except psycopg2.OperationalError as e:
177 | logging.exception(e)
178 | pass
179 | except Exception as e:
180 | logging.exception(e)
181 | finally:
182 | return result
183 | pass
184 | #
185 | # # 删除一个表
186 | # def drop_table(self, table_name):
187 | # value = False
188 | # try:
189 | # cursor = self.conn.cursor()
190 | # cursor.execute("DROP TABLE IF EXISTS '{}'".format(table_name))
191 | # value = True
192 | # except psycopg2.OperationalError as e:
193 | # logging.exception(e)
194 | # pass
195 | # except Exception as e:
196 | # logging.exception(e)
197 | # finally:
198 | # return value
199 |
--------------------------------------------------------------------------------
/utils/date_time.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import time
3 |
4 |
5 | def time_point(time_format='%Y%m%d_%H%M%S'):
6 | return time.strftime(time_format, time.localtime())
7 |
8 |
9 | def get_datetime_from_date_str(str_date):
10 | """
11 | 由日期字符串获得datetime对象
12 | :param str_date: 日期字符串,格式 2021-10-20
13 | :return: datetime对象
14 | """
15 | year_temp = str_date.split('-')[0]
16 | month_temp = str_date.split('-')[1]
17 | day_temp = str_date.split('-')[2]
18 | return datetime.date(int(year_temp), int(month_temp), int(day_temp))
19 |
20 |
21 | # def get_next_work_day(datetime_date, next_flag=1):
22 | # """
23 | # 获取下一个工作日
24 | # :param datetime_date: 日期类型
25 | # :param next_flag: -1 前一天,+1 后一天
26 | # :return: datetime.date
27 | # """
28 | # if next_flag > 0:
29 | # next_date = datetime_date + datetime.timedelta(
30 | # days=7 - datetime_date.weekday() if datetime_date.weekday() > 3 else 1)
31 | # else:
32 | # # 周二至周六
33 | # if 0 < datetime_date.weekday() < 6:
34 | # next_date = datetime_date - datetime.timedelta(days=1)
35 | # pass
36 | # elif 0 == datetime_date.weekday():
37 | # # 周一
38 | # next_date = datetime_date - datetime.timedelta(days=3)
39 | # else:
40 | # # 6 == datetime_date.weekday():
41 | # # 周日
42 | # next_date = datetime_date - datetime.timedelta(days=2)
43 | # pass
44 | # pass
45 | # return next_date
46 |
47 | def get_next_work_day(datetime_date, next_flag=1):
48 | """
49 | 获取下一个工作日
50 | :param datetime_date: 日期类型
51 | :param next_flag: -1 前一天,+1 后一天
52 | :return: datetime.date
53 | """
54 | for loop in range(next_flag.__abs__()):
55 | if next_flag > 0:
56 | datetime_date = datetime_date + datetime.timedelta(
57 | days=7 - datetime_date.weekday() if datetime_date.weekday() > 3 else 1)
58 | else:
59 | # 周二至周六
60 | if 0 < datetime_date.weekday() < 6:
61 | datetime_date = datetime_date - datetime.timedelta(days=1)
62 | pass
63 | elif 0 == datetime_date.weekday():
64 | # 周一
65 | datetime_date = datetime_date - datetime.timedelta(days=3)
66 | else:
67 | # 6 == datetime_date.weekday():
68 | # 周日
69 | datetime_date = datetime_date - datetime.timedelta(days=2)
70 | pass
71 | pass
72 | return datetime_date
73 |
74 |
75 | def get_next_day(datetime_date, next_flag=1):
76 | """
77 | 获取下一个日期
78 | :param datetime_date: 日期类型
79 | :param next_flag: -2 前2天,+1 后1天
80 | :return: datetime.date
81 | """
82 | if next_flag > 0:
83 | next_date = datetime_date + datetime.timedelta(days=abs(next_flag))
84 | else:
85 | next_date = datetime_date - datetime.timedelta(days=abs(next_flag))
86 | return next_date
87 |
88 |
89 | def get_today_date():
90 | # 获取今天日期
91 | time_format = '%Y-%m-%d'
92 | return time.strftime(time_format, time.localtime())
93 |
94 |
95 | def is_greater(date1, date2):
96 | """
97 | 日期1是否大于日期2
98 | :param date1: 日期1字符串 2001-03-01
99 | :param date2: 日期2字符串 2001-01-01
100 | :return: True/Flase
101 | """
102 | temp1 = time.strptime(date1, '%Y-%m-%d')
103 | temp2 = time.strptime(date2, '%Y-%m-%d')
104 |
105 | if temp1 > temp2:
106 | return True
107 | else:
108 | return False
109 | pass
110 |
111 |
112 | def get_begin_vali_date_list(end_vali_date):
113 | """
114 | 获取7个日期列表
115 | :return: list()
116 | """
117 | list_result = list()
118 | # for work_days in [20, 30, 40, 50, 60, 72, 90, 100, 150, 200, 300, 500, 518, 1000, 1200, 1268]:
119 | for work_days in [30, 40, 50, 60, 72, 90, 100, 150, 200, 300, 500, 518, 1000, 1200]:
120 | begin_vali_date = get_next_work_day(end_vali_date, next_flag=-work_days)
121 | list_result.append((work_days, begin_vali_date))
122 |
123 | list_result.reverse()
124 |
125 | return list_result
126 |
127 |
128 | def get_end_vali_date_list(begin_vali_date):
129 | """
130 | 获取7个日期列表
131 | :return: list()
132 | """
133 |
134 | list_result = list()
135 |
136 | # # 20周期
137 | # end_vali_date = get_next_day(begin_vali_date, next_flag=28)
138 | # list_result.append((20, end_vali_date))
139 | #
140 | # # 30周期
141 | # end_vali_date = get_next_day(begin_vali_date, next_flag=42)
142 | # list_result.append((30, end_vali_date))
143 | #
144 | # # 40周期
145 | # end_vali_date = get_next_day(begin_vali_date, next_flag=56)
146 | # list_result.append((40, end_vali_date))
147 | #
148 | # # 50周期
149 | # end_vali_date = get_next_day(begin_vali_date, next_flag=77)
150 | # list_result.append((50, end_vali_date))
151 | #
152 | # # 60周期
153 | # end_vali_date = get_next_day(begin_vali_date, next_flag=91)
154 | # list_result.append((60, end_vali_date))
155 | #
156 | # # 72周期
157 | # end_vali_date = get_next_day(begin_vali_date, next_flag=108)
158 | # list_result.append((72, end_vali_date))
159 | #
160 | # # 90周期
161 | # end_vali_date = get_next_day(begin_vali_date, next_flag=134)
162 | # list_result.append((90, end_vali_date))
163 |
164 | # 50周期
165 | work_days = 50
166 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days)
167 | list_result.append((work_days, end_vali_date))
168 |
169 | # 100周期
170 | work_days = 100
171 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days)
172 | list_result.append((work_days, end_vali_date))
173 |
174 | # 150周期
175 | work_days = 150
176 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days)
177 | list_result.append((work_days, end_vali_date))
178 |
179 | # 200周期
180 | work_days = 200
181 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days)
182 | list_result.append((work_days, end_vali_date))
183 |
184 | # 300周期
185 | work_days = 300
186 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days)
187 | list_result.append((work_days, end_vali_date))
188 |
189 | # 500周期
190 | work_days = 500
191 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days)
192 | list_result.append((work_days, end_vali_date))
193 |
194 | # 1000周期
195 | work_days = 1000
196 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days)
197 | list_result.append((work_days, end_vali_date))
198 |
199 | return list_result
200 |
201 |
202 | def get_week_day(string_time_point):
203 | week_day_dict = {
204 | 0: '周一',
205 | 1: '周二',
206 | 2: '周三',
207 | 3: '周四',
208 | 4: '周五',
209 | 5: '周六',
210 | 6: '周日',
211 | }
212 |
213 | day_of_week = datetime.datetime.fromtimestamp(
214 | time.mktime(time.strptime(string_time_point, "%Y-%m-%d"))).weekday()
215 |
216 | return week_day_dict[day_of_week]
217 |
--------------------------------------------------------------------------------
/reports/sqlite_to_html.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 |
4 | import time
5 |
6 | import config
7 | from utils.date_time import get_week_day
8 | from utils.sqlite import SQLite
9 |
10 | if __name__ == '__main__':
11 | # 获取要输出到html的tic
12 | config.SINGLE_A_STOCK_CODE = ['sh.600036', ]
13 |
14 | # 交易详情
15 | text_trade_detail = ''
16 |
17 | # 从 index.html.template 文件中 读取 网页模板
18 | with open('./template/index.html.template', 'r') as index_page_template:
19 | text_index_page_template = index_page_template.read()
20 |
21 | all_text_card = ''
22 |
23 | # 从 card_predict_result.template 文件中 读取 结果卡片 模板
24 | with open('./template/card_predict_result.template', 'r') as result_card_template:
25 | text_card = result_card_template.read()
26 |
27 | db_path = '../stock_db/stock.db'
28 |
29 | # 连接数据库
30 | sqlite = SQLite(db_path)
31 |
32 | for tic in config.SINGLE_A_STOCK_CODE:
33 |
34 | table_name = tic + '_report'
35 |
36 | # 复制一个结果卡片模板,将以下信息 填充到卡片中
37 | copy_text_card = text_card
38 |
39 | # 获取数据库中最大hold,最大action,用于计算百分比
40 | sql_cmd = f'SELECT abs("hold") FROM "{table_name}" ORDER BY abs("hold") DESC LIMIT 1'
41 | max_hold = sqlite.fetchone(sql_cmd)
42 | max_hold = max_hold[0]
43 |
44 | # max_action
45 | sql_cmd = f'SELECT abs("action") FROM "{table_name}" ORDER BY abs("action") DESC LIMIT 1'
46 | max_action = sqlite.fetchone(sql_cmd)
47 | max_action = max_action[0]
48 |
49 | # 获取数据库中最大日期
50 | sql_cmd = f'SELECT "date" FROM "{table_name}" ORDER BY "date" DESC LIMIT 1'
51 | max_date = sqlite.fetchone(sql_cmd)
52 | max_date = str(max_date[0])
53 |
54 | # 用此最大日期查询出一批数据
55 | sql_cmd = f'SELECT "id", "agent", "vali_period_value", "pred_period_name", "action", "hold", "day", "episode_return", "max_return", "trade_detail" ' \
56 | f'FROM "{table_name}" WHERE "date" = \'{max_date}\' ORDER BY episode_return DESC'
57 |
58 | list_result = sqlite.fetchall(sql_cmd)
59 |
60 | text_table_tr_td = ''
61 |
62 | for item_result in list_result:
63 | # 替换字符串内容
64 | copy_text_card = copy_text_card.replace('<%tic%>', tic)
65 | copy_text_card = copy_text_card.replace('<%tic_no_dot%>', tic.replace('.', ''))
66 |
67 | id1, agent1, vali_period_value1, pred_period_name1, action1, hold1, day1, episode_return1, max_return1, trade_detail1 = item_result
68 |
69 | # <%day%>
70 | copy_text_card = copy_text_card.replace('<%day%>', day1)
71 |
72 | # 改为百分比
73 | action1 = float(action1)
74 | max_action = float(max_action)
75 | hold1 = float(hold1)
76 | max_hold = float(max_hold)
77 | episode_return1 = float(episode_return1)
78 | max_return1 = float(max_return1)
79 |
80 | action1 = round(action1 * 100 / max_action, 0)
81 | hold1 = round(hold1 * 100 / max_hold, 0)
82 | # agent1
83 | agent1 = agent1[5:]
84 |
85 | # 回报
86 | episode_return1 = round((episode_return1 - 1) * 100, 2)
87 | max_return1 = round((max_return1 - 1) * 100, 2)
88 |
89 | text_table_tr_td += f'' \
90 | f'| {episode_return1}% / {max_return1}% | ' \
91 | f'{action1}% | ' \
92 | f'{hold1}% | ' \
93 | f'{agent1} | ' \
94 | f'{vali_period_value1}天 | ' \
95 | f'第{day1}/{pred_period_name1}天 | ' \
96 | f'
'
97 |
98 | # 交易详情,trade_detail1,保存为独立文件
99 |
100 | text_trade_detail += f'\r\n{"-" * 20} {agent1} {vali_period_value1}天 {"-" * 20}\r\n'
101 |
102 | text_trade_detail += f'{episode_return1}% / {max_return1}% ' \
103 | f' {action1}% ' \
104 | f' {hold1}% ' \
105 | f' {agent1} ' \
106 | f' {vali_period_value1}天 ' \
107 | f' 第{day1}/{pred_period_name1}天\r\n'
108 |
109 | text_trade_detail += '\r\n交易详情\r\n\r\n'
110 |
111 | text_trade_detail += trade_detail1
112 |
113 | pass
114 | pass
115 |
116 | # 日期
117 | date1 = max_date + ' ' + get_week_day(max_date)
118 | copy_text_card = copy_text_card.replace('<%date%>', date1)
119 |
120 | # 按 hold 分组,选出数量最多的 hold
121 | sql_cmd = f'SELECT "hold", COUNT(id) as count1 ' \
122 | f'FROM "{table_name}" WHERE "date" = \'{max_date}\' GROUP BY "hold"' \
123 | f' ORDER BY count1 DESC, abs("hold") DESC LIMIT 1'
124 |
125 | most_hold = sqlite.fetchone(sql_cmd)[0]
126 | most_hold = float(most_hold)
127 | most_hold = round(most_hold * 100 / max_hold, 0)
128 | copy_text_card = copy_text_card.replace('<%most_hold%>', str(most_hold))
129 |
130 | # 按 action 分组,取数量最多的 action
131 | sql_cmd = f'SELECT "action", COUNT("id") as count1 ' \
132 | f'FROM "{table_name}" WHERE "date" = \'{max_date}\' GROUP BY "action"' \
133 | f' ORDER BY count1 DESC, abs("action") DESC LIMIT 1'
134 |
135 | most_action = sqlite.fetchone(sql_cmd)[0]
136 | most_action = float(most_action)
137 | most_action = round(most_action * 100 / max_action, 0)
138 | copy_text_card = copy_text_card.replace('<%most_action%>', str(most_action))
139 |
140 | # 表格
141 | all_text_card += copy_text_card.replace('<%predict_result_table_tr_td%>', text_table_tr_td)
142 | all_text_card += '\r\n'
143 |
144 | if text_trade_detail is not '':
145 | # 写入交易详情文件
146 | with open(f'./{tic}.txt', 'w') as file_detail:
147 | file_detail.write(text_trade_detail)
148 | pass
149 | pass
150 | pass
151 | pass
152 |
153 | sqlite.close()
154 | pass
155 |
156 | # 将多个 卡片模板 替换到 网页模板
157 | text_index_page_template = text_index_page_template.replace('<%page_content%>', all_text_card)
158 |
159 | current_time_point = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
160 | text_index_page_template = text_index_page_template.replace('<%page_time_point%>', current_time_point)
161 |
162 | text_index_page_template = text_index_page_template.replace('<%page_title%>', 'A股预测')
163 |
164 | INDEX_HTML_PAGE_PATH = './index.html'
165 |
166 | # 写入网页文件
167 | with open(INDEX_HTML_PAGE_PATH, 'w') as file_index:
168 | file_index.write(text_index_page_template)
169 | pass
170 | pass
171 |
172 | pass
173 |
--------------------------------------------------------------------------------
/train_single.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | from datetime import datetime
5 |
6 | import config
7 | import numpy as np
8 | from stock_data import StockData
9 |
10 | from agent import AgentPPO, AgentSAC, AgentTD3, AgentDDPG, AgentModSAC, AgentDuelingDQN, AgentSharedSAC, \
11 | AgentDoubleDQN
12 |
13 | from run_single import Arguments, train_and_evaluate_mp, train_and_evaluate
14 | from train_helper import init_model_hyper_parameters_table_sqlite, query_model_hyper_parameters_sqlite, \
15 | update_model_hyper_parameters_by_train_history, clear_train_history_table_sqlite
16 | from utils import date_time
17 |
18 | from utils.date_time import get_datetime_from_date_str, get_next_work_day, \
19 | get_today_date
20 | # from env_train_single import StockTradingEnv
21 | from env_single import StockTradingEnvSingle
22 |
23 | if __name__ == '__main__':
24 |
25 | # 初始化超参表
26 | init_model_hyper_parameters_table_sqlite()
27 |
28 | # 开始训练的日期,在程序启动之后,不要改变
29 | config.SINGLE_A_STOCK_CODE = ['sh.600036', ]
30 |
31 | # 初始现金
32 | initial_capital = 150000
33 |
34 | # 单次 购买/卖出 最大股数
35 | max_stock = 3000
36 |
37 | initial_stocks_train = np.zeros(len(config.SINGLE_A_STOCK_CODE), dtype=np.float32)
38 | initial_stocks_vali = np.zeros(len(config.SINGLE_A_STOCK_CODE), dtype=np.float32)
39 |
40 | # 默认持有0-3000股
41 | initial_stocks_train[0] = 3000.0
42 | initial_stocks_vali[0] = 3000.0
43 |
44 | if_on_policy = False
45 | # if_use_gae = True
46 |
47 | config.IF_ACTUAL_PREDICT = False
48 |
49 | config.START_DATE = "2003-05-01"
50 | config.START_EVAL_DATE = ""
51 |
52 | # 整体结束日期,今天的日期,减去60工作日
53 | predict_work_days = 60
54 |
55 | config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), -predict_work_days))
56 | # config.END_DATE = '2021-07-21'
57 |
58 | # 更新股票数据,不复权,表名 fe_fillzero
59 | StockData.update_stock_data_to_sqlite(list_stock_code=config.SINGLE_A_STOCK_CODE, adjustflag='3', table_name='fe_fillzero')
60 |
61 | # 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(),
62 | # AgentDoubleDQN 单进程好用?
63 | # 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC()
64 |
65 | loop_index = 0
66 |
67 | # 循环
68 | while True:
69 |
70 | # 清空训练历史记录表
71 | clear_train_history_table_sqlite()
72 |
73 | # 从 model_hyper_parameters 表中,找到 training_times 最小的记录
74 | # 获取超参
75 | hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \
76 | eval_reward_scale, training_times, time_point = query_model_hyper_parameters_sqlite()
77 |
78 | if if_on_policy == 'True':
79 | if_on_policy = True
80 | else:
81 | if_on_policy = False
82 | pass
83 |
84 | config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id)
85 |
86 | # 获得Agent参数
87 | agent_class = None
88 | train_reward_scaling = 2 ** train_reward_scale
89 | eval_reward_scaling = 2 ** eval_reward_scale
90 |
91 | # 模型名称
92 | config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0]
93 |
94 | if config.AGENT_NAME == 'AgentPPO':
95 | agent_class = AgentPPO()
96 | elif config.AGENT_NAME == 'AgentSAC':
97 | agent_class = AgentSAC()
98 | elif config.AGENT_NAME == 'AgentTD3':
99 | agent_class = AgentTD3()
100 | elif config.AGENT_NAME == 'AgentDDPG':
101 | agent_class = AgentDDPG()
102 | elif config.AGENT_NAME == 'AgentModSAC':
103 | agent_class = AgentModSAC()
104 | elif config.AGENT_NAME == 'AgentDuelingDQN':
105 | agent_class = AgentDuelingDQN()
106 | elif config.AGENT_NAME == 'AgentSharedSAC':
107 | agent_class = AgentSharedSAC()
108 | elif config.AGENT_NAME == 'AgentDoubleDQN':
109 | agent_class = AgentDoubleDQN()
110 | pass
111 |
112 | # 预测周期
113 | work_days = int(str(hyper_parameters_model_name).split('_')[1])
114 |
115 | # 预测的截止日期
116 | end_vali_date = get_datetime_from_date_str(config.END_DATE)
117 |
118 | # 开始预测日期
119 | begin_date = date_time.get_next_work_day(end_vali_date, next_flag=-work_days)
120 |
121 | # 更新工作日标记,用于 run_single.py 加载训练过的 weights 文件
122 | config.VALI_DAYS_FLAG = str(work_days)
123 |
124 | model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}' \
125 | f'/single_{config.VALI_DAYS_FLAG}'
126 |
127 | if not os.path.exists(model_folder_path):
128 | os.makedirs(model_folder_path)
129 | pass
130 |
131 | # 开始预测的日期
132 | config.START_EVAL_DATE = str(begin_date)
133 |
134 | print('\r\n')
135 | print('-' * 40)
136 | print('config.AGENT_NAME', config.AGENT_NAME)
137 | print('# 训练-预测周期', config.START_DATE, '-', config.START_EVAL_DATE, '-', config.END_DATE)
138 | print('# work_days', work_days)
139 | print('# model_folder_path', model_folder_path)
140 | print('# initial_capital', initial_capital)
141 | print('# max_stock', max_stock)
142 |
143 | args = Arguments(if_on_policy=if_on_policy)
144 | args.agent = agent_class
145 | # args.agent.if_use_gae = if_use_gae
146 | args.agent.lambda_entropy = 0.04
147 | args.gpu_id = 0
148 |
149 | tech_indicator_list = [
150 | 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30',
151 | 'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST
152 |
153 | gamma = 0.99
154 |
155 | print('# initial_stocks_train', initial_stocks_train)
156 | print('# initial_stocks_vali', initial_stocks_vali)
157 |
158 | buy_cost_pct = 0.003
159 | sell_cost_pct = 0.003
160 | start_date = config.START_DATE
161 | start_eval_date = config.START_EVAL_DATE
162 | end_eval_date = config.END_DATE
163 |
164 | # train
165 | args.env = StockTradingEnvSingle(cwd='', gamma=gamma, max_stock=max_stock,
166 | initial_capital=initial_capital,
167 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, start_date=start_date,
168 | end_date=start_eval_date, env_eval_date=end_eval_date,
169 | ticker_list=config.SINGLE_A_STOCK_CODE,
170 | tech_indicator_list=tech_indicator_list, initial_stocks=initial_stocks_train,
171 | if_eval=False)
172 |
173 | # eval
174 | args.env_eval = StockTradingEnvSingle(cwd='', gamma=gamma, max_stock=max_stock,
175 | initial_capital=initial_capital,
176 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct,
177 | start_date=start_date,
178 | end_date=start_eval_date, env_eval_date=end_eval_date,
179 | ticker_list=config.SINGLE_A_STOCK_CODE,
180 | tech_indicator_list=tech_indicator_list,
181 | initial_stocks=initial_stocks_vali,
182 | if_eval=True)
183 |
184 | args.env.target_return = 10
185 | args.env_eval.target_return = 10
186 |
187 | # 奖励 比例
188 | args.env.reward_scaling = train_reward_scaling
189 | args.env_eval.reward_scaling = eval_reward_scaling
190 |
191 | print('train reward_scaling', args.env.reward_scaling)
192 | print('eval reward_scaling', args.env_eval.reward_scaling)
193 |
194 | # Hyperparameters
195 | args.gamma = gamma
196 | # args.gamma = 0.99
197 |
198 | # reward_scaling 在 args.env里调整了,这里不动
199 | args.reward_scale = 2 ** 0
200 |
201 | # args.break_step = int(break_step / 30)
202 | args.break_step = break_step
203 |
204 | print('break_step', args.break_step)
205 |
206 | args.net_dim = 2 ** 9
207 | args.max_step = args.env.max_step
208 |
209 | # args.max_memo = args.max_step * 4
210 | args.max_memo = (args.max_step - 1) * 8
211 |
212 | args.batch_size = 2 ** 12
213 | # args.batch_size = 2305
214 | print('batch_size', args.batch_size)
215 |
216 | # ----
217 | # args.repeat_times = 2 ** 3
218 | args.repeat_times = 2 ** 4
219 | # ----
220 |
221 | args.eval_gap = 2 ** 4
222 | args.eval_times1 = 2 ** 3
223 | args.eval_times2 = 2 ** 5
224 |
225 | args.if_allow_break = False
226 |
227 | args.rollout_num = 2 # the number of rollout workers (larger is not always faster)
228 |
229 | # train_and_evaluate(args)
230 | train_and_evaluate_mp(args) # the training process will terminate once it reaches the target reward.
231 |
232 | # 保存训练后的模型
233 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth', f'{model_folder_path}/actor.pth')
234 |
235 | # 多留一份
236 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth',
237 | f'{model_folder_path}/{date_time.time_point(time_format="%Y%m%d_%H%M%S")}.pth')
238 |
239 | # 保存训练曲线图
240 | # plot_learning_curve.jpg
241 | timepoint_temp = date_time.time_point()
242 | plot_learning_curve_file_path = f'{model_folder_path}/{timepoint_temp}.jpg'
243 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/plot_learning_curve.jpg',
244 | plot_learning_curve_file_path)
245 |
246 | # 训练结束后,model_hyper_parameters 表 中的 训练的次数 +1,训练的时间点 更新。
247 |
248 | # 判断 train_history 表,是否有记录,如果有,则整除 256 + 128。将此值更新到 model_hyper_parameters 表的 超参,减去相应的值。
249 |
250 | update_model_hyper_parameters_by_train_history(model_hyper_parameters_id=hyper_parameters_id,
251 | origin_train_reward_scale=train_reward_scale,
252 | origin_eval_reward_scale=eval_reward_scale,
253 | origin_training_times=training_times)
254 |
255 | print('>', config.AGENT_NAME, break_step, 'steps')
256 |
257 | # 循环次数
258 | loop_index += 1
259 |
260 | # # 5个模型都摸一遍,退出
261 | # if loop_index == 5:
262 | # break
263 |
264 | # print('>', 'while 循环次数', loop_index, '\r\n')
265 |
266 | print('sleep 10 秒\r\n')
267 | time.sleep(10)
268 |
269 | # TODO 训练一次退出
270 | break
271 |
272 | pass
273 |
--------------------------------------------------------------------------------
/train_batch_bak.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | from datetime import datetime
5 |
6 | import config
7 | import numpy as np
8 | from stock_data import StockData
9 |
10 | from agent import AgentPPO, AgentSAC, AgentTD3, AgentDDPG, AgentModSAC, AgentDuelingDQN, AgentSharedSAC, \
11 | AgentDoubleDQN
12 |
13 | from run_batch import Arguments, train_and_evaluate_mp, train_and_evaluate
14 | from train_helper import init_model_hyper_parameters_table_sqlite, query_model_hyper_parameters_sqlite, \
15 | update_model_hyper_parameters_by_train_history, clear_train_history_table_sqlite
16 | from utils import date_time
17 |
18 | from utils.date_time import get_datetime_from_date_str, get_next_work_day, \
19 | get_today_date
20 | from env_batch import StockTradingEnvBatch
21 |
22 | if __name__ == '__main__':
23 | # 开始预测的时间
24 | time_begin = datetime.now()
25 |
26 | # 初始化超参表
27 | init_model_hyper_parameters_table_sqlite()
28 |
29 | fe_table_name = 'fe_fillzero_train'
30 |
31 | # 2003年组,用 sz.000028 作为代号
32 | # 股票的顺序,不要改变
33 | # config.BATCH_A_STOCK_CODE = ['sz.000028', 'sh.600585', 'sz.000538', 'sh.600036']
34 | config.START_DATE = "2004-05-01"
35 |
36 | config.BATCH_A_STOCK_CODE = StockData.get_batch_a_share_code_list_string(table_name='tic_list_275')
37 |
38 | # 初始现金,每只股票15万元
39 | initial_capital = 150000 * len(config.BATCH_A_STOCK_CODE)
40 |
41 | # 单次 购买/卖出 最大股数
42 | # TODO 根据每只最近收盘价,得到多只股票平均价,计算 单次购买/卖出 最大股数
43 | max_stock = 3000
44 |
45 | initial_stocks_train = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32)
46 | initial_stocks_vali = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32)
47 |
48 | # 默认持有0-3000股
49 | initial_stocks_train = max_stock * initial_stocks_train
50 | initial_stocks_vali = max_stock * initial_stocks_vali
51 |
52 | # if_on_policy = False
53 |
54 | config.IF_ACTUAL_PREDICT = False
55 |
56 | config.START_EVAL_DATE = ""
57 |
58 | # 整体结束日期,今天的日期,减去60工作日
59 | predict_work_days = 60
60 |
61 | config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), -predict_work_days))
62 | # config.END_DATE = '2021-07-21'
63 |
64 | # TODO 因为数据量大,在 stock_data.py 中更新
65 | # # 更新股票数据,不复权
66 | # StockData.update_stock_data_to_sqlite(list_stock_code=config.BATCH_A_STOCK_CODE, adjustflag='3',
67 | # table_name=fe_table_name, if_incremental_update=False)
68 |
69 | # 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(),
70 | # AgentDoubleDQN 单进程好用?
71 | # 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC()
72 |
73 | loop_index = 0
74 |
75 | # 循环
76 | while True:
77 |
78 | # 清空训练历史记录表
79 | clear_train_history_table_sqlite()
80 |
81 | # 从 model_hyper_parameters 表中,找到 training_times 最小的记录
82 | # 获取超参
83 | hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \
84 | eval_reward_scale, training_times, time_point = query_model_hyper_parameters_sqlite()
85 |
86 | if if_on_policy == 'True':
87 | if_on_policy = True
88 | else:
89 | if_on_policy = False
90 | pass
91 |
92 | config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id)
93 |
94 | # 获得Agent参数
95 | agent_class = None
96 | train_reward_scaling = 2 ** train_reward_scale
97 | eval_reward_scaling = 2 ** eval_reward_scale
98 |
99 | # 模型名称
100 | config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0]
101 |
102 | if config.AGENT_NAME == 'AgentPPO':
103 | agent_class = AgentPPO()
104 | elif config.AGENT_NAME == 'AgentSAC':
105 | agent_class = AgentSAC()
106 | elif config.AGENT_NAME == 'AgentTD3':
107 | agent_class = AgentTD3()
108 | elif config.AGENT_NAME == 'AgentDDPG':
109 | agent_class = AgentDDPG()
110 | elif config.AGENT_NAME == 'AgentModSAC':
111 | agent_class = AgentModSAC()
112 | elif config.AGENT_NAME == 'AgentDuelingDQN':
113 | agent_class = AgentDuelingDQN()
114 | elif config.AGENT_NAME == 'AgentSharedSAC':
115 | agent_class = AgentSharedSAC()
116 | elif config.AGENT_NAME == 'AgentDoubleDQN':
117 | agent_class = AgentDoubleDQN()
118 | pass
119 |
120 | # 预测周期
121 | work_days = int(str(hyper_parameters_model_name).split('_')[1])
122 |
123 | # 预测的截止日期
124 | end_vali_date = get_datetime_from_date_str(config.END_DATE)
125 |
126 | # 开始预测日期
127 | begin_date = date_time.get_next_work_day(end_vali_date, next_flag=-work_days)
128 |
129 | # 更新工作日标记,用于 run_single.py 加载训练过的 weights 文件
130 | config.VALI_DAYS_FLAG = str(work_days)
131 |
132 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \
133 | f'batch_{config.VALI_DAYS_FLAG}'
134 |
135 | if not os.path.exists(model_folder_path):
136 | os.makedirs(model_folder_path)
137 | pass
138 |
139 | # 开始预测的日期
140 | config.START_EVAL_DATE = str(begin_date)
141 |
142 | print('\r\n')
143 | print('-' * 40)
144 | print('config.AGENT_NAME', config.AGENT_NAME)
145 | print('# 训练-预测周期', config.START_DATE, '-', config.START_EVAL_DATE, '-', config.END_DATE)
146 | print('# work_days', work_days)
147 | print('# model_folder_path', model_folder_path)
148 | print('# initial_capital', initial_capital)
149 | print('# max_stock', max_stock)
150 |
151 | args = Arguments(if_on_policy=if_on_policy)
152 | args.agent = agent_class
153 | # args.agent.if_use_gae = if_use_gae
154 | args.agent.lambda_entropy = 0.04
155 | args.gpu_id = 0
156 |
157 | tech_indicator_list = [
158 | 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30',
159 | 'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST
160 |
161 | gamma = 0.99
162 |
163 | # print('# initial_stocks_train', initial_stocks_train)
164 | # print('# initial_stocks_vali', initial_stocks_vali)
165 |
166 | buy_cost_pct = 0.003
167 | sell_cost_pct = 0.003
168 | start_date = config.START_DATE
169 | start_eval_date = config.START_EVAL_DATE
170 | end_eval_date = config.END_DATE
171 |
172 | # train
173 | args.env = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock,
174 | initial_capital=initial_capital,
175 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, start_date=start_date,
176 | end_date=start_eval_date, env_eval_date=end_eval_date,
177 | ticker_list=config.BATCH_A_STOCK_CODE,
178 | tech_indicator_list=tech_indicator_list, initial_stocks=initial_stocks_train,
179 | if_eval=False, fe_table_name=fe_table_name)
180 |
181 | # eval
182 | args.env_eval = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock,
183 | initial_capital=initial_capital,
184 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct,
185 | start_date=start_date,
186 | end_date=start_eval_date, env_eval_date=end_eval_date,
187 | ticker_list=config.BATCH_A_STOCK_CODE,
188 | tech_indicator_list=tech_indicator_list,
189 | initial_stocks=initial_stocks_vali,
190 | if_eval=True, fe_table_name=fe_table_name)
191 |
192 | args.env.target_return = 10
193 | args.env_eval.target_return = 10
194 |
195 | # 奖励 比例
196 | args.env.reward_scale = train_reward_scaling
197 | args.env_eval.reward_scale = eval_reward_scaling
198 |
199 | print('train reward_scaling', args.env.reward_scale)
200 | print('eval reward_scaling', args.env_eval.reward_scale)
201 |
202 | # Hyperparameters
203 | args.gamma = gamma
204 | # args.gamma = 0.99
205 |
206 | # reward_scaling 在 args.env里调整了,这里不动
207 | args.reward_scale = 2 ** 0
208 |
209 | # args.break_step = int(break_step / 30)
210 | args.break_step = break_step
211 |
212 | print('break_step', args.break_step)
213 |
214 | args.net_dim = 2 ** 9
215 | args.max_step = args.env.max_step
216 |
217 | # args.max_memo = args.max_step * 4
218 | args.max_memo = (args.max_step - 1) * 8
219 |
220 | args.batch_size = 2 ** 12
221 | # args.batch_size = 2305
222 | print('batch_size', args.batch_size)
223 |
224 | # ----
225 | # args.repeat_times = 2 ** 3
226 | args.repeat_times = 2 ** 4
227 | # ----
228 |
229 | args.eval_gap = 2 ** 4
230 | args.eval_times1 = 2 ** 3
231 | args.eval_times2 = 2 ** 5
232 |
233 | args.if_allow_break = False
234 |
235 | args.rollout_num = 2 # the number of rollout workers (larger is not always faster)
236 |
237 | # train_and_evaluate(args)
238 | train_and_evaluate_mp(args) # the training process will terminate once it reaches the target reward.
239 |
240 | # 保存训练后的模型
241 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth', f'{model_folder_path}/actor.pth')
242 |
243 | # 多留一份
244 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth',
245 | f'{model_folder_path}/{date_time.time_point(time_format="%Y%m%d_%H%M%S")}.pth')
246 |
247 | # 保存训练曲线图
248 | # plot_learning_curve.jpg
249 | timepoint_temp = date_time.time_point()
250 | plot_learning_curve_file_path = f'{model_folder_path}/{timepoint_temp}.jpg'
251 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/plot_learning_curve.jpg',
252 | plot_learning_curve_file_path)
253 |
254 | # 训练结束后,model_hyper_parameters 表 中的 训练的次数 +1,训练的时间点 更新。
255 |
256 | # 判断 train_history 表,是否有记录,如果有,则整除 256 + 128。将此值更新到 model_hyper_parameters 表的 超参,减去相应的值。
257 | update_model_hyper_parameters_by_train_history(model_hyper_parameters_id=hyper_parameters_id,
258 | origin_train_reward_scale=train_reward_scale,
259 | origin_eval_reward_scale=eval_reward_scale,
260 | origin_training_times=training_times)
261 |
262 | print('>', config.AGENT_NAME, break_step, 'steps')
263 |
264 | # 结束预测的时间
265 | time_end = datetime.now()
266 | duration = (time_end - time_begin).total_seconds()
267 | print('检测耗时', duration, '秒')
268 |
269 | # 循环次数
270 | loop_index += 1
271 | print('>', 'while 循环次数', loop_index, '\r\n')
272 |
273 | print('sleep 10 秒\r\n')
274 | time.sleep(10)
275 |
276 | # TODO 训练一次退出
277 | # break
278 |
279 | pass
280 |
--------------------------------------------------------------------------------
/train_batch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | from datetime import datetime
5 |
6 | import config
7 | import numpy as np
8 | from stock_data import StockData
9 |
10 | from agent import AgentPPO, AgentSAC, AgentTD3, AgentDDPG, AgentModSAC, AgentDuelingDQN, AgentSharedSAC, \
11 | AgentDoubleDQN
12 |
13 | from run_batch import Arguments, train_and_evaluate_mp, train_and_evaluate
14 | from train_helper import init_model_hyper_parameters_table_sqlite, query_model_hyper_parameters_sqlite, \
15 | update_model_hyper_parameters_by_train_history, clear_train_history_table_sqlite
16 | from utils import date_time
17 |
18 | from utils.date_time import get_datetime_from_date_str, get_next_work_day, get_today_date
19 | from env_batch import StockTradingEnvBatch
20 |
21 | if __name__ == '__main__':
22 | # 开始预测的时间
23 | time_begin = datetime.now()
24 |
25 | # 初始化超参表
26 | init_model_hyper_parameters_table_sqlite()
27 |
28 | fe_table_name = 'fe_fillzero_train'
29 |
30 | # 股票的顺序,不要改变
31 | # config.BATCH_A_STOCK_CODE = ['sz.000028', 'sh.600585', 'sz.000538', 'sh.600036']
32 | config.START_DATE = "2004-05-01"
33 |
34 | config.BATCH_A_STOCK_CODE = StockData.get_batch_a_share_code_list_string(table_name='tic_list_275')
35 |
36 | # 初始现金,每只股票15万元
37 | initial_capital = 150000 * len(config.BATCH_A_STOCK_CODE)
38 |
39 | # 单次 购买/卖出 最大股数
40 | max_stock = 50000
41 |
42 | initial_stocks_train = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32)
43 | initial_stocks_vali = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32)
44 |
45 | # 默认持有0-3000股
46 | initial_stocks_train = max_stock * initial_stocks_train
47 | initial_stocks_vali = max_stock * initial_stocks_vali
48 |
49 | config.IF_ACTUAL_PREDICT = False
50 |
51 | config.START_EVAL_DATE = ""
52 |
53 | # TODO 因为数据量大,在 stock_data.py 中更新
54 | # # 更新股票数据,不复权
55 | # StockData.update_stock_data_to_sqlite(list_stock_code=config.BATCH_A_STOCK_CODE, adjustflag='3',
56 | # table_name=fe_table_name, if_incremental_update=False)
57 |
58 | # 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(),
59 | # AgentDoubleDQN 单进程好用?
60 | # 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC()
61 |
62 | loop_index = 0
63 |
64 | # 循环
65 | while True:
66 |
67 | # 清空训练历史记录表
68 | clear_train_history_table_sqlite()
69 |
70 | # 从 model_hyper_parameters 表中,找到 training_times 最小的记录
71 | # 获取超参
72 | hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \
73 | eval_reward_scale, training_times, time_point, state_amount_scale, state_price_scale, state_stocks_scale, \
74 | state_tech_scale = query_model_hyper_parameters_sqlite()
75 |
76 | if if_on_policy == 'True':
77 | if_on_policy = True
78 | else:
79 | if_on_policy = False
80 | pass
81 |
82 | config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id)
83 |
84 | # 获得Agent参数
85 | agent_class = None
86 |
87 | # 模型名称
88 | config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0]
89 | # 模型预测周期
90 | config.AGENT_WORK_DAY = int(str(hyper_parameters_model_name).split('_')[1])
91 |
92 | if config.AGENT_NAME == 'AgentPPO':
93 | agent_class = AgentPPO()
94 | elif config.AGENT_NAME == 'AgentSAC':
95 | agent_class = AgentSAC()
96 | elif config.AGENT_NAME == 'AgentTD3':
97 | agent_class = AgentTD3()
98 | elif config.AGENT_NAME == 'AgentDDPG':
99 | agent_class = AgentDDPG()
100 | elif config.AGENT_NAME == 'AgentModSAC':
101 | agent_class = AgentModSAC()
102 | elif config.AGENT_NAME == 'AgentDuelingDQN':
103 | agent_class = AgentDuelingDQN()
104 | elif config.AGENT_NAME == 'AgentSharedSAC':
105 | agent_class = AgentSharedSAC()
106 | elif config.AGENT_NAME == 'AgentDoubleDQN':
107 | agent_class = AgentDoubleDQN()
108 | pass
109 |
110 | # 更新工作日标记,用于 run_single.py 加载训练过的 weights 文件
111 | config.VALI_DAYS_FLAG = str(config.AGENT_WORK_DAY)
112 |
113 | # TODO 整体结束日期,今天的日期,预留60个工作日,用于验证predict
114 |
115 | # config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()),
116 | # -1 * config.AGENT_WORK_DAY))
117 |
118 | config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), -60))
119 |
120 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \
121 | f'batch_{config.VALI_DAYS_FLAG}'
122 |
123 | if not os.path.exists(model_folder_path):
124 | os.makedirs(model_folder_path)
125 | pass
126 |
127 | # 开始预测的日期
128 | config.START_EVAL_DATE = str(
129 | get_next_work_day(get_datetime_from_date_str(get_today_date()), -2 * config.AGENT_WORK_DAY))
130 |
131 | print('\r\n')
132 | print('-' * 40)
133 | print('config.AGENT_NAME', config.AGENT_NAME)
134 | print('# 训练-预测周期', config.START_DATE, '-', config.START_EVAL_DATE, '-', config.END_DATE)
135 | print('# work_days', config.AGENT_WORK_DAY)
136 | print('# model_folder_path', model_folder_path)
137 | print('# initial_capital', initial_capital)
138 | print('# max_stock', max_stock)
139 |
140 | args = Arguments(if_on_policy=if_on_policy)
141 | args.agent = agent_class
142 | # args.agent.if_use_gae = if_use_gae
143 | args.agent.lambda_entropy = 0.04
144 | args.gpu_id = 0
145 |
146 | tech_indicator_list = [
147 | 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30',
148 | 'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST
149 |
150 | gamma = 0.99
151 |
152 | buy_cost_pct = 0.003
153 | sell_cost_pct = 0.003
154 | start_date = config.START_DATE
155 | start_eval_date = config.START_EVAL_DATE
156 | end_eval_date = config.END_DATE
157 |
158 | # train
159 | args.env = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock,
160 | initial_capital=initial_capital,
161 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, start_date=start_date,
162 | end_date=start_eval_date, env_eval_date=end_eval_date,
163 | ticker_list=config.BATCH_A_STOCK_CODE,
164 | tech_indicator_list=tech_indicator_list, initial_stocks=initial_stocks_train,
165 | if_eval=False, fe_table_name=fe_table_name)
166 |
167 | # eval
168 | args.env_eval = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock,
169 | initial_capital=initial_capital,
170 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct,
171 | start_date=start_date,
172 | end_date=start_eval_date, env_eval_date=end_eval_date,
173 | ticker_list=config.BATCH_A_STOCK_CODE,
174 | tech_indicator_list=tech_indicator_list,
175 | initial_stocks=initial_stocks_vali,
176 | if_eval=True, fe_table_name=fe_table_name)
177 |
178 | args.env.target_return = 100
179 | args.env_eval.target_return = 100
180 |
181 | # 奖励 比例
182 | args.env.reward_scale = train_reward_scale
183 | args.env_eval.reward_scale = eval_reward_scale
184 |
185 | args.env.state_amount_scale = state_amount_scale
186 | args.env.state_price_scale = state_price_scale
187 | args.env.state_stocks_scale = state_stocks_scale
188 | args.env.state_tech_scale = state_tech_scale
189 |
190 | args.env_eval.state_amount_scale = state_amount_scale
191 | args.env_eval.state_price_scale = state_price_scale
192 | args.env_eval.state_stocks_scale = state_stocks_scale
193 | args.env_eval.state_tech_scale = state_tech_scale
194 |
195 | print('train reward_scale', args.env.reward_scale)
196 | print('eval reward_scale', args.env_eval.reward_scale)
197 |
198 | print('state_amount_scale', state_amount_scale)
199 | print('state_price_scale', state_price_scale)
200 | print('state_stocks_scale', state_stocks_scale)
201 | print('state_tech_scale', state_tech_scale)
202 |
203 | # Hyperparameters
204 | args.gamma = gamma
205 | # args.gamma = 0.99
206 |
207 | # reward_scaling 在 args.env里调整了,这里不动
208 | # args.reward_scale = 2 ** 0
209 | args.reward_scale = 1
210 |
211 | # args.break_step = int(break_step / 30)
212 | args.break_step = break_step
213 |
214 | print('break_step', args.break_step)
215 |
216 | args.net_dim = 2 ** 9
217 | args.max_step = args.env.max_step
218 |
219 | # args.max_memo = args.max_step * 4
220 | args.max_memo = (args.max_step - 1) * 8
221 |
222 | args.batch_size = 2 ** 12
223 | # args.batch_size = 2305
224 | print('batch_size', args.batch_size)
225 |
226 | # ----
227 | # args.repeat_times = 2 ** 3
228 | args.repeat_times = 2 ** 4
229 | # ----
230 |
231 | args.eval_gap = 2 ** 4
232 | args.eval_times1 = 2 ** 3
233 | args.eval_times2 = 2 ** 5
234 |
235 | args.if_allow_break = False
236 |
237 | args.rollout_num = 2 # the number of rollout workers (larger is not always faster)
238 |
239 | # train_and_evaluate(args)
240 | train_and_evaluate_mp(args) # the training process will terminate once it reaches the target reward.
241 |
242 | # 保存训练后的模型
243 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth', f'{model_folder_path}/actor.pth')
244 |
245 | # 多留一份
246 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth',
247 | f'{model_folder_path}/{date_time.time_point(time_format="%Y%m%d_%H%M%S")}.pth')
248 |
249 | # 保存训练曲线图
250 | # plot_learning_curve.jpg
251 | timepoint_temp = date_time.time_point()
252 | plot_learning_curve_file_path = f'{model_folder_path}/{timepoint_temp}.jpg'
253 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/plot_learning_curve.jpg',
254 | plot_learning_curve_file_path)
255 |
256 | # 训练结束后,model_hyper_parameters 表 中的 训练的次数 +1,训练的时间点 更新。
257 |
258 | # 判断 train_history 表,是否有记录,如果有,则整除 256 + 128。将此值更新到 model_hyper_parameters 表的 超参,减去相应的值。
259 | update_model_hyper_parameters_by_train_history(model_hyper_parameters_id=hyper_parameters_id,
260 | origin_train_reward_scale=train_reward_scale,
261 | origin_eval_reward_scale=eval_reward_scale,
262 | origin_training_times=training_times,
263 | origin_state_amount_scale=state_amount_scale,
264 | origin_state_price_scale=state_price_scale,
265 | origin_state_stocks_scale=state_stocks_scale,
266 | origin_state_tech_scale=state_tech_scale)
267 |
268 | print('>', config.AGENT_NAME, break_step, 'steps')
269 |
270 | # 结束预测的时间
271 | time_end = datetime.now()
272 | duration = (time_end - time_begin).total_seconds()
273 | print('检测耗时', duration, '秒')
274 |
275 | # 循环次数
276 | loop_index += 1
277 | print('>', 'while 循环次数', loop_index, '\r\n')
278 |
279 | print('sleep 10 秒\r\n')
280 | time.sleep(10)
281 |
282 | # TODO 训练一次退出
283 | break
284 |
285 | pass
286 |
--------------------------------------------------------------------------------
/train_helper.py:
--------------------------------------------------------------------------------
1 | import config
2 | from utils import date_time
3 | from utils.date_time import get_next_work_day
4 | from utils.sqlite import SQLite
5 |
6 |
7 | def init_model_hyper_parameters_table_sqlite():
8 | # 初始化模型超参表 model_hyper_parameters 和 训练历史记录表 train_history
9 |
10 | time_point = date_time.time_point(time_format='%Y-%m-%d %H:%M:%S')
11 |
12 | # 连接数据库
13 | db_path = config.STOCK_DB_PATH
14 |
15 | # 表名
16 | table_name = 'model_hyper_parameters'
17 |
18 | sqlite = SQLite(db_path)
19 |
20 | if_exists = sqlite.table_exists(table_name)
21 |
22 | if if_exists is None:
23 | # 如果是初始化,则创建表
24 | sqlite.execute_non_query(sql=f'CREATE TABLE "{table_name}" (id INTEGER PRIMARY KEY AUTOINCREMENT, '
25 | f'model_name TEXT NOT NULL UNIQUE, if_on_policy INTEGER NOT NULL, '
26 | f'break_step INTEGER NOT NULL, train_reward_scale INTEGER NOT NULL, '
27 | f'eval_reward_scale INTEGER NOT NULL, training_times INTEGER NOT NULL, '
28 | f'state_amount_scale INTEGER NOT NULL, state_price_scale INTEGER NOT NULL, '
29 | f'state_stocks_scale INTEGER NOT NULL, state_tech_scale INTEGER NOT NULL, '
30 | f'time_point TEXT NOT NULL, if_active INTEGER NOT NULL);')
31 |
32 | # 提交
33 | sqlite.commit()
34 | pass
35 |
36 | # 初始化默认值
37 | for agent_item in ['AgentSAC', 'AgentPPO', 'AgentTD3', 'AgentDDPG', 'AgentModSAC', ]:
38 |
39 | if agent_item == 'AgentPPO':
40 | if_on_policy = 1
41 | break_step = 50000
42 | elif agent_item == 'AgentModSAC':
43 | if_on_policy = 0
44 | break_step = 50000
45 | else:
46 | if_on_policy = 0
47 | break_step = 5000
48 | pass
49 |
50 | # 例如 2 ** -6 这里只将 -6 保存进数据库
51 | train_reward_scale = -12
52 | eval_reward_scale = -8
53 |
54 | # 训练次数
55 | training_times = 0
56 |
57 | state_amount_scale = -25
58 | state_price_scale = -9
59 | state_stocks_scale = -16
60 | state_tech_scale = -9
61 |
62 | for work_days in [300, ]:
63 | # 如果是初始化,则创建表
64 | sql_cmd = f'INSERT INTO "{table_name}" (model_name, if_on_policy, break_step, ' \
65 | f'train_reward_scale, eval_reward_scale, ' \
66 | f'state_amount_scale, state_price_scale, state_stocks_scale, state_tech_scale,' \
67 | f'training_times, time_point, if_active) ' \
68 | f'VALUES (?,?,?,?,?,?,?,?,?,?,?,1)'
69 |
70 | sql_values = (agent_item + '_' + str(work_days), if_on_policy, break_step,
71 | train_reward_scale, eval_reward_scale,
72 | state_amount_scale, state_price_scale, state_stocks_scale, state_tech_scale,
73 | training_times, time_point)
74 |
75 | sqlite.execute_non_query(sql_cmd, sql_values)
76 |
77 | pass
78 | pass
79 |
80 | # 提交
81 | sqlite.commit()
82 | pass
83 |
84 | # 表名
85 | table_name = 'train_history'
86 |
87 | if_exists = sqlite.table_exists(table_name)
88 |
89 | if if_exists is None:
90 | # 如果是初始化,则创建表
91 | sqlite.execute_non_query(sql=f'CREATE TABLE "{table_name}" (id INTEGER PRIMARY KEY AUTOINCREMENT, '
92 | f'model_id TEXT NOT NULL, '
93 | f'train_reward_value NUMERIC NOT NULL, eval_reward_value NUMERIC NOT NULL, '
94 | f'state_amount_value NUMERIC NOT NULL, state_price_value NUMERIC NOT NULL, '
95 | f'state_stocks_value NUMERIC NOT NULL, state_tech_value NUMERIC NOT NULL, '
96 | f'time_point TEXT NOT NULL);')
97 |
98 | # 提交
99 | sqlite.commit()
100 | pass
101 | pass
102 |
103 | sqlite.close()
104 |
105 | pass
106 |
107 |
108 | def query_model_hyper_parameters_sqlite(model_name=None):
109 | # 根据 model_name 查询模型超参
110 |
111 | # 连接数据库
112 | db_path = config.STOCK_DB_PATH
113 |
114 | # 表名
115 | table_name = 'model_hyper_parameters'
116 |
117 | sqlite = SQLite(db_path)
118 |
119 | # 'state_amount_scale INTEGER NOT NULL, state_price_scale INTEGER NOT NULL, '
120 | # 'state_stocks_scale INTEGER NOT NULL, state_tech_scale INTEGER NOT NULL, '
121 |
122 | if model_name is None:
123 | query_sql = f'SELECT id, model_name, if_on_policy, break_step, train_reward_scale, eval_reward_scale, ' \
124 | f'training_times, time_point, state_amount_scale, state_price_scale, state_stocks_scale, ' \
125 | f'state_tech_scale FROM "{table_name}" WHERE if_active=1 ' \
126 | f' ORDER BY training_times ASC LIMIT 1'
127 | else:
128 | # 唯一记录
129 | query_sql = f'SELECT id, model_name, if_on_policy, break_step, train_reward_scale, eval_reward_scale, ' \
130 | f'training_times, time_point, state_amount_scale, state_price_scale, state_stocks_scale, ' \
131 | f'state_tech_scale FROM "{table_name}" WHERE model_name="{model_name}"' \
132 | f' LIMIT 1'
133 | pass
134 |
135 | id1, model_name, if_on_policy, break_step, train_reward_scale, eval_reward_scale, training_times, time_point, \
136 | state_amount_scale, state_price_scale, state_stocks_scale, state_tech_scale = sqlite.fetchone(query_sql)
137 |
138 | sqlite.close()
139 |
140 | return id1, model_name, if_on_policy, break_step, train_reward_scale, eval_reward_scale, training_times, \
141 | time_point, state_amount_scale, state_price_scale, state_stocks_scale, state_tech_scale
142 |
143 |
144 | def update_model_hyper_parameters_table_sqlite(model_hyper_parameters_id, train_reward_scale, eval_reward_scale,
145 | training_times, state_amount_scale, state_price_scale,
146 | state_stocks_scale, state_tech_scale):
147 | time_point = date_time.time_point(time_format='%Y-%m-%d %H:%M:%S')
148 |
149 | # 更新超参表
150 | # 连接数据库
151 | db_path = config.STOCK_DB_PATH
152 |
153 | # 表名
154 | table_name = 'model_hyper_parameters'
155 |
156 | sqlite = SQLite(db_path)
157 |
158 | # 如果是初始化,则创建表
159 | sqlite.execute_non_query(sql=f'UPDATE "{table_name}" SET train_reward_scale={train_reward_scale}, '
160 | f'eval_reward_scale={eval_reward_scale}, training_times={training_times}, '
161 | f'time_point="{time_point}", state_amount_scale={state_amount_scale}, '
162 | f'state_price_scale={state_price_scale}, state_stocks_scale={state_stocks_scale}, '
163 | f'state_tech_scale={state_tech_scale} WHERE id={model_hyper_parameters_id}')
164 | # 提交
165 | sqlite.commit()
166 |
167 | sqlite.close()
168 | pass
169 |
170 |
171 | def clear_train_history_table_sqlite():
172 | # 清空训练历史记录表
173 | # 连接数据库
174 | db_path = config.STOCK_DB_PATH
175 |
176 | # 表名
177 | table_name = 'train_history'
178 |
179 | sqlite = SQLite(db_path)
180 |
181 | sqlite.execute_non_query(sql=f'DELETE FROM "{table_name}"')
182 | # 提交
183 | sqlite.commit()
184 | pass
185 |
186 | sqlite.close()
187 | pass
188 |
189 |
190 | def insert_train_history_record_sqlite(model_id, train_reward_value=0.0, eval_reward_value=0.0,
191 | state_amount_value=0.0, state_price_value=0.0, state_stocks_value=0.0,
192 | state_tech_value=0.0):
193 | time_point = date_time.time_point(time_format='%Y-%m-%d %H:%M:%S')
194 |
195 | # 插入训练历史记录
196 | # 连接数据库
197 | db_path = config.STOCK_DB_PATH
198 |
199 | # 表名
200 | table_name = 'train_history'
201 |
202 | sqlite = SQLite(db_path)
203 |
204 | sql_cmd = f'INSERT INTO "{table_name}" ' \
205 | f'(model_id, train_reward_value, eval_reward_value, time_point, ' \
206 | f'state_amount_value, state_price_value, state_stocks_value, state_tech_value) VALUES (?,?,?,?,?,?,?,?);'
207 |
208 | sql_values = (model_id, train_reward_value, eval_reward_value, time_point,
209 | state_amount_value, state_price_value, state_stocks_value, state_tech_value)
210 |
211 | sqlite.execute_non_query(sql_cmd, sql_values)
212 |
213 | # 提交
214 | sqlite.commit()
215 |
216 | sqlite.close()
217 |
218 | pass
219 |
220 |
221 | def loop_scale_one(max_state_value, origin_state_scale):
222 | if max_state_value is None:
223 | new_state_scale = origin_state_scale
224 | else:
225 | i = 0
226 | while max_state_value >= 1.0:
227 | max_state_value = max_state_value / 2
228 | i += 1
229 | pass
230 | new_state_scale = origin_state_scale - i
231 | pass
232 |
233 | return new_state_scale
234 |
235 |
236 | def update_model_hyper_parameters_by_train_history(model_hyper_parameters_id, origin_train_reward_scale,
237 | origin_eval_reward_scale, origin_training_times,
238 | origin_state_amount_scale, origin_state_price_scale,
239 | origin_state_stocks_scale, origin_state_tech_scale):
240 | # 根据reward历史,更新超参表
241 | # 插入训练历史记录
242 | # 连接数据库
243 | db_path = config.STOCK_DB_PATH
244 |
245 | # 表名
246 | table_name = 'train_history'
247 |
248 | sqlite = SQLite(db_path)
249 |
250 | query_sql = f'SELECT MAX(train_reward_value), MAX(eval_reward_value), ' \
251 | f'MAX(state_amount_value), MAX(state_price_value), ' \
252 | f'MAX(state_stocks_value), MAX(state_tech_value) FROM "{table_name}" ' \
253 | f' WHERE model_id="{model_hyper_parameters_id}"'
254 |
255 | max_train_reward_value, max_eval_reward_value, max_state_amount_value, max_state_price_value, \
256 | max_state_stocks_value, max_state_tech_value = sqlite.fetchone(query_sql)
257 |
258 | sqlite.close()
259 |
260 | # reward 阈值
261 | reward_threshold = config.REWARD_THRESHOLD
262 |
263 | if max_train_reward_value is None:
264 | new_train_reward_scale = origin_train_reward_scale
265 | print('> keep origin train_reward_scale', new_train_reward_scale)
266 | pass
267 | else:
268 | if max_train_reward_value >= reward_threshold:
269 | new_train_reward_scale = origin_train_reward_scale - (max_train_reward_value // reward_threshold)
270 | print('> modify train_reward_scale:', origin_train_reward_scale, '->', new_train_reward_scale)
271 | else:
272 | new_train_reward_scale = origin_train_reward_scale
273 | print('> keep origin train_reward_scale', new_train_reward_scale)
274 | pass
275 | pass
276 |
277 | if max_eval_reward_value is None:
278 | new_eval_reward_scale = origin_eval_reward_scale
279 | print('> keep origin eval_reward_scale', new_eval_reward_scale)
280 | pass
281 | else:
282 | if max_eval_reward_value >= reward_threshold:
283 | new_eval_reward_scale = origin_eval_reward_scale - (max_eval_reward_value // reward_threshold)
284 |
285 | print('> modify eval_reward_scale:', origin_eval_reward_scale, '->', new_eval_reward_scale)
286 | pass
287 | else:
288 | new_eval_reward_scale = origin_eval_reward_scale
289 |
290 | print('> keep origin eval_reward_scale', new_eval_reward_scale)
291 | pass
292 | pass
293 | pass
294 |
295 | new_state_amount_scale = loop_scale_one(max_state_amount_value, origin_state_amount_scale)
296 | if new_state_amount_scale == origin_state_amount_scale:
297 | print('> keep origin state_amount_scale', new_state_amount_scale)
298 | else:
299 | print('> modify state_amount_scale:', origin_state_amount_scale, '->', new_state_amount_scale)
300 | pass
301 |
302 | new_state_price_scale = loop_scale_one(max_state_price_value, origin_state_price_scale)
303 | if new_state_price_scale == origin_state_price_scale:
304 | print('> keep origin state_price_scale', new_state_price_scale)
305 | else:
306 | print('> modify state_price_scale:', origin_state_price_scale, '->', new_state_price_scale)
307 | pass
308 |
309 | new_state_stocks_scale = loop_scale_one(max_state_stocks_value, origin_state_stocks_scale)
310 | if new_state_stocks_scale == origin_state_stocks_scale:
311 | print('> keep origin state_stocks_scale', new_state_stocks_scale)
312 | else:
313 | print('> modify state_stocks_scale:', origin_state_stocks_scale, '->', new_state_stocks_scale)
314 | pass
315 |
316 | new_state_tech_scale = loop_scale_one(max_state_tech_value, origin_state_tech_scale)
317 | if new_state_tech_scale == origin_state_tech_scale:
318 | print('> keep origin state_tech_scale', new_state_tech_scale)
319 | else:
320 | print('> modify state_tech_scale:', origin_state_tech_scale, '->', new_state_tech_scale)
321 | pass
322 |
323 | # 更新超参表
324 | update_model_hyper_parameters_table_sqlite(model_hyper_parameters_id=model_hyper_parameters_id,
325 | train_reward_scale=new_train_reward_scale,
326 | eval_reward_scale=new_eval_reward_scale,
327 | training_times=origin_training_times + 1,
328 | state_amount_scale=new_state_amount_scale,
329 | state_price_scale=new_state_price_scale,
330 | state_stocks_scale=new_state_stocks_scale,
331 | state_tech_scale=new_state_tech_scale)
332 |
333 | pass
334 |
335 |
336 | def query_begin_vali_date_list_by_agent_name(agent_name, end_vali_date):
337 | list_result = list()
338 |
339 | # 连接数据库
340 | db_path = config.STOCK_DB_PATH
341 |
342 | # 表名
343 | table_name = 'model_hyper_parameters'
344 |
345 | sqlite = SQLite(db_path)
346 |
347 | query_sql = f'SELECT model_name FROM "{table_name}" WHERE if_active=1 AND model_name LIKE "{agent_name}%" ' \
348 | f' ORDER BY model_name ASC'
349 |
350 | list_temp = sqlite.fetchall(query_sql)
351 |
352 | sqlite.close()
353 |
354 | for work_days in list_temp:
355 | # AgentSAC_60 --> 60
356 | work_days = int(str(work_days[0]).split('_')[1])
357 | begin_vali_date = get_next_work_day(end_vali_date, next_flag=-work_days)
358 | list_result.append((work_days, begin_vali_date))
359 | pass
360 |
361 | list_temp.clear()
362 |
363 | return list_result
364 |
--------------------------------------------------------------------------------
/predict_single_psql.py:
--------------------------------------------------------------------------------
1 | from stock_data import StockData
2 | # from train_single import get_agent_args
3 | from train_helper import query_model_hyper_parameters_sqlite, query_begin_vali_date_list_by_agent_name
4 | from utils.psqldb import Psqldb
5 | from agent import *
6 | from utils.date_time import *
7 | from env_single import StockTradingEnvSingle, FeatureEngineer
8 | from run_single import *
9 | from datetime import datetime
10 |
11 |
12 | def calc_max_return(price_ary, initial_capital_temp):
13 | max_return_temp = 0
14 | min_value = 0
15 |
16 | assert price_ary.shape[0] > 1
17 |
18 | count_price = price_ary.shape[0]
19 |
20 | for index_left in range(0, count_price - 1):
21 |
22 | for index_right in range(index_left + 1, count_price):
23 |
24 | assert price_ary[index_left][0] > 0
25 |
26 | assert price_ary[index_right][0] > 0
27 |
28 | temp_value = price_ary[index_right][0] - price_ary[index_left][0]
29 |
30 | if temp_value > max_return_temp:
31 | max_return_temp = temp_value
32 | # max_value = price_ary[index1][0]
33 | min_value = price_ary[index_right][0]
34 | pass
35 | pass
36 |
37 | if min_value == 0:
38 | ret = 0
39 | else:
40 | ret = (initial_capital_temp / min_value * max_return_temp + initial_capital_temp) / initial_capital_temp
41 |
42 | return ret
43 |
44 |
45 | if __name__ == '__main__':
46 | # 预测,并保存结果到 postgresql 数据库
47 | # 开始预测的时间
48 | time_begin = datetime.now()
49 |
50 | config.OUTPUT_DATE = '2021-08-03'
51 |
52 | initial_capital = 150000
53 |
54 | max_stock = 3000
55 |
56 | # for tic_item in ['sh.600036', 'sh.600667']:
57 | # 循环
58 | for tic_item in ['sh.600036', ]:
59 |
60 | # 要预测的那一天
61 | config.SINGLE_A_STOCK_CODE = [tic_item, ]
62 |
63 | # psql对象
64 | psql_object = Psqldb(database=config.PSQL_DATABASE, user=config.PSQL_USER,
65 | password=config.PSQL_PASSWORD, host=config.PSQL_HOST, port=config.PSQL_PORT)
66 |
67 | # 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(),
68 | # AgentDoubleDQN 单进程好用?
69 | # 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC()
70 | for agent_item in ['AgentSAC', 'AgentPPO', 'AgentDDPG', 'AgentTD3', 'AgentModSAC', ]:
71 |
72 | config.AGENT_NAME = agent_item
73 | # config.CWD = f'./{config.AGENT_NAME}/single/{config.SINGLE_A_STOCK_CODE[0]}/StockTradingEnv-v1'
74 |
75 | break_step = int(3e6)
76 |
77 | if_on_policy = False
78 | # if_use_gae = False
79 |
80 | # 预测的开始日期和结束日期,都固定
81 |
82 | # 日期列表
83 | # 4月16日向前,20,30,40,50,60,72,90周期
84 | # end_vali_date = get_datetime_from_date_str('2021-04-16')
85 | config.IF_ACTUAL_PREDICT = True
86 |
87 | config.START_DATE = "2003-05-01"
88 |
89 | # 前29后1
90 | config.PREDICT_PERIOD = '60'
91 |
92 | # 固定日期
93 | config.START_EVAL_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE), -10))
94 | # config.START_EVAL_DATE = "2021-05-22"
95 |
96 | # OUTPUT_DATE 向右3工作日
97 | config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE), +1))
98 |
99 | # 创建预测结果表
100 | StockData.create_predict_result_table_psql(list_tic=config.SINGLE_A_STOCK_CODE)
101 |
102 | # 更新股票数据
103 | StockData.update_stock_data_to_sqlite(list_stock_code=config.SINGLE_A_STOCK_CODE)
104 |
105 | # 预测的截止日期
106 | end_vali_date = get_datetime_from_date_str(config.END_DATE)
107 |
108 | # 获取 N 个日期list
109 | list_begin_vali_date = query_begin_vali_date_list_by_agent_name(agent_item, end_vali_date)
110 |
111 | # 循环 vali_date_list 训练7次
112 | for vali_days_count, begin_vali_date in list_begin_vali_date:
113 |
114 | # config.START_EVAL_DATE = str(begin_vali_date)
115 |
116 | # 更新工作日标记,用于 run_single.py 加载训练过的 weights 文件
117 | config.VALI_DAYS_FLAG = str(vali_days_count)
118 |
119 | # config.PREDICT_PERIOD = str(vali_days_count)
120 |
121 | # weights 文件目录
122 | # model_folder_path = f'./{config.AGENT_NAME}/single/{config.SINGLE_A_STOCK_CODE[0]}' \
123 | # f'/single_{config.VALI_DAYS_FLAG}'
124 | model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}' \
125 | f'/single_{config.VALI_DAYS_FLAG}'
126 |
127 | # 如果存在目录则预测
128 | if os.path.exists(model_folder_path):
129 |
130 | print('\r\n')
131 | print('#' * 40)
132 | print('config.AGENT_NAME', config.AGENT_NAME)
133 | print('# 预测周期', config.START_EVAL_DATE, '-', config.END_DATE)
134 | print('# 模型的 work_days', vali_days_count)
135 | print('# model_folder_path', model_folder_path)
136 | print('# initial_capital', initial_capital)
137 | print('# max_stock', max_stock)
138 |
139 | initial_stocks = np.zeros(len(config.SINGLE_A_STOCK_CODE), dtype=np.float32)
140 | initial_stocks[0] = 100.0
141 |
142 | # 获取超参
143 | model_name = agent_item + '_' + str(vali_days_count)
144 |
145 | hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \
146 | eval_reward_scale, training_times, time_point \
147 | = query_model_hyper_parameters_sqlite(model_name=model_name)
148 |
149 | if if_on_policy == 1:
150 | if_on_policy = True
151 | else:
152 | if_on_policy = False
153 | pass
154 |
155 | config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id)
156 |
157 | # 获得Agent参数
158 | agent_class = None
159 | train_reward_scaling = 2 ** train_reward_scale
160 | eval_reward_scaling = 2 ** eval_reward_scale
161 |
162 | # 模型名称
163 | config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0]
164 |
165 | if config.AGENT_NAME == 'AgentPPO':
166 | agent_class = AgentPPO()
167 | elif config.AGENT_NAME == 'AgentSAC':
168 | agent_class = AgentSAC()
169 | elif config.AGENT_NAME == 'AgentTD3':
170 | agent_class = AgentTD3()
171 | elif config.AGENT_NAME == 'AgentDDPG':
172 | agent_class = AgentDDPG()
173 | elif config.AGENT_NAME == 'AgentModSAC':
174 | agent_class = AgentModSAC()
175 | elif config.AGENT_NAME == 'AgentDuelingDQN':
176 | agent_class = AgentDuelingDQN()
177 | elif config.AGENT_NAME == 'AgentSharedSAC':
178 | agent_class = AgentSharedSAC()
179 | elif config.AGENT_NAME == 'AgentDoubleDQN':
180 | agent_class = AgentDoubleDQN()
181 | pass
182 |
183 | # 预测周期
184 | work_days = int(str(hyper_parameters_model_name).split('_')[1])
185 |
186 | args = Arguments(if_on_policy=if_on_policy)
187 | args.agent = agent_class
188 |
189 | args.gpu_id = 0
190 | # args.agent.if_use_gae = if_use_gae
191 | args.agent.lambda_entropy = 0.04
192 |
193 | tech_indicator_list = [
194 | 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30',
195 | 'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST
196 |
197 | gamma = 0.99
198 |
199 | buy_cost_pct = 0.003
200 | sell_cost_pct = 0.003
201 | start_date = config.START_DATE
202 | start_eval_date = config.START_EVAL_DATE
203 | end_eval_date = config.END_DATE
204 |
205 | args.env = StockTradingEnvSingle(cwd='', gamma=gamma, max_stock=max_stock,
206 | initial_capital=initial_capital,
207 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct,
208 | start_date=start_date,
209 | end_date=start_eval_date, env_eval_date=end_eval_date,
210 | ticker_list=config.SINGLE_A_STOCK_CODE,
211 | tech_indicator_list=tech_indicator_list,
212 | initial_stocks=initial_stocks,
213 | if_eval=True)
214 |
215 | args.env_eval = StockTradingEnvSingle(cwd='', gamma=gamma, max_stock=max_stock,
216 | initial_capital=initial_capital,
217 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct,
218 | start_date=start_date,
219 | end_date=start_eval_date, env_eval_date=end_eval_date,
220 | ticker_list=config.SINGLE_A_STOCK_CODE,
221 | tech_indicator_list=tech_indicator_list,
222 | initial_stocks=initial_stocks,
223 | if_eval=True)
224 |
225 | args.env.target_return = 100
226 | args.env_eval.target_return = 100
227 |
228 | # 奖励 比例
229 | args.env.reward_scaling = train_reward_scaling
230 | args.env_eval.reward_scaling = eval_reward_scaling
231 |
232 | print('train/eval reward scaling:', args.env.reward_scaling, args.env_eval.reward_scaling)
233 |
234 | # Hyperparameters
235 | args.gamma = gamma
236 | # ----
237 | args.break_step = break_step
238 | # ----
239 |
240 | args.net_dim = 2 ** 9
241 | args.max_step = args.env.max_step
242 |
243 | # ----
244 | # args.max_memo = args.max_step * 4
245 | args.max_memo = (args.max_step - 1) * 8
246 | # ----
247 |
248 | # ----
249 | args.batch_size = 2 ** 12
250 | # args.batch_size = 2305
251 | # ----
252 |
253 | # ----
254 | # args.repeat_times = 2 ** 3
255 | args.repeat_times = 2 ** 4
256 | # ----
257 |
258 | args.eval_gap = 2 ** 4
259 | args.eval_times1 = 2 ** 3
260 | args.eval_times2 = 2 ** 5
261 |
262 | # ----
263 | args.if_allow_break = False
264 | # args.if_allow_break = True
265 | # ----
266 |
267 | # ----------------------------
268 | args.init_before_training()
269 |
270 | '''basic arguments'''
271 | cwd = args.cwd
272 | env = args.env
273 | agent = args.agent
274 | gpu_id = args.gpu_id # necessary for Evaluator?
275 |
276 | '''training arguments'''
277 | net_dim = args.net_dim
278 | max_memo = args.max_memo
279 | break_step = args.break_step
280 | batch_size = args.batch_size
281 | target_step = args.target_step
282 | repeat_times = args.repeat_times
283 | if_break_early = args.if_allow_break
284 | if_per = args.if_per
285 | gamma = args.gamma
286 | reward_scale = args.reward_scale
287 |
288 | '''evaluating arguments'''
289 | eval_gap = args.eval_gap
290 | eval_times1 = args.eval_times1
291 | eval_times2 = args.eval_times2
292 | if args.env_eval is not None:
293 | env_eval = args.env_eval
294 | elif args.env_eval in set(gym.envs.registry.env_specs.keys()):
295 | env_eval = PreprocessEnv(gym.make(env.env_name))
296 | else:
297 | env_eval = deepcopy(env)
298 |
299 | del args # In order to show these hyper-parameters clearly, I put them above.
300 |
301 | '''init: environment'''
302 | max_step = env.max_step
303 | state_dim = env.state_dim
304 | action_dim = env.action_dim
305 | if_discrete = env.if_discrete
306 |
307 | '''init: Agent, ReplayBuffer, Evaluator'''
308 | agent.init(net_dim, state_dim, action_dim, if_per)
309 |
310 | # ----
311 | # work_days,周期数,用于存储和提取训练好的模型
312 | model_file_path = f'{model_folder_path}/actor.pth'
313 |
314 | # 如果model存在,则加载
315 | if os.path.exists(model_file_path):
316 | agent.save_load_model(model_folder_path, if_save=False)
317 |
318 | '''prepare for training'''
319 | agent.state = env.reset()
320 |
321 | episode_return = 0.0 # sum of rewards in an episode
322 | episode_step = 1
323 | max_step = env.max_step
324 | if_discrete = env.if_discrete
325 |
326 | state = env.reset()
327 |
328 | with torch.no_grad(): # speed up running
329 |
330 | # for episode_step in range(max_step):
331 | while True:
332 | s_tensor = torch.as_tensor((state,), device=agent.device)
333 | a_tensor = agent.act(s_tensor)
334 | if if_discrete:
335 | a_tensor = a_tensor.argmax(dim=1)
336 | action = a_tensor.detach().cpu().numpy()[
337 | 0] # not need detach(), because with torch.no_grad() outside
338 | state, reward, done, _ = env.step(action)
339 | episode_return += reward
340 | if done:
341 | break
342 | pass
343 | pass
344 |
345 | # 获取要预测的日期,保存到数据库中
346 | for item in env.list_buy_or_sell_output:
347 | tic, date, action, hold, day, episode_return = item
348 | if str(date) == config.OUTPUT_DATE:
349 | # 简单计算一次,低买高卖的最大回报
350 | max_return = calc_max_return(env.price_ary, env.initial_capital)
351 |
352 | # 找到要预测的那一天,存储到psql
353 | StockData.update_predict_result_to_psql(psql=psql_object, agent=config.AGENT_NAME,
354 | vali_period_value=config.VALI_DAYS_FLAG,
355 | pred_period_name=config.PREDICT_PERIOD,
356 | tic=tic, date=date, action=action,
357 | hold=hold,
358 | day=day, episode_return=episode_return,
359 | max_return=max_return,
360 | trade_detail=env.output_text_trade_detail)
361 |
362 | break
363 |
364 | pass
365 | pass
366 | pass
367 | # episode_return = getattr(env, 'episode_return', episode_return)
368 | pass
369 | else:
370 | print('未找到模型文件', model_file_path)
371 | pass
372 | # ----
373 |
374 | pass
375 | pass
376 |
377 | psql_object.close()
378 | pass
379 |
380 | # 结束预测的时间
381 | time_end = datetime.now()
382 | duration = (time_end - time_begin).seconds
383 | print('检测耗时', duration, '秒')
384 | pass
385 |
--------------------------------------------------------------------------------
/predict_batch_psql.py:
--------------------------------------------------------------------------------
1 | from stock_data import StockData
2 | from train_helper import query_model_hyper_parameters_sqlite, query_begin_vali_date_list_by_agent_name
3 | from utils.psqldb import Psqldb
4 | from agent import *
5 | from utils.date_time import *
6 | from env_batch import StockTradingEnvBatch
7 | from run_batch import *
8 | from datetime import datetime
9 |
10 | from utils.sqlite import SQLite
11 |
12 |
13 | def calc_max_return(stock_index_temp, price_ary, initial_capital_temp):
14 | max_return_temp = 0
15 | min_value = 0
16 |
17 | assert price_ary.shape[0] > 1
18 |
19 | count_price = price_ary.shape[0]
20 |
21 | for index_left in range(0, count_price - 1):
22 |
23 | for index_right in range(index_left + 1, count_price):
24 |
25 | assert price_ary[index_left][stock_index_temp] > 0
26 |
27 | assert price_ary[index_right][stock_index_temp] > 0
28 |
29 | temp_value = price_ary[index_right][stock_index_temp] - price_ary[index_left][stock_index_temp]
30 |
31 | if temp_value > max_return_temp:
32 | max_return_temp = temp_value
33 | # max_value = price_ary[index1][0]
34 | min_value = price_ary[index_right][stock_index_temp]
35 | pass
36 | pass
37 |
38 | if min_value == 0:
39 | ret = 0
40 | else:
41 | ret = (initial_capital_temp / min_value * max_return_temp + initial_capital_temp) / initial_capital_temp
42 |
43 | return ret
44 |
45 |
46 | if __name__ == '__main__':
47 | # 预测,并保存结果到 postgresql 数据库
48 | # 开始预测的时间
49 | time_begin = datetime.now()
50 |
51 | # 创建 预测汇总表
52 | StockData.create_predict_summary_table_psql(table_name='predict_summary')
53 |
54 | # 清空 预测汇总表
55 | StockData.clear_predict_summary_table_psql(table_name='predict_summary')
56 |
57 | # TODO
58 | # 获取今天日期,判断是否为工作日
59 | weekday = get_datetime_from_date_str(get_today_date()).weekday()
60 | if 0 < weekday < 6:
61 | # 工作日
62 | now = datetime.now().strftime("%H:%M")
63 | t1 = '09:00'
64 |
65 | if now >= t1:
66 | # 如果是工作日,大于等于 09:00,则预测明天
67 | config.OUTPUT_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), +1))
68 | pass
69 | else:
70 | # 如果是工作日,小于 09:00,则预测今天
71 | config.OUTPUT_DATE = get_today_date()
72 | pass
73 | pass
74 | else:
75 | # 假期
76 | # 下一个工作日
77 | config.OUTPUT_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), +1))
78 | pass
79 | pass
80 |
81 | # 如果不是工作日,则预测下一个工作日
82 | # config.OUTPUT_DATE = '2021-08-06'
83 |
84 | config.START_DATE = "2004-05-01"
85 |
86 | # 股票的顺序,不要改变
87 | config.BATCH_A_STOCK_CODE = StockData.get_batch_a_share_code_list_string(table_name='tic_list_275')
88 |
89 | fe_table_name = 'fe_fillzero_predict'
90 |
91 | initial_capital = 150000 * len(config.BATCH_A_STOCK_CODE)
92 |
93 | max_stock = 50000
94 |
95 | config.IF_ACTUAL_PREDICT = True
96 |
97 | # 预测周期
98 | config.PREDICT_PERIOD = '10'
99 |
100 | # psql对象
101 | psql_object = Psqldb(database=config.PSQL_DATABASE, user=config.PSQL_USER,
102 | password=config.PSQL_PASSWORD, host=config.PSQL_HOST, port=config.PSQL_PORT)
103 |
104 | # if_first_time = True
105 |
106 | # 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(),
107 | # AgentDoubleDQN 单进程好用?
108 | # 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC()
109 |
110 | # for agent_item in ['AgentTD3', 'AgentSAC', 'AgentPPO', 'AgentDDPG', 'AgentModSAC', ]:
111 | for agent_item in ['AgentTD3', ]:
112 |
113 | config.AGENT_NAME = agent_item
114 |
115 | break_step = int(3e6)
116 |
117 | if_on_policy = False
118 | # if_use_gae = False
119 |
120 | # 固定日期
121 | config.START_EVAL_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE),
122 | -1 * int(config.PREDICT_PERIOD)))
123 | # config.START_EVAL_DATE = "2021-05-22"
124 |
125 | # OUTPUT_DATE 向右1工作日
126 | config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE), +1))
127 |
128 | # # 更新股票数据
129 | # # 只更新一次啊
130 | # if if_first_time is True:
131 | # # 创建预测结果表
132 | # StockData.create_predict_result_table_psql(list_tic=config.BATCH_A_STOCK_CODE)
133 | #
134 | # StockData.update_stock_data_to_sqlite(list_stock_code=config.BATCH_A_STOCK_CODE, adjustflag='3',
135 | # table_name=fe_table_name, if_incremental_update=False)
136 | # if_first_time = False
137 | # pass
138 |
139 | # 预测的截止日期
140 | end_vali_date = get_datetime_from_date_str(config.END_DATE)
141 |
142 | # 获取 N 个日期list
143 | list_begin_vali_date = query_begin_vali_date_list_by_agent_name(agent_item, end_vali_date)
144 |
145 | # 循环 vali_date_list 训练7次
146 | for vali_days_count, begin_vali_date in list_begin_vali_date:
147 |
148 | # 更新工作日标记,用于 run_batch.py 加载训练过的 weights 文件
149 | config.VALI_DAYS_FLAG = str(vali_days_count)
150 |
151 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \
152 | f'batch_{config.VALI_DAYS_FLAG}'
153 |
154 | # 如果存在目录则预测
155 | if os.path.exists(model_folder_path):
156 |
157 | print('\r\n')
158 | print('#' * 40)
159 | print('config.AGENT_NAME', config.AGENT_NAME)
160 | print('# 预测周期', config.START_EVAL_DATE, '-', config.END_DATE)
161 | print('# 模型的 work_days', vali_days_count)
162 | print('# model_folder_path', model_folder_path)
163 | print('# initial_capital', initial_capital)
164 | print('# max_stock', max_stock)
165 |
166 | # initial_stocks = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32)
167 | initial_stocks = np.zeros(len(config.BATCH_A_STOCK_CODE), dtype=np.float32)
168 |
169 | # 默认持有一手
170 | # initial_stocks = initial_stocks * 100.0
171 |
172 | # 获取超参
173 | model_name = agent_item + '_' + str(vali_days_count)
174 |
175 | # hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \
176 | # eval_reward_scale, training_times, time_point \
177 | # = query_model_hyper_parameters_sqlite(model_name=model_name)
178 |
179 | hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \
180 | eval_reward_scale, training_times, time_point, state_amount_scale, state_price_scale, state_stocks_scale, \
181 | state_tech_scale = query_model_hyper_parameters_sqlite(model_name=model_name)
182 |
183 | if if_on_policy == 1:
184 | if_on_policy = True
185 | else:
186 | if_on_policy = False
187 | pass
188 |
189 | config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id)
190 |
191 | # 获得Agent参数
192 | agent_class = None
193 |
194 | # 模型名称
195 | config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0]
196 |
197 | if config.AGENT_NAME == 'AgentPPO':
198 | agent_class = AgentPPO()
199 | elif config.AGENT_NAME == 'AgentSAC':
200 | agent_class = AgentSAC()
201 | elif config.AGENT_NAME == 'AgentTD3':
202 | agent_class = AgentTD3()
203 | elif config.AGENT_NAME == 'AgentDDPG':
204 | agent_class = AgentDDPG()
205 | elif config.AGENT_NAME == 'AgentModSAC':
206 | agent_class = AgentModSAC()
207 | elif config.AGENT_NAME == 'AgentDuelingDQN':
208 | agent_class = AgentDuelingDQN()
209 | elif config.AGENT_NAME == 'AgentSharedSAC':
210 | agent_class = AgentSharedSAC()
211 | elif config.AGENT_NAME == 'AgentDoubleDQN':
212 | agent_class = AgentDoubleDQN()
213 | pass
214 |
215 | args = Arguments(if_on_policy=if_on_policy)
216 | args.agent = agent_class
217 |
218 | args.gpu_id = 0
219 | # args.agent.if_use_gae = if_use_gae
220 | args.agent.lambda_entropy = 0.04
221 |
222 | tech_indicator_list = [
223 | 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30',
224 | 'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST
225 |
226 | gamma = 0.99
227 |
228 | buy_cost_pct = 0.003
229 | sell_cost_pct = 0.003
230 | start_date = config.START_DATE
231 | start_eval_date = config.START_EVAL_DATE
232 | end_eval_date = config.END_DATE
233 |
234 | args.env = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock,
235 | initial_capital=initial_capital,
236 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct,
237 | start_date=start_date,
238 | end_date=start_eval_date, env_eval_date=end_eval_date,
239 | ticker_list=config.BATCH_A_STOCK_CODE,
240 | tech_indicator_list=tech_indicator_list,
241 | initial_stocks=initial_stocks,
242 | if_eval=True, fe_table_name=fe_table_name)
243 |
244 | args.env_eval = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock,
245 | initial_capital=initial_capital,
246 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct,
247 | start_date=start_date,
248 | end_date=start_eval_date, env_eval_date=end_eval_date,
249 | ticker_list=config.BATCH_A_STOCK_CODE,
250 | tech_indicator_list=tech_indicator_list,
251 | initial_stocks=initial_stocks,
252 | if_eval=True, fe_table_name=fe_table_name)
253 |
254 | args.env.target_return = 100
255 | args.env_eval.target_return = 100
256 |
257 | # 奖励 比例
258 | args.env.reward_scale = train_reward_scale
259 | args.env_eval.reward_scale = eval_reward_scale
260 |
261 | args.env.state_amount_scale = state_amount_scale
262 | args.env.state_price_scale = state_price_scale
263 | args.env.state_stocks_scale = state_stocks_scale
264 | args.env.state_tech_scale = state_tech_scale
265 |
266 | args.env_eval.state_amount_scale = state_amount_scale
267 | args.env_eval.state_price_scale = state_price_scale
268 | args.env_eval.state_stocks_scale = state_stocks_scale
269 | args.env_eval.state_tech_scale = state_tech_scale
270 |
271 | print('train reward_scale', args.env.reward_scale)
272 | print('eval reward_scale', args.env_eval.reward_scale)
273 |
274 | print('state_amount_scale', state_amount_scale)
275 | print('state_price_scale', state_price_scale)
276 | print('state_stocks_scale', state_stocks_scale)
277 | print('state_tech_scale', state_tech_scale)
278 |
279 | # Hyperparameters
280 | args.gamma = gamma
281 | # ----
282 | args.break_step = break_step
283 | # ----
284 |
285 | args.net_dim = 2 ** 9
286 | args.max_step = args.env.max_step
287 |
288 | # ----
289 | # args.max_memo = args.max_step * 4
290 | args.max_memo = (args.max_step - 1) * 8
291 | # ----
292 |
293 | # ----
294 | args.batch_size = 2 ** 12
295 | # args.batch_size = 2305
296 | # ----
297 |
298 | # ----
299 | # args.repeat_times = 2 ** 3
300 | args.repeat_times = 2 ** 4
301 | # ----
302 |
303 | args.eval_gap = 2 ** 4
304 | args.eval_times1 = 2 ** 3
305 | args.eval_times2 = 2 ** 5
306 |
307 | # ----
308 | args.if_allow_break = False
309 | # args.if_allow_break = True
310 | # ----
311 |
312 | # ----------------------------
313 | args.init_before_training()
314 |
315 | '''basic arguments'''
316 | cwd = args.cwd
317 | env = args.env
318 | agent = args.agent
319 | gpu_id = args.gpu_id # necessary for Evaluator?
320 |
321 | '''training arguments'''
322 | net_dim = args.net_dim
323 | max_memo = args.max_memo
324 | break_step = args.break_step
325 | batch_size = args.batch_size
326 | target_step = args.target_step
327 | repeat_times = args.repeat_times
328 | if_break_early = args.if_allow_break
329 | if_per = args.if_per
330 | gamma = args.gamma
331 | reward_scale = args.reward_scale
332 |
333 | '''evaluating arguments'''
334 | eval_gap = args.eval_gap
335 | eval_times1 = args.eval_times1
336 | eval_times2 = args.eval_times2
337 | if args.env_eval is not None:
338 | env_eval = args.env_eval
339 | elif args.env_eval in set(gym.envs.registry.env_specs.keys()):
340 | env_eval = PreprocessEnv(gym.make(env.env_name))
341 | else:
342 | env_eval = deepcopy(env)
343 |
344 | del args # In order to show these hyper-parameters clearly, I put them above.
345 |
346 | '''init: environment'''
347 | max_step = env.max_step
348 | state_dim = env.state_dim
349 | action_dim = env.action_dim
350 | if_discrete = env.if_discrete
351 |
352 | '''init: Agent, ReplayBuffer, Evaluator'''
353 | agent.init(net_dim, state_dim, action_dim, if_per)
354 |
355 | # ----
356 | # work_days,周期数,用于存储和提取训练好的模型
357 | model_file_path = f'{model_folder_path}/actor.pth'
358 |
359 | # 如果model存在,则加载
360 | if os.path.exists(model_file_path):
361 | agent.save_load_model(model_folder_path, if_save=False)
362 |
363 | '''prepare for training'''
364 | agent.state = env.reset()
365 |
366 | episode_return_total = 0.0 # sum of rewards in an episode
367 | episode_step = 1
368 | max_step = env.max_step
369 | if_discrete = env.if_discrete
370 |
371 | state = env.reset()
372 |
373 | with torch.no_grad(): # speed up running
374 |
375 | # for episode_step in range(max_step):
376 | while True:
377 | s_tensor = torch.as_tensor((state,), device=agent.device)
378 | a_tensor = agent.act(s_tensor)
379 | if if_discrete:
380 | a_tensor = a_tensor.argmax(dim=1)
381 | action = a_tensor.detach().cpu().numpy()[
382 | 0] # not need detach(), because with torch.no_grad() outside
383 | state, reward, done, _ = env.step(action)
384 | episode_return_total += reward
385 |
386 | if done:
387 | break
388 | pass
389 | pass
390 |
391 | # 根据tic查询名称
392 | sqlite_query_name_by_tic = SQLite(dbname=config.STOCK_DB_PATH)
393 |
394 | # 获取要预测的日期,保存到数据库中
395 | for stock_index in range(len(env.list_buy_or_sell_output)):
396 | list_one_stock = env.list_buy_or_sell_output[stock_index]
397 |
398 | trade_detail_all = ''
399 |
400 | for item in list_one_stock:
401 | tic, date, action, hold, day, episode_return, trade_detail = item
402 |
403 | trade_detail_all += trade_detail
404 |
405 | if str(date) == config.OUTPUT_DATE:
406 | # 简单计算一次,低买高卖的最大回报
407 | max_return = calc_max_return(stock_index_temp=stock_index, price_ary=env.price_ary,
408 | initial_capital_temp=env.initial_capital)
409 |
410 | stock_name = StockData.get_stock_name_by_tic(sqlite=sqlite_query_name_by_tic,
411 | tic=tic,
412 | table_name='tic_list_275')
413 |
414 | # 找到要预测的那一天,存储到psql
415 | StockData.update_predict_summary_result_to_psql(psql=psql_object,
416 | agent=config.AGENT_NAME,
417 | vali_period_value=config.VALI_DAYS_FLAG,
418 | pred_period_name=config.PREDICT_PERIOD,
419 | tic=tic, name=stock_name, date=date,
420 | action=action,
421 | hold=hold,
422 | day=day,
423 | episode_return=episode_return,
424 | max_return=max_return,
425 | trade_detail=trade_detail_all,
426 | table_name='predict_summary')
427 |
428 | break
429 |
430 | pass
431 | pass
432 | pass
433 | pass
434 |
435 | # 关闭数据库连接
436 | sqlite_query_name_by_tic.close()
437 |
438 | # episode_return = getattr(env, 'episode_return', episode_return)
439 | pass
440 | else:
441 | print('未找到模型文件', model_file_path)
442 | pass
443 | # ----
444 | pass
445 | pass
446 |
447 | psql_object.close()
448 | pass
449 |
450 | # 结束预测的时间
451 | time_end = datetime.now()
452 | duration = (time_end - time_begin).total_seconds()
453 | print('检测耗时', duration, '秒')
454 | pass
455 |
--------------------------------------------------------------------------------
/run_single.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gym
3 | import time
4 | import torch
5 | import numpy as np
6 | import numpy.random as rd
7 | from copy import deepcopy
8 | from ElegantRL_master.elegantrl.replay import ReplayBuffer, ReplayBufferMP
9 | from ElegantRL_master.elegantrl.env import PreprocessEnv
10 | import config
11 |
12 | """[ElegantRL](https://github.com/AI4Finance-LLC/ElegantRL)"""
13 |
14 |
15 | class Arguments:
16 | def __init__(self, agent=None, env=None, gpu_id=None, if_on_policy=False):
17 | self.agent = agent # Deep Reinforcement Learning algorithm
18 |
19 | self.cwd = None # current work directory. cwd is None means set it automatically
20 | self.env = env # the environment for training
21 | self.env_eval = None # the environment for evaluating
22 | self.gpu_id = gpu_id # choose the GPU for running. gpu_id is None means set it automatically
23 |
24 | '''Arguments for training (off-policy)'''
25 | self.net_dim = 2 ** 8 # the network width
26 | self.batch_size = 2 ** 8 # num of transitions sampled from replay buffer.
27 | self.repeat_times = 2 ** 0 # repeatedly update network to keep critic's loss small
28 | self.target_step = 2 ** 10 # collect target_step, then update network
29 | self.max_memo = 2 ** 17 # capacity of replay buffer
30 | if if_on_policy: # (on-policy)
31 | self.net_dim = 2 ** 9
32 | self.batch_size = 2 ** 9
33 | self.repeat_times = 2 ** 4
34 | self.target_step = 2 ** 12
35 | self.max_memo = self.target_step
36 | self.gamma = 0.99 # discount factor of future rewards
37 | self.reward_scale = 2 ** 0 # an approximate target reward usually be closed to 256
38 | self.if_per = False # Prioritized Experience Replay for sparse reward
39 |
40 | self.rollout_num = 2 # the number of rollout workers (larger is not always faster)
41 | self.num_threads = 8 # cpu_num for evaluate model, torch.set_num_threads(self.num_threads)
42 |
43 | '''Arguments for evaluate'''
44 | self.break_step = 2 ** 20 # break training after 'total_step > break_step'
45 | self.if_remove = True # remove the cwd folder? (True, False, None:ask me)
46 | self.if_allow_break = True # allow break training when reach goal (early termination)
47 | self.eval_gap = 2 ** 5 # evaluate the agent per eval_gap seconds
48 | self.eval_times1 = 2 ** 2 # evaluation times
49 | self.eval_times2 = 2 ** 4 # evaluation times if 'eval_reward > max_reward'
50 | self.random_seed = 0 # initialize random seed in self.init_before_training()
51 |
52 | def init_before_training(self, if_main=True):
53 | if self.agent is None:
54 | raise RuntimeError('\n| Why agent=None? Assignment args.agent = AgentXXX please.')
55 | if not hasattr(self.agent, 'init'):
56 | raise RuntimeError('\n| There should be agent=AgentXXX() instead of agent=AgentXXX')
57 | if self.env is None:
58 | raise RuntimeError('\n| Why env=None? Assignment args.env = XxxEnv() please.')
59 | if isinstance(self.env, str) or not hasattr(self.env, 'env_name'):
60 | raise RuntimeError('\n| What is env.env_name? use env=PreprocessEnv(env). It is a Wrapper.')
61 |
62 | '''set gpu_id automatically'''
63 | if self.gpu_id is None: # set gpu_id automatically
64 | import sys
65 | self.gpu_id = sys.argv[-1][-4]
66 | else:
67 | self.gpu_id = str(self.gpu_id)
68 | if not self.gpu_id.isdigit(): # set gpu_id as '0' in default
69 | self.gpu_id = '0'
70 |
71 | '''set cwd automatically'''
72 | if self.cwd is None:
73 | # ----
74 | agent_name = self.agent.__class__.__name__
75 | # self.cwd = f'./{agent_name}/{self.env.env_name}_{self.gpu_id}'
76 | # self.cwd = f'./{agent_name}/{self.env.env_name}'
77 | self.cwd = f'./{config.WEIGHTS_PATH}/{self.env.env_name}'
78 |
79 | # model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}' \
80 | # f'/single_{config.VALI_DAYS_FLAG}'
81 | # ----
82 |
83 | if if_main:
84 | print(f'| GPU id: {self.gpu_id}, cwd: {self.cwd}')
85 |
86 | import shutil # remove history according to bool(if_remove)
87 | if self.if_remove is None:
88 | self.if_remove = bool(input("PRESS 'y' to REMOVE: {}? ".format(self.cwd)) == 'y')
89 | if self.if_remove:
90 | shutil.rmtree(self.cwd, ignore_errors=True)
91 | print("| Remove history")
92 | os.makedirs(self.cwd, exist_ok=True)
93 |
94 | os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpu_id)
95 | torch.set_num_threads(self.num_threads)
96 | torch.set_default_dtype(torch.float32)
97 | torch.manual_seed(self.random_seed)
98 | np.random.seed(self.random_seed)
99 |
100 |
101 | '''single process training'''
102 |
103 |
104 | def train_and_evaluate(args):
105 | args.init_before_training()
106 |
107 | '''basic arguments'''
108 | cwd = args.cwd
109 | env = args.env
110 | agent = args.agent
111 | gpu_id = args.gpu_id # necessary for Evaluator?
112 |
113 | '''training arguments'''
114 | net_dim = args.net_dim
115 | max_memo = args.max_memo
116 | break_step = args.break_step
117 | batch_size = args.batch_size
118 | target_step = args.target_step
119 | repeat_times = args.repeat_times
120 | if_break_early = args.if_allow_break
121 | if_per = args.if_per
122 | gamma = args.gamma
123 | reward_scale = args.reward_scale
124 |
125 | '''evaluating arguments'''
126 | eval_gap = args.eval_gap
127 | eval_times1 = args.eval_times1
128 | eval_times2 = args.eval_times2
129 | if args.env_eval is not None:
130 | env_eval = args.env_eval
131 | elif args.env_eval in set(gym.envs.registry.env_specs.keys()):
132 | env_eval = PreprocessEnv(gym.make(env.env_name))
133 | else:
134 | env_eval = deepcopy(env)
135 |
136 | del args # In order to show these hyper-parameters clearly, I put them above.
137 |
138 | '''init: environment'''
139 | max_step = env.max_step
140 | state_dim = env.state_dim
141 | action_dim = env.action_dim
142 | if_discrete = env.if_discrete
143 |
144 | '''init: Agent, ReplayBuffer, Evaluator'''
145 | agent.init(net_dim, state_dim, action_dim, if_per)
146 |
147 | # ----
148 | # 目录 path
149 | model_folder_path = f'./{config.AGENT_NAME}/single/{config.SINGLE_A_STOCK_CODE[0]}' \
150 | f'/single_{config.VALI_DAYS_FLAG}'
151 | # 文件 path
152 | model_file_path = f'{model_folder_path}/actor.pth'
153 |
154 | # 如果model存在,则加载
155 | if os.path.exists(model_file_path):
156 | agent.save_load_model(model_folder_path, if_save=False)
157 | pass
158 | # ----
159 |
160 | if_on_policy = getattr(agent, 'if_on_policy', False)
161 |
162 | buffer = ReplayBuffer(max_len=max_memo + max_step, state_dim=state_dim, action_dim=1 if if_discrete else action_dim,
163 | if_on_policy=if_on_policy, if_per=if_per, if_gpu=True)
164 |
165 | evaluator = Evaluator(cwd=cwd, agent_id=gpu_id, device=agent.device, env=env_eval,
166 | eval_gap=eval_gap, eval_times1=eval_times1, eval_times2=eval_times2, )
167 |
168 | '''prepare for training'''
169 | agent.state = env.reset()
170 | if if_on_policy:
171 | steps = 0
172 | else: # explore_before_training for off-policy
173 | with torch.no_grad(): # update replay buffer
174 | steps = explore_before_training(env, buffer, target_step, reward_scale, gamma)
175 |
176 | agent.update_net(buffer, target_step, batch_size, repeat_times) # pre-training and hard update
177 | agent.act_target.load_state_dict(agent.act.state_dict()) if getattr(agent, 'act_target', None) else None
178 | agent.cri_target.load_state_dict(agent.cri.state_dict()) if getattr(agent, 'cri_target', None) else None
179 | total_step = steps
180 |
181 | '''start training'''
182 | if_reach_goal = False
183 | while not ((if_break_early and if_reach_goal)
184 | or total_step > break_step
185 | or os.path.exists(f'{cwd}/stop')):
186 | steps = agent.explore_env(env, buffer, target_step, reward_scale, gamma)
187 | total_step += steps
188 |
189 | obj_a, obj_c = agent.update_net(buffer, target_step, batch_size, repeat_times)
190 |
191 | if_reach_goal = evaluator.evaluate_save(agent.act, steps, obj_a, obj_c)
192 | evaluator.draw_plot()
193 |
194 | print(f'| SavedDir: {cwd}\n| UsedTime: {time.time() - evaluator.start_time:.0f}')
195 |
196 |
197 | '''multiprocessing training'''
198 |
199 |
200 | def train_and_evaluate_mp(args):
201 | act_workers = args.rollout_num
202 | import multiprocessing as mp # Python built-in multiprocessing library
203 |
204 | pipe1_eva, pipe2_eva = mp.Pipe() # Pipe() for Process mp_evaluate_agent()
205 | pipe2_exp_list = list() # Pipe() for Process mp_explore_in_env()
206 |
207 | process_train = mp.Process(target=mp_train, args=(args, pipe2_eva, pipe2_exp_list))
208 | process_evaluate = mp.Process(target=mp_evaluate, args=(args, pipe1_eva))
209 | process = [process_train, process_evaluate]
210 |
211 | for worker_id in range(act_workers):
212 | exp_pipe1, exp_pipe2 = mp.Pipe(duplex=True)
213 | pipe2_exp_list.append(exp_pipe1)
214 | process.append(mp.Process(target=mp_explore, args=(args, exp_pipe2, worker_id)))
215 |
216 | [p.start() for p in process]
217 | process_evaluate.join()
218 | process_train.join()
219 | [p.terminate() for p in process]
220 |
221 |
222 | def mp_train(args, pipe1_eva, pipe1_exp_list):
223 | args.init_before_training(if_main=False)
224 |
225 | '''basic arguments'''
226 | env = args.env
227 | cwd = args.cwd
228 | agent = args.agent
229 | rollout_num = args.rollout_num
230 |
231 | '''training arguments'''
232 | net_dim = args.net_dim
233 | max_memo = args.max_memo
234 | break_step = args.break_step
235 | batch_size = args.batch_size
236 | target_step = args.target_step
237 | repeat_times = args.repeat_times
238 | if_break_early = args.if_allow_break
239 | if_per = args.if_per
240 | del args # In order to show these hyper-parameters clearly, I put them above.
241 |
242 | '''init: environment'''
243 | max_step = env.max_step
244 | state_dim = env.state_dim
245 | action_dim = env.action_dim
246 | if_discrete = env.if_discrete
247 |
248 | '''init: Agent, ReplayBuffer'''
249 | agent.init(net_dim, state_dim, action_dim, if_per)
250 |
251 | # ----
252 | # 目录 path
253 | model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}' \
254 | f'/single_{config.VALI_DAYS_FLAG}'
255 |
256 | # f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}/StockTradingEnv-v1'
257 |
258 | # 文件 path
259 | model_file_path = f'{model_folder_path}/actor.pth'
260 |
261 | # 如果model存在,则加载
262 | if os.path.exists(model_file_path):
263 | agent.save_load_model(model_folder_path, if_save=False)
264 | pass
265 | # ----
266 |
267 | if_on_policy = getattr(agent, 'if_on_policy', False)
268 |
269 | '''send'''
270 | pipe1_eva.send(agent.act) # send
271 | # act = pipe2_eva.recv() # recv
272 |
273 | buffer_mp = ReplayBufferMP(max_len=max_memo + max_step * rollout_num, if_on_policy=if_on_policy,
274 | state_dim=state_dim, action_dim=1 if if_discrete else action_dim,
275 | rollout_num=rollout_num, if_gpu=True, if_per=if_per)
276 |
277 | '''prepare for training'''
278 | if if_on_policy:
279 | steps = 0
280 | else: # explore_before_training for off-policy
281 | with torch.no_grad(): # update replay buffer
282 | steps = 0
283 | for i in range(rollout_num):
284 | pipe1_exp = pipe1_exp_list[i]
285 |
286 | # pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len]))
287 | buf_state, buf_other = pipe1_exp.recv()
288 |
289 | steps += len(buf_state)
290 | buffer_mp.extend_buffer(buf_state, buf_other, i)
291 |
292 | agent.update_net(buffer_mp, target_step, batch_size, repeat_times) # pre-training and hard update
293 | agent.act_target.load_state_dict(agent.act.state_dict()) if getattr(env, 'act_target', None) else None
294 | agent.cri_target.load_state_dict(agent.cri.state_dict()) if getattr(env, 'cri_target', None) in dir(
295 | agent) else None
296 | total_step = steps
297 | '''send'''
298 | pipe1_eva.send((agent.act, steps, 0, 0.5)) # send
299 | # act, steps, obj_a, obj_c = pipe2_eva.recv() # recv
300 |
301 | '''start training'''
302 | if_solve = False
303 | while not ((if_break_early and if_solve)
304 | or total_step > break_step
305 | or os.path.exists(f'{cwd}/stop')):
306 | '''update ReplayBuffer'''
307 | steps = 0 # send by pipe1_eva
308 | for i in range(rollout_num):
309 | pipe1_exp = pipe1_exp_list[i]
310 | '''send'''
311 | pipe1_exp.send(agent.act)
312 | # agent.act = pipe2_exp.recv()
313 | '''recv'''
314 | # pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len]))
315 | buf_state, buf_other = pipe1_exp.recv()
316 |
317 | steps += len(buf_state)
318 | buffer_mp.extend_buffer(buf_state, buf_other, i)
319 | total_step += steps
320 |
321 | '''update network parameters'''
322 | obj_a, obj_c = agent.update_net(buffer_mp, target_step, batch_size, repeat_times)
323 |
324 | '''saves the agent with max reward'''
325 | '''send'''
326 | pipe1_eva.send((agent.act, steps, obj_a, obj_c))
327 | # q_i_eva_get = pipe2_eva.recv()
328 |
329 | if_solve = pipe1_eva.recv()
330 |
331 | if pipe1_eva.poll():
332 | '''recv'''
333 | # pipe2_eva.send(if_solve)
334 | if_solve = pipe1_eva.recv()
335 |
336 | buffer_mp.print_state_norm(env.neg_state_avg if hasattr(env, 'neg_state_avg') else None,
337 | env.div_state_std if hasattr(env, 'div_state_std') else None) # 2020-12-12
338 |
339 | '''send'''
340 | pipe1_eva.send('stop')
341 | # q_i_eva_get = pipe2_eva.recv()
342 | time.sleep(4)
343 |
344 |
345 | def mp_explore(args, pipe2_exp, worker_id):
346 | args.init_before_training(if_main=False)
347 |
348 | '''basic arguments'''
349 | env = args.env
350 | agent = args.agent
351 | rollout_num = args.rollout_num
352 |
353 | '''training arguments'''
354 | net_dim = args.net_dim
355 | max_memo = args.max_memo
356 | target_step = args.target_step
357 | gamma = args.gamma
358 | if_per = args.if_per
359 | reward_scale = args.reward_scale
360 |
361 | random_seed = args.random_seed
362 | torch.manual_seed(random_seed + worker_id)
363 | np.random.seed(random_seed + worker_id)
364 | del args # In order to show these hyper-parameters clearly, I put them above.
365 |
366 | '''init: environment'''
367 | max_step = env.max_step
368 | state_dim = env.state_dim
369 | action_dim = env.action_dim
370 | if_discrete = env.if_discrete
371 |
372 | '''init: Agent, ReplayBuffer'''
373 | agent.init(net_dim, state_dim, action_dim, if_per)
374 |
375 | # ----
376 | # 目录 path
377 | # model_folder_path = f'./{config.AGENT_NAME}/single/{config.SINGLE_A_STOCK_CODE[0]}' \
378 | # f'/single_{config.VALI_DAYS_FLAG}'
379 |
380 | model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}' \
381 | f'/single_{config.VALI_DAYS_FLAG}'
382 |
383 | # 文件 path
384 | model_file_path = f'{model_folder_path}/actor.pth'
385 |
386 | # 如果model存在,则加载
387 | if os.path.exists(model_file_path):
388 | agent.save_load_model(model_folder_path, if_save=False)
389 | pass
390 | # ----
391 |
392 | agent.state = env.reset()
393 |
394 | if_on_policy = getattr(agent, 'if_on_policy', False)
395 | buffer = ReplayBuffer(max_len=max_memo // rollout_num + max_step, if_on_policy=if_on_policy,
396 | state_dim=state_dim, action_dim=1 if if_discrete else action_dim,
397 | if_per=if_per, if_gpu=False)
398 |
399 | '''start exploring'''
400 | exp_step = target_step // rollout_num
401 | with torch.no_grad():
402 | if not if_on_policy:
403 | explore_before_training(env, buffer, exp_step, reward_scale, gamma)
404 |
405 | buffer.update_now_len_before_sample()
406 |
407 | pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len]))
408 | # buf_state, buf_other = pipe1_exp.recv()
409 |
410 | buffer.empty_buffer_before_explore()
411 |
412 | while True:
413 | agent.explore_env(env, buffer, exp_step, reward_scale, gamma)
414 |
415 | buffer.update_now_len_before_sample()
416 |
417 | pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len]))
418 | # buf_state, buf_other = pipe1_exp.recv()
419 |
420 | buffer.empty_buffer_before_explore()
421 |
422 | # pipe1_exp.send(agent.act)
423 | agent.act = pipe2_exp.recv()
424 |
425 |
426 | def mp_evaluate(args, pipe2_eva):
427 | args.init_before_training(if_main=True)
428 |
429 | '''basic arguments'''
430 | cwd = args.cwd
431 | env = args.env
432 | env_eval = env if args.env_eval is None else args.env_eval
433 | agent_id = args.gpu_id
434 |
435 | '''evaluating arguments'''
436 | eval_gap = args.eval_gap
437 | eval_times1 = args.eval_times1
438 | eval_times2 = args.eval_times2
439 | del args # In order to show these hyper-parameters clearly, I put them above.
440 |
441 | '''init: Evaluator'''
442 | evaluator = Evaluator(cwd=cwd, agent_id=agent_id, device=torch.device("cpu"), env=env_eval,
443 | eval_gap=eval_gap, eval_times1=eval_times1, eval_times2=eval_times2, ) # build Evaluator
444 |
445 | '''act_cpu without gradient for pipe1_eva'''
446 | # pipe1_eva.send(agent.act)
447 | act = pipe2_eva.recv()
448 |
449 | act_cpu = deepcopy(act).to(torch.device("cpu")) # for pipe1_eva
450 | [setattr(param, 'requires_grad', False) for param in act_cpu.parameters()]
451 |
452 | '''start evaluating'''
453 | with torch.no_grad(): # speed up running
454 | act, steps, obj_a, obj_c = pipe2_eva.recv() # pipe2_eva (act, steps, obj_a, obj_c)
455 |
456 | if_loop = True
457 | while if_loop:
458 | '''update actor'''
459 | while not pipe2_eva.poll(): # wait until pipe2_eva not empty
460 | time.sleep(1)
461 | steps_sum = 0
462 | while pipe2_eva.poll(): # receive the latest object from pipe
463 | '''recv'''
464 | # pipe1_eva.send((agent.act, steps, obj_a, obj_c))
465 | # pipe1_eva.send('stop')
466 | q_i_eva_get = pipe2_eva.recv()
467 |
468 | if q_i_eva_get == 'stop':
469 | if_loop = False
470 | break
471 | act, steps, obj_a, obj_c = q_i_eva_get
472 | steps_sum += steps
473 | act_cpu.load_state_dict(act.state_dict())
474 | if_solve = evaluator.evaluate_save(act_cpu, steps_sum, obj_a, obj_c)
475 | '''send'''
476 | pipe2_eva.send(if_solve)
477 | # if_solve = pipe1_eva.recv()
478 |
479 | evaluator.draw_plot()
480 |
481 | print(f'| SavedDir: {cwd}\n| UsedTime: {time.time() - evaluator.start_time:.0f}')
482 |
483 | while pipe2_eva.poll(): # empty the pipe
484 | pipe2_eva.recv()
485 |
486 |
487 | '''utils'''
488 |
489 |
490 | class Evaluator:
491 | def __init__(self, cwd, agent_id, eval_times1, eval_times2, eval_gap, env, device):
492 | self.recorder = [(0., -np.inf, 0., 0., 0.), ] # total_step, r_avg, r_std, obj_a, obj_c
493 | self.r_max = -np.inf
494 | self.total_step = 0
495 |
496 | self.cwd = cwd # constant
497 | self.device = device
498 | self.agent_id = agent_id
499 | self.eval_gap = eval_gap
500 | self.eval_times1 = eval_times1
501 | self.eval_times2 = eval_times2
502 | self.env = env
503 | self.target_return = env.target_return
504 |
505 | self.used_time = None
506 | self.start_time = time.time()
507 | self.eval_time = -1 # a early time
508 | print(f"{'ID':>2} {'Step':>8} {'MaxR':>8} |"
509 | f"{'avgR':>8} {'stdR':>8} {'objA':>8} {'objC':>8} |"
510 | f"{'avgS':>6} {'stdS':>4}")
511 |
512 | def evaluate_save(self, act, steps, obj_a, obj_c) -> bool:
513 | self.total_step += steps # update total training steps
514 |
515 | if time.time() - self.eval_time > self.eval_gap:
516 | self.eval_time = time.time()
517 |
518 | rewards_steps_list = [get_episode_return(self.env, act, self.device) for _ in range(self.eval_times1)]
519 | r_avg, r_std, s_avg, s_std = self.get_r_avg_std_s_avg_std(rewards_steps_list)
520 |
521 | if r_avg > self.r_max: # evaluate actor twice to save CPU Usage and keep precision
522 | rewards_steps_list += [get_episode_return(self.env, act, self.device)
523 | for _ in range(self.eval_times2 - self.eval_times1)]
524 | r_avg, r_std, s_avg, s_std = self.get_r_avg_std_s_avg_std(rewards_steps_list)
525 | if r_avg > self.r_max: # save checkpoint with highest episode return
526 | self.r_max = r_avg # update max reward (episode return)
527 |
528 | '''save actor.pth'''
529 | act_save_path = f'{self.cwd}/actor.pth'
530 | torch.save(act.state_dict(), act_save_path)
531 | print(f"{self.agent_id:<2} {self.total_step:8.2e} {self.r_max:8.2f} |") # save policy and print
532 |
533 | self.recorder.append((self.total_step, r_avg, r_std, obj_a, obj_c)) # update recorder
534 |
535 | if_reach_goal = bool(self.r_max > self.target_return) # check if_reach_goal
536 | if if_reach_goal and self.used_time is None:
537 | self.used_time = int(time.time() - self.start_time)
538 | print(f"{'ID':>2} {'Step':>8} {'TargetR':>8} |"
539 | f"{'avgR':>8} {'stdR':>8} {'UsedTime':>8} ########\n"
540 | f"{self.agent_id:<2} {self.total_step:8.2e} {self.target_return:8.2f} |"
541 | f"{r_avg:8.2f} {r_std:8.2f} {self.used_time:>8} ########")
542 |
543 | print(f"{self.agent_id:<2} {self.total_step:8.2e} {self.r_max:8.2f} |"
544 | f"{r_avg:8.2f} {r_std:8.2f} {obj_a:8.2f} {obj_c:8.2f} |"
545 | f"{s_avg:6.0f} {s_std:4.0f}")
546 | else:
547 | if_reach_goal = False
548 | return if_reach_goal
549 |
550 | def draw_plot(self):
551 | if len(self.recorder) == 0:
552 | print("| save_npy_draw_plot() WARNNING: len(self.recorder)==0")
553 | return None
554 |
555 | '''convert to array and save as npy'''
556 | np.save('%s/recorder.npy' % self.cwd, self.recorder)
557 |
558 | '''draw plot and save as png'''
559 | train_time = int(time.time() - self.start_time)
560 | total_step = int(self.recorder[-1][0])
561 | save_title = f"plot_step_time_maxR_{int(total_step)}_{int(train_time)}_{self.r_max:.3f}"
562 |
563 | save_learning_curve(self.recorder, self.cwd, save_title)
564 |
565 | @staticmethod
566 | def get_r_avg_std_s_avg_std(rewards_steps_list):
567 | rewards_steps_ary = np.array(rewards_steps_list)
568 | r_avg, s_avg = rewards_steps_ary.mean(axis=0) # average of episode return and episode step
569 | r_std, s_std = rewards_steps_ary.std(axis=0) # standard dev. of episode return and episode step
570 | return r_avg, r_std, s_avg, s_std
571 |
572 |
573 | def get_episode_return(env, act, device) -> (float, int):
574 | episode_return = 0.0 # sum of rewards in an episode
575 | episode_step = 1
576 | max_step = env.max_step
577 | if_discrete = env.if_discrete
578 |
579 | state = env.reset()
580 | for episode_step in range(max_step):
581 | s_tensor = torch.as_tensor((state,), device=device)
582 | a_tensor = act(s_tensor)
583 | if if_discrete:
584 | a_tensor = a_tensor.argmax(dim=1)
585 | action = a_tensor.detach().cpu().numpy()[0] # not need detach(), because with torch.no_grad() outside
586 | state, reward, done, _ = env.step(action)
587 | episode_return += reward
588 | if done:
589 | break
590 | episode_return = getattr(env, 'episode_return', episode_return)
591 | return episode_return, episode_step + 1
592 |
593 |
594 | def save_learning_curve(recorder, cwd='.', save_title='learning curve'):
595 | recorder = np.array(recorder) # recorder_ary.append((self.total_step, r_avg, r_std, obj_a, obj_c))
596 | steps = recorder[:, 0] # x-axis is training steps
597 | r_avg = recorder[:, 1]
598 | r_std = recorder[:, 2]
599 | obj_a = recorder[:, 3]
600 | obj_c = recorder[:, 4]
601 |
602 | '''plot subplots'''
603 | import matplotlib as mpl
604 | mpl.use('Agg')
605 | """Generating matplotlib graphs without a running X server [duplicate]
606 | write `mpl.use('Agg')` before `import matplotlib.pyplot as plt`
607 | https://stackoverflow.com/a/4935945/9293137
608 | """
609 | import matplotlib.pyplot as plt
610 | fig, axs = plt.subplots(2)
611 |
612 | axs0 = axs[0]
613 | axs0.cla()
614 | color0 = 'lightcoral'
615 | axs0.set_xlabel('Total Steps')
616 | axs0.set_ylabel('Episode Return')
617 | axs0.plot(steps, r_avg, label='Episode Return', color=color0)
618 | axs0.fill_between(steps, r_avg - r_std, r_avg + r_std, facecolor=color0, alpha=0.3)
619 |
620 | ax11 = axs[1]
621 | ax11.cla()
622 | color11 = 'royalblue'
623 | axs0.set_xlabel('Total Steps')
624 | ax11.set_ylabel('objA', color=color11)
625 | ax11.plot(steps, obj_a, label='objA', color=color11)
626 | ax11.tick_params(axis='y', labelcolor=color11)
627 |
628 | ax12 = axs[1].twinx()
629 | color12 = 'darkcyan'
630 | ax12.set_ylabel('objC', color=color12)
631 | ax12.fill_between(steps, obj_c, facecolor=color12, alpha=0.2, )
632 | ax12.tick_params(axis='y', labelcolor=color12)
633 |
634 | '''plot save'''
635 | plt.title(save_title, y=2.3)
636 | plt.savefig(f"{cwd}/plot_learning_curve.jpg")
637 | plt.close('all') # avoiding warning about too many open figures, rcParam `figure.max_open_warning`
638 | # plt.show() # if use `mpl.use('Agg')` to draw figures without GUI, then plt can't plt.show()
639 |
640 |
641 | def explore_before_training(env, buffer, target_step, reward_scale, gamma) -> int:
642 | # just for off-policy. Because on-policy don't explore before training.
643 | if_discrete = env.if_discrete
644 | action_dim = env.action_dim
645 |
646 | state = env.reset()
647 | steps = 0
648 |
649 | while steps < target_step:
650 | action = rd.randint(action_dim) if if_discrete else rd.uniform(-1, 1, size=action_dim)
651 | next_state, reward, done, _ = env.step(action)
652 | steps += 1
653 |
654 | scaled_reward = reward * reward_scale
655 | mask = 0.0 if done else gamma
656 | other = (scaled_reward, mask, action) if if_discrete else (scaled_reward, mask, *action)
657 | buffer.append_buffer(state, other)
658 |
659 | state = env.reset() if done else next_state
660 | return steps
661 |
--------------------------------------------------------------------------------
/env_single.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import pandas as pd
4 | import numpy as np
5 |
6 | import config
7 | from train_helper import insert_train_history_record_sqlite
8 |
9 |
10 | class StockTradingEnvSingle:
11 | def __init__(self, cwd='./envs/FinRL', gamma=0.99,
12 | max_stock=1e2, initial_capital=1e6, buy_cost_pct=1e-3, sell_cost_pct=1e-3,
13 | start_date='2008-03-19', end_date='2016-01-01', env_eval_date='2021-01-01',
14 | ticker_list=None, tech_indicator_list=None, initial_stocks=None, if_eval=False):
15 |
16 | self.price_ary, self.tech_ary, self.tic_ary, self.date_ary = self.load_data(cwd, if_eval, ticker_list,
17 | tech_indicator_list,
18 | start_date, end_date,
19 | env_eval_date, )
20 | stock_dim = self.price_ary.shape[1]
21 |
22 | self.gamma = gamma
23 | self.max_stock = max_stock
24 | self.buy_cost_pct = buy_cost_pct
25 | self.sell_cost_pct = sell_cost_pct
26 | self.initial_capital = initial_capital
27 | self.initial_stocks = np.zeros(stock_dim, dtype=np.float32) if initial_stocks is None else initial_stocks
28 |
29 | # reset()
30 | self.day = None
31 | self.amount = None
32 | self.stocks = None
33 | self.total_asset = None
34 | self.initial_total_asset = None
35 | self.gamma_reward = 0.0
36 |
37 | # environment information
38 | self.env_name = 'StockTradingEnv-v1'
39 | self.state_dim = 1 + 2 * stock_dim + self.tech_ary.shape[1]
40 | self.action_dim = stock_dim
41 | self.max_step = len(self.price_ary) - 1
42 | self.if_discrete = False
43 | self.target_return = 3.5
44 | self.episode_return = 0.0
45 |
46 | # 输出的缓存
47 | self.output_text_trade_detail = ''
48 |
49 | # 输出的list
50 | self.list_buy_or_sell_output = []
51 |
52 | # 奖励 比例
53 | self.reward_scaling = 0.0
54 |
55 | # 是 eval 还是 train
56 | self.if_eval = if_eval
57 |
58 | pass
59 |
60 | def reset(self):
61 | self.day = 0
62 | price = self.price_ary[self.day]
63 |
64 | # ----
65 | np.random.seed(round(time.time()))
66 | random_float = np.random.uniform(0.0, 1.01, size=self.initial_stocks.shape)
67 |
68 | # 如果是正式预测,输出到网页,固定 持股数和现金
69 | if config.IF_ACTUAL_PREDICT is True:
70 | self.stocks = self.initial_stocks.copy()
71 | self.amount = self.initial_capital - (self.stocks * price).sum()
72 | pass
73 | else:
74 | # 如果是train过程中的eval
75 | self.stocks = random_float * self.initial_stocks.copy() // 100 * 100
76 | self.amount = self.initial_capital * np.random.uniform(0.95, 1.05) - (self.stocks * price).sum()
77 | pass
78 | pass
79 |
80 | self.total_asset = self.amount + (self.stocks * price).sum()
81 | self.initial_total_asset = self.total_asset
82 | self.gamma_reward = 0.0
83 |
84 | state = np.hstack((self.amount * 2 ** -12,
85 | price,
86 | self.stocks * 2 ** -4,
87 | self.tech_ary[self.day],)).astype(np.float32) * 2 ** -8
88 |
89 | # 清空输出的缓存
90 | self.output_text_trade_detail = ''
91 |
92 | # 输出的list
93 | self.list_buy_or_sell_output.clear()
94 | self.list_buy_or_sell_output = []
95 |
96 | return state
97 |
98 | def step(self, actions):
99 | actions_temp = (actions * self.max_stock).astype(int)
100 |
101 | # ----
102 | yesterday_price = self.price_ary[self.day]
103 | # ----
104 |
105 | self.day += 1
106 |
107 | price = self.price_ary[self.day]
108 |
109 | tic_ary_temp = self.tic_ary[self.day]
110 | # 日期
111 | date_ary_temp = self.date_ary[self.day]
112 | date_temp = date_ary_temp[0]
113 |
114 | self.output_text_trade_detail += f'第 {self.day + 1} 天,{date_temp}\r\n'
115 |
116 | for index in np.where(actions_temp < 0)[0]: # sell_index:
117 | if price[index] > 0: # Sell only if current asset is > 0
118 |
119 | sell_num_shares = min(self.stocks[index], -actions_temp[index])
120 |
121 | tic_temp = tic_ary_temp[index]
122 |
123 | if sell_num_shares >= 100:
124 | # 若 action <= -100 地板除,卖1手整
125 | sell_num_shares = sell_num_shares // 100 * 100
126 | self.stocks[index] -= sell_num_shares
127 | self.amount += price[index] * sell_num_shares * (1 - self.sell_cost_pct)
128 |
129 | if self.if_eval is True:
130 | # tic, date, sell/buy, hold, 第x天
131 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset
132 |
133 | list_item = (tic_temp, date_temp, -1 * sell_num_shares, self.stocks[index], self.day + 1,
134 | episode_return_temp)
135 | # 添加到输出list
136 | self.list_buy_or_sell_output.append(list_item)
137 | pass
138 | else:
139 | # 当sell_num_shares < 100时,判断若 self.stocks[index] >= 100 则放大效果,卖1手
140 | if self.stocks[index] >= 100:
141 | sell_num_shares = 100
142 | self.stocks[index] -= sell_num_shares
143 | self.amount += price[index] * sell_num_shares * (1 - self.sell_cost_pct)
144 |
145 | if self.if_eval is True:
146 | # tic, date, sell/buy, hold, 第x天
147 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset
148 |
149 | list_item = (tic_temp, date_temp, -1 * sell_num_shares, self.stocks[index], self.day + 1,
150 | episode_return_temp)
151 | # 添加到输出list
152 | self.list_buy_or_sell_output.append(list_item)
153 | pass
154 | else:
155 | # self.stocks[index] 不足1手时,不动
156 | sell_num_shares = 0
157 |
158 | if self.if_eval is True:
159 | # tic, date, sell/buy, hold, 第x天
160 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset
161 |
162 | list_item = (tic_temp, date_temp, 0, self.stocks[index], self.day + 1, episode_return_temp)
163 | # 添加到输出list
164 | self.list_buy_or_sell_output.append(list_item)
165 | pass
166 | pass
167 | pass
168 |
169 | if yesterday_price[index] != 0:
170 | price_diff_percent = str(round((price[index] - yesterday_price[index]) / yesterday_price[index], 4))
171 | else:
172 | price_diff_percent = '0.0'
173 | pass
174 |
175 | price_diff = str(round(price[index] - yesterday_price[index], 6))
176 | self.output_text_trade_detail += f' > {tic_temp},预测涨跌:{round(-1 * actions[index], 4)},' \
177 | f'实际涨跌:{price_diff_percent} ¥{price_diff} 元,' \
178 | f'卖出:{sell_num_shares} 股, 持股数量 {self.stocks[index]},' \
179 | f'现金:{self.amount},资产:{self.total_asset} \r\n'
180 | pass
181 | pass
182 |
183 | for index in np.where(actions_temp > 0)[0]: # buy_index:
184 | if price[index] > 0: # Buy only if the price is > 0 (no missing data in this particular date)
185 | buy_num_shares = min(self.amount // price[index], actions_temp[index])
186 |
187 | tic_temp = tic_ary_temp[index]
188 |
189 | if buy_num_shares >= 100:
190 | # 若 actions >= +100,地板除,买1手整
191 | buy_num_shares = buy_num_shares // 100 * 100
192 | self.stocks[index] += buy_num_shares
193 | self.amount -= price[index] * buy_num_shares * (1 + self.buy_cost_pct)
194 |
195 | if self.if_eval is True:
196 | # tic, date, sell/buy, hold, 第x天
197 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset
198 |
199 | list_item = (tic_temp, date_temp, buy_num_shares, self.stocks[index], self.day + 1,
200 | episode_return_temp)
201 |
202 | # 添加到输出list
203 | self.list_buy_or_sell_output.append(list_item)
204 | pass
205 | else:
206 | # 当buy_num_shares < 100时,判断若 self.amount // price[index] >= 100,则放大效果,买1手
207 | if (self.amount // price[index]) >= 100:
208 | buy_num_shares = 100
209 | self.stocks[index] += buy_num_shares
210 | self.amount -= price[index] * buy_num_shares * (1 + self.buy_cost_pct)
211 |
212 | if self.if_eval is True:
213 | # tic, date, sell/buy, hold, 第x天
214 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset
215 |
216 | list_item = (tic_temp, date_temp, buy_num_shares, self.stocks[index], self.day + 1,
217 | episode_return_temp)
218 |
219 | # 添加到输出list
220 | self.list_buy_or_sell_output.append(list_item)
221 | else:
222 | # self.amount // price[index] 不足100时,不动
223 | # 未达到1手,不买
224 | buy_num_shares = 0
225 |
226 | if self.if_eval is True:
227 | # tic, date, sell/buy, hold, 第x天
228 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset
229 |
230 | list_item = (tic_temp, date_temp, 0, self.stocks[index], self.day + 1, episode_return_temp)
231 | # 添加到输出list
232 | self.list_buy_or_sell_output.append(list_item)
233 | pass
234 | pass
235 | pass
236 |
237 | if yesterday_price[index] != 0:
238 | price_diff_percent = str(round((price[index] - yesterday_price[index]) / yesterday_price[index], 4))
239 | else:
240 | price_diff_percent = '0.0'
241 | pass
242 |
243 | price_diff = str(round(price[index] - yesterday_price[index], 6))
244 | self.output_text_trade_detail += f' > {tic_temp},预测涨跌:{round(-1 * actions[index], 4)},' \
245 | f'实际涨跌:{price_diff_percent} ¥{price_diff} 元,' \
246 | f'买入:{buy_num_shares} 股, 持股数量:{self.stocks[index]},' \
247 | f'现金:{self.amount},资产:{self.total_asset} \r\n'
248 |
249 | pass
250 | pass
251 |
252 | if self.if_eval is True:
253 | for index in np.where(actions_temp == 0)[0]: # sell_index:
254 | if price[index] > 0: # Buy only if the price is > 0 (no missing data in this particular date)
255 | # tic, date, sell/buy, hold, 第x天
256 | tic_temp = tic_ary_temp[index]
257 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset
258 |
259 | list_item = (tic_temp, date_temp, 0, self.stocks[index], self.day + 1, episode_return_temp)
260 | # 添加到输出list
261 | self.list_buy_or_sell_output.append(list_item)
262 | pass
263 | pass
264 | pass
265 | pass
266 |
267 | state = np.hstack((self.amount * 2 ** -12,
268 | price,
269 | self.stocks * 2 ** -4,
270 | self.tech_ary[self.day],)).astype(np.float32) * 2 ** -8
271 |
272 | total_asset = self.amount + (self.stocks * price).sum()
273 | reward = (total_asset - self.total_asset) * self.reward_scaling
274 |
275 | self.total_asset = total_asset
276 |
277 | self.gamma_reward = self.gamma_reward * self.gamma + reward
278 | done = self.day == self.max_step
279 |
280 | if done:
281 | reward = self.gamma_reward
282 | self.episode_return = total_asset / self.initial_total_asset
283 | # if self.if_eval is True:
284 | # print(config.AGENT_NAME, 'eval DONE reward:', str(reward))
285 | # pass
286 | # else:
287 | # print(config.AGENT_NAME, 'train DONE reward:', str(reward))
288 | # pass
289 | # pass
290 |
291 | if config.IF_ACTUAL_PREDICT:
292 | print(self.output_text_trade_detail)
293 | print(f'第 {self.day + 1} 天,{date_temp},现金:{self.amount},'
294 | f'股票:{str((self.stocks * price).sum())},总资产:{self.total_asset}')
295 | else:
296 | # if self.if_eval is True:
297 | # print(config.AGENT_NAME, 'eval reward', str(reward))
298 | # else:
299 | # print(config.AGENT_NAME, 'train reward', str(reward))
300 | # pass
301 |
302 | if reward > config.REWARD_THRESHOLD:
303 | # 如果是 预测
304 | if self.if_eval is True:
305 | insert_train_history_record_sqlite(model_id=config.MODEL_HYPER_PARAMETERS, train_reward_value=0.0,
306 | eval_reward_value=reward)
307 |
308 | print('>>>>', config.AGENT_NAME, 'eval reward', str(reward))
309 | else:
310 | # 如果是 train
311 | insert_train_history_record_sqlite(model_id=config.MODEL_HYPER_PARAMETERS,
312 | train_reward_value=reward, eval_reward_value=0.0)
313 |
314 | print(config.AGENT_NAME, 'train reward', str(reward))
315 | pass
316 |
317 | pass
318 | pass
319 |
320 | return state, reward, done, dict()
321 |
322 | def load_data(self, cwd='./envs/FinRL', if_eval=None,
323 | ticker_list=None, tech_indicator_list=None,
324 | start_date='2008-03-19', end_date='2016-01-01', env_eval_date='2021-01-01'):
325 |
326 | # 从数据库中读取fe fillzero的数据
327 | from stock_data import StockData
328 | processed_df = StockData.get_fe_fillzero_from_sqlite(begin_date=start_date, end_date=env_eval_date,
329 | list_stock_code=config.SINGLE_A_STOCK_CODE,
330 | if_actual_predict=config.IF_ACTUAL_PREDICT,
331 | table_name='fe_fillzero')
332 |
333 | def data_split_train(df, start, end):
334 | data = df[(df.date >= start) & (df.date < end)]
335 | data = data.sort_values(["date", "tic"], ignore_index=True)
336 | data.index = data.date.factorize()[0]
337 | return data
338 |
339 | def data_split_eval(df, start, end):
340 | data = df[(df.date >= start) & (df.date <= end)]
341 | data = data.sort_values(["date", "tic"], ignore_index=True)
342 | data.index = data.date.factorize()[0]
343 | return data
344 |
345 | train_df = data_split_train(processed_df, start_date, end_date)
346 | eval_df = data_split_eval(processed_df, end_date, env_eval_date)
347 |
348 | train_price_ary, train_tech_ary, train_tic_ary, train_date_ary = self.convert_df_to_ary(train_df,
349 | tech_indicator_list)
350 | eval_price_ary, eval_tech_ary, eval_tic_ary, eval_date_ary = self.convert_df_to_ary(eval_df,
351 | tech_indicator_list)
352 |
353 | if if_eval is None:
354 | price_ary = np.concatenate((train_price_ary, eval_price_ary), axis=0)
355 | tech_ary = np.concatenate((train_tech_ary, eval_tech_ary), axis=0)
356 | tic_ary = None
357 | date_ary = None
358 | elif if_eval:
359 | price_ary = eval_price_ary
360 | tech_ary = eval_tech_ary
361 | tic_ary = eval_tic_ary
362 | date_ary = eval_date_ary
363 | else:
364 | price_ary = train_price_ary
365 | tech_ary = train_tech_ary
366 | tic_ary = train_tic_ary
367 | date_ary = train_date_ary
368 |
369 | return price_ary, tech_ary, tic_ary, date_ary
370 |
371 | @staticmethod
372 | def convert_df_to_ary(df, tech_indicator_list):
373 | tech_ary = list()
374 | price_ary = list()
375 | tic_ary = list()
376 | date_ary = list()
377 |
378 | from stock_data import fields_prep
379 | columns_list = fields_prep.split(',')
380 |
381 | for day in range(len(df.index.unique())):
382 | # item = df.loc[day]
383 | list_temp = df.loc[day]
384 | if list_temp.ndim == 1:
385 | list_temp = [df.loc[day]]
386 | pass
387 | item = pd.DataFrame(data=list_temp, columns=columns_list)
388 |
389 | tech_items = [item[tech].values.tolist() for tech in tech_indicator_list]
390 | tech_items_flatten = sum(tech_items, [])
391 | tech_ary.append(tech_items_flatten)
392 | price_ary.append(item.close) # adjusted close price (adjcp)
393 |
394 | # ----
395 | # tic_ary.append(list(item.tic))
396 | # date_ary.append(list(item.date))
397 |
398 | tic_ary.append(item.tic)
399 | date_ary.append(item.date)
400 |
401 | # ----
402 |
403 | pass
404 |
405 | price_ary = np.array(price_ary)
406 | tech_ary = np.array(tech_ary)
407 |
408 | tic_ary = np.array(tic_ary)
409 | date_ary = np.array(date_ary)
410 |
411 | print(f'| price_ary.shape: {price_ary.shape}, tech_ary.shape: {tech_ary.shape}')
412 | return price_ary, tech_ary, tic_ary, date_ary
413 |
414 | def draw_cumulative_return(self, args, _torch) -> list:
415 | state_dim = self.state_dim
416 | action_dim = self.action_dim
417 |
418 | agent = args.agent
419 | net_dim = args.net_dim
420 | cwd = args.cwd
421 |
422 | agent.init(net_dim, state_dim, action_dim)
423 | agent.save_load_model(cwd=cwd, if_save=False)
424 | act = agent.act
425 | device = agent.device
426 |
427 | state = self.reset()
428 | episode_returns = list() # the cumulative_return / initial_account
429 | with _torch.no_grad():
430 | for i in range(self.max_step):
431 | s_tensor = _torch.as_tensor((state,), device=device)
432 | a_tensor = act(s_tensor)
433 | action = a_tensor.cpu().numpy()[0] # not need detach(), because with torch.no_grad() outside
434 | state, reward, done, _ = self.step(action)
435 |
436 | total_asset = self.amount + (self.price_ary[self.day] * self.stocks).sum()
437 | episode_return = total_asset / self.initial_total_asset
438 | episode_returns.append(episode_return)
439 | if done:
440 | break
441 |
442 | import matplotlib.pyplot as plt
443 | plt.plot(episode_returns)
444 | plt.grid()
445 | plt.title('cumulative return')
446 | plt.xlabel('day')
447 | plt.xlabel('multiple of initial_account')
448 | plt.savefig(f'{cwd}/cumulative_return.jpg')
449 | return episode_returns
450 |
451 |
452 | class FeatureEngineer:
453 | """Provides methods for preprocessing the stock price data
454 | from finrl.preprocessing.preprocessors import FeatureEngineer
455 |
456 | Attributes
457 | ----------
458 | use_technical_indicator : boolean
459 | we technical indicator or not
460 | tech_indicator_list : list
461 | a list of technical indicator names (modified from config.py)
462 | use_turbulence : boolean
463 | use turbulence index or not
464 | user_defined_feature:boolean
465 | user user defined features or not
466 |
467 | Methods
468 | -------
469 | preprocess_data()
470 | main method to do the feature engineering
471 |
472 | """
473 |
474 | def __init__(
475 | self,
476 | use_technical_indicator=True,
477 | tech_indicator_list=None, # config.TECHNICAL_INDICATORS_LIST,
478 | use_turbulence=False,
479 | user_defined_feature=False,
480 | ):
481 | self.use_technical_indicator = use_technical_indicator
482 | self.tech_indicator_list = tech_indicator_list
483 | self.use_turbulence = use_turbulence
484 | self.user_defined_feature = user_defined_feature
485 |
486 | def preprocess_data(self, df):
487 | """main method to do the feature engineering
488 | @:param config: source dataframe
489 | @:return: a DataMatrices object
490 | """
491 |
492 | if self.use_technical_indicator:
493 | # add technical indicators using stockstats
494 | df = self.add_technical_indicator(df)
495 | print("Successfully added technical indicators")
496 |
497 | # add turbulence index for multiple stock
498 | if self.use_turbulence:
499 | df = self.add_turbulence(df)
500 | print("Successfully added turbulence index")
501 |
502 | # add user defined feature
503 | if self.user_defined_feature:
504 | df = self.add_user_defined_feature(df)
505 | print("Successfully added user defined features")
506 |
507 | # fill the missing values at the beginning and the end
508 | df = df.fillna(method="bfill").fillna(method="ffill")
509 | return df
510 |
511 | def add_technical_indicator(self, data):
512 | """
513 | calculate technical indicators
514 | use stockstats package to add technical inidactors
515 | :param data: (df) pandas dataframe
516 | :return: (df) pandas dataframe
517 | """
518 | from stockstats import StockDataFrame as Sdf # for Sdf.retype
519 |
520 | df = data.copy()
521 | df = df.sort_values(by=['tic', 'date'])
522 | stock = Sdf.retype(df.copy())
523 | unique_ticker = stock.tic.unique()
524 |
525 | for indicator in self.tech_indicator_list:
526 | indicator_df = pd.DataFrame()
527 | for i in range(len(unique_ticker)):
528 | try:
529 | temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
530 | temp_indicator = pd.DataFrame(temp_indicator)
531 | temp_indicator['tic'] = unique_ticker[i]
532 | temp_indicator['date'] = df[df.tic == unique_ticker[i]]['date'].to_list()
533 | indicator_df = indicator_df.append(
534 | temp_indicator, ignore_index=True
535 | )
536 | except Exception as e:
537 | print(e)
538 | df = df.merge(indicator_df[['tic', 'date', indicator]], on=['tic', 'date'], how='left')
539 | df = df.sort_values(by=['date', 'tic'])
540 | return df
541 |
542 | def add_turbulence(self, data):
543 | """
544 | add turbulence index from a precalcualted dataframe
545 | :param data: (df) pandas dataframe
546 | :return: (df) pandas dataframe
547 | """
548 | df = data.copy()
549 | turbulence_index = self.calculate_turbulence(df)
550 | df = df.merge(turbulence_index, on="date")
551 | df = df.sort_values(["date", "tic"]).reset_index(drop=True)
552 | return df
553 |
554 | @staticmethod
555 | def add_user_defined_feature(data):
556 | """
557 | add user defined features
558 | :param data: (df) pandas dataframe
559 | :return: (df) pandas dataframe
560 | """
561 | df = data.copy()
562 | df["daily_return"] = df.close.pct_change(1)
563 | # df['return_lag_1']=df.close.pct_change(2)
564 | # df['return_lag_2']=df.close.pct_change(3)
565 | # df['return_lag_3']=df.close.pct_change(4)
566 | # df['return_lag_4']=df.close.pct_change(5)
567 | return df
568 |
569 | @staticmethod
570 | def calculate_turbulence(data):
571 | """calculate turbulence index based on dow 30"""
572 | # can add other market assets
573 | df = data.copy()
574 | df_price_pivot = df.pivot(index="date", columns="tic", values="close")
575 | # use returns to calculate turbulence
576 | df_price_pivot = df_price_pivot.pct_change()
577 |
578 | unique_date = df.date.unique()
579 | # start after a year
580 | start = 252
581 | turbulence_index = [0] * start
582 | # turbulence_index = [0]
583 | count = 0
584 | for i in range(start, len(unique_date)):
585 | current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
586 | # use one year rolling window to calcualte covariance
587 | hist_price = df_price_pivot[
588 | (df_price_pivot.index < unique_date[i])
589 | & (df_price_pivot.index >= unique_date[i - 252])
590 | ]
591 | # Drop tickers which has number missing values more than the "oldest" ticker
592 | filtered_hist_price = hist_price.iloc[hist_price.isna().sum().min():].dropna(axis=1)
593 |
594 | cov_temp = filtered_hist_price.cov()
595 | current_temp = current_price[[x for x in filtered_hist_price]] - np.mean(filtered_hist_price, axis=0)
596 | temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
597 | current_temp.values.T
598 | )
599 | if temp > 0:
600 | count += 1
601 | if count > 2:
602 | turbulence_temp = temp[0][0]
603 | else:
604 | # avoid large outlier because of the calculation just begins
605 | turbulence_temp = 0
606 | else:
607 | turbulence_temp = 0
608 | turbulence_index.append(turbulence_temp)
609 |
610 | turbulence_index = pd.DataFrame(
611 | {"date": df_price_pivot.index, "turbulence": turbulence_index}
612 | )
613 | return turbulence_index
614 |
--------------------------------------------------------------------------------
/run_batch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gym
3 | import time
4 | import torch
5 | import numpy as np
6 | import numpy.random as rd
7 | from copy import deepcopy
8 | from ElegantRL_master.elegantrl.replay import ReplayBuffer, ReplayBufferMP
9 | from ElegantRL_master.elegantrl.env import PreprocessEnv
10 | import config
11 |
12 | """[ElegantRL](https://github.com/AI4Finance-LLC/ElegantRL)"""
13 |
14 |
15 | class Arguments:
16 | def __init__(self, agent=None, env=None, gpu_id=None, if_on_policy=False):
17 | self.agent = agent # Deep Reinforcement Learning algorithm
18 |
19 | self.cwd = None # current work directory. cwd is None means set it automatically
20 | self.env = env # the environment for training
21 | self.env_eval = None # the environment for evaluating
22 | self.gpu_id = gpu_id # choose the GPU for running. gpu_id is None means set it automatically
23 |
24 | '''Arguments for training (off-policy)'''
25 | self.net_dim = 2 ** 8 # the network width
26 | self.batch_size = 2 ** 8 # num of transitions sampled from replay buffer.
27 | self.repeat_times = 2 ** 0 # repeatedly update network to keep critic's loss small
28 | self.target_step = 2 ** 10 # collect target_step, then update network
29 | self.max_memo = 2 ** 17 # capacity of replay buffer
30 | if if_on_policy: # (on-policy)
31 | self.net_dim = 2 ** 9
32 | self.batch_size = 2 ** 9
33 | self.repeat_times = 2 ** 4
34 | self.target_step = 2 ** 12
35 | self.max_memo = self.target_step
36 | self.gamma = 0.99 # discount factor of future rewards
37 | self.reward_scale = 2 ** 0 # an approximate target reward usually be closed to 256
38 | self.if_per = False # Prioritized Experience Replay for sparse reward
39 |
40 | self.rollout_num = 2 # the number of rollout workers (larger is not always faster)
41 | self.num_threads = 8 # cpu_num for evaluate model, torch.set_num_threads(self.num_threads)
42 |
43 | '''Arguments for evaluate'''
44 | self.break_step = 2 ** 20 # break training after 'total_step > break_step'
45 | self.if_remove = True # remove the cwd folder? (True, False, None:ask me)
46 | self.if_allow_break = True # allow break training when reach goal (early termination)
47 | self.eval_gap = 2 ** 5 # evaluate the agent per eval_gap seconds
48 | self.eval_times1 = 2 ** 2 # evaluation times
49 | self.eval_times2 = 2 ** 4 # evaluation times if 'eval_reward > max_reward'
50 | self.random_seed = 0 # initialize random seed in self.init_before_training()
51 |
52 | def init_before_training(self, if_main=True):
53 | if self.agent is None:
54 | raise RuntimeError('\n| Why agent=None? Assignment args.agent = AgentXXX please.')
55 | if not hasattr(self.agent, 'init'):
56 | raise RuntimeError('\n| There should be agent=AgentXXX() instead of agent=AgentXXX')
57 | if self.env is None:
58 | raise RuntimeError('\n| Why env=None? Assignment args.env = XxxEnv() please.')
59 | if isinstance(self.env, str) or not hasattr(self.env, 'env_name'):
60 | raise RuntimeError('\n| What is env.env_name? use env=PreprocessEnv(env). It is a Wrapper.')
61 |
62 | '''set gpu_id automatically'''
63 | if self.gpu_id is None: # set gpu_id automatically
64 | import sys
65 | self.gpu_id = sys.argv[-1][-4]
66 | else:
67 | self.gpu_id = str(self.gpu_id)
68 | if not self.gpu_id.isdigit(): # set gpu_id as '0' in default
69 | self.gpu_id = '0'
70 |
71 | '''set cwd automatically'''
72 | if self.cwd is None:
73 | # ----
74 | agent_name = self.agent.__class__.__name__
75 | # self.cwd = f'./{agent_name}/{self.env.env_name}_{self.gpu_id}'
76 | # self.cwd = f'./{agent_name}/{self.env.env_name}'
77 | self.cwd = f'./{config.WEIGHTS_PATH}/{self.env.env_name}'
78 |
79 | # model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.BATCH_A_STOCK_CODE[0]}' \
80 | # f'/single_{config.VALI_DAYS_FLAG}'
81 | # ----
82 |
83 | if if_main:
84 | print(f'| GPU id: {self.gpu_id}, cwd: {self.cwd}')
85 |
86 | import shutil # remove history according to bool(if_remove)
87 | if self.if_remove is None:
88 | self.if_remove = bool(input("PRESS 'y' to REMOVE: {}? ".format(self.cwd)) == 'y')
89 | if self.if_remove:
90 | shutil.rmtree(self.cwd, ignore_errors=True)
91 | print("| Remove history")
92 | os.makedirs(self.cwd, exist_ok=True)
93 |
94 | os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpu_id)
95 | torch.set_num_threads(self.num_threads)
96 | torch.set_default_dtype(torch.float32)
97 | torch.manual_seed(self.random_seed)
98 | np.random.seed(self.random_seed)
99 |
100 |
101 | '''single process training'''
102 |
103 |
104 | def train_and_evaluate(args):
105 | args.init_before_training()
106 |
107 | '''basic arguments'''
108 | cwd = args.cwd
109 | env = args.env
110 | agent = args.agent
111 | gpu_id = args.gpu_id # necessary for Evaluator?
112 |
113 | '''training arguments'''
114 | net_dim = args.net_dim
115 | max_memo = args.max_memo
116 | break_step = args.break_step
117 | batch_size = args.batch_size
118 | target_step = args.target_step
119 | repeat_times = args.repeat_times
120 | if_break_early = args.if_allow_break
121 | if_per = args.if_per
122 | gamma = args.gamma
123 | reward_scale = args.reward_scale
124 |
125 | '''evaluating arguments'''
126 | eval_gap = args.eval_gap
127 | eval_times1 = args.eval_times1
128 | eval_times2 = args.eval_times2
129 | if args.env_eval is not None:
130 | env_eval = args.env_eval
131 | elif args.env_eval in set(gym.envs.registry.env_specs.keys()):
132 | env_eval = PreprocessEnv(gym.make(env.env_name))
133 | else:
134 | env_eval = deepcopy(env)
135 |
136 | del args # In order to show these hyper-parameters clearly, I put them above.
137 |
138 | '''init: environment'''
139 | max_step = env.max_step
140 | state_dim = env.state_dim
141 | action_dim = env.action_dim
142 | if_discrete = env.if_discrete
143 |
144 | '''init: Agent, ReplayBuffer, Evaluator'''
145 | agent.init(net_dim, state_dim, action_dim, if_per)
146 |
147 | # ----
148 | # 目录 path
149 | # model_folder_path = f'./{config.AGENT_NAME}/batch/{config.BATCH_A_STOCK_CODE[0]}' \
150 | # f'/batch_{config.VALI_DAYS_FLAG}'
151 |
152 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \
153 | f'batch_{config.VALI_DAYS_FLAG}'
154 |
155 | # 文件 path
156 | model_file_path = f'{model_folder_path}/actor.pth'
157 |
158 | # 如果model存在,则加载
159 | if os.path.exists(model_file_path):
160 | agent.save_load_model(model_folder_path, if_save=False)
161 | pass
162 | # ----
163 |
164 | if_on_policy = getattr(agent, 'if_on_policy', False)
165 |
166 | buffer = ReplayBuffer(max_len=max_memo + max_step, state_dim=state_dim, action_dim=1 if if_discrete else action_dim,
167 | if_on_policy=if_on_policy, if_per=if_per, if_gpu=True)
168 |
169 | evaluator = Evaluator(cwd=cwd, agent_id=gpu_id, device=agent.device, env=env_eval,
170 | eval_gap=eval_gap, eval_times1=eval_times1, eval_times2=eval_times2, )
171 |
172 | '''prepare for training'''
173 | agent.state = env.reset()
174 | if if_on_policy:
175 | steps = 0
176 | else: # explore_before_training for off-policy
177 | with torch.no_grad(): # update replay buffer
178 | steps = explore_before_training(env, buffer, target_step, reward_scale, gamma)
179 |
180 | agent.update_net(buffer, target_step, batch_size, repeat_times) # pre-training and hard update
181 | agent.act_target.load_state_dict(agent.act.state_dict()) if getattr(agent, 'act_target', None) else None
182 | agent.cri_target.load_state_dict(agent.cri.state_dict()) if getattr(agent, 'cri_target', None) else None
183 | total_step = steps
184 |
185 | '''start training'''
186 | if_reach_goal = False
187 | while not ((if_break_early and if_reach_goal)
188 | or total_step > break_step
189 | or os.path.exists(f'{cwd}/stop')):
190 | steps = agent.explore_env(env, buffer, target_step, reward_scale, gamma)
191 | total_step += steps
192 |
193 | obj_a, obj_c = agent.update_net(buffer, target_step, batch_size, repeat_times)
194 |
195 | if_reach_goal = evaluator.evaluate_save(agent.act, steps, obj_a, obj_c)
196 | evaluator.draw_plot()
197 |
198 | print(f'| SavedDir: {cwd}\n| UsedTime: {time.time() - evaluator.start_time:.0f}')
199 |
200 |
201 | '''multiprocessing training'''
202 |
203 |
204 | def train_and_evaluate_mp(args):
205 | act_workers = args.rollout_num
206 | import multiprocessing as mp # Python built-in multiprocessing library
207 |
208 | pipe1_eva, pipe2_eva = mp.Pipe() # Pipe() for Process mp_evaluate_agent()
209 | pipe2_exp_list = list() # Pipe() for Process mp_explore_in_env()
210 |
211 | process_train = mp.Process(target=mp_train, args=(args, pipe2_eva, pipe2_exp_list))
212 | process_evaluate = mp.Process(target=mp_evaluate, args=(args, pipe1_eva))
213 | process = [process_train, process_evaluate]
214 |
215 | for worker_id in range(act_workers):
216 | exp_pipe1, exp_pipe2 = mp.Pipe(duplex=True)
217 | pipe2_exp_list.append(exp_pipe1)
218 | process.append(mp.Process(target=mp_explore, args=(args, exp_pipe2, worker_id)))
219 |
220 | [p.start() for p in process]
221 | process_evaluate.join()
222 | process_train.join()
223 | [p.terminate() for p in process]
224 |
225 |
226 | def mp_train(args, pipe1_eva, pipe1_exp_list):
227 | args.init_before_training(if_main=False)
228 |
229 | '''basic arguments'''
230 | env = args.env
231 | cwd = args.cwd
232 | agent = args.agent
233 | rollout_num = args.rollout_num
234 |
235 | '''training arguments'''
236 | net_dim = args.net_dim
237 | max_memo = args.max_memo
238 | break_step = args.break_step
239 | batch_size = args.batch_size
240 | target_step = args.target_step
241 | repeat_times = args.repeat_times
242 | if_break_early = args.if_allow_break
243 | if_per = args.if_per
244 | del args # In order to show these hyper-parameters clearly, I put them above.
245 |
246 | '''init: environment'''
247 | max_step = env.max_step
248 | state_dim = env.state_dim
249 | action_dim = env.action_dim
250 | if_discrete = env.if_discrete
251 |
252 | '''init: Agent, ReplayBuffer'''
253 | agent.init(net_dim, state_dim, action_dim, if_per)
254 |
255 | # ----
256 | # 目录 path
257 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \
258 | f'batch_{config.VALI_DAYS_FLAG}'
259 |
260 | # f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.BATCH_A_STOCK_CODE[0]}/StockTradingEnv-v1'
261 |
262 | # 文件 path
263 | model_file_path = f'{model_folder_path}/actor.pth'
264 |
265 | # 如果model存在,则加载
266 | if os.path.exists(model_file_path):
267 | agent.save_load_model(model_folder_path, if_save=False)
268 | pass
269 | # ----
270 |
271 | if_on_policy = getattr(agent, 'if_on_policy', False)
272 |
273 | '''send'''
274 | pipe1_eva.send(agent.act) # send
275 | # act = pipe2_eva.recv() # recv
276 |
277 | buffer_mp = ReplayBufferMP(max_len=max_memo + max_step * rollout_num, if_on_policy=if_on_policy,
278 | state_dim=state_dim, action_dim=1 if if_discrete else action_dim,
279 | rollout_num=rollout_num, if_gpu=True, if_per=if_per)
280 |
281 | '''prepare for training'''
282 | if if_on_policy:
283 | steps = 0
284 | else: # explore_before_training for off-policy
285 | with torch.no_grad(): # update replay buffer
286 | steps = 0
287 | for i in range(rollout_num):
288 | pipe1_exp = pipe1_exp_list[i]
289 |
290 | # pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len]))
291 | buf_state, buf_other = pipe1_exp.recv()
292 |
293 | steps += len(buf_state)
294 | buffer_mp.extend_buffer(buf_state, buf_other, i)
295 |
296 | agent.update_net(buffer_mp, target_step, batch_size, repeat_times) # pre-training and hard update
297 | agent.act_target.load_state_dict(agent.act.state_dict()) if getattr(env, 'act_target', None) else None
298 | agent.cri_target.load_state_dict(agent.cri.state_dict()) if getattr(env, 'cri_target', None) in dir(
299 | agent) else None
300 | total_step = steps
301 | '''send'''
302 | pipe1_eva.send((agent.act, steps, 0, 0.5)) # send
303 | # act, steps, obj_a, obj_c = pipe2_eva.recv() # recv
304 |
305 | '''start training'''
306 | if_solve = False
307 | while not ((if_break_early and if_solve)
308 | or total_step > break_step
309 | or os.path.exists(f'{cwd}/stop')):
310 | '''update ReplayBuffer'''
311 | steps = 0 # send by pipe1_eva
312 | for i in range(rollout_num):
313 | pipe1_exp = pipe1_exp_list[i]
314 | '''send'''
315 | pipe1_exp.send(agent.act)
316 | # agent.act = pipe2_exp.recv()
317 | '''recv'''
318 | # pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len]))
319 | buf_state, buf_other = pipe1_exp.recv()
320 |
321 | steps += len(buf_state)
322 | buffer_mp.extend_buffer(buf_state, buf_other, i)
323 | total_step += steps
324 |
325 | '''update network parameters'''
326 | obj_a, obj_c = agent.update_net(buffer_mp, target_step, batch_size, repeat_times)
327 |
328 | '''saves the agent with max reward'''
329 | '''send'''
330 | pipe1_eva.send((agent.act, steps, obj_a, obj_c))
331 | # q_i_eva_get = pipe2_eva.recv()
332 |
333 | if_solve = pipe1_eva.recv()
334 |
335 | if pipe1_eva.poll():
336 | '''recv'''
337 | # pipe2_eva.send(if_solve)
338 | if_solve = pipe1_eva.recv()
339 |
340 | buffer_mp.print_state_norm(env.neg_state_avg if hasattr(env, 'neg_state_avg') else None,
341 | env.div_state_std if hasattr(env, 'div_state_std') else None) # 2020-12-12
342 |
343 | '''send'''
344 | pipe1_eva.send('stop')
345 | # q_i_eva_get = pipe2_eva.recv()
346 | time.sleep(4)
347 |
348 |
349 | def mp_explore(args, pipe2_exp, worker_id):
350 | args.init_before_training(if_main=False)
351 |
352 | '''basic arguments'''
353 | env = args.env
354 | agent = args.agent
355 | rollout_num = args.rollout_num
356 |
357 | '''training arguments'''
358 | net_dim = args.net_dim
359 | max_memo = args.max_memo
360 | target_step = args.target_step
361 | gamma = args.gamma
362 | if_per = args.if_per
363 | reward_scale = args.reward_scale
364 |
365 | random_seed = args.random_seed
366 | torch.manual_seed(random_seed + worker_id)
367 | np.random.seed(random_seed + worker_id)
368 | del args # In order to show these hyper-parameters clearly, I put them above.
369 |
370 | '''init: environment'''
371 | max_step = env.max_step
372 | state_dim = env.state_dim
373 | action_dim = env.action_dim
374 | if_discrete = env.if_discrete
375 |
376 | '''init: Agent, ReplayBuffer'''
377 | agent.init(net_dim, state_dim, action_dim, if_per)
378 |
379 | # ----
380 | # 目录 path
381 | # model_folder_path = f'./{config.AGENT_NAME}/single/{config.BATCH_A_STOCK_CODE[0]}' \
382 | # f'/single_{config.VALI_DAYS_FLAG}'
383 |
384 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \
385 | f'batch_{config.VALI_DAYS_FLAG}'
386 |
387 | # 文件 path
388 | model_file_path = f'{model_folder_path}/actor.pth'
389 |
390 | # 如果model存在,则加载
391 | if os.path.exists(model_file_path):
392 | agent.save_load_model(model_folder_path, if_save=False)
393 | pass
394 | # ----
395 |
396 | agent.state = env.reset()
397 |
398 | if_on_policy = getattr(agent, 'if_on_policy', False)
399 | buffer = ReplayBuffer(max_len=max_memo // rollout_num + max_step, if_on_policy=if_on_policy,
400 | state_dim=state_dim, action_dim=1 if if_discrete else action_dim,
401 | if_per=if_per, if_gpu=False)
402 |
403 | '''start exploring'''
404 | exp_step = target_step // rollout_num
405 | with torch.no_grad():
406 | if not if_on_policy:
407 | explore_before_training(env, buffer, exp_step, reward_scale, gamma)
408 |
409 | buffer.update_now_len_before_sample()
410 |
411 | pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len]))
412 | # buf_state, buf_other = pipe1_exp.recv()
413 |
414 | buffer.empty_buffer_before_explore()
415 |
416 | while True:
417 | agent.explore_env(env, buffer, exp_step, reward_scale, gamma)
418 |
419 | buffer.update_now_len_before_sample()
420 |
421 | pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len]))
422 | # buf_state, buf_other = pipe1_exp.recv()
423 |
424 | buffer.empty_buffer_before_explore()
425 |
426 | # pipe1_exp.send(agent.act)
427 | agent.act = pipe2_exp.recv()
428 |
429 |
430 | def mp_evaluate(args, pipe2_eva):
431 | args.init_before_training(if_main=True)
432 |
433 | '''basic arguments'''
434 | cwd = args.cwd
435 | env = args.env
436 | env_eval = env if args.env_eval is None else args.env_eval
437 | agent_id = args.gpu_id
438 |
439 | '''evaluating arguments'''
440 | eval_gap = args.eval_gap
441 | eval_times1 = args.eval_times1
442 | eval_times2 = args.eval_times2
443 | del args # In order to show these hyper-parameters clearly, I put them above.
444 |
445 | '''init: Evaluator'''
446 | evaluator = Evaluator(cwd=cwd, agent_id=agent_id, device=torch.device("cpu"), env=env_eval,
447 | eval_gap=eval_gap, eval_times1=eval_times1, eval_times2=eval_times2, ) # build Evaluator
448 |
449 | '''act_cpu without gradient for pipe1_eva'''
450 | # pipe1_eva.send(agent.act)
451 | act = pipe2_eva.recv()
452 |
453 | act_cpu = deepcopy(act).to(torch.device("cpu")) # for pipe1_eva
454 | [setattr(param, 'requires_grad', False) for param in act_cpu.parameters()]
455 |
456 | '''start evaluating'''
457 | with torch.no_grad(): # speed up running
458 | act, steps, obj_a, obj_c = pipe2_eva.recv() # pipe2_eva (act, steps, obj_a, obj_c)
459 |
460 | if_loop = True
461 | while if_loop:
462 | '''update actor'''
463 | while not pipe2_eva.poll(): # wait until pipe2_eva not empty
464 | time.sleep(1)
465 | steps_sum = 0
466 | while pipe2_eva.poll(): # receive the latest object from pipe
467 | '''recv'''
468 | # pipe1_eva.send((agent.act, steps, obj_a, obj_c))
469 | # pipe1_eva.send('stop')
470 | q_i_eva_get = pipe2_eva.recv()
471 |
472 | if q_i_eva_get == 'stop':
473 | if_loop = False
474 | break
475 | act, steps, obj_a, obj_c = q_i_eva_get
476 | steps_sum += steps
477 | act_cpu.load_state_dict(act.state_dict())
478 | if_solve = evaluator.evaluate_save(act_cpu, steps_sum, obj_a, obj_c)
479 | '''send'''
480 | pipe2_eva.send(if_solve)
481 | # if_solve = pipe1_eva.recv()
482 |
483 | evaluator.draw_plot()
484 |
485 | print(f'| SavedDir: {cwd}\n| UsedTime: {time.time() - evaluator.start_time:.0f}')
486 |
487 | while pipe2_eva.poll(): # empty the pipe
488 | pipe2_eva.recv()
489 |
490 |
491 | '''utils'''
492 |
493 |
494 | class Evaluator:
495 | def __init__(self, cwd, agent_id, eval_times1, eval_times2, eval_gap, env, device):
496 | self.recorder = [(0., -np.inf, 0., 0., 0.), ] # total_step, r_avg, r_std, obj_a, obj_c
497 | self.r_max = -np.inf
498 | self.total_step = 0
499 |
500 | self.cwd = cwd # constant
501 | self.device = device
502 | self.agent_id = agent_id
503 | self.eval_gap = eval_gap
504 | self.eval_times1 = eval_times1
505 | self.eval_times2 = eval_times2
506 | self.env = env
507 | self.target_return = env.target_return
508 |
509 | self.used_time = None
510 | self.start_time = time.time()
511 | self.eval_time = -1 # a early time
512 | print(f"{'ID':>2} {'Step':>8} {'MaxR':>8} |"
513 | f"{'avgR':>8} {'stdR':>8} {'objA':>8} {'objC':>8} |"
514 | f"{'avgS':>6} {'stdS':>4}")
515 |
516 | def evaluate_save(self, act, steps, obj_a, obj_c) -> bool:
517 | self.total_step += steps # update total training steps
518 |
519 | if time.time() - self.eval_time > self.eval_gap:
520 | self.eval_time = time.time()
521 |
522 | rewards_steps_list = [get_episode_return(self.env, act, self.device) for _ in range(self.eval_times1)]
523 | r_avg, r_std, s_avg, s_std = self.get_r_avg_std_s_avg_std(rewards_steps_list)
524 |
525 | if r_avg > self.r_max: # evaluate actor twice to save CPU Usage and keep precision
526 | rewards_steps_list += [get_episode_return(self.env, act, self.device)
527 | for _ in range(self.eval_times2 - self.eval_times1)]
528 | r_avg, r_std, s_avg, s_std = self.get_r_avg_std_s_avg_std(rewards_steps_list)
529 | if r_avg > self.r_max: # save checkpoint with highest episode return
530 | self.r_max = r_avg # update max reward (episode return)
531 |
532 | '''save actor.pth'''
533 | act_save_path = f'{self.cwd}/actor.pth'
534 | torch.save(act.state_dict(), act_save_path)
535 | print(f"{self.agent_id:<2} {self.total_step:8.2e} {self.r_max:8.2f} |") # save policy and print
536 |
537 | self.recorder.append((self.total_step, r_avg, r_std, obj_a, obj_c)) # update recorder
538 |
539 | if_reach_goal = bool(self.r_max > self.target_return) # check if_reach_goal
540 | if if_reach_goal and self.used_time is None:
541 | self.used_time = int(time.time() - self.start_time)
542 | print(f"{'ID':>2} {'Step':>8} {'TargetR':>8} |"
543 | f"{'avgR':>8} {'stdR':>8} {'UsedTime':>8} ########\n"
544 | f"{self.agent_id:<2} {self.total_step:8.2e} {self.target_return:8.2f} |"
545 | f"{r_avg:8.2f} {r_std:8.2f} {self.used_time:>8} ########")
546 |
547 | print(f"{self.agent_id:<2} {self.total_step:8.2e} {self.r_max:8.2f} |"
548 | f"{r_avg:8.2f} {r_std:8.2f} {obj_a:8.2f} {obj_c:8.2f} |"
549 | f"{s_avg:6.0f} {s_std:4.0f}")
550 | else:
551 | if_reach_goal = False
552 | return if_reach_goal
553 |
554 | def draw_plot(self):
555 | if len(self.recorder) == 0:
556 | print("| save_npy_draw_plot() WARNNING: len(self.recorder)==0")
557 | return None
558 |
559 | '''convert to array and save as npy'''
560 | np.save('%s/recorder.npy' % self.cwd, self.recorder)
561 |
562 | '''draw plot and save as png'''
563 | train_time = int(time.time() - self.start_time)
564 | total_step = int(self.recorder[-1][0])
565 | save_title = f"plot_step_time_maxR_{int(total_step)}_{int(train_time)}_{self.r_max:.3f}"
566 |
567 | save_learning_curve(self.recorder, self.cwd, save_title)
568 |
569 | @staticmethod
570 | def get_r_avg_std_s_avg_std(rewards_steps_list):
571 | rewards_steps_ary = np.array(rewards_steps_list)
572 | r_avg, s_avg = rewards_steps_ary.mean(axis=0) # average of episode return and episode step
573 | r_std, s_std = rewards_steps_ary.std(axis=0) # standard dev. of episode return and episode step
574 | return r_avg, r_std, s_avg, s_std
575 |
576 |
577 | def get_episode_return(env, act, device) -> (float, int):
578 | episode_return = 0.0 # sum of rewards in an episode
579 | episode_step = 1
580 | max_step = env.max_step
581 | if_discrete = env.if_discrete
582 |
583 | state = env.reset()
584 | for episode_step in range(max_step):
585 | s_tensor = torch.as_tensor((state,), device=device)
586 | a_tensor = act(s_tensor)
587 | if if_discrete:
588 | a_tensor = a_tensor.argmax(dim=1)
589 | action = a_tensor.detach().cpu().numpy()[0] # not need detach(), because with torch.no_grad() outside
590 | state, reward, done, _ = env.step(action)
591 | episode_return += reward
592 | if done:
593 | break
594 | episode_return = getattr(env, 'episode_return', episode_return)
595 | return episode_return, episode_step + 1
596 |
597 |
598 | def save_learning_curve(recorder, cwd='.', save_title='learning curve'):
599 | recorder = np.array(recorder) # recorder_ary.append((self.total_step, r_avg, r_std, obj_a, obj_c))
600 | steps = recorder[:, 0] # x-axis is training steps
601 | r_avg = recorder[:, 1]
602 | r_std = recorder[:, 2]
603 | obj_a = recorder[:, 3]
604 | obj_c = recorder[:, 4]
605 |
606 | '''plot subplots'''
607 | import matplotlib as mpl
608 | mpl.use('Agg')
609 | """Generating matplotlib graphs without a running X server [duplicate]
610 | write `mpl.use('Agg')` before `import matplotlib.pyplot as plt`
611 | https://stackoverflow.com/a/4935945/9293137
612 | """
613 | import matplotlib.pyplot as plt
614 | fig, axs = plt.subplots(2)
615 |
616 | axs0 = axs[0]
617 | axs0.cla()
618 | color0 = 'lightcoral'
619 | axs0.set_xlabel('Total Steps')
620 | axs0.set_ylabel('Episode Return')
621 | axs0.plot(steps, r_avg, label='Episode Return', color=color0)
622 | axs0.fill_between(steps, r_avg - r_std, r_avg + r_std, facecolor=color0, alpha=0.3)
623 |
624 | ax11 = axs[1]
625 | ax11.cla()
626 | color11 = 'royalblue'
627 | axs0.set_xlabel('Total Steps')
628 | ax11.set_ylabel('objA', color=color11)
629 | ax11.plot(steps, obj_a, label='objA', color=color11)
630 | ax11.tick_params(axis='y', labelcolor=color11)
631 |
632 | ax12 = axs[1].twinx()
633 | color12 = 'darkcyan'
634 | ax12.set_ylabel('objC', color=color12)
635 | ax12.fill_between(steps, obj_c, facecolor=color12, alpha=0.2, )
636 | ax12.tick_params(axis='y', labelcolor=color12)
637 |
638 | '''plot save'''
639 | plt.title(save_title, y=2.3)
640 | plt.savefig(f"{cwd}/plot_learning_curve.jpg")
641 | plt.close('all') # avoiding warning about too many open figures, rcParam `figure.max_open_warning`
642 | # plt.show() # if use `mpl.use('Agg')` to draw figures without GUI, then plt can't plt.show()
643 |
644 |
645 | def explore_before_training(env, buffer, target_step, reward_scale, gamma) -> int:
646 | # just for off-policy. Because on-policy don't explore before training.
647 | if_discrete = env.if_discrete
648 | action_dim = env.action_dim
649 |
650 | state = env.reset()
651 | steps = 0
652 |
653 | while steps < target_step:
654 | action = rd.randint(action_dim) if if_discrete else rd.uniform(-1, 1, size=action_dim)
655 | next_state, reward, done, _ = env.step(action)
656 | steps += 1
657 |
658 | scaled_reward = reward * reward_scale
659 | mask = 0.0 if done else gamma
660 | other = (scaled_reward, mask, action) if if_discrete else (scaled_reward, mask, *action)
661 | buffer.append_buffer(state, other)
662 |
663 | state = env.reset() if done else next_state
664 | return steps
665 |
--------------------------------------------------------------------------------