├── .gitignore ├── LICENSE ├── README.md ├── abat ├── __init__.py ├── backend │ ├── __init__.py │ └── orm.py ├── common.py ├── event.py ├── md.py ├── strategy.py ├── trade.py └── utils │ ├── __init__.py │ ├── db_utils.py │ ├── fh_utils.py │ └── redis.py ├── agent ├── __init__.py ├── md_agent.py └── td_agent.py ├── analysis ├── __init__.py └── account.py ├── backend ├── __init__.py ├── check.py └── orm.py ├── config.py ├── mass ├── alipay_code200.png ├── dashang_code200.png └── webchat_code200.png ├── requirements.txt ├── run.py ├── run_all.sh ├── strategy ├── __init__.py ├── bs_against_files │ ├── __init__.py │ ├── csv_orders.py │ └── csv_orders_with_feedback.py ├── file_strategy.py └── simple_strategy.py └── trader.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea/ 106 | /strategy/file_order/ 107 | *.csv 108 | file_order/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ABAT_trader_4_blockchain 2 | Auto Backtest Analysis Trade Framework (简称ABAT)支持期货、数字货币进行量化交易,集成回测、分析、交易于一体。与市场上同类回测框架相比,有如下优势: 3 | 4 | - 更加完备的支持多品种、多周期策略交易 5 | - 对跨周期策略回测更加真实,回测模式下,将不同周期数据进行时间排序推送,从而激活对应的周期的响应函数 6 | - 框架采用分部署架构,行情推送与主框架可分离部署,通过redis进行数据广播 7 | - 未来将可以支持股票、期货、数字货币多种接口,同时交易 8 | 9 | 主要组件: 10 | 11 | - 行情代理 md agent 12 | - 交易代理 trade agent 13 | - 行情推送 md feeder 14 | - 回测及实时行情交易框架 ABAT 15 | 16 | 当前项目主要用于对数字货币进行自动化交易,策略分析使用。 17 | 18 | 由于精力有限,目前暂未实现前端展现。交易自动执行,通过日至可查看交易执行情况。 19 | 20 | ## 安装 21 | 22 | 系统环境要求: 23 | > Python 3.6 24 | > 25 | > MySQL 5.7 (具体配置见下文) 26 | > 27 | > Redis 3.0.6 28 | 29 | ## 配置 30 | 31 | > config.py 配置文件 32 | 33 | ## 策略运行示例 34 | 35 | strategy 目录下 36 | - simple_strategy.py 实现简单均线交叉策略回测\ 37 | - file_strategy.py 调仓文件导入式的交易(实时行情)\ 38 | 39 | 其他策略 coming soon 40 | 41 | ## 策略研发框架 42 | 43 | 执行策略 44 | 45 | ```python 46 | # 参数设置 47 | strategy_params = {} 48 | md_agent_params_list = [ 49 | # { 50 | # 'name': 'min1', 51 | # 'md_period': PeriodType.Min1, 52 | # 'instrument_id_list': ['rb1805', 'i1801'], # ['jm1711', 'rb1712', 'pb1801', 'IF1710'], 53 | # 'init_md_date_to': '2017-9-1', 54 | # 'dict_or_df_as_param': dict 55 | # }, 56 | { 57 | 'name': 'tick', 58 | 'md_period': PeriodType.Tick, 59 | 'instrument_id_list': ['ethusdt', 'eosusdt'], # ['jm1711', 'rb1712', 'pb1801', 'IF1710'], 60 | }] 61 | run_mode_realtime_params = { 62 | 'run_mode': RunMode.Realtime, 63 | 'enable_timer_thread': True, 64 | 'seconds_of_timer_interval': 15, 65 | } 66 | run_mode_backtest_params = { 67 | 'run_mode': RunMode.Backtest, 68 | 'date_from': '2017-9-4', 69 | 'date_to': '2017-9-27', 70 | 'init_cash': 1000000, 71 | 'trade_mode': BacktestTradeMode.Order_2_Deal 72 | } 73 | # 初始化策略处理器 74 | stghandler = StgHandlerBase.factory( 75 | stg_class_obj=ReadFileStg, 76 | strategy_params=strategy_params, 77 | md_agent_params_list=md_agent_params_list, 78 | **run_mode_realtime_params) 79 | # 开始执行策略 80 | stghandler.start() 81 | # 策略执行 2 分钟后关闭 82 | time.sleep(120) 83 | stghandler.keep_running = False 84 | stghandler.join() 85 | logging.info("执行结束") 86 | ``` 87 | 88 | StgHandlerBase.factory 为工厂方法,用于产生策略执行对象实力 89 | 90 | ```Python 91 | def factory(stg_class_obj: StgBase.__class__, strategy_params, md_agent_params_list, run_mode: RunMode, **run_mode_params): 92 | """ 93 | 建立策略对象 94 | 建立数据库相应记录信息 95 | 根据运行模式(实时、回测):选择相应的md_agent以及trade_agent 96 | :param stg_class_obj: 策略类型 StgBase 的子类 97 | :param strategy_params: 策略参数,策略对象初始化是传入参数使用 98 | :param md_agent_params_list: 行情代理(md_agent)参数,支持同时订阅多周期、多品种,例如同时订阅 [ethusdt, eosusdt] 1min 行情、[btcusdt, ethbtc] tick 行情 99 | :param run_mode: 运行模式 RunMode.Realtime 或 RunMode.Backtest 100 | :param run_mode_params: 运行参数, 101 | :return: 策略处理对象实力 102 | """ 103 | ``` 104 | 105 | ## 欢迎赞助 106 | 107 | #### 微信 108 | 109 | ![微信支付](https://github.com/mmmaaaggg/ABAT_trader_4_blockchain/blob/master/mass/webchat_code200.png?raw=true) 110 | 111 | #### 支付宝 112 | 113 | ![微信支付](https://github.com/mmmaaaggg/ABAT_trader_4_blockchain/blob/master/mass/alipay_code200.png?raw=true) 114 | 115 | #### 微信打赏(¥10) 116 | 117 | ![微信打赏](https://github.com/mmmaaaggg/ABAT_trader_4_blockchain/blob/master/mass/dashang_code200.png?raw=true) 118 | 119 | ## MySQL 配置方法 120 | 121 | 1. Ubuntu 18.04 环境下安装 MySQL,5.7 122 | 123 | ```bash 124 | sudo apt install mysql-server 125 | ``` 126 | 2. 默认情况下,没有输入用户名密码的地方,因此,安装完后需要手动重置Root密码,方法如下: 127 | 128 | ```bash 129 | cd /etc/mysql/debian.cnf 130 | sudo more debian.cnf 131 | ``` 132 | 出现类似这样的东西 133 | ```bash 134 | # Automatically generated for Debian scripts. DO NOT TOUCH! 135 | [client] 136 | host = localhost 137 | user = debian-sys-maint 138 | password = j1bsABuuDRGKCV5s 139 | socket = /var/run/mysqld/mysqld.sock 140 | [mysql_upgrade] 141 | host = localhost 142 | user = debian-sys-maint 143 | password = j1bsABuuDRGKCV5s 144 | socket = /var/run/mysqld/mysqld.sock 145 | ``` 146 | 147 | 以debian-sys-maint为用户名登录,密码就是debian.cnf里那个 password = 后面的东西。 148 | 使用mysql -u debian-sys-maint -p 进行登录。 149 | 进入mysql之后修改MySQL的密码,具体的操作如下用命令: 150 | ```mysql 151 | use mysql; 152 | 153 | update user set authentication_string=PASSWORD("Dcba4321") where user='root'; 154 | 155 | update user set plugin="mysql_native_password"; 156 | 157 | flush privileges; 158 | ``` 159 | 3. 然后就可以用过root用户登陆了 160 | 161 | ```bash 162 | mysql -uroot -p 163 | ``` 164 | 165 | 4. 创建用户 mg 默认密码 Abcd1234 166 | 167 | ```mysql 168 | CREATE USER 'mg'@'%' IDENTIFIED BY 'Abcd1234'; 169 | ``` 170 | 5. 创建数据库 bc_md 171 | 172 | ```mysql 173 | CREATE DATABASE `abat` default charset utf8 collate utf8_general_ci; 174 | ``` 175 | 6. 授权 176 | 177 | ```mysql 178 | grant all privileges on abat.* to 'mg'@'localhost' identified by 'Abcd1234'; 179 | 180 | flush privileges; #刷新系统权限表 181 | ``` 182 | -------------------------------------------------------------------------------- /abat/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/20 15:10 6 | @File : __init__.py.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | 11 | -------------------------------------------------------------------------------- /abat/backend/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/20 15:16 6 | @File : __init__.py.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | from sqlalchemy import create_engine 11 | from config import Config 12 | engines = {key: create_engine(url) for key, url in Config.DB_URL_DIC.items()} 13 | # engine_md = engines[Config.DB_SCHEMA_MD] 14 | engine_abat = engines[Config.DB_SCHEMA_ABAT] 15 | -------------------------------------------------------------------------------- /abat/backend/orm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2017/10/9 4 | @author: MG 5 | """ 6 | from datetime import datetime, timedelta 7 | from sqlalchemy import MetaData, Table, Column, Integer, String, DateTime, Float, Boolean, SmallInteger, Date, Time 8 | from sqlalchemy.dialects.mysql import DOUBLE 9 | from sqlalchemy.orm import mapper 10 | from sqlalchemy.ext.declarative import declarative_base 11 | from sqlalchemy.sql import func 12 | from pandas import Timedelta 13 | from config import Config 14 | from abat.backend import engine_abat 15 | from abat.utils.db_utils import with_db_session 16 | from abat.common import Action, Direction 17 | from abat.utils.fh_utils import str_2_date, pd_timedelta_2_timedelta 18 | import logging 19 | 20 | BaseModel = declarative_base() 21 | # 每一次实务均产生数据库插入或更新动作(默认:否) 22 | UPDATE_OR_INSERT_PER_ACTION = False 23 | 24 | 25 | class StgRunInfo(BaseModel): 26 | """策略运行信息""" 27 | 28 | __tablename__ = 'stg_run_info' 29 | stg_run_id = Column(Integer, autoincrement=True, primary_key=True) 30 | stg_name = Column(String(200)) 31 | dt_from = Column(DateTime()) 32 | dt_to = Column(DateTime()) 33 | stg_params = Column(String(2000)) 34 | md_agent_params_list = Column(String(2000)) 35 | run_mode = Column(SmallInteger) 36 | run_mode_params = Column(String(2000)) 37 | 38 | 39 | class OrderInfo(BaseModel): 40 | """订单信息""" 41 | 42 | __tablename__ = 'order_info' 43 | order_id = Column(Integer, primary_key=True, autoincrement=True) 44 | stg_run_id = Column(Integer) 45 | # order_dt = Column(DateTime, server_default=func.now()) 46 | order_date = Column(Date) # 对应行情数据中 ActionDate 47 | order_time = Column(Time) # 对应行情数据中 ActionTime 48 | order_millisec = Column(Integer) # 对应行情数据中 ActionMillisec 49 | direction = Column(Boolean) # 0:空;1:多 50 | action = Column(Integer) # 0:关:1:开 51 | instrument_id = Column(String(30)) 52 | order_price = Column(DOUBLE) 53 | order_vol = Column(DOUBLE) # 订单量 54 | traded_vol = Column(DOUBLE, server_default='0') # 保证金 , comment="成交数量" 55 | margin = Column(DOUBLE, server_default='0') # 保证金 , comment="占用保证金" 56 | 57 | # 每一次实务均产生数据库插入或更新动作(默认:否) 58 | 59 | def __repr__(self): 60 | return "".format( 61 | self) 62 | 63 | @staticmethod 64 | def remove_order_info(stg_run_id: int): 65 | """ 66 | 仅作为调试工具使用,删除指定 stg_run_id 相关的 order_info 67 | :param stg_run_id: 68 | :return: 69 | """ 70 | with with_db_session(engine_abat) as session: 71 | session.execute('DELETE FROM order_info WHERE stg_run_id=:stg_run_id', 72 | {'stg_run_id': stg_run_id}) 73 | session.commit() 74 | 75 | @staticmethod 76 | def create_by_dic(order_dic): 77 | order_info = OrderInfo() 78 | order_info.order_date = order_dic['TradingDay'] 79 | order_info.order_time = pd_timedelta_2_timedelta(order_dic['InsertTime']) 80 | order_info.order_millisec = 0 81 | order_info.direction = Direction.create_by_direction(order_dic['Direction']) 82 | order_info.action = Action.create_by_offsetflag(order_dic['CombOffsetFlag']) 83 | order_info.instrument_id = order_dic['InstrumentID'] 84 | order_info.order_price = order_dic['LimitPrice'] 85 | order_info.order_vol = order_dic['VolumeTotalOriginal'] 86 | order_info.traded_vol = order_dic['VolumeTraded'] 87 | order_info.margin = 0 88 | return order_info 89 | 90 | 91 | class TradeInfo(BaseModel): 92 | """记录成交信息""" 93 | __tablename__ = 'trade_info' 94 | trade_id = Column(Integer, primary_key=True, autoincrement=True) # , comment="成交id" 95 | stg_run_id = Column(Integer) 96 | order_id = Column(Integer) # , comment="对应订单id" 97 | # order_dt = Column(DateTime, server_default=func.now()) 98 | order_price = Column(DOUBLE) # , comment="原订单价格" 99 | order_vol = Column(DOUBLE) # 订单量 , comment="原订单数量" 100 | trade_date = Column(Date) # 对应行情数据中 ActionDate 101 | trade_time = Column(Time) # 对应行情数据中 ActionTime 102 | trade_millisec = Column(Integer) # 对应行情数据中 ActionMillisec 103 | direction = Column(Boolean) # 0:空;1:多 104 | action = Column(Integer) # 0:关:1:开 105 | instrument_id = Column(String(30)) 106 | trade_price = Column(DOUBLE) # , comment="成交价格" 107 | trade_vol = Column(DOUBLE) # 订单量 , comment="成交数量" 108 | margin = Column(DOUBLE, server_default='0') # 保证金 , comment="占用保证金" 109 | commission = Column(DOUBLE, server_default='0') # 佣金、手续费 , comment="佣金、手续费" 110 | 111 | def set_trade_time(self, value): 112 | if isinstance(value, Timedelta): 113 | # print(value, 'parse to timedelta') 114 | self.trade_time = timedelta(seconds=value.seconds) 115 | else: 116 | self.trade_time = value 117 | 118 | @staticmethod 119 | def remove_trade_info(stg_run_id: int): 120 | """ 121 | 仅作为调试工具使用,删除指定 stg_run_id 相关的 trade_info 122 | :param stg_run_id: 123 | :return: 124 | """ 125 | with with_db_session(engine_abat) as session: 126 | session.execute('DELETE FROM trade_info WHERE stg_run_id=:stg_run_id', 127 | {'stg_run_id': stg_run_id}) 128 | session.commit() 129 | 130 | @staticmethod 131 | def create_by_order_info(order_info: OrderInfo): 132 | direction, action, instrument_id = order_info.direction, order_info.action, order_info.instrument_id 133 | order_price, order_vol, order_id = order_info.order_price, order_info.order_vol, order_info.order_id 134 | order_date, order_time, order_millisec = order_info.order_date, order_info.order_time, order_info.order_millisec 135 | stg_run_id = order_info.stg_run_id 136 | 137 | # TODO: 以后还可以增加滑点,成交比例等 138 | # instrument_info = Config.instrument_info_dic[instrument_id] 139 | # multiple = instrument_info['VolumeMultiple'] 140 | # margin_ratio = instrument_info['LongMarginRatio'] 141 | multiple, margin_ratio = 1, 1 142 | margin = order_vol * order_price * multiple * margin_ratio 143 | commission = 0 144 | trade_info = TradeInfo(stg_run_id=stg_run_id, 145 | order_id=order_id, 146 | trade_date=order_date, 147 | trade_time=order_time, 148 | trade_millisec=order_millisec, 149 | direction=direction, 150 | action=action, 151 | instrument_id=instrument_id, 152 | order_price=order_price, 153 | order_vol=order_vol, 154 | trade_price=order_price, 155 | trade_vol=order_vol, 156 | margin=margin, 157 | commission=commission 158 | ) 159 | if UPDATE_OR_INSERT_PER_ACTION: 160 | with with_db_session(engine_abat, expire_on_commit=False) as session: 161 | session.add(trade_info) 162 | session.commit() 163 | return trade_info 164 | 165 | 166 | class PosStatusInfo(BaseModel): 167 | """ 168 | 持仓状态数据 169 | 当持仓状态从有仓位到清仓时(position>0 --> position==0),计算清仓前的浮动收益,并设置到 floating_pl 字段最为当前状态的浮动收益 170 | 在调用 create_by_self 时,则需要处理一下,当 position==0 时,floating_pl 直接设置为 0,避免引起后续计算上的混淆 171 | """ 172 | __tablename__ = 'pos_status_info' 173 | pos_status_info_id = Column(Integer, primary_key=True, autoincrement=True) 174 | stg_run_id = Column(Integer) # 对应回测了策略 StgRunID 此数据与 AccSumID 对应数据相同 175 | trade_id = Column(Integer) # , comment="最新的成交id" 176 | # update_dt = Column(DateTime) # 每个订单变化生成一条记录 此数据与 AccSumID 对应数据相同 177 | trade_date = Column(Date) # 对应行情数据中 ActionDate 178 | trade_time = Column(Time) # 对应行情数据中 ActionTime 179 | trade_millisec = Column(Integer) # 对应行情数据中 ActionMillisec 180 | direction = Column(Integer) 181 | instrument_id = Column(String(30)) 182 | position = Column(DOUBLE, default=0.0) 183 | avg_price = Column(DOUBLE, default=0.0) # 所持投资品种上一交易日所有交易的加权平均价 184 | cur_price = Column(DOUBLE, default=0.0) 185 | floating_pl = Column(DOUBLE, default=0.0) 186 | floating_pl_chg = Column(DOUBLE, default=0.0) 187 | floating_pl_cum = Column(DOUBLE, default=0.0) 188 | margin = Column(DOUBLE, default=0.0) 189 | margin_chg = Column(DOUBLE, default=0.0) 190 | position_date = Column(Integer, default=0) 191 | logger = logging.getLogger(__tablename__) 192 | 193 | def __repr__(self): 194 | return "".format( 195 | self) 196 | 197 | @staticmethod 198 | def create_by_trade_info(trade_info: TradeInfo): 199 | direction, action, instrument_id = trade_info.direction, trade_info.action, trade_info.instrument_id 200 | trade_price, trade_vol, trade_id = trade_info.trade_price, trade_info.trade_vol, trade_info.trade_id 201 | trade_date, trade_time, trade_millisec = trade_info.trade_date, trade_info.trade_time, trade_info.trade_millisec 202 | stg_run_id = trade_info.stg_run_id 203 | if action == int(Action.Close): 204 | raise ValueError('trade_info.action 不能为 close') 205 | pos_status_info = PosStatusInfo(stg_run_id=stg_run_id, 206 | trade_id=trade_id, 207 | trade_date=trade_date, 208 | trade_time=trade_time, 209 | trade_millisec=trade_millisec, 210 | direction=direction, 211 | instrument_id=instrument_id, 212 | position=trade_vol, 213 | avg_price=trade_price, 214 | cur_price=trade_price, 215 | margin=0, 216 | margin_chg=0, 217 | floating_pl=0, 218 | floating_pl_chg=0, 219 | floating_pl_cum=0, 220 | ) 221 | if UPDATE_OR_INSERT_PER_ACTION: 222 | # 更新最新持仓纪录 223 | with with_db_session(engine_abat, expire_on_commit=False) as session: 224 | session.add(pos_status_info) 225 | session.commit() 226 | return pos_status_info 227 | 228 | def update_by_trade_info(self, trade_info: TradeInfo): 229 | """ 230 | 创建新的对象,根据 trade_info 更新相关信息 231 | :param trade_info: 232 | :return: 233 | """ 234 | # 复制前一个持仓状态 235 | pos_status_info = self.create_by_self() 236 | direction, action, instrument_id = trade_info.direction, trade_info.action, trade_info.instrument_id 237 | trade_price, trade_vol, trade_id = trade_info.trade_price, trade_info.trade_vol, trade_info.trade_id 238 | trade_date, trade_time, trade_millisec = trade_info.trade_date, trade_info.trade_time, trade_info.trade_millisec 239 | 240 | # 获取合约信息 241 | # instrument_info = Config.instrument_info_dic[instrument_id] 242 | # multiple = instrument_info['VolumeMultiple'] 243 | # margin_ratio = instrument_info['LongMarginRatio'] 244 | multiple, margin_ratio = 1, 1 245 | 246 | # 计算仓位、方向、平均价格 247 | pos_direction, position, avg_price = pos_status_info.direction, pos_status_info.position, pos_status_info.avg_price 248 | if pos_direction == direction: 249 | if action == Action.Open: 250 | # 方向相同:开仓:加仓; 251 | pos_status_info.avg_price = (position * avg_price + trade_price * trade_vol) / (position + trade_vol) 252 | pos_status_info.position = position + trade_vol 253 | else: 254 | # 方向相同:关仓:减仓; 255 | if trade_vol > position: 256 | raise ValueError("当前持仓%d,平仓%d,错误" % (position, trade_vol)) 257 | elif trade_vol == position: 258 | # 清仓前计算浮动收益 259 | # 未清仓的情况将在下面的代码中统一计算浮动收益 260 | if pos_status_info.direction == Direction.Long: 261 | pos_status_info.floating_pl = (trade_price - avg_price) * position * multiple 262 | else: 263 | pos_status_info.floating_pl = (avg_price - trade_price) * position * multiple 264 | 265 | pos_status_info.avg_price = 0 266 | pos_status_info.position = 0 267 | 268 | else: 269 | pos_status_info.avg_price = (position * avg_price - trade_price * trade_vol) / ( 270 | position - trade_vol) 271 | pos_status_info.position = position - trade_vol 272 | elif position == 0: 273 | pos_status_info.avg_price = trade_price 274 | pos_status_info.position = trade_vol 275 | pos_status_info.direction = direction 276 | else: 277 | # 方向相反 278 | raise ValueError("当前仓位:%s %d手,目标操作:%s %d手,请先平仓在开仓" % ( 279 | "多头" if pos_direction == Direction.Long else "空头", position, 280 | "多头" if direction == Direction.Long else "空头", trade_vol, 281 | )) 282 | # if position == trade_vol: 283 | # # 方向相反,量相同:清仓 284 | # pos_status_info.avg_price = 0 285 | # pos_status_info.position = 0 286 | # else: 287 | # holding_amount = position * avg_price 288 | # trade_amount = trade_price * trade_vol 289 | # position_rest = position - trade_vol 290 | # avg_price = (holding_amount - trade_amount) / position_rest 291 | # if position > trade_vol: 292 | # # 减仓 293 | # pos_status_info.avg_price = avg_price 294 | # pos_status_info.position = position_rest 295 | # else: 296 | # # 多空反手 297 | # self.logger.warning("%s 持%s:%d -> %d 多空反手", self.instrument_id, 298 | # '多' if direction == int(Direction.Long) else '空', position, position_rest) 299 | # pos_status_info.avg_price = avg_price 300 | # pos_status_info.position = position_rest 301 | # pos_status_info.direction = Direction.Short if direction == int(Direction.Short) else Direction.Long 302 | 303 | # 设置其他属性 304 | pos_status_info.cur_price = trade_price 305 | pos_status_info.trade_date = trade_date 306 | pos_status_info.trade_time = trade_time 307 | pos_status_info.trade_millisec = trade_millisec 308 | 309 | # 计算 floating_pl margin 310 | position = pos_status_info.position 311 | # cur_price = pos_status_info.cur_price 312 | avg_price = pos_status_info.avg_price 313 | pos_status_info.margin = position * trade_price * multiple * margin_ratio 314 | # 如果当前仓位不为 0 则计算浮动收益 315 | if position > 0: 316 | if pos_status_info.direction == Direction.Long: 317 | pos_status_info.floating_pl = (trade_price - avg_price) * position * multiple 318 | else: 319 | pos_status_info.floating_pl = (avg_price - trade_price) * position * multiple 320 | # 如果前一状态仓位为 0 则不进行差值计算 321 | if self.position == 0: 322 | pos_status_info.margin_chg = pos_status_info.margin 323 | pos_status_info.floating_pl_chg = pos_status_info.floating_pl 324 | else: 325 | pos_status_info.margin_chg = pos_status_info.margin - self.margin 326 | pos_status_info.floating_pl_chg = pos_status_info.floating_pl - self.floating_pl 327 | 328 | pos_status_info.floating_pl_cum += pos_status_info.floating_pl_chg 329 | 330 | if UPDATE_OR_INSERT_PER_ACTION: 331 | # 更新最新持仓纪录 332 | with with_db_session(engine_abat, expire_on_commit=False) as session: 333 | session.add(pos_status_info) 334 | session.commit() 335 | return pos_status_info 336 | 337 | def update_by_md(self, md: dict): 338 | """ 339 | 创建新的对象,根据 trade_info 更新相关信息 340 | :param md: 341 | :return: 342 | """ 343 | trade_date = md['ActionDay'] 344 | trade_time = pd_timedelta_2_timedelta(md['ActionTime']) 345 | trade_millisec = int(md.setdefault('ActionMillisec', 0)) 346 | trade_price = float(md['close']) 347 | instrument_id = md['InstrumentID'] 348 | pos_status_info = self.create_by_self() 349 | pos_status_info.cur_price = trade_price 350 | pos_status_info.trade_date = trade_date 351 | pos_status_info.trade_time = trade_time 352 | pos_status_info.trade_millisec = trade_millisec 353 | 354 | # 计算 floating_pl margin 355 | # instrument_info = Config.instrument_info_dic[instrument_id] 356 | # multiple = instrument_info['VolumeMultiple'] 357 | # margin_ratio = instrument_info['LongMarginRatio'] 358 | multiple, margin_ratio = 1, 1 359 | position = pos_status_info.position 360 | cur_price = pos_status_info.cur_price 361 | avg_price = pos_status_info.avg_price 362 | pos_status_info.margin = position * cur_price * multiple * margin_ratio 363 | pos_status_info.margin_chg = pos_status_info.margin - self.margin 364 | if pos_status_info.direction == Direction.Long: 365 | pos_status_info.floating_pl = (cur_price - avg_price) * position * multiple 366 | else: 367 | pos_status_info.floating_pl = (avg_price - cur_price) * position * multiple 368 | pos_status_info.floating_pl_chg = pos_status_info.floating_pl - self.floating_pl 369 | pos_status_info.floating_pl_cum += pos_status_info.floating_pl_chg 370 | 371 | if UPDATE_OR_INSERT_PER_ACTION: 372 | # 更新最新持仓纪录 373 | with with_db_session(engine_abat, expire_on_commit=False) as session: 374 | session.add(pos_status_info) 375 | session.commit() 376 | return pos_status_info 377 | 378 | def create_by_self(self): 379 | """ 380 | 创建新的对象 381 | 若当前对象持仓为0(position==0),则 浮动收益部分设置为0 382 | :return: 383 | """ 384 | position = self.position 385 | pos_status_info = PosStatusInfo(stg_run_id=self.stg_run_id, 386 | trade_id=self.trade_id, 387 | trade_date=self.trade_date, 388 | trade_time=self.trade_time, 389 | trade_millisec=self.trade_millisec, 390 | direction=self.direction, 391 | instrument_id=self.instrument_id, 392 | position=position, 393 | avg_price=self.avg_price, 394 | cur_price=self.cur_price, 395 | floating_pl=self.floating_pl if position > 0 else 0, 396 | floating_pl_cum=self.floating_pl_cum, 397 | margin=self.margin) 398 | return pos_status_info 399 | 400 | @staticmethod 401 | def create_by_dic(position_date_inv_pos_dic: dict) -> dict: 402 | if position_date_inv_pos_dic is None: 403 | return None 404 | position_date_pos_info_dic = {} 405 | for position_date, pos_dic in position_date_inv_pos_dic.items(): 406 | pos_info = PosStatusInfo() 407 | pos_info.trade_date = pos_dic['TradingDay'] 408 | pos_info.trade_time = None 409 | pos_info.direction = Direction.create_by_posi_direction(pos_dic['PosiDirection']) 410 | pos_info.instrument_id = pos_dic['InstrumentID'] 411 | pos_info.position = pos_dic['Position'] 412 | pos_info.avg_price = pos_dic['PositionCost'] 413 | pos_info.cur_price = 0 414 | pos_info.floating_pl = pos_dic['PositionProfit'] 415 | pos_info.floating_pl_chg = 0 416 | pos_info.margin = pos_dic['UseMargin'] 417 | pos_info.margin_chg = 0 418 | pos_info.position_date = int(position_date) 419 | position_date_pos_info_dic[position_date] = pos_info 420 | return position_date_pos_info_dic 421 | 422 | @staticmethod 423 | def remove_pos_status_info(stg_run_id: int): 424 | """ 425 | 仅作为调试工具使用,删除指定 stg_run_id 相关的 pos_status_info 426 | :param stg_run_id: 427 | :return: 428 | """ 429 | with with_db_session(engine_abat) as session: 430 | session.execute('DELETE FROM pos_status_info WHERE stg_run_id=:stg_run_id', 431 | {'stg_run_id': stg_run_id}) 432 | session.commit() 433 | 434 | 435 | class AccountStatusInfo(BaseModel): 436 | """持仓状态数据""" 437 | __tablename__ = 'account_status_info' 438 | account_status_info_id = Column(Integer, primary_key=True, autoincrement=True) 439 | stg_run_id = Column(Integer) # 对应回测了策略 StgRunID 此数据与 AccSumID 对应数据相同 440 | trade_date = Column(Date) # 对应行情数据中 ActionDate 441 | trade_time = Column(Time) # 对应行情数据中 ActionTime 442 | trade_millisec = Column(Integer) # 对应行情数据中 ActionMillisec 443 | available_cash = Column(DOUBLE, default=0.0) # 可用资金, double 444 | curr_margin = Column(DOUBLE, default=0.0) # 当前保证金总额, double 445 | close_profit = Column(DOUBLE, default=0.0) 446 | position_profit = Column(DOUBLE, default=0.0) 447 | floating_pl_cum = Column(DOUBLE, default=0.0) 448 | fee_tot = Column(DOUBLE, default=0.0) 449 | balance_tot = Column(DOUBLE, default=0.0) 450 | 451 | @staticmethod 452 | def create(stg_run_id, init_cash: int, md: dict): 453 | """ 454 | 根据 md 及 初始化资金 创建对象,默认日期为当前md数据-1天 455 | :param stg_run_id: 456 | :param init_cash: 457 | :param md: 458 | :return: 459 | """ 460 | trade_date = str_2_date(md['ActionDay']) - timedelta(days=1) 461 | trade_time = pd_timedelta_2_timedelta(md['ActionTime']) 462 | trade_millisec = int(md.setdefault('ActionMillisec', 0)) 463 | trade_price = float(md['close']) 464 | acc_status_info = AccountStatusInfo(stg_run_id=stg_run_id, 465 | trade_date=trade_date, 466 | trade_time=trade_time, 467 | trade_millisec=trade_millisec, 468 | available_cash=init_cash, 469 | balance_tot=init_cash, 470 | ) 471 | if UPDATE_OR_INSERT_PER_ACTION: 472 | # 更新最新持仓纪录 473 | with with_db_session(engine_abat, expire_on_commit=False) as session: 474 | session.add(acc_status_info) 475 | session.commit() 476 | return acc_status_info 477 | 478 | def create_by_self(self): 479 | """ 480 | 创建新的对象,默认前一日持仓信息的最新价,等于下一交易日的结算价(即 AvePrice) 481 | :return: 482 | """ 483 | account_status_info = AccountStatusInfo(stg_run_id=self.stg_run_id, 484 | trade_date=self.trade_date, 485 | trade_time=self.trade_time, 486 | trade_millisec=self.trade_millisec, 487 | available_cash=self.available_cash, 488 | curr_margin=self.curr_margin, 489 | close_profit=self.close_profit, 490 | position_profit=self.position_profit, 491 | floating_pl_cum=self.floating_pl_cum, 492 | fee_tot=self.fee_tot, 493 | balance_tot=self.balance_tot 494 | ) 495 | return account_status_info 496 | 497 | def update_by_pos_status_info(self, pos_status_info_dic, md: dict): 498 | """ 499 | 根据 持仓列表更新账户信息 500 | :param pos_status_info_dic: 501 | :return: 502 | """ 503 | account_status_info = self.create_by_self() 504 | # 上一次更新日期、时间 505 | # trade_date_last, trade_time_last, trade_millisec_last = \ 506 | # account_status_info.trade_date, account_status_info.trade_time, account_status_info.trade_millisec 507 | # 更新日期、时间 508 | trade_date = md['ActionDay'] 509 | trade_time = pd_timedelta_2_timedelta(md['ActionTime']) 510 | trade_millisec = int(md.setdefault('ActionMillisec', 0)) 511 | 512 | available_cash_chg = 0 513 | curr_margin = 0 514 | close_profit = 0 515 | position_profit = 0 516 | floating_pl_chg = 0 517 | margin_chg = 0 518 | floating_pl_cum = 0 519 | for instrument_id, pos_status_info in pos_status_info_dic.items(): 520 | curr_margin += pos_status_info.margin 521 | if pos_status_info.position == 0: 522 | close_profit += pos_status_info.floating_pl 523 | else: 524 | position_profit += pos_status_info.floating_pl 525 | floating_pl_chg += pos_status_info.floating_pl_chg 526 | margin_chg += pos_status_info.margin_chg 527 | floating_pl_cum += pos_status_info.floating_pl_cum 528 | 529 | available_cash_chg = floating_pl_chg - margin_chg 530 | account_status_info.curr_margin = curr_margin 531 | # # 对于同一时间,平仓后又开仓的情况,不能将close_profit重置为0 532 | # if trade_date == trade_date_last and trade_time == trade_time_last and trade_millisec == trade_millisec_last: 533 | # account_status_info.close_profit += close_profit 534 | # else: 535 | # 一个单位时段只允许一次,不需要考虑上面的情况 536 | account_status_info.close_profit = close_profit 537 | 538 | account_status_info.position_profit = position_profit 539 | account_status_info.available_cash += available_cash_chg 540 | account_status_info.floating_pl_cum = floating_pl_cum 541 | account_status_info.balance_tot = account_status_info.available_cash + curr_margin 542 | 543 | account_status_info.trade_date = trade_date 544 | account_status_info.trade_time = trade_time 545 | account_status_info.trade_millisec = trade_millisec 546 | if UPDATE_OR_INSERT_PER_ACTION: 547 | # 更新最新持仓纪录 548 | with with_db_session(engine_abat, expire_on_commit=False) as session: 549 | session.add(account_status_info) 550 | session.commit() 551 | return account_status_info 552 | 553 | 554 | def init(): 555 | from abat.backend import engine_abat 556 | 557 | BaseModel.metadata.create_all(engine_abat) 558 | with with_db_session(engine_abat) as session: 559 | for table_name, _ in BaseModel.metadata.tables.items(): 560 | sql_str = "ALTER TABLE %s ENGINE = MyISAM" % table_name 561 | session.execute(sql_str) 562 | print("所有表结构建立完成") 563 | 564 | 565 | if __name__ == "__main__": 566 | init() 567 | # 创建user表,继承metadata类 568 | # Engine使用Schama Type创建一个特定的结构对象 569 | # stg_info_table = Table("stg_info", metadata, autoload=True) 570 | -------------------------------------------------------------------------------- /abat/common.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/20 15:15 6 | @File : common.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | from enum import IntEnum, unique 11 | 12 | 13 | @unique 14 | class PeriodType(IntEnum): 15 | """ 16 | 周期类型 17 | """ 18 | Tick = 0 19 | Min1 = 1 20 | Min5 = 2 21 | Min10 = 3 22 | Min15 = 4 23 | Min30 = 5 24 | Hour1 = 10 25 | Day1 = 20 26 | Week1 = 30 27 | Mon1 = 40 28 | Year1 = 100 29 | 30 | @staticmethod 31 | def get_min_count(period_type) -> int: 32 | """ 33 | 返回给的周期类型对应的分钟数 34 | :param period_type: 35 | :return: 36 | """ 37 | if PeriodType.Min1 == period_type: 38 | min_count = 1 39 | elif PeriodType.Min5 == period_type: 40 | min_count = 5 41 | elif PeriodType.Min10 == period_type: 42 | min_count = 10 43 | elif PeriodType.Min15 == period_type: 44 | min_count = 15 45 | elif PeriodType.Min30 == period_type: 46 | min_count = 30 47 | else: 48 | raise ValueError('不支持 %s 周期' % period_type) 49 | return min_count 50 | 51 | @unique 52 | class RunMode(IntEnum): 53 | """ 54 | 运行模式,目前支持两种:实时行情模式,回测模式 55 | """ 56 | Realtime = 0 57 | Backtest = 1 58 | 59 | 60 | @unique 61 | class Direction(IntEnum): 62 | """买卖方向""" 63 | Short = 0 # 空头 64 | Long = 1 # 多头 65 | 66 | @staticmethod 67 | def create_by_direction(direction_str): 68 | # if isinstance(direction_str, str): 69 | # if direction_str == D_Buy_str: 70 | # return Direction.Long 71 | # elif direction_str == D_Sell_str: 72 | # return Direction.Short 73 | # else: 74 | # raise ValueError('Direction不支持 %s' % direction_str) 75 | # else: 76 | # if direction_str == D_Buy: 77 | # return Direction.Long 78 | # elif direction_str == D_Sell: 79 | # return Direction.Short 80 | # else: 81 | # raise ValueError('Direction不支持 %s' % direction_str) 82 | pass 83 | 84 | @staticmethod 85 | def create_by_posi_direction(posi_direction): 86 | # if isinstance(posi_direction, str): 87 | # if posi_direction == PD_Long_str: 88 | # return Direction.Long 89 | # elif posi_direction == PD_Short_str: 90 | # return Direction.Short 91 | # else: 92 | # raise ValueError('Direction不支持 %s' % posi_direction) 93 | # else: 94 | # if posi_direction == PD_Long: 95 | # return Direction.Long 96 | # elif posi_direction == PD_Short: 97 | # return Direction.Short 98 | # else: 99 | # raise ValueError('Direction不支持 %s' % posi_direction) 100 | pass 101 | 102 | 103 | @unique 104 | class Action(IntEnum): 105 | """开仓平仓""" 106 | Open = 0 # 开仓 107 | Close = 1 # 平仓 108 | ForceClose = 2 # 强平 109 | CloseToday = 3 # 平今 110 | CloseYesterday = 4 # 平昨 111 | ForceOff = 5 # 强减 112 | LocalForceClose = 6 # 本地强平 113 | 114 | @staticmethod 115 | def create_by_offsetflag(offset_flag): 116 | """ 117 | 将 Api 中 OffsetFlag 变为 Action 类型 118 | :param offset_flag: 119 | :return: 120 | """ 121 | if offset_flag == '0': 122 | return Action.Open 123 | elif offset_flag in {'1', '5'}: 124 | return Action.Close 125 | elif offset_flag == '3': 126 | return Action.CloseToday 127 | elif offset_flag == '4': 128 | return Action.CloseYesterday 129 | elif offset_flag == '2': 130 | return Action.ForceClose 131 | elif offset_flag == '6': 132 | return Action.LocalForceClose 133 | else: 134 | raise ValueError('Action不支持 %s' % offset_flag) 135 | 136 | 137 | @unique 138 | class PositionDateType(IntEnum): 139 | """今日持仓历史持仓标示""" 140 | Today = 1 # 今日持仓 141 | History = 2 # 历史持仓 142 | 143 | 144 | @unique 145 | class BacktestTradeMode(IntEnum): 146 | """ 147 | 回测模式下的成交模式 148 | """ 149 | Order_2_Deal = 0 # 一种简单回测模式,无论开仓平仓等操作,下单即成交 150 | MD_2_Deal = 1 # 根据下单后行情变化确定何时成交 151 | 152 | 153 | @unique 154 | class ContextKey(IntEnum): 155 | """ 156 | 策略执行逻辑中 context[key] 的部分定制化key 157 | """ 158 | instrument_id_list = 0 159 | 160 | 161 | -------------------------------------------------------------------------------- /abat/event.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2017/11/13 4 | @author: MG 5 | """ 6 | import logging 7 | from enum import IntEnum, unique 8 | from config import Config 9 | from abat.common import PeriodType 10 | logger = logging.getLogger() 11 | 12 | 13 | @unique 14 | class EventType(IntEnum): 15 | Tick_MD_Event = 0 16 | Min1_MD_Event = 1 17 | Min5_MD_Event = 5 18 | Min15_MD_Event = 15 19 | 20 | @staticmethod 21 | def try_2_period_type(event_type): 22 | """ 23 | 将 EventType类型转换为 PeriodType类型 24 | 由于两个对象不一定一一对应,因此,无法匹配的类型返回None 25 | :param event_type: 26 | :return: 27 | """ 28 | if event_type == EventType.Tick_MD_Event: 29 | period_type = PeriodType.Tick 30 | elif event_type == EventType.Min1_MD_Event: 31 | period_type = PeriodType.Min1 32 | elif event_type == EventType.Min5_MD_Event: 33 | period_type = PeriodType.Min5 34 | elif event_type == EventType.Min15_MD_Event: 35 | period_type = PeriodType.Min15 36 | else: 37 | # raise ValueError('event_type:%s 不是有效的类型' % event_type) 38 | logger.warning('event_type:%s 不是有效的类型', event_type) 39 | period_type = None 40 | return period_type 41 | 42 | 43 | class EventAgent: 44 | """ 45 | 事件注册器,用于集中通过事件具备 处理相关类型的事件 46 | """ 47 | def __init__(self): 48 | self._event_handler_dic = {} 49 | self._event_key_handler_dic = {} 50 | self.logger = logging.getLogger(self.__class__.__name__) 51 | 52 | def register_handler(self, event_type, handler, handler_name=None, key=None): 53 | """ 54 | 注册器:注册时间类型及对应的事件处理程序 55 | Key!=None,则该句柄只响应匹配 key 的事件 56 | Key=None,公共事件响应句柄,所有相应类型事件均响应 57 | :param event_type: 58 | :param handler: 59 | :param handler_name: 事件句柄名称,每个事件中,句柄名称唯一,重复将被覆盖 60 | :param key: 事件匹配 key,None代表公共事件响应句柄,所有该类型事件均响应 61 | :return: 62 | """ 63 | if event_type not in self._event_handler_dic: 64 | self._event_handler_dic[event_type] = {} 65 | self._event_key_handler_dic[event_type] = {} 66 | if handler_name is None: 67 | handler_name = handler.__name__ 68 | self.logger.debug("注册事件处理句柄 %s -> <%s>%s", 69 | event_type, handler_name, (" [%s]" % key) if key is not None else "") 70 | if key is None: 71 | if handler_name in self._event_handler_dic[event_type]: 72 | self.logger.warning('相应事件处理句柄 %s -> <%s> 已经存在,重新注册将覆盖原来执行句柄', 73 | event_type, handler_name) 74 | self._event_handler_dic[event_type][handler_name] = handler 75 | else: 76 | if key not in self._event_key_handler_dic[event_type]: 77 | self._event_key_handler_dic[event_type][key] = {} 78 | self._event_key_handler_dic[event_type][key][handler_name] = handler 79 | 80 | def send_event(self, event_type, data, key=None): 81 | """ 82 | 触发事件 83 | 如果带Key,则触发 匹配 key 事件,以及全部 无 key 的事件响应句柄 84 | 否则,只触发公共事件响应句柄 85 | :param event_type: 86 | :param data: 87 | :param key: 事件匹配 key,None代表只触发公共事件响应句柄 88 | :return: 89 | """ 90 | # 公共事件响应句柄 91 | if event_type in self._event_handler_dic: 92 | handler_dic = self._event_handler_dic[event_type] 93 | error_name_list = [] 94 | for handler_name, handler in handler_dic.items(): 95 | try: 96 | handler(data) 97 | except: 98 | self.logger.exception('%s run with error will be remove from register', handler_name) 99 | error_name_list.append(handler_name) 100 | for handler_name in error_name_list: 101 | del handler_dic[handler_name] 102 | self.logger.warning('从注册器中移除 %s - %s', event_type, handler_name) 103 | # key 匹配事件响应句柄 104 | if key is not None and event_type in self._event_key_handler_dic: 105 | key_handler_dic = self._event_key_handler_dic[event_type] 106 | if key in key_handler_dic: 107 | handler_dic = key_handler_dic[key] 108 | error_name_list = [] 109 | for handler_name, handler in handler_dic.items(): 110 | try: 111 | handler(data) 112 | except: 113 | self.logger.exception('%s run with error will be remove from register', handler_name) 114 | error_name_list.append(handler_name) 115 | for handler_name in error_name_list: 116 | del handler_dic[handler_name] 117 | self.logger.warning('从注册器中移除 %s - %s - %s', event_type, str(key), handler_name) 118 | 119 | 120 | event_agent = EventAgent() 121 | 122 | if __name__ == '__main__': 123 | logging.basicConfig(level=logging.DEBUG, format=Config.LOG_FORMAT) 124 | event_agent.register_handler(EventType.Tick_MD_Event, lambda x: print('T:', x), 'print_t') 125 | event_agent.register_handler(EventType.Min1_MD_Event, lambda x: print('M:', x), 'print_m') 126 | event_agent.register_handler(EventType.Tick_MD_Event, lambda x: print('T 1:', x), 'print_t with key1', key='1') 127 | event_agent.send_event(EventType.Tick_MD_Event, 'sdf') 128 | event_agent.send_event(EventType.Tick_MD_Event, 'sdf', key='1') 129 | -------------------------------------------------------------------------------- /abat/md.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/20 15:12 6 | @File : md.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | from config import Config 11 | from abat.common import PeriodType, RunMode 12 | from threading import Thread 13 | import time 14 | import pandas as pd 15 | import logging 16 | from abc import ABC, abstractmethod 17 | logger = logging.getLogger(__package__) 18 | 19 | 20 | class MdAgentBase(Thread, ABC): 21 | 22 | @staticmethod 23 | def factory(run_mode: RunMode, instrument_id_list, md_period: PeriodType, name=None, **kwargs): 24 | # if run_mode == RunMode.Backtest: 25 | # md_agent = MdAgentBacktest(instrument_id_list, md_period, name, **kwargs) 26 | # elif run_mode == RunMode.Realtime: 27 | # md_agent = MdAgentRealtime(instrument_id_list, md_period, name, **kwargs) 28 | # else: 29 | # raise ValueError("run_mode:%s exception", run_mode) 30 | md_agent_class = md_agent_class_dic[run_mode] 31 | md_agent = md_agent_class(instrument_id_list, md_period, name, **kwargs) 32 | return md_agent 33 | 34 | def __init__(self, instrument_id_set, md_period: PeriodType, name=None, 35 | init_load_md_count=None, init_md_date_from=None, init_md_date_to=None, **kwargs): 36 | if name is None: 37 | name = md_period 38 | super().__init__(name=name, daemon=True) 39 | self.md_period = md_period 40 | self.keep_running = None 41 | self.instrument_id_set = instrument_id_set 42 | self.init_load_md_count = init_load_md_count 43 | self.init_md_date_from = init_md_date_from 44 | self.init_md_date_to = init_md_date_to 45 | self.logger = logging.getLogger() 46 | 47 | @abstractmethod 48 | def load_history(self, date_from=None, date_to=None, load_md_count=None)->(pd.DataFrame, dict): 49 | """ 50 | 从mysql中加载历史数据 51 | 实时行情推送时进行合并后供数据分析使用 52 | :param date_from: None代表沿用类的 init_md_date_from 属性 53 | :param date_to: None代表沿用类的 init_md_date_from 属性 54 | :param load_md_count: 0 代表不限制,None代表沿用类的 init_load_md_count 属性,其他数字代表相应的最大加载条数 55 | :return: md_df 或者 56 | ret_data { 57 | 'md_df': md_df, 'datetime_key': 'ts_start', 58 | 'date_key': **, 'time_key': **, 'microseconds_key': ** 59 | } 60 | """ 61 | 62 | @abstractmethod 63 | def connect(self): 64 | """链接redis、初始化历史数据""" 65 | 66 | @abstractmethod 67 | def release(self): 68 | """释放channel资源""" 69 | 70 | def subscribe(self, instrument_id_set=None): 71 | """订阅合约""" 72 | if instrument_id_set is None: 73 | return 74 | self.instrument_id_set |= instrument_id_set 75 | 76 | def unsubscribe(self, instrument_id_set): 77 | """退订合约""" 78 | if instrument_id_set is None: 79 | self.instrument_id_set = set() 80 | else: 81 | self.instrument_id_set -= instrument_id_set 82 | 83 | 84 | md_agent_class_dic = {RunMode.Backtest: MdAgentBase, RunMode.Realtime: MdAgentBase} 85 | 86 | 87 | def register_realtime_md_agent(agent: MdAgentBase) -> MdAgentBase: 88 | md_agent_class_dic[RunMode.Realtime] = agent 89 | logger.info('设置 realtime md agent:%s', agent.__class__.__name__) 90 | return agent 91 | 92 | 93 | def register_backtest_md_agent(agent: MdAgentBase) -> MdAgentBase: 94 | md_agent_class_dic[RunMode.Backtest] = agent 95 | logger.info('设置 backtest md agent:%s', agent.__class__.__name__) 96 | return agent 97 | 98 | 99 | if __name__ == "__main__": 100 | logging.basicConfig(level=logging.DEBUG, format=Config.LOG_FORMAT) 101 | 102 | instrument_id_list = set(['jm1711', 'rb1712', 'pb1801', 'IF1710']) 103 | md_agent = MdAgentBase.factory(RunMode.Realtime, instrument_id_list, md_period=PeriodType.Min1, 104 | init_load_md_count=100) 105 | md_df = md_agent.load_history() 106 | print(md_df.shape) 107 | md_agent.connect() 108 | md_agent.subscribe(instrument_id_list) 109 | md_agent.start() 110 | for n in range(120): 111 | time.sleep(1) 112 | md_agent.keep_running = False 113 | md_agent.join() 114 | md_agent.release() 115 | print("all finished") 116 | -------------------------------------------------------------------------------- /abat/strategy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2017/9/2 4 | @author: MG 5 | """ 6 | from threading import Thread 7 | import warnings 8 | import json 9 | import numpy as np 10 | import pandas as pd 11 | from config import Config 12 | import logging 13 | from queue import Empty 14 | import time 15 | from datetime import date, datetime, timedelta 16 | from abc import ABC 17 | from abat.backend import engine_abat 18 | from abat.utils.db_utils import with_db_session 19 | from abat.md import MdAgentBase 20 | from abat.common import PeriodType, RunMode, ContextKey, Direction, BacktestTradeMode 21 | from abat.utils.fh_utils import try_2_date 22 | from abat.backend.orm import StgRunInfo 23 | from abat.trade import trader_agent_class_dic 24 | 25 | 26 | logger_stg_base = logging.getLogger('StgBase') 27 | 28 | 29 | class StgBase: 30 | 31 | def __init__(self): 32 | # 记录各个周期 md 数据 33 | self._md_period_df_dic = {} 34 | # 记录在行情推送过程中最新的一笔md数据 35 | # self._period_curr_md_dic = {} 36 | # 记录各个周期 md 列信息 37 | self._md_period_df_col_name_list_dic = {} 38 | self.trade_agent = None 39 | self.logger = logging.getLogger(self.__class__.__name__) 40 | self._on_period_event_dic = { 41 | PeriodType.Tick: EventHandlersRelation(PeriodType.Tick, 42 | self.on_prepare_tick, self.on_tick, pd.DataFrame), 43 | PeriodType.Min1: EventHandlersRelation(PeriodType.Min1, 44 | self.on_prepare_min1, self.on_min1, pd.DataFrame), 45 | PeriodType.Hour1: EventHandlersRelation(PeriodType.Hour1, 46 | self.on_prepare_hour1, self.on_hour1, pd.DataFrame), 47 | PeriodType.Day1: EventHandlersRelation(PeriodType.Day1, 48 | self.on_prepare_day1, self.on_day1, pd.DataFrame), 49 | PeriodType.Week1: EventHandlersRelation(PeriodType.Week1, 50 | self.on_prepare_week1, self.on_week1, pd.DataFrame), 51 | PeriodType.Mon1: EventHandlersRelation(PeriodType.Mon1, 52 | self.on_prepare_month1, self.on_month1, pd.DataFrame), 53 | } 54 | self._period_context_dic = {} 55 | 56 | def on_timer(self): 57 | pass 58 | 59 | def load_md_period_df(self, period, md_df, context): 60 | """初始化加载 md 数据""" 61 | self._md_period_df_dic[period] = md_df 62 | self._md_period_df_col_name_list_dic[period] = list(md_df.columns) if isinstance(md_df, pd.DataFrame) else None 63 | self._period_context_dic[period] = context 64 | # prepare_event_handler = self._on_period_prepare_event_dic[period] 65 | prepare_event_handler = self._on_period_event_dic[period].prepare_event 66 | prepare_event_handler(md_df, context) 67 | 68 | def init(self): 69 | """ 70 | 加载历史数据后,启动周期策略执行函数之前 71 | 执行初始化动作,连接 trade_agent 72 | 以后还可以放出其他初始化动作 73 | :return: 74 | """ 75 | self.trade_agent.connect() 76 | 77 | def release(self): 78 | """ 79 | 80 | :return: 81 | """ 82 | self.trade_agent.release() 83 | 84 | def _on_period_md_append(self, period, md): 85 | """ 86 | (仅供 on_period_md_handler 调用使用) 87 | 用于整理接收到的各个周期行情数据 88 | :param period: 89 | :param md: 90 | :return: 91 | """ 92 | # self._period_curr_md_dic[period] = md 93 | md_df = pd.DataFrame([md]) 94 | if period in self._md_period_df_dic: 95 | col_name_list = self._md_period_df_col_name_list_dic[period] 96 | md_df_his = self._md_period_df_dic[period].append(md_df[col_name_list]) 97 | self._md_period_df_dic[period] = md_df_his 98 | else: 99 | md_df_his = md_df 100 | self._md_period_df_dic[period] = md_df_his 101 | self._md_period_df_col_name_list_dic[period] = \ 102 | list(md_df.columns) if isinstance(md_df, pd.DataFrame) else None 103 | # logger_stg_base.debug('%s -> %s', period, md) 104 | return md_df_his 105 | 106 | def _on_period_md_event(self, period, md_df_his): 107 | """ 108 | (仅供 on_period_md_handler 调用使用) 109 | 用于将各个周期数据传入对应周期事件处理函数 110 | :param period: 111 | :param md_df_his: 112 | :return: 113 | """ 114 | # event_handler = self._on_period_md_event_dic[period] 115 | event_handler = self._on_period_event_dic[period].md_event 116 | context = self._period_context_dic[period] 117 | # self._trade_agent.curr_md = md 118 | event_handler(md_df_his, context) 119 | 120 | def on_period_md_handler(self, period, md): 121 | """响应 period 数据""" 122 | # 本机测试,延时0.155秒,从分钟K线合成到交易策略端收到数据 123 | logger_stg_base.debug("%s -> %s", PeriodType(period), md) 124 | # self._on_period_md_event(period, md_df_his) 125 | period_event_relation = self._on_period_event_dic[period] 126 | event_handler = period_event_relation.md_event 127 | param_type = period_event_relation.param_type 128 | context = self._period_context_dic[period] 129 | # TODO 由于每一次进入都需要进行判断,增加不必要的计算,考虑通过优化提高运行效率 130 | if param_type is dict: 131 | param = md 132 | elif param_type is pd.DataFrame: 133 | param = self._on_period_md_append(period, md) 134 | else: 135 | raise ValueError("不支持 %s 类型作为 %s 的事件参数" % (param_type, period)) 136 | event_handler(param, context) 137 | 138 | def on_prepare_tick(self, md_df, context): 139 | """Tick 历史数据加载执行语句""" 140 | pass 141 | 142 | def on_prepare_min1(self, md_df, context): 143 | """1分钟线 历史数据加载执行语句""" 144 | pass 145 | 146 | def on_prepare_hour1(self, md_df, context): 147 | """1小时线 历史数据加载执行语句""" 148 | pass 149 | 150 | def on_prepare_day1(self, md_df, context): 151 | """1日线 历史数据加载执行语句""" 152 | pass 153 | 154 | def on_prepare_week1(self, md_df, context): 155 | """1周线 历史数据加载执行语句""" 156 | pass 157 | 158 | def on_prepare_month1(self, md_df, context): 159 | """1月线 历史数据加载执行语句""" 160 | pass 161 | 162 | def on_tick(self, md_df, context): 163 | """Tick策略执行语句,需要相应策略实现具体的策略算法""" 164 | pass 165 | 166 | def on_min1(self, md_df, context): 167 | """1分钟线策略执行语句,需要相应策略实现具体的策略算法""" 168 | pass 169 | 170 | def on_hour1(self, md_df, context): 171 | """1小时线策略执行语句,需要相应策略实现具体的策略算法""" 172 | pass 173 | 174 | def on_day1(self, md_df, context): 175 | """1日线策略执行语句,需要相应策略实现具体的策略算法""" 176 | pass 177 | 178 | def on_week1(self, md_df, context): 179 | """1周线策略执行语句,需要相应策略实现具体的策略算法""" 180 | pass 181 | 182 | def on_month1(self, md_df, context): 183 | """1月线策略执行语句,需要相应策略实现具体的策略算法""" 184 | pass 185 | 186 | def open_long(self, instrument_id, price, vol): 187 | self.trade_agent.open_long(instrument_id, price, vol) 188 | 189 | def close_long(self, instrument_id, price, vol): 190 | self.trade_agent.close_long(instrument_id, price, vol) 191 | 192 | def open_short(self, instrument_id, price, vol): 193 | self.trade_agent.open_short(instrument_id, price, vol) 194 | 195 | def close_short(self, instrument_id, price, vol): 196 | self.trade_agent.close_short(instrument_id, price, vol) 197 | 198 | def get_position(self, instrument_id, **kwargs) -> dict: 199 | """ 200 | position_date 作为key, PosStatusInfo 为 val 201 | 返回 position_date_pos_info_dic 202 | :param instrument_id: 203 | :return: 204 | """ 205 | return self.trade_agent.get_position(instrument_id, **kwargs) 206 | 207 | def get_order(self, instrument_id) -> list: 208 | return self.trade_agent.get_order(instrument_id) 209 | 210 | def cancel_order(self, instrument_id): 211 | return self.trade_agent.cancel_order(instrument_id) 212 | 213 | @property 214 | def datetime_last_update_position(self): 215 | return self.trade_agent.datetime_last_update_position 216 | 217 | @property 218 | def datetime_last_rtn_trade_dic(self): 219 | return self.trade_agent.datetime_last_rtn_trade_dic 220 | 221 | @property 222 | def datetime_last_update_position_dic(self): 223 | return self.trade_agent.datetime_last_update_position_dic 224 | 225 | @property 226 | def datetime_last_send_order_dic(self): 227 | return self.trade_agent.datetime_last_send_order_dic 228 | 229 | def get_balance(self, non_zero_only=True, trade_type_only=True, currency=None, force_refresh=False) -> dict: 230 | """ 231 | 调用接口 查询 各个币种仓位 232 | :param non_zero_only: 只保留非零币种 233 | :param trade_type_only: 只保留 trade 类型币种,frozen 类型的不保存 234 | :param currency: 只返回制定币种 usdt eth 等 235 | :param force_refresh: 强制刷新,默认没30秒允许重新查询一次 236 | :return: {'usdt': {: {'currency': 'usdt', 'type': 'trade', 'balance': 144.09238}}} 237 | """ 238 | return self.trade_agent.get_balance(non_zero_only, trade_type_only, currency, force_refresh) 239 | 240 | def get_holding_currency(self, force_refresh=False, exclude_usdt=True) -> dict: 241 | """ 242 | 持仓情况dict(非usdt),仅包含交易状态 type = 'trade' 的记录 243 | :param force_refresh: 244 | :param exclude_usdt: 默认为True,剔除 usdt 245 | :return: 246 | {'eos': {: {'currency': 'eos', 'type': 'trade', 'balance': 144.09238}}} 247 | """ 248 | cur_balance_dic = self.get_balance(non_zero_only=True, force_refresh=force_refresh) 249 | balance_dic = {} 250 | for currency, dic in cur_balance_dic.items(): 251 | if exclude_usdt and currency == 'usdt': 252 | continue 253 | for pos_date_type, dic_sub in dic.items(): 254 | if dic_sub['type'] != 'trade': 255 | continue 256 | balance_dic.setdefault(currency, {})[pos_date_type] = dic_sub 257 | return balance_dic 258 | 259 | 260 | class StgHandlerBase(Thread, ABC): 261 | logger = logging.getLogger("StgHandlerBase") 262 | 263 | @staticmethod 264 | def factory(stg_class_obj: StgBase.__class__, strategy_params, md_agent_params_list, run_mode: RunMode, **run_mode_params): 265 | """ 266 | 建立策略对象 267 | 建立数据库相应记录信息 268 | 根据运行模式(实时、回测):选择相应的md_agent以及trade_agent 269 | :param stg_class_obj: 策略类型 StgBase 的子类 270 | :param strategy_params: 策略参数 271 | :param md_agent_params_list: 行情代理(md_agent)参数,支持同时订阅多周期、多品种, 272 | 例如:同时订阅 [ethusdt, eosusdt] 1min 行情、[btcusdt, ethbtc] tick 行情 273 | :param run_mode: 运行模式 RunMode.Realtime 或 RunMode.Backtest 274 | :param run_mode_params: 运行参数,回测模式下:运行起止时间,实时行情下:加载定时器等设置 275 | :return: 策略执行对象实力 276 | """ 277 | stg_run_info = StgRunInfo(stg_name=stg_class_obj.__name__, # '{.__name__}'.format(stg_class_obj) 278 | dt_from=datetime.now(), 279 | dt_to=None, 280 | stg_params=json.dumps(strategy_params), 281 | md_agent_params_list=json.dumps(md_agent_params_list), 282 | run_mode=int(run_mode), 283 | run_mode_params=json.dumps(run_mode_params)) 284 | with with_db_session(engine_abat) as session: 285 | session.add(stg_run_info) 286 | session.commit() 287 | stg_run_id = stg_run_info.stg_run_id 288 | # 设置运行模式:回测模式,实时模式。初始化交易接口 289 | # if run_mode == RunMode.Backtest: 290 | # trade_agent = BacktestTraderAgent(stg_run_id, run_mode_params) 291 | # elif run_mode == RunMode.Realtime: 292 | # trade_agent = RealTimeTraderAgent(stg_run_id, run_mode_params) 293 | # else: 294 | # raise ValueError('run_mode %d error' % run_mode) 295 | trade_agent_class = trader_agent_class_dic[run_mode] 296 | # 初始化策略实体,传入参数 297 | stg_base = stg_class_obj(**strategy_params) 298 | # 设置策略交易接口 trade_agent,这里不适用参数传递的方式而使用属性赋值, 299 | # 因为stg子类被继承后,参数主要用于设置策略所需各种参数使用 300 | stg_base.trade_agent = trade_agent_class(stg_run_id, run_mode_params) 301 | # 对不同周期设置相应的md_agent 302 | # 初始化各个周期的 md_agent 303 | md_period_agent_dic = {} 304 | for md_agent_param in md_agent_params_list: 305 | period = md_agent_param['md_period'] 306 | md_agent = MdAgentBase.factory(run_mode, **md_agent_param) 307 | md_period_agent_dic[period] = md_agent 308 | # 对各个周期分别加载历史数据,设置对应 handler 309 | # 通过 md_agent 加载各个周期的历史数据 310 | his_df_dic = md_agent.load_history() 311 | if his_df_dic is None: 312 | StgHandlerBase.logger.warning('加载 %s 历史数据为 None', period) 313 | continue 314 | if isinstance(his_df_dic, dict): 315 | md_df = his_df_dic['md_df'] 316 | else: 317 | md_df = his_df_dic 318 | warnings.warn('load_history 返回 df 数据格式即将废弃,请更新成 dict', DeprecationWarning) 319 | 320 | context = {ContextKey.instrument_id_list: list(md_agent.instrument_id_set)} 321 | stg_base.load_md_period_df(period, md_df, context) 322 | StgHandlerBase.logger.debug('加载 %s 历史数据 %s 条', period, 'None' if md_df is None else str(md_df.shape[0])) 323 | # 初始化 StgHandlerBase 实例 324 | if run_mode == RunMode.Realtime: 325 | stg_handler = StgHandlerRealtime(stg_run_id=stg_run_id, stg_base=stg_base, 326 | md_period_agent_dic=md_period_agent_dic, **run_mode_params) 327 | elif run_mode == RunMode.Backtest: 328 | stg_handler = StgHandlerBacktest(stg_run_id=stg_run_id, stg_base=stg_base, 329 | md_period_agent_dic=md_period_agent_dic, **run_mode_params) 330 | else: 331 | raise ValueError('run_mode %d error' % run_mode) 332 | StgHandlerBase.logger.debug('初始化 %r 完成', stg_handler) 333 | return stg_handler 334 | 335 | def __init__(self, stg_run_id, stg_base: StgBase, run_mode, md_period_agent_dic): 336 | super().__init__(daemon=True) 337 | self.stg_run_id = stg_run_id 338 | self.run_mode = run_mode 339 | # 初始化策略实体,传入参数 340 | self.stg_base = stg_base 341 | # 设置工作状态 342 | self.keep_running = None 343 | # 日志 344 | self.logger = logging.getLogger() 345 | # 对不同周期设置相应的md_agent 346 | self.md_period_agent_dic = md_period_agent_dic 347 | 348 | def stg_run_ending(self): 349 | """ 350 | 处理策略结束相关事项 351 | 释放策略资源 352 | 更新策略执行信息 353 | :return: 354 | """ 355 | self.stg_base.release() 356 | # 更新数据库 td_to 字段 357 | with with_db_session(engine_abat) as session: 358 | session.query(StgRunInfo).filter(StgRunInfo.stg_run_id == self.stg_run_id).update( 359 | {StgRunInfo.dt_to: datetime.now()}) 360 | # sql_str = StgRunInfo.update().where( 361 | # StgRunInfo.c.stg_run_id == self.stg_run_id).values(dt_to=datetime.now()) 362 | # session.execute(sql_str) 363 | session.commit() 364 | 365 | def __repr__(self): 366 | return '<{0.__class__.__name__}:{0.stg_run_id} {0.run_mode}>'.format(self) 367 | 368 | 369 | class StgHandlerRealtime(StgHandlerBase): 370 | 371 | def __init__(self, stg_run_id, stg_base: StgBase, md_period_agent_dic, **kwargs): 372 | super().__init__(stg_run_id=stg_run_id, stg_base=stg_base, run_mode=RunMode.Realtime, 373 | md_period_agent_dic=md_period_agent_dic) 374 | # 对不同周期设置相应的md_agent 375 | self.md_period_agent_dic = md_period_agent_dic 376 | # 设置线程池 377 | self.running_thread = {} 378 | # 日志 379 | self.logger = logging.getLogger() 380 | # 设置推送超时时间 381 | self.timeout_pull = 60 382 | # 设置独立的时间线程 383 | self.enable_timer_thread = kwargs.setdefault('enable_timer_thread', False) 384 | self.seconds_of_timer_interval = kwargs.setdefault('seconds_of_timer_interval', 9999) 385 | 386 | def run(self): 387 | 388 | # TODO: 以后再加锁,防止多线程,目前只是为了防止误操作导致的重复执行 389 | if self.keep_running: 390 | return 391 | else: 392 | self.keep_running = True 393 | 394 | try: 395 | # 策略初始化 396 | self.stg_base.init() 397 | # 对各个周期分别设置对应 handler 398 | for period, md_agent in self.md_period_agent_dic.items(): 399 | # 获取对应事件响应函数 400 | on_period_md_handler = self.stg_base.on_period_md_handler 401 | # 异步运行:每一个周期及其对应的 handler 作为一个线程独立运行 402 | thread_name = 'run_md_agent %s' % md_agent.name 403 | run_md_agent_thread = Thread(target=self.run_md_agent, name=thread_name, 404 | args=(md_agent, on_period_md_handler), daemon=True) 405 | self.running_thread[period] = run_md_agent_thread 406 | self.logger.info("加载 %s 线程", thread_name) 407 | run_md_agent_thread.start() 408 | 409 | if self.enable_timer_thread: 410 | thread_name = 'run_timer' 411 | timer_thread = Thread(target=self.run_timer, name=thread_name, daemon=True) 412 | self.logger.info("加载 %s 线程", thread_name) 413 | timer_thread.start() 414 | 415 | # 各个线程分别join等待结束信号 416 | for period, run_md_agent_thread in self.running_thread.items(): 417 | run_md_agent_thread.join() 418 | self.logger.info('%s period %s finished', run_md_agent_thread.name, period) 419 | finally: 420 | self.keep_running = False 421 | self.stg_run_ending() 422 | 423 | def run_timer(self): 424 | """ 425 | 负责定时运行策略对象的 on_timer 方法 426 | :return: 427 | """ 428 | while self.keep_running: 429 | try: 430 | self.stg_base.on_timer() 431 | except: 432 | self.logger.exception('on_timer 函数运行异常') 433 | finally: 434 | time.sleep(self.seconds_of_timer_interval) 435 | 436 | def run_md_agent(self, md_agent, handler): 437 | """ 438 | md_agent pull 方法的事件驱动处理函数 439 | :param md_agent: 440 | :param handler: self.stgbase对象的响应 md_agent 的梳理函数:根据不同的 md_period 可能是 on_tick、 on_min、 on_day、 on_week、 on_month 等其中一个 441 | :return: 442 | """ 443 | period = md_agent.md_period 444 | self.logger.info('启动 %s 行情监听线程', period) 445 | md_agent.connect() 446 | md_agent.subscribe() # 参数为空相当于 md_agent.subscribe(md_agent.instrument_id_list) 447 | md_agent.start() 448 | while self.keep_running: 449 | try: 450 | if not self.keep_running: 451 | break 452 | # 加载数据,是设置超时时间,防止长时间阻塞 453 | md_dic = md_agent.pull(self.timeout_pull) 454 | handler(period, md_dic) 455 | except Empty: 456 | # 工作状态检查 457 | pass 458 | except Exception: 459 | self.logger.exception('%s 事件处理句柄执行异常,对应行情数据md_dic:\n%s', 460 | period, md_dic) 461 | # time.sleep(1) 462 | md_agent.release() 463 | self.logger.info('period:%s finished', period) 464 | 465 | 466 | class StgHandlerBacktest(StgHandlerBase): 467 | 468 | def __init__(self, stg_run_id, stg_base: StgBase, md_period_agent_dic, date_from, date_to, **kwargs): 469 | super().__init__(stg_run_id=stg_run_id, stg_base=stg_base, run_mode=RunMode.Backtest, 470 | md_period_agent_dic=md_period_agent_dic) 471 | # 回测 ID 每一次测试生成一个新的ID,在数据库中作为本次测试的唯一标识 472 | # TODO: 这一ID需要从数据库生成 473 | # self.backtest_id = 1 474 | # self.stg_base._trade_agent.backtest_id = self.backtest_id 475 | # 设置回测时间区间 476 | self.date_from = try_2_date(date_from) 477 | self.date_to = try_2_date(date_to) 478 | if not isinstance(self.date_from, date): 479 | raise ValueError("date_from: %s", date_from) 480 | if not isinstance(self.date_to, date): 481 | raise ValueError("date_from: %s", date_to) 482 | # 初始资金账户金额 483 | self.init_cash = kwargs['init_cash'] 484 | # 载入回测时间段各个周期的历史数据,供回测使用 485 | # 对各个周期分别进行处理 486 | self.backtest_his_df_dic = {} 487 | for period, md_agent in self.md_period_agent_dic.items(): 488 | md_df = md_agent.load_history(date_from, date_to, load_md_count=0) 489 | if md_df is None: 490 | continue 491 | if isinstance(md_df, pd.DataFrame): 492 | # 对于 CTP 老程序接口直接返回的是 df,因此补充相关的 key 数据 493 | # TODO: 未来这部分代码将逐步给更替 494 | warnings.warn('load_history 需要返回 dict 类型数据, 对 DataFame 的数据处理即将废弃', DeprecationWarning) 495 | if period == PeriodType.Tick: 496 | his_df_dic = {'md_df': md_df, 497 | 'date_key': 'ActionDay', 'time_key': 'ActionTime', 498 | 'microseconds_key': 'ActionMillisec'} 499 | else: 500 | his_df_dic = {'md_df': md_df, 501 | 'date_key': 'ActionDay', 'time_key': 'ActionTime'} 502 | self.backtest_his_df_dic[period] = his_df_dic 503 | self.logger.debug('加载 %s 回测数据 %d 条记录', period, md_df.shape[0]) 504 | else: 505 | self.backtest_his_df_dic[period] = his_df_dic = md_df 506 | self.logger.debug('加载 %s 回测数据 %d 条记录', period, his_df_dic['md_df'].shape[0]) 507 | 508 | def run(self): 509 | """ 510 | 执行回测 511 | :return: 512 | """ 513 | # TODO: 以后再加锁,防止多线程,目前只是为了防止误操作导致的重复执行 514 | if self.keep_running: 515 | self.logger.warning('当前任务正在执行中..,避免重复执行') 516 | return 517 | else: 518 | self.keep_running = True 519 | self.logger.info('执行回测任务【%s - %s】开始', self.date_from, self.date_to) 520 | try: 521 | # 策略初始化 522 | self.stg_base.init() 523 | # 对每一个周期构建时间轴及对应记录的数组下标 524 | period_dt_idx_dic = {} 525 | for period, his_df_dic in self.backtest_his_df_dic.items(): 526 | his_df = his_df_dic['md_df'] 527 | datetime_s = his_df[his_df_dic['datetime_key']] if 'datetime_key' in his_df_dic else None 528 | date_s = his_df[his_df_dic['date_key']] if 'date_key' in his_df_dic else None 529 | time_s = his_df[his_df_dic['time_key']] if 'time_key' in his_df_dic else None 530 | microseconds_s = his_df[his_df_dic['microseconds_key']] if 'microseconds_key' in his_df_dic else None 531 | df_len = his_df.shape[0] 532 | # 整理日期轴 533 | dt_idx_dic = {} 534 | if datetime_s is not None: 535 | for idx in range(df_len): 536 | if microseconds_s: 537 | dt = datetime_s[idx] + timedelta(microseconds=int(microseconds_s[idx])) 538 | else: 539 | dt = datetime_s[idx] 540 | 541 | if dt in dt_idx_dic: 542 | dt_idx_dic[dt].append(idx) 543 | else: 544 | dt_idx_dic[dt] = [idx] 545 | elif date_s is not None and time_s is not None: 546 | for idx in range(df_len): 547 | # action_date = date_s[idx] 548 | # dt = datetime(action_date.year, action_date.month, action_date.day) + time_s[ 549 | # idx] + timedelta(microseconds=int(microseconds_s[idx])) 550 | if microseconds_s: 551 | dt = datetime.combine(date_s[idx], time_s[idx]) + timedelta( 552 | microseconds=int(microseconds_s[idx])) 553 | else: 554 | dt = datetime.combine(date_s[idx], time_s[idx]) 555 | 556 | if dt in dt_idx_dic: 557 | dt_idx_dic[dt].append(idx) 558 | else: 559 | dt_idx_dic[dt] = [idx] 560 | 561 | # action_day_s = his_df['ActionDay'] 562 | # action_time_s = his_df['ActionTime'] 563 | # # Tick 数据 存在 ActionMillisec 记录秒以下级别数据 564 | # if period == PeriodType.Tick: 565 | # action_milsec_s = his_df['ActionMillisec'] 566 | # dt_idx_dic = {} 567 | # for idx in range(df_len): 568 | # action_date = action_day_s[idx] 569 | # dt = datetime(action_date.year, action_date.month, action_date.day) + action_time_s[ 570 | # idx] + timedelta(microseconds=int(action_milsec_s[idx])) 571 | # if dt in dt_idx_dic: 572 | # dt_idx_dic[dt].append(idx) 573 | # else: 574 | # dt_idx_dic[dt] = [idx] 575 | # else: 576 | # dt_idx_dic = {} 577 | # for idx in range(df_len): 578 | # action_date = action_day_s[idx] 579 | # dt = datetime(action_date.year, action_date.month, action_date.day) + action_time_s[ 580 | # idx] 581 | # if dt in dt_idx_dic: 582 | # dt_idx_dic[dt].append(idx) 583 | # else: 584 | # dt_idx_dic[dt] = [idx] 585 | # 记录各个周期时间戳 586 | period_dt_idx_dic[period] = dt_idx_dic 587 | 588 | # 按照时间顺序将各个周期数据依次推入对应 handler 589 | period_idx_df = pd.DataFrame(period_dt_idx_dic).sort_index() 590 | for row_num in range(period_idx_df.shape[0]): 591 | period_idx_s = period_idx_df.ix[row_num, :] 592 | for period, idx_list in period_idx_s.items(): 593 | if all(np.isnan(idx_list)): 594 | continue 595 | his_df = self.backtest_his_df_dic[period]['md_df'] 596 | for idx_row in idx_list: 597 | # TODO: 这里存在着性能优化空间 DataFrame -> Series -> dict 效率太低 598 | md = his_df.ix[idx_row].to_dict() 599 | # 在回测阶段,需要对 trade_agent 设置最新的md数据,一遍交易接口确认相应的k线日期 600 | self.stg_base.trade_agent.set_curr_md(period, md) 601 | # 执行策略相应的事件响应函数 602 | self.stg_base.on_period_md_handler(period, md) 603 | # 根据最新的 md 及 持仓信息 更新 账户信息 604 | self.stg_base.trade_agent.update_account_info() 605 | self.logger.info('执行回测任务【%s - %s】完成', self.date_from, self.date_to) 606 | finally: 607 | self.keep_running = False 608 | self.stg_run_ending() 609 | 610 | 611 | class EventHandlersRelation: 612 | """ 613 | 用于记录事件类型与其对应的各种相关事件句柄之间的关系 614 | """ 615 | 616 | def __init__(self, period_type, prepare_event, md_event, param_type): 617 | self.period_type = period_type 618 | self.prepare_event = prepare_event 619 | self.md_event = md_event 620 | self.param_type = param_type 621 | 622 | 623 | class MACroseStg(StgBase): 624 | 625 | def __init__(self): 626 | super().__init__() 627 | self.ma5 = [] 628 | self.ma10 = [] 629 | 630 | def on_prepare_min1(self, md_df, context): 631 | if md_df: 632 | self.ma5 = list(md_df['close'].rolling(5, 5).mean())[10:] 633 | self.ma10 = list(md_df['close'].rolling(10, 10).mean())[10:] 634 | 635 | def on_min1(self, md_df, context): 636 | close = md_df['close'].iloc[-1] 637 | self.ma5.append(md_df['close'].iloc[-5:].mean()) 638 | self.ma10.append(md_df['close'].iloc[-10:].mean()) 639 | instrument_id = context[ContextKey.instrument_id_list][0] 640 | if self.ma5[-2] < self.ma10[-2] and self.ma5[-1] > self.ma10[-1]: 641 | position_date_pos_info_dic = self.get_position(instrument_id) 642 | no_target_position = True 643 | if position_date_pos_info_dic is not None: 644 | for position_date, pos_info in position_date_pos_info_dic.items(): 645 | direction = pos_info.direction 646 | if direction == Direction.Short: 647 | self.close_short(instrument_id, close, pos_info.position) 648 | elif direction == Direction.Long: 649 | no_target_position = False 650 | if no_target_position: 651 | self.open_long(instrument_id, close, 1) 652 | elif self.ma5[-2] > self.ma10[-2] and self.ma5[-1] < self.ma10[-1]: 653 | position_date_pos_info_dic = self.get_position(instrument_id) 654 | no_target_position = True 655 | if position_date_pos_info_dic is not None: 656 | for position_date, pos_info in position_date_pos_info_dic.items(): 657 | direction = pos_info.direction 658 | if direction == Direction.Long: 659 | self.close_long(instrument_id, close, pos_info.position) 660 | elif direction == Direction.Short: 661 | no_target_position = False 662 | if no_target_position: 663 | self.open_short(instrument_id, close, 1) 664 | 665 | 666 | if __name__ == '__main__': 667 | logging.basicConfig(level=logging.DEBUG, format=Config.LOG_FORMAT) 668 | # 参数设置 669 | strategy_params = {} 670 | md_agent_params_list = [{ 671 | 'name': 'min1', 672 | 'md_period': PeriodType.Min1, 673 | 'instrument_id_list': ['ethbtc'], # ['jm1711', 'rb1712', 'pb1801', 'IF1710'], 674 | 'init_md_date_to': '2017-9-1', 675 | }] 676 | run_mode_realtime_params = { 677 | 'run_mode': RunMode.Realtime, 678 | } 679 | run_mode_backtest_params = { 680 | 'run_mode': RunMode.Backtest, 681 | 'date_from': '2018-6-18', 682 | 'date_to': '2018-6-19', 683 | 'init_cash': 1000000, 684 | 'trade_mode': BacktestTradeMode.Order_2_Deal 685 | } 686 | # run_mode = RunMode.BackTest 687 | # 初始化策略处理器 688 | stghandler = StgHandlerBase.factory(stg_class_obj=MACroseStg, 689 | strategy_params=strategy_params, 690 | md_agent_params_list=md_agent_params_list, 691 | **run_mode_backtest_params) 692 | stghandler.start() 693 | time.sleep(10) 694 | stghandler.keep_running = False 695 | stghandler.join() 696 | logging.info("执行结束") 697 | -------------------------------------------------------------------------------- /abat/trade.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/20 15:12 6 | @File : trade.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | import logging 11 | from abc import abstractmethod, ABC 12 | from datetime import datetime 13 | from abat.common import RunMode 14 | from abat.backend.orm import OrderInfo 15 | 16 | logger = logging.getLogger(__package__) 17 | 18 | 19 | class TraderAgent(ABC): 20 | """ 21 | 交易代理(抽象类),回测交易代理,实盘交易代理的父类 22 | """ 23 | 24 | def __init__(self, stg_run_id, run_mode_params: dict): 25 | """ 26 | stg_run_id 作为每一次独立的执行策略过程的唯一标识 27 | :param stg_run_id: 28 | """ 29 | self.stg_run_id = stg_run_id 30 | self.run_mode_params = run_mode_params 31 | self.logger = logging.getLogger(self.__class__.__name__) 32 | 33 | @abstractmethod 34 | def connect(self): 35 | raise NotImplementedError() 36 | 37 | @abstractmethod 38 | def open_long(self, instrument_id, price, vol): 39 | raise NotImplementedError() 40 | 41 | @abstractmethod 42 | def close_long(self, instrument_id, price, vol): 43 | raise NotImplementedError() 44 | 45 | @abstractmethod 46 | def open_short(self, instrument_id, price, vol): 47 | raise NotImplementedError() 48 | 49 | @abstractmethod 50 | def close_short(self, instrument_id, price, vol): 51 | raise NotImplementedError() 52 | 53 | @abstractmethod 54 | def get_position(self, instrument_id) -> dict: 55 | raise NotImplementedError() 56 | 57 | @abstractmethod 58 | def get_order(self, instrument_id) -> OrderInfo: 59 | raise NotImplementedError() 60 | 61 | @abstractmethod 62 | def release(self): 63 | raise NotImplementedError() 64 | 65 | @property 66 | @abstractmethod 67 | def datetime_last_update_position(self) -> datetime: 68 | raise NotImplementedError() 69 | 70 | @property 71 | @abstractmethod 72 | def datetime_last_rtn_trade_dic(self) -> dict: 73 | raise NotImplementedError() 74 | 75 | @property 76 | @abstractmethod 77 | def datetime_last_update_position_dic(self) -> dict: 78 | raise NotImplementedError() 79 | 80 | @property 81 | @abstractmethod 82 | def datetime_last_send_order_dic(self) -> dict: 83 | raise NotImplementedError() 84 | 85 | @property 86 | @abstractmethod 87 | def get_balance(self) -> dict: 88 | raise NotImplementedError() 89 | 90 | 91 | trader_agent_class_dic = {RunMode.Backtest: TraderAgent, RunMode.Realtime: TraderAgent} 92 | 93 | 94 | def register_realtime_trader_agent(agent: TraderAgent) -> TraderAgent: 95 | trader_agent_class_dic[RunMode.Realtime] = agent 96 | logger.info('设置 realtime trade agent:%s', agent.__class__.__name__) 97 | return agent 98 | 99 | 100 | def register_backtest_trader_agent(agent: TraderAgent) -> TraderAgent: 101 | trader_agent_class_dic[RunMode.Backtest] = agent 102 | logger.info('设置 backtest trade agent:%s', agent.__class__.__name__) 103 | return agent 104 | -------------------------------------------------------------------------------- /abat/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/20 15:17 6 | @File : __init__.py.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | 11 | -------------------------------------------------------------------------------- /abat/utils/db_utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/4/8 21:11 6 | @File : db_utils.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 数据库相关工具 9 | """ 10 | from sqlalchemy.orm.session import Session 11 | from sqlalchemy.orm import sessionmaker 12 | from sqlalchemy.ext.declarative import DeclarativeMeta 13 | import json 14 | from datetime import date, datetime, timedelta 15 | from abat.utils.fh_utils import date_2_str 16 | from sqlalchemy.ext.compiler import compiles 17 | from sqlalchemy.sql.expression import Insert 18 | 19 | 20 | class SessionWrapper: 21 | """用于对session对象进行封装,方便使用with语句进行close控制""" 22 | 23 | def __init__(self, session): 24 | self.session = session 25 | 26 | def __enter__(self) -> Session: 27 | return self.session 28 | 29 | def __exit__(self, exc_type, exc_val, exc_tb): 30 | self.session.close() 31 | # logger.debug('db session closed') 32 | 33 | 34 | def with_db_session(engine, expire_on_commit=True): 35 | """创建session对象,返回 session_wrapper 可以使用with语句进行调用""" 36 | db_session = sessionmaker(bind=engine, expire_on_commit=expire_on_commit) 37 | session = db_session() 38 | return SessionWrapper(session) 39 | 40 | 41 | def get_db_session(engine) -> Session: 42 | """创建session对象 使用后需注意主动关闭""" 43 | db_session = sessionmaker(bind=engine) 44 | session = db_session() 45 | return session 46 | 47 | 48 | class AlchemyEncoder(json.JSONEncoder): 49 | def default(self, obj): 50 | # print("obj.__class__", obj.__class__, "isinstance(obj.__class__, DeclarativeMeta)", isinstance(obj.__class__, DeclarativeMeta)) 51 | if isinstance(obj.__class__, DeclarativeMeta): 52 | # an SQLAlchemy class 53 | fields = {} 54 | for field in [x for x in dir(obj) if not x.startswith('_') and x != 'metadata']: 55 | data = obj.__getattribute__(field) 56 | try: 57 | json.dumps(data) # this will fail on non-encodable values, like other classes 58 | fields[field] = data 59 | except TypeError: # 添加了对datetime的处理 60 | print(data) 61 | if isinstance(data, datetime): 62 | fields[field] = data.isoformat() 63 | elif isinstance(data, date): 64 | fields[field] = data.isoformat() 65 | elif isinstance(data, timedelta): 66 | fields[field] = (datetime.min + data).time().isoformat() 67 | else: 68 | fields[field] = None 69 | # a json-encodable dict 70 | return fields 71 | elif isinstance(obj, date): 72 | return json.dumps(date_2_str(obj)) 73 | 74 | return json.JSONEncoder.default(self, obj) 75 | 76 | 77 | @compiles(Insert) 78 | def append_string(insert, compiler, **kw): 79 | """ 80 | 支持 ON DUPLICATE KEY UPDATE 81 | 通过使用 on_duplicate_key_update=True 开启 82 | :param insert: 83 | :param compiler: 84 | :param kw: 85 | :return: 86 | """ 87 | s = compiler.visit_insert(insert, **kw) 88 | if insert.kwargs.get('on_duplicate_key_update'): 89 | fields = s[s.find("(") + 1:s.find(")")].replace(" ", "").split(",") 90 | generated_directive = ["{0}=VALUES({0})".format(field) for field in fields] 91 | return s + " ON DUPLICATE KEY UPDATE " + ",".join(generated_directive) 92 | return s -------------------------------------------------------------------------------- /abat/utils/redis.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/16 17:54 6 | @File : redis.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | from redis import StrictRedis, ConnectionPool 11 | from redis.client import PubSub 12 | from config import Config 13 | from abat.common import PeriodType 14 | _redis_client_dic = {} 15 | 16 | 17 | def get_channel(market=None, period: PeriodType=PeriodType.Mon1, symbol=''): 18 | """ 19 | 'md.{market}.{period}.{symbol}' or 'md.{period}.{symbol}' 20 | :param market: 21 | :param period: 22 | :param symbol: 23 | :return: 24 | """ 25 | if market: 26 | channel_str = f'md.{market}.{period.name}.{symbol}' 27 | else: 28 | channel_str = f'md.{period.name}.{symbol}' 29 | # md.market.tick.pair 30 | return channel_str 31 | 32 | 33 | def get_redis(db=0) -> StrictRedis: 34 | """ 35 | get StrictRedis object 36 | :param db: 37 | :return: 38 | """ 39 | if db in _redis_client_dic: 40 | redis_client = _redis_client_dic[db] 41 | else: 42 | conn = ConnectionPool(host=Config.REDIS_INFO_DIC['REDIS_HOST'], 43 | port=Config.REDIS_INFO_DIC['REDIS_PORT'], 44 | db=db) 45 | redis_client = StrictRedis(connection_pool=conn) 46 | _redis_client_dic[db] = redis_client 47 | return redis_client -------------------------------------------------------------------------------- /agent/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/20 19:53 6 | @File : __init__.py.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | 11 | -------------------------------------------------------------------------------- /agent/md_agent.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/20 19:53 6 | @File : md_agent.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | import json 11 | import time 12 | from queue import Queue 13 | from abat.md import MdAgentBase, register_backtest_md_agent, register_realtime_md_agent 14 | from abat.utils.fh_utils import bytes_2_str 15 | from abat.common import PeriodType 16 | from backend import engine_md 17 | from abat.utils.redis import get_redis, get_channel 18 | from backend.orm import MDMin1 19 | from abat.utils.db_utils import with_db_session 20 | import pandas as pd 21 | from config import Config 22 | 23 | 24 | class MdAgentPub(MdAgentBase): 25 | 26 | def load_history(self, date_from=None, date_to=None, load_md_count=None)->(pd.DataFrame, dict): 27 | """ 28 | 从mysql中加载历史数据 29 | 实时行情推送时进行合并后供数据分析使用 30 | :param date_from: None代表沿用类的 init_md_date_from 属性 31 | :param date_to: None代表沿用类的 init_md_date_from 属性 32 | :param load_md_count: 0 代表不限制,None代表沿用类的 init_load_md_count 属性,其他数字代表相应的最大加载条数 33 | :return: md_df 或者 34 | ret_data { 35 | 'md_df': md_df, 'datetime_key': 'ts_start', 36 | 'date_key': **, 'time_key': **, 'microseconds_key': ** 37 | } 38 | """ 39 | # 如果 init_md_date_from 以及 init_md_date_to 为空,则不加载历史数据 40 | if self.init_md_date_from is None and self.init_md_date_to is None: 41 | ret_data = {'md_df': None, 'datetime_key': 'ts_start'} 42 | return ret_data 43 | 44 | if self.md_period == PeriodType.Tick: 45 | # sql_str = """SELECT * FROM md_tick 46 | # WHERE InstrumentID IN (%s) %s 47 | # ORDER BY ActionDay DESC, ActionTime DESC, ActionMillisec DESC %s""" 48 | raise ValueError("暂不支持 tick 级回测") 49 | elif self.md_period == PeriodType.Min1: 50 | # 将sql 语句形势改成由 sqlalchemy 进行sql 拼装方式 51 | # sql_str = """select * from md_min_1 52 | # where InstrumentID in ('j1801') and tradingday>='2017-08-14' 53 | # order by ActionDay, ActionTime, ActionMillisec limit 200""" 54 | # sql_str = """SELECT * FROM md_min_1 55 | # WHERE InstrumentID IN (%s) %s 56 | # ORDER BY ActionDay DESC, ActionTime DESC %s""" 57 | with with_db_session(engine_md) as session: 58 | query = session.query( 59 | MDMin1.symbol.label('pair'), MDMin1.ts_start.label('ts_start'), 60 | MDMin1.open.label('open'), MDMin1.high.label('high'), 61 | MDMin1.low.label('low'), MDMin1.close.label('close'), 62 | MDMin1.vol.label('vol'), MDMin1.amount.label('amount'), MDMin1.count.label('count') 63 | ).filter( 64 | MDMin1.symbol.in_(self.instrument_id_set) 65 | ).order_by(MDMin1.ts_start.desc()) 66 | # 设置参数 67 | params = list(self.instrument_id_set) 68 | # date_from 起始日期 69 | if date_from is None: 70 | date_from = self.init_md_date_from 71 | if date_from is not None: 72 | # qry_str_date_from = " and tradingday>='%s'" % date_from 73 | query = query.filter(MDMin1.ts_start >= date_from) 74 | params.append(date_from) 75 | # date_to 截止日期 76 | if date_to is None: 77 | date_to = self.init_md_date_to 78 | if date_to is not None: 79 | # qry_str_date_to = " and tradingday<='%s'" % date_to 80 | query = query.filter(MDMin1.ts_start <= date_to) 81 | params.append(date_to) 82 | 83 | # load_limit 最大记录数 84 | if load_md_count is None: 85 | load_md_count = self.init_load_md_count 86 | if load_md_count is not None and load_md_count > 0: 87 | qry_str_limit = " limit %d" % load_md_count 88 | query = query.limite(load_md_count) 89 | params.append(load_md_count) 90 | 91 | sql_str = str(query) 92 | else: 93 | raise ValueError('%s error' % self.md_period) 94 | 95 | # 合约列表 96 | # qry_str_inst_list = "'" + "', '".join(self.instrument_id_set) + "'" 97 | # 拼接sql 98 | # qry_sql_str = sql_str % (qry_str_inst_list, qry_str_date_from + qry_str_date_to, qry_str_limit) 99 | 100 | # 加载历史数据 101 | md_df = pd.read_sql(sql_str, engine_md, params=params) 102 | # self.md_df = md_df 103 | ret_data = {'md_df': md_df, 'datetime_key': 'ts_start'} 104 | return ret_data 105 | 106 | 107 | @register_realtime_md_agent 108 | class MdAgentRealtime(MdAgentPub): 109 | 110 | def __init__(self, instrument_id_set, md_period: PeriodType, name=None, init_load_md_count=None, 111 | init_md_date_from=None, init_md_date_to=None, **kwargs): 112 | super().__init__(instrument_id_set, md_period, name=name, init_load_md_count=init_load_md_count, 113 | init_md_date_from=init_md_date_from, init_md_date_to=init_md_date_to, **kwargs) 114 | self.pub_sub = None 115 | self.md_queue = Queue() 116 | 117 | def connect(self): 118 | """链接redis、初始化历史数据""" 119 | redis_client = get_redis() 120 | self.pub_sub = redis_client.pubsub() 121 | 122 | def release(self): 123 | """释放channel资源""" 124 | self.pub_sub.close() 125 | 126 | def subscribe(self, instrument_id_set=None): 127 | """订阅合约""" 128 | super().subscribe(instrument_id_set) 129 | if instrument_id_set is None: 130 | instrument_id_set = self.instrument_id_set 131 | # channel_head = Config.REDIS_CHANNEL[self.md_period] 132 | # channel_list = [channel_head + instrument_id for instrument_id in instrument_id_set] 133 | channel_list = [get_channel(Config.MARKET_NAME, self.md_period, instrument_id) 134 | for instrument_id in instrument_id_set] 135 | self.pub_sub.psubscribe(*channel_list) 136 | 137 | def run(self): 138 | """启动多线程获取MD""" 139 | if not self.keep_running: 140 | self.keep_running = True 141 | for item in self.pub_sub.listen(): 142 | if self.keep_running: 143 | if item['type'] == 'pmessage': 144 | # self.logger.debug("pmessage:", item) 145 | md_dic_str = bytes_2_str(item['data']) 146 | md_dic = json.loads(md_dic_str) 147 | self.md_queue.put(md_dic) 148 | else: 149 | self.logger.debug("%s response: %s", self.name, item) 150 | else: 151 | break 152 | 153 | def unsubscribe(self, instrument_id_set): 154 | """退订合约""" 155 | if instrument_id_set is None: 156 | tmp_set = self.instrument_id_set 157 | super().unsubscribe(instrument_id_set) 158 | instrument_id_set = tmp_set 159 | else: 160 | super().unsubscribe(instrument_id_set) 161 | 162 | # channel_head = Config.REDIS_CHANNEL[self.md_period] 163 | # channel_list = [channel_head + instrument_id for instrument_id in instrument_id_set] 164 | channel_list = [get_channel(Config.MARKET_NAME, self.md_period, instrument_id) 165 | for instrument_id in instrument_id_set] 166 | if self.pub_sub is not None: # 在回测模式下有可能不进行 connect 调用以及 subscribe 订阅,因此,没有 pub_sub 实例 167 | self.pub_sub.punsubscribe(*channel_list) 168 | 169 | def pull(self, timeout=None): 170 | """阻塞方式提取合约数据""" 171 | md = self.md_queue.get(block=True, timeout=timeout) 172 | self.md_queue.task_done() 173 | return md 174 | 175 | 176 | @register_backtest_md_agent 177 | class MdAgentBacktest(MdAgentPub): 178 | 179 | def __init__(self, instrument_id_set, md_period: PeriodType, name=None, init_load_md_count=None, 180 | init_md_date_from=None, init_md_date_to=None, **kwargs): 181 | super().__init__(instrument_id_set, md_period, name=name, init_load_md_count=init_load_md_count, 182 | init_md_date_from=init_md_date_from, init_md_date_to=init_md_date_to, **kwargs) 183 | self.timeout = 1 184 | 185 | def connect(self): 186 | """链接redis、初始化历史数据""" 187 | pass 188 | 189 | def release(self): 190 | """释放channel资源""" 191 | pass 192 | 193 | def run(self): 194 | """启动多线程获取MD""" 195 | if not self.keep_running: 196 | self.keep_running = True 197 | while self.keep_running: 198 | time.sleep(self.timeout) 199 | else: 200 | self.logger.info('%s job finished', self.name) -------------------------------------------------------------------------------- /agent/td_agent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2017/10/3 4 | @author: MG 5 | """ 6 | import logging 7 | from collections import OrderedDict 8 | from datetime import datetime, timedelta, date 9 | from huobitrade import setKey 10 | from abat.backend.orm import OrderInfo, TradeInfo, PosStatusInfo, AccountStatusInfo 11 | from config import Config 12 | from abat.utils.db_utils import with_db_session 13 | from abat.backend import engine_abat 14 | from abat.common import Direction, Action, BacktestTradeMode, PositionDateType 15 | from abat.utils.fh_utils import try_n_times, ceil, floor 16 | from abat.trade import TraderAgent, register_backtest_trader_agent, register_realtime_trader_agent 17 | from huobitrade.service import HBRestAPI 18 | from backend import engine_md 19 | from backend.orm import SymbolPair 20 | from collections import defaultdict 21 | from enum import Enum 22 | 23 | logger = logging.getLogger() 24 | # 设置秘钥 25 | setKey(Config.EXCHANGE_ACCESS_KEY, Config.EXCHANGE_SECRET_KEY) 26 | 27 | 28 | class OrderType(Enum): 29 | """ 30 | buy-market:市价买, sell-market:市价卖, buy-limit:限价买, sell-limit:限价卖, buy-ioc:IOC买单, sell-ioc:IOC卖单 31 | """ 32 | buy_market = 'buy-market' 33 | sell_market = 'sell-market' 34 | buy_limit = 'buy-limit' 35 | sell_limit = 'sell-limit' 36 | buy_ioc = 'buy-ioc' 37 | sell_ioc = 'sell-ioc' 38 | 39 | 40 | @register_backtest_trader_agent 41 | class BacktestTraderAgent(TraderAgent): 42 | """ 43 | 供调用模拟交易接口使用 44 | """ 45 | 46 | def __init__(self, stg_run_id, run_mode_params: dict): 47 | super().__init__(stg_run_id, run_mode_params) 48 | # 标示 order 成交模式 49 | self.trade_mode = run_mode_params.setdefault('trade_mode', BacktestTradeMode.Order_2_Deal) 50 | # 账户初始资金 51 | self.init_cash = run_mode_params['init_cash'] 52 | # 用来标示当前md,一般执行买卖交易是,对时间,价格等信息进行记录 53 | self.curr_md_period_type = None 54 | self.curr_md = None 55 | # 用来保存历史的 order_info trade_info pos_status_info account_info 56 | self.order_info_list = [] 57 | self.trade_info_list = [] 58 | self.pos_status_info_dic = OrderedDict() 59 | self.account_info_list = [] 60 | # 持仓信息 初始化持仓状态字典,key为 instrument_id 61 | self._pos_status_info_dic = {} 62 | self._order_info_dic = {} 63 | # 账户信息 64 | self._account_status_info = None 65 | 66 | def set_curr_md(self, period_type, md): 67 | self.curr_md_period_type = period_type 68 | self.curr_md = md 69 | 70 | def connect(self): 71 | pass 72 | 73 | def _save_order_info(self, instrument_id, price: float, vol: int, direction: Direction, action: Action): 74 | order_date = self.curr_md['ts_start'].date() 75 | order_info = OrderInfo(stg_run_id=self.stg_run_id, 76 | order_date=order_date, 77 | order_time=self.curr_md['ts_start'].time(), 78 | order_millisec=0, 79 | direction=int(direction), 80 | action=int(action), 81 | instrument_id=instrument_id, 82 | order_price=float(price), 83 | order_vol=int(vol) 84 | ) 85 | if False: # 暂时不用 86 | with with_db_session(engine_abat, expire_on_commit=False) as session: 87 | session.add(order_info) 88 | session.commit() 89 | self.order_info_list.append(order_info) 90 | self._order_info_dic.setdefault(instrument_id, []).append(order_info) 91 | # 更新成交信息 92 | # Order_2_Deal 模式:下单即成交 93 | if self.trade_mode == BacktestTradeMode.Order_2_Deal: 94 | self._save_trade_info(order_info) 95 | 96 | def _save_trade_info(self, order_info: OrderInfo): 97 | """ 98 | 根据订单信息保存成交结果 99 | :param order_info: 100 | :return: 101 | """ 102 | trade_info = TradeInfo.create_by_order_info(order_info) 103 | self.trade_info_list.append(trade_info) 104 | # 更新持仓信息 105 | self._save_pos_status_info(trade_info) 106 | 107 | def _save_pos_status_info(self, trade_info: TradeInfo) -> AccountStatusInfo: 108 | """ 109 | 根据成交信息保存最新持仓信息 110 | :param trade_info: 111 | :return: 112 | """ 113 | instrument_id = trade_info.instrument_id 114 | if instrument_id in self._pos_status_info_dic: 115 | pos_status_info_last = self._pos_status_info_dic[instrument_id] 116 | pos_status_info = pos_status_info_last.update_by_trade_info(trade_info) 117 | else: 118 | pos_status_info = PosStatusInfo.create_by_trade_info(trade_info) 119 | # 更新 120 | trade_date, trade_time, trade_millisec = \ 121 | pos_status_info.trade_date, pos_status_info.trade_time, pos_status_info.trade_millisec 122 | self.pos_status_info_dic[(trade_date, trade_time, trade_millisec)] = pos_status_info 123 | self._pos_status_info_dic[instrument_id] = pos_status_info 124 | # self.c_save_acount_info(pos_status_info) 125 | 126 | def _create_account_status_info(self) -> AccountStatusInfo: 127 | stg_run_id, init_cash, md = self.stg_run_id, self.init_cash, self.curr_md 128 | trade_date = md['ts_start'].date() 129 | trade_time = md['ts_start'].time() 130 | trade_millisec = 0 131 | # trade_price = float(self.curr_md['close']) 132 | acc_status_info = AccountStatusInfo(stg_run_id=stg_run_id, 133 | trade_date=trade_date, 134 | trade_time=trade_time, 135 | trade_millisec=trade_millisec, 136 | available_cash=init_cash, 137 | balance_tot=init_cash, 138 | ) 139 | if Config.UPDATE_OR_INSERT_PER_ACTION: 140 | # 更新最新持仓纪录 141 | with with_db_session(engine_abat, expire_on_commit=False) as session: 142 | session.add(acc_status_info) 143 | session.commit() 144 | return acc_status_info 145 | 146 | def _update_by_pos_status_info(self) -> AccountStatusInfo: 147 | """根据 持仓列表更新账户信息""" 148 | 149 | pos_status_info_dic, md = self._pos_status_info_dic, self.curr_md 150 | 151 | account_status_info = self._account_status_info.create_by_self() 152 | # 上一次更新日期、时间 153 | # trade_date_last, trade_time_last, trade_millisec_last = \ 154 | # account_status_info.trade_date, account_status_info.trade_time, account_status_info.trade_millisec 155 | # 更新日期、时间 156 | trade_date = md['ts_start'].date() 157 | trade_time = md['ts_start'].time() 158 | trade_millisec = 0 159 | 160 | available_cash_chg = 0 161 | curr_margin = 0 162 | close_profit = 0 163 | position_profit = 0 164 | floating_pl_chg = 0 165 | margin_chg = 0 166 | floating_pl_cum = 0 167 | for instrument_id, pos_status_info in pos_status_info_dic.items(): 168 | curr_margin += pos_status_info.margin 169 | if pos_status_info.position == 0: 170 | close_profit += pos_status_info.floating_pl 171 | else: 172 | position_profit += pos_status_info.floating_pl 173 | floating_pl_chg += pos_status_info.floating_pl_chg 174 | margin_chg += pos_status_info.margin_chg 175 | floating_pl_cum += pos_status_info.floating_pl_cum 176 | 177 | available_cash_chg = floating_pl_chg - margin_chg 178 | account_status_info.curr_margin = curr_margin 179 | # # 对于同一时间,平仓后又开仓的情况,不能将close_profit重置为0 180 | # if trade_date == trade_date_last and trade_time == trade_time_last and trade_millisec == trade_millisec_last: 181 | # account_status_info.close_profit += close_profit 182 | # else: 183 | # 一个单位时段只允许一次,不需要考虑上面的情况 184 | account_status_info.close_profit = close_profit 185 | 186 | account_status_info.position_profit = position_profit 187 | account_status_info.available_cash += available_cash_chg 188 | account_status_info.floating_pl_cum = floating_pl_cum 189 | account_status_info.balance_tot = account_status_info.available_cash + curr_margin 190 | 191 | account_status_info.trade_date = trade_date 192 | account_status_info.trade_time = trade_time 193 | account_status_info.trade_millisec = trade_millisec 194 | if Config.UPDATE_OR_INSERT_PER_ACTION: 195 | # 更新最新持仓纪录 196 | with with_db_session(engine_abat, expire_on_commit=False) as session: 197 | session.add(account_status_info) 198 | session.commit() 199 | return account_status_info 200 | 201 | def _update_pos_status_info_by_md(self, pos_status_info_last) -> PosStatusInfo: 202 | """创建新的对象,根据 trade_info 更新相关信息""" 203 | md = self.curr_md 204 | trade_date = md['ts_start'].date() 205 | trade_time = md['ts_start'].time() 206 | trade_millisec = 0 207 | trade_price = float(md['close']) 208 | instrument_id = md['pair'] 209 | 210 | pos_status_info = pos_status_info_last.create_by_self() 211 | pos_status_info.cur_price = trade_price 212 | pos_status_info.trade_date = trade_date 213 | pos_status_info.trade_time = trade_time 214 | pos_status_info.trade_millisec = trade_millisec 215 | 216 | # 计算 floating_pl margin 217 | # instrument_info = Config.instrument_info_dic[instrument_id] 218 | # multiple = instrument_info['VolumeMultiple'] 219 | # margin_ratio = instrument_info['LongMarginRatio'] 220 | multiple, margin_ratio = 1, 1 221 | position = pos_status_info.position 222 | cur_price = pos_status_info.cur_price 223 | avg_price = pos_status_info.avg_price 224 | pos_status_info.margin = position * cur_price * multiple * margin_ratio 225 | pos_status_info.margin_chg = pos_status_info.margin - pos_status_info_last.margin 226 | if pos_status_info.direction == Direction.Long: 227 | pos_status_info.floating_pl = (cur_price - avg_price) * position * multiple 228 | else: 229 | pos_status_info.floating_pl = (avg_price - cur_price) * position * multiple 230 | pos_status_info.floating_pl_chg = pos_status_info.floating_pl - pos_status_info_last.floating_pl 231 | pos_status_info.floating_pl_cum += pos_status_info.floating_pl_chg 232 | 233 | if Config.UPDATE_OR_INSERT_PER_ACTION: 234 | # 更新最新持仓纪录 235 | with with_db_session(engine_abat, expire_on_commit=False) as session: 236 | session.add(pos_status_info) 237 | session.commit() 238 | return pos_status_info 239 | 240 | def update_account_info(self): 241 | """ 242 | 更新 持仓盈亏数据 汇总统计当前周期账户盈利情况 243 | :return: 244 | """ 245 | if self.curr_md is None: 246 | return 247 | if self._account_status_info is None: 248 | # self._account_status_info = AccountStatusInfo.create(self.stg_run_id, self.init_cash, self.curr_md) 249 | self._account_status_info = self._create_account_status_info() 250 | self.account_info_list.append(self._account_status_info) 251 | 252 | instrument_id = self.curr_md['pair'] 253 | if instrument_id in self._pos_status_info_dic: 254 | pos_status_info_last = self._pos_status_info_dic[instrument_id] 255 | trade_date = pos_status_info_last.trade_date 256 | trade_time = pos_status_info_last.trade_time 257 | # 如果当前K线以及更新则不需再次更新。如果当前K线以及有交易产生,则 pos_info 将会在 _save_pos_status_info 函数中被更新,因此无需再次更新 258 | if trade_date == self.curr_md['ts_start'].date() and trade_time == self.curr_md['ts_start'].time(): 259 | return 260 | # 说明上一根K线位置已经平仓,下一根K先位置将记录清除 261 | if pos_status_info_last.position == 0: 262 | del self._pos_status_info_dic[instrument_id] 263 | # 根据 md 数据更新 仓位信息 264 | # pos_status_info = pos_status_info_last.update_by_md(self.curr_md) 265 | pos_status_info = self._update_pos_status_info_by_md(pos_status_info_last) 266 | self._pos_status_info_dic[instrument_id] = pos_status_info 267 | 268 | # 统计账户信息,更新账户信息 269 | # account_status_info = self._account_status_info.update_by_pos_status_info( 270 | # self._pos_status_info_dic, self.curr_md) 271 | account_status_info = self._update_by_pos_status_info() 272 | self._account_status_info = account_status_info 273 | self.account_info_list.append(self._account_status_info) 274 | 275 | def open_long(self, instrument_id, price, vol): 276 | self._save_order_info(instrument_id, price, vol, Direction.Long, Action.Open) 277 | 278 | def close_long(self, instrument_id, price, vol): 279 | self._save_order_info(instrument_id, price, vol, Direction.Long, Action.Close) 280 | 281 | def open_short(self, instrument_id, price, vol): 282 | self._save_order_info(instrument_id, price, vol, Direction.Short, Action.Open) 283 | 284 | def close_short(self, instrument_id, price, vol): 285 | self._save_order_info(instrument_id, price, vol, Direction.Short, Action.Close) 286 | 287 | def get_position(self, instrument_id, **kwargs) -> dict: 288 | if instrument_id in self._pos_status_info_dic: 289 | pos_status_info = self._pos_status_info_dic[instrument_id] 290 | position_date_pos_info_dic = {PositionDateType.History: pos_status_info} 291 | else: 292 | position_date_pos_info_dic = None 293 | return position_date_pos_info_dic 294 | 295 | @property 296 | def datetime_last_update_position(self) -> datetime: 297 | return datetime.now() 298 | 299 | @property 300 | def datetime_last_rtn_trade_dic(self) -> dict: 301 | raise NotImplementedError() 302 | 303 | @property 304 | def datetime_last_update_position_dic(self) -> dict: 305 | raise NotImplementedError() 306 | 307 | @property 308 | def datetime_last_send_order_dic(self) -> dict: 309 | raise NotImplementedError() 310 | 311 | def release(self): 312 | try: 313 | with with_db_session(engine_abat) as session: 314 | session.add_all(self.order_info_list) 315 | session.add_all(self.trade_info_list) 316 | session.add_all(self.pos_status_info_dic.values()) 317 | session.add_all(self.account_info_list) 318 | session.commit() 319 | except: 320 | self.logger.exception("release exception") 321 | 322 | def get_order(self, instrument_id) -> OrderInfo: 323 | if instrument_id in self._order_info_dic: 324 | return self._order_info_dic[instrument_id] 325 | else: 326 | return None 327 | 328 | def get_balance(self): 329 | position_date_pos_info_dic = {key: {PositionDateType.History: pos_status_info} 330 | for key, pos_status_info in self._pos_status_info_dic.items()} 331 | return position_date_pos_info_dic 332 | 333 | 334 | @register_realtime_trader_agent 335 | class RealTimeTraderAgent(TraderAgent): 336 | """ 337 | 供调用实时交易接口使用 338 | """ 339 | 340 | def __init__(self, stg_run_id, run_mode_params: dict): 341 | super().__init__(stg_run_id, run_mode_params) 342 | self.trader_api = HBRestAPI() 343 | self.currency_balance_dic = {} 344 | self.currency_balance_last_get_datetime = None 345 | self.symbol_currency_dic = None 346 | self.symbol_precision_dic = None 347 | self._datetime_last_rtn_trade_dic = {} 348 | self._datetime_last_update_position_dic = {} 349 | 350 | def connect(self): 351 | with with_db_session(engine_md) as session: 352 | data = session.query(SymbolPair).all() 353 | self.symbol_currency_dic = { 354 | f'{sym.base_currency}{sym.quote_currency}': sym.base_currency 355 | for sym in data} 356 | self.symbol_precision_dic = { 357 | f'{sym.base_currency}{sym.quote_currency}': (int(sym.price_precision), int(sym.amount_precision)) 358 | for sym in data} 359 | 360 | # @try_n_times(times=3, sleep_time=2, logger=logger) 361 | def open_long(self, symbol, price, vol): 362 | """买入多头""" 363 | price_precision, amount_precision = self.symbol_precision_dic[symbol] 364 | if isinstance(price, float): 365 | price = format(price, f'.{price_precision}f') 366 | if isinstance(vol, float): 367 | if vol < 10 ** -amount_precision: 368 | logger.warning('%s open_long 订单量 %f 太小,忽略', symbol, vol) 369 | return 370 | vol = format(floor(vol, amount_precision), f'.{amount_precision}f') 371 | self.trader_api.send_order(vol, symbol, OrderType.buy_limit.value, price) 372 | self._datetime_last_rtn_trade_dic[symbol] = datetime.now() 373 | 374 | def close_long(self, symbol, price, vol): 375 | """卖出多头""" 376 | price_precision, amount_precision = self.symbol_precision_dic[symbol] 377 | if isinstance(price, float): 378 | price = format(price, f'.{price_precision}f') 379 | if isinstance(vol, float): 380 | if vol < 10 ** -amount_precision: 381 | logger.warning('%s close_long 订单量 %f 太小,忽略', symbol, vol) 382 | return 383 | vol = format(floor(vol, amount_precision), f'.{amount_precision}f') 384 | self.trader_api.send_order(vol, symbol, OrderType.sell_limit.value, price) 385 | self._datetime_last_rtn_trade_dic[symbol] = datetime.now() 386 | 387 | def open_short(self, instrument_id, price, vol): 388 | # self.trader_api.open_short(instrument_id, price, vol) 389 | raise NotImplementedError() 390 | 391 | def close_short(self, instrument_id, price, vol): 392 | # self.trader_api.close_short(instrument_id, price, vol) 393 | raise NotImplementedError() 394 | 395 | def get_position(self, instrument_id, force_refresh=False) -> dict: 396 | """ 397 | instrument_id(相当于 symbol ) 398 | symbol ethusdt, btcusdt 399 | currency eth, btc 400 | :param instrument_id: 401 | :param force_refresh: 402 | :return: 403 | """ 404 | symbol = instrument_id 405 | currency = self.get_currency(symbol) 406 | # currency = instrument_id 407 | # self.logger.debug('symbol:%s force_refresh=%s', symbol, force_refresh) 408 | position_date_inv_pos_dic = self.get_balance(currency=currency, force_refresh=force_refresh) 409 | return position_date_inv_pos_dic 410 | 411 | def get_currency(self, symbol): 412 | """ 413 | 根据 symbol 找到对应的 currency 414 | symbol: ethusdt, btcusdt 415 | currency: eth, btc 416 | :param symbol: 417 | :return: 418 | """ 419 | return self.symbol_currency_dic[symbol] 420 | 421 | def get_balance(self, non_zero_only=False, trade_type_only=True, currency=None, force_refresh=False): 422 | """ 423 | 调用接口 查询 各个币种仓位 424 | :param non_zero_only: 只保留非零币种 425 | :param trade_type_only: 只保留 trade 类型币种,frozen 类型的不保存 426 | :param currency: 只返回制定币种 usdt eth 等 427 | :param force_refresh: 强制刷新,默认没30秒允许重新查询一次 428 | :return: {'usdt': {: {'currency': 'usdt', 'type': 'trade', 'balance': 144.09238}}} 429 | """ 430 | if force_refresh or self.currency_balance_last_get_datetime is None or \ 431 | self.currency_balance_last_get_datetime < datetime.now() - timedelta(seconds=30): 432 | ret_data = self.trader_api.get_balance() 433 | acc_balance = ret_data['data']['list'] 434 | self.logger.debug('更新持仓数据: %d 条', len(acc_balance)) 435 | acc_balance_new_dic = defaultdict(dict) 436 | for balance_dic in acc_balance: 437 | currency_curr = balance_dic['currency'] 438 | self._datetime_last_update_position_dic[currency_curr] = datetime.now() 439 | 440 | if non_zero_only and balance_dic['balance'] == '0': 441 | continue 442 | 443 | if trade_type_only and balance_dic['type'] != 'trade': 444 | continue 445 | balance_dic['balance'] = float(balance_dic['balance']) 446 | # self.logger.debug(balance_dic) 447 | if PositionDateType.History in acc_balance_new_dic[currency_curr]: 448 | balance_dic_old = acc_balance_new_dic[currency_curr][PositionDateType.History] 449 | balance_dic_old['balance'] += balance_dic['balance'] 450 | # TODO: 日后可以考虑将 PositionDateType.History 替换为 type 451 | acc_balance_new_dic[currency_curr][PositionDateType.History] = balance_dic 452 | else: 453 | acc_balance_new_dic[currency_curr] = {PositionDateType.History: balance_dic} 454 | 455 | self.currency_balance_dic = acc_balance_new_dic 456 | self.currency_balance_last_get_datetime = datetime.now() 457 | 458 | if currency is not None: 459 | if currency in self.currency_balance_dic: 460 | ret_data = self.currency_balance_dic[currency] 461 | # for position_date_type, data in self.currency_balance_dic[currency].items(): 462 | # if data['currency'] == currency: 463 | # ret_data = data 464 | # break 465 | else: 466 | ret_data = None 467 | else: 468 | ret_data = self.currency_balance_dic 469 | return ret_data 470 | 471 | @property 472 | def datetime_last_update_position(self) -> datetime: 473 | return self.currency_balance_last_get_datetime 474 | 475 | @property 476 | def datetime_last_rtn_trade_dic(self) -> dict: 477 | return self._datetime_last_rtn_trade_dic 478 | 479 | @property 480 | def datetime_last_update_position_dic(self) -> dict: 481 | return self._datetime_last_update_position_dic 482 | 483 | @property 484 | def datetime_last_send_order_dic(self) -> dict: 485 | raise NotImplementedError() 486 | 487 | def get_order(self, instrument_id, states='submitted') -> list: 488 | """ 489 | 490 | :param instrument_id: 491 | :param states: 492 | :return: 格式如下: 493 | [{'id': 603164274, 'symbol': 'ethusdt', 'account-id': 909325, 'amount': '4.134700000000000000', 494 | 'price': '983.150000000000000000', 'created-at': 1515166787246, 'type': 'buy-limit', 495 | 'field-amount': '4.134700000000000000', 'field-cash-amount': '4065.030305000000000000', 496 | 'field-fees': '0.008269400000000000', 'finished-at': 1515166795669, 'source': 'web', 497 | 'state': 'filled', 'canceled-at': 0}, 498 | ... ] 499 | """ 500 | symbol = instrument_id 501 | ret_data = self.trader_api.get_orders_info(symbol=symbol, states=states) 502 | return ret_data['data'] 503 | 504 | def cancel_order(self, instrument_id): 505 | symbol = instrument_id 506 | order_list = self.get_order(symbol) 507 | order_id_list = [data['id'] for data in order_list] 508 | return self.trader_api.batchcancel_order(order_id_list) 509 | 510 | def release(self): 511 | pass 512 | 513 | 514 | if __name__ == "__main__": 515 | import time 516 | 517 | # 测试交易 下单接口及撤单接口 518 | # symbol, vol, price = 'ocnusdt', 1, 0.00004611 # OCN/USDT 519 | symbol, vol, price = 'eosusdt', 1.0251, 4.1234 # OCN/USDT 520 | 521 | td = RealTimeTraderAgent(stg_run_id=1, run_mode_params={}) 522 | td.open_long(symbol=symbol, price=price, vol=vol) 523 | order_dic_list = td.get_order(instrument_id=symbol) 524 | print('after open_long', order_dic_list) 525 | assert len(order_dic_list) == 1 526 | td.cancel_order(instrument_id=symbol) 527 | time.sleep(1) 528 | order_dic_list = td.get_order(instrument_id=symbol) 529 | print('after cancel', order_dic_list) 530 | assert len(order_dic_list) == 0 531 | -------------------------------------------------------------------------------- /analysis/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/1/14 4 | @author: MG 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /analysis/account.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2018/1/14 4 | @author: MG 5 | """ 6 | import matplotlib.pyplot as plt 7 | from abat.utils.fh_utils import return_risk_analysis 8 | import pandas as pd 9 | import numpy as np 10 | import logging 11 | from abat.backend.orm import AccountStatusInfo, StgRunInfo 12 | from abat.backend import engine_abat 13 | from abat.utils.db_utils import with_db_session 14 | from sqlalchemy import func, or_, and_, column, not_ 15 | 16 | logger = logging.getLogger() 17 | 18 | 19 | def get_account_balance(stg_run_id): 20 | """ 21 | 获取 account_info 账户走势数据 22 | :param stg_run_id: 23 | :return: 24 | """ 25 | with with_db_session(engine_abat) as session: 26 | sql_str = str( 27 | session.query( 28 | func.concat(AccountStatusInfo.trade_date, ' ', AccountStatusInfo.trade_time).label('trade_datetime'), 29 | AccountStatusInfo.available_cash.label('available_cash'), 30 | AccountStatusInfo.curr_margin.label('curr_margin'), 31 | AccountStatusInfo.balance_tot.label('balance_tot') 32 | ).filter(AccountStatusInfo.stg_run_id == stg_run_id).order_by( 33 | AccountStatusInfo.trade_date, AccountStatusInfo.trade_time 34 | ) 35 | ) 36 | # sql_str = """SELECT concat(trade_date, " ", trade_time) trade_datetime, available_cash, curr_margin, balance_tot 37 | # FROM account_status_info where stg_run_id=%s order by trade_date, trade_time""" 38 | data_df = pd.read_sql(sql_str, engine_abat, params=[' ', stg_run_id]) 39 | data_df["return_rate"] = (data_df["balance_tot"].pct_change().fillna(0) + 1).cumprod() 40 | data_df = data_df.set_index("trade_datetime") 41 | return data_df 42 | 43 | 44 | if __name__ == "__main__": 45 | with with_db_session(engine_abat) as session: 46 | # stg_run_id = session.execute("select max(stg_run_id) from stg_run_info").fetchone()[0] 47 | stg_run_id = session.query(func.max(StgRunInfo.stg_run_id)).scalar() 48 | # stg_run_id = 2 49 | data_df = get_account_balance(stg_run_id) 50 | 51 | logger.info("\n%s", data_df) 52 | data_df.plot(ylim=[min(data_df["available_cash"]), max(data_df["balance_tot"])]) 53 | data_df.plot(ylim=[min(data_df["curr_margin"]), max(data_df["curr_margin"])]) 54 | stat_df = return_risk_analysis(data_df[['return_rate']], freq=None) 55 | logger.info("\n%s", stat_df) 56 | plt.show() 57 | -------------------------------------------------------------------------------- /backend/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/12 13:47 6 | @File : __init__.py.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | from sqlalchemy import create_engine 11 | from config import Config 12 | 13 | engines = {key: create_engine(url) for key, url in Config.DB_URL_DIC.items()} 14 | 15 | engine_md = engines[Config.DB_SCHEMA_MD] 16 | -------------------------------------------------------------------------------- /backend/check.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/6 9:57 6 | @File : check.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 用于对系统配置的环境进行检测,检查是否环境可用,包括mysql、redis等 9 | """ 10 | from config import Config 11 | from abat.common import PeriodType 12 | import threading 13 | import json 14 | import time 15 | import logging 16 | from abat.utils.fh_utils import bytes_2_str 17 | from abat.utils.redis import get_redis, get_channel 18 | logger = logging.getLogger() 19 | _signal = {} 20 | 21 | 22 | def _timer(channel): 23 | global _signal 24 | count = 0 25 | r = get_redis() 26 | while not _signal['redis']: 27 | md_str = json.dumps({"message": "Hello World!!", "count": count}) 28 | r.publish(channel, md_str) 29 | logger.debug("发布成功 %s", md_str) 30 | count += 1 31 | if count >= 15: 32 | break 33 | time.sleep(1) 34 | 35 | 36 | def check_redis(): 37 | global _signal 38 | # channel_header = Config.REDIS_CHANNEL[PeriodType.Tick] 39 | instrument_id = 'rb1805' 40 | # channel = channel_header + 'test.' + instrument_id 41 | channel = get_channel('huobi', PeriodType.Year1, instrument_id) 42 | _signal['redis'] = False 43 | 44 | timer_t = threading.Thread(target=_timer, args=(channel,)) 45 | timer_t.start() 46 | 47 | def _receiver(channel): 48 | # 接收订阅的行情,成功接收后退出 49 | global _signal 50 | redis_client = get_redis() 51 | pub_sub = redis_client.pubsub() 52 | pub_sub.psubscribe(channel) 53 | for item in pub_sub.listen(): 54 | logger.debug("接收成功 %s", item) 55 | if item['type'] == 'pmessage': 56 | md_dic_str = bytes_2_str(item['data']) 57 | md_dic = json.loads(md_dic_str) 58 | if "message" in md_dic and "count" in md_dic: 59 | _signal['redis'] = True 60 | logger.debug("接收到消息") 61 | break 62 | 63 | receiver_t = threading.Thread(target=_receiver, args=(channel,)) 64 | receiver_t.start() 65 | 66 | for n in range(20): 67 | if _signal['redis']: 68 | logging.debug("检测redis %d %s", n, _signal['redis']) 69 | timer_t.join(1) 70 | break 71 | time.sleep(1) 72 | else: 73 | logger.error("redis 检测未通过") 74 | 75 | return _signal['redis'] 76 | 77 | 78 | def check(): 79 | ok_list = [] 80 | is_ok = check_redis() 81 | ok_list.append(is_ok) 82 | if is_ok: 83 | logger.info("redis 检测成功") 84 | 85 | return all(ok_list) 86 | 87 | 88 | if __name__ == "__main__": 89 | is_ok = check() 90 | logger.info("全部检测完成,%s", is_ok) 91 | -------------------------------------------------------------------------------- /backend/orm.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/12 13:02 6 | @File : orm.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | from sqlalchemy import Column, Integer, String, UniqueConstraint, TIMESTAMP 11 | from sqlalchemy.dialects.mysql import DOUBLE 12 | from sqlalchemy.ext.declarative import declarative_base 13 | from abat.utils.db_utils import with_db_session 14 | from backend import engine_md 15 | from config import Config 16 | import logging 17 | logger = logging.getLogger() 18 | BaseModel = declarative_base() 19 | 20 | 21 | class SymbolPair(BaseModel): 22 | __tablename__ = 'symbol_pair_info' 23 | id = Column(Integer, autoincrement=True, unique=True) 24 | market = Column(String(10), primary_key=True) 25 | base_currency = Column(String(10), primary_key=True) 26 | quote_currency = Column(String(10), primary_key=True) 27 | price_precision = Column(Integer) 28 | amount_precision = Column(Integer) 29 | symbol_partition = Column(String(12)) 30 | __table_args__ = ( 31 | UniqueConstraint('base_currency', 'quote_currency'), 32 | ) 33 | 34 | 35 | class MDTick(BaseModel): 36 | __tablename__ = 'md_min1_tick_bc' 37 | id = Column(Integer, autoincrement=True, unique=True) 38 | market = Column(String(10), primary_key=True) 39 | symbol = Column(String(10), primary_key=True) 40 | ts_start = Column(TIMESTAMP) 41 | ts_curr = Column(TIMESTAMP, primary_key=True) 42 | open = Column(DOUBLE) 43 | high = Column(DOUBLE) 44 | low = Column(DOUBLE) 45 | close = Column(DOUBLE) 46 | amount = Column(DOUBLE) 47 | vol = Column(DOUBLE) 48 | count = Column(DOUBLE) 49 | 50 | 51 | class MDMin1(BaseModel): 52 | __tablename__ = 'md_min1_bc' 53 | market = Column(String(10), primary_key=True) 54 | symbol = Column(String(10), primary_key=True) 55 | ts_start = Column(TIMESTAMP, primary_key=True) 56 | ts_curr = Column(TIMESTAMP) 57 | open = Column(DOUBLE) 58 | high = Column(DOUBLE) 59 | low = Column(DOUBLE) 60 | close = Column(DOUBLE) 61 | amount = Column(DOUBLE) 62 | vol = Column(DOUBLE) 63 | count = Column(DOUBLE) 64 | 65 | 66 | class MDMin1Temp(BaseModel): 67 | __tablename__ = 'md_min1_bc_temp' 68 | market = Column(String(10), primary_key=True) 69 | symbol = Column(String(10), primary_key=True) 70 | ts_start = Column(TIMESTAMP, primary_key=True) 71 | ts_curr = Column(TIMESTAMP) 72 | open = Column(DOUBLE) 73 | high = Column(DOUBLE) 74 | low = Column(DOUBLE) 75 | close = Column(DOUBLE) 76 | amount = Column(DOUBLE) 77 | vol = Column(DOUBLE) 78 | count = Column(DOUBLE) 79 | 80 | 81 | class MDMin60(BaseModel): 82 | __tablename__ = 'md_min60_bc' 83 | market = Column(String(10), primary_key=True) 84 | symbol = Column(String(10), primary_key=True) 85 | ts_start = Column(TIMESTAMP, primary_key=True) 86 | ts_curr = Column(TIMESTAMP) 87 | open = Column(DOUBLE) 88 | high = Column(DOUBLE) 89 | low = Column(DOUBLE) 90 | close = Column(DOUBLE) 91 | amount = Column(DOUBLE) 92 | vol = Column(DOUBLE) 93 | count = Column(DOUBLE) 94 | 95 | 96 | class MDMin60Temp(BaseModel): 97 | __tablename__ = 'md_min60_bc_temp' 98 | market = Column(String(10), primary_key=True) 99 | symbol = Column(String(10), primary_key=True) 100 | ts_start = Column(TIMESTAMP, primary_key=True) 101 | ts_curr = Column(TIMESTAMP) 102 | open = Column(DOUBLE) 103 | high = Column(DOUBLE) 104 | low = Column(DOUBLE) 105 | close = Column(DOUBLE) 106 | amount = Column(DOUBLE) 107 | vol = Column(DOUBLE) 108 | count = Column(DOUBLE) 109 | 110 | 111 | class MDMinDaily(BaseModel): 112 | __tablename__ = 'md_daily_bc' 113 | market = Column(String(10), primary_key=True) 114 | symbol = Column(String(10), primary_key=True) 115 | ts_start = Column(TIMESTAMP, primary_key=True) 116 | ts_curr = Column(TIMESTAMP) 117 | open = Column(DOUBLE) 118 | high = Column(DOUBLE) 119 | low = Column(DOUBLE) 120 | close = Column(DOUBLE) 121 | amount = Column(DOUBLE) 122 | vol = Column(DOUBLE) 123 | count = Column(DOUBLE) 124 | 125 | 126 | class MDMinDailyTemp(BaseModel): 127 | __tablename__ = 'md_daily_bc_temp' 128 | market = Column(String(10), primary_key=True) 129 | symbol = Column(String(10), primary_key=True) 130 | ts_start = Column(TIMESTAMP, primary_key=True) 131 | ts_curr = Column(TIMESTAMP) 132 | open = Column(DOUBLE) 133 | high = Column(DOUBLE) 134 | low = Column(DOUBLE) 135 | close = Column(DOUBLE) 136 | amount = Column(DOUBLE) 137 | vol = Column(DOUBLE) 138 | count = Column(DOUBLE) 139 | 140 | 141 | def init(alter_table=False): 142 | BaseModel.metadata.create_all(engine_md) 143 | if alter_table: 144 | with with_db_session(engine=engine_md) as session: 145 | for table_name, _ in BaseModel.metadata.tables.items(): 146 | sql_str = f"show table status from {Config.DB_SCHEMA_MD} where name=:table_name" 147 | row_data = session.execute(sql_str, params={'table_name': table_name}).first() 148 | if row_data is None: 149 | continue 150 | if row_data[1].lower() == 'myisam': 151 | continue 152 | 153 | logger.info('修改 %s 表引擎为 MyISAM', table_name) 154 | sql_str = "ALTER TABLE %s ENGINE = MyISAM" % table_name 155 | session.execute(sql_str) 156 | 157 | # This is an issue https://www.mail-archive.com/sqlalchemy@googlegroups.com/msg19744.html 158 | session.execute(f"ALTER TABLE {SymbolPair.__tablename__} CHANGE COLUMN `id` `id` INT(11) NULL AUTO_INCREMENT") 159 | session.commit() 160 | # This is an issue https://www.mail-archive.com/sqlalchemy@googlegroups.com/msg19744.html 161 | session.execute(f"ALTER TABLE {MDTick.__tablename__} CHANGE COLUMN `id` `id` INT(11) NULL AUTO_INCREMENT") 162 | session.commit() 163 | 164 | logger.info("所有表结构建立完成") 165 | 166 | 167 | if __name__ == "__main__": 168 | init() 169 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2017/6/9 4 | @author: MG 5 | """ 6 | import logging 7 | from logging.handlers import RotatingFileHandler 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | class ConfigBase: 13 | # 交易所名称 14 | MARKET_NAME = 'huobi' 15 | 16 | # api configuration 17 | EXCHANGE_ACCESS_KEY = "" 18 | EXCHANGE_SECRET_KEY = "" 19 | 20 | # mysql db info 21 | DB_SCHEMA_ABAT = 'abat' 22 | DB_SCHEMA_MD = 'bc_md' 23 | DB_URL_DIC = { 24 | DB_SCHEMA_MD: 'mysql://mg:****10.0.3.66/' + DB_SCHEMA_MD, 25 | DB_SCHEMA_ABAT: 'mysql://mg:****@10.0.3.66/' + DB_SCHEMA_ABAT, 26 | } 27 | 28 | # redis info 29 | REDIS_INFO_DIC = {'REDIS_HOST': '192.168.239.131', 30 | 'REDIS_PORT': '6379', 31 | } 32 | 33 | # evn configuration 34 | LOG_FORMAT = '%(asctime)s %(levelname)s %(name)s %(filename)s.%(funcName)s:%(lineno)d|%(message)s' 35 | 36 | # 每一次实务均产生数据库插入或更新动作(默认:否) 37 | UPDATE_OR_INSERT_PER_ACTION = False 38 | 39 | 40 | class ConfigProduct(ConfigBase): 41 | # 测试子账户 key 42 | EXCHANGE_ACCESS_KEY = '***' 43 | EXCHANGE_SECRET_KEY = '***' 44 | 45 | DB_URL_DIC = { 46 | ConfigBase.DB_SCHEMA_MD: 'mysql://mg:***@10.0.3.66/' + ConfigBase.DB_SCHEMA_MD, 47 | ConfigBase.DB_SCHEMA_ABAT: 'mysql://mg:***@10.0.3.66/' + ConfigBase.DB_SCHEMA_ABAT, 48 | } 49 | 50 | 51 | # 开发配置(SIMNOW MD + Trade) 52 | # Config = ConfigBase() 53 | # 测试配置(测试行情库) 54 | # Config = ConfigTest() 55 | # 生产配置 56 | Config = ConfigProduct() 57 | 58 | # 设定默认日志格式 59 | logging.basicConfig(level=logging.DEBUG, format=Config.LOG_FORMAT) 60 | # 设置rest调用日志输出级别 61 | logging.getLogger('requests.packages.urllib3.connectionpool').setLevel(logging.WARNING) 62 | logging.getLogger('urllib3.connectionpool').setLevel(logging.WARNING) 63 | logging.getLogger('EventAgent').setLevel(logging.INFO) 64 | logging.getLogger('StgBase').setLevel(logging.INFO) 65 | 66 | # 配置文件日至 67 | Rthandler = RotatingFileHandler('log.log', maxBytes=10 * 1024 * 1024, backupCount=5) 68 | Rthandler.setLevel(logging.INFO) 69 | formatter = logging.Formatter(Config.LOG_FORMAT) 70 | Rthandler.setFormatter(formatter) 71 | logging.getLogger('').addHandler(Rthandler) 72 | -------------------------------------------------------------------------------- /mass/alipay_code200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmmaaaggg/IBATS_HuobiTrader_old/d271e896332116fe43356391acdcff9725e7cb34/mass/alipay_code200.png -------------------------------------------------------------------------------- /mass/dashang_code200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmmaaaggg/IBATS_HuobiTrader_old/d271e896332116fe43356391acdcff9725e7cb34/mass/dashang_code200.png -------------------------------------------------------------------------------- /mass/webchat_code200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmmaaaggg/IBATS_HuobiTrader_old/d271e896332116fe43356391acdcff9725e7cb34/mass/webchat_code200.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2018.4.16 2 | chardet==3.0.4 3 | huobitrade==0.1.9 4 | idna==2.6 5 | mysqlclient==1.3.12 6 | numpy==1.14.4 7 | pandas==0.23.0 8 | prodconpattern==0.1.1 9 | pymongo==3.6.1 10 | python-dateutil==2.7.3 11 | pytz==2018.4 12 | pyzmq==17.0.0 13 | redis==2.10.6 14 | requests>=2.20.0 15 | six==1.11.0 16 | SQLAlchemy==1.2.8 17 | urllib3==1.22 18 | websocket-client==0.48.0 19 | xlrd==1.1.0 20 | click>=7.0 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/9/28 8:42 6 | @File : run.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | import click 11 | import logging 12 | import time 13 | from abat.common import PeriodType, RunMode, BacktestTradeMode 14 | from abat.strategy import StgHandlerBase 15 | from strategy.bs_against_files.csv_orders_with_feedback import ReadFileStg 16 | from strategy.simple_strategy import MACroseStg 17 | 18 | logger = logging.getLogger() 19 | 20 | strategy_list = [ReadFileStg] # , MACroseStg 21 | promt_str = '输入对应数字选择执行策略:\n' + \ 22 | '\n'.join(['%d) %s' % (num, foo.__name__) for num, foo in enumerate(strategy_list)]) + '\n' 23 | 24 | 25 | @click.command() 26 | @click.option('--num', type=click.IntRange(0, len(strategy_list) - 1), prompt=promt_str) 27 | @click.option('--init', type=click.BOOL, default=False) 28 | def main(num, init): 29 | if init: 30 | from abat.backend.orm import init 31 | init() 32 | 33 | stg_func = strategy_list[num] 34 | 35 | DEBUG = False 36 | # 参数设置 37 | strategy_params = {} 38 | md_agent_params_list = [ 39 | # { 40 | # 'name': 'min1', 41 | # 'md_period': PeriodType.Min1, 42 | # 'instrument_id_list': ['rb1805', 'i1801'], # ['jm1711', 'rb1712', 'pb1801', 'IF1710'], 43 | # 'init_md_date_to': '2017-9-1', 44 | # 'dict_or_df_as_param': dict 45 | # }, 46 | { 47 | 'name': 'tick', 48 | 'md_period': PeriodType.Tick, 49 | 'instrument_id_list': ['ethusdt', 'eosusdt'], # 50 | }] 51 | run_mode_realtime_params = { 52 | 'run_mode': RunMode.Realtime, 53 | 'enable_timer_thread': True, 54 | 'seconds_of_timer_interval': 15, 55 | } 56 | run_mode_backtest_params = { 57 | 'run_mode': RunMode.Backtest, 58 | 'date_from': '2017-9-4', 59 | 'date_to': '2017-9-27', 60 | 'init_cash': 1000000, 61 | 'trade_mode': BacktestTradeMode.Order_2_Deal 62 | } 63 | # run_mode = RunMode.BackTest 64 | # 初始化策略处理器 65 | stghandler = StgHandlerBase.factory( 66 | stg_class_obj=stg_func, 67 | strategy_params=strategy_params, 68 | md_agent_params_list=md_agent_params_list, 69 | **run_mode_realtime_params) 70 | 71 | if DEBUG: 72 | stghandler.run() 73 | else: 74 | # 开始执行策略 75 | stghandler.start() 76 | try: 77 | while True: 78 | # 策略执行 2 分钟后关闭 79 | time.sleep(2) 80 | except KeyboardInterrupt: 81 | logger.warning('程序中断中...') 82 | 83 | stghandler.keep_running = False 84 | stghandler.join(timeout=2) 85 | 86 | logger.info("执行结束") 87 | 88 | 89 | if __name__ == "__main__": 90 | 91 | main(standalone_mode=False) 92 | -------------------------------------------------------------------------------- /run_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | screen_name="worker" 3 | echo create $screen_name 4 | screen -dmS $screen_name 5 | screen -x -S $screen_name -p 0 -X stuff "cd /home/mushrooman/wspy/run_sh\n" 6 | screen -x -S $screen_name -p 0 -X stuff "./worker.sh\n" 7 | 8 | screen_name="beat" 9 | echo create $screen_name 10 | screen -dmS $screen_name 11 | screen -x -S $screen_name -p 0 -X stuff "cd /home/mushrooman/wspy/run_sh\n" 12 | screen -x -S $screen_name -p 0 -X stuff "./beat.sh\n" 13 | 14 | screen_name="feeder" 15 | echo create $screen_name 16 | screen -dmS $screen_name 17 | screen -x -S $screen_name -p 0 -X stuff "cd /home/mushrooman/wspy/run_sh\n" 18 | screen -x -S $screen_name -p 0 -X stuff "./feeder.sh\n" 19 | 20 | screen_name="trader" 21 | echo create $screen_name 22 | screen -dmS $screen_name 23 | screen -x -S $screen_name -p 0 -X stuff "cd /home/mushrooman/wspy/run_sh\n" 24 | screen -x -S $screen_name -p 0 -X stuff "./trader.sh\n" 25 | 26 | screen -ls 27 | -------------------------------------------------------------------------------- /strategy/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2017/11/18 4 | @author: MG 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /strategy/bs_against_files/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/10/9 8:27 6 | @File : __init__.py.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | 11 | -------------------------------------------------------------------------------- /strategy/bs_against_files/csv_orders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author : MG 4 | @Time : 2017/11/18 5 | @author : MG 6 | @desc : 接受文件订单 7 | 追踪tick级行情实时成交(仅适合小资金) 8 | 追踪tick级行情固定点位止损 9 | 目前仅支持做多,不支持做空 10 | 11 | 每15秒进行一次文件检查 12 | 文件格式(csv xls),每个symbol一行,不可重复,例:卖出eth 买入eos, 不考虑套利的情况(套利需要单独开发其他策略) 13 | 显示如下(非文件格式,csv文件以‘,’为分隔符,这个近为视觉好看,以表格形式显示): 14 | currency symbol weight stop_loss_rate 15 | eth ethusdt 0.5 0.3 16 | eos eosusdt 0.5 0.4 17 | """ 18 | import threading 19 | import time 20 | import logging 21 | import pandas as pd 22 | from datetime import datetime, timedelta 23 | from abat.strategy import StgBase, StgHandlerBase 24 | from config import Config 25 | from abat.common import PeriodType, RunMode, BacktestTradeMode, Direction, PositionDateType 26 | from collections import defaultdict 27 | import os 28 | # 下面代码是必要的引用 29 | # md_agent md_agent 并没有“显式”的被使用,但是在被引用期间,已经将相应的 agent 类注册到了相应的列表中 30 | import agent.md_agent 31 | import agent.td_agent 32 | 33 | DEBUG = False 34 | 35 | 36 | class TargetPosition: 37 | 38 | def __init__(self, direction, currency, position, symbol, 39 | price=None, stop_loss_price=None, has_stop_loss=False, gap_threshold_vol=None): 40 | self.direction = direction 41 | self.currency = currency 42 | self.position = position 43 | self.symbol = symbol 44 | self.price = price 45 | self.stop_loss_price = stop_loss_price 46 | self.has_stop_loss = has_stop_loss 47 | self.gap_threshold_vol = gap_threshold_vol 48 | 49 | def check_stop_loss(self, close): 50 | """ 51 | 根据当前价格计算是否已经到达止损点位 52 | 如果此前已经到达过止损点位则不再比较,也不需重置状态 53 | :param close: 54 | :return: 55 | """ 56 | # 如果此前已经到达过止损点位则不再比较,也不需重置状态 57 | if self.stop_loss_price is None or self.has_stop_loss: 58 | return 59 | self.has_stop_loss = (self.direction == Direction.Long and close < self.stop_loss_price) or ( 60 | self.direction == Direction.Short and close > self.stop_loss_price) 61 | if self.has_stop_loss: 62 | logging.warning('%s 处于止损状态。止损价格 %f 当前价格 %f', self.symbol, self.stop_loss_price, close) 63 | 64 | def get_target_position(self): 65 | return self.direction, self.currency, self.position, self.symbol, \ 66 | self.price, self.stop_loss_price, self.has_stop_loss, \ 67 | self.gap_threshold_vol 68 | 69 | 70 | class ReadFileStg(StgBase): 71 | _folder_path = os.path.abspath(os.path.join(os.path.curdir, r'file_order')) 72 | 73 | def __init__(self): 74 | super().__init__() 75 | self._mutex = threading.Lock() 76 | self._last_check_datetime = datetime.now() - timedelta(minutes=1) 77 | self.interval_timedelta = timedelta(seconds=15) 78 | self.symbol_target_position_dic = {} 79 | # 设定相应周期的事件驱动句柄 接收的参数类型 80 | self._on_period_event_dic[PeriodType.Tick].param_type = dict 81 | # 记录合约最近一次执行操作的时间 82 | self.symbol_last_deal_datetime = {} 83 | # 记录合约最近一个发送买卖请求的时间 84 | self.instrument_lastest_order_datetime_dic = {} 85 | # 目前由于交易是异步执行,在尚未记录每一笔订单的情况下,时间太短可能会导致仓位与请求但出现不同步现象,导致下单过多的问题 86 | self.timedelta_between_deal = timedelta(seconds=3) 87 | self.min_order_vol = 0.1 88 | self.symbol_latest_price_dic = defaultdict(float) 89 | 90 | def fetch_pos_by_file(self): 91 | """读取仓位配置csv文件,返回目标仓位DataFrame""" 92 | # 检查最近一次文件检查的时间,避免重复查询 93 | if self._last_check_datetime + self.interval_timedelta > datetime.now(): 94 | return 95 | # 获取文件列表 96 | file_name_list = os.listdir(self._folder_path) 97 | if file_name_list is None: 98 | # self.logger.info('No file') 99 | return 100 | # 读取所有 csv 文件 101 | position_df = None 102 | file_path_list = [] 103 | for file_name in file_name_list: 104 | file_base_name, file_extension = os.path.splitext(file_name) 105 | if file_extension.lower() != '.csv': 106 | continue 107 | file_path = os.path.join(self._folder_path, file_name) 108 | file_path_list.append(file_path) 109 | position_df_tmp = pd.read_csv(file_path) 110 | if position_df is None: 111 | position_df = position_df_tmp 112 | else: 113 | is_ok = True 114 | for col_name in ('currency', 'symbol', 'weight', 'stop_loss_price'): 115 | if col_name not in position_df_tmp.columns: 116 | is_ok = False 117 | self.logger.error('%s 文件格式不正确,缺少 %s 列数据', file_name, col_name) 118 | break 119 | 120 | if not is_ok: 121 | continue 122 | position_df = position_df.append(position_df_tmp) 123 | 124 | # 调试阶段暂时不重命名备份,不影响程序使用 125 | if not DEBUG: 126 | # 文件备份 127 | backup_file_name = file_base_name + datetime.now().strftime( 128 | '%Y-%m-%d %H_%M_%S') + file_extension + '.bak' 129 | os.rename(file_path, os.path.join(self._folder_path, backup_file_name)) 130 | 131 | return position_df, file_path_list 132 | 133 | def on_timer(self): 134 | """ 135 | 每15秒进行一次文件检查 136 | 获得目标持仓currency, 权重,止损点位 137 | 生成相应交易指令 138 | :param md_df: 139 | :param context: 140 | :return: 141 | """ 142 | with self._mutex: 143 | position_df, file_path_list = self.fetch_pos_by_file() 144 | if position_df is None or position_df.shape[0] == 0: 145 | return 146 | self.logger.debug('仓位调整目标:\n%s', position_df) 147 | target_holding_dic = position_df.set_index('currency').dropna().to_dict('index') 148 | if len(self.symbol_latest_price_dic) == 0: 149 | self.logger.warning('当前程序没有缓存到有效的最新价格数据,交易指令暂缓执行') 150 | return 151 | 152 | # {currency: (Direction, currency, target_position, symbol, target_price, stop_loss_price) 153 | symbol_target_position_dic = {} 154 | # 检查目标仓位与当前持仓是否相符,否则执行相应交易 155 | target_currency_set = set(list(position_df['currency'])) 156 | holding_currency_dic = self.get_holding_currency() 157 | # 检查是否所有持仓符合目标配置文件要求 158 | is_all_fit_target = True 159 | 160 | # 如果当前 currency 不在目标持仓列表里面,则卖出 161 | for num, (currency, balance_dic) in enumerate(holding_currency_dic.items(), start=1): 162 | # currency 在目标持仓中,无需清仓 163 | if currency in target_currency_set: 164 | continue 165 | # hc 为 货币交易所的一种手续费代币工具,不做交易使用 166 | if currency == 'hc': 167 | continue 168 | # 若持仓余额 小于 0.0001 则放弃清仓 169 | tot_balance = 0 170 | for _, dic in balance_dic.items(): 171 | tot_balance += dic['balance'] 172 | if tot_balance < 0.0001: 173 | continue 174 | 175 | symbol = self.get_symbol_by_currency(currency) 176 | symbol_target_position_dic[symbol] = TargetPosition(Direction.Long, currency, 0, symbol) 177 | is_all_fit_target = False 178 | 179 | # 生成目标持仓列表买入指令 180 | for num, (currency, position_dic) in enumerate(target_holding_dic.items()): 181 | weight = position_dic['weight'] 182 | stop_loss_price = position_dic['stop_loss_price'] 183 | symbol = self.get_symbol_by_currency(currency) 184 | target_vol, gap_threshold_vol = self.calc_vol(symbol, weight) 185 | if target_vol is None: 186 | self.logger.warning('%s 持仓权重 %.2f %% 无法计算目标持仓量', currency, weight * 100) 187 | continue 188 | # 检查当前持仓是否与目标持仓一致,如果一致则跳过 189 | # position_date_pos_info_dic = self.get_position(symbol) 190 | # if position_date_pos_info_dic is not None and len(position_date_pos_info_dic) > 0: 191 | # # 有持仓,比较是否满足目标仓位,否则下指令 192 | # position_cur = sum([pos_info['balance'] for pos_info in position_date_pos_info_dic.values()]) 193 | # position_gap = target_vol - position_cur 194 | # # 实盘情况下,很少绝对一致,在一定区间内即可 195 | # if position_gap > gap_threshold_vol: 196 | # # 当前合约累计持仓与目标持仓不一致,则添加目标持仓任务 197 | # is_all_fit_target = False 198 | # else: 199 | # is_all_fit_target = False 200 | # 无论仓位是否存在,均生成交易指令,待交易执行阶段进行比较(以上代码不影响是否生产建仓指令) 201 | 202 | # 多头目标持仓 203 | symbol_target_position_dic[symbol] = TargetPosition(Direction.Long, currency, target_vol, symbol, 204 | None, stop_loss_price, 205 | gap_threshold_vol=gap_threshold_vol) 206 | 207 | # if is_all_fit_target: 208 | # # 文件备份 file_path_list 209 | # for file_path in file_path_list: 210 | # file_base_name_with_path, file_extension = os.path.split(file_path) 211 | # backup_file_path = file_base_name_with_path + datetime.now().strftime( 212 | # '%Y-%m-%d %H_%M_%S') + file_extension + '.bak' 213 | # # 调试阶段暂时不重命名备份,不影响程序使用 214 | # os.rename(file_path, 215 | # os.path.join( 216 | # file_base_name_with_path, 217 | # datetime.now().strftime('%Y-%m-%d %H_%M_%S') + file_extension + '.bak')) 218 | # self.logger.info('备份仓位配置文件:%s -> %s', file_path, backup_file_path) 219 | # el 220 | if len(symbol_target_position_dic) > 0: 221 | self.symbol_target_position_dic = symbol_target_position_dic 222 | self.logger.info('发现新的目标持仓指令\n%s', symbol_target_position_dic) 223 | else: 224 | self.symbol_target_position_dic = None 225 | self.logger.debug('无仓位调整指令') 226 | 227 | def do_order(self, md_dic, instrument_id, order_vol, price=None, direction=Direction.Long, stop_loss_price=0, 228 | msg=""): 229 | # if True: 230 | # self.logger.info("%s %s %f 价格 %f [%s]", 231 | # instrument_id, '买入' if position > 0 else '卖出', position, price, msg) 232 | # return 233 | # position == 0 则代表无需操作 234 | # 执行交易 235 | if direction == Direction.Long: 236 | if order_vol == 0: 237 | return 238 | elif order_vol > 0: 239 | if price is None or price == 0: 240 | price = md_dic['close'] 241 | # TODO: 稍后按盘口卖一档价格挂单 242 | 243 | if DEBUG: 244 | # debug 模式下,价格不要真实成交,只要看一下有委托单就可以了 245 | price /= 2 246 | 247 | if stop_loss_price is not None and stop_loss_price > 0 and price <= stop_loss_price: 248 | self.logger.warning('已经出发止损价 %.6f 停止买入操作', stop_loss_price) 249 | return 250 | 251 | self.open_long(instrument_id, price, order_vol) 252 | self.logger.info("%s %s -> 开多 %.4f 价格:%.4f", instrument_id, msg, order_vol, price) 253 | elif order_vol < 0: 254 | if price is None or price == 0: 255 | price = md_dic['close'] 256 | # TODO: 稍后按盘口卖一档价格挂单 257 | 258 | if DEBUG: 259 | # debug 模式下,价格不要真实成交,只要看一下有委托单就可以了 260 | price += price 261 | 262 | order_vol_net = -order_vol 263 | self.close_long(instrument_id, price, order_vol_net) 264 | self.logger.info("%s %s -> 平多 %.4f 价格:%.4f", instrument_id, msg, order_vol_net, price) 265 | else: 266 | raise ValueError('目前不支持做空') 267 | self.instrument_lastest_order_datetime_dic[instrument_id] = datetime.now() 268 | 269 | def on_tick(self, md_dic, context): 270 | """ 271 | tick 级数据进行交易操作 272 | :param md_dic: 273 | :param context: 274 | :return: 275 | """ 276 | # self.logger.debug('get tick data: %s', md_dic) 277 | symbol = md_dic['symbol'] 278 | # 更新最新价格 279 | close_cur = md_dic['close'] 280 | self.symbol_latest_price_dic[symbol] = close_cur 281 | # 计算是否需要进行调仓操作 282 | if self.symbol_target_position_dic is None or symbol not in self.symbol_target_position_dic: 283 | return 284 | if self.datetime_last_update_position is None: 285 | logging.debug("尚未获取持仓数据,跳过") 286 | return 287 | 288 | target_currency = self.trade_agent.get_currency(symbol) 289 | # self.logger.debug('target_position_dic[%s]: %s', symbol, self.target_position_dic[symbol]) 290 | # 如果的当前合约近期存在交易回报,则交易回报时间一定要小于查询持仓时间: 291 | # 防止出现以及成交单持仓信息未及时更新导致的数据不同步问题 292 | if symbol in self.datetime_last_rtn_trade_dic: 293 | if target_currency not in self.datetime_last_update_position_dic: 294 | logging.debug("持仓数据中没有包含当前合约,最近一次成交回报时间:%s,跳过", 295 | self.datetime_last_rtn_trade_dic[symbol]) 296 | self.get_position(symbol, force_refresh=True) 297 | return 298 | if self.datetime_last_rtn_trade_dic[symbol] > self.datetime_last_update_position_dic[target_currency]: 299 | logging.debug("持仓数据尚未更新完成,最近一次成交回报时间:%s 晚于 最近一次持仓更新时间:%s", 300 | self.datetime_last_rtn_trade_dic[symbol], 301 | self.datetime_last_update_position_dic[target_currency]) 302 | self.get_position(symbol, force_refresh=True) 303 | return 304 | 305 | # 过于密集执行可能会导致重复下单的问题 306 | if symbol in self.symbol_last_deal_datetime: 307 | last_deal_datetime = self.symbol_last_deal_datetime[symbol] 308 | if last_deal_datetime + self.timedelta_between_deal > datetime.now(): 309 | # logging.debug("最近一次交易时间:%s,防止交易密度过大,跳过", last_deal_datetime) 310 | return 311 | 312 | with self._mutex: 313 | target_position = self.symbol_target_position_dic[symbol] 314 | target_position.check_stop_loss(close_cur) 315 | 316 | # 撤销所有相关订单 317 | self.cancel_order(symbol) 318 | 319 | # 计算目标仓位方向及交易数量 320 | position_date_pos_info_dic = self.get_position(symbol) 321 | if position_date_pos_info_dic is None: 322 | # 无当前持仓,有目标仓位,直接按照目标仓位进行开仓动作 323 | # target_direction, target_currency, target_position, symbol, target_price, \ 324 | # stop_loss_price, has_stop_loss, gap_threshold_vol = self.get_target_position(symbol) 325 | if not target_position.has_stop_loss: 326 | self.do_order(md_dic, symbol, target_position.position, target_position.price, 327 | target_position.direction, target_position.stop_loss_price, msg='当前无持仓') 328 | else: 329 | # 如果当前有持仓,执行两类动作: 330 | # 1)若 当前持仓与目标持仓不匹配,则进行相应的调仓操作 331 | # 2)若 当前持仓价格超出止损价位,则进行清仓操作 332 | 333 | position_holding = sum( 334 | [pos_info_dic['balance'] for pos_info_dic in position_date_pos_info_dic.values()]) 335 | self.logger.debug('当前 %s 持仓 %f', target_position.currency, position_holding) 336 | # 比较当前持仓总量与目标仓位是否一致 337 | # 如果当前有持仓,目标仓位也有持仓,则需要进一步比对 338 | # target_direction, target_currency, target_position, symbol, target_price, \ 339 | # stop_loss_price, has_stop_loss, gap_threshold_vol = self.get_target_position(symbol) 340 | if target_position.has_stop_loss: 341 | # 已经触发止损,如果依然有持仓,则进行持续清仓操作 342 | self.do_order(md_dic, symbol, -position_holding, None, 343 | target_position.direction, msg="止损") 344 | else: 345 | # 汇总全部同方向持仓,如果不够目标仓位,则加仓 346 | # 对全部的反方向持仓进行平仓 347 | 348 | # 如果持仓超过目标仓位,则平仓多出的部分,如果不足则补充多的部分 349 | position_gap = target_position.position - position_holding 350 | if position_gap > target_position.gap_threshold_vol: 351 | if position_holding < target_position.gap_threshold_vol: 352 | msg = '建仓' 353 | else: 354 | msg = "补充仓位" 355 | # 如果不足则补充多的部分 356 | self.do_order(md_dic, symbol, position_gap, target_position.price, 357 | target_position.direction, target_position.stop_loss_price, msg=msg) 358 | elif position_gap < - target_position.gap_threshold_vol: 359 | if target_position.position == 0: 360 | msg = '清仓' 361 | else: 362 | msg = "持仓超过目标仓位,减仓 %.4f" % position_gap 363 | # 如果持仓超过目标仓位,则平仓多出的部分 364 | self.do_order(md_dic, symbol, position_gap, target_position.price, 365 | target_position.direction, target_position.stop_loss_price, msg=msg) 366 | else: 367 | self.logger.debug('当前持仓 %f 与目标持仓%f 差距 %f 过小,忽略此调整', 368 | position_holding, target_position.position, position_gap) 369 | 370 | # 更新最近执行时间 371 | self.symbol_last_deal_datetime[symbol] = datetime.now() 372 | 373 | def get_symbol_by_currency(self, currency): 374 | """目前暂时仅支持currency 与 usdt 之间转换""" 375 | return currency + 'usdt' 376 | 377 | def calc_vol(self, symbol, weight, gap_threshold_precision=0.01): 378 | """ 379 | 根据权重及当前账号总市值,计算当前 symbol 对应多少 vol 380 | :param symbol: 381 | :param weight: 382 | :return: 383 | """ 384 | holding_currency_dic = self.get_holding_currency(exclude_usdt=False) 385 | # tot_value = sum([dic['balance'] * self.symbol_latest_price_dic[self.get_symbol_by_currency(currency)] 386 | # for currency, dic in holding_currency_dic.items()]) 387 | if symbol not in self.symbol_latest_price_dic or self.symbol_latest_price_dic[symbol] == 0: 388 | self.logger.error('%s 没有找到有效的最新价格', symbol) 389 | weight_vol = None 390 | gap_threshold_vol = None 391 | else: 392 | tot_value = 0 393 | for currency, dic in holding_currency_dic.items(): 394 | for pos_date_type, dic_sub in dic.items(): 395 | if currency == 'usdt': 396 | tot_value += dic_sub['balance'] 397 | else: 398 | tot_value += dic_sub['balance'] * self.symbol_latest_price_dic[ 399 | self.get_symbol_by_currency(currency)] 400 | 401 | weight_vol = tot_value * weight / self.symbol_latest_price_dic[symbol] 402 | gap_threshold_vol = tot_value * gap_threshold_precision / self.symbol_latest_price_dic[symbol] 403 | 404 | return weight_vol, gap_threshold_vol 405 | 406 | def get_target_position(self, symbol): 407 | dic = self.symbol_target_position_dic[symbol] 408 | return dic['direction'], dic['currency'], dic['position'], dic['symbol'], \ 409 | dic['price'], dic['stop_loss_price'], dic.setdefault('has_stop_loss', False), \ 410 | dic.setdefault('gap_threshold_vol', None) 411 | 412 | 413 | if __name__ == '__main__': 414 | logging.basicConfig(level=logging.DEBUG, format=Config.LOG_FORMAT) 415 | DEBUG = False 416 | # 参数设置 417 | strategy_params = {} 418 | md_agent_params_list = [ 419 | # { 420 | # 'name': 'min1', 421 | # 'md_period': PeriodType.Min1, 422 | # 'instrument_id_list': ['rb1805', 'i1801'], # ['jm1711', 'rb1712', 'pb1801', 'IF1710'], 423 | # 'init_md_date_to': '2017-9-1', 424 | # 'dict_or_df_as_param': dict 425 | # }, 426 | { 427 | 'name': 'tick', 428 | 'md_period': PeriodType.Tick, 429 | 'instrument_id_list': ['ethusdt', 'eosusdt'], # 430 | }] 431 | run_mode_realtime_params = { 432 | 'run_mode': RunMode.Realtime, 433 | 'enable_timer_thread': True, 434 | 'seconds_of_timer_interval': 15, 435 | } 436 | run_mode_backtest_params = { 437 | 'run_mode': RunMode.Backtest, 438 | 'date_from': '2017-9-4', 439 | 'date_to': '2017-9-27', 440 | 'init_cash': 1000000, 441 | 'trade_mode': BacktestTradeMode.Order_2_Deal 442 | } 443 | # run_mode = RunMode.BackTest 444 | # 初始化策略处理器 445 | stghandler = StgHandlerBase.factory( 446 | stg_class_obj=ReadFileStg, 447 | strategy_params=strategy_params, 448 | md_agent_params_list=md_agent_params_list, 449 | **run_mode_realtime_params) 450 | if DEBUG: 451 | stghandler.run() 452 | else: 453 | # 开始执行策略 454 | stghandler.start() 455 | # 策略执行 2 分钟后关闭 456 | time.sleep(120) 457 | stghandler.keep_running = False 458 | stghandler.join() 459 | 460 | logging.info("执行结束") 461 | # print(os.path.abspath(r'..\file_order')) 462 | -------------------------------------------------------------------------------- /strategy/bs_against_files/csv_orders_with_feedback.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author : MG 4 | @Time : 2017/11/18 5 | @author : MG 6 | @desc : 监控csv文件(每15秒) 7 | 对比当前日期与回测csv文件中最新记录是否匹配,存在最新交易请求,生成交易请求(order*.csv文件) 8 | 追踪tick级行情实时成交(仅适合小资金) 9 | 追踪tick级行情本金买入点 * N % 止损 10 | 目前仅支持做多,不支持做空 11 | 12 | 每15秒进行一次文件检查 13 | 文件格式(csv xls),每个symbol一行,不可重复,例:卖出eth 买入eos, 不考虑套利的情况(套利需要单独开发其他策略) 14 | 显示如下(非文件格式,csv文件以‘,’为分隔符,这个近为视觉好看,以表格形式显示): 15 | currency symbol weight stop_loss_rate 16 | eth ethusdt 0.5 0.3 17 | eos eosusdt 0.5 0.4 18 | """ 19 | import re 20 | import threading 21 | import time 22 | import logging 23 | import pandas as pd 24 | from datetime import datetime, timedelta, date 25 | from abat.strategy import StgBase, StgHandlerBase 26 | from abat.utils.fh_utils import str_2_date, get_folder_path 27 | from config import Config 28 | from abat.common import PeriodType, RunMode, BacktestTradeMode, Direction, PositionDateType 29 | from collections import defaultdict 30 | from backend import engine_md 31 | from backend.orm import SymbolPair 32 | from abat.utils.db_utils import with_db_session, get_db_session 33 | from sqlalchemy import func 34 | import os 35 | import json 36 | # 下面代码是必要的引用 37 | # md_agent md_agent 并没有“显式”的被使用,但是在被引用期间,已经将相应的 agent 类注册到了相应的列表中 38 | import agent.md_agent 39 | import agent.td_agent 40 | 41 | DEBUG = False 42 | # "TradeBookCryptoCurrency2018-10-08.csv" 43 | PATTERN_BACKTEST_FILE_NAME = re.compile(r'(?<=^TradeBookCryptoCurrency)[\d\-]{8,10}(?=.csv$)') 44 | PATTERN_ORDER_FILE_NAME = re.compile(r'^order.*\.csv$') 45 | PATTERN_FEEDBACK_FILE_NAME = re.compile(r'^feedback.*\.json$') 46 | 47 | 48 | class TargetPosition: 49 | 50 | def __init__(self, direction: Direction, currency, symbol, weight=None, position=None, price=None, 51 | stop_loss_rate=None, stop_loss_price=None, has_stop_loss=False, gap_threshold_vol=0.01): 52 | self.direction = direction 53 | self.currency = currency 54 | self.position = position 55 | self.symbol = symbol 56 | self.price = price 57 | self.stop_loss_rate = stop_loss_rate 58 | self.stop_loss_price = stop_loss_price 59 | self.has_stop_loss = has_stop_loss 60 | self.gap_threshold_vol = gap_threshold_vol 61 | self.weight = weight 62 | 63 | def check_stop_loss(self, close): 64 | """ 65 | 根据当前价格计算是否已经到达止损点位 66 | 如果此前已经到达过止损点位则不再比较,也不需重置状态 67 | :param close: 68 | :return: 69 | """ 70 | # 如果此前已经到达过止损点位则不再比较,也不需重置状态 71 | if self.stop_loss_price is None or self.has_stop_loss: 72 | return 73 | self.has_stop_loss = (self.direction == Direction.Long and close < self.stop_loss_price) or ( 74 | self.direction == Direction.Short and close > self.stop_loss_price) 75 | if self.has_stop_loss: 76 | logging.warning('%s 处于止损状态。止损价格 %f 当前价格 %f', self.symbol, self.stop_loss_price, close) 77 | 78 | def get_target_position(self): 79 | return self.direction, self.currency, self.position, self.symbol, \ 80 | self.price, self.stop_loss_price, self.has_stop_loss, \ 81 | self.gap_threshold_vol 82 | 83 | def to_dict(self): 84 | return {attr: getattr(self, attr) for attr in dir(self) 85 | if attr.find('_') != 0 and not callable(getattr(self, attr))} 86 | 87 | def __repr__(self): 88 | return f'' 91 | 92 | 93 | class ReadFileStg(StgBase): 94 | _folder_path = get_folder_path(r'file_order') 95 | 96 | def __init__(self, symbol_list=None): 97 | super().__init__() 98 | self.symbol_list = symbol_list 99 | self._mutex = threading.Lock() 100 | self._last_check_datetime = datetime.now() - timedelta(minutes=1) 101 | self.interval_timedelta = timedelta(seconds=15) 102 | self.symbol_target_position_dic = {} 103 | # 设定相应周期的事件驱动句柄 接收的参数类型 104 | self._on_period_event_dic[PeriodType.Tick].param_type = dict 105 | # 记录合约最近一次执行操作的时间 106 | self.symbol_last_deal_datetime = {} 107 | # 记录合约最近一个发送买卖请求的时间 108 | self.instrument_lastest_order_datetime_dic = {} 109 | # 目前由于交易是异步执行,在尚未记录每一笔订单的情况下,时间太短可能会导致仓位与请求但出现不同步现象,导致下单过多的问题 110 | self.timedelta_between_deal = timedelta(seconds=3) 111 | self.min_order_vol = 0.1 112 | self.symbol_latest_price_dic = defaultdict(float) 113 | self.weight = 1 if not DEBUG else 0.2 # 默认仓位权重 114 | self.stop_loss_rate = -0.03 115 | # 初始化 symbol 基本信息 116 | with with_db_session(engine_md) as session: 117 | symbol_info_list = session.query(SymbolPair).filter( 118 | func.concat(SymbolPair.base_currency, SymbolPair.quote_currency).in_(symbol_list)).all() 119 | self.symbol_info_dic = {symbol.base_currency+symbol.quote_currency: symbol for symbol in symbol_info_list} 120 | self.logger.info('接受订单文件目录:%s', self._folder_path) 121 | self.load_feedback_file() 122 | 123 | def fetch_pos_by_file(self): 124 | """读取仓位配置csv文件,返回目标仓位DataFrame""" 125 | # 检查最近一次文件检查的时间,避免重复查询 126 | if self._last_check_datetime + self.interval_timedelta > datetime.now(): 127 | return 128 | # 获取文件列表 129 | file_name_list = os.listdir(self._folder_path) 130 | if file_name_list is None: 131 | # self.logger.info('No file') 132 | return 133 | # 读取所有 csv 文件 134 | position_df = None 135 | file_path_list = [] 136 | for file_name in file_name_list: 137 | # 仅处理 order*.csv文件 138 | if PATTERN_ORDER_FILE_NAME.search(file_name) is None: 139 | continue 140 | self.logger.debug('处理文件 order 文件: %s', file_name) 141 | file_base_name, file_extension = os.path.splitext(file_name) 142 | file_path = os.path.join(self._folder_path, file_name) 143 | file_path_list.append(file_path) 144 | position_df_tmp = pd.read_csv(file_path) 145 | if position_df is None: 146 | position_df = position_df_tmp 147 | else: 148 | is_ok = True 149 | for col_name in ('currency', 'symbol', 'weight', 'stop_loss_rate'): 150 | if col_name not in position_df_tmp.columns: 151 | is_ok = False 152 | self.logger.error('%s 文件格式不正确,缺少 %s 列数据', file_name, col_name) 153 | break 154 | 155 | if not is_ok: 156 | continue 157 | position_df = position_df.append(position_df_tmp) 158 | 159 | # 调试阶段暂时不重命名备份,不影响程序使用 160 | if not DEBUG: 161 | # 文件备份 162 | backup_file_name = f"{file_base_name} {datetime.now().strftime('%Y-%m-%d %H_%M_%S')}" \ 163 | f"{file_extension}.bak" 164 | os.rename(file_path, os.path.join(self._folder_path, backup_file_name)) 165 | self.logger.info('备份 order 文件 %s -> %s', file_name, backup_file_name) 166 | 167 | return position_df, file_path_list 168 | 169 | def handle_backtest_file(self): 170 | """ 171 | 处理王淳的回测文件,生成相应的交易指令文件 172 | :return: 173 | """ 174 | with self._mutex: 175 | # 获取文件列表 176 | file_name_list = os.listdir(self._folder_path) 177 | if file_name_list is None: 178 | # self.logger.info('No file') 179 | return 180 | # 读取所有 csv 文件 181 | for file_name in file_name_list: 182 | file_base_name, file_extension = os.path.splitext(file_name) 183 | # 仅处理 order*.csv文件 184 | m = PATTERN_BACKTEST_FILE_NAME.search(file_name) 185 | if m is None: 186 | continue 187 | file_date_str = m.group() 188 | file_date = str_2_date(file_date_str) 189 | if file_date != date.today(): 190 | self.logger.warning('文件:%s 日期与当前系统日期 %s 不匹配,不予处理', file_name, date.today()) 191 | continue 192 | self.logger.debug('处理文件 %s 文件日期:%s', file_name, file_date_str) 193 | file_path = os.path.join(self._folder_path, file_name) 194 | data_df = pd.read_csv(file_path) 195 | if data_df is None or data_df.shape[0] == 0: 196 | continue 197 | if str_2_date(data_df.iloc[-1]['Date']) != file_date: 198 | self.logger.warning('文件:%s 回测记录中最新日期与当前文件日期 %s 不匹配,不予处理', file_name, file_date) 199 | continue 200 | 201 | # 生成交易指令文件 202 | currency = data_df.iloc[-1]['InstruLong'].lower() 203 | order_dic = { 204 | 'currency': [currency], 205 | 'symbol': [f'{currency}usdt'], 206 | 'weight': [self.weight], 207 | 'stop_loss_rate': [self.stop_loss_rate], 208 | } 209 | order_file_name = f'order_{file_date_str}.csv' 210 | order_file_path = os.path.join(self._folder_path, order_file_name) 211 | order_df = pd.DataFrame(order_dic) 212 | order_df.to_csv(order_file_path) 213 | # 调试阶段暂时不重命名备份,不影响程序使用 214 | if not DEBUG: 215 | # 文件备份 216 | backup_file_name = f"{file_base_name} {datetime.now().strftime('%Y-%m-%d %H_%M_%S')}" \ 217 | f"{file_extension}.bak" 218 | os.rename(file_path, os.path.join(self._folder_path, backup_file_name)) 219 | 220 | def handle_order_file(self): 221 | """ 222 | 获得目标持仓currency, 权重,止损点位 223 | 生成相应交易指令 224 | 另外,如果发现新的交易order文件,则将所有的 feedback 文件备份(根据新的order进行下单,生成新的feedback文件) 225 | :return: 226 | """ 227 | with self._mutex: 228 | position_df, file_path_list = self.fetch_pos_by_file() 229 | if position_df is None or position_df.shape[0] == 0: 230 | return 231 | if len(self.symbol_latest_price_dic) == 0: 232 | self.logger.warning('当前程序没有缓存到有效的最新价格数据,交易指令暂缓执行') 233 | return 234 | # 如果存在新的 order 指令,则将所有的 feedback 文件备份(根据新的order进行下单,生成新的feedback文件) 235 | self.backup_feedback_files() 236 | self.logger.debug('仓位调整目标:\n%s', position_df) 237 | target_holding_dic = position_df.set_index('currency').dropna().to_dict('index') 238 | 239 | # {currency: (Direction, currency, target_position, symbol, target_price, stop_loss_price) 240 | symbol_target_position_dic = {} 241 | # 检查目标仓位与当前持仓是否相符,否则执行相应交易 242 | target_currency_set = set(list(position_df['currency'])) 243 | holding_currency_dic = self.get_holding_currency() 244 | # 检查是否所有持仓符合目标配置文件要求 245 | is_all_fit_target = True 246 | 247 | # 如果当前 currency 不在目标持仓列表里面,则卖出 248 | for num, (currency, balance_dic) in enumerate(holding_currency_dic.items(), start=1): 249 | # currency 在目标持仓中,无需清仓 250 | if currency in target_currency_set: 251 | continue 252 | # hc 为 货币交易所的一种手续费代币工具,不做交易使用 253 | # if currency == 'hc': 254 | # continue 255 | # 若持仓余额 小于 0.0001 则放弃清仓 256 | tot_balance = 0 257 | for _, dic in balance_dic.items(): 258 | tot_balance += dic['balance'] 259 | if tot_balance < 0.0001: 260 | continue 261 | 262 | symbol = self.get_symbol_by_currency(currency) 263 | if self.symbol_list is not None and symbol not in self.symbol_list: 264 | self.logger.warning('%s 持仓: %.6f 不在当前订阅列表中,也不在目标持仓中,该持仓将不会被操作', 265 | symbol, tot_balance) 266 | continue 267 | self.logger.info('计划卖出 %s', symbol) 268 | # TODO: 最小下单量在数据库中有对应信息,有待改进 269 | weight = 0 270 | target_vol, gap_threshold_vol, stop_loss_price = self.calc_vol_and_stop_loss_price( 271 | symbol, weight, gap_threshold_precision=None) 272 | symbol_target_position_dic[symbol] = TargetPosition( 273 | Direction.Long, currency, symbol, 274 | weight=weight, position=0, gap_threshold_vol=gap_threshold_vol) 275 | is_all_fit_target = False 276 | 277 | # 生成目标持仓列表买入指令 278 | for num, (currency, position_dic) in enumerate(target_holding_dic.items()): 279 | weight = position_dic['weight'] 280 | stop_loss_rate = position_dic['stop_loss_rate'] 281 | # stop_loss_price = position_dic['stop_loss_rate'] 282 | symbol = self.get_symbol_by_currency(currency) 283 | target_vol, gap_threshold_vol, stop_loss_price = self.calc_vol_and_stop_loss_price( 284 | symbol, weight, stop_loss_rate) 285 | if target_vol is None: 286 | self.logger.warning('%s 持仓权重 %.2f %% 无法计算目标持仓量', currency, weight * 100) 287 | continue 288 | # 检查当前持仓是否与目标持仓一致,如果一致则跳过 289 | # position_date_pos_info_dic = self.get_position(symbol) 290 | # if position_date_pos_info_dic is not None and len(position_date_pos_info_dic) > 0: 291 | # # 有持仓,比较是否满足目标仓位,否则下指令 292 | # position_cur = sum([pos_info['balance'] for pos_info in position_date_pos_info_dic.values()]) 293 | # position_gap = target_vol - position_cur 294 | # # 实盘情况下,很少绝对一致,在一定区间内即可 295 | # if position_gap > gap_threshold_vol: 296 | # # 当前合约累计持仓与目标持仓不一致,则添加目标持仓任务 297 | # is_all_fit_target = False 298 | # else: 299 | # is_all_fit_target = False 300 | # 无论仓位是否存在,均生成交易指令,待交易执行阶段进行比较(以上代码不影响是否生产建仓指令) 301 | 302 | # 多头目标持仓 303 | self.logger.info('计划买入 %s 目标仓位:%f 止损价:%f', symbol, target_vol, stop_loss_price) 304 | symbol_target_position_dic[symbol] = TargetPosition( 305 | Direction.Long, currency, symbol, 306 | weight=weight, position=target_vol, price=None, 307 | stop_loss_rate=stop_loss_rate, stop_loss_price=stop_loss_price, gap_threshold_vol=gap_threshold_vol) 308 | 309 | symbol_target_position_dic_len = len(symbol_target_position_dic) 310 | if symbol_target_position_dic_len > 0: 311 | self.symbol_target_position_dic = symbol_target_position_dic 312 | self.logger.info('发现新的目标持仓指令:') 313 | self.logger_symbol_target_position_dic() 314 | # 生成 feedback 文件 315 | self.create_feedback_file() 316 | else: 317 | self.symbol_target_position_dic = None 318 | self.logger.debug('无仓位调整指令') 319 | 320 | def logger_symbol_target_position_dic(self): 321 | """ 322 | 展示当前目标持仓信息 323 | :return: 324 | """ 325 | symbol_target_position_dic_len = len(self.symbol_target_position_dic) 326 | for num, (key, val) in enumerate(self.symbol_target_position_dic.items()): 327 | self.logger.info('%d/%d) %s, %r', num, symbol_target_position_dic_len, key, val) 328 | 329 | def on_timer(self): 330 | """ 331 | 每15秒进行一次文件检查 332 | 1)检查王淳的回测文件,匹配最新日期 "TradeBookCryptoCurrency2018-10-08.csv" 中的日期是否与系统日期一致,如果一致则处理,生成“交易指令文件” 333 | 2)生成相应目标仓位文件 order_2018-10-08.csv 334 | :param md_df: 335 | :param context: 336 | :return: 337 | """ 338 | self.get_balance() 339 | self.handle_backtest_file() 340 | self.handle_order_file() 341 | 342 | def do_order(self, md_dic, instrument_id, order_vol, price=None, direction=Direction.Long, stop_loss_price=0, 343 | msg=""): 344 | # if True: 345 | # self.logger.info("%s %s %f 价格 %f [%s]", 346 | # instrument_id, '买入' if position > 0 else '卖出', position, price, msg) 347 | # return 348 | # position == 0 则代表无需操作 349 | # 执行交易 350 | if direction == Direction.Long: 351 | if order_vol == 0: 352 | return 353 | elif order_vol > 0: 354 | if price is None or price == 0: 355 | price = md_dic['close'] 356 | # TODO: 稍后按盘口卖一档价格挂单 357 | 358 | # if DEBUG: 359 | # # debug 模式下,价格不要真实成交,只要看一下有委托单就可以了 360 | # price /= 2 361 | 362 | if stop_loss_price is not None and stop_loss_price > 0 and price <= stop_loss_price: 363 | self.logger.warning('%s 当前价格 %.6f 已经触发止损价 %.6f 停止买入操作', 364 | instrument_id, price, stop_loss_price) 365 | return 366 | 367 | self.open_long(instrument_id, price, order_vol) 368 | self.logger.info("%s %s -> 开多 %.4f 价格:%.4f", instrument_id, msg, order_vol, price) 369 | elif order_vol < 0: 370 | if price is None or price == 0: 371 | price = md_dic['close'] 372 | # TODO: 稍后按盘口卖一档价格挂单 373 | 374 | # if DEBUG: 375 | # # debug 模式下,价格不要真实成交,只要看一下有委托单就可以了 376 | # price += price 377 | 378 | order_vol_net = -order_vol 379 | self.close_long(instrument_id, price, order_vol_net) 380 | self.logger.info("%s %s -> 平多 %.4f 价格:%.4f", instrument_id, msg, order_vol_net, price) 381 | else: 382 | raise ValueError('目前不支持做空') 383 | self.instrument_lastest_order_datetime_dic[instrument_id] = datetime.now() 384 | 385 | def on_tick(self, md_dic, context): 386 | """ 387 | tick 级数据进行交易操作 388 | :param md_dic: 389 | :param context: 390 | :return: 391 | """ 392 | # self.logger.debug('get tick data: %s', md_dic) 393 | symbol = md_dic['symbol'] 394 | # 更新最新价格 395 | close_cur = md_dic['close'] 396 | self.symbol_latest_price_dic[symbol] = close_cur 397 | # 计算是否需要进行调仓操作 398 | if self.symbol_target_position_dic is None or symbol not in self.symbol_target_position_dic: 399 | # self.logger.debug("当前 symbol='%s' 无操作", symbol) 400 | return 401 | if self.datetime_last_update_position is None: 402 | self.logger.debug("尚未获取持仓数据,跳过") 403 | return 404 | 405 | target_position = self.symbol_target_position_dic[symbol] 406 | # 权重为空,或者清仓的情况下,无需重新计算仓位 407 | if target_position.weight is not None and not(target_position.weight == 0 and target_position.position == 0): 408 | # target_position 为交易指令生产是产生的止损价格,无需浮动,否则永远无法止损了 409 | target_vol, gap_threshold_vol, stop_loss_price = self.calc_vol_and_stop_loss_price( 410 | symbol, target_position.weight, target_position.stop_loss_rate) 411 | target_position.position = target_vol 412 | target_position.gap_threshold_vol = gap_threshold_vol 413 | 414 | # target_currency = self.trade_agent.get_currency(symbol) 415 | target_currency = target_position.currency 416 | # self.logger.debug('target_position_dic[%s]: %s', symbol, self.target_position_dic[symbol]) 417 | # 如果的当前合约近期存在交易回报,则交易回报时间一定要小于查询持仓时间: 418 | # 防止出现以及成交单持仓信息未及时更新导致的数据不同步问题 419 | if symbol in self.datetime_last_rtn_trade_dic: 420 | if target_currency not in self.datetime_last_update_position_dic: 421 | self.logger.debug("%s 持仓数据中没有包含当前合约,最近一次成交回报时间:%s,跳过", 422 | target_currency, self.datetime_last_rtn_trade_dic[symbol]) 423 | self.get_position(symbol, force_refresh=True) 424 | # 此处可以不 return 因为当前火币交易所接口是同步返回持仓结果的 425 | # 不过为了兼容其他交易所,因此统一使用这种方式 426 | return 427 | if self.datetime_last_rtn_trade_dic[symbol] > self.datetime_last_update_position_dic[target_currency]: 428 | self.logger.debug("%s 持仓数据尚未更新完成,最近一次成交回报时间:%s > 最近一次持仓更新时间:%s", 429 | target_currency, 430 | self.datetime_last_rtn_trade_dic[symbol], 431 | self.datetime_last_update_position_dic[target_currency]) 432 | self.get_position(symbol, force_refresh=True) 433 | # 此处可以不 return 因为当前火币交易所接口是同步返回持仓结果的 434 | # 不过为了兼容其他交易所,因此统一使用这种方式 435 | return 436 | 437 | # 过于密集执行可能会导致重复下单的问题 438 | if symbol in self.symbol_last_deal_datetime: 439 | last_deal_datetime = self.symbol_last_deal_datetime[symbol] 440 | if last_deal_datetime + self.timedelta_between_deal > datetime.now(): 441 | # logging.debug("最近一次交易时间:%s,防止交易密度过大,跳过", last_deal_datetime) 442 | return 443 | 444 | with self._mutex: 445 | target_position.check_stop_loss(close_cur) 446 | # self.logger.debug("当前持仓目标:%r", target_position) 447 | # 撤销所有相关订单 448 | self.cancel_order(symbol) 449 | 450 | # 计算目标仓位方向及交易数量 451 | position_date_pos_info_dic = self.get_position(symbol) 452 | if position_date_pos_info_dic is None: 453 | # 无当前持仓,有目标仓位,直接按照目标仓位进行开仓动作 454 | # target_direction, target_currency, target_position, symbol, target_price, \ 455 | # stop_loss_price, has_stop_loss, gap_threshold_vol = self.get_target_position(symbol) 456 | if not target_position.has_stop_loss: 457 | self.do_order(md_dic, symbol, target_position.position, target_position.price, 458 | target_position.direction, target_position.stop_loss_price, msg='当前无持仓') 459 | else: 460 | # 如果当前有持仓,执行两类动作: 461 | # 1)若 当前持仓与目标持仓不匹配,则进行相应的调仓操作 462 | # 2)若 当前持仓价格超出止损价位,则进行清仓操作 463 | 464 | position_holding = sum( 465 | [pos_info_dic['balance'] for pos_info_dic in position_date_pos_info_dic.values()]) 466 | # 比较当前持仓总量与目标仓位是否一致 467 | # 如果当前有持仓,目标仓位也有持仓,则需要进一步比对 468 | # target_direction, target_currency, target_position, symbol, target_price, \ 469 | # stop_loss_price, has_stop_loss, gap_threshold_vol = self.get_target_position(symbol) 470 | if target_position.has_stop_loss: 471 | if position_holding <= target_position.gap_threshold_vol: 472 | self.logger.debug('当前 %s 持仓 %f -> %f 价格 %.6f 已处于止损状态,剩余仓位低于阀值 %f 无需进一步清仓', 473 | target_currency, position_holding, target_position.position, close_cur, 474 | target_position.gap_threshold_vol) 475 | else: 476 | self.logger.debug('当前 %s 持仓 %f -> %f 价格 %.6f 已处于止损状态', 477 | target_currency, position_holding, target_position.position, close_cur) 478 | # 已经触发止损,如果依然有持仓,则进行持续清仓操作 479 | self.do_order(md_dic, symbol, -position_holding, None, 480 | target_position.direction, msg="止损") 481 | else: 482 | # 汇总全部同方向持仓,如果不够目标仓位,则加仓 483 | # 对全部的反方向持仓进行平仓 484 | 485 | # 如果持仓超过目标仓位,则平仓多出的部分,如果不足则补充多的部分 486 | position_gap = target_position.position - position_holding 487 | if position_gap > target_position.gap_threshold_vol: 488 | if position_holding < target_position.gap_threshold_vol: 489 | msg = '建仓' 490 | else: 491 | msg = "补充仓位" 492 | # 如果不足则补充多的部分 493 | self.logger.debug('当前 %s 持仓 %f -> %f 价格 %.6f %s', 494 | target_currency, position_holding, target_position.position, close_cur, msg) 495 | self.do_order(md_dic, symbol, position_gap, target_position.price, 496 | target_position.direction, target_position.stop_loss_price, msg=msg) 497 | elif position_gap < - target_position.gap_threshold_vol: 498 | if target_position.position == 0: 499 | msg = '清仓' 500 | else: 501 | msg = "持仓超过目标仓位,减仓 %.4f" % position_gap 502 | # 如果持仓超过目标仓位,则平仓多出的部分 503 | self.logger.debug('当前 %s 持仓 %f -> %f 价格 %.6f %s', 504 | target_currency, position_holding, target_position.position, close_cur, msg) 505 | self.do_order(md_dic, symbol, position_gap, target_position.price, 506 | target_position.direction, target_position.stop_loss_price, msg=msg) 507 | else: 508 | self.logger.debug('当前 %s 持仓 %f -> %f 差距 %f 小于最小调整范围 %f,忽略此调整', 509 | target_currency, position_holding, target_position.position, position_gap, 510 | target_position.gap_threshold_vol) 511 | 512 | # 更新最近执行时间 513 | self.symbol_last_deal_datetime[symbol] = datetime.now() 514 | 515 | def get_symbol_by_currency(self, currency): 516 | """目前暂时仅支持currency 与 usdt 之间转换""" 517 | return currency + 'usdt' 518 | 519 | def calc_vol_and_stop_loss_price(self, symbol, weight, stop_loss_rate=None, gap_threshold_precision: (int, None)=2): 520 | """ 521 | 根据权重及当前账号总市值,计算当前 symbol 对应多少 vol, 根据 stop_loss_rate 计算止损价格(目前仅考虑做多的情况) 522 | :param symbol: 523 | :param weight: 524 | :param stop_loss_rate: 为空则不计算 525 | :param gap_threshold_precision: 为了避免反复调整,造成手续费摊高,设置最小调整精度, None 则使用当前Symbol默认值 526 | :return: 527 | """ 528 | if gap_threshold_precision is None: 529 | gap_threshold_precision = self.symbol_info_dic[symbol].amount_precision 530 | holding_currency_dic = self.get_holding_currency(exclude_usdt=False) 531 | # tot_value = sum([dic['balance'] * self.symbol_latest_price_dic[self.get_symbol_by_currency(currency)] 532 | # for currency, dic in holding_currency_dic.items()]) 533 | if symbol not in self.symbol_latest_price_dic or self.symbol_latest_price_dic[symbol] == 0: 534 | self.logger.error('%s 没有找到有效的最新价格', symbol) 535 | price_latest = None 536 | weight_vol = None 537 | gap_threshold_vol = None 538 | stop_loss_price = None 539 | else: 540 | tot_value = 0 541 | for currency, dic in holding_currency_dic.items(): 542 | for pos_date_type, dic_sub in dic.items(): 543 | if currency == 'usdt': 544 | tot_value += dic_sub['balance'] 545 | else: 546 | tot_value += dic_sub['balance'] * self.symbol_latest_price_dic[ 547 | self.get_symbol_by_currency(currency)] 548 | 549 | price_latest = self.symbol_latest_price_dic[symbol] 550 | weight_vol = tot_value * weight / price_latest 551 | gap_threshold_vol = None if gap_threshold_precision is None else \ 552 | tot_value * (0.1 ** gap_threshold_precision) / price_latest 553 | stop_loss_price = None if stop_loss_rate is None else \ 554 | price_latest * (1 + stop_loss_rate) 555 | 556 | self.logger.debug('%s price_latest=%.4f, weight_vol=%f [%.1f%%], gap_threshold_vol=%.4f, stop_loss_price=%.4f', 557 | symbol, 558 | 0 if price_latest is None else price_latest, 559 | 0 if weight_vol is None else weight_vol, 560 | 0 if weight is None else (weight * 100), 561 | 0 if gap_threshold_vol is None else gap_threshold_vol, 562 | 0 if stop_loss_price is None else stop_loss_price, 563 | ) 564 | return weight_vol, gap_threshold_vol, stop_loss_price 565 | 566 | def get_target_position(self, symbol): 567 | dic = self.symbol_target_position_dic[symbol] 568 | return dic['direction'], dic['currency'], dic['position'], dic['symbol'], \ 569 | dic['price'], dic['stop_loss_price'], dic.setdefault('has_stop_loss', False), \ 570 | dic.setdefault('gap_threshold_vol', None) 571 | 572 | def backup_feedback_files(self): 573 | """ 574 | 将所有的 feedback 文件备份 575 | :return: 576 | """ 577 | # 获取文件列表 578 | file_name_list = os.listdir(self._folder_path) 579 | if file_name_list is None: 580 | # self.logger.info('No file') 581 | return 582 | 583 | for file_name in file_name_list: 584 | # 仅处理 feedback*.csv文件 585 | if PATTERN_FEEDBACK_FILE_NAME.search(file_name) is None: 586 | continue 587 | file_base_name, file_extension = os.path.splitext(file_name) 588 | file_path = os.path.join(self._folder_path, file_name) 589 | # 文件备份 590 | backup_file_name = f"{file_base_name} {datetime.now().strftime('%Y-%m-%d %H_%M_%S')}" \ 591 | f"{file_extension}.bak" 592 | os.rename(file_path, os.path.join(self._folder_path, backup_file_name)) 593 | self.logger.info('备份 Feedback 文件 %s -> %s', file_name, backup_file_name) 594 | 595 | def create_feedback_file(self): 596 | """ 597 | 根据 symbol_target_position_dic 创建 feedback 文件 598 | :return: 599 | """ 600 | symbol_target_position_dic = self.symbol_target_position_dic 601 | data_dic = {} 602 | for key, val in symbol_target_position_dic.items(): 603 | val_dic = val.to_dict() 604 | val_dic['direction'] = int(val_dic['direction']) 605 | data_dic[key] = val_dic 606 | 607 | file_name = f"feedback_{datetime.now().strftime('%Y-%m-%d %H_%M_%S')}.json" 608 | file_path = os.path.join(self._folder_path, file_name) 609 | with open(file_path, 'w') as file: 610 | json.dump(data_dic, file) 611 | self.logger.info('生产 feedback 文件:%s', file_name) 612 | return file_path 613 | 614 | def load_feedback_file(self): 615 | """ 616 | 加载 feedback 文件,更新 self.symbol_target_position_dic 617 | :return: 618 | """ 619 | # 获取文件列表 620 | file_name_list = os.listdir(self._folder_path) 621 | if file_name_list is None or len(file_name_list) == 0: 622 | # self.logger.info('No file') 623 | return 624 | # 读取所有 csv 文件 625 | for file_name in file_name_list: 626 | # 仅处理 order*.csv文件 627 | if PATTERN_FEEDBACK_FILE_NAME.search(file_name) is None: 628 | continue 629 | self.logger.debug('处理文件 feedback文件: %s', file_name) 630 | file_path = os.path.join(self._folder_path, file_name) 631 | 632 | with open(file_path) as file: 633 | data_dic = json.load(file) 634 | # 构建 symbol_target_position_dic 对象 635 | symbol_target_position_dic = {} 636 | for key, val in data_dic.items(): 637 | val['direction'] = Direction(val['direction']) 638 | symbol_target_position_dic[key] = TargetPosition(**val) 639 | 640 | self.symbol_target_position_dic = symbol_target_position_dic 641 | self.logger.info('加载 feedback 文件:%s', file_name) 642 | self.logger_symbol_target_position_dic() 643 | break 644 | else: 645 | logging.info('没有可用的 feedback 文件可加载') 646 | 647 | 648 | if __name__ == '__main__': 649 | logging.basicConfig(level=logging.DEBUG, format=Config.LOG_FORMAT) 650 | DEBUG = False 651 | symbol_list = ['ethusdt', 'eosusdt'] 652 | # 参数设置 653 | strategy_params = {'symbol_list': symbol_list} 654 | md_agent_params_list = [ 655 | # { 656 | # 'name': 'min1', 657 | # 'md_period': PeriodType.Min1, 658 | # 'instrument_id_list': ['rb1805', 'i1801'], # ['jm1711', 'rb1712', 'pb1801', 'IF1710'], 659 | # 'init_md_date_to': '2017-9-1', 660 | # 'dict_or_df_as_param': dict 661 | # }, 662 | { 663 | 'name': 'tick', 664 | 'md_period': PeriodType.Tick, 665 | 'instrument_id_list': symbol_list, # 666 | }] 667 | run_mode_realtime_params = { 668 | 'run_mode': RunMode.Realtime, 669 | 'enable_timer_thread': True, 670 | 'seconds_of_timer_interval': 15, 671 | } 672 | run_mode_backtest_params = { 673 | 'run_mode': RunMode.Backtest, 674 | 'date_from': '2017-9-4', 675 | 'date_to': '2017-9-27', 676 | 'init_cash': 1000000, 677 | 'trade_mode': BacktestTradeMode.Order_2_Deal 678 | } 679 | # run_mode = RunMode.BackTest 680 | # 初始化策略处理器 681 | stghandler = StgHandlerBase.factory( 682 | stg_class_obj=ReadFileStg, 683 | strategy_params=strategy_params, 684 | md_agent_params_list=md_agent_params_list, 685 | **run_mode_realtime_params) 686 | if DEBUG: 687 | stghandler.run() 688 | else: 689 | # 开始执行策略 690 | stghandler.start() 691 | # 策略执行 2 分钟后关闭 692 | time.sleep(180) 693 | stghandler.keep_running = False 694 | stghandler.join() 695 | 696 | logging.info("执行结束") 697 | # print(os.path.abspath(r'..\file_order')) 698 | -------------------------------------------------------------------------------- /strategy/file_strategy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2017/11/18 4 | @author: MG 5 | """ 6 | import threading 7 | import time 8 | import logging 9 | import pandas as pd 10 | from datetime import datetime, timedelta 11 | from abat.strategy import StgBase, StgHandlerBase 12 | from config import Config 13 | from abat.common import PeriodType, RunMode, BacktestTradeMode, Direction, PositionDateType 14 | import os 15 | # 下面代码是必要的引用 16 | # md_agent md_agent 并没有“显式”的被使用,但是在被引用期间,已经将相应的 agent 类注册到了相应的列表中 17 | import agent.md_agent 18 | import agent.td_agent 19 | 20 | 21 | class ReadFileStg(StgBase): 22 | _folder_path = os.path.abspath(r'.\file_order') 23 | 24 | def __init__(self): 25 | super().__init__() 26 | self._mutex = threading.Lock() 27 | self._last_check_datetime = datetime.now() - timedelta(minutes=1) 28 | self.interval_timedelta = timedelta(seconds=15) 29 | self.target_position_dic = {} 30 | # 设定相应周期的事件驱动句柄 接收的参数类型 31 | self._on_period_event_dic[PeriodType.Tick].param_type = dict 32 | # 记录合约最近一次执行操作的时间 33 | self.symbol_last_deal_datetime = {} 34 | # 记录合约最近一个发送买卖请求的时间 35 | self.instrument_lastest_order_datetime_dic = {} 36 | # 目前由于交易是异步执行,在尚未记录每一笔订单的情况下,时间太短可能会导致仓位与请求但出现不同步现象,导致下单过多的问题 37 | self.timedelta_between_deal = timedelta(seconds=3) 38 | 39 | def fetch_pos_by_file(self): 40 | # 检查最近一次文件检查的时间,避免重复查询 41 | if self._last_check_datetime + self.interval_timedelta > datetime.now(): 42 | return 43 | # 获取文件列表 44 | file_name_list = os.listdir(self._folder_path) 45 | if file_name_list is None: 46 | # self.logger.info('No file') 47 | return 48 | # 读取所有 csv 文件 49 | position_df = None 50 | for file_name in file_name_list: 51 | file_base_name, file_extension = os.path.splitext(file_name) 52 | if file_extension.lower() != '.csv': 53 | continue 54 | file_path = os.path.join(self._folder_path, file_name) 55 | position_df_tmp = pd.read_csv(file_path) 56 | if position_df is None: 57 | position_df = position_df_tmp 58 | else: 59 | position_df = position_df.append(position_df_tmp) 60 | 61 | # 文件备份 62 | backup_file_name = file_base_name + datetime.now().strftime( 63 | '%Y-%m-%d %H_%M_%S') + file_extension + '.bak' 64 | # 调试阶段暂时不重命名备份,不影响程序使用 65 | os.rename(file_path, os.path.join(self._folder_path, backup_file_name)) 66 | 67 | return position_df 68 | 69 | def on_timer(self): 70 | """ 71 | 每15秒进行一次文件检查 72 | 文件格式(csv xls),每个symbol一行,不可重复,例:卖出eth 买入eos, 不考虑套利的情况(套利需要单独开发其他策略) 73 | currency target_vol symbol price 74 | eth 0 ethusdt 800 75 | eos 200 eosusdt 40 76 | :param md_df: 77 | :param context: 78 | :return: 79 | """ 80 | with self._mutex: 81 | position_df = self.fetch_pos_by_file() 82 | if position_df is None or position_df.shape[0] == 0: 83 | return 84 | # 检查目标仓位与当前持仓是否相符,不符,则执行相应交易 85 | 86 | long_pos_df = position_df.set_index('currency').dropna() 87 | self.logger.debug('仓位调整方案:\n%s', long_pos_df) 88 | target_position_dic = {} 89 | for num, (currency, (target_vol, symbol, price)) in enumerate(long_pos_df.T.items()): 90 | # 检查当前持仓是否与目标持仓一致,如果一致则清空 self.target_position 91 | position_cur = 0 92 | position_date_pos_info_dic = self.get_position(symbol) 93 | if position_date_pos_info_dic is not None and len(position_date_pos_info_dic) > 0: 94 | is_fit = False 95 | for position_date_type, pos_info in position_date_pos_info_dic.items(): 96 | if pos_info['balance'] == target_vol: 97 | is_fit = True 98 | break 99 | else: 100 | position_cur += pos_info['balance'] 101 | 102 | if is_fit: 103 | continue 104 | 105 | position_gap = target_vol - position_cur 106 | if position_gap == 0: 107 | continue 108 | # 当前合约累计持仓与目标持仓不一致,则添加目标持仓任务 109 | # 多头目标持仓 110 | target_position_dic[symbol] = (Direction.Long, currency, target_vol, symbol, price) 111 | 112 | self.target_position_dic = target_position_dic 113 | if len(target_position_dic) > 0: 114 | self.logger.info('发现新的目标持仓指令\n%s', target_position_dic) 115 | 116 | def do_order(self, md_dic, instrument_id, position, price=None, direction=Direction.Long, msg=""): 117 | # if True: 118 | # self.logger.info("%s %s %f 价格 %f [%s]", 119 | # instrument_id, '买入' if position > 0 else '卖出', position, price, msg) 120 | # return 121 | # position == 0 则代表无需操作 122 | # 执行交易 123 | if direction == Direction.Long: 124 | if position == 0: 125 | return 126 | elif position > 0: 127 | if price is None: 128 | price = md_dic['close'] 129 | # TODO: 稍后按盘口卖一档价格挂单 130 | 131 | self.open_long(instrument_id, price, position) 132 | self.logger.info("%s %s -> 开多 %d %.0f", instrument_id, msg, position, price) 133 | elif position < 0: 134 | if price is None: 135 | price = md_dic['close'] 136 | # TODO: 稍后按盘口卖一档价格挂单 137 | position_net = -position 138 | self.close_long(instrument_id, price, position_net) 139 | self.logger.info("%s %s -> 平多 %d %.0f", instrument_id, msg, position_net, price) 140 | else: 141 | raise ValueError('目前不支持做空') 142 | self.instrument_lastest_order_datetime_dic[instrument_id] = datetime.now() 143 | 144 | def on_tick(self, md_dic, context): 145 | """ 146 | tick级数据进行交易操作 147 | :param md_dic: 148 | :param context: 149 | :return: 150 | """ 151 | # self.logger.debug('get tick data: %s', md_dic) 152 | if self.target_position_dic is None or len(self.target_position_dic) == 0: 153 | return 154 | if self.datetime_last_update_position is None: 155 | logging.debug("尚未获取持仓数据,跳过") 156 | return 157 | 158 | # self.logger.debug('target_position_dic: %s', self.target_position_dic) 159 | symbol = md_dic['symbol'] 160 | if symbol not in self.target_position_dic: 161 | return 162 | currency = self.trade_agent.get_currency(symbol) 163 | # self.logger.debug('target_position_dic[%s]: %s', symbol, self.target_position_dic[symbol]) 164 | # 如果的当前合约近期存在交易回报,则交易回报时间一定要小于查询持仓时间: 165 | # 防止出现以及成交单持仓信息未及时更新导致的数据不同步问题 166 | if symbol in self.datetime_last_rtn_trade_dic: 167 | if currency not in self.datetime_last_update_position_dic: 168 | logging.debug("持仓数据中没有包含当前合约,最近一次成交回报时间:%s,跳过", 169 | self.datetime_last_rtn_trade_dic[symbol]) 170 | self.get_position(symbol, force_refresh=True) 171 | return 172 | if self.datetime_last_rtn_trade_dic[symbol] > self.datetime_last_update_position_dic[currency]: 173 | logging.debug("持仓数据尚未更新完成,最近一次成交回报时间:%s,最近一次持仓更新时间:%s", 174 | self.datetime_last_rtn_trade_dic[symbol], 175 | self.datetime_last_update_position_dic[currency]) 176 | self.get_position(symbol, force_refresh=True) 177 | return 178 | 179 | # 过于密集执行可能会导致重复下单的问题 180 | if symbol in self.symbol_last_deal_datetime: 181 | last_deal_datetime = self.symbol_last_deal_datetime[symbol] 182 | if last_deal_datetime + self.timedelta_between_deal > datetime.now(): 183 | # logging.debug("最近一次交易时间:%s,防止交易密度过大,跳过", last_deal_datetime) 184 | return 185 | 186 | with self._mutex: 187 | 188 | # 撤销所有相关订单 189 | self.cancel_order(symbol) 190 | 191 | # 计算目标仓位方向及交易数量 192 | position_date_pos_info_dic = self.get_position(symbol) 193 | if position_date_pos_info_dic is None: 194 | # 如果当前无持仓,直接按照目标仓位进行开仓动作 195 | if symbol not in self.target_position_dic: 196 | # 当前无持仓,目标仓位也没有 197 | pass 198 | else: 199 | # 当前无持仓,直接按照目标仓位进行开仓动作 200 | direction, currency, target_vol, symbol, price = self.target_position_dic[symbol] 201 | self.do_order(md_dic, symbol, target_vol, price, 202 | msg='当前无持仓') 203 | else: 204 | # 如果当前有持仓 205 | # 比较当前持仓总量与目标仓位是否一致 206 | if symbol not in self.target_position_dic: 207 | currency = self.trade_agent.get_currency(symbol) 208 | # 如果当前有持仓,目标仓位为空,则当前持仓无论多空全部平仓 209 | for position_date_type, pos_info_dic in position_date_pos_info_dic.items(): 210 | position = pos_info_dic['balance'] 211 | self.do_order(md_dic, symbol, -position, 212 | msg='目标仓位0,全部平仓') 213 | else: 214 | # 如果当前有持仓,目标仓位也有持仓,则需要进一步比对 215 | direction_target, currency_target, vol_target, symbol, price = self.target_position_dic[symbol] 216 | # 汇总全部同方向持仓,如果不够目标仓位,则加仓 217 | # 对全部的反方向持仓进行平仓 218 | position_holding = 0 219 | for position_date_type, pos_info_dic in position_date_pos_info_dic.items(): 220 | direction = Direction.Long 221 | position = pos_info_dic['balance'] 222 | if direction != direction_target: 223 | self.do_order(md_dic, symbol, -position, price, 224 | msg="目标仓位反向 %d,平仓" % position) 225 | continue 226 | else: 227 | position_holding += position 228 | 229 | # 如果持仓超过目标仓位,则平仓多出的部分,如果不足则补充多的部分 230 | position_gap = vol_target - position_holding 231 | if position_gap > 0: 232 | # 如果不足则补充多的部分 233 | self.do_order(md_dic, symbol, position_gap, price, 234 | msg="补充仓位") 235 | elif position_gap < 0: 236 | # 如果持仓超过目标仓位,则平仓多出的部分 237 | self.do_order(md_dic, symbol, position_gap, price, 238 | msg="持仓超量,平仓 %d" % position_gap) 239 | 240 | # 更新最近执行时间 241 | self.symbol_last_deal_datetime[symbol] = datetime.now() 242 | 243 | 244 | if __name__ == '__main__': 245 | logging.basicConfig(level=logging.DEBUG, format=Config.LOG_FORMAT) 246 | # 参数设置 247 | strategy_params = {} 248 | md_agent_params_list = [ 249 | # { 250 | # 'name': 'min1', 251 | # 'md_period': PeriodType.Min1, 252 | # 'instrument_id_list': ['rb1805', 'i1801'], # ['jm1711', 'rb1712', 'pb1801', 'IF1710'], 253 | # 'init_md_date_to': '2017-9-1', 254 | # 'dict_or_df_as_param': dict 255 | # }, 256 | { 257 | 'name': 'tick', 258 | 'md_period': PeriodType.Tick, 259 | 'instrument_id_list': ['ethusdt', 'eosusdt'], # ['jm1711', 'rb1712', 'pb1801', 'IF1710'], 260 | }] 261 | run_mode_realtime_params = { 262 | 'run_mode': RunMode.Realtime, 263 | 'enable_timer_thread': True, 264 | 'seconds_of_timer_interval': 15, 265 | } 266 | run_mode_backtest_params = { 267 | 'run_mode': RunMode.Backtest, 268 | 'date_from': '2017-9-4', 269 | 'date_to': '2017-9-27', 270 | 'init_cash': 1000000, 271 | 'trade_mode': BacktestTradeMode.Order_2_Deal 272 | } 273 | # run_mode = RunMode.BackTest 274 | # 初始化策略处理器 275 | stghandler = StgHandlerBase.factory( 276 | stg_class_obj=ReadFileStg, 277 | strategy_params=strategy_params, 278 | md_agent_params_list=md_agent_params_list, 279 | **run_mode_realtime_params) 280 | # 开始执行策略 281 | stghandler.start() 282 | # 策略执行 2 分钟后关闭 283 | time.sleep(120) 284 | stghandler.keep_running = False 285 | stghandler.join() 286 | logging.info("执行结束") 287 | # print(os.path.abspath(r'..\file_order')) 288 | -------------------------------------------------------------------------------- /strategy/simple_strategy.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @author : MG 5 | @Time : 2018/6/21 11:26 6 | @File : simple_strategy.py 7 | @contact : mmmaaaggg@163.com 8 | @desc : 9 | """ 10 | import logging 11 | import time 12 | from abat.common import BacktestTradeMode, PeriodType, RunMode, ContextKey, Direction 13 | from abat.strategy import StgBase 14 | import agent.md_agent 15 | import agent.td_agent 16 | from abat.strategy import StgHandlerBase 17 | from config import Config 18 | logger = logging.getLogger() 19 | 20 | 21 | class MACroseStg(StgBase): 22 | 23 | def __init__(self, unit=1): 24 | super().__init__() 25 | self.ma5 = [] 26 | self.ma10 = [] 27 | self.unit = unit 28 | 29 | def on_prepare_min1(self, md_df, context): 30 | if md_df is not None: 31 | self.ma5 = list(md_df['close'].rolling(5, 5).mean())[10:] 32 | self.ma10 = list(md_df['close'].rolling(10, 10).mean())[10:] 33 | 34 | def on_min1(self, md_df, context): 35 | close = md_df['close'].iloc[-1] 36 | self.ma5.append(md_df['close'].iloc[-5:].mean()) 37 | self.ma10.append(md_df['close'].iloc[-10:].mean()) 38 | instrument_id = context[ContextKey.instrument_id_list][0] 39 | if self.ma5[-2] < self.ma10[-2] and self.ma5[-1] > self.ma10[-1]: 40 | position_date_pos_info_dic = self.get_position(instrument_id) 41 | no_target_position = True 42 | if position_date_pos_info_dic is not None: 43 | for position_date, pos_info in position_date_pos_info_dic.items(): 44 | direction = pos_info.direction 45 | if direction == Direction.Short: 46 | self.close_short(instrument_id, close, pos_info.position) 47 | elif direction == Direction.Long: 48 | no_target_position = False 49 | if no_target_position: 50 | self.open_long(instrument_id, close, self.unit) 51 | elif self.ma5[-2] > self.ma10[-2] and self.ma5[-1] < self.ma10[-1]: 52 | position_date_pos_info_dic = self.get_position(instrument_id) 53 | no_target_position = True 54 | if position_date_pos_info_dic is not None: 55 | for position_date, pos_info in position_date_pos_info_dic.items(): 56 | direction = pos_info.direction 57 | if direction == Direction.Long: 58 | self.close_long(instrument_id, close, pos_info.position) 59 | elif direction == Direction.Short: 60 | no_target_position = False 61 | if no_target_position: 62 | self.open_short(instrument_id, close, self.unit) 63 | 64 | 65 | if __name__ == '__main__': 66 | logging.basicConfig(level=logging.DEBUG, format=Config.LOG_FORMAT) 67 | # 参数设置 68 | strategy_params = {'unit': 100000} 69 | md_agent_params_list = [{ 70 | 'name': 'min1', 71 | 'md_period': PeriodType.Min1, 72 | 'instrument_id_list': ['ethbtc'], # ['jm1711', 'rb1712', 'pb1801', 'IF1710'], 73 | 'init_md_date_to': '2018-6-17', 74 | 'init_md_date_to': '2018-6-19', 75 | }] 76 | run_mode_realtime_params = { 77 | 'run_mode': RunMode.Realtime, 78 | } 79 | run_mode_backtest_params = { 80 | 'run_mode': RunMode.Backtest, 81 | 'date_from': '2018-6-18', 82 | 'date_to': '2018-6-19', 83 | 'init_cash': 1000000, 84 | 'trade_mode': BacktestTradeMode.Order_2_Deal 85 | } 86 | # run_mode = RunMode.BackTest 87 | # 初始化策略处理器 88 | stghandler = StgHandlerBase.factory(stg_class_obj=MACroseStg, 89 | strategy_params=strategy_params, 90 | md_agent_params_list=md_agent_params_list, 91 | **run_mode_backtest_params) 92 | stghandler.start() 93 | time.sleep(10) 94 | stghandler.keep_running = False 95 | stghandler.join() 96 | logging.info("执行结束") 97 | -------------------------------------------------------------------------------- /trader.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "run trader" 3 | cd /home/mushrooman/wspy/ABAT_trader_4_blockchain/ 4 | source venv/bin/activate 5 | python3 run.py --num 0 6 | --------------------------------------------------------------------------------