├── .gitattributes ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── datasets ├── conditional_instructions.pickle ├── context.pickle ├── irrelevant_text.pickle ├── multiple_objects.pickle ├── multiple_rearrangements.pickle ├── new_scenes.pickle ├── novel_objects.pickle ├── referring_expressions.pickle ├── rephrasing.pickle ├── spatial_relationships.pickle ├── train.pickle └── train_validation.pickle ├── llarp ├── config │ ├── __init__.py │ ├── baseline │ │ ├── llarp.yaml │ │ ├── policy │ │ │ └── llarp_policy.yaml │ │ └── renderer │ │ │ └── regular.yaml │ ├── default_structured_configs.py │ └── task │ │ ├── dataset_v1.yaml │ │ ├── language_rearrangement.yaml │ │ ├── pddl_domain_replica_cad.yaml │ │ └── task_obs │ │ └── visual.yaml ├── dataset │ ├── __init__.py │ ├── combine_datasets.py │ ├── configs │ │ ├── dataset.yaml │ │ └── instructions.yaml │ ├── create_episodes.py │ ├── dataset_validator.py │ ├── demo_dataset.py │ ├── episodes.py │ ├── generator.py │ └── utils.py ├── policies │ ├── __init__.py │ ├── action_decoders.py │ ├── cores │ │ ├── __init__.py │ │ ├── base_core.py │ │ └── decoder.py │ ├── hl_policy.py │ ├── llama_parallel.py │ ├── llm_policy.py │ ├── transformer_storage.py │ ├── utils.py │ ├── vis_bridge.py │ └── visual_encoders.py ├── run.py ├── task │ ├── __init__.py │ ├── actions.py │ ├── measures.py │ ├── predicate_task.py │ ├── sensors.py │ └── utils.py └── trainer │ ├── __init__.py │ ├── custom_ddp.py │ ├── custom_env_factory.py │ ├── custom_evaluator.py │ ├── env_utils │ └── __init__.py │ ├── test_env_factory.py │ ├── trainer_loop.py │ └── transformer_ppo.py ├── setup.py └── test └── test_llarp.py /.gitattributes: -------------------------------------------------------------------------------- 1 | datasets/train.pickle filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Large Language Models as Generalizable Policies for Embodied Tasks 2 | 3 | This software project accompanies the research paper, [Large Language Models as Generalizable Policies for Embodied Tasks](https://arxiv.org/abs/2310.17722). See [llm-rl.github.io](https://llm-rl.github.io) for more information 4 | 5 | **Abstract**: We show that large language models (LLMs) can be adapted to be generalizable policies for embodied visual tasks. Our approach, called Large LAnguage model Reinforcement Learning Policy (LLaRP), adapts a pre-trained frozen LLM to take as input text instructions and visual egocentric observations and output actions directly in the environment. Using reinforcement learning, we train LLaRP to see and act solely through environmental interactions. We show that LLaRP is robust to complex paraphrasings of task instructions and can generalize to new tasks that require novel optimal behavior. In particular, on 1,000 unseen tasks it achieves 42% success rate, 1.7x the success rate of other common learned baselines or zero-shot applications of LLMs. Finally, to aid the community in studying language conditioned, massively multi-task, embodied AI problems we release a novel benchmark, Language Rearrangement, consisting of 150,000 training and 1,000 testing tasks for language-conditioned rearrangement. 6 | 7 | ## Getting Started 8 | 9 | ### Installation 10 | - Setup Python environment: 11 | - `conda create -n llarp -y python=3.9` 12 | - `conda activate llarp` 13 | - Install [Habitat-Sim](https://github.com/facebookresearch/habitat-sim) `conda install -y habitat-sim==0.3.0 withbullet headless -c conda-forge -c aihabitat` 14 | - Install [Habitat-Lab](https://github.com/facebookresearch/habitat-lab): 15 | - `git clone -b 'v0.3.0' --depth 1 https://github.com/facebookresearch/habitat-lab.git ~/habitat-lab` 16 | - `cd ~/habitat-lab` 17 | - `pip install -e habitat-lab` 18 | - `pip install -e habitat-baselines` 19 | - Install [VC-1](https://eai-vc.github.io): 20 | - `git clone -b 76fe35e87b1937168f1ec4b236e863451883eaf3 https://github.com/facebookresearch/eai-vc.git ~/eai-vc` 21 | - `git submodule update --init --recursive` 22 | - `pip install -e ./vc_models` 23 | - Install this repository, first clone this repo and then run `pip install -e .` in the cloned directory. 24 | - Download YCB and ReplicaCAD dataset for the Language Rearrangement task. Run the following in this code base's directory. 25 | - `conda install -y -c conda-forge git-lfs` 26 | - `python -m habitat_sim.utils.datasets_download --uids rearrange_task_assets` 27 | - Download LLaMA-1 weights. Instructions on how to do this are [here](https://huggingface.co/docs/transformers/main/model_doc/llama). Place the model weights in `data/hf_llama_7B/` in the `llarp` directory. To verify the download is correct, `ls data/hf_llama_7B` should return `pytorch_model-00001-of-00001.bin`, `config.json`, etc. 28 | 29 | ### Commands 30 | **Training Commands**: 31 | - Train LLaRP on Language Rearrangement task: `python llarp/run.py --config-name=baseline/llarp.yaml` 32 | - The trainer is built on Habitat-Baselines v0.3.0. Most Habitat-Baselines options also apply to LLaRP training. See [this README](https://github.com/facebookresearch/habitat-lab/tree/afe4058a7f8aa5ab71a133575cdaa79f0308af6a/habitat-baselines) for more information about how to use Habitat Baselines. 33 | - To run training on multiple GPUs, use [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html). 34 | - More configurable options are under `llarp/config/baseline/policy/llarp_policy.yaml`. 35 | 36 | **Evaluation Commands**: First, get the checkpoint generated by the training command. Then run: `python llarp/run.py --config-name=baseline/llarp.yaml habitat_baselines.evaluate=True habitat_baselines.rl.policy.main_agent.hierarchical_policy.high_level_policy.is_eval_mode=True habitat_baselines.eval_ckpt_path_dir=path/to/checkpoint.pth habitat.dataset.data_path=datasets/DATASET.pickle` where `DATASET` refers to one of the Language Rearrangement evaluation splits under `data/datasets`: `train`, `new_scenes`, `rephrasing`, `referring_expressions`, `spatial_relationships`, `context`, `irrelevant_text`, `multiple_rearrangements`, `novel_objects`, `multiple_objects`, or `conditional_instructions`. 37 | 38 | **Dataset Generation Commands**: Since the episodes are procedurally generated, it is possible to generate more episodes beyond the 150,000 in the training dataset. 39 | - Generate dataset: `python llarp/dataset/create_episodes.py --run --config rl_llm/dataset/configs/dataset.yaml --num-episodes 100 --out data/datasets/dataset_path.pickle --limit-scene-set scene_train_split --instruct-path llarp/dataset/configs/instructions.yaml --tag v2_train --verbose --n-procs 112 --seed 7 --procs-per-gpu 14` 40 | - Validate generated episodes are solvable and remove unsolvable episodes: `python rl_llm/dataset/dataset_validator.py --cfg rl_llm/config/task/lang_cond.yaml --n-procs 56 habitat.dataset.data_path=/mnt/task_runtime/projects/rl_llm/data/datasets/dataset_path_validated.pickle task_obs=all` 41 | 42 | **Other**: 43 | - Run tests: `pytest test`. This checks that LLaRP weights are updated exactly as expected after several training iterations under a variety of training settings 44 | 45 | ## Documentation 46 | 47 | Code Structure (under `llarp/` directory): 48 | - `config/`: Hydra configuration YAML files. 49 | - `baseline/`: Config files for policies and trainers. 50 | - `task/`: Config files for the Language Rearrangement task. 51 | - `dataset/`: Utilities for generating the Language Rearrangement dataset files. 52 | - `configs/`: Config files for Language Rearrangement dataset generation. In this directory, `instructions.yaml` contains the language instruction templates and `dataset.yaml` defines the possible objects and receptacles. Refer to Appendix Section B of the paper for details on the instruction templates. 53 | - `create_episodes.py` is the entry point for the episode generation. 54 | - `dataset_validator.py` processes an already created dataset and ensures all the episodes are solvable. It deletes unsolvable episodes. 55 | - `policies/`: Defines the LLaRP policy module. 56 | - `cores/decoder.py`: This contains the bulk of the policy code for sampling from and updating the policy. 57 | - `action_decoders.py`: Contains the action decoder head. 58 | - `llm_policy.py`: Entry point for the LLaRP policy. This integrates the LLaRP policy with the Habitat Baselines trainer. 59 | - `transformer_storage.py`: Modified rollout buffer for transformers in PPO. 60 | - `vis_bridge.py`: The observation encoder module. 61 | - `visual_encoders.py`: The visual encoder (VC-1). 62 | - `task/`: Code to setup the Language Rearrangement task in Habitat-Lab. 63 | - `trainer/`: Contains the core RL loop, PPO loss calculation, evaluation, environment creation, and distributed training utilities. 64 | Language Rearrangement episode datasets are included in this repository under the `datasets/` directory. 65 | 66 | ## License 67 | The code is provided under [Apple Sample Code license](https://github.pie.apple.com/aiml-oss/ml-aim/blob/main/LICENSE). 68 | 69 | ## Citation 70 | ``` 71 | @article{szot2023large, 72 | title={Large Language Models as Generalizable Policies for Embodied Tasks}, 73 | author={Szot, Andrew and Schwarzer, Max and Agrawal, Harsh and Mazoure, Bogdan and Talbott, Walter and Metcalf, Katherine and Mackraz, Natalie and Hjelm, Devon and Toshev, Alexander}, 74 | journal={arXiv preprint arXiv:2310.17722}, 75 | year={2023} 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /datasets/conditional_instructions.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/conditional_instructions.pickle -------------------------------------------------------------------------------- /datasets/context.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/context.pickle -------------------------------------------------------------------------------- /datasets/irrelevant_text.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/irrelevant_text.pickle -------------------------------------------------------------------------------- /datasets/multiple_objects.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/multiple_objects.pickle -------------------------------------------------------------------------------- /datasets/multiple_rearrangements.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/multiple_rearrangements.pickle -------------------------------------------------------------------------------- /datasets/new_scenes.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/new_scenes.pickle -------------------------------------------------------------------------------- /datasets/novel_objects.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/novel_objects.pickle -------------------------------------------------------------------------------- /datasets/referring_expressions.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/referring_expressions.pickle -------------------------------------------------------------------------------- /datasets/rephrasing.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/rephrasing.pickle -------------------------------------------------------------------------------- /datasets/spatial_relationships.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/spatial_relationships.pickle -------------------------------------------------------------------------------- /datasets/train.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:16d7087d2ac67b12a1bf2e10e98764a252193acfd1e2b77336535d1996856625 3 | size 833547264 4 | -------------------------------------------------------------------------------- /datasets/train_validation.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-llarp/be7cbd3af35ce1a3f9a98fbd81505b77012ef745/datasets/train_validation.pickle -------------------------------------------------------------------------------- /llarp/config/__init__.py: -------------------------------------------------------------------------------- 1 | import llarp.config.default_structured_configs 2 | -------------------------------------------------------------------------------- /llarp/config/baseline/llarp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # This is also the base config for HL policies. 3 | 4 | defaults: 5 | - /task: language_rearrangement 6 | - /habitat_baselines: habitat_baselines_rl_config_base 7 | - /baseline/renderer: regular 8 | - /baseline/policy@habitat_baselines.rl.policy.main_agent: llarp_policy 9 | - /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor 10 | - _self_ 11 | 12 | 13 | habitat_baselines: 14 | verbose: False 15 | trainer_name: "il_ppo" 16 | updater_name: "TransformerPPO" 17 | distrib_updater_name: "DistributedTransformerPPO" 18 | rollout_storage_name: "TransformerRolloutStorage" 19 | evaluator: 20 | _target_: llarp.trainer.custom_evaluator.CustomHabitatEvaluator 21 | torch_gpu_id: 0 22 | video_fps: 30 23 | eval_ckpt_path_dir: "" 24 | writer_type: 'tb' 25 | num_updates: -1 26 | total_num_steps: 2.0e8 27 | log_interval: 1 28 | num_checkpoints: 20 29 | num_environments: 8 30 | force_torch_single_threaded: True 31 | eval_keys_to_include_in_name: ['predicate_task_success', 'task_progress'] 32 | load_resume_state_config: False 33 | vector_env_factory: 34 | _target_: "llarp.trainer.custom_env_factory.CustomVectorEnvFactory" 35 | on_save_ckpt_callback: 36 | _target_: "llarp.run.on_save_ckpt_callback" 37 | 38 | eval: 39 | video_option: ["disk"] 40 | 41 | rl: 42 | ppo: 43 | # PPO params 44 | clip_param: 0.2 45 | ppo_epoch: 2 46 | num_mini_batch: 8 47 | value_loss_coef: 0.5 48 | entropy_coef: 0.01 49 | lr: 3e-4 50 | eps: 1e-5 51 | max_grad_norm: 0.2 52 | num_steps: 32 53 | use_gae: True 54 | gamma: 0.99 55 | tau: 0.95 56 | use_double_buffered_sampler: False 57 | use_normalized_advantage: False 58 | hidden_size: 4096 59 | 60 | ddppo: 61 | # DD-PPO sync parameter. 62 | sync_frac: 0.6 63 | # The PyTorch distributed backend to use 64 | distrib_backend: NCCL 65 | # Initialize just the visual encoder backbone with pretrained weights 66 | pretrained_encoder: False 67 | # Whether the visual encoder backbone will be trained. 68 | train_encoder: False 69 | reset_critic: False 70 | -------------------------------------------------------------------------------- /llarp/config/baseline/policy/llarp_policy.yaml: -------------------------------------------------------------------------------- 1 | # @package habitat_baselines.rl.policy.main_agent 2 | defaults: 3 | - _self_ 4 | 5 | name: "LlmPolicy" 6 | hierarchical_policy: 7 | high_level_policy: 8 | name: "" 9 | policy_core_type: "CleanLLMPolicyCore" 10 | normalize_visual_inputs: False 11 | num_rnn_layers: 2 12 | # Superset of all observations input to the policy 13 | policy_input_keys: 14 | # Visual sensors 15 | - head_rgb 16 | - obs_lang 17 | 18 | # LLM model 19 | use_b16: True 20 | load_in_8bit: False 21 | use_rope_scaling: False 22 | 23 | # Must be the same as the number of episode steps and the same as ppo.hidden_size 24 | context_len: 32 25 | rollout_take_ratio: 1.0 26 | is_eval_mode: False 27 | train_visual_encoder: False 28 | train_vis_bridge: True 29 | train_action_decoder: True 30 | # A value of null means don't load. 31 | pretrain_ckpt_path: null 32 | prefix_tokens_obs_k: "vocab_lang_goal" 33 | use_term_action: False 34 | 35 | remove_prompt_pad_toks: False 36 | 37 | set_llm_eval: False 38 | strict_loading: True 39 | num_visual_tokens: 1 40 | model_parallel_factor: 1 41 | use_action_inputs: False 42 | debug_mode: False 43 | 44 | critic: 45 | _target_: "llarp.policies.llm_policy.LinearCriticHead" 46 | use_b16: False 47 | 48 | # No prompting by default. 49 | prompts: 50 | habitat: "" 51 | prompt_suffix: "" 52 | tokenizer_id: "data/hf_llama_7B/" 53 | 54 | 55 | llm_wrapper: 56 | peft: False 57 | peft_full_att_params: False 58 | peft_settings: 59 | r: 8 60 | lora_alpha: 32 61 | lora_dropout: 0.1 62 | _target_: "llarp.policies.utils.DecoderWrapper" 63 | use_b16: True 64 | llm_id: "data/hf_llama_7B/" 65 | model_parallel_factor: 1 66 | debug_mode: False 67 | force_non_causal: False 68 | train_llm: False 69 | model_cfg: 70 | intermediate_size: 64 71 | hidden_size: 64 72 | num_hidden_layers: 2 73 | 74 | visual_bridge: 75 | _target_: "llarp.policies.vis_bridge.MlpVisBridge" 76 | _recursive_: False 77 | hidden_size: 4096 78 | 79 | action_decoder: 80 | _target_: "llarp.policies.action_decoders.MlpDecoder" 81 | hidden_size: 512 82 | 83 | vis_encoder: 84 | _target_: "llarp.policies.visual_encoders.Vc1VisualEncoder" 85 | use_b16: True 86 | classifier_feature: "use_cls_token" 87 | 88 | -------------------------------------------------------------------------------- /llarp/config/baseline/renderer/regular.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | habitat_baselines: 3 | num_environments: 24 4 | habitat: 5 | simulator: 6 | create_renderer: True 7 | concur_render: False 8 | renderer: 9 | enable_batch_renderer: False 10 | habitat_sim_v0: 11 | enable_gfx_replay_save: False 12 | 13 | -------------------------------------------------------------------------------- /llarp/config/task/dataset_v1.yaml: -------------------------------------------------------------------------------- 1 | # @package habitat.dataset 2 | 3 | defaults: 4 | - /habitat/dataset: dataset_config_schema 5 | - _self_ 6 | 7 | type: LangRearrangeDataset-v0 8 | data_path: "datasets/train.pickle" 9 | split: train 10 | scenes_dir: "data/replica_cad/" 11 | -------------------------------------------------------------------------------- /llarp/config/task/language_rearrangement.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat: habitat_config_base 5 | - /habitat/task: llarp_task_config_base 6 | - dataset_v1 7 | - task_obs: visual 8 | - /habitat/task/lab_sensors: 9 | - vocab_lang_goal 10 | - /habitat/task/measurements: 11 | - num_steps 12 | - was_prev_action_invalid 13 | - predicate_task_success 14 | - lang_goal 15 | - task_progress 16 | - subgoal_reward 17 | - num_invalid_actions 18 | - prev_action_name 19 | - /habitat/task/actions: 20 | - pddl_hl_action 21 | - _self_ 22 | 23 | habitat: 24 | environment: 25 | max_episode_steps: 32 26 | 27 | simulator: 28 | type: RearrangeSim-v0 29 | additional_object_paths: 30 | - data/objects/ycb/configs/ 31 | - data/replica_cad/configs/objects/ 32 | needs_markers: False 33 | concur_render: False 34 | 35 | create_renderer: True 36 | agents_order: ['main_agent'] 37 | renderer: 38 | enable_batch_renderer: False 39 | composite_files: 40 | - "data/composite_replica/replica.gltf" 41 | 42 | # The geometric goal position is not used. 43 | debug_render_goal: False 44 | debug_render: False 45 | auto_sleep: True 46 | kinematic_mode: True 47 | ac_freq_ratio: 1 48 | step_physics: False 49 | should_setup_semantic_ids: False 50 | 51 | habitat_sim_v0: 52 | allow_sliding: False 53 | enable_hbao: False 54 | enable_physics: True 55 | enable_gfx_replay_save: False 56 | 57 | task: 58 | type: "RearrangePredicateTask-v0" 59 | constraint_violation_ends_episode: False 60 | constraint_violation_drops_object: True 61 | end_on_success: True 62 | 63 | filter_instructs: null 64 | filter_down_num: null 65 | force_scene_per_worker: False 66 | 67 | tokenizer_name: "data/hf_llama_7B/" 68 | 69 | # PDDL task settings. 70 | task_spec_base_path: "task/" 71 | pddl_domain_def: "pddl_domain_replica_cad" 72 | # Robot randomly spawns 73 | start_template: null 74 | goal_template: null 75 | sample_entities: {} 76 | 77 | success_measure: "predicate_task_success" 78 | 79 | # Sparse reward for reaching each subgoal. 80 | reward_measure: "subgoal_reward" 81 | success_reward: 10.0 82 | # Per timestep reward (negative means penalty). 83 | slack_reward: 0.0 84 | actions: 85 | pddl_hl_action: 86 | allowed_actions: 87 | - nav 88 | - pick 89 | - place 90 | - open_fridge 91 | - close_fridge 92 | - open_cab 93 | - close_cab 94 | measurements: 95 | was_prev_action_invalid: 96 | pddl_action_name: "pddl_hl_action" 97 | prev_action_name: 98 | pddl_action_name: "pddl_hl_action" 99 | predicate_task_success: 100 | must_call_stop: False 101 | subgoal_reward: 102 | invalid_ac_pen: 0.05 103 | progress_reward_factor: 5.0 104 | -------------------------------------------------------------------------------- /llarp/config/task/pddl_domain_replica_cad.yaml: -------------------------------------------------------------------------------- 1 | # Defines the PDDL domain of tasks in ReplicaCAD. 2 | 3 | types: 4 | static_obj_type: 5 | - recep_type 6 | - movable_entity_type 7 | movable_entity_type: 8 | - stackable_entity_type 9 | recep_type: 10 | - static_receptacle_entity_type 11 | - art_receptacle_entity_type 12 | - nav_receptacle 13 | art_receptacle_entity_type: 14 | - cab_type 15 | - fridge_type 16 | nav_receptacle: 17 | - static_receptacle_entity_type 18 | - art_receptacle_entity_type 19 | static_receptacle_entity_type: 20 | - place_receptacle 21 | 22 | 23 | 24 | constants: 25 | - name: cab_push_point_7 26 | expr_type: cab_type 27 | - name: cab_push_point_6 28 | expr_type: cab_type 29 | - name: cab_push_point_5 30 | expr_type: cab_type 31 | - name: cab_push_point_4 32 | expr_type: cab_type 33 | - name: fridge_push_point 34 | expr_type: fridge_type 35 | - name: robot_0 36 | expr_type: robot_entity_type 37 | 38 | predicates: 39 | - name: in 40 | args: 41 | - name: obj 42 | expr_type: movable_entity_type 43 | - name: receptacle 44 | expr_type: art_receptacle_entity_type 45 | set_state: 46 | obj_states: 47 | obj: receptacle 48 | 49 | # - name: stacked_on_top 50 | # args: 51 | # - name: base_obj 52 | # expr_type: stackable_entity_type 53 | # - name: on_top_obj 54 | # expr_type: stackable_entity_type 55 | # set_state: 56 | # obj_states: 57 | # on_top_obj: base_obj 58 | 59 | - name: on_top 60 | args: 61 | - name: obj 62 | expr_type: movable_entity_type 63 | - name: receptacle 64 | expr_type: static_receptacle_entity_type 65 | set_state: 66 | obj_states: 67 | obj: receptacle 68 | recep_scaling: [1.0, 1.25, 1.0] 69 | 70 | - name: holding 71 | args: 72 | - name: obj 73 | expr_type: movable_entity_type 74 | set_state: 75 | robot_states: 76 | robot_0: 77 | holding: obj 78 | 79 | - name: not_holding 80 | args: [] 81 | set_state: 82 | robot_states: 83 | robot_0: 84 | should_drop: True 85 | 86 | - name: opened_cab 87 | args: 88 | - name: cab_id 89 | expr_type: cab_type 90 | set_state: 91 | art_states: 92 | cab_id: 93 | value: 0.45 94 | cmp: 'greater' 95 | override_thresh: 0.1 96 | 97 | - name: closed_cab 98 | args: 99 | - name: cab_id 100 | expr_type: cab_type 101 | set_state: 102 | art_states: 103 | cab_id: 104 | value: 0.0 105 | cmp: 'close' 106 | 107 | 108 | - name: opened_fridge 109 | args: 110 | - name: fridge_id 111 | expr_type: fridge_type 112 | set_state: 113 | art_states: 114 | fridge_id: 115 | value: 1.22 116 | cmp: 'greater' 117 | 118 | - name: closed_fridge 119 | args: 120 | - name: fridge_id 121 | expr_type: fridge_type 122 | set_state: 123 | art_states: 124 | fridge_id: 125 | value: 0.0 126 | cmp: 'close' 127 | 128 | # Place the robot as close as possible. Don't check any collision conditions 129 | - name: robot_at_closest 130 | args: 131 | - name: Y 132 | expr_type: static_obj_type 133 | set_state: 134 | robot_states: 135 | robot_0: 136 | pos: Y 137 | place_at_pos_dist: -1.0 138 | base_angle_noise: 0.0 139 | place_at_angle_thresh: 1.0 140 | 141 | - name: robot_at_obj 142 | args: 143 | - name: Y 144 | expr_type: movable_entity_type 145 | set_state: 146 | robot_states: 147 | robot_0: 148 | pos: Y 149 | place_at_pos_dist: 1.5 150 | base_angle_noise: 0.0 151 | place_at_angle_thresh: 1.0 152 | 153 | - name: robot_at 154 | args: 155 | - name: Y 156 | expr_type: static_obj_type 157 | set_state: 158 | robot_states: 159 | robot_0: 160 | pos: Y 161 | place_at_pos_dist: 2.0 162 | base_angle_noise: 0.0 163 | place_at_angle_thresh: 1.57 164 | 165 | - name: at 166 | args: 167 | - name: obj 168 | expr_type: movable_entity_type 169 | - name: at_entity 170 | expr_type: recep_type 171 | set_state: 172 | obj_states: 173 | obj: at_entity 174 | 175 | actions: 176 | - name: noop 177 | parameters: {} 178 | precondition: null 179 | postcondition: [] 180 | 181 | # Only defined relative to place receptacles. 182 | - name: nav 183 | parameters: 184 | - name: entity 185 | expr_type: nav_receptacle 186 | precondition: null 187 | postcondition: 188 | - robot_at_closest(entity) 189 | 190 | - name: pick 191 | parameters: 192 | - name: obj 193 | expr_type: movable_entity_type 194 | - name: robot 195 | expr_type: robot_entity_type 196 | precondition: 197 | expr_type: AND 198 | sub_exprs: 199 | - not_holding() 200 | - robot_at_obj(obj) 201 | - quantifier: FORALL 202 | inputs: 203 | - name: recep 204 | expr_type: cab_type 205 | expr_type: NAND 206 | sub_exprs: 207 | - in(obj, recep) 208 | - closed_cab(recep) 209 | postcondition: 210 | - holding(obj) 211 | 212 | # - name: place_on_top 213 | # parameters: 214 | # - name: base_obj 215 | # expr_type: stackable_entity_type 216 | # precondition: 217 | # expr_type: AND 218 | # quantifier: EXISTS 219 | # inputs: 220 | # - name: place_obj 221 | # expr_type: stackable_entity_type 222 | # sub_exprs: 223 | # - holding(place_obj) 224 | # - robot_at(base_obj) 225 | # postcondition: 226 | # - not_holding() 227 | # - stacked_on_top(place_obj, base_obj) 228 | 229 | - name: place 230 | parameters: 231 | - name: recep 232 | expr_type: place_receptacle 233 | precondition: 234 | expr_type: AND 235 | quantifier: EXISTS 236 | inputs: 237 | - name: place_obj 238 | expr_type: movable_entity_type 239 | sub_exprs: 240 | - holding(place_obj) 241 | - robot_at(recep) 242 | postcondition: 243 | - not_holding() 244 | - at(place_obj, recep) 245 | 246 | - name: open_fridge 247 | parameters: 248 | - name: fridge_id 249 | expr_type: fridge_type 250 | precondition: 251 | expr_type: AND 252 | sub_exprs: 253 | - robot_at(fridge_id) 254 | - closed_fridge(fridge_id) 255 | postcondition: 256 | - opened_fridge(fridge_id) 257 | 258 | - name: close_fridge 259 | parameters: 260 | - name: fridge_id 261 | expr_type: fridge_type 262 | precondition: 263 | expr_type: AND 264 | sub_exprs: 265 | - robot_at(fridge_id) 266 | - opened_fridge(fridge_id) 267 | postcondition: 268 | - closed_fridge(fridge_id) 269 | 270 | - name: open_cab 271 | parameters: 272 | - name: marker 273 | expr_type: cab_type 274 | precondition: 275 | expr_type: AND 276 | sub_exprs: 277 | - robot_at(marker) 278 | - closed_cab(marker) 279 | postcondition: 280 | - opened_cab(marker) 281 | 282 | - name: close_cab 283 | parameters: 284 | - name: marker 285 | expr_type: cab_type 286 | precondition: 287 | expr_type: AND 288 | sub_exprs: 289 | - robot_at(marker) 290 | - opened_cab(marker) 291 | postcondition: 292 | - closed_cab(marker) 293 | -------------------------------------------------------------------------------- /llarp/config/task/task_obs/visual.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /habitat/simulator/sensor_setups@habitat.simulator.agents.main_agent: rgb_head_agent 4 | - /habitat/task/lab_sensors: 5 | - is_holding_sensor 6 | - debug_info 7 | habitat: 8 | gym: 9 | obs_keys: 10 | # Visual sensors 11 | - head_rgb 12 | # Task spec 13 | - vocab_lang_goal 14 | - debug_info 15 | simulator: 16 | agents: 17 | main_agent: 18 | sim_sensors: 19 | head_rgb_sensor: 20 | width: 224 21 | height: 224 22 | radius: 0.3 23 | articulated_agent_urdf: ./data/robots/hab_fetch/robots/hab_suction.urdf 24 | articulated_agent_type: FetchSuctionRobot 25 | joint_start_noise: 0.0 26 | -------------------------------------------------------------------------------- /llarp/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from llarp.dataset.episodes import LangRearrangeDatasetV0, LangRearrangeEpisode 6 | -------------------------------------------------------------------------------- /llarp/dataset/combine_datasets.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | import gzip 7 | import os.path as osp 8 | import pickle 9 | import time 10 | from collections import defaultdict, deque 11 | from dataclasses import dataclass 12 | from threading import Thread 13 | from typing import Dict, List, Optional, Tuple 14 | 15 | import habitat 16 | import hydra 17 | import magnum as mn 18 | import numpy as np 19 | import torch 20 | from habitat import make_dataset 21 | from habitat.config.default_structured_configs import register_hydra_plugin 22 | from habitat.tasks.rearrange.multi_task.pddl_action import PddlAction 23 | from habitat.tasks.rearrange.multi_task.pddl_predicate import Predicate 24 | from habitat.tasks.rearrange.multi_task.rearrange_pddl import ( 25 | PddlEntity, SimulatorObjectType) 26 | from habitat_baselines.config.default_structured_configs import \ 27 | HabitatBaselinesConfigPlugin 28 | from PIL import Image 29 | from torch import multiprocessing as mp 30 | from tqdm import tqdm 31 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5Model 32 | 33 | import llarp.config 34 | import llarp.dataset 35 | import llarp.task 36 | from llarp.dataset.create_episodes import summarize_episodes 37 | from llarp.dataset.demo_dataset import DemoDataset, cat_demo_datasets 38 | from llarp.dataset.episodes import LangRearrangeDatasetV0 39 | from llarp.task.utils import PLACABLE_RECEP_TYPE, get_allowed_actions 40 | 41 | 42 | def start(args): 43 | initial_paths = args.paths.split(",") 44 | paths = [] 45 | for path in initial_paths: 46 | if "X" in path: 47 | for i in range(args.num_splits): 48 | paths.append(path.replace("X", str(i))) 49 | else: 50 | paths.append(path) 51 | 52 | print("Combining ") 53 | all_eps = [] 54 | for path in paths: 55 | print(path) 56 | path = osp.join("data/datasets", path) 57 | if not osp.exists(path): 58 | print(f"Could not find {path}") 59 | continue 60 | config = habitat.get_config(args.cfg, [f"habitat.dataset.data_path={path}"]) 61 | dataset = make_dataset( 62 | config.habitat.dataset.type, config=config.habitat.dataset 63 | ) 64 | all_eps.extend(dataset.episodes) 65 | print("Total eps", len(all_eps)) 66 | summarize_episodes(all_eps) 67 | 68 | combined_dataset = LangRearrangeDatasetV0(config, all_eps) 69 | with open(args.out_path, "wb") as f: 70 | pickle.dump(combined_dataset.to_binary(), f) 71 | print(f"Dumped to {args.out_path}") 72 | 73 | 74 | if __name__ == "__main__": 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("--cfg", required=True, type=str) 77 | parser.add_argument("--paths", required=True, type=str) 78 | parser.add_argument("--out-path", required=True, type=str) 79 | parser.add_argument("--num-splits", required=True, type=int) 80 | args = parser.parse_args() 81 | start(args) 82 | -------------------------------------------------------------------------------- /llarp/dataset/configs/dataset.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | dataset_path: "data/replica_cad/replicaCAD.scene_dataset_config.json" 3 | correct_unstable_results: False 4 | additional_object_paths: 5 | - "data/objects/ycb/configs" 6 | scene_sets: 7 | - 8 | name: "scene_train_split" 9 | included_substrings: 10 | - "v3_sc0_staging" 11 | - "v3_sc1_staging" 12 | - "v3_sc2_staging" 13 | excluded_substrings: [] 14 | - 15 | name: "scene_val_split" 16 | included_substrings: 17 | - "v3_sc3_staging" 18 | excluded_substrings: [] 19 | - 20 | name: "scene_test_split" 21 | included_substrings: 22 | - "v3_sc4_staging" 23 | excluded_substrings: [] 24 | 25 | category_groups: 26 | "all_cats": 27 | # A list of every object category including both train and evaluation. 28 | included: 29 | - "ball" 30 | - "clamp" 31 | - "hammer" 32 | - "screwdriver" 33 | - "padlock" 34 | - "scissors" 35 | - "block" 36 | - "drill" 37 | - "spatula" 38 | - "knife" 39 | - "spoon" 40 | - "plate" 41 | - "sponge" 42 | - "cleanser" 43 | - "plum" 44 | - "pear" 45 | - "peach" 46 | - "apple" 47 | - "lemon" 48 | - "can" 49 | - "box" 50 | - "banana" 51 | - "strawberry" 52 | - "lego" 53 | - "rubriks cube" 54 | - "book" 55 | - "bowl" 56 | - "cup" 57 | - "mug" 58 | - "orange" 59 | - "lid" 60 | - "toy airplane" 61 | - "wrench" 62 | # - "fork" 63 | 64 | "all_eval_cats": 65 | included: 66 | - "mug" 67 | - "orange" 68 | - "lid" 69 | - "toy airplane" 70 | - "wrench" 71 | 72 | "all_train_cats": 73 | included: 74 | - "ball" 75 | - "clamp" 76 | - "hammer" 77 | - "screwdriver" 78 | - "padlock" 79 | - "scissors" 80 | - "block" 81 | - "drill" 82 | - "spatula" 83 | - "knife" 84 | - "spoon" 85 | # - "fork" 86 | - "plate" 87 | - "sponge" 88 | - "cleanser" 89 | - "plum" 90 | - "pear" 91 | - "peach" 92 | - "apple" 93 | - "lemon" 94 | - "can" 95 | - "box" 96 | - "banana" 97 | - "strawberry" 98 | - "lego" 99 | - "rubriks cube" 100 | - "book" 101 | - "bowl" 102 | - "cup" 103 | 104 | "all_fruit": 105 | included: 106 | - "plum" 107 | - "pear" 108 | - "peach" 109 | - "apple" 110 | - "lemon" 111 | 112 | 113 | 114 | object_sets: 115 | # Must be a 1-1 mapping. 116 | - name: "CLUTTER_OBJECTS" 117 | excluded_substrings: [] 118 | included_substrings: 119 | - "053_mini_soccer_ball" 120 | - "054_softball" 121 | - "055_baseball" 122 | - "056_tennis_ball" 123 | - "057_racquetball" 124 | #- "058_golf_ball" 125 | - "050_medium_clamp" 126 | - "051_large_clamp" 127 | - "052_extra_large_clamp" 128 | - "048_hammer" 129 | - "043_phillips_screwdriver" 130 | - "044_flat_screwdriver" 131 | - "042_adjustable_wrench" 132 | - "038_padlock" 133 | - "037_scissors" 134 | - "036_wood_block" 135 | - "070-a_colored_wood_blocks" 136 | - "070-b_colored_wood_blocks" 137 | - "035_power_drill" 138 | - "033_spatula" 139 | - "032_knife" 140 | - "031_spoon" 141 | #- "030_fork" 142 | - "029_plate" 143 | - "028_skillet_lid" 144 | - "026_sponge" 145 | - "021_bleach_cleanser" 146 | - "018_plum" 147 | - "017_orange" 148 | - "016_pear" 149 | - "015_peach" 150 | - "013_apple" 151 | - "014_lemon" 152 | - "002_master_chef_can" 153 | - "005_tomato_soup_can" 154 | - "007_tuna_fish_can" 155 | - "010_potted_meat_can" 156 | - "003_cracker_box" 157 | - "004_sugar_box" 158 | - "008_pudding_box" 159 | - "009_gelatin_box" 160 | - "011_banana" 161 | - "012_strawberry" 162 | - "072-a_toy_airplane" 163 | - "072-b_toy_airplane" 164 | - "072-c_toy_airplane" 165 | - "072-d_toy_airplane" 166 | - "072-e_toy_airplane" 167 | - "073-a_lego_duplo" 168 | - "073-b_lego_duplo" 169 | - "073-c_lego_duplo" 170 | - "073-d_lego_duplo" 171 | - "073-e_lego_duplo" 172 | - "073-f_lego_duplo" 173 | - "073-g_lego_duplo" 174 | - "077_rubiks_cube" 175 | - "frl_apartment_book_01" 176 | - "frl_apartment_book_02" 177 | - "frl_apartment_book_03" 178 | - "frl_apartment_book_04" 179 | - "frl_apartment_book_05" 180 | - "frl_apartment_book_06" 181 | - "frl_apartment_bowl_01" 182 | - "frl_apartment_bowl_02" 183 | - "frl_apartment_bowl_03" 184 | - "frl_apartment_bowl_06" 185 | - "frl_apartment_bowl_07" 186 | - "024_bowl" 187 | - "025_mug" 188 | - "065-a_cups" 189 | - "065-b_cups" 190 | - "065-c_cups" 191 | - "065-d_cups" 192 | - "065-e_cups" 193 | - "065-f_cups" 194 | - "065-g_cups" 195 | - "065-h_cups" 196 | - "065-i_cups" 197 | - "065-j_cups" 198 | - "frl_apartment_cup_01" 199 | - "frl_apartment_cup_02" 200 | - "frl_apartment_cup_03" 201 | - "frl_apartment_cup_05" 202 | 203 | - name: "ball" 204 | excluded_substrings: [] 205 | included_substrings: 206 | - "053_mini_soccer_ball" 207 | - "054_softball" 208 | - "055_baseball" 209 | - "056_tennis_ball" 210 | - "057_racquetball" 211 | # This will roll away. 212 | #- "058_golf_ball" 213 | 214 | - name: "clamp" 215 | excluded_substrings: [] 216 | included_substrings: 217 | - "050_medium_clamp" 218 | - "051_large_clamp" 219 | - "052_extra_large_clamp" 220 | 221 | - name: "hammer" 222 | excluded_substrings: [] 223 | included_substrings: 224 | - "048_hammer" 225 | 226 | - name: "screwdriver" 227 | excluded_substrings: [] 228 | included_substrings: 229 | - "043_phillips_screwdriver" 230 | - "044_flat_screwdriver" 231 | 232 | - name: "wrench" 233 | excluded_substrings: [] 234 | included_substrings: 235 | - "042_adjustable_wrench" 236 | 237 | - name: "padlock" 238 | excluded_substrings: [] 239 | included_substrings: 240 | - "038_padlock" 241 | 242 | - name: "scissors" 243 | excluded_substrings: [] 244 | included_substrings: 245 | - "037_scissors" 246 | 247 | - name: "block" 248 | excluded_substrings: [] 249 | included_substrings: 250 | - "036_wood_block" 251 | - "070-a_colored_wood_blocks" 252 | - "070-b_colored_wood_blocks" 253 | 254 | - name: "drill" 255 | excluded_substrings: [] 256 | included_substrings: 257 | - "035_power_drill" 258 | 259 | - name: "spatula" 260 | excluded_substrings: [] 261 | included_substrings: 262 | - "033_spatula" 263 | 264 | - name: "knife" 265 | excluded_substrings: [] 266 | included_substrings: 267 | - "032_knife" 268 | 269 | - name: "spoon" 270 | excluded_substrings: [] 271 | included_substrings: 272 | - "031_spoon" 273 | 274 | # - name: "fork" 275 | # excluded_substrings: [] 276 | # included_substrings: 277 | # - "030_fork" 278 | 279 | - name: "plate" 280 | excluded_substrings: [] 281 | included_substrings: 282 | - "029_plate" 283 | 284 | - name: "lid" 285 | excluded_substrings: [] 286 | included_substrings: 287 | - "028_skillet_lid" 288 | 289 | - name: "sponge" 290 | excluded_substrings: [] 291 | included_substrings: 292 | - "026_sponge" 293 | 294 | - name: "cleanser" 295 | excluded_substrings: [] 296 | included_substrings: 297 | - "021_bleach_cleanser" 298 | 299 | - name: "plum" 300 | excluded_substrings: [] 301 | included_substrings: 302 | - "018_plum" 303 | 304 | - name: "orange" 305 | excluded_substrings: [] 306 | included_substrings: 307 | - "017_orange" 308 | 309 | - name: "pear" 310 | excluded_substrings: [] 311 | included_substrings: 312 | - "016_pear" 313 | 314 | - name: "peach" 315 | excluded_substrings: [] 316 | included_substrings: 317 | - "015_peach" 318 | 319 | - name: "apple" 320 | excluded_substrings: [] 321 | included_substrings: 322 | - "013_apple" 323 | 324 | - name: "lemon" 325 | excluded_substrings: [] 326 | included_substrings: 327 | - "014_lemon" 328 | 329 | - name: "can" 330 | excluded_substrings: [] 331 | included_substrings: 332 | - "002_master_chef_can" 333 | - "005_tomato_soup_can" 334 | - "007_tuna_fish_can" 335 | - "010_potted_meat_can" 336 | 337 | - name: "box" 338 | excluded_substrings: [] 339 | included_substrings: 340 | - "003_cracker_box" 341 | - "004_sugar_box" 342 | - "008_pudding_box" 343 | - "009_gelatin_box" 344 | 345 | - name: "banana" 346 | excluded_substrings: [] 347 | included_substrings: 348 | - "011_banana" 349 | 350 | - name: "strawberry" 351 | excluded_substrings: [] 352 | included_substrings: 353 | - "012_strawberry" 354 | 355 | - name: "toy airplane" 356 | excluded_substrings: [] 357 | included_substrings: 358 | - "072-a_toy_airplane" 359 | - "072-b_toy_airplane" 360 | - "072-c_toy_airplane" 361 | - "072-d_toy_airplane" 362 | - "072-e_toy_airplane" 363 | 364 | - name: "lego" 365 | excluded_substrings: [] 366 | included_substrings: 367 | - "073-a_lego_duplo" 368 | - "073-b_lego_duplo" 369 | - "073-c_lego_duplo" 370 | - "073-d_lego_duplo" 371 | - "073-e_lego_duplo" 372 | - "073-f_lego_duplo" 373 | - "073-g_lego_duplo" 374 | 375 | - name: "rubriks cube" 376 | excluded_substrings: [] 377 | included_substrings: 378 | - "077_rubiks_cube" 379 | 380 | - name: "book" 381 | excluded_substrings: [] 382 | included_substrings: 383 | - "frl_apartment_book_01" 384 | - "frl_apartment_book_02" 385 | - "frl_apartment_book_03" 386 | - "frl_apartment_book_04" 387 | - "frl_apartment_book_05" 388 | - "frl_apartment_book_06" 389 | 390 | - name: "bowl" 391 | excluded_substrings: [] 392 | included_substrings: 393 | - "frl_apartment_bowl_01" 394 | - "frl_apartment_bowl_02" 395 | - "frl_apartment_bowl_03" 396 | - "frl_apartment_bowl_06" 397 | - "frl_apartment_bowl_07" 398 | - "024_bowl" 399 | 400 | - name: "mug" 401 | excluded_substrings: [] 402 | included_substrings: 403 | - "025_mug" 404 | 405 | - name: "cup" 406 | excluded_substrings: [] 407 | included_substrings: 408 | - "065-a_cups" 409 | - "065-b_cups" 410 | - "065-c_cups" 411 | - "065-d_cups" 412 | - "065-e_cups" 413 | - "065-f_cups" 414 | - "065-g_cups" 415 | - "065-h_cups" 416 | - "065-i_cups" 417 | - "065-j_cups" 418 | - "frl_apartment_cup_01" 419 | - "frl_apartment_cup_02" 420 | - "frl_apartment_cup_03" 421 | - "frl_apartment_cup_05" 422 | 423 | receptacle_sets: 424 | # Keep this as the 0th entry. 425 | - 426 | name: "all_receps" 427 | included_object_substrings: [""] 428 | excluded_object_substrings: [] 429 | excluded_receptacle_substrings: [] 430 | included_receptacle_substrings: 431 | - "receptacle_aabb_Chr1_Top1_frl_apartment_chair_01" 432 | - "receptacle_aabb_Tbl1_Top1_frl_apartment_table_01" 433 | - "receptacle_aabb_Tbl2_Top1_frl_apartment_table_02" 434 | - "receptacle_aabb_TvStnd1_Top1_frl_apartment_tvstand" 435 | - "receptacle_aabb_sink_kitchen_counter" 436 | - "receptacle_aabb_counter_right_kitchen_counter" 437 | - "receptacle_aabb_counter_left_kitchen_counter" 438 | - "receptacle_aabb_Sofa_frl_apartment_sofa" 439 | 440 | # Inside of articulated objects. 441 | - "receptacle_aabb_middle_topfrl_apartment_refrigerator" 442 | - "receptacle_aabb_drawer_left_top_frl_apartment_kitchen_counter" 443 | - "receptacle_aabb_drawer_right_top_frl_apartment_kitchen_counter" 444 | # - "receptacle_aabb_drawer_middle_top_frl_apartment_kitchen_counter" 445 | 446 | - 447 | name: "open_air_receps" 448 | included_object_substrings: 449 | - "" 450 | excluded_object_substrings: [] 451 | excluded_receptacle_substrings: [] 452 | included_receptacle_substrings: 453 | # - "receptacle_aabb_Chr1_Top1_frl_apartment_chair_01" 454 | - "receptacle_aabb_Tbl1_Top1_frl_apartment_table_01" 455 | - "receptacle_aabb_Tbl2_Top1_frl_apartment_table_02" 456 | - "receptacle_aabb_TvStnd1_Top1_frl_apartment_tvstand" 457 | - "receptacle_aabb_sink_kitchen_counter" 458 | - "receptacle_aabb_counter_right_kitchen_counter" 459 | - "receptacle_aabb_counter_left_kitchen_counter" 460 | - "receptacle_aabb_Sofa_frl_apartment_sofa" 461 | 462 | 463 | max_objects_per_receptacle: 464 | - ["receptacle_aabb_Chr1_Top1_frl_apartment_chair_01", 2] 465 | - ["receptacle_aabb_sink_kitchen_counter", 2] 466 | 467 | scene_sampler: 468 | type: "subset" 469 | params: 470 | scene_sets: ["scene_train_split", "scene_val_split", "scene_test_split"] 471 | 472 | object_samplers: 473 | - 474 | name: "CLUTTER" 475 | type: "uniform" 476 | params: 477 | object_sets: ["CLUTTER_OBJECTS"] 478 | receptacle_sets: ["open_air_receps"] 479 | num_samples: [30, 30] 480 | orientation_sampling: "up" 481 | 482 | object_target_samplers: [] 483 | 484 | ao_state_samplers: 485 | - 486 | name: "open_fridge_cab" 487 | type: "composite" 488 | params: 489 | - 490 | ao_handle: "fridge" 491 | joint_states: 492 | - ["top_door", 1.5, 1.5] 493 | should_sample_all_joints: True 494 | - 495 | ao_handle: "counter" 496 | joint_states: 497 | - ["drawer1_top", 0.5, 0.5] 498 | - ["drawer1_bottom", 0.5, 0.5] 499 | - ["drawer2_top", 0.5, 0.5] 500 | - ["drawer2_middle", 0.5, 0.5] 501 | - ["drawer2_bottom", 0.5, 0.5] 502 | - ["drawer3", 0.5, 0.5] 503 | - ["drawer4", 0.5, 0.5] 504 | 505 | markers: 506 | - name: "cab_push_point_7" 507 | type: "articulated_object" 508 | params: 509 | offset: [0.3,0.0,0] 510 | link: "drawer1_top" 511 | object: "kitchen_counter_:0000" 512 | - name: "cab_push_point_6" 513 | type: "articulated_object" 514 | params: 515 | offset: [0.3,0.0,0] 516 | link: "drawer2_top" 517 | object: "kitchen_counter_:0000" 518 | - name: "cab_push_point_5" 519 | type: "articulated_object" 520 | params: 521 | offset: [0.3,0.0,0] 522 | link: "drawer3" 523 | object: "kitchen_counter_:0000" 524 | - name: "cab_push_point_4" 525 | type: "articulated_object" 526 | params: 527 | offset: [0.3,0.0,0] 528 | link: "drawer4" 529 | object: "kitchen_counter_:0000" 530 | - name: "fridge_push_point" 531 | type: "articulated_object" 532 | params: 533 | offset: [0.10,-0.62,0.2] 534 | link: "top_door" 535 | object: "fridge_:0000" 536 | -------------------------------------------------------------------------------- /llarp/dataset/create_episodes.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import gzip 6 | import os 7 | import os.path as osp 8 | import pickle 9 | import random 10 | from collections import defaultdict 11 | from dataclasses import dataclass, field 12 | from threading import Thread 13 | from typing import TYPE_CHECKING, Any, Dict, List 14 | 15 | import numpy as np 16 | import pandas as pd 17 | from habitat.core.logging import logger 18 | from habitat.datasets.rearrange.rearrange_generator import \ 19 | RearrangeEpisodeGenerator 20 | from habitat.datasets.rearrange.run_episode_generator import ( 21 | RearrangeEpisodeGeneratorConfig, get_arg_parser, get_config_defaults, 22 | print_metadata_mediator) 23 | from habitat.datasets.rearrange.samplers.receptacle import \ 24 | get_all_scenedataset_receptacles 25 | from habitat.utils.common import cull_string_list_by_substrings 26 | from omegaconf import OmegaConf 27 | from torch import multiprocessing as mp 28 | 29 | from llarp.dataset.episodes import LangRearrangeDatasetV0 30 | from llarp.dataset.generator import (LangRearrangeEpisodeGenerator, 31 | generate_all_instructions, 32 | get_flat_eps_split) 33 | from llarp.dataset.utils import get_category_info 34 | from llarp.task.utils import get_parser 35 | 36 | 37 | @dataclass 38 | class LangEpisodeGeneratorConfig(RearrangeEpisodeGeneratorConfig): 39 | category_groups: Dict[str, Any] = field(default_factory=dict) 40 | recep_cat_groups: Dict[str, Any] = field(default_factory=dict) 41 | 42 | 43 | def generate_episodes(args, cfg, conn, iter_eps, proc_idx): 44 | with LangRearrangeEpisodeGenerator( 45 | cfg=cfg, 46 | instruct_path=args.instruct_path, 47 | iter_eps=iter_eps, 48 | debug_visualization=args.debug, 49 | limit_scene_set=args.limit_scene_set, 50 | proc_idx=proc_idx, 51 | ) as ep_gen: 52 | if not osp.isdir(args.db_output): 53 | os.makedirs(args.db_output) 54 | ep_gen.vdb.output_path = osp.abspath(args.db_output) 55 | conn.send(ep_gen.generate_episodes(args.num_episodes, args.verbose)) 56 | conn.close() 57 | 58 | 59 | def summarize_episodes(episodes, show_examples=False, tokenizer_name=None): 60 | if tokenizer_name is not None: 61 | tokenizer = get_parser(tokenizer_name) 62 | instructs = [ep.instruction for ep in episodes] 63 | token_lens = [] 64 | for ep in episodes: 65 | # I intentionally didn't batch this to be sure I am matching the 66 | # way the tokenizer is used in the actual simulation. 67 | tokens = tokenizer( 68 | ep.instruction, 69 | return_tensors="np", 70 | )[ 71 | "input_ids" 72 | ][0] 73 | token_lens.append(tokens.shape[0]) 74 | print("max token len ", max(token_lens)) 75 | instructs = defaultdict(int) 76 | instruct_ids = defaultdict(int) 77 | instruct_id_instructs = defaultdict(set) 78 | scene_ids = defaultdict(int) 79 | for ep in episodes: 80 | instructs[ep.instruction] += 1 81 | instruct_ids[ep.instruct_id] += 1 82 | scene_ids[ep.scene_id] += 1 83 | instruct_id_instructs[ep.instruct_id].add(ep.instruction) 84 | if show_examples: 85 | # Show 25 of the instructions. 86 | for k, ex_instructs in instruct_id_instructs.items(): 87 | ex_instructs = list(ex_instructs) 88 | np.random.shuffle(ex_instructs) 89 | print(k) 90 | print("\n".join(instruct for instruct in ex_instructs[:25])) 91 | print("") 92 | print(f"Total instructs: {len(episodes)}") 93 | print(f"Num distinct instructs: {len(instructs)}") 94 | print(f"Num instruct IDS: {len(instruct_ids)}") 95 | print(f"Num scene IDS: {len(scene_ids)}") 96 | 97 | for k, v in instruct_ids.items(): 98 | print(f" {k}: {len(instruct_id_instructs[k])} distinct, {v} total") 99 | 100 | return instructs 101 | 102 | 103 | def get_base_info(cfg, args, conn): 104 | print("Starting rearrange gen") 105 | tmp_ep_gen = RearrangeEpisodeGenerator( 106 | cfg=cfg, 107 | debug_visualization=args.debug, 108 | limit_scene_set=args.limit_scene_set, 109 | ) 110 | print("Done rearrange gen") 111 | tmp_ep_gen.sim.close(destroy=True) 112 | del tmp_ep_gen.sim 113 | conn.send((tmp_ep_gen._obj_sets, tmp_ep_gen._receptacle_sets)) 114 | del tmp_ep_gen 115 | conn.close() 116 | 117 | 118 | def start(args): 119 | # Verify the dataset 120 | get_category_info() 121 | 122 | if args.seed is not None: 123 | random.seed(args.seed) 124 | np.random.seed(args.seed) 125 | 126 | cfg = OmegaConf.create(LangEpisodeGeneratorConfig()) 127 | if args.config is not None: 128 | assert osp.exists( 129 | args.config 130 | ), f"Provided config, '{args.config}', does not exist." 131 | override_config = OmegaConf.load(args.config) 132 | cfg = OmegaConf.merge(cfg, override_config) # type: ignore[assignment] 133 | 134 | dataset = LangRearrangeDatasetV0() 135 | if args.procs_per_gpu is None: 136 | # Default to all processes being on the same GPU. 137 | procs_per_gpu = args.n_procs 138 | else: 139 | procs_per_gpu = args.procs_per_gpu 140 | 141 | if args.instruct_lim is not None: 142 | instruct_lim_id = args.instruct_lim.split(",") 143 | else: 144 | instruct_lim_id = None 145 | 146 | mp_ctx = mp.get_context("forkserver") 147 | 148 | parent_conn, child_conn = mp_ctx.Pipe() 149 | # We need to fetch this info in another process because it is creating a new habsim. 150 | proc = mp_ctx.Process( 151 | target=get_base_info, 152 | args=(cfg, args, child_conn), 153 | ) 154 | proc.daemon = True 155 | proc.start() 156 | obj_sets, recep_sets = parent_conn.recv() 157 | proc.join() 158 | 159 | assert args.seed is not None 160 | rng = np.random.RandomState(args.seed) 161 | 162 | # By default, these are the Hydra list type. 163 | cat_groups = {k: list(v["included"]) for k, v in cfg.category_groups.items()} 164 | recep_sets = { 165 | k: list(v.included_receptacle_substrings) for k, v in recep_sets.items() 166 | } 167 | all_eps, instruct_samples = generate_all_instructions( 168 | args.instruct_path, 169 | obj_sets, 170 | recep_sets, 171 | instruct_lim_id, 172 | args.phrase, 173 | args.tag, 174 | args.obj_sample_split, 175 | cat_groups, 176 | args.per_instruct_count_lim, 177 | rng, 178 | ) 179 | rng.shuffle(instruct_samples) 180 | ep_keys = sorted(list(all_eps.keys())) 181 | for k in ep_keys: 182 | rng.shuffle(all_eps[k]) 183 | 184 | proc_infos = [] 185 | to_gen_distinct_instructs = defaultdict(lambda: [set(), 0]) 186 | 187 | for i in range(args.n_procs): 188 | use_cfg = cfg.copy() 189 | use_cfg.gpu_device_id = i // procs_per_gpu 190 | parent_conn, child_conn = mp_ctx.Pipe() 191 | iter_eps = get_flat_eps_split( 192 | all_eps, 193 | i, 194 | args.total_take, 195 | args.cur_gen_idx, 196 | args.n_procs, 197 | args.num_episodes, 198 | instruct_samples, 199 | ) 200 | for ep in iter_eps: 201 | to_gen_distinct_instructs[ep.instruct_info.instruct_id][0].add(ep.instruct) 202 | to_gen_distinct_instructs[ep.instruct_info.instruct_id][1] += 1 203 | 204 | if args.proc_debug: 205 | p = Thread( 206 | target=generate_episodes, 207 | args=(args, use_cfg, child_conn, iter_eps, i), 208 | ) 209 | else: 210 | p = mp_ctx.Process( 211 | target=generate_episodes, 212 | args=(args, use_cfg, child_conn, iter_eps, i), 213 | ) 214 | print(f"Starting worker {i}") 215 | p.start() 216 | proc_infos.append((parent_conn, p)) 217 | 218 | total_distinct = sum(len(x[0]) for x in to_gen_distinct_instructs.values()) 219 | total_instructs = sum(x[1] for x in to_gen_distinct_instructs.values()) 220 | print(f"Plan for generation: {total_distinct} distinct, {total_instructs} total.") 221 | for k, v in to_gen_distinct_instructs.items(): 222 | print(f" {k}: {len(v[0])} distinct, {v[1]} total") 223 | print() 224 | 225 | not_collected = list(range((len(proc_infos)))) 226 | for i, (conn, proc) in enumerate(proc_infos): 227 | try: 228 | result = conn.recv() 229 | except EOFError as e: 230 | logger.warning(f"Problem in worker {i}. Could not generate any episodes.") 231 | continue 232 | proc.join() 233 | if result is None: 234 | print("Result is none skipping") 235 | continue 236 | n_eps_collected = len(result) 237 | not_collected.pop(not_collected.index(i)) 238 | print( 239 | f"Collected {n_eps_collected} episodes from worker {i}. Waiting for {not_collected}" 240 | ) 241 | dataset.episodes.extend(result) 242 | if n_eps_collected != args.num_episodes: 243 | logger.warning( 244 | f"Problem collecting episodes from worker {i}. Expected {args.num_episodes}, got {n_eps_collected}" 245 | ) 246 | print("Done extending episodes list.") 247 | print("Summarizing the episodes") 248 | summarize_episodes(dataset.episodes) 249 | 250 | output_path = args.out 251 | if not osp.exists(osp.dirname(output_path)) and len(osp.dirname(output_path)) > 0: 252 | os.makedirs(osp.dirname(output_path)) 253 | with open(output_path, "wb") as f: 254 | pickle.dump(dataset.to_binary(), f) 255 | 256 | logger.warning("==============================================================") 257 | logger.warning(f"RearrangeDatasetV0 saved to '{osp.abspath(output_path)}'") 258 | logger.warning("==============================================================") 259 | 260 | 261 | if __name__ == "__main__": 262 | parser = get_arg_parser() 263 | parser.add_argument("--instruct-path", required=True, type=str) 264 | parser.add_argument("--get-instructs", action="store_true") 265 | parser.add_argument( 266 | "--instruct-lim", 267 | type=str, 268 | default=None, 269 | help="Comma seperated unique instruction IDs to limit sampling to", 270 | ) 271 | parser.add_argument("--tag", type=str, default=None, required=True) 272 | parser.add_argument("--phrase", type=str, default="train") 273 | parser.add_argument( 274 | "--obj-sample-split", 275 | type=str, 276 | default="train", 277 | help="Either 'train' or 'eval'", 278 | ) 279 | parser.add_argument("--n-procs", type=int, default=1) 280 | parser.add_argument("--procs-per-gpu", type=int, default=None) 281 | parser.add_argument("--cur-gen-idx", type=int, default=0) 282 | parser.add_argument("--total-take", type=int, default=1) 283 | parser.add_argument( 284 | "--per-instruct-count-lim", 285 | type=int, 286 | default=200_000, 287 | help="The maximum number of instruction allowed per instruction type.", 288 | ) 289 | parser.add_argument("--proc-debug", action="store_true") 290 | parser.add_argument( 291 | "--instruct-dir", default="interactive_and_embodied/projects/llarp/instructs" 292 | ) 293 | args, _ = parser.parse_known_args() 294 | 295 | start(args) 296 | -------------------------------------------------------------------------------- /llarp/dataset/demo_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import os.path as osp 7 | from collections import defaultdict 8 | from dataclasses import asdict, dataclass 9 | from typing import Any, Dict, List, Optional 10 | 11 | import numpy as np 12 | import torch 13 | from habitat.core.logging import logger 14 | from torch.nn.utils.rnn import pad_sequence 15 | from torch.utils.data import DataLoader, TensorDataset 16 | from tqdm import tqdm 17 | 18 | from llarp.dataset.utils import LOCAL_DATASETS_PATH, get_name_mappings 19 | 20 | 21 | @dataclass 22 | class TrajInfo: 23 | """ 24 | Let `H` be the length of the trajectory (meaning the number of actions the 25 | agent takes). 26 | 27 | :property actions: Tensor of shape [H, action_dim] 28 | :property obs: Dictionary of shape {k: [H, obs_dim...]}. Where step `t` is 29 | the observation in which the agent made action at `actions[t]`. 30 | """ 31 | 32 | actions: torch.Tensor 33 | obs: Dict[str, torch.Tensor] 34 | infos: Dict[str, Any] 35 | instruction: Optional[str] = None 36 | 37 | 38 | @dataclass 39 | class DemoDataset: 40 | trajs: List[TrajInfo] 41 | 42 | def save(self, filepath): 43 | with open(filepath, "wb") as f: 44 | torch.save(asdict(self), f) 45 | 46 | def __len__(self): 47 | return sum(x.actions.shape[0] for x in self.trajs) 48 | 49 | @staticmethod 50 | def load(filepath): 51 | # We don't have space to move these to GPU. They will be moved to GPU 52 | # on demand. 53 | demos = torch.load(filepath, map_location="cpu") 54 | 55 | trajs = [ 56 | TrajInfo( 57 | actions=d["actions"], 58 | obs=d["obs"], 59 | infos=d["infos"], 60 | instruction=d.get("instruction", None), 61 | ) 62 | for d in demos["trajs"] 63 | ] 64 | logger.info(f"Loaded demo dataset from {filepath}") 65 | np.random.shuffle(trajs) 66 | return DemoDataset( 67 | trajs=trajs, 68 | ) 69 | 70 | 71 | def cat_demo_datasets(demos: List[DemoDataset]) -> DemoDataset: 72 | trajs = [] 73 | for demo in demos: 74 | trajs.extend(demo.trajs) 75 | 76 | return DemoDataset( 77 | trajs=trajs, 78 | ) 79 | 80 | 81 | class DatasetCollector: 82 | def __init__(self, n_envs, flush_interval, dataset_name): 83 | self._n_envs = n_envs 84 | self._cur_traj_actions = [[] for _ in range(n_envs)] 85 | self._cur_traj_infos = [defaultdict(list) for _ in range(n_envs)] 86 | self._cur_traj_obs = [defaultdict(list) for _ in range(n_envs)] 87 | self._all_trajs = [] 88 | self._flush_interval = flush_interval 89 | self._dataset_name = dataset_name 90 | self._save_dir = osp.join(LOCAL_DATASETS_PATH, dataset_name) 91 | self._cur_block = 0 92 | 93 | def collect_action(self, action, infos): 94 | for env_i in range(self._n_envs): 95 | # Only save primitive types. Otherwise serialization will get messed up. 96 | filtered_infos = { 97 | k: v 98 | for k, v in infos[env_i].items() 99 | if isinstance(v, (int, float, str)) 100 | } 101 | self._cur_traj_actions[env_i].append(action[env_i]) 102 | for k, v in filtered_infos.items(): 103 | self._cur_traj_infos[env_i][k].append(v) 104 | 105 | def collect_obs(self, obs): 106 | """ 107 | Add observation to all workers. 108 | """ 109 | 110 | for env_i in range(self._n_envs): 111 | for k, obs_k in obs.items(): 112 | self._cur_traj_obs[env_i][k].append(obs_k[env_i].cpu()) 113 | 114 | def on_ep_done(self, env_i, env_infos, instruction: Optional[str] = None): 115 | # Only save the episode if it was successful. 116 | if env_infos["env_success"]: 117 | traj_info = TrajInfo( 118 | actions=torch.stack(self._cur_traj_actions[env_i], dim=0), 119 | obs={ 120 | k: torch.stack(v, dim=0) 121 | for k, v in self._cur_traj_obs[env_i].items() 122 | }, 123 | infos=dict(self._cur_traj_infos[env_i]), 124 | instruction=instruction, 125 | ) 126 | for k, v in traj_info.obs.items(): 127 | assert ( 128 | v.shape[0] == traj_info.actions.shape[0] 129 | ), f"Traj info shapes don't match {k}: {v.shape}, {traj_info.actions.shape}" 130 | self._all_trajs.append(traj_info) 131 | 132 | n_trajs = len(self._all_trajs) 133 | if n_trajs != 0 and n_trajs % self._flush_interval == 0: 134 | self.flush() 135 | 136 | # Restart demo collection for this worker. 137 | self._cur_traj_obs[env_i] = defaultdict(list) 138 | self._cur_traj_actions[env_i] = [] 139 | self._cur_traj_infos[env_i] = defaultdict(list) 140 | 141 | def flush(self): 142 | if len(self._all_trajs) == 0: 143 | # Nothing to write. 144 | return 145 | os.makedirs(self._save_dir, exist_ok=True) 146 | block_name = f"block{self._cur_block}.pt" 147 | save_path = osp.join(self._save_dir, block_name) 148 | dataset = DemoDataset(trajs=self._all_trajs) 149 | dataset.save(save_path) 150 | logger.info( 151 | f"Saved {len(dataset)} transitions ({len(self._all_trajs)} trajs) to {save_path}" 152 | ) 153 | self._all_trajs = [] 154 | self._cur_block += 1 155 | -------------------------------------------------------------------------------- /llarp/dataset/episodes.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import json 6 | import pickle 7 | import random 8 | from itertools import groupby 9 | from typing import Any, Dict, List, Optional 10 | 11 | import attr 12 | import numpy as np 13 | from habitat.core.dataset import EpisodeIterator 14 | from habitat.core.logging import logger 15 | from habitat.core.registry import registry 16 | from habitat.core.utils import DatasetFloatJSONEncoder 17 | from habitat.datasets.rearrange.rearrange_dataset import (RearrangeDatasetV0, 18 | RearrangeEpisode) 19 | from habitat.datasets.utils import check_and_gen_physics_config 20 | from habitat.tasks.rearrange.multi_task.pddl_predicate import Predicate 21 | 22 | 23 | @attr.s(auto_attribs=True, kw_only=True) 24 | class LangRearrangeEpisode(RearrangeEpisode): 25 | instruction: str 26 | sampled_entities: Dict[str, str] 27 | # str form of predicates 28 | start_preds: List[str] 29 | goal_preds: Dict[str, Any] 30 | instruct_id: str 31 | sampler_info: Dict[str, str] 32 | # str form of predicates 33 | subgoals: List[List[str]] = None 34 | 35 | 36 | @registry.register_dataset(name="LangRearrangeDataset-v0") 37 | class LangRearrangeDatasetV0(RearrangeDatasetV0): 38 | def __init__(self, config=None, preset_eps=None) -> None: 39 | self.config = config 40 | 41 | check_and_gen_physics_config() 42 | 43 | self.episodes = [] 44 | 45 | if config is None: 46 | return 47 | 48 | if preset_eps is None: 49 | datasetfile_path = config.data_path.format(split=config.split) 50 | logger.info(f"Loading from {datasetfile_path}") 51 | with open(datasetfile_path, "rb") as f: 52 | self.from_binary(pickle.load(f), scenes_dir=config.scenes_dir) 53 | 54 | self.episodes = list( 55 | filter(self.build_content_scenes_filter(config), self.episodes) 56 | ) 57 | else: 58 | self.episodes = preset_eps 59 | 60 | def to_json(self) -> str: 61 | result = DatasetFloatJSONEncoder().encode(self) 62 | return result 63 | 64 | def to_binary(self) -> str: 65 | def access_idx(k, name_to_idx): 66 | if len(name_to_idx) == 0: 67 | name_to_idx[k] = 0 68 | if k not in name_to_idx: 69 | name_to_idx[k] = max(name_to_idx.values()) + 1 70 | return name_to_idx[k] 71 | 72 | def encode_name_dict(d, name_to_idx): 73 | ret_d = {} 74 | for k, v in d.items(): 75 | ret_d[access_idx(k, name_to_idx)] = v 76 | return ret_d 77 | 78 | all_transforms = [] 79 | name_to_idx = {} 80 | all_eps = [] 81 | 82 | for ep in self.episodes: 83 | new_ep_data = attr.asdict(ep) 84 | rigid_objs = [] 85 | for name, T in ep.rigid_objs: 86 | rigid_objs.append([access_idx(name, name_to_idx), len(all_transforms)]) 87 | all_transforms.append(T) 88 | 89 | name_to_recep = [] 90 | for name, recep in ep.name_to_receptacle.items(): 91 | name_to_recep.append( 92 | [access_idx(name, name_to_idx), access_idx(recep, name_to_idx)] 93 | ) 94 | new_ep_data["rigid_objs"] = np.array(rigid_objs) 95 | new_ep_data["ao_states"] = encode_name_dict(ep.ao_states, name_to_idx) 96 | new_ep_data["name_to_receptacle"] = np.array(name_to_recep) 97 | new_ep_data["additional_obj_config_paths"] = list( 98 | new_ep_data["additional_obj_config_paths"] 99 | ) 100 | del new_ep_data["_shortest_path_cache"] 101 | 102 | new_markers = [] 103 | for marker_data in ep.markers: 104 | new_markers.append( 105 | [ 106 | access_idx(marker_data["name"], name_to_idx), 107 | access_idx(marker_data["type"], name_to_idx), 108 | np.array(marker_data["params"]["offset"]), 109 | access_idx(marker_data["params"]["link"], name_to_idx), 110 | access_idx(marker_data["params"]["object"], name_to_idx), 111 | ] 112 | ) 113 | 114 | new_ep_data["markers"] = new_markers 115 | 116 | all_eps.append(new_ep_data) 117 | 118 | idx_to_name = {} 119 | for k, v in name_to_idx.items(): 120 | # idx_to_name should define a 1-1 mapping between the name and the 121 | # name index. 122 | assert v not in idx_to_name 123 | idx_to_name[v] = k 124 | 125 | return { 126 | "all_transforms": np.array(all_transforms), 127 | "idx_to_name": idx_to_name, 128 | "all_eps": all_eps, 129 | } 130 | 131 | def from_binary( 132 | self, data_dict: Dict[str, Any], scenes_dir: Optional[str] = None 133 | ) -> None: 134 | all_T = data_dict["all_transforms"] 135 | idx_to_name = data_dict["idx_to_name"] 136 | for i, ep in enumerate(data_dict["all_eps"]): 137 | ep["rigid_objs"] = [ 138 | [idx_to_name[ni], all_T[ti]] for ni, ti in ep["rigid_objs"] 139 | ] 140 | ep["ao_states"] = {idx_to_name[ni]: v for ni, v in ep["ao_states"].items()} 141 | ep["name_to_receptacle"] = { 142 | idx_to_name[k]: idx_to_name[v] for k, v in ep["name_to_receptacle"] 143 | } 144 | 145 | new_markers = [] 146 | for name, mtype, offset, link, obj in ep["markers"]: 147 | new_markers.append( 148 | { 149 | "name": idx_to_name[name], 150 | "type": idx_to_name[mtype], 151 | "params": { 152 | "offset": offset, 153 | "link": idx_to_name[link], 154 | "object": idx_to_name[obj], 155 | }, 156 | } 157 | ) 158 | ep["markers"] = new_markers 159 | 160 | rearrangement_episode = LangRearrangeEpisode(**ep) 161 | rearrangement_episode.episode_id = str(i) 162 | self.episodes.append(rearrangement_episode) 163 | 164 | def from_json(self, json_str: str, scenes_dir: Optional[str] = None) -> None: 165 | deserialized = json.loads(json_str) 166 | 167 | for i, episode in enumerate(deserialized["episodes"]): 168 | rearrangement_episode = LangRearrangeEpisode(**episode) 169 | rearrangement_episode.episode_id = str(i) 170 | 171 | self.episodes.append(rearrangement_episode) 172 | 173 | def get_episode_iterator(self, *args, **kwargs): 174 | return CustomEpisodeIterator(self.episodes, *args, **kwargs) 175 | 176 | 177 | class CustomEpisodeIterator(EpisodeIterator): 178 | def __init__( 179 | self, 180 | episodes, 181 | cycle: bool = True, 182 | shuffle: bool = False, 183 | group_by_scene: bool = True, 184 | max_scene_repeat_episodes: int = -1, 185 | max_scene_repeat_steps: int = -1, 186 | num_episode_sample: int = -1, 187 | step_repetition_range: float = 0.2, 188 | seed: int = None, 189 | ) -> None: 190 | if seed: 191 | random.seed(seed) 192 | np.random.seed(seed) 193 | 194 | # sample episodes 195 | if num_episode_sample >= 0: 196 | episodes = np.random.choice( # type: ignore[assignment] 197 | episodes, num_episode_sample, replace=False # type: ignore[arg-type] 198 | ) 199 | 200 | if not isinstance(episodes, list): 201 | episodes = list(episodes) 202 | 203 | self.episodes = episodes 204 | self.cycle = cycle 205 | self.group_by_scene = group_by_scene 206 | self.shuffle = shuffle 207 | 208 | if shuffle: 209 | random.shuffle(self.episodes) 210 | 211 | if group_by_scene: 212 | self.episodes = self._group_scenes(self.episodes) 213 | 214 | self.max_scene_repetition_episodes = max_scene_repeat_episodes 215 | self.max_scene_repetition_steps = max_scene_repeat_steps 216 | 217 | self._rep_count = -1 # 0 corresponds to first episode already returned 218 | self._step_count = 0 219 | self._prev_scene_id: Optional[str] = None 220 | 221 | self._iterator = iter(self.episodes) 222 | 223 | self.step_repetition_range = step_repetition_range 224 | self._set_shuffle_intervals() 225 | 226 | def __iter__(self): 227 | return self 228 | 229 | def __next__(self): 230 | self._forced_scene_switch_if() 231 | next_episode = next(self._iterator, None) 232 | if next_episode is None: 233 | if not self.cycle: 234 | raise StopIteration 235 | 236 | self._iterator = iter(self.episodes) 237 | 238 | if self.shuffle: 239 | self._shuffle() 240 | 241 | next_episode = next(self._iterator) 242 | 243 | if ( 244 | self._prev_scene_id != next_episode.scene_id 245 | and self._prev_scene_id is not None 246 | ): 247 | self._rep_count = 0 248 | self._step_count = 0 249 | 250 | self._prev_scene_id = next_episode.scene_id 251 | return next_episode 252 | 253 | def _forced_scene_switch(self) -> None: 254 | grouped_episodes = [ 255 | list(g) for k, g in groupby(self._iterator, key=lambda x: x.scene_id) 256 | ] 257 | 258 | if len(grouped_episodes) > 1: 259 | # Ensure we swap by moving the current group to the end 260 | grouped_episodes = grouped_episodes[1:] + grouped_episodes[0:1] 261 | 262 | self._iterator = iter(sum(grouped_episodes, [])) 263 | 264 | def _shuffle(self) -> None: 265 | assert self.shuffle 266 | episodes = list(self._iterator) 267 | 268 | random.shuffle(episodes) 269 | 270 | if self.group_by_scene: 271 | episodes = self._group_scenes(episodes) 272 | 273 | self._iterator = iter(episodes) 274 | 275 | def _group_scenes(self, episodes): 276 | assert self.group_by_scene 277 | 278 | scene_sort_keys: Dict[str, int] = {} 279 | for e in episodes: 280 | if e.scene_id not in scene_sort_keys: 281 | scene_sort_keys[e.scene_id] = len(scene_sort_keys) 282 | 283 | return sorted(episodes, key=lambda e: scene_sort_keys[e.scene_id]) 284 | 285 | def step_taken(self) -> None: 286 | self._step_count += 1 287 | 288 | @staticmethod 289 | def _randomize_value(value: int, value_range: float) -> int: 290 | return random.randint( 291 | int(value * (1 - value_range)), int(value * (1 + value_range)) 292 | ) 293 | 294 | def _set_shuffle_intervals(self) -> None: 295 | if self.max_scene_repetition_episodes > 0: 296 | self._max_rep_episode = self.max_scene_repetition_episodes 297 | else: 298 | self._max_rep_episode = None 299 | 300 | if self.max_scene_repetition_steps > 0: 301 | self._max_rep_step = self._randomize_value( 302 | self.max_scene_repetition_steps, self.step_repetition_range 303 | ) 304 | else: 305 | self._max_rep_step = None 306 | 307 | def _forced_scene_switch_if(self) -> None: 308 | do_switch = False 309 | self._rep_count += 1 310 | 311 | # Shuffle if a scene has been selected more than _max_rep_episode times in a row 312 | if ( 313 | self._max_rep_episode is not None 314 | and self._rep_count >= self._max_rep_episode 315 | ): 316 | do_switch = True 317 | 318 | # Shuffle if a scene has been used for more than _max_rep_step steps in a row 319 | if self._max_rep_step is not None and self._step_count >= self._max_rep_step: 320 | do_switch = True 321 | 322 | if do_switch: 323 | self._forced_scene_switch() 324 | self._set_shuffle_intervals() 325 | -------------------------------------------------------------------------------- /llarp/dataset/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import inspect 6 | import os 7 | import os.path as osp 8 | from typing import Dict, List, Tuple 9 | 10 | import yaml 11 | 12 | import llarp.dataset 13 | 14 | ALL_CATS_NAME = "all_cats" 15 | INSTRUCTION_FILE = "instructions.yaml" 16 | 17 | # Standard buckets. 18 | LOCAL_DATASETS_PATH = "data/datasets" 19 | 20 | 21 | def get_instruct_data(): 22 | instructs_path = osp.dirname(inspect.getfile(llarp.dataset)) 23 | instructs_cfg = osp.join(instructs_path, "configs", INSTRUCTION_FILE) 24 | with open(instructs_cfg, "r") as f: 25 | instructs = yaml.load(f, Loader=yaml.FullLoader) 26 | return instructs 27 | 28 | 29 | def get_name_mappings() -> Dict[str, str]: 30 | """ 31 | Gets the friendly name mappings from the instruction file. 32 | """ 33 | instructs = get_instruct_data() 34 | return instructs["name_mappings"] 35 | 36 | 37 | def get_all_instruct_ids(): 38 | instructs = get_instruct_data() 39 | return sorted([str(x) for x in instructs["instructions"].keys()]) 40 | 41 | 42 | def get_category_info(skip_load_receps=False): 43 | """ 44 | Get the list of all categories and a mapping from object name to category. 45 | """ 46 | dataset_path = osp.dirname(inspect.getfile(llarp.dataset)) 47 | dataset_cfg = osp.join(dataset_path, "configs", "dataset.yaml") 48 | 49 | # Load dataset_cfg as a dict 50 | with open(dataset_cfg, "r") as f: 51 | dataset = yaml.load(f, Loader=yaml.FullLoader) 52 | cat_groups = dataset["category_groups"] 53 | all_receps_cat = dataset["receptacle_sets"][0] 54 | assert all_receps_cat["name"] == "all_receps" 55 | all_obj_cats = dataset["category_groups"][ALL_CATS_NAME]["included"] 56 | 57 | all_cats = [] 58 | if not skip_load_receps: 59 | all_cats.extend(all_receps_cat["included_receptacle_substrings"]) 60 | all_cats.extend(all_obj_cats) 61 | 62 | obj_to_cls = {} 63 | for oset in dataset["object_sets"]: 64 | if oset["name"] == "CLUTTER_OBJECTS": 65 | continue 66 | for oname in oset["included_substrings"]: 67 | if oname in obj_to_cls: 68 | raise ValueError(f"Object {oname} is in multiple sets") 69 | obj_to_cls[oname] = oset["name"] 70 | return all_cats, all_obj_cats, obj_to_cls 71 | -------------------------------------------------------------------------------- /llarp/policies/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from llarp.policies.action_decoders import * 6 | from llarp.policies.llm_policy import * 7 | from llarp.policies.transformer_storage import * 8 | from llarp.policies.vis_bridge import * 9 | -------------------------------------------------------------------------------- /llarp/policies/action_decoders.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import abc 6 | from typing import Optional, Tuple 7 | 8 | import gym.spaces as spaces 9 | import torch 10 | import torch.nn as nn 11 | from einops import rearrange 12 | from habitat_baselines.utils.common import CustomFixedCategorical, CustomNormal 13 | from transformers import BertConfig, BertModel 14 | from transformers.models.bert.modeling_bert import BertEncoder, BertPooler 15 | 16 | from llarp.policies.utils import LlmWrapper 17 | 18 | MAX_N_ACTIONS = 10 19 | 20 | 21 | class FixedCategorical(torch.distributions.Categorical): 22 | def sample(self, sample_shape=torch.Size()): 23 | return super().sample(sample_shape).unsqueeze(-1) 24 | 25 | def log_probs(self, actions): 26 | # return super().log_prob(actions) 27 | if actions.dim() == 2: 28 | return super().log_prob(actions.squeeze(-1)).view(actions.size(0), -1) 29 | elif actions.dim() == 3: 30 | return super().log_prob(actions) 31 | else: 32 | raise ValueError(f"Unrecognized actions shape {actions.shape}") 33 | 34 | def mode(self): 35 | return self.probs.argmax(dim=-1, keepdim=True) 36 | 37 | def entropy(self): 38 | return super().entropy().unsqueeze(-1) 39 | 40 | 41 | class ActionDecoder(nn.Module, abc.ABC): 42 | @abc.abstractmethod 43 | def forward( 44 | self, hidden_state, obs, embed_obs 45 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: 46 | pass 47 | 48 | @property 49 | @abc.abstractmethod 50 | def hidden_size(self) -> int: 51 | pass 52 | 53 | @abc.abstractmethod 54 | def embed_action(self, action, llm): 55 | pass 56 | 57 | def get_distrib(self, logits, **kwargs): 58 | return FixedCategorical(logits=logits.float(), validate_args=False) 59 | 60 | 61 | class MlpDecoder(ActionDecoder): 62 | def __init__( 63 | self, 64 | *args, 65 | input_dim: int, 66 | output_dim: int, 67 | hidden_size: int, 68 | action_space, 69 | use_b16: bool = False, 70 | min_log_std: int = -5, 71 | max_log_std: int = 2, 72 | log_std_init: float = 0.0, 73 | **kwargs, 74 | ): 75 | super().__init__() 76 | self.proj = nn.Sequential( 77 | nn.Linear(input_dim, hidden_size), 78 | nn.LayerNorm(hidden_size), 79 | nn.ReLU(True), 80 | nn.Linear(hidden_size, hidden_size), 81 | nn.LayerNorm(hidden_size), 82 | nn.ReLU(True), 83 | ) 84 | self._is_cont_action = isinstance(action_space, spaces.Box) 85 | 86 | self.linear = nn.Linear(hidden_size, output_dim) 87 | self._hidden_size = hidden_size 88 | 89 | nn.init.orthogonal_(self.linear.weight, gain=0.01) 90 | nn.init.constant_(self.linear.bias, 0) 91 | 92 | self._min_log_std = min_log_std 93 | self._max_log_std = max_log_std 94 | if self._is_cont_action: 95 | self.log_std = torch.nn.parameter.Parameter( 96 | torch.randn(output_dim) * 0.01 + log_std_init 97 | ) 98 | # Project to embedding of continuous action. 99 | self.action_embed = nn.Linear(output_dim, input_dim) 100 | else: 101 | # Embedding for each option. 102 | self.action_embed = nn.Embedding( 103 | num_embeddings=output_dim, embedding_dim=input_dim 104 | ) 105 | 106 | if use_b16: 107 | self.to(torch.bfloat16) 108 | 109 | @property 110 | def hidden_size(self) -> int: 111 | return self._hidden_size 112 | 113 | def forward(self, hidden_state, obs, embed_obs): 114 | hidden_state = self.proj(hidden_state) 115 | return self.linear(hidden_state), None, hidden_state 116 | 117 | def embed_action(self, action, llm): 118 | if self._is_cont_action: 119 | # Add sequence dim 120 | return self.action_embed(action.unsqueeze(-2)) 121 | else: 122 | return self.action_embed(action) 123 | 124 | def get_distrib(self, logits, **kwargs): 125 | if self._is_cont_action: 126 | if logits.dim() == 4: 127 | logits = rearrange(logits, "b n 1 d -> b n d") 128 | log_std = self.log_std 129 | log_std = torch.clamp(log_std, self._min_log_std, self._max_log_std) 130 | std = torch.exp(log_std) 131 | return CustomNormal(logits.float(), std.float(), validate_args=False) 132 | else: 133 | return FixedCategorical(logits=logits.float(), validate_args=False) 134 | -------------------------------------------------------------------------------- /llarp/policies/cores/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from llarp.policies.cores.base_core import PolicyCore 6 | from llarp.policies.cores.decoder import CleanLLMPolicyCore 7 | -------------------------------------------------------------------------------- /llarp/policies/cores/base_core.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import abc 6 | 7 | import gym.spaces as spaces 8 | import torch.nn as nn 9 | 10 | 11 | class PolicyCore(nn.Module, abc.ABC): 12 | def __init__(self, obs_space, config): 13 | super().__init__() 14 | self._im_obs_space = spaces.Dict( 15 | { 16 | k: v 17 | for k, v in obs_space.items() 18 | if len(v.shape) == 3 and k not in ["third_rgb", "top_down_rgb"] 19 | } 20 | ) 21 | 22 | self._state_obs_space = spaces.Dict( 23 | {k: v for k, v in obs_space.items() if len(v.shape) == 1} 24 | ) 25 | self._config = config 26 | self._is_blind = len(self._im_obs_space) == 0 27 | self._prefix_tokens_obs_k = config.prefix_tokens_obs_k 28 | 29 | @property 30 | @abc.abstractmethod 31 | def rnn_hidden_dim(self): 32 | pass 33 | 34 | @property 35 | def visual_encoder(self): 36 | return None 37 | 38 | @property 39 | def hidden_window_dim(self): 40 | return 512 41 | 42 | @abc.abstractmethod 43 | def get_num_rnn_layers(self): 44 | pass 45 | 46 | @abc.abstractmethod 47 | def get_trainable_params(self): 48 | pass 49 | -------------------------------------------------------------------------------- /llarp/policies/hl_policy.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import logging 6 | from itertools import chain 7 | from typing import Any, Dict, List 8 | 9 | import einops 10 | import gym.spaces as spaces 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from habitat import logger 15 | from habitat.tasks.rearrange.multi_task.pddl_action import PddlAction 16 | from habitat_baselines.common.baseline_registry import baseline_registry 17 | from habitat_baselines.common.logging import baselines_logger 18 | from habitat_baselines.rl.hrl.hierarchical_policy import HierarchicalPolicy 19 | from habitat_baselines.rl.hrl.hl.fixed_policy import FixedHighLevelPolicy 20 | from habitat_baselines.rl.hrl.hl.high_level_policy import HighLevelPolicy 21 | from habitat_baselines.rl.hrl.skills import NoopSkillPolicy # noqa: F401. 22 | from habitat_baselines.rl.hrl.skills import ResetArmSkill, WaitSkillPolicy 23 | from habitat_baselines.rl.ppo.policy import CriticHead, PolicyActionData 24 | from habitat_baselines.utils.common import (CategoricalNet, 25 | CustomFixedCategorical) 26 | from habitat_baselines.utils.timing import g_timer 27 | 28 | from llarp.dataset.utils import get_category_info 29 | from llarp.policies.cores import * 30 | from llarp.task.utils import get_allowed_actions, get_pddl 31 | 32 | 33 | def _map_state_dict(state_dict, rename_ckpt_keys): 34 | return { 35 | _map_state_dict_k(state_dict_k, rename_ckpt_keys): v 36 | for state_dict_k, v in state_dict.items() 37 | } 38 | 39 | 40 | def _map_state_dict_k(state_dict_k, rename_ckpt_keys): 41 | for orig_k, replace_k in rename_ckpt_keys.items(): 42 | state_dict_k = state_dict_k.replace(orig_k, replace_k) 43 | return state_dict_k 44 | 45 | 46 | @baseline_registry.register_policy 47 | class CatHierarchicalPolicy(HierarchicalPolicy): 48 | def __init__( 49 | self, 50 | config, 51 | full_config, 52 | observation_space, 53 | action_space, 54 | orig_action_space, 55 | num_envs, 56 | aux_loss_config, 57 | agent_name, 58 | ): 59 | super().__init__( 60 | config, 61 | full_config, 62 | observation_space, 63 | action_space, 64 | orig_action_space, 65 | num_envs, 66 | aux_loss_config, 67 | agent_name, 68 | ) 69 | self._strict_loading = ( 70 | config.hierarchical_policy.high_level_policy.strict_loading 71 | ) 72 | 73 | pretrain_path = config.hierarchical_policy.high_level_policy.pretrain_ckpt_path 74 | if pretrain_path is not None: 75 | logger.info(f"Loading {pretrain_path}") 76 | # Load in pre-trained checkpoint. 77 | ckpt = torch.load( 78 | pretrain_path, 79 | map_location="cpu", 80 | ) 81 | self.load_state_dict(ckpt["state_dict"]) 82 | 83 | @property 84 | def policy_action_space(self): 85 | return self._high_level_policy.policy_action_space 86 | 87 | @property 88 | def recurrent_hidden_size(self) -> int: 89 | if isinstance(self._high_level_policy, FixedHighLevelPolicy): 90 | # We are using a non-learning based policy 91 | return 0 92 | else: 93 | return self._high_level_policy.recurrent_hidden_size 94 | 95 | def load_state_dict(self, state_dict, **kwargs): 96 | kwargs["strict"] = self._strict_loading 97 | state_dict = _map_state_dict( 98 | state_dict, 99 | { 100 | "_high_level_policy._policy_core.vis_bridge.": "_high_level_policy._policy_core.vis_bridge_net.", 101 | "_high_level_policy._policy_core.action_proj.": "_high_level_policy._policy_core.action_proj_net.", 102 | }, 103 | ) 104 | return super().load_state_dict(state_dict, **kwargs) 105 | 106 | @property 107 | def visual_encoder(self): 108 | return self._high_level_policy.visual_encoder 109 | 110 | def _create_pddl(self, full_config, config): 111 | all_cats, obj_cats, _ = get_category_info() 112 | return get_pddl(full_config.habitat.task, all_cats, obj_cats) 113 | 114 | def _get_hl_policy_cls(self, config): 115 | return eval(config.hierarchical_policy.high_level_policy.name) 116 | 117 | def get_extra(self, action_data, infos, dones) -> List[Dict[str, float]]: 118 | extra = super().get_extra(action_data, infos, dones) 119 | # Also add the value information. 120 | if action_data is not None and action_data.values is not None: 121 | for i, val_pred in enumerate(action_data.values): 122 | extra[i]["value"] = val_pred.item() 123 | return extra 124 | 125 | def _create_skills(self, skills, observation_space, action_space, full_config): 126 | for action in self._pddl.get_ordered_actions(): 127 | if "pick" in action.name and action.name not in skills: 128 | # Duplicate the pick skill config. 129 | skills[action.name] = skills["pick"] 130 | skill_i = 0 131 | for ( 132 | skill_name, 133 | skill_config, 134 | ) in skills.items(): 135 | cls = eval(skill_config.skill_name) 136 | skill_policy = cls.from_config( 137 | skill_config, 138 | observation_space, 139 | action_space, 140 | self._num_envs, 141 | full_config, 142 | ) 143 | skill_policy.set_pddl_problem(self._pddl) 144 | if skill_config.pddl_action_names is None: 145 | action_names = [skill_name] 146 | else: 147 | action_names = skill_config.pddl_action_names 148 | for skill_id in action_names: 149 | self._name_to_idx[skill_id] = skill_i 150 | self._idx_to_name[skill_i] = skill_id 151 | self._skills[skill_i] = skill_policy 152 | skill_i += 1 153 | 154 | 155 | class HlPolicy(HighLevelPolicy): 156 | def __init__(self, *args, action_space, **kwargs): 157 | super().__init__(*args, action_space=action_space, **kwargs) 158 | self._all_actions = get_allowed_actions( 159 | self._pddl_prob, self._config.allowed_actions 160 | ) 161 | self._n_actions = len(self._all_actions) 162 | print(f"Got {self._n_actions} hl actions") 163 | 164 | # Only take the keys that are in the observation space. 165 | use_obs_space = spaces.Dict( 166 | { 167 | k: self._obs_space.spaces[k] 168 | for k in self._config.policy_input_keys 169 | if k in self._obs_space.spaces 170 | } 171 | ) 172 | # self._prev_sig = None 173 | 174 | self._use_term_action = self._config.use_term_action 175 | if self._use_term_action: 176 | self._n_actions += 1 177 | 178 | policy_core_cls = eval(self._config.policy_core_type) 179 | self._policy_core = policy_core_cls(use_obs_space, action_space, self._config) 180 | 181 | self._policy = CategoricalNet(self._config.ac_hidden_size, self._n_actions) 182 | self._critic = CriticHead(self._config.ac_hidden_size) 183 | 184 | @property 185 | def visual_encoder(self): 186 | return self._policy_core.visual_encoder 187 | 188 | @property 189 | def recurrent_hidden_size(self): 190 | return self._policy_core.rnn_hidden_dim 191 | 192 | @property 193 | def policy_action_space(self): 194 | return spaces.Discrete(self._n_actions) 195 | 196 | def parameters(self): 197 | return chain( 198 | self._policy_core.get_trainable_params(), 199 | self._policy.parameters(), 200 | self._critic.parameters(), 201 | ) 202 | 203 | def get_policy_components(self) -> List[nn.Module]: 204 | return [self._policy_core, self._policy, self._critic] 205 | 206 | def to(self, device): 207 | self._device = device 208 | self._critic.to(device) 209 | self._policy.to(device) 210 | self._policy_core.to(device) 211 | return self 212 | 213 | @property 214 | def should_load_agent_state(self): 215 | return True 216 | 217 | def get_value(self, observations, rnn_hidden_states, prev_actions, masks): 218 | features, _ = self._policy_core.forward(observations, rnn_hidden_states, masks) 219 | return self._critic(features) 220 | 221 | def evaluate_actions( 222 | self, 223 | observations, 224 | rnn_hidden_states, 225 | prev_actions, 226 | masks, 227 | action, 228 | rnn_build_seq_info, 229 | ): 230 | with g_timer.avg_time(f"hl_policy.eval_actions.get_distrib", level=1): 231 | distrib, features, _ = self.get_distrib( 232 | observations, rnn_hidden_states, masks, rnn_build_seq_info 233 | ) 234 | value = self._critic(features) 235 | 236 | if len(action.shape) == 3: 237 | # [batch_size, seq_len, data_dim] 238 | log_probs = distrib.log_prob(einops.rearrange(action, "b t 1 -> b t")) 239 | log_probs = einops.rearrange(log_probs, "b t -> b t 1") 240 | else: 241 | log_probs = distrib.log_probs(action) 242 | 243 | distribution_entropy = distrib.entropy() 244 | 245 | return ( 246 | value, 247 | log_probs, 248 | distribution_entropy, 249 | rnn_hidden_states, 250 | {}, 251 | ) 252 | 253 | def get_distrib( 254 | self, 255 | observations, 256 | rnn_hidden_states, 257 | masks, 258 | rnn_build_seq_info=None, 259 | flatten=False, 260 | ): 261 | features, rnn_hidden_states = self._policy_core.forward( 262 | observations, rnn_hidden_states, masks, rnn_build_seq_info 263 | ) 264 | distrib = self._policy(features) 265 | 266 | return distrib, features, rnn_hidden_states 267 | 268 | def get_next_skill( 269 | self, 270 | observations, 271 | rnn_hidden_states, 272 | prev_actions, 273 | masks, 274 | plan_masks, 275 | deterministic, 276 | log_info, 277 | ): 278 | next_skill = torch.zeros(self._num_envs, dtype=torch.long) 279 | skill_args_data: List[Any] = [None for _ in range(self._num_envs)] 280 | immediate_end = torch.zeros(self._num_envs, dtype=torch.bool) 281 | 282 | with g_timer.avg_time(f"hl_policy.act.get_distrib", level=1): 283 | distrib, features, rnn_hidden_states = self.get_distrib( 284 | observations, rnn_hidden_states, masks 285 | ) 286 | values = self._critic(features) 287 | 288 | if deterministic: 289 | skill_sel = distrib.mode() 290 | else: 291 | skill_sel = distrib.sample() 292 | action_log_probs = distrib.log_probs(skill_sel) 293 | 294 | for batch_idx, should_plan in enumerate(plan_masks): 295 | if should_plan != 1.0: 296 | continue 297 | batch_skill_sel = skill_sel[batch_idx] 298 | if batch_skill_sel >= len(self._all_actions): 299 | assert self._use_term_action 300 | immediate_end[batch_idx] = True 301 | log_info[batch_idx]["nn_action"] = "terminate" 302 | 303 | # Choose a random skill because it won't matter since we will terminate the episode. 304 | use_ac = self._all_actions[0] 305 | next_skill[batch_idx] = self._skill_name_to_idx[use_ac.name] 306 | skill_args_data[batch_idx] = [ 307 | entity.name for entity in use_ac.param_values 308 | ] 309 | else: 310 | use_ac = self._all_actions[batch_skill_sel] 311 | 312 | next_skill[batch_idx] = self._skill_name_to_idx[use_ac.name] 313 | skill_args_data[batch_idx] = [ 314 | entity.name for entity in use_ac.param_values 315 | ] 316 | log_info[batch_idx]["nn_action"] = use_ac.compact_str 317 | 318 | return ( 319 | next_skill, 320 | skill_args_data, 321 | immediate_end, 322 | PolicyActionData( 323 | action_log_probs=action_log_probs, 324 | values=values, 325 | actions=skill_sel, 326 | rnn_hidden_states=rnn_hidden_states, 327 | ), 328 | ) 329 | 330 | @property 331 | def num_recurrent_layers(self): 332 | return self._policy_core.get_num_rnn_layers() 333 | -------------------------------------------------------------------------------- /llarp/policies/llama_parallel.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import torch 9 | from torch import distributed as distrib 10 | from transformers import LlamaModel 11 | from transformers.modeling_outputs import BaseModelOutputWithPast 12 | from transformers.utils import logging 13 | 14 | logger = logging.get_logger(__name__) 15 | 16 | 17 | class LlamaModelParallel(LlamaModel): 18 | def __init__(self, *args, model_parallel_factor, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self._model_parallel_factor = model_parallel_factor 21 | self._model_blocks = None 22 | 23 | def to(self, device): 24 | if not isinstance(device, torch.device): 25 | return super().to(device) 26 | 27 | torch.cuda.set_device(device) 28 | 29 | self.embed_tokens.to(device) 30 | self.norm.to(device) 31 | 32 | if self._model_parallel_factor > 1: 33 | block_len = len(self.layers) // self._model_parallel_factor 34 | self._model_blocks = {} 35 | nranks = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) 36 | for block_i in range(self._model_parallel_factor): 37 | block_start = block_i * block_len 38 | block_end = (block_i + 1) * block_len 39 | device_idx = device.index + (block_i * nranks) 40 | 41 | use_device = torch.device(type="cuda", index=device_idx) 42 | for layer_idx in range(block_start, block_end): 43 | self._model_blocks[layer_idx] = use_device 44 | self.layers[layer_idx].to(use_device) 45 | 46 | elif self._model_parallel_factor == 1: 47 | self.layers.to(device) 48 | self._model_blocks = None 49 | else: 50 | raise ValueError(f"Invalid value for {self._model_parallel_factor=}") 51 | return self 52 | 53 | def forward( 54 | self, 55 | input_ids: torch.LongTensor = None, 56 | attention_mask: Optional[torch.Tensor] = None, 57 | position_ids: Optional[torch.LongTensor] = None, 58 | past_key_values: Optional[List[torch.FloatTensor]] = None, 59 | inputs_embeds: Optional[torch.FloatTensor] = None, 60 | use_cache: Optional[bool] = None, 61 | output_attentions: Optional[bool] = None, 62 | output_hidden_states: Optional[bool] = None, 63 | return_dict: Optional[bool] = None, 64 | ) -> Union[Tuple, BaseModelOutputWithPast]: 65 | output_attentions = ( 66 | output_attentions 67 | if output_attentions is not None 68 | else self.config.output_attentions 69 | ) 70 | output_hidden_states = ( 71 | output_hidden_states 72 | if output_hidden_states is not None 73 | else self.config.output_hidden_states 74 | ) 75 | use_cache = use_cache if use_cache is not None else self.config.use_cache 76 | 77 | return_dict = ( 78 | return_dict if return_dict is not None else self.config.use_return_dict 79 | ) 80 | 81 | # retrieve input_ids and inputs_embeds 82 | if input_ids is not None and inputs_embeds is not None: 83 | raise ValueError( 84 | "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" 85 | ) 86 | elif input_ids is not None: 87 | batch_size, seq_length = input_ids.shape 88 | elif inputs_embeds is not None: 89 | batch_size, seq_length, _ = inputs_embeds.shape 90 | else: 91 | raise ValueError( 92 | "You have to specify either decoder_input_ids or decoder_inputs_embeds" 93 | ) 94 | 95 | seq_length_with_past = seq_length 96 | past_key_values_length = 0 97 | 98 | if past_key_values is not None: 99 | past_key_values_length = past_key_values[0][0].shape[2] 100 | seq_length_with_past = seq_length_with_past + past_key_values_length 101 | 102 | device = input_ids.device if input_ids is not None else inputs_embeds.device 103 | if position_ids is None: 104 | position_ids = torch.arange( 105 | past_key_values_length, 106 | seq_length + past_key_values_length, 107 | dtype=torch.long, 108 | device=device, 109 | ) 110 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 111 | else: 112 | position_ids = position_ids.view(-1, seq_length).long() 113 | 114 | if inputs_embeds is None: 115 | inputs_embeds = self.embed_tokens(input_ids) 116 | # embed positions 117 | if attention_mask is None: 118 | attention_mask = torch.ones( 119 | (batch_size, seq_length_with_past), 120 | dtype=torch.bool, 121 | device=inputs_embeds.device, 122 | ) 123 | attention_mask = self._prepare_decoder_attention_mask( 124 | attention_mask, 125 | (batch_size, seq_length), 126 | inputs_embeds, 127 | past_key_values_length, 128 | ) 129 | 130 | hidden_states = inputs_embeds 131 | 132 | if self.gradient_checkpointing and self.training: 133 | if use_cache: 134 | logger.warning_once( 135 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 136 | ) 137 | use_cache = False 138 | 139 | # decoder layers 140 | all_hidden_states = () if output_hidden_states else None 141 | all_self_attns = () if output_attentions else None 142 | next_decoder_cache = () if use_cache else None 143 | 144 | orig_device = hidden_states.device 145 | 146 | prev_device = None 147 | for idx, decoder_layer in enumerate(self.layers): 148 | if output_hidden_states: 149 | all_hidden_states += (hidden_states,) 150 | 151 | if self._model_blocks is not None: 152 | use_device = self._model_blocks[idx] 153 | else: 154 | use_device = orig_device 155 | 156 | if past_key_values is not None: 157 | past_key_value = tuple( 158 | past_key_values[cache_i][idx].to(use_device) 159 | for cache_i in range(len(past_key_values)) 160 | ) 161 | else: 162 | past_key_value = None 163 | 164 | if self.gradient_checkpointing and self.training: 165 | 166 | def create_custom_forward(module): 167 | def custom_forward(*inputs): 168 | # None for past_key_value 169 | return module(*inputs, output_attentions, None) 170 | 171 | return custom_forward 172 | 173 | layer_outputs = torch.utils.checkpoint.checkpoint( 174 | create_custom_forward(decoder_layer), 175 | hidden_states, 176 | attention_mask, 177 | position_ids, 178 | None, 179 | ) 180 | else: 181 | layer_outputs = decoder_layer( 182 | hidden_states.to(use_device), 183 | attention_mask=attention_mask.to(use_device), 184 | position_ids=position_ids.to(use_device), 185 | past_key_value=past_key_value, 186 | output_attentions=output_attentions, 187 | use_cache=use_cache, 188 | ) 189 | 190 | hidden_states = layer_outputs[0] 191 | 192 | if use_cache: 193 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 194 | 195 | if output_attentions: 196 | all_self_attns += (layer_outputs[1],) 197 | 198 | # Cast back to the original GPU external systems expect. 199 | hidden_states = hidden_states.to(device) 200 | 201 | hidden_states = self.norm(hidden_states) 202 | 203 | # add hidden states from the last decoder layer 204 | if output_hidden_states: 205 | all_hidden_states += (hidden_states,) 206 | 207 | next_cache = next_decoder_cache if use_cache else None 208 | if not return_dict: 209 | return tuple( 210 | v 211 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] 212 | if v is not None 213 | ) 214 | return BaseModelOutputWithPast( 215 | last_hidden_state=hidden_states, 216 | past_key_values=next_cache, 217 | hidden_states=all_hidden_states, 218 | attentions=all_self_attns, 219 | ) 220 | -------------------------------------------------------------------------------- /llarp/policies/llm_policy.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import logging 6 | from collections import defaultdict 7 | from itertools import chain 8 | from typing import Any, Dict, List 9 | 10 | import gym.spaces as spaces 11 | import hydra 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | from einops import rearrange 16 | from habitat import logger 17 | from habitat.tasks.rearrange.multi_task.pddl_action import PddlAction 18 | from habitat_baselines.common.baseline_registry import baseline_registry 19 | from habitat_baselines.common.logging import baselines_logger 20 | from habitat_baselines.rl.hrl.hierarchical_policy import HierarchicalPolicy 21 | from habitat_baselines.rl.hrl.hl.fixed_policy import FixedHighLevelPolicy 22 | from habitat_baselines.rl.hrl.hl.high_level_policy import HighLevelPolicy 23 | from habitat_baselines.rl.hrl.skills import NoopSkillPolicy # noqa: F401. 24 | from habitat_baselines.rl.hrl.skills import ResetArmSkill, WaitSkillPolicy 25 | from habitat_baselines.rl.ppo.policy import (CriticHead, Policy, 26 | PolicyActionData) 27 | from habitat_baselines.rl.ppo.ppo_trainer import get_device 28 | from habitat_baselines.utils.timing import g_timer 29 | 30 | from llarp.dataset.utils import get_category_info 31 | from llarp.policies.cores import * 32 | from llarp.task.utils import get_allowed_actions, get_pddl 33 | 34 | EPS_PPO = 1e-5 35 | 36 | 37 | @baseline_registry.register_policy 38 | class LlmPolicy(nn.Module, Policy): 39 | def __init__(self, config, obs_space, action_space, **kwargs): 40 | Policy.__init__(self, action_space) 41 | nn.Module.__init__(self) 42 | self._obs_space = obs_space 43 | 44 | policy_cfg = ( 45 | config.habitat_baselines.rl.policy.main_agent.hierarchical_policy.high_level_policy 46 | ) 47 | device = get_device(config) 48 | # Only take the keys that are in the observation space. 49 | use_obs_space = spaces.Dict( 50 | { 51 | k: self._obs_space.spaces[k] 52 | for k in policy_cfg.policy_input_keys 53 | if k in self._obs_space.spaces 54 | } 55 | ) 56 | 57 | self._use_term_action = policy_cfg.use_term_action 58 | 59 | policy_core_cls = eval(policy_cfg.policy_core_type) 60 | self._policy_core = policy_core_cls( 61 | use_obs_space, action_space, policy_cfg, device 62 | ) 63 | 64 | self._critic = hydra.utils.instantiate( 65 | policy_cfg.critic, 66 | input_size=self._policy_core.action_decoder_net.hidden_size, 67 | n_envs=config.habitat_baselines.num_environments, 68 | device=device, 69 | obs_space=self._obs_space, 70 | ) 71 | self._resume_path = config.habitat_baselines.resume_ckpt_path 72 | 73 | @property 74 | def policy_core(self): 75 | return self._policy_core 76 | 77 | @property 78 | def visual_encoder(self): 79 | return self._policy_core.visual_encoder 80 | 81 | @property 82 | def recurrent_hidden_size(self): 83 | return self._policy_core.rnn_hidden_dim 84 | 85 | def parameters(self): 86 | return chain( 87 | self._policy_core.get_trainable_params(), 88 | self._critic.parameters(), 89 | ) 90 | 91 | def _get_policy_components(self) -> List[nn.Module]: 92 | return [self._policy_core, self._critic] 93 | 94 | def to(self, device): 95 | self._device = device 96 | self._critic.to(device) 97 | self._policy_core.to(device) 98 | if self._resume_path is not None: 99 | ckpt = torch.load(self._resume_path, map_location="cpu") 100 | self.load_state_dict(ckpt["state_dict"]) 101 | logger.info(f"Loaded from step {ckpt['extra_state']['step']:,}") 102 | self._resume_path = None 103 | return self 104 | 105 | @property 106 | def should_load_agent_state(self): 107 | return True 108 | 109 | def get_value(self, observations, hidden_states, prev_actions, masks): 110 | ( 111 | _, 112 | _, 113 | features, 114 | _, 115 | ) = self._policy_core.rollout_actions(observations, masks, only_features=True) 116 | return self._critic.forward(features, observations) 117 | 118 | def evaluate_actions( 119 | self, 120 | observations, 121 | action, 122 | ): 123 | ( 124 | log_probs, 125 | features, 126 | distribution_entropy, 127 | self.policy_core_log, 128 | ) = self._policy_core.update_actions(observations, action) 129 | value = self._critic.forward_norm(features, observations) 130 | 131 | return ( 132 | value, 133 | log_probs, 134 | distribution_entropy, 135 | {}, 136 | ) 137 | 138 | def act( 139 | self, 140 | observations, 141 | rnn_hidden_states, 142 | prev_actions, 143 | masks, 144 | deterministic=False, 145 | ): 146 | ( 147 | actions, 148 | log_probs, 149 | features, 150 | hidden_states, 151 | ) = self._policy_core.rollout_actions(observations, masks, deterministic) 152 | values = self._critic.forward(features, observations) 153 | 154 | policy_info = None 155 | if self._policy_core._is_eval_mode: 156 | sel_acs = [self._policy_core._tokenizer.decode(ac) for ac in actions] 157 | policy_info = [{"pred_ac": sel_ac} for sel_ac in sel_acs] 158 | 159 | return PolicyActionData( 160 | action_log_probs=log_probs, 161 | values=values, 162 | actions=actions, 163 | rnn_hidden_states=hidden_states, 164 | policy_info=policy_info, 165 | ) 166 | 167 | @property 168 | def num_recurrent_layers(self): 169 | return self._policy_core.get_num_rnn_layers() 170 | 171 | @classmethod 172 | def from_config(cls, config, observation_space, action_space, **kwargs): 173 | return cls(config, observation_space, action_space, **kwargs) 174 | 175 | 176 | class LinearCriticHead(nn.Module): 177 | def __init__(self, input_size, use_b16: bool, **kwargs): 178 | super().__init__() 179 | 180 | self.fc = nn.Linear(input_size, 1) 181 | nn.init.orthogonal_(self.fc.weight) 182 | nn.init.constant_(self.fc.bias, 0) 183 | if use_b16: 184 | self._use_type = torch.bfloat16 185 | self.to(torch.bfloat16) 186 | else: 187 | self._use_type = torch.float32 188 | 189 | def post_process_returns(self, returns, obs): 190 | return returns 191 | 192 | def forward_norm(self, x, obs): 193 | return self.forward(x, obs) 194 | 195 | def update_stats(self, returns, obs): 196 | pass 197 | 198 | def forward(self, x, obs): 199 | return self.fc(x.to(self._use_type)) 200 | 201 | 202 | class DeepCriticHead(nn.Module): 203 | def __init__( 204 | self, 205 | input_size, 206 | hidden_size, 207 | n_envs: int, 208 | device, 209 | obs_space: spaces.Dict, 210 | use_b16: bool = True, 211 | **kwargs, 212 | ): 213 | super().__init__() 214 | 215 | self.proj = nn.Sequential( 216 | nn.Linear(input_size, hidden_size), 217 | nn.LayerNorm(hidden_size), 218 | nn.ReLU(True), 219 | nn.Linear(hidden_size, hidden_size), 220 | nn.LayerNorm(hidden_size), 221 | nn.ReLU(True), 222 | ) 223 | self.fc = nn.Linear(hidden_size, 1) 224 | nn.init.orthogonal_(self.fc.weight) 225 | nn.init.constant_(self.fc.bias, 0) 226 | if use_b16: 227 | self._use_type = torch.bfloat16 228 | self.to(torch.bfloat16) 229 | else: 230 | self._use_type = torch.float32 231 | 232 | def update_stats(self, returns, obs): 233 | return 234 | 235 | def post_process_returns(self, returns, obs): 236 | return returns 237 | 238 | def forward(self, x, obs): 239 | return self.forward_norm(x, obs) 240 | 241 | def forward_norm(self, x, obs): 242 | return self.fc(self.proj(x.to(self._use_type))) 243 | -------------------------------------------------------------------------------- /llarp/policies/transformer_storage.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import warnings 6 | from typing import Any, Dict, Iterator, Optional 7 | 8 | import gym.spaces as spaces 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from einops import rearrange 13 | from habitat_baselines.common.baseline_registry import baseline_registry 14 | from habitat_baselines.common.rollout_storage import RolloutStorage 15 | from habitat_baselines.common.storage import Storage 16 | from habitat_baselines.common.tensor_dict import DictTree, TensorDict 17 | from habitat_baselines.rl.ddppo.policy import PointNavResNetNet 18 | from habitat_baselines.rl.models.rnn_state_encoder import ( 19 | build_pack_info_from_dones, build_rnn_build_seq_info) 20 | from habitat_baselines.utils.common import get_action_space_info 21 | from habitat_baselines.utils.timing import g_timer 22 | from torch import Tensor 23 | from torch.nn.utils.rnn import pad_sequence 24 | 25 | ATT_MASK_K = "att_mask" 26 | HIDDEN_WINDOW_K = "hidden_window" 27 | START_HIDDEN_WINDOW_K = "start_hidden_window" 28 | START_ATT_MASK_K = "start_att_mask" 29 | FETCH_BEFORE_COUNTS_K = "fetch_before_counts" 30 | 31 | 32 | def transpose_stack_pad_dicts(dicts_i): 33 | res = {} 34 | for k in dicts_i[0].keys(): 35 | if isinstance(dicts_i[0][k], dict): 36 | res[k] = transpose_stack_pad_dicts([d[k] for d in dicts_i]) 37 | else: 38 | res[k] = pad_sequence( 39 | [d[k] for d in dicts_i], batch_first=True, padding_value=0.0 40 | ) 41 | 42 | return res 43 | 44 | 45 | @baseline_registry.register_storage 46 | class TransformerRolloutStorage(RolloutStorage): 47 | def __init__( 48 | self, 49 | *args, 50 | numsteps, 51 | num_envs, 52 | observation_space, 53 | actor_critic, 54 | **kwargs, 55 | ): 56 | self._frozen_visual = ( 57 | PointNavResNetNet.PRETRAINED_VISUAL_FEATURES_KEY in observation_space.spaces 58 | ) 59 | if self._frozen_visual: 60 | # Remove the head RGB camera because we have the visual information in the visual features 61 | observation_space = spaces.Dict( 62 | {k: v for k, v in observation_space.spaces.items() if k != "head_rgb"} 63 | ) 64 | 65 | super().__init__( 66 | *args, 67 | numsteps=numsteps, 68 | num_envs=num_envs, 69 | observation_space=observation_space, 70 | actor_critic=actor_critic, 71 | **kwargs, 72 | ) 73 | 74 | core = actor_critic.policy_core 75 | 76 | self._context_len = core.context_len 77 | self._debug_mode = core._debug_mode 78 | self._rollout_take_ratio = core.rollout_take_ratio 79 | 80 | self.hidden_window = torch.zeros( 81 | num_envs, 82 | numsteps + self._context_len, 83 | *core.hidden_window_dim, 84 | ) 85 | self.att_masks = torch.zeros(num_envs, self._context_len, dtype=torch.bool) 86 | 87 | # The att masks BEFORE this rollout. 88 | self.before_start_att_masks = torch.zeros( 89 | num_envs, self._context_len, dtype=torch.bool 90 | ) 91 | 92 | def to(self, device): 93 | super().to(device) 94 | self.hidden_window = self.hidden_window.to(device) 95 | self.att_masks = self.att_masks.to(device) 96 | 97 | @g_timer.avg_time("rollout_storage.insert", level=1) 98 | def insert( 99 | self, 100 | next_observations=None, 101 | next_recurrent_hidden_states=None, 102 | actions=None, 103 | action_log_probs=None, 104 | value_preds=None, 105 | rewards=None, 106 | next_masks=None, 107 | buffer_index: int = 0, 108 | **kwargs, 109 | ): 110 | if self._frozen_visual and next_observations is not None: 111 | next_observations = { 112 | k: v for k, v in next_observations.items() if k != "head_rgb" 113 | } 114 | super().insert( 115 | next_observations=next_observations, 116 | next_recurrent_hidden_states=None, 117 | actions=actions, 118 | action_log_probs=action_log_probs, 119 | value_preds=value_preds, 120 | rewards=rewards, 121 | next_masks=next_masks, 122 | buffer_index=buffer_index, 123 | **kwargs, 124 | ) 125 | 126 | if next_masks is not None: 127 | step = self.current_rollout_step_idxs[buffer_index] 128 | # Shift the tensor over to the left 1. 129 | self.att_masks = self.att_masks.roll(shifts=-1, dims=1) 130 | # Start over if we are at a new episode. 131 | self.att_masks *= next_masks.to(self.device) 132 | self.att_masks[:, -1] = True 133 | 134 | if next_recurrent_hidden_states is not None: 135 | write_step = ( 136 | self.current_rollout_step_idxs[buffer_index] + self._context_len - 1 137 | ) 138 | env_slice = slice( 139 | int(buffer_index * self._num_envs / self._nbuffers), 140 | int((buffer_index + 1) * self._num_envs / self._nbuffers), 141 | ) 142 | 143 | self.hidden_window[env_slice, write_step] = next_recurrent_hidden_states 144 | 145 | def insert_first_observations(self, batch): 146 | if self._frozen_visual: 147 | batch = {k: v for k, v in batch.items() if k != "head_rgb"} 148 | super().insert_first_observations(batch) 149 | self.att_masks[:, -1] = True 150 | 151 | def after_update(self): 152 | """ 153 | Copy over information from the end of the current rollout buffer into the rollout buffer start. 154 | """ 155 | super().after_update() 156 | self.hidden_window[:, : self._context_len] = ( 157 | self.hidden_window[:, -self._context_len :].detach().clone() 158 | ) 159 | self.before_start_att_masks.copy_(self.att_masks) 160 | 161 | # Clear the rest to write the next rollout 162 | # self.hidden_window[:, -(self.num_steps - 1) :] = 0.0 163 | self.hidden_window[:, self._context_len - 1 :] = 0.0 164 | 165 | self.hidden_window[:, : self._context_len - 1] *= rearrange( 166 | self.att_masks[:, :-1], "b n -> b n 1 1" 167 | ) 168 | 169 | @g_timer.avg_time("rollout_storage.compute_returns", level=1) 170 | def compute_returns(self, next_value, use_gae, gamma, tau, clip_rewards: int = 0.0): 171 | rewards = self.buffers["rewards"] 172 | if clip_rewards > 0: 173 | rewards = torch.clamp(rewards, -clip_rewards, clip_rewards) 174 | 175 | if use_gae: 176 | assert isinstance(self.buffers["value_preds"], torch.Tensor) 177 | self.buffers["value_preds"][self.current_rollout_step_idx] = next_value 178 | gae = 0.0 179 | for step in reversed(range(self.current_rollout_step_idx)): 180 | delta = ( 181 | rewards[step] 182 | + gamma 183 | * self.buffers["value_preds"][step + 1] 184 | * self.buffers["masks"][step + 1] 185 | - self.buffers["value_preds"][step] 186 | ) 187 | gae = delta + gamma * tau * gae * self.buffers["masks"][step + 1] 188 | self.buffers["returns"][step] = ( # type: ignore 189 | gae + self.buffers["value_preds"][step] # type: ignore 190 | ) 191 | 192 | else: 193 | self.buffers["returns"][self.current_rollout_step_idx] = next_value 194 | for step in reversed(range(self.current_rollout_step_idx)): 195 | self.buffers["returns"][step] = ( 196 | gamma 197 | * self.buffers["returns"][step + 1] 198 | * self.buffers["masks"][step + 1] 199 | + self.buffers["rewards"][step] 200 | ) 201 | 202 | @g_timer.avg_time("ts.data_generator") 203 | def data_generator( 204 | self, 205 | advantages: Optional[torch.Tensor], 206 | num_mini_batch: int, 207 | norm_returns, 208 | norm_values, 209 | ) -> Iterator[DictTree]: 210 | assert isinstance(self.buffers["returns"], torch.Tensor) 211 | num_environments = self.buffers["returns"].size(1) 212 | assert num_environments >= num_mini_batch 213 | if num_environments % num_mini_batch != 0: 214 | warnings.warn( 215 | "Number of environments ({}) is not a multiple of the" 216 | " number of mini batches ({}). This results in mini batches" 217 | " of different sizes, which can harm training performance.".format( 218 | num_environments, num_mini_batch 219 | ) 220 | ) 221 | 222 | yield from flatten_trajs( 223 | advantages, 224 | self.buffers, 225 | self._context_len, 226 | self.current_rollout_step_idx, 227 | self.hidden_window, 228 | self.before_start_att_masks, 229 | inds=torch.arange(num_environments), 230 | batch_size=num_environments // num_mini_batch, 231 | max_num_batches=num_mini_batch, 232 | ) 233 | 234 | 235 | def flatten_trajs( 236 | advantages, 237 | buffers, 238 | context_len: int, 239 | current_rollout_step_idx, 240 | hidden_window, 241 | before_start_att_masks, 242 | inds, 243 | batch_size: int, 244 | max_num_batches: int, 245 | ) -> TensorDict: 246 | max_data_window_size = current_rollout_step_idx 247 | device = advantages.device 248 | curr_slice = ( 249 | slice(0, max_data_window_size), 250 | inds, 251 | ) 252 | 253 | batch = buffers[curr_slice] 254 | batch["advantages"] = advantages[curr_slice] 255 | 256 | # Find where episode starts (when masks = False). 257 | dones = ~batch["masks"] 258 | seq_len = batch["masks"].shape[0] 259 | inserted_start = {} 260 | ep_starts = torch.nonzero(~batch["masks"])[:, :2] 261 | ep_starts_cpu = ep_starts.cpu().numpy().tolist() 262 | for batch_idx in range(len(inds)): 263 | if [0, batch_idx] not in ep_starts_cpu: 264 | inserted_start[(0, batch_idx)] = True 265 | ep_starts_cpu.insert(0, [0, batch_idx]) 266 | current_index = 0 267 | while current_index < seq_len - context_len: 268 | ctx_dones_range = dones[ 269 | current_index + 1 : current_index + context_len, batch_idx 270 | ] 271 | if torch.any(ctx_dones_range): 272 | first_start = torch.nonzero(ctx_dones_range)[0].min() 273 | first_start = int(first_start.item()) 274 | current_index += first_start + 1 275 | if [current_index, batch_idx] not in ep_starts_cpu: 276 | ep_starts_cpu.append([current_index, batch_idx]) 277 | else: 278 | current_index += context_len 279 | if [current_index, batch_idx] not in ep_starts_cpu: 280 | ep_starts_cpu.append([current_index, batch_idx]) 281 | 282 | ep_starts_cpu = np.array(ep_starts_cpu) 283 | # Track the next start episode. 284 | batch_to_next_ep_starts = {} 285 | for batch_idx in range(len(inds)): 286 | batch_ep_starts = ep_starts_cpu[ep_starts_cpu[:, 1] == batch_idx][:, 0] 287 | batch_ep_starts = np.sort(batch_ep_starts) 288 | 289 | batch_to_next_ep_starts[batch_idx] = {} 290 | for i in range(len(batch_ep_starts) - 1): 291 | batch_to_next_ep_starts[batch_idx][batch_ep_starts[i]] = batch_ep_starts[ 292 | i + 1 293 | ] 294 | 295 | chunked_ep_start_indices = torch.randperm(len(ep_starts_cpu)).numpy().tolist() 296 | num_valid_before_steps = before_start_att_masks.sum(-1).cpu().tolist() 297 | 298 | for batch_index in range(max_num_batches): 299 | eps = [] 300 | fetch_before_counts = [] 301 | while len(eps) < batch_size: 302 | if not chunked_ep_start_indices: 303 | chunked_ep_start_indices = ( 304 | torch.randperm(len(ep_starts_cpu)).numpy().tolist() 305 | ) 306 | step_idx, batch_idx = ep_starts_cpu[chunked_ep_start_indices[0]] 307 | chunked_ep_start_indices = chunked_ep_start_indices[1:] 308 | 309 | next_ep_starts = batch_to_next_ep_starts[batch_idx] 310 | step_idx_end = next_ep_starts.get(step_idx, max_data_window_size) 311 | 312 | add_batch = batch.map(lambda x: x[step_idx:step_idx_end, batch_idx]) 313 | 314 | did_insert_start = inserted_start.get((step_idx, batch_idx), False) 315 | if did_insert_start: 316 | fetch_before_counts.append( 317 | max(num_valid_before_steps[batch_idx] - 1, 0) 318 | ) 319 | else: 320 | fetch_before_counts.append(0) 321 | 322 | add_batch["observations"][ATT_MASK_K] = torch.ones( 323 | # The att mask needs to have same window size as the data batch. 324 | (step_idx_end - step_idx, 1), 325 | dtype=bool, 326 | device=device, 327 | ) 328 | add_batch["observations"][START_HIDDEN_WINDOW_K] = hidden_window[ 329 | batch_idx, : (context_len - 1) 330 | ] 331 | add_batch["observations"][START_ATT_MASK_K] = before_start_att_masks[ 332 | batch_idx, :-1 333 | ].unsqueeze(-1) 334 | 335 | eps.append(add_batch) 336 | 337 | assert ( 338 | len(eps) > 0 339 | ), "Collected no episodes from rollout, ensure episode horizon is shorter than rollout length." 340 | ret_batch = transpose_stack_pad_dicts(eps) 341 | 342 | ret_batch["observations"][FETCH_BEFORE_COUNTS_K] = torch.tensor( 343 | fetch_before_counts, device=device 344 | ) 345 | 346 | yield ret_batch 347 | 348 | 349 | def check_extracted_batch(add_batch, batch_idx): 350 | """ 351 | Checks a batch is good. Note that the entire batch should be valid as we 352 | are going to attend over the entire sequence. Only the later stack pad 353 | operation will result in False attention masks. 354 | """ 355 | 356 | debug_info = add_batch["observations"]["debug_info"] 357 | batch_step_info = debug_info[:, 0] 358 | batch_ep_info = debug_info[:, 1] 359 | # Each step should be increasing by 1. 360 | if not ((batch_step_info[:-1] + 1) == batch_step_info[1:]).all(): 361 | raise ValueError(f"Episode steps not incrementing {batch_step_info}") 362 | 363 | # Steps should belong to the same episode. 364 | if not (batch_ep_info[0] == batch_ep_info).all(): 365 | raise ValueError(f"Episode IDs not consistent") 366 | -------------------------------------------------------------------------------- /llarp/policies/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import abc 6 | import os.path as osp 7 | import warnings 8 | from typing import Any, Dict, Optional 9 | 10 | import torch 11 | import torch.nn as nn 12 | from habitat import logger 13 | from peft import LoraConfig, TaskType, get_peft_model 14 | from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, 15 | AutoModelForSeq2SeqLM, AutoTokenizer, LlamaConfig, 16 | LlamaForCausalLM, LlamaModel, LlamaTokenizer, 17 | T5Model) 18 | 19 | from llarp.policies.llama_parallel import LlamaModelParallel 20 | from llarp.task.utils import get_parser 21 | 22 | 23 | def soft_update_params(net, target_net, tau): 24 | for param, target_param in zip(net.parameters(), target_net.parameters()): 25 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 26 | 27 | 28 | class LlmWrapper(nn.Module, abc.ABC): 29 | @abc.abstractmethod 30 | def decode( 31 | self, 32 | non_causal_tokens, 33 | causal_embeds, 34 | att_masks, 35 | kv_cache=None, 36 | use_cache=False, 37 | n_non_causal_tokens=None, 38 | ): 39 | pass 40 | 41 | @property 42 | @abc.abstractmethod 43 | def base_model(self) -> nn.Module: 44 | pass 45 | 46 | @property 47 | @abc.abstractmethod 48 | def llm_head(self) -> nn.Module: 49 | pass 50 | 51 | @property 52 | @abc.abstractmethod 53 | def d_model(self) -> int: 54 | pass 55 | 56 | def get_vision_output_shape(self): 57 | return None 58 | 59 | def get_visual_pooler(self): 60 | return None 61 | 62 | def get_visual_encoder(self): 63 | return None 64 | 65 | def get_image_processor(self): 66 | return None 67 | 68 | @property 69 | @abc.abstractmethod 70 | def config(self): 71 | pass 72 | 73 | 74 | def freeze_params(module): 75 | for param in module.parameters(): 76 | param.requires_grad = False 77 | 78 | 79 | class DecoderWrapper(LlmWrapper): 80 | def __init__( 81 | self, 82 | llm_id, 83 | use_b16: bool, 84 | model_cfg: Dict[str, Any], 85 | peft_settings, 86 | peft, 87 | peft_full_att_params, 88 | train_llm: bool = False, 89 | load_in_8bit=False, 90 | use_rope_scaling=False, 91 | debug_mode=False, 92 | model_parallel_factor: int = 1, 93 | force_non_causal: bool = False, 94 | **kwargs, 95 | ): 96 | """ 97 | :param force_non_causal: Legacy option to load models trained without 98 | the CausalLM HF model type. 99 | """ 100 | 101 | super().__init__() 102 | self._debug_mode = debug_mode 103 | if self._debug_mode: 104 | self._tokenizer = get_parser(llm_id) 105 | else: 106 | self._tokenizer = None 107 | 108 | logger.info(f"Loading LLM from {llm_id}") 109 | 110 | kwargs = {} 111 | if use_b16: 112 | kwargs = {"torch_dtype": torch.bfloat16} 113 | 114 | cmp_llm_id = llm_id.lower() 115 | if llm_id is None or llm_id == "": 116 | # Load with a custom config. 117 | self.llm = LlamaForCausalLM(LlamaConfig(**model_cfg)) 118 | elif "llama" in cmp_llm_id: 119 | rope_scaling = ( 120 | {"type": "dynamic", "factor": 2} if use_rope_scaling else None 121 | ) 122 | self.llm = LlamaForCausalLM.from_pretrained( 123 | llm_id, 124 | load_in_8bit=load_in_8bit, 125 | rope_scaling=rope_scaling, 126 | **kwargs, 127 | ) 128 | else: 129 | raise ValueError(f"Unsupported LLM type {llm_id}") 130 | 131 | if use_b16: 132 | logger.info(f"Setting model data type to bfloat16.") 133 | create_type = torch.bfloat16 134 | self.llm.base_model.to(create_type) 135 | else: 136 | create_type = torch.float32 137 | 138 | if force_non_causal: 139 | self.llm = self.llm.base_model 140 | 141 | self._create_type = create_type 142 | self._model_parallel_factor = model_parallel_factor 143 | self._non_causal_mask = None 144 | self.tmp_skip_debug = False 145 | self._train_llm = train_llm 146 | self._peft = peft 147 | 148 | if self._peft: 149 | self.llm = setup_peft_module(self.llm, peft_full_att_params, peft_settings) 150 | elif not self._train_llm: 151 | for param in self.llm.parameters(): 152 | param.requires_grad = False 153 | logger.info(f"Done loading LLM") 154 | 155 | def parameters(self): 156 | if self._peft or self._train_llm: 157 | return self.llm.parameters() 158 | else: 159 | return [] 160 | 161 | @property 162 | def config(self): 163 | return self.llm.config 164 | 165 | @property 166 | def base_model(self): 167 | if self._peft: 168 | # Skip past the PEFT wrapper. 169 | return self.llm.base_model.base_model 170 | elif hasattr(self.llm, "base_model"): 171 | return self.llm.base_model 172 | else: 173 | return self.llm 174 | 175 | @property 176 | def llm_head(self) -> nn.Module: 177 | return self.llm.lm_head 178 | 179 | @property 180 | def d_model(self): 181 | return self.llm.config.hidden_size 182 | 183 | def _verify_input(self, non_causal_tokens, causal_embeds, att_masks): 184 | if self.tmp_skip_debug: 185 | return 186 | debug_info = causal_embeds[..., -2:] 187 | # Check each env individually. 188 | for env_prefix_tokens, env_debug_info, env_att_masks in zip( 189 | non_causal_tokens, debug_info, att_masks 190 | ): 191 | env_att_masks = env_att_masks.view(-1) 192 | valid_debug_info = env_debug_info[env_att_masks].int() 193 | step_idx = valid_debug_info[:, 0] 194 | ep_idx = valid_debug_info[:, 1] 195 | 196 | if (ep_idx[0] != ep_idx).all(): 197 | raise ValueError(f"The episode indices don't match {ep_idx}") 198 | 199 | n_attend = env_att_masks.sum() 200 | if step_idx.max() + 1 != n_attend: 201 | raise ValueError( 202 | f"Attention mask is wrong {env_att_masks}, compare to steps {step_idx}" 203 | ) 204 | step_diff = step_idx[1:] - step_idx[:-1] 205 | if not (len(step_diff) == 0 or (step_diff == 1).all()): 206 | raise ValueError(f"Steps are inconsistent {step_diff}") 207 | 208 | text = self._tokenizer.decode(env_prefix_tokens) 209 | 210 | def embed_tokens(self, tokens) -> torch.Tensor: 211 | return self.base_model.embed_tokens(tokens) 212 | 213 | def decode( 214 | self, 215 | non_causal_tokens, 216 | causal_embeds, 217 | att_masks, 218 | kv_cache=None, 219 | use_cache=False, 220 | n_non_causal_tokens=None, 221 | ): 222 | if self._debug_mode: 223 | self._verify_input(non_causal_tokens, causal_embeds, att_masks) 224 | # Remove the debug info. 225 | causal_embeds = causal_embeds[..., :-2] 226 | 227 | if non_causal_tokens is None: 228 | all_embeds = causal_embeds 229 | else: 230 | causal_embeds = causal_embeds.to(self._create_type) 231 | non_causal_embeds = self.embed_tokens(non_causal_tokens) 232 | all_embeds = torch.cat([non_causal_embeds, causal_embeds], dim=1) 233 | 234 | non_causal_mask = self._get_non_causal_mask(non_causal_tokens.shape[1]) 235 | 236 | if n_non_causal_tokens is None: 237 | n_non_causal_tokens = non_causal_tokens.shape[1] 238 | non_causal_mask = self._get_non_causal_mask(n_non_causal_tokens) 239 | non_causal_mask = non_causal_mask.expand(all_embeds.shape[0], -1) 240 | att_masks = att_masks.view(*att_masks.shape[:2]) 241 | att_masks = torch.cat([non_causal_mask, att_masks], dim=-1) 242 | 243 | assert len(att_masks.shape) == 2 244 | 245 | if use_cache: 246 | kwargs = dict( 247 | use_cache=use_cache, 248 | past_key_values=kv_cache, 249 | ) 250 | else: 251 | kwargs = {} 252 | 253 | seq_out = self.base_model.forward( 254 | inputs_embeds=all_embeds, attention_mask=att_masks.int(), **kwargs 255 | ) 256 | 257 | # Ignore the part of the sequence from the prefix 258 | context_len = causal_embeds.shape[1] 259 | return ( 260 | seq_out.last_hidden_state[:, -context_len:].to(torch.float32), 261 | seq_out.get("past_key_values", None), 262 | ) 263 | 264 | def to(self, device): 265 | self.llm.to(device) 266 | self._device = device 267 | return self 268 | 269 | def _get_non_causal_mask(self, seq_len): 270 | # Recalculate the non-causal mask if the input shape change from the 271 | # last iteration or if the non-causal mask hasn't been created yet. 272 | if self._non_causal_mask is None or self._non_causal_mask.shape[1] != seq_len: 273 | # Cache this mask so we don't have to reallocate. 274 | self._non_causal_mask = torch.ones( 275 | # Only take the token length, we expand to fit the batch sizes 276 | # later. 277 | (1, seq_len), 278 | device=self._device, 279 | dtype=torch.bool, 280 | ) 281 | return self._non_causal_mask 282 | 283 | def generate(self, non_causal_tokens, causal_tokens, tokenizer, num_new_tokens): 284 | if len(causal_tokens) == 0: 285 | full_input = non_causal_tokens 286 | else: 287 | full_input = torch.cat([non_causal_tokens, causal_tokens], dim=-1) 288 | # Causal model. 289 | gen_tokens = self.llm.generate(full_input, max_new_tokens=num_new_tokens) 290 | return tokenizer.batch_decode( 291 | gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False 292 | ) 293 | 294 | 295 | def setup_peft_module(model, peft_full_att_params, peft_settings): 296 | # To figure out the possible names: 297 | # print([(n, type(m)) for n, m in self.vlm.named_modules()]) 298 | if peft_full_att_params: 299 | target_modules = r"model\.layers\.\d+\.self_attn\.(q_proj|v_proj|k_proj|o_proj)" 300 | else: 301 | target_modules = r"model\.layers\.\d+\.self_attn\.(q_proj|v_proj)" 302 | peft_config = LoraConfig( 303 | target_modules=target_modules, 304 | modules_to_save=["lm_head"], 305 | **peft_settings, 306 | ) 307 | model = get_peft_model(model, peft_config) 308 | model.print_trainable_parameters() 309 | return model 310 | -------------------------------------------------------------------------------- /llarp/policies/vis_bridge.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import abc 6 | 7 | import hydra 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from einops import rearrange, repeat 12 | from flamingo_pytorch import PerceiverResampler 13 | 14 | 15 | class VisBridge(abc.ABC, nn.Module): 16 | @abc.abstractmethod 17 | def forward(self, vis_features, obs): 18 | pass 19 | 20 | @property 21 | @abc.abstractmethod 22 | def num_tokens(self) -> int: 23 | pass 24 | 25 | 26 | class ResamplerVisBridge(VisBridge): 27 | def __init__( 28 | self, 29 | vis_encoder_net, 30 | llm, 31 | resampler_depth, 32 | resampler_dim_head, 33 | resampler_heads, 34 | num_output_latents, 35 | use_b16: bool, 36 | **kwargs, 37 | ): 38 | super().__init__() 39 | if use_b16: 40 | self._create_type = torch.bfloat16 41 | else: 42 | self._create_type = torch.float32 43 | 44 | num_visual_tokens, vis_token_dim = vis_encoder_net.output_shape 45 | 46 | self.token_resampler = PerceiverResampler( 47 | dim=vis_token_dim, 48 | num_media_embeds=num_visual_tokens, 49 | num_latents=num_output_latents, 50 | depth=resampler_depth, 51 | dim_head=resampler_dim_head, 52 | heads=resampler_heads, 53 | ) 54 | self.up_proj = nn.Linear(vis_token_dim, llm.d_model) 55 | self.token_resampler.to(self._create_type) 56 | self._num_output_tokens = num_output_latents 57 | 58 | def forward(self, vis_features, obs): 59 | """ 60 | Always returns float32 data type regardless of net internal type. 61 | """ 62 | 63 | if len(vis_features.shape) == 4: 64 | orig_batch_size = vis_features.shape[0] 65 | vis_features = rearrange(vis_features, "b n t d -> (b n) t d") 66 | else: 67 | orig_batch_size = None 68 | 69 | vis_features = vis_features.to(self._create_type) 70 | embeds = self.token_resampler(vis_features) 71 | # The token resampler outputs another dimension for some reason... 72 | embeds = rearrange(embeds, "b 1 t d -> b t d") 73 | embeds = self.up_proj(embeds) 74 | 75 | if orig_batch_size is not None: 76 | embeds = rearrange(embeds, "(b n) t d -> b n t d", b=orig_batch_size) 77 | return embeds.to(torch.float32) 78 | 79 | @property 80 | def num_tokens(self) -> int: 81 | return self._num_output_tokens 82 | 83 | 84 | class MlpVisBridge(VisBridge): 85 | """ 86 | For operating over single token inputs. 87 | """ 88 | 89 | def __init__( 90 | self, 91 | vis_encoder_net, 92 | llm, 93 | state_obs_space, 94 | hidden_size, 95 | cfg, 96 | **kwargs, 97 | ): 98 | super().__init__() 99 | llm_input_size = llm.d_model 100 | if not hasattr(vis_encoder_net, "embd_size"): 101 | if hasattr(vis_encoder_net, "output_shape"): 102 | input_size = np.prod(vis_encoder_net.output_shape) 103 | else: 104 | raise ValueError("Visual encoder must specify output size.") 105 | else: 106 | input_size = vis_encoder_net.embd_size 107 | 108 | self.visual_fc = nn.Sequential( 109 | nn.Flatten(), 110 | nn.Linear( 111 | input_size, 112 | hidden_size, 113 | ), 114 | nn.ReLU(True), 115 | ) 116 | self._state_obs_space = state_obs_space 117 | visual_dim = hidden_size 118 | input_dim = visual_dim + sum( 119 | space.shape[0] for space in state_obs_space.spaces.values() 120 | ) 121 | self.state_token_proj = nn.Linear(input_dim, llm_input_size) 122 | 123 | def forward(self, vis_features, obs): 124 | # There is only 1 visual token. Extract this token and expand. 125 | 126 | if len(vis_features.shape) == 4: 127 | # Operate on the only visual token. 128 | assert vis_features.shape[2] == 1 129 | 130 | batch_size = vis_features.shape[0] 131 | # Flatten and remove #token dim. 132 | vis_features = rearrange(vis_features, "b r 1 d -> (b r) d") 133 | vis_features = self.visual_fc(vis_features) 134 | vis_features = rearrange(vis_features, "(b r) d -> b r d", b=batch_size) 135 | else: 136 | assert vis_features.shape[1] == 1 137 | vis_features = vis_features[:, 0] 138 | 139 | vis_features = self.visual_fc(vis_features) 140 | 141 | state_features = [obs[k] for k in self._state_obs_space.keys()] 142 | 143 | if vis_features is None: 144 | hidden_window = torch.cat(state_features, dim=-1) 145 | elif len(state_features) == 0: 146 | hidden_window = vis_features 147 | else: 148 | hidden_window = torch.cat([vis_features, *state_features], dim=-1) 149 | 150 | hidden_window = self.state_token_proj(hidden_window) 151 | return hidden_window.unsqueeze(-2) 152 | 153 | @property 154 | def num_tokens(self) -> int: 155 | return 1 156 | -------------------------------------------------------------------------------- /llarp/policies/visual_encoders.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import abc 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from einops import rearrange 11 | from habitat_baselines.utils.timing import g_timer 12 | from torch import Tensor 13 | from vc_models.models.vit import model_utils 14 | 15 | 16 | class VisualEncoderWrapper(nn.Module, abc.ABC): 17 | @abc.abstractmethod 18 | def forward(self, obs): 19 | """ 20 | Must return shape [batch_size, # visual embeds, token embed dim] 21 | """ 22 | 23 | @property 24 | @abc.abstractmethod 25 | def output_shape(self): 26 | pass 27 | 28 | 29 | class Vc1VisualEncoder(VisualEncoderWrapper): 30 | """ 31 | Wrapper for the VC1 visual encoder using the CLS encoder. 32 | 33 | :param classifier_feature: Either "use_cls_token" or "reshape_embedding". 34 | """ 35 | 36 | def __init__(self, im_obs_space, use_b16: bool, classifier_feature: str, **kwargs): 37 | super().__init__() 38 | 39 | ( 40 | self.net, 41 | self.embd_size, 42 | self.model_transforms, 43 | model_info, 44 | ) = model_utils.load_model(model_utils.VC1_BASE_NAME) 45 | self.net.classifier_feature = classifier_feature 46 | 47 | self._image_obs_keys = im_obs_space.spaces.keys() 48 | 49 | if use_b16: 50 | self._use_type = torch.bfloat16 51 | else: 52 | self._use_type = torch.float32 53 | 54 | self.to(self._use_type) 55 | 56 | def forward(self, obs): 57 | img = torch.cat( 58 | [v for k, v in obs.items() if k in self._image_obs_keys], dim=-1 59 | ) 60 | img = img.to(self._use_type) 61 | 62 | # Image encoder expects shape [batch_size, img_width, img_height, img_channels] 63 | if len(img.shape) == 5: 64 | # We have a sequence dimension as well. Flatten that into the batch dimension 65 | expand_shape = img.shape[:2] 66 | img = rearrange(img, "b t w h c -> (b t) w h c") 67 | else: 68 | expand_shape = None 69 | 70 | img = self.model_transforms(img.permute(0, 3, 1, 2) / 255.0) 71 | ret = self.net(img) 72 | 73 | if self.net.classifier_feature == "reshape_embedding": 74 | # Flatten the spatial tokens since the PPO storage only stores flat vectors. 75 | ret = rearrange(ret, "b d w h -> b (d w h)") 76 | else: 77 | ret = rearrange(ret, "b d -> b 1 d") 78 | assert ret.shape[1:] == self.output_shape 79 | 80 | # Re-expand the sequence dimension if it is there. 81 | if expand_shape is not None: 82 | ret = rearrange( 83 | ret, "(b t) d -> b t d", b=expand_shape[0], t=expand_shape[1] 84 | ) 85 | 86 | return ret.to(torch.float32) 87 | 88 | @property 89 | def output_shape(self): 90 | if self.net.classifier_feature == "reshape_embedding": 91 | return (np.prod(self.net.patch_embed.grid_size), self.embd_size) 92 | else: 93 | return (1, self.embd_size) 94 | -------------------------------------------------------------------------------- /llarp/run.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | """ 6 | Script to launch Habitat Baselines trainer. 7 | """ 8 | 9 | import os 10 | import os.path as osp 11 | import random 12 | 13 | import gym 14 | import hydra 15 | import numpy as np 16 | import torch 17 | from habitat.config.default import patch_config 18 | from habitat.config.default_structured_configs import register_hydra_plugin 19 | from habitat_baselines.config.default_structured_configs import \ 20 | HabitatBaselinesConfigPlugin 21 | 22 | import llarp.policies 23 | import llarp.task 24 | import llarp.trainer 25 | from llarp.config import default_structured_configs 26 | 27 | # Suppress gym import warnings. 28 | gym.logger.set_level(40) 29 | 30 | 31 | def on_save_ckpt_callback(save_file_path: str) -> None: 32 | """ 33 | Process the saved checkpoints here if desired. 34 | """ 35 | 36 | pass 37 | 38 | 39 | @hydra.main( 40 | version_base=None, 41 | config_path="config", 42 | # The default is overridden in the launch command. 43 | config_name="pointnav/ppo_pointnav_example", 44 | ) 45 | def main(cfg): 46 | cfg = patch_config(cfg) 47 | random.seed(cfg.habitat.seed) 48 | np.random.seed(cfg.habitat.seed) 49 | torch.manual_seed(cfg.habitat.seed) 50 | 51 | if cfg.habitat_baselines.force_torch_single_threaded and torch.cuda.is_available(): 52 | torch.set_num_threads(1) 53 | 54 | from habitat_baselines.common.baseline_registry import baseline_registry 55 | 56 | trainer_init = baseline_registry.get_trainer(cfg.habitat_baselines.trainer_name) 57 | trainer = trainer_init(cfg) 58 | 59 | if cfg.habitat_baselines.evaluate: 60 | trainer.eval() 61 | else: 62 | trainer.train() 63 | 64 | 65 | if __name__ == "__main__": 66 | register_hydra_plugin(HabitatBaselinesConfigPlugin) 67 | main() 68 | -------------------------------------------------------------------------------- /llarp/task/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os.path as osp 6 | 7 | from habitat.gym.gym_definitions import _try_register 8 | 9 | import llarp.config 10 | from llarp.task.actions import * 11 | from llarp.task.measures import * 12 | from llarp.task.predicate_task import RearrangePredicateTask 13 | from llarp.task.sensors import * 14 | 15 | 16 | def easy_register(gym_id, local_path, overrides=None): 17 | if overrides is None: 18 | overrides = {} 19 | cfg_dir = osp.dirname(osp.dirname(osp.abspath(__file__))) 20 | full_path = osp.join(cfg_dir, "config", "task", local_path) 21 | _try_register( 22 | id_name=gym_id, 23 | entry_point="habitat.gym.gym_definitions:_make_habitat_gym_env", 24 | kwargs={"cfg_file_path": full_path, **overrides}, 25 | ) 26 | 27 | 28 | easy_register("HabitatLanguageTask-v0", "lang_cond.yaml") 29 | easy_register("HabitatVisLanguageTask-v0", "lang_cond.yaml", {"use_render_mode": True}) 30 | -------------------------------------------------------------------------------- /llarp/task/actions.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Any, Optional 6 | 7 | import numpy as np 8 | from gym import spaces 9 | from habitat.core.registry import registry 10 | from habitat.tasks.rearrange.actions.articulated_agent_action import \ 11 | ArticulatedAgentAction 12 | from habitat.tasks.rearrange.actions.grip_actions import ( 13 | GazeGraspAction, GripSimulatorTaskAction, MagicGraspAction, 14 | SuctionGraspAction) 15 | from habitat.tasks.rearrange.rearrange_sim import RearrangeSim 16 | 17 | from llarp.task.utils import get_allowed_actions 18 | 19 | 20 | @registry.register_task_action 21 | class CustomArmAbsJointAction(ArticulatedAgentAction): 22 | @property 23 | def action_space(self): 24 | return spaces.Dict( 25 | { 26 | "arm_joint_action": spaces.Box( 27 | shape=(self._config.arm_joint_dimensionality,), 28 | low=-255.0, 29 | high=255.0, 30 | dtype=np.float32, 31 | ) 32 | } 33 | ) 34 | 35 | def step(self, *args, arm_joint_action, **kwargs): 36 | if np.sum(arm_joint_action) < 1e-3: 37 | # Don't change the arm position if this action wasn't invoked (set to 0). 38 | return 39 | # No clipping because the arm is being set to exactly where it needs to 40 | # go. 41 | self.cur_articulated_agent.arm_joint_pos = arm_joint_action 42 | 43 | 44 | @registry.register_task_action 45 | class SafeSuctionGraspAction(MagicGraspAction): 46 | def __init__(self, *args, config, sim: RearrangeSim, **kwargs): 47 | super().__init__(*args, config=config, sim=sim, **kwargs) 48 | self._sim: RearrangeSim = sim 49 | self._prevent_change_duration = config.prevent_gripper_change_duration 50 | 51 | def reset(self, *args: Any, **kwargs: Any) -> None: 52 | self._step_count = 0 53 | super().reset(*args, **kwargs) 54 | 55 | def step(self, grip_action, should_step=True, *args, **kwargs): 56 | self._step_count += 1 57 | if grip_action is None: 58 | return 59 | 60 | if self._step_count < self._prevent_change_duration: 61 | return 62 | 63 | if grip_action >= 0 and not self.cur_grasp_mgr.is_grasped: 64 | self._grasp() 65 | elif grip_action < 0 and self.cur_grasp_mgr.is_grasped: 66 | self._ungrasp() 67 | 68 | 69 | @registry.register_task_action 70 | class KinematicArmEEAction(ArticulatedAgentAction): 71 | """Uses inverse kinematics (requires pybullet) to apply end-effector position control for the articulated_agent's arm.""" 72 | 73 | def __init__(self, *args, sim: RearrangeSim, **kwargs): 74 | self.ee_target: Optional[np.ndarray] = None 75 | self.ee_index: Optional[int] = 0 76 | super().__init__(*args, sim=sim, **kwargs) 77 | self._sim: RearrangeSim = sim 78 | self._render_ee_target = False # self._config.get("render_ee_target", False) 79 | self._ee_ctrl_lim = self._config.ee_ctrl_lim 80 | 81 | def reset(self, *args, **kwargs): 82 | super().reset() 83 | cur_ee = self._ik_helper.calc_fk( 84 | np.array(self._sim.articulated_agent.arm_joint_pos) 85 | ) 86 | 87 | self.ee_target = cur_ee 88 | 89 | @property 90 | def action_space(self): 91 | return spaces.Box(shape=(3,), low=-255.0, high=255.0, dtype=np.float32) 92 | 93 | def apply_ee_constraints(self): 94 | self.ee_target = np.clip( 95 | self.ee_target, 96 | self._sim.articulated_agent.params.ee_constraint[self.ee_index, :, 0], 97 | self._sim.articulated_agent.params.ee_constraint[self.ee_index, :, 1], 98 | ) 99 | 100 | def set_desired_ee_pos(self, ee_delta: np.ndarray) -> None: 101 | self.ee_target += np.array(ee_delta) 102 | 103 | self.apply_ee_constraints() 104 | 105 | joint_pos = np.array(self._sim.articulated_agent.arm_joint_pos) 106 | joint_vel = np.zeros(joint_pos.shape) 107 | 108 | self._ik_helper.set_arm_state(joint_pos, joint_vel) 109 | 110 | des_joint_pos = self._ik_helper.calc_ik(self.ee_target) 111 | des_joint_pos = list(des_joint_pos) 112 | self._sim.articulated_agent.arm_joint_pos = des_joint_pos 113 | 114 | def step(self, ee_delta, **kwargs): 115 | speed = np.linalg.norm(ee_delta) 116 | if speed == 0.0: 117 | # Only act when called. 118 | return 119 | if speed > self._ee_ctrl_lim: 120 | # Clip norm. 121 | ee_delta *= self._ee_ctrl_lim / speed 122 | self.set_desired_ee_pos(ee_delta) 123 | 124 | if self._render_ee_target: 125 | global_pos = ( 126 | self._sim.articulated_agent.base_transformation.transform_point( 127 | self.ee_target 128 | ) 129 | ) 130 | self._sim.viz_ids["ee_target"] = self._sim.visualize_position( 131 | global_pos, self._sim.viz_ids["ee_target"] 132 | ) 133 | 134 | 135 | @registry.register_task_action 136 | class PddlHlAction(ArticulatedAgentAction): 137 | def __init__(self, *args, config, task, **kwargs): 138 | actions = get_allowed_actions(task.pddl_problem, config.allowed_actions) 139 | 140 | self._action_datas = [] 141 | for action in actions: 142 | self._action_datas.append( 143 | (action.name, [p.name for p in action.param_values]) 144 | ) 145 | 146 | super().__init__(*args, config=config, task=task, **kwargs) 147 | 148 | @property 149 | def action_space(self): 150 | return spaces.Discrete(len(self._action_datas)) 151 | 152 | @property 153 | def was_prev_action_invalid(self): 154 | return self._was_prev_action_invalid 155 | 156 | def reset(self, *args, task, **kwargs): 157 | self._was_prev_action_invalid = False 158 | self._prev_action = None 159 | 160 | def step(self, *args, sel, task, **kwargs): 161 | pddl = task.pddl_problem 162 | action_name, param_names = self._action_datas[sel] 163 | 164 | # Get the current up to date PDDL action. 165 | param_values = [] 166 | missing_entity = False 167 | for name in param_names: 168 | if name not in pddl.all_entities: 169 | missing_entity = True 170 | break 171 | param_values.append(pddl.all_entities[name]) 172 | 173 | if missing_entity: 174 | # self._was_prev_action_invalid = True 175 | # return 176 | raise ValueError("MISSING ENTITY. THIS SHOULDNT HAPPEN") 177 | 178 | apply_action = task.pddl_problem.actions[action_name].clone() 179 | apply_action.set_param_values(param_values) 180 | 181 | self._was_prev_action_invalid = not apply_action.apply_if_true( 182 | task.pddl_problem.sim_info 183 | ) 184 | self._prev_action = apply_action 185 | -------------------------------------------------------------------------------- /llarp/task/predicate_task.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import inspect 6 | import os.path as osp 7 | import random 8 | import time 9 | from typing import Any, Dict, List 10 | 11 | import hydra 12 | import magnum as mn 13 | import numpy as np 14 | from habitat.core.registry import registry 15 | from habitat.datasets.rearrange.rearrange_dataset import RearrangeEpisode 16 | from habitat.tasks.rearrange.multi_task.pddl_domain import PddlDomain 17 | from habitat.tasks.rearrange.multi_task.pddl_logical_expr import LogicalExpr 18 | from habitat.tasks.rearrange.multi_task.pddl_predicate import Predicate 19 | from habitat.tasks.rearrange.multi_task.rearrange_pddl import ( 20 | ExprType, PddlEntity, SimulatorObjectType) 21 | from habitat.tasks.rearrange.rearrange_task import RearrangeTask 22 | from habitat.tasks.rearrange.utils import add_perf_timing_func 23 | from omegaconf import DictConfig, ListConfig 24 | from PIL import Image 25 | from transformers import AutoTokenizer 26 | 27 | import llarp.config 28 | from llarp.dataset.episodes import LangRearrangeEpisode 29 | from llarp.dataset.utils import get_category_info 30 | from llarp.task.actions import KinematicArmEEAction 31 | from llarp.task.utils import PLACABLE_RECEP_TYPE, get_parser, get_pddl 32 | 33 | 34 | @registry.register_task(name="RearrangePredicateTask-v0") 35 | class RearrangePredicateTask(RearrangeTask): 36 | def __init__(self, *args, sim, config, dataset, **kwargs): 37 | print(f"Num episodes {len(dataset.episodes)}") 38 | 39 | self._all_cls, obj_cats, self._name_to_cls = get_category_info( 40 | config.skip_load_receps 41 | ) 42 | 43 | self.pddl = get_pddl(config, self._all_cls, obj_cats) 44 | self._fix_agent_pos = config.fix_agent_pos 45 | 46 | super().__init__( 47 | *args, 48 | sim=sim, 49 | config=config, 50 | dataset=dataset, 51 | should_place_articulated_agent=not self._fix_agent_pos, 52 | **kwargs, 53 | ) 54 | self._tokenizer = get_parser(config.tokenizer_name) 55 | 56 | self._start_template = self._config.start_template 57 | self._goal_template = self._config.goal_template 58 | self._sample_entities = self._config.sample_entities 59 | self._sample_entities_use_constant_sampling = ( 60 | self._config.sample_entities_use_constant_sampling 61 | ) 62 | self._force_scene_per_worker = self._config.force_scene_per_worker 63 | self._goal_expr = None 64 | self._is_first_reset = True 65 | self._is_freeform = False 66 | 67 | @property 68 | def tokenizer(self): 69 | return self._tokenizer 70 | 71 | @add_perf_timing_func() 72 | def _load_start_goal(self, episode): 73 | """ 74 | Setup the start and goal PDDL conditions. Will change the simulator 75 | state to set the start state. 76 | """ 77 | 78 | pddl_entities = self.pddl.all_entities 79 | self.pddl.bind_to_instance(self._sim, self._dataset, self, episode) 80 | 81 | self._setup_pddl_entities(episode) 82 | 83 | if self._is_first_reset or not self._force_scene_per_worker: 84 | self.pddl.bind_actions() 85 | self._is_first_reset = False 86 | 87 | self._sim.internal_step(-1) 88 | self._load_sampled_names() 89 | 90 | self._goal_expr = self._load_goal_preds(episode) 91 | if self._goal_expr is not None: 92 | t_start = time.time() 93 | self._goal_expr, _ = self.pddl.expand_quantifiers(self._goal_expr) 94 | self._sim.add_perf_timing("goal_expand_quantifiers", t_start) 95 | self._load_start_preds(episode) 96 | 97 | def _load_sampled_names(self): 98 | t_start = time.time() 99 | self.new_entities: Dict[str, PddlEntity] = {} 100 | for entity_name, entity_conds in self._sample_entities.items(): 101 | match_type = self.pddl.expr_types[entity_conds["type"]] 102 | matches = list(self.pddl.find_entities(match_type)) 103 | 104 | if entity_conds.get("ignore_articulated_receptacles", False): 105 | matches = [ 106 | m 107 | for m in matches 108 | if m.name 109 | not in [ 110 | "receptacle_aabb_middle_topfrl_apartment_refrigerator", 111 | "receptacle_aabb_drawer_left_top_frl_apartment_kitchen_counter", 112 | "receptacle_aabb_drawer_right_top_frl_apartment_kitchen_counter", 113 | ] 114 | ] 115 | 116 | pred_conds = entity_conds.get("pred_conds", []) 117 | matches = _filter_matching_pred_conds( 118 | matches, pred_conds, self.pddl, self.new_entities, entity_name 119 | ) 120 | 121 | if len(matches) == 0: 122 | obj_ref, obj_name = list(self.new_entities.items())[0] 123 | on_top_pred = f"on_top({obj_ref},{entity_name})" 124 | cmp_pred_cond = pred_conds[0].replace(" ", "") 125 | if len(pred_conds) == 1 and cmp_pred_cond == on_top_pred: 126 | use_entity_name: str = self._sim.ep_info.name_to_receptacle[ 127 | obj_name.name 128 | ] 129 | entity = self.pddl.all_entities[use_entity_name] 130 | else: 131 | raise ValueError( 132 | f"Could not find match for {entity_name}: {entity_conds}" 133 | ) 134 | elif self._sample_entities_use_constant_sampling: 135 | entity = matches[0] 136 | else: 137 | entity = random.choice(matches) 138 | 139 | self.new_entities[entity_name] = entity 140 | self._sampled_names = list(self.new_entities.keys()) 141 | self._sim.add_perf_timing("find_entities", t_start) 142 | 143 | @property 144 | def pddl_problem(self): 145 | return self.pddl 146 | 147 | @property 148 | def all_cls(self): 149 | return self._all_cls 150 | 151 | def _get_cls_for_obj_name(self, obj_name, episode): 152 | return self._name_to_cls[obj_name] 153 | 154 | @add_perf_timing_func() 155 | def _setup_pddl_entities(self, episode): 156 | # Register the specific objects in this scene as simulator entities. 157 | for obj_name in self.pddl.sim_info.obj_ids: 158 | asset_name = _strip_instance_id(obj_name) 159 | if asset_name not in self._name_to_cls: 160 | # This is a goal or target object indicator from a Hab2 dataset. 161 | # We don't use these geo-goals, so we can ignore. 162 | continue 163 | cls_name = self._name_to_cls[asset_name] 164 | 165 | entity_type = self.pddl.expr_types[cls_name] 166 | self.pddl.register_episode_entity(PddlEntity(obj_name, entity_type)) 167 | 168 | @add_perf_timing_func() 169 | def _load_goal_preds(self, episode): 170 | if self._goal_template is None: 171 | # Load from the episode. 172 | return self.pddl.parse_only_logical_expr( 173 | episode.goal_preds, self.pddl.all_entities 174 | ) 175 | else: 176 | # Load from the config. 177 | goal_d = dict(self._goal_template) 178 | goal_d = _recur_dict_replace(goal_d, self.new_entities) 179 | return self.pddl.parse_only_logical_expr(goal_d, self.pddl.all_entities) 180 | 181 | @add_perf_timing_func() 182 | def _load_start_preds(self, episode): 183 | if self._start_template is None: 184 | # Load form the episode data. 185 | for pred in episode.start_preds: 186 | pred = self.pddl.parse_predicate(pred, self.pddl.all_entities) 187 | pred.set_state(self.pddl.sim_info) 188 | else: 189 | # Load from the config. 190 | start_preds = self._start_template[:] 191 | for pred in start_preds: 192 | for k, entity in self.new_entities.items(): 193 | pred = pred.replace(k, entity.name) 194 | pred = self.pddl.parse_predicate(pred, self.pddl.all_entities) 195 | pred.set_state(self.pddl.sim_info) 196 | 197 | @property 198 | def goal_expr(self) -> LogicalExpr: 199 | return self._goal_expr 200 | 201 | def set_is_freeform(self, is_freeform): 202 | self._is_freeform = is_freeform 203 | 204 | @property 205 | def subgoals(self): 206 | if self._is_freeform: 207 | return [] 208 | else: 209 | return self._subgoals 210 | 211 | def is_goal_satisfied(self): 212 | if self._is_freeform: 213 | return False 214 | if self._goal_expr is None: 215 | return False 216 | ret = self.pddl.is_expr_true(self._goal_expr) 217 | return ret 218 | 219 | @add_perf_timing_func() 220 | def _get_subgoals(self, episode) -> List[List[Predicate]]: 221 | if episode.subgoals is None: 222 | return [] 223 | # start_preds = self.pddl.get_true_predicates() 224 | ret_subgoals = [] 225 | for subgoal in episode.subgoals: 226 | subgoal_preds = [] 227 | for subgoal_predicate in subgoal: 228 | pred = self.pddl.parse_predicate( 229 | subgoal_predicate, self.pddl.all_entities 230 | ) 231 | subgoal_preds.append(pred) 232 | if len(subgoal_preds) != 0: 233 | ret_subgoals.append(subgoal_preds) 234 | return ret_subgoals 235 | 236 | @add_perf_timing_func() 237 | def step(self, *args, action, **kwargs): 238 | self.pddl.sim_info.reset_pred_truth_cache() 239 | fix_top_down_cam_pos(self._sim) 240 | self.num_steps += 1 241 | if "action_args" not in action: 242 | # This won't be added to discrete action spaces, but RearrangeTask 243 | # expects it. 244 | action["action_args"] = {"sel": action["action"]} 245 | action["action"] = 0 246 | 247 | return super().step(*args, action=action, **kwargs) 248 | 249 | @add_perf_timing_func() 250 | def reset(self, episode): 251 | self.num_steps = 0 252 | 253 | # Generate a random episode ID that can be used for debugging purpose. This should be different than the dataset episode ID, because multiple workers could be operating on the same episode. 254 | self.episode_rollout_id = random.randint(0, 9999999) 255 | 256 | super().reset(episode, fetch_observations=False) 257 | if type(episode) == RearrangeEpisode: 258 | # Hab2 style episode. 259 | episode = convert_rearrange_ep_to_lang_rearrange(episode) 260 | self.lang_goal = episode.instruction 261 | self.instruct_id = episode.instruct_id 262 | self.sampler_info = episode.sampler_info 263 | self._load_start_goal(episode) 264 | 265 | if self._fix_agent_pos: 266 | # Fix the agent position to 0,0 and the rotation to 0. 267 | agent = self._sim.get_agent_data(0).articulated_agent 268 | agent.base_pos = self._sim.pathfinder.snap_point( 269 | mn.Vector3(0.0, agent.base_pos[1], 0.0) 270 | ) 271 | agent.base_rot = 0.0 272 | 273 | self._subgoals = list(self._get_subgoals(episode)) 274 | 275 | self._sim.draw_bb_objs.clear() 276 | fix_top_down_cam_pos(self._sim) 277 | 278 | self._sim.maybe_update_articulated_agent() 279 | return self._get_observations(episode) 280 | 281 | def get_sampled(self) -> List[PddlEntity]: 282 | return [self.new_entities[k] for k in self._sampled_names] 283 | 284 | 285 | def _strip_instance_id(instance_id: str) -> str: 286 | # Strip off the unique instance ID of the object and only return the asset 287 | # name. 288 | return "_".join(instance_id.split("_")[:-1]) 289 | 290 | 291 | def _recur_dict_replace(d: Any, replaces: Dict[str, PddlEntity]) -> Any: 292 | """ 293 | Replace all string entries in `d` with the replace name to PDDL entity 294 | mapping in replaces. 295 | """ 296 | if isinstance(d, ListConfig): 297 | d = list(d) 298 | if isinstance(d, DictConfig): 299 | d = dict(d) 300 | 301 | if isinstance(d, str): 302 | for name, entity in replaces.items(): 303 | d = d.replace(f"{name}.type", entity.expr_type.name) 304 | d = d.replace(name, entity.name) 305 | elif isinstance(d, list): 306 | for i, v in enumerate(d): 307 | d[i] = _recur_dict_replace(v, replaces) 308 | elif isinstance(d, dict): 309 | for k, v in d.items(): 310 | d[k] = _recur_dict_replace(d[k], replaces) 311 | return d 312 | 313 | 314 | def _filter_matching_pred_conds( 315 | matches: List[PddlEntity], 316 | pred_conds: List[str], 317 | pddl: PddlDomain, 318 | other_entities: Dict[str, PddlEntity], 319 | entity_name: str, 320 | ) -> List[PddlEntity]: 321 | """ 322 | Filters elements from matches based on a set of PDDL predicate strings. 323 | """ 324 | 325 | ret = [] 326 | for match in matches: 327 | should_include = True 328 | for pred_cond in pred_conds: 329 | subbed_pred_cond = pred_cond.replace(entity_name, match.name) 330 | for other_entity_name, other_entity in other_entities.items(): 331 | subbed_pred_cond = subbed_pred_cond.replace( 332 | other_entity_name, other_entity.name 333 | ) 334 | pred = pddl.parse_predicate(subbed_pred_cond, pddl.all_entities) 335 | 336 | if not pred.is_true(pddl.sim_info): 337 | should_include = False 338 | break 339 | if not should_include: 340 | continue 341 | ret.append(match) 342 | 343 | return ret 344 | 345 | 346 | def fix_top_down_cam_pos(sim): 347 | if "top_down_rgb" not in sim._sensors: 348 | return 349 | middle_of_scene = np.array(sim.pathfinder.get_bounds()).mean(0) 350 | look_at_mat = mn.Matrix4.look_at( 351 | mn.Vector3(middle_of_scene[0], 6.5, middle_of_scene[0]), 352 | middle_of_scene, 353 | mn.Vector3(0.0, 1.0, 0.0), 354 | ) 355 | rot_mat = mn.Matrix4.from_( 356 | mn.Matrix4.rotation(mn.Rad(np.pi / 2), mn.Vector3(0.0, 1.0, 0.0)).rotation(), 357 | # mn.Vector3(2.5, 0.0, 1.0), 358 | mn.Vector3(0.0, 0.0, 4.0), 359 | ) 360 | sim._sensors["top_down_rgb"]._sensor_object.node.transformation = ( 361 | rot_mat @ look_at_mat 362 | ) 363 | 364 | 365 | def convert_rearrange_ep_to_lang_rearrange( 366 | ep: RearrangeEpisode, 367 | ) -> LangRearrangeEpisode: 368 | return LangRearrangeEpisode( 369 | instruction="", 370 | sampled_entities={}, 371 | start_preds=[], 372 | goal_preds={}, 373 | instruct_id="", 374 | sampler_info={}, 375 | subgoals=[], 376 | **ep.__getstate__(), 377 | ) 378 | -------------------------------------------------------------------------------- /llarp/task/sensors.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os.path as osp 6 | from typing import List 7 | 8 | import gym.spaces as spaces 9 | import numpy as np 10 | import torch 11 | from habitat.core.registry import registry 12 | from habitat.core.simulator import Sensor, SensorTypes 13 | from habitat.tasks.rearrange.multi_task.pddl_predicate import Predicate 14 | from habitat.tasks.rearrange.multi_task.rearrange_pddl import \ 15 | SimulatorObjectType 16 | from habitat.tasks.rearrange.utils import batch_transform_point 17 | 18 | from llarp.dataset.utils import get_category_info, get_name_mappings 19 | from llarp.task.utils import get_parser 20 | 21 | 22 | @registry.register_sensor 23 | class OneHotTargetSensor(Sensor): 24 | def __init__(self, *args, task, **kwargs): 25 | self._task = task 26 | self._n_cls = len(task.all_cls) 27 | self._obs = np.zeros((self._n_cls,)) 28 | super().__init__(*args, **kwargs) 29 | 30 | def _get_uuid(self, *args, **kwargs): 31 | return "one_hot_target_sensor" 32 | 33 | def _get_sensor_type(self, *args, **kwargs): 34 | return SensorTypes.TENSOR 35 | 36 | def _get_observation_space(self, *args, config, **kwargs): 37 | return spaces.Box(shape=(self._n_cls,), low=0, high=1, dtype=np.float32) 38 | 39 | def get_observation(self, *args, **kwargs): 40 | cur_target = self._task.get_sampled()[self.config.sampled_idx] 41 | 42 | # For receptacles the name will not be a class but the name directly. 43 | use_name = cur_target.expr_type.name 44 | if cur_target.name in self._task.all_cls: 45 | use_name = cur_target.name 46 | set_i = self._task.all_cls.index(use_name) 47 | self._obs *= 0.0 48 | self._obs[set_i] = 1.0 49 | 50 | return self._obs 51 | 52 | 53 | @registry.register_sensor 54 | class SimpleTargetSensor(Sensor): 55 | uuid = "simple_lang_goal" 56 | 57 | def __init__(self, *args, config, **kwargs): 58 | self._max_len = config.max_len 59 | self._add_special_tokens = config.add_special_tokens 60 | self._name_mappings = get_name_mappings() 61 | self.cur_lang_goal = "" 62 | super().__init__(*args, config=config, **kwargs) 63 | 64 | def _get_uuid(self, *args, **kwargs): 65 | return SimpleTargetSensor.uuid 66 | 67 | def _get_sensor_type(self, *args, **kwargs): 68 | return SensorTypes.TENSOR 69 | 70 | def _get_observation_space(self, *args, config, **kwargs): 71 | return spaces.Box( 72 | shape=(self._max_len,), 73 | low=0, 74 | high=1, 75 | dtype=np.int64, 76 | ) 77 | 78 | def get_observation(self, *args, task, **kwargs): 79 | cur_target = task.get_sampled()[0] 80 | if cur_target.expr_type.parent.name == SimulatorObjectType.MOVABLE_ENTITY.value: 81 | # This is an object. 82 | target_name = cur_target.expr_type.name 83 | else: 84 | # This is a receptacle. 85 | target_name = cur_target.name 86 | 87 | target_name = self._name_mappings.get(target_name, target_name) 88 | 89 | tokens = task.tokenizer( 90 | target_name, 91 | return_tensors="np", 92 | padding="max_length", 93 | max_length=self._max_len, 94 | add_special_tokens=self._add_special_tokens, 95 | )["input_ids"][0] 96 | 97 | assert ( 98 | tokens.shape[0] == self._max_len 99 | ), f"Instruction is too many tokens at {tokens.shape[0]} but the max is {self._max_len} for instruction {target_name}" 100 | self.cur_lang_goal = target_name 101 | return tokens 102 | 103 | 104 | @registry.register_sensor 105 | class WindowDebugSensor(Sensor): 106 | uuid = "debug_info" 107 | 108 | def _get_uuid(self, *args, **kwargs): 109 | return WindowDebugSensor.uuid 110 | 111 | def _get_sensor_type(self, *args, **kwargs): 112 | return SensorTypes.TENSOR 113 | 114 | def _get_observation_space(self, *args, config, **kwargs): 115 | return spaces.Box( 116 | shape=(2,), 117 | low=np.finfo(np.float32).min, 118 | high=np.finfo(np.float32).max, 119 | dtype=np.float32, 120 | ) 121 | 122 | def get_observation(self, *args, task, **kwargs): 123 | return np.array([task.num_steps + 1, task.episode_rollout_id], dtype=np.float32) 124 | 125 | 126 | @registry.register_sensor 127 | class StepCountSensor(Sensor): 128 | uuid = "step_count" 129 | 130 | def _get_uuid(self, *args, **kwargs): 131 | return StepCountSensor.uuid 132 | 133 | def _get_sensor_type(self, *args, **kwargs): 134 | return SensorTypes.TENSOR 135 | 136 | def _get_observation_space(self, *args, config, **kwargs): 137 | return spaces.Box( 138 | shape=(1,), 139 | low=np.finfo(np.float32).min, 140 | high=np.finfo(np.float32).max, 141 | dtype=np.float32, 142 | ) 143 | 144 | def get_observation(self, *args, task, **kwargs): 145 | return np.array([task.num_steps], dtype=np.float32) 146 | 147 | 148 | @registry.register_sensor 149 | class VocabLangGoalSensor(Sensor): 150 | uuid = "vocab_lang_goal" 151 | 152 | def __init__(self, *args, config, **kwargs): 153 | self._max_len = config.max_len 154 | self._add_special_tokens = config.add_special_tokens 155 | super().__init__(*args, config=config, **kwargs) 156 | 157 | def _get_uuid(self, *args, **kwargs): 158 | return VocabLangGoalSensor.uuid 159 | 160 | def _get_sensor_type(self, *args, **kwargs): 161 | return SensorTypes.TENSOR 162 | 163 | def _get_observation_space(self, *args, config, **kwargs): 164 | return spaces.Box( 165 | shape=(self._max_len,), 166 | low=0, 167 | high=1, 168 | dtype=np.int64, 169 | ) 170 | 171 | def get_observation(self, *args, task, **kwargs): 172 | tokens = task.tokenizer( 173 | task.lang_goal, 174 | return_tensors="np", 175 | padding="max_length", 176 | max_length=self._max_len, 177 | add_special_tokens=self._add_special_tokens, 178 | )["input_ids"][0] 179 | assert ( 180 | tokens.shape[0] == self._max_len 181 | ), f"Instruction is too many tokens at {tokens.shape[0]} but the max is {self._max_len}" 182 | return tokens 183 | 184 | 185 | @registry.register_sensor 186 | class T5VocabLangGoalSensor(Sensor): 187 | """ 188 | Always outputs the T5 tokenization. 189 | """ 190 | 191 | uuid = "t5_vocab_lang_goal" 192 | 193 | def __init__(self, *args, config, **kwargs): 194 | self._max_len = config.max_len 195 | self._tokenizer = get_parser("google/flan-t5-small") 196 | super().__init__(*args, config=config, **kwargs) 197 | 198 | def _get_uuid(self, *args, **kwargs): 199 | return T5VocabLangGoalSensor.uuid 200 | 201 | def _get_sensor_type(self, *args, **kwargs): 202 | return SensorTypes.TENSOR 203 | 204 | def _get_observation_space(self, *args, **kwargs): 205 | return spaces.Box( 206 | shape=(self._max_len,), 207 | low=0, 208 | high=1, 209 | dtype=np.int64, 210 | ) 211 | 212 | def get_observation(self, *args, task, **kwargs): 213 | tokens = self._tokenizer( 214 | task.lang_goal, 215 | return_tensors="np", 216 | padding="max_length", 217 | max_length=self._max_len, 218 | )["input_ids"][0] 219 | assert tokens.shape[0] == self._max_len 220 | return tokens 221 | 222 | 223 | @registry.register_sensor 224 | class LlamaVocabLangGoalSensor(Sensor): 225 | """ 226 | Always outputs the llama tokenization. 227 | """ 228 | 229 | uuid = "llama_vocab_lang_goal" 230 | 231 | def __init__(self, *args, config, **kwargs): 232 | self._max_len = config.max_len 233 | self._tokenizer = get_parser(config.tokenizer_name) 234 | super().__init__(*args, config=config, **kwargs) 235 | 236 | def _get_uuid(self, *args, **kwargs): 237 | return LlamaVocabLangGoalSensor.uuid 238 | 239 | def _get_sensor_type(self, *args, **kwargs): 240 | return SensorTypes.TENSOR 241 | 242 | def _get_observation_space(self, *args, **kwargs): 243 | return spaces.Box( 244 | shape=(self._max_len,), 245 | low=0, 246 | high=1, 247 | dtype=np.int64, 248 | ) 249 | 250 | def get_observation(self, *args, task, **kwargs): 251 | tokens = self._tokenizer( 252 | task.lang_goal, 253 | return_tensors="np", 254 | padding="max_length", 255 | max_length=self._max_len, 256 | )["input_ids"][0] 257 | assert tokens.shape[0] == self._max_len 258 | return tokens 259 | 260 | 261 | @registry.register_sensor 262 | class ObsLangSensor(Sensor): 263 | uuid = "obs_lang" 264 | 265 | def __init__(self, *args, config, task, **kwargs): 266 | self._max_len = config.max_len 267 | self._task = task 268 | self._predicates_list = None 269 | super().__init__(*args, config=config, task=task, **kwargs) 270 | 271 | def _get_uuid(self, *args, **kwargs): 272 | return ObsLangSensor.uuid 273 | 274 | def _get_sensor_type(self, *args, **kwargs): 275 | return SensorTypes.TENSOR 276 | 277 | @property 278 | def predicates_list(self) -> List[Predicate]: 279 | """ 280 | Returns all possible predicate combinations in the environment. 281 | """ 282 | 283 | if self._predicates_list is None: 284 | self._predicates_list = self._task.pddl_problem.get_possible_predicates() 285 | return self._predicates_list 286 | 287 | def _get_observation_space(self, *args, config, **kwargs): 288 | return spaces.Box( 289 | shape=(self._max_len,), 290 | low=0, 291 | high=1, 292 | dtype=np.int64, 293 | ) 294 | 295 | def get_observation(self, *args, **kwargs): 296 | # Fetch the predicates that are true in the current simulator step. 297 | sim_info = self._task.pddl_problem.sim_info 298 | true_preds: List[Predicate] = [ 299 | p for p in self.predicates_list if p.is_true(sim_info) 300 | ] 301 | 302 | # Conver the predicates to a string representation. 303 | true_preds_s: List[str] = [p.compact_str for p in true_preds] 304 | 305 | # Join all the predicate strings by sentences. 306 | state_s = ". ".join(true_preds_s) 307 | 308 | # Return the tokenized version. 309 | return self._task.tokenizer( 310 | state_s, 311 | return_tensors="np", 312 | padding="max_length", 313 | max_length=self._max_len, 314 | truncation=True, 315 | )["input_ids"][0] 316 | 317 | 318 | @registry.register_sensor 319 | class VocabEmbedSensor(Sensor): 320 | def __init__(self, *args, config, **kwargs): 321 | embed_dat = torch.load(config.embed_path) 322 | self._embeddings = embed_dat["hxs"] 323 | self._ep_idx_to_idx = embed_dat["ep_idx_to_hxs_idx"] 324 | super().__init__(*args, config=config, **kwargs) 325 | 326 | def _get_uuid(self, *args, **kwargs): 327 | return "vocab_embed_goal" 328 | 329 | def _get_sensor_type(self, *args, **kwargs): 330 | return SensorTypes.TENSOR 331 | 332 | def _get_observation_space(self, *args, config, **kwargs): 333 | return spaces.Box( 334 | shape=(self.config.hidden_dim,), 335 | low=np.finfo(np.float32).min, 336 | high=np.finfo(np.float32).max, 337 | dtype=np.float32, 338 | ) 339 | 340 | def get_observation(self, *args, task, **kwargs): 341 | pt_idx = self._ep_idx_to_idx[task._episode_id] 342 | return self._embeddings[pt_idx] 343 | 344 | 345 | @registry.register_sensor 346 | class ClosestTargetObjectPosSensor(Sensor): 347 | """ 348 | Gets the distance between the EE and the closest target object. 349 | """ 350 | 351 | @staticmethod 352 | def _get_uuid(*args, **kwargs): 353 | return "closest_targ_obj_pos_sensor" 354 | 355 | def _get_sensor_type(self, *args, **kwargs): 356 | return SensorTypes.TENSOR 357 | 358 | def _get_observation_space(self, *args, config, **kwargs): 359 | return spaces.Box( 360 | shape=(3,), 361 | low=np.finfo(np.float32).min, 362 | high=np.finfo(np.float32).max, 363 | dtype=np.float32, 364 | ) 365 | 366 | def get_observation(self, *args, task, **kwargs): 367 | targ_type = task.get_sampled()[0] 368 | matches = list(task.pddl.find_entities(targ_type.expr_type)) 369 | obj_pos = np.array( 370 | [task.pddl.sim_info.get_entity_pos(match) for match in matches] 371 | ) 372 | 373 | ee_pos = task._sim.articulated_agent.ee_transform().translation 374 | 375 | dists = np.linalg.norm(obj_pos - ee_pos, ord=2, axis=-1) 376 | closest_id = np.argmin(dists) 377 | return obj_pos[closest_id] 378 | 379 | 380 | @registry.register_sensor 381 | class AllObjectPositionsSensor(Sensor): 382 | def __init__(self, *args, task, **kwargs): 383 | self._all_cats, _, _ = get_category_info() 384 | super().__init__(*args, **kwargs) 385 | 386 | def _get_uuid(self, *args, **kwargs): 387 | return "all_obj_pos_sensor" 388 | 389 | def _get_sensor_type(self, *args, **kwargs): 390 | return SensorTypes.TENSOR 391 | 392 | def _get_observation_space(self, *args, config, **kwargs): 393 | return spaces.Box( 394 | shape=(3,), 395 | low=np.finfo(np.float32).min, 396 | high=np.finfo(np.float32).max, 397 | dtype=np.float32, 398 | ) 399 | 400 | def get_observation(self, *args, task, **kwargs): 401 | pddl = task.pddl 402 | entities = pddl.get_ordered_entities_list() 403 | 404 | entity_pos = [pddl.sim_info.get_entity_pos(entity) for entity in entities] 405 | 406 | base_T = task._sim.articulated_agent.base_transformation 407 | entity_pos = batch_transform_point(entity_pos, base_T.inverted(), np.float32) 408 | 409 | entity_type_idxs = [] 410 | for entity in entities: 411 | if entity.name in self._all_cats: 412 | entity_idx = self._all_cats.index(entity.name) 413 | elif entity.expr_type.name in self._all_cats: 414 | entity_idx = self._all_cats.index(entity.expr_type.name) 415 | else: 416 | entity_idx = -1 417 | entity_type_idxs.append(entity_idx) 418 | 419 | return np.array( 420 | [[*pos, idx] for pos, idx in zip(entity_pos, entity_type_idxs)], 421 | dtype=np.float32, 422 | ).reshape(-1) 423 | -------------------------------------------------------------------------------- /llarp/task/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import inspect 6 | import os.path as osp 7 | from typing import Dict, List, Tuple 8 | 9 | import yaml 10 | from habitat.tasks.rearrange.multi_task.pddl_action import PddlAction 11 | from habitat.tasks.rearrange.multi_task.pddl_domain import PddlDomain 12 | from habitat.tasks.rearrange.multi_task.pddl_logical_expr import ( 13 | LogicalExpr, LogicalQuantifierType) 14 | from habitat.tasks.rearrange.multi_task.rearrange_pddl import ( 15 | ExprType, PddlEntity, SimulatorObjectType) 16 | from transformers import (AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, 17 | LlamaForCausalLM, LlamaModel, LlamaTokenizer, 18 | T5Model) 19 | 20 | import llarp.config 21 | import llarp.dataset 22 | 23 | # Also defined in the PDDL 24 | PLACABLE_RECEP_TYPE = "place_receptacle" 25 | 26 | 27 | def get_allowed_actions(pddl, allowed_substrings): 28 | all_actions = pddl.get_possible_actions() 29 | 30 | def matches_any(s, allowed_acs): 31 | # returns if the string starts with any strings from allowed_acs 32 | return any(s.name.startswith(ac) for ac in allowed_acs) 33 | 34 | return [ac for ac in all_actions if matches_any(ac, allowed_substrings)] 35 | 36 | 37 | def _recur_replace(expr, search, replace): 38 | if isinstance(expr, LogicalExpr): 39 | for subexpr in expr.sub_exprs: 40 | _recur_replace(subexpr, search, replace) 41 | else: 42 | for i, arg_val in enumerate(expr._arg_values): 43 | if arg_val == search: 44 | expr._arg_values[i] = replace 45 | 46 | 47 | def flatten_actions(pddl: PddlDomain, obj_cats): 48 | new_acs = {} 49 | for ac_name, action in pddl.actions.items(): 50 | found_i = -1 51 | # TODO: Currently this is a hack for only the pick action. This will 52 | # not work for other PDDL actions. 53 | 54 | for i, param in enumerate(action.params): 55 | if param.expr_type.name == SimulatorObjectType.MOVABLE_ENTITY.value: 56 | found_i = i 57 | break 58 | 59 | if found_i == -1: 60 | new_acs[ac_name] = action 61 | continue 62 | 63 | param = action.params[found_i] 64 | del action.params[found_i] 65 | for obj_cat in obj_cats: 66 | precond = action.precond.clone() 67 | assert len(precond.inputs) == 0, precond.quantifier is None 68 | 69 | obj_cat_type = pddl.expr_types[obj_cat] 70 | at_entity = PddlEntity(name="DYN_OBJ", expr_type=obj_cat_type) 71 | inputs = [at_entity] 72 | 73 | # Ignore the first expression which was about the robot position. 74 | precond = precond.sub_in({param: at_entity}) 75 | 76 | postcond_pred = pddl.parse_predicate( 77 | f"holding({at_entity.name})", {at_entity.name: at_entity} 78 | ) 79 | obj_action = PddlAction( 80 | f"{action.name}_{obj_cat}", 81 | action._params, 82 | pre_cond=LogicalExpr( 83 | precond.expr_type, 84 | precond.sub_exprs, 85 | inputs, 86 | LogicalQuantifierType.EXISTS, 87 | ), 88 | post_cond=[postcond_pred], 89 | ) 90 | 91 | new_acs[obj_action.name] = obj_action 92 | pddl.set_actions(new_acs) 93 | return pddl 94 | 95 | 96 | def get_obj_type(cat, pddl): 97 | return pddl.expr_types[SimulatorObjectType.MOVABLE_ENTITY.value] 98 | 99 | 100 | def get_pddl(task_config, all_cats, obj_cats) -> PddlDomain: 101 | config_path = osp.dirname(inspect.getfile(llarp.config)) 102 | domain_file_path = osp.join( 103 | config_path, 104 | task_config.task_spec_base_path, 105 | task_config.pddl_domain_def + ".yaml", 106 | ) 107 | pddl = PddlDomain( 108 | domain_file_path, 109 | task_config, 110 | ) 111 | 112 | # Add permanent entity types. (Object types and receptacles). 113 | for cat in all_cats: 114 | if cat in obj_cats: 115 | obj_type = get_obj_type(cat, pddl) 116 | entity_type = ExprType(cat, obj_type) 117 | pddl._expr_types[cat] = entity_type 118 | else: 119 | # Assume this is a receptacle in the scene. Permanently place, not per episode. 120 | pddl._constants[cat] = PddlEntity(cat, pddl.expr_types[PLACABLE_RECEP_TYPE]) 121 | 122 | return flatten_actions(pddl, obj_cats) 123 | 124 | 125 | def get_parser(llm_id): 126 | if "llama" in llm_id.lower(): 127 | tokenizer = LlamaTokenizer.from_pretrained(llm_id) 128 | # llama has no pad token by default. As per this thread: 129 | # https://github.com/huggingface/transformers/issues/22312 we should 130 | # set pad token manually. 131 | tokenizer.pad_token = "[PAD]" 132 | return tokenizer 133 | else: 134 | tokenizer = AutoTokenizer.from_pretrained(llm_id) 135 | tokenizer.pad_token = "[PAD]" 136 | return tokenizer 137 | -------------------------------------------------------------------------------- /llarp/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from llarp.trainer.trainer_loop import ILPPOTrainer 6 | from llarp.trainer.transformer_ppo import (DistributedTransformerPPO, 7 | TransformerPPO) 8 | -------------------------------------------------------------------------------- /llarp/trainer/custom_ddp.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import contextlib 6 | import functools 7 | import io 8 | import os 9 | import pickle 10 | import signal 11 | import socket 12 | import subprocess 13 | import threading 14 | from os import path as osp 15 | from typing import (Any, Callable, List, Optional, Tuple, TypeVar, Union, 16 | overload) 17 | 18 | import ifcfg 19 | import numpy as np 20 | import torch 21 | from habitat import logger 22 | from habitat_baselines.rl.ddppo.ddp_utils import (DEFAULT_MAIN_ADDR, 23 | DEFAULT_PORT, 24 | DEFAULT_PORT_RANGE, 25 | SLURM_JOBID, 26 | get_distrib_size, get_ifname) 27 | from omegaconf import DictConfig 28 | from torch import distributed as distrib 29 | 30 | 31 | def is_multi_node(): 32 | return int(os.environ.get("NUM_NODES", 1)) > 1 33 | 34 | 35 | def get_main_addr() -> str: 36 | if is_multi_node(): 37 | k = "MASTER_ADDR" 38 | else: 39 | k = "MAIN_ADDR" 40 | return os.environ.get(k, DEFAULT_MAIN_ADDR) 41 | 42 | 43 | def init_distrib_slurm( 44 | backend: str = "nccl", 45 | ) -> Tuple[int, torch.distributed.TCPStore]: # type: ignore 46 | assert torch.distributed.is_available(), "torch.distributed must be available" 47 | 48 | if "GLOO_SOCKET_IFNAME" not in os.environ: 49 | os.environ["GLOO_SOCKET_IFNAME"] = get_ifname() 50 | 51 | if "NCCL_SOCKET_IFNAME" not in os.environ: 52 | os.environ["NCCL_SOCKET_IFNAME"] = get_ifname() 53 | 54 | local_rank, world_rank, world_size = get_distrib_size() 55 | 56 | if is_multi_node(): 57 | k = "MASTER_PORT" 58 | else: 59 | k = "MAIN_PORT" 60 | main_port = int(os.environ.get(k, DEFAULT_PORT)) 61 | if SLURM_JOBID is not None: 62 | main_port += int(SLURM_JOBID) % int( 63 | os.environ.get("MAIN_PORT_RANGE", DEFAULT_PORT_RANGE) 64 | ) 65 | main_addr = get_main_addr() 66 | 67 | print( 68 | f"Setting up TCP store with {main_addr}, {main_port}, {world_size}, {world_rank} on backend {backend}" 69 | ) 70 | tcp_store = distrib.TCPStore( # type: ignore 71 | main_addr, main_port, world_size, world_rank == 0 72 | ) 73 | distrib.init_process_group( 74 | backend, store=tcp_store, rank=world_rank, world_size=world_size 75 | ) 76 | 77 | return local_rank, tcp_store 78 | -------------------------------------------------------------------------------- /llarp/trainer/custom_env_factory.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | from collections import defaultdict 7 | from typing import TYPE_CHECKING, Any, List, Optional, Type 8 | 9 | import torch 10 | from habitat import ThreadedVectorEnv, VectorEnv, logger, make_dataset 11 | from habitat.config import read_write 12 | from habitat.gym import make_gym_from_config 13 | from habitat_baselines.common.env_factory import VectorEnvFactory 14 | 15 | from llarp.dataset.episodes import LangRearrangeDatasetV0, LangRearrangeEpisode 16 | 17 | if TYPE_CHECKING: 18 | from omegaconf import DictConfig 19 | 20 | 21 | def split_scenes(scenes, n_envs): 22 | # Mapping to env to scenes allowed in this env. 23 | scene_splits: List[List[str]] = [[] for _ in range(n_envs)] 24 | n_scenes = len(scenes) 25 | 26 | # Map all the scenes into an env 27 | for scene_i, scene in enumerate(scenes): 28 | scene_splits[scene_i % n_envs].append(scene) 29 | 30 | # Handle left over envs that didn't get any scenes. 31 | for env_i in range(n_scenes, n_envs): 32 | scene_splits[env_i].append(scenes[env_i % n_scenes]) 33 | return scene_splits 34 | 35 | 36 | def habitat_create_dataset( 37 | config, 38 | is_first_rank, 39 | num_environments, 40 | force_scene_per_worker: bool, 41 | filter_down_num: Optional[int], 42 | filter_instructs: Optional[List[str]] = None, 43 | ): 44 | configs = [] 45 | dataset = make_dataset(config.habitat.dataset.type, config=config.habitat.dataset) 46 | # Limit episodes 47 | if filter_down_num is not None: 48 | dataset.episodes = dataset.episodes[:filter_down_num] 49 | 50 | scenes = dataset.scene_ids 51 | 52 | scene_splits = split_scenes(scenes, num_environments) 53 | 54 | rank = 0 55 | if torch.distributed.is_initialized(): 56 | rank = torch.distributed.get_rank() 57 | 58 | if force_scene_per_worker: 59 | for i in range(num_environments): 60 | n_scenes_in_split = len(scene_splits[i]) 61 | # Offset which scene we are taking by the rank. 62 | scene_splits[i] = [scene_splits[i][rank % n_scenes_in_split]] 63 | 64 | all_env_episodes = defaultdict(list) 65 | ignored_count = 0 66 | worker_to_ep_counts = defaultdict(lambda: defaultdict(int)) 67 | 68 | for ep in dataset.episodes: 69 | match_env_idx = None 70 | for i, scene_split in enumerate(scene_splits): 71 | if ep.scene_id in scene_split: 72 | match_env_idx = i 73 | break 74 | if match_env_idx is None: 75 | ignored_count += 1 76 | continue 77 | all_env_episodes[match_env_idx].append(ep) 78 | if isinstance(ep, LangRearrangeEpisode): 79 | worker_to_ep_counts[match_env_idx][ep.instruct_id] += 1 80 | 81 | if force_scene_per_worker: 82 | n_distinct = len(set(x for scene_split in scene_splits for x in scene_split)) 83 | print(f"Worker using {n_distinct}/{len(set(scenes))} scenes") 84 | # Ensure the objects match per scene. 85 | for env_idx in range(num_environments): 86 | objs = None 87 | for ep in all_env_episodes[env_idx]: 88 | ep_objs = [x[0] for x in ep.rigid_objs] 89 | if objs is None: 90 | objs = ep_objs 91 | elif ep_objs != objs: 92 | raise ValueError("Objects are different within a scene!") 93 | for k, v in worker_to_ep_counts[env_idx].items(): 94 | print(f" - {env_idx}, {k}: {v}") 95 | 96 | env_datasets = [] 97 | unfound_idxs = [] 98 | found_idxs = [] 99 | for env_index in range(num_environments): 100 | print(f"Rank {rank}, Env {env_index}: {scene_splits[env_index]}") 101 | proc_config = config.copy() 102 | with read_write(proc_config): 103 | task_config = proc_config.habitat 104 | task_config.seed = task_config.seed + env_index 105 | remove_measure_names = [] 106 | if not is_first_rank: 107 | # Filter out non rank0_measure from the task config if we are not on rank0. 108 | remove_measure_names.extend(task_config.task.rank0_measure_names) 109 | if (env_index != 0) or not is_first_rank: 110 | # Filter out non-rank0_env0 measures from the task config if we 111 | # are not on rank0 env0. 112 | remove_measure_names.extend(task_config.task.rank0_env0_measure_names) 113 | 114 | task_config.task.measurements = { 115 | k: v 116 | for k, v in task_config.task.measurements.items() 117 | if k not in remove_measure_names 118 | } 119 | 120 | configs.append(proc_config) 121 | 122 | env_episodes = all_env_episodes[env_index] 123 | if len(env_episodes) == 0: 124 | assert env_index != 0 125 | env_episodes = all_env_episodes[env_index - 1] 126 | # raise ValueError("Found no initial episodes for scene") 127 | 128 | # Potentially filter episodes. 129 | if filter_instructs is not None: 130 | env_episodes = [ 131 | ep for ep in env_episodes if ep.instruct_id in filter_instructs 132 | ] 133 | if len(env_episodes) == 0: 134 | # Filtered to no episodes for this scene. Just grab the episodes from the previous worker. 135 | unfound_idxs.append(env_index) 136 | else: 137 | found_idxs.append(env_index) 138 | 139 | env_datasets.append( 140 | LangRearrangeDatasetV0( 141 | config=config.habitat.dataset, preset_eps=env_episodes 142 | ) 143 | ) 144 | 145 | for env_index in unfound_idxs: 146 | found_env_index = found_idxs[env_index % len(found_idxs)] 147 | env_datasets[env_index].episodes = env_datasets[found_env_index].episodes 148 | return env_datasets, configs 149 | 150 | 151 | class CustomVectorEnvFactory(VectorEnvFactory): 152 | def __init__(self): 153 | pass 154 | 155 | def construct_envs( 156 | self, 157 | config: "DictConfig", 158 | workers_ignore_signals: bool = False, 159 | enforce_scenes_greater_eq_environments: bool = False, 160 | is_first_rank: bool = True, 161 | ) -> VectorEnv: 162 | env_datasets, configs = habitat_create_dataset( 163 | config, 164 | is_first_rank, 165 | config.habitat_baselines.num_environments, 166 | config.habitat.task.force_scene_per_worker, 167 | config.habitat.task.filter_down_num, 168 | config.habitat.task.filter_instructs, 169 | ) 170 | 171 | if int(os.environ.get("HABITAT_ENV_DEBUG", 0)): 172 | logger.warn( 173 | "Using the debug Vector environment interface. Expect slower performance." 174 | ) 175 | vector_env_cls = ThreadedVectorEnv 176 | else: 177 | vector_env_cls = VectorEnv 178 | 179 | envs = vector_env_cls( 180 | make_env_fn=make_gym_from_config, 181 | env_fn_args=tuple( 182 | (c, env_dataset) for c, env_dataset in zip(configs, env_datasets) 183 | ), 184 | workers_ignore_signals=workers_ignore_signals, 185 | ) 186 | 187 | if config.habitat.simulator.renderer.enable_batch_renderer: 188 | envs.initialize_batch_renderer(config) 189 | 190 | return envs 191 | -------------------------------------------------------------------------------- /llarp/trainer/env_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import List 6 | 7 | GOAL_K = "lang_goal" 8 | VIS_OBS_K = "rgb_cam" 9 | SUCCESS_K = "env_success" 10 | REWARD_K = "accum_reward" 11 | -------------------------------------------------------------------------------- /llarp/trainer/test_env_factory.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | import random 6 | from collections import defaultdict 7 | 8 | import gym.spaces as spaces 9 | import numpy as np 10 | import torch 11 | from habitat_baselines.common.env_factory import VectorEnvFactory 12 | from habitat_baselines.common.tensor_dict import TensorDict 13 | 14 | 15 | class TestBatchedEnv: 16 | def __init__(self, n_envs, n_max_steps, prob_end, rgb_name, lang_name): 17 | # self._im_shape = (226, 226, 3) 18 | self._im_shape = (336, 336, 3) 19 | self._vocab_shape = (30,) 20 | self._n_max_steps = n_max_steps 21 | self._prob_end = prob_end 22 | self._rgb_name = rgb_name 23 | self._lang_name = lang_name 24 | self.observation_space = spaces.Dict( 25 | { 26 | rgb_name: spaces.Box( 27 | shape=self._im_shape, dtype=np.uint8, low=0, high=255 28 | ), 29 | "debug_info": spaces.Box(shape=(2,), dtype=int, low=0, high=100000000), 30 | lang_name: spaces.Box( 31 | shape=self._vocab_shape, 32 | low=0, 33 | high=1, 34 | dtype=np.int64, 35 | ), 36 | } 37 | ) 38 | 39 | self.action_space = spaces.Discrete(70) 40 | 41 | self.observation_spaces = [self.observation_space for _ in range(n_envs)] 42 | self.action_spaces = [self.action_space for _ in range(n_envs)] 43 | self.orig_action_spaces = self.action_spaces 44 | self._n_envs = n_envs 45 | self._n_steps = defaultdict(lambda: 1) 46 | self._episode_rollout_id = defaultdict(lambda: random.randint(0, 9999999)) 47 | 48 | self.num_envs = self._n_envs 49 | self._obs = {} 50 | self._cur_goal = {} 51 | 52 | def reset(self): 53 | self.num_steps = 1 54 | 55 | ret = [ 56 | { 57 | self._rgb_name: torch.randint( 58 | low=0, high=255, size=self._im_shape, dtype=torch.uint8 59 | ), 60 | "debug_info": np.array( 61 | [self._n_steps[env_i], self._episode_rollout_id[env_i]], 62 | dtype=np.float32, 63 | ), 64 | self._lang_name: torch.randint( 65 | low=0, 66 | high=5_000, 67 | size=self._vocab_shape, 68 | dtype=torch.int, 69 | ), 70 | } 71 | for env_i in range(self._n_envs) 72 | ] 73 | 74 | for i in range(self._n_envs): 75 | self._cur_goal[i] = torch.randint( 76 | low=0, 77 | high=5_000, 78 | size=self._vocab_shape, 79 | dtype=torch.int, 80 | ) 81 | ret[i][self._lang_name] = self._cur_goal[i] 82 | 83 | return ret 84 | 85 | def post_step(self, obs): 86 | return obs 87 | 88 | def async_step_at(self, index_env, action): 89 | self._n_steps[index_env] += 1 90 | self._obs[index_env] = { 91 | self._rgb_name: torch.randint( 92 | low=0, high=255, size=self._im_shape, dtype=torch.uint8 93 | ) 94 | + action, 95 | } 96 | 97 | def wait_step_at(self, index_env): 98 | done = torch.rand(1).item() < self._prob_end 99 | if self._n_steps[index_env] >= self._n_max_steps: 100 | done = True 101 | info = {} 102 | reward = torch.randn(1).item() 103 | if done: 104 | self._n_steps[index_env] = 1 105 | self._episode_rollout_id[index_env] = random.randint(0, 9999999) 106 | self._cur_goal[index_env] = torch.randint( 107 | low=0, 108 | high=5_000, 109 | size=self._vocab_shape, 110 | dtype=torch.int, 111 | ) 112 | self._obs[index_env][self._lang_name] = self._cur_goal[index_env] 113 | 114 | self._obs[index_env]["debug_info"] = np.array( 115 | [self._n_steps[index_env], self._episode_rollout_id[index_env]], 116 | dtype=np.float32, 117 | ) 118 | 119 | return self._obs[index_env], reward, done, info 120 | 121 | def close(self): 122 | return 123 | 124 | 125 | class TestVectorEnvFactory(VectorEnvFactory): 126 | def construct_envs( 127 | self, 128 | config, 129 | workers_ignore_signals: bool = False, 130 | enforce_scenes_greater_eq_environments: bool = False, 131 | is_first_rank: bool = True, 132 | ): 133 | return TestBatchedEnv( 134 | 16, 135 | config.habitat_baselines.test_env_n_max_env_steps, 136 | config.habitat_baselines.test_env_prob_end, 137 | config.habitat_baselines.test_env_rgb_name, 138 | config.habitat_baselines.test_env_lang_name, 139 | ) 140 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="llarp", 5 | packages=setuptools.find_packages(), 6 | version="0.1", 7 | install_requires=[ 8 | "torch==2.0.1", 9 | "transformers==4.31.0", 10 | "einops==0.7.0", 11 | "gym==0.23.0", 12 | "wandb==0.13.1", 13 | "flamingo-pytorch==0.1.2", 14 | "peft==0.4.0", 15 | "sentencepiece", 16 | ], 17 | ) 18 | -------------------------------------------------------------------------------- /test/test_llarp.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | """ 6 | Tests to ensure that LLaRP is training correctly. 7 | """ 8 | 9 | import glob 10 | import os 11 | import random 12 | 13 | import hydra 14 | import numpy as np 15 | import pytest 16 | import torch 17 | from habitat.config import read_write 18 | from habitat.config.default import get_agent_config 19 | from habitat_baselines.common.baseline_registry import baseline_registry 20 | 21 | import llarp.dataset 22 | import llarp.policies 23 | import llarp.task 24 | import llarp.trainer 25 | from llarp.task.predicate_task import RearrangePredicateTask 26 | from llarp.trainer.test_env_factory import TestBatchedEnv 27 | 28 | 29 | def pre_config_setup(): 30 | os.environ["HABITAT_ENV_DEBUG"] = "1" 31 | 32 | hydra.core.global_hydra.GlobalHydra.instance().clear() 33 | hydra.initialize(config_path="../llarp/config/", version_base=None) 34 | 35 | # Remove the checkpoints from previous tests 36 | for f in glob.glob("data/test_checkpoints/test_training/*"): 37 | os.remove(f) 38 | 39 | 40 | def post_config_setup(config): 41 | random.seed(config.habitat.seed) 42 | np.random.seed(config.habitat.seed) 43 | torch.manual_seed(config.habitat.seed) 44 | torch.cuda.manual_seed(config.habitat.seed) 45 | torch.backends.cudnn.deterministic = True 46 | if ( 47 | config.habitat_baselines.force_torch_single_threaded 48 | and torch.cuda.is_available() 49 | ): 50 | torch.set_num_threads(1) 51 | 52 | 53 | @pytest.mark.parametrize( 54 | "num_updates,expected_hash,add_opts", 55 | [ 56 | ( 57 | 1, 58 | 69515.362529343, 59 | [ 60 | "habitat.task.measurements.predicate_task_success.sanity_end_task=True", 61 | ], 62 | ), 63 | ( 64 | 5, 65 | 69498.58239580243, 66 | [ 67 | "habitat_baselines.rl.policy.main_agent.hierarchical_policy.high_level_policy.debug_mode=True", 68 | "habitat.task.measurements.predicate_task_success.sanity_end_task=True", 69 | "habitat_baselines.rl.ppo.ppo_epoch=1", 70 | ], 71 | ), 72 | # Where the context length is greater than the rollout length. 73 | ( 74 | 10, 75 | 69479.58418758967, 76 | [ 77 | "habitat_baselines.rl.ppo.num_steps=5", 78 | "habitat_baselines.rl.policy.main_agent.hierarchical_policy.high_level_policy.debug_mode=True", 79 | "habitat_baselines.rl.ppo.ppo_epoch=1", 80 | ], 81 | ), 82 | # Extended training. 83 | ( 84 | 25, 85 | 69448.26612490416, 86 | [ 87 | "habitat.task.measurements.predicate_task_success.sanity_end_task=True", 88 | ], 89 | ), 90 | ], 91 | ) 92 | def test_model_training(num_updates, expected_hash, add_opts): 93 | pre_config_setup() 94 | # Setup the training 95 | config = hydra.compose( 96 | "baseline/llarp.yaml", 97 | [ 98 | f"habitat_baselines.num_updates={num_updates}", 99 | "habitat_baselines.num_checkpoints=-1", 100 | "habitat_baselines.checkpoint_interval=100000000", 101 | "habitat_baselines.total_num_steps=-1", 102 | "habitat_baselines.checkpoint_folder=data/test_checkpoints/test_training", 103 | "habitat.task.force_scene_per_worker=True", 104 | "habitat_baselines.num_environments=1", 105 | "habitat_baselines.writer_type=tb", 106 | "habitat_baselines.log_interval=1", 107 | "habitat_baselines.rl.ppo.num_mini_batch=1", 108 | "habitat.dataset.data_path=datasets/train_validation.pickle", 109 | "habitat_baselines.rl.policy.main_agent.hierarchical_policy.high_level_policy.llm_wrapper.llm_id=''", 110 | "habitat.simulator.agents.main_agent.joint_start_override=null", 111 | *add_opts, 112 | ], 113 | ) 114 | 115 | with read_write(config): 116 | agent_config = get_agent_config(config.habitat.simulator) 117 | # Changing the visual observation size for speed 118 | for sim_sensor_config in agent_config.sim_sensors.values(): 119 | sim_sensor_config.update({"height": 64, "width": 64}) 120 | 121 | post_config_setup(config) 122 | 123 | trainer_init = baseline_registry.get_trainer(config.habitat_baselines.trainer_name) 124 | trainer = trainer_init(config) 125 | 126 | # Train 127 | trainer.train() 128 | train_hash = sum( 129 | v.sum().item() for v in trainer._agent.actor_critic.state_dict().values() 130 | ) 131 | print(f"Got train hash {train_hash}") 132 | 133 | # With the new context window. 134 | assert train_hash == expected_hash 135 | 136 | 137 | @pytest.mark.parametrize( 138 | "cfg_name,num_updates,expected_hash,add_opts,num_mini_batches", 139 | [ 140 | ("baseline/llarp.yaml", 2, 69518.70178928264, [], 1), 141 | # Longer training. 142 | ("baseline/llarp.yaml", 10, 69522.72576177015, [], 1), 143 | # Longer training with minibatches. 144 | ("baseline/llarp.yaml", 10, 69478.0996894521, [], 4), 145 | ], 146 | ) 147 | def test_model_training_batched( 148 | cfg_name, num_updates, expected_hash, add_opts, num_mini_batches 149 | ): 150 | pre_config_setup() 151 | # Setup the training 152 | config = hydra.compose( 153 | cfg_name, 154 | [ 155 | f"habitat_baselines.num_updates={num_updates}", 156 | "habitat_baselines.num_checkpoints=-1", 157 | "habitat_baselines.checkpoint_interval=100000000", 158 | "habitat_baselines.total_num_steps=-1", 159 | "habitat_baselines.checkpoint_folder=data/test_checkpoints/test_training", 160 | "habitat.task.force_scene_per_worker=True", 161 | "habitat_baselines.rl.policy.main_agent.hierarchical_policy.high_level_policy.debug_mode=True", 162 | "habitat_baselines.num_environments=1", 163 | "habitat_baselines.writer_type=tb", 164 | "habitat_baselines.log_interval=1", 165 | f"habitat_baselines.rl.ppo.num_mini_batch={num_mini_batches}", 166 | "habitat.dataset.data_path=datasets/train_validation.pickle", 167 | "habitat_baselines.rl.policy.main_agent.hierarchical_policy.high_level_policy.llm_wrapper.llm_id=''", 168 | "habitat_baselines.vector_env_factory._target_='llarp.trainer.test_env_factory.TestVectorEnvFactory'", 169 | *add_opts, 170 | ], 171 | ) 172 | 173 | post_config_setup(config) 174 | 175 | trainer_init = baseline_registry.get_trainer(config.habitat_baselines.trainer_name) 176 | trainer = trainer_init(config) 177 | 178 | # Train 179 | trainer.train() 180 | train_hash = sum( 181 | v.sum().item() for v in trainer._agent.actor_critic.state_dict().values() 182 | ) 183 | print(f"Got train hash {train_hash}") 184 | 185 | # With the new context window. 186 | assert train_hash == expected_hash 187 | --------------------------------------------------------------------------------