├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── conda_env.yaml ├── franka-env ├── franka_env │ ├── __init__.py │ └── envs │ │ ├── __init__.py │ │ └── franka_env.py └── setup.py ├── instructions ├── code.md └── installation_and_data_collection.md ├── point_policy ├── agent │ ├── baku.py │ ├── mtpi.py │ ├── networks │ │ ├── __init__.py │ │ ├── dit.py │ │ ├── gpt.py │ │ ├── mlp.py │ │ ├── policy_head.py │ │ ├── rgb_modules.py │ │ └── utils │ │ │ └── diffusion_policy.py │ ├── p3po.py │ └── point_policy.py ├── cfgs │ ├── agent │ │ ├── baku.yaml │ │ ├── mtpi.yaml │ │ ├── p3po.yaml │ │ └── point_policy.yaml │ ├── config.yaml │ ├── config_eval.yaml │ ├── dataloader │ │ ├── baku.yaml │ │ ├── mtpi.yaml │ │ ├── p3po.yaml │ │ └── point_policy.yaml │ └── suite │ │ ├── baku.yaml │ │ ├── mtpi.yaml │ │ ├── p3po.yaml │ │ ├── point_policy.yaml │ │ ├── points_cfg.yaml │ │ └── task │ │ ├── franka_env.yaml │ │ └── franka_env │ │ ├── bottle_on_rack.yaml │ │ ├── bottle_upright.yaml │ │ ├── bowl_in_oven.yaml │ │ ├── bread_on_plate.yaml │ │ ├── close_oven.yaml │ │ ├── drawer_close.yaml │ │ ├── fold_towel.yaml │ │ └── sweep_broom.yaml ├── eval.py ├── eval_point_track.py ├── logger.py ├── point_utils │ ├── correspondence.py │ ├── depth.py │ └── points_class.py ├── read_data │ ├── __init__.py │ ├── baku.py │ ├── mtpi.py │ ├── p3po.py │ └── point_policy.py ├── replay_buffer.py ├── robot_utils │ └── franka │ │ ├── calibration │ │ ├── constants.py │ │ └── generate_r2c_extrinsic.py │ │ ├── convert_pkl_human_to_robot.py │ │ ├── convert_to_pkl_human.py │ │ ├── convert_to_pkl_robot.py │ │ ├── gripper_points.py │ │ ├── label_points.ipynb │ │ ├── process_data_human.py │ │ ├── process_data_robot.py │ │ ├── save_video.py │ │ └── utils.py ├── suite │ ├── baku.py │ ├── mtpi.py │ ├── p3po.py │ └── point_policy.py ├── train.py ├── utils.py └── video.py └── setup.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # ignore 174 | */__pycache__/ 175 | */*/__pycache__/ 176 | */*/*/__pycache__/ 177 | */exp_local/ 178 | */*/exp_local/ 179 | expert_demos/ 180 | coordinates*/ 181 | calib/ 182 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "co-tracker"] 2 | path = co-tracker 3 | url = git@github.com:mlevy2525/co-tracker.git 4 | [submodule "dift"] 5 | path = dift 6 | url = git@github.com:mlevy2525/dift.git 7 | [submodule "Franka-Teach"] 8 | path = Franka-Teach 9 | url = git@github.com:NYU-robot-learning/Franka-Teach.git 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 23.7.0 10 | hooks: 11 | - id: black 12 | language_version: python3.8 13 | - repo: https://github.com/hadialqattan/pycln 14 | rev: v2.2.2 15 | hooks: 16 | - id: pycln 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Siddhant Haldar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Point Policy: Unifying Observations and Actions with Key Points for Robot Manipulation 2 | 3 | This is a repository containing the code for the paper [Point Policy: Unifying Observations and Actions with Key Points for Robot Manipulation](https://arxiv.org/abs/2502.20391). 4 | 5 | ![Image](https://github.com/user-attachments/assets/a03066bc-16fc-4ce5-b1eb-e3c8dc329a2d) 6 | 7 | ## Instructions 8 | 9 | We have provided the instructions for [installation and data collection](instructions/installation_and_data_collection.md) and [code execution](instructions/code.md) at the respective links. 10 | 11 | ## Bibtex 12 | 13 | If you find this work useful, please cite the paper using the following bibtex: 14 | 15 | ``` 16 | @article{haldar2025point, 17 | title={Point Policy: Unifying Observations and Actions with Key Points for Robot Manipulation}, 18 | author={Haldar, Siddhant and Pinto, Lerrel}, 19 | journal={arXiv preprint arXiv:2502.20391}, 20 | year={2025} 21 | } 22 | ``` 23 | 24 | ## Queries/Comments/Discussions 25 | 26 | We welcome any queries, comments or discussions on the paper. Please feel free to open an issue on this repository or reach out to siddhanthaldar@nyu.edu. 27 | -------------------------------------------------------------------------------- /conda_env.yaml: -------------------------------------------------------------------------------- 1 | name: point-policy 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.10 6 | - pip 7 | - numpy==2.1.3 8 | - absl-py=0.13.0 9 | - pyparsing=2.4.7 10 | - jupyterlab=3.0.14 11 | - scikit-image 12 | - nvidia::cudatoolkit 13 | - pytorch::pytorch 14 | - pytorch::torchvision 15 | - pytorch::torchaudio 16 | - pip: 17 | - pre-commit 18 | - black 19 | - PyOpenGL-accelerate 20 | - protobuf==3.20.1 21 | - termcolor==1.1.0 22 | - gym==0.22.0 23 | - gymnasium 24 | - tb-nightly 25 | - imageio==2.9.0 26 | - imageio-ffmpeg==0.4.4 27 | - hydra-core==1.3.2 28 | - hydra-submitit-launcher==1.2.0 29 | - pandas==2.2.3 30 | - ipdb==0.13.9 31 | - yapf==0.31.0 32 | - sklearn==0.0 33 | - matplotlib==3.10.0 34 | - opencv-python==4.11.0.86 35 | - sentence-transformers 36 | - einops==0.7.0 37 | - decord 38 | - mujoco 39 | - dm_control 40 | - diffusers==0.15.0 41 | - h5py 42 | - timm 43 | - blosc 44 | - transformers==4.45.2 45 | - ray==2.37.0 46 | - PyOpenGL==3.1.7 47 | - PyOpenGL-accelerate==3.1.7 48 | - robomimic==0.3.0 49 | - robosuite==1.4.1 50 | - huggingface_hub==0.25.2 51 | -------------------------------------------------------------------------------- /franka-env/franka_env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id="Franka-v1", 5 | entry_point="franka_env.envs:FrankaEnv", 6 | max_episode_steps=400, 7 | ) 8 | -------------------------------------------------------------------------------- /franka-env/franka_env/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from franka_env.envs.franka_env import FrankaEnv 2 | -------------------------------------------------------------------------------- /franka-env/franka_env/envs/franka_env.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import gym 3 | import numpy as np 4 | import time 5 | import pickle 6 | 7 | from frankateach.constants import ( 8 | CAM_PORT, 9 | GRIPPER_OPEN, 10 | HOST, 11 | CONTROL_PORT, 12 | ) 13 | from frankateach.messages import FrankaAction, FrankaState 14 | from frankateach.network import ( 15 | ZMQCameraSubscriber, 16 | create_request_socket, 17 | ) 18 | 19 | 20 | class FrankaEnv(gym.Env): 21 | def __init__( 22 | self, 23 | width=640, 24 | height=480, 25 | use_robot=True, 26 | use_gt_depth=False, 27 | crop_h=None, 28 | crop_w=None, 29 | ): 30 | super(FrankaEnv, self).__init__() 31 | self.width = width 32 | self.height = height 33 | self.crop_h = crop_h 34 | self.crop_w = crop_w 35 | 36 | self.channels = 3 37 | self.feature_dim = 8 38 | self.action_dim = 7 # (pos, axis angle, gripper) 39 | 40 | self.use_robot = use_robot 41 | self.use_gt_depth = use_gt_depth 42 | 43 | self.n_channels = 3 44 | self.reward = 0 45 | 46 | self.franka_state = None 47 | self.curr_images = None 48 | 49 | self.action_space = gym.spaces.Box( 50 | low=-float("inf"), high=float("inf"), shape=(self.action_dim,) 51 | ) 52 | self.observation_space = gym.spaces.Box( 53 | low=0, high=255, shape=(height, width, self.n_channels), dtype=np.uint8 54 | ) 55 | 56 | if self.use_robot: 57 | self.cam_ids = [1, 2] 58 | self.image_subscribers = {} 59 | if self.use_gt_depth: 60 | self.depth_subscribers = {} 61 | for cam_idx in self.cam_ids: 62 | port = CAM_PORT + cam_idx 63 | self.image_subscribers[cam_idx] = ZMQCameraSubscriber( 64 | host=HOST, 65 | port=port, 66 | topic_type="RGB", 67 | ) 68 | 69 | if self.use_gt_depth: 70 | depth_port = CAM_PORT + cam_idx + 1000 # depth offset =1000 71 | self.depth_subscribers[cam_idx] = ZMQCameraSubscriber( 72 | host=HOST, 73 | port=depth_port, 74 | topic_type="Depth", 75 | ) 76 | 77 | self.action_request_socket = create_request_socket(HOST, CONTROL_PORT) 78 | 79 | def get_state(self): 80 | self.action_request_socket.send(b"get_state") 81 | franka_state: FrankaState = pickle.loads(self.action_request_socket.recv()) 82 | self.franka_state = franka_state 83 | return franka_state 84 | 85 | def step(self, abs_action): 86 | """ 87 | Step the environment with an absolute action. 88 | 89 | Args: 90 | abs_action (_type_): absolute action in the format (position, orientation in quaternion, gripper) 91 | 92 | Returns: 93 | obs (dict): observation dictionary containing features and images 94 | reward (float): reward 95 | done (bool): whether the episode is done 96 | info (dict): additional information 97 | """ 98 | pos = abs_action[:3] 99 | quat = abs_action[3:7] 100 | gripper = abs_action[-1] 101 | 102 | # Send action to the robot 103 | franka_action = FrankaAction( 104 | pos=pos, 105 | quat=quat, 106 | gripper=gripper, 107 | reset=False, 108 | timestamp=time.time(), 109 | ) 110 | 111 | self.action_request_socket.send(bytes(pickle.dumps(franka_action, protocol=-1))) 112 | franka_state: FrankaState = pickle.loads(self.action_request_socket.recv()) 113 | self.franka_state = franka_state 114 | 115 | image_list = {} 116 | for cam_idx, subscriber in self.image_subscribers.items(): 117 | image, _ = subscriber.recv_rgb_image() 118 | 119 | # crop the image 120 | if self.crop_h is not None and self.crop_w is not None: 121 | h, w, _ = image.shape 122 | image = image[ 123 | int(h * self.crop_h[0]) : int(h * self.crop_h[1]), 124 | int(w * self.crop_w[0]) : int(w * self.crop_w[1]), 125 | ] 126 | 127 | image_list[cam_idx] = image 128 | 129 | if self.use_gt_depth: 130 | depth_list = {} 131 | for cam_idx, subscriber in self.depth_subscribers.items(): 132 | depth, _ = subscriber.recv_depth_image() 133 | 134 | if self.crop_h is not None and self.crop_w is not None: 135 | h, w = depth.shape 136 | depth = depth[ 137 | int(h * self.crop_h[0]) : int(h * self.crop_h[1]), 138 | int(w * self.crop_w[0]) : int(w * self.crop_w[1]), 139 | ] 140 | 141 | depth_list[cam_idx] = depth 142 | 143 | self.curr_images = image_list 144 | 145 | obs = { 146 | "features": np.concatenate( 147 | (franka_state.pos, franka_state.quat, [franka_state.gripper]) 148 | ), 149 | } 150 | 151 | for cam_idx, image in image_list.items(): 152 | obs[f"pixels{cam_idx}"] = cv2.resize(image, (self.width, self.height)) 153 | if self.use_gt_depth: 154 | for cam_idx, depth in depth_list.items(): 155 | obs[f"depth{cam_idx}"] = cv2.resize(depth, (self.width, self.height)) 156 | 157 | return obs, self.reward, False, None 158 | 159 | def reset(self): 160 | if self.use_robot: 161 | print("resetting") 162 | franka_action = FrankaAction( 163 | pos=np.zeros(3), 164 | quat=np.zeros(4), 165 | gripper=GRIPPER_OPEN, 166 | reset=True, 167 | timestamp=time.time(), 168 | ) 169 | 170 | self.action_request_socket.send( 171 | bytes(pickle.dumps(franka_action, protocol=-1)) 172 | ) 173 | franka_state: FrankaState = pickle.loads(self.action_request_socket.recv()) 174 | self.franka_state = franka_state 175 | print("reset done: ", franka_state) 176 | 177 | image_list = {} 178 | for cam_idx, subscriber in self.image_subscribers.items(): 179 | image, _ = subscriber.recv_rgb_image() 180 | 181 | # crop the image 182 | if self.crop_h is not None and self.crop_w is not None: 183 | h, w, _ = image.shape 184 | image = image[ 185 | int(h * self.crop_h[0]) : int(h * self.crop_h[1]), 186 | int(w * self.crop_w[0]) : int(w * self.crop_w[1]), 187 | ] 188 | 189 | image_list[cam_idx] = image 190 | 191 | if self.use_gt_depth: 192 | depth_list = {} 193 | for cam_idx, subscriber in self.depth_subscribers.items(): 194 | depth, _ = subscriber.recv_depth_image() 195 | 196 | if self.crop_h is not None and self.crop_w is not None: 197 | h, w = depth.shape 198 | depth = depth[ 199 | int(h * self.crop_h[0]) : int(h * self.crop_h[1]), 200 | int(w * self.crop_w[0]) : int(w * self.crop_w[1]), 201 | ] 202 | 203 | depth_list[cam_idx] = depth 204 | 205 | self.curr_images = image_list 206 | 207 | obs = { 208 | "features": np.concatenate( 209 | (franka_state.pos, franka_state.quat, [franka_state.gripper]) 210 | ), 211 | } 212 | for cam_idx, image in image_list.items(): 213 | obs[f"pixels{cam_idx}"] = cv2.resize(image, (self.width, self.height)) 214 | if self.use_gt_depth: 215 | for cam_idx, depth in depth_list.items(): 216 | obs[f"depth{cam_idx}"] = cv2.resize( 217 | depth, (self.width, self.height) 218 | ) 219 | 220 | return obs 221 | 222 | else: 223 | obs = {} 224 | obs["features"] = np.zeros(self.feature_dim) 225 | obs["pixels"] = np.zeros((self.height, self.width, self.n_channels)) 226 | if self.use_gt_depth: 227 | obs["depth"] = np.zeros((self.height, self.width)) 228 | 229 | return obs 230 | 231 | def render(self, mode="rgb_array", width=640, height=480): 232 | assert self.curr_images is not None, "Must call reset() before render()" 233 | if mode == "rgb_array": 234 | image_list = [] 235 | for key, im in self.curr_images.items(): 236 | image_list.append(cv2.resize(im, (width, height))) 237 | 238 | return np.concatenate(image_list, axis=1) 239 | else: 240 | raise NotImplementedError 241 | -------------------------------------------------------------------------------- /franka-env/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="franka_env", 5 | version="0.0.1", 6 | packages=["franka_env"], 7 | install_requires=["gym"], 8 | ) 9 | -------------------------------------------------------------------------------- /instructions/code.md: -------------------------------------------------------------------------------- 1 | NOTE: All commands must be run from inside the `point-policy` directory. In `point_policy/cfg/config.yaml`, set `root_dir` to `path/to/repo` and `data_dir` to `path/to/data/expert_demos`. Also, set 'root_dir' in `point_policy/cfg/suite/point_cfg.yaml` to `path/to/repo`. 2 | 3 | ## Training 4 | 5 | - BC 6 | ``` 7 | python train.py agent=baku suite=baku dataloader=baku eval=false suite/task/franka_env= experiment=bc 8 | ``` 9 | 10 | - BC w/ Depth 11 | ``` 12 | python train.py agent=baku suite=baku dataloader=baku eval=false suite/task/franka_env= suite.gt_depth=true experiment=bc_with_depth 13 | ``` 14 | 15 | - MT-π 16 | ``` 17 | python train.py agent=mtpi suite=mtpi dataloader=mtpi eval=false suite.use_robot_points=true suite.use_object_points=true suite/task/franka_env= experiment=mtpi 18 | ``` 19 | 20 | - P3PO 21 | ``` 22 | python train.py agent=p3po suite=p3po dataloader=p3po eval=false suite.use_robot_points=true suite.use_object_points=true suite/task/franka_env= experiment=p3po 23 | ``` 24 | 25 | - Point Policy 26 | ``` 27 | python train.py agent=point_policy suite=point_policy dataloader=point_policy eval=false suite.use_robot_points=true suite.use_object_points=true suite/task/franka_env= experiment=point_policy 28 | ``` 29 | 30 | ## Inference 31 | 32 | - BC 33 | ``` 34 | python eval.py agent=baku suite=baku dataloader=baku eval=true experiment=eval_bc suite/task/franka_env= bc_weight=path/to/bc/weight 35 | ``` 36 | 37 | - BC w/ Depth 38 | ``` 39 | python eval.py agent=baku suite=baku dataloader=baku eval=true experiment=eval_bc_with_depth suite.gt_depth=true suite/task/franka_env= bc_weight=path/to/bc/weight 40 | ``` 41 | 42 | - MT-π 43 | ``` 44 | python eval_point_track.py agent=mtpi suite=mtpi dataloader=mtpi eval=true suite.use_robot_points=true suite.use_object_points=false experiment=eval_mtpi suite/task/franka_env= bc_weight=path/to/bc/weight 45 | ``` 46 | 47 | - P3PO 48 | ``` 49 | python eval_point_track.py agent=p3po suite=p3po dataloader=p3po eval=true suite.use_robot_points=true suite.use_object_points=true experiment=eval_p3po suite/task/franka_env= bc_weight=path/to/bc/weight 50 | ``` 51 | 52 | - Point Policy 53 | ``` 54 | python eval_point_track.py agent=point_policy suite=point_policy dataloader=point_policy eval=true suite.use_robot_points=true suite.use_object_points=true experiment=eval_point_policy suite/task/franka_env= bc_weight=path/to/bc/weight 55 | ``` 56 | -------------------------------------------------------------------------------- /instructions/installation_and_data_collection.md: -------------------------------------------------------------------------------- 1 | ## Installation Instructions 2 | 3 | - Clone the repository with the submodules. 4 | ``` 5 | git clone git@github.com:siddhanthaldar/Point-Policy.git --recurse-submodules 6 | ``` 7 | - Create a conda environment using the provided `conda_env.yaml` file. 8 | ``` 9 | conda env create -f conda_env.yaml 10 | ``` 11 | - Activate the environment using `conda activate point-policy`. 12 | - Install Franka Teach using the instructions provided in the submodule. If you only want to perform training runs, you can skip setting up the Franka Teach environment and run the following. 13 | ``` 14 | cd Franka-Teach 15 | pip install -e . 16 | cd .. 17 | ``` 18 | - Install the Franka environment using the following command. 19 | ``` 20 | cd franka_env 21 | pip install -e . 22 | cd .. 23 | ``` 24 | - You can download and install the co-tracker and dift submodules and relevant packages by running the setup.sh file. Make sure to run this from the root repository or the models may get installed in the wrong location. 25 | ``` 26 | sudo chmod 777 setup.sh # make executable 27 | ./setup.sh 28 | ``` 29 | 30 | 31 | ## Data Collection Instructions 32 | - Instructions for data collection are provided in the [Franka Teach submodule](Franka-Teach/README.md). This is a fork of [Open Teach](https://github.com/aadhithya14/Open-Teach) modified to only work with Franka robots. 33 | 34 | ## Data Preprocessing Instructions 35 | 36 | NOTE: Set 'root_dir' in `point_policy/cfg/suite/point_cfg.yaml` to `path/to/repo`. 37 | 38 | - Go to the robot utils Franka directory. 39 | ``` 40 | cd point-policy/robot_utils/franka 41 | ``` 42 | - Once you have collected the human data using Franka Teach, process it to remove pauses and save it in a nicer format. 43 | ``` 44 | python process_data_human.py --data_dir path/to/data --task_names --process_depth 45 | ``` 46 | - Convert the data to a pickle file (without processing key points first). 47 | ``` 48 | python convert_to_pkl_human.py --data_dir path/to/data --calib_path path/to/calib_file --task_names 49 | ``` 50 | - NOTE: Before generating task data, we first need generate the calibration file. 51 | - For calibration, generate the pkl file without points for the calibration data (collected using Franka Teach) and make sure to set the `PATH_DATA_PKL` to the data pickle file for the calib data first. 52 | - Next generate the calib file using the following command 53 | ``` 54 | cd calibration 55 | python generate_r2c_extrinsic.py 56 | cd .. 57 | ``` 58 | - This will generate the calib file in `point_policy/calib/calib.npy`. 59 | 60 | - Label semantically meaningful points for each task following commands in `point-policy/robot_utils/franka/label_points.ipynb` 61 | - Save pickle data with key point labels, both for the human hand and object points obtained through human annotations. 62 | ``` 63 | python convert_to_pkl_human.py --data_dir path/to/data --calib_path path/to/calib_file --task_names --process_points 64 | ``` 65 | - Convert human hand poses to robot actions in the data. 66 | ``` 67 | python convert_pkl_human_to_robot.py --data_dir path/to/data --calib_path path/to/calib_file --task_name 68 | ``` 69 | 70 | NOTE: 71 | - The `calib_path` must be set to `/calib/calib.npy` where `` is the root directory of the repository. The `data_dir` must be set to the directory where the data is stored during teleoperation (`path/to/data`). 72 | - The generated pkl files with robot actions will be stored in `path/to/data/expert_demos/franka_env`. The variable `data_dir` in `config.yaml` and `config_eval.yaml` in `point_policy/cfg` must be set to `path/to/data/expert_demos`. 73 | -------------------------------------------------------------------------------- /point_policy/agent/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siddhanthaldar/Point-Policy/6f318852b324283d354446bfba0628890f78a116/point_policy/agent/networks/__init__.py -------------------------------------------------------------------------------- /point_policy/agent/networks/dit.py: -------------------------------------------------------------------------------- 1 | """ 2 | DiT like transformer for point-tracks 3 | Adapted from: https://github.com/facebookresearch/DiT/blob/main/models.py 4 | Inspired by: https://github.com/homangab/Track-2-Act/blob/main/single_script.py 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | from timm.models.vision_transformer import Attention, Mlp 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | def modulate(x, shift, scale): 18 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 19 | 20 | 21 | ################################################################################# 22 | # Core DiT Model # 23 | ################################################################################# 24 | 25 | 26 | class DiTBlock(nn.Module): 27 | """ 28 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 29 | """ 30 | 31 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): 32 | super().__init__() 33 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 34 | self.attn = Attention( 35 | hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs 36 | ) 37 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 38 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 39 | approx_gelu = lambda: nn.GELU(approximate="tanh") 40 | self.mlp = Mlp( 41 | in_features=hidden_size, 42 | hidden_features=mlp_hidden_dim, 43 | act_layer=approx_gelu, 44 | drop=0, 45 | ) 46 | self.adaLN_modulation = nn.Sequential( 47 | nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) 48 | ) 49 | 50 | def forward(self, x, c): 51 | ( 52 | shift_msa, 53 | scale_msa, 54 | gate_msa, 55 | shift_mlp, 56 | scale_mlp, 57 | gate_mlp, 58 | ) = self.adaLN_modulation(c).chunk(6, dim=1) 59 | x = x + gate_msa.unsqueeze(1) * self.attn( 60 | modulate(self.norm1(x), shift_msa, scale_msa) 61 | ) 62 | x = x + gate_mlp.unsqueeze(1) * self.mlp( 63 | modulate(self.norm2(x), shift_mlp, scale_mlp) 64 | ) 65 | return x 66 | 67 | 68 | class FinalLayer(nn.Module): 69 | """ 70 | The final layer of DiT. 71 | """ 72 | 73 | def __init__(self, hidden_size, patch_size, out_channels): 74 | super().__init__() 75 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 76 | self.linear = nn.Linear( 77 | hidden_size, patch_size * patch_size * out_channels, bias=True 78 | ) 79 | self.adaLN_modulation = nn.Sequential( 80 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) 81 | ) 82 | 83 | def forward(self, x, c): 84 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 85 | x = modulate(self.norm_final(x), shift, scale) 86 | x = self.linear(x) 87 | return x 88 | 89 | 90 | class DiT(nn.Module): 91 | """ 92 | Diffusion model with a Transformer backbone. 93 | """ 94 | 95 | def __init__( 96 | self, 97 | horizon=8, 98 | hidden_size=1152, 99 | depth=28, 100 | num_heads=16, 101 | mlp_ratio=4.0, 102 | learn_sigma=False, 103 | cond_dim=512, ## dim of image encodings # ResNet18 has output dim of 512 104 | num_points=25, ## remove if pos emb is not used 105 | with_pos_emb=True, 106 | num_conds=1, 107 | ): 108 | super().__init__() 109 | self.learn_sigma = learn_sigma 110 | self.in_channels = ( 111 | horizon # history of points converted to a vector using an encoder 112 | ) 113 | self.out_channels = hidden_size 114 | self.patch_size = 1 115 | self.num_heads = num_heads 116 | self.num_points = num_points 117 | self.with_pos_emb = with_pos_emb 118 | 119 | self.x_embedder = nn.Linear(self.in_channels, hidden_size) 120 | 121 | self.y_embedder = nn.Linear(num_conds * cond_dim, hidden_size) 122 | 123 | if self.with_pos_emb: 124 | # Will use fixed sin-cos embedding: 125 | self.pos_embed = nn.Parameter( 126 | torch.zeros(1, self.num_points, hidden_size), requires_grad=False 127 | ) 128 | 129 | self.blocks = nn.ModuleList( 130 | [ 131 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) 132 | for _ in range(depth) 133 | ] 134 | ) 135 | 136 | patch_size = 1 ### because patches are points 137 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 138 | 139 | ## resnet is initialized by get_resnet() above 140 | ## x_embedder, y_embedder has default initilization 141 | self.initialize_weights() 142 | 143 | def initialize_weights(self): 144 | # Initialize transformer layers: 145 | def _basic_init(module): 146 | if isinstance(module, nn.Linear): 147 | torch.nn.init.xavier_uniform_(module.weight) 148 | if module.bias is not None: 149 | nn.init.constant_(module.bias, 0) 150 | 151 | self.apply(_basic_init) 152 | 153 | if self.with_pos_emb: 154 | # Initialize (and freeze) pos_embed by sin-cos embedding: 155 | pos_embed = get_2d_sincos_pos_embed( 156 | self.pos_embed.shape[-1], int(self.num_points**0.5) 157 | ) 158 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 159 | 160 | # Zero-out adaLN modulation layers in DiT blocks: 161 | for block in self.blocks: 162 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 163 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 164 | 165 | # Zero-out output layers: 166 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 167 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 168 | nn.init.constant_(self.final_layer.linear.weight, 0) 169 | nn.init.constant_(self.final_layer.linear.bias, 0) 170 | 171 | def unpatchify(self, x): 172 | """ 173 | x: (N, T, patch_size**2 * C) 174 | imgs: (N, H, W, C) 175 | """ 176 | c = self.out_channels 177 | p = self.num_points 178 | 179 | x = torch.einsum("npc->ncp", x) 180 | return x 181 | 182 | def forward(self, x, y): 183 | """ 184 | Forward pass of DiT. 185 | # x: (N, C, P) tensor of point tracks, C = 2*T where T=8 horizon 186 | x: (N, P, C) tensor of point tracks, C = 2*T where T=8 horizon; P = num_points 187 | # y: (N,2,3,96,96) tensor of initial and goal images 188 | y: (N,2,512) tensor of initial and goal image encodings 189 | """ 190 | x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size ** 2 191 | if self.with_pos_emb: 192 | # x = x + self.pos_embed 193 | shape = x.shape 194 | x = x + self.pos_embed[:, : shape[1]] 195 | 196 | y = y.flatten(start_dim=1) # (N, 2*512) 197 | y = self.y_embedder(y) # (N, D) 198 | 199 | c = y 200 | for block in self.blocks: 201 | x = block(x, c) # (N, T, D) 202 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 203 | return x 204 | 205 | 206 | ################################################################################# 207 | # Sine/Cosine Positional Embedding Functions # 208 | ################################################################################# 209 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 210 | 211 | 212 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 213 | """ 214 | grid_size: int of the grid height and width 215 | return: 216 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 217 | """ 218 | grid_h = np.arange(grid_size, dtype=np.float32) 219 | grid_w = np.arange(grid_size, dtype=np.float32) 220 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 221 | grid = np.stack(grid, axis=0) 222 | 223 | grid = grid.reshape([2, 1, grid_size, grid_size]) 224 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 225 | if cls_token and extra_tokens > 0: 226 | pos_embed = np.concatenate( 227 | [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 228 | ) 229 | return pos_embed 230 | 231 | 232 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 233 | assert embed_dim % 2 == 0 234 | 235 | # use half of dimensions to encode grid_h 236 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 237 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 238 | 239 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 240 | return emb 241 | 242 | 243 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 244 | """ 245 | embed_dim: output dimension for each position 246 | pos: a list of positions to be encoded: size (M,) 247 | out: (M, D) 248 | """ 249 | assert embed_dim % 2 == 0 250 | omega = np.arange(embed_dim // 2, dtype=np.float64) 251 | omega /= embed_dim / 2.0 252 | omega = 1.0 / 10000**omega # (D/2,) 253 | 254 | pos = pos.reshape(-1) # (M,) 255 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 256 | 257 | emb_sin = np.sin(out) # (M, D/2) 258 | emb_cos = np.cos(out) # (M, D/2) 259 | 260 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 261 | return emb 262 | -------------------------------------------------------------------------------- /point_policy/agent/networks/gpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | An adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch. 3 | Original source: https://github.com/karpathy/nanoGPT 4 | 5 | Original License: 6 | MIT License 7 | 8 | Copyright (c) 2022 Andrej Karpathy 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | Original comments: 29 | Full definition of a GPT Language Model, all of it in this single file. 30 | References: 31 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 32 | https://github.com/openai/gpt-2/blob/master/src/model.py 33 | 2) huggingface/transformers PyTorch implementation: 34 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 35 | """ 36 | 37 | import math 38 | from dataclasses import dataclass 39 | 40 | import torch 41 | import torch.nn as nn 42 | from torch.nn import functional as F 43 | 44 | 45 | # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) 46 | def new_gelu(x): 47 | """ 48 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). 49 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 50 | """ 51 | return ( 52 | 0.5 53 | * x 54 | * ( 55 | 1.0 56 | + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))) 57 | ) 58 | ) 59 | 60 | 61 | class CausalSelfAttention(nn.Module): 62 | def __init__(self, config): 63 | super().__init__() 64 | assert config.n_embd % config.n_head == 0 65 | # key, query, value projections for all heads, but in a batch 66 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) 67 | # output projection 68 | self.c_proj = nn.Linear(config.n_embd, config.n_embd) 69 | # regularization 70 | self.attn_dropout = nn.Dropout(config.dropout) 71 | self.resid_dropout = nn.Dropout(config.dropout) 72 | # causal mask to ensure that attention is only applied to the left in the input sequence 73 | self.register_buffer( 74 | "bias", 75 | # torch.ones(1, 1, config.block_size, config.block_size), 76 | torch.tril(torch.ones(config.block_size, config.block_size)).view( 77 | 1, 1, config.block_size, config.block_size 78 | ) 79 | if config.causal 80 | else torch.ones(1, 1, config.block_size, config.block_size), 81 | ) 82 | self.n_head = config.n_head 83 | self.n_embd = config.n_embd 84 | 85 | def forward(self, x): 86 | ( 87 | B, 88 | T, 89 | C, 90 | ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 91 | 92 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 93 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 94 | k = k.view(B, T, self.n_head, C // self.n_head).transpose( 95 | 1, 2 96 | ) # (B, nh, T, hs) 97 | q = q.view(B, T, self.n_head, C // self.n_head).transpose( 98 | 1, 2 99 | ) # (B, nh, T, hs) 100 | v = v.view(B, T, self.n_head, C // self.n_head).transpose( 101 | 1, 2 102 | ) # (B, nh, T, hs) 103 | 104 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 105 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 106 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) 107 | att = F.softmax(att, dim=-1) 108 | att = self.attn_dropout(att) 109 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 110 | y = ( 111 | y.transpose(1, 2).contiguous().view(B, T, C) 112 | ) # re-assemble all head outputs side by side 113 | 114 | # output projection 115 | y = self.resid_dropout(self.c_proj(y)) 116 | return y 117 | 118 | 119 | class MLP(nn.Module): 120 | def __init__(self, config): 121 | super().__init__() 122 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) 123 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) 124 | self.dropout = nn.Dropout(config.dropout) 125 | 126 | def forward(self, x): 127 | x = self.c_fc(x) 128 | x = new_gelu(x) 129 | x = self.c_proj(x) 130 | x = self.dropout(x) 131 | return x 132 | 133 | 134 | class Block(nn.Module): 135 | def __init__(self, config): 136 | super().__init__() 137 | self.ln_1 = nn.LayerNorm(config.n_embd) 138 | self.attn = CausalSelfAttention(config) 139 | self.ln_2 = nn.LayerNorm(config.n_embd) 140 | self.mlp = MLP(config) 141 | 142 | def forward(self, x): 143 | x = x + self.attn(self.ln_1(x)) 144 | x = x + self.mlp(self.ln_2(x)) 145 | return x 146 | 147 | 148 | @dataclass 149 | class GPTConfig: 150 | block_size: int = 1024 151 | input_dim: int = 256 152 | output_dim: int = 256 153 | n_layer: int = 12 154 | n_head: int = 12 155 | n_embd: int = 768 156 | dropout: float = 0.1 157 | causal: bool = False 158 | 159 | 160 | class GPT(nn.Module): 161 | def __init__(self, config): 162 | super().__init__() 163 | assert config.input_dim is not None 164 | assert config.output_dim is not None 165 | assert config.block_size is not None 166 | self.config = config 167 | 168 | self.transformer = nn.ModuleDict( 169 | dict( 170 | wte=nn.Linear(config.input_dim, config.n_embd), 171 | wpe=nn.Embedding(config.block_size, config.n_embd), 172 | drop=nn.Dropout(config.dropout), 173 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 174 | ln_f=nn.LayerNorm(config.n_embd), 175 | ) 176 | ) 177 | self.lm_head = nn.Linear(config.n_embd, config.output_dim, bias=False) 178 | # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper 179 | self.apply(self._init_weights) 180 | for pn, p in self.named_parameters(): 181 | if pn.endswith("c_proj.weight"): 182 | torch.nn.init.normal_( 183 | p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) 184 | ) 185 | 186 | # report number of parameters 187 | n_params = sum(p.numel() for p in self.parameters()) 188 | print("number of parameters in GPT: %.2fM" % (n_params / 1e6,)) 189 | 190 | def forward(self, input, targets=None): 191 | device = input.device 192 | b, t, d = input.size() 193 | assert ( 194 | t <= self.config.block_size 195 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 196 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( 197 | 0 198 | ) # shape (1, t) 199 | 200 | # forward the GPT model itself 201 | tok_emb = self.transformer.wte( 202 | input 203 | ) # token embeddings of shape (b, t, n_embd) 204 | pos_emb = self.transformer.wpe( 205 | pos 206 | ) # position embeddings of shape (1, t, n_embd) 207 | x = self.transformer.drop(tok_emb + pos_emb) 208 | for block in self.transformer.h: 209 | x = block(x) 210 | x = self.transformer.ln_f(x) 211 | logits = self.lm_head(x) 212 | return logits 213 | 214 | def _init_weights(self, module): 215 | if isinstance(module, nn.Linear): 216 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 217 | if module.bias is not None: 218 | torch.nn.init.zeros_(module.bias) 219 | elif isinstance(module, nn.Embedding): 220 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 221 | elif isinstance(module, nn.LayerNorm): 222 | torch.nn.init.zeros_(module.bias) 223 | torch.nn.init.ones_(module.weight) 224 | 225 | def crop_block_size(self, block_size): 226 | assert block_size <= self.config.block_size 227 | self.config.block_size = block_size 228 | self.transformer.wpe.weight = nn.Parameter( 229 | self.transformer.wpe.weight[:block_size] 230 | ) 231 | for block in self.transformer.h: 232 | block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] 233 | 234 | def configure_optimizers(self, weight_decay, learning_rate, betas): 235 | """ 236 | This long function is unfortunately doing something very simple and is being very defensive: 237 | We are separating out all parameters of the model into two buckets: those that will experience 238 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 239 | We are then returning the PyTorch optimizer object. 240 | """ 241 | 242 | # separate out all parameters to those that will and won't experience regularizing weight decay 243 | decay = set() 244 | no_decay = set() 245 | whitelist_weight_modules = (torch.nn.Linear,) 246 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 247 | for mn, m in self.named_modules(): 248 | for pn, p in m.named_parameters(): 249 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 250 | if pn.endswith("bias"): 251 | # all biases will not be decayed 252 | no_decay.add(fpn) 253 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): 254 | # weights of whitelist modules will be weight decayed 255 | decay.add(fpn) 256 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): 257 | # weights of blacklist modules will NOT be weight decayed 258 | no_decay.add(fpn) 259 | 260 | # validate that we considered every parameter 261 | param_dict = {pn: p for pn, p in self.named_parameters()} 262 | inter_params = decay & no_decay 263 | union_params = decay | no_decay 264 | assert ( 265 | len(inter_params) == 0 266 | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) 267 | assert ( 268 | len(param_dict.keys() - union_params) == 0 269 | ), "parameters %s were not separated into either decay/no_decay set!" % ( 270 | str(param_dict.keys() - union_params), 271 | ) 272 | 273 | # create the pytorch optimizer object 274 | optim_groups = [ 275 | { 276 | "params": [param_dict[pn] for pn in sorted(list(decay))], 277 | "weight_decay": weight_decay, 278 | }, 279 | { 280 | "params": [param_dict[pn] for pn in sorted(list(no_decay))], 281 | "weight_decay": 0.0, 282 | }, 283 | ] 284 | # optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) 285 | optimizer = torch.optim.Adam(optim_groups, lr=learning_rate, betas=betas) 286 | return optimizer 287 | -------------------------------------------------------------------------------- /point_policy/agent/networks/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Callable, List, Optional 3 | 4 | 5 | class MLP(torch.nn.Sequential): 6 | """This block implements the multi-layer perceptron (MLP) module. 7 | Adapted for backward compatibility from the torchvision library: 8 | https://pytorch.org/vision/0.14/generated/torchvision.ops.MLP.html 9 | 10 | LICENSE: 11 | 12 | From PyTorch: 13 | 14 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 15 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 16 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 17 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 18 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 19 | Copyright (c) 2011-2013 NYU (Clement Farabet) 20 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 21 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 22 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 23 | 24 | From Caffe2: 25 | 26 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 27 | 28 | All contributions by Facebook: 29 | Copyright (c) 2016 Facebook Inc. 30 | 31 | All contributions by Google: 32 | Copyright (c) 2015 Google Inc. 33 | All rights reserved. 34 | 35 | All contributions by Yangqing Jia: 36 | Copyright (c) 2015 Yangqing Jia 37 | All rights reserved. 38 | 39 | All contributions by Kakao Brain: 40 | Copyright 2019-2020 Kakao Brain 41 | 42 | All contributions by Cruise LLC: 43 | Copyright (c) 2022 Cruise LLC. 44 | All rights reserved. 45 | 46 | All contributions from Caffe: 47 | Copyright(c) 2013, 2014, 2015, the respective contributors 48 | All rights reserved. 49 | 50 | All other contributions: 51 | Copyright(c) 2015, 2016 the respective contributors 52 | All rights reserved. 53 | 54 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 55 | copyright over their contributions to Caffe2. The project versioning records 56 | all such contribution and copyright details. If a contributor wants to further 57 | mark their specific copyright on a particular contribution, they should 58 | indicate their copyright solely in the commit message of the change when it is 59 | committed. 60 | 61 | All rights reserved. 62 | 63 | Redistribution and use in source and binary forms, with or without 64 | modification, are permitted provided that the following conditions are met: 65 | 66 | 1. Redistributions of source code must retain the above copyright 67 | notice, this list of conditions and the following disclaimer. 68 | 69 | 2. Redistributions in binary form must reproduce the above copyright 70 | notice, this list of conditions and the following disclaimer in the 71 | documentation and/or other materials provided with the distribution. 72 | 73 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 74 | and IDIAP Research Institute nor the names of its contributors may be 75 | used to endorse or promote products derived from this software without 76 | specific prior written permission. 77 | 78 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 79 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 80 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 81 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 82 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 83 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 84 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 85 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 86 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 87 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 88 | POSSIBILITY OF SUCH DAMAGE. 89 | 90 | 91 | Args: 92 | in_channels (int): Number of channels of the input 93 | hidden_channels (List[int]): List of the hidden channel dimensions 94 | norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` 95 | activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` 96 | inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. 97 | Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. 98 | bias (bool): Whether to use bias in the linear layer. Default ``True`` 99 | dropout (float): The probability for the dropout layer. Default: 0.0 100 | """ 101 | 102 | def __init__( 103 | self, 104 | in_channels: int, 105 | hidden_channels: List[int], 106 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, 107 | inplace: Optional[bool] = None, 108 | bias: bool = True, 109 | dropout: float = 0.0, 110 | ): 111 | params = {} if inplace is None else {"inplace": inplace} 112 | 113 | layers = [] 114 | in_dim = in_channels 115 | for hidden_dim in hidden_channels[:-1]: 116 | layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) 117 | layers.append(activation_layer(**params)) 118 | layers.append(torch.nn.Dropout(dropout, **params)) 119 | in_dim = hidden_dim 120 | 121 | layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) 122 | layers.append(torch.nn.Dropout(dropout, **params)) 123 | 124 | super().__init__(*layers) 125 | -------------------------------------------------------------------------------- /point_policy/agent/networks/policy_head.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import robomimic.utils.tensor_utils as TensorUtils 3 | import torch 4 | import torch.distributions as D 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import utils 9 | from agent.networks.utils.diffusion_policy import DiffusionPolicy 10 | from agent.networks.mlp import MLP 11 | 12 | ######################################### Deterministic Head ######################################### 13 | 14 | 15 | class DeterministicHead(nn.Module): 16 | def __init__( 17 | self, 18 | input_size, 19 | output_size, 20 | hidden_size=1024, 21 | num_layers=2, 22 | action_squash=False, 23 | loss_coef=1.0, 24 | ): 25 | super().__init__() 26 | self.loss_coef = loss_coef 27 | 28 | sizes = [input_size] + [hidden_size] * num_layers + [output_size] 29 | layers = [] 30 | for i in range(num_layers): 31 | layers += [nn.Linear(sizes[i], sizes[i + 1]), nn.ReLU()] 32 | layers += [nn.Linear(sizes[-2], sizes[-1])] 33 | 34 | if action_squash: 35 | layers += [nn.Tanh()] 36 | 37 | self.net = nn.Sequential(*layers) 38 | 39 | def forward(self, x, stddev=None, ret_action_value=False, **kwargs): 40 | mu = self.net(x) 41 | std = stddev if stddev is not None else 0.1 42 | std = torch.ones_like(mu) * std 43 | dist = utils.Normal(mu, std) 44 | if ret_action_value: 45 | return dist.mean 46 | else: 47 | return dist 48 | 49 | def loss_fn(self, dist, target, mask=None, reduction="mean", **kwargs): 50 | log_probs = dist.log_prob(target) 51 | if mask is not None: 52 | log_probs = log_probs * mask[..., None] / mask.mean() 53 | loss = -log_probs 54 | 55 | if reduction == "mean": 56 | loss = loss.mean() * self.loss_coef 57 | elif reduction == "none": 58 | loss = loss * self.loss_coef 59 | elif reduction == "sum": 60 | loss = loss.sum() * self.loss_coef 61 | else: 62 | raise NotImplementedError 63 | 64 | return { 65 | "actor_loss": loss, 66 | } 67 | 68 | def pred_loss_fn(self, pred, target, reduction="mean", **kwargs): 69 | dist = utils.TruncatedNormal(pred, 0.1) 70 | log_probs = dist.log_prob(target) 71 | loss = -log_probs 72 | 73 | if reduction == "mean": 74 | loss = loss.mean() * self.loss_coef 75 | elif reduction == "none": 76 | loss = loss * self.loss_coef 77 | elif reduction == "sum": 78 | loss = loss.sum() * self.loss_coef 79 | else: 80 | raise NotImplementedError 81 | 82 | return { 83 | "actor_loss": loss, 84 | } 85 | 86 | 87 | ######################################### Diffusion Head ######################################### 88 | 89 | 90 | class DiffusionHead(nn.Module): 91 | def __init__( 92 | self, 93 | input_size, 94 | output_size, 95 | obs_horizon, 96 | pred_horizon, 97 | hidden_size=1024, 98 | num_layers=2, 99 | device="cpu", 100 | loss_coef=100.0, 101 | ): 102 | super().__init__() 103 | 104 | self.net = DiffusionPolicy( 105 | obs_dim=input_size, 106 | act_dim=output_size, 107 | obs_horizon=obs_horizon, 108 | pred_horizon=pred_horizon, 109 | hidden_dim=hidden_size, 110 | num_layers=num_layers, 111 | policy_type="transformer", 112 | device=device, 113 | ) 114 | 115 | self.loss_coef = loss_coef 116 | 117 | def forward(self, x, stddev=None, ret_action_value=False, **kwargs): 118 | out = self.net(x, kwargs.get("action_seq", None)) 119 | return out[0] if ret_action_value else out 120 | 121 | def loss_fn(self, out, target, mask=None, reduction="mean", **kwargs): 122 | noise_pred = out["noise_pred"] 123 | noise = out["noise"] 124 | 125 | loss = F.mse_loss(noise_pred, noise, reduction="none") 126 | if mask is not None: 127 | loss = loss * mask[..., None] / mask.mean() 128 | loss = loss.mean() 129 | 130 | return { 131 | "actor_loss": loss * self.loss_coef, 132 | } 133 | -------------------------------------------------------------------------------- /point_policy/agent/networks/rgb_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from: https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/lifelong/models/modules/rgb_modules.py 3 | 4 | This file contains all neural modules related to encoding the spatial 5 | information of obs_t, i.e., the abstracted knowledge of the current visual 6 | input conditioned on the language. 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision 12 | 13 | import utils 14 | 15 | 16 | ############################################################################### 17 | # 18 | # Modules related to encoding visual information (can conditioned on language) 19 | # 20 | ############################################################################### 21 | 22 | 23 | class SpatialSoftmax(nn.Module): 24 | """ 25 | The spatial softmax layer (https://rll.berkeley.edu/dsae/dsae.pdf) 26 | """ 27 | 28 | def __init__(self, in_c, in_h, in_w, num_kp=None): 29 | super().__init__() 30 | self._spatial_conv = nn.Conv2d(in_c, num_kp, kernel_size=1) 31 | 32 | pos_x, pos_y = torch.meshgrid( 33 | torch.linspace(-1, 1, in_w).float(), 34 | torch.linspace(-1, 1, in_h).float(), 35 | ) 36 | 37 | pos_x = pos_x.reshape(1, in_w * in_h) 38 | pos_y = pos_y.reshape(1, in_w * in_h) 39 | self.register_buffer("pos_x", pos_x) 40 | self.register_buffer("pos_y", pos_y) 41 | 42 | if num_kp is None: 43 | self._num_kp = in_c 44 | else: 45 | self._num_kp = num_kp 46 | 47 | self._in_c = in_c 48 | self._in_w = in_w 49 | self._in_h = in_h 50 | 51 | def forward(self, x): 52 | assert x.shape[1] == self._in_c 53 | assert x.shape[2] == self._in_h 54 | assert x.shape[3] == self._in_w 55 | 56 | h = x 57 | if self._num_kp != self._in_c: 58 | h = self._spatial_conv(h) 59 | h = h.contiguous().view(-1, self._in_h * self._in_w) 60 | 61 | attention = F.softmax(h, dim=-1) 62 | keypoint_x = ( 63 | (self.pos_x * attention).sum(1, keepdims=True).view(-1, self._num_kp) 64 | ) 65 | keypoint_y = ( 66 | (self.pos_y * attention).sum(1, keepdims=True).view(-1, self._num_kp) 67 | ) 68 | keypoints = torch.cat([keypoint_x, keypoint_y], dim=1) 69 | return keypoints 70 | 71 | 72 | class SpatialProjection(nn.Module): 73 | def __init__(self, input_shape, out_dim): 74 | super().__init__() 75 | 76 | assert ( 77 | len(input_shape) == 3 78 | ), "[error] spatial projection: input shape is not a 3-tuple" 79 | in_c, in_h, in_w = input_shape 80 | num_kp = out_dim // 2 81 | self.out_dim = out_dim 82 | self.spatial_softmax = SpatialSoftmax(in_c, in_h, in_w, num_kp=num_kp) 83 | self.projection = nn.Linear(num_kp * 2, out_dim) 84 | 85 | def forward(self, x): 86 | out = self.spatial_softmax(x) 87 | out = self.projection(out) 88 | return out 89 | 90 | def output_shape(self, input_shape): 91 | return input_shape[:-3] + (self.out_dim,) 92 | 93 | 94 | class ResnetEncoder(nn.Module): 95 | """ 96 | A Resnet-18-based encoder for mapping an image to a latent vector 97 | 98 | Encode (f) an image into a latent vector. 99 | 100 | y = f(x), where 101 | x: (B, C, H, W) 102 | y: (B, H_out) 103 | 104 | Args: 105 | input_shape: (C, H, W), the shape of the image 106 | output_size: H_out, the latent vector size 107 | pretrained: whether use pretrained resnet 108 | freeze: whether freeze the pretrained resnet 109 | remove_layer_num: remove the top # layers 110 | no_stride: do not use striding 111 | """ 112 | 113 | def __init__( 114 | self, 115 | input_shape, 116 | output_size, 117 | pretrained=False, 118 | freeze=False, 119 | remove_layer_num=2, 120 | no_stride=False, 121 | language_dim=768, 122 | language_fusion="film", 123 | ): 124 | super().__init__() 125 | 126 | ### 1. encode input (images) using convolutional layers 127 | assert remove_layer_num <= 5, "[error] please only remove <=5 layers" 128 | layers = list(torchvision.models.resnet18(pretrained=pretrained).children())[ 129 | :-remove_layer_num 130 | ] 131 | self.remove_layer_num = remove_layer_num 132 | 133 | assert ( 134 | len(input_shape) == 3 135 | ), "[error] input shape of resnet should be (C, H, W)" 136 | 137 | in_channels = input_shape[0] 138 | if in_channels != 3: # has eye_in_hand, increase channel size 139 | conv0 = nn.Conv2d( 140 | in_channels=in_channels, 141 | out_channels=64, 142 | kernel_size=(7, 7), 143 | stride=(2, 2), 144 | padding=(3, 3), 145 | bias=False, 146 | ) 147 | layers[0] = conv0 148 | 149 | self.no_stride = no_stride 150 | if self.no_stride: 151 | layers[0].stride = (1, 1) 152 | layers[3].stride = 1 153 | 154 | self.resnet18_base = nn.Sequential(*layers[:4]) 155 | self.block_1 = layers[4][0] 156 | self.block_2 = layers[4][1] 157 | self.block_3 = layers[5][0] 158 | self.block_4 = layers[5][1] 159 | 160 | self.language_fusion = language_fusion 161 | if language_fusion != "none": 162 | self.lang_proj1 = nn.Linear(language_dim, 64 * 2) 163 | self.lang_proj2 = nn.Linear(language_dim, 64 * 2) 164 | self.lang_proj3 = nn.Linear(language_dim, 128 * 2) 165 | self.lang_proj4 = nn.Linear(language_dim, 128 * 2) 166 | 167 | if freeze: 168 | if in_channels != 3: 169 | raise Exception( 170 | "[error] cannot freeze pretrained " 171 | + "resnet with the extra eye_in_hand input" 172 | ) 173 | for param in self.resnet18_embeddings.parameters(): 174 | param.requires_grad = False 175 | 176 | ### 2. project the encoded input to a latent space 177 | x = torch.zeros(1, *input_shape) 178 | y = self.block_4( 179 | self.block_3(self.block_2(self.block_1(self.resnet18_base(x)))) 180 | ) 181 | output_shape = y.shape # compute the out dim 182 | self.projection_layer = SpatialProjection(output_shape[1:], output_size) 183 | self.output_shape = self.projection_layer(y).shape 184 | 185 | # Replace BatchNorm layers with GroupNorm 186 | self.resnet18_base = utils.batch_norm_to_group_norm(self.resnet18_base) 187 | self.block_1 = utils.batch_norm_to_group_norm(self.block_1) 188 | self.block_2 = utils.batch_norm_to_group_norm(self.block_2) 189 | self.block_3 = utils.batch_norm_to_group_norm(self.block_3) 190 | self.block_4 = utils.batch_norm_to_group_norm(self.block_4) 191 | 192 | # self.normlayer = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 193 | 194 | def forward(self, x, lang=None, return_intermediate=False): 195 | # # preprocess 196 | # preprocess = nn.Sequential(self.normlayer) 197 | # x = preprocess(x) 198 | 199 | h = self.resnet18_base(x) 200 | 201 | h = self.block_1(h) 202 | if lang is not None and self.language_fusion != "none": # FiLM layer 203 | B, C, H, W = h.shape 204 | beta, gamma = torch.split( 205 | self.lang_proj1(lang).reshape(B, C * 2, 1, 1), [C, C], 1 206 | ) 207 | h = (1 + gamma) * h + beta 208 | 209 | h = self.block_2(h) 210 | if lang is not None and self.language_fusion != "none": # FiLM layer 211 | B, C, H, W = h.shape 212 | beta, gamma = torch.split( 213 | self.lang_proj2(lang).reshape(B, C * 2, 1, 1), [C, C], 1 214 | ) 215 | h = (1 + gamma) * h + beta 216 | 217 | h = self.block_3(h) 218 | if lang is not None and self.language_fusion != "none": # FiLM layer 219 | B, C, H, W = h.shape 220 | beta, gamma = torch.split( 221 | self.lang_proj3(lang).reshape(B, C * 2, 1, 1), [C, C], 1 222 | ) 223 | h = (1 + gamma) * h + beta 224 | 225 | hi = self.block_4(h) 226 | if lang is not None and self.language_fusion != "none": # FiLM layer 227 | B, C, H, W = h.shape 228 | beta, gamma = torch.split( 229 | self.lang_proj4(lang).reshape(B, C * 2, 1, 1), [C, C], 1 230 | ) 231 | h = (1 + gamma) * h + beta 232 | 233 | if not return_intermediate: 234 | h = self.projection_layer(h) 235 | return h 236 | 237 | def output_shape(self): 238 | return self.output_shape 239 | -------------------------------------------------------------------------------- /point_policy/agent/p3po.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model that divides image into patches and performs cross attention with patches around points 3 | to predict future point tracks. 4 | """ 5 | 6 | import einops 7 | import numpy as np 8 | from collections import deque 9 | 10 | import torch 11 | from torch import nn 12 | 13 | import utils 14 | from agent.networks.policy_head import ( 15 | DeterministicHead, 16 | DiffusionHead, 17 | ) 18 | 19 | from agent.networks.mlp import MLP 20 | from agent.networks.gpt import GPT, GPTConfig 21 | 22 | 23 | class Actor(nn.Module): 24 | def __init__( 25 | self, 26 | repr_dim, 27 | act_dim, 28 | history_len, 29 | hidden_dim, 30 | policy_head="deterministic", 31 | device="cuda", 32 | ): 33 | super().__init__() 34 | 35 | self._policy_head = policy_head 36 | self._repr_dim = repr_dim 37 | self._act_dim = act_dim 38 | 39 | self._policy = GPT( 40 | GPTConfig( 41 | block_size=20, 42 | input_dim=repr_dim, 43 | output_dim=hidden_dim, 44 | n_layer=4, 45 | n_head=2, 46 | n_embd=hidden_dim, 47 | dropout=0.1, 48 | causal=True, 49 | ) 50 | ) 51 | 52 | if policy_head == "deterministic": 53 | self._action_head = DeterministicHead( 54 | hidden_dim, self._act_dim, hidden_size=hidden_dim, num_layers=2 55 | ) 56 | elif policy_head == "diffusion": 57 | obs_horizon = history_len 58 | pred_horizon = history_len 59 | self._action_head = DiffusionHead( 60 | input_size=hidden_dim, 61 | output_size=self._act_dim, 62 | obs_horizon=obs_horizon, 63 | pred_horizon=pred_horizon, 64 | hidden_size=hidden_dim, 65 | num_layers=2, 66 | device=device, 67 | ) 68 | 69 | self.apply(utils.weight_init) 70 | 71 | def forward( 72 | self, 73 | past_tracks, 74 | stddev, 75 | action=None, 76 | ): 77 | features = self._policy(past_tracks) 78 | 79 | pred_action = self._action_head( 80 | features, 81 | stddev, 82 | **{ 83 | "action_seq": action if action is not None else None, 84 | }, 85 | ) 86 | 87 | if action is None: 88 | return pred_action 89 | else: 90 | mask = torch.ones(action.shape[0], action.shape[1]).to(action.device) 91 | loss = self._action_head.loss_fn( 92 | pred_action, 93 | action, 94 | mask, 95 | reduction="mean", 96 | ) 97 | return pred_action, loss[0] if isinstance(loss, tuple) else loss 98 | 99 | 100 | class BCAgent: 101 | def __init__( 102 | self, 103 | obs_shape, 104 | action_shape, 105 | device, 106 | lr, 107 | hidden_dim, 108 | stddev_schedule, 109 | use_tb, 110 | policy_head, 111 | pixel_keys, 112 | history, 113 | history_len, 114 | eval_history_len, 115 | temporal_agg, 116 | max_episode_len, 117 | num_queries, 118 | use_robot_points, 119 | num_robot_points, 120 | use_object_points, 121 | num_object_points, 122 | point_dim, 123 | ): 124 | assert point_dim in [2, 3], "Only 2D or 3D points are supported" 125 | 126 | self.device = device 127 | self.lr = lr 128 | self.hidden_dim = hidden_dim 129 | self.stddev_schedule = stddev_schedule 130 | self.use_tb = use_tb 131 | self.policy_head = policy_head 132 | self.history_len = history_len if history else 1 133 | self.eval_history_len = eval_history_len if history else 1 134 | 135 | self._use_robot_points = use_robot_points 136 | self._num_robot_points = num_robot_points 137 | self._use_object_points = use_object_points 138 | self._num_object_points = num_object_points 139 | self.num_track_points = (num_robot_points if use_robot_points else 0) + ( 140 | num_object_points if use_object_points else 0 141 | ) 142 | 143 | # keys 144 | self.pixel_keys = pixel_keys 145 | 146 | # action chunking params 147 | self.temporal_agg = temporal_agg 148 | self.max_episode_len = max_episode_len 149 | self.num_queries = num_queries if self.temporal_agg else 1 150 | 151 | # observation params 152 | self._obs_dim = point_dim * self.num_track_points 153 | self.repr_dim = 512 154 | obs_shape = obs_shape[self.pixel_keys[0]] 155 | 156 | # actor parameters 157 | # Action dim was 6 with 3D position and 3D rotation. 158 | # Now it is 9D with 3D position and 6D rotation. 159 | self._act_dim = action_shape[0] + 3 160 | 161 | # Track model size 162 | model_size = 0 163 | 164 | # projector for points and patches 165 | self.point_projector = MLP(self._obs_dim, hidden_channels=[self.repr_dim]).to( 166 | device 167 | ) 168 | self.point_projector.apply(utils.weight_init) 169 | model_size += sum( 170 | p.numel() for p in self.point_projector.parameters() if p.requires_grad 171 | ) 172 | 173 | # actor 174 | action_dim = ( 175 | self._act_dim * self.num_queries if self.temporal_agg else self._act_dim 176 | ) 177 | self.actor = Actor( 178 | self.repr_dim, 179 | action_dim, 180 | self.history_len, 181 | hidden_dim, 182 | self.policy_head, 183 | device, 184 | ).to(device) 185 | model_size += sum(p.numel() for p in self.actor.parameters() if p.requires_grad) 186 | 187 | # optimizers 188 | # point projector 189 | params = list(self.point_projector.parameters()) 190 | self.point_opt = torch.optim.AdamW(params, lr=lr, weight_decay=1e-4) 191 | # actor 192 | self.actor_opt = torch.optim.AdamW( 193 | self.actor.parameters(), lr=lr, weight_decay=1e-4 194 | ) 195 | 196 | self.train() 197 | self.buffer_reset() 198 | 199 | def __repr__(self): 200 | return "bc" 201 | 202 | def train(self, training=True): 203 | self.training = training 204 | if training: 205 | self.point_projector.train(training) 206 | self.actor.train(training) 207 | else: 208 | self.point_projector.eval() 209 | self.actor.eval() 210 | 211 | def buffer_reset(self): 212 | self.observation_buffer = {} 213 | for key in self.pixel_keys: 214 | self.observation_buffer[f"past_tracks_{key}"] = deque( 215 | maxlen=self.eval_history_len 216 | ) # since point track history concatenated 217 | if self.temporal_agg: 218 | self.all_time_actions = torch.zeros( 219 | [ 220 | self.max_episode_len, 221 | self.max_episode_len + self.num_queries, 222 | self._act_dim, 223 | ] 224 | ).to(self.device) 225 | 226 | def clear_buffers(self): 227 | del self.observation_buffer 228 | if self.temporal_agg: 229 | del self.all_time_actions 230 | 231 | def act(self, obs, norm_stats, step, global_step, eval_mode=False, **kwargs): 232 | if norm_stats is not None: 233 | preprocess = { 234 | "past_tracks": lambda x: (x - norm_stats["past_tracks"]["min"]) 235 | / ( 236 | norm_stats["past_tracks"]["max"] 237 | - norm_stats["past_tracks"]["min"] 238 | + 1e-5 239 | ), 240 | } 241 | post_process = { 242 | "actions": lambda a: a 243 | * (norm_stats["actions"]["max"] - norm_stats["actions"]["min"]) 244 | + norm_stats["actions"]["min"], 245 | } 246 | 247 | past_tracks = [] 248 | for key in self.pixel_keys: 249 | point_tracks = preprocess["past_tracks"](obs[f"point_tracks_{key}"]) 250 | self.observation_buffer[f"past_tracks_{key}"].append(point_tracks) 251 | while len(self.observation_buffer[f"past_tracks_{key}"]) < self.history_len: 252 | self.observation_buffer[f"past_tracks_{key}"].append(point_tracks) 253 | past_tracks.append( 254 | np.stack(self.observation_buffer[f"past_tracks_{key}"], axis=0) 255 | ) 256 | 257 | # convert to tensor 258 | past_tracks = torch.as_tensor(np.array(past_tracks), device=self.device).float() 259 | 260 | # reshape past_tracks 261 | past_tracks = einops.rearrange(past_tracks, "n t p d-> n t (p d)") 262 | 263 | # encode past tracks 264 | past_tracks = self.point_projector(past_tracks) 265 | 266 | stddev = 0.1 267 | action = self.actor(past_tracks, stddev) 268 | 269 | if self.policy_head == "deterministic": 270 | action = action.mean 271 | elif self.policy_head == "diffusion": 272 | action = action[0] 273 | 274 | if self.temporal_agg: 275 | action = action.view(-1, self.num_queries, self._act_dim) 276 | self.all_time_actions[[step], step : step + self.num_queries] = action[-1:] 277 | actions_for_curr_step = self.all_time_actions[:, step] 278 | actions_populated = torch.all(actions_for_curr_step != 0, axis=1) 279 | actions_for_curr_step = actions_for_curr_step[actions_populated] 280 | k = 0.01 281 | exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) 282 | exp_weights = exp_weights / exp_weights.sum() 283 | exp_weights = torch.from_numpy(exp_weights).to(self.device).unsqueeze(dim=1) 284 | action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) 285 | if norm_stats is not None: 286 | return post_process["actions"](action.cpu().numpy()[0]) 287 | return action.cpu().numpy()[0] 288 | else: 289 | if norm_stats is not None: 290 | return post_process["actions"](action.cpu().numpy()[0, -1]) 291 | return action.cpu().numpy()[0, -1, :] 292 | 293 | def update(self, expert_replay_iter, step, **kwargs): 294 | metrics = dict() 295 | 296 | batch = next(expert_replay_iter) 297 | data = utils.to_torch(batch, self.device) 298 | 299 | past_tracks = data["past_tracks"].float() 300 | action = data["actions"].float() 301 | 302 | # reshape for training 303 | past_tracks = einops.rearrange(past_tracks, "n t p d-> n t (p d)") 304 | 305 | # encode past tracks 306 | past_tracks = self.point_projector(past_tracks) 307 | 308 | # rearrange action 309 | if self.temporal_agg: 310 | action = einops.rearrange(action, "b t1 t2 d -> b t1 (t2 d)") 311 | 312 | # actor loss 313 | stddev = utils.schedule(self.stddev_schedule, step) 314 | pred_action, actor_loss = self.actor(past_tracks, stddev, action, **kwargs) 315 | 316 | # optimize 317 | self.point_opt.zero_grad(set_to_none=True) 318 | self.actor_opt.zero_grad(set_to_none=True) 319 | actor_loss["actor_loss"].backward() 320 | self.point_opt.step() 321 | self.actor_opt.step() 322 | 323 | if self.policy_head == "diffusion" and step % 10 == 0: 324 | self.actor._action_head.net.ema_step() 325 | 326 | if self.use_tb: 327 | for key, value in actor_loss.items(): 328 | metrics[key] = value.item() 329 | 330 | return metrics 331 | 332 | def save_snapshot(self): 333 | model_keys = ["actor", "point_projector"] 334 | opt_keys = ["actor_opt", "point_opt"] 335 | # models 336 | payload = {k: self.__dict__[k].state_dict() for k in model_keys} 337 | # optimizers 338 | payload.update({k: self.__dict__[k] for k in opt_keys}) 339 | 340 | others = ["max_episode_len"] 341 | payload.update({k: self.__dict__[k] for k in others}) 342 | return payload 343 | 344 | def load_snapshot(self, payload, eval=False): 345 | # models 346 | model_keys = ["actor", "point_projector"] 347 | for k in model_keys: 348 | self.__dict__[k].load_state_dict(payload[k]) 349 | 350 | if eval: 351 | self.train(False) 352 | return 353 | -------------------------------------------------------------------------------- /point_policy/cfgs/agent/baku.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.baku.BCAgent 3 | obs_shape: ??? # to be specified later 4 | action_shape: ??? # to be specified later 5 | device: ${device} 6 | lr: 1e-4 7 | hidden_dim: ${suite.hidden_dim} 8 | stddev_schedule: 0.1 9 | use_tb: ${use_tb} 10 | policy_head: ${policy_head} 11 | pixel_keys: ${suite.pixel_keys} 12 | proprio_key: ${suite.proprio_key} 13 | use_proprio: ${use_proprio} 14 | history: ${suite.history} 15 | history_len: ${suite.history_len} 16 | eval_history_len: ${suite.eval_history_len} 17 | temporal_agg: ${temporal_agg} 18 | max_episode_len: ${suite.task_make_fn.max_episode_len} 19 | num_queries: ${num_queries} 20 | use_depth: ${suite.gt_depth} 21 | -------------------------------------------------------------------------------- /point_policy/cfgs/agent/mtpi.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.mtpi.BCAgent 3 | obs_shape: ??? # to be specified later 4 | action_shape: ??? # to be specified later 5 | device: ${device} 6 | lr: 1e-4 7 | hidden_dim: ${suite.hidden_dim} 8 | stddev_schedule: 0.1 9 | use_tb: ${use_tb} 10 | policy_head: ${policy_head} 11 | pixel_keys: ${suite.pixel_keys} 12 | history: ${suite.history} 13 | history_len: ${suite.history_len} 14 | eval_history_len: ${suite.eval_history_len} 15 | temporal_agg: ${temporal_agg} 16 | max_episode_len: ${suite.task_make_fn.max_episode_len} 17 | num_queries: ${num_queries} 18 | num_robot_points: ${suite.num_robot_points} 19 | use_object_points: ${suite.use_object_points} 20 | num_object_points: ${suite.num_object_points} 21 | pred_gripper: true 22 | -------------------------------------------------------------------------------- /point_policy/cfgs/agent/p3po.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.p3po.BCAgent 3 | obs_shape: ??? # to be specified later 4 | action_shape: ??? # to be specified later 5 | device: ${device} 6 | lr: 1e-4 7 | hidden_dim: ${suite.hidden_dim} 8 | stddev_schedule: 0.1 9 | use_tb: ${use_tb} 10 | policy_head: ${policy_head} 11 | pixel_keys: ${suite.pixel_keys} 12 | history: ${suite.history} 13 | history_len: ${suite.history_len} 14 | eval_history_len: ${suite.eval_history_len} 15 | temporal_agg: ${temporal_agg} 16 | max_episode_len: ${suite.task_make_fn.max_episode_len} 17 | num_queries: ${num_queries} 18 | use_robot_points: ${suite.use_robot_points} 19 | num_robot_points: ${suite.num_robot_points} 20 | use_object_points: ${suite.use_object_points} 21 | num_object_points: ${suite.num_object_points} 22 | point_dim: ${suite.point_dim} 23 | -------------------------------------------------------------------------------- /point_policy/cfgs/agent/point_policy.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.point_policy.BCAgent 3 | obs_shape: ??? # to be specified later 4 | action_shape: ??? # to be specified later 5 | device: ${device} 6 | lr: 1e-4 7 | hidden_dim: ${suite.hidden_dim} 8 | stddev_schedule: 0.1 9 | use_tb: ${use_tb} 10 | policy_head: ${policy_head} 11 | pixel_keys: ${suite.pixel_keys} 12 | history: ${suite.history} 13 | history_len: ${suite.history_len} 14 | eval_history_len: ${suite.eval_history_len} 15 | temporal_agg: ${temporal_agg} 16 | max_episode_len: ${suite.task_make_fn.max_episode_len} 17 | num_queries: ${num_queries} 18 | use_robot_points: ${suite.use_robot_points} 19 | num_robot_points: ${suite.num_robot_points} 20 | use_object_points: ${suite.use_object_points} 21 | num_object_points: ${suite.num_object_points} 22 | point_dim: ${suite.point_dim} 23 | pred_gripper: true 24 | -------------------------------------------------------------------------------- /point_policy/cfgs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - agent: point_policy 4 | - suite: point_policy 5 | - dataloader: point_policy 6 | - override hydra/launcher: submitit_local 7 | 8 | # Dir 9 | root_dir: /path/to/dir 10 | data_dir: /path/to/expert_demos 11 | 12 | # misc 13 | seed: 2 14 | device: cuda 15 | save_video: true 16 | use_tb: true 17 | batch_size: 64 18 | 19 | # experiment 20 | num_demos_per_task: 100 21 | policy_head: deterministic 22 | use_proprio: false 23 | eval: false 24 | experiment: train 25 | experiment_label: ${policy_head} 26 | 27 | # action chunking 28 | temporal_agg: true # aggregate actions over time 29 | num_queries: 20 30 | 31 | # expert dataset 32 | expert_dataset: ${dataloader.bc_dataset} 33 | 34 | # Load weights 35 | load_bc: false 36 | bc_weight: path/to/weight 37 | 38 | hydra: 39 | run: 40 | dir: ./exp_local/${now:%Y.%m.%d}/${experiment}/${experiment_label}/${now:%H%M%S}_hidden_dim_${suite.hidden_dim} 41 | sweep: 42 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S} 43 | subdir: ${hydra.job.num} 44 | launcher: 45 | tasks_per_node: 1 46 | nodes: 1 47 | submitit_folder: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${experiment}/.slurm 48 | -------------------------------------------------------------------------------- /point_policy/cfgs/config_eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - agent: point_policy 4 | - suite: point_policy 5 | - dataloader: point_policy 6 | - override hydra/launcher: submitit_local 7 | 8 | # Dir 9 | root_dir: /path/to/dir 10 | data_dir: /path/to/expert_demos 11 | 12 | # misc 13 | seed: 2 14 | device: cuda 15 | save_video: true 16 | use_tb: true 17 | batch_size: 64 18 | 19 | # experiment 20 | num_demos_per_task: 100 21 | policy_head: deterministic 22 | use_proprio: false 23 | eval: true 24 | experiment: eval 25 | experiment_label: ${policy_head} 26 | 27 | # action chunking 28 | temporal_agg: true # aggregate actions over time 29 | num_queries: 20 30 | 31 | # expert dataset 32 | expert_dataset: ${dataloader.bc_dataset} 33 | 34 | # Load weights 35 | bc_weight: path/to/weight 36 | 37 | hydra: 38 | run: 39 | dir: ./exp_local/eval/${now:%Y.%m.%d}_${experiment}/${experiment_label}/${now:%H%M%S}_hidden_dim_${suite.hidden_dim} 40 | sweep: 41 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S} 42 | subdir: ${hydra.job.num} 43 | launcher: 44 | tasks_per_node: 1 45 | nodes: 1 46 | submitit_folder: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${experiment}/.slurm 47 | -------------------------------------------------------------------------------- /point_policy/cfgs/dataloader/baku.yaml: -------------------------------------------------------------------------------- 1 | bc_dataset: 2 | _target_: read_data.baku.BCDataset 3 | path: ${data_dir}/franka_env 4 | tasks: ${suite.task.task_name} 5 | num_demos_per_task: ${num_demos_per_task} 6 | history: ${suite.history} 7 | history_len: ${suite.history_len} 8 | temporal_agg: ${temporal_agg} 9 | num_queries: ${num_queries} 10 | img_size: ${suite.img_size} 11 | action_after_steps: ${suite.action_after_steps} 12 | pixel_keys: ${suite.pixel_keys} 13 | subsample: 3 14 | skip_first_n: 0 15 | action_type: ${suite.action_type} 16 | gt_depth: ${suite.gt_depth} 17 | -------------------------------------------------------------------------------- /point_policy/cfgs/dataloader/mtpi.yaml: -------------------------------------------------------------------------------- 1 | bc_dataset: 2 | _target_: read_data.mtpi.BCDataset 3 | path: ${data_dir}/franka_env 4 | tasks: ${suite.task.task_name} 5 | num_demos_per_task: ${num_demos_per_task} 6 | history: ${suite.history} 7 | history_len: ${suite.history_len} 8 | temporal_agg: ${temporal_agg} 9 | num_queries: ${num_queries} 10 | img_size: ${suite.img_size} 11 | action_after_steps: ${suite.action_after_steps} 12 | use_robot_points: ${suite.use_robot_points} 13 | num_robot_points: ${suite.num_robot_points} 14 | use_object_points: ${suite.use_object_points} 15 | num_object_points: ${suite.num_object_points} 16 | point_dim: ${suite.point_dim} 17 | pixel_keys: ${suite.pixel_keys} 18 | subsample: 3 19 | skip_first_n: 0 20 | -------------------------------------------------------------------------------- /point_policy/cfgs/dataloader/p3po.yaml: -------------------------------------------------------------------------------- 1 | bc_dataset: 2 | _target_: read_data.p3po.BCDataset 3 | path: ${data_dir}/franka_env 4 | tasks: ${suite.task.task_name} 5 | num_demos_per_task: ${num_demos_per_task} 6 | history: ${suite.history} 7 | history_len: ${suite.history_len} 8 | temporal_agg: ${temporal_agg} 9 | num_queries: ${num_queries} 10 | img_size: ${suite.img_size} 11 | action_after_steps: ${suite.action_after_steps} 12 | use_robot_points: ${suite.use_robot_points} 13 | num_robot_points: ${suite.num_robot_points} 14 | use_object_points: ${suite.use_object_points} 15 | num_object_points: ${suite.num_object_points} 16 | point_dim: ${suite.point_dim} 17 | pixel_keys: ${suite.pixel_keys} 18 | action_type: ${suite.action_type} 19 | subsample: 3 20 | skip_first_n: 0 21 | gt_depth: ${suite.gt_depth} 22 | -------------------------------------------------------------------------------- /point_policy/cfgs/dataloader/point_policy.yaml: -------------------------------------------------------------------------------- 1 | bc_dataset: 2 | _target_: read_data.point_policy.BCDataset 3 | path: ${data_dir}/franka_env 4 | tasks: ${suite.task.task_name} 5 | num_demos_per_task: ${num_demos_per_task} 6 | history: ${suite.history} 7 | history_len: ${suite.history_len} 8 | temporal_agg: ${temporal_agg} 9 | num_queries: ${num_queries} 10 | img_size: ${suite.img_size} 11 | action_after_steps: ${suite.action_after_steps} 12 | use_robot_points: ${suite.use_robot_points} 13 | num_robot_points: ${suite.num_robot_points} 14 | use_object_points: ${suite.use_object_points} 15 | num_object_points: ${suite.num_object_points} 16 | point_dim: ${suite.point_dim} 17 | pixel_keys: ${suite.pixel_keys} 18 | subsample: 3 19 | skip_first_n: 0 20 | gt_depth: ${suite.gt_depth} 21 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/baku.yaml: -------------------------------------------------------------------------------- 1 | # @package suite 2 | defaults: 3 | - _self_ 4 | - task: franka_env 5 | 6 | suite: franka_env 7 | name: "franka_env" 8 | 9 | # obs dims 10 | img_size: 256 11 | gt_depth: false 12 | 13 | # action compute 14 | action_type: "absolute" # absolute, delta 15 | 16 | # task settings 17 | action_repeat: 1 18 | hidden_dim: 256 19 | 20 | # train settings 21 | num_train_steps: 100010 22 | log_every_steps: 100 23 | save_every_steps: 10000 24 | history: false 25 | history_len: 10 26 | 27 | # eval 28 | eval_every_steps: 200000 29 | num_eval_episodes: 5 30 | eval_history_len: 10 31 | 32 | # data loading 33 | action_after_steps: 1 34 | 35 | # obs_keys 36 | pixel_keys: ["pixels1", "pixels2"] 37 | proprio_key: "proprioceptive" 38 | feature_key: "features" 39 | 40 | # snapshot 41 | save_snapshot: true 42 | 43 | task_make_fn: 44 | _target_: suite.baku.make 45 | action_repeat: ${suite.action_repeat} 46 | seed: ${seed} 47 | height: ${suite.img_size} 48 | width: ${suite.img_size} 49 | max_episode_len: ??? # to be specified later 50 | max_state_dim: ??? # to be specified later 51 | pixel_keys: ${suite.pixel_keys} 52 | eval: ${eval} # eval true mean use robot 53 | action_type: ${suite.action_type} 54 | use_gt_depth: ${suite.gt_depth} 55 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/mtpi.yaml: -------------------------------------------------------------------------------- 1 | # @package suite 2 | defaults: 3 | - _self_ 4 | - task: franka_env 5 | 6 | suite: franka_env 7 | name: "franka_env" 8 | 9 | # obs dims 10 | img_size: [256, 256] # (width, height) 11 | calib_image_size: [640, 480] # (width, height) 12 | use_robot_points: true 13 | num_robot_points: 9 14 | 15 | # action compute 16 | point_dim: 2 # 2 or 3 17 | 18 | # object points 19 | use_object_points: true 20 | num_object_points: ${suite.task.num_object_points} 21 | 22 | # task settings 23 | action_repeat: 1 24 | hidden_dim: 256 25 | 26 | # train settings 27 | num_train_steps: 100100 28 | log_every_steps: 100 29 | save_every_steps: 10000 30 | history: true 31 | history_len: 10 32 | 33 | # eval 34 | eval_every_steps: 200000 35 | num_eval_episodes: 5 36 | eval_history_len: 10 37 | 38 | # data loading 39 | action_after_steps: 1 40 | 41 | # obs_keys 42 | pixel_keys: ["pixels1", "pixels2"] 43 | proprio_key: "proprioceptive" 44 | feature_key: "features" 45 | 46 | # snapshot 47 | save_snapshot: true 48 | 49 | task_make_fn: 50 | _target_: suite.mtpi.make 51 | task_name: ${suite.task.task_name} 52 | object_labels: ${suite.task.object_labels} 53 | action_repeat: ${suite.action_repeat} 54 | height: ${suite.img_size[1]} 55 | width: ${suite.img_size[0]} 56 | calib_height: ${suite.calib_image_size[1]} 57 | calib_width: ${suite.calib_image_size[0]} 58 | max_episode_len: ??? # to be specified later 59 | max_state_dim: ??? # to be specified later 60 | calib_path: ${root_dir}/calib/calib.npy 61 | eval: ${eval} # eval true mean use robot 62 | pixel_keys: ${suite.pixel_keys} 63 | use_robot_points: ${suite.use_robot_points} 64 | num_robot_points: ${suite.num_robot_points} 65 | use_object_points: ${suite.use_object_points} 66 | num_object_points: ${suite.num_object_points} 67 | points_cfg: null 68 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/p3po.yaml: -------------------------------------------------------------------------------- 1 | # @package suite 2 | defaults: 3 | - _self_ 4 | - task: franka_env 5 | 6 | suite: franka_env 7 | name: "franka_env" 8 | 9 | # obs dims 10 | img_size: [640, 480] # (width, height) 11 | use_robot_points: true 12 | num_robot_points: 9 13 | gt_depth: false 14 | 15 | # action compute 16 | point_dim: 3 # 2 or 3 17 | action_type: "absolute" # absolute, delta 18 | 19 | # object points 20 | use_object_points: true 21 | num_object_points: ${suite.task.num_object_points} 22 | 23 | # task settings 24 | action_repeat: 1 25 | hidden_dim: 256 26 | 27 | # train settings 28 | num_train_steps: 100100 29 | log_every_steps: 100 30 | save_every_steps: 10000 31 | history: true 32 | history_len: 10 33 | 34 | # eval 35 | eval_every_steps: 200000 36 | num_eval_episodes: 5 37 | eval_history_len: 10 38 | 39 | # data loading 40 | action_after_steps: 1 41 | 42 | # obs_keys 43 | pixel_keys: ["pixels1", "pixels2"] 44 | proprio_key: "proprioceptive" 45 | feature_key: "features" 46 | 47 | # snapshot 48 | save_snapshot: true 49 | 50 | task_make_fn: 51 | _target_: suite.p3po.make 52 | task_name: ${suite.task.task_name} 53 | object_labels: ${suite.task.object_labels} 54 | action_repeat: ${suite.action_repeat} 55 | height: ${suite.img_size[1]} 56 | width: ${suite.img_size[0]} 57 | max_episode_len: ??? # to be specified later 58 | max_state_dim: ??? # to be specified later 59 | calib_path: ${root_dir}/calib/calib.npy 60 | eval: ${eval} # eval true mean use robot 61 | pixel_keys: ${suite.pixel_keys} 62 | use_robot_points: ${suite.use_robot_points} 63 | num_robot_points: ${suite.num_robot_points} 64 | use_object_points: ${suite.use_object_points} 65 | num_object_points: ${suite.num_object_points} 66 | action_type: ${suite.action_type} 67 | points_cfg: ??? # to be specified later 68 | use_gt_depth: ${suite.gt_depth} 69 | point_dim: ${suite.point_dim} 70 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/point_policy.yaml: -------------------------------------------------------------------------------- 1 | # @package suite 2 | defaults: 3 | - _self_ 4 | - task: franka_env 5 | 6 | suite: franka_env 7 | name: "franka_env" 8 | 9 | # obs dims 10 | img_size: [640, 480] # (width, height) 11 | use_robot_points: true 12 | num_robot_points: 9 13 | gt_depth: false 14 | 15 | # action compute 16 | point_dim: 3 # 2 or 3 17 | 18 | # object points 19 | use_object_points: true 20 | num_object_points: ${suite.task.num_object_points} 21 | 22 | # task settings 23 | action_repeat: 1 24 | hidden_dim: 256 25 | 26 | # train settings 27 | num_train_steps: 100100 28 | log_every_steps: 100 29 | save_every_steps: 10000 30 | history: true 31 | history_len: 10 32 | 33 | # eval 34 | eval_every_steps: 200000 35 | num_eval_episodes: 5 36 | eval_history_len: 10 37 | 38 | # data loading 39 | action_after_steps: 1 40 | 41 | # obs_keys 42 | pixel_keys: ["pixels1", "pixels2"] 43 | proprio_key: "proprioceptive" 44 | feature_key: "features" 45 | 46 | # snapshot 47 | save_snapshot: true 48 | 49 | task_make_fn: 50 | _target_: suite.point_policy.make 51 | task_name: ${suite.task.task_name} 52 | object_labels: ${suite.task.object_labels} 53 | action_repeat: ${suite.action_repeat} 54 | height: ${suite.img_size[1]} 55 | width: ${suite.img_size[0]} 56 | max_episode_len: ??? # to be specified later 57 | max_state_dim: ??? # to be specified later 58 | calib_path: ${root_dir}/calib/calib.npy 59 | eval: ${eval} # eval true mean use robot 60 | pixel_keys: ${suite.pixel_keys} 61 | use_robot_points: ${suite.use_robot_points} 62 | num_robot_points: ${suite.num_robot_points} 63 | use_object_points: ${suite.use_object_points} 64 | num_object_points: ${suite.num_object_points} 65 | points_cfg: ??? # to be specified later 66 | use_gt_depth: ${suite.gt_depth} 67 | point_dim: ${suite.point_dim} 68 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/points_cfg.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: p3po 4 | 5 | root_dir: /path/to/repo 6 | dift_path: dift 7 | cotracker_checkpoint: co-tracker/checkpoints/scaled_online.pth 8 | 9 | task_name: null 10 | pixel_keys: null 11 | device: "cuda" 12 | 13 | width: -1 14 | height: -1 15 | image_size_multiplier: 1 16 | 17 | ensemble_size: 8 18 | dift_layer: 1 19 | dift_steps: 50 20 | 21 | num_points: -1 22 | 23 | object_labels: null 24 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/task/franka_env.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - franka_env: 0121_bottle_on_rack 4 | 5 | suite: franka_env 6 | 7 | task_name: ${suite.task.franka_env.task_name} 8 | object_labels: ${suite.task.franka_env.object_labels} 9 | num_object_points: ${suite.task.franka_env.num_object_points} 10 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/task/franka_env/bottle_on_rack.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | task_name: bottle_on_rack 5 | object_labels: [objects] 6 | num_object_points: 7 7 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/task/franka_env/bottle_upright.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | task_name: bottle_upright 5 | object_labels: [objects] 6 | num_object_points: 8 7 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/task/franka_env/bowl_in_oven.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | task_name: bowl_in_oven 5 | object_labels: [objects] 6 | num_object_points: 8 7 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/task/franka_env/bread_on_plate.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | task_name: bread_on_plate 5 | object_labels: [objects] 6 | num_object_points: 8 7 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/task/franka_env/close_oven.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | task_name: close_oven 5 | object_labels: [objects] 6 | num_object_points: 8 7 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/task/franka_env/drawer_close.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | task_name: drawer_close 5 | object_labels: [objects] 6 | num_object_points: 8 7 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/task/franka_env/fold_towel.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | task_name: fold_towel 5 | object_labels: [objects] 6 | num_object_points: 6 7 | -------------------------------------------------------------------------------- /point_policy/cfgs/suite/task/franka_env/sweep_broom.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | task_name: sweep_broom 5 | object_labels: [objects] 6 | num_object_points: 8 7 | -------------------------------------------------------------------------------- /point_policy/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import warnings 4 | import os 5 | 6 | os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 7 | os.environ["MUJOCO_GL"] = "egl" 8 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 9 | from pathlib import Path 10 | 11 | import hydra 12 | import torch 13 | import numpy as np 14 | 15 | import utils 16 | from logger import Logger 17 | from replay_buffer import make_expert_replay_loader 18 | from video import VideoRecorder 19 | 20 | warnings.filterwarnings("ignore", category=DeprecationWarning) 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | def make_agent(obs_spec, action_spec, cfg): 25 | obs_shape = {} 26 | for key in cfg.suite.pixel_keys: 27 | obs_shape[key] = obs_spec[key].shape 28 | if cfg.use_proprio: 29 | obs_shape[cfg.suite.proprio_key] = obs_spec[cfg.suite.proprio_key].shape 30 | obs_shape[cfg.suite.feature_key] = obs_spec[cfg.suite.feature_key].shape 31 | cfg.agent.obs_shape = obs_shape 32 | cfg.agent.action_shape = action_spec.shape 33 | return hydra.utils.instantiate(cfg.agent) 34 | 35 | 36 | class WorkspaceIL: 37 | def __init__(self, cfg): 38 | self.work_dir = Path.cwd() 39 | print(f"workspace: {self.work_dir}") 40 | 41 | self.cfg = cfg 42 | utils.set_seed_everywhere(cfg.seed) 43 | self.device = torch.device(cfg.device) 44 | 45 | # load data 46 | dataset_iterable = hydra.utils.call(self.cfg.expert_dataset) 47 | self.expert_replay_loader = make_expert_replay_loader( 48 | dataset_iterable, self.cfg.batch_size 49 | ) 50 | self.expert_replay_iter = iter(self.expert_replay_loader) 51 | 52 | # create logger 53 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb) 54 | # create envs 55 | self.cfg.suite.task_make_fn.max_episode_len = 400 56 | self.cfg.suite.task_make_fn.max_state_dim = ( 57 | self.expert_replay_loader.dataset._max_state_dim 58 | ) 59 | self.env, self.task_descriptions = hydra.utils.call(self.cfg.suite.task_make_fn) 60 | 61 | # create agent 62 | self.agent = make_agent( 63 | self.env[0].observation_spec(), self.env[0].action_spec(), cfg 64 | ) 65 | 66 | self.envs_till_idx = len(self.env) 67 | self.expert_replay_loader.dataset.envs_till_idx = self.envs_till_idx 68 | self.expert_replay_iter = iter(self.expert_replay_loader) 69 | 70 | self.timer = utils.Timer() 71 | self._global_step = 0 72 | self._global_episode = 0 73 | 74 | self.video_recorder = VideoRecorder( 75 | self.work_dir if self.cfg.save_video else None 76 | ) 77 | 78 | @property 79 | def global_step(self): 80 | return self._global_step 81 | 82 | @property 83 | def global_episode(self): 84 | return self._global_episode 85 | 86 | @property 87 | def global_frame(self): 88 | return self.global_step * self.cfg.suite.action_repeat 89 | 90 | def eval(self): 91 | self.agent.train(False) 92 | episode_rewards = [] 93 | successes = [] 94 | for env_idx in range(self.envs_till_idx): 95 | print(f"evaluating env {env_idx}") 96 | episode, total_reward = 0, 0 97 | eval_until_episode = utils.Until(self.cfg.suite.num_eval_episodes) 98 | success = [] 99 | 100 | while eval_until_episode(episode): 101 | time_step = self.env[env_idx].reset() 102 | self.agent.buffer_reset() 103 | step = 0 104 | 105 | if episode == 0: 106 | self.video_recorder.init(self.env[env_idx], enabled=True) 107 | 108 | while not time_step.last(): 109 | with torch.no_grad(), utils.eval_mode(self.agent): 110 | action = self.agent.act( 111 | time_step.observation, 112 | self.expert_replay_loader.dataset.stats, 113 | step, 114 | self.global_step, 115 | eval_mode=True, 116 | ) 117 | time_step = self.env[env_idx].step(action) 118 | self.video_recorder.record(self.env[env_idx]) 119 | total_reward += time_step.reward 120 | step += 1 121 | 122 | episode += 1 123 | success.append(time_step.observation["goal_achieved"]) 124 | self.video_recorder.save(f"{self.global_frame}_env{env_idx}.mp4") 125 | episode_rewards.append(total_reward / episode) 126 | successes.append(np.mean(success)) 127 | 128 | for _ in range(len(self.env) - self.envs_till_idx): 129 | episode_rewards.append(0) 130 | successes.append(0) 131 | 132 | with self.logger.log_and_dump_ctx(self.global_frame, ty="eval") as log: 133 | for env_idx, reward in enumerate(episode_rewards): 134 | log(f"episode_reward_env{env_idx}", reward) 135 | log(f"success_env{env_idx}", successes[env_idx]) 136 | log("episode_reward", np.mean(episode_rewards[: self.envs_till_idx])) 137 | log("success", np.mean(successes)) 138 | log("episode_length", step * self.cfg.suite.action_repeat / episode) 139 | log("episode", self.global_episode) 140 | log("step", self.global_step) 141 | 142 | self.agent.train(True) 143 | 144 | def save_snapshot(self): 145 | snapshot = self.work_dir / "snapshot.pt" 146 | self.agent.clear_buffers() 147 | keys_to_save = ["timer", "_global_step", "_global_episode"] 148 | payload = {k: self.__dict__[k] for k in keys_to_save} 149 | payload.update(self.agent.save_snapshot()) 150 | with snapshot.open("wb") as f: 151 | torch.save(payload, f) 152 | 153 | self.agent.buffer_reset() 154 | 155 | def load_snapshot(self, snapshots): 156 | # bc 157 | with snapshots["bc"].open("rb") as f: 158 | payload = torch.load(f) 159 | agent_payload = {} 160 | for k, v in payload.items(): 161 | if k not in self.__dict__: 162 | agent_payload[k] = v 163 | self.agent.load_snapshot(agent_payload, eval=True) 164 | 165 | 166 | @hydra.main(config_path="cfgs", config_name="config_eval") 167 | def main(cfg): 168 | from eval import WorkspaceIL as W 169 | 170 | workspace = W(cfg) 171 | 172 | # Load weights 173 | snapshots = {} 174 | # bc 175 | bc_snapshot = Path(cfg.bc_weight) 176 | if not bc_snapshot.exists(): 177 | raise FileNotFoundError(f"bc weight not found: {bc_snapshot}") 178 | print(f"loading bc weight: {bc_snapshot}") 179 | snapshots["bc"] = bc_snapshot 180 | workspace.load_snapshot(snapshots) 181 | 182 | workspace.eval() 183 | 184 | 185 | if __name__ == "__main__": 186 | main() 187 | -------------------------------------------------------------------------------- /point_policy/eval_point_track.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import warnings 4 | import os 5 | 6 | os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 7 | os.environ["MUJOCO_GL"] = "egl" 8 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 9 | from pathlib import Path 10 | 11 | import hydra 12 | import torch 13 | import numpy as np 14 | 15 | import utils 16 | from logger import Logger 17 | from replay_buffer import make_expert_replay_loader 18 | from video import VideoRecorder 19 | 20 | warnings.filterwarnings("ignore", category=DeprecationWarning) 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | def make_agent(obs_spec, action_spec, cfg): 25 | obs_shape = {} 26 | for key in cfg.suite.pixel_keys: 27 | obs_shape[key] = obs_spec[key].shape 28 | if cfg.use_proprio: 29 | obs_shape[cfg.suite.proprio_key] = obs_spec[cfg.suite.proprio_key].shape 30 | obs_shape[cfg.suite.feature_key] = obs_spec[cfg.suite.feature_key].shape 31 | cfg.agent.obs_shape = obs_shape 32 | cfg.agent.action_shape = action_spec.shape 33 | return hydra.utils.instantiate(cfg.agent) 34 | 35 | 36 | class Workspace: 37 | def __init__(self, cfg): 38 | self.work_dir = Path.cwd() 39 | print(f"workspace: {self.work_dir}") 40 | 41 | self.cfg = cfg 42 | utils.set_seed_everywhere(cfg.seed) 43 | self.device = torch.device(cfg.device) 44 | 45 | # load data 46 | dataset_iterable = hydra.utils.call(self.cfg.expert_dataset) 47 | self.expert_replay_loader = make_expert_replay_loader( 48 | dataset_iterable, self.cfg.batch_size 49 | ) 50 | self.expert_replay_iter = iter(self.expert_replay_loader) 51 | 52 | # create logger 53 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb) 54 | # create envs 55 | self.cfg.suite.task_make_fn.max_episode_len = ( 56 | self.expert_replay_loader.dataset._max_episode_len 57 | ) 58 | self.cfg.suite.task_make_fn.max_state_dim = ( 59 | self.expert_replay_loader.dataset._max_state_dim 60 | ) 61 | 62 | try: 63 | if self.cfg.suite.use_object_points: 64 | import yaml 65 | 66 | cfg_path = f"{cfg.root_dir}/point_policy/cfgs/suite/points_cfg.yaml" 67 | with open(cfg_path) as stream: 68 | try: 69 | points_cfg = yaml.safe_load(stream) 70 | except yaml.YAMLError as exc: 71 | print(exc) 72 | root_dir, dift_path, cotracker_checkpoint = ( 73 | points_cfg["root_dir"], 74 | points_cfg["dift_path"], 75 | points_cfg["cotracker_checkpoint"], 76 | ) 77 | points_cfg["dift_path"] = f"{root_dir}/{dift_path}" 78 | points_cfg[ 79 | "cotracker_checkpoint" 80 | ] = f"{root_dir}/{cotracker_checkpoint}" 81 | self.cfg.suite.task_make_fn.points_cfg = points_cfg 82 | except: 83 | pass 84 | 85 | self.env, self.task_descriptions = hydra.utils.call(self.cfg.suite.task_make_fn) 86 | 87 | # create agent 88 | self.agent = make_agent( 89 | self.env[0].observation_spec(), self.env[0].action_spec(), cfg 90 | ) 91 | 92 | self.envs_till_idx = len(self.env) 93 | self.expert_replay_loader.dataset.envs_till_idx = self.envs_till_idx 94 | self.expert_replay_iter = iter(self.expert_replay_loader) 95 | 96 | self.timer = utils.Timer() 97 | self._global_step = 0 98 | self._global_episode = 0 99 | 100 | self.video_recorder = VideoRecorder( 101 | self.work_dir if self.cfg.save_video else None 102 | ) 103 | 104 | @property 105 | def global_step(self): 106 | return self._global_step 107 | 108 | @property 109 | def global_episode(self): 110 | return self._global_episode 111 | 112 | @property 113 | def global_frame(self): 114 | return self.global_step * self.cfg.suite.action_repeat 115 | 116 | def eval(self): 117 | self.agent.train(False) 118 | episode_rewards = [] 119 | successes = [] 120 | for env_idx in range(self.envs_till_idx): 121 | print(f"evaluating env {env_idx}") 122 | episode, total_reward = 0, 0 123 | eval_until_episode = utils.Until(self.cfg.suite.num_eval_episodes) 124 | success = [] 125 | 126 | while eval_until_episode(episode): 127 | print(episode) 128 | time_step = self.env[env_idx].reset() 129 | self.agent.buffer_reset() 130 | step = 0 131 | 132 | if episode == 0: 133 | self.video_recorder.init(self.env[env_idx], enabled=True) 134 | 135 | while not time_step.last(): 136 | with torch.no_grad(), utils.eval_mode(self.agent): 137 | action = self.agent.act( 138 | time_step.observation, 139 | self.expert_replay_loader.dataset.stats, 140 | step, 141 | self.global_step, 142 | eval_mode=True, 143 | ) 144 | 145 | time_step = self.env[env_idx].step(action) 146 | self.video_recorder.record(self.env[env_idx]) 147 | total_reward += time_step.reward 148 | step += 1 149 | 150 | episode += 1 151 | success.append(time_step.observation["goal_achieved"]) 152 | self.video_recorder.save(f"{self.global_frame}_env{env_idx}.mp4") 153 | episode_rewards.append(total_reward / episode) 154 | successes.append(np.mean(success)) 155 | 156 | for _ in range(len(self.env) - self.envs_till_idx): 157 | episode_rewards.append(0) 158 | successes.append(0) 159 | 160 | with self.logger.log_and_dump_ctx(self.global_frame, ty="eval") as log: 161 | for env_idx, reward in enumerate(episode_rewards): 162 | log(f"episode_reward_env{env_idx}", reward) 163 | log(f"success_env{env_idx}", successes[env_idx]) 164 | log("episode_reward", np.mean(episode_rewards[: self.envs_till_idx])) 165 | log("success", np.mean(successes)) 166 | log("episode_length", step * self.cfg.suite.action_repeat / episode) 167 | log("episode", self.global_episode) 168 | log("step", self.global_step) 169 | 170 | self.agent.train(True) 171 | 172 | def save_snapshot(self): 173 | snapshot = self.work_dir / "snapshot.pt" 174 | self.agent.clear_buffers() 175 | keys_to_save = ["timer", "_global_step", "_global_episode"] 176 | payload = {k: self.__dict__[k] for k in keys_to_save} 177 | payload.update(self.agent.save_snapshot()) 178 | with snapshot.open("wb") as f: 179 | torch.save(payload, f) 180 | 181 | self.agent.buffer_reset() 182 | 183 | def load_snapshot(self, snapshots): 184 | # bc 185 | with snapshots["bc"].open("rb") as f: 186 | payload = torch.load(f) 187 | agent_payload = {} 188 | for k, v in payload.items(): 189 | if k not in self.__dict__: 190 | agent_payload[k] = v 191 | self.agent.load_snapshot(agent_payload, eval=True) 192 | 193 | 194 | @hydra.main(config_path="cfgs", config_name="config_eval") 195 | def main(cfg): 196 | workspace = Workspace(cfg) 197 | 198 | # Load weights 199 | snapshots = {} 200 | # bc 201 | bc_snapshot = Path(cfg.bc_weight) 202 | if not bc_snapshot.exists(): 203 | raise FileNotFoundError(f"bc weight not found: {bc_snapshot}") 204 | print(f"loading bc weight: {bc_snapshot}") 205 | snapshots["bc"] = bc_snapshot 206 | workspace.load_snapshot(snapshots) 207 | 208 | workspace.eval() 209 | 210 | 211 | if __name__ == "__main__": 212 | main() 213 | -------------------------------------------------------------------------------- /point_policy/logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from termcolor import colored 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | BC_TRAIN_FORMAT = [ 12 | ("step", "S", "int"), 13 | ("actor_loss", "L", "float"), 14 | ("total_time", "T", "time"), 15 | ] 16 | BC_EVAL_FORMAT = [ 17 | ("frame", "F", "int"), 18 | ("step", "S", "int"), 19 | ("episode", "E", "int"), 20 | ("episode_length", "L", "int"), 21 | ("episode_reward", "R", "float"), 22 | ("imitation_reward", "R_i", "float"), 23 | ("total_time", "T", "time"), 24 | ] 25 | 26 | 27 | class AverageMeter(object): 28 | def __init__(self): 29 | self._sum = 0 30 | self._count = 0 31 | 32 | def update(self, value, n=1): 33 | self._sum += value 34 | self._count += n 35 | 36 | def value(self): 37 | return self._sum / max(1, self._count) 38 | 39 | 40 | class MetersGroup(object): 41 | def __init__(self, csv_file_name, formating): 42 | self._csv_file_name = csv_file_name 43 | self._formating = formating 44 | self._meters = defaultdict(AverageMeter) 45 | self._csv_file = None 46 | self._csv_writer = None 47 | 48 | def log(self, key, value, n=1): 49 | self._meters[key].update(value, n) 50 | 51 | def _prime_meters(self): 52 | data = dict() 53 | for key, meter in self._meters.items(): 54 | if key.startswith("train_vq"): 55 | key = key[len("train_vq") + 1 :] 56 | elif key.startswith("train"): 57 | key = key[len("train") + 1 :] 58 | else: 59 | key = key[len("eval") + 1 :] 60 | key = key.replace("/", "_") 61 | data[key] = meter.value() 62 | return data 63 | 64 | def _remove_old_entries(self, data): 65 | rows = [] 66 | with self._csv_file_name.open("r") as f: 67 | reader = csv.DictReader(f) 68 | for row in reader: 69 | if float(row["episode"]) >= data["episode"]: 70 | break 71 | rows.append(row) 72 | with self._csv_file_name.open("w") as f: 73 | writer = csv.DictWriter(f, fieldnames=sorted(data.keys()), restval=0.0) 74 | writer.writeheader() 75 | for row in rows: 76 | writer.writerow(row) 77 | 78 | def _dump_to_csv(self, data): 79 | if self._csv_writer is None: 80 | should_write_header = True 81 | if self._csv_file_name.exists(): 82 | self._remove_old_entries(data) 83 | should_write_header = False 84 | 85 | self._csv_file = self._csv_file_name.open("a") 86 | self._csv_writer = csv.DictWriter( 87 | self._csv_file, fieldnames=sorted(data.keys()), restval=0.0 88 | ) 89 | if should_write_header: 90 | self._csv_writer.writeheader() 91 | 92 | self._csv_writer.writerow(data) 93 | self._csv_file.flush() 94 | 95 | def _format(self, key, value, ty): 96 | if ty == "int": 97 | value = int(value) 98 | return f"{key}: {value}" 99 | elif ty == "float": 100 | return f"{key}: {value:.04f}" 101 | elif ty == "time": 102 | value = str(datetime.timedelta(seconds=int(value))) 103 | return f"{key}: {value}" 104 | else: 105 | raise f"invalid format type: {ty}" 106 | 107 | def _dump_to_console(self, data, prefix): 108 | prefix = colored( 109 | prefix, "yellow" if prefix in ["train", "train_vq"] else "green" 110 | ) 111 | pieces = [f"| {prefix: <14}"] 112 | for key, disp_key, ty in self._formating: 113 | value = data.get(key, 0) 114 | pieces.append(self._format(disp_key, value, ty)) 115 | print(" | ".join(pieces)) 116 | 117 | def dump(self, step, prefix): 118 | if len(self._meters) == 0: 119 | return 120 | data = self._prime_meters() 121 | data["frame"] = step 122 | self._dump_to_csv(data) 123 | self._dump_to_console(data, prefix) 124 | self._meters.clear() 125 | 126 | 127 | class Logger(object): 128 | def __init__(self, log_dir, use_tb): 129 | """ 130 | mode: bc, ssl 131 | """ 132 | self._log_dir = log_dir 133 | self._train_mg = MetersGroup(log_dir / "train.csv", formating=BC_TRAIN_FORMAT) 134 | self._eval_mg = MetersGroup(log_dir / "eval.csv", formating=BC_EVAL_FORMAT) 135 | 136 | if use_tb: 137 | self._sw = SummaryWriter(str(log_dir / "tb")) 138 | else: 139 | self._sw = None 140 | 141 | def _try_sw_log(self, key, value, step): 142 | if self._sw is not None: 143 | self._sw.add_scalar(key, value, step) 144 | 145 | def log(self, key, value, step): 146 | assert key.startswith("train") or key.startswith("eval") 147 | if type(value) == torch.Tensor: 148 | value = value.item() 149 | self._try_sw_log(key, value, step) 150 | if key.startswith("train_vq"): 151 | mg = self._train_vq_mg 152 | else: 153 | mg = self._train_mg if key.startswith("train") else self._eval_mg 154 | mg.log(key, value) 155 | 156 | def log_metrics(self, metrics, step, ty): 157 | for key, value in metrics.items(): 158 | self.log(f"{ty}/{key}", value, step) 159 | 160 | def dump(self, step, ty=None): 161 | if ty is None or ty == "eval": 162 | self._eval_mg.dump(step, "eval") 163 | if ty is None or ty == "train": 164 | self._train_mg.dump(step, "train") 165 | if ty is None or ty == "train_vq": 166 | self._train_vq_mg.dump(step, "train_vq") 167 | 168 | def log_and_dump_ctx(self, step, ty): 169 | return LogAndDumpCtx(self, step, ty) 170 | 171 | 172 | class LogAndDumpCtx: 173 | def __init__(self, logger, step, ty): 174 | self._logger = logger 175 | self._step = step 176 | self._ty = ty 177 | 178 | def __enter__(self): 179 | return self 180 | 181 | def __call__(self, key, value): 182 | self._logger.log(f"{self._ty}/{key}", value, self._step) 183 | 184 | def __exit__(self, *args): 185 | self._logger.dump(self._step, self._ty) 186 | -------------------------------------------------------------------------------- /point_policy/point_utils/correspondence.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from torchvision import transforms 4 | import numpy as np 5 | import PIL 6 | 7 | 8 | class Correspondence: 9 | def __init__( 10 | self, 11 | device, 12 | dift_path, 13 | width, 14 | height, 15 | image_size_multiplier, 16 | ensemble_size, 17 | dift_layer, 18 | dift_steps, 19 | ): 20 | """ 21 | Initialize the Correspondence class. 22 | 23 | Parameters: 24 | ----------- 25 | device : str 26 | The device to use for computation, either 'cpu' or 'cuda' (for GPU). If you need to put the dift model on a different device to 27 | save space you can set this to cuda:1 28 | 29 | dift_path : str 30 | The file path to the directory containing the DIFT model source code. 31 | 32 | width : int 33 | The width that should be used in the correspondence model. 34 | 35 | height : int 36 | The height that should be used in the correspondence model. 37 | 38 | image_size_multiplier : int 39 | The multiplier to use for the image size in the DIFT model if height and weight are -1. 40 | 41 | ensemble_size : int 42 | The size of the ensemble for the DIFT model. 43 | 44 | dift_layer : int 45 | The specific layer of the DIFT model to use for feature extraction. 46 | 47 | dift_steps : int 48 | The number of steps/iterations for the DIFT model to use in feature extraction. 49 | 50 | """ 51 | sys.path.append(dift_path) 52 | from src.models.dift_sd import SDFeaturizer 53 | 54 | self.dift = SDFeaturizer(device=device) 55 | 56 | self.device = device 57 | self.width = width 58 | self.height = height 59 | self.image_size_multiplier = image_size_multiplier 60 | self.ensemble_size = ensemble_size 61 | self.dift_layer = dift_layer 62 | self.dift_steps = dift_steps 63 | 64 | # Get the feature map from the DIFT model for the expert image to compare with the first frame of each episode later on 65 | def set_expert_correspondence(self, expert_image, pixel_key, object_label=""): 66 | with torch.no_grad(): 67 | # Use a null prompt 68 | self.prompt = "" 69 | self.original_size = expert_image.size 70 | 71 | if self.width == -1 or self.height == -1: 72 | self.width = expert_image.size[0] * self.image_size_multiplier 73 | self.height = expert_image.size[1] * self.image_size_multiplier 74 | 75 | expert_image = expert_image.resize( 76 | (self.width, self.height), resample=PIL.Image.BILINEAR 77 | ) 78 | expert_image = (transforms.PILToTensor()(expert_image) / 255.0 - 0.5) * 2 79 | expert_image = expert_image.cuda(self.device) 80 | 81 | expert_img_features = self.dift.forward( 82 | expert_image, 83 | prompt=self.prompt, 84 | ensemble_size=self.ensemble_size, 85 | up_ft_index=self.dift_layer, 86 | t=self.dift_steps, 87 | ) 88 | expert_img_features = expert_img_features.to(self.device) 89 | 90 | return expert_img_features 91 | 92 | def find_correspondence( 93 | self, expert_img_features, current_image, coords, pixel_key, object_label 94 | ): 95 | """ 96 | Find the corresponding points between the expert image and the current image 97 | 98 | Parameters: 99 | ----------- 100 | expert_img_features : torch.Tensor 101 | The feature map from the DIFT model for the expert image. 102 | 103 | current_image : torch.Tensor 104 | The current image to compare with the expert image. 105 | 106 | coords : list 107 | The coordinates of the points to find correspondence between the expert image and the current image. 108 | """ 109 | 110 | with torch.no_grad(): 111 | curr_image_shape = (current_image.shape[2], current_image.shape[1]) 112 | current_image = transforms.Resize((self.height, self.width))(current_image) 113 | current_image_features = self.dift.forward( 114 | ((current_image - 0.5) * 2), 115 | prompt=self.prompt, 116 | ensemble_size=self.ensemble_size, 117 | up_ft_index=self.dift_layer, 118 | t=self.dift_steps, 119 | ) 120 | 121 | ft = torch.cat([expert_img_features, current_image_features]) 122 | num_channel = ft.size(1) 123 | src_ft = ft[0].unsqueeze(0) 124 | src_ft = torch.nn.Upsample( 125 | size=(self.height, self.width), mode="bilinear", align_corners=True 126 | )(src_ft) 127 | 128 | out_coords = torch.zeros(coords.shape) 129 | for idx, coord in enumerate(coords): 130 | x, y = int(coord[1] * self.width / self.original_size[0]), int( 131 | coord[2] * self.height / self.original_size[1] 132 | ) 133 | 134 | src_vec = src_ft[0, :, y, x].view(1, num_channel).clone() 135 | trg_ft = torch.nn.Upsample( 136 | size=(self.height, self.width), mode="bilinear", align_corners=True 137 | )(ft[1:]) 138 | trg_vec = trg_ft.view(1, num_channel, -1) # N, C, HW 139 | 140 | src_vec = torch.nn.functional.normalize(src_vec) # 1, C 141 | trg_vec = torch.nn.functional.normalize(trg_vec) # N, C, HW 142 | cos_map = ( 143 | torch.matmul(src_vec, trg_vec) 144 | .view(1, self.height, self.width) 145 | .cpu() 146 | .numpy() 147 | ) 148 | 149 | max_yx = np.unravel_index(cos_map[0].argmax(), cos_map[0].shape) 150 | out_coords[idx, 1], out_coords[idx, 2] = int( 151 | max_yx[1] * curr_image_shape[0] / self.width 152 | ), int(max_yx[0] * curr_image_shape[1] / self.height) 153 | 154 | return out_coords 155 | -------------------------------------------------------------------------------- /point_policy/point_utils/depth.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import cv2 4 | 5 | 6 | class Depth: 7 | def __init__(self, depth_path, device): 8 | """ 9 | Initialize the Depth class for finding depth maps from images. 10 | 11 | Parameters: 12 | ----------- 13 | depth_path : str 14 | The file path to the directory containing the Depth Anything source code. 15 | 16 | device : str 17 | The device to use for computation, either 'cpu' or 'cuda' (for GPU acceleration). 18 | """ 19 | sys.path.append(depth_path) 20 | from metric_depth.depth_anything_v2.dpt import DepthAnythingV2 21 | 22 | # Initialize the Depth model 23 | model_configs = { 24 | "vits": { 25 | "encoder": "vits", 26 | "features": 64, 27 | "out_channels": [48, 96, 192, 384], 28 | }, 29 | "vitb": { 30 | "encoder": "vitb", 31 | "features": 128, 32 | "out_channels": [96, 192, 384, 768], 33 | }, 34 | "vitl": { 35 | "encoder": "vitl", 36 | "features": 256, 37 | "out_channels": [256, 512, 1024, 1024], 38 | }, 39 | } 40 | encoder = "vitl" 41 | dataset = "hypersim" 42 | max_depth = 20 43 | self.depth = DepthAnythingV2( 44 | **{**model_configs[encoder], "max_depth": max_depth} 45 | ) 46 | self.depth.load_state_dict( 47 | torch.load( 48 | f"{depth_path}/checkpoints/depth_anything_v2_metric_{dataset}_{encoder}.pth", 49 | map_location="cpu", 50 | weights_only=False, 51 | ) 52 | ) 53 | self.depth = self.depth.to(device).eval() 54 | 55 | self.device = device 56 | 57 | def get_depth(self, image): 58 | """ 59 | Get the depth map from an image. 60 | 61 | Parameters: 62 | ----------- 63 | image : np.ndarray 64 | The image to find the depth map for. The image should be in RGB format. 65 | 66 | Returns: 67 | -------- 68 | depth : np.ndarray 69 | The depth map for the image. 70 | """ 71 | bgr_array = image[:, :, ::-1] 72 | depth = self.depth.infer_image(bgr_array) 73 | return depth 74 | -------------------------------------------------------------------------------- /point_policy/read_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siddhanthaldar/Point-Policy/6f318852b324283d354446bfba0628890f78a116/point_policy/read_data/__init__.py -------------------------------------------------------------------------------- /point_policy/read_data/p3po.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import random 3 | import numpy as np 4 | import pickle as pkl 5 | from pathlib import Path 6 | 7 | from torch.utils.data import IterableDataset 8 | from scipy.spatial.transform import Rotation as R 9 | 10 | from robot_utils.franka.utils import matrix_to_rotation_6d 11 | 12 | 13 | def get_relative_action(actions, action_after_steps): 14 | """ 15 | Convert absolute axis angle actions to relative axis angle actions 16 | Action has both position and orientation. Convert to transformation matrix, get 17 | relative transformation matrix, convert back to axis angle 18 | """ 19 | 20 | relative_actions = [] 21 | for i in range(len(actions)): 22 | ####### Get relative transformation matrix ####### 23 | # previous pose 24 | pos_prev = actions[i, :3] 25 | ori_prev = actions[i, 3:6] 26 | r_prev = R.from_rotvec(ori_prev).as_matrix() 27 | matrix_prev = np.eye(4) 28 | matrix_prev[:3, :3] = r_prev 29 | matrix_prev[:3, 3] = pos_prev 30 | # current pose 31 | next_idx = min(i + action_after_steps, len(actions) - 1) 32 | pos = actions[next_idx, :3] 33 | ori = actions[next_idx, 3:6] 34 | gripper = actions[next_idx, 6:] 35 | r = R.from_rotvec(ori).as_matrix() 36 | matrix = np.eye(4) 37 | matrix[:3, :3] = r 38 | matrix[:3, 3] = pos 39 | # relative transformation 40 | matrix_rel = np.linalg.inv(matrix_prev) @ matrix 41 | # relative pose 42 | pos_rel = pos - pos_prev 43 | r_rel = R.from_matrix(matrix_rel[:3, :3]).as_rotvec() 44 | # add to list 45 | relative_actions.append(np.concatenate([pos_rel, r_rel, gripper])) 46 | 47 | # last action 48 | last_action = np.zeros_like(actions[-1]) 49 | last_action[-1] = actions[-1][-1] 50 | while len(relative_actions) < len(actions): 51 | relative_actions.append(last_action) 52 | return np.array(relative_actions, dtype=np.float32) 53 | 54 | 55 | class BCDataset(IterableDataset): 56 | def __init__( 57 | self, 58 | path, 59 | tasks, 60 | num_demos_per_task, 61 | history, 62 | history_len, 63 | temporal_agg, 64 | num_queries, 65 | img_size, 66 | action_after_steps, 67 | use_robot_points, 68 | num_robot_points, 69 | use_object_points, 70 | num_object_points, 71 | point_dim, 72 | pixel_keys, 73 | action_type, 74 | subsample, 75 | skip_first_n, 76 | gt_depth, 77 | ): 78 | tasks = [tasks] # NOTE: single task for now 79 | 80 | self._history = history 81 | self._history_len = history_len if history else 1 82 | self._img_size = np.array(img_size) 83 | self._action_after_steps = action_after_steps 84 | self._pixel_keys = pixel_keys 85 | self._action_type = action_type 86 | self._subsample = subsample 87 | 88 | # track points 89 | self._use_robot_points = use_robot_points 90 | self._num_robot_points = num_robot_points 91 | self._use_object_points = use_object_points 92 | self._num_object_points = num_object_points 93 | self._point_dim = point_dim 94 | assert self._point_dim in [2, 3], "Point dimension must be 2 or 3" 95 | self._robot_points_key = ( 96 | "robot_tracks" if self._point_dim == 2 else "robot_tracks_3d" 97 | ) 98 | self._object_points_key = ( 99 | "object_tracks" if self._point_dim == 2 else "object_tracks_3d" 100 | ) 101 | 102 | # temporal aggregation 103 | self._temporal_agg = temporal_agg 104 | self._num_queries = num_queries if temporal_agg else 1 105 | 106 | # get data paths 107 | self._paths = [] 108 | for task in tasks: 109 | if gt_depth: 110 | self._paths.extend([Path(path) / f"{task}_gt_depth.pkl"]) 111 | else: 112 | self._paths.extend([Path(path) / f"{task}.pkl"]) 113 | if self._use_object_points: 114 | self._object_pt_paths = {} 115 | 116 | paths = {} 117 | idx = 0 118 | for path, task in zip(self._paths, tasks): 119 | paths[idx] = path 120 | idx += 1 121 | del self._paths 122 | self._paths = paths 123 | 124 | # read data 125 | self._episodes = {} 126 | self._num_demos = {} 127 | self._max_episode_len = 0 128 | self._max_state_dim = 0 129 | self._num_samples = 0 130 | min_act, max_act = None, None 131 | min_track, max_track = None, None 132 | for _path_idx in self._paths: 133 | print(f"Loading {str(self._paths[_path_idx])}") 134 | # read 135 | data = pkl.load(open(str(self._paths[_path_idx]), "rb")) 136 | observations = data["observations"] 137 | 138 | # store 139 | self._episodes[_path_idx] = [] 140 | self._num_demos[_path_idx] = min(num_demos_per_task, len(observations)) 141 | for i in range(min(num_demos_per_task, len(observations))): 142 | # absolute actions 143 | action_key = ( 144 | "human_poses" 145 | if "human_poses" in observations[i].keys() 146 | else "cartesian_states" 147 | ) 148 | actions = np.concatenate( 149 | [ 150 | observations[i][action_key], 151 | observations[i]["gripper_states"][:, None], 152 | ], 153 | axis=1, 154 | ) 155 | if len(actions) == 0: 156 | continue 157 | 158 | # skip first n 159 | if skip_first_n is not None: 160 | for key in observations[i].keys(): 161 | observations[i][key] = observations[i][key][skip_first_n:] 162 | actions = actions[skip_first_n:] 163 | 164 | # action after steps 165 | if self._action_type == "absolute": 166 | actions = actions[self._action_after_steps :] 167 | else: 168 | actions = get_relative_action(actions, self._subsample) 169 | 170 | # convert orientation to quarternions 171 | if self._action_type == "absolute": 172 | gripper = actions[:, -1:] 173 | # orientaion represented at 6D rotations 174 | pos = actions[:, :3] 175 | rot = actions[:, 3:6] 176 | rot = [R.from_rotvec(rot[i]).as_matrix() for i in range(len(rot))] 177 | rot = np.array(rot) 178 | rot = matrix_to_rotation_6d(rot) 179 | actions = np.concatenate([pos, rot], axis=-1) 180 | actions = np.concatenate([actions, gripper], axis=-1) 181 | 182 | # Repeat last dimension of each observation for history_len times 183 | for key in observations[i].keys(): 184 | observations[i][key] = np.concatenate( 185 | [ 186 | observations[i][key], 187 | [observations[i][key][-1]] * self._history_len, 188 | ], 189 | axis=0, 190 | ) 191 | # Repeat last action for history_len times 192 | remaining_actions = actions[-1] 193 | if self._action_type == "relative": 194 | pos = remaining_actions[:-1] 195 | ori_gripper = remaining_actions[-1:] 196 | remaining_actions = np.concatenate( 197 | [np.zeros_like(pos), ori_gripper] 198 | ) 199 | actions = np.concatenate( 200 | [ 201 | actions, 202 | [remaining_actions] * self._history_len, 203 | ], 204 | axis=0, 205 | ) 206 | 207 | # store 208 | episode = dict( 209 | observation=observations[i], 210 | action=actions, 211 | ) 212 | self._episodes[_path_idx].append(episode) 213 | self._max_episode_len = max( 214 | self._max_episode_len, 215 | ( 216 | len(observations[i]) 217 | if not isinstance(observations[i], dict) 218 | else len(observations[i][self._pixel_keys[0]]) 219 | ), 220 | ) 221 | self._max_state_dim = self._num_robot_points * self._point_dim 222 | self._num_samples += len(observations[i][self._pixel_keys[0]]) 223 | 224 | # max, min action 225 | if min_act is None: 226 | min_act = np.min(actions, axis=0) 227 | max_act = np.max(actions, axis=0) 228 | else: 229 | min_act = np.minimum(min_act, np.min(actions, axis=0)) 230 | max_act = np.maximum(max_act, np.max(actions, axis=0)) 231 | 232 | # min, max track 233 | for pixel_key in self._pixel_keys: 234 | if self._use_robot_points: 235 | track_key = f"{self._robot_points_key}_{pixel_key}" 236 | track = observations[i][track_key] 237 | track = einops.rearrange(track, "t n d -> (t n) d") 238 | min_track = ( 239 | np.minimum(min_track, np.min(track, axis=0)) 240 | if min_track is not None 241 | else np.min(track, axis=0) 242 | ) 243 | max_track = ( 244 | np.maximum(max_track, np.max(track, axis=0)) 245 | if max_track is not None 246 | else np.max(track, axis=0) 247 | ) 248 | if self._use_object_points: 249 | track_key = f"{self._object_points_key}_{pixel_key}" 250 | track = observations[i][track_key] 251 | track = einops.rearrange(track, "t n d -> (t n) d") 252 | min_track = ( 253 | np.minimum(min_track, np.min(track, axis=0)) 254 | if min_track is not None 255 | else np.min(track, axis=0) 256 | ) 257 | max_track = ( 258 | np.maximum(max_track, np.max(track, axis=0)) 259 | if max_track is not None 260 | else np.max(track, axis=0) 261 | ) 262 | 263 | self.stats = { 264 | "actions": { 265 | "min": min_act, 266 | "max": max_act, 267 | }, 268 | "past_tracks": { 269 | "min": min_track, 270 | "max": max_track, 271 | }, 272 | } 273 | 274 | self.preprocess = { 275 | "actions": lambda x: (x - self.stats["actions"]["min"]) 276 | / (self.stats["actions"]["max"] - self.stats["actions"]["min"] + 1e-5), 277 | "past_tracks": lambda x: (x - self.stats["past_tracks"]["min"]) 278 | / ( 279 | self.stats["past_tracks"]["max"] 280 | - self.stats["past_tracks"]["min"] 281 | + 1e-5 282 | ), 283 | } 284 | 285 | # Samples from envs 286 | self.envs_till_idx = len(self._episodes) 287 | 288 | def _sample_episode(self, env_idx=None): 289 | if env_idx is not None: 290 | idx = env_idx 291 | else: 292 | idx = np.random.choice(list(self._episodes.keys())) 293 | 294 | episode = random.choice(self._episodes[idx]) 295 | return (episode, idx) if env_idx is None else episode 296 | 297 | def _sample(self): 298 | episodes, env_idx = self._sample_episode() 299 | observations = episodes["observation"] 300 | actions = episodes["action"] 301 | 302 | # Sample obs, action 303 | sample_idx = np.random.randint( 304 | 0, len(observations[self._pixel_keys[0]]) - self._history_len 305 | ) 306 | pixel_key = np.random.choice(self._pixel_keys) 307 | 308 | if self._temporal_agg: 309 | # arrange sampled action to be of shape (history_len, num_queries, action_dim) 310 | action = np.zeros((self._history_len, self._num_queries, actions.shape[-1])) 311 | num_actions = ( 312 | self._history_len + self._num_queries - 1 313 | ) # -1 since its num_queries including the last action of the history 314 | act = np.zeros((num_actions, actions.shape[-1])) 315 | act[: min(len(actions), sample_idx + num_actions) - sample_idx] = actions[ 316 | sample_idx : sample_idx + num_actions 317 | ] 318 | if len(actions) < sample_idx + num_actions: 319 | act[len(actions) - sample_idx :] = actions[-1] 320 | action = np.lib.stride_tricks.sliding_window_view( 321 | act, (self._num_queries, actions.shape[-1]) 322 | ) 323 | action = action[:, 0] 324 | else: 325 | action = actions[sample_idx : sample_idx + self._history_len] 326 | 327 | past_tracks = [] 328 | if self._use_robot_points: 329 | track_key = f"{self._robot_points_key}_{pixel_key}" 330 | num_points = self._num_robot_points 331 | robot_points = observations[track_key][ 332 | max( 333 | 0, 334 | sample_idx - self._history_len * self._subsample + self._subsample, 335 | ) : sample_idx 336 | + 1 : self._subsample 337 | ][:, -num_points:] 338 | if len(robot_points) < self._history_len: 339 | prior = np.array( 340 | [robot_points[0]] * (self._history_len - len(robot_points)) 341 | ) 342 | robot_points = np.concatenate([prior, robot_points], axis=0) 343 | past_tracks.append(robot_points) 344 | 345 | if self._use_object_points: 346 | object_points = observations[f"{self._object_points_key}_{pixel_key}"][ 347 | max( 348 | 0, 349 | sample_idx - self._history_len * self._subsample + self._subsample, 350 | ) : sample_idx 351 | + 1 : self._subsample 352 | ] 353 | if len(object_points) < self._history_len: 354 | prior = np.array( 355 | [object_points[0]] * (self._history_len - len(object_points)) 356 | ) 357 | object_points = np.concatenate([prior, object_points], axis=0) 358 | past_tracks.append(object_points) 359 | 360 | past_tracks = np.concatenate(past_tracks, axis=1) 361 | 362 | return_dict = { 363 | "past_tracks": self.preprocess["past_tracks"](past_tracks), 364 | "actions": self.preprocess["actions"](action), 365 | } 366 | 367 | return return_dict 368 | 369 | def sample_actions(self, env_idx): 370 | episode = self._sample_episode(env_idx) 371 | return episode["action"] 372 | 373 | def __iter__(self): 374 | while True: 375 | yield self._sample() 376 | 377 | def __len__(self): 378 | return self._num_samples 379 | -------------------------------------------------------------------------------- /point_policy/read_data/point_policy.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import random 3 | import numpy as np 4 | import pickle as pkl 5 | from pathlib import Path 6 | 7 | from torch.utils.data import IterableDataset 8 | from scipy.spatial.transform import Rotation as R 9 | 10 | 11 | class BCDataset(IterableDataset): 12 | def __init__( 13 | self, 14 | path, 15 | tasks, 16 | num_demos_per_task, 17 | history, 18 | history_len, 19 | temporal_agg, 20 | num_queries, 21 | img_size, 22 | action_after_steps, 23 | use_robot_points, 24 | num_robot_points, 25 | use_object_points, 26 | num_object_points, 27 | point_dim, 28 | pixel_keys, 29 | subsample, 30 | skip_first_n, 31 | gt_depth, 32 | ): 33 | tasks = [tasks] # NOTE: single task for now 34 | 35 | self._history = history 36 | self._history_len = history_len if history else 1 37 | self._img_size = np.array(img_size) 38 | self._action_after_steps = action_after_steps 39 | self._pixel_keys = pixel_keys 40 | self._subsample = subsample 41 | 42 | # track points 43 | self._use_robot_points = use_robot_points 44 | self._num_robot_points = num_robot_points 45 | self._use_object_points = use_object_points 46 | self._num_object_points = num_object_points 47 | self._point_dim = point_dim 48 | assert self._point_dim in [2, 3], "Point dimension must be 2 or 3" 49 | self._robot_points_key = ( 50 | "robot_tracks" if self._point_dim == 2 else "robot_tracks_3d" 51 | ) 52 | self._object_points_key = ( 53 | "object_tracks" if self._point_dim == 2 else "object_tracks_3d" 54 | ) 55 | 56 | # temporal aggregation 57 | self._temporal_agg = temporal_agg 58 | self._num_queries = num_queries if temporal_agg else 1 59 | 60 | # get data paths 61 | self._paths = [] 62 | for task in tasks: 63 | if gt_depth: 64 | self._paths.extend([Path(path) / f"{task}_gt_depth.pkl"]) 65 | else: 66 | self._paths.extend([Path(path) / f"{task}.pkl"]) 67 | 68 | paths = {} 69 | idx = 0 70 | for path, task in zip(self._paths, tasks): 71 | paths[idx] = path 72 | idx += 1 73 | del self._paths 74 | self._paths = paths 75 | 76 | # read data 77 | self._episodes = {} 78 | self._num_demos = {} 79 | self._max_episode_len = 0 80 | self._max_state_dim = 0 81 | self._num_samples = 0 82 | min_track, max_track = None, None 83 | for _path_idx in self._paths: 84 | print(f"Loading {str(self._paths[_path_idx])}") 85 | # read 86 | data = pkl.load(open(str(self._paths[_path_idx]), "rb")) 87 | observations = data["observations"] 88 | 89 | # store 90 | self._episodes[_path_idx] = [] 91 | self._num_demos[_path_idx] = min(num_demos_per_task, len(observations)) 92 | for i in range(min(num_demos_per_task, len(observations))): 93 | # skip first n 94 | if skip_first_n is not None: 95 | for key in observations[i].keys(): 96 | observations[i][key] = observations[i][key][skip_first_n:] 97 | 98 | # Repeat last dimension of each observation for history_len times 99 | for key in observations[i].keys(): 100 | observations[i][key] = np.concatenate( 101 | [ 102 | observations[i][key], 103 | [observations[i][key][-1]] * self._history_len, 104 | ], 105 | axis=0, 106 | ) 107 | 108 | # store 109 | episode = dict( 110 | observation=observations[i], 111 | ) 112 | self._episodes[_path_idx].append(episode) 113 | self._max_episode_len = max( 114 | self._max_episode_len, 115 | ( 116 | len(observations[i]) 117 | if not isinstance(observations[i], dict) 118 | else len(observations[i][self._pixel_keys[0]]) 119 | ), 120 | ) 121 | self._max_state_dim = self._num_robot_points * self._point_dim 122 | self._num_samples += len(observations[i][self._pixel_keys[0]]) 123 | 124 | # min, max track 125 | for pixel_key in self._pixel_keys: 126 | if self._use_robot_points: 127 | track_key = f"{self._robot_points_key}_{pixel_key}" 128 | track = observations[i][track_key] 129 | track = einops.rearrange(track, "t n d -> (t n) d") 130 | min_track = ( 131 | np.minimum(min_track, np.min(track, axis=0)) 132 | if min_track is not None 133 | else np.min(track, axis=0) 134 | ) 135 | max_track = ( 136 | np.maximum(max_track, np.max(track, axis=0)) 137 | if max_track is not None 138 | else np.max(track, axis=0) 139 | ) 140 | if self._use_object_points: 141 | track_key = f"{self._object_points_key}_{pixel_key}" 142 | track = observations[i][track_key] 143 | track = einops.rearrange(track, "t n d -> (t n) d") 144 | min_track = ( 145 | np.minimum(min_track, np.min(track, axis=0)) 146 | if min_track is not None 147 | else np.min(track, axis=0) 148 | ) 149 | max_track = ( 150 | np.maximum(max_track, np.max(track, axis=0)) 151 | if max_track is not None 152 | else np.max(track, axis=0) 153 | ) 154 | 155 | self.stats = { 156 | "past_tracks": { 157 | "min": min_track, 158 | "max": max_track, 159 | }, 160 | "future_tracks": { 161 | "min": np.concatenate( 162 | [min_track for _ in range(self._num_queries)], axis=0 163 | ), 164 | "max": np.concatenate( 165 | [max_track for _ in range(self._num_queries)], axis=0 166 | ), 167 | }, 168 | "gripper_states": { 169 | "min": -2.0, 170 | "max": 2.0, 171 | }, 172 | } 173 | 174 | self.preprocess = { 175 | "past_tracks": lambda x: (x - self.stats["past_tracks"]["min"]) 176 | / ( 177 | self.stats["past_tracks"]["max"] 178 | - self.stats["past_tracks"]["min"] 179 | + 1e-5 180 | ), 181 | "future_tracks": lambda x: (x - self.stats["future_tracks"]["min"]) 182 | / ( 183 | self.stats["future_tracks"]["max"] 184 | - self.stats["future_tracks"]["min"] 185 | + 1e-5 186 | ), 187 | "gripper_states": lambda x: (x - self.stats["gripper_states"]["min"]) 188 | / ( 189 | self.stats["gripper_states"]["max"] 190 | - self.stats["gripper_states"]["min"] 191 | + 1e-5 192 | ), 193 | } 194 | 195 | # Samples from envs 196 | self.envs_till_idx = len(self._episodes) 197 | 198 | def _sample_episode(self, env_idx=None): 199 | if env_idx is not None: 200 | idx = env_idx 201 | else: 202 | idx = np.random.choice(list(self._episodes.keys())) 203 | 204 | episode = random.choice(self._episodes[idx]) 205 | return (episode, idx) if env_idx is None else episode 206 | 207 | def _sample(self): 208 | episodes, env_idx = self._sample_episode() 209 | observations = episodes["observation"] 210 | traj_len = len(observations[self._pixel_keys[0]]) 211 | 212 | # Sample obs, action 213 | sample_idx = np.random.randint( 214 | 0, len(observations[self._pixel_keys[0]]) - self._history_len 215 | ) 216 | pixel_key = np.random.choice(self._pixel_keys) 217 | 218 | # action mask to only apply loss for robot or hand points 219 | action_mask = [] 220 | 221 | past_tracks = [] 222 | 223 | if self._use_robot_points: 224 | track_key = f"{self._robot_points_key}_{pixel_key}" 225 | num_points = self._num_robot_points 226 | robot_points = observations[track_key][ 227 | max( 228 | 0, 229 | sample_idx - self._history_len * self._subsample + self._subsample, 230 | ) : sample_idx 231 | + 1 : self._subsample 232 | ][:, -num_points:] 233 | if len(robot_points) < self._history_len: 234 | prior = np.array( 235 | [robot_points[0]] * (self._history_len - len(robot_points)) 236 | ) 237 | robot_points = np.concatenate([prior, robot_points], axis=0) 238 | past_tracks.append(robot_points) 239 | action_mask.extend([1] * num_points) 240 | 241 | if self._use_object_points: 242 | object_points = observations[f"{self._object_points_key}_{pixel_key}"][ 243 | max( 244 | 0, 245 | sample_idx 246 | - self._history_len * self._subsample 247 | + self._subsample, # 1 248 | ) : sample_idx 249 | + 1 : self._subsample 250 | ] 251 | if len(object_points) < self._history_len: 252 | prior = np.array( 253 | [object_points[0]] * (self._history_len - len(object_points)) 254 | ) 255 | object_points = np.concatenate([prior, object_points], axis=0) 256 | past_tracks.append(object_points) 257 | action_mask.extend([0] * self._num_object_points) 258 | 259 | past_tracks = np.concatenate(past_tracks, axis=1) 260 | action_mask = np.array(action_mask) 261 | 262 | # past gripper_states 263 | past_gripper_states = observations[f"gripper_states"][ 264 | max( 265 | 0, 266 | sample_idx - self._history_len * self._subsample + self._subsample, # 1 267 | ) : sample_idx 268 | + 1 : self._subsample 269 | ] 270 | if len(past_gripper_states) < self._history_len: 271 | prior = np.array( 272 | [past_gripper_states[0]] 273 | * (self._history_len - len(past_gripper_states)) 274 | ) 275 | past_gripper_states = np.concatenate([prior, past_gripper_states], axis=0) 276 | 277 | future_tracks = [] 278 | num_future_tracks = self._history_len + self._num_queries - 1 279 | 280 | # for action sampling 281 | start_idx = min(sample_idx + 1, traj_len - 1) 282 | end_idx = min(start_idx + num_future_tracks * self._subsample, traj_len) 283 | 284 | if self._use_robot_points: 285 | track_key = f"{self._robot_points_key}_{pixel_key}" 286 | num_points = self._num_robot_points 287 | ft = observations[track_key][start_idx : end_idx : self._subsample][ 288 | :, -num_points: 289 | ] 290 | if len(ft) < num_future_tracks: 291 | post = np.array([ft[-1]] * (num_future_tracks - len(ft))) 292 | ft = np.concatenate([ft, post], axis=0) 293 | # ft is of shape (T, N, D) 294 | ft = ft.transpose( 295 | 1, 0, 2 296 | ) # (N, T, D) where T=history_len+num_queries-1=H+Q-1 297 | ft = np.lib.stride_tricks.sliding_window_view( 298 | ft, self._num_queries, 1 299 | ) # (N, H, D, Q) 300 | ft = ft.transpose(1, 0, 3, 2) # (H, N, Q, D) 301 | ft = einops.rearrange(ft, "h n q d -> h n (q d)") 302 | future_tracks.append(ft) 303 | 304 | if self._use_object_points: 305 | ft = observations[f"{self._object_points_key}_{pixel_key}"][ 306 | start_idx : end_idx : self._subsample 307 | ] 308 | if len(ft) < num_future_tracks: 309 | post = np.array([ft[-1]] * (num_future_tracks - len(ft))) 310 | ft = np.concatenate([ft, post], axis=0) 311 | # ft is of shape (T, N, D) 312 | ft = ft.transpose( 313 | 1, 0, 2 314 | ) # (N, T, D) where T=history_len+num_queries-1=H+Q-1 315 | ft = np.lib.stride_tricks.sliding_window_view( 316 | ft, self._num_queries, 1 317 | ) # (N, H, D, Q) 318 | ft = ft.transpose(1, 0, 3, 2) # (H, N, Q, D) 319 | ft = einops.rearrange(ft, "h n q d -> h n (q d)") 320 | future_tracks.append(ft) 321 | 322 | future_tracks = np.concatenate(future_tracks, axis=1) 323 | 324 | # future gripper_states 325 | future_gripper_states = observations[f"gripper_states"][ 326 | start_idx : end_idx : self._subsample 327 | ] 328 | if len(future_gripper_states) < num_future_tracks: 329 | post = np.array( 330 | [future_gripper_states[-1]] 331 | * (num_future_tracks - len(future_gripper_states)) 332 | ) 333 | future_gripper_states = np.concatenate( 334 | [future_gripper_states, post], axis=0 335 | ) 336 | future_gripper_states = future_gripper_states.reshape( 337 | future_gripper_states.shape[0] 338 | ) 339 | future_gripper_states = np.lib.stride_tricks.sliding_window_view( 340 | future_gripper_states, self._num_queries 341 | ) 342 | 343 | return_dict = { 344 | "past_tracks": self.preprocess["past_tracks"](past_tracks), 345 | "past_gripper_states": self.preprocess["gripper_states"]( 346 | past_gripper_states 347 | ), 348 | "future_tracks": self.preprocess["future_tracks"](future_tracks), 349 | "future_gripper_states": self.preprocess["gripper_states"]( 350 | future_gripper_states 351 | ), 352 | "action_mask": action_mask, 353 | } 354 | 355 | return return_dict 356 | 357 | def sample_actions(self, env_idx): 358 | episode = self._sample_episode(env_idx) 359 | actions = [] 360 | for i in range( 361 | 0, 362 | len(episode["observation"][f"point_tracks_{self._pixel_keys[0]}"]), 363 | self._subsample, 364 | ): 365 | action = {} 366 | for pixel_key in self._pixel_keys: 367 | action[f"future_tracks_{pixel_key}"] = episode["observation"][ 368 | f"point_tracks_{pixel_key}" 369 | ][i] 370 | actions.append(action) 371 | return actions 372 | 373 | def __iter__(self): 374 | while True: 375 | yield self._sample() 376 | 377 | def __len__(self): 378 | return self._num_samples 379 | -------------------------------------------------------------------------------- /point_policy/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def _worker_init_fn(worker_id): 7 | seed = np.random.get_state()[1][0] + worker_id 8 | np.random.seed(seed) 9 | random.seed(seed) 10 | 11 | 12 | def make_expert_replay_loader(iterable, batch_size): 13 | loader = torch.utils.data.DataLoader( 14 | iterable, 15 | batch_size=batch_size, 16 | num_workers=2, 17 | pin_memory=True, 18 | worker_init_fn=_worker_init_fn, 19 | ) 20 | return loader 21 | -------------------------------------------------------------------------------- /point_policy/robot_utils/franka/calibration/constants.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # From pyrealsense2 4 | # Command: rs-enumerate-devices -c 5 | # 640x480 6 | CAMERA_MATRICES = { 7 | "cam_1": np.array([[604.97, 0, 314.83], [0.0, 604.79, 249.03], [0, 0, 1]]), 8 | "cam_2": np.array([[609.41, 0, 314.85], [0.0, 609.65, 240.52], [0, 0, 1]]), 9 | } 10 | 11 | DISTORTION_COEFFICIENTS = { 12 | "cam_1": np.zeros((5)), 13 | "cam_2": np.zeros((5)), 14 | } 15 | -------------------------------------------------------------------------------- /point_policy/robot_utils/franka/calibration/generate_r2c_extrinsic.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script which given camera intrinsics computes te robot to camera transformation 3 | for each camera and uses that as extrinsics to save in a calib.pkl file 4 | """ 5 | 6 | import cv2 7 | from cv2 import aruco 8 | import numpy as np 9 | import pickle as pkl 10 | from pathlib import Path 11 | from scipy.spatial.transform import Rotation as R 12 | 13 | PATH_DATA_PKL = Path("/path/to/data/processed_data_pkl/calib.pkl") 14 | PATH_INTRINSICS = None 15 | SAVE_DIR = Path("../../../calib") 16 | PATH_SAVE_CALIB = SAVE_DIR / "calib.npy" 17 | CAM_IDS = [1, 2] 18 | R2C_TRAJ_IDX = 0 19 | FRAME_FREQ = 1 # consider every Nth frame 20 | 21 | 22 | with open(PATH_DATA_PKL, "rb") as f: 23 | observations = pkl.load(f)["observations"] 24 | 25 | if PATH_INTRINSICS is not None and PATH_INTRINSICS.exists(): 26 | print("Using intrinsics from file") 27 | with open(PATH_INTRINSICS, "rb") as f: 28 | intrinsics = pkl.load(f) 29 | else: 30 | print("Using intrinsics from constants") 31 | from constants import CAMERA_MATRICES, DISTORTION_COEFFICIENTS 32 | 33 | intrinsics = { 34 | "camera_matrices": {}, 35 | "distortion_coefficients": {}, 36 | } 37 | for cam_id in CAM_IDS: 38 | intrinsics["camera_matrices"][f"cam_{cam_id}"] = CAMERA_MATRICES[ 39 | f"cam_{cam_id}" 40 | ] 41 | intrinsics["distortion_coefficients"][ 42 | f"cam_{cam_id}" 43 | ] = DISTORTION_COEFFICIENTS[f"cam_{cam_id}"] 44 | 45 | SAVE_DIR.mkdir(exist_ok=True, parents=True) 46 | 47 | ################################# compute the robot to camera transformation ################################# 48 | 49 | T_ci_b = {} 50 | 51 | for cam_id in CAM_IDS: 52 | pixels = observations[R2C_TRAJ_IDX][f"pixels{cam_id}"][..., ::-1] 53 | 54 | # object point transformations 55 | object_pos = observations[R2C_TRAJ_IDX]["cartesian_states"][:, :3] 56 | object_aa = observations[R2C_TRAJ_IDX]["cartesian_states"][:, 3:] 57 | object_rot_mat = R.from_rotvec(object_aa).as_matrix() 58 | object_trans = np.zeros( 59 | (object_pos.shape[0], 4, 4) 60 | ) # pose of gripper in robot base frame 61 | object_trans[:, :3, :3] = object_rot_mat 62 | object_trans[:, :3, 3] = object_pos 63 | 64 | # compute object points 65 | T_a_g = np.array([[1, 0, 0, 0.0025], [0, 1, 0, 0], [0, 0, 1, 0.0625], [0, 0, 0, 1]]) 66 | 67 | # Franka axis is rotated by 45 degrees 68 | # See pg 52 of https://download.franka.de/documents/Product%20Manual%20Franka%20Research%203_R02210_1.0_EN.pdf 69 | angle = -np.pi / 4 70 | R_z = np.array( 71 | [ 72 | [np.cos(angle), -np.sin(angle), 0], 73 | [np.sin(angle), np.cos(angle), 0], 74 | [0, 0, 1], 75 | ] 76 | ) 77 | T_a_g[:3, :3] = R_z @ T_a_g[:3, :3] 78 | 79 | object_pts = [T @ T_a_g for T in object_trans] 80 | object_pts = np.array(object_pts) 81 | object_points = object_pts[:, :3, 3] 82 | 83 | # Aruco marker detection with Cv2 on pixels 84 | aruco_dict = aruco.getPredefinedDictionary(aruco.DICT_4X4_50) 85 | parameters = aruco.DetectorParameters() 86 | detector = aruco.ArucoDetector(aruco_dict, parameters) 87 | image_points = [] 88 | invalid_indices = [] 89 | idx = 0 90 | for i in range(0, len(pixels), FRAME_FREQ): 91 | corners, ids, rejectedImgPoints = detector.detectMarkers(pixels[i]) 92 | if corners: 93 | center_img = corners[0].mean(axis=1).flatten() 94 | image_points.append(center_img) 95 | else: 96 | invalid_indices.append(idx) 97 | print("error") 98 | idx += 1 99 | 100 | # remove invalid indices from subsampled object points 101 | object_points = object_points[::FRAME_FREQ] 102 | object_points = [ 103 | object_points[i] for i in range(len(object_points)) if i not in invalid_indices 104 | ] 105 | 106 | # convert to numpy float arrays 107 | object_points = np.array(object_points).astype(np.float32) 108 | image_points = np.array(image_points).astype(np.float32) 109 | 110 | # get T_ci_b 111 | camera_matrix = intrinsics["camera_matrices"][f"cam_{cam_id}"] 112 | dist_coeffs = intrinsics["distortion_coefficients"][f"cam_{cam_id}"] 113 | ret, rvec, tvec = cv2.solvePnP( 114 | object_points, image_points, camera_matrix, dist_coeffs 115 | ) 116 | rot = cv2.Rodrigues(rvec)[0] 117 | T_ci_b[f"cam_{cam_id}"] = np.eye(4) 118 | T_ci_b[f"cam_{cam_id}"][:3, :3] = rot 119 | T_ci_b[f"cam_{cam_id}"][ 120 | :3, 3 121 | ] = tvec.flatten() # these are extrinsics (world in camera frame) 122 | 123 | # save intrinsics and extrinsics in a dictionary 124 | calibration_dict = {} 125 | for cam_id in CAM_IDS: 126 | calibration_dict[f"cam_{cam_id}"] = { 127 | "int": intrinsics["camera_matrices"][f"cam_{cam_id}"], 128 | "dist_coeff": intrinsics["distortion_coefficients"][f"cam_{cam_id}"], 129 | "ext": T_ci_b[f"cam_{cam_id}"], 130 | } 131 | np.save(PATH_SAVE_CALIB, calibration_dict) 132 | -------------------------------------------------------------------------------- /point_policy/robot_utils/franka/convert_pkl_human_to_robot.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | import numpy as np 4 | import pickle as pkl 5 | from pathlib import Path 6 | from scipy.spatial.transform import Rotation as R 7 | from scipy.ndimage import zoom 8 | 9 | from gripper_points import extrapoints, Tshift 10 | from utils import camera2pixelkey, rigid_transform_3D 11 | 12 | 13 | def resize_depth_image(depth_image, new_size): 14 | # Calculate zoom factors 15 | zoom_factors = ( 16 | new_size[0] / depth_image.shape[0], 17 | new_size[1] / depth_image.shape[1], 18 | ) 19 | # Use scipy's zoom function with order=1 for bilinear interpolation 20 | resized_depth = zoom(depth_image, zoom_factors, order=1) 21 | return resized_depth 22 | 23 | 24 | # Create the parser 25 | parser = argparse.ArgumentParser( 26 | description="Convert human key points in pkl file to robot key points" 27 | ) 28 | 29 | # Add the arguments 30 | parser.add_argument("--data_dir", type=str, help="Path to the data directory") 31 | parser.add_argument("--calib_path", type=str, help="Path to the calibration file") 32 | parser.add_argument("--task_name", type=str, help="List of task names") 33 | parser.add_argument( 34 | "--use_gt_depth", action="store_true", help="Use ground truth depth" 35 | ) 36 | 37 | args = parser.parse_args() 38 | DATA_DIR = Path(args.data_dir) 39 | CALIB_PATH = Path(args.calib_path) 40 | task_name = args.task_name 41 | use_gt_depth = args.use_gt_depth 42 | 43 | camera_indices = [1, 2] 44 | image_size = (640, 480) 45 | save_image_size = (256, 256) 46 | num_hand_points = 9 47 | index_finger_indices = [3, 4] 48 | thumb_indices = [7, 8] 49 | 50 | if use_gt_depth: 51 | task_name += "_gt_depth" 52 | 53 | # orientation of the robot at the 0th step 54 | robot_base_orientation = R.from_rotvec([np.pi, 0, 0]).as_matrix() 55 | 56 | DATA_DIR = DATA_DIR / "processed_data_pkl" 57 | SAVE_DIR = DATA_DIR / "expert_demos" / "franka_env" 58 | 59 | calibration_data = np.load(CALIB_PATH, allow_pickle=True).item() 60 | DATA = pkl.load(open(DATA_DIR / f"{task_name}.pkl", "rb")) 61 | 62 | # make sure SAVE_DIR exists 63 | SAVE_DIR.mkdir(parents=True, exist_ok=True) 64 | 65 | observations = DATA["observations"] 66 | 67 | # find all pairs of indices with first index as index_finger and second index as thumb 68 | index_finger_thumb_pairs = [ 69 | (idx1, idx2) for idx1 in index_finger_indices for idx2 in thumb_indices 70 | ] 71 | 72 | observations = [] 73 | for obs_idx, observation in enumerate(DATA["observations"]): 74 | print(f"Processing observation {obs_idx}") 75 | 76 | for cam_idx in camera_indices: 77 | camera_name = f"cam_{cam_idx}" 78 | pixel_key = camera2pixelkey[camera_name] 79 | 80 | pixels = observation[pixel_key] 81 | pixels = [cv2.resize(p, save_image_size) for p in pixels] 82 | observation[pixel_key] = np.array(pixels) 83 | 84 | if use_gt_depth: 85 | depth = observation[f"depth_{pixel_key}"] 86 | depth = [resize_depth_image(d, save_image_size) for d in depth] 87 | observation[f"depth_{pixel_key}"] = np.array(depth) 88 | 89 | human_tracks_3d = observation[f"human_tracks_3d_{pixel_key}"] 90 | 91 | hand_points = human_tracks_3d[:, :num_hand_points] 92 | object_points = human_tracks_3d[:, num_hand_points:] 93 | 94 | robot_points, gripper_states = [], [] 95 | human_poses = [] 96 | for idx, hand_point in enumerate(hand_points): 97 | index_finger_thumb_dists = [ 98 | np.linalg.norm(hand_point[idx1] - hand_point[idx2]) 99 | for idx1, idx2 in index_finger_thumb_pairs 100 | ] 101 | index_finger_thumb_mindist = np.min(index_finger_thumb_dists) 102 | index_finger_thumb_mindist_idx = np.argmin(index_finger_thumb_dists) 103 | index_finger_idx, thumb_idx = index_finger_thumb_pairs[ 104 | index_finger_thumb_mindist_idx 105 | ] 106 | robot_pos = (hand_point[index_finger_idx] + hand_point[thumb_idx]) / 2 107 | 108 | if idx == 0: 109 | robot_ori = robot_base_orientation 110 | base_hand_points = hand_point.copy() 111 | else: 112 | current_hand_points = hand_point.copy() 113 | # find the rotation matrix between the base hand points and the current hand points 114 | rot, pos = rigid_transform_3D(base_hand_points, current_hand_points) 115 | 116 | robot_ori = rot @ robot_base_orientation 117 | 118 | # store human pose 119 | human_poses.append( 120 | np.concatenate([robot_pos, R.from_matrix(robot_ori).as_rotvec()]) 121 | ) 122 | 123 | # pos and orientation of gripper in robot base frame 124 | T_g_b = np.eye(4) 125 | T_g_b[:3, :3] = robot_ori 126 | T_g_b[:3, 3] = robot_pos 127 | 128 | # shift the point 129 | T_g_b = T_g_b @ Tshift 130 | 131 | # add extra points 132 | points3d = [T_g_b[:3, 3]] 133 | gripper_state = -1 # -1: open, 1: closed 134 | for idx, Tp in enumerate(extrapoints): 135 | if index_finger_thumb_mindist < 0.07 and idx in [0, 1]: 136 | Tp = Tp.copy() 137 | Tp[1, 3] = 0.015 if idx == 0 else -0.015 138 | gripper_state = 1 139 | pt = T_g_b @ Tp 140 | pt = pt[:3, 3] 141 | points3d.append(pt) 142 | points3d = np.array(points3d) 143 | 144 | robot_points.append(points3d) 145 | gripper_states.append(gripper_state) 146 | 147 | observation[f"robot_tracks_3d_{pixel_key}"] = np.array(robot_points) 148 | observation[f"object_tracks_3d_{pixel_key}"] = np.array(object_points) 149 | observation[f"gripper_states"] = np.array(gripper_states) 150 | observation[f"human_poses"] = np.array(human_poses) 151 | 152 | # get 2d robot tracks from 3d robot tracks 153 | P = calibration_data[camera_name]["ext"] 154 | K = calibration_data[camera_name]["int"] 155 | D = calibration_data[camera_name]["dist_coeff"] 156 | r, t = P[:3, :3], P[:3, 3] 157 | r, _ = cv2.Rodrigues(r) 158 | # robot points 159 | robot_points_2d = [] 160 | for points3d in robot_points: 161 | points3d = points3d[:, :3] 162 | points2d = cv2.projectPoints(points3d, r, t, K, D)[0].squeeze() 163 | robot_points_2d.append(points2d) 164 | robot_points_2d = np.array(robot_points_2d) 165 | observation[f"robot_tracks_{pixel_key}"] = robot_points_2d 166 | # object points 167 | object_points_2d = [] 168 | for points3d in object_points: 169 | points3d = points3d[:, :3] 170 | points2d = cv2.projectPoints(points3d, r, t, K, D)[0].squeeze() 171 | object_points_2d.append(points2d) 172 | object_points_2d = np.array(object_points_2d) 173 | observation[f"object_tracks_{pixel_key}"] = object_points_2d 174 | 175 | observations.append(observation) 176 | 177 | DATA["observations"] = observations 178 | 179 | # save data 180 | pkl.dump(DATA, open(SAVE_DIR / f"{task_name}.pkl", "wb")) 181 | -------------------------------------------------------------------------------- /point_policy/robot_utils/franka/gripper_points.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # For the Franka, while operating with end effector put EEF position 4 | # at the center of the gripper finger with is 12.7cm below arm tip. 5 | Tshift = np.array([[1, 0, 0, 0.0], [0, 1, 0, 0.0], [0, 0, 1, -0.127], [0, 0, 0, 1]]) 6 | 7 | extrapoints = [ 8 | # gripper points 9 | np.array( 10 | [ 11 | [1.0, 0.0, 0.0, 0.0], 12 | [0.0, 1.0, 0.0, 0.04], 13 | [0.0, 0.0, 1.0, 0.16], 14 | [0.0, 0.0, 0.0, 1.0], 15 | ] 16 | ), # 1 17 | np.array( 18 | [ 19 | [1.0, 0.0, 0.0, 0.0], 20 | [0.0, 1.0, 0.0, -0.04], 21 | [0.0, 0.0, 1.0, 0.16], 22 | [0.0, 0.0, 0.0, 1.0], 23 | ] 24 | ), # 2 25 | # First horizontal line 26 | np.array( 27 | [ 28 | [1.0, 0.0, 0.0, 0.0], 29 | [0.0, 1.0, 0.0, 0.0], 30 | [0.0, 0.0, 1.0, 0.08], 31 | [0.0, 0.0, 0.0, 1.0], 32 | ] 33 | ), # 3 34 | np.array( 35 | [ 36 | [1.0, 0.0, 0.0, 0.0], 37 | [0.0, 1.0, 0.0, 0.05], 38 | [0.0, 0.0, 1.0, 0.08], 39 | [0.0, 0.0, 0.0, 1.0], 40 | ] 41 | ), # 4 42 | np.array( 43 | [ 44 | [1.0, 0.0, 0.0, 0.0], 45 | [0.0, 1.0, 0.0, -0.05], 46 | [0.0, 0.0, 1.0, 0.08], 47 | [0.0, 0.0, 0.0, 1.0], 48 | ] 49 | ), # 5 50 | # Second horizontal line 51 | np.array( 52 | [ 53 | [1.0, 0.0, 0.0, 0.0], 54 | [0.0, 1.0, 0.0, 0.0], 55 | [0.0, 0.0, 1.0, 0.04], 56 | [0.0, 0.0, 0.0, 1.0], 57 | ] 58 | ), # 6 59 | np.array( 60 | [ 61 | [1.0, 0.0, 0.0, 0.0], 62 | [0.0, 1.0, 0.0, 0.05], 63 | [0.0, 0.0, 1.0, 0.04], 64 | [0.0, 0.0, 0.0, 1.0], 65 | ] 66 | ), # 7 67 | np.array( 68 | [ 69 | [1.0, 0.0, 0.0, 0.0], 70 | [0.0, 1.0, 0.0, -0.05], 71 | [0.0, 0.0, 1.0, 0.04], 72 | [0.0, 0.0, 0.0, 1.0], 73 | ] 74 | ), # 8 75 | ] 76 | -------------------------------------------------------------------------------- /point_policy/robot_utils/franka/label_points.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Script to label object points.\n", 8 | "\n", 9 | "Adapted from P3PO - [link](https://github.com/mlevy2525/P3PO/blob/main/p3po/data_generation/label_points.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# Run this the first time you use this notebook for point annotation\n", 19 | "# This will install the necessary dependencies\n", 20 | "!pip install ipywidgets ipycanvas" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 7, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import cv2\n", 30 | "import pickle\n", 31 | "from pathlib import Path\n", 32 | "\n", 33 | "#TODO: Set the task name here -- this will be used to save the output\n", 34 | "task_name = \"close_oven\"\n", 35 | "object_name = \"objects\"\n", 36 | "\n", 37 | "# If the image that shows at the bottom is bgr set original_bgr to True\n", 38 | "pickle_path = f\"/path/to/data/processed_data_pkl/{task_name}.pkl\"\n", 39 | "traj_idx = 0\n", 40 | "original_bgr = True\n", 41 | "\n", 42 | "#TODO: If its hard to see the image, you can increase the size_multiplier, this won't affect the selected coordinates\n", 43 | "size_multiplier = 1\n", 44 | "\n", 45 | "coordinates_path = f\"../../../coordinates/{task_name}\"" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 8, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "%gui asyncio\n", 55 | "\n", 56 | "import os\n", 57 | "from PIL import Image\n", 58 | "\n", 59 | "import numpy as np\n", 60 | "from IPython.display import display, Javascript\n", 61 | "import ipywidgets as widgets\n", 62 | "from ipycanvas import Canvas, hold_canvas\n", 63 | "import pickle\n", 64 | "\n", 65 | "import io\n", 66 | "import asyncio\n", 67 | "import logging\n", 68 | "\n", 69 | "# Define an async function to wait for button click\n", 70 | "async def wait_for_click(button):\n", 71 | " # Create a future object\n", 72 | " future = asyncio.Future()\n", 73 | " # Define the click event handler\n", 74 | " def on_button_clicked(b):\n", 75 | " future.set_result(None)\n", 76 | " # Attach the event handler to the button\n", 77 | " button.on_click(on_button_clicked)\n", 78 | " # Wait until the future is set\n", 79 | " await future\n", 80 | "\n", 81 | "class Points():\n", 82 | " def __init__(self, pixel_key, img, coordinates_path, size_multiplier=1):\n", 83 | " logging.getLogger().setLevel(logging.DEBUG)\n", 84 | " logging.info(\"Starting the Points class\")\n", 85 | " self.img = img\n", 86 | " self.size_multiplier = size_multiplier\n", 87 | " self.coordinates_path = coordinates_path\n", 88 | " self.pixel_key = pixel_key\n", 89 | "\n", 90 | " # Save the image to a bytes buffer\n", 91 | " image = Image.fromarray(self.img)\n", 92 | " size = img.shape\n", 93 | " image = image.resize((size[1] * self.size_multiplier, size[0] * self.size_multiplier))\n", 94 | " buffer = io.BytesIO()\n", 95 | " image.save(buffer, format='PNG')\n", 96 | " buffer.seek(0)\n", 97 | "\n", 98 | " # Create an IPyWidgets Image widget\n", 99 | " self.canvas = Canvas(width=size[1] * self.size_multiplier, height=size[0] * self.size_multiplier)\n", 100 | " # Define the size of each cell\n", 101 | "\n", 102 | " self.canvas.put_image_data(np.array(image), 0, 0)\n", 103 | "\n", 104 | " # Display coordinates\n", 105 | " coords_label = widgets.Label(value=\"Click on the image to select the coordinates\")\n", 106 | "\n", 107 | " # Define the click event handler\n", 108 | " self.coords = []\n", 109 | " def on_click(x, y):\n", 110 | " coords_label.value = f\"Coordinates: ({x}, {y})\"\n", 111 | " self.coords.append((0, x, y))\n", 112 | "\n", 113 | " with hold_canvas(self.canvas):\n", 114 | " self.canvas.put_image_data(np.array(image), 0, 0) # Redraw the original image\n", 115 | "\n", 116 | " self.canvas.fill_style = 'red'\n", 117 | " for coord in self.coords:\n", 118 | " x, y = coord[1] // self.size_multiplier, coord[2] // self.size_multiplier\n", 119 | " self.canvas.fill_circle(x, y, 2)\n", 120 | "\n", 121 | " # Connect the click event to the handler\n", 122 | " self.canvas.on_mouse_down(on_click)\n", 123 | "\n", 124 | " self.button = widgets.Button(description=\"Save Points\")\n", 125 | "\n", 126 | " # Display the widgets\n", 127 | " self.vbox = widgets.VBox([self.canvas, coords_label, self.button])\n", 128 | "\n", 129 | " # # Display the widget\n", 130 | " display(self.vbox)\n", 131 | "\n", 132 | " def on_done(self):\n", 133 | " logging.info(\"saving\")\n", 134 | " Path(self.coordinates_path + \"/coords/\").mkdir(parents=True, exist_ok=True)\n", 135 | " with open(self.coordinates_path + \"/coords/\" + f\"{self.pixel_key}_{object_name}\" + \".pkl\", 'wb') as f:\n", 136 | " try:\n", 137 | " pickle.dump(self.coords, f)\n", 138 | " except Exception as e:\n", 139 | " logging.info(e)\n", 140 | " Path(self.coordinates_path + \"/images/\").mkdir(parents=True, exist_ok=True)\n", 141 | " with open(self.coordinates_path + \"/images/\" + f\"{self.pixel_key}\" + \".png\", 'wb') as f:\n", 142 | " try:\n", 143 | " image = Image.fromarray(self.img)\n", 144 | " image.save(f)\n", 145 | " except Exception as e:\n", 146 | " logging.info(e)\n", 147 | " logging.info(\"saved\")" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "# NOTE: Label points for each pixel key. Make sure the order \n", 157 | "# of points is the same across pixel keys.\n", 158 | "pixel_key = \"pixels2\"\n", 159 | " \n", 160 | "with open(pickle_path, 'rb') as f:\n", 161 | " data = pickle.load(f)\n", 162 | "img = data['observations'][traj_idx][pixel_key][0]\n", 163 | "use_video = False\n", 164 | "if original_bgr:\n", 165 | " img = img[:,:,::-1] \n", 166 | "\n", 167 | "async def f():\n", 168 | " point = Points(pixel_key, img, coordinates_path, size_multiplier)\n", 169 | " x = await wait_for_click(point.button)\n", 170 | " point.vbox.close()\n", 171 | " point.canvas.close()\n", 172 | " point.on_done()\n", 173 | "asyncio.ensure_future(f())" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "p3po", 187 | "language": "python", 188 | "name": "python3" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.10.16" 201 | } 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 2 205 | } 206 | -------------------------------------------------------------------------------- /point_policy/robot_utils/franka/save_video.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import numpy as np 3 | import imageio 4 | from pathlib import Path 5 | import cv2 6 | 7 | DATA_DIR = Path("/path/to/expert_demos") 8 | TASK_NAME = "close_oven" 9 | plot_pts = True 10 | 11 | DATA_PATH = DATA_DIR / f"{TASK_NAME}.pkl" 12 | SAVE_DIR = Path(f"./videos/{TASK_NAME}") 13 | pixel_keys = ["pixels1", "pixels2"] 14 | original_image_size = (640, 480) 15 | k = 1 # number of track points to plot per frame 16 | traj_indices = None 17 | 18 | SAVE_DIR.mkdir(parents=True, exist_ok=True) 19 | 20 | # Read data 21 | with open(DATA_PATH, "rb") as f: 22 | data = pkl.load(f) 23 | 24 | if traj_indices is None: 25 | traj_indices = [i for i in range(len(data["observations"]))] 26 | 27 | for traj_idx in traj_indices: 28 | print(f"Processing traj_idx: {traj_idx}") 29 | for pixel_key in pixel_keys: 30 | point_track_key = ( 31 | f"robot_tracks_{pixel_key}" 32 | if "human" not in TASK_NAME 33 | else f"human_tracks_{pixel_key}" 34 | ) 35 | object_track_key = f"object_tracks_{pixel_key}" 36 | 37 | # Extract images and point tracks 38 | frames = data["observations"][traj_idx][pixel_key] 39 | frames = np.array(frames) 40 | 41 | if plot_pts and pixel_key != "pixels51": 42 | point_tracks = data["observations"][traj_idx][point_track_key] 43 | point_tracks = np.array(point_tracks) 44 | object_tracks = data["observations"][traj_idx][object_track_key] 45 | object_tracks = np.array(object_tracks) 46 | point_tracks = np.concatenate([point_tracks, object_tracks], axis=1) 47 | 48 | # Color for each point 49 | num_points = point_tracks.shape[1] 50 | colors = np.zeros((num_points, 3)) 51 | third = num_points // 3 52 | colors[:third, 0] = 255 53 | colors[third : 2 * third, 1] = 255 54 | colors[2 * third :, 2] = 255 55 | 56 | save_frames = [] 57 | for i, frame in enumerate(frames): 58 | frame = frame[..., [2, 1, 0]].copy() 59 | if plot_pts and pixel_key != "pixels51": 60 | for j, points in enumerate(point_tracks[max(0, i - k) : i + 1]): 61 | # points = points[3:4] 62 | for l, point in enumerate(points): 63 | point = point.astype(int) 64 | point[0] = int( 65 | point[0] * frame.shape[1] / original_image_size[0] 66 | ) 67 | point[1] = int( 68 | point[1] * frame.shape[0] / original_image_size[1] 69 | ) 70 | frame = cv2.circle( 71 | frame, tuple(point), 2, colors[l].tolist(), -1 72 | ) 73 | save_frames.append(frame) 74 | 75 | # Save the video 76 | save_frames = np.array(save_frames).astype(np.uint8) 77 | save_path = SAVE_DIR / f"{TASK_NAME}_traj{traj_idx}_{pixel_key}.mp4" 78 | imageio.mimwrite(save_path, save_frames, fps=20) 79 | -------------------------------------------------------------------------------- /point_policy/robot_utils/franka/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | camera2pixelkey = { 5 | "cam_1": "pixels1", 6 | "cam_2": "pixels2", 7 | "cam_51": "pixels51", 8 | } 9 | pixelkey2camera = {v: k for k, v in camera2pixelkey.items()} 10 | 11 | 12 | def pixel2d_to_3d_torch(points2d, depths, intrinsic_matrix, extrinsic_matrix): 13 | intrinsic_matrix = torch.tensor(intrinsic_matrix).float().to(depths.device) 14 | extrinsic_matrix = torch.tensor(extrinsic_matrix).float().to(depths.device) 15 | fx = intrinsic_matrix[0, 0] 16 | fy = intrinsic_matrix[1, 1] 17 | cx = intrinsic_matrix[0, 2] 18 | cy = intrinsic_matrix[1, 2] 19 | x = (points2d[:, 0] - cx) / fx 20 | y = (points2d[:, 1] - cy) / fy 21 | points3d = torch.stack((x * depths, y * depths, depths), dim=1) # in camera frame 22 | points3d = torch.cat( 23 | (points3d, torch.ones((len(points2d), 1)).to(depths.device)), dim=1 24 | ) 25 | points3d = (torch.linalg.inv(extrinsic_matrix) @ points3d.T).T # world frame 26 | return points3d[..., :3] 27 | 28 | 29 | def pixel2d_to_3d(points2d, depths, intrinsic_matrix, extrinsic_matrix): 30 | points2d = np.array(points2d) 31 | fx = intrinsic_matrix[0, 0] 32 | fy = intrinsic_matrix[1, 1] 33 | cx = intrinsic_matrix[0, 2] 34 | cy = intrinsic_matrix[1, 2] 35 | x = (points2d[:, 0] - cx) / fx 36 | y = (points2d[:, 1] - cy) / fy 37 | points_3d = np.column_stack((x * depths, y * depths, depths)) # in camera frame 38 | points_3d = np.concatenate([points_3d, np.ones((len(points2d), 1))], axis=1) 39 | points_3d = (np.linalg.inv(extrinsic_matrix) @ points_3d.T).T # world frame 40 | return points_3d[..., :3] 41 | 42 | 43 | def pixel3d_to_2d(points3d, intrinsic_matrix, camera_projection_matrix): 44 | points3d = np.array(points3d) 45 | points3d = np.concatenate([points3d, np.ones((len(points3d), 1))], axis=1) 46 | points3d = (camera_projection_matrix @ points3d.T).T # camera frame 47 | depth = points3d[:, 2] 48 | points2d = (intrinsic_matrix @ points3d.T).T 49 | points2d = points2d / points2d[:, 2][:, None] 50 | return points2d[..., :2], depth 51 | 52 | 53 | def triangulate_points(P, points): 54 | """ 55 | Triangulate a batch of points from a variable number of camera views. 56 | 57 | Parameters: 58 | P: list of 3x4 projection matrices for each camera (currently world2camera transform) 59 | points: list of Nx2 arrays of normalized image coordinates for each camera 60 | 61 | Returns: 62 | Nx4 array of homogeneous 3D points 63 | """ 64 | num_views = len(P) 65 | assert num_views > 1, "At least 2 cameras are required for triangulation" 66 | 67 | num_points = points[0].shape[0] 68 | A = np.zeros((num_points, num_views * 2, 4)) 69 | 70 | for idx in range(num_views): 71 | # Set up the linear system for each point 72 | A[:, idx * 2] = points[idx][:, 0, np.newaxis] * P[idx][2] - P[idx][0] 73 | A[:, idx * 2 + 1] = points[idx][:, 1, np.newaxis] * P[idx][2] - P[idx][1] 74 | 75 | # Solve the system using SVD 76 | _, _, Vt = np.linalg.svd(A) 77 | X = Vt[:, -1, :] 78 | 79 | # Normalize the homogeneous coordinates 80 | X = X / X[:, 3:] 81 | 82 | return X 83 | 84 | 85 | def rigid_transform_3D(A, B): 86 | assert A.shape == B.shape 87 | 88 | num_rows, num_cols = A.shape 89 | if num_cols != 3: 90 | raise Exception(f"matrix A is not Nx3, it is {num_rows}x{num_cols}") 91 | 92 | num_rows, num_cols = B.shape 93 | if num_cols != 3: 94 | raise Exception(f"matrix B is not Nx3, it is {num_rows}x{num_cols}") 95 | 96 | # find mean column wise 97 | centroid_A = np.mean(A, axis=0) 98 | centroid_B = np.mean(B, axis=0) 99 | 100 | # subtract mean 101 | Am = A - centroid_A 102 | Bm = B - centroid_B 103 | 104 | H = Am.T @ Bm 105 | 106 | # find rotation 107 | U, S, Vt = np.linalg.svd(H) 108 | R = Vt.T @ U.T 109 | 110 | # special reflection case 111 | if np.linalg.det(R) < 0: 112 | print("det(R) < R, reflection detected!, correcting for it ...") 113 | Vt[2, :] *= -1 114 | R = Vt.T @ U.T 115 | 116 | t = -R @ centroid_A.T + centroid_B.T 117 | 118 | return R, t 119 | 120 | 121 | def rotation_6d_to_matrix(d6: np.ndarray) -> np.ndarray: 122 | """ 123 | Converts 6D rotation representation to rotation matrix 124 | using Gram-Schmidt orthogonalization. 125 | 126 | Args: 127 | d6: 6D rotation representation, of shape (..., 6) 128 | 129 | Returns: 130 | Batch of rotation matrices of shape (..., 3, 3) 131 | """ 132 | a1, a2 = d6[..., :3], d6[..., 3:] 133 | 134 | b1 = a1 / np.linalg.norm(a1, axis=-1, keepdims=True) 135 | b2 = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1 136 | b2 = b2 / np.linalg.norm(b2, axis=-1, keepdims=True) 137 | b3 = np.cross(b1, b2, axis=-1) 138 | 139 | return np.stack((b1, b2, b3), axis=-2) 140 | 141 | 142 | def matrix_to_rotation_6d(matrix: np.ndarray) -> np.ndarray: 143 | """ 144 | Converts rotation matrices to 6D rotation representation 145 | by dropping the last row. 146 | 147 | Args: 148 | matrix: Batch of rotation matrices of shape (..., 3, 3) 149 | 150 | Returns: 151 | 6D rotation representation, of shape (..., 6) 152 | """ 153 | batch_dim = matrix.shape[:-2] 154 | return matrix[..., :2, :].reshape(batch_dim + (6,)) 155 | -------------------------------------------------------------------------------- /point_policy/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import warnings 4 | import os 5 | 6 | os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 7 | os.environ["MUJOCO_GL"] = "egl" 8 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 9 | from pathlib import Path 10 | 11 | import hydra 12 | import torch 13 | import numpy as np 14 | 15 | import utils 16 | from logger import Logger 17 | from replay_buffer import make_expert_replay_loader 18 | from video import VideoRecorder 19 | 20 | warnings.filterwarnings("ignore", category=DeprecationWarning) 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | def make_agent(obs_spec, action_spec, cfg): 25 | obs_shape = {} 26 | for key in cfg.suite.pixel_keys: 27 | obs_shape[key] = obs_spec[key].shape 28 | if cfg.use_proprio: 29 | obs_shape[cfg.suite.proprio_key] = obs_spec[cfg.suite.proprio_key].shape 30 | obs_shape[cfg.suite.feature_key] = obs_spec[cfg.suite.feature_key].shape 31 | cfg.agent.obs_shape = obs_shape 32 | cfg.agent.action_shape = action_spec.shape 33 | return hydra.utils.instantiate(cfg.agent) 34 | 35 | 36 | class WorkspaceIL: 37 | def __init__(self, cfg): 38 | self.work_dir = Path.cwd() 39 | print(f"workspace: {self.work_dir}") 40 | 41 | self.cfg = cfg 42 | utils.set_seed_everywhere(cfg.seed) 43 | self.device = torch.device(cfg.device) 44 | 45 | # load data 46 | dataset_iterable = hydra.utils.call(self.cfg.expert_dataset) 47 | self.expert_replay_loader = make_expert_replay_loader( 48 | dataset_iterable, self.cfg.batch_size 49 | ) 50 | self.expert_replay_iter = iter(self.expert_replay_loader) 51 | self.stats = self.expert_replay_loader.dataset.stats 52 | 53 | # create logger 54 | self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb) 55 | # create envs 56 | self.cfg.suite.task_make_fn.max_episode_len = ( 57 | self.expert_replay_loader.dataset._max_episode_len 58 | ) 59 | self.cfg.suite.task_make_fn.max_state_dim = ( 60 | self.expert_replay_loader.dataset._max_state_dim 61 | ) 62 | if self.cfg.suite.name == "dmc": 63 | self.cfg.suite.task_make_fn.max_action_dim = ( 64 | self.expert_replay_loader.dataset._max_action_dim 65 | ) 66 | 67 | # load points cfg if using object points 68 | # try-except since baku doesn't have use_object_points 69 | try: 70 | if self.cfg.suite.use_object_points: 71 | import yaml 72 | 73 | cfg_path = f"{cfg.root_dir}/point_policy/cfgs/suite/points_cfg.yaml" 74 | with open(cfg_path) as stream: 75 | try: 76 | points_cfg = yaml.safe_load(stream) 77 | except yaml.YAMLError as exc: 78 | print(exc) 79 | root_dir, dift_path, cotracker_checkpoint = ( 80 | points_cfg["root_dir"], 81 | points_cfg["dift_path"], 82 | points_cfg["cotracker_checkpoint"], 83 | ) 84 | points_cfg["dift_path"] = f"{root_dir}/{dift_path}" 85 | points_cfg[ 86 | "cotracker_checkpoint" 87 | ] = f"{root_dir}/{cotracker_checkpoint}" 88 | self.cfg.suite.task_make_fn.points_cfg = points_cfg 89 | except: 90 | pass 91 | 92 | self.env, self.task_descriptions = hydra.utils.call(self.cfg.suite.task_make_fn) 93 | 94 | # create agent 95 | self.agent = make_agent( 96 | self.env[0].observation_spec(), self.env[0].action_spec(), cfg 97 | ) 98 | 99 | self.envs_till_idx = self.expert_replay_loader.dataset.envs_till_idx 100 | 101 | self.timer = utils.Timer() 102 | self._global_step = 0 103 | self._global_episode = 0 104 | 105 | self.video_recorder = VideoRecorder( 106 | self.work_dir if self.cfg.save_video else None 107 | ) 108 | 109 | @property 110 | def global_step(self): 111 | return self._global_step 112 | 113 | @property 114 | def global_episode(self): 115 | return self._global_episode 116 | 117 | @property 118 | def global_frame(self): 119 | return self.global_step * self.cfg.suite.action_repeat 120 | 121 | def eval(self): 122 | self.agent.train(False) 123 | episode_rewards = [] 124 | successes = [] 125 | 126 | num_envs = self.envs_till_idx 127 | 128 | for env_idx in range(num_envs): 129 | print(f"evaluating env {env_idx}") 130 | episode, total_reward = 0, 0 131 | eval_until_episode = utils.Until(self.cfg.suite.num_eval_episodes) 132 | success = [] 133 | 134 | while eval_until_episode(episode): 135 | time_step = self.env[env_idx].reset() 136 | self.agent.buffer_reset() 137 | step = 0 138 | 139 | if episode == 0: 140 | self.video_recorder.init(self.env[env_idx], enabled=True) 141 | 142 | # plot obs with cv2 143 | while not time_step.last(): 144 | with torch.no_grad(), utils.eval_mode(self.agent): 145 | action = self.agent.act( 146 | time_step.observation, 147 | self.stats, 148 | step, 149 | self.global_step, 150 | eval_mode=True, 151 | ) 152 | time_step = self.env[env_idx].step(action) 153 | self.video_recorder.record(self.env[env_idx]) 154 | total_reward += time_step.reward 155 | step += 1 156 | 157 | episode += 1 158 | success.append(time_step.observation["goal_achieved"]) 159 | self.video_recorder.save(f"{self.global_step}_env{env_idx}.mp4") 160 | episode_rewards.append(total_reward / episode) 161 | successes.append(np.mean(success)) 162 | 163 | for _ in range(len(self.env) - num_envs): 164 | episode_rewards.append(0) 165 | successes.append(0) 166 | 167 | with self.logger.log_and_dump_ctx(self.global_step, ty="eval") as log: 168 | for env_idx, reward in enumerate(episode_rewards): 169 | log(f"episode_reward_env{env_idx}", reward) 170 | log(f"success_env{env_idx}", successes[env_idx]) 171 | log("episode_reward", np.mean(episode_rewards[:num_envs])) 172 | log("success", np.mean(successes)) 173 | log("episode_length", step * self.cfg.suite.action_repeat / episode) 174 | log("episode", self.global_episode) 175 | log("step", self.global_step) 176 | 177 | self.agent.train(True) 178 | 179 | def train(self): 180 | # predicates 181 | train_until_step = utils.Until(self.cfg.suite.num_train_steps, 1) 182 | log_every_step = utils.Every(self.cfg.suite.log_every_steps, 1) 183 | eval_every_step = utils.Every(self.cfg.suite.eval_every_steps, 1) 184 | save_every_step = utils.Every(self.cfg.suite.save_every_steps, 1) 185 | 186 | metrics = None 187 | while train_until_step(self.global_step): 188 | # try to evaluate 189 | if ( 190 | self.cfg.eval 191 | and eval_every_step(self.global_step) 192 | and self.global_step > 0 193 | ): 194 | self.logger.log( 195 | "eval_total_time", self.timer.total_time(), self.global_frame 196 | ) 197 | self.eval() 198 | 199 | # update 200 | metrics = self.agent.update( 201 | self.expert_replay_iter, 202 | self.global_step, 203 | ) 204 | self.logger.log_metrics(metrics, self.global_frame, ty="train") 205 | 206 | # log 207 | if log_every_step(self.global_step): 208 | elapsed_time, total_time = self.timer.reset() 209 | with self.logger.log_and_dump_ctx(self.global_frame, ty="train") as log: 210 | log("total_time", total_time) 211 | log("actor_loss", metrics["actor_loss"]) 212 | log("step", self.global_step) 213 | 214 | # save snapshot 215 | if save_every_step(self.global_step): 216 | self.save_snapshot() 217 | 218 | self._global_step += 1 219 | 220 | def save_snapshot(self): 221 | snapshot_dir = self.work_dir / "snapshot" 222 | snapshot_dir.mkdir(exist_ok=True) 223 | snapshot = snapshot_dir / f"{self.global_step}.pt" 224 | self.agent.clear_buffers() 225 | keys_to_save = ["timer", "_global_step", "_global_episode", "stats"] 226 | payload = {k: self.__dict__[k] for k in keys_to_save} 227 | payload.update(self.agent.save_snapshot()) 228 | with snapshot.open("wb") as f: 229 | torch.save(payload, f) 230 | 231 | self.agent.buffer_reset() 232 | 233 | def load_snapshot(self, snapshots): 234 | # bc 235 | with snapshots["bc"].open("rb") as f: 236 | payload = torch.load(f) 237 | agent_payload = {} 238 | for k, v in payload.items(): 239 | if k not in self.__dict__: 240 | agent_payload[k] = v 241 | self.agent.load_snapshot(agent_payload, eval=False) 242 | 243 | 244 | @hydra.main(config_path="cfgs", config_name="config") 245 | def main(cfg): 246 | from train import WorkspaceIL as W 247 | 248 | workspace = W(cfg) 249 | 250 | # Load weights 251 | if cfg.load_bc: 252 | snapshots = {} 253 | bc_snapshot = Path(cfg.bc_weight) 254 | if not bc_snapshot.exists(): 255 | raise FileNotFoundError(f"bc weight not found: {bc_snapshot}") 256 | print(f"loading bc weight: {bc_snapshot}") 257 | snapshots["bc"] = bc_snapshot 258 | workspace.load_snapshot(snapshots) 259 | 260 | workspace.train() 261 | 262 | 263 | if __name__ == "__main__": 264 | main() 265 | -------------------------------------------------------------------------------- /point_policy/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from omegaconf import OmegaConf 10 | from torch import distributions as pyd 11 | from torch.distributions.utils import _standard_normal 12 | from scipy import linalg 13 | 14 | 15 | class eval_mode: 16 | def __init__(self, *models): 17 | self.models = models 18 | 19 | def __enter__(self): 20 | self.prev_states = [] 21 | for model in self.models: 22 | self.prev_states.append(model.training) 23 | model.train(False) 24 | 25 | def __exit__(self, *args): 26 | for model, state in zip(self.models, self.prev_states): 27 | model.train(state) 28 | return False 29 | 30 | 31 | def set_seed_everywhere(seed): 32 | torch.manual_seed(seed) 33 | if torch.cuda.is_available(): 34 | torch.cuda.manual_seed_all(seed) 35 | np.random.seed(seed) 36 | random.seed(seed) 37 | 38 | 39 | def soft_update_params(net, target_net, tau): 40 | for param, target_param in zip(net.parameters(), target_net.parameters()): 41 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 42 | 43 | 44 | # def to_torch(xs, device): 45 | # return tuple(torch.as_tensor(x, device=device) for x in xs) 46 | def to_torch(xs, device): 47 | for key, value in xs.items(): 48 | xs[key] = torch.as_tensor(value, device=device) 49 | return xs 50 | 51 | 52 | def weight_init(m): 53 | if isinstance(m, nn.Linear): 54 | nn.init.orthogonal_(m.weight.data) 55 | if hasattr(m.bias, "data"): 56 | m.bias.data.fill_(0.0) 57 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 58 | gain = nn.init.calculate_gain("relu") 59 | nn.init.orthogonal_(m.weight.data, gain) 60 | if hasattr(m.bias, "data"): 61 | m.bias.data.fill_(0.0) 62 | 63 | 64 | class Until: 65 | def __init__(self, until, action_repeat=1): 66 | self._until = until 67 | self._action_repeat = action_repeat 68 | 69 | def __call__(self, step): 70 | if self._until is None: 71 | return True 72 | until = self._until // self._action_repeat 73 | return step < until 74 | 75 | 76 | class Every: 77 | def __init__(self, every, action_repeat=1): 78 | self._every = every 79 | self._action_repeat = action_repeat 80 | 81 | def __call__(self, step): 82 | if self._every is None: 83 | return False 84 | every = self._every // self._action_repeat 85 | if step % every == 0: 86 | return True 87 | return False 88 | 89 | 90 | class Timer: 91 | def __init__(self): 92 | self._start_time = time.time() 93 | self._last_time = time.time() 94 | # Keep track of evaluation time so that total time only includes train time 95 | self._eval_start_time = 0 96 | self._eval_time = 0 97 | self._eval_flag = False 98 | 99 | def reset(self): 100 | elapsed_time = time.time() - self._last_time 101 | self._last_time = time.time() 102 | total_time = time.time() - self._start_time - self._eval_time 103 | return elapsed_time, total_time 104 | 105 | def eval(self): 106 | if not self._eval_flag: 107 | self._eval_flag = True 108 | self._eval_start_time = time.time() 109 | else: 110 | self._eval_time += time.time() - self._eval_start_time 111 | self._eval_flag = False 112 | self._eval_start_time = 0 113 | 114 | def total_time(self): 115 | return time.time() - self._start_time - self._eval_time 116 | 117 | 118 | class TruncatedNormal(pyd.Normal): 119 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 120 | super().__init__(loc, scale, validate_args=False) 121 | self.low = low 122 | self.high = high 123 | self.eps = eps 124 | 125 | def _clamp(self, x): 126 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 127 | x = x - x.detach() + clamped_x.detach() 128 | return x 129 | 130 | def sample(self, clip=None, sample_shape=torch.Size()): 131 | shape = self._extended_shape(sample_shape) 132 | eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) 133 | eps *= self.scale 134 | if clip is not None: 135 | eps = torch.clamp(eps, -clip, clip) 136 | x = self.loc + eps 137 | return self._clamp(x) 138 | 139 | 140 | class Normal(pyd.Normal): 141 | def __init__(self, loc, scale, eps=1e-6): 142 | super().__init__(loc, scale, validate_args=False) 143 | self.eps = eps 144 | 145 | def sample(self, clip=None, sample_shape=torch.Size()): 146 | shape = self._extended_shape(sample_shape) 147 | eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) 148 | eps *= self.scale 149 | x = self.loc + eps 150 | return x 151 | 152 | 153 | def schedule(schdl, step): 154 | try: 155 | return float(schdl) 156 | except ValueError: 157 | match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) 158 | if match: 159 | init, final, duration = [float(g) for g in match.groups()] 160 | mix = np.clip(step / duration, 0.0, 1.0) 161 | return (1.0 - mix) * init + mix * final 162 | match = re.match(r"step_linear\((.+),(.+),(.+),(.+),(.+)\)", schdl) 163 | if match: 164 | init, final1, duration1, final2, duration2 = [ 165 | float(g) for g in match.groups() 166 | ] 167 | if step <= duration1: 168 | mix = np.clip(step / duration1, 0.0, 1.0) 169 | return (1.0 - mix) * init + mix * final1 170 | else: 171 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 172 | return (1.0 - mix) * final1 + mix * final2 173 | raise NotImplementedError(schdl) 174 | 175 | 176 | class RandomShiftsAug(nn.Module): 177 | def __init__(self, pad): 178 | super().__init__() 179 | self.pad = pad 180 | 181 | def forward(self, x): 182 | n, c, h, w = x.size() 183 | assert h == w 184 | padding = tuple([self.pad] * 4) 185 | x = F.pad(x, padding, "replicate") 186 | eps = 1.0 / (h + 2 * self.pad) 187 | arange = torch.linspace( 188 | -1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype 189 | )[:h] 190 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 191 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 192 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 193 | 194 | shift = torch.randint( 195 | 0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype 196 | ) 197 | shift *= 2.0 / (h + 2 * self.pad) 198 | 199 | grid = base_grid + shift 200 | return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) 201 | 202 | 203 | class TorchRunningMeanStd: 204 | def __init__(self, epsilon=1e-4, shape=(), device=None): 205 | self.mean = torch.zeros(shape, device=device) 206 | self.var = torch.ones(shape, device=device) 207 | self.count = epsilon 208 | 209 | def update(self, x): 210 | with torch.no_grad(): 211 | batch_mean = torch.mean(x, axis=0) 212 | batch_var = torch.var(x, axis=0) 213 | batch_count = x.shape[0] 214 | self.update_from_moments(batch_mean, batch_var, batch_count) 215 | 216 | def update_from_moments(self, batch_mean, batch_var, batch_count): 217 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 218 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count 219 | ) 220 | 221 | @property 222 | def std(self): 223 | return torch.sqrt(self.var) 224 | 225 | 226 | def update_mean_var_count_from_moments( 227 | mean, var, count, batch_mean, batch_var, batch_count 228 | ): 229 | delta = batch_mean - mean 230 | tot_count = count + batch_count 231 | 232 | new_mean = mean + delta + batch_count / tot_count 233 | m_a = var * count 234 | m_b = batch_var * batch_count 235 | M2 = m_a + m_b + torch.pow(delta, 2) * count * batch_count / tot_count 236 | new_var = M2 / tot_count 237 | new_count = tot_count 238 | 239 | return new_mean, new_var, new_count 240 | 241 | 242 | def batch_norm_to_group_norm(layer): 243 | """Iterates over a whole model (or layer of a model) and replaces every batch norm 2D with a group norm 244 | 245 | Args: 246 | layer: model or one layer of a model like resnet34.layer1 or Sequential(), ... 247 | """ 248 | 249 | # num_channels: num_groups 250 | GROUP_NORM_LOOKUP = { 251 | 16: 2, # -> channels per group: 8 252 | 32: 4, # -> channels per group: 8 253 | 64: 8, # -> channels per group: 8 254 | 128: 8, # -> channels per group: 16 255 | 256: 16, # -> channels per group: 16 256 | 512: 32, # -> channels per group: 16 257 | 1024: 32, # -> channels per group: 32 258 | 2048: 32, # -> channels per group: 64 259 | } 260 | 261 | for name, module in layer.named_modules(): 262 | if name: 263 | try: 264 | # name might be something like: model.layer1.sequential.0.conv1 --> this wont work. Except this case 265 | sub_layer = getattr(layer, name) 266 | if isinstance(sub_layer, torch.nn.BatchNorm2d): 267 | num_channels = sub_layer.num_features 268 | # first level of current layer or model contains a batch norm --> replacing. 269 | layer._modules[name] = torch.nn.GroupNorm( 270 | GROUP_NORM_LOOKUP[num_channels], num_channels 271 | ) 272 | except AttributeError: 273 | # go deeper: set name to layer1, getattr will return layer1 --> call this func again 274 | name = name.split(".")[0] 275 | sub_layer = getattr(layer, name) 276 | sub_layer = batch_norm_to_group_norm(sub_layer) 277 | layer.__setattr__(name=name, value=sub_layer) 278 | return layer 279 | -------------------------------------------------------------------------------- /point_policy/video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | 5 | 6 | class VideoRecorder: 7 | def __init__(self, root_dir, render_size=256, fps=20): 8 | if root_dir is not None: 9 | self.save_dir = root_dir / "eval_video" 10 | self.save_dir.mkdir(exist_ok=True) 11 | else: 12 | self.save_dir = None 13 | 14 | self.render_size = render_size 15 | self.fps = fps 16 | self.frames = [] 17 | 18 | def init(self, env, enabled=True): 19 | self.frames = [] 20 | self.enabled = self.save_dir is not None and enabled 21 | self.record(env) 22 | 23 | def record(self, env): 24 | if self.enabled: 25 | if hasattr(env, "physics"): 26 | frame = env.physics.render( 27 | height=self.render_size, width=self.render_size, camera_id=0 28 | ) 29 | else: 30 | frame = env.render() 31 | self.frames.append(frame) 32 | 33 | def save(self, file_name): 34 | if self.enabled: 35 | path = self.save_dir / file_name 36 | imageio.mimsave(str(path), self.frames, fps=self.fps) 37 | 38 | 39 | class TrainVideoRecorder: 40 | def __init__(self, root_dir, render_size=256, fps=20): 41 | if root_dir is not None: 42 | self.save_dir = root_dir / "train_video" 43 | self.save_dir.mkdir(exist_ok=True) 44 | else: 45 | self.save_dir = None 46 | 47 | self.render_size = render_size 48 | self.fps = fps 49 | self.frames = [] 50 | 51 | def init(self, obs, enabled=True): 52 | self.frames = [] 53 | self.enabled = self.save_dir is not None and enabled 54 | self.record(obs) 55 | 56 | def record(self, obs): 57 | if self.enabled: 58 | frame = cv2.resize( 59 | obs[-3:].transpose(1, 2, 0), 60 | dsize=(self.render_size, self.render_size), 61 | interpolation=cv2.INTER_CUBIC, 62 | ) 63 | self.frames.append(frame) 64 | 65 | def save(self, file_name): 66 | if self.enabled: 67 | path = self.save_dir / file_name 68 | imageio.mimsave(str(path), self.frames, fps=self.fps) 69 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | git submodule update --init --recursive 2 | 3 | cd co-tracker 4 | git checkout main 5 | pip install -e . 6 | pip install matplotlib flow_vis tqdm tensorboard imageio[ffmpeg] 7 | mkdir -p checkpoints 8 | cd checkpoints 9 | wget https://huggingface.co/facebook/cotracker3/resolve/main/scaled_online.pth 10 | cd ../../ 11 | 12 | cd dift 13 | pip install xformers==0.0.29.post1 14 | git checkout main 15 | cd ../ 16 | 17 | pip install torchvision==0.20.0 18 | pip install mediapipe==0.10.11 19 | pip install --force-reinstall transformers==4.45.2 20 | pip install --force-reinstall huggingface_hub==0.25.2 21 | pip install --force-reinstall numpy==2.1.3 22 | pip install --force-reinstall torch==2.5.1 23 | --------------------------------------------------------------------------------