├── terra ├── __init__.py ├── viz │ ├── __init__.py │ ├── game │ │ ├── settings.py │ │ ├── utils.py │ │ ├── agent.py │ │ ├── world.py │ │ └── game.py │ ├── play.py │ └── main_manual.py ├── env_generation │ ├── __init__.py │ ├── visualize_maps.py │ ├── generate_dataset.py │ ├── README.md │ ├── generate_relocations.py │ ├── generate_foundations.py │ ├── create_train_data.py │ ├── convert_to_terra.py │ └── openstreet.py ├── settings.py ├── map_utils │ ├── test_openstreet_generator.py │ ├── test_openstreet_plugin.py │ ├── jax_terrain_generation.py │ ├── openstreet_dataset_generator.py │ └── openstreet_plugin.py ├── map_generator.py ├── map.py ├── actions.py ├── curriculum.py ├── agent.py ├── wrappers.py ├── config.py ├── utils.py └── env.py ├── tests ├── __init__.py └── unit │ └── __init__.py ├── benchmark ├── __init__.py ├── plot_benchmark.py └── benchmark.py ├── requirements-dev.txt ├── assets ├── overview.gif ├── scaling-envs.png └── scaling-devices.png ├── data └── custom │ ├── images │ ├── custom_1.png │ ├── custom_2.png │ ├── custom_3.png │ ├── custom_4.png │ ├── custom_5.png │ ├── custom_6.png │ ├── custom_7.png │ ├── custom_8.png │ ├── custom_9.png │ ├── custom_10.png │ ├── custom_11.png │ └── custom_12.png │ ├── dumpability │ ├── custom_1.png │ ├── custom_2.png │ ├── custom_3.png │ ├── custom_4.png │ ├── custom_5.png │ ├── custom_6.png │ ├── custom_7.png │ ├── custom_8.png │ ├── custom_9.png │ ├── custom_10.png │ ├── custom_11.png │ └── custom_12.png │ └── occupancy │ ├── custom_1.png │ ├── custom_10.png │ ├── custom_11.png │ ├── custom_12.png │ ├── custom_2.png │ ├── custom_3.png │ ├── custom_4.png │ ├── custom_5.png │ ├── custom_6.png │ ├── custom_7.png │ ├── custom_8.png │ └── custom_9.png ├── .flake8 ├── .pre-commit-config.yaml ├── requirements.txt ├── pyproject.toml ├── .gitignore ├── setup.py ├── .pylintrc ├── config └── env_generation_config.yaml ├── environment.yml └── README.md /terra/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /terra/viz/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | -------------------------------------------------------------------------------- /terra/env_generation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/overview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/assets/overview.gif -------------------------------------------------------------------------------- /assets/scaling-envs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/assets/scaling-envs.png -------------------------------------------------------------------------------- /assets/scaling-devices.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/assets/scaling-devices.png -------------------------------------------------------------------------------- /data/custom/images/custom_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_1.png -------------------------------------------------------------------------------- /data/custom/images/custom_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_2.png -------------------------------------------------------------------------------- /data/custom/images/custom_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_3.png -------------------------------------------------------------------------------- /data/custom/images/custom_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_4.png -------------------------------------------------------------------------------- /data/custom/images/custom_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_5.png -------------------------------------------------------------------------------- /data/custom/images/custom_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_6.png -------------------------------------------------------------------------------- /data/custom/images/custom_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_7.png -------------------------------------------------------------------------------- /data/custom/images/custom_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_8.png -------------------------------------------------------------------------------- /data/custom/images/custom_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_9.png -------------------------------------------------------------------------------- /data/custom/images/custom_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_10.png -------------------------------------------------------------------------------- /data/custom/images/custom_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_11.png -------------------------------------------------------------------------------- /data/custom/images/custom_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/images/custom_12.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_1.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_2.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_3.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_4.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_5.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_6.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_7.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_8.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_9.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_1.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_10.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_11.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_12.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_2.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_3.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_4.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_5.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_6.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_7.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_8.png -------------------------------------------------------------------------------- /data/custom/occupancy/custom_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/occupancy/custom_9.png -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, E203, W, PAI 3 | per-file-ignores = 4 | terra/noise/simplex_noise.py: F401, E741, E731 5 | -------------------------------------------------------------------------------- /data/custom/dumpability/custom_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_10.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_11.png -------------------------------------------------------------------------------- /data/custom/dumpability/custom_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra/HEAD/data/custom/dumpability/custom_12.png -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.12 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | jaxlib 3 | chex 4 | tqdm 5 | flax 6 | matplotlib 7 | pygame 8 | wandb 9 | tensorflow_probability 10 | osmnx 11 | opencv-python 12 | pathlib 13 | scikit-image -------------------------------------------------------------------------------- /terra/settings.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | IntMap = jnp.int16 4 | INTMAP_MAX = jnp.iinfo(IntMap).max 5 | 6 | IntLowDim = jnp.int8 7 | INTLOWDIM_MAX = jnp.iinfo(IntLowDim).max 8 | 9 | Float = jnp.float32 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | target-version = ['py311'] # Adjust this line to match your target Python version 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | /( 7 | .git 8 | | .venv 9 | | __pycache__ 10 | )/ 11 | ''' 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__/* 2 | tests/unit/__pycache__/* 3 | .ipynb_checkpoints/* 4 | .vscode/* 5 | .others/* 6 | .coverage 7 | __pycache__/* 8 | *.egg* 9 | terra/map_utils/__pycache__/ 10 | cache/ 11 | docs/ 12 | terra/viz/game/__pycache__/* 13 | terra/viz/__pycache__/* 14 | viz/ 15 | terra/env_generation/__pycache__/ 16 | terra/digbench/ 17 | data/* 18 | !data/custom/ 19 | !data/custom/** 20 | *.pkl -------------------------------------------------------------------------------- /terra/viz/game/settings.py: -------------------------------------------------------------------------------- 1 | MAP_TILES = ( 2 | 192 # 64 * 3, total number of tiles for a nice visualization (scales with MAP_EDGE) 3 | ) 4 | COLORS = { 5 | 0: "#cfcfcf", # neutral 6 | 5: "#E4DCCF", # final dumping area to terminate the episode 7 | 3: "#ab9f95", # non-dumpable (e.g. road) 8 | 4: "#8800ff", # to dig 9 | 2: "#000000", # obstacle 10 | 1: "#002B5B", # action map dump 11 | -1: "#26bd6c", # action map dug 12 | "agent_body": (0, 43, 91), 13 | "agent_cabin": { 14 | "loaded": (165, 115, 75), 15 | "not_loaded": (234, 84, 85), 16 | }, 17 | } 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # Specifying all dependencies, including direct and indirect, for clarity 4 | requires = [ 5 | "jax", 6 | "jaxlib", 7 | "chex", 8 | "tqdm", 9 | "flax", 10 | "matplotlib", 11 | "pygame", 12 | "wandb", 13 | "tensorflow_probability", 14 | "osmnx", 15 | "opencv-python", 16 | "scikit-image", 17 | ] 18 | 19 | setup( 20 | name="terra", 21 | version="0.0.1", 22 | keywords="memory, environment, agent, rl, jax, gym, grid, gridworld, excavator", 23 | description="Minimalistic grid map environment built with JAX", 24 | packages=find_packages(), 25 | install_requires=requires, 26 | python_requires=">=3.10", 27 | ) 28 | -------------------------------------------------------------------------------- /terra/map_utils/test_openstreet_generator.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def _convert_terra_img_to_cv2(img): 6 | img = img.astype(np.int8) 7 | img = np.where(img == 0, 255, img) 8 | img = np.where(img == -1, 0, img) 9 | img = np.where(img == 1, 100, img) 10 | img = img[..., None].repeat(3, -1) 11 | return img.astype(np.uint8) 12 | 13 | 14 | if __name__ == "__main__": 15 | img_idx = 1 16 | path = f"/home/antonio/Downloads/img_generator/3_buildings/60x60/img_{img_idx}.npy" 17 | img = np.load(path) 18 | print(img) 19 | print(img.shape) 20 | 21 | print((img == 0).sum()) 22 | img = _convert_terra_img_to_cv2(img) 23 | print((img == 255).sum()) 24 | cv2.imshow("img", img.repeat(10, 0).repeat(10, 1)) 25 | cv2.waitKey(0) 26 | cv2.destroyAllWindows() 27 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | disable= 3 | missing-docstring, 4 | too-few-public-methods, 5 | fixme, 6 | broad-except, 7 | too-many-instance-attributes, 8 | too-many-arguments, 9 | raise-missing-from, 10 | no-self-use, 11 | arguments-renamed 12 | iterating-dictionary 13 | max-public-method=20 14 | 15 | 16 | [REPORTS] 17 | output-format=colorized 18 | files-output=no 19 | reports=no 20 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 21 | 22 | 23 | [MISCELLANEOUS] 24 | notes=FIXME,TODO 25 | 26 | 27 | [FORMAT] 28 | max-line-length=120 29 | max-module-lines=1500 30 | 31 | 32 | [SIMILARITIES] 33 | min-similarity-lines=6 34 | ignore-comments=yes 35 | ignore-docstrings=yes 36 | 37 | 38 | [ELIF] 39 | max-nested-blocks=6 40 | 41 | 42 | [DESIGN] 43 | max-complexity=10 44 | max-args=12 45 | ignored-argument-names=_.* 46 | max-locals=25 47 | max-parents=15 48 | max-attributes=10 49 | min-public-methods=0 50 | max-public-methods=15 51 | max-bool-expr=5 52 | 53 | 54 | [EXCEPTIONS] 55 | overgeneral-exceptions=Exception 56 | -------------------------------------------------------------------------------- /terra/map_generator.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import jax.numpy as jnp 4 | from jax import Array 5 | 6 | from terra.settings import IntMap 7 | 8 | 9 | class MapParams(NamedTuple): 10 | edge_min: IntMap 11 | edge_max: IntMap 12 | depth: int = -1 13 | 14 | 15 | class GridMap(NamedTuple): 16 | """ 17 | Clarifications on the map representation. 18 | 19 | The x axis corresponds to the first dimension of the map matrix. 20 | The y axis to the second. 21 | The origin is on the top left corner of the map matrix. 22 | 23 | The term "width" is associated with the x direction. 24 | The term "height" is associated with the y direction. 25 | """ 26 | 27 | map: IntMap 28 | 29 | @property 30 | def width(self) -> int: 31 | return self.map.shape[0] 32 | 33 | @property 34 | def height(self) -> int: 35 | return self.map.shape[1] 36 | 37 | @staticmethod 38 | def new(map: Array) -> "GridMap": 39 | assert len(map.shape) == 2 40 | 41 | return GridMap(map=map) 42 | 43 | @staticmethod 44 | def dummy_map() -> "GridMap": 45 | return GridMap.new(jnp.full((1, 1), fill_value=0, dtype=jnp.bool_)) 46 | -------------------------------------------------------------------------------- /terra/viz/game/utils.py: -------------------------------------------------------------------------------- 1 | import pygame as pg 2 | 3 | 4 | def rotate_triangle(center, points, scale, angle): 5 | vCenter = pg.math.Vector2(center) 6 | 7 | rotated_point = [pg.math.Vector2(p).rotate(angle) for p in points] 8 | 9 | triangle_points = [(vCenter + p * scale) for p in rotated_point] 10 | return triangle_points 11 | 12 | 13 | def validate_angle_divisions(angle_count): 14 | if 360 % angle_count != 0: 15 | raise ValueError(f"Angle count must divide 360 evenly, got {angle_count}") 16 | if angle_count % 4 != 0: 17 | raise ValueError(f"Angle count must be a multiple of 4, got {angle_count}") 18 | 19 | 20 | def agent_base_to_angle(agent_base, base_angles=8): 21 | """ 22 | Args: 23 | agent_base: The index of the base direction 24 | base_angles: Total number of possible base directions (default: 8) 25 | 26 | Returns: 27 | The angle in degrees 28 | """ 29 | 30 | validate_angle_divisions(base_angles) 31 | angle_increment = 360 / base_angles 32 | return (360 - (agent_base * angle_increment)) % 360 33 | 34 | 35 | def agent_cabin_to_angle(agent_cabin, cabin_angles=8): 36 | """ 37 | Args: 38 | agent_cabin: The index of the cabin direction 39 | cabin_angles: Total number of possible cabin directions (default: 8) 40 | 41 | Returns: 42 | The angle in degrees 43 | """ 44 | validate_angle_divisions(cabin_angles) 45 | angle_increment = 360 / cabin_angles 46 | return (360 - (agent_cabin * angle_increment)) % 360 47 | -------------------------------------------------------------------------------- /terra/viz/play.py: -------------------------------------------------------------------------------- 1 | import time 2 | import jax 3 | import jax.numpy as jnp 4 | import pygame as pg 5 | from pygame.locals import ( 6 | KEYDOWN, 7 | QUIT, 8 | ) 9 | from terra.config import EnvConfig 10 | from terra.env import TerraEnvBatch 11 | 12 | 13 | def main(): 14 | n_envs_x = 4 15 | n_envs_y = 10 16 | n_envs = n_envs_x * n_envs_y 17 | seed = 24 18 | rng = jax.random.PRNGKey(seed) 19 | shuffle_maps = True 20 | env = TerraEnvBatch( 21 | rendering=True, 22 | display=True, 23 | n_envs_x_rendering=n_envs_x, 24 | n_envs_y_rendering=n_envs_y, 25 | shuffle_maps=shuffle_maps, 26 | ) 27 | 28 | print("Starting the environment...") 29 | start_time = time.time() 30 | env_cfgs = jax.vmap(lambda x: EnvConfig.new())(jnp.arange(n_envs)) 31 | rng, _rng = jax.random.split(rng) 32 | _rng = jax.random.split(_rng, n_envs) 33 | timestep = env.reset(env_cfgs, _rng) 34 | end_time = time.time() 35 | print(f"Environment started. Compilation time: {end_time - start_time} seconds.") 36 | print("Press any key to query the next set of environments.") 37 | playing = True 38 | while playing: 39 | for event in pg.event.get(): 40 | if event.type == KEYDOWN: 41 | rng, _rng = jax.random.split(rng) 42 | _rng = jax.random.split(_rng, n_envs) 43 | timestep = env.reset(env_cfgs, _rng) 44 | 45 | elif event.type == QUIT: 46 | playing = False 47 | 48 | env.terra_env.render_obs_pygame( 49 | timestep.observation, 50 | ) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /config/env_generation_config.yaml: -------------------------------------------------------------------------------- 1 | resolution: 1 # Resolution of the images in meters per pixel 2 | n_imgs: 2000 # it has to be the same else we can't stack images together 3 | sizes: [64] 4 | 5 | trenches: 6 | difficulty_levels: ["single", "double", "double_diagonal", "triple", "triple_diagonal"] 7 | trenches_per_level: [[1, 1], [2, 2], [2, 2], [3, 3], [3, 3]] 8 | trench_dims: # Updated section 9 | single: 10 | min_ratio: [0.04, 0.05] 11 | max_ratio: [0.2, 0.3] 12 | diagonal: False 13 | double: 14 | min_ratio: [0.06, 0.08] 15 | max_ratio: [0.25, 0.5] 16 | diagonal: False 17 | double_diagonal: 18 | min_ratio: [0.06, 0.08] 19 | max_ratio: [0.25, 0.5] 20 | diagonal: True 21 | triple: 22 | min_ratio: [0.06, 0.08] 23 | max_ratio: [0.25, 0.5] 24 | diagonal: False 25 | triple_diagonal: 26 | min_ratio: [0.06, 0.08] 27 | max_ratio: [0.25, 0.5] 28 | diagonal: True 29 | img_edge_min: 64 30 | img_edge_max: 64 31 | # obstacles 32 | n_obs_min: 1 33 | n_obs_max: 3 34 | size_obstacle_min: 5 35 | size_obstacle_max: 9 36 | # dumping constraints 37 | n_nodump_min: 1 38 | n_nodump_max: 3 39 | size_nodump_min: 8 40 | size_nodump_max: 12 41 | 42 | foundations: 43 | dataset_rel_path: "data/openstreet/" 44 | min_size: 8 45 | max_size: 64 46 | max_buildings: 3000 47 | 48 | relocations: 49 | img_edge_min: 64 50 | img_edge_max: 64 51 | n_dump_min: 1 52 | n_dump_max: 3 53 | size_dump_min: 10 54 | size_dump_max: 16 55 | n_obs_min: 1 56 | n_obs_max: 3 57 | size_obstacle_min: 5 58 | size_obstacle_max: 9 59 | n_nodump_min: 1 60 | n_nodump_max: 2 61 | size_nodump_min: 7 62 | size_nodump_max: 10 63 | n_dirt_min: 1 64 | n_dirt_max: 3 65 | size_dirt_min: 6 66 | size_dirt_max: 10 67 | -------------------------------------------------------------------------------- /benchmark/plot_benchmark.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | 4 | if __name__ == "__main__": 5 | policy = "random" 6 | n_steps_per_env = 100 7 | gpu_path = "/home/antonio/thesis/benchmarks/terra/Benchmark_random_gpu:0_2023_04_20_14_50_44.csv" 8 | cpu_path = "/home/antonio/thesis/benchmarks/terra/Benchmark_random_TFRT_CPU_0_2023_04_20_14_51_15.csv" 9 | 10 | # config_path = f"/media/benchmarks/terra/Config_{policy}_{device}_{now}.csv" 11 | 12 | gpu_df = pd.read_csv(gpu_path).to_dict() 13 | cpu_df = pd.read_csv(cpu_path).to_dict() 14 | 15 | batch_sizes = gpu_df["batch_size"].values() 16 | gpu_times = gpu_df["time"] 17 | cpu_times = cpu_df["time"] 18 | 19 | gpu_step_env = gpu_df["avg_step_env"] 20 | cpu_step_env = cpu_df["avg_step_env"] 21 | 22 | # Times 23 | fig = plt.figure(0) 24 | plt.scatter(batch_sizes, gpu_times.values(), c="g", label="gpu") 25 | plt.plot(batch_sizes, gpu_times.values(), c="g", label="gpu") 26 | plt.scatter(batch_sizes, cpu_times.values(), c="r", label="cpu") 27 | plt.plot(batch_sizes, cpu_times.values(), c="r", label="cpu") 28 | plt.xlabel("number of environments") 29 | plt.ylabel("duration (s)") 30 | plt.title(f"Terra - {n_steps_per_env} steps per environment - {policy} policy") 31 | plt.legend() 32 | plt.xscale("log") 33 | plt.yscale("log") 34 | 35 | fig = plt.figure(1) 36 | plt.scatter(batch_sizes, gpu_step_env.values(), c="g", label="gpu") 37 | plt.plot(batch_sizes, gpu_step_env.values(), c="g", label="gpu") 38 | plt.scatter(batch_sizes, cpu_step_env.values(), c="r", label="cpu") 39 | plt.plot(batch_sizes, cpu_step_env.values(), c="r", label="cpu") 40 | plt.xlabel("number of environments") 41 | plt.ylabel("avg step duration (s)") 42 | plt.title(f"Terra - {n_steps_per_env} steps per environment - {policy} policy") 43 | plt.legend() 44 | plt.xscale("log") 45 | plt.yscale("log") 46 | 47 | plt.show() 48 | -------------------------------------------------------------------------------- /terra/env_generation/visualize_maps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | 6 | 7 | def visualize_and_save_map(data, folder, filename, title): 8 | plt.figure(figsize=(6, 6)) 9 | plt.imshow(data, cmap="viridis", interpolation="nearest") 10 | plt.title(title) 11 | plt.axis("off") 12 | plt.savefig(folder / f"{filename}.jpg", bbox_inches="tight") 13 | plt.close() 14 | 15 | 16 | def visualize_maps_recursive( 17 | base_folder, map_categories=["images", "occupancy", "dumpability"] 18 | ): 19 | base_folder = Path(base_folder) 20 | if not base_folder.exists(): 21 | print(f"No folder found for {base_folder}, skipping.") 22 | return 23 | 24 | for category in map_categories: 25 | category_folder = base_folder / category 26 | if category_folder.exists(): 27 | image_output_folder = category_folder / "visualized" 28 | image_output_folder.mkdir(parents=True, exist_ok=True) 29 | 30 | npy_files = list(category_folder.glob("*.npy")) 31 | for npy_file in tqdm( 32 | npy_files, desc=f"Processing {category} in {base_folder.name}" 33 | ): 34 | map_data = np.load(npy_file) 35 | print(map_data.shape) 36 | filename = npy_file.stem # Removes the file extension 37 | visualize_and_save_map( 38 | map_data, 39 | image_output_folder, 40 | filename, 41 | f"{category.capitalize()}: {filename}", 42 | ) 43 | 44 | print( 45 | f"Visualization complete for {category}. Images saved in {image_output_folder}" 46 | ) 47 | else: 48 | # If the current category folder doesn't exist, check for subdirectories to recurse into 49 | for subfolder in base_folder.iterdir(): 50 | if subfolder.is_dir(): 51 | visualize_maps_recursive(subfolder, map_categories) 52 | 53 | 54 | if __name__ == "__main__": 55 | digbench_path = Path(__file__).resolve().parents[1] 56 | visualize_maps_recursive( 57 | "/home/lorenzo/git/terra_jax/terra/data/terra/train/foundations", 58 | map_categories=["occupancy"], 59 | ) 60 | -------------------------------------------------------------------------------- /benchmark/benchmark.py: -------------------------------------------------------------------------------- 1 | import time 2 | from time import gmtime 3 | from time import strftime 4 | 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from terra.actions import TrackedActionType 10 | from terra.config import EnvConfig 11 | from terra.env import TerraEnvBatch 12 | 13 | if __name__ == "__main__": 14 | policy = "random" 15 | 16 | device = jnp.ones(1).device_buffer.device() 17 | print(f"Device = {device}\n") 18 | 19 | now = strftime("%Y_%m_%d_%H_%M_%S", gmtime()) 20 | benchmark_path = f"/media/benchmarks/terra/Benchmark_{policy}_{device}_{now}.csv" 21 | config_path = f"/media/benchmarks/terra/Config_{policy}_{device}_{now}.csv" 22 | 23 | batch_sizes = [1e3, 1e4, 1e5, 1e6] 24 | episode_length = 100 25 | 26 | benchmark_dict = { 27 | "batch_size": [], 28 | "time": [], 29 | "avg_step_batch": [], 30 | "avg_step_env": [], 31 | } 32 | for batch_size in batch_sizes: 33 | batch_size = int(batch_size) 34 | print("\n") 35 | print(f"{batch_size=}") 36 | print(f"{episode_length=}") 37 | 38 | seeds = np.random.randint(0, 1000000, (batch_size)) 39 | env_batch = TerraEnvBatch(env_cfg=EnvConfig()) 40 | states = env_batch.reset(seeds) 41 | 42 | duration = 0 43 | for i in range(episode_length): 44 | actions = np.random.randint( 45 | TrackedActionType.FORWARD, TrackedActionType.DO + 1, (batch_size) 46 | ) 47 | s = time.time() 48 | _, (states, reward, dones, infos) = env_batch.step(states, actions) 49 | e = time.time() 50 | duration += e - s 51 | 52 | benchmark_dict["batch_size"].append(batch_size) 53 | benchmark_dict["time"].append(duration) 54 | benchmark_dict["avg_step_batch"].append((duration) / episode_length) 55 | benchmark_dict["avg_step_env"].append( 56 | (duration) / (episode_length * batch_size) 57 | ) 58 | 59 | print(f"Duration = {benchmark_dict['time'][-1]}") 60 | print( 61 | f"Average step duration per batch = {benchmark_dict['avg_step_batch'][-1]}" 62 | ) 63 | print( 64 | f"Average step duration per environment = {benchmark_dict['avg_step_env'][-1]}" 65 | ) 66 | 67 | benchmark_df = pd.DataFrame(benchmark_dict) 68 | benchmark_df.to_csv(benchmark_path, index=False) 69 | -------------------------------------------------------------------------------- /terra/map_utils/test_openstreet_plugin.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import cv2 4 | import numpy as np 5 | import skimage.measure 6 | 7 | wm, hm = 60, 60 # meters 8 | 9 | div = 10 10 | w, h = div * wm, div * hm 11 | 12 | for i in range(2700): 13 | try: 14 | with open( 15 | f"/home/antonio/Downloads/openstreet_v1/metadata/building_{i}.json" 16 | ) as f: 17 | meta1 = json.load(f) 18 | w1 = int(meta1["real_dimensions"]["height"]) 19 | h1 = int(meta1["real_dimensions"]["width"]) 20 | if w1 < 40 and h1 < 40: 21 | break 22 | except: 23 | continue 24 | 25 | for j in range(i + 1, 2700): 26 | try: 27 | with open( 28 | f"/home/antonio/Downloads/openstreet_v1/metadata/building_{j}.json" 29 | ) as f: 30 | meta2 = json.load(f) 31 | w2 = int(meta2["real_dimensions"]["height"]) 32 | h2 = int(meta2["real_dimensions"]["width"]) 33 | if w2 < 40 and h2 < 40: 34 | break 35 | except: 36 | continue 37 | 38 | print(f"{(w1, h1)=}") 39 | print(f"{(w2, h2)=}") 40 | 41 | building_img_path1 = f"/home/antonio/Downloads/openstreet_v1/images/building_{i}.png" 42 | building_img_path2 = f"/home/antonio/Downloads/openstreet_v1/images/building_{j}.png" 43 | 44 | 45 | pic1 = cv2.imread(building_img_path1) 46 | pic2 = cv2.imread(building_img_path2) 47 | print(f"{pic1.shape=}") 48 | print(f"{pic2.shape=}") 49 | 50 | p = np.ones((w, h, 3)) * 255 51 | 52 | pic1 = cv2.resize(pic1, (w1 * div, h1 * div)).astype( 53 | np.uint8 54 | ) # , interpolation=cv2.INTER_AREA) 55 | pic2 = cv2.resize(pic2, (w2 * div, h2 * div)).astype( 56 | np.uint8 57 | ) # , interpolation=cv2.INTER_AREA) 58 | 59 | 60 | print(f"{pic1.shape=}") 61 | print(f"{pic2.shape=}") 62 | 63 | p[: pic1.shape[0], : pic1.shape[1]] = pic1 64 | p[-pic2.shape[0] :, -pic2.shape[1] :] = pic2 65 | 66 | # pd = cv2.resize(p, (40, 40)).repeat(10, 0).repeat(10, 1) 67 | pd = ( 68 | skimage.measure.block_reduce(p, (p.shape[0] // wm, p.shape[1] // hm, 1), np.min) 69 | .astype(np.uint8) 70 | .repeat(10, 0) 71 | .repeat(10, 1) 72 | ) 73 | print(f"{pd.shape=}") 74 | 75 | pd = np.where(pd == 255, 255, 0).astype(np.uint8) 76 | p_grey = np.where(p < 255, 100, p).astype(np.uint8) 77 | # p += pic1 + pic2 78 | 79 | # cv2.imshow("building1", pic1) 80 | # cv2.imshow("building2", pic2) 81 | # cv2.imshow("buildings", p) 82 | # cv2.imshow("building downsampled", pd) 83 | cv2.imshow("building comparison", np.where(p_grey == 100, 100, pd).astype(np.uint8)) 84 | cv2.waitKey(0) 85 | cv2.destroyAllWindows() 86 | -------------------------------------------------------------------------------- /terra/viz/main_manual.py: -------------------------------------------------------------------------------- 1 | import time 2 | import jax 3 | import jax.numpy as jnp 4 | import pygame as pg 5 | from pygame.locals import ( 6 | K_UP, 7 | K_DOWN, 8 | K_LEFT, 9 | K_RIGHT, 10 | K_a, 11 | K_d, 12 | K_k, 13 | K_l, 14 | K_SPACE, 15 | KEYDOWN, 16 | QUIT, 17 | ) 18 | from terra.config import BatchConfig 19 | from terra.config import EnvConfig 20 | from terra.env import TerraEnvBatch 21 | 22 | 23 | def main(): 24 | batch_cfg = BatchConfig() 25 | action_type = batch_cfg.action_type 26 | n_envs_x = 1 27 | n_envs_y = 1 28 | n_envs = n_envs_x * n_envs_y 29 | seed = 24 30 | rng = jax.random.PRNGKey(seed) 31 | env = TerraEnvBatch( 32 | rendering=True, 33 | display=True, 34 | n_envs_x_rendering=n_envs_x, 35 | n_envs_y_rendering=n_envs_y, 36 | ) 37 | 38 | print("Starting the environment...") 39 | start_time = time.time() 40 | env_cfgs = jax.vmap(lambda x: EnvConfig.new())(jnp.arange(n_envs)) 41 | rng, _rng = jax.random.split(rng) 42 | _rng = _rng[None] 43 | timestep = env.reset(env_cfgs, _rng) 44 | print(f"{timestep.state.agent.width=}") 45 | print(f"{timestep.state.agent.height=}") 46 | 47 | rng, _rng = jax.random.split(rng) 48 | _rng = _rng[None] 49 | 50 | def repeat_action(action, n_times=n_envs): 51 | return action_type.new(action.action[None].repeat(n_times, 0)) 52 | 53 | # Trigger the JIT compilation 54 | timestep = env.step(timestep, repeat_action(action_type.do_nothing()), _rng) 55 | end_time = time.time() 56 | print(f"Environment started. Compilation time: {end_time - start_time} seconds.") 57 | 58 | playing = True 59 | while playing: 60 | for event in pg.event.get(): 61 | if event.type == KEYDOWN: 62 | action = None 63 | if event.key == K_UP: 64 | action = action_type.forward() 65 | elif event.key == K_DOWN: 66 | action = action_type.backward() 67 | elif event.key == K_LEFT: 68 | action = action_type.anticlock() 69 | elif event.key == K_RIGHT: 70 | action = action_type.clock() 71 | elif event.key == K_a: 72 | action = action_type.cabin_anticlock() 73 | elif event.key == K_d: 74 | action = action_type.cabin_clock() 75 | elif event.key == K_k: 76 | action = action_type.wheels_left() 77 | elif event.key == K_l: 78 | action = action_type.wheels_right() 79 | elif event.key == K_SPACE: 80 | action = action_type.do() 81 | 82 | if action is not None: 83 | print("Action: ", action) 84 | rng, _rng = jax.random.split(rng) 85 | _rng = _rng[None] 86 | timestep = env.step( 87 | timestep, 88 | repeat_action(action), 89 | _rng, 90 | ) 91 | print("Reward: ", timestep.reward.item()) 92 | 93 | elif event.type == QUIT: 94 | playing = False 95 | 96 | env.terra_env.render_obs_pygame( 97 | timestep.observation, 98 | timestep.info, 99 | ) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /terra/viz/game/agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from .settings import COLORS 4 | from .utils import agent_base_to_angle 5 | from .utils import agent_cabin_to_angle 6 | from .utils import rotate_triangle 7 | 8 | 9 | class Agent: 10 | def __init__(self, width, height, tile_size, angles_base, angles_cabin) -> None: 11 | self.width = width 12 | self.height = height 13 | self.tile_size = tile_size 14 | self.angles_base = angles_base 15 | self.angles_cabin = angles_cabin 16 | 17 | def create_agent(self, px_center, py_center, angle_base, angle_cabin, loaded): 18 | # Convert angle_base index to degrees using the util function 19 | base_angle_degrees = agent_base_to_angle(angle_base, self.angles_base) 20 | 21 | # Calculate center position in pixels 22 | center_x = px_center * self.tile_size 23 | center_y = py_center * self.tile_size 24 | 25 | # Calculate half-dimensions for the body 26 | half_width = (self.width * self.tile_size) / 2 27 | half_height = (self.height * self.tile_size) / 2 28 | 29 | # Create the four corners of the rectangle (unrotated) 30 | rect_points = [ 31 | (-half_width, -half_height), # top-left 32 | (half_width, -half_height), # top-right 33 | (half_width, half_height), # bottom-right 34 | (-half_width, half_height) # bottom-left 35 | ] 36 | 37 | # Rotate and translate each point 38 | agent_body = [] 39 | for x, y in rect_points: 40 | # Convert to radians for trigonometric calculations 41 | angle_rad = math.radians(base_angle_degrees) 42 | # Apply rotation 43 | rotated_x = x * math.cos(angle_rad) - y * math.sin(angle_rad) 44 | rotated_y = x * math.sin(angle_rad) + y * math.cos(angle_rad) 45 | # Translate to actual position 46 | agent_body.append((center_y + rotated_x, center_x + rotated_y)) 47 | 48 | # Use the actual dimensions for the agent 49 | w = self.width 50 | h = self.height 51 | 52 | # Calculate cabin angle (global orientation) 53 | cabin_relative_degrees = agent_cabin_to_angle(angle_cabin, self.angles_cabin) 54 | global_cabin_angle = (cabin_relative_degrees + base_angle_degrees) % 360 55 | 56 | # Create cabin triangle 57 | scaling = self.tile_size / 3 58 | points = [ 59 | (3 / scaling, 0), 60 | (-1.5 / scaling, -1.5 / scaling), 61 | (-1.5 / scaling, 1.5 / scaling), 62 | ] 63 | agent_cabin = rotate_triangle( 64 | (center_y, center_x), points, self.tile_size, global_cabin_angle 65 | ) 66 | 67 | out = { 68 | "body": { 69 | "vertices": agent_body, 70 | "width": w, 71 | "height": h, 72 | "color": COLORS["agent_body"], 73 | }, 74 | "cabin": { 75 | "vertices": agent_cabin, 76 | "color": COLORS["agent_cabin"]["loaded"] 77 | if loaded 78 | else COLORS["agent_cabin"]["not_loaded"], 79 | }, 80 | } 81 | return out 82 | 83 | def update(self, agent_pos, base_dir, cabin_dir, loaded): 84 | agent_pos = np.asarray(agent_pos, dtype=np.int32) 85 | base_dir = np.asarray(base_dir, dtype=np.int32) 86 | cabin_dir = np.asarray(cabin_dir, dtype=np.int32) 87 | loaded = np.asarray(loaded, dtype=bool) 88 | self.agent = self.create_agent( 89 | agent_pos[0].item(), 90 | agent_pos[1].item(), 91 | base_dir.item(), 92 | cabin_dir.item(), 93 | loaded.item(), 94 | ) 95 | -------------------------------------------------------------------------------- /terra/map_utils/jax_terrain_generation.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.scipy.signal import convolve2d 4 | from terra.settings import IntMap 5 | 6 | 7 | def generate_clustered_bitmap( 8 | width: int, 9 | height: int, 10 | n_clusters: int, 11 | n_tiles_per_cluster: int, 12 | kernel_size_aggregation: int, 13 | kernel_size_initial_sampling: int, 14 | key: int, 15 | placeholder=None, 16 | ): 17 | kernel_init = jnp.ones((kernel_size_initial_sampling, kernel_size_initial_sampling)) 18 | 19 | def _loop_init(i, carry): 20 | """ 21 | Init the map by sampling spaced tiles based on the position 22 | of the previously sampled ones 23 | (low chance to sample neighbouring tiles). 24 | """ 25 | key, map, set_value = carry 26 | 27 | mask = (map != 0).astype(IntMap) 28 | mask_convolved = convolve2d(mask, kernel_init, mode="same", boundary="fill") 29 | mask_convolved_opposite = (mask_convolved <= mask_convolved.min()).astype( 30 | IntMap 31 | ) 32 | p_sampling = mask_convolved_opposite / mask_convolved_opposite.sum() 33 | 34 | key, subkey = jax.random.split(key) 35 | idx = jax.random.choice( 36 | subkey, jnp.arange(start=0, stop=width * height), p=p_sampling.reshape(-1) 37 | ) 38 | map = map.reshape(-1).at[idx].set(set_value).reshape(width, height) 39 | 40 | set_value *= -1 41 | carry = key, map, set_value 42 | return carry 43 | 44 | carry = key, jnp.zeros((width, height), dtype=IntMap), -1 45 | carry = jax.lax.fori_loop( 46 | lower=0, upper=n_clusters, body_fun=_loop_init, init_val=carry 47 | ) 48 | key, map, _ = carry 49 | 50 | def _loop(i, carry): 51 | key, map, set_value = carry 52 | 53 | mask = jax.lax.cond( 54 | set_value < 0, 55 | lambda: (map < 0).astype(IntMap), 56 | lambda: (map > 0).astype(IntMap), 57 | ) 58 | 59 | mask_probs = convolve2d(mask, kernel, mode="same", boundary="fill") 60 | mask_probs = mask_probs * (~(mask).astype(jnp.bool_)) 61 | mask_probs = mask_probs / mask_probs.sum() 62 | 63 | key, subkey = jax.random.split(key) 64 | # 75% random sample, 25% argmax 65 | do_random_sample = jax.random.randint(subkey, (), 0, 4).astype(jnp.bool_) 66 | 67 | def _random_sample(key, map): 68 | key, *subkeys = jax.random.split(key, 3) 69 | next_tile_idx = jax.random.choice( 70 | subkeys[1], jnp.arange(0, width * height), p=mask_probs.reshape(-1) 71 | ) 72 | map = ( 73 | map.reshape(-1).at[next_tile_idx].set(set_value).reshape(width, height) 74 | ) 75 | return key, map 76 | 77 | def _argmax(key, map): 78 | next_tile_idx = jnp.argmax(mask_probs.reshape(-1)) 79 | map = ( 80 | map.reshape(-1).at[next_tile_idx].set(set_value).reshape(width, height) 81 | ) 82 | return key, map 83 | 84 | key, map = jax.lax.cond(do_random_sample, _random_sample, _argmax, key, map) 85 | 86 | set_value *= -1 87 | carry = key, map, set_value 88 | return carry 89 | 90 | kernel = jnp.ones((kernel_size_aggregation, kernel_size_aggregation)) 91 | 92 | carry = key, map, -1 93 | carry = jax.lax.fori_loop(0, n_tiles_per_cluster * n_clusters, _loop, carry) 94 | key, map, _ = carry 95 | return map, key 96 | 97 | 98 | if __name__ == "__main__": 99 | key = jax.random.PRNGKey(131) 100 | map, key = generate_clustered_bitmap( 101 | 10, 102 | 10, 103 | 4, 104 | 3, 105 | 5, 106 | key, 107 | ) 108 | import numpy as np 109 | 110 | # import cv2 111 | map = np.array(map) 112 | print(f"{map=}") 113 | -------------------------------------------------------------------------------- /terra/env_generation/generate_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import yaml 4 | import argparse 5 | from terra.env_generation.generate_foundations import download_foundations, create_foundations 6 | from terra.env_generation.create_train_data import ( 7 | create_procedural_trenches, 8 | create_foundations as create_train_foundations 9 | ) 10 | from terra.env_generation.generate_relocations import create_relocations 11 | import terra.env_generation.convert_to_terra as convert_to_terra 12 | 13 | def generate_complete_dataset(config_path="config/env_generation/config.yml"): 14 | """ 15 | Generate a complete dataset in one go - combining foundations generation and training data creation. 16 | 17 | Args: 18 | config_path: Path to the configuration file 19 | """ 20 | # Get the package directory 21 | package_dir = os.path.dirname( 22 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 23 | ) 24 | 25 | # Load configuration 26 | with open(package_dir + "/" + config_path, "r") as file: 27 | config = yaml.safe_load(file) 28 | 29 | # Create necessary directories 30 | os.makedirs("data/", exist_ok=True) 31 | os.makedirs("data/terra/", exist_ok=True) 32 | os.makedirs("data/openstreet/", exist_ok=True) 33 | 34 | n_imgs = config["n_imgs"] 35 | 36 | print("Step 1: Downloading and processing foundation maps...") 37 | # Download foundations 38 | # Read foundation parameters from the config file 39 | foundations_config = config.get("foundations", {}) 40 | # Backup to using sizes list if not provided in foundations config 41 | if "min_size" in foundations_config and "max_size" in foundations_config: 42 | foundation_min_size = foundations_config.get("min_size") 43 | foundation_max_size = foundations_config.get("max_size") 44 | else: 45 | raise ValueError("min_size and max_size must be provided in the config file") 46 | max_buildings = foundations_config.get("max_buildings", 100) 47 | print(f"max_buildings: {max_buildings}") 48 | 49 | print(f"Foundation min_size: {foundation_min_size}, max_size: {foundation_max_size}, max_buildings: {max_buildings}") 50 | 51 | # Get bounding box from config, or use default 52 | bbox = config.get("center_bbox", (47.5376, 47.6126, 7.5401, 7.6842)) 53 | 54 | # Download foundations 55 | dataset_folder = os.path.join(package_dir, "data", "openstreet") 56 | download_foundations( 57 | dataset_folder, 58 | min_size=(foundation_min_size, foundation_min_size), 59 | max_size=(foundation_max_size, foundation_max_size), 60 | center_bbox=bbox, 61 | max_buildings=max_buildings 62 | ) 63 | create_foundations(dataset_folder) 64 | 65 | print("Step 2: Creating procedural trenches and processing training data...") 66 | # Create procedural trenches 67 | create_procedural_trenches(config) 68 | 69 | # Process foundation maps for training 70 | create_train_foundations(config) 71 | 72 | # Create relocations maps 73 | relocations_config = config.get("relocations", {}) 74 | create_relocations(relocations_config, n_imgs) 75 | 76 | # Generate Terra format datasets 77 | print("Step 3: Converting data to Terra format...") 78 | sizes = [(size, size) for size in config["sizes"]] 79 | npy_dataset_folder = package_dir + "/data/terra" 80 | for size in sizes: 81 | convert_to_terra.generate_dataset_terra_format(npy_dataset_folder, size, n_imgs) 82 | 83 | print("Dataset generation complete!") 84 | print(f"Data saved to {os.path.join(package_dir, 'data/terra')}") 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser(description="Generate complete Terra training dataset.") 88 | parser.add_argument( 89 | "--config", 90 | type=str, 91 | default="config/env_generation_config.yaml", 92 | help="Path to the configuration file" 93 | ) 94 | args = parser.parse_args() 95 | 96 | generate_complete_dataset(args.config) -------------------------------------------------------------------------------- /terra/map.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import jax.numpy as jnp 4 | from jax import Array 5 | 6 | from terra.map_generator import GridMap 7 | from terra.settings import IntLowDim 8 | 9 | 10 | class GridWorld(NamedTuple): 11 | """ 12 | Here we define the encoding of the maps. 13 | - target map 14 | - 1: must dump here to terminate the episode 15 | - 0: free 16 | - -1: must dig here 17 | - action map 18 | - -1: dug here during the episode 19 | - 0: free 20 | - greater than 0: dumped here 21 | - dumpability mask 22 | - 1: can dump 23 | - 0: can't dump 24 | - padding mask 25 | - 0: traversable 26 | - 1: non traversable 27 | - traversability mask 28 | - -1: agent occupancy 29 | - 0: traversable 30 | - 1: non traversable 31 | - last dig mask 32 | - 1: dug here during previous dig action 33 | - 0: not dug here during previous dig action 34 | - local map target positive (contains the sum of all the positive target map tiles in a given workspace) 35 | - local map target negative (contains the sum of all the negative target map tiles in a given workspace) 36 | - local map action positive (contains the sum of all the positive action map tiles in a given workspace) 37 | - local map action negative (contains the sum of all the negative action map tiles in a given workspace) 38 | - local obstacles map (contains the sum of all the padding mask tiles in a given workspace) 39 | - local dumpability mask (contains the sum of all the dumpability mask tiles in a given workspace) 40 | """ 41 | 42 | target_map: GridMap 43 | action_map: GridMap 44 | padding_mask: GridMap 45 | dumpability_mask: GridMap 46 | dumpability_mask_init: GridMap 47 | last_dig_mask: GridMap 48 | 49 | trench_axes: Array 50 | trench_type: jnp.int32 # type of trench (number of branches), or -1 if not a trench 51 | 52 | # Dummies for wrappers 53 | traversability_mask: GridMap = GridMap.dummy_map() 54 | local_map_target_pos: GridMap = GridMap.dummy_map() 55 | local_map_target_neg: GridMap = GridMap.dummy_map() 56 | local_map_action_pos: GridMap = GridMap.dummy_map() 57 | local_map_action_neg: GridMap = GridMap.dummy_map() 58 | local_map_dumpability: GridMap = GridMap.dummy_map() 59 | local_map_obstacles: GridMap = GridMap.dummy_map() 60 | 61 | @property 62 | def width(self) -> int: 63 | return self.target_map.width 64 | 65 | @property 66 | def height(self) -> int: 67 | return self.target_map.height 68 | 69 | @property 70 | def max_traversable_x(self) -> int: 71 | return (self.padding_mask.map[:, 0] == 0).sum() 72 | 73 | @property 74 | def max_traversable_y(self) -> int: 75 | return (self.padding_mask.map[0] == 0).sum() 76 | 77 | @classmethod 78 | def new( 79 | cls, 80 | target_map: Array, 81 | padding_mask: Array, 82 | trench_axes: Array, 83 | trench_type: Array, 84 | dumpability_mask_init: Array, 85 | action_map: Array, 86 | ) -> "GridWorld": 87 | action_map = GridMap.new(IntLowDim(action_map)) 88 | target_map = GridMap.new(IntLowDim(target_map)) 89 | padding_mask = GridMap.new(IntLowDim(padding_mask)) 90 | dumpability_mask_init_gm = GridMap.new(dumpability_mask_init.astype(jnp.bool_)) 91 | dumpability_mask = GridMap.new(dumpability_mask_init.astype(jnp.bool_)) 92 | last_dig_mask = GridMap.new(jnp.zeros_like(target_map.map, dtype=jnp.bool_)) 93 | 94 | world = cls( 95 | target_map=target_map, 96 | action_map=action_map, 97 | padding_mask=padding_mask, 98 | trench_axes=trench_axes, 99 | trench_type=trench_type, 100 | dumpability_mask=dumpability_mask, 101 | dumpability_mask_init=dumpability_mask_init_gm, 102 | last_dig_mask=last_dig_mask, 103 | ) 104 | 105 | return world 106 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # IMPORTANT: The correct version of laxlib is NOT installed by default since 2 | # it is not available in the conda-forge channel. You must install it manually 3 | # within the environment. Use the following command: 4 | # pip install -U "jax[cuda12]" 5 | 6 | name: terra 7 | channels: 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=main 11 | - _openmp_mutex=5.1=1_gnu 12 | - bzip2=1.0.8=h5eee18b_5 13 | - ca-certificates=2023.12.12=h06a4308_0 14 | - expat=2.5.0=h6a678d5_0 15 | - ld_impl_linux-64=2.38=h1181459_1 16 | - libffi=3.4.4=h6a678d5_0 17 | - libgcc-ng=11.2.0=h1234567_1 18 | - libgomp=11.2.0=h1234567_1 19 | - libstdcxx-ng=11.2.0=h1234567_1 20 | - libuuid=1.41.5=h5eee18b_0 21 | - ncurses=6.4=h6a678d5_0 22 | - openssl=3.0.13=h7f8727e_0 23 | - python=3.12.2=h996f2a0_0 24 | - readline=8.2=h5eee18b_0 25 | - setuptools=68.2.2=py312h06a4308_0 26 | - sqlite=3.41.2=h5eee18b_0 27 | - tk=8.6.12=h1ccaba5_0 28 | - wheel=0.41.2=py312h06a4308_0 29 | - xz=5.4.6=h5eee18b_0 30 | - zlib=1.2.13=h5eee18b_0 31 | - pip: 32 | - absl-py==2.1.0 33 | - appdirs==1.4.4 34 | - astunparse==1.6.3 35 | - attrs==23.2.0 36 | - black==24.3.0 37 | - certifi==2024.2.2 38 | - cfgv==3.4.0 39 | - charset-normalizer==3.3.2 40 | - chex==0.1.85 41 | - click==8.1.7 42 | - click-plugins==1.1.1 43 | - cligj==0.7.2 44 | - cloudpickle==3.0.0 45 | - contourpy==1.2.0 46 | - cycler==0.12.1 47 | - decorator==5.1.1 48 | - distlib==0.3.8 49 | - dm-tree==0.1.8 50 | - docker-pycreds==0.4.0 51 | - etils==1.7.0 52 | - filelock==3.13.1 53 | - fiona==1.9.6 54 | - flake8==7.0.0 55 | - flatbuffers==24.3.7 56 | - flax==0.8.2 57 | - fonttools==4.50.0 58 | - fsspec==2024.3.0 59 | - gast==0.5.4 60 | - geopandas==0.14.3 61 | - gitdb==4.0.11 62 | - gitpython==3.1.42 63 | - google-pasta==0.2.0 64 | - grpcio==1.62.1 65 | - gviz-api==1.10.0 66 | - h5py==3.10.0 67 | - identify==2.5.35 68 | - idna==3.6 69 | - imageio==2.34.0 70 | - importlib-resources==6.3.1 71 | - keras==3.0.5 72 | - kiwisolver==1.4.5 73 | - lazy-loader==0.3 74 | - libclang==18.1.1 75 | - markdown==3.6 76 | - markdown-it-py==3.0.0 77 | - markupsafe==2.1.5 78 | - matplotlib==3.8.3 79 | - mccabe==0.7.0 80 | - mdurl==0.1.2 81 | - ml-dtypes==0.3.2 82 | - msgpack==1.0.8 83 | - mypy-extensions==1.0.0 84 | - namex==0.0.7 85 | - nest-asyncio==1.6.0 86 | - networkx==3.2.1 87 | - nodeenv==1.8.0 88 | - numpy==1.26.4 89 | - nvidia-cublas-cu12==12.4.2.65 90 | - nvidia-cuda-cupti-cu12==12.4.99 91 | - nvidia-cuda-nvcc-cu12==12.4.99 92 | - nvidia-cuda-nvrtc-cu12==12.4.99 93 | - nvidia-cuda-runtime-cu12==12.4.99 94 | - nvidia-cudnn-cu12==8.9.7.29 95 | - nvidia-cufft-cu12==11.2.0.44 96 | - nvidia-cusolver-cu12==11.6.0.99 97 | - nvidia-cusparse-cu12==12.3.0.142 98 | - nvidia-nccl-cu12==2.20.5 99 | - nvidia-nvjitlink-cu12==12.4.99 100 | - opencv-python==4.9.0.80 101 | - opt-einsum==3.3.0 102 | - optax==0.2.1 103 | - orbax-checkpoint==0.5.6 104 | - osmnx==1.9.1 105 | - packaging==24.0 106 | - pandas==2.2.1 107 | - pathlib==1.0.1 108 | - pathspec==0.12.1 109 | - pillow==10.2.0 110 | - pip==24.0 111 | - platformdirs==4.2.0 112 | - pre-commit==3.6.2 113 | - protobuf==3.20.3 114 | - psutil==5.9.8 115 | - pycodestyle==2.11.1 116 | - pyflakes==3.2.0 117 | - pygame==2.5.2 118 | - pygments==2.17.2 119 | - pyparsing==3.1.2 120 | - pyproj==3.6.1 121 | - python-dateutil==2.9.0.post0 122 | - pytz==2024.1 123 | - pyyaml==6.0.1 124 | - requests==2.31.0 125 | - rich==13.7.1 126 | - scikit-image==0.22.0 127 | - scipy==1.12.0 128 | - sentry-sdk==1.42.0 129 | - setproctitle==1.3.3 130 | - shapely==2.0.3 131 | - six==1.16.0 132 | - smmap==5.0.1 133 | - tdqm==0.0.1 134 | - tensorboard==2.16.2 135 | - tensorboard-data-server==0.7.2 136 | - tensorboard-plugin-profile==2.15.1 137 | - tensorflow==2.16.1 138 | - tensorflow-probability==0.24.0 139 | - tensorstore==0.1.56 140 | - termcolor==2.4.0 141 | - tifffile==2024.2.12 142 | - toolz==0.12.1 143 | - tqdm==4.66.2 144 | - typing-extensions==4.10.0 145 | - tzdata==2024.1 146 | - urllib3==2.2.1 147 | - virtualenv==20.25.1 148 | - wandb==0.16.4 149 | - werkzeug==3.0.1 150 | - wrapt==1.16.0 151 | - zipp==3.18.1 152 | -------------------------------------------------------------------------------- /terra/env_generation/README.md: -------------------------------------------------------------------------------- 1 | # Env Generation 2 | 3 | This folder contains the essential tools for generating maps to train Terra agents. It leverages both procedurally generated environments and real-world building footprints. 4 | 5 | ## Available Map Types 6 | 7 | - **Foundations**: Downloaded from OpenStreetMap and projected onto a grid map. 8 | - **Trenches**: Procedurally generated trenches featuring 1, 2, or 3 axes, along with obstacles, no-dumping zones, and terminal dumping constraints. 9 | 10 | ## Generating Maps 11 | 12 | ### Step 1: Create Training Maps 13 | 14 | 1. Generate procedural trenches, add constraints and obstacles, and reformat the maps for Terra use: 15 | ```bash 16 | python generate_dataset.py 17 | ``` 18 | This will create a data/train folder which contains the maps used during training. 19 | 20 | ### Step 3: Verify Map Generation 21 | 22 | 1. Ensure the maps are correctly generated by running: 23 | ```bash 24 | DATASET_PATH="/terra/digbench/data/train/" DATASET_SIZE= python -m terra.viz.play 25 | ``` 26 | For example: 27 | ``` 28 | DATASET_PATH=/terra/data/terra/train DATASET_SIZE=24 python -m terra.viz.play 29 | ``` 30 | Replace `` with the actual path to your Terra installation and `` with the desired dataset size. 31 | 32 | ## Data Generation Workflows In Detail 33 | 34 | This section provides a deeper understanding of how each type of training data is generated and processed in the Terra system. 35 | 36 | ### Foundations Data Workflow 37 | 38 | The foundations data is based on real-world building footprints from OpenStreetMap: 39 | 40 | 1. **Download and Processing**: 41 | - `generate_foundations.py` downloads building footprints from OpenStreetMap using the specified bounding box 42 | - Buildings are projected onto a grid with configurable resolution 43 | - Images undergo preprocessing including padding, hole filling, and filtering 44 | 45 | 2. **convert_to_terra**: 46 | - The `create_foundations` function in `create_train_data.py` handles: 47 | - Downsampling of images to fit maximum size requirements 48 | - Converting images to the Terra format 49 | - Generating occupancy and dumpability maps 50 | - The convert_to_terra module applies final transformations to make the data usable for training 51 | 52 | 3. **Configuration**: 53 | - Parameters are loaded from the config YAML file, including: 54 | - Resolution and size constraints 55 | - Dataset paths 56 | - Optional obstacle and non-dumpable zone parameters 57 | 58 | ### Trenches Data Workflow 59 | 60 | Trenches are procedurally generated environments for excavation tasks: 61 | 62 | 1. **Generation**: 63 | - `create_procedural_trenches` in `create_train_data.py` handles generation 64 | - The `generate_trenches_v2` function in `procedural_data.py` creates trenches with 1, 2, or 3 axes 65 | - Trenches are organized in different difficulty levels based on configuration 66 | 67 | 2. **Feature Addition**: 68 | - Obstacles are added with configurable parameters (number, size) 69 | - No-dumping zones are placed with constraints 70 | - Terminal dumping constraints are applied 71 | 72 | 3. **Data Organization**: 73 | - Trenches are saved in folders organized by difficulty level 74 | - Each trench includes image data, metadata, occupancy, and dumpability maps 75 | - The `generate_trenches_terra` function in `convert_to_terra.py` converts all data to the Terra format 76 | 77 | ### Curriculum Generation 78 | 79 | For structured training progression: 80 | 81 | 1. **Generating Curriculum Data**: 82 | - `generate_curriculum.py` creates a progression of environments with increasing difficulty 83 | - Different environment types can be integrated into the curriculum 84 | - Each stage is stored in appropriately named folders 85 | 86 | 2. **Usage**: 87 | - Configure the curriculum in the config file 88 | - Run the curriculum generator 89 | - The resulting data follows a progression suitable for staged training 90 | 91 | ### Data Format Conversion 92 | 93 | All generated data undergoes format conversion for training: 94 | 95 | 1. **Conversion Process**: 96 | - `convert_to_terra.py` contains functions to convert all data to the Terra format 97 | - `generate_dataset_terra_format` converts data to multiple resolutions 98 | - Images, occupancy maps, and dumpability maps are all properly formatted 99 | 100 | 2. **Output Structure**: 101 | - Final data is organized in the `/data/terra/train/` directory 102 | - Subdirectories include foundations, trenches (with difficulty levels) and custom maps 103 | - Each environment has consistent format and metadata 104 | -------------------------------------------------------------------------------- /terra/env_generation/generate_relocations.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | 5 | from pathlib import Path 6 | from terra.env_generation.procedural_data import ( 7 | add_non_dumpables, 8 | add_obstacles, 9 | initialize_image, 10 | save_or_display_image 11 | ) 12 | 13 | from terra.env_generation.utils import color_dict, _get_img_mask 14 | 15 | def add_dump_zones(img, n_dump_min, n_dump_max, size_dump_min, size_dump_max): 16 | w, h = img.shape[:2] 17 | n_dump = 0 18 | n_dump_todo = np.random.randint(n_dump_min, n_dump_max + 1) 19 | cumulative_mask = np.zeros_like(img[..., 0], dtype=bool) 20 | 21 | while n_dump < n_dump_todo: 22 | # Randomly determine the size of the dumping zone 23 | sizeox = np.random.randint(size_dump_min, size_dump_max + 1) 24 | sizeoy = np.random.randint(size_dump_min, size_dump_max + 1) 25 | # Randomly select a position for the dumping zone 26 | x = np.random.randint(0, w - sizeox) 27 | y = np.random.randint(0, h - sizeoy) 28 | # Update the dumping zone layer 29 | img[x : x + sizeox, y : y + sizeoy] = np.array(color_dict["dumping"]) 30 | cumulative_mask[x : x + sizeox, y : y + sizeoy] = True 31 | n_dump += 1 # Increment the count of zones added 32 | 33 | return img, cumulative_mask 34 | 35 | def add_dirt_tiles(img, occ, dmp, cumulative_mask, n_dirt_min, n_dirt_max, size_dirt_min, size_dirt_max): 36 | w, h = img.shape[:2] 37 | n_dirt = 0 38 | n_dirt_todo = np.random.randint(n_dirt_min, n_dirt_max + 1) 39 | drt = np.ones_like(img) * 255 40 | 41 | mask_occ = _get_img_mask(occ, color_dict["obstacle"]) 42 | mask_dmp = _get_img_mask(occ, color_dict["nondumpable"]) 43 | 44 | while n_dirt < n_dirt_todo: 45 | # Randomly determine the size of the dirt pile 46 | sizeox = np.random.randint(size_dirt_min, size_dirt_max + 1) 47 | sizeoy = np.random.randint(size_dirt_min, size_dirt_max + 1) 48 | # Randomly select a position for the dirt pile 49 | x = np.random.randint(0, w - sizeox) 50 | y = np.random.randint(0, h - sizeoy) 51 | # Check if the selected area overlaps with existing features 52 | if np.all(cumulative_mask[x : x + sizeox, y : y + sizeoy] == 0) and np.all( 53 | mask_occ[x : x + sizeox, y : y + sizeoy] == 0 54 | ) and np.all(mask_dmp[x : x + sizeox, y : y + sizeoy] == 0): 55 | drt[x : x + sizeox, y : y + sizeoy] = np.array(color_dict["dirt"]) 56 | cumulative_mask[x : x + sizeox, y : y + sizeoy] = True 57 | n_dirt += 1 58 | 59 | return drt, cumulative_mask 60 | 61 | def save_action_image(drt, save_folder, i): 62 | # make dir if does not exist 63 | os.makedirs(save_folder, exist_ok=True) 64 | save_folder_action = Path(save_folder) / "actions" 65 | save_folder_action.mkdir(parents=True, exist_ok=True) 66 | cv2.imwrite( 67 | os.path.join(save_folder_action, "trench_" + str(i) + ".png"), drt 68 | ) # Added .png extension 69 | 70 | def create_relocations(config, n_imgs): 71 | img_edge_min = config["img_edge_min"] 72 | img_edge_max = config["img_edge_max"] 73 | n_dump_min = config["n_dump_min"] 74 | n_dump_max = config["n_dump_max"] 75 | size_dump_min = config["size_dump_min"] 76 | size_dump_max = config["size_dump_max"] 77 | n_obs_min = config["n_obs_min"] 78 | n_obs_max = config["n_obs_max"] 79 | size_obstacle_min = config["size_obstacle_min"] 80 | size_obstacle_max = config["size_obstacle_max"] 81 | n_nodump_min = config["n_nodump_min"] 82 | n_nodump_max = config["n_nodump_max"] 83 | size_nodump_min = config["size_nodump_min"] 84 | size_nodump_max = config["size_nodump_max"] 85 | n_dirt_min = config["n_dirt_min"] 86 | n_dirt_max = config["n_dirt_max"] 87 | size_dirt_min = config["size_dirt_min"] 88 | size_dirt_max = config["size_dirt_max"] 89 | 90 | package_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 91 | save_folder = os.path.join(package_dir, "data", "terra", "relocations") 92 | 93 | i = 0 94 | while i < n_imgs: 95 | img = initialize_image(img_edge_min, img_edge_max, color_dict["neutral"]) 96 | img, cumulative_mask = add_dump_zones(img, n_dump_min, n_dump_max, size_dump_min, size_dump_max) 97 | occ, cumulative_mask = add_obstacles(img, cumulative_mask, n_obs_min, n_obs_max, size_obstacle_min, size_obstacle_max) 98 | dmp, cumulative_mask = add_non_dumpables(img, occ, cumulative_mask, n_nodump_min, n_nodump_max, size_nodump_min, size_nodump_max) 99 | drt, cumulative_mask = add_dirt_tiles(img, occ, dmp, cumulative_mask, n_dirt_min, n_dirt_max, size_dirt_min, size_dirt_max) 100 | save_or_display_image(img, occ, dmp, {}, save_folder, i) 101 | save_action_image(drt, save_folder, i) 102 | i += 1 103 | 104 | print("Relocations created successfully.") 105 | -------------------------------------------------------------------------------- /terra/actions.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | from typing import NamedTuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax import Array 7 | 8 | from terra.settings import IntLowDim 9 | 10 | ActionType = IntEnum 11 | 12 | 13 | class TrackedActionType(ActionType): 14 | """ 15 | Tracked robot specific actions. 16 | """ 17 | 18 | DO_NOTHING = -1 19 | FORWARD = 0 20 | BACKWARD = 1 21 | CLOCK = 2 22 | ANTICLOCK = 3 23 | CABIN_CLOCK = 4 24 | CABIN_ANTICLOCK = 5 25 | DO = 6 26 | 27 | 28 | Action = NamedTuple 29 | 30 | 31 | class TrackedAction(Action): 32 | type: Array = jnp.full((1,), fill_value=0, dtype=IntLowDim) 33 | action: Array = jnp.full( 34 | (1,), fill_value=TrackedActionType.DO_NOTHING, dtype=IntLowDim 35 | ) 36 | 37 | @classmethod 38 | def new(cls, action: TrackedActionType) -> "TrackedAction": 39 | return TrackedAction( 40 | action=IntLowDim(action), type=jnp.zeros_like(action, dtype=IntLowDim) 41 | ) 42 | 43 | @classmethod 44 | def do_nothing(cls): 45 | return cls.new(jnp.full((1,), TrackedActionType.DO_NOTHING, dtype=IntLowDim)) 46 | 47 | @classmethod 48 | def forward(cls): 49 | return cls.new(jnp.full((1,), TrackedActionType.FORWARD, dtype=IntLowDim)) 50 | 51 | @classmethod 52 | def backward(cls): 53 | return cls.new(jnp.full((1,), TrackedActionType.BACKWARD, dtype=IntLowDim)) 54 | 55 | @classmethod 56 | def clock(cls): 57 | return cls.new(jnp.full((1,), TrackedActionType.CLOCK, dtype=IntLowDim)) 58 | 59 | @classmethod 60 | def anticlock(cls): 61 | return cls.new(jnp.full((1,), TrackedActionType.ANTICLOCK, dtype=IntLowDim)) 62 | 63 | @classmethod 64 | def cabin_clock(cls): 65 | return cls.new(jnp.full((1,), TrackedActionType.CABIN_CLOCK, dtype=IntLowDim)) 66 | 67 | @classmethod 68 | def cabin_anticlock(cls): 69 | return cls.new( 70 | jnp.full((1,), TrackedActionType.CABIN_ANTICLOCK, dtype=IntLowDim) 71 | ) 72 | 73 | @classmethod 74 | def do(cls): 75 | return cls.new(jnp.full((1,), TrackedActionType.DO, dtype=IntLowDim)) 76 | 77 | @classmethod 78 | def random(cls, key: jnp.int32): 79 | return cls.new( 80 | jax.random.choice( 81 | key, 82 | jnp.arange(TrackedActionType.FORWARD, TrackedActionType.DO + 1), 83 | (1,), 84 | ) 85 | ) 86 | 87 | @staticmethod 88 | def get_num_actions(): 89 | return 7 90 | 91 | 92 | class WheeledActionType(ActionType): 93 | """ 94 | Wheeled robot specific actions. 95 | """ 96 | 97 | DO_NOTHING = -1 98 | FORWARD = 0 99 | BACKWARD = 1 100 | WHEELS_LEFT = 2 101 | WHEELS_RIGHT = 3 102 | CABIN_CLOCK = 4 103 | CABIN_ANTICLOCK = 5 104 | DO = 6 105 | 106 | 107 | class WheeledAction(Action): 108 | type: Array = jnp.full((1,), fill_value=1, dtype=IntLowDim) 109 | action: Array = jnp.full( 110 | (1,), fill_value=WheeledActionType.DO_NOTHING, dtype=IntLowDim 111 | ) 112 | 113 | @classmethod 114 | def new(cls, action: WheeledActionType) -> "WheeledAction": 115 | return WheeledAction( 116 | action=IntLowDim(action), type=jnp.ones_like(action, dtype=IntLowDim) 117 | ) 118 | 119 | @classmethod 120 | def do_nothing(cls): 121 | return cls.new(jnp.full((1,), WheeledActionType.DO_NOTHING, dtype=IntLowDim)) 122 | 123 | @classmethod 124 | def forward(cls): 125 | return cls.new(jnp.full((1,), WheeledActionType.FORWARD, dtype=IntLowDim)) 126 | 127 | @classmethod 128 | def backward(cls): 129 | return cls.new(jnp.full((1,), WheeledActionType.BACKWARD, dtype=IntLowDim)) 130 | 131 | @classmethod 132 | def wheels_left(cls): 133 | return cls.new(jnp.full((1,), WheeledActionType.WHEELS_LEFT, dtype=IntLowDim)) 134 | 135 | @classmethod 136 | def wheels_right(cls): 137 | return cls.new( 138 | jnp.full((1,), WheeledActionType.WHEELS_RIGHT, dtype=IntLowDim) 139 | ) 140 | 141 | @classmethod 142 | def cabin_clock(cls): 143 | return cls.new(jnp.full((1,), WheeledActionType.CABIN_CLOCK, dtype=IntLowDim)) 144 | 145 | @classmethod 146 | def cabin_anticlock(cls): 147 | return cls.new( 148 | jnp.full((1,), WheeledActionType.CABIN_ANTICLOCK, dtype=IntLowDim) 149 | ) 150 | 151 | @classmethod 152 | def do(cls): 153 | return cls.new(jnp.full((1,), WheeledActionType.DO, dtype=IntLowDim)) 154 | 155 | @classmethod 156 | def random(cls, key: jnp.int32): 157 | return cls.new( 158 | jax.random.choice( 159 | key, 160 | jnp.arange(WheeledActionType.FORWARD, WheeledActionType.DO + 1), 161 | (1,), 162 | ) 163 | ) 164 | 165 | @staticmethod 166 | def get_num_actions(): 167 | return 7 168 | -------------------------------------------------------------------------------- /terra/viz/game/world.py: -------------------------------------------------------------------------------- 1 | import pygame as pg 2 | from .settings import COLORS 3 | import numpy as np 4 | 5 | 6 | class World: 7 | def __init__(self, grid_length_x, grid_length_y, width, height, tile_size): 8 | self.grid_length_x = grid_length_x 9 | self.grid_length_y = grid_length_y 10 | self.width = width 11 | self.height = height 12 | self.tile_size = tile_size 13 | 14 | def grid_to_world(self, grid_x, grid_y, bitmap_code): 15 | rect = [ 16 | (grid_x * self.tile_size, grid_y * self.tile_size), 17 | (grid_x * self.tile_size + self.tile_size, grid_y * self.tile_size), 18 | ( 19 | grid_x * self.tile_size + self.tile_size, 20 | grid_y * self.tile_size + self.tile_size, 21 | ), 22 | (grid_x * self.tile_size, grid_y * self.tile_size + self.tile_size), 23 | ] 24 | 25 | # Handle custom colors (hex strings) or standard COLORS dictionary lookup 26 | if isinstance(bitmap_code, str) and bitmap_code.startswith('#'): 27 | # Custom color (hex string) 28 | color = bitmap_code 29 | elif isinstance(bitmap_code, (int, str)) and bitmap_code in COLORS: 30 | # Standard color from dictionary 31 | color = COLORS[bitmap_code] 32 | else: 33 | # Fallback to neutral color 34 | color = COLORS[0] 35 | 36 | out = { 37 | "grid": [grid_x, grid_y], 38 | "cart_rect": rect, 39 | "color": color, 40 | } 41 | 42 | return out 43 | 44 | def _get_dirt_gradient_color(self, dirt_amount, max_dirt_amount): 45 | """ 46 | Generate a gradient color for dirt amount. 47 | Light blue for small amounts, dark blue for large amounts. 48 | 49 | Args: 50 | dirt_amount: Current dirt amount on this tile 51 | max_dirt_amount: Maximum dirt amount across all tiles for normalization 52 | Returns: 53 | Hex color string 54 | """ 55 | if dirt_amount <= 0 or max_dirt_amount <= 0: 56 | return COLORS[0] # neutral color 57 | 58 | # Normalize dirt amount to 0-1 range 59 | intensity = min(dirt_amount / max_dirt_amount, 1.0) 60 | 61 | # Define gradient from light blue to dark blue 62 | # Light blue: RGB(173, 216, 230) -> #ADD8E6 63 | # Dark blue: RGB(0, 43, 91) -> #002B5B (original dumped dirt color) 64 | light_blue = (173, 216, 230) # Light blue 65 | dark_blue = (0, 43, 91) # Dark blue (original) 66 | 67 | # Interpolate between light and dark blue 68 | r = int(light_blue[0] + (dark_blue[0] - light_blue[0]) * intensity) 69 | g = int(light_blue[1] + (dark_blue[1] - light_blue[1]) * intensity) 70 | b = int(light_blue[2] + (dark_blue[2] - light_blue[2]) * intensity) 71 | 72 | # Convert to hex 73 | hex_color = f"#{r:02x}{g:02x}{b:02x}" 74 | 75 | return hex_color 76 | 77 | def update(self, action_map, target_map, obstacles_mask, dumpability_mask): 78 | action_map = np.asarray(action_map, dtype=np.int32) 79 | action_map = action_map.swapaxes(0, 1) 80 | 81 | target_map = np.asarray(target_map, dtype=np.int32) 82 | target_map = target_map.swapaxes(0, 1) 83 | if obstacles_mask is not None: 84 | obstacles_mask = np.asarray(obstacles_mask, dtype=np.bool_) 85 | obstacles_mask = obstacles_mask.swapaxes(0, 1) 86 | if dumpability_mask is not None: 87 | dumpability_mask = np.asarray(dumpability_mask, dtype=np.bool_) 88 | dumpability_mask = dumpability_mask.swapaxes(0, 1) 89 | 90 | # Find max dirt amount for gradient normalization (per-frame) 91 | max_dirt_amount = np.max(action_map[action_map > 0]) if np.any(action_map > 0) else 1 92 | 93 | world = [] 94 | 95 | for grid_x in range(self.grid_length_x): 96 | world.append([]) 97 | for grid_y in range(self.grid_length_y): 98 | dirt_amount = action_map[grid_x, grid_y] 99 | 100 | if target_map[grid_x, grid_y] == -1: 101 | # to dig 102 | tile = 4 103 | elif target_map[grid_x, grid_y] == 1: 104 | # final dumping area to terminate the episode 105 | tile = 5 106 | else: 107 | # neutral 108 | tile = 0 109 | 110 | if obstacles_mask is not None and obstacles_mask[grid_x, grid_y] == 1: 111 | # obstacle 112 | tile = 2 113 | if ( 114 | dumpability_mask is not None 115 | and dumpability_mask[grid_x, grid_y] == 0 116 | ): 117 | # non-dumpable (e.g. road) 118 | tile = 3 119 | if dirt_amount > 0: 120 | # action map dump - use gradient based on dirt amount 121 | tile = self._get_dirt_gradient_color(dirt_amount, max_dirt_amount) 122 | if dirt_amount < 0: 123 | # action map dug 124 | tile = -1 125 | 126 | world_tile = self.grid_to_world(grid_x, grid_y, tile) 127 | world[grid_x].append(world_tile) 128 | 129 | self.action_map = world 130 | -------------------------------------------------------------------------------- /terra/curriculum.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import Array 6 | from terra.config import Rewards 7 | 8 | 9 | def print_arrays(arr, what): 10 | print(f"{what}: {arr}") 11 | 12 | 13 | class CurriculumManager(NamedTuple): 14 | """ 15 | This class defines the logic to change the environment configuration given the performance of the agent. 16 | This class is not stateful, the state of the curriculum is fully contained in the EnvConfig object. 17 | """ 18 | 19 | max_level: int 20 | increase_level_threshold: int 21 | decrease_level_threshold: int 22 | max_steps_in_episode_per_level: Array 23 | apply_trench_rewards_per_level: Array 24 | reward_type_per_level: Array 25 | last_level_type: str 26 | 27 | def _update_single_cfg(self, timestep, rng): 28 | """ 29 | Update the environment configuration based on the timestep. This function is vammaped therefore the timestep.arrays have dimensions (batch_size, ...) (1 step per env). 30 | """ 31 | env_cfg = timestep.env_cfg 32 | done = timestep.done 33 | completed = timestep.info["task_done"] 34 | 35 | failure = done & ~completed 36 | success = done & completed 37 | 38 | def handle_update(): 39 | consecutive_failures = jax.lax.cond( 40 | failure, 41 | lambda: env_cfg.curriculum.consecutive_failures + 1, 42 | lambda: 0, 43 | ) 44 | consecutive_successes = jax.lax.cond( 45 | success, 46 | lambda: env_cfg.curriculum.consecutive_successes + 1, 47 | lambda: 0, 48 | ) 49 | return consecutive_failures, consecutive_successes 50 | 51 | consecutive_failures, consecutive_successes = jax.lax.cond( 52 | done, 53 | handle_update, 54 | lambda: ( 55 | env_cfg.curriculum.consecutive_failures, 56 | env_cfg.curriculum.consecutive_successes, 57 | ), 58 | ) 59 | 60 | do_increase = consecutive_successes >= self.increase_level_threshold 61 | do_decrease = consecutive_failures >= self.decrease_level_threshold 62 | 63 | level, consecutive_failures, consecutive_successes = jax.lax.cond( 64 | do_increase, 65 | lambda: ( 66 | jax.lax.cond( 67 | env_cfg.curriculum.level < self.max_level, 68 | lambda: env_cfg.curriculum.level + 1, 69 | lambda: jax.lax.cond( 70 | self.last_level_type == "none", 71 | lambda: env_cfg.curriculum.level, 72 | lambda: jax.lax.cond( 73 | self.last_level_type == "random", 74 | lambda: jax.random.randint(rng, (), 0, self.max_level + 1), 75 | lambda: 97, # Error case 76 | ), 77 | ), 78 | ), 79 | 0, # Reset consecutive_failures 80 | 0, # Reset consecutive_successes 81 | ), 82 | lambda: jax.lax.cond( 83 | do_decrease, 84 | lambda: ( 85 | jnp.maximum(env_cfg.curriculum.level - 1, 0), 86 | 0, # Reset consecutive_failures 87 | 0, # Reset consecutive_successes 88 | ), 89 | lambda: ( 90 | env_cfg.curriculum.level, 91 | consecutive_failures, # Keep the current count 92 | consecutive_successes, # Keep the current count 93 | ), 94 | ), 95 | ) 96 | 97 | max_steps_in_episode = self.max_steps_in_episode_per_level[level] 98 | apply_trench_rewards = self.apply_trench_rewards_per_level[level] 99 | 100 | rewards_list = [Rewards.dense, Rewards.sparse] 101 | reward_type = self.reward_type_per_level[level] 102 | rewards = jax.lax.switch(reward_type, rewards_list) 103 | 104 | env_cfg = env_cfg._replace( 105 | rewards=rewards, 106 | apply_trench_rewards=apply_trench_rewards, 107 | max_steps_in_episode=max_steps_in_episode, 108 | curriculum=env_cfg.curriculum._replace( 109 | level=level, 110 | consecutive_failures=consecutive_failures, 111 | consecutive_successes=consecutive_successes, 112 | ), 113 | ) 114 | timestep = timestep._replace(env_cfg=env_cfg) 115 | return timestep 116 | 117 | def _reset_single_cfg(self, env_cfg): 118 | level = env_cfg.curriculum.level 119 | max_steps_in_episode = self.max_steps_in_episode_per_level[level] 120 | apply_trench_rewards = self.apply_trench_rewards_per_level[level] 121 | 122 | rewards_list = [Rewards.dense, Rewards.sparse] 123 | reward_type = self.reward_type_per_level[level] 124 | rewards = jax.lax.switch(reward_type, rewards_list) 125 | 126 | env_cfg = env_cfg._replace( 127 | rewards=rewards, 128 | apply_trench_rewards=apply_trench_rewards, 129 | max_steps_in_episode=max_steps_in_episode, 130 | ) 131 | return env_cfg 132 | 133 | def update_cfgs(self, timesteps, rng): 134 | batch_size = timesteps.done.shape[0] 135 | if rng.ndim == 1: 136 | rngs = jax.random.split(rng, batch_size) 137 | else: 138 | rngs = rng 139 | return jax.vmap(self._update_single_cfg)(timesteps, rngs) 140 | 141 | def reset_cfgs(self, env_cfgs): 142 | return jax.vmap(self._reset_single_cfg)(env_cfgs) 143 | -------------------------------------------------------------------------------- /terra/agent.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import Array 6 | 7 | from terra.config import EnvConfig 8 | from terra.settings import IntLowDim 9 | from terra.settings import IntMap 10 | from terra.utils import compute_polygon_mask 11 | from terra.utils import get_agent_corners 12 | 13 | 14 | class AgentState(NamedTuple): 15 | """ 16 | Clarifications on the agent state representation. 17 | 18 | angle_base: 19 | Orientations of the agent are an integer between 0 and 3 (included), 20 | where 0 means that it is aligned with the x axis, and for every positive 21 | increment, 90 degrees are added in the direction of the arrow going from 22 | the x axis to the y axes (anti-clockwise). 23 | """ 24 | 25 | pos_base: IntMap 26 | angle_base: IntLowDim 27 | angle_cabin: IntLowDim 28 | wheel_angle: IntLowDim 29 | loaded: IntLowDim 30 | 31 | 32 | class Agent(NamedTuple): 33 | """ 34 | Defines the state and type of the agent. 35 | """ 36 | 37 | agent_state: AgentState 38 | 39 | width: int 40 | height: int 41 | 42 | moving_dumped_dirt: bool 43 | 44 | @staticmethod 45 | def new( 46 | key: jax.random.PRNGKey, 47 | env_cfg: EnvConfig, 48 | max_traversable_x: int, 49 | max_traversable_y: int, 50 | padding_mask: Array, 51 | action_map: Array, 52 | ) -> tuple["Agent", jax.random.PRNGKey]: 53 | pos_base, angle_base, key = jax.lax.cond( 54 | env_cfg.agent.random_init_state, 55 | lambda k: _get_random_init_state( 56 | k, 57 | env_cfg, 58 | max_traversable_x, 59 | max_traversable_y, 60 | padding_mask, 61 | action_map, 62 | env_cfg.agent.width, 63 | env_cfg.agent.height, 64 | ), 65 | lambda k: _get_top_left_init_state(k, env_cfg), 66 | key, 67 | ) 68 | 69 | agent_state = AgentState( 70 | pos_base=pos_base, 71 | angle_base=angle_base, 72 | angle_cabin=jnp.full((1,), 0, dtype=IntLowDim), 73 | wheel_angle=jnp.full((1,), 0, dtype=IntLowDim), 74 | loaded=jnp.full((1,), 0, dtype=IntLowDim), 75 | ) 76 | 77 | width = env_cfg.agent.width 78 | height = env_cfg.agent.height 79 | 80 | moving_dumped_dirt = False 81 | 82 | return Agent(agent_state=agent_state, width=width, height=height, moving_dumped_dirt=moving_dumped_dirt), key 83 | 84 | 85 | def _get_top_left_init_state(key: jax.random.PRNGKey, env_cfg: EnvConfig): 86 | max_center_coord = jnp.ceil( 87 | jnp.max( 88 | jnp.array([env_cfg.agent.width // 2 - 1, env_cfg.agent.height // 2 - 1]) 89 | ) 90 | ).astype(IntMap) 91 | pos_base = IntMap(jnp.array([max_center_coord, max_center_coord])) 92 | theta = jnp.full((1,), fill_value=0, dtype=IntMap) 93 | return pos_base, theta, key 94 | 95 | 96 | def _get_random_init_state( 97 | key: jax.random.PRNGKey, 98 | env_cfg: EnvConfig, 99 | max_traversable_x: int, 100 | max_traversable_y: int, 101 | padding_mask: Array, 102 | action_map: Array, 103 | agent_width: int, 104 | agent_height: int, 105 | ): 106 | def _get_random_agent_state(carry): 107 | key, padding_mask, pos_base, angle_base = carry 108 | max_center_coord = jnp.ceil( 109 | jnp.max( 110 | jnp.array([env_cfg.agent.width / 2 - 1, env_cfg.agent.height / 2 - 1]) 111 | ) 112 | ).astype(IntMap) 113 | key, subkey_x, subkey_y, subkey_angle = jax.random.split(key, 4) 114 | 115 | max_w = jnp.minimum(max_traversable_x, env_cfg.maps.edge_length_px) 116 | max_h = jnp.minimum(max_traversable_y, env_cfg.maps.edge_length_px) 117 | 118 | x = jax.random.randint( 119 | subkey_x, 120 | (1,), 121 | minval=max_center_coord, 122 | maxval=max_w - max_center_coord, 123 | ) 124 | y = jax.random.randint( 125 | subkey_y, 126 | (1,), 127 | minval=max_center_coord, 128 | maxval=max_h - max_center_coord, 129 | ) 130 | pos_base = IntMap(jnp.concatenate((x, y))) 131 | angle_base = jax.random.randint( 132 | subkey_angle, (1,), 0, env_cfg.agent.angles_base, dtype=IntMap 133 | ) 134 | return key, padding_mask, pos_base, angle_base 135 | 136 | def _check_agent_obstacles_intersection(carry): 137 | key, padding_mask, pos_base, angle_base = carry 138 | map_width = padding_mask.shape[0] 139 | map_height = padding_mask.shape[1] 140 | 141 | def _check_intersection(): 142 | """ 143 | Checks that the agent does not spawn where an obstacle is (or else it will get stuck forever). 144 | The check takes the four agent corners and checks that in the tiles included 145 | within the corners there is no obstacle-encoded tile. 146 | The padding mask is the map encoding obstacles (1 for obstacle and 0 for no obstacle). 147 | """ 148 | agent_corners_xy = get_agent_corners( 149 | pos_base, angle_base, agent_width, agent_height, env_cfg.agent.angles_base 150 | ) 151 | polygon_mask = compute_polygon_mask( 152 | agent_corners_xy, map_width, map_height 153 | ) 154 | 155 | obstacle_inside = jnp.any(jnp.logical_and(polygon_mask, padding_mask == 1)) 156 | action_illegal = jnp.any(jnp.logical_and(polygon_mask, action_map != 0)) 157 | return obstacle_inside | action_illegal 158 | 159 | keep_searching = jax.lax.cond( 160 | jnp.any(pos_base < 0) | jnp.any(angle_base < 0), 161 | lambda: True, 162 | _check_intersection, 163 | ) 164 | return keep_searching 165 | 166 | key, padding_mask, pos_base, angle_base = jax.lax.while_loop( 167 | _check_agent_obstacles_intersection, 168 | _get_random_agent_state, 169 | ( 170 | key, 171 | padding_mask, 172 | jnp.array([-1, -1], dtype=IntMap), 173 | jnp.full((1,), -1, dtype=IntMap), 174 | ), 175 | ) 176 | 177 | return pos_base, angle_base, key 178 | -------------------------------------------------------------------------------- /terra/env_generation/generate_foundations.py: -------------------------------------------------------------------------------- 1 | from terra.env_generation import openstreet 2 | from terra.env_generation import utils 3 | import os 4 | import shutil 5 | 6 | # set seed 7 | from pyproj import CRS, Transformer 8 | import random 9 | import pathlib 10 | 11 | random.seed(42) 12 | 13 | 14 | def download_foundations( 15 | main_folder, 16 | center_bbox=(47.378177, 47.364622, 8.526535, 8.544894), 17 | min_size=(20, 20), 18 | max_size=(100, 100), 19 | padding=3, 20 | resolution=0.05, 21 | max_buildings=None, 22 | ): 23 | dataset_folder = main_folder + "/foundations_raw" 24 | # if it does not exist 25 | if not os.path.exists(dataset_folder): 26 | os.makedirs(dataset_folder) 27 | os.makedirs(dataset_folder + "/images", exist_ok=True) 28 | os.makedirs(dataset_folder + "/metadata", exist_ok=True) 29 | try: 30 | openstreet.get_building_shapes_from_OSM( 31 | *center_bbox, option=2, save_folder=dataset_folder, max_buildings=max_buildings 32 | ) 33 | except Exception as e: 34 | print(e) 35 | 36 | # filter out small cases 37 | image_folder = main_folder + "/foundations_raw/images" 38 | save_folder = main_folder + "/foundations_filtered/images" 39 | metadata_folder = main_folder + "/foundations_raw/metadata" 40 | utils.size_filter(image_folder, save_folder, metadata_folder, min_size, max_size) 41 | 42 | # pad the edges 43 | image_folder = main_folder + "/foundations_filtered/images" 44 | save_folder = main_folder + "/foundations_filtered_padded" 45 | metadata_folder = main_folder + "/foundations_raw/metadata" 46 | utils.pad_images_and_update_metadata( 47 | image_folder, metadata_folder, padding, (255, 255, 255), save_folder 48 | ) 49 | 50 | # set resolution 51 | image_folder = main_folder + "/foundations_filtered_padded" 52 | metadata_folder = main_folder + "/foundations_filtered_padded" 53 | image_resized_folder = main_folder + "/foundations_filtered_padded_resized" 54 | utils.preprocess_dataset_fixed_resolution( 55 | image_folder, metadata_folder, image_resized_folder, resolution 56 | ) 57 | 58 | # filter out small cases again, after resizing and all 59 | image_folder = main_folder + "/foundations_filtered_padded_resized" 60 | save_folder = main_folder + "/foundations_filtered_padded_resized_refiltered" 61 | metadata_folder = main_folder + "/foundations_filtered_padded_resized" 62 | utils.size_filter( 63 | image_folder, 64 | save_folder, 65 | metadata_folder, 66 | min_size, 67 | max_size, 68 | copy_metadata=True, 69 | ) 70 | 71 | 72 | def create_exterior_foundations(main_folder, padding=5): 73 | # fill holes 74 | image_folder = main_folder + "/foundations_filtered_padded_resized_refiltered" 75 | dataset_folder = main_folder + "/exterior_foundations_filled" 76 | # make it if it doesn't exist 77 | os.makedirs(dataset_folder, exist_ok=True) 78 | save_folder = dataset_folder + "/images" 79 | metadata_folder = main_folder + "/foundations_filtered_padded" 80 | utils.fill_dataset(image_folder, save_folder, copy_metadata=False) 81 | # copy metadata folder to save folder and change its name to metadata 82 | utils.copy_metadata(metadata_folder, dataset_folder + "/metadata") 83 | # make occupancy, in this case is the same as the images folder 84 | # copy folder but change name 85 | shutil.copytree( 86 | main_folder + "/exterior_foundations_filled/images", 87 | main_folder + "/exterior_foundations_filled/occupancy", 88 | dirs_exist_ok=True, 89 | ) 90 | # pad the edges for navigation 91 | image_folder = main_folder + "/exterior_foundations_filled/images" 92 | save_folder = main_folder + "/exterior_foundations" 93 | metadata_folder = main_folder + "/foundations_filtered_padded" 94 | utils.pad_images_and_update_metadata( 95 | image_folder, metadata_folder, padding, (220, 220, 200), save_folder 96 | ) 97 | # restructure folder format 98 | image_folder = main_folder + "/exterior_foundations/images" 99 | metadata_folder = main_folder + "/exterior_foundations/metadata" 100 | utils.copy_metadata(save_folder, metadata_folder) 101 | # remove all json files from save_folder 102 | os.system("rm " + save_folder + "/*.json") 103 | # move all remaining images in /exterior_foundations inside /exterior_foundations/images 104 | # Create the destination directory if it doesn't exist 105 | os.makedirs(image_folder, exist_ok=True) 106 | 107 | # Get a list of all files in the source directory 108 | files = os.listdir(save_folder) 109 | 110 | # Filter the files to only include PNG images 111 | png_files = [file for file in files if file.lower().endswith(".png")] 112 | 113 | # Move each PNG image to the destination directory 114 | for file in png_files: 115 | source_path = os.path.join(save_folder, file) 116 | destination_path = os.path.join(image_folder, file) 117 | shutil.move(source_path, destination_path) 118 | 119 | shutil.copytree( 120 | main_folder + "/exterior_foundations/images", 121 | main_folder + "/exterior_foundations/occupancy", 122 | dirs_exist_ok=True, 123 | ) 124 | 125 | 126 | def create_exterior_foundations_traversable(main_folder): 127 | dataset_folder = main_folder + "/exterior_foundations" 128 | save_folder = main_folder + "/exterior_foundations_traversable" 129 | utils.make_obstacles_traversable( 130 | dataset_folder + "/images", save_folder + "/images" 131 | ) 132 | # copy metadata folder to save folder and change its name to metadata 133 | utils.copy_metadata(dataset_folder + "/metadata", save_folder + "/metadata") 134 | # generate empty occupancy 135 | utils.generate_empty_occupancy( 136 | dataset_folder + "/images", save_folder + "/occupancy" 137 | ) 138 | 139 | 140 | def create_foundations(main_folder): 141 | dataset_folder = main_folder + "/foundations_filtered_padded_resized_refiltered" 142 | save_folder = main_folder + "/foundations" 143 | utils.invert_dataset_apply_dump_foundations(dataset_folder, save_folder) 144 | utils.copy_metadata(dataset_folder, save_folder + f"/metadata") 145 | utils.generate_empty_occupancy(dataset_folder, save_folder + f"/occupancy") 146 | 147 | 148 | if __name__ == "__main__": 149 | # Basel center_bbox = (47.5376, 47.6126, 7.5401, 7.6842) 150 | # Basel center_bbox small = (47.5645, 47.572, 7.5867, 7.5979) 151 | # Zurich center_bbox small (benchmark) = (47.378177, 47.364622, 8.526535, 8.544894) 152 | sizes = [(20, 60)] # , (40, 80), (80, 160), (160, 320), (320, 640)] 153 | package_dir = os.path.dirname( 154 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 155 | ) 156 | for size in sizes: 157 | dataset_folder = os.path.join(package_dir, "data", "openstreet") 158 | download_foundations( 159 | dataset_folder, 160 | min_size=(size[0], size[0]), 161 | max_size=(size[1], size[1]), 162 | center_bbox=(47.5376, 47.6126, 7.5401, 7.6842), 163 | max_buildings=100 164 | ) 165 | create_foundations(dataset_folder) 166 | -------------------------------------------------------------------------------- /terra/wrappers.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import Array 4 | 5 | from terra.config import EnvConfig 6 | from terra.state import State 7 | from terra.utils import angle_idx_to_rad 8 | from terra.utils import apply_local_cartesian_to_cyl 9 | from terra.utils import apply_rot_transl 10 | from terra.utils import compute_polygon_mask 11 | from terra.utils import get_arm_angle_int 12 | from terra.settings import IntLowDim 13 | 14 | 15 | class TraversabilityMaskWrapper: 16 | @staticmethod 17 | def wrap(state: State) -> State: 18 | """ 19 | Encodes the traversability mask in GridWorld. 20 | 21 | The traversability mask has the same size as the action map, and encodes: 22 | - 0: no obstacle 23 | - 1: obstacle (digged or dumped tile) 24 | - (-1): agent occupying the tile 25 | """ 26 | # encode map obstacles 27 | traversability_mask = (state.world.action_map.map != 0).astype(IntLowDim) 28 | 29 | # encode agent pos and size in the map 30 | agent_corners = state._get_agent_corners( 31 | state.agent.agent_state.pos_base, 32 | state.agent.agent_state.angle_base, 33 | state.env_cfg.agent.width, 34 | state.env_cfg.agent.height, 35 | ) 36 | 37 | map_width = state.world.width 38 | map_height = state.world.height 39 | 40 | polygon_mask = compute_polygon_mask(agent_corners, map_width, map_height) 41 | traversability_mask = jnp.where(polygon_mask, -1, traversability_mask) 42 | 43 | padding_mask = state.world.padding_mask.map 44 | tm = jnp.where(padding_mask == 1, padding_mask, traversability_mask) 45 | 46 | return state._replace( 47 | world=state.world._replace( 48 | traversability_mask=state.world.traversability_mask._replace( 49 | map=tm.astype(IntLowDim) 50 | ) 51 | ) 52 | ) 53 | 54 | 55 | class LocalMapWrapper: 56 | @staticmethod 57 | def _wrap(state: State, map_to_wrap: Array) -> Array: 58 | """ 59 | Encodes the local map in the GridWorld. 60 | 61 | The local map is of dim angles_cabin, and encodes the cumulative 62 | sum of tiles to dig in the area spanned by the cyilindrical tile. 63 | """ 64 | current_pos_idx = state._get_current_pos_vector_idx( 65 | pos_base=state.agent.agent_state.pos_base, 66 | map_height=state.env_cfg.maps.edge_length_px, 67 | ) 68 | map_global_coords = state._map_to_flattened_global_coords( 69 | state.world.width, state.world.height, state.env_cfg.tile_size 70 | ) 71 | current_pos = state._get_current_pos_from_flattened_map( 72 | map_global_coords, current_pos_idx 73 | ) 74 | current_arm_angle = get_arm_angle_int( 75 | state.agent.agent_state.angle_base, 76 | state.agent.agent_state.angle_cabin, 77 | state.env_cfg.agent.angles_base, 78 | state.env_cfg.agent.angles_cabin, 79 | ) 80 | 81 | # Get the cumsum of the action height map in cyl coords, for every of [r, theta] portion of local space 82 | angles_cabin = ( 83 | EnvConfig().agent.angles_cabin 84 | ) # TODO: make state.env_cfg work instead of recreating the object every time 85 | arm_angles_ints = jnp.arange(angles_cabin) 86 | arm_angles_rads = jax.vmap( 87 | lambda angle: angle_idx_to_rad(angle, EnvConfig().agent.angles_cabin) 88 | )(arm_angles_ints) 89 | 90 | possible_states_arm = jax.vmap(lambda angle: jnp.hstack([current_pos, angle]))( 91 | arm_angles_rads 92 | ) 93 | possible_maps_local_coords_arm = jax.vmap( 94 | lambda arm_state: apply_rot_transl(arm_state, map_global_coords) 95 | )(possible_states_arm) 96 | possible_maps_cyl_coords = jax.vmap(apply_local_cartesian_to_cyl)( 97 | possible_maps_local_coords_arm 98 | ) # (n_angles x 2 x width*height) 99 | 100 | local_cartesian_masks = jax.vmap(lambda map: state._get_dig_dump_mask_cyl(map))( 101 | possible_maps_cyl_coords 102 | ) 103 | map_to_wrap_reshaped = map_to_wrap.reshape(state.world.height, state.world.width) 104 | local_cyl_height_map = jax.vmap( 105 | lambda mask: (map_to_wrap_reshaped * mask.reshape(state.world.height, state.world.width)).sum() 106 | )(local_cartesian_masks) 107 | 108 | # Roll it to bring it back in agent view 109 | local_cyl_height_map = jnp.roll( 110 | local_cyl_height_map, -current_arm_angle, axis=0 111 | ) 112 | 113 | return local_cyl_height_map.astype(IntLowDim) 114 | 115 | @staticmethod 116 | def wrap_target_map(state: State) -> State: 117 | target_map_pos = jnp.clip(state.world.target_map.map, a_min=0) 118 | target_map_neg = jnp.clip(state.world.target_map.map, a_max=0) 119 | local_map_target_pos = LocalMapWrapper._wrap(state, target_map_pos) 120 | local_map_target_neg = LocalMapWrapper._wrap(state, target_map_neg) 121 | return state._replace( 122 | world=state.world._replace( 123 | local_map_target_pos=state.world.local_map_target_pos._replace( 124 | map=local_map_target_pos 125 | ), 126 | local_map_target_neg=state.world.local_map_target_neg._replace( 127 | map=local_map_target_neg 128 | ), 129 | ) 130 | ) 131 | 132 | @staticmethod 133 | def wrap_action_map(state: State) -> State: 134 | action_map_pos = jnp.clip(state.world.action_map.map, a_min=0) 135 | action_map_neg = jnp.clip(state.world.action_map.map, a_max=0) 136 | local_map_action_pos = LocalMapWrapper._wrap(state, action_map_pos) 137 | local_map_action_neg = LocalMapWrapper._wrap(state, action_map_neg) 138 | return state._replace( 139 | world=state.world._replace( 140 | local_map_action_pos=state.world.local_map_action_pos._replace( 141 | map=local_map_action_pos 142 | ), 143 | local_map_action_neg=state.world.local_map_action_neg._replace( 144 | map=local_map_action_neg 145 | ), 146 | ) 147 | ) 148 | 149 | @staticmethod 150 | def wrap_dumpability_mask(state: State) -> State: 151 | dumpability_mask = state.world.dumpability_mask.map 152 | local_map_dumpability = LocalMapWrapper._wrap(state, dumpability_mask) 153 | return state._replace( 154 | world=state.world._replace( 155 | local_map_dumpability=state.world.local_map_dumpability._replace( 156 | map=local_map_dumpability 157 | ) 158 | ) 159 | ) 160 | 161 | @staticmethod 162 | def wrap_obstacles_mask(state: State) -> State: 163 | obstacles_mask = state.world.padding_mask.map 164 | local_map_obstacles = LocalMapWrapper._wrap(state, obstacles_mask) 165 | return state._replace( 166 | world=state.world._replace( 167 | local_map_obstacles=state.world.local_map_obstacles._replace( 168 | map=local_map_obstacles 169 | ) 170 | ) 171 | ) 172 | 173 | @staticmethod 174 | def wrap(state: State) -> State: 175 | state = LocalMapWrapper.wrap_target_map(state) 176 | state = LocalMapWrapper.wrap_action_map(state) 177 | state = LocalMapWrapper.wrap_dumpability_mask(state) 178 | state = LocalMapWrapper.wrap_obstacles_mask(state) 179 | return state 180 | -------------------------------------------------------------------------------- /terra/config.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | from typing import NamedTuple 3 | 4 | from terra.actions import Action 5 | from terra.actions import TrackedAction # noqa: F401 6 | from terra.actions import WheeledAction # noqa: F401 7 | 8 | 9 | class ExcavatorDims(NamedTuple): 10 | WIDTH: float = 6.08 # longer side 11 | HEIGHT: float = 3.5 # shorter side 12 | 13 | 14 | class RewardsType(IntEnum): 15 | DENSE = 0 16 | SPARSE = 1 17 | 18 | 19 | class ImmutableMapsConfig(NamedTuple): 20 | """ 21 | Define the max size of the map in meters. 22 | This defines the proportion between the map and the agent. 23 | """ 24 | 25 | edge_length_m: float = 44.0 # map edge length in meters 26 | edge_length_px: int = 0 # updated in the code 27 | 28 | 29 | class TargetMapConfig(NamedTuple): 30 | pass 31 | 32 | 33 | class ActionMapConfig(NamedTuple): 34 | pass 35 | 36 | 37 | class ImmutableAgentConfig(NamedTuple): 38 | """ 39 | The part of the AgentConfig that won't change based on curriculum. 40 | """ 41 | 42 | dimensions: ExcavatorDims = ExcavatorDims() 43 | angles_base: int = 12 44 | angles_cabin: int = 12 45 | max_wheel_angle: int = 2 46 | wheel_step: float = 25.0 # difference between next angles in discretization (in degrees) 47 | num_state_obs: int = 6 # number of state observations (used to determine network input) 48 | 49 | 50 | class AgentConfig(NamedTuple): 51 | random_init_state: bool = True 52 | 53 | angles_base: int = ImmutableAgentConfig().angles_base 54 | angles_cabin: int = ImmutableAgentConfig().angles_cabin 55 | max_wheel_angle: int = ImmutableAgentConfig().max_wheel_angle 56 | wheel_step: float = ImmutableAgentConfig().wheel_step 57 | 58 | move_tiles: int = 5 # number of tiles of progress for every move action 59 | # Note: move_tiles is also used as radius of excavation 60 | # (we dig as much as move_tiles in the radial distance) 61 | 62 | dig_depth: int = 1 # how much every dig action digs 63 | 64 | height: int = 0 # updated in the code 65 | width: int = 0 # updated in the code 66 | 67 | 68 | class Rewards(NamedTuple): 69 | existence: float 70 | 71 | collision_move: float 72 | move_while_loaded: float 73 | move: float 74 | move_with_turned_wheels: float 75 | 76 | collision_turn: float 77 | base_turn: float 78 | 79 | cabin_turn: float 80 | wheel_turn: float 81 | 82 | dig_wrong: float # dig where the target map is not negative (exclude case of positive action map -> moving dumped terrain) 83 | dump_wrong: float # given if loaded stayed the same or tried to dump in non-dumpable tile 84 | 85 | dig_correct: ( 86 | float # dig where the target map is negative, and not more than required 87 | ) 88 | dump_correct: float # dump where the target map is positive 89 | 90 | terminal: float # given if the action map is the same as the target map where it matters (digged tiles) 91 | 92 | normalizer: float # constant scaling factor for all rewards 93 | 94 | @staticmethod 95 | def dense(): 96 | return Rewards( 97 | existence=-0.1, 98 | collision_move=-0.2, 99 | move_while_loaded=0.0, 100 | move=-0.1, 101 | move_with_turned_wheels=-0.1, 102 | collision_turn=-0.1, 103 | base_turn=-0.1, 104 | cabin_turn=-0.05, 105 | wheel_turn=-0.05, 106 | dig_wrong=-0.25, 107 | dump_wrong=-1.0, 108 | dig_correct=0.2, 109 | dump_correct=0.15, 110 | terminal=100.0, 111 | normalizer=100.0, 112 | ) 113 | 114 | @staticmethod 115 | def sparse(): 116 | return Rewards( 117 | existence=-0.1, 118 | collision_move=-0.1, 119 | move_while_loaded=0.0, 120 | move=-0.05, 121 | move_with_turned_wheels=-0.15, 122 | collision_turn=-0.1, 123 | base_turn=-0.1, 124 | cabin_turn=-0.01, 125 | wheel_turn=-0.005, 126 | dig_wrong=-0.3, 127 | dump_wrong=-0.3, 128 | dig_correct=0.0, 129 | dump_correct=0.0, 130 | terminal=100.0, 131 | normalizer=100.0, 132 | ) 133 | 134 | 135 | class CurriculumConfig(NamedTuple): 136 | """State of the curriculum. This config should not be changed.""" 137 | 138 | level: int = 0 139 | consecutive_failures: int = 0 140 | consecutive_successes: int = 0 141 | 142 | 143 | class EnvConfig(NamedTuple): 144 | agent: AgentConfig = AgentConfig() 145 | 146 | target_map: TargetMapConfig = TargetMapConfig() 147 | action_map: ActionMapConfig = ActionMapConfig() 148 | 149 | maps: ImmutableMapsConfig = ImmutableMapsConfig() 150 | 151 | rewards: Rewards = Rewards.dense() 152 | 153 | apply_trench_rewards: bool = False 154 | alignment_coefficient: float = -0.08 155 | distance_coefficient: float = -0.04 156 | 157 | curriculum: CurriculumConfig = CurriculumConfig() 158 | 159 | max_steps_in_episode: int = 0 # changed by CurriculumManager 160 | tile_size: float = 0 # updated in the code 161 | 162 | @classmethod 163 | def new(cls): 164 | return EnvConfig() 165 | 166 | 167 | class MapsDimsConfig(NamedTuple): 168 | maps_edge_length: int = 0 # updated in the code 169 | 170 | 171 | class CurriculumGlobalConfig(NamedTuple): 172 | increase_level_threshold: int = 20 173 | decrease_level_threshold: int = 50 174 | last_level_type = "random" # ["random", "none"] 175 | 176 | # NOTE: all maps need to have the same size 177 | levels = [ 178 | { 179 | "maps_path": "terra/foundations", 180 | "max_steps_in_episode": 400, 181 | "rewards_type": RewardsType.DENSE, 182 | "apply_trench_rewards": False, 183 | }, 184 | { 185 | "maps_path": "terra/trenches/single", 186 | "max_steps_in_episode": 400, 187 | "rewards_type": RewardsType.DENSE, 188 | "apply_trench_rewards": True, 189 | }, 190 | { 191 | "maps_path": "terra/trenches/double", 192 | "max_steps_in_episode": 400, 193 | "rewards_type": RewardsType.DENSE, 194 | "apply_trench_rewards": True, 195 | }, 196 | { 197 | "maps_path": "terra/trenches/double_diagonal", 198 | "max_steps_in_episode": 400, 199 | "rewards_type": RewardsType.DENSE, 200 | "apply_trench_rewards": True, 201 | }, 202 | { 203 | "maps_path": "terra/foundations", 204 | "max_steps_in_episode": 400, 205 | "rewards_type": RewardsType.DENSE, 206 | "apply_trench_rewards": False, 207 | }, 208 | { 209 | "maps_path": "terra/trenches/triple_diagonal", 210 | "max_steps_in_episode": 400, 211 | "rewards_type": RewardsType.DENSE, 212 | "apply_trench_rewards": True, 213 | }, 214 | { 215 | "maps_path": "terra/foundations_large", 216 | "max_steps_in_episode": 500, 217 | "rewards_type": RewardsType.DENSE, 218 | "apply_trench_rewards": False, 219 | }, 220 | ] 221 | 222 | 223 | class BatchConfig(NamedTuple): 224 | action_type: Action = WheeledAction # [WheeledAction, TrackedAction] 225 | 226 | # Config to get data for batched env initialization 227 | agent: ImmutableAgentConfig = ImmutableAgentConfig() 228 | maps: ImmutableMapsConfig = ImmutableMapsConfig() 229 | maps_dims: MapsDimsConfig = MapsDimsConfig() 230 | 231 | curriculum_global: CurriculumGlobalConfig = CurriculumGlobalConfig() 232 | -------------------------------------------------------------------------------- /terra/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from jax import Array 5 | from terra.settings import IntLowDim 6 | from terra.settings import Float 7 | 8 | 9 | def increase_angle_circular(angle: IntLowDim, max_angle: IntLowDim) -> IntLowDim: 10 | """ 11 | Increases the angle by 1 until max angle. In case of max angle, 0 is returned. 12 | 13 | Args: 14 | angle: int >= 0 15 | max_angle: int > 0 16 | """ 17 | return (angle + 1) % max_angle 18 | 19 | 20 | def decrease_angle_circular(angle: IntLowDim, max_angle: IntLowDim) -> IntLowDim: 21 | """ 22 | Decreases the angle by 1 until 0. In case of a negative value, max_angle - 1 is returned. 23 | 24 | Args: 25 | angle: int >= 0 26 | max_angle: int > 0 27 | """ 28 | return (angle + max_angle - 1) % max_angle 29 | 30 | 31 | def apply_rot_transl(anchor_state: Array, global_coords: Array) -> Array: 32 | """ 33 | Applies the following transform to every element of global_coords: 34 | local_coords = [R|t]^-1 global_coords 35 | where R and t are extracted from anchor state. 36 | 37 | Args: 38 | - anchor_state: (3, ) Array containing [x, y, theta (rad)] 39 | Note: this is intended as the local frame expressed in the global frame. 40 | - global_coords: (2, N) Array containing [x, y] 41 | Returns: 42 | - local_coords: (2, N) Array containing local [x, y] 43 | """ 44 | theta = anchor_state[2] 45 | costheta, sintheta = jnp.cos(theta), jnp.sin(theta) 46 | R_t = jnp.array([[costheta, sintheta], [-sintheta, costheta]]) 47 | 48 | t = anchor_state[:2] 49 | neg_Rt_t = -R_t @ t 50 | 51 | # Build the inverse transformation matrix 52 | T = jnp.block([[R_t, neg_Rt_t[:, None]], [jnp.array([0.0, 0.0, 1.0])]]) 53 | 54 | local_coords = jnp.einsum( 55 | "ij,jk->ik", 56 | T, 57 | jnp.vstack([global_coords, jnp.ones((1, global_coords.shape[1]))]), 58 | ) 59 | return local_coords[:2] 60 | 61 | 62 | def apply_local_cartesian_to_cyl(local_coords: Array) -> Array: 63 | """ 64 | Transforms the input array from local cartesian coordinates to 65 | cyilindrical coordinates. 66 | 67 | Note: this function takes also care of the fact that we use an 68 | unconventional reference frame (x vertical axis to the bottom, 69 | y horizontal axis to the right). You can see this in the computation of theta. 70 | 71 | Args: 72 | - local_coords: (2, N) Array with [x, y] rows 73 | Returns: 74 | - cyl_coords: (2, N) Array with [r, theta] rows, 75 | Note: theta belongs to [-pi, pi] 76 | """ 77 | x, y = local_coords[0], local_coords[1] 78 | r = jnp.sqrt(x * x + y * y) 79 | theta = jnp.arctan2(-x, y) 80 | return jnp.vstack([r, theta]) 81 | 82 | 83 | def wrap_angle_rad(angle: Float) -> Float: 84 | """ 85 | Wraps an angle in rad to the interval [-pi, pi] 86 | """ 87 | return (angle + np.pi) % (2 * np.pi) - np.pi 88 | 89 | 90 | def angle_idx_to_rad(angle: IntLowDim, idx_tot: IntLowDim) -> Float: 91 | """ 92 | Converts an angle idx (e.g. 4) to an angle in rad, given the max 93 | angle idx possible (e.g. 8). 94 | """ 95 | angle = 2.0 * np.pi * angle / idx_tot 96 | return wrap_angle_rad(angle) 97 | 98 | 99 | def get_arm_angle_int( 100 | angle_base, angle_cabin, n_angles_base, n_angles_cabin 101 | ) -> IntLowDim: 102 | """ 103 | Returns the equivalent int angle of the arm, expressed in 104 | the range of numbers allowed by the cabin angles. 105 | """ 106 | angles_cabin_base_ratio = round(n_angles_cabin / n_angles_base) 107 | return (angles_cabin_base_ratio * angle_base + angle_cabin) % n_angles_cabin 108 | 109 | 110 | def get_distance_point_to_line(p, abc): 111 | """ 112 | p = Array[x, y] 113 | abc = Array[A, B, C] 114 | """ 115 | # Note: this is swapped because in digmap we are computing the axes coefficients 116 | # with the opposite convention. 117 | p_x = p[1] 118 | p_y = p[0] 119 | 120 | numerator = jnp.abs(abc[0] * p_x + abc[1] * p_y + abc[2]) 121 | denominator = jnp.sqrt(abc[0] ** 2 + abc[1] ** 2) 122 | distance = numerator / denominator 123 | return distance 124 | 125 | 126 | def get_min_distance_point_to_lines(p, lines, trench_type): 127 | """ 128 | p = Array[x, y] 129 | lines = Array[Array[A, B, C]] 130 | trench_type = int, number of axis of a trench (-1 if not a trench) 131 | """ 132 | p = p.astype(Float) 133 | lines = lines.astype(Float) 134 | 135 | def _for_body_it(i, d_min): 136 | d = get_distance_point_to_line(p, lines[i]) 137 | return jnp.min( 138 | jnp.concatenate((jnp.array([d]), d_min)), 139 | axis=0, 140 | keepdims=True, 141 | ).astype(Float) 142 | 143 | d_min = jax.lax.fori_loop( 144 | 0, 145 | trench_type, 146 | _for_body_it, 147 | jnp.full((1,), 9999.0, dtype=Float), 148 | ) 149 | return d_min[0] 150 | 151 | 152 | def get_agent_corners( 153 | pos_base: Array, 154 | base_orientation: IntLowDim, 155 | agent_width: IntLowDim, 156 | agent_height: IntLowDim, 157 | angles_base: IntLowDim, 158 | ): 159 | """ 160 | Gets the coordinates of the 4 corners of the agent. 161 | The function uses a biased rounding strategy to avoid rectangle shrinkage. 162 | """ 163 | # Determine half dimensions using floor/ceil to properly handle odd dimensions. 164 | half_width_left = jnp.floor(agent_width / 2.0) 165 | half_width_right = jnp.ceil(agent_width / 2.0) 166 | half_height_bottom = jnp.floor(agent_height / 2.0) 167 | half_height_top = jnp.ceil(agent_height / 2.0) 168 | 169 | # Define corners in local coordinates relative to the center. 170 | local_corners = jnp.array([ 171 | [-half_width_left, -half_height_bottom], 172 | [ half_width_right, -half_height_bottom], 173 | [ half_width_right, half_height_top], 174 | [-half_width_left, half_height_top] 175 | ]) 176 | 177 | # Convert degrees to radians using JAX. 178 | angle_rad = (base_orientation.astype(jnp.float32) / jnp.array(angles_base, dtype=jnp.float32)) * (2 * jnp.pi) 179 | cos_a = jnp.cos(angle_rad) 180 | sin_a = jnp.sin(angle_rad) 181 | # Build the rotation matrix. 182 | R = jnp.array([[cos_a, -sin_a], 183 | [sin_a, cos_a]]) 184 | R = R.squeeze() 185 | 186 | # Rotate local corners and translate by the center position. 187 | global_corners_float = (R @ local_corners.T).T + jnp.array(pos_base, dtype=IntLowDim) 188 | 189 | # Bias the rounding: use floor if below the center, ceil otherwise. 190 | center_arr = jnp.array(pos_base, dtype=IntLowDim) 191 | biased_corners = jnp.where( 192 | global_corners_float < center_arr, 193 | jnp.floor(global_corners_float), 194 | jnp.ceil(global_corners_float) 195 | ).astype(IntLowDim) 196 | 197 | return biased_corners 198 | 199 | 200 | def compute_polygon_mask(corners: Array, map_width: int, map_height: int) -> Array: 201 | """ 202 | Compute a mask (map_width x map_height) indicating the cells covered 203 | by the polygon defined by its corners. 204 | """ 205 | # Create a grid of points. 206 | xs = jnp.arange(map_height) 207 | ys = jnp.arange(map_width) 208 | X, Y = jnp.meshgrid(xs, ys, indexing='xy') 209 | pts = jnp.stack([Y, X], axis=-1).reshape((-1, 2)) # (N,2) as [y,x] 210 | edges = jnp.roll(corners, -1, axis=0) - corners # (4,2) 211 | diff = pts[None, :, :] - corners[:, None, :] # (4, N, 2) 212 | edges_exp = edges[:, None, :] # (4, 1, 2) 213 | cross = edges_exp[..., 0] * diff[..., 1] - edges_exp[..., 1] * diff[..., 0] # (4, N) 214 | inside = jnp.logical_or(jnp.all(cross > 0, axis=0), jnp.all(cross < 0, axis=0)) 215 | mask = inside.reshape((map_height, map_width)) 216 | return mask 217 | -------------------------------------------------------------------------------- /terra/viz/game/game.py: -------------------------------------------------------------------------------- 1 | import pygame as pg 2 | import sys 3 | from PIL import Image 4 | from .world import World 5 | from .agent import Agent 6 | from .settings import MAP_TILES 7 | from terra.config import ExcavatorDims, ImmutableMapsConfig, ImmutableAgentConfig 8 | import threading 9 | import math 10 | 11 | 12 | def get_agent_dims(agent_w_m, agent_h_m, tile_size_m): 13 | """TODO repeated function, move to utils and share with env.""" 14 | agent_height = ( 15 | round(agent_w_m / tile_size_m) 16 | if (round(agent_w_m / tile_size_m)) % 2 != 0 17 | else round(agent_w_m / tile_size_m) + 1 18 | ) 19 | agent_width = ( 20 | round(agent_h_m / tile_size_m) 21 | if (round(agent_h_m / tile_size_m)) % 2 != 0 22 | else round(agent_h_m / tile_size_m) + 1 23 | ) 24 | return agent_width, agent_height 25 | 26 | 27 | class Game: 28 | def __init__( 29 | self, 30 | screen, 31 | surface, 32 | clock, 33 | maps_size_px, 34 | n_envs_x=1, 35 | n_envs_y=1, 36 | display=True, 37 | ): 38 | self.screen = screen 39 | self.surface = surface 40 | self.clock = clock 41 | self.display = display 42 | self.width, self.height = self.screen.get_size() 43 | 44 | self.n_envs_x = n_envs_x 45 | self.n_envs_y = n_envs_y 46 | self.n_envs = n_envs_x * n_envs_y 47 | self.worlds = [] 48 | self.agents = [] 49 | 50 | tile_size_m = ImmutableMapsConfig().edge_length_m / maps_size_px 51 | self.maps_size_px = maps_size_px 52 | tile_size = MAP_TILES // maps_size_px 53 | self.tile_size = tile_size 54 | excavator_dims = ExcavatorDims() 55 | agent_h, agent_w = get_agent_dims( 56 | excavator_dims.WIDTH, excavator_dims.HEIGHT, tile_size_m 57 | ) 58 | angles_base = ImmutableAgentConfig().angles_base 59 | angles_cabin = ImmutableAgentConfig().angles_cabin 60 | print(f"Agent size (in rendering): {agent_w}x{agent_h}") 61 | print(f"Tile size (in rendering): {tile_size_m}") 62 | print(f"Rendering tile size: {tile_size}") 63 | print(f"Number of possible base rotations: {angles_base}") 64 | print(f"Number of possible cabin rotations: {angles_cabin}") 65 | for _ in range(self.n_envs): 66 | self.worlds.append( 67 | World(maps_size_px, maps_size_px, self.width, self.height, tile_size) 68 | ) 69 | self.agents.append(Agent(agent_w, agent_h, tile_size, angles_base, angles_cabin)) 70 | 71 | self.frames = [] 72 | 73 | self.old_agents = [] 74 | self.count = 0 75 | 76 | def run( 77 | self, 78 | active_grid, 79 | target_grid, 80 | padding_mask, 81 | dumpability_mask, 82 | agent_pos, 83 | base_dir, 84 | cabin_dir, 85 | loaded, 86 | generate_gif, 87 | target_tiles=None, 88 | ): 89 | # self.events() 90 | self.update( 91 | active_grid, 92 | target_grid, 93 | padding_mask, 94 | dumpability_mask, 95 | agent_pos, 96 | base_dir, 97 | cabin_dir, 98 | loaded, 99 | target_tiles, 100 | ) 101 | self.draw() 102 | if generate_gif: 103 | frame = pg.surfarray.array3d(pg.display.get_surface()) 104 | self.frames.append(frame.swapaxes(0, 1)) 105 | 106 | def create_gif(self, gif_path="/home/antonio/Downloads/Terra.gif"): 107 | image_frames = [Image.fromarray(frame) for frame in self.frames] 108 | image_frames[0].save( 109 | gif_path, 110 | save_all=True, 111 | append_images=image_frames[1:], 112 | loop=0, 113 | duration=100, 114 | ) 115 | print(f"GIF generated at {gif_path}") 116 | self.frames = [] 117 | 118 | def events(self): 119 | for event in pg.event.get(): 120 | if event.type == pg.QUIT: 121 | pg.quit() 122 | sys.exit() 123 | if event.type == pg.KEYDOWN: 124 | if event.key == pg.K_ESCAPE: 125 | pg.quit() 126 | sys.exit() 127 | 128 | def update( 129 | self, 130 | active_grid, 131 | target_grid, 132 | padding_mask, 133 | dumpability_mask, 134 | agent_pos, 135 | base_dir, 136 | cabin_dir, 137 | loaded, 138 | target_tiles=None, 139 | ): 140 | def update_world_agent( 141 | world, 142 | agent, 143 | active_grid, 144 | target_grid, 145 | padding_mask, 146 | dumpability_mask, 147 | agent_pos, 148 | base_dir, 149 | cabin_dir, 150 | loaded, 151 | target_tiles=None, 152 | ): 153 | world.update(active_grid, target_grid, padding_mask, dumpability_mask) 154 | agent.update(agent_pos, base_dir, cabin_dir, loaded) 155 | if target_tiles is not None: 156 | world.target_tiles = target_tiles 157 | 158 | threads = [] 159 | for i in range(self.n_envs): 160 | ag = active_grid[i] 161 | tg = target_grid[i] 162 | pm = padding_mask[i] 163 | dm = dumpability_mask[i] 164 | ap = agent_pos[i] 165 | bd = base_dir[i] 166 | cd = cabin_dir[i] 167 | ld = loaded[i] 168 | tt = None if target_tiles is None else target_tiles[i] 169 | thread = threading.Thread( 170 | target=update_world_agent, 171 | args=(self.worlds[i], self.agents[i], ag, tg, pm, dm, ap, bd, cd, ld, tt), 172 | ) 173 | thread.start() 174 | threads.append(thread) 175 | 176 | for thread in threads: 177 | thread.join() 178 | 179 | def draw(self): 180 | self.surface.fill("#F0F0F0") 181 | agent_surfaces = [] 182 | agent_positions = [] 183 | 184 | for i, (world, agent) in enumerate(zip(self.worlds, self.agents)): 185 | ix = i % self.n_envs_y 186 | iy = i // self.n_envs_y 187 | 188 | total_offset_x = ( 189 | ix * (self.maps_size_px + 4) * self.tile_size + 4 * self.tile_size 190 | ) 191 | total_offset_y = ( 192 | iy * (self.maps_size_px + 4) * self.tile_size + 4 * self.tile_size 193 | ) 194 | 195 | # Draw terrain 196 | for x in range(world.grid_length_x): 197 | for y in range(world.grid_length_y): 198 | sq = world.action_map[x][y]["cart_rect"] 199 | c = world.action_map[x][y]["color"] 200 | rect = pg.Rect( 201 | sq[0][0] + total_offset_x, 202 | sq[0][1] + total_offset_y, 203 | self.tile_size, 204 | self.tile_size, 205 | ) 206 | pg.draw.rect(self.surface, c, rect, 0) 207 | 208 | # Highlight target tiles (where the digger will dig / dump) 209 | if hasattr(world, 'target_tiles') and world.target_tiles is not None: 210 | flat_idx = y * world.grid_length_x + x 211 | if flat_idx < len(world.target_tiles) and world.target_tiles[flat_idx]: 212 | pg.draw.rect(self.surface, "#FF3300", rect, 2) 213 | 214 | body_vertices = agent.agent["body"]["vertices"] 215 | ca = agent.agent["body"]["color"] 216 | 217 | # Calculate the bounding box 218 | min_x = min(v[0] for v in body_vertices) 219 | min_y = min(v[1] for v in body_vertices) 220 | max_x = max(v[0] for v in body_vertices) 221 | max_y = max(v[1] for v in body_vertices) 222 | 223 | # Calculate surface size with a small padding 224 | surface_width = math.ceil(max_x - min_x) + 2 225 | surface_height = math.ceil(max_y - min_y) + 2 226 | 227 | # Create surface for the agent 228 | agent_surfaces.append( 229 | pg.Surface((surface_width, surface_height), pg.SRCALPHA) 230 | ) 231 | 232 | # Calculate surface position 233 | agent_x = min_x + total_offset_x 234 | agent_y = min_y + total_offset_y 235 | agent_positions.append((agent_x, agent_y)) 236 | 237 | # Adjust vertices for the agent's surface 238 | offset_vertices = [(v[0] - min_x, v[1] - min_y) for v in body_vertices] 239 | 240 | # Draw agent body as polygon 241 | pg.draw.polygon(agent_surfaces[-1], ca, offset_vertices) 242 | 243 | # Get cabin vertices and adjust for agent surface 244 | cabin = agent.agent["cabin"]["vertices"] 245 | cabin_offset = [(v[0] - min_x, v[1] - min_y) for v in cabin] 246 | cabin_color = agent.agent["cabin"]["color"] 247 | pg.draw.polygon(agent_surfaces[-1], cabin_color, cabin_offset) 248 | 249 | self.screen.blit(self.surface, (0, 0)) 250 | 251 | for agent_surface, agent_position in zip(agent_surfaces, agent_positions): 252 | self.screen.blit(agent_surface, agent_position) 253 | 254 | if self.display: 255 | pg.display.flip() 256 | -------------------------------------------------------------------------------- /terra/map_utils/openstreet_dataset_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | import skimage.measure 7 | 8 | 9 | def _filter_buildings_on_dims(w_max: float, h_max: float) -> list[int]: 10 | idx_list = [] 11 | for i in range(2700): 12 | try: 13 | with open(f"{openstreet_path}/metadata/building_{i}.json") as f: 14 | meta1 = json.load(f) 15 | w1 = int(meta1["real_dimensions"]["length"]) 16 | h1 = int(meta1["real_dimensions"]["width"]) 17 | if w1 < w_max and h1 < h_max: 18 | idx_list.append(i) 19 | except: 20 | continue 21 | return idx_list 22 | 23 | 24 | def _get_building_h_w(idx: int) -> tuple[float, float]: 25 | with open(f"{openstreet_path}/metadata/building_{idx}.json") as f: 26 | meta1 = json.load(f) 27 | w1 = int(meta1["real_dimensions"]["length"]) 28 | h1 = int(meta1["real_dimensions"]["width"]) 29 | return w1, h1 30 | 31 | 32 | def _convert_img_to_terra(img): 33 | """ 34 | Converts an image from [0, 255] convention 35 | to [-1, 0, 1] convention. 36 | """ 37 | img = img[..., 0].astype(np.int16) # squeeze to 1 channel 38 | img = np.where(img == 0, -1, img) 39 | img = np.where(img == 100, 1, img) 40 | img = np.where(img == 255, 0, img) 41 | return img.astype(np.int8) 42 | 43 | 44 | def _handle_options(img, option: int, img_path: str | None): 45 | """ 46 | The input image is a NxMx3 uint8 image, where: 47 | 0 --> dig 48 | 100 --> dump 49 | 255 --> nothing 50 | 51 | option: [1, 2] 52 | 1 --> save image at path (converted to [-1, 0, 1]) 53 | 2 --> show img 54 | """ 55 | target_path = Path(img_path).parent 56 | target_path.mkdir(parents=True, exist_ok=True) 57 | if option == 1: 58 | img = _convert_img_to_terra(img) 59 | np.save(img_path, img) 60 | elif option == 2: 61 | img = img.astype(np.uint8).repeat(10, 0).repeat(10, 1) 62 | cv2.imshow("buildings", img.astype(np.uint8)) 63 | cv2.waitKey(0) 64 | cv2.destroyAllWindows() 65 | else: 66 | raise RuntimeError(f"Option {option} doesn't exist.") 67 | 68 | 69 | def generate_openstreet_2( 70 | wm, hm: int, div: int, option: int = 1, max_n_imgs: int | None = None 71 | ): 72 | """ 73 | option: [1, 2] 74 | 1 --> save image at path 75 | 2 --> show imgs 76 | """ 77 | w, h = div * wm, div * hm 78 | 79 | idx_list = _filter_buildings_on_dims(int(0.65 * wm), int(0.65 * hm)) 80 | n = len(idx_list) 81 | idx_list = np.array(idx_list) 82 | idx_list_a = idx_list[None].repeat(n, 0).reshape(-1) 83 | idx_list_b = idx_list[:, None].repeat(n, 1).reshape(-1) 84 | idx_list = np.vstack([idx_list_a, idx_list_b]) # 2x(N^2) -> all combinations 85 | np.random.shuffle(idx_list.swapaxes(0, 1)) 86 | n_combinations = idx_list.shape[-1] 87 | 88 | n_overlapping = 0 89 | img_idx = 0 90 | for i in range(n_combinations): 91 | if max_n_imgs is not None and i >= max_n_imgs: 92 | break 93 | 94 | idx1 = idx_list[0, i] 95 | idx2 = idx_list[1, i] 96 | 97 | building_img_path1 = f"{openstreet_path}/images/building_{idx1}.png" 98 | building_img_path2 = f"{openstreet_path}/images/building_{idx2}.png" 99 | 100 | w1, h1 = _get_building_h_w(idx1) 101 | w2, h2 = _get_building_h_w(idx2) 102 | 103 | pic1 = cv2.imread(building_img_path1) 104 | pic2 = cv2.imread(building_img_path2) 105 | 106 | p = np.ones((w, h, 3)) * 255 107 | p1 = p.copy() 108 | 109 | pic1 = cv2.resize(pic1, (w1 * div, h1 * div)).astype( 110 | np.uint8 111 | ) # , interpolation=cv2.INTER_AREA) 112 | pic2 = cv2.resize(pic2, (w2 * div, h2 * div)).astype( 113 | np.uint8 114 | ) # , interpolation=cv2.INTER_AREA) 115 | 116 | for _ in range(40): 117 | x1 = np.random.randint(0, w - pic1.shape[0]) 118 | y1 = np.random.randint(0, h - pic1.shape[1]) 119 | p[x1 : x1 + pic1.shape[0], y1 : y1 + pic1.shape[1]] = np.where( 120 | pic1 < 255, 0, pic1 121 | ) 122 | mask1 = p == 0 123 | 124 | x2 = np.random.randint(0, w - pic2.shape[0]) 125 | y2 = np.random.randint(0, h - pic2.shape[1]) 126 | p1[x2 : x2 + pic2.shape[0], y2 : y2 + pic2.shape[1]] = pic2 127 | mask2 = p1 == 0 128 | overlapping = np.any(mask1 * mask2) 129 | if not overlapping: 130 | p[x2 : x2 + pic2.shape[0], y2 : y2 + pic2.shape[1]] = np.where( 131 | pic2 < 255, 100, pic2 132 | ) 133 | break 134 | 135 | if not overlapping: 136 | # pd = cv2.resize(p, (40, 40)).repeat(10, 0).repeat(10, 1) 137 | pd = skimage.measure.block_reduce( 138 | p, (p.shape[0] // wm, p.shape[1] // hm, 1), np.min 139 | ) 140 | 141 | _handle_options( 142 | pd, 143 | option, 144 | f"{target_path}/2_buildings/{wm}x{hm}/img_{img_idx}.npy", 145 | ) 146 | img_idx += 1 147 | 148 | continue 149 | else: 150 | n_overlapping += 1 151 | print(f"{n_overlapping=}") 152 | 153 | print(f"Generated {img_idx} images.") 154 | 155 | 156 | def generate_openstreet_3( 157 | wm, hm: int, div: int, option: int = 1, max_n_imgs: int | None = None 158 | ): 159 | """ 160 | option: [1, 2] 161 | 1 --> save image at path 162 | 2 --> show imgs 163 | """ 164 | w, h = div * wm, div * hm 165 | 166 | idx_list = _filter_buildings_on_dims(int(0.65 * wm), int(0.65 * hm)) 167 | n = len(idx_list) 168 | idx_list = np.array(idx_list) 169 | idx_list_plain = idx_list.copy() 170 | idx_list_a = idx_list[None].repeat(n, 0).reshape(-1) 171 | idx_list_b = idx_list[:, None].repeat(n, 1).reshape(-1) 172 | idx_list = np.vstack([idx_list_a, idx_list_b]) # 2x(N^2) -> all combinations of 2 173 | 174 | img_idx = 0 175 | for idx3 in idx_list_plain.tolist(): 176 | if max_n_imgs is not None and img_idx >= max_n_imgs: 177 | break 178 | np.random.shuffle(idx_list.swapaxes(0, 1)) 179 | n_combinations = idx_list.shape[-1] 180 | 181 | n_overlapping = 0 182 | for i in range(n_combinations): 183 | if max_n_imgs is not None and img_idx >= max_n_imgs: 184 | break 185 | 186 | idx1 = idx_list[0, i] 187 | idx2 = idx_list[1, i] 188 | 189 | building_img_path1 = f"{openstreet_path}/images/building_{idx1}.png" 190 | building_img_path2 = f"{openstreet_path}/images/building_{idx2}.png" 191 | building_img_path3 = f"{openstreet_path}/images/building_{idx3}.png" 192 | 193 | w1, h1 = _get_building_h_w(idx1) 194 | w2, h2 = _get_building_h_w(idx2) 195 | w3, h3 = _get_building_h_w(idx3) 196 | 197 | pic1 = cv2.imread(building_img_path1) 198 | pic2 = cv2.imread(building_img_path2) 199 | pic3 = cv2.imread(building_img_path3) 200 | 201 | p = np.ones((w, h, 3)) * 255 202 | p1 = p.copy() 203 | p2 = p.copy() 204 | 205 | pic1 = cv2.resize(pic1, (w1 * div, h1 * div)).astype(np.uint8) 206 | pic2 = cv2.resize(pic2, (w2 * div, h2 * div)).astype(np.uint8) 207 | pic3 = cv2.resize(pic3, (w3 * div, h3 * div)).astype(np.uint8) 208 | 209 | for _ in range(40): 210 | x1 = np.random.randint(0, w - pic1.shape[0]) 211 | y1 = np.random.randint(0, h - pic1.shape[1]) 212 | p[x1 : x1 + pic1.shape[0], y1 : y1 + pic1.shape[1]] = np.where( 213 | pic1 < 255, 0, pic1 214 | ) 215 | mask1 = p < 255 216 | 217 | x2 = np.random.randint(0, w - pic2.shape[0]) 218 | y2 = np.random.randint(0, h - pic2.shape[1]) 219 | p1[x2 : x2 + pic2.shape[0], y2 : y2 + pic2.shape[1]] = pic2 220 | mask2 = p1 < 255 221 | 222 | overlapping = np.any(mask1 * mask2) 223 | overlapping2 = True 224 | if not overlapping: 225 | x3 = np.random.randint(0, w - pic3.shape[0]) 226 | y3 = np.random.randint(0, h - pic3.shape[1]) 227 | p2[x3 : x3 + pic3.shape[0], y3 : y3 + pic3.shape[1]] = pic3 228 | mask3 = p2 < 255 229 | 230 | overlapping2 = np.any(mask1 * mask2 * mask3) 231 | if not overlapping2: 232 | p[x2 : x2 + pic2.shape[0], y2 : y2 + pic2.shape[1]] = np.where( 233 | pic2 < 255, 100, pic2 234 | ) 235 | p[x3 : x3 + pic3.shape[0], y3 : y3 + pic3.shape[1]] = np.where( 236 | pic3 < 255, 0, pic3 237 | ) 238 | break 239 | 240 | if not overlapping2: 241 | # pd = cv2.resize(p, (40, 40)).repeat(10, 0).repeat(10, 1) 242 | pd = skimage.measure.block_reduce( 243 | p, (p.shape[0] // wm, p.shape[1] // hm, 1), np.min 244 | ) 245 | 246 | # p_grey = np.where(p < 255, 100, p).astype(np.uint8) 247 | 248 | _handle_options( 249 | pd, 250 | option, 251 | f"{target_path}/3_buildings/{wm}x{hm}/img_{img_idx}.npy", 252 | ) 253 | img_idx += 1 254 | continue 255 | else: 256 | n_overlapping += 1 257 | print(f"{n_overlapping=}") 258 | 259 | print(f"Generated {img_idx} images.") 260 | 261 | 262 | if __name__ == "__main__": 263 | wm, hm = 20, 20 # meters 264 | div = 10 265 | openstreet_path = "/media/openstreet" 266 | target_path = "/media/img_generator" 267 | generate_openstreet_2(wm, hm, div, option=1, max_n_imgs=None) 268 | -------------------------------------------------------------------------------- /terra/map_utils/openstreet_plugin.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | import osmnx as ox 8 | from PIL import Image 9 | from pyproj import CRS, Transformer 10 | from shapely.geometry import Point 11 | from shapely.geometry import Polygon 12 | 13 | 14 | def get_building_shapes_from_OSM( 15 | north, south, east, west, option=1, save_folder="data/", folder_path=None 16 | ): 17 | """ 18 | Extracts building shapes from OpenStreetMap given a bounding box of coordinates. 19 | 20 | Parameters: 21 | north (float): northern boundary of bounding box 22 | south (float): southern boundary of bounding box 23 | east (float): eastern boundary of bounding box 24 | west (float): western boundary of bounding box 25 | option (int, optional): Option for operation. 1 for saving a binary map, 2 for saving individual buildings. Defaults to 1. 26 | save_path (str, optional): File path to save output if option is 1. Defaults to 'output.png'. 27 | folder_path (str, optional): Folder path to save output if option is 2. Defaults to None. 28 | 29 | Returns: 30 | None 31 | """ 32 | # make sure the save folder exists 33 | if not os.path.exists(save_folder): 34 | os.makedirs(save_folder) 35 | 36 | # Fetch buildings from OSM\ 37 | bbox = (north, south, east, west) 38 | buildings = ox.geometries.geometries_from_bbox(*bbox, tags={"building": True}) 39 | print("got buildings") 40 | 41 | # Define coordinate reference systems 42 | wgs84 = CRS("EPSG:4326") # WGS84 (lat-long) coordinate system 43 | utm = CRS("EPSG:32633") # UTM zone 33N (covers central Europe) 44 | 45 | # Create transformer 46 | transformer = Transformer.from_crs(wgs84, utm, always_xy=True) 47 | 48 | # Check option 49 | if option == 1: 50 | extract_crop(buildings, wgs84, utm, north, south, east, west, save_folder, transformer) 51 | elif option == 2: 52 | extract_single_buildings(buildings, wgs84, utm, folder_path=folder_path, transformer=transformer) 53 | else: 54 | print("Invalid option selected. Choose either 1 or 2.") 55 | 56 | 57 | def extract_crop(buildings, wgs84, utm, north, south, east, west, save_folder, transformer): 58 | # Convert the bounding box to UTM 59 | west_utm, south_utm = transformer.transform(west, south) 60 | east_utm, north_utm = transformer.transform(east, north) 61 | 62 | # Create a Polygon that represents the bounding box 63 | bbox_polygon = Polygon( 64 | [ 65 | (west_utm, south_utm), 66 | (east_utm, south_utm), 67 | (east_utm, north_utm), 68 | (west_utm, north_utm), 69 | ] 70 | ) 71 | 72 | # Count how many buildings have their centroid within the bounding box 73 | buildings_in_bbox = sum( 74 | bbox_polygon.contains( 75 | Point(transformer.transform(*building.centroid.coords[0])) 76 | ) 77 | for building in buildings.geometry 78 | ) 79 | 80 | # Check the number of buildings. Return if less than 2. 81 | if buildings_in_bbox < 2: 82 | print("Less than 2 buildings found in the given bounding box") 83 | return 84 | 85 | try: 86 | # Convert total bounds to UTM 87 | total_bounds_utm = [ 88 | transformer.transform( 89 | buildings.total_bounds[i], buildings.total_bounds[i + 1] 90 | ) 91 | for i in range(0, 4, 2) 92 | ] 93 | aspect_ratio = ( 94 | (total_bounds_utm[1][0] - total_bounds_utm[0][0]) 95 | / (total_bounds_utm[1][1] - total_bounds_utm[0][1]) 96 | ) ** (-1) 97 | dpi = 50 98 | max_pixels = 600 99 | max_inches = int(max_pixels / dpi) 100 | if aspect_ratio > 1: 101 | figsize = (max_inches, int(max_inches / aspect_ratio)) 102 | else: 103 | figsize = (int(max_inches * aspect_ratio), max_inches) 104 | # if 0 set it to 1 105 | if figsize[0] == 0: 106 | figsize = (1, figsize[1]) 107 | if figsize[1] == 0: 108 | figsize = (figsize[0], 1) 109 | 110 | # Plot and save the binary map 111 | fig, ax = plt.subplots(figsize=figsize, dpi=20) 112 | ax.axis("off") 113 | print("Plotting binary map...") 114 | ox.plot_footprints(buildings, ax=ax, color="black", bgcolor="white", show=False) 115 | print("Done") 116 | # save path include the coordinates of the bounding box 117 | save_path = f"{save_folder}/output_{north}_{south}_{east}_{west}" 118 | plt.savefig( 119 | save_path + ".png", 120 | dpi=300, 121 | bbox_inches="tight", 122 | pad_inches=0, 123 | transparent=True, 124 | ) 125 | print(f"Saved binary map in {save_path}") 126 | plt.close() 127 | except ValueError: 128 | print("No buildings found in the given bounding box") 129 | return 130 | 131 | # Post-process the image with OpenCV to count the number of distinct buildings 132 | image = cv2.imread(f"{save_path}.png", cv2.IMREAD_GRAYSCALE) 133 | ret, thresh = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY) 134 | num_labels, labels = cv2.connectedComponents(thresh) 135 | 136 | # Check the number of distinct buildings (subtract 1 because the background is also considered a label) 137 | num_buildings = num_labels - 1 138 | if num_buildings < 2: 139 | print("Less than 2 distinct buildings found in the image") 140 | os.remove( 141 | f"{save_path}.png" 142 | ) # Delete the image if it has less than 2 buildings 143 | return 144 | 145 | # Open the image file with PIL and get its size in pixels 146 | with Image.open(f"{save_path}.png") as img: 147 | width_px, height_px = img.size 148 | 149 | # Convert the size from pixels to inches 150 | new_size = (width_px / dpi, height_px / dpi) 151 | 152 | print(f"old size {figsize} and new size {new_size}") 153 | 154 | width = total_bounds_utm[1][0] - total_bounds_utm[0][0] 155 | height = total_bounds_utm[1][1] - total_bounds_utm[0][1] 156 | new_height = new_size[0] / figsize[0] * height 157 | new_width = new_size[1] / figsize[1] * width 158 | print(f"old width {width} and new width {new_width}") 159 | print(f"old height {height} and new height {new_height}") 160 | 161 | # Save metadata file with real dimensions in meters 162 | metadata = {"real_dimensions": {"height": new_width, "width": new_height}} 163 | with open(f"{save_path}.json", "w") as file: 164 | json.dump(metadata, file) 165 | 166 | 167 | def extract_single_buildings(buildings, wgs84, utm, folder_path=None, transformer=None): 168 | if folder_path is None: 169 | raise ValueError("No folder path provided") 170 | 171 | for i, building in enumerate(buildings.geometry): 172 | if building.area < 1e-9: 173 | continue 174 | 175 | bounds_utm = [ 176 | transformer.transform(building.bounds[i], building.bounds[i + 1]) 177 | for i in range(0, 4, 2) 178 | ] 179 | aspect_ratio = ( 180 | (bounds_utm[1][0] - bounds_utm[0][0]) 181 | / (bounds_utm[1][1] - bounds_utm[0][1]) 182 | ) ** (-1) 183 | dpi = 50 184 | max_pixels = 600 185 | max_inches = int(max_pixels / dpi) 186 | if aspect_ratio > 1: 187 | figsize = (max_inches, int(max_inches / aspect_ratio)) 188 | else: 189 | figsize = (int(max_inches * aspect_ratio), max_inches) 190 | # if 0 set it to 1 191 | if figsize[0] == 0: 192 | figsize = (1, figsize[1]) 193 | if figsize[1] == 0: 194 | figsize = (figsize[0], 1) 195 | 196 | fig, ax = plt.subplots(figsize=figsize) 197 | ax.axis("off") 198 | ax.set_xlim(bounds_utm[0][0], bounds_utm[1][0]) 199 | ax.set_ylim(bounds_utm[0][1], bounds_utm[1][1]) 200 | ax.invert_yaxis() 201 | ax.set_aspect("equal", adjustable="box") 202 | ox.plot_footprints( 203 | buildings.iloc[i : i + 1], ax=ax, color="black", bgcolor="white", show=False 204 | ) 205 | # plt.tight_layout() 206 | # new_size = fig.get_size_inches() 207 | # apply padding 208 | plt.savefig( 209 | f"{folder_path}/images/building_{i}.png", 210 | dpi=dpi, 211 | pad_inches=1.0, 212 | bbox_inches="tight", 213 | ) 214 | plt.close() 215 | # Open the image file with PIL and get its size in pixels 216 | with Image.open(f"{folder_path}/images/building_{i}.png") as img: 217 | width_px, height_px = img.size 218 | 219 | # Convert the size from pixels to inches 220 | new_size = (width_px / dpi, height_px / dpi) 221 | 222 | print(f"old size {figsize} and new size {new_size} for building {i}") 223 | 224 | width = bounds_utm[1][0] - bounds_utm[0][0] 225 | height = bounds_utm[1][1] - bounds_utm[0][1] 226 | new_height = new_size[0] / figsize[0] * height 227 | new_width = new_size[1] / figsize[1] * width 228 | print(f"old width {width} and new width {new_width} for building {i}") 229 | print( 230 | "old height {} and new height {} for building {}".format( 231 | height, new_height, i 232 | ) 233 | ) 234 | 235 | metadata = { 236 | "building_index": i, 237 | "real_dimensions": {"width": new_width, "height": new_height}, 238 | } 239 | 240 | with open(f"{folder_path}/metadata/building_{i}.json", "w") as file: 241 | json.dump(metadata, file) 242 | 243 | 244 | def collect_random_crops(bbox: tuple, scale_factor: float, save_folder: str): 245 | """ 246 | Collects random crops of the map using get_building_shapes_from_OSM option 1. 247 | Select a random crop of inside the main bbox. The size of the random crop is expressed as a fraction 248 | of the total bbox size by scale_factor. 249 | 250 | Parameters: 251 | bbox (Tuple): Bounding box coordinates (north, south, east, west). 252 | scale_factor (float): Scale factor to determine the size of the random crop. 253 | 254 | Returns: 255 | None 256 | """ 257 | # Load the metadata file to obtain real dimensions 258 | # Convert total bounds to UTM 259 | length = bbox[1] - bbox[0] 260 | width = bbox[3] - bbox[2] 261 | 262 | # Calculate the size of the random crop 263 | crop_length = length * scale_factor 264 | crop_width = width * scale_factor 265 | 266 | # Calculate the range of valid crop positions 267 | min_x = bbox[3] + crop_width 268 | max_x = bbox[2] - crop_width 269 | min_y = bbox[1] + crop_length 270 | max_y = bbox[0] - crop_length 271 | 272 | # Generate random crop position 273 | random_x = random.uniform(min_x, max_x) 274 | random_y = random.uniform(min_y, max_y) 275 | 276 | # Define the new crop bounding box 277 | crop_bbox = ( 278 | random_y + crop_length, 279 | random_y - crop_length, 280 | random_x + crop_width, 281 | random_x - crop_width, 282 | ) 283 | 284 | # Call get_building_shapes_from_OSM with option 2 to save the random crop 285 | folder_path = f"{save_folder}/random_crops" 286 | print("getting random crop") 287 | get_building_shapes_from_OSM(*crop_bbox, option=1, save_folder=folder_path) 288 | 289 | print(f"Random crop saved in {folder_path}") 290 | 291 | 292 | if __name__ == "__main__": 293 | # to get the bbox use https://colab.research.google.com/github/opengeos/segment-geospatial/blob/main/docs/examples/satellite.ipynb#scrollTo=kvB16LLk0qPY 294 | center_bbox = (47.378177, 47.364622, 8.526535, 8.544894) 295 | zurich_bbox = (47.3458, 47.409, 8.5065, 8.5814) 296 | folder_path = "/home/antonio/Downloads/openstreet_v1" 297 | 298 | # get_building_shapes_from_OSM(*zurich_bbox, option=1, folder_path=folder_path, save_folder=folder_path + "/" + 299 | # "random_crops") 300 | # num_crop = 100 301 | # for i in range(num_crop): 302 | # print(f"crop {i} of {num_crop}") 303 | # collect_random_crops(center_bbox, 0.02, folder_path) 304 | get_building_shapes_from_OSM( 305 | *center_bbox, option=2, folder_path=folder_path, save_folder=folder_path 306 | ) 307 | -------------------------------------------------------------------------------- /terra/env_generation/create_train_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import json 4 | import math 5 | import numpy as np 6 | import cv2 7 | import skimage 8 | from pathlib import Path 9 | from terra.env_generation.procedural_data import ( 10 | generate_trenches_v2, 11 | add_obstacles, 12 | add_non_dumpables, 13 | initialize_image, 14 | save_or_display_image, 15 | convert_terra_pad_to_color, 16 | ) 17 | from terra.env_generation.convert_to_terra import ( 18 | _convert_dumpability_to_terra, 19 | _convert_img_to_terra, 20 | _convert_occupancy_to_terra, 21 | ) 22 | import terra.env_generation.convert_to_terra as convert_to_terra 23 | from terra.env_generation.utils import _get_img_mask, color_dict 24 | import os 25 | import yaml 26 | 27 | # Define package directory at module level 28 | PACKAGE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 29 | 30 | def create_procedural_trenches(config): 31 | # Load configurations from YAML 32 | resolution = config["resolution"] 33 | trenches_config = config["trenches"] 34 | difficulty_levels = trenches_config["difficulty_levels"] 35 | 36 | # Fix for loading tuples/lists correctly 37 | trenches_per_level = config["trenches"]["trenches_per_level"] 38 | corrected_trenches_per_level = [tuple(level) for level in trenches_per_level] 39 | 40 | n_imgs = config["n_imgs"] 41 | 42 | # Load new configurations for obstacles and non-dumpables 43 | n_obs_min = trenches_config["n_obs_min"] 44 | n_obs_max = trenches_config["n_obs_max"] 45 | size_obstacle_min = trenches_config["size_obstacle_min"] 46 | size_obstacle_max = trenches_config["size_obstacle_max"] 47 | n_nodump_min = trenches_config["n_nodump_min"] 48 | n_nodump_max = trenches_config["n_nodump_max"] 49 | size_nodump_min = trenches_config["size_nodump_min"] 50 | size_nodump_max = trenches_config["size_nodump_max"] 51 | 52 | for level, n_trenches in zip(difficulty_levels, corrected_trenches_per_level): 53 | save_folder = os.path.join("data/terra", "trenches", level) 54 | os.makedirs(save_folder, exist_ok=True) 55 | 56 | # Updated to use new configuration structure 57 | trench_dims_config = trenches_config["trench_dims"][level] 58 | trench_dims_min_ratio = trench_dims_config["min_ratio"] 59 | trench_dims_max_ratio = trench_dims_config["max_ratio"] 60 | 61 | trench_dims_min = ( 62 | max(1, int(trench_dims_min_ratio[0] * trenches_config["img_edge_min"])), 63 | max(1, int(trench_dims_min_ratio[1] * trenches_config["img_edge_max"])), 64 | ) 65 | trench_dims_max = ( 66 | max(1, int(trench_dims_max_ratio[0] * trenches_config["img_edge_min"])), 67 | max(1, int(trench_dims_max_ratio[1] * trenches_config["img_edge_max"])), 68 | ) 69 | 70 | diagonal = trench_dims_config["diagonal"] 71 | 72 | generate_trenches_v2( 73 | n_imgs, 74 | trenches_config["img_edge_min"], 75 | trenches_config["img_edge_max"], 76 | trench_dims_min, 77 | trench_dims_max, 78 | n_trenches, # Fixed to correctly pass the tuple/list 79 | resolution, 80 | save_folder, 81 | n_obs_min, 82 | n_obs_max, 83 | size_obstacle_min, 84 | size_obstacle_max, 85 | n_nodump_min, 86 | n_nodump_max, 87 | size_nodump_min, 88 | size_nodump_max, 89 | diagonal, 90 | ) 91 | 92 | 93 | def create_foundations(config, 94 | n_obs_min=1, 95 | n_obs_max=3, 96 | size_obstacle_min=6, 97 | size_obstacle_max=8, 98 | n_nodump_min=1, 99 | n_nodump_max=3, 100 | size_nodump_min=8, 101 | size_nodump_max=10, 102 | expansion_factor=1, 103 | all_dumpable=False, 104 | copy_metadata=True, 105 | has_dumpability=False, 106 | center_padding=True): 107 | """ 108 | Creates foundation environments using configurations from a YAML file. 109 | 110 | Parameters: 111 | - config (dict): Configuration dictionary loaded from YAML file. 112 | - n_obs_min (int): Minimum number of obstacles to add. 113 | - n_obs_max (int): Maximum number of obstacles to add. 114 | - size_obstacle_min (int): Minimum size of obstacles. 115 | - size_obstacle_max (int): Maximum size of obstacles. 116 | - n_nodump_min (int): Minimum number of non-dumpable areas. 117 | - n_nodump_max (int): Maximum number of non-dumpable areas. 118 | - size_nodump_min (int): Minimum size of non-dumpable areas. 119 | - size_nodump_max (int): Maximum size of non-dumpable areas. 120 | - expansion_factor (int): Factor to expand the image by. 121 | - all_dumpable (bool): Whether all areas should be dumpable. 122 | - copy_metadata (bool): Whether to copy metadata. 123 | - has_dumpability (bool): Whether the image has dumpability information. 124 | - center_padding (bool): Whether to center the padding. 125 | """ 126 | # Extract configuration parameters 127 | foundation_config = config["foundations"] 128 | n_imgs = config["n_imgs"] 129 | size = foundation_config["max_size"] 130 | dataset_path = foundation_config["dataset_rel_path"] 131 | 132 | # Define save folder for the envs using os.path.join 133 | save_folder = os.path.join(PACKAGE_DIR, "data", "terra", "foundations") 134 | save_folder_large = os.path.join(PACKAGE_DIR, "data", "terra", "foundations_large") 135 | 136 | # Choose different downsampling factors for different curriculum levels 137 | downsampling_factors = { 138 | save_folder: 2, 139 | save_folder_large: 1, 140 | } 141 | 142 | # Get the full dataset path using os.path.join 143 | full_dataset_path = os.path.join(PACKAGE_DIR, dataset_path) 144 | 145 | # Process foundation images 146 | max_size = size 147 | foundations_name = "foundations" 148 | img_folder = Path(full_dataset_path) / foundations_name / "images" 149 | metadata_folder = Path(full_dataset_path) / foundations_name / "metadata" 150 | occupancy_folder = Path(full_dataset_path) / foundations_name / "occupancy" 151 | dumpability_folder = Path(full_dataset_path) / foundations_name / "dumpability" 152 | filename_start = sorted(os.listdir(img_folder))[0].split("_")[0] 153 | 154 | 155 | for curriculum_level, downsampling_factor in downsampling_factors.items(): 156 | for i, fn in enumerate(os.listdir(img_folder)): 157 | if i >= n_imgs: 158 | break 159 | 160 | print(f"Processing foundation nr {i + 1}") 161 | 162 | n = int(fn.split(".png")[0].split("_")[1]) 163 | filename = filename_start + f"_{n}.png" 164 | file_path = img_folder / filename 165 | 166 | occupancy_path = occupancy_folder / filename 167 | img = cv2.imread(str(file_path)) 168 | 169 | occupancy = cv2.imread(str(occupancy_path)) 170 | 171 | if has_dumpability: 172 | dumpability_path = dumpability_folder / filename 173 | dumpability = cv2.imread(str(dumpability_path)) 174 | 175 | with open( 176 | metadata_folder / f"{filename.split('.png')[0]}.json" 177 | ) as json_file: 178 | metadata = json.load(json_file) 179 | 180 | # Calculate downsample factors based on max_size 181 | downsample_factor_w = int(max(1, math.ceil(img.shape[1] / max_size))) * downsampling_factor 182 | downsample_factor_h = int(max(1, math.ceil(img.shape[0] / max_size))) * downsampling_factor 183 | 184 | img_downsampled = skimage.measure.block_reduce( 185 | img, (downsample_factor_h, downsample_factor_w, 1), np.max 186 | ) 187 | img = img_downsampled 188 | occupancy_downsampled = skimage.measure.block_reduce( 189 | occupancy, (downsample_factor_h, downsample_factor_w, 1), np.min, cval=0 190 | ) 191 | occupancy = occupancy_downsampled 192 | if has_dumpability: 193 | dumpability_downsampled = skimage.measure.block_reduce( 194 | dumpability, 195 | (downsample_factor_h, downsample_factor_w, 1), 196 | np.min, 197 | cval=0, 198 | ) 199 | dumpability = dumpability_downsampled 200 | 201 | # assert img_downsampled.shape[:-1] == occupancy_downsampled.shape 202 | img_terra = _convert_img_to_terra(img, all_dumpable) 203 | 204 | # Pad to max size 205 | if center_padding: 206 | xdim = max_size - img_terra.shape[0] 207 | ydim = max_size - img_terra.shape[1] 208 | # Note: applying full dumping tiles for the centered version 209 | img_terra_pad = np.ones((max_size, max_size), dtype=img_terra.dtype) 210 | img_terra_pad[ 211 | xdim // 2 : max_size - (xdim - xdim // 2), 212 | ydim // 2 : max_size - (ydim - ydim // 2), 213 | ] = img_terra 214 | # Note: applying no occupancy for the centered version (mismatch with Terra env) 215 | img_terra_occupancy = np.zeros((max_size, max_size), dtype=np.bool_) 216 | img_terra_occupancy[ 217 | xdim // 2 : max_size - (xdim - xdim // 2), 218 | ydim // 2 : max_size - (ydim - ydim // 2), 219 | ] = _convert_occupancy_to_terra(occupancy) 220 | if has_dumpability: 221 | img_terra_dumpability = np.zeros((max_size, max_size), dtype=np.bool_) 222 | img_terra_dumpability[ 223 | xdim // 2 : max_size - (xdim - xdim // 2), 224 | ydim // 2 : max_size - (ydim - ydim // 2), 225 | ] = _convert_dumpability_to_terra(dumpability) 226 | else: 227 | img_terra_pad = np.zeros((max_size, max_size), dtype=img_terra.dtype) 228 | img_terra_pad[: img_terra.shape[0], : img_terra.shape[1]] = img_terra 229 | img_terra_occupancy = np.ones((max_size, max_size), dtype=np.bool_) 230 | img_terra_occupancy[: occupancy.shape[0], : occupancy.shape[1]] = ( 231 | _convert_occupancy_to_terra(occupancy) 232 | ) 233 | if has_dumpability: 234 | img_terra_dumpability = np.zeros((max_size, max_size), dtype=np.bool_) 235 | img_terra_dumpability[ 236 | : dumpability.shape[0], : dumpability.shape[1] 237 | ] = _convert_dumpability_to_terra(dumpability) 238 | 239 | img_terra_pad = img_terra_pad.repeat(expansion_factor, 0).repeat( 240 | expansion_factor, 1 241 | ) 242 | img_terra_pad = convert_terra_pad_to_color(img_terra_pad, color_dict) 243 | dumping_image = initialize_image(size, size, color_dict["dumping"]) 244 | 245 | # Create a mask where img_terra_pad is not equal to color_dict["digging"] 246 | mask = np.all(img_terra_pad != color_dict["digging"], axis=-1) 247 | 248 | # Use the mask to assign values from dumping_image to img_terra_pad 249 | img_terra_pad[mask] = dumping_image[mask] 250 | 251 | cumulative_mask = np.zeros_like(img_terra_pad, dtype=np.bool_) 252 | # where the img_terra_pad is [255, 255, 255] set to True across the three channels 253 | cumulative_mask[img_terra_pad == 255] = True 254 | occ, cumulative_mask = add_obstacles( 255 | img_terra_pad, 256 | cumulative_mask, 257 | n_obs_min, 258 | n_obs_max, 259 | size_obstacle_min, 260 | size_obstacle_max, 261 | ) 262 | 263 | dmp, cumulative_mask = add_non_dumpables( 264 | img_terra_pad, 265 | occ, 266 | cumulative_mask, 267 | n_nodump_min, 268 | n_nodump_max, 269 | size_nodump_min, 270 | size_nodump_max, 271 | ) 272 | save_or_display_image(img_terra_pad, occ, dmp, metadata, curriculum_level, n) 273 | 274 | print("Foundations created successfully.") 275 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌍 Terra - Earthwork planning environment in JAX 2 | 3 | ![img](assets/overview.gif) 4 | 5 | Terra is a flexible and abstracted grid world environment for training intelligent agents in the context of earthworks planning. It makes it possible to approach motion and excavation high-level planning as a reinforcement learning problem, providing a multi-GPU JAX-accelerated environment. We show that we can train an agent capable of planning earthworks in trenches and foundations environments in less than 1 minute on 8 Nvidia RTX-4090 GPUs. 6 | 7 | ## Features 8 | - 🚜 Two Agent Types: Wheeled and tracked excavator embodiments for different types of actions 9 | - 🏞️ Realistic Maps: Up-to-3-axes trenches and building foundations with obstacles and dumping constraints 10 | - 🔥 Performance: Easily scale to more than 1M steps per second on a single GPU 11 | - 🚀 Scaling: Out of the box multi-device training 12 | - 📖 Curriculum: Customizable RL curriculum via config interface 13 | - 🔧 Tooling: Visualization, evaluation, manual play, and maps inspection scripts 14 | - 🏌 Baselines: We provide baseline results and PPO-based training scripts inspired from [purejaxrl](https://github.com/luchris429/purejaxrl) and [xland-minigrid](https://github.com/corl-team/xland-minigrid) 15 | 16 | ## Installation 17 | Clone the repo, and if you want to use Terra in your project use 18 | ~~~ 19 | pip install -e . 20 | ~~~ 21 | You can check out [terra-baselines](https://github.com/leggedrobotics/rl-excavation-planning) for an installation workflow example. 22 | 23 | ### JAX 24 | The JAX installation is hardware-dependent and therefore needs to be done separately. Follow [this link](https://jax.readthedocs.io/en/latest/installation.html) to install the right one for you. 25 | 26 | ## Usage 27 | The standard workflow is made of the following steps: 28 | 1. Generate the maps by following this [README](https://github.com/leggedrobotics/terra/blob/main/terra/env_generation/README.md) (you can check out a preview of the generated maps in the `data/` folder) 29 | 2. Set up the curriculum in `config.py` 30 | 2. Build your own training script or use the ready-to-use script from our [terra-baselines](https://github.com/leggedrobotics/rl-excavation-planning). 31 | 3. Train 🚀 32 | 4. Run [evaluations](https://github.com/leggedrobotics/rl-excavation-planning/blob/master/eval.py) and [visualization](https://github.com/leggedrobotics/rl-excavation-planning/blob/master/visualize.py). 33 | 34 | # Environment Setup Instructions 35 | 36 | This repository includes configuration files to help you set up the required environment for this project. 37 | 38 | ## Using Conda Environment (Recommended) 39 | 40 | The `environment.yml` file contains all the necessary dependencies to reproduce the project environment named "terra" with Python 3.12.2. 41 | 42 | ### Creating the Environment 43 | 44 | To create a new conda environment from the provided `environment.yml` file, run: 45 | 46 | ```bash 47 | conda env create -f environment.yml 48 | ``` 49 | 50 | This command will create a new environment named "terra" with all the specified dependencies. 51 | 52 | ### Activating the Environment 53 | 54 | After creating the environment, activate it with: 55 | 56 | ```bash 57 | conda activate terra 58 | ``` 59 | 60 | ### Installing Terra and JAX 61 | 62 | After activating the environment it is necessary to install Terra 63 | 64 | ```bash 65 | pip install -e . 66 | ``` 67 | 68 | and [JAX](https://docs.jax.dev/en/latest/installation.html). 69 | 70 | At the moment you should use jaxlib and jax version <= 0.4.26 as `jax.tree_map` is deprecated in newer versions 71 | 72 | ```bash 73 | pip install -U "jax[cuda12]==0.4.26" jaxlib==0.4.26+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 74 | ``` 75 | 76 | With GPU (NVIDIA, CUDA 12): 77 | 78 | ```bash 79 | pip install -U "jax[cuda12]" 80 | ``` 81 | 82 | ### Verifying the Environment 83 | 84 | To verify that the environment was set up correctly, you can: 85 | 86 | 1. Check that you're in the correct environment: 87 | ```bash 88 | conda info 89 | ``` 90 | The active environment should be displayed as "terra". 91 | 92 | 2. List all installed packages: 93 | ```bash 94 | conda list 95 | ``` 96 | 97 | 3. Check Python version: 98 | ```bash 99 | python --version 100 | ``` 101 | This should output "Python 3.12.2" 102 | 103 | ## Environment Files: requirements.txt vs environment.yml 104 | 105 | ### environment.yml 106 | 107 | The `environment.yml` file is used by conda to create environments and has several advantages: 108 | - Specifies the Python version 109 | - Can include conda and pip dependencies in one file 110 | - Can include dependencies from multiple channels 111 | - Resolves dependencies more effectively than pip alone 112 | - Handles non-Python dependencies (e.g., C libraries) 113 | 114 | ### requirements.txt 115 | 116 | The `requirements.txt` file is a standard pip requirements file that specifies Python packages to be installed with pip. This file: 117 | - Lists Python dependencies only 118 | - Can be used with `pip install -r requirements.txt` 119 | - Doesn't specify Python version 120 | - Doesn't handle non-Python dependencies 121 | 122 | For this project, the `environment.yml` file is the recommended way to set up your environment as it ensures all dependencies (including the correct Python version) are properly installed. 123 | 124 | ## Updating the Environment 125 | 126 | If you need to update the environment after changes to the `environment.yml` file: 127 | 128 | ```bash 129 | conda env update -f environment.yml --prune 130 | ``` 131 | 132 | The `--prune` option removes dependencies that are no longer specified in the updated file. 133 | 134 | 135 | ### Basic Usage 136 | ``` python 137 | import jax 138 | from terra.env import TerraEnvBatch 139 | from terra.config import EnvConfig 140 | 141 | key = jax.random.PRNGKey(0) 142 | key, reset_key, step_key = jax.random.split(key, 3) 143 | 144 | # create Terra and configs 145 | env = TerraEnvBatch() 146 | env_params = EnvConfig() 147 | 148 | # jitted reset and step functions 149 | timestep = env.reset(env_params, reset_key) 150 | timestep = env.step(timestep, action, step_key) 151 | ``` 152 | 153 | ### Map generation 154 | Running the standard map generation will produce the following folder structure. This includes foundations and trenches, and additional curriculum maps that help with the training in case the sparse reward strategy is used. 155 | ``` 156 | - data 157 | - custom <-- custom made maps for training special behaviour 158 | - openstreet 159 | - terra 160 | - foundations 161 | - trenches 162 | - train <- npy maps formatter for terra 163 | - trenches <-- trenches with 1 to 3 intersecting axes 164 | - foundations <-- building foundations from OpenStreetMap 165 | - dumpability <-- encodes where the agent can dump 166 | - images <-- encodes the target dig profile 167 | - occupancy <-- encodes the obstacles 168 | - foundations_large <-- same as foundations but bigger 169 | - custom 170 | ``` 171 | 172 | ### Training Configurations 173 | In Terra the settings are expressed as curriculum levels. To set the levels of the training, you can edit the `config.py` file. For example, if you want to start your training with dense rewards on foundations and then shift to shorter episodes with sparse rewards, you can set the curriculum as follows. 174 | ``` python 175 | class CurriculumGlobalConfig(NamedTuple): 176 | increase_level_threshold: int = 3 177 | decrease_level_threshold: int = 10 178 | 179 | levels = [ 180 | { 181 | "maps_path": "terra/foundations", 182 | "max_steps_in_episode": 300, 183 | "rewards_type": RewardsType.DENSE, 184 | "apply_trench_rewards": False, 185 | }, 186 | { 187 | "maps_path": "terra/trenches/single", 188 | "max_steps_in_episode": 200, 189 | "rewards_type": RewardsType.SPARSE, 190 | "apply_trench_rewards": True, 191 | } 192 | ] 193 | ``` 194 | Note that `apply_trench_rewards` should be `True` if you are training on trenches. This enables an additional reward that penalizes the distance of the agent from any trench axis at dig time, pushing the agent to be aligned to them. 195 | 196 | To select the embodiment to use, set the following to either `TrackedAction` or `WheeledAction`. Check out `state.py` for the documentation of the embodiment-specific state transitions. 197 | ``` python 198 | class BatchConfig(NamedTuple): 199 | action_type: Action = TrackedAction 200 | ``` 201 | 202 | ## Tools 🔧 203 | We provide debugging tools to explore Terra maps and play with the different agents. 204 | 205 | You can play on a single environment using your keyboard with 206 | ``` python 207 | DATASET_PATH=/path/to/dataset DATASET_SIZE= python -m terra.viz.main_manual 208 | ``` 209 | and you can inspect the generated maps with 210 | ``` python 211 | DATASET_PATH=/path/to/dataset DATASET_SIZE= python -m terra.viz.play 212 | ``` 213 | note that these scripts assume that the maps are stored in the `data/` folder. 214 | 215 | ## Rules 🔮 216 | In Terra the agent can move around, dig, and dump terrain. The goal of the agent is to dig the target shape given at the beginning of the episode, planning around obstacles and dumping constraints. The target map defines all the tiles that can be dug, and the action map stores the progress. Tiles are dug in batches, where a batch is defined by the conical section representing the full reach of the excavator arm for a given base and cabin pose. Therefore, with a `DO` action, the agent digs all the tiles in the target map that are within reach, and subsequently with another `DO` action it distributes the dirt evenly on the dumpable tiles within reach. 217 | 218 | ### Agent Types 219 | Two types of excavators are abstracted in Terra: tracked and wheeled. The difference is that the tracked excavator is able to turn the base on the spot whereas the wheeled turns by doing an L-shaped movement (e.g. forward-turn-forward). 220 | 221 | ### Map Types 222 | Terra comes with two types of maps: foundations and trenches. Foundations are produced by projecting real building profiles from OpenStreetMap on the grid map. Trenches are procedurally generated and are divided in three categories based on the number of axes the trench has (1 to 3). All the maps have additional layers to encode obstacles, regions where the excavator can't dump terrain (e.g. roads), and regions where the excavator needs to dump all the terrain to terminate the episode (terminal dumping areas). Check out `map.py` for the documentation of the map layering and logic. 223 | 224 | ## Observation Space 🔍 225 | The agent in Terra perceives the environment through a rich observation space that provides comprehensive information about the state of the world and the agent itself. The observation is a dictionary with the following components: 226 | 227 | - **agent_states**: A 6-dimensional vector containing: 228 | - Position of the base (x, y coordinates) 229 | - Base rotation angle 230 | - Cabin rotation angle 231 | - Whether the agent is loaded with dirt (0 or 1) 232 | - Wheel angle (in case of wheeled digger) 233 | 234 | - **Local Maps**: The agent has access to local maps representing different aspects of the environment from the agent's perspective: 235 | - **local_map_action_neg/pos**: Current state of the terrain (negative/positive height) within the agent's reach 236 | - **local_map_target_neg/pos**: Target digging profile (negative/positive height) within reach 237 | - **local_map_dumpability**: Areas where the agent can dump soil within reach 238 | - **local_map_obstacles**: Obstacles within the agent's reach 239 | 240 | - **Global Maps**: Full maps of the environment: 241 | - **action_map**: Current state of the terrain across the entire map 242 | - **target_map**: Target digging profile across the entire map 243 | - **traversability_mask**: Areas where the agent can navigate 244 | - **dumpability_mask**: Areas where the agent can dump soil 245 | - **padding_mask**: Areas with obstacles 246 | 247 | - **Agent Dimensions**: 248 | - **agent_width**: Width of the agent 249 | - **agent_height**: Height of the agent 250 | 251 | This rich observation space allows the agent to understand both its immediate surroundings and the global state of the environment, enabling effective planning for navigation, digging, and dumping operations. 252 | 253 | ## Performance 🔥 254 | We benchmark the environments by measuring the runtime of our PPO algorithm including environment steps and model update on Nvidia RTX-4090 GPUs. For all the experiments we keep constant 32 steps, 32 minibatches, and 1 update epoch. 255 | 256 | Scaling on single device | Scaling across devices 257 | :-------------------------:|:-------------------------: 258 | ![](assets/scaling-envs.png) | ![](assets/scaling-devices.png) 259 | 260 | ## Baselines 🎮 261 | We release a set of baselines and checkpoints in [terra-baselines](https://github.com/leggedrobotics/rl-excavation-planning). 262 | 263 | ## Citation 264 | If you use this code in your research or project, please cite the following: 265 | ``` 266 | @misc{terra, 267 | title={Terra - Earthwork planning environment in JAX}, 268 | author={Antonio Arbues, Lorenzo Terenzi}, 269 | howpublished={\url{https://github.com/leggedrobotics/terra}}, 270 | year={2024} 271 | } 272 | ``` 273 | -------------------------------------------------------------------------------- /terra/env_generation/convert_to_terra.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from pathlib import Path 5 | 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import skimage 10 | from scipy.signal import convolve2d 11 | from skimage import measure 12 | from tqdm import tqdm 13 | import time 14 | 15 | import terra.env_generation.utils as utils 16 | from terra.env_generation.utils import _get_img_mask, color_dict 17 | 18 | 19 | def _convert_img_to_terra(img, all_dumpable=False): 20 | """ 21 | Converts an image from color_dict RGB convention 22 | to [-1, 0, 1] Terra convention. 23 | """ 24 | img = img.astype(np.int16) 25 | img = np.where(img == np.array(color_dict["digging"]), -1, img) 26 | img = np.where(img == np.array(color_dict["dumping"]), 1, img) 27 | if all_dumpable: 28 | img = np.where(img == np.array(color_dict["neutral"]), 1, img) 29 | else: 30 | img = np.where(img == np.array(color_dict["neutral"]), 0, img) 31 | img = np.where((img != -1) & (img != 1), 0, img) 32 | img = img[..., 0] # take only 1 channel 33 | return img.astype(np.int8) 34 | 35 | 36 | def _convert_occupancy_to_terra(img): 37 | img = img.astype(np.int16) 38 | mask = _get_img_mask(img, np.array(color_dict["obstacle"])) 39 | img = np.where(mask, 1, 0) 40 | return img.astype(np.bool_) 41 | 42 | 43 | def _convert_dumpability_to_terra(img): 44 | img = img.astype(np.int16) 45 | mask = _get_img_mask(img, np.array(color_dict["nondumpable"])) 46 | img = np.where(mask, 0, 1) 47 | return img.astype(np.bool_) 48 | 49 | 50 | def _convert_actions_to_terra(img): 51 | img = img.astype(np.int16) 52 | mask = _get_img_mask(img, np.array(color_dict["dirt"])) 53 | img = np.where(mask, 1, 0) 54 | return img.astype(np.int8) 55 | 56 | 57 | def _convert_all_imgs_to_terra( 58 | img_folder, 59 | metadata_folder, 60 | occupancy_folder, 61 | dumpability_folder, 62 | destination_folder, 63 | size, 64 | n_imgs, 65 | expansion_factor=1, 66 | all_dumpable=False, 67 | copy_metadata=True, 68 | downsample=True, 69 | has_dumpability=True, 70 | center_padding=False, 71 | actions_folder=None, 72 | ): 73 | max_size = size[1] 74 | print("max size: ", max_size) 75 | # try: 76 | 77 | filename_start = sorted(os.listdir(img_folder))[0].split("_")[0] 78 | 79 | for i, fn in tqdm(enumerate(os.listdir(img_folder))): 80 | if i >= n_imgs: 81 | break 82 | 83 | n = int(fn.split(".png")[0].split("_")[1]) 84 | filename = filename_start + f"_{n}.png" 85 | file_path = img_folder / filename 86 | 87 | occupancy_path = occupancy_folder / filename 88 | img = cv2.imread(str(file_path)) 89 | occupancy = cv2.imread(str(occupancy_path)) 90 | 91 | if has_dumpability: 92 | dumpability_path = dumpability_folder / filename 93 | dumpability = cv2.imread(str(dumpability_path)) 94 | 95 | if downsample: 96 | with open( 97 | metadata_folder / f"{filename.split('.png')[0]}.json" 98 | ) as json_file: 99 | metadata = json.load(json_file) 100 | 101 | # Calculate downsample factors based on max_size 102 | downsample_factor_w = max(1, math.ceil(img.shape[1] / max_size)) 103 | downsample_factor_h = max(1, math.ceil(img.shape[0] / max_size)) 104 | 105 | img_downsampled = skimage.measure.block_reduce( 106 | img, (downsample_factor_h, downsample_factor_w, 1), np.max 107 | ) 108 | img = img_downsampled 109 | occupancy_downsampled = skimage.measure.block_reduce( 110 | occupancy, (downsample_factor_h, downsample_factor_w, 1), np.min, cval=0 111 | ) 112 | occupancy = occupancy_downsampled 113 | if has_dumpability: 114 | dumpability_downsampled = skimage.measure.block_reduce( 115 | dumpability, 116 | (downsample_factor_h, downsample_factor_w, 1), 117 | np.min, 118 | cval=0, 119 | ) 120 | dumpability = dumpability_downsampled 121 | 122 | img_terra = _convert_img_to_terra(img, all_dumpable) 123 | # Pad to max size 124 | if center_padding: 125 | xdim = max_size - img_terra.shape[0] 126 | ydim = max_size - img_terra.shape[1] 127 | # Note: applying full dumping tiles for the centered version 128 | img_terra_pad = np.ones((max_size, max_size), dtype=img_terra.dtype) 129 | print( 130 | "xdim:", 131 | xdim, 132 | "max_size:", 133 | max_size, 134 | "ydim:", 135 | ydim, 136 | "img_terra shape:", 137 | img_terra.shape, 138 | ) 139 | img_terra_pad[ 140 | xdim // 2 : max_size - (xdim - xdim // 2), 141 | ydim // 2 : max_size - (ydim - ydim // 2), 142 | ] = img_terra 143 | # Note: applying no occupancy for the centered version (mismatch with Terra env) 144 | img_terra_occupancy = np.zeros((max_size, max_size), dtype=np.bool_) 145 | img_terra_occupancy[ 146 | xdim // 2 : max_size - (xdim - xdim // 2), 147 | ydim // 2 : max_size - (ydim - ydim // 2), 148 | ] = _convert_occupancy_to_terra(occupancy) 149 | if has_dumpability: 150 | img_terra_dumpability = np.zeros((max_size, max_size), dtype=np.bool_) 151 | img_terra_dumpability[ 152 | xdim // 2 : max_size - (xdim - xdim // 2), 153 | ydim // 2 : max_size - (ydim - ydim // 2), 154 | ] = _convert_dumpability_to_terra(dumpability) 155 | else: 156 | img_terra_pad = np.zeros((max_size, max_size), dtype=img_terra.dtype) 157 | img_terra_pad[: img_terra.shape[0], : img_terra.shape[1]] = img_terra 158 | img_terra_occupancy = np.ones((max_size, max_size), dtype=np.bool_) 159 | img_terra_occupancy = _convert_occupancy_to_terra(occupancy) 160 | if has_dumpability: 161 | img_terra_dumpability = np.zeros((max_size, max_size), dtype=np.bool_) 162 | img_terra_dumpability = _convert_dumpability_to_terra(dumpability) 163 | 164 | destination_folder_images = destination_folder / "images" 165 | destination_folder_images.mkdir(parents=True, exist_ok=True) 166 | destination_folder_occupancy = destination_folder / "occupancy" 167 | destination_folder_occupancy.mkdir(parents=True, exist_ok=True) 168 | destination_folder_dumpability = destination_folder / "dumpability" 169 | destination_folder_dumpability.mkdir(parents=True, exist_ok=True) 170 | if copy_metadata: 171 | destination_folder_metadata = destination_folder / "metadata" 172 | destination_folder_metadata.mkdir(parents=True, exist_ok=True) 173 | 174 | img_terra_pad = img_terra_pad.repeat(expansion_factor, axis=0).repeat( 175 | expansion_factor, axis=1 176 | ) 177 | img_terra_occupancy = img_terra_occupancy.repeat( 178 | expansion_factor, axis=0 179 | ).repeat(expansion_factor, axis=1) 180 | if has_dumpability: 181 | img_terra_dumpability = img_terra_dumpability.repeat( 182 | expansion_factor, 0 183 | ).repeat(expansion_factor, 1) 184 | 185 | np.save(destination_folder_images / f"img_{i + 1}", img_terra_pad) 186 | np.save(destination_folder_occupancy / f"img_{i + 1}", img_terra_occupancy) 187 | if has_dumpability: 188 | np.save( 189 | destination_folder_dumpability / f"img_{i + 1}", img_terra_dumpability 190 | ) 191 | else: 192 | np.save( 193 | destination_folder_dumpability / f"img_{i + 1}", 194 | np.ones_like(img_terra_pad), 195 | ) 196 | if actions_folder is not None: 197 | actions_path = actions_folder / filename 198 | actions = cv2.imread(str(actions_path)) 199 | actions_terra = _convert_actions_to_terra(actions) 200 | actions_terra = actions_terra.repeat(expansion_factor, axis=0).repeat(expansion_factor, axis=1) 201 | destination_folder_actions = destination_folder / "actions" 202 | destination_folder_actions.mkdir(parents=True, exist_ok=True) 203 | np.save(destination_folder_actions / f"img_{i + 1}", actions_terra) 204 | if copy_metadata: 205 | utils.copy_and_increment_filenames(str(metadata_folder), str(destination_folder_metadata)) 206 | 207 | 208 | def generate_foundations_terra(dataset_folder, size, n_imgs, all_dumpable): 209 | print("Converting foundations...") 210 | foundations_levels = ["foundations", "foundations_large"] 211 | for level in foundations_levels: 212 | img_folder = Path(dataset_folder) / level / "images" 213 | metadata_folder = Path(dataset_folder) / level / "metadata" 214 | occupancy_folder = Path(dataset_folder) / level/ "occupancy" 215 | dumpability_folder = Path(dataset_folder) / level / "dumpability" 216 | destination_folder = Path(dataset_folder) / "train" / level 217 | destination_folder.mkdir(parents=True, exist_ok=True) 218 | _convert_all_imgs_to_terra( 219 | img_folder, 220 | metadata_folder, 221 | occupancy_folder, 222 | dumpability_folder, 223 | destination_folder, 224 | size, 225 | n_imgs, 226 | all_dumpable=all_dumpable, 227 | copy_metadata=True, 228 | downsample=False, 229 | has_dumpability=True, 230 | center_padding=False, 231 | actions_folder=None, 232 | ) 233 | 234 | 235 | def generate_trenches_terra(dataset_folder, size, n_imgs, expansion_factor, all_dumpable): 236 | print("Converting trenches...") 237 | trenches_name = "trenches" 238 | trenches_path = Path(dataset_folder) / trenches_name 239 | levels = [d.name for d in trenches_path.iterdir() if d.is_dir()] 240 | for level in levels: 241 | img_folder = trenches_path / level / "images" 242 | metadata_folder = trenches_path / level / "metadata" 243 | occupancy_folder = trenches_path / level / "occupancy" 244 | dumpability_folder = trenches_path / level / "dumpability" 245 | destination_folder = Path(dataset_folder) / "train" / trenches_name / level 246 | destination_folder.mkdir(parents=True, exist_ok=True) 247 | _convert_all_imgs_to_terra( 248 | img_folder, 249 | metadata_folder, 250 | occupancy_folder, 251 | dumpability_folder, 252 | destination_folder, 253 | size, 254 | n_imgs, 255 | expansion_factor=expansion_factor, 256 | all_dumpable=all_dumpable, 257 | actions_folder=None, 258 | ) 259 | 260 | def generate_relocations_terra(dataset_folder, size, n_imgs): 261 | print("Converting relocations...") 262 | img_folder = Path(dataset_folder) / "relocations" / "images" 263 | metadata_folder = Path(dataset_folder) / "relocations" / "metadata" 264 | occupancy_folder = Path(dataset_folder) / "relocations"/ "occupancy" 265 | dumpability_folder = Path(dataset_folder) / "relocations" / "dumpability" 266 | actions_folder = Path(dataset_folder) / "relocations" / "actions" 267 | destination_folder = Path(dataset_folder) / "train" / "relocations" 268 | destination_folder.mkdir(parents=True, exist_ok=True) 269 | _convert_all_imgs_to_terra( 270 | img_folder, 271 | metadata_folder, 272 | occupancy_folder, 273 | dumpability_folder, 274 | destination_folder, 275 | size, 276 | n_imgs, 277 | all_dumpable=False, 278 | copy_metadata=False, 279 | downsample=False, 280 | has_dumpability=True, 281 | center_padding=False, 282 | actions_folder=actions_folder 283 | ) 284 | 285 | def generate_custom_terra(dataset_folder, size, n_imgs): 286 | print("Converting custom maps...") 287 | img_folder = Path(dataset_folder) / ".." / "custom" / "images" 288 | metadata_folder = Path(dataset_folder) / ".." / "custom" / "metadata" 289 | occupancy_folder = Path(dataset_folder) / ".." / "custom"/ "occupancy" 290 | dumpability_folder = Path(dataset_folder) / ".." / "custom" / "dumpability" 291 | destination_folder = Path(dataset_folder) / "train" / "custom" 292 | destination_folder.mkdir(parents=True, exist_ok=True) 293 | _convert_all_imgs_to_terra( 294 | img_folder, 295 | metadata_folder, 296 | occupancy_folder, 297 | dumpability_folder, 298 | destination_folder, 299 | size, 300 | n_imgs, 301 | all_dumpable=False, 302 | copy_metadata=False, 303 | downsample=False, 304 | has_dumpability=True, 305 | center_padding=False, 306 | actions_folder=None, 307 | ) 308 | 309 | 310 | def generate_dataset_terra_format(dataset_folder, size, n_imgs=1000): 311 | print("dataset folder: ", dataset_folder) 312 | generate_foundations_terra(dataset_folder, size, n_imgs, all_dumpable=False) 313 | print("Foundations processed successfully.") 314 | generate_trenches_terra( 315 | dataset_folder, size, n_imgs, expansion_factor=1, all_dumpable=False 316 | ) 317 | print("Trenches processed successfully.") 318 | generate_relocations_terra(dataset_folder, size, n_imgs) 319 | print("Custom maps processed successfully.") 320 | generate_custom_terra(dataset_folder, size, n_imgs) 321 | print("Custom maps processed successfully.") 322 | -------------------------------------------------------------------------------- /terra/env_generation/openstreet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import matplotlib.pyplot as plt 3 | import osmnx as ox 4 | import cv2 5 | from PIL import Image 6 | from pyproj import CRS, Transformer 7 | import os 8 | from shapely.geometry import Polygon, Point 9 | import random 10 | from typing import Tuple 11 | import concurrent.futures 12 | import geopandas as gpd 13 | 14 | def get_building_shapes_from_OSM( 15 | north, south, east, west, option=1, save_folder="data/", max_buildings=None 16 | ): 17 | """ 18 | Extracts building shapes from OpenStreetMap given a bounding box of coordinates. 19 | 20 | Parameters: 21 | north (float): northern boundary of bounding box 22 | south (float): southern boundary of bounding box 23 | east (float): eastern boundary of bounding box 24 | west (float): western boundary of bounding box 25 | option (int, optional): Option for operation. 1 for saving a binary map, 2 for saving individual buildings. Defaults to 1. 26 | save_path (str, optional): File path to save output if option is 1. Defaults to 'output.png'. 27 | folder_path (str, optional): Folder path to save output if option is 2. Defaults to None. 28 | 29 | Returns: 30 | None 31 | """ 32 | # make sure the save folder exists 33 | if not os.path.exists(save_folder): 34 | os.makedirs(save_folder) 35 | 36 | # Fetch buildings from OSM\ 37 | bbox = (north, south, east, west) 38 | try: 39 | buildings = ox.features.features_from_bbox(*bbox, tags={"building": True}) 40 | except Exception as e: 41 | return 42 | print("got buildings") 43 | 44 | # Define coordinate reference systems 45 | wgs84 = CRS("EPSG:4326") # WGS84 (lat-long) coordinate system 46 | utm = CRS("EPSG:32633") # UTM zone 33N (covers central Europe) 47 | 48 | # Create transformer 49 | transformer = Transformer.from_crs(wgs84, utm, always_xy=True) 50 | 51 | # Check option 52 | if option == 1: 53 | extract_crop(buildings, wgs84, utm, north, south, east, west, save_folder, transformer) 54 | elif option == 2: 55 | extract_single_buildings_parallel(buildings, wgs84, utm, folder_path=save_folder, transformer=transformer, max_buildings=max_buildings) 56 | else: 57 | print("Invalid option selected. Choose either 1 or 2.") 58 | 59 | 60 | def process_single_building(i, building, wgs84, transformer, folder_path): 61 | """ 62 | Worker function to process one building. This creates a single-row GeoDataFrame, 63 | sets up the plot, saves the image, computes its dimensions, and writes metadata. 64 | """ 65 | if building.area < 1e-9: 66 | return 67 | 68 | # Create a GeoDataFrame for the single building. 69 | gdf = gpd.GeoDataFrame({'geometry': [building]}, crs=wgs84) 70 | 71 | # Calculate UTM bounds from the building's bounding box. 72 | b = building.bounds # (minx, miny, maxx, maxy) 73 | bounds_utm = [ 74 | transformer.transform(b[0], b[1]), 75 | transformer.transform(b[2], b[3]) 76 | ] 77 | width_utm = bounds_utm[1][0] - bounds_utm[0][0] 78 | height_utm = bounds_utm[1][1] - bounds_utm[0][1] 79 | if height_utm == 0 or width_utm == 0: 80 | return 81 | # Compute aspect ratio as in the original code. 82 | aspect_ratio = (width_utm / height_utm) ** (-1) 83 | dpi = 50 84 | max_pixels = 600 85 | max_inches = int(max_pixels / dpi) 86 | if aspect_ratio > 1: 87 | figsize = (max_inches, int(max_inches / aspect_ratio)) 88 | else: 89 | figsize = (int(max_inches * aspect_ratio), max_inches) 90 | if figsize[0] == 0: 91 | figsize = (1, figsize[1]) 92 | if figsize[1] == 0: 93 | figsize = (figsize[0], 1) 94 | 95 | # Create a matplotlib figure. 96 | fig, ax = plt.subplots(figsize=figsize) 97 | ax.axis("off") 98 | ax.set_xlim(bounds_utm[0][0], bounds_utm[1][0]) 99 | ax.set_ylim(bounds_utm[0][1], bounds_utm[1][1]) 100 | ax.invert_yaxis() 101 | ax.set_aspect("equal", adjustable="box") 102 | ox.plot_footprints(gdf, ax=ax, color="black", bgcolor="white", show=False) 103 | 104 | # Ensure the images folder exists. 105 | images_folder = os.path.join(folder_path, "images") 106 | os.makedirs(images_folder, exist_ok=True) 107 | image_path = os.path.join(images_folder, f"building_{i}.png") 108 | plt.savefig(image_path, dpi=dpi, pad_inches=1.0, bbox_inches="tight") 109 | plt.close(fig) 110 | 111 | # Open the saved image and compute its dimensions. 112 | try: 113 | with Image.open(image_path) as img: 114 | width_px, height_px = img.size 115 | except Exception as e: 116 | print(f"Error opening image for building {i}: {e}") 117 | return 118 | 119 | new_size = (width_px / dpi, height_px / dpi) 120 | print(f"Building {i}: old size {figsize} and new size {new_size}") 121 | 122 | width = width_utm 123 | height = height_utm 124 | new_height = new_size[0] / figsize[0] * height 125 | new_width = new_size[1] / figsize[1] * width 126 | print(f"Building {i}: old width {width} and new width {new_width}") 127 | print(f"Building {i}: old height {height} and new height {new_height}") 128 | 129 | metadata = { 130 | "building_index": i, 131 | "real_dimensions": {"width": new_height, "height": new_width}, 132 | } 133 | metadata_folder = os.path.join(folder_path, "metadata") 134 | os.makedirs(metadata_folder, exist_ok=True) 135 | metadata_path = os.path.join(metadata_folder, f"building_{i}.json") 136 | with open(metadata_path, "w") as f: 137 | json.dump(metadata, f) 138 | 139 | def extract_single_buildings_parallel(buildings, wgs84, utm, folder_path=None, transformer=None, max_buildings=None): 140 | """ 141 | Processes building extraction in parallel using multiple CPU cores. 142 | Each building is handled independently. 143 | """ 144 | if folder_path is None: 145 | raise ValueError("No folder path provided") 146 | if transformer is None: 147 | raise ValueError("No transformer provided") 148 | 149 | tasks = [] 150 | for i, building in enumerate(buildings.geometry): 151 | if max_buildings is not None and i >= max_buildings: 152 | break 153 | if building.area < 1e-9: 154 | continue 155 | tasks.append((i, building, wgs84, transformer, folder_path)) 156 | 157 | with concurrent.futures.ProcessPoolExecutor() as executor: 158 | futures = [executor.submit(process_single_building, *args) for args in tasks] 159 | for future in concurrent.futures.as_completed(futures): 160 | try: 161 | future.result() 162 | except Exception as e: 163 | print("Error processing building:", e) 164 | 165 | 166 | def extract_crop(buildings, wgs84, utm, north, south, east, west, save_folder, transformer): 167 | # Convert the bounding box to UTM 168 | west_utm, south_utm = transformer.transform(west, south) 169 | east_utm, north_utm = transformer.transform(east, north) 170 | print("width crop in meters", east_utm - west_utm) 171 | print("height crop in meters", north_utm - south_utm) 172 | # Create a Polygon that represents the bounding box 173 | bbox_polygon = Polygon( 174 | [ 175 | (west_utm, south_utm), 176 | (east_utm, south_utm), 177 | (east_utm, north_utm), 178 | (west_utm, north_utm), 179 | ] 180 | ) 181 | 182 | # Count how many buildings have their centroid within the bounding box 183 | try: 184 | buildings_in_bbox = sum( 185 | bbox_polygon.contains( 186 | Point(transformer.transform(*building.centroid.coords[0])) 187 | ) 188 | for building in buildings.geometry 189 | ) 190 | except Exception as e: 191 | buildings_in_bbox = 0 192 | 193 | # Check the number of buildings. Return if less than 2. 194 | if buildings_in_bbox < 2: 195 | print("Less than 2 buildings found in the given bounding box") 196 | return 197 | 198 | try: 199 | # Convert total bounds to UTM 200 | total_bounds_utm = [ 201 | transformer.transform( 202 | buildings.total_bounds[i], buildings.total_bounds[i + 1] 203 | ) 204 | for i in range(0, 4, 2) 205 | ] 206 | aspect_ratio = ( 207 | (total_bounds_utm[1][0] - total_bounds_utm[0][0]) 208 | / (total_bounds_utm[1][1] - total_bounds_utm[0][1]) 209 | ) ** (-1) 210 | dpi = 50 211 | max_pixels = 600 212 | max_inches = int(max_pixels / dpi) 213 | if aspect_ratio > 1: 214 | figsize = (max_inches, int(max_inches / aspect_ratio)) 215 | else: 216 | figsize = (int(max_inches * aspect_ratio), max_inches) 217 | # if 0 set it to 1 218 | if figsize[0] == 0: 219 | figsize = (1, figsize[1]) 220 | if figsize[1] == 0: 221 | figsize = (figsize[0], 1) 222 | 223 | # Plot and save the binary map 224 | fig, ax = plt.subplots(figsize=figsize, dpi=20) 225 | ax.axis("off") 226 | print("Plotting binary map...") 227 | ox.plot_footprints(buildings, ax=ax, color="black", bgcolor="white", show=False) 228 | print("Done") 229 | # save path include the coordinates of the bounding box 230 | save_path = f"{save_folder}/output_{north}_{south}_{east}_{west}" 231 | plt.savefig( 232 | save_path + ".png", 233 | dpi=300, 234 | bbox_inches="tight", 235 | pad_inches=0, 236 | transparent=True, 237 | ) 238 | print("Saved binary map in {}".format(save_path)) 239 | plt.close() 240 | except ValueError: 241 | print("No buildings found in the given bounding box") 242 | return 243 | 244 | # Post-process the image with OpenCV to count the number of distinct buildings 245 | image = cv2.imread(f"{save_path}.png", cv2.IMREAD_GRAYSCALE) 246 | ret, thresh = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY) 247 | num_labels, labels = cv2.connectedComponents(thresh) 248 | 249 | # Check the number of distinct buildings (subtract 1 because the background is also considered a label) 250 | num_buildings = num_labels - 1 251 | if num_buildings < 2: 252 | print("Less than 2 distinct buildings found in the image") 253 | os.remove( 254 | f"{save_path}.png" 255 | ) # Delete the image if it has less than 2 buildings 256 | return 257 | 258 | # Open the image file with PIL and get its size in pixels 259 | with Image.open(f"{save_path}.png") as img: 260 | width_px, height_px = img.size 261 | 262 | # Convert the size from pixels to inches 263 | new_size = (width_px / dpi, height_px / dpi) 264 | 265 | print("old size {} and new size {}".format(figsize, new_size)) 266 | 267 | width = total_bounds_utm[1][0] - total_bounds_utm[0][0] 268 | height = total_bounds_utm[1][1] - total_bounds_utm[0][1] 269 | new_height = new_size[0] / figsize[0] * height 270 | new_width = new_size[1] / figsize[1] * width 271 | print("old width {} and new width {}".format(width, new_width)) 272 | print("old height {} and new height {}".format(height, new_height)) 273 | 274 | # Save metadata file with real dimensions in meters 275 | metadata = { 276 | "coordinates": {"north": north, "south": south, "east": east, "west": west}, 277 | "real_dimensions": {"height": new_width, "width": new_height}, 278 | } 279 | with open(f"{save_path}.json", "w") as file: 280 | json.dump(metadata, file) 281 | 282 | 283 | def collect_random_crops( 284 | outer_bbox_wgs84: Tuple, crop_size_utm: Tuple[float, float], save_folder: str 285 | ): 286 | """ 287 | Creates a random crop in wgs84 coordinates and using the outer_bbox_wgs84 as a boundary and the crops_size_utm in meters 288 | as the size of the crop. 289 | 290 | Args: 291 | outer_bbox_wgs84: (north, south, east, west) 292 | crop_size_utm: (width, height) 293 | save_folder: str 294 | """ 295 | north, south, east, west = outer_bbox_wgs84 296 | width_crop, height_crop = crop_size_utm 297 | 298 | # Define coordinate reference systems 299 | utm = CRS("EPSG:32633") # UTM zone 33N (covers central Europe) 300 | wgs84 = CRS("EPSG:4326") # WGS84 (covers the entire globe) 301 | 302 | # Create transformers 303 | transformer_to_utm = Transformer.from_crs(wgs84, utm, always_xy=True) 304 | transformer_from_utm = Transformer.from_crs(utm, wgs84, always_xy=True) 305 | 306 | west_utm, south_utm = transformer_to_utm.transform(west, south) 307 | east_utm, north_utm = transformer_to_utm.transform(east, north) 308 | 309 | # Get width and height of the outer box in UTM coordinates 310 | width_utm = east_utm - west_utm 311 | height_utm = north_utm - south_utm 312 | 313 | # Choose a random point within the UTM coordinate box as the lower left corner of the new box 314 | west_crop = west_utm + random.uniform(0, width_utm - width_crop) 315 | south_crop = south_utm + random.uniform(0, height_utm - height_crop) 316 | 317 | # Create a new box with the given size around this point 318 | east_crop = west_crop + width_crop 319 | north_crop = south_crop + height_crop 320 | 321 | # Convert the UTM coordinates of this box back to WGS84 coordinates 322 | west_wgs84, south_wgs84 = transformer_from_utm.transform(west_crop, south_crop) 323 | east_wgs84, north_wgs84 = transformer_from_utm.transform(east_crop, north_crop) 324 | 325 | crop_wgs84 = (north_wgs84, south_wgs84, east_wgs84, west_wgs84) 326 | 327 | if not os.path.exists(save_folder): 328 | os.makedirs(save_folder) 329 | 330 | # Call the get_building_shapes_from_OSM function with these new coordinates and the given save folder 331 | get_building_shapes_from_OSM(*crop_wgs84, option=1, save_folder=save_folder) 332 | -------------------------------------------------------------------------------- /terra/env.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from functools import partial 3 | from typing import NamedTuple 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | from jax import Array 8 | 9 | from terra.actions import Action 10 | from terra.config import BatchConfig 11 | from terra.config import EnvConfig 12 | from terra.maps_buffer import init_maps_buffer 13 | from terra.state import State 14 | from terra.wrappers import LocalMapWrapper 15 | from terra.wrappers import TraversabilityMaskWrapper 16 | from terra.curriculum import CurriculumManager 17 | import pygame as pg 18 | from terra.viz.game.game import Game 19 | from terra.viz.game.settings import MAP_TILES 20 | 21 | 22 | class TimeStep(NamedTuple): 23 | state: State 24 | observation: dict[str, jax.Array] 25 | reward: jax.Array 26 | done: jax.Array 27 | info: dict 28 | env_cfg: EnvConfig 29 | 30 | 31 | class TerraEnv(NamedTuple): 32 | rendering_engine: Game | None = None 33 | 34 | @classmethod 35 | def new( 36 | cls, 37 | maps_size_px: int, 38 | rendering: bool = False, 39 | n_envs_x: int = 1, 40 | n_envs_y: int = 1, 41 | display: bool = False, 42 | ) -> "TerraEnv": 43 | re = None 44 | tile_size_rendering = MAP_TILES // maps_size_px 45 | if rendering: 46 | pg.init() 47 | pg.mixer.init() 48 | display_dims = ( 49 | n_envs_y * (maps_size_px + 4) * tile_size_rendering 50 | + 4 * tile_size_rendering, 51 | n_envs_x * (maps_size_px + 4) * tile_size_rendering 52 | + 4 * tile_size_rendering, 53 | ) 54 | if not display: 55 | print("TerraEnv: disabling display...") 56 | screen = pg.display.set_mode( 57 | display_dims, pg.FULLSCREEN | pg.HIDDEN 58 | ) 59 | else: 60 | screen = pg.display.set_mode(display_dims) 61 | surface = pg.Surface(display_dims, pg.SRCALPHA) 62 | clock = pg.time.Clock() 63 | re = Game( 64 | screen, 65 | surface, 66 | clock, 67 | maps_size_px=maps_size_px, 68 | n_envs_x=n_envs_x, 69 | n_envs_y=n_envs_y, 70 | display=display, 71 | ) 72 | return TerraEnv(rendering_engine=re) 73 | 74 | @partial(jax.jit, static_argnums=(0,)) 75 | def reset( 76 | self, 77 | key: jax.random.PRNGKey, 78 | target_map: Array, 79 | padding_mask: Array, 80 | trench_axes: Array, 81 | trench_type: Array, 82 | dumpability_mask_init: Array, 83 | action_map: Array, 84 | env_cfg: EnvConfig, 85 | ) -> tuple[State, dict[str, Array]]: 86 | """ 87 | Resets the environment using values from config files, and a seed. 88 | """ 89 | state = State.new( 90 | key, 91 | env_cfg, 92 | target_map, 93 | padding_mask, 94 | trench_axes, 95 | trench_type, 96 | dumpability_mask_init, 97 | action_map, 98 | ) 99 | state = self.wrap_state(state) 100 | 101 | observations = self._state_to_obs_dict(state) 102 | dummy_action = BatchConfig().action_type.do_nothing() 103 | 104 | return TimeStep( 105 | state=state, 106 | observation=observations, 107 | reward=jnp.zeros(()), 108 | done=jnp.zeros((), dtype=bool), 109 | info=state._get_infos(dummy_action, False), 110 | env_cfg=env_cfg, 111 | ) 112 | 113 | @staticmethod 114 | def wrap_state(state: State) -> State: 115 | state = TraversabilityMaskWrapper.wrap(state) 116 | state = LocalMapWrapper.wrap(state) 117 | return state 118 | 119 | @partial(jax.jit, static_argnums=(0,)) 120 | def _reset_existent( 121 | self, 122 | state: State, 123 | target_map: Array, 124 | padding_mask: Array, 125 | trench_axes: Array, 126 | trench_type: Array, 127 | dumpability_mask_init: Array, 128 | action_map: Array, 129 | env_cfg: EnvConfig, 130 | ) -> tuple[State, dict[str, Array]]: 131 | """ 132 | Resets the env, assuming that it already exists. 133 | """ 134 | state = state._reset( 135 | env_cfg, 136 | target_map, 137 | padding_mask, 138 | trench_axes, 139 | trench_type, 140 | dumpability_mask_init, 141 | action_map 142 | ) 143 | state = self.wrap_state(state) 144 | observations = self._state_to_obs_dict(state) 145 | return state, observations 146 | 147 | def render_obs_pygame( 148 | self, 149 | obs: dict[str, Array], 150 | info=None, 151 | generate_gif: bool = False, 152 | ) -> Array: 153 | """ 154 | Renders the environment at a given observation. 155 | """ 156 | if info is not None: 157 | target_tiles = info["target_tiles"] 158 | else: 159 | target_tiles = None 160 | 161 | self.rendering_engine.run( 162 | active_grid=obs["action_map"], 163 | target_grid=obs["target_map"], 164 | padding_mask=obs["padding_mask"], 165 | dumpability_mask=obs["dumpability_mask"], 166 | agent_pos=obs["agent_state"][..., [0, 1]], 167 | base_dir=obs["agent_state"][..., [2]], 168 | cabin_dir=obs["agent_state"][..., [3]], 169 | loaded=obs["agent_state"][..., [5]], 170 | target_tiles=target_tiles, 171 | generate_gif=generate_gif, 172 | ) 173 | 174 | @partial(jax.jit, static_argnums=(0,)) 175 | def step( 176 | self, 177 | state: State, 178 | action: Action, 179 | target_map: Array, 180 | padding_mask: Array, 181 | trench_axes: Array, 182 | trench_type: Array, 183 | dumpability_mask_init: Array, 184 | action_map: Array, 185 | env_cfg: EnvConfig, 186 | ) -> TimeStep: 187 | new_state = state._step(action) 188 | reward = state._get_reward(new_state, action) 189 | new_state = self.wrap_state(new_state) 190 | obs = self._state_to_obs_dict(new_state) 191 | 192 | done, task_done = state._is_done( 193 | new_state.world.action_map.map, 194 | new_state.world.target_map.map, 195 | new_state.agent.agent_state.loaded, 196 | ) 197 | 198 | def _reset_branch(s, o, cfg): 199 | s_reset, o_reset = self._reset_existent( 200 | s, 201 | target_map, 202 | padding_mask, 203 | trench_axes, 204 | trench_type, 205 | dumpability_mask_init, 206 | action_map, 207 | cfg, 208 | ) 209 | return s_reset, o_reset, cfg 210 | 211 | def _nominal_branch(s, o, cfg): 212 | return s, o, cfg 213 | 214 | new_state, obs, env_cfg = jax.lax.cond( 215 | done, 216 | _reset_branch, 217 | _nominal_branch, 218 | new_state, 219 | obs, 220 | env_cfg, 221 | ) 222 | 223 | infos = new_state._get_infos(action, task_done) 224 | return TimeStep( 225 | state=new_state, 226 | observation=obs, 227 | reward=reward, 228 | done=done, 229 | info=infos, 230 | env_cfg=env_cfg, # now the right, possibly flipped `apply_trench_rewards` 231 | ) 232 | 233 | @staticmethod 234 | def _state_to_obs_dict(state: State) -> dict[str, Array]: 235 | """ 236 | Transforms a State object to an observation dictionary. 237 | """ 238 | agent_state = jnp.hstack( 239 | [ 240 | state.agent.agent_state.pos_base, # pos_base is encoded in traversability_mask 241 | state.agent.agent_state.angle_base, 242 | state.agent.agent_state.angle_cabin, 243 | state.agent.agent_state.wheel_angle, 244 | state.agent.agent_state.loaded, 245 | ] 246 | ) 247 | # Note: not all of those fields are used by the network for training! 248 | return { 249 | "agent_state": agent_state, 250 | "local_map_action_neg": state.world.local_map_action_neg.map, 251 | "local_map_action_pos": state.world.local_map_action_pos.map, 252 | "local_map_target_neg": state.world.local_map_target_neg.map, 253 | "local_map_target_pos": state.world.local_map_target_pos.map, 254 | "local_map_dumpability": state.world.local_map_dumpability.map, 255 | "local_map_obstacles": state.world.local_map_obstacles.map, 256 | "traversability_mask": state.world.traversability_mask.map, 257 | "action_map": state.world.action_map.map, 258 | "target_map": state.world.target_map.map, 259 | "agent_width": state.agent.width, 260 | "agent_height": state.agent.height, 261 | "padding_mask": state.world.padding_mask.map, 262 | "dumpability_mask": state.world.dumpability_mask.map, 263 | } 264 | 265 | 266 | class TerraEnvBatch: 267 | """ 268 | Takes care of the parallelization of the environment. 269 | """ 270 | 271 | def __init__( 272 | self, 273 | batch_cfg: BatchConfig = BatchConfig(), 274 | rendering: bool = False, 275 | n_envs_x_rendering: int = 1, 276 | n_envs_y_rendering: int = 1, 277 | display: bool = False, 278 | shuffle_maps: bool = False, 279 | single_map_path: str = None, 280 | ) -> None: 281 | self.maps_buffer, self.batch_cfg = init_maps_buffer(batch_cfg, shuffle_maps, single_map_path) 282 | self.terra_env = TerraEnv.new( 283 | maps_size_px=self.batch_cfg.maps_dims.maps_edge_length, 284 | rendering=rendering, 285 | n_envs_x=n_envs_x_rendering, 286 | n_envs_y=n_envs_y_rendering, 287 | display=display, 288 | ) 289 | max_curriculum_level = len(batch_cfg.curriculum_global.levels) - 1 290 | max_steps_in_episode_per_level = jnp.array( 291 | [ 292 | level["max_steps_in_episode"] 293 | for level in batch_cfg.curriculum_global.levels 294 | ], 295 | dtype=jnp.int32, 296 | ) 297 | apply_trench_rewards_per_level = jnp.array( 298 | [ 299 | level["apply_trench_rewards"] 300 | for level in batch_cfg.curriculum_global.levels 301 | ], 302 | dtype=jnp.bool_, 303 | ) 304 | reward_type_per_level = jnp.array( 305 | [level["rewards_type"] for level in batch_cfg.curriculum_global.levels], 306 | dtype=jnp.int32, 307 | ) 308 | self.curriculum_manager = CurriculumManager( 309 | max_level=max_curriculum_level, 310 | increase_level_threshold=batch_cfg.curriculum_global.increase_level_threshold, 311 | decrease_level_threshold=batch_cfg.curriculum_global.decrease_level_threshold, 312 | max_steps_in_episode_per_level=max_steps_in_episode_per_level, 313 | apply_trench_rewards_per_level=apply_trench_rewards_per_level, 314 | reward_type_per_level=reward_type_per_level, 315 | last_level_type=batch_cfg.curriculum_global.last_level_type, 316 | ) 317 | 318 | def update_env_cfgs(self, env_cfgs: EnvConfig) -> EnvConfig: 319 | tile_size = ( 320 | self.batch_cfg.maps.edge_length_m 321 | / self.batch_cfg.maps_dims.maps_edge_length 322 | ) 323 | print(f"tile_size: {tile_size}") 324 | agent_w = self.batch_cfg.agent.dimensions.WIDTH 325 | agent_h = self.batch_cfg.agent.dimensions.HEIGHT 326 | agent_height = ( 327 | round(agent_w / tile_size) 328 | if (round(agent_w / tile_size)) % 2 != 0 329 | else round(agent_w / tile_size) + 1 330 | ) 331 | agent_width = ( 332 | round(agent_h / tile_size) 333 | if (round(agent_h / tile_size)) % 2 != 0 334 | else round(agent_h / tile_size) + 1 335 | ) 336 | print(f"agent_width: {agent_width}, agent_height: {agent_height}") 337 | 338 | # Repeat to match the number of environments 339 | n_envs = env_cfgs.agent.dig_depth.shape[ 340 | 0 341 | ] # leading dimension of any field in the config is the number of envs 342 | tile_size = jnp.repeat(jnp.array([tile_size], dtype=jnp.float32), n_envs) 343 | agent_width = jnp.repeat(jnp.array([agent_width], dtype=jnp.int32), n_envs) 344 | agent_height = jnp.repeat(jnp.array([agent_height], dtype=jnp.int32), n_envs) 345 | edge_length_px = jnp.repeat( 346 | jnp.array([self.batch_cfg.maps_dims.maps_edge_length], dtype=jnp.int32), 347 | n_envs, 348 | ) 349 | env_cfgs = env_cfgs._replace( 350 | tile_size=tile_size, 351 | agent=env_cfgs.agent._replace(width=agent_width, height=agent_height), 352 | maps=env_cfgs.maps._replace(edge_length_px=edge_length_px), 353 | ) 354 | return env_cfgs 355 | 356 | def _get_map_init(self, key: jax.random.PRNGKey, env_cfgs: EnvConfig): 357 | return jax.vmap(self.maps_buffer.get_map_init)(key, env_cfgs) 358 | 359 | def _get_map(self, maps_buffer_keys: jax.random.PRNGKey, env_cfgs: EnvConfig): 360 | return jax.vmap(self.maps_buffer.get_map)(maps_buffer_keys, env_cfgs) 361 | 362 | @partial(jax.jit, static_argnums=(0,)) 363 | def reset(self, env_cfgs: EnvConfig, rng_key: jax.random.PRNGKey) -> State: 364 | env_cfgs = self.curriculum_manager.reset_cfgs(env_cfgs) 365 | env_cfgs = self.update_env_cfgs(env_cfgs) 366 | ( 367 | target_maps, 368 | padding_masks, 369 | trench_axes, 370 | trench_type, 371 | dumpability_mask_init, 372 | action_maps, 373 | new_rng_key, 374 | ) = self._get_map_init(rng_key, env_cfgs) 375 | timestep = jax.vmap(self.terra_env.reset)( 376 | rng_key, 377 | target_maps, 378 | padding_masks, 379 | trench_axes, 380 | trench_type, 381 | dumpability_mask_init, 382 | action_maps, 383 | env_cfgs, 384 | ) 385 | return timestep 386 | 387 | @partial(jax.jit, static_argnums=(0,)) 388 | def step( 389 | self, 390 | timestep: TimeStep, 391 | actions: Action, 392 | maps_buffer_keys: jax.random.PRNGKey, 393 | ) -> tuple[State, tuple[dict, Array, Array, dict]]: 394 | # Update env_cfgs based on the curriculum, and get the new maps 395 | timestep = self.curriculum_manager.update_cfgs(timestep, maps_buffer_keys) 396 | ( 397 | target_maps, 398 | padding_masks, 399 | trench_axes, 400 | trench_type, 401 | dumpability_mask_init, 402 | action_maps, 403 | maps_buffer_keys, 404 | ) = self._get_map(maps_buffer_keys, timestep.env_cfg) 405 | # Step the environment 406 | timestep = jax.vmap(self.terra_env.step)( 407 | timestep.state, 408 | actions, 409 | target_maps, 410 | padding_masks, 411 | trench_axes, 412 | trench_type, 413 | dumpability_mask_init, 414 | action_maps, 415 | timestep.env_cfg, 416 | ) 417 | return timestep 418 | --------------------------------------------------------------------------------