├── checkpoints └── ckpt_files_stored_here ├── README.md ├── data └── dataset_download.txt ├── LICENSE ├── .gitignore ├── dataset.py ├── model.py └── train.py /checkpoints/ckpt_files_stored_here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HighAir 2 | The source codes for HighAir: A Hierarchical Graph Neural Network-Based Air Quality Forecasting Method 3 | -------------------------------------------------------------------------------- /data/dataset_download.txt: -------------------------------------------------------------------------------- 1 | The pre-processed dataset can be downloaded from Google Drive: 2 | https://drive.google.com/file/d/1OLiFjHeY35o_PT4rLVfwUB1wsIzKSLOv/view?usp=sharing 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Friger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import torch.utils.data as Data 6 | 7 | import os 8 | import json 9 | import numpy as np 10 | import pandas as pd 11 | 12 | TIME_WINDOW = 24 13 | PRED_TIME = 12 14 | 15 | 16 | 17 | DATA_PATH = './' 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | 20 | class trainDataset(Data.Dataset): 21 | def __init__(self, transform=None, train=True): 22 | with open(os.path.join(DATA_PATH,'data','city_train.txt'), 'r') as f: 23 | self.cities = json.load(f) 24 | 25 | with open(os.path.join(DATA_PATH,'data','jiaxing_train.txt'), 'r') as f: 26 | self.jiaxing = json.load(f) 27 | 28 | with open(os.path.join(DATA_PATH,'data','shanghai_train.txt'), 'r') as f: 29 | self.shanghai = json.load(f) 30 | 31 | with open(os.path.join(DATA_PATH,'data','suzhou_train.txt'), 'r') as f: 32 | self.suzhou = json.load(f) 33 | 34 | with open(os.path.join(DATA_PATH,'stations.txt'), 'r') as f: 35 | self.stations = json.load(f) 36 | 37 | def GetCityData(self,city_name,city_source,index): 38 | station_list = self.stations[city_name] 39 | city_aqi = [] 40 | city_y = [] 41 | for x in station_list: 42 | city_aqi.append(city_source[x][index][:TIME_WINDOW]) 43 | city_y.append(city_source[x][index][TIME_WINDOW:]) 44 | 45 | city_aqi = torch.FloatTensor(city_aqi) 46 | city_y = torch.FloatTensor(city_y) 47 | city_sim = torch.FloatTensor(city_source['sim'][index]) 48 | city_conn = torch.tensor(city_source['conn']) 49 | city_weather = torch.FloatTensor(city_source['weather'][index]) 50 | city_for = torch.FloatTensor(city_source['weather_for'][index]) 51 | city_poi = torch.FloatTensor(city_source['poi']) 52 | 53 | city_data = [city_aqi, city_conn, city_poi, city_sim, 54 | city_weather, city_for, city_y] 55 | 56 | return city_data 57 | 58 | 59 | def __getitem__(self, index): 60 | jiaxing_data = self.GetCityData('jiaxing',self.jiaxing,index) 61 | shanghai_data = self.GetCityData('shanghai',self.shanghai,index) 62 | suzhou_data = self.GetCityData('suzhou',self.suzhou,index) 63 | 64 | cities_aqi = torch.FloatTensor(self.cities['aqi'][index]) 65 | cities_conn = torch.tensor(self.cities['conn']) 66 | cities_weather = torch.FloatTensor(self.cities['weather'][index]) 67 | cities_sim = torch.FloatTensor(self.cities['sim'][index]) 68 | 69 | cities_data = [cities_aqi, cities_conn,cities_sim,cities_weather] 70 | 71 | return cities_data,jiaxing_data,shanghai_data,suzhou_data 72 | 73 | def __len__(self): 74 | return len(self.shanghai['weather']) 75 | 76 | 77 | 78 | class valDataset(Data.Dataset): 79 | def __init__(self, transform=None, train=True): 80 | with open(os.path.join(DATA_PATH,'data','city_val.txt'), 'r') as f: 81 | self.cities = json.load(f) 82 | 83 | with open(os.path.join(DATA_PATH,'data','jiaxing_val.txt'), 'r') as f: 84 | self.jiaxing = json.load(f) 85 | 86 | with open(os.path.join(DATA_PATH,'data','shanghai_val.txt'), 'r') as f: 87 | self.shanghai = json.load(f) 88 | 89 | with open(os.path.join(DATA_PATH,'data','suzhou_val.txt'), 'r') as f: 90 | self.suzhou = json.load(f) 91 | 92 | with open(os.path.join(DATA_PATH,'stations.txt'), 'r') as f: 93 | self.stations = json.load(f) 94 | 95 | def GetCityData(self,city_name,city_source,index): 96 | station_list = self.stations[city_name] 97 | city_aqi = [] 98 | city_y = [] 99 | for x in station_list: 100 | city_aqi.append(city_source[x][index][:TIME_WINDOW]) 101 | city_y.append(city_source[x][index][TIME_WINDOW:]) 102 | 103 | city_aqi = torch.FloatTensor(city_aqi) 104 | city_y = torch.FloatTensor(city_y) 105 | city_sim = torch.FloatTensor(city_source['sim'][index]) 106 | city_conn = torch.tensor(city_source['conn']) 107 | city_weather = torch.FloatTensor(city_source['weather'][index]) 108 | city_for = torch.FloatTensor(city_source['weather_for'][index]) 109 | city_poi = torch.FloatTensor(city_source['poi']) 110 | 111 | city_data = [city_aqi, city_conn, city_poi, city_sim, 112 | city_weather, city_for, city_y] 113 | 114 | return city_data 115 | 116 | 117 | def __getitem__(self, index): 118 | jiaxing_data = self.GetCityData('jiaxing',self.jiaxing,index) 119 | shanghai_data = self.GetCityData('shanghai',self.shanghai,index) 120 | suzhou_data = self.GetCityData('suzhou',self.suzhou,index) 121 | 122 | cities_aqi = torch.FloatTensor(self.cities['aqi'][index]) 123 | cities_conn = torch.tensor(self.cities['conn']) 124 | cities_weather = torch.FloatTensor(self.cities['weather'][index]) 125 | cities_sim = torch.FloatTensor(self.cities['sim'][index]) 126 | 127 | cities_data = [cities_aqi, cities_conn,cities_sim,cities_weather] 128 | 129 | return cities_data,jiaxing_data,shanghai_data,suzhou_data 130 | 131 | def __len__(self): 132 | return len(self.shanghai['weather']) 133 | 134 | class testDataset(Data.Dataset): 135 | def __init__(self, transform=None, train=True): 136 | with open(os.path.join(DATA_PATH,'data','city_test.txt'), 'r') as f: 137 | self.cities = json.load(f) 138 | 139 | with open(os.path.join(DATA_PATH,'data','jiaxing_test.txt'), 'r') as f: 140 | self.jiaxing = json.load(f) 141 | 142 | with open(os.path.join(DATA_PATH,'data','shanghai_test.txt'), 'r') as f: 143 | self.shanghai = json.load(f) 144 | 145 | with open(os.path.join(DATA_PATH,'data','suzhou_test.txt'), 'r') as f: 146 | self.suzhou = json.load(f) 147 | 148 | with open(os.path.join(DATA_PATH,'stations.txt'), 'r') as f: 149 | self.stations = json.load(f) 150 | 151 | def GetCityData(self,city_name,city_source,index): 152 | station_list = self.stations[city_name] 153 | city_aqi = [] 154 | city_y = [] 155 | for x in station_list: 156 | city_aqi.append(city_source[x][index][:TIME_WINDOW]) 157 | city_y.append(city_source[x][index][TIME_WINDOW:]) 158 | 159 | city_aqi = torch.FloatTensor(city_aqi) 160 | city_y = torch.FloatTensor(city_y) 161 | city_sim = torch.FloatTensor(city_source['sim'][index]) 162 | city_conn = torch.tensor(city_source['conn']) 163 | city_weather = torch.FloatTensor(city_source['weather'][index]) 164 | city_for = torch.FloatTensor(city_source['weather_for'][index]) 165 | city_poi = torch.FloatTensor(city_source['poi']) 166 | 167 | city_data = [city_aqi, city_conn, city_poi, city_sim, 168 | city_weather, city_for, city_y] 169 | 170 | return city_data 171 | 172 | 173 | def __getitem__(self, index): 174 | jiaxing_data = self.GetCityData('jiaxing',self.jiaxing,index) 175 | shanghai_data = self.GetCityData('shanghai',self.shanghai,index) 176 | suzhou_data = self.GetCityData('suzhou',self.suzhou,index) 177 | 178 | cities_aqi = torch.FloatTensor(self.cities['aqi'][index]) 179 | cities_conn = torch.tensor(self.cities['conn']) 180 | cities_weather = torch.FloatTensor(self.cities['weather'][index]) 181 | cities_sim = torch.FloatTensor(self.cities['sim'][index]) 182 | 183 | cities_data = [cities_aqi, cities_conn,cities_sim,cities_weather] 184 | 185 | return cities_data,jiaxing_data,shanghai_data,suzhou_data 186 | 187 | def __len__(self): 188 | return len(self.shanghai['weather']) 189 | 190 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.nn import Sequential as Seq, Linear as Lin, ReLU 6 | from torch_scatter import scatter_mean, scatter_add 7 | 8 | 9 | class RNNEncoder(nn.Module): 10 | def __init__(self, input_size, hidden_size, num_layers): 11 | super(RNNEncoder, self).__init__() 12 | self.hidden_size = hidden_size 13 | self.num_layers = num_layers 14 | self.lstm = nn.LSTM(input_size, 15 | hidden_size, 16 | num_layers, 17 | batch_first=True) 18 | 19 | def forward(self, x, h0, c0): 20 | # Set initial hidden and cell states 21 | # Forward propagate LSTM 22 | out, (h_n, c_n) = self.lstm(x, (h0, c0)) 23 | # out: tensor of shape (batch_size, seq_length, hidden_size) 24 | 25 | # Decode the hidden state of the last time step 26 | return h_n, c_n 27 | 28 | 29 | class RNNDecoder(nn.Module): 30 | def __init__(self, input_size, hidden_size, num_layers): 31 | super(RNNDecoder, self).__init__() 32 | self.hidden_size = hidden_size 33 | self.num_layers = num_layers 34 | self.lstm = nn.LSTM(input_size, 35 | hidden_size, 36 | num_layers, 37 | batch_first=True) 38 | self.lin = nn.Linear(hidden_size, 1) 39 | self.relu = nn.ReLU() 40 | 41 | def forward(self, x, h_0, c_0): 42 | # Forward propagate LSTM 43 | out, (h_n, c_n) = self.lstm(x, (h_0, c_0)) 44 | # out: tensor of shape (batch_size, seq_length, hidden_size) 45 | out = self.lin(out) 46 | out = self.relu(out) 47 | # Decode the hidden state of the last time step 48 | return out, h_n, c_n 49 | 50 | 51 | class GlobalModel(torch.nn.Module): 52 | def __init__(self, aqi_em, rnn_h, rnn_l, gnn_h): 53 | super(GlobalModel, self).__init__() 54 | self.aqi_em = aqi_em 55 | self.rnn_h = rnn_h 56 | self.gnn_h = gnn_h 57 | self.aqi_embed = Seq(Lin(1, aqi_em), ReLU()) 58 | self.aqi_rnn = nn.LSTM(aqi_em, rnn_h, rnn_l, batch_first=True) 59 | self.city_gnn = CityGNN(rnn_h, 2, gnn_h) 60 | 61 | def batchInput(self, x, edge_w, edge_conn): 62 | sta_num = x.shape[1] 63 | x = x.reshape(-1, x.shape[-1]) 64 | edge_w = edge_w.reshape(-1, edge_w.shape[-1]) 65 | for i in range(edge_conn.size(0)): 66 | edge_conn[i, :] = torch.add(edge_conn[i, :], i * sta_num) 67 | edge_conn = edge_conn.transpose(0, 1) 68 | edge_conn = edge_conn.reshape(2, -1) 69 | return x, edge_w, edge_conn 70 | 71 | def forward(self, city_aqi, city_conn, city_w, city_num): 72 | city_aqi = city_aqi.unsqueeze(dim=-1) 73 | city_aqi = self.aqi_embed(city_aqi) 74 | city_aqi, _ = self.aqi_rnn(city_aqi.reshape(-1, 24, self.aqi_em)) 75 | city_aqi = city_aqi.reshape(-1, 10, 24, self.rnn_h) 76 | city_aqi = city_aqi.transpose(1, 2) 77 | city_aqi = city_aqi.reshape(-1, city_num, city_aqi.shape[-1]) 78 | 79 | city_conn = city_conn.transpose(1, 2).repeat(24, 1, 1) 80 | city_w = city_w.reshape(-1, city_w.shape[-2], city_w.shape[-1]) 81 | # print(city_aqi.shape,city_conn.shape, city_w.shape) 82 | city_x, city_weight, city_conn = self.batchInput( 83 | city_aqi, city_w, city_conn) 84 | out = self.city_gnn(city_x, city_conn, city_weight) 85 | out = out.reshape(-1, 24, city_num, out.shape[-1]) 86 | 87 | return out 88 | 89 | 90 | class CityGNN(torch.nn.Module): 91 | def __init__(self, node_h, edge_h, gnn_h): 92 | super(CityGNN, self).__init__() 93 | self.node_mlp_1 = Seq(Lin(2 * node_h + edge_h, gnn_h), 94 | ReLU(inplace=True)) 95 | self.node_mlp_2 = Seq(Lin(node_h + gnn_h, gnn_h), ReLU(inplace=True)) 96 | 97 | def forward(self, x, edge_index, edge_attr): 98 | # x: [N, F_x], where N is the number of nodes. 99 | # edge_index: [2, E] with max entry N - 1. 100 | # edge_attr: [E, F_e] 101 | row, col = edge_index 102 | out = torch.cat([x[row], x[col], edge_attr], dim=1) 103 | out = self.node_mlp_1(out) 104 | out = scatter_mean(out, col, dim=0, dim_size=x.size(0)) 105 | out = torch.cat([x, out], dim=1) 106 | return self.node_mlp_2(out) 107 | 108 | 109 | class CityModel(nn.Module): 110 | """Station graph""" 111 | def __init__(self, aqi_em, poi_em, wea_em, rnn_h, rnn_l, gnn_h): 112 | super(CityModel, self).__init__() 113 | self.rnn_h = rnn_h 114 | self.gnn_h = gnn_h 115 | self.rnn_l = rnn_l 116 | self.aqi_embed = Seq(Lin(1, aqi_em), ReLU()) 117 | self.poi_embed = Seq(Lin(5, poi_em), ReLU()) 118 | self.city_embed = Seq(Lin(gnn_h, wea_em), ReLU()) 119 | self.wea_embed = Seq(Lin(5, wea_em), ReLU()) 120 | self.sta_gnn = StaGNN(aqi_em + poi_em, 2, gnn_h, 2 * wea_em) 121 | self.encoder = RNNEncoder(input_size=gnn_h, 122 | hidden_size=rnn_h, 123 | num_layers=rnn_l) 124 | self.decoder_embed = Seq(Lin(1, aqi_em), ReLU()) 125 | self.decoder = RNNDecoder(input_size=4 + aqi_em, 126 | hidden_size=rnn_h, 127 | num_layers=rnn_l) 128 | 129 | def batchInput(self, x, edge_w, edge_conn): 130 | sta_num = x.shape[1] 131 | x = x.reshape(-1, x.shape[-1]) 132 | edge_w = edge_w.reshape(-1, edge_w.shape[-1]) 133 | for i in range(edge_conn.size(0)): 134 | edge_conn[i, :] = torch.add(edge_conn[i, :], i * sta_num) 135 | edge_conn = edge_conn.transpose(0, 1) 136 | edge_conn = edge_conn.reshape(2, -1) 137 | return x, edge_w, edge_conn 138 | 139 | def forward(self, city_data, city_u, device): 140 | sta_aqi, sta_conn, sta_poi, sta_w, sta_wea, sta_for, _ = city_data 141 | sta_num = sta_aqi.shape[1] 142 | sta_x = sta_aqi.unsqueeze(dim=-1) 143 | sta_x = self.aqi_embed(sta_x) 144 | sta_poi = self.poi_embed(sta_poi) 145 | sta_poi = sta_poi.unsqueeze(dim=-2).repeat_interleave(24, dim=-2) 146 | sta_x = torch.cat([sta_x, sta_poi], dim=-1) 147 | sta_x = sta_x.transpose(1, 2) 148 | sta_x = sta_x.reshape(-1, sta_x.shape[-2], sta_x.shape[-1]) 149 | 150 | sta_conn = sta_conn.transpose(1, 2).repeat(24, 1, 1) 151 | sta_w = sta_w.reshape(-1, sta_w.shape[-2], sta_w.shape[-1]) 152 | # print(sta_x.shape,sta_conn.shape,sta_w.shape) 153 | sta_x, sta_weight, sta_conn = self.batchInput(sta_x, sta_w, sta_conn) 154 | city_u = self.city_embed(city_u) 155 | sta_wea = self.wea_embed(sta_wea) 156 | sta_u = torch.cat([city_u, sta_wea], dim=-1) 157 | sta_x = self.sta_gnn(sta_x, sta_conn, sta_weight, sta_u, sta_num) 158 | sta_x = sta_x.reshape(-1, 24, sta_num, sta_x.shape[-1]).transpose(1, 2) 159 | sta_x = sta_x.reshape(-1, 24, sta_x.shape[-1]) 160 | 161 | h0 = torch.randn(self.rnn_l, sta_x.size(0), self.rnn_h).to(device) 162 | c0 = torch.randn(self.rnn_l, sta_x.size(0), self.rnn_h).to(device) 163 | h_x, c_x = self.encoder(sta_x, h0, c0) 164 | 165 | outputs = torch.zeros((sta_x.size(0), sta_for.size(1), 1)).to(device) 166 | aqi = sta_aqi[:, :, -1].reshape(-1, 1) 167 | sta_for = sta_for.repeat(sta_num, 1, 1) 168 | for i in range(sta_for.size(1)): 169 | aqi_em = self.decoder_embed(aqi) 170 | inputs = torch.cat((aqi_em, sta_for[:, i]), dim=-1) 171 | inputs = inputs.unsqueeze(dim=1) 172 | output, h_x, c_x = self.decoder(inputs, h_x, c_x) 173 | output = output.reshape(-1, 1) 174 | outputs[:, i] = output 175 | aqi = output 176 | outputs = outputs.reshape(-1, sta_num, sta_for.size(1)) 177 | 178 | return outputs 179 | 180 | 181 | class StaGNN(torch.nn.Module): 182 | def __init__(self, node_h, edge_h, gnn_h, u_h): 183 | super(StaGNN, self).__init__() 184 | self.node_mlp_1 = Seq(Lin(2 * node_h + edge_h, gnn_h), 185 | ReLU(inplace=True)) 186 | self.node_mlp_2 = Seq(Lin(node_h + gnn_h + u_h, gnn_h), 187 | ReLU(inplace=True)) 188 | 189 | def forward(self, x, edge_index, edge_attr, u, sta_num): 190 | # x: [N, F_x], where N is the number of nodes. 191 | # edge_index: [2, E] with max entry N - 1. 192 | # edge_attr: [E, F_e] 193 | u = u.reshape(-1, u.shape[-1]) 194 | u = u.repeat(sta_num, 1) 195 | row, col = edge_index 196 | out = torch.cat([x[row], x[col], edge_attr], dim=1) 197 | out = self.node_mlp_1(out) 198 | out = scatter_mean(out, col, dim=0, dim_size=x.size(0)) 199 | out = torch.cat([x, out, u], dim=1) 200 | return self.node_mlp_2(out) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.data as Data 12 | 13 | from dataset import testDataset, trainDataset, valDataset 14 | from model import CityModel, GlobalModel 15 | from torch_geometric.nn import MetaLayer 16 | 17 | parser = argparse.ArgumentParser(description='Multi-city AQI forecasting') 18 | parser.add_argument('--device', type=str, default='cuda', help='') 19 | parser.add_argument('--run_times', type=int, default=5, help='') 20 | parser.add_argument('--epoch', type=int, default=300, help='') 21 | parser.add_argument('--batch_size', type=int, default=128, help='') 22 | parser.add_argument('--city_num', type=int, default=10, help='') 23 | parser.add_argument('--gnn_h', type=int, default=32, help='') 24 | parser.add_argument('--rnn_h', type=int, default=64, help='') 25 | parser.add_argument('--rnn_l', type=int, default=1, help='') 26 | parser.add_argument('--aqi_em', type=int, default=16, help='') 27 | parser.add_argument('--poi_em', type=int, default=8, help='poi embedding') 28 | parser.add_argument('--wea_em', type=int, default=12, help='wea embedding') 29 | parser.add_argument('--lr', type=float, default=0.001, help='lr') 30 | parser.add_argument('--wd', type=float, default=0.001, help='weight decay') 31 | parser.add_argument('--pred_step', type=int, default=12, help='step') 32 | args = parser.parse_args() 33 | 34 | device = args.device 35 | 36 | train_dataset = trainDataset() 37 | train_loader = Data.DataLoader(train_dataset, 38 | batch_size=args.batch_size, 39 | num_workers=4, 40 | shuffle=True) 41 | 42 | val_dataset = valDataset() 43 | val_loader = Data.DataLoader(val_dataset, 44 | batch_size=args.batch_size, 45 | num_workers=4, 46 | shuffle=True) 47 | 48 | test_dataset = testDataset() 49 | test_loader = Data.DataLoader(test_dataset, 50 | batch_size=args.batch_size, 51 | num_workers=4, 52 | shuffle=False) 53 | 54 | for runtimes in range(args.run_times): 55 | 56 | global_model = GlobalModel(args.aqi_em, args.rnn_h, args.rnn_l, 57 | args.gnn_h).to(device) 58 | jiaxing_model = CityModel(args.aqi_em, args.poi_em, args.wea_em, 59 | args.rnn_h, args.rnn_l, args.gnn_h).to(device) 60 | shanghai_model = CityModel(args.aqi_em, args.poi_em, args.wea_em, 61 | args.rnn_h, args.rnn_l, args.gnn_h).to(device) 62 | suzhou_model = CityModel(args.aqi_em, args.poi_em, args.wea_em, args.rnn_h, 63 | args.rnn_l, args.gnn_h).to(device) 64 | 65 | city_model_num = sum(p.numel() for p in global_model.parameters() 66 | if p.requires_grad) 67 | print('city_model:', 'Trainable,', city_model_num) 68 | 69 | shanghai_model_num = sum(p.numel() for p in shanghai_model.parameters() 70 | if p.requires_grad) 71 | print('shanghai_model_num:', 'Trainable,', shanghai_model_num) 72 | 73 | criterion = nn.MSELoss() 74 | params = list(global_model.parameters()) + list(jiaxing_model.parameters()) + \ 75 | list(shanghai_model.parameters()) + list(suzhou_model.parameters()) 76 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wd) 77 | 78 | val_loss_min = np.inf 79 | for epoch in range(args.epoch): 80 | for i, (cities_data, jiaxing_data, shanghai_data, 81 | suzhou_data) in enumerate(train_loader): 82 | cities_aqi, cities_conn, cities_sim, _ = [x.to(device) for x in cities_data] 83 | # print(cities_aqi.shape, cities_conn.shape,cities_sim.shape,cities_weather.shape) 84 | city_u = global_model(cities_aqi, cities_conn, cities_sim, 85 | args.city_num) 86 | 87 | jiaxing_data = [item.to(device, non_blocking=True) for item in jiaxing_data] 88 | jiaxing_outputs = jiaxing_model(jiaxing_data, city_u[:, :, 4], device) 89 | jiaxing_loss = criterion(jiaxing_outputs, jiaxing_data[-1]) 90 | 91 | shanghai_data = [item.to(device, non_blocking=True) for item in shanghai_data] 92 | shanghai_outputs = shanghai_model(shanghai_data, city_u[:, :, 6], device) 93 | shanghai_loss = criterion(shanghai_outputs, shanghai_data[-1]) 94 | 95 | suzhou_data = [item.to(device, non_blocking=True) for item in suzhou_data] 96 | suzhou_outputs = suzhou_model(suzhou_data, city_u[:, :, 8], device) 97 | suzhou_loss = criterion(suzhou_outputs, suzhou_data[-1]) 98 | 99 | jiaxing_model.zero_grad() 100 | shanghai_model.zero_grad() 101 | suzhou_model.zero_grad() 102 | global_model.zero_grad() 103 | 104 | loss = jiaxing_loss + shanghai_loss + suzhou_loss 105 | 106 | loss.backward() 107 | optimizer.step() 108 | 109 | if i % 20 == 0 and epoch % 50 == 0: 110 | 111 | print('{},Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( 112 | 'jiaxing', epoch, args.epoch, i, 113 | int(5922 / args.batch_size), jiaxing_loss.item())) 114 | 115 | print('{},Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( 116 | 'shanghai', epoch, args.epoch, i, 117 | int(5922 / args.batch_size), shanghai_loss.item())) 118 | 119 | print('{},Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format( 120 | 'suzhou', epoch, args.epoch, i, 121 | int(5922 / args.batch_size), suzhou_loss.item())) 122 | 123 | val_loss = 0 124 | with torch.no_grad(): 125 | for j, (cities_data_val, jiaxing_data_val, shanghai_data_val, 126 | suzhou_data_val) in enumerate(val_loader): 127 | cities_aqi_val, cities_conn_val, cities_sim_val, _ = [x.to(device) for x in cities_data_val] 128 | # print(cities_aqi.shape, cities_conn.shape,cities_sim.shape,cities_weather.shape) 129 | city_u_val = global_model(cities_aqi_val, cities_conn_val, 130 | cities_sim_val, args.city_num) 131 | 132 | jiaxing_data_val = [item.to(device, non_blocking=True) for item in jiaxing_data_val] 133 | jiaxing_outputs_val = jiaxing_model(jiaxing_data_val, city_u_val[:, :, 4], device) 134 | jiaxing_loss_val = criterion(jiaxing_outputs_val, jiaxing_data_val[-1]) 135 | 136 | shanghai_data_val = [item.to(device, non_blocking=True) for item in shanghai_data_val] 137 | shanghai_outputs_val = shanghai_model(shanghai_data_val, city_u_val[:, :, 6], device) 138 | shanghai_loss_val = criterion(shanghai_outputs_val, shanghai_data_val[-1]) 139 | 140 | suzhou_data_val = [item.to(device, non_blocking=True) for item in suzhou_data_val] 141 | suzhou_outputs_val = suzhou_model(suzhou_data_val, city_u_val[:, :, 8], device) 142 | suzhou_loss_val = criterion(suzhou_outputs_val, suzhou_data_val[-1]) 143 | 144 | val_loss = val_loss + jiaxing_loss_val.item( 145 | ) + shanghai_loss_val.item() + suzhou_loss_val.item() 146 | 147 | if val_loss < val_loss_min and epoch > (args.epoch * 0.7): 148 | torch.save(global_model.state_dict(), 149 | './checkpoints/global.ckpt') 150 | torch.save(jiaxing_model.state_dict(), 151 | './checkpoints/jiaxing.ckpt') 152 | torch.save(shanghai_model.state_dict(), 153 | './checkpoints/shanghai.ckpt') 154 | torch.save(suzhou_model.state_dict(), 155 | './checkpoints/suzhou.ckpt') 156 | val_loss_min = val_loss 157 | 158 | print('Finished Training') 159 | 160 | mae_loss = torch.zeros(3, 4) 161 | rmse_loss = torch.zeros(3, 4) 162 | 163 | def cal_loss(outputs, y, index): 164 | 165 | temp_loss = torch.abs(outputs - y) 166 | mae_loss_1 = temp_loss[:, :, 0] 167 | mae_loss_3 = temp_loss[:, :, 2] 168 | mae_loss_6 = temp_loss[:, :, 5] 169 | mae_loss_12 = temp_loss[:, :, -1] 170 | 171 | mae_loss[index, 0] += mae_loss_1.sum().item() 172 | mae_loss[index, 1] += mae_loss_3.sum().item() 173 | mae_loss[index, 2] += mae_loss_6.sum().item() 174 | mae_loss[index, 3] += mae_loss_12.sum().item() 175 | 176 | temp_loss = torch.pow(temp_loss, 2) 177 | rmse_loss_1 = temp_loss[:, :, 0] 178 | rmse_loss_3 = temp_loss[:, :, 2] 179 | rmse_loss_6 = temp_loss[:, :, 5] 180 | rmse_loss_12 = temp_loss[:, :, -1] 181 | 182 | rmse_loss[index, 0] += rmse_loss_1.sum().item() 183 | rmse_loss[index, 1] += rmse_loss_3.sum().item() 184 | rmse_loss[index, 2] += rmse_loss_6.sum().item() 185 | rmse_loss[index, 3] += rmse_loss_12.sum().item() 186 | 187 | with torch.no_grad(): 188 | global_model.load_state_dict(torch.load('./checkpoints/global.ckpt')) 189 | jiaxing_model.load_state_dict(torch.load('./checkpoints/jiaxing.ckpt')) 190 | shanghai_model.load_state_dict(torch.load('./checkpoints/shanghai.ckpt')) 191 | suzhou_model.load_state_dict(torch.load('./checkpoints/suzhou.ckpt')) 192 | 193 | for i, (cities_data, jiaxing_data, shanghai_data, 194 | suzhou_data) in enumerate(test_loader): 195 | cities_aqi, cities_conn, cities_sim, _ = [x.to(device) for x in cities_data] 196 | city_u = global_model(cities_aqi, cities_conn, cities_sim, 197 | args.city_num) 198 | 199 | jiaxing_data = [item.to(device, non_blocking=True) for item in jiaxing_data] 200 | jiaxing_outputs = jiaxing_model(jiaxing_data, city_u[:, :, 4], device) 201 | 202 | shanghai_data = [item.to(device, non_blocking=True) for item in shanghai_data] 203 | shanghai_outputs = shanghai_model(shanghai_data, city_u[:, :, 6], device) 204 | 205 | suzhou_data = [item.to(device, non_blocking=True) for item in suzhou_data] 206 | suzhou_outputs = suzhou_model(suzhou_data, city_u[:, :, 8], device) 207 | 208 | cal_loss(jiaxing_outputs, jiaxing_data[-1], 0) 209 | cal_loss(shanghai_outputs, shanghai_data[-1], 1) 210 | cal_loss(suzhou_outputs, suzhou_data[-1], 2) 211 | 212 | mae_loss = mae_loss.numpy() 213 | jiaxing_mae_loss = mae_loss[0] / (len(test_dataset) * 2) 214 | shanghai_mae_loss = mae_loss[1] / (len(test_dataset) * 10) 215 | suzhou_mae_loss = mae_loss[2] / (len(test_dataset) * 8) 216 | 217 | jiaxing_rmse_loss = torch.sqrt(rmse_loss[0] / (len(test_dataset) * 2)) 218 | shanghai_rmse_loss = torch.sqrt(rmse_loss[1] / (len(test_dataset) * 10)) 219 | suzhou_rmse_loss = torch.sqrt(rmse_loss[2] / (len(test_dataset) * 8)) 220 | 221 | print('jiaxing_mae:', jiaxing_mae_loss) 222 | print('shanghai_mae:', shanghai_mae_loss) 223 | print('suzhou_mae:', suzhou_mae_loss) 224 | 225 | print('jiaxing_rmse:', jiaxing_rmse_loss) 226 | print('shanghai_rmse:', shanghai_rmse_loss) 227 | print('suzhou_rmse:', suzhou_rmse_loss) 228 | --------------------------------------------------------------------------------