├── README.pdf ├── data ├── v_svm.png ├── k_svm(1).png ├── k_svm(3).png ├── HS300_05_18.xlsx └── attr_heatmap.png ├── .gitignore ├── README.md ├── k_svm.py └── v_svm.py /README.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhuli/SVM-and-HS300/HEAD/README.pdf -------------------------------------------------------------------------------- /data/v_svm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhuli/SVM-and-HS300/HEAD/data/v_svm.png -------------------------------------------------------------------------------- /data/k_svm(1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhuli/SVM-and-HS300/HEAD/data/k_svm(1).png -------------------------------------------------------------------------------- /data/k_svm(3).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhuli/SVM-and-HS300/HEAD/data/k_svm(3).png -------------------------------------------------------------------------------- /data/HS300_05_18.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhuli/SVM-and-HS300/HEAD/data/HS300_05_18.xlsx -------------------------------------------------------------------------------- /data/attr_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhuli/SVM-and-HS300/HEAD/data/attr_heatmap.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## SVM & HS300 2 | 3 | ### k_svm.py 4 | 5 | 1. 利用沪深300的日行情数据:开高低收、交易量,构建五个特征来刻画K线的形态: 6 | 7 | $high_low = \frac {high} {low} - 1$ 8 | 9 | $high\_close = \frac {high}{close} - 1$ 10 | 11 | $close\_low = \frac {close} {low} - 1$ 12 | 13 | $close\_open = \frac {close}{open} - 1$ 14 | 15 | $vol\_pct = \frac {vol_i} {vol_{i-1}} -1$ 16 | 17 | 以前3天的特征数据,利用基于高斯核的支持向量机分类器,预测后3天的市场涨跌方向; 18 | 19 | 2. 以10年的数据为训练集,采用网格搜索和交叉验证优化参数,其中$K-Fold$的$K$设为10,下一年的数据为测试集,依次往后滚动每隔1年构建一个新的分类器; 20 | 21 | 3. 根据预测进行多空交易,收益率为预测期间第三天的收盘价减去第一天的开盘价,并扣除一定的滑点比例 22 | 23 | 单方向的滑点为0.0002 24 | 25 | 如果预测正确: 26 | 27 | $profit = |Close_3 - Open_1|- slippage$ 28 | 29 | 如果预测失误: 30 | 31 | $loss =- |Close_3 - Open_1| - slippage$ 32 | 33 | 4. **收益情况** 34 | 35 | ![k_svm(3)](https://github.com/Jensenberg/SVM-and-HS300/blob/master/data/k_svm(3).png) 36 | 37 | 2015年1月至2018年10月考虑滑点的情况下的总回报,是1.66,最高值1.75左右,最大回撤26%,沪深300的总回报1.11,最大回撤47%。 38 | 39 | ![k_svm(1)](https://github.com/Jensenberg/SVM-and-HS300/blob/master/data/k_svm(1).png) 40 | 41 | 如果仅仅是预测后1天的涨跌,总回报是1.80,最高值为2.31,最大回撤34%; 42 | 43 | 5. 可以看到策略主要是在2015年完成收益率的积累,后续表现乏力,说明在趋势阶段预测效果较好,其他阶段则不如人意。 44 | 45 | 6. **预测正确率** 46 | 47 | 在2016年和2017年的预测正确率均低于50%,可能原因是,特征的相关性很高,有效信息更新太慢。 48 | 49 | | period | C | gamma | train score | average train score(k=10) | test score | 50 | | :--------------: | :--: | :---: | :---------: | :------------------------: | :--------: | 51 | | 2014.11- 2015.11 | 8 | 0.02 | 0.6592 | 0.5650 | 0.5333 | 52 | | 2015.11-2016.11 | 6 | 0.02 | 0.6329 | 0.5529 | 0.4750 | 53 | | 2016.11-2017.11 | 5 | 0.06 | 0.7063 | 0.5517 | 0.5167 | 54 | | 2017.11-2018.10 | 8 | 0.02 | 0.6383 | 0.5579 | 0.4566 | 55 | 56 | 7. **reference** 57 | 58 | [优矿,基于SVM的大盘预测](https://uqer.io/v3/community/share/56e6629e228e5b6ef3157588) 59 | 60 | ### v_svm.py 61 | 62 | 1. 利用沪深300的日行情数据:开高低收、交易量,构建波动、动量和趋势相关的七个特征: 63 | 64 | $high\_low = \frac {high-open} {open} $ 65 | 66 | $close\_open = \frac {close}{open} - 1$ 67 | 68 | $vol\_pct = \frac {vol_i} {vol_{i-1}} -1$ 69 | 70 | $pct\_m = \frac {close_i } {open_{i-m}} - 1, m=5$ 71 | 72 | $high\_v = \frac {high_i} {max{\{high_t, t=i, i-1, ... , i-v}\}} -1, v=20$ 73 | 74 | $low\_v = \frac {low_i} {max{\{low_t, t=i, i-1, ... , i-v}\}} -1, v=20$ 75 | 76 | $sigma =\sqrt{\sum _{i=1}^{21}\frac{(r_i - \bar r)^2}{20}} \times \frac{240}{20}$ 77 | 78 | 分别表示当天的振幅(最大可能收益),当天实现的收益,交易量的变化幅度,m天的累计收益率,当天最高价与20日最高价的差距,当天最低价与20日最低价的差距,20日的波动率。 79 | 80 | 2. 特征说明 81 | 82 | (1)振幅、当天收益率,交易量的变化,可以反映当天的市场波动情况; 83 | 84 | (2)m天的累计收益率,反映了当前的市场位置和周线的情况; 85 | 86 | (3)后三者反映了中期月度频率下的市场位置和波动情况 87 | 88 | 利用这七个特征和基于高斯核的支持向量机预测后一天的市场涨跌,并进行相应的多空交易。 89 | 90 | 3. 特征之间的相关性较小 91 | 92 | 相关性最强的是$sigma$和$high\_low$,相关系数为0.65; 93 | 94 | 其次是$sigma$和$high\_v$,相关系数为-0.54; 95 | 96 | 其余的相关系数的绝对值都小于0.5 97 | 98 | ![attr_heatmap](https://github.com/Jensenberg/SVM-and-HS300/blob/master/data/attr_heatmap.png) 99 | 100 | 4. **收益情况** 101 | 102 | 总回报是2.58,最大回撤18.67%,calmar比率1.43,远好于指数的表现。 103 | 104 | ![v_svm](https://github.com/Jensenberg/SVM-and-HS300/blob/master/data/v_svm.png) 105 | 106 | 107 | 108 | 5. **预测正确率** 109 | 110 | 可以看到特征的预测准确率是比较高的,除了2016年外,均达到了60%以上。 111 | 112 | |period|C|gamma|train score|average train score(k=10)|test score| 113 | |:--------------:|:--:|:-----:|:---------:|:------------------------:|:--------:| 114 | | 2014.11- 2015.11 | 1 | 0.006 | 0.6483 | 0.6413 | 0.6392 | 115 | | 2015.11- 2016.11 | 2 | 0.006 | 0.6521 | 0.6404 | 0.6792 | 116 | | 2016.11- 2017.11 | 1 | 0.012 | 0.6508 | 0.6454 | 0.5083 | 117 | | 2017.11- 2018.10 | 3 | 0.008 | 0.6396 | 0.6267 | 0.6162 | 118 | 119 | -------------------------------------------------------------------------------- /k_svm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Oct 10 23:24:02 2018 4 | 5 | @author: 54326 6 | """ 7 | 8 | import pandas as pd 9 | import numpy as np 10 | from sklearn.svm import SVC 11 | from sklearn.model_selection import GridSearchCV as GSCV 12 | 13 | class SVM(): 14 | ''' 15 | 必须建立类属性data,包含开高低收、成交量的日数据 16 | ''' 17 | 18 | train_years=10 #训练样本的大小,十年数据 19 | year_days=240 #一年有240个交易日 20 | 21 | def __init__(self, backward, forward): 22 | self.bk = backward #回看前几天的数据 23 | self.fw = forward #预测后几天的收益 24 | self.slippage = 0.0002 #滑点情况,共2个基点,买入卖出各1个基点 25 | 26 | def data_pre(self): 27 | ''' 28 | 数据准备,构建数据特征,并标记类别标签 29 | ''' 30 | data = self.data 31 | data['retn'] = data['close'].shift(-self.fw) / data['open'] - 1 32 | attr = pd.DataFrame() 33 | #构建四个个数据特征,用于刻画K线的形态,并考虑交易量的变化 34 | attr['high_low'] = data['high'] / data['low'] - 1 35 | attr['high_close'] = data['high'] / data['close'] - 1 36 | attr['close_low'] = data['close'] / data['low'] - 1 37 | attr['close_open'] = data['close'] / data['open'] - 1 38 | attr['vol_pct'] = data['volume'].pct_change() 39 | attr.dropna(inplace= True) 40 | 41 | #将特征标准化 42 | for field in attr.columns: 43 | attr[field] = (attr[field] - attr[field].mean()) / attr[field].std() 44 | 45 | attrs = {} 46 | dates = attr.index 47 | bk = self.bk 48 | #用前3天的特征值来预测后3天的涨跌涨跌方向 49 | for i in range(bk, len(attr) - bk + 1): 50 | attrs[dates[i-1]] = attr.iloc[i-bk : i, :].stack().values 51 | attrs = pd.DataFrame(attrs).T 52 | attrs['change_fw'] = data['close'].shift(-self.fw) - data['close'] 53 | #上涨标记为1,下跌标记为-1 54 | attrs['label'] = np.where(attrs['change_fw'] > 0, 1, -1) 55 | attrs.drop('change_fw', axis=1, inplace=True) 56 | return attrs 57 | 58 | def data_split(self, begin_year): 59 | ''' 60 | 数据划分,十年数据为训练样本,下一年的数据为测试样本 61 | Args: 62 | begin_year: 63 | 第几年,0表示第一年,以此类推 64 | Returns: 65 | train_data: 66 | 训练集 67 | test_data: 68 | 测试集 69 | ''' 70 | begin = begin_year * self.year_days #训练集开始的位置 71 | end = begin + self.train_years * self.year_days #训练集结束的位置 72 | attrs = self.data_pre() 73 | train_data = attrs.iloc[begin : end, :] #训练集 74 | test_data = attrs.iloc[end : end + self.year_days, :] #下一年数据为测试集 75 | return train_data, test_data 76 | 77 | def gscv_para(self, C_list, gamma_list, x_train, y_train): 78 | ''' 79 | 用网格搜索和交叉验证调节参数,考虑类别不平衡的情况,k-fold的k为10 80 | Args: 81 | C_list: 82 | C参数的备选列表 83 | gamma_list: 84 | gamma参数的备选列表 85 | x_train: 86 | 训练集中的特征数据 87 | y_train: 88 | 训练集中的类别数据 89 | Returns: 90 | 最优的C参数,gamma参数和最优时的score(平均值) 91 | 优化的标准是SVC的score值,score越高表示,表示参数越好 92 | ''' 93 | clf = SVC(class_weight='balanced', cache_size=4000) 94 | gscv = GSCV(clf, param_grid={'C': C_list, 'gamma': gamma_list}, 95 | n_jobs=-1, cv=10, pre_dispatch=4) 96 | gscv.fit(x_train, y_train) 97 | return gscv.best_params_.values(), gscv.best_score_ 98 | 99 | def svm_fit(self, x_train, y_train): 100 | ''' 101 | 用网格搜索调参数 102 | ''' 103 | C_list = list(range(1, 10)) 104 | gamma_list = np.linspace(0.01, 0.1, 10) 105 | (C, gamma), score = self.gscv_para(C_list, gamma_list, x_train, y_train) 106 | #可能存在参数在边界的情况,此时需要重置一下参数的备选范围 107 | # if C == 1 or C == 10: 108 | # if C == 1: 109 | # print('C touched the lower limit') 110 | # C_list = np.linspace(0.1, 0.5, 5) 111 | # else: 112 | # print('C touched the upper limit') 113 | # C_list = list(range(10, 15)) 114 | # C, gamma = self.gscv_para(C_list, gamma_list, x_train, y_train) 115 | # else: 116 | # pass 117 | # if gamma == 0.01 or gamma == 0.1: 118 | # if gamma == 0.01: 119 | # print('gamma touched the lower limit') 120 | # gamma_list = np.linspace(0.001, 0.01, 10) 121 | # else: 122 | # print('gamma touched the upper limit') 123 | # gamma_list = np.linspace(0.1, 0.5, 9) 124 | # C, gamma = self.gscv_para(C_list, gamma_list, x_train, y_train) 125 | # else: 126 | # pass 127 | return C, gamma, score 128 | 129 | def predict(self, begin_year): 130 | ''' 131 | 用拟合的分类器对样本外的数据做预测 132 | ''' 133 | train_data, test_data = self.data_split(begin_year) 134 | x_train, y_train = train_data.iloc[:, :-1], train_data.iloc[:, -1] 135 | x_test, y_test = test_data.iloc[:, :-1], test_data.iloc[:, -1] 136 | # C, gamma, score = self.svm_fit(x_train, y_train) 137 | C, gamma, score, _ = paras[begin_year] 138 | clf = SVC(C=C, gamma=gamma, class_weight='balanced', cache_size=4000) 139 | clf.fit(x_train, y_train) 140 | test_score = clf.score(x_test, y_test) #测试集上的准确率 141 | para = C, gamma, score, test_score 142 | print('%d, %.4f, %.4f, %.4f' % para) 143 | #在测试集上的预测值 144 | y_predict = pd.Series(clf.predict(x_test), index=y_test.index, name='predict') 145 | return para, pd.concat([y_test, y_predict], axis=1) 146 | 147 | def cum_retn(self, years): 148 | retns = {} 149 | paras = [] 150 | for i in range(years): 151 | para, label = self.predict(i) 152 | paras.append(para) 153 | for j in range(len(label)): 154 | date = label.index[j] 155 | if label.loc[date, 'label'] == label.loc[date, 'predict']: 156 | #预测正确,则获取多空收益 157 | retns[date] = abs(data.loc[date, 'retn']) - self.slippage 158 | else: 159 | #预测失误,产生损失 160 | retns[date] = - abs(data.loc[date, 'retn']) - self.slippage 161 | return paras, 1 + pd.Series(retns).sort_index().cumsum() 162 | 163 | if __name__ == '__main__': 164 | 165 | import matplotlib.pyplot as plt 166 | 167 | data = pd.read_excel('E:/Data/HS300_05_18.xlsx', index_col='date') 168 | SVM.data = data 169 | hs = SVM(3, 3) 170 | years = int(len(data) / 240) + 1 171 | paras, nav = hs.cum_retn(years - SVM.train_years) 172 | def drawdown(nav): 173 | #计算最大回撤 174 | dd = [] 175 | for i in range(1, len(nav)): 176 | max_i = max(nav[:i]) 177 | dd.append(min(0, nav[i] - max_i) / max_i) 178 | return dd 179 | Drawdown = pd.Series(drawdown(nav), index=nav.index[1:]) 180 | maxdd = min(Drawdown) 181 | 182 | fig, ax1 = plt.subplots(figsize=(15, 8)) 183 | ax1.plot(nav, label='strategy') 184 | hs300 = data['close'][-len(nav):] / data['close'][-len(nav)] 185 | ax1.set_xlim(nav.index[0], nav.index[-1]) 186 | ax1.plot(hs300, label='HS300') 187 | ax1.set_ylabel('Net Asset Value', fontdict={'fontsize':16}) 188 | ax1.set_xlabel('Date', fontdict={'fontsize':16}) 189 | ax1.legend(loc='center right', fontsize=16) 190 | ax2 = ax1.twinx() 191 | ax2.set_ylim(-1.5, 0) 192 | ax2.plot(Drawdown, color='c') 193 | ax2.set_ylabel('Max Drawdown', fontdict={'fontsize':16}) 194 | ax2.fill_between(Drawdown.index, Drawdown, color='c') 195 | ax2.set_ylim(-1.5, 0) 196 | plt.savefig('svm_hs300.png', bbox_inches='tight') 197 | 198 | for i in range(4): 199 | train_data, test_data = hs.data_split(i) 200 | x_train, y_train = train_data.iloc[:, :-1], train_data.iloc[:, -1] 201 | x_test, y_test = test_data.iloc[:, :-1], test_data.iloc[:, -1] 202 | C, gamma, *res = paras[i] 203 | clf = SVC(C=C, gamma=gamma, class_weight='balanced', cache_size=4000) 204 | clf.fit(x_train, y_train) 205 | print('%.4f' % clf.score(x_train, y_train)) 206 | 207 | -------------------------------------------------------------------------------- /v_svm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Oct 13 13:22:02 2018 4 | 5 | @author: 54326 6 | """ 7 | 8 | import pandas as pd 9 | import numpy as np 10 | import seaborn as sns 11 | import matplotlib.pyplot as plt 12 | from sklearn.svm import SVC 13 | from sklearn.model_selection import GridSearchCV as GSCV 14 | 15 | class SVM(): 16 | ''' 17 | 必须先建立类属性data,包含开高低收、交易量的日数据 18 | ''' 19 | 20 | train_years=10 #训练样本的大小,十年数据 21 | year_days=240 #一年有240个交易日 22 | 23 | def __init__(self, momentum, span): 24 | self.m_span = momentum #动量的观测天数 25 | self.span = span #波动性的观测天数 26 | self.slippage = 0.0002 #滑点情况,共2个基点,买入卖出各1个基点 27 | 28 | def data_pre(self): 29 | ''' 30 | 数据准备,构建数据特征,并标记类别标签 31 | ''' 32 | data = self.data 33 | data['retn'] = data['close'].pct_change() 34 | span = self.span 35 | 36 | attr = pd.DataFrame() 37 | #构建七个数据特征,刻画当天,周度和月度的市场状况 38 | attr['high_low'] = (data['high'] - data['low']) / data['open'] 39 | attr['close_open'] = data['close'] / data['open'] - 1 40 | attr['vol_pct'] = data['volume'].pct_change() 41 | attr['pct_m'] = data['close'].shift(-self.m_span) / data['open'] - 1 42 | for i in range(span, len(data)-span): 43 | now = data.index[i] 44 | begin = data.index[i-span] 45 | attr.loc[now, 'high_v'] = data.loc[now, 'high']\ 46 | / max(data.loc[begin:now, 'high']) - 1 47 | attr.loc[now, 'low_v'] = data.loc[now, 'low']\ 48 | / min(data.loc[begin:now, 'low']) - 1 49 | attr.loc[now, 'sigma'] = data.loc[begin:now, 'retn'].std()\ 50 | * self.year_days / span 51 | attr.dropna(inplace= True) 52 | 53 | #将特征标准化 54 | for field in attr.columns: 55 | attr[field] = (attr[field] - attr[field].mean()) / attr[field].std() 56 | 57 | #上涨标记为1,下跌标记为-1 58 | attr['label'] = np.where(attr['close_open'].shift(-1) > 0, 1, -1) 59 | 60 | return attr 61 | 62 | def data_split(self, begin_year): 63 | ''' 64 | 数据划分,十年数据为训练样本,下一年的数据为测试样本 65 | Args: 66 | begin_year: 67 | 第几年,0表示第一年,以此类推 68 | Returns: 69 | train_data: 70 | 训练集 71 | test_data: 72 | 测试集 73 | ''' 74 | begin = begin_year * self.year_days #训练集开始的位置 75 | end = begin + self.train_years * self.year_days #训练集结束的位置 76 | attrs = self.data_pre() 77 | train_data = attrs.iloc[begin : end, :] #训练集 78 | test_data = attrs.iloc[end : end + self.year_days, :] #下一年数据为测试集 79 | return train_data, test_data 80 | 81 | def gscv_para(self, C_list, gamma_list, x_train, y_train): 82 | ''' 83 | 用网格搜索和交叉验证调节参数,考虑类别不平衡的情况,k-fold的k为10 84 | Args: 85 | C_list: 86 | C参数的备选列表 87 | gamma_list: 88 | gamma参数的备选列表 89 | x_train: 90 | 训练集中的特征数据 91 | y_train: 92 | 训练集中的类别数据 93 | Returns: 94 | 最优的C参数,gamma参数和最优时的score(平均值) 95 | 优化的标准是SVC的score值,score越高表示,表示参数越好 96 | ''' 97 | clf = SVC(class_weight='balanced', cache_size=4000) 98 | gscv = GSCV(clf, param_grid={'C': C_list, 'gamma': gamma_list}, 99 | n_jobs=-1, cv=10, pre_dispatch=4) 100 | gscv.fit(x_train, y_train) 101 | return gscv.best_params_.values(), gscv.best_score_ 102 | 103 | def svm_fit(self, x_train, y_train): 104 | ''' 105 | 用网格搜索调参数 106 | ''' 107 | C_list = list(range(1, 5)) 108 | gamma_list = np.linspace(0.01, 0.05, 5) 109 | (C, gamma), score = self.gscv_para(C_list, gamma_list, x_train, y_train) 110 | #可能存在参数在边界的情况,此时需要重置一下参数的备选范围 111 | if C == 1 or C == 10: 112 | if C == 1: 113 | print('C touched the lower limit') 114 | C_list = np.linspace(0.2, 2, 10) 115 | else: 116 | print('C touched the upper limit') 117 | C_list = list(range(1, 10)) 118 | if gamma == 0.01 or gamma == 0.1: 119 | #C和gamma都在边界 120 | if gamma == 0.01: 121 | print('gamma touched the lower limit') 122 | gamma_list = np.linspace(0.002, 0.02, 10) 123 | else: 124 | print('gamma touched the upper limit') 125 | gamma_list = np.linspace(0.02, 0.2, 10) 126 | (C, gamma), score = self.gscv_para(C_list, gamma_list, x_train, y_train) 127 | else: 128 | pass 129 | 130 | if gamma == 0.01 or gamma == 0.1: 131 | #仅仅是gamma在边界 132 | if gamma == 0.01: 133 | print('gamma touched the lower limit') 134 | gamma_list = np.linspace(0.002, 0.02, 10) 135 | else: 136 | print('gamma touched the upper limit') 137 | gamma_list = np.linspace(0.02, 0.2, 10) 138 | (C, gamma), score = self.gscv_para(C_list, gamma_list, x_train, y_train) 139 | else: 140 | pass 141 | return C, gamma, score 142 | 143 | def predict(self, begin_year): 144 | ''' 145 | 用拟合的分类器对样本外的数据做预测 146 | ''' 147 | train_data, test_data = self.data_split(begin_year) 148 | x_train, y_train = train_data.iloc[:, :-1], train_data.iloc[:, -1] 149 | x_test, y_test = test_data.iloc[:, :-1], test_data.iloc[:, -1] 150 | C, gamma, score = self.svm_fit(x_train, y_train) 151 | # C, gamma, score, _ = paras[begin_year] 152 | clf = SVC(C=C, gamma=gamma, class_weight='balanced', cache_size=4000) 153 | clf.fit(x_train, y_train) 154 | test_score = clf.score(x_test, y_test) #测试集上的准确率 155 | para = C, gamma, score, test_score 156 | print('%d, %.4f, %.4f, %.4f' % para) 157 | #在测试集上的预测值 158 | y_predict = pd.Series(clf.predict(x_test), index=y_test.index, name='predict') 159 | return para, pd.concat([y_test, y_predict], axis=1) 160 | 161 | def cum_retn(self, years): 162 | retns = {} 163 | paras = [] 164 | for i in range(years): 165 | para, label = self.predict(i) 166 | paras.append(para) 167 | for j in range(len(label)): 168 | date = label.index[j] 169 | if label.loc[date, 'label'] == label.loc[date, 'predict']: 170 | #预测正确,则获取多空收益 171 | retn_i = abs(data.loc[date, 'close'] / data.loc[date, 'open'] - 1) 172 | retns[date] = retn_i - self.slippage 173 | else: 174 | #预测失误,产生损失 175 | retn_i = -abs(data.loc[date, 'close'] / data.loc[date, 'open'] - 1) 176 | retns[date] = retn_i - self.slippage 177 | return paras, 1 + pd.Series(retns).sort_index().cumsum() 178 | 179 | if __name__ == '__main__': 180 | 181 | data = pd.read_excel('E:/Data/HS300_05_18.xlsx', index_col='date') 182 | SVM.data = data 183 | hs = SVM(5, 20) 184 | attr = hs.data_pre() 185 | plt.figure(figsize=(6, 5)) 186 | sns.heatmap(attr.iloc[:, :-1].corr()) 187 | plt.xticks(fontsize=12) 188 | plt.yticks(fontsize=12) 189 | plt.tight_layout() 190 | plt.savefig('attr_heatmap.png') 191 | total_years = int(len(data) / 240) + 1 192 | years = total_years - SVM.train_years 193 | paras, nav = hs.cum_retn(years) 194 | def drawdown(nav): 195 | #计算最大回撤 196 | dd = [] 197 | for i in range(1, len(nav)): 198 | max_i = max(nav[:i]) 199 | dd.append(min(0, nav[i] - max_i) / max_i) 200 | return dd 201 | Drawdown = pd.Series(drawdown(nav), index=nav.index[1:]) 202 | MaxDD = min(Drawdown) 203 | Total_return = nav[-1] 204 | Ann = Total_return ** (1/years) -1 205 | Sigma = nav.std() / (years**(1/2)) 206 | IR = Ann / Sigma 207 | Calmar = - Ann / MaxDD 208 | 209 | fig, ax1 = plt.subplots(figsize=(15, 8)) 210 | ax1.plot(nav, label='strategy') 211 | hs300 = data['close'][-len(nav):] / data['close'][-len(nav)] 212 | ax1.set_xlim(nav.index[0], nav.index[-1]) 213 | ax1.plot(hs300, label='HS300') 214 | ax1.set_ylabel('Net Asset Value', fontdict={'fontsize':16}) 215 | ax1.set_xlabel('Date', fontdict={'fontsize':16}) 216 | ax1.legend(loc='center left', fontsize=14) 217 | ax2 = ax1.twinx() 218 | ax2.set_ylim(-1.5, 0) 219 | ax2.plot(Drawdown, color='c') 220 | ax2.set_ylabel('Max Drawdown', fontdict={'fontsize':16}) 221 | ax2.fill_between(Drawdown.index, Drawdown, color='c') 222 | ax2.set_ylim(-1.5, 0) 223 | words = '''slippage: {:8.4f} 224 | Total return: {:8.2f} 225 | Annual return: {:8.2f} 226 | Volatility: {:8.2f} 227 | IR: {:8.2f} 228 | Max Drawdown: {:8.2f} 229 | Calmar: {:8.2f}'''.format(hs.slippage, Total_return, Ann, Sigma, IR, MaxDD, Calmar) 230 | ax2.text(ax2.get_xbound()[0]+1335, -0.8, words, 231 | fontsize=16, horizontalalignment='right', 232 | bbox=dict(boxstyle='square', fc='white')) 233 | plt.savefig('v_svm.png', bbox_inches='tight') 234 | 235 | for i in range(4): 236 | train_data, test_data = hs.data_split(i) 237 | x_train, y_train = train_data.iloc[:, :-1], train_data.iloc[:, -1] 238 | x_test, y_test = test_data.iloc[:, :-1], test_data.iloc[:, -1] 239 | C, gamma, *res = paras[i] 240 | clf = SVC(C=C, gamma=gamma, class_weight='balanced', cache_size=4000) 241 | clf.fit(x_train, y_train) 242 | print('%.4f' % clf.score(x_train, y_train)) 243 | --------------------------------------------------------------------------------