├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── evaluation ├── README.md ├── dev.py ├── eval.sh ├── r3meval │ ├── __init__.py │ ├── core │ │ ├── config │ │ │ ├── BC_config.yaml │ │ │ ├── eval_config.yaml │ │ │ └── hydra │ │ │ │ ├── launcher │ │ │ │ └── local.yaml │ │ │ │ └── output │ │ │ │ └── local.yaml │ │ ├── eval.sh │ │ ├── eval_loop.py │ │ ├── hydra_eval_launcher.py │ │ ├── hydra_launcher.py │ │ ├── run.sh │ │ └── train_loop.py │ └── utils │ │ ├── __init__.py │ │ ├── behavior_cloning.py │ │ ├── fc_network.py │ │ ├── gaussian_mlp.py │ │ ├── gym_env.py │ │ ├── logger.py │ │ ├── obs_wrappers.py │ │ ├── sampling.py │ │ ├── tensor_utils.py │ │ └── visualizations.py ├── run.sh └── setup.py ├── plots └── mvp_dino_r3m.py ├── r3m ├── __init__.py ├── cfgs │ ├── config_rep.yaml │ └── hydra │ │ ├── launcher │ │ └── local.yaml │ │ └── output │ │ └── local.yaml ├── example.py ├── models │ ├── models_language.py │ └── models_r3m.py ├── r3m_base.yaml ├── train_representation.py ├── trainer.py └── utils │ ├── data_loaders.py │ ├── logger.py │ └── utils.py ├── setup.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | __pycache__/ 3 | .nfs* 4 | *.ipynb* 5 | *.egg* 6 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to r3m 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to r3m, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segmenting-Features 2 | 3 | Code for evaluating different visual pre-training strategies. 4 | 5 | ## Installation 6 | 7 | To install R3M from an existing conda environment, simply run `pip install -e .` from this directory. 8 | 9 | You can alternatively build a fresh conda env from the r3m_base.yaml file [here](https://github.com/facebookresearch/r3m/blob/main/r3m/r3m_base.yaml) and then install from this directory with `pip install -e .` 10 | 11 | ## Running Evaluation 12 | 13 | To train policies on top of each representation: 14 | ``` 15 | cd evaluation/r3meval/core/ 16 | ./run.sh 17 | ``` 18 | 19 | ## Testing Transfer 20 | 21 | To test transfer with kitchen shift: 22 | ``` 23 | cd evaluation/r3meval/core/ 24 | ./eval.sh 25 | ``` 26 | 27 | ## License 28 | 29 | R3M is licensed under the MIT license. 30 | 31 | ## Ackowledgements 32 | 33 | Adapted from the [R3M](https://github.com/facebookresearch/r3m) codebase. 34 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation Code for R3M 2 | 3 | The codebase contains the evaluation codebase from the paper [R3M: A Universal Visual Representation for Robot Manipulation](https://sites.google.com/view/robot-r3m/). 4 | 5 | It trains policies from pixels with behavior cloning using pre-collected demonstrations, evaluating the policies in the environment at regular intervals. It allows for selecting different visual representations to use during imitation. 6 | 7 | ## Environment Installation 8 | 9 | The first step to running the code involves installing the evaluation environments. 10 | 11 | For metaworld environments, install the environments by cloning this [fork of the metaworld repo](https://github.com/suraj-nair-1/metaworld) and installing via `pip install -e .` 12 | 13 | In order to install the Franka Kitchen and Adroit environments, first install the `mjrl` repo using instructions [here](https://github.com/aravindr93/mjrl). 14 | 15 | Then, install the `mj_envs` repo as described in [this tag](https://github.com/vikashplus/mj_envs/releases/tag/v0.0.5). 16 | 17 | ## Installing R3M 18 | 19 | To use the R3M model, simply follow the installation process in the parent directory [here](https://github.com/facebookresearch/r3m/tree/eval). 20 | 21 | ## Downloading Demonstration Data 22 | 23 | All demonstrations are located [here](https://drive.google.com/drive/folders/108VW5t5JV8uNtkWvfZxEvY2P2QkC_tsf?usp=sharing). Then change the path [here](https://github.com/facebookresearch/r3m/blob/eval/evaluation/r3meval/core/train_loop.py#L99) to point to where the demonstration data is located. Make sure the data is saved with the same folder structure as on the google drive, e.g. `/final_paths_multiview_meta_200//.pickle`. 24 | 25 | ## Install and Run Eval Code 26 | 27 | If the above was all done correctly, you should be able to simply run `pip install -e .` in this directory. 28 | 29 | ## Verifying Correct Installation 30 | 31 | While running all experiments can be time consuming, a simple check to make sure things are behaving as expected is to download the demos for the kitchen sliding door task, and run: 32 | 33 | ``` 34 | python hydra_launcher.py hydra/launcher=local hydra/output=local env="kitchen_sdoor_open-v3" camera="left_cap2" pixel_based=true embedding=resnet50 num_demos=5 env_kwargs.load_path=r3m bc_kwargs.finetune=false proprio=9 job_name=r3m_repro seed=125 35 | ``` 36 | and 37 | ``` 38 | python hydra_launcher.py hydra/launcher=local hydra/output=local env="kitchen_sdoor_open-v3" camera="left_cap2" pixel_based=true embedding=resnet50 num_demos=5 env_kwargs.load_path=clip bc_kwargs.finetune=false proprio=9 job_name=r3m_repro seed=125 39 | ``` 40 | 41 | You should see R3M get ~60% success on the first eval, while CLIP will get ~30%. 42 | 43 | 44 | ## Commands for All Experiments 45 | 46 | For running kitchen environments run: 47 | ``` 48 | python hydra_launcher.py --multirun hydra/launcher=local hydra/output=local env=["kitchen_knob1_on-v3","kitchen_light_on-v3","kitchen_sdoor_open-v3","kitchen_ldoor_open-v3","kitchen_micro_open-v3"] camera=["default","left_cap2","right_cap2"] pixel_based=true embedding=resnet50 num_demos=25 env_kwargs.load_path=r3m bc_kwargs.finetune=false proprio=9 job_name=try_r3m 49 | ``` 50 | 51 | For running metaworld environments run: 52 | 53 | ``` 54 | python hydra_launcher.py --multirun hydra/launcher=local hydra/output=local env=["assembly-v2-goal-observable","bin-picking-v2-goal-observable","button-press-topdown-v2-goal-observable","drawer-open-v2-goal-observable","hammer-v2-goal-observable"] camera=["left_cap2","right_cap2","top_cap2"] pixel_based=true embedding=resnet50 num_demos=25 env_kwargs.load_path=r3m bc_kwargs.finetune=false proprio=4 job_name=try_r3m 55 | ``` 56 | 57 | For running the Adroit pen task: 58 | ``` 59 | python hydra_launcher.py --multirun hydra/launcher=local hydra/output=local env=pen-v0 camera=["view_1","top","view_4"] pixel_based=true embedding=resnet50 num_demos=25 env_kwargs.load_path=r3m bc_kwargs.finetune=false proprio=24 job_name=try_r3m 60 | ``` 61 | 62 | For running the Adroit relocate task: 63 | ``` 64 | python hydra_launcher.py --multirun hydra/launcher=local hydra/output=local env=relocate-v0 camera=["view_1","top","view_4"] pixel_based=true embedding=resnet50 num_demos=25 env_kwargs.load_path=r3m bc_kwargs.finetune=false proprio=30 job_name=try_r3m 65 | ``` 66 | -------------------------------------------------------------------------------- /evaluation/dev.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from torch import embedding 4 | from r3meval.utils.gym_env import GymEnv 5 | from r3meval.utils.obs_wrappers import MuJoCoPixelObs, StateEmbedding 6 | from r3meval.utils.sampling import sample_paths 7 | from r3meval.utils.gaussian_mlp import MLP 8 | from r3meval.utils.behavior_cloning import BC 9 | from r3meval.utils.visualizations import place_attention_heatmap_over_images 10 | from tabulate import tabulate 11 | from tqdm import tqdm 12 | import mj_envs, gym 13 | import numpy as np, time as timer, multiprocessing, pickle, os 14 | import os 15 | from collections import namedtuple 16 | 17 | import kitchen_shift 18 | 19 | env_name = 'kitchen_knob1_on-v3' 20 | shift = 'none' 21 | render_gpu_id = 0 22 | image_width = 256 23 | image_height = 256 24 | embedding_name = 'resnet50' 25 | load_path = 'r3m' 26 | proprio = 9 27 | device = 'cuda' 28 | 29 | distractors = ['cracker_box', 'medium', 'hard'] 30 | textures_slide = ['wood2', 'metal2', 'tile1'] 31 | textures_hinge = ['wood1', 'metal1', 'marble1'] 32 | textures_floor = ['tile1', 'wood1'] 33 | textures_counter = ['wood2'] 34 | lightings = ['cast_left', 'cast_right', 'brighter', 'darker'] 35 | 36 | for camera_name in ['right_cap2', 'left_cap2', 'default']: 37 | for distractor in distractors: 38 | e = gym.make(env_name, model_path=f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_distractor_{distractor}.xml') 39 | ## Wrap in pixel observation wrapper 40 | e = MuJoCoPixelObs(e, width=image_width, height=image_height, 41 | camera_name=camera_name, device_id=render_gpu_id, 42 | shift=shift) 43 | 44 | ## Wrapper which encodes state in pretrained model 45 | e = StateEmbedding(e, embedding_name=embedding_name, device=device, load_path=load_path, 46 | proprio=proprio, camera_name=camera_name, env_name=env_name) 47 | e = GymEnv(e) 48 | img = e.env.env.get_image() 49 | 50 | import cv2 51 | cv2.imwrite(f'photos/{camera_name}_distractor_{distractor}.png', img[:, :, ::-1]) 52 | 53 | camera_name = 'left_cap2' 54 | for texture in textures_counter: 55 | e = gym.make(env_name, model_path=f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_counter_{texture}.xml') 56 | ## Wrap in pixel observation wrapper 57 | e = MuJoCoPixelObs(e, width=image_width, height=image_height, 58 | camera_name=camera_name, device_id=render_gpu_id, 59 | shift=shift) 60 | 61 | ## Wrapper which encodes state in pretrained model 62 | e = StateEmbedding(e, embedding_name=embedding_name, device=device, load_path=load_path, 63 | proprio=proprio, camera_name=camera_name, env_name=env_name) 64 | e = GymEnv(e) 65 | img = e.env.env.get_image() 66 | 67 | import cv2 68 | cv2.imwrite(f'photos/counter_{texture}.png', img[:, :, ::-1]) 69 | 70 | 71 | for texture in textures_floor: 72 | e = gym.make(env_name, model_path=f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_floor_{texture}.xml') 73 | ## Wrap in pixel observation wrapper 74 | e = MuJoCoPixelObs(e, width=image_width, height=image_height, 75 | camera_name=camera_name, device_id=render_gpu_id, 76 | shift=shift) 77 | 78 | ## Wrapper which encodes state in pretrained model 79 | e = StateEmbedding(e, embedding_name=embedding_name, device=device, load_path=load_path, 80 | proprio=proprio, camera_name=camera_name, env_name=env_name) 81 | e = GymEnv(e) 82 | img = e.env.env.get_image() 83 | 84 | import cv2 85 | cv2.imwrite(f'photos/floor_{texture}.png', img[:, :, ::-1]) 86 | 87 | for texture in textures_slide: 88 | e = gym.make(env_name, model_path=f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_slide_{texture}.xml') 89 | ## Wrap in pixel observation wrapper 90 | e = MuJoCoPixelObs(e, width=image_width, height=image_height, 91 | camera_name=camera_name, device_id=render_gpu_id, 92 | shift=shift) 93 | 94 | ## Wrapper which encodes state in pretrained model 95 | e = StateEmbedding(e, embedding_name=embedding_name, device=device, load_path=load_path, 96 | proprio=proprio, camera_name=camera_name, env_name=env_name) 97 | e = GymEnv(e) 98 | img = e.env.env.get_image() 99 | 100 | import cv2 101 | cv2.imwrite(f'photos/slide_{texture}.png', img[:, :, ::-1]) 102 | 103 | for texture in textures_hinge: 104 | e = gym.make(env_name, model_path=f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_hinge_{texture}.xml') 105 | ## Wrap in pixel observation wrapper 106 | e = MuJoCoPixelObs(e, width=image_width, height=image_height, 107 | camera_name=camera_name, device_id=render_gpu_id, 108 | shift=shift) 109 | 110 | ## Wrapper which encodes state in pretrained model 111 | e = StateEmbedding(e, embedding_name=embedding_name, device=device, load_path=load_path, 112 | proprio=proprio, camera_name=camera_name, env_name=env_name) 113 | e = GymEnv(e) 114 | img = e.env.env.get_image() 115 | 116 | import cv2 117 | cv2.imwrite(f'photos/hinge_{texture}.png', img[:, :, ::-1]) 118 | 119 | for lighting in lightings: 120 | e = gym.make(env_name, model_path=f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_{lighting}.xml') 121 | ## Wrap in pixel observation wrapper 122 | e = MuJoCoPixelObs(e, width=image_width, height=image_height, 123 | camera_name=camera_name, device_id=render_gpu_id, 124 | shift=shift) 125 | 126 | ## Wrapper which encodes state in pretrained model 127 | e = StateEmbedding(e, embedding_name=embedding_name, device=device, load_path=load_path, 128 | proprio=proprio, camera_name=camera_name, env_name=env_name) 129 | e = GymEnv(e) 130 | img = e.env.env.get_image() 131 | 132 | import cv2 133 | cv2.imwrite(f'photos/lighting_{lighting}.png', img[:, :, ::-1]) 134 | 135 | -------------------------------------------------------------------------------- /evaluation/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=iris-hi 3 | #SBATCH --mem=32G 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --exclude=iris4,iris2,iris-hp-z8 6 | #SBATCH --job-name="fancy new architecture" 7 | #SBATCH --time=3-0:0 8 | 9 | source /sailhome/kayburns/.bashrc 10 | conda activate py3.8_torch1.10.1 11 | cd /iris/u/kayburns/new_arch/r3m/evaluation/ 12 | 13 | # python r3meval/core/hydra_launcher.py hydra/launcher=local hydra/output=local \ 14 | # env="kitchen_sdoor_open-v3" camera="left_cap2" pixel_based=true \ 15 | # embedding=resnet50 num_demos=5 env_kwargs.load_path=r3m \ 16 | # bc_kwargs.finetune=false proprio=9 job_name=r3m_repro seed=125 \ 17 | 18 | for shift in none bottom_left_copy_crop bottom_left_red_rectangle bottom_left_white_rectangle bottom_left_no_blue_rectangle top_right_red_rectangle 19 | do 20 | python r3meval/core/hydra_launcher.py hydra/launcher=local hydra/output=local \ 21 | env="kitchen_sdoor_open-v3" camera="left_cap2" pixel_based=true \ 22 | embedding=resnet50 env_kwargs.load_path=r3m \ 23 | bc_kwargs.finetune=true proprio=9 job_name=eval seed=125 \ 24 | eval.eval=True env_kwargs.shift=$shift eval_num_traj=10 \ 25 | env_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-10_12-45-18/try_r3m/iterations/embedding_533.pickle \ 26 | bc_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-10_12-45-18/try_r3m/iterations/policy_533.pickle 27 | done 28 | 29 | # left_cap2, sdoor 30 | # . dino ft 31 | # env_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-08_03-44-50/try_r3m/iterations/embedding_733.pickle \ 32 | # bc_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-08_03-44-50/try_r3m/iterations/policy_733.pickle 33 | # . dino 34 | # bc_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-09_05-52-21/try_r3m/iterations/policy_1066.pickle 35 | # . r3m ft 36 | # env_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-10_12-45-18/try_r3m/iterations/embedding_533.pickle \ 37 | # bc_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-10_12-45-18/try_r3m/iterations/policy_533.pickle 38 | # . r3m 39 | # bc_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-08_03-29-49/try_r3m/iterations/policy_333.pickle 40 | 41 | 42 | 43 | 44 | 45 | 46 | # env_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-08_03-44-50/try_r3m/iterations/embedding_733.pickle \ #cross embodiment on kitchen left to right 47 | # bc_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-08_03-44-50/try_r3m/iterations/policy_733.pickle 48 | 49 | # env_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-09_09-04-41/try_r3m/iterations/embedding_999.pickle \ 50 | # bc_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-09_09-04-41/try_r3m/iterations/policy_999.pickle 51 | 52 | # bc_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/BC_pretrained_rep/2022-10-10_08-32-11/try_r3m/iterations/policy_333.pickle 53 | 54 | # python r3meval/core/hydra_launcher.py hydra/launcher=local hydra/output=local \ 55 | # env="kitchen_sdoor_open-v3" camera="left_cap2" pixel_based=true \ 56 | # embedding=dino num_demos=5 env_kwargs.load_path="/iris/u/kayburns/new_arch/r3m/evaluation/outputs/BC_pretrained_rep/2022-10-06_14-12-16/r3m_repro_all/iterations/embedding_999.pickle" \ 57 | # bc_kwargs.finetune=false proprio=9 job_name=eval_all seed=125 \ 58 | # eval.eval=True env_kwargs.shift=none \ 59 | # bc_kwargs.load_path="/iris/u/kayburns/new_arch/r3m/evaluation/outputs/BC_pretrained_rep/2022-10-06_14-12-16/r3m_repro_all/iterations/policy_999.pickle" \ 60 | 61 | # all finetuned 62 | # bc_kwargs.load_path="/iris/u/kayburns/new_arch/r3m/evaluation/outputs/BC_pretrained_rep/2022-10-05_18-05-13/r3m_repro_all/iterations/policy_999.pickle" # 63 | 64 | # head 2 65 | # bc_kwargs.load_path="/iris/u/kayburns/new_arch/r3m/evaluation/outputs/BC_pretrained_rep/2022-10-04_11-11-50/r3m_repro/iterations/policy_2857.pickle" # 42, 42 66 | # head 3 67 | # bc_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/outputs/BC_pretrained_rep/2022-09-28_23-47-08/r3m_repro/iterations/policy_2857.pickle # 52, 66 -> 24 68 | # bc_kwargs.load_path=/iris/u/kayburns/new_arch/r3m/evaluation/outputs/BC_pretrained_rep/2022-10-03_17-57-00/r3m_repro/iterations/policy_2857.pickle # 44 69 | # r3m 70 | # bc_kwargs.load_path="/iris/u/kayburns/new_arch/r3m/evaluation/outputs/BC_pretrained_rep/2022-10-04_11-22-39/r3m_repro/iterations/policy_2857.pickle" 71 | -------------------------------------------------------------------------------- /evaluation/r3meval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /evaluation/r3meval/core/config/BC_config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | default: 6 | - override hydra/launcher: local 7 | - override hydra/output: local 8 | 9 | # general inputs 10 | env : pen-v0 11 | pixel_based : True # pixel based (True) or state based (False) experiment 12 | embedding : resnet50 # choice of embedding network 13 | camera : vil_camera # choice of camera to use for image generation 14 | device : cuda 15 | 16 | # experiment and evaluation 17 | seed : 123 # used as base_seed for rolling out policy for eval in sample_paths 18 | steps : 20000 # number of outer epochs 19 | eval_frequency : 1000 # frequency of epochs for evaluation and logging 20 | eval_num_traj : 50 # number of rollouts to eval 21 | num_cpu : 1 # for rolling out paths when evaluating 22 | num_demos : 200 # path to demo file auto-inferred from other inputs 23 | proprio : 0 24 | 25 | ft_only_last_layer: false 26 | 27 | # environment related kwargs 28 | env_kwargs: 29 | env_name : ${env} 30 | device : ${device} # device to use for representation network (policy clamped to CPU for now) 31 | image_width : 256 32 | image_height : 256 33 | camera_name : ${camera} 34 | embedding_name : ${embedding} 35 | pixel_based : ${pixel_based} 36 | render_gpu_id : 0 37 | load_path : "" 38 | proprio : ${proprio} 39 | lang_cond : False 40 | gc : False 41 | shift : "none" 42 | 43 | # demo reparsing arguments (states -> image embeddings) 44 | reparse_kwargs: 45 | visualize : True # store videos (.mp4) of the trajectory while reparsing 46 | save_frames : True # save the generated images in the trajectory (can increase storage space dramatically) 47 | 48 | # BC agent setup 49 | bc_kwargs: 50 | loss_type : 'MSE' 51 | batch_size : 32 #200 52 | lr : 1e-3 53 | save_logs : False 54 | finetune : False 55 | proprio : ${proprio} 56 | proprio_only : False 57 | load_path : "" 58 | 59 | # eval setup 60 | eval: 61 | eval : False 62 | 63 | # general outputs 64 | job_name : 'R3M' 65 | project : 'main_sweep_noft' 66 | 67 | hydra: 68 | job: 69 | name: main_sweep_noft 70 | -------------------------------------------------------------------------------- /evaluation/r3meval/core/config/eval_config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | default: 6 | - override hydra/launcher: local 7 | - override hydra/output: local 8 | 9 | # general inputs 10 | env : pen-v0 11 | pixel_based : True # pixel based (True) or state based (False) experiment 12 | embedding : resnet50 # choice of embedding network 13 | camera : vil_camera # choice of camera to use for image generation 14 | device : cuda 15 | 16 | # experiment and evaluation 17 | seed : 123 # used as base_seed for rolling out policy for eval in sample_paths 18 | steps : 20000 # number of outer epochs 19 | eval_frequency : 1000 # frequency of epochs for evaluation and logging 20 | eval_num_traj : 20 # number of rollouts to eval 21 | num_cpu : 1 # for rolling out paths when evaluating 22 | num_demos : 200 # path to demo file auto-inferred from other inputs 23 | proprio : 0 24 | 25 | # environment related kwargs 26 | env_kwargs: 27 | env_name : ${env} 28 | device : ${device} # device to use for representation network (policy clamped to CPU for now) 29 | image_width : 256 30 | image_height : 256 31 | camera_name : ${camera} 32 | embedding_name : ${embedding} 33 | pixel_based : ${pixel_based} 34 | render_gpu_id : 0 35 | load_path : "" 36 | proprio : ${proprio} 37 | lang_cond : False 38 | gc : False 39 | shift : "none" 40 | 41 | # demo reparsing arguments (states -> image embeddings) 42 | reparse_kwargs: 43 | visualize : True # store videos (.mp4) of the trajectory while reparsing 44 | save_frames : True # save the generated images in the trajectory (can increase storage space dramatically) 45 | 46 | # BC agent setup 47 | bc_kwargs: 48 | loss_type : 'MSE' 49 | batch_size : 32 #200 50 | lr : 1e-3 51 | save_logs : False 52 | finetune : False 53 | proprio : ${proprio} 54 | proprio_only : False 55 | load_path : "" 56 | 57 | # eval setup 58 | eval: 59 | eval : False 60 | 61 | # general outputs 62 | job_name : 'eval' 63 | 64 | hydra: 65 | job: 66 | name: trash 67 | -------------------------------------------------------------------------------- /evaluation/r3meval/core/config/hydra/launcher/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | launcher: 4 | cpus_per_task: 20 5 | gpus_per_node: 1 6 | tasks_per_node: 1 7 | timeout_min: 600 8 | mem_gb: 64 9 | name: ${hydra.job.name} 10 | _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher 11 | submitit_folder: ${now:%Y-%m-%d}_${now:%H-%M-%S}/.submitit/%j 12 | -------------------------------------------------------------------------------- /evaluation/r3meval/core/config/hydra/output/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | run: 4 | dir: ./outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 5 | subdir: ${hydra.job.num}_${now:%Y-%m-%d}_${now:%H-%M-%S} 6 | sweep: 7 | dir: ./outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num}_${now:%Y-%m-%d}_${now:%H-%M-%S} 9 | -------------------------------------------------------------------------------- /evaluation/r3meval/core/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=iris-hi 3 | #SBATCH --mem=32G 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --exclude=iris4,iris2,iris-hp-z8,iris3 6 | #SBATCH --job-name="AMPHIBIOUS GAZE IMPROVEMENT" 7 | #SBATCH --time=3-0:0 8 | #SBATCH --account=iris 9 | 10 | export ENV_NAME=${1} 11 | export SEED=${2} 12 | export CAM_NAME=${3} 13 | export EMB_NAME=${4} 14 | export LOAD_PATH=${5} 15 | export NUM_DEMOS=10 16 | 17 | if [[ "${1}" == *"v2"* ]]; then 18 | echo "Using proprio=4 for Meta-World environment." 19 | export PROPRIO=4 20 | else 21 | echo "Using proprio=9 for FrankaKitchen environment." 22 | export PROPRIO=9 23 | fi 24 | 25 | export PYTHONPATH='/iris/u/kayburns/new_arch/Intriguing-Properties-of-Vision-Transformers/' 26 | source /sailhome/kayburns/.bashrc 27 | conda activate py3.8_torch1.10.1 28 | cd /iris/u/kayburns/new_arch/r3m/evaluation/ 29 | python r3meval/core/hydra_eval_launcher.py hydra/launcher=local hydra/output=local \ 30 | env=${ENV_NAME} camera=${CAM_NAME} pixel_based=true \ 31 | embedding=${EMB_NAME} num_demos=${NUM_DEMOS} env_kwargs.load_path=${LOAD_PATH} \ 32 | bc_kwargs.finetune=false proprio=${PROPRIO} job_name=try_r3m seed=${SEED} 33 | -------------------------------------------------------------------------------- /evaluation/r3meval/core/eval_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from collections import namedtuple 6 | from r3meval.utils.gym_env import GymEnv 7 | from r3meval.utils.obs_wrappers import MuJoCoPixelObs, StateEmbedding 8 | from r3meval.utils.sampling import sample_paths 9 | from r3meval.utils.gaussian_mlp import MLP 10 | from r3meval.utils.behavior_cloning import BC 11 | from r3meval.utils.visualizations import place_attention_heatmap_over_images 12 | from tabulate import tabulate 13 | from tqdm import tqdm 14 | import mj_envs, gym 15 | import numpy as np, time as timer, multiprocessing, pickle, os 16 | import os 17 | from collections import namedtuple 18 | 19 | 20 | import metaworld 21 | from metaworld.envs import (ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE, 22 | ALL_V2_ENVIRONMENTS_GOAL_HIDDEN) 23 | 24 | 25 | def env_constructor(env_name, device='cuda', image_width=256, image_height=256, 26 | camera_name=None, embedding_name='resnet50', pixel_based=True, 27 | render_gpu_id=0, load_path="", proprio=False, lang_cond=False, 28 | gc=False, model_path=None, shift=None): 29 | 30 | ## If pixel based will wrap in a pixel observation wrapper 31 | if pixel_based: 32 | ## Need to do some special environment config for the metaworld environments 33 | if "v2" in env_name: 34 | e = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](model_name=model_path) 35 | e._freeze_rand_vec = False 36 | e.spec = namedtuple('spec', ['id', 'max_episode_steps']) 37 | e.spec.id = env_name 38 | e.spec.max_episode_steps = 500 39 | else: 40 | if model_path: 41 | e = gym.make(env_name, model_path=model_path) 42 | else: 43 | e = gym.make(env_name) # probably unnecessary 44 | ## Wrap in pixel observation wrapper 45 | e = MuJoCoPixelObs(e, width=image_width, height=image_height, 46 | camera_name=camera_name, device_id=render_gpu_id, 47 | shift=shift) 48 | ## Wrapper which encodes state in pretrained model 49 | e = StateEmbedding(e, embedding_name=embedding_name, device=device, load_path=load_path, 50 | proprio=proprio, camera_name=camera_name, env_name=env_name) 51 | e = GymEnv(e) 52 | else: 53 | print("Only supports pixel based") 54 | assert(False) 55 | return e 56 | 57 | 58 | def make_bc_agent(env_kwargs:dict, bc_kwargs:dict, demo_paths:list, epochs:int, seed:int, pixel_based=True, model_path=None): 59 | ## Creates environment 60 | env_kwargs['model_path'] = model_path 61 | e = env_constructor(**env_kwargs) 62 | 63 | ## Creates MLP (Where the FC Network has a batchnorm in front of it) 64 | policy = MLP(e.spec, hidden_sizes=(256, 256), seed=seed) 65 | policy.model.proprio_only = False 66 | if bc_kwargs["load_path"]: 67 | policy = pickle.load(open(bc_kwargs["load_path"], 'rb')) 68 | 69 | ## Pass the encoder params to the BC agent (for finetuning) 70 | if pixel_based: 71 | enc_p = e.env.embedding.parameters() 72 | else: 73 | print("Only supports pixel based") 74 | assert(False) 75 | bc_agent = BC(demo_paths, policy=policy, epochs=epochs, set_transforms=False, encoder_params=enc_p, **bc_kwargs) 76 | 77 | ## Pass the environmetns observation encoder to the BC agent to encode demo data 78 | if pixel_based: 79 | bc_agent.encodefn = e.env.encode_batch 80 | else: 81 | print("Only supports pixel based") 82 | assert(False) 83 | return e, bc_agent 84 | 85 | 86 | def configure_cluster_GPUs(gpu_logical_id: int) -> int: 87 | # get the correct GPU ID 88 | if "SLURM_STEP_GPUS" in os.environ.keys(): 89 | physical_gpu_ids = os.environ.get('SLURM_STEP_GPUS') 90 | gpu_id = int(physical_gpu_ids.split(',')[gpu_logical_id]) 91 | print("Found slurm-GPUS: ".format(physical_gpu_ids)) 92 | print("Using GPU ".format(gpu_id, gpu_logical_id)) 93 | else: 94 | gpu_id = 0 # base case when no GPUs detected in SLURM 95 | print("No GPUs detected. Defaulting to 0 as the device ID") 96 | return gpu_id 97 | 98 | def eval_model_path(job_data, demo_paths, model_path, init=False): 99 | env_kwargs = job_data['env_kwargs'] 100 | if model_path: 101 | shift = model_path.split('_')[-1].split('.')[0] 102 | else: 103 | shift = '' 104 | e, agent = make_bc_agent(env_kwargs=env_kwargs, bc_kwargs=job_data['bc_kwargs'], 105 | demo_paths=demo_paths, epochs=1, seed=job_data['seed'], pixel_based=job_data["pixel_based"], 106 | model_path=model_path) 107 | if init: 108 | agent.logger.init_wb(job_data, project='r3m_shift_eval_noft_final') 109 | 110 | # perform evaluation rollouts every few epochs 111 | agent.policy.model.eval() 112 | if job_data['pixel_based']: 113 | e.env.embedding.eval() 114 | paths = sample_paths(num_traj=job_data['eval_num_traj'], env=e, #env_constructor, 115 | policy=agent.policy, eval_mode=True, horizon=e.horizon, 116 | base_seed=job_data['seed'], num_cpu=job_data['num_cpu'], 117 | env_kwargs=env_kwargs) 118 | 119 | try: 120 | ## Success computation and logging for Adroit and Kitchen 121 | success_percentage = e.env.unwrapped.evaluate_success(paths) 122 | for i, path in enumerate(paths): 123 | if (i < 1) and job_data['pixel_based']: 124 | vid = path['images'] 125 | # filename = f'./iterations/vid_{i}.gif' 126 | from moviepy.editor import ImageSequenceClip 127 | # cl = ImageSequenceClip(vid, fps=20) 128 | # cl.write_gif(filename, fps=20) 129 | 130 | if job_data.embedding == 'dino' or job_data.embedding == 'mvp': 131 | for j in range(6): 132 | heatmap_vid = place_attention_heatmap_over_images(vid, e.env.embedding, job_data.embedding, head=j) 133 | filename = f'./iterations/heatmap_vid_{i}_{j}_{shift}.gif' 134 | cl = ImageSequenceClip(heatmap_vid, fps=20) 135 | cl.write_gif(filename, fps=20) 136 | except: 137 | ## Success computation and logging for MetaWorld 138 | sc = [] 139 | for i, path in enumerate(paths): 140 | sc.append(path['env_infos']['success'][-1]) 141 | # if (i < 3) and job_data['pixel_based']: 142 | # vid = path['images'] 143 | # filename = f'./iterations/vid_{i}.gif' 144 | # from moviepy.editor import ImageSequenceClip 145 | # cl = ImageSequenceClip(vid, fps=20) 146 | # cl.write_gif(filename, fps=20) 147 | 148 | # for j in range(6): 149 | # heatmap_vid = place_attention_heatmap_over_images(vid, e.env.embedding, head=j) 150 | # filename = f'./iterations/heatmap_vid_{i}_{j}.gif' 151 | # cl = ImageSequenceClip(heatmap_vid, fps=20) 152 | # cl.write_gif(filename, fps=20) 153 | success_percentage = np.mean(sc) * 100 154 | if not model_path: 155 | agent.logger.log_kv('eval_success', success_percentage) 156 | else: 157 | agent.logger.log_kv(f'eval_success{shift}', success_percentage) 158 | agent.logger.save_wb(step=0) 159 | 160 | print_data = sorted(filter(lambda v: np.asarray(v[1]).size == 1, 161 | agent.logger.get_current_log().items())) 162 | print(tabulate(print_data)) 163 | 164 | def eval_loop(job_data:dict) -> None: 165 | 166 | # configure GPUs 167 | os.environ['GPUS'] = os.environ.get('SLURM_STEP_GPUS', '0') 168 | physical_gpu_id = 0 #configure_cluster_GPUs(job_data['env_kwargs']['render_gpu_id']) 169 | job_data['env_kwargs']['render_gpu_id'] = physical_gpu_id 170 | 171 | # Infers the location of the demos 172 | ## V2 is metaworld, V0 adroit, V3 kitchen 173 | data_dir = '/iris/u/kayburns/data/r3m/' 174 | if "v2" in job_data['env_kwargs']['env_name']: 175 | demo_paths_loc = data_dir + 'final_paths_multiview_meta_200/' + job_data['camera'] + '/' + job_data['env_kwargs']['env_name'] + '.pickle' 176 | elif "v0" in job_data['env_kwargs']['env_name']: 177 | demo_paths_loc = data_dir + 'final_paths_multiview_adroit_200/' + job_data['camera'] + '/' + job_data['env_kwargs']['env_name'] + '.pickle' 178 | else: 179 | demo_paths_loc = data_dir + 'final_paths_multiview_rb_200/' + job_data['camera'] + '/' + job_data['env_kwargs']['env_name'] + '.pickle' 180 | 181 | ## Loads the demos 182 | demo_paths = pickle.load(open(demo_paths_loc, 'rb')) 183 | demo_paths = demo_paths[:job_data['num_demos']] 184 | print(len(demo_paths)) 185 | demo_score = np.mean([np.sum(p['rewards']) for p in demo_paths]) 186 | print("Demonstration score : %.2f " % demo_score) 187 | 188 | # Make log dir 189 | if os.path.isdir(job_data['job_name']) == False: os.mkdir(job_data['job_name']) 190 | previous_dir = os.getcwd() 191 | os.chdir(job_data['job_name']) # important! we are now in the directory to save data 192 | if os.path.isdir('iterations') == False: os.mkdir('iterations') 193 | if os.path.isdir('logs') == False: os.mkdir('logs') 194 | 195 | # Creates agent and environment 196 | model_paths = None 197 | eval_model_path(job_data, demo_paths, model_path=None, init=True) 198 | 199 | if 'kitchen' in job_data.env: 200 | distractors = ['cracker_box', 'medium', 'hard'] 201 | for distractor in distractors: 202 | model_path = f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_distractor_{distractor}.xml' 203 | eval_model_path(job_data, demo_paths, model_path, init=False) 204 | textures = ['wood2', 'metal2', 'tile1'] 205 | for texture in textures: 206 | model_path = f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_slide_{texture}.xml' 207 | eval_model_path(job_data, demo_paths, model_path, init=False) 208 | lightings = ['cast_left', 'cast_right', 'brighter', 'darker'] 209 | for lighting in lightings: 210 | model_path = f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_{lighting}.xml' 211 | eval_model_path(job_data, demo_paths, model_path, init=False) 212 | elif 'assembly' in job_data.env: 213 | for shift in ['distractor_easy', 'distractor_medium', 'distractor_hard', 'blue-woodtable', 'dark-woodtable', 'darkwoodtable', 'cast_right', 'cast_left', 'darker', 'brighter']: 214 | model_path = f'/iris/u/kayburns/packages/metaworld_r3m/metaworld/envs/assets_v2/sawyer_xyz/sawyer_assembly_peg_{shift}.xml' 215 | eval_model_path(job_data, demo_paths, model_path, init=False) 216 | elif 'bin' in job_data.env: 217 | for shift in ['distractor_easy', 'distractor_medium', 'distractor_hard', 'blue-woodtable', 'dark-woodtable', 'darkwoodtable', 'cast_right', 'cast_left', 'darker', 'brighter']: 218 | model_path = f'/iris/u/kayburns/packages/metaworld_r3m/metaworld/envs/assets_v2/sawyer_xyz/sawyer_bin_picking_{shift}.xml' 219 | eval_model_path(job_data, demo_paths, model_path, init=False) 220 | elif 'button' in job_data.env: 221 | for shift in ['distractor_easy', 'distractor_medium', 'distractor_hard', 'blue-woodtable', 'dark-woodtable', 'darkwoodtable', 'cast_right', 'cast_left', 'darker', 'brighter']: 222 | model_path = f'/iris/u/kayburns/packages/metaworld_r3m/metaworld/envs/assets_v2/sawyer_xyz/sawyer_button_press_topdown_{shift}.xml' 223 | eval_model_path(job_data, demo_paths, model_path, init=False) 224 | elif 'drawer' in job_data.env: 225 | for shift in ['distractor_easy', 'distractor_medium', 'distractor_hard', 'blue-woodtable', 'dark-woodtable', 'darkwoodtable', 'cast_right', 'cast_left', 'darker', 'brighter']: 226 | model_path = f'/iris/u/kayburns/packages/metaworld_r3m/metaworld/envs/assets_v2/sawyer_xyz/sawyer_drawer_{shift}.xml' 227 | eval_model_path(job_data, demo_paths, model_path, init=False) 228 | elif 'hammer' in job_data.env: 229 | for shift in ['distractor_easy', 'distractor_medium', 'distractor_hard', 'blue-woodtable', 'dark-woodtable', 'darkwoodtable', 'cast_right', 'cast_left', 'darker', 'brighter']: 230 | model_path = f'/iris/u/kayburns/packages/metaworld_r3m/metaworld/envs/assets_v2/sawyer_xyz/sawyer_hammer_{shift}.xml' 231 | eval_model_path(job_data, demo_paths, model_path, init=False) 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /evaluation/r3meval/core/hydra_eval_launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import time as timer 8 | import hydra 9 | import multiprocessing 10 | from omegaconf import DictConfig, OmegaConf 11 | from eval_loop import eval_loop 12 | 13 | cwd = os.getcwd() 14 | 15 | sweep_dir = '/iris/u/kayburns/new_arch/r3m/evaluation/outputs/main_sweep_noft/' 16 | 17 | def is_target_task(target_job_data, query_job_data): 18 | # envs 19 | slide_door = query_job_data.env == 'hammer-v2-goal-observable' 20 | 21 | # cameras 22 | left_cap = query_job_data.camera == 'left_cap2' 23 | 24 | # embeddings 25 | deit_sin = (query_job_data.embedding == 'deit_s_sin') and (query_job_data.env_kwargs.load_path == 'deit_s_sin') 26 | 27 | # num_demos 28 | ten_demos = query_job_data.num_demos == 10 29 | 30 | # seeds 31 | seed_123 = query_job_data.seed == 123 32 | 33 | # if (slide_door and left_cap and deit_sin and ten_demos and seed_123): 34 | # import pdb; pdb.set_trace() 35 | 36 | target_conditions_met = [ 37 | target_job_data.env == query_job_data.env, 38 | target_job_data.camera == query_job_data.camera, 39 | target_job_data.embedding == query_job_data.embedding, 40 | target_job_data.env_kwargs.load_path == query_job_data.env_kwargs.load_path, 41 | target_job_data.bc_kwargs.finetune == query_job_data.bc_kwargs.finetune, 42 | target_job_data.num_demos == query_job_data.num_demos, 43 | target_job_data.seed == query_job_data.seed, 44 | target_job_data.proprio == query_job_data.proprio, 45 | target_job_data.get('ft_only_last_layer', False) == query_job_data.get('ft_only_last_layer', False) 46 | ] 47 | 48 | return all(target_conditions_met) 49 | 50 | # =============================================================================== 51 | # Process Inputs and configure job 52 | # =============================================================================== 53 | @hydra.main(config_name="eval_config", config_path="config") 54 | def configure_jobs(job_data:dict) -> None: 55 | os.environ['GPUS'] = os.environ.get('SLURM_STEP_GPUS', '0') 56 | 57 | print("========================================") 58 | print("Job Configuration") 59 | print("========================================") 60 | 61 | job_data = OmegaConf.structured(OmegaConf.to_yaml(job_data)) 62 | 63 | job_data['cwd'] = cwd 64 | 65 | run_paths = os.listdir(sweep_dir) 66 | 67 | for run_path in run_paths: 68 | 69 | old_config_path = os.path.join(sweep_dir, run_path, 'job_config.json') 70 | embedding_path = os.path.join(sweep_dir, run_path, 'try_r3m/iterations/embedding_best.pickle') 71 | policy_path = os.path.join(sweep_dir, run_path, 'try_r3m/iterations/policy_best.pickle') 72 | 73 | if not os.path.exists(old_config_path): 74 | continue 75 | 76 | with open(old_config_path, 'r') as fp: 77 | old_job_data = OmegaConf.load(fp) 78 | 79 | if not is_target_task(job_data, old_job_data): 80 | continue 81 | 82 | if not os.path.isfile(policy_path): 83 | print(f'No weights for ' \ 84 | f'{old_job_data.env}, ' \ 85 | f'{old_job_data.bc_kwargs.finetune}, ' \ 86 | f'{old_job_data.num_demos}, ' \ 87 | f'{old_job_data.seed}, ' \ 88 | f'{old_job_data.env_kwargs.load_path}, ' \ 89 | f'{os.path.join(sweep_dir, run_path)}, ' \ 90 | f'{old_job_data.camera}') 91 | continue 92 | 93 | job_data.env_kwargs = old_job_data.env_kwargs 94 | job_data.bc_kwargs.load_path = policy_path 95 | if job_data.bc_kwargs.finetune: 96 | job_data.env_kwargs.load_path = embedding_path 97 | else: 98 | job_data.env_kwargs.load_path = old_job_data.env_kwargs.load_path 99 | 100 | with open('job_config.json', 'w') as fp: 101 | OmegaConf.save(config=job_data, f=fp.name) 102 | print(OmegaConf.to_yaml(job_data)) 103 | eval_loop(job_data) 104 | 105 | if __name__ == "__main__": 106 | multiprocessing.set_start_method('spawn') 107 | configure_jobs() -------------------------------------------------------------------------------- /evaluation/r3meval/core/hydra_launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import time as timer 8 | import hydra 9 | import multiprocessing 10 | from omegaconf import DictConfig, OmegaConf 11 | from train_loop import bc_train_loop, eval_loop 12 | 13 | cwd = os.getcwd() 14 | 15 | # =============================================================================== 16 | # Process Inputs and configure job 17 | # =============================================================================== 18 | @hydra.main(config_name="BC_config", config_path="config") 19 | def configure_jobs(job_data:dict) -> None: 20 | os.environ['GPUS'] = os.environ.get('SLURM_STEP_GPUS', '0') 21 | 22 | print("========================================") 23 | print("Job Configuration") 24 | print("========================================") 25 | 26 | job_data = OmegaConf.structured(OmegaConf.to_yaml(job_data)) 27 | 28 | job_data['cwd'] = cwd 29 | with open('job_config.json', 'w') as fp: 30 | OmegaConf.save(config=job_data, f=fp.name) 31 | print(OmegaConf.to_yaml(job_data)) 32 | if job_data["eval"]["eval"]: 33 | eval_loop(job_data) 34 | else: 35 | bc_train_loop(job_data) 36 | 37 | if __name__ == "__main__": 38 | multiprocessing.set_start_method('spawn') 39 | configure_jobs() -------------------------------------------------------------------------------- /evaluation/r3meval/core/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=iris-hi 3 | #SBATCH --mem=32G 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --exclude=iris4,iris2,iris-hp-z8 6 | #SBATCH --job-name="AMPHIBIOUS GAZE IMPROVEMENT" 7 | #SBATCH --time=3-0:0 8 | #SBATCH --account=iris 9 | 10 | export ENV_NAME=${1} 11 | export SEED=${2} 12 | export CAM_NAME=${3} 13 | export EMB_NAME=${4} 14 | export LOAD_PATH=${5} 15 | export NUM_DEMOS=10 16 | 17 | 18 | if [[ "${1}" == *"v2"* ]]; then 19 | echo "Using proprio=4 for Meta-World environment." 20 | export PROPRIO=4 21 | else 22 | echo "Using proprio=9 for FrankaKitchen environment." 23 | export PROPRIO=9 24 | fi 25 | 26 | export PYTHONPATH='/iris/u/kayburns/new_arch/Intriguing-Properties-of-Vision-Transformers/' 27 | source /sailhome/kayburns/.bashrc 28 | conda activate py3.8_torch1.10.1 29 | cd /iris/u/kayburns/new_arch/r3m/evaluation/ 30 | python r3meval/core/hydra_launcher.py hydra/launcher=local hydra/output=local \ 31 | env=${ENV_NAME} camera=${CAM_NAME} pixel_based=true \ 32 | embedding=${EMB_NAME} num_demos=${NUM_DEMOS} env_kwargs.load_path=${LOAD_PATH} \ 33 | bc_kwargs.finetune=false proprio=${PROPRIO} job_name=try_r3m seed=${SEED} 34 | 35 | # source /sailhome/kayburns/.bashrc 36 | # conda activate py3.8_torch1.10.1 37 | # cd /iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/ 38 | 39 | # # export PYTHONPATH='/iris/u/kayburns/new_arch/Intriguing-Properties-of-Vision-Transformers/' 40 | # # for env in kitchen_sdoor_open-v3 kitchen_knob1_on-v3 kitchen_light_on-v3 kitchen_ldoor_open-v3 kitchen_micro_open-v3 # kitchen_sdoor_open-v3 kitchen_knob1_on-v3 kitchen_light_on-v3 kitchen_ldoor_open-v3 kitchen_micro_open-v3 41 | # # # for env in assembly-v2-goal-observable # assembly-v2-goal-observable bin-picking-v2-goal-observable button-press-topdown-v2-goal-observable drawer-open-v2-goal-observable hammer-v2-goal-observable 42 | # # do 43 | # # for num_demos in 10 # 5 10 25 44 | # # do 45 | # # for camera in left_cap2 # default left_cap2 right_cap2 46 | # # do 47 | # # for seed in 123 124 125 # 123 124 125 48 | # # do 49 | # # python hydra_launcher.py hydra/launcher=local hydra/output=local \ 50 | # # pixel_based=true embedding=dino_ensemble env_kwargs.load_path=dino_ensemble \ 51 | # # bc_kwargs.finetune=true ft_only_last_layer=true proprio=9 job_name=try_r3m \ 52 | # # seed=$seed num_demos=$num_demos env=$env camera=$camera 53 | # # done 54 | # # done 55 | # # done 56 | # # done 57 | 58 | # export PYTHONPATH='/iris/u/kayburns/new_arch/Intriguing-Properties-of-Vision-Transformers/:/iris/u/kayburns/new_arch/' 59 | # for env in kitchen_knob1_on-v3 # kitchen_light_on-v3 kitchen_ldoor_open-v3 kitchen_micro_open-v3 # kitchen_sdoor_open-v3 kitchen_knob1_on-v3 kitchen_light_on-v3 kitchen_ldoor_open-v3 kitchen_micro_open-v3 60 | # do 61 | # for num_demos in 10 # 5 10 25 62 | # do 63 | # for camera in left_cap2 64 | # do 65 | # for seed in 123 124 125 # 123 124 125 66 | # do 67 | # python hydra_launcher.py hydra/launcher=local hydra/output=local \ 68 | # pixel_based=true embedding=keypoints env_kwargs.load_path=keypoints \ 69 | # bc_kwargs.finetune=false proprio=9 job_name=try_r3m \ 70 | # seed=$seed num_demos=$num_demos env=$env camera=$camera 71 | # done 72 | # done 73 | # done 74 | # done 75 | 76 | # # export PYTHONPATH='/iris/u/kayburns/new_arch/Intriguing-Properties-of-Vision-Transformers/' 77 | # # for env in hammer-v2-goal-observable drawer-open-v2-goal-observable # assembly-v2-goal-observable bin-picking-v2-goal-observable button-press-topdown-v2-goal-observable drawer-open-v2-goal-observable hammer-v2-goal-observable 78 | # # do 79 | # # for num_demos in 10 # 5 10 25 80 | # # do 81 | # # for camera in left_cap2 # default left_cap2 right_cap2 82 | # # do 83 | # # for seed in 123 124 125 # 123 124 125 84 | # # do 85 | # # python hydra_launcher.py hydra/launcher=local hydra/output=local \ 86 | # # pixel_based=true embedding=resnet50_dino env_kwargs.load_path=resnet50_dino \ 87 | # # bc_kwargs.finetune=false proprio=4 job_name=try_r3m \ 88 | # # seed=$seed num_demos=$num_demos env=$env camera=$camera 89 | # # done 90 | # # done 91 | # # done 92 | # # done 93 | 94 | # # export PYTHONPATH='/iris/u/kayburns/new_arch/Intriguing-Properties-of-Vision-Transformers/' 95 | # # for env in assembly-v2-goal-observable bin-picking-v2-goal-observable button-press-topdown-v2-goal-observable drawer-open-v2-goal-observable hammer-v2-goal-observable # assembly-v2-goal-observable bin-picking-v2-goal-observable button-press-topdown-v2-goal-observable drawer-open-v2-goal-observable hammer-v2-goal-observable 96 | # # do 97 | # # for num_demos in 10 # 5 10 25 98 | # # do 99 | # # for camera in left_cap2 # default left_cap2 right_cap2 100 | # # do 101 | # # for seed in 123 124 125 # 123 124 125 102 | # # do 103 | # # python hydra_launcher.py hydra/launcher=local hydra/output=local \ 104 | # # pixel_based=true embedding=dino_ensemble env_kwargs.load_path=dino_ensemble \ 105 | # # bc_kwargs.finetune=true ft_only_last_layer=true proprio=4 job_name=try_r3m \ 106 | # # seed=$seed num_demos=$num_demos env=$env camera=$camera 107 | # # done 108 | # # done 109 | # # done 110 | # # done 111 | -------------------------------------------------------------------------------- /evaluation/r3meval/core/train_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from collections import namedtuple 6 | from r3meval.utils.gym_env import GymEnv 7 | from r3meval.utils.obs_wrappers import MuJoCoPixelObs, StateEmbedding 8 | from r3meval.utils.sampling import sample_paths 9 | from r3meval.utils.gaussian_mlp import MLP 10 | from r3meval.utils.behavior_cloning import BC 11 | from r3meval.utils.visualizations import place_attention_heatmap_over_images 12 | from tabulate import tabulate 13 | from tqdm import tqdm 14 | import mj_envs, gym 15 | import numpy as np, time as timer, multiprocessing, pickle, os 16 | import os 17 | from collections import namedtuple 18 | 19 | 20 | import metaworld 21 | from metaworld.envs import (ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE, 22 | ALL_V2_ENVIRONMENTS_GOAL_HIDDEN) 23 | 24 | 25 | def env_constructor(env_name, device='cuda', image_width=256, image_height=256, 26 | camera_name=None, embedding_name='resnet50', pixel_based=True, 27 | render_gpu_id=0, load_path="", proprio=False, lang_cond=False, 28 | gc=False, shift="none"): 29 | 30 | ## If pixel based will wrap in a pixel observation wrapper 31 | if pixel_based: 32 | ## Need to do some special environment config for the metaworld environments 33 | if "v2" in env_name: 34 | e = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name]() 35 | e._freeze_rand_vec = False 36 | e.spec = namedtuple('spec', ['id', 'max_episode_steps']) 37 | e.spec.id = env_name 38 | e.spec.max_episode_steps = 500 39 | else: 40 | e = gym.make(env_name) 41 | ## Wrap in pixel observation wrapper 42 | e = MuJoCoPixelObs(e, width=image_width, height=image_height, 43 | camera_name=camera_name, device_id=render_gpu_id, 44 | shift=shift) 45 | ## Wrapper which encodes state in pretrained model 46 | e = StateEmbedding(e, embedding_name=embedding_name, device=device, load_path=load_path, 47 | proprio=proprio, camera_name=camera_name, env_name=env_name) 48 | e = GymEnv(e) 49 | else: 50 | print("Only supports pixel based") 51 | assert(False) 52 | return e 53 | 54 | 55 | def make_bc_agent(env_kwargs:dict, bc_kwargs:dict, job_data:dict, demo_paths:list, epochs:int, seed:int, pixel_based=True): 56 | ## Creates environment 57 | e = env_constructor(**env_kwargs) 58 | 59 | ## Creates MLP (Where the FC Network has a batchnorm in front of it) 60 | policy = MLP(e.spec, hidden_sizes=(256, 256), seed=seed) 61 | policy.model.proprio_only = False 62 | if bc_kwargs["load_path"]: 63 | policy = pickle.load(open(bc_kwargs["load_path"], 'rb')) 64 | 65 | ## Pass the encoder params to the BC agent (for finetuning) 66 | if pixel_based: 67 | if 'ensemble' in job_data.embedding: 68 | assert job_data.ft_only_last_layer 69 | assert bc_kwargs.finetune 70 | enc_p = [e.env.embedding.blocks[-1].attn.scale] 71 | else: 72 | enc_p = e.env.embedding.parameters() 73 | else: 74 | print("Only supports pixel based") 75 | assert(False) 76 | bc_agent = BC(demo_paths, policy=policy, epochs=epochs, set_transforms=False, encoder_params=enc_p, **bc_kwargs) 77 | 78 | ## Pass the environmetns observation encoder to the BC agent to encode demo data 79 | if pixel_based: 80 | bc_agent.encodefn = e.env.encode_batch 81 | else: 82 | print("Only supports pixel based") 83 | assert(False) 84 | return e, bc_agent 85 | 86 | 87 | def configure_cluster_GPUs(gpu_logical_id: int) -> int: 88 | # get the correct GPU ID 89 | if "SLURM_STEP_GPUS" in os.environ.keys(): 90 | physical_gpu_ids = os.environ.get('SLURM_STEP_GPUS') 91 | gpu_id = int(physical_gpu_ids.split(',')[gpu_logical_id]) 92 | print("Found slurm-GPUS: ".format(physical_gpu_ids)) 93 | print("Using GPU ".format(gpu_id, gpu_logical_id)) 94 | else: 95 | gpu_id = 0 # base case when no GPUs detected in SLURM 96 | print("No GPUs detected. Defaulting to 0 as the device ID") 97 | return gpu_id 98 | 99 | 100 | def bc_train_loop(job_data:dict) -> None: 101 | 102 | # configure GPUs 103 | os.environ['GPUS'] = os.environ.get('SLURM_STEP_GPUS', '0') 104 | physical_gpu_id = 0 #configure_cluster_GPUs(job_data['env_kwargs']['render_gpu_id']) 105 | job_data['env_kwargs']['render_gpu_id'] = physical_gpu_id 106 | 107 | # Infers the location of the demos 108 | ## V2 is metaworld, V0 adroit, V3 kitchen 109 | data_dir = '/iris/u/kayburns/data/r3m/' 110 | if "v2" in job_data['env_kwargs']['env_name']: 111 | demo_paths_loc = data_dir + 'final_paths_multiview_meta_200/' + job_data['camera'] + '/' + job_data['env_kwargs']['env_name'] + '.pickle' 112 | elif "v0" in job_data['env_kwargs']['env_name']: 113 | demo_paths_loc = data_dir + 'final_paths_multiview_adroit_200/' + job_data['camera'] + '/' + job_data['env_kwargs']['env_name'] + '.pickle' 114 | else: 115 | demo_paths_loc = data_dir + 'final_paths_multiview_rb_200/' + job_data['camera'] + '/' + job_data['env_kwargs']['env_name'] + '.pickle' 116 | 117 | ## Loads the demos 118 | demo_paths = pickle.load(open(demo_paths_loc, 'rb')) 119 | demo_paths = demo_paths[:job_data['num_demos']] 120 | print(len(demo_paths)) 121 | demo_score = np.mean([np.sum(p['rewards']) for p in demo_paths]) 122 | print("Demonstration score : %.2f " % demo_score) 123 | 124 | # Make log dir 125 | if os.path.isdir(job_data['job_name']) == False: os.mkdir(job_data['job_name']) 126 | previous_dir = os.getcwd() 127 | os.chdir(job_data['job_name']) # important! we are now in the directory to save data 128 | if os.path.isdir('iterations') == False: os.mkdir('iterations') 129 | if os.path.isdir('logs') == False: os.mkdir('logs') 130 | 131 | ## Creates agent and environment 132 | env_kwargs = job_data['env_kwargs'] 133 | e, agent = make_bc_agent(env_kwargs=env_kwargs, bc_kwargs=job_data['bc_kwargs'], job_data=job_data, 134 | demo_paths=demo_paths, epochs=1, seed=job_data['seed'], pixel_based=job_data["pixel_based"]) 135 | agent.logger.init_wb(job_data, project=job_data.project) 136 | 137 | highest_score = -np.inf 138 | max_success = 0 139 | epoch = 0 140 | while True: 141 | # update policy using one BC epoch 142 | last_step = agent.steps 143 | print("Step", last_step) 144 | agent.policy.model.train() 145 | # If finetuning, wait until 25% of training is done then 146 | ## set embedding to train mode and turn on finetuning 147 | if (job_data['bc_kwargs']['finetune']) and (job_data['pixel_based']) and (job_data['env_kwargs']['load_path'] != "clip"): 148 | if last_step > (job_data['steps'] / 4.0): 149 | e.env.embedding.train() 150 | # freeze all but last layer 151 | if job_data.ft_only_last_layer: 152 | for name, param in e.env.embedding.named_parameters(): 153 | if not ('blocks.11' in name or 'norm.weight' == name or 'norm.bias' == name): 154 | param.requires_grad = False 155 | e.env.start_finetuning() 156 | agent.train(job_data['pixel_based'], suppress_fit_tqdm=True, step = last_step) 157 | 158 | # perform evaluation rollouts every few epochs 159 | if "ensemble" in job_data.embedding: 160 | for i, head_weight in enumerate(e.env.embedding.blocks[-1].attn.scale[0,:,0,0]): 161 | agent.logger.log_kv(f'head_{i}', head_weight.cpu().detach().numpy()) 162 | if ((agent.steps % job_data['eval_frequency']) < (last_step % job_data['eval_frequency'])): 163 | agent.policy.model.eval() 164 | if job_data['pixel_based']: 165 | e.env.embedding.eval() 166 | paths = sample_paths(num_traj=job_data['eval_num_traj'], env=e, #env_constructor, 167 | policy=agent.policy, eval_mode=True, horizon=e.horizon, 168 | base_seed=job_data['seed']+epoch, num_cpu=job_data['num_cpu'], 169 | env_kwargs=env_kwargs) 170 | 171 | try: 172 | ## Success computation and logging for Adroit and Kitchen 173 | success_percentage = e.env.unwrapped.evaluate_success(paths) 174 | for i, path in enumerate(paths): 175 | if (i < 3) and job_data['pixel_based']: 176 | vid = path['images'] 177 | filename = f'./iterations/vid_{i}.gif' 178 | from moviepy.editor import ImageSequenceClip 179 | cl = ImageSequenceClip(vid, fps=20) 180 | cl.write_gif(filename, fps=20) 181 | 182 | except: 183 | ## Success computation and logging for MetaWorld 184 | sc = [] 185 | for i, path in enumerate(paths): 186 | sc.append(path['env_infos']['success'][-1]) 187 | # if (i < 10) and job_data['pixel_based']: 188 | # vid = path['images'] 189 | # filename = f'./iterations/vid_{i}.gif' 190 | # from moviepy.editor import ImageSequenceClip 191 | # cl = ImageSequenceClip(vid, fps=20) 192 | # cl.write_gif(filename, fps=20) 193 | success_percentage = np.mean(sc) * 100 194 | agent.logger.log_kv('eval_epoch', epoch) 195 | agent.logger.log_kv('eval_success', success_percentage) 196 | if "ensemble" in job_data.embedding: 197 | for i, head_weight in enumerate(e.env.embedding.blocks[-1].attn.scale[0,:,0,0]): 198 | agent.logger.log_kv(f'head_{i}', head_weight.cpu().detach().numpy()) 199 | 200 | # save policy and logging 201 | if success_percentage >= max_success: 202 | pickle.dump(agent.policy, open('./iterations/policy_best.pickle', 'wb')) 203 | if job_data.bc_kwargs.finetune: 204 | pickle.dump(e.env.embedding, open('./iterations/embedding_best.pickle', 'wb')) 205 | agent.logger.save_log('./logs/') 206 | agent.logger.save_wb(step=agent.steps) 207 | 208 | # Tracking best success over training 209 | max_success = max(max_success, success_percentage) 210 | 211 | print_data = sorted(filter(lambda v: np.asarray(v[1]).size == 1, 212 | agent.logger.get_current_log().items())) 213 | print(tabulate(print_data)) 214 | epoch += 1 215 | if agent.steps > job_data['steps']: 216 | break 217 | agent.logger.log_kv('max_success', max_success) 218 | agent.logger.save_wb(step=agent.steps) 219 | 220 | def eval_loop(job_data:dict) -> None: 221 | 222 | # configure GPUs 223 | os.environ['GPUS'] = os.environ.get('SLURM_STEP_GPUS', '0') 224 | physical_gpu_id = 0 #configure_cluster_GPUs(job_data['env_kwargs']['render_gpu_id']) 225 | job_data['env_kwargs']['render_gpu_id'] = physical_gpu_id 226 | 227 | # Infers the location of the demos 228 | ## V2 is metaworld, V0 adroit, V3 kitchen 229 | data_dir = '/iris/u/kayburns/data/r3m/' 230 | if "v2" in job_data['env_kwargs']['env_name']: 231 | demo_paths_loc = data_dir + 'final_paths_multiview_meta_200/' + job_data['camera'] + '/' + job_data['env_kwargs']['env_name'] + '.pickle' 232 | elif "v0" in job_data['env_kwargs']['env_name']: 233 | demo_paths_loc = data_dir + 'final_paths_multiview_adroit_200/' + job_data['camera'] + '/' + job_data['env_kwargs']['env_name'] + '.pickle' 234 | else: 235 | demo_paths_loc = data_dir + 'final_paths_multiview_rb_200/' + job_data['camera'] + '/' + job_data['env_kwargs']['env_name'] + '.pickle' 236 | 237 | ## Loads the demos 238 | demo_paths = pickle.load(open(demo_paths_loc, 'rb')) 239 | demo_paths = demo_paths[:job_data['num_demos']] 240 | print(len(demo_paths)) 241 | demo_score = np.mean([np.sum(p['rewards']) for p in demo_paths]) 242 | print("Demonstration score : %.2f " % demo_score) 243 | 244 | # Make log dir 245 | if os.path.isdir(job_data['job_name']) == False: os.mkdir(job_data['job_name']) 246 | previous_dir = os.getcwd() 247 | os.chdir(job_data['job_name']) # important! we are now in the directory to save data 248 | if os.path.isdir('iterations') == False: os.mkdir('iterations') 249 | if os.path.isdir('logs') == False: os.mkdir('logs') 250 | 251 | ## Creates agent and environment 252 | env_kwargs = job_data['env_kwargs'] 253 | e, agent = make_bc_agent(env_kwargs=env_kwargs, bc_kwargs=job_data['bc_kwargs'], job_data=job_data, 254 | demo_paths=demo_paths, epochs=1, seed=job_data['seed'], pixel_based=job_data["pixel_based"]) 255 | agent.logger.init_wb(job_data) 256 | 257 | # perform evaluation rollouts every few epochs 258 | agent.policy.model.eval() 259 | if job_data['pixel_based']: 260 | e.env.embedding.eval() 261 | paths = sample_paths(num_traj=job_data['eval_num_traj'], env=e, #env_constructor, 262 | policy=agent.policy, eval_mode=True, horizon=e.horizon, 263 | base_seed=job_data['seed'], num_cpu=job_data['num_cpu'], 264 | env_kwargs=env_kwargs) 265 | 266 | try: 267 | ## Success computation and logging for Adroit and Kitchen 268 | success_percentage = e.env.unwrapped.evaluate_success(paths) 269 | # for i, path in enumerate(paths): 270 | # if (i < 3) and job_data['pixel_based']: 271 | # vid = path['images'] 272 | # filename = f'./iterations/vid_{i}.gif' 273 | # from moviepy.editor import ImageSequenceClip 274 | # cl = ImageSequenceClip(vid, fps=20) 275 | # cl.write_gif(filename, fps=20) 276 | 277 | # for j in range(6): 278 | # heatmap_vid = place_attention_heatmap_over_images(vid, e.env.embedding, head=j) 279 | # filename = f'./iterations/heatmap_vid_{i}_{j}.gif' 280 | # cl = ImageSequenceClip(heatmap_vid, fps=20) 281 | # cl.write_gif(filename, fps=20) 282 | except: 283 | ## Success computation and logging for MetaWorld 284 | sc = [] 285 | for i, path in enumerate(paths): 286 | sc.append(path['env_infos']['success'][-1]) 287 | if (i < 3) and job_data['pixel_based']: 288 | vid = path['images'] 289 | filename = f'./iterations/vid_{i}.gif' 290 | from moviepy.editor import ImageSequenceClip 291 | cl = ImageSequenceClip(vid, fps=20) 292 | cl.write_gif(filename, fps=20) 293 | 294 | for j in range(6): 295 | heatmap_vid = place_attention_heatmap_over_images(vid, e.env.embedding, head=j) 296 | filename = f'./iterations/heatmap_vid_{i}_{j}.gif' 297 | cl = ImageSequenceClip(heatmap_vid, fps=20) 298 | cl.write_gif(filename, fps=20) 299 | success_percentage = np.mean(sc) * 100 300 | agent.logger.log_kv('load_step', agent.steps) 301 | agent.logger.log_kv('zero_shot_success', success_percentage) 302 | agent.logger.save_wb(step=0) 303 | 304 | print_data = sorted(filter(lambda v: np.asarray(v[1]).size == 1, 305 | agent.logger.get_current_log().items())) 306 | print(tabulate(print_data)) 307 | 308 | -------------------------------------------------------------------------------- /evaluation/r3meval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /evaluation/r3meval/utils/behavior_cloning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | Minimize bc loss (MLE, MSE, RWR etc.) with pytorch optimizers 7 | """ 8 | 9 | import logging 10 | logging.disable(logging.CRITICAL) 11 | import numpy as np 12 | import time as timer 13 | import torch 14 | from torch.autograd import Variable 15 | from r3meval.utils.logger import DataLog 16 | from tqdm import tqdm 17 | 18 | 19 | class BC: 20 | def __init__(self, expert_paths, 21 | policy, 22 | epochs = 5, 23 | batch_size = 64, 24 | lr = 1e-3, 25 | optimizer = None, 26 | loss_type = 'MSE', # can be 'MLE' or 'MSE' 27 | save_logs = True, 28 | set_transforms = False, 29 | finetune = False, 30 | proprio = 0, 31 | encoder_params = [], 32 | **kwargs, 33 | ): 34 | 35 | self.policy = policy 36 | self.expert_paths = expert_paths 37 | self.epochs = epochs 38 | self.mb_size = batch_size 39 | self.logger = DataLog() 40 | self.loss_type = loss_type 41 | self.save_logs = save_logs 42 | self.finetune = finetune 43 | self.proprio = proprio 44 | self.steps = 0 45 | 46 | if set_transforms: 47 | in_shift, in_scale, out_shift, out_scale = self.compute_transformations() 48 | self.set_transformations(in_shift, in_scale, out_shift, out_scale) 49 | self.set_variance_with_data(out_scale) 50 | 51 | # construct optimizer 52 | self.optimizer = torch.optim.Adam(list(self.policy.trainable_params) + list(encoder_params), lr=lr) if optimizer is None else optimizer 53 | 54 | # Loss criterion if required 55 | if loss_type == 'MSE': 56 | self.loss_criterion = torch.nn.MSELoss() 57 | 58 | # make logger 59 | if self.save_logs: 60 | self.logger = DataLog() 61 | 62 | def compute_transformations(self, e): 63 | # get transformations 64 | if self.expert_paths == [] or self.expert_paths is None: 65 | in_shift, in_scale, out_shift, out_scale = None, None, None, None 66 | else: 67 | # observations = np.concatenate([path["observations"] for path in self.expert_paths]) 68 | observations = np.concatenate([path["images"] for path in self.expert_paths]) 69 | observations = self.encodefn(observations, finetune=self.finetune) 70 | actions = np.concatenate([path["actions"] for path in self.expert_paths]) 71 | in_shift, in_scale = np.mean(observations, axis=0), np.std(observations, axis=0) 72 | out_shift, out_scale = np.mean(actions, axis=0), np.std(actions, axis=0) 73 | return in_shift, in_scale, out_shift, out_scale 74 | 75 | def set_transformations(self, in_shift=None, in_scale=None, out_shift=None, out_scale=None): 76 | # set scalings in the target policy 77 | self.policy.model.set_transformations(in_shift, in_scale, out_shift, out_scale) 78 | self.policy.old_model.set_transformations(in_shift, in_scale, out_shift, out_scale) 79 | 80 | def set_variance_with_data(self, out_scale): 81 | # set the variance of gaussian policy based on out_scale 82 | params = self.policy.get_param_values() 83 | params[-self.policy.m:] = np.log(out_scale + 1e-12) 84 | self.policy.set_param_values(params) 85 | 86 | def loss(self, data, idx=None): 87 | if self.loss_type == 'MLE': 88 | return self.mle_loss(data, idx) 89 | elif self.loss_type == 'MSE': 90 | return self.mse_loss(data, idx) 91 | else: 92 | print("Please use valid loss type") 93 | return None 94 | 95 | def mle_loss(self, data, idx): 96 | # use indices if provided (e.g. for mini-batching) 97 | # otherwise, use all the data 98 | idx = range(data['observations'].shape[0]) if idx is None else idx 99 | if type(data['observations']) == torch.Tensor: 100 | idx = torch.LongTensor(idx) 101 | obs = data['observations'][idx] 102 | act = data['expert_actions'][idx] 103 | LL, mu, log_std = self.policy.new_dist_info(obs, act) 104 | # minimize negative log likelihood 105 | return -torch.mean(LL) 106 | 107 | def mse_loss(self, data, idx=None): 108 | idx = range(data['observations'].shape[0]) if idx is None else idx 109 | if type(data['observations']) is torch.Tensor: 110 | idx = torch.LongTensor(idx) 111 | obs = data['observations'][idx] 112 | ## Encode images with environments encode function 113 | obs = self.encodefn(obs, finetune=self.finetune) 114 | act_expert = data['expert_actions'][idx] 115 | if type(obs) is not torch.Tensor: 116 | obs = Variable(torch.from_numpy(obs).float(), requires_grad=False).cuda() 117 | 118 | ## Concatenate proprioceptive data 119 | if self.proprio: 120 | proprio= data['proprio'][idx] 121 | if type(proprio) is not torch.Tensor: 122 | proprio = Variable(torch.from_numpy(proprio).float(), requires_grad=False).cuda() 123 | obs = torch.cat([obs, proprio], -1) 124 | if type(act_expert) is not torch.Tensor: 125 | act_expert = Variable(torch.from_numpy(act_expert).float(), requires_grad=False) 126 | act_pi = self.policy.model(obs) 127 | return self.loss_criterion(act_pi, act_expert.detach()) 128 | 129 | def fit(self, data, suppress_fit_tqdm=False, **kwargs): 130 | # data is a dict 131 | # keys should have "observations" and "expert_actions" 132 | validate_keys = all([k in data.keys() for k in ["observations", "expert_actions"]]) 133 | assert validate_keys is True 134 | ts = timer.time() 135 | num_samples = data["observations"].shape[0] 136 | 137 | # log stats before 138 | if self.save_logs: 139 | loss_val = self.loss(data, idx=range(num_samples)).data.numpy().ravel()[0] 140 | self.logger.log_kv('loss_before', loss_val) 141 | 142 | # train loop 143 | for ep in config_tqdm(range(self.epochs), suppress_fit_tqdm): 144 | for mb in range(int(num_samples / self.mb_size)): 145 | rand_idx = np.random.choice(num_samples, size=self.mb_size) 146 | self.optimizer.zero_grad() 147 | loss = self.loss(data, idx=rand_idx) 148 | loss.backward() 149 | self.optimizer.step() 150 | self.steps += 1 151 | params_after_opt = self.policy.get_param_values() 152 | self.policy.set_param_values(params_after_opt, set_new=True, set_old=True) 153 | # log stats after 154 | if self.save_logs: 155 | self.logger.log_kv('epoch', self.epochs) 156 | loss_val = self.loss(data, idx=range(num_samples)).data.numpy().ravel()[0] 157 | self.logger.log_kv('loss_after', loss_val) 158 | self.logger.log_kv('time', (timer.time()-ts)) 159 | 160 | def train(self, pixel=True, **kwargs): 161 | ## If using proprioception, select the first N elements from the state observation 162 | ## Assumes proprioceptive features are at the front of the state observation 163 | if self.proprio: 164 | proprio = np.concatenate([path["observations"] for path in self.expert_paths]) 165 | proprio = proprio[:, :self.proprio] 166 | else: 167 | proprio = None 168 | 169 | ## Extract images 170 | if pixel: 171 | observations = np.concatenate([path["images"] for path in self.expert_paths]) 172 | else: 173 | observations = np.concatenate([path["observations"] for path in self.expert_paths]) 174 | 175 | ## Extract actions 176 | expert_actions = np.concatenate([path["actions"] for path in self.expert_paths]) 177 | data = dict(observations=observations, proprio=proprio, expert_actions=expert_actions) 178 | self.fit(data, **kwargs) 179 | 180 | 181 | def config_tqdm(range_inp, suppress_tqdm=False): 182 | if suppress_tqdm: 183 | return range_inp 184 | else: 185 | return tqdm(range_inp) -------------------------------------------------------------------------------- /evaluation/r3meval/utils/fc_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class FCNetwork(nn.Module): 11 | def __init__(self, obs_dim, act_dim, 12 | hidden_sizes=(64,64), 13 | nonlinearity='tanh', # either 'tanh' or 'relu' 14 | in_shift = None, 15 | in_scale = None, 16 | out_shift = None, 17 | out_scale = None): 18 | super(FCNetwork, self).__init__() 19 | 20 | self.obs_dim = obs_dim 21 | self.act_dim = act_dim 22 | assert type(hidden_sizes) == tuple 23 | self.layer_sizes = (obs_dim, ) + hidden_sizes + (act_dim, ) 24 | self.set_transformations(in_shift, in_scale, out_shift, out_scale) 25 | self.proprio_only = False 26 | 27 | # Batch Norm Layers 28 | self.bn = torch.nn.BatchNorm1d(obs_dim) 29 | 30 | # hidden layers 31 | self.fc_layers = nn.ModuleList([nn.Linear(self.layer_sizes[i], self.layer_sizes[i+1]) \ 32 | for i in range(len(self.layer_sizes) -1)]) 33 | self.nonlinearity = torch.relu if nonlinearity == 'relu' else torch.tanh 34 | 35 | def set_transformations(self, in_shift=None, in_scale=None, out_shift=None, out_scale=None): 36 | # store native scales that can be used for resets 37 | self.transformations = dict(in_shift=in_shift, 38 | in_scale=in_scale, 39 | out_shift=out_shift, 40 | out_scale=out_scale 41 | ) 42 | self.in_shift = torch.from_numpy(np.float32(in_shift)) if in_shift is not None else torch.zeros(self.obs_dim) 43 | self.in_scale = torch.from_numpy(np.float32(in_scale)) if in_scale is not None else torch.ones(self.obs_dim) 44 | self.out_shift = torch.from_numpy(np.float32(out_shift)) if out_shift is not None else torch.zeros(self.act_dim) 45 | self.out_scale = torch.from_numpy(np.float32(out_scale)) if out_scale is not None else torch.ones(self.act_dim) 46 | 47 | def forward(self, x): 48 | # Small MLP runs on CPU 49 | # Required for the way the Gaussian MLP class does weight saving and loading. 50 | if x.is_cuda: 51 | out = x.to('cpu') 52 | else: 53 | out = x 54 | 55 | ## BATCHNORM 56 | out = self.bn(out) 57 | for i in range(len(self.fc_layers)-1): 58 | out = self.fc_layers[i](out) 59 | out = self.nonlinearity(out) 60 | out = self.fc_layers[-1](out) 61 | out = out * self.out_scale + self.out_shift 62 | return out 63 | -------------------------------------------------------------------------------- /evaluation/r3meval/utils/gaussian_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import numpy as np 6 | from r3meval.utils.fc_network import FCNetwork 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | 11 | class MLP: 12 | def __init__(self, env_spec, 13 | hidden_sizes=(64,64), 14 | min_log_std=-3, 15 | init_log_std=0, 16 | seed=None): 17 | """ 18 | :param env_spec: specifications of the env (see utils/gym_env.py) 19 | :param hidden_sizes: network hidden layer sizes (currently 2 layers only) 20 | :param min_log_std: log_std is clamped at this value and can't go below 21 | :param init_log_std: initial log standard deviation 22 | :param seed: random seed 23 | """ 24 | self.n = env_spec.observation_dim # number of states 25 | self.m = env_spec.action_dim # number of actions 26 | self.min_log_std = min_log_std 27 | 28 | # Set seed 29 | # ------------------------ 30 | if seed is not None: 31 | torch.manual_seed(seed) 32 | np.random.seed(seed) 33 | 34 | # Policy network 35 | # ------------------------ 36 | self.model = FCNetwork(self.n, self.m, hidden_sizes) 37 | # make weights small 38 | for param in list(self.model.parameters())[-2:]: # only last layer 39 | param.data = 1e-2 * param.data 40 | self.log_std = Variable(torch.ones(self.m) * init_log_std, requires_grad=True) 41 | self.trainable_params = list(self.model.parameters()) + [self.log_std] 42 | 43 | # Old Policy network 44 | # ------------------------ 45 | self.old_model = FCNetwork(self.n, self.m, hidden_sizes) 46 | self.old_log_std = Variable(torch.ones(self.m) * init_log_std) 47 | self.old_params = list(self.old_model.parameters()) + [self.old_log_std] 48 | for idx, param in enumerate(self.old_params): 49 | param.data = self.trainable_params[idx].data.clone() 50 | 51 | # Easy access variables 52 | # ------------------------- 53 | self.log_std_val = np.float64(self.log_std.data.numpy().ravel()) 54 | self.param_shapes = [p.data.numpy().shape for p in self.trainable_params] 55 | self.param_sizes = [p.data.numpy().size for p in self.trainable_params] 56 | self.d = np.sum(self.param_sizes) # total number of params 57 | 58 | # Placeholders 59 | # ------------------------ 60 | self.obs_var = Variable(torch.randn(self.n), requires_grad=False) 61 | 62 | # Utility functions 63 | # ============================================ 64 | def get_param_values(self): 65 | params = np.concatenate([p.contiguous().view(-1).data.numpy() 66 | for p in self.trainable_params]) 67 | return params.copy() 68 | 69 | def set_param_values(self, new_params, set_new=True, set_old=True): 70 | if set_new: 71 | current_idx = 0 72 | for idx, param in enumerate(self.trainable_params): 73 | vals = new_params[current_idx:current_idx + self.param_sizes[idx]] 74 | vals = vals.reshape(self.param_shapes[idx]) 75 | param.data = torch.from_numpy(vals).float() 76 | current_idx += self.param_sizes[idx] 77 | # clip std at minimum value 78 | self.trainable_params[-1].data = \ 79 | torch.clamp(self.trainable_params[-1], self.min_log_std).data 80 | # update log_std_val for sampling 81 | self.log_std_val = np.float64(self.log_std.data.numpy().ravel()) 82 | if set_old: 83 | current_idx = 0 84 | for idx, param in enumerate(self.old_params): 85 | vals = new_params[current_idx:current_idx + self.param_sizes[idx]] 86 | vals = vals.reshape(self.param_shapes[idx]) 87 | param.data = torch.from_numpy(vals).float() 88 | current_idx += self.param_sizes[idx] 89 | # clip std at minimum value 90 | self.old_params[-1].data = \ 91 | torch.clamp(self.old_params[-1], self.min_log_std).data 92 | 93 | # Main functions 94 | # ============================================ 95 | def get_action(self, observation): 96 | o = np.float32(observation.reshape(1, -1)) 97 | self.obs_var.data = torch.from_numpy(o) 98 | mean = self.model(self.obs_var).data.numpy().ravel() 99 | noise = np.exp(self.log_std_val) * np.random.randn(self.m) 100 | action = mean + noise 101 | return [action, {'mean': mean, 'log_std': self.log_std_val, 'evaluation': mean}] 102 | 103 | def mean_LL(self, observations, actions, model=None, log_std=None): 104 | model = self.model if model is None else model 105 | log_std = self.log_std if log_std is None else log_std 106 | if type(observations) is not torch.Tensor: 107 | obs_var = Variable(torch.from_numpy(observations).float(), requires_grad=False) 108 | else: 109 | obs_var = observations 110 | if type(actions) is not torch.Tensor: 111 | act_var = Variable(torch.from_numpy(actions).float(), requires_grad=False) 112 | else: 113 | act_var = actions 114 | mean = model(obs_var) 115 | zs = (act_var - mean) / torch.exp(log_std) 116 | LL = - 0.5 * torch.sum(zs ** 2, dim=1) + \ 117 | - torch.sum(log_std) + \ 118 | - 0.5 * self.m * np.log(2 * np.pi) 119 | return mean, LL 120 | 121 | def log_likelihood(self, observations, actions, model=None, log_std=None): 122 | mean, LL = self.mean_LL(observations, actions, model, log_std) 123 | return LL.data.numpy() 124 | 125 | def old_dist_info(self, observations, actions): 126 | mean, LL = self.mean_LL(observations, actions, self.old_model, self.old_log_std) 127 | return [LL, mean, self.old_log_std] 128 | 129 | def new_dist_info(self, observations, actions): 130 | mean, LL = self.mean_LL(observations, actions, self.model, self.log_std) 131 | return [LL, mean, self.log_std] 132 | 133 | def likelihood_ratio(self, new_dist_info, old_dist_info): 134 | LL_old = old_dist_info[0] 135 | LL_new = new_dist_info[0] 136 | LR = torch.exp(LL_new - LL_old) 137 | return LR 138 | 139 | def mean_kl(self, new_dist_info, old_dist_info): 140 | old_log_std = old_dist_info[2] 141 | new_log_std = new_dist_info[2] 142 | old_std = torch.exp(old_log_std) 143 | new_std = torch.exp(new_log_std) 144 | old_mean = old_dist_info[1] 145 | new_mean = new_dist_info[1] 146 | Nr = (old_mean - new_mean) ** 2 + old_std ** 2 - new_std ** 2 147 | Dr = 2 * new_std ** 2 + 1e-8 148 | sample_kl = torch.sum(Nr / Dr + new_log_std - old_log_std, dim=1) 149 | return torch.mean(sample_kl) 150 | -------------------------------------------------------------------------------- /evaluation/r3meval/utils/gym_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | Wrapper around a gym env that provides convenience functions 7 | """ 8 | 9 | import gym 10 | import numpy as np 11 | 12 | 13 | class EnvSpec(object): 14 | def __init__(self, obs_dim, act_dim, horizon): 15 | self.observation_dim = obs_dim 16 | self.action_dim = act_dim 17 | self.horizon = horizon 18 | 19 | 20 | class GymEnv(object): 21 | def __init__(self, env, env_kwargs=None, 22 | obs_mask=None, act_repeat=1, 23 | *args, **kwargs): 24 | 25 | # get the correct env behavior 26 | if type(env) == str: 27 | env = gym.make(env) 28 | elif isinstance(env, gym.Env): 29 | env = env 30 | elif callable(env): 31 | env = env(**env_kwargs) 32 | else: 33 | print("Unsupported environment format") 34 | raise AttributeError 35 | 36 | self.env = env 37 | self.env_id = env.unwrapped.spec.id 38 | self.act_repeat = act_repeat 39 | 40 | try: 41 | self._horizon = env.spec.max_episode_steps 42 | except AttributeError: 43 | self._horizon = env.spec._horizon 44 | 45 | assert self._horizon % act_repeat == 0 46 | self._horizon = self._horizon // self.act_repeat 47 | 48 | try: 49 | self._action_dim = self.env.action_space.shape[0] 50 | except AttributeError: 51 | self._action_dim = self.env.unwrapped.action_dim 52 | 53 | try: 54 | self._observation_dim = self.env.observation_space.shape[0] 55 | except AttributeError: 56 | self._observation_dim = self.env.unwrapped.obs_dim 57 | 58 | # Specs 59 | self.spec = EnvSpec(self._observation_dim, self._action_dim, self._horizon) 60 | 61 | # obs mask 62 | self.obs_mask = np.ones(self._observation_dim) if obs_mask is None else obs_mask 63 | 64 | @property 65 | def action_dim(self): 66 | return self._action_dim 67 | 68 | @property 69 | def observation_dim(self): 70 | return self._observation_dim 71 | 72 | @property 73 | def observation_space(self): 74 | return self.env.observation_space 75 | 76 | @property 77 | def action_space(self): 78 | return self.env.action_space 79 | 80 | @property 81 | def horizon(self): 82 | return self._horizon 83 | 84 | def reset(self, seed=None): 85 | try: 86 | self.env._elapsed_steps = 0 87 | return self.env.unwrapped.reset_model(seed=seed) 88 | except: 89 | if seed is not None: 90 | self.set_seed(seed) 91 | return self.env.reset() 92 | 93 | def reset_model(self, seed=None): 94 | # overloading for legacy code 95 | return self.reset(seed) 96 | 97 | def step(self, action): 98 | action = action.clip(self.action_space.low, self.action_space.high) 99 | if self.act_repeat == 1: 100 | obs, cum_reward, done, ifo = self.env.step(action) 101 | else: 102 | cum_reward = 0.0 103 | for i in range(self.act_repeat): 104 | obs, reward, done, ifo = self.env.step(action) 105 | cum_reward += reward 106 | if done: break 107 | return self.obs_mask * obs, cum_reward, done, ifo 108 | 109 | def render(self): 110 | try: 111 | self.env.unwrapped.mujoco_render_frames = True 112 | self.env.unwrapped.mj_render() 113 | except: 114 | self.env.render() 115 | 116 | def set_seed(self, seed=123): 117 | try: 118 | self.env.seed(seed) 119 | except AttributeError: 120 | self.env._seed(seed) 121 | 122 | def get_obs(self): 123 | try: 124 | return self.obs_mask * self.env.get_obs() 125 | except: 126 | return self.obs_mask * self.env._get_obs() 127 | 128 | def get_env_infos(self): 129 | try: 130 | return self.env.unwrapped.get_env_infos() 131 | except: 132 | return {} 133 | 134 | # =========================================== 135 | # Trajectory optimization related 136 | # Envs should support these functions in case of trajopt 137 | 138 | def get_env_state(self): 139 | try: 140 | return self.env.unwrapped.get_env_state() 141 | except: 142 | raise NotImplementedError 143 | 144 | def set_env_state(self, state_dict): 145 | try: 146 | self.env.unwrapped.set_env_state(state_dict) 147 | except: 148 | self.env.unwrapped.__setstate__(state_dict) 149 | # raise NotImplementedError 150 | 151 | def real_env_step(self, bool_val): 152 | try: 153 | self.env.unwrapped.real_step = bool_val 154 | except: 155 | raise NotImplementedError 156 | 157 | # =========================================== 158 | 159 | def visualize_policy(self, policy, horizon=1000, num_episodes=1, mode='exploration'): 160 | try: 161 | self.env.unwrapped.visualize_policy(policy, horizon, num_episodes, mode) 162 | except: 163 | for ep in range(num_episodes): 164 | o = self.reset() 165 | d = False 166 | t = 0 167 | score = 0.0 168 | while t < horizon and d is False: 169 | a = policy.get_action(o)[0] if mode == 'exploration' else policy.get_action(o)[1]['evaluation'] 170 | o, r, d, _ = self.step(a) 171 | score = score + r 172 | self.render() 173 | t = t+1 174 | print("Episode score = %f" % score) 175 | 176 | def evaluate_policy(self, policy, 177 | num_episodes=5, 178 | horizon=None, 179 | gamma=1, 180 | visual=False, 181 | percentile=[], 182 | get_full_dist=False, 183 | mean_action=False, 184 | init_env_state=None, 185 | terminate_at_done=True, 186 | seed=123): 187 | 188 | self.set_seed(seed) 189 | horizon = self._horizon if horizon is None else horizon 190 | mean_eval, std, min_eval, max_eval = 0.0, 0.0, -1e8, -1e8 191 | ep_returns = np.zeros(num_episodes) 192 | 193 | for ep in range(num_episodes): 194 | self.reset() 195 | if init_env_state is not None: 196 | self.set_env_state(init_env_state) 197 | t, done = 0, False 198 | while t < horizon and (done == False or terminate_at_done == False): 199 | self.render() if visual is True else None 200 | o = self.get_obs() 201 | a = policy.get_action(o)[1]['evaluation'] if mean_action is True else policy.get_action(o)[0] 202 | o, r, done, _ = self.step(a) 203 | ep_returns[ep] += (gamma ** t) * r 204 | t += 1 205 | 206 | mean_eval, std = np.mean(ep_returns), np.std(ep_returns) 207 | min_eval, max_eval = np.amin(ep_returns), np.amax(ep_returns) 208 | base_stats = [mean_eval, std, min_eval, max_eval] 209 | 210 | percentile_stats = [] 211 | for p in percentile: 212 | percentile_stats.append(np.percentile(ep_returns, p)) 213 | 214 | full_dist = ep_returns if get_full_dist is True else None 215 | 216 | return [base_stats, percentile_stats, full_dist] 217 | -------------------------------------------------------------------------------- /evaluation/r3meval/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import numpy as np 6 | import scipy 7 | import pickle 8 | import os 9 | import csv 10 | import wandb 11 | 12 | class DataLog: 13 | 14 | def __init__(self): 15 | self.log = {} 16 | self.max_len = 0 17 | 18 | def init_wb(self, cfg, project='r3mbc_sweep_full'): 19 | print(cfg.keys()) 20 | wandb.init(project=project, entity='burntkayl', name=cfg.job_name, reinit=True) 21 | fullcfg = {**cfg, **cfg.env_kwargs, **cfg.bc_kwargs} 22 | wandb.config.update(fullcfg) 23 | 24 | def log_kv(self, key, value): 25 | # logs the (key, value) pair 26 | 27 | # TODO: This implementation is error-prone: 28 | # it would be NOT aligned if some keys are missing during one iteration. 29 | if key not in self.log: 30 | self.log[key] = [] 31 | self.log[key].append(value) 32 | if len(self.log[key]) > self.max_len: 33 | self.max_len = self.max_len + 1 34 | 35 | def save_wb(self, step): 36 | logs = self.get_current_log() 37 | wandb.log(logs, step = step) 38 | 39 | def save_log(self, save_path): 40 | # TODO: Validate all lengths are the same. 41 | pickle.dump(self.log, open(save_path + '/log.pickle', 'wb')) 42 | with open(save_path + '/log.csv', 'w') as csv_file: 43 | fieldnames = list(self.log.keys()) 44 | if 'iteration' not in fieldnames: 45 | fieldnames = ['iteration'] + fieldnames 46 | 47 | writer = csv.DictWriter(csv_file, fieldnames=fieldnames) 48 | writer.writeheader() 49 | for row in range(self.max_len): 50 | row_dict = {'iteration': row} 51 | for key in self.log.keys(): 52 | if row < len(self.log[key]): 53 | row_dict[key] = self.log[key][row] 54 | writer.writerow(row_dict) 55 | 56 | def get_current_log(self): 57 | row_dict = {} 58 | for key in self.log.keys(): 59 | # TODO: this is very error-prone (alignment is not guaranteed) 60 | row_dict[key] = self.log[key][-1] 61 | return row_dict 62 | 63 | def shrink_to(self, num_entries): 64 | for key in self.log.keys(): 65 | self.log[key] = self.log[key][:num_entries] 66 | 67 | self.max_len = num_entries 68 | assert min([len(series) for series in self.log.values()]) == \ 69 | max([len(series) for series in self.log.values()]) 70 | 71 | def read_log(self, log_path): 72 | assert log_path.endswith('log.csv') 73 | 74 | with open(log_path) as csv_file: 75 | reader = csv.DictReader(csv_file) 76 | listr = list(reader) 77 | keys = reader.fieldnames 78 | data = {} 79 | for key in keys: 80 | data[key] = [] 81 | for row, row_dict in enumerate(listr): 82 | for key in keys: 83 | try: 84 | data[key].append(eval(row_dict[key])) 85 | except: 86 | print("ERROR on reading key {}: {}".format(key, row_dict[key])) 87 | 88 | if 'iteration' in data and data['iteration'][-1] != row: 89 | raise RuntimeError("Iteration %d mismatch -- possibly corrupted logfile?" % row) 90 | 91 | self.log = data 92 | self.max_len = max(len(v) for k, v in self.log.items()) 93 | print("Log read from {}: had {} entries".format(log_path, self.max_len)) 94 | -------------------------------------------------------------------------------- /evaluation/r3meval/utils/obs_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import numpy as np 6 | import gym 7 | from gym.spaces.box import Box 8 | import omegaconf 9 | import torch 10 | from torch.utils import model_zoo 11 | import torch.nn as nn 12 | from torch.nn.modules.linear import Identity 13 | import torchvision.models as models 14 | import torchvision.transforms as T 15 | from PIL import Image 16 | from pathlib import Path 17 | import pickle 18 | from torchvision.utils import save_image 19 | import hydra 20 | 21 | 22 | def init(module, weight_init, bias_init, gain=1): 23 | weight_init(module.weight.data, gain=gain) 24 | bias_init(module.bias.data) 25 | return module 26 | 27 | def _get_embedding(embedding_name='resnet34', load_path="", *args, **kwargs): 28 | if load_path == "random": 29 | prt = False 30 | else: 31 | prt = True 32 | if embedding_name == 'resnet34': 33 | model = models.resnet34(pretrained=prt, progress=False) 34 | embedding_dim = 512 35 | elif embedding_name == 'resnet18': 36 | model = models.resnet18(pretrained=prt, progress=False) 37 | embedding_dim = 512 38 | elif 'resnet50' in embedding_name: 39 | model = models.resnet50(pretrained=True, progress=False) 40 | embedding_dim = 2048 41 | else: 42 | print("Requested model not available currently") 43 | raise NotImplementedError 44 | # make FC layers to be identity 45 | # NOTE: This works for ResNet backbones but should check if same 46 | # template applies to other backbone architectures 47 | model.fc = Identity() 48 | model = model.eval() 49 | return model, embedding_dim 50 | 51 | def _get_shift(shift): 52 | 53 | def no_shift(img): 54 | return img 55 | 56 | def bottom_left_copy_crop(img): 57 | img[-25:,:120,:] = img[30:55,25:145,:] 58 | return img 59 | 60 | def bottom_left_red_rectangle(img): 61 | img[-25:,:120,2] = 1 62 | return img 63 | 64 | def bottom_left_white_rectangle(img): 65 | img[-25:,:120,:] = 255 66 | return img 67 | 68 | def bottom_left_no_blue_rectangle(img): 69 | img[-25:,:120,2] = 1 70 | return img 71 | 72 | def top_right_red_rectangle(img): 73 | img[:40,125:,0] = 255 74 | return img 75 | 76 | if shift == "none": 77 | return no_shift 78 | elif shift == "bottom_left_copy_crop": 79 | return bottom_left_copy_crop 80 | elif shift == "bottom_left_red_rectangle": 81 | return bottom_left_red_rectangle 82 | elif shift == "bottom_left_white_rectangle": 83 | return bottom_left_white_rectangle 84 | elif shift == "bottom_left_no_blue_rectangle": 85 | return bottom_left_no_blue_rectangle 86 | elif shift == "top_right_red_rectangle": 87 | return top_right_red_rectangle 88 | else: 89 | print("Requested shift not available currently") 90 | raise NotImplementedError 91 | 92 | 93 | class ClipEnc(nn.Module): 94 | def __init__(self, m): 95 | super().__init__() 96 | self.m = m 97 | def forward(self, im): 98 | e = self.m.encode_image(im) 99 | return e 100 | 101 | 102 | class IgnoreEnc(nn.Module): 103 | def __init__(self, m): 104 | super().__init__() 105 | self.m = m 106 | 107 | def forward(self, im): 108 | B = im.shape[0] 109 | return torch.normal(torch.zeros((B, self.m)), torch.ones(B, self.m)) 110 | 111 | class MaskVisionTransformerEnc(nn.Module): 112 | def __init__(self, vit_model): 113 | super().__init__() 114 | self.vit_model = vit_model 115 | 116 | def forward(self, im): 117 | B = im.shape[0] 118 | C = 6 119 | # fetch attention masks 120 | attn = self.vit_model.get_last_selfattention(im) 121 | masks = attn[:,:,0,:].reshape(B, C, -1) # B, C, H*W 122 | # sum weights from "good" attention masks to reweight patches 123 | masks = masks[:,[1,3,4]].sum(1).unsqueeze(-1) 124 | masks[:,0,:] = 0 125 | masks += 1 126 | 127 | # reweight tokens by masks 128 | x = self.vit_model.prepare_tokens(im) 129 | x = x * masks 130 | for blk in self.vit_model.blocks: 131 | x = blk(x) 132 | x = self.vit_model.norm(x) 133 | return x[:, 0] 134 | 135 | class KeypointsVisionTransformerEnc(nn.Module): 136 | def __init__(self, vit_model): 137 | super().__init__() 138 | self.vit_model = vit_model 139 | self.embed_dim = 6 140 | 141 | def get_last_value(self, im): 142 | x = self.vit_model.prepare_tokens(im) 143 | for i, blk in enumerate(self.vit_model.blocks): 144 | if i < len(self.vit_model.blocks) - 1: 145 | x = blk(x) 146 | else: 147 | # apply norm to input 148 | x = blk.norm1(x) 149 | # apply attention up to value 150 | B, N, C = x.shape 151 | qkv = blk.attn.qkv(x).reshape(B, N, 3, blk.attn.num_heads, C // blk.attn.num_heads).permute(2, 0, 3, 1, 4) 152 | return qkv[2] 153 | 154 | def forward(self, im): 155 | B = im.shape[0] 156 | C = 6 157 | # fetch attention masks and values 158 | attn = self.vit_model.get_last_selfattention(im) 159 | masks = attn[:,:,0,:].reshape(B, C, -1) # B, C, H*W 160 | values = self.get_last_value(im) 161 | D = values.shape[-1] 162 | # sum weights from "good" attention masks to reweight patches 163 | # masks = masks[:,[1,3,4]].sum(1).unsqueeze(-1) 164 | # ignore CLS token 165 | masks = masks[:,:,1:] 166 | values = values[:,:,1:] 167 | # find max keypoints and index into values 168 | keypoints = masks.argmax(-1) 169 | values_flat = values.reshape(B*C, -1, D) # B*C, H*W, D 170 | kp_flat = keypoints.reshape(-1).long() 171 | values_flat = values_flat[torch.arange(B*C), kp_flat] 172 | values = values_flat.reshape(B, C, D) 173 | # normalize keypoints 174 | keypoints = (keypoints - 98) / 196 175 | # return concattenated values and keypoints 176 | return torch.cat([values, keypoints.unsqueeze(-1)], -1).reshape(B, -1) 177 | 178 | 179 | class StateEmbedding(gym.ObservationWrapper): 180 | """ 181 | This wrapper places a convolution model over the observation. 182 | 183 | From https://pytorch.org/vision/stable/models.html 184 | All pre-trained models expect input images normalized in the same way, 185 | i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), 186 | where H and W are expected to be at least 224. 187 | 188 | Args: 189 | env (Gym environment): the original environment, 190 | embedding_name (str, 'baseline'): the name of the convolution model, 191 | device (str, 'cuda'): where to allocate the model. 192 | 193 | """ 194 | def __init__(self, env, embedding_name=None, device='cuda', load_path="", proprio=0, camera_name=None, env_name=None, shift="none"): 195 | gym.ObservationWrapper.__init__(self, env) 196 | 197 | self.proprio = proprio 198 | self.load_path = load_path 199 | self.start_finetune = False 200 | self.embedding_name = embedding_name 201 | if load_path == "clip": 202 | import clip 203 | model, cliptransforms = clip.load("RN50", device="cuda") 204 | embedding = ClipEnc(model) 205 | embedding.eval() 206 | embedding_dim = 1024 207 | self.transforms = cliptransforms 208 | elif (load_path == "random") or (load_path == "") or (embedding_name == "resnet50_insup"): 209 | embedding, embedding_dim = _get_embedding(embedding_name=embedding_name, load_path=load_path) 210 | self.transforms = T.Compose([T.Resize(256), 211 | T.CenterCrop(224), 212 | T.ToTensor(), # ToTensor() divides by 255 213 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 214 | elif "resnet50" == embedding_name: 215 | from r3m import load_r3m_reproduce 216 | rep = load_r3m_reproduce("r3m") 217 | rep.eval() 218 | embedding_dim = rep.module.outdim 219 | embedding = rep 220 | self.transforms = T.Compose([T.Resize(256), 221 | T.CenterCrop(224), 222 | T.ToTensor()]) # ToTensor() divides by 255 223 | elif "deit_s" in embedding_name: 224 | import vit_models 225 | if embedding_name == "deit_s_sin_dist_cls_feat": 226 | embedding = vit_models.dino_small_dist_cls_feat(patch_size=16, pretrained=False) 227 | state_dict = torch.hub.load_state_dict_from_url(url="https://github.com/Muzammal-Naseer/Intriguing-Properties-of-Vision-Transformers/releases/download/v0/deit_s_sin_dist.pth") 228 | msg = embedding.load_state_dict(state_dict["model"], strict=False) 229 | print(msg) 230 | elif embedding_name == "deit_s_sin_dist_shape_feat": 231 | embedding = vit_models.dino_small_dist_shape_feat(patch_size=16, pretrained=False) 232 | state_dict = torch.hub.load_state_dict_from_url(url="https://github.com/Muzammal-Naseer/Intriguing-Properties-of-Vision-Transformers/releases/download/v0/deit_s_sin_dist.pth") 233 | msg = embedding.load_state_dict(state_dict["model"], strict=False) 234 | print(msg) 235 | elif embedding_name == "deit_s_sin": 236 | embedding = vit_models.dino_small_feat(patch_size=16, pretrained=False) 237 | state_dict = torch.hub.load_state_dict_from_url(url="https://github.com/Muzammal-Naseer/Intriguing-Properties-of-Vision-Transformers/releases/download/v0/deit_s_sin.pth") 238 | msg = embedding.load_state_dict(state_dict["model"], strict=False) 239 | print(msg) 240 | elif embedding_name == "deit_s": 241 | embedding = vit_models.dino_small_feat(patch_size=16, pretrained=True) 242 | elif embedding_name == "deit_s_insup": 243 | embedding = vit_models.dino_small_feat(patch_size=16, pretrained=False) 244 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth") 245 | msg = embedding.load_state_dict(state_dict["model"], strict=False) 246 | print(msg) 247 | 248 | 249 | embedding.eval() 250 | embedding_dim = embedding.embed_dim 251 | self.transforms = T.Compose([T.Resize((224, 224)), 252 | T.ToTensor(), # ToTensor() divides by 255 253 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 254 | 255 | elif "resnet50_sin" == embedding_name: 256 | embedding = models.resnet50(pretrained=False) 257 | embedding = embedding.eval() 258 | embedding_dim = 2048 259 | checkpoint = model_zoo.load_url('https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/6f41d2e86fc60566f78de64ecff35cc61eb6436f/resnet50_train_60_epochs-c8e5653e.pth.tar') 260 | model = torch.nn.DataParallel(embedding) 261 | # state dict is saved with DataParallel, this will change embedding weights 262 | model.load_state_dict(checkpoint["state_dict"]) 263 | embedding.fc = Identity() 264 | embedding = embedding.eval() 265 | self.transforms = T.Compose([ 266 | T.Resize(256), 267 | T.CenterCrop(224), 268 | T.ToTensor()]) 269 | 270 | elif "mvp" == embedding_name and "mvp" == load_path: 271 | import mvp 272 | embedding = mvp.load("vitb-mae-egosoup") 273 | embedding.eval() 274 | embedding_dim = embedding.embed_dim 275 | self.transforms = T.Compose([T.Resize(256), 276 | T.CenterCrop(224), 277 | T.ToTensor(), # ToTensor() divides by 255 278 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 279 | elif "mvp" == embedding_name and "imagenet" == load_path: 280 | import mvp 281 | embedding = mvp.load("vits-sup-in") 282 | embedding.eval() 283 | embedding_dim = embedding.embed_dim 284 | self.transforms = T.Compose([T.Resize(256), 285 | T.CenterCrop(224), 286 | T.ToTensor(), # ToTensor() divides by 255 287 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 288 | elif "pickle" in load_path and embedding_name == 'mvp': 289 | import mvp 290 | embedding = pickle.load(open(load_path, 'rb')).cuda() 291 | embedding.eval() 292 | embedding_dim = embedding.embed_dim 293 | self.transforms = T.Compose([T.Resize(256), 294 | T.CenterCrop(224), 295 | T.ToTensor(), # ToTensor() divides by 255 296 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 297 | elif "ignore_input" == load_path: 298 | self.transforms = T.Compose([T.ToTensor(),T.Resize(224)]) 299 | embedding_dim = 1024 300 | embedding = IgnoreEnc(embedding_dim) 301 | elif "pickle" in load_path and 'dino' in embedding_name and embedding_name != 'resnet50_dino' and embedding_name != 'mask_two_pass_dino': # TODO 302 | # get vision transformer by loading original weights 🤪 303 | embedding = torch.hub.load('facebookresearch/dino:main', 304 | 'dino_vits16') 305 | print(f"Loading model from {load_path}") 306 | embedding = pickle.load(open(load_path, 'rb')) 307 | embedding.eval() 308 | embedding_dim = embedding.embed_dim 309 | 310 | arch_args = embedding_name.split("-") 311 | if len(arch_args) > 1: 312 | state_dict = embedding.state_dict() 313 | new_bias = state_dict['blocks.11.attn.qkv.bias'].reshape((3, 6, -1)) 314 | new_weight = state_dict['blocks.11.attn.qkv.weight'].reshape((3, 6, 64, 384)) 315 | unmasked_heads = [int(um_head) for um_head in arch_args[1:]] 316 | for head in range(6): 317 | if head not in unmasked_heads: 318 | print(f"masking out {head}") 319 | # surgically remove some attention maps 320 | # zero out bias 321 | new_bias[:,head,:] = 0 322 | # zero out weight 323 | new_weight[:,head,:] = 0 324 | state_dict['blocks.11.attn.qkv.bias'] = new_bias.reshape(-1) 325 | state_dict['blocks.11.attn.qkv.weight'] = new_weight.reshape((-1,384)) 326 | embedding.load_state_dict(state_dict) 327 | 328 | self.transforms = T.Compose([T.ToTensor(), 329 | T.Resize(224), 330 | T.Normalize((0.485, 0.456, 0.406), 331 | (0.229, 0.224, 0.225))]) 332 | # elif "pickle" in load_path and embedding_name == 'resnet50': 333 | # print(f"Loading model from {load_path}") 334 | # embedding = pickle.load(open(load_path, 'rb')) 335 | # embedding.eval() 336 | # embedding_dim = embedding.module.outdim 337 | # self.transforms = T.Compose([T.Resize(256), 338 | # T.CenterCrop(224), 339 | # T.ToTensor()]) # ToTensor() divides by 255 340 | elif "dino" in embedding_name and embedding_name != 'resnet50_dino' and embedding_name != 'mask_two_pass_dino': 341 | embedding = torch.hub.load('facebookresearch/dino:main', 342 | 'dino_vits16') 343 | embedding.eval() 344 | embedding_dim = embedding.embed_dim 345 | if embedding_name == "dino_ensemble": 346 | num_heads = 6 347 | ensemble_weights = torch.FloatTensor(num_heads) # WARNING: Variable not added to module's parameter list 348 | ensemble_weights[:] = embedding.blocks[-1].attn.scale # TODO: add noise, should not suffer from symmetry tho 349 | ensemble_weights = ensemble_weights.reshape((1, -1, 1, 1)).cuda() # reshape to match attn to avoid broadcasting in Attention 350 | embedding.blocks[-1].attn.scale = ensemble_weights 351 | ensemble_weights.requires_grad = True 352 | 353 | arch_args = embedding_name.split("-") 354 | if len(arch_args) > 1: 355 | state_dict = embedding.state_dict() 356 | new_bias = state_dict['blocks.11.attn.qkv.bias'].reshape((3, 6, -1)) 357 | new_weight = state_dict['blocks.11.attn.qkv.weight'].reshape((3, 6, 64, 384)) 358 | unmasked_heads = [int(um_head) for um_head in arch_args[1:]] 359 | for head in range(6): 360 | if head not in unmasked_heads: 361 | print(f"masking out {head}") 362 | # surgically remove some attention maps 363 | # zero out bias 364 | new_bias[:,head,:] = 0 365 | # zero out weight 366 | new_weight[:,head,:] = 0 367 | state_dict['blocks.11.attn.qkv.bias'] = new_bias.reshape(-1) 368 | state_dict['blocks.11.attn.qkv.weight'] = new_weight.reshape((-1,384)) 369 | embedding.load_state_dict(state_dict) 370 | 371 | self.transforms = T.Compose([T.ToTensor(), 372 | T.Resize(224), 373 | T.Normalize((0.485, 0.456, 0.406), 374 | (0.229, 0.224, 0.225))]) 375 | elif embedding_name=='resnet50_dino': 376 | embedding = torch.hub.load('facebookresearch/dino:main', 377 | 'dino_resnet50') 378 | 379 | embedding.eval() 380 | embedding_dim = 2048 381 | 382 | self.transforms = T.Compose([T.Resize(256, interpolation=3), 383 | T.CenterCrop(224), 384 | T.ToTensor(), 385 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),]) 386 | # elif embedding_name=='resnet50_dino' and 'pickle' in load_path: 387 | # try: 388 | # print(f"Loading model from {load_path}, resnet50_dino") 389 | # embedding = pickle.load(open(load_path, 'rb')) 390 | # except: 391 | # # /iris/u/kayburns/new_arch/r3m/evaluation/r3meval/core/outputs/main_sweep_1/2022-11-01_16-28-13/ 392 | # import pdb; pdb.set_trace() 393 | 394 | # embedding.eval() 395 | # embedding_dim = embedding.embed_dim 396 | 397 | # self.transforms = T.Compose([T.Resize(256, interpolation=3), 398 | # T.CenterCrop(224), 399 | # T.ToTensor(), 400 | # T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),]) 401 | elif embedding_name == 'mask_two_pass_dino': 402 | vit_model = torch.hub.load('facebookresearch/dino:main', 403 | 'dino_vits16') 404 | vit_model.eval() 405 | embedding_dim = vit_model.embed_dim 406 | embedding = MaskVisionTransformerEnc(vit_model) 407 | embedding.eval() 408 | 409 | self.transforms = T.Compose([T.ToTensor(), 410 | T.Resize(224), 411 | T.Normalize((0.485, 0.456, 0.406), 412 | (0.229, 0.224, 0.225))]) 413 | elif embedding_name == 'keypoints': 414 | import dino 415 | vit_model = torch.hub.load('facebookresearch/dino:main', 416 | 'dino_vits16') 417 | vit_model.eval() 418 | embedding_dim = 6*65 419 | embedding = KeypointsVisionTransformerEnc(vit_model) 420 | embedding.eval() 421 | 422 | self.transforms = T.Compose([T.ToTensor(), 423 | T.Resize(224), 424 | T.Normalize((0.485, 0.456, 0.406), 425 | (0.229, 0.224, 0.225))]) 426 | else: 427 | raise NameError("Invalid Model") 428 | embedding.eval() 429 | 430 | if device == 'cuda' and torch.cuda.is_available(): 431 | print('Using CUDA.') 432 | device = torch.device('cuda') 433 | else: 434 | print('Not using CUDA.') 435 | device = torch.device('cpu') 436 | self.device = device 437 | embedding.to(device=device) 438 | 439 | self.embedding, self.embedding_dim = embedding, embedding_dim 440 | self.observation_space = Box( 441 | low=-np.inf, high=np.inf, shape=(self.embedding_dim+self.proprio,)) 442 | 443 | def observation(self, observation): 444 | ### INPUT SHOULD BE [0,255] 445 | if self.embedding is not None: 446 | inp = self.transforms(Image.fromarray(observation.astype(np.uint8))).reshape(-1, 3, 224, 224) 447 | if not 'VisionTransformer' in type(self.embedding).__name__: # "r3m" in self.load_path and "pickle" not in self.load_path: 448 | print("shifting input to 0-255 (should only happen for R3M)") 449 | ## R3M Expects input to be 0-255, preprocess makes 0-1 450 | inp *= 255.0 451 | inp = inp.to(self.device) 452 | with torch.no_grad(): 453 | emb = self.embedding(inp).view(-1, self.embedding_dim).to('cpu').numpy().squeeze() 454 | 455 | ## IF proprioception add it to end of embedding 456 | if self.proprio: 457 | try: 458 | proprio = self.env.unwrapped.get_obs()[:self.proprio] 459 | except: 460 | proprio = self.env.unwrapped._get_obs()[:self.proprio] 461 | emb = np.concatenate([emb, proprio]) 462 | 463 | return emb 464 | else: 465 | return observation 466 | 467 | def encode_batch(self, obs, finetune=False): 468 | ### INPUT SHOULD BE [0,255] 469 | inp = [] 470 | for o in obs: 471 | i = self.transforms(Image.fromarray(o.astype(np.uint8))).reshape(-1, 3, 224, 224) 472 | if (self.embedding_name == 'resnet50') or (self.embedding_name == 'resnet50_insup') or (self.embedding_name == 'resnet50_dino'): # mapping resnet50 to R3M # if not 'VisionTransformer' in type(self.embedding).__name__: # and "pickle" not in self.load_path: # not 'VisionTransformer' in type(self.embedding).__name__: 473 | ## R3M Expects input to be 0-255, preprocess makes 0-1 474 | print("shifting input to 0-255 (should only happen for R3M)") 475 | i *= 255.0 476 | inp.append(i) 477 | inp = torch.cat(inp) 478 | inp = inp.to(self.device) 479 | if finetune and self.start_finetune: 480 | emb = self.embedding(inp).view(-1, self.embedding_dim) 481 | else: 482 | with torch.no_grad(): 483 | emb = self.embedding(inp).view(-1, self.embedding_dim).to('cpu').numpy().squeeze() 484 | return emb 485 | 486 | def get_obs(self): 487 | if self.embedding is not None: 488 | return self.observation(self.env.observation(None)) 489 | else: 490 | # returns the state based observations 491 | return self.env.unwrapped.get_obs() 492 | 493 | def start_finetuning(self): 494 | self.start_finetune = True 495 | 496 | 497 | class MuJoCoPixelObs(gym.ObservationWrapper): 498 | def __init__(self, env, width, height, camera_name, device_id=-1, depth=False, shift="none", *args, **kwargs): 499 | gym.ObservationWrapper.__init__(self, env) 500 | self.observation_space = Box(low=0., high=255., shape=(3, width, height)) 501 | self.width = width 502 | self.height = height 503 | self.camera_name = camera_name 504 | self.depth = depth 505 | self.device_id = device_id 506 | self.shift = _get_shift(shift) 507 | if "v2" in env.spec.id: 508 | self.get_obs = env._get_obs 509 | 510 | def get_image(self): 511 | if self.camera_name == "default": 512 | img = self.sim.render(width=self.width, height=self.height, depth=self.depth, 513 | device_id=self.device_id) 514 | else: 515 | img = self.sim.render(width=self.width, height=self.height, depth=self.depth, 516 | camera_name=self.camera_name, device_id=self.device_id) 517 | img = img[::-1,:,:] 518 | img = self.shift(img) 519 | 520 | return img 521 | 522 | def observation(self, observation): 523 | # This function creates observations based on the current state of the environment. 524 | # Argument `observation` is ignored, but `gym.ObservationWrapper` requires it. 525 | return self.get_image() 526 | -------------------------------------------------------------------------------- /evaluation/r3meval/utils/sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import logging 6 | import numpy as np 7 | from r3meval.utils.gym_env import GymEnv 8 | from r3meval.utils import tensor_utils 9 | logging.disable(logging.CRITICAL) 10 | import multiprocessing as mp 11 | import time as timer 12 | logging.disable(logging.CRITICAL) 13 | import gc 14 | from collections import namedtuple 15 | 16 | from metaworld.envs import (ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE, 17 | ALL_V2_ENVIRONMENTS_GOAL_HIDDEN) 18 | 19 | 20 | # Single core rollout to sample trajectories 21 | # ======================================================= 22 | def do_rollout( 23 | num_traj, 24 | env, 25 | policy, 26 | eval_mode = False, 27 | horizon = 1e6, 28 | base_seed = None, 29 | env_kwargs=None, 30 | ): 31 | """ 32 | :param num_traj: number of trajectories (int) 33 | :param env: environment (env class, str with env_name, or factory function) 34 | :param policy: policy to use for action selection 35 | :param eval_mode: use evaluation mode for action computation (bool) 36 | :param horizon: max horizon length for rollout (<= env.horizon) 37 | :param base_seed: base seed for rollouts (int) 38 | :param env_kwargs: dictionary with parameters, will be passed to env generator 39 | :return: 40 | """ 41 | # get the correct env behavior 42 | print("Evaluating") 43 | if type(env) == str: 44 | ## MetaWorld specific stuff 45 | if "v2" in env: 46 | env_name = env 47 | env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name]() 48 | env._freeze_rand_vec =False 49 | env.horizon = 500 50 | env.spec = namedtuple('spec', ['id', 'max_episode_steps', 'observation_dim', 'action_dim']) 51 | env.spec.id = env_name 52 | env.spec.observation_dim = int(env.observation_space.shape[0]) 53 | env.spec.action_dim = int(env.action_space.shape[0]) 54 | env.spec.max_episode_steps = 500 55 | else: 56 | env = GymEnv(env) 57 | elif isinstance(env, GymEnv): 58 | env = env 59 | elif callable(env): 60 | env = env(**env_kwargs) 61 | else: 62 | # print("Unsupported environment format") 63 | # raise AttributeError 64 | ## Support passing in one env for everything 65 | env = env 66 | 67 | if base_seed is not None: 68 | try: 69 | env.set_seed(base_seed) 70 | except: 71 | env.seed(base_seed) 72 | np.random.seed(base_seed) 73 | else: 74 | np.random.seed() 75 | # horizon = min(horizon, env.horizon) 76 | paths = [] 77 | 78 | ep = 0 79 | while ep < num_traj: 80 | # seeding 81 | if base_seed is not None: 82 | seed = base_seed + ep 83 | try: 84 | env.set_seed(seed) 85 | except: 86 | env.seed(seed) 87 | np.random.seed(seed) 88 | 89 | observations=[] 90 | actions=[] 91 | rewards=[] 92 | agent_infos = [] 93 | env_infos = [] 94 | 95 | o = env.reset() 96 | done = False 97 | t = 0 98 | ims = [] 99 | try: 100 | ims.append(env.env.env.get_image()) 101 | except: 102 | ## For state based learning 103 | pass 104 | 105 | ## MetaWorld vs. Adroit/Kitchen syntax 106 | try: 107 | init_state = env.__getstate__() 108 | except: 109 | init_state = env.get_env_state() 110 | 111 | while t < horizon and done != True: 112 | a, agent_info = policy.get_action(o) 113 | if eval_mode: 114 | a = agent_info['evaluation'] 115 | 116 | next_o, r, done, env_info_step = env.step(a) 117 | env_info = env_info_step #if env_info_base == {} else env_info_base 118 | observations.append(o) 119 | actions.append(a) 120 | rewards.append(r) 121 | try: 122 | ims.append(env.env.env.get_image()) 123 | except: 124 | pass 125 | agent_infos.append(agent_info) 126 | env_infos.append(env_info) 127 | o = next_o 128 | t += 1 129 | 130 | path = dict( 131 | observations=np.array(observations), 132 | actions=np.array(actions), 133 | rewards=np.array(rewards), 134 | agent_infos=tensor_utils.stack_tensor_dict_list(agent_infos), 135 | env_infos=tensor_utils.stack_tensor_dict_list(env_infos), 136 | terminated=done, 137 | init_state = init_state, 138 | images=ims 139 | ) 140 | 141 | paths.append(path) 142 | ep += 1 143 | 144 | del(env) 145 | gc.collect() 146 | return paths 147 | 148 | 149 | def sample_paths( 150 | num_traj, 151 | env, 152 | policy, 153 | eval_mode = False, 154 | horizon = 1e6, 155 | base_seed = None, 156 | num_cpu = 1, 157 | max_process_time=300, 158 | max_timeouts=4, 159 | suppress_print=False, 160 | env_kwargs=None, 161 | ): 162 | 163 | num_cpu = 1 if num_cpu is None else num_cpu 164 | num_cpu = mp.cpu_count() if num_cpu == 'max' else num_cpu 165 | assert type(num_cpu) == int 166 | 167 | if num_cpu == 1: 168 | input_dict = dict(num_traj=num_traj, env=env, policy=policy, 169 | eval_mode=eval_mode, horizon=horizon, base_seed=base_seed, 170 | env_kwargs=env_kwargs) 171 | # dont invoke multiprocessing if not necessary 172 | return do_rollout(**input_dict) 173 | 174 | # do multiprocessing otherwise 175 | paths_per_cpu = int(np.ceil(num_traj/num_cpu)) 176 | input_dict_list= [] 177 | for i in range(num_cpu): 178 | input_dict = dict(num_traj=paths_per_cpu, env=env, policy=policy, 179 | eval_mode=eval_mode, horizon=horizon, 180 | base_seed=base_seed + i * paths_per_cpu, 181 | env_kwargs=env_kwargs) 182 | input_dict_list.append(input_dict) 183 | if suppress_print is False: 184 | start_time = timer.time() 185 | print("####### Gathering Samples #######") 186 | 187 | results = _try_multiprocess(do_rollout, input_dict_list, 188 | num_cpu, max_process_time, max_timeouts) 189 | paths = [] 190 | # result is a paths type and results is list of paths 191 | for result in results: 192 | for path in result: 193 | paths.append(path) 194 | 195 | if suppress_print is False: 196 | print("======= Samples Gathered ======= | >>>> Time taken = %f " %(timer.time()-start_time) ) 197 | 198 | return paths 199 | 200 | def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts): 201 | 202 | # Base case 203 | if max_timeouts == 0: 204 | return None 205 | 206 | pool = mp.Pool(processes=num_cpu, maxtasksperchild=None) 207 | parallel_runs = [pool.apply_async(func, kwds=input_dict) for input_dict in input_dict_list] 208 | try: 209 | results = [p.get(timeout=max_process_time) for p in parallel_runs] 210 | except Exception as e: 211 | print(str(e)) 212 | print("Timeout Error raised... Trying again") 213 | pool.close() 214 | pool.terminate() 215 | pool.join() 216 | return _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1) 217 | 218 | pool.close() 219 | pool.terminate() 220 | pool.join() 221 | return results 222 | -------------------------------------------------------------------------------- /evaluation/r3meval/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import operator 6 | 7 | import numpy as np 8 | 9 | 10 | def flatten_tensors(tensors): 11 | if len(tensors) > 0: 12 | return np.concatenate([np.reshape(x, [-1]) for x in tensors]) 13 | else: 14 | return np.asarray([]) 15 | 16 | 17 | def unflatten_tensors(flattened, tensor_shapes): 18 | tensor_sizes = list(map(np.prod, tensor_shapes)) 19 | indices = np.cumsum(tensor_sizes)[:-1] 20 | return [np.reshape(pair[0], pair[1]) for pair in zip(np.split(flattened, indices), tensor_shapes)] 21 | 22 | 23 | def pad_tensor(x, max_len, mode='zero'): 24 | padding = np.zeros_like(x[0]) 25 | if mode == 'last': 26 | padding = x[-1] 27 | return np.concatenate([ 28 | x, 29 | np.tile(padding, (max_len - len(x),) + (1,) * np.ndim(x[0])) 30 | ]) 31 | 32 | 33 | def pad_tensor_n(xs, max_len): 34 | ret = np.zeros((len(xs), max_len) + xs[0].shape[1:], dtype=xs[0].dtype) 35 | for idx, x in enumerate(xs): 36 | ret[idx][:len(x)] = x 37 | return ret 38 | 39 | 40 | def pad_tensor_dict(tensor_dict, max_len, mode='zero'): 41 | keys = list(tensor_dict.keys()) 42 | ret = dict() 43 | for k in keys: 44 | if isinstance(tensor_dict[k], dict): 45 | ret[k] = pad_tensor_dict(tensor_dict[k], max_len, mode=mode) 46 | else: 47 | ret[k] = pad_tensor(tensor_dict[k], max_len, mode=mode) 48 | return ret 49 | 50 | 51 | def flatten_first_axis_tensor_dict(tensor_dict): 52 | keys = list(tensor_dict.keys()) 53 | ret = dict() 54 | for k in keys: 55 | if isinstance(tensor_dict[k], dict): 56 | ret[k] = flatten_first_axis_tensor_dict(tensor_dict[k]) 57 | else: 58 | old_shape = tensor_dict[k].shape 59 | ret[k] = tensor_dict[k].reshape((-1,) + old_shape[2:]) 60 | return ret 61 | 62 | 63 | def high_res_normalize(probs): 64 | return [x / sum(map(float, probs)) for x in list(map(float, probs))] 65 | 66 | 67 | def stack_tensor_list(tensor_list): 68 | return np.array(tensor_list) 69 | # tensor_shape = np.array(tensor_list[0]).shape 70 | # if tensor_shape is tuple(): 71 | # return np.array(tensor_list) 72 | # return np.vstack(tensor_list) 73 | 74 | 75 | def stack_tensor_dict_list(tensor_dict_list): 76 | """ 77 | Stack a list of dictionaries of {tensors or dictionary of tensors}. 78 | :param tensor_dict_list: a list of dictionaries of {tensors or dictionary of tensors}. 79 | :return: a dictionary of {stacked tensors or dictionary of stacked tensors} 80 | """ 81 | keys = list(tensor_dict_list[0].keys()) 82 | ret = dict() 83 | for k in keys: 84 | example = tensor_dict_list[0][k] 85 | if isinstance(example, dict): 86 | v = stack_tensor_dict_list([x[k] for x in tensor_dict_list]) 87 | else: 88 | v = stack_tensor_list([x[k] for x in tensor_dict_list]) 89 | ret[k] = v 90 | return ret 91 | 92 | 93 | def concat_tensor_list_subsample(tensor_list, f): 94 | return np.concatenate( 95 | [t[np.random.choice(len(t), int(np.ceil(len(t) * f)), replace=False)] for t in tensor_list], axis=0) 96 | 97 | 98 | def concat_tensor_dict_list_subsample(tensor_dict_list, f): 99 | keys = list(tensor_dict_list[0].keys()) 100 | ret = dict() 101 | for k in keys: 102 | example = tensor_dict_list[0][k] 103 | if isinstance(example, dict): 104 | v = concat_tensor_dict_list_subsample([x[k] for x in tensor_dict_list], f) 105 | else: 106 | v = concat_tensor_list_subsample([x[k] for x in tensor_dict_list], f) 107 | ret[k] = v 108 | return ret 109 | 110 | 111 | def concat_tensor_list(tensor_list): 112 | return np.concatenate(tensor_list, axis=0) 113 | 114 | 115 | def concat_tensor_dict_list(tensor_dict_list): 116 | keys = list(tensor_dict_list[0].keys()) 117 | ret = dict() 118 | for k in keys: 119 | example = tensor_dict_list[0][k] 120 | if isinstance(example, dict): 121 | v = concat_tensor_dict_list([x[k] for x in tensor_dict_list]) 122 | else: 123 | v = concat_tensor_list([x[k] for x in tensor_dict_list]) 124 | ret[k] = v 125 | return ret 126 | 127 | 128 | def split_tensor_dict_list(tensor_dict): 129 | keys = list(tensor_dict.keys()) 130 | ret = None 131 | for k in keys: 132 | vals = tensor_dict[k] 133 | if isinstance(vals, dict): 134 | vals = split_tensor_dict_list(vals) 135 | if ret is None: 136 | ret = [{k: v} for v in vals] 137 | else: 138 | for v, cur_dict in zip(vals, ret): 139 | cur_dict[k] = v 140 | return ret 141 | 142 | 143 | def truncate_tensor_list(tensor_list, truncated_len): 144 | return tensor_list[:truncated_len] 145 | 146 | 147 | def truncate_tensor_dict(tensor_dict, truncated_len): 148 | ret = dict() 149 | for k, v in tensor_dict.items(): 150 | if isinstance(v, dict): 151 | ret[k] = truncate_tensor_dict(v, truncated_len) 152 | else: 153 | ret[k] = truncate_tensor_list(v, truncated_len) 154 | return ret 155 | -------------------------------------------------------------------------------- /evaluation/r3meval/utils/visualizations.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torchvision import transforms as T 9 | 10 | import mvp 11 | 12 | def place_attention_heatmap_over_images(images, vision_model, model_type, head=1): 13 | 14 | H, W = 224, 224 15 | patch_size = 16 16 | alpha = .4 17 | new_H, new_W = H//patch_size, W//patch_size 18 | 19 | transforms = T.Compose([T.ToTensor(), 20 | T.Resize(H), 21 | T.Normalize((0.485, 0.456, 0.406), 22 | (0.229, 0.224, 0.225))]) 23 | 24 | cmap = plt.get_cmap('jet') 25 | 26 | 27 | heatmap_images = [] 28 | for image in images: # can vectorize? 29 | 30 | # resize and normalize image to feed into model 31 | image = image.copy() 32 | torch_image = transforms(image) 33 | 34 | # grab the output attention map at the desired attention head 35 | if model_type == 'mvp': 36 | attn = vision_model.forward_attention(torch_image.unsqueeze(0).to('cuda'), layer=11) 37 | elif model_type == 'dino': 38 | attn = vision_model.get_last_selfattention(torch_image.unsqueeze(0).to('cuda')) 39 | else: 40 | raise ValueError(f'Visualization with {model_type} not supported') 41 | attn_map = attn[0,head,0,1:].reshape(1, 1, new_H, new_W) # B, C, H, W 42 | 43 | # interpolate smoothly to create a heatmap 44 | resized_attn_map = F.interpolate(attn_map, scale_factor=patch_size, 45 | mode='bilinear') 46 | resized_attn_map = resized_attn_map.cpu().detach().numpy().squeeze() 47 | 48 | # convert attention scores to heatmap 49 | image = cv2.resize(image, (W, H)) 50 | heatmap = cmap(resized_attn_map/resized_attn_map.max()) 51 | heatmap *= 255 52 | heatmap = heatmap[:,:,:3] 53 | heatmap_image = (.8*image + .2*heatmap).astype(int) 54 | heatmap_image = np.clip(heatmap_image, 0, 255) 55 | heatmap_images.append(heatmap_image) 56 | 57 | return heatmap_images 58 | 59 | -------------------------------------------------------------------------------- /evaluation/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=iris-hi 3 | #SBATCH --mem=32G 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --exclude=iris4,iris2,iris-hp-z8 6 | #SBATCH --job-name="fancy new architecture" 7 | #SBATCH --time=3-0:0 8 | 9 | export ENV_NAME=${1} 10 | export SEED=${2} 11 | export CAM_NAME=${3} 12 | export EMB_NAME=${4} 13 | export LOAD_PATH=${5} 14 | export NUM_DEMOS=10 15 | 16 | 17 | if [[ "${1}" == *"v2"* ]]; then 18 | echo "Using proprio=4 for Meta-World environment." 19 | export PROPRIO=4 20 | else 21 | echo "Using proprio=9 for FrankaKitchen environment." 22 | export PROPRIO=9 23 | fi 24 | 25 | source /sailhome/kayburns/.bashrc 26 | conda activate py3.8_torch1.10.1 27 | cd /iris/u/kayburns/new_arch/r3m/evaluation/ 28 | python r3meval/core/hydra_launcher.py hydra/launcher=local hydra/output=local \ 29 | env=${ENV_NAME} camera=${CAM_NAME} pixel_based=true \ 30 | embedding=${EMB_NAME} num_demos=${NUM_DEMOS} env_kwargs.load_path=${LOAD_PATH} \ 31 | bc_kwargs.finetune=false proprio=${PROPRIO} job_name=r3m_repro seed=${SEED} 32 | 33 | # fine-tune all heads, last layer 34 | # cd /iris/u/kayburns/new_arch/r3m/evaluation/ 35 | # python r3meval/core/hydra_launcher.py hydra/launcher=local hydra/output=local \ 36 | # env="kitchen_sdoor_open-v3" camera="left_cap2" pixel_based=true \ 37 | # embedding=dino num_demos=5 env_kwargs.load_path=dino \ 38 | # bc_kwargs.finetune=true proprio=9 job_name=r3m_repro_all seed=123 39 | 40 | # blind baseline 41 | # python r3meval/core/hydra_launcher.py hydra/launcher=local hydra/output=local \ 42 | # env="kitchen_sdoor_open-v3" camera="left_cap2" pixel_based=true \ 43 | # embedding=ignore_input num_demos=5 env_kwargs.load_path=ignore_input \ 44 | # bc_kwargs.finetune=false proprio=9 job_name=r3m_repro_random seed=125 -------------------------------------------------------------------------------- /evaluation/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import os 6 | import sys 7 | from setuptools import setup, find_packages 8 | 9 | if sys.version_info.major != 3: 10 | print("This Python is only compatible with Python 3, but you are running " 11 | "Python {}. The installation will likely fail.".format(sys.version_info.major)) 12 | 13 | def read(fname): 14 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 15 | 16 | setup( 17 | name='r3meval', 18 | description='Policy learning from pixels using pre-trained representations, used in the R3M paper', 19 | install_requires=[ 20 | ], 21 | ) 22 | -------------------------------------------------------------------------------- /plots/mvp_dino_r3m.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas 5 | 6 | hex_colors = [ 7 | "#FF6150", 8 | "#134E6F", 9 | "#1AC0C6", 10 | "#FFA822", 11 | "#DEE0E6", 12 | "#091A29" 13 | ] 14 | 15 | lighter_hex_colors = [ 16 | "#ff8f82", 17 | "#4d83a3", 18 | "#92d4d6", 19 | "#FFA822", 20 | "#DEE0E6", 21 | "#091A29" 22 | ] 23 | 24 | noft = pandas.read_csv('./run_data/mvp_dino_r3m_no_default.csv') 25 | noft = pandas.read_csv('./run_data/mvp_dino_r3m_metaworld.csv') 26 | ft = pandas.read_csv('./run_data/dino_r3m_ft_no_default.csv') 27 | mvp_ft = pandas.read_csv('./run_data/mvp_ft_sdoor_no_default.csv') 28 | mvp_ft_others = pandas.read_csv('./run_data/mvp_ft_other_tasks_no_default.csv') 29 | mvp_ft = pandas.concat([mvp_ft, mvp_ft_others]) 30 | 31 | # lighting_columns = ['eval_successbrighter', 'eval_successdarker', 'eval_successleft', 'eval_successright'] 32 | # texture_columns = ['eval_successmetal2', 'eval_successtile1', 'eval_successwood2'] 33 | # distractor_columns = ['eval_successbox', 'eval_successmedium', 'eval_successhard'] 34 | 35 | lighting_columns = ['eval_successbrighter', 'eval_successdarker', 'eval_successleft', 'eval_successright'] 36 | texture_columns = ['eval_successblue-woodtable', 'eval_successdark-woodtable', 'eval_successdarkwoodtable'] 37 | distractor_columns = ['eval_successeasy', 'eval_successmedium', 'eval_successhard'] 38 | all_dist_shift_columns = lighting_columns + texture_columns + distractor_columns 39 | 40 | # add lighting, texture, distractor, and all average test columns 41 | noft['Train Dist.'] = noft['eval_success'] 42 | noft['Test (Lighting)'] = noft[lighting_columns].mean(axis=1) 43 | noft['Test (Texture)'] = noft[texture_columns].mean(axis=1) 44 | noft['Test (Distractors)'] = noft[distractor_columns].mean(axis=1) 45 | noft['Test (Avg)'] = noft[all_dist_shift_columns].mean(axis=1) 46 | # ft['Train Dist.'] = ft['eval_success'] 47 | # ft['Test (Lighting)'] = ft[lighting_columns].mean(axis=1) 48 | # ft['Test (Texture)'] = ft[texture_columns].mean(axis=1) 49 | # ft['Test (Distractors)'] = ft[distractor_columns].mean(axis=1) 50 | # ft['Test (Avg)'] = ft[all_dist_shift_columns].mean(axis=1) 51 | mvp_ft['Train Dist.'] = mvp_ft['eval_success'] 52 | mvp_ft['Test (Lighting)'] = mvp_ft[lighting_columns].mean(axis=1) 53 | mvp_ft['Test (Texture)'] = mvp_ft[texture_columns].mean(axis=1) 54 | mvp_ft['Test (Distractors)'] = mvp_ft[distractor_columns].mean(axis=1) 55 | mvp_ft['Test (Avg)'] = mvp_ft[all_dist_shift_columns].mean(axis=1) 56 | 57 | # aggregate results by seed 58 | mvp_noft = noft[noft['embedding'] == 'mvp'].groupby('seed').mean() 59 | r3m_noft = noft[noft['embedding'] == 'resnet50'].groupby('seed').mean() 60 | # r3m_ft = ft[ft['embedding'] == 'resnet50'].groupby('seed').mean() 61 | dino_noft = noft[noft['embedding'] == 'dino'].groupby('seed').mean() 62 | # dino_ft = ft[ft['embedding'] == 'dino'].groupby('seed').mean() 63 | # mvp_ft = mvp_ft.groupby('seed').mean() 64 | # # plot standard error by seed 65 | mvp_noft_mean = mvp_noft.mean() 66 | mvp_noft_err = (mvp_noft.std()*1.96) / np.sqrt(3) 67 | # mvp_ft_mean = mvp_ft.mean() 68 | # mvp_ft_err = (mvp_ft.std()*1.96) / np.sqrt(3) 69 | r3m_noft_mean = r3m_noft.mean() 70 | r3m_noft_err = (r3m_noft.std()*1.96) / np.sqrt(3) 71 | # r3m_ft_mean = r3m_ft.mean() 72 | # r3m_ft_err = (r3m_ft.std()*1.96) / np.sqrt(3) 73 | dino_noft_mean = dino_noft.mean() 74 | dino_noft_err = (dino_noft.std()*1.96) / np.sqrt(3) 75 | # dino_ft_mean = dino_ft.mean() 76 | # dino_ft_err = (dino_ft.std()*1.96) / np.sqrt(3) 77 | 78 | # make bar chart (mvp vs dino vs r3m figure) 79 | # labels = ['MVP', 'MVP (FT)', 'R3M', 'R3M (FT)', 'DiNo', 'DiNo (FT)'] 80 | labels = ['Train Dist.', 'Test (Lighting)', 'Test (Texture)', 'Test (Distractors)', 'Test (Avg)'] 81 | labels_pretty = ['Train Dist.', 'Lighting', 'Texture', 'Distractors', 'Zero-Shot Avg'] 82 | mvp_noft_means = [mvp_noft_mean[label] for label in labels] 83 | # mvp_ft_means = [mvp_ft_mean[label] for label in labels] 84 | r3m_noft_means = [r3m_noft_mean[label] for label in labels] 85 | # r3m_ft_means = [r3m_ft_mean[label] for label in labels] 86 | dino_noft_means = [dino_noft_mean[label] for label in labels] 87 | # dino_ft_means = [dino_ft_mean[label] for label in labels] 88 | mvp_noft_errs = [mvp_noft_err[label] for label in labels] 89 | # mvp_ft_errs = [mvp_ft_err[label] for label in labels] 90 | r3m_noft_errs = [r3m_noft_err[label] for label in labels] 91 | # r3m_ft_errs = [r3m_ft_err[label] for label in labels] 92 | dino_noft_errs = [dino_noft_err[label] for label in labels] 93 | # dino_ft_errs = [dino_ft_err[label] for label in labels] 94 | 95 | # ft and noft 96 | # def plot_ft_and_noft(): 97 | # x = np.arange(len(labels))*2.5 # the label locations 98 | # width = 0.35 # the width of the bars 99 | # delta = .01 100 | 101 | # fig, ax = plt.subplots(figsize=(20,6)) 102 | # rects1 = ax.bar(x - width*2.5 - delta, mvp_noft_means, width, yerr=mvp_noft_errs, label='MVP', color=lighter_hex_colors[2]) 103 | # rects2 = ax.bar(x - width*1.5 - delta, r3m_noft_means, width, yerr=r3m_noft_errs, label='R3M', color=lighter_hex_colors[1]) 104 | # rects3 = ax.bar(x - width/2 - delta, dino_noft_means, width, yerr=dino_noft_errs, label='DiNo', color=lighter_hex_colors[0]) 105 | # rects4 = ax.bar(x + width/2 + delta, r3m_ft_means, width, yerr=r3m_ft_errs, label='MVP (FT)', color=hex_colors[2]) 106 | # rects5 = ax.bar(x + width*1.5 + delta, r3m_ft_means, width, yerr=r3m_ft_errs, label='R3M (FT)', color=hex_colors[1]) 107 | # rects6 = ax.bar(x + width*2.5 + delta, dino_ft_means, width, yerr=dino_ft_errs, label='DiNo (FT)', color=hex_colors[0]) 108 | 109 | # # Add some text for labels, title and custom x-axis tick labels, etc. 110 | # ax.set_ylabel('Success', fontsize=18) 111 | # # ax.set_title('Zero-Shot Transfer Performance', fontsize=18) 112 | # ax.set_xticks(x, labels_pretty, fontsize=18) 113 | # plt.axvline(x=1.25, color='black', linestyle='--') 114 | # ax.legend(fontsize=18, ncol=2) 115 | # ax.spines[['right', 'top']].set_visible(False) 116 | 117 | # # ax.bar_label(rects4, padding=3) 118 | # # ax.bar_label(rects5, padding=3) 119 | # # ax.bar_label(rects6, padding=3) 120 | 121 | # fig.tight_layout() 122 | # plt.savefig('mvp_dino_r3m_ft_and_noft.png') 123 | # plot_ft_and_noft() 124 | 125 | 126 | # mvp_noft_means, mvp_noft_errs = mvp_ft_means, mvp_ft_errs 127 | # r3m_noft_means, r3m_noft_errs = r3m_ft_means, r3m_ft_errs 128 | # dino_noft_means, dino_noft_errs = dino_ft_means, dino_ft_errs 129 | # just noft 130 | def plot_just_noft(): 131 | matplotlib.rcParams.update({'font.size': 22}) 132 | x = np.arange(len(labels))*1.2 # the label locations 133 | width = 0.3 # the width of the bars 134 | delta = .01 135 | 136 | print(mvp_noft_means[-1], r3m_noft_means[-1], dino_noft_means[-1]) 137 | fig, ax = plt.subplots(figsize=(12,6)) 138 | rects1 = ax.bar(x - width, mvp_noft_means, width, yerr=mvp_noft_errs, label='MVP', color=hex_colors[2]) 139 | rects2 = ax.bar(x, r3m_noft_means, width, yerr=r3m_noft_errs, label='R3M', color=hex_colors[1]) 140 | rects3 = ax.bar(x + width, dino_noft_means, width, yerr=dino_noft_errs, label='DiNo', color=hex_colors[0]) 141 | 142 | # Add some text for labels, title and custom x-axis tick labels, etc. 143 | ax.set_ylabel('Success', fontsize=22) 144 | ax.set_title('Meta-World', fontsize=22) 145 | ax.set_xticks(x, labels_pretty, fontsize=22) 146 | plt.axvline(x=.6, color='black', linestyle='--') 147 | plt.axvline(x=4.2, color='gray') 148 | ax.legend(fontsize=22) 149 | ax.spines[['right', 'top']].set_visible(False) 150 | 151 | # ax.bar_label(rects4, padding=3) 152 | # ax.bar_label(rects5, padding=3) 153 | # ax.bar_label(rects6, padding=3) 154 | 155 | fig.tight_layout() 156 | plt.savefig('mvp_dino_r3m_metaworld.png') 157 | plot_just_noft() 158 | -------------------------------------------------------------------------------- /r3m/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from r3m.models.models_r3m import R3M 6 | 7 | import os 8 | from os.path import expanduser 9 | import omegaconf 10 | import hydra 11 | import gdown 12 | import torch 13 | import copy 14 | 15 | VALID_ARGS = ["_target_", "device", "lr", "hidden_dim", "size", "l2weight", "l1weight", "langweight", "tcnweight", "l2dist", "bs"] 16 | if torch.cuda.is_available(): 17 | device = "cuda" 18 | else: 19 | device = "cpu" 20 | 21 | def cleanup_config(cfg): 22 | config = copy.deepcopy(cfg) 23 | keys = config.agent.keys() 24 | for key in list(keys): 25 | if key not in VALID_ARGS: 26 | del config.agent[key] 27 | config.agent["_target_"] = "r3m.R3M" 28 | config["device"] = device 29 | 30 | ## Hardcodes to remove the language head 31 | ## Assumes downstream use is as visual representation 32 | config.agent["langweight"] = 0 33 | return config.agent 34 | 35 | def remove_language_head(state_dict): 36 | keys = state_dict.keys() 37 | ## Hardcodes to remove the language head 38 | ## Assumes downstream use is as visual representation 39 | for key in list(keys): 40 | if ("lang_enc" in key) or ("lang_rew" in key): 41 | del state_dict[key] 42 | return state_dict 43 | 44 | def load_r3m(modelid): 45 | home = os.path.join(expanduser("~"), ".r3m") 46 | if modelid == "resnet50": 47 | foldername = "r3m_50" 48 | modelurl = 'https://drive.google.com/uc?id=1Xu0ssuG0N1zjZS54wmWzJ7-nb0-7XzbA' 49 | configurl = 'https://drive.google.com/uc?id=10jY2VxrrhfOdNPmsFdES568hjjIoBJx8' 50 | elif modelid == "resnet34": 51 | foldername = "r3m_34" 52 | modelurl = 'https://drive.google.com/uc?id=15bXD3QRhspIRacOKyWPw5y2HpoWUCEnE' 53 | configurl = 'https://drive.google.com/uc?id=1RY0NS-Tl4G7M1Ik_lOym0b5VIBxX9dqW' 54 | elif modelid == "resnet18": 55 | foldername = "r3m_18" 56 | modelurl = 'https://drive.google.com/uc?id=1A1ic-p4KtYlKXdXHcV2QV0cUzI4kn0u-' 57 | configurl = 'https://drive.google.com/uc?id=1nitbHQ-GRorxc7vMUiEHjHWP5N11Jvc6' 58 | else: 59 | raise NameError('Invalid Model ID') 60 | 61 | if not os.path.exists(os.path.join(home, foldername)): 62 | os.makedirs(os.path.join(home, foldername)) 63 | modelpath = os.path.join(home, foldername, "model.pt") 64 | configpath = os.path.join(home, foldername, "config.yaml") 65 | if not os.path.exists(modelpath): 66 | gdown.download(modelurl, modelpath, quiet=False) 67 | gdown.download(configurl, configpath, quiet=False) 68 | 69 | modelcfg = omegaconf.OmegaConf.load(configpath) 70 | cleancfg = cleanup_config(modelcfg) 71 | rep = hydra.utils.instantiate(cleancfg) 72 | rep = torch.nn.DataParallel(rep) 73 | r3m_state_dict = remove_language_head(torch.load(modelpath, map_location=torch.device(device))['r3m']) 74 | rep.load_state_dict(r3m_state_dict) 75 | return rep 76 | 77 | def load_r3m_reproduce(modelid): 78 | home = os.path.join(expanduser("~"), ".r3m") 79 | if modelid == "r3m": 80 | foldername = "original_r3m" 81 | modelurl = 'https://drive.google.com/uc?id=1jLb1yldIMfAcGVwYojSQmMpmRM7vqjp9' 82 | configurl = 'https://drive.google.com/uc?id=1cu-Pb33qcfAieRIUptNlG1AQIMZlAI-q' 83 | elif modelid == "r3m_noaug": 84 | foldername = "original_r3m_noaug" 85 | modelurl = 'https://drive.google.com/uc?id=1k_ZlVtvlktoYLtBcfD0aVFnrZcyCNS9D' 86 | configurl = 'https://drive.google.com/uc?id=1hPmJwDiWPkd6GGez6ywSC7UOTIX7NgeS' 87 | elif modelif == "r3m_nol1": 88 | foldername = "original_r3m_nol1" 89 | modelurl = 'https://drive.google.com/uc?id=1LpW3aBMdjoXsjYlkaDnvwx7q22myM_nB' 90 | configurl = 'https://drive.google.com/uc?id=1rZUBrYJZvlF1ReFwRidZsH7-xe7csvab' 91 | elif modelif == "r3m_nolang": 92 | foldername = "original_r3m_nolang" 93 | modelurl = 'https://drive.google.com/uc?id=1FXcniRei2JDaGMJJ_KlVxHaLy0Fs_caV' 94 | configurl = 'https://drive.google.com/uc?id=192G4UkcNJO4EKN46ECujMcH0AQVhnyQe' 95 | else: 96 | raise NameError('Invalid Model ID') 97 | 98 | if not os.path.exists(os.path.join(home, foldername)): 99 | os.makedirs(os.path.join(home, foldername)) 100 | modelpath = os.path.join(home, foldername, "model.pt") 101 | configpath = os.path.join(home, foldername, "config.yaml") 102 | if not os.path.exists(modelpath): 103 | gdown.download(modelurl, modelpath, quiet=False) 104 | gdown.download(configurl, configpath, quiet=False) 105 | 106 | modelcfg = omegaconf.OmegaConf.load(configpath) 107 | cleancfg = cleanup_config(modelcfg) 108 | rep = hydra.utils.instantiate(cleancfg) 109 | rep = torch.nn.DataParallel(rep) 110 | r3m_state_dict = remove_language_head(torch.load(modelpath, map_location=torch.device(device))['r3m']) 111 | 112 | rep.load_state_dict(r3m_state_dict) 113 | return rep 114 | -------------------------------------------------------------------------------- /r3m/cfgs/config_rep.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/launcher: local 4 | - override hydra/output: local 5 | 6 | 7 | # snapshot 8 | save_snapshot: false 9 | load_snap: "" 10 | # replay buffer 11 | num_workers: 10 12 | batch_size: 32 #256 13 | train_steps: 2000000 14 | eval_freq: 20000 15 | # misc 16 | seed: 1 17 | device: cuda 18 | # experiment 19 | experiment: train_r3m 20 | # agent 21 | lr: 1e-4 22 | # data 23 | alpha: 0.2 24 | dataset: "ego4d" 25 | wandbproject: 26 | wandbuser: 27 | doaug: "none" 28 | datapath: 29 | 30 | agent: 31 | _target_: r3m.R3M 32 | device: ${device} 33 | lr: ${lr} 34 | hidden_dim: 1024 35 | size: 34 36 | l2weight: 0.00001 37 | l1weight: 0.00001 38 | tcnweight: 1.0 39 | langweight: 0.0 40 | l2dist: true 41 | bs: ${batch_size} 42 | -------------------------------------------------------------------------------- /r3m/cfgs/hydra/launcher/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | launcher: 4 | cpus_per_task: 20 5 | gpus_per_node: 0 6 | tasks_per_node: 1 7 | timeout_min: 600 8 | mem_gb: 64 9 | name: ${hydra.job.name} 10 | _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher 11 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 12 | -------------------------------------------------------------------------------- /r3m/cfgs/hydra/output/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | run: 4 | dir: ./r3moutput/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 5 | subdir: ${hydra.job.num}_${hydra.job.override_dirname} 6 | sweep: 7 | dir: ./r3moutput/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 8 | subdir: ${hydra.job.num}_${hydra.job.override_dirname} -------------------------------------------------------------------------------- /r3m/example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import omegaconf 6 | import hydra 7 | import torch 8 | import torchvision.transforms as T 9 | import numpy as np 10 | from PIL import Image 11 | 12 | from r3m import load_r3m 13 | 14 | if torch.cuda.is_available(): 15 | device = "cuda" 16 | else: 17 | device = "cpu" 18 | 19 | r3m = load_r3m("resnet50") # resnet18, resnet34 20 | r3m.eval() 21 | r3m.to(device) 22 | 23 | ## DEFINE PREPROCESSING 24 | transforms = T.Compose([T.Resize(256), 25 | T.CenterCrop(224), 26 | T.ToTensor()]) # ToTensor() divides by 255 27 | 28 | ## ENCODE IMAGE 29 | image = np.random.randint(0, 255, (500, 500, 3)) 30 | preprocessed_image = transforms(Image.fromarray(image.astype(np.uint8))).reshape(-1, 3, 224, 224) 31 | preprocessed_image.to(device) 32 | with torch.no_grad(): 33 | embedding = r3m(preprocessed_image * 255.0) ## R3M expects image input to be [0-255] 34 | print(embedding.shape) # [1, 2048] -------------------------------------------------------------------------------- /r3m/models/models_language.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import numpy as np 6 | from numpy.core.numeric import full 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.modules.activation import Sigmoid 10 | 11 | epsilon = 1e-8 12 | 13 | class LangEncoder(nn.Module): 14 | def __init__(self, device, finetune = False, scratch=False): 15 | super().__init__() 16 | from transformers import AutoTokenizer, AutoModel, AutoConfig 17 | self.device = device 18 | self.modelname = "distilbert-base-uncased" 19 | self.tokenizer = AutoTokenizer.from_pretrained(self.modelname) 20 | self.model = AutoModel.from_pretrained(self.modelname).to(self.device) 21 | self.lang_size = 768 22 | 23 | def forward(self, langs): 24 | try: 25 | langs = langs.tolist() 26 | except: 27 | pass 28 | 29 | with torch.no_grad(): 30 | encoded_input = self.tokenizer(langs, return_tensors='pt', padding=True) 31 | input_ids = encoded_input['input_ids'].to(self.device) 32 | attention_mask = encoded_input['attention_mask'].to(self.device) 33 | lang_embedding = self.model(input_ids, attention_mask=attention_mask).last_hidden_state 34 | lang_embedding = lang_embedding.mean(1) 35 | return lang_embedding 36 | 37 | class LanguageReward(nn.Module): 38 | def __init__(self, ltype, im_dim, hidden_dim, lang_dim, simfunc=None): 39 | super().__init__() 40 | self.ltype = ltype 41 | self.sim = simfunc 42 | self.sigm = Sigmoid() 43 | self.pred = nn.Sequential(nn.Linear(im_dim * 2 + lang_dim, hidden_dim), 44 | nn.ReLU(inplace=True), 45 | nn.Linear(hidden_dim, hidden_dim), 46 | nn.ReLU(inplace=True), 47 | nn.Linear(hidden_dim, hidden_dim), 48 | nn.ReLU(inplace=True), 49 | nn.Linear(hidden_dim, hidden_dim), 50 | nn.ReLU(inplace=True), 51 | nn.Linear(hidden_dim, 1)) 52 | 53 | def forward(self, e0, eg, le): 54 | info = {} 55 | return self.pred(torch.cat([e0, eg, le], -1)).squeeze(), info -------------------------------------------------------------------------------- /r3m/models/models_r3m.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import numpy as np 6 | from numpy.core.numeric import full 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.modules.activation import Sigmoid 10 | from torch.nn.modules.linear import Identity 11 | import torchvision 12 | from torchvision import transforms 13 | from r3m import utils 14 | from pathlib import Path 15 | from torchvision.utils import save_image 16 | import torchvision.transforms as T 17 | 18 | epsilon = 1e-8 19 | def do_nothing(x): return x 20 | 21 | class R3M(nn.Module): 22 | def __init__(self, device, lr, hidden_dim, size=34, l2weight=1.0, l1weight=1.0, 23 | langweight=1.0, tcnweight=0.0, l2dist=True, bs=16): 24 | super().__init__() 25 | 26 | self.device = device 27 | self.use_tb = False 28 | self.l2weight = l2weight 29 | self.l1weight = l1weight 30 | self.tcnweight = tcnweight ## Weight on TCN loss (states closer in same clip closer in embedding) 31 | self.l2dist = l2dist ## Use -l2 or cosine sim 32 | self.langweight = langweight ## Weight on language reward 33 | self.size = size ## Size ResNet or ViT 34 | self.num_negatives = 3 35 | 36 | ## Distances and Metrics 37 | self.cs = torch.nn.CosineSimilarity(1) 38 | self.bce = nn.BCELoss(reduce=False) 39 | self.sigm = Sigmoid() 40 | 41 | params = [] 42 | ######################################################################## Sub Modules 43 | ## Visual Encoder 44 | if size == 18: 45 | self.outdim = 512 46 | self.convnet = torchvision.models.resnet18(pretrained=False) 47 | elif size == 34: 48 | self.outdim = 512 49 | self.convnet = torchvision.models.resnet34(pretrained=False) 50 | elif size == 50: 51 | self.outdim = 2048 52 | self.convnet = torchvision.models.resnet50(pretrained=False) 53 | elif size == 0: 54 | from transformers import AutoConfig 55 | self.outdim = 768 56 | self.convnet = AutoModel.from_config(config = AutoConfig.from_pretrained('google/vit-base-patch32-224-in21k')).to(self.device) 57 | 58 | if self.size == 0: 59 | self.normlayer = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 60 | else: 61 | self.normlayer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 62 | self.convnet.fc = Identity() 63 | self.convnet.train() 64 | params += list(self.convnet.parameters()) 65 | 66 | ## Language Reward 67 | if self.langweight > 0.0: 68 | ## Pretrained DistilBERT Sentence Encoder 69 | from r3m.models.models_language import LangEncoder, LanguageReward 70 | self.lang_enc = LangEncoder(self.device, 0, 0) 71 | self.lang_rew = LanguageReward(None, self.outdim, hidden_dim, self.lang_enc.lang_size, simfunc=self.sim) 72 | params += list(self.lang_rew.parameters()) 73 | ######################################################################## 74 | 75 | ## Optimizer 76 | self.encoder_opt = torch.optim.Adam(params, lr = lr) 77 | 78 | def get_reward(self, e0, es, sentences): 79 | ## Only callable is langweight was set to be 1 80 | le = self.lang_enc(sentences) 81 | return self.lang_rew(e0, es, le) 82 | 83 | ## Forward Call (im --> representation) 84 | def forward(self, obs, num_ims = 1, obs_shape = [3, 224, 224]): 85 | if obs_shape != [3, 224, 224]: 86 | preprocess = nn.Sequential( 87 | transforms.Resize(256), 88 | transforms.CenterCrop(224), 89 | self.normlayer, 90 | ) 91 | else: 92 | preprocess = nn.Sequential( 93 | self.normlayer, 94 | ) 95 | 96 | ## Input must be [0, 255], [3,244,244] 97 | obs = obs.float() / 255.0 98 | obs_p = preprocess(obs) 99 | h = self.convnet(obs_p) 100 | return h 101 | 102 | def sim(self, tensor1, tensor2): 103 | if self.l2dist: 104 | d = - torch.linalg.norm(tensor1 - tensor2, dim = -1) 105 | else: 106 | d = self.cs(tensor1, tensor2) 107 | return d 108 | -------------------------------------------------------------------------------- /r3m/r3m_base.yaml: -------------------------------------------------------------------------------- 1 | name: r3m_base 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=1_gnu 8 | - bzip2=1.0.8=h7f98852_4 9 | - ca-certificates=2021.10.8=ha878542_0 10 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 11 | - libffi=3.4.2=h7f98852_5 12 | - libgcc-ng=11.2.0=h1d223b6_13 13 | - libgomp=11.2.0=h1d223b6_13 14 | - libnsl=2.0.0=h7f98852_0 15 | - libuuid=2.32.1=h7f98852_1000 16 | - libzlib=1.2.11=h36c2ea0_1013 17 | - ncurses=6.3=h9c3ff4c_0 18 | - openssl=3.0.0=h7f98852_2 19 | - pip=22.0.4=pyhd8ed1ab_0 20 | - python=3.9.10=hc74c709_2_cpython 21 | - python_abi=3.9=2_cp39 22 | - readline=8.1=h46c0cb4_0 23 | - setuptools=60.9.3=py39hf3d152e_0 24 | - sqlite=3.37.0=h9cd32fc_0 25 | - tk=8.6.12=h27826a3_0 26 | - tzdata=2021e=he74cb21_0 27 | - wheel=0.37.1=pyhd8ed1ab_0 28 | - xz=5.2.5=h516909a_1 29 | - zlib=1.2.11=h36c2ea0_1013 30 | - pip: 31 | - antlr4-python3-runtime==4.8 32 | - beautifulsoup4==4.10.0 33 | - certifi==2021.10.8 34 | - charset-normalizer==2.0.12 35 | - click==8.0.4 36 | - cycler==0.11.0 37 | - filelock==3.6.0 38 | - fonttools==4.30.0 39 | - gdown==4.4.0 40 | - huggingface-hub==0.4.0 41 | - hydra-core==1.1.1 42 | - idna==3.3 43 | - joblib==1.1.0 44 | - kiwisolver==1.3.2 45 | - matplotlib==3.5.1 46 | - numpy==1.22.3 47 | - omegaconf==2.1.1 48 | - packaging==21.3 49 | - pillow==9.0.1 50 | - pyparsing==3.0.7 51 | - pysocks==1.7.1 52 | - python-dateutil==2.8.2 53 | - pyyaml==6.0 54 | - regex==2022.3.2 55 | - requests==2.27.1 56 | - sacremoses==0.0.47 57 | - six==1.16.0 58 | - soupsieve==2.3.1 59 | - tokenizers==0.11.6 60 | - torch==1.7.1 61 | - torchvision==0.8.2 62 | - tqdm==4.63.0 63 | - transformers==4.17.0 64 | - typing-extensions==4.1.1 65 | - urllib3==1.26.8 66 | prefix: /private/home/surajn/.conda/envs/r3m_base 67 | -------------------------------------------------------------------------------- /r3m/train_representation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import warnings 6 | 7 | import torchvision 8 | warnings.filterwarnings('ignore', category=DeprecationWarning) 9 | 10 | import os 11 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 12 | os.environ['MUJOCO_GL'] = 'egl' 13 | 14 | from pathlib import Path 15 | import hydra 16 | import numpy as np 17 | import torch 18 | from r3m.utils import utils 19 | from r3m.trainer import Trainer 20 | from r3m.utils.data_loaders import R3MBuffer 21 | from r3m.utils.logger import Logger 22 | import time 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | 27 | def make_network(cfg): 28 | model = hydra.utils.instantiate(cfg) 29 | print("Let's use", torch.cuda.device_count(), "GPUs!") 30 | model = torch.nn.DataParallel(model) 31 | return model.cuda() 32 | 33 | class Workspace: 34 | def __init__(self, cfg): 35 | self.work_dir = Path.cwd() 36 | print(f'workspace: {self.work_dir}') 37 | 38 | self.cfg = cfg 39 | utils.set_seed_everywhere(cfg.seed) 40 | self.device = torch.device(cfg.device) 41 | self.setup() 42 | 43 | print("Creating Dataloader") 44 | if self.cfg.dataset == "ego4d": 45 | sources = ["ego4d"] 46 | else: 47 | raise NameError('Invalid Dataset') 48 | 49 | train_iterable = R3MBuffer(self.cfg.datapath, self.cfg.num_workers, "train", "train", 50 | alpha = self.cfg.alpha, datasources=sources, doaug = self.cfg.doaug) 51 | val_iterable = R3MBuffer(self.cfg.datapath, self.cfg.num_workers, "val", "validation", 52 | alpha = 0, datasources=sources, doaug = 0) 53 | 54 | self.train_loader = iter(torch.utils.data.DataLoader(train_iterable, 55 | batch_size=self.cfg.batch_size, 56 | num_workers=self.cfg.num_workers, 57 | pin_memory=True)) 58 | self.val_loader = iter(torch.utils.data.DataLoader(val_iterable, 59 | batch_size=self.cfg.batch_size, 60 | num_workers=self.cfg.num_workers, 61 | pin_memory=True)) 62 | 63 | 64 | ## Init Model 65 | print("Initializing Model") 66 | self.model = make_network(cfg.agent) 67 | 68 | self.timer = utils.Timer() 69 | self._global_step = 0 70 | 71 | ## If reloading existing model 72 | if cfg.load_snap: 73 | print("LOADING", cfg.load_snap) 74 | self.load_snapshot(cfg.load_snap) 75 | 76 | def setup(self): 77 | # create logger 78 | self.logger = Logger(self.work_dir, use_tb=False, cfg=self.cfg) 79 | 80 | @property 81 | def global_step(self): 82 | return self._global_step 83 | 84 | @property 85 | def global_frame(self): 86 | return self.global_step 87 | 88 | def train(self): 89 | # predicates 90 | train_until_step = utils.Until(self.cfg.train_steps, 91 | 1) 92 | eval_freq = self.cfg.eval_freq 93 | eval_every_step = utils.Every(eval_freq, 94 | 1) 95 | trainer = Trainer(eval_freq) 96 | 97 | ## Training Loop 98 | print("Begin Training") 99 | while train_until_step(self.global_step): 100 | ## Sample Batch 101 | t0 = time.time() 102 | batch_f, batch_langs = next(self.train_loader) 103 | t1 = time.time() 104 | metrics, st = trainer.update(self.model, (batch_f.cuda(), batch_langs), self.global_step) 105 | t2 = time.time() 106 | self.logger.log_metrics(metrics, self.global_frame, ty='train') 107 | 108 | if self.global_step % 10 == 0: 109 | print(self.global_step, metrics) 110 | print(f'Sample time {t1-t0}, Update time {t2-t1}') 111 | print(st) 112 | 113 | if eval_every_step(self.global_step): 114 | with torch.no_grad(): 115 | batch_f, batch_langs = next(self.val_loader) 116 | metrics, st = trainer.update(self.model, (batch_f.cuda(), batch_langs), self.global_step, eval=True) 117 | self.logger.log_metrics(metrics, self.global_frame, ty='eval') 118 | print("EVAL", self.global_step, metrics) 119 | 120 | self.save_snapshot() 121 | self._global_step += 1 122 | 123 | def save_snapshot(self): 124 | snapshot = self.work_dir / f'snapshot_{self.global_step}.pt' 125 | global_snapshot = self.work_dir / f'snapshot.pt' 126 | sdict = {} 127 | sdict["r3m"] = self.model.state_dict() 128 | torch.save(sdict, snapshot) 129 | sdict["global_step"] = self._global_step 130 | torch.save(sdict, global_snapshot) 131 | 132 | def load_snapshot(self, snapshot_path): 133 | payload = torch.load(snapshot_path) 134 | self.model.load_state_dict(payload['r3m']) 135 | try: 136 | self._global_step = payload['global_step'] 137 | except: 138 | print("No global step found") 139 | 140 | @hydra.main(config_path='cfgs', config_name='config_rep') 141 | def main(cfg): 142 | from train_representation import Workspace as W 143 | root_dir = Path.cwd() 144 | workspace = W(cfg) 145 | 146 | snapshot = root_dir / 'snapshot.pt' 147 | if snapshot.exists(): 148 | print(f'resuming: {snapshot}') 149 | workspace.load_snapshot(snapshot) 150 | workspace.train() 151 | 152 | 153 | if __name__ == '__main__': 154 | main() -------------------------------------------------------------------------------- /r3m/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import hydra 6 | import numpy as np 7 | from numpy.core.numeric import full 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import random 12 | from pathlib import Path 13 | from torchvision.utils import save_image 14 | import time 15 | import copy 16 | import torchvision.transforms as T 17 | 18 | epsilon = 1e-8 19 | def do_nothing(x): return x 20 | 21 | class Trainer(): 22 | def __init__(self, eval_freq): 23 | self.eval_freq = eval_freq 24 | 25 | def update(self, model, batch, step, eval=False): 26 | t0 = time.time() 27 | metrics = dict() 28 | if eval: 29 | model.eval() 30 | else: 31 | model.train() 32 | 33 | t1 = time.time() 34 | ## Batch 35 | b_im, b_lang = batch 36 | t2 = time.time() 37 | 38 | ## Encode Start and End Frames 39 | bs = b_im.shape[0] 40 | b_im_r = b_im.reshape(bs*5, 3, 224, 224) 41 | alles = model(b_im_r) 42 | alle = alles.reshape(bs, 5, -1) 43 | e0 = alle[:, 0] 44 | eg = alle[:, 1] 45 | es0 = alle[:, 2] 46 | es1 = alle[:, 3] 47 | es2 = alle[:, 4] 48 | 49 | full_loss = 0 50 | 51 | ## LP Loss 52 | l2loss = torch.linalg.norm(alles, ord=2, dim=-1).mean() 53 | l1loss = torch.linalg.norm(alles, ord=1, dim=-1).mean() 54 | l0loss = torch.linalg.norm(alles, ord=0, dim=-1).mean() 55 | metrics['l2loss'] = l2loss.item() 56 | metrics['l1loss'] = l1loss.item() 57 | metrics['l0loss'] = l0loss.item() 58 | full_loss += model.module.l2weight * l2loss 59 | full_loss += model.module.l1weight * l1loss 60 | 61 | 62 | t3 = time.time() 63 | ## Language Predictive Loss 64 | if model.module.langweight > 0: 65 | ## Number of negative examples to use for language 66 | num_neg = model.module.num_negatives 67 | 68 | ## Trains to have G(e_0, e_t, l) be higher than G(e_0, e_ 0: 123 | ## Number of negative video examples to use 124 | num_neg_v = model.module.num_negatives 125 | 126 | ## Computing distance from t0-t2, t1-t2, t1-t0 127 | sim_0_2 = model.module.sim(es2, es0) 128 | sim_1_2 = model.module.sim(es2, es1) 129 | sim_0_1 = model.module.sim(es1, es0) 130 | 131 | ## For the specified number of negatives from other videos 132 | ## Add it as a negative 133 | neg2 = [] 134 | neg0 = [] 135 | for _ in range(num_neg_v): 136 | es0_shuf = es0[torch.randperm(es0.size()[0])] 137 | es2_shuf = es2[torch.randperm(es2.size()[0])] 138 | neg0.append(model.module.sim(es0, es0_shuf)) 139 | neg2.append(model.module.sim(es2, es2_shuf)) 140 | neg0 = torch.stack(neg0, -1) 141 | neg2 = torch.stack(neg2, -1) 142 | 143 | ## TCN Loss 144 | smoothloss1 = -torch.log(epsilon + (torch.exp(sim_1_2) / (epsilon + torch.exp(sim_0_2) + torch.exp(sim_1_2) + torch.exp(neg2).sum(-1)))) 145 | smoothloss2 = -torch.log(epsilon + (torch.exp(sim_0_1) / (epsilon + torch.exp(sim_0_1) + torch.exp(sim_0_2) + torch.exp(neg0).sum(-1)))) 146 | smoothloss = ((smoothloss1 + smoothloss2) / 2.0).mean() 147 | a_state = ((1.0 * (sim_0_2 < sim_1_2)) * (1.0 * (sim_0_1 > sim_0_2))).mean() 148 | metrics['tcnloss'] = smoothloss.item() 149 | metrics['aligned'] = a_state.item() 150 | full_loss += model.module.tcnweight * smoothloss 151 | 152 | metrics['full_loss'] = full_loss.item() 153 | 154 | t6 = time.time() 155 | if not eval: 156 | model.module.encoder_opt.zero_grad() 157 | full_loss.backward() 158 | model.module.encoder_opt.step() 159 | 160 | t7 = time.time() 161 | st = f"Load time {t1-t0}, Batch time {t2-t1}, Encode and LP tine {t3-t2}, Lang time {t5-t3}, TCN time {t6-t5}, Backprop time {t7-t6}" 162 | return metrics, st -------------------------------------------------------------------------------- /r3m/utils/data_loaders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import warnings 6 | 7 | import torchvision 8 | warnings.filterwarnings('ignore', category=DeprecationWarning) 9 | 10 | import os 11 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 12 | os.environ['MUJOCO_GL'] = 'egl' 13 | 14 | from pathlib import Path 15 | 16 | import hydra 17 | import numpy as np 18 | import torch 19 | from torchvision import transforms 20 | from torch.utils.data import IterableDataset 21 | import pandas as pd 22 | import json 23 | import time 24 | import pickle 25 | from torchvision.utils import save_image 26 | import json 27 | import random 28 | 29 | 30 | def get_ind(vid, index, ds): 31 | if ds == "ego4d": 32 | return torchvision.io.read_image(f"{vid}/{index:06}.jpg") 33 | else: 34 | raise NameError('Invalid Dataset') 35 | 36 | 37 | ## Data Loader for Ego4D 38 | class R3MBuffer(IterableDataset): 39 | def __init__(self, ego4dpath, num_workers, source1, source2, alpha, datasources, doaug = "none"): 40 | self._num_workers = max(1, num_workers) 41 | self.alpha = alpha 42 | self.curr_same = 0 43 | self.data_sources = datasources 44 | self.doaug = doaug 45 | 46 | # Augmentations 47 | if doaug in ["rc", "rctraj"]: 48 | self.aug = torch.nn.Sequential( 49 | transforms.RandomResizedCrop(224, scale = (0.2, 1.0)), 50 | ) 51 | else: 52 | self.aug = lambda a : a 53 | 54 | # Load Data 55 | if "ego4d" in self.data_sources: 56 | print("Ego4D") 57 | self.manifest = pd.read_csv(f"{ego4dpath}manifest.csv") 58 | print(self.manifest) 59 | self.ego4dlen = len(self.manifest) 60 | else: 61 | raise NameError('Invalid Dataset') 62 | 63 | 64 | def _sample(self): 65 | t0 = time.time() 66 | ds = random.choice(self.data_sources) 67 | 68 | vidid = np.random.randint(0, self.ego4dlen) 69 | m = self.manifest.iloc[vidid] 70 | vidlen = m["len"] 71 | txt = m["txt"] 72 | label = txt[2:] ## Cuts of the "C " part of the text 73 | vid = m["path"] 74 | 75 | start_ind = np.random.randint(1, 2 + int(self.alpha * vidlen)) 76 | end_ind = np.random.randint(int((1-self.alpha) * vidlen)-1, vidlen) 77 | s1_ind = np.random.randint(2, vidlen) 78 | s0_ind = np.random.randint(1, s1_ind) 79 | s2_ind = np.random.randint(s1_ind, vidlen+1) 80 | 81 | if self.doaug == "rctraj": 82 | ### Encode each image in the video at once the same way 83 | im0 = get_ind(vid, start_ind, ds) 84 | img = get_ind(vid, end_ind, ds) 85 | imts0 = get_ind(vid, s0_ind, ds) 86 | imts1 = get_ind(vid, s1_ind, ds) 87 | imts2 = get_ind(vid, s2_ind, ds) 88 | allims = torch.stack([im0, img, imts0, imts1, imts2], 0) 89 | allims_aug = self.aug(allims / 255.0) * 255.0 90 | 91 | im0 = allims_aug[0] 92 | img = allims_aug[1] 93 | imts0 = allims_aug[2] 94 | imts1 = allims_aug[3] 95 | imts2 = allims_aug[4] 96 | else: 97 | ### Encode each image individually 98 | im0 = self.aug(get_ind(vid, start_ind, ds) / 255.0) * 255.0 99 | img = self.aug(get_ind(vid, end_ind, ds) / 255.0) * 255.0 100 | imts0 = self.aug(get_ind(vid, s0_ind, ds) / 255.0) * 255.0 101 | imts1 = self.aug(get_ind(vid, s1_ind, ds) / 255.0) * 255.0 102 | imts2 = self.aug(get_ind(vid, s2_ind, ds) / 255.0) * 255.0 103 | 104 | im = torch.stack([im0, img, imts0, imts1, imts2]) 105 | return (im, label) 106 | 107 | def __iter__(self): 108 | while True: 109 | yield self._sample() -------------------------------------------------------------------------------- /r3m/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import csv 6 | import datetime 7 | from collections import defaultdict 8 | 9 | import numpy as np 10 | import torch 11 | import torchvision 12 | import wandb 13 | 14 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 15 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 16 | ('episode_reward', 'R', 'float'), 17 | ('buffer_size', 'BS', 'int'), ('fps', 'FPS', 'float'), 18 | ('total_time', 'T', 'time')] 19 | 20 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 21 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 22 | ('episode_reward', 'R', 'float'), 23 | ('total_time', 'T', 'time')] 24 | 25 | 26 | class AverageMeter(object): 27 | def __init__(self): 28 | self._sum = 0 29 | self._count = 0 30 | 31 | def update(self, value, n=1): 32 | self._sum += value 33 | self._count += n 34 | 35 | def value(self): 36 | return self._sum / max(1, self._count) 37 | 38 | 39 | class MetersGroup(object): 40 | def __init__(self, csv_file_name, formating): 41 | self._csv_file_name = csv_file_name 42 | self._formating = formating 43 | self._meters = defaultdict(AverageMeter) 44 | self._csv_file = None 45 | self._csv_writer = None 46 | 47 | def log(self, key, value, n=1): 48 | self._meters[key].update(value, n) 49 | 50 | def _prime_meters(self): 51 | data = dict() 52 | for key, meter in self._meters.items(): 53 | if key.startswith('train'): 54 | key = key[len('train') + 1:] 55 | else: 56 | key = key[len('eval') + 1:] 57 | key = key.replace('/', '_') 58 | data[key] = meter.value() 59 | return data 60 | 61 | def _remove_old_entries(self, data): 62 | rows = [] 63 | with self._csv_file_name.open('r') as f: 64 | reader = csv.DictReader(f) 65 | for row in reader: 66 | if float(row['episode']) >= data['episode']: 67 | break 68 | rows.append(row) 69 | with self._csv_file_name.open('w') as f: 70 | writer = csv.DictWriter(f, 71 | fieldnames=sorted(data.keys()), 72 | restval=0.0) 73 | writer.writeheader() 74 | for row in rows: 75 | writer.writerow(row) 76 | 77 | def _dump_to_csv(self, data): 78 | if self._csv_writer is None: 79 | should_write_header = True 80 | if self._csv_file_name.exists(): 81 | self._remove_old_entries(data) 82 | should_write_header = False 83 | 84 | self._csv_file = self._csv_file_name.open('a') 85 | self._csv_writer = csv.DictWriter(self._csv_file, 86 | fieldnames=sorted(data.keys()), 87 | restval=0.0) 88 | if should_write_header: 89 | self._csv_writer.writeheader() 90 | 91 | self._csv_writer.writerow(data) 92 | self._csv_file.flush() 93 | 94 | def _format(self, key, value, ty): 95 | if ty == 'int': 96 | value = int(value) 97 | return f'{key}: {value}' 98 | elif ty == 'float': 99 | return f'{key}: {value:.04f}' 100 | elif ty == 'time': 101 | value = str(datetime.timedelta(seconds=int(value))) 102 | return f'{key}: {value}' 103 | else: 104 | raise f'invalid format type: {ty}' 105 | 106 | def _dump_to_console(self, data, prefix): 107 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 108 | pieces = [f'| {prefix: <14}'] 109 | for key, disp_key, ty in self._formating: 110 | value = data.get(key, 0) 111 | pieces.append(self._format(disp_key, value, ty)) 112 | print(' | '.join(pieces)) 113 | 114 | def dump(self, step, prefix): 115 | if len(self._meters) == 0: 116 | return 117 | data = self._prime_meters() 118 | data['frame'] = step 119 | self._dump_to_csv(data) 120 | self._dump_to_console(data, prefix) 121 | self._meters.clear() 122 | 123 | 124 | class Logger(object): 125 | def __init__(self, log_dir, use_tb, cfg=None): 126 | self._log_dir = log_dir 127 | self.use_tb = use_tb 128 | self._train_mg = MetersGroup(log_dir / 'train.csv', 129 | formating=COMMON_TRAIN_FORMAT) 130 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', 131 | formating=COMMON_EVAL_FORMAT) 132 | if use_tb: 133 | self._sw = SummaryWriter(str(log_dir / 'tb')) 134 | else: 135 | print(cfg.wandbuser) 136 | wandb.init(project=cfg.wandbproject, entity=cfg.wandbuser, name=cfg.experiment) 137 | fullcfg = {**cfg, **cfg.agent} 138 | wandb.config.update(fullcfg) 139 | 140 | def _try_sw_log(self, key, value, step): 141 | if self.use_tb: 142 | self._sw.add_scalar(key, value, step) 143 | else: 144 | logs = {} 145 | logs[key] = value 146 | wandb.log(logs, step = step) 147 | 148 | def log(self, key, value, step): 149 | assert key.startswith('train') or key.startswith('eval') 150 | if type(value) == torch.Tensor: 151 | value = value.item() 152 | self._try_sw_log(key, value, step) 153 | mg = self._train_mg if key.startswith('train') else self._eval_mg 154 | mg.log(key, value) 155 | 156 | def log_metrics(self, metrics, step, ty): 157 | for key, value in metrics.items(): 158 | self.log(f'{ty}/{key}', value, step) 159 | 160 | def dump(self, step, ty=None): 161 | if ty is None or ty == 'eval': 162 | self._eval_mg.dump(step, 'eval') 163 | if ty is None or ty == 'train': 164 | self._train_mg.dump(step, 'train') 165 | 166 | def log_and_dump_ctx(self, step, ty): 167 | return LogAndDumpCtx(self, step, ty) 168 | 169 | 170 | class LogAndDumpCtx: 171 | def __init__(self, logger, step, ty): 172 | self._logger = logger 173 | self._step = step 174 | self._ty = ty 175 | 176 | def __enter__(self): 177 | return self 178 | 179 | def __call__(self, key, value): 180 | self._logger.log(f'{self._ty}/{key}', value, self._step) 181 | 182 | def __exit__(self, *args): 183 | self._logger.dump(self._step, self._ty) 184 | -------------------------------------------------------------------------------- /r3m/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import random 6 | import re 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from omegaconf import OmegaConf 14 | from torch import distributions as pyd 15 | from torch.distributions.utils import _standard_normal 16 | 17 | 18 | class eval_mode: 19 | def __init__(self, *models): 20 | self.models = models 21 | 22 | def __enter__(self): 23 | self.prev_states = [] 24 | for model in self.models: 25 | self.prev_states.append(model.training) 26 | model.train(False) 27 | 28 | def __exit__(self, *args): 29 | for model, state in zip(self.models, self.prev_states): 30 | model.train(state) 31 | return False 32 | 33 | 34 | def set_seed_everywhere(seed): 35 | torch.manual_seed(seed) 36 | if torch.cuda.is_available(): 37 | torch.cuda.manual_seed_all(seed) 38 | np.random.seed(seed) 39 | random.seed(seed) 40 | 41 | 42 | def soft_update_params(net, target_net, tau): 43 | for param, target_param in zip(net.parameters(), target_net.parameters()): 44 | target_param.data.copy_(tau * param.data + 45 | (1 - tau) * target_param.data) 46 | 47 | 48 | def to_torch(xs, device): 49 | return tuple(torch.as_tensor(x, device=device) for x in xs) 50 | 51 | 52 | def weight_init(m): 53 | if isinstance(m, nn.Linear): 54 | nn.init.orthogonal_(m.weight.data) 55 | if hasattr(m.bias, 'data'): 56 | m.bias.data.fill_(0.0) 57 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 58 | gain = nn.init.calculate_gain('relu') 59 | nn.init.orthogonal_(m.weight.data, gain) 60 | if hasattr(m.bias, 'data'): 61 | m.bias.data.fill_(0.0) 62 | 63 | def accuracy(output, target, topk=(1,)): 64 | """Computes the precision@k for the specified values of k""" 65 | maxk = max(topk) 66 | batch_size = target.size(0) 67 | 68 | _, pred = output.topk(maxk, 1, True, True) 69 | pred = pred.t() 70 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 71 | 72 | res = [] 73 | for k in topk: 74 | correct_k = correct[:k].reshape(-1).float().sum(0) 75 | res.append(correct_k.mul_(1.0 / batch_size)) 76 | return res 77 | 78 | class Until: 79 | def __init__(self, until, action_repeat=1): 80 | self._until = until 81 | self._action_repeat = action_repeat 82 | 83 | def __call__(self, step): 84 | if self._until is None: 85 | return True 86 | until = self._until // self._action_repeat 87 | return step < until 88 | 89 | 90 | class Every: 91 | def __init__(self, every, action_repeat=1): 92 | self._every = every 93 | self._action_repeat = action_repeat 94 | 95 | def __call__(self, step): 96 | if self._every is None: 97 | return False 98 | every = self._every // self._action_repeat 99 | if step % every == 0: 100 | return True 101 | return False 102 | 103 | 104 | class Timer: 105 | def __init__(self): 106 | self._start_time = time.time() 107 | self._last_time = time.time() 108 | 109 | def reset(self): 110 | elapsed_time = time.time() - self._last_time 111 | self._last_time = time.time() 112 | total_time = time.time() - self._start_time 113 | return elapsed_time, total_time 114 | 115 | def total_time(self): 116 | return time.time() - self._start_time 117 | 118 | 119 | class TruncatedNormal(pyd.Normal): 120 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 121 | super().__init__(loc, scale, validate_args=False) 122 | self.low = low 123 | self.high = high 124 | self.eps = eps 125 | 126 | def _clamp(self, x): 127 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 128 | x = x - x.detach() + clamped_x.detach() 129 | return x 130 | 131 | def sample(self, clip=None, sample_shape=torch.Size()): 132 | shape = self._extended_shape(sample_shape) 133 | eps = _standard_normal(shape, 134 | dtype=self.loc.dtype, 135 | device=self.loc.device) 136 | eps *= self.scale 137 | if clip is not None: 138 | eps = torch.clamp(eps, -clip, clip) 139 | x = self.loc + eps 140 | return self._clamp(x) 141 | 142 | 143 | def schedule(schdl, step): 144 | try: 145 | return float(schdl) 146 | except ValueError: 147 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl) 148 | if match: 149 | init, final, duration = [float(g) for g in match.groups()] 150 | mix = np.clip(step / duration, 0.0, 1.0) 151 | return (1.0 - mix) * init + mix * final 152 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl) 153 | if match: 154 | init, final1, duration1, final2, duration2 = [ 155 | float(g) for g in match.groups() 156 | ] 157 | if step <= duration1: 158 | mix = np.clip(step / duration1, 0.0, 1.0) 159 | return (1.0 - mix) * init + mix * final1 160 | else: 161 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 162 | return (1.0 - mix) * final1 + mix * final2 163 | raise NotImplementedError(schdl) 164 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import os 6 | import sys 7 | from setuptools import setup, find_packages 8 | 9 | if sys.version_info.major != 3: 10 | print("This Python is only compatible with Python 3, but you are running " 11 | "Python {}. The installation will likely fail.".format(sys.version_info.major)) 12 | 13 | def read(fname): 14 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 15 | 16 | setup( 17 | name='r3m', 18 | version='0.0.0', 19 | packages=find_packages(), 20 | description='R3M: Pretrained Reusable Representations for Robot Manipulation from Diverse Human Videos', 21 | long_description=read('README.md'), 22 | author='Suraj Nair (Meta AI)', 23 | install_requires=[ 24 | 'gdown==4.4.0', 25 | 'torch<=1.10.2,>=1.7.1', 26 | 'torchvision<=0.11.3,>=0.8.2', 27 | 'omegaconf==2.1.1', 28 | 'hydra-core==1.1.1', 29 | 'pillow==9.0.1', 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | # from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE 3 | from PIL import Image 4 | import torch 5 | 6 | from collections import namedtuple 7 | from r3meval.utils.gym_env import GymEnv 8 | from r3meval.utils.obs_wrappers import MuJoCoPixelObs, StateEmbedding 9 | from r3meval.utils.sampling import sample_paths 10 | from r3meval.utils.gaussian_mlp import MLP 11 | from r3meval.utils.behavior_cloning import BC 12 | from r3meval.utils.visualizations import place_attention_heatmap_over_images 13 | from tabulate import tabulate 14 | from tqdm import tqdm 15 | import mj_envs, gym 16 | import numpy as np, time as timer, multiprocessing, pickle, os 17 | import os 18 | from collections import namedtuple 19 | 20 | import mvp 21 | 22 | from r3meval.utils.obs_wrappers import MuJoCoPixelObs, StateEmbedding 23 | from r3meval.utils.visualizations import place_attention_heatmap_over_images 24 | 25 | 26 | def visualize_shifts_metaworld(): 27 | env_to_model_name = { 28 | 'assembly-v2-goal-observable':'sawyer_assembly_peg', 29 | 'bin-picking-v2-goal-observable':'sawyer_bin_picking', 30 | 'button-press-topdown-v2-goal-observable':'sawyer_button_press_topdown', 31 | 'drawer-open-v2-goal-observable':'sawyer_drawer', 32 | 'hammer-v2-goal-observable':'sawyer_hammer', 33 | } 34 | model_path = f'/iris/u/kayburns/packages/metaworld_r3m/metaworld/envs/assets_v2/sawyer_xyz/' 35 | # model_path = '/iris/u/kayburns/packages/metaworld_r3m/metaworld/envs/assets_v2/sawyer_xyz/sawyer_assembly_peg_blue-woodtable.xml' 36 | # for shift in ['_distractor_medium', '_granite_table', '_metal1_table', '_cast_right', '_cast_left', '_darker', '_brighter', '']: 37 | for camera_name in ['top_cap2']:#, 'right_cap2', 'left_cap2']: 38 | for env_name in ['assembly-v2-goal-observable', 'bin-picking-v2-goal-observable', 'button-press-topdown-v2-goal-observable', 'drawer-open-v2-goal-observable', 'hammer-v2-goal-observable']: 39 | for shift in ['_distractor_easy', '_distractor_medium', '_distractor_hard', '_blue-woodtable', '_dark-woodtable', '_darkwoodtable', '_cast_right', '_cast_left', '_darker', '_brighter', '']: 40 | model_name = model_path+env_to_model_name[env_name]+shift+'.xml' 41 | # model_name=model_path 42 | e = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](model_name=model_name) 43 | e._freeze_rand_vec = False 44 | e.spec = namedtuple('spec', ['id', 'max_episode_steps']) 45 | e.spec.id = env_name 46 | e.spec.max_episode_steps = 500 47 | e = MuJoCoPixelObs(e, camera_name=camera_name, width=256, height=256) 48 | e = StateEmbedding(e, embedding_name='resnet50', load_path='r3m', 49 | proprio=9, camera_name=camera_name, env_name=env_name) 50 | im = e.render() 51 | image = Image.fromarray(im) 52 | env_name_prefix = env_name.split('-')[0] 53 | image.save(f'photos_of_envs/test_{env_name_prefix}{shift}_{camera_name}.jpg') 54 | 55 | shifts = [ 56 | 'slide_metal2', 'darker', 'distractor_hard' 57 | # 'distractor_cracker_box', \ 58 | # 'distractor_medium', \ 59 | # 'distractor_hard', \ 60 | # 'cast_left', 'cast_right', 'brighter', 'darker' \ 61 | ] 62 | 63 | def visualize_heatmap(model='dino', visualize_shift=False): 64 | camera_name = 'left_cap2' 65 | if model == 'dino': 66 | embedding_model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16').cuda() 67 | else: 68 | embedding_model= mvp.load("vitb-mae-egosoup").cuda() 69 | 70 | env_name = 'kitchen_knob1_on-v3' 71 | for shift in shifts: 72 | e = gym.make(env_name, model_path = f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_{shift}.xml') 73 | e = MuJoCoPixelObs(e, camera_name=camera_name, width=256, height=256) 74 | e = StateEmbedding(e, embedding_name='resnet50', load_path='r3m', 75 | proprio=9, camera_name=camera_name, env_name=env_name) 76 | e.reset() 77 | im = e.env.get_image() 78 | image = Image.fromarray(im.astype('uint8'), mode='RGB') 79 | image.save(f'heads_with distractors/train_env_{shift}.png') 80 | import pdb; pdb.set_trace() 81 | 82 | for head in range(6): 83 | e = gym.make(env_name) 84 | e = MuJoCoPixelObs(e, camera_name=camera_name, width=256, height=256) 85 | e = StateEmbedding(e, embedding_name='resnet50', load_path='r3m', 86 | proprio=9, camera_name=camera_name, env_name=env_name) 87 | e.reset() 88 | im = e.env.get_image() 89 | attention_vis = place_attention_heatmap_over_images([im], embedding_model, model, head=head) 90 | image = Image.fromarray(attention_vis[0].astype('uint8'), mode='RGB') 91 | image.save(f'heads_with distractors/{model}_head_{head}_{camera_name}.png') 92 | 93 | if visualize_shift: 94 | for shift in shifts: 95 | for head in range(6): 96 | e = gym.make(env_name, model_path = f'/iris/u/kayburns/packages/mj_envs/mj_envs/envs/relay_kitchen/assets/franka_kitchen_{shift}.xml') 97 | e = MuJoCoPixelObs(e, camera_name=camera_name, width=256, height=256) 98 | e = StateEmbedding(e, embedding_name='resnet50', load_path='r3m', 99 | proprio=9, camera_name=camera_name, env_name=env_name) 100 | e.reset() 101 | im = e.env.get_image() 102 | attention_vis = place_attention_heatmap_over_images([im], embedding_model, model, head=head) 103 | image = Image.fromarray(attention_vis[0].astype('uint8'), mode='RGB') 104 | image.save(f'heads_with distractors/{model}_{shift}_head_{head}.png') 105 | 106 | visualize_heatmap(model='dino', visualize_shift=True) --------------------------------------------------------------------------------