├── .gitignore
├── LICENSE
├── README.md
├── TrajLearn
├── TrajectoryBatchDataset.py
├── config_loader.py
├── evaluator.py
├── logger.py
├── mixed_res.py
├── model.py
├── preprocess.py
├── trainer.py
└── utils.py
├── baselines
├── HigherOrderAttnLSTM.py
├── HigherOrderGRU.py
├── HigherOrderLSTM.py
├── HigherOrderMarkovChain.py
└── __init__.py
├── configs.yaml
├── download_data.sh
├── environment.yml
├── img
└── architecture.jpg
└── main.py
/.gitignore:
--------------------------------------------------------------------------------
1 | logs/
2 | data/
3 | plots/
4 | old/
5 | visualization
6 |
7 | # OS
8 | .DS_Store
9 |
10 | # Byte-compiled / optimized / DLL files
11 | *.pyc
12 | *.pth
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # Jupyter Notebook
18 | .ipynb_checkpoints
19 |
20 | # Environments
21 | .env
22 | .venv
23 | env/
24 | venv/
25 | ENV/
26 | env.bak/
27 | venv.bak/
28 |
29 | ### Visual Studio Code ###
30 | .vscode/*
31 | !.vscode/settings.json
32 | !.vscode/tasks.json
33 | !.vscode/launch.json
34 | !.vscode/extensions.json
35 |
36 | ## Data
37 | ./data
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Amirhossein Nadiri
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TrajLearn: A Novel Model for Trajectory Prediction
2 |
3 | ## Overview
4 |
5 | *TrajLearn* is a transformer-based model designed to predict future trajectories using higher-order mobility flow representations (hexagonal grids). This model integrates a beam search variant to enhance spatial continuity, providing superior accuracy compared to existing methods. It is a powerful solution for trajectory prediction tasks in autonomous vehicles, robotics, and human motion analysis.
6 |
7 |
8 |
9 |
10 |
11 | ## Getting Started
12 |
13 | ### Prerequisites
14 |
15 | - This implementation requires **Python version `>= 3.8`**.
16 | - Ensure you have a compatible Python environment. Refer to `environment.yml` for the required packages.
17 |
18 | ### Step-by-Step Instructions
19 |
20 | 0. **Clone the Repository**
21 |
22 | First, clone the project on your computer:
23 | ```bash
24 | git clone https://github.com/amir-ni/Trajectory-prediction
25 | ```
26 |
27 | 1. **Download the Datasets**:
28 |
29 | Make the dataset downloader script executable:
30 | ```bash
31 | chmod +x ./download_data.sh
32 | ```
33 | Run the script to download the datasets. You can specify the datasets by passing them as arguments (`geolife`, `porto`, or `rome`). For example:
34 | ```bash
35 | ./download_data.sh geolife porto rome
36 | ```
37 |
38 | 2. **Prepare the Datasets**:
39 |
40 | After downloading, run the following command to prepare and transform the datasets:
41 | ```bash
42 | python3 TrajLearn/preprocess.py --input_dir --output_dir --embedding_dim --datasets
43 | ```
44 | You can specify the `input_dir`, `output_dir`, and `datasets` to be processed:
45 | - **`--input_dir`**: Directory where the raw datasets are stored. Defaults to `./data`.
46 | - **`--output_dir`**: Directory where the transformed datasets will be saved. Defaults to `./data`.
47 |
48 | - **`--datasets`**: Select which datasets to process (`geolife`, `porto`, `rome`). Multiple datasets can be processed by specifying them in a space-separated list. For example:
49 | ```bash
50 | python3 TrajLearn/preprocess.py --datasets rome geolife porto --embedding_dim 512
51 | ```
52 |
53 | 3. **Set Up the Model Configuration**:
54 |
55 | The configuration of the model, such as batch size, learning rates, and dataset-specific settings, can be passed as to model as a `yaml` configuration file. This file can also include multiple configurations and will train separate models sequentially. An example configuration used for generating results provided in the paper can be found in `configs.yaml`.
56 |
57 | You can create/modify this file according to your needs. Some configurations are described below and additional configurations can be found in the end of this document.
58 | - **`data_dir`**: Directory where the dataset is stored. If you have not changed the default output directory in the previous steps, the address would be `./data`.
59 | - **`dataset`**: Name of the dataset being used, such as `rome7`.
60 | - **`model_checkpoint_directory`**: Directory path where model checkpoints will be saved during training.
61 | - **`min_input_length`**: The minimum length of input sequences used during training and testing.
62 | - **`max_input_length`**: The maximum length of input sequences allowed for model training.
63 | - **`test_input_length`**: Length of the input sequence during testing.
64 | - **`test_prediction_length`**: The number of future steps the model will predict during testing.
65 |
66 | After modifying these parameters as per your requirements, save it in a `yaml` file. This file will be used during training and testing to control model behavior.
67 |
68 | 4. **Train the Model**:
69 |
70 | After configuring the model, you can start the training process. Use the following command:
71 | ```bash
72 | python3 main.py configs.yaml
73 | ```
74 |
75 | This will train the model using the parameters specified in the `configs.yaml` file. You can change it to your saved `yaml` file.
76 |
77 | 5. **Test the Model**:
78 |
79 | Once the model is trained, you can evaluate its performance by running:
80 | ```bash
81 | python3 main.py configs.yaml --test
82 | ```
83 |
84 | This will test the trained model on the test part of the dataset.
85 |
86 |
87 |
88 | ### Additional Configuration Options:
89 |
90 | - **`test_ratio`**: Proportion of the dataset used for testing. For example, a value of `0.2` means 20% of the dataset will be used for testing.
91 | - **`validation_ratio`**: Proportion of the dataset used for validation. For example, a value of `0.1` means 10% of the dataset will be used for validation.
92 | - **`delimiter`**: The character that separates values in your dataset files (default is `" "`).
93 | - **`batch_size`**: The number of samples processed together in one forward/backward pass.
94 | - **`device`**: The computational device to use for training and testing. Set to `cuda` for GPU acceleration or `cpu` if no GPU is available.
95 | - **`max_epochs`**: The maximum number of training epochs, where one epoch means a complete pass through the entire dataset.
96 | - **`block_size`**: Block size used for processing sequences. Defines the length of the sequence chunks used during training and testing.
97 | - **`learning_rate`**: Initial learning rate for the optimizer. Adjust this to control how fast the model learns.
98 | - **`weight_decay`**: Regularization term to avoid overfitting by penalizing large weights. Higher values provide stronger regularization.
99 | - **`beta1`**: Beta1 hyperparameter for the Adam optimizer, which controls the decay rate for the first moment estimate.
100 | - **`beta2`**: Beta2 hyperparameter for the Adam optimizer, controlling the decay rate for the second moment estimate.
101 | - **`grad_clip`**: Threshold for gradient clipping. Gradients that exceed this value will be clipped to prevent exploding gradients.
102 | - **`decay_lr`**: Boolean flag to indicate whether the learning rate should be decayed over time.
103 | - **`warmup_iters`**: Number of iterations during which the learning rate will increase from a small value to the initial learning rate (used in learning rate scheduling).
104 | - **`lr_decay_iters`**: Number of iterations over which the learning rate decays.
105 | - **`min_lr`**: Minimum learning rate after decay. The learning rate will not decrease below this value.
106 | - **`seed`**: Random seed for reproducibility. Ensures that experiments can be replicated with the same results.
107 | - **`n_layer`**: Number of layers in the transformer model. More layers can increase model capacity but also computational cost.
108 | - **`n_head`**: Number of attention heads in the transformer model, which allows the model to focus on different parts of the input sequence simultaneously.
109 | - **`n_embd`**: Dimensionality of the embedding space. This represents the size of the vector representations for each token in the input sequence.
110 | - **`bias`**: Boolean flag to indicate whether to include bias terms in the model's layers. Set to `False` to exclude bias.
111 | - **`dropout`**: Dropout rate used for regularization. A value of `0` means no dropout will be applied.
112 | - **`custom_initialization`**: A boolean flag that specifies whether to use a axial coordination based initialization for the model's training.
113 | - **`train_from_checkpoint_if_exist`**: A boolean flag that indicates whether to resume training from an existing checkpoint if one is found.
114 | - **`patience`**: Integer value indicating the number of epochs to wait for before early stopping.
115 | - **`continuity`**: Boolean flag to enforce spatial continuity constraints on predictions.
116 | - **`beam_width`**: Integer specifying the beam width for beam search.
117 | - **`store_predictions`**: Boolean flag to enable or disable storing the predicted sequences.
118 |
119 |
120 | ## Contact
121 |
122 | This project was developed by [Amirhossein Nadiri](https://github.com/amir-ni).
123 |
124 | For any inquiries or collaboration, feel free to reach out at [anadiri@yorku.ca](mailto:anadiri@yorku.ca).
125 |
126 | ## License
127 |
128 | This project is open-source software licensed under the [LICENSE](LICENSE).
129 |
130 | ## Citation
131 |
132 | If you use this project or TrajLearn in your research, please consider citing it as follows:
133 |
134 | ```tex
135 | @article{TrajLearn,
136 | author = {Nadiri, Amirhossein and Li, Jing and Faraji, Ali and Abuoda, Ghadeer and Papagelis, Manos},
137 | title = {TrajLearn: Trajectory Prediction Learning using Deep Generative Models},
138 | year = {2025},
139 | publisher = {Association for Computing Machinery},
140 | address = {New York, NY, USA},
141 | journal = {ACM Trans. Spatial Algorithms Syst.},
142 | volume = {In Press},
143 | numpages = {33},
144 | }
145 | ```
146 |
--------------------------------------------------------------------------------
/TrajLearn/TrajectoryBatchDataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | from collections import defaultdict
5 | import torch
6 | import numpy as np
7 | import pandas as pd
8 | from torch.utils.data import IterableDataset
9 |
10 |
11 | class TrajectoryBatchDataset(IterableDataset):
12 | def __init__(self, dataset_directory, dataset_type='train', delimiter=' ', validation_ratio=0.1, test_ratio=0.2):
13 | self.dataset_directory = dataset_directory
14 |
15 | full_data = [np.array([int(j) for j in i.strip().split(delimiter)]) for i in pd.read_csv(
16 | os.path.join(dataset_directory, 'data.txt'), header=None)[0]]
17 |
18 | self.number_of_trajectories = len(full_data)
19 | self.vocab_size = sum(1 for _ in open(
20 | os.path.join(dataset_directory, 'vocab.txt'), encoding='utf-8'))
21 |
22 | if dataset_type == 'train':
23 | self.data = full_data[:-int(self.number_of_trajectories *
24 | (validation_ratio + test_ratio))]
25 | elif dataset_type == 'val':
26 | self.data = full_data[-int(self.number_of_trajectories * (
27 | validation_ratio + test_ratio)): -int(self.number_of_trajectories * test_ratio)]
28 | elif dataset_type == 'test':
29 | self.data = full_data[-int(self.number_of_trajectories * test_ratio):]
30 | else:
31 | raise ValueError('Invalid type')
32 |
33 | self.dataX = []
34 | self.dataY = []
35 | self.batches = []
36 | self.dataset_type = dataset_type
37 |
38 | def create_batches(self, batch_size, observe, predict=1, shuffle=True, drop_last=False):
39 |
40 | if isinstance(observe, int):
41 | observe = [observe]
42 | if isinstance(predict, int):
43 | predict = [predict] * len(observe)
44 |
45 | for trajectory in self.data:
46 | for j, observe_length in enumerate(observe):
47 | for i in range(0, len(trajectory) - observe_length - predict[j] + 1):
48 | self.dataX.append(trajectory[i:i+observe_length])
49 | self.dataY.append(
50 | trajectory[i+observe_length:i+observe_length+predict[j]])
51 |
52 | # Group indices of same size together
53 | size_to_indices = defaultdict(list)
54 | for i, x in enumerate(self.dataX):
55 | size_to_indices[len(x)].append(i)
56 |
57 | # Prepare the list of batches and shuffle it
58 | batches = []
59 | for size_indices in size_to_indices.values():
60 | for i in range(0, len(size_indices), batch_size):
61 | batch = size_indices[i:i+batch_size]
62 | if len(batch) == batch_size or not drop_last:
63 | batches.append(batch)
64 |
65 | if shuffle:
66 | random.shuffle(batches)
67 |
68 | self.batches = batches
69 |
70 | def get_neighbors(self):
71 | with open(os.path.join(self.dataset_directory, 'neighbors.json'), encoding='utf-8') as neighbors_file:
72 | neighbors = json.load(neighbors_file)
73 | neighbors = {int(k): v + [0] for k, v in neighbors.items()}
74 | neighbors[0] = []
75 | return neighbors
76 |
77 | def __len__(self):
78 | return len(self.batches)
79 |
80 | def __getitem__(self, index):
81 | batch_indices = self.batches[index]
82 | return torch.LongTensor(np.stack([self.dataX[i] for i in batch_indices])), torch.LongTensor(np.stack([self.dataY[i] for i in batch_indices]))
83 |
84 | def __iter__(self):
85 | # worker_info = torch.utils.data.get_worker_info()
86 |
87 | # if worker_info is None:
88 | # batches = self.batches
89 | # else:
90 | # n_workers = worker_info.num_workers
91 | # n_data = len(self.batches)
92 | # chunk_size = n_data // n_workers
93 |
94 | # chunk_start = chunk_size * worker_info.id
95 | # batches = self.batches[chunk_start: chunk_start + chunk_size]
96 | # for i in range(len(self.dataX)):
97 | # yield torch.LongTensor(self.dataX[i]), torch.LongTensor(self.dataY[i])
98 |
99 | for batch_indices in self.batches:
100 | yield torch.LongTensor(np.stack([self.dataX[i] for i in batch_indices])), torch.LongTensor(np.stack([self.dataY[i] for i in batch_indices]))
101 |
--------------------------------------------------------------------------------
/TrajLearn/config_loader.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 | default_config = {
4 | "test_ratio": 0.2,
5 | "validation_ratio": 0.1,
6 | "delimiter": " ",
7 | "min_input_length": 10,
8 | "max_input_length": 14,
9 | "test_input_length": 10,
10 | "test_prediction_length": 5,
11 | "batch_size": 128,
12 | "device": "cpu",
13 | "max_epochs": 10,
14 | "block_size": 24,
15 | "learning_rate": 5.e-3,
16 | "weight_decay": 5.e-1,
17 | "beta1": 0.9,
18 | "beta2": 0.95,
19 | "grad_clip": 1.0,
20 | "decay_lr": True,
21 | "warmup_iters": 200,
22 | "lr_decay_iters": 40000,
23 | "min_lr": 5.e-7,
24 | "seed": 42,
25 | "data_dir": "./data",
26 | "dataset": "geolife7",
27 | "n_layer": 12,
28 | "n_head": 6,
29 | "n_embd": 512,
30 | "bias": False,
31 | "dropout": 0.1,
32 | "model_checkpoint_directory": "./models/",
33 | "train_from_checkpoint_if_exist": False,
34 | "custom_initialization": False,
35 | "patience": 3,
36 | "continuity": True,
37 | "beam_width": 5,
38 | "store_predictions": False,
39 | }
40 |
41 | def load_config(config_file: str) -> dict:
42 | """
43 | Load a YAML configuration file and apply default values for missing parameters.
44 |
45 | Parameters:
46 | - config_file (str): Path to the configuration YAML file.
47 |
48 | Returns:
49 | - config_list (dict): A dictionary with the final configuration, including defaults.
50 | """
51 | with open(config_file, 'r', encoding='utf-8') as stream:
52 | try:
53 | config_list = yaml.safe_load(stream)
54 |
55 | if config_list is None:
56 | config_list = {}
57 |
58 | for config_name, config_values in config_list.items():
59 | if config_values is None:
60 | config_values = {}
61 |
62 | for key, value in default_config.items():
63 | config_values[key] = config_values.get(key, value)
64 |
65 | config_list[config_name] = config_values
66 |
67 | return config_list
68 |
69 | except yaml.YAMLError as exc:
70 | print(f"Error loading YAML file: {exc}")
71 | return None
72 |
--------------------------------------------------------------------------------
/TrajLearn/evaluator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import warnings
4 | from typing import Any
5 | from logging import Logger
6 | from tqdm import tqdm
7 | import torch
8 | from nltk.translate.bleu_score import sentence_bleu
9 | from TrajLearn.TrajectoryBatchDataset import TrajectoryBatchDataset
10 |
11 | def calculate_bleu(predictions: torch.Tensor, targets: torch.Tensor) -> float:
12 | bleu_score = 0.0
13 | with warnings.catch_warnings():
14 | warnings.simplefilter("ignore")
15 | for prediction, target in zip(predictions, targets):
16 | prediction = prediction.tolist()
17 | target = target.tolist()
18 | bleu_score += sentence_bleu([target], prediction)
19 | return bleu_score
20 |
21 |
22 | @torch.no_grad()
23 | def evaluate_model(
24 | model: torch.nn.Module,
25 | dataset: TrajectoryBatchDataset,
26 | config: Any,
27 | logger: Logger,
28 | top_k: list = None,
29 | ) -> list:
30 | model.eval()
31 | device = config["device"]
32 | device_type = 'cuda' if 'cuda' in device else 'cpu'
33 | prediction_length = config["test_prediction_length"]
34 | ctx = torch.amp.autocast(device_type=device_type, dtype=torch.float32)
35 |
36 | if top_k is None:
37 | top_k = [1, 3, 5]
38 |
39 | beam_width = config["beam_width"]
40 |
41 | if config["continuity"]:
42 | neighbors = dataset.get_neighbors()
43 |
44 | total_bleu_score = 0.0
45 | correct_predictions = {k: torch.zeros(
46 | prediction_length, dtype=torch.int32).to(device) for k in top_k}
47 |
48 | if config["store_predictions"]:
49 | pred_results_buffer = ["input sequence,true label,predicted label\n"]
50 | pred_results_file = open(os.path.join(logger.log_directory, 'predictions.txt'), 'w', encoding='utf-8')
51 |
52 | start_time = time.time()
53 |
54 | total_samples = 0
55 | for X, Y in (pbar := tqdm(dataset, leave=False)):
56 | x, y = X.to(device), Y.to(device)
57 | beams = torch.zeros((x.shape[0], 1, 0), dtype=torch.int32).to(device)
58 | scores = torch.zeros((x.shape[0], 1), dtype=torch.float32).to(device)
59 |
60 | total_samples += x.shape[0]
61 | for j in range(prediction_length):
62 | new_scores, new_beams = [], []
63 | for b in range(beams.shape[1]):
64 | beam = beams[:, b:b+1]
65 | with ctx:
66 | input_sequence = torch.cat((x, beam.squeeze(1)), dim=1)
67 | logits, _ = model(
68 | input_sequence[:, -config["block_size"]:])
69 | logits = torch.squeeze(logits, dim=1)
70 | probs = torch.softmax(logits, dim=1)
71 |
72 | # Apply the mask
73 | if config["continuity"]:
74 | last_prediction = input_sequence[:, -1]
75 | mask = torch.zeros_like(logits, dtype=torch.bool)
76 | for idx, item in enumerate(last_prediction):
77 | mask[idx, neighbors[item.item()]] = True
78 | probs[~mask] = 0
79 |
80 | # Get top-k probabilities and their indices
81 | top_probs, indices = torch.topk(probs, beam_width)
82 | indices = torch.where(
83 | indices == 0, torch.ones_like(indices), indices)
84 | # Append new indices to beam
85 | new_beam = torch.cat(
86 | (beam.repeat(1, beam_width, 1), indices.unsqueeze(2)), dim=2)
87 | new_score = scores[:, b:b+1] + \
88 | torch.log(top_probs) # Update scores
89 | new_scores.append(new_score)
90 | new_beams.append(new_beam)
91 | # Concatenate along beam dimension
92 | new_scores = torch.cat(new_scores, dim=1)
93 | # Concatenate along beam dimension
94 | new_beams = torch.cat(new_beams, dim=1)
95 | top_scores, top_beams = torch.topk(new_scores.view(
96 | X.shape[0], -1), beam_width) # Reshape scores to 2D and get top-k
97 | # Reshape beams to 3D for gathering
98 | beams = new_beams.view(X.shape[0], -1, new_beams.shape[2])
99 | beams = torch.gather(beams, 1, top_beams.unsqueeze(
100 | 2).expand(-1, -1, beams.shape[2])) # Gather the top-k beams
101 | scores = top_scores # Update scores with top-k scores
102 |
103 | for k in correct_predictions.keys():
104 | if beam_width >= k:
105 | predictions = beams[:, :k] # Get the top-k beams
106 | for beam_number in range(k):
107 | correct_predictions[k][j] += torch.sum((predictions[:, beam_number:beam_number+1].squeeze(1) ==
108 | y[:, :j+1]).all(dim=1).int())
109 |
110 | total_bleu_score += calculate_bleu(beams[:, 0], y)
111 |
112 | if config["store_predictions"]:
113 | beams_np = beams[:, 0].cpu().numpy()
114 | xs_np, ys_np = x.cpu().numpy(), y.cpu().numpy()
115 | for sample_id, x_np in enumerate(xs_np):
116 | prediction_result = f"{' '.join(map(str, x_np))}, {' '.join(map(str, ys_np[sample_id]))}, {' '.join(map(str, beams_np[sample_id]))}\n"
117 | pred_results_buffer.append(prediction_result)
118 | if len(pred_results_buffer) > 10000:
119 | pred_results_file.writelines(pred_results_buffer)
120 | pred_results_buffer = []
121 |
122 | acc1 = ((100 * correct_predictions[1][prediction_length-1]) / total_samples).item()
123 | pbar.set_postfix(**{"Acc@1": acc1})
124 |
125 | if config["store_predictions"]:
126 | pred_results_file.writelines(pred_results_buffer)
127 |
128 | test_duration = time.time() - start_time
129 |
130 | avg_bleu_score = total_bleu_score / total_samples
131 | acc = {k: (100 * v) / (total_samples)
132 | for k, v in correct_predictions.items()}
133 | results = [f"Dataset: {config['dataset']}"]
134 | for k, v in acc.items():
135 | results.append(f"Accuracy@{k}: {v[-1]:.4f}")
136 | results.append(f"BLEU score: {avg_bleu_score:.4f}")
137 | results.append(f"Test duration: {test_duration:.3f}(s)")
138 | results.append(f"Samples: {total_samples}")
139 | logger.info(", ".join(results))
140 |
141 | return results
142 |
--------------------------------------------------------------------------------
/TrajLearn/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from logging.handlers import RotatingFileHandler
4 |
5 | def get_logger(log_directory,
6 | name,
7 | phase="train",
8 | console_level=logging.DEBUG,
9 | file_level=logging.INFO,
10 | max_file_size=10 * 1024 * 1024,
11 | backup_count=5,
12 | log_format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d]: %(message)s',
13 | date_format='%Y-%m-%d %H:%M:%S'):
14 | """
15 | Creates and configures a logger with console and file handlers.
16 |
17 | Parameters:
18 | - log_directory (str): Directory where log files will be stored.
19 | - phase (str): Logging phase (e.g., "train", "test"). The log file will be named accordingly.
20 | - console_level (int): Logging level for console output (default: logging.DEBUG).
21 | - file_level (int): Logging level for file output (default: logging.INFO).
22 | - max_file_size (int): Maximum size of the log file before it gets rotated (default: 10MB).
23 | - backup_count (int): Number of backup log files to keep (default: 5).
24 | - log_format (str): Format of the log messages.
25 | - date_format (str): Format of the date in log messages.
26 |
27 | Returns:
28 | - logger (logging.Logger): Configured logger object.
29 | """
30 | logger = logging.getLogger(f"{name}-{phase}")
31 |
32 | if not logger.hasHandlers():
33 | logger.setLevel(logging.DEBUG)
34 |
35 | formatter = logging.Formatter(log_format, date_format)
36 |
37 | console_handler = logging.StreamHandler()
38 | console_handler.setLevel(console_level)
39 | console_handler.setFormatter(formatter)
40 |
41 | os.makedirs(log_directory, exist_ok=True)
42 | logger.log_directory = log_directory
43 |
44 | logfile = os.path.join(log_directory, f"{phase}.log")
45 | file_handler = RotatingFileHandler(logfile, mode='a',
46 | maxBytes=max_file_size, backupCount=backup_count)
47 | file_handler.setLevel(file_level)
48 | file_handler.setFormatter(formatter)
49 |
50 | logger.addHandler(console_handler)
51 | logger.addHandler(file_handler)
52 |
53 | return logger
54 |
--------------------------------------------------------------------------------
/TrajLearn/mixed_res.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ast
3 | import json
4 | import pickle
5 | import argparse
6 | from pathlib import Path
7 | from collections import defaultdict
8 | from collections.abc import Callable
9 |
10 | import h3
11 | import geopandas
12 | import matplotlib
13 | import numpy as np
14 | import pandas as pd
15 | import contextily as cx
16 | import matplotlib.pyplot as plt
17 | from matplotlib.colors import ListedColormap
18 |
19 |
20 | def gps_to_h3(
21 | gps_points: list,
22 | resolution: int,
23 | boundary: dict,
24 | ):
25 | """
26 | Convert a list of GPS coordinates to H3 hexagons at a specified resolution.
27 |
28 | Arguments:
29 | - gps_points: list of tuples, where each tuple contains (longitude, latitude) in degrees.
30 | Example: [(lon1, lat1), (lon2, lat2), ...]
31 | - resolution: int, the H3 resolution level to use for converting coordinates.
32 | Higher resolutions create smaller hexagons.
33 |
34 | Returns:
35 | - A list of H3 hexagon IDs corresponding to the input GPS coordinates.
36 | """
37 | result = []
38 | for lon, lat in gps_points:
39 | if boundary["Min_lon"] <= lon and boundary["Min_lat"] <= lat and \
40 | boundary["Max_lon"] >= lon and boundary["Max_lat"] >= lat:
41 | cell = h3.latlng_to_cell(lat, lon, resolution)
42 | if len(result) == 0 or cell != result[-1]:
43 | result.append(h3.latlng_to_cell(lat, lon, resolution))
44 | return result
45 |
46 |
47 | def preprocess_resolution(
48 | dataset: str = 'geolife',
49 | min_resolution: int = 7,
50 | max_resolution: int = 9,
51 | output_dir: str = 'data/mixed_res',
52 | save_csv: bool = False,
53 | use_boundary: bool = True,
54 | ):
55 | """
56 | Preprocess a trajectory dataset by converting GPS points to H3 hexagon representations
57 | at various resolutions, and save the results and associated data structures.
58 |
59 | This function reads a CSV file containing GPS points, converts these points to H3
60 | hexagon IDs at specified resolutions, and optionally saves the results as CSV and
61 | pickle files. It also calculates hexagon counts and neighbor relationships.
62 |
63 | Arguments:
64 | - dataset: str, name of the dataset to process (e.g., 'geolife', 'rome', 'porto').
65 | - min_resolution: int, the minimum H3 resolution level for hexagon generation.
66 | - max_resolution: int, the maximum H3 resolution level for hexagon generation.
67 | - output_dir: str, the directory where processed data and outputs will be saved.
68 | - save_csv: bool, if True, saves processed data and hexagon counts as CSV files.
69 |
70 | Steps Performed:
71 | 1. Reads GPS points from a dataset CSV file and applies a transformation to generate
72 | H3 hexagon IDs at each resolution from `min_resolution` to `max_resolution`.
73 | 2. Saves the transformed dataset to a CSV and pickle file if `save_csv` is True.
74 | 3. Computes and saves hexagon counts and unique hexagon IDs at each resolution.
75 | 4. Computes neighbor relationships for hexagons and saves them as a pickle file.
76 |
77 | Outputs:
78 | - Processed dataset with hexagon columns at specified resolutions (CSV and/or pickle).
79 | - Hexagon count file, detailing how many times each hexagon appears (CSV and pickle).
80 | - A dictionary of unique hexagons at each resolution (pickle).
81 | - Neighbor relationships for each hexagon, excluding child hexagons (pickle).
82 | """
83 | os.makedirs(output_dir, exist_ok=True)
84 |
85 | boundary = {
86 | 'geolife':{
87 | "Min_lat":39.50000000,
88 | "Max_lat":40.50000000,
89 | "Min_lon":116.00000000,
90 | "Max_lon":117.00000000
91 | },
92 | 'rome':{
93 | "Min_lat":41.793710,
94 | "Max_lat":41.991390,
95 | "Min_lon":12.372598,
96 | "Max_lon":12.622537
97 | },
98 | 'porto': {
99 | "Min_lat": -1000,
100 | "Max_lat": 1000,
101 | "Min_lon": -1000,
102 | "Max_lon": 1000
103 | }
104 | }
105 |
106 | data_path = {
107 | 'geolife':'data/raw_aggrigated/geolife_aggregated.csv',
108 | 'rome': 'data/raw_aggrigated/rome_taxi_aggregated.csv',
109 | 'porto': 'data/raw_aggrigated/porto.csv'
110 | }
111 |
112 | file_path = data_path[dataset]
113 | print(f"Processing file: {file_path}")
114 |
115 | target_column = 'route_points' if 'geolife' in file_path or 'rome_taxi' in file_path else 'TIMESTAMP'
116 | target_time = 'date' if 'geolife' in file_path or 'rome_taxi' in file_path else 'POLYLINE'
117 | df = pd.read_csv(file_path, usecols=[target_column, target_time]).rename(columns={target_column: 'points', target_time: 'time'})
118 |
119 | df['points'] = df['points'].apply(ast.literal_eval)
120 | print("Processed points")
121 | for resolution in range(min_resolution, max_resolution+1):
122 | column_name = f'hex_{resolution}'
123 | if use_boundary:
124 | df[column_name] = df['points'].apply(gps_to_h3, args=(resolution, boundary[dataset],))
125 | else:
126 | df[column_name] = df['points'].apply(gps_to_h3, args=(resolution,))
127 | print(f"Processed resolution: {resolution}")
128 |
129 | if save_csv:
130 | df.to_csv(f"{output_dir}/{dataset}.csv", index=False)
131 | print("Saved dataset csv")
132 |
133 | with open(f'{output_dir}/{dataset}.pkl', 'wb') as f:
134 | pickle.dump(df, f)
135 | print("Saved dataset pickle")
136 |
137 | hex_counts = defaultdict(int)
138 | hexes = {res: set() for res in range(min_resolution, max_resolution + 1)}
139 |
140 | for res in range(min_resolution, max_resolution + 1):
141 | column_name = f'hex_{res}'
142 | for hex_list in df[column_name]:
143 | for hex_id in hex_list:
144 | hex_counts[hex_id] += 1
145 | hexes[res].add(hex_id)
146 |
147 | if save_csv:
148 | hex_df = pd.DataFrame.from_dict(hex_counts, orient='index', columns=['occurrences'])
149 | hex_df.index.name = 'hex_id'
150 | hex_df.to_csv(f'{output_dir}/hex_count_{dataset}.csv')
151 | print("Saved hexagon count csv")
152 |
153 |
154 | with open(f'{output_dir}/hexes_{dataset}.pkl', 'wb') as f:
155 | pickle.dump(hexes, f)
156 | print("Saved hexagons dictionary pickle")
157 |
158 |
159 | with open(f'{output_dir}/hex_count_{dataset}.pkl', 'wb') as f:
160 | pickle.dump(hex_counts, f)
161 | print("Saved hexagon count pickle")
162 |
163 | neighbors = defaultdict(set)
164 | for resolution in range(min_resolution, max_resolution+1):
165 | for hex_id in hexes[resolution]:
166 | hex_neighbors = set()
167 | hex_children = set()
168 |
169 | for children_resolution in range(resolution+1, max_resolution+1):
170 | hex_children.update(h3.cell_to_children(hex_id, children_resolution))
171 |
172 | for hex_child in hex_children:
173 | hex_neighbors.update(h3.grid_ring(hex_child, 1))
174 | neighbors[hex_id] = hex_neighbors - hex_children
175 |
176 | for hex_id, neighbors_set in list(neighbors.items()):
177 | for neighbor in neighbors_set:
178 | neighbors[neighbor].add(hex_id)
179 |
180 | with open(f'{output_dir}/neighbors_{dataset}.pkl', 'wb') as f:
181 | pickle.dump(neighbors, f)
182 | print("Saved hexagon neighbors pickle")
183 |
184 |
185 | def visualize(hex_counts: dict, output_path: str, bins:list = None, **kwargs):
186 | """
187 | Generate and display a heatmap of hexagon counts on a map, and save it as an image.
188 |
189 | Arguments:
190 | - hex_seq: list of str, a list of hexagon sequences.
191 | - zoom_level: int, the zoom level for the map.
192 | - output_dir: str, the directory path to save any output if necessary.
193 | - **kwargs: Options to pass to geopandas plotting method.
194 |
195 | The function flattens the input hex sequences, counts the occurrences of each hexagon,
196 | creates a GeoJSON object, and visualizes it using Plotly with a heatmap.
197 | """
198 | matplotlib.rcParams.update({'font.size': 28})
199 | df_hex_plot = pd.DataFrame(hex_counts.items(), columns=['hex_id', 'count'])
200 |
201 | df_hex_plot['geometry'] = df_hex_plot['hex_id'].apply(lambda x: h3.cells_to_h3shape([x]))
202 | df = geopandas.GeoDataFrame(df_hex_plot.drop(columns=['hex_id']), crs='EPSG:4326')
203 | df = df.to_crs(epsg=3857)
204 |
205 | _, ax = plt.subplots(figsize=(24,24))
206 | ax.get_xaxis().set_visible(False)
207 | ax.get_yaxis().set_visible(False)
208 |
209 | if bins is not None:
210 | labels = ['Low', 'Medium', 'High']
211 | df['count'] = pd.cut(df['count'], bins=bins, labels=labels, include_lowest=True)
212 |
213 | df.plot(
214 | ax=ax,
215 | alpha=0.9,
216 | edgecolor=(134/256, 218/256, 227/256, 0.1),
217 | linewidth=0.001,
218 | column='count',
219 | legend=True,
220 | **kwargs,
221 | )
222 |
223 | cx.add_basemap(ax, crs=df.crs, source=cx.providers.CartoDB.Positron)
224 |
225 | plt.tight_layout()
226 | plt.savefig(output_path, format="pdf", bbox_inches='tight')
227 |
228 |
229 | def threshold_split_condition(threshold: int = 100):
230 | """
231 | Create a custom split condition function based on a threshold for hexagon occurrences.
232 |
233 | This function returns a nested `split_condition` function that evaluates whether
234 | a hexagon should be split based on the number of occurrences compared to a specified
235 | threshold.
236 |
237 | Arguments:
238 | - threshold: int, the minimum number of occurrences required for a hexagon to be
239 | considered for splitting. Default is 100.
240 |
241 | Returns:
242 | - A function `split_condition(current_res, hex_count, neighbors_stat)` that takes:
243 | - current_res: int, the current resolution of the hexagon (not used in this function).
244 | - hex_count: int, the count of occurrences in the current hexagon.
245 | - neighbors_stat: any, information/statistics about neighboring hexagons
246 | (not used in this function).
247 | The `split_condition` function returns `True` if `hex_count` exceeds the threshold,
248 | indicating the hexagon should be split; otherwise, it returns `False`.
249 | """
250 | def split_condition(current_res, hex_count, neighbors_stat):
251 | """
252 | Decide if a hexagon should split based on its occurrence count.
253 |
254 | Arguments:
255 | - current_res: int, the current resolution of the hexagon (not used in this function).
256 | - hex_count: int, the count of occurrences in the current hexagon.
257 | - neighbors_stat: any, information/statistics about neighboring hexagons
258 | (not used in this function).
259 |
260 | Returns:
261 | - True if `hex_count` exceeds the specified `threshold`, indicating the hexagon
262 | should be split; otherwise, False.
263 | """
264 | return hex_count > threshold
265 |
266 | return split_condition
267 |
268 |
269 | def complex_split_condition(threshold=100, std_ratio=2):
270 | """
271 | Create a split condition function that decides whether to split a hexagon based on:
272 | 1. The normalized occurrence count (using the coefficient of variation) exceeding a threshold.
273 | 2. Significant normalized variance in the number of occurrences between the hexagon and its neighbors.
274 |
275 | Arguments:
276 | - threshold_ratio: float, the minimum coefficient of variation for the hexagon's count to consider splitting.
277 | - variance_ratio: float, the minimum normalized variance between the hexagon and its neighbors for significant change.
278 |
279 | Returns:
280 | - A function `split_condition(current_res, hex_count, neighbors_stat)` that returns True if the
281 | hexagon meets the split criteria; otherwise, False.
282 | """
283 | def split_condition(current_res, hex_count, neighbors_stat):
284 | """
285 | Determine if a hexagon should be split based on normalized variance and threshold ratio.
286 |
287 | Arguments:
288 | - current_res: int, the current resolution of the hexagon.
289 | - hex_count: int, the number of occurrences in the current hexagon.
290 | - neighbors_stat: dict, where keys are neighbor hex IDs and values are their occurrence counts.
291 |
292 | Returns:
293 | - Boolean: True if the hexagon meets the split conditions; otherwise, False.
294 | """
295 | if not neighbors_stat:
296 | return False
297 |
298 | neighbor_counts = np.array(list(neighbors_stat.values()))
299 | mean_neighbors = np.mean(neighbor_counts)
300 | if mean_neighbors == 0:
301 | return False
302 |
303 | if hex_count <= threshold:
304 | return False
305 |
306 | normalized_std = np.std(neighbor_counts) / mean_neighbors
307 |
308 | return normalized_std > std_ratio
309 |
310 | return split_condition
311 |
312 |
313 | def skewness_stopping_condition(threshold: float = 0.2):
314 | """
315 | Create a stopping condition function based on the skewness of a distribution.
316 |
317 | This function returns a nested `stopping_condition` function that evaluates whether
318 | the skewness of the distribution of hexagon occurrences meets a specified threshold.
319 | The stopping condition is used to decide if the iterative process should halt.
320 |
321 | Arguments:
322 | - threshold: float, the skewness threshold for stopping. If the skewness of the
323 | distribution of occurrences is less than this threshold, the condition
324 | returns `True`, indicating that the process should stop. Default is 0.2.
325 |
326 | Returns:
327 | - A function `stopping_condition(hexagons)` that takes:
328 | - hexagons: dict[str, int], a dictionary where keys are hexagon IDs and values are
329 | the count of occurrences in each hexagon.
330 | The function returns `True` if the skewness of the values in `hexagons` is below
331 | the specified threshold, indicating that the stopping condition is met; otherwise,
332 | it returns `False`.
333 | """
334 | def stopping_condition(hexagons: dict):
335 | """
336 | Evaluate the skewness of the hexagon occurrence counts to determine if the process should stop.
337 |
338 | Arguments:
339 | - hexagons: dict[str, int], a dictionary where keys are hexagon IDs and values are
340 | the count of occurrences in each hexagon.
341 |
342 | Returns:
343 | - True if the skewness is less than the specified threshold, indicating that the
344 | condition for stopping is met; otherwise, False.
345 | """
346 | data = np.array(list(hexagons.values()))
347 | n = len(data)
348 | if n < 3:
349 | raise ValueError("Skewness calculation requires at least 3 data points.")
350 |
351 | mean = np.mean(data)
352 | std_dev = np.std(data, ddof=1)
353 | skewness = (n / ((n - 1) * (n - 2))) * np.sum(((data - mean) / std_dev) ** 3)
354 | print(skewness)
355 | return skewness < threshold
356 |
357 | return stopping_condition
358 |
359 |
360 | def mixed_resolution(
361 | split_condition_fn: Callable,
362 | stopping_condition_fn: Callable,
363 | dataset: str = 'geolife',
364 | min_resolution: int = 7,
365 | max_resolution: int = 9,
366 | input_dir: str = 'data/mixed_res',
367 | output_dir: str = 'data/mixed_res',
368 | max_iterations: int = 5,
369 | ):
370 | """
371 | Perform iterative hexagon refinement based on split and stopping conditions.
372 |
373 | This function reads preprocessed data including hexagon counts, unique hexagon sets,
374 | and neighbor relationships, and iteratively refines the hexagon set by splitting
375 | hexagons that meet the given `split_condition_fn`. The process halts when the
376 | `stopping_condition_fn` is satisfied or the maximum number of iterations is reached.
377 |
378 | Arguments:
379 | - split_condition_fn: Callable, a function that takes the current resolution,
380 | hexagon count, and neighbor statistics and returns a boolean
381 | indicating whether the hexagon should be split.
382 | - stopping_condition_fn: Callable, a function that takes a dictionary of hexagons
383 | and their counts and returns a boolean indicating whether
384 | the stopping condition is met.
385 | - dataset: str, the name of the dataset to process (e.g., 'geolife').
386 | - min_resolution: int, the initial resolution of hexagons to start processing from.
387 | - max_resolution: int, the maximum resolution allowed for splitting hexagons.
388 | - input_dir: str, the directory path to load the preprocessed input files.
389 | - output_dir: str, the directory path to save any output if necessary.
390 | - max_iterations: int, the maximum number of iterations for the process.
391 |
392 | Process:
393 | 1. Load preprocessed neighbor relationships, hexagon counts, and unique hexagon sets
394 | from pickle files.
395 | 2. Initialize the hexagon set from the `min_resolution` level.
396 | 3. Iterate up to `max_iterations` times, splitting hexagons that meet the `split_condition_fn`.
397 | 4. Check the `stopping_condition_fn` at each iteration to potentially stop the process early.
398 | 5. Split marked hexagons into their children at the next higher resolution.
399 | 6. Print status messages and information about each iteration.
400 |
401 | Outputs:
402 | - Prints the number of hexagons processed at each iteration.
403 | - Stops the process when the stopping condition is met or no hexagons meet the split condition.
404 | """
405 | os.makedirs(output_dir, exist_ok=True)
406 |
407 | with open(f'{input_dir}/neighbors_{dataset}.pkl', 'rb') as f:
408 | neighbors = pickle.load(f)
409 |
410 | with open(f'{input_dir}/hex_count_{dataset}.pkl', 'rb') as f:
411 | hex_counts = pickle.load(f)
412 |
413 | with open(f'{input_dir}/hexes_{dataset}.pkl', 'rb') as f:
414 | hexes = pickle.load(f)
415 |
416 | hexagon_set = hexes[min_resolution]
417 | iteration = 0
418 |
419 | while iteration < max_iterations:
420 | iteration += 1
421 | print(f"Iteration {iteration}: Dataset has {len(hexagon_set)} hexagons")
422 | if stopping_condition_fn({hexagon: hex_counts[hexagon] for hexagon in hexagon_set}):
423 | print(f"Stopping condition invoked at iteration {iteration}")
424 | break
425 |
426 | marked_for_split = set()
427 |
428 | for hex_id in hexagon_set:
429 | current_res = h3.get_resolution(hex_id)
430 | if current_res == max_resolution:
431 | continue
432 | neighbors_stat = {neighbor: hex_counts[neighbor] for neighbor in neighbors[hex_id]}
433 | hex_count = hex_counts[hex_id]
434 |
435 | if split_condition_fn(current_res, hex_count, neighbors_stat):
436 | marked_for_split.add(hex_id)
437 |
438 | if len(marked_for_split) == 0:
439 | print(f"None of hexagons met split condition at iteration {iteration}")
440 | break
441 | else:
442 | print(f"{len(marked_for_split)} of hexagons met split condition at iteration {iteration}")
443 |
444 | new_hexagon_set = set()
445 | for hexagon in hexagon_set:
446 | if hexagon in marked_for_split:
447 | new_hexagon_set.update(h3.cell_to_children(hexagon))
448 | else:
449 | new_hexagon_set.add(hexagon)
450 | hexagon_set = new_hexagon_set
451 |
452 | with open(f'{output_dir}/final_hexes_{dataset}.pkl', 'wb') as f:
453 | pickle.dump(hexagon_set, f)
454 | print("Saved final hexagons set pickle")
455 |
456 | visualize({hexagon: h3.get_resolution(hexagon) for hexagon in hexagon_set}, "mixed-res-heatmap.pdf", categorical=True, cmap=ListedColormap(["#66CDAA", "#9370DB", "#86DAE3"]), legend_kwds={"loc": "upper right", "title":"Resolution", "markerscale":2.5},)
457 | visualize({hexagon: hex_counts[hexagon] for hexagon in hexes[min_resolution]}, "mixed-res-map.pdf", bins=[0,500,10000,np.inf], cmap=ListedColormap(["#66CDAA", "#9370DB", "#86DAE3"]), categorical=True, k=10, legend_kwds={"loc": "upper right", "title":"Movement Density", "markerscale":2.5,},)
458 | print("Heatmap saved")
459 |
460 |
461 | def apply_processing(
462 | dataset: str = 'geolife',
463 | min_resolution: int = 7,
464 | max_resolution: int = 9,
465 | output_dir: str = 'data/mixed_res',
466 | date_column: str = "time"
467 | ):
468 |
469 | with open(f'{output_dir}/{dataset}.pkl', 'rb') as f:
470 | hexes = pickle.load(f)
471 |
472 | with open(f'{output_dir}/final_hexes_{dataset}.pkl', 'rb') as f:
473 | hexagon_set = pickle.load(f)
474 |
475 |
476 | boundary = {
477 | 'geolife':{
478 | "Min_lat":39.50000000,
479 | "Max_lat":40.50000000,
480 | "Min_lon":116.00000000,
481 | "Max_lon":117.00000000
482 | },
483 | 'rome':{
484 | "Min_lat":41.793710,
485 | "Max_lat":41.991390,
486 | "Min_lon":12.372598,
487 | "Max_lon":12.622537
488 | },
489 | 'porto': {
490 | "Min_lat": -1000,
491 | "Max_lat": 1000,
492 | "Min_lon": -1000,
493 | "Max_lon": 1000
494 | }
495 | }
496 |
497 | def gps_to_mixed_h3(gps_points: list, boundary: dict):
498 | result = []
499 | for lon, lat in gps_points:
500 | if boundary["Min_lon"] <= lon and boundary["Min_lat"] <= lat and \
501 | boundary["Max_lon"] >= lon and boundary["Max_lat"] >= lat:
502 | for resolution in range(min_resolution, max_resolution+1):
503 | cell = h3.latlng_to_cell(lat, lon, resolution)
504 | if cell in hexagon_set:
505 | if len(result) == 0 or cell != result[-1]:
506 | result.append(cell)
507 | break
508 | return result
509 |
510 | hexes['points'] = hexes['points'].apply(gps_to_mixed_h3, args=(boundary[dataset],))
511 |
512 | output_dir = Path(output_dir)
513 | (output_dir / dataset).mkdir(parents=True, exist_ok=True)
514 |
515 | vocab = ["EOT"] + list(hexagon_set)
516 | vocab_file_path = output_dir / dataset / 'vocab.txt'
517 | with vocab_file_path.open('w', encoding='utf-8') as vocab_file:
518 | vocab_file.write("\n".join(vocab) + "\n")
519 |
520 | mapping = {k: v for v, k in enumerate(vocab)}
521 | mapping_file_path = output_dir / dataset / 'mapping.json'
522 | with mapping_file_path.open('w', encoding='utf-8') as mapping_file:
523 | json.dump(mapping, mapping_file, ensure_ascii=False)
524 |
525 | hexes['points'] = hexes['points'].apply(lambda x: [str(mapping[j]) for j in x])
526 | df_mapped = hexes.sort_values(by=[date_column])["points"].to_list()
527 | data_file_path = output_dir / dataset / 'data.txt'
528 | with data_file_path.open('w', encoding='utf-8') as data_file:
529 | for item in df_mapped:
530 | data_file.write(' '.join(item) + f" {mapping['EOT']}\n")
531 |
532 |
533 | def main() -> None:
534 | """
535 | Main function to handle argument parsing and execute data preprocessing.
536 | """
537 | parser = argparse.ArgumentParser(description='Trajectory Prediction Learning')
538 | parser.add_argument('dataset', type=str, help='Dataset')
539 | args = parser.parse_args()
540 |
541 | preprocess_resolution(dataset=args.dataset)
542 | mixed_resolution(
543 | split_condition_fn=threshold_split_condition(150),
544 | stopping_condition_fn=skewness_stopping_condition(0.2),
545 | dataset=args.dataset
546 | )
547 | apply_processing(dataset=args.dataset)
548 |
549 |
550 | if __name__ == '__main__':
551 | main()
552 |
--------------------------------------------------------------------------------
/TrajLearn/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import inspect
3 | from dataclasses import dataclass
4 | from typing import Optional, Tuple
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | # @torch.jit.script
11 |
12 | class CausalSelfAttention(nn.Module):
13 | """Self-attention module with causal masking."""
14 |
15 | def __init__(self, config):
16 | super().__init__()
17 | assert config.n_embd % config.n_head == 0, "Embedding size must be divisible by the number of heads."
18 | self.n_head = config.n_head
19 | self.n_embd = config.n_embd
20 | self.dropout = config.dropout
21 | self.c_attn = nn.Linear(
22 | config.n_embd, 3 * config.n_embd, bias=config.bias)
23 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
24 | self.attn_dropout = nn.Dropout(config.dropout)
25 | self.resid_dropout = nn.Dropout(config.dropout)
26 | self.register_buffer(
27 | "bias",
28 | torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size)
29 | )
30 |
31 | def forward(self, x: torch.Tensor) -> torch.Tensor:
32 | B, T, C = x.size()
33 |
34 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
35 | k = k.view(B, T, self.n_head, C //
36 | self.n_head).transpose(1, 2) # (B, nh, T, hs)
37 | q = q.view(B, T, self.n_head, C //
38 | self.n_head).transpose(1, 2) # (B, nh, T, hs)
39 | v = v.view(B, T, self.n_head, C //
40 | self.n_head).transpose(1, 2) # (B, nh, T, hs)
41 |
42 | # (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
43 |
44 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
45 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
46 | att = F.softmax(att, dim=-1)
47 | att = self.attn_dropout(att)
48 | y = att @ v
49 | y = y.transpose(1, 2).contiguous().view(B, T, C)
50 |
51 | y = self.resid_dropout(self.c_proj(y))
52 | return y
53 |
54 |
55 | class MLP(nn.Module):
56 | """Multilayer Perceptron with GELU activation."""
57 |
58 | def __init__(self, config):
59 | super().__init__()
60 | self.c_fc = nn.Linear(
61 | config.n_embd, 4 * config.n_embd, bias=config.bias)
62 | self.c_proj = nn.Linear(
63 | 4 * config.n_embd, config.n_embd, bias=config.bias)
64 | self.dropout = nn.Dropout(config.dropout)
65 |
66 | def forward(self, x: torch.Tensor) -> torch.Tensor:
67 | x = self.c_fc(x)
68 | x = F.gelu(x, approximate='tanh')
69 | x = self.c_proj(x)
70 | return self.dropout(x)
71 |
72 |
73 | class Block(nn.Module):
74 | """Transformer block consisting of attention and MLP layers."""
75 |
76 | def __init__(self, config):
77 | super().__init__()
78 | self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
79 | self.attn = CausalSelfAttention(config)
80 | self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
81 | self.mlp = MLP(config)
82 |
83 | def forward(self, x: torch.Tensor) -> torch.Tensor:
84 | x = x + self.attn(self.ln_1(x))
85 | return x + self.mlp(self.ln_2(x))
86 |
87 |
88 | @dataclass
89 | class ModelConfig:
90 | block_size: int = 1024
91 | vocab_size: int = 50304
92 | n_layer: int = 12
93 | n_head: int = 12
94 | n_embd: int = 768
95 | dropout: float = 0.0
96 | bias: bool = True
97 |
98 |
99 | class CausalLM(nn.Module):
100 | """Causal Language Model with optional custom initialization."""
101 |
102 | def __init__(self, config, custom_init=None):
103 | super().__init__()
104 | assert config.vocab_size is not None
105 | assert config.block_size is not None
106 | self.config = config
107 |
108 | self.transformer = nn.ModuleDict(dict(
109 | wte=nn.Embedding(config.vocab_size, config.n_embd),
110 | wpe=nn.Embedding(config.block_size, config.n_embd),
111 | drop=nn.Dropout(config.dropout),
112 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
113 | ln_f=nn.LayerNorm(config.n_embd, bias=config.bias),
114 | ))
115 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
116 | self.transformer.wte.weight = self.lm_head.weight
117 |
118 | self.apply(self._init_weights)
119 | for name, param in self.named_parameters():
120 | if name.endswith('c_proj.weight'):
121 | nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
122 | if custom_init is not None:
123 | self.transformer.wte.weight.data.copy_(custom_init)
124 |
125 | def _init_weights(self, module: nn.Module):
126 | if isinstance(module, (nn.Linear, nn.Embedding)):
127 | nn.init.normal_(module.weight, mean=0.0, std=0.02)
128 | if hasattr(module, 'bias') and module.bias is not None:
129 | nn.init.zeros_(module.bias)
130 |
131 | def forward(self, idx, targets=None):
132 | device = idx.device
133 | b, t = idx.size()
134 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
135 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # (1, t)
136 |
137 | tok_emb = self.transformer.wte(idx) # (b, t, n_embd)
138 | pos_emb = self.transformer.wpe(pos) # (1, t, n_embd)
139 | x = self.transformer.drop(tok_emb + pos_emb)
140 | for block in self.transformer.h:
141 | x = block(x)
142 | x = self.transformer.ln_f(x)
143 | logits = self.lm_head(x[:, [-1], :])
144 |
145 | if targets is not None:
146 | loss = F.cross_entropy(
147 | logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
148 | else:
149 | loss = None
150 |
151 | return logits, loss
152 |
153 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
154 | """
155 | This long function is unfortunately doing something very simple and is being very defensive:
156 | We are separating out all parameters of the model into two buckets: those that will experience
157 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
158 | We are then returning the PyTorch optimizer object.
159 | """
160 |
161 | decay = set()
162 | no_decay = set()
163 | whitelist_weight_modules = (nn.Linear, )
164 | blacklist_weight_modules = (
165 | nn.LayerNorm, nn.Embedding)
166 | for mn, m in self.named_modules():
167 | for pn, p in m.named_parameters():
168 | fpn = '%s.%s' % (mn, pn) if mn else pn
169 | if pn.endswith('bias'):
170 | no_decay.add(fpn)
171 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
172 | if fpn != 'lm_head.weight':
173 | decay.add(fpn)
174 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
175 | no_decay.add(fpn)
176 |
177 | param_dict = {pn: p for pn, p in self.named_parameters()}
178 | optim_groups = [
179 | {"params": [param_dict[pn] for pn in sorted(
180 | list(decay))], "weight_decay": weight_decay},
181 | {"params": [param_dict[pn]
182 | for pn in sorted(list(no_decay))], "weight_decay": 0.0},
183 | ]
184 | use_fused = (device_type == 'cuda') and (
185 | 'fused' in inspect.signature(torch.optim.AdamW).parameters)
186 | extra_args = dict(fused=True) if use_fused else dict()
187 | optimizer = torch.optim.AdamW(
188 | optim_groups, lr=learning_rate, betas=betas, **extra_args)
189 |
190 | return optimizer
191 |
--------------------------------------------------------------------------------
/TrajLearn/preprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from pathlib import Path
4 | from typing import List, Dict, Optional
5 | import numpy as np
6 | import pandas as pd
7 | import h3
8 |
9 | def generate_embeddings(
10 | vocab: list,
11 | embedding_dim: int,
12 | mean: float = 0,
13 | std: float = 0.02,
14 | projection_matrix: np.ndarray = None,
15 | random_seed: int = 10,
16 | ) -> np.ndarray:
17 | """
18 | Generates embeddings for a given vocabulary based on axial coordinates.
19 |
20 | Args:
21 | vocab (list): A list of H3 hex addresses representing the vocabulary for which embeddings will be generated.
22 | embedding_dim (int): The dimension of the generated embedding vectors.
23 | mean (float, optional): The mean of the normal distribution used for embedding normalization. Default is 0.
24 | std (float, optional): The standard deviation of the normal distribution used for embedding normalization. Default is 0.02.
25 | projection_matrix (np.ndarray, optional): A 2D numpy array used for projecting the axial coordinates. If None, a random matrix will be generated.
26 | """
27 | origin_hex = vocab[0]
28 | base_i, base_j = h3.cell_to_local_ij(origin_hex, origin_hex)
29 |
30 | axial_coordinates = []
31 | for h3_hex in vocab:
32 | target_i, target_j = h3.cell_to_local_ij(origin_hex, h3_hex)
33 | q, r = target_i - base_i, target_j - base_j
34 | axial_coordinates.append((q,r))
35 |
36 | np.random.seed(random_seed)
37 | if projection_matrix is None:
38 | projection_matrix = np.random.randn(2, embedding_dim)
39 |
40 | projected_embedding = np.dot(axial_coordinates, projection_matrix)
41 |
42 | # standardized_embedding = (projected_embedding - np.mean(projected_embedding)) / np.std(projected_embedding)
43 |
44 | # normalized_embedding = standardized_embedding * std + mean
45 |
46 | eot_embedding = np.random.normal(loc=mean, scale=std, size=(1, embedding_dim))
47 |
48 | normal_samples = np.random.normal(loc=mean, scale=std, size=(len(vocab) * embedding_dim))
49 | flat_projected_embedding = projected_embedding.flatten()
50 | sorted_indices = np.argsort(flat_projected_embedding)
51 | sorted_projected_embedding = np.empty_like(flat_projected_embedding)
52 | sorted_projected_embedding[sorted_indices] = np.sort(normal_samples)
53 | sorted_projected_embedding = sorted_projected_embedding.reshape(projected_embedding.shape)
54 |
55 | return np.concatenate((eot_embedding, sorted_projected_embedding), axis=0)
56 |
57 |
58 | def process_datasets(
59 | input_dir: Path,
60 | output_dir: Path,
61 | datasets_to_process: List[str],
62 | embedding_dim: Optional[int]
63 | ) -> None:
64 | """
65 | Process trajectory datasets for geolife, porto, and rome, generating vocab, mapping, neighbors,
66 | and transformed trajectory data.
67 |
68 | Args:
69 | input_dir (Path): Directory containing the input datasets.
70 | output_dir (Path): Directory where the processed data will be saved.
71 | datasets_to_process (List[str]): List of datasets to process (e.g., 'geolife', 'porto', 'rome').
72 | """
73 | datasets = {
74 | "geolife": [
75 | ("geolife7", input_dir / "geolife" / "ho_geolife_res7.csv", "date"),
76 | ("geolife8", input_dir / "geolife" / "ho_geolife_res8.csv", "date"),
77 | ("geolife9", input_dir / "geolife" / "ho_geolife_res9.csv", "date")
78 | ],
79 | "porto": [
80 | ("porto7", input_dir / "porto" / "ho_porto_res7.csv", "TIMESTAMP"),
81 | ("porto8", input_dir / "porto" / "ho_porto_res8.csv", "TIMESTAMP"),
82 | ("porto9", input_dir / "porto" / "ho_porto_res9.csv", "TIMESTAMP")
83 | ],
84 | "rome": [
85 | ("rome7", input_dir / "rome" / "ho_rome_res7.csv", "date"),
86 | ("rome8", input_dir / "rome" / "ho_rome_res8.csv", "date"),
87 | ("rome9", input_dir / "rome" / "ho_rome_res9.csv", "date")
88 | ]
89 | }
90 |
91 | for dataset_key in datasets_to_process:
92 | if dataset_key in datasets:
93 | for dataset in datasets[dataset_key]:
94 | dataset_name, file_path, date_column = dataset
95 |
96 | dataset_output_dir = output_dir / dataset_name
97 | dataset_output_dir.mkdir(parents=True, exist_ok=True)
98 |
99 | if not file_path.exists():
100 | print(f"Warning: {file_path} does not exist. Skipping this dataset.")
101 | continue
102 |
103 | df = pd.read_csv(file_path, header=0, usecols=["higher_order_trajectory", date_column],
104 | dtype={"higher_order_trajectory": "string", date_column: "string"})
105 | df = df.sort_values(by=[date_column])["higher_order_trajectory"].to_numpy()
106 |
107 | df_split = [i.split() for i in df]
108 |
109 | vocab = list(np.unique(np.concatenate(df_split, axis=0)))
110 |
111 | if embedding_dim is not None:
112 | embeddings = generate_embeddings(vocab, embedding_dim)
113 | embeddings_file_path = dataset_output_dir / 'embeddings.npy'
114 | np.save(embeddings_file_path, embeddings)
115 |
116 | vocab = ["EOT"] + vocab
117 | vocab_file_path = dataset_output_dir / 'vocab.txt'
118 | with vocab_file_path.open('w', encoding='utf-8') as vocab_file:
119 | vocab_file.write("\n".join(vocab) + "\n")
120 |
121 | mapping = {k: v for v, k in enumerate(vocab)}
122 | mapping_file_path = dataset_output_dir / 'mapping.json'
123 | with mapping_file_path.open('w', encoding='utf-8') as mapping_file:
124 | json.dump(mapping, mapping_file, ensure_ascii=False)
125 |
126 | neighbors: Dict[int, List[int]] = dict()
127 | for x in vocab[1:]:
128 | neighbors[mapping[str(x)]] = [mapping[i] for i in h3.grid_ring(str(x)) if i in vocab]
129 | neighbors_file_path = dataset_output_dir / 'neighbors.json'
130 | with neighbors_file_path.open('w', encoding='utf-8') as neighbors_file:
131 | json.dump(neighbors, neighbors_file, ensure_ascii=False)
132 |
133 | df_mapped = [[str(mapping[j]) for j in i] for i in df_split]
134 | data_file_path = dataset_output_dir / 'data.txt'
135 | with data_file_path.open('w', encoding='utf-8') as data_file:
136 | for item in df_mapped:
137 | data_file.write(' '.join(item) + f" {mapping['EOT']}\n")
138 |
139 | print(f"Processing completed for {dataset_name}.")
140 |
141 | if __name__ == '__main__':
142 | parser = argparse.ArgumentParser(description='Trajectory Prediction Learning for geolife, porto, and rome datasets')
143 |
144 | parser.add_argument('--input_dir', type=Path, default=Path('data'),
145 | help='Path to input dataset files (default: ./data)')
146 |
147 | parser.add_argument('--output_dir', type=Path, default=Path('data'),
148 | help='Path to output directory (default: ./data)')
149 |
150 | parser.add_argument('--datasets', type=str, nargs='+', choices=['geolife', 'porto', 'rome'], required=True,
151 | help='Specify which datasets to process (choose from geolife, porto, rome)')
152 |
153 | parser.add_argument('--embedding_dim', type=int,
154 | help="Dimension of the generated embedding vectors. If not provided, embeddings will not be generated.")
155 |
156 | args = parser.parse_args()
157 |
158 | process_datasets(args.input_dir, args.output_dir, args.datasets, args.embedding_dim)
159 |
--------------------------------------------------------------------------------
/TrajLearn/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import torch
5 | from torch.optim import Optimizer
6 | from tqdm import tqdm
7 | from typing import Optional, Tuple
8 |
9 | from TrajLearn.TrajectoryBatchDataset import TrajectoryBatchDataset
10 |
11 |
12 | class Trainer:
13 | """
14 | Trainer class to handle the training and validation of a given model on trajectory datasets.
15 | """
16 | def __init__(self, model: torch.nn.Module, dataset: TrajectoryBatchDataset, config: dict,
17 | logger, model_checkpoint_directory: str, always_save_checkpoint: bool = False,
18 | optimizer: Optional[Optimizer] = None):
19 | """
20 | Initialize the Trainer class with model, dataset, configurations, and other options.
21 |
22 | Args:
23 | model (torch.nn.Module): The model to train.
24 | dataset (TrajectoryBatchDataset): The dataset to use for training.
25 | config (dict): Configuration dictionary with training parameters.
26 | logger: Logger for logging training progress.
27 | model_checkpoint_directory (str): Directory to save model checkpoints.
28 | always_save_checkpoint (bool): Whether to save a checkpoint every epoch. Defaults to False.
29 | optimizer (Optional[Optimizer]): Optimizer to use. If None, a default optimizer is configured.
30 | """
31 | self.model = model
32 | self.logger = logger
33 | self.train_dataset = dataset
34 | self.always_save_checkpoint = always_save_checkpoint
35 | self.device = config["device"]
36 | self.device_type = 'cuda' if 'cuda' in self.device else 'cpu'
37 | self.config = config
38 | self.out_dir = model_checkpoint_directory
39 | self.max_epochs = config["max_epochs"]
40 | self.block_size = config["block_size"]
41 | self.batch_size = config["batch_size"]
42 | self.min_input_length = config["min_input_length"]
43 | self.max_input_length = config["max_input_length"]
44 | self.learning_rate = config["learning_rate"]
45 | self.weight_decay = config["weight_decay"]
46 | self.beta1 = config["beta1"]
47 | self.beta2 = config["beta2"]
48 | self.grad_clip = config["grad_clip"]
49 | self.decay_lr = config["decay_lr"]
50 | self.warmup_iters = config["warmup_iters"]
51 | self.lr_decay_iters = config["lr_decay_iters"]
52 | self.min_lr = config["min_lr"]
53 | self.patience = config["patience"]
54 | self.early_stopping_counter = 0
55 |
56 | dtype = 'float32'
57 | ptdtype = {
58 | 'float32': torch.float32,
59 | 'bfloat16': torch.bfloat16,
60 | 'float16': torch.float16
61 | }[dtype]
62 | self.ctx = torch.amp.autocast(device_type=self.device_type, dtype=ptdtype)
63 | self.scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
64 |
65 | if optimizer is None:
66 | self.optimizer = model.configure_optimizers(
67 | self.weight_decay, self.learning_rate, (self.beta1, self.beta2), self.device_type
68 | )
69 | else:
70 | self.optimizer = optimizer
71 |
72 | input_lengths = list(range(self.min_input_length, self.max_input_length + 1))
73 | self.train_dataset.create_batches(self.batch_size, input_lengths)
74 |
75 | self.validation_dataset = TrajectoryBatchDataset(
76 | os.path.join(config["data_dir"], config["dataset"]),
77 | dataset_type='val',
78 | delimiter=config["delimiter"],
79 | validation_ratio=config["validation_ratio"],
80 | test_ratio=config["test_ratio"]
81 | )
82 | self.validation_dataset.create_batches(self.batch_size, self.min_input_length)
83 |
84 | @torch.no_grad()
85 | def val_epoch(self) -> Tuple[float, float]:
86 | """
87 | Run a single validation epoch.
88 |
89 | Returns:
90 | Tuple[float, float]: Average validation loss and accuracy.
91 | """
92 | self.model.eval()
93 | total_val_loss = 0
94 | total_correct = 0
95 | total_samples = 0
96 |
97 | for X, Y in tqdm(self.validation_dataset, leave=False):
98 | x = X.to(self.device)
99 | y = Y.to(self.device)
100 | with self.ctx:
101 | output, loss = self.model(x, y)
102 | total_val_loss += loss.item()
103 | total_correct += (output.argmax(dim=2)[:, -1] == y[:, -1]).sum().item()
104 | total_samples += Y.shape[0]
105 |
106 | avg_val_loss = total_val_loss / total_samples
107 | val_accuracy = total_correct / total_samples
108 | return avg_val_loss, val_accuracy
109 |
110 | def train(self):
111 | """
112 | Train the model for a specified number of epochs, validate, and save checkpoints.
113 | """
114 | iter_num = 0
115 | best_val_loss = float('inf')
116 | self.logger.info("Starting training")
117 |
118 | for epoch in range(self.max_epochs):
119 | self.model.train()
120 | t_epoch_start = time.time()
121 | total_loss = 0
122 | total_samples = 0
123 |
124 | for X, Y in (pbar := tqdm(self.train_dataset, leave=False)):
125 | iter_num += 1
126 | lr = self.get_lr(iter_num) if self.decay_lr else self.learning_rate
127 | for param_group in self.optimizer.param_groups:
128 | param_group['lr'] = lr
129 |
130 | with self.ctx:
131 | _, loss = self.model(X.to(self.device), Y.to(self.device))
132 | total_loss += loss.item()
133 |
134 | total_samples += X.shape[0]
135 |
136 | self.scaler.scale(loss).backward()
137 | if self.grad_clip != 0.0:
138 | self.scaler.unscale_(self.optimizer)
139 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
140 | self.scaler.step(self.optimizer)
141 | self.scaler.update()
142 | self.optimizer.zero_grad(set_to_none=True)
143 | pbar.set_postfix({'loss': total_loss / total_samples})
144 |
145 | dt = time.time() - t_epoch_start
146 | avg_loss = total_loss / total_samples
147 | self.logger.info(f"Training epoch {epoch + 1}/{self.max_epochs}, "
148 | f"Training loss: {avg_loss:.3g}, Time: {dt:.1f}s")
149 |
150 | t_val_start = time.time()
151 | avg_val_loss, val_accuracy = self.val_epoch()
152 | dt = time.time() - t_val_start
153 | self.logger.info(f'Validation loss: {avg_val_loss:.3g}, '
154 | f'Validation Accuracy: {val_accuracy * 100:.2f}%, Time: {dt:.1f}s')
155 |
156 | if avg_val_loss < best_val_loss:
157 | best_val_loss = avg_val_loss
158 | self.save_checkpoint()
159 | self.early_stopping_counter = 0
160 | else:
161 | self.early_stopping_counter += 1
162 | if self.early_stopping_counter >= self.patience:
163 | self.logger.info("Early stopping triggered.")
164 | break
165 |
166 | def get_lr(self, it: int) -> float:
167 | """
168 | Calculate learning rate with optional warmup and cosine decay.
169 |
170 | Args:
171 | it (int): Current iteration number.
172 |
173 | Returns:
174 | float: Calculated learning rate for the current iteration.
175 | """
176 | if it < self.warmup_iters:
177 | return self.learning_rate * it / self.warmup_iters
178 | if it == self.warmup_iters:
179 | self.logger.info("Warm-up iterations ended, starting cosine decay")
180 | if it == self.lr_decay_iters:
181 | self.logger.info("Decay iterations ended, using minimum learning rate")
182 | if it >= self.lr_decay_iters:
183 | return self.min_lr
184 |
185 | decay_ratio = (it - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters)
186 | assert 0 <= decay_ratio <= 1
187 | coefficient = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
188 | return self.min_lr + coefficient * (self.learning_rate - self.min_lr)
189 |
190 | def save_checkpoint(self):
191 | """
192 | Save model and optimizer state as a checkpoint.
193 | """
194 | checkpoint = {
195 | 'model': self.model.state_dict(),
196 | 'optimizer': self.optimizer.state_dict(),
197 | 'config': self.config,
198 | }
199 | checkpoint_path = os.path.join(self.out_dir, 'checkpoint.pt')
200 | try:
201 | torch.save(checkpoint, checkpoint_path)
202 | self.logger.info("Saved current best model to " + checkpoint_path)
203 | except Exception as e:
204 | self.logger.error(f"Failed to save checkpoint: {e}")
205 |
--------------------------------------------------------------------------------
/TrajLearn/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import time
4 | import pickle
5 | import random
6 | from pathlib import Path
7 | from typing import Dict, Any, Optional
8 |
9 | import numpy as np
10 | import torch
11 | from TrajLearn.TrajectoryBatchDataset import TrajectoryBatchDataset
12 | from TrajLearn.model import ModelConfig, CausalLM
13 | from TrajLearn.evaluator import evaluate_model
14 | from TrajLearn.trainer import Trainer
15 | from TrajLearn.logger import get_logger
16 | from baselines import HigherOrderMarkovChain
17 |
18 |
19 | def setup_environment(seed: int) -> None:
20 | """
21 | Set up the environment by configuring CUDA and setting random seeds.
22 |
23 | Args:
24 | - seed (int): The seed for random number generators.
25 | - device_id (str): The CUDA device ID to set for training.
26 | """
27 | torch.cuda.cudnn_enabled = False
28 | torch.backends.cudnn.deterministic = True
29 |
30 | random.seed(seed)
31 | np.random.seed(seed)
32 | torch.manual_seed(seed)
33 | torch.cuda.manual_seed(seed)
34 | torch.backends.cuda.matmul.allow_tf32 = True
35 | torch.backends.cudnn.allow_tf32 = True
36 |
37 |
38 | def get_dataset(config: Dict[str, Any], test_mode: bool = False) -> TrajectoryBatchDataset:
39 | """
40 | Load the trajectory dataset based on configuration.
41 |
42 | Args:
43 | - config (Dict[str, Any]): Configuration dictionary.
44 | - test_mode (bool): Whether to load test or training data (default is False).
45 |
46 | Returns:
47 | - TrajectoryBatchDataset: The dataset object.
48 | """
49 | dataset_type = 'test' if test_mode else 'train'
50 | dataset_path = Path(config["data_dir"]) / config["dataset"]
51 | dataset = TrajectoryBatchDataset(
52 | dataset_path,
53 | dataset_type=dataset_type,
54 | delimiter=config["delimiter"],
55 | validation_ratio=config["validation_ratio"],
56 | test_ratio=config["test_ratio"]
57 | )
58 | config["vocab_size"] = dataset.vocab_size
59 | return dataset
60 |
61 |
62 | def load_model(model: torch.nn.Module | HigherOrderMarkovChain, checkpoint_path: Optional[Path], device: str) -> torch.nn.Module:
63 | """
64 | Load a model from a checkpoint.
65 |
66 | Args:
67 | - config (Dict[str, Any]): Configuration dictionary.
68 | - dataset (TrajectoryBatchDataset): Dataset to extract vocabulary size.
69 | - checkpoint_path (Optional[Path]): Path to the model checkpoint (default is None).
70 |
71 | Returns:
72 | - Module: The initialized model, possibly with loaded weights.
73 | """
74 | if isinstance(model, HigherOrderMarkovChain):
75 | with open(checkpoint_path, 'rb') as f:
76 | checkpoint = pickle.load(f)
77 | optimizer = None
78 | else:
79 | checkpoint = torch.load(checkpoint_path, map_location=device)
80 | optimizer = checkpoint['optimizer']
81 | config = checkpoint['config']
82 | state_dict = checkpoint['model']
83 | unwanted_prefix = '_orig_mod.'
84 | for k, _ in list(state_dict.items()):
85 | if k.startswith(unwanted_prefix):
86 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
87 | model.load_state_dict(state_dict)
88 |
89 | return model, config, optimizer
90 |
91 | def initialize_model(config, custom_init=None):
92 | model_config = ModelConfig(
93 | block_size=config["block_size"],
94 | vocab_size=config["vocab_size"],
95 | n_layer=config["n_layer"],
96 | n_head=config["n_head"],
97 | n_embd=config["n_embd"],
98 | dropout=config["dropout"],
99 | bias=config["bias"]
100 | )
101 | return CausalLM(model_config, custom_init)
102 |
103 |
104 | def train_model(
105 | name: str,
106 | dataset: TrajectoryBatchDataset,
107 | config: Dict[str, Any],
108 | model: Optional[torch.nn.Module | HigherOrderMarkovChain] = None
109 | ) -> None:
110 | """
111 | Set up and execute the training process.
112 |
113 | Args:
114 | - name (str): Name for the current training session (used for saving logs/checkpoints).
115 | - dataset (TrajectoryBatchDataset): Dataset object for training.
116 | - config (Dict[str, Any]): Configuration dictionary.
117 | - model (Optional[torch.nn.Module]): The model to be trained (can be None before loading).
118 | """
119 | time_str = name + "-" + time.strftime("%Y%m%d-%H%M%S")
120 | model_checkpoint_directory = Path(config["model_checkpoint_directory"]) / time_str
121 | Path(model_checkpoint_directory).mkdir(parents=True, exist_ok=True)
122 | log_directory = model_checkpoint_directory / 'logs'
123 |
124 | if model is None:
125 | if config['custom_initialization']:
126 | custom_init_path = os.path.join(config["data_dir"], config["dataset"], 'embeddings.npy')
127 | embeddings_np = np.load(custom_init_path)
128 | custom_init = torch.from_numpy(embeddings_np).to(torch.float32)
129 | model = initialize_model(config, custom_init=custom_init)
130 | else:
131 | model = initialize_model(config=config)
132 |
133 | if config['train_from_checkpoint_if_exist']:
134 | model_checkpoints = sorted(glob.glob(str(Path(config["model_checkpoint_directory"]) / (name + "-*"))))
135 | if len(model_checkpoints) > 0:
136 | last_checkpoint = Path(model_checkpoints[-1]) / 'checkpoint.pt'
137 | model, config, optimizer = load_model(model, last_checkpoint, config['device'])
138 |
139 | logger = get_logger(log_directory, name, phase="train")
140 |
141 | if isinstance(model, HigherOrderMarkovChain):
142 | model.train(dataset, logger, str(model_checkpoint_directory))
143 | else:
144 | model.to(config["device"])
145 | trainer = Trainer(model, dataset, config, logger, str(model_checkpoint_directory))
146 | trainer.train()
147 |
148 |
149 | def test_model(name: str, dataset: TrajectoryBatchDataset, config: Dict[str, Any], model: Optional[torch.nn.Module] = None) -> list:
150 | """
151 | Set up and execute the testing process.
152 |
153 | Args:
154 | - name (str): Name of the configuration (used for loading the model checkpoint).
155 | - dataset (TrajectoryBatchDataset): Dataset object for testing.
156 | - config (Dict[str, Any]): Configuration dictionary.
157 | - model (Optional[torch.nn.Module]): The model to be tested (can be None before loading).
158 | """
159 | model_checkpoint_directory = sorted(glob.glob(str(Path(config["model_checkpoint_directory"]) / (name + "-2024*"))))[-1]
160 | log_directory = Path(model_checkpoint_directory) / 'logs'
161 |
162 | logger = get_logger(log_directory, name, phase="test")
163 |
164 | if model is None:
165 | model = initialize_model(config)
166 |
167 | checkpoint_path = Path(model_checkpoint_directory) / 'checkpoint.pt'
168 | model, _, __ = load_model(model, checkpoint_path, config['device'])
169 |
170 | prediction_length = config["test_prediction_length"]
171 | dataset.create_batches(
172 | config["batch_size"], config["test_input_length"], prediction_length, False, False)
173 |
174 | if isinstance(model, HigherOrderMarkovChain):
175 | results = model.evaluate(dataset)
176 | logger.info(", ".join(results))
177 | return results
178 | else:
179 | model.to(config["device"])
180 | return evaluate_model(model, dataset, config, logger)
181 |
--------------------------------------------------------------------------------
/baselines/HigherOrderAttnLSTM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 |
6 | class HigherOrderAttnLSTM(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 | self.embedding = nn.Embedding(config["vocab_size"], config["n_embd"])
10 | self.lstm = nn.LSTM(config["n_embd"], config["n_embd"], num_layers=config["n_layer"], dropout=config["dropout"], batch_first=True)
11 | self.attention = nn.Linear(config["n_embd"], 1)
12 | self.lm_head = nn.Linear(config["n_embd"], config["vocab_size"])
13 |
14 | def forward(self, x, targets=None):
15 | x = self.embedding(x)
16 | x = torch.squeeze(x, dim=2)
17 | lstm_outputs, _ = self.lstm(x)
18 | attention_scores = self.attention(lstm_outputs)
19 | attention_weights = torch.softmax(attention_scores.squeeze(dim=-1), dim=1)
20 | attended_outputs = torch.bmm(attention_weights.unsqueeze(dim=1), lstm_outputs)
21 | logits = self.lm_head(attended_outputs)
22 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) if targets is not None else None
23 | return logits, loss
24 |
25 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
26 | return optim.AdamW(self.parameters(), betas=betas, lr=learning_rate)
--------------------------------------------------------------------------------
/baselines/HigherOrderGRU.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 |
6 | class HigherOrderGRU(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 | self.embedding = nn.Embedding(config["vocab_size"], config["n_embd"])
10 | self.gru = nn.GRU(config["n_embd"], config["n_embd"], num_layers=config["n_layer"], dropout=config["dropout"], batch_first=True)
11 | self.lm_head = nn.Linear(config["n_embd"], config["vocab_size"])
12 |
13 | def forward(self, x, targets=None):
14 | x = self.embedding(x)
15 | x = torch.squeeze(x, dim=2)
16 | x, _ = self.gru(x)
17 | logits = self.lm_head(x[:, [-1], :])
18 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) if targets is not None else None
19 | return logits, loss
20 |
21 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
22 | return optim.AdamW(self.parameters(), betas=betas, lr=learning_rate)
--------------------------------------------------------------------------------
/baselines/HigherOrderLSTM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 |
6 | class HigherOrderLSTM(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 | self.embedding = nn.Embedding(config["vocab_size"], config["n_embd"])
10 | self.lstm = nn.LSTM(config["n_embd"], config["n_embd"], num_layers=config["n_layer"], dropout=config["dropout"], batch_first=True)
11 | self.lm_head = nn.Linear(config["n_embd"], config["vocab_size"])
12 |
13 | def forward(self, x, targets=None):
14 | x = self.embedding(x)
15 | x = torch.squeeze(x, dim=2)
16 | x, _ = self.lstm(x)
17 | logits = self.lm_head(x[:, [-1], :])
18 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) if targets is not None else None
19 | return logits, loss
20 |
21 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
22 | return optim.AdamW(self.parameters(), betas=betas, lr=learning_rate)
--------------------------------------------------------------------------------
/baselines/HigherOrderMarkovChain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import pickle
4 | from collections import defaultdict
5 | from nltk.translate.bleu_score import sentence_bleu
6 | from tqdm import tqdm
7 |
8 | # Assuming that the TrajectoryBatchDataset class is defined as provided
9 | # and that the dataset instance is ready and create_batches() has been called.
10 |
11 | class HigherOrderMarkovChain:
12 | def __init__(self, config, order=1):
13 | self.order = order
14 | self.transition_counts = None
15 | self.transition_probs = None
16 | self.state_index_mapping = {}
17 | self.index_state_mapping = {}
18 | self.states = []
19 | self.num_states = 0
20 | self.config = config
21 | self.logger = None
22 |
23 |
24 | def train(self, dataset, logger, model_checkpoint_directory: str):
25 | """
26 | Train the Higher-Order Markov Chain using the provided sequences.
27 |
28 | :param sequences: List of sequences (list of lists), where each sequence is a list of states.
29 | """
30 | logger.info("Building transition matrix...")
31 | self.logger = logger
32 | sequences = dataset.data
33 | self._build_state_mappings(sequences)
34 | self._build_transition_matrix(sequences)
35 | self.save_checkpoint(model_checkpoint_directory)
36 |
37 | def _build_state_mappings(self, sequences):
38 | # Build state to index mappings
39 | state_set = set()
40 | for sequence in sequences:
41 | state_set.update(sequence)
42 | self.states = list(state_set)
43 | self.num_states = len(self.states)
44 | self.state_index_mapping = {state: idx for idx, state in enumerate(self.states)}
45 | self.index_state_mapping = {idx: state for state, idx in self.state_index_mapping.items()}
46 |
47 |
48 | def _build_transition_matrix(self, sequences):
49 | # Initialize transition counts with ones for smoothing
50 | self.transition_counts = defaultdict(lambda: defaultdict(lambda: 1))
51 |
52 | for sequence in tqdm(sequences, desc="Processing training sequences"):
53 | for i in range(len(sequence) - self.order):
54 | current_state = tuple(sequence[i:i + self.order])
55 | next_state = sequence[i + self.order]
56 | self.transition_counts[current_state][next_state] += 1
57 |
58 | # Convert counts to probabilities
59 | self.transition_probs = {}
60 | for current_state, next_states in self.transition_counts.items():
61 | total = sum(next_states.values())
62 | self.transition_probs[current_state] = {state: count / total for state, count in next_states.items()}
63 |
64 | def predict_next_state(self, current_sequence):
65 | """
66 | Predict the next state(s) given the current sequence.
67 |
68 | :param current_sequence: List of the current sequence of states.
69 | :return: List of predicted next states, sorted by probability in descending order.
70 | """
71 | current_state = tuple(current_sequence[-self.order:])
72 | next_state_probs = self.transition_probs.get(current_state, {})
73 | if not next_state_probs:
74 | return []
75 | sorted_next_states = sorted(next_state_probs.items(), key=lambda item: item[1], reverse=True)
76 | return [state for state, prob in sorted_next_states[:5]] # Return top 5
77 |
78 | def predict_next_n_steps(self, sequence, n=5):
79 | """
80 | Predict the next n steps given the initial sequence.
81 |
82 | :param sequence: Initial sequence of states (list).
83 | :param n: Number of steps to predict.
84 | :return: List of lists containing predicted states at each step.
85 | """
86 | predictions = []
87 | current_sequence = sequence.copy()
88 | for _ in range(n):
89 | next_states = self.predict_next_state(current_sequence)
90 | if not next_states:
91 | break # If no next state is found
92 | predictions.append(next_states)
93 | current_sequence.append(next_states[0]) # Append the most probable next state
94 | return predictions
95 |
96 | def save_checkpoint(self, model_checkpoint_directory):
97 | """
98 | Save the trained model to a file.
99 |
100 | :param filepath: Path to the file where the model will be saved.
101 | """
102 | model_data = {
103 | 'order': self.order,
104 | 'transition_probs': self.transition_probs,
105 | 'state_index_mapping': self.state_index_mapping,
106 | 'index_state_mapping': self.index_state_mapping,
107 | 'states': self.states,
108 | 'num_states': self.num_states
109 | }
110 | checkpoint = {
111 | 'model': model_data,
112 | 'config': self.config,
113 | }
114 | checkpoint_path = os.path.join(model_checkpoint_directory, 'checkpoint.pt')
115 | try:
116 | with open(checkpoint_path, 'wb') as f:
117 | pickle.dump(checkpoint, f)
118 | except Exception as e:
119 | self.logger.error(f"Failed to save checkpoint: {e}")
120 |
121 | def load_state_dict(self, model_data):
122 | """
123 | Load a trained model from a file.
124 |
125 | :param filepath: Path to the file where the model is saved.
126 | :return: An instance of HigherOrderMarkovChain with loaded parameters.
127 | """
128 | self.transition_probs = model_data['transition_probs']
129 | self.state_index_mapping = model_data['state_index_mapping']
130 | self.index_state_mapping = model_data['index_state_mapping']
131 | self.states = model_data['states']
132 | self.num_states = model_data['num_states']
133 | self.order = model_data['order']
134 |
135 | def evaluate(self, test_dataset):
136 | """
137 | Evaluate the model using the test sequences.
138 |
139 | :param test_sequences: List of test sequences.
140 | :return: Dictionary containing evaluation metrics.
141 | """
142 | print("Evaluating model...")
143 | total_predictions = 0
144 | hit_step5_at1 = 0
145 | hit_step5_at3 = 0
146 | hit_step5_at5 = 0
147 | bleu_scores = 0
148 |
149 | start_time = time.time()
150 | for test_sequence in tqdm(test_dataset.data, desc="Processing test sequences"):
151 | if len(test_sequence) < self.order + 5:
152 | continue
153 | for i in range(len(test_sequence) - (self.order + 5) + 1):
154 | original_sequence = test_sequence[i:i + self.order + 5] # Include the initial sequence and actual next steps
155 | observe_sequence = original_sequence[:self.order]
156 | actual_next_steps = original_sequence[self.order:self.order + 5]
157 |
158 | predicted_next_steps = self.predict_next_n_steps(observe_sequence, n=5)
159 |
160 | if len(predicted_next_steps) < 5:
161 | continue # Skip if predictions are incomplete
162 |
163 | actual_step5 = actual_next_steps[4]
164 | predicted_step5 = predicted_next_steps[4]
165 |
166 | if actual_step5 == predicted_step5[0]:
167 | hit_step5_at1 += 1
168 | if actual_step5 in predicted_step5[:3]:
169 | hit_step5_at3 += 1
170 | if actual_step5 in predicted_step5[:5]:
171 | hit_step5_at5 += 1
172 |
173 | # For BLEU score calculation
174 | predicted_sentence = [predicted_next_steps[j][0] for j in range(len(predicted_next_steps))]
175 | bleu_scores += sentence_bleu([actual_next_steps], predicted_sentence)
176 |
177 | total_predictions += 1
178 |
179 | test_duration = time.time() - start_time
180 |
181 | return [
182 | f"Accuracy@1: {hit_step5_at1 / total_predictions if total_predictions else 0}",
183 | f"Accuracy@3: {hit_step5_at3 / total_predictions if total_predictions else 0}",
184 | f"Accuracy@5: {hit_step5_at5 / total_predictions if total_predictions else 0}",
185 | f"BLEU score: {bleu_scores / total_predictions if total_predictions else 0}",
186 | f"Test duration: {test_duration:.3f}(s)",
187 | f"Samples: {total_predictions}"
188 | ]
189 |
--------------------------------------------------------------------------------
/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | from .HigherOrderAttnLSTM import HigherOrderAttnLSTM
2 | from .HigherOrderLSTM import HigherOrderLSTM
3 | from .HigherOrderGRU import HigherOrderGRU
4 | from .HigherOrderMarkovChain import HigherOrderMarkovChain
5 |
6 | __all__ = ["HigherOrderAttnLSTM", "HigherOrderLSTM", "HigherOrderGRU", "HigherOrderMarkovChain"]
7 |
--------------------------------------------------------------------------------
/configs.yaml:
--------------------------------------------------------------------------------
1 | porto7:
2 | test_ratio: 0.2
3 | validation_ratio: 0.1
4 | delimiter: " "
5 | min_input_length: 10
6 | max_input_length: 14
7 | test_input_length: 10
8 | test_prediction_length: 5
9 | batch_size: 256
10 | device: cuda
11 | max_epochs: 5
12 | block_size: 24
13 | learning_rate: 5.e-3
14 | weight_decay: 8.e-3
15 | beta1: 0.9
16 | beta2: 0.95
17 | grad_clip: 1.0
18 | decay_lr: True
19 | warmup_iters: 100
20 | lr_decay_iters: 20000
21 | min_lr: 5.e-8
22 | seed: 42
23 | data_dir: ./data
24 | dataset: porto7
25 | n_layer: 12
26 | n_head: 4
27 | n_embd: 512
28 | bias: False
29 | dropout: 0
30 | model_checkpoint_directory: ./models/
31 | train_from_checkpoint_if_exist: False
32 | custom_initialization: False
33 | patience: 3
34 | porto8:
35 | test_ratio: 0.2
36 | validation_ratio: 0.1
37 | delimiter: " "
38 | min_input_length: 10
39 | max_input_length: 14
40 | test_input_length: 10
41 | test_prediction_length: 5
42 | batch_size: 256
43 | device: cuda
44 | max_epochs: 4
45 | block_size: 24
46 | learning_rate: 5.e-4
47 | weight_decay: 5.e-3
48 | beta1: 0.9
49 | beta2: 0.95
50 | grad_clip: 1.0
51 | decay_lr: True
52 | warmup_iters: 200
53 | lr_decay_iters: 40000
54 | min_lr: 1.e-8
55 | seed: 42
56 | data_dir: ./data
57 | dataset: porto8
58 | n_layer: 12
59 | n_head: 4
60 | n_embd: 512
61 | bias: False
62 | dropout: 0
63 | model_checkpoint_directory: ./models/
64 | train_from_checkpoint_if_exist: False
65 | custom_initialization: False
66 | patience: 3
67 | porto9:
68 | test_ratio: 0.2
69 | validation_ratio: 0.1
70 | delimiter: " "
71 | min_input_length: 10
72 | max_input_length: 14
73 | test_input_length: 10
74 | test_prediction_length: 5
75 | batch_size: 256
76 | device: cuda
77 | max_epochs: 3
78 | block_size: 24
79 | learning_rate: 1.e-3
80 | weight_decay: 5.e-3
81 | beta1: 0.9
82 | beta2: 0.95
83 | grad_clip: 1.0
84 | decay_lr: True
85 | warmup_iters: 400
86 | lr_decay_iters: 80000
87 | min_lr: 5.e-9
88 | seed: 42
89 | data_dir: ./data
90 | dataset: porto9
91 | n_layer: 12
92 | n_head: 4
93 | n_embd: 512
94 | bias: False
95 | dropout: 0
96 | model_checkpoint_directory: ./models/
97 | train_from_checkpoint_if_exist: False
98 | custom_initialization: False
99 | patience: 3
100 | rome7:
101 | test_ratio: 0.2
102 | validation_ratio: 0.1
103 | delimiter: " "
104 | min_input_length: 10
105 | max_input_length: 14
106 | test_input_length: 10
107 | test_prediction_length: 5
108 | batch_size: 256
109 | device: cuda
110 | max_epochs: 10
111 | block_size: 24
112 | learning_rate: 5.e-4
113 | weight_decay: 5.e-3
114 | beta1: 0.9
115 | beta2: 0.95
116 | grad_clip: 1.0
117 | decay_lr: True
118 | warmup_iters: 100
119 | lr_decay_iters: 10000
120 | min_lr: 5.e-8
121 | seed: 42
122 | data_dir: ./data
123 | dataset: rome7
124 | n_layer: 12
125 | n_head: 4
126 | n_embd: 512
127 | bias: False
128 | dropout: 0
129 | model_checkpoint_directory: ./models/
130 | train_from_checkpoint_if_exist: False
131 | custom_initialization: False
132 | patience: 3
133 | rome8:
134 | test_ratio: 0.2
135 | validation_ratio: 0.1
136 | delimiter: " "
137 | min_input_length: 10
138 | max_input_length: 14
139 | test_input_length: 10
140 | test_prediction_length: 5
141 | batch_size: 256
142 | device: cuda
143 | max_epochs: 10
144 | block_size: 24
145 | learning_rate: 5.e-4
146 | weight_decay: 5.e-3
147 | beta1: 0.9
148 | beta2: 0.95
149 | grad_clip: 1.0
150 | decay_lr: True
151 | warmup_iters: 200
152 | lr_decay_iters: 24000
153 | min_lr: 1.e-8
154 | seed: 42
155 | data_dir: ./data
156 | dataset: rome8
157 | n_layer: 12
158 | n_head: 4
159 | n_embd: 512
160 | bias: False
161 | dropout: 0
162 | model_checkpoint_directory: ./models/
163 | train_from_checkpoint_if_exist: False
164 | custom_initialization: False
165 | patience: 3
166 | rome9:
167 | test_ratio: 0.2
168 | validation_ratio: 0.1
169 | delimiter: " "
170 | min_input_length: 10
171 | max_input_length: 14
172 | test_input_length: 10
173 | test_prediction_length: 5
174 | batch_size: 256
175 | device: cuda
176 | max_epochs: 8
177 | block_size: 24
178 | learning_rate: 1.e-3
179 | weight_decay: 5.e-3
180 | beta1: 0.9
181 | beta2: 0.95
182 | grad_clip: 1.0
183 | decay_lr: True
184 | warmup_iters: 400
185 | lr_decay_iters: 60000
186 | min_lr: 5.e-9
187 | seed: 42
188 | data_dir: ./data
189 | dataset: rome9
190 | n_layer: 12
191 | n_head: 4
192 | n_embd: 512
193 | bias: False
194 | dropout: 0
195 | model_checkpoint_directory: ./models/
196 | train_from_checkpoint_if_exist: False
197 | custom_initialization: False
198 | patience: 3
199 | geolife7:
200 | test_ratio: 0.2
201 | validation_ratio: 0.1
202 | delimiter: " "
203 | min_input_length: 10
204 | max_input_length: 14
205 | test_input_length: 10
206 | test_prediction_length: 5
207 | batch_size: 256
208 | device: cuda
209 | max_epochs: 10
210 | block_size: 24
211 | learning_rate: 5.e-4
212 | weight_decay: 2.e-3
213 | beta1: 0.9
214 | beta2: 0.95
215 | grad_clip: 1.0
216 | decay_lr: True
217 | warmup_iters: 100
218 | lr_decay_iters: 12000
219 | min_lr: 2.e-8
220 | seed: 42
221 | data_dir: ./data
222 | dataset: geolife7
223 | n_layer: 12
224 | n_head: 4
225 | n_embd: 512
226 | bias: False
227 | dropout: 0
228 | model_checkpoint_directory: ./models/
229 | train_from_checkpoint_if_exist: False
230 | custom_initialization: False
231 | patience: 3
232 | geolife8:
233 | test_ratio: 0.2
234 | validation_ratio: 0.1
235 | delimiter: " "
236 | min_input_length: 10
237 | max_input_length: 14
238 | test_input_length: 10
239 | test_prediction_length: 5
240 | batch_size: 256
241 | device: cuda
242 | max_epochs: 10
243 | block_size: 24
244 | learning_rate: 5.e-4
245 | weight_decay: 5.e-3
246 | beta1: 0.9
247 | beta2: 0.95
248 | grad_clip: 1.0
249 | decay_lr: True
250 | warmup_iters: 200
251 | lr_decay_iters: 24000
252 | min_lr: 1.e-8
253 | seed: 42
254 | data_dir: ./data
255 | dataset: geolife8
256 | n_layer: 12
257 | n_head: 4
258 | n_embd: 512
259 | bias: False
260 | dropout: 0
261 | model_checkpoint_directory: ./models/
262 | train_from_checkpoint_if_exist: False
263 | custom_initialization: False
264 | patience: 3
265 | geolife9:
266 | test_ratio: 0.2
267 | validation_ratio: 0.1
268 | delimiter: " "
269 | min_input_length: 10
270 | max_input_length: 14
271 | test_input_length: 10
272 | test_prediction_length: 5
273 | batch_size: 256
274 | device: cuda
275 | max_epochs: 6
276 | block_size: 24
277 | learning_rate: 5.e-3
278 | weight_decay: 1.e-2
279 | beta1: 0.9
280 | beta2: 0.95
281 | grad_clip: 1.0
282 | decay_lr: True
283 | warmup_iters: 400
284 | lr_decay_iters: 60000
285 | min_lr: 1.e-9
286 | seed: 42
287 | data_dir: ./data
288 | dataset: geolife9
289 | n_layer: 12
290 | n_head: 4
291 | n_embd: 512
292 | bias: False
293 | dropout: 0
294 | model_checkpoint_directory: ./models/
295 | train_from_checkpoint_if_exist: False
296 | custom_initialization: False
297 | patience: 3
--------------------------------------------------------------------------------
/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This script downloads datasets, extracts them, and verifies MD5 checksums.
4 |
5 | # Exit the script if any command fails
6 | set -e
7 |
8 | # Variables
9 | DATA_DIR="./data"
10 | BASE_URL="https://zenodo.org/records/8076553/files"
11 |
12 | echo "Starting download and extraction script..."
13 |
14 | # Check if arguments are provided
15 | if [ "$#" -eq 0 ]; then
16 | echo "No datasets specified. Please provide dataset names as arguments."
17 | exit 1
18 | fi
19 |
20 | # Create data directory if it doesn't exist
21 | echo "Checking for data directory..."
22 | if [ ! -d "$DATA_DIR" ]; then
23 | echo "Creating data directory at $DATA_DIR"
24 | mkdir -p "$DATA_DIR"
25 | else
26 | echo "Data directory already exists. Proceeding..."
27 | fi
28 |
29 | # Loop through all datasets provided as arguments
30 | for DATASET in "$@"; do
31 | echo "Processing dataset: $DATASET"
32 |
33 | # Define output zip and dataset URL
34 | OUTPUT_ZIP="${DATASET}.zip"
35 | DATASET_URL="${BASE_URL}/ho_${OUTPUT_ZIP}?download=1"
36 |
37 | # Download the dataset
38 | echo "Checking if the dataset is already downloaded..."
39 | if [ -f "$DATA_DIR/$OUTPUT_ZIP" ]; then
40 | echo "Dataset zip file already exists. Skipping download for $DATASET."
41 | else
42 | echo "Downloading dataset from $DATASET_URL..."
43 | wget --show-progress -O "$DATA_DIR/$OUTPUT_ZIP" "$DATASET_URL"
44 | echo "Download completed for $DATASET."
45 | fi
46 |
47 | # Extract the dataset
48 | echo "Checking if dataset is already extracted..."
49 | if [ -d "$DATA_DIR/$DATASET" ]; then
50 | echo "Dataset already extracted. Skipping extraction for $DATASET."
51 | else
52 | echo "Extracting dataset to $DATASET..."
53 | unzip -q "$DATA_DIR/$OUTPUT_ZIP" -d "$DATA_DIR/$DATASET"
54 | echo "Extraction completed for $DATASET."
55 | fi
56 |
57 | # Clean up the zip file
58 | if [ -f "$DATA_DIR/$OUTPUT_ZIP" ]; then
59 | echo "Cleaning up zip file for $DATASET..."
60 | rm -f "$DATA_DIR/$OUTPUT_ZIP"
61 | fi
62 |
63 | # Verify MD5 checksums
64 | echo "Verifying MD5 checksums for files in $DATA_DIR/$DATASET..."
65 | for file in "$DATA_DIR/$DATASET"/*; do
66 | if [ -f "$file" ]; then
67 | md5sum "$file"
68 | fi
69 | done
70 | echo "MD5 checksum verification completed for $DATASET."
71 | done
72 |
73 | echo "Script completed successfully for all datasets!"
74 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: trajlearn
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - python=3.10
7 | - pip:
8 | - transformers==4.44.1
9 | - numpy==1.26.4
10 | - pandas==2.2.2
11 | - h3==4.1.2
12 | - torch==2.3.0
13 | - geojson==3.0.1
14 | - geopandas==1.0.1
15 | - geopy==2.4.1
16 | - matplotlib==3.9.0
17 | - nltk==3.8.1
18 | - PyYAML==6.0.1
--------------------------------------------------------------------------------
/img/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amir-ni/Trajectory-prediction/570b89b4e0261056df3297487726b68a4983df9f/img/architecture.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from TrajLearn.utils import setup_environment, get_dataset, test_model, train_model
4 | from TrajLearn.config_loader import load_config
5 | from baselines import HigherOrderAttnLSTM, HigherOrderLSTM, HigherOrderGRU, HigherOrderMarkovChain
6 |
7 | def main() -> None:
8 | """
9 | Main function to handle argument parsing and execute training or testing.
10 | """
11 | parser = argparse.ArgumentParser(description='Trajectory Prediction Learning')
12 | parser.add_argument('config', type=str, help='Path to configuration file')
13 | parser.add_argument('--baseline', type=str, required=False, help='Baseline model to run')
14 | parser.add_argument('--test', default=False, action='store_true')
15 | args = parser.parse_args()
16 |
17 | config_list = load_config(args.config)
18 |
19 | for name, config in config_list.items():
20 | setup_environment(config["seed"])
21 |
22 | dataset = get_dataset(config, test_mode=args.test)
23 |
24 | if args.baseline:
25 | if args.baseline == "gru":
26 | model = HigherOrderGRU(config)
27 | elif args.baseline == "lstm":
28 | model = HigherOrderLSTM(config)
29 | elif args.baseline == "lstm-attn":
30 | model = HigherOrderAttnLSTM(config)
31 | elif args.baseline == "mc":
32 | model = HigherOrderMarkovChain(config)
33 | else:
34 | raise ValueError("Baseline not found.")
35 | else:
36 | model = None
37 |
38 | if args.test:
39 | test_model(name, dataset, config, model)
40 | else:
41 | train_model(name, dataset, config, model)
42 |
43 |
44 | if __name__ == '__main__':
45 | main()
46 |
--------------------------------------------------------------------------------