├── cover.png ├── requirements.txt ├── .gitmodules ├── .idea ├── other.xml ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── vcs.xml ├── modules.xml └── a_share_pipeline.iml ├── utils ├── log.py ├── sqlite.py ├── psqldb.py └── date_time.py ├── reports ├── template │ ├── index.html.template │ └── card_predict_result.template ├── index.html └── sqlite_to_html.py ├── config.py ├── filter_stock.py ├── README.md ├── .gitignore ├── train_single.py ├── train_batch_bak.py ├── train_batch.py ├── train_helper.py ├── predict_single_psql.py ├── predict_batch_psql.py ├── run_single.py ├── env_single.py └── run_batch.py /cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dyh/a_share_pipeline/HEAD/cover.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | baostock 2 | requests 3 | psycopg2-binary 4 | gym 5 | stockstats 6 | matplotlib -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ElegantRL_master"] 2 | path = ElegantRL_master 3 | url = https://github.com/AI4Finance-LLC/ElegantRL.git 4 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def get_logger(log_file_path, log_level=logging.INFO): 5 | """ 6 | 返回 logger 实例 7 | :param log_file_path: 日志文件路径 8 | :param log_level: 日志记录级别 logging.INFO、logging.ERROR、logging.DEBUG 等等 9 | :return: logger 实例 10 | """ 11 | logger = logging.getLogger() 12 | logger.setLevel(log_level) 13 | 14 | logfile = log_file_path 15 | fh = logging.FileHandler(logfile, mode='a') 16 | formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s") 17 | fh.setFormatter(formatter) 18 | fh.setLevel(log_level) 19 | logger.addHandler(fh) 20 | 21 | sh = logging.StreamHandler() 22 | sh.setLevel(log_level) 23 | logger.addHandler(sh) 24 | 25 | return logger 26 | -------------------------------------------------------------------------------- /.idea/a_share_pipeline.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 18 | 19 | 20 | 22 | -------------------------------------------------------------------------------- /reports/template/index.html.template: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | <%page_title%> 8 | 9 | 10 | 11 | 12 | 13 | 23 | 24 | <%page_content%> 25 | 26 |

页面生成时间 <%page_time_point%>

27 | 28 | -------------------------------------------------------------------------------- /reports/template/card_predict_result.template: -------------------------------------------------------------------------------- 1 | 2 |
3 |

4 | 预测 <%date%> 5 |

6 | 7 |

8 | <%tic%> 持股 <%most_hold%> % 9 | 买卖 <%most_action%> % 10 |

11 | 12 |

13 | [ 点击查看 <%day%> 天预测交易详情 <%tic%>.txt ] 14 |

15 | 16 |
17 |
18 |

<%tic%> <%date%>

19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | <%predict_result_table_tr_td%> 31 |
T-1回报率/低买高卖价差买卖持股算法验证周期预测周期
32 |
33 | 34 | 日K 35 | 36 |
37 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | # 交易的最小股数,可设置最低100股,即1手 3 | MINIMUM_TRADE_SHARES = 1 4 | 5 | # 调试 state 比例 6 | IF_DEBUG_STATE_SCALE = False 7 | 8 | # 调试 reward 比例 9 | IF_DEBUG_REWARD_SCALE = True 10 | 11 | # 多只股票代码List 12 | BATCH_A_STOCK_CODE = [] 13 | 14 | # update reward 阈值 15 | REWARD_THRESHOLD = 256 * 1.5 16 | 17 | # 超参ID 18 | MODEL_HYPER_PARAMETERS = -1 19 | 20 | # 代理名称 21 | AGENT_NAME = '' 22 | 23 | # 代理预测周期 24 | AGENT_WORK_DAY = 0 25 | 26 | # 预测周期 27 | PREDICT_PERIOD = '' 28 | 29 | # 单支股票代码List 30 | SINGLE_A_STOCK_CODE = [] 31 | 32 | # 真实的预测 33 | IF_ACTUAL_PREDICT = False 34 | 35 | # 工作日标记,用于加载对应的weights 36 | # weights的vali周期 37 | VALI_DAYS_FLAG = '' 38 | 39 | ## time_fmt = '%Y-%m-%d' 40 | START_DATE = '' 41 | START_EVAL_DATE = '' 42 | END_DATE = '' 43 | # 要输出的日期 44 | OUTPUT_DATE = '' 45 | 46 | DATA_SAVE_DIR = f'stock_db' 47 | 48 | # pth路径 49 | WEIGHTS_PATH = 'weights' 50 | 51 | # batch股票数据库地址 52 | STOCK_DB_PATH = './' + DATA_SAVE_DIR + '/stock.db' 53 | 54 | # ---- 55 | # PostgreSQL 56 | PSQL_HOST = '192.168.192.1' 57 | PSQL_PORT = '5432' 58 | PSQL_DATABASE = 'a_share' 59 | PSQL_USER = 'dyh' 60 | PSQL_PASSWORD = '9898BdlLsdfsHsbgp' 61 | # ---- 62 | 63 | ## stockstats technical indicator column names 64 | ## check https://pypi.org/project/stockstats/ for different names 65 | TECHNICAL_INDICATORS_LIST = ['macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30', 'close_30_sma', 'close_60_sma'] 66 | -------------------------------------------------------------------------------- /filter_stock.py: -------------------------------------------------------------------------------- 1 | import baostock as bs 2 | import pandas as pd 3 | 4 | from stock_data import StockData 5 | 6 | if __name__ == '__main__': 7 | 8 | list1 = StockData.get_batch_a_share_code_list_string(date_filter='2004-05-01') 9 | 10 | # 登陆系统 11 | lg = bs.login() 12 | # 显示登陆返回信息 13 | print('login respond error_code:' + lg.error_code) 14 | print('login respond error_msg:' + lg.error_msg) 15 | 16 | # 获取行业分类数据 17 | rs = bs.query_stock_industry() 18 | # rs = bs.query_stock_basic(code_name="浦发银行") 19 | print('query_stock_industry error_code:' + rs.error_code) 20 | print('query_stock_industry respond error_msg:' + rs.error_msg) 21 | 22 | # 打印结果集 23 | industry_list = [] 24 | # 查询 季频盈利能力 25 | profit_list = [] 26 | 27 | fields1 = ['code', 'code_name', 'industry', 'pubDate', 'statDate', 'roeAvg', 'npMargin', 'gpMargin', 28 | 'netProfit', 'epsTTM', 'MBRevenue', 'totalShare', 'liqaShare'] 29 | 30 | index1 = 0 31 | count1 = len(list1) 32 | 33 | while (rs.error_code == '0') & rs.next(): 34 | # 获取一条记录,将记录合并在一起 35 | industry_item = rs.get_row_data() 36 | 37 | updateDate, code, code_name, industry, industryClassification = industry_item 38 | 39 | if code in list1: 40 | 41 | rs_profit = bs.query_profit_data(code=code, year=2021, quarter=1) 42 | 43 | while (rs_profit.error_code == '0') & rs_profit.next(): 44 | code, pubDate, statDate, roeAvg, npMargin, gpMargin, netProfit, epsTTM, MBRevenue, totalShare, liqaShare = rs_profit.get_row_data() 45 | 46 | # profit_list.append(rs_profit.get_row_data()) 47 | profit_list.append( 48 | [code, code_name, industry, pubDate, statDate, roeAvg, npMargin, gpMargin, netProfit, epsTTM, 49 | MBRevenue, totalShare, liqaShare]) 50 | 51 | pass 52 | pass 53 | 54 | index1 += 1 55 | print(index1, '/', count1) 56 | 57 | pass 58 | 59 | result_profit = pd.DataFrame(profit_list, columns=fields1) 60 | 61 | # result = pd.DataFrame(industry_list, columns=rs.fields) 62 | 63 | # 结果集输出到csv文件 64 | result_profit.to_csv("./stock_industry_profit.csv", index=False) 65 | 66 | print(result_profit) 67 | 68 | # 登出系统 69 | bs.logout() 70 | pass 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Share Pipeline 2 | 3 | - A股决策 Pipeline 4 | 5 | ### 视频 6 | 7 | bilibili 8 | 9 | [![bilibili](https://github.com/dyh/a_share_pipeline/blob/main/cover.png?raw=true)](https://www.bilibili.com/video/BV1ph411Y78q/ "bilibili") 10 | 11 | 12 | ### 使用框架: 13 | - https://github.com/AI4Finance-LLC/ElegantRL 14 | 15 | ### 数据来源: 16 | - baostock.com 17 | - sinajs.cn 18 | 19 | ## 运行环境 20 | 21 | - ubuntu 18.04.5 22 | - python 3.6+,pip 20+ 23 | - pytorch 1.7+ 24 | - pip install -r requirements.txt 25 | 26 | 27 | ## 如何运行 28 | 29 | 1. 下载代码 30 | 31 | ``` 32 | $ git clone https://github.com/dyh/a_share_pipeline.git 33 | ``` 34 | 35 | 2. 进入目录 36 | 37 | ``` 38 | $ cd a_share_pipeline 39 | ``` 40 | 41 | 3. 更新submodule ElegantRL 42 | 43 | ``` 44 | $ cd ElegantRL_master/ 45 | $ git submodule update --init --recursive 46 | $ git pull 47 | $ cd .. 48 | ``` 49 | 50 | > 如果git pull 提示没在任何分支上,可以检出ElegantRL的master分支: 51 | 52 | ``` 53 | $ git checkout master 54 | ``` 55 | 56 | 4. 创建 python 虚拟环境 57 | 58 | ``` 59 | $ python3 -m venv venv 60 | ``` 61 | 62 | > 如果提示找不到 venv,则安装venv: 63 | 64 | ``` 65 | $ sudo apt install python3-pip 66 | $ sudo apt-get install python3-venv 67 | ``` 68 | 69 | 5. 激活虚拟环境 70 | 71 | ``` 72 | $ source venv/bin/activate 73 | ``` 74 | 75 | 6. 升级pip 76 | 77 | ``` 78 | $ python -m pip install --upgrade pip 79 | ``` 80 | 81 | 7. 安装pytorch 82 | 83 | > 根据你的操作系统,运行环境工具和CUDA版本,在 https://pytorch.org/get-started/locally/ 找到对应的安装命令,复制粘贴运行。为了方便演示,以下使用CPU运行,安装命令如下: 84 | 85 | ``` 86 | $ pip3 install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 87 | ``` 88 | 89 | 8. 安装其他软件包 90 | 91 | ``` 92 | $ pip install -r requirements.txt 93 | ``` 94 | 95 | > 如果上述安装出错,可以考虑安装如下: 96 | 97 | ``` 98 | $ sudo apt-get update && sudo apt-get install cmake libopenmpi-dev python3-dev zlib1g-dev libgl1-mesa-glx 99 | ``` 100 | 101 | 9. 训练模型 sh.600036 102 | 103 | ``` 104 | $ python train_single.py 105 | ``` 106 | 107 | 10. 预测数据 sh.600036 108 | 109 | ``` 110 | $ python predict_single_sqlite.py 111 | ``` 112 | 113 | 11. 生成报告 sh.600036 114 | 115 | ``` 116 | $ cd reports 117 | $ python sqlite_to_html.py 118 | ``` 119 | -------------------------------------------------------------------------------- /utils/sqlite.py: -------------------------------------------------------------------------------- 1 | # coding = utf-8 2 | 3 | import logging 4 | import sqlite3 5 | 6 | 7 | class SQLite: 8 | def __init__(self, dbname): 9 | self.name = dbname 10 | self.conn = sqlite3.connect(self.name, check_same_thread=False) 11 | 12 | def commit(self): 13 | self.conn.commit() 14 | 15 | def close(self): 16 | self.conn.close() 17 | 18 | # 执行sql但不查询 19 | def execute_non_query(self, sql, values=()): 20 | value = False 21 | try: 22 | cursor = self.conn.cursor() 23 | cursor.execute(sql, values) 24 | # self.commit() 25 | value = True 26 | except sqlite3.OperationalError as e: 27 | logging.exception(e) 28 | pass 29 | except Exception as e: 30 | logging.exception(e) 31 | finally: 32 | return value 33 | # self.close() 34 | 35 | # 查询,返回所有结果 36 | def fetchall(self, sql, values=()): 37 | try: 38 | cursor = self.conn.cursor() 39 | cursor.execute(sql, values) 40 | values = cursor.fetchall() 41 | return values 42 | except sqlite3.OperationalError as e: 43 | logging.exception(e) 44 | pass 45 | except Exception as e: 46 | logging.exception(e) 47 | finally: 48 | pass 49 | 50 | # 查询,返回一条结果 51 | def fetchone(self, sql, values=()): 52 | try: 53 | cursor = self.conn.cursor() 54 | cursor.execute(sql, values) 55 | value = cursor.fetchone() 56 | return value 57 | except sqlite3.OperationalError as e: 58 | logging.exception(e) 59 | pass 60 | except Exception as e: 61 | logging.exception(e) 62 | finally: 63 | pass 64 | 65 | # 查询,返回一条结果 66 | def table_exists(self, table_name): 67 | try: 68 | cursor = self.conn.cursor() 69 | cursor.execute(f"select name from sqlite_master where type = 'table' and name = '{table_name}'") 70 | value = cursor.fetchone() 71 | return value 72 | except sqlite3.OperationalError as e: 73 | logging.exception(e) 74 | pass 75 | except Exception as e: 76 | logging.exception(e) 77 | finally: 78 | pass 79 | 80 | # 删除一个表 81 | def drop_table(self, table_name): 82 | value = False 83 | try: 84 | cursor = self.conn.cursor() 85 | cursor.execute(f"DROP TABLE IF EXISTS '{table_name}'") 86 | value = True 87 | except sqlite3.OperationalError as e: 88 | logging.exception(e) 89 | pass 90 | except Exception as e: 91 | logging.exception(e) 92 | finally: 93 | return value 94 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | *.pth 132 | *.zip 133 | *.npy 134 | *.npz 135 | *.df 136 | *.csv 137 | *.db 138 | 139 | stock.db 140 | 141 | /pipeline/informer/checkpoints/ 142 | /pipeline/informer/temp_dataset/ 143 | /pipeline/finrl/datasets_temp/ 144 | /pipeline/finrl/results/ 145 | /pipeline/finrl/tensorboard_log/ 146 | /pipeline/finrl/trained_models/ 147 | /pipeline/informer/results/ 148 | /pipeline/elegant/AgentPPO/ 149 | /pipeline/elegant/datasets/ 150 | /pipeline/elegant/bak/ 151 | /pipeline/elegant/AgentSAC/ 152 | /pipeline/elegant/AgentDDPG/ 153 | /pipeline/elegant/AgentDuelingDQN/ 154 | /pipeline/elegant/AgentTD3/ 155 | /pipeline/elegant/AgentModSAC/ 156 | /pipeline/elegant/AgentDoubleDQN/ 157 | /pipeline/elegant/AgentSharedSAC/ 158 | /weights/ 159 | /stock_db/ 160 | /temp/ 161 | -------------------------------------------------------------------------------- /reports/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | A股预测 8 | 9 | 10 | 11 | 12 | 13 | 23 | 24 | 25 |
26 |

27 | 预测 2021-06-24 周四 28 |

29 | 30 |

31 | sh.600036 持股 0.0 % 32 | 买卖 0.0 % 33 |

34 | 35 |

36 | [ 点击查看 36 天预测交易详情 sh.600036.txt ] 37 |

38 | 39 |
40 |
41 |

sh.600036 2021-06-24 周四

42 |
43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 |
T-1回报率/低买高卖价差买卖持股算法验证周期预测周期
5.22% / 9.91%-97.0%0.0%PPO90天第36/40天
3.46% / 9.91%0.0%0.0%SAC20天第36/40天
3.37% / 9.91%0.0%0.0%DDPG60天第36/40天
2.6% / 9.91%0.0%97.0%PPO50天第36/40天
2.6% / 9.91%0.0%97.0%SAC30天第36/40天
2.6% / 9.91%0.0%97.0%SAC90天第36/40天
2.3% / 9.91%0.0%97.0%PPO20天第36/40天
2.22% / 9.91%0.0%97.0%DDPG30天第36/40天
1.81% / 9.91%-97.0%0.0%TD350天第36/40天
1.6% / 9.91%28.0%69.0%SAC60天第36/40天
1.44% / 9.91%0.0%93.0%PPO40天第36/40天
1.4% / 9.91%0.0%93.0%TD390天第36/40天
1.19% / 9.91%0.0%93.0%PPO30天第36/40天
-0.15% / 9.91%0.0%90.0%SAC72天第36/40天
-0.33% / 9.91%0.0%93.0%TD330天第36/40天
-1.06% / 9.91%0.0%93.0%DDPG90天第36/40天
-1.21% / 9.91%0.0%0.0%TD372天第36/40天
-1.24% / 9.91%0.0%0.0%DDPG50天第36/40天
-1.39% / 9.91%-76.0%0.0%TD320天第36/40天
-1.98% / 9.91%0.0%0.0%DDPG20天第36/40天
-2.29% / 9.91%17.0%41.0%SAC50天第36/40天
-2.33% / 9.91%0.0%0.0%TD360天第36/40天
-2.44% / 9.91%0.0%0.0%SAC40天第36/40天
-3.08% / 9.91%-83.0%7.0%PPO60天第36/40天
-3.47% / 9.91%0.0%0.0%DDPG72天第36/40天
-3.94% / 9.91%-90.0%0.0%DDPG40天第36/40天
-4.27% / 9.91%90.0%90.0%TD340天第36/40天
-12.21% / 9.91%83.0%83.0%PPO72天第36/40天
55 |
56 | 57 | 日K 58 | 59 |
60 | 61 | 62 | 63 |

页面生成时间 2021-06-24 12:49:45

64 | 65 | -------------------------------------------------------------------------------- /utils/psqldb.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import time 3 | 4 | import logging 5 | import psycopg2 6 | 7 | 8 | # 判断是否数据库超时或者数据库连接关闭 9 | def is_timeout_or_closed_error(str_exception): 10 | if "Connection timed out" in str_exception or "connection already closed" in str_exception: 11 | return True 12 | else: 13 | return False 14 | 15 | 16 | class Psqldb: 17 | def __init__(self, database, user, password, host, port): 18 | self.database = database 19 | self.user = user 20 | self.password = password 21 | self.host = host 22 | self.port = port 23 | self.conn = None 24 | # 调用连接函数 25 | self.re_connect() 26 | 27 | # 再次连接数据库 28 | def re_connect(self): 29 | # 如果连接数据库超时,则循环再次连接 30 | while True: 31 | try: 32 | print(__name__, "re_connect") 33 | 34 | self.conn = psycopg2.connect(database=self.database, user=self.user, 35 | password=self.password, host=self.host, port=self.port) 36 | # 如果连接正常,则退出循环 37 | break 38 | except psycopg2.OperationalError as e: 39 | logging.exception(e) 40 | except Exception as e: 41 | # 其他错误,记录,退出 42 | logging.exception(e) 43 | break 44 | 45 | # sleep 5秒后重新连接 46 | print(__name__, "re-connecting PostgreSQL in 5 seconds") 47 | time.sleep(5) 48 | 49 | # 提交数据 50 | def commit(self): 51 | try: 52 | self.conn.commit() 53 | except psycopg2.OperationalError as e: 54 | logging.exception(e) 55 | except Exception as e: 56 | logging.exception(e) 57 | 58 | # 关闭数据库 59 | def close(self): 60 | try: 61 | print(__name__, "close") 62 | 63 | self.conn.close() 64 | except psycopg2.OperationalError as e: 65 | logging.exception(e) 66 | except Exception as e: 67 | logging.exception(e) 68 | 69 | # 执行sql但不查询 70 | def execute_non_query(self, sql, values=()): 71 | result = False 72 | while True: 73 | try: 74 | # console(__name__, "execute_non_query") 75 | 76 | cursor = self.conn.cursor() 77 | cursor.execute(sql, values) 78 | # self.commit() 79 | result = True 80 | break 81 | except psycopg2.OperationalError as e: 82 | logging.exception(e) 83 | str_exception = str(e) 84 | if is_timeout_or_closed_error(str_exception): 85 | # 重新连接数据库 86 | self.re_connect() 87 | else: 88 | break 89 | except Exception as e: 90 | # 其他错误,记录,退出 91 | logging.exception(e) 92 | str_exception = str(e) 93 | if is_timeout_or_closed_error(str_exception): 94 | # console(__name__, "re-connecting PostgreSQL immediately") 95 | # 重新连接数据库 96 | self.re_connect() 97 | else: 98 | break 99 | # sleep 5秒后重新连接 100 | print(__name__, "re-connecting PostgreSQL in 5 seconds") 101 | time.sleep(5) 102 | 103 | return result 104 | 105 | # 查询,返回所有结果 106 | def fetchall(self, sql, values=()): 107 | results = None 108 | while True: 109 | try: 110 | # console(__name__, "fetchall") 111 | 112 | cursor = self.conn.cursor() 113 | cursor.execute(sql, values) 114 | results = cursor.fetchall() 115 | break 116 | except psycopg2.OperationalError as e: 117 | logging.exception(e) 118 | str_exception = str(e) 119 | if is_timeout_or_closed_error(str_exception): 120 | # 重新连接数据库 121 | self.re_connect() 122 | else: 123 | break 124 | except Exception as e: 125 | # 其他错误,记录忽略 126 | logging.exception(e) 127 | str_exception = str(e) 128 | if is_timeout_or_closed_error(str_exception): 129 | # 重新连接数据库 130 | self.re_connect() 131 | else: 132 | break 133 | return results 134 | 135 | # 查询,返回一条结果 136 | def fetchone(self, sql, values=()): 137 | result = None 138 | while True: 139 | try: 140 | # console(__name__, "fetchone") 141 | 142 | cursor = self.conn.cursor() 143 | cursor.execute(sql, values) 144 | result = cursor.fetchone() 145 | break 146 | except psycopg2.OperationalError as e: 147 | logging.exception(e) 148 | str_exception = str(e) 149 | if is_timeout_or_closed_error(str_exception): 150 | # 重新连接数据库 151 | self.re_connect() 152 | else: 153 | break 154 | except Exception as e: 155 | # 其他错误,记录,退出 156 | logging.exception(e) 157 | str_exception = str(e) 158 | if is_timeout_or_closed_error(str_exception): 159 | # 重新连接数据库 160 | self.re_connect() 161 | else: 162 | break 163 | return result 164 | 165 | # 查询,返回一条结果 166 | def table_exists(self, table_name): 167 | result = None 168 | try: 169 | cursor = self.conn.cursor() 170 | cursor.execute(f"select count(*) from pg_class where relname = \'{table_name}\';") 171 | value = cursor.fetchone()[0] 172 | if value == 0: 173 | result = None 174 | else: 175 | result = value 176 | except psycopg2.OperationalError as e: 177 | logging.exception(e) 178 | pass 179 | except Exception as e: 180 | logging.exception(e) 181 | finally: 182 | return result 183 | pass 184 | # 185 | # # 删除一个表 186 | # def drop_table(self, table_name): 187 | # value = False 188 | # try: 189 | # cursor = self.conn.cursor() 190 | # cursor.execute("DROP TABLE IF EXISTS '{}'".format(table_name)) 191 | # value = True 192 | # except psycopg2.OperationalError as e: 193 | # logging.exception(e) 194 | # pass 195 | # except Exception as e: 196 | # logging.exception(e) 197 | # finally: 198 | # return value 199 | -------------------------------------------------------------------------------- /utils/date_time.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | 4 | 5 | def time_point(time_format='%Y%m%d_%H%M%S'): 6 | return time.strftime(time_format, time.localtime()) 7 | 8 | 9 | def get_datetime_from_date_str(str_date): 10 | """ 11 | 由日期字符串获得datetime对象 12 | :param str_date: 日期字符串,格式 2021-10-20 13 | :return: datetime对象 14 | """ 15 | year_temp = str_date.split('-')[0] 16 | month_temp = str_date.split('-')[1] 17 | day_temp = str_date.split('-')[2] 18 | return datetime.date(int(year_temp), int(month_temp), int(day_temp)) 19 | 20 | 21 | # def get_next_work_day(datetime_date, next_flag=1): 22 | # """ 23 | # 获取下一个工作日 24 | # :param datetime_date: 日期类型 25 | # :param next_flag: -1 前一天,+1 后一天 26 | # :return: datetime.date 27 | # """ 28 | # if next_flag > 0: 29 | # next_date = datetime_date + datetime.timedelta( 30 | # days=7 - datetime_date.weekday() if datetime_date.weekday() > 3 else 1) 31 | # else: 32 | # # 周二至周六 33 | # if 0 < datetime_date.weekday() < 6: 34 | # next_date = datetime_date - datetime.timedelta(days=1) 35 | # pass 36 | # elif 0 == datetime_date.weekday(): 37 | # # 周一 38 | # next_date = datetime_date - datetime.timedelta(days=3) 39 | # else: 40 | # # 6 == datetime_date.weekday(): 41 | # # 周日 42 | # next_date = datetime_date - datetime.timedelta(days=2) 43 | # pass 44 | # pass 45 | # return next_date 46 | 47 | def get_next_work_day(datetime_date, next_flag=1): 48 | """ 49 | 获取下一个工作日 50 | :param datetime_date: 日期类型 51 | :param next_flag: -1 前一天,+1 后一天 52 | :return: datetime.date 53 | """ 54 | for loop in range(next_flag.__abs__()): 55 | if next_flag > 0: 56 | datetime_date = datetime_date + datetime.timedelta( 57 | days=7 - datetime_date.weekday() if datetime_date.weekday() > 3 else 1) 58 | else: 59 | # 周二至周六 60 | if 0 < datetime_date.weekday() < 6: 61 | datetime_date = datetime_date - datetime.timedelta(days=1) 62 | pass 63 | elif 0 == datetime_date.weekday(): 64 | # 周一 65 | datetime_date = datetime_date - datetime.timedelta(days=3) 66 | else: 67 | # 6 == datetime_date.weekday(): 68 | # 周日 69 | datetime_date = datetime_date - datetime.timedelta(days=2) 70 | pass 71 | pass 72 | return datetime_date 73 | 74 | 75 | def get_next_day(datetime_date, next_flag=1): 76 | """ 77 | 获取下一个日期 78 | :param datetime_date: 日期类型 79 | :param next_flag: -2 前2天,+1 后1天 80 | :return: datetime.date 81 | """ 82 | if next_flag > 0: 83 | next_date = datetime_date + datetime.timedelta(days=abs(next_flag)) 84 | else: 85 | next_date = datetime_date - datetime.timedelta(days=abs(next_flag)) 86 | return next_date 87 | 88 | 89 | def get_today_date(): 90 | # 获取今天日期 91 | time_format = '%Y-%m-%d' 92 | return time.strftime(time_format, time.localtime()) 93 | 94 | 95 | def is_greater(date1, date2): 96 | """ 97 | 日期1是否大于日期2 98 | :param date1: 日期1字符串 2001-03-01 99 | :param date2: 日期2字符串 2001-01-01 100 | :return: True/Flase 101 | """ 102 | temp1 = time.strptime(date1, '%Y-%m-%d') 103 | temp2 = time.strptime(date2, '%Y-%m-%d') 104 | 105 | if temp1 > temp2: 106 | return True 107 | else: 108 | return False 109 | pass 110 | 111 | 112 | def get_begin_vali_date_list(end_vali_date): 113 | """ 114 | 获取7个日期列表 115 | :return: list() 116 | """ 117 | list_result = list() 118 | # for work_days in [20, 30, 40, 50, 60, 72, 90, 100, 150, 200, 300, 500, 518, 1000, 1200, 1268]: 119 | for work_days in [30, 40, 50, 60, 72, 90, 100, 150, 200, 300, 500, 518, 1000, 1200]: 120 | begin_vali_date = get_next_work_day(end_vali_date, next_flag=-work_days) 121 | list_result.append((work_days, begin_vali_date)) 122 | 123 | list_result.reverse() 124 | 125 | return list_result 126 | 127 | 128 | def get_end_vali_date_list(begin_vali_date): 129 | """ 130 | 获取7个日期列表 131 | :return: list() 132 | """ 133 | 134 | list_result = list() 135 | 136 | # # 20周期 137 | # end_vali_date = get_next_day(begin_vali_date, next_flag=28) 138 | # list_result.append((20, end_vali_date)) 139 | # 140 | # # 30周期 141 | # end_vali_date = get_next_day(begin_vali_date, next_flag=42) 142 | # list_result.append((30, end_vali_date)) 143 | # 144 | # # 40周期 145 | # end_vali_date = get_next_day(begin_vali_date, next_flag=56) 146 | # list_result.append((40, end_vali_date)) 147 | # 148 | # # 50周期 149 | # end_vali_date = get_next_day(begin_vali_date, next_flag=77) 150 | # list_result.append((50, end_vali_date)) 151 | # 152 | # # 60周期 153 | # end_vali_date = get_next_day(begin_vali_date, next_flag=91) 154 | # list_result.append((60, end_vali_date)) 155 | # 156 | # # 72周期 157 | # end_vali_date = get_next_day(begin_vali_date, next_flag=108) 158 | # list_result.append((72, end_vali_date)) 159 | # 160 | # # 90周期 161 | # end_vali_date = get_next_day(begin_vali_date, next_flag=134) 162 | # list_result.append((90, end_vali_date)) 163 | 164 | # 50周期 165 | work_days = 50 166 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days) 167 | list_result.append((work_days, end_vali_date)) 168 | 169 | # 100周期 170 | work_days = 100 171 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days) 172 | list_result.append((work_days, end_vali_date)) 173 | 174 | # 150周期 175 | work_days = 150 176 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days) 177 | list_result.append((work_days, end_vali_date)) 178 | 179 | # 200周期 180 | work_days = 200 181 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days) 182 | list_result.append((work_days, end_vali_date)) 183 | 184 | # 300周期 185 | work_days = 300 186 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days) 187 | list_result.append((work_days, end_vali_date)) 188 | 189 | # 500周期 190 | work_days = 500 191 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days) 192 | list_result.append((work_days, end_vali_date)) 193 | 194 | # 1000周期 195 | work_days = 1000 196 | end_vali_date = get_next_work_day(begin_vali_date, next_flag=work_days) 197 | list_result.append((work_days, end_vali_date)) 198 | 199 | return list_result 200 | 201 | 202 | def get_week_day(string_time_point): 203 | week_day_dict = { 204 | 0: '周一', 205 | 1: '周二', 206 | 2: '周三', 207 | 3: '周四', 208 | 4: '周五', 209 | 5: '周六', 210 | 6: '周日', 211 | } 212 | 213 | day_of_week = datetime.datetime.fromtimestamp( 214 | time.mktime(time.strptime(string_time_point, "%Y-%m-%d"))).weekday() 215 | 216 | return week_day_dict[day_of_week] 217 | -------------------------------------------------------------------------------- /reports/sqlite_to_html.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | import time 5 | 6 | import config 7 | from utils.date_time import get_week_day 8 | from utils.sqlite import SQLite 9 | 10 | if __name__ == '__main__': 11 | # 获取要输出到html的tic 12 | config.SINGLE_A_STOCK_CODE = ['sh.600036', ] 13 | 14 | # 交易详情 15 | text_trade_detail = '' 16 | 17 | # 从 index.html.template 文件中 读取 网页模板 18 | with open('./template/index.html.template', 'r') as index_page_template: 19 | text_index_page_template = index_page_template.read() 20 | 21 | all_text_card = '' 22 | 23 | # 从 card_predict_result.template 文件中 读取 结果卡片 模板 24 | with open('./template/card_predict_result.template', 'r') as result_card_template: 25 | text_card = result_card_template.read() 26 | 27 | db_path = '../stock_db/stock.db' 28 | 29 | # 连接数据库 30 | sqlite = SQLite(db_path) 31 | 32 | for tic in config.SINGLE_A_STOCK_CODE: 33 | 34 | table_name = tic + '_report' 35 | 36 | # 复制一个结果卡片模板,将以下信息 填充到卡片中 37 | copy_text_card = text_card 38 | 39 | # 获取数据库中最大hold,最大action,用于计算百分比 40 | sql_cmd = f'SELECT abs("hold") FROM "{table_name}" ORDER BY abs("hold") DESC LIMIT 1' 41 | max_hold = sqlite.fetchone(sql_cmd) 42 | max_hold = max_hold[0] 43 | 44 | # max_action 45 | sql_cmd = f'SELECT abs("action") FROM "{table_name}" ORDER BY abs("action") DESC LIMIT 1' 46 | max_action = sqlite.fetchone(sql_cmd) 47 | max_action = max_action[0] 48 | 49 | # 获取数据库中最大日期 50 | sql_cmd = f'SELECT "date" FROM "{table_name}" ORDER BY "date" DESC LIMIT 1' 51 | max_date = sqlite.fetchone(sql_cmd) 52 | max_date = str(max_date[0]) 53 | 54 | # 用此最大日期查询出一批数据 55 | sql_cmd = f'SELECT "id", "agent", "vali_period_value", "pred_period_name", "action", "hold", "day", "episode_return", "max_return", "trade_detail" ' \ 56 | f'FROM "{table_name}" WHERE "date" = \'{max_date}\' ORDER BY episode_return DESC' 57 | 58 | list_result = sqlite.fetchall(sql_cmd) 59 | 60 | text_table_tr_td = '' 61 | 62 | for item_result in list_result: 63 | # 替换字符串内容 64 | copy_text_card = copy_text_card.replace('<%tic%>', tic) 65 | copy_text_card = copy_text_card.replace('<%tic_no_dot%>', tic.replace('.', '')) 66 | 67 | id1, agent1, vali_period_value1, pred_period_name1, action1, hold1, day1, episode_return1, max_return1, trade_detail1 = item_result 68 | 69 | # <%day%> 70 | copy_text_card = copy_text_card.replace('<%day%>', day1) 71 | 72 | # 改为百分比 73 | action1 = float(action1) 74 | max_action = float(max_action) 75 | hold1 = float(hold1) 76 | max_hold = float(max_hold) 77 | episode_return1 = float(episode_return1) 78 | max_return1 = float(max_return1) 79 | 80 | action1 = round(action1 * 100 / max_action, 0) 81 | hold1 = round(hold1 * 100 / max_hold, 0) 82 | # agent1 83 | agent1 = agent1[5:] 84 | 85 | # 回报 86 | episode_return1 = round((episode_return1 - 1) * 100, 2) 87 | max_return1 = round((max_return1 - 1) * 100, 2) 88 | 89 | text_table_tr_td += f'' \ 90 | f'{episode_return1}% / {max_return1}%' \ 91 | f'{action1}%' \ 92 | f'{hold1}%' \ 93 | f'{agent1}' \ 94 | f'{vali_period_value1}天' \ 95 | f'第{day1}/{pred_period_name1}天' \ 96 | f'' 97 | 98 | # 交易详情,trade_detail1,保存为独立文件 99 | 100 | text_trade_detail += f'\r\n{"-" * 20} {agent1} {vali_period_value1}天 {"-" * 20}\r\n' 101 | 102 | text_trade_detail += f'{episode_return1}% / {max_return1}% ' \ 103 | f' {action1}% ' \ 104 | f' {hold1}% ' \ 105 | f' {agent1} ' \ 106 | f' {vali_period_value1}天 ' \ 107 | f' 第{day1}/{pred_period_name1}天\r\n' 108 | 109 | text_trade_detail += '\r\n交易详情\r\n\r\n' 110 | 111 | text_trade_detail += trade_detail1 112 | 113 | pass 114 | pass 115 | 116 | # 日期 117 | date1 = max_date + ' ' + get_week_day(max_date) 118 | copy_text_card = copy_text_card.replace('<%date%>', date1) 119 | 120 | # 按 hold 分组,选出数量最多的 hold 121 | sql_cmd = f'SELECT "hold", COUNT(id) as count1 ' \ 122 | f'FROM "{table_name}" WHERE "date" = \'{max_date}\' GROUP BY "hold"' \ 123 | f' ORDER BY count1 DESC, abs("hold") DESC LIMIT 1' 124 | 125 | most_hold = sqlite.fetchone(sql_cmd)[0] 126 | most_hold = float(most_hold) 127 | most_hold = round(most_hold * 100 / max_hold, 0) 128 | copy_text_card = copy_text_card.replace('<%most_hold%>', str(most_hold)) 129 | 130 | # 按 action 分组,取数量最多的 action 131 | sql_cmd = f'SELECT "action", COUNT("id") as count1 ' \ 132 | f'FROM "{table_name}" WHERE "date" = \'{max_date}\' GROUP BY "action"' \ 133 | f' ORDER BY count1 DESC, abs("action") DESC LIMIT 1' 134 | 135 | most_action = sqlite.fetchone(sql_cmd)[0] 136 | most_action = float(most_action) 137 | most_action = round(most_action * 100 / max_action, 0) 138 | copy_text_card = copy_text_card.replace('<%most_action%>', str(most_action)) 139 | 140 | # 表格 141 | all_text_card += copy_text_card.replace('<%predict_result_table_tr_td%>', text_table_tr_td) 142 | all_text_card += '\r\n' 143 | 144 | if text_trade_detail is not '': 145 | # 写入交易详情文件 146 | with open(f'./{tic}.txt', 'w') as file_detail: 147 | file_detail.write(text_trade_detail) 148 | pass 149 | pass 150 | pass 151 | pass 152 | 153 | sqlite.close() 154 | pass 155 | 156 | # 将多个 卡片模板 替换到 网页模板 157 | text_index_page_template = text_index_page_template.replace('<%page_content%>', all_text_card) 158 | 159 | current_time_point = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 160 | text_index_page_template = text_index_page_template.replace('<%page_time_point%>', current_time_point) 161 | 162 | text_index_page_template = text_index_page_template.replace('<%page_title%>', 'A股预测') 163 | 164 | INDEX_HTML_PAGE_PATH = './index.html' 165 | 166 | # 写入网页文件 167 | with open(INDEX_HTML_PAGE_PATH, 'w') as file_index: 168 | file_index.write(text_index_page_template) 169 | pass 170 | pass 171 | 172 | pass 173 | -------------------------------------------------------------------------------- /train_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | from datetime import datetime 5 | 6 | import config 7 | import numpy as np 8 | from stock_data import StockData 9 | 10 | from agent import AgentPPO, AgentSAC, AgentTD3, AgentDDPG, AgentModSAC, AgentDuelingDQN, AgentSharedSAC, \ 11 | AgentDoubleDQN 12 | 13 | from run_single import Arguments, train_and_evaluate_mp, train_and_evaluate 14 | from train_helper import init_model_hyper_parameters_table_sqlite, query_model_hyper_parameters_sqlite, \ 15 | update_model_hyper_parameters_by_train_history, clear_train_history_table_sqlite 16 | from utils import date_time 17 | 18 | from utils.date_time import get_datetime_from_date_str, get_next_work_day, \ 19 | get_today_date 20 | # from env_train_single import StockTradingEnv 21 | from env_single import StockTradingEnvSingle 22 | 23 | if __name__ == '__main__': 24 | 25 | # 初始化超参表 26 | init_model_hyper_parameters_table_sqlite() 27 | 28 | # 开始训练的日期,在程序启动之后,不要改变 29 | config.SINGLE_A_STOCK_CODE = ['sh.600036', ] 30 | 31 | # 初始现金 32 | initial_capital = 150000 33 | 34 | # 单次 购买/卖出 最大股数 35 | max_stock = 3000 36 | 37 | initial_stocks_train = np.zeros(len(config.SINGLE_A_STOCK_CODE), dtype=np.float32) 38 | initial_stocks_vali = np.zeros(len(config.SINGLE_A_STOCK_CODE), dtype=np.float32) 39 | 40 | # 默认持有0-3000股 41 | initial_stocks_train[0] = 3000.0 42 | initial_stocks_vali[0] = 3000.0 43 | 44 | if_on_policy = False 45 | # if_use_gae = True 46 | 47 | config.IF_ACTUAL_PREDICT = False 48 | 49 | config.START_DATE = "2003-05-01" 50 | config.START_EVAL_DATE = "" 51 | 52 | # 整体结束日期,今天的日期,减去60工作日 53 | predict_work_days = 60 54 | 55 | config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), -predict_work_days)) 56 | # config.END_DATE = '2021-07-21' 57 | 58 | # 更新股票数据,不复权,表名 fe_fillzero 59 | StockData.update_stock_data_to_sqlite(list_stock_code=config.SINGLE_A_STOCK_CODE, adjustflag='3', table_name='fe_fillzero') 60 | 61 | # 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(), 62 | # AgentDoubleDQN 单进程好用? 63 | # 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC() 64 | 65 | loop_index = 0 66 | 67 | # 循环 68 | while True: 69 | 70 | # 清空训练历史记录表 71 | clear_train_history_table_sqlite() 72 | 73 | # 从 model_hyper_parameters 表中,找到 training_times 最小的记录 74 | # 获取超参 75 | hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \ 76 | eval_reward_scale, training_times, time_point = query_model_hyper_parameters_sqlite() 77 | 78 | if if_on_policy == 'True': 79 | if_on_policy = True 80 | else: 81 | if_on_policy = False 82 | pass 83 | 84 | config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id) 85 | 86 | # 获得Agent参数 87 | agent_class = None 88 | train_reward_scaling = 2 ** train_reward_scale 89 | eval_reward_scaling = 2 ** eval_reward_scale 90 | 91 | # 模型名称 92 | config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0] 93 | 94 | if config.AGENT_NAME == 'AgentPPO': 95 | agent_class = AgentPPO() 96 | elif config.AGENT_NAME == 'AgentSAC': 97 | agent_class = AgentSAC() 98 | elif config.AGENT_NAME == 'AgentTD3': 99 | agent_class = AgentTD3() 100 | elif config.AGENT_NAME == 'AgentDDPG': 101 | agent_class = AgentDDPG() 102 | elif config.AGENT_NAME == 'AgentModSAC': 103 | agent_class = AgentModSAC() 104 | elif config.AGENT_NAME == 'AgentDuelingDQN': 105 | agent_class = AgentDuelingDQN() 106 | elif config.AGENT_NAME == 'AgentSharedSAC': 107 | agent_class = AgentSharedSAC() 108 | elif config.AGENT_NAME == 'AgentDoubleDQN': 109 | agent_class = AgentDoubleDQN() 110 | pass 111 | 112 | # 预测周期 113 | work_days = int(str(hyper_parameters_model_name).split('_')[1]) 114 | 115 | # 预测的截止日期 116 | end_vali_date = get_datetime_from_date_str(config.END_DATE) 117 | 118 | # 开始预测日期 119 | begin_date = date_time.get_next_work_day(end_vali_date, next_flag=-work_days) 120 | 121 | # 更新工作日标记,用于 run_single.py 加载训练过的 weights 文件 122 | config.VALI_DAYS_FLAG = str(work_days) 123 | 124 | model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}' \ 125 | f'/single_{config.VALI_DAYS_FLAG}' 126 | 127 | if not os.path.exists(model_folder_path): 128 | os.makedirs(model_folder_path) 129 | pass 130 | 131 | # 开始预测的日期 132 | config.START_EVAL_DATE = str(begin_date) 133 | 134 | print('\r\n') 135 | print('-' * 40) 136 | print('config.AGENT_NAME', config.AGENT_NAME) 137 | print('# 训练-预测周期', config.START_DATE, '-', config.START_EVAL_DATE, '-', config.END_DATE) 138 | print('# work_days', work_days) 139 | print('# model_folder_path', model_folder_path) 140 | print('# initial_capital', initial_capital) 141 | print('# max_stock', max_stock) 142 | 143 | args = Arguments(if_on_policy=if_on_policy) 144 | args.agent = agent_class 145 | # args.agent.if_use_gae = if_use_gae 146 | args.agent.lambda_entropy = 0.04 147 | args.gpu_id = 0 148 | 149 | tech_indicator_list = [ 150 | 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30', 151 | 'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST 152 | 153 | gamma = 0.99 154 | 155 | print('# initial_stocks_train', initial_stocks_train) 156 | print('# initial_stocks_vali', initial_stocks_vali) 157 | 158 | buy_cost_pct = 0.003 159 | sell_cost_pct = 0.003 160 | start_date = config.START_DATE 161 | start_eval_date = config.START_EVAL_DATE 162 | end_eval_date = config.END_DATE 163 | 164 | # train 165 | args.env = StockTradingEnvSingle(cwd='', gamma=gamma, max_stock=max_stock, 166 | initial_capital=initial_capital, 167 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, start_date=start_date, 168 | end_date=start_eval_date, env_eval_date=end_eval_date, 169 | ticker_list=config.SINGLE_A_STOCK_CODE, 170 | tech_indicator_list=tech_indicator_list, initial_stocks=initial_stocks_train, 171 | if_eval=False) 172 | 173 | # eval 174 | args.env_eval = StockTradingEnvSingle(cwd='', gamma=gamma, max_stock=max_stock, 175 | initial_capital=initial_capital, 176 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, 177 | start_date=start_date, 178 | end_date=start_eval_date, env_eval_date=end_eval_date, 179 | ticker_list=config.SINGLE_A_STOCK_CODE, 180 | tech_indicator_list=tech_indicator_list, 181 | initial_stocks=initial_stocks_vali, 182 | if_eval=True) 183 | 184 | args.env.target_return = 10 185 | args.env_eval.target_return = 10 186 | 187 | # 奖励 比例 188 | args.env.reward_scaling = train_reward_scaling 189 | args.env_eval.reward_scaling = eval_reward_scaling 190 | 191 | print('train reward_scaling', args.env.reward_scaling) 192 | print('eval reward_scaling', args.env_eval.reward_scaling) 193 | 194 | # Hyperparameters 195 | args.gamma = gamma 196 | # args.gamma = 0.99 197 | 198 | # reward_scaling 在 args.env里调整了,这里不动 199 | args.reward_scale = 2 ** 0 200 | 201 | # args.break_step = int(break_step / 30) 202 | args.break_step = break_step 203 | 204 | print('break_step', args.break_step) 205 | 206 | args.net_dim = 2 ** 9 207 | args.max_step = args.env.max_step 208 | 209 | # args.max_memo = args.max_step * 4 210 | args.max_memo = (args.max_step - 1) * 8 211 | 212 | args.batch_size = 2 ** 12 213 | # args.batch_size = 2305 214 | print('batch_size', args.batch_size) 215 | 216 | # ---- 217 | # args.repeat_times = 2 ** 3 218 | args.repeat_times = 2 ** 4 219 | # ---- 220 | 221 | args.eval_gap = 2 ** 4 222 | args.eval_times1 = 2 ** 3 223 | args.eval_times2 = 2 ** 5 224 | 225 | args.if_allow_break = False 226 | 227 | args.rollout_num = 2 # the number of rollout workers (larger is not always faster) 228 | 229 | # train_and_evaluate(args) 230 | train_and_evaluate_mp(args) # the training process will terminate once it reaches the target reward. 231 | 232 | # 保存训练后的模型 233 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth', f'{model_folder_path}/actor.pth') 234 | 235 | # 多留一份 236 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth', 237 | f'{model_folder_path}/{date_time.time_point(time_format="%Y%m%d_%H%M%S")}.pth') 238 | 239 | # 保存训练曲线图 240 | # plot_learning_curve.jpg 241 | timepoint_temp = date_time.time_point() 242 | plot_learning_curve_file_path = f'{model_folder_path}/{timepoint_temp}.jpg' 243 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/plot_learning_curve.jpg', 244 | plot_learning_curve_file_path) 245 | 246 | # 训练结束后,model_hyper_parameters 表 中的 训练的次数 +1,训练的时间点 更新。 247 | 248 | # 判断 train_history 表,是否有记录,如果有,则整除 256 + 128。将此值更新到 model_hyper_parameters 表的 超参,减去相应的值。 249 | 250 | update_model_hyper_parameters_by_train_history(model_hyper_parameters_id=hyper_parameters_id, 251 | origin_train_reward_scale=train_reward_scale, 252 | origin_eval_reward_scale=eval_reward_scale, 253 | origin_training_times=training_times) 254 | 255 | print('>', config.AGENT_NAME, break_step, 'steps') 256 | 257 | # 循环次数 258 | loop_index += 1 259 | 260 | # # 5个模型都摸一遍,退出 261 | # if loop_index == 5: 262 | # break 263 | 264 | # print('>', 'while 循环次数', loop_index, '\r\n') 265 | 266 | print('sleep 10 秒\r\n') 267 | time.sleep(10) 268 | 269 | # TODO 训练一次退出 270 | break 271 | 272 | pass 273 | -------------------------------------------------------------------------------- /train_batch_bak.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | from datetime import datetime 5 | 6 | import config 7 | import numpy as np 8 | from stock_data import StockData 9 | 10 | from agent import AgentPPO, AgentSAC, AgentTD3, AgentDDPG, AgentModSAC, AgentDuelingDQN, AgentSharedSAC, \ 11 | AgentDoubleDQN 12 | 13 | from run_batch import Arguments, train_and_evaluate_mp, train_and_evaluate 14 | from train_helper import init_model_hyper_parameters_table_sqlite, query_model_hyper_parameters_sqlite, \ 15 | update_model_hyper_parameters_by_train_history, clear_train_history_table_sqlite 16 | from utils import date_time 17 | 18 | from utils.date_time import get_datetime_from_date_str, get_next_work_day, \ 19 | get_today_date 20 | from env_batch import StockTradingEnvBatch 21 | 22 | if __name__ == '__main__': 23 | # 开始预测的时间 24 | time_begin = datetime.now() 25 | 26 | # 初始化超参表 27 | init_model_hyper_parameters_table_sqlite() 28 | 29 | fe_table_name = 'fe_fillzero_train' 30 | 31 | # 2003年组,用 sz.000028 作为代号 32 | # 股票的顺序,不要改变 33 | # config.BATCH_A_STOCK_CODE = ['sz.000028', 'sh.600585', 'sz.000538', 'sh.600036'] 34 | config.START_DATE = "2004-05-01" 35 | 36 | config.BATCH_A_STOCK_CODE = StockData.get_batch_a_share_code_list_string(table_name='tic_list_275') 37 | 38 | # 初始现金,每只股票15万元 39 | initial_capital = 150000 * len(config.BATCH_A_STOCK_CODE) 40 | 41 | # 单次 购买/卖出 最大股数 42 | # TODO 根据每只最近收盘价,得到多只股票平均价,计算 单次购买/卖出 最大股数 43 | max_stock = 3000 44 | 45 | initial_stocks_train = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32) 46 | initial_stocks_vali = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32) 47 | 48 | # 默认持有0-3000股 49 | initial_stocks_train = max_stock * initial_stocks_train 50 | initial_stocks_vali = max_stock * initial_stocks_vali 51 | 52 | # if_on_policy = False 53 | 54 | config.IF_ACTUAL_PREDICT = False 55 | 56 | config.START_EVAL_DATE = "" 57 | 58 | # 整体结束日期,今天的日期,减去60工作日 59 | predict_work_days = 60 60 | 61 | config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), -predict_work_days)) 62 | # config.END_DATE = '2021-07-21' 63 | 64 | # TODO 因为数据量大,在 stock_data.py 中更新 65 | # # 更新股票数据,不复权 66 | # StockData.update_stock_data_to_sqlite(list_stock_code=config.BATCH_A_STOCK_CODE, adjustflag='3', 67 | # table_name=fe_table_name, if_incremental_update=False) 68 | 69 | # 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(), 70 | # AgentDoubleDQN 单进程好用? 71 | # 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC() 72 | 73 | loop_index = 0 74 | 75 | # 循环 76 | while True: 77 | 78 | # 清空训练历史记录表 79 | clear_train_history_table_sqlite() 80 | 81 | # 从 model_hyper_parameters 表中,找到 training_times 最小的记录 82 | # 获取超参 83 | hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \ 84 | eval_reward_scale, training_times, time_point = query_model_hyper_parameters_sqlite() 85 | 86 | if if_on_policy == 'True': 87 | if_on_policy = True 88 | else: 89 | if_on_policy = False 90 | pass 91 | 92 | config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id) 93 | 94 | # 获得Agent参数 95 | agent_class = None 96 | train_reward_scaling = 2 ** train_reward_scale 97 | eval_reward_scaling = 2 ** eval_reward_scale 98 | 99 | # 模型名称 100 | config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0] 101 | 102 | if config.AGENT_NAME == 'AgentPPO': 103 | agent_class = AgentPPO() 104 | elif config.AGENT_NAME == 'AgentSAC': 105 | agent_class = AgentSAC() 106 | elif config.AGENT_NAME == 'AgentTD3': 107 | agent_class = AgentTD3() 108 | elif config.AGENT_NAME == 'AgentDDPG': 109 | agent_class = AgentDDPG() 110 | elif config.AGENT_NAME == 'AgentModSAC': 111 | agent_class = AgentModSAC() 112 | elif config.AGENT_NAME == 'AgentDuelingDQN': 113 | agent_class = AgentDuelingDQN() 114 | elif config.AGENT_NAME == 'AgentSharedSAC': 115 | agent_class = AgentSharedSAC() 116 | elif config.AGENT_NAME == 'AgentDoubleDQN': 117 | agent_class = AgentDoubleDQN() 118 | pass 119 | 120 | # 预测周期 121 | work_days = int(str(hyper_parameters_model_name).split('_')[1]) 122 | 123 | # 预测的截止日期 124 | end_vali_date = get_datetime_from_date_str(config.END_DATE) 125 | 126 | # 开始预测日期 127 | begin_date = date_time.get_next_work_day(end_vali_date, next_flag=-work_days) 128 | 129 | # 更新工作日标记,用于 run_single.py 加载训练过的 weights 文件 130 | config.VALI_DAYS_FLAG = str(work_days) 131 | 132 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \ 133 | f'batch_{config.VALI_DAYS_FLAG}' 134 | 135 | if not os.path.exists(model_folder_path): 136 | os.makedirs(model_folder_path) 137 | pass 138 | 139 | # 开始预测的日期 140 | config.START_EVAL_DATE = str(begin_date) 141 | 142 | print('\r\n') 143 | print('-' * 40) 144 | print('config.AGENT_NAME', config.AGENT_NAME) 145 | print('# 训练-预测周期', config.START_DATE, '-', config.START_EVAL_DATE, '-', config.END_DATE) 146 | print('# work_days', work_days) 147 | print('# model_folder_path', model_folder_path) 148 | print('# initial_capital', initial_capital) 149 | print('# max_stock', max_stock) 150 | 151 | args = Arguments(if_on_policy=if_on_policy) 152 | args.agent = agent_class 153 | # args.agent.if_use_gae = if_use_gae 154 | args.agent.lambda_entropy = 0.04 155 | args.gpu_id = 0 156 | 157 | tech_indicator_list = [ 158 | 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30', 159 | 'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST 160 | 161 | gamma = 0.99 162 | 163 | # print('# initial_stocks_train', initial_stocks_train) 164 | # print('# initial_stocks_vali', initial_stocks_vali) 165 | 166 | buy_cost_pct = 0.003 167 | sell_cost_pct = 0.003 168 | start_date = config.START_DATE 169 | start_eval_date = config.START_EVAL_DATE 170 | end_eval_date = config.END_DATE 171 | 172 | # train 173 | args.env = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock, 174 | initial_capital=initial_capital, 175 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, start_date=start_date, 176 | end_date=start_eval_date, env_eval_date=end_eval_date, 177 | ticker_list=config.BATCH_A_STOCK_CODE, 178 | tech_indicator_list=tech_indicator_list, initial_stocks=initial_stocks_train, 179 | if_eval=False, fe_table_name=fe_table_name) 180 | 181 | # eval 182 | args.env_eval = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock, 183 | initial_capital=initial_capital, 184 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, 185 | start_date=start_date, 186 | end_date=start_eval_date, env_eval_date=end_eval_date, 187 | ticker_list=config.BATCH_A_STOCK_CODE, 188 | tech_indicator_list=tech_indicator_list, 189 | initial_stocks=initial_stocks_vali, 190 | if_eval=True, fe_table_name=fe_table_name) 191 | 192 | args.env.target_return = 10 193 | args.env_eval.target_return = 10 194 | 195 | # 奖励 比例 196 | args.env.reward_scale = train_reward_scaling 197 | args.env_eval.reward_scale = eval_reward_scaling 198 | 199 | print('train reward_scaling', args.env.reward_scale) 200 | print('eval reward_scaling', args.env_eval.reward_scale) 201 | 202 | # Hyperparameters 203 | args.gamma = gamma 204 | # args.gamma = 0.99 205 | 206 | # reward_scaling 在 args.env里调整了,这里不动 207 | args.reward_scale = 2 ** 0 208 | 209 | # args.break_step = int(break_step / 30) 210 | args.break_step = break_step 211 | 212 | print('break_step', args.break_step) 213 | 214 | args.net_dim = 2 ** 9 215 | args.max_step = args.env.max_step 216 | 217 | # args.max_memo = args.max_step * 4 218 | args.max_memo = (args.max_step - 1) * 8 219 | 220 | args.batch_size = 2 ** 12 221 | # args.batch_size = 2305 222 | print('batch_size', args.batch_size) 223 | 224 | # ---- 225 | # args.repeat_times = 2 ** 3 226 | args.repeat_times = 2 ** 4 227 | # ---- 228 | 229 | args.eval_gap = 2 ** 4 230 | args.eval_times1 = 2 ** 3 231 | args.eval_times2 = 2 ** 5 232 | 233 | args.if_allow_break = False 234 | 235 | args.rollout_num = 2 # the number of rollout workers (larger is not always faster) 236 | 237 | # train_and_evaluate(args) 238 | train_and_evaluate_mp(args) # the training process will terminate once it reaches the target reward. 239 | 240 | # 保存训练后的模型 241 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth', f'{model_folder_path}/actor.pth') 242 | 243 | # 多留一份 244 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth', 245 | f'{model_folder_path}/{date_time.time_point(time_format="%Y%m%d_%H%M%S")}.pth') 246 | 247 | # 保存训练曲线图 248 | # plot_learning_curve.jpg 249 | timepoint_temp = date_time.time_point() 250 | plot_learning_curve_file_path = f'{model_folder_path}/{timepoint_temp}.jpg' 251 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/plot_learning_curve.jpg', 252 | plot_learning_curve_file_path) 253 | 254 | # 训练结束后,model_hyper_parameters 表 中的 训练的次数 +1,训练的时间点 更新。 255 | 256 | # 判断 train_history 表,是否有记录,如果有,则整除 256 + 128。将此值更新到 model_hyper_parameters 表的 超参,减去相应的值。 257 | update_model_hyper_parameters_by_train_history(model_hyper_parameters_id=hyper_parameters_id, 258 | origin_train_reward_scale=train_reward_scale, 259 | origin_eval_reward_scale=eval_reward_scale, 260 | origin_training_times=training_times) 261 | 262 | print('>', config.AGENT_NAME, break_step, 'steps') 263 | 264 | # 结束预测的时间 265 | time_end = datetime.now() 266 | duration = (time_end - time_begin).total_seconds() 267 | print('检测耗时', duration, '秒') 268 | 269 | # 循环次数 270 | loop_index += 1 271 | print('>', 'while 循环次数', loop_index, '\r\n') 272 | 273 | print('sleep 10 秒\r\n') 274 | time.sleep(10) 275 | 276 | # TODO 训练一次退出 277 | # break 278 | 279 | pass 280 | -------------------------------------------------------------------------------- /train_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | from datetime import datetime 5 | 6 | import config 7 | import numpy as np 8 | from stock_data import StockData 9 | 10 | from agent import AgentPPO, AgentSAC, AgentTD3, AgentDDPG, AgentModSAC, AgentDuelingDQN, AgentSharedSAC, \ 11 | AgentDoubleDQN 12 | 13 | from run_batch import Arguments, train_and_evaluate_mp, train_and_evaluate 14 | from train_helper import init_model_hyper_parameters_table_sqlite, query_model_hyper_parameters_sqlite, \ 15 | update_model_hyper_parameters_by_train_history, clear_train_history_table_sqlite 16 | from utils import date_time 17 | 18 | from utils.date_time import get_datetime_from_date_str, get_next_work_day, get_today_date 19 | from env_batch import StockTradingEnvBatch 20 | 21 | if __name__ == '__main__': 22 | # 开始预测的时间 23 | time_begin = datetime.now() 24 | 25 | # 初始化超参表 26 | init_model_hyper_parameters_table_sqlite() 27 | 28 | fe_table_name = 'fe_fillzero_train' 29 | 30 | # 股票的顺序,不要改变 31 | # config.BATCH_A_STOCK_CODE = ['sz.000028', 'sh.600585', 'sz.000538', 'sh.600036'] 32 | config.START_DATE = "2004-05-01" 33 | 34 | config.BATCH_A_STOCK_CODE = StockData.get_batch_a_share_code_list_string(table_name='tic_list_275') 35 | 36 | # 初始现金,每只股票15万元 37 | initial_capital = 150000 * len(config.BATCH_A_STOCK_CODE) 38 | 39 | # 单次 购买/卖出 最大股数 40 | max_stock = 50000 41 | 42 | initial_stocks_train = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32) 43 | initial_stocks_vali = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32) 44 | 45 | # 默认持有0-3000股 46 | initial_stocks_train = max_stock * initial_stocks_train 47 | initial_stocks_vali = max_stock * initial_stocks_vali 48 | 49 | config.IF_ACTUAL_PREDICT = False 50 | 51 | config.START_EVAL_DATE = "" 52 | 53 | # TODO 因为数据量大,在 stock_data.py 中更新 54 | # # 更新股票数据,不复权 55 | # StockData.update_stock_data_to_sqlite(list_stock_code=config.BATCH_A_STOCK_CODE, adjustflag='3', 56 | # table_name=fe_table_name, if_incremental_update=False) 57 | 58 | # 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(), 59 | # AgentDoubleDQN 单进程好用? 60 | # 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC() 61 | 62 | loop_index = 0 63 | 64 | # 循环 65 | while True: 66 | 67 | # 清空训练历史记录表 68 | clear_train_history_table_sqlite() 69 | 70 | # 从 model_hyper_parameters 表中,找到 training_times 最小的记录 71 | # 获取超参 72 | hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \ 73 | eval_reward_scale, training_times, time_point, state_amount_scale, state_price_scale, state_stocks_scale, \ 74 | state_tech_scale = query_model_hyper_parameters_sqlite() 75 | 76 | if if_on_policy == 'True': 77 | if_on_policy = True 78 | else: 79 | if_on_policy = False 80 | pass 81 | 82 | config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id) 83 | 84 | # 获得Agent参数 85 | agent_class = None 86 | 87 | # 模型名称 88 | config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0] 89 | # 模型预测周期 90 | config.AGENT_WORK_DAY = int(str(hyper_parameters_model_name).split('_')[1]) 91 | 92 | if config.AGENT_NAME == 'AgentPPO': 93 | agent_class = AgentPPO() 94 | elif config.AGENT_NAME == 'AgentSAC': 95 | agent_class = AgentSAC() 96 | elif config.AGENT_NAME == 'AgentTD3': 97 | agent_class = AgentTD3() 98 | elif config.AGENT_NAME == 'AgentDDPG': 99 | agent_class = AgentDDPG() 100 | elif config.AGENT_NAME == 'AgentModSAC': 101 | agent_class = AgentModSAC() 102 | elif config.AGENT_NAME == 'AgentDuelingDQN': 103 | agent_class = AgentDuelingDQN() 104 | elif config.AGENT_NAME == 'AgentSharedSAC': 105 | agent_class = AgentSharedSAC() 106 | elif config.AGENT_NAME == 'AgentDoubleDQN': 107 | agent_class = AgentDoubleDQN() 108 | pass 109 | 110 | # 更新工作日标记,用于 run_single.py 加载训练过的 weights 文件 111 | config.VALI_DAYS_FLAG = str(config.AGENT_WORK_DAY) 112 | 113 | # TODO 整体结束日期,今天的日期,预留60个工作日,用于验证predict 114 | 115 | # config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), 116 | # -1 * config.AGENT_WORK_DAY)) 117 | 118 | config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), -60)) 119 | 120 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \ 121 | f'batch_{config.VALI_DAYS_FLAG}' 122 | 123 | if not os.path.exists(model_folder_path): 124 | os.makedirs(model_folder_path) 125 | pass 126 | 127 | # 开始预测的日期 128 | config.START_EVAL_DATE = str( 129 | get_next_work_day(get_datetime_from_date_str(get_today_date()), -2 * config.AGENT_WORK_DAY)) 130 | 131 | print('\r\n') 132 | print('-' * 40) 133 | print('config.AGENT_NAME', config.AGENT_NAME) 134 | print('# 训练-预测周期', config.START_DATE, '-', config.START_EVAL_DATE, '-', config.END_DATE) 135 | print('# work_days', config.AGENT_WORK_DAY) 136 | print('# model_folder_path', model_folder_path) 137 | print('# initial_capital', initial_capital) 138 | print('# max_stock', max_stock) 139 | 140 | args = Arguments(if_on_policy=if_on_policy) 141 | args.agent = agent_class 142 | # args.agent.if_use_gae = if_use_gae 143 | args.agent.lambda_entropy = 0.04 144 | args.gpu_id = 0 145 | 146 | tech_indicator_list = [ 147 | 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30', 148 | 'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST 149 | 150 | gamma = 0.99 151 | 152 | buy_cost_pct = 0.003 153 | sell_cost_pct = 0.003 154 | start_date = config.START_DATE 155 | start_eval_date = config.START_EVAL_DATE 156 | end_eval_date = config.END_DATE 157 | 158 | # train 159 | args.env = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock, 160 | initial_capital=initial_capital, 161 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, start_date=start_date, 162 | end_date=start_eval_date, env_eval_date=end_eval_date, 163 | ticker_list=config.BATCH_A_STOCK_CODE, 164 | tech_indicator_list=tech_indicator_list, initial_stocks=initial_stocks_train, 165 | if_eval=False, fe_table_name=fe_table_name) 166 | 167 | # eval 168 | args.env_eval = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock, 169 | initial_capital=initial_capital, 170 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, 171 | start_date=start_date, 172 | end_date=start_eval_date, env_eval_date=end_eval_date, 173 | ticker_list=config.BATCH_A_STOCK_CODE, 174 | tech_indicator_list=tech_indicator_list, 175 | initial_stocks=initial_stocks_vali, 176 | if_eval=True, fe_table_name=fe_table_name) 177 | 178 | args.env.target_return = 100 179 | args.env_eval.target_return = 100 180 | 181 | # 奖励 比例 182 | args.env.reward_scale = train_reward_scale 183 | args.env_eval.reward_scale = eval_reward_scale 184 | 185 | args.env.state_amount_scale = state_amount_scale 186 | args.env.state_price_scale = state_price_scale 187 | args.env.state_stocks_scale = state_stocks_scale 188 | args.env.state_tech_scale = state_tech_scale 189 | 190 | args.env_eval.state_amount_scale = state_amount_scale 191 | args.env_eval.state_price_scale = state_price_scale 192 | args.env_eval.state_stocks_scale = state_stocks_scale 193 | args.env_eval.state_tech_scale = state_tech_scale 194 | 195 | print('train reward_scale', args.env.reward_scale) 196 | print('eval reward_scale', args.env_eval.reward_scale) 197 | 198 | print('state_amount_scale', state_amount_scale) 199 | print('state_price_scale', state_price_scale) 200 | print('state_stocks_scale', state_stocks_scale) 201 | print('state_tech_scale', state_tech_scale) 202 | 203 | # Hyperparameters 204 | args.gamma = gamma 205 | # args.gamma = 0.99 206 | 207 | # reward_scaling 在 args.env里调整了,这里不动 208 | # args.reward_scale = 2 ** 0 209 | args.reward_scale = 1 210 | 211 | # args.break_step = int(break_step / 30) 212 | args.break_step = break_step 213 | 214 | print('break_step', args.break_step) 215 | 216 | args.net_dim = 2 ** 9 217 | args.max_step = args.env.max_step 218 | 219 | # args.max_memo = args.max_step * 4 220 | args.max_memo = (args.max_step - 1) * 8 221 | 222 | args.batch_size = 2 ** 12 223 | # args.batch_size = 2305 224 | print('batch_size', args.batch_size) 225 | 226 | # ---- 227 | # args.repeat_times = 2 ** 3 228 | args.repeat_times = 2 ** 4 229 | # ---- 230 | 231 | args.eval_gap = 2 ** 4 232 | args.eval_times1 = 2 ** 3 233 | args.eval_times2 = 2 ** 5 234 | 235 | args.if_allow_break = False 236 | 237 | args.rollout_num = 2 # the number of rollout workers (larger is not always faster) 238 | 239 | # train_and_evaluate(args) 240 | train_and_evaluate_mp(args) # the training process will terminate once it reaches the target reward. 241 | 242 | # 保存训练后的模型 243 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth', f'{model_folder_path}/actor.pth') 244 | 245 | # 多留一份 246 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/actor.pth', 247 | f'{model_folder_path}/{date_time.time_point(time_format="%Y%m%d_%H%M%S")}.pth') 248 | 249 | # 保存训练曲线图 250 | # plot_learning_curve.jpg 251 | timepoint_temp = date_time.time_point() 252 | plot_learning_curve_file_path = f'{model_folder_path}/{timepoint_temp}.jpg' 253 | shutil.copyfile(f'./{config.WEIGHTS_PATH}/StockTradingEnv-v1/plot_learning_curve.jpg', 254 | plot_learning_curve_file_path) 255 | 256 | # 训练结束后,model_hyper_parameters 表 中的 训练的次数 +1,训练的时间点 更新。 257 | 258 | # 判断 train_history 表,是否有记录,如果有,则整除 256 + 128。将此值更新到 model_hyper_parameters 表的 超参,减去相应的值。 259 | update_model_hyper_parameters_by_train_history(model_hyper_parameters_id=hyper_parameters_id, 260 | origin_train_reward_scale=train_reward_scale, 261 | origin_eval_reward_scale=eval_reward_scale, 262 | origin_training_times=training_times, 263 | origin_state_amount_scale=state_amount_scale, 264 | origin_state_price_scale=state_price_scale, 265 | origin_state_stocks_scale=state_stocks_scale, 266 | origin_state_tech_scale=state_tech_scale) 267 | 268 | print('>', config.AGENT_NAME, break_step, 'steps') 269 | 270 | # 结束预测的时间 271 | time_end = datetime.now() 272 | duration = (time_end - time_begin).total_seconds() 273 | print('检测耗时', duration, '秒') 274 | 275 | # 循环次数 276 | loop_index += 1 277 | print('>', 'while 循环次数', loop_index, '\r\n') 278 | 279 | print('sleep 10 秒\r\n') 280 | time.sleep(10) 281 | 282 | # TODO 训练一次退出 283 | break 284 | 285 | pass 286 | -------------------------------------------------------------------------------- /train_helper.py: -------------------------------------------------------------------------------- 1 | import config 2 | from utils import date_time 3 | from utils.date_time import get_next_work_day 4 | from utils.sqlite import SQLite 5 | 6 | 7 | def init_model_hyper_parameters_table_sqlite(): 8 | # 初始化模型超参表 model_hyper_parameters 和 训练历史记录表 train_history 9 | 10 | time_point = date_time.time_point(time_format='%Y-%m-%d %H:%M:%S') 11 | 12 | # 连接数据库 13 | db_path = config.STOCK_DB_PATH 14 | 15 | # 表名 16 | table_name = 'model_hyper_parameters' 17 | 18 | sqlite = SQLite(db_path) 19 | 20 | if_exists = sqlite.table_exists(table_name) 21 | 22 | if if_exists is None: 23 | # 如果是初始化,则创建表 24 | sqlite.execute_non_query(sql=f'CREATE TABLE "{table_name}" (id INTEGER PRIMARY KEY AUTOINCREMENT, ' 25 | f'model_name TEXT NOT NULL UNIQUE, if_on_policy INTEGER NOT NULL, ' 26 | f'break_step INTEGER NOT NULL, train_reward_scale INTEGER NOT NULL, ' 27 | f'eval_reward_scale INTEGER NOT NULL, training_times INTEGER NOT NULL, ' 28 | f'state_amount_scale INTEGER NOT NULL, state_price_scale INTEGER NOT NULL, ' 29 | f'state_stocks_scale INTEGER NOT NULL, state_tech_scale INTEGER NOT NULL, ' 30 | f'time_point TEXT NOT NULL, if_active INTEGER NOT NULL);') 31 | 32 | # 提交 33 | sqlite.commit() 34 | pass 35 | 36 | # 初始化默认值 37 | for agent_item in ['AgentSAC', 'AgentPPO', 'AgentTD3', 'AgentDDPG', 'AgentModSAC', ]: 38 | 39 | if agent_item == 'AgentPPO': 40 | if_on_policy = 1 41 | break_step = 50000 42 | elif agent_item == 'AgentModSAC': 43 | if_on_policy = 0 44 | break_step = 50000 45 | else: 46 | if_on_policy = 0 47 | break_step = 5000 48 | pass 49 | 50 | # 例如 2 ** -6 这里只将 -6 保存进数据库 51 | train_reward_scale = -12 52 | eval_reward_scale = -8 53 | 54 | # 训练次数 55 | training_times = 0 56 | 57 | state_amount_scale = -25 58 | state_price_scale = -9 59 | state_stocks_scale = -16 60 | state_tech_scale = -9 61 | 62 | for work_days in [300, ]: 63 | # 如果是初始化,则创建表 64 | sql_cmd = f'INSERT INTO "{table_name}" (model_name, if_on_policy, break_step, ' \ 65 | f'train_reward_scale, eval_reward_scale, ' \ 66 | f'state_amount_scale, state_price_scale, state_stocks_scale, state_tech_scale,' \ 67 | f'training_times, time_point, if_active) ' \ 68 | f'VALUES (?,?,?,?,?,?,?,?,?,?,?,1)' 69 | 70 | sql_values = (agent_item + '_' + str(work_days), if_on_policy, break_step, 71 | train_reward_scale, eval_reward_scale, 72 | state_amount_scale, state_price_scale, state_stocks_scale, state_tech_scale, 73 | training_times, time_point) 74 | 75 | sqlite.execute_non_query(sql_cmd, sql_values) 76 | 77 | pass 78 | pass 79 | 80 | # 提交 81 | sqlite.commit() 82 | pass 83 | 84 | # 表名 85 | table_name = 'train_history' 86 | 87 | if_exists = sqlite.table_exists(table_name) 88 | 89 | if if_exists is None: 90 | # 如果是初始化,则创建表 91 | sqlite.execute_non_query(sql=f'CREATE TABLE "{table_name}" (id INTEGER PRIMARY KEY AUTOINCREMENT, ' 92 | f'model_id TEXT NOT NULL, ' 93 | f'train_reward_value NUMERIC NOT NULL, eval_reward_value NUMERIC NOT NULL, ' 94 | f'state_amount_value NUMERIC NOT NULL, state_price_value NUMERIC NOT NULL, ' 95 | f'state_stocks_value NUMERIC NOT NULL, state_tech_value NUMERIC NOT NULL, ' 96 | f'time_point TEXT NOT NULL);') 97 | 98 | # 提交 99 | sqlite.commit() 100 | pass 101 | pass 102 | 103 | sqlite.close() 104 | 105 | pass 106 | 107 | 108 | def query_model_hyper_parameters_sqlite(model_name=None): 109 | # 根据 model_name 查询模型超参 110 | 111 | # 连接数据库 112 | db_path = config.STOCK_DB_PATH 113 | 114 | # 表名 115 | table_name = 'model_hyper_parameters' 116 | 117 | sqlite = SQLite(db_path) 118 | 119 | # 'state_amount_scale INTEGER NOT NULL, state_price_scale INTEGER NOT NULL, ' 120 | # 'state_stocks_scale INTEGER NOT NULL, state_tech_scale INTEGER NOT NULL, ' 121 | 122 | if model_name is None: 123 | query_sql = f'SELECT id, model_name, if_on_policy, break_step, train_reward_scale, eval_reward_scale, ' \ 124 | f'training_times, time_point, state_amount_scale, state_price_scale, state_stocks_scale, ' \ 125 | f'state_tech_scale FROM "{table_name}" WHERE if_active=1 ' \ 126 | f' ORDER BY training_times ASC LIMIT 1' 127 | else: 128 | # 唯一记录 129 | query_sql = f'SELECT id, model_name, if_on_policy, break_step, train_reward_scale, eval_reward_scale, ' \ 130 | f'training_times, time_point, state_amount_scale, state_price_scale, state_stocks_scale, ' \ 131 | f'state_tech_scale FROM "{table_name}" WHERE model_name="{model_name}"' \ 132 | f' LIMIT 1' 133 | pass 134 | 135 | id1, model_name, if_on_policy, break_step, train_reward_scale, eval_reward_scale, training_times, time_point, \ 136 | state_amount_scale, state_price_scale, state_stocks_scale, state_tech_scale = sqlite.fetchone(query_sql) 137 | 138 | sqlite.close() 139 | 140 | return id1, model_name, if_on_policy, break_step, train_reward_scale, eval_reward_scale, training_times, \ 141 | time_point, state_amount_scale, state_price_scale, state_stocks_scale, state_tech_scale 142 | 143 | 144 | def update_model_hyper_parameters_table_sqlite(model_hyper_parameters_id, train_reward_scale, eval_reward_scale, 145 | training_times, state_amount_scale, state_price_scale, 146 | state_stocks_scale, state_tech_scale): 147 | time_point = date_time.time_point(time_format='%Y-%m-%d %H:%M:%S') 148 | 149 | # 更新超参表 150 | # 连接数据库 151 | db_path = config.STOCK_DB_PATH 152 | 153 | # 表名 154 | table_name = 'model_hyper_parameters' 155 | 156 | sqlite = SQLite(db_path) 157 | 158 | # 如果是初始化,则创建表 159 | sqlite.execute_non_query(sql=f'UPDATE "{table_name}" SET train_reward_scale={train_reward_scale}, ' 160 | f'eval_reward_scale={eval_reward_scale}, training_times={training_times}, ' 161 | f'time_point="{time_point}", state_amount_scale={state_amount_scale}, ' 162 | f'state_price_scale={state_price_scale}, state_stocks_scale={state_stocks_scale}, ' 163 | f'state_tech_scale={state_tech_scale} WHERE id={model_hyper_parameters_id}') 164 | # 提交 165 | sqlite.commit() 166 | 167 | sqlite.close() 168 | pass 169 | 170 | 171 | def clear_train_history_table_sqlite(): 172 | # 清空训练历史记录表 173 | # 连接数据库 174 | db_path = config.STOCK_DB_PATH 175 | 176 | # 表名 177 | table_name = 'train_history' 178 | 179 | sqlite = SQLite(db_path) 180 | 181 | sqlite.execute_non_query(sql=f'DELETE FROM "{table_name}"') 182 | # 提交 183 | sqlite.commit() 184 | pass 185 | 186 | sqlite.close() 187 | pass 188 | 189 | 190 | def insert_train_history_record_sqlite(model_id, train_reward_value=0.0, eval_reward_value=0.0, 191 | state_amount_value=0.0, state_price_value=0.0, state_stocks_value=0.0, 192 | state_tech_value=0.0): 193 | time_point = date_time.time_point(time_format='%Y-%m-%d %H:%M:%S') 194 | 195 | # 插入训练历史记录 196 | # 连接数据库 197 | db_path = config.STOCK_DB_PATH 198 | 199 | # 表名 200 | table_name = 'train_history' 201 | 202 | sqlite = SQLite(db_path) 203 | 204 | sql_cmd = f'INSERT INTO "{table_name}" ' \ 205 | f'(model_id, train_reward_value, eval_reward_value, time_point, ' \ 206 | f'state_amount_value, state_price_value, state_stocks_value, state_tech_value) VALUES (?,?,?,?,?,?,?,?);' 207 | 208 | sql_values = (model_id, train_reward_value, eval_reward_value, time_point, 209 | state_amount_value, state_price_value, state_stocks_value, state_tech_value) 210 | 211 | sqlite.execute_non_query(sql_cmd, sql_values) 212 | 213 | # 提交 214 | sqlite.commit() 215 | 216 | sqlite.close() 217 | 218 | pass 219 | 220 | 221 | def loop_scale_one(max_state_value, origin_state_scale): 222 | if max_state_value is None: 223 | new_state_scale = origin_state_scale 224 | else: 225 | i = 0 226 | while max_state_value >= 1.0: 227 | max_state_value = max_state_value / 2 228 | i += 1 229 | pass 230 | new_state_scale = origin_state_scale - i 231 | pass 232 | 233 | return new_state_scale 234 | 235 | 236 | def update_model_hyper_parameters_by_train_history(model_hyper_parameters_id, origin_train_reward_scale, 237 | origin_eval_reward_scale, origin_training_times, 238 | origin_state_amount_scale, origin_state_price_scale, 239 | origin_state_stocks_scale, origin_state_tech_scale): 240 | # 根据reward历史,更新超参表 241 | # 插入训练历史记录 242 | # 连接数据库 243 | db_path = config.STOCK_DB_PATH 244 | 245 | # 表名 246 | table_name = 'train_history' 247 | 248 | sqlite = SQLite(db_path) 249 | 250 | query_sql = f'SELECT MAX(train_reward_value), MAX(eval_reward_value), ' \ 251 | f'MAX(state_amount_value), MAX(state_price_value), ' \ 252 | f'MAX(state_stocks_value), MAX(state_tech_value) FROM "{table_name}" ' \ 253 | f' WHERE model_id="{model_hyper_parameters_id}"' 254 | 255 | max_train_reward_value, max_eval_reward_value, max_state_amount_value, max_state_price_value, \ 256 | max_state_stocks_value, max_state_tech_value = sqlite.fetchone(query_sql) 257 | 258 | sqlite.close() 259 | 260 | # reward 阈值 261 | reward_threshold = config.REWARD_THRESHOLD 262 | 263 | if max_train_reward_value is None: 264 | new_train_reward_scale = origin_train_reward_scale 265 | print('> keep origin train_reward_scale', new_train_reward_scale) 266 | pass 267 | else: 268 | if max_train_reward_value >= reward_threshold: 269 | new_train_reward_scale = origin_train_reward_scale - (max_train_reward_value // reward_threshold) 270 | print('> modify train_reward_scale:', origin_train_reward_scale, '->', new_train_reward_scale) 271 | else: 272 | new_train_reward_scale = origin_train_reward_scale 273 | print('> keep origin train_reward_scale', new_train_reward_scale) 274 | pass 275 | pass 276 | 277 | if max_eval_reward_value is None: 278 | new_eval_reward_scale = origin_eval_reward_scale 279 | print('> keep origin eval_reward_scale', new_eval_reward_scale) 280 | pass 281 | else: 282 | if max_eval_reward_value >= reward_threshold: 283 | new_eval_reward_scale = origin_eval_reward_scale - (max_eval_reward_value // reward_threshold) 284 | 285 | print('> modify eval_reward_scale:', origin_eval_reward_scale, '->', new_eval_reward_scale) 286 | pass 287 | else: 288 | new_eval_reward_scale = origin_eval_reward_scale 289 | 290 | print('> keep origin eval_reward_scale', new_eval_reward_scale) 291 | pass 292 | pass 293 | pass 294 | 295 | new_state_amount_scale = loop_scale_one(max_state_amount_value, origin_state_amount_scale) 296 | if new_state_amount_scale == origin_state_amount_scale: 297 | print('> keep origin state_amount_scale', new_state_amount_scale) 298 | else: 299 | print('> modify state_amount_scale:', origin_state_amount_scale, '->', new_state_amount_scale) 300 | pass 301 | 302 | new_state_price_scale = loop_scale_one(max_state_price_value, origin_state_price_scale) 303 | if new_state_price_scale == origin_state_price_scale: 304 | print('> keep origin state_price_scale', new_state_price_scale) 305 | else: 306 | print('> modify state_price_scale:', origin_state_price_scale, '->', new_state_price_scale) 307 | pass 308 | 309 | new_state_stocks_scale = loop_scale_one(max_state_stocks_value, origin_state_stocks_scale) 310 | if new_state_stocks_scale == origin_state_stocks_scale: 311 | print('> keep origin state_stocks_scale', new_state_stocks_scale) 312 | else: 313 | print('> modify state_stocks_scale:', origin_state_stocks_scale, '->', new_state_stocks_scale) 314 | pass 315 | 316 | new_state_tech_scale = loop_scale_one(max_state_tech_value, origin_state_tech_scale) 317 | if new_state_tech_scale == origin_state_tech_scale: 318 | print('> keep origin state_tech_scale', new_state_tech_scale) 319 | else: 320 | print('> modify state_tech_scale:', origin_state_tech_scale, '->', new_state_tech_scale) 321 | pass 322 | 323 | # 更新超参表 324 | update_model_hyper_parameters_table_sqlite(model_hyper_parameters_id=model_hyper_parameters_id, 325 | train_reward_scale=new_train_reward_scale, 326 | eval_reward_scale=new_eval_reward_scale, 327 | training_times=origin_training_times + 1, 328 | state_amount_scale=new_state_amount_scale, 329 | state_price_scale=new_state_price_scale, 330 | state_stocks_scale=new_state_stocks_scale, 331 | state_tech_scale=new_state_tech_scale) 332 | 333 | pass 334 | 335 | 336 | def query_begin_vali_date_list_by_agent_name(agent_name, end_vali_date): 337 | list_result = list() 338 | 339 | # 连接数据库 340 | db_path = config.STOCK_DB_PATH 341 | 342 | # 表名 343 | table_name = 'model_hyper_parameters' 344 | 345 | sqlite = SQLite(db_path) 346 | 347 | query_sql = f'SELECT model_name FROM "{table_name}" WHERE if_active=1 AND model_name LIKE "{agent_name}%" ' \ 348 | f' ORDER BY model_name ASC' 349 | 350 | list_temp = sqlite.fetchall(query_sql) 351 | 352 | sqlite.close() 353 | 354 | for work_days in list_temp: 355 | # AgentSAC_60 --> 60 356 | work_days = int(str(work_days[0]).split('_')[1]) 357 | begin_vali_date = get_next_work_day(end_vali_date, next_flag=-work_days) 358 | list_result.append((work_days, begin_vali_date)) 359 | pass 360 | 361 | list_temp.clear() 362 | 363 | return list_result 364 | -------------------------------------------------------------------------------- /predict_single_psql.py: -------------------------------------------------------------------------------- 1 | from stock_data import StockData 2 | # from train_single import get_agent_args 3 | from train_helper import query_model_hyper_parameters_sqlite, query_begin_vali_date_list_by_agent_name 4 | from utils.psqldb import Psqldb 5 | from agent import * 6 | from utils.date_time import * 7 | from env_single import StockTradingEnvSingle, FeatureEngineer 8 | from run_single import * 9 | from datetime import datetime 10 | 11 | 12 | def calc_max_return(price_ary, initial_capital_temp): 13 | max_return_temp = 0 14 | min_value = 0 15 | 16 | assert price_ary.shape[0] > 1 17 | 18 | count_price = price_ary.shape[0] 19 | 20 | for index_left in range(0, count_price - 1): 21 | 22 | for index_right in range(index_left + 1, count_price): 23 | 24 | assert price_ary[index_left][0] > 0 25 | 26 | assert price_ary[index_right][0] > 0 27 | 28 | temp_value = price_ary[index_right][0] - price_ary[index_left][0] 29 | 30 | if temp_value > max_return_temp: 31 | max_return_temp = temp_value 32 | # max_value = price_ary[index1][0] 33 | min_value = price_ary[index_right][0] 34 | pass 35 | pass 36 | 37 | if min_value == 0: 38 | ret = 0 39 | else: 40 | ret = (initial_capital_temp / min_value * max_return_temp + initial_capital_temp) / initial_capital_temp 41 | 42 | return ret 43 | 44 | 45 | if __name__ == '__main__': 46 | # 预测,并保存结果到 postgresql 数据库 47 | # 开始预测的时间 48 | time_begin = datetime.now() 49 | 50 | config.OUTPUT_DATE = '2021-08-03' 51 | 52 | initial_capital = 150000 53 | 54 | max_stock = 3000 55 | 56 | # for tic_item in ['sh.600036', 'sh.600667']: 57 | # 循环 58 | for tic_item in ['sh.600036', ]: 59 | 60 | # 要预测的那一天 61 | config.SINGLE_A_STOCK_CODE = [tic_item, ] 62 | 63 | # psql对象 64 | psql_object = Psqldb(database=config.PSQL_DATABASE, user=config.PSQL_USER, 65 | password=config.PSQL_PASSWORD, host=config.PSQL_HOST, port=config.PSQL_PORT) 66 | 67 | # 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(), 68 | # AgentDoubleDQN 单进程好用? 69 | # 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC() 70 | for agent_item in ['AgentSAC', 'AgentPPO', 'AgentDDPG', 'AgentTD3', 'AgentModSAC', ]: 71 | 72 | config.AGENT_NAME = agent_item 73 | # config.CWD = f'./{config.AGENT_NAME}/single/{config.SINGLE_A_STOCK_CODE[0]}/StockTradingEnv-v1' 74 | 75 | break_step = int(3e6) 76 | 77 | if_on_policy = False 78 | # if_use_gae = False 79 | 80 | # 预测的开始日期和结束日期,都固定 81 | 82 | # 日期列表 83 | # 4月16日向前,20,30,40,50,60,72,90周期 84 | # end_vali_date = get_datetime_from_date_str('2021-04-16') 85 | config.IF_ACTUAL_PREDICT = True 86 | 87 | config.START_DATE = "2003-05-01" 88 | 89 | # 前29后1 90 | config.PREDICT_PERIOD = '60' 91 | 92 | # 固定日期 93 | config.START_EVAL_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE), -10)) 94 | # config.START_EVAL_DATE = "2021-05-22" 95 | 96 | # OUTPUT_DATE 向右3工作日 97 | config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE), +1)) 98 | 99 | # 创建预测结果表 100 | StockData.create_predict_result_table_psql(list_tic=config.SINGLE_A_STOCK_CODE) 101 | 102 | # 更新股票数据 103 | StockData.update_stock_data_to_sqlite(list_stock_code=config.SINGLE_A_STOCK_CODE) 104 | 105 | # 预测的截止日期 106 | end_vali_date = get_datetime_from_date_str(config.END_DATE) 107 | 108 | # 获取 N 个日期list 109 | list_begin_vali_date = query_begin_vali_date_list_by_agent_name(agent_item, end_vali_date) 110 | 111 | # 循环 vali_date_list 训练7次 112 | for vali_days_count, begin_vali_date in list_begin_vali_date: 113 | 114 | # config.START_EVAL_DATE = str(begin_vali_date) 115 | 116 | # 更新工作日标记,用于 run_single.py 加载训练过的 weights 文件 117 | config.VALI_DAYS_FLAG = str(vali_days_count) 118 | 119 | # config.PREDICT_PERIOD = str(vali_days_count) 120 | 121 | # weights 文件目录 122 | # model_folder_path = f'./{config.AGENT_NAME}/single/{config.SINGLE_A_STOCK_CODE[0]}' \ 123 | # f'/single_{config.VALI_DAYS_FLAG}' 124 | model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}' \ 125 | f'/single_{config.VALI_DAYS_FLAG}' 126 | 127 | # 如果存在目录则预测 128 | if os.path.exists(model_folder_path): 129 | 130 | print('\r\n') 131 | print('#' * 40) 132 | print('config.AGENT_NAME', config.AGENT_NAME) 133 | print('# 预测周期', config.START_EVAL_DATE, '-', config.END_DATE) 134 | print('# 模型的 work_days', vali_days_count) 135 | print('# model_folder_path', model_folder_path) 136 | print('# initial_capital', initial_capital) 137 | print('# max_stock', max_stock) 138 | 139 | initial_stocks = np.zeros(len(config.SINGLE_A_STOCK_CODE), dtype=np.float32) 140 | initial_stocks[0] = 100.0 141 | 142 | # 获取超参 143 | model_name = agent_item + '_' + str(vali_days_count) 144 | 145 | hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \ 146 | eval_reward_scale, training_times, time_point \ 147 | = query_model_hyper_parameters_sqlite(model_name=model_name) 148 | 149 | if if_on_policy == 1: 150 | if_on_policy = True 151 | else: 152 | if_on_policy = False 153 | pass 154 | 155 | config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id) 156 | 157 | # 获得Agent参数 158 | agent_class = None 159 | train_reward_scaling = 2 ** train_reward_scale 160 | eval_reward_scaling = 2 ** eval_reward_scale 161 | 162 | # 模型名称 163 | config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0] 164 | 165 | if config.AGENT_NAME == 'AgentPPO': 166 | agent_class = AgentPPO() 167 | elif config.AGENT_NAME == 'AgentSAC': 168 | agent_class = AgentSAC() 169 | elif config.AGENT_NAME == 'AgentTD3': 170 | agent_class = AgentTD3() 171 | elif config.AGENT_NAME == 'AgentDDPG': 172 | agent_class = AgentDDPG() 173 | elif config.AGENT_NAME == 'AgentModSAC': 174 | agent_class = AgentModSAC() 175 | elif config.AGENT_NAME == 'AgentDuelingDQN': 176 | agent_class = AgentDuelingDQN() 177 | elif config.AGENT_NAME == 'AgentSharedSAC': 178 | agent_class = AgentSharedSAC() 179 | elif config.AGENT_NAME == 'AgentDoubleDQN': 180 | agent_class = AgentDoubleDQN() 181 | pass 182 | 183 | # 预测周期 184 | work_days = int(str(hyper_parameters_model_name).split('_')[1]) 185 | 186 | args = Arguments(if_on_policy=if_on_policy) 187 | args.agent = agent_class 188 | 189 | args.gpu_id = 0 190 | # args.agent.if_use_gae = if_use_gae 191 | args.agent.lambda_entropy = 0.04 192 | 193 | tech_indicator_list = [ 194 | 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30', 195 | 'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST 196 | 197 | gamma = 0.99 198 | 199 | buy_cost_pct = 0.003 200 | sell_cost_pct = 0.003 201 | start_date = config.START_DATE 202 | start_eval_date = config.START_EVAL_DATE 203 | end_eval_date = config.END_DATE 204 | 205 | args.env = StockTradingEnvSingle(cwd='', gamma=gamma, max_stock=max_stock, 206 | initial_capital=initial_capital, 207 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, 208 | start_date=start_date, 209 | end_date=start_eval_date, env_eval_date=end_eval_date, 210 | ticker_list=config.SINGLE_A_STOCK_CODE, 211 | tech_indicator_list=tech_indicator_list, 212 | initial_stocks=initial_stocks, 213 | if_eval=True) 214 | 215 | args.env_eval = StockTradingEnvSingle(cwd='', gamma=gamma, max_stock=max_stock, 216 | initial_capital=initial_capital, 217 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, 218 | start_date=start_date, 219 | end_date=start_eval_date, env_eval_date=end_eval_date, 220 | ticker_list=config.SINGLE_A_STOCK_CODE, 221 | tech_indicator_list=tech_indicator_list, 222 | initial_stocks=initial_stocks, 223 | if_eval=True) 224 | 225 | args.env.target_return = 100 226 | args.env_eval.target_return = 100 227 | 228 | # 奖励 比例 229 | args.env.reward_scaling = train_reward_scaling 230 | args.env_eval.reward_scaling = eval_reward_scaling 231 | 232 | print('train/eval reward scaling:', args.env.reward_scaling, args.env_eval.reward_scaling) 233 | 234 | # Hyperparameters 235 | args.gamma = gamma 236 | # ---- 237 | args.break_step = break_step 238 | # ---- 239 | 240 | args.net_dim = 2 ** 9 241 | args.max_step = args.env.max_step 242 | 243 | # ---- 244 | # args.max_memo = args.max_step * 4 245 | args.max_memo = (args.max_step - 1) * 8 246 | # ---- 247 | 248 | # ---- 249 | args.batch_size = 2 ** 12 250 | # args.batch_size = 2305 251 | # ---- 252 | 253 | # ---- 254 | # args.repeat_times = 2 ** 3 255 | args.repeat_times = 2 ** 4 256 | # ---- 257 | 258 | args.eval_gap = 2 ** 4 259 | args.eval_times1 = 2 ** 3 260 | args.eval_times2 = 2 ** 5 261 | 262 | # ---- 263 | args.if_allow_break = False 264 | # args.if_allow_break = True 265 | # ---- 266 | 267 | # ---------------------------- 268 | args.init_before_training() 269 | 270 | '''basic arguments''' 271 | cwd = args.cwd 272 | env = args.env 273 | agent = args.agent 274 | gpu_id = args.gpu_id # necessary for Evaluator? 275 | 276 | '''training arguments''' 277 | net_dim = args.net_dim 278 | max_memo = args.max_memo 279 | break_step = args.break_step 280 | batch_size = args.batch_size 281 | target_step = args.target_step 282 | repeat_times = args.repeat_times 283 | if_break_early = args.if_allow_break 284 | if_per = args.if_per 285 | gamma = args.gamma 286 | reward_scale = args.reward_scale 287 | 288 | '''evaluating arguments''' 289 | eval_gap = args.eval_gap 290 | eval_times1 = args.eval_times1 291 | eval_times2 = args.eval_times2 292 | if args.env_eval is not None: 293 | env_eval = args.env_eval 294 | elif args.env_eval in set(gym.envs.registry.env_specs.keys()): 295 | env_eval = PreprocessEnv(gym.make(env.env_name)) 296 | else: 297 | env_eval = deepcopy(env) 298 | 299 | del args # In order to show these hyper-parameters clearly, I put them above. 300 | 301 | '''init: environment''' 302 | max_step = env.max_step 303 | state_dim = env.state_dim 304 | action_dim = env.action_dim 305 | if_discrete = env.if_discrete 306 | 307 | '''init: Agent, ReplayBuffer, Evaluator''' 308 | agent.init(net_dim, state_dim, action_dim, if_per) 309 | 310 | # ---- 311 | # work_days,周期数,用于存储和提取训练好的模型 312 | model_file_path = f'{model_folder_path}/actor.pth' 313 | 314 | # 如果model存在,则加载 315 | if os.path.exists(model_file_path): 316 | agent.save_load_model(model_folder_path, if_save=False) 317 | 318 | '''prepare for training''' 319 | agent.state = env.reset() 320 | 321 | episode_return = 0.0 # sum of rewards in an episode 322 | episode_step = 1 323 | max_step = env.max_step 324 | if_discrete = env.if_discrete 325 | 326 | state = env.reset() 327 | 328 | with torch.no_grad(): # speed up running 329 | 330 | # for episode_step in range(max_step): 331 | while True: 332 | s_tensor = torch.as_tensor((state,), device=agent.device) 333 | a_tensor = agent.act(s_tensor) 334 | if if_discrete: 335 | a_tensor = a_tensor.argmax(dim=1) 336 | action = a_tensor.detach().cpu().numpy()[ 337 | 0] # not need detach(), because with torch.no_grad() outside 338 | state, reward, done, _ = env.step(action) 339 | episode_return += reward 340 | if done: 341 | break 342 | pass 343 | pass 344 | 345 | # 获取要预测的日期,保存到数据库中 346 | for item in env.list_buy_or_sell_output: 347 | tic, date, action, hold, day, episode_return = item 348 | if str(date) == config.OUTPUT_DATE: 349 | # 简单计算一次,低买高卖的最大回报 350 | max_return = calc_max_return(env.price_ary, env.initial_capital) 351 | 352 | # 找到要预测的那一天,存储到psql 353 | StockData.update_predict_result_to_psql(psql=psql_object, agent=config.AGENT_NAME, 354 | vali_period_value=config.VALI_DAYS_FLAG, 355 | pred_period_name=config.PREDICT_PERIOD, 356 | tic=tic, date=date, action=action, 357 | hold=hold, 358 | day=day, episode_return=episode_return, 359 | max_return=max_return, 360 | trade_detail=env.output_text_trade_detail) 361 | 362 | break 363 | 364 | pass 365 | pass 366 | pass 367 | # episode_return = getattr(env, 'episode_return', episode_return) 368 | pass 369 | else: 370 | print('未找到模型文件', model_file_path) 371 | pass 372 | # ---- 373 | 374 | pass 375 | pass 376 | 377 | psql_object.close() 378 | pass 379 | 380 | # 结束预测的时间 381 | time_end = datetime.now() 382 | duration = (time_end - time_begin).seconds 383 | print('检测耗时', duration, '秒') 384 | pass 385 | -------------------------------------------------------------------------------- /predict_batch_psql.py: -------------------------------------------------------------------------------- 1 | from stock_data import StockData 2 | from train_helper import query_model_hyper_parameters_sqlite, query_begin_vali_date_list_by_agent_name 3 | from utils.psqldb import Psqldb 4 | from agent import * 5 | from utils.date_time import * 6 | from env_batch import StockTradingEnvBatch 7 | from run_batch import * 8 | from datetime import datetime 9 | 10 | from utils.sqlite import SQLite 11 | 12 | 13 | def calc_max_return(stock_index_temp, price_ary, initial_capital_temp): 14 | max_return_temp = 0 15 | min_value = 0 16 | 17 | assert price_ary.shape[0] > 1 18 | 19 | count_price = price_ary.shape[0] 20 | 21 | for index_left in range(0, count_price - 1): 22 | 23 | for index_right in range(index_left + 1, count_price): 24 | 25 | assert price_ary[index_left][stock_index_temp] > 0 26 | 27 | assert price_ary[index_right][stock_index_temp] > 0 28 | 29 | temp_value = price_ary[index_right][stock_index_temp] - price_ary[index_left][stock_index_temp] 30 | 31 | if temp_value > max_return_temp: 32 | max_return_temp = temp_value 33 | # max_value = price_ary[index1][0] 34 | min_value = price_ary[index_right][stock_index_temp] 35 | pass 36 | pass 37 | 38 | if min_value == 0: 39 | ret = 0 40 | else: 41 | ret = (initial_capital_temp / min_value * max_return_temp + initial_capital_temp) / initial_capital_temp 42 | 43 | return ret 44 | 45 | 46 | if __name__ == '__main__': 47 | # 预测,并保存结果到 postgresql 数据库 48 | # 开始预测的时间 49 | time_begin = datetime.now() 50 | 51 | # 创建 预测汇总表 52 | StockData.create_predict_summary_table_psql(table_name='predict_summary') 53 | 54 | # 清空 预测汇总表 55 | StockData.clear_predict_summary_table_psql(table_name='predict_summary') 56 | 57 | # TODO 58 | # 获取今天日期,判断是否为工作日 59 | weekday = get_datetime_from_date_str(get_today_date()).weekday() 60 | if 0 < weekday < 6: 61 | # 工作日 62 | now = datetime.now().strftime("%H:%M") 63 | t1 = '09:00' 64 | 65 | if now >= t1: 66 | # 如果是工作日,大于等于 09:00,则预测明天 67 | config.OUTPUT_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), +1)) 68 | pass 69 | else: 70 | # 如果是工作日,小于 09:00,则预测今天 71 | config.OUTPUT_DATE = get_today_date() 72 | pass 73 | pass 74 | else: 75 | # 假期 76 | # 下一个工作日 77 | config.OUTPUT_DATE = str(get_next_work_day(get_datetime_from_date_str(get_today_date()), +1)) 78 | pass 79 | pass 80 | 81 | # 如果不是工作日,则预测下一个工作日 82 | # config.OUTPUT_DATE = '2021-08-06' 83 | 84 | config.START_DATE = "2004-05-01" 85 | 86 | # 股票的顺序,不要改变 87 | config.BATCH_A_STOCK_CODE = StockData.get_batch_a_share_code_list_string(table_name='tic_list_275') 88 | 89 | fe_table_name = 'fe_fillzero_predict' 90 | 91 | initial_capital = 150000 * len(config.BATCH_A_STOCK_CODE) 92 | 93 | max_stock = 50000 94 | 95 | config.IF_ACTUAL_PREDICT = True 96 | 97 | # 预测周期 98 | config.PREDICT_PERIOD = '10' 99 | 100 | # psql对象 101 | psql_object = Psqldb(database=config.PSQL_DATABASE, user=config.PSQL_USER, 102 | password=config.PSQL_PASSWORD, host=config.PSQL_HOST, port=config.PSQL_PORT) 103 | 104 | # if_first_time = True 105 | 106 | # 好用 AgentPPO(), # AgentSAC(), AgentTD3(), AgentDDPG(), AgentModSAC(), 107 | # AgentDoubleDQN 单进程好用? 108 | # 不好用 AgentDuelingDQN(), AgentDoubleDQN(), AgentSharedSAC() 109 | 110 | # for agent_item in ['AgentTD3', 'AgentSAC', 'AgentPPO', 'AgentDDPG', 'AgentModSAC', ]: 111 | for agent_item in ['AgentTD3', ]: 112 | 113 | config.AGENT_NAME = agent_item 114 | 115 | break_step = int(3e6) 116 | 117 | if_on_policy = False 118 | # if_use_gae = False 119 | 120 | # 固定日期 121 | config.START_EVAL_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE), 122 | -1 * int(config.PREDICT_PERIOD))) 123 | # config.START_EVAL_DATE = "2021-05-22" 124 | 125 | # OUTPUT_DATE 向右1工作日 126 | config.END_DATE = str(get_next_work_day(get_datetime_from_date_str(config.OUTPUT_DATE), +1)) 127 | 128 | # # 更新股票数据 129 | # # 只更新一次啊 130 | # if if_first_time is True: 131 | # # 创建预测结果表 132 | # StockData.create_predict_result_table_psql(list_tic=config.BATCH_A_STOCK_CODE) 133 | # 134 | # StockData.update_stock_data_to_sqlite(list_stock_code=config.BATCH_A_STOCK_CODE, adjustflag='3', 135 | # table_name=fe_table_name, if_incremental_update=False) 136 | # if_first_time = False 137 | # pass 138 | 139 | # 预测的截止日期 140 | end_vali_date = get_datetime_from_date_str(config.END_DATE) 141 | 142 | # 获取 N 个日期list 143 | list_begin_vali_date = query_begin_vali_date_list_by_agent_name(agent_item, end_vali_date) 144 | 145 | # 循环 vali_date_list 训练7次 146 | for vali_days_count, begin_vali_date in list_begin_vali_date: 147 | 148 | # 更新工作日标记,用于 run_batch.py 加载训练过的 weights 文件 149 | config.VALI_DAYS_FLAG = str(vali_days_count) 150 | 151 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \ 152 | f'batch_{config.VALI_DAYS_FLAG}' 153 | 154 | # 如果存在目录则预测 155 | if os.path.exists(model_folder_path): 156 | 157 | print('\r\n') 158 | print('#' * 40) 159 | print('config.AGENT_NAME', config.AGENT_NAME) 160 | print('# 预测周期', config.START_EVAL_DATE, '-', config.END_DATE) 161 | print('# 模型的 work_days', vali_days_count) 162 | print('# model_folder_path', model_folder_path) 163 | print('# initial_capital', initial_capital) 164 | print('# max_stock', max_stock) 165 | 166 | # initial_stocks = np.ones(len(config.BATCH_A_STOCK_CODE), dtype=np.float32) 167 | initial_stocks = np.zeros(len(config.BATCH_A_STOCK_CODE), dtype=np.float32) 168 | 169 | # 默认持有一手 170 | # initial_stocks = initial_stocks * 100.0 171 | 172 | # 获取超参 173 | model_name = agent_item + '_' + str(vali_days_count) 174 | 175 | # hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \ 176 | # eval_reward_scale, training_times, time_point \ 177 | # = query_model_hyper_parameters_sqlite(model_name=model_name) 178 | 179 | hyper_parameters_id, hyper_parameters_model_name, if_on_policy, break_step, train_reward_scale, \ 180 | eval_reward_scale, training_times, time_point, state_amount_scale, state_price_scale, state_stocks_scale, \ 181 | state_tech_scale = query_model_hyper_parameters_sqlite(model_name=model_name) 182 | 183 | if if_on_policy == 1: 184 | if_on_policy = True 185 | else: 186 | if_on_policy = False 187 | pass 188 | 189 | config.MODEL_HYPER_PARAMETERS = str(hyper_parameters_id) 190 | 191 | # 获得Agent参数 192 | agent_class = None 193 | 194 | # 模型名称 195 | config.AGENT_NAME = str(hyper_parameters_model_name).split('_')[0] 196 | 197 | if config.AGENT_NAME == 'AgentPPO': 198 | agent_class = AgentPPO() 199 | elif config.AGENT_NAME == 'AgentSAC': 200 | agent_class = AgentSAC() 201 | elif config.AGENT_NAME == 'AgentTD3': 202 | agent_class = AgentTD3() 203 | elif config.AGENT_NAME == 'AgentDDPG': 204 | agent_class = AgentDDPG() 205 | elif config.AGENT_NAME == 'AgentModSAC': 206 | agent_class = AgentModSAC() 207 | elif config.AGENT_NAME == 'AgentDuelingDQN': 208 | agent_class = AgentDuelingDQN() 209 | elif config.AGENT_NAME == 'AgentSharedSAC': 210 | agent_class = AgentSharedSAC() 211 | elif config.AGENT_NAME == 'AgentDoubleDQN': 212 | agent_class = AgentDoubleDQN() 213 | pass 214 | 215 | args = Arguments(if_on_policy=if_on_policy) 216 | args.agent = agent_class 217 | 218 | args.gpu_id = 0 219 | # args.agent.if_use_gae = if_use_gae 220 | args.agent.lambda_entropy = 0.04 221 | 222 | tech_indicator_list = [ 223 | 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30', 224 | 'close_30_sma', 'close_60_sma'] # finrl.config.TECHNICAL_INDICATORS_LIST 225 | 226 | gamma = 0.99 227 | 228 | buy_cost_pct = 0.003 229 | sell_cost_pct = 0.003 230 | start_date = config.START_DATE 231 | start_eval_date = config.START_EVAL_DATE 232 | end_eval_date = config.END_DATE 233 | 234 | args.env = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock, 235 | initial_capital=initial_capital, 236 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, 237 | start_date=start_date, 238 | end_date=start_eval_date, env_eval_date=end_eval_date, 239 | ticker_list=config.BATCH_A_STOCK_CODE, 240 | tech_indicator_list=tech_indicator_list, 241 | initial_stocks=initial_stocks, 242 | if_eval=True, fe_table_name=fe_table_name) 243 | 244 | args.env_eval = StockTradingEnvBatch(cwd='', gamma=gamma, max_stock=max_stock, 245 | initial_capital=initial_capital, 246 | buy_cost_pct=buy_cost_pct, sell_cost_pct=sell_cost_pct, 247 | start_date=start_date, 248 | end_date=start_eval_date, env_eval_date=end_eval_date, 249 | ticker_list=config.BATCH_A_STOCK_CODE, 250 | tech_indicator_list=tech_indicator_list, 251 | initial_stocks=initial_stocks, 252 | if_eval=True, fe_table_name=fe_table_name) 253 | 254 | args.env.target_return = 100 255 | args.env_eval.target_return = 100 256 | 257 | # 奖励 比例 258 | args.env.reward_scale = train_reward_scale 259 | args.env_eval.reward_scale = eval_reward_scale 260 | 261 | args.env.state_amount_scale = state_amount_scale 262 | args.env.state_price_scale = state_price_scale 263 | args.env.state_stocks_scale = state_stocks_scale 264 | args.env.state_tech_scale = state_tech_scale 265 | 266 | args.env_eval.state_amount_scale = state_amount_scale 267 | args.env_eval.state_price_scale = state_price_scale 268 | args.env_eval.state_stocks_scale = state_stocks_scale 269 | args.env_eval.state_tech_scale = state_tech_scale 270 | 271 | print('train reward_scale', args.env.reward_scale) 272 | print('eval reward_scale', args.env_eval.reward_scale) 273 | 274 | print('state_amount_scale', state_amount_scale) 275 | print('state_price_scale', state_price_scale) 276 | print('state_stocks_scale', state_stocks_scale) 277 | print('state_tech_scale', state_tech_scale) 278 | 279 | # Hyperparameters 280 | args.gamma = gamma 281 | # ---- 282 | args.break_step = break_step 283 | # ---- 284 | 285 | args.net_dim = 2 ** 9 286 | args.max_step = args.env.max_step 287 | 288 | # ---- 289 | # args.max_memo = args.max_step * 4 290 | args.max_memo = (args.max_step - 1) * 8 291 | # ---- 292 | 293 | # ---- 294 | args.batch_size = 2 ** 12 295 | # args.batch_size = 2305 296 | # ---- 297 | 298 | # ---- 299 | # args.repeat_times = 2 ** 3 300 | args.repeat_times = 2 ** 4 301 | # ---- 302 | 303 | args.eval_gap = 2 ** 4 304 | args.eval_times1 = 2 ** 3 305 | args.eval_times2 = 2 ** 5 306 | 307 | # ---- 308 | args.if_allow_break = False 309 | # args.if_allow_break = True 310 | # ---- 311 | 312 | # ---------------------------- 313 | args.init_before_training() 314 | 315 | '''basic arguments''' 316 | cwd = args.cwd 317 | env = args.env 318 | agent = args.agent 319 | gpu_id = args.gpu_id # necessary for Evaluator? 320 | 321 | '''training arguments''' 322 | net_dim = args.net_dim 323 | max_memo = args.max_memo 324 | break_step = args.break_step 325 | batch_size = args.batch_size 326 | target_step = args.target_step 327 | repeat_times = args.repeat_times 328 | if_break_early = args.if_allow_break 329 | if_per = args.if_per 330 | gamma = args.gamma 331 | reward_scale = args.reward_scale 332 | 333 | '''evaluating arguments''' 334 | eval_gap = args.eval_gap 335 | eval_times1 = args.eval_times1 336 | eval_times2 = args.eval_times2 337 | if args.env_eval is not None: 338 | env_eval = args.env_eval 339 | elif args.env_eval in set(gym.envs.registry.env_specs.keys()): 340 | env_eval = PreprocessEnv(gym.make(env.env_name)) 341 | else: 342 | env_eval = deepcopy(env) 343 | 344 | del args # In order to show these hyper-parameters clearly, I put them above. 345 | 346 | '''init: environment''' 347 | max_step = env.max_step 348 | state_dim = env.state_dim 349 | action_dim = env.action_dim 350 | if_discrete = env.if_discrete 351 | 352 | '''init: Agent, ReplayBuffer, Evaluator''' 353 | agent.init(net_dim, state_dim, action_dim, if_per) 354 | 355 | # ---- 356 | # work_days,周期数,用于存储和提取训练好的模型 357 | model_file_path = f'{model_folder_path}/actor.pth' 358 | 359 | # 如果model存在,则加载 360 | if os.path.exists(model_file_path): 361 | agent.save_load_model(model_folder_path, if_save=False) 362 | 363 | '''prepare for training''' 364 | agent.state = env.reset() 365 | 366 | episode_return_total = 0.0 # sum of rewards in an episode 367 | episode_step = 1 368 | max_step = env.max_step 369 | if_discrete = env.if_discrete 370 | 371 | state = env.reset() 372 | 373 | with torch.no_grad(): # speed up running 374 | 375 | # for episode_step in range(max_step): 376 | while True: 377 | s_tensor = torch.as_tensor((state,), device=agent.device) 378 | a_tensor = agent.act(s_tensor) 379 | if if_discrete: 380 | a_tensor = a_tensor.argmax(dim=1) 381 | action = a_tensor.detach().cpu().numpy()[ 382 | 0] # not need detach(), because with torch.no_grad() outside 383 | state, reward, done, _ = env.step(action) 384 | episode_return_total += reward 385 | 386 | if done: 387 | break 388 | pass 389 | pass 390 | 391 | # 根据tic查询名称 392 | sqlite_query_name_by_tic = SQLite(dbname=config.STOCK_DB_PATH) 393 | 394 | # 获取要预测的日期,保存到数据库中 395 | for stock_index in range(len(env.list_buy_or_sell_output)): 396 | list_one_stock = env.list_buy_or_sell_output[stock_index] 397 | 398 | trade_detail_all = '' 399 | 400 | for item in list_one_stock: 401 | tic, date, action, hold, day, episode_return, trade_detail = item 402 | 403 | trade_detail_all += trade_detail 404 | 405 | if str(date) == config.OUTPUT_DATE: 406 | # 简单计算一次,低买高卖的最大回报 407 | max_return = calc_max_return(stock_index_temp=stock_index, price_ary=env.price_ary, 408 | initial_capital_temp=env.initial_capital) 409 | 410 | stock_name = StockData.get_stock_name_by_tic(sqlite=sqlite_query_name_by_tic, 411 | tic=tic, 412 | table_name='tic_list_275') 413 | 414 | # 找到要预测的那一天,存储到psql 415 | StockData.update_predict_summary_result_to_psql(psql=psql_object, 416 | agent=config.AGENT_NAME, 417 | vali_period_value=config.VALI_DAYS_FLAG, 418 | pred_period_name=config.PREDICT_PERIOD, 419 | tic=tic, name=stock_name, date=date, 420 | action=action, 421 | hold=hold, 422 | day=day, 423 | episode_return=episode_return, 424 | max_return=max_return, 425 | trade_detail=trade_detail_all, 426 | table_name='predict_summary') 427 | 428 | break 429 | 430 | pass 431 | pass 432 | pass 433 | pass 434 | 435 | # 关闭数据库连接 436 | sqlite_query_name_by_tic.close() 437 | 438 | # episode_return = getattr(env, 'episode_return', episode_return) 439 | pass 440 | else: 441 | print('未找到模型文件', model_file_path) 442 | pass 443 | # ---- 444 | pass 445 | pass 446 | 447 | psql_object.close() 448 | pass 449 | 450 | # 结束预测的时间 451 | time_end = datetime.now() 452 | duration = (time_end - time_begin).total_seconds() 453 | print('检测耗时', duration, '秒') 454 | pass 455 | -------------------------------------------------------------------------------- /run_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import time 4 | import torch 5 | import numpy as np 6 | import numpy.random as rd 7 | from copy import deepcopy 8 | from ElegantRL_master.elegantrl.replay import ReplayBuffer, ReplayBufferMP 9 | from ElegantRL_master.elegantrl.env import PreprocessEnv 10 | import config 11 | 12 | """[ElegantRL](https://github.com/AI4Finance-LLC/ElegantRL)""" 13 | 14 | 15 | class Arguments: 16 | def __init__(self, agent=None, env=None, gpu_id=None, if_on_policy=False): 17 | self.agent = agent # Deep Reinforcement Learning algorithm 18 | 19 | self.cwd = None # current work directory. cwd is None means set it automatically 20 | self.env = env # the environment for training 21 | self.env_eval = None # the environment for evaluating 22 | self.gpu_id = gpu_id # choose the GPU for running. gpu_id is None means set it automatically 23 | 24 | '''Arguments for training (off-policy)''' 25 | self.net_dim = 2 ** 8 # the network width 26 | self.batch_size = 2 ** 8 # num of transitions sampled from replay buffer. 27 | self.repeat_times = 2 ** 0 # repeatedly update network to keep critic's loss small 28 | self.target_step = 2 ** 10 # collect target_step, then update network 29 | self.max_memo = 2 ** 17 # capacity of replay buffer 30 | if if_on_policy: # (on-policy) 31 | self.net_dim = 2 ** 9 32 | self.batch_size = 2 ** 9 33 | self.repeat_times = 2 ** 4 34 | self.target_step = 2 ** 12 35 | self.max_memo = self.target_step 36 | self.gamma = 0.99 # discount factor of future rewards 37 | self.reward_scale = 2 ** 0 # an approximate target reward usually be closed to 256 38 | self.if_per = False # Prioritized Experience Replay for sparse reward 39 | 40 | self.rollout_num = 2 # the number of rollout workers (larger is not always faster) 41 | self.num_threads = 8 # cpu_num for evaluate model, torch.set_num_threads(self.num_threads) 42 | 43 | '''Arguments for evaluate''' 44 | self.break_step = 2 ** 20 # break training after 'total_step > break_step' 45 | self.if_remove = True # remove the cwd folder? (True, False, None:ask me) 46 | self.if_allow_break = True # allow break training when reach goal (early termination) 47 | self.eval_gap = 2 ** 5 # evaluate the agent per eval_gap seconds 48 | self.eval_times1 = 2 ** 2 # evaluation times 49 | self.eval_times2 = 2 ** 4 # evaluation times if 'eval_reward > max_reward' 50 | self.random_seed = 0 # initialize random seed in self.init_before_training() 51 | 52 | def init_before_training(self, if_main=True): 53 | if self.agent is None: 54 | raise RuntimeError('\n| Why agent=None? Assignment args.agent = AgentXXX please.') 55 | if not hasattr(self.agent, 'init'): 56 | raise RuntimeError('\n| There should be agent=AgentXXX() instead of agent=AgentXXX') 57 | if self.env is None: 58 | raise RuntimeError('\n| Why env=None? Assignment args.env = XxxEnv() please.') 59 | if isinstance(self.env, str) or not hasattr(self.env, 'env_name'): 60 | raise RuntimeError('\n| What is env.env_name? use env=PreprocessEnv(env). It is a Wrapper.') 61 | 62 | '''set gpu_id automatically''' 63 | if self.gpu_id is None: # set gpu_id automatically 64 | import sys 65 | self.gpu_id = sys.argv[-1][-4] 66 | else: 67 | self.gpu_id = str(self.gpu_id) 68 | if not self.gpu_id.isdigit(): # set gpu_id as '0' in default 69 | self.gpu_id = '0' 70 | 71 | '''set cwd automatically''' 72 | if self.cwd is None: 73 | # ---- 74 | agent_name = self.agent.__class__.__name__ 75 | # self.cwd = f'./{agent_name}/{self.env.env_name}_{self.gpu_id}' 76 | # self.cwd = f'./{agent_name}/{self.env.env_name}' 77 | self.cwd = f'./{config.WEIGHTS_PATH}/{self.env.env_name}' 78 | 79 | # model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}' \ 80 | # f'/single_{config.VALI_DAYS_FLAG}' 81 | # ---- 82 | 83 | if if_main: 84 | print(f'| GPU id: {self.gpu_id}, cwd: {self.cwd}') 85 | 86 | import shutil # remove history according to bool(if_remove) 87 | if self.if_remove is None: 88 | self.if_remove = bool(input("PRESS 'y' to REMOVE: {}? ".format(self.cwd)) == 'y') 89 | if self.if_remove: 90 | shutil.rmtree(self.cwd, ignore_errors=True) 91 | print("| Remove history") 92 | os.makedirs(self.cwd, exist_ok=True) 93 | 94 | os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpu_id) 95 | torch.set_num_threads(self.num_threads) 96 | torch.set_default_dtype(torch.float32) 97 | torch.manual_seed(self.random_seed) 98 | np.random.seed(self.random_seed) 99 | 100 | 101 | '''single process training''' 102 | 103 | 104 | def train_and_evaluate(args): 105 | args.init_before_training() 106 | 107 | '''basic arguments''' 108 | cwd = args.cwd 109 | env = args.env 110 | agent = args.agent 111 | gpu_id = args.gpu_id # necessary for Evaluator? 112 | 113 | '''training arguments''' 114 | net_dim = args.net_dim 115 | max_memo = args.max_memo 116 | break_step = args.break_step 117 | batch_size = args.batch_size 118 | target_step = args.target_step 119 | repeat_times = args.repeat_times 120 | if_break_early = args.if_allow_break 121 | if_per = args.if_per 122 | gamma = args.gamma 123 | reward_scale = args.reward_scale 124 | 125 | '''evaluating arguments''' 126 | eval_gap = args.eval_gap 127 | eval_times1 = args.eval_times1 128 | eval_times2 = args.eval_times2 129 | if args.env_eval is not None: 130 | env_eval = args.env_eval 131 | elif args.env_eval in set(gym.envs.registry.env_specs.keys()): 132 | env_eval = PreprocessEnv(gym.make(env.env_name)) 133 | else: 134 | env_eval = deepcopy(env) 135 | 136 | del args # In order to show these hyper-parameters clearly, I put them above. 137 | 138 | '''init: environment''' 139 | max_step = env.max_step 140 | state_dim = env.state_dim 141 | action_dim = env.action_dim 142 | if_discrete = env.if_discrete 143 | 144 | '''init: Agent, ReplayBuffer, Evaluator''' 145 | agent.init(net_dim, state_dim, action_dim, if_per) 146 | 147 | # ---- 148 | # 目录 path 149 | model_folder_path = f'./{config.AGENT_NAME}/single/{config.SINGLE_A_STOCK_CODE[0]}' \ 150 | f'/single_{config.VALI_DAYS_FLAG}' 151 | # 文件 path 152 | model_file_path = f'{model_folder_path}/actor.pth' 153 | 154 | # 如果model存在,则加载 155 | if os.path.exists(model_file_path): 156 | agent.save_load_model(model_folder_path, if_save=False) 157 | pass 158 | # ---- 159 | 160 | if_on_policy = getattr(agent, 'if_on_policy', False) 161 | 162 | buffer = ReplayBuffer(max_len=max_memo + max_step, state_dim=state_dim, action_dim=1 if if_discrete else action_dim, 163 | if_on_policy=if_on_policy, if_per=if_per, if_gpu=True) 164 | 165 | evaluator = Evaluator(cwd=cwd, agent_id=gpu_id, device=agent.device, env=env_eval, 166 | eval_gap=eval_gap, eval_times1=eval_times1, eval_times2=eval_times2, ) 167 | 168 | '''prepare for training''' 169 | agent.state = env.reset() 170 | if if_on_policy: 171 | steps = 0 172 | else: # explore_before_training for off-policy 173 | with torch.no_grad(): # update replay buffer 174 | steps = explore_before_training(env, buffer, target_step, reward_scale, gamma) 175 | 176 | agent.update_net(buffer, target_step, batch_size, repeat_times) # pre-training and hard update 177 | agent.act_target.load_state_dict(agent.act.state_dict()) if getattr(agent, 'act_target', None) else None 178 | agent.cri_target.load_state_dict(agent.cri.state_dict()) if getattr(agent, 'cri_target', None) else None 179 | total_step = steps 180 | 181 | '''start training''' 182 | if_reach_goal = False 183 | while not ((if_break_early and if_reach_goal) 184 | or total_step > break_step 185 | or os.path.exists(f'{cwd}/stop')): 186 | steps = agent.explore_env(env, buffer, target_step, reward_scale, gamma) 187 | total_step += steps 188 | 189 | obj_a, obj_c = agent.update_net(buffer, target_step, batch_size, repeat_times) 190 | 191 | if_reach_goal = evaluator.evaluate_save(agent.act, steps, obj_a, obj_c) 192 | evaluator.draw_plot() 193 | 194 | print(f'| SavedDir: {cwd}\n| UsedTime: {time.time() - evaluator.start_time:.0f}') 195 | 196 | 197 | '''multiprocessing training''' 198 | 199 | 200 | def train_and_evaluate_mp(args): 201 | act_workers = args.rollout_num 202 | import multiprocessing as mp # Python built-in multiprocessing library 203 | 204 | pipe1_eva, pipe2_eva = mp.Pipe() # Pipe() for Process mp_evaluate_agent() 205 | pipe2_exp_list = list() # Pipe() for Process mp_explore_in_env() 206 | 207 | process_train = mp.Process(target=mp_train, args=(args, pipe2_eva, pipe2_exp_list)) 208 | process_evaluate = mp.Process(target=mp_evaluate, args=(args, pipe1_eva)) 209 | process = [process_train, process_evaluate] 210 | 211 | for worker_id in range(act_workers): 212 | exp_pipe1, exp_pipe2 = mp.Pipe(duplex=True) 213 | pipe2_exp_list.append(exp_pipe1) 214 | process.append(mp.Process(target=mp_explore, args=(args, exp_pipe2, worker_id))) 215 | 216 | [p.start() for p in process] 217 | process_evaluate.join() 218 | process_train.join() 219 | [p.terminate() for p in process] 220 | 221 | 222 | def mp_train(args, pipe1_eva, pipe1_exp_list): 223 | args.init_before_training(if_main=False) 224 | 225 | '''basic arguments''' 226 | env = args.env 227 | cwd = args.cwd 228 | agent = args.agent 229 | rollout_num = args.rollout_num 230 | 231 | '''training arguments''' 232 | net_dim = args.net_dim 233 | max_memo = args.max_memo 234 | break_step = args.break_step 235 | batch_size = args.batch_size 236 | target_step = args.target_step 237 | repeat_times = args.repeat_times 238 | if_break_early = args.if_allow_break 239 | if_per = args.if_per 240 | del args # In order to show these hyper-parameters clearly, I put them above. 241 | 242 | '''init: environment''' 243 | max_step = env.max_step 244 | state_dim = env.state_dim 245 | action_dim = env.action_dim 246 | if_discrete = env.if_discrete 247 | 248 | '''init: Agent, ReplayBuffer''' 249 | agent.init(net_dim, state_dim, action_dim, if_per) 250 | 251 | # ---- 252 | # 目录 path 253 | model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}' \ 254 | f'/single_{config.VALI_DAYS_FLAG}' 255 | 256 | # f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}/StockTradingEnv-v1' 257 | 258 | # 文件 path 259 | model_file_path = f'{model_folder_path}/actor.pth' 260 | 261 | # 如果model存在,则加载 262 | if os.path.exists(model_file_path): 263 | agent.save_load_model(model_folder_path, if_save=False) 264 | pass 265 | # ---- 266 | 267 | if_on_policy = getattr(agent, 'if_on_policy', False) 268 | 269 | '''send''' 270 | pipe1_eva.send(agent.act) # send 271 | # act = pipe2_eva.recv() # recv 272 | 273 | buffer_mp = ReplayBufferMP(max_len=max_memo + max_step * rollout_num, if_on_policy=if_on_policy, 274 | state_dim=state_dim, action_dim=1 if if_discrete else action_dim, 275 | rollout_num=rollout_num, if_gpu=True, if_per=if_per) 276 | 277 | '''prepare for training''' 278 | if if_on_policy: 279 | steps = 0 280 | else: # explore_before_training for off-policy 281 | with torch.no_grad(): # update replay buffer 282 | steps = 0 283 | for i in range(rollout_num): 284 | pipe1_exp = pipe1_exp_list[i] 285 | 286 | # pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len])) 287 | buf_state, buf_other = pipe1_exp.recv() 288 | 289 | steps += len(buf_state) 290 | buffer_mp.extend_buffer(buf_state, buf_other, i) 291 | 292 | agent.update_net(buffer_mp, target_step, batch_size, repeat_times) # pre-training and hard update 293 | agent.act_target.load_state_dict(agent.act.state_dict()) if getattr(env, 'act_target', None) else None 294 | agent.cri_target.load_state_dict(agent.cri.state_dict()) if getattr(env, 'cri_target', None) in dir( 295 | agent) else None 296 | total_step = steps 297 | '''send''' 298 | pipe1_eva.send((agent.act, steps, 0, 0.5)) # send 299 | # act, steps, obj_a, obj_c = pipe2_eva.recv() # recv 300 | 301 | '''start training''' 302 | if_solve = False 303 | while not ((if_break_early and if_solve) 304 | or total_step > break_step 305 | or os.path.exists(f'{cwd}/stop')): 306 | '''update ReplayBuffer''' 307 | steps = 0 # send by pipe1_eva 308 | for i in range(rollout_num): 309 | pipe1_exp = pipe1_exp_list[i] 310 | '''send''' 311 | pipe1_exp.send(agent.act) 312 | # agent.act = pipe2_exp.recv() 313 | '''recv''' 314 | # pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len])) 315 | buf_state, buf_other = pipe1_exp.recv() 316 | 317 | steps += len(buf_state) 318 | buffer_mp.extend_buffer(buf_state, buf_other, i) 319 | total_step += steps 320 | 321 | '''update network parameters''' 322 | obj_a, obj_c = agent.update_net(buffer_mp, target_step, batch_size, repeat_times) 323 | 324 | '''saves the agent with max reward''' 325 | '''send''' 326 | pipe1_eva.send((agent.act, steps, obj_a, obj_c)) 327 | # q_i_eva_get = pipe2_eva.recv() 328 | 329 | if_solve = pipe1_eva.recv() 330 | 331 | if pipe1_eva.poll(): 332 | '''recv''' 333 | # pipe2_eva.send(if_solve) 334 | if_solve = pipe1_eva.recv() 335 | 336 | buffer_mp.print_state_norm(env.neg_state_avg if hasattr(env, 'neg_state_avg') else None, 337 | env.div_state_std if hasattr(env, 'div_state_std') else None) # 2020-12-12 338 | 339 | '''send''' 340 | pipe1_eva.send('stop') 341 | # q_i_eva_get = pipe2_eva.recv() 342 | time.sleep(4) 343 | 344 | 345 | def mp_explore(args, pipe2_exp, worker_id): 346 | args.init_before_training(if_main=False) 347 | 348 | '''basic arguments''' 349 | env = args.env 350 | agent = args.agent 351 | rollout_num = args.rollout_num 352 | 353 | '''training arguments''' 354 | net_dim = args.net_dim 355 | max_memo = args.max_memo 356 | target_step = args.target_step 357 | gamma = args.gamma 358 | if_per = args.if_per 359 | reward_scale = args.reward_scale 360 | 361 | random_seed = args.random_seed 362 | torch.manual_seed(random_seed + worker_id) 363 | np.random.seed(random_seed + worker_id) 364 | del args # In order to show these hyper-parameters clearly, I put them above. 365 | 366 | '''init: environment''' 367 | max_step = env.max_step 368 | state_dim = env.state_dim 369 | action_dim = env.action_dim 370 | if_discrete = env.if_discrete 371 | 372 | '''init: Agent, ReplayBuffer''' 373 | agent.init(net_dim, state_dim, action_dim, if_per) 374 | 375 | # ---- 376 | # 目录 path 377 | # model_folder_path = f'./{config.AGENT_NAME}/single/{config.SINGLE_A_STOCK_CODE[0]}' \ 378 | # f'/single_{config.VALI_DAYS_FLAG}' 379 | 380 | model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.SINGLE_A_STOCK_CODE[0]}' \ 381 | f'/single_{config.VALI_DAYS_FLAG}' 382 | 383 | # 文件 path 384 | model_file_path = f'{model_folder_path}/actor.pth' 385 | 386 | # 如果model存在,则加载 387 | if os.path.exists(model_file_path): 388 | agent.save_load_model(model_folder_path, if_save=False) 389 | pass 390 | # ---- 391 | 392 | agent.state = env.reset() 393 | 394 | if_on_policy = getattr(agent, 'if_on_policy', False) 395 | buffer = ReplayBuffer(max_len=max_memo // rollout_num + max_step, if_on_policy=if_on_policy, 396 | state_dim=state_dim, action_dim=1 if if_discrete else action_dim, 397 | if_per=if_per, if_gpu=False) 398 | 399 | '''start exploring''' 400 | exp_step = target_step // rollout_num 401 | with torch.no_grad(): 402 | if not if_on_policy: 403 | explore_before_training(env, buffer, exp_step, reward_scale, gamma) 404 | 405 | buffer.update_now_len_before_sample() 406 | 407 | pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len])) 408 | # buf_state, buf_other = pipe1_exp.recv() 409 | 410 | buffer.empty_buffer_before_explore() 411 | 412 | while True: 413 | agent.explore_env(env, buffer, exp_step, reward_scale, gamma) 414 | 415 | buffer.update_now_len_before_sample() 416 | 417 | pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len])) 418 | # buf_state, buf_other = pipe1_exp.recv() 419 | 420 | buffer.empty_buffer_before_explore() 421 | 422 | # pipe1_exp.send(agent.act) 423 | agent.act = pipe2_exp.recv() 424 | 425 | 426 | def mp_evaluate(args, pipe2_eva): 427 | args.init_before_training(if_main=True) 428 | 429 | '''basic arguments''' 430 | cwd = args.cwd 431 | env = args.env 432 | env_eval = env if args.env_eval is None else args.env_eval 433 | agent_id = args.gpu_id 434 | 435 | '''evaluating arguments''' 436 | eval_gap = args.eval_gap 437 | eval_times1 = args.eval_times1 438 | eval_times2 = args.eval_times2 439 | del args # In order to show these hyper-parameters clearly, I put them above. 440 | 441 | '''init: Evaluator''' 442 | evaluator = Evaluator(cwd=cwd, agent_id=agent_id, device=torch.device("cpu"), env=env_eval, 443 | eval_gap=eval_gap, eval_times1=eval_times1, eval_times2=eval_times2, ) # build Evaluator 444 | 445 | '''act_cpu without gradient for pipe1_eva''' 446 | # pipe1_eva.send(agent.act) 447 | act = pipe2_eva.recv() 448 | 449 | act_cpu = deepcopy(act).to(torch.device("cpu")) # for pipe1_eva 450 | [setattr(param, 'requires_grad', False) for param in act_cpu.parameters()] 451 | 452 | '''start evaluating''' 453 | with torch.no_grad(): # speed up running 454 | act, steps, obj_a, obj_c = pipe2_eva.recv() # pipe2_eva (act, steps, obj_a, obj_c) 455 | 456 | if_loop = True 457 | while if_loop: 458 | '''update actor''' 459 | while not pipe2_eva.poll(): # wait until pipe2_eva not empty 460 | time.sleep(1) 461 | steps_sum = 0 462 | while pipe2_eva.poll(): # receive the latest object from pipe 463 | '''recv''' 464 | # pipe1_eva.send((agent.act, steps, obj_a, obj_c)) 465 | # pipe1_eva.send('stop') 466 | q_i_eva_get = pipe2_eva.recv() 467 | 468 | if q_i_eva_get == 'stop': 469 | if_loop = False 470 | break 471 | act, steps, obj_a, obj_c = q_i_eva_get 472 | steps_sum += steps 473 | act_cpu.load_state_dict(act.state_dict()) 474 | if_solve = evaluator.evaluate_save(act_cpu, steps_sum, obj_a, obj_c) 475 | '''send''' 476 | pipe2_eva.send(if_solve) 477 | # if_solve = pipe1_eva.recv() 478 | 479 | evaluator.draw_plot() 480 | 481 | print(f'| SavedDir: {cwd}\n| UsedTime: {time.time() - evaluator.start_time:.0f}') 482 | 483 | while pipe2_eva.poll(): # empty the pipe 484 | pipe2_eva.recv() 485 | 486 | 487 | '''utils''' 488 | 489 | 490 | class Evaluator: 491 | def __init__(self, cwd, agent_id, eval_times1, eval_times2, eval_gap, env, device): 492 | self.recorder = [(0., -np.inf, 0., 0., 0.), ] # total_step, r_avg, r_std, obj_a, obj_c 493 | self.r_max = -np.inf 494 | self.total_step = 0 495 | 496 | self.cwd = cwd # constant 497 | self.device = device 498 | self.agent_id = agent_id 499 | self.eval_gap = eval_gap 500 | self.eval_times1 = eval_times1 501 | self.eval_times2 = eval_times2 502 | self.env = env 503 | self.target_return = env.target_return 504 | 505 | self.used_time = None 506 | self.start_time = time.time() 507 | self.eval_time = -1 # a early time 508 | print(f"{'ID':>2} {'Step':>8} {'MaxR':>8} |" 509 | f"{'avgR':>8} {'stdR':>8} {'objA':>8} {'objC':>8} |" 510 | f"{'avgS':>6} {'stdS':>4}") 511 | 512 | def evaluate_save(self, act, steps, obj_a, obj_c) -> bool: 513 | self.total_step += steps # update total training steps 514 | 515 | if time.time() - self.eval_time > self.eval_gap: 516 | self.eval_time = time.time() 517 | 518 | rewards_steps_list = [get_episode_return(self.env, act, self.device) for _ in range(self.eval_times1)] 519 | r_avg, r_std, s_avg, s_std = self.get_r_avg_std_s_avg_std(rewards_steps_list) 520 | 521 | if r_avg > self.r_max: # evaluate actor twice to save CPU Usage and keep precision 522 | rewards_steps_list += [get_episode_return(self.env, act, self.device) 523 | for _ in range(self.eval_times2 - self.eval_times1)] 524 | r_avg, r_std, s_avg, s_std = self.get_r_avg_std_s_avg_std(rewards_steps_list) 525 | if r_avg > self.r_max: # save checkpoint with highest episode return 526 | self.r_max = r_avg # update max reward (episode return) 527 | 528 | '''save actor.pth''' 529 | act_save_path = f'{self.cwd}/actor.pth' 530 | torch.save(act.state_dict(), act_save_path) 531 | print(f"{self.agent_id:<2} {self.total_step:8.2e} {self.r_max:8.2f} |") # save policy and print 532 | 533 | self.recorder.append((self.total_step, r_avg, r_std, obj_a, obj_c)) # update recorder 534 | 535 | if_reach_goal = bool(self.r_max > self.target_return) # check if_reach_goal 536 | if if_reach_goal and self.used_time is None: 537 | self.used_time = int(time.time() - self.start_time) 538 | print(f"{'ID':>2} {'Step':>8} {'TargetR':>8} |" 539 | f"{'avgR':>8} {'stdR':>8} {'UsedTime':>8} ########\n" 540 | f"{self.agent_id:<2} {self.total_step:8.2e} {self.target_return:8.2f} |" 541 | f"{r_avg:8.2f} {r_std:8.2f} {self.used_time:>8} ########") 542 | 543 | print(f"{self.agent_id:<2} {self.total_step:8.2e} {self.r_max:8.2f} |" 544 | f"{r_avg:8.2f} {r_std:8.2f} {obj_a:8.2f} {obj_c:8.2f} |" 545 | f"{s_avg:6.0f} {s_std:4.0f}") 546 | else: 547 | if_reach_goal = False 548 | return if_reach_goal 549 | 550 | def draw_plot(self): 551 | if len(self.recorder) == 0: 552 | print("| save_npy_draw_plot() WARNNING: len(self.recorder)==0") 553 | return None 554 | 555 | '''convert to array and save as npy''' 556 | np.save('%s/recorder.npy' % self.cwd, self.recorder) 557 | 558 | '''draw plot and save as png''' 559 | train_time = int(time.time() - self.start_time) 560 | total_step = int(self.recorder[-1][0]) 561 | save_title = f"plot_step_time_maxR_{int(total_step)}_{int(train_time)}_{self.r_max:.3f}" 562 | 563 | save_learning_curve(self.recorder, self.cwd, save_title) 564 | 565 | @staticmethod 566 | def get_r_avg_std_s_avg_std(rewards_steps_list): 567 | rewards_steps_ary = np.array(rewards_steps_list) 568 | r_avg, s_avg = rewards_steps_ary.mean(axis=0) # average of episode return and episode step 569 | r_std, s_std = rewards_steps_ary.std(axis=0) # standard dev. of episode return and episode step 570 | return r_avg, r_std, s_avg, s_std 571 | 572 | 573 | def get_episode_return(env, act, device) -> (float, int): 574 | episode_return = 0.0 # sum of rewards in an episode 575 | episode_step = 1 576 | max_step = env.max_step 577 | if_discrete = env.if_discrete 578 | 579 | state = env.reset() 580 | for episode_step in range(max_step): 581 | s_tensor = torch.as_tensor((state,), device=device) 582 | a_tensor = act(s_tensor) 583 | if if_discrete: 584 | a_tensor = a_tensor.argmax(dim=1) 585 | action = a_tensor.detach().cpu().numpy()[0] # not need detach(), because with torch.no_grad() outside 586 | state, reward, done, _ = env.step(action) 587 | episode_return += reward 588 | if done: 589 | break 590 | episode_return = getattr(env, 'episode_return', episode_return) 591 | return episode_return, episode_step + 1 592 | 593 | 594 | def save_learning_curve(recorder, cwd='.', save_title='learning curve'): 595 | recorder = np.array(recorder) # recorder_ary.append((self.total_step, r_avg, r_std, obj_a, obj_c)) 596 | steps = recorder[:, 0] # x-axis is training steps 597 | r_avg = recorder[:, 1] 598 | r_std = recorder[:, 2] 599 | obj_a = recorder[:, 3] 600 | obj_c = recorder[:, 4] 601 | 602 | '''plot subplots''' 603 | import matplotlib as mpl 604 | mpl.use('Agg') 605 | """Generating matplotlib graphs without a running X server [duplicate] 606 | write `mpl.use('Agg')` before `import matplotlib.pyplot as plt` 607 | https://stackoverflow.com/a/4935945/9293137 608 | """ 609 | import matplotlib.pyplot as plt 610 | fig, axs = plt.subplots(2) 611 | 612 | axs0 = axs[0] 613 | axs0.cla() 614 | color0 = 'lightcoral' 615 | axs0.set_xlabel('Total Steps') 616 | axs0.set_ylabel('Episode Return') 617 | axs0.plot(steps, r_avg, label='Episode Return', color=color0) 618 | axs0.fill_between(steps, r_avg - r_std, r_avg + r_std, facecolor=color0, alpha=0.3) 619 | 620 | ax11 = axs[1] 621 | ax11.cla() 622 | color11 = 'royalblue' 623 | axs0.set_xlabel('Total Steps') 624 | ax11.set_ylabel('objA', color=color11) 625 | ax11.plot(steps, obj_a, label='objA', color=color11) 626 | ax11.tick_params(axis='y', labelcolor=color11) 627 | 628 | ax12 = axs[1].twinx() 629 | color12 = 'darkcyan' 630 | ax12.set_ylabel('objC', color=color12) 631 | ax12.fill_between(steps, obj_c, facecolor=color12, alpha=0.2, ) 632 | ax12.tick_params(axis='y', labelcolor=color12) 633 | 634 | '''plot save''' 635 | plt.title(save_title, y=2.3) 636 | plt.savefig(f"{cwd}/plot_learning_curve.jpg") 637 | plt.close('all') # avoiding warning about too many open figures, rcParam `figure.max_open_warning` 638 | # plt.show() # if use `mpl.use('Agg')` to draw figures without GUI, then plt can't plt.show() 639 | 640 | 641 | def explore_before_training(env, buffer, target_step, reward_scale, gamma) -> int: 642 | # just for off-policy. Because on-policy don't explore before training. 643 | if_discrete = env.if_discrete 644 | action_dim = env.action_dim 645 | 646 | state = env.reset() 647 | steps = 0 648 | 649 | while steps < target_step: 650 | action = rd.randint(action_dim) if if_discrete else rd.uniform(-1, 1, size=action_dim) 651 | next_state, reward, done, _ = env.step(action) 652 | steps += 1 653 | 654 | scaled_reward = reward * reward_scale 655 | mask = 0.0 if done else gamma 656 | other = (scaled_reward, mask, action) if if_discrete else (scaled_reward, mask, *action) 657 | buffer.append_buffer(state, other) 658 | 659 | state = env.reset() if done else next_state 660 | return steps 661 | -------------------------------------------------------------------------------- /env_single.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pandas as pd 4 | import numpy as np 5 | 6 | import config 7 | from train_helper import insert_train_history_record_sqlite 8 | 9 | 10 | class StockTradingEnvSingle: 11 | def __init__(self, cwd='./envs/FinRL', gamma=0.99, 12 | max_stock=1e2, initial_capital=1e6, buy_cost_pct=1e-3, sell_cost_pct=1e-3, 13 | start_date='2008-03-19', end_date='2016-01-01', env_eval_date='2021-01-01', 14 | ticker_list=None, tech_indicator_list=None, initial_stocks=None, if_eval=False): 15 | 16 | self.price_ary, self.tech_ary, self.tic_ary, self.date_ary = self.load_data(cwd, if_eval, ticker_list, 17 | tech_indicator_list, 18 | start_date, end_date, 19 | env_eval_date, ) 20 | stock_dim = self.price_ary.shape[1] 21 | 22 | self.gamma = gamma 23 | self.max_stock = max_stock 24 | self.buy_cost_pct = buy_cost_pct 25 | self.sell_cost_pct = sell_cost_pct 26 | self.initial_capital = initial_capital 27 | self.initial_stocks = np.zeros(stock_dim, dtype=np.float32) if initial_stocks is None else initial_stocks 28 | 29 | # reset() 30 | self.day = None 31 | self.amount = None 32 | self.stocks = None 33 | self.total_asset = None 34 | self.initial_total_asset = None 35 | self.gamma_reward = 0.0 36 | 37 | # environment information 38 | self.env_name = 'StockTradingEnv-v1' 39 | self.state_dim = 1 + 2 * stock_dim + self.tech_ary.shape[1] 40 | self.action_dim = stock_dim 41 | self.max_step = len(self.price_ary) - 1 42 | self.if_discrete = False 43 | self.target_return = 3.5 44 | self.episode_return = 0.0 45 | 46 | # 输出的缓存 47 | self.output_text_trade_detail = '' 48 | 49 | # 输出的list 50 | self.list_buy_or_sell_output = [] 51 | 52 | # 奖励 比例 53 | self.reward_scaling = 0.0 54 | 55 | # 是 eval 还是 train 56 | self.if_eval = if_eval 57 | 58 | pass 59 | 60 | def reset(self): 61 | self.day = 0 62 | price = self.price_ary[self.day] 63 | 64 | # ---- 65 | np.random.seed(round(time.time())) 66 | random_float = np.random.uniform(0.0, 1.01, size=self.initial_stocks.shape) 67 | 68 | # 如果是正式预测,输出到网页,固定 持股数和现金 69 | if config.IF_ACTUAL_PREDICT is True: 70 | self.stocks = self.initial_stocks.copy() 71 | self.amount = self.initial_capital - (self.stocks * price).sum() 72 | pass 73 | else: 74 | # 如果是train过程中的eval 75 | self.stocks = random_float * self.initial_stocks.copy() // 100 * 100 76 | self.amount = self.initial_capital * np.random.uniform(0.95, 1.05) - (self.stocks * price).sum() 77 | pass 78 | pass 79 | 80 | self.total_asset = self.amount + (self.stocks * price).sum() 81 | self.initial_total_asset = self.total_asset 82 | self.gamma_reward = 0.0 83 | 84 | state = np.hstack((self.amount * 2 ** -12, 85 | price, 86 | self.stocks * 2 ** -4, 87 | self.tech_ary[self.day],)).astype(np.float32) * 2 ** -8 88 | 89 | # 清空输出的缓存 90 | self.output_text_trade_detail = '' 91 | 92 | # 输出的list 93 | self.list_buy_or_sell_output.clear() 94 | self.list_buy_or_sell_output = [] 95 | 96 | return state 97 | 98 | def step(self, actions): 99 | actions_temp = (actions * self.max_stock).astype(int) 100 | 101 | # ---- 102 | yesterday_price = self.price_ary[self.day] 103 | # ---- 104 | 105 | self.day += 1 106 | 107 | price = self.price_ary[self.day] 108 | 109 | tic_ary_temp = self.tic_ary[self.day] 110 | # 日期 111 | date_ary_temp = self.date_ary[self.day] 112 | date_temp = date_ary_temp[0] 113 | 114 | self.output_text_trade_detail += f'第 {self.day + 1} 天,{date_temp}\r\n' 115 | 116 | for index in np.where(actions_temp < 0)[0]: # sell_index: 117 | if price[index] > 0: # Sell only if current asset is > 0 118 | 119 | sell_num_shares = min(self.stocks[index], -actions_temp[index]) 120 | 121 | tic_temp = tic_ary_temp[index] 122 | 123 | if sell_num_shares >= 100: 124 | # 若 action <= -100 地板除,卖1手整 125 | sell_num_shares = sell_num_shares // 100 * 100 126 | self.stocks[index] -= sell_num_shares 127 | self.amount += price[index] * sell_num_shares * (1 - self.sell_cost_pct) 128 | 129 | if self.if_eval is True: 130 | # tic, date, sell/buy, hold, 第x天 131 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset 132 | 133 | list_item = (tic_temp, date_temp, -1 * sell_num_shares, self.stocks[index], self.day + 1, 134 | episode_return_temp) 135 | # 添加到输出list 136 | self.list_buy_or_sell_output.append(list_item) 137 | pass 138 | else: 139 | # 当sell_num_shares < 100时,判断若 self.stocks[index] >= 100 则放大效果,卖1手 140 | if self.stocks[index] >= 100: 141 | sell_num_shares = 100 142 | self.stocks[index] -= sell_num_shares 143 | self.amount += price[index] * sell_num_shares * (1 - self.sell_cost_pct) 144 | 145 | if self.if_eval is True: 146 | # tic, date, sell/buy, hold, 第x天 147 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset 148 | 149 | list_item = (tic_temp, date_temp, -1 * sell_num_shares, self.stocks[index], self.day + 1, 150 | episode_return_temp) 151 | # 添加到输出list 152 | self.list_buy_or_sell_output.append(list_item) 153 | pass 154 | else: 155 | # self.stocks[index] 不足1手时,不动 156 | sell_num_shares = 0 157 | 158 | if self.if_eval is True: 159 | # tic, date, sell/buy, hold, 第x天 160 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset 161 | 162 | list_item = (tic_temp, date_temp, 0, self.stocks[index], self.day + 1, episode_return_temp) 163 | # 添加到输出list 164 | self.list_buy_or_sell_output.append(list_item) 165 | pass 166 | pass 167 | pass 168 | 169 | if yesterday_price[index] != 0: 170 | price_diff_percent = str(round((price[index] - yesterday_price[index]) / yesterday_price[index], 4)) 171 | else: 172 | price_diff_percent = '0.0' 173 | pass 174 | 175 | price_diff = str(round(price[index] - yesterday_price[index], 6)) 176 | self.output_text_trade_detail += f' > {tic_temp},预测涨跌:{round(-1 * actions[index], 4)},' \ 177 | f'实际涨跌:{price_diff_percent} ¥{price_diff} 元,' \ 178 | f'卖出:{sell_num_shares} 股, 持股数量 {self.stocks[index]},' \ 179 | f'现金:{self.amount},资产:{self.total_asset} \r\n' 180 | pass 181 | pass 182 | 183 | for index in np.where(actions_temp > 0)[0]: # buy_index: 184 | if price[index] > 0: # Buy only if the price is > 0 (no missing data in this particular date) 185 | buy_num_shares = min(self.amount // price[index], actions_temp[index]) 186 | 187 | tic_temp = tic_ary_temp[index] 188 | 189 | if buy_num_shares >= 100: 190 | # 若 actions >= +100,地板除,买1手整 191 | buy_num_shares = buy_num_shares // 100 * 100 192 | self.stocks[index] += buy_num_shares 193 | self.amount -= price[index] * buy_num_shares * (1 + self.buy_cost_pct) 194 | 195 | if self.if_eval is True: 196 | # tic, date, sell/buy, hold, 第x天 197 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset 198 | 199 | list_item = (tic_temp, date_temp, buy_num_shares, self.stocks[index], self.day + 1, 200 | episode_return_temp) 201 | 202 | # 添加到输出list 203 | self.list_buy_or_sell_output.append(list_item) 204 | pass 205 | else: 206 | # 当buy_num_shares < 100时,判断若 self.amount // price[index] >= 100,则放大效果,买1手 207 | if (self.amount // price[index]) >= 100: 208 | buy_num_shares = 100 209 | self.stocks[index] += buy_num_shares 210 | self.amount -= price[index] * buy_num_shares * (1 + self.buy_cost_pct) 211 | 212 | if self.if_eval is True: 213 | # tic, date, sell/buy, hold, 第x天 214 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset 215 | 216 | list_item = (tic_temp, date_temp, buy_num_shares, self.stocks[index], self.day + 1, 217 | episode_return_temp) 218 | 219 | # 添加到输出list 220 | self.list_buy_or_sell_output.append(list_item) 221 | else: 222 | # self.amount // price[index] 不足100时,不动 223 | # 未达到1手,不买 224 | buy_num_shares = 0 225 | 226 | if self.if_eval is True: 227 | # tic, date, sell/buy, hold, 第x天 228 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset 229 | 230 | list_item = (tic_temp, date_temp, 0, self.stocks[index], self.day + 1, episode_return_temp) 231 | # 添加到输出list 232 | self.list_buy_or_sell_output.append(list_item) 233 | pass 234 | pass 235 | pass 236 | 237 | if yesterday_price[index] != 0: 238 | price_diff_percent = str(round((price[index] - yesterday_price[index]) / yesterday_price[index], 4)) 239 | else: 240 | price_diff_percent = '0.0' 241 | pass 242 | 243 | price_diff = str(round(price[index] - yesterday_price[index], 6)) 244 | self.output_text_trade_detail += f' > {tic_temp},预测涨跌:{round(-1 * actions[index], 4)},' \ 245 | f'实际涨跌:{price_diff_percent} ¥{price_diff} 元,' \ 246 | f'买入:{buy_num_shares} 股, 持股数量:{self.stocks[index]},' \ 247 | f'现金:{self.amount},资产:{self.total_asset} \r\n' 248 | 249 | pass 250 | pass 251 | 252 | if self.if_eval is True: 253 | for index in np.where(actions_temp == 0)[0]: # sell_index: 254 | if price[index] > 0: # Buy only if the price is > 0 (no missing data in this particular date) 255 | # tic, date, sell/buy, hold, 第x天 256 | tic_temp = tic_ary_temp[index] 257 | episode_return_temp = (self.amount + (self.stocks * price).sum()) / self.initial_total_asset 258 | 259 | list_item = (tic_temp, date_temp, 0, self.stocks[index], self.day + 1, episode_return_temp) 260 | # 添加到输出list 261 | self.list_buy_or_sell_output.append(list_item) 262 | pass 263 | pass 264 | pass 265 | pass 266 | 267 | state = np.hstack((self.amount * 2 ** -12, 268 | price, 269 | self.stocks * 2 ** -4, 270 | self.tech_ary[self.day],)).astype(np.float32) * 2 ** -8 271 | 272 | total_asset = self.amount + (self.stocks * price).sum() 273 | reward = (total_asset - self.total_asset) * self.reward_scaling 274 | 275 | self.total_asset = total_asset 276 | 277 | self.gamma_reward = self.gamma_reward * self.gamma + reward 278 | done = self.day == self.max_step 279 | 280 | if done: 281 | reward = self.gamma_reward 282 | self.episode_return = total_asset / self.initial_total_asset 283 | # if self.if_eval is True: 284 | # print(config.AGENT_NAME, 'eval DONE reward:', str(reward)) 285 | # pass 286 | # else: 287 | # print(config.AGENT_NAME, 'train DONE reward:', str(reward)) 288 | # pass 289 | # pass 290 | 291 | if config.IF_ACTUAL_PREDICT: 292 | print(self.output_text_trade_detail) 293 | print(f'第 {self.day + 1} 天,{date_temp},现金:{self.amount},' 294 | f'股票:{str((self.stocks * price).sum())},总资产:{self.total_asset}') 295 | else: 296 | # if self.if_eval is True: 297 | # print(config.AGENT_NAME, 'eval reward', str(reward)) 298 | # else: 299 | # print(config.AGENT_NAME, 'train reward', str(reward)) 300 | # pass 301 | 302 | if reward > config.REWARD_THRESHOLD: 303 | # 如果是 预测 304 | if self.if_eval is True: 305 | insert_train_history_record_sqlite(model_id=config.MODEL_HYPER_PARAMETERS, train_reward_value=0.0, 306 | eval_reward_value=reward) 307 | 308 | print('>>>>', config.AGENT_NAME, 'eval reward', str(reward)) 309 | else: 310 | # 如果是 train 311 | insert_train_history_record_sqlite(model_id=config.MODEL_HYPER_PARAMETERS, 312 | train_reward_value=reward, eval_reward_value=0.0) 313 | 314 | print(config.AGENT_NAME, 'train reward', str(reward)) 315 | pass 316 | 317 | pass 318 | pass 319 | 320 | return state, reward, done, dict() 321 | 322 | def load_data(self, cwd='./envs/FinRL', if_eval=None, 323 | ticker_list=None, tech_indicator_list=None, 324 | start_date='2008-03-19', end_date='2016-01-01', env_eval_date='2021-01-01'): 325 | 326 | # 从数据库中读取fe fillzero的数据 327 | from stock_data import StockData 328 | processed_df = StockData.get_fe_fillzero_from_sqlite(begin_date=start_date, end_date=env_eval_date, 329 | list_stock_code=config.SINGLE_A_STOCK_CODE, 330 | if_actual_predict=config.IF_ACTUAL_PREDICT, 331 | table_name='fe_fillzero') 332 | 333 | def data_split_train(df, start, end): 334 | data = df[(df.date >= start) & (df.date < end)] 335 | data = data.sort_values(["date", "tic"], ignore_index=True) 336 | data.index = data.date.factorize()[0] 337 | return data 338 | 339 | def data_split_eval(df, start, end): 340 | data = df[(df.date >= start) & (df.date <= end)] 341 | data = data.sort_values(["date", "tic"], ignore_index=True) 342 | data.index = data.date.factorize()[0] 343 | return data 344 | 345 | train_df = data_split_train(processed_df, start_date, end_date) 346 | eval_df = data_split_eval(processed_df, end_date, env_eval_date) 347 | 348 | train_price_ary, train_tech_ary, train_tic_ary, train_date_ary = self.convert_df_to_ary(train_df, 349 | tech_indicator_list) 350 | eval_price_ary, eval_tech_ary, eval_tic_ary, eval_date_ary = self.convert_df_to_ary(eval_df, 351 | tech_indicator_list) 352 | 353 | if if_eval is None: 354 | price_ary = np.concatenate((train_price_ary, eval_price_ary), axis=0) 355 | tech_ary = np.concatenate((train_tech_ary, eval_tech_ary), axis=0) 356 | tic_ary = None 357 | date_ary = None 358 | elif if_eval: 359 | price_ary = eval_price_ary 360 | tech_ary = eval_tech_ary 361 | tic_ary = eval_tic_ary 362 | date_ary = eval_date_ary 363 | else: 364 | price_ary = train_price_ary 365 | tech_ary = train_tech_ary 366 | tic_ary = train_tic_ary 367 | date_ary = train_date_ary 368 | 369 | return price_ary, tech_ary, tic_ary, date_ary 370 | 371 | @staticmethod 372 | def convert_df_to_ary(df, tech_indicator_list): 373 | tech_ary = list() 374 | price_ary = list() 375 | tic_ary = list() 376 | date_ary = list() 377 | 378 | from stock_data import fields_prep 379 | columns_list = fields_prep.split(',') 380 | 381 | for day in range(len(df.index.unique())): 382 | # item = df.loc[day] 383 | list_temp = df.loc[day] 384 | if list_temp.ndim == 1: 385 | list_temp = [df.loc[day]] 386 | pass 387 | item = pd.DataFrame(data=list_temp, columns=columns_list) 388 | 389 | tech_items = [item[tech].values.tolist() for tech in tech_indicator_list] 390 | tech_items_flatten = sum(tech_items, []) 391 | tech_ary.append(tech_items_flatten) 392 | price_ary.append(item.close) # adjusted close price (adjcp) 393 | 394 | # ---- 395 | # tic_ary.append(list(item.tic)) 396 | # date_ary.append(list(item.date)) 397 | 398 | tic_ary.append(item.tic) 399 | date_ary.append(item.date) 400 | 401 | # ---- 402 | 403 | pass 404 | 405 | price_ary = np.array(price_ary) 406 | tech_ary = np.array(tech_ary) 407 | 408 | tic_ary = np.array(tic_ary) 409 | date_ary = np.array(date_ary) 410 | 411 | print(f'| price_ary.shape: {price_ary.shape}, tech_ary.shape: {tech_ary.shape}') 412 | return price_ary, tech_ary, tic_ary, date_ary 413 | 414 | def draw_cumulative_return(self, args, _torch) -> list: 415 | state_dim = self.state_dim 416 | action_dim = self.action_dim 417 | 418 | agent = args.agent 419 | net_dim = args.net_dim 420 | cwd = args.cwd 421 | 422 | agent.init(net_dim, state_dim, action_dim) 423 | agent.save_load_model(cwd=cwd, if_save=False) 424 | act = agent.act 425 | device = agent.device 426 | 427 | state = self.reset() 428 | episode_returns = list() # the cumulative_return / initial_account 429 | with _torch.no_grad(): 430 | for i in range(self.max_step): 431 | s_tensor = _torch.as_tensor((state,), device=device) 432 | a_tensor = act(s_tensor) 433 | action = a_tensor.cpu().numpy()[0] # not need detach(), because with torch.no_grad() outside 434 | state, reward, done, _ = self.step(action) 435 | 436 | total_asset = self.amount + (self.price_ary[self.day] * self.stocks).sum() 437 | episode_return = total_asset / self.initial_total_asset 438 | episode_returns.append(episode_return) 439 | if done: 440 | break 441 | 442 | import matplotlib.pyplot as plt 443 | plt.plot(episode_returns) 444 | plt.grid() 445 | plt.title('cumulative return') 446 | plt.xlabel('day') 447 | plt.xlabel('multiple of initial_account') 448 | plt.savefig(f'{cwd}/cumulative_return.jpg') 449 | return episode_returns 450 | 451 | 452 | class FeatureEngineer: 453 | """Provides methods for preprocessing the stock price data 454 | from finrl.preprocessing.preprocessors import FeatureEngineer 455 | 456 | Attributes 457 | ---------- 458 | use_technical_indicator : boolean 459 | we technical indicator or not 460 | tech_indicator_list : list 461 | a list of technical indicator names (modified from config.py) 462 | use_turbulence : boolean 463 | use turbulence index or not 464 | user_defined_feature:boolean 465 | user user defined features or not 466 | 467 | Methods 468 | ------- 469 | preprocess_data() 470 | main method to do the feature engineering 471 | 472 | """ 473 | 474 | def __init__( 475 | self, 476 | use_technical_indicator=True, 477 | tech_indicator_list=None, # config.TECHNICAL_INDICATORS_LIST, 478 | use_turbulence=False, 479 | user_defined_feature=False, 480 | ): 481 | self.use_technical_indicator = use_technical_indicator 482 | self.tech_indicator_list = tech_indicator_list 483 | self.use_turbulence = use_turbulence 484 | self.user_defined_feature = user_defined_feature 485 | 486 | def preprocess_data(self, df): 487 | """main method to do the feature engineering 488 | @:param config: source dataframe 489 | @:return: a DataMatrices object 490 | """ 491 | 492 | if self.use_technical_indicator: 493 | # add technical indicators using stockstats 494 | df = self.add_technical_indicator(df) 495 | print("Successfully added technical indicators") 496 | 497 | # add turbulence index for multiple stock 498 | if self.use_turbulence: 499 | df = self.add_turbulence(df) 500 | print("Successfully added turbulence index") 501 | 502 | # add user defined feature 503 | if self.user_defined_feature: 504 | df = self.add_user_defined_feature(df) 505 | print("Successfully added user defined features") 506 | 507 | # fill the missing values at the beginning and the end 508 | df = df.fillna(method="bfill").fillna(method="ffill") 509 | return df 510 | 511 | def add_technical_indicator(self, data): 512 | """ 513 | calculate technical indicators 514 | use stockstats package to add technical inidactors 515 | :param data: (df) pandas dataframe 516 | :return: (df) pandas dataframe 517 | """ 518 | from stockstats import StockDataFrame as Sdf # for Sdf.retype 519 | 520 | df = data.copy() 521 | df = df.sort_values(by=['tic', 'date']) 522 | stock = Sdf.retype(df.copy()) 523 | unique_ticker = stock.tic.unique() 524 | 525 | for indicator in self.tech_indicator_list: 526 | indicator_df = pd.DataFrame() 527 | for i in range(len(unique_ticker)): 528 | try: 529 | temp_indicator = stock[stock.tic == unique_ticker[i]][indicator] 530 | temp_indicator = pd.DataFrame(temp_indicator) 531 | temp_indicator['tic'] = unique_ticker[i] 532 | temp_indicator['date'] = df[df.tic == unique_ticker[i]]['date'].to_list() 533 | indicator_df = indicator_df.append( 534 | temp_indicator, ignore_index=True 535 | ) 536 | except Exception as e: 537 | print(e) 538 | df = df.merge(indicator_df[['tic', 'date', indicator]], on=['tic', 'date'], how='left') 539 | df = df.sort_values(by=['date', 'tic']) 540 | return df 541 | 542 | def add_turbulence(self, data): 543 | """ 544 | add turbulence index from a precalcualted dataframe 545 | :param data: (df) pandas dataframe 546 | :return: (df) pandas dataframe 547 | """ 548 | df = data.copy() 549 | turbulence_index = self.calculate_turbulence(df) 550 | df = df.merge(turbulence_index, on="date") 551 | df = df.sort_values(["date", "tic"]).reset_index(drop=True) 552 | return df 553 | 554 | @staticmethod 555 | def add_user_defined_feature(data): 556 | """ 557 | add user defined features 558 | :param data: (df) pandas dataframe 559 | :return: (df) pandas dataframe 560 | """ 561 | df = data.copy() 562 | df["daily_return"] = df.close.pct_change(1) 563 | # df['return_lag_1']=df.close.pct_change(2) 564 | # df['return_lag_2']=df.close.pct_change(3) 565 | # df['return_lag_3']=df.close.pct_change(4) 566 | # df['return_lag_4']=df.close.pct_change(5) 567 | return df 568 | 569 | @staticmethod 570 | def calculate_turbulence(data): 571 | """calculate turbulence index based on dow 30""" 572 | # can add other market assets 573 | df = data.copy() 574 | df_price_pivot = df.pivot(index="date", columns="tic", values="close") 575 | # use returns to calculate turbulence 576 | df_price_pivot = df_price_pivot.pct_change() 577 | 578 | unique_date = df.date.unique() 579 | # start after a year 580 | start = 252 581 | turbulence_index = [0] * start 582 | # turbulence_index = [0] 583 | count = 0 584 | for i in range(start, len(unique_date)): 585 | current_price = df_price_pivot[df_price_pivot.index == unique_date[i]] 586 | # use one year rolling window to calcualte covariance 587 | hist_price = df_price_pivot[ 588 | (df_price_pivot.index < unique_date[i]) 589 | & (df_price_pivot.index >= unique_date[i - 252]) 590 | ] 591 | # Drop tickers which has number missing values more than the "oldest" ticker 592 | filtered_hist_price = hist_price.iloc[hist_price.isna().sum().min():].dropna(axis=1) 593 | 594 | cov_temp = filtered_hist_price.cov() 595 | current_temp = current_price[[x for x in filtered_hist_price]] - np.mean(filtered_hist_price, axis=0) 596 | temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot( 597 | current_temp.values.T 598 | ) 599 | if temp > 0: 600 | count += 1 601 | if count > 2: 602 | turbulence_temp = temp[0][0] 603 | else: 604 | # avoid large outlier because of the calculation just begins 605 | turbulence_temp = 0 606 | else: 607 | turbulence_temp = 0 608 | turbulence_index.append(turbulence_temp) 609 | 610 | turbulence_index = pd.DataFrame( 611 | {"date": df_price_pivot.index, "turbulence": turbulence_index} 612 | ) 613 | return turbulence_index 614 | -------------------------------------------------------------------------------- /run_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import time 4 | import torch 5 | import numpy as np 6 | import numpy.random as rd 7 | from copy import deepcopy 8 | from ElegantRL_master.elegantrl.replay import ReplayBuffer, ReplayBufferMP 9 | from ElegantRL_master.elegantrl.env import PreprocessEnv 10 | import config 11 | 12 | """[ElegantRL](https://github.com/AI4Finance-LLC/ElegantRL)""" 13 | 14 | 15 | class Arguments: 16 | def __init__(self, agent=None, env=None, gpu_id=None, if_on_policy=False): 17 | self.agent = agent # Deep Reinforcement Learning algorithm 18 | 19 | self.cwd = None # current work directory. cwd is None means set it automatically 20 | self.env = env # the environment for training 21 | self.env_eval = None # the environment for evaluating 22 | self.gpu_id = gpu_id # choose the GPU for running. gpu_id is None means set it automatically 23 | 24 | '''Arguments for training (off-policy)''' 25 | self.net_dim = 2 ** 8 # the network width 26 | self.batch_size = 2 ** 8 # num of transitions sampled from replay buffer. 27 | self.repeat_times = 2 ** 0 # repeatedly update network to keep critic's loss small 28 | self.target_step = 2 ** 10 # collect target_step, then update network 29 | self.max_memo = 2 ** 17 # capacity of replay buffer 30 | if if_on_policy: # (on-policy) 31 | self.net_dim = 2 ** 9 32 | self.batch_size = 2 ** 9 33 | self.repeat_times = 2 ** 4 34 | self.target_step = 2 ** 12 35 | self.max_memo = self.target_step 36 | self.gamma = 0.99 # discount factor of future rewards 37 | self.reward_scale = 2 ** 0 # an approximate target reward usually be closed to 256 38 | self.if_per = False # Prioritized Experience Replay for sparse reward 39 | 40 | self.rollout_num = 2 # the number of rollout workers (larger is not always faster) 41 | self.num_threads = 8 # cpu_num for evaluate model, torch.set_num_threads(self.num_threads) 42 | 43 | '''Arguments for evaluate''' 44 | self.break_step = 2 ** 20 # break training after 'total_step > break_step' 45 | self.if_remove = True # remove the cwd folder? (True, False, None:ask me) 46 | self.if_allow_break = True # allow break training when reach goal (early termination) 47 | self.eval_gap = 2 ** 5 # evaluate the agent per eval_gap seconds 48 | self.eval_times1 = 2 ** 2 # evaluation times 49 | self.eval_times2 = 2 ** 4 # evaluation times if 'eval_reward > max_reward' 50 | self.random_seed = 0 # initialize random seed in self.init_before_training() 51 | 52 | def init_before_training(self, if_main=True): 53 | if self.agent is None: 54 | raise RuntimeError('\n| Why agent=None? Assignment args.agent = AgentXXX please.') 55 | if not hasattr(self.agent, 'init'): 56 | raise RuntimeError('\n| There should be agent=AgentXXX() instead of agent=AgentXXX') 57 | if self.env is None: 58 | raise RuntimeError('\n| Why env=None? Assignment args.env = XxxEnv() please.') 59 | if isinstance(self.env, str) or not hasattr(self.env, 'env_name'): 60 | raise RuntimeError('\n| What is env.env_name? use env=PreprocessEnv(env). It is a Wrapper.') 61 | 62 | '''set gpu_id automatically''' 63 | if self.gpu_id is None: # set gpu_id automatically 64 | import sys 65 | self.gpu_id = sys.argv[-1][-4] 66 | else: 67 | self.gpu_id = str(self.gpu_id) 68 | if not self.gpu_id.isdigit(): # set gpu_id as '0' in default 69 | self.gpu_id = '0' 70 | 71 | '''set cwd automatically''' 72 | if self.cwd is None: 73 | # ---- 74 | agent_name = self.agent.__class__.__name__ 75 | # self.cwd = f'./{agent_name}/{self.env.env_name}_{self.gpu_id}' 76 | # self.cwd = f'./{agent_name}/{self.env.env_name}' 77 | self.cwd = f'./{config.WEIGHTS_PATH}/{self.env.env_name}' 78 | 79 | # model_folder_path = f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.BATCH_A_STOCK_CODE[0]}' \ 80 | # f'/single_{config.VALI_DAYS_FLAG}' 81 | # ---- 82 | 83 | if if_main: 84 | print(f'| GPU id: {self.gpu_id}, cwd: {self.cwd}') 85 | 86 | import shutil # remove history according to bool(if_remove) 87 | if self.if_remove is None: 88 | self.if_remove = bool(input("PRESS 'y' to REMOVE: {}? ".format(self.cwd)) == 'y') 89 | if self.if_remove: 90 | shutil.rmtree(self.cwd, ignore_errors=True) 91 | print("| Remove history") 92 | os.makedirs(self.cwd, exist_ok=True) 93 | 94 | os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpu_id) 95 | torch.set_num_threads(self.num_threads) 96 | torch.set_default_dtype(torch.float32) 97 | torch.manual_seed(self.random_seed) 98 | np.random.seed(self.random_seed) 99 | 100 | 101 | '''single process training''' 102 | 103 | 104 | def train_and_evaluate(args): 105 | args.init_before_training() 106 | 107 | '''basic arguments''' 108 | cwd = args.cwd 109 | env = args.env 110 | agent = args.agent 111 | gpu_id = args.gpu_id # necessary for Evaluator? 112 | 113 | '''training arguments''' 114 | net_dim = args.net_dim 115 | max_memo = args.max_memo 116 | break_step = args.break_step 117 | batch_size = args.batch_size 118 | target_step = args.target_step 119 | repeat_times = args.repeat_times 120 | if_break_early = args.if_allow_break 121 | if_per = args.if_per 122 | gamma = args.gamma 123 | reward_scale = args.reward_scale 124 | 125 | '''evaluating arguments''' 126 | eval_gap = args.eval_gap 127 | eval_times1 = args.eval_times1 128 | eval_times2 = args.eval_times2 129 | if args.env_eval is not None: 130 | env_eval = args.env_eval 131 | elif args.env_eval in set(gym.envs.registry.env_specs.keys()): 132 | env_eval = PreprocessEnv(gym.make(env.env_name)) 133 | else: 134 | env_eval = deepcopy(env) 135 | 136 | del args # In order to show these hyper-parameters clearly, I put them above. 137 | 138 | '''init: environment''' 139 | max_step = env.max_step 140 | state_dim = env.state_dim 141 | action_dim = env.action_dim 142 | if_discrete = env.if_discrete 143 | 144 | '''init: Agent, ReplayBuffer, Evaluator''' 145 | agent.init(net_dim, state_dim, action_dim, if_per) 146 | 147 | # ---- 148 | # 目录 path 149 | # model_folder_path = f'./{config.AGENT_NAME}/batch/{config.BATCH_A_STOCK_CODE[0]}' \ 150 | # f'/batch_{config.VALI_DAYS_FLAG}' 151 | 152 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \ 153 | f'batch_{config.VALI_DAYS_FLAG}' 154 | 155 | # 文件 path 156 | model_file_path = f'{model_folder_path}/actor.pth' 157 | 158 | # 如果model存在,则加载 159 | if os.path.exists(model_file_path): 160 | agent.save_load_model(model_folder_path, if_save=False) 161 | pass 162 | # ---- 163 | 164 | if_on_policy = getattr(agent, 'if_on_policy', False) 165 | 166 | buffer = ReplayBuffer(max_len=max_memo + max_step, state_dim=state_dim, action_dim=1 if if_discrete else action_dim, 167 | if_on_policy=if_on_policy, if_per=if_per, if_gpu=True) 168 | 169 | evaluator = Evaluator(cwd=cwd, agent_id=gpu_id, device=agent.device, env=env_eval, 170 | eval_gap=eval_gap, eval_times1=eval_times1, eval_times2=eval_times2, ) 171 | 172 | '''prepare for training''' 173 | agent.state = env.reset() 174 | if if_on_policy: 175 | steps = 0 176 | else: # explore_before_training for off-policy 177 | with torch.no_grad(): # update replay buffer 178 | steps = explore_before_training(env, buffer, target_step, reward_scale, gamma) 179 | 180 | agent.update_net(buffer, target_step, batch_size, repeat_times) # pre-training and hard update 181 | agent.act_target.load_state_dict(agent.act.state_dict()) if getattr(agent, 'act_target', None) else None 182 | agent.cri_target.load_state_dict(agent.cri.state_dict()) if getattr(agent, 'cri_target', None) else None 183 | total_step = steps 184 | 185 | '''start training''' 186 | if_reach_goal = False 187 | while not ((if_break_early and if_reach_goal) 188 | or total_step > break_step 189 | or os.path.exists(f'{cwd}/stop')): 190 | steps = agent.explore_env(env, buffer, target_step, reward_scale, gamma) 191 | total_step += steps 192 | 193 | obj_a, obj_c = agent.update_net(buffer, target_step, batch_size, repeat_times) 194 | 195 | if_reach_goal = evaluator.evaluate_save(agent.act, steps, obj_a, obj_c) 196 | evaluator.draw_plot() 197 | 198 | print(f'| SavedDir: {cwd}\n| UsedTime: {time.time() - evaluator.start_time:.0f}') 199 | 200 | 201 | '''multiprocessing training''' 202 | 203 | 204 | def train_and_evaluate_mp(args): 205 | act_workers = args.rollout_num 206 | import multiprocessing as mp # Python built-in multiprocessing library 207 | 208 | pipe1_eva, pipe2_eva = mp.Pipe() # Pipe() for Process mp_evaluate_agent() 209 | pipe2_exp_list = list() # Pipe() for Process mp_explore_in_env() 210 | 211 | process_train = mp.Process(target=mp_train, args=(args, pipe2_eva, pipe2_exp_list)) 212 | process_evaluate = mp.Process(target=mp_evaluate, args=(args, pipe1_eva)) 213 | process = [process_train, process_evaluate] 214 | 215 | for worker_id in range(act_workers): 216 | exp_pipe1, exp_pipe2 = mp.Pipe(duplex=True) 217 | pipe2_exp_list.append(exp_pipe1) 218 | process.append(mp.Process(target=mp_explore, args=(args, exp_pipe2, worker_id))) 219 | 220 | [p.start() for p in process] 221 | process_evaluate.join() 222 | process_train.join() 223 | [p.terminate() for p in process] 224 | 225 | 226 | def mp_train(args, pipe1_eva, pipe1_exp_list): 227 | args.init_before_training(if_main=False) 228 | 229 | '''basic arguments''' 230 | env = args.env 231 | cwd = args.cwd 232 | agent = args.agent 233 | rollout_num = args.rollout_num 234 | 235 | '''training arguments''' 236 | net_dim = args.net_dim 237 | max_memo = args.max_memo 238 | break_step = args.break_step 239 | batch_size = args.batch_size 240 | target_step = args.target_step 241 | repeat_times = args.repeat_times 242 | if_break_early = args.if_allow_break 243 | if_per = args.if_per 244 | del args # In order to show these hyper-parameters clearly, I put them above. 245 | 246 | '''init: environment''' 247 | max_step = env.max_step 248 | state_dim = env.state_dim 249 | action_dim = env.action_dim 250 | if_discrete = env.if_discrete 251 | 252 | '''init: Agent, ReplayBuffer''' 253 | agent.init(net_dim, state_dim, action_dim, if_per) 254 | 255 | # ---- 256 | # 目录 path 257 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \ 258 | f'batch_{config.VALI_DAYS_FLAG}' 259 | 260 | # f'./{config.WEIGHTS_PATH}/single/{config.AGENT_NAME}/{config.BATCH_A_STOCK_CODE[0]}/StockTradingEnv-v1' 261 | 262 | # 文件 path 263 | model_file_path = f'{model_folder_path}/actor.pth' 264 | 265 | # 如果model存在,则加载 266 | if os.path.exists(model_file_path): 267 | agent.save_load_model(model_folder_path, if_save=False) 268 | pass 269 | # ---- 270 | 271 | if_on_policy = getattr(agent, 'if_on_policy', False) 272 | 273 | '''send''' 274 | pipe1_eva.send(agent.act) # send 275 | # act = pipe2_eva.recv() # recv 276 | 277 | buffer_mp = ReplayBufferMP(max_len=max_memo + max_step * rollout_num, if_on_policy=if_on_policy, 278 | state_dim=state_dim, action_dim=1 if if_discrete else action_dim, 279 | rollout_num=rollout_num, if_gpu=True, if_per=if_per) 280 | 281 | '''prepare for training''' 282 | if if_on_policy: 283 | steps = 0 284 | else: # explore_before_training for off-policy 285 | with torch.no_grad(): # update replay buffer 286 | steps = 0 287 | for i in range(rollout_num): 288 | pipe1_exp = pipe1_exp_list[i] 289 | 290 | # pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len])) 291 | buf_state, buf_other = pipe1_exp.recv() 292 | 293 | steps += len(buf_state) 294 | buffer_mp.extend_buffer(buf_state, buf_other, i) 295 | 296 | agent.update_net(buffer_mp, target_step, batch_size, repeat_times) # pre-training and hard update 297 | agent.act_target.load_state_dict(agent.act.state_dict()) if getattr(env, 'act_target', None) else None 298 | agent.cri_target.load_state_dict(agent.cri.state_dict()) if getattr(env, 'cri_target', None) in dir( 299 | agent) else None 300 | total_step = steps 301 | '''send''' 302 | pipe1_eva.send((agent.act, steps, 0, 0.5)) # send 303 | # act, steps, obj_a, obj_c = pipe2_eva.recv() # recv 304 | 305 | '''start training''' 306 | if_solve = False 307 | while not ((if_break_early and if_solve) 308 | or total_step > break_step 309 | or os.path.exists(f'{cwd}/stop')): 310 | '''update ReplayBuffer''' 311 | steps = 0 # send by pipe1_eva 312 | for i in range(rollout_num): 313 | pipe1_exp = pipe1_exp_list[i] 314 | '''send''' 315 | pipe1_exp.send(agent.act) 316 | # agent.act = pipe2_exp.recv() 317 | '''recv''' 318 | # pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len])) 319 | buf_state, buf_other = pipe1_exp.recv() 320 | 321 | steps += len(buf_state) 322 | buffer_mp.extend_buffer(buf_state, buf_other, i) 323 | total_step += steps 324 | 325 | '''update network parameters''' 326 | obj_a, obj_c = agent.update_net(buffer_mp, target_step, batch_size, repeat_times) 327 | 328 | '''saves the agent with max reward''' 329 | '''send''' 330 | pipe1_eva.send((agent.act, steps, obj_a, obj_c)) 331 | # q_i_eva_get = pipe2_eva.recv() 332 | 333 | if_solve = pipe1_eva.recv() 334 | 335 | if pipe1_eva.poll(): 336 | '''recv''' 337 | # pipe2_eva.send(if_solve) 338 | if_solve = pipe1_eva.recv() 339 | 340 | buffer_mp.print_state_norm(env.neg_state_avg if hasattr(env, 'neg_state_avg') else None, 341 | env.div_state_std if hasattr(env, 'div_state_std') else None) # 2020-12-12 342 | 343 | '''send''' 344 | pipe1_eva.send('stop') 345 | # q_i_eva_get = pipe2_eva.recv() 346 | time.sleep(4) 347 | 348 | 349 | def mp_explore(args, pipe2_exp, worker_id): 350 | args.init_before_training(if_main=False) 351 | 352 | '''basic arguments''' 353 | env = args.env 354 | agent = args.agent 355 | rollout_num = args.rollout_num 356 | 357 | '''training arguments''' 358 | net_dim = args.net_dim 359 | max_memo = args.max_memo 360 | target_step = args.target_step 361 | gamma = args.gamma 362 | if_per = args.if_per 363 | reward_scale = args.reward_scale 364 | 365 | random_seed = args.random_seed 366 | torch.manual_seed(random_seed + worker_id) 367 | np.random.seed(random_seed + worker_id) 368 | del args # In order to show these hyper-parameters clearly, I put them above. 369 | 370 | '''init: environment''' 371 | max_step = env.max_step 372 | state_dim = env.state_dim 373 | action_dim = env.action_dim 374 | if_discrete = env.if_discrete 375 | 376 | '''init: Agent, ReplayBuffer''' 377 | agent.init(net_dim, state_dim, action_dim, if_per) 378 | 379 | # ---- 380 | # 目录 path 381 | # model_folder_path = f'./{config.AGENT_NAME}/single/{config.BATCH_A_STOCK_CODE[0]}' \ 382 | # f'/single_{config.VALI_DAYS_FLAG}' 383 | 384 | model_folder_path = f'./{config.WEIGHTS_PATH}/batch/{config.AGENT_NAME}/' \ 385 | f'batch_{config.VALI_DAYS_FLAG}' 386 | 387 | # 文件 path 388 | model_file_path = f'{model_folder_path}/actor.pth' 389 | 390 | # 如果model存在,则加载 391 | if os.path.exists(model_file_path): 392 | agent.save_load_model(model_folder_path, if_save=False) 393 | pass 394 | # ---- 395 | 396 | agent.state = env.reset() 397 | 398 | if_on_policy = getattr(agent, 'if_on_policy', False) 399 | buffer = ReplayBuffer(max_len=max_memo // rollout_num + max_step, if_on_policy=if_on_policy, 400 | state_dim=state_dim, action_dim=1 if if_discrete else action_dim, 401 | if_per=if_per, if_gpu=False) 402 | 403 | '''start exploring''' 404 | exp_step = target_step // rollout_num 405 | with torch.no_grad(): 406 | if not if_on_policy: 407 | explore_before_training(env, buffer, exp_step, reward_scale, gamma) 408 | 409 | buffer.update_now_len_before_sample() 410 | 411 | pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len])) 412 | # buf_state, buf_other = pipe1_exp.recv() 413 | 414 | buffer.empty_buffer_before_explore() 415 | 416 | while True: 417 | agent.explore_env(env, buffer, exp_step, reward_scale, gamma) 418 | 419 | buffer.update_now_len_before_sample() 420 | 421 | pipe2_exp.send((buffer.buf_state[:buffer.now_len], buffer.buf_other[:buffer.now_len])) 422 | # buf_state, buf_other = pipe1_exp.recv() 423 | 424 | buffer.empty_buffer_before_explore() 425 | 426 | # pipe1_exp.send(agent.act) 427 | agent.act = pipe2_exp.recv() 428 | 429 | 430 | def mp_evaluate(args, pipe2_eva): 431 | args.init_before_training(if_main=True) 432 | 433 | '''basic arguments''' 434 | cwd = args.cwd 435 | env = args.env 436 | env_eval = env if args.env_eval is None else args.env_eval 437 | agent_id = args.gpu_id 438 | 439 | '''evaluating arguments''' 440 | eval_gap = args.eval_gap 441 | eval_times1 = args.eval_times1 442 | eval_times2 = args.eval_times2 443 | del args # In order to show these hyper-parameters clearly, I put them above. 444 | 445 | '''init: Evaluator''' 446 | evaluator = Evaluator(cwd=cwd, agent_id=agent_id, device=torch.device("cpu"), env=env_eval, 447 | eval_gap=eval_gap, eval_times1=eval_times1, eval_times2=eval_times2, ) # build Evaluator 448 | 449 | '''act_cpu without gradient for pipe1_eva''' 450 | # pipe1_eva.send(agent.act) 451 | act = pipe2_eva.recv() 452 | 453 | act_cpu = deepcopy(act).to(torch.device("cpu")) # for pipe1_eva 454 | [setattr(param, 'requires_grad', False) for param in act_cpu.parameters()] 455 | 456 | '''start evaluating''' 457 | with torch.no_grad(): # speed up running 458 | act, steps, obj_a, obj_c = pipe2_eva.recv() # pipe2_eva (act, steps, obj_a, obj_c) 459 | 460 | if_loop = True 461 | while if_loop: 462 | '''update actor''' 463 | while not pipe2_eva.poll(): # wait until pipe2_eva not empty 464 | time.sleep(1) 465 | steps_sum = 0 466 | while pipe2_eva.poll(): # receive the latest object from pipe 467 | '''recv''' 468 | # pipe1_eva.send((agent.act, steps, obj_a, obj_c)) 469 | # pipe1_eva.send('stop') 470 | q_i_eva_get = pipe2_eva.recv() 471 | 472 | if q_i_eva_get == 'stop': 473 | if_loop = False 474 | break 475 | act, steps, obj_a, obj_c = q_i_eva_get 476 | steps_sum += steps 477 | act_cpu.load_state_dict(act.state_dict()) 478 | if_solve = evaluator.evaluate_save(act_cpu, steps_sum, obj_a, obj_c) 479 | '''send''' 480 | pipe2_eva.send(if_solve) 481 | # if_solve = pipe1_eva.recv() 482 | 483 | evaluator.draw_plot() 484 | 485 | print(f'| SavedDir: {cwd}\n| UsedTime: {time.time() - evaluator.start_time:.0f}') 486 | 487 | while pipe2_eva.poll(): # empty the pipe 488 | pipe2_eva.recv() 489 | 490 | 491 | '''utils''' 492 | 493 | 494 | class Evaluator: 495 | def __init__(self, cwd, agent_id, eval_times1, eval_times2, eval_gap, env, device): 496 | self.recorder = [(0., -np.inf, 0., 0., 0.), ] # total_step, r_avg, r_std, obj_a, obj_c 497 | self.r_max = -np.inf 498 | self.total_step = 0 499 | 500 | self.cwd = cwd # constant 501 | self.device = device 502 | self.agent_id = agent_id 503 | self.eval_gap = eval_gap 504 | self.eval_times1 = eval_times1 505 | self.eval_times2 = eval_times2 506 | self.env = env 507 | self.target_return = env.target_return 508 | 509 | self.used_time = None 510 | self.start_time = time.time() 511 | self.eval_time = -1 # a early time 512 | print(f"{'ID':>2} {'Step':>8} {'MaxR':>8} |" 513 | f"{'avgR':>8} {'stdR':>8} {'objA':>8} {'objC':>8} |" 514 | f"{'avgS':>6} {'stdS':>4}") 515 | 516 | def evaluate_save(self, act, steps, obj_a, obj_c) -> bool: 517 | self.total_step += steps # update total training steps 518 | 519 | if time.time() - self.eval_time > self.eval_gap: 520 | self.eval_time = time.time() 521 | 522 | rewards_steps_list = [get_episode_return(self.env, act, self.device) for _ in range(self.eval_times1)] 523 | r_avg, r_std, s_avg, s_std = self.get_r_avg_std_s_avg_std(rewards_steps_list) 524 | 525 | if r_avg > self.r_max: # evaluate actor twice to save CPU Usage and keep precision 526 | rewards_steps_list += [get_episode_return(self.env, act, self.device) 527 | for _ in range(self.eval_times2 - self.eval_times1)] 528 | r_avg, r_std, s_avg, s_std = self.get_r_avg_std_s_avg_std(rewards_steps_list) 529 | if r_avg > self.r_max: # save checkpoint with highest episode return 530 | self.r_max = r_avg # update max reward (episode return) 531 | 532 | '''save actor.pth''' 533 | act_save_path = f'{self.cwd}/actor.pth' 534 | torch.save(act.state_dict(), act_save_path) 535 | print(f"{self.agent_id:<2} {self.total_step:8.2e} {self.r_max:8.2f} |") # save policy and print 536 | 537 | self.recorder.append((self.total_step, r_avg, r_std, obj_a, obj_c)) # update recorder 538 | 539 | if_reach_goal = bool(self.r_max > self.target_return) # check if_reach_goal 540 | if if_reach_goal and self.used_time is None: 541 | self.used_time = int(time.time() - self.start_time) 542 | print(f"{'ID':>2} {'Step':>8} {'TargetR':>8} |" 543 | f"{'avgR':>8} {'stdR':>8} {'UsedTime':>8} ########\n" 544 | f"{self.agent_id:<2} {self.total_step:8.2e} {self.target_return:8.2f} |" 545 | f"{r_avg:8.2f} {r_std:8.2f} {self.used_time:>8} ########") 546 | 547 | print(f"{self.agent_id:<2} {self.total_step:8.2e} {self.r_max:8.2f} |" 548 | f"{r_avg:8.2f} {r_std:8.2f} {obj_a:8.2f} {obj_c:8.2f} |" 549 | f"{s_avg:6.0f} {s_std:4.0f}") 550 | else: 551 | if_reach_goal = False 552 | return if_reach_goal 553 | 554 | def draw_plot(self): 555 | if len(self.recorder) == 0: 556 | print("| save_npy_draw_plot() WARNNING: len(self.recorder)==0") 557 | return None 558 | 559 | '''convert to array and save as npy''' 560 | np.save('%s/recorder.npy' % self.cwd, self.recorder) 561 | 562 | '''draw plot and save as png''' 563 | train_time = int(time.time() - self.start_time) 564 | total_step = int(self.recorder[-1][0]) 565 | save_title = f"plot_step_time_maxR_{int(total_step)}_{int(train_time)}_{self.r_max:.3f}" 566 | 567 | save_learning_curve(self.recorder, self.cwd, save_title) 568 | 569 | @staticmethod 570 | def get_r_avg_std_s_avg_std(rewards_steps_list): 571 | rewards_steps_ary = np.array(rewards_steps_list) 572 | r_avg, s_avg = rewards_steps_ary.mean(axis=0) # average of episode return and episode step 573 | r_std, s_std = rewards_steps_ary.std(axis=0) # standard dev. of episode return and episode step 574 | return r_avg, r_std, s_avg, s_std 575 | 576 | 577 | def get_episode_return(env, act, device) -> (float, int): 578 | episode_return = 0.0 # sum of rewards in an episode 579 | episode_step = 1 580 | max_step = env.max_step 581 | if_discrete = env.if_discrete 582 | 583 | state = env.reset() 584 | for episode_step in range(max_step): 585 | s_tensor = torch.as_tensor((state,), device=device) 586 | a_tensor = act(s_tensor) 587 | if if_discrete: 588 | a_tensor = a_tensor.argmax(dim=1) 589 | action = a_tensor.detach().cpu().numpy()[0] # not need detach(), because with torch.no_grad() outside 590 | state, reward, done, _ = env.step(action) 591 | episode_return += reward 592 | if done: 593 | break 594 | episode_return = getattr(env, 'episode_return', episode_return) 595 | return episode_return, episode_step + 1 596 | 597 | 598 | def save_learning_curve(recorder, cwd='.', save_title='learning curve'): 599 | recorder = np.array(recorder) # recorder_ary.append((self.total_step, r_avg, r_std, obj_a, obj_c)) 600 | steps = recorder[:, 0] # x-axis is training steps 601 | r_avg = recorder[:, 1] 602 | r_std = recorder[:, 2] 603 | obj_a = recorder[:, 3] 604 | obj_c = recorder[:, 4] 605 | 606 | '''plot subplots''' 607 | import matplotlib as mpl 608 | mpl.use('Agg') 609 | """Generating matplotlib graphs without a running X server [duplicate] 610 | write `mpl.use('Agg')` before `import matplotlib.pyplot as plt` 611 | https://stackoverflow.com/a/4935945/9293137 612 | """ 613 | import matplotlib.pyplot as plt 614 | fig, axs = plt.subplots(2) 615 | 616 | axs0 = axs[0] 617 | axs0.cla() 618 | color0 = 'lightcoral' 619 | axs0.set_xlabel('Total Steps') 620 | axs0.set_ylabel('Episode Return') 621 | axs0.plot(steps, r_avg, label='Episode Return', color=color0) 622 | axs0.fill_between(steps, r_avg - r_std, r_avg + r_std, facecolor=color0, alpha=0.3) 623 | 624 | ax11 = axs[1] 625 | ax11.cla() 626 | color11 = 'royalblue' 627 | axs0.set_xlabel('Total Steps') 628 | ax11.set_ylabel('objA', color=color11) 629 | ax11.plot(steps, obj_a, label='objA', color=color11) 630 | ax11.tick_params(axis='y', labelcolor=color11) 631 | 632 | ax12 = axs[1].twinx() 633 | color12 = 'darkcyan' 634 | ax12.set_ylabel('objC', color=color12) 635 | ax12.fill_between(steps, obj_c, facecolor=color12, alpha=0.2, ) 636 | ax12.tick_params(axis='y', labelcolor=color12) 637 | 638 | '''plot save''' 639 | plt.title(save_title, y=2.3) 640 | plt.savefig(f"{cwd}/plot_learning_curve.jpg") 641 | plt.close('all') # avoiding warning about too many open figures, rcParam `figure.max_open_warning` 642 | # plt.show() # if use `mpl.use('Agg')` to draw figures without GUI, then plt can't plt.show() 643 | 644 | 645 | def explore_before_training(env, buffer, target_step, reward_scale, gamma) -> int: 646 | # just for off-policy. Because on-policy don't explore before training. 647 | if_discrete = env.if_discrete 648 | action_dim = env.action_dim 649 | 650 | state = env.reset() 651 | steps = 0 652 | 653 | while steps < target_step: 654 | action = rd.randint(action_dim) if if_discrete else rd.uniform(-1, 1, size=action_dim) 655 | next_state, reward, done, _ = env.step(action) 656 | steps += 1 657 | 658 | scaled_reward = reward * reward_scale 659 | mask = 0.0 if done else gamma 660 | other = (scaled_reward, mask, action) if if_discrete else (scaled_reward, mask, *action) 661 | buffer.append_buffer(state, other) 662 | 663 | state = env.reset() if done else next_state 664 | return steps 665 | --------------------------------------------------------------------------------