├── .project ├── .pydevproject ├── .settings └── org.eclipse.core.resources.prefs ├── README.md ├── classification ├── __init__.py ├── calc_profit.py ├── featureSelection.py └── predictStock.py ├── clawer ├── USAStock_clawer.py ├── __init__.py └── code_name.py ├── data_download.py ├── getData ├── __init__.py └── getTrainData.py └── 美股涨跌预测系统的探究-141499罗斌.docx /.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | MLHomework34 4 | 5 | 6 | 7 | 8 | 9 | org.python.pydev.PyDevBuilder 10 | 11 | 12 | 13 | 14 | 15 | org.python.pydev.pythonNature 16 | 17 | 18 | -------------------------------------------------------------------------------- /.pydevproject: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | /${PROJECT_DIR_NAME} 5 | 6 | python 3.0 7 | C:\Python34\python.exe 8 | 9 | -------------------------------------------------------------------------------- /.settings/org.eclipse.core.resources.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | encoding//clawer/USAStock_clawer.py=utf-8 3 | encoding/=UTF-8 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stockPrediction 2 | 本项目对美股股票的涨跌进行了研究,从问题出发并提出猜想,然后定义了机器学习的实验任务。通过多次实验得到实验数据,最终证明了所提出的猜想:中国股市确实和美国股市存在着一定的联系,并且通过这些隐含的联系可以预测某些美国股票的涨跌。 3 | -------------------------------------------------------------------------------- /classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxls080511/stockPrediction/7ccb30b9f2b3537347e00b613ec0dbbaa97c6bc7/classification/__init__.py -------------------------------------------------------------------------------- /classification/calc_profit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | def calc(): 5 | log_out=open('I:\\MLHomework\\Data\\2\\clf8.log','w',encoding='utf-8')#log文件 6 | result_out=open('I:\\MLHomework\\Data\\2\\result8.log','w',encoding='utf-8')#log文件 7 | clf_dir='I:\\MLHomework\\Data\\2\\clf8' 8 | #先读入5-8月中国股市的数据 9 | chinese_stock={} 10 | with open('I:\\MLHomework\\Data\\4\\testData4m\\ASX.csv.data','r',encoding='utf-8') as f: 11 | for line in f: 12 | array=line.split('\t') 13 | date=array[0] 14 | v=int(array[len(array)-1])#最后的值不需要了 15 | l=list(0 for i in range(len(array)-2)) 16 | for i in range(1,len(array)-2): 17 | l[i]=int(array[i]) 18 | chinese_stock[date]=l 19 | 20 | us_stock_record={} 21 | for filename in os.listdir(clf_dir): 22 | us_stock_record[filename]={} 23 | with open('I:\\MLHomework\\stockData_american\\'+filename+'.csv','r',encoding='utf-8') as f: 24 | firstline=True#第一行不是数据 25 | for line in f: 26 | if(firstline): 27 | firstline=False 28 | continue 29 | else: 30 | array=line.strip().split(',') 31 | date=array[0] 32 | openValue=float(array[1]) 33 | closeValue=float(array[4]) 34 | l=list(0 for i in range(2)) 35 | l[0]=openValue 36 | l[1]=closeValue 37 | us_stock_record[filename][date]=l 38 | 39 | money=1000000.0 40 | #for循环,每一天 41 | for year in range(2015,2016): 42 | for month in range(5,9): 43 | for day in range(1,32): 44 | date=str(year)+'-' 45 | if(month<10): 46 | date+='0'+str(month)+'-' 47 | else: 48 | date+=str(month)+'-' 49 | 50 | if(day<10): 51 | date+='0'+str(day) 52 | else: 53 | date+=str(day) 54 | #此时的date是一个日期string 55 | 56 | if(date in chinese_stock.keys()): 57 | print(date+'\t'+str(money)) 58 | result_out.write(str(money)+'\n') 59 | log_out.write(date+'\t'+str(money)+'\n') 60 | #得到当天的中国股市情况 61 | l=chinese_stock[date] 62 | #对于高准确率的股票进行预测,并记下当天的开盘价收盘价 63 | us_stock={}#记录美股的预测情况 64 | buy_num=0#今天需要买入的股票数量 65 | for stock_name in os.listdir(clf_dir): 66 | record=list(0 for i in range(3)) 67 | #读入分类器 68 | with open(clf_dir+'\\'+stock_name,'br') as f: 69 | clf=pickle.load(f) 70 | result=clf.predict(l) 71 | if(1==result): 72 | buy_num+=1 73 | try: 74 | record[0]=int(result[0])#第一列存预测结果 75 | record[1]=us_stock_record[stock_name][date][0]#开盘价 76 | record[2]=us_stock_record[stock_name][date][1]#收盘价 77 | except: 78 | #print('error') 79 | #buy_num=0#如果出现异常,就不交易了,日期报错 80 | record[0]=0; 81 | if(1==result): 82 | buy_num-=1 83 | us_stock[stock_name]=record 84 | 85 | #计算今天的收益 86 | if(buy_num!=0): 87 | div_money=(money+0.0)/buy_num 88 | today_sum_money=0.0 89 | for stock_name in us_stock.keys(): 90 | if(us_stock[stock_name][0]==1): 91 | p=(us_stock[stock_name][2]+0.0)/us_stock[stock_name][1]#收益率 92 | earn=div_money*(p-1.0) 93 | today_sum_money+=div_money*p 94 | log_out.write(stock_name+'\t'+str(us_stock[stock_name][1])+'\t'+str(us_stock[stock_name][2])+'\t'+str(earn)+'\n') 95 | #以收盘价卖出所有股票 96 | money=today_sum_money 97 | log_out.close() 98 | result_out.close() 99 | return money 100 | 101 | 102 | 103 | #返回最后剩下的钱 104 | 105 | 106 | if __name__ == '__main__': 107 | print(calc()) -------------------------------------------------------------------------------- /classification/featureSelection.py: -------------------------------------------------------------------------------- 1 | from sklearn import datasets 2 | from sklearn.feature_selection import RFE 3 | from sklearn.linear_model import LogisticRegression 4 | from sklearn.ensemble import RandomForestClassifier 5 | from sklearn import svm 6 | 7 | def test_data(clf,testFile,remove_list): 8 | sum=0 9 | right_sum=0 10 | with open(testFile,'r',encoding='utf-8') as f: 11 | for line in f: 12 | sum+=1 13 | array=line.split('\t') 14 | v=int(array[len(array)-1]) 15 | l=[] 16 | for i in range(1,len(array)-1): 17 | if(i-1 not in remove_list): 18 | l.append(int(array[i])) 19 | 20 | #print(len(l)) 21 | result=clf.predict(l) 22 | if(result==v): 23 | right_sum+=1 24 | #print(str(v)+'\t'+str(result)) 25 | print('precision='+str(right_sum/(sum+0.0))) 26 | print(str(right_sum)+'/'+str(sum)) 27 | return right_sum/(sum+0.0) 28 | 29 | if __name__ == '__main__': 30 | #读文件 31 | data=[] 32 | target=[] 33 | with open('I:\\MLHomework\\Data\\2\\trainData1y\\NOBGY.csv.data','r',encoding='utf-8') as f: 34 | for line in f: 35 | array=line.split('\t') 36 | #print(len(array)) 37 | l=list(0 for i in range(len(array)-2)) 38 | for i in range(1,len(array)-1): 39 | l[i-1]=int(array[i]) 40 | data.append(l) 41 | #print(len(l)) 42 | target.append(int(array[len(array)-1])) 43 | 44 | 45 | # create a base classifier used to evaluate a subset of attributes 46 | #model = RandomForestClassifier() 47 | remove_list=[] 48 | 49 | print('RandomForestClassifier') 50 | clf=RandomForestClassifier() 51 | clf.fit(data, target) 52 | test_data(clf,'I:\\MLHomework\\Data\\2\\testData8m\\NOBGY.csv.data',remove_list) 53 | print('SVM') 54 | clf11=svm.SVC(kernel='rbf') 55 | clf11.fit(data, target) 56 | test_data(clf11,'I:\\MLHomework\\Data\\2\\testData8m\\NOBGY.csv.data',remove_list) 57 | 58 | print('1') 59 | for i in range(0,len(clf.feature_importances_)): 60 | if(clf.feature_importances_[i]==0.0): 61 | remove_list.append(i) 62 | print(len(remove_list)) 63 | #读文件 64 | data2=[] 65 | target2=[] 66 | with open('I:\\MLHomework\\Data\\2\\trainData1y\\NOBGY.csv.data','r',encoding='utf-8') as f: 67 | for line in f: 68 | array=line.split('\t') 69 | l=[] 70 | for i in range(1,len(array)-2): 71 | if(i-1 not in remove_list): 72 | l.append(int(array[i])) 73 | data2.append(l) 74 | target2.append(int(array[len(array)-1])) 75 | 76 | print('RandomForestClassifier') 77 | clf2=RandomForestClassifier() 78 | clf2.fit(data2, target2) 79 | test_data(clf2,'I:\\MLHomework\\Data\\2\\testData8m\\NOBGY.csv.data',remove_list) 80 | print('SVM') 81 | clf22=svm.SVC(kernel='rbf') 82 | clf22.fit(data2, target2) 83 | test_data(clf22,'I:\\MLHomework\\Data\\2\\testData8m\\NOBGY.csv.data',remove_list) 84 | # for importance in clf.feature_importances_: 85 | # print(importance) 86 | 87 | -------------------------------------------------------------------------------- /classification/predictStock.py: -------------------------------------------------------------------------------- 1 | from sklearn.naive_bayes import GaussianNB 2 | from sklearn import tree 3 | from sklearn import svm 4 | from sklearn.ensemble import RandomForestClassifier 5 | from sklearn.naive_bayes import GaussianNB 6 | 7 | import os 8 | import pickle 9 | 10 | def train_data(filePath): 11 | data=[] 12 | target=[] 13 | with open(filePath,'r',encoding='utf-8') as f: 14 | for line in f: 15 | array=line.split('\t') 16 | l=list(0 for i in range(len(array)-2)) 17 | for i in range(1,len(array)-2): 18 | l[i]=int(array[i]) 19 | data.append(l) 20 | target.append(int(array[len(array)-1])) 21 | 22 | #clf=RandomForestClassifier() 23 | try: 24 | clf=svm.SVC(kernel='rbf') 25 | #clf=RandomForestClassifier() 26 | #clf=GaussianNB() 27 | #clf=tree.DecisionTreeClassifier() 28 | 29 | clf.fit(data, target) 30 | return clf 31 | except: 32 | return None 33 | 34 | 35 | 36 | 37 | def test_data(clf,testFile): 38 | sum=0 39 | right_sum=0 40 | with open(testFile,'r',encoding='utf-8') as f: 41 | for line in f: 42 | sum+=1 43 | array=line.split('\t') 44 | v=int(array[len(array)-1]) 45 | l=list(0 for i in range(len(array)-2)) 46 | for i in range(1,len(array)-2): 47 | l[i]=int(array[i]) 48 | result=clf.predict(l) 49 | if(result==v): 50 | right_sum+=1 51 | print(str(v)+'\t'+str(result)) 52 | print('precision='+str(right_sum/(sum+0.0))) 53 | print(str(right_sum)+'/'+str(sum)) 54 | return right_sum/(sum+0.0) 55 | 56 | def perdict_all(): 57 | trainDir='I:\\MLHomework\\\Data\\4\\trainData4m' 58 | testDir='I:\\MLHomework\\\Data\\4\\testData4m' 59 | clfDir='I:\\MLHomework\\\Data\\2\\clf8'#分类器保存位置 60 | 61 | out=open('I:\\MLHomework\\\Data\\2\\result_svm3.txt','w',encoding='utf-8') 62 | #数据统计 63 | sum=0 64 | pn_list=[0,0,0,0,0,0,0,0,0,0]#分别表示准确率50-60的,60-70的,70-80的,80-90的,90-100的数量 65 | #训练数据文件夹 66 | for filename in os.listdir(trainDir): 67 | #测试数据文件夹 68 | if(filename in os.listdir(testDir)): 69 | print(filename[0:len(filename)-9]) 70 | clf=train_data(trainDir+'\\'+filename) 71 | 72 | if(clf!=None): 73 | p=test_data(clf,testDir+'\\'+filename) 74 | sum+=1 75 | #把这个股票代码保存起来 76 | out.write(filename[0:len(filename)-9]+'\t'+str(p)+'\n') 77 | 78 | if(p>=0.8): 79 | #把分类器保存起来 80 | with open(clfDir+'\\'+filename[0:len(filename)-9],'bw') as f: 81 | pickle.dump(clf,f) 82 | # if(os.path.exists(trainDir+'\\'+filename)): 83 | # os.remove(trainDir+'\\'+filename) 84 | # if(os.path.exists(testDir+'\\'+filename)): 85 | # os.remove(testDir+'\\'+filename) 86 | if(p==1):#防止在1的时候数组出界 87 | p-=0.01 88 | 89 | pn_list[int((p*100)/10)]=pn_list[int((p*100)/10)]+1 90 | 91 | for i in range(10): 92 | print(str(i*10)+'~'+str(10+i*10)+'='+str(pn_list[i])+'\t\t\t占'+str(pn_list[i]/(sum+0.0))) 93 | out.close() 94 | 95 | def perdict_193Stock(clfDir,testDataPath,outFilePath): 96 | l=list(0 for i in range(2892))#一共中国股票是2892维/个 97 | with open(testDataPath,'r',encoding='utf-8') as f: 98 | for line in f: 99 | array=line.split('\t') 100 | for i in range(2892): 101 | l[i]=array[i] 102 | 103 | 104 | out=open(outFilePath,'w',encoding='utf-8') 105 | for filename in os.listdir(clfDir): 106 | print(filename) 107 | with open(clfDir+'\\'+filename,'br') as f: 108 | clf=pickle.load(f) 109 | result=clf.predict(l) 110 | out.write(filename+'\t'+str(result[0])+'\n') 111 | out.close() 112 | 113 | 114 | if __name__ == '__main__': 115 | #perdict_193Stock('I:\\MLHomework\\Data\\2\\clf','I:\\MLHomework\\1dayData.txt','I:\\MLHomework\\1dayDataResult.txt') 116 | perdict_all() -------------------------------------------------------------------------------- /clawer/USAStock_clawer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | import re 3 | import requests 4 | import urllib 5 | import time 6 | import os 7 | from bs4 import BeautifulSoup 8 | import re 9 | 10 | def get_encoding(text): 11 | try: 12 | return re.search('charset=\"?([^\"]+)"',text).group(1) 13 | except: 14 | return None 15 | 16 | 17 | def clawer_page(url): 18 | try: 19 | response = requests.get(url) 20 | text=response.text 21 | encoding=get_encoding(text) 22 | if(encoding is not None): 23 | response.encoding=encoding 24 | text=response.text 25 | return text 26 | except: 27 | print('error') 28 | return None 29 | 30 | 31 | 32 | 33 | if __name__ == '__main__': 34 | for i in range(8): 35 | url='http://data.tsci.com.cn/US/USCODE.aspx?M=3&First=All&Sid=&uhu=df&P='+str(i) 36 | page=clawer_page(url) 37 | out=open('D:\\programing\\Python\\MLHomework\\美股数据\\美国证券交易所'+str(i),'w',encoding='utf-8') 38 | out.write(page) 39 | out.close() 40 | print(i) 41 | -------------------------------------------------------------------------------- /clawer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxls080511/stockPrediction/7ccb30b9f2b3537347e00b613ec0dbbaa97c6bc7/clawer/__init__.py -------------------------------------------------------------------------------- /clawer/code_name.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def china(): 4 | out=open('D:\\programing\\Python\\MLHomework\\american_codeName.txt','w',encoding='utf-8') 5 | with open('D:\\programing\\Python\\MLHomework\\american_codeNameaaa.txt','r',encoding='utf-8') as f: 6 | for line in f: 7 | array=line.strip().split('\t') 8 | match=re.search('([^(]+)\\(([^\\)]+)\\)',array[0]) 9 | name=match.group(1) 10 | code=match.group(2) 11 | market=array[2] 12 | out.write(name+'\t'+code+'\t'+market+'\n') 13 | out.close() 14 | 15 | 16 | def american(): 17 | out=open('D:\\programing\\Python\\MLHomework\\american_codeName2.txt','w',encoding='utf-8') 18 | with open('D:\\programing\\Python\\MLHomework\\美股抽取.txt','r',encoding='utf-8') as f: 19 | for line in f: 20 | array=line.strip().split('\t') 21 | name='' 22 | code='' 23 | market='' 24 | if('.NASDQ' in array[0]): 25 | array1=array[0].split('.NASDQ') 26 | market='NASDQ' 27 | name=array1[1].strip() 28 | code=array1[0].strip() 29 | if('.NYSE' in array[0]): 30 | array1=array[0].split('.NYSE') 31 | market='NYSE' 32 | name=array1[1].strip() 33 | code=array1[0].strip() 34 | if('.AMEX' in array[0]): 35 | array1=array[0].split('.AMEX') 36 | market='AMEX' 37 | name=array1[1].strip() 38 | code=array1[0].strip() 39 | 40 | out.write(name+'\t'+code+'\t'+market+'\n') 41 | out.close() 42 | 43 | if __name__ == '__main__': 44 | china() 45 | american() -------------------------------------------------------------------------------- /data_download.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | import urllib 3 | import urllib2 4 | import requests 5 | import codecs 6 | import socket 7 | 8 | def download(): 9 | with codecs.open('D:\\programing\\Python\\MLHomework\\china_codeName.txt','r','utf-8') as f: 10 | for line in f: 11 | array=line.strip().split('\t') 12 | name=array[0] 13 | code=array[1] 14 | market=array[2] 15 | if(u'上证'==market): 16 | print code+u'上证' 17 | url='http://table.finance.yahoo.com/table.csv?s='+code+'.ss' 18 | urllib.urlretrieve(url, 'D:\\programing\\Python\\MLHomework\\stockData_china\\'+code+'.ss.csv') 19 | elif(u'深成'==market): 20 | print code+u'深成' 21 | url='http://table.finance.yahoo.com/table.csv?s='+code+'.sz' 22 | urllib.urlretrieve(url, 'D:\\programing\\Python\\MLHomework\\stockData_china\\'+code+'.sz.csv') 23 | 24 | def american_download(): 25 | #socket.setdefaulttimeout(15) 26 | n=0 27 | with codecs.open('I:\\MLHomework\\american_codeName.txt','r','utf-8') as f: 28 | for line in f: 29 | array=line.strip().split('\t') 30 | name=array[0] 31 | code=array[1] 32 | market=array[2] 33 | 34 | n+=1 35 | print code,n 36 | url='http://table.finance.yahoo.com/table.csv?s='+code+'&d=8&e=24&f=2015&g=d&a=1&b=1&c=2010&ignore=.csv' 37 | try: 38 | urllib.urlretrieve(url, 'I:\\MLHomework\\stockData_american1\\'+code+'.csv') 39 | except: 40 | print 'error' 41 | 42 | 43 | def china_download_day(seqFilePath,downloadDir,outDataPath,y,m,d): 44 | data=[] 45 | out=codecs.open(outDataPath,'w','utf-8')#输出data数据 46 | n=0 47 | with codecs.open(seqFilePath,'r','utf-8') as f: 48 | for line in f: 49 | array=line.strip().split('.') 50 | code=array[0] 51 | market=array[1] 52 | print(code+market+'\t'+str(n)) 53 | n+=1 54 | savedFilePath='' 55 | if(u'ss'==market): 56 | url='http://table.finance.yahoo.com/table.csv?s='+code+'.ss&d='+str(m-1)+'&e='+str(d)+'&f='+str(y)+'&g=d&a='+str(m-1)+'&b='+str(d)+'&c='+str(y)+'&ignore=.csv' 57 | urllib.urlretrieve(url, downloadDir+'\\'+code+'.ss.csv') 58 | savedFilePath=downloadDir+'\\'+code+'.ss.csv' 59 | elif(u'sz'==market): 60 | url='http://table.finance.yahoo.com/table.csv?s='+code+'.sz&d='+str(m-1)+'&e='+str(d)+'&f='+str(y)+'&g=d&a='+str(m-1)+'&b='+str(d)+'&c='+str(y)+'&ignore=.csv' 61 | urllib.urlretrieve(url, downloadDir+'\\'+code+'.sz.csv') 62 | savedFilePath=downloadDir+'\\'+code+'.sz.csv' 63 | #下载完成之后就可以得到这个文件中的数据 64 | with codecs.open(savedFilePath,'r','utf-8') as f: 65 | firstline=True 66 | for line in f: 67 | if(firstline): 68 | firstline=False 69 | continue 70 | else:#那些error的会出错 最后break就行 71 | rise=1#上涨 72 | date='' 73 | try: 74 | array=line.strip().split(',') 75 | date=array[0] 76 | openValue=float(array[1]) 77 | closeValue=float(array[4]) 78 | if(closeValue