├── .github └── workflows │ ├── ci.yml │ └── pypi_publish.yml ├── .gitignore ├── CITATION.cff ├── DATASETS.md ├── LICENSE ├── README.md ├── examples ├── batch_example.py ├── cache_and_filter_example.py ├── custom_batch_data.py ├── lane_query_example.py ├── map_api_example.py ├── preprocess_data.py ├── preprocess_maps.py ├── scene_batch_example.py ├── scenetimebatcher_example.py ├── sim_example.py ├── simple_map_api_example.py ├── simple_sim_example.py ├── speed_example.py ├── state_example.py └── visualization_example.py ├── img └── architecture.png ├── pyproject.toml ├── src └── trajdata │ ├── __init__.py │ ├── augmentation │ ├── __init__.py │ ├── augmentation.py │ ├── low_vel_yaw_correction.py │ └── noise_histories.py │ ├── caching │ ├── __init__.py │ ├── df_cache.py │ ├── env_cache.py │ └── scene_cache.py │ ├── data_structures │ ├── __init__.py │ ├── agent.py │ ├── batch.py │ ├── batch_element.py │ ├── collation.py │ ├── data_index.py │ ├── environment.py │ ├── scene.py │ ├── scene_metadata.py │ ├── scene_tag.py │ └── state.py │ ├── dataset.py │ ├── dataset_specific │ ├── __init__.py │ ├── argoverse2 │ │ ├── __init__.py │ │ ├── av2_dataset.py │ │ └── av2_utils.py │ ├── eth_ucy_peds │ │ ├── __init__.py │ │ └── eupeds_dataset.py │ ├── interaction │ │ ├── __init__.py │ │ └── interaction_dataset.py │ ├── lyft │ │ ├── __init__.py │ │ ├── lyft_dataset.py │ │ └── lyft_utils.py │ ├── nuplan │ │ ├── __init__.py │ │ ├── nuplan_dataset.py │ │ └── nuplan_utils.py │ ├── nusc │ │ ├── __init__.py │ │ ├── nusc_dataset.py │ │ └── nusc_utils.py │ ├── raw_dataset.py │ ├── scene_records.py │ ├── sdd_peds │ │ ├── __init__.py │ │ ├── estimated_homography.py │ │ └── sddpeds_dataset.py │ ├── vod │ │ ├── __init__.py │ │ ├── vod_dataset.py │ │ └── vod_utils.py │ └── waymo │ │ ├── __init__.py │ │ ├── waymo_dataset.py │ │ └── waymo_utils.py │ ├── filtering │ ├── __init__.py │ └── filters.py │ ├── maps │ ├── __init__.py │ ├── lane_route.py │ ├── map_api.py │ ├── map_kdtree.py │ ├── map_strtree.py │ ├── raster_map.py │ ├── traffic_light_status.py │ ├── vec_map.py │ └── vec_map_elements.py │ ├── parallel │ ├── __init__.py │ └── data_preprocessor.py │ ├── proto │ ├── __init__.py │ ├── vectorized_map.proto │ └── vectorized_map_pb2.py │ ├── simulation │ ├── __init__.py │ ├── sim_cache.py │ ├── sim_df_cache.py │ ├── sim_metrics.py │ ├── sim_scene.py │ ├── sim_stats.py │ └── sim_vis.py │ ├── utils │ ├── __init__.py │ ├── agent_utils.py │ ├── arr_utils.py │ ├── batch_utils.py │ ├── df_utils.py │ ├── env_utils.py │ ├── map_utils.py │ ├── parallel_utils.py │ ├── py_utils.py │ ├── raster_utils.py │ ├── scene_utils.py │ ├── state_utils.py │ ├── string_utils.py │ └── vis_utils.py │ └── visualization │ ├── __init__.py │ ├── interactive_animation.py │ ├── interactive_figure.py │ ├── interactive_vis.py │ └── vis.py └── tests ├── __init__.py ├── test_batch_conversion.py ├── test_collation.py ├── test_dataset.py ├── test_datasizes.py ├── test_description_matching.py ├── test_state.py ├── test_traffic_data.py └── test_vec_map.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - '**' 6 | pull_request: 7 | branches: 8 | - '**' 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | container: python:3.8-slim 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Install prerequisites (for OpenCV) 17 | run: apt-get update && apt-get install ffmpeg libsm6 libxext6 -y 18 | - name: Install trajdata base version 19 | run: python -m pip install . 20 | - name: Run tests 21 | run: python -m unittest tests/test_state.py 22 | -------------------------------------------------------------------------------- /.github/workflows/pypi_publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | workflow_dispatch: 15 | 16 | permissions: 17 | contents: read 18 | 19 | jobs: 20 | deploy: 21 | 22 | runs-on: ubuntu-latest 23 | 24 | steps: 25 | - uses: actions/checkout@v3 26 | - name: Set up Python 27 | uses: actions/setup-python@v3 28 | with: 29 | python-version: '3.x' 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install build 34 | - name: Build package 35 | run: python -m build 36 | - name: Publish package 37 | uses: pypa/gh-action-pypi-publish@release/v1 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | public/ 2 | .vscode/ 3 | *.html 4 | *.mp4 5 | *.avi 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 111 | __pypackages__/ 112 | 113 | # Celery stuff 114 | celerybeat-schedule 115 | celerybeat.pid 116 | 117 | # SageMath parsed files 118 | *.sage.py 119 | 120 | # Environments 121 | .env 122 | .venv 123 | env/ 124 | venv/ 125 | ENV/ 126 | env.bak/ 127 | venv.bak/ 128 | 129 | # Spyder project settings 130 | .spyderproject 131 | .spyproject 132 | 133 | # Rope project settings 134 | .ropeproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | .dmypy.json 142 | dmypy.json 143 | 144 | # Pyre type checker 145 | .pyre/ 146 | 147 | # pytype static type analyzer 148 | .pytype/ 149 | 150 | # Cython debug symbols 151 | cython_debug/ 152 | 153 | # PyCharm 154 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 155 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 156 | # and can be added to the global gitignore or merged into this file. For a more nuclear 157 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 158 | #.idea/ 159 | 160 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as follows." 3 | authors: 4 | - family-names: "Ivanovic" 5 | given-names: "Boris" 6 | orcid: "https://orcid.org/0000-0002-8698-202X" 7 | title: "trajdata: A unified interface to many trajectory forecasting datasets" 8 | version: 1.3.3 9 | doi: 10.5281/zenodo.6671548 10 | date-released: 2023-08-22 11 | url: "https://github.com/nvr-avg/trajdata" 12 | preferred-citation: 13 | type: conference-paper 14 | authors: 15 | - family-names: "Ivanovic" 16 | given-names: "Boris" 17 | orcid: "https://orcid.org/0000-0002-8698-202X" 18 | - family-names: "Song" 19 | given-names: "Guanyu" 20 | - family-names: "Gilitschenski" 21 | given-names: "Igor" 22 | - family-names: "Pavone" 23 | given-names: "Marco" 24 | journal: "Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks" 25 | month: 12 26 | title: "trajdata: A Unified Interface to Multiple Human Trajectory Datasets" 27 | year: 2023 -------------------------------------------------------------------------------- /examples/batch_example.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | 6 | from trajdata import AgentBatch, AgentType, UnifiedDataset 7 | from trajdata.augmentation import NoiseHistories 8 | from trajdata.visualization.vis import plot_agent_batch 9 | 10 | 11 | def main(): 12 | noise_hists = NoiseHistories() 13 | 14 | dataset = UnifiedDataset( 15 | desired_data=["nusc_mini-mini_train"], 16 | centric="agent", 17 | desired_dt=0.1, 18 | history_sec=(3.2, 3.2), 19 | future_sec=(4.8, 4.8), 20 | only_predict=[AgentType.VEHICLE], 21 | agent_interaction_distances=defaultdict(lambda: 30.0), 22 | incl_robot_future=False, 23 | incl_raster_map=True, 24 | raster_map_params={ 25 | "px_per_m": 2, 26 | "map_size_px": 224, 27 | "offset_frac_xy": (-0.5, 0.0), 28 | }, 29 | augmentations=[noise_hists], 30 | num_workers=0, 31 | verbose=True, 32 | data_dirs={ # Remember to change this to match your filesystem! 33 | "nusc_mini": "~/datasets/nuScenes", 34 | }, 35 | ) 36 | 37 | print(f"# Data Samples: {len(dataset):,}") 38 | 39 | dataloader = DataLoader( 40 | dataset, 41 | batch_size=4, 42 | shuffle=True, 43 | collate_fn=dataset.get_collate_fn(), 44 | num_workers=4, 45 | ) 46 | 47 | batch: AgentBatch 48 | for batch in tqdm(dataloader): 49 | plot_agent_batch(batch, batch_idx=0) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /examples/cache_and_filter_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | 7 | from trajdata import AgentBatch, AgentType, UnifiedDataset 8 | from trajdata.augmentation import NoiseHistories 9 | from trajdata.data_structures.batch_element import AgentBatchElement 10 | from trajdata.visualization.vis import plot_agent_batch 11 | 12 | 13 | def main(): 14 | noise_hists = NoiseHistories() 15 | 16 | create_dataset = lambda: UnifiedDataset( 17 | desired_data=["nusc_mini-mini_val"], 18 | centric="agent", 19 | desired_dt=0.5, 20 | history_sec=(2.0, 2.0), 21 | future_sec=(4.0, 4.0), 22 | only_predict=[AgentType.VEHICLE], 23 | agent_interaction_distances=defaultdict(lambda: 30.0), 24 | incl_robot_future=False, 25 | incl_raster_map=False, 26 | # map_params={"px_per_m": 2, "map_size_px": 224, "offset_frac_xy": (-0.5, 0.0)}, 27 | augmentations=[noise_hists], 28 | num_workers=0, 29 | verbose=True, 30 | data_dirs={ # Remember to change this to match your filesystem! 31 | "nusc_mini": "~/datasets/nuScenes", 32 | }, 33 | ) 34 | 35 | dataset = create_dataset() 36 | 37 | print(f"# Data Samples: {len(dataset):,}") 38 | 39 | print( 40 | "To demonstrate how to use caching we will first save the " 41 | "entire dataset (all BatchElements) to a cache file and then load from " 42 | "the cache file. Note that for large datasets and/or high time resolution " 43 | "this will create a large file and will use a lot of RAM." 44 | ) 45 | cache_path = "./temp_cache_file.dill" 46 | 47 | print( 48 | "We also use a custom filter function that only keeps elements with more " 49 | "than 5 neighbors" 50 | ) 51 | 52 | def my_filter(el: AgentBatchElement) -> bool: 53 | return el.num_neighbors > 5 54 | 55 | print( 56 | f"In the first run we will iterate through the entire dataset and save all " 57 | f"BatchElements to the cache file {cache_path}" 58 | ) 59 | print("This may take several minutes.") 60 | dataset.load_or_create_cache( 61 | cache_path=cache_path, num_workers=0, filter_fn=my_filter 62 | ) 63 | assert os.path.isfile(cache_path) 64 | 65 | print( 66 | "To demonstrate a consecuitve run we create a new dataset and load elements " 67 | "from the cache file." 68 | ) 69 | del dataset 70 | dataset = create_dataset() 71 | 72 | dataset.load_or_create_cache( 73 | cache_path=cache_path, num_workers=0, filter_fn=my_filter 74 | ) 75 | 76 | # Remove the temp cache file, we dont need it anymore. 77 | os.remove(cache_path) 78 | 79 | print( 80 | "We can iterate through the dataset the same way as normally, but this " 81 | "time it will be much faster because all BatchElements are in memory." 82 | ) 83 | dataloader = DataLoader( 84 | dataset, 85 | batch_size=4, 86 | shuffle=True, 87 | collate_fn=dataset.get_collate_fn(), 88 | num_workers=0, 89 | ) 90 | 91 | batch: AgentBatch 92 | for batch in tqdm(dataloader): 93 | plot_agent_batch(batch, batch_idx=0) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /examples/custom_batch_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is an example of how to extend a batch to include custom data 3 | """ 4 | 5 | from collections import defaultdict 6 | from functools import partial 7 | from typing import Tuple, Union 8 | 9 | import numpy as np 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from trajdata import AgentBatch, AgentType, UnifiedDataset 14 | from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement 15 | 16 | 17 | def custom_random_data( 18 | batch_elem: Union[AgentBatchElement, SceneBatchElement] 19 | ) -> np.ndarray: 20 | # create new data to add to each batch element 21 | return np.random.random((10, 10)) 22 | 23 | 24 | def custom_goal_location( 25 | batch_elem: Union[AgentBatchElement, SceneBatchElement] 26 | ) -> np.ndarray: 27 | # simply access existing element attributes 28 | return batch_elem.agent_future_np.position 29 | 30 | 31 | def custom_min_distance_from_others( 32 | batch_elem: Union[AgentBatchElement, SceneBatchElement] 33 | ) -> np.ndarray: 34 | # ... or more complicated calculations 35 | current_ego_loc = batch_elem.agent_history_np[-1, :2] 36 | all_distances = [ 37 | np.linalg.norm(current_ego_loc - veh[-1, :2]) 38 | for veh in batch_elem.neighbor_histories 39 | ] 40 | 41 | if not len(all_distances): 42 | return np.inf 43 | else: 44 | return np.min(all_distances) 45 | 46 | 47 | def custom_distances_squared( 48 | batch_elem: Union[AgentBatchElement, SceneBatchElement] 49 | ) -> np.ndarray: 50 | # we can chain extras together if needed 51 | return batch_elem.extras["min_distance"] ** 2 52 | 53 | 54 | def custom_raster( 55 | batch_elem: Union[AgentBatchElement, SceneBatchElement], 56 | raster_size: Tuple[int, ...], 57 | ) -> np.ndarray: 58 | # draw a custom raster 59 | img = np.zeros(raster_size) 60 | 61 | # ... 62 | return img 63 | 64 | 65 | def main(): 66 | dataset = UnifiedDataset( 67 | desired_data=["nusc_mini-mini_train"], 68 | centric="agent", 69 | desired_dt=0.1, 70 | history_sec=(3.2, 3.2), 71 | future_sec=(4.8, 4.8), 72 | only_types=[AgentType.VEHICLE], 73 | agent_interaction_distances=defaultdict(lambda: 30.0), 74 | incl_robot_future=False, 75 | incl_raster_map=True, 76 | raster_map_params={ 77 | "px_per_m": 2, 78 | "map_size_px": 224, 79 | "offset_frac_xy": (-0.5, 0.0), 80 | }, 81 | num_workers=0, 82 | verbose=True, 83 | data_dirs={ # Remember to change this to match your filesystem! 84 | "nusc_mini": "~/datasets/nuScenes", 85 | }, 86 | extras={ # a dictionary that contains functions that generate our custom data. Can be any function and has access to the batch element. 87 | "random_data": custom_random_data, 88 | "goal_location": custom_goal_location, 89 | "min_distance": custom_min_distance_from_others, 90 | "min_distance_sq": custom_distances_squared, # in Python >= 3.7 dictionaries are guaranteed to maintain order => you can use previously computed keys 91 | "raster": partial(custom_raster, raster_size=(100, 100)), 92 | }, 93 | ) 94 | 95 | print(f"# Data Samples: {len(dataset):,}") 96 | 97 | dataloader = DataLoader( 98 | dataset, 99 | batch_size=4, 100 | shuffle=True, 101 | collate_fn=dataset.get_collate_fn(), 102 | num_workers=4, 103 | ) 104 | 105 | batch: AgentBatch 106 | for batch in tqdm(dataloader): 107 | assert "random_data" in batch.extras 108 | assert "goal_location" in batch.extras 109 | assert "min_distance" in batch.extras 110 | assert "raster" in batch.extras 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /examples/lane_query_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is an example of how to extend a batch with lane information 3 | """ 4 | 5 | import random 6 | from collections import defaultdict 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | 14 | from trajdata import AgentBatch, AgentType, UnifiedDataset 15 | from trajdata.data_structures.batch_element import AgentBatchElement 16 | from trajdata.maps import VectorMap 17 | from trajdata.utils.arr_utils import transform_angles_np, transform_coords_np 18 | from trajdata.utils.state_utils import transform_state_np_2d 19 | from trajdata.visualization.vis import plot_agent_batch 20 | 21 | 22 | def get_closest_lane_point(element: AgentBatchElement) -> np.ndarray: 23 | """Closest lane for predicted agent.""" 24 | 25 | # Transform from agent coordinate frame to world coordinate frame. 26 | vector_map: VectorMap = element.vec_map 27 | world_from_agent_tf = np.linalg.inv(element.agent_from_world_tf) 28 | agent_future_xyzh_world = transform_state_np_2d( 29 | element.agent_future_np, world_from_agent_tf 30 | ).as_format("x,y,z,h") 31 | 32 | # Use cached kdtree to find closest lane point 33 | lane_points = [] 34 | for point_xyzh in agent_future_xyzh_world: 35 | possible_lanes = vector_map.get_current_lane(point_xyzh) 36 | xyzh_on_lane = np.full((1, 4), np.nan) 37 | if len(possible_lanes) > 0: 38 | xyzh_on_lane = possible_lanes[0].center.project_onto(point_xyzh[None, :3]) 39 | xyzh_on_lane[:, :2] = transform_coords_np( 40 | xyzh_on_lane[:, :2], element.agent_from_world_tf 41 | ) 42 | xyzh_on_lane[:, -1] = transform_angles_np( 43 | xyzh_on_lane[:, -1], element.agent_from_world_tf 44 | ) 45 | 46 | lane_points.append(xyzh_on_lane) 47 | 48 | lane_points = np.concatenate(lane_points, axis=0) 49 | return lane_points 50 | 51 | 52 | def main(): 53 | dataset = UnifiedDataset( 54 | desired_data=[ 55 | # "nusc_mini-mini_train", 56 | "lyft_sample-mini_val", 57 | ], 58 | centric="agent", 59 | desired_dt=0.1, 60 | history_sec=(3.2, 3.2), 61 | future_sec=(4.8, 4.8), 62 | only_types=[AgentType.VEHICLE], 63 | state_format="x,y,z,xd,yd,xdd,ydd,h", 64 | obs_format="x,y,z,xd,yd,xdd,ydd,s,c", 65 | agent_interaction_distances=defaultdict(lambda: 30.0), 66 | incl_robot_future=False, 67 | incl_raster_map=True, 68 | raster_map_params={ 69 | "px_per_m": 2, 70 | "map_size_px": 224, 71 | "offset_frac_xy": (-0.5, 0.0), 72 | }, 73 | incl_vector_map=True, 74 | num_workers=0, 75 | verbose=True, 76 | data_dirs={ # Remember to change this to match your filesystem! 77 | # "nusc_mini": "~/datasets/nuScenes", 78 | "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", 79 | }, 80 | # A dictionary that contains functions that generate our custom data. 81 | # Can be any function and has access to the batch element. 82 | extras={ 83 | "closest_lane_point": get_closest_lane_point, 84 | }, 85 | ) 86 | 87 | print(f"# Data Samples: {len(dataset):,}") 88 | 89 | dataloader = DataLoader( 90 | dataset, 91 | batch_size=4, 92 | shuffle=False, 93 | collate_fn=dataset.get_collate_fn(), 94 | num_workers=0, 95 | ) 96 | 97 | # Visualize selected examples 98 | num_plots = 3 99 | # batch_idxs = [10876, 10227, 1284] 100 | batch_idxs = random.sample(range(len(dataset)), num_plots) 101 | batch: AgentBatch = dataset.get_collate_fn(pad_format="right")( 102 | [dataset[i] for i in batch_idxs] 103 | ) 104 | assert "closest_lane_point" in batch.extras 105 | 106 | for batch_i in range(num_plots): 107 | ax = plot_agent_batch( 108 | batch, batch_idx=batch_i, legend=False, show=False, close=False 109 | ) 110 | lane_points = batch.extras["closest_lane_point"][batch_i] 111 | lane_points = lane_points[ 112 | torch.logical_not(torch.any(torch.isnan(lane_points), dim=1)), : 113 | ].numpy() 114 | 115 | ax.plot( 116 | lane_points[:, 0], 117 | lane_points[:, 1], 118 | "o-", 119 | markersize=3, 120 | label="Lane points", 121 | ) 122 | 123 | ax.legend(loc="best", frameon=True) 124 | 125 | plt.show() 126 | plt.close("all") 127 | 128 | # Scan through dataset 129 | batch: AgentBatch 130 | for idx, batch in enumerate(tqdm(dataloader)): 131 | assert "closest_lane_point" in batch.extras 132 | if idx > 50: 133 | break 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /examples/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from trajdata import UnifiedDataset 4 | 5 | 6 | def main(): 7 | dataset = UnifiedDataset( 8 | desired_data=["nusc_mini", "lyft_sample", "nuplan_mini"], 9 | rebuild_cache=True, 10 | rebuild_maps=True, 11 | num_workers=os.cpu_count(), 12 | verbose=True, 13 | data_dirs={ # Remember to change this to match your filesystem! 14 | "nusc_mini": "~/datasets/nuScenes", 15 | "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", 16 | "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", 17 | }, 18 | ) 19 | print(f"Total Data Samples: {len(dataset):,}") 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /examples/preprocess_maps.py: -------------------------------------------------------------------------------- 1 | from trajdata import UnifiedDataset 2 | 3 | 4 | # @profile 5 | def main(): 6 | dataset = UnifiedDataset( 7 | # TODO(bivanovic@nvidia.com) Remove lyft from default examples 8 | desired_data=["nusc_mini", "lyft_sample", "nuplan_mini"], 9 | rebuild_maps=True, 10 | data_dirs={ # Remember to change this to match your filesystem! 11 | "nusc_mini": "~/datasets/nuScenes", 12 | "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", 13 | "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", 14 | }, 15 | verbose=True, 16 | ) 17 | print(f"Finished Caching Maps!") 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /examples/scene_batch_example.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | 6 | from trajdata import AgentType, SceneBatch, UnifiedDataset 7 | from trajdata.augmentation import NoiseHistories 8 | from trajdata.visualization.vis import plot_scene_batch 9 | 10 | 11 | def main(): 12 | noise_hists = NoiseHistories() 13 | 14 | dataset = UnifiedDataset( 15 | desired_data=["nusc_mini-mini_train"], 16 | centric="scene", 17 | desired_dt=0.1, 18 | history_sec=(3.2, 3.2), 19 | future_sec=(4.8, 4.8), 20 | only_types=[AgentType.VEHICLE], 21 | agent_interaction_distances=defaultdict(lambda: 30.0), 22 | incl_robot_future=True, 23 | incl_raster_map=True, 24 | raster_map_params={ 25 | "px_per_m": 2, 26 | "map_size_px": 224, 27 | "offset_frac_xy": (-0.5, 0.0), 28 | }, 29 | augmentations=[noise_hists], 30 | max_agent_num=20, 31 | num_workers=4, 32 | verbose=True, 33 | data_dirs={ # Remember to change this to match your filesystem! 34 | "nusc_mini": "~/datasets/nuScenes", 35 | }, 36 | ) 37 | 38 | print(f"# Data Samples: {len(dataset):,}") 39 | 40 | dataloader = DataLoader( 41 | dataset, 42 | batch_size=4, 43 | shuffle=True, 44 | collate_fn=dataset.get_collate_fn(), 45 | num_workers=4, 46 | ) 47 | 48 | batch: SceneBatch 49 | for batch in tqdm(dataloader): 50 | plot_scene_batch(batch, batch_idx=0) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /examples/scenetimebatcher_example.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | 6 | from trajdata import AgentBatch, AgentType, UnifiedDataset 7 | from trajdata.utils.batch_utils import SceneTimeBatcher 8 | from trajdata.visualization.vis import plot_agent_batch_all 9 | 10 | 11 | def main(): 12 | """ 13 | Here, we use SceneTimeBatcher to loop through an 14 | Agent-centric dataset with batches grouped by scene and timestep 15 | """ 16 | dataset = UnifiedDataset( 17 | desired_data=["nusc_mini-mini_train"], 18 | centric="agent", 19 | desired_dt=0.1, 20 | history_sec=(3.2, 3.2), 21 | future_sec=(4.8, 4.8), 22 | only_predict=[AgentType.VEHICLE], 23 | agent_interaction_distances=defaultdict(lambda: 30.0), 24 | incl_robot_future=False, 25 | incl_raster_map=True, 26 | raster_map_params={ 27 | "px_per_m": 2, 28 | "map_size_px": 224, 29 | "offset_frac_xy": (-0.5, 0.0), 30 | }, 31 | num_workers=0, 32 | verbose=True, 33 | data_dirs={ # Remember to change this to match your filesystem! 34 | "nusc_mini": "~/datasets/nuScenes", 35 | }, 36 | ) 37 | 38 | print(f"# Data Samples: {len(dataset):,}") 39 | 40 | dataloader = DataLoader( 41 | dataset, 42 | batch_sampler=SceneTimeBatcher(dataset), 43 | collate_fn=dataset.get_collate_fn(), 44 | num_workers=4, 45 | ) 46 | 47 | batch: AgentBatch 48 | for batch in tqdm(dataloader): 49 | plot_agent_batch_all(batch) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /examples/sim_example.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, List, Tuple 3 | 4 | import numpy as np 5 | from tqdm import trange 6 | 7 | from trajdata import AgentBatch, AgentType, UnifiedDataset 8 | from trajdata.data_structures.scene_metadata import Scene 9 | from trajdata.data_structures.state import StateArray 10 | from trajdata.simulation import SimulationScene, sim_metrics, sim_stats, sim_vis 11 | from trajdata.visualization.vis import plot_agent_batch 12 | 13 | 14 | def main(): 15 | dataset = UnifiedDataset( 16 | desired_data=["nusc_mini"], 17 | only_types=[AgentType.VEHICLE], 18 | agent_interaction_distances=defaultdict(lambda: 50.0), 19 | # incl_map=True, 20 | # map_params={ 21 | # "px_per_m": 2, 22 | # "map_size_px": 224, 23 | # "offset_frac_xy": (0.0, 0.0), 24 | # }, 25 | verbose=True, 26 | # desired_dt=0.1, 27 | num_workers=4, 28 | data_dirs={ # Remember to change this to match your filesystem! 29 | "nusc_mini": "~/datasets/nuScenes", 30 | }, 31 | ) 32 | 33 | ade = sim_metrics.ADE() 34 | fde = sim_metrics.FDE() 35 | 36 | sim_env_name = "nusc_mini_sim" 37 | all_sim_scenes: List[Scene] = list() 38 | desired_scene: Scene 39 | for idx, desired_scene in enumerate(dataset.scenes()): 40 | sim_scene: SimulationScene = SimulationScene( 41 | env_name=sim_env_name, 42 | scene_name=f"sim_scene-{idx:04d}", 43 | scene=desired_scene, 44 | dataset=dataset, 45 | init_timestep=0, 46 | freeze_agents=True, 47 | ) 48 | 49 | vel_hist = sim_stats.VelocityHistogram(bins=np.linspace(0, 40, 41)) 50 | lon_acc_hist = sim_stats.LongitudinalAccHistogram(bins=np.linspace(0, 10, 11)) 51 | lat_acc_hist = sim_stats.LateralAccHistogram(bins=np.linspace(0, 10, 11)) 52 | jerk_hist = sim_stats.JerkHistogram( 53 | bins=np.linspace(0, 40, 41), dt=sim_scene.scene.dt 54 | ) 55 | 56 | obs: AgentBatch = sim_scene.reset() 57 | for t in trange(1, sim_scene.scene.length_timesteps): 58 | new_xyzh_dict: Dict[str, StateArray] = dict() 59 | for idx, agent_name in enumerate(obs.agent_name): 60 | curr_yaw = obs.curr_agent_state[idx].heading.item() 61 | curr_pos = obs.curr_agent_state[idx].position.numpy() 62 | world_from_agent = np.array( 63 | [ 64 | [np.cos(curr_yaw), np.sin(curr_yaw)], 65 | [-np.sin(curr_yaw), np.cos(curr_yaw)], 66 | ] 67 | ) 68 | next_state = np.zeros((4,)) 69 | if obs.agent_fut_len[idx] < 1: 70 | next_state[:2] = curr_pos 71 | yaw_ac = 0 72 | else: 73 | next_state[:2] = ( 74 | obs.agent_fut[idx, 0].position.numpy() @ world_from_agent 75 | + curr_pos 76 | ) 77 | yaw_ac = obs.agent_fut[idx, 0].heading.item() 78 | 79 | next_state[-1] = curr_yaw + yaw_ac 80 | new_xyzh_dict[agent_name] = StateArray.from_array(next_state, "x,y,z,h") 81 | 82 | obs = sim_scene.step(new_xyzh_dict) 83 | metrics: Dict[str, Dict[str, float]] = sim_scene.get_metrics([ade, fde]) 84 | print(metrics) 85 | 86 | stats: Dict[ 87 | str, Dict[str, Tuple[np.ndarray, np.ndarray]] 88 | ] = sim_scene.get_stats([vel_hist, lon_acc_hist, lat_acc_hist, jerk_hist]) 89 | sim_vis.plot_sim_stats(stats) 90 | 91 | plot_agent_batch(obs, 0, show=False, close=False) 92 | plot_agent_batch(obs, 1, show=False, close=False) 93 | plot_agent_batch(obs, 2, show=False, close=False) 94 | plot_agent_batch(obs, 3, show=True, close=True) 95 | 96 | sim_scene.finalize() 97 | sim_scene.save() 98 | 99 | all_sim_scenes.append(sim_scene.scene) 100 | 101 | dataset.env_cache.save_env_scenes_list(sim_env_name, all_sim_scenes) 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /examples/simple_map_api_example.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | from typing import Dict 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | from trajdata import MapAPI, VectorMap 9 | from trajdata.maps.vec_map_elements import MapElementType 10 | from trajdata.utils import map_utils 11 | 12 | 13 | def main(): 14 | cache_path = Path("~/.unified_data_cache").expanduser() 15 | map_api = MapAPI(cache_path) 16 | 17 | ### Loading random scene and initializing VectorMap. 18 | env_name: str = np.random.choice(["nusc_mini", "lyft_sample", "nuplan_mini"]) 19 | random_location_dict: Dict[str, str] = { 20 | "nuplan_mini": np.random.choice( 21 | ["boston", "singapore", "pittsburgh", "las_vegas"] 22 | ), 23 | "nusc_mini": np.random.choice(["boston-seaport", "singapore-onenorth"]), 24 | "lyft_sample": "palo_alto", 25 | } 26 | 27 | start = time.perf_counter() 28 | vec_map: VectorMap = map_api.get_map( 29 | f"{env_name}:{random_location_dict[env_name]}", incl_road_areas=True 30 | ) 31 | end = time.perf_counter() 32 | print(f"Map loading took {(end - start)*1000:.2f} ms") 33 | 34 | start = time.perf_counter() 35 | vec_map: VectorMap = map_api.get_map( 36 | f"{env_name}:{random_location_dict[env_name]}", incl_road_areas=True 37 | ) 38 | end = time.perf_counter() 39 | print(f"Repeated (cached in memory) map loading took {(end - start)*1000:.2f} ms") 40 | 41 | print(f"Randomly chose {vec_map.env_name}, {vec_map.map_name} map.") 42 | 43 | ### Lane Graph Visualization (with rasterized map in background) 44 | fig, ax = plt.subplots() 45 | 46 | print(f"Rasterizing Map...") 47 | start = time.perf_counter() 48 | map_img, raster_from_world = vec_map.rasterize( 49 | resolution=2, 50 | return_tf_mat=True, 51 | incl_centerlines=False, 52 | area_color=(255, 255, 255), 53 | edge_color=(0, 0, 0), 54 | scene_ts=100, 55 | ) 56 | end = time.perf_counter() 57 | print(f"Map rasterization took {(end - start)*1000:.2f} ms") 58 | 59 | ax.imshow(map_img, alpha=0.5, origin="lower") 60 | 61 | lane_idx = np.random.randint(0, len(vec_map.lanes)) 62 | print(f"Visualizing random lane index {lane_idx}...") 63 | start = time.perf_counter() 64 | vec_map.visualize_lane_graph( 65 | origin_lane=lane_idx, 66 | num_hops=10, 67 | raster_from_world=raster_from_world, 68 | ax=ax, 69 | ) 70 | end = time.perf_counter() 71 | print(f"Lane visualization took {(end - start)*1000:.2f} ms") 72 | 73 | point = vec_map.lanes[lane_idx].center.xyz[0, :] 74 | point_raster = map_utils.transform_points( 75 | point[None, :], transf_mat=raster_from_world 76 | ) 77 | ax.scatter(point_raster[:, 0], point_raster[:, 1]) 78 | 79 | print("Getting nearest road area...") 80 | start = time.perf_counter() 81 | area = vec_map.get_closest_area(point, elem_type=MapElementType.ROAD_AREA) 82 | end = time.perf_counter() 83 | print(f"Getting nearest area took {(end-start)*1000:.2f} ms") 84 | 85 | raster_pts = map_utils.transform_points(area.exterior_polygon.xy, raster_from_world) 86 | ax.fill(raster_pts[:, 0], raster_pts[:, 1], alpha=1.0, color="C0") 87 | 88 | print("Getting road areas within 100m...") 89 | start = time.perf_counter() 90 | areas = vec_map.get_areas_within( 91 | point, elem_type=MapElementType.ROAD_AREA, dist=100.0 92 | ) 93 | end = time.perf_counter() 94 | print(f"Getting areas within took {(end-start)*1000:.2f} ms") 95 | 96 | for area in areas: 97 | raster_pts = map_utils.transform_points( 98 | area.exterior_polygon.xy, raster_from_world 99 | ) 100 | ax.fill(raster_pts[:, 0], raster_pts[:, 1], alpha=0.2, color="C1") 101 | 102 | ax.axis("equal") 103 | ax.grid(None) 104 | 105 | plt.show() 106 | plt.close("all") 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /examples/simple_sim_example.py: -------------------------------------------------------------------------------- 1 | from typing import Dict # Just for type annotations 2 | 3 | import numpy as np 4 | from tqdm import trange 5 | 6 | from trajdata import AgentBatch, UnifiedDataset 7 | from trajdata.data_structures.scene_metadata import Scene 8 | from trajdata.data_structures.state import StateArray # Just for type annotations 9 | from trajdata.simulation import SimulationScene 10 | 11 | dataset = UnifiedDataset( 12 | desired_data=["nusc_mini"], 13 | data_dirs={ # Remember to change this to match your filesystem! 14 | "nusc_mini": "~/datasets/nuScenes", 15 | }, 16 | ) 17 | 18 | desired_scene: Scene = dataset.get_scene(scene_idx=0) 19 | sim_scene = SimulationScene( 20 | env_name="nusc_mini_sim", 21 | scene_name="sim_scene", 22 | scene=desired_scene, 23 | dataset=dataset, 24 | init_timestep=0, 25 | freeze_agents=True, 26 | ) 27 | 28 | obs: AgentBatch = sim_scene.reset() 29 | for t in trange(1, sim_scene.scene.length_timesteps): 30 | new_xyzh_dict: Dict[str, StateArray] = dict() 31 | 32 | # Everything inside the forloop just sets 33 | # agents' next states to their current ones. 34 | for idx, agent_name in enumerate(obs.agent_name): 35 | curr_yaw = obs.curr_agent_state[idx].heading.item() 36 | curr_pos = obs.curr_agent_state[idx].position.numpy() 37 | 38 | next_state = np.zeros((4,)) 39 | next_state[:2] = curr_pos 40 | next_state[-1] = curr_yaw 41 | new_xyzh_dict[agent_name] = StateArray.from_array(next_state, "x,y,z,h") 42 | 43 | obs = sim_scene.step(new_xyzh_dict) 44 | -------------------------------------------------------------------------------- /examples/speed_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | 7 | from trajdata import AgentBatch, AgentType, UnifiedDataset 8 | from trajdata.augmentation import NoiseHistories 9 | 10 | 11 | def main(): 12 | noise_hists = NoiseHistories() 13 | 14 | dataset = UnifiedDataset( 15 | desired_data=["nusc_mini-mini_train"], 16 | centric="agent", 17 | desired_dt=0.1, 18 | history_sec=(3.2, 3.2), 19 | future_sec=(4.8, 4.8), 20 | only_predict=[AgentType.VEHICLE], 21 | agent_interaction_distances=defaultdict(lambda: 30.0), 22 | incl_robot_future=True, 23 | incl_raster_map=True, 24 | raster_map_params={ 25 | "px_per_m": 2, 26 | "map_size_px": 224, 27 | "offset_frac_xy": (-0.5, 0.0), 28 | }, 29 | incl_vector_map=True, 30 | augmentations=[noise_hists], 31 | num_workers=0, 32 | verbose=True, 33 | data_dirs={ # Remember to change this to match your filesystem! 34 | "nusc_mini": "~/datasets/nuScenes", 35 | }, 36 | ) 37 | 38 | print(f"# Data Samples: {len(dataset):,}") 39 | 40 | dataloader = DataLoader( 41 | dataset, 42 | batch_size=64, 43 | shuffle=True, 44 | collate_fn=dataset.get_collate_fn(), 45 | num_workers=os.cpu_count() // 2, 46 | ) 47 | 48 | batch: AgentBatch 49 | for batch in tqdm(dataloader): 50 | pass 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /examples/state_example.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | 6 | from trajdata import AgentBatch, AgentType, UnifiedDataset 7 | from trajdata.data_structures.state import StateArray, StateTensor 8 | 9 | 10 | def main(): 11 | dataset = UnifiedDataset( 12 | desired_data=["lyft_sample-mini_val"], 13 | centric="agent", 14 | desired_dt=0.1, 15 | history_sec=(3.2, 3.2), 16 | future_sec=(4.8, 4.8), 17 | only_predict=[AgentType.VEHICLE], 18 | state_format="x,y,z,xd,yd,xdd,ydd,h", 19 | agent_interaction_distances=defaultdict(lambda: 30.0), 20 | incl_robot_future=False, 21 | incl_raster_map=True, 22 | raster_map_params={ 23 | "px_per_m": 2, 24 | "map_size_px": 224, 25 | "offset_frac_xy": (-0.5, 0.0), 26 | }, 27 | num_workers=0, 28 | verbose=True, 29 | data_dirs={ # Remember to change this to match your filesystem! 30 | "lyft_sample": "~/datasets/lyft_sample/scenes/sample.zarr", 31 | }, 32 | ) 33 | 34 | print(f"# Data Samples: {len(dataset):,}") 35 | 36 | dataloader = DataLoader( 37 | dataset, 38 | batch_size=4, 39 | shuffle=True, 40 | collate_fn=dataset.get_collate_fn(), 41 | num_workers=4, 42 | ) 43 | 44 | # batchElement has properties that correspond to agent states 45 | ego_state = dataset[0].curr_agent_state_np.copy() 46 | print(ego_state) 47 | 48 | # StateArray types offer easy conversion to whatever format you want your state 49 | # e.g. we want x,y position and cos/sin heading: 50 | print(ego_state.as_format("x,y,c,s")) 51 | 52 | # We can also access elements via properties 53 | print(ego_state.position3d) 54 | print(ego_state.velocity) 55 | 56 | # We can set elements of states via properties. E.g., let's reset the heading to 0 57 | ego_state.heading = 0 58 | print(ego_state) 59 | 60 | # We can request elements that aren't directly stored in the state, e.g. cos/sin heading 61 | print(ego_state.heading_vector) 62 | 63 | # However, we can't set properties that aren't directly stored in the state tensor 64 | try: 65 | ego_state.heading_vector = 0.0 66 | except AttributeError as e: 67 | print(e) 68 | 69 | # Finally, StateArrays are just np.ndarrays under the hood, and any normal np operation 70 | # should convert them to a normal array 71 | print(ego_state**2) 72 | 73 | # To convert an np.array into a StateArray, we just need to specify what format it is 74 | # Note that StateArrays can have an arbitrary number of batch elems 75 | print(StateArray.from_array(np.random.randn(1, 2, 3), "x,y,z")) 76 | 77 | # Analagous to StateArray wrapping np.arrays, the StateTensor class gives the same 78 | # functionality to torch.Tensors 79 | batch: AgentBatch = next(iter(dataloader)) 80 | ego_state_t: StateTensor = batch.curr_agent_state 81 | 82 | print(ego_state_t.as_format("x,y,c,s")) 83 | print(ego_state_t.position3d) 84 | print(ego_state_t.velocity) 85 | ego_state_t.heading = 0 86 | print(ego_state_t) 87 | print(ego_state_t.heading_vector) 88 | 89 | # Furthermore, we can use the from_numpy() and numpy() methods to convert to and from 90 | # StateTensors with the same format 91 | print(ego_state_t.numpy()) 92 | print(StateTensor.from_numpy(ego_state)) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /examples/visualization_example.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | 6 | from trajdata import AgentBatch, AgentType, UnifiedDataset 7 | from trajdata.visualization.interactive_animation import ( 8 | InteractiveAnimation, 9 | animate_agent_batch_interactive, 10 | ) 11 | from trajdata.visualization.interactive_vis import plot_agent_batch_interactive 12 | from trajdata.visualization.vis import plot_agent_batch 13 | 14 | 15 | def main(): 16 | dataset = UnifiedDataset( 17 | desired_data=["nusc_mini"], 18 | centric="agent", 19 | desired_dt=0.1, 20 | # history_sec=(3.2, 3.2), 21 | # future_sec=(4.8, 4.8), 22 | only_predict=[AgentType.VEHICLE], 23 | state_format="x,y,z,xd,yd,h", 24 | obs_format="x,y,z,xd,yd,s,c", 25 | # agent_interaction_distances=defaultdict(lambda: 30.0), 26 | incl_robot_future=False, 27 | incl_raster_map=True, 28 | raster_map_params={ 29 | "px_per_m": 2, 30 | "map_size_px": 224, 31 | "offset_frac_xy": (-0.5, 0.0), 32 | }, 33 | num_workers=4, 34 | verbose=True, 35 | data_dirs={ # Remember to change this to match your filesystem! 36 | "nusc_mini": "~/datasets/nuScenes", 37 | "lyft_sample": "~/datasets/lyft/scenes/sample.zarr", 38 | "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", 39 | }, 40 | ) 41 | 42 | print(f"# Data Samples: {len(dataset):,}") 43 | 44 | dataloader = DataLoader( 45 | dataset, 46 | batch_size=4, 47 | shuffle=True, 48 | collate_fn=dataset.get_collate_fn(), 49 | num_workers=0, 50 | ) 51 | 52 | batch: AgentBatch 53 | for batch in tqdm(dataloader): 54 | plot_agent_batch_interactive(batch, batch_idx=0, cache_path=dataset.cache_path) 55 | plot_agent_batch(batch, batch_idx=0) 56 | 57 | animation = InteractiveAnimation( 58 | animate_agent_batch_interactive, 59 | batch=batch, 60 | batch_idx=0, 61 | cache_path=dataset.cache_path, 62 | ) 63 | animation.show() 64 | # break 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /img/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/trajdata/51f0efa9d572fa7da480ef8caf089d0d6987de9f/img/architecture.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile = "black" 3 | 4 | [build-system] 5 | requires = ["setuptools>=58", "wheel"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | classifiers = [ 10 | "Development Status :: 3 - Alpha", 11 | "Intended Audience :: Developers", 12 | "Programming Language :: Python :: 3.8", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | name = "trajdata" 16 | version = "1.4.0" 17 | authors = [{ name = "Boris Ivanovic", email = "bivanovic@nvidia.com" }] 18 | description = "A unified interface to many trajectory forecasting datasets." 19 | readme = "README.md" 20 | requires-python = ">=3.8" 21 | dependencies = [ 22 | "numpy>=1.19", 23 | "tqdm>=4.62", 24 | "matplotlib>=3.5", 25 | "dill>=0.3.4", 26 | "pandas>=1.4.1", 27 | "pyarrow>=7.0.0", 28 | "torch>=1.10.2", 29 | "zarr>=2.11.0", 30 | "kornia>=0.6.4", 31 | "seaborn>=0.12", 32 | "bokeh>=3.0.3", 33 | "geopandas>=0.13.2", 34 | "protobuf==3.19.4", 35 | "scipy>=1.9.0", 36 | "opencv-python>=4.5.0", 37 | "shapely>=2.0.0", 38 | ] 39 | 40 | [project.optional-dependencies] 41 | av2 = ["av2==0.2.1"] 42 | dev = ["black", "isort", "pytest", "pytest-xdist", "twine", "build"] 43 | interaction = ["lanelet2==1.2.1"] 44 | lyft = ["l5kit==1.5.0"] 45 | nusc = ["nuscenes-devkit==1.1.9"] 46 | waymo = ["tensorflow==2.11.0", "waymo-open-dataset-tf-2-11-0", "intervaltree"] 47 | vod = ["vod-devkit==1.1.1"] 48 | 49 | [project.urls] 50 | "Homepage" = "https://github.com/nvr-avg/trajdata" 51 | "Bug Tracker" = "https://github.com/nvr-avg/trajdata/issues" 52 | -------------------------------------------------------------------------------- /src/trajdata/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_structures import AgentBatch, AgentType, SceneBatch 2 | from .dataset import UnifiedDataset 3 | from .maps import MapAPI, VectorMap 4 | -------------------------------------------------------------------------------- /src/trajdata/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentation import Augmentation, BatchAugmentation, DatasetAugmentation 2 | from .low_vel_yaw_correction import LowSpeedYawCorrection 3 | from .noise_histories import NoiseHistories 4 | -------------------------------------------------------------------------------- /src/trajdata/augmentation/augmentation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from trajdata.data_structures.batch import AgentBatch, SceneBatch 4 | 5 | 6 | class Augmentation: 7 | def __init__(self) -> None: 8 | raise NotImplementedError() 9 | 10 | 11 | class DatasetAugmentation(Augmentation): 12 | def apply(self, scene_data_df: pd.DataFrame) -> None: 13 | raise NotImplementedError() 14 | 15 | 16 | class BatchAugmentation(Augmentation): 17 | def apply_agent(self, agent_batch: AgentBatch) -> None: 18 | raise NotImplementedError() 19 | 20 | def apply_scene(self, scene_batch: SceneBatch) -> None: 21 | raise NotImplementedError() 22 | -------------------------------------------------------------------------------- /src/trajdata/augmentation/low_vel_yaw_correction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from trajdata.augmentation.augmentation import DatasetAugmentation 5 | 6 | 7 | class LowSpeedYawCorrection(DatasetAugmentation): 8 | def __init__(self, speed_threshold: float = 0.0) -> None: 9 | self.speed_threshold = speed_threshold 10 | 11 | def apply(self, scene_data_df: pd.DataFrame) -> None: 12 | speed = np.linalg.norm(scene_data_df[["vx", "vy"]], axis=1) 13 | 14 | scene_data_df["yaw_diffs"] = scene_data_df["heading"].diff() 15 | # Doing this because the first row is always nan 16 | scene_data_df["yaw_diffs"].iat[0] = 0.0 17 | 18 | agent_ids: pd.Series = scene_data_df.index.get_level_values(0).astype( 19 | "category" 20 | ) 21 | 22 | # The point of the border mask is to catch data like this: 23 | # index agent_id vx vy 24 | # 0 1 7.3 9.1 25 | # 1 2 0.0 0.0 26 | # ... 27 | # As implemented, we would currently only 28 | # return index 0 (since we chop off the top with the 1: in the slice below), but 29 | # we want to return 1 so that's why the + 1 at the end. 30 | border_mask: np.ndarray = np.concatenate( 31 | [[0], np.nonzero(agent_ids[1:] != agent_ids[:-1])[0] + 1] 32 | ) 33 | 34 | scene_data_df["yaw_diffs"].iloc[border_mask] = 0.0 35 | scene_data_df["yaw_diffs"].iloc[speed < self.speed_threshold] = 0.0 36 | 37 | mask_arr = np.ones((len(scene_data_df),), dtype=np.bool) 38 | mask_arr[border_mask] = False 39 | scene_data_df["heading"].iloc[mask_arr] = 0.0 40 | 41 | scene_data_df.loc[:, "yaw_diffs"] += scene_data_df["heading"] 42 | scene_data_df.loc[:, "heading"] = scene_data_df.groupby("agent_id")[ 43 | "yaw_diffs" 44 | ].cumsum() 45 | 46 | scene_data_df.drop(columns="yaw_diffs", inplace=True) 47 | -------------------------------------------------------------------------------- /src/trajdata/augmentation/noise_histories.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from trajdata.augmentation.augmentation import BatchAugmentation 4 | from trajdata.data_structures.batch import AgentBatch, SceneBatch 5 | from trajdata.utils.arr_utils import PadDirection, mask_up_to 6 | 7 | 8 | class NoiseHistories(BatchAugmentation): 9 | def __init__( 10 | self, 11 | mean: float = 0.0, 12 | stddev: float = 0.1, 13 | ) -> None: 14 | self.mean = mean 15 | self.stddev = stddev 16 | 17 | def apply_agent(self, agent_batch: AgentBatch) -> None: 18 | agent_hist_noise = torch.normal( 19 | self.mean, self.stddev, size=agent_batch.agent_hist.shape 20 | ) 21 | neigh_hist_noise = torch.normal( 22 | self.mean, self.stddev, size=agent_batch.neigh_hist.shape 23 | ) 24 | 25 | if agent_batch.history_pad_dir == PadDirection.BEFORE: 26 | agent_hist_noise[..., -1, :] = 0 27 | neigh_hist_noise[..., -1, :] = 0 28 | else: 29 | len_mask = ~mask_up_to( 30 | agent_batch.agent_hist_len, 31 | delta=-1, 32 | max_len=agent_batch.agent_hist.shape[1], 33 | ).unsqueeze(-1) 34 | agent_hist_noise[len_mask.expand(-1, -1, agent_hist_noise.shape[-1])] = 0 35 | 36 | len_mask = ~mask_up_to( 37 | agent_batch.neigh_hist_len, 38 | delta=-1, 39 | max_len=agent_batch.neigh_hist.shape[2], 40 | ).unsqueeze(-1) 41 | neigh_hist_noise[ 42 | len_mask.expand(-1, -1, -1, neigh_hist_noise.shape[-1]) 43 | ] = 0 44 | 45 | agent_batch.agent_hist += agent_hist_noise 46 | agent_batch.neigh_hist += neigh_hist_noise 47 | 48 | def apply_scene(self, scene_batch: SceneBatch) -> None: 49 | scene_batch.agent_hist[..., :-1, :] += torch.normal( 50 | self.mean, self.stddev, size=scene_batch.agent_hist[..., :-1, :].shape 51 | ) 52 | -------------------------------------------------------------------------------- /src/trajdata/caching/__init__.py: -------------------------------------------------------------------------------- 1 | from .env_cache import EnvCache 2 | from .scene_cache import SceneCache 3 | -------------------------------------------------------------------------------- /src/trajdata/caching/env_cache.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, NamedTuple, Union 3 | 4 | import dill 5 | 6 | from trajdata.data_structures.scene_metadata import Scene 7 | 8 | 9 | class EnvCache: 10 | def __init__(self, cache_location: Path) -> None: 11 | self.path = cache_location 12 | 13 | def env_is_cached(self, env_name: str) -> bool: 14 | return (self.path / env_name / "scenes_list.dill").exists() 15 | 16 | def scene_is_cached(self, env_name: str, scene_name: str, scene_dt: float) -> bool: 17 | return EnvCache.scene_metadata_path( 18 | self.path, env_name, scene_name, scene_dt 19 | ).exists() 20 | 21 | @staticmethod 22 | def scene_metadata_path( 23 | base_path: Path, env_name: str, scene_name: str, scene_dt: float 24 | ) -> Path: 25 | return ( 26 | base_path / env_name / scene_name / f"scene_metadata_dt{scene_dt:.2f}.dill" 27 | ) 28 | 29 | def load_scene(self, env_name: str, scene_name: str, scene_dt: float) -> Scene: 30 | scene_file: Path = EnvCache.scene_metadata_path( 31 | self.path, env_name, scene_name, scene_dt 32 | ) 33 | with open(scene_file, "rb") as f: 34 | scene: Scene = dill.load(f) 35 | 36 | return scene 37 | 38 | def save_scene(self, scene: Scene) -> Path: 39 | scene_file: Path = EnvCache.scene_metadata_path( 40 | self.path, scene.env_name, scene.name, scene.dt 41 | ) 42 | 43 | scene_cache_dir: Path = scene_file.parent 44 | scene_cache_dir.mkdir(parents=True, exist_ok=True) 45 | 46 | with open(scene_file, "wb") as f: 47 | dill.dump(scene, f) 48 | 49 | return scene_file 50 | 51 | @staticmethod 52 | def save_scene_with_path(base_path: Path, scene: Scene) -> Path: 53 | scene_file: Path = EnvCache.scene_metadata_path( 54 | base_path, scene.env_name, scene.name, scene.dt 55 | ) 56 | 57 | scene_cache_dir: Path = scene_file.parent 58 | scene_cache_dir.mkdir(parents=True, exist_ok=True) 59 | 60 | with open(scene_file, "wb") as f: 61 | dill.dump(scene, f) 62 | 63 | return scene_file 64 | 65 | def load_env_scenes_list(self, env_name: str) -> List[NamedTuple]: 66 | env_cache_dir: Path = self.path / env_name 67 | with open(env_cache_dir / "scenes_list.dill", "rb") as f: 68 | scenes_list: List[NamedTuple] = dill.load(f) 69 | 70 | return scenes_list 71 | 72 | def save_env_scenes_list( 73 | self, env_name: str, scenes_list: List[NamedTuple] 74 | ) -> None: 75 | env_cache_dir: Path = self.path / env_name 76 | env_cache_dir.mkdir(parents=True, exist_ok=True) 77 | with open(env_cache_dir / "scenes_list.dill", "wb") as f: 78 | dill.dump(scenes_list, f) 79 | 80 | @staticmethod 81 | def load(scene_info_path: Union[Path, str]) -> Scene: 82 | with open(scene_info_path, "rb") as handle: 83 | scene: Scene = dill.load(handle) 84 | 85 | return scene 86 | -------------------------------------------------------------------------------- /src/trajdata/caching/scene_cache.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Type 4 | 5 | if TYPE_CHECKING: 6 | from trajdata.maps import TrafficLightStatus, VectorMap 7 | 8 | from pathlib import Path 9 | from typing import Any, Dict, List, Optional, Tuple, Union 10 | 11 | import numpy as np 12 | 13 | from trajdata.augmentation.augmentation import Augmentation 14 | from trajdata.data_structures.agent import AgentMetadata 15 | from trajdata.data_structures.scene_metadata import Scene 16 | from trajdata.data_structures.state import StateArray 17 | 18 | 19 | class SceneCache: 20 | def __init__( 21 | self, 22 | cache_path: Path, 23 | scene: Scene, 24 | augmentations: Optional[List[Augmentation]] = None, 25 | ) -> None: 26 | """ 27 | Creates and prepares the cache for online data loading. 28 | """ 29 | self.path = cache_path 30 | self.scene = scene 31 | self.dt = scene.dt 32 | self.augmentations = augmentations 33 | 34 | # Ensuring the scene cache folder exists 35 | self.scene_dir: Path = SceneCache.scene_cache_dir( 36 | self.path, self.scene.env_name, self.scene.name 37 | ) 38 | self.scene_dir.mkdir(parents=True, exist_ok=True) 39 | 40 | self.obs_type: Type[StateArray] = None 41 | 42 | @staticmethod 43 | def scene_cache_dir(cache_path: Path, env_name: str, scene_name: str) -> Path: 44 | """Standardized convention to compute scene cache folder path""" 45 | return cache_path / env_name / scene_name 46 | 47 | def write_cache_to_disk(self) -> None: 48 | """Saves agent data to disk for fast loading later (just like save_agent_data), 49 | but using the class attributes for the sources of data and file paths. 50 | """ 51 | raise NotImplementedError() 52 | 53 | # AGENT STATE DATA 54 | @staticmethod 55 | def save_agent_data( 56 | agent_data: Any, 57 | cache_path: Path, 58 | scene: Scene, 59 | ) -> None: 60 | """Saves agent data to disk for fast loading later.""" 61 | raise NotImplementedError() 62 | 63 | def get_value(self, agent_id: str, scene_ts: int, attribute: str) -> float: 64 | """ 65 | Get a single attribute value for an agent at a timestep. 66 | """ 67 | raise NotImplementedError() 68 | 69 | def get_raw_state(self, agent_id: str, scene_ts: int) -> StateArray: 70 | """ 71 | Get an agent's raw state (without transformations applied) 72 | """ 73 | raise NotImplementedError() 74 | 75 | def get_state(self, agent_id: str, scene_ts: int) -> StateArray: 76 | """ 77 | Get an agent's state at a specific timestep. 78 | """ 79 | raise NotImplementedError() 80 | 81 | def get_states(self, agent_ids: List[str], scene_ts: int) -> StateArray: 82 | """ 83 | Get multiple agents' states at a specific timestep. 84 | """ 85 | raise NotImplementedError() 86 | 87 | def set_obs_frame(self, obs_frame: StateArray) -> None: 88 | """ 89 | Set frame in which to return observations 90 | """ 91 | raise NotImplementedError() 92 | 93 | def reset_obs_frame(self) -> None: 94 | """ 95 | Reset observation frame to be same as world frame 96 | """ 97 | raise NotImplementedError() 98 | 99 | def set_obs_format(self, format_str: str) -> None: 100 | """ 101 | Sets observation format (which elements to include and their order) 102 | """ 103 | raise NotImplementedError() 104 | 105 | def reset_obs_format(self) -> None: 106 | """ 107 | Resets observation format to default (set by subclass) 108 | """ 109 | raise NotImplementedError() 110 | 111 | def interpolate_data(self, desired_dt: float, method: str = "linear") -> None: 112 | """Increase the sampling frequency of the data by interpolation. 113 | 114 | Args: 115 | desired_dt (float): The desired spacing between timesteps. 116 | method (str, optional): The type of interpolation to use, currently only "linear" is implemented. Defaults to "linear". 117 | """ 118 | raise NotImplementedError() 119 | 120 | def get_agent_history( 121 | self, 122 | agent_info: AgentMetadata, 123 | scene_ts: int, 124 | history_sec: Tuple[Optional[float], Optional[float]], 125 | ) -> Tuple[StateArray, np.ndarray]: 126 | """ 127 | Returns (agent_history_state, agent_extent) 128 | """ 129 | raise NotImplementedError() 130 | 131 | def get_agent_future( 132 | self, 133 | agent_info: AgentMetadata, 134 | scene_ts: int, 135 | future_sec: Tuple[Optional[float], Optional[float]], 136 | ) -> Tuple[StateArray, np.ndarray]: 137 | """ 138 | Returns (agent_future_state, agent_extent) 139 | """ 140 | raise NotImplementedError() 141 | 142 | def get_agents_history( 143 | self, 144 | scene_ts: int, 145 | agents: List[AgentMetadata], 146 | history_sec: Tuple[Optional[float], Optional[float]], 147 | ) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]: 148 | raise NotImplementedError() 149 | 150 | def get_agents_future( 151 | self, 152 | scene_ts: int, 153 | agents: List[AgentMetadata], 154 | future_sec: Tuple[Optional[float], Optional[float]], 155 | ) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]: 156 | raise NotImplementedError() 157 | 158 | # TRAFFIC LIGHT INFO 159 | @staticmethod 160 | def save_traffic_light_data( 161 | traffic_light_status_data: Any, cache_path: Path, scene: Scene 162 | ) -> None: 163 | """Saves traffic light status to disk for easy access later""" 164 | raise NotImplementedError() 165 | 166 | def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bool: 167 | raise NotImplementedError() 168 | 169 | def get_traffic_light_status_dict( 170 | self, desired_dt: Optional[float] = None 171 | ) -> Dict[Tuple[str, int], TrafficLightStatus]: 172 | """Returns lookup table for traffic light status in the current scene 173 | lane_id, scene_ts -> TrafficLightStatus""" 174 | raise NotImplementedError() 175 | 176 | # MAPS 177 | @staticmethod 178 | def are_maps_cached(cache_path: Path, env_name: str) -> bool: 179 | raise NotImplementedError() 180 | 181 | @staticmethod 182 | def is_map_cached( 183 | cache_path: Path, env_name: str, map_name: str, resolution: float 184 | ) -> bool: 185 | raise NotImplementedError() 186 | 187 | @staticmethod 188 | def finalize_and_cache_map( 189 | cache_path: Path, 190 | vector_map: VectorMap, 191 | map_params: Dict[str, Any], 192 | ) -> None: 193 | raise NotImplementedError() 194 | 195 | def load_map_patch( 196 | self, 197 | world_x: float, 198 | world_y: float, 199 | desired_patch_size: int, 200 | resolution: float, 201 | offset_xy: Tuple[float, float], 202 | agent_heading: float, 203 | return_rgb: bool, 204 | rot_pad_factor: float = 1.0, 205 | no_map_val: float = 0.0, 206 | ) -> Tuple[np.ndarray, np.ndarray, bool]: 207 | raise NotImplementedError() 208 | -------------------------------------------------------------------------------- /src/trajdata/data_structures/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import Agent, AgentMetadata, AgentType, FixedExtent, VariableExtent 2 | from .batch import AgentBatch, SceneBatch 3 | from .batch_element import AgentBatchElement, SceneBatchElement 4 | from .collation import agent_collate_fn, scene_collate_fn 5 | from .data_index import AgentDataIndex, DataIndex, SceneDataIndex 6 | from .environment import EnvMetadata 7 | from .scene import SceneTime, SceneTimeAgent 8 | from .scene_metadata import Scene, SceneMetadata 9 | from .scene_tag import SceneTag 10 | from .state import NP_STATE_TYPES, TORCH_STATE_TYPES, StateArray, StateTensor 11 | -------------------------------------------------------------------------------- /src/trajdata/data_structures/agent.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import IntEnum 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | class AgentType(IntEnum): 9 | UNKNOWN = 0 10 | VEHICLE = 1 11 | PEDESTRIAN = 2 12 | BICYCLE = 3 13 | MOTORCYCLE = 4 14 | 15 | 16 | class Extent: 17 | def get_extents(self, start_ts: int, end_ts: int) -> np.ndarray: 18 | """Get the agent's extents within the specified scene timesteps. 19 | 20 | Args: 21 | start_ts (int): The first scene timestep to get extents for (inclusive) 22 | end_ts (int): The last scene timestep to get extents for (inclusive) 23 | 24 | Returns: 25 | np.ndarray: The extents as a (T, 3)-shaped ndarray (length, width, height) 26 | """ 27 | raise NotImplementedError() 28 | 29 | 30 | @dataclass 31 | class FixedExtent(Extent): 32 | length: float 33 | width: float 34 | height: float 35 | 36 | def get_extents(self, start_ts: int, end_ts: int) -> np.ndarray: 37 | return np.repeat( 38 | np.array([[self.length, self.width, self.height]]), 39 | end_ts - start_ts + 1, 40 | axis=0, 41 | ) 42 | 43 | 44 | class VariableExtent(Extent): 45 | pass 46 | 47 | 48 | class AgentMetadata: 49 | """Holds node metadata, e.g., name, type, but without the memory footprint of all the actual underlying scene data.""" 50 | 51 | def __init__( 52 | self, 53 | name: str, 54 | agent_type: AgentType, 55 | first_timestep: int, 56 | last_timestep: int, 57 | extent: Extent, 58 | ) -> None: 59 | self.name = name 60 | self.type = agent_type 61 | self.first_timestep = first_timestep 62 | self.last_timestep = last_timestep 63 | self.extent = extent 64 | 65 | def __repr__(self) -> str: 66 | return "/".join([self.type.name, self.name]) 67 | 68 | 69 | class Agent: 70 | """Holds the data for a particular node.""" 71 | 72 | def __init__( 73 | self, 74 | metadata: AgentMetadata, 75 | data: pd.DataFrame, 76 | ) -> None: 77 | self.name = metadata.name 78 | self.type = metadata.type 79 | self.metadata = metadata 80 | self.data = data 81 | -------------------------------------------------------------------------------- /src/trajdata/data_structures/data_index.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | 7 | class DataIndex: 8 | """The data index is effectively a big list of tuples taking the form: 9 | 10 | [(scene_path, total_index_len, valid_scene_ts)] for scene-centric data, or 11 | [(scene_path, total_index_len, [(agent_name, valid_agent_ts)])] for agent-centric data 12 | """ 13 | 14 | def __init__( 15 | self, 16 | data_index: Union[ 17 | List[Tuple[str, int, np.ndarray]], 18 | List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], 19 | ], 20 | verbose: bool = False, 21 | ) -> None: 22 | scene_paths, full_index_len, _ = zip(*data_index) 23 | 24 | self._cumulative_lengths: np.ndarray = np.concatenate( 25 | ([0], np.cumsum(full_index_len)) 26 | ) 27 | self._len: int = self._cumulative_lengths[-1].item() 28 | 29 | self._scene_paths: np.ndarray = np.array(scene_paths).astype(np.string_) 30 | 31 | def __len__(self) -> int: 32 | return self._len 33 | 34 | def __getitem__(self, index: int) -> Tuple[str, int, int]: 35 | scene_idx: int = ( 36 | np.searchsorted(self._cumulative_lengths, index, side="right").item() - 1 37 | ) 38 | 39 | scene_path: str = str(self._scene_paths[scene_idx], encoding="utf-8") 40 | scene_elem_index: int = index - self._cumulative_lengths[scene_idx].item() 41 | return (scene_path, scene_idx, scene_elem_index) 42 | 43 | 44 | class AgentDataIndex(DataIndex): 45 | def __init__( 46 | self, 47 | data_index: List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], 48 | verbose: bool = False, 49 | ) -> None: 50 | super().__init__(data_index) 51 | 52 | agent_timesteps: List[List[Tuple[str, np.ndarray]]] = [ 53 | agent_ts_index for _, _, agent_ts_index in data_index 54 | ] 55 | 56 | self._agent_ids: List[np.ndarray] = list() 57 | self._agent_times: List[np.ndarray] = list() 58 | self._cumulative_scene_lengths: List[np.ndarray] = list() 59 | for scene_data_index in tqdm( 60 | agent_timesteps, desc="Structuring Agent Data Index", disable=not verbose 61 | ): 62 | agent_ids, agent_times = zip(*scene_data_index) 63 | 64 | self._agent_ids.append(np.array(agent_ids).astype(np.string_)) 65 | 66 | agent_ts: np.ndarray = np.stack(agent_times) 67 | self._agent_times.append(agent_ts) 68 | self._cumulative_scene_lengths.append( 69 | np.concatenate(([0], np.cumsum(agent_ts[:, 1] - agent_ts[:, 0] + 1))) 70 | ) 71 | 72 | def __getitem__(self, index: int) -> Tuple[str, str, int]: 73 | scene_path, scene_idx, scene_elem_index = super().__getitem__(index) 74 | 75 | agent_idx: int = ( 76 | np.searchsorted( 77 | self._cumulative_scene_lengths[scene_idx], 78 | scene_elem_index, 79 | side="right", 80 | ).item() 81 | - 1 82 | ) 83 | 84 | agent_id: str = str(self._agent_ids[scene_idx][agent_idx], encoding="utf-8") 85 | 86 | agent_timestep: int = ( 87 | scene_elem_index 88 | - self._cumulative_scene_lengths[scene_idx][agent_idx].item() 89 | + self._agent_times[scene_idx][agent_idx, 0] 90 | ).item() 91 | 92 | assert ( 93 | self._agent_times[scene_idx][agent_idx, 0] 94 | <= agent_timestep 95 | <= self._agent_times[scene_idx][agent_idx, 1] 96 | ) 97 | 98 | return scene_path, agent_id, agent_timestep 99 | 100 | 101 | class SceneDataIndex(DataIndex): 102 | def __init__( 103 | self, data_index: List[Tuple[str, int, np.ndarray]], verbose: bool = False 104 | ) -> None: 105 | super().__init__(data_index) 106 | 107 | self.scene_ts: List[np.ndarray] = [ 108 | valid_ts 109 | for _, _, valid_ts in tqdm( 110 | data_index, desc="Structuring Scene Data Index", disable=not verbose 111 | ) 112 | ] 113 | 114 | def __getitem__(self, index: int) -> Tuple[str, int]: 115 | scene_path, scene_idx, scene_elem_index = super().__getitem__(index) 116 | 117 | scene_ts: int = self.scene_ts[scene_idx][scene_elem_index].item() 118 | 119 | return scene_path, scene_ts 120 | -------------------------------------------------------------------------------- /src/trajdata/data_structures/environment.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from pathlib import Path 3 | from typing import Dict, List, Optional, Tuple 4 | 5 | from trajdata.data_structures.scene_tag import SceneTag 6 | 7 | 8 | class EnvMetadata: 9 | def __init__( 10 | self, 11 | name: str, 12 | data_dir: str, 13 | dt: float, 14 | parts: List[Tuple[str]], 15 | scene_split_map: Dict[str, str], 16 | map_locations: Optional[Tuple[str]] = None, 17 | ) -> None: 18 | self.name = name 19 | self.data_dir = Path(data_dir).expanduser().resolve() 20 | self.dt = dt 21 | self.map_locations = map_locations 22 | self.parts = parts 23 | self.scene_tags: List[SceneTag] = [ 24 | SceneTag(tag_tuple) 25 | # Cartesian product of the given list of tuples 26 | for tag_tuple in itertools.product(*([(name,)] + parts)) 27 | ] 28 | self.scene_split_map = scene_split_map 29 | -------------------------------------------------------------------------------- /src/trajdata/data_structures/scene.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Set 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from trajdata import filtering 7 | from trajdata.caching import SceneCache 8 | from trajdata.data_structures.agent import Agent, AgentMetadata, AgentType 9 | from trajdata.data_structures.scene_metadata import Scene 10 | from trajdata.data_structures.state import StateArray 11 | 12 | 13 | class SceneTime: 14 | """Holds the data for a particular scene at a particular timestep.""" 15 | 16 | def __init__( 17 | self, 18 | scene: Scene, 19 | scene_ts: int, 20 | agents: List[AgentMetadata], 21 | cache: SceneCache, 22 | ) -> None: 23 | self.scene = scene 24 | self.ts = scene_ts 25 | self.agents = agents 26 | self.cache = cache 27 | 28 | @classmethod 29 | def from_cache( 30 | cls, 31 | scene: Scene, 32 | scene_ts: int, 33 | cache: SceneCache, 34 | only_types: Optional[Set[AgentType]] = None, 35 | no_types: Optional[Set[AgentType]] = None, 36 | ): 37 | agents_present: List[AgentMetadata] = scene.agent_presence[scene_ts] 38 | filtered_agents: List[AgentMetadata] = filtering.agent_types( 39 | agents_present, no_types, only_types 40 | ) 41 | 42 | return cls(scene, scene_ts, filtered_agents, cache) 43 | 44 | def get_agent_distances_to(self, agent: Agent) -> np.ndarray: 45 | agent_pos: StateArray = self.cache.get_state(agent.name, self.ts).position 46 | nb_pos: np.ndarray = np.stack( 47 | [self.cache.get_state(nb.name, self.ts).position for nb in self.agents] 48 | ) 49 | 50 | return np.linalg.norm(nb_pos - agent_pos, axis=1) 51 | 52 | 53 | class SceneTimeAgent: 54 | """Holds the data for a particular agent in a scene at a particular timestep.""" 55 | 56 | def __init__( 57 | self, 58 | scene: Scene, 59 | scene_ts: int, 60 | agents: List[AgentMetadata], 61 | agent: AgentMetadata, 62 | cache: SceneCache, 63 | robot: Optional[AgentMetadata] = None, 64 | ) -> None: 65 | self.scene = scene 66 | self.ts = scene_ts 67 | self.agents = agents 68 | self.agent = agent 69 | self.cache = cache 70 | self.robot = robot 71 | 72 | @classmethod 73 | def from_cache( 74 | cls, 75 | scene: Scene, 76 | scene_ts: int, 77 | agent_id: str, 78 | cache: SceneCache, 79 | only_types: Optional[Set[AgentType]] = None, 80 | no_types: Optional[Set[AgentType]] = None, 81 | incl_robot_future: bool = False, 82 | ): 83 | agents_present: List[AgentMetadata] = scene.agent_presence[scene_ts] 84 | filtered_agents: List[AgentMetadata] = filtering.agent_types( 85 | agents_present, no_types, only_types 86 | ) 87 | 88 | agent_metadata = next((a for a in filtered_agents if a.name == agent_id), None) 89 | 90 | if incl_robot_future: 91 | ego_metadata = next((a for a in filtered_agents if a.name == "ego"), None) 92 | 93 | return cls( 94 | scene, 95 | scene_ts, 96 | agents=filtered_agents, 97 | agent=agent_metadata, 98 | cache=cache, 99 | robot=ego_metadata, 100 | ) 101 | else: 102 | return cls( 103 | scene, 104 | scene_ts, 105 | agents=filtered_agents, 106 | agent=agent_metadata, 107 | cache=cache, 108 | ) 109 | 110 | # @profile 111 | def get_agent_distances_to(self, agent_info: AgentMetadata) -> np.ndarray: 112 | agent_pos: StateArray = self.cache.get_state(agent_info.name, self.ts).position 113 | 114 | curr_poses: StateArray = self.cache.get_states( 115 | [a.name for a in self.agents], self.ts 116 | ).position 117 | return np.linalg.norm(curr_poses - agent_pos, axis=1) 118 | -------------------------------------------------------------------------------- /src/trajdata/data_structures/scene_metadata.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Any, List, Optional 3 | 4 | from trajdata.data_structures.agent import AgentMetadata 5 | from trajdata.data_structures.environment import EnvMetadata 6 | 7 | # Holds scene metadata (e.g., name, dt), but without the memory 8 | # footprint of all the actual underlying scene data. 9 | SceneMetadata = namedtuple("SceneMetadata", ["env_name", "name", "dt", "raw_data_idx"]) 10 | 11 | 12 | class Scene: 13 | """Holds the data for a particular scene.""" 14 | 15 | def __init__( 16 | self, 17 | env_metadata: EnvMetadata, 18 | name: str, 19 | location: str, 20 | data_split: str, 21 | length_timesteps: int, 22 | raw_data_idx: int, 23 | data_access_info: Any, 24 | description: Optional[str] = None, 25 | agents: Optional[List[AgentMetadata]] = None, 26 | agent_presence: Optional[List[List[AgentMetadata]]] = None, 27 | ) -> None: 28 | self.env_metadata = env_metadata 29 | self.env_name = env_metadata.name 30 | self.name = name 31 | self.location = location 32 | self.data_split = data_split 33 | self.dt = env_metadata.dt 34 | self.length_timesteps = length_timesteps 35 | self.raw_data_idx = raw_data_idx 36 | self.data_access_info = data_access_info 37 | self.description = description 38 | self.agents = agents 39 | self.agent_presence = agent_presence 40 | 41 | def length_seconds(self) -> float: 42 | return self.length_timesteps * self.dt 43 | 44 | def __repr__(self) -> str: 45 | return "/".join([self.env_name, self.name]) 46 | 47 | def update_agent_info( 48 | self, 49 | new_agents: List[AgentMetadata], 50 | new_agent_presence: List[List[AgentMetadata]], 51 | ) -> None: 52 | self.agents = new_agents 53 | self.agent_presence = new_agent_presence 54 | 55 | def to_metadata(self) -> SceneMetadata: 56 | return SceneMetadata(self.env_name, self.name, self.dt, self.raw_data_idx) 57 | -------------------------------------------------------------------------------- /src/trajdata/data_structures/scene_tag.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Set, Tuple 3 | 4 | 5 | class SceneTag: 6 | def __init__(self, tag_tuple: Tuple[str, ...]) -> None: 7 | self._tag_tuple: Set[str] = set(tag_tuple) 8 | 9 | def contains(self, query: Set[str]) -> bool: 10 | return query.issubset(self._tag_tuple) 11 | 12 | def matches_any(self, regex: re.Pattern) -> bool: 13 | return any(regex.search(x) is not None for x in self._tag_tuple) 14 | 15 | def __contains__(self, item) -> bool: 16 | return item in self._tag_tuple 17 | 18 | def __repr__(self) -> str: 19 | return "-".join(self._tag_tuple) 20 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/__init__.py: -------------------------------------------------------------------------------- 1 | from .raw_dataset import RawDataset 2 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/argoverse2/__init__.py: -------------------------------------------------------------------------------- 1 | from .av2_dataset import Av2Dataset 2 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/argoverse2/av2_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, List, Tuple, Type 3 | 4 | import pandas as pd 5 | import tqdm 6 | from av2.datasets.motion_forecasting.constants import ( 7 | AV2_SCENARIO_OBS_TIMESTEPS, 8 | AV2_SCENARIO_STEP_HZ, 9 | AV2_SCENARIO_TOTAL_TIMESTEPS, 10 | ) 11 | 12 | from trajdata.caching.env_cache import EnvCache 13 | from trajdata.caching.scene_cache import SceneCache 14 | from trajdata.data_structures import AgentMetadata, EnvMetadata, Scene, SceneMetadata 15 | from trajdata.data_structures.scene_tag import SceneTag 16 | from trajdata.dataset_specific.argoverse2.av2_utils import ( 17 | AV2_SPLITS, 18 | Av2Object, 19 | Av2ScenarioIds, 20 | av2_map_to_vector_map, 21 | get_track_metadata, 22 | scenario_name_to_split, 23 | ) 24 | from trajdata.dataset_specific.raw_dataset import RawDataset 25 | from trajdata.dataset_specific.scene_records import Argoverse2Record 26 | from trajdata.utils import arr_utils 27 | 28 | AV2_MOTION_FORECASTING = "av2_motion_forecasting" 29 | AV2_DT = 1 / AV2_SCENARIO_STEP_HZ 30 | 31 | 32 | class Av2Dataset(RawDataset): 33 | 34 | def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: 35 | if env_name != AV2_MOTION_FORECASTING: 36 | raise ValueError(f"Unknown Argoverse 2 env name: {env_name}") 37 | 38 | scenario_ids = Av2ScenarioIds.create(Path(data_dir)) 39 | 40 | return EnvMetadata( 41 | name=env_name, 42 | data_dir=data_dir, 43 | dt=AV2_DT, 44 | parts=[AV2_SPLITS], 45 | scene_split_map=scenario_ids.scene_split_map, 46 | map_locations=None, 47 | ) 48 | 49 | def load_dataset_obj(self, verbose: bool = False) -> None: 50 | if verbose: 51 | print(f"Loading {self.name} dataset...", flush=True) 52 | self.dataset_obj = Av2Object(self.metadata.data_dir) 53 | 54 | def _get_matching_scenes_from_obj( 55 | self, 56 | scene_tag: SceneTag, 57 | scene_desc_contains: List[str] | None, 58 | env_cache: EnvCache, 59 | ) -> List[SceneMetadata]: 60 | """Compute SceneMetadata for all samples from self.dataset_obj. 61 | 62 | Also saves records to env_cache for later reuse. 63 | """ 64 | if scene_desc_contains: 65 | raise ValueError("Argoverse dataset does not support scene descriptions.") 66 | 67 | record_list = [] 68 | metadata_list = [] 69 | 70 | for idx, scenario_name in enumerate(self.dataset_obj.scenario_names): 71 | record_list.append(Argoverse2Record(scenario_name, idx)) 72 | metadata_list.append( 73 | SceneMetadata( 74 | env_name=self.metadata.name, 75 | name=scenario_name, 76 | dt=AV2_DT, 77 | raw_data_idx=idx, 78 | ) 79 | ) 80 | 81 | self.cache_all_scenes_list(env_cache, record_list) 82 | return metadata_list 83 | 84 | def _get_matching_scenes_from_cache( 85 | self, 86 | scene_tag: SceneTag, 87 | scene_desc_contains: List[str] | None, 88 | env_cache: EnvCache, 89 | ) -> List[Scene]: 90 | """Computes Scene data for all samples by reading data from env_cache.""" 91 | if scene_desc_contains: 92 | raise ValueError("Argoverse dataset does not support scene descriptions.") 93 | 94 | record_list: List[Argoverse2Record] = env_cache.load_env_scenes_list(self.name) 95 | return [ 96 | self._create_scene(record.name, record.data_idx) for record in record_list 97 | ] 98 | 99 | def get_scene(self, scene_info: SceneMetadata) -> Scene: 100 | return self._create_scene(scene_info.name, scene_info.raw_data_idx) 101 | 102 | def _create_scene(self, scenario_name: str, data_idx: int) -> Scene: 103 | data_split = scenario_name_to_split(scenario_name) 104 | return Scene( 105 | env_metadata=self.metadata, 106 | name=scenario_name, 107 | location=scenario_name, 108 | data_split=data_split, 109 | length_timesteps=( 110 | AV2_SCENARIO_OBS_TIMESTEPS 111 | if data_split == "test" 112 | else AV2_SCENARIO_TOTAL_TIMESTEPS 113 | ), 114 | raw_data_idx=data_idx, 115 | data_access_info=None, 116 | ) 117 | 118 | def get_agent_info( 119 | self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] 120 | ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: 121 | """ 122 | Get frame-level information from source dataset, caching it 123 | to cache_path. 124 | 125 | Always called after cache_maps, can load map if needed 126 | to associate map information to positions. 127 | """ 128 | scenario = self.dataset_obj.load_scenario(scene.name) 129 | 130 | agent_list: List[AgentMetadata] = [] 131 | agent_presence: List[List[AgentMetadata]] = [[] for _ in scenario.timestamps_ns] 132 | 133 | df_records = [] 134 | 135 | for track in scenario.tracks: 136 | track_metadata = get_track_metadata(track) 137 | if track_metadata is None: 138 | continue 139 | 140 | agent_list.append(track_metadata) 141 | 142 | for object_state in track.object_states: 143 | agent_presence[int(object_state.timestep)].append(track_metadata) 144 | 145 | df_records.append( 146 | { 147 | "agent_id": track_metadata.name, 148 | "scene_ts": object_state.timestep, 149 | "x": object_state.position[0], 150 | "y": object_state.position[1], 151 | "z": 0.0, 152 | "vx": object_state.velocity[0], 153 | "vy": object_state.velocity[1], 154 | "heading": object_state.heading, 155 | } 156 | ) 157 | 158 | df = pd.DataFrame.from_records(df_records) 159 | df.set_index(["agent_id", "scene_ts"], inplace=True) 160 | df.sort_index(inplace=True) 161 | 162 | df[["ax", "ay"]] = ( 163 | arr_utils.agent_aware_diff( 164 | df[["vx", "vy"]].to_numpy(), df.index.get_level_values(0) 165 | ) 166 | / AV2_DT 167 | ) 168 | cache_class.save_agent_data(df, cache_path, scene) 169 | 170 | return agent_list, agent_presence 171 | 172 | def cache_maps( 173 | self, 174 | cache_path: Path, 175 | map_cache_class: Type[SceneCache], 176 | map_params: Dict[str, Any], 177 | ) -> None: 178 | """ 179 | Get static, scene-level info from the source dataset, caching it 180 | to cache_path. (Primarily this is info needed to construct VectorMap) 181 | 182 | Resolution is in pixels per meter. 183 | """ 184 | for scenario_name in tqdm.tqdm( 185 | self.dataset_obj.scenario_names, 186 | desc=f"{self.name} cache maps", 187 | dynamic_ncols=True, 188 | ): 189 | av2_map = self.dataset_obj.load_map(scenario_name) 190 | vector_map = av2_map_to_vector_map(f"{self.name}:{scenario_name}", av2_map) 191 | map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) 192 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/eth_ucy_peds/__init__.py: -------------------------------------------------------------------------------- 1 | from .eupeds_dataset import EUPedsDataset 2 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/interaction/__init__.py: -------------------------------------------------------------------------------- 1 | from .interaction_dataset import InteractionDataset 2 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/lyft/__init__.py: -------------------------------------------------------------------------------- 1 | from .lyft_dataset import LyftDataset 2 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/lyft/lyft_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Final, List 2 | 3 | import l5kit.data.proto.road_network_pb2 as l5_pb2 4 | import numpy as np 5 | import pandas as pd 6 | from l5kit.data import ChunkedDataset 7 | from l5kit.data.map_api import InterpolationMethod, MapAPI 8 | from l5kit.geometry import rotation33_as_yaw 9 | from tqdm import tqdm 10 | 11 | from trajdata.data_structures import ( 12 | Agent, 13 | AgentMetadata, 14 | AgentType, 15 | FixedExtent, 16 | Scene, 17 | VariableExtent, 18 | ) 19 | from trajdata.maps.vec_map import VectorMap 20 | from trajdata.maps.vec_map_elements import PedCrosswalk, Polyline, RoadLane 21 | from trajdata.utils import map_utils 22 | 23 | LYFT_DT: Final[float] = 0.1 24 | 25 | 26 | def agg_ego_data(lyft_obj: ChunkedDataset, scene: Scene) -> Agent: 27 | scene_frame_start = scene.data_access_info[0] 28 | scene_frame_end = scene.data_access_info[1] 29 | 30 | ego_translations = lyft_obj.frames[scene_frame_start:scene_frame_end][ 31 | "ego_translation" 32 | ][:, :3] 33 | 34 | # Doing this prepending so that the first velocity isn't zero (rather it's just the first actual velocity duplicated) 35 | prepend_pos = ego_translations[0, :2] - ( 36 | ego_translations[1, :2] - ego_translations[0, :2] 37 | ) 38 | ego_velocities = ( 39 | np.diff( 40 | ego_translations[:, :2], axis=0, prepend=np.expand_dims(prepend_pos, axis=0) 41 | ) 42 | / LYFT_DT 43 | ) 44 | 45 | # Doing this prepending so that the first acceleration isn't zero (rather it's just the first actual acceleration duplicated) 46 | prepend_vel = ego_velocities[0] - (ego_velocities[1] - ego_velocities[0]) 47 | ego_accelerations = ( 48 | np.diff(ego_velocities, axis=0, prepend=np.expand_dims(prepend_vel, axis=0)) 49 | / LYFT_DT 50 | ) 51 | 52 | ego_rotations = lyft_obj.frames[scene_frame_start:scene_frame_end]["ego_rotation"] 53 | ego_yaws = np.array( 54 | [rotation33_as_yaw(ego_rotations[i]) for i in range(scene.length_timesteps)] 55 | ) 56 | 57 | ego_extents = FixedExtent(length=4.869, width=1.852, height=1.476).get_extents( 58 | scene_frame_start, scene_frame_end - 1 59 | ) 60 | extent_cols: List[str] = ["length", "width", "height"] 61 | 62 | ego_data_np = np.concatenate( 63 | [ 64 | ego_translations, 65 | ego_velocities, 66 | ego_accelerations, 67 | np.expand_dims(ego_yaws, axis=1), 68 | ego_extents, 69 | ], 70 | axis=1, 71 | ) 72 | ego_data_df = pd.DataFrame( 73 | ego_data_np, 74 | columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"] + extent_cols, 75 | index=pd.MultiIndex.from_tuples( 76 | [("ego", idx) for idx in range(ego_data_np.shape[0])], 77 | names=["agent_id", "scene_ts"], 78 | ), 79 | ) 80 | 81 | ego_metadata = AgentMetadata( 82 | name="ego", 83 | agent_type=AgentType.VEHICLE, 84 | first_timestep=0, 85 | last_timestep=ego_data_np.shape[0] - 1, 86 | extent=VariableExtent(), 87 | ) 88 | return Agent( 89 | metadata=ego_metadata, 90 | data=ego_data_df, 91 | ) 92 | 93 | 94 | def lyft_type_to_unified_type(lyft_type: int) -> AgentType: 95 | # TODO(bivanovic): Currently not handling TRAM or ANIMAL. 96 | if lyft_type in [0, 1, 2, 16]: 97 | return AgentType.UNKNOWN 98 | elif lyft_type in [3, 4, 6, 7, 8, 9]: 99 | return AgentType.VEHICLE 100 | elif lyft_type in [10, 12]: 101 | return AgentType.BICYCLE 102 | elif lyft_type in [11, 13]: 103 | return AgentType.MOTORCYCLE 104 | elif lyft_type == 14: 105 | return AgentType.PEDESTRIAN 106 | 107 | 108 | def populate_vector_map(vector_map: VectorMap, mapAPI: MapAPI) -> None: 109 | maximum_bound: np.ndarray = np.full((3,), np.nan) 110 | minimum_bound: np.ndarray = np.full((3,), np.nan) 111 | for l5_element in tqdm(mapAPI.elements, desc="Creating Vectorized Map"): 112 | if mapAPI.is_lane(l5_element): 113 | l5_element_id: str = mapAPI.id_as_str(l5_element.id) 114 | l5_lane: l5_pb2.Lane = l5_element.element.lane 115 | 116 | lane_dict = mapAPI.get_lane_coords(l5_element_id) 117 | left_pts = lane_dict["xyz_left"] 118 | right_pts = lane_dict["xyz_right"] 119 | 120 | # Ensuring the left and right bounds have the same numbers of points. 121 | if len(left_pts) < len(right_pts): 122 | left_pts = mapAPI.interpolate( 123 | left_pts, len(right_pts), InterpolationMethod.INTER_ENSURE_LEN 124 | ) 125 | elif len(right_pts) < len(left_pts): 126 | right_pts = mapAPI.interpolate( 127 | right_pts, len(left_pts), InterpolationMethod.INTER_ENSURE_LEN 128 | ) 129 | 130 | midlane_pts: np.ndarray = (left_pts + right_pts) / 2 131 | 132 | # Computing the maximum and minimum map coordinates. 133 | maximum_bound = np.fmax(maximum_bound, left_pts.max(axis=0)) 134 | minimum_bound = np.fmin(minimum_bound, left_pts.min(axis=0)) 135 | 136 | maximum_bound = np.fmax(maximum_bound, right_pts.max(axis=0)) 137 | minimum_bound = np.fmin(minimum_bound, right_pts.min(axis=0)) 138 | 139 | maximum_bound = np.fmax(maximum_bound, midlane_pts.max(axis=0)) 140 | minimum_bound = np.fmin(minimum_bound, midlane_pts.min(axis=0)) 141 | 142 | # Adding the element to the map. 143 | new_lane = RoadLane( 144 | id=l5_element_id, 145 | center=Polyline(midlane_pts), 146 | left_edge=Polyline(left_pts), 147 | right_edge=Polyline(right_pts), 148 | ) 149 | 150 | new_lane.next_lanes.update( 151 | [mapAPI.id_as_str(gid) for gid in l5_lane.lanes_ahead] 152 | ) 153 | 154 | left_lane_change_id: str = mapAPI.id_as_str( 155 | l5_lane.adjacent_lane_change_left 156 | ) 157 | if left_lane_change_id: 158 | new_lane.adj_lanes_left.add(left_lane_change_id) 159 | 160 | right_lane_change_id: str = mapAPI.id_as_str( 161 | l5_lane.adjacent_lane_change_right 162 | ) 163 | if right_lane_change_id: 164 | new_lane.adj_lanes_right.add(right_lane_change_id) 165 | 166 | vector_map.add_map_element(new_lane) 167 | 168 | if mapAPI.is_crosswalk(l5_element): 169 | l5_element_id: str = mapAPI.id_as_str(l5_element.id) 170 | crosswalk_pts: np.ndarray = mapAPI.get_crosswalk_coords(l5_element_id)[ 171 | "xyz" 172 | ] 173 | 174 | # Computing the maximum and minimum map coordinates. 175 | maximum_bound = np.fmax(maximum_bound, crosswalk_pts.max(axis=0)) 176 | minimum_bound = np.fmin(minimum_bound, crosswalk_pts.min(axis=0)) 177 | 178 | vector_map.add_map_element( 179 | PedCrosswalk(id=l5_element_id, polygon=Polyline(crosswalk_pts)) 180 | ) 181 | 182 | # Setting the map bounds. 183 | # vector_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] 184 | vector_map.extent = np.concatenate((minimum_bound, maximum_bound)) 185 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/nuplan/__init__.py: -------------------------------------------------------------------------------- 1 | from .nuplan_dataset import NuplanDataset 2 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/nusc/__init__.py: -------------------------------------------------------------------------------- 1 | from .nusc_dataset import NuscDataset 2 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/raw_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Type, Union 3 | 4 | from trajdata.caching import EnvCache, SceneCache 5 | from trajdata.data_structures import ( 6 | AgentMetadata, 7 | EnvMetadata, 8 | Scene, 9 | SceneMetadata, 10 | SceneTag, 11 | ) 12 | 13 | 14 | class RawDataset: 15 | def __init__( 16 | self, env_name: str, data_dir: str, parallelizable: bool, has_maps: bool 17 | ) -> None: 18 | metadata = self.compute_metadata(env_name, data_dir) 19 | 20 | self.metadata = metadata 21 | self.name = metadata.name 22 | self.scene_tags = metadata.scene_tags 23 | self.dataset_obj = None 24 | 25 | self.parallelizable = parallelizable 26 | self.has_maps = has_maps 27 | 28 | def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: 29 | raise NotImplementedError() 30 | 31 | def get_matching_scene_tags(self, query: Set[str]) -> List[SceneTag]: 32 | return [scene_tag for scene_tag in self.scene_tags if scene_tag.contains(query)] 33 | 34 | def load_dataset_obj(self, verbose: bool = False) -> None: 35 | raise NotImplementedError() 36 | 37 | def del_dataset_obj(self) -> None: 38 | del self.dataset_obj 39 | self.dataset_obj = None 40 | 41 | def _get_matching_scenes_from_cache( 42 | self, 43 | scene_tag: SceneTag, 44 | scene_desc_contains: Optional[List[str]], 45 | env_cache: EnvCache, 46 | ) -> List[Scene]: 47 | raise NotImplementedError() 48 | 49 | def _get_matching_scenes_from_obj( 50 | self, 51 | scene_tag: SceneTag, 52 | scene_desc_contains: Optional[List[str]], 53 | env_cache: EnvCache, 54 | ) -> List[SceneMetadata]: 55 | raise NotImplementedError() 56 | 57 | def cache_all_scenes_list( 58 | self, env_cache: EnvCache, all_scenes_list: List[NamedTuple] 59 | ) -> None: 60 | env_cache.save_env_scenes_list(self.name, all_scenes_list) 61 | 62 | def get_matching_scenes( 63 | self, 64 | scene_tag: SceneTag, 65 | scene_desc_contains: Optional[List[str]], 66 | env_cache: EnvCache, 67 | rebuild_cache: bool, 68 | ) -> Union[List[Scene], List[SceneMetadata]]: 69 | if self.dataset_obj is None and not rebuild_cache: 70 | return self._get_matching_scenes_from_cache( 71 | scene_tag, scene_desc_contains, env_cache 72 | ) 73 | else: 74 | return self._get_matching_scenes_from_obj( 75 | scene_tag, scene_desc_contains, env_cache 76 | ) 77 | 78 | def get_scene(self, scene_info: SceneMetadata) -> Scene: 79 | raise NotImplementedError() 80 | 81 | def get_agent_info( 82 | self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] 83 | ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: 84 | """ 85 | Get frame-level information from source dataset, caching it 86 | to cache_path. 87 | 88 | Always called after cache_maps, can load map if needed 89 | to associate map information to positions. 90 | """ 91 | raise NotImplementedError() 92 | 93 | def cache_maps( 94 | self, 95 | cache_path: Path, 96 | map_cache_class: Type[SceneCache], 97 | map_params: Dict[str, Any], 98 | ) -> None: 99 | """ 100 | Get static, scene-level info from the source dataset, caching it 101 | to cache_path. (Primarily this is info needed to construct VectorMap) 102 | 103 | Resolution is in pixels per meter. 104 | """ 105 | raise NotImplementedError() 106 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/scene_records.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | 4 | class Argoverse2Record(NamedTuple): 5 | name: str 6 | data_idx: int 7 | 8 | 9 | class EUPedsRecord(NamedTuple): 10 | name: str 11 | location: str 12 | length: str 13 | split: str 14 | data_idx: int 15 | 16 | 17 | class SDDPedsRecord(NamedTuple): 18 | name: str 19 | length: str 20 | data_idx: int 21 | 22 | 23 | class InteractionRecord(NamedTuple): 24 | name: str 25 | length: str 26 | data_idx: int 27 | 28 | 29 | class NuscSceneRecord(NamedTuple): 30 | name: str 31 | location: str 32 | length: str 33 | desc: str 34 | data_idx: int 35 | 36 | 37 | class VODSceneRecord(NamedTuple): 38 | token: str 39 | name: str 40 | location: str 41 | length: str 42 | desc: str 43 | data_idx: int 44 | 45 | 46 | class LyftSceneRecord(NamedTuple): 47 | name: str 48 | length: str 49 | data_idx: int 50 | 51 | 52 | class WaymoSceneRecord(NamedTuple): 53 | name: str 54 | length: str 55 | data_idx: int 56 | 57 | 58 | class NuPlanSceneRecord(NamedTuple): 59 | name: str 60 | location: str 61 | length: str 62 | split: str 63 | # desc: str 64 | data_idx: int 65 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/sdd_peds/__init__.py: -------------------------------------------------------------------------------- 1 | from .sddpeds_dataset import SDDPedsDataset 2 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/sdd_peds/estimated_homography.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Final 2 | 3 | # Please see https://github.com/crowdbotp/OpenTraj/tree/master/datasets/SDD for more information. 4 | # These homographies (transformations from pixel values to world coordinates) were estimated, 5 | # albeit most of them with high certainty. The certainty values indicate how reliable the 6 | # estimate is (or is not). Some of these scales were estimated using google maps, others are a pure guess. 7 | SDD_HOMOGRAPHY_SCALES: Final[Dict[str, Dict[str, float]]] = { 8 | "bookstore_0": {"certainty": 1.0, "scale": 0.038392063}, 9 | "bookstore_1": {"certainty": 1.0, "scale": 0.039892913}, 10 | "bookstore_2": {"certainty": 1.0, "scale": 0.04062433}, 11 | "bookstore_3": {"certainty": 1.0, "scale": 0.039098596}, 12 | "bookstore_4": {"certainty": 1.0, "scale": 0.0396}, 13 | "bookstore_5": {"certainty": 0.9, "scale": 0.0396}, 14 | "bookstore_6": {"certainty": 0.9, "scale": 0.0413}, 15 | "coupa_0": {"certainty": 1.0, "scale": 0.027995674}, 16 | "coupa_1": {"certainty": 1.0, "scale": 0.023224545}, 17 | "coupa_2": {"certainty": 1.0, "scale": 0.024}, 18 | "coupa_3": {"certainty": 1.0, "scale": 0.025524906}, 19 | "deathCircle_0": {"certainty": 1.0, "scale": 0.04064}, 20 | "deathCircle_1": {"certainty": 1.0, "scale": 0.039076923}, 21 | "deathCircle_2": {"certainty": 1.0, "scale": 0.03948382}, 22 | "deathCircle_3": {"certainty": 1.0, "scale": 0.028478209}, 23 | "deathCircle_4": {"certainty": 1.0, "scale": 0.038980137}, 24 | "gates_0": {"certainty": 1.0, "scale": 0.03976968}, 25 | "gates_1": {"certainty": 1.0, "scale": 0.03770837}, 26 | "gates_2": {"certainty": 1.0, "scale": 0.037272793}, 27 | "gates_3": {"certainty": 1.0, "scale": 0.034515323}, 28 | "gates_4": {"certainty": 1.0, "scale": 0.04412268}, 29 | "gates_5": {"certainty": 1.0, "scale": 0.0342392}, 30 | "gates_6": {"certainty": 1.0, "scale": 0.0342392}, 31 | "gates_7": {"certainty": 1.0, "scale": 0.04540353}, 32 | "gates_8": {"certainty": 1.0, "scale": 0.045191525}, 33 | "hyang_0": {"certainty": 1.0, "scale": 0.034749693}, 34 | "hyang_1": {"certainty": 1.0, "scale": 0.0453136}, 35 | "hyang_10": {"certainty": 1.0, "scale": 0.054460944}, 36 | "hyang_11": {"certainty": 1.0, "scale": 0.054992233}, 37 | "hyang_12": {"certainty": 1.0, "scale": 0.054104065}, 38 | "hyang_13": {"certainty": 0.0, "scale": 0.0541}, 39 | "hyang_14": {"certainty": 0.0, "scale": 0.0541}, 40 | "hyang_2": {"certainty": 1.0, "scale": 0.054992233}, 41 | "hyang_3": {"certainty": 1.0, "scale": 0.056642}, 42 | "hyang_4": {"certainty": 1.0, "scale": 0.034265612}, 43 | "hyang_5": {"certainty": 1.0, "scale": 0.029655497}, 44 | "hyang_6": {"certainty": 1.0, "scale": 0.052936449}, 45 | "hyang_7": {"certainty": 1.0, "scale": 0.03540125}, 46 | "hyang_8": {"certainty": 1.0, "scale": 0.034592381}, 47 | "hyang_9": {"certainty": 1.0, "scale": 0.038031423}, 48 | "little_0": {"certainty": 1.0, "scale": 0.028930169}, 49 | "little_1": {"certainty": 1.0, "scale": 0.028543144}, 50 | "little_2": {"certainty": 1.0, "scale": 0.028543144}, 51 | "little_3": {"certainty": 1.0, "scale": 0.028638926}, 52 | "nexus_0": {"certainty": 1.0, "scale": 0.043986494}, 53 | "nexus_1": {"certainty": 1.0, "scale": 0.043316805}, 54 | "nexus_10": {"certainty": 1.0, "scale": 0.043991753}, 55 | "nexus_11": {"certainty": 1.0, "scale": 0.043766154}, 56 | "nexus_2": {"certainty": 1.0, "scale": 0.042247434}, 57 | "nexus_3": {"certainty": 1.0, "scale": 0.045883871}, 58 | "nexus_4": {"certainty": 1.0, "scale": 0.045883871}, 59 | "nexus_5": {"certainty": 1.0, "scale": 0.045395745}, 60 | "nexus_6": {"certainty": 1.0, "scale": 0.037929168}, 61 | "nexus_7": {"certainty": 1.0, "scale": 0.037106087}, 62 | "nexus_8": {"certainty": 1.0, "scale": 0.037106087}, 63 | "nexus_9": {"certainty": 1.0, "scale": 0.044917895}, 64 | "quad_0": {"certainty": 1.0, "scale": 0.043606807}, 65 | "quad_1": {"certainty": 1.0, "scale": 0.042530206}, 66 | "quad_2": {"certainty": 1.0, "scale": 0.043338169}, 67 | "quad_3": {"certainty": 1.0, "scale": 0.044396842}, 68 | } 69 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/vod/__init__.py: -------------------------------------------------------------------------------- 1 | from .vod_dataset import VODDataset 2 | -------------------------------------------------------------------------------- /src/trajdata/dataset_specific/waymo/__init__.py: -------------------------------------------------------------------------------- 1 | from .waymo_dataset import WaymoDataset 2 | -------------------------------------------------------------------------------- /src/trajdata/filtering/__init__.py: -------------------------------------------------------------------------------- 1 | from .filters import * 2 | -------------------------------------------------------------------------------- /src/trajdata/filtering/filters.py: -------------------------------------------------------------------------------- 1 | from decimal import Decimal 2 | from math import ceil 3 | from typing import List, Optional, Set, Tuple 4 | 5 | from trajdata.data_structures.agent import AgentMetadata, AgentType 6 | 7 | 8 | def agent_types( 9 | agents: List[AgentMetadata], no_types: Set[AgentType], only_types: Set[AgentType] 10 | ) -> List[AgentMetadata]: 11 | agents_list: List[AgentMetadata] = agents 12 | 13 | if no_types is not None: 14 | agents_list = [agent for agent in agents_list if agent.type not in no_types] 15 | 16 | if only_types is not None: 17 | agents_list = [agent for agent in agents_list if agent.type in only_types] 18 | 19 | return agents_list 20 | 21 | 22 | def all_agents_excluded_types( 23 | no_types: Optional[List[AgentType]], agents: List[AgentMetadata] 24 | ) -> bool: 25 | return no_types is not None and all( 26 | agent_info.type in no_types for agent_info in agents 27 | ) 28 | 29 | 30 | def no_agent_included_types( 31 | only_types: Optional[List[AgentType]], agents: List[AgentMetadata] 32 | ) -> bool: 33 | return only_types is not None and all( 34 | agent_info.type not in only_types for agent_info in agents 35 | ) 36 | 37 | 38 | def get_valid_ts( 39 | agent_info: AgentMetadata, 40 | dt: float, 41 | history_sec: Tuple[Optional[float], Optional[float]], 42 | future_sec: Tuple[Optional[float], Optional[float]], 43 | ) -> Tuple[int, int]: 44 | """The returned timesteps are both inclusive. 45 | 46 | Args: 47 | agent_info (AgentMetadata): _description_ 48 | dt (float): _description_ 49 | history_sec (Tuple[Optional[float], Optional[float]]): _description_ 50 | future_sec (Tuple[Optional[float], Optional[float]]): _description_ 51 | 52 | Returns: 53 | Tuple[int, int]: _description_ 54 | """ 55 | first_valid_ts = agent_info.first_timestep 56 | if history_sec[0] is not None: 57 | min_history = ceil(Decimal(str(history_sec[0])) / Decimal(str(dt))) 58 | first_valid_ts += min_history 59 | 60 | last_valid_ts = agent_info.last_timestep 61 | if future_sec[0] is not None: 62 | min_future = ceil(Decimal(str(future_sec[0])) / Decimal(str(dt))) 63 | last_valid_ts -= min_future 64 | 65 | return first_valid_ts, last_valid_ts 66 | 67 | 68 | def satisfies_history( 69 | agent_info: AgentMetadata, 70 | ts: int, 71 | dt: float, 72 | history_sec: Tuple[Optional[float], Optional[float]], 73 | ) -> bool: 74 | if history_sec[0] is not None: 75 | min_history = ceil(Decimal(str(history_sec[0])) / Decimal(str(dt))) 76 | agent_history_satisfies = ts - agent_info.first_timestep >= min_history 77 | else: 78 | agent_history_satisfies = True 79 | 80 | return agent_history_satisfies 81 | 82 | 83 | def satisfies_future( 84 | agent_info: AgentMetadata, 85 | ts: int, 86 | dt: float, 87 | future_sec: Tuple[Optional[float], Optional[float]], 88 | ) -> bool: 89 | if future_sec[0] is not None: 90 | min_future = ceil(Decimal(str(future_sec[0])) / Decimal(str(dt))) 91 | agent_future_satisfies = agent_info.last_timestep - ts >= min_future 92 | else: 93 | agent_future_satisfies = True 94 | 95 | return agent_future_satisfies 96 | 97 | 98 | def satisfies_times( 99 | agent_info: AgentMetadata, 100 | ts: int, 101 | dt: float, 102 | history_sec: Tuple[Optional[float], Optional[float]], 103 | future_sec: Tuple[Optional[float], Optional[float]], 104 | ) -> bool: 105 | agent_history_satisfies = satisfies_history(agent_info, ts, dt, history_sec) 106 | agent_future_satisfies = satisfies_future(agent_info, ts, dt, future_sec) 107 | return agent_history_satisfies and agent_future_satisfies 108 | 109 | 110 | def no_agent_satisfies_time( 111 | ts: int, 112 | dt: float, 113 | history_sec: Tuple[Optional[float], Optional[float]], 114 | future_sec: Tuple[Optional[float], Optional[float]], 115 | agents: List[AgentMetadata], 116 | ) -> bool: 117 | return all( 118 | not satisfies_times(agent_info, ts, dt, history_sec, future_sec) 119 | for agent_info in agents 120 | ) 121 | -------------------------------------------------------------------------------- /src/trajdata/maps/__init__.py: -------------------------------------------------------------------------------- 1 | from .map_api import MapAPI 2 | from .raster_map import RasterizedMap, RasterizedMapMetadata, RasterizedMapPatch 3 | from .traffic_light_status import TrafficLightStatus 4 | from .vec_map import VectorMap 5 | -------------------------------------------------------------------------------- /src/trajdata/maps/lane_route.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Set 3 | 4 | 5 | @dataclass 6 | class LaneRoute: 7 | lane_idxs: Set[int] 8 | -------------------------------------------------------------------------------- /src/trajdata/maps/map_api.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Optional 4 | 5 | if TYPE_CHECKING: 6 | from trajdata.maps.map_kdtree import MapElementKDTree 7 | from trajdata.caching.scene_cache import SceneCache 8 | 9 | from pathlib import Path 10 | from typing import Dict 11 | 12 | from trajdata.maps.vec_map import VectorMap 13 | from trajdata.proto.vectorized_map_pb2 import VectorizedMap 14 | from trajdata.utils import map_utils 15 | 16 | 17 | class MapAPI: 18 | def __init__(self, unified_cache_path: Path, keep_in_memory: bool = False) -> None: 19 | """A simple interface for loading trajdata's vector maps which does not require 20 | instantiation of a `UnifiedDataset` object. 21 | 22 | Args: 23 | unified_cache_path (Path): Path to trajdata's local cache on disk. 24 | keep_in_memory (bool): Whether loaded maps should be stored 25 | in memory (memoized) for later re-use. For most cases (e.g., batched dataloading), 26 | this is a good idea. However, this can cause rapid memory usage growth for some 27 | datasets (e.g., Waymo) and it can be better to disable this. Defaults to False. 28 | """ 29 | self.unified_cache_path: Path = unified_cache_path 30 | self.maps: Dict[str, VectorMap] = dict() 31 | self._keep_in_memory = keep_in_memory 32 | 33 | def get_map( 34 | self, map_id: str, scene_cache: Optional[SceneCache] = None, **kwargs 35 | ) -> VectorMap: 36 | if map_id not in self.maps: 37 | env_name, map_name = map_id.split(":") 38 | env_maps_path: Path = self.unified_cache_path / env_name / "maps" 39 | stored_vec_map: VectorizedMap = map_utils.load_vector_map( 40 | env_maps_path / f"{map_name}.pb" 41 | ) 42 | 43 | vec_map: VectorMap = VectorMap.from_proto(stored_vec_map, **kwargs) 44 | vec_map.search_kdtrees = map_utils.load_kdtrees( 45 | env_maps_path / f"{map_name}_kdtrees.dill" 46 | ) 47 | vec_map.search_rtrees = map_utils.load_rtrees( 48 | env_maps_path / f"{map_name}_rtrees.dill" 49 | ) 50 | 51 | if self._keep_in_memory: 52 | self.maps[map_id] = vec_map 53 | else: 54 | vec_map = self.maps[map_id] 55 | 56 | if scene_cache is not None: 57 | vec_map.associate_scene_data( 58 | scene_cache.get_traffic_light_status_dict( 59 | kwargs.get("desired_dt", None) 60 | ) 61 | ) 62 | 63 | return vec_map 64 | -------------------------------------------------------------------------------- /src/trajdata/maps/map_kdtree.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import defaultdict 4 | from typing import TYPE_CHECKING, Dict 5 | 6 | if TYPE_CHECKING: 7 | from trajdata.maps.vec_map import VectorMap 8 | 9 | from typing import Optional, Tuple 10 | 11 | import numpy as np 12 | from scipy.spatial import KDTree 13 | from tqdm import tqdm 14 | 15 | from trajdata.maps.vec_map_elements import MapElement, MapElementType, Polyline 16 | from trajdata.utils.arr_utils import angle_wrap 17 | 18 | 19 | class MapElementKDTree: 20 | """ 21 | Constructs a KDTree of MapElements and exposes fast lookup functions. 22 | 23 | Inheriting classes need to implement the _extract_points function that defines for a MapElement 24 | the coordinates we want to store in the KDTree. 25 | """ 26 | 27 | def __init__(self, vector_map: VectorMap, verbose: bool = False) -> None: 28 | # Build kd-tree 29 | self.kdtree, self.polyline_inds, self.metadata = self._build_kdtree( 30 | vector_map, verbose 31 | ) 32 | 33 | def _build_kdtree(self, vector_map: VectorMap, verbose: bool = False): 34 | polylines = [] 35 | polyline_inds = [] 36 | metadata = defaultdict(list) 37 | 38 | map_elem: MapElement 39 | for map_elem in tqdm( 40 | vector_map.iter_elems(), 41 | desc=f"Building K-D Trees", 42 | leave=False, 43 | total=len(vector_map), 44 | disable=not verbose, 45 | ): 46 | result = self._extract_points_and_metadata(map_elem) 47 | if result is not None: 48 | points, extras = result 49 | polyline_inds.extend([len(polylines)] * points.shape[0]) 50 | 51 | # Apply any map offsets to ensure we're in the same coordinate area as the 52 | # original world map. 53 | polylines.append(points) 54 | 55 | for k, v in extras.items(): 56 | metadata[k].append(v) 57 | 58 | points = np.concatenate(polylines, axis=0) 59 | polyline_inds = np.array(polyline_inds) 60 | metadata = {k: np.concatenate(v) for k, v in metadata.items()} 61 | 62 | kdtree = KDTree(points) 63 | return kdtree, polyline_inds, metadata 64 | 65 | def _extract_points_and_metadata( 66 | self, map_element: MapElement 67 | ) -> Optional[Tuple[np.ndarray, Dict[str, np.ndarray]]]: 68 | """Defines the coordinates we want to store in the KDTree for a MapElement. 69 | Args: 70 | map_element (MapElement): the MapElement to store in the KDTree. 71 | Returns: 72 | Optional[np.ndarray]: coordinates based on which we can search the KDTree, or None. 73 | If None, the MapElement will not be stored. 74 | Else, tuple of 75 | np.ndarray: [B,d] set of B d-dimensional points to add, 76 | Dict[str, np.ndarray] mapping names to meta-information about the points 77 | """ 78 | raise NotImplementedError() 79 | 80 | def closest_point(self, query_points: np.ndarray) -> np.ndarray: 81 | """Find the closest KDTree points to (a batch of) query points. 82 | 83 | Args: 84 | query_points: np.ndarray of shape (..., data_dim). 85 | 86 | Return: 87 | np.ndarray of shape (..., data_dim), the KDTree points closest to query_point. 88 | """ 89 | _, data_inds = self.kdtree.query(query_points, k=1) 90 | pts = self.kdtree.data[data_inds] 91 | return pts 92 | 93 | def closest_polyline_ind(self, query_points: np.ndarray) -> np.ndarray: 94 | """Find the index of the closest polyline(s) in self.polylines.""" 95 | _, data_ind = self.kdtree.query(query_points, k=1) 96 | return self.polyline_inds[data_ind] 97 | 98 | def polyline_inds_in_range(self, point: np.ndarray, range: float) -> np.ndarray: 99 | """Find the index of polylines in self.polylines within 'range' distance to 'point'.""" 100 | data_inds = self.kdtree.query_ball_point(point, range) 101 | return np.unique(self.polyline_inds[data_inds], axis=0) 102 | 103 | 104 | class LaneCenterKDTree(MapElementKDTree): 105 | """KDTree for lane center polylines.""" 106 | 107 | def __init__( 108 | self, vector_map: VectorMap, max_segment_len: Optional[float] = None 109 | ) -> None: 110 | """ 111 | Args: 112 | vec_map: the VectorizedMap object to build the KDTree for 113 | max_segment_len (float, optional): if specified, we will insert extra points into the KDTree 114 | such that all polyline segments are shorter then max_segment_len. 115 | """ 116 | self.max_segment_len = max_segment_len 117 | super().__init__(vector_map) 118 | 119 | def _extract_points_and_metadata( 120 | self, map_element: MapElement 121 | ) -> Optional[Tuple[np.ndarray, Dict[str, np.ndarray]]]: 122 | if map_element.elem_type == MapElementType.ROAD_LANE: 123 | pts: Polyline = map_element.center 124 | if self.max_segment_len is not None: 125 | pts = pts.interpolate(max_dist=self.max_segment_len) 126 | 127 | # We only want to store xyz in the kdtree, not heading. 128 | return pts.xyz, {"heading": pts.h} 129 | else: 130 | return None 131 | 132 | def current_lane_inds( 133 | self, 134 | xyzh: np.ndarray, 135 | distance_threshold: float, 136 | heading_threshold: float, 137 | sorted: bool = True, 138 | dist_weight: float = 1.0, 139 | heading_weight: float = 0.1, 140 | ) -> np.ndarray: 141 | """ 142 | Args: 143 | xyzh (np.ndarray): [...,d]: (batch of) position and heading in world frame 144 | distance_threshold (Optional[float], optional). Defaults to None. 145 | heading_threshold (float, optional). Defaults to np.pi/8. 146 | 147 | Returns: 148 | np.ndarray: List of polyline inds that could be considered the current lane 149 | for the provided position and heading, ordered by heading similarity 150 | """ 151 | query_point = xyzh[..., :3] # query on xyz 152 | heading = xyzh[..., 3] 153 | data_inds = np.array( 154 | self.kdtree.query_ball_point(query_point, distance_threshold) 155 | ) 156 | 157 | if len(data_inds) == 0: 158 | return [] 159 | possible_points = self.kdtree.data[data_inds] 160 | possible_headings = self.metadata["heading"][data_inds] 161 | 162 | heading_errs = np.abs(angle_wrap(heading - possible_headings)) 163 | dist_errs = np.linalg.norm( 164 | query_point[None, :] - possible_points, ord=2, axis=-1 165 | ) 166 | 167 | under_thresh = heading_errs < heading_threshold 168 | lane_inds = self.polyline_inds[data_inds[under_thresh]] 169 | 170 | # we don't want to return duplicates of lanes 171 | unique_lane_inds = np.unique(lane_inds) 172 | 173 | if not sorted: 174 | return unique_lane_inds 175 | 176 | # if we are sorting results, evaluate cost: 177 | costs = ( 178 | dist_weight * dist_errs[under_thresh] 179 | + heading_weight * heading_errs[under_thresh] 180 | ) 181 | 182 | # cost for a lane is minimum over all possible points for that lane 183 | min_costs = [np.min(costs[lane_inds == ind]) for ind in unique_lane_inds] 184 | 185 | return unique_lane_inds[np.argsort(min_costs)] 186 | -------------------------------------------------------------------------------- /src/trajdata/maps/map_strtree.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple 4 | 5 | if TYPE_CHECKING: 6 | from trajdata.maps.vec_map import VectorMap 7 | 8 | import numpy as np 9 | from shapely import LinearRing, Polygon, STRtree, linearrings, points, polygons 10 | from tqdm import tqdm 11 | 12 | from trajdata.maps.vec_map_elements import ( 13 | MapElement, 14 | MapElementType, 15 | PedCrosswalk, 16 | PedWalkway, 17 | RoadArea, 18 | ) 19 | 20 | 21 | def polygon_with_holes_geometry(map_element: MapElement) -> Polygon: 22 | assert isinstance(map_element, RoadArea) 23 | points = linearrings(map_element.exterior_polygon.xy) 24 | holes: Optional[List[LinearRing]] = None 25 | if len(map_element.interior_holes) > 0: 26 | holes = [linearrings(hole.xy) for hole in map_element.interior_holes] 27 | 28 | return polygons(points, holes=holes) 29 | 30 | 31 | def polygon_geometry(map_element: MapElement) -> Polygon: 32 | assert isinstance(map_element, (PedWalkway, PedCrosswalk)) 33 | return polygons(map_element.polygon.xy) 34 | 35 | 36 | # Dictionary mapping map_elem_type to function returning 37 | # shapely polygon for that map element 38 | MAP_ELEM_TO_GEOMETRY: Dict[MapElementType, Callable[[MapElement], Polygon]] = { 39 | MapElementType.ROAD_AREA: polygon_with_holes_geometry, 40 | MapElementType.PED_CROSSWALK: polygon_geometry, 41 | MapElementType.PED_WALKWAY: polygon_geometry, 42 | } 43 | 44 | 45 | class MapElementSTRTree: 46 | """ 47 | Constructs an Rtree of Polygonal MapElements and exposes fast lookup functions. 48 | 49 | Inheriting classes need to implement the _extract_geometry function which for a MapElement 50 | returns the geometry we want to store 51 | """ 52 | 53 | def __init__( 54 | self, 55 | vector_map: VectorMap, 56 | elem_type: MapElementType, 57 | verbose: bool = False, 58 | ) -> None: 59 | # Build R-tree 60 | self.strtree, self.elem_ids = self._build_strtree( 61 | vector_map, elem_type, verbose 62 | ) 63 | 64 | def _build_strtree( 65 | self, 66 | vector_map: VectorMap, 67 | elem_type: MapElementType, 68 | verbose: bool = False, 69 | ) -> Tuple[STRtree, np.ndarray]: 70 | geometries: List[Polygon] = [] 71 | ids: List[str] = [] 72 | geometry_fn = MAP_ELEM_TO_GEOMETRY[elem_type] 73 | 74 | map_elem: MapElement 75 | for id, map_elem in tqdm( 76 | vector_map.elements[elem_type].items(), 77 | desc=f"Building STR Tree for {elem_type.name} elements", 78 | leave=False, 79 | disable=not verbose, 80 | ): 81 | ids.append(id) 82 | geometries.append(geometry_fn(map_elem)) 83 | 84 | return STRtree(geometries), np.array(ids) 85 | 86 | def query_point( 87 | self, 88 | point: np.ndarray, 89 | **kwargs, 90 | ) -> np.ndarray: 91 | """ 92 | Returns ID of all elements of type elem_type 93 | that intersect with query point 94 | 95 | Args: 96 | point (np.ndarray): point to query 97 | elem_type (MapElementType): type of elem to query 98 | kwargs: passed on to STRtree.query(), see 99 | https://pygeos.readthedocs.io/en/latest/strtree.html 100 | Can be used for predicate based queries, e.g. 101 | predicate="dwithin", distance=100. 102 | returns all elements which are within 100m of query point 103 | 104 | Returns: 105 | np.ndarray[str]: 1d array of ids of all elements matching query 106 | """ 107 | indices = self.strtree.query(points(point), **kwargs) 108 | return self.elem_ids[indices] 109 | 110 | def nearest_area( 111 | self, 112 | point: np.ndarray, 113 | **kwargs, 114 | ) -> str: 115 | """ 116 | Returns ID of the elements of type elem_type 117 | that are closest to point. 118 | 119 | Args: 120 | point (np.ndarray): point to query 121 | elem_type (MapElementType): type of elem to query 122 | kwargs: passed on to STRtree.nearest(), see 123 | https://pygeos.readthedocs.io/en/latest/strtree.html 124 | 125 | Returns: 126 | str: element_id of all elements matching query 127 | """ 128 | idx = self.strtree.nearest(points(point), **kwargs) 129 | return self.elem_ids[idx] 130 | -------------------------------------------------------------------------------- /src/trajdata/maps/raster_map.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | class RasterizedMapMetadata: 9 | def __init__( 10 | self, 11 | name: str, 12 | shape: Tuple[int, int, int], 13 | layers: List[str], 14 | layer_rgb_groups: Tuple[List[int], List[int], List[int]], 15 | resolution: float, # px/m 16 | map_from_world: np.ndarray, # Transformation from world coordinates [m] to map coordinates [px] 17 | ) -> None: 18 | self.name: str = name 19 | self.shape: Tuple[int, int, int] = shape 20 | self.layers: List[str] = layers 21 | self.layer_rgb_groups: Tuple[List[int], List[int], List[int]] = layer_rgb_groups 22 | self.resolution: float = resolution 23 | self.map_from_world: np.ndarray = map_from_world 24 | 25 | 26 | class RasterizedMap: 27 | def __init__( 28 | self, 29 | metadata: RasterizedMapMetadata, 30 | data: np.ndarray, 31 | ) -> None: 32 | assert data.shape == metadata.shape 33 | self.metadata: RasterizedMapMetadata = metadata 34 | self.data: np.ndarray = data 35 | 36 | @property 37 | def shape(self) -> Tuple[int, int, int]: 38 | return self.data.shape 39 | 40 | @staticmethod 41 | def to_img( 42 | map_arr: Tensor, 43 | idx_groups: Optional[Tuple[List[int], List[int], List[int]]] = None, 44 | ) -> Tensor: 45 | if idx_groups is None: 46 | return map_arr.permute(1, 2, 0).numpy() 47 | 48 | return torch.stack( 49 | [ 50 | torch.amax(map_arr[idx_groups[0]], dim=0), 51 | torch.amax(map_arr[idx_groups[1]], dim=0), 52 | torch.amax(map_arr[idx_groups[2]], dim=0), 53 | ], 54 | dim=-1, 55 | ).numpy() 56 | 57 | 58 | class RasterizedMapPatch: 59 | def __init__( 60 | self, 61 | data: np.ndarray, 62 | rot_angle: float, 63 | crop_size: int, 64 | resolution: float, 65 | raster_from_world_tf: np.ndarray, 66 | has_data: bool, 67 | ) -> None: 68 | self.data = data 69 | self.rot_angle = rot_angle 70 | self.crop_size = crop_size 71 | self.resolution = resolution 72 | self.raster_from_world_tf = raster_from_world_tf 73 | self.has_data = has_data 74 | -------------------------------------------------------------------------------- /src/trajdata/maps/traffic_light_status.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class TrafficLightStatus(IntEnum): 5 | NO_DATA = -1 6 | UNKNOWN = 0 7 | GREEN = 1 8 | RED = 2 9 | -------------------------------------------------------------------------------- /src/trajdata/maps/vec_map_elements.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import IntEnum 3 | from typing import List, Optional, Set 4 | 5 | import numpy as np 6 | 7 | from trajdata.utils import map_utils 8 | 9 | 10 | class MapElementType(IntEnum): 11 | ROAD_LANE = 1 12 | ROAD_AREA = 2 13 | PED_CROSSWALK = 3 14 | PED_WALKWAY = 4 15 | 16 | 17 | @dataclass 18 | class Polyline: 19 | points: np.ndarray 20 | 21 | def __post_init__(self) -> None: 22 | if self.points.shape[-1] < 2: 23 | raise ValueError( 24 | f"Polylines are expected to have 2 (xy), 3 (xyz), or 4 (xyzh) dimensions, but received {self.points.shape[-1]}." 25 | ) 26 | 27 | if self.points.shape[-1] == 2: 28 | # If only xy are passed in, then append zero to the end for z. 29 | self.points = np.append( 30 | self.points, np.zeros_like(self.points[:, [0]]), axis=-1 31 | ) 32 | 33 | @property 34 | def midpoint(self) -> np.ndarray: 35 | num_pts: int = self.points.shape[0] 36 | return self.points[num_pts // 2] 37 | 38 | @property 39 | def has_heading(self) -> bool: 40 | return self.points.shape[-1] == 4 41 | 42 | @property 43 | def xy(self) -> np.ndarray: 44 | return self.points[..., :2] 45 | 46 | @property 47 | def xyz(self) -> np.ndarray: 48 | return self.points[..., :3] 49 | 50 | @property 51 | def xyzh(self) -> np.ndarray: 52 | if self.has_heading: 53 | return self.points[..., :4] 54 | else: 55 | raise ValueError( 56 | f"This Polyline only has {self.points.shape[-1]} coordinates, expected 4." 57 | ) 58 | 59 | @property 60 | def h(self) -> np.ndarray: 61 | return self.points[..., 3] 62 | 63 | def interpolate( 64 | self, num_pts: Optional[int] = None, max_dist: Optional[float] = None 65 | ) -> "Polyline": 66 | return Polyline( 67 | map_utils.interpolate(self.points, num_pts=num_pts, max_dist=max_dist) 68 | ) 69 | 70 | def project_onto(self, xyz_or_xyzh: np.ndarray) -> np.ndarray: 71 | """Project the given points onto this Polyline. 72 | 73 | Args: 74 | xyzh (np.ndarray): Points to project, of shape (M, D) 75 | 76 | Returns: 77 | np.ndarray: The projected points, of shape (M, D) 78 | 79 | Note: 80 | D = 4 if this Polyline has headings, otherwise D = 3 81 | """ 82 | # xyzh is now (M, 1, 3), we do not use heading for projection. 83 | xyz = xyz_or_xyzh[:, np.newaxis, :3] 84 | 85 | # p0, p1 are (1, N, 3) 86 | p0: np.ndarray = self.points[np.newaxis, :-1, :3] 87 | p1: np.ndarray = self.points[np.newaxis, 1:, :3] 88 | 89 | # 1. Compute projections of each point to each line segment in a 90 | # batched manner. 91 | line_seg_diffs: np.ndarray = p1 - p0 92 | point_seg_diffs: np.ndarray = xyz - p0 93 | 94 | dot_products: np.ndarray = (point_seg_diffs * line_seg_diffs).sum( 95 | axis=-1, keepdims=True 96 | ) 97 | norms: np.ndarray = np.linalg.norm(line_seg_diffs, axis=-1, keepdims=True) ** 2 98 | 99 | # Clip ensures that the projected point stays within the line segment boundaries. 100 | projs: np.ndarray = ( 101 | p0 + np.clip(dot_products / norms, a_min=0, a_max=1) * line_seg_diffs 102 | ) 103 | 104 | # 2. Find the nearest projections to the original points. 105 | closest_proj_idxs: int = np.linalg.norm(xyz - projs, axis=-1).argmin(axis=-1) 106 | 107 | if self.has_heading: 108 | # Adding in the heading of the corresponding p0 point (which makes 109 | # sense as p0 to p1 is a line => same heading along it). 110 | return np.concatenate( 111 | [ 112 | projs[range(xyz.shape[0]), closest_proj_idxs], 113 | np.expand_dims(self.points[closest_proj_idxs, -1], axis=-1), 114 | ], 115 | axis=-1, 116 | ) 117 | else: 118 | return projs[range(xyz.shape[0]), closest_proj_idxs] 119 | 120 | 121 | @dataclass 122 | class MapElement: 123 | id: str 124 | 125 | 126 | @dataclass 127 | class RoadLane(MapElement): 128 | center: Polyline 129 | left_edge: Optional[Polyline] = None 130 | right_edge: Optional[Polyline] = None 131 | adj_lanes_left: Set[str] = field(default_factory=lambda: set()) 132 | adj_lanes_right: Set[str] = field(default_factory=lambda: set()) 133 | next_lanes: Set[str] = field(default_factory=lambda: set()) 134 | prev_lanes: Set[str] = field(default_factory=lambda: set()) 135 | elem_type: MapElementType = MapElementType.ROAD_LANE 136 | 137 | def __post_init__(self) -> None: 138 | if not self.center.has_heading: 139 | self.center = Polyline( 140 | np.append( 141 | self.center.xyz, 142 | map_utils.get_polyline_headings(self.center.xyz), 143 | axis=-1, 144 | ) 145 | ) 146 | 147 | def __hash__(self) -> int: 148 | return hash(self.id) 149 | 150 | @property 151 | def reachable_lanes(self) -> Set[str]: 152 | return self.adj_lanes_left | self.adj_lanes_right | self.next_lanes 153 | 154 | 155 | @dataclass 156 | class RoadArea(MapElement): 157 | exterior_polygon: Polyline 158 | interior_holes: List[Polyline] = field(default_factory=lambda: list()) 159 | elem_type: MapElementType = MapElementType.ROAD_AREA 160 | 161 | 162 | @dataclass 163 | class PedCrosswalk(MapElement): 164 | polygon: Polyline 165 | elem_type: MapElementType = MapElementType.PED_CROSSWALK 166 | 167 | 168 | @dataclass 169 | class PedWalkway(MapElement): 170 | polygon: Polyline 171 | elem_type: MapElementType = MapElementType.PED_WALKWAY 172 | -------------------------------------------------------------------------------- /src/trajdata/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_preprocessor import ParallelDatasetPreprocessor, scene_paths_collate_fn 2 | -------------------------------------------------------------------------------- /src/trajdata/parallel/data_preprocessor.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, List, Optional, Tuple, Type 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | from trajdata.caching import EnvCache, SceneCache 8 | from trajdata.data_structures import Scene, SceneMetadata 9 | from trajdata.utils import agent_utils, env_utils 10 | 11 | 12 | def scene_paths_collate_fn(filled_scenes: List) -> List: 13 | return filled_scenes 14 | 15 | 16 | class ParallelDatasetPreprocessor(Dataset): 17 | def __init__( 18 | self, 19 | scene_info_list: List[SceneMetadata], 20 | envs_dir_dict: Dict[str, str], 21 | env_cache_path: str, 22 | desired_dt: Optional[float], 23 | cache_class: Type[SceneCache], 24 | rebuild_cache: bool, 25 | ) -> None: 26 | self.env_cache_path = np.array(env_cache_path).astype(np.string_) 27 | self.desired_dt = desired_dt 28 | self.cache_class = cache_class 29 | self.rebuild_cache = rebuild_cache 30 | 31 | env_names: List[str] = list(envs_dir_dict.keys()) 32 | scene_idxs_names: List[Tuple[int, str]] = [ 33 | (idx, scene_info.name) for idx, scene_info in enumerate(scene_info_list) 34 | ] 35 | scene_name_idxs, scene_names = zip(*scene_idxs_names) 36 | 37 | self.scene_idxs = np.array( 38 | [scene_info.raw_data_idx for scene_info in scene_info_list], dtype=int 39 | ) 40 | self.env_name_idxs = np.array( 41 | [env_names.index(scene_info.env_name) for scene_info in scene_info_list], 42 | dtype=int, 43 | ) 44 | 45 | self.scene_name_idxs = np.array(scene_name_idxs, dtype=int) 46 | self.env_names_arr = np.array(env_names).astype(np.string_) 47 | self.scene_names_arr = np.array(scene_names).astype(np.string_) 48 | self.data_dir_arr = np.array(list(envs_dir_dict.values())).astype(np.string_) 49 | 50 | self.data_len: int = len(scene_info_list) 51 | 52 | def __len__(self) -> int: 53 | return self.data_len 54 | 55 | def __getitem__(self, idx: int) -> str: 56 | env_cache_path: Path = Path(str(self.env_cache_path, encoding="utf-8")) 57 | env_cache: EnvCache = EnvCache(env_cache_path) 58 | 59 | env_idx: int = self.env_name_idxs[idx] 60 | scene_idx: int = self.scene_name_idxs[idx] 61 | 62 | env_name: str = str(self.env_names_arr[env_idx], encoding="utf-8") 63 | raw_dataset = env_utils.get_raw_dataset( 64 | env_name, str(self.data_dir_arr[env_idx], encoding="utf-8") 65 | ) 66 | 67 | scene_name: str = str(self.scene_names_arr[scene_idx], encoding="utf-8") 68 | 69 | scene_info = SceneMetadata( 70 | env_name, scene_name, raw_dataset.metadata.dt, self.scene_idxs[idx] 71 | ) 72 | 73 | # Leaving verbose False here so that we don't spam 74 | # stdout with loading messages. 75 | raw_dataset.load_dataset_obj(verbose=False) 76 | scene: Scene = agent_utils.get_agent_data( 77 | scene_info, 78 | raw_dataset, 79 | env_cache, 80 | self.rebuild_cache, 81 | self.cache_class, 82 | self.desired_dt, 83 | ) 84 | raw_dataset.del_dataset_obj() 85 | 86 | if scene is None: 87 | # This provides an escape hatch in case there's a reason we 88 | # don't want to add a scene to the list of scenes. As an example, 89 | # nuPlan has a scene with only a single frame of data which we 90 | # can't do much with in terms of prediction/planning/etc. 91 | return None 92 | 93 | scene_path: Path = EnvCache.scene_metadata_path( 94 | env_cache.path, scene.env_name, scene.name, scene.dt 95 | ) 96 | return str(scene_path) 97 | -------------------------------------------------------------------------------- /src/trajdata/proto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/trajdata/51f0efa9d572fa7da480ef8caf089d0d6987de9f/src/trajdata/proto/__init__.py -------------------------------------------------------------------------------- /src/trajdata/proto/vectorized_map.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package trajdata; 4 | 5 | message VectorizedMap { 6 | // The name of this map in the format environment_name:map_name 7 | string name = 1; 8 | 9 | // The full set of map elements. 10 | repeated MapElement elements = 2; 11 | 12 | // The coordinates of the cuboid (in m) 13 | // containing all elements in this map. 14 | Point max_pt = 3; 15 | Point min_pt = 4; 16 | 17 | // The original world coordinates (in m) of the bottom-left of the map 18 | // (account for a change in the origin for storage efficiency). 19 | Point shifted_origin = 5; 20 | } 21 | 22 | message MapElement { 23 | // A unique ID to identify this element. 24 | bytes id = 1; 25 | 26 | // Type specific data. 27 | oneof element_data { 28 | RoadLane road_lane = 2; 29 | RoadArea road_area = 3; 30 | PedCrosswalk ped_crosswalk = 4; 31 | PedWalkway ped_walkway = 5; 32 | } 33 | } 34 | 35 | message Point { 36 | double x = 1; 37 | double y = 2; 38 | double z = 3; 39 | } 40 | 41 | message Polyline { 42 | // Position deltas in millimeters. The origin is an arbitrary location. 43 | // From https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 44 | // The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from 45 | // the origin. For subsequent points, this field stores the difference between the point's 46 | // coordinates and the previous point's coordinates. This is for representation efficiency. 47 | repeated sint32 dx_mm = 1; 48 | repeated sint32 dy_mm = 2; 49 | repeated sint32 dz_mm = 3; 50 | repeated double h_rad = 4; 51 | } 52 | 53 | message RoadLane { 54 | // The polyline data for the lane. A polyline is a list of points with 55 | // segments defined between consecutive points. 56 | Polyline center = 1; 57 | 58 | // The polyline data for the (optional) left boundary of this lane. 59 | optional Polyline left_boundary = 2; 60 | 61 | // The polyline data for the (optional) right boundary of this lane. 62 | optional Polyline right_boundary = 3; 63 | 64 | // A list of IDs for lanes that this lane may be entered from. 65 | repeated bytes entry_lanes = 4; 66 | 67 | // A list of IDs for lanes that this lane may exit to. 68 | repeated bytes exit_lanes = 5; 69 | 70 | // A list of neighbors to the left of this lane. Neighbor lanes 71 | // include only adjacent lanes going the same direction. 72 | repeated bytes adjacent_lanes_left = 6; 73 | 74 | // A list of neighbors to the right of this lane. Neighbor lanes 75 | // include only adjacent lanes going the same direction. 76 | repeated bytes adjacent_lanes_right = 7; 77 | } 78 | 79 | message RoadArea { 80 | // The polygon defining the outline of general driveable area. This acts as a 81 | // catch-all when there's a road segment without lane information. 82 | // For example, intersections in nuScenes are not "lanes". 83 | Polyline exterior_polygon = 1; 84 | 85 | // The area within these polygons is NOT included. 86 | repeated Polyline interior_holes = 2; 87 | } 88 | 89 | message PedCrosswalk { 90 | // The polygon defining the outline of the crosswalk. The polygon is assumed 91 | // to be closed (i.e. a segment exists between the last point and the first 92 | // point). 93 | Polyline polygon = 1; 94 | } 95 | 96 | message PedWalkway { 97 | // The polygon defining the outline of the pedestrian walkway (e.g., sidewalk). 98 | // The polygon is assumed to be closed (i.e. a segment exists between the last 99 | // point and the first point). 100 | Polyline polygon = 1; 101 | } -------------------------------------------------------------------------------- /src/trajdata/proto/vectorized_map_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: vectorized_map.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 17 | b'\n\x14vectorized_map.proto\x12\x08trajdata"\xb0\x01\n\rVectorizedMap\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x08\x65lements\x18\x02 \x03(\x0b\x32\x14.trajdata.MapElement\x12\x1f\n\x06max_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.Point\x12\x1f\n\x06min_pt\x18\x04 \x01(\x0b\x32\x0f.trajdata.Point\x12\'\n\x0eshifted_origin\x18\x05 \x01(\x0b\x32\x0f.trajdata.Point"\xd8\x01\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x42\x0e\n\x0c\x65lement_data"(\n\x05Point\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01"F\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x11\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x11\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x11\x12\r\n\x05h_rad\x18\x04 \x03(\x01"\x98\x02\n\x08RoadLane\x12"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12.\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.PolylineH\x00\x88\x01\x01\x12/\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.PolylineH\x01\x88\x01\x01\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c\x42\x10\n\x0e_left_boundaryB\x11\n\x0f_right_boundary"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polylineb\x06proto3' 18 | ) 19 | 20 | 21 | _VECTORIZEDMAP = DESCRIPTOR.message_types_by_name["VectorizedMap"] 22 | _MAPELEMENT = DESCRIPTOR.message_types_by_name["MapElement"] 23 | _POINT = DESCRIPTOR.message_types_by_name["Point"] 24 | _POLYLINE = DESCRIPTOR.message_types_by_name["Polyline"] 25 | _ROADLANE = DESCRIPTOR.message_types_by_name["RoadLane"] 26 | _ROADAREA = DESCRIPTOR.message_types_by_name["RoadArea"] 27 | _PEDCROSSWALK = DESCRIPTOR.message_types_by_name["PedCrosswalk"] 28 | _PEDWALKWAY = DESCRIPTOR.message_types_by_name["PedWalkway"] 29 | VectorizedMap = _reflection.GeneratedProtocolMessageType( 30 | "VectorizedMap", 31 | (_message.Message,), 32 | { 33 | "DESCRIPTOR": _VECTORIZEDMAP, 34 | "__module__": "vectorized_map_pb2" 35 | # @@protoc_insertion_point(class_scope:trajdata.VectorizedMap) 36 | }, 37 | ) 38 | _sym_db.RegisterMessage(VectorizedMap) 39 | 40 | MapElement = _reflection.GeneratedProtocolMessageType( 41 | "MapElement", 42 | (_message.Message,), 43 | { 44 | "DESCRIPTOR": _MAPELEMENT, 45 | "__module__": "vectorized_map_pb2" 46 | # @@protoc_insertion_point(class_scope:trajdata.MapElement) 47 | }, 48 | ) 49 | _sym_db.RegisterMessage(MapElement) 50 | 51 | Point = _reflection.GeneratedProtocolMessageType( 52 | "Point", 53 | (_message.Message,), 54 | { 55 | "DESCRIPTOR": _POINT, 56 | "__module__": "vectorized_map_pb2" 57 | # @@protoc_insertion_point(class_scope:trajdata.Point) 58 | }, 59 | ) 60 | _sym_db.RegisterMessage(Point) 61 | 62 | Polyline = _reflection.GeneratedProtocolMessageType( 63 | "Polyline", 64 | (_message.Message,), 65 | { 66 | "DESCRIPTOR": _POLYLINE, 67 | "__module__": "vectorized_map_pb2" 68 | # @@protoc_insertion_point(class_scope:trajdata.Polyline) 69 | }, 70 | ) 71 | _sym_db.RegisterMessage(Polyline) 72 | 73 | RoadLane = _reflection.GeneratedProtocolMessageType( 74 | "RoadLane", 75 | (_message.Message,), 76 | { 77 | "DESCRIPTOR": _ROADLANE, 78 | "__module__": "vectorized_map_pb2" 79 | # @@protoc_insertion_point(class_scope:trajdata.RoadLane) 80 | }, 81 | ) 82 | _sym_db.RegisterMessage(RoadLane) 83 | 84 | RoadArea = _reflection.GeneratedProtocolMessageType( 85 | "RoadArea", 86 | (_message.Message,), 87 | { 88 | "DESCRIPTOR": _ROADAREA, 89 | "__module__": "vectorized_map_pb2" 90 | # @@protoc_insertion_point(class_scope:trajdata.RoadArea) 91 | }, 92 | ) 93 | _sym_db.RegisterMessage(RoadArea) 94 | 95 | PedCrosswalk = _reflection.GeneratedProtocolMessageType( 96 | "PedCrosswalk", 97 | (_message.Message,), 98 | { 99 | "DESCRIPTOR": _PEDCROSSWALK, 100 | "__module__": "vectorized_map_pb2" 101 | # @@protoc_insertion_point(class_scope:trajdata.PedCrosswalk) 102 | }, 103 | ) 104 | _sym_db.RegisterMessage(PedCrosswalk) 105 | 106 | PedWalkway = _reflection.GeneratedProtocolMessageType( 107 | "PedWalkway", 108 | (_message.Message,), 109 | { 110 | "DESCRIPTOR": _PEDWALKWAY, 111 | "__module__": "vectorized_map_pb2" 112 | # @@protoc_insertion_point(class_scope:trajdata.PedWalkway) 113 | }, 114 | ) 115 | _sym_db.RegisterMessage(PedWalkway) 116 | 117 | if _descriptor._USE_C_DESCRIPTORS == False: 118 | DESCRIPTOR._options = None 119 | _VECTORIZEDMAP._serialized_start = 35 120 | _VECTORIZEDMAP._serialized_end = 211 121 | _MAPELEMENT._serialized_start = 214 122 | _MAPELEMENT._serialized_end = 430 123 | _POINT._serialized_start = 432 124 | _POINT._serialized_end = 472 125 | _POLYLINE._serialized_start = 474 126 | _POLYLINE._serialized_end = 544 127 | _ROADLANE._serialized_start = 547 128 | _ROADLANE._serialized_end = 827 129 | _ROADAREA._serialized_start = 829 130 | _ROADAREA._serialized_end = 929 131 | _PEDCROSSWALK._serialized_start = 931 132 | _PEDCROSSWALK._serialized_end = 982 133 | _PEDWALKWAY._serialized_start = 984 134 | _PEDWALKWAY._serialized_end = 1033 135 | # @@protoc_insertion_point(module_scope) 136 | -------------------------------------------------------------------------------- /src/trajdata/simulation/__init__.py: -------------------------------------------------------------------------------- 1 | from .sim_scene import SimulationScene 2 | -------------------------------------------------------------------------------- /src/trajdata/simulation/sim_cache.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import numpy as np 4 | 5 | from trajdata.caching.scene_cache import SceneCache 6 | from trajdata.data_structures.state import StateArray 7 | from trajdata.simulation.sim_metrics import SimMetric 8 | from trajdata.simulation.sim_stats import SimStatistic 9 | 10 | 11 | class SimulationCache(SceneCache): 12 | def reset(self) -> None: 13 | raise NotImplementedError() 14 | 15 | def append_state(self, xyzh_dict: Dict[str, StateArray]) -> None: 16 | raise NotImplementedError() 17 | 18 | def add_agents(self, agent_data: List[Tuple]) -> None: 19 | raise NotImplementedError() 20 | 21 | def save_sim_scene(self) -> None: 22 | raise NotImplementedError() 23 | 24 | def calculate_metrics( 25 | self, metrics: List[SimMetric], ts_range: Optional[Tuple[int, int]] = None 26 | ) -> Dict[str, Dict[str, float]]: 27 | raise NotImplementedError() 28 | 29 | def calculate_stats( 30 | self, stats: List[SimStatistic], ts_range: Optional[Tuple[int, int]] = None 31 | ) -> Dict[str, Dict[str, Tuple[np.ndarray, np.ndarray]]]: 32 | raise NotImplementedError() 33 | -------------------------------------------------------------------------------- /src/trajdata/simulation/sim_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | class SimMetric: 8 | def __init__(self, name: str) -> None: 9 | self.name = name 10 | 11 | def __call__(self, gt_df: pd.DataFrame, sim_df: pd.DataFrame) -> Dict[str, float]: 12 | raise NotImplementedError() 13 | 14 | 15 | class ADE(SimMetric): 16 | def __init__(self) -> None: 17 | super().__init__("ade") 18 | 19 | def __call__(self, gt_df: pd.DataFrame, sim_df: pd.DataFrame) -> Dict[str, float]: 20 | err_df = pd.DataFrame(index=gt_df.index, columns=["error"]) 21 | err_df["error"] = np.linalg.norm(gt_df[["x", "y"]] - sim_df[["x", "y"]], axis=1) 22 | return err_df.groupby("agent_id")["error"].mean().to_dict() 23 | 24 | 25 | class FDE(SimMetric): 26 | def __init__(self) -> None: 27 | super().__init__("fde") 28 | 29 | def __call__(self, gt_df: pd.DataFrame, sim_df: pd.DataFrame) -> Dict[str, float]: 30 | err_df = pd.DataFrame(index=gt_df.index, columns=["error"]) 31 | err_df["error"] = np.linalg.norm(gt_df[["x", "y"]] - sim_df[["x", "y"]], axis=1) 32 | return err_df.groupby("agent_id")["error"].last().to_dict() 33 | -------------------------------------------------------------------------------- /src/trajdata/simulation/sim_stats.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from torch import Tensor 7 | 8 | from trajdata.utils import arr_utils 9 | 10 | 11 | class SimStatistic: 12 | def __init__(self, name: str) -> None: 13 | self.name = name 14 | 15 | def __call__(self, scene_df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: 16 | raise NotImplementedError() 17 | 18 | 19 | class VelocityHistogram(SimStatistic): 20 | def __init__(self, bins: List[int]) -> None: 21 | super().__init__("vel_hist") 22 | self.bins = bins 23 | 24 | def __call__(self, scene_df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: 25 | velocities: np.ndarray = np.linalg.norm(scene_df[["vx", "vy"]], axis=1) 26 | 27 | return np.histogram(velocities, bins=self.bins) 28 | 29 | 30 | class LongitudinalAccHistogram(SimStatistic): 31 | def __init__(self, bins: List[int]) -> None: 32 | super().__init__("lon_acc_hist") 33 | self.bins = bins 34 | 35 | def __call__(self, scene_df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: 36 | accels: np.ndarray = np.linalg.norm(scene_df[["ax", "ay"]], axis=1) 37 | lon_accels: np.ndarray = accels * np.cos(scene_df["heading"]) 38 | 39 | return np.histogram(lon_accels, bins=self.bins) 40 | 41 | 42 | class LateralAccHistogram(SimStatistic): 43 | def __init__(self, bins: List[int]) -> None: 44 | super().__init__("lat_acc_hist") 45 | self.bins = bins 46 | 47 | def __call__(self, scene_df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: 48 | accels: np.ndarray = np.linalg.norm(scene_df[["ax", "ay"]], axis=1) 49 | lat_accels: np.ndarray = accels * np.sin(scene_df["heading"]) 50 | 51 | return np.histogram(lat_accels, bins=self.bins) 52 | 53 | 54 | class JerkHistogram(SimStatistic): 55 | def __init__(self, bins: List[int], dt: float) -> None: 56 | super().__init__("jerk_hist") 57 | self.bins = bins 58 | self.dt = dt 59 | 60 | def __call__(self, scene_df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: 61 | accels: np.ndarray = np.linalg.norm(scene_df[["ax", "ay"]], axis=1) 62 | jerk: np.ndarray = ( 63 | arr_utils.agent_aware_diff(accels, scene_df.index.get_level_values(0)) 64 | / self.dt 65 | ) 66 | 67 | return np.histogram(jerk, bins=self.bins) 68 | 69 | 70 | def calc_stats( 71 | positions: Tensor, heading: Tensor, dt: float, bins: Dict[str, Tensor] 72 | ) -> Dict[str, Tensor]: 73 | """Calculate scene statistics for a simulated scene. 74 | 75 | Args: 76 | positions (Tensor): N x T x 2 tensor of agent positions (in world coordinates). 77 | heading (Tensor): N x T x 1 tensor of agent headings (in world coordinates). 78 | dt (float): The data's delta timestep. 79 | bins (Dict[str, Tensor]): A mapping from statistic name to a Tensor of bin edges. 80 | 81 | Returns: 82 | Dict[str, Tensor]: A mapping of value names to histograms. 83 | """ 84 | 85 | velocity: Tensor = ( 86 | torch.diff( 87 | positions, 88 | dim=1, 89 | prepend=positions[:, [0]] - (positions[:, [1]] - positions[:, [0]]), 90 | ) 91 | / dt 92 | ) 93 | velocity_norm: Tensor = torch.linalg.vector_norm(velocity, dim=-1) 94 | 95 | accel: Tensor = ( 96 | torch.diff( 97 | velocity, 98 | dim=1, 99 | prepend=velocity[:, [0]] - (velocity[:, [1]] - velocity[:, [0]]), 100 | ) 101 | / dt 102 | ) 103 | accel_norm: Tensor = torch.linalg.vector_norm(accel, dim=-1) 104 | 105 | lon_acc: Tensor = accel_norm * torch.cos(heading.squeeze(-1)) 106 | lat_acc: Tensor = accel_norm * torch.sin(heading.squeeze(-1)) 107 | 108 | jerk: Tensor = ( 109 | torch.diff( 110 | accel_norm, 111 | dim=1, 112 | prepend=accel_norm[:, [0]] - (accel_norm[:, [1]] - accel_norm[:, [0]]), 113 | ) 114 | / dt 115 | ) 116 | 117 | return { 118 | "velocity": torch.histogram(velocity_norm, bins["velocity"]), 119 | "lon_accel": torch.histogram(lon_acc, bins["lon_accel"]), 120 | "lat_accel": torch.histogram(lat_acc, bins["lat_accel"]), 121 | "jerk": torch.histogram(jerk, bins["jerk"]), 122 | } 123 | -------------------------------------------------------------------------------- /src/trajdata/simulation/sim_vis.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | def plot_sim_stats( 9 | stats: Dict[str, Dict[str, Tuple[np.ndarray, np.ndarray]]], 10 | show: bool = True, 11 | close: bool = True, 12 | ): 13 | fig, axes = plt.subplots(nrows=len(stats), ncols=2, figsize=(4, 8)) 14 | 15 | axes[0, 0].set_title("Ground Truth") 16 | axes[0, 1].set_title("Simulated") 17 | 18 | for row, scene_stats in enumerate(stats.values()): 19 | histogram, bins = scene_stats["gt"] 20 | axes[row, 0].hist(histogram, bins, linewidth=0.5, edgecolor="white") 21 | 22 | histogram, bins = scene_stats["sim"] 23 | axes[row, 1].hist(histogram, bins, linewidth=0.5, edgecolor="white") 24 | 25 | plt.tight_layout() 26 | 27 | if show: 28 | plt.show() 29 | 30 | if close: 31 | plt.close(fig) 32 | -------------------------------------------------------------------------------- /src/trajdata/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/trajdata/51f0efa9d572fa7da480ef8caf089d0d6987de9f/src/trajdata/utils/__init__.py -------------------------------------------------------------------------------- /src/trajdata/utils/agent_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type 2 | 3 | from trajdata.caching import EnvCache, SceneCache 4 | from trajdata.data_structures import Scene, SceneMetadata 5 | from trajdata.dataset_specific import RawDataset 6 | from trajdata.utils import scene_utils 7 | 8 | 9 | def get_agent_data( 10 | scene_info: SceneMetadata, 11 | raw_dataset: RawDataset, 12 | env_cache: EnvCache, 13 | rebuild_cache: bool, 14 | cache_class: Type[SceneCache], 15 | desired_dt: Optional[float] = None, 16 | ) -> Scene: 17 | if not rebuild_cache and env_cache.scene_is_cached( 18 | scene_info.env_name, scene_info.name, scene_info.dt 19 | ): 20 | scene: Scene = env_cache.load_scene( 21 | scene_info.env_name, scene_info.name, scene_info.dt 22 | ) 23 | 24 | # If the original data is already cached... 25 | if scene_utils.enforce_desired_dt(scene_info, desired_dt, dry_run=True): 26 | # If the original data is already cached, 27 | # but this scene's dt doesn't match what we desire: 28 | # First, interpolate and save the data. 29 | # Then, return the interpolated scene. 30 | 31 | # Interpolating the scene metadata and caching it. 32 | scene_utils.enforce_desired_dt(scene, desired_dt) 33 | env_cache.save_scene(scene) 34 | 35 | # Interpolating the agent data and caching it. 36 | # The core point of doing this here rather than in Line 45 and below 37 | # is that we do not need to access the raw dataset object, we can 38 | # leverage the already cached data. 39 | scene_cache: SceneCache = cache_class(env_cache.path, scene) 40 | scene_cache.write_cache_to_disk() 41 | 42 | # Once this scene's dt matches what we desire: Return it. 43 | return scene 44 | 45 | # Obtaining and caching the original scene data. 46 | scene: Scene = raw_dataset.get_scene(scene_info) 47 | agent_list, agent_presence = raw_dataset.get_agent_info( 48 | scene, env_cache.path, cache_class 49 | ) 50 | if agent_list is None and agent_presence is None: 51 | raise ValueError(f"Scene {scene_info.name} contains no agents!") 52 | 53 | scene.update_agent_info(agent_list, agent_presence) 54 | env_cache.save_scene(scene) 55 | 56 | if scene_utils.enforce_desired_dt(scene, desired_dt, dry_run=True): 57 | # In case the user specified a desired_dt that's different from the scene's 58 | # native dt, we will perform the interpolation here and cache the result for 59 | # later reuse. 60 | 61 | # Interpolating the scene metadata and caching it. 62 | scene_utils.enforce_desired_dt(scene, desired_dt) 63 | env_cache.save_scene(scene) 64 | 65 | # Interpolating the agent data and caching it. 66 | scene_cache: SceneCache = cache_class(env_cache.path, scene) 67 | scene_cache.write_cache_to_disk() 68 | 69 | return scene 70 | -------------------------------------------------------------------------------- /src/trajdata/utils/batch_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Any, Dict, Iterator, List, Optional, Tuple 3 | 4 | import numpy as np 5 | from torch.utils.data import Sampler 6 | 7 | from trajdata import UnifiedDataset 8 | from trajdata.data_structures import ( 9 | AgentBatch, 10 | AgentBatchElement, 11 | AgentDataIndex, 12 | AgentType, 13 | SceneBatchElement, 14 | SceneTimeAgent, 15 | ) 16 | from trajdata.data_structures.collation import agent_collate_fn 17 | 18 | 19 | class SceneTimeBatcher(Sampler): 20 | _agent_data_index: AgentDataIndex 21 | _agent_idx: int 22 | 23 | def __init__( 24 | self, agent_centric_dataset: UnifiedDataset, agent_idx_to_follow: int = 0 25 | ) -> None: 26 | """ 27 | Returns a sampler (to be used in a torch.utils.data.DataLoader) 28 | which works with an agent-centric UnifiedDataset, yielding 29 | batches consisting of whole scenes (AgentBatchElements for all agents 30 | in a particular scene at a particular time) 31 | 32 | Args: 33 | agent_centric_dataset (UnifiedDataset) 34 | agent_idx_to_follow (int): index of agent to return batches for. Defaults to 0, 35 | meaning we include all scene frames where the ego agent appears, which 36 | usually covers the entire dataset. 37 | """ 38 | super().__init__(agent_centric_dataset) 39 | self._agent_data_index = agent_centric_dataset._data_index 40 | self._agent_idx = agent_idx_to_follow 41 | self._cumulative_lengths = np.concatenate( 42 | [ 43 | [0], 44 | np.cumsum( 45 | [ 46 | cumulative_scene_length[self._agent_idx + 1] 47 | - cumulative_scene_length[self._agent_idx] 48 | for cumulative_scene_length in self._agent_data_index._cumulative_scene_lengths 49 | ] 50 | ), 51 | ] 52 | ) 53 | 54 | def __len__(self): 55 | return self._cumulative_lengths[-1] 56 | 57 | def __iter__(self) -> Iterator[int]: 58 | for idx in range(len(self)): 59 | # TODO(apoorvas) May not need to do this search, since we only support an iterable style access? 60 | scene_idx: int = ( 61 | np.searchsorted(self._cumulative_lengths, idx, side="right").item() - 1 62 | ) 63 | 64 | # offset into dataset index to reach current scene 65 | scene_offset = self._agent_data_index._cumulative_lengths[scene_idx].item() 66 | 67 | # how far along we are in the current scene 68 | scene_elem_index = idx - self._cumulative_lengths[scene_idx].item() 69 | 70 | # convert to scene-timestep for the tracked agent 71 | scene_ts = ( 72 | scene_elem_index 73 | + self._agent_data_index._agent_times[scene_idx][self._agent_idx, 0] 74 | ) 75 | 76 | # build a set of indices into the agent-centric dataset for all agents that exist at this scene and timestep 77 | indices = [] 78 | for agent_idx, agent_times in enumerate( 79 | self._agent_data_index._agent_times[scene_idx] 80 | ): 81 | if scene_ts > agent_times[1]: 82 | # we are past the last timestep for this agent (times are inclusive) 83 | continue 84 | agent_offset = scene_ts - agent_times[0] 85 | if agent_offset < 0: 86 | # this agent hasn't entered the scene yet 87 | continue 88 | 89 | # compute index into original dataset, first into scene, then into this agent's part in scene, and then the offset 90 | index_to_add = ( 91 | scene_offset 92 | + self._agent_data_index._cumulative_scene_lengths[scene_idx][ 93 | agent_idx 94 | ] 95 | + agent_offset 96 | ) 97 | indices.append(index_to_add) 98 | 99 | yield indices 100 | 101 | 102 | def convert_to_agent_batch( 103 | scene_batch_element: SceneBatchElement, 104 | only_types: Optional[List[AgentType]] = None, 105 | no_types: Optional[List[AgentType]] = None, 106 | agent_interaction_distances: Dict[Tuple[AgentType, AgentType], float] = defaultdict( 107 | lambda: np.inf 108 | ), 109 | incl_map: bool = False, 110 | map_params: Optional[Dict[str, Any]] = None, 111 | max_neighbor_num: Optional[int] = None, 112 | state_format: Optional[str] = None, 113 | standardize_data: bool = True, 114 | standardize_derivatives: bool = False, 115 | pad_format: str = "outside", 116 | ) -> AgentBatch: 117 | """ 118 | Converts a SceneBatchElement into a AgentBatch consisting of 119 | AgentBatchElements for all agents present at the given scene at the given 120 | time step. 121 | 122 | Args: 123 | scene_batch_element (SceneBatchElement): element to process 124 | only_types (Optional[List[AgentType]], optional): AgentsTypes to consider. Defaults to None. 125 | no_types (Optional[List[AgentType]], optional): AgentTypes to ignore. Defaults to None. 126 | agent_interaction_distances (_type_, optional): Distance threshold for interaction. Defaults to defaultdict(lambda: np.inf). 127 | incl_map (bool, optional): Whether to include map info. Defaults to False. 128 | map_params (Optional[Dict[str, Any]], optional): Map params. Defaults to None. 129 | max_neighbor_num (Optional[int], optional): Max number of neighbors to allow. Defaults to None. 130 | standardize_data (bool): Whether to return data relative to current agent state. Defaults to True. 131 | standardize_derivatives: Whether to transform relative velocities and accelerations as well. Defaults to False. 132 | pad_format (str, optional): Pad format when collating agent trajectories. Defaults to "outside". 133 | 134 | Returns: 135 | AgentBatch: batch of AgentBatchElements corresponding to all agents in the SceneBatchElement 136 | """ 137 | data_idx = scene_batch_element.data_index 138 | cache = scene_batch_element.cache 139 | scene = cache.scene 140 | dt = scene_batch_element.dt 141 | ts = scene_batch_element.scene_ts 142 | state_format = scene_batch_element.centered_agent_state_np._format 143 | 144 | batch_elems: List[AgentBatchElement] = [] 145 | for j, agent_name in enumerate(scene_batch_element.agent_names): 146 | history_sec = dt * (scene_batch_element.agent_histories[j].shape[0] - 1) 147 | future_sec = dt * (scene_batch_element.agent_futures[j].shape[0]) 148 | cache.reset_obs_frame() 149 | scene_time_agent: SceneTimeAgent = SceneTimeAgent.from_cache( 150 | scene, 151 | ts, 152 | agent_name, 153 | cache, 154 | only_types=only_types, 155 | no_types=no_types, 156 | ) 157 | 158 | batch_elems.append( 159 | AgentBatchElement( 160 | cache=cache, 161 | data_index=data_idx, 162 | scene_time_agent=scene_time_agent, 163 | history_sec=(history_sec, history_sec), 164 | future_sec=(future_sec, future_sec), 165 | agent_interaction_distances=agent_interaction_distances, 166 | incl_raster_map=incl_map, 167 | raster_map_params=map_params, 168 | state_format=state_format, 169 | standardize_data=standardize_data, 170 | standardize_derivatives=standardize_derivatives, 171 | max_neighbor_num=max_neighbor_num, 172 | ) 173 | ) 174 | 175 | return agent_collate_fn(batch_elems, return_dict=False, pad_format=pad_format) 176 | -------------------------------------------------------------------------------- /src/trajdata/utils/df_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def downsample_multi_index_df( 8 | df: pd.DataFrame, downsample_dt_factor: int 9 | ) -> pd.DataFrame: 10 | """ 11 | Downsamples MultiIndex dataframe, assuming level=1 of the index 12 | corresponds to the scene timestep. 13 | """ 14 | subsampled_df = df.groupby(level=0).apply( 15 | lambda g: g.reset_index(level=0, drop=True) 16 | .iloc[::downsample_dt_factor] 17 | .rename(index=lambda ts: ts // downsample_dt_factor) 18 | ) 19 | 20 | return subsampled_df 21 | 22 | 23 | def upsample_ts_index_df( 24 | df: pd.DataFrame, 25 | upsample_dt_factor: int, 26 | method: str, 27 | preprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, 28 | postprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, 29 | ): 30 | """ 31 | Upsamples a time indexed dataframe, applying specified method. 32 | Calls preprocess and postprocess before and after upsampling repsectively. 33 | 34 | If original data is at frames 2,3,4,5, and upsample_dt_factor is 3, then 35 | the original data will live at frames 6,9,12,15, and new data will 36 | be generated according to method for frames 7,8, 10,11, 13,14 (frames after the last frame are not generated) 37 | """ 38 | if preprocess: 39 | df = preprocess(df) 40 | 41 | # first, we multiply ts index by upsample factor 42 | df = df.rename(index=lambda ts: ts * upsample_dt_factor) 43 | 44 | # get the index by adding the number of frames needed per original index 45 | new_index = pd.Index( 46 | (df.index.to_numpy()[:, None] + np.arange(upsample_dt_factor)).flatten()[ 47 | : -(upsample_dt_factor - 1) 48 | ], 49 | name=df.index.name, 50 | ) 51 | 52 | # reindex and interpolate according to method 53 | df = df.reindex(new_index).interpolate(method=method, limit_area="inside") 54 | 55 | if postprocess: 56 | df = postprocess(df) 57 | 58 | return df 59 | 60 | 61 | def upsample_multi_index_df( 62 | df: pd.DataFrame, 63 | upsample_dt_factor: int, 64 | method: str, 65 | preprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, 66 | postprocess: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, 67 | ) -> pd.DataFrame: 68 | return df.groupby(level=[0]).apply( 69 | lambda g: upsample_ts_index_df( 70 | g.reset_index(level=[0], drop=True), 71 | upsample_dt_factor, 72 | method, 73 | preprocess, 74 | postprocess, 75 | ) 76 | ) 77 | 78 | 79 | def interpolate_multi_index_df( 80 | df: pd.DataFrame, data_dt: float, desired_dt: float, method: str = "linear" 81 | ) -> pd.DataFrame: 82 | """ 83 | Interpolates the given dataframe indexed with (elem_id, scene_ts) 84 | where scene_ts corresponds to timesteps with increment data_dt to a new 85 | desired_dt. 86 | """ 87 | upsample_dt_ratio: float = data_dt / desired_dt 88 | downsample_dt_ratio: float = desired_dt / data_dt 89 | if not upsample_dt_ratio.is_integer() and not downsample_dt_ratio.is_integer(): 90 | raise ValueError( 91 | f"Data's dt of {data_dt}s " 92 | f"is not integer divisible by the desired dt {desired_dt}s." 93 | ) 94 | 95 | if upsample_dt_ratio >= 1: 96 | return upsample_multi_index_df(df, int(upsample_dt_ratio), method) 97 | elif downsample_dt_ratio >= 1: 98 | return downsample_multi_index_df(df, int(downsample_dt_ratio)) 99 | -------------------------------------------------------------------------------- /src/trajdata/utils/env_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from trajdata.dataset_specific import RawDataset 4 | 5 | 6 | def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: 7 | if "nusc" in dataset_name: 8 | from trajdata.dataset_specific.nusc import NuscDataset 9 | 10 | return NuscDataset(dataset_name, data_dir, parallelizable=False, has_maps=True) 11 | 12 | if "vod" in dataset_name: 13 | from trajdata.dataset_specific.vod import VODDataset 14 | 15 | return VODDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) 16 | 17 | if "lyft" in dataset_name: 18 | from trajdata.dataset_specific.lyft import LyftDataset 19 | 20 | return LyftDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) 21 | 22 | if "eupeds" in dataset_name: 23 | from trajdata.dataset_specific.eth_ucy_peds import EUPedsDataset 24 | 25 | return EUPedsDataset( 26 | dataset_name, data_dir, parallelizable=True, has_maps=False 27 | ) 28 | 29 | if "sdd" in dataset_name: 30 | from trajdata.dataset_specific.sdd_peds import SDDPedsDataset 31 | 32 | return SDDPedsDataset( 33 | dataset_name, data_dir, parallelizable=True, has_maps=False 34 | ) 35 | 36 | if "nuplan" in dataset_name: 37 | from trajdata.dataset_specific.nuplan import NuplanDataset 38 | 39 | return NuplanDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) 40 | 41 | if "waymo" in dataset_name: 42 | from trajdata.dataset_specific.waymo import WaymoDataset 43 | 44 | return WaymoDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) 45 | 46 | if "interaction" in dataset_name: 47 | from trajdata.dataset_specific.interaction import InteractionDataset 48 | 49 | return InteractionDataset( 50 | dataset_name, data_dir, parallelizable=True, has_maps=True 51 | ) 52 | 53 | if "av2" in dataset_name: 54 | from trajdata.dataset_specific.argoverse2 import Av2Dataset 55 | 56 | return Av2Dataset(dataset_name, data_dir, parallelizable=True, has_maps=True) 57 | 58 | raise ValueError(f"Dataset with name '{dataset_name}' is not supported") 59 | 60 | 61 | def get_raw_datasets(data_dirs: Dict[str, str]) -> List[RawDataset]: 62 | raw_datasets: List[RawDataset] = list() 63 | 64 | for dataset_name, data_dir in data_dirs.items(): 65 | raw_datasets.append(get_raw_dataset(dataset_name, data_dir)) 66 | 67 | return raw_datasets 68 | -------------------------------------------------------------------------------- /src/trajdata/utils/parallel_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from multiprocessing import Manager, Pool 3 | from typing import Callable, Iterable, List, Optional 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | def parallel_apply( 9 | element_fn: Callable, 10 | element_list: Iterable, 11 | num_workers: int, 12 | desc: Optional[str] = None, 13 | disable: bool = False, 14 | ) -> List: 15 | return list(parallel_iapply(element_fn, element_list, num_workers, desc, disable)) 16 | 17 | 18 | def parallel_iapply( 19 | element_fn: Callable, 20 | element_list: Iterable, 21 | num_workers: int, 22 | desc: Optional[str] = None, 23 | disable: bool = False, 24 | ) -> Iterable: 25 | with Pool(processes=num_workers) as pool: 26 | for fn_output in tqdm( 27 | pool.imap(element_fn, element_list), 28 | desc=desc, 29 | total=len(element_list), 30 | disable=disable, 31 | ): 32 | yield fn_output 33 | 34 | 35 | def pickle_objects(objs: List) -> List[bytes]: 36 | pickled_objs: List[bytes] = list() 37 | for obj in objs: 38 | pickled_objs.append(pickle.dumps(obj)) 39 | 40 | return pickled_objs 41 | 42 | 43 | class AsyncExecutor: 44 | def __init__( 45 | self, num_workers: int, total_jobs: int, desc: str, position: int, disable: bool 46 | ): 47 | self.pool = Pool(processes=num_workers) 48 | self.manager = Manager() 49 | self.results_queue = self.manager.Queue() 50 | 51 | self.pbar = tqdm( 52 | desc=desc, position=position, total=total_jobs, disable=disable 53 | ) 54 | 55 | def error(self, err): 56 | raise err 57 | 58 | def prompt(self, results: List[bytes]): 59 | self.pbar.update(len(results)) 60 | for result in results: 61 | self.results_queue.put(result) 62 | 63 | def schedule(self, function, args): 64 | return self.pool.apply_async( 65 | function, args, callback=self.prompt, error_callback=self.error 66 | ) 67 | 68 | def wait(self): 69 | self.pool.close() 70 | self.pool.join() 71 | self.pbar.close() 72 | -------------------------------------------------------------------------------- /src/trajdata/utils/py_utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | from typing import Dict, List, Set, Tuple, Union 4 | 5 | 6 | def hash_dict(o: Union[Dict, List, Tuple, Set]) -> str: 7 | """ 8 | Makes a hash from a dictionary, list, tuple or set to any level, that contains 9 | only other hashable types (including any lists, tuples, sets, and 10 | dictionaries). 11 | """ 12 | string_rep: str = json.dumps(o) 13 | return hashlib.sha1(str.encode(string_rep)).hexdigest() 14 | -------------------------------------------------------------------------------- /src/trajdata/utils/scene_utils.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from typing import List, Optional, Union 3 | 4 | from trajdata.data_structures import AgentMetadata, Scene, SceneMetadata 5 | 6 | 7 | def enforce_desired_dt( 8 | scene_info: Union[Scene, SceneMetadata], 9 | desired_dt: Optional[float], 10 | dry_run: bool = False, 11 | ) -> bool: 12 | """Enforces that a scene's data is at the desired frequency (specified by desired_dt 13 | if it's not None) through interpolation. 14 | 15 | Args: 16 | scene_info (Scene | SceneMetadata): The scene to interpolate to the desired data frequency. 17 | desired_dt (Optional[float]): The desired data timestep difference (in seconds). 18 | dry_run (bool): If True, only check if the scene meets the desired data frequency (without modifying scene_info). Defaults to False. 19 | 20 | Returns: 21 | bool: True if the scene was modified (or would be modified if dry_run=True), False otherwise. 22 | """ 23 | if desired_dt is not None and scene_info.dt != desired_dt: 24 | if not dry_run and scene_info.dt > desired_dt: 25 | interpolate_scene_dt(scene_info, desired_dt) 26 | elif not dry_run and scene_info.dt < desired_dt: 27 | subsample_scene_dt(scene_info, desired_dt) 28 | return True 29 | 30 | return False 31 | 32 | 33 | def interpolate_scene_dt(scene: Scene, desired_dt: float) -> None: 34 | dt_ratio: float = scene.dt / desired_dt 35 | if not dt_ratio.is_integer(): 36 | raise ValueError( 37 | f"Cannot interpolate scene: {scene.dt} is not integer divisible by {desired_dt} for {str(scene)}" 38 | ) 39 | 40 | dt_factor: int = int(dt_ratio) 41 | 42 | # E.g., the scene is currently at dt = 0.5s (2 Hz), 43 | # but we want desired_dt = 0.1s (10 Hz). 44 | scene.length_timesteps = (scene.length_timesteps - 1) * dt_factor + 1 45 | agent_presence: List[List[AgentMetadata]] = [ 46 | [] for _ in range(scene.length_timesteps) 47 | ] 48 | for agent in scene.agents: 49 | agent.first_timestep *= dt_factor 50 | agent.last_timestep *= dt_factor 51 | 52 | for scene_ts in range(agent.first_timestep, agent.last_timestep + 1): 53 | agent_presence[scene_ts].append(agent) 54 | 55 | scene.update_agent_info(scene.agents, agent_presence) 56 | scene.dt = desired_dt 57 | # Note we do not touch scene_info.env_metadata.dt, this will serve as our 58 | # source of the "original" data dt information. 59 | 60 | 61 | def subsample_scene_dt(scene: Scene, desired_dt: float) -> None: 62 | dt_ratio: float = desired_dt / scene.dt 63 | if not dt_ratio.is_integer(): 64 | raise ValueError( 65 | f"Cannot subsample scene: {desired_dt} is not integer divisible by {scene.dt} for {str(scene)}" 66 | ) 67 | 68 | dt_factor: int = int(dt_ratio) 69 | 70 | # E.g., the scene is currently at dt = 0.1s (10 Hz), 71 | # but we want desired_dt = 0.5s (2 Hz). 72 | scene.length_timesteps = (scene.length_timesteps - 1) // dt_factor + 1 73 | agent_presence: List[List[AgentMetadata]] = [ 74 | [] for _ in range(scene.length_timesteps) 75 | ] 76 | for agent in scene.agents: 77 | # Need to be careful with the first timestep, since agents can have 78 | # first timesteps that are not exactly divisible by the dt_factor. 79 | agent.first_timestep = ceil(agent.first_timestep / dt_factor) 80 | agent.last_timestep //= dt_factor 81 | 82 | for scene_ts in range(agent.first_timestep, agent.last_timestep + 1): 83 | agent_presence[scene_ts].append(agent) 84 | 85 | scene.update_agent_info(scene.agents, agent_presence) 86 | scene.dt = desired_dt 87 | # Note we do not touch scene_info.env_metadata.dt, this will serve as our 88 | # source of the "original" data dt information. 89 | -------------------------------------------------------------------------------- /src/trajdata/utils/state_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | 5 | from trajdata.data_structures.state import StateArray, StateTensor 6 | from trajdata.utils.arr_utils import ( 7 | angle_wrap, 8 | rotation_matrix, 9 | transform_angles_np, 10 | transform_coords_2d_np, 11 | transform_coords_np, 12 | ) 13 | 14 | 15 | def transform_state_np_2d(state: StateArray, tf_mat_2d: np.ndarray): 16 | """ 17 | Transforms a state into another coordinate frame 18 | assumes center has dim 2 (xy shift) or shape 6 normalizes derivatives as well 19 | """ 20 | new_state = state.copy() 21 | attributes = state._format_dict.keys() 22 | if "x" in attributes and "y" in attributes: 23 | # transform xy position with translation and rotation 24 | new_state.position = transform_coords_np(state.position, tf_mat_2d) 25 | if "xd" in attributes and "yd" in attributes: 26 | # transform velocities 27 | new_state.velocity = transform_coords_np( 28 | state.velocity, tf_mat_2d, translate=False 29 | ) 30 | if "xdd" in attributes and "ydd" in attributes: 31 | # transform acceleration 32 | new_state.acceleration = transform_coords_np( 33 | state.acceleration, tf_mat_2d, translate=False 34 | ) 35 | if "c" in attributes and "s" in attributes: 36 | new_state.heading_vector = transform_coords_np( 37 | state.heading_vector, tf_mat_2d, translate=False 38 | ) 39 | if "h" in attributes: 40 | new_state.heading = transform_angles_np(state.heading, tf_mat_2d) 41 | 42 | return new_state 43 | 44 | 45 | def convert_to_frame_state( 46 | state: StateArray, 47 | stationary: bool = True, 48 | grounded: bool = True, 49 | ) -> StateArray: 50 | """ 51 | Returns a StateArray corresponding to a frame centered around the passed in State 52 | """ 53 | frame: StateArray = state.copy() 54 | attributes = state._format_dict.keys() 55 | if stationary: 56 | if "xd" in attributes and "yd" in attributes: 57 | frame.velocity = 0 58 | if "xdd" in attributes and "ydd" in attributes: 59 | frame.acceleration = 0 60 | if grounded: 61 | if "z" in attributes: 62 | frame.set_attr("z", 0) 63 | 64 | return frame 65 | 66 | 67 | def transform_to_frame( 68 | state: StateArray, frame_state: StateArray, rot_mat: Optional[np.ndarray] = None 69 | ) -> StateArray: 70 | """ 71 | Returns state with coordinates relative to a frame with state frame_state. 72 | Does not modify state in place. 73 | 74 | Args: 75 | state (StateArray): state to transform in world coordinates 76 | frame_state (StateArray): state of frame in world coordinates 77 | rot_mat Optional[nd.array]: rotation matrix A such that c = A @ b returns coordinates in the new frame 78 | if not given, it is computed frome frame_state 79 | """ 80 | new_state = state.copy() 81 | attributes = state._format_dict.keys() 82 | 83 | frame_heading = frame_state.heading[..., 0] 84 | if rot_mat is None: 85 | rot_mat = rotation_matrix(-frame_heading) 86 | 87 | if "x" in attributes and "y" in attributes: 88 | # transform xy position with translation and rotation 89 | new_state.position = transform_coords_2d_np( 90 | state.position, offset=-frame_state.position, rot_mat=rot_mat 91 | ) 92 | if "xd" in attributes and "yd" in attributes: 93 | # transform velocities 94 | new_state.velocity = transform_coords_2d_np( 95 | state.velocity, offset=-frame_state.velocity, rot_mat=rot_mat 96 | ) 97 | if "xdd" in attributes and "ydd" in attributes: 98 | # transform acceleration 99 | new_state.acceleration = transform_coords_2d_np( 100 | state.acceleration, offset=-frame_state.acceleration, rot_mat=rot_mat 101 | ) 102 | if "c" in attributes and "s" in attributes: 103 | new_state.heading_vector = transform_coords_2d_np( 104 | state.heading_vector, rot_mat=rot_mat 105 | ) 106 | if "h" in attributes: 107 | new_state.heading = angle_wrap(state.heading - frame_heading) 108 | 109 | return new_state 110 | 111 | 112 | def transform_from_frame( 113 | state: StateArray, frame_state: StateArray, rot_mat: Optional[np.ndarray] = None 114 | ) -> StateArray: 115 | """ 116 | Returns state with coordinates in world frame 117 | Does not modify state in place. 118 | 119 | Args: 120 | state (StateArray): state to transform in world coordinates 121 | frame_state (StateArray): state of frame in world coordinates 122 | rot_mat Optional[nd.array]: rotation matrix A such that c = A @ b returns coordinates in the new frame 123 | if not given, it is computed frome frame_state 124 | """ 125 | new_state = state.copy() 126 | attributes = state._format_dict.keys() 127 | 128 | frame_heading = frame_state.heading[..., 0] 129 | if rot_mat is None: 130 | rot_mat = rotation_matrix(frame_heading) 131 | 132 | if "x" in attributes and "y" in attributes: 133 | # transform xy position with translation and rotation 134 | new_state.position = ( 135 | transform_coords_2d_np(state.position, rot_mat=rot_mat) 136 | + frame_state.position 137 | ) 138 | if "xd" in attributes and "yd" in attributes: 139 | # transform velocities 140 | new_state.velocity = ( 141 | transform_coords_2d_np( 142 | state.velocity, 143 | angle=frame_heading, 144 | ) 145 | + frame_state.velocity 146 | ) 147 | if "xdd" in attributes and "ydd" in attributes: 148 | # transform acceleration 149 | new_state.acceleration = ( 150 | transform_coords_2d_np( 151 | state.acceleration, 152 | angle=frame_heading, 153 | ) 154 | + frame_state.acceleration 155 | ) 156 | if "c" in attributes and "s" in attributes: 157 | new_state.heading_vector = transform_coords_2d_np( 158 | state.heading_vector, 159 | angle=frame_heading, 160 | ) 161 | if "h" in attributes: 162 | new_state.heading = angle_wrap(state.heading + frame_heading) 163 | 164 | return new_state 165 | -------------------------------------------------------------------------------- /src/trajdata/utils/string_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from trajdata.data_structures.scene_tag import SceneTag 4 | 5 | 6 | def pretty_string_tags(tag_lst: List[SceneTag]) -> List[str]: 7 | return [str(tag) for tag in tag_lst] 8 | -------------------------------------------------------------------------------- /src/trajdata/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .interactive_figure import InteractiveFigure 2 | from .interactive_vis import plot_agent_batch_interactive 3 | from .vis import plot_agent_batch, plot_scene_batch 4 | -------------------------------------------------------------------------------- /src/trajdata/visualization/interactive_figure.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import bokeh.plotting as plt 4 | import numpy as np 5 | import torch 6 | from bokeh.models import ColumnDataSource, Range1d 7 | from bokeh.models.renderers import GlyphRenderer 8 | from bokeh.plotting import figure 9 | from torch import Tensor 10 | 11 | from trajdata.data_structures.agent import AgentType 12 | from trajdata.data_structures.state import StateTensor 13 | from trajdata.maps import VectorMap 14 | from trajdata.utils import vis_utils 15 | 16 | 17 | class InteractiveFigure: 18 | def __init__(self, **kwargs) -> None: 19 | self.aspect_ratio: float = kwargs.get("aspect_ratio", 16 / 9) 20 | self.width: int = kwargs.get("width", 1280) 21 | self.height: int = kwargs.get("height", int(self.width / self.aspect_ratio)) 22 | 23 | # We'll be tracking the maxes and mins of data with these. 24 | self.x_min = np.inf 25 | self.x_max = -np.inf 26 | self.y_min = np.inf 27 | self.y_max = -np.inf 28 | 29 | self.raw_figure = figure(width=self.width, height=self.height, **kwargs) 30 | vis_utils.apply_default_settings(self.raw_figure) 31 | 32 | def update_mins_maxs(self, x_min, x_max, y_min, y_max) -> None: 33 | self.x_min = min(self.x_min, x_min) 34 | self.x_max = max(self.x_max, x_max) 35 | self.y_min = min(self.y_min, y_min) 36 | self.y_max = max(self.y_max, y_max) 37 | 38 | def show(self) -> None: 39 | if np.isfinite((self.x_min, self.x_max, self.y_min, self.y_max)).all(): 40 | ( 41 | x_range_min, 42 | x_range_max, 43 | y_range_min, 44 | y_range_max, 45 | ) = vis_utils.calculate_figure_sizes( 46 | data_bbox=(self.x_min, self.x_max, self.y_min, self.y_max), 47 | aspect_ratio=self.aspect_ratio, 48 | ) 49 | 50 | self.raw_figure.x_range = Range1d(x_range_min, x_range_max) 51 | self.raw_figure.y_range = Range1d(y_range_min, y_range_max) 52 | 53 | plt.show(self.raw_figure) 54 | 55 | def add_line(self, states: StateTensor, **kwargs) -> GlyphRenderer: 56 | xy_pos = states.position.cpu().numpy() 57 | 58 | x_min, y_min = np.nanmin(xy_pos, axis=0) 59 | x_max, y_max = np.nanmax(xy_pos, axis=0) 60 | self.update_mins_maxs(x_min.item(), x_max.item(), y_min.item(), y_max.item()) 61 | 62 | return self.raw_figure.line(xy_pos[:, 0], xy_pos[:, 1], **kwargs) 63 | 64 | def add_lines(self, lines_data: ColumnDataSource, **kwargs) -> GlyphRenderer: 65 | self.update_mins_maxs(*vis_utils.get_multi_line_bbox(lines_data)) 66 | return self.raw_figure.multi_line( 67 | source=lines_data, 68 | # This is to ensure that the columns given in the 69 | # ColumnDataSource are respected (e.g., "line_color"). 70 | **{x: x for x in lines_data.column_names}, 71 | **kwargs, 72 | ) 73 | 74 | def add_map( 75 | self, 76 | map_from_world_tf: np.ndarray, 77 | vec_map: VectorMap, 78 | bbox: Optional[Tuple[float, float, float, float]] = None, 79 | **kwargs, 80 | ) -> Tuple[ 81 | GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer 82 | ]: 83 | """_summary_ 84 | 85 | Args: 86 | map_from_world_tf (np.ndarray): _description_ 87 | vec_map (VectorMap): _description_ 88 | bbox (Tuple[float, float, float, float]): x_min, x_max, y_min, y_max 89 | 90 | Returns: 91 | Tuple[ GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer ]: _description_ 92 | """ 93 | return vis_utils.draw_map_elems( 94 | self.raw_figure, vec_map, map_from_world_tf, bbox, **kwargs 95 | ) 96 | 97 | def add_agent( 98 | self, 99 | agent_type: AgentType, 100 | agent_state: StateTensor, 101 | agent_extent: Tensor, 102 | **kwargs, 103 | ) -> Tuple[GlyphRenderer, GlyphRenderer]: 104 | """Draws an agent at the given location, heading, and dimensions. 105 | 106 | Args: 107 | agent_type (AgentType): _description_ 108 | agent_state (Tensor): _description_ 109 | agent_extent (Tensor): _description_ 110 | """ 111 | if torch.any(torch.isnan(agent_extent)): 112 | raise ValueError("Agent extents cannot be NaN!") 113 | 114 | length = agent_extent[0].item() 115 | width = agent_extent[1].item() 116 | 117 | x, y = agent_state.position.cpu().numpy() 118 | heading = agent_state.heading.cpu().numpy() 119 | 120 | agent_rect_coords, dir_patch_coords = vis_utils.compute_agent_rect_coords( 121 | agent_type, heading, length, width 122 | ) 123 | 124 | source = { 125 | "x": agent_rect_coords[:, 0] + x, 126 | "y": agent_rect_coords[:, 1] + y, 127 | "type": [vis_utils.pretty_print_agent_type(agent_type)], 128 | "speed": [torch.linalg.norm(agent_state.velocity).item()], 129 | } 130 | 131 | r = self.raw_figure.patch( 132 | x="x", 133 | y="y", 134 | source=source, 135 | **kwargs, 136 | ) 137 | p = self.raw_figure.patch( 138 | x=dir_patch_coords[:, 0] + x, y=dir_patch_coords[:, 1] + y, **kwargs 139 | ) 140 | 141 | return r, p 142 | 143 | def add_agents( 144 | self, 145 | agent_rects_data: ColumnDataSource, 146 | dir_patches_data: ColumnDataSource, 147 | **kwargs, 148 | ) -> Tuple[GlyphRenderer, GlyphRenderer]: 149 | r = self.raw_figure.patches( 150 | source=agent_rects_data, 151 | # This is to ensure that the columns given in the 152 | # ColumnDataSource are respected (e.g., "line_color"). 153 | xs="xs", 154 | ys="ys", 155 | fill_alpha="fill_alpha", 156 | fill_color="fill_color", 157 | line_color="line_color", 158 | **kwargs, 159 | ) 160 | 161 | p = self.raw_figure.patches( 162 | source=dir_patches_data, 163 | # This is to ensure that the columns given in the 164 | # ColumnDataSource are respected (e.g., "line_color"). 165 | **{x: x for x in dir_patches_data.column_names}, 166 | **kwargs, 167 | ) 168 | 169 | return r, p 170 | -------------------------------------------------------------------------------- /src/trajdata/visualization/interactive_vis.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from bokeh.models import ColumnDataSource 5 | 6 | from trajdata.data_structures.agent import AgentType 7 | from trajdata.data_structures.batch import AgentBatch, SceneBatch 8 | from trajdata.data_structures.state import StateArray, StateTensor 9 | from trajdata.maps.map_api import MapAPI 10 | from trajdata.utils import vis_utils 11 | from trajdata.utils.arr_utils import transform_coords_2d_np 12 | from trajdata.visualization.interactive_figure import InteractiveFigure 13 | 14 | 15 | def plot_agent_batch_interactive(batch: AgentBatch, batch_idx: int, cache_path: Path): 16 | fig = InteractiveFigure( 17 | tooltips=[ 18 | ("Class", "@type"), 19 | ("Position", "(@x, @y) m"), 20 | ("Speed", "@speed_mps m/s (@speed_kph km/h)"), 21 | ] 22 | ) 23 | 24 | agent_type: int = batch.agent_type[batch_idx].item() 25 | num_neighbors: int = batch.num_neigh[batch_idx].item() 26 | agent_hist_np: StateArray = batch.agent_hist[batch_idx].cpu().numpy() 27 | neigh_hist_np: StateArray = batch.neigh_hist[batch_idx].cpu().numpy() 28 | neigh_types = batch.neigh_types[batch_idx].cpu().numpy() 29 | agent_histories = ColumnDataSource( 30 | data={ 31 | "xs": [agent_hist_np.get_attr("x")] 32 | + [ 33 | neigh_hist_np[n_neigh].get_attr("x") for n_neigh in range(num_neighbors) 34 | ], 35 | "ys": [agent_hist_np.get_attr("y")] 36 | + [ 37 | neigh_hist_np[n_neigh].get_attr("y") for n_neigh in range(num_neighbors) 38 | ], 39 | "line_dash": ["dashed"] * (num_neighbors + 1), 40 | "line_color": [vis_utils.get_agent_type_color(agent_type)] 41 | + [ 42 | vis_utils.get_agent_type_color(neigh_types[n_neigh]) 43 | for n_neigh in range(num_neighbors) 44 | ], 45 | } 46 | ) 47 | 48 | agent_fut_np: StateArray = batch.agent_fut[batch_idx].cpu().numpy() 49 | neigh_fut_np: StateArray = batch.neigh_fut[batch_idx].cpu().numpy() 50 | agent_futures = ColumnDataSource( 51 | data={ 52 | "xs": [agent_fut_np.get_attr("x")] 53 | + [neigh_fut_np[n_neigh].get_attr("x") for n_neigh in range(num_neighbors)], 54 | "ys": [agent_fut_np.get_attr("y")] 55 | + [neigh_fut_np[n_neigh].get_attr("y") for n_neigh in range(num_neighbors)], 56 | "line_dash": ["solid"] * (num_neighbors + 1), 57 | "line_color": [vis_utils.get_agent_type_color(agent_type)] 58 | + [ 59 | vis_utils.get_agent_type_color(neigh_types[n_neigh]) 60 | for n_neigh in range(num_neighbors) 61 | ], 62 | } 63 | ) 64 | 65 | agent_state: StateArray = batch.agent_hist[batch_idx, -1].cpu().numpy() 66 | x, y = agent_state.position 67 | 68 | if batch.map_names is not None: 69 | map_vis_radius: float = 50.0 70 | mapAPI = MapAPI(cache_path) 71 | fig.add_map( 72 | batch.agents_from_world_tf[batch_idx].cpu().numpy(), 73 | mapAPI.get_map( 74 | batch.map_names[batch_idx], 75 | incl_road_lanes=True, 76 | incl_road_areas=True, 77 | incl_ped_crosswalks=True, 78 | incl_ped_walkways=True, 79 | ), 80 | # x_min, x_max, y_min, y_max 81 | bbox=( 82 | x - map_vis_radius, 83 | x + map_vis_radius, 84 | y - map_vis_radius, 85 | y + map_vis_radius, 86 | ), 87 | ) 88 | 89 | fig.add_lines(agent_histories) 90 | fig.add_lines(agent_futures) 91 | 92 | agent_extent: np.ndarray = batch.agent_hist_extent[batch_idx, -1] 93 | if agent_extent.isnan().any(): 94 | raise ValueError("Agent extents cannot be NaN!") 95 | 96 | length = agent_extent[0].item() 97 | width = agent_extent[1].item() 98 | 99 | heading: float = agent_state.heading.item() 100 | speed_mps: float = np.linalg.norm(agent_state.velocity).item() 101 | 102 | agent_rect_coords = transform_coords_2d_np( 103 | np.array( 104 | [ 105 | [-length / 2, -width / 2], 106 | [-length / 2, width / 2], 107 | [length / 2, width / 2], 108 | [length / 2, -width / 2], 109 | ] 110 | ), 111 | angle=heading, 112 | ) 113 | 114 | agent_rects_data = { 115 | "x": [x], 116 | "y": [y], 117 | "xs": [agent_rect_coords[:, 0] + x], 118 | "ys": [agent_rect_coords[:, 1] + y], 119 | "fill_color": [vis_utils.get_agent_type_color(agent_type)], 120 | "line_color": ["black"], 121 | "fill_alpha": [0.7], 122 | "type": [str(AgentType(agent_type))[len("AgentType.") :]], 123 | "speed_mps": [speed_mps], 124 | "speed_kph": [speed_mps * 3.6], 125 | } 126 | 127 | size = 1.0 128 | if agent_type == AgentType.PEDESTRIAN: 129 | size = 0.25 130 | 131 | dir_patch_coords = transform_coords_2d_np( 132 | np.array( 133 | [ 134 | [0, np.sqrt(3) / 3], 135 | [-1 / 2, -np.sqrt(3) / 6], 136 | [1 / 2, -np.sqrt(3) / 6], 137 | ] 138 | ) 139 | * size, 140 | angle=heading - np.pi / 2, 141 | ) 142 | dir_patches_data = { 143 | "xs": [dir_patch_coords[:, 0] + x], 144 | "ys": [dir_patch_coords[:, 1] + y], 145 | "fill_color": [vis_utils.get_agent_type_color(agent_type)], 146 | "line_color": ["black"], 147 | "alpha": [0.7], 148 | } 149 | 150 | for n_neigh in range(num_neighbors): 151 | agent_type: int = batch.neigh_types[batch_idx, n_neigh].item() 152 | agent_state: StateArray = batch.neigh_hist[batch_idx, n_neigh, -1].cpu().numpy() 153 | agent_extent: np.ndarray = batch.neigh_hist_extents[batch_idx, n_neigh, -1] 154 | 155 | if agent_extent.isnan().any(): 156 | raise ValueError("Agent extents cannot be NaN!") 157 | 158 | length = agent_extent[0].item() 159 | width = agent_extent[1].item() 160 | 161 | x, y = agent_state.position 162 | heading: float = agent_state.heading.item() 163 | speed_mps: float = np.linalg.norm(agent_state.velocity).item() 164 | 165 | agent_rect_coords, dir_patch_coords = vis_utils.compute_agent_rect_coords( 166 | agent_type, heading, length, width 167 | ) 168 | 169 | agent_rects_data["x"].append(x) 170 | agent_rects_data["y"].append(y) 171 | agent_rects_data["xs"].append(agent_rect_coords[:, 0] + x) 172 | agent_rects_data["ys"].append(agent_rect_coords[:, 1] + y) 173 | agent_rects_data["fill_color"].append( 174 | vis_utils.get_agent_type_color(agent_type) 175 | ) 176 | agent_rects_data["line_color"].append("black") 177 | agent_rects_data["fill_alpha"].append(0.7) 178 | agent_rects_data["type"].append(str(AgentType(agent_type))[len("AgentType.") :]) 179 | agent_rects_data["speed_mps"].append(speed_mps) 180 | agent_rects_data["speed_kph"].append(speed_mps * 3.6) 181 | 182 | dir_patches_data["xs"].append(dir_patch_coords[:, 0] + x) 183 | dir_patches_data["ys"].append(dir_patch_coords[:, 1] + y) 184 | dir_patches_data["fill_color"].append( 185 | vis_utils.get_agent_type_color(agent_type) 186 | ) 187 | dir_patches_data["line_color"].append("black") 188 | dir_patches_data["alpha"].append(0.7) 189 | 190 | rects, _ = fig.add_agents( 191 | ColumnDataSource(data=agent_rects_data), ColumnDataSource(data=dir_patches_data) 192 | ) 193 | 194 | fig.raw_figure.hover.renderers = [rects] 195 | fig.show() 196 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/trajdata/51f0efa9d572fa7da480ef8caf089d0d6987de9f/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from collections import defaultdict 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from trajdata.data_structures.agent import AgentType 7 | from trajdata.data_structures.batch import AgentBatch 8 | from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement 9 | from trajdata.data_structures.state import NP_STATE_TYPES, TORCH_STATE_TYPES 10 | from trajdata.dataset import UnifiedDataset 11 | 12 | 13 | class TestDataset(unittest.TestCase): 14 | def test_dataloading(self): 15 | dataset = UnifiedDataset( 16 | desired_data=["nusc_mini-mini_val"], 17 | centric="agent", 18 | desired_dt=0.1, 19 | history_sec=(3.2, 3.2), 20 | future_sec=(4.8, 4.8), 21 | only_predict=[AgentType.VEHICLE], 22 | agent_interaction_distances=defaultdict(lambda: 30.0), 23 | incl_robot_future=True, 24 | incl_raster_map=True, 25 | standardize_data=False, 26 | raster_map_params={ 27 | "px_per_m": 2, 28 | "map_size_px": 224, 29 | "offset_frac_xy": (-0.5, 0.0), 30 | }, 31 | num_workers=4, 32 | verbose=True, 33 | data_dirs={ # Remember to change this to match your filesystem! 34 | "nusc_mini": "~/datasets/nuScenes", 35 | }, 36 | ) 37 | 38 | dataloader = DataLoader( 39 | dataset, 40 | batch_size=4, 41 | shuffle=True, 42 | collate_fn=dataset.get_collate_fn(), 43 | num_workers=0, 44 | ) 45 | 46 | i = 0 47 | batch: AgentBatch 48 | for batch in dataloader: 49 | i += 1 50 | 51 | batch.to("cuda") 52 | 53 | self.assertIsInstance(batch.curr_agent_state, dataset.torch_state_type) 54 | self.assertIsInstance(batch.agent_hist, dataset.torch_obs_type) 55 | self.assertIsInstance(batch.agent_fut, dataset.torch_obs_type) 56 | self.assertIsInstance(batch.robot_fut, dataset.torch_obs_type) 57 | 58 | if i == 5: 59 | break 60 | 61 | def test_dict_dataloading(self): 62 | dataset = UnifiedDataset( 63 | desired_data=["nusc_mini-mini_val"], 64 | centric="agent", 65 | desired_dt=0.1, 66 | history_sec=(3.2, 3.2), 67 | future_sec=(4.8, 4.8), 68 | only_predict=[AgentType.VEHICLE], 69 | agent_interaction_distances=defaultdict(lambda: 30.0), 70 | incl_robot_future=True, 71 | incl_raster_map=True, 72 | standardize_data=False, 73 | raster_map_params={ 74 | "px_per_m": 2, 75 | "map_size_px": 224, 76 | "offset_frac_xy": (-0.5, 0.0), 77 | }, 78 | num_workers=4, 79 | verbose=True, 80 | data_dirs={ # Remember to change this to match your filesystem! 81 | "nusc_mini": "~/datasets/nuScenes", 82 | }, 83 | ) 84 | 85 | dataloader = DataLoader( 86 | dataset, 87 | batch_size=4, 88 | shuffle=True, 89 | collate_fn=dataset.get_collate_fn(return_dict=True), 90 | num_workers=0, 91 | ) 92 | 93 | i = 0 94 | for batch in dataloader: 95 | i += 1 96 | 97 | self.assertIsInstance(batch["curr_agent_state"], dataset.torch_state_type) 98 | self.assertIsInstance(batch["agent_hist"], dataset.torch_obs_type) 99 | self.assertIsInstance(batch["agent_fut"], dataset.torch_obs_type) 100 | self.assertIsInstance(batch["robot_fut"], dataset.torch_obs_type) 101 | 102 | if i == 5: 103 | break 104 | 105 | dataset = UnifiedDataset( 106 | desired_data=["nusc_mini-mini_val"], 107 | centric="scene", 108 | desired_dt=0.1, 109 | history_sec=(3.2, 3.2), 110 | future_sec=(4.8, 4.8), 111 | only_predict=[AgentType.VEHICLE], 112 | agent_interaction_distances=defaultdict(lambda: 30.0), 113 | incl_robot_future=True, 114 | incl_raster_map=True, 115 | standardize_data=False, 116 | raster_map_params={ 117 | "px_per_m": 2, 118 | "map_size_px": 224, 119 | "offset_frac_xy": (-0.5, 0.0), 120 | }, 121 | num_workers=4, 122 | verbose=True, 123 | data_dirs={ # Remember to change this to match your filesystem! 124 | "nusc_mini": "~/datasets/nuScenes", 125 | }, 126 | ) 127 | 128 | dataloader = DataLoader( 129 | dataset, 130 | batch_size=4, 131 | shuffle=True, 132 | collate_fn=dataset.get_collate_fn(return_dict=True), 133 | num_workers=0, 134 | ) 135 | 136 | i = 0 137 | for batch in dataloader: 138 | i += 1 139 | 140 | self.assertIsInstance( 141 | batch["centered_agent_state"], dataset.torch_state_type 142 | ) 143 | self.assertIsInstance(batch["agent_hist"], dataset.torch_obs_type) 144 | self.assertIsInstance(batch["agent_fut"], dataset.torch_obs_type) 145 | self.assertIsInstance(batch["robot_fut"], dataset.torch_obs_type) 146 | 147 | if i == 5: 148 | break 149 | 150 | def test_default_datatypes_agent(self): 151 | dataset = UnifiedDataset( 152 | desired_data=["nusc_mini-mini_val"], 153 | centric="agent", 154 | desired_dt=0.1, 155 | history_sec=(3.2, 3.2), 156 | future_sec=(4.8, 4.8), 157 | only_predict=[AgentType.VEHICLE], 158 | agent_interaction_distances=defaultdict(lambda: 30.0), 159 | incl_robot_future=True, 160 | incl_raster_map=True, 161 | standardize_data=False, 162 | raster_map_params={ 163 | "px_per_m": 2, 164 | "map_size_px": 224, 165 | "offset_frac_xy": (-0.5, 0.0), 166 | }, 167 | num_workers=4, 168 | verbose=True, 169 | data_dirs={ # Remember to change this to match your filesystem! 170 | "nusc_mini": "~/datasets/nuScenes", 171 | }, 172 | ) 173 | 174 | elem: AgentBatchElement = dataset[0] 175 | self.assertIsInstance(elem.curr_agent_state_np, dataset.np_state_type) 176 | self.assertIsInstance(elem.agent_history_np, dataset.np_obs_type) 177 | self.assertIsInstance(elem.agent_future_np, dataset.np_obs_type) 178 | self.assertIsInstance(elem.robot_future_np, dataset.np_obs_type) 179 | 180 | def test_default_datatypes_scene(self): 181 | dataset = UnifiedDataset( 182 | desired_data=["nusc_mini-mini_val"], 183 | centric="scene", 184 | desired_dt=0.1, 185 | history_sec=(3.2, 3.2), 186 | future_sec=(4.8, 4.8), 187 | only_predict=[AgentType.VEHICLE], 188 | agent_interaction_distances=defaultdict(lambda: 30.0), 189 | incl_robot_future=True, 190 | incl_raster_map=True, 191 | standardize_data=False, 192 | raster_map_params={ 193 | "px_per_m": 2, 194 | "map_size_px": 224, 195 | "offset_frac_xy": (-0.5, 0.0), 196 | }, 197 | num_workers=4, 198 | verbose=True, 199 | data_dirs={ # Remember to change this to match your filesystem! 200 | "nusc_mini": "~/datasets/nuScenes", 201 | }, 202 | ) 203 | 204 | elem: SceneBatchElement = dataset[0] 205 | self.assertIsInstance(elem.centered_agent_state_np, dataset.np_state_type) 206 | self.assertIsInstance(elem.agent_histories[0], dataset.np_obs_type) 207 | self.assertIsInstance(elem.agent_futures[0], dataset.np_obs_type) 208 | self.assertIsInstance(elem.robot_future_np, dataset.np_obs_type) 209 | 210 | 211 | if __name__ == "__main__": 212 | unittest.main() 213 | -------------------------------------------------------------------------------- /tests/test_description_matching.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from trajdata import UnifiedDataset 4 | 5 | 6 | class TestDescriptionMatching(unittest.TestCase): 7 | def test_night(self): 8 | dataset = UnifiedDataset( 9 | desired_data=["nusc_mini"], scene_description_contains=["night"] 10 | ) 11 | 12 | for scene_info in dataset.scenes(): 13 | self.assertIn("night", scene_info.description) 14 | 15 | def test_intersection(self): 16 | dataset = UnifiedDataset( 17 | desired_data=["nusc_mini"], scene_description_contains=["intersection"] 18 | ) 19 | 20 | for scene_info in dataset.scenes(): 21 | self.assertIn("intersection", scene_info.description) 22 | 23 | def test_intersection_more_initial(self): 24 | dataset = UnifiedDataset( 25 | desired_data=["nusc_mini", "nuplan_mini"], 26 | scene_description_contains=["intersection"], 27 | data_dirs={ # Remember to change this to match your filesystem! 28 | "nusc_mini": "~/datasets/nuScenes", 29 | "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", 30 | }, 31 | ) 32 | 33 | for scene_info in dataset.scenes(): 34 | self.assertIn("intersection", scene_info.description) 35 | 36 | 37 | if __name__ == "__main__": 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /tests/test_state.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from trajdata.data_structures.state import NP_STATE_TYPES, TORCH_STATE_TYPES 7 | 8 | AgentStateArray = NP_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,h"] 9 | AgentObsArray = NP_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,s,c"] 10 | AgentStateTensor = TORCH_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,h"] 11 | AgentObsTensor = TORCH_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,s,c"] 12 | 13 | 14 | class TestStateTensor(unittest.TestCase): 15 | def test_construction(self): 16 | a = AgentStateTensor(torch.rand(2, 8)) 17 | b = torch.rand(8).as_subclass(AgentStateTensor) 18 | c = AgentObsTensor(torch.rand(5, 9)) 19 | 20 | def test_class_propagation(self): 21 | a = AgentStateTensor(torch.rand(2, 8)) 22 | self.assertTrue(isinstance(a.to("cpu"), AgentStateTensor)) 23 | 24 | a = AgentStateTensor(torch.rand(2, 8)) 25 | self.assertTrue(isinstance(a.cpu(), AgentStateTensor)) 26 | 27 | b = AgentStateTensor(torch.rand(2, 8)) 28 | self.assertTrue(isinstance(a + b, AgentStateTensor)) 29 | 30 | b = torch.rand(2, 8) 31 | self.assertTrue(isinstance(a + b, AgentStateTensor)) 32 | 33 | a += 1 34 | self.assertTrue(isinstance(a, AgentStateTensor)) 35 | 36 | def test_property_access(self): 37 | a = AgentStateTensor(torch.rand(2, 8)) 38 | position = a[..., :3] 39 | velocity = a[..., 3:5] 40 | acc = a[..., 5:7] 41 | h = a[..., 7:] 42 | 43 | self.assertTrue(torch.allclose(a.position3d, position)) 44 | self.assertTrue(torch.allclose(a.velocity, velocity)) 45 | self.assertTrue(torch.allclose(a.acceleration, acc)) 46 | self.assertTrue(torch.allclose(a.heading, h)) 47 | 48 | def test_heading_conversion(self): 49 | a = AgentStateTensor(torch.rand(2, 8)) 50 | h = a[..., 7:] 51 | hv = a.heading_vector 52 | self.assertTrue(torch.allclose(torch.atan2(hv[..., 1], hv[..., 0])[:, None], h)) 53 | 54 | def test_long_lat_velocity(self): 55 | a = AgentStateTensor(torch.rand(8)) 56 | velocity = a[3:5] 57 | h = a[7] 58 | lonlat_v = a.as_format("v_lon,v_lat") 59 | lonlat_v_correct = ( 60 | torch.tensor([[np.cos(h), np.sin(h)], [-np.sin(h), np.cos(h)]])[None, ...] 61 | @ velocity[..., None] 62 | )[..., 0] 63 | 64 | self.assertTrue(torch.allclose(lonlat_v, lonlat_v_correct)) 65 | 66 | b = a.as_format("x,y,xd,yd,s,c") 67 | s = b[-2] 68 | c = b[-1] 69 | lonlat_v = b.as_format("v_lon,v_lat") 70 | lonlat_v_correct = ( 71 | torch.tensor([[c, s], [-s, c]])[None, ...] @ velocity[..., None] 72 | )[..., 0] 73 | 74 | self.assertTrue(torch.allclose(lonlat_v, lonlat_v_correct)) 75 | 76 | def test_long_lat_conversion(self): 77 | a = AgentStateTensor(torch.rand(2, 8)) 78 | b = a.as_format("xd,yd,h") 79 | c = b.as_format("v_lon,v_lat,h") 80 | d = c.as_format("xd,yd,h") 81 | self.assertTrue(torch.allclose(b, d)) 82 | 83 | def test_as_format(self): 84 | a = AgentStateTensor(torch.rand(2, 8)) 85 | b = a.as_format("x,y,z,xd,yd,xdd,ydd,s,c") 86 | self.assertTrue(isinstance(b, AgentObsTensor)) 87 | self.assertTrue(torch.allclose(a, b.as_format(a._format))) 88 | 89 | def test_as_tensor(self): 90 | a = AgentStateTensor(torch.rand(2, 8)) 91 | b = a.as_tensor() 92 | self.assertTrue(isinstance(b, torch.Tensor)) 93 | self.assertFalse(isinstance(b, AgentStateTensor)) 94 | 95 | def test_tensor_ops(self): 96 | a = AgentStateTensor(torch.rand(2, 8)) 97 | b = a[0] + a[1] 98 | c = torch.mean(b) 99 | self.assertFalse(isinstance(c, AgentStateTensor)) 100 | self.assertTrue(isinstance(c, torch.Tensor)) 101 | 102 | 103 | class TestStateArray(unittest.TestCase): 104 | def test_construction(self): 105 | a = np.random.rand(2, 8).view(AgentStateArray) 106 | c = np.random.rand(5, 9).view(AgentObsArray) 107 | 108 | def test_property_access(self): 109 | a = np.random.rand(2, 8).view(AgentStateArray) 110 | position = a[..., :3] 111 | velocity = a[..., 3:5] 112 | acc = a[..., 5:7] 113 | h = a[..., 7:] 114 | 115 | self.assertTrue(np.allclose(a.position3d, position)) 116 | self.assertTrue(np.allclose(a.velocity, velocity)) 117 | self.assertTrue(np.allclose(a.acceleration, acc)) 118 | self.assertTrue(np.allclose(a.heading, h)) 119 | 120 | def test_property_setting(self): 121 | a = np.random.rand(2, 8).view(AgentStateArray) 122 | a.heading = 0.0 123 | self.assertTrue(np.allclose(a[..., -1], np.zeros([2, 1]))) 124 | 125 | def test_heading_conversion(self): 126 | a = np.random.rand(2, 8).view(AgentStateArray) 127 | h = a[..., 7:] 128 | hv = a.heading_vector 129 | self.assertTrue(np.allclose(np.arctan2(hv[..., 1], hv[..., 0])[:, None], h)) 130 | 131 | def test_long_lat_velocity(self): 132 | a = np.random.rand(8).view(AgentStateArray) 133 | velocity = a[3:5] 134 | h = a[7] 135 | lonlat_v = a.as_format("v_lon,v_lat") 136 | lonlat_v_correct = ( 137 | np.array([[np.cos(h), np.sin(h)], [-np.sin(h), np.cos(h)]])[None, ...] 138 | @ velocity[..., None] 139 | )[..., 0] 140 | 141 | self.assertTrue(np.allclose(lonlat_v, lonlat_v_correct)) 142 | 143 | b = a.as_format("x,y,xd,yd,s,c") 144 | s = b[-2] 145 | c = b[-1] 146 | lonlat_v = b.as_format("v_lon,v_lat") 147 | lonlat_v_correct = ( 148 | np.array([[c, s], [-s, c]])[None, ...] @ velocity[..., None] 149 | )[..., 0] 150 | 151 | self.assertTrue(np.allclose(lonlat_v, lonlat_v_correct)) 152 | 153 | def test_long_lat_conversion(self): 154 | a = np.random.rand(2, 8).view(AgentStateArray) 155 | b = a.as_format("xd,yd,h") 156 | c = b.as_format("v_lon,v_lat,h") 157 | d = c.as_format("xd,yd,h") 158 | self.assertTrue(np.allclose(b, d)) 159 | 160 | def test_as_format(self): 161 | a = np.random.rand(2, 8).view(AgentStateArray) 162 | b = a.as_format("x,y,z,xd,yd,xdd,ydd,s,c") 163 | self.assertTrue(isinstance(b, AgentObsArray)) 164 | self.assertTrue(np.allclose(a, b.as_format(a._format))) 165 | 166 | def test_as_ndarray(self): 167 | a: AgentStateArray = np.random.rand(2, 8).view(AgentStateArray) 168 | b = a.as_ndarray() 169 | self.assertTrue(isinstance(b, np.ndarray)) 170 | self.assertFalse(isinstance(b, AgentStateArray)) 171 | 172 | def test_tensor_ops(self): 173 | a = np.random.rand(2, 8).view(AgentStateArray) 174 | b = a[0] + a[1] 175 | c = np.mean(b) 176 | self.assertFalse(isinstance(c, AgentStateArray)) 177 | self.assertTrue(isinstance(c, float)) 178 | 179 | 180 | if __name__ == "__main__": 181 | unittest.main() 182 | -------------------------------------------------------------------------------- /tests/test_traffic_data.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from trajdata import UnifiedDataset 4 | from trajdata.caching.df_cache import DataFrameCache 5 | 6 | 7 | class TestTrafficLightData(unittest.TestCase): 8 | @classmethod 9 | def setUpClass(cls) -> None: 10 | kwargs = { 11 | "desired_data": ["nuplan_mini-mini_val"], 12 | "centric": "scene", 13 | "history_sec": (3.2, 3.2), 14 | "future_sec": (4.8, 4.8), 15 | "incl_robot_future": False, 16 | "incl_raster_map": True, 17 | "cache_location": "~/.unified_data_cache", 18 | "raster_map_params": { 19 | "px_per_m": 2, 20 | "map_size_px": 224, 21 | "offset_frac_xy": (-0.5, 0.0), 22 | }, 23 | "num_workers": 64, 24 | "verbose": True, 25 | "data_dirs": { # Remember to change this to match your filesystem! 26 | "nuplan_mini": "~/datasets/nuplan/dataset/nuplan-v1.1", 27 | }, 28 | } 29 | 30 | cls.dataset = UnifiedDataset( 31 | **kwargs, 32 | desired_dt=0.05, 33 | ) 34 | 35 | cls.downsampled_dataset = UnifiedDataset( 36 | **kwargs, 37 | desired_dt=0.1, 38 | ) 39 | 40 | cls.upsampled_dataset = UnifiedDataset( 41 | **kwargs, 42 | desired_dt=0.025, 43 | ) 44 | 45 | cls.scene_num: int = 100 46 | 47 | def test_traffic_light_loading(self): 48 | # get random scene 49 | scene = self.dataset.get_scene(self.scene_num) 50 | scene_cache = DataFrameCache(self.dataset.cache_path, scene) 51 | traffic_light_status = scene_cache.get_traffic_light_status_dict() 52 | 53 | # just check if the loading works without errors 54 | self.assertTrue(traffic_light_status is not None) 55 | 56 | def test_downsampling(self): 57 | # get random scene from both datasets 58 | scene = self.dataset.get_scene(self.scene_num) 59 | downsampled_scene = self.downsampled_dataset.get_scene(self.scene_num) 60 | 61 | self.assertEqual(scene.name, downsampled_scene.name) 62 | 63 | scene_cache = DataFrameCache(self.dataset.cache_path, scene) 64 | downsampled_scene_cache = DataFrameCache( 65 | self.downsampled_dataset.cache_path, downsampled_scene 66 | ) 67 | traffic_light_status = scene_cache.get_traffic_light_status_dict() 68 | downsampled_traffic_light_status = ( 69 | downsampled_scene_cache.get_traffic_light_status_dict() 70 | ) 71 | 72 | orig_lane_ids = set(key[0] for key in traffic_light_status.keys()) 73 | downsampled_lane_ids = set( 74 | key[0] for key in downsampled_traffic_light_status.keys() 75 | ) 76 | self.assertSetEqual(orig_lane_ids, downsampled_lane_ids) 77 | 78 | # check that matching indices match 79 | for ( 80 | lane_id, 81 | scene_ts, 82 | ), downsampled_status in downsampled_traffic_light_status.items(): 83 | if scene_ts % 2 == 0: 84 | try: 85 | prev_status = traffic_light_status[lane_id, scene_ts * 2] 86 | except KeyError: 87 | prev_status = None 88 | 89 | try: 90 | next_status = traffic_light_status[lane_id, scene_ts * 2 + 1] 91 | except KeyError: 92 | next_status = None 93 | 94 | self.assertTrue( 95 | prev_status is not None or next_status is not None, 96 | f"Lane {lane_id} at t={scene_ts} has status {downsampled_status} " 97 | f"in the downsampled dataset, but neither t={2*scene_ts} nor " 98 | f"t={2*scene_ts + 1} were found in the original dataset.", 99 | ) 100 | self.assertTrue( 101 | downsampled_status == prev_status 102 | or downsampled_status == next_status, 103 | f"Lane {lane_id} at t={scene_ts*2, scene_ts*2 + 1} in the original dataset " 104 | f"had status {prev_status, next_status}, but in the downsampled dataset, " 105 | f"{lane_id} at t={scene_ts} had status {downsampled_status}", 106 | ) 107 | 108 | def test_upsampling(self): 109 | # get random scene from both datasets 110 | scene = self.dataset.get_scene(self.scene_num) 111 | upsampled_scene = self.upsampled_dataset.get_scene(self.scene_num) 112 | scene_cache = DataFrameCache(self.dataset.cache_path, scene) 113 | upsampled_scene_cache = DataFrameCache( 114 | self.upsampled_dataset.cache_path, upsampled_scene 115 | ) 116 | traffic_light_status = scene_cache.get_traffic_light_status_dict() 117 | upsampled_traffic_light_status = ( 118 | upsampled_scene_cache.get_traffic_light_status_dict() 119 | ) 120 | 121 | # check that matching indices match 122 | for (lane_id, scene_ts), status in upsampled_traffic_light_status.items(): 123 | if scene_ts % 2 == 0: 124 | orig_status = traffic_light_status[lane_id, scene_ts // 2] 125 | self.assertEqual( 126 | status, 127 | orig_status, 128 | f"Lane {lane_id} at t={scene_ts // 2} in the original dataset " 129 | f"had status {orig_status}, but in the upsampled dataset, " 130 | f"{lane_id} at t={scene_ts} had status {status}", 131 | ) 132 | else: 133 | try: 134 | prev_status = traffic_light_status[lane_id, scene_ts // 2] 135 | except KeyError: 136 | prev_status = None 137 | try: 138 | next_status = traffic_light_status[lane_id, scene_ts // 2 + 1] 139 | except KeyError as k: 140 | next_status = None 141 | 142 | self.assertTrue( 143 | prev_status is not None or next_status is not None, 144 | f"Lane {lane_id} at t={scene_ts} has status {status} " 145 | f"in the upsampled dataset, but neither t={scene_ts // 2} nor " 146 | f"t={scene_ts // 2 + 1} were found in the original dataset.", 147 | ) 148 | 149 | self.assertTrue( 150 | status == prev_status or status == next_status, 151 | f"Lane {lane_id} at t={scene_ts // 2, scene_ts // 2 + 1} in the original dataset " 152 | f"had status {prev_status, next_status}, but in the upsampled dataset, " 153 | f"{lane_id} at t={scene_ts} had status {status}", 154 | ) 155 | 156 | 157 | if __name__ == "__main__": 158 | unittest.main() 159 | -------------------------------------------------------------------------------- /tests/test_vec_map.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pathlib import Path 3 | from typing import Dict, List 4 | 5 | import numpy as np 6 | from shapely import contains_xy, dwithin, linearrings, points, polygons 7 | 8 | from trajdata import MapAPI, VectorMap 9 | from trajdata.maps.vec_map_elements import MapElementType 10 | 11 | 12 | class TestVectorMap(unittest.TestCase): 13 | @classmethod 14 | def setUpClass(cls) -> None: 15 | cache_path = Path("~/.unified_data_cache").expanduser() 16 | cls.map_api = MapAPI(cache_path) 17 | cls.proto_loading_kwargs = { 18 | "incl_road_lanes": True, 19 | "incl_road_areas": True, 20 | "incl_ped_crosswalks": True, 21 | "incl_ped_walkways": True, 22 | } 23 | 24 | cls.location_dict: Dict[str, List[str]] = { 25 | "nusc_mini": ["boston-seaport", "singapore-onenorth"], 26 | } 27 | 28 | # TODO(pkarkus) this assumes we already have the maps cached. It would be better 29 | # to attempt to cache them if the cache does not yet exists. 30 | def test_map_existence(self): 31 | for env_name, map_names in self.location_dict.items(): 32 | for map_name in map_names: 33 | vec_map: VectorMap = self.map_api.get_map( 34 | f"{env_name}:{map_name}", **self.proto_loading_kwargs 35 | ) 36 | assert vec_map is not None 37 | 38 | def test_proto_equivalence(self): 39 | for env_name, map_names in self.location_dict.items(): 40 | for map_name in map_names: 41 | vec_map: VectorMap = self.map_api.get_map( 42 | f"{env_name}:{map_name}", **self.proto_loading_kwargs 43 | ) 44 | 45 | assert maps_equal( 46 | VectorMap.from_proto( 47 | vec_map.to_proto(), **self.proto_loading_kwargs 48 | ), 49 | vec_map, 50 | ) 51 | 52 | def test_road_area_queries(self): 53 | env_name = next(self.location_dict.keys().__iter__()) 54 | map_name = self.location_dict[env_name][0] 55 | 56 | vec_map: VectorMap = self.map_api.get_map( 57 | f"{env_name}:{map_name}", **self.proto_loading_kwargs 58 | ) 59 | 60 | if vec_map.search_rtrees is None: 61 | return 62 | 63 | point = vec_map.lanes[0].center.xy[0, :] 64 | closest_area = vec_map.get_closest_area( 65 | point, elem_type=MapElementType.ROAD_AREA 66 | ) 67 | holes = closest_area.interior_holes 68 | if len(holes) == 0: 69 | holes = None 70 | closest_area_polygon = polygons(closest_area.exterior_polygon.xy, holes=holes) 71 | self.assertTrue(contains_xy(closest_area_polygon, point[None, :2])) 72 | 73 | rnd_points = np.random.uniform( 74 | low=vec_map.extent[:2], high=vec_map.extent[3:5], size=(10, 2) 75 | ) 76 | 77 | NEARBY_DIST = 150.0 78 | for point in rnd_points: 79 | nearby_areas = vec_map.get_areas_within( 80 | point, elem_type=MapElementType.ROAD_AREA, dist=NEARBY_DIST 81 | ) 82 | for area in nearby_areas: 83 | holes = [linearrings(hole.xy) for hole in area.interior_holes] 84 | if len(holes) == 0: 85 | holes = None 86 | area_polygon = polygons(area.exterior_polygon.xy, holes=holes) 87 | point_pt = points(point) 88 | self.assertTrue(dwithin(area_polygon, point_pt, distance=NEARBY_DIST)) 89 | 90 | for elem_type in [ 91 | MapElementType.PED_CROSSWALK, 92 | MapElementType.PED_WALKWAY, 93 | ]: 94 | for point in rnd_points: 95 | nearby_areas = vec_map.get_areas_within( 96 | point, elem_type=elem_type, dist=NEARBY_DIST 97 | ) 98 | for area in nearby_areas: 99 | area_polygon = polygons(area.polygon.xy) 100 | point_pt = points(point) 101 | if not dwithin(area_polygon, point_pt, distance=NEARBY_DIST): 102 | print( 103 | f"{elem_type.name} at {area_polygon} is not within {NEARBY_DIST} of {point_pt}", 104 | ) 105 | 106 | # TODO(bivanovic): Add more! 107 | 108 | 109 | def maps_equal(map1: VectorMap, map2: VectorMap) -> bool: 110 | elements1_set = set([elem.id for elem in map1.iter_elems()]) 111 | elements2_set = set([elem.id for elem in map2.iter_elems()]) 112 | return elements1_set == elements2_set 113 | 114 | 115 | if __name__ == "__main__": 116 | unittest.main() 117 | --------------------------------------------------------------------------------