├── .project
├── .pydevproject
├── .settings
└── org.eclipse.core.resources.prefs
├── README.md
├── classification
├── __init__.py
├── calc_profit.py
├── featureSelection.py
└── predictStock.py
├── clawer
├── USAStock_clawer.py
├── __init__.py
└── code_name.py
├── data_download.py
├── getData
├── __init__.py
└── getTrainData.py
└── 美股涨跌预测系统的探究-141499罗斌.docx
/.project:
--------------------------------------------------------------------------------
1 |
2 |
3 | MLHomework34
4 |
5 |
6 |
7 |
8 |
9 | org.python.pydev.PyDevBuilder
10 |
11 |
12 |
13 |
14 |
15 | org.python.pydev.pythonNature
16 |
17 |
18 |
--------------------------------------------------------------------------------
/.pydevproject:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | /${PROJECT_DIR_NAME}
5 |
6 | python 3.0
7 | C:\Python34\python.exe
8 |
9 |
--------------------------------------------------------------------------------
/.settings/org.eclipse.core.resources.prefs:
--------------------------------------------------------------------------------
1 | eclipse.preferences.version=1
2 | encoding//clawer/USAStock_clawer.py=utf-8
3 | encoding/=UTF-8
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # stockPrediction
2 | 本项目对美股股票的涨跌进行了研究,从问题出发并提出猜想,然后定义了机器学习的实验任务。通过多次实验得到实验数据,最终证明了所提出的猜想:中国股市确实和美国股市存在着一定的联系,并且通过这些隐含的联系可以预测某些美国股票的涨跌。
3 |
--------------------------------------------------------------------------------
/classification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxls080511/stockPrediction/7ccb30b9f2b3537347e00b613ec0dbbaa97c6bc7/classification/__init__.py
--------------------------------------------------------------------------------
/classification/calc_profit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | def calc():
5 | log_out=open('I:\\MLHomework\\Data\\2\\clf8.log','w',encoding='utf-8')#log文件
6 | result_out=open('I:\\MLHomework\\Data\\2\\result8.log','w',encoding='utf-8')#log文件
7 | clf_dir='I:\\MLHomework\\Data\\2\\clf8'
8 | #先读入5-8月中国股市的数据
9 | chinese_stock={}
10 | with open('I:\\MLHomework\\Data\\4\\testData4m\\ASX.csv.data','r',encoding='utf-8') as f:
11 | for line in f:
12 | array=line.split('\t')
13 | date=array[0]
14 | v=int(array[len(array)-1])#最后的值不需要了
15 | l=list(0 for i in range(len(array)-2))
16 | for i in range(1,len(array)-2):
17 | l[i]=int(array[i])
18 | chinese_stock[date]=l
19 |
20 | us_stock_record={}
21 | for filename in os.listdir(clf_dir):
22 | us_stock_record[filename]={}
23 | with open('I:\\MLHomework\\stockData_american\\'+filename+'.csv','r',encoding='utf-8') as f:
24 | firstline=True#第一行不是数据
25 | for line in f:
26 | if(firstline):
27 | firstline=False
28 | continue
29 | else:
30 | array=line.strip().split(',')
31 | date=array[0]
32 | openValue=float(array[1])
33 | closeValue=float(array[4])
34 | l=list(0 for i in range(2))
35 | l[0]=openValue
36 | l[1]=closeValue
37 | us_stock_record[filename][date]=l
38 |
39 | money=1000000.0
40 | #for循环,每一天
41 | for year in range(2015,2016):
42 | for month in range(5,9):
43 | for day in range(1,32):
44 | date=str(year)+'-'
45 | if(month<10):
46 | date+='0'+str(month)+'-'
47 | else:
48 | date+=str(month)+'-'
49 |
50 | if(day<10):
51 | date+='0'+str(day)
52 | else:
53 | date+=str(day)
54 | #此时的date是一个日期string
55 |
56 | if(date in chinese_stock.keys()):
57 | print(date+'\t'+str(money))
58 | result_out.write(str(money)+'\n')
59 | log_out.write(date+'\t'+str(money)+'\n')
60 | #得到当天的中国股市情况
61 | l=chinese_stock[date]
62 | #对于高准确率的股票进行预测,并记下当天的开盘价收盘价
63 | us_stock={}#记录美股的预测情况
64 | buy_num=0#今天需要买入的股票数量
65 | for stock_name in os.listdir(clf_dir):
66 | record=list(0 for i in range(3))
67 | #读入分类器
68 | with open(clf_dir+'\\'+stock_name,'br') as f:
69 | clf=pickle.load(f)
70 | result=clf.predict(l)
71 | if(1==result):
72 | buy_num+=1
73 | try:
74 | record[0]=int(result[0])#第一列存预测结果
75 | record[1]=us_stock_record[stock_name][date][0]#开盘价
76 | record[2]=us_stock_record[stock_name][date][1]#收盘价
77 | except:
78 | #print('error')
79 | #buy_num=0#如果出现异常,就不交易了,日期报错
80 | record[0]=0;
81 | if(1==result):
82 | buy_num-=1
83 | us_stock[stock_name]=record
84 |
85 | #计算今天的收益
86 | if(buy_num!=0):
87 | div_money=(money+0.0)/buy_num
88 | today_sum_money=0.0
89 | for stock_name in us_stock.keys():
90 | if(us_stock[stock_name][0]==1):
91 | p=(us_stock[stock_name][2]+0.0)/us_stock[stock_name][1]#收益率
92 | earn=div_money*(p-1.0)
93 | today_sum_money+=div_money*p
94 | log_out.write(stock_name+'\t'+str(us_stock[stock_name][1])+'\t'+str(us_stock[stock_name][2])+'\t'+str(earn)+'\n')
95 | #以收盘价卖出所有股票
96 | money=today_sum_money
97 | log_out.close()
98 | result_out.close()
99 | return money
100 |
101 |
102 |
103 | #返回最后剩下的钱
104 |
105 |
106 | if __name__ == '__main__':
107 | print(calc())
--------------------------------------------------------------------------------
/classification/featureSelection.py:
--------------------------------------------------------------------------------
1 | from sklearn import datasets
2 | from sklearn.feature_selection import RFE
3 | from sklearn.linear_model import LogisticRegression
4 | from sklearn.ensemble import RandomForestClassifier
5 | from sklearn import svm
6 |
7 | def test_data(clf,testFile,remove_list):
8 | sum=0
9 | right_sum=0
10 | with open(testFile,'r',encoding='utf-8') as f:
11 | for line in f:
12 | sum+=1
13 | array=line.split('\t')
14 | v=int(array[len(array)-1])
15 | l=[]
16 | for i in range(1,len(array)-1):
17 | if(i-1 not in remove_list):
18 | l.append(int(array[i]))
19 |
20 | #print(len(l))
21 | result=clf.predict(l)
22 | if(result==v):
23 | right_sum+=1
24 | #print(str(v)+'\t'+str(result))
25 | print('precision='+str(right_sum/(sum+0.0)))
26 | print(str(right_sum)+'/'+str(sum))
27 | return right_sum/(sum+0.0)
28 |
29 | if __name__ == '__main__':
30 | #读文件
31 | data=[]
32 | target=[]
33 | with open('I:\\MLHomework\\Data\\2\\trainData1y\\NOBGY.csv.data','r',encoding='utf-8') as f:
34 | for line in f:
35 | array=line.split('\t')
36 | #print(len(array))
37 | l=list(0 for i in range(len(array)-2))
38 | for i in range(1,len(array)-1):
39 | l[i-1]=int(array[i])
40 | data.append(l)
41 | #print(len(l))
42 | target.append(int(array[len(array)-1]))
43 |
44 |
45 | # create a base classifier used to evaluate a subset of attributes
46 | #model = RandomForestClassifier()
47 | remove_list=[]
48 |
49 | print('RandomForestClassifier')
50 | clf=RandomForestClassifier()
51 | clf.fit(data, target)
52 | test_data(clf,'I:\\MLHomework\\Data\\2\\testData8m\\NOBGY.csv.data',remove_list)
53 | print('SVM')
54 | clf11=svm.SVC(kernel='rbf')
55 | clf11.fit(data, target)
56 | test_data(clf11,'I:\\MLHomework\\Data\\2\\testData8m\\NOBGY.csv.data',remove_list)
57 |
58 | print('1')
59 | for i in range(0,len(clf.feature_importances_)):
60 | if(clf.feature_importances_[i]==0.0):
61 | remove_list.append(i)
62 | print(len(remove_list))
63 | #读文件
64 | data2=[]
65 | target2=[]
66 | with open('I:\\MLHomework\\Data\\2\\trainData1y\\NOBGY.csv.data','r',encoding='utf-8') as f:
67 | for line in f:
68 | array=line.split('\t')
69 | l=[]
70 | for i in range(1,len(array)-2):
71 | if(i-1 not in remove_list):
72 | l.append(int(array[i]))
73 | data2.append(l)
74 | target2.append(int(array[len(array)-1]))
75 |
76 | print('RandomForestClassifier')
77 | clf2=RandomForestClassifier()
78 | clf2.fit(data2, target2)
79 | test_data(clf2,'I:\\MLHomework\\Data\\2\\testData8m\\NOBGY.csv.data',remove_list)
80 | print('SVM')
81 | clf22=svm.SVC(kernel='rbf')
82 | clf22.fit(data2, target2)
83 | test_data(clf22,'I:\\MLHomework\\Data\\2\\testData8m\\NOBGY.csv.data',remove_list)
84 | # for importance in clf.feature_importances_:
85 | # print(importance)
86 |
87 |
--------------------------------------------------------------------------------
/classification/predictStock.py:
--------------------------------------------------------------------------------
1 | from sklearn.naive_bayes import GaussianNB
2 | from sklearn import tree
3 | from sklearn import svm
4 | from sklearn.ensemble import RandomForestClassifier
5 | from sklearn.naive_bayes import GaussianNB
6 |
7 | import os
8 | import pickle
9 |
10 | def train_data(filePath):
11 | data=[]
12 | target=[]
13 | with open(filePath,'r',encoding='utf-8') as f:
14 | for line in f:
15 | array=line.split('\t')
16 | l=list(0 for i in range(len(array)-2))
17 | for i in range(1,len(array)-2):
18 | l[i]=int(array[i])
19 | data.append(l)
20 | target.append(int(array[len(array)-1]))
21 |
22 | #clf=RandomForestClassifier()
23 | try:
24 | clf=svm.SVC(kernel='rbf')
25 | #clf=RandomForestClassifier()
26 | #clf=GaussianNB()
27 | #clf=tree.DecisionTreeClassifier()
28 |
29 | clf.fit(data, target)
30 | return clf
31 | except:
32 | return None
33 |
34 |
35 |
36 |
37 | def test_data(clf,testFile):
38 | sum=0
39 | right_sum=0
40 | with open(testFile,'r',encoding='utf-8') as f:
41 | for line in f:
42 | sum+=1
43 | array=line.split('\t')
44 | v=int(array[len(array)-1])
45 | l=list(0 for i in range(len(array)-2))
46 | for i in range(1,len(array)-2):
47 | l[i]=int(array[i])
48 | result=clf.predict(l)
49 | if(result==v):
50 | right_sum+=1
51 | print(str(v)+'\t'+str(result))
52 | print('precision='+str(right_sum/(sum+0.0)))
53 | print(str(right_sum)+'/'+str(sum))
54 | return right_sum/(sum+0.0)
55 |
56 | def perdict_all():
57 | trainDir='I:\\MLHomework\\\Data\\4\\trainData4m'
58 | testDir='I:\\MLHomework\\\Data\\4\\testData4m'
59 | clfDir='I:\\MLHomework\\\Data\\2\\clf8'#分类器保存位置
60 |
61 | out=open('I:\\MLHomework\\\Data\\2\\result_svm3.txt','w',encoding='utf-8')
62 | #数据统计
63 | sum=0
64 | pn_list=[0,0,0,0,0,0,0,0,0,0]#分别表示准确率50-60的,60-70的,70-80的,80-90的,90-100的数量
65 | #训练数据文件夹
66 | for filename in os.listdir(trainDir):
67 | #测试数据文件夹
68 | if(filename in os.listdir(testDir)):
69 | print(filename[0:len(filename)-9])
70 | clf=train_data(trainDir+'\\'+filename)
71 |
72 | if(clf!=None):
73 | p=test_data(clf,testDir+'\\'+filename)
74 | sum+=1
75 | #把这个股票代码保存起来
76 | out.write(filename[0:len(filename)-9]+'\t'+str(p)+'\n')
77 |
78 | if(p>=0.8):
79 | #把分类器保存起来
80 | with open(clfDir+'\\'+filename[0:len(filename)-9],'bw') as f:
81 | pickle.dump(clf,f)
82 | # if(os.path.exists(trainDir+'\\'+filename)):
83 | # os.remove(trainDir+'\\'+filename)
84 | # if(os.path.exists(testDir+'\\'+filename)):
85 | # os.remove(testDir+'\\'+filename)
86 | if(p==1):#防止在1的时候数组出界
87 | p-=0.01
88 |
89 | pn_list[int((p*100)/10)]=pn_list[int((p*100)/10)]+1
90 |
91 | for i in range(10):
92 | print(str(i*10)+'~'+str(10+i*10)+'='+str(pn_list[i])+'\t\t\t占'+str(pn_list[i]/(sum+0.0)))
93 | out.close()
94 |
95 | def perdict_193Stock(clfDir,testDataPath,outFilePath):
96 | l=list(0 for i in range(2892))#一共中国股票是2892维/个
97 | with open(testDataPath,'r',encoding='utf-8') as f:
98 | for line in f:
99 | array=line.split('\t')
100 | for i in range(2892):
101 | l[i]=array[i]
102 |
103 |
104 | out=open(outFilePath,'w',encoding='utf-8')
105 | for filename in os.listdir(clfDir):
106 | print(filename)
107 | with open(clfDir+'\\'+filename,'br') as f:
108 | clf=pickle.load(f)
109 | result=clf.predict(l)
110 | out.write(filename+'\t'+str(result[0])+'\n')
111 | out.close()
112 |
113 |
114 | if __name__ == '__main__':
115 | #perdict_193Stock('I:\\MLHomework\\Data\\2\\clf','I:\\MLHomework\\1dayData.txt','I:\\MLHomework\\1dayDataResult.txt')
116 | perdict_all()
--------------------------------------------------------------------------------
/clawer/USAStock_clawer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*- #
2 | import re
3 | import requests
4 | import urllib
5 | import time
6 | import os
7 | from bs4 import BeautifulSoup
8 | import re
9 |
10 | def get_encoding(text):
11 | try:
12 | return re.search('charset=\"?([^\"]+)"',text).group(1)
13 | except:
14 | return None
15 |
16 |
17 | def clawer_page(url):
18 | try:
19 | response = requests.get(url)
20 | text=response.text
21 | encoding=get_encoding(text)
22 | if(encoding is not None):
23 | response.encoding=encoding
24 | text=response.text
25 | return text
26 | except:
27 | print('error')
28 | return None
29 |
30 |
31 |
32 |
33 | if __name__ == '__main__':
34 | for i in range(8):
35 | url='http://data.tsci.com.cn/US/USCODE.aspx?M=3&First=All&Sid=&uhu=df&P='+str(i)
36 | page=clawer_page(url)
37 | out=open('D:\\programing\\Python\\MLHomework\\美股数据\\美国证券交易所'+str(i),'w',encoding='utf-8')
38 | out.write(page)
39 | out.close()
40 | print(i)
41 |
--------------------------------------------------------------------------------
/clawer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxls080511/stockPrediction/7ccb30b9f2b3537347e00b613ec0dbbaa97c6bc7/clawer/__init__.py
--------------------------------------------------------------------------------
/clawer/code_name.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | def china():
4 | out=open('D:\\programing\\Python\\MLHomework\\american_codeName.txt','w',encoding='utf-8')
5 | with open('D:\\programing\\Python\\MLHomework\\american_codeNameaaa.txt','r',encoding='utf-8') as f:
6 | for line in f:
7 | array=line.strip().split('\t')
8 | match=re.search('([^(]+)\\(([^\\)]+)\\)',array[0])
9 | name=match.group(1)
10 | code=match.group(2)
11 | market=array[2]
12 | out.write(name+'\t'+code+'\t'+market+'\n')
13 | out.close()
14 |
15 |
16 | def american():
17 | out=open('D:\\programing\\Python\\MLHomework\\american_codeName2.txt','w',encoding='utf-8')
18 | with open('D:\\programing\\Python\\MLHomework\\美股抽取.txt','r',encoding='utf-8') as f:
19 | for line in f:
20 | array=line.strip().split('\t')
21 | name=''
22 | code=''
23 | market=''
24 | if('.NASDQ' in array[0]):
25 | array1=array[0].split('.NASDQ')
26 | market='NASDQ'
27 | name=array1[1].strip()
28 | code=array1[0].strip()
29 | if('.NYSE' in array[0]):
30 | array1=array[0].split('.NYSE')
31 | market='NYSE'
32 | name=array1[1].strip()
33 | code=array1[0].strip()
34 | if('.AMEX' in array[0]):
35 | array1=array[0].split('.AMEX')
36 | market='AMEX'
37 | name=array1[1].strip()
38 | code=array1[0].strip()
39 |
40 | out.write(name+'\t'+code+'\t'+market+'\n')
41 | out.close()
42 |
43 | if __name__ == '__main__':
44 | china()
45 | american()
--------------------------------------------------------------------------------
/data_download.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*- #
2 | import urllib
3 | import urllib2
4 | import requests
5 | import codecs
6 | import socket
7 |
8 | def download():
9 | with codecs.open('D:\\programing\\Python\\MLHomework\\china_codeName.txt','r','utf-8') as f:
10 | for line in f:
11 | array=line.strip().split('\t')
12 | name=array[0]
13 | code=array[1]
14 | market=array[2]
15 | if(u'上证'==market):
16 | print code+u'上证'
17 | url='http://table.finance.yahoo.com/table.csv?s='+code+'.ss'
18 | urllib.urlretrieve(url, 'D:\\programing\\Python\\MLHomework\\stockData_china\\'+code+'.ss.csv')
19 | elif(u'深成'==market):
20 | print code+u'深成'
21 | url='http://table.finance.yahoo.com/table.csv?s='+code+'.sz'
22 | urllib.urlretrieve(url, 'D:\\programing\\Python\\MLHomework\\stockData_china\\'+code+'.sz.csv')
23 |
24 | def american_download():
25 | #socket.setdefaulttimeout(15)
26 | n=0
27 | with codecs.open('I:\\MLHomework\\american_codeName.txt','r','utf-8') as f:
28 | for line in f:
29 | array=line.strip().split('\t')
30 | name=array[0]
31 | code=array[1]
32 | market=array[2]
33 |
34 | n+=1
35 | print code,n
36 | url='http://table.finance.yahoo.com/table.csv?s='+code+'&d=8&e=24&f=2015&g=d&a=1&b=1&c=2010&ignore=.csv'
37 | try:
38 | urllib.urlretrieve(url, 'I:\\MLHomework\\stockData_american1\\'+code+'.csv')
39 | except:
40 | print 'error'
41 |
42 |
43 | def china_download_day(seqFilePath,downloadDir,outDataPath,y,m,d):
44 | data=[]
45 | out=codecs.open(outDataPath,'w','utf-8')#输出data数据
46 | n=0
47 | with codecs.open(seqFilePath,'r','utf-8') as f:
48 | for line in f:
49 | array=line.strip().split('.')
50 | code=array[0]
51 | market=array[1]
52 | print(code+market+'\t'+str(n))
53 | n+=1
54 | savedFilePath=''
55 | if(u'ss'==market):
56 | url='http://table.finance.yahoo.com/table.csv?s='+code+'.ss&d='+str(m-1)+'&e='+str(d)+'&f='+str(y)+'&g=d&a='+str(m-1)+'&b='+str(d)+'&c='+str(y)+'&ignore=.csv'
57 | urllib.urlretrieve(url, downloadDir+'\\'+code+'.ss.csv')
58 | savedFilePath=downloadDir+'\\'+code+'.ss.csv'
59 | elif(u'sz'==market):
60 | url='http://table.finance.yahoo.com/table.csv?s='+code+'.sz&d='+str(m-1)+'&e='+str(d)+'&f='+str(y)+'&g=d&a='+str(m-1)+'&b='+str(d)+'&c='+str(y)+'&ignore=.csv'
61 | urllib.urlretrieve(url, downloadDir+'\\'+code+'.sz.csv')
62 | savedFilePath=downloadDir+'\\'+code+'.sz.csv'
63 | #下载完成之后就可以得到这个文件中的数据
64 | with codecs.open(savedFilePath,'r','utf-8') as f:
65 | firstline=True
66 | for line in f:
67 | if(firstline):
68 | firstline=False
69 | continue
70 | else:#那些error的会出错 最后break就行
71 | rise=1#上涨
72 | date=''
73 | try:
74 | array=line.strip().split(',')
75 | date=array[0]
76 | openValue=float(array[1])
77 | closeValue=float(array[4])
78 | if(closeValue