├── cnn.py ├── data.py ├── rnn.py ├── store ├── 2018-12-15.zip ├── 2018-12-16.zip ├── 2018-12-17.zip ├── 2018-12-18.zip ├── 2018-12-19.zip ├── 2018-12-20.zip ├── 2018-12-21.zip ├── 2018-12-24.zip ├── 2018-12-25.zip ├── 2018-12-26.zip ├── 2018-12-27.zip └── 2018-12-28.zip └── test.py /cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import json 4 | import numpy as np 5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 6 | TEST_LEN = 5000 7 | 8 | def read(name): 9 | f = open('./nndata/'+name+'.data','r') 10 | result = json.loads(f.read()) 11 | f.close() 12 | return result 13 | 14 | 15 | #载入数据集 16 | list = np.array(read('list')) 17 | label = np.array(read('label')) 18 | 19 | train_list = list[0:(len(list)-TEST_LEN)] 20 | train_label = label[0:(len(label)-TEST_LEN)] 21 | 22 | test_list = list[-TEST_LEN:] 23 | test_label = label[-TEST_LEN:] 24 | 25 | print('数据加载完毕') 26 | width = 48 #输入一行,一行有28个数据 27 | height = 48 #一共28行 28 | n_classes = 4 # 10个分类 29 | batch_size = 48 #每批次50个样本 30 | n_batch = (len(label)-TEST_LEN)// batch_size #计算一共有多少个批次 31 | 32 | x = tf.placeholder(tf.float32, [None, width*height]) #输入的数据占位符 33 | y_actual = tf.placeholder(tf.float32, shape=[None, n_classes]) #输入的标签占位符 34 | 35 | #定义一个函数,用于初始化所有的权值 W 36 | def weight_variable(shape): 37 | initial = tf.truncated_normal(shape, stddev=0.1) 38 | return tf.Variable(initial) 39 | 40 | #定义一个函数,用于初始化所有的偏置项 b 41 | def bias_variable(shape): 42 | initial = tf.constant(0.1, shape=shape) 43 | return tf.Variable(initial) 44 | 45 | #定义一个函数,用于构建卷积层 46 | def conv2d(x, W): 47 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 48 | 49 | #定义一个函数,用于构建池化层 50 | def max_pool(x): 51 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME') 52 | 53 | #构建网络 54 | x_image = tf.reshape(x, [-1,width,height,1]) #转换输入数据shape,以便于用于网络中 55 | W_conv1 = weight_variable([5, 5, 1, 32]) 56 | b_conv1 = bias_variable([32]) 57 | h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) #第一个卷积层 58 | h_pool1 = max_pool(h_conv1) #第一个池化层 59 | 60 | W_conv2 = weight_variable([5, 5, 32, 64]) 61 | b_conv2 = bias_variable([64]) 62 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) #第二个卷积层 63 | h_pool2 = max_pool(h_conv2) #第二个池化层 64 | 65 | W_fc1 = weight_variable([12 * 12 * 64, 1024]) 66 | b_fc1 = bias_variable([1024]) 67 | h_pool2_flat = tf.reshape(h_pool2, [-1, 12*12*64]) #reshape成向量 68 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) #第一个全连接层 69 | 70 | keep_prob = tf.placeholder("float") 71 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) #dropout层 72 | 73 | W_fc2 = weight_variable([1024, n_classes]) 74 | b_fc2 = bias_variable([n_classes]) 75 | y_predict=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) #softmax层 76 | 77 | cross_entropy = -tf.reduce_sum(y_actual*tf.log(y_predict)) #交叉熵 78 | train_step = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy) #梯度下降法 79 | correct_prediction = tf.equal(tf.argmax(y_predict,1), tf.argmax(y_actual,1)) 80 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) #精确度计算 81 | sess=tf.InteractiveSession() 82 | sess.run(tf.global_variables_initializer()) 83 | for j in range(0,10): 84 | for i in range(0,n_batch): 85 | batch_xs = train_list[(i*batch_size):((i+1)*batch_size)] 86 | batch_ys = train_label[(i*batch_size):((i+1)*batch_size)] 87 | #print(batch_xs[-1:]) 88 | if i%50==0: 89 | train_acc = accuracy.eval(feed_dict={x:batch_xs, y_actual:batch_ys, keep_prob: 1.0}) 90 | print('step ',i,' training accuracy ',train_acc) 91 | train_step.run(feed_dict={x: batch_xs, y_actual: batch_ys, keep_prob: 0.5}) 92 | 93 | test_acc=accuracy.eval(feed_dict={x:test_list,y_actual:test_label, keep_prob: 1.0}) 94 | print('===================================') 95 | print("test accuracy ",test_acc) 96 | print('===================================') 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | import json 3 | import os 4 | import zipfile 5 | 6 | 7 | PATH_DATA = 'D:\quant\python\data' 8 | PATH_ZIP = 'D:\quant\python\zip' 9 | PATH_NNDATA = 'D:\\quant\\python\\nndata\\' 10 | 11 | DATA_ORIGINAL = [] 12 | DATA_FINAL = [] 13 | DATA_FINAL_SHAPE = 48 #数据矩阵的形状 14 | DEPTH = 19 #市场深度的长度 15 | 16 | PERCENT = 0.005 #百分之多少算暴涨暴跌 17 | RATIO = 0.5 #占未来的概率 18 | LABEL_LENGT = 1000 #标签数据的长度 19 | 20 | ''' 21 | 将store 目录下的压缩的数据文件移入zip目录下,并创建nndata目录。 22 | zip 目录下的文件是从okex抓取的价格以及市场深度的原始数据,用于加工成可卷积的数据 48*48 的二维矩阵 包括原料数据和标签数据 23 | 最后存入nndata 目录下 24 | ''' 25 | 26 | 27 | 28 | #解压文件 29 | def unzip(zip_path,data_path): 30 | list = os.listdir(zip_path) 31 | for i in range(0,len(list)): 32 | path = os.path.join(zip_path,list[i]) 33 | if os.path.isfile(path): 34 | z = zipfile.ZipFile(path, 'r') 35 | z.extractall(path=r''+data_path) 36 | z.close() 37 | 38 | #过滤异常数据 39 | def filter(str): 40 | str = str.replace('\n[',',') 41 | return json.loads(str) 42 | 43 | 44 | def write(path,data): 45 | f = open(path, 'w') 46 | f.write(str(data)) 47 | f.close() 48 | 49 | 50 | def getData(data_path): 51 | list = os.listdir(data_path) 52 | result = [] 53 | for i in range(0,len(list)): 54 | path = os.path.join(data_path,list[i]) 55 | if os.path.isfile(path): 56 | f = open(path,'r') 57 | result = result + filter(f.read()) 58 | f.close() 59 | return result 60 | 61 | 62 | #获得ma数据 63 | def ma(list,size,index,key): 64 | start = 0 65 | num = index + 1 66 | total = 0 67 | if size 0 代表什么都不是 1暴涨暴跌 2暴涨 3暴跌 115 | for j in range(0,height): 116 | item = item + data[i+j-length+1] 117 | for j in range(0,length): 118 | 119 | if data[i+j][40]/data[i][40]>=1.01: 120 | up +=1 121 | elif data[i+j][40]/data[i][40]<=0.99: 122 | down +=1 123 | if j==length-1: 124 | if up/length>=ratio: 125 | lab = [0,0,1,0] 126 | elif down/length>=ratio: 127 | lab = [0,0,0,1] 128 | elif up/length>=ratio and down/length>=ratio: 129 | lab = [0,1,0,0] 130 | 131 | 132 | list.append(item) 133 | label.append(lab) 134 | return {"list":list,"label":label} 135 | 136 | 137 | def main(): 138 | 139 | unzip(PATH_ZIP,PATH_DATA) 140 | DATA_ORIGINAL = getData(PATH_DATA) 141 | #print(len(DATA_ORIGINAL)) 142 | d = final(DATA_ORIGINAL,DATA_FINAL_SHAPE) 143 | 144 | ''' 145 | 回测代码,忽略 146 | up = 0 147 | down = 0 148 | price = 0 149 | target = 'down' 150 | my = 1 151 | price = DATA_ORIGINAL[0]['last'] 152 | for j in range(len(DATA_ORIGINAL)): 153 | depth = DATA_ORIGINAL[j] 154 | bids = depth['bids'] 155 | asks = depth['asks'] 156 | bidsArea = 0 157 | asksArea = 0 158 | bidsNum = 0 159 | asksNum = 0 160 | 161 | for i in range(0,len(bids)): 162 | bidsArea = bidsArea + bids[i][0]*bids[i][1] 163 | bidsNum = bidsNum + bids[i][1] 164 | for i in range(0,len(asks)): 165 | asksArea = asksArea + asks[i][0]*asks[i][1] 166 | asksNum = asksNum + asks[i][1] 167 | if(round(my,2)<0.001): 168 | return print(j) 169 | if bidsArea/asksArea>1.5: 170 | 171 | if target=='up': 172 | my = ((bids[2][0] -price)/price*10+1)*my*0.997 173 | print('up:' +str(my) +' '+str(bidsArea/asksArea)) 174 | target = 'down' 175 | price = depth['last'] 176 | down+=1 177 | elif bidsArea/asksArea<1.35: 178 | 179 | if target=='down': 180 | my = ((price - asks[2][0])/price*10+1)*my*0.997 181 | print('down:' +str(my) +' '+str(bidsArea/asksArea)) 182 | 183 | target = 'up' 184 | price = depth['last'] 185 | up+=1 186 | 187 | 188 | print('up:'+str(up)+' down:'+str(down)+' my:'+str(round(my,2))) 189 | ''' 190 | 191 | #print('结束!!!!!!!!!!!!正在生成神经网络的数据') 192 | 193 | 194 | train = nnData(DATA_FINAL_SHAPE,d,RATIO,LABEL_LENGT) 195 | write(PATH_NNDATA+'list.data',train['list']) 196 | write(PATH_NNDATA+'label.data',train['label']) 197 | 198 | print('神经网络的数据成功,共生成list:'+str(len(train['list'])) + ' lab:'+str(len(train['label']))) 199 | #return train 200 | 201 | main() 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | -------------------------------------------------------------------------------- /rnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import json 4 | import numpy as np 5 | import os 6 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 7 | 8 | 9 | TEST_LEN = 1000 10 | 11 | def read(name): 12 | f = open('./nndata/'+name+'.data','r') 13 | result = json.loads(f.read()) 14 | f.close() 15 | return result 16 | 17 | #载入数据集 18 | list = np.array(read('list')) 19 | label = np.array(read('label')) 20 | 21 | n_inputs = 48 22 | max_time = 48 23 | lstm_size = 100 24 | n_classes = 4 25 | batch_size = 50 #每批次50个样本 26 | n_batch = (len(label)-TEST_LEN)// batch_size #计算一共有多少个批次 27 | 28 | #这里的none表示第一个维度可以是任意的长度 29 | x = tf.placeholder(tf.float32,[None,n_inputs*max_time]) 30 | #正确的标签 31 | y = tf.placeholder(tf.float32,[None,n_classes]) 32 | 33 | #初始化权值 34 | weights = tf.Variable(tf.truncated_normal([lstm_size, n_classes], stddev=0.1)) 35 | #初始化偏置值 36 | biases = tf.Variable(tf.constant(0.1, shape=[n_classes])) 37 | 38 | 39 | #定义RNN网络 40 | def RNN(X,weights,biases): 41 | # inputs=[batch_size, max_time, n_inputs] 42 | inputs = tf.reshape(X,[-1,max_time,n_inputs]) 43 | #定义LSTM基本CELL 44 | lstm_cell = tf.nn.rnn_cell.LSTMCell(lstm_size) 45 | # final_state[0]是cell state 46 | # final_state[1]是hidden_state 47 | outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32) 48 | results = tf.nn.softmax(tf.matmul(final_state[1],weights) + biases) 49 | return results 50 | 51 | 52 | #计算RNN的返回结果 53 | prediction= RNN(x, weights, biases) 54 | #损失函数 55 | cross_entropy = tf.reduce_mean(y*tf.log(prediction)) 56 | #使用AdamOptimizer进行优化 57 | train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 58 | #结果存放在一个布尔型列表中 59 | correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置 60 | #求准确率 61 | accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#把correct_prediction变为float32类型 62 | #初始化 63 | init = tf.global_variables_initializer() 64 | 65 | with tf.Session() as sess: 66 | sess.run(init) 67 | for epoch in range(6): 68 | for i in range(0,n_batch): 69 | batch_xs = list[(i*batch_size):((i+1)*batch_size)] 70 | batch_ys = label[(i*batch_size):((i+1)*batch_size)] 71 | #batch_xs,batch_ys = mnist.train.next_batch(batch_size) 72 | ax = sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) 73 | 74 | acc = sess.run(accuracy,feed_dict={x:list[:TEST_LEN],y:label[:TEST_LEN]}) 75 | print ("Iter " + str(epoch) + ", Testing Accuracy= " + str(acc)) 76 | 77 | 78 | -------------------------------------------------------------------------------- /store/2018-12-15.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-15.zip -------------------------------------------------------------------------------- /store/2018-12-16.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-16.zip -------------------------------------------------------------------------------- /store/2018-12-17.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-17.zip -------------------------------------------------------------------------------- /store/2018-12-18.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-18.zip -------------------------------------------------------------------------------- /store/2018-12-19.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-19.zip -------------------------------------------------------------------------------- /store/2018-12-20.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-20.zip -------------------------------------------------------------------------------- /store/2018-12-21.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-21.zip -------------------------------------------------------------------------------- /store/2018-12-24.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-24.zip -------------------------------------------------------------------------------- /store/2018-12-25.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-25.zip -------------------------------------------------------------------------------- /store/2018-12-26.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-26.zip -------------------------------------------------------------------------------- /store/2018-12-27.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-27.zip -------------------------------------------------------------------------------- /store/2018-12-28.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jax2014/quant-cnn/fa50faa4aa5c501439a575e7920c4173766e9f4e/store/2018-12-28.zip -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def read(name): 5 | f = open('./nndata/'+name+'.data','r') 6 | result = json.loads(f.read()) 7 | f.close() 8 | return result 9 | 10 | 11 | 12 | list = read('list') 13 | label = read('label') 14 | print(str(len(list))+' '+str(len(label))) --------------------------------------------------------------------------------