├── utils ├── __init__.py ├── helpers.py ├── utils_ppo.py └── models.py ├── plan_maps ├── .DS_Store ├── trench │ ├── image.npy │ ├── occupancy.npy │ ├── dumpability.npy │ └── metadata.json └── foundation │ ├── .DS_Store │ ├── image.npy │ ├── dumpability.npy │ ├── metadata.json │ └── occupancy.npy ├── assets ├── tracked-dense.gif └── wheeled-dense.gif ├── requirements.txt ├── .flake8 ├── checkpoints ├── tracked-dense.pkl └── wheeled-dense.pkl ├── pyproject.toml ├── cluster ├── train_cluster.sh ├── sweep_cluster.sh └── README.md ├── .gitignore ├── sweep.py ├── eval_ppo.py ├── visualize.py ├── README.md ├── terra_planner ├── extract_plan.py └── visualize_plan.py ├── eval.py └── train.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /plan_maps/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/plan_maps/.DS_Store -------------------------------------------------------------------------------- /assets/tracked-dense.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/assets/tracked-dense.gif -------------------------------------------------------------------------------- /assets/wheeled-dense.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/assets/wheeled-dense.gif -------------------------------------------------------------------------------- /plan_maps/trench/image.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/plan_maps/trench/image.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | jaxlib 3 | chex 4 | tqdm 5 | flax 6 | matplotlib 7 | pygame 8 | wandb 9 | tensorflow_probability -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, E203, W, PAI 3 | per-file-ignores = 4 | terra/noise/simplex_noise.py: F401, E741, E731 -------------------------------------------------------------------------------- /checkpoints/tracked-dense.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/checkpoints/tracked-dense.pkl -------------------------------------------------------------------------------- /checkpoints/wheeled-dense.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/checkpoints/wheeled-dense.pkl -------------------------------------------------------------------------------- /plan_maps/foundation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/plan_maps/foundation/.DS_Store -------------------------------------------------------------------------------- /plan_maps/foundation/image.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/plan_maps/foundation/image.npy -------------------------------------------------------------------------------- /plan_maps/trench/occupancy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/plan_maps/trench/occupancy.npy -------------------------------------------------------------------------------- /plan_maps/trench/dumpability.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/plan_maps/trench/dumpability.npy -------------------------------------------------------------------------------- /plan_maps/foundation/dumpability.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/plan_maps/foundation/dumpability.npy -------------------------------------------------------------------------------- /plan_maps/foundation/metadata.json: -------------------------------------------------------------------------------- 1 | {"building_index": 1500, "real_dimensions": {"width": 21.574441602392632, "height": 22.864106025400492}} -------------------------------------------------------------------------------- /plan_maps/foundation/occupancy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leggedrobotics/terra-baselines/HEAD/plan_maps/foundation/occupancy.npy -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | target-version = ['py311'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | /( 7 | .git 8 | | .venv 9 | | __pycache__ 10 | )/ 11 | ''' 12 | -------------------------------------------------------------------------------- /plan_maps/trench/metadata.json: -------------------------------------------------------------------------------- 1 | {"real_dimensions": {"width": 64.0, "height": 64.0}, "axes_ABC": [{"A": 0.0, "B": 21.0, "C": -525.0}, {"A": -16.0, "B": 0.0, "C": 720.0}], "lines_pts": [[[45.0, 25.0], [24.0, 25.0]], [[45.0, 41.0], [45.0, 25.0]]]} -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | 5 | def load_pkl_object(filename: str): 6 | """Helper to reload pickle objects.""" 7 | import pickle 8 | 9 | with open(filename, "rb") as input: 10 | obj = pickle.load(input) 11 | print(f"Loaded data from {filename}.") 12 | return obj 13 | 14 | 15 | def save_pkl_object(obj, filename): 16 | """Helper to store pickle objects.""" 17 | output_file = Path(filename) 18 | output_file.parent.mkdir(exist_ok=True, parents=True) 19 | 20 | with open(filename, "wb") as output: 21 | pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL) 22 | 23 | print(f"Stored data at {filename}.") 24 | -------------------------------------------------------------------------------- /cluster/train_cluster.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -n 1 3 | #SBATCH --cpus-per-task=8 4 | #SBATCH --gpus=rtx_4090:2 5 | #SBATCH --time=12:00:00 6 | #SBATCH --mem-per-cpu=8G 7 | #SBATCH --mail-type=END 8 | #SBATCH --mail-user=name@mail 9 | #SBATCH --job-name="training-$(date +"%Y-%m-%dT%H:%M")" 10 | #SBATCH --output=%j_training.out 11 | 12 | # Load required modules 13 | module load eth_proxy 14 | module load stack/2024-06 cuda/12.1.1 15 | 16 | # Set paths to conda 17 | CONDA_ROOT=/cluster/home/spiasecki/miniconda3 18 | CONDA_ENV=terra 19 | 20 | # Activate conda environment properly for batch jobs 21 | eval "$($CONDA_ROOT/bin/conda shell.bash hook)" 22 | conda activate $CONDA_ENV 23 | 24 | # Set environment variables and run training 25 | export DATASET_PATH=/cluster/home/spiasecki/terra/data/ 26 | export DATASET_SIZE=1800 27 | 28 | # Change to the directory containing train.py or use the full path 29 | cd /cluster/home/spiasecki/terra-baselines 30 | python train.py 31 | -------------------------------------------------------------------------------- /cluster/sweep_cluster.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -n 1 3 | #SBATCH --cpus-per-task=8 4 | #SBATCH --gpus=rtx_4090:8 5 | #SBATCH --time=120:00:00 6 | #SBATCH --mem-per-cpu=8G 7 | #SBATCH --mail-type=END 8 | #SBATCH --mail-user=name@mail 9 | #SBATCH --job-name="sweep-$(date +"%Y-%m-%dT%H:%M")" 10 | #SBATCH --output=%j_sweep.out 11 | 12 | module load eth_proxy 13 | module load stack/2024-06 cuda/12.1.1 14 | 15 | CONDA_ROOT=/cluster/home/spiasecki/miniconda3 16 | CONDA_ENV=terra 17 | 18 | eval "$($CONDA_ROOT/bin/conda shell.bash hook)" 19 | conda activate $CONDA_ENV 20 | 21 | export DATASET_PATH=/cluster/home/spiasecki/terra/data/ 22 | export DATASET_SIZE=1500 23 | 24 | cd /cluster/home/spiasecki/terra-baselines 25 | 26 | # Create the sweep and capture the sweep ID 27 | SWEEP_ID=$(python sweep.py create | grep 'Create sweep with ID:' | awk '{print $5}') 28 | 29 | # Run agents in parallel 30 | wandb agent terra-sp-thesis/sweep/$SWEEP_ID & # Agent 1 31 | wandb agent terra-sp-thesis/sweep/$SWEEP_ID & # Agent 2 32 | 33 | wait 34 | -------------------------------------------------------------------------------- /utils/utils_ppo.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from tensorflow_probability.substrates import jax as tfp 4 | 5 | 6 | def clip_action_map_in_obs(obs): 7 | """Clip action maps to [-1, 1] on the intuition that a binary map is enough for the agent to take decisions.""" 8 | obs["action_map"] = jnp.clip(obs["action_map"], a_min=-1, a_max=1) 9 | return obs 10 | 11 | 12 | def obs_to_model_input(obs, prev_actions, train_cfg): 13 | # Feature engineering 14 | if train_cfg.clip_action_maps: 15 | obs = clip_action_map_in_obs(obs) 16 | 17 | obs = [ 18 | obs["agent_state"], 19 | obs["local_map_action_neg"], 20 | obs["local_map_action_pos"], 21 | obs["local_map_target_neg"], 22 | obs["local_map_target_pos"], 23 | obs["local_map_dumpability"], 24 | obs["local_map_obstacles"], 25 | # obs["action_map"], not used in the model, performance somehow better without it! 26 | obs["target_map"], 27 | obs["traversability_mask"], 28 | obs["dumpability_mask"], 29 | prev_actions, 30 | ] 31 | return obs 32 | 33 | 34 | def policy( 35 | apply_fn, 36 | params, 37 | obs, 38 | ): 39 | value, logits_pi = apply_fn(params, obs) 40 | pi = tfp.distributions.Categorical(logits=logits_pi) 41 | return value, pi 42 | 43 | 44 | def select_action_ppo( 45 | train_state, 46 | obs: jnp.ndarray, 47 | prev_actions: jnp.ndarray, 48 | rng: jax.random.PRNGKey, 49 | config, 50 | ): 51 | # Prepare policy input from Terra State 52 | obs = obs_to_model_input(obs, prev_actions, config) 53 | 54 | value, pi = policy(train_state.apply_fn, train_state.params, obs) 55 | action = pi.sample(seed=rng) 56 | log_prob = pi.log_prob(action) 57 | return action, log_prob, value[:, 0], pi 58 | 59 | 60 | def wrap_action(action, action_type): 61 | action = action_type.new(action[:, None]) 62 | return action 63 | -------------------------------------------------------------------------------- /cluster/README.md: -------------------------------------------------------------------------------- 1 | # Running Terra Training on the Cluster 2 | 3 | This guide provides instructions for setting up and running Terra training jobs on the compute cluster using SLURM. 4 | 5 | ## Prerequisites 6 | 7 | - Access to the cluster with SLURM workload manager 8 | - CUDA-compatible GPUs (RTX 3090) 9 | - Anaconda/Miniconda installed 10 | 11 | ## Setup 12 | 13 | 1. **Conda Environment Setup**: 14 | ```bash 15 | # Create the conda environment if you haven't already 16 | conda env create -f /cluster/home/lterenzi/terra_jax/terra/environment.yml -n terra 17 | ``` 18 | 19 | 2. **Update the Training Script**: 20 | - The script `train_cluster.sh` has been prepared for you but may need adjustments: 21 | - Ensure `CONDA_ROOT` points to your conda installation (currently set to `/home/lorenzo/anaconda3`) 22 | - Verify `CONDA_ENV` is correct (currently set to `terra`) 23 | - Update any paths if your file structure differs 24 | 25 | ## Running Training Jobs 26 | 27 | ### Submit a Training Job 28 | 29 | To submit a training job to the SLURM scheduler: 30 | 31 | ```bash 32 | # Navigate to the project directory 33 | cd /cluster/home/lterenzi/terra_jax 34 | 35 | # Submit the job 36 | sbatch terra-baselines/cluster/train_cluster.sh 37 | ``` 38 | 39 | This will submit your job to the SLURM scheduler and return a job ID. 40 | 41 | ### Monitor Your Job 42 | 43 | You can monitor the status of your job using: 44 | 45 | ```bash 46 | # Check job status 47 | squeue -u lterenzi 48 | 49 | # View job details 50 | scontrol show job 51 | 52 | # Monitor output in real-time 53 | tail -f _training.out 54 | ``` 55 | 56 | ### Common Commands 57 | 58 | - Cancel a job: `scancel ` 59 | - View job efficiency: `seff ` 60 | - Check resource availability: `sinfo` 61 | 62 | ## Customizing Training Parameters 63 | 64 | The current training script uses these parameters: 65 | 66 | - `DATASET_PATH=/cluster/home/lterenzi/terra_jax/terra/data/terra/train` 67 | - `DATASET_SIZE=200` 68 | 69 | To modify these or add other parameters, edit the `train_cluster.sh` script. 70 | 71 | ## Resource Allocation 72 | 73 | The current script requests: 74 | - 8 CPUs 75 | - 1 RTX 3090 GPU 76 | - 4048 MB memory per CPU 77 | - 3 hours maximum runtime 78 | 79 | To adjust these resources, modify the SLURM directives at the top of the `train_cluster.sh` script. 80 | 81 | ## Troubleshooting 82 | 83 | ### Common Issues 84 | 85 | 1. **Conda Activation Errors**: 86 | - The script uses `eval "$($CONDA_ROOT/bin/conda shell.bash hook)"` which should work in most cluster environments 87 | - If conda still fails to activate, you might need to use a module-based approach specific to your cluster 88 | 89 | 2. **Module Not Found Errors**: 90 | - The script sets `PYTHONPATH=$PYTHONPATH:/cluster/home/lterenzi/terra_jax` to find the `terra` module 91 | - If you encounter "ModuleNotFoundError", verify that the path is correct and the module is present 92 | 93 | 3. **GPU Memory Issues**: 94 | - Adjust your batch size in the training code 95 | - Monitor GPU usage with `nvidia-smi` 96 | 97 | ### Getting Help 98 | 99 | If you encounter persistent issues: 100 | - Check SLURM logs: `less _training.out` 101 | - Contact cluster administrators for cluster-specific issues 102 | 103 | ## Advanced Configuration 104 | 105 | ### Multi-GPU Training 106 | 107 | To utilize multiple GPUs, modify the SLURM parameter: 108 | 109 | ```bash 110 | #SBATCH --gpus=rtx_3090:2 # Using 2 GPUs 111 | ``` 112 | 113 | And ensure your training code is set up for distributed training. 114 | 115 | ### Checkpointing 116 | 117 | If your training jobs are long-running, consider implementing checkpointing in your training code to save progress periodically. You can add a checkpoint directory to your script: 118 | 119 | ```bash 120 | export CHECKPOINT_DIR=/cluster/home/lterenzi/terra_jax/checkpoints 121 | ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | *.pkl 163 | docs/ 164 | wandb/ 165 | 166 | agents/Terra/dummy.txt 167 | 168 | slurm* 169 | 170 | profile-remote/ 171 | tools/ 172 | -------------------------------------------------------------------------------- /sweep.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import time 3 | import sys 4 | import wandb 5 | from dataclasses import asdict, dataclass 6 | 7 | from terra.config import EnvConfig 8 | from train import make_states, make_train, TrainConfig 9 | 10 | @dataclass 11 | class TrainConfigSweep(TrainConfig): 12 | # Training config 13 | project: str = "sweep" 14 | total_timesteps: int = 2_000_000_000 15 | 16 | # Rewards 17 | existence: float = -0.1 18 | collision_move: float = -0.2 19 | move: float = -0.1 20 | move_with_turned_wheels: float = -0.1 21 | cabin_turn: float = -0.05 22 | wheel_turn: float = -0.05 23 | dig_wrong: float = -0.25 24 | dump_wrong: float = -1.0 25 | dig_correct: float = 0.2 26 | dump_correct: float = 0.15 27 | terminal: float = 100.0 28 | 29 | 30 | def train(config: TrainConfigSweep): 31 | run = wandb.init( 32 | entity="terra-sp-thesis", 33 | project=config.project, 34 | group=config.group, 35 | name=config.name, 36 | config=asdict(config), 37 | save_code=True, 38 | ) 39 | 40 | # Replace the rewards with the sweep values 41 | env_params = EnvConfig() 42 | env_params = env_params._replace( 43 | rewards=env_params.rewards._replace( 44 | existence=config.existence, 45 | collision_move=config.collision_move, 46 | move=config.move, 47 | cabin_turn=config.cabin_turn, 48 | wheel_turn=config.wheel_turn, 49 | dig_wrong=config.dig_wrong, 50 | dump_wrong=config.dump_wrong, 51 | dig_correct=config.dig_correct, 52 | dump_correct=config.dump_correct, 53 | terminal=config.terminal, 54 | ) 55 | ) 56 | rng, env, env_params, train_state = make_states(config, env_params) 57 | train_fn = make_train(env, env_params, config) 58 | 59 | print("Training...") 60 | try: # Try block starts here 61 | t = time.time() 62 | train_info = jax.block_until_ready(train_fn(rng, train_state)) 63 | elapsed_time = time.time() - t 64 | print(f"Done in {elapsed_time:.2f}s") 65 | except KeyboardInterrupt: # Catch Ctrl+C 66 | print("Training interrupted. Finalizing...") 67 | finally: # Ensure wandb.finish() is called 68 | run.finish() 69 | print("wandb session finished.") 70 | 71 | def sweep_train(): 72 | config = wandb.config 73 | if "name" not in config: 74 | config["name"] = f"sweep-{wandb.run.id}" 75 | # Convert wandb.config to TrainConfigSweep 76 | train_config = TrainConfigSweep(**dict(config)) 77 | train(train_config) 78 | 79 | if __name__ == "__main__": 80 | # If called with "create" argument, create the sweep and print the ID 81 | if len(sys.argv) > 1 and sys.argv[1] == "create": 82 | sweep_config = { 83 | "program": "train.py", 84 | "method": "bayes", 85 | "metric": { 86 | "name": "eval/positive_terminations", 87 | "goal": "maximize", 88 | }, 89 | "parameters": { 90 | "existence": {"values": [-0.1, -0.05, 0.0]}, 91 | "collision_move": {"values": [-0.3, -0.2, -0.1, 0.0]}, 92 | "move": {"values": [-0.3, -0.2, -0.1, -0.05, 0.0]}, 93 | "move_with_turned_wheels": {"values": [-0.3, -0.2, -0.1, -0.05, 0.0]}, 94 | "cabin_turn": {"values": [-0.2, -0.1, -0.05, 0.0]}, 95 | "wheel_turn": {"values": [-0.2, -0.1, -0.05, 0.0]}, 96 | "dig_wrong": {"values": [-0.5, -0.4, -0.3, -0.2, -0.1]}, 97 | "dump_wrong": {"values": [-0.5, -0.4, -0.3, -0.2, -0.1]}, 98 | "dig_correct": {"values": [0.1, 0.2, 0.3, 0.4, 0.5]}, 99 | "dump_correct": {"values": [0.1, 0.2, 0.3, 0.4, 0.5]}, 100 | "lr": {"values": [1e-4, 2e-4, 3e-4, 4e-4]}, 101 | "ent_coef": {"values": [0.0025, 0.005, 0.0075]}, 102 | } 103 | } 104 | sweep_id = wandb.sweep(sweep_config, project="sweep") 105 | else: 106 | # Called by wandb agent 107 | sweep_train() 108 | -------------------------------------------------------------------------------- /eval_ppo.py: -------------------------------------------------------------------------------- 1 | # utilities for PPO training and evaluation 2 | import jax 3 | import jax.numpy as jnp 4 | from flax.training.train_state import TrainState 5 | from typing import NamedTuple 6 | from utils.utils_ppo import select_action_ppo, wrap_action 7 | 8 | 9 | # for evaluation (evaluate for N consecutive episodes, sum rewards) 10 | # N=1 single task, N>1 for meta-RL 11 | class RolloutStats(NamedTuple): 12 | max_reward: jax.Array = jnp.asarray(-100) 13 | min_reward: jax.Array = jnp.asarray(100) 14 | reward: jax.Array = jnp.asarray(0.0) 15 | length: jax.Array = jnp.asarray(0) 16 | episodes: jax.Array = jnp.asarray(0) 17 | positive_terminations: jax.Array = jnp.asarray(0) # Count of positive terminations 18 | terminations: jax.Array = jnp.asarray(0) # Count of terminations 19 | positive_terminations_steps: jax.Array = jnp.asarray(0) 20 | 21 | action_0: jax.Array = jnp.asarray(0) 22 | action_1: jax.Array = jnp.asarray(0) 23 | action_2: jax.Array = jnp.asarray(0) 24 | action_3: jax.Array = jnp.asarray(0) 25 | action_4: jax.Array = jnp.asarray(0) 26 | action_5: jax.Array = jnp.asarray(0) 27 | action_6: jax.Array = jnp.asarray(0) 28 | 29 | 30 | # @partial(jax.pmap, axis_name="devices") 31 | def rollout( 32 | rng: jax.Array, 33 | env, 34 | env_params, 35 | train_state: TrainState, 36 | config, 37 | ) -> RolloutStats: 38 | num_envs = config.num_envs_per_device 39 | num_rollouts = config.num_rollouts_eval 40 | 41 | def _cond_fn(carry): 42 | _, stats, _ = carry 43 | # Check if the number of steps has been reached 44 | return jnp.less(stats.length, num_rollouts + 1) 45 | 46 | def _body_fn(carry): 47 | rng, stats, timestep, prev_actions = carry 48 | 49 | rng, _rng_step, _rng_model = jax.random.split(rng, 3) 50 | 51 | action, _, _, _ = select_action_ppo( 52 | train_state, timestep.observation, prev_actions, _rng_model, config 53 | ) 54 | _rng_step = jax.random.split(_rng_step, num_envs) 55 | action_env = wrap_action(action, env.batch_cfg.action_type) 56 | timestep = env.step(timestep, action_env, _rng_step) 57 | 58 | prev_actions = jnp.roll(prev_actions, shift=1, axis=-1) 59 | prev_actions = prev_actions.at[..., 0].set(action) 60 | 61 | terminations_update = timestep.done.sum() 62 | positive_termination_update = timestep.info["task_done"].sum() 63 | positive_termination_steps_update = (stats.length + 1) * positive_termination_update 64 | 65 | stats = RolloutStats( 66 | max_reward=jnp.maximum(stats.max_reward, timestep.reward.max()), 67 | min_reward=jnp.minimum(stats.min_reward, timestep.reward.min()), 68 | reward=stats.reward + timestep.reward.sum(), # Ensure correct aggregation 69 | length=stats.length + 1, 70 | episodes=stats.episodes + timestep.done.any(), 71 | positive_terminations=stats.positive_terminations 72 | + positive_termination_update, 73 | terminations=stats.terminations + terminations_update, 74 | positive_terminations_steps=stats.positive_terminations_steps 75 | + positive_termination_steps_update, 76 | action_0=stats.action_0 + (action == 0).sum(), 77 | action_1=stats.action_1 + (action == 1).sum(), 78 | action_2=stats.action_2 + (action == 2).sum(), 79 | action_3=stats.action_3 + (action == 3).sum(), 80 | action_4=stats.action_4 + (action == 4).sum(), 81 | action_5=stats.action_5 + (action == 5).sum(), 82 | action_6=stats.action_6 + (action == 6).sum(), 83 | ) 84 | carry = (rng, stats, timestep, prev_actions) 85 | return carry 86 | 87 | rng, _rng_reset = jax.random.split(rng) 88 | _rng_reset = jax.random.split(_rng_reset, num_envs) 89 | timestep = env.reset(env_params, _rng_reset) 90 | prev_actions = jnp.zeros((num_envs, config.num_prev_actions), dtype=jnp.int32) 91 | init_carry = (rng, RolloutStats(), timestep, prev_actions) 92 | 93 | # final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry) 94 | final_carry = jax.lax.fori_loop( 95 | 0, num_rollouts, lambda i, carry: _body_fn(carry), init_carry 96 | ) 97 | return final_carry[1] 98 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Partially from https://github.com/RobertTLange/gymnax-blines 3 | """ 4 | 5 | import numpy as np 6 | import jax 7 | from tqdm import tqdm 8 | from utils.models import load_neural_network 9 | from utils.helpers import load_pkl_object 10 | from terra.env import TerraEnvBatch 11 | import jax.numpy as jnp 12 | from utils.utils_ppo import obs_to_model_input, wrap_action 13 | from terra.state import State 14 | import matplotlib.animation as animation 15 | 16 | # from utils.curriculum import Curriculum 17 | from tensorflow_probability.substrates import jax as tfp 18 | from train import TrainConfig # needed for unpickling checkpoints 19 | from terra.config import EnvConfig 20 | 21 | 22 | def rollout_episode( 23 | env: TerraEnvBatch, model, model_params, env_cfgs, rl_config, max_frames, seed 24 | ): 25 | print(f"Using {seed=}") 26 | rng = jax.random.PRNGKey(seed) 27 | rng, _rng = jax.random.split(rng) 28 | rng_reset = jax.random.split(_rng, rl_config.num_test_rollouts) 29 | timestep = env.reset(env_cfgs, rng_reset) 30 | prev_actions = jnp.zeros( 31 | (rl_config.num_test_rollouts, rl_config.num_prev_actions), 32 | dtype=jnp.int32 33 | ) 34 | 35 | t_counter = 0 36 | reward_seq = [] 37 | obs_seq = [] 38 | state_seq = [] # Also collect states 39 | 40 | # Add initial observation and state (after reset) 41 | obs_seq.append(timestep.observation) 42 | state_seq.append(timestep.state) 43 | 44 | while True: 45 | rng, rng_act, rng_step = jax.random.split(rng, 3) 46 | if model is not None: 47 | obs = obs_to_model_input(timestep.observation, prev_actions, rl_config) 48 | v, logits_pi = model.apply(model_params, obs) 49 | pi = tfp.distributions.Categorical(logits=logits_pi) 50 | action = pi.sample(seed=rng_act) 51 | prev_actions = jnp.roll(prev_actions, shift=1, axis=1) 52 | prev_actions = prev_actions.at[:, 0].set(action) 53 | else: 54 | raise RuntimeError("Model is None!") 55 | rng_step = jax.random.split(rng_step, rl_config.num_test_rollouts) 56 | timestep = env.step( 57 | timestep, wrap_action(action, env.batch_cfg.action_type), rng_step 58 | ) 59 | 60 | t_counter += 1 61 | 62 | # COLLECT OBSERVATION AFTER STEP (includes soil mechanics changes) 63 | obs_seq.append(timestep.observation) 64 | state_seq.append(timestep.state) 65 | 66 | if t_counter <= 3: 67 | action_map = timestep.observation['action_map'] 68 | state_action_map = timestep.state.world.action_map.map 69 | 70 | # Compare first environment 71 | obs_dirt = action_map[0][action_map[0] > 0] if action_map.shape[0] > 0 else [] 72 | state_dirt = state_action_map[0][state_action_map[0] > 0] if state_action_map.shape[0] > 0 else [] 73 | 74 | 75 | 76 | reward_seq.append(timestep.reward) 77 | print(t_counter, timestep.reward, action, timestep.done) 78 | print(10 * "=") 79 | 80 | if jnp.all(timestep.done).item() or t_counter == max_frames: 81 | break 82 | print(f"Terra - Steps: {t_counter}, Return: {np.sum(reward_seq)}") 83 | return obs_seq, np.cumsum(reward_seq), state_seq 84 | 85 | 86 | def update_render(seq, env: TerraEnvBatch, frame): 87 | obs = {k: v[:, frame] for k, v in seq.items()} 88 | return env.terra_env.render_obs(obs, mode="gif") 89 | 90 | 91 | if __name__ == "__main__": 92 | import argparse 93 | 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument( 96 | "-run", 97 | "--run_name", 98 | type=str, 99 | default="ppo_2023_05_09_10_01_23", 100 | help="es/ppo trained agent.", 101 | ) 102 | parser.add_argument( 103 | "-env", 104 | "--env_name", 105 | type=str, 106 | default="Terra", 107 | help="Environment name.", 108 | ) 109 | parser.add_argument( 110 | "-nx", 111 | "--n_envs_x", 112 | type=int, 113 | default=3, 114 | help="Number of environments on x.", 115 | ) 116 | parser.add_argument( 117 | "-ny", 118 | "--n_envs_y", 119 | type=int, 120 | default=3, 121 | help="Number of environments on y.", 122 | ) 123 | parser.add_argument( 124 | "-steps", 125 | "--n_steps", 126 | type=int, 127 | default=100, 128 | help="Number of steps.", 129 | ) 130 | parser.add_argument( 131 | "-o", 132 | "--out_path", 133 | type=str, 134 | default="./visualize.gif", 135 | help="Output path.", 136 | ) 137 | parser.add_argument( 138 | "-s", 139 | "--seed", 140 | type=int, 141 | default=0, 142 | help="Random seed for the environment.", 143 | ) 144 | args, _ = parser.parse_known_args() 145 | n_envs = args.n_envs_x * args.n_envs_y 146 | 147 | log = load_pkl_object(f"{args.run_name}") 148 | config = log["train_config"] 149 | config.num_test_rollouts = n_envs 150 | config.num_devices = 1 151 | 152 | # curriculum = Curriculum(rl_config=config, n_devices=n_devices) 153 | # env_cfgs, dofs_count_dict = curriculum.get_cfgs_eval() 154 | env_cfgs = log["env_config"] 155 | env_cfgs = jax.tree_map( 156 | lambda x: x[0][None, ...].repeat(n_envs, 0), env_cfgs 157 | ) # take first config and replicate 158 | suffle_maps = True 159 | env = TerraEnvBatch( 160 | rendering=True, 161 | n_envs_x_rendering=args.n_envs_x, 162 | n_envs_y_rendering=args.n_envs_y, 163 | display=False, 164 | shuffle_maps=suffle_maps, 165 | ) 166 | config.num_embeddings_agent_min = 60 # curriculum.get_num_embeddings_agent_min() 167 | 168 | model = load_neural_network(config, env) 169 | model_params = log["model"] 170 | # replicated_params = log['network'] 171 | # model_params = jax.tree_map(lambda x: x[0], replicated_params) 172 | obs_seq, cum_rewards, state_seq = rollout_episode( 173 | env, 174 | model, 175 | model_params, 176 | env_cfgs, 177 | config, 178 | max_frames=args.n_steps, 179 | seed=args.seed, 180 | ) 181 | 182 | for i, o in enumerate(tqdm(obs_seq, desc="Rendering")): 183 | # Try using state action_map instead of observation action_map 184 | if i < len(state_seq): 185 | # Create modified observation with raw state action_map 186 | modified_obs = dict(o) 187 | modified_obs['action_map'] = state_seq[i].world.action_map.map 188 | env.terra_env.render_obs_pygame(modified_obs, generate_gif=True) 189 | else: 190 | env.terra_env.render_obs_pygame(o, generate_gif=True) 191 | 192 | env.terra_env.rendering_engine.create_gif(args.out_path) 193 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌍🚀 Terra Baselines - Training, Evals, and Checkpoints for Terra 2 | Terra Baselines provides a set of tools to train and evaluate RL policies on the [Terra](https://github.com/leggedrobotics/Terra) environment. This implementation allows to train an agent capable of planning earthworks in trenches and foundations environments in less than 1 minute on 8 Nvidia RTX-4090 GPUs. 3 | 4 | ## Features 5 | - Train on multiple devices using PPO with `train.py` (based on [XLand-MiniGrid](https://github.com/corl-team/xland-minigrid)) 6 | - Generate metrics for your checkpoint with `eval.py` 7 | - Visualize rollouts of your checkpoint with `visualize.py` 8 | - Run a grid search on the hyperparameters with `train_sweep.py` (orchestrated with [wandb](https://wandb.ai/)) 9 | 10 | ## Installation 11 | Clone the repo and install the requirements with 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | Clone [Terra](https://github.com/leggedrobotics/Terra) in a different folder and install it with 17 | ``` 18 | pip install -e . 19 | ``` 20 | 21 | Lastly, [install JAX](https://jax.readthedocs.io/en/latest/installation.html). 22 | 23 | ## Train 24 | Setup your training job configuring the `TrainConfig` 25 | ``` python 26 | @dataclass 27 | class TrainConfig: 28 | name: str 29 | num_devices: int = 0 30 | project: str = "excavator" 31 | group: str = "default" 32 | num_envs_per_device: int = 4096 33 | num_steps: int = 32 34 | update_epochs: int = 3 35 | num_minibatches: int = 32 36 | total_timesteps: int = 3_000_000_000 37 | lr: float = 3e-4 38 | clip_eps: float = 0.5 39 | gamma: float = 0.995 40 | gae_lambda: float = 0.95 41 | ent_coef: float = 0.01 42 | vf_coef: float = 5.0 43 | max_grad_norm: float = 0.5 44 | eval_episodes: int = 100 45 | seed: int = 42 46 | log_train_interval: int = 1 47 | log_eval_interval: int = 50 48 | checkpoint_interval: int = 50 49 | clip_action_maps = True 50 | local_map_normalization_bounds = [-16, 16] 51 | loaded_max = 100 52 | num_rollouts_eval = 300 53 | ``` 54 | Then, setup the curriculum in `config.py` in Terra (making sure the maps are saved to disk). 55 | 56 | Run a training job with 57 | ``` 58 | DATASET_PATH=/path/to/dataset DATASET_SIZE= python train.py -d 59 | ``` 60 | and collect your weights in the `checkpoints/` folder. 61 | 62 | ## Sweep 63 | 64 | You can run a hyperparameter sweep over reward settings using [Weights & Biases Sweeps](https://docs.wandb.ai/guides/sweeps). This allows you to efficiently grid search or random search over reward parameters and compare results. 65 | 66 | ### 1. Define the Sweep 67 | 68 | The sweep configuration is defined in `sweep.py`. It includes a grid over reward parameters such as `existence`, `collision_move`, `move`, etc. The sweep uses the `TrainConfigSweep` dataclass, which extends the standard training config with sweepable reward parameters. 69 | 70 | ### 2. Create the Sweep 71 | 72 | To create a new sweep on wandb, run: 73 | ```bash 74 | python sweep.py create 75 | ``` 76 | This will print a sweep ID (e.g., `abc123xy`). Copy this ID for the next step. 77 | 78 | ### 3. Launch Agents 79 | 80 | You can launch multiple agents (workers) to run experiments in parallel. Each agent will pick up a different configuration from the sweep and start a training run. 81 | 82 | To launch an agent, run: 83 | ```bash 84 | wandb agent 85 | ``` 86 | You can run this command multiple times (e.g., in different terminals, or as background jobs in a cluster script) to parallelize the sweep. 87 | 88 | #### Example: Running Multiple Agents in a Cluster Script 89 | 90 | If you are using a cluster, you can use the provided `sweep_cluster.sh` script. Make sure to set the `SWEEP_ID` variable to your sweep ID: 91 | ```bash 92 | # In sweep_cluster.sh 93 | SWEEP_ID= 94 | wandb agent $SWEEP_ID & 95 | wandb agent $SWEEP_ID & 96 | wandb agent $SWEEP_ID & 97 | wandb agent $SWEEP_ID & 98 | wait 99 | ``` 100 | 101 | ## Eval 102 | Evaluate your checkpoint with standard metrics using 103 | ``` 104 | DATASET_PATH=/path/to/dataset DATASET_SIZE= python eval.py -run -n -steps 105 | ``` 106 | 107 | ## Visualize 108 | Visualize the rollout of your policy with 109 | ``` 110 | DATASET_PATH=/path/to/dataset DATASET_SIZE= python visualize.py -run -nx -ny -steps -o 111 | ``` 112 | 113 | ## Plan Extraction and Analysis 114 | 115 | Extract and analyze terrain modification plans from your trained policies using the plan extraction tools. 116 | 117 | ### Extract Plans 118 | Extract action maps and terrain modifications from policy rollouts: 119 | ```bash 120 | python extract_plan.py -policy -map -steps -o 121 | ``` 122 | 123 | Example: 124 | ```bash 125 | python extract_plan.py -policy checkpoints/tracked-dense.pkl -map plan_maps/foundation -o plan.pkl 126 | ``` 127 | 128 | This captures the robot state and terrain modifications at each DO (dig/dump) action, storing: 129 | - Agent position, orientation, and loaded state 130 | - Terrain change values and modification masks 131 | - Traversability information 132 | 133 | ### Visualize Plans 134 | Create visualizations of the extracted terrain modification plans: 135 | ```bash 136 | python visualize_plan.py 137 | ``` 138 | 139 | This generates multi-panel plots showing: 140 | - Terrain modification masks (where digging/dumping occurred) 141 | - Traversability maps with agent positions 142 | - Terrain change values and action map evolution 143 | - Combined overlays of all modifications 144 | 145 | ## Baselines 146 | We train 2 models capable of solving both foundation and trench type of environments. They differentiate themselves based on the type of agent (wheeled or tracked), and the type of curriculum used to train them (dense reward with single level, or sparse reward with curriculum). All models are trained on 64x64 maps and are stored in the `checkpoints/` folder. 147 | 148 | | Checkpoint | Map Type | $C_r$ | $S_p$ | $S_w$ | $Coverage$ | 149 | |----------------------|-----------|-------|-------|-------|------------| 150 | | `tracked-dense.pkl` |Foundations|97%|5.66 (1.51)|19.06 (2.86)|0.99 (0.04)| 151 | | |Trenches |94%|7.09 (5.66)|20.57 (5.26)|0.99 (0.10)| 152 | | `wheeled-dense.pkl` |Foundations|99%|11.43 (8.96)|22.06 (3.65)|1.00 (0.00)| 153 | | |Trenches |89%|15.84 (25.10)|21.12 (5.65)|0.96 (0.14)| 154 | 155 | Where we define the metrics from [Terenzi et al](https://arxiv.org/abs/2308.11478): 156 | 157 | $$ 158 | \begin{equation} 159 | \text{Completion Rate}= C_{r} = \frac{N_{terminated}}{N_{total}} 160 | \end{equation} 161 | $$ 162 | 163 | $$ 164 | \begin{equation} 165 | \text{Path Efficiency}=S_{p}=\sum_{i=0}^{N-1} \frac{\left(x_{B_{i+1}}-x_{B_{i}}\right)}{\sqrt{A_{d}}} 166 | \end{equation} 167 | $$ 168 | 169 | $$ 170 | \begin{equation} 171 | \text{Workspace Efficiency} = S_{w} = \frac{N_{w} \cdot A_{w}}{A_{d}} 172 | \end{equation} 173 | $$ 174 | 175 | $$ 176 | \begin{equation} 177 | \text{Coverage}=\frac{N_{tiles\ dug}}{N_{tiles\ to\ dig}} 178 | \end{equation} 179 | $$ 180 | 181 | ### Model Details 182 | All the models we train share the same structure. We encode the maps with a CNN, and the agent state and local maps with MLPs. The latent features are concatenated and shared by the two MLP heads of the model (value and action). In total, the model has ~130k parameters counting both value and action weights. 183 | 184 | ## Policy Rollouts 😄 185 | Here's a collection of rollouts for the models we trained. 186 | #### `tracked-dense.pkl` 187 | ![img](assets/tracked-dense.gif) 188 | #### `wheeled-dense.pkl` 189 | ![img](assets/wheeled-dense.gif) 190 | -------------------------------------------------------------------------------- /terra_planner/extract_plan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | import argparse 4 | import pickle 5 | import sys 6 | from pathlib import Path 7 | 8 | # Add the parent directory to the path so we can import utils 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | 11 | from utils.models import load_neural_network 12 | from utils.helpers import load_pkl_object 13 | from terra.env import TerraEnvBatch 14 | from terra.actions import TrackedAction, WheeledAction, TrackedActionType, WheeledActionType 15 | import jax.numpy as jnp 16 | from utils.utils_ppo import obs_to_model_input, wrap_action 17 | from tensorflow_probability.substrates import jax as tfp 18 | from train import TrainConfig # needed for unpickling checkpoints 19 | 20 | 21 | def extract_plan(env, model, model_params, env_cfgs, rl_config, max_frames, seed): 22 | """Extract plan by capturing action_map and robot state on DO actions.""" 23 | print(f"Using seed={seed}") 24 | rng = jax.random.PRNGKey(seed) 25 | rng, _rng = jax.random.split(rng) 26 | rng_reset = jax.random.split(_rng, 1) # Just one environment 27 | timestep = env.reset(env_cfgs, rng_reset) 28 | prev_actions = jnp.zeros( 29 | (1, rl_config.num_prev_actions), 30 | dtype=jnp.int32 31 | ) 32 | 33 | # Determine action type and DO action 34 | action_type = env.batch_cfg.action_type 35 | if action_type == TrackedAction: 36 | do_action = TrackedActionType.DO 37 | elif action_type == WheeledAction: 38 | do_action = WheeledActionType.DO 39 | else: 40 | raise ValueError(f"Unknown action type: {action_type}") 41 | 42 | print(f"Action type: {action_type.__name__}, DO action value: {do_action}") 43 | 44 | # Plan storage 45 | plan = [] 46 | 47 | t_counter = 0 48 | 49 | while True: 50 | rng, rng_act, rng_step = jax.random.split(rng, 3) 51 | 52 | # Get action from policy 53 | obs_model = obs_to_model_input(timestep.observation, prev_actions, rl_config) 54 | v, logits_pi = model.apply(model_params, obs_model) 55 | pi = tfp.distributions.Categorical(logits=logits_pi) 56 | action = pi.sample(seed=rng_act) 57 | 58 | # Check if DO action and record state BEFORE executing the action 59 | if action[0] == do_action: 60 | print(f"DO action at step {t_counter}") 61 | action_map_before = jnp.squeeze(timestep.observation["action_map"]).copy() 62 | traversability_mask = jnp.squeeze(timestep.observation["traversability_mask"]).copy() 63 | agent_state_before = jnp.squeeze(timestep.observation["agent_state"]).copy() 64 | loaded_before = jnp.bool_(agent_state_before[5]) 65 | 66 | # Update previous actions 67 | prev_actions = jnp.roll(prev_actions, shift=1, axis=1) 68 | prev_actions = prev_actions.at[:, 0].set(action) 69 | 70 | # Take step in environment 71 | rng_step = jax.random.split(rng_step, 1) 72 | timestep = env.step( 73 | timestep, wrap_action(action, env.batch_cfg.action_type), rng_step 74 | ) 75 | 76 | # Get state AFTER executing the action 77 | action_map_after = jnp.squeeze(timestep.observation["action_map"]).copy() 78 | agent_state_after = jnp.squeeze(timestep.observation["agent_state"]).copy() 79 | loaded_after = jnp.bool_(agent_state_after[5]) 80 | 81 | changed_tiles = action_map_before != action_map_after 82 | terrain_modification_mask = changed_tiles.astype(jnp.bool_) 83 | 84 | # Define dug and dump masks based on the final action map state 85 | dug_mask = action_map_after < 0 86 | dump_mask = action_map_after > 0 87 | 88 | if loaded_before != loaded_after: 89 | if not loaded_before and loaded_after: 90 | print(f" Digging detected: {jnp.sum(changed_tiles)} tiles modified") 91 | elif loaded_before and not loaded_after: 92 | print(f" Dumping detected: {jnp.sum(changed_tiles)} tiles modified") 93 | else: 94 | # Case 3: loaded state did not change, but still check for modifications 95 | print(f" Loaded state unchanged ({loaded_before}), {jnp.sum(changed_tiles)} tiles modified") 96 | 97 | plan_entry = { 98 | 'step': t_counter, 99 | 'traversability_mask': traversability_mask, 100 | 'agent_state': { 101 | 'pos_base': (agent_state_before[0], agent_state_before[1]), 102 | 'angle_base': agent_state_before[2], 103 | 'angle_cabin': agent_state_before[3], 104 | 'wheel_angle': agent_state_before[4], 105 | }, 106 | 'terrain_modification_mask': terrain_modification_mask, 107 | 'dug_mask': dug_mask.copy(), 108 | 'dump_mask': dump_mask.copy(), 109 | 'loaded_state_change': { 110 | 'before': loaded_before, 111 | 'after': loaded_after, 112 | } 113 | } 114 | plan.append(plan_entry) 115 | else: 116 | # Update previous actions 117 | prev_actions = jnp.roll(prev_actions, shift=1, axis=1) 118 | prev_actions = prev_actions.at[:, 0].set(action) 119 | 120 | # Take step in environment 121 | rng_step = jax.random.split(rng_step, 1) 122 | timestep = env.step( 123 | timestep, wrap_action(action, env.batch_cfg.action_type), rng_step 124 | ) 125 | 126 | t_counter += 1 127 | print(f"Step {t_counter}, Action: {action[0]}") 128 | 129 | # Check if done 130 | if jnp.all(timestep.info["task_done"]).item() or t_counter == max_frames: 131 | break 132 | 133 | return plan 134 | 135 | 136 | def main(): 137 | parser = argparse.ArgumentParser(description="Extract plan from policy") 138 | parser.add_argument( 139 | "-policy", 140 | "--policy_path", 141 | type=str, 142 | required=True, 143 | help="Path to the policy .pkl file" 144 | ) 145 | parser.add_argument( 146 | "-map", 147 | "--map_path", 148 | type=str, 149 | required=True, 150 | help="Path to the map file" 151 | ) 152 | parser.add_argument( 153 | "-steps", 154 | "--n_steps", 155 | type=int, 156 | default=500, 157 | help="Maximum number of steps" 158 | ) 159 | parser.add_argument( 160 | "-o", 161 | "--output_path", 162 | type=str, 163 | default="plan.pkl", 164 | help="Output path for the plan" 165 | ) 166 | parser.add_argument( 167 | "-s", 168 | "--seed", 169 | type=int, 170 | default=0, 171 | help="Random seed" 172 | ) 173 | 174 | args = parser.parse_args() 175 | 176 | # Load policy 177 | log = load_pkl_object(args.policy_path) 178 | config = log["train_config"] 179 | config.num_test_rollouts = 1 # Only one environment 180 | config.num_devices = 1 181 | 182 | # Disable action map clipping to see full terrain state 183 | print(f"Original clip_action_maps setting: {config.clip_action_maps}") 184 | config.clip_action_maps = False 185 | # Add the missing attribute that the model expects when clipping is disabled 186 | config.maps_net_normalization_bounds = [-100, 100] # Reasonable range for terrain heights 187 | print(f"Modified clip_action_maps setting: {config.clip_action_maps}") 188 | print(f"Added maps_net_normalization_bounds: {config.maps_net_normalization_bounds}") 189 | 190 | # Create environment 191 | env_cfgs = log["env_config"] 192 | env_cfgs = jax.tree_map(lambda x: x[0][None, ...], env_cfgs) 193 | env = TerraEnvBatch(rendering=False, shuffle_maps=False, single_map_path=args.map_path) 194 | 195 | # Load neural network 196 | model = load_neural_network(config, env) 197 | model_params = log["model"] 198 | 199 | # Extract plan 200 | plan = extract_plan( 201 | env, 202 | model, 203 | model_params, 204 | env_cfgs, 205 | config, 206 | max_frames=args.n_steps, 207 | seed=args.seed 208 | ) 209 | 210 | # Save plan 211 | output_path = Path(args.output_path) 212 | with open(output_path, 'wb') as f: 213 | pickle.dump(plan, f) 214 | 215 | print(f"Plan extracted and saved to {output_path}") 216 | print(f"Total DO actions: {len(plan)}") 217 | 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /terra_planner/visualize_plan.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to visualize terrain modification masks from a plan extracted by extract_plan.py 4 | """ 5 | 6 | import pickle 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from pathlib import Path 11 | 12 | 13 | def load_plan(plan_path): 14 | """Load the plan from a pickle file.""" 15 | with open(plan_path, 'rb') as f: 16 | plan = pickle.load(f) 17 | return plan 18 | 19 | 20 | def plot_terrain_modifications(plan, output_dir=None, show_plots=True): 21 | """ 22 | Plot terrain modification masks for each step in the plan. 23 | 24 | Args: 25 | plan: List of plan entries from extract_plan.py (terrain modification actions) 26 | output_dir: Directory to save plots (optional) 27 | show_plots: Whether to display plots interactively 28 | """ 29 | if output_dir: 30 | output_dir = Path(output_dir) 31 | output_dir.mkdir(exist_ok=True) 32 | 33 | print(f"Found {len(plan)} terrain modification actions in the plan") 34 | 35 | # Filter out entries with no terrain modifications 36 | modified_entries = [] 37 | for entry in plan: 38 | terrain_mask = np.array(entry['terrain_modification_mask']) 39 | if np.sum(terrain_mask) > 0: 40 | modified_entries.append(entry) 41 | 42 | print(f"{len(modified_entries)} of these actually modified terrain") 43 | 44 | if len(modified_entries) == 0: 45 | print("No terrain modification actions found in the plan.") 46 | return 47 | 48 | for i, entry in enumerate(modified_entries): 49 | # Check if we have terrain change values for enhanced visualization 50 | if 'terrain_change_values' in entry: 51 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6)) 52 | has_change_values = True 53 | else: 54 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) 55 | has_change_values = False 56 | 57 | # Get data 58 | terrain_mask = np.array(entry['terrain_modification_mask']) 59 | traversability_mask = np.array(entry['traversability_mask']) 60 | agent_state = entry['agent_state'] 61 | step = entry['step'] 62 | 63 | # Plot 1: Binary modification mask 64 | im1 = ax1.imshow(terrain_mask, cmap='Reds', interpolation='nearest') 65 | ax1.set_title(f'Terrain Modification Mask (Binary)\nStep {step} (Entry {i+1}/{len(modified_entries)})') 66 | ax1.set_xlabel('X coordinate') 67 | ax1.set_ylabel('Y coordinate') 68 | 69 | # Add agent position 70 | agent_y, agent_x = agent_state['pos_base'] # Note: Terra coordinates are (y, x) 71 | loaded_before = entry['loaded_state_change']['before'] 72 | loaded_after = entry['loaded_state_change']['after'] 73 | ax1.plot(agent_x, agent_y, 'bo', markersize=8, label=f'Agent (before: {loaded_before}, after: {loaded_after})') 74 | ax1.legend() 75 | 76 | # Add colorbar 77 | plt.colorbar(im1, ax=ax1, label='Modified (1) / Unmodified (0)') 78 | 79 | # Plot 2: Traversability mask for context 80 | im2 = ax2.imshow(traversability_mask, cmap='viridis', interpolation='nearest') 81 | ax2.set_title(f'Traversability Mask\nStep {step}') 82 | ax2.set_xlabel('X coordinate') 83 | ax2.set_ylabel('Y coordinate') 84 | 85 | # Add agent position on traversability map too 86 | ax2.plot(agent_x, agent_y, 'ro', markersize=8, label=f'Agent (before: {loaded_before}, after: {loaded_after})') 87 | ax2.legend() 88 | 89 | # Add colorbar 90 | plt.colorbar(im2, ax=ax2, label='Traversable (1) / Non-traversable (0)') 91 | 92 | # Plot 3: Actual change values (if available) 93 | if has_change_values: 94 | terrain_changes = np.array(entry['terrain_change_values']) 95 | # Use a diverging colormap to show positive and negative changes 96 | max_abs_change = np.max(np.abs(terrain_changes)) if np.any(terrain_changes != 0) else 1 97 | im3 = ax3.imshow(terrain_changes, cmap='RdBu_r', interpolation='nearest', 98 | vmin=-max_abs_change, vmax=max_abs_change) 99 | ax3.set_title(f'Terrain Change Values\nStep {step}') 100 | ax3.set_xlabel('X coordinate') 101 | ax3.set_ylabel('Y coordinate') 102 | 103 | # Add agent position 104 | ax3.plot(agent_x, agent_y, 'ko', markersize=8, label=f'Agent (before: {loaded_before}, after: {loaded_after})') 105 | ax3.legend() 106 | 107 | # Add colorbar 108 | plt.colorbar(im3, ax=ax3, label='Change in terrain value') 109 | 110 | # Add info text - determine action type based on loaded state change 111 | loaded_change = entry['loaded_state_change'] 112 | if not loaded_change['before'] and loaded_change['after']: 113 | action_type = "Digging" 114 | elif loaded_change['before'] and not loaded_change['after']: 115 | action_type = "Dumping" 116 | else: 117 | action_type = "Unknown terrain modification" 118 | 119 | num_modified = int(np.sum(terrain_mask)) 120 | 121 | # Add statistics about changes if available 122 | if has_change_values: 123 | terrain_changes = np.array(entry['terrain_change_values']) 124 | changed_values = terrain_changes[terrain_changes != 0] 125 | if len(changed_values) > 0: 126 | stats_text = f' | Min: {np.min(changed_values):.3f}, Max: {np.max(changed_values):.3f}, Mean: {np.mean(np.abs(changed_values)):.3f}' 127 | else: 128 | stats_text = '' 129 | else: 130 | stats_text = '' 131 | 132 | fig.suptitle(f'{action_type} - {num_modified} tiles modified{stats_text}', fontsize=14, fontweight='bold') 133 | 134 | plt.tight_layout() 135 | 136 | # Save plot if output directory specified 137 | if output_dir: 138 | filename = f'digging_step_{step:04d}.png' 139 | filepath = output_dir / filename 140 | plt.savefig(filepath, dpi=150, bbox_inches='tight') 141 | print(f"Saved plot to {filepath}") 142 | 143 | # Show plot if requested 144 | if show_plots: 145 | plt.show() 146 | else: 147 | plt.close() 148 | 149 | 150 | def plot_all_modifications_overlay(plan, output_dir=None, show_plots=True): 151 | """ 152 | Create an overlay plot showing all digging locations across all steps. 153 | 154 | Args: 155 | plan: List of plan entries from extract_plan.py (only digging actions) 156 | output_dir: Directory to save plots (optional) 157 | show_plots: Whether to display plots interactively 158 | """ 159 | if len(plan) == 0: 160 | print("No plan entries found.") 161 | return 162 | 163 | # Get the shape from the first entry 164 | first_entry = plan[0] 165 | terrain_shape = np.array(first_entry['terrain_modification_mask']).shape 166 | 167 | # Accumulate all modifications 168 | all_modifications = np.zeros(terrain_shape) 169 | modification_count = np.zeros(terrain_shape) 170 | 171 | for entry in plan: 172 | terrain_mask = np.array(entry['terrain_modification_mask']) 173 | all_modifications = np.logical_or(all_modifications, terrain_mask > 0) 174 | modification_count += terrain_mask 175 | 176 | # Create the overlay plot 177 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) 178 | 179 | # Plot binary overlay (any modification) 180 | im1 = ax1.imshow(all_modifications, cmap='Reds', interpolation='nearest') 181 | ax1.set_title('All Digging Locations\n(Binary Overlay)') 182 | ax1.set_xlabel('X coordinate') 183 | ax1.set_ylabel('Y coordinate') 184 | plt.colorbar(im1, ax=ax1, label='Dug (1) / Undug (0)') 185 | 186 | # Plot modification count 187 | im2 = ax2.imshow(modification_count, cmap='hot', interpolation='nearest') 188 | ax2.set_title('Digging Frequency\n(Number of times dug)') 189 | ax2.set_xlabel('X coordinate') 190 | ax2.set_ylabel('Y coordinate') 191 | plt.colorbar(im2, ax=ax2, label='Number of digging actions') 192 | 193 | total_modified_tiles = int(np.sum(all_modifications)) 194 | max_modifications = int(np.max(modification_count)) 195 | fig.suptitle(f'Total: {total_modified_tiles} tiles dug, Max digging actions per tile: {max_modifications}', 196 | fontsize=14, fontweight='bold') 197 | 198 | plt.tight_layout() 199 | 200 | # Save plot if output directory specified 201 | if output_dir: 202 | output_dir = Path(output_dir) 203 | output_dir.mkdir(exist_ok=True) 204 | filename = 'all_digging_locations_overlay.png' 205 | filepath = output_dir / filename 206 | plt.savefig(filepath, dpi=150, bbox_inches='tight') 207 | print(f"Saved overlay plot to {filepath}") 208 | 209 | # Show plot if requested 210 | if show_plots: 211 | plt.show() 212 | else: 213 | plt.close() 214 | 215 | 216 | def main(): 217 | parser = argparse.ArgumentParser(description="Visualize digging actions from extracted plan") 218 | parser.add_argument( 219 | "plan_path", 220 | type=str, 221 | help="Path to the plan .pkl file generated by extract_plan.py" 222 | ) 223 | parser.add_argument( 224 | "-o", "--output_dir", 225 | type=str, 226 | help="Directory to save plots (optional)" 227 | ) 228 | parser.add_argument( 229 | "--no-show", 230 | action="store_true", 231 | help="Don't display plots interactively (useful when saving only)" 232 | ) 233 | parser.add_argument( 234 | "--overlay-only", 235 | action="store_true", 236 | help="Only generate the overlay plot, not individual step plots" 237 | ) 238 | 239 | args = parser.parse_args() 240 | 241 | # Load the plan 242 | print(f"Loading plan from {args.plan_path}") 243 | plan = load_plan(args.plan_path) 244 | 245 | show_plots = not args.no_show 246 | 247 | if not args.overlay_only: 248 | # Plot individual digging actions 249 | print("Generating individual digging action plots...") 250 | plot_terrain_modifications(plan, args.output_dir, show_plots) 251 | 252 | # Plot overlay of all digging locations 253 | print("Generating digging overlay plot...") 254 | plot_all_modifications_overlay(plan, args.output_dir, show_plots) 255 | 256 | print("Visualization complete!") 257 | 258 | 259 | if __name__ == "__main__": 260 | main() 261 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | import math 4 | from utils.models import load_neural_network 5 | from utils.helpers import load_pkl_object 6 | from terra.env import TerraEnvBatch 7 | from terra.actions import ( 8 | WheeledAction, 9 | TrackedAction, 10 | WheeledActionType, 11 | TrackedActionType, 12 | ) 13 | import jax.numpy as jnp 14 | from utils.utils_ppo import obs_to_model_input, wrap_action 15 | 16 | # from utils.curriculum import Curriculum 17 | from tensorflow_probability.substrates import jax as tfp 18 | from train import TrainConfig # needed for unpickling checkpoints 19 | 20 | 21 | def _append_to_obs(o, obs_log): 22 | if obs_log == {}: 23 | return {k: v[:, None] for k, v in o.items()} 24 | obs_log = { 25 | k: jnp.concatenate((v, o[k][:, None]), axis=1) for k, v in obs_log.items() 26 | } 27 | return obs_log 28 | 29 | 30 | def rollout_episode( 31 | env: TerraEnvBatch, 32 | model, 33 | model_params, 34 | env_cfgs, 35 | rl_config, 36 | max_frames, 37 | deterministic, 38 | seed, 39 | ): 40 | """ 41 | NOTE: this function assumes it's a tracked agent in the way it computes the stats. 42 | """ 43 | print(f"Using {seed=}") 44 | rng = jax.random.PRNGKey(seed) 45 | rng, _rng = jax.random.split(rng) 46 | rng_reset = jax.random.split(_rng, rl_config.num_test_rollouts) 47 | timestep = env.reset(env_cfgs, rng_reset) 48 | prev_actions = jnp.zeros( 49 | (rl_config.num_test_rollouts, rl_config.num_prev_actions), 50 | dtype=jnp.int32 51 | ) 52 | 53 | tile_size = env_cfgs.tile_size[0].item() 54 | move_tiles = env_cfgs.agent.move_tiles[0].item() 55 | 56 | action_type = env.batch_cfg.action_type 57 | if action_type == TrackedAction: 58 | move_actions = (TrackedActionType.FORWARD, TrackedActionType.BACKWARD) 59 | l_actions = () 60 | do_action = TrackedActionType.DO 61 | elif action_type == WheeledAction: 62 | move_actions = (WheeledActionType.FORWARD, WheeledActionType.BACKWARD) 63 | l_actions = (WheeledActionType.WHEELS_LEFT, WheeledActionType.WHEELS_RIGHT) 64 | do_action = WheeledActionType.DO 65 | else: 66 | raise (ValueError(f"{action_type=}")) 67 | 68 | obs = timestep.observation 69 | areas = (obs["target_map"] == -1).sum( 70 | tuple([i for i in range(len(obs["target_map"].shape))][1:]) 71 | ) * (tile_size**2) 72 | target_maps_init = obs["target_map"].copy() 73 | dig_tiles_per_target_map_init = (target_maps_init == -1).sum( 74 | tuple([i for i in range(len(target_maps_init.shape))][1:]) 75 | ) 76 | 77 | t_counter = 0 78 | reward_seq = [] 79 | episode_done_once = None 80 | episode_length = None 81 | move_cumsum = None 82 | do_cumsum = None 83 | obs_seq = {} 84 | while True: 85 | obs_seq = _append_to_obs(obs, obs_seq) 86 | rng, rng_act, rng_step = jax.random.split(rng, 3) 87 | if model is not None: 88 | obs_model = obs_to_model_input(timestep.observation, prev_actions, rl_config) 89 | v, logits_pi = model.apply(model_params, obs_model) 90 | if deterministic: 91 | action = np.argmax(logits_pi, axis=-1) 92 | else: 93 | pi = tfp.distributions.Categorical(logits=logits_pi) 94 | action = pi.sample(seed=rng_act) 95 | prev_actions = jnp.roll(prev_actions, shift=1, axis=1) 96 | prev_actions = prev_actions.at[:, 0].set(action) 97 | else: 98 | raise RuntimeError("Model is None!") 99 | rng_step = jax.random.split(rng_step, rl_config.num_test_rollouts) 100 | timestep = env.step( 101 | timestep, wrap_action(action, env.batch_cfg.action_type), rng_step 102 | ) 103 | reward = timestep.reward 104 | next_obs = timestep.observation 105 | done = timestep.info["task_done"] 106 | 107 | reward_seq.append(reward) 108 | print(t_counter) 109 | print(10 * "=") 110 | t_counter += 1 111 | if jnp.all(done).item() or t_counter == max_frames: 112 | break 113 | obs = next_obs 114 | 115 | # Log stats 116 | if episode_done_once is None: 117 | episode_done_once = done 118 | if episode_length is None: 119 | episode_length = jnp.zeros_like(done, dtype=jnp.int32) 120 | if move_cumsum is None: 121 | move_cumsum = jnp.zeros_like(done, dtype=jnp.int32) 122 | if do_cumsum is None: 123 | do_cumsum = jnp.zeros_like(done, dtype=jnp.int32) 124 | 125 | episode_done_once = episode_done_once | done 126 | 127 | episode_length += ~episode_done_once 128 | 129 | move_cumsum_tmp = jnp.zeros_like(done, dtype=jnp.int32) 130 | for move_action in move_actions: 131 | move_mask = (action == move_action) * (~episode_done_once) 132 | move_cumsum_tmp += move_tiles * tile_size * move_mask.astype(jnp.int32) 133 | for l_action in l_actions: 134 | l_mask = (action == l_action) * (~episode_done_once) 135 | move_cumsum_tmp += 2 * move_tiles * tile_size * l_mask.astype(jnp.int32) 136 | move_cumsum += move_cumsum_tmp 137 | 138 | do_cumsum += (action == do_action) * (~episode_done_once) 139 | 140 | # Path efficiency -- only include finished envs 141 | move_cumsum *= episode_done_once 142 | path_efficiency = (move_cumsum / jnp.sqrt(areas))[episode_done_once] 143 | path_efficiency_std = path_efficiency.std() 144 | path_efficiency_mean = path_efficiency.mean() 145 | 146 | # Workspaces efficiency -- only include finished envs 147 | reference_workspace_area = 0.5 * np.pi * (8**2) 148 | n_dig_actions = do_cumsum // 2 149 | workspaces_efficiency = ( 150 | reference_workspace_area 151 | * ((n_dig_actions * episode_done_once) / areas)[episode_done_once] 152 | ) 153 | workspaces_efficiency_mean = workspaces_efficiency.mean() 154 | workspaces_efficiency_std = workspaces_efficiency.std() 155 | 156 | # Coverage scores 157 | dug_tiles_per_action_map = (obs["action_map"] == -1).sum( 158 | tuple([i for i in range(len(obs["action_map"].shape))][1:]) 159 | ) 160 | coverage_ratios = dug_tiles_per_action_map / dig_tiles_per_target_map_init 161 | coverage_scores = episode_done_once + (~episode_done_once) * coverage_ratios 162 | coverage_score_mean = coverage_scores.mean() 163 | coverage_score_std = coverage_scores.std() 164 | 165 | stats = { 166 | "episode_done_once": episode_done_once, 167 | "episode_length": episode_length, 168 | "path_efficiency": { 169 | "mean": path_efficiency_mean, 170 | "std": path_efficiency_std, 171 | }, 172 | "workspaces_efficiency": { 173 | "mean": workspaces_efficiency_mean, 174 | "std": workspaces_efficiency_std, 175 | }, 176 | "coverage": { 177 | "mean": coverage_score_mean, 178 | "std": coverage_score_std, 179 | }, 180 | } 181 | return np.cumsum(reward_seq), stats, obs_seq 182 | 183 | 184 | def print_stats( 185 | stats, 186 | ): 187 | episode_done_once = stats["episode_done_once"] 188 | episode_length = stats["episode_length"] 189 | path_efficiency = stats["path_efficiency"] 190 | workspaces_efficiency = stats["workspaces_efficiency"] 191 | coverage = stats["coverage"] 192 | 193 | completion_rate = 100 * episode_done_once.sum() / len(episode_done_once) 194 | 195 | print("\nStats:\n") 196 | print(f"Completion: {completion_rate:.2f}%") 197 | # print(f"First episode length average: {episode_length.mean()}") 198 | # print(f"First episode length min: {episode_length.min()}") 199 | # print(f"First episode length max: {episode_length.max()}") 200 | print( 201 | f"Path efficiency: {path_efficiency['mean']:.2f} ({path_efficiency['std']:.2f})" 202 | ) 203 | print( 204 | f"Workspaces efficiency: {workspaces_efficiency['mean']:.2f} ({workspaces_efficiency['std']:.2f})" 205 | ) 206 | print(f"Coverage: {coverage['mean']:.2f} ({coverage['std']:.2f})") 207 | 208 | 209 | if __name__ == "__main__": 210 | import argparse 211 | 212 | parser = argparse.ArgumentParser() 213 | parser.add_argument( 214 | "-run", 215 | "--run_name", 216 | type=str, 217 | default="ppo_2023_05_09_10_01_23", 218 | help="es/ppo trained agent.", 219 | ) 220 | parser.add_argument( 221 | "-env", 222 | "--env_name", 223 | type=str, 224 | default="Terra", 225 | help="Environment name.", 226 | ) 227 | parser.add_argument( 228 | "-n", 229 | "--n_envs", 230 | type=int, 231 | default=1, 232 | help="Number of environments.", 233 | ) 234 | parser.add_argument( 235 | "-steps", 236 | "--n_steps", 237 | type=int, 238 | default=10, 239 | help="Number of steps.", 240 | ) 241 | parser.add_argument( 242 | "-d", 243 | "--deterministic", 244 | type=int, 245 | default=0, 246 | help="Deterministic. 0 for stochastic, 1 for deterministic.", 247 | ) 248 | parser.add_argument( 249 | "-s", 250 | "--seed", 251 | type=int, 252 | default=0, 253 | help="Random seed for the environment.", 254 | ) 255 | args, _ = parser.parse_known_args() 256 | n_envs = args.n_envs 257 | 258 | log = load_pkl_object(f"{args.run_name}") 259 | config = log["train_config"] 260 | # from utils.helpers import load_config 261 | # config = load_config("agents/Terra/ppo.yaml", 22333, 33222, 5e-04, True, "")["train_config"] 262 | 263 | config.num_test_rollouts = n_envs 264 | config.num_devices = 1 265 | 266 | # curriculum = Curriculum(rl_config=config, n_devices=n_devices) 267 | # env_cfgs, dofs_count_dict = curriculum.get_cfgs_eval() 268 | env_cfgs = log["env_config"] 269 | env_cfgs = jax.tree_map( 270 | lambda x: x[0][None, ...].repeat(n_envs, 0), env_cfgs 271 | ) # take first config and replicate 272 | shuffle_maps = True 273 | env = TerraEnvBatch(rendering=False, shuffle_maps=shuffle_maps) 274 | config.num_embeddings_agent_min = 60 275 | 276 | model = load_neural_network(config, env) 277 | model_params = log["model"] 278 | # model_params = jax.tree_map(lambda x: x[0], replicated_params) 279 | deterministic = bool(args.deterministic) 280 | print(f"\nDeterministic = {deterministic}\n") 281 | 282 | cum_rewards, stats, _ = rollout_episode( 283 | env, 284 | model, 285 | model_params, 286 | env_cfgs, 287 | config, 288 | max_frames=args.n_steps, 289 | deterministic=deterministic, 290 | seed=args.seed, 291 | ) 292 | 293 | print_stats(stats) 294 | -------------------------------------------------------------------------------- /utils/models.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import Array 4 | import flax.linen as nn 5 | from typing import Sequence, Union 6 | from terra.actions import TrackedAction, WheeledAction 7 | from terra.env import TerraEnvBatch 8 | from functools import partial 9 | 10 | 11 | def get_model_ready(rng, config, env: TerraEnvBatch, speed=False): 12 | """Instantiate a model according to obs shape of environment.""" 13 | num_embeddings_agent = jnp.max( 14 | jnp.array( 15 | [ 16 | env.batch_cfg.maps_dims.maps_edge_length, 17 | env.batch_cfg.agent.angles_cabin, 18 | env.batch_cfg.agent.angles_base, 19 | ], 20 | dtype=jnp.int16, 21 | ) 22 | ).item() 23 | jax.debug.print("num_embeddings_agent = {x}", x=num_embeddings_agent) 24 | map_min_max = ( 25 | tuple(config["maps_net_normalization_bounds"]) 26 | if not config["clip_action_maps"] 27 | else (-1, 1) 28 | ) 29 | jax.debug.print("map normalization min max = {x}", x=map_min_max) 30 | model = SimplifiedCoupledCategoricalNet( 31 | num_prev_actions=config["num_prev_actions"], 32 | num_embeddings_agent=num_embeddings_agent, 33 | map_min_max=map_min_max, 34 | local_map_min_max=tuple(config["local_map_normalization_bounds"]), 35 | loaded_max=config["loaded_max"], 36 | action_type=env.batch_cfg.action_type, 37 | ) 38 | 39 | map_width = env.batch_cfg.maps_dims.maps_edge_length 40 | map_height = env.batch_cfg.maps_dims.maps_edge_length 41 | 42 | obs = [ 43 | jnp.zeros((config["num_envs"], env.batch_cfg.agent.num_state_obs)), 44 | jnp.zeros((config["num_envs"], env.batch_cfg.agent.angles_cabin)), 45 | jnp.zeros((config["num_envs"], env.batch_cfg.agent.angles_cabin)), 46 | jnp.zeros((config["num_envs"], env.batch_cfg.agent.angles_cabin)), 47 | jnp.zeros((config["num_envs"], env.batch_cfg.agent.angles_cabin)), 48 | jnp.zeros((config["num_envs"], env.batch_cfg.agent.angles_cabin)), 49 | jnp.zeros((config["num_envs"], env.batch_cfg.agent.angles_cabin)), 50 | jnp.zeros((config["num_envs"], map_width, map_height)), 51 | jnp.zeros((config["num_envs"], map_width, map_height)), 52 | jnp.zeros((config["num_envs"], map_width, map_height)), 53 | jnp.zeros((config["num_envs"], config["num_prev_actions"])), 54 | ] 55 | params = model.init(rng, obs) 56 | 57 | print(f"Model: {sum(x.size for x in jax.tree_leaves(params)):,} parameters") 58 | return model, params 59 | 60 | 61 | def load_neural_network(config, env): 62 | """Load neural network model based on config and environment.""" 63 | rng = jax.random.PRNGKey(0) 64 | model, _ = get_model_ready(rng, config, env) 65 | return model 66 | 67 | 68 | def normalize(x: Array, x_min: Array, x_max: Array) -> Array: 69 | """ 70 | Normalizes to [-1, 1] 71 | """ 72 | return 2.0 * (x - x_min) / (x_max - x_min) - 1.0 73 | 74 | 75 | class MLP(nn.Module): 76 | """ 77 | MLP without activation function at the last layer. 78 | """ 79 | 80 | hidden_dim_layers: Sequence[int] 81 | use_layer_norm: bool 82 | last_layer_init_scaling: float = 1.0 83 | 84 | def setup(self) -> None: 85 | layer_init = nn.initializers.lecun_normal 86 | last_layer_init = lambda a, b, c: self.last_layer_init_scaling * layer_init()( 87 | a, b, c 88 | ) 89 | self.activation = nn.relu 90 | 91 | if self.use_layer_norm: 92 | self.layers = [ 93 | nn.Sequential( 94 | [ 95 | nn.Dense(self.hidden_dim_layers[i], kernel_init=layer_init()), 96 | nn.LayerNorm(), 97 | ] 98 | ) 99 | for i in range(len(self.hidden_dim_layers) - 1) 100 | ] 101 | self.layers += ( 102 | nn.Dense(self.hidden_dim_layers[-1], kernel_init=last_layer_init), 103 | ) 104 | else: 105 | self.layers = [] 106 | for i, f in enumerate(self.hidden_dim_layers): 107 | if i < len(self.hidden_dim_layers) - 1: 108 | self.layers += (nn.Dense(f, kernel_init=layer_init()),) 109 | else: 110 | self.layers += (nn.Dense(f, kernel_init=last_layer_init),) 111 | 112 | def __call__(self, x): 113 | if self.use_layer_norm: 114 | for i, layer in enumerate(self.layers): 115 | x = layer(x) 116 | if ~(i % 2) and i != len(self.layers) - 1: 117 | x = self.activation(x) 118 | else: 119 | for i, layer in enumerate(self.layers): 120 | x = layer(x) 121 | if i != len(self.layers) - 1: 122 | x = self.activation(x) 123 | return x 124 | 125 | 126 | class AgentStateNet(nn.Module): 127 | """ 128 | Pre-process the agent state features. 129 | """ 130 | 131 | num_embeddings: int 132 | loaded_max: int 133 | mlp_use_layernorm: bool 134 | num_embedding_features: int = 8 135 | hidden_dim_layers_mlp_one_hot: Sequence[int] = (16, 32) 136 | hidden_dim_layers_mlp_continuous: Sequence[int] = (16, 32) 137 | 138 | def setup(self) -> None: 139 | self.embedding = nn.Embed( 140 | num_embeddings=self.num_embeddings, features=self.num_embedding_features 141 | ) 142 | self.mlp_one_hot = MLP( 143 | hidden_dim_layers=self.hidden_dim_layers_mlp_one_hot, 144 | use_layer_norm=self.mlp_use_layernorm, 145 | ) 146 | self.mlp_continuous = MLP( 147 | hidden_dim_layers=self.hidden_dim_layers_mlp_continuous, 148 | use_layer_norm=self.mlp_use_layernorm, 149 | ) 150 | 151 | def __call__(self, obs: dict[str, Array]): 152 | x_one_hot = obs[0][..., :-1].astype(dtype=jnp.int32) 153 | x_loaded = obs[0][..., [-1]].astype(dtype=jnp.int32) 154 | 155 | x_one_hot = self.embedding(x_one_hot) 156 | x_one_hot = self.mlp_one_hot(x_one_hot.reshape(*x_one_hot.shape[:-2], -1)) 157 | 158 | x_loaded = normalize(x_loaded, 0, self.loaded_max) 159 | x_continuous = self.mlp_continuous(x_loaded) 160 | 161 | return jnp.concatenate([x_one_hot, x_continuous], axis=-1) 162 | 163 | 164 | class LocalMapNet(nn.Module): 165 | """ 166 | Pre-process one or multiple maps. 167 | """ 168 | 169 | map_min_max: Sequence[int] 170 | mlp_use_layernorm: bool 171 | hidden_dim_layers_mlp: Sequence[int] = (256, 32) 172 | 173 | def setup(self) -> None: 174 | self.mlp = MLP( 175 | hidden_dim_layers=self.hidden_dim_layers_mlp, 176 | use_layer_norm=self.mlp_use_layernorm, 177 | ) 178 | 179 | def __call__(self, obs: dict[str, Array]): 180 | """ 181 | obs["agent_state"], 182 | obs["local_map_action_neg"], 183 | obs["local_map_action_pos"], 184 | obs["local_map_target_neg"], 185 | obs["local_map_target_pos"], 186 | obs["local_map_dumpability"], 187 | obs["local_map_obstacles"], 188 | obs["action_map"], 189 | obs["target_map"], 190 | obs["traversability_mask"], 191 | obs["dumpability_mask"], 192 | """ 193 | x_action_neg = normalize(obs[1], self.map_min_max[0], self.map_min_max[1]) 194 | x_action_pos = normalize(obs[2], self.map_min_max[0], self.map_min_max[1]) 195 | x_target_neg = normalize(obs[3], self.map_min_max[0], self.map_min_max[1]) 196 | x_target_pos = normalize(obs[4], self.map_min_max[0], self.map_min_max[1]) 197 | x_dumpability = obs[5] 198 | x_obstacles = obs[6] 199 | x = jnp.concatenate( 200 | ( 201 | x_action_neg[..., None], 202 | x_action_pos[..., None], 203 | x_target_neg[..., None], 204 | x_target_pos[..., None], 205 | x_dumpability[..., None], 206 | x_obstacles[..., None], 207 | ), 208 | -1, 209 | ) 210 | 211 | x = self.mlp(x.reshape(*x.shape[:-2], -1)) 212 | return x 213 | 214 | 215 | class AtariCNN(nn.Module): 216 | """From https://github.com/deepmind/dqn_zoo/blob/master/dqn_zoo/networks.py""" 217 | 218 | @nn.compact 219 | def __call__(self, x): 220 | x = nn.Conv(features=16, kernel_size=(8, 8), strides=(4, 4))(x) 221 | x = nn.relu(x) 222 | x = nn.Conv(features=32, kernel_size=(4, 4), strides=(2, 2))(x) 223 | x = nn.relu(x) 224 | x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1))(x) 225 | x = nn.relu(x) 226 | x = x.reshape((x.shape[0], -1)) 227 | 228 | x = nn.Dense(features=128)(x) 229 | x = nn.relu(x) 230 | x = nn.Dense(features=32)(x) 231 | return x 232 | 233 | 234 | @jax.jit 235 | def min_pool(x): 236 | pool_fn = partial( 237 | nn.max_pool, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)) 238 | ) 239 | return -pool_fn(-x) 240 | 241 | 242 | @jax.jit 243 | def max_pool(x): 244 | pool_fn = partial( 245 | nn.max_pool, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)) 246 | ) 247 | return pool_fn(x) 248 | 249 | 250 | @jax.jit 251 | def zero_pool(x): 252 | """ 253 | Given an input x with neg to pos values, 254 | zero_pool pools zeros with priority, then neg, then pos values. 255 | """ 256 | x_pool = min_pool(x) 257 | mask_pool = max_pool(x == 0) 258 | 259 | return jnp.where( 260 | mask_pool, 261 | 0, 262 | x_pool, 263 | ) 264 | 265 | 266 | class MapsNet(nn.Module): 267 | """ 268 | Pre-process one or multiple maps. 269 | """ 270 | 271 | map_min_max: Sequence[int] 272 | 273 | def setup(self) -> None: 274 | self.cnn = AtariCNN() 275 | 276 | def __call__(self, obs: dict[str, Array]): 277 | """ 278 | obs["agent_state"], 279 | obs["local_map_action_neg"], 280 | obs["local_map_action_pos"], 281 | obs["local_map_target_neg"], 282 | obs["local_map_target_pos"], 283 | obs["local_map_dumpability"], 284 | obs["local_map_obstacles"], 285 | obs["action_map"], 286 | obs["target_map"], 287 | obs["traversability_mask"], 288 | obs["dumpability_mask"], 289 | """ 290 | target_map = obs[7] 291 | traversability_map = obs[8] 292 | dumpability_mask = obs[9] 293 | 294 | x = jnp.concatenate( 295 | ( 296 | traversability_map[..., None], 297 | target_map[..., None], 298 | dumpability_mask[..., None], 299 | ), 300 | axis=-1, 301 | ) 302 | 303 | x = self.cnn(x) 304 | return x 305 | 306 | 307 | class PreviousActionsNet(nn.Module): 308 | """ 309 | Pre-processes the sequence of previous actions. 310 | """ 311 | num_actions: int 312 | mlp_use_layernorm: bool 313 | num_embedding_features: int = 8 314 | hidden_dim_layers_mlp: Sequence[int] = (16, 32) 315 | 316 | def setup(self) -> None: 317 | self.embedding = nn.Embed( 318 | num_embeddings=self.num_actions, 319 | features=self.num_embedding_features 320 | ) 321 | 322 | self.mlp = MLP( 323 | hidden_dim_layers=self.hidden_dim_layers_mlp, 324 | use_layer_norm=self.mlp_use_layernorm, 325 | ) 326 | 327 | self.activation = nn.relu 328 | 329 | def __call__(self, obs: dict[str, Array]): 330 | x_actions = obs[-1].astype(jnp.int32) 331 | x_actions = self.embedding(x_actions) 332 | 333 | x_flattened = x_actions.reshape(*x_actions.shape[:-2], -1) 334 | x_flattened = self.mlp(x_flattened) 335 | 336 | x = self.activation(x_flattened) 337 | return x 338 | 339 | 340 | class SimplifiedCoupledCategoricalNet(nn.Module): 341 | """ 342 | The full net. 343 | 344 | The obs List follows the following order: 345 | obs["agent_state"], 346 | obs["local_map_action_neg"], 347 | obs["local_map_action_pos"], 348 | obs["local_map_target_neg"], 349 | obs["local_map_target_pos"], 350 | obs["local_map_dumpability"], 351 | obs["local_map_obstacles"], 352 | obs["action_map"], 353 | obs["target_map"], 354 | obs["traversability_mask"], 355 | obs["dumpability_mask"], 356 | """ 357 | 358 | num_prev_actions: int 359 | num_embeddings_agent: int 360 | map_min_max: Sequence[int] 361 | local_map_min_max: Sequence[int] 362 | loaded_max: int 363 | action_type: Union[TrackedAction, WheeledAction] 364 | hidden_dim_pi: Sequence[int] = (128, 32) 365 | hidden_dim_v: Sequence[int] = (128, 32, 1) 366 | mlp_use_layernorm: bool = False 367 | 368 | def setup(self) -> None: 369 | num_actions = self.action_type.get_num_actions() 370 | 371 | self.mlp_v = MLP( 372 | hidden_dim_layers=self.hidden_dim_v, 373 | use_layer_norm=self.mlp_use_layernorm, 374 | last_layer_init_scaling=0.01, 375 | ) 376 | self.mlp_pi = MLP( 377 | hidden_dim_layers=self.hidden_dim_pi + (num_actions,), 378 | use_layer_norm=self.mlp_use_layernorm, 379 | last_layer_init_scaling=0.01, 380 | ) 381 | 382 | self.local_map_net = LocalMapNet( 383 | map_min_max=self.local_map_min_max, mlp_use_layernorm=self.mlp_use_layernorm 384 | ) 385 | 386 | self.agent_state_net = AgentStateNet( 387 | num_embeddings=self.num_embeddings_agent, 388 | loaded_max=self.loaded_max, 389 | mlp_use_layernorm=self.mlp_use_layernorm, 390 | ) 391 | 392 | self.maps_net = MapsNet(self.map_min_max) 393 | 394 | self.actions_net = PreviousActionsNet( 395 | num_actions=num_actions, 396 | mlp_use_layernorm=self.mlp_use_layernorm, 397 | ) 398 | 399 | self.activation = nn.relu 400 | 401 | def __call__(self, obs: Array) -> Array: 402 | x_agent_state = self.agent_state_net(obs) 403 | 404 | x_maps = self.maps_net(obs) 405 | 406 | x_local_map = self.local_map_net(obs) 407 | 408 | x_actions = self.actions_net(obs) 409 | 410 | x = jnp.concatenate((x_agent_state, x_maps, x_local_map, x_actions), axis=-1) 411 | x = self.activation(x) 412 | 413 | v = self.mlp_v(x) 414 | xpi = self.mlp_pi(x) 415 | 416 | return v, xpi 417 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jax.tree_util as jtu 4 | from utils.models import get_model_ready 5 | from terra.env import TerraEnvBatch 6 | from terra.config import EnvConfig 7 | from flax.training.train_state import TrainState 8 | import optax 9 | import wandb 10 | import eval_ppo 11 | from datetime import datetime 12 | from dataclasses import asdict, dataclass 13 | import time 14 | from tqdm import tqdm 15 | from functools import partial 16 | from flax.jax_utils import replicate, unreplicate 17 | from flax import struct 18 | import utils.helpers as helpers 19 | from utils.utils_ppo import select_action_ppo, wrap_action, obs_to_model_input, policy 20 | import os 21 | 22 | jax.config.update("jax_threefry_partitionable", True) 23 | 24 | 25 | @dataclass 26 | class TrainConfig: 27 | name: str 28 | num_devices: int = 0 29 | project: str = "main" 30 | group: str = "default" 31 | num_envs_per_device: int = 4096 32 | num_steps: int = 32 33 | update_epochs: int = 5 34 | num_minibatches: int = 32 35 | total_timesteps: int = 30_000_000_000 36 | lr: float = 3e-4 37 | clip_eps: float = 0.5 38 | gamma: float = 0.995 39 | gae_lambda: float = 0.95 40 | ent_coef: float = 0.005 41 | vf_coef: float = 5.0 42 | max_grad_norm: float = 0.5 43 | eval_episodes: int = 100 44 | seed: int = 42 45 | log_train_interval: int = 1 # Number of updates between logging train stats 46 | log_eval_interval: int = ( 47 | 50 # Number of updates between running eval and syncing with wandb 48 | ) 49 | checkpoint_interval: int = 50 # Number of updates between checkpoints 50 | # model settings 51 | num_prev_actions = 5 52 | clip_action_maps = True # clips the action maps to [-1, 1] 53 | local_map_normalization_bounds = [-16, 16] 54 | loaded_max = 100 55 | num_rollouts_eval = 500 # max length of an episode in Terra for eval (for training it is in Terra's curriculum) 56 | cache_clear_interval = 1000 # Number of updates between clearing caches 57 | 58 | def __post_init__(self): 59 | self.num_devices = ( 60 | jax.local_device_count() if self.num_devices == 0 else self.num_devices 61 | ) 62 | self.num_envs = self.num_envs_per_device * self.num_devices 63 | self.total_timesteps_per_device = self.total_timesteps // self.num_devices 64 | self.eval_episodes_per_device = self.eval_episodes // self.num_devices 65 | assert ( 66 | self.num_envs % self.num_devices == 0 67 | ), "Number of environments must be divisible by the number of devices." 68 | self.num_updates = ( 69 | self.total_timesteps // (self.num_steps * self.num_envs) 70 | ) // self.num_devices 71 | print(f"Num devices: {self.num_devices}, Num updates: {self.num_updates}") 72 | 73 | # make object subscriptable 74 | def __getitem__(self, key): 75 | return getattr(self, key) 76 | 77 | 78 | def make_states(config: TrainConfig, env_params: EnvConfig = EnvConfig()): 79 | env = TerraEnvBatch() 80 | num_devices = config.num_devices 81 | num_envs_per_device = config.num_envs_per_device 82 | 83 | env_params = jax.tree_map( 84 | lambda x: jnp.array(x)[None, None] 85 | .repeat(num_devices, 0) 86 | .repeat(num_envs_per_device, 1), 87 | env_params, 88 | ) 89 | print(f"{env_params.tile_size.shape=}") 90 | 91 | rng = jax.random.PRNGKey(config.seed) 92 | rng, _rng = jax.random.split(rng) 93 | 94 | network, network_params = get_model_ready(_rng, config, env) 95 | tx = optax.chain( 96 | optax.clip_by_global_norm(config.max_grad_norm), 97 | optax.adam(learning_rate=config.lr, eps=1e-5), 98 | ) 99 | train_state = TrainState.create( 100 | apply_fn=network.apply, params=network_params, tx=tx 101 | ) 102 | 103 | return rng, env, env_params, train_state 104 | 105 | 106 | class Transition(struct.PyTreeNode): 107 | done: jax.Array 108 | action: jax.Array 109 | value: jax.Array 110 | reward: jax.Array 111 | log_prob: jax.Array 112 | obs: jax.Array 113 | # for rnn policy 114 | prev_actions: jax.Array 115 | prev_reward: jax.Array 116 | 117 | 118 | def calculate_gae( 119 | transitions: Transition, 120 | last_val: jax.Array, 121 | gamma: float, 122 | gae_lambda: float, 123 | ) -> tuple[jax.Array, jax.Array]: 124 | # single iteration for the loop 125 | def _get_advantages(gae_and_next_value, transition): 126 | gae, next_value = gae_and_next_value 127 | delta = ( 128 | transition.reward 129 | + gamma * next_value * (1 - transition.done) 130 | - transition.value 131 | ) 132 | gae = delta + gamma * gae_lambda * (1 - transition.done) * gae 133 | return (gae, transition.value), gae 134 | 135 | _, advantages = jax.lax.scan( 136 | _get_advantages, 137 | (jnp.zeros_like(last_val), last_val), 138 | transitions, 139 | reverse=True, 140 | ) 141 | # advantages and values (Q) 142 | return advantages, advantages + transitions.value 143 | 144 | 145 | def ppo_update_networks( 146 | train_state: TrainState, 147 | transitions: Transition, 148 | advantages: jax.Array, 149 | targets: jax.Array, 150 | config, 151 | ): 152 | clip_eps = config.clip_eps 153 | vf_coef = config.vf_coef 154 | ent_coef = config.ent_coef 155 | 156 | # NORMALIZE ADVANTAGES 157 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 158 | 159 | def _loss_fn(params): 160 | # Terra: Reshape 161 | # [minibatch_size, seq_len, ...] -> [minibatch_size * seq_len, ...] 162 | print(f"ppo_update_networks {transitions.obs['agent_state'].shape=}") 163 | print(f"ppo_update_networks {transitions.prev_actions.shape=}") 164 | transitions_obs_reshaped = jax.tree_map( 165 | lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1], *x.shape[2:])), 166 | transitions.obs, 167 | ) 168 | transitions_actions_reshaped = jax.tree_map( 169 | lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1], *x.shape[2:])), 170 | transitions.prev_actions, 171 | ) 172 | print(f"ppo_update_networks {transitions_obs_reshaped['agent_state'].shape=}") 173 | print(f"ppo_update_networks {transitions_actions_reshaped.shape=}") 174 | 175 | # NOTE: can't use select_action_ppo here because it doesn't decouple params from train_state 176 | obs = obs_to_model_input(transitions_obs_reshaped, transitions_actions_reshaped, config) 177 | value, dist = policy(train_state.apply_fn, params, obs) 178 | value = value[:, 0] 179 | # action = dist.sample(seed=rng_model) 180 | transitions_actions_reshaped = jnp.reshape( 181 | transitions.action, (-1, *transitions.action.shape[2:]) 182 | ) 183 | log_prob = dist.log_prob(transitions_actions_reshaped) 184 | 185 | # Terra: Reshape 186 | value = jnp.reshape(value, transitions.value.shape) 187 | log_prob = jnp.reshape(log_prob, transitions.log_prob.shape) 188 | 189 | # CALCULATE VALUE LOSS 190 | value_pred_clipped = transitions.value + (value - transitions.value).clip( 191 | -clip_eps, clip_eps 192 | ) 193 | value_loss = jnp.square(value - targets) 194 | value_loss_clipped = jnp.square(value_pred_clipped - targets) 195 | value_loss = 0.5 * jnp.maximum(value_loss, value_loss_clipped).mean() 196 | 197 | # CALCULATE ACTOR LOSS 198 | ratio = jnp.exp(log_prob - transitions.log_prob) 199 | actor_loss1 = advantages * ratio 200 | actor_loss2 = advantages * jnp.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps) 201 | actor_loss = -jnp.minimum(actor_loss1, actor_loss2).mean() 202 | entropy = dist.entropy().mean() 203 | 204 | total_loss = actor_loss + vf_coef * value_loss - ent_coef * entropy 205 | return total_loss, (value_loss, actor_loss, entropy) 206 | 207 | (loss, (vloss, aloss, entropy)), grads = jax.value_and_grad(_loss_fn, has_aux=True)( 208 | train_state.params 209 | ) 210 | (loss, vloss, aloss, entropy, grads) = jax.lax.pmean( 211 | (loss, vloss, aloss, entropy, grads), axis_name="devices" 212 | ) 213 | train_state = train_state.apply_gradients(grads=grads) 214 | update_info = { 215 | "total_loss": loss, 216 | "value_loss": vloss, 217 | "actor_loss": aloss, 218 | "entropy": entropy, 219 | } 220 | return train_state, update_info 221 | 222 | 223 | def get_curriculum_levels(env_cfg, global_curriculum_levels): 224 | curriculum_stat = {} 225 | curriculum_levels = env_cfg.curriculum.level 226 | for i, global_curriculum_level in enumerate(global_curriculum_levels): 227 | curriculum_stat[f'Level {i}: {global_curriculum_level["maps_path"]}'] = jnp.sum( 228 | curriculum_levels == i 229 | ).item() 230 | return curriculum_stat 231 | 232 | 233 | def make_train( 234 | env: TerraEnvBatch, 235 | env_params: EnvConfig, 236 | config: TrainConfig, 237 | ): 238 | def train( 239 | rng: jax.Array, 240 | train_state: TrainState, 241 | ): 242 | # INIT ENV 243 | rng, _rng = jax.random.split(rng) 244 | reset_rng = jax.random.split( 245 | _rng, config.num_envs_per_device * config.num_devices 246 | ) 247 | reset_rng = reset_rng.reshape( 248 | (config.num_devices, config.num_envs_per_device, -1) 249 | ) 250 | 251 | # TERRA: Reset envs 252 | reset_fn_p = jax.pmap(env.reset, axis_name="devices") # vmapped inside 253 | timestep = reset_fn_p(env_params, reset_rng) 254 | prev_actions = jnp.zeros( 255 | (config.num_devices, config.num_envs_per_device, config.num_prev_actions), dtype=jnp.int32 256 | ) 257 | prev_reward = jnp.zeros((config.num_devices, config.num_envs_per_device)) 258 | 259 | # TRAIN LOOP 260 | @partial(jax.pmap, axis_name="devices") 261 | def _update_step(runner_state, _): 262 | """ 263 | Performs a single update step in the training loop. 264 | 265 | This function orchestrates the collection of trajectories from the environment, 266 | calculation of advantages, and updating of the network parameters based on the 267 | collected data. It involves stepping through the environment to collect data, 268 | calculating the advantage estimates for each step, and performing several epochs 269 | of updates on the network parameters using the collected data. 270 | 271 | Parameters: 272 | - runner_state: A tuple containing the current state of the RNG, the training state, 273 | the previous timestep, the previous action, and the previous reward. 274 | - _: Placeholder to match the expected input signature for jax.lax.scan. 275 | 276 | Returns: 277 | - runner_state: Updated runner state after performing the update step. 278 | - loss_info: A dictionary containing information about the loss and other 279 | metrics for this update step. 280 | """ 281 | 282 | # COLLECT TRAJECTORIES 283 | def _env_step(runner_state, _): 284 | """ 285 | Executes a step in the environment for all agents. 286 | 287 | This function takes the current state of the runners (agents), selects an 288 | action for each agent based on the current observation using the PPO 289 | algorithm, and then steps the environment forward using these actions. 290 | The environment returns the next state, reward, and whether the episode 291 | has ended for each agent. These are then used to create a transition tuple 292 | containing the current state, action, reward, and next state, which can 293 | be used for training the model. 294 | 295 | Parameters: 296 | - runner_state: Tuple containing the current rng state, train_state, 297 | previous timestep, previous action, and previous reward. 298 | - _: Placeholder to match the expected input signature for jax.lax.scan. 299 | 300 | Returns: 301 | - runner_state: Updated runner state after stepping the environment. 302 | - transition: A namedtuple containing the transition information 303 | (current state, action, reward, next state) for this step. 304 | """ 305 | rng, train_state, prev_timestep, prev_actions, prev_reward = runner_state 306 | 307 | # SELECT ACTION 308 | rng, _rng_model, _rng_env = jax.random.split(rng, 3) 309 | action, log_prob, value, _ = select_action_ppo( 310 | train_state, prev_timestep.observation, prev_actions, _rng_model, config 311 | ) 312 | 313 | # STEP ENV 314 | _rng_env = jax.random.split(_rng_env, config.num_envs_per_device) 315 | action_env = wrap_action(action, env.batch_cfg.action_type) 316 | timestep = env.step(prev_timestep, action_env, _rng_env) 317 | transition = Transition( 318 | # done=timestep.last(), 319 | done=timestep.done, 320 | action=action, 321 | value=value, 322 | reward=timestep.reward, 323 | log_prob=log_prob, 324 | obs=prev_timestep.observation, 325 | prev_actions=prev_actions, 326 | prev_reward=prev_reward, 327 | ) 328 | 329 | # UPDATE PREVIOUS ACTIONS 330 | prev_actions = jnp.roll(prev_actions, shift=1, axis=-1) 331 | prev_actions = prev_actions.at[..., 0].set(action) 332 | 333 | runner_state = (rng, train_state, timestep, prev_actions, timestep.reward) 334 | return runner_state, transition 335 | 336 | # transitions: [seq_len, batch_size, ...] 337 | runner_state, transitions = jax.lax.scan( 338 | _env_step, runner_state, None, config.num_steps 339 | ) 340 | 341 | # CALCULATE ADVANTAGE 342 | rng, train_state, timestep, prev_actions, prev_reward = runner_state 343 | rng, _rng = jax.random.split(rng) 344 | _, _, last_val, _ = select_action_ppo( 345 | train_state, timestep.observation, prev_actions, _rng, config 346 | ) 347 | advantages, targets = calculate_gae( 348 | transitions, last_val, config.gamma, config.gae_lambda 349 | ) 350 | 351 | # UPDATE NETWORK 352 | def _update_epoch(update_state, _): 353 | """ 354 | Performs a single epoch of updates on the network parameters. 355 | 356 | This function iterates over minibatches of the collected data and 357 | applies updates to the network parameters based on the PPO algorithm. 358 | It is called multiple times to perform multiple epochs of updates. 359 | 360 | Parameters: 361 | - update_state: A tuple containing the current state of the RNG, 362 | the training state, and the collected transitions, 363 | advantages, and targets. 364 | - _: Placeholder to match the expected input signature for jax.lax.scan. 365 | 366 | Returns: 367 | - update_state: Updated state after performing the epoch of updates. 368 | - update_info: Information about the updates performed in this epoch. 369 | """ 370 | 371 | def _update_minbatch(train_state, batch_info): 372 | """ 373 | Updates the network parameters based on a single minibatch of data. 374 | 375 | This function applies the PPO update rule to the network 376 | parameters using the data from a single minibatch. It is 377 | called for each minibatch in an epoch. 378 | 379 | Parameters: 380 | - train_state: The current training state, including the network parameters. 381 | - batch_info: A tuple containing the transitions, advantages, and targets for the minibatch. 382 | 383 | Returns: 384 | - new_train_state: The training state after applying the updates. 385 | - update_info: Information about the updates performed on this minibatch. 386 | """ 387 | transitions, advantages, targets = batch_info 388 | new_train_state, update_info = ppo_update_networks( 389 | train_state=train_state, 390 | transitions=transitions, 391 | advantages=advantages, 392 | targets=targets, 393 | config=config, 394 | ) 395 | return new_train_state, update_info 396 | 397 | rng, train_state, transitions, advantages, targets = update_state 398 | 399 | # MINIBATCHES PREPARATION 400 | rng, _rng = jax.random.split(rng) 401 | permutation = jax.random.permutation(_rng, config.num_envs_per_device) 402 | # [seq_len, batch_size, ...] 403 | batch = (transitions, advantages, targets) 404 | # [batch_size, seq_len, ...], as our model assumes 405 | batch = jtu.tree_map(lambda x: x.swapaxes(0, 1), batch) 406 | 407 | shuffled_batch = jtu.tree_map( 408 | lambda x: jnp.take(x, permutation, axis=0), batch 409 | ) 410 | # [num_minibatches, minibatch_size, seq_len, ...] 411 | minibatches = jtu.tree_map( 412 | lambda x: jnp.reshape( 413 | x, (config.num_minibatches, -1) + x.shape[1:] 414 | ), 415 | shuffled_batch, 416 | ) 417 | train_state, update_info = jax.lax.scan( 418 | _update_minbatch, train_state, minibatches 419 | ) 420 | 421 | update_state = (rng, train_state, transitions, advantages, targets) 422 | return update_state, update_info 423 | 424 | # [seq_len, batch_size, num_layers, hidden_dim] 425 | update_state = (rng, train_state, transitions, advantages, targets) 426 | update_state, loss_info = jax.lax.scan( 427 | _update_epoch, update_state, None, config.update_epochs 428 | ) 429 | 430 | # averaging over minibatches then over epochs 431 | loss_info = jtu.tree_map(lambda x: x.mean(-1).mean(-1), loss_info) 432 | 433 | rng, train_state = update_state[:2] 434 | # EVALUATE AGENT 435 | rng, _rng = jax.random.split(rng) 436 | 437 | runner_state = (rng, train_state, timestep, prev_actions, prev_reward) 438 | return runner_state, loss_info 439 | 440 | # Setup runner state for multiple devices 441 | 442 | rng, rng_rollout = jax.random.split(rng) 443 | rng = jax.random.split(rng, num=config.num_devices) 444 | train_state = replicate(train_state, jax.local_devices()[: config.num_devices]) 445 | runner_state = (rng, train_state, timestep, prev_actions, prev_reward) 446 | # runner_state, loss_info = jax.lax.scan(_update_step, runner_state, None, config.num_updates) 447 | for i in tqdm(range(config.num_updates), desc="Training"): 448 | start_time = time.time() # Start time for measuring iteration speed 449 | runner_state, loss_info = jax.block_until_ready( 450 | _update_step(runner_state, None) 451 | ) 452 | end_time = time.time() 453 | 454 | iteration_duration = end_time - start_time 455 | iterations_per_second = 1 / iteration_duration 456 | steps_per_second = ( 457 | iterations_per_second 458 | * config.num_steps 459 | * config.num_envs 460 | * config.num_devices 461 | ) 462 | 463 | tqdm.write( 464 | f"Steps/s: {steps_per_second:.2f}" 465 | ) # Display steps and iterations per second 466 | 467 | # Use data from the first device for stats and eval 468 | loss_info_single = unreplicate(loss_info) 469 | runner_state_single = unreplicate(runner_state) 470 | _, train_state, timestep, prev_actions = runner_state_single[:4] 471 | env_params_single = timestep.env_cfg 472 | 473 | if i % config.log_train_interval == 0: 474 | curriculum_levels = get_curriculum_levels( 475 | env_params_single, env.batch_cfg.curriculum_global.levels 476 | ) 477 | wandb.log( 478 | { 479 | "performance/steps_per_second": steps_per_second, 480 | "performance/iterations_per_second": iterations_per_second, 481 | "curriculum_levels": curriculum_levels, 482 | "lr": config.lr, 483 | **loss_info_single, 484 | } 485 | ) 486 | 487 | if i % config.checkpoint_interval == 0: 488 | checkpoint = { 489 | "train_config": config, 490 | "env_config": env_params_single, 491 | "model": runner_state_single[1].params, 492 | "loss_info": loss_info_single, 493 | } 494 | helpers.save_pkl_object(checkpoint, f"checkpoints/{config.name}.pkl") 495 | 496 | if i % config.log_eval_interval == 0: 497 | eval_stats = eval_ppo.rollout( 498 | rng_rollout, 499 | env, 500 | env_params_single, 501 | train_state, 502 | config, 503 | ) 504 | 505 | # eval_stats = jax.lax.pmean(eval_stats, axis_name="devices") 506 | n = config.num_envs_per_device * eval_stats.length 507 | avg_positive_episode_length = jnp.where( 508 | eval_stats.positive_terminations > 0, 509 | eval_stats.positive_terminations_steps / eval_stats.positive_terminations, 510 | jnp.zeros_like(eval_stats.positive_terminations_steps) 511 | ) 512 | loss_info_single.update( 513 | { 514 | "eval/rewards": eval_stats.reward / n, 515 | "eval/max_reward": eval_stats.max_reward, 516 | "eval/min_reward": eval_stats.min_reward, 517 | "eval/lengths": eval_stats.length, 518 | "eval/FORWARD %": eval_stats.action_0 / n, 519 | "eval/BACKWARD %": eval_stats.action_1 / n, 520 | "eval/CLOCK %": eval_stats.action_2 / n, 521 | "eval/ANTICLOCK %": eval_stats.action_3 / n, 522 | "eval/CABIN_CLOCK %": eval_stats.action_4 / n, 523 | "eval/CABIN_ANTICLOCK %": eval_stats.action_5 / n, 524 | "eval/DO": eval_stats.action_6 / n, 525 | "eval/positive_terminations": eval_stats.positive_terminations 526 | / config.num_envs_per_device, 527 | "eval/total_terminations": eval_stats.terminations 528 | / config.num_envs_per_device, 529 | "eval/avg_positive_episode_length": avg_positive_episode_length 530 | } 531 | ) 532 | 533 | wandb.log(loss_info_single) 534 | 535 | # Clear JAX caches and run garbage collection to stabilize memory use 536 | if i % config.cache_clear_interval == 0: 537 | jax.clear_caches() 538 | import gc 539 | gc.collect() 540 | 541 | return {"runner_state": runner_state_single, "loss_info": loss_info_single} 542 | 543 | return train 544 | 545 | 546 | def train(config: TrainConfig): 547 | run = wandb.init( 548 | entity="terra-sp-thesis", 549 | project=config.project, 550 | group=config.group, 551 | name=config.name, 552 | config=asdict(config), 553 | save_code=True, 554 | ) 555 | 556 | # Log config.py and train.py files to wandb 557 | train_py_path = os.path.abspath(__file__) # Path to current train.py file 558 | config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "terra", "terra", "config.py") 559 | 560 | code_artifact = wandb.Artifact(name="source_code", type="code") 561 | 562 | # Add train.py 563 | if os.path.exists(train_py_path): 564 | code_artifact.add_file(train_py_path, name="train.py") 565 | 566 | # Add config.py 567 | if os.path.exists(config_path): 568 | code_artifact.add_file(config_path, name="config.py") 569 | 570 | # Log the artifact if any files were added 571 | if code_artifact.files: 572 | run.log_artifact(code_artifact) 573 | 574 | rng, env, env_params, train_state = make_states(config) 575 | 576 | train_fn = make_train(env, env_params, config) 577 | 578 | print("Training...") 579 | try: # Try block starts here 580 | t = time.time() 581 | train_info = jax.block_until_ready(train_fn(rng, train_state)) 582 | elapsed_time = time.time() - t 583 | print(f"Done in {elapsed_time:.2f}s") 584 | except KeyboardInterrupt: # Catch Ctrl+C 585 | print("Training interrupted. Finalizing...") 586 | finally: # Ensure wandb.finish() is called 587 | run.finish() 588 | print("wandb session finished.") 589 | 590 | 591 | if __name__ == "__main__": 592 | DT = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 593 | import argparse 594 | 595 | parser = argparse.ArgumentParser() 596 | parser.add_argument( 597 | "-n", 598 | "--name", 599 | type=str, 600 | default="experiment", 601 | ) 602 | parser.add_argument( 603 | "-m", 604 | "--machine", 605 | type=str, 606 | default="local", 607 | ) 608 | parser.add_argument( 609 | "-d", 610 | "--num_devices", 611 | type=int, 612 | default=0, 613 | help="Number of devices to use. If 0, uses all available devices.", 614 | ) 615 | args, _ = parser.parse_known_args() 616 | 617 | name = f"{args.name}-{args.machine}-{DT}" 618 | train(TrainConfig(name=name, num_devices=args.num_devices)) 619 | --------------------------------------------------------------------------------