├── README.md ├── Web前端配置及使用教程.pdf ├── __init__.py ├── checkpoint_ ├── checkpoint ├── my_modelv1.ckpt.data-00000-of-00001 ├── my_modelv1.ckpt.index ├── my_modelv2.ckpt.data-00000-of-00001 └── my_modelv2.ckpt.index ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── cnn_attention_lstm.cpython-39.pyc └── cnn_attention_lstm.py ├── pred.npy ├── pred.py ├── pred_API.py ├── test_x.npy ├── test_y.npy ├── train_v2.py ├── train_x.npy ├── train_y.npy ├── 主要算法(独立份)使用教程.pdf ├── 原数据 ├── data(1).sql ├── 处理后的数据表.xlsx ├── 玉米期货数据周报7.25.xlsx └── 相关性分析用表.xlsx ├── 我是热力图.png ├── 时间步处理.py ├── 相关性分析.py └── 相关性分析数据.npy /README.md: -------------------------------------------------------------------------------- 1 | # CNN_Attention_LSTM 2 | 基于相关性分析的CNN_Attention_LSTM期货价格预测模型 3 | -------------------------------------------------------------------------------- /Web前端配置及使用教程.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/Web前端配置及使用教程.pdf -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/__init__.py -------------------------------------------------------------------------------- /checkpoint_/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "my_modelv2.ckpt" 2 | all_model_checkpoint_paths: "my_modelv2.ckpt" 3 | -------------------------------------------------------------------------------- /checkpoint_/my_modelv1.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/checkpoint_/my_modelv1.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoint_/my_modelv1.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/checkpoint_/my_modelv1.ckpt.index -------------------------------------------------------------------------------- /checkpoint_/my_modelv2.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/checkpoint_/my_modelv2.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoint_/my_modelv2.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/checkpoint_/my_modelv2.ckpt.index -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/cnn_attention_lstm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/models/__pycache__/cnn_attention_lstm.cpython-39.pyc -------------------------------------------------------------------------------- /models/cnn_attention_lstm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | #这里是主要使用的模型结构 CNN_Attention_LSTM 4 | def attention_3d_block(inputs,TIME_STEPS,SINGLE_ATTENTION_VECTOR): 5 | # inputs.shape = (batch_size, time_steps, input_dim) 6 | # inputs = tf.expand_dims(inputs,1) 7 | input_dim = int(inputs.shape[2]) 8 | a = tf.keras.layers.Permute((2, 1))(inputs) 9 | a = tf.keras.layers.Reshape((input_dim, TIME_STEPS))(a) # this line is not useful. It's just to know which dimension is what. 10 | a = tf.keras.layers.Dense(TIME_STEPS, activation='softmax')(a) 11 | if SINGLE_ATTENTION_VECTOR: 12 | a = tf.keras.layers.Lambda(lambda x: tf.keras.backend.mean(x, axis=1), name='dim_reduction')(a) 13 | a = tf.keras.layers.RepeatVector(input_dim)(a) 14 | a_probs = tf.keras.layers.Permute((2, 1), name='attention_vec')(a) 15 | output_attention_mul = tf.keras.layers.Multiply()([inputs, a_probs]) 16 | return output_attention_mul 17 | 18 | def conv_lstm(TIME_STEPS, INPUT_DIM,lstm_units = 32): 19 | tf.keras.backend.clear_session() # 清除之前的模型,省得压满内存 20 | inputs = tf.keras.Input(shape=(TIME_STEPS, INPUT_DIM,)) 21 | x = tf.keras.layers.Conv1D(65,3,padding='same')(inputs) 22 | x = tf.keras.layers.MaxPooling1D(2)(x) 23 | x = tf.keras.layers.ReLU()(x) 24 | x = tf.keras.layers.Conv1D(128,3,padding='same')(x) 25 | x = tf.keras.layers.MaxPooling1D(2)(x) 26 | x = tf.keras.layers.ReLU()(x) 27 | x = tf.keras.layers.Dropout(0.5)(x) 28 | x = tf.keras.layers.Flatten()(x) 29 | x = tf.keras.layers.RepeatVector(TIME_STEPS)(x) 30 | x = tf.keras.layers.LSTM(lstm_units,return_sequences=True)(x) 31 | x = attention_3d_block(x,TIME_STEPS,1) 32 | x = tf.keras.layers.LSTM(lstm_units)(x) 33 | x = tf.keras.layers.Dense(1024)(x) 34 | output = tf.keras.layers.Dense(1,kernel_regularizer=tf.keras.regularizers.L1L2())(x) 35 | model = tf.keras.Model(inputs=[inputs], outputs=output) 36 | return model -------------------------------------------------------------------------------- /pred.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/pred.npy -------------------------------------------------------------------------------- /pred.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import pandas as pd 4 | import math 5 | import matplotlib.pyplot as plt 6 | from sklearn.preprocessing import MinMaxScaler 7 | import models.cnn_attention_lstm as mv 8 | from sklearn.metrics import mean_squared_error, mean_absolute_error,mean_absolute_percentage_error 9 | #此处为预测程序 10 | a = np.load('train_x.npy') 11 | b = np.load('test_x.npy') 12 | train_y_set = np.load('train_y.npy') 13 | test_y_set = np.load('test_y.npy') 14 | # 归一化 15 | sc = MinMaxScaler(feature_range=(0 , 1)) # 定义归一化:归一化到(0,1)之间 16 | training_set_scaled = sc.fit_transform(a) # 求得训练集的最大值,最小值这些训练集固有的属性,并在训练集上进行归一化 17 | y_training_set_scaled = sc.fit_transform(train_y_set) # 求得训练集的最大值,最小值这些训练集固有的属性,并在训练集上进行归一化 18 | 19 | test_set_scaled = sc.fit_transform(b) # 求得训练集的最大值,最小值这些训练集固有的属性,并在训练集上进行归一化 20 | y_test_set_scaled = sc.fit_transform(test_y_set) # 利用训练集的属性对测试集进行归一化 21 | 22 | 23 | x_train = [] 24 | y_train = [] 25 | 26 | x_test = [] 27 | y_test = [] 28 | 29 | # 利用for循环,遍历整个训练集,提取训练集中连续30天的收盘价作为输入特征x_train,第30天的数据作为标签,for循环共构建1135-235-60=840组数据。 30 | for i in range(60, len(training_set_scaled)): 31 | x_train.append(training_set_scaled[i - 60:i, :]) 32 | y_train.append(y_training_set_scaled[i, :]) 33 | 34 | # 对训练集进行打乱 35 | np.random.seed(7) 36 | np.random.shuffle(x_train) 37 | np.random.seed(7) 38 | np.random.shuffle(y_train) 39 | tf.random.set_seed(7) 40 | # 将训练集由list格式变为array格式 41 | x_train, y_train = np.array(x_train), np.array(y_train) 42 | y_train = np.reshape(y_train,(y_train.shape[0],)) 43 | 44 | # print(x_train.shape) 45 | # print(y_train.shape) 46 | # 使x_train符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。 47 | 48 | # 利用for循环,遍历整个测试集,提取测试集中连续30天的收盘价作为输入特征x_train,第31天的数据作为标签,for循环共构建235-30=200组数据。 49 | for i in range(60, len(test_set_scaled)): 50 | x_test.append(test_set_scaled[i - 60:i, :]) 51 | y_test.append(y_test_set_scaled[i, :]) 52 | # 测试集变array并reshape为符合LSTM输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数] 53 | x_test, y_test = np.array(x_test), np.array(y_test) 54 | y_test = np.reshape(y_test,(y_test.shape[0],)) #标签为一维数组 55 | 56 | 57 | model = mv.conv_lstm(60,6,128) #attention_lstm(时间步长,特征数量,lstm神经元数量) 58 | 59 | 60 | model.load_weights('./checkpoint_/my_modelv2.ckpt') 61 | 62 | ################## predict ###################### 63 | # 测试集输入模型进行预测 64 | predicted_stock_price_t = model.predict(x_train) 65 | 66 | # 对预测数据还原---从(0,1)反归一化到原始范围 67 | predicted_stock_price_t = sc.inverse_transform(predicted_stock_price_t) 68 | 69 | # 对真实数据还原---从(0,1)反归一化到原始范围 70 | y_train = tf.reshape(y_train,(-1,1)) 71 | real_stock_price_t = sc.inverse_transform(y_train) 72 | print(real_stock_price_t.shape) 73 | 74 | # 画出真实数据和预测数据的对比曲线 75 | plt.plot(real_stock_price_t, color='red', label='Maize Close Price') 76 | plt.plot(predicted_stock_price_t, color='blue', label='Predicted Maize Close Price') 77 | plt.title('Maize Close Price traindataset Prediction') 78 | plt.xlabel('Time') 79 | plt.ylabel('Maize Close Price') 80 | plt.legend() 81 | plt.show() 82 | 83 | ##########evaluate############## 84 | # calculate MSE 均方误差 ---> E[(预测值-真实值)^2] (预测值减真实值求平方后求均值) 85 | mse = mean_squared_error(predicted_stock_price_t, real_stock_price_t) 86 | # calculate RMSE 均方根误差--->sqrt[MSE] (对均方误差开方) 87 | rmse = math.sqrt(mean_squared_error(predicted_stock_price_t, real_stock_price_t)) 88 | # calculate MAE 平均绝对误差----->E[|预测值-真实值|](预测值减真实值求绝对值后求均值) 89 | mae = mean_absolute_error(predicted_stock_price_t, real_stock_price_t) 90 | mape = mean_absolute_percentage_error(predicted_stock_price_t, real_stock_price_t) 91 | print('MSE: %.6f' % mse) 92 | print('RMSE: %.6f' % rmse) 93 | print('MAE: %.6f' % mae) 94 | print('MAPE: %.6f' % mape) -------------------------------------------------------------------------------- /pred_API.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import pandas as pd 4 | import math 5 | import matplotlib.pyplot as plt 6 | from sklearn.preprocessing import MinMaxScaler 7 | import models.cnn_attention_lstm as mv 8 | #此处为预测程序连接前端的接口 9 | 10 | def day_60_data_pred(data_x_set): 11 | """ 12 | 这是一个预测函数 13 | :param data_x_set:连续60天的数据集 14 | :return: 预测后具有六个特征的时间步 15 | """ 16 | sc = MinMaxScaler(feature_range=(0 , 1)) 17 | training_set_scaled = sc.fit_transform(data_x_set) 18 | x_train = tf.expand_dims(training_set_scaled,0) 19 | 20 | model = mv.conv_lstm(60,6,128) #attention_lstm(时间步长,特征数量,lstm神经元数量) 21 | 22 | model.load_weights('./checkpoint_/my_modelv2.ckpt') 23 | ################## predict ###################### 24 | # 测试集输入模型进行预测 25 | predicted_stock_price_t = model.predict(x_train) 26 | c = list(predicted_stock_price_t[0]) 27 | for i in range(5): 28 | c.append(np.random.uniform(0,1,1)) 29 | c = np.expand_dims(c,0) 30 | predicted_stock_price_t = sc.inverse_transform(c) 31 | return predicted_stock_price_t 32 | 33 | def next_nday_pred(dataset,n,method): 34 | """ 35 | 这是一个连续预测函数 36 | :param dataset: 连续前60天时间步且每个时间步6个特征的数据 37 | :param n: 需要连续预测的天数 38 | :param method: 预测方法,是一个函数 39 | :return: 为n天玉米主力收盘价的预测值列表 40 | """ 41 | pred_day = [] 42 | for i in range(n): 43 | dataset = dataset[i:, :] 44 | c = method(dataset) 45 | pred_day.append(c[0]) 46 | dataset = list(dataset) 47 | for j in pred_day: 48 | dataset.append(j) 49 | dataset = np.array(dataset) 50 | pred_nday = [] 51 | for k in pred_day: 52 | pred_nday.append(k[0]) 53 | return np.array(pred_nday) 54 | # if __name__ == '__main__': 55 | # # b = pd.read_excel(r'原数据/相关性分析用表.xlsx','Sheet1') 56 | # # b = b.iloc[1:61,1:7].values 57 | # # np.save('pred.npy',b) 58 | # b = np.load('pred.npy') 59 | # a = next_nday_pred(b,7,day_60_data_pred) 60 | # for i in range(len(a)): 61 | # print(f"未来七天的玉米主力收盘价预测第{i+1}天为:",round(a[i],2)) -------------------------------------------------------------------------------- /test_x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/test_x.npy -------------------------------------------------------------------------------- /test_y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/test_y.npy -------------------------------------------------------------------------------- /train_v2.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pandas as pd 3 | import numpy as np 4 | import models.cnn_attention_lstm as mv 5 | from sklearn.preprocessing import MinMaxScaler 6 | #模型训练程序 7 | 8 | # dataframe = pd.read_excel('原数据/相关性分析用表.xlsx','Sheet1') 9 | # train_set = dataframe.iloc[375:1876,1:2].values 10 | # test_set = dataframe.iloc[1:375,1:2].values 11 | # tz_train_set = dataframe.iloc[375:1876,1:7].values 12 | # tz_test_set = dataframe.iloc[1:375,1:7].values 13 | # 14 | # train_set = np.nan_to_num(train_set,nan=2844) 15 | # test_set = np.nan_to_num(test_set,nan=2844) 16 | # tz_train_set = np.nan_to_num(tz_train_set,nan=2923) 17 | # tz_test_set = np.nan_to_num(tz_test_set,nan=2923) 18 | # 19 | # np.save('train_x.npy',tz_train_set) 20 | # np.save('test_x.npy',tz_test_set) 21 | # np.save('train_y.npy',train_set) 22 | # np.save('test_y.npy',test_set) 23 | 24 | a = np.load('train_x.npy') 25 | b = np.load('test_x.npy') 26 | train_y_set = np.load('train_y.npy') 27 | test_y_set = np.load('test_y.npy') 28 | # 归一化 29 | sc = MinMaxScaler(feature_range=(0 , 1)) # 定义归一化:归一化到(0,1)之间 30 | training_set_scaled = sc.fit_transform(a) # 求得训练集的最大值,最小值这些训练集固有的属性,并在训练集上进行归一化 31 | y_training_set_scaled = sc.fit_transform(train_y_set) # 求得训练集的最大值,最小值这些训练集固有的属性,并在训练集上进行归一化 32 | 33 | test_set_scaled = sc.fit_transform(b) # 求得训练集的最大值,最小值这些训练集固有的属性,并在训练集上进行归一化 34 | y_test_set_scaled = sc.fit_transform(test_y_set) # 利用训练集的属性对测试集进行归一化 35 | 36 | 37 | x_train = [] 38 | y_train = [] 39 | 40 | x_test = [] 41 | y_test = [] 42 | 43 | # 利用for循环,遍历整个训练集,提取训练集中连续30天的收盘价作为输入特征x_train,第30天的数据作为标签,for循环共构建1135-235-60=840组数据。 44 | for i in range(60, len(training_set_scaled)): 45 | x_train.append(training_set_scaled[i - 60:i, :]) 46 | y_train.append(y_training_set_scaled[i, :]) 47 | 48 | # 对训练集进行打乱 49 | np.random.seed(7) 50 | np.random.shuffle(x_train) 51 | np.random.seed(7) 52 | np.random.shuffle(y_train) 53 | tf.random.set_seed(7) 54 | # 将训练集由list格式变为array格式 55 | x_train, y_train = np.array(x_train), np.array(y_train) 56 | y_train = np.reshape(y_train,(y_train.shape[0],)) 57 | 58 | # print(x_train.shape) 59 | # print(y_train.shape) 60 | # 使x_train符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。 61 | 62 | # 利用for循环,遍历整个测试集,提取测试集中连续30天的收盘价作为输入特征x_train,第31天的数据作为标签,for循环共构建235-30=200组数据。 63 | for i in range(60, len(test_set_scaled)): 64 | x_test.append(test_set_scaled[i - 60:i, :]) 65 | y_test.append(y_test_set_scaled[i, :]) 66 | # 测试集变array并reshape为符合LSTM输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数] 67 | x_test, y_test = np.array(x_test), np.array(y_test) 68 | y_test = np.reshape(y_test,(y_test.shape[0],)) #标签为一维数组 69 | 70 | 71 | model = mv.conv_lstm(60,6,128) #attention_lstm(时间步长,特征数量,lstm神经元数量) 72 | 73 | model.compile(loss=tf.keras.losses.MeanSquaredError(), optimizer=tf.keras.optimizers.Adam(1e-4)) 74 | 75 | histroy = model.fit(x_train,y_train,validation_data=(x_test,y_test),epochs=3000,batch_size=128,callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath='./checkpoint_/my_modelv2.ckpt',save_weights_only=True,save_best_only=True,verbose=1)]) 76 | -------------------------------------------------------------------------------- /train_x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/train_x.npy -------------------------------------------------------------------------------- /train_y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/train_y.npy -------------------------------------------------------------------------------- /主要算法(独立份)使用教程.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/主要算法(独立份)使用教程.pdf -------------------------------------------------------------------------------- /原数据/处理后的数据表.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/原数据/处理后的数据表.xlsx -------------------------------------------------------------------------------- /原数据/玉米期货数据周报7.25.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/原数据/玉米期货数据周报7.25.xlsx -------------------------------------------------------------------------------- /原数据/相关性分析用表.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/原数据/相关性分析用表.xlsx -------------------------------------------------------------------------------- /我是热力图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/我是热力图.png -------------------------------------------------------------------------------- /时间步处理.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | #其实就是数据预处理,不用管,经供参考 4 | data = pd.read_excel(r'C:\Users\20426\Desktop\okk.xlsx',sheet_name='Sheet1') 5 | 6 | dataset = data.iloc[1:1261,:14].values 7 | # np.save('10_12.npy',dataset) 8 | # dataset = np.load('10_12.npy') 9 | dataset = np.nan_to_num(dataset) 10 | a = [] 11 | b = [] 12 | for i in range(len(dataset)): 13 | for j in range(len(dataset[i])):#0-1 14 | aa = round(dataset[i][j]/30,2) 15 | for k in range(30): 16 | if j == 0: 17 | a.append(aa) 18 | else: 19 | b.append(aa) 20 | c = [] 21 | for i in range(len(a)): 22 | kk = [] 23 | kk.append(a[i]) 24 | kk.append(b[i]) 25 | c.append(kk) 26 | 27 | c = pd.DataFrame(c) 28 | c.to_excel('okk.xlsx') 29 | -------------------------------------------------------------------------------- /相关性分析.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import seaborn as sns 5 | #热力图生成 6 | data = pd.read_excel(r'C:\Users\20426\Desktop\okk.xlsx',sheet_name='Sheet1') 7 | df = data.iloc[1:1161,1:14] 8 | result5 = df.corr(method='pearson') 9 | # np.save('相关性分析数据.npy',dataset) 10 | # dataset = np.load('相关性分析数据.npy') 11 | # df = pd.DataFrame(dataset) 12 | rc = {'font.sans-serif': 'SimHei', 13 | 'axes.unicode_minus': False} 14 | sns.set(font_scale=0.5,rc=rc) # 设置字体大小 15 | sns.heatmap(result5, 16 | annot=True, # 显示相关系数的数据 17 | center=0.5, # 居中 18 | fmt='.2f', # 只显示两位小数 19 | linewidth=0.5, # 设置每个单元格的距离 20 | linecolor='blue', # 设置间距线的颜色 21 | vmin=0, vmax=1, # 设置数值最小值和最大值 22 | xticklabels=True, yticklabels=True, # 显示x轴和y轴 23 | square=True, # 每个方格都是正方形 24 | cbar=True, # 绘制颜色条 25 | cmap='coolwarm_r', # 设置热力图颜色 26 | ) 27 | plt.savefig("我是热力图.png",dpi=2000)#保存图片,分辨率为600 28 | plt.ion() #显示图片 29 | -------------------------------------------------------------------------------- /相关性分析数据.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanshuishou/CNN_Attention_LSTM/3b178cbbfcc572857ba32cddaa1390a4ddc2df13/相关性分析数据.npy --------------------------------------------------------------------------------