├── requirements.txt ├── LICENSE ├── utils.py ├── dataloader.py ├── model.py ├── README.md ├── model2.py ├── learn.py └── learn_memory.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | pandas==1.2.3 3 | scipy==1.6.1 4 | torch==1.8.0 5 | scikit_learn==0.24.1 6 | tqdm==4.61.1 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 MingjieWang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | 5 | def mse(pred, label): 6 | loss = (pred - label)**2 7 | return torch.mean(loss) 8 | 9 | def mae(pred, label): 10 | loss = (pred - label).abs() 11 | return torch.mean(loss) 12 | 13 | def cal_cos_similarity(x, y): # the 2nd dimension of x and y are the same 14 | xy = x.mm(torch.t(y)) 15 | x_norm = torch.sqrt(torch.sum(x*x, dim =1)).reshape(-1, 1) 16 | y_norm = torch.sqrt(torch.sum(y*y, dim =1)).reshape(-1, 1) 17 | cos_similarity = xy/x_norm.mm(torch.t(y_norm)) 18 | cos_similarity[cos_similarity != cos_similarity] = 0 19 | return cos_similarity 20 | 21 | 22 | def cal_convariance(x, y): # the 2nd dimension of x and y are the same 23 | e_x = torch.mean(x, dim = 1).reshape(-1, 1) 24 | e_y = torch.mean(y, dim = 1).reshape(-1, 1) 25 | e_x_e_y = e_x.mm(torch.t(e_y)) 26 | x_extend = x.reshape(x.shape[0], 1, x.shape[1]).repeat(1, y.shape[0], 1) 27 | y_extend = y.reshape(1, y.shape[0], y.shape[1]).repeat(x.shape[0], 1, 1) 28 | e_xy = torch.mean(x_extend*y_extend, dim = 2) 29 | return e_xy - e_x_e_y 30 | 31 | 32 | def metric_fn(preds): 33 | preds = preds[~np.isnan(preds['label'])] 34 | precision = {} 35 | recall = {} 36 | temp = preds.groupby(level='datetime').apply(lambda x: x.sort_values(by='score', ascending=False)) 37 | if len(temp.index[0]) > 2: 38 | temp = temp.reset_index(level =0).drop('datetime', axis = 1) 39 | 40 | for k in [1, 3, 5, 10, 20, 30, 50, 100]: 41 | precision[k] = temp.groupby(level='datetime').apply(lambda x:(x.label[:k]>0).sum()/k).mean() 42 | recall[k] = temp.groupby(level='datetime').apply(lambda x:(x.label[:k]>0).sum()/(x.label>0).sum()).mean() 43 | 44 | ic = preds.groupby(level='datetime').apply(lambda x: x.label.corr(x.score)).mean() 45 | rank_ic = preds.groupby(level='datetime').apply(lambda x: x.label.corr(x.score, method='spearman')).mean() 46 | 47 | return precision, recall, ic, rank_ic 48 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class DataLoader: 5 | 6 | def __init__(self, df_feature, df_label, df_market_value, df_stock_index, batch_size=800, pin_memory=True, start_index = 0, device=None): 7 | 8 | assert len(df_feature) == len(df_label) 9 | 10 | self.df_feature = df_feature.values 11 | self.df_label = df_label.values 12 | self.df_market_value = df_market_value 13 | self.df_stock_index = df_stock_index 14 | self.device = device 15 | 16 | if pin_memory: 17 | self.df_feature = torch.tensor(self.df_feature, dtype=torch.float, device=device) 18 | self.df_label = torch.tensor(self.df_label, dtype=torch.float, device=device) 19 | self.df_market_value = torch.tensor(self.df_market_value, dtype=torch.float, device=device) 20 | self.df_stock_index = torch.tensor(self.df_stock_index, dtype=torch.long, device=device) 21 | 22 | self.index = df_label.index 23 | 24 | self.batch_size = batch_size 25 | self.pin_memory = pin_memory 26 | self.start_index = start_index 27 | 28 | self.daily_count = df_label.groupby(level=0).size().values 29 | self.daily_index = np.roll(np.cumsum(self.daily_count), 1) 30 | self.daily_index[0] = 0 31 | 32 | @property 33 | def batch_length(self): 34 | 35 | if self.batch_size <= 0: 36 | return self.daily_length 37 | 38 | return len(self.df_label) // self.batch_size 39 | 40 | @property 41 | def daily_length(self): 42 | 43 | return len(self.daily_count) 44 | 45 | def iter_batch(self): 46 | if self.batch_size <= 0: 47 | yield from self.iter_daily_shuffle() 48 | return 49 | 50 | indices = np.arange(len(self.df_label)) 51 | np.random.shuffle(indices) 52 | 53 | for i in range(len(indices))[::self.batch_size]: 54 | if len(indices) - i < self.batch_size: 55 | break 56 | yield i, indices[i:i+self.batch_size] # NOTE: advanced indexing will cause copy 57 | 58 | def iter_daily_shuffle(self): 59 | indices = np.arange(len(self.daily_count)) 60 | np.random.shuffle(indices) 61 | for i in indices: 62 | yield i, slice(self.daily_index[i], self.daily_index[i] + self.daily_count[i]) 63 | 64 | def iter_daily(self): 65 | indices = np.arange(len(self.daily_count)) 66 | for i in indices: 67 | yield i, slice(self.daily_index[i], self.daily_index[i] + self.daily_count[i]) 68 | # for idx, count in zip(self.daily_index, self.daily_count): 69 | # yield slice(idx, idx + count) # NOTE: slice index will not cause copy 70 | 71 | def get(self, slc): 72 | outs = self.df_feature[slc], self.df_label[slc][:,0], self.df_market_value[slc], self.df_stock_index[slc] 73 | # outs = self.df_feature[slc], self.df_label[slc] 74 | 75 | if not self.pin_memory: 76 | outs = tuple(torch.tensor(x, device=self.device) for x in outs) 77 | 78 | return outs + (self.index[slc],) 79 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import numpy as np 5 | from utils import cal_cos_similarity 6 | 7 | class MLP(nn.Module): 8 | 9 | def __init__(self, d_feat, hidden_size=512, num_layers=3, dropout=0.0): 10 | super().__init__() 11 | 12 | self.mlp = nn.Sequential() 13 | 14 | for i in range(num_layers): 15 | if i > 0: 16 | self.mlp.add_module('drop_%d'%i, nn.Dropout(dropout)) 17 | self.mlp.add_module('fc_%d'%i, nn.Linear( 18 | 360 if i == 0 else hidden_size, hidden_size)) 19 | self.mlp.add_module('relu_%d'%i, nn.ReLU()) 20 | 21 | self.mlp.add_module('fc_out', nn.Linear(hidden_size, 1)) 22 | 23 | def forward(self, x): 24 | # feature 25 | # [N, F] 26 | return self.mlp(x).squeeze() 27 | 28 | class HIST(nn.Module): 29 | def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU", K =3): 30 | super().__init__() 31 | 32 | self.d_feat = d_feat 33 | self.hidden_size = hidden_size 34 | 35 | self.rnn = nn.GRU( 36 | input_size=d_feat, 37 | hidden_size=hidden_size, 38 | num_layers=num_layers, 39 | batch_first=True, 40 | dropout=dropout, 41 | ) 42 | 43 | self.fc_ps = nn.Linear(hidden_size, hidden_size) 44 | torch.nn.init.xavier_uniform_(self.fc_ps.weight) 45 | self.fc_hs = nn.Linear(hidden_size, hidden_size) 46 | torch.nn.init.xavier_uniform_(self.fc_hs.weight) 47 | 48 | self.fc_ps_fore = nn.Linear(hidden_size, hidden_size) 49 | torch.nn.init.xavier_uniform_(self.fc_ps_fore.weight) 50 | self.fc_hs_fore = nn.Linear(hidden_size, hidden_size) 51 | torch.nn.init.xavier_uniform_(self.fc_hs_fore.weight) 52 | 53 | self.fc_ps_back = nn.Linear(hidden_size, hidden_size) 54 | torch.nn.init.xavier_uniform_(self.fc_ps_back.weight) 55 | self.fc_hs_back = nn.Linear(hidden_size, hidden_size) 56 | torch.nn.init.xavier_uniform_(self.fc_hs_back.weight) 57 | self.fc_indi = nn.Linear(hidden_size, hidden_size) 58 | torch.nn.init.xavier_uniform_(self.fc_indi.weight) 59 | 60 | self.leaky_relu = nn.LeakyReLU() 61 | self.softmax_s2t = torch.nn.Softmax(dim = 0) 62 | self.softmax_t2s = torch.nn.Softmax(dim = 1) 63 | 64 | self.fc_out_ps = nn.Linear(hidden_size, 1) 65 | self.fc_out_hs = nn.Linear(hidden_size, 1) 66 | self.fc_out_indi = nn.Linear(hidden_size, 1) 67 | self.fc_out = nn.Linear(hidden_size, 1) 68 | self.K = K 69 | 70 | def cal_cos_similarity(self, x, y): # the 2nd dimension of x and y are the same 71 | xy = x.mm(torch.t(y)) 72 | x_norm = torch.sqrt(torch.sum(x*x, dim =1)).reshape(-1, 1) 73 | y_norm = torch.sqrt(torch.sum(y*y, dim =1)).reshape(-1, 1) 74 | cos_similarity = xy/x_norm.mm(torch.t(y_norm)) 75 | cos_similarity[cos_similarity != cos_similarity] = 0 76 | return cos_similarity 77 | 78 | def forward(self, x, concept_matrix, market_value): 79 | device = torch.device(torch.get_device(x)) 80 | x_hidden = x.reshape(len(x), self.d_feat, -1) # [N, F, T] 81 | x_hidden = x_hidden.permute(0, 2, 1) # [N, T, F] 82 | x_hidden, _ = self.rnn(x_hidden) 83 | x_hidden = x_hidden[:, -1, :] 84 | 85 | # Predefined Concept Module 86 | 87 | market_value_matrix = market_value.reshape(market_value.shape[0], 1).repeat(1, concept_matrix.shape[1]) 88 | stock_to_concept = concept_matrix * market_value_matrix 89 | 90 | stock_to_concept_sum = torch.sum(stock_to_concept, 0).reshape(1, -1).repeat(stock_to_concept.shape[0], 1) 91 | stock_to_concept_sum = stock_to_concept_sum.mul(concept_matrix) 92 | 93 | stock_to_concept_sum = stock_to_concept_sum + (torch.ones(stock_to_concept.shape[0], stock_to_concept.shape[1]).to(device)) 94 | stock_to_concept = stock_to_concept / stock_to_concept_sum 95 | hidden = torch.t(stock_to_concept).mm(x_hidden) 96 | 97 | hidden = hidden[hidden.sum(1)!=0] 98 | stock_to_concept = x_hidden.mm(torch.t(hidden)) 99 | # stock_to_concept = cal_cos_similarity(x_hidden, hidden) 100 | stock_to_concept = self.softmax_s2t(stock_to_concept) 101 | hidden = torch.t(stock_to_concept).mm(x_hidden) 102 | 103 | concept_to_stock = cal_cos_similarity(x_hidden, hidden) 104 | concept_to_stock = self.softmax_t2s(concept_to_stock) 105 | 106 | p_shared_info = concept_to_stock.mm(hidden) 107 | p_shared_info = self.fc_ps(p_shared_info) 108 | 109 | p_shared_back = self.fc_ps_back(p_shared_info) 110 | output_ps = self.fc_ps_fore(p_shared_info) 111 | output_ps = self.leaky_relu(output_ps) 112 | 113 | pred_ps = self.fc_out_ps(output_ps).squeeze() 114 | 115 | # Hidden Concept Module 116 | h_shared_info = x_hidden - p_shared_back 117 | hidden = h_shared_info 118 | h_stock_to_concept = cal_cos_similarity(h_shared_info, hidden) 119 | 120 | dim = h_stock_to_concept.shape[0] 121 | diag = h_stock_to_concept.diagonal(0) 122 | h_stock_to_concept = h_stock_to_concept * (torch.ones(dim, dim) - torch.eye(dim)).to(device) 123 | # row = torch.linspace(0,dim-1,dim).to(device).long() 124 | # column = h_stock_to_concept.argmax(1) 125 | row = torch.linspace(0, dim-1, dim).reshape([-1, 1]).repeat(1, self.K).reshape(1, -1).long().to(device) 126 | column = torch.topk(h_stock_to_concept, self.K, dim = 1)[1].reshape(1, -1) 127 | mask = torch.zeros([h_stock_to_concept.shape[0], h_stock_to_concept.shape[1]], device = h_stock_to_concept.device) 128 | mask[row, column] = 1 129 | h_stock_to_concept = h_stock_to_concept * mask 130 | h_stock_to_concept = h_stock_to_concept + torch.diag_embed((h_stock_to_concept.sum(0)!=0).float()*diag) 131 | hidden = torch.t(h_shared_info).mm(h_stock_to_concept).t() 132 | hidden = hidden[hidden.sum(1)!=0] 133 | 134 | h_concept_to_stock = cal_cos_similarity(h_shared_info, hidden) 135 | h_concept_to_stock = self.softmax_t2s(h_concept_to_stock) 136 | h_shared_info = h_concept_to_stock.mm(hidden) 137 | h_shared_info = self.fc_hs(h_shared_info) 138 | 139 | h_shared_back = self.fc_hs_back(h_shared_info) 140 | output_hs = self.fc_hs_fore(h_shared_info) 141 | output_hs = self.leaky_relu(output_hs) 142 | pred_hs = self.fc_out_hs(output_hs).squeeze() 143 | 144 | # Individual Information Module 145 | individual_info = x_hidden - p_shared_back - h_shared_back 146 | output_indi = individual_info 147 | output_indi = self.fc_indi(output_indi) 148 | output_indi = self.leaky_relu(output_indi) 149 | pred_indi = self.fc_out_indi(output_indi).squeeze() 150 | # Stock Trend Prediction 151 | all_info = output_ps + output_hs + output_indi 152 | pred_all = self.fc_out(all_info).squeeze() 153 | return pred_all 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MTMD: Multi-Scale Temporal Memory Learning and Efficient Debiasing Framework for Stock Trend Forecasting 2 | The official implementation of the paper "[MTMD: Multi-Scale Temporal Memory Learning and Efficient Debiasing Framework for Stock Trend Forecasting](https://ieeexplore.ieee.org/document/10906481/)". 3 | ![image](https://i.ibb.co/5MFPqTJ/12.png) 4 | 5 | 🎺🎺🎺 Good News! We have established new code in the QLIB library, which allows you to test MTMD with dozens of models and larger datasets simultaneously! And excitingly, MTMD remains the SOTA (State Of The Art) model. Please check [here](https://github.com/tianshijing/qlib/blob/main/examples/benchmarks/README.md)! 6 | 7 | 🔈🔈🔈 Good News! We are thrilled to announce that our article has been officially published in IEEE Transactions on Emerging Topics in Computational Intelligence. 8 | 9 | ## Environment 10 | 1. Install python3.7, 3.8 or 3.9. 11 | 2. Install the requirements in [requirements.txt](https://github.com/Wentao-Xu/HIST/blob/main/requirements.txt). 12 | 3. Install the quantitative investment platform [Qlib](https://github.com/microsoft/qlib) and download the data from Qlib: 13 | ``` 14 | # install Qlib from source 15 | pip install --upgrade cython 16 | git clone https://github.com/microsoft/qlib.git && cd qlib 17 | python setup.py install 18 | 19 | # Download the stock features of Alpha360 from Qlib 20 | python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn --version v2 21 | mkdir data 22 | ``` 23 | 4. Please download the [concept matrix](https://github.com/Wentao-Xu/HIST/tree/main/data), which is provided by [tushare](https://tushare.pro/document/2?doc_id=81). 24 | 5. Please put the concept data and stock data in the new' data' folder. 25 | 26 | 27 | ## The result in qlib: 28 | 29 | | Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown | 30 | |-------------------------------------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------| 31 | | Transformer(Ashish Vaswani, et al.) | Alpha360 | 0.0114±0.00 | 0.0716±0.03 | 0.0327±0.00 | 0.2248±0.02 | -0.0270±0.03 | -0.3378±0.37 | -0.1653±0.05 | 32 | | TabNet(Sercan O. Arik, et al.) | Alpha360 | 0.0099±0.00 | 0.0593±0.00 | 0.0290±0.00 | 0.1887±0.00 | -0.0369±0.00 | -0.3892±0.00 | -0.2145±0.00 | 33 | | MLP | Alpha360 | 0.0273±0.00 | 0.1870±0.02 | 0.0396±0.00 | 0.2910±0.02 | 0.0029±0.02 | 0.0274±0.23 | -0.1385±0.03 | 34 | | Localformer(Juyong Jiang, et al.) | Alpha360 | 0.0404±0.00 | 0.2932±0.04 | 0.0542±0.00 | 0.4110±0.03 | 0.0246±0.02 | 0.3211±0.21 | -0.1095±0.02 | 35 | | CatBoost((Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0378±0.00 | 0.2714±0.00 | 0.0467±0.00 | 0.3659±0.00 | 0.0292±0.00 | 0.3781±0.00 | -0.0862±0.00 | 36 | | XGBoost(Tianqi Chen, et al.) | Alpha360 | 0.0394±0.00 | 0.2909±0.00 | 0.0448±0.00 | 0.3679±0.00 | 0.0344±0.00 | 0.4527±0.02 | -0.1004±0.00 | 37 | | DoubleEnsemble(Chuheng Zhang, et al.) | Alpha360 | 0.0390±0.00 | 0.2946±0.01 | 0.0486±0.00 | 0.3836±0.01 | 0.0462±0.01 | 0.6151±0.18 | -0.0915±0.01 | 38 | | LightGBM(Guolin Ke, et al.) | Alpha360 | 0.0400±0.00 | 0.3037±0.00 | 0.0499±0.00 | 0.4042±0.00 | 0.0558±0.00 | 0.7632±0.00 | -0.0659±0.00 | 39 | | TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 | 40 | | ALSTM (Yao Qin, et al.) | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02 | 0.8651±0.31 | -0.0994±0.03 | 41 | | LSTM(Sepp Hochreiter, et al.) | Alpha360 | 0.0448±0.00 | 0.3474±0.04 | 0.0549±0.00 | 0.4366±0.03 | 0.0647±0.03 | 0.8963±0.39 | -0.0875±0.02 | 42 | | ADD | Alpha360 | 0.0430±0.00 | 0.3188±0.04 | 0.0559±0.00 | 0.4301±0.03 | 0.0667±0.02 | 0.8992±0.34 | -0.0855±0.02 | 43 | | GRU(Kyunghyun Cho, et al.) | Alpha360 | 0.0493±0.00 | 0.3772±0.04 | 0.0584±0.00 | 0.4638±0.03 | 0.0720±0.02 | 0.9730±0.33 | -0.0821±0.02 | 44 | | AdaRNN(Yuntao Du, et al.) | Alpha360 | 0.0464±0.01 | 0.3619±0.08 | 0.0539±0.01 | 0.4287±0.06 | 0.0753±0.03 | 1.0200±0.40 | -0.0936±0.03 | 45 | | GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 | 46 | | TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 | 47 | | TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 | 48 | | IGMTF(Wentao Xu, et al.) | Alpha360 | 0.0480±0.00 | 0.3589±0.02 | 0.0606±0.00 | 0.4773±0.01 | 0.0946±0.02 | 1.3509±0.25 | -0.0716±0.02 | 49 | | HIST(Wentao Xu, et al.) | Alpha360 | 0.0522±0.00 | 0.3530±0.01 | 0.0667±0.00 | 0.4576±0.01 | 0.0987±0.02 | 1.3726±0.27 | -0.0681±0.01 | 50 | | KRNN | Alpha360 | 0.0173±0.01 | 0.1210±0.06 | 0.0270±0.01 | 0.2018±0.04 | -0.0465±0.05 | -0.5415±0.62 | -0.2919±0.13 | 51 | | Sandwich | Alpha360 | 0.0258±0.00 | 0.1924±0.04 | 0.0337±0.00 | 0.2624±0.03 | 0.0005±0.03 | 0.0001±0.33 | -0.1752±0.05 | 52 | | MTMD(Mingjie Wang, et al.) | Alpha360 | 0.0538±0.00 | 0.3849±0.01 | 0.0672±0.00 | 0.4656±0.01 | 0.1022±0.02 | 1.4031±0.26 | -0.0664±0.01 | 53 | 54 | 55 | 56 | ## Reproduce the stock trend forecasting results 57 | ![image](https://i.ibb.co/X7CVp2v/res.png) 58 | 59 | ``` 60 | git clone https://github.com/MingjieWang0606/MTMD-Public.git 61 | cd MTMD-Public 62 | mkdir output 63 | ``` 64 | 65 | ### Reproduce our MTMD framework 66 | ``` 67 | # CSI 100 68 | python learn_memory.py --model_name HIST --data_set csi100 --hidden_size 128 --num_layers 2 --outdir ./output/csi100_MTMD 69 | 70 | # CSI 300 71 | python learn_memory.py --model_name HIST --data_set csi300 --hidden_size 128 --num_layers 2 --outdir ./output/csi300_MTMD 72 | ``` 73 | 74 | ### Reproduce our HIST framework 75 | ``` 76 | # CSI 100 77 | python learn.py --model_name HIST --data_set csi100 --hidden_size 128 --num_layers 2 --outdir ./output/csi100_HIST 78 | 79 | # CSI 300 80 | python learn.py --model_name HIST --data_set csi300 --hidden_size 128 --num_layers 2 --outdir ./output/csi300_HIST 81 | ``` 82 | ### Reproduce the baselines 83 | * MLP 84 | ``` 85 | # MLP on CSI 100 86 | python learn.py --model_name MLP --data_set csi100 --hidden_size 512 --num_layers 3 --outdir ./output/csi100_MLP 87 | 88 | # MLP on CSI 300 89 | python learn.py --model_name MLP --data_set csi300 --hidden_size 512 --num_layers 3 --outdir ./output/csi300_MLP 90 | ``` 91 | 92 | * LSTM 93 | ``` 94 | # LSTM on CSI 100 95 | python learn.py --model_name LSTM --data_set csi100 --hidden_size 128 --num_layers 2 --outdir ./output/csi100_LSTM 96 | 97 | # LSTM on CSI 300 98 | python learn.py --model_name LSTM --data_set csi300 --hidden_size 128 --num_layers 2 --outdir ./output/csi300_LSTM 99 | ``` 100 | 101 | * GRU 102 | ``` 103 | # GRU on CSI 100 104 | python learn.py --model_name GRU --data_set csi100 --hidden_size 128 --num_layers 2 --outdir ./output/csi100_GRU 105 | 106 | # GRU on CSI 300 107 | python learn.py --model_name GRU --data_set csi300 --hidden_size 64 --num_layers 2 --outdir ./output/csi300_GRU 108 | ``` 109 | 110 | * SFM 111 | ``` 112 | # SFM on CSI 100 113 | python learn.py --model_name SFM --data_set csi100 --hidden_size 64 --num_layers 2 --outdir ./output/csi100_SFM 114 | 115 | # SFM on CSI 300 116 | python learn.py --model_name SFM --data_set csi300 --hidden_size 128 --num_layers 2 --outdir ./output/csi300_SFM 117 | ``` 118 | 119 | * GATs 120 | ``` 121 | # GATs on CSI 100 122 | python learn.py --model_name GATs --data_set csi100 --hidden_size 128 --num_layers 2 --outdir ./output/csi100_GATs 123 | 124 | # GATs on CSI 300 125 | python learn.py --model_name GATs --data_set csi300 --hidden_size 64 --num_layers 2 --outdir ./output/csi300_GATs 126 | ``` 127 | 128 | * ALSTM 129 | ``` 130 | # ALSTM on CSI 100 131 | python learn.py --model_name ALSTM --data_set csi100 --hidden_size 64 --num_layers 2 --outdir ./output/csi100_ALSTM 132 | 133 | # ALSTM on CSI 300 134 | python learn.py --model_name ALSTM --data_set csi300 --hidden_size 128 --num_layers 2 --outdir ./output/csi300_ALSTM 135 | ``` 136 | 137 | * Transformer 138 | ``` 139 | # Transformer on CSI 100 140 | python learn.py --model_name Transformer --data_set csi100 --hidden_size 32 --num_layers 3 --outdir ./output/csi100_Transformer 141 | 142 | # Transformer on CSI 300 143 | python learn.py --model_name Transformer --data_set csi300 --hidden_size 32 --num_layers 3 --outdir ./output/csi300_Transformer 144 | ``` 145 | 146 | * ALSTM+TRA 147 | 148 | We reproduce the ALSTM+TRA with its [source code](https://github.com/microsoft/qlib/tree/main/examples/benchmarks/TRA). 149 | 150 | ### Acknowledgements 151 | Special thanks to ChenFeng, Zhang Mingze,Tian Junxi and LiTingXin for the their help and discussion! 152 | Thanks for the clean and efficient [HIST](https://github.com/Wentao-Xu/HIST) code. 153 | 154 | 155 | ## Citation 156 | Please cite the following paper if you use this code in your work. 157 | ``` 158 | @article{wang2022mtmd, 159 | title={MTMD: Multi-Scale Temporal Memory Learning and Efficient Debiasing Framework for Stock Trend Forecasting}, 160 | author={Mingjie Wang and Juanxi Tian and Mingze Zhang and Jianxiong Guo and Weijia Jia}, 161 | journal={arXiv preprint arXiv:2212.08656}, 162 | year={2022} 163 | } 164 | ``` 165 | 166 | -------------------------------------------------------------------------------- /model2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import numpy as np 5 | from utils import cal_cos_similarity 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | 9 | 10 | class MLP(nn.Module): 11 | 12 | def __init__(self, d_feat, hidden_size=512, num_layers=3, dropout=0.0): 13 | super().__init__() 14 | 15 | self.mlp = nn.Sequential() 16 | 17 | for i in range(num_layers): 18 | if i > 0: 19 | self.mlp.add_module('drop_%d'%i, nn.Dropout(dropout)) 20 | self.mlp.add_module('fc_%d'%i, nn.Linear( 21 | 360 if i == 0 else hidden_size, hidden_size)) 22 | self.mlp.add_module('relu_%d'%i, nn.ReLU()) 23 | 24 | self.mlp.add_module('fc_out', nn.Linear(hidden_size, 1)) 25 | 26 | def forward(self, x): 27 | # feature 28 | # [N, F] 29 | return self.mlp(x).squeeze() 30 | def get_score(K, Q): 31 | score = torch.matmul(Q,torch.t(K) ) 32 | score_query = F.softmax(score, dim=0) 33 | score_memory = F.softmax(score,dim=1) 34 | return score_query, score_memory 35 | 36 | def read(Q,K): 37 | softmax_score_query, softmax_score_memory = get_score(K,Q) 38 | query_reshape = Q.contiguous() 39 | concat_memory = torch.matmul(softmax_score_memory.detach(), K) # (b X h X w) X d 40 | updated_query = query_reshape*concat_memory 41 | return updated_query,softmax_score_query,softmax_score_memory 42 | 43 | def upload(Q,K): 44 | softmax_score_query, softmax_score_memory = get_score(K,Q) 45 | query_reshape = Q.contiguous() 46 | _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1) 47 | _, updating_indices = torch.topk(softmax_score_query, 1, dim=0) 48 | m, d = K.size() 49 | query_update = torch.zeros((m,d)).cuda() 50 | random_update = torch.zeros((m,d)).cuda() 51 | for i in range(m): 52 | idx = torch.nonzero(gathering_indices.squeeze(1)==i) 53 | a, _ = idx.size() 54 | if a != 0: 55 | query_update[i] = torch.sum(((softmax_score_query[idx,i] / torch.max(softmax_score_query[:,i])) 56 | *query_reshape[idx].squeeze(1)), dim=0) 57 | else: 58 | query_update[i] = 0 59 | updated_memory = F.normalize(query_update.cuda() + K.cuda(), dim=1) 60 | return updated_memory.detach() 61 | 62 | class HIST(nn.Module): 63 | def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU", K =3): 64 | super().__init__() 65 | 66 | self.d_feat = d_feat 67 | self.hidden_size = hidden_size 68 | 69 | self.rnn = nn.GRU( 70 | input_size=d_feat, 71 | hidden_size=hidden_size, 72 | num_layers=num_layers, 73 | batch_first=True, 74 | dropout=dropout, 75 | ) 76 | 77 | self.fc_ps = nn.Linear(hidden_size, hidden_size) 78 | torch.nn.init.xavier_uniform_(self.fc_ps.weight) 79 | self.fc_hs = nn.Linear(hidden_size, hidden_size) 80 | torch.nn.init.xavier_uniform_(self.fc_hs.weight) 81 | 82 | self.fc_ps_fore = nn.Linear(hidden_size, hidden_size) 83 | torch.nn.init.xavier_uniform_(self.fc_ps_fore.weight) 84 | self.fc_hs_fore = nn.Linear(hidden_size, hidden_size) 85 | torch.nn.init.xavier_uniform_(self.fc_hs_fore.weight) 86 | 87 | self.fc_ps_back = nn.Linear(hidden_size, hidden_size) 88 | torch.nn.init.xavier_uniform_(self.fc_ps_back.weight) 89 | self.fc_hs_back = nn.Linear(hidden_size, hidden_size) 90 | torch.nn.init.xavier_uniform_(self.fc_hs_back.weight) 91 | self.fc_indi = nn.Linear(hidden_size, hidden_size) 92 | torch.nn.init.xavier_uniform_(self.fc_indi.weight) 93 | 94 | self.leaky_relu = nn.LeakyReLU() 95 | self.softmax_s2t = torch.nn.Softmax(dim = 0) 96 | self.softmax_t2s = torch.nn.Softmax(dim = 1) 97 | 98 | self.fc_out_ps = nn.Linear(hidden_size, 1) 99 | self.fc_out_hs = nn.Linear(hidden_size, 1) 100 | self.fc_out_indi = nn.Linear(hidden_size, 1) 101 | self.fc_out = nn.Linear(hidden_size, 1) 102 | self.K = K 103 | 104 | def cal_cos_similarity(self, x, y): # the 2nd dimension of x and y are the same 105 | xy = x.mm(torch.t(y)) 106 | x_norm = torch.sqrt(torch.sum(x*x, dim =1)).reshape(-1, 1) 107 | y_norm = torch.sqrt(torch.sum(y*y, dim =1)).reshape(-1, 1) 108 | cos_similarity = xy/x_norm.mm(torch.t(y_norm)) 109 | cos_similarity[cos_similarity != cos_similarity] = 0 110 | return cos_similarity 111 | 112 | def forward(self, x, concept_matrix, market_value,m_items,train): 113 | #print(x.size()) 114 | device = torch.device(torch.get_device(x)) 115 | ######################### modification ######################### 116 | # get memory items,这部分不需要变 117 | m_item0, m_item1 = m_items 118 | #[m_item1] = m_items 119 | ######################### modification ######################### 120 | x_hidden = x.reshape(len(x), self.d_feat, -1) # [N, F, T] 121 | x_hidden = x_hidden.permute(0, 2, 1) # [N, T, F] 122 | x_hidden, _ = self.rnn(x_hidden) 123 | x_hidden = x_hidden[:, -1, :] 124 | # print(x_hidden.shape) 125 | # print(m_items0.shape) 126 | 127 | # Predefined Concept Module 128 | # representation initialization 129 | market_value_matrix = market_value.reshape(market_value.shape[0], 1).repeat(1, concept_matrix.shape[1]) 130 | stock_to_concept = concept_matrix * market_value_matrix 131 | stock_to_concept_sum = torch.sum(stock_to_concept, 0).reshape(1, -1).repeat(stock_to_concept.shape[0], 1) 132 | stock_to_concept_sum = stock_to_concept_sum.mul(concept_matrix) 133 | 134 | stock_to_concept_sum = stock_to_concept_sum + (torch.ones(stock_to_concept.shape[0], stock_to_concept.shape[1]).to(device)) 135 | stock_to_concept = stock_to_concept / stock_to_concept_sum 136 | hidden = torch.t(stock_to_concept).mm(x_hidden) 137 | 138 | hidden = hidden[hidden.sum(1)!=0] 139 | 140 | stock_to_concept = x_hidden.mm(torch.t(hidden)) 141 | # stock_to_concept = cal_cos_similarity(x_hidden, hidden) 142 | 143 | stock_to_concept = self.softmax_s2t(stock_to_concept) 144 | hidden = torch.t(stock_to_concept).mm(x_hidden) 145 | 146 | concept_to_stock = cal_cos_similarity(x_hidden, hidden) 147 | concept_to_stock = self.softmax_t2s(concept_to_stock) 148 | 149 | p_shared_info = concept_to_stock.mm(hidden) 150 | p_shared_info = self.fc_ps(p_shared_info) 151 | 152 | # ######################### modification ######################### 153 | # ## 修改点0,对X^(t,0)添加记忆模块,即x_hidden 154 | # #import pdb;pdb.set_trace() 155 | p_shared_info, _,softmax_score_memory0 = read(p_shared_info,m_item0) 156 | if train: 157 | softmax_score_memory0 = upload(p_shared_info,m_item0) 158 | else: 159 | softmax_score_memory0 = m_item0 160 | ######################### modification ######################### 161 | 162 | p_shared_back = self.fc_ps_back(p_shared_info) # output of backcast branch 163 | output_ps = self.fc_ps_fore(p_shared_info) # output of forecast branch 164 | output_ps = self.leaky_relu(output_ps) 165 | 166 | pred_ps = self.fc_out_ps(output_ps).squeeze() 167 | 168 | # Hidden Concept Module 169 | h_shared_info = x_hidden - p_shared_back 170 | hidden = h_shared_info 171 | h_stock_to_concept = cal_cos_similarity(hidden, hidden) 172 | 173 | dim = h_stock_to_concept.shape[0] 174 | diag = h_stock_to_concept.diagonal(0) 175 | 176 | h_stock_to_concept = h_stock_to_concept * (torch.ones(dim, dim) - torch.eye(dim)).to(device) 177 | # row = torch.linspace(0,dim-1,dim).to(device).long() 178 | # column = h_stock_to_concept.argmax(1) 179 | row = torch.linspace(0, dim-1, dim).reshape([-1, 1]).repeat(1, self.K).reshape(1, -1).long().to(device) 180 | column = torch.topk(h_stock_to_concept, self.K, dim = 1)[1].reshape(1, -1) 181 | mask = torch.zeros([h_stock_to_concept.shape[0], h_stock_to_concept.shape[1]], device = h_stock_to_concept.device) 182 | mask[row, column] = 1 183 | h_stock_to_concept = h_stock_to_concept * mask 184 | h_stock_to_concept = h_stock_to_concept + torch.diag_embed((h_stock_to_concept.sum(0)!=0).float()*diag) 185 | hidden = torch.t(h_shared_info).mm(h_stock_to_concept).t() 186 | hidden = hidden[hidden.sum(1)!=0] 187 | 188 | h_concept_to_stock = cal_cos_similarity(h_shared_info, hidden) 189 | 190 | 191 | h_concept_to_stock = self.softmax_t2s(h_concept_to_stock) 192 | 193 | h_shared_info = h_concept_to_stock.mm(hidden) 194 | h_shared_info = self.fc_hs(h_shared_info) 195 | 196 | ######################### modification ######################### 197 | ## 修改点1 198 | #import pdb;pdb.set_trace() 199 | h_shared_info, _,softmax_score_memory1 = read(h_shared_info,m_item1) 200 | 201 | if train: 202 | softmax_score_memory1 = upload(h_shared_info,m_item1) 203 | else: 204 | softmax_score_memory1 = m_item1 205 | ######################### modification ######################### 206 | 207 | h_shared_back = self.fc_hs_back(h_shared_info) 208 | output_hs = self.fc_hs_fore(h_shared_info) 209 | 210 | output_hs = self.leaky_relu(output_hs) 211 | pred_hs = self.fc_out_hs(output_hs).squeeze() 212 | 213 | # Individual Information Module 214 | 215 | individual_info = x_hidden - p_shared_back - h_shared_back 216 | # ######################### modification ######################### 217 | # ## 修改点2,对X^(t,2)添加记忆模块,即individual_info 218 | # #import pdb;pdb.set_trace() 219 | # individual_info, _,softmax_score_memory2 = read(individual_info,m_item2) 220 | # if train: 221 | # softmax_score_memory2 = upload(individual_info,m_item2) 222 | # else: 223 | # softmax_score_memory2 = m_item2 224 | # ######################### modification ######################### 225 | output_indi = individual_info 226 | output_indi = self.fc_indi(output_indi) 227 | output_indi = self.leaky_relu(output_indi) 228 | pred_indi = self.fc_out_indi(output_indi).squeeze() 229 | # Stock Trend Prediction 230 | all_info = output_ps + output_hs + output_indi 231 | pred_all = self.fc_out(all_info).squeeze() 232 | return pred_all, [softmax_score_memory0,softmax_score_memory1] #m_items -------------------------------------------------------------------------------- /learn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import os 5 | import copy 6 | import json 7 | import argparse 8 | import datetime 9 | import collections 10 | import numpy as np 11 | import pandas as pd 12 | from tqdm import tqdm 13 | import qlib 14 | # regiodatetimeG_CN, REG_US] 15 | from qlib.config import REG_US, REG_CN 16 | # provider_uri = "~/.qlib/qlib_data/us_data" # target_dir 17 | provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir 18 | qlib.init(provider_uri=provider_uri, region=REG_CN) 19 | from qlib.data.dataset import DatasetH 20 | from qlib.data.dataset.handler import DataHandlerLP 21 | from torch.utils.tensorboard import SummaryWriter 22 | from qlib.contrib.model.pytorch_gru import GRUModel 23 | from qlib.contrib.model.pytorch_lstm import LSTMModel 24 | from qlib.contrib.model.pytorch_gats import GATModel 25 | from qlib.contrib.model.pytorch_sfm import SFM_Model 26 | from qlib.contrib.model.pytorch_alstm import ALSTMModel 27 | from qlib.contrib.model.pytorch_transformer import Transformer 28 | from model import MLP, HIST 29 | from utils import metric_fn, mse 30 | from dataloader import DataLoader 31 | 32 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 33 | 34 | EPS = 1e-12 35 | 36 | 37 | def get_model(model_name): 38 | 39 | if model_name.upper() == 'MLP': 40 | return MLP 41 | 42 | if model_name.upper() == 'LSTM': 43 | return LSTMModel 44 | 45 | if model_name.upper() == 'GRU': 46 | return GRUModel 47 | 48 | if model_name.upper() == 'GATS': 49 | return GATModel 50 | 51 | if model_name.upper() == 'SFM': 52 | return SFM_Model 53 | 54 | if model_name.upper() == 'ALSTM': 55 | return ALSTMModel 56 | 57 | if model_name.upper() == 'TRANSFORMER': 58 | return Transformer 59 | 60 | if model_name.upper() == 'HIST': 61 | return HIST 62 | 63 | raise ValueError('unknown model name `%s`'%model_name) 64 | 65 | 66 | def average_params(params_list): 67 | assert isinstance(params_list, (tuple, list, collections.deque)) 68 | n = len(params_list) 69 | if n == 1: 70 | return params_list[0] 71 | new_params = collections.OrderedDict() 72 | keys = None 73 | for i, params in enumerate(params_list): 74 | if keys is None: 75 | keys = params.keys() 76 | for k, v in params.items(): 77 | if k not in keys: 78 | raise ValueError('the %d-th model has different params'%i) 79 | if k not in new_params: 80 | new_params[k] = v / n 81 | else: 82 | new_params[k] += v / n 83 | return new_params 84 | 85 | 86 | 87 | def loss_fn(pred, label, args): 88 | mask = ~torch.isnan(label) 89 | return mse(pred[mask], label[mask]) 90 | 91 | 92 | global_log_file = None 93 | def pprint(*args): 94 | # print with UTC+8 time 95 | time = '['+str(datetime.datetime.utcnow()+ 96 | datetime.timedelta(hours=8))[:19]+'] -' 97 | print(time, *args, flush=True) 98 | 99 | if global_log_file is None: 100 | return 101 | with open(global_log_file, 'a') as f: 102 | print(time, *args, flush=True, file=f) 103 | 104 | 105 | global_step = -1 106 | def train_epoch(epoch, model, optimizer, train_loader, writer, args, stock2concept_matrix = None): 107 | 108 | global global_step 109 | 110 | model.train() 111 | 112 | for i, slc in tqdm(train_loader.iter_batch(), total=train_loader.batch_length): 113 | global_step += 1 114 | feature, label, market_value , stock_index, _ = train_loader.get(slc) 115 | if args.model_name == 'HIST': 116 | pred = model(feature, stock2concept_matrix[stock_index], market_value) 117 | else: 118 | pred = model(feature) 119 | loss = loss_fn(pred, label, args) 120 | 121 | optimizer.zero_grad() 122 | loss.backward() 123 | torch.nn.utils.clip_grad_value_(model.parameters(), 3.) 124 | optimizer.step() 125 | 126 | 127 | def test_epoch(epoch, model, test_loader, writer, args, stock2concept_matrix=None, prefix='Test'): 128 | 129 | model.eval() 130 | 131 | losses = [] 132 | preds = [] 133 | 134 | for i, slc in tqdm(test_loader.iter_daily(), desc=prefix, total=test_loader.daily_length): 135 | 136 | feature, label, market_value, stock_index, index = test_loader.get(slc) 137 | 138 | with torch.no_grad(): 139 | if args.model_name == 'HIST': 140 | pred = model(feature, stock2concept_matrix[stock_index], market_value) 141 | else: 142 | pred = model(feature) 143 | loss = loss_fn(pred, label, args) 144 | preds.append(pd.DataFrame({ 'score': pred.cpu().numpy(), 'label': label.cpu().numpy(), }, index=index)) 145 | 146 | losses.append(loss.item()) 147 | #evaluate 148 | preds = pd.concat(preds, axis=0) 149 | precision, recall, ic, rank_ic = metric_fn(preds) 150 | scores = ic 151 | # scores = (precision[3] + precision[5] + precision[10] + precision[30])/4.0 152 | # scores = -1.0 * mse 153 | 154 | writer.add_scalar(prefix+'/Loss', np.mean(losses), epoch) 155 | writer.add_scalar(prefix+'/std(Loss)', np.std(losses), epoch) 156 | writer.add_scalar(prefix+'/'+args.metric, np.mean(scores), epoch) 157 | writer.add_scalar(prefix+'/std('+args.metric+')', np.std(scores), epoch) 158 | 159 | return np.mean(losses), scores, precision, recall, ic, rank_ic 160 | 161 | def inference(model, data_loader, stock2concept_matrix=None): 162 | 163 | model.eval() 164 | 165 | preds = [] 166 | for i, slc in tqdm(data_loader.iter_daily(), total=data_loader.daily_length): 167 | 168 | feature, label, market_value, stock_index, index = data_loader.get(slc) 169 | with torch.no_grad(): 170 | if args.model_name == 'HIST': 171 | pred = model(feature, stock2concept_matrix[stock_index], market_value) 172 | else: 173 | pred = model(feature) 174 | preds.append(pd.DataFrame({ 'score': pred.cpu().numpy(), 'label': label.cpu().numpy(), }, index=index)) 175 | 176 | preds = pd.concat(preds, axis=0) 177 | return preds 178 | 179 | 180 | def create_loaders(args): 181 | 182 | start_time = datetime.datetime.strptime(args.train_start_date, '%Y-%m-%d') 183 | end_time = datetime.datetime.strptime(args.test_end_date, '%Y-%m-%d') 184 | train_end_time = datetime.datetime.strptime(args.train_end_date, '%Y-%m-%d') 185 | 186 | hanlder = {'class': 'Alpha360', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': start_time, 'end_time': end_time, 'fit_start_time': start_time, 'fit_end_time': train_end_time, 'instruments': args.data_set, 'infer_processors': [{'class': 'RobustZScoreNorm', 'kwargs': {'fields_group': 'feature', 'clip_outlier': True}}, {'class': 'Fillna', 'kwargs': {'fields_group': 'feature'}}], 'learn_processors': [{'class': 'DropnaLabel'}, {'class': 'CSRankNorm', 'kwargs': {'fields_group': 'label'}}], 'label': ['Ref($close, -1) / $close - 1']}} 187 | segments = { 'train': (args.train_start_date, args.train_end_date), 'valid': (args.valid_start_date, args.valid_end_date), 'test': (args.test_start_date, args.test_end_date)} 188 | dataset = DatasetH(hanlder,segments) 189 | 190 | df_train, df_valid, df_test = dataset.prepare( ["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L,) 191 | import pickle5 as pickle 192 | with open(args.market_value_path, "rb") as fh: 193 | df_market_value = pickle.load(fh) 194 | #df_market_value = pd.read_pickle(args.market_value_path) 195 | df_market_value = df_market_value/1000000000 196 | stock_index = np.load(args.stock_index, allow_pickle=True).item() 197 | 198 | start_index = 0 199 | slc = slice(pd.Timestamp(args.train_start_date), pd.Timestamp(args.train_end_date)) 200 | df_train['market_value'] = df_market_value[slc] 201 | df_train['market_value'] = df_train['market_value'].fillna(df_train['market_value'].mean()) 202 | df_train['stock_index'] = 733 203 | df_train['stock_index'] = df_train.index.get_level_values('instrument').map(stock_index).fillna(733).astype(int) 204 | 205 | train_loader = DataLoader(df_train["feature"], df_train["label"], df_train['market_value'], df_train['stock_index'], batch_size=args.batch_size, pin_memory=args.pin_memory, start_index=start_index, device = device) 206 | 207 | slc = slice(pd.Timestamp(args.valid_start_date), pd.Timestamp(args.valid_end_date)) 208 | df_valid['market_value'] = df_market_value[slc] 209 | df_valid['market_value'] = df_valid['market_value'].fillna(df_train['market_value'].mean()) 210 | df_valid['stock_index'] = 733 211 | df_valid['stock_index'] = df_valid.index.get_level_values('instrument').map(stock_index).fillna(733).astype(int) 212 | start_index += len(df_valid.groupby(level=0).size()) 213 | 214 | valid_loader = DataLoader(df_valid["feature"], df_valid["label"], df_valid['market_value'], df_valid['stock_index'], pin_memory=True, start_index=start_index, device = device) 215 | 216 | slc = slice(pd.Timestamp(args.test_start_date), pd.Timestamp(args.test_end_date)) 217 | df_test['market_value'] = df_market_value[slc] 218 | df_test['market_value'] = df_test['market_value'].fillna(df_train['market_value'].mean()) 219 | df_test['stock_index'] = 733 220 | df_test['stock_index'] = df_test.index.get_level_values('instrument').map(stock_index).fillna(733).astype(int) 221 | start_index += len(df_test.groupby(level=0).size()) 222 | 223 | test_loader = DataLoader(df_test["feature"], df_test["label"], df_test['market_value'], df_test['stock_index'], pin_memory=True, start_index=start_index, device = device) 224 | 225 | return train_loader, valid_loader, test_loader 226 | 227 | 228 | def main(args): 229 | seed = np.random.randint(1000000) 230 | np.random.seed(seed) 231 | torch.manual_seed(seed) 232 | suffix = "%s_dh%s_dn%s_drop%s_lr%s_bs%s_seed%s%s"%( 233 | args.model_name, args.hidden_size, args.num_layers, args.dropout, 234 | args.lr, args.batch_size, args.seed, args.annot 235 | ) 236 | 237 | output_path = args.outdir 238 | if not output_path: 239 | output_path = './output/' + suffix 240 | if not os.path.exists(output_path): 241 | os.makedirs(output_path) 242 | 243 | if not args.overwrite and os.path.exists(output_path+'/'+'info.json'): 244 | print('already runned, exit.') 245 | return 246 | 247 | writer = SummaryWriter(log_dir=output_path) 248 | 249 | global global_log_file 250 | global_log_file = output_path + '/' + args.name + '_run.log' 251 | 252 | pprint('create loaders...') 253 | train_loader, valid_loader, test_loader = create_loaders(args) 254 | 255 | stock2concept_matrix = np.load(args.stock2concept_matrix) 256 | if args.model_name == 'HIST': 257 | stock2concept_matrix = torch.Tensor(stock2concept_matrix).to(device) 258 | 259 | all_precision = [] 260 | all_recall = [] 261 | all_ic = [] 262 | all_rank_ic = [] 263 | for times in range(args.repeat): 264 | pprint('create model...') 265 | if args.model_name == 'SFM': 266 | model = get_model(args.model_name)(d_feat = args.d_feat, output_dim = 32, freq_dim = 25, hidden_size = args.hidden_size, dropout_W = 0.5, dropout_U = 0.5, device = device) 267 | elif args.model_name == 'ALSTM': 268 | model = get_model(args.model_name)(args.d_feat, args.hidden_size, args.num_layers, args.dropout, 'LSTM') 269 | elif args.model_name == 'Transformer': 270 | model = get_model(args.model_name)(args.d_feat, args.hidden_size, args.num_layers, dropout=0.5) 271 | elif args.model_name == 'HIST': 272 | model = get_model(args.model_name)(d_feat = args.d_feat, num_layers = args.num_layers, K = args.K) 273 | else: 274 | model = get_model(args.model_name)(d_feat = args.d_feat, num_layers = args.num_layers) 275 | 276 | model.to(device) 277 | 278 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 279 | best_score = -np.inf 280 | best_epoch = 0 281 | stop_round = 0 282 | best_param = copy.deepcopy(model.state_dict()) 283 | params_list = collections.deque(maxlen=args.smooth_steps) 284 | for epoch in range(args.n_epochs): 285 | pprint('Running', times,'Epoch:', epoch) 286 | 287 | pprint('training...') 288 | train_epoch(epoch, model, optimizer, train_loader, writer, args, stock2concept_matrix) 289 | torch.save(model.state_dict(), output_path+'/model.bin.e'+str(epoch)) 290 | torch.save(optimizer.state_dict(), output_path+'/optimizer.bin.e'+str(epoch)) 291 | 292 | params_ckpt = copy.deepcopy(model.state_dict()) 293 | params_list.append(params_ckpt) 294 | avg_params = average_params(params_list) 295 | model.load_state_dict(avg_params) 296 | 297 | pprint('evaluating...') 298 | train_loss, train_score, train_precision, train_recall, train_ic, train_rank_ic = test_epoch(epoch, model, train_loader, writer, args, stock2concept_matrix, prefix='Train') 299 | val_loss, val_score, val_precision, val_recall, val_ic, val_rank_ic = test_epoch(epoch, model, valid_loader, writer, args, stock2concept_matrix, prefix='Valid') 300 | test_loss, test_score, test_precision, test_recall, test_ic, test_rank_ic = test_epoch(epoch, model, test_loader, writer, args, stock2concept_matrix, prefix='Test') 301 | 302 | pprint('train_loss %.6f, valid_loss %.6f, test_loss %.6f'%(train_loss, val_loss, test_loss)) 303 | pprint('train_score %.6f, valid_score %.6f, test_score %.6f'%(train_score, val_score, test_score)) 304 | # pprint('train_mse %.6f, valid_mse %.6f, test_mse %.6f'%(train_mse, val_mse, test_mse)) 305 | # pprint('train_mae %.6f, valid_mae %.6f, test_mae %.6f'%(train_mae, val_mae, test_mae)) 306 | pprint('train_ic %.6f, valid_ic %.6f, test_ic %.6f'%(train_ic, val_ic, test_ic)) 307 | pprint('train_rank_ic %.6f, valid_rank_ic %.6f, test_rank_ic %.6f'%(train_rank_ic, val_rank_ic, test_rank_ic)) 308 | pprint('Train Precision: ', train_precision) 309 | pprint('Valid Precision: ', val_precision) 310 | pprint('Test Precision: ', test_precision) 311 | pprint('Train Recall: ', train_recall) 312 | pprint('Valid Recall: ', val_recall) 313 | pprint('Test Recall: ', test_recall) 314 | model.load_state_dict(params_ckpt) 315 | 316 | if val_score > best_score: 317 | best_score = val_score 318 | stop_round = 0 319 | best_epoch = epoch 320 | best_param = copy.deepcopy(avg_params) 321 | else: 322 | stop_round += 1 323 | if stop_round >= args.early_stop: 324 | pprint('early stop') 325 | break 326 | 327 | pprint('best score:', best_score, '@', best_epoch) 328 | model.load_state_dict(best_param) 329 | torch.save(best_param, output_path+'/model.bin') 330 | 331 | pprint('inference...') 332 | res = dict() 333 | for name in ['train', 'valid', 'test']: 334 | 335 | pred= inference(model, eval(name+'_loader'), stock2concept_matrix) 336 | pred.to_pickle(output_path+'/pred.pkl.'+name+str(times)) 337 | 338 | precision, recall, ic, rank_ic = metric_fn(pred) 339 | 340 | pprint(('%s: IC %.6f Rank IC %.6f')%( 341 | name, ic.mean(), rank_ic.mean())) 342 | pprint(name, ': Precision ', precision) 343 | pprint(name, ': Recall ', recall) 344 | res[name+'-IC'] = ic 345 | # res[name+'-ICIR'] = ic.mean() / ic.std() 346 | res[name+'-RankIC'] = rank_ic 347 | # res[name+'-RankICIR'] = rank_ic.mean() / rank_ic.std() 348 | 349 | all_precision.append(list(precision.values())) 350 | all_recall.append(list(recall.values())) 351 | all_ic.append(ic) 352 | all_rank_ic.append(rank_ic) 353 | 354 | pprint('save info...') 355 | writer.add_hparams( 356 | vars(args), 357 | { 358 | 'hparam/'+key: value 359 | for key, value in res.items() 360 | } 361 | ) 362 | 363 | info = dict( 364 | config=vars(args), 365 | best_epoch=best_epoch, 366 | best_score=res, 367 | ) 368 | default = lambda x: str(x)[:10] if isinstance(x, pd.Timestamp) else x 369 | with open(output_path+'/info.json', 'w') as f: 370 | json.dump(info, f, default=default, indent=4) 371 | pprint(('IC: %.4f (%.4f), Rank IC: %.4f (%.4f)')%(np.array(all_ic).mean(), np.array(all_ic).std(), np.array(all_rank_ic).mean(), np.array(all_rank_ic).std())) 372 | precision_mean = np.array(all_precision).mean(axis= 0) 373 | precision_std = np.array(all_precision).std(axis= 0) 374 | N = [1, 3, 5, 10, 20, 30, 50, 100] 375 | for k in range(len(N)): 376 | pprint (('Precision@%d: %.4f (%.4f)')%(N[k], precision_mean[k], precision_std[k])) 377 | 378 | pprint('finished.') 379 | 380 | 381 | class ParseConfigFile(argparse.Action): 382 | 383 | def __call__(self, parser, namespace, filename, option_string=None): 384 | 385 | if not os.path.exists(filename): 386 | raise ValueError('cannot find config at `%s`'%filename) 387 | 388 | with open(filename) as f: 389 | config = json.load(f) 390 | for key, value in config.items(): 391 | setattr(namespace, key, value) 392 | 393 | 394 | def parse_args(): 395 | 396 | parser = argparse.ArgumentParser() 397 | 398 | # model 399 | parser.add_argument('--model_name', default='HIST') 400 | parser.add_argument('--d_feat', type=int, default=6) 401 | parser.add_argument('--hidden_size', type=int, default=128) 402 | parser.add_argument('--num_layers', type=int, default=2) 403 | parser.add_argument('--dropout', type=float, default=0.0) 404 | parser.add_argument('--K', type=int, default=1) 405 | 406 | # training 407 | parser.add_argument('--n_epochs', type=int, default=200) 408 | parser.add_argument('--lr', type=float, default=2e-4) 409 | parser.add_argument('--early_stop', type=int, default=30) 410 | parser.add_argument('--smooth_steps', type=int, default=5) 411 | parser.add_argument('--metric', default='IC') 412 | parser.add_argument('--loss', default='mse') 413 | parser.add_argument('--repeat', type=int, default=10) 414 | 415 | # data 416 | parser.add_argument('--data_set', type=str, default='csi300') 417 | parser.add_argument('--pin_memory', action='store_false', default=True) 418 | parser.add_argument('--batch_size', type=int, default=-1) # -1 indicate daily batch 419 | parser.add_argument('--least_samples_num', type=float, default=1137.0) 420 | parser.add_argument('--label', default='') # specify other labels 421 | parser.add_argument('--train_start_date', default='2007-01-01') 422 | parser.add_argument('--train_end_date', default='2014-12-31') 423 | parser.add_argument('--valid_start_date', default='2015-01-01') 424 | parser.add_argument('--valid_end_date', default='2016-12-31') 425 | parser.add_argument('--test_start_date', default='2017-01-01') 426 | parser.add_argument('--test_end_date', default='2020-12-31') 427 | 428 | # other 429 | parser.add_argument('--seed', type=int, default=0) 430 | parser.add_argument('--annot', default='') 431 | parser.add_argument('--config', action=ParseConfigFile, default='') 432 | parser.add_argument('--name', type=str, default='csi300_HIST') 433 | 434 | # input for csi 300 435 | parser.add_argument('--market_value_path', default='./data/csi300_market_value_07to20.pkl') 436 | parser.add_argument('--stock2concept_matrix', default='./data/csi300_stock2concept.npy') 437 | parser.add_argument('--stock_index', default='./data/csi300_stock_index.npy') 438 | 439 | parser.add_argument('--outdir', default='./output/csi300_HIST') 440 | parser.add_argument('--overwrite', action='store_true', default=False) 441 | 442 | args = parser.parse_args() 443 | 444 | return args 445 | 446 | 447 | if __name__ == '__main__': 448 | 449 | args = parse_args() 450 | main(args) 451 | -------------------------------------------------------------------------------- /learn_memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import os 5 | import copycopy 6 | import json 7 | import argparse 8 | import datetime 9 | import collections 10 | import numpy as np 11 | import pandas as pd 12 | from tqdm import tqdm 13 | import qlib 14 | # regiodatetimeG_CN, REG_US] 15 | from qlib.config import REG_US, REG_CN 16 | # provider_uri = "~/.qlib/qlib_data/us_data" # target_dir 17 | provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir 18 | qlib.init(provider_uri=provider_uri, region=REG_CN) 19 | from qlib.data.dataset import DatasetHDatasetH 20 | from qlib.data.dataset.handler import DataHandlerLP 21 | from torch.utils.tensorboard import SummaryWriter 22 | from qlib.contrib.model.pytorch_gru import GRUModel 23 | from qlib.contrib.model.pytorch_lstm import LSTMModelLSTMModel 24 | from qlib.contrib.model.pytorch_gats import GATModel 25 | from qlib.contrib.model.pytorch_sfm import SFM_Model 26 | from qlib.contrib.model.pytorch_alstm import ALSTMModel 27 | from qlib.contrib.model.pytorch_transformer import Transformer 28 | from model2 import * 29 | from utils import metric_fn, mse 30 | from dataloader import DataLoader 31 | 32 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 33 | 34 | EPS = 1e-12 35 | 36 | 37 | def get_model(model_name): 38 | 39 | if model_name.upper() == 'MLP': 40 | return MLP 41 | 42 | if model_name.upper() == 'LSTM': 43 | return LSTMModel 44 | 45 | if model_name.upper() == 'GRU': 46 | return GRUModel 47 | 48 | if model_name.upper() == 'GATS': 49 | return GATModel 50 | 51 | if model_name.upper() == 'SFM': 52 | return SFM_Model 53 | 54 | if model_name.upper() == 'ALSTM': 55 | return ALSTMModel 56 | 57 | if model_name.upper() == 'TRANSFORMER': 58 | return Transformer 59 | 60 | if model_name.upper() == 'HIST': 61 | return HIST 62 | 63 | raise ValueError('unknown model name `%s`'%model_name) 64 | 65 | def gather_loss(query, keys): 66 | 67 | batch_size,dims = query.size() # b X h X w X d 68 | 69 | loss_mse = torch.nn.MSELoss() 70 | 71 | softmax_score_query, softmax_score_memory = get_score(keys, query) 72 | 73 | query_reshape = query.contiguous().view(batch_size, dims) 74 | 75 | _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1) 76 | 77 | gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach()) 78 | 79 | return gathering_loss 80 | 81 | def spread_loss(query, keys): 82 | batch_size, dims = query.size() # b X h X w X d 83 | 84 | loss = torch.nn.TripletMarginLoss(margin=1.0) 85 | 86 | softmax_score_query, softmax_score_memory = get_score(keys, query) 87 | 88 | query_reshape = query.contiguous().view(batch_size, dims) 89 | 90 | _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1) 91 | 92 | #1st, 2nd closest memories 93 | pos = keys[gathering_indices[:,0]] 94 | neg = keys[gathering_indices[:,1]] 95 | 96 | spreading_loss = loss(query_reshape,pos.detach(), neg.detach()) 97 | 98 | return spreading_loss 99 | 100 | def average_params(params_list): 101 | assert isinstance(params_list, (tuple, list, collections.deque)) 102 | n = len(params_list) 103 | if n == 1: 104 | return params_list[0] 105 | new_params = collections.OrderedDict() 106 | keys = None 107 | for i, params in enumerate(params_list): 108 | if keys is None: 109 | keys = params.keys() 110 | for k, v in params.items(): 111 | if k not in keys: 112 | raise ValueError('the %d-th model has different params'%i) 113 | if k not in new_params: 114 | new_params[k] = v / n 115 | else: 116 | new_params[k] += v / n 117 | return new_params 118 | 119 | 120 | 121 | def loss_fn(pred, label, args): 122 | mask = ~torch.isnan(label) 123 | return mse(pred[mask], label[mask]) 124 | 125 | 126 | global_log_file = None 127 | def pprint(*args): 128 | # print with UTC+8 time 129 | time = '['+str(datetime.datetime.utcnow()+ 130 | datetime.timedelta(hours=8))[:19]+'] -' 131 | print(time, *args, flush=True) 132 | 133 | if global_log_file is None: 134 | return 135 | with open(global_log_file, 'a') as f: 136 | print(time, *args, flush=True, file=f) 137 | 138 | 139 | global_step = -1 140 | def train_epoch(epoch, model, optimizer, train_loader, writer, args, stock2concept_matrix = None,m_items = None): 141 | 142 | global global_step 143 | 144 | model.train() 145 | 146 | for i, slc in tqdm(train_loader.iter_batch(), total=train_loader.batch_length): 147 | # 148 | global_step += 1 149 | feature, label, market_value , stock_index, _ = train_loader.get(slc) 150 | if args.model_name == 'HIST': 151 | pred,m_items = model(feature, stock2concept_matrix[stock_index], market_value,m_items,train=True) 152 | else: 153 | pred = model(feature) 154 | 155 | loss = loss_fn(pred, label, args) 156 | 157 | #loss += (loss_gather_loss*0.1)#+)(loss_spread_loss *0.1) 158 | 159 | optimizer.zero_grad() 160 | loss.backward(retain_graph=True) 161 | torch.nn.utils.clip_grad_value_(model.parameters(), 3.) 162 | optimizer.step() 163 | return m_items 164 | 165 | def test_epoch(rep,epoch, model, test_loader, writer, args, stock2concept_matrix=None,m_items = None, prefix='Test', train=False): 166 | 167 | model.eval() 168 | 169 | losses = [] 170 | preds = [] 171 | 172 | for i, slc in tqdm(test_loader.iter_daily(), desc=prefix, total=test_loader.daily_length): 173 | 174 | feature, label, market_value, stock_index, index = test_loader.get(slc) 175 | 176 | with torch.no_grad(): 177 | if args.model_name == 'HIST': 178 | pred, m_items = model(feature, stock2concept_matrix[stock_index], market_value,m_items,train=False) 179 | else: 180 | pred, m_items= model(feature,m_items,train=False) 181 | loss = loss_fn(pred, label, args) 182 | preds.append(pd.DataFrame({ 'score': pred.cpu().numpy(), 'label': label.cpu().numpy(), }, index=index)) 183 | 184 | losses.append(loss.item()) 185 | #evaluate 186 | preds = pd.concat(preds, axis=0) 187 | if not os.path.exists(args.outdir+"/csv"): 188 | os.makedirs(args.outdir+"/csv") 189 | preds.to_csv(args.outdir+"/csv/+"+args.model_name+"_r"+str(rep)+"_e"+str(epoch)+'.csv') 190 | precision, recall, ic, rank_ic = metric_fn(preds) 191 | scores = ic 192 | # scores = (precision[3] + precision[5] + precision[10] + precision[30])/4.0 193 | # scores = -1.0 * mse 194 | 195 | writer.add_scalar(prefix+'/Loss', np.mean(losses), epoch) 196 | writer.add_scalar(prefix+'/std(Loss)', np.std(losses), epoch) 197 | writer.add_scalar(prefix+'/'+args.metric, np.mean(scores), epoch) 198 | writer.add_scalar(prefix+'/std('+args.metric+')', np.std(scores), epoch) 199 | 200 | return np.mean(losses), scores, precision, recall, ic, rank_ic 201 | 202 | def inference(model, data_loader, stock2concept_matrix=None,m_items = None, train=False): 203 | 204 | model.eval() 205 | 206 | preds = [] 207 | for i, slc in tqdm(data_loader.iter_daily(), total=data_loader.daily_length): 208 | 209 | feature, label, market_value, stock_index, index = data_loader.get(slc) 210 | with torch.no_grad(): 211 | if args.model_name == 'HIST': 212 | pred,m_items_out = model(feature, stock2concept_matrix[stock_index], market_value,m_items,False) 213 | else: 214 | pred = model(feature) 215 | preds.append(pd.DataFrame({ 'score': pred.cpu().numpy(), 'label': label.cpu().numpy(), }, index=index)) 216 | 217 | preds = pd.concat(preds, axis=0) 218 | return preds 219 | 220 | 221 | def create_loaders(args): 222 | 223 | start_time = datetime.datetime.strptime(args.train_start_date, '%Y-%m-%d') 224 | end_time = datetime.datetime.strptime(args.test_end_date, '%Y-%m-%d') 225 | train_end_time = datetime.datetime.strptime(args.train_end_date, '%Y-%m-%d') 226 | 227 | hanlder = {'class': 'Alpha360', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': start_time, 'end_time': end_time, 'fit_start_time': start_time, 'fit_end_time': train_end_time, 'instruments': args.data_set, 'infer_processors': [{'class': 'RobustZScoreNorm', 'kwargs': {'fields_group': 'feature', 'clip_outlier': True}}, {'class': 'Fillna', 'kwargs': {'fields_group': 'feature'}}], 'learn_processors': [{'class': 'DropnaLabel'}, {'class': 'CSRankNorm', 'kwargs': {'fields_group': 'label'}}], 'label': ['Ref($close, -1) / $close - 1']}} 228 | segments = { 'train': (args.train_start_date, args.train_end_date), 'valid': (args.valid_start_date, args.valid_end_date), 'test': (args.test_start_date, args.test_end_date)} 229 | dataset = DatasetH(hanlder,segments) 230 | 231 | df_train, df_valid, df_test = dataset.prepare( ["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L,) 232 | import pickle5 as pickle 233 | with open(args.market_value_path, "rb") as fh: 234 | df_market_value = pickle.load(fh) 235 | #df_market_value = pd.read_pickle(args.market_value_path) 236 | df_market_value = df_market_value/1000000000 237 | stock_index = np.load(args.stock_index, allow_pickle=True).item() 238 | 239 | start_index = 0 240 | slc = slice(pd.Timestamp(args.train_start_date), pd.Timestamp(args.train_end_date)) 241 | df_train['market_value'] = df_market_value[slc] 242 | df_train['market_value'] = df_train['market_value'].fillna(df_train['market_value'].mean()) 243 | df_train['stock_index'] = 733 244 | df_train['stock_index'] = df_train.index.get_level_values('instrument').map(stock_index).fillna(733).astype(int) 245 | 246 | train_loader = DataLoader(df_train["feature"], df_train["label"], df_train['market_value'], df_train['stock_index'], batch_size=args.batch_size, pin_memory=args.pin_memory, start_index=start_index, device = device) 247 | 248 | slc = slice(pd.Timestamp(args.valid_start_date), pd.Timestamp(args.valid_end_date)) 249 | df_valid['market_value'] = df_market_value[slc] 250 | df_valid['market_value'] = df_valid['market_value'].fillna(df_train['market_value'].mean()) 251 | df_valid['stock_index'] = 733 252 | df_valid['stock_index'] = df_valid.index.get_level_values('instrument').map(stock_index).fillna(733).astype(int) 253 | start_index += len(df_valid.groupby(level=0).size()) 254 | 255 | valid_loader = DataLoader(df_valid["feature"], df_valid["label"], df_valid['market_value'], df_valid['stock_index'], pin_memory=True, start_index=start_index, device = device) 256 | 257 | slc = slice(pd.Timestamp(args.test_start_date), pd.Timestamp(args.test_end_date)) 258 | df_test['market_value'] = df_market_value[slc] 259 | df_test['market_value'] = df_test['market_value'].fillna(df_train['market_value'].mean()) 260 | df_test['stock_index'] = 733 261 | df_test['stock_index'] = df_test.index.get_level_values('instrument').map(stock_index).fillna(733).astype(int) 262 | start_index += len(df_test.groupby(level=0).size()) 263 | 264 | test_loader = DataLoader(df_test["feature"], df_test["label"], df_test['market_value'], df_test['stock_index'], pin_memory=True, start_index=start_index, device = device) 265 | 266 | return train_loader, valid_loader, test_loader 267 | 268 | 269 | def main(args): 270 | seed = np.random.randint(1000000) 271 | np.random.seed(seed) 272 | torch.manual_seed(seed) 273 | suffix = "%s_dh%s_dn%s_drop%s_lr%s_bs%s_seed%s%s"%( 274 | args.model_name, args.hidden_size, args.num_layers, args.dropout, 275 | args.lr, args.batch_size, args.seed, args.annot 276 | ) 277 | 278 | output_path = args.outdir 279 | if not output_path: 280 | output_path = './output/' + suffix 281 | if not os.path.exists(output_path): 282 | os.makedirs(output_path) 283 | 284 | if not args.overwrite and os.path.exists(output_path+'/'+'info.json'): 285 | print('already runned, exit.') 286 | return 287 | 288 | writer = SummaryWriter(log_dir=output_path) 289 | 290 | global global_log_file 291 | global_log_file = output_path + '/' + args.name + '_run.log' 292 | 293 | pprint('create loaders...') 294 | train_loader, valid_loader, test_loader = create_loaders(args) 295 | 296 | stock2concept_matrix = np.load(args.stock2concept_matrix) 297 | if args.model_name == 'HIST': 298 | stock2concept_matrix = torch.Tensor(stock2concept_matrix).to(device) 299 | 300 | all_precision = [] 301 | all_recall = [] 302 | all_ic = [] 303 | all_rank_ic = [] 304 | for times in range(args.repeat): 305 | 306 | pprint('create model...') 307 | ######################### modification ######################### 308 | m_item0 = F.normalize(torch.rand((96,64), dtype=torch.float), dim=1).cuda() 309 | m_item1 = F.normalize(torch.rand((96,64), dtype=torch.float), dim=1).cuda() 310 | # m_item2 = F.normalize(torch.rand((96,64), dtype=torch.float), dim=1).cuda() 311 | m_items = [m_item0, m_item1] 312 | #m_items = [m_item1] 313 | ######################### modification ######################### 314 | if args.model_name == 'SFM': 315 | model = get_model(args.model_name)(d_feat = args.d_feat, output_dim = 32, freq_dim = 25, hidden_size = args.hidden_size, dropout_W = 0.5, dropout_U = 0.5, device = device) 316 | elif args.model_name == 'ALSTM': 317 | model = get_model(args.model_name)(args.d_feat, args.hidden_size, args.num_layers, args.dropout, 'LSTM') 318 | elif args.model_name == 'Transformer': 319 | model = get_model(args.model_name)(args.d_feat, args.hidden_size, args.num_layers, dropout=0.5) 320 | elif args.model_name == 'HIST': 321 | model = get_model(args.model_name)(d_feat = args.d_feat, num_layers = args.num_layers, K = args.K) 322 | else: 323 | model = get_model(args.model_name)(d_feat = args.d_feat, num_layers = args.num_layers) 324 | 325 | model.to(device) 326 | 327 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 328 | best_score = -np.inf 329 | best_epoch = 0 330 | best_test_score = -np.inf 331 | best_test_epoch = 0 332 | stop_round = 0 333 | best_param = copy.deepcopy(model.state_dict()) 334 | params_list = collections.deque(maxlen=args.smooth_steps) 335 | #m_items=torch.load("m_items.bin.r1.e44") 336 | for epoch in range(args.n_epochs): 337 | pprint('Running', times,'Epoch:', epoch) 338 | 339 | pprint('training...') 340 | train_epoch(epoch, model, optimizer, train_loader, writer, args, stock2concept_matrix,m_items) 341 | 342 | 343 | params_ckpt = copy.deepcopy(model.state_dict()) 344 | params_list.append(params_ckpt) 345 | avg_params = average_params(params_list) 346 | model.load_state_dict(avg_params) 347 | 348 | pprint('evaluating...') 349 | train_loss, train_score, train_precision, train_recall, train_ic, train_rank_ic = test_epoch(times,epoch, model, train_loader, writer, args, stock2concept_matrix, m_items = m_items,prefix='Train', train=False) 350 | val_loss, val_score, val_precision, val_recall, val_ic, val_rank_ic = test_epoch(times,epoch, model, valid_loader, writer, args, stock2concept_matrix,m_items = m_items, prefix='Valid', train=False) 351 | torch.save(model, output_path + '/model.bin'+'.r'+str(times)+'.e' + str(epoch)) 352 | torch.save(optimizer, output_path + '/optimizer.bin'+'.r'+str(times)+'.e' + str(epoch)) 353 | test_loss, test_score, test_precision, test_recall, test_ic, test_rank_ic = test_epoch(times,epoch, model, test_loader, writer, args, stock2concept_matrix,m_items = m_items, prefix='Test', train=False) 354 | 355 | pprint('train_loss %.6f, valid_loss %.6f, test_loss %.6f'%(train_loss, val_loss, test_loss)) 356 | pprint('train_score %.6f, valid_score %.6f, test_score %.6f'%(train_score, val_score, test_score)) 357 | # pprint('train_mse %.6f, valid_mse %.6f, test_mse %.6f'%(train_mse, val_mse, test_mse)) 358 | # pprint('train_mae %.6f, valid_mae %.6f, test_mae %.6f'%(train_mae, val_mae, test_mae)) 359 | pprint('train_ic %.6f, valid_ic %.6f, test_ic %.6f'%(train_ic, val_ic, test_ic)) 360 | pprint('train_rank_ic %.6f, valid_rank_ic %.6f, test_rank_ic %.6f'%(train_rank_ic, val_rank_ic, test_rank_ic)) 361 | pprint('Train Precision: ', train_precision) 362 | pprint('Valid Precision: ', val_precision) 363 | pprint('Test Precision: ', test_precision) 364 | pprint('Train Recall: ', train_recall) 365 | pprint('Valid Recall: ', val_recall) 366 | pprint('Test Recall: ', test_recall) 367 | model.load_state_dict(params_ckpt) 368 | if test_score>best_test_score: 369 | best_test_score=test_score 370 | best_test_epoch=epoch 371 | if val_score > best_score: 372 | best_score = val_score 373 | stop_round = 0 374 | best_epoch = epoch 375 | best_param = copy.deepcopy(avg_params) 376 | else: 377 | stop_round += 1 378 | if stop_round >= args.early_stop: 379 | pprint('early stop') 380 | break 381 | pprint('best test score:', best_test_score, '@', best_test_epoch) 382 | pprint('best score:', best_score, '@', best_epoch) 383 | model.load_state_dict(best_param) 384 | torch.save(best_param, output_path+'/model.bin') 385 | 386 | class ParseConfigFile(argparse.Action): 387 | 388 | def __call__(self, parser, namespace, filename, option_string=None): 389 | 390 | if not os.path.exists(filename): 391 | raise ValueError('cannot find config at `%s`'%filename) 392 | 393 | with open(filename) as f: 394 | config = json.load(f) 395 | for key, value in config.items(): 396 | setattr(namespace, key, value) 397 | 398 | 399 | def parse_args(): 400 | 401 | parser = argparse.ArgumentParser() 402 | 403 | # model 404 | parser.add_argument('--model_name', default='HIST') 405 | parser.add_argument('--d_feat', type=int, default=6) 406 | parser.add_argument('--hidden_size', type=int, default=128) 407 | parser.add_argument('--num_layers', type=int, default=2) 408 | parser.add_argument('--dropout', type=float, default=0.0) 409 | parser.add_argument('--K', type=int, default=1) 410 | 411 | # training 412 | parser.add_argument('--n_epochs', type=int, default=200) 413 | parser.add_argument('--lr', type=float, default=2e-4) 414 | parser.add_argument('--early_stop', type=int, default=30) 415 | parser.add_argument('--smooth_steps', type=int, default=5) 416 | parser.add_argument('--metric', default='IC') 417 | parser.add_argument('--loss', default='mse') 418 | parser.add_argument('--repeat', type=int, default=2) 419 | 420 | # data 421 | parser.add_argument('--data_set', type=str, default='csi300') 422 | parser.add_argument('--pin_memory', action='store_false', default=True) 423 | parser.add_argument('--batch_size', type=int, default=-1) # -1 indicate daily batch 424 | parser.add_argument('--least_samples_num', type=float, default=1137.0) 425 | parser.add_argument('--label', default='') # specify other labels 426 | parser.add_argument('--train_start_date', default='2007-01-01') 427 | parser.add_argument('--train_end_date', default='2014-12-31') 428 | parser.add_argument('--valid_start_date', default='2015-01-01') 429 | parser.add_argument('--valid_end_date', default='2016-12-31') 430 | parser.add_argument('--test_start_date', default='2017-01-01') 431 | parser.add_argument('--test_end_date', default='2020-12-31') 432 | 433 | # other 434 | parser.add_argument('--seed', type=int, default=0) 435 | parser.add_argument('--annot', default='') 436 | parser.add_argument('--config', action=ParseConfigFile, default='') 437 | parser.add_argument('--name', type=str, default='csi300_HIST') 438 | 439 | # input for csi 300 440 | parser.add_argument('--market_value_path', default='./data/csi300_market_value_07to20.pkl') 441 | parser.add_argument('--stock2concept_matrix', default='./data/csi300_stock2concept.npy') 442 | parser.add_argument('--stock_index', default='./data/csi300_stock_index.npy') 443 | 444 | parser.add_argument('--outdir', default='./output/csi300_HIST') 445 | parser.add_argument('--overwrite', action='store_true', default=False) 446 | 447 | args = parser.parse_args() 448 | 449 | return args 450 | 451 | 452 | if __name__ == '__main__': 453 | 454 | args = parse_args() 455 | main(args) 456 | --------------------------------------------------------------------------------