├── README.md ├── ex01 ├── DC.py ├── Init_StockALL_Sp.py └── SVM.py ├── ex02 ├── Cap_Update_daily.py ├── DC.py ├── Deal.py ├── Filter.py ├── Init_StockALL_Sp.py ├── Model_Evaluate.py ├── Operator.py ├── Portfolio.py ├── SVM.py └── main.py └── 可视化 ├── db.py └── drawbodonglv.py /README.md: -------------------------------------------------------------------------------- 1 | # python_toshare_practice 2 | 基于toshare的机器量化分析 3 | 4 | ### ex01 5 | 数据采集、预处理与建模 6 | ### ex02 7 | 模型评估与仓位管理、模拟交易与回测 8 | ### 可视化 9 | 利用Tushare金融数据接口获取股票数据,保存到mysql数据。 10 | 单一波动率指标计算 + 蒙特卡洛模拟Markowitz模型 + 风险最小模型 + 夏普最优组合 11 | -------------------------------------------------------------------------------- /ex01/DC.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf8 -*- 2 | import numpy as np 3 | import pymysql 4 | 5 | 6 | class data_collect(object): 7 | 8 | def __init__(self, in_code,start_dt,end_dt): 9 | ans = self.collectDATA(in_code,start_dt,end_dt) 10 | 11 | def collectDATA(self,in_code,start_dt,end_dt): 12 | # 建立数据库连接,获取日线基础行情(开盘价,收盘价,最高价,最低价,成交量,成交额) 13 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 14 | cursor = db.cursor() 15 | sql_done_set = "SELECT * FROM stock_all a where stock_code = '%s' and state_dt >= '%s' and state_dt <= '%s' order by state_dt asc" % (in_code, start_dt, end_dt) 16 | cursor.execute(sql_done_set) 17 | done_set = cursor.fetchall() 18 | if len(done_set) == 0: 19 | raise Exception 20 | self.date_seq = [] 21 | self.open_list = [] 22 | self.close_list = [] 23 | self.high_list = [] 24 | self.low_list = [] 25 | self.vol_list = [] 26 | self.amount_list = [] 27 | for i in range(len(done_set)): 28 | self.date_seq.append(done_set[i][0]) 29 | self.open_list.append(float(done_set[i][2])) 30 | self.close_list.append(float(done_set[i][3])) 31 | self.high_list.append(float(done_set[i][4])) 32 | self.low_list.append(float(done_set[i][5])) 33 | self.vol_list.append(float(done_set[i][6])) 34 | self.amount_list.append(float(done_set[i][7])) 35 | cursor.close() 36 | db.close() 37 | # 将日线行情整合为训练集(其中self.train是输入集,self.target是输出集,self.test_case是end_dt那天的单条测试输入) 38 | self.data_train = [] 39 | self.data_target = [] 40 | self.data_target_onehot = [] 41 | self.cnt_pos = 0 42 | self.test_case = [] 43 | 44 | for i in range(1,len(self.close_list)): 45 | train = [self.open_list[i-1],self.close_list[i-1],self.high_list[i-1],self.low_list[i-1],self.vol_list[i-1],self.amount_list[i-1]] 46 | self.data_train.append(np.array(train)) 47 | 48 | if self.close_list[i]/self.close_list[i-1] > 1.0: 49 | self.data_target.append(float(1.00)) 50 | self.data_target_onehot.append([1,0,0]) 51 | else: 52 | self.data_target.append(float(0.00)) 53 | self.data_target_onehot.append([0,1,0]) 54 | self.cnt_pos =len([x for x in self.data_target if x == 1.00]) 55 | self.test_case = np.array([self.open_list[-1],self.close_list[-1],self.high_list[-1],self.low_list[-1],self.vol_list[-1],self.amount_list[-1]]) 56 | self.data_train = np.array(self.data_train) 57 | self.data_target = np.array(self.data_target) 58 | return 1 -------------------------------------------------------------------------------- /ex01/Init_StockALL_Sp.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import tushare as ts 3 | import pymysql 4 | 5 | if __name__ == '__main__': 6 | 7 | # 设置tushare pro的token并获取连接 8 | ts.set_token('ec02769191b25376fcfeb57aa35078b8c85e5dd35808fda58b395221') 9 | pro = ts.pro_api() 10 | # 设定获取日线行情的初始日期和终止日期,其中终止日期设定为昨天。 11 | start_dt = '20100101' 12 | time_temp = datetime.datetime.now() - datetime.timedelta(days=1) 13 | end_dt = time_temp.strftime('%Y%m%d') 14 | # 建立数据库连接,剔除已入库的部分 15 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 16 | cursor = db.cursor() 17 | # 设定需要获取数据的股票池 18 | stock_pool = ['603912.SH','300666.SZ','300618.SZ','002049.SZ','300672.SZ'] 19 | total = len(stock_pool) 20 | # 循环获取单个股票的日线行情 21 | for i in range(len(stock_pool)): 22 | try: 23 | df = pro.daily(ts_code=stock_pool[i], start_date=start_dt, end_date=end_dt) 24 | # 打印进度 25 | print('Seq: ' + str(i+1) + ' of ' + str(total) + ' Code: ' + str(stock_pool[i])) 26 | c_len = df.shape[0] 27 | except Exception as aa: 28 | print(aa) 29 | print('No DATA Code: ' + str(i)) 30 | continue 31 | for j in range(c_len): 32 | resu0 = list(df.ix[c_len-1-j]) 33 | resu = [] 34 | for k in range(len(resu0)): 35 | if str(resu0[k]) == 'nan': 36 | resu.append(-1) 37 | else: 38 | resu.append(resu0[k]) 39 | state_dt = (datetime.datetime.strptime(resu[1], "%Y%m%d")).strftime('%Y-%m-%d') 40 | try: 41 | sql_insert = "INSERT INTO stock_all(state_dt,stock_code,open,close,high,low,vol,amount,pre_close,amt_change,pct_change) VALUES ('%s', '%s', '%.2f', '%.2f','%.2f','%.2f','%i','%.2f','%.2f','%.2f','%.2f')" % (state_dt,str(resu[0]),float(resu[2]),float(resu[5]),float(resu[3]),float(resu[4]),float(resu[9]),float(resu[10]),float(resu[6]),float(resu[7]),float(resu[8])) 42 | cursor.execute(sql_insert) 43 | db.commit() 44 | except Exception as err: 45 | continue 46 | cursor.close() 47 | db.close() 48 | print('All Finished!') 49 | -------------------------------------------------------------------------------- /ex01/SVM.py: -------------------------------------------------------------------------------- 1 | from sklearn import svm 2 | import DC 3 | 4 | if __name__ == '__main__': 5 | stock = '002049.SZ' 6 | dc = DC.data_collect(stock, '2017-03-01', '2018-03-01') 7 | train = dc.data_train 8 | target = dc.data_target 9 | test_case = [dc.test_case] 10 | model = svm.SVC() # 建模 11 | model.fit(train, target) # 训练 12 | ans2 = model.predict(test_case) # 预测 13 | # 输出对2018-03-02的涨跌预测,1表示涨,0表示不涨。 14 | print(ans2[0]) 15 | 16 | 17 | -------------------------------------------------------------------------------- /ex02/Cap_Update_daily.py: -------------------------------------------------------------------------------- 1 | import pymysql 2 | 3 | def cap_update_daily(state_dt): 4 | para_norisk = (1.0 + 0.04/365) 5 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 6 | cursor = db.cursor() 7 | sql_pool = "select * from my_stock_pool" 8 | cursor.execute(sql_pool) 9 | done_set = cursor.fetchall() 10 | db.commit() 11 | new_lock_cap = 0.00 12 | for i in range(len(done_set)): 13 | stock_code = str(done_set[i][0]) 14 | stock_vol = float(done_set[i][2]) 15 | sql = "select * from stock_info a where a.stock_code = '%s' and a.state_dt <= '%s' order by a.state_dt desc limit 1"%(stock_code,state_dt) 16 | cursor.execute(sql) 17 | done_temp = cursor.fetchall() 18 | db.commit() 19 | if len(done_temp) > 0: 20 | cur_close_price = float(done_temp[0][3]) 21 | new_lock_cap += cur_close_price * stock_vol 22 | else: 23 | print('Cap_Update_daily Err!!') 24 | raise Exception 25 | sql_cap = "select * from my_capital order by seq asc" 26 | cursor.execute(sql_cap) 27 | done_cap = cursor.fetchall() 28 | db.commit() 29 | new_cash_cap = float(done_cap[-1][2]) * para_norisk 30 | new_total_cap = new_cash_cap + new_lock_cap 31 | sql_insert = "insert into my_capital(capital,money_lock,money_rest,bz,state_dt)values('%.2f','%.2f','%.2f','%s','%s')"%(new_total_cap,new_lock_cap,new_cash_cap,str('Daily_Update'),state_dt) 32 | cursor.execute(sql_insert) 33 | db.commit() 34 | return 1 -------------------------------------------------------------------------------- /ex02/DC.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf8 -*- 2 | import numpy as np 3 | import pymysql 4 | 5 | 6 | class data_collect(object): 7 | 8 | def __init__(self, in_code,start_dt,end_dt): 9 | ans = self.collectDATA(in_code,start_dt,end_dt) 10 | 11 | def collectDATA(self,in_code,start_dt,end_dt): 12 | # 建立数据库连接,获取日线基础行情(开盘价,收盘价,最高价,最低价,成交量,成交额) 13 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 14 | cursor = db.cursor() 15 | sql_done_set = "SELECT * FROM stock_all a where stock_code = '%s' and state_dt >= '%s' and state_dt <= '%s' order by state_dt asc" % (in_code, start_dt, end_dt) 16 | cursor.execute(sql_done_set) 17 | done_set = cursor.fetchall() 18 | if len(done_set) == 0: 19 | raise Exception 20 | self.date_seq = [] 21 | self.open_list = [] 22 | self.close_list = [] 23 | self.high_list = [] 24 | self.low_list = [] 25 | self.vol_list = [] 26 | self.amount_list = [] 27 | for i in range(len(done_set)): 28 | self.date_seq.append(done_set[i][0]) 29 | self.open_list.append(float(done_set[i][2])) 30 | self.close_list.append(float(done_set[i][3])) 31 | self.high_list.append(float(done_set[i][4])) 32 | self.low_list.append(float(done_set[i][5])) 33 | self.vol_list.append(float(done_set[i][6])) 34 | self.amount_list.append(float(done_set[i][7])) 35 | cursor.close() 36 | db.close() 37 | # 将日线行情整合为训练集(其中self.train是输入集,self.target是输出集,self.test_case是end_dt那天的单条测试输入) 38 | self.data_train = [] 39 | self.data_target = [] 40 | self.data_target_onehot = [] 41 | self.cnt_pos = 0 42 | self.test_case = [] 43 | 44 | for i in range(1,len(self.close_list)): 45 | train = [self.open_list[i-1],self.close_list[i-1],self.high_list[i-1],self.low_list[i-1],self.vol_list[i-1],self.amount_list[i-1]] 46 | self.data_train.append(np.array(train)) 47 | 48 | if self.close_list[i]/self.close_list[i-1] > 1.0: 49 | self.data_target.append(float(1.00)) 50 | self.data_target_onehot.append([1,0,0]) 51 | else: 52 | self.data_target.append(float(0.00)) 53 | self.data_target_onehot.append([0,1,0]) 54 | self.cnt_pos =len([x for x in self.data_target if x == 1.00]) 55 | self.test_case = np.array([self.open_list[-1],self.close_list[-1],self.high_list[-1],self.low_list[-1],self.vol_list[-1],self.amount_list[-1]]) 56 | self.data_train = np.array(self.data_train) 57 | self.data_target = np.array(self.data_target) 58 | return 1 -------------------------------------------------------------------------------- /ex02/Deal.py: -------------------------------------------------------------------------------- 1 | import pymysql.cursors 2 | 3 | class Deal(object): 4 | cur_capital = 0.00 5 | cur_money_lock = 0.00 6 | cur_money_rest = 0.00 7 | stock_pool = [] 8 | stock_map1 = {} 9 | stock_map2 = {} 10 | stock_map3 = {} 11 | stock_all = [] 12 | ban_list = [] 13 | 14 | def __init__(self,state_dt): 15 | # 建立数据库连接 16 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 17 | cursor = db.cursor() 18 | try: 19 | sql_select = 'select * from my_capital a order by seq desc limit 1' 20 | cursor.execute(sql_select) 21 | done_set = cursor.fetchall() 22 | self.cur_capital = 0.00 23 | self.cur_money_lock = 0.00 24 | self.cur_money_rest = 0.00 25 | if len(done_set) > 0: 26 | self.cur_capital = float(done_set[0][0]) 27 | self.cur_money_rest = float(done_set[0][2]) 28 | sql_select2 = 'select * from my_stock_pool' 29 | cursor.execute(sql_select2) 30 | done_set2 = cursor.fetchall() 31 | self.stock_pool = [] 32 | self.stock_all = [] 33 | self.stock_map1 = [] 34 | self.stock_map2 = [] 35 | self.stock_map3 = [] 36 | self.ban_list = [] 37 | if len(done_set2) > 0: 38 | self.stock_pool = [x[0] for x in done_set2 if x[2] > 0] 39 | self.stock_all = [x[0] for x in done_set2] 40 | self.stock_map1 = {x[0]: float(x[1]) for x in done_set2} 41 | self.stock_map2 = {x[0]: int(x[2]) for x in done_set2} 42 | self.stock_map3 = {x[0]: int(x[3]) for x in done_set2} 43 | for i in range(len(done_set2)): 44 | sql = "select * from stock_info a where a.stock_code = '%s' and a.state_dt = '%s'"%(done_set2[i][0],state_dt) 45 | cursor.execute(sql) 46 | done_temp = cursor.fetchall() 47 | db.commit() 48 | self.cur_money_lock += float(done_temp[0][3]) * float(done_set2[i][2]) 49 | # sql_select3 = 'select * from ban_list' 50 | # cursor.execute(sql_select3) 51 | # done_set3 = cursor.fetchall() 52 | # if len(done_set3) > 0: 53 | # self.ban_list = [x[0] for x in done_set3] 54 | 55 | 56 | except Exception as excp: 57 | #db.rollback() 58 | print(excp) 59 | 60 | db.close() 61 | -------------------------------------------------------------------------------- /ex02/Filter.py: -------------------------------------------------------------------------------- 1 | import pymysql.cursors 2 | import Deal 3 | import Operator 4 | 5 | def filter_main(stock_new,state_dt,predict_dt,poz): 6 | # 建立数据库连接 7 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 8 | cursor = db.cursor() 9 | 10 | #先更新持股天数 11 | sql_update_hold_days = 'update my_stock_pool w set w.hold_days = w.hold_days + 1' 12 | cursor.execute(sql_update_hold_days) 13 | db.commit() 14 | 15 | #先卖出 16 | deal = Deal.Deal(state_dt) 17 | stock_pool_local = deal.stock_pool 18 | for stock in stock_pool_local: 19 | sql_predict = "select predict from model_ev_resu a where a.state_dt = '%s' and a.stock_code = '%s'"%(predict_dt,stock) 20 | cursor.execute(sql_predict) 21 | done_set_predict = cursor.fetchall() 22 | predict = 0 23 | if len(done_set_predict) > 0: 24 | predict = int(done_set_predict[0][0]) 25 | ans = Operator.sell(stock,state_dt,predict) 26 | 27 | #后买入 28 | for stock_index in range(len(stock_new)): 29 | deal_buy = Deal.Deal(state_dt) 30 | 31 | # # 如果模型f1分值低于50则不买入 32 | # sql_f1_check = "select * from model_ev_resu a where a.stock_code = '%s' and a.state_dt < '%s' order by a.state_dt desc limit 1"%(stock_new[stock_index],state_dt) 33 | # cursor.execute(sql_f1_check) 34 | # done_check = cursor.fetchall() 35 | # db.commit() 36 | # if len(done_check) > 0: 37 | # if float(done_check[0][4]) < 0.5: 38 | # print('F1 Warning !!') 39 | # continue 40 | 41 | 42 | ans = Operator.buy(stock_new[stock_index],state_dt,poz[stock_index]*deal_buy.cur_money_rest) 43 | del deal_buy 44 | db.close() 45 | -------------------------------------------------------------------------------- /ex02/Init_StockALL_Sp.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import tushare as ts 3 | import pymysql 4 | 5 | if __name__ == '__main__': 6 | 7 | # 设置tushare pro的token并获取连接 8 | ts.set_token('ec02769191b25376fcfeb57aa35078b8c85e5dd35808fda58b395221') 9 | pro = ts.pro_api() 10 | # 设定获取日线行情的初始日期和终止日期,其中终止日期设定为昨天。 11 | start_dt = '20160101' 12 | time_temp = datetime.datetime.now() - datetime.timedelta(days=1) 13 | end_dt = time_temp.strftime('%Y%m%d') 14 | # 建立数据库连接,剔除已入库的部分 15 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 16 | cursor = db.cursor() 17 | # 设定需要获取数据的股票池 18 | stock_pool = ['603912.SH','300666.SZ','300618.SZ','002049.SZ','300672.SZ'] 19 | total = len(stock_pool) 20 | # 循环获取单个股票的日线行情 21 | for i in range(len(stock_pool)): 22 | try: 23 | df = pro.daily(ts_code=stock_pool[i], start_date=start_dt, end_date=end_dt) 24 | print('Seq: ' + str(i+1) + ' of ' + str(total) + ' Code: ' + str(stock_pool[i])) 25 | c_len = df.shape[0] 26 | except Exception as aa: 27 | print(aa) 28 | print('No DATA Code: ' + str(i)) 29 | continue 30 | for j in range(c_len): 31 | resu0 = list(df.ix[c_len-1-j]) 32 | resu = [] 33 | for k in range(len(resu0)): 34 | if str(resu0[k]) == 'nan': 35 | resu.append(-1) 36 | else: 37 | resu.append(resu0[k]) 38 | state_dt = (datetime.datetime.strptime(resu[1], "%Y%m%d")).strftime('%Y-%m-%d') 39 | try: 40 | sql_insert = "INSERT INTO stock_all(state_dt,stock_code,open,close,high,low,vol,amount,pre_close,amt_change,pct_change) VALUES ('%s', '%s', '%.2f', '%.2f','%.2f','%.2f','%i','%.2f','%.2f','%.2f','%.2f')" % (state_dt,str(resu[0]),float(resu[2]),float(resu[5]),float(resu[3]),float(resu[4]),float(resu[9]),float(resu[10]),float(resu[6]),float(resu[7]),float(resu[8])) 41 | cursor.execute(sql_insert) 42 | db.commit() 43 | except Exception as err: 44 | continue 45 | cursor.close() 46 | db.close() 47 | print('All Finished!') 48 | -------------------------------------------------------------------------------- /ex02/Model_Evaluate.py: -------------------------------------------------------------------------------- 1 | from sklearn import svm 2 | import pymysql.cursors 3 | import datetime 4 | import DC 5 | import tushare as ts 6 | 7 | 8 | def model_eva(stock,state_dt,para_window,para_dc_window): 9 | # 建立数据库连接,设置tushare token 10 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 11 | cursor = db.cursor() 12 | ts.set_token('ec02769191b25376fcfeb57aa35078b8c85e5dd35808fda58b395221') 13 | pro = ts.pro_api() 14 | # 建评估时间序列, para_window参数代表回测窗口长度 15 | model_test_date_start = (datetime.datetime.strptime(state_dt, '%Y-%m-%d') - datetime.timedelta(days=para_window)).strftime( 16 | '%Y%m%d') 17 | model_test_date_end = state_dt 18 | df = pro.trade_cal(exchange_id='', is_open = 1,start_date=model_test_date_start, end_date=model_test_date_end) 19 | date_temp = list(df.iloc[:,1]) 20 | model_test_date_seq = [(datetime.datetime.strptime(x, "%Y%m%d")).strftime('%Y-%m-%d') for x in date_temp] 21 | # 清空评估用的中间表model_ev_mid 22 | sql_truncate_model_test = 'truncate table model_ev_mid' 23 | cursor.execute(sql_truncate_model_test) 24 | db.commit() 25 | return_flag = 0 26 | # 开始回测,其中para_dc_window参数代表建模时数据预处理所需的时间窗长度 27 | for d in range(len(model_test_date_seq)): 28 | model_test_new_start = (datetime.datetime.strptime(model_test_date_seq[d], '%Y-%m-%d') - datetime.timedelta(days=para_dc_window)).strftime('%Y-%m-%d') 29 | model_test_new_end = model_test_date_seq[d] 30 | try: 31 | dc = DC.data_collect(stock, model_test_new_start, model_test_new_end) 32 | if len(set(dc.data_target)) <= 1: 33 | continue 34 | except Exception as exp: 35 | print("DC Errrrr") 36 | return_flag = 1 37 | break 38 | train = dc.data_train 39 | target = dc.data_target 40 | test_case = [dc.test_case] 41 | model = svm.SVC() # 建模 42 | model.fit(train, target) # 训练 43 | ans2 = model.predict(test_case) # 预测 44 | # 将预测结果插入到中间表 45 | sql_insert = "insert into model_ev_mid(state_dt,stock_code,resu_predict)values('%s','%s','%.2f')" % (model_test_new_end, stock, float(ans2[0])) 46 | cursor.execute(sql_insert) 47 | db.commit() 48 | if return_flag == 1: 49 | acc = recall = acc_neg = f1 = 0 50 | return -1 51 | else: 52 | # 在中间表中刷真实值 53 | for i in range(len(model_test_date_seq)): 54 | sql_select = "select * from stock_all a where a.stock_code = '%s' and a.state_dt >= '%s' order by a.state_dt asc limit 2" % (stock, model_test_date_seq[i]) 55 | cursor.execute(sql_select) 56 | done_set2 = cursor.fetchall() 57 | if len(done_set2) <= 1: 58 | break 59 | resu = 0 60 | if float(done_set2[1][3]) / float(done_set2[0][3]) > 1.00: 61 | resu = 1 62 | sql_update = "update model_ev_mid w set w.resu_real = '%.2f' where w.state_dt = '%s' and w.stock_code = '%s'" % (resu, model_test_date_seq[i], stock) 63 | cursor.execute(sql_update) 64 | db.commit() 65 | # 计算查全率 66 | sql_resu_recall_son = "select count(*) from model_ev_mid a where a.resu_real is not null and a.resu_predict = 1 and a.resu_real = 1" 67 | cursor.execute(sql_resu_recall_son) 68 | recall_son = cursor.fetchall()[0][0] 69 | sql_resu_recall_mon = "select count(*) from model_ev_mid a where a.resu_real is not null and a.resu_real = 1" 70 | cursor.execute(sql_resu_recall_mon) 71 | recall_mon = cursor.fetchall()[0][0] 72 | if recall_mon == 0: 73 | acc = recall = acc_neg = f1 = 0 74 | else: 75 | recall = recall_son / recall_mon 76 | # 计算查准率 77 | sql_resu_acc_son = "select count(*) from model_ev_mid a where a.resu_real is not null and a.resu_predict = 1 and a.resu_real = 1" 78 | cursor.execute(sql_resu_acc_son) 79 | acc_son = cursor.fetchall()[0][0] 80 | sql_resu_acc_mon = "select count(*) from model_ev_mid a where a.resu_real is not null and a.resu_predict = 1" 81 | cursor.execute(sql_resu_acc_mon) 82 | acc_mon = cursor.fetchall()[0][0] 83 | if acc_mon == 0: 84 | acc = recall = acc_neg = f1 = 0 85 | else: 86 | acc = acc_son / acc_mon 87 | # 计算查准率(负样本) 88 | sql_resu_acc_neg_son = "select count(*) from model_ev_mid a where a.resu_real is not null and a.resu_predict = -1 and a.resu_real = -1" 89 | cursor.execute(sql_resu_acc_neg_son) 90 | acc_neg_son = cursor.fetchall()[0][0] 91 | sql_resu_acc_neg_mon = "select count(*) from model_ev_mid a where a.resu_real is not null and a.resu_predict = -1" 92 | cursor.execute(sql_resu_acc_neg_mon) 93 | acc_neg_mon = cursor.fetchall()[0][0] 94 | if acc_neg_mon == 0: 95 | acc_neg_mon = -1 96 | acc_neg = -1 97 | else: 98 | acc_neg = acc_neg_son / acc_neg_mon 99 | # 计算 F1 分值 100 | if acc + recall == 0: 101 | acc = recall = acc_neg = f1 = 0 102 | else: 103 | f1 = (2 * acc * recall) / (acc + recall) 104 | sql_predict = "select resu_predict from model_ev_mid a where a.state_dt = '%s'" % (model_test_date_seq[-1]) 105 | cursor.execute(sql_predict) 106 | done_predict = cursor.fetchall() 107 | predict = 0 108 | if len(done_predict) != 0: 109 | predict = int(done_predict[0][0]) 110 | # 将评估结果存入结果表model_ev_resu中 111 | sql_final_insert = "insert into model_ev_resu(state_dt,stock_code,acc,recall,f1,acc_neg,bz,predict)values('%s','%s','%.4f','%.4f','%.4f','%.4f','%s','%s')" % (state_dt, stock, acc, recall, f1, acc_neg, 'svm', str(predict)) 112 | cursor.execute(sql_final_insert) 113 | db.commit() 114 | db.close() 115 | print(str(state_dt) + ' Precision : ' + str(acc) + ' Recall : ' + str(recall) + ' F1 : ' + str(f1) + ' Acc_Neg : ' + str(acc_neg)) 116 | return 1 117 | 118 | 119 | -------------------------------------------------------------------------------- /ex02/Operator.py: -------------------------------------------------------------------------------- 1 | import pymysql.cursors 2 | import Deal 3 | 4 | def buy(stock_code,opdate,buy_money): 5 | # 建立数据库连接 6 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 7 | cursor = db.cursor() 8 | deal_buy = Deal.Deal(opdate) 9 | #后买入 10 | if deal_buy.cur_money_rest+1 >= buy_money: 11 | sql_buy = "select * from stock_info a where a.state_dt = '%s' and a.stock_code = '%s'" % (opdate, stock_code) 12 | cursor.execute(sql_buy) 13 | done_set_buy = cursor.fetchall() 14 | if len(done_set_buy) == 0: 15 | return -1 16 | buy_price = float(done_set_buy[0][3]) 17 | if buy_price >= 195: 18 | return 0 19 | vol, rest = divmod(min(deal_buy.cur_money_rest, buy_money), buy_price * 100) 20 | vol = vol * 100 21 | if vol == 0: 22 | return 0 23 | new_capital = deal_buy.cur_capital - vol * buy_price * 0.0005 24 | new_money_lock = deal_buy.cur_money_lock + vol * buy_price 25 | new_money_rest = deal_buy.cur_money_rest - vol * buy_price * 1.0005 26 | sql_buy_update2 = "insert into my_capital(capital,money_lock,money_rest,deal_action,stock_code,stock_vol,state_dt,deal_price)VALUES ('%.2f', '%.2f', '%.2f','%s','%s','%i','%s','%.2f')" % (new_capital, new_money_lock,new_money_rest, 'buy', stock_code, vol, opdate, buy_price) 27 | cursor.execute(sql_buy_update2) 28 | db.commit() 29 | if stock_code in deal_buy.stock_all: 30 | new_buy_price = (deal_buy.stock_map1[stock_code] * deal_buy.stock_map2[stock_code] + vol * buy_price) / (deal_buy.stock_map2[stock_code] + vol) 31 | new_vol = deal_buy.stock_map2[stock_code] + vol 32 | sql_buy_update3 = "update my_stock_pool w set w.buy_price = (select '%.2f' from dual) where w.stock_code = '%s'" % (new_buy_price, stock_code) 33 | sql_buy_update3b = "update my_stock_pool w set w.hold_vol = (select '%i' from dual) where w.stock_code = '%s'" % (new_vol, stock_code) 34 | sql_buy_update3c = "update my_stock_pool w set w.hold_days = (select '%i' from dual) where w.stock_code = '%s'" % (1, stock_code) 35 | cursor.execute(sql_buy_update3) 36 | cursor.execute(sql_buy_update3b) 37 | cursor.execute(sql_buy_update3c) 38 | db.commit() 39 | else: 40 | sql_buy_update3 = "insert into my_stock_pool(stock_code,buy_price,hold_vol,hold_days) VALUES ('%s','%.2f','%i','%i')" % (stock_code, buy_price, vol, int(1)) 41 | cursor.execute(sql_buy_update3) 42 | db.commit() 43 | db.close() 44 | return 1 45 | db.close() 46 | return 0 47 | 48 | def sell(stock_code,opdate,predict): 49 | # 建立数据库连接 50 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 51 | cursor = db.cursor() 52 | 53 | deal = Deal.Deal(opdate) 54 | init_price = deal.stock_map1[stock_code] 55 | hold_vol = deal.stock_map2[stock_code] 56 | hold_days = deal.stock_map3[stock_code] 57 | sql_sell_select = "select * from stock_info a where a.state_dt = '%s' and a.stock_code = '%s'" % (opdate, stock_code) 58 | cursor.execute(sql_sell_select) 59 | done_set_sell_select = cursor.fetchall() 60 | if len(done_set_sell_select) == 0: 61 | return -1 62 | sell_price = float(done_set_sell_select[0][3]) 63 | 64 | if sell_price > init_price*1.03 and hold_vol > 0: 65 | new_money_lock = deal.cur_money_lock - sell_price*hold_vol 66 | new_money_rest = deal.cur_money_rest + sell_price*hold_vol 67 | new_capital = deal.cur_capital + (sell_price-init_price)*hold_vol 68 | new_profit = (sell_price-init_price)*hold_vol 69 | new_profit_rate = sell_price/init_price 70 | sql_sell_insert = "insert into my_capital(capital,money_lock,money_rest,deal_action,stock_code,stock_vol,profit,profit_rate,bz,state_dt,deal_price)values('%.2f','%.2f','%.2f','%s','%s','%.2f','%.2f','%.2f','%s','%s','%.2f')" %(new_capital,new_money_lock,new_money_rest,'SELL',stock_code,hold_vol,new_profit,new_profit_rate,'GOODSELL',opdate,sell_price) 71 | cursor.execute(sql_sell_insert) 72 | db.commit() 73 | sql_sell_update = "delete from my_stock_pool where stock_code = '%s'" % (stock_code) 74 | cursor.execute(sql_sell_update) 75 | db.commit() 76 | db.close() 77 | return 1 78 | 79 | elif sell_price < init_price*0.97 and hold_vol > 0: 80 | new_money_lock = deal.cur_money_lock - sell_price*hold_vol 81 | new_money_rest = deal.cur_money_rest + sell_price*hold_vol 82 | new_capital = deal.cur_capital + (sell_price-init_price)*hold_vol 83 | new_profit = (sell_price-init_price)*hold_vol 84 | new_profit_rate = sell_price/init_price 85 | sql_sell_insert2 = "insert into my_capital(capital,money_lock,money_rest,deal_action,stock_code,stock_vol,profit,profit_rate,bz,state_dt,deal_price)values('%.2f','%.2f','%.2f','%s','%s','%.2f','%.2f','%.2f','%s','%s','%.2f')" %(new_capital,new_money_lock,new_money_rest,'SELL',stock_code,hold_vol,new_profit,new_profit_rate,'BADSELL',opdate,sell_price) 86 | cursor.execute(sql_sell_insert2) 87 | db.commit() 88 | sql_sell_update2 = "delete from my_stock_pool where stock_code = '%s'" % (stock_code) 89 | cursor.execute(sql_sell_update2) 90 | db.commit() 91 | # sql_ban_insert = "insert into ban_list(stock_code) values ('%s')" %(stock_code) 92 | # cursor.execute(sql_ban_insert) 93 | # db.commit() 94 | db.close() 95 | return 1 96 | 97 | elif hold_days >= 4 and hold_vol > 0: 98 | new_money_lock = deal.cur_money_lock - sell_price * hold_vol 99 | new_money_rest = deal.cur_money_rest + sell_price * hold_vol 100 | new_capital = deal.cur_capital + (sell_price - init_price) * hold_vol 101 | new_profit = (sell_price - init_price) * hold_vol 102 | new_profit_rate = sell_price / init_price 103 | sql_sell_insert3 = "insert into my_capital(capital,money_lock,money_rest,deal_action,stock_code,stock_vol,profit,profit_rate,bz,state_dt,deal_price)values('%.2f','%.2f','%.2f','%s','%s','%.2f','%.2f','%.2f','%s','%s','%.2f')" % (new_capital, new_money_lock, new_money_rest, 'OVERTIME', stock_code, hold_vol, new_profit, new_profit_rate,'OVERTIMESELL', opdate,sell_price) 104 | cursor.execute(sql_sell_insert3) 105 | db.commit() 106 | sql_sell_update3 = "delete from my_stock_pool where stock_code = '%s'" % (stock_code) 107 | cursor.execute(sql_sell_update3) 108 | db.commit() 109 | db.close() 110 | return 1 111 | 112 | elif predict == -1: 113 | new_money_lock = deal.cur_money_lock - sell_price * hold_vol 114 | new_money_rest = deal.cur_money_rest + sell_price * hold_vol 115 | new_capital = deal.cur_capital + (sell_price - init_price) * hold_vol 116 | new_profit = (sell_price - init_price) * hold_vol 117 | new_profit_rate = sell_price / init_price 118 | sql_sell_insert4 = "insert into my_capital(capital,money_lock,money_rest,deal_action,stock_code,stock_vol,profit,profit_rate,bz,state_dt,deal_price)values('%.2f','%.2f','%.2f','%s','%s','%.2f','%.2f','%.2f','%s','%s','%.2f')" % ( 119 | new_capital, new_money_lock, new_money_rest, 'Predict', stock_code, hold_vol, new_profit, new_profit_rate, 120 | 'PredictSell', opdate, sell_price) 121 | cursor.execute(sql_sell_insert4) 122 | db.commit() 123 | sql_sell_update3 = "delete from my_stock_pool where stock_code = '%s'" % (stock_code) 124 | cursor.execute(sql_sell_update3) 125 | db.commit() 126 | db.close() 127 | return 1 128 | db.close() 129 | return 0 130 | 131 | -------------------------------------------------------------------------------- /ex02/Portfolio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | import pymysql 4 | import copy 5 | import tushare as ts 6 | 7 | 8 | # 返回的resu中 特征值按由小到大排列,对应的是其特征向量 9 | def get_portfolio(stock_list,state_dt,para_window): 10 | # 建数据库连接,设置Tushare的token 11 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 12 | cursor = db.cursor() 13 | ts.set_token('ec02769191b25376fcfeb57aa35078b8c85e5dd35808fda58b395221') 14 | pro = ts.pro_api() 15 | 16 | portfilio = stock_list 17 | 18 | # 建评估时间序列, para_window参数代表回测窗口长度 19 | model_test_date_start = (datetime.datetime.strptime(state_dt, '%Y-%m-%d') - datetime.timedelta(days=para_window)).strftime( 20 | '%Y%m%d') 21 | model_test_date_end = (datetime.datetime.strptime(state_dt, "%Y-%m-%d")).strftime('%Y%m%d') 22 | df = pro.trade_cal(exchange_id='', is_open=1, start_date=model_test_date_start, end_date=model_test_date_end) 23 | date_temp = list(df.iloc[:, 1]) 24 | model_test_date_seq = [(datetime.datetime.strptime(x, "%Y%m%d")).strftime('%Y-%m-%d') for x in date_temp] 25 | 26 | list_return = [] 27 | for i in range(len(model_test_date_seq)-4): 28 | ti = model_test_date_seq[i] 29 | ri = [] 30 | for j in range(len(portfilio)): 31 | sql_select = "select * from stock_all a where a.stock_code = '%s' and a.state_dt >= '%s' and a.state_dt <= '%s' order by state_dt asc" % (portfilio[j], model_test_date_seq[i], model_test_date_seq[i + 4]) 32 | cursor.execute(sql_select) 33 | done_set = cursor.fetchall() 34 | db.commit() 35 | temp = [x[3] for x in done_set] 36 | base_price = 0.00 37 | after_mean_price = 0.00 38 | if len(temp) <= 1: 39 | r = 0.00 40 | else: 41 | base_price = temp[0] 42 | after_mean_price = np.array(temp[1:]).mean() 43 | r = (float(after_mean_price/base_price)-1.00)*100.00 44 | ri.append(r) 45 | del done_set 46 | del temp 47 | del base_price 48 | del after_mean_price 49 | list_return.append(ri) 50 | 51 | # 求协方差矩阵 52 | cov = np.cov(np.array(list_return).T) 53 | # 求特征值和其对应的特征向量 54 | ans = np.linalg.eig(cov) 55 | # 排序,特征向量中负数置0,非负数归一 56 | ans_index = copy.copy(ans[0]) 57 | ans_index.sort() 58 | resu = [] 59 | for k in range(len(ans_index)): 60 | con_temp = [] 61 | con_temp.append(ans_index[k]) 62 | content_temp1 = ans[1][np.argwhere(ans[0] == ans_index[k])[0][0]] 63 | content_temp2 = [] 64 | content_sum = np.array([x for x in content_temp1 if x >= 0.00]).sum() 65 | for m in range(len(content_temp1)): 66 | if content_temp1[m] >= 0 and content_sum > 0: 67 | content_temp2.append(content_temp1[m]/content_sum) 68 | else: 69 | content_temp2.append(0.00) 70 | con_temp.append(content_temp2) 71 | # 计算夏普率 72 | sharp_temp = np.array(copy.copy(list_return)) * content_temp2 73 | sharp_exp = sharp_temp.mean() 74 | sharp_base = 0.04 75 | sharp_std = np.std(sharp_temp) 76 | if sharp_std == 0.00: 77 | sharp = 0.00 78 | else: 79 | sharp = (sharp_exp - sharp_base) / sharp_std 80 | 81 | con_temp.append(sharp) 82 | resu.append(con_temp) 83 | 84 | return resu 85 | 86 | if __name__ == '__main__': 87 | pf = ['603912.SH', '300666.SZ', '300618.SZ', '002049.SZ', '300672.SZ'] 88 | ans = get_portfolio(pf,'2018-01-01',90) 89 | print('************** Market Trend ****************') 90 | print('Risk : ' + str(round(ans[0][0], 2))) 91 | print('Sharp ratio : ' + str(round(ans[0][2], 2))) 92 | 93 | for i in range(5): 94 | print('----------------------------------------------') 95 | print('Stock_code : ' + str(pf[i]) + ' Position : ' + str(round(ans[0][1][i] * 100, 2)) + '%') 96 | print('----------------------------------------------') 97 | 98 | print('************** Best Return *****************') 99 | print('Risk : ' + str(round(ans[1][0], 2))) 100 | print('Sharp ratio : ' + str(round(ans[1][2], 2))) 101 | for j in range(5): 102 | print('----------------------------------------------') 103 | print('Stock_code : ' + str(pf[j]) + ' Position : ' + str( 104 | round(ans[1][1][j] * 100, 2)) + '%') 105 | print('----------------------------------------------') 106 | -------------------------------------------------------------------------------- /ex02/SVM.py: -------------------------------------------------------------------------------- 1 | from sklearn import svm 2 | import DC 3 | 4 | if __name__ == '__main__': 5 | stock = '002049.SZ' 6 | dc = DC.data_collect(stock, '2017-03-01', '2018-03-01') 7 | train = dc.data_train 8 | target = dc.data_target 9 | test_case = [dc.test_case] 10 | model = svm.SVC() # 建模 11 | model.fit(train, target) # 训练 12 | ans2 = model.predict(test_case) # 预测 13 | # 输出对2018-03-02的涨跌预测,1表示涨,0表示不涨。 14 | print(ans2[0]) 15 | 16 | 17 | -------------------------------------------------------------------------------- /ex02/main.py: -------------------------------------------------------------------------------- 1 | import pymysql 2 | import Model_Evaluate as ev 3 | import Filter 4 | import Portfolio as pf 5 | from pylab import * 6 | import Cap_Update_daily as cap_update 7 | import tushare as ts 8 | import datetime 9 | import matplotlib.pyplot as plt 10 | 11 | import warnings 12 | 13 | 14 | def get_sharp_rate(): 15 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 16 | cursor = db.cursor() 17 | 18 | sql_cap = "select * from my_capital a order by seq asc" 19 | cursor.execute(sql_cap) 20 | done_exp = cursor.fetchall() 21 | db.commit() 22 | cap_list = [float(x[0]) for x in done_exp] 23 | return_list = [] 24 | base_cap = float(done_exp[0][0]) 25 | for i in range(len(cap_list)): 26 | if i == 0: 27 | return_list.append(float(1.00)) 28 | else: 29 | ri = (float(done_exp[i][0]) - float(done_exp[0][0]))/float(done_exp[0][0]) 30 | return_list.append(ri) 31 | std = float(np.array(return_list).std()) 32 | exp_portfolio = (float(done_exp[-1][0]) - float(done_exp[0][0]))/float(done_exp[0][0]) 33 | exp_norisk = 0.04*(5.0/12.0) 34 | sharp_rate = (exp_portfolio - exp_norisk)/(std) 35 | 36 | return sharp_rate,std 37 | 38 | if __name__ == '__main__': 39 | 40 | warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn", lineno=193) 41 | 42 | # 建立数据库连接,设置tushare的token,定义一些初始化参数 43 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='123456', db='stock', charset='utf8') 44 | cursor = db.cursor() 45 | ts.set_token('ec02769191b25376fcfeb57aa35078b8c85e5dd35808fda58b395221') 46 | pro = ts.pro_api() 47 | year = 2018 48 | date_seq_start = str(year) + '-03-01' 49 | date_seq_end = str(year) + '-04-01' 50 | stock_pool = ['603912.SH', '300666.SZ', '300618.SZ', '002049.SZ', '300672.SZ'] 51 | 52 | # 先清空之前的测试记录,并创建中间表 53 | sql_wash1 = 'delete from my_capital where seq != 1' 54 | cursor.execute(sql_wash1) 55 | db.commit() 56 | sql_wash3 = 'truncate table my_stock_pool' 57 | cursor.execute(sql_wash3) 58 | db.commit() 59 | # 清空行情源表,并插入相关股票的行情数据。该操作是为了提高回测计算速度而剔除行情表(stock_all)中的冗余数据。 60 | sql_wash4 = 'truncate table stock_info' 61 | cursor.execute(sql_wash4) 62 | db.commit() 63 | in_str = '(' 64 | for x in range(len(stock_pool)): 65 | if x != len(stock_pool)-1: 66 | in_str += str('\'') + str(stock_pool[x])+str('\',') 67 | else: 68 | in_str += str('\'') + str(stock_pool[x]) + str('\')') 69 | sql_insert = "insert into stock_info(select * from stock_all a where a.stock_code in %s)"%(in_str) 70 | cursor.execute(sql_insert) 71 | db.commit() 72 | 73 | 74 | # 建回测时间序列 75 | back_test_date_start = (datetime.datetime.strptime(date_seq_start, '%Y-%m-%d')).strftime('%Y%m%d') 76 | back_test_date_end = (datetime.datetime.strptime(date_seq_end, "%Y-%m-%d")).strftime('%Y%m%d') 77 | df = pro.trade_cal(exchange_id='', is_open=1, start_date=back_test_date_start, end_date=back_test_date_end) 78 | date_temp = list(df.iloc[:, 1]) 79 | date_seq = [(datetime.datetime.strptime(x, "%Y%m%d")).strftime('%Y-%m-%d') for x in date_temp] 80 | print(date_seq) 81 | 82 | #开始模拟交易 83 | index = 1 84 | day_index = 0 85 | for i in range(1,len(date_seq)): 86 | day_index += 1 87 | # 每日推进式建模,并获取对下一个交易日的预测结果 88 | for stock in stock_pool: 89 | try: 90 | ans2 = ev.model_eva(stock,date_seq[i],90,365) 91 | # print('Date : ' + str(date_seq[i]) + ' Update : ' + str(stock)) 92 | except Exception as ex: 93 | print(ex) 94 | continue 95 | # 每5个交易日更新一次配仓比例 96 | if divmod(day_index+4,5)[1] == 0: 97 | portfolio_pool = stock_pool 98 | if len(portfolio_pool) < 5: 99 | print('Less than 5 stocks for portfolio!! state_dt : ' + str(date_seq[i])) 100 | continue 101 | pf_src = pf.get_portfolio(portfolio_pool,date_seq[i-1],year) 102 | # 取最佳收益方向的资产组合 103 | risk = pf_src[1][0] 104 | weight = pf_src[1][1] 105 | Filter.filter_main(portfolio_pool,date_seq[i],date_seq[i-1],weight) 106 | else: 107 | Filter.filter_main([],date_seq[i],date_seq[i - 1], []) 108 | cap_update_ans = cap_update.cap_update_daily(date_seq[i]) 109 | print('Runnig to Date : ' + str(date_seq[i])) 110 | print('ALL FINISHED!!') 111 | 112 | sharp,c_std = get_sharp_rate() 113 | print('Sharp Rate : ' + str(sharp)) 114 | print('Risk Factor : ' + str(c_std)) 115 | 116 | sql_show_btc = "select * from stock_index a where a.stock_code = 'SH' and a.state_dt >= '%s' and a.state_dt <= '%s' order by state_dt asc"%(date_seq_start,date_seq_end) 117 | cursor.execute(sql_show_btc) 118 | done_set_show_btc = cursor.fetchall() 119 | #btc_x = [x[0] for x in done_set_show_btc] 120 | btc_x = list(range(len(done_set_show_btc))) 121 | btc_y = [x[3] / done_set_show_btc[0][3] for x in done_set_show_btc] 122 | dict_anti_x = {} 123 | dict_x = {} 124 | for a in range(len(btc_x)): 125 | dict_anti_x[btc_x[a]] = done_set_show_btc[a][0] 126 | dict_x[done_set_show_btc[a][0]] = btc_x[a] 127 | 128 | #sql_show_profit = "select * from my_capital order by state_dt asc" 129 | sql_show_profit = "select max(a.capital),a.state_dt from my_capital a where a.state_dt is not null group by a.state_dt order by a.state_dt asc" 130 | cursor.execute(sql_show_profit) 131 | done_set_show_profit = cursor.fetchall() 132 | profit_x = [dict_x[x[1]] for x in done_set_show_profit] 133 | profit_y = [x[0] / done_set_show_profit[0][0] for x in done_set_show_profit] 134 | # 绘制收益率曲线(含大盘基准收益曲线) 135 | def c_fnx(val, poz): 136 | if val in dict_anti_x.keys(): 137 | return dict_anti_x[val] 138 | else: 139 | return '' 140 | 141 | 142 | fig = plt.figure(figsize=(20, 12)) 143 | ax = fig.add_subplot(111) 144 | ax.xaxis.set_major_formatter(FuncFormatter(c_fnx)) 145 | 146 | plt.plot(btc_x, btc_y, color='blue') 147 | plt.plot(profit_x, profit_y, color='red') 148 | 149 | plt.show() 150 | 151 | cursor.close() 152 | db.close() 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /可视化/db.py: -------------------------------------------------------------------------------- 1 | import pymysql 2 | 3 | config = dict(host='localhost', user='root', password='123456', 4 | cursorclass=pymysql.cursors.DictCursor) 5 | # 建立连接 6 | conn = pymysql.Connect(**config) 7 | # 自动确认commit True 8 | conn.autocommit(1) 9 | # 设置光标 10 | cursor = conn.cursor() 11 | 12 | 13 | # 一个根据pandas自动识别type来设定table的type 14 | def make_table_sql(df): 15 | columns = df.columns.tolist() 16 | types = df.ftypes 17 | # 添加id 制动递增主键模式 18 | make_table = [] 19 | for item in columns: 20 | if 'int' in types[item]: 21 | char = item + ' INT' 22 | elif 'float' in types[item]: 23 | char = item + ' FLOAT' 24 | elif 'object' in types[item]: 25 | char = item + ' VARCHAR(255)' 26 | elif 'datetime' in types[item]: 27 | char = item + ' DATETIME' 28 | make_table.append(char) 29 | return ','.join(make_table) 30 | 31 | 32 | # csv 格式输入 mysql 中 33 | def csv2mysql(db_name, table_name, df): 34 | # 创建database 35 | cursor.execute('CREATE DATABASE IF NOT EXISTS {}'.format(db_name)) 36 | # 选择连接database 37 | conn.select_db(db_name) 38 | # 创建table 39 | cursor.execute('DROP TABLE IF EXISTS {}'.format(table_name)) 40 | cursor.execute('CREATE TABLE {}({})'.format(table_name, make_table_sql(df))) 41 | # 提取数据转list 这里有与pandas时间模式无法写入因此换成str 此时mysql上格式已经设置完成 42 | df['date'] = df['date'].astype('str') 43 | values = df.values.tolist() 44 | # 根据columns个数 45 | s = ','.join(['%s' for _ in range(len(df.columns))]) 46 | # executemany批量操作 插入数据 批量操作比逐个操作速度快很多 47 | cursor.executemany('INSERT INTO {} VALUES ({})'.format(table_name, s), values) 48 | 49 | 50 | if __name__ == "__main__": 51 | pass 52 | -------------------------------------------------------------------------------- /可视化/drawbodonglv.py: -------------------------------------------------------------------------------- 1 | import tushare as ts 2 | from time import sleep 3 | from threading import Thread 4 | import matplotlib.pyplot as plt 5 | from matplotlib.widgets import Button 6 | import pandas as pd 7 | import numpy as np 8 | import wx 9 | import wx.lib.plot as plot 10 | 11 | pd.set_option('display.max_rows', None) 12 | pd.set_option('display.max_columns', None) 13 | pd.set_option('max_colwidth', 100) 14 | pd.set_option('display.width', 1000) 15 | 16 | 17 | def stock_history_data_get(code, startTime, endTime): 18 | df = ts.get_hist_data(code, start=startTime, end=endTime) 19 | df_v1 = df.reset_index() 20 | # print(df_v1) 21 | 22 | df_v2 = df_v1[['date', 'close', 'p_change']].sort_values(by='date', ascending=True) 23 | df_v2['p_change'] = df_v2['p_change'] * 0.01 24 | stv_4 = [0, 0, 0] 25 | stv_8 = [0, 0, 0, 0, 0, 0, 0] 26 | stv_12 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 27 | stv_16 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 28 | for i in range(0, len(df_v2)): 29 | sample = list(df_v2['p_change'][i:i + 4]) 30 | stv = np.std(sample) 31 | stv_4.append(stv) 32 | if i == len(df_v2) - 4: 33 | break 34 | df_v2['stv_4'] = stv_4 35 | 36 | for i in range(0, len(df_v2)): 37 | sample = list(df_v2['p_change'][i:i + 8]) 38 | stv = np.std(sample) 39 | stv_8.append(stv) 40 | if i == len(df_v2) - 8: 41 | break 42 | df_v2['stv_8'] = stv_8 43 | 44 | for i in range(0, len(df_v2)): 45 | sample = list(df_v2['p_change'][i:i + 12]) 46 | stv = np.std(sample) 47 | stv_12.append(stv) 48 | if i == len(df_v2) - 12: 49 | break 50 | df_v2['stv_12'] = stv_12 51 | 52 | for i in range(0, len(df_v2)): 53 | sample = list(df_v2['p_change'][i:i + 16]) 54 | stv = np.std(sample) 55 | stv_16.append(stv) 56 | if i == len(df_v2) - 16: 57 | break 58 | df_v2['stv_16'] = stv_16 59 | 60 | df_v2['hv_4'] = round(df_v2['stv_4'] * 12 * 1040, 2).apply(lambda x: '%.2f%%' % (x)) 61 | df_v2['hv_8'] = round(df_v2['stv_8'] * 12 * 1040, 2).apply(lambda x: '%.2f%%' % (x)) 62 | df_v2['hv_12'] = round(df_v2['stv_12'] * 12 * 1040, 2).apply(lambda x: '%.2f%%' % (x)) 63 | df_v2['hv_16'] = round(df_v2['stv_16'] * 12 * 1040, 2).apply(lambda x: '%.2f%%' % (x)) 64 | # df_v2.to_csv('//Users//hejie//AnacondaProjects//181902PythonInFinace//50.csv', index=True) 65 | df_v2.to_csv('50.csv', index=True) 66 | 67 | 68 | def getDF(code, startTime, endTime): 69 | stock_history_data_get(code, startTime, endTime) 70 | df = pd.read_csv('./50.csv', header=0) 71 | return df 72 | 73 | 74 | def getData(df): 75 | df = pd.read_csv('./50.csv', header=0) 76 | dataList = [] 77 | for index, item in enumerate(df['hv_4'].tolist()): 78 | dataList.append([index, item[:-1]]) 79 | return dataList 80 | # print(dataList) 81 | 82 | 83 | class MyFrame(wx.Frame): 84 | def __init__(self): 85 | self.frame1 = wx.Frame(None, title="sz50股票波动率", id=-1, size=(500, 350)) 86 | self.panel1 = wx.Panel(self.frame1) 87 | self.panel1.SetBackgroundColour("white") 88 | self.code = wx.TextCtrl(self.panel1, value="sz50", pos=(100, 220), size=(150, 20)) 89 | wx.StaticText(self.panel1, -1, "标签代码:", pos=(30, 220), size=(60, 20)) 90 | wx.StaticText(self.panel1, -1, "股票时间:", pos=(30, 260), size=(60, 20)) 91 | self.startTime = wx.TextCtrl(self.panel1, value="2018-03-01", pos=(100, 260), size=(100, 20)) 92 | self.endTime = wx.TextCtrl(self.panel1, value="2019-05-03", pos=(230, 260), size=(100, 20)) 93 | Button1 = wx.Button(self.panel1, -1, "查找", (280, 215)) 94 | Button1.Bind(wx.EVT_BUTTON, self.redraw) 95 | 96 | plotter = plot.PlotCanvas(self.panel1) 97 | plotter.SetInitialSize(size=(500, 200)) 98 | code = self.code.GetValue() 99 | startTime = self.startTime.GetValue() 100 | endTime = self.endTime.GetValue() 101 | self.df = getDF(code, startTime, endTime) 102 | data = getData(self.df) 103 | line = plot.PolyLine(data, colour='red', width=1) 104 | 105 | gc = plot.PlotGraphics([line], 'sz50股票波动率', '时间', '波动率') 106 | plotter.Draw(gc) 107 | 108 | self.frame1.Show(True) 109 | 110 | def redraw(self, event): 111 | plotter = plot.PlotCanvas(self.panel1) 112 | plotter.SetInitialSize(size=(500, 200)) 113 | code = self.code.GetValue() 114 | startTime = self.startTime.GetValue() 115 | endTime = self.endTime.GetValue() 116 | self.df = getDF(code, startTime, endTime) 117 | for i in range(len(self.df)): 118 | data2 = getData(self.df[:i]) 119 | line = plot.PolyLine(data2, colour='red', width=1) 120 | gc = plot.PlotGraphics([line], code + '股票波动率', '时间', '波动率') 121 | plotter.Draw(gc) 122 | 123 | 124 | app = wx.App() 125 | f = MyFrame() 126 | app.MainLoop() 127 | --------------------------------------------------------------------------------