├── README.md ├── data └── lipschitz_generator.py ├── freeze1.txt ├── freeze2.txt ├── logs └── .empty ├── model ├── __init__.py ├── model.py ├── tester.py └── trainer.py ├── outputs ├── models │ └── .empty ├── predictions │ ├── .empty │ └── metric_calculation.py └── tensorboard │ └── .empty ├── run.sh ├── run_test.sh ├── test.py ├── train.py └── utils ├── __init__.py ├── data_utils.py └── test_data_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Frigate: Frugal Spatio-temporal Forecasting on Road Networks 2 | This repository is the official implementaion of [**Frigate: Frugal Spatio-temporal Forecasting on Road Networks**](https://doi.org/10.1145/3580305.3599357) 3 | 4 | ## Requirements 5 | This code has been tested under two configurations: 6 | **Config one** 7 | - Python: 3.9.0 8 | - PyTorch: 1.9.0 (CUDA 11.1) 9 | - PyTorch Geometric: 1.7.2 10 | - Numpy: 1.23.3 11 | - Pandas: 1.5.1 12 | - SciPy: 1.9.1 13 | - NetworkX: 2.2.8 14 | 15 | **Config two** 16 | - Python: 3.9.0 17 | - PyTorch: 1.13.1 18 | - PyTorch Geometric: 2.2.0 19 | - Numpy: 1.23.5 20 | - Pandas: 1.5.2 21 | - SciPy: 1.9.3 22 | - NetworkX: 2.8.8 23 | 24 | Other requirements: tensorboardX and tqdm are also required for logging and display 25 | 26 | There is a full list of packages from ```$ pip freeze``` from the two conda environments to help in case of package clashes. 27 | 28 | ## Data 29 | Download the [preprocessed dataset](https://drive.google.com/file/d/1l715iYVktwi8WFs_eOAvoVWS2pPzYiDJ/view?usp=share_link) 30 | from here. Unzip the zip file, and move the contents to be inside the ```data``` folder. 31 | 32 | The expected file structure after this step is: 33 | ```bash 34 | Frigate 35 | ├── data 36 | │ ├── Beijing 37 | │ ├── Chengdu 38 | │ └── Harbin 39 | ├── logs 40 | ├── model 41 | │ ├── __init__.py 42 | │ ├── model.py 43 | │ ├── tester.py 44 | │ └── trainer.py 45 | ├── outputs 46 | │ ├── models 47 | │ ├── predictions 48 | │ └── tensorboard 49 | ├── run.sh 50 | ├── run_test.sh 51 | ├── test.py 52 | ├── train.py 53 | └── utils 54 | ├── __init__.py 55 | ├── data_utils.py 56 | └── test_data_utils.py 57 | ``` 58 | 59 | ## Training 60 | Script named ```run.sh``` is provided to facilitate training. Just change the dataset's name in line 1 and 61 | the path to seen nodes in line 17 for various configurations. There are a few seen.npy already in the dataset folders. 62 | 63 | ```run.sh``` takes one argument that tells which GPU to run the training code on. For example to run the training code on GPU 0, 64 | the command is 65 | 66 | ```bash 67 | bash run.sh 0 68 | ``` 69 | 70 | ## Evaluation 71 | Script named ```run_test.sh``` is provided to facilitate evaluation. You need to set 4 things in the file: 72 | 1. ```dataset``` 73 | 2. ```seen_path``` 74 | 3. ```run_num``` 75 | 4. ```model_name``` 76 | 77 | Run number and model name are used to locate the trained model can be found from the logs. Note, the model name 78 | is just the model file's name, not the full path to it. The test script automatically loads the correct model based 79 | on the ```run_num``` parameter. 80 | To run the evaluation script on GPU 0, do the following: 81 | 82 | ```bash 83 | bash run_test.sh 0 84 | ``` 85 | 86 | The script will display the MAE metric and will save the predictions in ```outputs/predictions/run_/pred_true.npz```. 87 | A metric calculation script is also provided in ```outputs/predictions``` that takes a file in the format saved by this script and 88 | computes the metrics. 89 | 90 | ## ACM Reference Format 91 | > Mridul Gupta, Hariprasad Kodamana, and Sayan Ranu. 2023. Frigate: Frugal Spatio-temporal Forecasting on Road Networks. _In Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD ’23), August 6–10, 2023, Long Beach, CA, USA_. ACM, New York, NY, USA, 12 pages. https://doi.org/10.1145/3580305.3599357 92 | 93 | ### Bibtex 94 | ```tex 95 | @inproceedings{FrigateGNN, 96 | author = {Gupta, Mridul and Kodamana, Hariprasad and Ranu, Sayan}, 97 | title = {Frigate: Frugal Spatio-temporal Forecasting on Road Networks}, 98 | booktitle = {Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD '23)}, 99 | location = {Long Beach, CA, USA}, 100 | publisher = {ACM}, 101 | address = {New York, NY, USA}, 102 | numpages = {12}, 103 | urls = {https://doi.org/10.1145/3580305.3599357}, 104 | year = {2023}, 105 | doi = {10.1134/3580305.3599357}, 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /data/lipschitz_generator.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import numpy as np 4 | import networkx as nx 5 | from pathlib import Path 6 | import multiprocessing as mp 7 | from scipy.sparse import coo_array 8 | 9 | 10 | def sssplr(G, nodes): 11 | result = {} 12 | for n in nodes: 13 | result[n] = nx.single_source_dijkstra_path_length(G, n, cutoff=None, weight='haversine') 14 | return result 15 | 16 | 17 | def lipschitz_node_embeddings(G, nodes, k): 18 | G_temp = G.reverse(copy=True) 19 | anchor_nodes = np.random.choice(nodes, size=k, replace=False) 20 | num_workers = 16 if k > 16 else k 21 | results = [] 22 | per_worker = k/num_workers 23 | pool = mp.Pool(processes=num_workers) 24 | for n in range(num_workers): 25 | start, end = int(per_worker*n), int(per_worker*(n+1)) 26 | results.append( 27 | pool.apply_async( 28 | sssplr, args=[ 29 | G_temp, anchor_nodes[start:end]])) 30 | lips_dist_list = [result.get() for result in results] 31 | pool.close() 32 | pool.join() 33 | lips_dist = {} 34 | for d in lips_dist_list: 35 | lips_dist.update(d) 36 | embeddings = np.zeros((len(nodes), k)) 37 | for i, node_i in enumerate(anchor_nodes): 38 | sd = lips_dist[node_i] 39 | for j, node_j in enumerate(nodes): 40 | dist = sd.get(node_j, -1) 41 | if dist!=-1: 42 | embeddings[node_j, i] = 1/(dist+1) 43 | embeddings = (embeddings - embeddings.mean(axis=0))/embeddings.std(axis=0) 44 | return embeddings 45 | 46 | 47 | def main(): 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--dataset", required=True, type=str, 50 | help="Path to the directory having adj_mx.pkl file, eg: Chengdu") 51 | parser.add_argument('-k', type=int, default=16, 52 | help="Number of Lipschitz anchor nodes") 53 | parser.add_argument('--output_filename', required=True, type=str, 54 | help="Output file name") 55 | pargs = parser.parse_args() 56 | # --------------------------------------------------- 57 | adj_path = Path(pargs.dataset,"adj_mx.pkl") 58 | with open(adj_path, "rb") as f: 59 | adj_data = pickle.load(f) 60 | adj = coo_array((adj_data['v'],adj_data['ij']),shape=adj_data['shape']).todense() 61 | G = nx.DiGraph() 62 | G.add_nodes_from(range(adj_data['shape'][0])) 63 | for i,j in zip(adj_data['ij'][0],adj_data['ij'][1]): 64 | G.add_edge(i, j, haversine=adj[i,j]) 65 | nodes = list(range(adj_data['shape'][0])) 66 | embeddings = lipschitz_node_embeddings(G, nodes, pargs.k) 67 | np.savez_compressed(pargs.output_filename+'.npz', lipschitz=embeddings) 68 | 69 | 70 | if __name__=="__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /freeze1.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | aiohttp 3 | aiosignal 4 | anyio==3.6.2 5 | argon2-cffi==21.3.0 6 | argon2-cffi-bindings==21.2.0 7 | asttokens==2.2.0 8 | async-timeout 9 | attrs 10 | backcall==0.2.0 11 | beautifulsoup4==4.11.1 12 | bleach==5.0.1 13 | blinker==1.4 14 | brotlipy==0.7.0 15 | cachetools 16 | certifi 17 | cffi 18 | charset-normalizer==2.1.1 19 | click 20 | contourpy==1.0.6 21 | cryptography 22 | cycler==0.11.0 23 | debugpy==1.6.4 24 | decorator==5.1.1 25 | defusedxml==0.7.1 26 | docopt==0.6.2 27 | entrypoints==0.4 28 | executing==1.2.0 29 | fastjsonschema==2.16.2 30 | fonttools==4.38.0 31 | frozenlist 32 | google-auth 33 | google-auth-oauthlib 34 | googledrivedownloader==0.4 35 | grpcio 36 | h5py==3.7.0 37 | idna 38 | importlib-metadata 39 | ipykernel==6.17.1 40 | ipython==8.7.0 41 | ipython-genutils==0.2.0 42 | ipywidgets==8.0.2 43 | isodate==0.6.1 44 | jedi==0.18.2 45 | Jinja2==3.1.2 46 | joblib==1.2.0 47 | Js2Py==0.71 48 | jsonschema==4.17.3 49 | jupyter==1.0.0 50 | jupyter-console==6.4.4 51 | jupyter-server==1.23.3 52 | jupyter_client==7.4.7 53 | jupyter_core==5.1.0 54 | jupyterlab-pygments==0.2.2 55 | jupyterlab-widgets==3.0.3 56 | kiwisolver==1.4.4 57 | Markdown 58 | MarkupSafe==2.1.1 59 | matplotlib==3.6.2 60 | matplotlib-inline==0.1.6 61 | mistune==2.0.4 62 | mkl-fft==1.3.1 63 | mkl-random 64 | mkl-service==2.4.0 65 | multidict 66 | nbclassic==0.4.8 67 | nbclient==0.7.2 68 | nbconvert==7.2.5 69 | nbformat==5.7.0 70 | nest-asyncio==1.5.6 71 | networkx==2.8.7 72 | notebook==6.5.2 73 | notebook_shim==0.2.2 74 | numpy 75 | oauthlib 76 | packaging==21.3 77 | pandas==1.5.1 78 | pandocfilters==1.5.0 79 | parso==0.8.3 80 | pexpect==4.8.0 81 | pickleshare==0.7.5 82 | Pillow==9.3.0 83 | platformdirs==2.5.4 84 | prettytable==3.5.0 85 | prometheus-client==0.15.0 86 | prompt-toolkit==3.0.33 87 | protobuf==3.20.1 88 | psutil==5.9.4 89 | ptyprocess==0.7.0 90 | pure-eval==0.2.2 91 | pyaml==21.10.1 92 | pyasn1 93 | pyasn1-modules==0.2.8 94 | pycparser 95 | Pygments==2.13.0 96 | pyjsparser==2.7.1 97 | PyJWT 98 | pyOpenSSL 99 | pyparsing==3.0.9 100 | PyPrind==2.11.3 101 | pyrsistent==0.19.2 102 | pySmartDL==1.3.4 103 | PySocks 104 | python-dateutil==2.8.2 105 | python-louvain==0.16 106 | pytz==2022.5 107 | pytz-deprecation-shim==0.1.0.post0 108 | PyYAML==6.0 109 | pyzmq==24.0.1 110 | qtconsole==5.4.0 111 | QtPy==2.3.0 112 | rdflib==6.2.0 113 | requests 114 | requests-oauthlib==1.3.0 115 | rsa 116 | scikit-learn==1.1.3 117 | scipy==1.9.1 118 | Send2Trash==1.8.0 119 | six 120 | sniffio==1.3.0 121 | soupsieve==2.3.2.post1 122 | stack-data==0.6.2 123 | tensorboard 124 | tensorboard-data-server 125 | tensorboard-plugin-wit 126 | tensorboardX 127 | terminado==0.17.0 128 | threadpoolctl==3.1.0 129 | tinycss2==1.2.1 130 | torch==1.9.0+cu111 131 | torch-geometric==1.7.2 132 | torch-scatter==2.0.9 133 | torch-sparse==0.6.10 134 | torchaudio==0.9.0 135 | torchvision==0.10.0+cu111 136 | tornado==6.2 137 | tqdm==4.64.1 138 | traitlets==5.6.0 139 | typing_extensions==4.4.0 140 | tzdata==2022.6 141 | tzlocal==4.2 142 | urllib3 143 | wcwidth==0.2.5 144 | webencodings==0.5.1 145 | websocket-client==1.4.2 146 | Werkzeug 147 | widgetsnbextension==4.0.3 148 | yarl 149 | zipp 150 | -------------------------------------------------------------------------------- /freeze2.txt: -------------------------------------------------------------------------------- 1 | Bottleneck 2 | brotlipy==0.7.0 3 | certifi 4 | cffi 5 | charset-normalizer 6 | cryptography 7 | flit_core 8 | idna 9 | Jinja2 10 | joblib 11 | MarkupSafe 12 | mkl-fft==1.3.1 13 | mkl-random 14 | mkl-service==2.4.0 15 | numexpr 16 | numpy 17 | packaging 18 | pandas==1.5.2 19 | Pillow==9.3.0 20 | protobuf==3.20.1 21 | psutil 22 | pycparser 23 | pyOpenSSL 24 | pyparsing 25 | PySocks 26 | python-dateutil 27 | pytz 28 | requests 29 | scikit-learn 30 | scipy==1.9.3 31 | six 32 | tensorboardX 33 | threadpoolctl 34 | torch==1.13.1 35 | torch-cluster 36 | torch-geometric 37 | torch-scatter 38 | torch-sparse 39 | torchaudio==0.13.1 40 | torchvision==0.14.1 41 | tqdm 42 | typing_extensions 43 | urllib3 44 | -------------------------------------------------------------------------------- /logs/.empty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idea-iitd/Frigate/3fbdd9f911542b13565313db972cc7c10339334e/logs/.empty -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torch_geometric.nn as gnn 6 | from torch_geometric.nn.conv import MessagePassing 7 | 8 | 9 | class FrigateConv(MessagePassing): 10 | def __init__(this, in_channels, out_channels): 11 | super(FrigateConv, this).__init__(aggr='add') 12 | this.lin = nn.Linear(in_channels, out_channels) 13 | this.lin_r = nn.Linear(in_channels, out_channels) 14 | this.lin_rout = nn.Linear(out_channels, out_channels) 15 | this.lin_ew = nn.Linear(1, 16) 16 | this.gate = nn.Sequential( 17 | nn.Linear(16 * 3, 3), 18 | nn.ReLU(), 19 | nn.Linear(3, 1), 20 | nn.Sigmoid(), 21 | ) 22 | #for p in this.lin_r.parameters(): 23 | # nn.init.constant_(p.data, 0.) 24 | # p.requires_grad = False 25 | def forward(this, x, edge_index, edge_weight, lipschitz_embeddings): 26 | if isinstance(x, torch.Tensor): 27 | x_r = x 28 | x = this.lin(x) 29 | x = (x, x) 30 | else: 31 | x_r = this.lin_r(x[1]) 32 | x_rest = this.lin(x[0]) 33 | x = (x_rest, x_r) 34 | out = this.propagate(edge_index, x=x, edge_weight=edge_weight, lipschitz_embeddings=lipschitz_embeddings) 35 | #out += this.lin_rout(x_r) 36 | out = F.normalize(out, p=2., dim=-1) 37 | return out 38 | def message(this, x_j, edge_index_i, edge_index_j, edge_weight, lipschitz_embeddings): 39 | edge_weight_j = edge_weight.view(-1, 1) 40 | edge_weight_j = this.lin_ew(edge_weight_j) 41 | gating_input = torch.cat((edge_weight_j, lipschitz_embeddings[edge_index_i], 42 | lipschitz_embeddings[edge_index_j]), dim=1) 43 | gating = this.gate(gating_input) 44 | output = x_j * gating 45 | return output 46 | 47 | 48 | class GNN(nn.Module): 49 | def __init__(this, input_dim, hidden_dim, output_dim, nlayers): 50 | super().__init__() 51 | this.nlayers = nlayers 52 | this.gc = nn.ModuleList() 53 | this.gc.append(FrigateConv(input_dim, hidden_dim)) 54 | for _ in range(this.nlayers - 2): 55 | this.gc.append(FrigateConv(hidden_dim, hidden_dim)) 56 | this.gc.append(FrigateConv(hidden_dim, output_dim)) 57 | this.gs_sum = gnn.SAGEConv(1, 1, root_weight=False) 58 | this.reset_param(this.gs_sum) 59 | this.freezer(this.gs_sum) 60 | 61 | def reset_param(this, module): 62 | for n, p in module.named_parameters(): 63 | if n.endswith('bias'): 64 | v = 0. 65 | elif n.endswith('weight'): 66 | v = 1. 67 | torch.nn.init.constant_(p, v) 68 | 69 | def freezer(this, module): 70 | for p in module.parameters(): 71 | p.requires_grad = False 72 | 73 | def forward(this, xs, adjs, edge_weight, lipschitz, mu, std): 74 | last = len(adjs) - 1 75 | if mu is None or std is None: 76 | mu = 0. 77 | std = 1. 78 | x_org = xs.clone()[:, :, :, :1] * std + mu 79 | for i, (edge_index, e_id, size) in enumerate(adjs): 80 | xs_target = xs[:, :, :size[1], :] 81 | xs = this.gc[i]((xs, xs_target), edge_index, edge_weight[e_id], 82 | lipschitz_embeddings=lipschitz) 83 | xs = F.relu(xs) 84 | if i == last: 85 | x_org_selected = x_org[:, :, :size[0], :] 86 | x_org_targ = x_org[:, :, :size[1], :] 87 | nz_org = (x_org_selected != 0).to(torch.float32) 88 | nz_targ = (x_org_targ != 0).to(torch.float32) 89 | x_sum = this.gs_sum((x_org_selected, x_org_targ), edge_index) 90 | count = this.gs_sum((nz_org, nz_targ), edge_index) 91 | return xs, x_sum, count 92 | 93 | 94 | class Encoder(nn.Module): 95 | def __init__(this, input_dim, hidden_dim, enc_input_dim, enc_hidden_dim, nlayers): 96 | super().__init__() 97 | this.gnn = GNN(input_dim, hidden_dim, enc_input_dim//2, nlayers) 98 | this.rnn = nn.LSTM(enc_input_dim, enc_hidden_dim, batch_first=True, num_layers=2, dropout=0.1) 99 | this.h_0 = nn.Parameter(torch.randn(2, enc_hidden_dim)) 100 | this.c_0 = nn.Parameter(torch.randn(2, enc_hidden_dim)) 101 | this.reset_bias() 102 | def reset_bias(this): 103 | r"""https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745/3 104 | """ 105 | for names in this.rnn._all_weights: 106 | for name in filter(lambda n: "bias" in n, names): 107 | bias = getattr(this.rnn, name) 108 | n = bias.size(0) 109 | start, end = n//4, n//2 110 | bias.data[start:end].fill_(1.) 111 | def forward(this, xs, xrev, adjs, adjs_rev, edge_weight, lipschitz, mu, std): 112 | hs, xsum, count = this.gnn(xs, adjs, edge_weight, lipschitz, mu, std) 113 | mean = xsum / count 114 | mean[mean!=mean]=0. 115 | hs_rev, rev_sum, rev_count = this.gnn(xrev, adjs_rev, edge_weight, lipschitz, mu, std) 116 | mean2 = rev_sum / rev_count 117 | mean2[mean2!=mean2]=0. 118 | mean = (mean + mean2) / 2 119 | hs = torch.cat((hs,hs_rev), dim=-1) 120 | batch_size = hs.size(0) 121 | n_nodes = hs.size(2) 122 | h_0 = this.h_0.repeat(batch_size, 1, 1).permute(1, 0, 2).contiguous() 123 | c_0 = this.c_0.repeat(batch_size, 1, 1).permute(1, 0, 2).contiguous() 124 | rnn_outputs = [] 125 | for n in range(n_nodes): 126 | out, (_, _) = this.rnn(hs[:, :, n, :], (h_0, c_0)) 127 | rnn_outputs.append(out[:,-1:,:]) 128 | rnn_outputs = torch.stack(rnn_outputs, dim=2) 129 | return rnn_outputs, mean 130 | 131 | 132 | class Decoder(nn.Module): 133 | def __init__(this, enc_hidden_dim, dec_hidden_dim, dec_output_dim): 134 | super().__init__() 135 | this.rnn = nn.LSTM(enc_hidden_dim + dec_output_dim, dec_hidden_dim, 136 | proj_size=dec_output_dim, batch_first=True, num_layers=2, dropout=0.1) 137 | this.h_0 = nn.Parameter(torch.randn(2, dec_output_dim)) 138 | this.c_0 = nn.Parameter(torch.randn(2, dec_hidden_dim)) 139 | this.s = nn.Parameter(torch.randn(1, dec_output_dim)) 140 | this.reset_bias() 141 | def reset_bias(this): 142 | r"""https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745/3 143 | """ 144 | for names in this.rnn._all_weights: 145 | for name in filter(lambda n: "bias" in n, names): 146 | bias = getattr(this.rnn, name) 147 | n = bias.size(0) 148 | start, end = n//4, n//2 149 | bias.data[start:end].fill_(1.) 150 | def forward(this, enc_outputs, ys=None, mean=0, TFRate=1., future=12): 151 | TFRate = 0. 152 | if ys is None: 153 | TFRate = 0. 154 | decoder_outputs = [] 155 | batch_size = enc_outputs.size(0) 156 | n_nodes = enc_outputs.size(2) 157 | last_output = this.s.repeat(batch_size, 1, 1) 158 | h_last = this.h_0.repeat(batch_size, 1, 1).permute(1, 0, 2).contiguous() 159 | c_last = this.c_0.repeat(batch_size, 1, 1).permute(1, 0, 2).contiguous() 160 | for n in range(n_nodes): 161 | enc_output = enc_outputs[:, :, n, :] 162 | y_slice = ys[:, :, n, :] if ys is not None else None 163 | decoder_outputs_per_node = [] 164 | node_mean = torch.mean(mean[:, :, n, :], dim=1, keepdim=True) if isinstance(mean, torch.Tensor) and len(mean.shape) == 4 else mean 165 | for t in range(future): 166 | if torch.rand(1) >= TFRate: 167 | loop = last_output + node_mean 168 | else: 169 | if t == 0: 170 | loop = last_output + node_mean 171 | else: 172 | loop = y_slice[:, t-1:t, :] 173 | input = torch.cat((enc_output, loop), dim=2) 174 | last_output, (h_last, c_last) = this.rnn(input, (h_last, c_last)) 175 | decoder_outputs_per_node.append(last_output+node_mean) 176 | decoder_outputs_per_node = torch.cat(decoder_outputs_per_node, dim=1) 177 | decoder_outputs.append(decoder_outputs_per_node) 178 | decoder_outputs = torch.stack(decoder_outputs, dim=2) 179 | return decoder_outputs 180 | 181 | 182 | class Frigate(nn.Module): 183 | def __init__(this, gnn_input_dim, gnn_hidden_dim, 184 | enc_input_dim, enc_hidden_dim, dec_hidden_dim, 185 | output_dim, nlayers): 186 | super().__init__() 187 | this.enc = Encoder(gnn_input_dim, gnn_hidden_dim, enc_input_dim, 188 | enc_hidden_dim, nlayers) 189 | this.dec = Decoder(enc_hidden_dim, dec_hidden_dim, output_dim) 190 | def forward(this, xs, xrev, adjs, adjs_rev, edge_weight, lipschitz, ys=None, TFRate=1., future=12, mu=None, std=None): 191 | enc_output, mean = this.enc(xs, xrev, adjs, adjs_rev, edge_weight, lipschitz, mu, std) 192 | dec_output = this.dec(enc_output, ys, mean, TFRate, future) 193 | return dec_output 194 | -------------------------------------------------------------------------------- /model/tester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import numpy as np 5 | from pathlib import Path 6 | from tqdm import tqdm, trange 7 | 8 | from .model import Frigate 9 | 10 | 11 | def model_test(model_args, 12 | device, 13 | dataloaders, 14 | edge_weight, 15 | run_num, 16 | model_name, 17 | n_nodes, 18 | ): 19 | model = Frigate(**model_args).to(device) 20 | model_path = Path("outputs","models",f"run_{run_num}",f"{model_name}") 21 | checkpoint = torch.load(model_path, map_location=device) 22 | model.load_state_dict(checkpoint) 23 | prediction_path = Path("outputs","predictions",f"run_{run_num}") 24 | prediction_path.mkdir(exist_ok=True) 25 | prediction_name = Path(prediction_path, "pred_true.npz") 26 | # 27 | test_loss = test_step(model, dataloaders['test_loader'], device, edge_weight, prediction_name, n_nodes) 28 | 29 | 30 | def test_step(model, 31 | test_loader, 32 | device, 33 | edge_weight, 34 | prediction_name, 35 | n_nodes, 36 | ): 37 | model.eval() 38 | dataloader = test_loader['dataloader'] 39 | nbrloader = test_loader['neighbor_loader'] 40 | rnbrloader = test_loader['rev_loader'] 41 | lipschitz = torch.tensor(dataloader.dataset.ls, dtype=torch.float32).to(device) 42 | mu = torch.tensor(dataloader.dataset.mu, dtype=torch.float32).to(device) 43 | sig = torch.tensor(dataloader.dataset.sig, dtype=torch.float32).to(device) 44 | accumulator = Accumulator() 45 | nb = len(dataloader) 46 | updates = nb // 10 if nb > 10 else 1 47 | loop = tqdm(enumerate(dataloader), total=nb, unit='batch', file=sys.stdout) 48 | loop.set_description("Testing") 49 | prediction = np.ones((32*nb,12,n_nodes,1)) * np.inf 50 | truths = np.ones((32*nb,12,n_nodes,1)) * np.inf 51 | with torch.no_grad(): 52 | for b, (xs, ys) in loop: 53 | xs, ys = xs.to(device), ys.to(device) 54 | xs, ys = xs.to(torch.float32), ys.to(torch.float32) 55 | bs = xs.shape[0] 56 | for n1, n2 in zip(nbrloader, rnbrloader): 57 | batch_size, n_ids, adjs = n1 58 | adjs = [adj.to(device) for adj in adjs] 59 | _, n_id_rev, adjs_rev = n2 60 | adjs_rev = [adj.to(device) for adj in adjs_rev] 61 | x_slice = xs[:, :, n_ids, :] 62 | x_rev = xs[:, :, n_id_rev, :] 63 | y_slice = ys[:, :, n_ids[:batch_size], :1] 64 | y_hat = model(x_slice, x_rev, adjs, adjs_rev, edge_weight, lipschitz, mu=mu, std=sig) 65 | accumulator(y_hat, y_slice) 66 | prediction[b*32:b*32+bs,:,n_ids[:batch_size],:] = y_hat.detach().cpu().numpy() 67 | truths[b*32:b*32+bs,:,n_ids[:batch_size],:] = y_slice.detach().cpu().numpy() 68 | if b % updates == 0: 69 | loop.set_postfix(loss=accumulator.get_score()) 70 | np.savez_compressed(prediction_name, predictions=prediction, truths=truths, ignore_val=np.inf) 71 | return accumulator.get_score() 72 | 73 | 74 | class Accumulator: 75 | def __init__(this): 76 | this.score = 0 77 | this.n = 0 78 | def __call__(this, y_pred, y_true): 79 | if isinstance(y_pred, torch.Tensor): 80 | y_pred = y_pred.detach().cpu().numpy().reshape(-1) 81 | if isinstance(y_true, torch.Tensor): 82 | y_true = y_true.detach().cpu().numpy().reshape(-1) 83 | n = y_true.shape[0] 84 | this.score *= this.n / (this.n + n) 85 | this.n += n 86 | this.score += np.sum(np.absolute(y_true - y_pred)) / this.n 87 | def get_score(this): 88 | return this.score 89 | -------------------------------------------------------------------------------- /model/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import numpy as np 5 | from pathlib import Path 6 | import torch.optim as optim 7 | from tqdm import tqdm, trange 8 | from datetime import datetime 9 | from tensorboardX import SummaryWriter 10 | 11 | from .model import Frigate 12 | 13 | 14 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, 15 | num_training_steps, last_epoch=-1): 16 | r""" 17 | https://github.com/huggingface/transformers/blob/v4.25.1/src/transformers/optimization.py#L75 18 | """ 19 | def lr_lambda(current_step: int): 20 | if current_step < num_warmup_steps: 21 | return float(current_step) / float(max(1, num_warmup_steps)) 22 | return max( 23 | 0.0, float(num_training_steps - current_step) / float(max(1, 24 | num_training_steps - num_warmup_steps))) 25 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 26 | 27 | 28 | def get_tfrate_calculator(nepochs): 29 | def TFRate_calculator(epoch): 30 | a = -1.5 31 | b = -a - 1.3 32 | c = 1 33 | x = epoch / nepochs 34 | return max(min(a*x**x+b*x+c, 1), 0) 35 | return TFRate_calculator 36 | 37 | 38 | def model_train(model_args, 39 | device, 40 | nepochs, 41 | dataloaders, 42 | edge_weight, 43 | loss_fn, 44 | run_num, 45 | logger, 46 | log_dir='outputs/tensorboard', 47 | ): 48 | model = Frigate(**model_args).to(device) 49 | opt = optim.AdamW(model.parameters(), lr=5e-4) 50 | num_training_steps = nepochs * len(dataloaders['train_loader']['dataloader']) * len(dataloaders['train_loader']['neighbor_loader']) 51 | lr_scheduler = get_linear_schedule_with_warmup(opt, 0, num_training_steps) 52 | best_val_loss = float('inf') 53 | patience = 5 54 | exhausted = 0 55 | writer = SummaryWriter(logdir=os.path.join(log_dir, f'run_{run_num}')) 56 | model_path = Path("outputs","models",f"run_{run_num}") 57 | model_path.mkdir(exist_ok=True) 58 | model_name = Path(model_path, f'model_best_loss_{datetime.now().strftime("%d%B%Y_%H_%M_%S")}.pt') 59 | TFRate_calculator = get_tfrate_calculator(nepochs) 60 | # 61 | for e in range(1, nepochs + 1): 62 | logger.info(f'epoch {e}') 63 | TFRate = TFRate_calculator(e - 1) 64 | train_loss = train_step(model, opt, lr_scheduler, dataloaders['train_loader'], device, logger, loss_fn, e, edge_weight, TFRate) 65 | val_loss = val_step(model, dataloaders['val_loader'], device, logger, e, edge_weight) 66 | writer.add_scalar('loss/train', train_loss, e) 67 | writer.add_scalar('loss/val', val_loss, e) 68 | if best_val_loss > val_loss: 69 | best_val_loss = val_loss 70 | exhausted = 0 71 | save_model(model, model_name) 72 | logger.info(f"Model saved at epoch {e} to {model_name} with loss {best_val_loss}") 73 | else: 74 | exhausted += 1 75 | if exhausted >= patience: 76 | logger.info(f"Early stopping at epoch: {e}") 77 | break 78 | 79 | 80 | def train_step(model, 81 | opt, 82 | lr_scheduler, 83 | train_loader, 84 | device, 85 | logger, 86 | loss_fn, 87 | epoch, 88 | edge_weight, 89 | TFRate, 90 | ): 91 | model.train() 92 | accumulator = Accumulator() 93 | dataloader = train_loader['dataloader'] 94 | nbrloader = train_loader['neighbor_loader'] 95 | rnbrloader = train_loader['rev_loader'] 96 | lipschitz = torch.tensor(dataloader.dataset.ls, dtype=torch.float32).to(device) 97 | mu = torch.tensor(dataloader.dataset.mu, dtype=torch.float32).to(device) 98 | sig = torch.tensor(dataloader.dataset.sig, dtype=torch.float32).to(device) 99 | nb = len(dataloader) 100 | updates = nb // 10 if nb > 10 else 1 101 | loop = tqdm(enumerate(dataloader), total=nb, unit='batch', file=sys.stdout) 102 | loop.set_description(f"Training epoch: {epoch}") 103 | for b, (xs, ys) in loop: 104 | xs, ys = xs.to(device), ys.to(device) 105 | xs, ys = xs.to(torch.float32), ys.to(torch.float32) 106 | for nf, nb in zip(nbrloader, rnbrloader): 107 | batch_size, n_ids, adjs = nf 108 | _, n_id_rev, adjs_rev = nb 109 | adjs = [adj.to(device) for adj in adjs] 110 | adjs_rev = [adj.to(device) for adj in adjs_rev] 111 | x_slice = xs[:, :, n_ids, :] 112 | x_slice_rev = xs[:, :, n_id_rev, :] 113 | y_slice = ys[:, :, n_ids[:batch_size], :1] 114 | y_hat = model(x_slice, x_slice_rev, adjs, adjs_rev, edge_weight, lipschitz, y_slice, TFRate=TFRate, mu=mu, std=sig) 115 | loss = loss_fn(y_hat, y_slice) 116 | accumulator(y_hat, y_slice) 117 | opt.zero_grad() 118 | loss.backward() 119 | opt.step() 120 | lr_scheduler.step() 121 | if b % updates == 0: 122 | loop.set_postfix(loss=accumulator.get_score()) 123 | return accumulator.get_score() 124 | 125 | 126 | def val_step(model, 127 | val_loader, 128 | device, 129 | logger, 130 | epoch, 131 | edge_weight, 132 | ): 133 | model.eval() 134 | dataloader = val_loader['dataloader'] 135 | nbrloader = val_loader['neighbor_loader'] 136 | rnbrloader = val_loader['rev_loader'] 137 | lipschitz = torch.tensor(dataloader.dataset.ls, dtype=torch.float32).to(device) 138 | mu = torch.tensor(dataloader.dataset.mu, dtype=torch.float32).to(device) 139 | sig = torch.tensor(dataloader.dataset.sig, dtype=torch.float32).to(device) 140 | accumulator = Accumulator() 141 | nb = len(dataloader) 142 | updates = 1#nb // 10 if nb > 10 else 1 143 | loop = tqdm(enumerate(dataloader), total=nb, unit='batch', file=sys.stdout) 144 | loop.set_description(f"Valid epoch: {epoch}") 145 | with torch.no_grad(): 146 | for b, (xs, ys) in loop: 147 | xs, ys = xs.to(device), ys.to(device) 148 | xs, ys = xs.to(torch.float32), ys.to(torch.float32) 149 | for nf, nb in zip(nbrloader, rnbrloader): 150 | batch_size, n_ids, adjs = nf 151 | _, n_id_rev, adjs_rev = nb 152 | adjs = [adj.to(device) for adj in adjs] 153 | adjs_rev = [adj.to(device) for adj in adjs_rev] 154 | x_slice = xs[:, :, n_ids, :] 155 | x_slice_rev = xs[:, :, n_id_rev, :] 156 | y_slice = ys[:, :, n_ids[:batch_size], :1] 157 | y_hat = model(x_slice, x_slice_rev, adjs, adjs_rev, edge_weight, lipschitz, mu=mu, std=sig) 158 | accumulator(y_hat, y_slice) 159 | if b % updates == 0: 160 | loop.set_postfix(loss=accumulator.get_score()) 161 | return accumulator.get_score() 162 | 163 | 164 | class Accumulator: 165 | def __init__(this): 166 | this.score = 0 167 | this.n = 0 168 | def __call__(this, y_pred, y_true): 169 | if isinstance(y_pred, torch.Tensor): 170 | y_pred = y_pred.detach().cpu().numpy().reshape(-1) 171 | if isinstance(y_true, torch.Tensor): 172 | y_true = y_true.detach().cpu().numpy().reshape(-1) 173 | n = y_true.shape[0] 174 | this.score *= this.n / (this.n + n) 175 | this.n += n 176 | this.score += np.sum(np.absolute(y_true - y_pred)) / this.n 177 | def get_score(this): 178 | return this.score 179 | 180 | 181 | def save_model(model, model_file): 182 | torch.save(model.state_dict(), model_file) 183 | -------------------------------------------------------------------------------- /outputs/models/.empty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idea-iitd/Frigate/3fbdd9f911542b13565313db972cc7c10339334e/outputs/models/.empty -------------------------------------------------------------------------------- /outputs/predictions/.empty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idea-iitd/Frigate/3fbdd9f911542b13565313db972cc7c10339334e/outputs/predictions/.empty -------------------------------------------------------------------------------- /outputs/predictions/metric_calculation.py: -------------------------------------------------------------------------------- 1 | r"""This script is an interface between the predictions saved by Frigate 2 | and metrics calculation. 3 | The predictions file format: 4 | There should be a file named pred_true.npz in run_. 5 | The npz file has 3 keys: 'truths', 'predictions', 'ignore_val' 6 | 7 | The shape of 'truths' and 'predictions' is the same and is equal to 8 | (batch_size, Delta, n_nodes, 1). Not all columns from 0 to n_nodes-1 9 | have valid entries. Only the nodes that weren't seen in training 10 | have predictions. The columns corresponding to seen nodes contain 11 | 'ignore_val' which is np.inf for now. 12 | 13 | This file takes the arguments: --pred_file 14 | Ex: if you want to calculate the metrics for run_1/pred_true.npz then run 15 | $ python3 metric_calculation.py --pred_file "run_1/pred_true.npz" 16 | """ 17 | import argparse 18 | import numpy as np 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--pred_file', required=True, type=str) 24 | pargs = parser.parse_args() 25 | pred_dict = np.load(pargs.pred_file) 26 | truths = pred_dict['truths'] 27 | predictions = pred_dict['predictions'] 28 | ignore_val = pred_dict['ignore_val'] 29 | nodes, = np.where(truths[0, 0, :, 0] != ignore_val) 30 | batches, = np.where(truths[:, 0, nodes[0], 0] != ignore_val) 31 | print(f"Calculating MAE of predictions on {nodes}") 32 | truths = truths[times, :, :, 0][:, :, nodes] 33 | predictions = predictions[times, :, :, 0][:, :, nodes] 34 | MAE = np.mean(np.abs(truths - predictions)) 35 | print(f"MAE = {MAE}") 36 | 37 | 38 | if __name__=="__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /outputs/tensorboard/.empty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idea-iitd/Frigate/3fbdd9f911542b13565313db972cc7c10339334e/outputs/tensorboard/.empty -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | dataset='Harbin' 2 | future=12 3 | past=12 4 | nepochs=50 5 | nlayers=10 6 | gnn_input_dim=18 7 | gnn_hidden_dim=64 8 | enc_input_dim=64 9 | enc_hidden_dim=64 10 | dec_hidden_dim=64 11 | output_dim=1 12 | gpu="${1}" 13 | # 14 | traffic_path="data/${dataset}/traffic_df.pkl.gz" 15 | lipschitz_path="data/${dataset}/lipschitz.npz" 16 | adj_path="data/${dataset}/adj_mx.pkl" 17 | seen_path="data/${dataset}/seen_30.npy" 18 | 19 | # Now running in theatres! 20 | if echo $* | grep -e "--debug" -q 21 | then 22 | CUDA_VISIBLE_DEVICES=$gpu python3 -m pdb train.py --traffic_path "${traffic_path}" --lipschitz_path "${lipschitz_path}" --adj_path "${adj_path}" --seen_path "${seen_path}" --keep_tod --future "${future}" --past "${past}" --nepochs "${nepochs}" --nlayers "${nlayers}" --gnn_input_dim "${gnn_input_dim}" --gnn_hidden_dim "${gnn_hidden_dim}" --enc_input_dim "${enc_input_dim}" --enc_hidden_dim "${enc_hidden_dim}" --dec_hidden_dim "${dec_hidden_dim}" --output_dim "${output_dim}" 23 | # debug 24 | else 25 | CUDA_VISIBLE_DEVICES=$gpu python3 train.py --traffic_path "${traffic_path}" --lipschitz_path "${lipschitz_path}" --adj_path "${adj_path}" --seen_path "${seen_path}" --keep_tod --future "${future}" --past "${past}" --nepochs "${nepochs}" --nlayers "${nlayers}" --gnn_input_dim "${gnn_input_dim}" --gnn_hidden_dim "${gnn_hidden_dim}" --enc_input_dim "${enc_input_dim}" --enc_hidden_dim "${enc_hidden_dim}" --dec_hidden_dim "${dec_hidden_dim}" --output_dim "${output_dim}" 26 | fi 27 | -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | dataset='Harbin' 2 | future=12 3 | past=12 4 | nlayers=10 5 | gnn_input_dim=18 6 | gnn_hidden_dim=64 7 | enc_input_dim=64 8 | enc_hidden_dim=64 9 | dec_hidden_dim=64 10 | output_dim=1 11 | run_num=1 12 | model_name='model_best_loss_01February2023_03_49_14.pt' 13 | gpu="${1}" 14 | # 15 | traffic_path="data/${dataset}/traffic_df.pkl.gz" 16 | lipschitz_path="data/${dataset}/lipschitz.npz" 17 | adj_path="data/${dataset}/adj_mx.pkl" 18 | seen_path="data/${dataset}/seen_30.npy" 19 | 20 | # Now running in theatres! 21 | if echo $* | grep -e "--debug" -q 22 | then 23 | CUDA_VISIBLE_DEVICES=$gpu python3 -m pdb test.py --traffic_path "${traffic_path}" --lipschitz_path "${lipschitz_path}" --adj_path "${adj_path}" --seen_path "${seen_path}" --keep_tod --future "${future}" --past "${past}" --nlayers "${nlayers}" --gnn_input_dim "${gnn_input_dim}" --gnn_hidden_dim "${gnn_hidden_dim}" --enc_input_dim "${enc_input_dim}" --enc_hidden_dim "${enc_hidden_dim}" --dec_hidden_dim "${dec_hidden_dim}" --output_dim "${output_dim}" --model_name "${model_name}" --run_num "${run_num}" 24 | # debug 25 | else 26 | CUDA_VISIBLE_DEVICES=$gpu python3 test.py --traffic_path "${traffic_path}" --lipschitz_path "${lipschitz_path}" --adj_path "${adj_path}" --seen_path "${seen_path}" --keep_tod --future "${future}" --past "${past}" --nlayers "${nlayers}" --gnn_input_dim "${gnn_input_dim}" --gnn_hidden_dim "${gnn_hidden_dim}" --enc_input_dim "${enc_input_dim}" --enc_hidden_dim "${enc_hidden_dim}" --dec_hidden_dim "${dec_hidden_dim}" --output_dim "${output_dim}" --model_name "${model_name}" --run_num "${run_num}" 27 | fi 28 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from io import StringIO 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from utils.test_data_utils import get_dataloader_and_adj_mx 8 | from model.tester import model_test 9 | 10 | 11 | def main(): 12 | # --------------------------------------------------- 13 | parser = argparse.ArgumentParser(description='Test the model') 14 | parser.add_argument('--traffic_path', type=str, required=True, 15 | help='path to traffic data (pkl gz format)') 16 | parser.add_argument('--lipschitz_path', type=str, required=True, 17 | help='path to lipschitz data (npz)') 18 | parser.add_argument('--adj_path', type=str, required=True, 19 | help='path to adjacency data (pickle)') 20 | parser.add_argument('--seen_path', type=str, required=True, 21 | help='path to seen nodes index (npy)') 22 | parser.add_argument('--keep_tod', default=False, action='store_true', 23 | help='whether to keep time of day (boolean flag)') 24 | parser.add_argument('--future', type=int, default=12, 25 | help='how far in the future to predict') 26 | parser.add_argument('--past', type=int, default=12, 27 | help='how far in the past to look') 28 | parser.add_argument('--nlayers', type=int, default=10, 29 | help='number of layers used in the GNN') 30 | parser.add_argument('--gnn_input_dim', type=int, required=True, 31 | help='number of input dimensions taken by gnn') 32 | parser.add_argument('--gnn_hidden_dim', type=int, required=True, 33 | help='number of hidden dimensions of gnn') 34 | parser.add_argument('--enc_input_dim', type=int, required=True, 35 | help='number of input dimensions taken by lstm\'s encoder') 36 | parser.add_argument('--enc_hidden_dim', type=int, required=True, 37 | help='number of hidden dimensions of lstm encoder') 38 | parser.add_argument('--dec_hidden_dim', type=int, required=True, 39 | help='number of hidden dimensions of lstm decoder') 40 | parser.add_argument('--output_dim', type=int, required=True, 41 | help='number of output dimensions') 42 | parser.add_argument('--model_name', type=str, required=True, 43 | help='trained model\'s name that corresponds to given hyperparams') 44 | parser.add_argument('--run_num', type=int, required=True, 45 | help='used to find path to model, the run num of training') 46 | parser.print_usage = parser.print_help 47 | pargs = parser.parse_args() 48 | # --------------------------------------------------- 49 | model_args = { 50 | 'gnn_input_dim':pargs.gnn_input_dim, 51 | 'gnn_hidden_dim':pargs.gnn_hidden_dim, 52 | 'enc_input_dim':pargs.enc_input_dim, 53 | 'enc_hidden_dim':pargs.enc_hidden_dim, 54 | 'dec_hidden_dim':pargs.dec_hidden_dim, 55 | 'output_dim':pargs.output_dim, 56 | 'nlayers':pargs.nlayers, 57 | } 58 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') 59 | dataloaders, edge_weight, n_nodes = get_dataloader_and_adj_mx( 60 | pargs.traffic_path, 61 | pargs.lipschitz_path, 62 | pargs.adj_path, 63 | pargs.seen_path, 64 | keep_tod=pargs.keep_tod, 65 | f=pargs.future, 66 | p=pargs.past, 67 | nlayers=pargs.nlayers, 68 | ) 69 | with StringIO() as s: 70 | pprint(vars(pargs), stream=s, indent=4) 71 | print(s.getvalue()) 72 | edge_weight = edge_weight.to(device).to(torch.float32) 73 | model_test(model_args, 74 | device, 75 | dataloaders, 76 | edge_weight, 77 | pargs.run_num, 78 | pargs.model_name, 79 | n_nodes 80 | ) 81 | 82 | 83 | if __name__=="__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | import torch 5 | import logging 6 | import argparse 7 | from io import StringIO 8 | from pathlib import Path 9 | from pprint import pprint 10 | 11 | from utils.data_utils import get_dataloader_and_adj_mx 12 | from model.trainer import model_train 13 | 14 | 15 | def masked_mae_loss(y_pred, y_true): 16 | mask = (y_true != 0).float() 17 | mask /= mask.mean() 18 | loss = torch.abs(y_pred - y_true) 19 | loss[loss != loss] = 0. 20 | return loss.mean() 21 | 22 | 23 | def get_run_num(log_dir="outputs/tensorboard"): 24 | p = re.compile(r'run_\d+$') 25 | files = (int(file.split('_')[1]) for file in os.listdir(log_dir) if p.match(file)) 26 | run_num = 1 27 | try: 28 | run_num = 1 + max(files) 29 | except ValueError as e: 30 | pass 31 | return run_num 32 | 33 | 34 | def config_logging(run_num, log_dir='logs'): 35 | path = Path(log_dir, f'run_{run_num}') 36 | path.mkdir(exist_ok=True) 37 | logger = logging.getLogger() 38 | file = Path(path, 'log.txt') 39 | fh = logging.FileHandler(file) 40 | fh.setLevel(logging.INFO) 41 | logger.setLevel(logging.INFO) 42 | logger.addHandler(fh) 43 | sys.stderr = open(Path(log_dir, f'run_{run_num}', 'stderr.txt'), 'w') 44 | return logger 45 | 46 | 47 | def main(): 48 | # ---------------- parser setup ------------------- 49 | parser = argparse.ArgumentParser(description='Train the model') 50 | parser.add_argument('--traffic_path', type=str, required=True, 51 | help='path to traffic data (pkl gz format)') 52 | parser.add_argument('--lipschitz_path', type=str, required=True, 53 | help='path to lipschitz data (npz)') 54 | parser.add_argument('--adj_path', type=str, required=True, 55 | help='path to adjacency data (pickle)') 56 | parser.add_argument('--seen_path', type=str, required=True, 57 | help='path to seen nodes index (npy)') 58 | parser.add_argument('--keep_tod', default=False, action='store_true', 59 | help='whether to keep time of day (boolean flag)') 60 | parser.add_argument('--future', type=int, default=12, 61 | help='how far in the future to predict') 62 | parser.add_argument('--past', type=int, default=12, 63 | help='how far in the past to look') 64 | parser.add_argument('--nepochs', type=int, required=True, 65 | help='number of epochs') 66 | parser.add_argument('--nlayers', type=int, default=10, 67 | help='number of layers used in the GNN') 68 | parser.add_argument('--gnn_input_dim', type=int, required=True, 69 | help='number of input dimensions taken by gnn') 70 | parser.add_argument('--gnn_hidden_dim', type=int, required=True, 71 | help='number of hidden dimensions of gnn') 72 | parser.add_argument('--enc_input_dim', type=int, required=True, 73 | help='number of input dimensions taken by lstm\'s encoder') 74 | parser.add_argument('--enc_hidden_dim', type=int, required=True, 75 | help='number of hidden dimensions of lstm encoder') 76 | parser.add_argument('--dec_hidden_dim', type=int, required=True, 77 | help='number of hidden dimensions of lstm decoder') 78 | parser.add_argument('--output_dim', type=int, required=True, 79 | help='number of output dimensions') 80 | parser.print_usage = parser.print_help 81 | pargs = parser.parse_args() 82 | # --------------------------------------------------- 83 | model_args = { 84 | 'gnn_input_dim':pargs.gnn_input_dim, 85 | 'gnn_hidden_dim':pargs.gnn_hidden_dim, 86 | 'enc_input_dim':pargs.enc_input_dim, 87 | 'enc_hidden_dim':pargs.enc_hidden_dim, 88 | 'dec_hidden_dim':pargs.dec_hidden_dim, 89 | 'output_dim':pargs.output_dim, 90 | 'nlayers':pargs.nlayers, 91 | } 92 | run_num = get_run_num() 93 | logger = config_logging(run_num) 94 | print(f"This is run number: {run_num}\n Logs will be saved in logs/run_{run_num}") 95 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') 96 | dataloaders, edge_weight = get_dataloader_and_adj_mx( 97 | pargs.traffic_path, 98 | pargs.lipschitz_path, 99 | pargs.adj_path, 100 | pargs.seen_path, 101 | keep_tod=pargs.keep_tod, 102 | f=pargs.future, 103 | p=pargs.past, 104 | nlayers=pargs.nlayers, 105 | ) 106 | with StringIO() as s: 107 | pprint(vars(pargs), stream=s, indent=4) 108 | logger.info(s.getvalue()) 109 | edge_weight = edge_weight.to(device).to(torch.float32) 110 | model_train(model_args, device, pargs.nepochs, dataloaders, edge_weight, 111 | #masked_mae_loss, run_num, logger) # doesn't help 112 | torch.nn.L1Loss(), run_num, logger) 113 | 114 | 115 | if __name__=="__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idea-iitd/Frigate/3fbdd9f911542b13565313db972cc7c10339334e/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import pandas as pd 5 | from scipy.sparse import coo_array 6 | 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | from torch_geometric.utils import dense_to_sparse 10 | from torch_geometric.data import NeighborSampler 11 | 12 | 13 | class TrafficDataset(Dataset): 14 | def __init__(this, xs, ys, ls, mu, sig, unseen_node_ids, val=False): 15 | this.xs = xs 16 | this.ys = ys 17 | this.ls = ls # lipschitz 18 | this.mu = mu 19 | this.sig = sig 20 | this.unseen_node_ids = unseen_node_ids 21 | this.val = val 22 | def __len__(this): 23 | return len(this.xs) 24 | def __getitem__(this, idx): 25 | x = (this.xs[idx] - this.mu) / this.sig 26 | x[:, this.unseen_node_ids, :] = -this.mu / this.sig 27 | y = this.ys[idx] 28 | if not this.val: 29 | y[:, this.unseen_node_ids, :] = -this.mu / this.sig 30 | l = np.tile(this.ls.reshape(1, this.ls.shape[0], -1), 31 | (x.shape[0], 1, 1)) 32 | x = np.concatenate((x, l), axis=2) 33 | return x, y 34 | 35 | 36 | def get_xy(data, f, p): 37 | x_offsets = np.arange(-(p - 1), 1) 38 | y_offsets = np.arange(1, f + 1) 39 | # tmin - (p - 1) = 0 => tmin = p - 1 40 | # tmax + f = L - 1 => tmax = L - f - 1 41 | tmin = p - 1 42 | tmax = len(data) - f - 1 43 | xs, ys = [], [] 44 | for t in range(tmin, tmax + 1): # tmax inclusive range 45 | xs.append(data[t + x_offsets, :, :]) 46 | ys.append(data[t + y_offsets, :, :]) 47 | xs, ys = list(map(np.stack, [xs, ys])) 48 | return xs, ys 49 | 50 | 51 | def get_dataloader(traffic_path, lipschitz_path, keep_tod, f, p, unseen_node_ids): 52 | traffic_data_df = pd.read_pickle(traffic_path) 53 | traffic_data = traffic_data_df.values 54 | mu, sig = np.mean(traffic_data), np.std(traffic_data) 55 | if keep_tod: 56 | index = traffic_data_df.index.values 57 | index = ((index.astype('datetime64[ns]') - index.astype('datetime64[D]'))/ 58 | np.timedelta64(1,'D')) 59 | tod = np.tile(index, (traffic_data.shape[1], 1)).T 60 | traffic_data = np.transpose(np.stack((traffic_data, tod)), (1, 2, 0)) 61 | else: 62 | traffic_data = traffic_data.reshape(traffic_data.shape[0], -1, 1) 63 | cut_point1 = int(0.7 * len(traffic_data)) 64 | cut_point2 = int(0.9 * len(traffic_data)) 65 | train_data = traffic_data[:cut_point1, :, :] 66 | val_data = traffic_data[cut_point1:cut_point2, :, :] 67 | xys = list(map(lambda arg:get_xy(arg, f, p), [train_data, val_data])) 68 | ls = np.load(lipschitz_path)['lipschitz'] 69 | train_datasets = TrafficDataset(*xys[0], ls, mu, sig, unseen_node_ids) 70 | val_datasets = TrafficDataset(*xys[1], ls, mu, sig, unseen_node_ids, val=True) 71 | train_loader = DataLoader(train_datasets, batch_size=32, shuffle=True) 72 | val_loader = DataLoader(val_datasets, batch_size=32, shuffle=False) 73 | return train_loader, val_loader 74 | 75 | 76 | def get_adjacency(adj_path): 77 | with open(adj_path, 'rb') as pkl: 78 | sparse_adj_data = pickle.load(pkl) 79 | v = sparse_adj_data['v'] 80 | ij = sparse_adj_data['ij'] 81 | shape = sparse_adj_data['shape'] 82 | adj_mx = torch.tensor(coo_array((v, ij), shape=shape).todense()) 83 | edge_index, edge_weight = dense_to_sparse(adj_mx) 84 | return edge_index, edge_weight, sparse_adj_data['shape'] 85 | 86 | 87 | def get_nbrloader(edge_index, node_ids, nlayers): 88 | node_ids = torch.tensor(node_ids, dtype=torch.long) 89 | return NeighborSampler(edge_index, node_idx=node_ids, batch_size=32, 90 | sizes=[-1 for _ in range(nlayers)]) 91 | 92 | 93 | def get_dataloader_and_adj_mx(traffic_path, lipschitz_path, adj_path, seen_path, 94 | *, keep_tod=True, f=12, p=12, nlayers=10): 95 | seen_node_ids = np.load(seen_path) 96 | edge_index, edge_weight, shape = get_adjacency(adj_path) 97 | unseen_node_ids = np.setdiff1d(np.arange(shape[0]), seen_node_ids) 98 | rev_index = torch.flip(edge_index, dims=[0]) 99 | 100 | dataloaders = get_dataloader(traffic_path, lipschitz_path, keep_tod, f, p, 101 | unseen_node_ids) 102 | train_nbrloader = get_nbrloader(edge_index, seen_node_ids, nlayers) 103 | rev_train_nbrloader = get_nbrloader(rev_index, seen_node_ids, nlayers) 104 | val_nbrloader = get_nbrloader(edge_index, unseen_node_ids, nlayers) 105 | rev_val_nbrloader = get_nbrloader(rev_index, unseen_node_ids, nlayers) 106 | dataloaders_ = {'train_loader': { 107 | 'dataloader': dataloaders[0], 108 | 'neighbor_loader': train_nbrloader, 109 | 'rev_loader': rev_train_nbrloader, 110 | }, 111 | 'val_loader': { 112 | 'dataloader': dataloaders[1], 113 | 'neighbor_loader': val_nbrloader, 114 | 'rev_loader': rev_val_nbrloader, 115 | }, 116 | } 117 | return dataloaders_, edge_weight 118 | -------------------------------------------------------------------------------- /utils/test_data_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import pandas as pd 4 | from scipy.sparse import coo_array 5 | 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch_geometric.utils import dense_to_sparse 9 | from torch_geometric.data import NeighborSampler 10 | 11 | 12 | class TrafficDataset(Dataset): 13 | def __init__(this, xs, ys, ls, mu, sig, unseen_node_ids): 14 | this.xs = xs 15 | this.ys = ys 16 | this.ls = ls # lipschitz 17 | this.mu = mu 18 | this.sig = sig 19 | this.unseen_node_ids = unseen_node_ids 20 | def __len__(this): 21 | return len(this.xs) 22 | def __getitem__(this, idx): 23 | x = (this.xs[idx] - this.mu) / this.sig 24 | x[:, this.unseen_node_ids, :] = -this.mu / this.sig 25 | y = this.ys[idx] 26 | l = np.tile(this.ls.reshape(1, this.ls.shape[0], -1), 27 | (x.shape[0], 1, 1)) 28 | x = np.concatenate((x, l), axis=2) 29 | return x, y 30 | 31 | 32 | def get_xy(data, f, p): 33 | x_offsets = np.arange(-(p - 1), 1) 34 | y_offsets = np.arange(1, f + 1) 35 | # tmin - (p - 1) = 0 => tmin = p - 1 36 | # tmax + f = L - 1 => tmax = L - f - 1 37 | tmin = p - 1 38 | tmax = len(data) - f - 1 39 | xs, ys = [], [] 40 | for t in range(tmin, tmax + 1): # tmax inclusive range 41 | xs.append(data[t + x_offsets, :, :]) 42 | ys.append(data[t + y_offsets, :, :]) 43 | xs, ys = list(map(np.stack, [xs, ys])) 44 | return xs, ys 45 | 46 | 47 | def get_dataloader(traffic_path, lipschitz_path, keep_tod, f, p, unseen_node_ids): 48 | traffic_data_df = pd.read_pickle(traffic_path) 49 | traffic_data = traffic_data_df.values 50 | mu, sig = np.mean(traffic_data), np.std(traffic_data) 51 | if keep_tod: 52 | index = traffic_data_df.index.values 53 | index = ((index.astype('datetime64[ns]') - index.astype('datetime64[D]'))/ 54 | np.timedelta64(1,'D')) 55 | tod = np.tile(index, (traffic_data.shape[1], 1)).T 56 | traffic_data = np.transpose(np.stack((traffic_data, tod)), (1, 2, 0)) 57 | else: 58 | traffic_data = traffic_data.reshape(traffic_data.shape[0], -1, 1) 59 | cut_point2 = int(0.9 * len(traffic_data)) 60 | test_data = traffic_data[cut_point2:, :, :] 61 | xy = get_xy(test_data, f, p) 62 | ls = np.load(lipschitz_path)['lipschitz'] 63 | datasets = TrafficDataset(*xy, ls, mu, sig, unseen_node_ids) 64 | test_loader = DataLoader(datasets, batch_size=32, shuffle=False) 65 | return test_loader 66 | 67 | 68 | def get_adjacency(adj_path): 69 | with open(adj_path, 'rb') as pkl: 70 | sparse_adj_data = pickle.load(pkl) 71 | v = sparse_adj_data['v'] 72 | ij = sparse_adj_data['ij'] 73 | shape = sparse_adj_data['shape'] 74 | adj_mx = torch.tensor(coo_array((v, ij), shape=shape).todense()) 75 | edge_index, edge_weight = dense_to_sparse(adj_mx) 76 | return edge_index, edge_weight, sparse_adj_data['shape'] 77 | 78 | 79 | def get_nbrloader(edge_index, node_ids, nlayers): 80 | node_ids = torch.tensor(node_ids, dtype=torch.long) 81 | return NeighborSampler(edge_index, node_idx=node_ids, batch_size=32, 82 | sizes=[-1 for _ in range(nlayers)]) 83 | 84 | 85 | def get_dataloader_and_adj_mx(traffic_path, lipschitz_path, adj_path, seen_path, 86 | *, keep_tod=True, f=12, p=12, nlayers=10): 87 | seen_node_ids = np.load(seen_path) 88 | edge_index, edge_weight, shape = get_adjacency(adj_path) 89 | unseen_node_ids = np.setdiff1d(np.arange(shape[0]), seen_node_ids) 90 | rev_index = torch.flip(edge_index, dims=[0]) 91 | 92 | dataloader = get_dataloader(traffic_path, lipschitz_path, keep_tod, f, p, 93 | unseen_node_ids) 94 | test_nbrloader = get_nbrloader(edge_index, unseen_node_ids, nlayers) 95 | test_rnbrloader = get_nbrloader(rev_index, unseen_node_ids, nlayers) 96 | dataloaders = {'test_loader': { 97 | 'dataloader': dataloader, 98 | 'neighbor_loader': test_nbrloader, 99 | 'rev_loader': test_rnbrloader, 100 | }, 101 | } 102 | return dataloaders, edge_weight, shape[0] 103 | --------------------------------------------------------------------------------