├── .gitignore ├── LICENSE ├── README.md ├── TrajLearn ├── TrajectoryBatchDataset.py ├── config_loader.py ├── evaluator.py ├── logger.py ├── mixed_res.py ├── model.py ├── preprocess.py ├── trainer.py └── utils.py ├── baselines ├── HigherOrderAttnLSTM.py ├── HigherOrderGRU.py ├── HigherOrderLSTM.py ├── HigherOrderMarkovChain.py └── __init__.py ├── configs.yaml ├── download_data.sh ├── environment.yml ├── img └── architecture.jpg └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | data/ 3 | plots/ 4 | old/ 5 | visualization 6 | 7 | # OS 8 | .DS_Store 9 | 10 | # Byte-compiled / optimized / DLL files 11 | *.pyc 12 | *.pth 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # Jupyter Notebook 18 | .ipynb_checkpoints 19 | 20 | # Environments 21 | .env 22 | .venv 23 | env/ 24 | venv/ 25 | ENV/ 26 | env.bak/ 27 | venv.bak/ 28 | 29 | ### Visual Studio Code ### 30 | .vscode/* 31 | !.vscode/settings.json 32 | !.vscode/tasks.json 33 | !.vscode/launch.json 34 | !.vscode/extensions.json 35 | 36 | ## Data 37 | ./data -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Amirhossein Nadiri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TrajLearn: A Novel Model for Trajectory Prediction 2 | 3 | ## Overview 4 | 5 | *TrajLearn* is a transformer-based model designed to predict future trajectories using higher-order mobility flow representations (hexagonal grids). This model integrates a beam search variant to enhance spatial continuity, providing superior accuracy compared to existing methods. It is a powerful solution for trajectory prediction tasks in autonomous vehicles, robotics, and human motion analysis. 6 | 7 |

8 | TrajLearn model architecture 9 |

10 | 11 | ## Getting Started 12 | 13 | ### Prerequisites 14 | 15 | - This implementation requires **Python version `>= 3.8`**. 16 | - Ensure you have a compatible Python environment. Refer to `environment.yml` for the required packages. 17 | 18 | ### Step-by-Step Instructions 19 | 20 | 0. **Clone the Repository** 21 | 22 | First, clone the project on your computer: 23 | ```bash 24 | git clone https://github.com/amir-ni/Trajectory-prediction 25 | ``` 26 | 27 | 1. **Download the Datasets**: 28 | 29 | Make the dataset downloader script executable: 30 | ```bash 31 | chmod +x ./download_data.sh 32 | ``` 33 | Run the script to download the datasets. You can specify the datasets by passing them as arguments (`geolife`, `porto`, or `rome`). For example: 34 | ```bash 35 | ./download_data.sh geolife porto rome 36 | ``` 37 | 38 | 2. **Prepare the Datasets**: 39 | 40 | After downloading, run the following command to prepare and transform the datasets: 41 | ```bash 42 | python3 TrajLearn/preprocess.py --input_dir --output_dir --embedding_dim --datasets 43 | ``` 44 | You can specify the `input_dir`, `output_dir`, and `datasets` to be processed: 45 | - **`--input_dir`**: Directory where the raw datasets are stored. Defaults to `./data`. 46 | - **`--output_dir`**: Directory where the transformed datasets will be saved. Defaults to `./data`. 47 | 48 | - **`--datasets`**: Select which datasets to process (`geolife`, `porto`, `rome`). Multiple datasets can be processed by specifying them in a space-separated list. For example: 49 | ```bash 50 | python3 TrajLearn/preprocess.py --datasets rome geolife porto --embedding_dim 512 51 | ``` 52 | 53 | 3. **Set Up the Model Configuration**: 54 | 55 | The configuration of the model, such as batch size, learning rates, and dataset-specific settings, can be passed as to model as a `yaml` configuration file. This file can also include multiple configurations and will train separate models sequentially. An example configuration used for generating results provided in the paper can be found in `configs.yaml`. 56 | 57 | You can create/modify this file according to your needs. Some configurations are described below and additional configurations can be found in the end of this document. 58 | - **`data_dir`**: Directory where the dataset is stored. If you have not changed the default output directory in the previous steps, the address would be `./data`. 59 | - **`dataset`**: Name of the dataset being used, such as `rome7`. 60 | - **`model_checkpoint_directory`**: Directory path where model checkpoints will be saved during training. 61 | - **`min_input_length`**: The minimum length of input sequences used during training and testing. 62 | - **`max_input_length`**: The maximum length of input sequences allowed for model training. 63 | - **`test_input_length`**: Length of the input sequence during testing. 64 | - **`test_prediction_length`**: The number of future steps the model will predict during testing. 65 | 66 | After modifying these parameters as per your requirements, save it in a `yaml` file. This file will be used during training and testing to control model behavior. 67 | 68 | 4. **Train the Model**: 69 | 70 | After configuring the model, you can start the training process. Use the following command: 71 | ```bash 72 | python3 main.py configs.yaml 73 | ``` 74 | 75 | This will train the model using the parameters specified in the `configs.yaml` file. You can change it to your saved `yaml` file. 76 | 77 | 5. **Test the Model**: 78 | 79 | Once the model is trained, you can evaluate its performance by running: 80 | ```bash 81 | python3 main.py configs.yaml --test 82 | ``` 83 | 84 | This will test the trained model on the test part of the dataset. 85 | 86 | 87 | 88 | ### Additional Configuration Options: 89 | 90 | - **`test_ratio`**: Proportion of the dataset used for testing. For example, a value of `0.2` means 20% of the dataset will be used for testing. 91 | - **`validation_ratio`**: Proportion of the dataset used for validation. For example, a value of `0.1` means 10% of the dataset will be used for validation. 92 | - **`delimiter`**: The character that separates values in your dataset files (default is `" "`). 93 | - **`batch_size`**: The number of samples processed together in one forward/backward pass. 94 | - **`device`**: The computational device to use for training and testing. Set to `cuda` for GPU acceleration or `cpu` if no GPU is available. 95 | - **`max_epochs`**: The maximum number of training epochs, where one epoch means a complete pass through the entire dataset. 96 | - **`block_size`**: Block size used for processing sequences. Defines the length of the sequence chunks used during training and testing. 97 | - **`learning_rate`**: Initial learning rate for the optimizer. Adjust this to control how fast the model learns. 98 | - **`weight_decay`**: Regularization term to avoid overfitting by penalizing large weights. Higher values provide stronger regularization. 99 | - **`beta1`**: Beta1 hyperparameter for the Adam optimizer, which controls the decay rate for the first moment estimate. 100 | - **`beta2`**: Beta2 hyperparameter for the Adam optimizer, controlling the decay rate for the second moment estimate. 101 | - **`grad_clip`**: Threshold for gradient clipping. Gradients that exceed this value will be clipped to prevent exploding gradients. 102 | - **`decay_lr`**: Boolean flag to indicate whether the learning rate should be decayed over time. 103 | - **`warmup_iters`**: Number of iterations during which the learning rate will increase from a small value to the initial learning rate (used in learning rate scheduling). 104 | - **`lr_decay_iters`**: Number of iterations over which the learning rate decays. 105 | - **`min_lr`**: Minimum learning rate after decay. The learning rate will not decrease below this value. 106 | - **`seed`**: Random seed for reproducibility. Ensures that experiments can be replicated with the same results. 107 | - **`n_layer`**: Number of layers in the transformer model. More layers can increase model capacity but also computational cost. 108 | - **`n_head`**: Number of attention heads in the transformer model, which allows the model to focus on different parts of the input sequence simultaneously. 109 | - **`n_embd`**: Dimensionality of the embedding space. This represents the size of the vector representations for each token in the input sequence. 110 | - **`bias`**: Boolean flag to indicate whether to include bias terms in the model's layers. Set to `False` to exclude bias. 111 | - **`dropout`**: Dropout rate used for regularization. A value of `0` means no dropout will be applied. 112 | - **`custom_initialization`**: A boolean flag that specifies whether to use a axial coordination based initialization for the model's training. 113 | - **`train_from_checkpoint_if_exist`**: A boolean flag that indicates whether to resume training from an existing checkpoint if one is found. 114 | - **`patience`**: Integer value indicating the number of epochs to wait for before early stopping. 115 | - **`continuity`**: Boolean flag to enforce spatial continuity constraints on predictions. 116 | - **`beam_width`**: Integer specifying the beam width for beam search. 117 | - **`store_predictions`**: Boolean flag to enable or disable storing the predicted sequences. 118 | 119 | 120 | ## Contact 121 | 122 | This project was developed by [Amirhossein Nadiri](https://github.com/amir-ni). 123 | 124 | For any inquiries or collaboration, feel free to reach out at [anadiri@yorku.ca](mailto:anadiri@yorku.ca). 125 | 126 | ## License 127 | 128 | This project is open-source software licensed under the [LICENSE](LICENSE). 129 | 130 | ## Citation 131 | 132 | If you use this project or TrajLearn in your research, please consider citing it as follows: 133 | 134 | ```tex 135 | @article{TrajLearn, 136 | author = {Nadiri, Amirhossein and Li, Jing and Faraji, Ali and Abuoda, Ghadeer and Papagelis, Manos}, 137 | title = {TrajLearn: Trajectory Prediction Learning using Deep Generative Models}, 138 | year = {2025}, 139 | publisher = {Association for Computing Machinery}, 140 | address = {New York, NY, USA}, 141 | journal = {ACM Trans. Spatial Algorithms Syst.}, 142 | volume = {In Press}, 143 | numpages = {33}, 144 | } 145 | ``` 146 | -------------------------------------------------------------------------------- /TrajLearn/TrajectoryBatchDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from collections import defaultdict 5 | import torch 6 | import numpy as np 7 | import pandas as pd 8 | from torch.utils.data import IterableDataset 9 | 10 | 11 | class TrajectoryBatchDataset(IterableDataset): 12 | def __init__(self, dataset_directory, dataset_type='train', delimiter=' ', validation_ratio=0.1, test_ratio=0.2): 13 | self.dataset_directory = dataset_directory 14 | 15 | full_data = [np.array([int(j) for j in i.strip().split(delimiter)]) for i in pd.read_csv( 16 | os.path.join(dataset_directory, 'data.txt'), header=None)[0]] 17 | 18 | self.number_of_trajectories = len(full_data) 19 | self.vocab_size = sum(1 for _ in open( 20 | os.path.join(dataset_directory, 'vocab.txt'), encoding='utf-8')) 21 | 22 | if dataset_type == 'train': 23 | self.data = full_data[:-int(self.number_of_trajectories * 24 | (validation_ratio + test_ratio))] 25 | elif dataset_type == 'val': 26 | self.data = full_data[-int(self.number_of_trajectories * ( 27 | validation_ratio + test_ratio)): -int(self.number_of_trajectories * test_ratio)] 28 | elif dataset_type == 'test': 29 | self.data = full_data[-int(self.number_of_trajectories * test_ratio):] 30 | else: 31 | raise ValueError('Invalid type') 32 | 33 | self.dataX = [] 34 | self.dataY = [] 35 | self.batches = [] 36 | self.dataset_type = dataset_type 37 | 38 | def create_batches(self, batch_size, observe, predict=1, shuffle=True, drop_last=False): 39 | 40 | if isinstance(observe, int): 41 | observe = [observe] 42 | if isinstance(predict, int): 43 | predict = [predict] * len(observe) 44 | 45 | for trajectory in self.data: 46 | for j, observe_length in enumerate(observe): 47 | for i in range(0, len(trajectory) - observe_length - predict[j] + 1): 48 | self.dataX.append(trajectory[i:i+observe_length]) 49 | self.dataY.append( 50 | trajectory[i+observe_length:i+observe_length+predict[j]]) 51 | 52 | # Group indices of same size together 53 | size_to_indices = defaultdict(list) 54 | for i, x in enumerate(self.dataX): 55 | size_to_indices[len(x)].append(i) 56 | 57 | # Prepare the list of batches and shuffle it 58 | batches = [] 59 | for size_indices in size_to_indices.values(): 60 | for i in range(0, len(size_indices), batch_size): 61 | batch = size_indices[i:i+batch_size] 62 | if len(batch) == batch_size or not drop_last: 63 | batches.append(batch) 64 | 65 | if shuffle: 66 | random.shuffle(batches) 67 | 68 | self.batches = batches 69 | 70 | def get_neighbors(self): 71 | with open(os.path.join(self.dataset_directory, 'neighbors.json'), encoding='utf-8') as neighbors_file: 72 | neighbors = json.load(neighbors_file) 73 | neighbors = {int(k): v + [0] for k, v in neighbors.items()} 74 | neighbors[0] = [] 75 | return neighbors 76 | 77 | def __len__(self): 78 | return len(self.batches) 79 | 80 | def __getitem__(self, index): 81 | batch_indices = self.batches[index] 82 | return torch.LongTensor(np.stack([self.dataX[i] for i in batch_indices])), torch.LongTensor(np.stack([self.dataY[i] for i in batch_indices])) 83 | 84 | def __iter__(self): 85 | # worker_info = torch.utils.data.get_worker_info() 86 | 87 | # if worker_info is None: 88 | # batches = self.batches 89 | # else: 90 | # n_workers = worker_info.num_workers 91 | # n_data = len(self.batches) 92 | # chunk_size = n_data // n_workers 93 | 94 | # chunk_start = chunk_size * worker_info.id 95 | # batches = self.batches[chunk_start: chunk_start + chunk_size] 96 | # for i in range(len(self.dataX)): 97 | # yield torch.LongTensor(self.dataX[i]), torch.LongTensor(self.dataY[i]) 98 | 99 | for batch_indices in self.batches: 100 | yield torch.LongTensor(np.stack([self.dataX[i] for i in batch_indices])), torch.LongTensor(np.stack([self.dataY[i] for i in batch_indices])) 101 | -------------------------------------------------------------------------------- /TrajLearn/config_loader.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | default_config = { 4 | "test_ratio": 0.2, 5 | "validation_ratio": 0.1, 6 | "delimiter": " ", 7 | "min_input_length": 10, 8 | "max_input_length": 14, 9 | "test_input_length": 10, 10 | "test_prediction_length": 5, 11 | "batch_size": 128, 12 | "device": "cpu", 13 | "max_epochs": 10, 14 | "block_size": 24, 15 | "learning_rate": 5.e-3, 16 | "weight_decay": 5.e-1, 17 | "beta1": 0.9, 18 | "beta2": 0.95, 19 | "grad_clip": 1.0, 20 | "decay_lr": True, 21 | "warmup_iters": 200, 22 | "lr_decay_iters": 40000, 23 | "min_lr": 5.e-7, 24 | "seed": 42, 25 | "data_dir": "./data", 26 | "dataset": "geolife7", 27 | "n_layer": 12, 28 | "n_head": 6, 29 | "n_embd": 512, 30 | "bias": False, 31 | "dropout": 0.1, 32 | "model_checkpoint_directory": "./models/", 33 | "train_from_checkpoint_if_exist": False, 34 | "custom_initialization": False, 35 | "patience": 3, 36 | "continuity": True, 37 | "beam_width": 5, 38 | "store_predictions": False, 39 | } 40 | 41 | def load_config(config_file: str) -> dict: 42 | """ 43 | Load a YAML configuration file and apply default values for missing parameters. 44 | 45 | Parameters: 46 | - config_file (str): Path to the configuration YAML file. 47 | 48 | Returns: 49 | - config_list (dict): A dictionary with the final configuration, including defaults. 50 | """ 51 | with open(config_file, 'r', encoding='utf-8') as stream: 52 | try: 53 | config_list = yaml.safe_load(stream) 54 | 55 | if config_list is None: 56 | config_list = {} 57 | 58 | for config_name, config_values in config_list.items(): 59 | if config_values is None: 60 | config_values = {} 61 | 62 | for key, value in default_config.items(): 63 | config_values[key] = config_values.get(key, value) 64 | 65 | config_list[config_name] = config_values 66 | 67 | return config_list 68 | 69 | except yaml.YAMLError as exc: 70 | print(f"Error loading YAML file: {exc}") 71 | return None 72 | -------------------------------------------------------------------------------- /TrajLearn/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import warnings 4 | from typing import Any 5 | from logging import Logger 6 | from tqdm import tqdm 7 | import torch 8 | from nltk.translate.bleu_score import sentence_bleu 9 | from TrajLearn.TrajectoryBatchDataset import TrajectoryBatchDataset 10 | 11 | def calculate_bleu(predictions: torch.Tensor, targets: torch.Tensor) -> float: 12 | bleu_score = 0.0 13 | with warnings.catch_warnings(): 14 | warnings.simplefilter("ignore") 15 | for prediction, target in zip(predictions, targets): 16 | prediction = prediction.tolist() 17 | target = target.tolist() 18 | bleu_score += sentence_bleu([target], prediction) 19 | return bleu_score 20 | 21 | 22 | @torch.no_grad() 23 | def evaluate_model( 24 | model: torch.nn.Module, 25 | dataset: TrajectoryBatchDataset, 26 | config: Any, 27 | logger: Logger, 28 | top_k: list = None, 29 | ) -> list: 30 | model.eval() 31 | device = config["device"] 32 | device_type = 'cuda' if 'cuda' in device else 'cpu' 33 | prediction_length = config["test_prediction_length"] 34 | ctx = torch.amp.autocast(device_type=device_type, dtype=torch.float32) 35 | 36 | if top_k is None: 37 | top_k = [1, 3, 5] 38 | 39 | beam_width = config["beam_width"] 40 | 41 | if config["continuity"]: 42 | neighbors = dataset.get_neighbors() 43 | 44 | total_bleu_score = 0.0 45 | correct_predictions = {k: torch.zeros( 46 | prediction_length, dtype=torch.int32).to(device) for k in top_k} 47 | 48 | if config["store_predictions"]: 49 | pred_results_buffer = ["input sequence,true label,predicted label\n"] 50 | pred_results_file = open(os.path.join(logger.log_directory, 'predictions.txt'), 'w', encoding='utf-8') 51 | 52 | start_time = time.time() 53 | 54 | total_samples = 0 55 | for X, Y in (pbar := tqdm(dataset, leave=False)): 56 | x, y = X.to(device), Y.to(device) 57 | beams = torch.zeros((x.shape[0], 1, 0), dtype=torch.int32).to(device) 58 | scores = torch.zeros((x.shape[0], 1), dtype=torch.float32).to(device) 59 | 60 | total_samples += x.shape[0] 61 | for j in range(prediction_length): 62 | new_scores, new_beams = [], [] 63 | for b in range(beams.shape[1]): 64 | beam = beams[:, b:b+1] 65 | with ctx: 66 | input_sequence = torch.cat((x, beam.squeeze(1)), dim=1) 67 | logits, _ = model( 68 | input_sequence[:, -config["block_size"]:]) 69 | logits = torch.squeeze(logits, dim=1) 70 | probs = torch.softmax(logits, dim=1) 71 | 72 | # Apply the mask 73 | if config["continuity"]: 74 | last_prediction = input_sequence[:, -1] 75 | mask = torch.zeros_like(logits, dtype=torch.bool) 76 | for idx, item in enumerate(last_prediction): 77 | mask[idx, neighbors[item.item()]] = True 78 | probs[~mask] = 0 79 | 80 | # Get top-k probabilities and their indices 81 | top_probs, indices = torch.topk(probs, beam_width) 82 | indices = torch.where( 83 | indices == 0, torch.ones_like(indices), indices) 84 | # Append new indices to beam 85 | new_beam = torch.cat( 86 | (beam.repeat(1, beam_width, 1), indices.unsqueeze(2)), dim=2) 87 | new_score = scores[:, b:b+1] + \ 88 | torch.log(top_probs) # Update scores 89 | new_scores.append(new_score) 90 | new_beams.append(new_beam) 91 | # Concatenate along beam dimension 92 | new_scores = torch.cat(new_scores, dim=1) 93 | # Concatenate along beam dimension 94 | new_beams = torch.cat(new_beams, dim=1) 95 | top_scores, top_beams = torch.topk(new_scores.view( 96 | X.shape[0], -1), beam_width) # Reshape scores to 2D and get top-k 97 | # Reshape beams to 3D for gathering 98 | beams = new_beams.view(X.shape[0], -1, new_beams.shape[2]) 99 | beams = torch.gather(beams, 1, top_beams.unsqueeze( 100 | 2).expand(-1, -1, beams.shape[2])) # Gather the top-k beams 101 | scores = top_scores # Update scores with top-k scores 102 | 103 | for k in correct_predictions.keys(): 104 | if beam_width >= k: 105 | predictions = beams[:, :k] # Get the top-k beams 106 | for beam_number in range(k): 107 | correct_predictions[k][j] += torch.sum((predictions[:, beam_number:beam_number+1].squeeze(1) == 108 | y[:, :j+1]).all(dim=1).int()) 109 | 110 | total_bleu_score += calculate_bleu(beams[:, 0], y) 111 | 112 | if config["store_predictions"]: 113 | beams_np = beams[:, 0].cpu().numpy() 114 | xs_np, ys_np = x.cpu().numpy(), y.cpu().numpy() 115 | for sample_id, x_np in enumerate(xs_np): 116 | prediction_result = f"{' '.join(map(str, x_np))}, {' '.join(map(str, ys_np[sample_id]))}, {' '.join(map(str, beams_np[sample_id]))}\n" 117 | pred_results_buffer.append(prediction_result) 118 | if len(pred_results_buffer) > 10000: 119 | pred_results_file.writelines(pred_results_buffer) 120 | pred_results_buffer = [] 121 | 122 | acc1 = ((100 * correct_predictions[1][prediction_length-1]) / total_samples).item() 123 | pbar.set_postfix(**{"Acc@1": acc1}) 124 | 125 | if config["store_predictions"]: 126 | pred_results_file.writelines(pred_results_buffer) 127 | 128 | test_duration = time.time() - start_time 129 | 130 | avg_bleu_score = total_bleu_score / total_samples 131 | acc = {k: (100 * v) / (total_samples) 132 | for k, v in correct_predictions.items()} 133 | results = [f"Dataset: {config['dataset']}"] 134 | for k, v in acc.items(): 135 | results.append(f"Accuracy@{k}: {v[-1]:.4f}") 136 | results.append(f"BLEU score: {avg_bleu_score:.4f}") 137 | results.append(f"Test duration: {test_duration:.3f}(s)") 138 | results.append(f"Samples: {total_samples}") 139 | logger.info(", ".join(results)) 140 | 141 | return results 142 | -------------------------------------------------------------------------------- /TrajLearn/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from logging.handlers import RotatingFileHandler 4 | 5 | def get_logger(log_directory, 6 | name, 7 | phase="train", 8 | console_level=logging.DEBUG, 9 | file_level=logging.INFO, 10 | max_file_size=10 * 1024 * 1024, 11 | backup_count=5, 12 | log_format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d]: %(message)s', 13 | date_format='%Y-%m-%d %H:%M:%S'): 14 | """ 15 | Creates and configures a logger with console and file handlers. 16 | 17 | Parameters: 18 | - log_directory (str): Directory where log files will be stored. 19 | - phase (str): Logging phase (e.g., "train", "test"). The log file will be named accordingly. 20 | - console_level (int): Logging level for console output (default: logging.DEBUG). 21 | - file_level (int): Logging level for file output (default: logging.INFO). 22 | - max_file_size (int): Maximum size of the log file before it gets rotated (default: 10MB). 23 | - backup_count (int): Number of backup log files to keep (default: 5). 24 | - log_format (str): Format of the log messages. 25 | - date_format (str): Format of the date in log messages. 26 | 27 | Returns: 28 | - logger (logging.Logger): Configured logger object. 29 | """ 30 | logger = logging.getLogger(f"{name}-{phase}") 31 | 32 | if not logger.hasHandlers(): 33 | logger.setLevel(logging.DEBUG) 34 | 35 | formatter = logging.Formatter(log_format, date_format) 36 | 37 | console_handler = logging.StreamHandler() 38 | console_handler.setLevel(console_level) 39 | console_handler.setFormatter(formatter) 40 | 41 | os.makedirs(log_directory, exist_ok=True) 42 | logger.log_directory = log_directory 43 | 44 | logfile = os.path.join(log_directory, f"{phase}.log") 45 | file_handler = RotatingFileHandler(logfile, mode='a', 46 | maxBytes=max_file_size, backupCount=backup_count) 47 | file_handler.setLevel(file_level) 48 | file_handler.setFormatter(formatter) 49 | 50 | logger.addHandler(console_handler) 51 | logger.addHandler(file_handler) 52 | 53 | return logger 54 | -------------------------------------------------------------------------------- /TrajLearn/mixed_res.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ast 3 | import json 4 | import pickle 5 | import argparse 6 | from pathlib import Path 7 | from collections import defaultdict 8 | from collections.abc import Callable 9 | 10 | import h3 11 | import geopandas 12 | import matplotlib 13 | import numpy as np 14 | import pandas as pd 15 | import contextily as cx 16 | import matplotlib.pyplot as plt 17 | from matplotlib.colors import ListedColormap 18 | 19 | 20 | def gps_to_h3( 21 | gps_points: list, 22 | resolution: int, 23 | boundary: dict, 24 | ): 25 | """ 26 | Convert a list of GPS coordinates to H3 hexagons at a specified resolution. 27 | 28 | Arguments: 29 | - gps_points: list of tuples, where each tuple contains (longitude, latitude) in degrees. 30 | Example: [(lon1, lat1), (lon2, lat2), ...] 31 | - resolution: int, the H3 resolution level to use for converting coordinates. 32 | Higher resolutions create smaller hexagons. 33 | 34 | Returns: 35 | - A list of H3 hexagon IDs corresponding to the input GPS coordinates. 36 | """ 37 | result = [] 38 | for lon, lat in gps_points: 39 | if boundary["Min_lon"] <= lon and boundary["Min_lat"] <= lat and \ 40 | boundary["Max_lon"] >= lon and boundary["Max_lat"] >= lat: 41 | cell = h3.latlng_to_cell(lat, lon, resolution) 42 | if len(result) == 0 or cell != result[-1]: 43 | result.append(h3.latlng_to_cell(lat, lon, resolution)) 44 | return result 45 | 46 | 47 | def preprocess_resolution( 48 | dataset: str = 'geolife', 49 | min_resolution: int = 7, 50 | max_resolution: int = 9, 51 | output_dir: str = 'data/mixed_res', 52 | save_csv: bool = False, 53 | use_boundary: bool = True, 54 | ): 55 | """ 56 | Preprocess a trajectory dataset by converting GPS points to H3 hexagon representations 57 | at various resolutions, and save the results and associated data structures. 58 | 59 | This function reads a CSV file containing GPS points, converts these points to H3 60 | hexagon IDs at specified resolutions, and optionally saves the results as CSV and 61 | pickle files. It also calculates hexagon counts and neighbor relationships. 62 | 63 | Arguments: 64 | - dataset: str, name of the dataset to process (e.g., 'geolife', 'rome', 'porto'). 65 | - min_resolution: int, the minimum H3 resolution level for hexagon generation. 66 | - max_resolution: int, the maximum H3 resolution level for hexagon generation. 67 | - output_dir: str, the directory where processed data and outputs will be saved. 68 | - save_csv: bool, if True, saves processed data and hexagon counts as CSV files. 69 | 70 | Steps Performed: 71 | 1. Reads GPS points from a dataset CSV file and applies a transformation to generate 72 | H3 hexagon IDs at each resolution from `min_resolution` to `max_resolution`. 73 | 2. Saves the transformed dataset to a CSV and pickle file if `save_csv` is True. 74 | 3. Computes and saves hexagon counts and unique hexagon IDs at each resolution. 75 | 4. Computes neighbor relationships for hexagons and saves them as a pickle file. 76 | 77 | Outputs: 78 | - Processed dataset with hexagon columns at specified resolutions (CSV and/or pickle). 79 | - Hexagon count file, detailing how many times each hexagon appears (CSV and pickle). 80 | - A dictionary of unique hexagons at each resolution (pickle). 81 | - Neighbor relationships for each hexagon, excluding child hexagons (pickle). 82 | """ 83 | os.makedirs(output_dir, exist_ok=True) 84 | 85 | boundary = { 86 | 'geolife':{ 87 | "Min_lat":39.50000000, 88 | "Max_lat":40.50000000, 89 | "Min_lon":116.00000000, 90 | "Max_lon":117.00000000 91 | }, 92 | 'rome':{ 93 | "Min_lat":41.793710, 94 | "Max_lat":41.991390, 95 | "Min_lon":12.372598, 96 | "Max_lon":12.622537 97 | }, 98 | 'porto': { 99 | "Min_lat": -1000, 100 | "Max_lat": 1000, 101 | "Min_lon": -1000, 102 | "Max_lon": 1000 103 | } 104 | } 105 | 106 | data_path = { 107 | 'geolife':'data/raw_aggrigated/geolife_aggregated.csv', 108 | 'rome': 'data/raw_aggrigated/rome_taxi_aggregated.csv', 109 | 'porto': 'data/raw_aggrigated/porto.csv' 110 | } 111 | 112 | file_path = data_path[dataset] 113 | print(f"Processing file: {file_path}") 114 | 115 | target_column = 'route_points' if 'geolife' in file_path or 'rome_taxi' in file_path else 'TIMESTAMP' 116 | target_time = 'date' if 'geolife' in file_path or 'rome_taxi' in file_path else 'POLYLINE' 117 | df = pd.read_csv(file_path, usecols=[target_column, target_time]).rename(columns={target_column: 'points', target_time: 'time'}) 118 | 119 | df['points'] = df['points'].apply(ast.literal_eval) 120 | print("Processed points") 121 | for resolution in range(min_resolution, max_resolution+1): 122 | column_name = f'hex_{resolution}' 123 | if use_boundary: 124 | df[column_name] = df['points'].apply(gps_to_h3, args=(resolution, boundary[dataset],)) 125 | else: 126 | df[column_name] = df['points'].apply(gps_to_h3, args=(resolution,)) 127 | print(f"Processed resolution: {resolution}") 128 | 129 | if save_csv: 130 | df.to_csv(f"{output_dir}/{dataset}.csv", index=False) 131 | print("Saved dataset csv") 132 | 133 | with open(f'{output_dir}/{dataset}.pkl', 'wb') as f: 134 | pickle.dump(df, f) 135 | print("Saved dataset pickle") 136 | 137 | hex_counts = defaultdict(int) 138 | hexes = {res: set() for res in range(min_resolution, max_resolution + 1)} 139 | 140 | for res in range(min_resolution, max_resolution + 1): 141 | column_name = f'hex_{res}' 142 | for hex_list in df[column_name]: 143 | for hex_id in hex_list: 144 | hex_counts[hex_id] += 1 145 | hexes[res].add(hex_id) 146 | 147 | if save_csv: 148 | hex_df = pd.DataFrame.from_dict(hex_counts, orient='index', columns=['occurrences']) 149 | hex_df.index.name = 'hex_id' 150 | hex_df.to_csv(f'{output_dir}/hex_count_{dataset}.csv') 151 | print("Saved hexagon count csv") 152 | 153 | 154 | with open(f'{output_dir}/hexes_{dataset}.pkl', 'wb') as f: 155 | pickle.dump(hexes, f) 156 | print("Saved hexagons dictionary pickle") 157 | 158 | 159 | with open(f'{output_dir}/hex_count_{dataset}.pkl', 'wb') as f: 160 | pickle.dump(hex_counts, f) 161 | print("Saved hexagon count pickle") 162 | 163 | neighbors = defaultdict(set) 164 | for resolution in range(min_resolution, max_resolution+1): 165 | for hex_id in hexes[resolution]: 166 | hex_neighbors = set() 167 | hex_children = set() 168 | 169 | for children_resolution in range(resolution+1, max_resolution+1): 170 | hex_children.update(h3.cell_to_children(hex_id, children_resolution)) 171 | 172 | for hex_child in hex_children: 173 | hex_neighbors.update(h3.grid_ring(hex_child, 1)) 174 | neighbors[hex_id] = hex_neighbors - hex_children 175 | 176 | for hex_id, neighbors_set in list(neighbors.items()): 177 | for neighbor in neighbors_set: 178 | neighbors[neighbor].add(hex_id) 179 | 180 | with open(f'{output_dir}/neighbors_{dataset}.pkl', 'wb') as f: 181 | pickle.dump(neighbors, f) 182 | print("Saved hexagon neighbors pickle") 183 | 184 | 185 | def visualize(hex_counts: dict, output_path: str, bins:list = None, **kwargs): 186 | """ 187 | Generate and display a heatmap of hexagon counts on a map, and save it as an image. 188 | 189 | Arguments: 190 | - hex_seq: list of str, a list of hexagon sequences. 191 | - zoom_level: int, the zoom level for the map. 192 | - output_dir: str, the directory path to save any output if necessary. 193 | - **kwargs: Options to pass to geopandas plotting method. 194 | 195 | The function flattens the input hex sequences, counts the occurrences of each hexagon, 196 | creates a GeoJSON object, and visualizes it using Plotly with a heatmap. 197 | """ 198 | matplotlib.rcParams.update({'font.size': 28}) 199 | df_hex_plot = pd.DataFrame(hex_counts.items(), columns=['hex_id', 'count']) 200 | 201 | df_hex_plot['geometry'] = df_hex_plot['hex_id'].apply(lambda x: h3.cells_to_h3shape([x])) 202 | df = geopandas.GeoDataFrame(df_hex_plot.drop(columns=['hex_id']), crs='EPSG:4326') 203 | df = df.to_crs(epsg=3857) 204 | 205 | _, ax = plt.subplots(figsize=(24,24)) 206 | ax.get_xaxis().set_visible(False) 207 | ax.get_yaxis().set_visible(False) 208 | 209 | if bins is not None: 210 | labels = ['Low', 'Medium', 'High'] 211 | df['count'] = pd.cut(df['count'], bins=bins, labels=labels, include_lowest=True) 212 | 213 | df.plot( 214 | ax=ax, 215 | alpha=0.9, 216 | edgecolor=(134/256, 218/256, 227/256, 0.1), 217 | linewidth=0.001, 218 | column='count', 219 | legend=True, 220 | **kwargs, 221 | ) 222 | 223 | cx.add_basemap(ax, crs=df.crs, source=cx.providers.CartoDB.Positron) 224 | 225 | plt.tight_layout() 226 | plt.savefig(output_path, format="pdf", bbox_inches='tight') 227 | 228 | 229 | def threshold_split_condition(threshold: int = 100): 230 | """ 231 | Create a custom split condition function based on a threshold for hexagon occurrences. 232 | 233 | This function returns a nested `split_condition` function that evaluates whether 234 | a hexagon should be split based on the number of occurrences compared to a specified 235 | threshold. 236 | 237 | Arguments: 238 | - threshold: int, the minimum number of occurrences required for a hexagon to be 239 | considered for splitting. Default is 100. 240 | 241 | Returns: 242 | - A function `split_condition(current_res, hex_count, neighbors_stat)` that takes: 243 | - current_res: int, the current resolution of the hexagon (not used in this function). 244 | - hex_count: int, the count of occurrences in the current hexagon. 245 | - neighbors_stat: any, information/statistics about neighboring hexagons 246 | (not used in this function). 247 | The `split_condition` function returns `True` if `hex_count` exceeds the threshold, 248 | indicating the hexagon should be split; otherwise, it returns `False`. 249 | """ 250 | def split_condition(current_res, hex_count, neighbors_stat): 251 | """ 252 | Decide if a hexagon should split based on its occurrence count. 253 | 254 | Arguments: 255 | - current_res: int, the current resolution of the hexagon (not used in this function). 256 | - hex_count: int, the count of occurrences in the current hexagon. 257 | - neighbors_stat: any, information/statistics about neighboring hexagons 258 | (not used in this function). 259 | 260 | Returns: 261 | - True if `hex_count` exceeds the specified `threshold`, indicating the hexagon 262 | should be split; otherwise, False. 263 | """ 264 | return hex_count > threshold 265 | 266 | return split_condition 267 | 268 | 269 | def complex_split_condition(threshold=100, std_ratio=2): 270 | """ 271 | Create a split condition function that decides whether to split a hexagon based on: 272 | 1. The normalized occurrence count (using the coefficient of variation) exceeding a threshold. 273 | 2. Significant normalized variance in the number of occurrences between the hexagon and its neighbors. 274 | 275 | Arguments: 276 | - threshold_ratio: float, the minimum coefficient of variation for the hexagon's count to consider splitting. 277 | - variance_ratio: float, the minimum normalized variance between the hexagon and its neighbors for significant change. 278 | 279 | Returns: 280 | - A function `split_condition(current_res, hex_count, neighbors_stat)` that returns True if the 281 | hexagon meets the split criteria; otherwise, False. 282 | """ 283 | def split_condition(current_res, hex_count, neighbors_stat): 284 | """ 285 | Determine if a hexagon should be split based on normalized variance and threshold ratio. 286 | 287 | Arguments: 288 | - current_res: int, the current resolution of the hexagon. 289 | - hex_count: int, the number of occurrences in the current hexagon. 290 | - neighbors_stat: dict, where keys are neighbor hex IDs and values are their occurrence counts. 291 | 292 | Returns: 293 | - Boolean: True if the hexagon meets the split conditions; otherwise, False. 294 | """ 295 | if not neighbors_stat: 296 | return False 297 | 298 | neighbor_counts = np.array(list(neighbors_stat.values())) 299 | mean_neighbors = np.mean(neighbor_counts) 300 | if mean_neighbors == 0: 301 | return False 302 | 303 | if hex_count <= threshold: 304 | return False 305 | 306 | normalized_std = np.std(neighbor_counts) / mean_neighbors 307 | 308 | return normalized_std > std_ratio 309 | 310 | return split_condition 311 | 312 | 313 | def skewness_stopping_condition(threshold: float = 0.2): 314 | """ 315 | Create a stopping condition function based on the skewness of a distribution. 316 | 317 | This function returns a nested `stopping_condition` function that evaluates whether 318 | the skewness of the distribution of hexagon occurrences meets a specified threshold. 319 | The stopping condition is used to decide if the iterative process should halt. 320 | 321 | Arguments: 322 | - threshold: float, the skewness threshold for stopping. If the skewness of the 323 | distribution of occurrences is less than this threshold, the condition 324 | returns `True`, indicating that the process should stop. Default is 0.2. 325 | 326 | Returns: 327 | - A function `stopping_condition(hexagons)` that takes: 328 | - hexagons: dict[str, int], a dictionary where keys are hexagon IDs and values are 329 | the count of occurrences in each hexagon. 330 | The function returns `True` if the skewness of the values in `hexagons` is below 331 | the specified threshold, indicating that the stopping condition is met; otherwise, 332 | it returns `False`. 333 | """ 334 | def stopping_condition(hexagons: dict): 335 | """ 336 | Evaluate the skewness of the hexagon occurrence counts to determine if the process should stop. 337 | 338 | Arguments: 339 | - hexagons: dict[str, int], a dictionary where keys are hexagon IDs and values are 340 | the count of occurrences in each hexagon. 341 | 342 | Returns: 343 | - True if the skewness is less than the specified threshold, indicating that the 344 | condition for stopping is met; otherwise, False. 345 | """ 346 | data = np.array(list(hexagons.values())) 347 | n = len(data) 348 | if n < 3: 349 | raise ValueError("Skewness calculation requires at least 3 data points.") 350 | 351 | mean = np.mean(data) 352 | std_dev = np.std(data, ddof=1) 353 | skewness = (n / ((n - 1) * (n - 2))) * np.sum(((data - mean) / std_dev) ** 3) 354 | print(skewness) 355 | return skewness < threshold 356 | 357 | return stopping_condition 358 | 359 | 360 | def mixed_resolution( 361 | split_condition_fn: Callable, 362 | stopping_condition_fn: Callable, 363 | dataset: str = 'geolife', 364 | min_resolution: int = 7, 365 | max_resolution: int = 9, 366 | input_dir: str = 'data/mixed_res', 367 | output_dir: str = 'data/mixed_res', 368 | max_iterations: int = 5, 369 | ): 370 | """ 371 | Perform iterative hexagon refinement based on split and stopping conditions. 372 | 373 | This function reads preprocessed data including hexagon counts, unique hexagon sets, 374 | and neighbor relationships, and iteratively refines the hexagon set by splitting 375 | hexagons that meet the given `split_condition_fn`. The process halts when the 376 | `stopping_condition_fn` is satisfied or the maximum number of iterations is reached. 377 | 378 | Arguments: 379 | - split_condition_fn: Callable, a function that takes the current resolution, 380 | hexagon count, and neighbor statistics and returns a boolean 381 | indicating whether the hexagon should be split. 382 | - stopping_condition_fn: Callable, a function that takes a dictionary of hexagons 383 | and their counts and returns a boolean indicating whether 384 | the stopping condition is met. 385 | - dataset: str, the name of the dataset to process (e.g., 'geolife'). 386 | - min_resolution: int, the initial resolution of hexagons to start processing from. 387 | - max_resolution: int, the maximum resolution allowed for splitting hexagons. 388 | - input_dir: str, the directory path to load the preprocessed input files. 389 | - output_dir: str, the directory path to save any output if necessary. 390 | - max_iterations: int, the maximum number of iterations for the process. 391 | 392 | Process: 393 | 1. Load preprocessed neighbor relationships, hexagon counts, and unique hexagon sets 394 | from pickle files. 395 | 2. Initialize the hexagon set from the `min_resolution` level. 396 | 3. Iterate up to `max_iterations` times, splitting hexagons that meet the `split_condition_fn`. 397 | 4. Check the `stopping_condition_fn` at each iteration to potentially stop the process early. 398 | 5. Split marked hexagons into their children at the next higher resolution. 399 | 6. Print status messages and information about each iteration. 400 | 401 | Outputs: 402 | - Prints the number of hexagons processed at each iteration. 403 | - Stops the process when the stopping condition is met or no hexagons meet the split condition. 404 | """ 405 | os.makedirs(output_dir, exist_ok=True) 406 | 407 | with open(f'{input_dir}/neighbors_{dataset}.pkl', 'rb') as f: 408 | neighbors = pickle.load(f) 409 | 410 | with open(f'{input_dir}/hex_count_{dataset}.pkl', 'rb') as f: 411 | hex_counts = pickle.load(f) 412 | 413 | with open(f'{input_dir}/hexes_{dataset}.pkl', 'rb') as f: 414 | hexes = pickle.load(f) 415 | 416 | hexagon_set = hexes[min_resolution] 417 | iteration = 0 418 | 419 | while iteration < max_iterations: 420 | iteration += 1 421 | print(f"Iteration {iteration}: Dataset has {len(hexagon_set)} hexagons") 422 | if stopping_condition_fn({hexagon: hex_counts[hexagon] for hexagon in hexagon_set}): 423 | print(f"Stopping condition invoked at iteration {iteration}") 424 | break 425 | 426 | marked_for_split = set() 427 | 428 | for hex_id in hexagon_set: 429 | current_res = h3.get_resolution(hex_id) 430 | if current_res == max_resolution: 431 | continue 432 | neighbors_stat = {neighbor: hex_counts[neighbor] for neighbor in neighbors[hex_id]} 433 | hex_count = hex_counts[hex_id] 434 | 435 | if split_condition_fn(current_res, hex_count, neighbors_stat): 436 | marked_for_split.add(hex_id) 437 | 438 | if len(marked_for_split) == 0: 439 | print(f"None of hexagons met split condition at iteration {iteration}") 440 | break 441 | else: 442 | print(f"{len(marked_for_split)} of hexagons met split condition at iteration {iteration}") 443 | 444 | new_hexagon_set = set() 445 | for hexagon in hexagon_set: 446 | if hexagon in marked_for_split: 447 | new_hexagon_set.update(h3.cell_to_children(hexagon)) 448 | else: 449 | new_hexagon_set.add(hexagon) 450 | hexagon_set = new_hexagon_set 451 | 452 | with open(f'{output_dir}/final_hexes_{dataset}.pkl', 'wb') as f: 453 | pickle.dump(hexagon_set, f) 454 | print("Saved final hexagons set pickle") 455 | 456 | visualize({hexagon: h3.get_resolution(hexagon) for hexagon in hexagon_set}, "mixed-res-heatmap.pdf", categorical=True, cmap=ListedColormap(["#66CDAA", "#9370DB", "#86DAE3"]), legend_kwds={"loc": "upper right", "title":"Resolution", "markerscale":2.5},) 457 | visualize({hexagon: hex_counts[hexagon] for hexagon in hexes[min_resolution]}, "mixed-res-map.pdf", bins=[0,500,10000,np.inf], cmap=ListedColormap(["#66CDAA", "#9370DB", "#86DAE3"]), categorical=True, k=10, legend_kwds={"loc": "upper right", "title":"Movement Density", "markerscale":2.5,},) 458 | print("Heatmap saved") 459 | 460 | 461 | def apply_processing( 462 | dataset: str = 'geolife', 463 | min_resolution: int = 7, 464 | max_resolution: int = 9, 465 | output_dir: str = 'data/mixed_res', 466 | date_column: str = "time" 467 | ): 468 | 469 | with open(f'{output_dir}/{dataset}.pkl', 'rb') as f: 470 | hexes = pickle.load(f) 471 | 472 | with open(f'{output_dir}/final_hexes_{dataset}.pkl', 'rb') as f: 473 | hexagon_set = pickle.load(f) 474 | 475 | 476 | boundary = { 477 | 'geolife':{ 478 | "Min_lat":39.50000000, 479 | "Max_lat":40.50000000, 480 | "Min_lon":116.00000000, 481 | "Max_lon":117.00000000 482 | }, 483 | 'rome':{ 484 | "Min_lat":41.793710, 485 | "Max_lat":41.991390, 486 | "Min_lon":12.372598, 487 | "Max_lon":12.622537 488 | }, 489 | 'porto': { 490 | "Min_lat": -1000, 491 | "Max_lat": 1000, 492 | "Min_lon": -1000, 493 | "Max_lon": 1000 494 | } 495 | } 496 | 497 | def gps_to_mixed_h3(gps_points: list, boundary: dict): 498 | result = [] 499 | for lon, lat in gps_points: 500 | if boundary["Min_lon"] <= lon and boundary["Min_lat"] <= lat and \ 501 | boundary["Max_lon"] >= lon and boundary["Max_lat"] >= lat: 502 | for resolution in range(min_resolution, max_resolution+1): 503 | cell = h3.latlng_to_cell(lat, lon, resolution) 504 | if cell in hexagon_set: 505 | if len(result) == 0 or cell != result[-1]: 506 | result.append(cell) 507 | break 508 | return result 509 | 510 | hexes['points'] = hexes['points'].apply(gps_to_mixed_h3, args=(boundary[dataset],)) 511 | 512 | output_dir = Path(output_dir) 513 | (output_dir / dataset).mkdir(parents=True, exist_ok=True) 514 | 515 | vocab = ["EOT"] + list(hexagon_set) 516 | vocab_file_path = output_dir / dataset / 'vocab.txt' 517 | with vocab_file_path.open('w', encoding='utf-8') as vocab_file: 518 | vocab_file.write("\n".join(vocab) + "\n") 519 | 520 | mapping = {k: v for v, k in enumerate(vocab)} 521 | mapping_file_path = output_dir / dataset / 'mapping.json' 522 | with mapping_file_path.open('w', encoding='utf-8') as mapping_file: 523 | json.dump(mapping, mapping_file, ensure_ascii=False) 524 | 525 | hexes['points'] = hexes['points'].apply(lambda x: [str(mapping[j]) for j in x]) 526 | df_mapped = hexes.sort_values(by=[date_column])["points"].to_list() 527 | data_file_path = output_dir / dataset / 'data.txt' 528 | with data_file_path.open('w', encoding='utf-8') as data_file: 529 | for item in df_mapped: 530 | data_file.write(' '.join(item) + f" {mapping['EOT']}\n") 531 | 532 | 533 | def main() -> None: 534 | """ 535 | Main function to handle argument parsing and execute data preprocessing. 536 | """ 537 | parser = argparse.ArgumentParser(description='Trajectory Prediction Learning') 538 | parser.add_argument('dataset', type=str, help='Dataset') 539 | args = parser.parse_args() 540 | 541 | preprocess_resolution(dataset=args.dataset) 542 | mixed_resolution( 543 | split_condition_fn=threshold_split_condition(150), 544 | stopping_condition_fn=skewness_stopping_condition(0.2), 545 | dataset=args.dataset 546 | ) 547 | apply_processing(dataset=args.dataset) 548 | 549 | 550 | if __name__ == '__main__': 551 | main() 552 | -------------------------------------------------------------------------------- /TrajLearn/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import inspect 3 | from dataclasses import dataclass 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # @torch.jit.script 11 | 12 | class CausalSelfAttention(nn.Module): 13 | """Self-attention module with causal masking.""" 14 | 15 | def __init__(self, config): 16 | super().__init__() 17 | assert config.n_embd % config.n_head == 0, "Embedding size must be divisible by the number of heads." 18 | self.n_head = config.n_head 19 | self.n_embd = config.n_embd 20 | self.dropout = config.dropout 21 | self.c_attn = nn.Linear( 22 | config.n_embd, 3 * config.n_embd, bias=config.bias) 23 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 24 | self.attn_dropout = nn.Dropout(config.dropout) 25 | self.resid_dropout = nn.Dropout(config.dropout) 26 | self.register_buffer( 27 | "bias", 28 | torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size) 29 | ) 30 | 31 | def forward(self, x: torch.Tensor) -> torch.Tensor: 32 | B, T, C = x.size() 33 | 34 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 35 | k = k.view(B, T, self.n_head, C // 36 | self.n_head).transpose(1, 2) # (B, nh, T, hs) 37 | q = q.view(B, T, self.n_head, C // 38 | self.n_head).transpose(1, 2) # (B, nh, T, hs) 39 | v = v.view(B, T, self.n_head, C // 40 | self.n_head).transpose(1, 2) # (B, nh, T, hs) 41 | 42 | # (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 43 | 44 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 45 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) 46 | att = F.softmax(att, dim=-1) 47 | att = self.attn_dropout(att) 48 | y = att @ v 49 | y = y.transpose(1, 2).contiguous().view(B, T, C) 50 | 51 | y = self.resid_dropout(self.c_proj(y)) 52 | return y 53 | 54 | 55 | class MLP(nn.Module): 56 | """Multilayer Perceptron with GELU activation.""" 57 | 58 | def __init__(self, config): 59 | super().__init__() 60 | self.c_fc = nn.Linear( 61 | config.n_embd, 4 * config.n_embd, bias=config.bias) 62 | self.c_proj = nn.Linear( 63 | 4 * config.n_embd, config.n_embd, bias=config.bias) 64 | self.dropout = nn.Dropout(config.dropout) 65 | 66 | def forward(self, x: torch.Tensor) -> torch.Tensor: 67 | x = self.c_fc(x) 68 | x = F.gelu(x, approximate='tanh') 69 | x = self.c_proj(x) 70 | return self.dropout(x) 71 | 72 | 73 | class Block(nn.Module): 74 | """Transformer block consisting of attention and MLP layers.""" 75 | 76 | def __init__(self, config): 77 | super().__init__() 78 | self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias) 79 | self.attn = CausalSelfAttention(config) 80 | self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias) 81 | self.mlp = MLP(config) 82 | 83 | def forward(self, x: torch.Tensor) -> torch.Tensor: 84 | x = x + self.attn(self.ln_1(x)) 85 | return x + self.mlp(self.ln_2(x)) 86 | 87 | 88 | @dataclass 89 | class ModelConfig: 90 | block_size: int = 1024 91 | vocab_size: int = 50304 92 | n_layer: int = 12 93 | n_head: int = 12 94 | n_embd: int = 768 95 | dropout: float = 0.0 96 | bias: bool = True 97 | 98 | 99 | class CausalLM(nn.Module): 100 | """Causal Language Model with optional custom initialization.""" 101 | 102 | def __init__(self, config, custom_init=None): 103 | super().__init__() 104 | assert config.vocab_size is not None 105 | assert config.block_size is not None 106 | self.config = config 107 | 108 | self.transformer = nn.ModuleDict(dict( 109 | wte=nn.Embedding(config.vocab_size, config.n_embd), 110 | wpe=nn.Embedding(config.block_size, config.n_embd), 111 | drop=nn.Dropout(config.dropout), 112 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 113 | ln_f=nn.LayerNorm(config.n_embd, bias=config.bias), 114 | )) 115 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 116 | self.transformer.wte.weight = self.lm_head.weight 117 | 118 | self.apply(self._init_weights) 119 | for name, param in self.named_parameters(): 120 | if name.endswith('c_proj.weight'): 121 | nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) 122 | if custom_init is not None: 123 | self.transformer.wte.weight.data.copy_(custom_init) 124 | 125 | def _init_weights(self, module: nn.Module): 126 | if isinstance(module, (nn.Linear, nn.Embedding)): 127 | nn.init.normal_(module.weight, mean=0.0, std=0.02) 128 | if hasattr(module, 'bias') and module.bias is not None: 129 | nn.init.zeros_(module.bias) 130 | 131 | def forward(self, idx, targets=None): 132 | device = idx.device 133 | b, t = idx.size() 134 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 135 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # (1, t) 136 | 137 | tok_emb = self.transformer.wte(idx) # (b, t, n_embd) 138 | pos_emb = self.transformer.wpe(pos) # (1, t, n_embd) 139 | x = self.transformer.drop(tok_emb + pos_emb) 140 | for block in self.transformer.h: 141 | x = block(x) 142 | x = self.transformer.ln_f(x) 143 | logits = self.lm_head(x[:, [-1], :]) 144 | 145 | if targets is not None: 146 | loss = F.cross_entropy( 147 | logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 148 | else: 149 | loss = None 150 | 151 | return logits, loss 152 | 153 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 154 | """ 155 | This long function is unfortunately doing something very simple and is being very defensive: 156 | We are separating out all parameters of the model into two buckets: those that will experience 157 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 158 | We are then returning the PyTorch optimizer object. 159 | """ 160 | 161 | decay = set() 162 | no_decay = set() 163 | whitelist_weight_modules = (nn.Linear, ) 164 | blacklist_weight_modules = ( 165 | nn.LayerNorm, nn.Embedding) 166 | for mn, m in self.named_modules(): 167 | for pn, p in m.named_parameters(): 168 | fpn = '%s.%s' % (mn, pn) if mn else pn 169 | if pn.endswith('bias'): 170 | no_decay.add(fpn) 171 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 172 | if fpn != 'lm_head.weight': 173 | decay.add(fpn) 174 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 175 | no_decay.add(fpn) 176 | 177 | param_dict = {pn: p for pn, p in self.named_parameters()} 178 | optim_groups = [ 179 | {"params": [param_dict[pn] for pn in sorted( 180 | list(decay))], "weight_decay": weight_decay}, 181 | {"params": [param_dict[pn] 182 | for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 183 | ] 184 | use_fused = (device_type == 'cuda') and ( 185 | 'fused' in inspect.signature(torch.optim.AdamW).parameters) 186 | extra_args = dict(fused=True) if use_fused else dict() 187 | optimizer = torch.optim.AdamW( 188 | optim_groups, lr=learning_rate, betas=betas, **extra_args) 189 | 190 | return optimizer 191 | -------------------------------------------------------------------------------- /TrajLearn/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from pathlib import Path 4 | from typing import List, Dict, Optional 5 | import numpy as np 6 | import pandas as pd 7 | import h3 8 | 9 | def generate_embeddings( 10 | vocab: list, 11 | embedding_dim: int, 12 | mean: float = 0, 13 | std: float = 0.02, 14 | projection_matrix: np.ndarray = None, 15 | random_seed: int = 10, 16 | ) -> np.ndarray: 17 | """ 18 | Generates embeddings for a given vocabulary based on axial coordinates. 19 | 20 | Args: 21 | vocab (list): A list of H3 hex addresses representing the vocabulary for which embeddings will be generated. 22 | embedding_dim (int): The dimension of the generated embedding vectors. 23 | mean (float, optional): The mean of the normal distribution used for embedding normalization. Default is 0. 24 | std (float, optional): The standard deviation of the normal distribution used for embedding normalization. Default is 0.02. 25 | projection_matrix (np.ndarray, optional): A 2D numpy array used for projecting the axial coordinates. If None, a random matrix will be generated. 26 | """ 27 | origin_hex = vocab[0] 28 | base_i, base_j = h3.cell_to_local_ij(origin_hex, origin_hex) 29 | 30 | axial_coordinates = [] 31 | for h3_hex in vocab: 32 | target_i, target_j = h3.cell_to_local_ij(origin_hex, h3_hex) 33 | q, r = target_i - base_i, target_j - base_j 34 | axial_coordinates.append((q,r)) 35 | 36 | np.random.seed(random_seed) 37 | if projection_matrix is None: 38 | projection_matrix = np.random.randn(2, embedding_dim) 39 | 40 | projected_embedding = np.dot(axial_coordinates, projection_matrix) 41 | 42 | # standardized_embedding = (projected_embedding - np.mean(projected_embedding)) / np.std(projected_embedding) 43 | 44 | # normalized_embedding = standardized_embedding * std + mean 45 | 46 | eot_embedding = np.random.normal(loc=mean, scale=std, size=(1, embedding_dim)) 47 | 48 | normal_samples = np.random.normal(loc=mean, scale=std, size=(len(vocab) * embedding_dim)) 49 | flat_projected_embedding = projected_embedding.flatten() 50 | sorted_indices = np.argsort(flat_projected_embedding) 51 | sorted_projected_embedding = np.empty_like(flat_projected_embedding) 52 | sorted_projected_embedding[sorted_indices] = np.sort(normal_samples) 53 | sorted_projected_embedding = sorted_projected_embedding.reshape(projected_embedding.shape) 54 | 55 | return np.concatenate((eot_embedding, sorted_projected_embedding), axis=0) 56 | 57 | 58 | def process_datasets( 59 | input_dir: Path, 60 | output_dir: Path, 61 | datasets_to_process: List[str], 62 | embedding_dim: Optional[int] 63 | ) -> None: 64 | """ 65 | Process trajectory datasets for geolife, porto, and rome, generating vocab, mapping, neighbors, 66 | and transformed trajectory data. 67 | 68 | Args: 69 | input_dir (Path): Directory containing the input datasets. 70 | output_dir (Path): Directory where the processed data will be saved. 71 | datasets_to_process (List[str]): List of datasets to process (e.g., 'geolife', 'porto', 'rome'). 72 | """ 73 | datasets = { 74 | "geolife": [ 75 | ("geolife7", input_dir / "geolife" / "ho_geolife_res7.csv", "date"), 76 | ("geolife8", input_dir / "geolife" / "ho_geolife_res8.csv", "date"), 77 | ("geolife9", input_dir / "geolife" / "ho_geolife_res9.csv", "date") 78 | ], 79 | "porto": [ 80 | ("porto7", input_dir / "porto" / "ho_porto_res7.csv", "TIMESTAMP"), 81 | ("porto8", input_dir / "porto" / "ho_porto_res8.csv", "TIMESTAMP"), 82 | ("porto9", input_dir / "porto" / "ho_porto_res9.csv", "TIMESTAMP") 83 | ], 84 | "rome": [ 85 | ("rome7", input_dir / "rome" / "ho_rome_res7.csv", "date"), 86 | ("rome8", input_dir / "rome" / "ho_rome_res8.csv", "date"), 87 | ("rome9", input_dir / "rome" / "ho_rome_res9.csv", "date") 88 | ] 89 | } 90 | 91 | for dataset_key in datasets_to_process: 92 | if dataset_key in datasets: 93 | for dataset in datasets[dataset_key]: 94 | dataset_name, file_path, date_column = dataset 95 | 96 | dataset_output_dir = output_dir / dataset_name 97 | dataset_output_dir.mkdir(parents=True, exist_ok=True) 98 | 99 | if not file_path.exists(): 100 | print(f"Warning: {file_path} does not exist. Skipping this dataset.") 101 | continue 102 | 103 | df = pd.read_csv(file_path, header=0, usecols=["higher_order_trajectory", date_column], 104 | dtype={"higher_order_trajectory": "string", date_column: "string"}) 105 | df = df.sort_values(by=[date_column])["higher_order_trajectory"].to_numpy() 106 | 107 | df_split = [i.split() for i in df] 108 | 109 | vocab = list(np.unique(np.concatenate(df_split, axis=0))) 110 | 111 | if embedding_dim is not None: 112 | embeddings = generate_embeddings(vocab, embedding_dim) 113 | embeddings_file_path = dataset_output_dir / 'embeddings.npy' 114 | np.save(embeddings_file_path, embeddings) 115 | 116 | vocab = ["EOT"] + vocab 117 | vocab_file_path = dataset_output_dir / 'vocab.txt' 118 | with vocab_file_path.open('w', encoding='utf-8') as vocab_file: 119 | vocab_file.write("\n".join(vocab) + "\n") 120 | 121 | mapping = {k: v for v, k in enumerate(vocab)} 122 | mapping_file_path = dataset_output_dir / 'mapping.json' 123 | with mapping_file_path.open('w', encoding='utf-8') as mapping_file: 124 | json.dump(mapping, mapping_file, ensure_ascii=False) 125 | 126 | neighbors: Dict[int, List[int]] = dict() 127 | for x in vocab[1:]: 128 | neighbors[mapping[str(x)]] = [mapping[i] for i in h3.grid_ring(str(x)) if i in vocab] 129 | neighbors_file_path = dataset_output_dir / 'neighbors.json' 130 | with neighbors_file_path.open('w', encoding='utf-8') as neighbors_file: 131 | json.dump(neighbors, neighbors_file, ensure_ascii=False) 132 | 133 | df_mapped = [[str(mapping[j]) for j in i] for i in df_split] 134 | data_file_path = dataset_output_dir / 'data.txt' 135 | with data_file_path.open('w', encoding='utf-8') as data_file: 136 | for item in df_mapped: 137 | data_file.write(' '.join(item) + f" {mapping['EOT']}\n") 138 | 139 | print(f"Processing completed for {dataset_name}.") 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser(description='Trajectory Prediction Learning for geolife, porto, and rome datasets') 143 | 144 | parser.add_argument('--input_dir', type=Path, default=Path('data'), 145 | help='Path to input dataset files (default: ./data)') 146 | 147 | parser.add_argument('--output_dir', type=Path, default=Path('data'), 148 | help='Path to output directory (default: ./data)') 149 | 150 | parser.add_argument('--datasets', type=str, nargs='+', choices=['geolife', 'porto', 'rome'], required=True, 151 | help='Specify which datasets to process (choose from geolife, porto, rome)') 152 | 153 | parser.add_argument('--embedding_dim', type=int, 154 | help="Dimension of the generated embedding vectors. If not provided, embeddings will not be generated.") 155 | 156 | args = parser.parse_args() 157 | 158 | process_datasets(args.input_dir, args.output_dir, args.datasets, args.embedding_dim) 159 | -------------------------------------------------------------------------------- /TrajLearn/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import torch 5 | from torch.optim import Optimizer 6 | from tqdm import tqdm 7 | from typing import Optional, Tuple 8 | 9 | from TrajLearn.TrajectoryBatchDataset import TrajectoryBatchDataset 10 | 11 | 12 | class Trainer: 13 | """ 14 | Trainer class to handle the training and validation of a given model on trajectory datasets. 15 | """ 16 | def __init__(self, model: torch.nn.Module, dataset: TrajectoryBatchDataset, config: dict, 17 | logger, model_checkpoint_directory: str, always_save_checkpoint: bool = False, 18 | optimizer: Optional[Optimizer] = None): 19 | """ 20 | Initialize the Trainer class with model, dataset, configurations, and other options. 21 | 22 | Args: 23 | model (torch.nn.Module): The model to train. 24 | dataset (TrajectoryBatchDataset): The dataset to use for training. 25 | config (dict): Configuration dictionary with training parameters. 26 | logger: Logger for logging training progress. 27 | model_checkpoint_directory (str): Directory to save model checkpoints. 28 | always_save_checkpoint (bool): Whether to save a checkpoint every epoch. Defaults to False. 29 | optimizer (Optional[Optimizer]): Optimizer to use. If None, a default optimizer is configured. 30 | """ 31 | self.model = model 32 | self.logger = logger 33 | self.train_dataset = dataset 34 | self.always_save_checkpoint = always_save_checkpoint 35 | self.device = config["device"] 36 | self.device_type = 'cuda' if 'cuda' in self.device else 'cpu' 37 | self.config = config 38 | self.out_dir = model_checkpoint_directory 39 | self.max_epochs = config["max_epochs"] 40 | self.block_size = config["block_size"] 41 | self.batch_size = config["batch_size"] 42 | self.min_input_length = config["min_input_length"] 43 | self.max_input_length = config["max_input_length"] 44 | self.learning_rate = config["learning_rate"] 45 | self.weight_decay = config["weight_decay"] 46 | self.beta1 = config["beta1"] 47 | self.beta2 = config["beta2"] 48 | self.grad_clip = config["grad_clip"] 49 | self.decay_lr = config["decay_lr"] 50 | self.warmup_iters = config["warmup_iters"] 51 | self.lr_decay_iters = config["lr_decay_iters"] 52 | self.min_lr = config["min_lr"] 53 | self.patience = config["patience"] 54 | self.early_stopping_counter = 0 55 | 56 | dtype = 'float32' 57 | ptdtype = { 58 | 'float32': torch.float32, 59 | 'bfloat16': torch.bfloat16, 60 | 'float16': torch.float16 61 | }[dtype] 62 | self.ctx = torch.amp.autocast(device_type=self.device_type, dtype=ptdtype) 63 | self.scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 64 | 65 | if optimizer is None: 66 | self.optimizer = model.configure_optimizers( 67 | self.weight_decay, self.learning_rate, (self.beta1, self.beta2), self.device_type 68 | ) 69 | else: 70 | self.optimizer = optimizer 71 | 72 | input_lengths = list(range(self.min_input_length, self.max_input_length + 1)) 73 | self.train_dataset.create_batches(self.batch_size, input_lengths) 74 | 75 | self.validation_dataset = TrajectoryBatchDataset( 76 | os.path.join(config["data_dir"], config["dataset"]), 77 | dataset_type='val', 78 | delimiter=config["delimiter"], 79 | validation_ratio=config["validation_ratio"], 80 | test_ratio=config["test_ratio"] 81 | ) 82 | self.validation_dataset.create_batches(self.batch_size, self.min_input_length) 83 | 84 | @torch.no_grad() 85 | def val_epoch(self) -> Tuple[float, float]: 86 | """ 87 | Run a single validation epoch. 88 | 89 | Returns: 90 | Tuple[float, float]: Average validation loss and accuracy. 91 | """ 92 | self.model.eval() 93 | total_val_loss = 0 94 | total_correct = 0 95 | total_samples = 0 96 | 97 | for X, Y in tqdm(self.validation_dataset, leave=False): 98 | x = X.to(self.device) 99 | y = Y.to(self.device) 100 | with self.ctx: 101 | output, loss = self.model(x, y) 102 | total_val_loss += loss.item() 103 | total_correct += (output.argmax(dim=2)[:, -1] == y[:, -1]).sum().item() 104 | total_samples += Y.shape[0] 105 | 106 | avg_val_loss = total_val_loss / total_samples 107 | val_accuracy = total_correct / total_samples 108 | return avg_val_loss, val_accuracy 109 | 110 | def train(self): 111 | """ 112 | Train the model for a specified number of epochs, validate, and save checkpoints. 113 | """ 114 | iter_num = 0 115 | best_val_loss = float('inf') 116 | self.logger.info("Starting training") 117 | 118 | for epoch in range(self.max_epochs): 119 | self.model.train() 120 | t_epoch_start = time.time() 121 | total_loss = 0 122 | total_samples = 0 123 | 124 | for X, Y in (pbar := tqdm(self.train_dataset, leave=False)): 125 | iter_num += 1 126 | lr = self.get_lr(iter_num) if self.decay_lr else self.learning_rate 127 | for param_group in self.optimizer.param_groups: 128 | param_group['lr'] = lr 129 | 130 | with self.ctx: 131 | _, loss = self.model(X.to(self.device), Y.to(self.device)) 132 | total_loss += loss.item() 133 | 134 | total_samples += X.shape[0] 135 | 136 | self.scaler.scale(loss).backward() 137 | if self.grad_clip != 0.0: 138 | self.scaler.unscale_(self.optimizer) 139 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) 140 | self.scaler.step(self.optimizer) 141 | self.scaler.update() 142 | self.optimizer.zero_grad(set_to_none=True) 143 | pbar.set_postfix({'loss': total_loss / total_samples}) 144 | 145 | dt = time.time() - t_epoch_start 146 | avg_loss = total_loss / total_samples 147 | self.logger.info(f"Training epoch {epoch + 1}/{self.max_epochs}, " 148 | f"Training loss: {avg_loss:.3g}, Time: {dt:.1f}s") 149 | 150 | t_val_start = time.time() 151 | avg_val_loss, val_accuracy = self.val_epoch() 152 | dt = time.time() - t_val_start 153 | self.logger.info(f'Validation loss: {avg_val_loss:.3g}, ' 154 | f'Validation Accuracy: {val_accuracy * 100:.2f}%, Time: {dt:.1f}s') 155 | 156 | if avg_val_loss < best_val_loss: 157 | best_val_loss = avg_val_loss 158 | self.save_checkpoint() 159 | self.early_stopping_counter = 0 160 | else: 161 | self.early_stopping_counter += 1 162 | if self.early_stopping_counter >= self.patience: 163 | self.logger.info("Early stopping triggered.") 164 | break 165 | 166 | def get_lr(self, it: int) -> float: 167 | """ 168 | Calculate learning rate with optional warmup and cosine decay. 169 | 170 | Args: 171 | it (int): Current iteration number. 172 | 173 | Returns: 174 | float: Calculated learning rate for the current iteration. 175 | """ 176 | if it < self.warmup_iters: 177 | return self.learning_rate * it / self.warmup_iters 178 | if it == self.warmup_iters: 179 | self.logger.info("Warm-up iterations ended, starting cosine decay") 180 | if it == self.lr_decay_iters: 181 | self.logger.info("Decay iterations ended, using minimum learning rate") 182 | if it >= self.lr_decay_iters: 183 | return self.min_lr 184 | 185 | decay_ratio = (it - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters) 186 | assert 0 <= decay_ratio <= 1 187 | coefficient = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 188 | return self.min_lr + coefficient * (self.learning_rate - self.min_lr) 189 | 190 | def save_checkpoint(self): 191 | """ 192 | Save model and optimizer state as a checkpoint. 193 | """ 194 | checkpoint = { 195 | 'model': self.model.state_dict(), 196 | 'optimizer': self.optimizer.state_dict(), 197 | 'config': self.config, 198 | } 199 | checkpoint_path = os.path.join(self.out_dir, 'checkpoint.pt') 200 | try: 201 | torch.save(checkpoint, checkpoint_path) 202 | self.logger.info("Saved current best model to " + checkpoint_path) 203 | except Exception as e: 204 | self.logger.error(f"Failed to save checkpoint: {e}") 205 | -------------------------------------------------------------------------------- /TrajLearn/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | import pickle 5 | import random 6 | from pathlib import Path 7 | from typing import Dict, Any, Optional 8 | 9 | import numpy as np 10 | import torch 11 | from TrajLearn.TrajectoryBatchDataset import TrajectoryBatchDataset 12 | from TrajLearn.model import ModelConfig, CausalLM 13 | from TrajLearn.evaluator import evaluate_model 14 | from TrajLearn.trainer import Trainer 15 | from TrajLearn.logger import get_logger 16 | from baselines import HigherOrderMarkovChain 17 | 18 | 19 | def setup_environment(seed: int) -> None: 20 | """ 21 | Set up the environment by configuring CUDA and setting random seeds. 22 | 23 | Args: 24 | - seed (int): The seed for random number generators. 25 | - device_id (str): The CUDA device ID to set for training. 26 | """ 27 | torch.cuda.cudnn_enabled = False 28 | torch.backends.cudnn.deterministic = True 29 | 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.backends.cuda.matmul.allow_tf32 = True 35 | torch.backends.cudnn.allow_tf32 = True 36 | 37 | 38 | def get_dataset(config: Dict[str, Any], test_mode: bool = False) -> TrajectoryBatchDataset: 39 | """ 40 | Load the trajectory dataset based on configuration. 41 | 42 | Args: 43 | - config (Dict[str, Any]): Configuration dictionary. 44 | - test_mode (bool): Whether to load test or training data (default is False). 45 | 46 | Returns: 47 | - TrajectoryBatchDataset: The dataset object. 48 | """ 49 | dataset_type = 'test' if test_mode else 'train' 50 | dataset_path = Path(config["data_dir"]) / config["dataset"] 51 | dataset = TrajectoryBatchDataset( 52 | dataset_path, 53 | dataset_type=dataset_type, 54 | delimiter=config["delimiter"], 55 | validation_ratio=config["validation_ratio"], 56 | test_ratio=config["test_ratio"] 57 | ) 58 | config["vocab_size"] = dataset.vocab_size 59 | return dataset 60 | 61 | 62 | def load_model(model: torch.nn.Module | HigherOrderMarkovChain, checkpoint_path: Optional[Path], device: str) -> torch.nn.Module: 63 | """ 64 | Load a model from a checkpoint. 65 | 66 | Args: 67 | - config (Dict[str, Any]): Configuration dictionary. 68 | - dataset (TrajectoryBatchDataset): Dataset to extract vocabulary size. 69 | - checkpoint_path (Optional[Path]): Path to the model checkpoint (default is None). 70 | 71 | Returns: 72 | - Module: The initialized model, possibly with loaded weights. 73 | """ 74 | if isinstance(model, HigherOrderMarkovChain): 75 | with open(checkpoint_path, 'rb') as f: 76 | checkpoint = pickle.load(f) 77 | optimizer = None 78 | else: 79 | checkpoint = torch.load(checkpoint_path, map_location=device) 80 | optimizer = checkpoint['optimizer'] 81 | config = checkpoint['config'] 82 | state_dict = checkpoint['model'] 83 | unwanted_prefix = '_orig_mod.' 84 | for k, _ in list(state_dict.items()): 85 | if k.startswith(unwanted_prefix): 86 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 87 | model.load_state_dict(state_dict) 88 | 89 | return model, config, optimizer 90 | 91 | def initialize_model(config, custom_init=None): 92 | model_config = ModelConfig( 93 | block_size=config["block_size"], 94 | vocab_size=config["vocab_size"], 95 | n_layer=config["n_layer"], 96 | n_head=config["n_head"], 97 | n_embd=config["n_embd"], 98 | dropout=config["dropout"], 99 | bias=config["bias"] 100 | ) 101 | return CausalLM(model_config, custom_init) 102 | 103 | 104 | def train_model( 105 | name: str, 106 | dataset: TrajectoryBatchDataset, 107 | config: Dict[str, Any], 108 | model: Optional[torch.nn.Module | HigherOrderMarkovChain] = None 109 | ) -> None: 110 | """ 111 | Set up and execute the training process. 112 | 113 | Args: 114 | - name (str): Name for the current training session (used for saving logs/checkpoints). 115 | - dataset (TrajectoryBatchDataset): Dataset object for training. 116 | - config (Dict[str, Any]): Configuration dictionary. 117 | - model (Optional[torch.nn.Module]): The model to be trained (can be None before loading). 118 | """ 119 | time_str = name + "-" + time.strftime("%Y%m%d-%H%M%S") 120 | model_checkpoint_directory = Path(config["model_checkpoint_directory"]) / time_str 121 | Path(model_checkpoint_directory).mkdir(parents=True, exist_ok=True) 122 | log_directory = model_checkpoint_directory / 'logs' 123 | 124 | if model is None: 125 | if config['custom_initialization']: 126 | custom_init_path = os.path.join(config["data_dir"], config["dataset"], 'embeddings.npy') 127 | embeddings_np = np.load(custom_init_path) 128 | custom_init = torch.from_numpy(embeddings_np).to(torch.float32) 129 | model = initialize_model(config, custom_init=custom_init) 130 | else: 131 | model = initialize_model(config=config) 132 | 133 | if config['train_from_checkpoint_if_exist']: 134 | model_checkpoints = sorted(glob.glob(str(Path(config["model_checkpoint_directory"]) / (name + "-*")))) 135 | if len(model_checkpoints) > 0: 136 | last_checkpoint = Path(model_checkpoints[-1]) / 'checkpoint.pt' 137 | model, config, optimizer = load_model(model, last_checkpoint, config['device']) 138 | 139 | logger = get_logger(log_directory, name, phase="train") 140 | 141 | if isinstance(model, HigherOrderMarkovChain): 142 | model.train(dataset, logger, str(model_checkpoint_directory)) 143 | else: 144 | model.to(config["device"]) 145 | trainer = Trainer(model, dataset, config, logger, str(model_checkpoint_directory)) 146 | trainer.train() 147 | 148 | 149 | def test_model(name: str, dataset: TrajectoryBatchDataset, config: Dict[str, Any], model: Optional[torch.nn.Module] = None) -> list: 150 | """ 151 | Set up and execute the testing process. 152 | 153 | Args: 154 | - name (str): Name of the configuration (used for loading the model checkpoint). 155 | - dataset (TrajectoryBatchDataset): Dataset object for testing. 156 | - config (Dict[str, Any]): Configuration dictionary. 157 | - model (Optional[torch.nn.Module]): The model to be tested (can be None before loading). 158 | """ 159 | model_checkpoint_directory = sorted(glob.glob(str(Path(config["model_checkpoint_directory"]) / (name + "-2024*"))))[-1] 160 | log_directory = Path(model_checkpoint_directory) / 'logs' 161 | 162 | logger = get_logger(log_directory, name, phase="test") 163 | 164 | if model is None: 165 | model = initialize_model(config) 166 | 167 | checkpoint_path = Path(model_checkpoint_directory) / 'checkpoint.pt' 168 | model, _, __ = load_model(model, checkpoint_path, config['device']) 169 | 170 | prediction_length = config["test_prediction_length"] 171 | dataset.create_batches( 172 | config["batch_size"], config["test_input_length"], prediction_length, False, False) 173 | 174 | if isinstance(model, HigherOrderMarkovChain): 175 | results = model.evaluate(dataset) 176 | logger.info(", ".join(results)) 177 | return results 178 | else: 179 | model.to(config["device"]) 180 | return evaluate_model(model, dataset, config, logger) 181 | -------------------------------------------------------------------------------- /baselines/HigherOrderAttnLSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | class HigherOrderAttnLSTM(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.embedding = nn.Embedding(config["vocab_size"], config["n_embd"]) 10 | self.lstm = nn.LSTM(config["n_embd"], config["n_embd"], num_layers=config["n_layer"], dropout=config["dropout"], batch_first=True) 11 | self.attention = nn.Linear(config["n_embd"], 1) 12 | self.lm_head = nn.Linear(config["n_embd"], config["vocab_size"]) 13 | 14 | def forward(self, x, targets=None): 15 | x = self.embedding(x) 16 | x = torch.squeeze(x, dim=2) 17 | lstm_outputs, _ = self.lstm(x) 18 | attention_scores = self.attention(lstm_outputs) 19 | attention_weights = torch.softmax(attention_scores.squeeze(dim=-1), dim=1) 20 | attended_outputs = torch.bmm(attention_weights.unsqueeze(dim=1), lstm_outputs) 21 | logits = self.lm_head(attended_outputs) 22 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) if targets is not None else None 23 | return logits, loss 24 | 25 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 26 | return optim.AdamW(self.parameters(), betas=betas, lr=learning_rate) -------------------------------------------------------------------------------- /baselines/HigherOrderGRU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | class HigherOrderGRU(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.embedding = nn.Embedding(config["vocab_size"], config["n_embd"]) 10 | self.gru = nn.GRU(config["n_embd"], config["n_embd"], num_layers=config["n_layer"], dropout=config["dropout"], batch_first=True) 11 | self.lm_head = nn.Linear(config["n_embd"], config["vocab_size"]) 12 | 13 | def forward(self, x, targets=None): 14 | x = self.embedding(x) 15 | x = torch.squeeze(x, dim=2) 16 | x, _ = self.gru(x) 17 | logits = self.lm_head(x[:, [-1], :]) 18 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) if targets is not None else None 19 | return logits, loss 20 | 21 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 22 | return optim.AdamW(self.parameters(), betas=betas, lr=learning_rate) -------------------------------------------------------------------------------- /baselines/HigherOrderLSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | 6 | class HigherOrderLSTM(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.embedding = nn.Embedding(config["vocab_size"], config["n_embd"]) 10 | self.lstm = nn.LSTM(config["n_embd"], config["n_embd"], num_layers=config["n_layer"], dropout=config["dropout"], batch_first=True) 11 | self.lm_head = nn.Linear(config["n_embd"], config["vocab_size"]) 12 | 13 | def forward(self, x, targets=None): 14 | x = self.embedding(x) 15 | x = torch.squeeze(x, dim=2) 16 | x, _ = self.lstm(x) 17 | logits = self.lm_head(x[:, [-1], :]) 18 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) if targets is not None else None 19 | return logits, loss 20 | 21 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 22 | return optim.AdamW(self.parameters(), betas=betas, lr=learning_rate) -------------------------------------------------------------------------------- /baselines/HigherOrderMarkovChain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pickle 4 | from collections import defaultdict 5 | from nltk.translate.bleu_score import sentence_bleu 6 | from tqdm import tqdm 7 | 8 | # Assuming that the TrajectoryBatchDataset class is defined as provided 9 | # and that the dataset instance is ready and create_batches() has been called. 10 | 11 | class HigherOrderMarkovChain: 12 | def __init__(self, config, order=1): 13 | self.order = order 14 | self.transition_counts = None 15 | self.transition_probs = None 16 | self.state_index_mapping = {} 17 | self.index_state_mapping = {} 18 | self.states = [] 19 | self.num_states = 0 20 | self.config = config 21 | self.logger = None 22 | 23 | 24 | def train(self, dataset, logger, model_checkpoint_directory: str): 25 | """ 26 | Train the Higher-Order Markov Chain using the provided sequences. 27 | 28 | :param sequences: List of sequences (list of lists), where each sequence is a list of states. 29 | """ 30 | logger.info("Building transition matrix...") 31 | self.logger = logger 32 | sequences = dataset.data 33 | self._build_state_mappings(sequences) 34 | self._build_transition_matrix(sequences) 35 | self.save_checkpoint(model_checkpoint_directory) 36 | 37 | def _build_state_mappings(self, sequences): 38 | # Build state to index mappings 39 | state_set = set() 40 | for sequence in sequences: 41 | state_set.update(sequence) 42 | self.states = list(state_set) 43 | self.num_states = len(self.states) 44 | self.state_index_mapping = {state: idx for idx, state in enumerate(self.states)} 45 | self.index_state_mapping = {idx: state for state, idx in self.state_index_mapping.items()} 46 | 47 | 48 | def _build_transition_matrix(self, sequences): 49 | # Initialize transition counts with ones for smoothing 50 | self.transition_counts = defaultdict(lambda: defaultdict(lambda: 1)) 51 | 52 | for sequence in tqdm(sequences, desc="Processing training sequences"): 53 | for i in range(len(sequence) - self.order): 54 | current_state = tuple(sequence[i:i + self.order]) 55 | next_state = sequence[i + self.order] 56 | self.transition_counts[current_state][next_state] += 1 57 | 58 | # Convert counts to probabilities 59 | self.transition_probs = {} 60 | for current_state, next_states in self.transition_counts.items(): 61 | total = sum(next_states.values()) 62 | self.transition_probs[current_state] = {state: count / total for state, count in next_states.items()} 63 | 64 | def predict_next_state(self, current_sequence): 65 | """ 66 | Predict the next state(s) given the current sequence. 67 | 68 | :param current_sequence: List of the current sequence of states. 69 | :return: List of predicted next states, sorted by probability in descending order. 70 | """ 71 | current_state = tuple(current_sequence[-self.order:]) 72 | next_state_probs = self.transition_probs.get(current_state, {}) 73 | if not next_state_probs: 74 | return [] 75 | sorted_next_states = sorted(next_state_probs.items(), key=lambda item: item[1], reverse=True) 76 | return [state for state, prob in sorted_next_states[:5]] # Return top 5 77 | 78 | def predict_next_n_steps(self, sequence, n=5): 79 | """ 80 | Predict the next n steps given the initial sequence. 81 | 82 | :param sequence: Initial sequence of states (list). 83 | :param n: Number of steps to predict. 84 | :return: List of lists containing predicted states at each step. 85 | """ 86 | predictions = [] 87 | current_sequence = sequence.copy() 88 | for _ in range(n): 89 | next_states = self.predict_next_state(current_sequence) 90 | if not next_states: 91 | break # If no next state is found 92 | predictions.append(next_states) 93 | current_sequence.append(next_states[0]) # Append the most probable next state 94 | return predictions 95 | 96 | def save_checkpoint(self, model_checkpoint_directory): 97 | """ 98 | Save the trained model to a file. 99 | 100 | :param filepath: Path to the file where the model will be saved. 101 | """ 102 | model_data = { 103 | 'order': self.order, 104 | 'transition_probs': self.transition_probs, 105 | 'state_index_mapping': self.state_index_mapping, 106 | 'index_state_mapping': self.index_state_mapping, 107 | 'states': self.states, 108 | 'num_states': self.num_states 109 | } 110 | checkpoint = { 111 | 'model': model_data, 112 | 'config': self.config, 113 | } 114 | checkpoint_path = os.path.join(model_checkpoint_directory, 'checkpoint.pt') 115 | try: 116 | with open(checkpoint_path, 'wb') as f: 117 | pickle.dump(checkpoint, f) 118 | except Exception as e: 119 | self.logger.error(f"Failed to save checkpoint: {e}") 120 | 121 | def load_state_dict(self, model_data): 122 | """ 123 | Load a trained model from a file. 124 | 125 | :param filepath: Path to the file where the model is saved. 126 | :return: An instance of HigherOrderMarkovChain with loaded parameters. 127 | """ 128 | self.transition_probs = model_data['transition_probs'] 129 | self.state_index_mapping = model_data['state_index_mapping'] 130 | self.index_state_mapping = model_data['index_state_mapping'] 131 | self.states = model_data['states'] 132 | self.num_states = model_data['num_states'] 133 | self.order = model_data['order'] 134 | 135 | def evaluate(self, test_dataset): 136 | """ 137 | Evaluate the model using the test sequences. 138 | 139 | :param test_sequences: List of test sequences. 140 | :return: Dictionary containing evaluation metrics. 141 | """ 142 | print("Evaluating model...") 143 | total_predictions = 0 144 | hit_step5_at1 = 0 145 | hit_step5_at3 = 0 146 | hit_step5_at5 = 0 147 | bleu_scores = 0 148 | 149 | start_time = time.time() 150 | for test_sequence in tqdm(test_dataset.data, desc="Processing test sequences"): 151 | if len(test_sequence) < self.order + 5: 152 | continue 153 | for i in range(len(test_sequence) - (self.order + 5) + 1): 154 | original_sequence = test_sequence[i:i + self.order + 5] # Include the initial sequence and actual next steps 155 | observe_sequence = original_sequence[:self.order] 156 | actual_next_steps = original_sequence[self.order:self.order + 5] 157 | 158 | predicted_next_steps = self.predict_next_n_steps(observe_sequence, n=5) 159 | 160 | if len(predicted_next_steps) < 5: 161 | continue # Skip if predictions are incomplete 162 | 163 | actual_step5 = actual_next_steps[4] 164 | predicted_step5 = predicted_next_steps[4] 165 | 166 | if actual_step5 == predicted_step5[0]: 167 | hit_step5_at1 += 1 168 | if actual_step5 in predicted_step5[:3]: 169 | hit_step5_at3 += 1 170 | if actual_step5 in predicted_step5[:5]: 171 | hit_step5_at5 += 1 172 | 173 | # For BLEU score calculation 174 | predicted_sentence = [predicted_next_steps[j][0] for j in range(len(predicted_next_steps))] 175 | bleu_scores += sentence_bleu([actual_next_steps], predicted_sentence) 176 | 177 | total_predictions += 1 178 | 179 | test_duration = time.time() - start_time 180 | 181 | return [ 182 | f"Accuracy@1: {hit_step5_at1 / total_predictions if total_predictions else 0}", 183 | f"Accuracy@3: {hit_step5_at3 / total_predictions if total_predictions else 0}", 184 | f"Accuracy@5: {hit_step5_at5 / total_predictions if total_predictions else 0}", 185 | f"BLEU score: {bleu_scores / total_predictions if total_predictions else 0}", 186 | f"Test duration: {test_duration:.3f}(s)", 187 | f"Samples: {total_predictions}" 188 | ] 189 | -------------------------------------------------------------------------------- /baselines/__init__.py: -------------------------------------------------------------------------------- 1 | from .HigherOrderAttnLSTM import HigherOrderAttnLSTM 2 | from .HigherOrderLSTM import HigherOrderLSTM 3 | from .HigherOrderGRU import HigherOrderGRU 4 | from .HigherOrderMarkovChain import HigherOrderMarkovChain 5 | 6 | __all__ = ["HigherOrderAttnLSTM", "HigherOrderLSTM", "HigherOrderGRU", "HigherOrderMarkovChain"] 7 | -------------------------------------------------------------------------------- /configs.yaml: -------------------------------------------------------------------------------- 1 | porto7: 2 | test_ratio: 0.2 3 | validation_ratio: 0.1 4 | delimiter: " " 5 | min_input_length: 10 6 | max_input_length: 14 7 | test_input_length: 10 8 | test_prediction_length: 5 9 | batch_size: 256 10 | device: cuda 11 | max_epochs: 5 12 | block_size: 24 13 | learning_rate: 5.e-3 14 | weight_decay: 8.e-3 15 | beta1: 0.9 16 | beta2: 0.95 17 | grad_clip: 1.0 18 | decay_lr: True 19 | warmup_iters: 100 20 | lr_decay_iters: 20000 21 | min_lr: 5.e-8 22 | seed: 42 23 | data_dir: ./data 24 | dataset: porto7 25 | n_layer: 12 26 | n_head: 4 27 | n_embd: 512 28 | bias: False 29 | dropout: 0 30 | model_checkpoint_directory: ./models/ 31 | train_from_checkpoint_if_exist: False 32 | custom_initialization: False 33 | patience: 3 34 | porto8: 35 | test_ratio: 0.2 36 | validation_ratio: 0.1 37 | delimiter: " " 38 | min_input_length: 10 39 | max_input_length: 14 40 | test_input_length: 10 41 | test_prediction_length: 5 42 | batch_size: 256 43 | device: cuda 44 | max_epochs: 4 45 | block_size: 24 46 | learning_rate: 5.e-4 47 | weight_decay: 5.e-3 48 | beta1: 0.9 49 | beta2: 0.95 50 | grad_clip: 1.0 51 | decay_lr: True 52 | warmup_iters: 200 53 | lr_decay_iters: 40000 54 | min_lr: 1.e-8 55 | seed: 42 56 | data_dir: ./data 57 | dataset: porto8 58 | n_layer: 12 59 | n_head: 4 60 | n_embd: 512 61 | bias: False 62 | dropout: 0 63 | model_checkpoint_directory: ./models/ 64 | train_from_checkpoint_if_exist: False 65 | custom_initialization: False 66 | patience: 3 67 | porto9: 68 | test_ratio: 0.2 69 | validation_ratio: 0.1 70 | delimiter: " " 71 | min_input_length: 10 72 | max_input_length: 14 73 | test_input_length: 10 74 | test_prediction_length: 5 75 | batch_size: 256 76 | device: cuda 77 | max_epochs: 3 78 | block_size: 24 79 | learning_rate: 1.e-3 80 | weight_decay: 5.e-3 81 | beta1: 0.9 82 | beta2: 0.95 83 | grad_clip: 1.0 84 | decay_lr: True 85 | warmup_iters: 400 86 | lr_decay_iters: 80000 87 | min_lr: 5.e-9 88 | seed: 42 89 | data_dir: ./data 90 | dataset: porto9 91 | n_layer: 12 92 | n_head: 4 93 | n_embd: 512 94 | bias: False 95 | dropout: 0 96 | model_checkpoint_directory: ./models/ 97 | train_from_checkpoint_if_exist: False 98 | custom_initialization: False 99 | patience: 3 100 | rome7: 101 | test_ratio: 0.2 102 | validation_ratio: 0.1 103 | delimiter: " " 104 | min_input_length: 10 105 | max_input_length: 14 106 | test_input_length: 10 107 | test_prediction_length: 5 108 | batch_size: 256 109 | device: cuda 110 | max_epochs: 10 111 | block_size: 24 112 | learning_rate: 5.e-4 113 | weight_decay: 5.e-3 114 | beta1: 0.9 115 | beta2: 0.95 116 | grad_clip: 1.0 117 | decay_lr: True 118 | warmup_iters: 100 119 | lr_decay_iters: 10000 120 | min_lr: 5.e-8 121 | seed: 42 122 | data_dir: ./data 123 | dataset: rome7 124 | n_layer: 12 125 | n_head: 4 126 | n_embd: 512 127 | bias: False 128 | dropout: 0 129 | model_checkpoint_directory: ./models/ 130 | train_from_checkpoint_if_exist: False 131 | custom_initialization: False 132 | patience: 3 133 | rome8: 134 | test_ratio: 0.2 135 | validation_ratio: 0.1 136 | delimiter: " " 137 | min_input_length: 10 138 | max_input_length: 14 139 | test_input_length: 10 140 | test_prediction_length: 5 141 | batch_size: 256 142 | device: cuda 143 | max_epochs: 10 144 | block_size: 24 145 | learning_rate: 5.e-4 146 | weight_decay: 5.e-3 147 | beta1: 0.9 148 | beta2: 0.95 149 | grad_clip: 1.0 150 | decay_lr: True 151 | warmup_iters: 200 152 | lr_decay_iters: 24000 153 | min_lr: 1.e-8 154 | seed: 42 155 | data_dir: ./data 156 | dataset: rome8 157 | n_layer: 12 158 | n_head: 4 159 | n_embd: 512 160 | bias: False 161 | dropout: 0 162 | model_checkpoint_directory: ./models/ 163 | train_from_checkpoint_if_exist: False 164 | custom_initialization: False 165 | patience: 3 166 | rome9: 167 | test_ratio: 0.2 168 | validation_ratio: 0.1 169 | delimiter: " " 170 | min_input_length: 10 171 | max_input_length: 14 172 | test_input_length: 10 173 | test_prediction_length: 5 174 | batch_size: 256 175 | device: cuda 176 | max_epochs: 8 177 | block_size: 24 178 | learning_rate: 1.e-3 179 | weight_decay: 5.e-3 180 | beta1: 0.9 181 | beta2: 0.95 182 | grad_clip: 1.0 183 | decay_lr: True 184 | warmup_iters: 400 185 | lr_decay_iters: 60000 186 | min_lr: 5.e-9 187 | seed: 42 188 | data_dir: ./data 189 | dataset: rome9 190 | n_layer: 12 191 | n_head: 4 192 | n_embd: 512 193 | bias: False 194 | dropout: 0 195 | model_checkpoint_directory: ./models/ 196 | train_from_checkpoint_if_exist: False 197 | custom_initialization: False 198 | patience: 3 199 | geolife7: 200 | test_ratio: 0.2 201 | validation_ratio: 0.1 202 | delimiter: " " 203 | min_input_length: 10 204 | max_input_length: 14 205 | test_input_length: 10 206 | test_prediction_length: 5 207 | batch_size: 256 208 | device: cuda 209 | max_epochs: 10 210 | block_size: 24 211 | learning_rate: 5.e-4 212 | weight_decay: 2.e-3 213 | beta1: 0.9 214 | beta2: 0.95 215 | grad_clip: 1.0 216 | decay_lr: True 217 | warmup_iters: 100 218 | lr_decay_iters: 12000 219 | min_lr: 2.e-8 220 | seed: 42 221 | data_dir: ./data 222 | dataset: geolife7 223 | n_layer: 12 224 | n_head: 4 225 | n_embd: 512 226 | bias: False 227 | dropout: 0 228 | model_checkpoint_directory: ./models/ 229 | train_from_checkpoint_if_exist: False 230 | custom_initialization: False 231 | patience: 3 232 | geolife8: 233 | test_ratio: 0.2 234 | validation_ratio: 0.1 235 | delimiter: " " 236 | min_input_length: 10 237 | max_input_length: 14 238 | test_input_length: 10 239 | test_prediction_length: 5 240 | batch_size: 256 241 | device: cuda 242 | max_epochs: 10 243 | block_size: 24 244 | learning_rate: 5.e-4 245 | weight_decay: 5.e-3 246 | beta1: 0.9 247 | beta2: 0.95 248 | grad_clip: 1.0 249 | decay_lr: True 250 | warmup_iters: 200 251 | lr_decay_iters: 24000 252 | min_lr: 1.e-8 253 | seed: 42 254 | data_dir: ./data 255 | dataset: geolife8 256 | n_layer: 12 257 | n_head: 4 258 | n_embd: 512 259 | bias: False 260 | dropout: 0 261 | model_checkpoint_directory: ./models/ 262 | train_from_checkpoint_if_exist: False 263 | custom_initialization: False 264 | patience: 3 265 | geolife9: 266 | test_ratio: 0.2 267 | validation_ratio: 0.1 268 | delimiter: " " 269 | min_input_length: 10 270 | max_input_length: 14 271 | test_input_length: 10 272 | test_prediction_length: 5 273 | batch_size: 256 274 | device: cuda 275 | max_epochs: 6 276 | block_size: 24 277 | learning_rate: 5.e-3 278 | weight_decay: 1.e-2 279 | beta1: 0.9 280 | beta2: 0.95 281 | grad_clip: 1.0 282 | decay_lr: True 283 | warmup_iters: 400 284 | lr_decay_iters: 60000 285 | min_lr: 1.e-9 286 | seed: 42 287 | data_dir: ./data 288 | dataset: geolife9 289 | n_layer: 12 290 | n_head: 4 291 | n_embd: 512 292 | bias: False 293 | dropout: 0 294 | model_checkpoint_directory: ./models/ 295 | train_from_checkpoint_if_exist: False 296 | custom_initialization: False 297 | patience: 3 -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script downloads datasets, extracts them, and verifies MD5 checksums. 4 | 5 | # Exit the script if any command fails 6 | set -e 7 | 8 | # Variables 9 | DATA_DIR="./data" 10 | BASE_URL="https://zenodo.org/records/8076553/files" 11 | 12 | echo "Starting download and extraction script..." 13 | 14 | # Check if arguments are provided 15 | if [ "$#" -eq 0 ]; then 16 | echo "No datasets specified. Please provide dataset names as arguments." 17 | exit 1 18 | fi 19 | 20 | # Create data directory if it doesn't exist 21 | echo "Checking for data directory..." 22 | if [ ! -d "$DATA_DIR" ]; then 23 | echo "Creating data directory at $DATA_DIR" 24 | mkdir -p "$DATA_DIR" 25 | else 26 | echo "Data directory already exists. Proceeding..." 27 | fi 28 | 29 | # Loop through all datasets provided as arguments 30 | for DATASET in "$@"; do 31 | echo "Processing dataset: $DATASET" 32 | 33 | # Define output zip and dataset URL 34 | OUTPUT_ZIP="${DATASET}.zip" 35 | DATASET_URL="${BASE_URL}/ho_${OUTPUT_ZIP}?download=1" 36 | 37 | # Download the dataset 38 | echo "Checking if the dataset is already downloaded..." 39 | if [ -f "$DATA_DIR/$OUTPUT_ZIP" ]; then 40 | echo "Dataset zip file already exists. Skipping download for $DATASET." 41 | else 42 | echo "Downloading dataset from $DATASET_URL..." 43 | wget --show-progress -O "$DATA_DIR/$OUTPUT_ZIP" "$DATASET_URL" 44 | echo "Download completed for $DATASET." 45 | fi 46 | 47 | # Extract the dataset 48 | echo "Checking if dataset is already extracted..." 49 | if [ -d "$DATA_DIR/$DATASET" ]; then 50 | echo "Dataset already extracted. Skipping extraction for $DATASET." 51 | else 52 | echo "Extracting dataset to $DATASET..." 53 | unzip -q "$DATA_DIR/$OUTPUT_ZIP" -d "$DATA_DIR/$DATASET" 54 | echo "Extraction completed for $DATASET." 55 | fi 56 | 57 | # Clean up the zip file 58 | if [ -f "$DATA_DIR/$OUTPUT_ZIP" ]; then 59 | echo "Cleaning up zip file for $DATASET..." 60 | rm -f "$DATA_DIR/$OUTPUT_ZIP" 61 | fi 62 | 63 | # Verify MD5 checksums 64 | echo "Verifying MD5 checksums for files in $DATA_DIR/$DATASET..." 65 | for file in "$DATA_DIR/$DATASET"/*; do 66 | if [ -f "$file" ]; then 67 | md5sum "$file" 68 | fi 69 | done 70 | echo "MD5 checksum verification completed for $DATASET." 71 | done 72 | 73 | echo "Script completed successfully for all datasets!" 74 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: trajlearn 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.10 7 | - pip: 8 | - transformers==4.44.1 9 | - numpy==1.26.4 10 | - pandas==2.2.2 11 | - h3==4.1.2 12 | - torch==2.3.0 13 | - geojson==3.0.1 14 | - geopandas==1.0.1 15 | - geopy==2.4.1 16 | - matplotlib==3.9.0 17 | - nltk==3.8.1 18 | - PyYAML==6.0.1 -------------------------------------------------------------------------------- /img/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amir-ni/Trajectory-prediction/570b89b4e0261056df3297487726b68a4983df9f/img/architecture.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from TrajLearn.utils import setup_environment, get_dataset, test_model, train_model 4 | from TrajLearn.config_loader import load_config 5 | from baselines import HigherOrderAttnLSTM, HigherOrderLSTM, HigherOrderGRU, HigherOrderMarkovChain 6 | 7 | def main() -> None: 8 | """ 9 | Main function to handle argument parsing and execute training or testing. 10 | """ 11 | parser = argparse.ArgumentParser(description='Trajectory Prediction Learning') 12 | parser.add_argument('config', type=str, help='Path to configuration file') 13 | parser.add_argument('--baseline', type=str, required=False, help='Baseline model to run') 14 | parser.add_argument('--test', default=False, action='store_true') 15 | args = parser.parse_args() 16 | 17 | config_list = load_config(args.config) 18 | 19 | for name, config in config_list.items(): 20 | setup_environment(config["seed"]) 21 | 22 | dataset = get_dataset(config, test_mode=args.test) 23 | 24 | if args.baseline: 25 | if args.baseline == "gru": 26 | model = HigherOrderGRU(config) 27 | elif args.baseline == "lstm": 28 | model = HigherOrderLSTM(config) 29 | elif args.baseline == "lstm-attn": 30 | model = HigherOrderAttnLSTM(config) 31 | elif args.baseline == "mc": 32 | model = HigherOrderMarkovChain(config) 33 | else: 34 | raise ValueError("Baseline not found.") 35 | else: 36 | model = None 37 | 38 | if args.test: 39 | test_model(name, dataset, config, model) 40 | else: 41 | train_model(name, dataset, config, model) 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | --------------------------------------------------------------------------------