├── Fig ├── Atten-U-net.png ├── LSTM-TAU-net.png └── Temp-Atten-Unet.png ├── README.md ├── __pycache__ ├── convLSTM_network.cpython-36.pyc ├── dataset.cpython-36.pyc ├── evaluation.cpython-36.pyc ├── network.cpython-36.pyc └── solver.cpython-36.pyc ├── convLSTM_network.py ├── dataset.py ├── evaluation.py ├── main.py ├── model.py ├── models └── RCA_U_Net-100-0.0005-80-0.0000.pkl ├── network.py ├── solver.py └── train.py /Fig/Atten-U-net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trichtu/ConvLSTM-RAU-net/53db02574cfa87c598c52f030a1bf07f5e607d04/Fig/Atten-U-net.png -------------------------------------------------------------------------------- /Fig/LSTM-TAU-net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trichtu/ConvLSTM-RAU-net/53db02574cfa87c598c52f030a1bf07f5e607d04/Fig/LSTM-TAU-net.png -------------------------------------------------------------------------------- /Fig/Temp-Atten-Unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trichtu/ConvLSTM-RAU-net/53db02574cfa87c598c52f030a1bf07f5e607d04/Fig/Temp-Atten-Unet.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvLSTM-RAU-net 2 | Spatial-temperal Prediction Model based on history observation and WRF numerical prediction. This is an ideal but unfinished thought. But the time flies, expecting the selected one to finishe it. This project is derived from 2019 Deecamp, relavent conclusions are given in https://blog.csdn.net/maliang_1993/article/details/99622197 3 | 4 | 5 | **U-Net: Convolutional Networks for Biomedical Image Segmentation** 6 | 7 | https://arxiv.org/abs/1505.04597 8 | 9 | 10 | **Attention U-Net: Learning Where to Look for the Pancreas** 11 | 12 | https://arxiv.org/abs/1804.03999 13 | 14 | 15 | **Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting** 16 | 17 | https://arxiv.org/abs/1506.04214 18 | 19 | 20 | ## Atten-U-Net 21 | ![Atten-U-Net](/Fig/Atten-U-net.png) 22 | 23 | ## Temp-Atten-Unet 24 | ![TAU-Net](/Fig/Temp-Atten-Unet.png) 25 | 26 | ## ConvLSTM-TAU-net 27 | ![convLSTM-Net](/Fig/LSTM-TAU-net.png) 28 | 29 | -------------------------------------------------------------------------------- /__pycache__/convLSTM_network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trichtu/ConvLSTM-RAU-net/53db02574cfa87c598c52f030a1bf07f5e607d04/__pycache__/convLSTM_network.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trichtu/ConvLSTM-RAU-net/53db02574cfa87c598c52f030a1bf07f5e607d04/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trichtu/ConvLSTM-RAU-net/53db02574cfa87c598c52f030a1bf07f5e607d04/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trichtu/ConvLSTM-RAU-net/53db02574cfa87c598c52f030a1bf07f5e607d04/__pycache__/network.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/solver.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trichtu/ConvLSTM-RAU-net/53db02574cfa87c598c52f030a1bf07f5e607d04/__pycache__/solver.cpython-36.pyc -------------------------------------------------------------------------------- /convLSTM_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | import time 5 | import torch 6 | from torch import nn,device 7 | from torch.autograd import Variable 8 | import pandas as pd 9 | import datetime 10 | import pandas as pd 11 | import random 12 | 13 | 14 | class ConvLSTMCell(nn.Module): 15 | 16 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias): 17 | """ 18 | Initialize ConvLSTM cell. 19 | 20 | Parameters 21 | ---------- 22 | input_size: (int, int) 23 | Height and width of input tensor as (height, width). 24 | input_dim: int 25 | Number of channels of input tensor. 26 | hidden_dim: int 27 | Number of channels of hidden state. 28 | kernel_size: (int, int) 29 | Size of the convolutional kernel. 30 | bias: bool 31 | Whether or not to add the bias. 32 | """ 33 | 34 | super(ConvLSTMCell, self).__init__() 35 | 36 | self.height, self.width = input_size 37 | self.input_dim = input_dim 38 | self.hidden_dim = hidden_dim 39 | 40 | self.kernel_size = kernel_size 41 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 42 | self.bias = bias 43 | 44 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 45 | out_channels=4 * self.hidden_dim, 46 | kernel_size=self.kernel_size, 47 | padding=self.padding, 48 | bias=self.bias) 49 | 50 | def forward(self, input_tensor, cur_state): 51 | 52 | h_cur, c_cur = cur_state 53 | 54 | # print(input_tensor.shape) 55 | # print(cur_state.shape) 56 | 57 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 58 | 59 | combined_conv = self.conv(combined) 60 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 61 | 62 | # i = torch.sigmoid(cc_i) 63 | # f = torch.sigmoid(cc_f) 64 | # o = torch.sigmoid(cc_o) 65 | # g = torch.relu(cc_g) 66 | 67 | # c_next = f * c_cur + i * g 68 | # h_next = o * torch.relu(c_next) 69 | 70 | i = torch.sigmoid(cc_i) 71 | f = torch.sigmoid(cc_f) 72 | o = torch.sigmoid(cc_o) 73 | g = torch.tanh(cc_g) 74 | 75 | c_next = f * c_cur + i * g 76 | h_next = o * torch.tanh(c_next) 77 | 78 | return h_next, c_next 79 | 80 | def init_hidden(self, batch_size): 81 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 82 | return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).to(device), 83 | Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).to(device)) 84 | 85 | 86 | class ConvLSTM(nn.Module): 87 | 88 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers, 89 | batch_first=False, bias=True, return_all_layers=False): 90 | super(ConvLSTM, self).__init__() 91 | 92 | self._check_kernel_size_consistency(kernel_size) 93 | 94 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 95 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 96 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 97 | if not len(kernel_size) == len(hidden_dim) == num_layers: 98 | raise ValueError('Inconsistent list length.') 99 | 100 | self.height, self.width = input_size 101 | 102 | self.input_dim = input_dim 103 | self.hidden_dim = hidden_dim 104 | self.kernel_size = kernel_size 105 | self.num_layers = num_layers 106 | self.batch_first = batch_first 107 | self.bias = bias 108 | self.return_all_layers = return_all_layers 109 | 110 | cell_list = [] 111 | for i in range(0, self.num_layers): 112 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1] 113 | 114 | cell_list.append(ConvLSTMCell(input_size=(self.height, self.width), 115 | input_dim=cur_input_dim, 116 | hidden_dim=self.hidden_dim[i], 117 | kernel_size=self.kernel_size[i], 118 | bias=self.bias)) 119 | 120 | self.cell_list = nn.ModuleList(cell_list) 121 | 122 | def forward(self, input_tensor, hidden_state=None): 123 | """ 124 | 125 | Parameters 126 | ---------- 127 | input_tensor: todo 128 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 129 | hidden_state: todo 130 | None. todo implement stateful (num_layers, 2, batch, filter, h, w) 131 | 132 | Returns 133 | ------- 134 | last_state_list, layer_output 135 | """ 136 | if not self.batch_first: 137 | # (t, b, c, h, w) -> (b, t, c, h, w) 138 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 139 | # print('hidden_state_shape before:',type( hidden_state)) 140 | # Implement stateful ConvLSTM 141 | if hidden_state is None: 142 | hidden_state = self._init_hidden(batch_size=input_tensor.size(0)) 143 | # print('hidden_state_shape after:',len(hidden_state), hidden_state[0].shape) 144 | layer_output_list = [] 145 | last_state_list = [] 146 | 147 | seq_len = input_tensor.size(1) # 读取sequence length 148 | cur_layer_input = input_tensor 149 | 150 | for layer_idx in range(self.num_layers): 151 | 152 | h, c = hidden_state[layer_idx] # (b,c,h,w) 153 | output_inner = [] 154 | for t in range(seq_len): 155 | 156 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], 157 | cur_state=[h, c]) 158 | output_inner.append(h) 159 | 160 | layer_output = torch.stack(output_inner, dim=1) 161 | cur_layer_input = layer_output 162 | 163 | layer_output_list.append(layer_output) 164 | last_state_list.append([h, c]) 165 | 166 | if not self.return_all_layers: 167 | layer_output_list = layer_output_list[-1:] 168 | last_state_list = last_state_list[-1:] 169 | 170 | return layer_output_list, last_state_list 171 | 172 | def _init_hidden(self, batch_size): 173 | init_states = [] 174 | for i in range(self.num_layers): 175 | init_states.append(self.cell_list[i].init_hidden(batch_size)) 176 | return init_states 177 | 178 | @staticmethod 179 | def _check_kernel_size_consistency(kernel_size): 180 | if not (isinstance(kernel_size, tuple) or (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 181 | raise ValueError('`kernel_size` must be tuple or list of tuples') 182 | 183 | @staticmethod 184 | def _extend_for_multilayer(param, num_layers): 185 | if not isinstance(param, list): 186 | param = [param] * num_layers 187 | return param 188 | 189 | 190 | 191 | 192 | class convLSTM_model(torch.nn.Module): 193 | 194 | def __init__(self, histgc_feature, width, height): 195 | 196 | super(convLSTM_model,self).__init__() 197 | self.histgc_feature=histgc_feature 198 | self.width=width 199 | self.height=height 200 | filter_1 = 16 201 | filter_2 = 32 202 | filter_3 = 64 203 | self.conv1=torch.nn.Conv2d(in_channels=histgc_feature,out_channels=filter_1, 204 | padding=1,kernel_size=(3,3)) 205 | self.BN_1=torch.nn.BatchNorm2d(num_features=filter_1) 206 | self.convlstm1=ConvLSTM(input_size=[width,height],input_dim=filter_1, 207 | hidden_dim=filter_1, num_layers=1,batch_first=True, 208 | return_all_layers=True,kernel_size=(3,3)) 209 | 210 | self.maxpool2=torch.nn.MaxPool2d(2) 211 | self.conv2_1=torch.nn.Conv2d(in_channels=filter_1,out_channels=filter_2, 212 | padding=1,kernel_size=(3,3)) 213 | self.BN_2_1=torch.nn.BatchNorm2d(num_features=filter_2) 214 | self.convlstm2=ConvLSTM(input_size=[width//2,height//2],input_dim=filter_2, 215 | hidden_dim=filter_2, num_layers=1,batch_first=True, 216 | return_all_layers=True,kernel_size=(3,3)) 217 | self.conv2_2=torch.nn.Conv2d(in_channels=filter_2,out_channels=filter_2, 218 | padding=1,kernel_size=(3,3)) 219 | self.BN_2_2=torch.nn.BatchNorm2d(num_features=filter_2) 220 | 221 | self.maxpool3=torch.nn.MaxPool2d(2) 222 | self.conv3_1=torch.nn.Conv2d(in_channels=filter_2,out_channels=filter_3, 223 | padding=1,kernel_size=(3,3)) 224 | self.BN_3_1=torch.nn.BatchNorm2d(num_features=filter_3) 225 | self.convlstm3=ConvLSTM(input_size=[width//4,height//4],input_dim=filter_3, 226 | hidden_dim=filter_3, num_layers=1,batch_first=True, 227 | return_all_layers=True,kernel_size=(3,3)) 228 | self.conv3_2=torch.nn.Conv2d(in_channels=filter_3,out_channels=filter_3, 229 | padding=1,kernel_size=(3,3)) 230 | self.BN_3_2=torch.nn.BatchNorm2d(num_features=filter_3) 231 | 232 | 233 | self.upsample_2=torch.nn.Upsample(scale_factor=2) 234 | self.convu_2_1=torch.nn.Conv2d(in_channels=filter_3,out_channels=filter_2, 235 | padding=1, kernel_size=(3,3)) 236 | self.BN_u_2_1=torch.nn.BatchNorm2d(num_features=filter_2) 237 | self.convlstm_u_2=ConvLSTM(input_size=[width//2,height//2],input_dim=filter_2, 238 | hidden_dim=filter_2, num_layers=1,batch_first=True, 239 | return_all_layers=True, kernel_size=(3,3)) 240 | self.convu_2_2=torch.nn.Conv2d(in_channels=filter_2,out_channels=filter_1, 241 | padding=1, kernel_size=(3,3)) 242 | self.BN_u_2_2=torch.nn.BatchNorm2d(num_features=filter_1) 243 | 244 | self.upsample_1=torch.nn.Upsample(scale_factor=2) 245 | self.convu_1_1=torch.nn.Conv2d(in_channels=filter_1,out_channels=filter_1, 246 | padding=1, kernel_size=(3,3)) 247 | self.BN_u_1_1=torch.nn.BatchNorm2d(num_features=filter_1) 248 | self.convlstm_u_1=ConvLSTM(input_size=[width,height],input_dim=filter_1, 249 | hidden_dim=filter_1, num_layers=1,batch_first=True, 250 | return_all_layers=True, kernel_size=(3,3)) 251 | 252 | self.convu_1_2=torch.nn.Conv2d(in_channels=filter_1,out_channels=1, 253 | padding=1, kernel_size=(3,3)) 254 | 255 | 256 | 257 | def forward(self, histgc_inputs): 258 | ''' 259 | histgc_input should [batch, time, c, w, h] 260 | ''' 261 | B,T,C,W,H=histgc_inputs.size() 262 | 263 | x1=self.conv1(torch.reshape(histgc_inputs, (-1, C, W, H))) 264 | x1=self.BN_1(x1) 265 | x1=torch.relu(x1) 266 | x1 = torch.reshape(x1, (B, T, x1.size(1), x1.size(2), x1.size(3))) 267 | 268 | x1, state1 = self.convlstm1(x1,None) 269 | x1=x1[0] 270 | 271 | #第1次下采样 272 | B,T,C,W,H=x1.size() 273 | x2=self.maxpool2(torch.reshape(x1,(-1,C,W,H))) 274 | x2=self.conv2_1(x2) 275 | x2=self.BN_2_1(x2) 276 | x2=torch.relu(x2) 277 | x2 = torch.reshape(x2, (B, T, x2.size(1), x2.size(2), x2.size(3))) 278 | 279 | x2, state2=self.convlstm2(x2,None) 280 | x2=x2[0] 281 | 282 | B,T,C,W,H=x2.size() 283 | x2=self.conv2_2(torch.reshape(x2,(-1,C,W,H))) 284 | x2=self.BN_2_2(x2) 285 | x2=torch.relu(x2) 286 | x2 = torch.reshape(x2, (B, T, x2.size(1), x2.size(2), x2.size(3))) 287 | 288 | #第2次下采样 289 | B,T,C,W,H=x2.size() 290 | x3=self.maxpool3(torch.reshape(x2,(-1,C,W,H))) 291 | x3=self.conv3_1(x3) 292 | x3=self.BN_3_1(x3) 293 | x3=torch.relu(x3) 294 | x3 = torch.reshape(x3, (B, T, x3.size(1), x3.size(2), x3.size(3))) 295 | 296 | x3, state3=self.convlstm3(x3,None) 297 | x3=x3[0] 298 | 299 | B,T,C,W,H=x3.size() 300 | x3=self.conv3_2(torch.reshape(x3,(-1,C,W,H))) 301 | x3=self.BN_3_2(x3) 302 | x3=torch.relu(x3) 303 | 304 | #第1次上采样 305 | x2_u=self.upsample_2(x3) 306 | x2_u=self.convu_2_1(x2_u) 307 | x2_u=self.BN_u_2_1(x2_u) 308 | x2_u=torch.relu(x2_u) 309 | x2_u=torch.reshape(x2_u,(B,T, x2_u.size(1),x2_u.size(2),x2_u.size(3))) 310 | 311 | 312 | x2_u, state_new_2 =self.convlstm_u_2(x2_u,None) 313 | x2_u=x2_u[0] 314 | 315 | 316 | #变为filter1 317 | B,T,C,W,H=x2_u.size() 318 | x2_u = self.convu_2_2(torch.reshape(x2_u, (-1, C,W,H))) 319 | x2_u=self.BN_u_2_2(x2_u) 320 | x2_u=torch.relu(x2_u) 321 | 322 | #第2次上采样 323 | x1_u=self.upsample_1(x2_u) 324 | x1_u=self.convu_1_1(x1_u) 325 | x1_u=self.BN_u_1_1(x1_u) 326 | x1_u=torch.relu(x1_u) 327 | x1_u=torch.reshape(x1_u,(B,T, x1_u.size(1),x1_u.size(2),x1_u.size(3))) 328 | 329 | x1_u, state_new_1 =self.convlstm_u_1(x1_u,None) 330 | x1_u=x1_u[0] 331 | 332 | 333 | #变为通道1 334 | B,T,C,W,H=x1_u.size() 335 | x1_u = self.convu_1_2(torch.reshape(x1_u, (-1, C,W,H))) 336 | x1_u=torch.reshape(x1_u,(B,T, x1_u.size(1),x1_u.size(2),x1_u.size(3))) 337 | 338 | return x1_u -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | # created by Ma Liang 4 | # contact with liang.ma@nlpr.ia.ac.cn 5 | 6 | import numpy as np 7 | import datetime 8 | import os 9 | import pandas as pd 10 | import random 11 | import time 12 | import threading 13 | import multiprocessing 14 | 15 | def check_file(tt,datelist,hour): 16 | ''' 17 | chech file at the time of 'tt' and its continuous 24 hours and histort hours 18 | pretime 25 times include time.now() 19 | history time include 'hour' times 20 | return if file is ready at the time of 'tt' 21 | ''' 22 | ruitufile = '/data/output/ruitu_data/{}/{}.npy'.format(tt.strftime('%Y%m'),tt.strftime('%Y%m%d%H')) 23 | sign = os.path.exists(ruitufile) 24 | if sign: 25 | pass 26 | # shape0 = np.load(ruitufile).shape[0] 27 | # sign = sign and shape0==25 28 | # if not shape0==25: 29 | # print(ruitufile) 30 | # os.remove(ruitufile) 31 | else: 32 | return False 33 | pretimelist = [ tt+datetime.timedelta(seconds=3600*i) for i in range(25)] 34 | pretimelist = pretimelist+ [ tt-datetime.timedelta(seconds=3600*i) for i in range(hour)] 35 | for pretime in pretimelist: 36 | # gaughDir = '/data/output/guance_data/{}/{}.npy'.format(pretime) 37 | timestring = pretime.strftime("%Y%m%d%H%M") 38 | sign = (timestring in datelist ) and sign 39 | if sign==False : 40 | # print(timestring,os.path.exists(ruitufile),timestring in datelist) 41 | break 42 | return sign 43 | 44 | 45 | def file_dataset(hour ): 46 | '''write a data-ready file list''' 47 | print('creating the dataset with history ', hour, ' hours') 48 | file_dict = pd.read_csv('/data/output/all_guance_data_name_list/all_gc_filename_list.csv',index_col=0) 49 | datelist = [str(line).split('_')[1] for line in file_dict.values] 50 | file_dict.index = datelist 51 | start_time, end_time = datetime.datetime(2016,10,1,0),datetime.datetime(2019,4,1,0) 52 | pretimelist=[] 53 | pretime= start_time 54 | while pretime<=end_time: 55 | if check_file(pretime,datelist,hour): 56 | pretimelist.append(pretime) 57 | pretime += datetime.timedelta(seconds=3600*3) 58 | pretimelist = np.array(pretimelist) 59 | np.save('/data/code/ml/pretimelist_{}.npy'.format(hour),pretimelist) 60 | print('finishing creating dataset with history ', hour, ' hours') 61 | return None 62 | 63 | def my_test_dataset( batch, history_hour, season=None ): 64 | '''return list shape [number , batch]''' 65 | file_dict = pd.read_csv('/data/output/all_guance_data_name_list/2019_04_07_gc_filename_list.csv', index_col=0) 66 | datelist = [str(line).split('_')[1] for line in file_dict.values] 67 | file_dict.index = datelist 68 | target = '/data/code/ml/pretimelist_test_{}.npy'.format(history_hour) 69 | if not os.path.exists(target): 70 | file_test_dataset( history_hour ) 71 | pretimelist = np.load(target, allow_pickle=True) 72 | 73 | if season=='summer': 74 | tmp = [] 75 | for pretime in pretimelist: 76 | if pretime.month in [4,5,6,7,8,9]: 77 | tmp.append(pretime) 78 | pretimelist = np.array(tmp) 79 | print('dataset lenght',len(pretimelist)) 80 | pretimelist = pretimelist[:len(pretimelist)//batch*batch] 81 | pretimelist = np.array(pretimelist).reshape(-1, batch) 82 | return pretimelist, file_dict 83 | 84 | def file_test_dataset(hour ): 85 | '''write a data-ready file list''' 86 | print('creating the dataset with history ', hour, ' hours') 87 | file_dict = pd.read_csv('/data/output/all_guance_data_name_list/2019_04_07_gc_filename_list.csv',index_col=0) 88 | datelist = [str(line).split('_')[1] for line in file_dict.values] 89 | file_dict.index = datelist 90 | start_time, end_time = datetime.datetime(2019,4,1,0),datetime.datetime(2019,7,31,21) 91 | pretimelist=[] 92 | pretime= start_time 93 | while pretime<=end_time: 94 | if check_file(pretime,datelist,hour): 95 | pretimelist.append(pretime) 96 | pretime += datetime.timedelta(seconds=3600*3) 97 | pretimelist = np.array(pretimelist) 98 | np.save('/data/code/ml/pretimelist_test_{}.npy'.format(hour),pretimelist) 99 | print('finishing creating dataset with history ', hour, ' hours') 100 | return None 101 | 102 | 103 | def my_dataset( batch, history_hour, season=None ): 104 | '''return list shape [number , batch]''' 105 | file_dict = pd.read_csv('/data/output/all_guance_data_name_list/all_gc_filename_list.csv', index_col=0) 106 | datelist = [str(line).split('_')[1] for line in file_dict.values] 107 | file_dict.index = datelist 108 | target = '/data/code/ml/pretimelist_{}.npy'.format(history_hour) 109 | if not os.path.exists(target): 110 | file_dataset( history_hour ) 111 | pretimelist = np.load(target, allow_pickle=True) 112 | 113 | if season=='summer': 114 | tmp = [] 115 | for pretime in pretimelist: 116 | if pretime.month in [6,7,8,9]: 117 | tmp.append(pretime) 118 | pretimelist = np.array(tmp) 119 | print('dataset lenght',len(pretimelist)) 120 | pretimelist = pretimelist[:len(pretimelist)//batch*batch] 121 | random.shuffle(pretimelist) 122 | pretimelist = np.array(pretimelist).reshape(-1, batch) 123 | return pretimelist, file_dict 124 | 125 | 126 | 127 | def conbime_thread(batch_list, batch_time): 128 | ''' 129 | parallization the thread to read the data 130 | ''' 131 | # print("Sub-process(es) begin.") 132 | ruitulist, gaugelist, histgaugelist, jobresults = [], [], [], [] 133 | pool = multiprocessing.Pool(processes=12) # 创建4个进程 134 | for filelist, pretime in zip(batch_list, batch_time): 135 | jobresults.append(pool.apply_async(read_one, (filelist, pretime))) 136 | for res in jobresults: 137 | ruituFile, gaugeFile, histgaugeFile = res.get() 138 | ruitulist.append(ruituFile) 139 | gaugelist.append(gaugeFile) 140 | histgaugelist.append(histgaugeFile) 141 | pool.close() # 关闭进程池,表示不能在往进程池中添加进程 142 | pool.join() # 等待进程池中的所有进程执行完毕,必须在close()之后调用 143 | # print("Sub-process(es) done.") 144 | gaugelist, ruitulist, histgaugelist = np.array(gaugelist), np.array(ruitulist), np.array(histgaugelist) 145 | # print(gaugelist.shape, ruitulist.shape, histgaugelist.shape) 146 | return ruitulist, gaugelist, histgaugelist 147 | 148 | 149 | def read_one(filelist, pretime): 150 | '''read single data in training data with preprocessing ''' 151 | # tt = time.time() 152 | ruituFile = np.load(filelist[0])[:,:,:80,:84] 153 | # print('processing',pretime) 154 | gaugeFile = np.array([np.load(file) for file in filelist[1:25]])[:,4:5,:80,:84] 155 | histgaugeFile = np.array([np.load(file) for file in filelist[25:]])[:,:,:80,:84] 156 | ruituFile, gaugeFile, histgaugeFile = norm_preprocess(ruituFile, gaugeFile, histgaugeFile, pretime) 157 | # print(time.time()-tt) 158 | return ruituFile, gaugeFile, histgaugeFile 159 | 160 | 161 | def norm_preprocess(ruituFile, gaugeFile, histgaugeFile, pretime): 162 | ''' 163 | processing with abnormal values, time , geography values, normalized values. 164 | ''' 165 | # print(ruituFile.shape, gaugeFile.shape, histgaugeFile.shape) 166 | #remoev the abnormal value 167 | assert ruituFile.shape[0]==25,print(pretime,'without full prediction') 168 | if (np.abs(ruituFile) > 10000).any(): 169 | meantmp = ruituFile.mean(axis=(0,2,3)) 170 | for i in range(ruituFile.shape[1]): 171 | ruituFile[:,i,:,:][np.abs(ruituFile[:,i,:,:])>10000] = meantmp[i] 172 | 173 | histgaugeFile[np.isnan(histgaugeFile)]=200000 174 | if (np.abs(histgaugeFile) > 10000).any(): 175 | meantmp = histgaugeFile.mean(axis=(0,2,3)) 176 | for i in range(histgaugeFile.shape[1]): 177 | histgaugeFile[:,i,:,:][np.abs(histgaugeFile[:,i,:,:])>10000] = meantmp[i] 178 | #normal the value 179 | ruituInfo = pd.read_csv('/data/output/ruitu_info.csv') 180 | ruitu_mean, ruitu_std = np.ones_like(ruituFile),np.ones_like(ruituFile) 181 | for i in range(len(ruituInfo)): 182 | ruitu_mean[:,i,:,:] *= ruituInfo['mean'].iloc[i] 183 | ruitu_std[:,i,:,:] *= ruituInfo['std'].iloc[i] 184 | ruituFile = (ruituFile-ruitu_mean)/ruitu_std 185 | 186 | gaugeInfo = pd.read_csv('/data/output/gauge_info.csv') 187 | gauge_mean, gauge_std = np.ones_like(histgaugeFile),np.ones_like(histgaugeFile) 188 | for i in range(len(gaugeInfo)): 189 | gauge_mean[:,i,:,:] *= gaugeInfo['mean'].iloc[i] 190 | gauge_std[:,i,:,:] *= gaugeInfo['std'].iloc[i] 191 | histgaugeFile = (histgaugeFile-gauge_mean)/gauge_std 192 | 193 | #add time and geo info 194 | geoinfo = np.load('/data/output/height_norm.npy') 195 | hist_hour = histgaugeFile.shape[0] 196 | pretimelist = [pretime+datetime.timedelta(seconds=i*3600) for i in range(-hist_hour+1, 25)] 197 | yearvariancelist = [ np.sin(2*np.pi*(tt.toordinal()-730180)/365.25) for tt in pretimelist] 198 | dayvariancelist = [ np.sin(2*np.pi*(tt.hour-3)/24) for tt in pretimelist] 199 | ruituFile[1:25, 32:35, :, :] = ruituFile[1:25, 32:35, :, :] - ruituFile[0:24,32:35,:,:] 200 | ruituFile_new = ruituFile[1:].copy() 201 | histgaugeFile[:,7,:,:] = np.array([geoinfo]*histgaugeFile.shape[0]) 202 | histgaugeFile[:,10,:,:] = np.array([sli*yvar for sli, yvar in zip(np.ones([hist_hour,80,84]),yearvariancelist[:hist_hour])]) 203 | histgaugeFile[:,11,:,:] = np.array([sli*dvar for sli, dvar in zip(np.ones([hist_hour,80,84]),dayvariancelist[:hist_hour])]) 204 | tmpyear = np.expand_dims([sli*yvar for sli, yvar in zip(np.ones([24,80,84]),yearvariancelist[hist_hour:])], axis=1) 205 | tmpday = np.expand_dims([sli*dvar for sli, dvar in zip(np.ones([24,80,84]),dayvariancelist[hist_hour:])], axis=1) 206 | tmpgeo = np.expand_dims(np.array([geoinfo]*ruituFile_new.shape[0]),axis=1) 207 | ruituFile_new = np.concatenate((ruituFile_new, tmpyear, tmpday, tmpgeo),axis=1) 208 | # print(ruituFile_new.shape, gaugeFile.shape, histgaugeFile.shape) 209 | return ruituFile_new, gaugeFile, histgaugeFile 210 | 211 | 212 | def load_data2(pretimelist, file_dict, history_hour, binary=False): 213 | ''' 214 | load batch data in parallized way, more faster. 215 | input args: load_data2(pretimelist, file_dict, history_hour, binary=False) 216 | return args: ruitudata, gaugedata, histgaugedata 217 | shape: [batch ,24, channels_1, height, width],[batch ,24 , 1, height, width],[batch , historyhour, channels_2, height, width] 218 | if binary is True, the gaugedata will return in shape [batch ,time, 2, height, width] 219 | ''' 220 | pretimelist = list(pretimelist) 221 | batchfile = [] 222 | for batch_time in pretimelist: 223 | ruituFile = ['/data/output/ruitu_data/{}/{}.npy'.format(batch_time.strftime('%Y%m'),batch_time.strftime('%Y%m%d%H'))] 224 | time24h = [ batch_time+datetime.timedelta(seconds=3600*i) for i in range(1,25)] 225 | gaugeFile = ['/data/output/guance_data/{}/{}'.format(tt.strftime('%Y%m'),file_dict.loc[tt.strftime('%Y%m%d%H%M')].values[0]) for tt in time24h] 226 | timehist = [ batch_time-datetime.timedelta(seconds=3600*i) for i in range(history_hour)] 227 | histgaugeFile = ['/data/output/guance_data/{}/{}'.format(tt.strftime('%Y%m'),file_dict.loc[tt.strftime('%Y%m%d%H%M')].values[0]) for tt in timehist] 228 | singlefile = ruituFile+gaugeFile+histgaugeFile 229 | batchfile.append(singlefile) 230 | 231 | ruitudata, gaugedata, histgaugedata = conbime_thread(batchfile, pretimelist) 232 | 233 | if binary: 234 | # gaugedata = (gaugedata>=0.1).astype('int') 235 | gaugebinary = np.concatenate((gaugedata>=0.1, gaugedata<0.1),axis=2).astype('int') 236 | 237 | gaugedata[ gaugedata<0.1]=0 238 | histgaugedata = np.concatenate((histgaugedata, np.zeros_like(histgaugedata)), axis=1) 239 | return np.array(ruitudata)[:,:,:,:80,:80], np.array(gaugebinary)[:,:,:,:80,:80], np.array(gaugedata[:,:,:,:80,:80]), np.array(histgaugedata[:,:,:,:80,:80]) 240 | 241 | 242 | 243 | # def load_data(pretimelist,file_dict): 244 | # '''pretimelist is a batch timelist at once 245 | # output shape = [batch, 24, channel, 80, 84],[batch, 24, channel, 80, 84] 246 | # ''' 247 | # print('old') 248 | # t1 = time.time() 249 | # pretimelist = list(pretimelist) 250 | # gaugedata = [] 251 | # ruitudata = [] 252 | # for batch_time in pretimelist: 253 | # ruitutmp = np.load('/data/output/ruitu_data/{}/{}.npy'.format(batch_time.strftime('%Y%m'),batch_time.strftime('%Y%m%d%H')))[:24,:,:80,:84] 254 | # time24h = [ batch_time+datetime.timedelta(seconds=3600) for i in range(24)] 255 | # guagetmp = np.array([np.load('/data/output/guance_data/{}/{}'.format(tt.strftime('%Y%m'),file_dict.loc[tt.strftime('%Y%m%d%H%M')].values[0])) for tt in time24h])[:,4:5,:80,:84] 256 | # gaugedata.append(guagetmp) 257 | # ruitudata.append(ruitutmp) 258 | # print('total:',time.time()-t1) 259 | # return np.array(gaugedata), np.array(ruitudata) 260 | 261 | 262 | if __name__=='__main__': 263 | batch = 8 264 | historyhour = 24 265 | batch_filelist, file_dict = my_dataset( batch, historyhour,season='summer') 266 | 267 | split_num=0.7 268 | train_num = int(len(batch_filelist)*split_num) 269 | mydataset = {'train':batch_filelist[:train_num], 'test': batch_filelist[train_num:]} 270 | 271 | for filelist in mydataset['train']: 272 | tt = time.time() 273 | ruitudata, gaugedata, histgaugedata = load_data2(filelist,file_dict,historyhour, binary=True) 274 | print(gaugedata.shape, ruitudata.shape, histgaugedata.shape, 'finished time cost:',time.time()-tt) 275 | # print(gaugedata.mean(axis=(0,1,3,4)),gaugedata.std(axis=(0,1,3,4))) 276 | # print(ruitudata.mean(axis=(0,1,3,4)),ruitudata.std(axis=(0,1,3,4))) 277 | # print(histgaugedata.mean(axis=(0,1,3,4)),histgaugedata.std(axis=(0,1,3,4))) -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import matplotlib.pyplot as plt 7 | # SR : Segmentation Result 8 | # GT : Ground Truth 9 | 10 | 11 | 12 | class SoftIoULoss(nn.Module): 13 | def __init__(self, n_classes=2): 14 | super(SoftIoULoss, self).__init__() 15 | self.n_classes = n_classes 16 | 17 | # @staticmethod 18 | # def to_one_hot(tensor, n_classes): 19 | # n, h, w = tensor.size() 20 | # one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1) 21 | # return one_hot 22 | 23 | def forward(self, input, target): 24 | # logit => N x Classes x H x W 25 | # target => N x H x W 26 | 27 | N = len(input) 28 | 29 | pred = F.softmax(input, dim=1) 30 | # target_onehot = self.to_one_hot(target, self.n_classes) 31 | target_onehot = input 32 | # Numerator Product 33 | inter = pred * target_onehot 34 | # Sum over all pixels N x C x H x W => N x C 35 | inter = inter.view(N, self.n_classes, -1).sum(2) 36 | 37 | # Denominator 38 | union = pred + target_onehot - (pred * target_onehot) 39 | # Sum over all pixels N x C x H x W => N x C 40 | union = union.view(N, self.n_classes, -1).sum(2) 41 | 42 | loss = inter / (union + 1e-16) 43 | 44 | # Return average loss over classes and batch 45 | return -loss.mean() 46 | 47 | class FocalLoss(nn.Module): 48 | def __init__(self, alpha=0.25, gamma=2, weight=None, ignore_index=None): 49 | super(FocalLoss, self).__init__() 50 | self.alpha = alpha 51 | self.gamma = gamma 52 | self.weight = weight 53 | self.ignore_index = ignore_index 54 | self.bce_fn = nn.BCEWithLogitsLoss(weight=self.weight) 55 | 56 | def forward(self, preds, labels): 57 | if self.ignore_index is not None: 58 | mask = labels != self.ignore 59 | labels = labels[mask] 60 | preds = preds[mask] 61 | 62 | logpt = -self.bce_fn(preds, labels) 63 | pt = torch.exp(logpt) 64 | loss = -((1 - pt) ** self.gamma) * self.alpha * logpt 65 | return loss 66 | 67 | 68 | def get_best_threshold(SR, GT): 69 | # DC : Dice Coefficient 70 | # print('first', SR.min(),SR.max()) 71 | max_dc = 0 72 | best_threhold = None 73 | thresholdlist = torch.linspace(0.,1.,21).cuda() 74 | for i, threshold in enumerate(thresholdlist): 75 | DC = get_DC(SR,GT,threshold) 76 | if max_dc < DC : 77 | max_dc = DC 78 | best_threhold = threshold 79 | # print( 'second',SR.min(),SR.max()) 80 | return best_threhold 81 | 82 | 83 | def get_accuracy(SR,GT,threshold=0.5): 84 | SR = SR > threshold 85 | GT = GT == torch.max(GT) 86 | corr = torch.sum(SR==GT) 87 | tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3) 88 | acc = float(corr)/float(tensor_size) 89 | 90 | return acc 91 | 92 | def get_sensitivity(SR,GT,threshold=0.5): 93 | # Sensitivity == Recall 94 | SR = SR > threshold 95 | GT = GT == torch.max(GT) 96 | 97 | # TP : True Positive 98 | # FN : False Negative 99 | TP = ((SR==1)+(GT==1))==2 100 | FN = ((SR==0)+(GT==1))==2 101 | 102 | SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6) 103 | 104 | return SE 105 | 106 | def get_specificity(SR,GT,threshold=0.5): 107 | SR = SR > threshold 108 | GT = GT == torch.max(GT) 109 | 110 | # TN : True Negative 111 | # FP : False Positive 112 | TN = ((SR==0)+(GT==0))==2 113 | FP = ((SR==1)+(GT==0))==2 114 | 115 | SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6) 116 | 117 | return SP 118 | 119 | def get_precision(SR,GT,threshold=0.5): 120 | SR = SR > threshold 121 | GT = GT == torch.max(GT) 122 | 123 | # TP : True Positive 124 | # FP : False Positive 125 | TP = ((SR==1)+(GT==1))==2 126 | FP = ((SR==1)+(GT==0))==2 127 | 128 | PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6) 129 | 130 | return PC 131 | 132 | def get_F1(SR,GT,threshold=0.5): 133 | # Sensitivity == Recall 134 | SE = get_sensitivity(SR,GT,threshold=threshold) 135 | PC = get_precision(SR,GT,threshold=threshold) 136 | 137 | F1 = 2*SE*PC/(SE+PC + 1e-6) 138 | 139 | return F1 140 | 141 | def get_JS(SR,GT,threshold=0.5): 142 | # JS : Jaccard similarity 143 | SR = SR > threshold 144 | GT = GT == torch.max(GT) 145 | 146 | Inter = torch.sum((SR+GT)==2) 147 | Union = torch.sum((SR+GT)>=1) 148 | 149 | JS = float(Inter)/(float(Union) + 1e-6) 150 | 151 | return JS 152 | 153 | def get_DC(SR,GT,threshold=0.5): 154 | # DC : Dice Coefficient 155 | SR = SR > threshold 156 | GT = GT == torch.max(GT) 157 | 158 | Inter = torch.sum((SR+GT)==2) 159 | DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6) 160 | 161 | return DC 162 | 163 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from solver import Solver 4 | from dataset import my_dataset, load_data2, my_test_dataset 5 | from torch.backends import cudnn 6 | import random 7 | 8 | def main(config): 9 | cudnn.benchmark = True 10 | 11 | # Create directories if not exist 12 | if not os.path.exists(config.model_path): 13 | os.makedirs(config.model_path) 14 | if not os.path.exists(config.result_path): 15 | os.makedirs(config.result_path) 16 | config.result_path = os.path.join(config.result_path,config.model_type) 17 | if not os.path.exists(config.result_path): 18 | os.makedirs(config.result_path) 19 | 20 | lr = random.random()*0.0005 + 0.0000005 21 | augmentation_prob= random.random()*0.7 22 | # lr = 0.0005 23 | # augmentation_prob = 0 24 | epoch = random.choice([100,150,200,250]) 25 | # epoch = 100 26 | decay_ratio = 0.8 27 | decay_epoch = int(epoch*decay_ratio) 28 | 29 | config.augmentation_prob = augmentation_prob 30 | config.num_epochs = epoch 31 | config.lr = lr 32 | config.num_epochs_decay = decay_epoch 33 | 34 | print(config) 35 | 36 | batch = config.batch_size 37 | historyhour = config.historyhour 38 | batch_filelist, file_dict = my_dataset( batch, historyhour,season='summer') 39 | 40 | batch_test, file_dict_test = my_test_dataset( batch, historyhour, season=False) 41 | split_num=0.9 42 | valid_num=1 43 | train_num = int(len(batch_filelist)*split_num) 44 | valid_num = int(len(batch_filelist)*valid_num) 45 | mydataset = {'train':batch_filelist[:train_num], 46 | 'valid':batch_filelist[train_num:valid_num], 47 | 'test': batch_test} 48 | 49 | solver = Solver(config, mydataset['train'], mydataset['valid'], mydataset['test']) 50 | 51 | 52 | # Train and sample the images 53 | if config.mode == 'train': 54 | solver.train() 55 | elif config.mode == 'test': 56 | solver.test() 57 | 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser() 61 | 62 | 63 | # model hyper-parameters 64 | parser.add_argument('--image_size', type=int, default=224) 65 | parser.add_argument('--t', type=int, default=3, help='t for Recurrent step of R2U_Net or R2AttU_Net') 66 | 67 | # training hyper-parameters 68 | parser.add_argument('--img_ch', type=int, default=47) 69 | parser.add_argument('--output_ch', type=int, default=2) 70 | parser.add_argument('--num_epochs', type=int, default=250) 71 | parser.add_argument('--num_epochs_decay', type=int, default=70) 72 | parser.add_argument('--batch_size', type=int, default=8) 73 | parser.add_argument('--num_workers', type=int, default=8) 74 | parser.add_argument('--lr', type=float, default=0.002) 75 | parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam 76 | parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam 77 | parser.add_argument('--augmentation_prob', type=float, default=0.1) 78 | 79 | parser.add_argument('--log_step', type=int, default=2) 80 | parser.add_argument('--val_step', type=int, default=2) 81 | parser.add_argument('--historyhour', type=int, default=24) 82 | parser.add_argument('--test_only', type=bool, default=False) 83 | # misc 84 | parser.add_argument('--mode', type=str, default='train') 85 | parser.add_argument('--model_type', type=str, default='RCA_U_Net', help='U_Net/R2U_Net/AttU_Net/R2AttU_Net') 86 | parser.add_argument('--model_path', type=str, default='./models') 87 | # parser.add_argument('--train_path', type=str, default='./dataset/train/') 88 | # parser.add_argument('--valid_path', type=str, default='./dataset/valid/') 89 | # parser.add_argument('--test_path', type=str, default='./dataset/test/') 90 | parser.add_argument('--result_path', type=str, default='./result/') 91 | 92 | parser.add_argument('--cuda_idx', type=int, default=1) 93 | 94 | config = parser.parse_args() 95 | main(config) 96 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | import time 5 | import torch 6 | from torch import nn,device 7 | from torch.autograd import Variable 8 | import pandas as pd 9 | 10 | import datetime 11 | import pandas as pd 12 | import random 13 | import threading 14 | import multiprocessing 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | 19 | class ConvLSTMCell(nn.Module): 20 | 21 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias): 22 | """ 23 | Initialize ConvLSTM cell. 24 | 25 | Parameters 26 | ---------- 27 | input_size: (int, int) 28 | Height and width of input tensor as (height, width). 29 | input_dim: int 30 | Number of channels of input tensor. 31 | hidden_dim: int 32 | Number of channels of hidden state. 33 | kernel_size: (int, int) 34 | Size of the convolutional kernel. 35 | bias: bool 36 | Whether or not to add the bias. 37 | """ 38 | 39 | super(ConvLSTMCell, self).__init__() 40 | 41 | self.height, self.width = input_size 42 | self.input_dim = input_dim 43 | self.hidden_dim = hidden_dim 44 | 45 | self.kernel_size = kernel_size 46 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 47 | self.bias = bias 48 | 49 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 50 | out_channels=4 * self.hidden_dim, 51 | kernel_size=self.kernel_size, 52 | padding=self.padding, 53 | bias=self.bias) 54 | 55 | def forward(self, input_tensor, cur_state): 56 | 57 | h_cur, c_cur = cur_state 58 | 59 | # print(input_tensor.shape) 60 | # print(cur_state.shape) 61 | 62 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 63 | 64 | combined_conv = self.conv(combined) 65 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 66 | 67 | # i = torch.sigmoid(cc_i) 68 | # f = torch.sigmoid(cc_f) 69 | # o = torch.sigmoid(cc_o) 70 | # g = torch.relu(cc_g) 71 | 72 | # c_next = f * c_cur + i * g 73 | # h_next = o * torch.relu(c_next) 74 | 75 | i = torch.sigmoid(cc_i) 76 | f = torch.sigmoid(cc_f) 77 | o = torch.sigmoid(cc_o) 78 | g = torch.tanh(cc_g) 79 | 80 | c_next = f * c_cur + i * g 81 | h_next = o * torch.tanh(c_next) 82 | 83 | return h_next, c_next 84 | 85 | def init_hidden(self, batch_size): 86 | device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") 87 | return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).to(device), 88 | Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).to(device)) 89 | 90 | 91 | class ConvLSTM(nn.Module): 92 | 93 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers, 94 | batch_first=False, bias=True, return_all_layers=False): 95 | super(ConvLSTM, self).__init__() 96 | 97 | self._check_kernel_size_consistency(kernel_size) 98 | 99 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 100 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 101 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 102 | if not len(kernel_size) == len(hidden_dim) == num_layers: 103 | raise ValueError('Inconsistent list length.') 104 | 105 | self.height, self.width = input_size 106 | 107 | self.input_dim = input_dim 108 | self.hidden_dim = hidden_dim 109 | self.kernel_size = kernel_size 110 | self.num_layers = num_layers 111 | self.batch_first = batch_first 112 | self.bias = bias 113 | self.return_all_layers = return_all_layers 114 | 115 | cell_list = [] 116 | for i in range(0, self.num_layers): 117 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1] 118 | 119 | cell_list.append(ConvLSTMCell(input_size=(self.height, self.width), 120 | input_dim=cur_input_dim, 121 | hidden_dim=self.hidden_dim[i], 122 | kernel_size=self.kernel_size[i], 123 | bias=self.bias)) 124 | 125 | self.cell_list = nn.ModuleList(cell_list) 126 | 127 | def forward(self, input_tensor, hidden_state=None): 128 | """ 129 | 130 | Parameters 131 | ---------- 132 | input_tensor: todo 133 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 134 | hidden_state: todo 135 | None. todo implement stateful (num_layers, 2, batch, filter, h, w) 136 | 137 | Returns 138 | ------- 139 | last_state_list, layer_output 140 | """ 141 | if not self.batch_first: 142 | # (t, b, c, h, w) -> (b, t, c, h, w) 143 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 144 | print('hidden_state_shape before:',type( hidden_state)) 145 | # Implement stateful ConvLSTM 146 | if hidden_state is None: 147 | hidden_state = self._init_hidden(batch_size=input_tensor.size(0)) 148 | print('hidden_state_shape after:',len(hidden_state), hidden_state[0].shape) 149 | layer_output_list = [] 150 | last_state_list = [] 151 | 152 | seq_len = input_tensor.size(1) # 读取sequence length 153 | cur_layer_input = input_tensor 154 | 155 | for layer_idx in range(self.num_layers): 156 | 157 | h, c = hidden_state[layer_idx] # (b,c,h,w) 158 | output_inner = [] 159 | for t in range(seq_len): 160 | 161 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], 162 | cur_state=[h, c]) 163 | output_inner.append(h) 164 | 165 | layer_output = torch.stack(output_inner, dim=1) 166 | cur_layer_input = layer_output 167 | 168 | layer_output_list.append(layer_output) 169 | last_state_list.append([h, c]) 170 | 171 | if not self.return_all_layers: 172 | layer_output_list = layer_output_list[-1:] 173 | last_state_list = last_state_list[-1:] 174 | 175 | return layer_output_list, last_state_list 176 | 177 | def _init_hidden(self, batch_size): 178 | init_states = [] 179 | for i in range(self.num_layers): 180 | init_states.append(self.cell_list[i].init_hidden(batch_size)) 181 | return init_states 182 | 183 | @staticmethod 184 | def _check_kernel_size_consistency(kernel_size): 185 | if not (isinstance(kernel_size, tuple) or (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 186 | raise ValueError('`kernel_size` must be tuple or list of tuples') 187 | 188 | @staticmethod 189 | def _extend_for_multilayer(param, num_layers): 190 | if not isinstance(param, list): 191 | param = [param] * num_layers 192 | return param 193 | 194 | 195 | class encoder_model(torch.nn.Module): 196 | 197 | def __init__(self, histgc_feature, width,height): 198 | 199 | super(encoder_model,self).__init__() 200 | self.histgc_feature=histgc_feature 201 | self.width=width 202 | self.height=height 203 | filter_1=16 204 | filter_2=32 205 | filter_3=64 206 | filter_4=128 207 | filter_5=128 208 | self.conv1=torch.nn.Conv2d(in_channels=histgc_feature,out_channels=filter_1, 209 | padding=1,kernel_size=(3,3)) 210 | self.BN_1=torch.nn.BatchNorm2d(num_features=filter_1) 211 | self.convlstm1=ConvLSTM(input_size=[width,height],input_dim=filter_1, 212 | hidden_dim=filter_1, num_layers=1,batch_first=True, 213 | return_all_layers=True,kernel_size=(3,3)) 214 | 215 | self.maxpool2=torch.nn.MaxPool2d(2) 216 | self.conv2_1=torch.nn.Conv2d(in_channels=filter_1,out_channels=filter_2, 217 | padding=1,kernel_size=(3,3)) 218 | self.BN_2_1=torch.nn.BatchNorm2d(num_features=filter_2) 219 | self.convlstm2=ConvLSTM(input_size=[width//2,height//2],input_dim=filter_2, 220 | hidden_dim=filter_2, num_layers=1,batch_first=True, 221 | return_all_layers=True,kernel_size=(3,3)) 222 | self.conv2_2=torch.nn.Conv2d(in_channels=filter_2,out_channels=filter_2, 223 | padding=1,kernel_size=(3,3)) 224 | self.BN_2_2=torch.nn.BatchNorm2d(num_features=filter_2) 225 | 226 | 227 | self.maxpool3=torch.nn.MaxPool2d(2) 228 | self.conv3_1=torch.nn.Conv2d(in_channels=filter_2,out_channels=filter_3, 229 | padding=1,kernel_size=(3,3)) 230 | self.BN_3_1=torch.nn.BatchNorm2d(num_features=filter_3) 231 | self.convlstm3=ConvLSTM(input_size=[width//4,height//4],input_dim=filter_3, 232 | hidden_dim=filter_3, num_layers=1,batch_first=True, 233 | return_all_layers=True,kernel_size=(3,3)) 234 | self.conv3_2=torch.nn.Conv2d(in_channels=filter_3,out_channels=filter_3, 235 | padding=1,kernel_size=(3,3)) 236 | self.BN_3_2=torch.nn.BatchNorm2d(num_features=filter_3) 237 | 238 | 239 | self.maxpool4=torch.nn.MaxPool2d(2) 240 | self.conv4_1=torch.nn.Conv2d(in_channels=filter_3,out_channels=filter_4, 241 | padding=1,kernel_size=(3,3)) 242 | self.BN_4_1=torch.nn.BatchNorm2d(num_features=filter_4) 243 | self.convlstm4=ConvLSTM(input_size=[width//8,height//8],input_dim=filter_4, 244 | hidden_dim=filter_4, num_layers=1,batch_first=True, 245 | return_all_layers=True,kernel_size=(3,3)) 246 | self.conv4_2=torch.nn.Conv2d(in_channels=filter_4,out_channels=filter_4, 247 | padding=1,kernel_size=(3,3)) 248 | self.BN_4_2=torch.nn.BatchNorm2d(num_features=filter_4) 249 | 250 | 251 | self.maxpool5=torch.nn.MaxPool2d(2) 252 | self.conv5_1=torch.nn.Conv2d(in_channels=filter_4,out_channels=filter_5, 253 | padding=1,kernel_size=(3,3)) 254 | self.BN_5_1=torch.nn.BatchNorm2d(num_features=filter_5) 255 | self.convlstm5=ConvLSTM(input_size=[width//16,height//16],input_dim=filter_5, 256 | hidden_dim=filter_5, num_layers=1,batch_first=True, 257 | return_all_layers=True,kernel_size=(3,3)) 258 | 259 | 260 | def forward(self, histgc_inputs): 261 | 262 | B,T,C,W,H=histgc_inputs.size() 263 | 264 | x1=self.conv1(torch.reshape(histgc_inputs, (-1, C, W, H))) 265 | x1=self.BN_1(x1) 266 | x1=torch.relu(x1) 267 | x1 = torch.reshape(x1, (B, T, x1.size(1), x1.size(2), x1.size(3))) 268 | 269 | x1, state1 = self.convlstm1(x1,None) 270 | x1=x1[0] 271 | 272 | #第1次下采样 273 | B,T,C,W,H=x1.size() 274 | x2=self.maxpool2(torch.reshape(x1,(-1,C,W,H))) 275 | x2=self.conv2_1(x2) 276 | x2=self.BN_2_1(x2) 277 | x2=torch.relu(x2) 278 | x2 = torch.reshape(x2, (B, T, x2.size(1), x2.size(2), x2.size(3))) 279 | 280 | x2, state2=self.convlstm2(x2,None) 281 | x2=x2[0] 282 | 283 | B,T,C,W,H=x2.size() 284 | x2=self.conv2_2(torch.reshape(x2,(-1,C,W,H))) 285 | x2=self.BN_2_2(x2) 286 | x2=torch.relu(x2) 287 | x2 = torch.reshape(x2, (B, T, x2.size(1), x2.size(2), x2.size(3))) 288 | 289 | #第2次下采样 290 | B,T,C,W,H=x2.size() 291 | x3=self.maxpool3(torch.reshape(x2,(-1,C,W,H))) 292 | x3=self.conv3_1(x3) 293 | x3=self.BN_3_1(x3) 294 | x3=torch.relu(x3) 295 | x3 = torch.reshape(x3, (B, T, x3.size(1), x3.size(2), x3.size(3))) 296 | 297 | x3, state3=self.convlstm3(x3,None) 298 | x3=x3[0] 299 | 300 | B,T,C,W,H=x3.size() 301 | x3=self.conv3_2(torch.reshape(x3,(-1,C,W,H))) 302 | x3=self.BN_3_2(x3) 303 | x3=torch.relu(x3) 304 | x3 = torch.reshape(x3, (B, T, x3.size(1), x3.size(2), x3.size(3))) 305 | 306 | #第3次下采样 307 | B,T,C,W,H=x3.size() 308 | x4=self.maxpool4(torch.reshape(x3,(-1,C,W,H))) 309 | x4=self.conv4_1(x4) 310 | x4=self.BN_4_1(x4) 311 | x4=torch.relu(x4) 312 | x4 = torch.reshape(x4, (B, T, x4.size(1), x4.size(2), x4.size(3))) 313 | 314 | x4, state4=self.convlstm4(x4,None) 315 | x4=x4[0] 316 | 317 | B,T,C,W,H=x4.size() 318 | x4=self.conv4_2(torch.reshape(x4,(-1,C,W,H))) 319 | x4=self.BN_4_2(x4) 320 | x4=torch.relu(x4) 321 | x4 = torch.reshape(x4, (B, T, x4.size(1), x4.size(2), x4.size(3))) 322 | 323 | #第4次下采样 324 | B,T,C,W,H=x4.size() 325 | x5=self.maxpool5(torch.reshape(x4,(-1,C,W,H))) 326 | x5=self.conv5_1(x5) 327 | x5=self.BN_5_1(x5) 328 | x5=torch.relu(x5) 329 | x5 = torch.reshape(x5, (B, T, x5.size(1), x5.size(2), x5.size(3))) 330 | 331 | x5, state5=self.convlstm5(x5,None) 332 | x5=x5[0] 333 | 334 | state=state1+state2+state3+state4+state5 335 | 336 | return state 337 | 338 | class forecaster_model(torch.nn.Module): 339 | 340 | 341 | def __init__(self, fut_guance_feature ,width,height): 342 | 343 | super(forecaster_model,self).__init__() 344 | 345 | self.fut_guance_feature=fut_guance_feature 346 | self.width=width 347 | self.height=height 348 | filter_1=16 349 | filter_2=32 350 | filter_3=64 351 | filter_4=128 352 | filter_5=128 353 | self.convd_1=torch.nn.Conv2d(in_channels=fut_guance_feature, out_channels=filter_1, 354 | padding=1,kernel_size=(3,3)) 355 | self.BN_d_1=torch.nn.BatchNorm2d(num_features=filter_1) 356 | 357 | self.maxpoold_1=torch.nn.MaxPool2d(2) 358 | self.convd_2=torch.nn.Conv2d(in_channels=filter_1, out_channels=filter_2, 359 | padding=1,kernel_size=(3,3)) 360 | self.BN_d_2=torch.nn.BatchNorm2d(num_features=filter_2) 361 | 362 | self.maxpoold_2=torch.nn.MaxPool2d(2) 363 | self.convd_3=torch.nn.Conv2d(in_channels=filter_2, out_channels=filter_3, 364 | padding=1,kernel_size=(3,3)) 365 | self.BN_d_3=torch.nn.BatchNorm2d(num_features=filter_3) 366 | 367 | self.maxpoold_3=torch.nn.MaxPool2d(2) 368 | self.convd_4=torch.nn.Conv2d(in_channels=filter_3, out_channels=filter_4, 369 | padding=1,kernel_size=(3,3)) 370 | self.BN_d_4=torch.nn.BatchNorm2d(num_features=filter_4) 371 | 372 | self.maxpoold_4=torch.nn.MaxPool2d(2) 373 | self.convd_5=torch.nn.Conv2d(in_channels=filter_4, out_channels=filter_5, 374 | padding=1,kernel_size=(3,3)) 375 | self.BN_d_5=torch.nn.BatchNorm2d(num_features=filter_5) 376 | 377 | 378 | self.convlstm_u_5=ConvLSTM(input_size=[width//16,height//16],input_dim=filter_5, 379 | hidden_dim=filter_5, num_layers=1,batch_first=True, 380 | return_all_layers=True, kernel_size=(3,3)) 381 | 382 | self.convu_5_1=torch.nn.Conv2d(in_channels=filter_5,out_channels=filter_4, 383 | padding=1, kernel_size=(3,3)) 384 | self.BN_u_5_1=torch.nn.BatchNorm2d(num_features=filter_4) 385 | 386 | self.upsample_4=torch.nn.Upsample(scale_factor=2) 387 | self.convu_4_1=torch.nn.Conv2d(in_channels=filter_4,out_channels=filter_4, 388 | padding=1, kernel_size=(3,3)) 389 | self.BN_u_4_1=torch.nn.BatchNorm2d(num_features=filter_4) 390 | self.convlstm_u_4=ConvLSTM(input_size=[width//8,height//8],input_dim=filter_4, 391 | hidden_dim=filter_4, num_layers=1,batch_first=True, 392 | return_all_layers=True, kernel_size=(3,3)) 393 | self.convu_4_2=torch.nn.Conv2d(in_channels=filter_4,out_channels=filter_3, 394 | padding=1, kernel_size=(3,3)) 395 | self.BN_u_4_2=torch.nn.BatchNorm2d(num_features=filter_3) 396 | 397 | self.upsample_3=torch.nn.Upsample(scale_factor=2) 398 | self.convu_3_1=torch.nn.Conv2d(in_channels=filter_3,out_channels=filter_3, 399 | padding=1, kernel_size=(3,3)) 400 | self.BN_u_3_1=torch.nn.BatchNorm2d(num_features=filter_3) 401 | self.convlstm_u_3=ConvLSTM(input_size=[width//4,height//4],input_dim=filter_3, 402 | hidden_dim=filter_3, num_layers=1,batch_first=True, 403 | return_all_layers=True, kernel_size=(3,3)) 404 | self.convu_3_2=torch.nn.Conv2d(in_channels=filter_3,out_channels=filter_2, 405 | padding=1, kernel_size=(3,3)) 406 | self.BN_u_3_2=torch.nn.BatchNorm2d(num_features=filter_2) 407 | 408 | 409 | self.upsample_2=torch.nn.Upsample(scale_factor=2) 410 | self.convu_2_1=torch.nn.Conv2d(in_channels=filter_2,out_channels=filter_2, 411 | padding=1, kernel_size=(3,3)) 412 | self.BN_u_2_1=torch.nn.BatchNorm2d(num_features=filter_2) 413 | self.convlstm_u_2=ConvLSTM(input_size=[width//2,height//2],input_dim=filter_2, 414 | hidden_dim=filter_2, num_layers=1,batch_first=True, 415 | return_all_layers=True, kernel_size=(3,3)) 416 | self.convu_2_2=torch.nn.Conv2d(in_channels=filter_2,out_channels=filter_1, 417 | padding=1, kernel_size=(3,3)) 418 | self.BN_u_2_2=torch.nn.BatchNorm2d(num_features=filter_1) 419 | 420 | self.upsample_1=torch.nn.Upsample(scale_factor=2) 421 | self.convu_1_1=torch.nn.Conv2d(in_channels=filter_1,out_channels=filter_1, 422 | padding=1, kernel_size=(3,3)) 423 | self.BN_u_1_1=torch.nn.BatchNorm2d(num_features=filter_1) 424 | self.convlstm_u_1=ConvLSTM(input_size=[width,height],input_dim=filter_1, 425 | hidden_dim=filter_1, num_layers=1,batch_first=True, 426 | return_all_layers=True, kernel_size=(3,3)) 427 | 428 | self.convu_1_2=torch.nn.Conv2d(in_channels=filter_1,out_channels=1, 429 | padding=1, kernel_size=(3,3)) 430 | 431 | 432 | def forward(self, guance_data,encoder_states): 433 | 434 | state_1=encoder_states[0] 435 | state_2=encoder_states[1] 436 | state_3=encoder_states[2] 437 | state_4=encoder_states[3] 438 | state_5=encoder_states[4] 439 | 440 | #变为filter1 441 | B,T,C,W,H=guance_data.size() 442 | print('guance data', guance_data.shape) 443 | x_gc_1=self.convd_1(torch.reshape(guance_data,(-1,C,W,H))) 444 | x_gc_1=self.BN_d_1(x_gc_1) 445 | x_gc_1=torch.relu(x_gc_1) 446 | 447 | #第1次下采样 448 | x_gc_2=self.maxpoold_1(x_gc_1) 449 | #变为filter2 450 | x_gc_2=self.convd_2(x_gc_2) 451 | x_gc_2=self.BN_d_2(x_gc_2) 452 | x_gc_2=torch.relu(x_gc_2) 453 | 454 | #第2次下采样 455 | x_gc_3=self.maxpoold_2(x_gc_2) 456 | #变为filter3 457 | x_gc_3=self.convd_3(x_gc_3) 458 | x_gc_3=self.BN_d_3(x_gc_3) 459 | x_gc_3=torch.relu(x_gc_3) 460 | 461 | #第3次下采样 462 | x_gc_4=self.maxpoold_3(x_gc_3) 463 | #变为filter4 464 | x_gc_4=self.convd_4(x_gc_4) 465 | x_gc_4=self.BN_d_4(x_gc_4) 466 | x_gc_4=torch.relu(x_gc_4) 467 | 468 | #第4次下采样 469 | x_gc_5=self.maxpoold_4(x_gc_4) 470 | #变为filter5 471 | x_gc_5=self.convd_5(x_gc_5) 472 | x_gc_5=self.BN_d_5(x_gc_5) 473 | x_gc_5=torch.relu(x_gc_5) 474 | 475 | x_gc_1 = torch.reshape(x_gc_1, (B,T,-1, W, H)) #filter1 476 | x_gc_2 = torch.reshape(x_gc_2, (B,T,-1, W//2, H//2)) #filter2 477 | x_gc_3 = torch.reshape(x_gc_3, (B,T,-1, W//4, H//4)) #filter3 478 | x_gc_4 = torch.reshape(x_gc_4, (B,T,-1, W//8, H//8)) #filter4 479 | x_gc_5 = torch.reshape(x_gc_5, (B,T,-1, W//16, H//16)) #filter5 480 | 481 | #第1次convlstm 482 | # print(x_gc_5.size()) 483 | # print(state_5[0].size()) 484 | 485 | x5, state_new_5 = self.convlstm_u_5(x_gc_5,[state_5]) 486 | x5=x5[0] 487 | x5=x5+x_gc_5 488 | 489 | #变为filter4 490 | B,T,C,W,H=x5.size() 491 | x5 = self.convu_5_1(torch.reshape(x5, (-1, C,W,H))) 492 | x5=self.BN_u_5_1(x5) 493 | x5=torch.relu(x5) 494 | 495 | #第1次上采样 496 | x4=self.upsample_4(x5) 497 | x4=self.convu_4_1(x4) 498 | x4=self.BN_u_4_1(x4) 499 | x4=torch.relu(x4) 500 | x4=torch.reshape(x4,(B,T, x4.size(1),x4.size(2),x4.size(3))) 501 | 502 | x4, state_new_4 = self.convlstm_u_4(x4,[state_4]) 503 | x4=x4[0] 504 | 505 | x4=x4+x_gc_4 506 | 507 | #变为filter3 508 | B,T,C,W,H=x4.size() 509 | x4 = self.convu_4_2(torch.reshape(x4, (-1, C,W,H))) 510 | x4=self.BN_u_4_2(x4) 511 | x4=torch.relu(x4) 512 | 513 | #第2次上采样 514 | x3=self.upsample_3(x4) 515 | x3=self.convu_3_1(x3) 516 | x3=self.BN_u_3_1(x3) 517 | x3=torch.relu(x3) 518 | x3=torch.reshape(x3,(B,T, x3.size(1),x3.size(2),x3.size(3))) 519 | 520 | x3, state_new_3 =self.convlstm_u_3(x3,[state_3]) 521 | x3=x3[0] 522 | 523 | x3=x3+x_gc_3 524 | 525 | #变为filter2 526 | B,T,C,W,H=x3.size() 527 | x3 = self.convu_3_2(torch.reshape(x3, (-1, C,W,H))) 528 | x3=self.BN_u_3_2(x3) 529 | x3=torch.relu(x3) 530 | 531 | #第3次上采样 532 | x2=self.upsample_2(x3) 533 | x2=self.convu_2_1(x2) 534 | x2=self.BN_u_2_1(x2) 535 | x2=torch.relu(x2) 536 | x2=torch.reshape(x2,(B,T, x2.size(1),x2.size(2),x2.size(3))) 537 | 538 | 539 | print(x2.size()) 540 | print(state_2[0].size()) 541 | x2, state_new_2 =self.convlstm_u_2(x2,[state_2]) 542 | x2=x2[0] 543 | 544 | x2=x2+x_gc_2 545 | 546 | #变为filter1 547 | B,T,C,W,H=x2.size() 548 | x2 = self.convu_2_2(torch.reshape(x2, (-1, C,W,H))) 549 | x2=self.BN_u_2_2(x2) 550 | x2=torch.relu(x2) 551 | 552 | #第4次上采样 553 | x1=self.upsample_1(x2) 554 | x1=self.convu_1_1(x1) 555 | x1=self.BN_u_1_1(x1) 556 | x1=torch.relu(x1) 557 | x1=torch.reshape(x1,(B,T, x1.size(1),x1.size(2),x1.size(3))) 558 | 559 | x1, state_new_1 =self.convlstm_u_1(x1,[state_1]) 560 | x1=x1[0] 561 | 562 | x1=x1+x_gc_1 563 | 564 | #变为通道1 565 | B,T,C,W,H=x1.size() 566 | x1 = self.convu_1_2(torch.reshape(x1, (-1, C,W,H))) 567 | x1=torch.reshape(x1,(B,T, x1.size(1),x1.size(2),x1.size(3))) 568 | 569 | new_state=state_new_1+state_new_2+state_new_3+state_new_4+state_new_5 570 | 571 | return [x1,new_state] 572 | 573 | 574 | class combined_net(torch.nn.Module): 575 | 576 | def __init__(self,histgc_feature,fut_guance_feature,width,height): 577 | 578 | super(combined_net,self).__init__() 579 | 580 | self.histgc_feature=histgc_feature 581 | self.fut_guance_feature=fut_guance_feature 582 | self.width=width 583 | self.height=height 584 | self.encoder=encoder_model(histgc_feature,width,height) 585 | self.forecaster=forecaster_model(fut_guance_feature,width,height) 586 | 587 | 588 | # def forward(self,hist_gc_inputs,fut_gc_inputs): 589 | 590 | # encoder_states=self.encoder(hist_gc_inputs) 591 | # inputdata = fut_gc_inputs[:,0:1,:,:] 592 | # y = torch.zeros(fut_gc_inputs.size(), device = fut_gc_inputs.device) 593 | 594 | # state = [] 595 | # for i in range(fut_gc_inputs.size(1)): 596 | # print('time',i) 597 | # x = self.forecaster(inputdata, encoder_states) 598 | 599 | # # if i>3: 600 | # # inputdata=x[0] 601 | # # else: 602 | # inputdata=fut_gc_inputs[:,i:i+1,:,:] 603 | # encoder_states=x[1] 604 | # y[:,i:i+1,:,:]=x[0] 605 | # state.append(x[1]) 606 | 607 | # return [y,state] 608 | def forward(self,hist_gc_inputs,fut_gc_inputs): 609 | 610 | encoder_states=self.encoder(hist_gc_inputs) 611 | 612 | x = self.forecaster(fut_gc_inputs,encoder_states) 613 | 614 | y=x[0] 615 | state=x[1] 616 | 617 | return [y,state] -------------------------------------------------------------------------------- /models/RCA_U_Net-100-0.0005-80-0.0000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trichtu/ConvLSTM-RAU-net/53db02574cfa87c598c52f030a1bf07f5e607d04/models/RCA_U_Net-100-0.0005-80-0.0000.pkl -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from convLSTM_network import convLSTM_model 6 | 7 | def init_weights(net, init_type='normal', gain=0.02): 8 | def init_func(m): 9 | classname = m.__class__.__name__ 10 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 11 | if init_type == 'normal': 12 | init.normal_(m.weight.data, 0.0, gain) 13 | elif init_type == 'xavier': 14 | init.xavier_normal_(m.weight.data, gain=gain) 15 | elif init_type == 'kaiming': 16 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 17 | elif init_type == 'orthogonal': 18 | init.orthogonal_(m.weight.data, gain=gain) 19 | else: 20 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 21 | if hasattr(m, 'bias') and m.bias is not None: 22 | init.constant_(m.bias.data, 0.0) 23 | elif classname.find('BatchNorm2d') != -1: 24 | init.normal_(m.weight.data, 1.0, gain) 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | print('initialize network with %s' % init_type) 28 | net.apply(init_func) 29 | 30 | class conv_block(nn.Module): 31 | def __init__(self,ch_in,ch_out): 32 | super(conv_block,self).__init__() 33 | self.conv = nn.Sequential( 34 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 35 | nn.BatchNorm2d(ch_out), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 38 | nn.BatchNorm2d(ch_out), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | 43 | def forward(self,x): 44 | x = self.conv(x) 45 | return x 46 | 47 | 48 | 49 | class up_conv(nn.Module): 50 | def __init__(self,ch_in,ch_out): 51 | super(up_conv,self).__init__() 52 | self.up = nn.Sequential( 53 | nn.Upsample(scale_factor=2, mode='bilinear'), 54 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 55 | nn.BatchNorm2d(ch_out), 56 | nn.ReLU(inplace=True) 57 | ) 58 | 59 | def forward(self,x): 60 | x = self.up(x) 61 | return x 62 | 63 | 64 | 65 | class recurr_conv_block(nn.Module): 66 | def __init__(self,ch_in,ch_out): 67 | super(recurr_conv_block,self).__init__() 68 | self.conv1 = nn.Sequential( 69 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 70 | nn.BatchNorm2d(ch_out), 71 | nn.ReLU(inplace=True)) 72 | self.conv2 = nn.Sequential( 73 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 74 | nn.BatchNorm2d(ch_out), 75 | nn.ReLU(inplace=True)) 76 | self.conv3 = nn.Sequential( 77 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 78 | nn.BatchNorm2d(ch_out), 79 | nn.ReLU(inplace=True)) 80 | 81 | 82 | def forward(self,x,y): 83 | state = self.conv1(x) 84 | if not (type(y)==type(x)): 85 | x = self.conv2(state) 86 | else: 87 | x = self.conv2(state+y) 88 | state = self.conv3(x) 89 | return x, state 90 | 91 | 92 | 93 | 94 | class single_conv(nn.Module): 95 | def __init__(self,ch_in,ch_out): 96 | super(single_conv,self).__init__() 97 | self.conv = nn.Sequential( 98 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 99 | nn.BatchNorm2d(ch_out), 100 | nn.ReLU(inplace=True) 101 | ) 102 | 103 | def forward(self,x): 104 | x = self.conv(x) 105 | return x 106 | 107 | class Attention_block(nn.Module): 108 | def __init__(self,F_g,F_l,F_int): 109 | super(Attention_block,self).__init__() 110 | self.W_g = nn.Sequential( 111 | nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), 112 | nn.BatchNorm2d(F_int) 113 | ) 114 | 115 | self.W_x = nn.Sequential( 116 | nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), 117 | nn.BatchNorm2d(F_int) 118 | ) 119 | 120 | self.psi = nn.Sequential( 121 | nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), 122 | nn.BatchNorm2d(1), 123 | nn.Sigmoid() 124 | ) 125 | 126 | self.relu = nn.ReLU(inplace=True) 127 | 128 | def forward(self,g,x): 129 | g1 = self.W_g(g) 130 | x1 = self.W_x(x) 131 | psi = self.relu(g1+x1) 132 | psi = self.psi(psi) 133 | 134 | return x*psi 135 | 136 | 137 | class Recurr_Com_Att_U_Net(nn.Module): 138 | def __init__(self,img_ch=47,output_ch=2): 139 | super(Recurr_Com_Att_U_Net,self).__init__() 140 | 141 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 142 | print('img_ch',img_ch) 143 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=16) 144 | self.Conv2 = recurr_conv_block(ch_in=16,ch_out=32) 145 | self.Conv3 = recurr_conv_block(ch_in=32,ch_out=64) 146 | self.Conv4 = recurr_conv_block(ch_in=64,ch_out=128) 147 | self.Conv5 = recurr_conv_block(ch_in=128,ch_out=256) 148 | 149 | self.Up5 = up_conv(ch_in=256,ch_out=128) 150 | self.Att5 = Attention_block(F_g=128,F_l=128,F_int=64) 151 | self.Up_conv5 = conv_block(ch_in=256, ch_out=128) 152 | 153 | self.Up4 = up_conv(ch_in=128,ch_out=64) 154 | self.Att4 = Attention_block(F_g=64,F_l=64,F_int=32) 155 | self.Up_conv4 = conv_block(ch_in=128, ch_out=64) 156 | 157 | self.Up3 = up_conv(ch_in=64,ch_out=32) 158 | self.Att3 = Attention_block(F_g=32,F_l=32,F_int=16) 159 | self.Up_conv3 = conv_block(ch_in=64, ch_out=32) 160 | 161 | self.Up2 = up_conv(ch_in=32,ch_out=16) 162 | self.Att2 = Attention_block(F_g=16,F_l=16,F_int=8) 163 | self.Up_conv2 = conv_block(ch_in=32, ch_out=16) 164 | 165 | self.Conv_1x1 = nn.Conv2d(20,output_ch,kernel_size=1,stride=1,padding=0) 166 | self.Conv_rain = nn.Conv2d(20,1,kernel_size=1,stride=1,padding=0) 167 | 168 | def forward(self,x, hist_rain, state2_pre=None,state3_pre=None,state4_pre=None,state5_pre=None): 169 | # encoding path 170 | x_x = torch.cat((x,hist_rain), dim=1) 171 | # print('xx',x_x.shape) 172 | x1 = self.Conv1(x_x) 173 | 174 | x2 = self.Maxpool(x1) 175 | 176 | x2,state2 = self.Conv2(x2,state2_pre) 177 | 178 | x3 = self.Maxpool(x2) 179 | 180 | x3,state3 = self.Conv3(x3,state3_pre) 181 | 182 | x4 = self.Maxpool(x3) 183 | x4,state4 = self.Conv4(x4,state4_pre) 184 | 185 | x5 = self.Maxpool(x4) 186 | x5,state5 = self.Conv5(x5,state5_pre) 187 | 188 | # decoding + concat path 189 | d5 = self.Up5(x5) 190 | x4 = self.Att5(g=d5,x=x4) 191 | d5 = torch.cat((x4,d5),dim=1) 192 | d5 = self.Up_conv5(d5) 193 | 194 | d4 = self.Up4(d5) 195 | x3 = self.Att4(g=d4,x=x3) 196 | d4 = torch.cat((x3,d4),dim=1) 197 | d4 = self.Up_conv4(d4) 198 | 199 | d3 = self.Up3(d4) 200 | x2 = self.Att3(g=d3,x=x2) 201 | d3 = torch.cat((x2,d3),dim=1) 202 | d3 = self.Up_conv3(d3) 203 | 204 | d2 = self.Up2(d3) 205 | x1 = self.Att2(g=d2,x=x1) 206 | d2 = torch.cat((x1,d2),dim=1) 207 | d2 = self.Up_conv2(d2) 208 | d2 = torch.cat((d2, hist_rain, x[:,32:35,:,:] ), dim=1) 209 | 210 | d1 = self.Conv_1x1(d2) 211 | out = F.softmax(d1, dim=1) 212 | rain = self.Conv_rain(d2)*out[:,0:1,:,:]+hist_rain 213 | 214 | return out, rain, state2, state3, state4, state5 215 | 216 | 217 | class entire_model(nn.Module): 218 | def __init__(self,img_ch=47,output_ch=2): 219 | super(entire_model,self).__init__() 220 | self.convLSTM_layer = convLSTM_model(41,80,80) 221 | self.TAUnet = Recurr_Com_Att_U_Net(img_ch,output_ch) 222 | 223 | def forward(self,input, hist_rain): 224 | out = torch.zeros([input.size()[0],input.size()[1], 2, input.size()[3],input.size()[4]], device = input.device) 225 | rain = torch.zeros([input.size()[0],input.size()[1], 1, input.size()[3],input.size()[4]], device = input.device) 226 | self.prerain48 = self.convLSTM_layer(hist_rain) 227 | self.prerain24 = self.prerain48[:,24:,:,:,:] 228 | for i in range(input.size()[1]): 229 | if i==0: 230 | out[:,i,:,:,:],rain[:,i,:,:,:],state2,state3,state4,state5 = self.TAUnet(input[:,i,:,:,:], self.prerain24[:,i,:,:,:]) 231 | else: 232 | out[:,i,:,:,:],rain[:,i,:,:,:],state2,state3,state4,state5 = self.TAUnet(input[:,i,:,:,:], self.prerain24[:,i,:,:,:], state2, state3, state4, state5) 233 | 234 | return out, rain, self.prerain24 235 | 236 | 237 | # def recurrent_model(LSTMmodel,unetmodel, input, histrain): 238 | # ''' 239 | # model: model to iterate 240 | # input shape : [Batch, Time, Filter, W, H ] 241 | # ''' 242 | # # LSTM_model 243 | # out = torch.zeros([input.size()[0],input.size()[1], 2, input.size()[3],input.size()[4]], device = input.device) 244 | # rain = torch.zeros([input.size()[0],input.size()[1], 1, input.size()[3],input.size()[4]], device = input.device) 245 | # # histrain = torch.zeros([input.size()[0],input.size()[1], 1, input.size()[3],input.size()[4]], device = input.device) 246 | # prerain48 = LSTMmodel(histrain) 247 | # prerain24 = prerain48[:,24:,:,:,:] 248 | # for i in range(input.size()[1]): 249 | # if i==0: 250 | # out[:,i,:,:,:],rain[:,i,:,:,:],state2,state3,state4,state5 = model(input[:,i,:,:,:], prerain24[:,i,:,:,:]) 251 | # else: 252 | # out[:,i,:,:,:],rain[:,i,:,:,:],state2,state3,state4,state5 = model(input[:,i,:,:,:], prerain24[:,i,:,:,:], state2, state3, state4, state5) 253 | 254 | # return out, rain, prerain24 255 | 256 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import datetime 5 | import torch 6 | import torchvision 7 | from torch import optim 8 | from torch.autograd import Variable 9 | import torch.nn.functional as F 10 | from evaluation import * 11 | from network import Recurr_Com_Att_U_Net,recurrent_model, entire_model 12 | import csv 13 | import pandas as pd 14 | from dataset import load_data2 15 | import random 16 | from dataset import my_dataset, load_data2, my_test_dataset 17 | from convLSTM_network import convLSTM_model 18 | 19 | class Solver(object): 20 | def __init__(self, config, train_loader, valid_loader, test_loader): 21 | 22 | # Data loader 23 | self.train_loader = train_loader 24 | self.valid_loader = valid_loader 25 | self.test_loader = test_loader 26 | 27 | # Models 28 | self.unet = None 29 | self.optimizer = None 30 | self.img_ch = config.img_ch 31 | self.output_ch = config.output_ch 32 | self.criterion = torch.nn.BCELoss() 33 | self.criterion2 = SoftIoULoss(2) 34 | self.criterion3 = FocalLoss() 35 | self.criterion4 = torch.nn.MSELoss() 36 | self.augmentation_prob = config.augmentation_prob 37 | 38 | # Hyper-parameters 39 | self.lr = config.lr 40 | self.beta1 = config.beta1 41 | self.beta2 = config.beta2 42 | 43 | # Training settings 44 | self.num_epochs = config.num_epochs 45 | self.num_epochs_decay = config.num_epochs_decay 46 | self.batch_size = config.batch_size 47 | 48 | # Step size 49 | self.log_step = config.log_step 50 | self.val_step = config.val_step 51 | self.test_only = config.test_only 52 | # Path 53 | self.model_path = config.model_path 54 | self.result_path = config.result_path 55 | self.mode = config.mode 56 | 57 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 58 | self.model_type = config.model_type 59 | print('model:',self.model_type, 'batch_size:',self.batch_size) 60 | self.t = config.t 61 | self.build_model() 62 | 63 | def build_model(self): 64 | """Build generator and discriminator.""" 65 | if self.model_type =='RCA_U_Net': 66 | self.unet = entire_model(img_ch=self.img_ch, output_ch=self.output_ch) 67 | 68 | self.best_threshold = 0.5 69 | self.optimizer = optim.Adam(list(self.unet.parameters()), 70 | self.lr, [self.beta1, self.beta2]) 71 | self.unet.to(self.device) 72 | 73 | # self.print_network(self.unet, self.model_type) 74 | 75 | def print_network(self, model, name): 76 | """Print out the network information.""" 77 | num_params = 0 78 | for p in model.parameters(): 79 | num_params += p.numel() 80 | print(model) 81 | print(name) 82 | print("The number of parameters: {}".format(num_params)) 83 | 84 | def to_data(self, x): 85 | """Convert variable to tensor.""" 86 | if torch.cuda.is_available(): 87 | x = x.cpu() 88 | return x.data 89 | 90 | def update_lr(self, g_lr, d_lr): 91 | for param_group in self.optimizer.param_groups: 92 | param_group['lr'] = lr 93 | 94 | def reset_grad(self): 95 | """Zero the gradient buffers.""" 96 | self.unet.zero_grad() 97 | 98 | def compute_accuracy(self,SR,GT): 99 | SR_flat = SR.view(-1) 100 | GT_flat = GT.view(-1) 101 | 102 | acc = GT_flat.data.cpu()==(SR_flat.data.cpu()>0.5) 103 | 104 | def tensor2img(self,x): 105 | img = (x[:,0,:,:]>x[:,1,:,:]).float() 106 | img = img*255 107 | return img 108 | 109 | 110 | def train(self): 111 | """Train encoder, generator and discriminator.""" 112 | 113 | #====================================== Training ===========================================# 114 | #===========================================================================================# 115 | 116 | unet_path = os.path.join(self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %(self.model_type,self.num_epochs,self.lr,self.num_epochs_decay,self.augmentation_prob)) 117 | 118 | # U-Net Train 119 | if False : #os.path.isfile(unet_path): 120 | # Load the pretrained Encoder 121 | self.unet.load_state_dict(torch.load(unet_path)) 122 | print('%s is Successfully Loaded from %s'%(self.model_type,unet_path)) 123 | else: 124 | if os.path.isfile(unet_path): 125 | # Load the pretrained Encoder 126 | self.unet.load_state_dict(torch.load(unet_path)) 127 | print('%s is Successfully Loaded from %s'%(self.model_type,unet_path)) 128 | print(self.best_threshold) 129 | # Train for Encoder 130 | lr = self.lr 131 | best_unet_score = 0. 132 | best_threshold = 0. 133 | best_unet_loss =100000. 134 | file_dict = pd.read_csv('/data/output/all_guance_data_name_list/all_gc_filename_list.csv',index_col=0) 135 | datelist = [str(line).split('_')[1] for line in file_dict.values] 136 | file_dict.index = datelist 137 | historyhour = 24 138 | if self.test_only: 139 | self.num_epochs=0 140 | for epoch in range(self.num_epochs): 141 | self.unet.train(True) 142 | epoch_loss = 0 143 | tt_threshold = 0 144 | acc = 0. # Accuracy 145 | SE = 0. # Sensitivity (Recall) 146 | SP = 0. # Specificity 147 | PC = 0. # Precision 148 | F1 = 0. # F1 Score 149 | JS = 0. # Jaccard Similarity 150 | DC = 0. # Dice Coefficient 151 | epoch_loss=0 152 | length = 0 153 | trainlist = self.train_loader 154 | trainlist = np.array(trainlist).reshape(-1) 155 | random.shuffle(trainlist) 156 | trainlist = trainlist.reshape(-1,self.batch_size) 157 | for i, batchlist in enumerate(trainlist): 158 | # GT : Ground Truth 159 | tt = time.time() 160 | images, GT, rain_true, histrain =load_data2(batchlist, file_dict, 24, binary=True) 161 | # images, GT, rain_true = load_processed_data(batchlist) 162 | print(time.time()-tt) 163 | histrain = torch.FloatTensor(histrain).to(self.device) 164 | rain_true = torch.FloatTensor(rain_true).to(self.device) 165 | images = torch.FloatTensor(images).to(self.device) 166 | GT = torch.FloatTensor(GT).to(self.device) 167 | # SR : Segmentation Result 168 | SR_probs, rain_pred, prerain24 = self.unet( images, histrain) 169 | # SR_probs, rain_pred = self.unet(images) 170 | SR_flat = SR_probs.view(SR_probs.size(0),-1) 171 | GT_flat = GT.view(GT.size(0),-1) 172 | rpred_flat = rain_pred.view(rain_pred.size(0),-1) 173 | rtrue_flat = rain_true.view(rain_true.size(0),-1) 174 | rpred24_flat = rain_pred.view(prerain24.size(0),-1) 175 | 176 | # loss = self.criterion2(SR_flat,GT_flat)+self.criterion3(SR_flat,GT_flat)+self.criterion4(rpred_flat,rtrue_flat) 177 | loss = self.criterion(SR_flat,GT_flat)+4*self.criterion2(SR_flat,GT_flat)+self.criterion3(SR_flat,GT_flat)+self.criterion4(rpred_flat,rtrue_flat)+4*self.criterion4(rpred24_flat,rtrue_flat) 178 | epoch_loss += loss.item() 179 | self.reset_grad() 180 | loss.backward() 181 | self.optimizer.step() 182 | SR_probs = SR_probs.view(-1,2,80,80) 183 | GT = GT.view(-1,2,80,80) 184 | # tmp_threshold = get_best_threshold(SR_probs, GT) 185 | tmp = get_accuracy(SR_probs, GT) 186 | print('epoch: ',epoch,'batch number: ',i,'/',len(trainlist),'training loss:',loss.item(),'acc',tmp) 187 | # Backprop + optimize 188 | epoch_loss += loss.item() 189 | acc += get_accuracy(SR_probs,GT) 190 | SE += get_sensitivity(SR_probs,GT) 191 | SP += get_specificity(SR_probs,GT) 192 | PC += get_precision(SR_probs,GT) 193 | F1 += get_F1(SR_probs[:,0,:,:],GT[:,0,:,:]) 194 | JS += get_JS(SR_probs[:,0,:,:],GT[:,0,:,:]) 195 | DC += get_DC(SR_probs[:,0,:,:],GT[:,0,:,:]) 196 | length = len(trainlist) 197 | acc = acc/length 198 | SE = SE/length 199 | SP = SP/length 200 | PC = PC/length 201 | F1 = F1/length 202 | JS = JS/length 203 | DC = DC/length 204 | epoch_loss = epoch_loss/length 205 | # Print the log info 206 | print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % ( 207 | epoch+1, self.num_epochs, \ 208 | epoch_loss,\ 209 | acc,SE,SP,PC,F1,JS,DC)) 210 | 211 | 212 | 213 | # Decay learning rate 214 | if (epoch+1) > (self.num_epochs - self.num_epochs_decay): 215 | lr -= (self.lr / float(self.num_epochs_decay)) 216 | for param_group in self.optimizer.param_groups: 217 | param_group['lr'] = lr 218 | print ('Decay learning rate to lr: {}.'.format(lr)) 219 | 220 | 221 | if best_unet_loss >epoch_loss: 222 | best_unet_loss = epoch_loss 223 | best_epoch = epoch 224 | best_unet = self.unet.state_dict() 225 | print('Best %s model loss : %.4f'%(self.model_type,best_unet_loss)) 226 | torch.save(best_unet,unet_path) 227 | # ===================================== Validation ====================================# 228 | self.unet.train(False) 229 | self.unet.eval() 230 | 231 | acc = 0. # Accuracy 232 | SE = 0. # Sensitivity (Recall) 233 | SP = 0. # Specificity 234 | PC = 0. # Precision 235 | F1 = 0. # F1 Score 236 | JS = 0. # Jaccard Similarity 237 | DC = 0. # Dice Coefficient 238 | length=0 239 | # print(self.valid_loader) 240 | for i, batchlist in enumerate(self.valid_loader): 241 | # GT : Ground Truth 242 | images, GT, rain_true, histrain =load_data2(batchlist, file_dict, 24, binary=True) 243 | histrain = torch.FloatTensor(histrain).to(self.device) 244 | rain_true = torch.FloatTensor(rain_true).to(self.device) 245 | images = torch.FloatTensor(images).to(self.device) 246 | GT = torch.FloatTensor(GT).to(self.device) 247 | # SR : Segmentation Result 248 | SR_probs, rain_pred, prerain24= self.unet(images, histrain) 249 | SR_flat = SR_probs.view(SR_probs.size(0),-1) 250 | GT_flat = GT.view(GT.size(0),-1) 251 | rpred_flat = rain_pred.view(rain_pred.size(0),-1) 252 | rtrue_flat = rain_true.view(rain_true.size(0),-1) 253 | rpred24_flat = rain_pred.view(prerain24.size(0),-1) 254 | loss = self.criterion2(SR_flat,GT_flat)+ self.criterion3(rpred_flat,rtrue_flat)+self.criterion3(rpred24_flat,rtrue_flat) 255 | epoch_loss += loss.item() 256 | SR_probs = SR_probs.view(-1,2,80,80) 257 | GT = GT.view(-1,2,80,80) 258 | acc += get_accuracy(SR_probs,GT,best_threshold) 259 | SE += get_sensitivity(SR_probs,GT) 260 | SP += get_specificity(SR_probs,GT) 261 | PC += get_precision(SR_probs,GT) 262 | F1 += get_F1(SR_probs[:,0,:,:],GT[:,0,:,:]) 263 | JS += get_JS(SR_probs[:,0,:,:],GT[:,0,:,:]) 264 | DC += get_DC(SR_probs[:,0,:,:],GT[:,0,:,:]) 265 | tmp = get_accuracy(SR_probs,GT) 266 | print('epoch: ',epoch,'batch number: ',i,'validation loss:',loss.item(),'acc',tmp) 267 | length = len(self.valid_loader) 268 | acc = acc/length 269 | SE = SE/length 270 | SP = SP/length 271 | PC = PC/length 272 | F1 = F1/length 273 | JS = JS/length 274 | DC = DC/length 275 | epoch_loss/length 276 | unet_score = JS + DC 277 | print('epoch_loss:', epoch_loss,' best_unet_loss', best_unet_loss,'unet_score:',unet_score) 278 | print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'%(acc,SE,SP,PC,F1,JS,DC)) 279 | 280 | ''' 281 | torchvision.utils.save_image(images.data.cpu(), 282 | os.path.join(self.result_path, 283 | '%s_valid_%d_image.png'%(self.model_type,epoch+1))) 284 | torchvision.utils.save_image(SR.data.cpu(), 285 | os.path.join(self.result_path, 286 | '%s_valid_%d_SR.png'%(self.model_type,epoch+1))) 287 | torchvision.utils.save_image(GT.data.cpu(), 288 | os.path.join(self.result_path, 289 | '%s_valid_%d_GT.png'%(self.model_type,epoch+1))) 290 | ''' 291 | 292 | 293 | # Save Best U-Net model 294 | # if unet_score > best_unet_score : 295 | # best_unet_score = unet_scored 296 | # if best_unet_loss >epoch_loss: 297 | # best_unet_loss = epdoch_loss 298 | # best_epoch = epoch 299 | # best_unet = self.unet.state_dict() 300 | # print('Best %s model loss : %.4f'%(self.model_type,best_unet_loss)) 301 | # torch.save(best_unet,unet_path) 302 | # print('saveing picture') 303 | # rain_compare_gc_rt_pre(images[:,:,34,:,:]*10,rain_true[:,:,0,:,:],rain_pred[:,:,0,:,:], prerain24[:,:,0,:,:],vmax=5) 304 | #===================================== Test ====================================# 305 | if not self.test_only: 306 | del self.unet 307 | del best_unet 308 | self.build_model() 309 | self.unet.load_state_dict(torch.load(unet_path)) 310 | 311 | self.unet.train(False) 312 | self.unet.eval() 313 | 314 | acc = 0. # Accuracy 315 | SE = 0. # Sensitivity (Recall) 316 | SP = 0. # Specificity 317 | PC = 0. # Precision 318 | F1 = 0. # F1 Score 319 | JS = 0. # Jaccard Similarity 320 | DC = 0. # Dice Coefficient 321 | length=0 322 | batch_test, file_dict_test = my_test_dataset( self.batch_size, historyhour, season=False) 323 | for i, batchlist in enumerate(self.test_loader): 324 | print('batch',i) 325 | # GT : Ground Truth 326 | images, GT, rain_true, histrain =load_data2(batchlist, file_dict_test, 24, binary=True) 327 | histrain = torch.FloatTensor(histrain).to(self.device) 328 | rain_true = torch.FloatTensor(rain_true).to(self.device) 329 | images = torch.FloatTensor(images).to(self.device) 330 | GT = torch.FloatTensor(GT).to(self.device) 331 | # SR : Segmentation Result 332 | SR, rain_pred, prerain24= self.unet(images, histrain) 333 | np.save('./vis/prediction_{}.npy'.format(i), SR.cpu().detach().numpy()) 334 | np.save('./vis/ground_truth_{}.npy'.format(i),GT.cpu().detach().numpy()) 335 | np.save('./vis/ruitu_pre_{}.npy'.format(i),images[:,:,34,:,:].cpu().detach().numpy()) 336 | np.save('./vis/prerain_{}.npy'.format(i),rain_pred.cpu().detach().numpy()) 337 | np.save('./vis/prerain24_{}.npy'.format(i),prerain24.cpu().detach().numpy()) 338 | np.save('./vis/ground_rain_{}.npy'.format(i),rain_true.cpu().detach().numpy()) 339 | del SR, GT, images,rain_pred,prerain24,rain_true 340 | 341 | 342 | 343 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | import time 5 | import torch 6 | from torch import nn,device 7 | from torch.autograd import Variable 8 | import pandas as pd 9 | from dataset import * 10 | from model import * 11 | import datetime 12 | import pandas as pd 13 | import random 14 | import threading 15 | import multiprocessing 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | 20 | def rain_compare_gc_rt_pre(ruitudata,guancedata,pre,vmax=5): 21 | ''' 22 | 输入: 23 | ruitudata: 睿图数据,只包含降水信息 24 | guancedata: 观测数据,同只包含降水信息 25 | pre: 模型预测数据,只包含降水信息 26 | 27 | 输出: 预报时刻的对应的睿图、观测和模型预测降水图 28 | ''' 29 | import matplotlib as mpl 30 | vmax=vmax 31 | 32 | ruitudata=ruitudata*10 33 | ruitudata[ruitudata>vmax]=vmax 34 | 35 | guancedata=guancedata*10 36 | guancedata[guancedata>vmax]=vmax 37 | 38 | pre=pre*10 39 | pre[pre>vmax]=vmax 40 | 41 | #确定预测时长,对比每个时次的降水信息 42 | time_length=ruitudata.shape[1] 43 | for i in np.arange(time_length): 44 | 45 | fig=plt.figure(i,figsize=(20,5)) 46 | 47 | plt.subplot(1,3,1) 48 | norm=mpl.colors.Normalize(vmin=0,vmax=5) 49 | plt.imshow(ruitudata[0,i,:,:],norm=norm) #显示热力图,范围正则化 50 | # plt.colorbar(ruitu_f) 51 | plt.title('ruitu:'+str(i)) 52 | #plt.tight_layout() #貌似会造成一个colorbar消失 53 | 54 | plt.subplot(1,3,2) 55 | plt.imshow(guancedata[0,i,:,:],norm=norm) 56 | # plt.colorbar(guance_f) 57 | plt.title('guance:'+str(i)) 58 | #plt.tight_layout() 59 | 60 | plt.subplot(1,3,3) 61 | plt.imshow(pre[0,i,:,:],norm=norm) 62 | # plt.colorbar(pre_f) 63 | plt.title('prediction:'+str(i)) 64 | 65 | plt.savefig('/data/code/ml/encoder_decoder/vis_test/'+str(i)+'_'+str(np.round(np.random.random(),4))+'.jpg') #保存所有时次图 66 | 67 | # plt.show() 68 | 69 | 70 | def train_test_by_batch(model,epochs,Batch_size,historyhour=24,season='summer',fut_time_steps=6 ,k=10,binary=False,rain_threshold=10): 71 | 72 | ''' 73 | 输入: 74 | model:待训练模型 75 | epochs:训练的次数 76 | Batch_size: 每个Batch中的数据样本数 77 | historyhour=24:表示选取历史24小时数据 78 | season:对夏季数据(6`9月)进行训练 79 | k:表示每隔k次batch训练进行1次测试 80 | binary:表示是否进行二分类, 81 | 如果为False,load_data2中的输出的guancedata为[batch,timesteps,1,widh,high],降水没有进行归一化 82 | 如果为True, 表示对降水进行了one-hot编码,guancedata维度为[batch,timesteps,2,widh,high] 83 | rain_threshold:如果binary为False,则需要进行回归,回归时 84 | ''' 85 | #调用模型,编译模型,当模型为2分类时,使用交叉熵,当为回归时,使用MSE 86 | # model=build_forecaster_model(binary=binary) 87 | 88 | #创建列表来保存所有train,和test 的loss,acc 89 | total_loss_train=[] 90 | total_acc_train=[] 91 | 92 | total_loss_test=[] 93 | total_acc_test=[] 94 | 95 | #定义优化方法和损失函数 96 | loss_func=torch.nn.MSELoss() #MSE损失 97 | # loss_func=torch.nn.L1Loss() #MAE损失 98 | # loss_func=torch.nn.SmoothL1Loss() #计算平滑L1损失,属于 Huber Loss中的一种(因为参数δ固定为1了) 99 | 100 | # ssim_loss=SSIM() 101 | opt=torch.optim.Adam(model.parameters(),lr=0.001) 102 | # scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.5) 103 | 104 | # mult_step_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[EPOCH//2, EPOCH//4*3], gamma=0.1) 105 | 106 | k=k 107 | #每训练k个batch进行一次测试 108 | 109 | #将数据集划分为训练和测试集 110 | batch_filelist,file_dict,train_dataset,test_dataset=load_all_batch_filelist(Batch_size,historyhour,season=season,split_num=0.8) 111 | 112 | #对train_dataset进行遍历训练epoch次 113 | for i in range(epochs): 114 | 115 | #定义模型训练 116 | model.train() 117 | 118 | #每个epoch记录一次loss 119 | loss_train=[] 120 | 121 | #每次epoch之前,对训练数据集进行shuffle 122 | train_dataset= np.array(train_dataset).reshape(-1) 123 | random.shuffle(train_dataset) 124 | train_dataset = train_dataset.reshape(-1,Batch_size) 125 | 126 | #确定在训练数据集上有多少个batch 127 | train_batch_num=len(train_dataset) 128 | 129 | #将每个epoch创建一个记录损失和准确率的文件 130 | train_log_file='/data/code/fzl/encoder_forecaster_1_terminal/inference_model/train_log_file/log_epoch_'+str(i)+'.txt' 131 | 132 | 133 | with open(train_log_file,'w') as f: 134 | 135 | # for j in range(train_batch_num): 136 | for j in range(train_batch_num): 137 | 138 | #获取一次batch数据,并进行训练 139 | ruitudata, guancedata, histguancedata = load_data2(train_dataset[j], file_dict,history_hour=historyhour, binary=binary) 140 | 141 | #获取合适宽度范围 142 | histguancedata=histguancedata[:,:,:,0:80,0:80] 143 | ruitudata=ruitudata[:,:,:,0:80,0:80] 144 | guancedata=guancedata[:,:,:,0:80,0:80] 145 | 146 | guancedata[guancedata<=0.1]=0 #卡阈值 147 | 148 | if fut_time_steps<24: 149 | guancedata=guancedata[:,0:fut_time_steps,:,:,:] 150 | ruitudata=ruitudata[:,0:fut_time_steps,:,:,:] 151 | 152 | if fut_time_steps==1: 153 | guancedata=np.expand_dims(guancedata,axis=1) 154 | ruitudata=np.expand_dims(ruitudata,axis=1) 155 | 156 | 157 | #只取睿图降水信息 158 | if ruitu_features==1: 159 | ruitudata=ruitudata[:,:,34,:,:] 160 | ruitudata=np.expand_dims(ruitudata,axis=2) 161 | 162 | #取出历史观测的最后一个时刻的降水信息,并将其乘以10 163 | hist_gc_0=histguancedata[:,-1,4,:,:]*0.6713073720808679+0.052805550785578505 164 | hist_gc_0=np.expand_dims(hist_gc_0,axis=1) 165 | hist_gc_0=np.expand_dims(hist_gc_0,axis=1) 166 | 167 | 168 | #在训练的时候,forecaster的输入为 滞后一个小时的观测和睿图信息 169 | forecaster_input_gc= np.concatenate((hist_gc_0,guancedata[:,0:-1,:,:,:]+ ruitudata[:,0:-1,:,:,:]),axis=1) 170 | print('hist_gc_0:',hist_gc_0.shape) 171 | print(ruitudata.shape) 172 | print(guancedata.shape) 173 | print(forecaster_input_gc.shape) 174 | # #只取历史观测降水信息 175 | # if guance_features==1: 176 | # histguancedata=histguancedata[:,:,4,:,:] 177 | # histguancedata=np.expand_dims(histguancedata,axis=2) 178 | 179 | 180 | # ruitudata=torch.from_numpy(ruitudata).type(torch.FloatTensor).cuda().to(device) 181 | guancedata=torch.from_numpy(guancedata).type(torch.FloatTensor).cuda().to(device) 182 | histguancedata=torch.from_numpy(histguancedata).type(torch.FloatTensor).cuda().to(device) 183 | forecaster_input_gc=torch.from_numpy(forecaster_input_gc).type(torch.FloatTensor).cuda().to(device) 184 | 185 | #如果binary为False,需要对guance降水进行归一化 186 | if binary==False: 187 | guancedata=guancedata/rain_threshold 188 | forecaster_input_gc=forecaster_input_gc/rain_threshold 189 | 190 | # scheduler.step() 191 | 192 | pred=model(histguancedata,forecaster_input_gc) 193 | pred=pred[0] #不获取状态 194 | 195 | B, S, C, H, W = guancedata.size() 196 | 197 | 198 | #每隔20次batch画一次图,看看效果 199 | if j%10==0: 200 | rain_compare_gc_rt_pre(ruitudata[:,:,0,:,:],guancedata[:,:,0,:,:].cpu().detach().numpy(),pred[:,:,0,:,:].cpu().detach().numpy()) 201 | 202 | pred=pred.view(-1,C,H,W) 203 | guancedata=guancedata.view(-1,C,H,W) 204 | 205 | # print(type(pred)) 206 | loss=loss_func(pred,guancedata) 207 | 208 | # loss=loss_func(pred,guancedata) 209 | 210 | #将每个batch的训练信息保存下来,h[0]为loss,h[1]为acc 211 | loss_train.append(loss.item()) 212 | 213 | #输出每次bathc训练的损失和准确率,输出到文件中 214 | print('Epoch {}/{} : train_batch {}/{}: -----train_loss:{:.4f} \n'.format(i,epochs-1,j,train_batch_num-1,loss.item())) 215 | f.write('Epoch {}/{} : train_batch {}/{}: -----train_loss:{:.4f} \n'.format(i,epochs-1,j,train_batch_num-1,loss.item())) 216 | f.write('\n') 217 | 218 | # loss.backward() 219 | 220 | opt.zero_grad() 221 | loss.backward() 222 | opt.step() 223 | 224 | 225 | #每个epoch画出训练和测试的损失信息和准确率信息,将其保存到对应文件中去 226 | plt.figure(figsize=(10,6)) 227 | plt.plot(np.arange(len(loss_train)),loss_train,'-r',label='loss-train') 228 | plt.title('epoch_'+str(i)) 229 | plt.legend() 230 | plt.savefig('/data/code/fzl/encoder_forecaster_1_terminal/inference_model/'+'epoch_'+str(i)+'.png') 231 | plt.show() 232 | 233 | total_loss_train.append(loss_train) 234 | f.close() 235 | 236 | #每个epoch保存一次模型 237 | # torch.save(model,'pytorch_encoder_forecaster_model_epoch_'+str(i)+'.pkl') 238 | torch.save({'state_dict': model.state_dict()}, '/data/code/fzl/encoder_forecaster_1_terminal/inference_model/epoch_'+str(i)+'_checkpoint.pth.tar') 239 | # torch.save() 240 | 241 | return model,total_loss_train 242 | 243 | 244 | 245 | def inference(k=5,rain_threshold=10): 246 | Batch_size=8 247 | 248 | epochs = 10 249 | 250 | fut_time_steps=6 251 | 252 | filter_1=16 253 | filter_2=32 254 | filter_3=64 255 | filter_4=128 256 | filter_5=128 257 | 258 | #ruitu_features=46 259 | width=80 260 | height=80 261 | histgc_feature=41 262 | fut_guance_feature=1 263 | ruitu_features=1 264 | device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") 265 | model=combined_net(histgc_feature,fut_guance_feature,width,height).to(device) 266 | # model = encoder_forecaster_net()epoch_5_checkpoint.pth.tar 267 | checkpoint = torch.load('/data/code/fzl/encoder_forecaster_1_terminal/inference_model/epoch_5_checkpoint.pth.tar') 268 | model.load_state_dict(checkpoint['state_dict']) 269 | model.eval() 270 | historyhour=24 271 | binary=False 272 | batch_filelist,file_dict,train_dataset,test_dataset=load_all_batch_filelist(Batch_size,historyhour,season='summer',split_num=0.8) 273 | #确定在训练数据集上有多少个batch 274 | train_batch_num=len(train_dataset) 275 | for j in range(train_batch_num): 276 | 277 | #获取一次batch数据,并进行训练 278 | ruitudata, guancedata, histguancedata = load_data2(train_dataset[j], file_dict,history_hour=historyhour, binary=binary) 279 | 280 | #获取合适宽度范围 281 | histguancedata=histguancedata[:,:,:,0:80,0:80] 282 | ruitudata=ruitudata[:,:,:,0:80,0:80] 283 | guancedata=guancedata[:,:,:,0:80,0:80] 284 | 285 | guancedata[guancedata<=0.1]=0 #卡阈值 286 | 287 | if fut_time_steps<24: 288 | guancedata=guancedata[:,0:fut_time_steps,:,:,:] 289 | ruitudata=ruitudata[:,0:fut_time_steps,:,:,:] 290 | 291 | if fut_time_steps==1: 292 | guancedata=np.expand_dims(guancedata,axis=1) 293 | ruitudata=np.expand_dims(ruitudata,axis=1) 294 | 295 | 296 | #只取睿图降水信息 297 | if ruitu_features==1: 298 | ruitudata=ruitudata[:,:,34,:,:] 299 | ruitudata=np.expand_dims(ruitudata,axis=2) 300 | 301 | #取出历史观测的最后一个时刻的降水信息,并将其乘以10 302 | hist_gc_0=histguancedata[:,-1,4,:,:]*10 303 | hist_gc_0=np.expand_dims(hist_gc_0,axis=1) 304 | hist_gc_0=np.expand_dims(hist_gc_0,axis=1) 305 | 306 | 307 | #在训练的时候,forecaster的输入为 滞后一个小时的观测和睿图信息 308 | forecaster_input_gc= np.concatenate((hist_gc_0,guancedata[:,0:-1,:,:,:]+ ruitudata[:,0:-1,:,:,:]),axis=1) 309 | print('hist_gc_0:',hist_gc_0.shape) 310 | print(ruitudata.shape) 311 | print(guancedata.shape) 312 | print(forecaster_input_gc.shape) 313 | # #只取历史观测降水信息 314 | # if guance_features==1: 315 | # histguancedata=histguancedata[:,:,4,:,:] 316 | # histguancedata=np.expand_dims(histguancedata,axis=2) 317 | 318 | 319 | # ruitudata=torch.from_numpy(ruitudata).type(torch.FloatTensor).cuda().to(device) 320 | guancedata=torch.from_numpy(guancedata).type(torch.FloatTensor).cuda().to(device) 321 | histguancedata=torch.from_numpy(histguancedata).type(torch.FloatTensor).cuda().to(device) 322 | forecaster_input_gc=torch.from_numpy(forecaster_input_gc).type(torch.FloatTensor).cuda().to(device) 323 | 324 | #如果binary为False,需要对guance降水进行归一化 325 | if binary==False: 326 | guancedata=guancedata/rain_threshold 327 | forecaster_input_gc=forecaster_input_gc/rain_threshold 328 | 329 | # scheduler.step() 330 | 331 | pred=model(histguancedata,forecaster_input_gc) 332 | pred=pred[0] #不获取状态 333 | 334 | B, S, C, H, W = guancedata.size() 335 | 336 | 337 | #每隔20次batch画一次图,看看效果 338 | if j%5==0: 339 | print('peeking in picture') 340 | rain_compare_gc_rt_pre(ruitudata[:,:,0,:,:],guancedata[:,:,0,:,:].cpu().detach().numpy(),pred[:,:,0,:,:].cpu().detach().numpy()) 341 | 342 | 343 | 344 | 345 | 346 | if __name__=='__main__': 347 | inference() 348 | # Batch_size=8 349 | 350 | # epochs = 10 351 | 352 | # fut_time_steps=1 353 | 354 | # filter_1=16 355 | # filter_2=32 356 | # filter_3=64 357 | # filter_4=128 358 | # filter_5=128 359 | 360 | # #ruitu_features=46 361 | # width=80 362 | # height=80 363 | # histgc_feature=41 364 | # fut_guance_feature=1 365 | # ruitu_features=1 366 | 367 | # device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") 368 | # model=combined_net(histgc_feature,fut_guance_feature,width,height).to(device) 369 | # T_model,total_loss_train=train_test_by_batch(model,Batch_size=Batch_size,epochs=epochs,k=5,rain_threshold=10) 370 | 371 | --------------------------------------------------------------------------------