├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── job_scripts ├── eval_test_multi10.slurm ├── eval_test_multi74.slurm ├── eval_test_peract.slurm ├── eval_val_multi10.slurm ├── eval_val_multi74.slurm ├── eval_val_peract.slurm ├── eval_val_singletask.slurm ├── generate_data.slurm ├── generate_instructions.slurm ├── preprocess_data.slurm ├── train_multitask_bc_10tasks.slurm ├── train_multitask_bc_74tasks.slurm ├── train_multitask_bc_peract.slurm └── train_singletask.slurm ├── overview.png ├── polarnet ├── __init__.py ├── assets │ ├── 10_tasks.csv │ ├── 10_tasks.json │ ├── 10_tasks_var.csv │ ├── 74_tasks.csv │ ├── 74_tasks_per_category.json │ ├── 74_tasks_var.csv │ ├── all_tasks.json │ ├── peract_tasks.csv │ ├── peract_tasks.json │ ├── peract_tasks_var.csv │ ├── tasks_use_table_surface.json │ ├── tasks_with_color.json │ └── taskvar_instructions.jsonl ├── config │ ├── 10tasks.yaml │ ├── 74tasks.yaml │ ├── constants.py │ ├── default.py │ ├── peract.yaml │ └── single_task.yaml ├── core │ ├── actioner.py │ └── environments.py ├── dataloaders │ ├── __init__.py │ ├── keystep_dataset.py │ ├── loader.py │ └── pcd_keystep_dataset.py ├── eval_models.py ├── eval_tst_split.py ├── models │ ├── __init__.py │ ├── base.py │ ├── network_utils.py │ └── pcd_unet.py ├── optim │ ├── __init__.py │ ├── adamw.py │ ├── lookahead.py │ ├── misc.py │ ├── radam.py │ ├── ralamb.py │ ├── rangerlars.py │ └── sched.py ├── preprocess │ ├── evaluate_dataset_keysteps.py │ ├── generate_dataset_keysteps.py │ ├── generate_dataset_microsteps.py │ ├── generate_instructions.py │ ├── generate_pcd_dataset_keysteps.py │ └── generate_real_instructions.py ├── summarize_74tasks_tst_results_by_groups.py ├── summarize_peract_official_tst_results.py ├── summarize_tst_results.py ├── summarize_val_results.py ├── train_models.py └── utils │ ├── __init__.py │ ├── coord_transforms.py │ ├── distributed.py │ ├── keystep_detection.py │ ├── logger.py │ ├── misc.py │ ├── ops.py │ ├── recorder.py │ ├── save.py │ ├── slurm_requeue.py │ ├── utils.py │ └── visualize.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | *.pyc 4 | __pycache__/ 5 | core-python-* 6 | bug-report-* 7 | 8 | data 9 | notebooks 10 | slurm_logs 11 | 12 | *.lock 13 | .~lock.* 14 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "openpoints"] 2 | path = openpoints 3 | url = https://github.com/guochengqian/openpoints/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shizhe Chen & Ricardo Garcia-Pinel 4 | Copyright (c) INRIA 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PolarNet: 3D Point Clouds for Language-Guided Robotic Manipulation 2 | 3 | ![Figure 2 from paper](./overview.png) 4 | 5 | > [PolarNet: 3D Point Clouds for Language-Guided Robotic Manipulation](https://openreview.net/forum?id=efaE7iJ2GJv) 6 | > Shizhe Chen*, Ricardo Garcia*, Cordelia Schmid and Ivan Laptev 7 | > **CoRL 2023** 8 | 9 | *Equal Contribution 10 | 11 | ## Prerequisite 12 | 13 | 1. Installation 14 | 15 | Option 1: Use our [pre-build singularity image](https://drive.google.com/file/d/1i1mNppGWhzEgrQz9G7wP7fb-0iA4mGq4/view?usp=sharing). 16 | 17 | Option 2: Install everything from scratch. 18 | ```bash 19 | conda create --name polarnet python=3.9 20 | conda activate polarnet 21 | ``` 22 | 23 | See instructions in [PyRep](https://github.com/stepjam/PyRep) and [RLBench](https://github.com/stepjam/RLBench) to install RLBench simulator (with VirtualGL in headless machines). Use our modified version of [RLBench](https://github.com/rjgpinel/RLBench) to support additional tasks. 24 | 25 | Install python packages: 26 | ```bash 27 | conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia 28 | conda install -c huggingface transformers 29 | conda install scipy tqdm 30 | 31 | pip install typed-argument-parser lmdb msgpack-numpy tensorboardX 32 | pip install multimethod shortuuid termcolor easydict 33 | pip install yacs jsonlines einops 34 | ``` 35 | 36 | Install Openpoints: 37 | 38 | ```bash 39 | 40 | git submodule update --init 41 | 42 | # install cuda 43 | sudo apt-get remove --purge '^nvidia-.*' 44 | sudo apt autoremove 45 | 46 | # blacklist nouveau: https://blog.csdn.net/threestooegs/article/details/124582963 47 | wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run 48 | sudo sh cuda_11.7.0_515.43.04_linux.run 49 | # If you are a user of Jean-Zay cluster run `module load cuda/11.7.1 && module load cudnn/8.5.0.96-11.7-cuda` instead 50 | 51 | pip install open3d==0.16.0 torch-scatter 52 | conda install pytorch-scatter -c pyg 53 | 54 | cd openpoints/cpp/pointnet2_batch 55 | python setup.py install --user 56 | cd ../ 57 | 58 | cd subsampling 59 | python setup.py build_ext --inplace 60 | cd .. 61 | 62 | cd pointops/ 63 | python setup.py install --user 64 | cd .. 65 | 66 | cd chamfer_dist 67 | python setup.py install --user 68 | cd ../emd 69 | python setup.py install --user 70 | cd ../../../ 71 | ``` 72 | 73 | Finally, install polarnet using: 74 | ```bash 75 | pip install -e . 76 | ``` 77 | 78 | 2. Dataset Generation 79 | 80 | Option 1: Use our [pre-generated datasets](https://drive.google.com/drive/folders/1WvaopPRbQYDkIf5V_bFuetwISsu2ez9E?usp=drive_link) including the keystep trajectories and instruction embeddings for the three setups studied in our paper and data for the 7 real robot tasks (17 variations). Using these datasets will also help reproducibility. 81 | 82 | We recommend downloading the data using [rclone](https://rclone.org/drive/). 83 | 84 | Option 2: generate the dataset on your own. 85 | ```bash 86 | conda activate polarnet 87 | seed=0 88 | task=put_knife_on_chopping_board 89 | variation=0 90 | variation_count=1 91 | 92 | cd ~/Code/polarnet/ 93 | 94 | # 1. generate microstep demonstrations 95 | python -m polarnet.preprocess.generate.dataset_microsteps \ 96 | --save_path data/train_dataset/microsteps/seed{seed} \ 97 | --all_task_file polarnet/assets/all_tasks.json \ 98 | --image_size 128,128 --renderer opengl \ 99 | --episodes_per_task 100 \ 100 | --tasks ${task} --variations ${variation_count} --offset ${variation} \ 101 | --processes 1 --seed ${seed} 102 | 103 | # 2. generate keystep demonstrations 104 | python -m polarnet.preprocess.generate_dataset_keysteps \ 105 | --microstep_data_dir data/train_dataset/microsteps/seed${seed} \ 106 | --keystep_data_dir data/train_dataset/keysteps/seed${seed} \ 107 | --tasks ${task} 108 | 109 | # 3. (optional) check the correctness of generated keysteps 110 | python -m polarnet.preprocess.evaluate_dataset_keysteps \ 111 | --microstep_data_dir data/train_dataset/microsteps/seed${seed} \ 112 | --keystep_data_dir data/train_dataset/keysteps/seed${seed} \ 113 | --tasks ${task} 114 | 115 | # 4. generate instructions embeddings for the tasks 116 | python -m polarnet.preprocess.generate_instructions \ 117 | --encoder clip \ 118 | --output_file data/taskvar_instrs/clip 119 | 120 | # 5. generate preprocessed keysteps demonstrations 121 | python -m polarnet.preprocess.generate_pcd_dataset_keysteps \ 122 | --seed ${seed} \ 123 | --num_cameras 3 \ 124 | --dataset_dir data/train_dataset/ \ 125 | --outname keysteps_pcd \ 126 | ``` 127 | For slurm users, please check scripts in `job_scripts`. 128 | 129 | ## Train 130 | 131 | Our codes support distributed training with multiple GPUs in SLURM clusters. 132 | 133 | For slurm users, please use the following command to launch the training script. 134 | ```bash 135 | sbatch job_scripts/train_multitask_bc_10tasks.sh 136 | ``` 137 | 138 | For non-slurm users, please manually set the environment variables as follows. 139 | 140 | ```bash 141 | export WORLD_SIZE=1 142 | export MASTER_ADDR='localhost' 143 | export MASTER_PORT=10000 144 | 145 | export LOCAL_RANK=0 146 | export RANK=0 147 | export CUDA_VISIBLE_DEVICES=0 148 | 149 | python -m polarnet.train_models --exp-config config/10tasks.yaml 150 | ``` 151 | 152 | You can find PointNeXt pre-trained weights [here](https://drive.google.com/file/d/13qq4QPIlvJF4BwC7zEqG8vj7jcP04DBL/view?usp=drive_link). 153 | 154 | ## Evaluation 155 | 156 | For slurm users, please use the following command to launch the evaluation script. 157 | ```bash 158 | sbatch job_scripts/eval_test_multi10.sh 159 | ``` 160 | 161 | For non-slurm users, run the following commands to evaluate the trained model. 162 | 163 | ```bash 164 | # set outdir to the directory of your trained model 165 | export DISPLAY=:0.0 # in headless machines 166 | 167 | # validation: select the best epoch 168 | for step in {50000..200000..10000} 169 | do 170 | python -m polarnet.eval_models \ 171 | --exp_config ${outdir}/logs/training_config.yaml \ 172 | --seed 100 --num_demos 20 \ 173 | checkpoint ${outdir}/ckpts/model_step_${step}.pt 174 | done 175 | 176 | # run the script to summarize the validation results 177 | python -m polarnet.summarize_val_results --result_file ${outdir}/preds/seed100/results.jsonl 178 | 179 | # test: use a different seed from validation 180 | step=300000 181 | python -m polarnet.eval_models \ 182 | --exp_config ${outdir}/logs/training_config.yaml \ 183 | --seed 200 --num_demos 500 \ 184 | checkpoint ${outdir}/ckpts/model_step_${step}.pt 185 | 186 | # run the script to summarize the testing results 187 | python -m polarnet.summarize_tst_results --result_file ${outdir}/preds/seed200/results.jsonl 188 | ``` 189 | 190 | You can also use in the same manner `summarize_peract_official_tst_results.py` and `summarize_74_tst_results_by_groups.py` to summarize 74 tasks and peract setups results. 191 | 192 | ## Pre-trained models 193 | 194 | You can find a checkpoint for the 10 tasks multi-task setup [here](https://drive.google.com/drive/folders/17bqVpiNyxkXOFzHWqsFEq-Z-VlmARVFp?usp=drive_link). 195 | 196 | | | pick_ and_lift | pick_up _cup | put_knife_on_ chopping_board | put_money _in_safe | push_ button | reach_ target | slide_block _to_target | stack _wine | take_money _out_safe | take_umbrella_out_ of_umbrella_stand | Avg. | 197 | |:------:|:--------------:|:------------:|:----------------------------:|:------------------:|:------------:|:-------------:|:----------------------:|:-----------:|:--------------------:|:------------------------------------:|:-----:| 198 | | seed=0 | 95.40 | 83.80 | 86.00 | 85.40 | 98.80 | 100.00 | 93.20 | 80.20 | 71.40 | 89.80 | 89.20 | 199 | 200 | Other models coming soon... 201 | 202 | ## BibTex 203 | 204 | ``` 205 | @article{chen23polarnet, 206 | author = {Shizhe Chen and Ricardo Garcia and Cordelia Schmid and Ivan Laptev}, 207 | title = {PolarNet: 3D Point Clouds for Language-Guided Robotic Manipulation}, 208 | booktitle = {Conference on Robotic Learning (CoRL)}, 209 | year = {2023} 210 | } 211 | ``` 212 | 213 | ## Acknowledgements 214 | 215 | PointNeXt code from [openpoints](https://github.com/rwightman/pytorch-image-modelshttps://github.com/guochengqian/openpoints) library. 216 | 217 | 218 | -------------------------------------------------------------------------------- /job_scripts/eval_test_multi10.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=vlc_eval_val 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --qos=qos_gpu-t3 8 | #SBATCH --hint=nomultithread 9 | #SBATCH --time=20:00:00 10 | #SBATCH -C v100-32g 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | #SBATCH --array=0-9 14 | 15 | set -x 16 | set -e 17 | 18 | cd ${SLURM_SUBMIT_DIR} 19 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 20 | mkdir $XDG_RUNTIME_DIR 21 | chmod 700 $XDG_RUNTIME_DIR 22 | export PYTHONPATH=/opt/YARR/ 23 | 24 | module purge 25 | pwd; hostname; date; 26 | 27 | task_offset=${task_offset:-1} 28 | seed_default=${seed_default:-0} 29 | task=${task:-water_plants} 30 | 31 | code_dir=$WORK/Code/polarnet/ 32 | 33 | task_file=${task_file:-$code_dir/polarnet/assets/10_tasks_var.csv} 34 | 35 | instr_num=6 36 | if [ ! -z $SLURM_ARRAY_TASK_ID ]; then 37 | num_tasks=$(wc -l < $task_file) 38 | task_id=$(( (${SLURM_ARRAY_TASK_ID} % $num_tasks) + $task_offset )) 39 | taskvar=$(sed -n "${task_id},${task_id}p" $task_file) 40 | task=$(echo $taskvar | awk -F ',' '{ print $1 }') 41 | seed_default=$(( ${SLURM_ARRAY_TASK_ID} / $num_tasks )) 42 | seed=${seed:-$seed_default} 43 | else 44 | seed=${seed:-$seed_default} 45 | fi 46 | 47 | log_dir=$WORK/logs/ 48 | 49 | mkdir -p $log_dir 50 | 51 | module load singularity 52 | 53 | . $WORK/miniconda3/etc/profile.d/conda.sh 54 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 55 | conda activate polarnet 56 | 57 | export PYTHONPATH="$PYTHONPATH:$code_dir" 58 | 59 | models_dir=exprs/10tasks-multi-model/ 60 | instr_embed_file=data/taskvar_instrs/clip/ 61 | 62 | step=90000 63 | 64 | pushd $code_dir/polarnet 65 | # validation: select the best epoch 66 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 67 | singularity exec \ 68 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 69 | $SINGULARITY_ALLOWED_DIR/polarnet.sif \ 70 | xvfb-run -a python eval_tst_split.py \ 71 | --exp_config ${models_dir}/logs/training_config.yaml \ 72 | --seed 200 --num_demos 500 \ 73 | --checkpoint ${models_dir}/ckpts/model_step_${step}.pt \ 74 | --taskvars ${task} \ 75 | --num_workers 1 --instr_embed_file $instr_embed_file 76 | popd 77 | -------------------------------------------------------------------------------- /job_scripts/eval_test_multi74.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=vlc_eval_val 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --qos=qos_gpu-t3 8 | #SBATCH --hint=nomultithread 9 | #SBATCH --time=20:00:00 10 | #SBATCH -C v100-32g 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | #SBATCH --array=0-73 14 | 15 | set -x 16 | set -e 17 | 18 | cd ${SLURM_SUBMIT_DIR} 19 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 20 | mkdir $XDG_RUNTIME_DIR 21 | chmod 700 $XDG_RUNTIME_DIR 22 | export PYTHONPATH=/opt/YARR/ 23 | 24 | module purge 25 | pwd; hostname; date; 26 | 27 | task_offset=${task_offset:-1} 28 | seed_default=${seed_default:-0} 29 | task=${task:-water_plants} 30 | 31 | code_dir=$WORK/Code/polarnet/ 32 | 33 | task_file=${task_file:-$code_dir/polarnet/assets/74_tasks_var.csv} 34 | 35 | instr_num=6 36 | if [ ! -z $SLURM_ARRAY_TASK_ID ]; then 37 | num_tasks=$(wc -l < $task_file) 38 | task_id=$(( (${SLURM_ARRAY_TASK_ID} % $num_tasks) + $task_offset )) 39 | taskvar=$(sed -n "${task_id},${task_id}p" $task_file) 40 | task=$(echo $taskvar | awk -F ',' '{ print $1 }') 41 | seed_default=$(( ${SLURM_ARRAY_TASK_ID} / $num_tasks )) 42 | seed=${seed:-$seed_default} 43 | else 44 | seed=${seed:-$seed_default} 45 | fi 46 | 47 | log_dir=$WORK/logs/ 48 | 49 | mkdir -p $log_dir 50 | 51 | module load singularity 52 | 53 | . $WORK/miniconda3/etc/profile.d/conda.sh 54 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 55 | conda activate polarnet 56 | 57 | export PYTHONPATH="$PYTHONPATH:$code_dir" 58 | 59 | models_dir=exprs/74tasks-multi-model/ 60 | instr_embed_file=data/taskvar_instrs/clip/ 61 | 62 | step=980000 63 | 64 | pushd $code_dir/polarnet 65 | # validation: select the best epoch 66 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 67 | singularity exec \ 68 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 69 | $SINGULARITY_ALLOWED_DIR/polarnet.sif \ 70 | xvfb-run -a python eval_tst_split.py \ 71 | --exp_config ${models_dir}/logs/training_config.yaml \ 72 | --seed 200 --num_demos 500 \ 73 | --checkpoint ${models_dir}/ckpts/model_step_${step}.pt \ 74 | --taskvars ${task} \ 75 | --num_workers 1 --instr_embed_file $instr_embed_file 76 | popd 77 | -------------------------------------------------------------------------------- /job_scripts/eval_test_peract.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=vlc_eval_val 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --qos=qos_gpu-t3 8 | #SBATCH --hint=nomultithread 9 | #SBATCH --time=20:00:00 10 | #SBATCH -C v100-32g 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | #SBATCH --array=0-248 14 | 15 | set -x 16 | set -e 17 | 18 | cd ${SLURM_SUBMIT_DIR} 19 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 20 | mkdir $XDG_RUNTIME_DIR 21 | chmod 700 $XDG_RUNTIME_DIR 22 | export PYTHONPATH=/opt/YARR/ 23 | 24 | module purge 25 | pwd; hostname; date 26 | 27 | task_offset=${task_offset:-1} 28 | seed_default=${seed_default:-0} 29 | task=${task:-reach_and_drag} 30 | 31 | code_dir=$WORK/Code/polarnet/ 32 | 33 | task_file=${task_file:-$code_dir/polarnet/assets/peract_tasks_var.csv} 34 | 35 | if [ ! -z $SLURM_ARRAY_TASK_ID ]; then 36 | num_tasks=$(wc -l < $task_file) 37 | task_id=$(( (${SLURM_ARRAY_TASK_ID} % $num_tasks) + $task_offset )) 38 | taskvar=$(sed -n "${task_id},${task_id}p" $task_file) 39 | task=$(echo $taskvar | awk -F ',' '{ print $1 }') 40 | seed_default=$(( ${SLURM_ARRAY_TASK_ID} / $num_tasks )) 41 | seed=${seed:-$seed_default} 42 | else 43 | seed=${seed:-$seed_default} 44 | fi 45 | 46 | log_dir=$WORK/logs/ 47 | 48 | mkdir -p $log_dir 49 | 50 | module load singularity 51 | 52 | . $WORK/miniconda3/etc/profile.d/conda.sh 53 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 54 | conda activate polarnet 55 | 56 | export PYTHONPATH="$PYTHONPATH:$code_dir" 57 | 58 | models_dir=exprs/peract-multi-model/ 59 | instr_embed_file=data/taskvar_instrs/clip/ 60 | microstep_data_dir=data/peract_data/test/microsteps/ 61 | 62 | step=${step:-590000} 63 | 64 | pushd $code_dir/polarnet 65 | # validation: select the best epoch 66 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 67 | singularity exec \ 68 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 69 | $SINGULARITY_ALLOWED_DIR/polarnet.sif \ 70 | xvfb-run -a python eval_tst_split.py \ 71 | --exp_config ${models_dir}/logs/training_config.yaml \ 72 | --seed 200 --num_demos 500 \ 73 | --checkpoint ${models_dir}/ckpts/model_step_${step}.pt \ 74 | --taskvars $taskvar \ 75 | --microstep_data_dir $microstep_data_dir --microstep_outname microsteps_test_video --num_workers 1 --instr_embed_file $instr_embed_file --record_video 76 | popd 77 | -------------------------------------------------------------------------------- /job_scripts/eval_val_multi10.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=vlc_eval_val 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --qos=qos_gpu-t3 8 | #SBATCH --hint=nomultithread 9 | #SBATCH --time=20:00:00 10 | #SBATCH -C v100-32g 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | #SBATCH --array=0-9 14 | 15 | set -x 16 | set -e 17 | 18 | cd ${SLURM_SUBMIT_DIR} 19 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 20 | mkdir $XDG_RUNTIME_DIR 21 | chmod 700 $XDG_RUNTIME_DIR 22 | export PYTHONPATH=/opt/YARR/ 23 | 24 | module purge 25 | pwd; hostname; date; 26 | 27 | 28 | task_offset=${task_offset:-1} 29 | seed_default=${seed_default:-0} 30 | task=${task:-pick_and_lift} 31 | 32 | code_dir=$WORK/Code/polarnet/ 33 | 34 | task_file=${task_file:-$code_dir/polarnet/assets/10_tasks_var.csv} 35 | 36 | if [ ! -z $SLURM_ARRAY_TASK_ID ]; then 37 | num_tasks=$(wc -l < $task_file) 38 | task_id=$(( (${SLURM_ARRAY_TASK_ID} % $num_tasks) + $task_offset )) 39 | taskvar=$(sed -n "${task_id},${task_id}p" $task_file) 40 | task=$(echo $taskvar | awk -F ',' '{ print $1 }') 41 | seed_default=$(( ${SLURM_ARRAY_TASK_ID} / $num_tasks )) 42 | seed=${seed:-$seed_default} 43 | else 44 | seed=${seed:-$seed_default} 45 | fi 46 | 47 | log_dir=$WORK/logs/ 48 | 49 | mkdir -p $log_dir 50 | 51 | module load singularity 52 | 53 | . $WORK/miniconda3/etc/profile.d/conda.sh 54 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 55 | conda activate polarnet 56 | 57 | export PYTHONPATH="$PYTHONPATH:$code_dir" 58 | 59 | models_dir=exprs/10tasks-multi-model/ 60 | instr_embed_file=data/taskvar_instrs/clip/ 61 | 62 | init_step=${init_step:-50000} 63 | max_step=${max_step:-200000} 64 | step_jump=10000 65 | 66 | pushd $code_dir/polarnet 67 | # validation: select the best epoch 68 | for step in $( eval echo {${init_step}..${max_step}..${step_jump}} ) 69 | do 70 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 71 | singularity exec \ 72 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 73 | $SINGULARITY_ALLOWED_DIR/polarnet.sif \ 74 | xvfb-run -a python eval_tst_split.py \ 75 | --exp_config ${models_dir}/logs/training_config.yaml \ 76 | --seed 100 --num_demos 20 \ 77 | --checkpoint ${models_dir}/ckpts/model_step_${step}.pt \ 78 | --taskvars ${task} \ 79 | --num_workers 1 --instr_embed_file $instr_embed_file 80 | done 81 | popd 82 | -------------------------------------------------------------------------------- /job_scripts/eval_val_multi74.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=vlc_eval_val 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --qos=qos_gpu-t3 8 | #SBATCH --hint=nomultithread 9 | #SBATCH --time=20:00:00 10 | #SBATCH -C v100-32g 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | ##SBATCH --array=0-73 14 | 15 | set -x 16 | set -e 17 | 18 | cd ${SLURM_SUBMIT_DIR} 19 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 20 | mkdir $XDG_RUNTIME_DIR 21 | chmod 700 $XDG_RUNTIME_DIR 22 | export PYTHONPATH=/opt/YARR/ 23 | 24 | module purge 25 | pwd; hostname; date; 26 | 27 | task_offset=${task_offset:-1} 28 | seed_default=${seed_default:-0} 29 | task=${task:-water_plants} 30 | 31 | code_dir=$WORK/Code/polarnet/ 32 | 33 | task_file=${task_file:-$code_dir/polarnet/assets/74_tasks_var.csv} 34 | 35 | if [ ! -z $SLURM_ARRAY_TASK_ID ]; then 36 | num_tasks=$(wc -l < $task_file) 37 | task_id=$(( (${SLURM_ARRAY_TASK_ID} % $num_tasks) + $task_offset )) 38 | taskvar=$(sed -n "${task_id},${task_id}p" $task_file) 39 | task=$(echo $taskvar | awk -F ',' '{ print $1 }') 40 | seed_default=$(( ${SLURM_ARRAY_TASK_ID} / $num_tasks )) 41 | seed=${seed:-$seed_default} 42 | else 43 | seed=${seed:-$seed_default} 44 | fi 45 | 46 | log_dir=$WORK/logs/ 47 | 48 | mkdir -p $log_dir 49 | 50 | module load singularity 51 | 52 | . $WORK/miniconda3/etc/profile.d/conda.sh 53 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 54 | conda activate polarnet 55 | 56 | export PYTHONPATH="$PYTHONPATH:$code_dir" 57 | 58 | models_dir=exprs/74tasks-multi-model/ 59 | instr_embed_file=data/taskvar_instrs/clip/ 60 | 61 | init_step=${init_step:-1000000} 62 | max_step=${max_step:-1000000} 63 | step_jump=20000 64 | 65 | pushd $code_dir/polarnet 66 | # validation: select the best epoch 67 | for step in $( eval echo {${init_step}..${max_step}..${step_jump}} ) 68 | do 69 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 70 | singularity exec \ 71 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 72 | $SINGULARITY_ALLOWED_DIR/polarnet.sif \ 73 | xvfb-run -a python eval_tst_split.py \ 74 | --exp_config ${models_dir}/logs/training_config.yaml \ 75 | --seed 100 --num_demos 20 \ 76 | --checkpoint ${models_dir}/ckpts/model_step_${step}.pt \ 77 | --taskvars ${task} \ 78 | --num_workers 1 --instr_embed_file $instr_embed_file 79 | done 80 | popd 81 | -------------------------------------------------------------------------------- /job_scripts/eval_val_peract.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=polarnet_val 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --qos=qos_gpu-t3 8 | #SBATCH --hint=nomultithread 9 | #SBATCH --time=20:00:00 10 | #SBATCH -C v100-32g 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | #SBATCH --array=0-248 14 | 15 | set -x 16 | set -e 17 | 18 | cd ${SLURM_SUBMIT_DIR} 19 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 20 | mkdir $XDG_RUNTIME_DIR 21 | chmod 700 $XDG_RUNTIME_DIR 22 | export PYTHONPATH=/opt/YARR/ 23 | 24 | module purge 25 | pwd; hostname; date 26 | 27 | task_offset=${task_offset:-1} 28 | seed_default=${seed_default:-0} 29 | task=${task:-reach_and_drag} 30 | 31 | code_dir=$WORK/Code/polarnet/ 32 | 33 | task_file=${task_file:-$code_dir/polarnet/assets/peract_tasks_var.csv} 34 | 35 | if [ ! -z $SLURM_ARRAY_TASK_ID ]; then 36 | num_tasks=$(wc -l < $task_file) 37 | task_id=$(( (${SLURM_ARRAY_TASK_ID} % $num_tasks) + $task_offset )) 38 | taskvar=$(sed -n "${task_id},${task_id}p" $task_file) 39 | task=$(echo $taskvar | awk -F ',' '{ print $1 }') 40 | seed_default=$(( ${SLURM_ARRAY_TASK_ID} / $num_tasks )) 41 | seed=${seed:-$seed_default} 42 | else 43 | seed=${seed:-$seed_default} 44 | fi 45 | 46 | log_dir=$WORK/logs/ 47 | 48 | mkdir -p $log_dir 49 | 50 | module load singularity 51 | 52 | . $WORK/miniconda3/etc/profile.d/conda.sh 53 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 54 | conda activate polarnet 55 | 56 | export PYTHONPATH="$PYTHONPATH:$code_dir" 57 | 58 | models_dir=exprs/peract-multi-model/ 59 | instr_embed_file=data/taskvar_instrs/clip/ 60 | microstep_data_dir=data/peract_data/val/microsteps/ 61 | 62 | init_step=${init_step:-50000} 63 | max_step=${max_step:-600000} 64 | step_jump=10000 65 | 66 | pushd $code_dir/polarnet 67 | # validation: select the best epoch 68 | for step in $( eval echo {${init_step}..${max_step}..${step_jump}} ) 69 | do 70 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 71 | singularity exec \ 72 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 73 | $SINGULARITY_ALLOWED_DIR/polarnet.sif \ 74 | xvfb-run -a python eval_tst_split.py \ 75 | --exp_config ${models_dir}/logs/training_config.yaml \ 76 | --seed 100 --num_demos 20 \ 77 | --checkpoint ${models_dir}/ckpts/model_step_${step}.pt \ 78 | --taskvars $taskvar \ 79 | --microstep_data_dir --microstep_outname microsteps_val_bis --num_workers 1 --instr_embed_file $instr_embed_file 80 | done 81 | popd 82 | -------------------------------------------------------------------------------- /job_scripts/eval_val_singletask.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=vlc_eval_val 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --qos=qos_gpu-t3 8 | #SBATCH --hint=nomultithread 9 | #SBATCH --time=20:00:00 10 | #SBATCH --output=logs/%j.out 11 | #SBATCH --error=logs/%j.out 12 | 13 | set -x 14 | set -e 15 | 16 | cd ${SLURM_SUBMIT_DIR} 17 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 18 | mkdir $XDG_RUNTIME_DIR 19 | chmod 700 $XDG_RUNTIME_DIR 20 | export PYTHONPATH=/opt/YARR/ 21 | 22 | module purge 23 | pwd; hostname; date; 24 | 25 | seed=${seed:-0} 26 | taskvar=${task:-pick_and_lift+0} 27 | 28 | code_dir=$WORK/Code/polarnet/ 29 | 30 | log_dir=$WORK/logs/ 31 | 32 | mkdir -p $log_dir 33 | 34 | module load singularity 35 | 36 | . $WORK/miniconda3/etc/profile.d/conda.sh 37 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 38 | conda activate polarnet 39 | 40 | export PYTHONPATH="$PYTHONPATH:$code_dir" 41 | 42 | models_dir=exprs/${taskvar}_model/seed${seed} 43 | instr_embed_file=data/10tasks_data/taskvar_instrs/clip/ 44 | 45 | init_step=${init_step:-50000} 46 | max_step=${max_step:-100000} 47 | step_jump=10000 48 | 49 | pushd $code_dir/polarnet 50 | # validation: select the best epoch 51 | for step in $( eval echo {${init_step}..${max_step}..${step_jump}} ) 52 | do 53 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 54 | singularity exec \ 55 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 56 | $SINGULARITY_ALLOWED_DIR/polarnet.sif \ 57 | xvfb-run -a python eval_tst_split.py \ 58 | --exp_config ${models_dir}/logs/training_config.yaml \ 59 | --seed 100 --num_demos 20 \ 60 | --checkpoint ${models_dir}/ckpts/model_step_${step}.pt \ 61 | --taskvars ${taskvar} \ 62 | --num_workers 1 --instr_embed_file $instr_embed_file 63 | done 64 | popd 65 | -------------------------------------------------------------------------------- /job_scripts/generate_data.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=generate_data 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | ##SBATCH --qos=qos_gpu-dev 8 | #SBATCH --qos=qos_gpu-t3 9 | #SBATCH --hint=nomultithread 10 | #SBATCH --time=20:00:00 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | #SBATCH --array=0-9 14 | 15 | # This script generates fresh samples 16 | # OOM errors with CPU 17 | 18 | # go into the submission directory 19 | set -x 20 | set -e 21 | 22 | cd ${SLURM_SUBMIT_DIR} 23 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 24 | mkdir $XDG_RUNTIME_DIR 25 | chmod 700 $XDG_RUNTIME_DIR 26 | export PYTHONPATH=/opt/YARR/ 27 | num_episodes=${num_episodes:-100} 28 | task_offset=${task_offset:-1} 29 | seed_default=${seed_default:-0} 30 | task=${task:-pick_and_lift} 31 | 32 | code_dir=$WORK/Code/polarnet/ 33 | task_file=${task_file:-$code_dir/polarnet/assets/'10_tasks.csv'} 34 | 35 | if [ ! -z $SLURM_ARRAY_TASK_ID ]; then 36 | num_tasks=$(wc -l < $task_file) 37 | task_id=$(( (${SLURM_ARRAY_TASK_ID} % $num_tasks) + $task_offset )) 38 | taskvar=$(sed -n "${task_id},${task_id}p" $task_file) 39 | task=$(echo $taskvar | awk -F ',' '{ print $1 }') 40 | seed_default=$(( ${SLURM_ARRAY_TASK_ID} / $num_tasks )) 41 | seed=${seed:-$seed_default} 42 | else 43 | seed=${seed:-$seed_default} 44 | fi 45 | 46 | log_dir=$WORK/logs 47 | data_dir=data/10tasks_data/ 48 | 49 | mkdir -p $data_dir 50 | mkdir -p $log_dir 51 | 52 | module load singularity 53 | 54 | img_size=128 55 | 56 | image=polarnet.sif 57 | 58 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 59 | export PYTHONPATH="$PYTHONPATH:$code_dir" 60 | 61 | pushd $code_dir/polarnet 62 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 63 | singularity exec \ 64 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 65 | $SINGULARITY_ALLOWED_DIR/${image} \ 66 | xvfb-run -a -e $log_dir/$SLURM_JOBID.err \ 67 | /usr/bin/python3.9 preprocess/generate_dataset_microsteps.py \ 68 | --save_path $data_dir/train/microsteps/seed${seed} \ 69 | --all_task_file assets/all_tasks.json \ 70 | --image_size $img_size,$img_size --renderer opengl \ 71 | --episodes_per_task $num_episodes \ 72 | --tasks ${task} --variations 1 --offset 0 \ 73 | --processes 1 --seed ${seed} 74 | 75 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 76 | singularity exec \ 77 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 78 | $SINGULARITY_ALLOWED_DIR/${image} \ 79 | xvfb-run -a -e $log_dir/$SLURM_JOBID.err \ 80 | /usr/bin/python3.9 preprocess/generate_dataset_keysteps.py \ 81 | --microstep_data_dir $data_dir/train/microsteps/seed${seed} \ 82 | --keystep_data_dir $data_dir/train/keysteps/seed${seed} \ 83 | --tasks ${task} 84 | 85 | # check the correctness of generated keysteps 86 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 87 | singularity exec \ 88 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 89 | $SINGULARITY_ALLOWED_DIR/${image} \ 90 | xvfb-run -a -e $log_dir/$SLURM_JOBID.err \ 91 | /usr/bin/python3.9 preprocess/evaluate_dataset_keysteps.py \ 92 | --microstep_data_dir $data_dir/train/microsteps/seed${seed} \ 93 | --keystep_data_dir $data_dir/train/keysteps/seed${seed} \ 94 | --image_size $img_size $img_size \ 95 | --tasks ${task} --headless 96 | popd 97 | -------------------------------------------------------------------------------- /job_scripts/generate_instructions.slurm: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | #SBATCH --job-name=generate_instr 4 | #SBATCH --nodes 1 5 | #SBATCH --ntasks-per-node 1 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --cpus-per-task=10 8 | #SBATCH --qos=qos_gpu-dev 9 | #SBATCH --hint=nomultithread 10 | #SBATCH --time=2:00:00 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | 14 | # This script generates fresh samples 15 | # OOM errors with CPU 16 | 17 | # go into the submission directory 18 | set -x 19 | set -e 20 | 21 | cd ${SLURM_SUBMIT_DIR} 22 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 23 | mkdir $XDG_RUNTIME_DIR 24 | chmod 700 $XDG_RUNTIME_DIR 25 | export PYTHONPATH=/opt/YARR/ 26 | 27 | code_dir=$WORK/Code/polarnet/ 28 | data_dir=data/ 29 | log_dir=logs 30 | 31 | mkdir -p $data_dir 32 | mkdir -p $log_dir 33 | 34 | module load singularity 35 | 36 | export PYTHONPATH="$PYTHONPATH:$code_dir" 37 | 38 | pushd $code_dir/polarnet/ 39 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 40 | singularity exec \ 41 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE,$HOME:$HOME --nv \ 42 | $SINGULARITY_ALLOWED_DIR/polarnet2.sif \ 43 | xvfb-run -a -e $log_dir/$SLURM_JOBID.err \ 44 | /usr/bin/python3.9 preprocess/generate_instructions.py \ 45 | --encoder clip \ 46 | --output_file $data_dir/taskvar_instrs/clip \ 47 | --generate_all_instructions --env_file assets/all_tasks.json 48 | popd -------------------------------------------------------------------------------- /job_scripts/preprocess_data.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=generate_data 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | ##SBATCH --qos=qos_gpu-dev 8 | #SBATCH --qos=qos_gpu-t3 9 | #SBATCH --hint=nomultithread 10 | #SBATCH --time=20:00:00 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | 14 | # This script generates fresh samples 15 | # OOM errors with CPU 16 | 17 | # go into the submission directory 18 | set -x 19 | set -e 20 | 21 | cd ${SLURM_SUBMIT_DIR} 22 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 23 | mkdir $XDG_RUNTIME_DIR 24 | chmod 700 $XDG_RUNTIME_DIR 25 | export PYTHONPATH=/opt/YARR/ 26 | 27 | code_dir=$WORK/Code/polarnet/ 28 | 29 | log_dir=$WORK/logs 30 | data_dir=data/10tasks_data/ 31 | 32 | mkdir -p $log_dir 33 | 34 | module load singularity 35 | 36 | image=polarnet.sif 37 | 38 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 39 | export PYTHONPATH="$PYTHONPATH:$code_dir" 40 | 41 | pushd $code_dir/polarnet 42 | # generate preprocessed keysteps 43 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 44 | singularity exec \ 45 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 46 | $SINGULARITY_ALLOWED_DIR/${image} \ 47 | xvfb-run -a -e $log_dir/$SLURM_JOBID.err \ 48 | /usr/bin/python3.9 preprocess/generate_pcd_dataset_keysteps.py \ 49 | --seed ${seed} \ 50 | --dataset_dir $data_dir/train/ \ 51 | --outname $outname 52 | popd 53 | -------------------------------------------------------------------------------- /job_scripts/train_multitask_bc_10tasks.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=trainbc 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH -C v100-32g 7 | #SBATCH --cpus-per-task=10 8 | #SBATCH --qos=qos_gpu-t3 9 | #SBATCH --hint=nomultithread 10 | #SBATCH --time=20:00:00 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | 14 | set -x 15 | set -e 16 | 17 | module purge 18 | pwd; hostname; date 19 | 20 | code_dir=$WORK/Code/polarnet 21 | export PYTHONPATH="$PYTHONPATH:$code_dir" 22 | 23 | . $WORK/miniconda3/etc/profile.d/conda.sh 24 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 25 | 26 | conda activate polarnet 27 | export PYTHONPATH=$PYTHONPATH:$(pwd) 28 | 29 | export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4)) 30 | export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE)) 31 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 32 | export MASTER_ADDR=$master_addr 33 | 34 | 35 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 36 | mkdir $XDG_RUNTIME_DIR 37 | chmod 700 $XDG_RUNTIME_DIR 38 | 39 | seed=0 40 | taskvars="pick_and_lift+0,pick_up_cup+0,put_knife_on_chopping_board+0,put_money_in_safe+0,push_button+0,reach_target+0,slide_block_to_target+0,stack_wine+0,take_money_out_safe+0,take_umbrella_out_of_umbrella_stand+0" 41 | checkpoint=data/pretrained_models/pointnext-s-c64-enc-dec-sameshape.pt 42 | 43 | output_dir=exprs/10tasks-multi-model/seed${seed} 44 | data_dir=data/10tasks_data/train_dataset/keysteps_pcd/seed${seed}/ 45 | 46 | config_file=$code_dir/polarnet/config/10tasks.yaml 47 | instr_embed_file=data/taskvar_instrs/clip/ 48 | 49 | pushd $code_dir/polarnet 50 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 51 | python train_models.py \ 52 | --exp-config ${config_file} --restart_epoch 0 \ 53 | output_dir ${output_dir} \ 54 | DATASET.dataset_class pre_pcd_keystep_stepwise \ 55 | DATASET.taskvars ${taskvars} DATASET.in_memory False n_workers 8 \ 56 | DATASET.data_dir $data_dir \ 57 | DATASET.instr_embed_file $instr_embed_file \ 58 | DATASET.exclude_overlength_episodes 20 \ 59 | num_train_steps 200000 save_steps 10000 \ 60 | MODEL.dropout 0.0 DATASET.color_drop 0.0 \ 61 | checkpoint_strict_load False \ 62 | MODEL.num_trans_layers 2 train_batch_size 8 \ 63 | checkpoint $checkpoint \ 64 | MODEL.pcd_encoder_cfg.width 64 \ 65 | MODEL.learnable_step_embedding False \ 66 | MODEL.use_prev_action True MODEL.pcd_encoder_cfg.in_channels 10 67 | popd -------------------------------------------------------------------------------- /job_scripts/train_multitask_bc_74tasks.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=trainbc 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH -C v100-32g 7 | #SBATCH --cpus-per-task=10 8 | #SBATCH --qos=qos_gpu-t4 9 | #SBATCH --hint=nomultithread 10 | #SBATCH --time=100:00:00 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | 14 | 15 | 16 | set -x 17 | set -e 18 | 19 | module purge 20 | pwd; hostname; date 21 | 22 | code_dir=$WORK/Code/polarnet 23 | export PYTHONPATH="$PYTHONPATH:$code_dir" 24 | 25 | . $WORK/miniconda3/etc/profile.d/conda.sh 26 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 27 | 28 | conda activate polarnet 29 | export PYTHONPATH=$PYTHONPATH:$(pwd) 30 | 31 | export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4)) 32 | export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE)) 33 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 34 | export MASTER_ADDR=$master_addr 35 | 36 | 37 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 38 | mkdir $XDG_RUNTIME_DIR 39 | chmod 700 $XDG_RUNTIME_DIR 40 | 41 | echo $SINGULARITY_ALLOWED_DIR 42 | export SINGULARITY_ALLOWED_DIR=/gpfsssd/singularity/images/uqn73qm 43 | module load singularity 44 | 45 | seed=0 46 | taskvars=$code_dir/polarnet/assets/74_tasks_var.csv 47 | checkpoint=pretrained_models/pointnext-s-c64-enc-dec-sameshape.pt 48 | output_dir=exprs/74tasks-multi-model/seed${seed}/ 49 | config_file=$code_dir/polarnet/config/74tasks.yaml 50 | data_dir=data/74tasks/keysteps_pcd/seed${seed}/ 51 | instr_embed_file=data/taskvar_instrs/clip/ 52 | 53 | pushd $code_dir/polarnet 54 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 55 | singularity exec \ 56 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 57 | $SINGULARITY_ALLOWED_DIR/polarnet.sif \ 58 | xvfb-run -a python train_models.py \ 59 | --exp-config ${config_file} --restart_epoch 0 \ 60 | output_dir ${output_dir} \ 61 | DATASET.dataset_class pre_pcd_keystep_stepwise \ 62 | DATASET.taskvars $taskvars \ 63 | DATASET.in_memory False n_workers 8 \ 64 | DATASET.data_dir $data_dir \ 65 | DATASET.instr_embed_file $instr_embed_file \ 66 | DATASET.exclude_overlength_episodes 20 \ 67 | num_train_steps 1000000 save_steps 10000 \ 68 | MODEL.dropout 0.0 DATASET.color_drop 0.0 \ 69 | checkpoint_strict_load False \ 70 | MODEL.num_trans_layers 2 train_batch_size 8 \ 71 | checkpoint $checkpoint \ 72 | MODEL.pcd_encoder_cfg.width 64 \ 73 | MODEL.learnable_step_embedding False \ 74 | MODEL.pcd_encoder_cfg.in_channels 10 \ 75 | MODEL.use_prev_action True 76 | popd 77 | -------------------------------------------------------------------------------- /job_scripts/train_multitask_bc_peract.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=trainbc 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 4 5 | #SBATCH --gres=gpu:4 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH -C v100-32g 8 | #SBATCH --qos=qos_gpu-t4 9 | #SBATCH --hint=nomultithread 10 | #SBATCH --time=100:00:00 11 | #SBATCH --output=logs/%j.out 12 | #SBATCH --error=logs/%j.out 13 | 14 | set -x 15 | set -e 16 | 17 | module purge 18 | pwd; hostname; date 19 | 20 | code_dir=$WORK/Code/polarnet 21 | export PYTHONPATH="$PYTHONPATH:$code_dir" 22 | 23 | . $WORK/miniconda3/etc/profile.d/conda.sh 24 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 25 | 26 | conda activate polarnet 27 | 28 | export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4)) 29 | export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE)) 30 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 31 | export MASTER_ADDR=$master_addr 32 | 33 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 34 | mkdir $XDG_RUNTIME_DIR 35 | chmod 700 $XDG_RUNTIME_DIR 36 | 37 | echo $SINGULARITY_ALLOWED_DIR 38 | export SINGULARITY_ALLOWED_DIR=/gpfsssd/singularity/images/uqn73qm 39 | module load singularity 40 | 41 | seed=0 42 | taskvars=$code_dir/polarnet/assets/peract_tasks_var.csv 43 | checkpoint=pretrained_models/pointnext-s-c64-enc-dec-sameshape.pt 44 | output_dir=exprs/peract-multi-model_bis/ 45 | config_file=$code_dir/polarnet/config/peract.yaml 46 | data_dir=data/peract/keysteps_pcd/ 47 | instr_embed_file=data/taskvar_instrs/clip/ 48 | 49 | pushd $code_dir/polarnet 50 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 51 | singularity exec \ 52 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 53 | $SINGULARITY_ALLOWED_DIR/polarnet.sif \ 54 | xvfb-run -a python train_models.py \ 55 | --exp-config ${config_file} --restart_epoch 0 \ 56 | output_dir ${output_dir} \ 57 | DATASET.taskvars $taskvars \ 58 | DATASET.data_dir $data_dir \ 59 | DATASET.instr_embed_file $instr_embed_file \ 60 | num_train_steps 600000 save_steps 10000 \ 61 | MODEL.dropout 0.0 DATASET.color_drop 0.0 \ 62 | checkpoint $checkpoint \ 63 | MODEL.pcd_encoder_cfg.width 64 \ 64 | checkpoint_strict_load False \ 65 | MODEL.max_steps 25 \ 66 | MODEL.num_trans_layers 2 train_batch_size 4 \ 67 | DATASET.max_steps_per_episode 12 \ 68 | DATASET.multi_instruction True \ 69 | MODEL.use_prev_action True \ 70 | MODEL.pcd_encoder_cfg.in_channels 10 71 | popd 72 | -------------------------------------------------------------------------------- /job_scripts/train_singletask.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=trainbc 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --qos=qos_gpu-t3 8 | #SBATCH --hint=nomultithread 9 | #SBATCH --time=20:00:00 10 | #SBATCH --output=logs/%j.out 11 | #SBATCH --error=logs/%j.out 12 | 13 | set -x 14 | set -e 15 | 16 | module purge 17 | pwd; hostname; date 18 | 19 | code_dir=$WORK/Code/polarnet/ 20 | export PYTHONPATH="$PYTHONPATH:$code_dir" 21 | 22 | . $WORK/miniconda3/etc/profile.d/conda.sh 23 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 24 | 25 | conda activate polarnet 26 | 27 | export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4)) 28 | export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE)) 29 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 30 | export MASTER_ADDR=$master_addr 31 | 32 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 33 | mkdir $XDG_RUNTIME_DIR 34 | chmod 700 $XDG_RUNTIME_DIR 35 | 36 | module load singularity 37 | 38 | seed=0 39 | taskvar=pick_and_lift+0 40 | checkpoint=data/pretrained_models/pointnext-s-c64-enc-dec-sameshape.pt 41 | output_dir=data/exprs/${taskvar}_model/seed${seed} 42 | config_file=$code_dir/polarnet/config/single_task.yaml 43 | instr_embed_file=data/10tasks_data/taskvar_instrs/clip/ 44 | data_dir=data/10tasks_data/train/keysteps_pcd/seed${seed} 45 | 46 | pushd $code_dir/polarnet 47 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 48 | singularity exec \ 49 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 50 | $SINGULARITY_ALLOWED_DIR/polarnet.sif \ 51 | xvfb-run -a python train_models.py \ 52 | --exp-config $config_file \ 53 | output_dir $output_dir \ 54 | DATASET.taskvars ${taskvar} \ 55 | DATASET.data_dir ${data_dir} \ 56 | DATASET.instr_embed_file $instr_embed_file \ 57 | num_train_steps 100000 save_steps 2000 \ 58 | checkpoint $checkpoint 59 | popd 60 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlc-robot/polarnet/882b7ef5b82ee4c7779cdd0020f58e919e0f8bce/overview.png -------------------------------------------------------------------------------- /polarnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlc-robot/polarnet/882b7ef5b82ee4c7779cdd0020f58e919e0f8bce/polarnet/__init__.py -------------------------------------------------------------------------------- /polarnet/assets/10_tasks.csv: -------------------------------------------------------------------------------- 1 | pick_and_lift 2 | pick_up_cup 3 | put_knife_on_chopping_board 4 | put_money_in_safe 5 | push_button 6 | reach_target 7 | slide_block_to_target 8 | stack_wine 9 | take_money_out_safe 10 | take_umbrella_out_of_umbrella_stand 11 | -------------------------------------------------------------------------------- /polarnet/assets/10_tasks.json: -------------------------------------------------------------------------------- 1 | [ 2 | "pick_and_lift", 3 | "pick_up_cup", 4 | "put_knife_on_chopping_board", 5 | "put_money_in_safe", 6 | "push_button", 7 | "reach_target", 8 | "slide_block_to_target", 9 | "stack_wine", 10 | "take_money_out_safe", 11 | "take_umbrella_out_of_umbrella_stand" 12 | ] -------------------------------------------------------------------------------- /polarnet/assets/10_tasks_var.csv: -------------------------------------------------------------------------------- 1 | pick_and_lift+0 2 | pick_up_cup+0 3 | put_knife_on_chopping_board+0 4 | put_money_in_safe+0 5 | push_button+0 6 | reach_target+0 7 | slide_block_to_target+0 8 | stack_wine+0 9 | take_money_out_safe+0 10 | take_umbrella_out_of_umbrella_stand+0 11 | -------------------------------------------------------------------------------- /polarnet/assets/74_tasks.csv: -------------------------------------------------------------------------------- 1 | basketball_in_hoop 2 | beat_the_buzz 3 | change_channel 4 | change_clock 5 | close_box 6 | close_door 7 | close_drawer 8 | close_fridge 9 | close_grill 10 | close_laptop_lid 11 | close_microwave 12 | hang_frame_on_hanger 13 | insert_onto_square_peg 14 | insert_usb_in_computer 15 | lamp_off 16 | lamp_on 17 | lift_numbered_block 18 | meat_off_grill 19 | meat_on_grill 20 | move_hanger 21 | open_box 22 | open_door 23 | open_drawer 24 | open_fridge 25 | open_grill 26 | open_microwave 27 | open_oven 28 | open_window 29 | open_wine_bottle 30 | phone_on_base 31 | pick_and_lift 32 | pick_and_lift_small 33 | pick_up_cup 34 | place_hanger_on_rack 35 | place_shape_in_shape_sorter 36 | play_jenga 37 | plug_charger_in_power_supply 38 | press_switch 39 | push_button 40 | push_buttons 41 | put_books_on_bookshelf 42 | put_knife_on_chopping_board 43 | put_money_in_safe 44 | put_rubbish_in_bin 45 | put_umbrella_in_umbrella_stand 46 | reach_and_drag 47 | reach_target 48 | scoop_with_spatula 49 | screw_nail 50 | setup_checkers 51 | slide_block_to_target 52 | slide_cabinet_open_and_place_cups 53 | stack_blocks 54 | stack_cups 55 | stack_wine 56 | straighten_rope 57 | sweep_to_dustpan 58 | take_frame_off_hanger 59 | take_lid_off_saucepan 60 | take_money_out_safe 61 | take_plate_off_colored_dish_rack 62 | take_shoes_out_of_box 63 | take_toilet_roll_off_stand 64 | take_umbrella_out_of_umbrella_stand 65 | take_usb_out_of_computer 66 | toilet_seat_down 67 | toilet_seat_up 68 | tower3 69 | turn_oven_on 70 | turn_tap 71 | tv_on 72 | unplug_charger 73 | water_plants 74 | wipe_desk 75 | -------------------------------------------------------------------------------- /polarnet/assets/74_tasks_per_category.json: -------------------------------------------------------------------------------- 1 | { 2 | "planning": [ 3 | "basketball_in_hoop", 4 | "put_rubbish_in_bin", 5 | "meat_off_grill", 6 | "meat_on_grill", 7 | "change_channel", 8 | "tv_on", 9 | "tower3", 10 | "push_buttons", 11 | "stack_wine" 12 | ], 13 | "tools": [ 14 | "slide_block_to_target", 15 | "reach_and_drag", 16 | "take_frame_off_hanger", 17 | "water_plants", 18 | "hang_frame_on_hanger", 19 | "scoop_with_spatula", 20 | "place_hanger_on_rack", 21 | "move_hanger", 22 | "sweep_to_dustpan", 23 | "take_plate_off_colored_dish_rack", 24 | "screw_nail" 25 | ], 26 | "long_term": [ 27 | "wipe_desk", 28 | "stack_blocks", 29 | "take_shoes_out_of_box", 30 | "slide_cabinet_open_and_place_cups" 31 | ], 32 | "rotation-invariant": [ 33 | "reach_target", 34 | "push_button", 35 | "lamp_on", 36 | "lamp_off", 37 | "pick_and_lift", 38 | "take_lid_off_saucepan" 39 | ], 40 | "motion-planner": [ 41 | "toilet_seat_down", 42 | "close_laptop_lid", 43 | "open_box", 44 | "open_drawer", 45 | "close_drawer", 46 | "close_box", 47 | "phone_on_base", 48 | "toilet_seat_up", 49 | "put_books_on_bookshelf" 50 | ], 51 | "multimodal": [ 52 | "pick_up_cup", 53 | "turn_tap", 54 | "lift_numbered_block", 55 | "beat_the_buzz", 56 | "stack_cups" 57 | ], 58 | "precision": [ 59 | "take_usb_out_of_computer", 60 | "play_jenga", 61 | "insert_onto_square_peg", 62 | "take_umbrella_out_of_umbrella_stand", 63 | "insert_usb_in_computer", 64 | "straighten_rope", 65 | "pick_and_lift_small", 66 | "put_knife_on_chopping_board", 67 | "place_shape_in_shape_sorter", 68 | "take_toilet_roll_off_stand", 69 | "put_umbrella_in_umbrella_stand", 70 | "setup_checkers" 71 | ], 72 | "screw": [ 73 | "turn_oven_on", 74 | "change_clock", 75 | "open_window", 76 | "open_wine_bottle" 77 | ], 78 | "visual_occlusion": [ 79 | "close_microwave", 80 | "close_fridge", 81 | "close_grill", 82 | "open_grill", 83 | "unplug_charger", 84 | "press_switch", 85 | "take_money_out_safe", 86 | "open_microwave", 87 | "put_money_in_safe", 88 | "open_door", 89 | "close_door", 90 | "open_fridge", 91 | "open_oven", 92 | "plug_charger_in_power_supply" 93 | ] 94 | } -------------------------------------------------------------------------------- /polarnet/assets/74_tasks_var.csv: -------------------------------------------------------------------------------- 1 | basketball_in_hoop+0 2 | beat_the_buzz+0 3 | change_channel+0 4 | change_clock+0 5 | close_box+0 6 | close_door+0 7 | close_drawer+0 8 | close_fridge+0 9 | close_grill+0 10 | close_laptop_lid+0 11 | close_microwave+0 12 | hang_frame_on_hanger+0 13 | insert_onto_square_peg+0 14 | insert_usb_in_computer+0 15 | lamp_off+0 16 | lamp_on+0 17 | lift_numbered_block+0 18 | meat_off_grill+0 19 | meat_on_grill+0 20 | move_hanger+0 21 | open_box+0 22 | open_door+0 23 | open_drawer+0 24 | open_fridge+0 25 | open_grill+0 26 | open_microwave+0 27 | open_oven+0 28 | open_window+0 29 | open_wine_bottle+0 30 | phone_on_base+0 31 | pick_and_lift+0 32 | pick_and_lift_small+0 33 | pick_up_cup+0 34 | place_hanger_on_rack+0 35 | place_shape_in_shape_sorter+0 36 | play_jenga+0 37 | plug_charger_in_power_supply+0 38 | press_switch+0 39 | push_button+0 40 | push_buttons+0 41 | put_books_on_bookshelf+0 42 | put_knife_on_chopping_board+0 43 | put_money_in_safe+0 44 | put_rubbish_in_bin+0 45 | put_umbrella_in_umbrella_stand+0 46 | reach_and_drag+0 47 | reach_target+0 48 | scoop_with_spatula+0 49 | screw_nail+0 50 | setup_checkers+0 51 | slide_block_to_target+0 52 | slide_cabinet_open_and_place_cups+0 53 | stack_blocks+0 54 | stack_cups+0 55 | stack_wine+0 56 | straighten_rope+0 57 | sweep_to_dustpan+0 58 | take_frame_off_hanger+0 59 | take_lid_off_saucepan+0 60 | take_money_out_safe+0 61 | take_plate_off_colored_dish_rack+0 62 | take_shoes_out_of_box+0 63 | take_toilet_roll_off_stand+0 64 | take_umbrella_out_of_umbrella_stand+0 65 | take_usb_out_of_computer+0 66 | toilet_seat_down+0 67 | toilet_seat_up+0 68 | tower3+0 69 | turn_oven_on+0 70 | turn_tap+0 71 | tv_on+0 72 | unplug_charger+0 73 | water_plants+0 74 | wipe_desk+0 75 | -------------------------------------------------------------------------------- /polarnet/assets/all_tasks.json: -------------------------------------------------------------------------------- 1 | [ 2 | "basketball_in_hoop", 3 | "beat_the_buzz", 4 | "block_pyramid", 5 | "change_channel", 6 | "change_clock", 7 | "close_box", 8 | "close_door", 9 | "close_drawer", 10 | "close_fridge", 11 | "close_grill", 12 | "close_jar", 13 | "close_laptop_lid", 14 | "close_microwave", 15 | "empty_container", 16 | "empty_dishwasher", 17 | "get_ice_from_fridge", 18 | "hang_frame_on_hanger", 19 | "hit_ball_with_queue", 20 | "hockey", 21 | "insert_onto_square_peg", 22 | "insert_usb_in_computer", 23 | "lamp_off", 24 | "lamp_on", 25 | "lift_numbered_block", 26 | "light_bulb_in", 27 | "light_bulb_out", 28 | "meat_off_grill", 29 | "meat_on_grill", 30 | "move_hanger", 31 | "open_box", 32 | "open_door", 33 | "open_drawer", 34 | "open_fridge", 35 | "open_grill", 36 | "open_jar", 37 | "open_microwave", 38 | "open_oven", 39 | "open_washing_machine", 40 | "open_window", 41 | "open_wine_bottle", 42 | "phone_on_base", 43 | "pick_and_lift", 44 | "pick_and_lift_small", 45 | "pick_up_cup", 46 | "place_cups", 47 | "place_hanger_on_rack", 48 | "place_shape_in_shape_sorter", 49 | "play_jenga", 50 | "plug_charger_in_power_supply", 51 | "pour_from_cup_to_cup", 52 | "press_switch", 53 | "push_button", 54 | "push_buttons", 55 | "put_all_groceries_in_cupboard", 56 | "put_books_on_bookshelf", 57 | "put_bottle_in_fridge", 58 | "put_groceries_in_cupboard", 59 | "put_item_in_drawer", 60 | "put_knife_in_knife_block", 61 | "put_knife_on_chopping_board", 62 | "put_money_in_safe", 63 | "put_plate_in_colored_dish_rack", 64 | "put_rubbish_in_bin", 65 | "put_shoes_in_box", 66 | "put_toilet_roll_on_stand", 67 | "put_tray_in_oven", 68 | "put_umbrella_in_umbrella_stand", 69 | "reach_and_drag", 70 | "reach_target", 71 | "remove_cups", 72 | "scoop_with_spatula", 73 | "screw_nail", 74 | "set_the_table", 75 | "setup_checkers", 76 | "setup_chess", 77 | "slide_block_to_target", 78 | "slide_cabinet_open", 79 | "slide_cabinet_open_and_place_cups", 80 | "solve_puzzle", 81 | "stack_blocks", 82 | "stack_chairs", 83 | "stack_cups", 84 | "stack_wine", 85 | "straighten_rope", 86 | "sweep_to_dustpan", 87 | "take_cup_out_from_cabinet", 88 | "take_frame_off_hanger", 89 | "take_item_out_of_drawer", 90 | "take_lid_off_saucepan", 91 | "take_money_out_safe", 92 | "take_off_weighing_scales", 93 | "take_plate_off_colored_dish_rack", 94 | "take_shoes_out_of_box", 95 | "take_toilet_roll_off_stand", 96 | "take_tray_out_of_oven", 97 | "take_umbrella_out_of_umbrella_stand", 98 | "take_usb_out_of_computer", 99 | "toilet_seat_down", 100 | "toilet_seat_up", 101 | "tower3", 102 | "turn_oven_on", 103 | "turn_tap", 104 | "tv_on", 105 | "unplug_charger", 106 | "water_plants", 107 | "weighing_scales", 108 | "wipe_desk", 109 | "close_jar_peract", 110 | "insert_onto_square_peg_peract", 111 | "light_bulb_in_peract", 112 | "meat_off_grill_peract", 113 | "open_drawer_peract", 114 | "place_cups_peract", 115 | "place_shape_in_shape_sorter_peract", 116 | "place_wine_at_rack_location_peract", 117 | "push_buttons_peract", 118 | "put_groceries_in_cupboard_peract", 119 | "put_item_in_drawer_peract", 120 | "put_money_in_safe_peract", 121 | "reach_and_drag_peract", 122 | "slide_block_to_color_target_peract", 123 | "stack_blocks_peract", 124 | "stack_cups_peract", 125 | "sweep_to_dustpan_of_size_peract", 126 | "turn_tap_peract" 127 | ] -------------------------------------------------------------------------------- /polarnet/assets/peract_tasks.csv: -------------------------------------------------------------------------------- 1 | open_drawer_peract 2 | slide_block_to_color_target_peract 3 | sweep_to_dustpan_of_size_peract 4 | meat_off_grill_peract 5 | turn_tap_peract 6 | put_item_in_drawer_peract 7 | close_jar_peract 8 | reach_and_drag_peract 9 | stack_blocks_peract 10 | light_bulb_in_peract 11 | put_money_in_safe_peract 12 | place_wine_at_rack_location_peract 13 | put_groceries_in_cupboard_peract 14 | place_shape_in_shape_sorter_peract 15 | push_buttons_peract 16 | insert_onto_square_peg_peract 17 | stack_cups_peract 18 | place_cups_peract 19 | -------------------------------------------------------------------------------- /polarnet/assets/peract_tasks.json: -------------------------------------------------------------------------------- 1 | [ 2 | "close_jar_peract", 3 | "insert_onto_square_peg_peract", 4 | "light_bulb_in_peract", 5 | "meat_off_grill_peract", 6 | "open_drawer_peract", 7 | "place_cups_peract", 8 | "place_shape_in_shape_sorter_peract", 9 | "place_wine_at_rack_location_peract", 10 | "push_buttons_peract", 11 | "put_groceries_in_cupboard_peract", 12 | "put_item_in_drawer_peract", 13 | "put_money_in_safe_peract", 14 | "reach_and_drag_peract", 15 | "slide_block_to_color_target_peract", 16 | "stack_blocks_peract", 17 | "stack_cups_peract", 18 | "sweep_to_dustpan_of_size_peract", 19 | "turn_tap_peract" 20 | ] 21 | -------------------------------------------------------------------------------- /polarnet/assets/peract_tasks_var.csv: -------------------------------------------------------------------------------- 1 | close_jar_peract+0 2 | close_jar_peract+1 3 | close_jar_peract+2 4 | close_jar_peract+3 5 | close_jar_peract+4 6 | close_jar_peract+5 7 | close_jar_peract+6 8 | close_jar_peract+7 9 | close_jar_peract+8 10 | close_jar_peract+9 11 | close_jar_peract+10 12 | close_jar_peract+11 13 | close_jar_peract+12 14 | close_jar_peract+13 15 | close_jar_peract+14 16 | close_jar_peract+15 17 | close_jar_peract+16 18 | close_jar_peract+17 19 | close_jar_peract+18 20 | close_jar_peract+19 21 | insert_onto_square_peg_peract+0 22 | insert_onto_square_peg_peract+1 23 | insert_onto_square_peg_peract+2 24 | insert_onto_square_peg_peract+3 25 | insert_onto_square_peg_peract+4 26 | insert_onto_square_peg_peract+5 27 | insert_onto_square_peg_peract+6 28 | insert_onto_square_peg_peract+7 29 | insert_onto_square_peg_peract+8 30 | insert_onto_square_peg_peract+9 31 | insert_onto_square_peg_peract+10 32 | insert_onto_square_peg_peract+11 33 | insert_onto_square_peg_peract+12 34 | insert_onto_square_peg_peract+13 35 | insert_onto_square_peg_peract+14 36 | insert_onto_square_peg_peract+15 37 | insert_onto_square_peg_peract+16 38 | insert_onto_square_peg_peract+17 39 | insert_onto_square_peg_peract+18 40 | insert_onto_square_peg_peract+19 41 | light_bulb_in_peract+0 42 | light_bulb_in_peract+1 43 | light_bulb_in_peract+2 44 | light_bulb_in_peract+3 45 | light_bulb_in_peract+4 46 | light_bulb_in_peract+5 47 | light_bulb_in_peract+6 48 | light_bulb_in_peract+7 49 | light_bulb_in_peract+8 50 | light_bulb_in_peract+9 51 | light_bulb_in_peract+10 52 | light_bulb_in_peract+11 53 | light_bulb_in_peract+12 54 | light_bulb_in_peract+13 55 | light_bulb_in_peract+14 56 | light_bulb_in_peract+15 57 | light_bulb_in_peract+16 58 | light_bulb_in_peract+17 59 | light_bulb_in_peract+18 60 | light_bulb_in_peract+19 61 | meat_off_grill_peract+0 62 | meat_off_grill_peract+1 63 | open_drawer_peract+0 64 | open_drawer_peract+1 65 | open_drawer_peract+2 66 | place_cups_peract+0 67 | place_cups_peract+1 68 | place_cups_peract+2 69 | place_shape_in_shape_sorter_peract+0 70 | place_shape_in_shape_sorter_peract+1 71 | place_shape_in_shape_sorter_peract+2 72 | place_shape_in_shape_sorter_peract+3 73 | place_shape_in_shape_sorter_peract+4 74 | place_wine_at_rack_location_peract+0 75 | place_wine_at_rack_location_peract+1 76 | place_wine_at_rack_location_peract+2 77 | push_buttons_peract+0 78 | push_buttons_peract+1 79 | push_buttons_peract+2 80 | push_buttons_peract+3 81 | push_buttons_peract+4 82 | push_buttons_peract+5 83 | push_buttons_peract+6 84 | push_buttons_peract+7 85 | push_buttons_peract+8 86 | push_buttons_peract+9 87 | push_buttons_peract+10 88 | push_buttons_peract+11 89 | push_buttons_peract+12 90 | push_buttons_peract+13 91 | push_buttons_peract+14 92 | push_buttons_peract+15 93 | push_buttons_peract+16 94 | push_buttons_peract+17 95 | push_buttons_peract+18 96 | push_buttons_peract+19 97 | push_buttons_peract+20 98 | push_buttons_peract+21 99 | push_buttons_peract+22 100 | push_buttons_peract+23 101 | push_buttons_peract+24 102 | push_buttons_peract+25 103 | push_buttons_peract+26 104 | push_buttons_peract+27 105 | push_buttons_peract+28 106 | push_buttons_peract+29 107 | push_buttons_peract+30 108 | push_buttons_peract+31 109 | push_buttons_peract+32 110 | push_buttons_peract+33 111 | push_buttons_peract+34 112 | push_buttons_peract+35 113 | push_buttons_peract+36 114 | push_buttons_peract+37 115 | push_buttons_peract+38 116 | push_buttons_peract+39 117 | push_buttons_peract+40 118 | push_buttons_peract+41 119 | push_buttons_peract+42 120 | push_buttons_peract+43 121 | push_buttons_peract+44 122 | push_buttons_peract+45 123 | push_buttons_peract+46 124 | push_buttons_peract+47 125 | push_buttons_peract+48 126 | push_buttons_peract+49 127 | put_groceries_in_cupboard_peract+0 128 | put_groceries_in_cupboard_peract+1 129 | put_groceries_in_cupboard_peract+2 130 | put_groceries_in_cupboard_peract+3 131 | put_groceries_in_cupboard_peract+4 132 | put_groceries_in_cupboard_peract+5 133 | put_groceries_in_cupboard_peract+6 134 | put_groceries_in_cupboard_peract+7 135 | put_groceries_in_cupboard_peract+8 136 | put_item_in_drawer_peract+0 137 | put_item_in_drawer_peract+1 138 | put_item_in_drawer_peract+2 139 | put_money_in_safe_peract+0 140 | put_money_in_safe_peract+1 141 | put_money_in_safe_peract+2 142 | reach_and_drag_peract+0 143 | reach_and_drag_peract+1 144 | reach_and_drag_peract+2 145 | reach_and_drag_peract+3 146 | reach_and_drag_peract+4 147 | reach_and_drag_peract+5 148 | reach_and_drag_peract+6 149 | reach_and_drag_peract+7 150 | reach_and_drag_peract+8 151 | reach_and_drag_peract+9 152 | reach_and_drag_peract+10 153 | reach_and_drag_peract+11 154 | reach_and_drag_peract+12 155 | reach_and_drag_peract+13 156 | reach_and_drag_peract+14 157 | reach_and_drag_peract+15 158 | reach_and_drag_peract+16 159 | reach_and_drag_peract+17 160 | reach_and_drag_peract+18 161 | reach_and_drag_peract+19 162 | slide_block_to_color_target_peract+0 163 | slide_block_to_color_target_peract+1 164 | slide_block_to_color_target_peract+2 165 | slide_block_to_color_target_peract+3 166 | stack_blocks_peract+0 167 | stack_blocks_peract+1 168 | stack_blocks_peract+2 169 | stack_blocks_peract+3 170 | stack_blocks_peract+4 171 | stack_blocks_peract+5 172 | stack_blocks_peract+6 173 | stack_blocks_peract+7 174 | stack_blocks_peract+8 175 | stack_blocks_peract+9 176 | stack_blocks_peract+10 177 | stack_blocks_peract+11 178 | stack_blocks_peract+12 179 | stack_blocks_peract+13 180 | stack_blocks_peract+14 181 | stack_blocks_peract+15 182 | stack_blocks_peract+16 183 | stack_blocks_peract+17 184 | stack_blocks_peract+18 185 | stack_blocks_peract+19 186 | stack_blocks_peract+20 187 | stack_blocks_peract+21 188 | stack_blocks_peract+22 189 | stack_blocks_peract+23 190 | stack_blocks_peract+24 191 | stack_blocks_peract+25 192 | stack_blocks_peract+26 193 | stack_blocks_peract+27 194 | stack_blocks_peract+28 195 | stack_blocks_peract+29 196 | stack_blocks_peract+30 197 | stack_blocks_peract+31 198 | stack_blocks_peract+32 199 | stack_blocks_peract+33 200 | stack_blocks_peract+34 201 | stack_blocks_peract+35 202 | stack_blocks_peract+36 203 | stack_blocks_peract+37 204 | stack_blocks_peract+38 205 | stack_blocks_peract+39 206 | stack_blocks_peract+40 207 | stack_blocks_peract+41 208 | stack_blocks_peract+42 209 | stack_blocks_peract+43 210 | stack_blocks_peract+44 211 | stack_blocks_peract+45 212 | stack_blocks_peract+46 213 | stack_blocks_peract+47 214 | stack_blocks_peract+48 215 | stack_blocks_peract+49 216 | stack_blocks_peract+50 217 | stack_blocks_peract+51 218 | stack_blocks_peract+52 219 | stack_blocks_peract+53 220 | stack_blocks_peract+54 221 | stack_blocks_peract+55 222 | stack_blocks_peract+56 223 | stack_blocks_peract+57 224 | stack_blocks_peract+58 225 | stack_blocks_peract+59 226 | stack_cups_peract+0 227 | stack_cups_peract+1 228 | stack_cups_peract+2 229 | stack_cups_peract+3 230 | stack_cups_peract+4 231 | stack_cups_peract+5 232 | stack_cups_peract+6 233 | stack_cups_peract+7 234 | stack_cups_peract+8 235 | stack_cups_peract+9 236 | stack_cups_peract+10 237 | stack_cups_peract+11 238 | stack_cups_peract+12 239 | stack_cups_peract+13 240 | stack_cups_peract+14 241 | stack_cups_peract+15 242 | stack_cups_peract+16 243 | stack_cups_peract+17 244 | stack_cups_peract+18 245 | stack_cups_peract+19 246 | sweep_to_dustpan_of_size_peract+0 247 | sweep_to_dustpan_of_size_peract+1 248 | turn_tap_peract+0 249 | turn_tap_peract+1 250 | -------------------------------------------------------------------------------- /polarnet/assets/tasks_use_table_surface.json: -------------------------------------------------------------------------------- 1 | [ 2 | "slide_block_to_target", 3 | "slide_block_to_color_target", 4 | "reach_and_drag", 5 | "stack_blocks", 6 | "straighten_rope", 7 | "tower3", 8 | "wipe_desk", 9 | "real_take_plate" 10 | ] 11 | -------------------------------------------------------------------------------- /polarnet/assets/tasks_with_color.json: -------------------------------------------------------------------------------- 1 | [ 2 | "block_pyramid", 3 | "close_jar", 4 | "insert_onto_square_peg", 5 | "lift_numbered_block", 6 | "light_bulb_in", 7 | "light_bulb_out", 8 | "open_jar", 9 | "pick_and_lift", 10 | "pick_and_lift_small", 11 | "pick_up_cup", 12 | "pour_from_cup_to_cup", 13 | "push_button", 14 | "push_buttons", 15 | "put_plate_in_colored_dish_rack", 16 | "reach_and_drag", 17 | "reach_target", 18 | "slide_block_to_target", 19 | "stack_blocks", 20 | "stack_chairs", 21 | "stack_cups", 22 | "take_off_weighing_scales", 23 | "take_plate_off_colored_dish_rack", 24 | "unplug_charger", 25 | "weighing_scales" 26 | ] -------------------------------------------------------------------------------- /polarnet/config/10tasks.yaml: -------------------------------------------------------------------------------- 1 | SEED: 2023 2 | output_dir: 'exprs/pcd_unet/10tasks' 3 | checkpoint: null 4 | checkpoint: data/pretrained_models/pointnext-s-c64-enc-dec-sameshape.pt 5 | checkpoint_strict_load: false # true, false 6 | resume_training: true 7 | 8 | train_batch_size: 8 9 | gradient_accumulation_steps: 1 10 | num_epochs: null 11 | num_train_steps: 200000 # 100k for single-task, 200k for 10tasks, 1M for 74tasks, 600k for peract 12 | warmup_steps: 2000 13 | log_steps: 1000 14 | save_steps: 5000 15 | 16 | optim: 'adamw' 17 | learning_rate: 5e-4 18 | lr_sched: 'linear' # inverse_sqrt, linear 19 | betas: [0.9, 0.98] 20 | weight_decay: 0.001 21 | grad_norm: 5 22 | n_workers: 0 23 | pin_mem: True 24 | 25 | DATASET: 26 | dataset_class: 'pre_pcd_keystep_stepwise' # pre_pcd_keystep_stepwise 27 | 28 | voxel_size: 0.01 # null, 0.01, 0.005 29 | npoints: 2048 30 | use_color: True 31 | use_normal: True 32 | use_height: True 33 | color_drop: 0.0 34 | only_success: False 35 | multi_instruction: True 36 | 37 | max_steps_per_episode: 12 38 | 39 | use_discrete_rot: False 40 | rot_resolution: 5 # degrees 41 | 42 | aug_shift_pcd: 0.0 43 | aug_rotate_pcd: 0.0 44 | 45 | add_pcd_noises: False 46 | pcd_noises_std: 0.01 47 | remove_pcd_outliers: False 48 | real_robot: False 49 | 50 | max_demos_per_taskvar: null 51 | exclude_overlength_episodes: null 52 | 53 | pc_space: 'workspace_on_table' # none, workspace, workspace_on_table 54 | pc_center: 'gripper' # point, gripper 55 | pc_radius_norm: True # true (unit ball), false 56 | 57 | data_dir: 'data/train_dataset/keysteps_pcd/seed0' 58 | taskvars: ['assets/10_tasks_var.csv'] 59 | instr_embed_file: 'data/train_dataset/taskvar_instrs/clip' 60 | use_instr_embed: 'all' # none, avg, last, all 61 | cameras: ("left_shoulder", "right_shoulder", "wrist") 62 | camera_ids: [0, 1, 2] 63 | gripper_channel: False 64 | is_training: True 65 | in_memory: True 66 | num_workers: 0 67 | 68 | MODEL: 69 | model_class: 'PointCloudUNet' 70 | 71 | use_max_action: False 72 | use_discrete_rot: False 73 | rot_resolution: 5 # degrees 74 | 75 | heatmap_loss: false 76 | heatmap_loss_weight: 1.0 77 | heatmap_distance_weight: 1.0 78 | use_heatmap_max: false 79 | use_pos_loss: true 80 | 81 | num_tasks: 1 82 | max_steps: 25 83 | dropout: 0.0 84 | learnable_step_embedding: false 85 | use_prev_action: true 86 | 87 | use_instr_embed: 'all' # none, avg, last, all 88 | instr_embed_size: 512 89 | txt_attn_type: 'cross' # none, cross 90 | num_trans_layers: 2 91 | trans_hidden_size: 512 92 | cat_global_in_head: False 93 | 94 | heatmap_temp: 0.1 95 | 96 | pcd_encoder_cfg: 97 | blocks: [1, 1, 1, 1, 1] 98 | strides: [1, 2, 2, 2, 2] 99 | width: 64 100 | in_channels: 10 101 | sa_layers: 3 102 | sa_use_res: True 103 | radius: 0.05 104 | radius_scaling: 2.5 105 | nsample: 32 106 | expansion: 4 107 | aggr_args: 108 | feature_type: dp_fj 109 | reduction: max 110 | group_args: 111 | NAME: ballquery 112 | normalize_dp: True 113 | conv_args: 114 | order: conv-norm-act 115 | act_args: 116 | act: relu 117 | inplace: True 118 | norm_args: 119 | norm: bn 120 | 121 | pcd_decoder_cfg: 122 | layers: 2 123 | 124 | -------------------------------------------------------------------------------- /polarnet/config/74tasks.yaml: -------------------------------------------------------------------------------- 1 | SEED: 2023 2 | output_dir: 'exprs/pcd_unet/74tasks' 3 | checkpoint: null 4 | checkpoint: data/pretrained_models/pointnext-s-c64-enc-dec-sameshape.pt 5 | checkpoint_strict_load: false # true, false 6 | resume_training: true 7 | 8 | train_batch_size: 8 9 | gradient_accumulation_steps: 1 10 | num_epochs: null 11 | num_train_steps: 1000000 # 100k for single-task, 200k for 10tasks, 1M for 74tasks, 600k for peract 12 | warmup_steps: 2000 13 | log_steps: 1000 14 | save_steps: 5000 15 | 16 | optim: 'adamw' 17 | learning_rate: 5e-4 18 | lr_sched: 'linear' # inverse_sqrt, linear 19 | betas: [0.9, 0.98] 20 | weight_decay: 0.001 21 | grad_norm: 5 22 | n_workers: 0 23 | pin_mem: True 24 | 25 | DATASET: 26 | dataset_class: 'pre_pcd_keystep_stepwise' # pre_pcd_keystep_stepwise 27 | 28 | voxel_size: 0.01 # null, 0.01, 0.005 29 | npoints: 2048 30 | use_color: True 31 | use_normal: True 32 | use_height: True 33 | color_drop: 0.0 34 | only_success: False 35 | multi_instruction: True 36 | 37 | max_steps_per_episode: 12 38 | 39 | use_discrete_rot: False 40 | rot_resolution: 5 # degrees 41 | 42 | aug_shift_pcd: 0.0 43 | aug_rotate_pcd: 0.0 44 | 45 | add_pcd_noises: False 46 | pcd_noises_std: 0.01 47 | remove_pcd_outliers: False 48 | real_robot: False 49 | 50 | max_demos_per_taskvar: null 51 | exclude_overlength_episodes: null 52 | 53 | pc_space: 'workspace_on_table' # none, workspace, workspace_on_table 54 | pc_center: 'gripper' # point, gripper 55 | pc_radius_norm: True # true (unit ball), false 56 | 57 | data_dir: 'data/train_dataset/keysteps_pcd/seed0' 58 | taskvars: ['assets/74_tasks_var.csv'] 59 | instr_embed_file: 'data/train_dataset/taskvar_instrs/clip' 60 | use_instr_embed: 'all' # none, avg, last, all 61 | # cameras: ("left_shoulder", "right_shoulder", "wrist", "front") 62 | cameras: ("left_shoulder", "right_shoulder", "wrist") 63 | camera_ids: [0, 1, 2] 64 | gripper_channel: False 65 | is_training: True 66 | in_memory: True 67 | num_workers: 0 68 | 69 | MODEL: 70 | model_class: 'PointCloudUNet' 71 | 72 | use_max_action: False 73 | use_discrete_rot: False 74 | rot_resolution: 5 # degrees 75 | 76 | heatmap_loss: false 77 | heatmap_loss_weight: 1.0 78 | heatmap_distance_weight: 1.0 79 | use_heatmap_max: false 80 | use_pos_loss: true 81 | 82 | num_tasks: 1 83 | max_steps: 25 84 | dropout: 0.0 85 | learnable_step_embedding: false 86 | use_prev_action: true 87 | 88 | use_instr_embed: 'all' # none, avg, last, all 89 | instr_embed_size: 512 90 | txt_attn_type: 'cross' # none, cross 91 | num_trans_layers: 2 92 | trans_hidden_size: 512 93 | cat_global_in_head: False 94 | 95 | heatmap_temp: 0.1 96 | 97 | pcd_encoder_cfg: 98 | blocks: [1, 1, 1, 1, 1] 99 | strides: [1, 2, 2, 2, 2] 100 | width: 64 101 | in_channels: 10 102 | sa_layers: 3 103 | sa_use_res: True 104 | radius: 0.05 105 | radius_scaling: 2.5 106 | nsample: 32 107 | expansion: 4 108 | aggr_args: 109 | feature_type: dp_fj 110 | reduction: max 111 | group_args: 112 | NAME: ballquery 113 | normalize_dp: True 114 | conv_args: 115 | order: conv-norm-act 116 | act_args: 117 | act: relu 118 | inplace: True 119 | norm_args: 120 | norm: bn 121 | 122 | pcd_decoder_cfg: 123 | layers: 2 124 | 125 | -------------------------------------------------------------------------------- /polarnet/config/constants.py: -------------------------------------------------------------------------------- 1 | 2 | def get_workspace(real_robot=False): 3 | if real_robot: 4 | # ur5 robotics room 5 | TABLE_HEIGHT = 0.01 # meters 6 | 7 | X_BBOX = (-1, 0) # 0 is the robot base 8 | Y_BBOX = (-0.175, 0.4) # 0 is the robot base 9 | Z_BBOX = (0, 0.75) # 0 is the table 10 | else: 11 | # rlbench workspace 12 | TABLE_HEIGHT = 0.76 # meters 13 | 14 | X_BBOX = (-0.5, 1.5) # 0 is the robot base 15 | Y_BBOX = (-1, 1) # 0 is the robot base 16 | Z_BBOX = (0.2, 2) # 0 is the floor 17 | 18 | return { 19 | 'TABLE_HEIGHT': TABLE_HEIGHT, 20 | 'X_BBOX': X_BBOX, 21 | 'Y_BBOX': Y_BBOX, 22 | 'Z_BBOX': Z_BBOX 23 | } 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /polarnet/config/default.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import numpy as np 4 | 5 | import yacs.config 6 | 7 | # Default config node 8 | class Config(yacs.config.CfgNode): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs, new_allowed=True) 11 | 12 | CN = Config 13 | 14 | 15 | CONFIG_FILE_SEPARATOR = ';' 16 | 17 | # ----------------------------------------------------------------------------- 18 | # EXPERIMENT CONFIG 19 | # ----------------------------------------------------------------------------- 20 | _C = CN() 21 | _C.SEED = 2023 22 | _C.CMD_TRAILING_OPTS = [] # store command line options as list of strings 23 | 24 | # ----------------------------------------------------------------------------- 25 | # MODEL 26 | # ----------------------------------------------------------------------------- 27 | _C.MODEL = CN() 28 | 29 | # ----------------------------------------------------------------------------- 30 | # DATASET 31 | # ----------------------------------------------------------------------------- 32 | _C.DATASET = CN() 33 | 34 | 35 | def get_config( 36 | config_paths: Optional[Union[List[str], str]] = None, 37 | opts: Optional[list] = None, 38 | ) -> CN: 39 | r"""Create a unified config with default values overwritten by values from 40 | :ref:`config_paths` and overwritten by options from :ref:`opts`. 41 | 42 | Args: 43 | config_paths: List of config paths or string that contains comma 44 | separated list of config paths. 45 | opts: Config options (keys, values) in a list (e.g., passed from 46 | command line into the config. For example, ``opts = ['FOO.BAR', 47 | 0.5]``. Argument can be used for parameter sweeping or quick tests. 48 | """ 49 | config = _C.clone() 50 | if config_paths: 51 | if isinstance(config_paths, str): 52 | if CONFIG_FILE_SEPARATOR in config_paths: 53 | config_paths = config_paths.split(CONFIG_FILE_SEPARATOR) 54 | else: 55 | config_paths = [config_paths] 56 | 57 | for config_path in config_paths: 58 | config.merge_from_file(config_path) 59 | 60 | if opts: 61 | config.CMD_TRAILING_OPTS = config.CMD_TRAILING_OPTS + opts 62 | # FIXME: remove later 63 | for i in range(len(config.CMD_TRAILING_OPTS)): 64 | if config.CMD_TRAILING_OPTS[i] == "DATASET.taskvars": 65 | if type(config.CMD_TRAILING_OPTS[i + 1]) is str: 66 | config.CMD_TRAILING_OPTS[i + 1] = config.CMD_TRAILING_OPTS[i + 1].split(',') 67 | if config.CMD_TRAILING_OPTS[i] == 'DATASET.camera_ids': 68 | if type(config.CMD_TRAILING_OPTS[i + 1]) is str: 69 | config.CMD_TRAILING_OPTS[i + 1] = [int(v) for v in config.CMD_TRAILING_OPTS[i + 1].split(',')] 70 | 71 | config.merge_from_list(config.CMD_TRAILING_OPTS) 72 | 73 | config.freeze() 74 | return config 75 | -------------------------------------------------------------------------------- /polarnet/config/peract.yaml: -------------------------------------------------------------------------------- 1 | SEED: 2023 2 | output_dir: 'exprs/pcd_unet/peract' 3 | checkpoint: null 4 | checkpoint: data/pretrained_models/pointnext-s-c64-enc-dec-sameshape.pt 5 | checkpoint_strict_load: false # true, false 6 | resume_training: true 7 | 8 | train_batch_size: 4 9 | gradient_accumulation_steps: 1 10 | num_epochs: null 11 | num_train_steps: 600000 # 100k for single-task, 200k for 10tasks 12 | warmup_steps: 2000 13 | log_steps: 1000 14 | save_steps: 5000 15 | 16 | optim: 'adamw' 17 | learning_rate: 5e-4 18 | lr_sched: 'linear' # inverse_sqrt, linear 19 | betas: [0.9, 0.98] 20 | weight_decay: 0.001 21 | grad_norm: 5 22 | n_workers: 0 23 | pin_mem: True 24 | 25 | DATASET: 26 | dataset_class: 'pre_pcd_keystep_stepwise' # pre_pcd_keystep_stepwise 27 | 28 | voxel_size: 0.01 # null, 0.01, 0.005 29 | npoints: 2048 30 | use_color: True 31 | use_normal: True 32 | use_height: True 33 | color_drop: 0.0 34 | only_success: False 35 | multi_instruction: True 36 | 37 | max_steps_per_episode: 12 38 | 39 | use_discrete_rot: False 40 | rot_resolution: 5 # degrees 41 | 42 | aug_shift_pcd: 0.0 43 | aug_rotate_pcd: 0.0 44 | 45 | add_pcd_noises: False 46 | pcd_noises_std: 0.01 47 | remove_pcd_outliers: False 48 | real_robot: False 49 | 50 | max_demos_per_taskvar: null 51 | exclude_overlength_episodes: null 52 | 53 | pc_space: 'workspace_on_table' # none, workspace, workspace_on_table 54 | pc_center: 'gripper' # point, gripper 55 | pc_radius_norm: True # true (unit ball), false 56 | 57 | data_dir: 'data/train_dataset/keysteps_pcd/seed0' 58 | taskvars: ['assets/peract_tasks_var.csv'] 59 | instr_embed_file: 'data/train_dataset/taskvar_instrs/clip' 60 | use_instr_embed: 'all' # none, avg, last, all 61 | cameras: ("left_shoulder", "right_shoulder", "wrist", "front") 62 | camera_ids: [0, 1, 2, 3] 63 | gripper_channel: False 64 | is_training: True 65 | in_memory: True 66 | num_workers: 0 67 | 68 | MODEL: 69 | model_class: 'PointCloudUNet' 70 | 71 | use_max_action: False 72 | use_discrete_rot: False 73 | rot_resolution: 5 # degrees 74 | 75 | heatmap_loss: false 76 | heatmap_loss_weight: 1.0 77 | heatmap_distance_weight: 1.0 78 | use_heatmap_max: false 79 | use_pos_loss: true 80 | 81 | num_tasks: 1 82 | max_steps: 25 83 | dropout: 0.0 84 | learnable_step_embedding: false 85 | use_prev_action: true 86 | 87 | use_instr_embed: 'all' # none, avg, last, all 88 | instr_embed_size: 512 89 | txt_attn_type: 'cross' # none, cross 90 | num_trans_layers: 2 91 | trans_hidden_size: 512 92 | cat_global_in_head: False 93 | 94 | heatmap_temp: 0.1 95 | 96 | pcd_encoder_cfg: 97 | blocks: [1, 1, 1, 1, 1] 98 | strides: [1, 2, 2, 2, 2] 99 | width: 64 100 | in_channels: 10 101 | sa_layers: 3 102 | sa_use_res: True 103 | radius: 0.05 104 | radius_scaling: 2.5 105 | nsample: 32 106 | expansion: 4 107 | aggr_args: 108 | feature_type: dp_fj 109 | reduction: max 110 | group_args: 111 | NAME: ballquery 112 | normalize_dp: True 113 | conv_args: 114 | order: conv-norm-act 115 | act_args: 116 | act: relu 117 | inplace: True 118 | norm_args: 119 | norm: bn 120 | 121 | pcd_decoder_cfg: 122 | layers: 2 123 | 124 | -------------------------------------------------------------------------------- /polarnet/config/single_task.yaml: -------------------------------------------------------------------------------- 1 | SEED: 2023 2 | output_dir: 'exprs/pcd_unet/single_task' 3 | checkpoint: null 4 | checkpoint: data/pretrained_models/pointnext-s-c64-enc-dec-sameshape.pt 5 | checkpoint_strict_load: false # true, false 6 | resume_training: true 7 | 8 | train_batch_size: 8 9 | gradient_accumulation_steps: 1 10 | num_epochs: null 11 | num_train_steps: 100000 # 100k for single-task, 200k for 10tasks, 1M for 74tasks, 600k for peract 12 | warmup_steps: 2000 13 | log_steps: 1000 14 | save_steps: 5000 15 | 16 | optim: 'adamw' 17 | learning_rate: 5e-4 18 | lr_sched: 'linear' # inverse_sqrt, linear 19 | betas: [0.9, 0.98] 20 | weight_decay: 0.001 21 | grad_norm: 5 22 | n_workers: 0 23 | pin_mem: True 24 | 25 | DATASET: 26 | dataset_class: 'pre_pcd_keystep_stepwise' # pre_pcd_keystep_stepwise 27 | 28 | voxel_size: 0.01 # null, 0.01, 0.005 29 | npoints: 2048 30 | use_color: True 31 | use_normal: True 32 | use_height: True 33 | color_drop: 0.0 34 | only_success: False 35 | multi_instruction: True 36 | 37 | max_steps_per_episode: 12 38 | 39 | use_discrete_rot: False 40 | rot_resolution: 5 # degrees 41 | 42 | aug_shift_pcd: 0.0 43 | aug_rotate_pcd: 0.0 44 | 45 | add_pcd_noises: False 46 | pcd_noises_std: 0.01 47 | remove_pcd_outliers: False 48 | real_robot: False 49 | 50 | max_demos_per_taskvar: null 51 | exclude_overlength_episodes: null 52 | 53 | pc_space: 'workspace_on_table' # none, workspace, workspace_on_table 54 | pc_center: 'gripper' # point, gripper 55 | pc_radius_norm: True # true (unit ball), false 56 | 57 | data_dir: 'data/train_dataset/keysteps_pcd/seed0' 58 | taskvars: ('pick_and_lift+0', ) 59 | instr_embed_file: 'data/train_dataset/taskvar_instrs/clip' 60 | use_instr_embed: 'all' # none, avg, last, all 61 | cameras: ("left_shoulder", "right_shoulder", "wrist") 62 | camera_ids: [0, 1, 2] 63 | gripper_channel: False 64 | is_training: True 65 | in_memory: True 66 | num_workers: 0 67 | 68 | MODEL: 69 | model_class: 'PointCloudUNet' 70 | 71 | use_max_action: False 72 | use_discrete_rot: False 73 | rot_resolution: 5 # degrees 74 | 75 | heatmap_loss: false 76 | heatmap_loss_weight: 1.0 77 | heatmap_distance_weight: 1.0 78 | use_heatmap_max: false 79 | use_pos_loss: true 80 | 81 | num_tasks: 1 82 | max_steps: 25 83 | dropout: 0.0 84 | learnable_step_embedding: false 85 | use_prev_action: true 86 | 87 | use_instr_embed: 'all' # none, avg, last, all 88 | instr_embed_size: 512 89 | txt_attn_type: 'cross' # none, cross 90 | num_trans_layers: 2 91 | trans_hidden_size: 512 92 | cat_global_in_head: False 93 | 94 | heatmap_temp: 0.1 95 | 96 | pcd_encoder_cfg: 97 | blocks: [1, 1, 1, 1, 1] 98 | strides: [1, 2, 2, 2, 2] 99 | width: 64 100 | in_channels: 10 101 | sa_layers: 3 102 | sa_use_res: True 103 | radius: 0.05 104 | radius_scaling: 2.5 105 | nsample: 32 106 | expansion: 4 107 | aggr_args: 108 | feature_type: dp_fj 109 | reduction: max 110 | group_args: 111 | NAME: ballquery 112 | normalize_dp: True 113 | conv_args: 114 | order: conv-norm-act 115 | act_args: 116 | act: relu 117 | inplace: True 118 | norm_args: 119 | norm: bn 120 | 121 | pcd_decoder_cfg: 122 | layers: 2 123 | 124 | -------------------------------------------------------------------------------- /polarnet/core/actioner.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional, Sequence, Tuple, TypedDict, Union, Any 2 | 3 | class BaseActioner: 4 | 5 | def reset(self, task_str, variation, instructions, demo_id): 6 | self.task_str = task_str 7 | self.variation = variation 8 | self.instructions = instructions 9 | self.demo_id = demo_id 10 | 11 | self.step_id = 0 12 | self.state_dict = {} 13 | self.history_obs = {} 14 | 15 | def predict(self, *args, **kwargs): 16 | raise NotImplementedError('implete predict function') 17 | -------------------------------------------------------------------------------- /polarnet/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlc-robot/polarnet/882b7ef5b82ee4c7779cdd0020f58e919e0f8bce/polarnet/dataloaders/__init__.py -------------------------------------------------------------------------------- /polarnet/dataloaders/keystep_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | 3 | import os 4 | import numpy as np 5 | import einops 6 | import json 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | import torchvision.transforms as transforms 11 | import torchvision.transforms.functional as transforms_f 12 | 13 | import lmdb 14 | import msgpack 15 | import msgpack_numpy 16 | 17 | msgpack_numpy.patch() 18 | 19 | from polarnet.utils.ops import pad_tensors, gen_seq_masks 20 | 21 | 22 | class DataTransform(object): 23 | def __init__(self, scales): 24 | self.scales = scales 25 | 26 | def __call__(self, data) -> Dict[str, torch.Tensor]: 27 | """ 28 | Inputs: 29 | data: dict 30 | - rgb: (T, N, C, H, W), N: num of cameras 31 | - pc: (T, N, C, H, W) 32 | """ 33 | keys = list(data.keys()) 34 | 35 | # Continuous range of scales 36 | sc = np.random.uniform(*self.scales) 37 | 38 | t, n, c, raw_h, raw_w = data[keys[0]].shape 39 | data = {k: v.flatten(0, 1) for k, v in data.items()} # (t*n, h, w, c) 40 | resized_size = [int(raw_h * sc), int(raw_w * sc)] 41 | 42 | # Resize based on randomly sampled scale 43 | data = { 44 | k: transforms_f.resize( 45 | v, resized_size, transforms.InterpolationMode.BILINEAR 46 | ) 47 | for k, v in data.items() 48 | } 49 | 50 | # Adding padding if crop size is smaller than the resized size 51 | if raw_h > resized_size[0] or raw_w > resized_size[1]: 52 | right_pad = max(raw_w - resized_size[1], 0) 53 | bottom_pad = max(raw_h - resized_size[0], 0) 54 | data = { 55 | k: transforms_f.pad( 56 | v, 57 | padding=[0, 0, right_pad, bottom_pad], 58 | padding_mode="edge", 59 | ) 60 | for k, v in data.items() 61 | } 62 | 63 | # Random Cropping 64 | i, j, h, w = transforms.RandomCrop.get_params( 65 | data[keys[0]], output_size=(raw_h, raw_w) 66 | ) 67 | 68 | data = {k: transforms_f.crop(v, i, j, h, w) for k, v in data.items()} 69 | 70 | data = { 71 | k: einops.rearrange(v, "(t n) c h w -> t n c h w", t=t) 72 | for k, v in data.items() 73 | } 74 | 75 | return data 76 | 77 | 78 | class KeystepDataset(Dataset): 79 | def __init__( 80 | self, 81 | data_dir, 82 | taskvars, 83 | instr_embed_file=None, 84 | gripper_channel=False, 85 | camera_ids=None, 86 | cameras=("left_shoulder", "right_shoulder", "wrist"), 87 | use_instr_embed="none", 88 | is_training=False, 89 | in_memory=False, 90 | only_success=False, 91 | **kwargs, 92 | ): 93 | """ 94 | - use_instr_embed: 95 | 'none': use task_id; 96 | 'avg': use the average instruction embedding; 97 | 'last': use the last instruction embedding; 98 | 'all': use the embedding of all instruction tokens. 99 | """ 100 | self.data_dir = data_dir 101 | 102 | if len(taskvars) == 1 and os.path.exists(taskvars[0]): 103 | with open(taskvars[0]) as file: 104 | self.taskvars = [taskvar.rstrip() for taskvar in file.readlines()] 105 | self.taskvars.sort() 106 | else: 107 | self.taskvars = taskvars 108 | 109 | self.instr_embed_file = instr_embed_file 110 | self.taskvar_to_id = {x: i for i, x in enumerate(self.taskvars)} 111 | self.use_instr_embed = use_instr_embed 112 | self.gripper_channel = gripper_channel 113 | self.cameras = cameras 114 | if camera_ids is None: 115 | self.camera_ids = np.arange(len(self.cameras)) 116 | else: 117 | self.camera_ids = np.array(camera_ids) 118 | self.in_memory = in_memory 119 | self.is_training = is_training 120 | self.multi_instruction = kwargs.get("multi_instruction", True) 121 | self.max_demos_per_taskvar = kwargs.get("max_demos_per_taskvar", None) 122 | self.exclude_overlength_episodes = kwargs.get( 123 | "exclude_overlength_episodes", None 124 | ) 125 | 126 | self.memory = {} 127 | 128 | self._transform = DataTransform((0.75, 1.25)) 129 | 130 | self.lmdb_envs, self.lmdb_txns = [], [] 131 | self.episode_ids = [] 132 | for i, taskvar in enumerate(self.taskvars): 133 | demo_res_file = os.path.join(data_dir, taskvar, "results.json") 134 | if only_success and os.path.exists(demo_res_file): 135 | demo_results = json.load(open(demo_res_file, "r")) 136 | if not os.path.exists(os.path.join(data_dir, taskvar)): 137 | self.lmdb_envs.append(None) 138 | self.lmdb_txns.append(None) 139 | continue 140 | lmdb_env = lmdb.open( 141 | os.path.join(data_dir, taskvar), readonly=True, lock=False 142 | ) 143 | self.lmdb_envs.append(lmdb_env) 144 | lmdb_txn = lmdb_env.begin() 145 | self.lmdb_txns.append(lmdb_txn) 146 | keys = [ 147 | key.decode("ascii") 148 | for key in list(lmdb_txn.cursor().iternext(values=False)) 149 | ] 150 | self.episode_ids.extend( 151 | [ 152 | (i, key.encode("ascii")) 153 | for key in keys 154 | if key.startswith("episode") 155 | and ((not only_success) or demo_results[key]) 156 | ][: self.max_demos_per_taskvar] 157 | ) 158 | if self.in_memory: 159 | self.memory[f"taskvar{i}"] = {} 160 | 161 | if self.use_instr_embed != "none": 162 | assert self.instr_embed_file is not None 163 | self.lmdb_instr_env = lmdb.open( 164 | self.instr_embed_file, readonly=True, lock=False 165 | ) 166 | self.lmdb_instr_txn = self.lmdb_instr_env.begin() 167 | if True: # self.in_memory: 168 | self.memory["instr_embeds"] = {} 169 | else: 170 | self.lmdb_instr_env = None 171 | 172 | def __exit__(self): 173 | for lmdb_env in self.lmdb_envs: 174 | if lmdb_env is not None: 175 | lmdb_env.close() 176 | if self.lmdb_instr_env is not None: 177 | self.lmdb_instr_env.close() 178 | 179 | def __len__(self): 180 | return len(self.episode_ids) 181 | 182 | def get_taskvar_episode(self, taskvar_idx, episode_key): 183 | if self.in_memory: 184 | mem_key = f"taskvar{taskvar_idx}" 185 | if episode_key in self.memory[mem_key]: 186 | return self.memory[mem_key][episode_key] 187 | 188 | value = self.lmdb_txns[taskvar_idx].get(episode_key) 189 | value = msgpack.unpackb(value) 190 | # rgb, pc: (num_steps, num_cameras, height, width, 3) 191 | value["rgb"] = value["rgb"][:, self.camera_ids] 192 | value["pc"] = value["pc"][:, self.camera_ids] 193 | if self.in_memory: 194 | self.memory[mem_key][episode_key] = value 195 | return value 196 | 197 | def get_taskvar_instr_embeds(self, taskvar): 198 | instr_embeds = None 199 | if True: # self.in_memory: 200 | if taskvar in self.memory["instr_embeds"]: 201 | instr_embeds = self.memory["instr_embeds"][taskvar] 202 | 203 | if instr_embeds is None: 204 | instr_embeds = self.lmdb_instr_txn.get(taskvar.encode("ascii")) 205 | instr_embeds = msgpack.unpackb(instr_embeds) 206 | instr_embeds = [torch.from_numpy(x).float() for x in instr_embeds] 207 | if self.in_memory: 208 | self.memory["instr_embeds"][taskvar] = instr_embeds 209 | 210 | # randomly select one instruction for the taskvar 211 | if self.multi_instruction: 212 | ridx = np.random.randint(len(instr_embeds)) 213 | else: 214 | ridx = 0 215 | instr_embeds = instr_embeds[ridx] 216 | 217 | if self.use_instr_embed == "avg": 218 | instr_embeds = torch.mean(instr_embeds, 0, keepdim=True) 219 | elif self.use_instr_embed == "last": 220 | instr_embeds = instr_embeds[-1:] 221 | 222 | return instr_embeds # (num_ttokens, dim) 223 | 224 | def __getitem__(self, idx): 225 | taskvar_idx, episode_key = self.episode_ids[idx] 226 | 227 | value = self.get_taskvar_episode(taskvar_idx, episode_key) 228 | 229 | # The last one is the stop observation 230 | rgbs = ( 231 | torch.from_numpy(value["rgb"][:-1]).float().permute(0, 1, 4, 2, 3) 232 | ) # (T, N, C, H, W) 233 | pcs = torch.from_numpy(value["pc"][:-1]).float().permute(0, 1, 4, 2, 3) 234 | # normalise to [-1, 1] 235 | rgbs = 2 * (rgbs / 255.0 - 0.5) 236 | 237 | num_steps, num_cameras, _, im_height, im_width = rgbs.size() 238 | 239 | if self.gripper_channel: 240 | gripper_imgs = torch.zeros( 241 | num_steps, num_cameras, 1, im_height, im_width, dtype=torch.float32 242 | ) 243 | for t in range(num_steps): 244 | for c, cam in enumerate(self.cameras): 245 | u, v = value["gripper_pose"][t][cam] 246 | if u >= 0 and u < 128 and v >= 0 and v < 128: 247 | gripper_imgs[t, c, 0, v, u] = 1 248 | rgbs = torch.cat([rgbs, gripper_imgs], dim=2) 249 | 250 | # rgb, pcd: (T, N, C, H, W) 251 | outs = {"rgbs": rgbs, "pcds": pcs} 252 | if self.is_training: 253 | outs = self._transform(outs) 254 | 255 | outs["step_ids"] = torch.arange(0, num_steps).long() 256 | outs["actions"] = torch.from_numpy(value["action"][1:]) 257 | outs["episode_ids"] = episode_key.decode("ascii") 258 | outs["taskvars"] = self.taskvars[taskvar_idx] 259 | outs["taskvar_ids"] = taskvar_idx 260 | 261 | if self.exclude_overlength_episodes is not None: 262 | for key in ["rgbs", "pcds", "step_ids", "actions"]: 263 | outs[key] = outs[key][: self.exclude_overlength_episodes] 264 | 265 | if self.use_instr_embed != "none": 266 | outs["instr_embeds"] = self.get_taskvar_instr_embeds(outs["taskvars"]) 267 | 268 | return outs 269 | 270 | 271 | def stepwise_collate_fn(data: List[Dict]): 272 | batch = {} 273 | 274 | for key in data[0].keys(): 275 | if key == "taskvar_ids": 276 | batch[key] = [ 277 | torch.LongTensor([v["taskvar_ids"]] * len(v["step_ids"])) for v in data 278 | ] 279 | elif key == "instr_embeds": 280 | batch[key] = sum( 281 | [[v["instr_embeds"]] * len(v["step_ids"]) for v in data], [] 282 | ) 283 | else: 284 | batch[key] = [v[key] for v in data] 285 | 286 | for key in ["rgbs", "pcds", "taskvar_ids", "step_ids", "actions"]: 287 | # e.g. rgbs: (B*T, N, C, H, W) 288 | batch[key] = torch.cat(batch[key], dim=0) 289 | 290 | if "instr_embeds" in batch: 291 | batch["instr_embeds"] = pad_tensors(batch["instr_embeds"]) 292 | 293 | return batch 294 | 295 | 296 | def episode_collate_fn(data: List[Dict]): 297 | batch = {} 298 | 299 | for key in data[0].keys(): 300 | batch[key] = [v[key] for v in data] 301 | 302 | batch["taskvar_ids"] = torch.LongTensor(batch["taskvar_ids"]) 303 | num_steps = [len(x["rgbs"]) for x in data] 304 | if "instr_embeds" in batch: 305 | num_ttokens = [len(x["instr_embeds"]) for x in data] 306 | 307 | for key in ["rgbs", "pcds", "step_ids", "actions"]: 308 | # e.g. rgbs: (B, T, N, C, H, W) 309 | batch[key] = pad_tensors(batch[key], lens=num_steps) 310 | 311 | if "instr_embeds" in batch: 312 | batch["instr_embeds"] = pad_tensors(batch["instr_embeds"], lens=num_ttokens) 313 | batch["txt_masks"] = torch.from_numpy(gen_seq_masks(num_ttokens)) 314 | else: 315 | batch["txt_masks"] = torch.ones(len(num_steps), 1).bool() 316 | 317 | batch["step_masks"] = torch.from_numpy(gen_seq_masks(num_steps)) 318 | 319 | return batch 320 | 321 | 322 | if __name__ == "__main__": 323 | import time 324 | from torch.utils.data import DataLoader 325 | 326 | data_dir = "data/train_dataset/keysteps/seed0" 327 | taskvars = ["pick_up_cup+0"] 328 | cameras = ["left_shoulder", "right_shoulder", "wrist"] 329 | instr_embed_file = None 330 | instr_embed_file = "data/train_dataset/taskvar_instrs/clip" 331 | 332 | dataset = KeystepDataset( 333 | data_dir, 334 | taskvars, 335 | instr_embed_file=instr_embed_file, 336 | use_instr_embed="all", 337 | gripper_channel="attn", 338 | cameras=cameras, 339 | is_training=True, 340 | ) 341 | 342 | data_loader = DataLoader( 343 | dataset, 344 | batch_size=16, 345 | # collate_fn=stepwise_collate_fn 346 | collate_fn=episode_collate_fn, 347 | ) 348 | 349 | print(len(dataset), len(data_loader)) 350 | 351 | st = time.time() 352 | for batch in data_loader: 353 | for k, v in batch.items(): 354 | if isinstance(v, torch.Tensor): 355 | print(k, v.size()) 356 | break 357 | et = time.time() 358 | print("cost time: %.2fs" % (et - st)) 359 | -------------------------------------------------------------------------------- /polarnet/dataloaders/loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | A prefetch loader to speedup data loading 6 | Modified from Nvidia Deep Learning Examples 7 | (https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch). 8 | """ 9 | import random 10 | from typing import List, Dict, Tuple, Union, Iterator 11 | 12 | import torch 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 14 | from torch.utils.data.distributed import DistributedSampler 15 | import torch.distributed as dist 16 | 17 | 18 | class MetaLoader: 19 | """wraps multiple data loaders""" 20 | 21 | def __init__( 22 | self, loaders, accum_steps: int = 1, distributed: bool = False, device=None 23 | ): 24 | assert isinstance(loaders, dict) 25 | self.name2loader = {} 26 | self.name2iter = {} 27 | self.name2pre_epoch = {} 28 | self.names: List[str] = [] 29 | ratios: List[int] = [] 30 | for n, l in loaders.items(): 31 | if isinstance(l, tuple): 32 | l, r, p = l 33 | elif isinstance(l, DataLoader): 34 | r = 1 35 | p = lambda e: None 36 | else: 37 | raise ValueError() 38 | self.names.append(n) 39 | self.name2loader[n] = l 40 | self.name2iter[n] = iter(l) 41 | self.name2pre_epoch[n] = p 42 | ratios.append(r) 43 | 44 | self.accum_steps = accum_steps 45 | self.device = device 46 | self.sampling_ratios = torch.tensor(ratios).float().to(self.device) 47 | self.distributed = distributed 48 | self.step = 0 49 | 50 | def __iter__(self) -> Iterator[Tuple]: 51 | """this iterator will run indefinitely""" 52 | task_id = None 53 | epoch_id = 0 54 | while True: 55 | if self.step % self.accum_steps == 0: 56 | task_id = torch.multinomial(self.sampling_ratios, 1) 57 | if self.distributed: 58 | # make sure all process is training same task 59 | dist.broadcast(task_id, 0) 60 | self.step += 1 61 | task = self.names[task_id.cpu().item()] 62 | iter_ = self.name2iter[task] 63 | try: 64 | batch = next(iter_) 65 | except StopIteration: 66 | epoch_id += 1 67 | # In distributed mode, calling the set_epoch() method at the beginning of each epoch 68 | # before creating the DataLoader iterator is necessary to make shuffling work properly 69 | # across multiple epochs. Otherwise, the same ordering will be always used. 70 | self.name2pre_epoch[task](epoch_id) 71 | iter_ = iter(self.name2loader[task]) 72 | batch = next(iter_) 73 | self.name2iter[task] = iter_ 74 | 75 | yield task, batch 76 | 77 | 78 | def move_to_cuda(batch: Union[List, Tuple, Dict, torch.Tensor], device: torch.device): 79 | if isinstance(batch, torch.Tensor): 80 | return batch.to(device, non_blocking=True) 81 | elif isinstance(batch, list): 82 | return [move_to_cuda(t, device) for t in batch] 83 | elif isinstance(batch, tuple): 84 | return tuple(move_to_cuda(t, device) for t in batch) 85 | elif isinstance(batch, dict): 86 | return {n: move_to_cuda(t, device) for n, t in batch.items()} 87 | return batch 88 | 89 | 90 | class PrefetchLoader(object): 91 | """ 92 | overlap compute and cuda data transfer 93 | """ 94 | def __init__(self, loader, device: torch.device): 95 | self.loader = loader 96 | self.device = device 97 | 98 | def __iter__(self): 99 | loader_it = iter(self.loader) 100 | self.preload(loader_it) 101 | batch = self.next(loader_it) 102 | while batch is not None: 103 | yield batch 104 | batch = self.next(loader_it) 105 | 106 | def __len__(self): 107 | return len(self.loader) 108 | 109 | def preload(self, it): 110 | try: 111 | self.batch = next(it) 112 | except StopIteration: 113 | self.batch = None 114 | return 115 | self.batch = move_to_cuda(self.batch, self.device) 116 | 117 | def next(self, it): 118 | batch = self.batch 119 | self.preload(it) 120 | return batch 121 | 122 | def __getattr__(self, name): 123 | method = self.loader.__getattribute__(name) 124 | return method 125 | 126 | 127 | def build_dataloader(dataset, collate_fn, is_train: bool, opts): 128 | 129 | batch_size = opts.train_batch_size if is_train else opts.val_batch_size 130 | 131 | if opts.local_rank == -1: 132 | if is_train: 133 | sampler: Union[ 134 | RandomSampler, SequentialSampler, DistributedSampler 135 | ] = RandomSampler(dataset) 136 | else: 137 | sampler = SequentialSampler(dataset) 138 | 139 | size = torch.cuda.device_count() if torch.cuda.is_available() else 1 140 | pre_epoch = lambda e: None 141 | 142 | # DataParallel: scale the batch size by the number of GPUs 143 | if size > 1: 144 | batch_size *= size 145 | 146 | else: 147 | size = dist.get_world_size() 148 | sampler = DistributedSampler( 149 | dataset, num_replicas=size, rank=dist.get_rank(), 150 | shuffle=is_train 151 | ) 152 | pre_epoch = sampler.set_epoch 153 | 154 | loader = DataLoader( 155 | dataset, 156 | sampler=sampler, 157 | batch_size=batch_size, 158 | num_workers=opts.n_workers, 159 | pin_memory=opts.pin_mem, 160 | collate_fn=collate_fn, 161 | drop_last=False, 162 | prefetch_factor=2, 163 | ) 164 | 165 | return loader, pre_epoch 166 | -------------------------------------------------------------------------------- /polarnet/eval_tst_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import jsonlines 4 | import argparse 5 | import multiprocessing as mp 6 | 7 | from polarnet.config.default import get_config 8 | 9 | def work_fn(cmd): 10 | os.system(cmd) 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--exp_config', required=True) 16 | parser.add_argument('--seed', type=int, default=200) 17 | parser.add_argument('--num_demos', type=int, default=25) 18 | parser.add_argument('--microstep_data_dir', type=str, default=None) 19 | parser.add_argument('--microstep_outname', type=str, default='microsteps') 20 | parser.add_argument('--checkpoint', required=True) 21 | parser.add_argument('--num_workers', type=int, default=1) 22 | parser.add_argument('--record_video', action='store_true', default=False) 23 | parser.add_argument('--taskvars', type=str, default=None) 24 | parser.add_argument('--cam_rand_factor', type=float, default=0.0) 25 | parser.add_argument('--instr_embed_file', type=str, default=None) 26 | parser.add_argument( 27 | "opts", 28 | default=None, 29 | nargs=argparse.REMAINDER, 30 | help="Modify config options from command line", 31 | ) 32 | args = parser.parse_args() 33 | 34 | config = get_config(args.exp_config) 35 | if len(config.DATASET.taskvars) == 1 and os.path.exists(config.DATASET.taskvars[0]): 36 | taskvars = list(json.load(open(config.DATASET.taskvars[0], 'r')).keys()) 37 | taskvars.sort() 38 | else: 39 | taskvars = config.DATASET.taskvars 40 | 41 | if args.taskvars is not None: 42 | taskvars = args.taskvars.split(',') 43 | exist_taskvars = set() 44 | pred_file = os.path.join(config.output_dir, 'preds', f'seed{args.seed}', 'results.jsonl') 45 | if os.path.exists(pred_file): 46 | with jsonlines.open(pred_file, 'r') as f: 47 | for x in f: 48 | if x['checkpoint'] == args.checkpoint: 49 | exist_taskvars.add('%s+%d'%(x['task'], x['variation'])) 50 | 51 | cmds = [] 52 | for taskvar in taskvars: 53 | if taskvar not in exist_taskvars: 54 | cmd = f'python eval_models.py --exp_config {args.exp_config} --headless --seed {args.seed} --num_demos {args.num_demos} --checkpoint {args.checkpoint} --num_workers 1 --taskvars {taskvar} --cam_rand_factor {args.cam_rand_factor} --instr_embed_file {args.instr_embed_file}' 55 | cmd = '%s %s' % (cmd, ' '.join(args.opts)) 56 | if args.microstep_data_dir is not None: 57 | cmd = '%s --microstep_data_dir %s --microstep_outname %s' % (cmd, args.microstep_data_dir, args.microstep_outname) 58 | if args.record_video: 59 | cmd = '%s --record_video --not_include_robot_cameras' % cmd 60 | cmds.append(cmd) 61 | print('num_jobs', len(cmds)) 62 | 63 | if args.num_workers == 1: 64 | for cmd in cmds: 65 | work_fn(cmd) 66 | else: 67 | pool = mp.Pool(processes=args.num_workers) 68 | pool.map(work_fn, cmds) 69 | pool.close() 70 | pool.join() 71 | 72 | if __name__ == '__main__': 73 | main() -------------------------------------------------------------------------------- /polarnet/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlc-robot/polarnet/882b7ef5b82ee4c7779cdd0020f58e919e0f8bce/polarnet/models/__init__.py -------------------------------------------------------------------------------- /polarnet/models/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class BaseModel(nn.Module): 7 | @property 8 | def num_parameters(self): 9 | nweights, nparams = 0, 0 10 | for k, v in self.named_parameters(): 11 | nweights += np.prod(v.size()) 12 | nparams += 1 13 | return nweights, nparams 14 | 15 | @property 16 | def num_trainable_parameters(self): 17 | nweights, nparams = 0, 0 18 | for k, v in self.named_parameters(): 19 | if v.requires_grad: 20 | nweights += np.prod(v.size()) 21 | nparams += 1 22 | return nweights, nparams 23 | 24 | def prepare_batch(self, batch): 25 | device = next(self.parameters()).device 26 | for k, v in batch.items(): 27 | if isinstance(v, torch.Tensor): 28 | batch[k] = v.to(device) 29 | return batch -------------------------------------------------------------------------------- /polarnet/models/network_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Literal, Union, List, Dict 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import einops 10 | 11 | 12 | def dense_layer(in_channels, out_channels, apply_activation=True): 13 | layer: List[nn.Module] = [nn.Linear(in_channels, out_channels)] 14 | if apply_activation: 15 | layer += [nn.LeakyReLU(0.02)] 16 | return layer 17 | 18 | 19 | def normalise_quat(x): 20 | return x / x.square().sum(dim=-1).sqrt().unsqueeze(-1) 21 | 22 | 23 | class ActionLoss(object): 24 | def __init__(self, use_discrete_rot: bool = False, rot_resolution: int = 5): 25 | self.use_discrete_rot = use_discrete_rot 26 | if self.use_discrete_rot: 27 | self.rot_resolution = rot_resolution 28 | self.rot_classes = 360 // rot_resolution 29 | 30 | def decompose_actions(self, actions, onehot_rot=False): 31 | pos = actions[..., :3] 32 | if not self.use_discrete_rot: 33 | rot = actions[..., 3:7] 34 | open = actions[..., 7] 35 | else: 36 | if onehot_rot: 37 | rot = actions[..., 3: 6].long() 38 | else: 39 | rot = [ 40 | actions[..., 3: 3 + self.rot_classes], 41 | actions[..., 3 + self.rot_classes: 3 + 2*self.rot_classes], 42 | actions[..., 3 + 2*self.rot_classes: 3 + 3*self.rot_classes], 43 | ] 44 | open = actions[..., -1] 45 | return pos, rot, open 46 | 47 | def compute_loss( 48 | self, preds, targets, masks=None, 49 | heatmap_loss=False, distance_weight=1, heatmap_loss_weight=1, 50 | pred_heatmap_logits=None, pred_offset=None, pcd_xyzs=None, 51 | use_heatmap_max=False, use_pos_loss=True 52 | ) -> Dict[str, torch.Tensor]: 53 | pred_pos, pred_rot, pred_open = self.decompose_actions(preds) 54 | tgt_pos, tgt_rot, tgt_open = self.decompose_actions(targets, onehot_rot=True) 55 | 56 | losses = {} 57 | losses['pos'] = F.mse_loss(pred_pos, tgt_pos) 58 | 59 | if self.use_discrete_rot: 60 | losses['rot'] = (F.cross_entropy(pred_rot[0], tgt_rot[:, 0]) + \ 61 | F.cross_entropy(pred_rot[1], tgt_rot[:, 1]) + \ 62 | F.cross_entropy(pred_rot[2], tgt_rot[:, 2])) / 3 63 | else: 64 | # Automatically matching the closest quaternions (symmetrical solution). 65 | tgt_rot_ = -tgt_rot.clone() 66 | rot_loss = F.mse_loss(pred_rot, tgt_rot, reduction='none').mean(-1) 67 | rot_loss_ = F.mse_loss(pred_rot, tgt_rot_, reduction='none').mean(-1) 68 | select_mask = (rot_loss < rot_loss_).float() 69 | losses['rot'] = (select_mask * rot_loss + (1 - select_mask) * rot_loss_).mean() 70 | 71 | losses['open'] = F.binary_cross_entropy_with_logits(pred_open, tgt_open) 72 | 73 | if use_pos_loss: 74 | losses['total'] = losses['pos'] + losses['rot'] + losses['open'] 75 | else: 76 | losses['total'] = losses['rot'] + losses['open'] 77 | 78 | if heatmap_loss: 79 | # (batch, npoints, 3) 80 | tgt_offset = targets[:, :3].unsqueeze(1) - pcd_xyzs 81 | dists = torch.norm(tgt_offset, dim=-1) 82 | if use_heatmap_max: 83 | tgt_heatmap_index = torch.min(dists, 1)[1] # (b, ) 84 | 85 | losses['xt_heatmap'] = F.cross_entropy( 86 | pred_heatmap_logits, tgt_heatmap_index 87 | ) 88 | losses['total'] += losses['xt_heatmap'] * heatmap_loss_weight 89 | 90 | losses['xt_offset'] = F.mse_loss( 91 | pred_offset.gather( 92 | 2, einops.repeat(tgt_heatmap_index, 'b -> b 3').unsqueeze(2) 93 | ), 94 | tgt_offset.gather( 95 | 1, einops.repeat(tgt_heatmap_index, 'b -> b 3').unsqueeze(1) 96 | ) 97 | ) 98 | losses['total'] += losses['xt_offset'] 99 | 100 | else: 101 | inv_dists = 1 / (1e-12 + dists)**distance_weight 102 | 103 | tgt_heatmap = torch.softmax(inv_dists, dim=1) 104 | tgt_log_heatmap = torch.log_softmax(inv_dists, dim=1) 105 | losses['tgt_heatmap_max'] = torch.mean(tgt_heatmap.max(1)[0]) 106 | 107 | losses['xt_heatmap'] = F.kl_div( 108 | torch.log_softmax(pred_heatmap_logits, dim=-1), tgt_log_heatmap, 109 | reduction='batchmean', log_target=True 110 | ) 111 | losses['total'] += losses['xt_heatmap'] * heatmap_loss_weight 112 | 113 | losses['xt_offset'] = torch.sum(F.mse_loss( 114 | pred_offset.permute(0, 2, 1), tgt_offset, 115 | reduction='none' 116 | ) * tgt_heatmap.unsqueeze(2)) / tgt_offset.size(0) / 3 117 | 118 | losses['total'] += losses['xt_offset'] 119 | 120 | return losses 121 | 122 | 123 | class PositionalEncoding(nn.Module): 124 | ''' 125 | Transformer-style positional encoding with wavelets 126 | ''' 127 | 128 | def __init__(self, dim_embed, max_len=500): 129 | super().__init__() 130 | 131 | pe = torch.zeros(max_len, dim_embed) 132 | position = torch.arange(0, max_len).unsqueeze(1) 133 | div_term = torch.exp((torch.arange(0, dim_embed, 2, dtype=torch.float) * 134 | -(math.log(10000.0) / dim_embed))) 135 | pe[:, 0::2] = torch.sin(position.float() * div_term) 136 | pe[:, 1::2] = torch.cos(position.float() * div_term) 137 | 138 | self.pe = pe # size=(max_len, dim_embed) 139 | self.dim_embed = dim_embed 140 | 141 | def forward(self, step_ids): 142 | if step_ids.device != self.pe.device: 143 | self.pe = self.pe.to(step_ids.device) 144 | return self.pe[step_ids] 145 | -------------------------------------------------------------------------------- /polarnet/models/pcd_unet.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import copy 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import einops 8 | from scipy.spatial.transform import Rotation as R 9 | 10 | from openpoints.models.backbone.pointnext import ( 11 | PointNextEncoder, FeaturePropogation 12 | ) 13 | 14 | from polarnet.models.network_utils import ( 15 | dense_layer, normalise_quat, ActionLoss, PositionalEncoding 16 | ) 17 | from polarnet.models.base import BaseModel 18 | 19 | 20 | class PointNextDecoder(nn.Module): 21 | def __init__(self, 22 | encoder_channel_list: List[int], 23 | decoder_layers: int = 2, 24 | ): 25 | super().__init__() 26 | self.decoder_layers = decoder_layers 27 | self.in_channels = encoder_channel_list[-1] 28 | skip_channels = encoder_channel_list[:-1] 29 | fp_channels = encoder_channel_list[:-1] # feature propogation 30 | 31 | n_decoder_stages = len(fp_channels) 32 | decoder = [[] for _ in range(n_decoder_stages)] 33 | for i in range(-1, -n_decoder_stages - 1, -1): 34 | decoder[i] = self._make_dec( 35 | skip_channels[i], fp_channels[i] 36 | ) 37 | self.decoder = nn.Sequential(*decoder) 38 | self.out_channels = fp_channels[-n_decoder_stages] 39 | 40 | def _make_dec(self, skip_channels, fp_channels): 41 | layers = [] 42 | mlp = [skip_channels + self.in_channels] + \ 43 | [fp_channels] * self.decoder_layers 44 | layers.append(FeaturePropogation(mlp)) 45 | self.in_channels = fp_channels 46 | return nn.Sequential(*layers) 47 | 48 | def forward(self, p, f, txt_tokens=None, txt_padding_masks=None, return_all_layers=False): 49 | if return_all_layers: 50 | out_per_layer = [] 51 | 52 | for i in range(-1, -len(self.decoder) - 1, -1): 53 | x = self.decoder[i][0]([p[i - 1], f[i - 1]], [p[i], f[i]]) 54 | f[i - 1] = self.decoder[i][1:]([p[i], x])[1] 55 | if return_all_layers: 56 | out_per_layer.append(f[i - 1]) 57 | 58 | if return_all_layers: 59 | return out_per_layer 60 | 61 | out = f[-len(self.decoder) - 1] 62 | return out 63 | 64 | 65 | class ActionHead(nn.Module): 66 | def __init__( 67 | self, dec_channels, heatmap_temp=1, dropout=0, use_max_action=False, 68 | use_discrete_rot:bool=False, rot_resolution:int=5, 69 | ) -> None: 70 | super().__init__() 71 | self.use_discrete_rot = use_discrete_rot 72 | self.rot_resolution = rot_resolution 73 | 74 | if self.use_discrete_rot: 75 | self.rot_decoder = nn.Sequential( 76 | *dense_layer(dec_channels[0], dec_channels[0] // 2), 77 | nn.Dropout(dropout), 78 | *dense_layer(dec_channels[0] // 2, (360 // rot_resolution) * 3 + 1, apply_activation=False), 79 | ) 80 | else: 81 | self.quat_decoder = nn.Sequential( 82 | *dense_layer(dec_channels[0], dec_channels[0] // 2), 83 | nn.Dropout(dropout), 84 | *dense_layer(dec_channels[0] // 2, 4 + 1, apply_activation=False), 85 | ) 86 | 87 | self.maps_to_coord = nn.Sequential( 88 | nn.Dropout(dropout), 89 | nn.Conv1d(dec_channels[-1], 1 + 3, 1) 90 | ) 91 | self.heatmap_temp = heatmap_temp 92 | self.use_max_action = use_max_action 93 | 94 | def forward(self, dec_fts, pcds, pc_centers, pc_radii): 95 | ''' 96 | - dec_fts: [(batch, dec_channels[0], npoints), (batch, dec_channels[-1], npoints)] 97 | - pcds: (batch, 3, npoints) 98 | ''' 99 | # predict the translation of the gripper 100 | xt_fts = self.maps_to_coord(dec_fts[-1]) 101 | xt_heatmap = torch.softmax(xt_fts[:, :1] / self.heatmap_temp, dim=-1) 102 | xt_offset = xt_fts[:, 1:] 103 | if self.use_max_action: 104 | xt = pcds + xt_offset # (b, 3, npoints) 105 | xt = xt.gather( 106 | 2, einops.repeat(torch.max(xt_heatmap, dim=2)[1], 'b 1 -> b 3').unsqueeze(2) 107 | ).squeeze(2) 108 | else: 109 | xt = einops.reduce((pcds + xt_offset) * xt_heatmap, 'b c n -> b c', 'sum') 110 | xt = xt * pc_radii + pc_centers 111 | 112 | # predict the (rotation, openness) of the gripper 113 | xg_fts, _ = torch.max(dec_fts[0], -1) 114 | 115 | if self.use_discrete_rot: 116 | xg = self.rot_decoder(xg_fts) 117 | xr = xg[..., :-1] 118 | else: 119 | xg = self.quat_decoder(xg_fts) 120 | xr = normalise_quat(xg[..., :4]) 121 | 122 | xo = xg[..., -1:] 123 | 124 | actions = torch.cat([xt, xr, xo], dim=-1) 125 | 126 | return { 127 | 'actions': actions, 128 | 'xt_offset': xt_offset * pc_radii.unsqueeze(2), 129 | 'xt_heatmap': xt_heatmap.squeeze(1), 130 | 'xt_heatmap_logits': xt_fts[:, 0] / self.heatmap_temp, 131 | } 132 | 133 | 134 | class ActionEmbedding(nn.Module): 135 | def __init__(self, hidden_size) -> None: 136 | super().__init__() 137 | 138 | self.open_embedding = nn.Embedding(2, hidden_size) 139 | self.pos_embedding = nn.Linear(3, hidden_size) 140 | self.rot_embedding = nn.Linear(6, hidden_size) 141 | self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12) 142 | 143 | def forward(self, actions): 144 | ''' 145 | actions: (batch_size, 8) 146 | ''' 147 | pos_embeds = self.pos_embedding(actions[..., :3]) 148 | open_embeds = self.open_embedding(actions[..., -1].long()) 149 | 150 | rot_euler_angles = R.from_quat(actions[..., 3:7].data.cpu()).as_euler('xyz') 151 | rot_euler_angles = torch.from_numpy(rot_euler_angles).float().to(actions.device) 152 | rot_inputs = torch.cat( 153 | [torch.sin(rot_euler_angles), torch.cos(rot_euler_angles)], -1 154 | ) 155 | rot_embeds = self.rot_embedding(rot_inputs) 156 | 157 | act_embeds = self.layer_norm( 158 | pos_embeds + rot_embeds + open_embeds 159 | ) 160 | return act_embeds 161 | 162 | 163 | class PointCloudUNet(BaseModel): 164 | def __init__( 165 | self, pcd_encoder_cfg, pcd_decoder_cfg, 166 | num_tasks: int = None, max_steps: int = 20, 167 | use_instr_embed: str = 'none', instr_embed_size: int = None, 168 | txt_attn_type: str = 'none', num_trans_layers: int = 1, 169 | trans_hidden_size: int = 512, 170 | dropout=0.2, heatmap_temp=1, use_prev_action=False, 171 | cat_global_in_head=False, **kwargs 172 | ): 173 | super().__init__() 174 | 175 | self.pcd_encoder_cfg = pcd_encoder_cfg 176 | self.pcd_decoder_cfg = pcd_decoder_cfg 177 | self.num_tasks = num_tasks 178 | self.max_steps = max_steps 179 | self.use_instr_embed = use_instr_embed 180 | self.instr_embed_size = instr_embed_size 181 | self.txt_attn_type = txt_attn_type 182 | self.num_trans_layers = num_trans_layers 183 | self.use_prev_action = use_prev_action 184 | self.cat_global_in_head = cat_global_in_head 185 | self.heatmap_temp = heatmap_temp 186 | self.use_discrete_rot = kwargs.get('use_discrete_rot', False) 187 | self.rot_resolution = kwargs.get('rot_resolution', 5) 188 | self.kwargs = kwargs 189 | 190 | self.pcd_encoder = PointNextEncoder(**pcd_encoder_cfg) 191 | enc_channel_list = self.pcd_encoder.channel_list 192 | self.hidden_size = trans_hidden_size 193 | 194 | self.pcd_decoder = PointNextDecoder( 195 | enc_channel_list[:-1] + [enc_channel_list[-1] + self.hidden_size], pcd_decoder_cfg.layers, 196 | ) 197 | 198 | if self.kwargs.get('learnable_step_embedding', True): 199 | self.step_embedding = nn.Embedding(self.max_steps, self.hidden_size) 200 | else: 201 | self.step_embedding = PositionalEncoding(self.hidden_size, max_len=self.max_steps) 202 | 203 | if self.use_prev_action: 204 | self.prev_action_embedding = ActionEmbedding(self.hidden_size) 205 | 206 | if self.use_instr_embed == 'none': 207 | assert self.num_tasks is not None 208 | self.task_embedding = nn.Embedding(self.num_tasks, self.hidden_size) 209 | else: 210 | assert self.instr_embed_size is not None 211 | self.task_embedding = nn.Linear(self.instr_embed_size, self.hidden_size) 212 | 213 | self.point_pos_embedding = nn.Linear(3, self.hidden_size) 214 | 215 | if self.txt_attn_type == 'cross': 216 | if enc_channel_list[-1] != self.hidden_size: 217 | self.pcd_to_trans_fc = nn.Conv1d( 218 | in_channels=enc_channel_list[-1], 219 | out_channels=self.hidden_size, 220 | kernel_size=1, stride=1 221 | ) 222 | else: 223 | self.pcd_to_trans_fc = None 224 | trans_layer = nn.TransformerDecoderLayer( 225 | d_model=self.hidden_size, 226 | nhead=8, 227 | dim_feedforward=self.hidden_size*4, 228 | dropout=0.1, activation='gelu', 229 | layer_norm_eps=1e-12, norm_first=False, 230 | batch_first=True, 231 | ) 232 | self.cross_attention = nn.TransformerDecoder( 233 | trans_layer, num_layers=self.num_trans_layers 234 | ) 235 | 236 | dec_ft_size = enc_channel_list[0] 237 | if self.cat_global_in_head: 238 | dec_ft_size += self.hidden_size 239 | self.head = ActionHead( 240 | [enc_channel_list[-1] + self.hidden_size, dec_ft_size], 241 | heatmap_temp=heatmap_temp, dropout=dropout, 242 | use_max_action=kwargs.get('use_max_action', False), 243 | use_discrete_rot=self.use_discrete_rot, 244 | rot_resolution=self.rot_resolution, 245 | ) 246 | 247 | self.loss_fn = ActionLoss( 248 | use_discrete_rot=self.use_discrete_rot, 249 | rot_resolution=self.rot_resolution 250 | ) 251 | 252 | def forward(self, batch, compute_loss=False): 253 | batch = self.prepare_batch(batch) 254 | 255 | # encode point cloud 256 | pcd_fts = batch['fts'] # (batch, dim, npoints) 257 | pcd_poses = pcd_fts[:, :3] 258 | 259 | pos_list, ft_list = self.pcd_encoder( 260 | pcd_poses.permute(0, 2, 1).contiguous(), pcd_fts 261 | ) 262 | ctx_embeds = ft_list[-1] 263 | if self.pcd_to_trans_fc is not None: 264 | ctx_embeds = self.pcd_to_trans_fc(ctx_embeds) 265 | 266 | step_ids = batch['step_ids'] 267 | step_embeds = self.step_embedding(step_ids) 268 | ctx_embeds = ctx_embeds + step_embeds.unsqueeze(2) 269 | if self.use_prev_action: 270 | ctx_embeds = ctx_embeds + self.prev_action_embedding(batch['prev_actions']).unsqueeze(2) 271 | ctx_embeds = ctx_embeds + self.point_pos_embedding(pos_list[-1]).permute(0, 2, 1) 272 | 273 | # conditioned on the task 274 | taskvar_ids = batch['taskvar_ids'] 275 | instr_embeds = batch.get('instr_embeds', None) 276 | txt_masks = batch.get('txt_masks', None) 277 | 278 | if self.use_instr_embed == 'none': 279 | task_embeds = self.task_embedding(taskvar_ids).unsqueeze(1) # (batch, 1, dim) 280 | else: 281 | task_embeds = self.task_embedding(instr_embeds) # (batch, 1/len, dim) 282 | 283 | if self.txt_attn_type == 'none': 284 | assert task_embeds.size(1) == 1 285 | ctx_embeds = ctx_embeds + task_embeds.permute(0, 2, 1) 286 | elif self.txt_attn_type == 'cross': 287 | assert txt_masks is not None 288 | ctx_embeds = self.cross_attention( 289 | ctx_embeds.permute(0, 2, 1), task_embeds, 290 | memory_key_padding_mask=txt_masks.logical_not(), 291 | ) 292 | ctx_embeds = ctx_embeds.permute(0, 2, 1) 293 | else: 294 | raise NotImplementedError(f'unsupported txt_attn_type {self.txt_attn_type}') 295 | 296 | ft_list[-1] = torch.cat([ft_list[-1], ctx_embeds], dim=1) 297 | 298 | # decoding features 299 | dec_fts = self.pcd_decoder(pos_list, ft_list) 300 | 301 | if self.cat_global_in_head: 302 | global_ctx_embeds, _ = torch.max(ctx_embeds, 2) 303 | global_ctx_embeds = einops.repeat(global_ctx_embeds, 'b c -> b c n', n=dec_fts.size(2)) 304 | dec_fts = torch.cat([dec_fts, global_ctx_embeds], dim=1) 305 | outs = self.head( 306 | (ft_list[-1], dec_fts), pcd_poses, 307 | batch['pc_centers'], batch['pc_radii'] 308 | ) 309 | actions = outs['actions'] 310 | 311 | if compute_loss: 312 | heatmap_loss = self.kwargs.get('heatmap_loss', False) 313 | heatmap_loss_weight = self.kwargs.get('heatmap_loss_weight', 1) 314 | distance_weight = self.kwargs.get('heatmap_distance_weight', 1) 315 | if heatmap_loss: 316 | pcd_xyzs = pcd_poses.permute(0, 2, 1) * batch['pc_radii'].unsqueeze(1) + batch['pc_centers'].unsqueeze(1) # (b, npoints, 3) 317 | else: 318 | pcd_xyzs = None 319 | losses = self.loss_fn.compute_loss( 320 | actions, batch['actions'], heatmap_loss=heatmap_loss, 321 | pred_heatmap_logits=outs['xt_heatmap_logits'], 322 | pred_offset=outs['xt_offset'], 323 | pcd_xyzs=pcd_xyzs, distance_weight=distance_weight, 324 | heatmap_loss_weight=heatmap_loss_weight, 325 | use_heatmap_max=self.kwargs.get('use_heatmap_max', False), 326 | use_pos_loss=self.kwargs.get('use_pos_loss', True) 327 | ) 328 | 329 | return losses, actions 330 | 331 | return actions 332 | -------------------------------------------------------------------------------- /polarnet/optim/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | from .sched import noam_schedule, warmup_linear, get_lr_sched, get_lr_sched_decay_rate 7 | from .adamw import AdamW 8 | -------------------------------------------------------------------------------- /polarnet/optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamW optimizer (weight decay fix) 3 | copied from hugginface (https://github.com/huggingface/transformers). 4 | """ 5 | 6 | import math 7 | from typing import Callable, Iterable, Tuple 8 | 9 | import torch 10 | 11 | from torch.optim import Optimizer 12 | 13 | class AdamW(Optimizer): 14 | """ 15 | Implements Adam algorithm with weight decay fix as introduced in `Decoupled Weight Decay Regularization 16 | `__. 17 | 18 | Parameters: 19 | params (:obj:`Iterable[torch.nn.parameter.Parameter]`): 20 | Iterable of parameters to optimize or dictionaries defining parameter groups. 21 | lr (:obj:`float`, `optional`, defaults to 1e-3): 22 | The learning rate to use. 23 | betas (:obj:`Tuple[float,float]`, `optional`, defaults to (0.9, 0.999)): 24 | Adam's betas parameters (b1, b2). 25 | eps (:obj:`float`, `optional`, defaults to 1e-6): 26 | Adam's epsilon for numerical stability. 27 | weight_decay (:obj:`float`, `optional`, defaults to 0): 28 | Decoupled weight decay to apply. 29 | correct_bias (:obj:`bool`, `optional`, defaults to `True`): 30 | Whether ot not to correct bias in Adam (for instance, in Bert TF repository they use :obj:`False`). 31 | """ 32 | 33 | def __init__( 34 | self, 35 | params: Iterable[torch.nn.parameter.Parameter], 36 | lr: float = 1e-3, 37 | betas: Tuple[float, float] = (0.9, 0.999), 38 | eps: float = 1e-6, 39 | weight_decay: float = 0.0, 40 | correct_bias: bool = True, 41 | ): 42 | if lr < 0.0: 43 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 44 | if not 0.0 <= betas[0] < 1.0: 45 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 46 | if not 0.0 <= betas[1] < 1.0: 47 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 48 | if not 0.0 <= eps: 49 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 50 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) 51 | super().__init__(params, defaults) 52 | 53 | def step(self, closure: Callable = None): 54 | """ 55 | Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss. 59 | """ 60 | loss = None 61 | if closure is not None: 62 | loss = closure() 63 | 64 | for group in self.param_groups: 65 | for p in group["params"]: 66 | if p.grad is None: 67 | continue 68 | grad = p.grad.data 69 | if grad.is_sparse: 70 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 71 | 72 | state = self.state[p] 73 | 74 | # State initialization 75 | if len(state) == 0: 76 | state["step"] = 0 77 | # Exponential moving average of gradient values 78 | state["exp_avg"] = torch.zeros_like(p.data) 79 | # Exponential moving average of squared gradient values 80 | state["exp_avg_sq"] = torch.zeros_like(p.data) 81 | 82 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 83 | beta1, beta2 = group["betas"] 84 | 85 | state["step"] += 1 86 | 87 | # Decay the first and second moment running average coefficient 88 | # In-place operations to update the averages at the same time 89 | exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) 90 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 91 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 92 | 93 | step_size = group["lr"] 94 | if group["correct_bias"]: # No bias correction for Bert 95 | bias_correction1 = 1.0 - beta1 ** state["step"] 96 | bias_correction2 = 1.0 - beta2 ** state["step"] 97 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 98 | 99 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 100 | 101 | # Just adding the square of the weights to the loss function is *not* 102 | # the correct way of using L2 regularization/weight decay with Adam, 103 | # since that will interact with the m and v parameters in strange ways. 104 | # 105 | # Instead we want to decay the weights in a manner that doesn't interact 106 | # with the m/v parameters. This is equivalent to adding the square 107 | # of the weights to the loss with plain (non-momentum) SGD. 108 | # Add weight decay at the end (fixed version) 109 | if group["weight_decay"] > 0.0: 110 | p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) 111 | 112 | return loss 113 | -------------------------------------------------------------------------------- /polarnet/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | # Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py 2 | 3 | """ Lookahead Optimizer Wrapper. 4 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 5 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from torch.optim import Adam 10 | from collections import defaultdict 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | if not 0.0 <= alpha <= 1.0: 15 | raise ValueError(f'Invalid slow update rate: {alpha}') 16 | if not 1 <= k: 17 | raise ValueError(f'Invalid lookahead steps: {k}') 18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 19 | self.base_optimizer = base_optimizer 20 | self.param_groups = self.base_optimizer.param_groups 21 | self.defaults = base_optimizer.defaults 22 | self.defaults.update(defaults) 23 | self.state = defaultdict(dict) 24 | # manually add our defaults to the param groups 25 | for name, default in defaults.items(): 26 | for group in self.param_groups: 27 | group.setdefault(name, default) 28 | 29 | def update_slow(self, group): 30 | for fast_p in group["params"]: 31 | if fast_p.grad is None: 32 | continue 33 | param_state = self.state[fast_p] 34 | if 'slow_buffer' not in param_state: 35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 36 | param_state['slow_buffer'].copy_(fast_p.data) 37 | slow = param_state['slow_buffer'] 38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 39 | fast_p.data.copy_(slow) 40 | 41 | def sync_lookahead(self): 42 | for group in self.param_groups: 43 | self.update_slow(group) 44 | 45 | def step(self, closure=None): 46 | # print(self.k) 47 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups) 48 | loss = self.base_optimizer.step(closure) 49 | for group in self.param_groups: 50 | group['lookahead_step'] += 1 51 | if group['lookahead_step'] % group['lookahead_k'] == 0: 52 | self.update_slow(group) 53 | return loss 54 | 55 | def state_dict(self): 56 | fast_state_dict = self.base_optimizer.state_dict() 57 | slow_state = { 58 | (id(k) if isinstance(k, torch.Tensor) else k): v 59 | for k, v in self.state.items() 60 | } 61 | fast_state = fast_state_dict['state'] 62 | param_groups = fast_state_dict['param_groups'] 63 | return { 64 | 'state': fast_state, 65 | 'slow_state': slow_state, 66 | 'param_groups': param_groups, 67 | } 68 | 69 | def load_state_dict(self, state_dict): 70 | fast_state_dict = { 71 | 'state': state_dict['state'], 72 | 'param_groups': state_dict['param_groups'], 73 | } 74 | self.base_optimizer.load_state_dict(fast_state_dict) 75 | 76 | # We want to restore the slow state, but share param_groups reference 77 | # with base_optimizer. This is a bit redundant but least code 78 | slow_state_new = False 79 | if 'slow_state' not in state_dict: 80 | print('Loading state_dict from optimizer without Lookahead applied.') 81 | state_dict['slow_state'] = defaultdict(dict) 82 | slow_state_new = True 83 | slow_state_dict = { 84 | 'state': state_dict['slow_state'], 85 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 86 | } 87 | super(Lookahead, self).load_state_dict(slow_state_dict) 88 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 89 | if slow_state_new: 90 | # reapply defaults to catch missing lookahead specific ones 91 | for name, default in self.defaults.items(): 92 | for group in self.param_groups: 93 | group.setdefault(name, default) 94 | 95 | def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs): 96 | adam = Adam(params, *args, **kwargs) 97 | return Lookahead(adam, alpha, k) 98 | -------------------------------------------------------------------------------- /polarnet/optim/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Misc lr helper 6 | """ 7 | from torch.optim import Adam, Adamax 8 | 9 | from .adamw import AdamW 10 | from .rangerlars import RangerLars 11 | 12 | 13 | def build_optimizer(model, opts): 14 | param_optimizer = list(model.named_parameters()) 15 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 16 | rgb_enc_params, other_params = {}, {} 17 | for n, p in param_optimizer: 18 | if not p.requires_grad: continue 19 | if 'rgb_encoder' in n: 20 | rgb_enc_params[n] = p 21 | else: 22 | other_params[n] = p 23 | 24 | optimizer_grouped_parameters = [] 25 | init_lrs = [] 26 | for ptype, pdict in [('rgb', rgb_enc_params), ('others', other_params)]: 27 | if len(pdict) == 0: 28 | continue 29 | init_lr = opts.learning_rate 30 | if ptype == 'rgb': 31 | init_lr = init_lr * getattr(opts, 'rgb_encoder_lr_multi', 1) 32 | optimizer_grouped_parameters.extend([ 33 | {'params': [p for n, p in pdict.items() 34 | if not any(nd in n for nd in no_decay)], 35 | 'weight_decay': opts.weight_decay, 'lr': init_lr}, 36 | {'params': [p for n, p in pdict.items() 37 | if any(nd in n for nd in no_decay)], 38 | 'weight_decay': 0.0, 'lr': init_lr} 39 | ]) 40 | init_lrs.extend([init_lr] * 2) 41 | 42 | # currently Adam only 43 | if opts.optim == 'adam': 44 | OptimCls = Adam 45 | elif opts.optim == 'adamax': 46 | OptimCls = Adamax 47 | elif opts.optim == 'adamw': 48 | OptimCls = AdamW 49 | elif opts.optim == 'rangerlars': 50 | OptimCls = RangerLars 51 | else: 52 | raise ValueError('invalid optimizer') 53 | optimizer = OptimCls(optimizer_grouped_parameters, 54 | lr=opts.learning_rate, betas=opts.betas) 55 | return optimizer, init_lrs 56 | -------------------------------------------------------------------------------- /polarnet/optim/radam.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam.py 2 | 3 | import math 4 | import torch 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | class RAdam(Optimizer): 8 | 9 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 10 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 11 | self.buffer = [[None, None, None] for ind in range(10)] 12 | super(RAdam, self).__init__(params, defaults) 13 | 14 | def __setstate__(self, state): 15 | super(RAdam, self).__setstate__(state) 16 | 17 | def step(self, closure=None): 18 | 19 | loss = None 20 | if closure is not None: 21 | loss = closure() 22 | 23 | for group in self.param_groups: 24 | 25 | for p in group['params']: 26 | if p.grad is None: 27 | continue 28 | grad = p.grad.data.float() 29 | if grad.is_sparse: 30 | raise RuntimeError('RAdam does not support sparse gradients') 31 | 32 | p_data_fp32 = p.data.float() 33 | 34 | state = self.state[p] 35 | 36 | if len(state) == 0: 37 | state['step'] = 0 38 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 39 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 40 | else: 41 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 42 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 43 | 44 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 45 | beta1, beta2 = group['betas'] 46 | 47 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 48 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 49 | 50 | state['step'] += 1 51 | buffered = self.buffer[int(state['step'] % 10)] 52 | if state['step'] == buffered[0]: 53 | N_sma, step_size = buffered[1], buffered[2] 54 | else: 55 | buffered[0] = state['step'] 56 | beta2_t = beta2 ** state['step'] 57 | N_sma_max = 2 / (1 - beta2) - 1 58 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 59 | buffered[1] = N_sma 60 | 61 | # more conservative since it's an approximated value 62 | if N_sma >= 5: 63 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 64 | else: 65 | step_size = 1.0 / (1 - beta1 ** state['step']) 66 | buffered[2] = step_size 67 | 68 | if group['weight_decay'] != 0: 69 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 70 | 71 | # more conservative since it's an approximated value 72 | if N_sma >= 5: 73 | denom = exp_avg_sq.sqrt().add_(group['eps']) 74 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 75 | else: 76 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 77 | 78 | p.data.copy_(p_data_fp32) 79 | 80 | return loss 81 | 82 | class PlainRAdam(Optimizer): 83 | 84 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 85 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 86 | 87 | super(PlainRAdam, self).__init__(params, defaults) 88 | 89 | def __setstate__(self, state): 90 | super(PlainRAdam, self).__setstate__(state) 91 | 92 | def step(self, closure=None): 93 | 94 | loss = None 95 | if closure is not None: 96 | loss = closure() 97 | 98 | for group in self.param_groups: 99 | 100 | for p in group['params']: 101 | if p.grad is None: 102 | continue 103 | grad = p.grad.data.float() 104 | if grad.is_sparse: 105 | raise RuntimeError('RAdam does not support sparse gradients') 106 | 107 | p_data_fp32 = p.data.float() 108 | 109 | state = self.state[p] 110 | 111 | if len(state) == 0: 112 | state['step'] = 0 113 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 114 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 115 | else: 116 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 117 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 118 | 119 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 120 | beta1, beta2 = group['betas'] 121 | 122 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 123 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 124 | 125 | state['step'] += 1 126 | beta2_t = beta2 ** state['step'] 127 | N_sma_max = 2 / (1 - beta2) - 1 128 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 129 | 130 | if group['weight_decay'] != 0: 131 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 132 | 133 | # more conservative since it's an approximated value 134 | if N_sma >= 5: 135 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 136 | denom = exp_avg_sq.sqrt().add_(group['eps']) 137 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 138 | else: 139 | step_size = group['lr'] / (1 - beta1 ** state['step']) 140 | p_data_fp32.add_(-step_size, exp_avg) 141 | 142 | p.data.copy_(p_data_fp32) 143 | 144 | return loss 145 | 146 | 147 | class AdamW(Optimizer): 148 | 149 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 150 | defaults = dict(lr=lr, betas=betas, eps=eps, 151 | weight_decay=weight_decay, warmup = warmup) 152 | super(AdamW, self).__init__(params, defaults) 153 | 154 | def __setstate__(self, state): 155 | super(AdamW, self).__setstate__(state) 156 | 157 | def step(self, closure=None): 158 | loss = None 159 | if closure is not None: 160 | loss = closure() 161 | 162 | for group in self.param_groups: 163 | 164 | for p in group['params']: 165 | if p.grad is None: 166 | continue 167 | grad = p.grad.data.float() 168 | if grad.is_sparse: 169 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 170 | 171 | p_data_fp32 = p.data.float() 172 | 173 | state = self.state[p] 174 | 175 | if len(state) == 0: 176 | state['step'] = 0 177 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 178 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 179 | else: 180 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 181 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 182 | 183 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 184 | beta1, beta2 = group['betas'] 185 | 186 | state['step'] += 1 187 | 188 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 189 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 190 | 191 | denom = exp_avg_sq.sqrt().add_(group['eps']) 192 | bias_correction1 = 1 - beta1 ** state['step'] 193 | bias_correction2 = 1 - beta2 ** state['step'] 194 | 195 | if group['warmup'] > state['step']: 196 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 197 | else: 198 | scheduled_lr = group['lr'] 199 | 200 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 201 | 202 | if group['weight_decay'] != 0: 203 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 204 | 205 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 206 | 207 | p.data.copy_(p_data_fp32) 208 | 209 | return loss 210 | -------------------------------------------------------------------------------- /polarnet/optim/ralamb.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | from torch.optim.optimizer import Optimizer 3 | 4 | # RAdam + LARS 5 | class Ralamb(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 8 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 9 | self.buffer = [[None, None, None] for ind in range(10)] 10 | super(Ralamb, self).__init__(params, defaults) 11 | 12 | def __setstate__(self, state): 13 | super(Ralamb, self).__setstate__(state) 14 | 15 | def step(self, closure=None): 16 | 17 | loss = None 18 | if closure is not None: 19 | loss = closure() 20 | 21 | for group in self.param_groups: 22 | 23 | for p in group['params']: 24 | if p.grad is None: 25 | continue 26 | grad = p.grad.data.float() 27 | if grad.is_sparse: 28 | raise RuntimeError('Ralamb does not support sparse gradients') 29 | 30 | p_data_fp32 = p.data.float() 31 | 32 | state = self.state[p] 33 | 34 | if len(state) == 0: 35 | state['step'] = 0 36 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 37 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 38 | else: 39 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 40 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 41 | 42 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 43 | beta1, beta2 = group['betas'] 44 | 45 | # Decay the first and second moment running average coefficient 46 | # m_t 47 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 48 | # v_t 49 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 50 | 51 | state['step'] += 1 52 | buffered = self.buffer[int(state['step'] % 10)] 53 | 54 | if state['step'] == buffered[0]: 55 | N_sma, radam_step_size = buffered[1], buffered[2] 56 | else: 57 | buffered[0] = state['step'] 58 | beta2_t = beta2 ** state['step'] 59 | N_sma_max = 2 / (1 - beta2) - 1 60 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 61 | buffered[1] = N_sma 62 | 63 | # more conservative since it's an approximated value 64 | if N_sma >= 5: 65 | radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 66 | else: 67 | radam_step_size = 1.0 / (1 - beta1 ** state['step']) 68 | buffered[2] = radam_step_size 69 | 70 | if group['weight_decay'] != 0: 71 | p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) 72 | 73 | # more conservative since it's an approximated value 74 | radam_step = p_data_fp32.clone() 75 | if N_sma >= 5: 76 | denom = exp_avg_sq.sqrt().add_(group['eps']) 77 | radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom) 78 | else: 79 | radam_step.add_(exp_avg, alpha=-radam_step_size * group['lr']) 80 | 81 | radam_norm = radam_step.pow(2).sum().sqrt() 82 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 83 | if weight_norm == 0 or radam_norm == 0: 84 | trust_ratio = 1 85 | else: 86 | trust_ratio = weight_norm / radam_norm 87 | 88 | state['weight_norm'] = weight_norm 89 | state['adam_norm'] = radam_norm 90 | state['trust_ratio'] = trust_ratio 91 | 92 | if N_sma >= 5: 93 | p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom) 94 | else: 95 | p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg) 96 | 97 | p.data.copy_(p_data_fp32) 98 | 99 | return loss 100 | -------------------------------------------------------------------------------- /polarnet/optim/rangerlars.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | from torch.optim.optimizer import Optimizer 3 | import itertools as it 4 | from .lookahead import * 5 | from .ralamb import * 6 | 7 | # RAdam + LARS + LookAHead 8 | 9 | # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py 10 | # RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 11 | 12 | def RangerLars(params, alpha=0.5, k=6, *args, **kwargs): 13 | ralamb = Ralamb(params, *args, **kwargs) 14 | return Lookahead(ralamb, alpha, k) 15 | -------------------------------------------------------------------------------- /polarnet/optim/sched.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | optimizer learning rate scheduling helpers 6 | """ 7 | from math import ceil 8 | 9 | 10 | def noam_schedule(step, warmup_step=4000): 11 | """ original Transformer schedule""" 12 | if step <= warmup_step: 13 | return step / warmup_step 14 | return (warmup_step ** 0.5) * (step ** -0.5) 15 | 16 | 17 | def warmup_linear(step, warmup_step, tot_step): 18 | """ BERT schedule """ 19 | if step < warmup_step: 20 | return step / warmup_step 21 | return max(0, (tot_step-step)/(tot_step-warmup_step)) 22 | 23 | def warmup_inverse_sqrt(step, warmup_step, tot_step): 24 | """Decay the LR based on the inverse square root of the update number. 25 | We also support a warmup phase where we linearly increase the learning rate 26 | from some initial learning rate (``--warmup-init-lr``) until the configured 27 | learning rate (``--lr``). Thereafter we decay proportional to the number of 28 | updates, with a decay factor set to align with the configured learning rate. 29 | 30 | During warmup:: 31 | 32 | lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) 33 | lr = lrs[update_num] 34 | 35 | After warmup:: 36 | 37 | decay_factor = cfg.lr * sqrt(cfg.warmup_updates) 38 | lr = decay_factor / sqrt(update_num) 39 | """ 40 | if step < warmup_step: 41 | return step / warmup_step 42 | else: 43 | return warmup_step**0.5 * step**-0.5 44 | 45 | 46 | def get_lr_sched(global_step, opts): 47 | # learning rate scheduling 48 | if opts.lr_sched == 'linear': 49 | func = warmup_linear 50 | elif opts.lr_sched == 'inverse_sqrt': 51 | func = warmup_inverse_sqrt 52 | else: 53 | raise NotImplementedError(f'invalid lr scheduler {opts.lr_sched}') 54 | 55 | lr_this_step = opts.learning_rate * func( 56 | global_step, opts.warmup_steps, opts.num_train_steps 57 | ) 58 | if lr_this_step <= 0: 59 | lr_this_step = 1e-8 60 | return lr_this_step 61 | 62 | def get_lr_sched_decay_rate(global_step, opts): 63 | # learning rate scheduling 64 | if opts.lr_sched == 'linear': 65 | lr_decay_fn = warmup_linear 66 | elif opts.lr_sched == 'inverse_sqrt': 67 | lr_decay_fn = warmup_inverse_sqrt 68 | 69 | lr_decay_rate = lr_decay_fn( 70 | global_step, opts.warmup_steps, opts.num_train_steps 71 | ) 72 | lr_decay_rate = max(lr_decay_rate, 1e-5) 73 | return lr_decay_rate 74 | -------------------------------------------------------------------------------- /polarnet/preprocess/evaluate_dataset_keysteps.py: -------------------------------------------------------------------------------- 1 | from polarnet.core.actioner import BaseActioner 2 | from polarnet.core.environments import RLBenchEnv 3 | from typing import Tuple, Dict, List 4 | 5 | import os 6 | import numpy as np 7 | import random 8 | 9 | import itertools 10 | from pathlib import Path 11 | from tqdm import tqdm 12 | import collections 13 | import tap 14 | import json 15 | 16 | import lmdb 17 | import msgpack 18 | import msgpack_numpy 19 | 20 | msgpack_numpy.patch() 21 | 22 | 23 | class Arguments(tap.Tap): 24 | microstep_data_dir: Path = "data/train_dataset/microsteps/seed0" 25 | keystep_data_dir: Path = "data/train_dataset/keysteps/seed0" 26 | 27 | seed: int = 0 28 | num_demos: int = 100 29 | 30 | tasks: Tuple[str, ...] = ("pick_up_cup",) 31 | cameras: Tuple[str, ...] = ("left_shoulder", "right_shoulder", "wrist") 32 | 33 | max_variations: int = 1 34 | offset: int = 0 35 | 36 | headless: bool = False 37 | gripper_pose: str = None 38 | max_tries: int = 10 39 | 40 | log_dir: Path = None 41 | 42 | 43 | class KeystepActioner(BaseActioner): 44 | def __init__(self, keystep_data_dir) -> None: 45 | self.lmdb_env = lmdb.open(str(keystep_data_dir), readonly=True) 46 | self.lmdb_txn = self.lmdb_env.begin() 47 | 48 | def __exit__(self): 49 | self.lmdb_env.close() 50 | 51 | def reset(self, task_str, variation, instructions, demo_id): 52 | super().reset(task_str, variation, instructions, demo_id) 53 | 54 | value = self.lmdb_txn.get(demo_id.encode("ascii")) 55 | value = msgpack.unpackb(value) 56 | self.actions = value["action"][1:] 57 | 58 | def predict(self, taskvar_id, step_id, *args, **kwargs): 59 | out = {} 60 | if step_id < len(self.actions): 61 | out["action"] = self.actions[step_id] 62 | else: 63 | out["action"] = np.zeros((8,), dtype=np.float32) 64 | print(self.demo_id, step_id, len(self.actions)) 65 | return out 66 | 67 | 68 | def evaluate_keysteps(args): 69 | np.random.seed(args.seed) 70 | random.seed(args.seed) 71 | 72 | env = RLBenchEnv( 73 | data_path=args.microstep_data_dir, 74 | apply_rgb=True, 75 | apply_pc=True, 76 | apply_cameras=args.cameras, 77 | headless=args.headless, 78 | gripper_pose=args.gripper_pose, 79 | ) 80 | 81 | variations = range(args.offset, args.max_variations) 82 | 83 | taskvar_id = 0 84 | for task_str in args.tasks: 85 | for variation in variations: 86 | actioner = KeystepActioner( 87 | args.keystep_data_dir / f"{task_str}+{variation}" 88 | ) 89 | episodes_dir = ( 90 | args.microstep_data_dir 91 | / task_str 92 | / f"variation{variation}" 93 | / "episodes" 94 | ) 95 | 96 | result_file = os.path.join( 97 | args.keystep_data_dir, f"{task_str}+{variation}", "results.json" 98 | ) 99 | if os.path.exists(result_file): 100 | continue 101 | 102 | demo_keys, demos = [], [] 103 | if os.path.exists(str(episodes_dir)): 104 | for ep in tqdm(episodes_dir.glob("episode*")): 105 | episode_id = int(ep.stem[7:]) 106 | try: 107 | demo = env.get_demo(task_str, variation, episode_id) 108 | demo_keys.append(f"episode{episode_id}") 109 | demos.append(demo) 110 | except: 111 | print("\tProblem to load demo:", episode_id) 112 | else: 113 | demo_keys = None 114 | demos = None 115 | 116 | success_rate, detail_results = env.evaluate( 117 | taskvar_id, 118 | task_str, 119 | actioner=actioner, 120 | max_episodes=30, # max_step_per_episode 121 | variation=variation, 122 | num_demos=len(demos) if demos is not None else args.num_demos, 123 | demos=demos, 124 | demo_keys=demo_keys, 125 | log_dir=args.log_dir, 126 | max_tries=args.max_tries, 127 | save_image=False, 128 | return_detail_results=True, 129 | skip_demos=1 130 | if demos is None 131 | else 0, # during microstep generate, we skip one demo 132 | ) 133 | 134 | print("Testing Success Rate {}: {:.04f}".format(task_str, success_rate)) 135 | 136 | with open(result_file, "w") as outf: 137 | json.dump(detail_results, outf) 138 | 139 | 140 | if __name__ == "__main__": 141 | args = Arguments().parse_args() 142 | evaluate_keysteps(args) 143 | -------------------------------------------------------------------------------- /polarnet/preprocess/generate_dataset_keysteps.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, List 2 | 3 | import os 4 | import numpy as np 5 | import itertools 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | import collections 9 | import tap 10 | 11 | import lmdb 12 | import msgpack 13 | import msgpack_numpy 14 | msgpack_numpy.patch() 15 | 16 | from polarnet.utils.keystep_detection import keypoint_discovery 17 | from polarnet.utils.coord_transforms import convert_gripper_pose_world_to_image 18 | 19 | from polarnet.core.environments import RLBenchEnv 20 | from PIL import Image 21 | 22 | 23 | class Arguments(tap.Tap): 24 | microstep_data_dir: Path = "data/train_dataset/microsteps/seed0" 25 | keystep_data_dir: Path = "data/train_dataset/keysteps/seed0" 26 | 27 | tasks: Tuple[str, ...] = ("pick_up_cup",) 28 | cameras: Tuple[str, ...] = ("left_shoulder", "right_shoulder", "wrist") 29 | 30 | max_variations: int = 1 31 | offset: int = 0 32 | 33 | 34 | def get_observation(task_str: str, variation: int, episode: int, env: RLBenchEnv): 35 | demo = env.get_demo(task_str, variation, episode) 36 | 37 | key_frames = keypoint_discovery(demo) 38 | key_frames.insert(0, 0) 39 | 40 | state_dict_ls = collections.defaultdict(list) 41 | for f in key_frames: 42 | state_dict = env.get_observation(demo._observations[f]) 43 | for k, v in state_dict.items(): 44 | if len(v) > 0: 45 | # rgb: (N: num_of_cameras, H, W, C); gripper: (7+1, ) 46 | state_dict_ls[k].append(v) 47 | 48 | for k, v in state_dict_ls.items(): 49 | state_dict_ls[k] = np.stack(v, 0) # (T, N, H, W, C) 50 | 51 | action_ls = state_dict_ls['gripper'] # (T, 7+1) 52 | del state_dict_ls['gripper'] 53 | 54 | return demo, key_frames, state_dict_ls, action_ls 55 | 56 | 57 | def generate_keystep_dataset(args: Arguments): 58 | # load RLBench environment 59 | rlbench_env = RLBenchEnv( 60 | data_path=args.microstep_data_dir, 61 | apply_rgb=True, 62 | apply_pc=True, 63 | apply_cameras=args.cameras, 64 | ) 65 | 66 | tasks = args.tasks 67 | variations = range(args.offset, args.max_variations) 68 | 69 | for task_str, variation in itertools.product(tasks, variations): 70 | episodes_dir = args.microstep_data_dir / task_str / f"variation{variation}" / "episodes" 71 | 72 | output_dir = args.keystep_data_dir / f"{task_str}+{variation}" 73 | output_dir.mkdir(parents=True, exist_ok=True) 74 | 75 | lmdb_env = lmdb.open(str(output_dir), map_size=int(1024**4)) 76 | 77 | for ep in tqdm(episodes_dir.glob('episode*')): 78 | episode = int(ep.stem[7:]) 79 | try: 80 | demo, key_frameids, state_dict_ls, action_ls = get_observation( 81 | task_str, variation, episode, rlbench_env 82 | ) 83 | except (FileNotFoundError, RuntimeError, IndexError) as e: 84 | print(e) 85 | return 86 | 87 | gripper_pose = [] 88 | for key_frameid in key_frameids: 89 | gripper_pose.append({ 90 | cam: convert_gripper_pose_world_to_image(demo[key_frameid], cam) for cam in args.cameras 91 | }) 92 | 93 | outs = { 94 | 'key_frameids': key_frameids, 95 | 'rgb': state_dict_ls['rgb'], # (T, N, H, W, 3) 96 | 'pc': state_dict_ls['pc'], # (T, N, H, W, 3) 97 | 'action': action_ls, # (T, A) 98 | 'gripper_pose': gripper_pose, # [T of dict] 99 | } 100 | 101 | txn = lmdb_env.begin(write=True) 102 | txn.put(f'episode{episode}'.encode('ascii'), msgpack.packb(outs)) 103 | txn.commit() 104 | 105 | lmdb_env.close() 106 | 107 | 108 | if __name__ == "__main__": 109 | args = Arguments().parse_args() 110 | generate_keystep_dataset(args) 111 | -------------------------------------------------------------------------------- /polarnet/preprocess/generate_instructions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate instruction embeddings 3 | """ 4 | import os 5 | import json 6 | import jsonlines 7 | from tqdm import tqdm 8 | 9 | import lmdb 10 | import msgpack 11 | import msgpack_numpy 12 | 13 | msgpack_numpy.patch() 14 | 15 | import torch 16 | 17 | from rlbench.action_modes.action_mode import MoveArmThenGripper 18 | from rlbench.action_modes.arm_action_modes import JointVelocity 19 | from rlbench.action_modes.gripper_action_modes import Discrete 20 | from rlbench.environment import Environment 21 | from rlbench.utils import name_to_task_class 22 | 23 | import transformers 24 | 25 | BROKEN_TASKS = set( 26 | [ 27 | "empty_container", 28 | "set_the_table", 29 | ] 30 | ) 31 | 32 | 33 | def generate_all_instructions(env_file: str, instruction_file: str): 34 | if os.path.exists(instruction_file): 35 | exist_tasks = set() 36 | with jsonlines.open(instruction_file) as f: 37 | for item in f: 38 | exist_tasks.add(item["task"]) 39 | print("Exist task", len(exist_tasks)) 40 | else: 41 | exist_tasks = [] 42 | 43 | all_tasks = json.load(open(env_file)) 44 | 45 | action_mode = MoveArmThenGripper( 46 | arm_action_mode=JointVelocity(), gripper_action_mode=Discrete() 47 | ) 48 | env = Environment(action_mode, headless=True) 49 | env.launch() 50 | 51 | outfile = jsonlines.open(instruction_file, "a", flush=True) 52 | 53 | for task in tqdm(all_tasks): 54 | if task in BROKEN_TASKS or task in exist_tasks: 55 | continue 56 | print(task) 57 | outs = {"task": task, "variations": {}} 58 | task_env = env.get_task(name_to_task_class(task)) 59 | num_variations = task_env.variation_count() 60 | for v in tqdm(range(num_variations)): 61 | try: 62 | task_env.set_variation(v) 63 | descriptions, obs = task_env.reset() 64 | outs["variations"][v] = descriptions 65 | except Exception as e: 66 | print("Error", task, v, e) 67 | outfile.write(outs) 68 | 69 | env.shutdown() 70 | outfile.close() 71 | 72 | 73 | def load_all_instructions(instruction_file: str): 74 | data = [] 75 | with jsonlines.open(instruction_file, "r") as f: 76 | for item in f: 77 | data.append(item) 78 | return data 79 | 80 | 81 | def load_text_encoder(encoder: str): 82 | if encoder == "bert": 83 | tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") 84 | model = transformers.BertModel.from_pretrained("bert-base-uncased") 85 | elif encoder == "clip": 86 | model_name = "openai/clip-vit-base-patch32" 87 | tokenizer = transformers.CLIPTokenizer.from_pretrained(model_name) 88 | model = transformers.CLIPTextModel.from_pretrained(model_name) 89 | else: 90 | raise ValueError(f"Unexpected encoder {encoder}") 91 | 92 | return tokenizer, model 93 | 94 | 95 | def main(args): 96 | taskvar_instrs = load_all_instructions(args.instruction_file) 97 | 98 | tokenizer, model = load_text_encoder(args.encoder) 99 | model = model.to(args.device) 100 | 101 | os.makedirs(args.output_file, exist_ok=True) 102 | lmdb_env = lmdb.open(args.output_file, map_size=int(1024**3)) 103 | 104 | for item in tqdm(taskvar_instrs): 105 | task = item["task"] 106 | for variation, instructions in item["variations"].items(): 107 | key = "%s+%s" % (task, variation) 108 | 109 | instr_embeds = [] 110 | for instr in instructions: 111 | tokens = tokenizer(instr, padding=False)["input_ids"] 112 | if len(tokens) > 77: 113 | print("Too long", task, variation, instr) 114 | 115 | tokens = torch.LongTensor(tokens).unsqueeze(0).to(args.device) 116 | with torch.no_grad(): 117 | embed = model(tokens).last_hidden_state.squeeze(0) 118 | instr_embeds.append(embed.data.cpu().numpy()) 119 | 120 | txn = lmdb_env.begin(write=True) 121 | txn.put(key.encode("ascii"), msgpack.packb(instr_embeds)) 122 | txn.commit() 123 | 124 | lmdb_env.close() 125 | 126 | 127 | if __name__ == "__main__": 128 | import argparse 129 | 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument("--encoder", choices=["bert", "clip"], default="clip") 132 | parser.add_argument("--device", default="cuda") 133 | parser.add_argument("--output_file", required=True) 134 | parser.add_argument( 135 | "--generate_all_instructions", action="store_true", default=False 136 | ) 137 | parser.add_argument("--env_file", default="assets/all_tasks.json") 138 | parser.add_argument( 139 | "--instruction_file", default="assets/taskvar_instructions.jsonl" 140 | ) 141 | args = parser.parse_args() 142 | if args.generate_all_instructions: 143 | generate_all_instructions(args.env_file, args.instruction_file) 144 | main(args) 145 | -------------------------------------------------------------------------------- /polarnet/preprocess/generate_pcd_dataset_keysteps.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | 3 | import os 4 | import argparse 5 | import numpy as np 6 | import collections 7 | import json 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | import open3d as o3d 14 | 15 | import lmdb 16 | import msgpack 17 | import msgpack_numpy 18 | 19 | msgpack_numpy.patch() 20 | 21 | from polarnet.utils.utils import get_assets_dir 22 | from polarnet.config.constants import get_workspace 23 | 24 | 25 | tasks_use_table_surface = json.load( 26 | open(f"{get_assets_dir()}/tasks_use_table_surface.json", "r") 27 | ) 28 | 29 | 30 | def process_point_clouds( 31 | rgb, 32 | pc, 33 | gripper_pose=None, 34 | task=None, 35 | pc_space="workspace_on_table", 36 | pc_center_type="gripper", 37 | pc_radius_norm=True, 38 | voxel_size=0.01, 39 | no_rgb=False, 40 | no_normal=False, 41 | no_height=False, 42 | ): 43 | rgb = rgb.reshape(-1, 3) 44 | pc = pc.reshape(-1, 3) 45 | WORKSPACE = get_workspace(real_robot=False) 46 | 47 | X_BBOX = WORKSPACE["X_BBOX"] 48 | Y_BBOX = WORKSPACE["Y_BBOX"] 49 | Z_BBOX = WORKSPACE["Z_BBOX"] 50 | TABLE_HEIGHT = WORKSPACE["TABLE_HEIGHT"] 51 | 52 | if pc_space in ["workspace", "workspace_on_table"]: 53 | masks = ( 54 | (pc[:, 0] > X_BBOX[0]) 55 | & (pc[:, 0] < X_BBOX[1]) 56 | & (pc[:, 1] > Y_BBOX[0]) 57 | & (pc[:, 1] < Y_BBOX[1]) 58 | & (pc[:, 2] > Z_BBOX[0]) 59 | & (pc[:, 2] < Z_BBOX[1]) 60 | ) 61 | if pc_space == "workspace_on_table" and task not in tasks_use_table_surface: 62 | masks = masks & (pc[:, 2] > TABLE_HEIGHT) 63 | rgb = rgb[masks] 64 | pc = pc[masks] 65 | 66 | # pcd center and radius 67 | if pc_center_type == "point": 68 | pc_center = np.mean(pc, 0) 69 | elif pc_center_type == "gripper": 70 | pc_center = gripper_pose[:3] 71 | 72 | if pc_radius_norm: 73 | pc_radius = np.max(np.sqrt(np.sum((pc - pc_center) ** 2, 1)), keepdims=True) 74 | else: 75 | pc_radius = np.ones((1,), dtype=np.float32) 76 | 77 | pcd = o3d.geometry.PointCloud() 78 | pcd.points = o3d.utility.Vector3dVector(np.copy(pc)) 79 | pcd.colors = o3d.utility.Vector3dVector(np.copy(rgb) / 255.0) 80 | 81 | if voxel_size is not None and voxel_size > 0: 82 | downpcd, _, idxs = pcd.voxel_down_sample_and_trace( 83 | voxel_size, np.min(pc, 0), np.max(pc, 0) 84 | ) 85 | else: 86 | downpcd = pcd 87 | 88 | new_rgb = np.asarray(downpcd.colors) * 2 - 1 # (-1, 1) 89 | new_pos = np.asarray(downpcd.points) 90 | 91 | # normalized point clouds 92 | new_ft = (new_pos - pc_center) / pc_radius 93 | if not no_rgb: 94 | # use_color 95 | new_ft = np.concatenate([new_ft, new_rgb], axis=-1) 96 | if not no_normal: 97 | # use_normal 98 | downpcd.estimate_normals( 99 | search_param=o3d.geometry.KDTreeSearchParamHybrid( 100 | radius=voxel_size * 2, max_nn=30 101 | ) 102 | ) 103 | new_ft = np.concatenate([new_ft, np.asarray(downpcd.normals)], axis=-1) 104 | if not no_height: 105 | # use_height 106 | heights = np.asarray(downpcd.points)[:, -1] 107 | heights = heights - TABLE_HEIGHT 108 | new_ft = np.concatenate([new_ft, heights[:, None]], axis=-1) 109 | return new_ft, pc_center, pc_radius 110 | 111 | 112 | def main(args): 113 | seed = args.seed 114 | dataset_dir = args.dataset_dir 115 | 116 | if args.seed >= 0: 117 | keystep_dir = os.path.join(dataset_dir, "keysteps", "seed%d" % seed) 118 | out_dir = os.path.join(dataset_dir, args.outname, "seed%d" % seed) 119 | else: 120 | keystep_dir = os.path.join(dataset_dir, "keysteps") 121 | out_dir = os.path.join(dataset_dir, args.outname) 122 | 123 | taskvars = os.listdir(keystep_dir) 124 | taskvars.sort() 125 | print("#taskvars", len(taskvars)) 126 | 127 | os.makedirs(out_dir, exist_ok=True) 128 | 129 | for taskvar in tqdm(taskvars): 130 | task = taskvar.split("+")[0] 131 | 132 | lmdb_env = lmdb.open( 133 | os.path.join(keystep_dir, taskvar), readonly=True, lock=False 134 | ) 135 | lmdb_txn = lmdb_env.begin() 136 | 137 | out_lmdb_env = lmdb.open( 138 | os.path.join(out_dir, taskvar), map_size=int(1024**4) 139 | ) 140 | 141 | pbar = tqdm(total=lmdb_txn.stat()["entries"]) 142 | for episode_key, value in lmdb_txn.cursor(): 143 | episode_key_dec = episode_key.decode("utf-8") 144 | if not episode_key_dec.startswith("episode"): 145 | continue 146 | 147 | value = msgpack.unpackb(value) 148 | rgbs = value["rgb"][ 149 | :, : args.num_cameras 150 | ] # (T, C, H, W, 3) just use the first 3 cameras 151 | pcs = value["pc"][:, : args.num_cameras] # (T, C, H, W, 3) 152 | actions = value["action"] 153 | 154 | outs = collections.defaultdict(list) 155 | for t, rgb in enumerate(rgbs): 156 | pcd_ft, pc_center, pc_radius = process_point_clouds( 157 | rgbs[t], 158 | pcs[t], 159 | gripper_pose=actions[t], 160 | task=task, 161 | pc_space="workspace_on_table", 162 | pc_center_type="gripper", 163 | pc_radius_norm=True, 164 | voxel_size=0.01, 165 | no_rgb=args.no_rgb, 166 | no_normal=args.no_normal, 167 | no_height=args.no_height, 168 | ) 169 | outs["pc_fts"].append(pcd_ft) 170 | outs["pc_centers"].append(pc_center) 171 | outs["pc_radii"].append(pc_radius) 172 | outs["actions"] = actions 173 | 174 | txn = out_lmdb_env.begin(write=True) 175 | txn.put(episode_key, msgpack.packb(outs)) 176 | txn.commit() 177 | 178 | pbar.update(1) 179 | 180 | lmdb_env.close() 181 | out_lmdb_env.close() 182 | 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument("--seed", type=int, default=0) 187 | parser.add_argument("--num_cameras", type=int, default=3) 188 | parser.add_argument("--dataset_dir", type=str, default="data/train_dataset") 189 | parser.add_argument("--outname", type=str, default="keysteps_pcd") 190 | 191 | parser.add_argument("--no_rgb", action="store_true", default=False) 192 | parser.add_argument("--no_normal", action="store_true", default=False) 193 | parser.add_argument("--no_height", action="store_true", default=False) 194 | 195 | args = parser.parse_args() 196 | 197 | main(args) 198 | -------------------------------------------------------------------------------- /polarnet/preprocess/generate_real_instructions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import jsonlines 3 | 4 | from pathlib import Path 5 | 6 | 7 | def get_args_parser(): 8 | parser = argparse.ArgumentParser("Generate real instructions", add_help=False) 9 | parser.add_argument("--output-file", default="assets/real_robot_instructions.jsonl", type=str) 10 | return parser 11 | 12 | 13 | def push_buttons_template(variations): 14 | instruction = {} 15 | for var, info in variations.items(): 16 | instruction[f"{var}"] = [] 17 | first_button = info[0] 18 | rtn0 = 'push the %s button' % first_button 19 | rtn1 = 'press the %s button' % first_button 20 | rtn2 = 'push down the button with the %s base' % first_button 21 | for next_button in info[1:]: 22 | rtn0 += ', then push the %s button' % next_button 23 | rtn1 += ', then press the %s button' % next_button 24 | rtn2 += ', then the %s one' % next_button 25 | instruction[f"{var}"].append(rtn0) 26 | instruction[f"{var}"].append(rtn1) 27 | instruction[f"{var}"].append(rtn2) 28 | return instruction 29 | 30 | 31 | def stack_cup_template(variations): 32 | instruction = {} 33 | for var, info in variations.items(): 34 | instruction[f"{var}"] = [] 35 | first_cup = info[0] 36 | second_cup = info[1] 37 | rtn0 = f"stack the {first_cup} cup on top of the {second_cup} cup" 38 | rtn1 = f"place the {first_cup} cup onto the {second_cup} cup" 39 | rtn2 = f"put the {first_cup} cup on top of the {second_cup} one" 40 | rtn3 = f"pick up and set the {first_cup} cup down into the {second_cup} cup" 41 | rtn4 = f"create a stack of cups with the {second_cup} cup as its base and the {first_cup} on top of it" 42 | rtn5 = f"keeping the {second_cup} cup on the table, stack the {first_cup} one onto it" 43 | instruction[f"{var}"].append(rtn0) 44 | instruction[f"{var}"].append(rtn1) 45 | instruction[f"{var}"].append(rtn2) 46 | instruction[f"{var}"].append(rtn3) 47 | instruction[f"{var}"].append(rtn4) 48 | instruction[f"{var}"].append(rtn5) 49 | return instruction 50 | 51 | 52 | def put_plate_template(variations): 53 | instruction = {} 54 | for var, info in variations.items(): 55 | instruction[f"{var}"] = [] 56 | color = info[0] 57 | rtn0 = f"take the {color} plate to the target place" 58 | rtn1 = f"place the {color} plate on the target" 59 | instruction[f"{var}"].append(rtn0) 60 | instruction[f"{var}"].append(rtn1) 61 | return instruction 62 | 63 | def put_item_in_drawer_template(variations): 64 | instruction = {} 65 | for var, info in variations.items(): 66 | instruction[f"{var}"] = [] 67 | obj = info[0] 68 | part = info[1] 69 | rtn0 = f"put the {obj} in the {part} part of the drawer" 70 | rtn1 = f"take the {obj} and put it in the {part} compartiment of the drawer" 71 | instruction[f"{var}"].append(rtn0) 72 | instruction[f"{var}"].append(rtn1) 73 | return instruction 74 | 75 | 76 | def open_drawer_template(variations): 77 | instruction = {} 78 | for var, info in variations.items(): 79 | instruction[f"{var}"] = [] 80 | part = info[0] 81 | rtn0 = f"open {part} drawer" 82 | rtn1 = f"grip the {part} handle and pull the {part} drawer open" 83 | rtn2 = f"slide the {part} drawer open" 84 | instruction[f"{var}"].append(rtn0) 85 | instruction[f"{var}"].append(rtn1) 86 | instruction[f"{var}"].append(rtn2) 87 | return instruction 88 | 89 | def put_item_in_cabinet_template(variations): 90 | instruction = {} 91 | for var, info in variations.items(): 92 | instruction[f"{var}"] = [] 93 | obj = info[0] 94 | part = info[1] 95 | rtn0 = f"put the {obj} in the {part} part of the cabinet" 96 | rtn1 = f"take the {obj} and put it in the {part} compartiment of the cabinet" 97 | instruction[f"{var}"].append(rtn0) 98 | instruction[f"{var}"].append(rtn1) 99 | return instruction 100 | 101 | 102 | def put_fruit_in_box_template(variations): 103 | instruction = {} 104 | for var, info in variations.items(): 105 | instruction[f"{var}"] = [] 106 | obj = info[0] 107 | rtn0 = f"put the {obj} in the box" 108 | rtn1 = f"take the {obj} and put it inside the box" 109 | instruction[f"{var}"].append(rtn0) 110 | instruction[f"{var}"].append(rtn1) 111 | return instruction 112 | 113 | 114 | def hang_mug_template(variations): 115 | instruction = {} 116 | for var, info in variations.items(): 117 | instruction[f"{var}"] = [] 118 | color = info[0] 119 | part = info[1] 120 | rtn0 = f"take the {color} mug and put it on the {part} part of the hanger" 121 | rtn1 = f"put the {color} mug on the {part} part of the hanger" 122 | instruction[f"{var}"].append(rtn0) 123 | instruction[f"{var}"].append(rtn1) 124 | return instruction 125 | 126 | 127 | def main(args): 128 | 129 | tasks = ["real_push_buttons", "real_put_plate", "real_stack_cup", "real_put_item_in_drawer", "real_open_drawer", "real_put_item_in_cabinet", "real_put_fruit_in_box", "real_hang_mug"] 130 | instructions = [] 131 | for task in tasks: 132 | if task == "real_push_buttons": 133 | vars = {0: ["red", "green", "yellow"], 1: ["white", "yellow", "black"], 2: ["blue", "black", "red"], 3: ["orange", "pink", "white"], 4: ["green", "cyan"]} 134 | var_inst = push_buttons_template(vars) 135 | elif task == "real_stack_cup": 136 | vars = {0: ["yellow", "pink"], 1: ["navy", "yellow"], 2: ["pink", "cyan"], 3: ["cyan", "navy"], 4: ["pink", "yellow"]} 137 | var_inst = stack_cup_template(vars) 138 | elif task == "real_put_plate": 139 | vars = {0: ["white"], 1: ["blue"], 2: ["red"]} 140 | var_inst = put_plate_template(vars) 141 | elif task == "real_put_item_in_drawer": 142 | vars = {0: ["peach", "top"], 1: ["orange", "middle"], 2: ["strawberry", "top"]} 143 | var_inst = put_item_in_drawer_template(vars) 144 | elif task == "real_open_drawer": 145 | vars = {0: ["top"], 1: ["middle"]} 146 | var_inst = open_drawer_template(vars) 147 | elif task == "real_put_item_in_cabinet": 148 | vars = {0: ["apple", "top"], 1: ["strawberry", "bottom"], 2: ["lemon", "top"]} 149 | var_inst = put_item_in_cabinet_template(vars) 150 | elif task == "real_put_fruit_in_box": 151 | vars = {0: ["strawberry"], 1: ["peach"], 2: ["banana"], 3: ["lemon"]} 152 | var_inst = put_fruit_in_box_template(vars) 153 | elif task == "real_hang_mug": 154 | vars = {0: ["green", "middle"], 1: ["pink", "middle"], 2: ["blue", "top"]} 155 | var_inst = hang_mug_template(vars) 156 | else: 157 | continue 158 | instructions.append({"task": task, "variations": var_inst}) 159 | 160 | print(instructions) 161 | output_file = jsonlines.open(args.output_file, 'a', flush=True) 162 | output_file.write_all(instructions) 163 | print("Instructions succesfully written in", args.output_file) 164 | 165 | 166 | 167 | if __name__ == "__main__": 168 | parser = argparse.ArgumentParser( 169 | "Generate real instructions", parents=[get_args_parser()] 170 | ) 171 | args = parser.parse_args() 172 | main(args) -------------------------------------------------------------------------------- /polarnet/summarize_74tasks_tst_results_by_groups.py: -------------------------------------------------------------------------------- 1 | import json 2 | import jsonlines 3 | import collections 4 | import numpy as np 5 | import tap 6 | 7 | from polarnet.utils.utils import get_assets_dir 8 | 9 | 10 | class Arguments(tap.Tap): 11 | result_file: str 12 | 13 | 14 | def main(args): 15 | task_groups = json.load(open(f"{get_assets_dir()}/74_tasks_per_category.json")) 16 | task2group = {} 17 | for group, tasks in task_groups.items(): 18 | for task in tasks: 19 | task2group[task] = group 20 | group_orders = [ 21 | "planning", 22 | "tools", 23 | "long_term", 24 | "rotation-invariant", 25 | "motion-planner", 26 | "screw", 27 | "multimodal", 28 | "precision", 29 | "visual_occlusion", 30 | ] 31 | 32 | results = collections.defaultdict(dict) 33 | with jsonlines.open(args.result_file, "r") as f: 34 | for item in f: 35 | results[item["checkpoint"]].setdefault(item["task"], []) 36 | results[item["checkpoint"]][item["task"]].append(item["sr"]) 37 | 38 | ckpt_results = collections.defaultdict(list) 39 | for ckpt, res in results.items(): 40 | for task, v in res.items(): 41 | ckpt_results[ckpt].append((task, np.mean(v))) 42 | print("\nnum_tasks", len(ckpt_results[ckpt])) 43 | 44 | for ckpt, res in ckpt_results.items(): 45 | print() 46 | print(ckpt, "num_tasks", len(res)) 47 | group_res = collections.defaultdict(list) 48 | for task, sr in res: 49 | group_res[task2group[task]].append(sr) 50 | 51 | print(",".join(group_orders)) 52 | print(",".join(["%.2f" % (np.mean(group_res[g]) * 100) for g in group_orders])) 53 | print("avg tasks: %.2f" % (np.mean([x[1] for x in res]) * 100)) 54 | 55 | 56 | if __name__ == "__main__": 57 | args = Arguments().parse_args() 58 | main(args) 59 | -------------------------------------------------------------------------------- /polarnet/summarize_peract_official_tst_results.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import jsonlines 4 | import collections 5 | import tap 6 | 7 | 8 | class Arguments(tap.Tap): 9 | result_file: str 10 | 11 | 12 | def main(args): 13 | results = collections.defaultdict(dict) 14 | exists_taskvars = collections.defaultdict(set) 15 | with jsonlines.open(args.result_file, 'r') as f: 16 | for item in f: 17 | taskvar = '%s+%d'%(item['task'], item['variation']) 18 | if taskvar in exists_taskvars[item['checkpoint']]: 19 | continue 20 | exists_taskvars[item['checkpoint']].add(taskvar) 21 | results[item['checkpoint']].setdefault(item['task'], []) 22 | results[item['checkpoint']][item['task']].append((item['sr'] * item['num_demos'], item['num_demos'])) 23 | #.append((item['task'], item['variation'], item['sr'])) 24 | ckpt_results = collections.defaultdict(list) 25 | for ckpt, res in results.items(): 26 | for task, v in res.items(): 27 | if np.sum([x[1] for x in v]) != 25: 28 | print(ckpt, task, np.sum([x[1] for x in v])) 29 | sr = np.sum([x[0] for x in v]) / np.sum([x[1] for x in v]) 30 | ckpt_results[ckpt].append((task, sr)) 31 | 32 | print('\nnum_tasks', len(ckpt_results[ckpt])) 33 | for ckpt, res in ckpt_results.items(): 34 | print() 35 | print(ckpt) 36 | res.sort(key=lambda x: x[0]) 37 | 38 | print(','.join([x[0] for x in res])) 39 | 40 | print(','.join(['%.2f' % (x[1]*100) for x in res])) 41 | 42 | print('average: %.2f' % (np.mean([x[1] for x in res]) * 100)) 43 | 44 | 45 | if __name__ == '__main__': 46 | args = Arguments().parse_args() 47 | main(args) 48 | -------------------------------------------------------------------------------- /polarnet/summarize_tst_results.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import jsonlines 5 | import collections 6 | import tap 7 | 8 | 9 | class Arguments(tap.Tap): 10 | result_file: str 11 | 12 | 13 | def main(args): 14 | results = collections.defaultdict(dict) 15 | with jsonlines.open(args.result_file, 'r') as f: 16 | for item in f: 17 | results[item['checkpoint']].setdefault(item['task'], []) 18 | results[item['checkpoint']][item['task']].append(item['sr']) 19 | #.append((item['task'], item['variation'], item['sr'])) 20 | ckpt_results = collections.defaultdict(list) 21 | for ckpt, res in results.items(): 22 | for task, v in res.items(): 23 | ckpt_results[ckpt].append((task, np.mean(v))) 24 | 25 | print('\nnum_tasks', len(ckpt_results[ckpt])) 26 | for ckpt, res in ckpt_results.items(): 27 | print() 28 | print(ckpt) 29 | res.sort(key=lambda x: x[0]) 30 | 31 | print(','.join([x[0] for x in res])) 32 | 33 | print(','.join(['%.2f' % (x[1]*100) for x in res])) 34 | 35 | print('#tasks: %d, average: %.2f' % (len(res), np.mean([x[1] for x in res]) * 100)) 36 | 37 | 38 | if __name__ == '__main__': 39 | args = Arguments().parse_args() 40 | main(args) 41 | -------------------------------------------------------------------------------- /polarnet/summarize_val_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import jsonlines 4 | import collections 5 | import tap 6 | 7 | 8 | class Arguments(tap.Tap): 9 | result_file: str 10 | 11 | 12 | def main(args): 13 | results = collections.defaultdict(list) 14 | with jsonlines.open(args.result_file, 'r') as f: 15 | for item in f: 16 | results[item['checkpoint']].append((item['task'], item['variation'], item['sr'])) 17 | 18 | ckpts = list(results.keys()) 19 | ckpts.sort(key=lambda x: int(os.path.basename(x).split('.')[0].split('_')[-1])) 20 | 21 | # show task results 22 | tasks = set() 23 | for ckpt in ckpts: 24 | for x in results[ckpt]: 25 | tasks.add(x[0]) 26 | tasks = list(tasks) 27 | tasks.sort() 28 | for task in tasks: 29 | res = [] 30 | for ckpt in ckpts: 31 | ckpt_res = [] 32 | for x in results[ckpt]: 33 | if x[0] == task: 34 | ckpt_res.append(x[-1]) 35 | res.append(np.mean(ckpt_res)) 36 | print('\n', task, len(ckpt_res)) 37 | print(', '.join(['%.2f' % (x*100) for x in res])) 38 | print() 39 | 40 | avg_results = [] 41 | for k in ckpts: 42 | v = results[k] 43 | sr = collections.defaultdict(list) 44 | for x in v: 45 | sr[x[0]].append(x[-1]) 46 | sr = [np.mean(x) for x in sr.values()] 47 | print(k, len(v), np.mean(sr)*100) 48 | avg_results.append((k, np.mean(sr))) 49 | 50 | print() 51 | print('Best checkpoint and SR') 52 | avg_results.sort(key=lambda x: -x[1]) 53 | for x in avg_results: 54 | if x[-1] < avg_results[0][-1]: 55 | break 56 | print((x[0], x[1]*100)) 57 | print('\n') 58 | 59 | 60 | if __name__ == '__main__': 61 | args = Arguments().parse_args() 62 | main(args) 63 | -------------------------------------------------------------------------------- /polarnet/train_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | import time 6 | from collections import defaultdict 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.distributed as dist 14 | 15 | from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file 16 | from utils.save import ModelSaver, save_training_meta 17 | from utils.misc import NoOp, set_dropout, set_random_seed, set_cuda, wrap_model 18 | from utils.distributed import all_gather 19 | 20 | from optim import get_lr_sched, get_lr_sched_decay_rate 21 | from optim.misc import build_optimizer 22 | 23 | from config.default import get_config 24 | from dataloaders.loader import build_dataloader 25 | 26 | from polarnet.dataloaders.pcd_keystep_dataset import ( 27 | PCDKeystepDataset, pcd_stepwise_collate_fn, 28 | ProcessedPCDKeystepDataset 29 | ) 30 | from polarnet.models.pcd_unet import PointCloudUNet 31 | 32 | 33 | import warnings 34 | warnings.filterwarnings("ignore") 35 | 36 | from polarnet.utils.slurm_requeue import init_signal_handler 37 | 38 | 39 | dataset_factory = { 40 | 'pre_pcd_keystep_stepwise': (ProcessedPCDKeystepDataset, pcd_stepwise_collate_fn), 41 | 'pcd_keystep_stepwise': (PCDKeystepDataset, pcd_stepwise_collate_fn), 42 | } 43 | 44 | 45 | 46 | def main(config): 47 | config.defrost() 48 | default_gpu, n_gpu, device = set_cuda(config) 49 | # config.freeze() 50 | 51 | if default_gpu: 52 | LOGGER.info( 53 | 'device: {} n_gpu: {}, distributed training: {}'.format( 54 | device, n_gpu, bool(config.local_rank != -1) 55 | ) 56 | ) 57 | 58 | seed = config.SEED 59 | if config.local_rank != -1: 60 | seed += config.rank 61 | set_random_seed(seed) 62 | 63 | if type(config.DATASET.taskvars) is str: 64 | config.DATASET.taskvars = [config.DATASET.taskvars] 65 | 66 | # load data training set 67 | dataset_class, dataset_collate_fn = dataset_factory[config.DATASET.dataset_class] 68 | 69 | dataset = dataset_class(**config.DATASET) 70 | data_loader, pre_epoch = build_dataloader( 71 | dataset, dataset_collate_fn, True, config 72 | ) 73 | LOGGER.info(f'#num_steps_per_epoch: {len(data_loader)}') 74 | if config.num_train_steps is None: 75 | config.num_train_steps = len(data_loader) * config.num_epochs 76 | else: 77 | assert config.num_epochs is None, 'cannot set num_train_steps and num_epochs at the same time.' 78 | config.num_epochs = int( 79 | np.ceil(config.num_train_steps / len(data_loader))) 80 | 81 | # setup loggers 82 | if default_gpu: 83 | save_training_meta(config) 84 | TB_LOGGER.create(os.path.join(config.output_dir, 'logs')) 85 | model_saver = ModelSaver(os.path.join(config.output_dir, 'ckpts')) 86 | add_log_to_file(os.path.join(config.output_dir, 'logs', 'log.txt')) 87 | else: 88 | LOGGER.disabled = True 89 | model_saver = NoOp() 90 | 91 | # Prepare model 92 | model = PointCloudUNet(**config.MODEL) 93 | # DDP: SyncBN 94 | if int(os.environ['WORLD_SIZE']) > 1: 95 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 96 | 97 | LOGGER.info("Model: nweights %d nparams %d" % (model.num_parameters)) 98 | LOGGER.info("Model: trainable nweights %d nparams %d" % 99 | (model.num_trainable_parameters)) 100 | 101 | config.freeze() 102 | 103 | # Load from checkpoint 104 | model_checkpoint_file = config.checkpoint 105 | optimizer_checkpoint_file = os.path.join( 106 | config.output_dir, 'ckpts', 'train_state_latest.pt' 107 | ) 108 | if os.path.exists(optimizer_checkpoint_file) and config.resume_training: 109 | LOGGER.info('Load the optimizer checkpoint from %s' % optimizer_checkpoint_file) 110 | optimizer_checkpoint = torch.load( 111 | optimizer_checkpoint_file, map_location=lambda storage, loc: storage 112 | ) 113 | lastest_model_checkpoint_file = os.path.join( 114 | config.output_dir, 'ckpts', 'model_step_%d.pt' % optimizer_checkpoint['step'] 115 | ) 116 | if os.path.exists(lastest_model_checkpoint_file): 117 | LOGGER.info('Load the model checkpoint from %s' % lastest_model_checkpoint_file) 118 | model_checkpoint_file = lastest_model_checkpoint_file 119 | global_step = optimizer_checkpoint['step'] 120 | restart_epoch = global_step // len(data_loader) 121 | else: 122 | optimizer_checkpoint = None 123 | # to compute training statistics 124 | restart_epoch = config.restart_epoch 125 | global_step = restart_epoch * len(data_loader) 126 | 127 | if model_checkpoint_file is not None: 128 | checkpoint = torch.load( 129 | model_checkpoint_file, map_location=lambda storage, loc: storage) 130 | model.load_state_dict(checkpoint, strict=config.checkpoint_strict_load) 131 | 132 | model.train() 133 | # set_dropout(model, config.dropout) 134 | model = wrap_model(model, device, config.local_rank) 135 | 136 | # Prepare optimizer 137 | optimizer, init_lrs = build_optimizer(model, config) 138 | if optimizer_checkpoint is not None: 139 | optimizer.load_state_dict(optimizer_checkpoint['optimizer']) 140 | 141 | if default_gpu: 142 | pbar = tqdm(initial=global_step, total=config.num_train_steps) 143 | else: 144 | pbar = NoOp() 145 | 146 | LOGGER.info(f"***** Running training with {config.world_size} GPUs *****") 147 | LOGGER.info(" Batch size = %d", config.train_batch_size if config.local_rank == - 148 | 1 else config.train_batch_size * config.world_size) 149 | LOGGER.info(" Accumulate steps = %d", config.gradient_accumulation_steps) 150 | LOGGER.info(" Num steps = %d", config.num_train_steps) 151 | 152 | start_time = time.time() 153 | # quick hack for amp delay_unscale bug 154 | optimizer.zero_grad() 155 | optimizer.step() 156 | 157 | init_signal_handler() 158 | 159 | for epoch_id in range(restart_epoch, config.num_epochs): 160 | if global_step >= config.num_train_steps: 161 | break 162 | 163 | # In distributed mode, calling the set_epoch() method at the beginning of each epoch 164 | pre_epoch(epoch_id) 165 | 166 | for step, batch in enumerate(data_loader): 167 | # forward pass 168 | losses, logits = model(batch, compute_loss=True) 169 | 170 | # backward pass 171 | if config.gradient_accumulation_steps > 1: # average loss 172 | losses['total'] = losses['total'] / \ 173 | config.gradient_accumulation_steps 174 | losses['total'].backward() 175 | 176 | acc = ((logits[..., -1].data.cpu() > 0) 177 | == batch['actions'][..., -1].cpu()).float() 178 | 179 | if 'step_masks' in batch: 180 | acc = torch.sum(acc * batch['step_masks']) / \ 181 | torch.sum(batch['step_masks']).cpu() 182 | else: 183 | acc = acc.mean().cpu() 184 | 185 | for key, value in losses.items(): 186 | TB_LOGGER.add_scalar( 187 | f'step/loss_{key}', value.item(), global_step) 188 | TB_LOGGER.add_scalar('step/acc_open', acc.item(), global_step) 189 | 190 | # optimizer update and logging 191 | if (step + 1) % config.gradient_accumulation_steps == 0: 192 | global_step += 1 193 | 194 | # learning rate scheduling 195 | lr_decay_rate = get_lr_sched_decay_rate(global_step, config) 196 | for kp, param_group in enumerate(optimizer.param_groups): 197 | param_group['lr'] = lr_this_step = init_lrs[kp] * lr_decay_rate 198 | TB_LOGGER.add_scalar('lr', lr_this_step, global_step) 199 | 200 | # log loss 201 | # NOTE: not gathered across GPUs for efficiency 202 | TB_LOGGER.step() 203 | 204 | # update model params 205 | if config.grad_norm != -1: 206 | grad_norm = torch.nn.utils.clip_grad_norm_( 207 | model.parameters(), config.grad_norm 208 | ) 209 | # print(step, name, grad_norm) 210 | # for k, v in model.named_parameters(): 211 | # if v.grad is not None: 212 | # v = torch.norm(v).data.item() 213 | # print(k, v) 214 | TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) 215 | optimizer.step() 216 | optimizer.zero_grad() 217 | pbar.update(1) 218 | 219 | if global_step % config.log_steps == 0: 220 | # monitor training throughput 221 | LOGGER.info( 222 | f'==============Epoch {epoch_id} Step {global_step}===============') 223 | LOGGER.info(', '.join(['%s:%.4f' % ( 224 | lk, lv.item()) for lk, lv in losses.items()] + ['acc:%.2f' % (acc*100)])) 225 | LOGGER.info('===============================================') 226 | 227 | if global_step % config.save_steps == 0: 228 | model_saver.save(model, global_step, optimizer=optimizer, rewrite_optimizer=True) 229 | 230 | if global_step >= config.num_train_steps: 231 | break 232 | 233 | if global_step % config.save_steps != 0: 234 | LOGGER.info( 235 | f'==============Epoch {epoch_id} Step {global_step}===============') 236 | LOGGER.info(', '.join(['%s:%.4f' % (lk, lv.item()) 237 | for lk, lv in losses.items()] + ['acc:%.2f' % (acc*100)])) 238 | LOGGER.info('===============================================') 239 | model_saver.save(model, global_step, optimizer=optimizer, rewrite_optimizer=True) 240 | 241 | 242 | def build_args(): 243 | parser = argparse.ArgumentParser() 244 | parser.add_argument( 245 | "--exp-config", 246 | type=str, 247 | required=True, 248 | help="path to config yaml containing info about experiment", 249 | ) 250 | parser.add_argument('--sleep_time', type=float, default=0, help='hour') 251 | parser.add_argument('--restart_epoch', type=int, default=0) 252 | parser.add_argument( 253 | "opts", 254 | default=None, 255 | nargs=argparse.REMAINDER, 256 | help="Modify config options from command line", 257 | ) 258 | args = parser.parse_args() 259 | 260 | if args.sleep_time > 0: 261 | time.sleep(args.sleep_time * 3600) 262 | 263 | config = get_config(args.exp_config, args.opts) 264 | 265 | config.defrost() 266 | config.restart_epoch = args.restart_epoch 267 | config.freeze() 268 | 269 | for i in range(len(config.CMD_TRAILING_OPTS)): 270 | if config.CMD_TRAILING_OPTS[i] == "DATASET.taskvars": 271 | if type(config.CMD_TRAILING_OPTS[i + 1]) is str: 272 | config.CMD_TRAILING_OPTS[i + 273 | 1] = [config.CMD_TRAILING_OPTS[i + 1]] 274 | 275 | if os.path.exists(config.output_dir) and os.listdir(config.output_dir): 276 | LOGGER.warning( 277 | "Output directory ({}) already exists and is not empty.".format( 278 | config.output_dir 279 | ) 280 | ) 281 | 282 | return config 283 | 284 | 285 | if __name__ == '__main__': 286 | config = build_args() 287 | main(config) 288 | -------------------------------------------------------------------------------- /polarnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlc-robot/polarnet/882b7ef5b82ee4c7779cdd0020f58e919e0f8bce/polarnet/utils/__init__.py -------------------------------------------------------------------------------- /polarnet/utils/coord_transforms.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import einops 7 | import json 8 | from scipy.spatial.transform import Rotation as R 9 | 10 | 11 | def convert_gripper_pose_world_to_image(obs, camera: str) -> Tuple[int, int]: 12 | '''Convert the gripper pose from world coordinate system to image coordinate system. 13 | image[v, u] is the gripper location. 14 | ''' 15 | extrinsics_44 = obs.misc[f"{camera}_camera_extrinsics"].astype(np.float32) 16 | extrinsics_44 = np.linalg.inv(extrinsics_44) 17 | 18 | intrinsics_33 = obs.misc[f"{camera}_camera_intrinsics"].astype(np.float32) 19 | intrinsics_34 = np.concatenate([intrinsics_33, np.zeros((3, 1), dtype=np.float32)], 1) 20 | 21 | gripper_pos_31 = obs.gripper_pose[:3].astype(np.float32)[:, None] 22 | gripper_pos_41 = np.concatenate([gripper_pos_31, np.ones((1, 1), dtype=np.float32)], 0) 23 | 24 | points_cam_41 = extrinsics_44 @ gripper_pos_41 25 | 26 | proj_31 = intrinsics_34 @ points_cam_41 27 | proj_3 = proj_31[:, 0] 28 | 29 | u = int((proj_3[0] / proj_3[2]).round()) 30 | v = int((proj_3[1] / proj_3[2]).round()) 31 | 32 | return u, v 33 | 34 | 35 | def quaternion_to_discrete_euler(quaternion, resolution: int): 36 | euler = R.from_quat(quaternion).as_euler('xyz', degrees=True) + 180 37 | assert np.min(euler) >= 0 and np.max(euler) <= 360 38 | disc = np.around((euler / resolution)).astype(int) 39 | disc[disc == int(360 / resolution)] = 0 40 | return disc 41 | 42 | 43 | def discrete_euler_to_quaternion(discrete_euler, resolution: int): 44 | euluer = (discrete_euler * resolution) - 180 45 | return R.from_euler('xyz', euluer, degrees=True).as_quat() 46 | 47 | 48 | def euler_to_quat(euler, degrees): 49 | rotation = R.from_euler("xyz", euler, degrees=degrees) 50 | return rotation.as_quat() 51 | 52 | 53 | def quat_to_euler(quat, degrees): 54 | rotation = R.from_quat(quat) 55 | return rotation.as_euler("xyz", degrees=degrees) -------------------------------------------------------------------------------- /polarnet/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Distributed tools 3 | """ 4 | import os 5 | from pathlib import Path 6 | from pprint import pformat 7 | import pickle 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | def set_local_rank(opts) -> int: 14 | if os.environ.get("LOCAL_RANK", "") != "": 15 | opts.local_rank = int(os.environ["LOCAL_RANK"]) 16 | elif os.environ.get("SLURM_LOCALID", "") != "": 17 | opts.local_rank = int(os.environ["SLURM_LOCALID"]) 18 | else: 19 | opts.local_rank = -1 20 | return opts.local_rank 21 | 22 | 23 | def load_init_param(opts): 24 | """ 25 | Load parameters for the rendezvous distributed procedure 26 | """ 27 | # num of gpus per node 28 | # WARNING: this assumes that each node has the same number of GPUs 29 | if os.environ.get("SLURM_NTASKS_PER_NODE", "") != "": 30 | num_gpus = int(os.environ['SLURM_NTASKS_PER_NODE']) 31 | else: 32 | num_gpus = torch.cuda.device_count() 33 | 34 | # world size 35 | if os.environ.get("WORLD_SIZE", "") != "": 36 | world_size = int(os.environ["WORLD_SIZE"]) 37 | elif os.environ.get("SLURM_JOB_NUM_NODES", ""): 38 | num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"]) 39 | world_size = num_nodes * num_gpus 40 | else: 41 | raise RuntimeError("Can't find any world size") 42 | opts.world_size = world_size 43 | 44 | # rank 45 | if os.environ.get("RANK", "") != "": 46 | # pytorch.distributed.launch provide this variable no matter what 47 | opts.rank = int(os.environ["RANK"]) 48 | elif os.environ.get("SLURM_PROCID", "") != "": 49 | opts.rank = int(os.environ["SLURM_PROCID"]) 50 | else: 51 | if os.environ.get("NODE_RANK", "") != "": 52 | opts.node_rank = int(os.environ["NODE_RANK"]) 53 | elif os.environ.get("SLURM_NODEID", "") != "": 54 | opts.node_rank = int(os.environ["SLURM_NODEID"]) 55 | else: 56 | raise RuntimeError("Can't find any rank or node rank") 57 | 58 | opts.rank = opts.local_rank + node_rank * num_gpus 59 | 60 | init_method = "env://" # need to specify MASTER_ADDR and MASTER_PORT 61 | 62 | return { 63 | "backend": "nccl", 64 | "init_method": init_method, 65 | "rank": opts.rank, 66 | "world_size": world_size, 67 | } 68 | 69 | 70 | def init_distributed(opts): 71 | init_param = load_init_param(opts) 72 | rank = init_param["rank"] 73 | print(f"Init distributed {init_param['rank']} - {init_param['world_size']}") 74 | 75 | dist.init_process_group(**init_param) 76 | 77 | 78 | def is_default_gpu(opts) -> bool: 79 | return opts.local_rank == -1 or dist.get_rank() == 0 80 | 81 | 82 | def is_dist_avail_and_initialized(): 83 | if not dist.is_available(): 84 | return False 85 | if not dist.is_initialized(): 86 | return False 87 | return True 88 | 89 | def get_world_size(): 90 | if not is_dist_avail_and_initialized(): 91 | return 1 92 | return dist.get_world_size() 93 | 94 | def all_gather(data): 95 | """ 96 | Run all_gather on arbitrary picklable data (not necessarily tensors) 97 | Args: 98 | data: any picklable object 99 | Returns: 100 | list[data]: list of data gathered from each rank 101 | """ 102 | world_size = get_world_size() 103 | if world_size == 1: 104 | return [data] 105 | 106 | # serialized to a Tensor 107 | buffer = pickle.dumps(data) 108 | storage = torch.ByteStorage.from_buffer(buffer) 109 | tensor = torch.ByteTensor(storage).to("cuda") 110 | 111 | # obtain Tensor size of each rank 112 | local_size = torch.tensor([tensor.numel()], device="cuda") 113 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 114 | dist.all_gather(size_list, local_size) 115 | size_list = [int(size.item()) for size in size_list] 116 | max_size = max(size_list) 117 | 118 | # receiving Tensor from all ranks 119 | # we pad the tensor because torch all_gather does not support 120 | # gathering tensors of different shapes 121 | tensor_list = [] 122 | for _ in size_list: 123 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 124 | if local_size != max_size: 125 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 126 | tensor = torch.cat((tensor, padding), dim=0) 127 | dist.all_gather(tensor_list, tensor) 128 | 129 | data_list = [] 130 | for size, tensor in zip(size_list, tensor_list): 131 | buffer = tensor.cpu().numpy().tobytes()[:size] 132 | data_list.append(pickle.loads(buffer)) 133 | 134 | return data_list 135 | 136 | 137 | def reduce_dict(input_dict, average=True): 138 | """ 139 | Args: 140 | input_dict (dict): all the values will be reduced 141 | average (bool): whether to do average or sum 142 | Reduce the values in the dictionary from all processes so that all processes 143 | have the averaged results. Returns a dict with the same fields as 144 | input_dict, after reduction. 145 | """ 146 | world_size = get_world_size() 147 | if world_size < 2: 148 | return input_dict 149 | with torch.no_grad(): 150 | names = [] 151 | values = [] 152 | # sort the keys so that they are consistent across processes 153 | for k in sorted(input_dict.keys()): 154 | names.append(k) 155 | values.append(input_dict[k]) 156 | values = torch.stack(values, dim=0) 157 | dist.all_reduce(values) 158 | if average: 159 | values /= world_size 160 | reduced_dict = {k: v for k, v in zip(names, values)} 161 | return reduced_dict 162 | 163 | 164 | -------------------------------------------------------------------------------- /polarnet/utils/keystep_detection.py: -------------------------------------------------------------------------------- 1 | '''Identify way-point in each RLBench Demo 2 | ''' 3 | 4 | from typing import List, Dict, Optional, Sequence, Tuple, TypedDict, Union, Any 5 | 6 | import numpy as np 7 | 8 | from rlbench.demo import Demo 9 | 10 | 11 | def _is_stopped(demo, i, obs, stopped_buffer): 12 | next_is_not_final = (i < (len(demo) - 2)) 13 | gripper_state_no_change = i < (len(demo) - 2) and ( 14 | obs.gripper_open == demo[i + 1].gripper_open 15 | and obs.gripper_open == demo[max(0, i - 1)].gripper_open 16 | and demo[max(0, i - 2)].gripper_open == demo[max(0, i - 1)].gripper_open 17 | ) 18 | small_delta = np.allclose(obs.joint_velocities, 0, atol=0.1) 19 | stopped = ( 20 | stopped_buffer <= 0 21 | and small_delta 22 | and next_is_not_final 23 | and gripper_state_no_change 24 | ) 25 | return stopped 26 | 27 | 28 | def keypoint_discovery(demo: Demo) -> List[int]: 29 | episode_keypoints = [] 30 | prev_gripper_open = demo[0].gripper_open 31 | stopped_buffer = 0 32 | for i, obs in enumerate(demo): 33 | stopped = _is_stopped(demo, i, obs, stopped_buffer) 34 | stopped_buffer = 4 if stopped else stopped_buffer - 1 35 | # If change in gripper, or end of episode. 36 | last = i == (len(demo) - 1) 37 | if i != 0 and (obs.gripper_open != prev_gripper_open or last or stopped): 38 | episode_keypoints.append(i) 39 | prev_gripper_open = obs.gripper_open 40 | if ( 41 | len(episode_keypoints) > 1 42 | and (episode_keypoints[-1] - 1) == episode_keypoints[-2] 43 | ): 44 | episode_keypoints.pop(-2) 45 | 46 | return episode_keypoints 47 | 48 | -------------------------------------------------------------------------------- /polarnet/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | helper for logging 6 | NOTE: loggers are global objects use with caution 7 | """ 8 | import logging 9 | import math 10 | 11 | import tensorboardX 12 | 13 | 14 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 15 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 16 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 17 | LOGGER = logging.getLogger('__main__') # this is the global logger 18 | 19 | 20 | def add_log_to_file(log_path): 21 | fh = logging.FileHandler(log_path) 22 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 23 | fh.setFormatter(formatter) 24 | LOGGER.addHandler(fh) 25 | 26 | 27 | class TensorboardLogger(object): 28 | def __init__(self): 29 | self._logger = None 30 | self._global_step = 0 31 | 32 | def create(self, path): 33 | self._logger = tensorboardX.SummaryWriter(path) 34 | 35 | def noop(self, *args, **kwargs): 36 | return 37 | 38 | def step(self): 39 | self._global_step += 1 40 | 41 | @property 42 | def global_step(self): 43 | return self._global_step 44 | 45 | def log_scalar_dict(self, log_dict, prefix=''): 46 | """ log a dictionary of scalar values""" 47 | if self._logger is None: 48 | return 49 | if prefix: 50 | prefix = f'{prefix}_' 51 | for name, value in log_dict.items(): 52 | if isinstance(value, dict): 53 | self.log_scalar_dict(value, self._global_step, 54 | prefix=f'{prefix}{name}') 55 | else: 56 | self._logger.add_scalar(f'{prefix}{name}', value, 57 | self._global_step) 58 | 59 | def __getattr__(self, name): 60 | if self._logger is None: 61 | return self.noop 62 | return self._logger.__getattribute__(name) 63 | 64 | 65 | TB_LOGGER = TensorboardLogger() 66 | 67 | 68 | class RunningMeter(object): 69 | """ running meteor of a scalar value 70 | (useful for monitoring training loss) 71 | """ 72 | def __init__(self, name, val=None, smooth=0.99): 73 | self._name = name 74 | self._sm = smooth 75 | self._val = val 76 | 77 | def __call__(self, value): 78 | val = (value if self._val is None 79 | else value*(1-self._sm) + self._val*self._sm) 80 | if not math.isnan(val): 81 | self._val = val 82 | 83 | def __str__(self): 84 | return f'{self._name}: {self._val:.4f}' 85 | 86 | @property 87 | def val(self): 88 | if self._val is None: 89 | return 0 90 | return self._val 91 | 92 | @property 93 | def name(self): 94 | return self._name 95 | -------------------------------------------------------------------------------- /polarnet/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from typing import Tuple, Union, Dict, Any 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | 10 | from .distributed import init_distributed, set_local_rank 11 | from .logger import LOGGER 12 | 13 | 14 | def set_random_seed(seed): 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | def set_dropout(model, drop_p): 21 | for name, module in model.named_modules(): 22 | # we might want to tune dropout for smaller dataset 23 | if isinstance(module, torch.nn.Dropout): 24 | if module.p != drop_p: 25 | module.p = drop_p 26 | LOGGER.info(f'{name} set to {drop_p}') 27 | 28 | 29 | def set_cuda(opts) -> Tuple[bool, int, torch.device]: 30 | """ 31 | Initialize CUDA for distributed computing 32 | """ 33 | set_local_rank(opts) 34 | 35 | if not torch.cuda.is_available(): 36 | assert opts.local_rank == -1, opts.local_rank 37 | return True, 0, torch.device("cpu") 38 | 39 | # get device settings 40 | if opts.local_rank != -1: 41 | init_distributed(opts) 42 | torch.cuda.set_device(opts.local_rank) 43 | device = torch.device("cuda", opts.local_rank) 44 | n_gpu = 1 45 | default_gpu = dist.get_rank() == 0 46 | if default_gpu: 47 | LOGGER.info(f"Found {dist.get_world_size()} GPUs") 48 | else: 49 | default_gpu = True 50 | device = torch.device("cuda") 51 | n_gpu = torch.cuda.device_count() 52 | 53 | return default_gpu, n_gpu, device 54 | 55 | 56 | def wrap_model( 57 | model: torch.nn.Module, device: torch.device, local_rank: int 58 | ) -> torch.nn.Module: 59 | model.to(device) 60 | 61 | if local_rank != -1: 62 | model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) 63 | # At the time of DDP wrapping, parameters and buffers (i.e., model.state_dict()) 64 | # on rank0 are broadcasted to all other ranks. 65 | elif torch.cuda.device_count() > 1: 66 | LOGGER.info("Using data parallel") 67 | model = torch.nn.DataParallel(model) 68 | 69 | return model 70 | 71 | 72 | class NoOp(object): 73 | """ useful for distributed training No-Ops """ 74 | def __getattr__(self, name): 75 | return self.noop 76 | 77 | def noop(self, *args, **kwargs): 78 | return 79 | 80 | -------------------------------------------------------------------------------- /polarnet/utils/ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def pad_tensors(tensors, lens=None, pad=0): 5 | """B x [T, ...] torch tensors""" 6 | if lens is None: 7 | lens = [t.size(0) for t in tensors] 8 | max_len = max(lens) 9 | bs = len(tensors) 10 | hid = list(tensors[0].size()[1:]) 11 | size = [bs, max_len] + hid 12 | 13 | dtype = tensors[0].dtype 14 | output = torch.zeros(*size, dtype=dtype) 15 | if pad: 16 | output.data.fill_(pad) 17 | for i, (t, l) in enumerate(zip(tensors, lens)): 18 | output.data[i, :l, ...] = t.data 19 | return output 20 | 21 | def pad_tensors_wgrad(tensors, lens=None, value=0): 22 | """B x [T, ...] torch tensors""" 23 | if lens is None: 24 | lens = [t.size(0) for t in tensors] 25 | max_len = max(lens) 26 | batch_size = len(tensors) 27 | hid = list(tensors[0].size()[1:]) 28 | 29 | device = tensors[0].device 30 | dtype = tensors[0].dtype 31 | 32 | output = [] 33 | for i in range(batch_size): 34 | if lens[i] < max_len: 35 | tmp = torch.cat( 36 | [tensors[i], torch.zeros([max_len-lens[i]]+hid, dtype=dtype).to(device) + value], 37 | dim=0 38 | ) 39 | else: 40 | tmp = tensors[i] 41 | output.append(tmp) 42 | output = torch.stack(output, 0) 43 | return output 44 | 45 | 46 | def gen_seq_masks(seq_lens, max_len=None): 47 | """ 48 | Args: 49 | seq_lens: list or nparray int, shape=(N, ) 50 | Returns: 51 | masks: nparray, shape=(N, L), padded=0 52 | """ 53 | seq_lens = np.array(seq_lens) 54 | if max_len is None: 55 | max_len = max(seq_lens) 56 | if max_len == 0: 57 | return np.zeros((len(seq_lens), 0), dtype=np.bool) 58 | batch_size = len(seq_lens) 59 | masks = np.arange(max_len).reshape(-1, max_len).repeat(batch_size, 0) 60 | masks = masks < seq_lens.reshape(-1, 1) 61 | return masks 62 | 63 | 64 | def extend_neg_masks(masks, dtype=None): 65 | """ 66 | mask from (N, L) into (N, 1(H), 1(L), L) and make it negative 67 | """ 68 | if dtype is None: 69 | dtype = torch.float 70 | extended_masks = masks.unsqueeze(1).unsqueeze(2) 71 | extended_masks = extended_masks.to(dtype=dtype) 72 | extended_masks = (1.0 - extended_masks) * -10000.0 73 | return extended_masks -------------------------------------------------------------------------------- /polarnet/utils/recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Type 3 | import numpy as np 4 | 5 | from pathlib import Path 6 | from typing import Tuple, Dict, List 7 | from pyrep.objects.dummy import Dummy 8 | from pyrep.objects.vision_sensor import VisionSensor 9 | 10 | 11 | class CameraMotion(object): 12 | def __init__(self, cam: VisionSensor): 13 | self.cam = cam 14 | 15 | def step(self): 16 | raise NotImplementedError() 17 | 18 | def save_pose(self): 19 | self._prev_pose = self.cam.get_pose() 20 | 21 | def restore_pose(self): 22 | self.cam.set_pose(self._prev_pose) 23 | 24 | 25 | class CircleCameraMotion(CameraMotion): 26 | 27 | def __init__(self, cam: VisionSensor, origin: Dummy, speed: float): 28 | super().__init__(cam) 29 | self.origin = origin 30 | self.speed = speed # in radians 31 | 32 | def step(self): 33 | self.origin.rotate([0, 0, self.speed]) 34 | 35 | 36 | class StaticCameraMotion(CameraMotion): 37 | 38 | def __init__(self, cam: VisionSensor): 39 | super().__init__(cam) 40 | 41 | def step(self): 42 | pass 43 | 44 | class AttachedCameraMotion(CameraMotion): 45 | 46 | def __init__(self, cam: VisionSensor, parent_cam: VisionSensor): 47 | super().__init__(cam) 48 | self.parent_cam = parent_cam 49 | 50 | def step(self): 51 | self.cam.set_pose(self.parent_cam.get_pose()) 52 | 53 | 54 | class TaskRecorder(object): 55 | 56 | def __init__(self, cams_motion: Dict[str, CameraMotion], fps=30): 57 | self._cams_motion = cams_motion 58 | self._fps = fps 59 | self._snaps = {cam_name: [] for cam_name in self._cams_motion.keys()} 60 | 61 | def take_snap(self): 62 | for cam_name, cam_motion in self._cams_motion.items(): 63 | cam_motion.step() 64 | self._snaps[cam_name].append( 65 | (cam_motion.cam.capture_rgb() * 255.).astype(np.uint8)) 66 | 67 | def save(self, path): 68 | print('Converting to video ...') 69 | path = Path(path) 70 | path.mkdir(exist_ok=True) 71 | # OpenCV QT version can conflict with PyRep, so import here 72 | import cv2 73 | for cam_name, cam_motion in self._cams_motion.items(): 74 | video = cv2.VideoWriter( 75 | str(path / f"{cam_name}.avi"), cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), self._fps, 76 | tuple(cam_motion.cam.get_resolution())) 77 | for image in self._snaps[cam_name]: 78 | video.write(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 79 | video.release() 80 | 81 | self._snaps = {cam_name: [] for cam_name in self._cams_motion.keys()} 82 | 83 | def clean_buffer(self): 84 | for cam_name, cam_motion in self._cams_motion.items(): 85 | cam_motion.step() 86 | self._snaps[cam_name] = [] 87 | -------------------------------------------------------------------------------- /polarnet/utils/save.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | saving utilities 6 | """ 7 | import json 8 | import os 9 | import torch 10 | 11 | 12 | def save_training_meta(args): 13 | os.makedirs(os.path.join(args.output_dir, 'logs'), exist_ok=True) 14 | os.makedirs(os.path.join(args.output_dir, 'ckpts'), exist_ok=True) 15 | 16 | with open(os.path.join(args.output_dir, 'logs', 'training_config.yaml'), 'w') as writer: 17 | args_str = args.dump() 18 | print(args_str, file=writer) 19 | 20 | class ModelSaver(object): 21 | def __init__(self, output_dir, prefix='model_step', suffix='pt'): 22 | self.output_dir = output_dir 23 | self.prefix = prefix 24 | self.suffix = suffix 25 | 26 | def save(self, model, step, optimizer=None, rewrite_optimizer=False): 27 | output_model_file = os.path.join(self.output_dir, 28 | f"{self.prefix}_{step}.{self.suffix}") 29 | state_dict = {} 30 | for k, v in model.state_dict().items(): 31 | if k.startswith('module.'): 32 | k = k[7:] 33 | if isinstance(v, torch.Tensor): 34 | state_dict[k] = v.cpu() 35 | else: 36 | state_dict[k] = v 37 | torch.save(state_dict, output_model_file) 38 | if optimizer is not None: 39 | dump = {'step': step, 'optimizer': optimizer.state_dict()} 40 | if hasattr(optimizer, '_amp_stash'): 41 | pass # TODO fp16 optimizer 42 | if rewrite_optimizer: 43 | torch.save(dump, f'{self.output_dir}/train_state_latest.pt') 44 | else: 45 | torch.save(dump, f'{self.output_dir}/train_state_{step}.pt') 46 | 47 | 48 | -------------------------------------------------------------------------------- /polarnet/utils/slurm_requeue.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | import signal 4 | import sys 5 | import logging 6 | 7 | from pathlib import Path 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | def sig_handler(signum, frame): 12 | logger.warning("Signal handler called with signal " + str(signum)) 13 | prod_id = int(os.environ['SLURM_PROCID']) 14 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 15 | if prod_id == 0: 16 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 17 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 18 | else: 19 | logger.warning("Not the master process, no need to requeue.") 20 | sys.exit(-1) 21 | 22 | 23 | def init_signal_handler(): 24 | """ 25 | Handle signals sent by SLURM for time limit. 26 | """ 27 | signal.signal(signal.SIGUSR1, sig_handler) 28 | logger.warning("Signal handler installed.") 29 | -------------------------------------------------------------------------------- /polarnet/utils/utils.py: -------------------------------------------------------------------------------- 1 | import polarnet 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import random 7 | import torch 8 | 9 | def set_random_seed(seed): 10 | torch.manual_seed(seed) 11 | np.random.seed(seed) 12 | random.seed(seed) 13 | 14 | def get_assets_dir(): 15 | return str(Path(polarnet.__file__).parent / "assets") 16 | 17 | def get_expr_dirs(output_dir): 18 | log_dir = os.path.join(output_dir, 'logs') 19 | ckpt_dir = os.path.join(output_dir, 'ckpts') 20 | pred_dir = os.path.join(output_dir, 'preds') 21 | 22 | os.makedirs(log_dir, exist_ok=True) 23 | os.makedirs(ckpt_dir, exist_ok=True) 24 | os.makedirs(pred_dir, exist_ok=True) 25 | 26 | return log_dir, ckpt_dir, pred_dir 27 | 28 | -------------------------------------------------------------------------------- /polarnet/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | def plot_attention( 7 | attentions: torch.Tensor, rgbs: torch.Tensor, pcds: torch.Tensor, dest: Path 8 | ) -> plt.Figure: 9 | attentions = attentions.detach().cpu() 10 | rgbs = rgbs.detach().cpu() 11 | pcds = pcds.detach().cpu() 12 | 13 | ep_dir = dest.parent 14 | ep_dir.mkdir(exist_ok=True, parents=True) 15 | name = dest.stem 16 | ext = dest.suffix 17 | 18 | # plt.figure(figsize=(10, 8)) 19 | num_cameras = len(attentions) 20 | for i, (a, rgb, pcd) in enumerate(zip(attentions, rgbs, pcds)): 21 | # plt.subplot(num_cameras, 4, i * 4 + 1) 22 | plt.imshow(a.permute(1, 2, 0).log()) 23 | plt.axis("off") 24 | plt.colorbar() 25 | plt.savefig(ep_dir / f"{name}-{i}-attn{ext}", bbox_inches="tight") 26 | plt.tight_layout() 27 | plt.clf() 28 | 29 | # plt.subplot(num_cameras, 4, i * 4 + 2) 30 | # plt.imshow(a.permute(1, 2, 0)) 31 | # plt.axis('off') 32 | # plt.colorbar() 33 | # plt.tight_layout() 34 | # plt.clf() 35 | 36 | # plt.subplot(num_cameras, 4, i * 4 + 3) 37 | plt.imshow(((rgb + 1) / 2).permute(1, 2, 0)) 38 | plt.axis("off") 39 | plt.savefig(ep_dir / f"{name}-{i}-rgb{ext}", bbox_inches="tight") 40 | plt.tight_layout() 41 | plt.clf() 42 | 43 | pcd_norm = (pcd - pcd.min(0).values) / (pcd.max(0).values - pcd.min(0).values) 44 | # plt.subplot(num_cameras, 4, i * 4 + 4) 45 | plt.imshow(pcd_norm.permute(1, 2, 0)) 46 | plt.axis("off") 47 | plt.savefig(ep_dir / f"{name}-{i}-pcd{ext}", bbox_inches="tight") 48 | plt.tight_layout() 49 | plt.clf() 50 | 51 | return plt.gcf() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | 10 | def read_requirements_file(filename): 11 | req_file_path = path.join(path.dirname(path.realpath(__file__)), filename) 12 | with open(req_file_path) as f: 13 | return [line.strip() for line in f] 14 | 15 | 16 | setup( 17 | name="polarnet", 18 | version="0.0.1", 19 | description="PolarNet: 3D Point Clouds for Language-Guided Robotic Manipulation", 20 | packages=find_packages(), 21 | ) --------------------------------------------------------------------------------