├── .gitignore ├── BackTrader ├── __init__.py ├── base_back_trader.py ├── core_trade_logic.py ├── market_choose.py └── position_analysis.py ├── DailyPicking └── board_mom_hot.py ├── GetBaseData ├── __init__.py ├── ch_eng_mapping.py ├── get_board_data.py ├── get_cash_flow_data.py ├── get_dc_data.py ├── get_history_data.py ├── handle_data_show.py ├── thirdpart │ ├── __init__.py │ └── stock_board_industry_em.py └── use_baostock │ ├── __init__.py │ ├── cal_base_indicators.py │ ├── stock_k_data_dask.py │ └── stock_k_multi_cycle.py ├── LICENSE ├── MachineLearning ├── __init__.py ├── annotation_platform │ └── buy_and_sell_signals.py ├── data_process │ ├── __init__.py │ ├── base_data.py │ ├── dask_data.py │ └── indicators_cal.py └── model_train │ ├── Informer │ ├── __init__.py │ ├── main.ipynb │ └── train.py │ ├── __init__.py │ └── base_lgb.py ├── Monitor ├── BaseInfoStockMonitor.py └── __init__.py ├── README.md ├── StrategyLib ├── AutomaticInvestmentPlan │ ├── __init__.py │ ├── result_show.py │ ├── stable_dog.py │ └── stable_dog_unstop.py ├── ChanStrategy │ ├── BasicChan │ │ ├── __init__.py │ │ ├── basic_enum.py │ │ ├── basic_structure.py │ │ └── basic_tools.py │ ├── Test │ │ ├── __init__.py │ │ └── test_plot.py │ ├── __init__.py │ └── automatic_drawing.py ├── ChooseAssetStrategy │ └── board_mom.py ├── OneAssetStrategy │ ├── Demark9.py │ ├── EMA_Ma_Crossover.py │ ├── Ma5Ma10.py │ ├── MacdDeviate.py │ ├── README.md │ ├── __init__.py │ ├── macd30.py │ ├── macd_30m.py │ ├── macd_30m_dayMacd.py │ └── macd_day.py ├── __init__.py └── macd_day.py ├── StrategyResearch ├── board_mom.py ├── board_stock_analysis.py ├── data_research │ ├── compare_sk_ta.py │ └── data_pipeline_dask.py ├── linear_regression.py ├── popularity.py ├── time_series │ └── CR.py └── tmp.py ├── StudyDoc └── Advances.in.Financial.Machine.Learning-Wiley(2018).pdf ├── Utils ├── ShowKline │ ├── OfficeCase.py │ ├── __init__.py │ ├── base_kline.py │ └── chan_plot.py ├── TechnicalIndicators │ ├── __init__.py │ └── basic_indicators.py ├── __init__.py ├── base_utils.py └── info_push.py ├── api ├── hist.html ├── stock_api.py ├── test.html ├── test.py └── 示例 │ ├── README.md │ ├── demo.html │ └── js │ ├── echarts.min.js │ ├── jquery-3.3.1.min.js │ ├── k-line.js │ └── tmpData.js ├── pyproject.toml ├── requirements.txt ├── start_stable_dog.sh └── web_ui ├── __init__.py ├── show_page.py └── time_sharing.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | **/__pycache__ 4 | 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | .DS_Store 134 | .idea/ 135 | .vscode/ 136 | .env 137 | Data/ 138 | ShowHtml/ 139 | tmp/ 140 | log/ 141 | 142 | PrivacyConfig/ 143 | result/ 144 | # StrategyResearch/ 145 | model/ 146 | 147 | dask-worker-space/ 148 | nohup.out 149 | 150 | .ipynb_checkpoints/ 151 | .vscode/ 152 | .idea/ 153 | 154 | poetry.lock 155 | .ruff_cache/ -------------------------------------------------------------------------------- /BackTrader/__init__.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2021/12/22 22:28 4 | # @Author : Adolf 5 | # @File : __init__.py.py 6 | # @Function: 7 | -------------------------------------------------------------------------------- /BackTrader/base_back_trader.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2021/12/22 22:29 4 | # @Author : Adolf 5 | # @File : base_back_trader.py 6 | # @Function: 7 | 8 | import itertools 9 | import json 10 | import os 11 | import random 12 | import statistics 13 | from dataclasses import dataclass, field 14 | 15 | # from functools import reduce 16 | import pandas as pd 17 | 18 | from BackTrader.core_trade_logic import CoreTradeLogic 19 | from BackTrader.position_analysis import BaseTransactionAnalysis 20 | from GetBaseData.handle_data_show import show_data_from_df 21 | from Utils.ShowKline.base_kline import draw_chart 22 | from Utils.TechnicalIndicators.basic_indicators import MACD, SMA 23 | 24 | # from tqdm.auto import tqdm 25 | 26 | pd.set_option("expand_frame_repr", False) 27 | pd.set_option("display.max_rows", None) 28 | 29 | 30 | @dataclass 31 | class TradeStructureConfig: 32 | LOG_LEVEL: str = field( 33 | default="INFO", 34 | metadata={ 35 | "help": "日志级别,默认INFO,可选DEBUG、INFO、WARNING、ERROR、CRITICAL" 36 | }, 37 | ) 38 | CODE_NAME: str | None = field(default=None, metadata={"help": "股票代码"}) 39 | RANDOM_SEED: int = field(default=42, metadata={"help": "随机种子"}) 40 | STRATEGY_PARAMS: dict = field( 41 | default_factory=lambda: dict(), metadata={"help": "策略参数"} 42 | ) 43 | START_STAMP: str = field(default=None, metadata={"help": "开始时间"}) 44 | END_STAMP: str = field(default=None, metadata={"help": "结束时间"}) 45 | SHOW_DATA_PATH: str = field(default=None, metadata={"help": "展示数据路径"}) 46 | 47 | 48 | class TradeStructure(CoreTradeLogic): 49 | def __init__(self, config): 50 | self.config = TradeStructureConfig(**config) 51 | super().__init__() 52 | 53 | self.logger.info("Trade is begging ......") 54 | 55 | self.data = None 56 | self.transaction_analysis = BaseTransactionAnalysis(logger=self.logger) 57 | 58 | self.stock_result = None 59 | self.pl_result = None 60 | 61 | # 设置随机种子,保证实验结果的可复现性 62 | random.seed(self.config.RANDOM_SEED) 63 | 64 | # 加载数据集 65 | def load_dataset(self, data_path, start_stamp=None, end_stamp=None): 66 | df = pd.read_csv(data_path) 67 | df["market_cap"] = (df["amount"] * 100 / df["turn"]) / pow(10, 8) 68 | 69 | if start_stamp is not None: 70 | df = df[df["date"] > start_stamp] 71 | 72 | if end_stamp is not None: 73 | df = df[df["date"] < end_stamp] 74 | 75 | df.reset_index(drop=True, inplace=True) 76 | 77 | # self.logger.debug(df) 78 | self.data = df 79 | self.data = self.data[ 80 | [ 81 | "date", 82 | "open", 83 | "high", 84 | "low", 85 | "close", 86 | "volume", 87 | "turn", 88 | "market_cap", 89 | "code", 90 | ] 91 | ] 92 | 93 | # 计算基础的交易指标 94 | def cal_base_technical_indicators( 95 | self, sma_list=(5, 10, 20), macd_parm=(12, 26, 9) 96 | ): 97 | if sma_list is not None: 98 | for sma_parm in sma_list: 99 | self.data["sma" + str(sma_parm)] = SMA( 100 | self.data["close"], timeperiod=sma_parm 101 | ) 102 | if macd_parm is not None: 103 | macd_df = MACD( 104 | close=self.data["close"], 105 | fastperiod=macd_parm[0], 106 | slowperiod=macd_parm[1], 107 | signalperiod=macd_parm[2], 108 | ) 109 | 110 | self.data["macd"], self.data["histogram"], self.data["signal"] = [ 111 | macd_df["MACD_12_26_9"], 112 | macd_df["MACDh_12_26_9"], 113 | macd_df["MACDs_12_26_9"], 114 | ] 115 | 116 | # 计算需要使用到的指标 117 | def cal_technical_indicators(self, indicators_config): 118 | """可以计算需要用到的指标,如果不重写则使用默认的指标""" 119 | self.cal_base_technical_indicators() 120 | # raise NotImplementedError 121 | 122 | # 使用的到的交涉策略细节(已废弃) 123 | # def trading_algorithm(self): 124 | # raise NotImplementedError 125 | 126 | # 需要保证show_data里面的核心数据没有空值,不然会造成数据无法显示 127 | # @staticmethod 128 | def show_one_stock(self, show_data): 129 | if not os.path.exists("ShowHtml"): 130 | os.mkdir("ShowHtml") 131 | show_data_path = ( 132 | self.config.SHOW_DATA_PATH 133 | if self.config.SHOW_DATA_PATH 134 | else "ShowHtml/demo.html" 135 | ) 136 | show_data = show_data_from_df(df_or_dfpath=show_data) 137 | # import pdb;pdb.set_trace() 138 | draw_chart(input_data=show_data, show_html_path=show_data_path) 139 | 140 | # 使用一套参数对一只股票进行回测 141 | def run_one_stock_once(self, code_name, indicators_config=None): 142 | if indicators_config is None: 143 | indicators_config = self.config.STRATEGY_PARAMS 144 | 145 | data_path = os.path.join("Data/RealData/hfq/", code_name + ".csv") 146 | 147 | self.load_dataset( 148 | data_path=data_path, 149 | start_stamp=self.config.START_STAMP, 150 | end_stamp=self.config.END_STAMP, 151 | ) 152 | 153 | self.cal_technical_indicators(indicators_config) 154 | self.data.dropna( 155 | axis=0, how="any", inplace=True 156 | ) # drop all rows that have any NaN values 157 | self.data.reset_index(drop=True, inplace=True) 158 | 159 | # if not self.cal_technical_indicators(indicators_config): 160 | # return False 161 | 162 | # 废弃了,现在需要编写算法的购买逻辑和卖出逻辑 163 | # self.trading_algorithm() 164 | # transaction_record_df = self.strategy_execute() 165 | 166 | transaction_record_df = self.base_trade(self.data) 167 | 168 | asset_analysis = self.transaction_analysis.cal_asset_analysis(self.data) 169 | 170 | if asset_analysis is not None: 171 | self.logger.success(f"对标的进行分析:\n{asset_analysis}") 172 | self.stock_result = asset_analysis 173 | 174 | if len(transaction_record_df) > 0: 175 | strategy_analysis = self.transaction_analysis.cal_trader_analysis( 176 | transaction_record_df 177 | ) 178 | else: 179 | self.logger.info("没有交易记录,无法进行交易分析") 180 | return None 181 | 182 | # self.logger.debug("策略使用的参数:\n{}".format(indicators_config)) 183 | # self.logger.debug("对策略结果进行分析:\n{}".format(strategy_analysis)) 184 | 185 | self.pl_result = strategy_analysis 186 | 187 | pl_ration = strategy_analysis.loc["策略的盈亏比", "result"] 188 | # self.logger.info(pl_ration) 189 | 190 | return pl_ration 191 | 192 | def run_one_stock(self, code_name=None): 193 | pl_ration = 0 194 | # indicators_config = self.config.get("strategy_params", {}) 195 | indicators_config = self.config.STRATEGY_PARAMS 196 | # self.logger.info(indicators_config) 197 | 198 | if code_name is None: 199 | code_name = self.config.CODE_NAME 200 | 201 | # if not self.config["one_param"]: 202 | # self.logger.debug(indicators_config) 203 | # p = {k: list(itertools.permutations(v)) for k, v in indicators_config.items()} 204 | # for blah in itertools.product() 205 | # self.logger.info(p) 206 | 207 | if indicators_config: 208 | if any( 209 | [isinstance(value, list) for key, value in indicators_config.items()] 210 | ): 211 | pl_ration_list = [] 212 | for item in itertools.product( 213 | *[value for key, value in indicators_config.items()] 214 | ): 215 | self.logger.debug(item) 216 | one_indicator_config = { 217 | list(indicators_config.keys())[index]: item[index] 218 | for index in range(len(list(indicators_config.keys()))) 219 | } 220 | 221 | one_pl_relation = self.run_one_stock_once( 222 | code_name=code_name, indicators_config=one_indicator_config 223 | ) 224 | # if one_pl_relation is not nan: 225 | pl_ration_list.append(one_pl_relation) 226 | pl_ration = statistics.mean(pl_ration_list) 227 | 228 | else: 229 | pl_ration = self.run_one_stock_once(code_name=code_name) 230 | 231 | else: 232 | pl_ration = self.run_one_stock_once(code_name=code_name) 233 | 234 | self.logger.success(f"{code_name}的盈亏比是{pl_ration}") 235 | self.show_one_stock(self.data) 236 | 237 | return pl_ration 238 | 239 | def run(self) -> None: 240 | code_name = self.config.CODE_NAME 241 | self.logger.debug(code_name) 242 | 243 | pl_ration_list = [] 244 | if isinstance(code_name, list): 245 | for code in code_name: 246 | one_pl_ration = self.run_one_stock(code_name=code) 247 | pl_ration_list.append(one_pl_ration) 248 | pl_ration = statistics.mean(pl_ration_list) 249 | 250 | # elif code_name.upper() == "ALL_MARKET": 251 | elif "ALL_MARKET" in code_name.upper(): 252 | with open("Data/RealData/ALL_MARKET_CODE.json") as all_market_code: 253 | market_code_dict = json.load(all_market_code) 254 | self.logger.debug(market_code_dict) 255 | 256 | market_code_list = list(market_code_dict.keys()) 257 | 258 | if code_name.upper() != "ALL_MARKET": 259 | sample_num = int(code_name.split("_")[-1]) 260 | market_code_list = random.sample(market_code_list, int(sample_num)) 261 | 262 | self.logger.debug(market_code_list) 263 | 264 | for code in market_code_list: 265 | self.logger.info(code) 266 | try: 267 | one_pl_ration = self.run_one_stock(code_name=code) 268 | # self.logger.info(one_pl_ration) 269 | 270 | # 判断变量是否为nan,如果是nan则不添加进List 271 | if one_pl_ration == one_pl_ration and one_pl_ration is not None: 272 | pl_ration_list.append(one_pl_ration) 273 | # self.logger.info(pl_ration_list) 274 | except Exception as e: 275 | self.logger.debug(e) 276 | self.logger.debug(code) 277 | pl_ration = statistics.mean(pl_ration_list) 278 | 279 | else: 280 | pl_ration = self.run_one_stock() 281 | 282 | if pl_ration is not None: 283 | self.logger.success( 284 | f"策略交易一次的收益的数学期望为:{pl_ration * 100:.2f}%" 285 | ) 286 | 287 | # if len(pl_ration_list) > 0: 288 | # return pl_ration_list[0] 289 | # else: 290 | # return None 291 | # self.run_one_stock() 292 | 293 | 294 | # if __name__ == '__main__': 295 | # trade_structure = TradeStructure(config="") 296 | # trade_structure.run_one_stock(code_name="600570", start_stamp="2021-01-01", end_stamp="2021-12-31") 297 | -------------------------------------------------------------------------------- /BackTrader/core_trade_logic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-12-01 23:29:44 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2023-02-06 00:22:54 6 | FilePath : /stock_quant/BackTrader/core_trade_logic.py 7 | """ 8 | 9 | 10 | from dataclasses import dataclass, field 11 | 12 | import pandas as pd 13 | 14 | from Utils.base_utils import get_logger 15 | 16 | from .position_analysis import BaseTransactionAnalysis 17 | 18 | 19 | @dataclass 20 | class OneTransactionRecord: 21 | pos_asset: str = field(default=None, metadata={"help": "持仓资产"}) 22 | buy_date: str = field(default=None, metadata={"help": "买入时间"}) 23 | buy_price: float = field(default=0.0, metadata={"help": "买入价格"}) 24 | sell_date: str = field(default=None, metadata={"help": "卖出时间"}) 25 | sell_price: float = field(default=0.0, metadata={"help": "卖出价格"}) 26 | holding_time: int = field(default=0, metadata={"help": "持仓时间"}) 27 | take_profit: float = field(default=None, metadata={"help": "止盈价格"}) 28 | stop_loss: float = field(default=None, metadata={"help": "止损价格"}) 29 | 30 | 31 | @dataclass 32 | class TradeStructure: 33 | trading_step: pd.Series = field( 34 | default=None, metadata={"help": "当前交易标的物的状态"} 35 | ) 36 | one_transaction_record: OneTransactionRecord = field( 37 | default=None, metadata={"help": "当前交易记录"} 38 | ) 39 | history_trading_step: list[pd.Series] = field( 40 | default=None, metadata={"help": "历史交易记录"} 41 | ) 42 | 43 | # def __post_init__(self): 44 | # self.one_transaction_record = OneTransactionRecord() 45 | # self.history_trading_step = [] 46 | 47 | 48 | class CoreTradeLogic: 49 | def __init__(self) -> None: 50 | self.trade_rate = 1.5 / 1000 51 | 52 | self.logger = get_logger( 53 | level=self.config.LOG_LEVEL, console=True, logger_file=None 54 | ) 55 | 56 | self.trade_state = TradeStructure() 57 | # 针对交易结果进行分析 58 | self.transaction_analysis = BaseTransactionAnalysis(logger=self.logger) 59 | 60 | def buy_logic(self): 61 | raise NotImplementedError 62 | 63 | def sell_logic(self): 64 | raise NotImplementedError 65 | 66 | def buy(self, index, trading_step, one_transaction_record): 67 | self.logger.debug(f"buy {index} {trading_step} {one_transaction_record}") 68 | 69 | one_transaction_record.pos_asset = trading_step.code 70 | one_transaction_record.buy_date = trading_step.date 71 | one_transaction_record.buy_price = trading_step.close 72 | one_transaction_record.holding_time = index 73 | self.logger.debug(one_transaction_record) 74 | return one_transaction_record 75 | 76 | def sell(self, index, trading_step, one_transaction_record): 77 | self.logger.debug(f"sell {index} \n {trading_step} \n {one_transaction_record}") 78 | 79 | one_transaction_record.sell_date = trading_step.date 80 | one_transaction_record.sell_price = trading_step.close 81 | one_transaction_record.holding_time = ( 82 | index - one_transaction_record.holding_time 83 | ) 84 | 85 | self.logger.debug(one_transaction_record) 86 | return one_transaction_record 87 | 88 | def base_trade(self, data) -> list[dict]: 89 | self.trade_state.one_transaction_record = OneTransactionRecord() 90 | 91 | self.trade_state.history_trading_step = [] 92 | transaction_record_list = [] 93 | # self.logger.debug(one_transaction_record) 94 | 95 | for index, trading_step in data.iterrows(): 96 | if len(self.trade_state.history_trading_step) == 0: 97 | self.trade_state.history_trading_step.append(trading_step) 98 | continue 99 | 100 | self.trade_state.trading_step = trading_step 101 | if ( 102 | self.trade_state.one_transaction_record.buy_date is None 103 | and self.buy_logic() 104 | ): 105 | data.loc[index, "buy"] = 1 106 | one_transaction_record = self.buy( 107 | index, trading_step, self.trade_state.one_transaction_record 108 | ) 109 | continue 110 | 111 | if ( 112 | self.trade_state.one_transaction_record.buy_date != trading_step.date 113 | and self.trade_state.one_transaction_record.buy_date is not None 114 | and self.sell_logic() 115 | ): 116 | data.loc[index, "sell"] = 1 117 | one_transaction_record = self.sell( 118 | index, trading_step, self.trade_state.one_transaction_record 119 | ) 120 | 121 | transaction_record_list.append(one_transaction_record) 122 | self.trade_state.one_transaction_record = OneTransactionRecord() 123 | 124 | # if self.buy_logic(trading_step, one_transaction_record): 125 | # one_transaction_record = self.buy( 126 | # index, trading_step, one_transaction_record 127 | # ) 128 | 129 | self.trade_state.history_trading_step.append(trading_step) 130 | if len(self.trade_state.history_trading_step) > 1: 131 | self.trade_state.history_trading_step.pop(0) 132 | 133 | self.logger.debug(transaction_record_list) 134 | transaction_record_df = pd.DataFrame(transaction_record_list) 135 | self.logger.debug(transaction_record_df) 136 | 137 | if len(transaction_record_df) == 0: 138 | return transaction_record_df 139 | 140 | transaction_record_df["pct"] = ( 141 | round( 142 | ( 143 | transaction_record_df["sell_price"] 144 | / transaction_record_df["buy_price"] 145 | ) 146 | * (1 - self.trade_rate), 147 | 4, 148 | ) 149 | - 1 150 | ) 151 | 152 | self.logger.info(transaction_record_df) 153 | 154 | return transaction_record_df 155 | -------------------------------------------------------------------------------- /BackTrader/market_choose.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-07-03 23:11:00 5 | LastEditTime: 2022-08-09 22:38:37 6 | LastEditors: adolf 7 | """ 8 | 9 | import os 10 | from dataclasses import dataclass, field 11 | from functools import reduce 12 | 13 | import pandas as pd 14 | 15 | # import ray 16 | import pathos 17 | from tqdm.auto import tqdm 18 | 19 | from BackTrader.core_trade_logic import CoreTradeLogic 20 | 21 | 22 | @dataclass 23 | class MarketChooseConfig: 24 | LOG_LEVEL: str = field( 25 | default="INFO", 26 | metadata={ 27 | "help": "日志级别,默认INFO,可选DEBUG、INFO、WARNING、ERROR、CRITICAL" 28 | }, 29 | ) 30 | DATA_PATH: str = field( 31 | default="Data/BoardData/industry_origin", metadata={"help": "数据路径"} 32 | ) 33 | SAVE_PATH: str = field( 34 | default="Data/ChooseData/board_mom.csv", metadata={"help": "保存获取完指标路径"} 35 | ) 36 | RUN_ONLINE: bool = field(default=True, metadata={"help": "是否在线运行,默认为True"}) 37 | 38 | 39 | # @ray.remote 40 | class MarketChoose(CoreTradeLogic): 41 | def __init__(self, *args, **kwargs) -> None: 42 | self.config = MarketChooseConfig(*args, **kwargs) 43 | super().__init__() 44 | 45 | self.logger.info("MarketChoose init") 46 | self.all_data_list = os.listdir(self.config.DATA_PATH) 47 | 48 | def get_market_data(self): 49 | with pathos.multiprocessing.ProcessingPool(8) as p: 50 | result = list( 51 | tqdm( 52 | p.imap(self.cal_one_data, self.all_data_list), 53 | total=len(self.all_data_list), 54 | desc="运行全体数据", 55 | ) 56 | ) 57 | self.logger.success("MarketChoose run success") 58 | 59 | return result 60 | 61 | # @ray.remote 62 | def cal_one_data(self, *args, **kwargs): 63 | raise NotImplementedError 64 | 65 | def choose_rule(self, *args, **kwargs): 66 | raise NotImplementedError 67 | 68 | def buy_logic(self, *args, **kwargs): 69 | return True 70 | 71 | def sell_logic(self, trading_step, one_transaction_record, *args, **kwargs): 72 | if trading_step["choose_assert"] != one_transaction_record.pos_asset: 73 | return True 74 | return False 75 | 76 | def buy(self, index, trading_step, one_transaction_record): 77 | self.logger.trace(f"buy {index} {trading_step} {one_transaction_record}") 78 | 79 | one_transaction_record.pos_asset = trading_step["choose_assert"] 80 | one_transaction_record.buy_date = trading_step["date"] 81 | one_transaction_record.buy_price = trading_step[ 82 | f"{trading_step['choose_assert']}_close" 83 | ] 84 | one_transaction_record.holding_time = index 85 | 86 | return one_transaction_record 87 | 88 | def sell(self, index, trading_step, one_transaction_record): 89 | self.logger.debug(f"sell {index} {trading_step} {one_transaction_record}") 90 | 91 | one_transaction_record.sell_date = trading_step["date"] 92 | one_transaction_record.sell_price = trading_step[ 93 | f"{one_transaction_record.pos_asset}_close" 94 | ] 95 | # one_transaction_record.pos_asset = None 96 | one_transaction_record.holding_time = ( 97 | index - one_transaction_record.holding_time 98 | ) 99 | 100 | self.logger.debug(one_transaction_record) 101 | return one_transaction_record 102 | # self.buy(index, trading_step, one_transaction_record) 103 | # self.logger.debug(one_transaction_record) 104 | # exit() 105 | 106 | def run(self): 107 | if self.config.RUN_ONLINE: 108 | res_data_list = self.get_market_data() 109 | 110 | df_merged = reduce( 111 | lambda left, right: pd.merge(left, right, on=["date"], how="outer"), 112 | res_data_list, 113 | ) 114 | df_merged.sort_values(by=["date"], inplace=True) 115 | df_merged.reset_index(drop=True, inplace=True) 116 | 117 | self.logger.success(df_merged) 118 | 119 | choose_data = self.choose_rule(df_merged) 120 | choose_data.to_csv(self.config.SAVE_PATH, index=False) 121 | 122 | else: 123 | choose_data = pd.read_csv(self.config.SAVE_PATH) 124 | 125 | # choose_data['choose_assert'].dropna(inplace=True) 126 | choose_data = choose_data[~pd.isnull(choose_data["choose_assert"])] 127 | 128 | if self.config.LOG_LEVEL == "DEBUG": 129 | choose_data = choose_data[:100] 130 | 131 | choose_data.reset_index(drop=True, inplace=True) 132 | 133 | self.logger.success(choose_data) 134 | 135 | transaction_record_df = self.base_trade(choose_data) 136 | 137 | pl = self.transaction_analysis.cal_trader_analysis(transaction_record_df) 138 | 139 | return pl 140 | -------------------------------------------------------------------------------- /BackTrader/position_analysis.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/1/3 15:53 4 | # @Author : Adolf 5 | # @File : position_analysis.py 6 | # @Function: 7 | import pandas as pd 8 | 9 | from Utils.base_utils import run_once 10 | 11 | 12 | class BaseTransactionAnalysis: 13 | def __init__(self, logger): 14 | self.logger = logger 15 | 16 | @staticmethod 17 | def cal_max_down(df, pct_name="strategy_net", time_stamp="date"): 18 | res_df = df.copy() 19 | res_df["max2here"] = res_df[pct_name].expanding().max() 20 | res_df["dd2here"] = res_df[pct_name] / res_df["max2here"] - 1 21 | # 计算最大回撤,以及最大回撤结束时间 22 | end_date, max_draw_down = tuple( 23 | res_df.sort_values(by=["dd2here"]).iloc[0][[time_stamp, "dd2here"]] 24 | ) 25 | # 计算最大回撤开始时间 26 | start_date = ( 27 | res_df[res_df[time_stamp] <= end_date] 28 | .sort_values(by=pct_name, ascending=False) 29 | .iloc[0][time_stamp] 30 | ) 31 | # 将无关的变量删除 32 | res_df.drop(["max2here", "dd2here"], axis=1, inplace=True) 33 | return max_draw_down, start_date, end_date 34 | 35 | def cal_trader_analysis(self, data): 36 | # self.logger.debug(data) 37 | 38 | # 计算策略的收益率 39 | data["strategy_net"] = (1 + data["pct"]).cumprod() 40 | data["pct_show"] = data["pct"].apply(lambda x: format(x, ".2%")) 41 | 42 | strategy_pct = data.tail(1)["strategy_net"].item() 43 | 44 | # 计算策略的成功率 45 | success_rate = len(data[data["pct"] > 0]) / len(data) 46 | 47 | # 计算策略的赔率 48 | odds = data["pct"].mean() 49 | 50 | # 计算盈亏比 51 | if len(data[data["pct"] > 0]) > 0: 52 | profit = data[data["pct"] > 0]["pct"].mean() 53 | else: 54 | profit = 0 55 | 56 | if len(data[data["pct"] < 0]) > 0: 57 | loss = data[data["pct"] < 0]["pct"].mean() 58 | else: 59 | loss = 0 60 | 61 | pl_ratio = profit * success_rate + loss * (1 - success_rate) 62 | 63 | # 计算策略的平均持股天数 64 | mean_holding_day = data["holding_time"].mean() 65 | 66 | # 计算总持有时间 67 | sum_holding_day = data["holding_time"].sum() 68 | 69 | # 统计策略的交易次数 70 | trade_nums = len(data) 71 | 72 | # strategy_annual_return = strategy_pct ** (250 / len(self.Data)) - 1 73 | 74 | # 统计策略的最大回撤 75 | strategy_max_draw_down, strategy_start_date, strategy_end_date = ( 76 | self.cal_max_down(df=data, pct_name="strategy_net", time_stamp="buy_date") 77 | ) 78 | 79 | result_dict = dict() 80 | result_dict["股票代码"] = data.loc[0, "pos_asset"] 81 | result_dict["平均持有时间"] = mean_holding_day 82 | result_dict["交易次数"] = trade_nums 83 | result_dict["计算总持有时间"] = sum_holding_day 84 | 85 | result_dict["策略收益率"] = strategy_pct 86 | 87 | result_dict["策略成功率"] = success_rate 88 | result_dict["策略赔率"] = odds 89 | 90 | result_dict["策略最大回撤"] = strategy_max_draw_down 91 | 92 | result_dict["策略最大回撤开始时间"] = strategy_start_date 93 | result_dict["策略最大回撤结束时间"] = strategy_end_date 94 | 95 | result_dict["策略的盈亏比"] = pl_ratio 96 | # self.logger.info(result_dict) 97 | 98 | result_df = pd.DataFrame.from_dict( 99 | result_dict, orient="index", columns=["result"] 100 | ) 101 | self.logger.success(result_df) 102 | 103 | return result_df 104 | 105 | @run_once 106 | def cal_asset_analysis(self, data): 107 | # 计算标的收益率 108 | asset_pct = data.close[len(data) - 1] / data.close[0] 109 | 110 | # 计算标的年化 111 | asset_pct_annual_return = asset_pct ** (250 / len(data)) - 1 112 | 113 | # 统计策略的最大回撤 114 | asset_max_draw_down, asset_start_date, asset_end_date = self.cal_max_down( 115 | df=data, pct_name="close", time_stamp="date" 116 | ) 117 | 118 | result_dict = dict() 119 | result_dict["标的收益率"] = asset_pct 120 | result_dict["标的年化"] = asset_pct_annual_return 121 | 122 | result_dict["标的最大回撤"] = asset_max_draw_down 123 | result_dict["标的最大回撤开始时间"] = asset_start_date 124 | result_dict["标的最大回撤结束时间"] = asset_end_date 125 | 126 | result_dict["标的交易时间"] = len(data) 127 | 128 | result_df = pd.DataFrame.from_dict( 129 | result_dict, orient="index", columns=["result"] 130 | ) 131 | # self.logger.debug(result_df) 132 | 133 | return result_df 134 | 135 | # TODO 展示股票交易结果买卖点 136 | def show_analysis_result(self): 137 | pass 138 | -------------------------------------------------------------------------------- /DailyPicking/board_mom_hot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-08-17 21:37:01 5 | LastEditTime: 2022-08-17 21:37:01 6 | LastEditors: adolf 7 | """ 8 | 9 | import json 10 | import os 11 | from datetime import date 12 | 13 | # from pprint import pprint 14 | import akshare as ak 15 | import numpy as np 16 | import pandas as pd 17 | from sklearn.linear_model import LinearRegression 18 | 19 | from GetBaseData.ch_eng_mapping import ch_eng_mapping_dict 20 | 21 | # from itertools import reduce 22 | 23 | today = date.today() 24 | today = today.strftime("%Y%m%d") 25 | 26 | 27 | def normalization(data): 28 | _range = np.max(data) - np.min(data) 29 | return (data - np.min(data)) / _range 30 | 31 | 32 | def cal_one_board_mom(board_data_path, period=20): 33 | data = pd.read_csv("Data/BoardData/industry_origin/" + board_data_path) 34 | data = data[["date", "open", "high", "low", "close", "volume"]] 35 | data = data[-period:] 36 | data.reset_index(drop=True, inplace=True) 37 | data["mid"] = (data["open"] + data["close"] + data["high"] + data["low"]) / 4 38 | 39 | model = LinearRegression() 40 | x = np.linspace(0, 1, period).reshape(-1, 1) 41 | 42 | y_close = data.close.values.reshape(-1, 1) 43 | y_close = normalization(y_close) 44 | model.fit(x, y_close) 45 | 46 | # print(model.coef_[0][0]) 47 | R2 = model.score(x, y_close) 48 | # print(data) 49 | # print(R2) 50 | return (model.coef_[0][0], R2) 51 | 52 | 53 | # cal_one_board_mom("汽车整车.csv") 54 | 55 | board_list = os.listdir("Data/BoardData/industry_origin") 56 | # print(board_list) 57 | 58 | with open("Data/BoardData/ALL_INDUSTRY_BOARD.json") as all_market_code: 59 | industry_board_name_mapping = json.load(all_market_code) 60 | 61 | 62 | # 根据行业板块的动量对板块进行选择,选择市场上涨势最强的10个板块 63 | def get_choose_board(): 64 | board_res = {} 65 | 66 | for board_name in board_list: 67 | (w0, R2) = cal_one_board_mom(board_name) 68 | print(board_name, w0, R2) 69 | exit() 70 | # print(board_name) 71 | # print(industry_board_name_mapping[board_name.replace(".csv","")]) 72 | # print(w0) 73 | # print(R2) 74 | board_res[board_name.replace(".csv", "")] = w0 * R2 75 | 76 | board_res_sort = dict( 77 | sorted(board_res.items(), key=lambda item: item[1], reverse=True) 78 | ) 79 | # print(board_res_sort) 80 | 81 | choose_board = list(board_res_sort.keys())[:10] 82 | choose_board_code = [industry_board_name_mapping[key] for key in choose_board] 83 | 84 | print(choose_board) 85 | print(choose_board_code) 86 | 87 | return choose_board 88 | 89 | 90 | choose_board_list = get_choose_board() 91 | 92 | # 问财热度排行 93 | stock_hot_rank_wc_df = ak.stock_hot_rank_wc(date=today) 94 | # print(stock_hot_rank_wc_df) 95 | 96 | stock_hot_rank = stock_hot_rank_wc_df.set_index(["股票代码"])["序号"].to_dict() 97 | # print(stock_hot_rank) 98 | # exit() 99 | res_df_list = [] 100 | 101 | # 选择板块中的股票,然后通过问财的热度进行排序 102 | for one_choose_board in choose_board_list: 103 | stock_board_industry_cons = ak.stock_board_industry_cons_em(symbol=one_choose_board) 104 | 105 | stock_board_industry_cons.rename(columns=ch_eng_mapping_dict, inplace=True) 106 | stock_board_industry_cons = stock_board_industry_cons[["code", "name", "price"]] 107 | stock_board_industry_cons["hot_rank"] = stock_board_industry_cons["code"].apply( 108 | lambda x: stock_hot_rank[x] if x in stock_hot_rank else 9999 109 | ) 110 | 111 | stock_board_industry_cons.sort_values(by="hot_rank", inplace=True) 112 | # stock_board_industry_cons = stock_board_industry_cons[::-1] 113 | # print(stock_board_industry_cons) 114 | # print(stock_board_industry_cons[:5]) 115 | res_df_list.append(stock_board_industry_cons[:5]) 116 | 117 | # break 118 | res_df = pd.concat(res_df_list) 119 | res_df.sort_values(by="hot_rank", inplace=True) 120 | res_df.reset_index(drop=True, inplace=True) 121 | 122 | res_df = res_df[: len(res_df) // 2] 123 | print(res_df) 124 | -------------------------------------------------------------------------------- /GetBaseData/__init__.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2021/11/30 23:07 4 | # @Author : Adolf 5 | # @File : __init__.py.py 6 | # @Function: 7 | -------------------------------------------------------------------------------- /GetBaseData/ch_eng_mapping.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2021/11/30 23:20 4 | # @Author : Adolf 5 | # @File : ch_eng_mapping.py 6 | # @Function: 7 | 8 | ch_eng_mapping_dict = { 9 | "序号": "index", 10 | "代码": "code", 11 | "名称": "name", 12 | "最新价": "price", 13 | "涨跌额": "priceChg", 14 | "涨跌幅": "pctChg", 15 | "成交量": "volume", 16 | "成交额": "amount", 17 | "振幅": "amplitude", 18 | "最高": "high", 19 | "最低": "low", 20 | "今开": "open", 21 | "昨收": "pre_close", 22 | "量比": "QRR", 23 | "换手率": "turn", 24 | "市盈率-动态": "pe", 25 | "市净率": "pb", 26 | "日期": "date", 27 | "开盘": "open", 28 | "收盘": "close", 29 | "开盘价": "open", 30 | "最高价": "high", 31 | "最低价": "low", 32 | "收盘价": "close", 33 | "现价": "price", 34 | "涨跌": "pctChg", 35 | "换手": "turn", 36 | "市盈率": "pe", 37 | "总市值": "TMC", 38 | "流通市值": "FMC", 39 | } 40 | -------------------------------------------------------------------------------- /GetBaseData/get_board_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-07-23 16:04:07 5 | LastEditTime: 2022-08-18 23:25:46 6 | LastEditors: adolf 7 | """ 8 | 9 | import json 10 | import os 11 | 12 | from loguru import logger 13 | from tqdm.auto import tqdm 14 | 15 | from GetBaseData.ch_eng_mapping import ch_eng_mapping_dict 16 | 17 | # import akshare as ak 18 | # from pprint import pprint 19 | from GetBaseData.thirdpart.stock_board_industry_em import ( 20 | stock_board_industry_hist_em, 21 | stock_board_industry_name_em, 22 | ) 23 | 24 | logger.info("开始获取股票板块数据") 25 | 26 | # stock_board_concept_name_em_df = ak.stock_board_concept_name_em() 27 | # print(stock_board_concept_name_em_df) 28 | 29 | # board_name_mapping = stock_board_concept_name_em_df.set_index(['板块名称'])['板块代码'].to_dict() 30 | # pprint(board_name_mapping) 31 | 32 | # with open("Data/BoardData/ALL_BOARD_NAME.json", "w") as all_market_code: 33 | # json.dump(board_name_mapping, all_market_code, ensure_ascii=False) 34 | 35 | # stock_board_industry_name_em_df = ak.stock_board_industry_name_em() 36 | stock_board_industry_name_em_df = stock_board_industry_name_em() 37 | # print(stock_board_industry_name_em_df) 38 | 39 | industry_board_name_mapping = stock_board_industry_name_em_df.set_index(["板块名称"])[ 40 | "板块代码" 41 | ].to_dict() 42 | # pprint(industry_board_name_mapping) 43 | 44 | with open("Data/BoardData/ALL_INDUSTRY_BOARD.json", "w") as all_market_code: 45 | json.dump(industry_board_name_mapping, all_market_code, ensure_ascii=False) 46 | 47 | save_path = "Data/BoardData/industry_origin/" 48 | if not os.path.exists(save_path): 49 | os.mkdir(save_path) 50 | 51 | for key, value in tqdm(industry_board_name_mapping.items()): 52 | # print(key,value) 53 | try: 54 | # stock_board_industry_hist_em_df = ak.stock_board_industry_hist_em( 55 | # symbol=key, start_date="19900101", end_date="20991231", adjust="") 56 | stock_board_industry_hist_em_df = stock_board_industry_hist_em( 57 | symbol=key, start_date="19900101", end_date="20991231", adjust="" 58 | ) 59 | 60 | stock_board_industry_hist_em_df.rename( 61 | columns=ch_eng_mapping_dict, inplace=True 62 | ) 63 | # print(stock_board_industry_hist_em_df) 64 | 65 | stock_board_industry_hist_em_df.to_csv(save_path + f"{key}.csv", index=False) 66 | 67 | except Exception as e: 68 | logger.error(e) 69 | logger.error(key) 70 | -------------------------------------------------------------------------------- /GetBaseData/get_cash_flow_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-07-10 15:46:40 5 | LastEditTime: 2022-09-23 00:01:35 6 | LastEditors: adolf 7 | """ 8 | 9 | import json 10 | import os 11 | 12 | import akshare as ak 13 | 14 | # from pprint import pprint 15 | import pandas as pd 16 | from loguru import logger 17 | 18 | logger.info("开始获取股票现金流数据") 19 | 20 | import time 21 | 22 | import ray 23 | from tqdm.auto import tqdm 24 | 25 | ray.init() 26 | 27 | 28 | @ray.remote 29 | def get_one_stock_cash_data(code): 30 | try: 31 | if not os.path.exists("Data/CashFlow/"): 32 | os.mkdir("Data/CashFlow/") 33 | 34 | csv_path = f"Data/CashFlow/{code}.csv" 35 | # print(code[0]) 36 | # exit() 37 | if code[0] == "6": 38 | market = "sh" 39 | else: 40 | market = "sz" 41 | now = ak.stock_individual_fund_flow(stock=code, market=market) 42 | if os.path.exists(csv_path): 43 | # print(code) 44 | origin = pd.read_csv(csv_path) 45 | now = pd.merge(origin, now, how="inner") 46 | now.to_csv(csv_path, index=False) 47 | except Exception as e: 48 | logger.error(e) 49 | logger.error(code) 50 | # error_code_list.append(code) 51 | 52 | 53 | # code = 300389 54 | # test_df = ak.stock_individual_fund_flow(stock=code,market="sz") 55 | # exit() 56 | 57 | with open("Data/RealData/ALL_MARKET_CODE.json") as all_market_code: 58 | market_code_dict = json.load(all_market_code) 59 | 60 | code_list = list(market_code_dict.keys()) 61 | 62 | start = time.time() 63 | # error_code_list = [] 64 | # for code in tqdm(code_list): 65 | # get_count = get_one_stock_cash_data(code,error_code_list) 66 | 67 | # error_stock_list = [market_code_dict[code] for code in error_code_list] 68 | # print(error_stock_list) 69 | # print(len(error_code_list)) 70 | 71 | 72 | futures = [get_one_stock_cash_data.remote(code) for code in code_list] 73 | 74 | 75 | def to_iterator(obj_ids): 76 | while obj_ids: 77 | done, obj_ids = ray.wait(obj_ids) 78 | yield ray.get(done[0]) 79 | 80 | 81 | for x in tqdm(to_iterator(futures), total=len(code_list)): 82 | pass 83 | 84 | print("use data:", time.time() - start) 85 | -------------------------------------------------------------------------------- /GetBaseData/get_dc_data.py: -------------------------------------------------------------------------------- 1 | # @Project : stock_quant 2 | # @Date : 2021/11/30 23:07 3 | # @Author : Adolf 4 | # @File : get_dc_data.py 5 | # @Function: 6 | 7 | import json 8 | import shutil 9 | import time 10 | from pathlib import Path 11 | 12 | import akshare as ak 13 | import pandas as pd 14 | import ray 15 | from loguru import logger 16 | from tqdm.auto import tqdm 17 | 18 | from GetBaseData.ch_eng_mapping import ch_eng_mapping_dict 19 | 20 | pd.set_option("expand_frame_repr", False) 21 | 22 | logger.info("开始获取股票日线数据") 23 | 24 | only_hfq = True 25 | all_data = False 26 | # 获取实时行情数据 27 | stock_zh_a_spot_em_df = ak.stock_zh_a_spot_em() 28 | stock_zh_a_spot_em_df.rename(columns=ch_eng_mapping_dict, inplace=True) 29 | stock_zh_a_spot_em_df.sort_values( 30 | by="TMC", ascending=False, inplace=True 31 | ) # 按照市值排序 32 | 33 | if not all_data: 34 | stock_zh_a_spot_em_df = stock_zh_a_spot_em_df.head(500) 35 | 36 | code_list = stock_zh_a_spot_em_df.code.to_list() 37 | 38 | code_name_mapping = stock_zh_a_spot_em_df.set_index(["code"])["name"].to_dict() 39 | # breakpoint() 40 | 41 | data_path = Path("Data") 42 | real_data_path = Path("Data/RealData") 43 | 44 | if not data_path.exists(): 45 | data_path.mkdir() 46 | 47 | if not real_data_path.exists(): 48 | real_data_path.mkdir() 49 | 50 | with real_data_path.joinpath("ALL_MARKET_CODE.json").open("w") as all_market_code: 51 | json.dump(code_name_mapping, all_market_code, ensure_ascii=False) 52 | 53 | ray.init() 54 | 55 | error_code_list = [] 56 | 57 | hfq_path = Path("Data/RealData/hfq/") 58 | if hfq_path.exists(): 59 | shutil.rmtree(hfq_path) 60 | hfq_path.mkdir() 61 | 62 | 63 | if not only_hfq: 64 | qfq_path = Path("Data/RealData/qfq/") 65 | if qfq_path.exists(): 66 | shutil.rmtree(qfq_path) 67 | 68 | origin_path = Path("Data/RealData/origin/") 69 | if origin_path.exists(): 70 | shutil.rmtree(origin_path) 71 | 72 | qfq_path.mkdir() 73 | origin_path.mkdir() 74 | 75 | 76 | @ray.remote 77 | def get_one_stock_data(code): 78 | try: 79 | # 获取后复权数据 80 | stock_zh_a_hist_df = ak.stock_zh_a_hist(symbol=code, adjust="hfq") 81 | stock_zh_a_hist_df.rename(columns=ch_eng_mapping_dict, inplace=True) 82 | # if len(stock_zh_a_hist_df) < 120: 83 | # return 0 84 | stock_zh_a_hist_df["code"] = code 85 | stock_zh_a_hist_df["name"] = code_name_mapping[code] 86 | # stock_zh_a_hist_df["industry"] = get_stock_board_df(code) 87 | stock_zh_a_hist_df.to_csv(hfq_path.joinpath(code + ".csv"), index=False) 88 | 89 | if not only_hfq: 90 | # 获取前复权数据 91 | stock_zh_a_hist_df = ak.stock_zh_a_hist(symbol=code, adjust="qfq") 92 | stock_zh_a_hist_df.rename(columns=ch_eng_mapping_dict, inplace=True) 93 | # if len(stock_zh_a_hist_df) < 120: 94 | # return 0 95 | stock_zh_a_hist_df["code"] = code 96 | stock_zh_a_hist_df["name"] = code_name_mapping[code] 97 | stock_zh_a_hist_df.to_csv(qfq_path.joinpath(code + ".csv"), index=False) 98 | 99 | # 获取原始不复权数据 100 | stock_zh_a_hist_df = ak.stock_zh_a_hist(symbol=code) 101 | stock_zh_a_hist_df.rename(columns=ch_eng_mapping_dict, inplace=True) 102 | # if len(stock_zh_a_hist_df) < 120: 103 | # return 0 104 | stock_zh_a_hist_df["code"] = code 105 | stock_zh_a_hist_df["name"] = code_name_mapping[code] 106 | # stock_zh_a_hist_df["industry"] = get_stock_board_df(code) 107 | stock_zh_a_hist_df.to_csv(origin_path.joinpath(code + ".csv"), index=False) 108 | # pbar.update(1) 109 | 110 | return 0 111 | except Exception as e: 112 | logger.error(code) 113 | logger.error(e) 114 | error_code_list.append(code) 115 | # pbar.update(1) 116 | 117 | 118 | start_time = time.time() 119 | futures = [get_one_stock_data.remote(code) for code in code_list] 120 | 121 | 122 | # ray.get(futures) 123 | # for code in code_list: 124 | # get_one_stock_data(code) 125 | def to_iterator(obj_ids): 126 | while obj_ids: 127 | done, obj_ids = ray.wait(obj_ids) 128 | yield ray.get(done[0]) 129 | 130 | 131 | for _ in tqdm(to_iterator(futures), total=len(code_list)): 132 | pass 133 | 134 | print( 135 | f"本次获取了{len(code_list)}只股票的数据,共用时间为{time.time() - start_time:.2f}" 136 | ) 137 | # pbar.close() 138 | print("date", time.strftime("%Y-%m-%d")) 139 | print("=" * 20) 140 | -------------------------------------------------------------------------------- /GetBaseData/get_history_data.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/2/5 16:00 4 | # @Author : Adolf 5 | # @File : get_history_data.py 6 | # @Function: 7 | # import requests 8 | import time 9 | 10 | import baostock as bs 11 | import pandas as pd 12 | from tqdm import tqdm 13 | 14 | # import akshare as ak 15 | 16 | pd.set_option("expand_frame_repr", False) 17 | pd.set_option("display.max_rows", 1000) 18 | 19 | # 历史行情数据 20 | # stock_zh_a_hist_df = ak.stock_zh_a_hist(symbol="000001", 21 | # period="daily", # choice of {'daily', 'weekly', 'monthly'} 22 | # start_date="20170301", 23 | # end_date='20210907', 24 | # adjust="qfq") 25 | # 26 | # print(stock_zh_a_hist_df) 27 | 28 | # 分时数据 29 | # stock_zh_a_hist_min_em_df = ak.stock_zh_a_hist_min_em(symbol="000001", 30 | # # start_date="2021-09-01 09:32:00", 31 | # # end_date="2021-09-06 09:32:00", 32 | # period='5', # choice of {'1', '5', '15', '30', '60'}; 33 | # adjust='qfq') 34 | # # {'', 'qfq', 'hfq'}; '': 不复权, 'qfq': 前复权, 'hfq': 后复权, 35 | # 分时数据只能返回最近的, 36 | # 其中 1 分钟数据返回近 5 个交易日数据且不复权 37 | # TODO 寻找较长时间的分时数据源头 38 | 39 | # print(stock_zh_a_hist_min_em_df) 40 | 41 | # 指数数据 42 | # stock_zh_index_spot_df = ak.stock_zh_index_spot() 43 | # print(stock_zh_index_spot_df) 44 | 45 | # "上证指数", "深证成指", "创业板指", "沪深300","中证500" 46 | # index_list = ["sh000001", "sz399001", "sz399006", "sz399300", "sh000905"] 47 | 48 | # stock_zh_index_daily_tx_df = ak.stock_zh_index_daily_tx(symbol="sh000001") 49 | # print(stock_zh_index_daily_tx_df) 50 | 51 | # 使用baostock获取数据 52 | 53 | lg = bs.login() 54 | date = "2023-01-20" 55 | stock_df = bs.query_all_stock(date).get_data() 56 | # print(stock_df) 57 | 58 | # 获取沪深A股历史K线数据 59 | # 详细指标参数,参见“历史行情指标参数”章节;“分钟线”参数与“日线”参数不同。“分钟线”不包含指数。 60 | # 分钟线指标:date,time,code,open,high,low,close,volume,amount,adjustflag 61 | # 周月线指标:date,code,open,high,low,close,volume,amount,adjustflag,turn,pctChg 62 | # adjustflag 复权状态(1:后复权, 2:前复权,3:不复权) 63 | 64 | st_time = time.time() 65 | 66 | for index, row in tqdm(stock_df.iterrows(), total=stock_df.shape[0]): 67 | if row["tradeStatus"] == "0" or "bj" in row["code"]: 68 | continue 69 | code = row["code"] 70 | rs = bs.query_history_k_data_plus( 71 | code, 72 | "date,code,open,high,low,close,volume,amount,turn,peTTM,pctChg,tradestatus,isST", 73 | start_date="1990-12-19", 74 | end_date=date, 75 | frequency="d", 76 | adjustflag="1", 77 | ) 78 | # 打印结果集 79 | data_list = [] 80 | while (rs.error_code == "0") & rs.next(): 81 | # 获取一条记录,将记录合并在一起 82 | data_list.append(rs.get_row_data()) 83 | result = pd.DataFrame(data_list, columns=rs.fields) 84 | result.to_csv("Data/RealData/hfq/" + code + ".csv", index=False) 85 | # print(result) 86 | # break 87 | 88 | print("耗时: ", time.time() - st_time) 89 | # 登出系统 90 | bs.logout() 91 | -------------------------------------------------------------------------------- /GetBaseData/handle_data_show.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*-''' 3 | # @Project : stock_quant 4 | # @Date : 2022/1/6 17:09 5 | # @Author : Adolf 6 | # @File : handle_data_show.py 7 | import pandas as pd 8 | 9 | from Utils.TechnicalIndicators.basic_indicators import MACD 10 | 11 | # def get_show_data(_df): 12 | 13 | # if isinstance(_df, pd.DataFrame): 14 | # macd_df = _df[["macd", "histogram", "signal"]] 15 | # macd_df = macd_df.rename( 16 | # columns={"macd": "MACD_12_26_9", "histogram": "MACDh_12_26_9", "signal": "MACDs_12_26_9", }) 17 | # # print(macd_df) 18 | 19 | # elif isinstance(_df, str): 20 | # _df = pd.read_csv(_df) 21 | 22 | # _df = _df[-300:] 23 | 24 | # macd_df = ta.macd(close=_df['close']) 25 | # macd_df.fillna(0, inplace=True) 26 | # # print(macd_df) 27 | # # print(_df) 28 | 29 | # else: 30 | # macd_df = None 31 | 32 | # # oclh 33 | # datas = [list(oclh) for oclh in 34 | # zip(_df["open"].tolist(), _df["close"].tolist(), _df["high"].tolist(), _df["low"].tolist(), 35 | # _df["volume"].tolist(), macd_df["MACDh_12_26_9"].tolist(), macd_df["MACD_12_26_9"].tolist(), 36 | # macd_df['MACDs_12_26_9'].tolist())] 37 | 38 | # times = _df["date"].tolist() 39 | # vols = _df["volume"].tolist() 40 | # macds = macd_df["MACDh_12_26_9"].tolist() 41 | # difs = macd_df["MACD_12_26_9"].tolist() 42 | # deas = macd_df['MACDs_12_26_9'].tolist() 43 | 44 | # # print(times) 45 | 46 | # # return df.to_dict(orient="list") 47 | # return { 48 | # "datas": datas, 49 | # "times": times, 50 | # "vols": vols, 51 | # "macds": macds, 52 | # "difs": difs, 53 | # "deas": deas, 54 | # } 55 | 56 | 57 | def show_data_from_df( 58 | df_or_dfpath: str = None, 59 | use_all_data: bool = True, 60 | start_date: str = None, 61 | end_date: str = None, 62 | ): 63 | if isinstance(df_or_dfpath, pd.DataFrame): 64 | pass 65 | 66 | elif isinstance(df_or_dfpath, str): 67 | df_or_dfpath = pd.read_csv(df_or_dfpath) 68 | 69 | else: 70 | raise ValueError("df_or_dfpath must be str or pd.DataFrame") 71 | 72 | if "MACD" not in df_or_dfpath.columns: 73 | # print(macd_df) 74 | # breakpoint() 75 | df_or_dfpath["DIFF"], df_or_dfpath["DEA"], df_or_dfpath["MACD"] = MACD( 76 | df_or_dfpath["close"] 77 | ) 78 | 79 | if not use_all_data: 80 | df_or_dfpath = df_or_dfpath[-300:] 81 | 82 | if start_date is not None: 83 | df_or_dfpath = df_or_dfpath[df_or_dfpath["date"] >= start_date] 84 | 85 | if end_date is not None: 86 | df_or_dfpath = df_or_dfpath[df_or_dfpath["date"] <= end_date] 87 | 88 | # df_or_dfpath = df_or_dfpath[-60:] 89 | # df_or_dfpath = df_or_dfpath[:30] 90 | datas = [ 91 | list(oclh) 92 | for oclh in zip( 93 | df_or_dfpath["open"].tolist(), 94 | df_or_dfpath["close"].tolist(), 95 | df_or_dfpath["high"].tolist(), 96 | df_or_dfpath["low"].tolist(), strict=False, 97 | ) 98 | ] 99 | if "index" in df_or_dfpath: 100 | df_or_dfpath["date"] = ( 101 | df_or_dfpath["date"] + "_" + df_or_dfpath["index"].map(str) 102 | ) 103 | 104 | # import pdb;pdb.set_trace() 105 | buy_list = df_or_dfpath["buy"].tolist() if "buy" in df_or_dfpath else [] 106 | sell_list = df_or_dfpath["sell"].tolist() if "sell" in df_or_dfpath else [] 107 | # import pdb; pdb.set_trace() 108 | 109 | return { 110 | "datas": datas, 111 | "times": df_or_dfpath["date"].tolist(), 112 | "vols": df_or_dfpath["volume"].tolist(), 113 | "macds": df_or_dfpath["MACD"].tolist(), 114 | "difs": df_or_dfpath["DIFF"].tolist(), 115 | "deas": df_or_dfpath["DEA"].tolist(), 116 | "buy": buy_list, 117 | "sell": sell_list, 118 | } 119 | 120 | 121 | if __name__ == "__main__": 122 | csv_path = "Data/RealData/qfq/600570.csv" 123 | print(show_data_from_df(df_or_dfpath=csv_path)) 124 | # draw_chart(show_data_from_df(df_or_dfpath=csv_path)) 125 | -------------------------------------------------------------------------------- /GetBaseData/thirdpart/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*-''' 3 | # @Project : stock_quant 4 | # @Date : 2022/2/8 16:08 5 | # @Author : Adolf 6 | # @File : __init__.py.py 7 | -------------------------------------------------------------------------------- /GetBaseData/use_baostock/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # @Time : 2023/1/23 15:06 3 | # @Author : Adolf 4 | # @Site : 5 | # @File : __init__.py.py 6 | # @Software: PyCharm 7 | -------------------------------------------------------------------------------- /GetBaseData/use_baostock/cal_base_indicators.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @Project :stock_quant 4 | @File :cal_base_indicators.py 5 | @Author :adolf 6 | @Date :2023/4/21 22:00 7 | """ 8 | 9 | import os 10 | 11 | import pandas as pd 12 | from finta import TA 13 | 14 | pd.set_option("display.max_columns", None) 15 | 16 | 17 | def cal_base_indicators(code_path): 18 | base_data = pd.read_csv("Data/Baostock/day/" + code_path) 19 | 20 | # print(base_data) 21 | base_data["MA5"] = TA.SMA(base_data, period=5) 22 | base_data["MA10"] = TA.SMA(base_data, period=10) 23 | base_data["MA20"] = TA.SMA(base_data, period=20) 24 | base_data["MA30"] = TA.SMA(base_data, period=30) 25 | base_data["MA60"] = TA.SMA(base_data, period=60) 26 | 27 | if "HISTOGRAM" not in base_data.columns: 28 | macd_df = TA.MACD(base_data) 29 | macd_df["HISTOGRAM"] = macd_df["MACD"] - macd_df["SIGNAL"] 30 | base_data = pd.concat([base_data, macd_df], axis=1) 31 | # print(final_data.tail()) 32 | return base_data 33 | 34 | 35 | code_list = os.listdir("Data/Baostock/day/") 36 | for code in code_list: 37 | df = cal_base_indicators(code) 38 | df.to_csv("Data/Baostock/day/" + code, index=False) 39 | print(code) 40 | -------------------------------------------------------------------------------- /GetBaseData/use_baostock/stock_k_data_dask.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # @Time : 2023/1/23 15:08 3 | # @Author : Adolf 4 | # @Site : 5 | # @File : stock_k_data_dask.py 6 | # @Software: PyCharm 7 | import baostock as bs 8 | import pandas as pd 9 | 10 | # 对文件夹进行清空处理 11 | # dir_path = Path("Data/RealData/hfq/") 12 | # if dir_path.exists() and dir_path.is_dir(): 13 | # shutil.rmtree(dir_path) 14 | # dir_path.mkdir(parents=True, exist_ok=True) 15 | 16 | bs.login() 17 | 18 | date = "2024-09-21" 19 | # 获取交易日当天的交易股票 20 | # stock_df = bs.query_all_stock(date).get_data() 21 | 22 | 23 | def get_base_k_data(code, start_date, end_date, frequency): 24 | if start_date is None: 25 | start_date = "1990-12-19" 26 | if end_date is None: 27 | end_date = date 28 | if frequency is None: 29 | frequency = "d" 30 | rs = bs.query_history_k_data_plus( 31 | code, 32 | "date,code,open,high,low,close,volume,amount,turn,peTTM,pctChg,tradestatus,isST", 33 | start_date=start_date, 34 | end_date=end_date, 35 | frequency=frequency, 36 | adjustflag="1", 37 | ) 38 | # 打印结果集 39 | data_list = [] 40 | while (rs.error_code == "0") & rs.next(): 41 | # 获取一条记录,将记录合并在一起 42 | data_list.append(rs.get_row_data()) 43 | result = pd.DataFrame(data_list, columns=rs.fields) 44 | result.to_csv("Data/RealData/hfq/" + code + ".csv", index=False) 45 | return result 46 | 47 | 48 | # if __name__ == '__main__': 49 | # # client = Client(n_workers=4) 50 | # 51 | # start_time = time.time() 52 | # # futures = [] 53 | # for index, row in tqdm(stock_df.iterrows(), total=stock_df.shape[0]): 54 | # if row['tradeStatus'] == '0' or "bj" in row["code"]: 55 | # continue 56 | # _code = row['code'] 57 | # get_base_k_data(_code) 58 | # # future = client.submit(get_base_k_data, _code) 59 | # # futures.append(future) 60 | # 61 | # # progress(futures) 62 | # print("use time: {}".format(time.time() - start_time)) 63 | # # client.close() 64 | # bs.logout() 65 | # get_base_k_data("sh.600570") 66 | -------------------------------------------------------------------------------- /GetBaseData/use_baostock/stock_k_multi_cycle.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2023-02-05 15:45:49 4 | @LastEditors : adolf 5 | @LastEditTime : 2023-06-10 16:18:11 6 | @FilePath : /stock_quant/GetBaseData/use_baostock/stock_k_multi_cycle.py 7 | """ 8 | 9 | import os 10 | 11 | import baostock as bs 12 | import pandas as pd 13 | from tqdm import tqdm 14 | 15 | # import akshare as ak 16 | 17 | # 获取 1990-12-19 到 2023-12-29的交易日历 18 | # tool_trade_date_hist_sina_df = ak.tool_trade_date_hist_sina() 19 | # print(tool_trade_date_hist_sina_df) 20 | 21 | 22 | def get_base_k_data(code, last_day, file_name="day", frequency="d"): 23 | rs = bs.query_history_k_data_plus( 24 | code, 25 | # "date,code,open,high,low,close,volume,amount,turn,peTTM,tradestatus,isST", 26 | "date,code,open,high,low,close,volume,amount", 27 | # start_date="1990-12-19", 28 | # end_date=last_day, 29 | frequency=frequency, 30 | adjustflag="3", 31 | ) 32 | # 打印结果集 33 | data_list = [] 34 | while (rs.error_code == "0") & rs.next(): 35 | # 获取一条记录,将记录合并在一起 36 | data_list.append(rs.get_row_data()) 37 | result = pd.DataFrame(data_list, columns=rs.fields) 38 | # 判断是否存在文件夹 39 | if not os.path.exists("Data/Baostock/" + file_name): 40 | os.makedirs("Data/Baostock/" + file_name) 41 | result.to_csv("Data/Baostock/" + file_name + code + ".csv", index=False) 42 | 43 | 44 | def main(): 45 | bs.login() 46 | 47 | time_cycle = { 48 | # "5": "5min/", 49 | # "15": "15min/", 50 | # "30": "30min/", 51 | # "60": "60min/", 52 | "d": "day/" 53 | # "w": "week/", 54 | # "m": "month/", 55 | } 56 | 57 | last_day = "2023-06-10" 58 | 59 | stock_df = bs.query_all_stock(last_day).get_data() 60 | # print(stock_df) 61 | code_list = [] 62 | for _, row in stock_df.iterrows(): 63 | if ( 64 | "510300" in row.code 65 | or "000001" in row.code 66 | or "399006" in row.code 67 | or "399106" in row.code 68 | ): 69 | print("row.code:", row.code) 70 | code_list.append(row.code) 71 | if row.code[:6] in [ 72 | "sh.600", 73 | "sh.601", 74 | "sh.603", 75 | "sh.605", 76 | "sz.300", 77 | "sz.000", 78 | "sz.002", 79 | ]: 80 | code_list.append(row.code) 81 | # print(code_list) 82 | # print(len(code_list)) 83 | code_list = ["sh.000001", "sz.399001", "sz.399006"] 84 | 85 | for code_name in tqdm(code_list, total=len(code_list)): 86 | for frequency in time_cycle: 87 | file_name = time_cycle[frequency] 88 | get_base_k_data( 89 | code=code_name, 90 | last_day=last_day, 91 | file_name=file_name, 92 | frequency=frequency, 93 | ) 94 | 95 | bs.logout() 96 | 97 | 98 | main() 99 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 PKQ1688 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MachineLearning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKQ1688/stock_quant/c88e695a5cf17f2a445b671026f1535d369910ce/MachineLearning/__init__.py -------------------------------------------------------------------------------- /MachineLearning/annotation_platform/buy_and_sell_signals.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author : adolf adolf1321794021@gmail.com 3 | @Date : 2023-03-09 21:40:11 4 | @LastEditors : adolf 5 | @LastEditTime : 2023-03-12 21:29:22 6 | @FilePath : /stock_quant/MachineLearning/annotation_platform/buy_and_sell_signals.py 7 | @Description : 8 | """ 9 | 10 | import os 11 | import random 12 | from datetime import datetime, timedelta 13 | 14 | import akshare as ak 15 | import numpy as np 16 | import pandas as pd 17 | import streamlit as st 18 | from loguru import logger 19 | from streamlit_echarts import st_pyecharts 20 | 21 | from GetBaseData.handle_data_show import show_data_from_df 22 | from Utils.ShowKline.base_kline import draw_chart 23 | 24 | # st.set_page_config(page_title="股票买卖标注平台", layout="wide") 25 | 26 | 27 | def _init_session_state(): 28 | # 缓存内存 29 | if "code_name" not in st.session_state: 30 | st.session_state["code_name"] = "sh.600570" 31 | 32 | if "start_time" not in st.session_state: 33 | st.session_state["start_time"] = None 34 | 35 | if "end_time" not in st.session_state: 36 | st.session_state["end_time"] = None 37 | 38 | if "trade_calendar" not in st.session_state: 39 | st.session_state["trade_calendar"] = ( 40 | ak.tool_trade_date_hist_sina().trade_date.tolist() 41 | ) 42 | 43 | if "account_status" not in st.session_state: 44 | st.session_state["account_status"] = "cash" 45 | 46 | 47 | _init_session_state() 48 | 49 | # st.write(st.session_state["trade_calendar"]) 50 | 51 | 52 | def base_show_fun(code_name, start_time, end_time): 53 | start_time = datetime.strftime(start_time, "%Y-%m-%d") 54 | end_time = datetime.strftime(end_time, "%Y-%m-%d") 55 | 56 | _show_data = show_data_from_df( 57 | f"Data/RealData/Baostock/day/{code_name}.csv", 58 | start_date=start_time, 59 | end_date=end_time, 60 | ) 61 | chart = draw_chart(_show_data) 62 | st_pyecharts(chart, height="600%", width="100%") 63 | 64 | if st.session_state["account_status"] == "cash": 65 | st.write("账户状态:现金") 66 | _trade_result = st.radio(label="操作", options=["买", "保持"]) 67 | if _trade_result == "买": 68 | st.session_state["account_status"] = "hold" 69 | else: 70 | st.write("账户状态:持仓") 71 | _trade_result = st.radio(label="操作", options=["卖", "保持"]) 72 | if _trade_result == "卖": 73 | st.session_state["account_status"] = "cash" 74 | 75 | # _next_button = st.button("next day") 76 | 77 | return _trade_result, _show_data 78 | 79 | 80 | def annotation_platform_main(): 81 | st.title("买卖信号标注平台") 82 | 83 | label_tab, dataset_tab = st.tabs(["Label", "Dataset"]) 84 | 85 | code_list = os.listdir("Data/RealData/Baostock/day") 86 | 87 | with label_tab: 88 | random_code = st.button("随机股票") 89 | if random_code: 90 | st.session_state["code_name"] = random.choice(code_list).replace(".csv", "") 91 | 92 | st.session_state["code_name"] = st.text_input( 93 | "股票代码,上证带上sh,深圳带上sz", value=st.session_state["code_name"] 94 | ) 95 | 96 | logger.info(st.session_state["code_name"]) 97 | st.markdown("### 股票代码 === {}".format(st.session_state["code_name"])) 98 | 99 | if st.session_state["start_time"] is None: 100 | st.session_state["start_time"] = st.date_input( 101 | "开始时间", 102 | value=pd.to_datetime("2019-01-01"), 103 | max_value=pd.to_datetime("2023-03-01"), 104 | ) 105 | 106 | if st.session_state["end_time"] is None: 107 | st.session_state["end_time"] = st.date_input( 108 | "结束时间", 109 | value=pd.to_datetime("2020-01-01"), 110 | max_value=pd.to_datetime("2023-03-01"), 111 | ) 112 | 113 | while st.session_state["end_time"] not in st.session_state["trade_calendar"]: 114 | st.session_state["end_time"] = st.session_state["end_time"] + timedelta( 115 | days=1 116 | ) 117 | 118 | trade_result, show_data = base_show_fun( 119 | code_name=st.session_state["code_name"], 120 | start_time=st.session_state["start_time"], 121 | end_time=st.session_state["end_time"], 122 | ) 123 | next_button = st.button("next day") 124 | 125 | if next_button: 126 | with open(f"Data/LabelData/{st.session_state['code_name']}.tsv", "a") as f: 127 | save_data = list() 128 | save_data.append(st.session_state["code_name"]) 129 | save_data.append(show_data["times"][0]) 130 | save_data.append(show_data["times"][-1]) 131 | save_data.append(trade_result) 132 | save_data.append(st.session_state["account_status"]) 133 | line = "\t".join(save_data) 134 | f.write(f"{line}\n") 135 | 136 | st.session_state["end_time"] = st.session_state["end_time"] + timedelta( 137 | days=1 138 | ) 139 | while ( 140 | st.session_state["end_time"] not in st.session_state["trade_calendar"] 141 | ): 142 | st.session_state["end_time"] = st.session_state["end_time"] + timedelta( 143 | days=1 144 | ) 145 | st.success("保存成功", icon="✅") 146 | st.experimental_rerun() 147 | 148 | refresh = st.button("重开") 149 | if refresh: 150 | _init_session_state() 151 | st.experimental_rerun() 152 | 153 | with dataset_tab: 154 | # pass 155 | rank_texts_list = [] 156 | # logger.info(st.session_state) 157 | # 判断一个文件是否存在 158 | if os.path.exists(f"Data/LabelData/{st.session_state['code_name']}.tsv"): 159 | with open( 160 | f"Data/LabelData/{st.session_state['code_name']}.tsv", encoding="utf8" 161 | ) as f: 162 | for i, line in enumerate(f.readlines()): 163 | texts = line.strip().split("\t") 164 | rank_texts_list.append(texts) 165 | if len(rank_texts_list) == 0: 166 | st.write("还没有相关数据") 167 | else: 168 | df = pd.DataFrame( 169 | np.array(rank_texts_list), 170 | columns=( 171 | ["code", "start_time", "end_time", "trade_result", "account_status"] 172 | ), 173 | ) 174 | st.dataframe(df, use_container_width=True) 175 | 176 | 177 | if __name__ == "__main__": 178 | annotation_platform_main() 179 | -------------------------------------------------------------------------------- /MachineLearning/data_process/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-08-30 22:40:28 5 | LastEditTime: 2022-08-30 22:40:48 6 | LastEditors: adolf 7 | """ 8 | -------------------------------------------------------------------------------- /MachineLearning/data_process/base_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-08-21 15:26:58 5 | LastEditTime: 2022-08-21 15:27:02 6 | LastEditors: adolf 7 | """ 8 | 9 | import os 10 | import traceback 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import ray 15 | from tqdm.auto import tqdm 16 | 17 | # pd.set_option('display.max_columns', None) 18 | 19 | 20 | def init_dict(): 21 | new_data_dict = {} 22 | 23 | # new_data_dict["code"] = [] 24 | for i in range(60): 25 | new_data_dict[f"open_{i}"] = [] 26 | new_data_dict[f"high_{i}"] = [] 27 | new_data_dict[f"low_{i}"] = [] 28 | new_data_dict[f"close_{i}"] = [] 29 | new_data_dict[f"volume_{i}"] = [] 30 | new_data_dict[f"turn_{i}"] = [] 31 | new_data_dict["label"] = [] 32 | return new_data_dict 33 | 34 | 35 | def get_handle_data(data_name): 36 | try: 37 | # if True: 38 | data_t = pd.read_csv(f"Data/RealData/hfq/{data_name}", dtype={"code": str}) 39 | data_t = data_t[ 40 | ["date", "open", "high", "low", "close", "volume", "turn", "code", "pctChg"] 41 | ] 42 | 43 | new_data_dict = init_dict() 44 | 45 | for index, row in data_t.iterrows(): 46 | if index < 60: 47 | continue 48 | # print(row) 49 | if index >= len(data_t) - 1: 50 | break 51 | tmp_data = data_t[index - 60 : index].copy() 52 | 53 | for feature in ["open", "high", "low", "close", "volume"]: 54 | tmp_data[feature] = tmp_data[[feature]].apply( 55 | lambda x: (x - np.min(x)) / (np.max(x) - np.min(x)) 56 | ) 57 | 58 | # tmp_data["turn"] = tmp_data["turn"].apply(lambda x: x / 100) 59 | 60 | # print(tmp_data) 61 | # exit() 62 | 63 | open_list = tmp_data.open.values.tolist() 64 | close_list = tmp_data.close.values.tolist() 65 | high_list = tmp_data.high.values.tolist() 66 | low_list = tmp_data.low.values.tolist() 67 | volume_list = tmp_data.volume.values.tolist() 68 | turn_list = tmp_data.turn.values.tolist() 69 | 70 | # new_data_dict["code"].append(row.code) 71 | 72 | for i in range(60): 73 | new_data_dict.get(f"open_{i}").append(open_list[i]) 74 | new_data_dict.get(f"high_{i}").append(high_list[i]) 75 | new_data_dict.get(f"low_{i}").append(low_list[i]) 76 | new_data_dict.get(f"close_{i}").append(close_list[i]) 77 | new_data_dict.get(f"volume_{i}").append(volume_list[i]) 78 | new_data_dict.get(f"turn_{i}").append(turn_list[i]) 79 | 80 | # new_data_dict["pct"].append(data_t.loc[index + 1, "pctChg"]) 81 | next_pct = data_t.loc[index + 1, "pctChg"] 82 | if next_pct > 7: 83 | new_data_dict["label"].append("超强") 84 | elif next_pct > 3: 85 | new_data_dict["label"].append("中强") 86 | elif next_pct > 0: 87 | new_data_dict["label"].append("小强") 88 | elif next_pct > -3: 89 | new_data_dict["label"].append("小弱") 90 | elif next_pct > -7: 91 | new_data_dict["label"].append("中弱") 92 | else: 93 | new_data_dict["label"].append("超弱") 94 | 95 | res_data = pd.DataFrame(new_data_dict) 96 | # res_data = res_data.drop(columns=['code']) 97 | 98 | # print(res_data) 99 | if not os.path.exists("Data/HandleData/base_ohlcv_data"): 100 | os.mkdir("Data/HandleData/base_ohlcv_data") 101 | 102 | res_data.to_csv(f"Data/HandleData/base_ohlcv_data/{data_name}", index=False) 103 | return res_data 104 | except: 105 | # else: 106 | print(traceback.format_exc()) 107 | print(data_name) 108 | return None 109 | 110 | 111 | # get_handle_data("000001.csv") 112 | 113 | 114 | @ray.remote 115 | def ray_get_handle_data(data_name): 116 | return get_handle_data(data_name) 117 | 118 | 119 | if __name__ == "__main__": 120 | import os 121 | import time 122 | 123 | ray.init() 124 | 125 | # for data_name in os.listdir("Data/RealData/hfq"): 126 | # print(data_name) 127 | # get_handle_data.remote(data_name) 128 | # break 129 | 130 | start_time = time.time() 131 | 132 | futures = [ 133 | ray_get_handle_data.remote(code) for code in os.listdir("Data/RealData/hfq") 134 | ] 135 | 136 | def to_iterator(obj_ids): 137 | while obj_ids: 138 | done, obj_ids = ray.wait(obj_ids) 139 | yield ray.get(done[0]) 140 | 141 | for x in tqdm(to_iterator(futures), total=len(os.listdir("Data/RealData/hfq"))): 142 | pass 143 | 144 | print("use time: ", time.time() - start_time) 145 | -------------------------------------------------------------------------------- /MachineLearning/data_process/dask_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-11-19 21:06:39 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2022-11-23 21:41:43 6 | FilePath : /stock_quant/MachineLearning/data_process/dask_data.py 7 | """ 8 | 9 | import time 10 | import traceback 11 | 12 | import numpy as np 13 | 14 | # import dask.array as da 15 | import pandas as pd 16 | from dask.distributed import Client, progress 17 | 18 | # from MachineLearning.data_process.base_data import get_handle_data 19 | 20 | 21 | def init_dict(): 22 | new_data_dict = {} 23 | 24 | # new_data_dict["code"] = [] 25 | for i in range(60): 26 | new_data_dict[f"open_{i}"] = [] 27 | new_data_dict[f"high_{i}"] = [] 28 | new_data_dict[f"low_{i}"] = [] 29 | new_data_dict[f"close_{i}"] = [] 30 | new_data_dict[f"volume_{i}"] = [] 31 | new_data_dict[f"turn_{i}"] = [] 32 | new_data_dict["label"] = [] 33 | return new_data_dict 34 | 35 | 36 | def get_handle_data(data_name): 37 | try: 38 | # if True: 39 | data_t = pd.read_csv(f"Data/RealData/hfq/{data_name}", dtype={"code": str}) 40 | data_t = data_t[ 41 | ["date", "open", "high", "low", "close", "volume", "turn", "code", "pctChg"] 42 | ] 43 | 44 | new_data_dict = init_dict() 45 | 46 | for index, row in data_t.iterrows(): 47 | if index < 60: 48 | continue 49 | # print(row) 50 | if index >= len(data_t) - 1: 51 | break 52 | tmp_data = data_t[index - 60 : index].copy() 53 | 54 | for feature in ["open", "high", "low", "close", "volume"]: 55 | tmp_data[feature] = tmp_data[[feature]].apply( 56 | lambda x: (x - np.min(x)) / (np.max(x) - np.min(x)) 57 | ) 58 | 59 | # tmp_data["turn"] = tmp_data["turn"].apply(lambda x: x / 100) 60 | 61 | # print(tmp_data) 62 | # exit() 63 | 64 | open_list = tmp_data.open.values.tolist() 65 | close_list = tmp_data.close.values.tolist() 66 | high_list = tmp_data.high.values.tolist() 67 | low_list = tmp_data.low.values.tolist() 68 | volume_list = tmp_data.volume.values.tolist() 69 | turn_list = tmp_data.turn.values.tolist() 70 | 71 | # new_data_dict["code"].append(row.code) 72 | 73 | for i in range(60): 74 | new_data_dict.get(f"open_{i}").append(open_list[i]) 75 | new_data_dict.get(f"high_{i}").append(high_list[i]) 76 | new_data_dict.get(f"low_{i}").append(low_list[i]) 77 | new_data_dict.get(f"close_{i}").append(close_list[i]) 78 | new_data_dict.get(f"volume_{i}").append(volume_list[i]) 79 | new_data_dict.get(f"turn_{i}").append(turn_list[i]) 80 | 81 | # new_data_dict["pct"].append(data_t.loc[index + 1, "pctChg"]) 82 | next_pct = data_t.loc[index + 1, "pctChg"] 83 | if next_pct > 7: 84 | # 超强 85 | new_data_dict["label"].append(0) 86 | elif next_pct > 3: 87 | # 中强 88 | new_data_dict["label"].append(1) 89 | elif next_pct > 0: 90 | # 小强 91 | new_data_dict["label"].append(2) 92 | elif next_pct > -3: 93 | # 小弱 94 | new_data_dict["label"].append(3) 95 | elif next_pct > -7: 96 | # 中弱 97 | new_data_dict["label"].append(4) 98 | else: 99 | # 超弱 100 | new_data_dict["label"].append(5) 101 | 102 | res_data = pd.DataFrame(new_data_dict) 103 | # res_data = res_data.drop(columns=['code']) 104 | 105 | # print(res_data) 106 | if not os.path.exists("Data/HandleData/base_ohlcv_data"): 107 | os.mkdir("Data/HandleData/base_ohlcv_data") 108 | 109 | res_data.to_csv(f"Data/HandleData/base_ohlcv_data/{data_name}", index=False) 110 | return res_data 111 | except: 112 | # else: 113 | print(traceback.format_exc()) 114 | print(data_name) 115 | res = {"errpr_message": traceback.format_exc(), "data_name": data_name} 116 | return os.getcwd() 117 | 118 | 119 | # @dask.delayed 120 | # def process_file(filename): 121 | # return get_handle_data(filename) 122 | 123 | 124 | if __name__ == "__main__": 125 | import os 126 | 127 | start_time = time.time() 128 | 129 | all_data_list = os.listdir("Data/RealData/hfq") 130 | # with pathos.multiprocessing.ProcessingPool(8) as p: 131 | # result = list( 132 | # tqdm( 133 | # p.imap(get_handle_data, all_data_list), 134 | # total=len(all_data_list), 135 | # desc="使用python进程池对数据进行预处理", 136 | # ) 137 | # ) 138 | 139 | # client = Client(n_workers=8, threads_per_worker=1, processes=True) 140 | client = Client("tcp://127.0.0.1:8786") 141 | 142 | # contents = [] 143 | # for filename in all_data_list: 144 | # contents.append(process_file(filename)) 145 | 146 | # print("submit use time: {}".format(time.time() - ts)) 147 | 148 | # with ProgressBar(): 149 | # res = dask.compute(contents)[0] 150 | 151 | futures = [] 152 | for filename in all_data_list: 153 | future = client.submit(get_handle_data, filename) 154 | futures.append(future) 155 | 156 | # total = client.map(reulst,futures) 157 | 158 | # progress(client.gather(futures)) 159 | # results = [future.result() for future in futures] 160 | # with ProgressBar(): 161 | # results = client.gather(futures) # this can be faster 162 | # progress(client.gather(futures)) 163 | # progress(results) 164 | progress(futures) 165 | print(futures[0].result()) 166 | print(f"use time: {time.time() - start_time}") 167 | -------------------------------------------------------------------------------- /MachineLearning/data_process/indicators_cal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2023-01-09 22:34:00 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2023-01-14 19:32:41 6 | FilePath : /stock_quant/MachineLearning/data_process/indicators_cal.py 7 | """ 8 | 9 | import warnings 10 | 11 | import pandas as pd 12 | import pandas_ta as ta 13 | 14 | warnings.filterwarnings("ignore", category=FutureWarning) 15 | pd.set_option("display.max_columns", None) 16 | pd.set_option("display.max_rows", None) 17 | 18 | 19 | def cal_indicators(filename): 20 | # print(filename.name) 21 | try: 22 | data = pd.read_csv(filename) 23 | # data["pct"] = data.close.pct_change() 24 | 25 | # 计算5日均线和10日均线 26 | data["sma5"] = ta.sma(data.close, length=5) 27 | data["sma10"] = ta.sma(data.close, length=10) 28 | 29 | # 计算5日均线和10日均线的交叉 30 | # data["ma_long"] = ta.cross(data.sma5, data.sma10) 31 | # data["ma_short"] = ta.cross(data.sma10, data.sma5) 32 | 33 | # 计算macd的值 34 | data[["macd", "histogram", "signal"]] = ta.macd( 35 | data.close, fast=12, slow=26, signal=9 36 | ) 37 | 38 | # # 计算bolinger band的值 39 | # data[["lower", "mid", "upper", "width", "percent"]] = ta.bbands( 40 | # data.close, length=20, std=2 41 | # ) 42 | 43 | # 计算atr的值 44 | data["atr"] = ta.atr(data.high, data.low, data.close, length=14) 45 | 46 | data["pct"] = data.pctChg.shift(-1) 47 | print(data.tail(30)) 48 | exit() 49 | 50 | data.drop( 51 | ["date", "amount", "amplitude", "priceChg", "code", "name"], 52 | axis=1, 53 | inplace=True, 54 | ) 55 | 56 | # print(data.tail(30)) 57 | data.dropna(inplace=True) 58 | data.to_csv(f"Data/HandleData/indicator_data/{filename.name}", index=False) 59 | return data 60 | except Exception as e: 61 | print(e) 62 | print(filename.name) 63 | return None 64 | 65 | 66 | if __name__ == "__main__": 67 | # import dask 68 | import pathlib 69 | 70 | from dask.distributed import Client, LocalCluster, progress 71 | 72 | client = Client(LocalCluster(n_workers=4, threads_per_worker=1, memory_limit="2GB")) 73 | 74 | futures = [] 75 | 76 | file_path = pathlib.Path("Data/RealData/hfq/") 77 | for filename in file_path.glob("*.csv"): 78 | # print(filename) 79 | future = client.submit(cal_indicators, filename) 80 | # break 81 | futures.append(future) 82 | # if len(futures) > 10: 83 | # break 84 | 85 | progress(futures) 86 | # print(futures[0].result()) 87 | -------------------------------------------------------------------------------- /MachineLearning/model_train/Informer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @Project :stock_quant 4 | @File :__init__.py.py 5 | @Author :adolf 6 | @Date :2023/4/5 21:26 7 | """ 8 | -------------------------------------------------------------------------------- /MachineLearning/model_train/Informer/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @Project :stock_quant 4 | @File :train.py 5 | @Author :adolf 6 | @Date :2023/4/5 21:41 7 | """ 8 | 9 | from datasets import load_dataset 10 | 11 | dataset = load_dataset("monash_tsf", "traffic_hourly") 12 | print(dataset) 13 | -------------------------------------------------------------------------------- /MachineLearning/model_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKQ1688/stock_quant/c88e695a5cf17f2a445b671026f1535d369910ce/MachineLearning/model_train/__init__.py -------------------------------------------------------------------------------- /MachineLearning/model_train/base_lgb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-11-22 23:18:30 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2023-01-11 23:00:24 6 | FilePath : /stock_quant/MachineLearning/model_train/base_lgb.py 7 | """ 8 | 9 | import lightgbm as lgb 10 | import pandas as pd 11 | from sklearn.metrics import mean_squared_error 12 | 13 | print("Load data...") 14 | df_train = pd.read_csv("Data/HandleData/indicator_data/000001.csv") 15 | df_test = pd.read_csv("Data/HandleData/indicator_data/000002.csv") 16 | 17 | y_train = df_train["pctChg"].values 18 | y_test = df_test["pctChg"].values 19 | 20 | X_train = df_train.drop("pctChg", axis=1).values 21 | X_test = df_test.drop("pctChg", axis=1).values 22 | 23 | # print(X_train) 24 | 25 | # create dataset for lightgbm 26 | lgb_train = lgb.Dataset(X_train, y_train) 27 | lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) 28 | 29 | # specify your configurations as a dict 30 | params = { 31 | "task": "train", 32 | "boosting_type": "gbdt", 33 | "objective": "regression", 34 | "metric": {"l2", "auc"}, 35 | "num_leaves": 31, 36 | "learning_rate": 0.05, 37 | "feature_fraction": 0.9, 38 | "bagging_fraction": 0.8, 39 | "bagging_freq": 5, 40 | "verbose": 0, 41 | } 42 | 43 | print("Start training...") 44 | # train 45 | gbm = lgb.train( 46 | params, lgb_train, num_boost_round=20, valid_sets=lgb_eval, early_stopping_rounds=5 47 | ) 48 | 49 | print("Save model...") 50 | # save model to file 51 | gbm.save_model("model.txt") 52 | 53 | print("Start predicting...") 54 | # predict 55 | y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration) 56 | # eval 57 | print("The rmse of prediction is:", mean_squared_error(y_test, y_pred) ** 0.5) 58 | -------------------------------------------------------------------------------- /Monitor/BaseInfoStockMonitor.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/1/18 22:08 4 | # @Author : Adolf 5 | # @File : BaseInfoStockMonitor.py 6 | # @Function: 7 | import datetime 8 | import json 9 | 10 | import akshare as ak 11 | from PrivacyConfig.dingtalk import code_name_list, dingtalk_config 12 | 13 | from Utils.info_push import post_msg_to_dingtalk 14 | 15 | 16 | # 判断是否是交易时间 17 | def pd_ztjytime(): 18 | now_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 19 | now_datetime = datetime.datetime.strptime(now_time, "%Y-%m-%d %H:%M:%S") 20 | d1 = datetime.datetime.strptime( 21 | datetime.datetime.now().strftime("%Y-%m-%d") + " 11:30:01", "%Y-%m-%d %H:%M:%S" 22 | ) 23 | d2 = datetime.datetime.strptime( 24 | datetime.datetime.now().strftime("%Y-%m-%d") + " 13:00:00", "%Y-%m-%d %H:%M:%S" 25 | ) 26 | delta1 = (now_datetime - d1).total_seconds() 27 | delta2 = (d2 - now_datetime).total_seconds() 28 | if delta1 > 0 and delta2 > 0: # 在暂停交易的时间内 29 | return True # 不在暂停的交易时间范围内,返回 True 30 | return False # 在暂停的交易时间范围内,返回 Fasle 31 | 32 | 33 | def get_stock_name_mapping(): 34 | with open(file="Data/RealData/ALL_MARKET_CODE.json", encoding="utf-8") as f: 35 | _market_code_dict = json.load(f) 36 | _market_code_dict = dict( 37 | zip(_market_code_dict.values(), _market_code_dict.keys(), strict=False) 38 | ) 39 | 40 | return _market_code_dict 41 | 42 | 43 | # market_code_dict = get_stock_name_mapping() 44 | 45 | # 获取全市场股票的最近价 46 | stock_zh_a_spot_em_df = ak.stock_zh_a_spot_em() 47 | 48 | 49 | def get_stock_individual_info(_code_name): 50 | # code_id = market_code_dict[code_name] 51 | stock_spot = stock_zh_a_spot_em_df[stock_zh_a_spot_em_df["名称"] == _code_name] 52 | stock_raw_dict = stock_spot.to_dict("list") # ["最新价"][0] 53 | return {k: v[0] for k, v in stock_raw_dict.items()} 54 | 55 | 56 | def monitor_condition(code_res): 57 | code_res["condition"] = False 58 | try: 59 | exec("monitor_function_{}(code_res)".format(code_res["代码"])) 60 | except Exception as e: 61 | print(e) 62 | 63 | if code_res["condition"]: 64 | return True 65 | return False 66 | 67 | 68 | def monitor_function_603229(code_res): 69 | # print(code_res) 70 | code_res["condition"] = False 71 | 72 | 73 | def monitor_function_002555(code_res): 74 | if code_res["最新价"] <= 24: 75 | code_res["condition"] = True 76 | else: 77 | code_res["condition"] = False 78 | 79 | 80 | def push_one_stock_info(_code_name): 81 | res = get_stock_individual_info(_code_name) 82 | if monitor_condition(res): 83 | message = "已经到达预设条件,请查看!名称:{},最新价:{},涨跌幅:{}%".format( 84 | _code_name, res["最新价"], res["涨跌幅"] 85 | ) 86 | print(message) 87 | post_msg_to_dingtalk( 88 | msg=message, title=dingtalk_config["title"], token=dingtalk_config["token"] 89 | ) 90 | 91 | 92 | for code_name in code_name_list: 93 | push_one_stock_info(_code_name=code_name) 94 | # 95 | # print(pd_ztjytime()) 96 | -------------------------------------------------------------------------------- /Monitor/__init__.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/1/18 22:06 4 | # @Author : Adolf 5 | # @File : __init__.py.py 6 | # @Function: 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 8 | # stock_quant 9 | 10 | 用于股票量化策略的研究,欢迎一起讨论交流。目前主要针对择时和择股两种策略的回测框架进行开发,因对缠论较为感兴趣,对缠论有单独的代码进行开发。 11 | 12 | ## TODO LIST 13 | 14 | - [x] 支持工业板块的历史数据获取 15 | - [ ] 支持概念板块的历史数据获取 16 | - [ ] 完善缠论对股票的分析 17 | - [x] 开发选股策略基本回测框架 18 | - [x] 优化回测代码交易核心代码部分 19 | - [ ] 构建多机器分布式机器学习框架 20 | 21 | ## 项目的安装 22 | 23 | ### 1.1 安装项目的依赖 24 | 25 | ```bash 26 | brew install ta-lib 27 | pip install poetry 28 | poetry install 29 | ``` 30 | 31 | ## 项目的基础设定 32 | 33 | ### 2.1 用于控制台使用代理 34 | 35 | ```bash 36 | export http_proxy="http://127.0.0.1:7890" 37 | export https_proxy="http://127.0.0.1:7890" 38 | ``` 39 | 40 | ### 2.2 设置python运行路径 41 | 42 | ```bash 43 | export PYTHONPATH=$(pwd):$PYTHONPATH 44 | ``` 45 | 46 | ### 2.3 设置python不生成pyc(__pycache__) 47 | 48 | ```bash 49 | export PYTHONDONTWRITEBYTECODE=1 50 | ``` 51 | 52 | ### 2.4 国内配置github的真实ip 53 | 54 | - 通过网址```https://ipaddress.com/website/github.com```获取到github的真实ip 55 | - 通过修改```sudo vi /etc/hosts```文件,向其中添加```140.82.112.4 github.com``` 56 | - 安装nscd,如果已经安装了忽略。centos使用```sudo yum install -y nscd``` 57 | - 刷新本地dns缓存```service nscd restart``` 58 | - 备选刷新dns缓存命令```sudo /etc/init.d/nscd restart``` 59 | 60 | ## 获取需要使用到的基本数据 61 | 62 | ### 3.1 获取基础股票数据 63 | 64 |   从东方财富官网获取个股的历史数据,包含前复权,后复权,未复权。 65 | 66 | ```bash 67 | python GetBaseData/get_dc_data.py 68 | ``` 69 | 70 | ### 3.2 获取基础的个股资金流量数据 71 | 72 |   从东方财富官网获取不同股票的近100日的超大、大、中、小单数据变化。 73 | 74 | ```bash 75 | python GetBaseData/get_cash_flow_data.py 76 | ``` 77 | 78 | ### 3.3 获取不同板块的历史数据 79 | 80 |   从东方财富官网获取板块的历史数据。 81 | 目前仅支持行业板块,暂不支持概念板块的数据 82 | 83 | ```bash 84 | python GetBaseData/get_board_data.py 85 | ``` 86 | 87 | ## 如果要使用前端展示缠论结果 88 | 89 | ```bash 90 | streamlit run StrategyLib/ChanStrategy/automatic_drawing.py 91 | ``` 92 | 93 | ### 4.0 Contributing 94 | 95 | 1. Fork it () 96 | 2. Study how it's implemented. 97 | 3. Create your feature branch (`git checkout -b my-new-feature`). 98 | 4. Run [black](https://github.com/ambv/black) code formatter on the finta.py to ensure uniform code style. 99 | 5. Commit your changes (`git commit -am 'Add some feature'`). 100 | 6. Push to the branch (`git push origin my-new-feature`). 101 | 7. Create a new Pull Request. 102 | 103 | ------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /StrategyLib/AutomaticInvestmentPlan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKQ1688/stock_quant/c88e695a5cf17f2a445b671026f1535d369910ce/StrategyLib/AutomaticInvestmentPlan/__init__.py -------------------------------------------------------------------------------- /StrategyLib/AutomaticInvestmentPlan/result_show.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2023-02-06 23:17:11 4 | @LastEditors : error: git config user.name & please set dead value or install git 5 | @LastEditTime : 2023-06-06 23:36:34 6 | @FilePath : /stock_quant/StrategyLib/AutomaticInvestmentPlan/result_show.py 7 | """ 8 | 9 | from datetime import datetime 10 | 11 | import streamlit as st 12 | 13 | from StrategyLib.AutomaticInvestmentPlan.stable_dog import get_AI_plan_result 14 | 15 | 16 | def getDays(day1, day2): 17 | # 获取需要计算的时间戳 18 | d1 = datetime.strptime(day1, "%Y-%m-%d") 19 | d2 = datetime.strptime(day2, "%Y-%m-%d") 20 | interval = d2 - d1 # 两日期差距 21 | return interval.days 22 | 23 | 24 | def auto_investment_plan(): 25 | try: 26 | st.subheader("定投回测结果") 27 | code = st.sidebar.text_input("code", value="sz.399006") 28 | gap_days = st.sidebar.text_input("interval", value=1) 29 | first_buy_day = st.sidebar.text_input("start", value="2019-01-02") 30 | want_rate = st.sidebar.text_input("target", value=1.1) 31 | if_intelli = st.sidebar.text_input("if_intelli", value="yes") 32 | threshold = st.sidebar.text_input("threshold", value=500000) 33 | st.text("定投记录表:") 34 | res, stock_data = get_AI_plan_result( 35 | code=code, 36 | gap_days=int(gap_days), 37 | first_buy_day=first_buy_day, 38 | want_rate=float(want_rate), 39 | if_intelli=if_intelli == "yes", 40 | threshold=int(threshold), 41 | ) 42 | res.drop(["buy_index"], axis=1, inplace=True) 43 | st.dataframe(res, width=900) 44 | natual_day = getDays(res.loc[0, "date"], res.loc[len(res) - 1, "date"]) 45 | st.text( 46 | f"达成目标自然日天数:{natual_day},投入次数/总金额:{len(res[res['put'] != 0])}/{int(res.loc[len(res) - 1, 'put_in'])},总收益:{int(res.loc[len(res) - 1, 'account'] - res.loc[len(res) - 1, 'put_in'])},收益率:{round(res.loc[len(res) - 1, 'rate'], 3)},标的涨跌幅:{round(stock_data.loc[len(stock_data) - 1, 'close'] / stock_data.loc[0, 'open'])}" 47 | ) 48 | chart_data = res[["rate"]].apply(lambda x: (x - 1) * 100) 49 | st.text("定投收益率:") 50 | st.line_chart(chart_data, y="rate") 51 | st.text("达成目标收益率时投入总金额:") 52 | st.line_chart(res, y="put_in") 53 | st.text("股票行情:") 54 | st.line_chart(stock_data, y="close", x="date") 55 | 56 | except Exception as e: 57 | st.title("出错了") 58 | st.write(e) 59 | -------------------------------------------------------------------------------- /StrategyLib/AutomaticInvestmentPlan/stable_dog.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2023-02-05 18:37:24 4 | @LastEditors : adolf 5 | @LastEditTime : 2023-06-06 23:54:58 6 | @FilePath : /stock_quant/StrategyLib/AutomaticInvestmentPlan/stable_dog.py 7 | """ 8 | 9 | # "稳狗策略" 在股票下跌后进行买入,股票上涨后进行卖出 10 | from dataclasses import asdict, dataclass 11 | 12 | import baostock as bs 13 | import pandas as pd 14 | 15 | # import pandas_ta as ta 16 | from finta import TA 17 | 18 | 19 | @dataclass 20 | class Account: 21 | date: str = "" 22 | buy_index: int = 0 23 | buy_date: str = "" 24 | assert_num: float = 0 25 | account: float = 0 26 | put: float = 0 27 | put_in: float = 0 28 | rate: float = 0 29 | 30 | 31 | def get_AI_plan_result( 32 | code="sh.600000", 33 | gap_days=3, 34 | first_buy_day="2019-01-05", 35 | want_rate=1.1, 36 | if_intelli=True, 37 | threshold=500000, 38 | ): 39 | try: 40 | data = pd.read_csv(f"Data/RealData/Baostock/day/{code}.csv") 41 | except Exception as e: 42 | print(e) 43 | bs.login() 44 | rs = bs.query_history_k_data_plus( 45 | code, 46 | "date,code,open,high,low,close,volume,amount", 47 | start_date=first_buy_day, 48 | end_date="2023-06-06", 49 | frequency="d", 50 | adjustflag="3", 51 | ) 52 | # 打印结果集 53 | data_list = [] 54 | while (rs.error_code == "0") & rs.next(): 55 | # 获取一条记录,将记录合并在一起 56 | data_list.append(rs.get_row_data()) 57 | data = pd.DataFrame(data_list, columns=rs.fields) 58 | bs.logout() 59 | 60 | data = data.fillna("") 61 | # data = data[data["tradestatus"] == 1] 62 | data = data[["date", "code", "open", "high", "low", "close", "volume"]] 63 | 64 | macd_df = TA.MACD(data) 65 | data["MACD"], data["SIGNAL"] = [macd_df["MACD"], macd_df["SIGNAL"]] 66 | data["HISTOGRAM"] = data["MACD"] - data["SIGNAL"] 67 | 68 | data = pd.concat([data, macd_df]) 69 | data = data[data["date"] >= first_buy_day] 70 | # print(data) 71 | data.reset_index(inplace=True, drop=True) 72 | # first_buy_day=data.loc[0,"date"] 73 | my_account = Account() 74 | rate_list = [] 75 | for index, row in data.iterrows(): 76 | times = 1 77 | my_account.put = 0 78 | buy_flag = True 79 | if if_intelli: 80 | # if row.HISTOGRAM > 0 or (index >2 and data.loc[index-1].HISTOGRAM >0 and data.loc[index-2].HISTOGRAM >0 ): 81 | if index != 0 and (row.HISTOGRAM > 0 and row.close > row.open): 82 | print(row.date + "macd 红柱,或今日上涨不投") 83 | buy_flag = False 84 | if my_account.put_in != 0: 85 | if my_account.rate > want_rate * 0.95: 86 | print( 87 | f"{row.date} + 目前收益率已到达高水位线{want_rate * 0.95},不再买入" 88 | ) 89 | buy_flag = False 90 | if my_account.rate < 0.9: 91 | times = min(int(my_account.put_in / 1000 / 6), 5) 92 | if my_account.rate < 0.8: 93 | times = min(int(my_account.put_in / 1000 / 3), 10) 94 | if my_account.rate < 0.7: 95 | times = min(int(my_account.put_in / 1000 / 2), 20) 96 | if my_account.put_in > threshold: 97 | print(f"{row.date} + 目前投入已达到{threshold},需要卖出一半股票以降低仓位") 98 | buy_flag = False 99 | sell_amount = my_account.put_in / 2 100 | my_account.put = -sell_amount 101 | my_account.put_in -= sell_amount 102 | my_account.assert_num = my_account.assert_num - sell_amount / row.close 103 | 104 | if index == 0 or (index - my_account.buy_index >= gap_days and buy_flag): 105 | print("index:%s" % index) 106 | money = 1000 * times 107 | my_account.buy_index = index 108 | my_account.buy_date = row.date 109 | my_account.assert_num += money / row.close 110 | my_account.put_in += money 111 | my_account.put = money 112 | 113 | my_account.date = row.date 114 | my_account.account = row.close * my_account.assert_num 115 | if my_account.put_in != 0: 116 | my_account.rate = my_account.account / my_account.put_in 117 | rate_list.append(asdict(my_account).copy()) 118 | if my_account.rate > want_rate: 119 | print("my_account.rate ", my_account.rate) 120 | break 121 | 122 | # print(my_account) 123 | # print(rate_list) 124 | rate_df = pd.DataFrame(rate_list) 125 | # print(rate_df) 126 | data = data[data["date"] <= my_account.date] 127 | return rate_df, data 128 | 129 | 130 | code = "sz.002044" 131 | gap_days = 1 132 | first_buy_day = "2021-02-22" 133 | want_rate = 1.1 134 | 135 | res, stock = get_AI_plan_result( 136 | code=code, 137 | gap_days=int(gap_days), 138 | first_buy_day=first_buy_day, 139 | want_rate=float(want_rate), 140 | threshold=500000, 141 | ) 142 | 143 | print(res) 144 | -------------------------------------------------------------------------------- /StrategyLib/AutomaticInvestmentPlan/stable_dog_unstop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2023-02-06 23:17:11 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2023-02-06 23:54:00 6 | FilePath : /stock_quant/StrategyLib/AutomaticInvestmentPlan/result_show.py 7 | """ 8 | 9 | from datetime import datetime 10 | 11 | import pandas as pd 12 | import streamlit as st 13 | 14 | from StrategyLib.AutomaticInvestmentPlan.stable_dog import get_AI_plan_result 15 | 16 | 17 | def getDays(day1, day2): 18 | # 获取需要计算的时间戳 19 | d1 = datetime.strptime(day1, "%Y-%m-%d") 20 | d2 = datetime.strptime(day2, "%Y-%m-%d") 21 | interval = d2 - d1 # 两日期差距 22 | return interval.days 23 | 24 | 25 | st.subheader("Stable unstop ") 26 | code = st.sidebar.text_input("code", value="sz.399006") 27 | gap_days = st.sidebar.text_input("interval", value=1) 28 | first_buy_day = st.sidebar.text_input("start", value="2019-01-02") 29 | end_buy_day = st.sidebar.text_input("end", value="2023-02-06") 30 | 31 | want_rate = st.sidebar.text_input("target", value=1.1) 32 | if_intelli = st.sidebar.text_input("if_intelli", value="yes") 33 | threshold = st.sidebar.text_input("threshold", value=100000) 34 | 35 | records = [] 36 | total_earned = 0 37 | start = first_buy_day 38 | while start < end_buy_day: 39 | res, _ = get_AI_plan_result( 40 | code=code, 41 | gap_days=int(gap_days), 42 | first_buy_day=start, 43 | want_rate=float(want_rate), 44 | if_intelli=if_intelli == "yes", 45 | threshold=int(threshold), 46 | ) 47 | if len(res) == 0: 48 | break 49 | earned = res.loc[len(res) - 1, "account"] - res.loc[len(res) - 1, "put_in"] 50 | total_earned += earned 51 | end = res.loc[len(res) - 1, "date"] 52 | records.append([start, end, res.loc[len(res) - 1, "put_in"], earned, total_earned]) 53 | start = end 54 | print(records[-1]) 55 | 56 | records = pd.DataFrame( 57 | records, columns=["start", "end", "put_in", "earned", "total_earned"] 58 | ) 59 | print("总收益为%d" % total_earned) 60 | natual_day = getDays(records.loc[0, "start"], records.loc[len(records) - 1, "end"]) 61 | # 股票曲线图 62 | stock_data = pd.read_csv(f"Data/RealData/Baostock/day/{code}.csv") 63 | stock_data = stock_data[stock_data["tradestatus"] == 1] 64 | stock_data = stock_data[["date", "code", "open", "high", "low", "close", "volume"]] 65 | stock_data = stock_data[ 66 | (stock_data["date"] >= first_buy_day) & (stock_data["date"] <= end_buy_day) 67 | ] 68 | 69 | print("records:") 70 | print(records) 71 | stock_data.reset_index(drop=True, inplace=True) 72 | st.text( 73 | f"自然日天数:{natual_day},总收益:{total_earned} ,标的涨跌幅:{stock_data.loc[len(stock_data) - 1, 'close'] / stock_data.loc[0, 'open']}" 74 | ) 75 | st.dataframe(records, width=900) 76 | st.line_chart(records, y="total_earned") 77 | st.line_chart(stock_data, y="close", x="date") 78 | -------------------------------------------------------------------------------- /StrategyLib/ChanStrategy/BasicChan/__init__.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/2/2 22:25 4 | # @Author : Adolf 5 | # @File : __init__.py.py 6 | # @Function: 7 | -------------------------------------------------------------------------------- /StrategyLib/ChanStrategy/BasicChan/basic_enum.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/2/2 15:08 4 | # @Author : Adolf 5 | # @File : basic_enum.py 6 | # @Function: 7 | 8 | from enum import Enum 9 | 10 | # class Operate(Enum): 11 | # # 持有状态 12 | # HL = "持多" # Hold Long 13 | # HS = "持空" # Hold Short 14 | # HO = "持币" # Hold Other 15 | # 16 | # # 多头操作 17 | # LO = "开多" # Long Open 18 | # LE = "平多" # Long Exit 19 | # LA1 = "第一次加多仓" # Long Add 1 20 | # LA2 = "第二次加多仓" # Long Add 2 21 | # LR1 = "第一次减多仓" # Long Reduce 1 22 | # LR2 = "第二次减多仓" # Long Reduce 2 23 | # 24 | # # 空头操作 25 | # SO = "开空" # Short Open 26 | # SE = "平空" # Short Exit 27 | # SA1 = "第一次加空仓" # Short Add 1 28 | # SA2 = "第二次加空仓" # Short Add 2 29 | # SR1 = "第一次减空仓" # Short Reduce 1 30 | # SR2 = "第二次减空仓" # Short Reduce 2 31 | 32 | 33 | class Mark(Enum): 34 | D = "底分型" 35 | G = "顶分型" 36 | 37 | 38 | class Direction(Enum): 39 | Up = "向上" 40 | Down = "向下" 41 | 42 | 43 | class Freq(Enum): 44 | Tick = "Tick" 45 | F1 = "1分钟" 46 | F5 = "5分钟" 47 | F15 = "15分钟" 48 | F30 = "30分钟" 49 | F60 = "60分钟" 50 | D = "daily" 51 | W = "weekly" 52 | M = "monthly" 53 | S = "quarterly" 54 | Y = "yearly" 55 | -------------------------------------------------------------------------------- /StrategyLib/ChanStrategy/BasicChan/basic_structure.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/2/2 14:56 4 | # @Author : Adolf 5 | # @File : basic_structure.py 6 | # @Function: 7 | 8 | from dataclasses import dataclass 9 | from datetime import datetime 10 | 11 | from StrategyLib.ChanStrategy.BasicChan.basic_enum import Direction, Freq, Mark 12 | from Utils.TechnicalIndicators.basic_indicators import RSQ 13 | 14 | 15 | @dataclass 16 | class RawBar: 17 | """原始K线元素""" 18 | 19 | symbol: str 20 | id: int # id 必须是升序 21 | dt: datetime 22 | freq: Freq 23 | open: float 24 | close: float 25 | high: float 26 | low: float 27 | vol: float 28 | amount: float = None 29 | 30 | 31 | @dataclass 32 | class NewBar: 33 | """去除包含关系后的K线元素""" 34 | 35 | symbol: str 36 | id: int # id 必须是升序 37 | dt: datetime 38 | freq: Freq 39 | open: float 40 | close: float 41 | high: float 42 | low: float 43 | vol: float 44 | amount: float = None 45 | elements: list = None # 存入具有包含关系的原始K线 46 | 47 | @property 48 | def raw_bars(self): 49 | return self.elements 50 | 51 | 52 | @dataclass 53 | class FX: 54 | symbol: str 55 | dt: datetime 56 | mark: Mark 57 | high: float 58 | low: float 59 | fx: float 60 | power: str = None 61 | elements: list = None 62 | 63 | @property 64 | def new_bars(self): 65 | """构成分型的无包含关系K线""" 66 | return self.elements 67 | 68 | @property 69 | def raw_bars(self): 70 | """构成分型的原始K线""" 71 | res = [] 72 | for e in self.elements: 73 | res.extend(e.raw_bars) 74 | return res 75 | 76 | 77 | @dataclass 78 | class FakeBI: 79 | """虚拟笔:主要为笔的内部分析提供便利""" 80 | 81 | symbol: str 82 | sdt: datetime 83 | edt: datetime 84 | direction: Direction 85 | high: float 86 | low: float 87 | power: float 88 | 89 | 90 | def create_fake_bis(fxs: list[FX]) -> list[FakeBI]: 91 | """创建 fake_bis 列表 92 | :param fxs: 分型序列,必须顶底分型交替 93 | :return: fake_bis 94 | """ 95 | if len(fxs) % 2 != 0: 96 | fxs = fxs[:-1] 97 | 98 | fake_bis = [] 99 | for i in range(1, len(fxs)): 100 | fx1 = fxs[i - 1] 101 | fx2 = fxs[i] 102 | assert fx1.mark != fx2.mark 103 | if fx1.mark == Mark.D: 104 | fake_bi = FakeBI( 105 | symbol=fx1.symbol, 106 | sdt=fx1.dt, 107 | edt=fx2.dt, 108 | direction=Direction.Up, 109 | high=fx2.high, 110 | low=fx1.low, 111 | power=round(fx2.high - fx1.low, 2), 112 | ) 113 | elif fx1.mark == Mark.G: 114 | fake_bi = FakeBI( 115 | symbol=fx1.symbol, 116 | sdt=fx1.dt, 117 | edt=fx2.dt, 118 | direction=Direction.Down, 119 | high=fx1.high, 120 | low=fx2.low, 121 | power=round(fx1.high - fx2.low, 2), 122 | ) 123 | else: 124 | raise ValueError 125 | fake_bis.append(fake_bi) 126 | return fake_bis 127 | 128 | 129 | @dataclass 130 | class BI: 131 | symbol: str 132 | fx_a: FX = None # 笔开始的分型 133 | fx_b: FX = None # 笔开始的分型 134 | fxs: list = None # 笔内部的分型列表 135 | direction: Direction = None 136 | bars: list[NewBar] = None 137 | 138 | def __post_init__(self): 139 | self.sdt = self.fx_a.dt 140 | self.edt = self.fx_b.dt 141 | 142 | def __repr__(self): 143 | return ( 144 | f"BI(symbol={self.symbol}, sdt={self.sdt}, edt={self.edt}," 145 | f"direction={self.direction}, high={self.high}, low={self.low})" 146 | ) 147 | 148 | # 定义一些附加属性,用的时候才会计算,提高效率 149 | # ====================================================================== 150 | @property 151 | def fake_bis(self): 152 | return create_fake_bis(self.fxs) 153 | 154 | @property 155 | def high(self): 156 | return max(self.fx_a.high, self.fx_b.high) 157 | 158 | @property 159 | def low(self): 160 | return min(self.fx_a.low, self.fx_b.low) 161 | 162 | @property 163 | def power(self): 164 | return self.power_price 165 | 166 | @property 167 | def power_price(self): 168 | """价差力度""" 169 | return round(abs(self.fx_b.fx - self.fx_a.fx), 2) 170 | 171 | @property 172 | def power_volume(self): 173 | """成交量力度""" 174 | return sum([x.vol for x in self.bars[1:-1]]) 175 | 176 | @property 177 | def change(self): 178 | """笔的涨跌幅""" 179 | c = round((self.fx_b.fx - self.fx_a.fx) / self.fx_a.fx, 4) 180 | return c 181 | 182 | @property 183 | def length(self): 184 | """笔的无包含关系K线数量""" 185 | return len(self.bars) 186 | 187 | @property 188 | def rsq(self): 189 | close = [x.close for x in self.raw_bars] 190 | return round(RSQ(close), 4) 191 | 192 | @property 193 | def raw_bars(self): 194 | """构成笔的原始K线序列""" 195 | x = [] 196 | for bar in self.bars[1:-1]: 197 | x.extend(bar.raw_bars) 198 | return x 199 | 200 | 201 | @dataclass 202 | class ZS: 203 | symbol: str 204 | bis: list[BI] 205 | 206 | @property 207 | def sdt(self): 208 | """中枢开始时间""" 209 | return self.bis[0].sdt 210 | 211 | @property 212 | def edt(self): 213 | """中枢结束时间""" 214 | return self.bis[-1].edt 215 | 216 | @property 217 | def sdir(self): 218 | """中枢第一笔方向""" 219 | return self.bis[0].direction 220 | 221 | @property 222 | def edir(self): 223 | """中枢倒一笔方向""" 224 | return self.bis[-1].direction 225 | 226 | @property 227 | def zz(self): 228 | """中枢中轴""" 229 | return self.zd + (self.zg - self.zd) / 2 230 | 231 | @property 232 | def gg(self): 233 | """中枢最高点""" 234 | return max([x.high for x in self.bis]) 235 | 236 | @property 237 | def zg(self): 238 | return min([x.high for x in self.bis[:3]]) 239 | 240 | @property 241 | def dd(self): 242 | """中枢最低点""" 243 | return min([x.low for x in self.bis]) 244 | 245 | @property 246 | def zd(self): 247 | return max([x.low for x in self.bis[:3]]) 248 | 249 | def __repr__(self): 250 | return ( 251 | f"ZS(sdt={self.sdt}, sdir={self.sdir}, edt={self.edt}, edir={self.edir}, " 252 | f"len_bis={len(self.bis)}, zg={self.zg}, zd={self.zd}, " 253 | f"gg={self.gg}, dd={self.dd}, zz={self.zz})" 254 | ) 255 | -------------------------------------------------------------------------------- /StrategyLib/ChanStrategy/Test/__init__.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/2/2 23:12 4 | # @Author : Adolf 5 | # @File : __init__.py.py 6 | # @Function: 7 | -------------------------------------------------------------------------------- /StrategyLib/ChanStrategy/Test/test_plot.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/2/2 23:13 4 | # @Author : Adolf 5 | # @File : test_plot.py 6 | # @Function: 7 | import os 8 | import random 9 | 10 | import pandas as pd 11 | 12 | from StrategyLib.ChanStrategy.BasicChan.basic_enum import Freq 13 | from StrategyLib.ChanStrategy.BasicChan.basic_tools import CZSC, RawBar, get_zs_seq 14 | from Utils.ShowKline import chan_plot 15 | 16 | 17 | def test_heat_map(): 18 | data = [ 19 | {"x": f"{i}hour", "y": f"{j}day", "heat": random.randint(0, 50)} 20 | for i in range(24) 21 | for j in range(7) 22 | ] 23 | x_label = [f"{i}hour" for i in range(24)] 24 | y_label = [f"{i}day" for i in range(7)] 25 | hm = chan_plot.heat_map(data, x_label=x_label, y_label=y_label) 26 | file_html = "ShowHtml/render.html" 27 | hm.render(file_html) 28 | assert os.path.exists(file_html) 29 | os.remove(file_html) 30 | 31 | 32 | test_heat_map() 33 | 34 | 35 | # cur_path = os.path.split(os.path.realpath(__file__))[0] 36 | 37 | 38 | def read_read_daily(): 39 | file_kline = "Data/RealData/origin/000538.csv" 40 | kline = pd.read_csv(file_kline, encoding="utf-8") 41 | kline = kline[-2000:] 42 | kline.reset_index(drop=True, inplace=True) 43 | # print(kline) 44 | 45 | bars = [ 46 | RawBar( 47 | symbol=row["name"], 48 | id=i, 49 | freq=Freq.D, 50 | open=row["open"], 51 | dt=row["date"], 52 | close=row["close"], 53 | high=row["high"], 54 | low=row["low"], 55 | vol=row["volume"], 56 | amount=row["amount"], 57 | ) 58 | for i, row in kline.iterrows() 59 | ] 60 | 61 | # print(bars[0]) 62 | return bars 63 | 64 | 65 | def test_kline_pro(): 66 | # file_kline = os.path.join(cur_path, "data/000001.SH_D.csv") 67 | bars = read_read_daily() 68 | ka = CZSC(bars) 69 | # ka.open_in_browser() 70 | file_html = "ShowHtml/czsc_render.html" 71 | chart = ka.to_echarts(width="1200px", height="600px") 72 | chart.render(file_html) 73 | assert os.path.exists(file_html) 74 | # os.remove(file_html) 75 | 76 | 77 | test_kline_pro() 78 | 79 | 80 | def test_get_zs_seq(): 81 | bars = read_read_daily() 82 | c = CZSC(bars) 83 | zs_seq = get_zs_seq(c.bi_list) 84 | # assert len(zs_seq) == 7 85 | # assert len(zs_seq[-1].bis) == 20 86 | # print(zs_seq) 87 | for one_zs in zs_seq: 88 | print(one_zs) 89 | 90 | 91 | test_get_zs_seq() 92 | -------------------------------------------------------------------------------- /StrategyLib/ChanStrategy/__init__.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/2/2 14:13 4 | # @Author : Adolf 5 | # @File : __init__.py.py 6 | # @Function: 7 | -------------------------------------------------------------------------------- /StrategyLib/ChanStrategy/automatic_drawing.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/2/5 17:22 4 | # @Author : Adolf 5 | # @File : automatic_drawing.py 6 | # @Function: 7 | import akshare as ak 8 | import streamlit as st 9 | from streamlit_echarts import st_pyecharts 10 | 11 | from StrategyLib.ChanStrategy.BasicChan.basic_tools import CZSC, RawBar 12 | 13 | # st.set_page_config(layout="wide") 14 | 15 | 16 | def chan_show_base(): 17 | symbol = st.sidebar.text_input( 18 | "stock code,上证指数:sh000001,深证成指:sz399001,创业板指:sz399006,沪深300:sz399300,中证500:sh000905", 19 | "000001", 20 | ) 21 | period = st.sidebar.selectbox( 22 | "time period", 23 | ("daily", "weekly", "monthly", "1min", "5min", "15min", "30min", "60min"), 24 | ) 25 | adjust = st.sidebar.selectbox("adjust", ("qfq", "hfq", "origin")) 26 | 27 | start_date = None 28 | end_date = None 29 | 30 | if period in ["daily", "weekly", "monthly"]: 31 | start_date = st.sidebar.text_input("start time", "20170301") 32 | end_date = st.sidebar.text_input("end time", "20210907") 33 | 34 | # 'You selected:', option 35 | 36 | st.title("股票数据展示") 37 | st.write("下面是表格") 38 | 39 | if symbol in ["sh000001", "sz399001", "sz399006", "sz399300", "sh000905"]: 40 | df = ak.stock_zh_index_daily_tx(symbol=symbol) 41 | print(df) 42 | bars = [ 43 | RawBar( 44 | symbol=symbol, 45 | id=i, 46 | freq=period, 47 | open=row["open"], 48 | dt=row["date"], 49 | close=row["close"], 50 | high=row["high"], 51 | low=row["low"], 52 | vol=0, 53 | amount=row["amount"], 54 | ) 55 | for i, row in df.iterrows() 56 | ] 57 | 58 | elif period in ["daily", "weekly", "monthly"]: 59 | df = ak.stock_zh_a_hist( 60 | symbol=symbol, 61 | period=period, # choice of {'daily', 'weekly', 'monthly'} 62 | start_date=start_date, 63 | end_date=end_date, 64 | adjust=adjust, 65 | ) 66 | 67 | bars = [ 68 | RawBar( 69 | symbol=symbol, 70 | id=i, 71 | freq=period, 72 | open=row["开盘"], 73 | dt=row["日期"], 74 | close=row["收盘"], 75 | high=row["最高"], 76 | low=row["最低"], 77 | vol=row["成交量"], 78 | amount=row["成交额"], 79 | ) 80 | for i, row in df.iterrows() 81 | ] 82 | 83 | elif period in ["1min", "5min", "15min", "30min", "60min"]: 84 | df = ak.stock_zh_a_hist_min_em( 85 | symbol=symbol, period=period.replace("min", ""), adjust=adjust 86 | ) 87 | 88 | bars = [ 89 | RawBar( 90 | symbol=symbol, 91 | id=i, 92 | freq=period, 93 | open=row["开盘"], 94 | dt=row["时间"], 95 | close=row["收盘"], 96 | high=row["最高"], 97 | low=row["最低"], 98 | vol=row["成交量"], 99 | amount=row["成交额"], 100 | ) 101 | for i, row in df.iterrows() 102 | ] 103 | 104 | else: 105 | raise ValueError 106 | 107 | if start_date is not None: 108 | try: 109 | df = df[df["日期"] > start_date] 110 | except Exception as e: 111 | print(e) 112 | print(df.columns) 113 | 114 | if end_date is not None: 115 | try: 116 | df = df[df["日期"] < end_date] 117 | except Exception as e: 118 | print(e) 119 | print(df.columns) 120 | 121 | return bars 122 | 123 | 124 | # print(df) 125 | def chan_show_main(): 126 | bars = chan_show_base() 127 | ka = CZSC(bars) 128 | # file_html = 'ShowHtml/czsc_render.html' 129 | # chart = ka.to_echarts(width="1200px", height='1000px') 130 | chart = ka.to_echarts() 131 | st_pyecharts(chart, height="600%", width="100%") 132 | # chart.render(file_html) 133 | # assert os.path.exists(file_html) 134 | # st.write(df) 135 | # components.html(chart, width=1200, height=600) 136 | 137 | 138 | if __name__ == "__main__": 139 | chan_show_main() 140 | -------------------------------------------------------------------------------- /StrategyLib/ChooseAssetStrategy/board_mom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-08-09 23:11:43 5 | LastEditTime: 2022-08-09 23:35:59 6 | LastEditors: adolf 7 | """ 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from sklearn.linear_model import LinearRegression # , Ridge, Lasso 12 | 13 | from BackTrader.market_choose import MarketChoose 14 | 15 | 16 | class BoardMoMStrategy(MarketChoose): 17 | def __init__(self, *args, **kwargs) -> None: 18 | super().__init__(*args, **kwargs) 19 | self.board_data_path = "Data/BoardData/industry_origin/" 20 | 21 | # 线性回归 22 | self.model = LinearRegression() 23 | # 岭回归 24 | # model = Ridge(alpha=1.0, fit_intercept=True) 25 | # Lasso回归 26 | # model = Lasso(alpha=1.0, fit_intercept=True) 27 | 28 | @staticmethod 29 | def normalization(data): 30 | _range = np.max(data) - np.min(data) 31 | return (data - np.min(data)) / _range 32 | 33 | def cal_one_date_mom(self, origin_data, period=20): 34 | """ 35 | 计算一天的mom 36 | :param origin_data: 一天的原始数据 37 | :param period: mom周期 38 | """ 39 | x = np.linspace(0, 1, period).reshape(-1, 1) 40 | y = origin_data.values.reshape(-1, 1) 41 | y = self.normalization(y) 42 | 43 | self.model.fit(x, y) 44 | return self.model.coef_[0][0] 45 | 46 | def cal_one_data(self, file_name="", period=20): 47 | """ 48 | 计算一个板块的mom 49 | :param file_name: 文件名称 50 | :param period: mom周期 51 | """ 52 | origin_data = pd.read_csv(self.config.DATA_PATH + file_name) 53 | data = origin_data[["date", "open", "close", "high", "low", "volume"]].copy() 54 | data["mid"] = (data["open"] + data["close"] + data["high"] + data["low"]) / 4 55 | # self.logger.debug(data) 56 | 57 | data["line_w"] = ( 58 | data["close"] 59 | .rolling(window=period) 60 | .apply(lambda x: self.cal_one_date_mom(x, period)) 61 | ) 62 | 63 | data = data[["date", "close", "line_w"]] 64 | data.rename( 65 | columns={ 66 | "close": "{}_close".format(file_name.split(".")[0]), 67 | "line_w": "{}_mom".format(file_name.split(".")[0]), 68 | }, 69 | inplace=True, 70 | ) 71 | 72 | # self.logger.success(data) 73 | return data 74 | 75 | def choose_rule(self, data): 76 | for index, row in data.iterrows(): 77 | tmp_mom = row[ 78 | [ 79 | "{}_mom".format(board_name.split(".")[0]) 80 | for board_name in self.all_data_list 81 | if not pd.isna(row["{}_mom".format(board_name.split(".")[0])]) 82 | ] 83 | ] 84 | tmp_mom = tmp_mom.to_dict() 85 | tmp_mom = sorted(tmp_mom.items(), key=lambda x: x[1], reverse=True) 86 | 87 | try: 88 | # data.loc[index, "top_mom"] = tmp_mom[0][0].replace("_mom", "") 89 | data.loc[index, "choose_assert"] = tmp_mom[0][0].replace("_mom", "") 90 | # data.loc[index,"choose_value"] = row[tmp_mom[0][0].replace("_mom", "_close")] 91 | except Exception as e: 92 | self.logger.warning(e) 93 | 94 | self.logger.debug(data) 95 | 96 | # data = data[["date", "top_mom", "top_mom_pct"]] 97 | return data 98 | 99 | 100 | if __name__ == "__main__": 101 | import time 102 | 103 | board_mom_strategy = BoardMoMStrategy( 104 | LOG_LEVEL="DEBUG", 105 | DATA_PATH="Data/BoardData/industry_origin/", 106 | SAVE_PATH="Data/ChooseData/board_mom.csv", 107 | RUN_ONLINE=False, 108 | ) 109 | local_time = time.time() 110 | # board_mom_strategy.cal_one_data(board_name="汽车零部件", period=20) 111 | board_mom_strategy.run() 112 | print(time.time() - local_time) 113 | -------------------------------------------------------------------------------- /StrategyLib/OneAssetStrategy/Demark9.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-12-11 16:42:58 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2022-12-13 23:58:01 6 | FilePath : /stock_quant/StrategyLib/OneAssetStrategy/Demark9.py 7 | """ 8 | 9 | import pandas as pd 10 | import pandas_ta as ta 11 | 12 | from BackTrader.base_back_trader import TradeStructure 13 | 14 | 15 | class Demark9Strategy(TradeStructure): 16 | """ 17 | 使用Demark9策略进行交易 18 | """ 19 | 20 | def cal_technical_indicators(self, indicators_config): 21 | self.logger.debug(indicators_config) 22 | 23 | return_demark = ta.td_seq(self.data.close, asint=True, show_all=False) 24 | self.data = pd.concat([self.data, return_demark], axis=1) 25 | 26 | self.logger.debug(self.data.head(30)) 27 | show_data = self.data[ 28 | (self.data["TD_SEQ_UP"] == 9) | (self.data["TD_SEQ_DN"] == 9) 29 | ] 30 | self.logger.info(show_data) 31 | exit() 32 | # self.logger.debug(len(self.data)) 33 | # self.logger.debug(len(return_demark)) 34 | 35 | def buy_logic(self): 36 | if self.trade_state.trading_step.TD_SEQ_DN == 9: 37 | return True 38 | return False 39 | 40 | def sell_logic(self): 41 | if self.trade_state.trading_step.TD_SEQ_UP == 9: 42 | return True 43 | return False 44 | 45 | 46 | if __name__ == "__main__": 47 | config = { 48 | "RANDOM_SEED": 42, 49 | "LOG_LEVEL": "INFO", 50 | "CODE_NAME": "600570", 51 | # "CODE_NAME": "ALL_MARKET_100", 52 | # "CODE_NAME": ["600570", "002610", "300663"], 53 | "START_STAMP": "2020-01-01", 54 | # "END_STAMP": "2020-12-31", 55 | # "SHOW_DATA_PATH": "", 56 | # "STRATEGY_PARAMS": {} 57 | } 58 | strategy = Demark9Strategy(config) 59 | strategy.run() 60 | -------------------------------------------------------------------------------- /StrategyLib/OneAssetStrategy/EMA_Ma_Crossover.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-12-01 22:43:45 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2022-12-10 17:25:18 6 | FilePath : /stock_quant/StrategyLib/OneAssetStrategy/EMA_Ma_Crossover.py 7 | """ 8 | 9 | import pandas_ta as ta 10 | 11 | from BackTrader.base_back_trader import TradeStructure 12 | 13 | 14 | class MaEmaCrossover(TradeStructure): 15 | """ 16 | 使用EMA & MA Crossover 和 RSI 进行交易,具体策略如下: 17 | 1、MA > EMA; 18 | 2、close > open 19 | 3、70 > RSI > 50 20 | 4、close > Ma > Ema 21 | 5、stop loss Ma 22 | 6、profit-loss ratio 1:1.5 23 | """ 24 | 25 | def cal_technical_indicators(self, indicators_config): 26 | self.logger.debug(indicators_config) 27 | self.data = self.data[ 28 | ["date", "open", "high", "low", "close", "volume", "code"] 29 | ] 30 | 31 | self.data["ma"] = ta.sma( 32 | self.data["close"], length=indicators_config["sma_length"] 33 | ) 34 | self.data["ema"] = ta.ema( 35 | self.data["close"], length=indicators_config["ema_length"] 36 | ) 37 | 38 | self.data["rsi"] = ta.rsi(self.data["close"]) 39 | 40 | self.logger.debug(self.data.tail(30)) 41 | 42 | def buy_logic(self): 43 | self.logger.debug(self.trade_state.trading_step) 44 | self.logger.debug(self.trade_state.one_transaction_record) 45 | if ( 46 | self.trade_state.trading_step.ma > self.trade_state.trading_step.ema 47 | and self.trade_state.trading_step.close > self.trade_state.trading_step.open 48 | ): 49 | if ( 50 | self.trade_state.trading_step.rsi > 50 51 | and self.trade_state.trading_step.rsi < 70 52 | ): 53 | if ( 54 | self.trade_state.trading_step.close 55 | > self.trade_state.trading_step.ma 56 | > self.trade_state.trading_step.ema 57 | ): 58 | self.trade_state.one_transaction_record.stop_loss = ( 59 | self.trade_state.trading_step.ma 60 | ) 61 | self.trade_state.one_transaction_record.take_profit = ( 62 | self.trade_state.trading_step.close 63 | - self.trade_state.one_transaction_record.stop_loss 64 | ) * 1.5 + self.trade_state.trading_step.close 65 | return True 66 | else: 67 | return False 68 | 69 | def sell_logic(self): 70 | self.logger.debug(self.trade_state.trading_step) 71 | self.logger.debug(self.trade_state.one_transaction_record) 72 | if ( 73 | self.trade_state.one_transaction_record.take_profit is None 74 | or self.trade_state.one_transaction_record.stop_loss is None 75 | ): 76 | return False 77 | 78 | if ( 79 | self.trade_state.trading_step.close 80 | > self.trade_state.one_transaction_record.take_profit 81 | or self.trade_state.trading_step.close 82 | < self.trade_state.one_transaction_record.stop_loss 83 | ): 84 | return True 85 | return False 86 | 87 | 88 | if __name__ == "__main__": 89 | config = { 90 | # "RANDOM_SEED": 42, 91 | "LOG_LEVEL": "SUCCESS", 92 | # "CODE_NAME": "600570", 93 | "CODE_NAME": "ALL_MARKET_10", 94 | # "CODE_NAME": ["600570", "002610", "300663"], 95 | # "START_STAMP": "2020-01-01", 96 | # "END_STAMP": "2020-12-31", 97 | # "SHOW_DATA_PATH": "", 98 | "STRATEGY_PARAMS": {"sma_length": 10, "ema_length": 10}, 99 | } 100 | strategy = MaEmaCrossover(config) 101 | strategy.run() 102 | -------------------------------------------------------------------------------- /StrategyLib/OneAssetStrategy/Ma5Ma10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-12-03 17:59:38 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2022-12-10 18:02:26 6 | FilePath : /stock_quant/StrategyLib/OneAssetStrategy/Ma5Ma10.py 7 | """ 8 | 9 | from pprint import pformat 10 | 11 | import pandas_ta as ta 12 | 13 | from BackTrader.base_back_trader import TradeStructure 14 | 15 | 16 | class Ma5Ma10Strategy(TradeStructure): 17 | """ 18 | 5日均线和10日均线策略,当5日均线上穿10日均线时买入,当5日均线下穿10日均线时卖出 19 | """ 20 | 21 | def cal_technical_indicators(self, indicators_config): 22 | self.logger.debug(indicators_config) 23 | 24 | self.data["sma5"] = ta.sma(self.data["close"], length=5) 25 | self.data["sma10"] = ta.sma(self.data["close"], length=10) 26 | 27 | # def buy_logic(self, trading_step, one_transaction_record, history_trading_step): 28 | def buy_logic(self): 29 | self.logger.debug(pformat(self.trade_state, indent=4, width=20)) 30 | if ( 31 | self.trade_state.trading_step.sma5 > self.trade_state.trading_step.sma10 32 | and self.trade_state.history_trading_step[0].sma5 33 | < self.trade_state.history_trading_step[0].sma10 34 | ): 35 | return True 36 | return False 37 | 38 | def sell_logic(self): 39 | if ( 40 | self.trade_state.trading_step.sma5 < self.trade_state.trading_step.sma10 41 | and self.trade_state.history_trading_step[0].sma5 42 | > self.trade_state.history_trading_step[0].sma10 43 | ): 44 | return True 45 | return False 46 | 47 | 48 | if __name__ == "__main__": 49 | config = { 50 | "RANDOM_SEED": 42, 51 | "LOG_LEVEL": "INFO", 52 | "CODE_NAME": "600570", 53 | # "CODE_NAME": "ALL_MARKET_100", 54 | # "CODE_NAME": ["600570", "002610", "300663"], 55 | # "START_STAMP": "2020-01-01", 56 | # "END_STAMP": "2020-12-31", 57 | # "SHOW_DATA_PATH": "", 58 | # "STRATEGY_PARAMS": {} 59 | } 60 | strategy = Ma5Ma10Strategy(config) 61 | strategy.run() 62 | -------------------------------------------------------------------------------- /StrategyLib/OneAssetStrategy/MacdDeviate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-08-14 13:21:21 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2022-12-11 14:19:26 6 | FilePath : /stock_quant/StrategyLib/OneAssetStrategy/MacdDeviate.py 7 | """ 8 | 9 | import numpy as np 10 | from scipy.signal import argrelextrema 11 | 12 | from BackTrader.base_back_trader import TradeStructure 13 | 14 | # from StrategyLib.OneAssetStrategy import macd_deviate_config 15 | from Utils.TechnicalIndicators.basic_indicators import MACD 16 | 17 | 18 | class MACDDeviate(TradeStructure): 19 | """ 20 | 使用MACD底背离策略,如果价格低点创下30日新低,并比上一个低点价格更低,但是对应的MACD值更高,则在其后的第一个MACD金叉买入,第一个MACD死叉卖出。 21 | """ 22 | 23 | def __init__(self, config): 24 | super(MACDDeviate, self).__init__(config) 25 | 26 | def cal_technical_indicators(self, indicators_config): 27 | self.logger.debug(f"macd config:{indicators_config}") 28 | 29 | self.data["30_lowest"] = self.data.low.rolling(30).min() 30 | 31 | self.data["DIFF"], self.data["DEA"], self.data["MACD"] = MACD(self.data.close) 32 | # self.data["MACD"], self.data["SIGNAL"] = [macd_df["MACD"], macd_df["SIGNAL"]] 33 | # self.data["HISTOGRAM"] = self.data["MACD"] - self.data["SIGNAL"] 34 | 35 | # 获取到MACD金叉点和死叉点 36 | self.data.loc[ 37 | (self.data["MACD"] > 0) & (self.data["MACD"].shift(1) < 0), "trade" 38 | ] = "LONG" 39 | self.data.loc[ 40 | (self.data["MACD"] < 0) & (self.data["MACD"].shift(1) > 0), "trade" 41 | ] = "SHORT" 42 | 43 | self.data["price_state"] = 0 44 | # 寻找价格的极值点 45 | price_res = argrelextrema(self.data.low.values, np.less, order=1)[0].tolist() 46 | 47 | last_low_price = None 48 | last_macd = None 49 | for index in price_res: 50 | if last_low_price is not None and last_macd is not None: 51 | if ( 52 | self.data.loc[index, "low"] < last_low_price - 0.1 53 | and self.data.loc[index, "low"] == self.data.loc[index, "30_lowest"] 54 | ): 55 | if self.data.loc[index, "MACD"] > last_macd: 56 | self.data.loc[index, "price_state"] = 1 57 | 58 | self.logger.debug(self.data.loc[index, "date"]) 59 | 60 | last_low_price = self.data.loc[index, "low"] 61 | last_macd = self.data.loc[index, "MACD"] 62 | 63 | # self.logger.debug(self.data) 64 | 65 | def trading_algorithm(self): 66 | price_flag = 0 67 | lowest_30 = 0 68 | for index, row in self.data.iterrows(): 69 | if row["low"] < lowest_30: 70 | price_flag = 0 71 | lowest_30 = row["low"] 72 | 73 | if row["price_state"] == 1: 74 | price_flag = 1 75 | lowest_30 = row["low"] 76 | 77 | if row["trade"] == "LONG" and price_flag == 1: 78 | self.data.loc[index, "trade"] = "BUY" 79 | price_flag = 0 80 | elif row["trade"] == "SHORT": 81 | self.data.loc[index, "trade"] = "SELL" 82 | 83 | 84 | if __name__ == "__main__": 85 | config = { 86 | "RANDOM_SEED": 42, 87 | "LOG_LEVEL": "SUCCESS", 88 | "CODE_NAME": "600570", 89 | # "CODE_NAME": "ALL_MARKET_10", 90 | # "CODE_NAME": ["sh.600238",], 91 | # "CODE_NAME": ["sh.603806", "sh.603697", "sh.603700", "sh.600570", "sh.603809","sh.600238","sh.603069","sh.600764","sz.002044"], 92 | "START_STAMP": "2015-05-01", 93 | "END_STAMP": "2022-12-20", 94 | # "SHOW_DATA_PATH": "", 95 | # "STRATEGY_PARAMS": {} 96 | } 97 | MACD_strategy = MACDDeviate(config=config) 98 | MACD_strategy.run() 99 | -------------------------------------------------------------------------------- /StrategyLib/OneAssetStrategy/README.md: -------------------------------------------------------------------------------- 1 | 8 | # 构建交易策略 9 | 10 | ### 1.Ma5Ma10 11 | - 策略名称:双均线策略 12 | - 策略逻辑:5日均线和10日均线策略,当5日均线上穿10日均线时买入,当5日均线下穿10日均线时卖出 13 | 14 | ### 2.EMA_Ma_Crossover 15 | - 策略名称:均线动量策略 16 | - 策略逻辑:使用EMA & MA Crossover 和 RSI 进行交易,具体策略如下:
17 | 1、MA > EMA;
18 | 2、close > open;
19 | 3、70 > RSI > 50;
20 | 4、close > Ma > Ema;
21 | 5、stop loss Ma;
22 | 6、profit-loss ratio 1:1.5 23 | 24 | ### 3.MACD_MA 25 | - 策略名称:MACD结合均线策略 26 | - 策略逻辑:
27 | 选股:1、60日线上 2、金叉 3、macd在0轴上方
28 | 买点:第一根绿线
29 | 卖点:收益2个点则卖,最晚第三天卖 30 | 31 | 32 | ### 4.MacdDeviate 33 | - 策略名称:MACD底背离策略 34 | - 策略逻辑:使用MACD底背离策略,如果价格低点创下30日新低,并比上一个低点价格更低,但是对应的MACD值更高,则在其后的第一个MACD金叉买入,第一个MACD死叉卖出。 35 | -------------------------------------------------------------------------------- /StrategyLib/OneAssetStrategy/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-08-14 13:20:32 5 | LastEditTime: 2022-08-14 13:20:33 6 | LastEditors: adolf 7 | """ 8 | -------------------------------------------------------------------------------- /StrategyLib/OneAssetStrategy/macd30.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | # import pandas_ta as ta 4 | # from finta import TA 5 | from BackTrader.base_back_trader import TradeStructure 6 | 7 | 8 | class MACD309Strategy(TradeStructure): 9 | """ """ 10 | 11 | def load_dataset(self, data_path, start_stamp=None, end_stamp=None): 12 | min30_data_path = data_path.replace( 13 | "Data/RealData/hfq/", "Data/RealData/Baostock/30min/" 14 | ) 15 | day_data_path = data_path.replace( 16 | "Data/RealData/hfq/", "Data/RealData/Baostock/day/" 17 | ) 18 | self.data = {} 19 | # self.logger.debug(data_path) 20 | self.data["30min"] = pd.read_csv(min30_data_path) 21 | self.data["day"] = pd.read_csv(day_data_path) 22 | # self.logger.debug((self.data.head())) 23 | 24 | def cal_technical_indicators(self, indicators_config): 25 | self.logger.debug(indicators_config) 26 | macd_day = TA.MACD(self.data["day"]) 27 | self.data["day"]["MACD"], self.data["day"]["SIGNAL"] = [ 28 | macd_day["MACD"], 29 | macd_day["SIGNAL"], 30 | ] 31 | self.data["day"]["HISTOGRAM_day"] = ( 32 | self.data["day"]["MACD"] - self.data["day"]["SIGNAL"] 33 | ) 34 | self.data["day"]["sma5"] = ta.sma(self.data["day"]["close"], length=5) 35 | self.data["day"]["sma10"] = ta.sma(self.data["day"]["close"], length=10) 36 | self.data["day"]["_5_10"] = round( 37 | self.data["day"]["sma5"] - self.data["day"]["sma10"], 3 38 | ) 39 | self.logger.debug(self.data["day"].tail()) 40 | self.data["day"] = self.data["day"][["date", "_5_10", "HISTOGRAM_day"]] 41 | self.logger.info(self.data["day"].tail(n=20)) 42 | 43 | # self.data["30min"]["5_10"] = self.data["30min"]["date"].apply( 44 | # lambda x: self.data["day"][self.data["day"].date == x]["5_10"].tolist()[0]) 45 | self.data["30min"] = pd.merge(self.data["30min"], self.data["day"], on="date") 46 | self.logger.debug(self.data["30min"].tail(n=20)) 47 | # exit() 48 | 49 | macd_df = TA.MACD(self.data["30min"]) 50 | self.data = self.data["30min"] 51 | self.data["MACD"], self.data["SIGNAL"] = [macd_df["MACD"], macd_df["SIGNAL"]] 52 | self.data["HISTOGRAM"] = self.data["MACD"] - self.data["SIGNAL"] 53 | # exit() 54 | self.logger.info(self.data.tail(n=30)) 55 | 56 | def buy_logic(self): 57 | # self.logger.debug(pformat(self.trade_state, indent=4, width=20)) 58 | # if self.trade_state.trading_step.HISTOGRAM_day >= -0.1 and self.trade_state.trading_step.HISTOGRAM>0: 59 | if self.trade_state.trading_step.HISTOGRAM_day >= -0.1: 60 | # if self.trade_state.trading_step._5_10 >= 0 and self.trade_state.trading_step.HISTOGRAM>0: 61 | # if self.trade_state.trading_step.HISTOGRAM > 0: 62 | return True 63 | return False 64 | 65 | def sell_logic(self): 66 | # if self.trade_state.trading_step.HISTOGRAM_day <= 0.1 and self.trade_state.trading_step.HISTOGRAM < 0: 67 | if self.trade_state.trading_step.HISTOGRAM_day <= 0.1: 68 | # if self.trade_state.trading_step._5_10 <= 0 and self.trade_state.trading_step.HISTOGRAM < 0: 69 | 70 | # if self.trade_state.trading_step.HISTOGRAM < 0: 71 | 72 | return True 73 | return False 74 | 75 | 76 | if __name__ == "__main__": 77 | config = { 78 | "RANDOM_SEED": 42, 79 | "LOG_LEVEL": "INFO", 80 | # "CODE_NAME": "sh.600006", 81 | # "CODE_NAME": "ALL_MARKET_10", 82 | "CODE_NAME": ["sh.600238"], 83 | # "CODE_NAME": ["sh.603806", "sh.603697", "sh.603700", "sh.600570", "sh.603809","sh.600238"], 84 | # "START_STAMP": "2020-01-01", 85 | # "END_STAMP": "2020-12-31", 86 | # "SHOW_DATA_PATH": "", 87 | # "STRATEGY_PARAMS": {} 88 | } 89 | strategy = MACD309Strategy(config) 90 | strategy.run() 91 | -------------------------------------------------------------------------------- /StrategyLib/OneAssetStrategy/macd_30m.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from finta import TA 3 | 4 | from BackTrader.base_back_trader import TradeStructure 5 | from StrategyLib.OneAssetStrategy.macd_30m_dayMacd import MACD30DayMacdStrategy 6 | 7 | 8 | class MACD30CurMacdStrategy(TradeStructure): 9 | """ """ 10 | 11 | def load_dataset(self, data_path, start_stamp=None, end_stamp=None): 12 | min30_data_path = data_path.replace( 13 | "Data/RealData/hfq/", "Data/RealData/Baostock/30min/" 14 | ) 15 | day_data_path = data_path.replace( 16 | "Data/RealData/hfq/", "Data/RealData/Baostock/day/" 17 | ) 18 | self.data = pd.read_csv(min30_data_path) 19 | self.data = self.data[ 20 | (self.data["date"] >= start_stamp) & (self.data["date"] <= end_stamp) 21 | ] 22 | self.data["buy"] = 0 23 | self.data["sell"] = 0 24 | self.hist_ratio = 0.005 25 | self.day_data = pd.read_csv(day_data_path) 26 | self.day_data = self.day_data[ 27 | (self.day_data["date"] >= start_stamp) 28 | & (self.day_data["date"] <= end_stamp) 29 | ] 30 | self.day_data["index"] = list(range(len(self.day_data))) 31 | 32 | # self.logger.debug((self.data.head())) 33 | 34 | def cal_technical_indicators(self, indicators_config): 35 | self.logger.debug(indicators_config) 36 | # macd_day = TA.MACD(self.day_data) 37 | # self.day_data["MACD"], self.day_data["SIGNAL"] = [macd_day["MACD"], macd_day["SIGNAL"]] 38 | # self.day_data["HISTOGRAM_day"] = self.day_data['MACD'] - self.day_data['SIGNAL'] 39 | # self.day_data["sma5"] = ta.sma(self.day_data["close"], length=5) 40 | # self.day_data["sma10"] = ta.sma(self.day_data["close"], length=10) 41 | # self.day_data["_5_10"] = round(self.day_data["sma5"] - self.day_data["sma10"],3) 42 | # self.logger.debug(self.day_data.tail()) 43 | # self.day_data=self.day_data[['date','_5_10','HISTOGRAM_day']] 44 | # self.logger.debug(self.day_data.tail(n=20)) 45 | # self.data=pd.merge(self.data,self.day_data,on="date") 46 | # self.logger.debug(self.data.tail(n=20)) 47 | # exit() 48 | 49 | macd_df = TA.MACD(self.data) 50 | self.data["MACD"], self.data["SIGNAL"] = [macd_df["MACD"], macd_df["SIGNAL"]] 51 | self.data["HISTOGRAM"] = self.data["MACD"] - self.data["SIGNAL"] 52 | self.data["index"] = list(range(len(self.data))) 53 | self.data["buy"] = 0 54 | self.data["sell"] = 0 55 | 56 | def buy_logic(self): 57 | # 用30分钟的close价格作为日线价格重新计算macd 58 | self.day_data.loc[ 59 | self.day_data["date"] == self.trade_state.trading_step.date, "close" 60 | ] = self.trade_state.trading_step.close 61 | macd_day = TA.MACD(self.day_data) 62 | self.day_data["MACD_day"], self.day_data["SIGNAL_day"] = [ 63 | macd_day["MACD"], 64 | macd_day["SIGNAL"], 65 | ] 66 | self.day_data["HISTOGRAM_day"] = ( 67 | self.day_data["MACD_day"] - self.day_data["SIGNAL_day"] 68 | ) 69 | self.day_data["HISTOGRAM_ratio"] = ( 70 | self.day_data["MACD_day"] / self.day_data["SIGNAL_day"] 71 | ) 72 | allow_diff = self.trade_state.trading_step.close * self.hist_ratio 73 | cur_index = self.day_data.loc[ 74 | self.day_data["date"] == self.trade_state.trading_step.date, "index" 75 | ].item() 76 | new_data = self.day_data[["date", "HISTOGRAM_day", "MACD_day", "SIGNAL_day"]] 77 | 78 | # self.data=pd.merge(self.data,new_data,on="date") 79 | 80 | increase_three_days = False 81 | if cur_index >= 2: 82 | increase_three_days = ( 83 | self.day_data.loc[ 84 | self.day_data["index"] == cur_index, "HISTOGRAM_day" 85 | ].item() 86 | > self.day_data.loc[ 87 | self.day_data["index"] == cur_index - 1, "HISTOGRAM_day" 88 | ].item() 89 | > self.day_data.loc[ 90 | self.day_data["index"] == cur_index - 2, "HISTOGRAM_day" 91 | ].item() 92 | ) 93 | if cur_index > 4: 94 | last_five_list = self.day_data.iloc[cur_index - 4 : cur_index] 95 | else: 96 | last_five_list = self.day_data.iloc[:cur_index] 97 | last_five_avg_val = last_five_list["HISTOGRAM_day"].mean() 98 | HISTOGRAM_bigger_than = ( 99 | self.day_data.loc[ 100 | self.day_data["date"] == self.trade_state.trading_step.date, 101 | "HISTOGRAM_day", 102 | ].item() 103 | > last_five_avg_val 104 | ) 105 | 106 | # if HISTOGRAM_bigger_than and increase_three_days \ 107 | # and self.trade_state.trading_step.HISTOGRAM>=-0: 108 | # if self.trade_state.trading_step.HISTOGRAM_day >= -0.1 : 109 | # if self.trade_state.trading_step._5_10 >= 0 and self.trade_state.trading_step.HISTOGRAM>0: 110 | if ( 111 | self.trade_state.trading_step.HISTOGRAM > 0 112 | and HISTOGRAM_bigger_than 113 | and self.day_data.loc[ 114 | self.day_data["index"] == cur_index, "MACD_day" 115 | ].item() 116 | > -30 117 | ): 118 | # pdb.set_trace() 119 | return True 120 | return False 121 | 122 | def sell_logic(self): 123 | self.day_data.loc[ 124 | self.day_data["date"] == self.trade_state.trading_step.date, "close" 125 | ] = self.trade_state.trading_step.close 126 | macd_day = TA.MACD(self.day_data) 127 | self.day_data["MACD_day"], self.day_data["SIGNAL_day"] = [ 128 | macd_day["MACD"], 129 | macd_day["SIGNAL"], 130 | ] 131 | self.day_data["HISTOGRAM_day"] = ( 132 | self.day_data["MACD_day"] - self.day_data["SIGNAL_day"] 133 | ) 134 | self.day_data["HISTOGRAM_ratio"] = ( 135 | self.day_data["MACD_day"] / self.day_data["SIGNAL_day"] 136 | ) 137 | new_day_data = self.day_data[ 138 | ["date", "HISTOGRAM_day", "MACD_day", "SIGNAL_day"] 139 | ] 140 | 141 | # self.data=pd.merge(self.data,new_day_data,on="date") 142 | 143 | allow_diff = 0.1 144 | cur_index = self.day_data.loc[ 145 | self.day_data["date"] == self.trade_state.trading_step.date, "index" 146 | ].item() 147 | if cur_index > 3: 148 | last_five_list = self.day_data.iloc[cur_index - 3 : cur_index] 149 | else: 150 | last_five_list = self.day_data.iloc[:cur_index] 151 | last_five_avg_val = last_five_list["HISTOGRAM_day"].mean() 152 | HISTOGRAM_smaller_than = ( 153 | self.day_data.loc[ 154 | self.day_data["date"] == self.trade_state.trading_step.date, 155 | "HISTOGRAM_day", 156 | ].item() 157 | < last_five_avg_val 158 | ) 159 | cur_index = self.day_data.loc[ 160 | self.day_data["date"] == self.trade_state.trading_step.date, "index" 161 | ].item() 162 | decrease_three_days = False 163 | if cur_index >= 2: 164 | decrease_three_days = ( 165 | self.day_data.loc[ 166 | self.day_data["index"] == cur_index, "HISTOGRAM_day" 167 | ].item() 168 | < self.day_data.loc[ 169 | self.day_data["index"] == cur_index - 1, "HISTOGRAM_day" 170 | ].item() 171 | < self.day_data.loc[ 172 | self.day_data["index"] == cur_index - 2, "HISTOGRAM_day" 173 | ].item() 174 | ) 175 | 176 | # if self.trade_state.trading_step.HISTOGRAM <= 0 and HISTOGRAM_smaller_than and decrease_three_days: 177 | # if self.trade_state.trading_step.HISTOGRAM_day <= 0.1: 178 | 179 | # if self.trade_state.trading_step._5_10 <= 0 and self.trade_state.trading_step.HISTOGRAM < 0: 180 | HISTOGRAM_smaller_than = True 181 | if ( 182 | self.trade_state.trading_step.HISTOGRAM < 0 183 | and self.day_data.loc[ 184 | self.day_data["index"] == cur_index, "MACD_day" 185 | ].item() 186 | < 20 187 | ): 188 | # if self.trade_state.one_transaction_record.buy_date is not None: 189 | # self.data.loc[self.data["index"] == self.trade_state.trading_step["index"], "sell"] = 1 190 | return True 191 | return False 192 | 193 | 194 | if __name__ == "__main__": 195 | config = { 196 | "RANDOM_SEED": 42, 197 | "LOG_LEVEL": "INFO", 198 | "CODE_NAME": "sh.600570", 199 | # "CODE_NAME": "ALL_MARKET_10", 200 | # "CODE_NAME": ["sh.600238",], 201 | # "CODE_NAME": ["sh.603806", "sh.603697", "sh.603700", "sh.600570", "sh.603809","sh.600238","sh.603069","sh.600764","sz.002044"], 202 | "START_STAMP": "2016-01-01", 203 | "END_STAMP": "2022-05-23", 204 | # "SHOW_DATA_PATH": "", 205 | # "STRATEGY_PARAMS": {} 206 | } 207 | print( 208 | "======================macd + day 分割线=====================================分割线===============" 209 | ) 210 | print( 211 | "======================macd + day 分割线=====================================分割线===============" 212 | ) 213 | 214 | strategy = MACD30DayMacdStrategy(config) 215 | strategy.run() 216 | 217 | print( 218 | "======================30min macd+cur macd分割线=====================================分割线===============" 219 | ) 220 | print( 221 | "======================30min macd+cur macd分割线=====================================分割线===============" 222 | ) 223 | 224 | strategy = MACD30CurMacdStrategy(config) 225 | strategy.run() 226 | -------------------------------------------------------------------------------- /StrategyLib/OneAssetStrategy/macd_30m_dayMacd.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | # import pandas_ta as ta 4 | # from finta import TA 5 | from BackTrader.base_back_trader import TradeStructure 6 | 7 | 8 | class MACD30DayMacdStrategy(TradeStructure): 9 | """ """ 10 | 11 | def load_dataset(self, data_path, start_stamp=None, end_stamp=None): 12 | min30_data_path = data_path.replace( 13 | "Data/RealData/hfq/", "Data/RealData/Baostock/30min/" 14 | ) 15 | day_data_path = data_path.replace( 16 | "Data/RealData/hfq/", "Data/RealData/Baostock/day/" 17 | ) 18 | self.data = {} 19 | # self.logger.debug(data_path) 20 | self.data["30min"] = pd.read_csv(min30_data_path) 21 | self.data["day"] = pd.read_csv(day_data_path) 22 | self.data["day"] = self.data["day"][ 23 | (self.data["day"]["date"] >= start_stamp) 24 | & (self.data["day"]["date"] <= end_stamp) 25 | ] 26 | self.data["30min"] = self.data["30min"][ 27 | (self.data["30min"]["date"] >= start_stamp) 28 | & (self.data["30min"]["date"] <= end_stamp) 29 | ] 30 | self.data["30min"]["buy"] = 0 31 | self.data["30min"]["sell"] = 0 32 | self.data["30min"]["index"] = list(range(len(self.data["30min"]))) 33 | 34 | # self.logger.debug((self.data.head())) 35 | 36 | def cal_technical_indicators(self, indicators_config): 37 | self.logger.debug(indicators_config) 38 | macd_day = MACD(self.data["day"]) 39 | self.data["day"]["MACD"], self.data["day"]["SIGNAL"] = [ 40 | macd_day["MACD"], 41 | macd_day["SIGNAL"], 42 | ] 43 | self.data["day"]["HISTOGRAM_day"] = ( 44 | self.data["day"]["MACD"] - self.data["day"]["SIGNAL"] 45 | ) 46 | self.data["day"]["sma5"] = ta.sma(self.data["day"]["close"], length=5) 47 | self.data["day"]["sma10"] = ta.sma(self.data["day"]["close"], length=10) 48 | self.data["day"]["_5_10"] = round( 49 | self.data["day"]["sma5"] - self.data["day"]["sma10"], 3 50 | ) 51 | self.logger.debug(self.data["day"].tail()) 52 | self.data["day"] = self.data["day"][["date", "_5_10", "HISTOGRAM_day", "sma10"]] 53 | self.logger.debug(self.data["day"].tail(n=20)) 54 | 55 | # self.data["30min"]["5_10"] = self.data["30min"]["date"].apply( 56 | # lambda x: self.data["day"][self.data["day"].date == x]["5_10"].tolist()[0]) 57 | self.data["30min"] = pd.merge(self.data["30min"], self.data["day"], on="date") 58 | self.logger.debug(self.data["30min"].tail(n=20)) 59 | # exit() 60 | 61 | macd_df = TA.MACD(self.data["30min"]) 62 | self.data = self.data["30min"] 63 | self.data["MACD"], self.data["SIGNAL"] = [macd_df["MACD"], macd_df["SIGNAL"]] 64 | self.data["HISTOGRAM"] = self.data["MACD"] - self.data["SIGNAL"] 65 | # exit() 66 | # self.logger.info(self.data.tail(n=30)) 67 | 68 | def buy_logic(self): 69 | # self.logger.debug(pformat(self.trade_state, indent=4, width=20)) 70 | if ( 71 | self.trade_state.trading_step.HISTOGRAM_day >= -0.1 72 | ) and self.trade_state.trading_step.HISTOGRAM >= -0: 73 | # if self.trade_state.trading_step.HISTOGRAM_day >= -0.1 : 74 | # if self.trade_state.trading_step.HISTOGRAM>=-0: 75 | 76 | # if self.trade_state.trading_step._5_10 >= 0 and self.trade_state.trading_step.HISTOGRAM>0: 77 | # if self.trade_state.trading_step.HISTOGRAM > 0: 78 | return True 79 | return False 80 | 81 | def sell_logic(self): 82 | if ( 83 | self.trade_state.trading_step.HISTOGRAM_day <= 0.1 84 | and self.trade_state.trading_step.HISTOGRAM <= 0 85 | ): 86 | # if self.trade_state.trading_step.HISTOGRAM_day <= 0.1: 87 | # if self.trade_state.trading_step.HISTOGRAM <= 0: 88 | # if self.trade_state.trading_step._5_10 <= 0 and self.trade_state.trading_step.HISTOGRAM < 0: 89 | 90 | # if self.trade_state.trading_step.HISTOGRAM < 0: 91 | 92 | return True 93 | return False 94 | 95 | 96 | if __name__ == "__main__": 97 | config = { 98 | "RANDOM_SEED": 42, 99 | "LOG_LEVEL": "SUCCESS", 100 | "CODE_NAME": "sh.600570", 101 | # "CODE_NAME": "ALL_MARKET_10", 102 | # "CODE_NAME": ["sh.600238",], 103 | # "CODE_NAME": ["sh.603806", "sh.603697", "sh.603700", "sh.600570", "sh.603809","sh.600238","sh.603069","sh.600764","sz.002044"], 104 | "START_STAMP": "2015-05-01", 105 | "END_STAMP": "2022-12-20", 106 | # "SHOW_DATA_PATH": "", 107 | # "STRATEGY_PARAMS": {} 108 | } 109 | strategy = MACD30DayMacdStrategy(config) 110 | strategy.run() 111 | print( 112 | "======================分割线=====================================分割线===============" 113 | ) 114 | print( 115 | "======================分割线=====================================分割线===============" 116 | ) 117 | 118 | # strategy = MACDdayStrategy(config) 119 | # strategy.run() 120 | -------------------------------------------------------------------------------- /StrategyLib/OneAssetStrategy/macd_day.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author : adolf adolf1321794021@gmail.com 3 | @Date : 2023-06-06 22:51:19 4 | @LastEditors : adolf 5 | @LastEditTime : 2023-06-10 16:20:57 6 | @FilePath : /stock_quant/StrategyLib/OneAssetStrategy/macd_day.py 7 | @Description : 8 | """ 9 | 10 | import pandas as pd 11 | import pandas_ta as ta 12 | from finta import TA 13 | 14 | from BackTrader.base_back_trader import TradeStructure 15 | 16 | 17 | class MACDdayStrategy(TradeStructure): 18 | """ """ 19 | 20 | def load_dataset(self, data_path, start_stamp=None, end_stamp=None): 21 | self.logger.info(data_path) 22 | day_data_path = data_path.replace("Data/RealData/hfq/", "Data/Baostock/day/") 23 | self.data = {} 24 | # self.logger.debug(data_path) 25 | self.data["day"] = pd.read_csv(day_data_path) 26 | self.data["day"] = self.data["day"][ 27 | (self.data["day"]["date"] >= start_stamp) 28 | & (self.data["day"]["date"] <= end_stamp) 29 | ] 30 | 31 | # self.logger.debug((self.data.head())) 32 | 33 | def cal_technical_indicators(self, indicators_config): 34 | self.logger.debug(indicators_config) 35 | macd_day = TA.MACD(self.data["day"]) 36 | self.data["day"]["MACD"], self.data["day"]["SIGNAL"] = [ 37 | macd_day["MACD"], 38 | macd_day["SIGNAL"], 39 | ] 40 | self.data["day"]["HISTOGRAM_day"] = ( 41 | self.data["day"]["MACD"] - self.data["day"]["SIGNAL"] 42 | ) 43 | self.data["day"]["sma5"] = ta.sma(self.data["day"]["close"], length=5) 44 | self.data["day"]["sma10"] = ta.sma(self.data["day"]["close"], length=10) 45 | self.data["day"]["_5_10"] = round( 46 | self.data["day"]["sma5"] - self.data["day"]["sma10"], 3 47 | ) 48 | # self.logger.info(self.data["day"].tail(n=20)) 49 | self.data = self.data["day"] 50 | self.data["buy"] = 0 51 | self.data["sell"] = 0 52 | 53 | def buy_logic(self): 54 | # self.logger.debug(pformat(self.trade_state, indent=4, width=20)) 55 | # if self.trade_state.trading_step.HISTOGRAM_day >= -0.1 and self.trade_state.trading_step.HISTOGRAM>0: 56 | if self.trade_state.trading_step.HISTOGRAM_day >= -0: 57 | # if self.trade_state.trading_step._5_10 >= 0 and self.trade_state.trading_step.HISTOGRAM>0: 58 | # if self.trade_state.trading_step.HISTOGRAM > 0: 59 | return True 60 | return False 61 | 62 | def sell_logic(self): 63 | # if self.trade_state.trading_step.HISTOGRAM_day <= 0.1 and self.trade_state.trading_step.HISTOGRAM < 0: 64 | if self.trade_state.trading_step.HISTOGRAM_day <= 0: 65 | # if self.trade_state.trading_step._5_10 <= 0 and self.trade_state.trading_step.HISTOGRAM < 0: 66 | 67 | # if self.trade_state.trading_step.HISTOGRAM < 0: 68 | 69 | return True 70 | return False 71 | 72 | 73 | if __name__ == "__main__": 74 | config = { 75 | "RANDOM_SEED": 42, 76 | "LOG_LEVEL": "INFO", 77 | "CODE_NAME": "sz.399006", 78 | # "CODE_NAME": "ALL_MARKET_10", 79 | # "CODE_NAME": ["sh.600238",], 80 | # "CODE_NAME": ["sh.603806", "sh.603697", "sh.603700", "sh.600570", "sh.603809","sh.600238","sh.603069","sh.600764","sz.002044"], 81 | "START_STAMP": "2015-05-01", 82 | "END_STAMP": "2022-12-20", 83 | # "SHOW_DATA_PATH": "", 84 | # "STRATEGY_PARAMS": {} 85 | } 86 | strategy = MACDdayStrategy(config) 87 | strategy.run() 88 | -------------------------------------------------------------------------------- /StrategyLib/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-08-14 13:18:30 5 | LastEditTime: 2022-08-14 13:20:42 6 | LastEditors: adolf 7 | """ 8 | 9 | # from .OneAssetStrategy.config import * 10 | -------------------------------------------------------------------------------- /StrategyLib/macd_day.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-12-03 17:59:38 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2022-12-10 18:02:26 6 | FilePath : /stock_quant/StrategyLib/OneAssetStrategy/Ma5Ma10.py 7 | """ 8 | 9 | from pprint import pformat 10 | 11 | from BackTrader.base_back_trader import TradeStructure 12 | 13 | # from pyti.simple_moving_average import simple_moving_average as sma 14 | from Utils.TechnicalIndicators.basic_indicators import SMA 15 | 16 | 17 | class Ma5Ma10Strategy(TradeStructure): 18 | """ 19 | 5日均线和10日均线策略,当5日均线上穿10日均线时买入,当5日均线下穿10日均线时卖出 20 | """ 21 | 22 | def cal_technical_indicators(self, indicators_config): 23 | self.logger.debug(indicators_config) 24 | 25 | # self.data["sma5"] = ta.sma(self.data["close"], length=5) 26 | # self.data["sma10"] = ta.sma(self.data["close"], length=10) 27 | self.data["sma5"] = SMA(self.data["close"], timeperiod=5) 28 | self.data["sma10"] = SMA(self.data["close"], timeperiod=10) 29 | 30 | # def buy_logic(self, trading_step, one_transaction_record, history_trading_step): 31 | def buy_logic(self): 32 | self.logger.debug(pformat(self.trade_state, indent=4, width=20)) 33 | return bool( 34 | self.trade_state.trading_step.sma5 > self.trade_state.trading_step.sma10 35 | and self.trade_state.history_trading_step[0].sma5 36 | < self.trade_state.history_trading_step[0].sma10 37 | ) 38 | 39 | def sell_logic(self): 40 | return bool( 41 | self.trade_state.trading_step.sma5 < self.trade_state.trading_step.sma10 42 | and self.trade_state.history_trading_step[0].sma5 43 | > self.trade_state.history_trading_step[0].sma10 44 | ) 45 | 46 | 47 | if __name__ == "__main__": 48 | config = { 49 | "RANDOM_SEED": 42, 50 | "LOG_LEVEL": "INFO", 51 | "CODE_NAME": "600570", 52 | # "CODE_NAME": "ALL_MARKET_100", 53 | # "CODE_NAME": ["600570", "002610", "300663"], 54 | "START_STAMP": "2020-01-01", 55 | "END_STAMP": "2020-12-31", 56 | # "SHOW_DATA_PATH": "", 57 | # "STRATEGY_PARAMS": {} 58 | } 59 | strategy = Ma5Ma10Strategy(config) 60 | strategy.run() 61 | -------------------------------------------------------------------------------- /StrategyResearch/board_mom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-07-24 14:09:47 5 | LastEditTime: 2022-08-01 20:51:31 6 | LastEditors: adolf 7 | """ 8 | 9 | import json 10 | import os 11 | 12 | import pandas as pd 13 | from finta import TA 14 | from loguru import logger 15 | 16 | # from Utils.base_utils import logger 17 | os.environ["LOGURU_LEVEL"] = "INFO" 18 | 19 | pd.set_option("expand_frame_repr", False) 20 | pd.set_option("display.max_rows", 100) 21 | 22 | board_data_path = "Data/BoardData/" 23 | 24 | with open(board_data_path + "ALL_INDUSTRY_BOARD.json") as f: 25 | board_dict = json.load(f) 26 | 27 | # pprint(board_dict) 28 | 29 | board_list = board_dict.keys() 30 | # board_mom_dict = {} 31 | mom_list = [] 32 | mom_sma_list = [] 33 | for one_board in board_list: 34 | # print(one_board) 35 | df = pd.read_csv(board_data_path + f"industry_origin/{one_board}.csv") 36 | 37 | df = df[["date", "open", "close", "high", "low", "volume"]] 38 | # print(df) 39 | 40 | df["mom"] = TA.MOM(df, period=20) 41 | # print(df) 42 | df["mom_sma"] = TA.SMA(df, period=10, column="mom") 43 | # print(df) 44 | mom_list.append(round(df.mom.tail(1).item(), 3)) 45 | mom_sma_list.append(round(df.mom_sma.tail(1).item(), 3)) 46 | # board_mom_dict[one_board] = {} 47 | # board_mom_dict[one_board]['mom'] = round(df.mom.tail(1).item(),3) 48 | # board_mom_dict[one_board]['mom_sma'] = round(df.mom_sma.tail(1).item(),3) 49 | # break 50 | 51 | # pprint(board_mom_dict) 52 | board_mom_dict = {"board": board_list, "mom": mom_list, "mom_sma": mom_sma_list} 53 | board_mom_df = pd.DataFrame.from_dict(board_mom_dict) 54 | board_mom_df["code"] = board_mom_df["board"].apply(lambda x: board_dict[x]) 55 | board_mom_df.sort_values(by=["mom_sma", "mom"], ascending=False, inplace=True) 56 | logger.debug(board_mom_df) 57 | -------------------------------------------------------------------------------- /StrategyResearch/board_stock_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-07-26 22:52:23 5 | LastEditTime: 2022-07-26 23:11:39 6 | LastEditors: adolf 7 | """ 8 | 9 | import akshare as ak 10 | import pandas as pd 11 | from loguru import logger 12 | 13 | from GetBaseData.ch_eng_mapping import ch_eng_mapping_dict 14 | 15 | pd.set_option("display.max_columns", None) 16 | pd.set_option("display.max_rows", 500) 17 | 18 | stock_board_industry_cons_em_df = ak.stock_board_industry_cons_em(symbol="通用设备") 19 | 20 | stock_board_industry_cons_em_df.rename(columns=ch_eng_mapping_dict, inplace=True) 21 | reserve_list = ["code", "name", "pctChg", "price", "pre_close", "turn"] 22 | stock_board_industry_cons_em_df = stock_board_industry_cons_em_df[reserve_list] 23 | 24 | stock_board_industry_cons_em_df["pct"] = ( 25 | stock_board_industry_cons_em_df["price"] 26 | / stock_board_industry_cons_em_df["pre_close"] 27 | - 1 28 | ) * 100 29 | stock_board_industry_cons_em_df.sort_values(by=["pct"], ascending=False, inplace=True) 30 | stock_board_industry_cons_em_df = stock_board_industry_cons_em_df.round(2) 31 | 32 | logger.info(stock_board_industry_cons_em_df) 33 | -------------------------------------------------------------------------------- /StrategyResearch/data_research/compare_sk_ta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-12-18 14:19:37 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2022-12-18 17:27:09 6 | FilePath : /stock_quant/StrategyResearch/data_research/compare_sk_ta.py 7 | """ 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import pandas_ta as ta 12 | from sklearn.linear_model import LinearRegression 13 | 14 | 15 | def normalization(data): 16 | _range = np.max(data) - np.min(data) 17 | return (data - np.min(data)) / _range 18 | 19 | 20 | def cal_one_board_mom(board_data_path, period=20): 21 | data = pd.read_csv("Data/BoardData/industry_origin/" + board_data_path) 22 | data = data[["date", "open", "high", "low", "close", "volume"]] 23 | data = data[-period:] 24 | data.reset_index(drop=True, inplace=True) 25 | # data["mid"] = (data["open"] + data["close"] + data["high"] + data["low"]) / 4 26 | 27 | model = LinearRegression() 28 | x = np.linspace(0, 1, period).reshape(-1, 1) 29 | 30 | y_close = data.close.values.reshape(-1, 1) 31 | # y_close = normalization(y_close) 32 | model.fit(x, y_close) 33 | 34 | # print(model.coef_[0][0]) 35 | R2 = model.score(x, y_close) 36 | return model.coef_, R2 37 | 38 | 39 | def use_ta_cal_one_board_mom(board_data_path, period=20): 40 | data = pd.read_csv("Data/BoardData/industry_origin/" + board_data_path) 41 | data = data[["date", "open", "high", "low", "close", "volume"]] 42 | data = data[-period:] 43 | data.reset_index(drop=True, inplace=True) 44 | # data["mid"] = (data["open"] + data["close"] + data["high"] + data["low"]) / 4 45 | data["linreg"] = ta.linreg(data.close, length=period, slope=True) 46 | return data.tail(1).linreg.values[0] 47 | 48 | 49 | if __name__ == "__main__": 50 | print(cal_one_board_mom("汽车整车.csv")) 51 | print(use_ta_cal_one_board_mom("汽车整车.csv")) 52 | -------------------------------------------------------------------------------- /StrategyResearch/data_research/data_pipeline_dask.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-11-23 21:58:43 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2022-11-23 22:16:03 6 | FilePath : /stock_quant/StrategyResearch/data_research/data_pipeline_dask.py 7 | """ 8 | 9 | import akshare as ak 10 | 11 | from GetBaseData.ch_eng_mapping import ch_eng_mapping_dict 12 | 13 | 14 | def get_original_data(code): 15 | stock_zh_a_hist_df = ak.stock_zh_a_hist(symbol=code, adjust="hfq") 16 | stock_zh_a_hist_df.rename(columns=ch_eng_mapping_dict, inplace=True) 17 | 18 | return stock_zh_a_hist_df[ 19 | ["date", "open", "close", "high", "low", "volume", "turn", "pctChg"] 20 | ].iloc[::-1] 21 | 22 | 23 | def base_data_pipeline(code): 24 | data = get_original_data(code) 25 | 26 | print(data) 27 | 28 | for index, row in data.iterrows(): 29 | print(row) 30 | break 31 | 32 | 33 | if __name__ == "__main__": 34 | base_data_pipeline("000001") 35 | -------------------------------------------------------------------------------- /StrategyResearch/linear_regression.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-08-01 21:05:06 5 | LastEditTime: 2022-08-08 23:57:07 6 | LastEditors: adolf 7 | """ 8 | 9 | import json 10 | import sys 11 | from functools import reduce 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import psutil 16 | 17 | # from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error, median_absolute_error 18 | # from GetBaseData.hanle_data_show import show_data_from_df 19 | # from Utils.ShowKline.base_kline import draw_chart 20 | # import matplotlib.pyplot as plt 21 | import ray 22 | from loguru import logger 23 | from sklearn.linear_model import LinearRegression 24 | from tqdm.auto import tqdm 25 | 26 | pd.set_option("display.max_columns", None) 27 | 28 | ray.init(num_cpus=psutil.cpu_count(logical=False)) 29 | 30 | logger.remove() # 删去import logger之后自动产生的handler,不删除的话会出现重复输出的现象 31 | handler_id = logger.add( 32 | sys.stderr, level="debug".upper() 33 | ) # 添加一个可以修改控制的handler 34 | 35 | board_data_path = "Data/BoardData/" 36 | 37 | with open(board_data_path + "ALL_INDUSTRY_BOARD.json") as f: 38 | board_dict = json.load(f) 39 | 40 | board_list = list(board_dict.keys()) 41 | 42 | 43 | def normalization(data): 44 | _range = np.max(data) - np.min(data) 45 | return (data - np.min(data)) / _range 46 | 47 | 48 | def cal_one_date_mom(origin_data, period=20): 49 | x = np.linspace(0, 1, period).reshape(-1, 1) 50 | 51 | logger.debug(origin_data) 52 | 53 | y = origin_data.values.reshape(-1, 1) 54 | 55 | y = normalization(y) 56 | # 线性回归 57 | model = LinearRegression() 58 | 59 | # 岭回归 60 | # model = Ridge(alpha=1.0, fit_intercept=True) 61 | # Lasso回归 62 | # model = Lasso(alpha=1.0, fit_intercept=True) 63 | model.fit(x, y) 64 | 65 | # yFit = model.predict(x) 66 | 67 | # 输出回归结果 XUPT 68 | # print('回归截距: w0={}'.format(model.intercept_[0])) # w0: 截距 69 | # print('回归系数: w1={}'.format(model.coef_[0][0])) # w1,..wm: 回归系数 70 | 71 | # print('R2 确定系数:{:.4f}'.format(model.score(x, y))) # R2 判定系数 72 | # print('均方误差:{:.4f}'.format(mean_squared_error(y, yFit))) # MSE 均方误差 73 | # print('平均绝对值误差:{:.4f}'.format(mean_absolute_error(y, yFit))) # MAE 平均绝对误差 74 | # print('中位绝对值误差:{:.4f}'.format(median_absolute_error(y, yFit))) # 中值绝对误差 75 | 76 | # return (model.coef_[0][0], model.score(x, y)) 77 | return model.coef_[0][0] 78 | 79 | 80 | # one_board = "汽车零部件" 81 | @ray.remote 82 | def cal_linear_regression(board_name): 83 | df = pd.read_csv(board_data_path + f"industry_origin/{board_name}.csv") 84 | 85 | # df = df[-100:] 86 | 87 | df = df[["date", "open", "close", "high", "low", "volume"]] 88 | df["mid"] = (df["open"] + df["close"] + df["high"] + df["low"]) / 4 89 | 90 | logger.info(df) 91 | time_period = 20 92 | 93 | df["line_w"] = ( 94 | df["close"] 95 | .rolling(window=time_period) 96 | .apply(lambda x: cal_one_date_mom(x, time_period)) 97 | ) 98 | # df['line_w'], df['line_R2'] = zip(*df['close'].rolling( 99 | # window=time_period).apply(lambda x: cal_one_date_mom(x, time_period))) 100 | 101 | df = df[["date", "close", "line_w"]] 102 | df.rename( 103 | columns={"close": f"{board_name}_close", "line_w": f"{board_name}_mom"}, 104 | inplace=True, 105 | ) 106 | 107 | logger.info(df) 108 | # logger.debug(x) 109 | # logger.debug(y) 110 | 111 | return df 112 | 113 | 114 | def get_all_data(): 115 | # board_list = board_list[:10] 116 | futures = [cal_linear_regression.remote(board) for board in board_list] 117 | 118 | def to_iterator(obj_ids): 119 | while obj_ids: 120 | done, obj_ids = ray.wait(obj_ids) 121 | yield ray.get(done[0]) 122 | 123 | for x in tqdm(to_iterator(futures), total=len(board_list)): 124 | pass 125 | 126 | futures = ray.get(futures) 127 | 128 | df_merged = reduce( 129 | lambda left, right: pd.merge(left, right, on=["date"], how="outer"), futures 130 | ) 131 | df_merged.sort_values(by=["date"], inplace=True) 132 | df_merged.reset_index(drop=True, inplace=True) 133 | 134 | # for future in futures: 135 | # logger.success(future) 136 | logger.success(df_merged) 137 | df_merged.to_csv(board_data_path + "/ALL_INDUSTRY_BOARD_HISTORY.csv", index=False) 138 | 139 | 140 | def choose_what_need(all_df): 141 | for board_name in board_dict.keys(): 142 | all_df[f"{board_name}_pct"] = ( 143 | all_df[f"{board_name}_close"] / all_df[f"{board_name}_close"].shift(1) - 1 144 | ) 145 | 146 | for index, row in all_df.iterrows(): 147 | # logger.debug(row) 148 | tmp_mom = row[ 149 | [ 150 | f"{board_name}_mom" 151 | for board_name in board_list 152 | if not pd.isna(row[f"{board_name}_mom"]) 153 | ] 154 | ] 155 | tmp_mom = tmp_mom.to_dict() 156 | tmp_mom = sorted(tmp_mom.items(), key=lambda x: x[1], reverse=True) 157 | # logger.debug(tmp_mom) 158 | try: 159 | all_df.loc[index, "top_mom"] = tmp_mom[0][0] 160 | all_df.loc[index, "top_mom_pct"] = all_df.loc[ 161 | index + 1, tmp_mom[0][0].replace("_mom", "_pct") 162 | ] 163 | except Exception as e: 164 | logger.warning(e) 165 | # if index > 100: 166 | # break 167 | 168 | # logger.info(all_df) 169 | 170 | # all_df['top_mom_pct'] = all_df["{}_pct".format(all_df['top_mom'])].shift(1) 171 | 172 | all_df = all_df[-1000:] 173 | all_df["strategy_net"] = (1 + all_df["top_mom_pct"]).cumprod() 174 | logger.info(all_df) 175 | 176 | 177 | if __name__ == "__main__": 178 | # cal_linear_regression("汽车零部件") 179 | all_df_ = pd.read_csv(board_data_path + "/ALL_INDUSTRY_BOARD_HISTORY.csv") 180 | choose_what_need(all_df_) 181 | 182 | # fig, ax = plt.subplots(figsize=(8, 6)) 183 | # ax.plot(x, y, 'o', label="data") # 原始数据 184 | # ax.plot(x, yFit, 'r-', label="OLS") # 拟合数据 185 | 186 | # ax.legend(loc='best') # 显示图例 187 | # plt.title('Linear regression by SKlearn (Youcans)') 188 | # # plt.show() # YouCans, XUPT 189 | # plt.savefig('ShowHtml/line.jpg') 190 | -------------------------------------------------------------------------------- /StrategyResearch/popularity.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-07-26 23:45:09 5 | LastEditTime: 2022-08-01 20:36:30 6 | LastEditors: adolf 7 | """ 8 | 9 | from datetime import date 10 | 11 | import akshare as ak 12 | import pandas as pd 13 | 14 | pd.set_option("display.max_columns", None) 15 | pd.set_option("display.max_rows", 200) 16 | 17 | today = date.today() 18 | d1 = today.strftime("%Y%m%d") 19 | print("今天的日期:", d1) 20 | # exit() 21 | 22 | # 问财热度排行 23 | stock_hot_rank_wc_df = ak.stock_hot_rank_wc(date=d1) 24 | print(stock_hot_rank_wc_df[:200]) 25 | exit() 26 | 27 | # 东财热度排行 28 | stock_hot_rank_em_df = ak.stock_hot_rank_em() 29 | print(stock_hot_rank_em_df) 30 | 31 | # 淘股吧热度排行 32 | stock_hot_tgb_df = ak.stock_hot_tgb() 33 | print(stock_hot_tgb_df) 34 | 35 | # 雪球讨论热度榜 36 | new_hot = ak.stock_hot_tweet_xq(symbol="本周新增") 37 | new_hot["new_hot_rank"] = new_hot.index 38 | print(new_hot[:100]) 39 | 40 | # old_hot = ak.stock_hot_tweet_xq(symbol="最热门") 41 | # old_hot['old_hot_rank'] = old_hot.index 42 | # print(old_hot[:100]) 43 | 44 | # hot_df = pd.merge(new_hot, old_hot) 45 | # hot_df['diff'] = hot_df['new_hot_rank'] - hot_df['old_hot_rank'] 46 | # # hot_df.sort_values(by='diff', ascending=False, inplace=True) 47 | # hot_df = hot_df.loc[hot_df.new_hot_rank < 2000] 48 | # print(hot_df[:100]) 49 | -------------------------------------------------------------------------------- /StrategyResearch/time_series/CR.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author : adolf 3 | Date : 2022-12-18 22:55:21 4 | LastEditors : adolf adolf1321794021@gmail.com 5 | LastEditTime : 2023-01-09 22:12:50 6 | FilePath : /stock_quant/StrategyResearch/time_series/CR.py 7 | """ 8 | 9 | import warnings 10 | 11 | import pandas_ta as ta 12 | 13 | from BackTrader.base_back_trader import TradeStructure 14 | 15 | warnings.filterwarnings("ignore", category=FutureWarning) 16 | 17 | 18 | class Intertwine(TradeStructure): 19 | def cal_technical_indicators(self, indicators_config): 20 | # self.logger.debug(indicators_config) 21 | # self.logger.debug(self.data.head(30)) 22 | 23 | # 计算5日均线和10日均线 24 | self.data["sma5"] = ta.sma(self.data.close, length=5) 25 | self.data["sma10"] = ta.sma(self.data.close, length=10) 26 | 27 | # 计算5日均线和10日均线的交叉 28 | self.data["ma_long"] = ta.cross(self.data.sma5, self.data.sma10) 29 | self.data["ma_short"] = ta.cross(self.data.sma10, self.data.sma5) 30 | 31 | # 计算macd的值 32 | self.data[["macd", "histogram", "signal"]] = ta.macd( 33 | self.data.close, fast=12, slow=26, signal=9 34 | ) 35 | 36 | # 计算bolinger band的值 37 | self.data[["lower", "mid", "upper", "width", "percent"]] = ta.bbands( 38 | self.data.close, length=20, std=2 39 | ) 40 | 41 | # 计算atr的值 42 | self.data["atr"] = ta.atr( 43 | self.data.high, self.data.low, self.data.close, length=14 44 | ) 45 | 46 | self.data.drop( 47 | ["signal", "market_cap", "code", "width", "percent"], axis=1, inplace=True 48 | ) 49 | 50 | # self.logger.debug(res.tail(30)) 51 | self.logger.debug(self.data.tail(30)) 52 | 53 | exit() 54 | 55 | # self.logger.debug(self.data.tail(30)) 56 | 57 | def buy_logic(self): 58 | self.logger.debug(self.trade_state.trading_step) 59 | self.logger.debug(self.trade_state.one_transaction_record) 60 | pass 61 | 62 | def sell_logic(self): 63 | self.logger.debug(self.trade_state.trading_step) 64 | self.logger.debug(self.trade_state.one_transaction_record) 65 | pass 66 | 67 | 68 | if __name__ == "__main__": 69 | config = { 70 | # "RANDOM_SEED": 42, 71 | "LOG_LEVEL": "DEBUG", 72 | "CODE_NAME": "600519", 73 | # "CODE_NAME": "ALL_MARKET_10", 74 | # "CODE_NAME": ["600570", "002610", "300663"], 75 | "START_STAMP": "2022-01-01", 76 | # "END_STAMP": "2020-12-31", 77 | # "SHOW_DATA_PATH": "", 78 | # "STRATEGY_PARAMS": {"sma_length": 10, "ema_length": 10}, 79 | } 80 | strategy = Intertwine(config) 81 | strategy.run() 82 | -------------------------------------------------------------------------------- /StrategyResearch/tmp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: 3 | Author: adolf 4 | Date: 2022-08-14 23:26:03 5 | LastEditTime: 2022-08-15 00:10:45 6 | LastEditors: adolf 7 | """ 8 | 9 | import pandas as pd 10 | import pandas_ta as ta 11 | 12 | # df = pd.DataFrame() 13 | 14 | # print(df.ta.indicators()) 15 | 16 | # Help about an indicator such as bbands 17 | # help(ta.vp) 18 | 19 | df = pd.read_csv("Data/RealData/hfq/000001.csv") 20 | 21 | df = df[-1000:] 22 | 23 | # print(df) 24 | test = ta.vp(close=df["close"], volume=df["volume"], width=20) 25 | print(test) 26 | -------------------------------------------------------------------------------- /StudyDoc/Advances.in.Financial.Machine.Learning-Wiley(2018).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKQ1688/stock_quant/c88e695a5cf17f2a445b671026f1535d369910ce/StudyDoc/Advances.in.Financial.Machine.Learning-Wiley(2018).pdf -------------------------------------------------------------------------------- /Utils/ShowKline/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*-''' 3 | # @Project : stock_quant 4 | # @Date : 2022/1/6 16:49 5 | # @Author : Adolf 6 | # @File : __init__.py.py 7 | -------------------------------------------------------------------------------- /Utils/ShowKline/base_kline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*-''' 3 | # @Project : stock_quant 4 | # @Date : 2022/1/6 16:55 5 | # @Author : Adolf 6 | # @File : base_kline.py 7 | 8 | from pyecharts import options as opts 9 | from pyecharts.charts import Bar, Grid, Kline, Line 10 | from pyecharts.commons.utils import JsCode 11 | from pyecharts.options import InitOpts 12 | 13 | 14 | def calculate_ma(input_data, day_count: int): 15 | result: list[float | str] = [] 16 | 17 | for i in range(len(input_data["times"])): 18 | if i < day_count: 19 | result.append("-") 20 | continue 21 | sum_total = 0.0 22 | for j in range(day_count): 23 | sum_total += float(input_data["datas"][i - j][1]) 24 | result.append(abs(float("%.2f" % (sum_total / day_count)))) 25 | return result 26 | 27 | 28 | def draw_chart(input_data, show_html_path="ShowHtml/CandleChart.html"): 29 | kline = Kline() 30 | points = [] 31 | colors_div = {"buy": "red", "sell": "green"} 32 | for label in ["buy", "sell"]: 33 | for i, val in enumerate(input_data[label]): 34 | if val == 1: 35 | coord = [input_data["times"][i], input_data["datas"][i][1]] 36 | point = opts.MarkPointItem( 37 | coord=coord, name=label, itemstyle_opts={"color": colors_div[label]} 38 | ) 39 | points.append(point) 40 | 41 | # points.extend([opts.MarkPointItem(type_="max", name="最大值"), 42 | # opts.MarkPointItem(type_="min", name="最小值")]) 43 | 44 | # import pdb;pdb.set_trace() 45 | kline.add_xaxis(xaxis_data=input_data["times"]) 46 | kline.add_yaxis( 47 | series_name="", 48 | y_axis=input_data["datas"], 49 | itemstyle_opts=opts.ItemStyleOpts( 50 | color="#ef232a", 51 | color0="#14b143", 52 | border_color="#ef232a", 53 | border_color0="#14b143", 54 | ), 55 | markpoint_opts=opts.MarkPointOpts(data=points), 56 | # markline_opts=opts.MarkLineOpts( 57 | # label_opts=opts.LabelOpts( 58 | # position="middle", color="blue", font_size=15 59 | # ), 60 | # data=split_data_part(input_data), 61 | # symbol=["circle", "none"], 62 | # ), 63 | ) 64 | # kline.set_series_opts( 65 | # markarea_opts=opts.MarkAreaOpts(is_silent=True, data=split_data_part()) 66 | # ) 67 | kline.set_global_opts( 68 | title_opts=opts.TitleOpts(title="K线展示图", pos_left="0"), 69 | xaxis_opts=opts.AxisOpts( 70 | type_="category", 71 | is_scale=True, 72 | boundary_gap=False, 73 | axisline_opts=opts.AxisLineOpts(is_on_zero=False), 74 | splitline_opts=opts.SplitLineOpts(is_show=False), 75 | split_number=20, 76 | min_="dataMin", 77 | max_="dataMax", 78 | ), 79 | yaxis_opts=opts.AxisOpts( 80 | is_scale=True, splitline_opts=opts.SplitLineOpts(is_show=True) 81 | ), 82 | tooltip_opts=opts.TooltipOpts(trigger="axis", axis_pointer_type="line"), 83 | datazoom_opts=[ 84 | opts.DataZoomOpts( 85 | is_show=True, type_="inside", xaxis_index=[0, 0], range_end=100 86 | ), 87 | opts.DataZoomOpts( 88 | is_show=True, xaxis_index=[0, 1], pos_top="97%", range_end=100 89 | ), 90 | opts.DataZoomOpts(is_show=False, xaxis_index=[0, 2], range_end=100), 91 | ], 92 | ) 93 | 94 | kline_line_ma = Line() 95 | kline_line_ma.add_xaxis(xaxis_data=input_data["times"]) 96 | kline_line_ma.add_yaxis( 97 | series_name="MA5", 98 | y_axis=calculate_ma(input_data=input_data, day_count=5), 99 | is_smooth=True, 100 | linestyle_opts=opts.LineStyleOpts(opacity=0.5), 101 | label_opts=opts.LabelOpts(is_show=False), 102 | is_symbol_show=False, 103 | ) 104 | kline_line_ma.add_yaxis( 105 | series_name="MA10", 106 | y_axis=calculate_ma(input_data=input_data, day_count=10), 107 | is_smooth=True, 108 | linestyle_opts=opts.LineStyleOpts(opacity=0.5), 109 | label_opts=opts.LabelOpts(is_show=False), 110 | is_symbol_show=False, 111 | ) 112 | kline_line_ma.set_global_opts( 113 | xaxis_opts=opts.AxisOpts( 114 | type_="category", grid_index=1, axislabel_opts=opts.LabelOpts(is_show=False) 115 | ), 116 | yaxis_opts=opts.AxisOpts( 117 | grid_index=1, 118 | split_number=3, 119 | axisline_opts=opts.AxisLineOpts(is_on_zero=False), 120 | axistick_opts=opts.AxisTickOpts(is_show=False), 121 | splitline_opts=opts.SplitLineOpts(is_show=False), 122 | axislabel_opts=opts.LabelOpts(is_show=True), 123 | ), 124 | ) 125 | overlap_kline_line = kline.overlap(kline_line_ma) 126 | 127 | # bar_vol = Bar() 128 | # bar_vol.add_xaxis(xaxis_data=input_data["times"]) 129 | # bar_vol.add_yaxis( 130 | # series_name="Volume", 131 | # y_axis=input_data["vols"], 132 | # xaxis_index=1, 133 | # yaxis_index=1, 134 | # label_opts=opts.LabelOpts(is_show=False), 135 | # # 根据 echarts demo 的原版是这么写的 136 | # # itemstyle_opts=opts.ItemStyleOpts( 137 | # # color=JsCode( 138 | # # """ 139 | # # function(params) { 140 | # # var colorList; 141 | # # if (input_data.datas[params.dataIndex][1]>input_data.datas[params.dataIndex][0]) { 142 | # # colorList = '#ef232a'; 143 | # # } else { 144 | # # colorList = '#14b143'; 145 | # # } 146 | # # return colorList; 147 | # # } 148 | # # """) 149 | # # ) 150 | # # 改进后在 grid 中 add_js_func 后变成如下 151 | # itemstyle_opts=opts.ItemStyleOpts( 152 | # color=JsCode( 153 | # """ 154 | # function(params) { 155 | # var colorList; 156 | # if (barData[params.dataIndex][1] > barData[params.dataIndex][0]) { 157 | # colorList = '#ef232a'; 158 | # } else { 159 | # colorList = '#14b143'; 160 | # } 161 | # return colorList; 162 | # } 163 | # """ 164 | # ) 165 | # ), 166 | # ) 167 | # bar_vol.set_global_opts( 168 | # xaxis_opts=opts.AxisOpts( 169 | # type_="category", 170 | # grid_index=1, 171 | # axislabel_opts=opts.LabelOpts(is_show=False), 172 | # ), 173 | # legend_opts=opts.LegendOpts(is_show=False), 174 | # ) 175 | 176 | # 成交量图 177 | bar_vol = Bar() 178 | bar_vol.add_xaxis(input_data["times"]) 179 | bar_vol.add_yaxis( 180 | series_name="Volume", 181 | y_axis=input_data["vols"], 182 | bar_width="60%", 183 | label_opts=opts.LabelOpts(is_show=False), 184 | itemstyle_opts=opts.ItemStyleOpts( 185 | color=JsCode( 186 | """ 187 | function(params) { 188 | var colorList; 189 | if (params.data >= 0) { 190 | colorList = '#ef232a'; 191 | } else { 192 | colorList = '#14b143'; 193 | } 194 | return colorList; 195 | } 196 | """ 197 | ) 198 | ), 199 | ) 200 | bar_vol.set_global_opts( 201 | xaxis_opts=opts.AxisOpts( 202 | type_="category", 203 | grid_index=1, 204 | axislabel_opts=opts.LabelOpts(is_show=True, font_size=8, color="#9b9da9"), 205 | is_show=False, 206 | ), 207 | yaxis_opts=opts.AxisOpts( 208 | is_scale=True, 209 | axislabel_opts=opts.LabelOpts( 210 | color="#c7c7c7", font_size=8, position="inside", is_show=False 211 | ), 212 | is_show=False, 213 | ), 214 | legend_opts=opts.LegendOpts(is_show=False), 215 | ) 216 | 217 | # macd图 218 | bar_macd = Bar() 219 | bar_macd.add_xaxis(xaxis_data=input_data["times"]) 220 | bar_macd.add_yaxis( 221 | series_name="MACD", 222 | y_axis=input_data["macds"], 223 | xaxis_index=2, 224 | yaxis_index=2, 225 | label_opts=opts.LabelOpts(is_show=False), 226 | itemstyle_opts=opts.ItemStyleOpts( 227 | color=JsCode( 228 | """ 229 | function(params) { 230 | var colorList; 231 | if (params.data >= 0) { 232 | colorList = '#ef232a'; 233 | } else { 234 | colorList = '#14b143'; 235 | } 236 | return colorList; 237 | } 238 | """ 239 | ) 240 | ), 241 | ) 242 | bar_macd.set_global_opts( 243 | xaxis_opts=opts.AxisOpts( 244 | type_="category", grid_index=2, axislabel_opts=opts.LabelOpts(is_show=False) 245 | ), 246 | yaxis_opts=opts.AxisOpts( 247 | grid_index=2, 248 | split_number=4, 249 | axisline_opts=opts.AxisLineOpts(is_on_zero=False), 250 | axistick_opts=opts.AxisTickOpts(is_show=False), 251 | splitline_opts=opts.SplitLineOpts(is_show=False), 252 | axislabel_opts=opts.LabelOpts(is_show=True), 253 | ), 254 | legend_opts=opts.LegendOpts(is_show=False), 255 | ) 256 | 257 | line_macd = Line() 258 | line_macd.add_xaxis(xaxis_data=input_data["times"]) 259 | line_macd.add_yaxis( 260 | series_name="DIF", 261 | y_axis=input_data["difs"], 262 | xaxis_index=2, 263 | yaxis_index=2, 264 | label_opts=opts.LabelOpts(is_show=False), 265 | is_symbol_show=False, 266 | ) 267 | line_macd.add_yaxis( 268 | series_name="DEA", 269 | y_axis=input_data["deas"], 270 | xaxis_index=2, 271 | yaxis_index=2, 272 | label_opts=opts.LabelOpts(is_show=False), 273 | is_symbol_show=False, 274 | ) 275 | line_macd.set_global_opts(legend_opts=opts.LegendOpts(is_show=False)) 276 | overlap_macd_line = bar_macd.overlap(line_macd) 277 | ops = InitOpts(width="100%", height="800px") 278 | grid_chart = Grid(init_opts=ops) 279 | # grid_chart = Grid() 280 | grid_chart.add_js_funcs("var barData = {}".format(input_data["datas"])) 281 | 282 | grid_chart.add( 283 | overlap_kline_line, 284 | # grid_opts=grid0_opts, 285 | grid_opts=opts.GridOpts(pos_left="3%", pos_right="1%", height="60%"), 286 | ) 287 | 288 | # # Volume 柱状图 289 | grid_chart.add( 290 | bar_vol, 291 | # grid_opts=grid1_opts 292 | grid_opts=opts.GridOpts( 293 | pos_left="3%", pos_right="1%", pos_top="71%", height="10%" 294 | ), 295 | ) 296 | 297 | # # MACD DIFS DEAS 298 | grid_chart.add( 299 | overlap_macd_line, 300 | # grid_opts=grid2_opts, 301 | grid_opts=opts.GridOpts( 302 | pos_left="3%", pos_right="1%", pos_top="82%", height="14%" 303 | ), 304 | ) 305 | 306 | # grid_chart.render(path="ShowHtml/CandleChart.html") 307 | 308 | if show_html_path is not None: 309 | grid_chart.render(path=show_html_path, height="600%", width="100%") 310 | 311 | return grid_chart 312 | # if show_render: 313 | # grid_chart.render() 314 | # make_snapshot(snapshot, grid_chart.render(), "bar.png") 315 | 316 | 317 | if __name__ == "__main__": 318 | from GetBaseData.handle_data_show import show_data_from_df 319 | 320 | show_data = show_data_from_df("Data/RealData/hfq/600570.csv") 321 | draw_chart(show_data, show_html_path="ShowHtml/CandleChartV2.html") 322 | -------------------------------------------------------------------------------- /Utils/TechnicalIndicators/__init__.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/2/2 22:30 4 | # @Author : Adolf 5 | # @File : __init__.py.py 6 | # @Function: 7 | -------------------------------------------------------------------------------- /Utils/TechnicalIndicators/basic_indicators.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/2/2 22:30 4 | # @Author : Adolf 5 | # @File : basic_indicators.py 6 | # @Function: 7 | """ 8 | 常用技术分析指标:MA, MACD 9 | """ 10 | 11 | import numpy as np 12 | 13 | 14 | def SMA(close: np.array, timeperiod=5): 15 | """简单移动平均 16 | https://baike.baidu.com/item/%E7%A7%BB%E5%8A%A8%E5%B9%B3%E5%9D%87%E7%BA%BF/217887 17 | :param close: np. Array 18 | 收盘价序列 19 | :param timeperiod: int 20 | 均线参数 21 | :return: np. Array 22 | """ 23 | res = [] 24 | for i in range(len(close)): 25 | if i < timeperiod: 26 | seq = close[0 : i + 1] 27 | else: 28 | seq = close[i - timeperiod + 1 : i + 1] 29 | res.append(seq.mean()) 30 | return np.array(res, dtype=np.double).round(4) 31 | 32 | 33 | def EMA(close: np.array, timeperiod=5): 34 | """ 35 | https://baike.baidu.com/item/EMA/12646151 36 | :param close: np. array 37 | 收盘价序列 38 | :param timeperiod: int 39 | 均线参数 40 | :return: np.array 41 | """ 42 | res = [] 43 | for i in range(len(close)): 44 | if i < 1: 45 | res.append(close[i]) 46 | else: 47 | ema = (2 * close[i] + res[i - 1] * (timeperiod - 1)) / (timeperiod + 1) 48 | res.append(ema) 49 | return np.array(res, dtype=np.double).round(4) 50 | 51 | 52 | def MACD(close: np.array, fastperiod=12, slowperiod=26, signalperiod=9): 53 | """MACD 异同移动平均线 54 | https://baike.baidu.com/item/MACD%E6%8C%87%E6%A0%87/6271283 55 | :param close: np.array 56 | 收盘价序列 57 | :param fastperiod: int 58 | 快周期,默认值 12 59 | :param slowperiod: int 60 | 慢周期,默认值 26 61 | :param signalperiod: int 62 | 信号周期,默认值 9 63 | :return: (np.array, np.array, np.array) 64 | diff, dea, macd 65 | """ 66 | ema12 = EMA(close, timeperiod=fastperiod) 67 | ema26 = EMA(close, timeperiod=slowperiod) 68 | diff = ema12 - ema26 69 | dea = EMA(diff, timeperiod=signalperiod) 70 | macd = (diff - dea) * 2 71 | return diff.round(4), dea.round(4), macd.round(4) 72 | 73 | 74 | def KDJ(close: np.array, high: np.array, low: np.array): 75 | """ 76 | :param close: 收盘价序列 77 | :param high: 最高价序列 78 | :param low: 最低价序列 79 | :return: 80 | """ 81 | n = 9 82 | hv = [] 83 | lv = [] 84 | for i in range(len(close)): 85 | if i < n: 86 | h_ = high[0 : i + 1] 87 | l_ = low[0 : i + 1] 88 | else: 89 | h_ = high[i - n + 1 : i + 1] 90 | l_ = low[i - n + 1 : i + 1] 91 | hv.append(max(h_)) 92 | lv.append(min(l_)) 93 | 94 | hv = np.around(hv, decimals=2) 95 | lv = np.around(lv, decimals=2) 96 | rsv = np.where(hv == lv, 0, (close - lv) / (hv - lv) * 100) 97 | 98 | k = [] 99 | d = [] 100 | j = [] 101 | for i in range(len(rsv)): 102 | if i < n: 103 | k_ = rsv[i] 104 | d_ = k_ 105 | else: 106 | k_ = (2 / 3) * k[i - 1] + (1 / 3) * rsv[i] 107 | d_ = (2 / 3) * d[i - 1] + (1 / 3) * k_ 108 | 109 | k.append(k_) 110 | d.append(d_) 111 | j.append(3 * k_ - 2 * d_) 112 | 113 | k = np.array(k, dtype=np.double) 114 | d = np.array(d, dtype=np.double) 115 | j = np.array(j, dtype=np.double) 116 | return k.round(4), d.round(4), j.round(4) 117 | 118 | 119 | def RSQ(close) -> float: 120 | """拟合优度 R SQuare 121 | :param close: 收盘价序列 122 | :return: 123 | """ 124 | x = list(range(len(close))) 125 | y = np.array(close) 126 | x_squred_sum = sum([x1 * x1 for x1 in x]) 127 | xy_product_sum = sum([x[i] * y[i] for i in range(len(x))]) 128 | num = len(x) 129 | x_sum = sum(x) 130 | y_sum = sum(y) 131 | delta = float(num * x_squred_sum - x_sum * x_sum) 132 | if delta == 0: 133 | return 0 134 | y_intercept = (1 / delta) * (x_squred_sum * y_sum - x_sum * xy_product_sum) 135 | slope = (1 / delta) * (num * xy_product_sum - x_sum * y_sum) 136 | 137 | y_mean = np.mean(y) 138 | ss_tot = sum([(y1 - y_mean) * (y1 - y_mean) for y1 in y]) + 0.00001 139 | ss_err = sum( 140 | [ 141 | (y[i] - slope * x[i] - y_intercept) * (y[i] - slope * x[i] - y_intercept) 142 | for i in range(len(x)) 143 | ] 144 | ) 145 | rsq = 1 - ss_err / ss_tot 146 | 147 | return round(rsq, 4) 148 | -------------------------------------------------------------------------------- /Utils/__init__.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2021/12/23 00:01 4 | # @Author : Adolf 5 | # @File : __init__.py.py 6 | # @Function: 7 | -------------------------------------------------------------------------------- /Utils/base_utils.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2021/12/23 00:02 4 | # @Author : Adolf 5 | # @File : base_utils.py 6 | # @Function: 7 | import sys 8 | 9 | import loguru 10 | 11 | 12 | def get_logger(level="INFO", console=True, logger_file=None): 13 | """ 14 | :param level: 选择日志的级别,可选trace,debug,info,warning,error,critical 15 | :param console: 是不进行控制台输出日志 16 | :param logger_file: 日志文件路径,None则表示不输出日志到文件 17 | :return: 18 | """ 19 | logger = loguru.logger 20 | logger.remove() 21 | 22 | logger_format = """{time:YYYY-MM-DD HH:mm:ss}| {level} | {name}=>{function}=>{line}\n{message}""" 23 | 24 | if console: 25 | logger.add(sys.stderr, format=logger_format, colorize=True, level=level.upper()) 26 | 27 | # 添加一个文件输出的内容 28 | # 目前每天一个日志文件,日志文件最多保存7天 29 | if logger_file is not None: 30 | logger.add( 31 | logger_file, 32 | enqueue=True, 33 | level=level.upper(), 34 | encoding="utf-8", 35 | rotation="00:00", 36 | retention="7 days", 37 | ) 38 | 39 | return logger 40 | 41 | 42 | # 指定只运行一次 43 | def run_once(f): 44 | def wrapper(*args, **kwargs): 45 | if not wrapper.has_run: 46 | wrapper.has_run = True 47 | return f(*args, **kwargs) 48 | 49 | wrapper.has_run = False 50 | return wrapper 51 | 52 | 53 | # 指定函数只运行一次 54 | # def run_one_stock_once(once_stock=False): 55 | # def run_once(f): 56 | # def wrapper(*args, **kwargs): 57 | # if not wrapper.has_run: 58 | # wrapper.has_run = True 59 | # return f(*args, **kwargs) 60 | # 61 | # wrapper.has_run = once_stock 62 | # return wrapper 63 | # 64 | # return run_once 65 | -------------------------------------------------------------------------------- /Utils/info_push.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # @Project : stock_quant 3 | # @Date : 2022/1/18 23:29 4 | # @Author : Adolf 5 | # @File : info_push.py 6 | # @Function: 7 | import json 8 | import logging 9 | 10 | import requests 11 | 12 | 13 | def post_msg_to_dingtalk(title="", msg="", token="", at=None, type="text"): 14 | if at is None: 15 | at = [] 16 | url = "https://oapi.dingtalk.com/robot/send?access_token=" + token 17 | if type == "markdown": 18 | # 使用markdown时at不起作用,大佬们有空调一下 19 | data = { 20 | "msgtype": "markdown", 21 | "markdown": {"title": "[" + title + "]" + title, "text": "" + msg}, 22 | "at": {}, 23 | } 24 | if type == "text": 25 | data = { 26 | "msgtype": "text", 27 | "text": {"content": "[" + title + "]" + msg}, 28 | "at": {}, 29 | } 30 | data["at"]["atMobiles"] = at 31 | json_data = json.dumps(data) 32 | try: 33 | response = requests.post( 34 | url=url, data=json_data, headers={"Content-Type": "application/json"} 35 | ).json() 36 | assert response["errcode"] == 0 37 | except Exception as e: 38 | logging.getLogger().error(f"发送钉钉提醒失败,请检查;{e}") 39 | -------------------------------------------------------------------------------- /api/hist.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |

Opt records

12 | 13 |
14 |
15 | start_date: 16 | 17 | end_date: 18 | 19 | user_id: 20 | 21 | stock_code: 22 | 23 | count: 24 | 25 | 26 | 27 | 28 |
29 | 30 |
31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 |
user_iddatestock_codestock_profit_rateprofit_rateover_profit
47 |
48 | 49 | 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /api/stock_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | import time 4 | 5 | import baostock as bs 6 | import pandas as pd 7 | import uvicorn 8 | from fastapi import Body, FastAPI, Response 9 | 10 | # from pyecharts.components import Table 11 | from finta import TA 12 | from loguru import logger 13 | from pymongo import MongoClient 14 | 15 | app = FastAPI() 16 | mongo_config = {"host": "172.22.67.15", "port": 27017} 17 | db = MongoClient(mongo_config["host"], mongo_config["port"])["stock_db"] 18 | 19 | # 打印db的信息 20 | logger.info(db) 21 | 22 | 23 | @app.post("/stock_data") 24 | def get_stock_data( 25 | code=Body("sh.600570"), 26 | start_date=Body("2022-12-19"), 27 | end_date=Body("2023-04-10"), 28 | frequency=Body("d"), 29 | ): 30 | print( 31 | f"code:{code} start_date:{start_date} end_date:{end_date} frequency:{frequency}" 32 | ) 33 | if os.path.exists(f"Data/Baostock/day/{code}.csv"): 34 | data_df = pd.read_csv(f"Data/Baostock/day/{code}.csv") 35 | 36 | data_df = data_df[data_df["date"] >= start_date] 37 | data_df = data_df[data_df["date"] <= end_date] 38 | 39 | else: 40 | bs.login() 41 | rs = bs.query_history_k_data_plus( 42 | code, 43 | "date,code,open,high,low,close,volume,amount", 44 | start_date=start_date, 45 | end_date=end_date, 46 | frequency=frequency, 47 | adjustflag="3", 48 | ) 49 | # 打印结果集 50 | data_list = [] 51 | while (rs.error_code == "0") & rs.next(): 52 | # 获取一条记录,将记录合并在一起 53 | data_list.append(rs.get_row_data()) 54 | data_df = pd.DataFrame(data_list, columns=rs.fields) 55 | 56 | data_df["MA5"] = TA.SMA(data_df, period=5) 57 | data_df["MA10"] = TA.SMA(data_df, period=10) 58 | data_df["MA20"] = TA.SMA(data_df, period=20) 59 | data_df["MA30"] = TA.SMA(data_df, period=30) 60 | data_df["MA60"] = TA.SMA(data_df, period=60) 61 | 62 | macd_df = TA.MACD(data_df) 63 | macd_df["HISTOGRAM"] = macd_df["MACD"] - macd_df["SIGNAL"] 64 | data_df = pd.concat([data_df, macd_df], axis=1) 65 | 66 | bs.logout() 67 | 68 | data_df = data_df.fillna("") 69 | 70 | base_data_list = [ 71 | list(oclh) 72 | for oclh in zip( 73 | data_df["date"].tolist(), 74 | # data_df["code"].tolist(), 75 | data_df["open"].tolist(), 76 | data_df["close"].tolist(), 77 | data_df["high"].tolist(), 78 | data_df["low"].tolist(), 79 | data_df["volume"].tolist(), 80 | data_df["amount"].tolist(), strict=False, 81 | ) 82 | ] 83 | # macd_list = [ 84 | # list(macd) for macd in zip( 85 | # data_df["MACD"].tolist(), 86 | # data_df["SIGNAL"].tolist(), 87 | # data_df["HISTOGRAM"].to_list(), 88 | # ) 89 | # ] 90 | macd_list = data_df["HISTOGRAM"].tolist() 91 | diff = data_df["MACD"].tolist() 92 | dea = data_df["SIGNAL"].tolist() 93 | 94 | ma5_list = data_df["MA5"].tolist() 95 | ma10_list = data_df["MA10"].tolist() 96 | ma20_list = data_df["MA20"].tolist() 97 | ma30_list = data_df["MA30"].tolist() 98 | ma60_list = data_df["MA60"].tolist() 99 | return { 100 | "base_data": base_data_list, 101 | "macd": macd_list, 102 | "diff": diff, 103 | "dea": dea, 104 | "ma5": ma5_list, 105 | "ma10": ma10_list, 106 | "ma20": ma20_list, 107 | "ma30": ma30_list, 108 | "ma60": ma60_list, 109 | } 110 | 111 | 112 | @app.post("/get_records") 113 | def get_records( 114 | user_id=Body(None), 115 | start_date=Body(None), 116 | end_date=Body(None), 117 | stock_code=Body(None), 118 | count=Body(None), 119 | ): 120 | print("user_id:", user_id) 121 | filter = {} 122 | if user_id is not None and user_id != "": 123 | filter["user_id"] = user_id 124 | if start_date is not None and start_date != "": 125 | filter["date"] = {"$gte": start_date} 126 | if end_date is not None and end_date != "": 127 | if "date" in filter: 128 | filter["date"]["$lte"] = end_date 129 | else: 130 | filter["date"] = {"$lte": end_date} 131 | if stock_code is not None and stock_code != "": 132 | filter["stock_code"] = stock_code 133 | print(f"filter:{filter}") 134 | print(f"count:{count}") 135 | count = int(count) 136 | return_records = [] 137 | profit_rate = 1 138 | stock_profit_rate = 1 139 | for history in db["play_records"].find(filter).limit(count): 140 | history.pop("_id") 141 | history.pop("records") 142 | profit_rate *= 1 + float(history["profit_rate"]) 143 | stock_profit_rate *= 1 + float(history["stock_profit_rate"]) 144 | print(f"profit_rate {profit_rate} stock_profit_rate {stock_profit_rate}") 145 | print("history: ", history) 146 | return_records.append( 147 | [ 148 | history["user_id"], 149 | history["date"], 150 | history["stock_code"], 151 | float_to_pct(history["stock_profit_rate"]), 152 | float_to_pct(history["profit_rate"]), 153 | float_to_pct(history["over_profit"]), 154 | ] 155 | ) 156 | return { 157 | "records": return_records, 158 | "profit_rate": profit_rate - 1, 159 | "stock_profit_rate": stock_profit_rate - 1, 160 | } 161 | 162 | 163 | @app.post("/push_records") 164 | def push_records( 165 | records=Body(None), 166 | user_id=Body(None), 167 | stock_code=Body(None), 168 | stock_profit_rate=Body(None), 169 | ): 170 | print(f"get {user_id} records from browser:{records}") 171 | profit_rate = cal_profit_rate(records) 172 | today = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) 173 | over_profit = float(profit_rate) - float(stock_profit_rate) 174 | db["play_records"].insert_one( 175 | { 176 | "user_id": user_id, 177 | "records": records, 178 | "profit_rate": profit_rate, 179 | "date": today, 180 | "stock_code": stock_code, 181 | "stock_profit_rate": stock_profit_rate, 182 | "over_profit": over_profit, 183 | } 184 | ) 185 | return f"success,profit_rate is {float_to_pct(profit_rate)},over_profit is {float_to_pct(over_profit)}" 186 | 187 | 188 | def float_to_pct(f): 189 | val = round(float(f) * 100, 2) 190 | return str(val) + "%" 191 | 192 | 193 | def cal_profit_rate(records): 194 | profit_rate = 1 195 | assert len(records) % 2 == 0 196 | i = 0 197 | while i < len(records): 198 | profit_rate = records[i + 1]["price"] / records[i]["price"] * profit_rate 199 | i = i + 2 200 | return round((profit_rate - 1), 2) 201 | 202 | 203 | @app.get("/index") 204 | def func(): 205 | with open("api/test.html", encoding="utf8") as file: 206 | content = file.read() 207 | # 4.返回响应数据 208 | return Response(content=content, media_type="text/html") 209 | 210 | 211 | @app.get("/hist") 212 | def func(): 213 | with open("api/hist.html", encoding="utf8") as file: 214 | content = file.read() 215 | # 4.返回响应数据 216 | return Response(content=content, media_type="text/html") 217 | 218 | 219 | if __name__ == "__main__": 220 | # 获取本机ip 221 | ip = socket.gethostbyname(socket.gethostname()) 222 | print(f"ip : {ip}") 223 | uvicorn.run(app, host=ip, port=8502) 224 | -------------------------------------------------------------------------------- /api/test.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def test_push_records(): 5 | url = "http://localhost:9999/push_records" 6 | json_param = [ 7 | {"opt": "buy", "date": "2021-09-09", "price": 1200}, 8 | {"opt": "buy", "date": "2021-09-09", "price": 2000}, 9 | ] 10 | rec = { 11 | "user_id": "rose", 12 | "records": json_param, 13 | "stock_code": "000111", 14 | "stock_profit_rate": "1", 15 | } 16 | response = requests.post(url, json=rec) 17 | print(f"response {response.text}") 18 | 19 | 20 | # def test_get_code_data(): 21 | url = "http://172.22.67.15:9999/stock_data" 22 | rec = {"start_date": "2023-04-19", "end_date": "2023-04-19", "stock_code": "sz.002044"} 23 | response = requests.post(url, json=rec) 24 | print(len(response.json())) 25 | print(f"response {response.json()}") 26 | 27 | 28 | def test_get_records(): 29 | url = "http://localhost:9999/get_records" 30 | rec = {"user_id": "rose", "start_date": "2023-04-19 14:15:10", "count": 10} 31 | response = requests.post(url, json=rec) 32 | print(len(response.json())) 33 | print(f"response {response.json()}") 34 | -------------------------------------------------------------------------------- /api/示例/README.md: -------------------------------------------------------------------------------- 1 | # H5-Kline 是基于echarts封装的非常轻量级的股票行情图表 2 | ## 功能点 3 | - 分时行情图表 4 | - K线周期图表 5 | - MA移动平均线 6 | - MACD、DIF、DEA幅图指标 7 | ## 适用场景 8 | - pc端门户网站 9 | - H5移动端(在小部分低端机上,K线滑动流畅性会低一些,不影响功能) 10 | ## 使用方法 11 | - 准备行情数据 详情请看tmpData.js里面的数据格式 12 | - 准备div初始化图标,详情请看demo.html 13 | ## 效果图 14 | ![Image text](https://github.com/2557606319/H5-Kline/blob/master/kline1.gif) 15 | ![Image text](https://github.com/2557606319/H5-Kline/blob/master/kline2.gif) 16 | ![Image text](https://github.com/2557606319/H5-Kline/blob/master/kline3.gif) 17 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "stock-quant" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["adolf "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.12" 10 | akshare = "^1.14.77" 11 | baostock = "^0.8.9" 12 | pandas = "^2.2.2" 13 | tqdm = "^4.66.5" 14 | pyecharts = "^2.0.6" 15 | schedule = "^1.2.2" 16 | streamlit = "^1.38.0" 17 | streamlit-echarts = "^0.4.0" 18 | tabulate = "^0.9.0" 19 | loguru = "^0.7.2" 20 | pathos = "^0.3.2" 21 | dask = {extras = ["complete"], version = "^2024.9.0"} 22 | scikit-learn = "^1.5.2" 23 | matplotlib = "^3.9.2" 24 | fastapi = "^0.114.2" 25 | uvicorn = "^0.30.6" 26 | pymongo = "^4.8.0" 27 | ruff = "^0.6.5" 28 | ray = "^2.35.0" 29 | 30 | 31 | [tool.poetry.group.dev.dependencies] 32 | lightgbm = "^4.5.0" 33 | 34 | [build-system] 35 | requires = ["poetry-core"] 36 | build-backend = "poetry.core.masonry.api" 37 | 38 | [tool.ruff] 39 | # https://github.com/astral-sh/ruff 40 | # https://docs.astral.sh/ruff/settings 41 | output-format = "grouped" 42 | show-fixes = true 43 | target-version = "py312" 44 | 45 | [tool.ruff.format] 46 | docstring-code-format = true 47 | docstring-code-line-length = 79 48 | skip-magic-trailing-comma = true 49 | 50 | [tool.ruff.lint] 51 | select = [ # UPDATEME with additional rules from https://docs.astral.sh/ruff/rules/ 52 | "F", 53 | "E", 54 | "W", 55 | "I", 56 | "N", 57 | "S", 58 | "B", 59 | "UP", 60 | "C90", 61 | "T20", 62 | "EM", 63 | "PL", 64 | "C4", 65 | "PT", 66 | "TD", 67 | "ICN", 68 | "RET", 69 | "RSE", 70 | "ARG", 71 | "SIM", 72 | "TID", 73 | "PTH", 74 | "TCH", 75 | "FIX", 76 | "FLY", 77 | "YTT", 78 | "RUF" 79 | ] 80 | ignore = ["D200", "D104", "E501", "D101", "D100", "D102", "ANN101", "T201"] 81 | task-tags = ["TODO", "FIXME", "XXX", "UPDATEME"] # UPDATEME by modifying or removing this setting after addressing all UPDATEMEs 82 | 83 | 84 | [tool.ruff.lint.flake8-pytest-style] 85 | fixture-parentheses = false 86 | mark-parentheses = false 87 | 88 | [tool.ruff.lint.flake8-tidy-imports] 89 | ban-relative-imports = "all" 90 | 91 | [tool.ruff.lint.pycodestyle] 92 | max-line-length = 88 93 | max-doc-length = 90 94 | 95 | [tool.ruff.lint.pydocstyle] 96 | convention = "numpy" 97 | 98 | [tool.ruff.lint.pylint] 99 | max-bool-expr = 3 100 | 101 | [tool.ruff.lint.isort] 102 | split-on-trailing-comma = false 103 | sections.typing = ["typing", "types", "typing_extensions", "mypy", "mypy_extensions"] 104 | sections.testing = ["pytest", "tests"] 105 | section-order = [ 106 | "future", 107 | "typing", 108 | "standard-library", 109 | "third-party", 110 | "first-party", 111 | "local-folder", 112 | "testing" 113 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | akshare 2 | baostock 3 | pandas 4 | tqdm 5 | pyecharts 6 | schedule 7 | streamlit 8 | streamlit-echarts 9 | tabulate 10 | loguru 11 | pathos 12 | dask[complete] 13 | scikit-learn 14 | matplotlib 15 | fastapi 16 | uvicorn 17 | pymongo 18 | ruff 19 | ray -------------------------------------------------------------------------------- /start_stable_dog.sh: -------------------------------------------------------------------------------- 1 | conda activate AlgoTrader 2 | export PYTHONPATH=$(pwd):$PYTHONPATH 3 | streamlit run StrategyLib/AutomaticInvestmentPlan/result_show.py 4 | -------------------------------------------------------------------------------- /web_ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKQ1688/stock_quant/c88e695a5cf17f2a445b671026f1535d369910ce/web_ui/__init__.py -------------------------------------------------------------------------------- /web_ui/show_page.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @Project :stock_quant 4 | @File :show_page.py 5 | @Author :adolf 6 | @Date :2023/3/25 16:24 7 | """ 8 | 9 | import streamlit as st 10 | import streamlit.components.v1 as components 11 | from streamlit_echarts import st_echarts 12 | 13 | # from MachineLearning.annotation_platform.buy_and_sell_signals import ( 14 | # annotation_platform_main, 15 | # ) 16 | from StrategyLib.AutomaticInvestmentPlan.result_show import auto_investment_plan 17 | from StrategyLib.ChanStrategy.automatic_drawing import chan_show_main 18 | from StrategyLib.OneAssetStrategy.macd_day import MACDdayStrategy 19 | from web_ui.time_sharing import three_inidexs 20 | 21 | st.set_page_config(page_title="量化炒股系统", layout="wide") 22 | 23 | 24 | class MultiApp: 25 | def __init__(self): 26 | self.apps = [] 27 | self.app_dict = {} 28 | 29 | def add_app(self, title, func): 30 | if title not in self.apps: 31 | self.apps.append(title) 32 | self.app_dict[title] = func 33 | 34 | def run(self): 35 | title = st.sidebar.radio( 36 | "选择服务类型", self.apps, format_func=lambda title: str(title) 37 | ) 38 | self.app_dict[title]() 39 | 40 | 41 | # st.set_page_config(layout="wide") # 设置屏幕展开方式,宽屏模式布局更好 42 | 43 | 44 | def welcome(): 45 | # st.title("欢迎来到法域通测试页面!") 46 | # st.markdown("#### 合同智审") 47 | # st.markdown("* [测试接口文档](http://101.69.229.138:8131/docs)") 48 | 49 | st.title("量化策略回测系统") 50 | st.markdown("#### 今日大盘走势") 51 | 52 | options = { 53 | "xAxis": { 54 | "type": "", 55 | "data": ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"], 56 | }, 57 | "yAxis": {"type": "价格"}, 58 | "series": [{"data": [820, 932, 901, 934, 1290, 1330, 1320], "type": "line"}], 59 | } 60 | st_echarts(options=options) 61 | 62 | 63 | def MACD_main(): 64 | # st.sidebar.text_input("请输入股票代码", value="000001") 65 | code = st.sidebar.text_input("请输入股票代码", value="sz.399006") 66 | start_date = st.sidebar.text_input("请输入开始日期", value="2019-01-01") 67 | end_date = st.sidebar.text_input("请输入结束日期", value="2021-01-01") 68 | # st.sidebar.selectbox("请选择策略", ["金买死卖", "60日上方"]) 69 | options1 = st.multiselect( 70 | "选择macd策略", ["金买死卖", "60日上方操作", "250日上方操作"] 71 | ) 72 | options2 = st.multiselect( 73 | "周期选择", ["5min", "30min", "60min", "日线", "周线", "月线"] 74 | ) 75 | st.title("MACD") 76 | 77 | config = { 78 | "RANDOM_SEED": 42, 79 | "LOG_LEVEL": "INFO", 80 | "CODE_NAME": code, 81 | # "CODE_NAME": "ALL_MARKET_10", 82 | # "CODE_NAME": ["sh.600238",], 83 | # "CODE_NAME": ["sh.603806", "sh.603697", "sh.603700", "sh.600570", "sh.603809","sh.600238","sh.603069","sh.600764","sz.002044"], 84 | "START_STAMP": start_date, 85 | "END_STAMP": end_date, 86 | # "SHOW_DATA_PATH": "", 87 | # "STRATEGY_PARAMS": {} 88 | } 89 | strategy = MACDdayStrategy(config) 90 | strategy.run() 91 | stock_result = strategy.stock_result.astype(str) 92 | pl = strategy.pl_result.astype(str) 93 | 94 | # print('-----------------------------------') 95 | # print(pl) 96 | # print('-----------------------------------') 97 | # st.components.v1.iframe(src="demo.html", width=700, height=500) 98 | with open("ShowHtml/demo.html") as fp: # 如果遇到decode错误,就加上合适的encoding 99 | text = fp.read() 100 | components.html(html=text, width=None, height=800, scrolling=False) 101 | 102 | st.table(stock_result) 103 | st.table(pl) 104 | 105 | 106 | def SMA_main(): 107 | st.title("SMA") 108 | 109 | 110 | def Kline_challenge(): 111 | st.markdown("#### 欢迎来到K线挑战!") 112 | st.markdown("* [K线挑战入口](http://127.0.0.1:8501/index)") 113 | 114 | # st.components.v1.iframe(src="demo.html", width=700, height=500) 115 | # with open("api/test.html") as fp: 116 | # text = fp.read() 117 | # components.html(html=text, width=None, height=1200, scrolling=False) 118 | 119 | 120 | app = MultiApp() 121 | app.add_app("首页", three_inidexs) 122 | app.add_app("MACD策略", MACD_main) 123 | # app.add_app("均线策略", SMA_main) 124 | app.add_app("定投策略", auto_investment_plan) 125 | # app.add_app("K线游戏", annotation_platform_main) 126 | app.add_app("K线游戏", Kline_challenge) 127 | app.add_app("缠论", chan_show_main) 128 | app.run() 129 | -------------------------------------------------------------------------------- /web_ui/time_sharing.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author : adolf adolf1321794021@gmail.com 3 | @Date : 2023-06-08 23:07:44 4 | @LastEditors : adolf 5 | @LastEditTime : 2023-06-10 16:04:20 6 | @FilePath : /stock_quant/web_ui/time_sharing.py 7 | @Description : 8 | """ 9 | 10 | import baostock as bs 11 | import pandas as pd 12 | import streamlit as st 13 | from streamlit_echarts import st_echarts 14 | 15 | 16 | def get_index_data(code="sh.000001"): 17 | rs = bs.query_history_k_data_plus( 18 | code=code, 19 | fields="date,code,close", 20 | start_date="2013-01-01", 21 | end_date="2023-06-10", 22 | frequency="d", 23 | adjustflag="3", 24 | ) 25 | # 打印结果集 26 | data_list = [] 27 | while (rs.error_code == "0") & rs.next(): 28 | # 获取一条记录,将记录合并在一起 29 | data_list.append(rs.get_row_data()) 30 | result_df = pd.DataFrame(data_list, columns=rs.fields) 31 | return result_df 32 | 33 | 34 | if "df_sh" not in st.session_state: 35 | bs.login() 36 | 37 | st.session_state["df_sh"] = get_index_data(code="sh.000001") 38 | st.session_state["df_sz"] = get_index_data(code="sz.399001") 39 | st.session_state["df_cr"] = get_index_data(code="sz.399006") 40 | 41 | bs.logout() 42 | 43 | # st.dataframe(df_sh) 44 | 45 | x_data_sz = st.session_state["df_sh"]["date"].tolist() # 日期 46 | y_data_sz = st.session_state["df_sz"]["close"].tolist() 47 | y_data_sh = st.session_state["df_sh"]["close"].tolist() 48 | y_data_cr = st.session_state["df_cr"]["close"].tolist() # 收盘价 49 | 50 | 51 | def three_inidexs(): 52 | st.title("指数历史行情") 53 | 54 | options = { 55 | "title": {"text": "三大指数走势"}, 56 | "legend": {"data": ["上证指数", "深证指数", "创业板指数"]}, 57 | "xAxis": {"type": "category", "data": x_data_sz}, 58 | "yAxis": {"type": "value"}, 59 | "series": [ 60 | {"name": "上证指数", "type": "line", "data": y_data_sh}, 61 | {"name": "深证指数", "type": "line", "data": y_data_sz}, 62 | {"name": "创业板指数", "type": "line", "data": y_data_cr}, 63 | ], 64 | } 65 | 66 | st_echarts(options=options, height="600px") 67 | --------------------------------------------------------------------------------