├── README.md ├── Selector.py ├── appendix.json ├── configs.json ├── fetch_kline.py ├── requirements.txt └── select_stock.py /README.md: -------------------------------------------------------------------------------- 1 | # Z哥战法的Python实现 2 | 3 | > **更新时间:2025-07-03** – 增加填坑战法。 4 | 5 | --- 6 | 7 | ## 目录 8 | 9 | * [项目简介](#项目简介) 10 | * [快速上手](#快速上手) 11 | 12 | * [安装依赖](#安装依赖) 13 | * [Tushare Token(可选)](#tushare-token可选) 14 | * [Mootdx 运行前置步骤](#mootdx-运行前置步骤) 15 | * [下载历史行情](#下载历史行情) 16 | * [运行选股](#运行选股) 17 | * [参数说明](#参数说明) 18 | 19 | * [`fetch_kline.py`](#fetch_klinepy) 20 | 21 | * [K 线频率编码](#k-线频率编码) 22 | * [`select_stock.py`](#select_stockpy) 23 | * [内置策略参数](#内置策略参数) 24 | 25 | * [1. BBIKDJSelector(少妇战法)](#1-bbikdjselector少妇战法) 26 | * [2. PeakKDJSelector(填坑战法)](#2-peakkdjselector填坑战法) 27 | * [3. BBIShortLongSelector(补票战法)](#3-bbishortlongselector补票战法) 28 | * [4. BreakoutVolumeKDJSelector(TePu 战法)](#4-breakoutvolumekdjselectortepu-战法) 29 | * [项目结构](#项目结构) 30 | * [免责声明](#免责声明) 31 | 32 | --- 33 | 34 | ## 项目简介 35 | 36 | | 名称 | 功能简介 | 37 | | --------------------- | ---------------------------------------------------------------------------------------------------------------- | 38 | | **`fetch_kline.py`** | *按市值筛选* A 股股票,并抓取其**历史 K 线**保存为 CSV。支持 **AkShare / Tushare / Mootdx** 三大数据源,自动增量更新、多线程下载。*本版本不再保存市值快照*,每次运行实时拉取。 | 39 | | **`select_stock.py`** | 读取本地 CSV 行情,依据 `configs.json` 中的 **Selector** 定义批量选股,结果输出到 `select_results.log` 与控制台。 | 40 | 41 | 内置策略(见 `Selector.py`): 42 | 43 | * **BBIKDJSelector**(少妇战法) 44 | * **PeakKDJSelector**(填坑战法) 45 | * **BBIShortLongSelector**(补票战法) 46 | * **BreakoutVolumeKDJSelector**(TePu 战法) 47 | 48 | --- 49 | 50 | ## 快速上手 51 | 52 | ### 安装依赖 53 | 54 | ```bash 55 | # 创建并激活 Python 3.12 虚拟环境(推荐) 56 | conda create -n stock python=3.12 57 | conda activate stock 58 | 59 | # 进入项目目录(将以下路径替换为你的实际路径) 60 | cd "你的路径" 61 | 62 | # 安装依赖 63 | pip install -r requirements.txt 64 | 65 | # 若遇到 cffi 安装报错,可先升级后重试 66 | pip install --upgrade cffi 67 | ``` 68 | 69 | > 主要依赖:`akshare`、`tushare`、`mootdx`、`pandas`、`tqdm` 等。 70 | 71 | ### Tushare Token(可选) 72 | 73 | 若选择 **Tushare** 作为数据源,请按以下步骤操作: 74 | 75 | 1. **注册账号** 76 | 点击专属注册链接 [https://tushare.pro/register?reg=820660](https://tushare.pro/register?reg=820660) 完成注册。*通过该链接注册,我将获得 50 积分 – 感谢支持!* 77 | 2. **开通基础权限** 78 | 登录后进入「**平台介绍 → 社区捐助**」,按提示捐赠 **200 元/年** 可解锁 Tushare 基础接口。 79 | 3. **获取 Token** 80 | 打开个人主页,点击 **「接口 Token」**,复制生成的 Token。 81 | 4. **填入代码** 82 | 在 `fetch_kline.py` 约 **第 307 行**(以实际行为准): 83 | 84 | ```python 85 | ts_token = "***" # ← 替换为你的 Token 86 | ``` 87 | 88 | ### Mootdx 运行前置步骤 89 | 90 | **注意,Mootdx 下载的数据是未复权数据,会使选股结果存在偏差,请尽量使用 Tushare** 91 | 使用 **Mootdx** 数据源前,需先探测最快行情服务器一次: 92 | 93 | ```bash 94 | python -m mootdx bestip -vv 95 | ``` 96 | 97 | 脚本将保存最佳 IP,后续抓取更稳定。 98 | 99 | ### 下载历史行情 100 | 101 | ```bash 102 | python fetch_kline.py \ 103 | --datasource mootdx # mootdx / akshare / tushare 104 | --frequency 4 # K 线频率编码(4 = 日线) 105 | --exclude-gem # 排除创业板 / 科创板 / 北交所 106 | --min-mktcap 5e9 # 最小总市值(元) 107 | --max-mktcap +inf # 最大总市值(元) 108 | --start 20200101 # 起始日期(YYYYMMDD 或 today) 109 | --end today # 结束日期 110 | --out ./data # 输出目录 111 | --workers 10 # 并发线程数 112 | ``` 113 | 114 | *首跑* 下载完整历史;之后脚本会 **增量更新**。 115 | 116 | ### 运行选股 117 | 118 | ```bash 119 | python select_stock.py \ 120 | --data-dir ./data # CSV 行情目录 121 | --config ./configs.json # Selector 配置 122 | --date 2025-07-02 # 交易日(缺省 = 最新) 123 | ``` 124 | 125 | 示例输出: 126 | 127 | ``` 128 | ============== 选股结果 [填坑战法] =============== 129 | 交易日: 2025-07-02 130 | 符合条件股票数: 2 131 | 600690, 000333 132 | ``` 133 | 134 | --- 135 | 136 | ## 参数说明 137 | 138 | ### `fetch_kline.py` 139 | 140 | | 参数 | 默认值 | 说明 | 141 | | ------------------- | -------- | ------------------------------------ | 142 | | `--datasource` | `mootdx` | 数据源:`tushare` / `akshare` / `mootdx` | 143 | | `--frequency` | `4` | K 线频率编码(下表) | 144 | | `--exclude-gem` | flag | 排除创业板/科创板/北交所 | 145 | | `--min-mktcap` | `5e9` | 最小总市值(元) | 146 | | `--max-mktcap` | `+inf` | 最大总市值(元) | 147 | | `--start` / `--end` | `today` | 日期范围,`YYYYMMDD` 或 `today` | 148 | | `--out` | `./data` | 输出目录 | 149 | | `--workers` | `10` | 并发线程数 | 150 | 151 | #### K 线频率编码 152 | 153 | | 编码 | 周期 | Mootdx 关键字 | 用途 | 154 | | :-: | :--: | :--------: | ---- | 155 | | 0 | 5 分 | `5m` | 高频 | 156 | | 1 | 15 分 | `15m` | 高频 | 157 | | 2 | 30 分 | `30m` | 高频 | 158 | | 3 | 60 分 | `1h` | 波段 | 159 | | 4 | 日线 | `day` | ★ 常用 | 160 | | 5 | 周线 | `week` | 中长线 | 161 | | 6 | 月线 | `mon` | 中长线 | 162 | | 7 | 1 分 | `1m` | Tick | 163 | | 8 | 1 分 | `1m` | Tick | 164 | | 9 | 日线 | `day` | 备用 | 165 | | 10 | 季线 | `3mon` | 长周期 | 166 | | 11 | 年线 | `year` | 长周期 | 167 | 168 | ### `select_stock.py` 169 | 170 | | 参数 | 默认值 | 说明 | 171 | | ------------ | ---------------- | ------------- | 172 | | `--data-dir` | `./data` | CSV 行情目录 | 173 | | `--config` | `./configs.json` | Selector 配置文件 | 174 | | `--date` | 最新交易日 | 选股日期 | 175 | | `--tickers` | `all` | 股票池(逗号分隔列表) | 176 | 177 | 执行 `python select_stock.py --help` 获取更多高级参数与解释。 178 | 179 | ### 内置策略参数 180 | 181 | 以下参数均来自 **`configs.json`**,可根据个人喜好自由调整。 182 | 183 | #### 1. BBIKDJSelector(少妇战法) 184 | 185 | | 参数 | 预设值 | 说明 | 186 | | ----------------- | ------ | --------------------------------------------------- | 187 | | `j_threshold` | `1` | 当日 **J** 值必须 *小于* 该阈值 | 188 | | `bbi_min_window` | `20` | 检测 BBI 上升的最短窗口(交易日) | 189 | | `max_window` | `60` | 参与检测的最大窗口(交易日) | 190 | | `price_range_pct` | `0.5` | 最近 *max\_window* 根 K 线内,收盘价最大波动(`high/low−1`)不得超过此值 | 191 | | `bbi_q_threshold` | `0.1` | 允许 BBI 一阶差分为负的分位阈值(回撤容忍度) | 192 | | `j_q_threshold` | `0.10` | 当日 **J** 值需 *不高于* 最近窗口内该分位数 | 193 | 194 | #### 2. PeakKDJSelector(填坑战法) 195 | 196 | | 参数 | 预设值 | 说明 | 197 | | ---------------- | ------ | ----------------------------------------------------------- | 198 | | `j_threshold` | `10` | 当日 **J** 值必须 *小于* 该阈值 | 199 | | `max_window` | `100` | 参与检测的最大窗口(交易日) | 200 | | `fluc_threshold` | `0.03` | 当日收盘价与坑口的最大允许波动率 | 201 | | `gap_threshold` | `0.2` | 要求坑口高于区间最低收盘价的幅度(`oc_prev > min_close × (1+gap_threshold)`) | 202 | | `j_q_threshold` | `0.10` | 当日 **J** 值需 *不高于* 最近窗口内该分位数 | 203 | 204 | #### 3. BBIShortLongSelector(补票战法) 205 | 206 | | 参数 | 预设值 | 说明 | 207 | | ----------------- | ----- | ----------------------- | 208 | | `n_short` | `3` | 计算短周期 **RSV** 的窗口(交易日) | 209 | | `n_long` | `21` | 计算长周期 **RSV** 的窗口(交易日) | 210 | | `m` | `3` | 最近 *m* 天满足短 RSV 条件的判别窗口 | 211 | | `bbi_min_window` | `2` | 检测 BBI 上升的最短窗口(交易日) | 212 | | `max_window` | `60` | 参与检测的最大窗口(交易日) | 213 | | `bbi_q_threshold` | `0.2` | 允许 BBI 一阶差分为负的分位阈值 | 214 | 215 | #### 4. BreakoutVolumeKDJSelector(TePu 战法) 216 | 217 | | 参数 | 预设值 | 说明 | 218 | | ------------------ | -------- | --------------------------------------------------- | 219 | | `j_threshold` | `1` | 当日 **J** 值必须 *小于* 该阈值 | 220 | | `j_q_threshold` | `0.10` | 当日 **J** 值需 *不高于* 最近窗口内该分位数 | 221 | | `up_threshold` | `3.0` | 单日涨幅不低于该百分比,视为“突破” | 222 | | `volume_threshold` | `0.6667` | 放量日成交量需 **≥ 1/(1−volume\_threshold)** 倍于窗口内其他任意日 | 223 | | `offset` | `15` | 向前回溯的突破判定窗口(交易日) | 224 | | `max_window` | `60` | 参与检测的最大窗口(交易日) | 225 | | `price_range_pct` | `0.5` | 最近 *max\_window* 根 K 线内,收盘价最大波动不得超过此值(`high/low−1`) | 226 | 227 | --- 228 | 229 | ## 项目结构 230 | 231 | ``` 232 | . 233 | ├── appendix.json # 附加股票池 234 | ├── configs.json # Selector 配置 235 | ├── fetch_kline.py # 行情抓取脚本 236 | ├── select_stock.py # 批量选股脚本 237 | ├── Selector.py # 策略实现 238 | ├── data/ # CSV 数据输出目录 239 | ├── fetch.log # 抓取日志 240 | └── select_results.log # 选股日志 241 | ``` 242 | 243 | --- 244 | 245 | ## 免责声明 246 | 247 | * 本仓库仅供学习与技术研究之用,**不构成任何投资建议**。股市有风险,入市需审慎。 248 | * 致谢 **@Zettaranc** 在 Bilibili 的无私分享:[https://b23.tv/JxIOaNE](https://b23.tv/JxIOaNE) 249 | -------------------------------------------------------------------------------- /Selector.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Any 2 | 3 | from scipy.signal import find_peaks 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | # --------------------------- 通用指标 --------------------------- # 9 | 10 | def compute_kdj(df: pd.DataFrame, n: int = 9) -> pd.DataFrame: 11 | if df.empty: 12 | return df.assign(K=np.nan, D=np.nan, J=np.nan) 13 | 14 | low_n = df["low"].rolling(window=n, min_periods=1).min() 15 | high_n = df["high"].rolling(window=n, min_periods=1).max() 16 | rsv = (df["close"] - low_n) / (high_n - low_n + 1e-9) * 100 17 | 18 | K = np.zeros_like(rsv, dtype=float) 19 | D = np.zeros_like(rsv, dtype=float) 20 | for i in range(len(df)): 21 | if i == 0: 22 | K[i] = D[i] = 50.0 23 | else: 24 | K[i] = 2 / 3 * K[i - 1] + 1 / 3 * rsv.iloc[i] 25 | D[i] = 2 / 3 * D[i - 1] + 1 / 3 * K[i] 26 | J = 3 * K - 2 * D 27 | return df.assign(K=K, D=D, J=J) 28 | 29 | 30 | def compute_bbi(df: pd.DataFrame) -> pd.Series: 31 | ma3 = df["close"].rolling(3).mean() 32 | ma6 = df["close"].rolling(6).mean() 33 | ma12 = df["close"].rolling(12).mean() 34 | ma24 = df["close"].rolling(24).mean() 35 | return (ma3 + ma6 + ma12 + ma24) / 4 36 | 37 | 38 | def compute_rsv( 39 | df: pd.DataFrame, 40 | n: int, 41 | ) -> pd.Series: 42 | """ 43 | 按公式:RSV(N) = 100 × (C - LLV(L,N)) ÷ (HHV(C,N) - LLV(L,N)) 44 | - C 用收盘价最高值 (HHV of close) 45 | - L 用最低价最低值 (LLV of low) 46 | """ 47 | low_n = df["low"].rolling(window=n, min_periods=1).min() 48 | high_close_n = df["close"].rolling(window=n, min_periods=1).max() 49 | rsv = (df["close"] - low_n) / (high_close_n - low_n + 1e-9) * 100.0 50 | return rsv 51 | 52 | 53 | def compute_dif(df: pd.DataFrame, fast: int = 12, slow: int = 26) -> pd.Series: 54 | """计算 MACD 指标中的 DIF (EMA fast - EMA slow)。""" 55 | ema_fast = df["close"].ewm(span=fast, adjust=False).mean() 56 | ema_slow = df["close"].ewm(span=slow, adjust=False).mean() 57 | return ema_fast - ema_slow 58 | 59 | 60 | def bbi_deriv_uptrend( 61 | bbi: pd.Series, 62 | *, 63 | min_window: int, 64 | max_window: int | None = None, 65 | q_threshold: float = 0.0, 66 | ) -> bool: 67 | """ 68 | 判断 BBI 是否“整体上升”。 69 | 70 | 令最新交易日为 T,在区间 [T-w+1, T](w 自适应,w ≥ min_window 且 ≤ max_window) 71 | 内,先将 BBI 归一化:BBI_norm(t) = BBI(t) / BBI(T-w+1)。 72 | 73 | 再计算一阶差分 Δ(t) = BBI_norm(t) - BBI_norm(t-1)。 74 | 若 Δ(t) 的前 q_threshold 分位数 ≥ 0,则认为该窗口通过;只要存在 75 | **最长** 满足条件的窗口即可返回 True。q_threshold=0 时退化为 76 | “全程单调不降”(旧版行为)。 77 | 78 | Parameters 79 | ---------- 80 | bbi : pd.Series 81 | BBI 序列(最新值在最后一位)。 82 | min_window : int 83 | 检测窗口的最小长度。 84 | max_window : int | None 85 | 检测窗口的最大长度;None 表示不设上限。 86 | q_threshold : float, default 0.0 87 | 允许一阶差分为负的比例(0 ≤ q_threshold ≤ 1)。 88 | """ 89 | if not 0.0 <= q_threshold <= 1.0: 90 | raise ValueError("q_threshold 必须位于 [0, 1] 区间内") 91 | 92 | bbi = bbi.dropna() 93 | if len(bbi) < min_window: 94 | return False 95 | 96 | longest = min(len(bbi), max_window or len(bbi)) 97 | 98 | # 自最长窗口向下搜索,找到任一满足条件的区间即通过 99 | for w in range(longest, min_window - 1, -1): 100 | seg = bbi.iloc[-w:] # 区间 [T-w+1, T] 101 | norm = seg / seg.iloc[0] # 归一化 102 | diffs = np.diff(norm.values) # 一阶差分 103 | if np.quantile(diffs, q_threshold) >= 0: 104 | return True 105 | return False 106 | 107 | 108 | def _find_peaks( 109 | df: pd.DataFrame, 110 | *, 111 | column: str = "high", 112 | distance: Optional[int] = None, 113 | prominence: Optional[float] = None, 114 | height: Optional[float] = None, 115 | width: Optional[float] = None, 116 | rel_height: float = 0.5, 117 | **kwargs: Any, 118 | ) -> pd.DataFrame: 119 | 120 | if column not in df.columns: 121 | raise KeyError(f"'{column}' not found in DataFrame columns: {list(df.columns)}") 122 | 123 | y = df[column].to_numpy() 124 | 125 | indices, props = find_peaks( 126 | y, 127 | distance=distance, 128 | prominence=prominence, 129 | height=height, 130 | width=width, 131 | rel_height=rel_height, 132 | **kwargs, 133 | ) 134 | 135 | peaks_df = df.iloc[indices].copy() 136 | peaks_df["is_peak"] = True 137 | 138 | # Flatten SciPy arrays into columns (only those with same length as indices) 139 | for key, arr in props.items(): 140 | if isinstance(arr, (list, np.ndarray)) and len(arr) == len(indices): 141 | peaks_df[f"peak_{key}"] = arr 142 | 143 | return peaks_df 144 | 145 | 146 | # --------------------------- Selector 类 --------------------------- # 147 | class BBIKDJSelector: 148 | """ 149 | 自适应 *BBI(导数)* + *KDJ* 选股器 150 | • BBI: 允许 bbi_q_threshold 比例的回撤 151 | • KDJ: J < threshold ;或位于历史 J 的 j_q_threshold 分位及以下 152 | • MACD: DIF > 0 153 | • 收盘价波动幅度 ≤ price_range_pct 154 | """ 155 | 156 | def __init__( 157 | self, 158 | j_threshold: float = -5, 159 | bbi_min_window: int = 90, 160 | max_window: int = 90, 161 | price_range_pct: float = 100.0, 162 | bbi_q_threshold: float = 0.05, 163 | j_q_threshold: float = 0.10, 164 | ) -> None: 165 | self.j_threshold = j_threshold 166 | self.bbi_min_window = bbi_min_window 167 | self.max_window = max_window 168 | self.price_range_pct = price_range_pct 169 | self.bbi_q_threshold = bbi_q_threshold # ← 原 q_threshold 170 | self.j_q_threshold = j_q_threshold # ← 新增 171 | 172 | # ---------- 单支股票过滤 ---------- # 173 | def _passes_filters(self, hist: pd.DataFrame) -> bool: 174 | hist = hist.copy() 175 | hist["BBI"] = compute_bbi(hist) 176 | 177 | # 0. 收盘价波动幅度约束(最近 max_window 根 K 线) 178 | win = hist.tail(self.max_window) 179 | high, low = win["close"].max(), win["close"].min() 180 | if low <= 0 or (high / low - 1) > self.price_range_pct: 181 | return False 182 | 183 | # 1. BBI 上升(允许部分回撤) 184 | if not bbi_deriv_uptrend( 185 | hist["BBI"], 186 | min_window=self.bbi_min_window, 187 | max_window=self.max_window, 188 | q_threshold=self.bbi_q_threshold, 189 | ): 190 | return False 191 | 192 | # 2. KDJ 过滤 —— 双重条件 193 | kdj = compute_kdj(hist) 194 | j_today = float(kdj.iloc[-1]["J"]) 195 | 196 | # 最近 max_window 根 K 线的 J 分位 197 | j_window = kdj["J"].tail(self.max_window).dropna() 198 | if j_window.empty: 199 | return False 200 | j_quantile = float(j_window.quantile(self.j_q_threshold)) 201 | 202 | if not (j_today < self.j_threshold or j_today <= j_quantile): 203 | return False 204 | 205 | # 3. MACD:DIF > 0 206 | hist["DIF"] = compute_dif(hist) 207 | return hist["DIF"].iloc[-1] > 0 208 | 209 | # ---------- 多股票批量 ---------- # 210 | def select( 211 | self, date: pd.Timestamp, data: Dict[str, pd.DataFrame] 212 | ) -> List[str]: 213 | picks: List[str] = [] 214 | for code, df in data.items(): 215 | hist = df[df["date"] <= date] 216 | if hist.empty: 217 | continue 218 | # 额外预留 20 根 K 线缓冲 219 | hist = hist.tail(self.max_window + 20) 220 | if self._passes_filters(hist): 221 | picks.append(code) 222 | return picks 223 | 224 | 225 | class PeakKDJSelector: 226 | """ 227 | Peaks + KDJ 选股器 228 | """ 229 | 230 | def __init__( 231 | self, 232 | j_threshold: float = -5, 233 | max_window: int = 90, 234 | fluc_threshold: float = 0.03, 235 | gap_threshold: float = 0.02, 236 | j_q_threshold: float = 0.10, 237 | ) -> None: 238 | self.j_threshold = j_threshold 239 | self.max_window = max_window 240 | self.fluc_threshold = fluc_threshold # 当日↔peak_(t-n) 波动率上限 241 | self.gap_threshold = gap_threshold # oc_prev 必须高于区间最低收盘价的比例 242 | self.j_q_threshold = j_q_threshold 243 | 244 | # ---------- 单支股票过滤 ---------- # 245 | # ---------- 单支股票过滤 ---------- # 246 | def _passes_filters(self, hist: pd.DataFrame) -> bool: 247 | if hist.empty: 248 | return False 249 | 250 | hist = hist.copy().sort_values("date") 251 | hist["oc_max"] = hist[["open", "close"]].max(axis=1) 252 | 253 | # 1. 提取 peaks 254 | peaks_df = _find_peaks( 255 | hist, 256 | column="oc_max", 257 | distance=6, 258 | prominence=0.5, 259 | ) 260 | 261 | # 至少两个峰 262 | date_today = hist.iloc[-1]["date"] 263 | peaks_df = peaks_df[peaks_df["date"] < date_today] 264 | if len(peaks_df) < 2: 265 | return False 266 | 267 | peak_t = peaks_df.iloc[-1] # 最新一个峰 268 | peaks_list = peaks_df.reset_index(drop=True) 269 | oc_t = peak_t.oc_max 270 | total_peaks = len(peaks_list) 271 | 272 | # 2. 回溯寻找 peak_(t-n) 273 | target_peak = None 274 | for idx in range(total_peaks - 2, -1, -1): 275 | peak_prev = peaks_list.loc[idx] 276 | oc_prev = peak_prev.oc_max 277 | if oc_t <= oc_prev: # 要求 peak_t > peak_(t-n) 278 | continue 279 | 280 | # 只有当“总峰数 ≥ 3”时才检查区间内其他峰 oc_max 281 | if total_peaks >= 3 and idx < total_peaks - 2: 282 | inter_oc = peaks_list.loc[idx + 1 : total_peaks - 2, "oc_max"] 283 | if not (inter_oc < oc_prev).all(): 284 | continue 285 | 286 | # 新增: oc_prev 高于区间最低收盘价 gap_threshold 287 | date_prev = peak_prev.date 288 | mask = (hist["date"] > date_prev) & (hist["date"] < peak_t.date) 289 | min_close = hist.loc[mask, "close"].min() 290 | if pd.isna(min_close): 291 | continue # 区间无数据 292 | if oc_prev <= min_close * (1 + self.gap_threshold): 293 | continue 294 | 295 | target_peak = peak_prev 296 | 297 | break 298 | 299 | if target_peak is None: 300 | return False 301 | 302 | # 3. 当日收盘价波动率 303 | close_today = hist.iloc[-1]["close"] 304 | fluc_pct = abs(close_today - target_peak.close) / target_peak.close 305 | if fluc_pct > self.fluc_threshold: 306 | return False 307 | 308 | # 4. KDJ 过滤 309 | kdj = compute_kdj(hist) 310 | j_today = float(kdj.iloc[-1]["J"]) 311 | j_window = kdj["J"].tail(self.max_window).dropna() 312 | if j_window.empty: 313 | return False 314 | j_quantile = float(j_window.quantile(self.j_q_threshold)) 315 | if not (j_today < self.j_threshold or j_today <= j_quantile): 316 | return False 317 | 318 | return True 319 | 320 | # ---------- 多股票批量 ---------- # 321 | def select( 322 | self, 323 | date: pd.Timestamp, 324 | data: Dict[str, pd.DataFrame], 325 | ) -> List[str]: 326 | picks: List[str] = [] 327 | for code, df in data.items(): 328 | hist = df[df["date"] <= date] 329 | if hist.empty: 330 | continue 331 | hist = hist.tail(self.max_window + 20) # 额外缓冲 332 | if self._passes_filters(hist): 333 | picks.append(code) 334 | return picks 335 | 336 | 337 | class BBIShortLongSelector: 338 | """ 339 | BBI 上升 + 短/长期 RSV 条件 + DIF > 0 选股器 340 | """ 341 | def __init__( 342 | self, 343 | n_short: int = 3, 344 | n_long: int = 21, 345 | m: int = 3, 346 | bbi_min_window: int = 90, 347 | max_window: int = 150, 348 | bbi_q_threshold: float = 0.05, 349 | ) -> None: 350 | if m < 2: 351 | raise ValueError("m 必须 ≥ 2") 352 | self.n_short = n_short 353 | self.n_long = n_long 354 | self.m = m 355 | self.bbi_min_window = bbi_min_window 356 | self.max_window = max_window 357 | self.bbi_q_threshold = bbi_q_threshold # 新增参数 358 | 359 | # ---------- 单支股票过滤 ---------- # 360 | def _passes_filters(self, hist: pd.DataFrame) -> bool: 361 | hist = hist.copy() 362 | hist["BBI"] = compute_bbi(hist) 363 | 364 | # 1. BBI 上升(允许部分回撤) 365 | if not bbi_deriv_uptrend( 366 | hist["BBI"], 367 | min_window=self.bbi_min_window, 368 | max_window=self.max_window, 369 | q_threshold=self.bbi_q_threshold, 370 | ): 371 | return False 372 | 373 | # 2. 计算短/长期 RSV ----------------- 374 | hist["RSV_short"] = compute_rsv(hist, self.n_short) 375 | hist["RSV_long"] = compute_rsv(hist, self.n_long) 376 | 377 | if len(hist) < self.m: 378 | return False # 数据不足 379 | 380 | win = hist.iloc[-self.m :] # 最近 m 天 381 | long_ok = (win["RSV_long"] >= 80).all() # 长期 RSV 全 ≥ 80 382 | 383 | short_series = win["RSV_short"] 384 | short_start_end_ok = ( 385 | short_series.iloc[0] >= 80 and short_series.iloc[-1] >= 80 386 | ) 387 | short_has_below_20 = (short_series < 20).any() 388 | 389 | if not (long_ok and short_start_end_ok and short_has_below_20): 390 | return False 391 | 392 | # 3. MACD:DIF > 0 ------------------- 393 | hist["DIF"] = compute_dif(hist) 394 | return hist["DIF"].iloc[-1] > 0 395 | 396 | # ---------- 多股票批量 ---------- # 397 | def select( 398 | self, 399 | date: pd.Timestamp, 400 | data: Dict[str, pd.DataFrame], 401 | ) -> List[str]: 402 | picks: List[str] = [] 403 | for code, df in data.items(): 404 | hist = df[df["date"] <= date] 405 | if hist.empty: 406 | continue 407 | # 预留足够长度:RSV 计算窗口 + BBI 检测窗口 + m 408 | need_len = ( 409 | max(self.n_short, self.n_long) 410 | + self.bbi_min_window 411 | + self.m 412 | ) 413 | hist = hist.tail(max(need_len, self.max_window)) 414 | if self._passes_filters(hist): 415 | picks.append(code) 416 | return picks 417 | 418 | 419 | class BreakoutVolumeKDJSelector: 420 | """ 421 | 放量突破 + KDJ + DIF>0 + 收盘价波动幅度 选股器 422 | """ 423 | 424 | def __init__( 425 | self, 426 | j_threshold: float = 0.0, 427 | up_threshold: float = 3.0, 428 | volume_threshold: float = 2.0 / 3, 429 | offset: int = 15, 430 | max_window: int = 120, 431 | price_range_pct: float = 10.0, 432 | j_q_threshold: float = 0.10, # ← 新增 433 | ) -> None: 434 | self.j_threshold = j_threshold 435 | self.up_threshold = up_threshold 436 | self.volume_threshold = volume_threshold 437 | self.offset = offset 438 | self.max_window = max_window 439 | self.price_range_pct = price_range_pct 440 | self.j_q_threshold = j_q_threshold # ← 新增 441 | 442 | # ---------- 单支股票过滤 ---------- # 443 | def _passes_filters(self, hist: pd.DataFrame) -> bool: 444 | if len(hist) < self.offset + 2: 445 | return False 446 | 447 | hist = hist.tail(self.max_window).copy() 448 | 449 | # ---- 收盘价波动幅度约束 ---- 450 | high, low = hist["close"].max(), hist["close"].min() 451 | if low <= 0 or (high / low - 1) > self.price_range_pct: 452 | return False 453 | 454 | # ---- 技术指标 ---- 455 | hist = compute_kdj(hist) 456 | hist["pct_chg"] = hist["close"].pct_change() * 100 457 | hist["DIF"] = compute_dif(hist) 458 | 459 | # 0) 指定日约束:J < j_threshold 或位于历史分位;且 DIF > 0 460 | j_today = float(hist["J"].iloc[-1]) 461 | 462 | j_window = hist["J"].tail(self.max_window).dropna() 463 | if j_window.empty: 464 | return False 465 | j_quantile = float(j_window.quantile(self.j_q_threshold)) 466 | 467 | # 若不满足任一 J 条件,则淘汰 468 | if not (j_today < self.j_threshold or j_today <= j_quantile): 469 | return False 470 | if hist["DIF"].iloc[-1] <= 0: 471 | return False 472 | 473 | # ---- 放量突破条件 ---- 474 | n = len(hist) 475 | wnd_start = max(0, n - self.offset - 1) 476 | last_idx = n - 1 477 | 478 | for t_idx in range(wnd_start, last_idx): # 探索突破日 T 479 | row = hist.iloc[t_idx] 480 | 481 | # 1) 单日涨幅 482 | if row["pct_chg"] < self.up_threshold: 483 | continue 484 | 485 | # 2) 相对放量 486 | vol_T = row["volume"] 487 | if vol_T <= 0: 488 | continue 489 | vols_except_T = hist["volume"].drop(index=hist.index[t_idx]) 490 | if not (vols_except_T <= self.volume_threshold * vol_T).all(): 491 | continue 492 | 493 | # 3) 创新高 494 | if row["close"] <= hist["close"].iloc[:t_idx].max(): 495 | continue 496 | 497 | # 4) T 之后 J 值维持高位 498 | if not (hist["J"].iloc[t_idx:last_idx] > hist["J"].iloc[-1] - 10).all(): 499 | continue 500 | 501 | return True # 满足所有条件 502 | 503 | return False 504 | 505 | # ---------- 多股票批量 ---------- # 506 | def select( 507 | self, date: pd.Timestamp, data: Dict[str, pd.DataFrame] 508 | ) -> List[str]: 509 | picks: List[str] = [] 510 | for code, df in data.items(): 511 | hist = df[df["date"] <= date] 512 | if hist.empty: 513 | continue 514 | if self._passes_filters(hist): 515 | picks.append(code) 516 | return picks 517 | -------------------------------------------------------------------------------- /appendix.json: -------------------------------------------------------------------------------- 1 | {"data": [ 2 | "002527" 3 | ]} -------------------------------------------------------------------------------- /configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "selectors": [ 3 | { 4 | "class": "BBIKDJSelector", 5 | "alias": "少妇战法", 6 | "activate": true, 7 | "params": { 8 | "j_threshold": 1, 9 | "bbi_min_window": 20, 10 | "max_window": 60, 11 | "price_range_pct": 0.5, 12 | "bbi_q_threshold": 0.1, 13 | "j_q_threshold": 0.10 14 | } 15 | }, 16 | { 17 | "class": "BBIShortLongSelector", 18 | "alias": "补票战法", 19 | "activate": true, 20 | "params": { 21 | "n_short": 3, 22 | "n_long": 21, 23 | "m": 3, 24 | "bbi_min_window": 2, 25 | "max_window": 60, 26 | "bbi_q_threshold": 0.2 27 | } 28 | }, 29 | { 30 | "class": "BreakoutVolumeKDJSelector", 31 | "alias": "TePu战法", 32 | "activate": true, 33 | "params": { 34 | "j_threshold": 1, 35 | "j_q_threshold": 0.10, 36 | "up_threshold": 3.0, 37 | "volume_threshold": 0.6667, 38 | "offset": 15, 39 | "max_window": 60, 40 | "price_range_pct": 0.5 41 | } 42 | }, 43 | { 44 | "class": "PeakKDJSelector", 45 | "alias": "填坑战法", 46 | "activate": true, 47 | "params": { 48 | "j_threshold": 10, 49 | "max_window": 100, 50 | "fluc_threshold": 0.03, 51 | "j_q_threshold": 0.10, 52 | "gap_threshold": 0.2 53 | } 54 | } 55 | ] 56 | } 57 | -------------------------------------------------------------------------------- /fetch_kline.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import datetime as dt 5 | import json 6 | import logging 7 | import random 8 | import sys 9 | import time 10 | import warnings 11 | from concurrent.futures import ThreadPoolExecutor, as_completed 12 | from pathlib import Path 13 | from typing import List, Optional 14 | 15 | import akshare as ak 16 | import pandas as pd 17 | import tushare as ts 18 | from mootdx.quotes import Quotes 19 | from tqdm import tqdm 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | # --------------------------- 全局日志配置 --------------------------- # 24 | LOG_FILE = Path("fetch.log") 25 | logging.basicConfig( 26 | level=logging.INFO, 27 | format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d %(message)s", 28 | handlers=[ 29 | logging.StreamHandler(sys.stdout), 30 | logging.FileHandler(LOG_FILE, mode="a", encoding="utf-8"), 31 | ], 32 | ) 33 | logger = logging.getLogger("fetch_mktcap") 34 | 35 | # 屏蔽第三方库多余 INFO 日志 36 | for noisy in ("httpx", "urllib3", "_client", "akshare"): 37 | logging.getLogger(noisy).setLevel(logging.WARNING) 38 | 39 | # --------------------------- 市值快照 --------------------------- # 40 | 41 | def _get_mktcap_ak() -> pd.DataFrame: 42 | """实时快照,返回列:code, mktcap(单位:元)""" 43 | for attempt in range(1, 4): 44 | try: 45 | df = ak.stock_zh_a_spot_em() 46 | break 47 | except Exception as e: 48 | logger.warning("AKShare 获取市值快照失败(%d/3): %s", attempt, e) 49 | time.sleep(backoff := random.uniform(1, 3) * attempt) 50 | else: 51 | raise RuntimeError("AKShare 连续三次拉取市值快照失败!") 52 | 53 | df = df[["代码", "总市值"]].rename(columns={"代码": "code", "总市值": "mktcap"}) 54 | df["mktcap"] = pd.to_numeric(df["mktcap"], errors="coerce") 55 | return df 56 | 57 | # --------------------------- 股票池筛选 --------------------------- # 58 | 59 | def get_constituents( 60 | min_cap: float, 61 | max_cap: float, 62 | small_player: bool, 63 | mktcap_df: Optional[pd.DataFrame] = None, 64 | ) -> List[str]: 65 | df = mktcap_df if mktcap_df is not None else _get_mktcap_ak() 66 | 67 | cond = (df["mktcap"] >= min_cap) & (df["mktcap"] <= max_cap) 68 | if small_player: 69 | cond &= ~df["code"].str.startswith(("300", "301", "688", "8", "4")) 70 | 71 | codes = df.loc[cond, "code"].str.zfill(6).tolist() 72 | 73 | # 附加股票池 appendix.json 74 | try: 75 | with open("appendix.json", "r", encoding="utf-8") as f: 76 | appendix_codes = json.load(f)["data"] 77 | except FileNotFoundError: 78 | appendix_codes = [] 79 | codes = list(dict.fromkeys(appendix_codes + codes)) # 去重保持顺序 80 | 81 | logger.info("筛选得到 %d 只股票", len(codes)) 82 | return codes 83 | 84 | # --------------------------- 历史 K 线抓取 --------------------------- # 85 | COLUMN_MAP_HIST_AK = { 86 | "日期": "date", 87 | "开盘": "open", 88 | "收盘": "close", 89 | "最高": "high", 90 | "最低": "low", 91 | "成交量": "volume", 92 | "成交额": "amount", 93 | "换手率": "turnover", 94 | } 95 | 96 | _FREQ_MAP = { 97 | 0: "5m", 98 | 1: "15m", 99 | 2: "30m", 100 | 3: "1h", 101 | 4: "day", 102 | 5: "week", 103 | 6: "mon", 104 | 7: "1m", 105 | 8: "1m", 106 | 9: "day", 107 | 10: "3mon", 108 | 11: "year", 109 | } 110 | 111 | # ---------- Tushare 工具函数 ---------- # 112 | 113 | def _to_ts_code(code: str) -> str: 114 | return f"{code.zfill(6)}.SH" if code.startswith(("60", "68", "9")) else f"{code.zfill(6)}.SZ" 115 | 116 | 117 | def _get_kline_tushare(code: str, start: str, end: str, adjust: str) -> pd.DataFrame: 118 | ts_code = _to_ts_code(code) 119 | adj_flag = None if adjust == "" else adjust 120 | for attempt in range(1, 4): 121 | try: 122 | df = ts.pro_bar( 123 | ts_code=ts_code, 124 | adj=adj_flag, 125 | start_date=start, 126 | end_date=end, 127 | freq="D", 128 | ) 129 | break 130 | except Exception as e: 131 | logger.warning("Tushare 拉取 %s 失败(%d/3): %s", code, attempt, e) 132 | time.sleep(random.uniform(1, 2) * attempt) 133 | else: 134 | return pd.DataFrame() 135 | 136 | if df is None or df.empty: 137 | return pd.DataFrame() 138 | 139 | df = df.rename(columns={"trade_date": "date", "vol": "volume"})[ 140 | ["date", "open", "close", "high", "low", "volume"] 141 | ].copy() 142 | df["date"] = pd.to_datetime(df["date"]) 143 | df[[c for c in df.columns if c != "date"]] = df[[c for c in df.columns if c != "date"]].apply( 144 | pd.to_numeric, errors="coerce" 145 | ) 146 | return df.sort_values("date").reset_index(drop=True) 147 | 148 | # ---------- AKShare 工具函数 ---------- # 149 | 150 | def _get_kline_akshare(code: str, start: str, end: str, adjust: str) -> pd.DataFrame: 151 | for attempt in range(1, 4): 152 | try: 153 | df = ak.stock_zh_a_hist( 154 | symbol=code, 155 | period="daily", 156 | start_date=start, 157 | end_date=end, 158 | adjust=adjust, 159 | ) 160 | break 161 | except Exception as e: 162 | logger.warning("AKShare 拉取 %s 失败(%d/3): %s", code, attempt, e) 163 | time.sleep(random.uniform(1, 2) * attempt) 164 | else: 165 | return pd.DataFrame() 166 | 167 | if df is None or df.empty: 168 | return pd.DataFrame() 169 | 170 | df = ( 171 | df[list(COLUMN_MAP_HIST_AK)] 172 | .rename(columns=COLUMN_MAP_HIST_AK) 173 | .assign(date=lambda x: pd.to_datetime(x["date"])) 174 | ) 175 | df[[c for c in df.columns if c != "date"]] = df[[c for c in df.columns if c != "date"]].apply( 176 | pd.to_numeric, errors="coerce" 177 | ) 178 | df = df[["date", "open", "close", "high", "low", "volume"]] 179 | return df.sort_values("date").reset_index(drop=True) 180 | 181 | # ---------- Mootdx 工具函数 ---------- # 182 | 183 | def _get_kline_mootdx(code: str, start: str, end: str, adjust: str, freq_code: int) -> pd.DataFrame: 184 | symbol = code.zfill(6) 185 | freq = _FREQ_MAP.get(freq_code, "day") 186 | client = Quotes.factory(market="std") 187 | try: 188 | df = client.bars(symbol=symbol, frequency=freq, adjust=adjust or None) 189 | except Exception as e: 190 | logger.warning("Mootdx 拉取 %s 失败: %s", code, e) 191 | return pd.DataFrame() 192 | if df is None or df.empty: 193 | return pd.DataFrame() 194 | 195 | df = df.rename( 196 | columns={"datetime": "date", "open": "open", "high": "high", "low": "low", "close": "close", "vol": "volume"} 197 | ) 198 | df["date"] = pd.to_datetime(df["date"]).dt.normalize() 199 | start_ts = pd.to_datetime(start, format="%Y%m%d") 200 | end_ts = pd.to_datetime(end, format="%Y%m%d") 201 | df = df[(df["date"].dt.date >= start_ts.date()) & (df["date"].dt.date <= end_ts.date())].copy() 202 | df = df.sort_values("date").reset_index(drop=True) 203 | return df[["date", "open", "close", "high", "low", "volume"]] 204 | 205 | # ---------- 通用接口 ---------- # 206 | 207 | def get_kline( 208 | code: str, 209 | start: str, 210 | end: str, 211 | adjust: str, 212 | datasource: str, 213 | freq_code: int = 4, 214 | ) -> pd.DataFrame: 215 | if datasource == "tushare": 216 | return _get_kline_tushare(code, start, end, adjust) 217 | elif datasource == "akshare": 218 | return _get_kline_akshare(code, start, end, adjust) 219 | elif datasource == "mootdx": 220 | return _get_kline_mootdx(code, start, end, adjust, freq_code) 221 | else: 222 | raise ValueError("datasource 仅支持 'tushare', 'akshare' 或 'mootdx'") 223 | 224 | # ---------- 数据校验 ---------- # 225 | 226 | def validate(df: pd.DataFrame) -> pd.DataFrame: 227 | df = df.drop_duplicates(subset="date").sort_values("date").reset_index(drop=True) 228 | if df["date"].isna().any(): 229 | raise ValueError("存在缺失日期!") 230 | if (df["date"] > pd.Timestamp.today()).any(): 231 | raise ValueError("数据包含未来日期,可能抓取错误!") 232 | return df 233 | 234 | def drop_dup_columns(df: pd.DataFrame) -> pd.DataFrame: 235 | return df.loc[:, ~df.columns.duplicated()] 236 | # ---------- 单只股票抓取 ---------- # 237 | def fetch_one( 238 | code: str, 239 | start: str, 240 | end: str, 241 | out_dir: Path, 242 | incremental: bool, 243 | datasource: str, 244 | freq_code: int, 245 | ): 246 | csv_path = out_dir / f"{code}.csv" 247 | 248 | # 增量更新:若本地已有数据则从最后一天开始 249 | if incremental and csv_path.exists(): 250 | try: 251 | existing = pd.read_csv(csv_path, parse_dates=["date"]) 252 | last_date = existing["date"].max() 253 | if last_date.date() > pd.to_datetime(end, format="%Y%m%d").date(): 254 | logger.debug("%s 已是最新,无需更新", code) 255 | return 256 | start = last_date.strftime("%Y%m%d") 257 | except Exception: 258 | logger.exception("读取 %s 失败,将重新下载", csv_path) 259 | 260 | for attempt in range(1, 4): 261 | try: 262 | new_df = get_kline(code, start, end, "qfq", datasource, freq_code) 263 | if new_df.empty: 264 | logger.debug("%s 无新数据", code) 265 | break 266 | new_df = validate(new_df) 267 | if csv_path.exists() and incremental: 268 | old_df = pd.read_csv( 269 | csv_path, 270 | parse_dates=["date"], 271 | index_col=False 272 | ) 273 | old_df = drop_dup_columns(old_df) 274 | new_df = drop_dup_columns(new_df) 275 | new_df = ( 276 | pd.concat([old_df, new_df], ignore_index=True) 277 | .drop_duplicates(subset="date") 278 | .sort_values("date") 279 | ) 280 | new_df.to_csv(csv_path, index=False) 281 | break 282 | except Exception: 283 | logger.exception("%s 第 %d 次抓取失败", code, attempt) 284 | time.sleep(random.uniform(1, 3) * attempt) # 指数退避 285 | else: 286 | logger.error("%s 三次抓取均失败,已跳过!", code) 287 | 288 | 289 | # ---------- 主入口 ---------- # 290 | 291 | def main(): 292 | parser = argparse.ArgumentParser(description="按市值筛选 A 股并抓取历史 K 线") 293 | parser.add_argument("--datasource", choices=["tushare", "akshare", "mootdx"], default="tushare", help="历史 K 线数据源") 294 | parser.add_argument("--frequency", type=int, choices=list(_FREQ_MAP.keys()), default=4, help="K线频率编码,参见说明") 295 | parser.add_argument("--exclude-gem", default=True, help="True则排除创业板/科创板/北交所") 296 | parser.add_argument("--min-mktcap", type=float, default=5e9, help="最小总市值(含),单位:元") 297 | parser.add_argument("--max-mktcap", type=float, default=float("+inf"), help="最大总市值(含),单位:元,默认无限制") 298 | parser.add_argument("--start", default="20190101", help="起始日期 YYYYMMDD 或 'today'") 299 | parser.add_argument("--end", default="today", help="结束日期 YYYYMMDD 或 'today'") 300 | parser.add_argument("--out", default="./data", help="输出目录") 301 | parser.add_argument("--workers", type=int, default=3, help="并发线程数") 302 | args = parser.parse_args() 303 | 304 | # ---------- Token 处理 ---------- # 305 | if args.datasource == "tushare": 306 | ts_token = " " # 在这里补充token 307 | ts.set_token(ts_token) 308 | global pro 309 | pro = ts.pro_api() 310 | 311 | # ---------- 日期解析 ---------- # 312 | start = dt.date.today().strftime("%Y%m%d") if args.start.lower() == "today" else args.start 313 | end = dt.date.today().strftime("%Y%m%d") if args.end.lower() == "today" else args.end 314 | 315 | out_dir = Path(args.out) 316 | out_dir.mkdir(parents=True, exist_ok=True) 317 | 318 | # ---------- 市值快照 & 股票池 ---------- # 319 | mktcap_df = _get_mktcap_ak() 320 | 321 | codes_from_filter = get_constituents( 322 | args.min_mktcap, 323 | args.max_mktcap, 324 | args.exclude_gem, 325 | mktcap_df=mktcap_df, 326 | ) 327 | # 加上本地已有的股票,确保旧数据也能更新 328 | local_codes = [p.stem for p in out_dir.glob("*.csv")] 329 | codes = sorted(set(codes_from_filter) | set(local_codes)) 330 | 331 | if not codes: 332 | logger.error("筛选结果为空,请调整参数!") 333 | sys.exit(1) 334 | 335 | logger.info( 336 | "开始抓取 %d 支股票 | 数据源:%s | 频率:%s | 日期:%s → %s", 337 | len(codes), 338 | args.datasource, 339 | _FREQ_MAP[args.frequency], 340 | start, 341 | end, 342 | ) 343 | 344 | # ---------- 多线程抓取 ---------- # 345 | with ThreadPoolExecutor(max_workers=args.workers) as executor: 346 | futures = [ 347 | executor.submit( 348 | fetch_one, 349 | code, 350 | start, 351 | end, 352 | out_dir, 353 | True, 354 | args.datasource, 355 | args.frequency, 356 | ) 357 | for code in codes 358 | ] 359 | for _ in tqdm(as_completed(futures), total=len(futures), desc="下载进度"): 360 | pass 361 | 362 | logger.info("全部任务完成,数据已保存至 %s", out_dir.resolve()) 363 | 364 | 365 | if __name__ == "__main__": 366 | main() 367 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | akshare==1.17.7 2 | mootdx==0.11.7 3 | pandas==2.3.0 4 | tqdm==4.66.4 5 | tushare==1.4.21 6 | scipy==1.14.1 7 | -------------------------------------------------------------------------------- /select_stock.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import importlib 5 | import json 6 | import logging 7 | import sys 8 | from pathlib import Path 9 | from typing import Any, Dict, Iterable, List 10 | 11 | import pandas as pd 12 | 13 | # ---------- 日志 ---------- 14 | logging.basicConfig( 15 | level=logging.INFO, 16 | format="%(asctime)s [%(levelname)s] %(message)s", 17 | handlers=[ 18 | logging.StreamHandler(sys.stdout), 19 | # 将日志写入文件 20 | logging.FileHandler("select_results.log", encoding="utf-8"), 21 | ], 22 | ) 23 | logger = logging.getLogger("select") 24 | 25 | 26 | # ---------- 工具 ---------- 27 | 28 | def load_data(data_dir: Path, codes: Iterable[str]) -> Dict[str, pd.DataFrame]: 29 | frames: Dict[str, pd.DataFrame] = {} 30 | for code in codes: 31 | fp = data_dir / f"{code}.csv" 32 | if not fp.exists(): 33 | logger.warning("%s 不存在,跳过", fp.name) 34 | continue 35 | df = pd.read_csv(fp, parse_dates=["date"]).sort_values("date") 36 | frames[code] = df 37 | return frames 38 | 39 | 40 | def load_config(cfg_path: Path) -> List[Dict[str, Any]]: 41 | if not cfg_path.exists(): 42 | logger.error("配置文件 %s 不存在", cfg_path) 43 | sys.exit(1) 44 | with cfg_path.open(encoding="utf-8") as f: 45 | cfg_raw = json.load(f) 46 | 47 | # 兼容三种结构:单对象、对象数组、或带 selectors 键 48 | if isinstance(cfg_raw, list): 49 | cfgs = cfg_raw 50 | elif isinstance(cfg_raw, dict) and "selectors" in cfg_raw: 51 | cfgs = cfg_raw["selectors"] 52 | else: 53 | cfgs = [cfg_raw] 54 | 55 | if not cfgs: 56 | logger.error("configs.json 未定义任何 Selector") 57 | sys.exit(1) 58 | 59 | return cfgs 60 | 61 | 62 | def instantiate_selector(cfg: Dict[str, Any]): 63 | """动态加载 Selector 类并实例化""" 64 | cls_name: str = cfg.get("class") 65 | if not cls_name: 66 | raise ValueError("缺少 class 字段") 67 | 68 | try: 69 | module = importlib.import_module("Selector") 70 | cls = getattr(module, cls_name) 71 | except (ModuleNotFoundError, AttributeError) as e: 72 | raise ImportError(f"无法加载 Selector.{cls_name}: {e}") from e 73 | 74 | params = cfg.get("params", {}) 75 | return cfg.get("alias", cls_name), cls(**params) 76 | 77 | 78 | # ---------- 主函数 ---------- 79 | 80 | def main(): 81 | p = argparse.ArgumentParser(description="Run selectors defined in configs.json") 82 | p.add_argument("--data-dir", default="./data", help="CSV 行情目录") 83 | p.add_argument("--config", default="./configs.json", help="Selector 配置文件") 84 | p.add_argument("--date", help="交易日 YYYY-MM-DD;缺省=数据最新日期") 85 | p.add_argument("--tickers", default="all", help="'all' 或逗号分隔股票代码列表") 86 | args = p.parse_args() 87 | 88 | # --- 加载行情 --- 89 | data_dir = Path(args.data_dir) 90 | if not data_dir.exists(): 91 | logger.error("数据目录 %s 不存在", data_dir) 92 | sys.exit(1) 93 | 94 | codes = ( 95 | [f.stem for f in data_dir.glob("*.csv")] 96 | if args.tickers.lower() == "all" 97 | else [c.strip() for c in args.tickers.split(",") if c.strip()] 98 | ) 99 | if not codes: 100 | logger.error("股票池为空!") 101 | sys.exit(1) 102 | 103 | data = load_data(data_dir, codes) 104 | if not data: 105 | logger.error("未能加载任何行情数据") 106 | sys.exit(1) 107 | 108 | trade_date = ( 109 | pd.to_datetime(args.date) 110 | if args.date 111 | else max(df["date"].max() for df in data.values()) 112 | ) 113 | if not args.date: 114 | logger.info("未指定 --date,使用最近日期 %s", trade_date.date()) 115 | 116 | # --- 加载 Selector 配置 --- 117 | selector_cfgs = load_config(Path(args.config)) 118 | 119 | # --- 逐个 Selector 运行 --- 120 | for cfg in selector_cfgs: 121 | if cfg.get("activate", True) is False: 122 | continue 123 | try: 124 | alias, selector = instantiate_selector(cfg) 125 | except Exception as e: 126 | logger.error("跳过配置 %s:%s", cfg, e) 127 | continue 128 | 129 | picks = selector.select(trade_date, data) 130 | 131 | # 将结果写入日志,同时输出到控制台 132 | logger.info("") 133 | logger.info("============== 选股结果 [%s] ==============", alias) 134 | logger.info("交易日: %s", trade_date.date()) 135 | logger.info("符合条件股票数: %d", len(picks)) 136 | logger.info("%s", ", ".join(picks) if picks else "无符合条件股票") 137 | 138 | 139 | if __name__ == "__main__": 140 | main() 141 | --------------------------------------------------------------------------------