├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── docs └── franka_walkthrough.md ├── examples ├── __init__.py ├── data_util.py ├── experiments │ ├── __init__.py │ ├── config.py │ ├── mappings.py │ ├── resnet10_params.pkl │ └── task1_pick_banana │ │ ├── config.py │ │ ├── run_actor_conrft.sh │ │ ├── run_actor_hilserl.sh │ │ ├── run_learner_conrft.sh │ │ ├── run_learner_conrft_pretrain.sh │ │ ├── run_learner_hilserl.sh │ │ └── wrapper.py ├── record_demos.py ├── record_demos_octo.py ├── record_success_fail.py ├── train_conrft_octo.py └── train_reward_classifier.py ├── serl_launcher ├── .gitignore ├── README.md ├── requirements.txt ├── serl_launcher │ ├── __init__.py │ ├── agents │ │ ├── __init__.py │ │ └── continuous │ │ │ ├── bc.py │ │ │ ├── conrft_single_octo_cp.py │ │ │ ├── ddpm_bc.py │ │ │ ├── sac.py │ │ │ └── sac_single.py │ ├── common │ │ ├── common.py │ │ ├── encoding.py │ │ ├── evaluation.py │ │ ├── optimizers.py │ │ ├── typing.py │ │ └── wandb.py │ ├── data │ │ ├── __init__.py │ │ ├── data_store.py │ │ ├── dataset.py │ │ ├── memory_efficient_replay_buffer.py │ │ └── replay_buffer.py │ ├── networks │ │ ├── actor_critic_nets.py │ │ ├── classifier.py │ │ ├── diffusion_nets.py │ │ ├── lagrange.py │ │ ├── mlp.py │ │ └── reward_classifier.py │ ├── utils │ │ ├── __init__.py │ │ ├── jax_utils.py │ │ ├── launcher.py │ │ ├── logging_utils.py │ │ ├── timer_utils.py │ │ ├── tools.py │ │ └── train_utils.py │ ├── vision │ │ ├── __init__.py │ │ ├── data_augmentations.py │ │ ├── film_conditioning_layer.py │ │ ├── resnet_v1.py │ │ └── spatial.py │ └── wrappers │ │ ├── __init__.py │ │ ├── chunking.py │ │ ├── front_camera_wrapper.py │ │ ├── norm.py │ │ ├── remap.py │ │ ├── serl_obs_wrappers.py │ │ ├── video_recorder.py │ │ └── video_wrapper.py └── setup.py └── serl_robot_infra ├── README.md ├── franka_env ├── camera │ ├── __init__.py │ ├── multi_video_capture.py │ ├── rs_capture.py │ └── video_capture.py ├── envs │ ├── __init__.py │ ├── dual_franka_env.py │ ├── franka_env.py │ ├── franka_wrench_env.py │ ├── relative_env.py │ └── wrappers.py ├── spacemouse │ ├── __init__.py │ ├── pyspacemouse.py │ ├── spacemouse_expert.py │ └── spacemouse_test.py └── utils │ ├── __init__.py │ ├── rotations.py │ ├── transform_absolute_actions_and_obs.py │ └── transformations.py ├── robot_servers ├── __init__.py ├── franka_eggflip_server.py ├── franka_gripper_server.py ├── franka_server.py ├── gripper_server.py ├── launch_left_server.sh ├── launch_right_eggflip_server.sh ├── launch_right_server.sh └── robotiq_gripper_server.py └── setup.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tests/debug_dataset/bridge_dataset/1.0.0/dataset_statistics_*.json 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | octo/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | .ipynb_checkpoints/ 164 | wandb 165 | *.png 166 | *.sif 167 | .vscode 168 | .idea 169 | datasets/debug_dataset/bridge_dataset/1.0.0/action_* 170 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConRFT: A Reinforced Fine-tuning Method for VLA Models via Consistency Policy 2 | 3 | [![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 4 | [![Static Badge](https://img.shields.io/badge/Project-Page-a)](https://cccedric.github.io/conrft/) 5 | 6 | We provide examples to fine-tune Octo, on the top of [HIL-SERL](https://github.com/rail-berkeley/hil-serl) that provides the base environment to perform robotic manipulation tasks with human interventions. The following sections describe how to use our code. 7 | 8 | 9 | **Table of Contents** 10 | - [ConRFT: A Reinforced Fine-tuning Method for VLA Models via Consistency Policy](#conrft-a-reinforced-fine-tuning-method-for-vla-models-via-consistency-policy) 11 | - [🛠️ Installation Instructions](#️-installation-instructions) 12 | - [💻 Overview and Code Structure](#-overview-and-code-structure) 13 | - [✉️ Contact](#️-contact) 14 | - [🙏 Acknowledgement](#-acknowledgement) 15 | - [📝 Citation](#-citation) 16 | 17 | ## 🛠️ Installation Instructions 18 | 1. **Setup Conda Environment:** 19 | create an environment with 20 | ```bash 21 | conda create -n conrft python=3.10 22 | ``` 23 | 24 | 2. **Install Jax as follows:** 25 | - For CPU (not recommended): 26 | ```bash 27 | pip install --upgrade "jax[cpu]" 28 | ``` 29 | 30 | - For GPU: 31 | ```bash 32 | pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 33 | ``` 34 | - See the [Jax Github page](https://github.com/google/jax) for more details on installing Jax. 35 | 36 | 3. **Install the Octo** 37 | ```bash 38 | git clone git@github.com:cccedric/octo.git 39 | cd octo 40 | pip install -e . 41 | pip install -r requirements.txt 42 | ``` 43 | **Note**: This is a personalized fork of Octo, adding custom functions while preserving its core capabilities for general-purpose robotic manipulation. 44 | 45 | 4. **Install the serl_launcher** 46 | ```bash 47 | cd serl_launcher 48 | pip install -e . 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | 5. **Install for serl_robot_infra** 53 | 54 | Please refer to the [README](./serl_robot_infra/README.md) in the `serl_robot_infra` directory for installation instructions and details on operating the Franka robot arm. This document includes guidance on setting up the impedance-based [serl_franka_controllers](https://github.com/rail-berkeley/serl_franka_controllers). After completing the installation, you should be able to start the robot server and interact with the `franka_env` gym for hardware control. 55 | 56 | 57 | ## 💻 Overview and Code Structure 58 | 59 | We offers a set of code for fine-tuning Octo in robotic manipulation tasks. The approach's pipeline consists of an actor thread and a learner thread, both of which interact with the robot gym environment. These two threads operate asynchronously, with data transmitted from the actor to the learner node over the network using [agentlace](https://github.com/youliangtan/agentlace). The learner thread periodically updates the policy and syncs it with the actor. 60 | 61 | **Table for code structure** 62 | 63 | | Code Directory | Description | 64 | | --- | --- | 65 | | examples | Scripts for policy training, demonstration data collection, reward classifier training | 66 | | serl_launcher | Main code for Agent Training | 67 | | serl_launcher.agents | Agent Policies (e.g. SAC, BC) | 68 | | serl_launcher.wrappers | Gym env wrappers | 69 | | serl_launcher.data | Replay buffer and data store | 70 | | serl_launcher.vision | Vision related models and utils | 71 | | serl_robot_infra | Robot infra for running with real robots | 72 | | serl_robot_infra.robot_servers | Flask server for sending commands to robot via ROS | 73 | | serl_robot_infra.franka_env | Gym env for Franka robot | 74 | 75 | We provide a step-by-step guide in [franka_walkthrough](/docs/franka_walkthrough.md) to fine-tune VLA with ConRFT on a Franka robot. 76 | 77 | ## ✉️ Contact 78 | For any questions, please feel free to email [chenyuhui2022@ia.ac.cn](mailto:chenyuhui2022@ia.ac.cn). 79 | 80 | ## 🙏 Acknowledgement 81 | Our code is built upon [CPQL](https://github.com/cccedric/cpql/), [Octo](https://github.com/octo-models/octo), [HIL-SERL](https://github.com/rail-berkeley/hil-serl). We thank all these authors for their nicely open sourced code and their great contributions to the community. 82 | 83 | ## 📝 Citation 84 | 85 | If you find our research helpful and would like to reference it in your work, please consider the following citations: 86 | 87 | ```bibtex 88 | @article{chen2025conrft, 89 | title={ConRFT: A Reinforced Fine-tuning Method for VLA Models via Consistency Policy}, 90 | author={Chen, Yuhui and Tian, Shuai and Liu, Shugao and Zhou, Yingting and Li, Haoran and Zhao, Dongbin}, 91 | journal={arXiv preprint arXiv:2502.05450}, 92 | year={2025} 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /docs/franka_walkthrough.md: -------------------------------------------------------------------------------- 1 | # Training on Franka Arm Walkthrough 2 | 3 | We demonstrate how to use our code with real robot manipulators with a representative task in the paper: Pick Banana. We provide detailed instructions and tips the entire pipeline for using ConRFT to fine-tune Octo in a real-world environemnt. 4 | 5 | ## Pick Banana 6 | ### Procedure 7 | #### Setup Franka Arm Control Server 8 | 1. To setup the Python environment and install the Franka controllers, please refer to [README.md](../README.md). 9 | 10 | 2. To setup the workspace, please refer to the image of our workspace setup in our paper. 11 | 12 | 3. Adjust for the weight of the wrist cameras by editing `Desk > Settings > End-effector > Mechanical Data > Mass`. 13 | 14 | 4. Unlock the robot and activate FCI in the Franka Desk. The `franka_server` launch file is found at [serl_robot_infra/robot_servers/launch_right_server.sh](../serl_robot_infra/robot_servers/launch_right_server.sh). You will need to edit the `setup.bash` path as well as the flags for the `python franka_server.py` command. You can refer to the [README.md](../serl_robot_infra/README.md) for `serl_robot_infra` for instructions on setting these flags. To launch the server, run: 15 | 16 | ```bash 17 | bash serl_robot_infra/robot_servers/launch_right_server.sh 18 | ``` 19 | 20 | #### Specify Training Configuration for Your Workplace 21 | For each task, we create a folder in the experiments folder to store data (i.e. task demonstrations, reward classifier data, training run checkpoints), launch scripts, and training configurations (see [experiments/task1_pick_banana](../examples/experiments/task1_pick_banana/)). Next, we will walkthrough all of the changes you need to make the training configuration in [experiments/task1_pick_banana/config.py](../examples/experiments/task1_pick_banana/config.py)) to begin training: 22 | 23 | 1. First, in the `EnvConfig` class, change `SERVER_URL` to the URL of the running Franka server. 24 | 25 | 2. Next, we need to configure the cameras. For this task, we used two wrist cameras. All cameras used for a task (both for the reward classifier and policy training) are listed in `REALSENSE_CAMERAS` and their corresponding image crops are set in `IMAGE_CROP` in the `EnvConfig` class. The camera keys used for policy training and for the reward classifier are listed in `TrainConfig` class in `image_keys` and `classifier_keys` respectively. Change the serial numbers in `REALSENSE_CAMERAS` to the serial numbers of the cameras in your setup (this can be found in the RealSense Viewer application). To adjust the image crops (and potentially the exposure), you can run the reward classifier data collection script (see step 6) or the demonstration data collection script (see step 8) to visualize the camera inputs. 26 | 27 | 3. Finally, we need to collect some poses for the training process. For this task, `TARGET_POSE` refers to the arm pose when putting the banana to the plate, and `RESET_POSE` refers to the arm pose to reset to. `ABS_POSE_LIMIT_HIGH` and `ABS_POSE_LIMIT_LOW` determine the bounding box for the policy. We have `RANDOM_RESET` enabled, meaning there is randomization around the `RESET_POSE` for every reset (`RANDOM_XY_RANGE` and `RANDOM_RZ_RANGE` control the amount of randomization). You should recollect `TARGET_POSE`, and ensure the bounding box is set for safe exploration. To collect the current pose of the Franka arm, you can run: 28 | ```bash 29 | curl -X POST http://:5000/getpos_euler 30 | ``` 31 | 32 | #### Training Reward Classifier 33 | The reward for this task is given via a reward classifier trained on camera images. For this task, we use the same specified images in `classifier_keys` for training the policy to train the reward classifier. The following steps goes through collecting classifier data and training the reward classifier. 34 | 35 | 1. First, we need to collect training data for the classifier. Navigate into the examples folder and run: 36 | ```bash 37 | cd examples 38 | python record_success_fail.py --exp_name task1_pick_banana --successes_needed 200 39 | ``` 40 | While the script is running, all transitions recorded are marked as negative (or no reward) by default. If the space bar is held during a transition, that transition will be marked as positive. The script will terminate when enough positive transitions have been collected (defaults to 200, but can be set via the successes_needed flag). For this task, you should collect negative transitions of the RAM stick held in various locations in the workspace and during the insertion process, and pressing the space bar when the RAM is fully inserted. The classifier data will be saved to the folder `experiments/task1_pick_banana/classifier_data`. 41 | 42 | > **TIP**: To train a classifier robust against false positives (this is important for training a successful policy), we've found it helpful to collect 2-3x more negative transitions as positive transitions to cover all failure modes. 43 | 44 | 2. To train the reward classifier, navigate to this task's experiment folder and run: 45 | ```bash 46 | cd experiments/task1_pick_banana 47 | python ../../train_reward_classifier.py --exp_name task1_pick_banana 48 | ``` 49 | The reward classifier will be trained on the camera images specified by the classifier keys in the training config. The trained classifier will be saved to the folder `experiments/task1_pick_banana/classifier_ckpt`. 50 | 51 | #### Recording Demonstrations 52 | A small number of human demonstrations is crucial for stage I (Cal-ConRFT), and for this task, we use 30 demonstrations. 53 | 54 | 1. To record the 30 demonstrations with the spacemouse, run: 55 | ```bash 56 | python ../../record_demos_octo.py --exp_name task1_pick_banana --successes_needed 30 57 | ``` 58 | Once the episode is deemed successful by the reward classifier or the episode times out, the robot will reset. The script will terminate once 30 successful demonstrations have been collected, which will be saved to the folder `experiments/task1_pick_banana/demo_data`. 59 | 60 | > **TIP**: During the demo data collection progress, you may notice the reward classifier outputting false positives (episode terminating with reward given without a successful insertion) or false negatives (no reward given despite successful insertion). In that case, you should collect additional classifier data to target the classifier failure modes observed (i.e., if the classifier is giving false positives for holding RAM stick in the air, you should collect more negative data of that occurring). Alternatively, you can also adjust the reward classifier threshold, although we strongly recommend collecting additional classifier data (or even adding more classifier cameras/images if needed) before doing this. 61 | 62 | #### Policy Training 63 | Policy training is done asynchronously via an actor thread, responsible for rolling out the policy in the environment and sending the collected transitions to the learner thread, responsible for training the policy and sending the updated policy back to the actor. Both the actor and the learner should be running during policy training. 64 | 65 | 1. Inside the folder corresponding to the Pick Banana experiment ([experiments/task1_pick_banana](../examples/experiments/task1_pick_banana/)), you will find `run_actor_conrft.sh`, `run_learner_conrft_pretrain.sh` and `run_learner_conrft.sh`. In both scripts, edit `checkpoint_path` to point to the folder where checkpoints and other data generated in the training process will be saved to and in `run_learner_conrft_pretrain.sh` and `run_learner_conrft.sh`, edit `demo_path` to point to the path of the recorded demonstrations (if there are multiple demonstration files, you can provide multiple `demo_path` flags). Firstly, begin stage I (Cal-ConRFT): 66 | ```bash 67 | bash run_learner_conrft_pretrain.sh 68 | ``` 69 | 70 | Then, to begin stage II (HIL-ConRFT), launch both threads: 71 | ```bash 72 | bash run_actor_conrft.sh 73 | bash run_learner_conrft.sh 74 | ``` 75 | 76 | 2. During online training, you should give some interventions as necessary with the spacemouse to speed up the training run, particularly closer to the beginning of the run or when the policy is exploring an incorrect behavior repeatedly. For reference, with the randomizations on and giving occasional interventions, the policy took around 1 hours to converge to 100% success rate. 77 | 78 | 3. To evaluate the trained policy, add the flags `--eval_checkpoint_step=CHECKPOINT_NUMBER_TO_EVAL` and `--eval_n_trajs=N_TIMES_TO_EVAL` to `run_actor_conrft.sh`. Then, launch the actor: 79 | ```bash 80 | bash run_actor_conrft.sh 81 | ``` 82 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/examples/__init__.py -------------------------------------------------------------------------------- /examples/data_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | 4 | 5 | def calc_return_to_go(rewards, terminals, gamma, reward_scale, reward_bias, reward_neg, is_sparse_reward): 6 | """ 7 | A config dict for getting the default high/low rewrd values for each envs 8 | """ 9 | if len(rewards) == 0: 10 | return np.array([]) 11 | 12 | if is_sparse_reward: 13 | reward_neg = reward_neg * reward_scale + reward_bias 14 | else: 15 | assert not is_sparse_reward, "If you want to try on a sparse reward env, please add the reward_neg value in the ENV_CONFIG dict." 16 | 17 | if is_sparse_reward and np.all(np.array(rewards) == reward_neg): 18 | """ 19 | If the env has sparse reward and the trajectory is all negative rewards, 20 | we use r / (1-gamma) as return to go. 21 | For exapmle, if gamma = 0.99 and the rewards = [-1, -1, -1], 22 | then return_to_go = [-100, -100, -100] 23 | """ 24 | return_to_go = [float(reward_neg / (1-gamma))] * len(rewards) 25 | else: 26 | return_to_go = [0] * len(rewards) 27 | prev_return = 0 28 | for i in range(len(rewards)): 29 | return_to_go[-i-1] = rewards[-i-1] + gamma * \ 30 | prev_return * (1 - terminals[-i-1]) 31 | prev_return = return_to_go[-i-1] 32 | 33 | return np.array(return_to_go, dtype=np.float32) 34 | 35 | 36 | def add_mc_returns_to_trajectory(trajectory, gamma, reward_scale, reward_bias, reward_neg, is_sparse_reward): 37 | """ 38 | undate every transition in the trajectory and add mc_returns 39 | return the updated trajectory 40 | """ 41 | rewards = [t['rewards'] for t in trajectory] 42 | terminals = [t['dones'] for t in trajectory] 43 | 44 | mc_returns = calc_return_to_go( 45 | rewards=rewards, 46 | terminals=terminals, 47 | gamma=gamma, 48 | reward_scale=reward_scale, 49 | reward_bias=reward_bias, 50 | reward_neg=reward_neg, 51 | is_sparse_reward=is_sparse_reward, 52 | ) 53 | 54 | for i, transition in enumerate(trajectory): 55 | transition['mc_returns'] = mc_returns[i] 56 | 57 | return trajectory 58 | 59 | 60 | def add_embeddings_to_trajectory(trajectory, model, tasks): 61 | """ 62 | undate every transition in the trajectory and add embeddings 63 | return the updated trajectory 64 | """ 65 | for i in range(len(trajectory)): 66 | observation = trajectory[i]['observations'] 67 | 68 | image_primary = observation["side_policy_256"] 69 | image_wrist = observation["wrist_1"] 70 | # Add batch dimension 71 | image_primary = image_primary[np.newaxis, ...] 72 | image_wrist = image_wrist[np.newaxis, ...] 73 | timestep_pad_mask = np.array([[True, True]]) 74 | 75 | observation = {"image_primary": image_primary, 76 | "image_wrist": image_wrist, 77 | "timestep_pad_mask": timestep_pad_mask, 78 | } 79 | 80 | action_embeddings = model.sample_transformer(observation, tasks,) 81 | # Now, action_embeddings is (batch_size, window_size, embedding_size) 82 | 83 | # remove window_size dimension 84 | action_embeddings = action_embeddings[:, -1, :] 85 | 86 | trajectory[i]['embeddings'] = action_embeddings 87 | 88 | return trajectory 89 | 90 | 91 | def add_next_embeddings_to_trajectory(trajectory): 92 | """ 93 | undate every transition in the trajectory and add next_embeddings 94 | return the updated trajectory 95 | """ 96 | for i in range(len(trajectory)): 97 | if i == len(trajectory) - 1: 98 | trajectory[i]['next_embeddings'] = trajectory[i]['embeddings'] 99 | else: 100 | trajectory[i]['next_embeddings'] = trajectory[i+1]['embeddings'] 101 | 102 | return trajectory 103 | -------------------------------------------------------------------------------- /examples/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/examples/experiments/__init__.py -------------------------------------------------------------------------------- /examples/experiments/config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import List 3 | 4 | class DefaultTrainingConfig: 5 | """Default training configuration. """ 6 | 7 | agent: str = "drq" 8 | max_traj_length: int = 100 9 | batch_size: int = 256 # 256 10 | cta_ratio: int = 2 11 | discount: float = 0.97 12 | 13 | max_steps: int = 1000000 14 | replay_buffer_capacity: int = 200000 15 | 16 | random_steps: int = 0 17 | training_starts: int = 100 18 | steps_per_update: int = 50 19 | 20 | log_period: int = 10 21 | eval_period: int = 2000 22 | 23 | # "resnet" for ResNet10 from scratch and "resnet-pretrained" for frozen ResNet10 with pretrained weights 24 | encoder_type: str = "resnet-pretrained" 25 | demo_path: str = None 26 | checkpoint_period: int = 0 27 | buffer_period: int = 0 28 | 29 | eval_checkpoint_step: int = 0 30 | eval_n_trajs: int = 5 31 | 32 | image_keys: List[str] = None 33 | classifier_keys: List[str] = None 34 | proprio_keys: List[str] = None 35 | 36 | # "single-arm-learned-gripper", "dual-arm-learned-gripper" for with learned gripper, 37 | # "single-arm-fixed-gripper", "dual-arm-fixed-gripper" for without learned gripper (i.e. pregrasped) 38 | setup_mode: str = "single-arm-fixed-gripper" 39 | 40 | @abstractmethod 41 | def get_environment(self, fake_env=False, save_video=False, classifier=False): 42 | raise NotImplementedError 43 | 44 | @abstractmethod 45 | def process_demos(self, demo): 46 | raise NotImplementedError 47 | 48 | -------------------------------------------------------------------------------- /examples/experiments/mappings.py: -------------------------------------------------------------------------------- 1 | from experiments.task1_pick_banana.config import TrainConfig as PickBananaTrainConfig 2 | 3 | CONFIG_MAPPING = { 4 | "task1_pick_banana": PickBananaTrainConfig, 5 | } 6 | -------------------------------------------------------------------------------- /examples/experiments/resnet10_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/examples/experiments/resnet10_params.pkl -------------------------------------------------------------------------------- /examples/experiments/task1_pick_banana/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jax 3 | import numpy as np 4 | import jax.numpy as jnp 5 | 6 | from franka_env.envs.wrappers import ( 7 | Quat2EulerWrapper, 8 | SpacemouseIntervention, 9 | MultiCameraBinaryRewardClassifierWrapper, 10 | ) 11 | from franka_env.envs.relative_env import RelativeFrame 12 | from franka_env.envs.franka_env import DefaultEnvConfig 13 | from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper 14 | from serl_launcher.wrappers.chunking import ChunkingWrapper 15 | from serl_launcher.networks.reward_classifier import load_classifier_func 16 | 17 | from experiments.config import DefaultTrainingConfig 18 | from experiments.task1_pick_banana.wrapper import PickBananaEnv, GripperPenaltyWrapper 19 | 20 | 21 | class EnvConfig(DefaultEnvConfig): 22 | SERVER_URL: str = "http://127.0.0.2:5000/" 23 | REALSENSE_CAMERAS = { 24 | "wrist_1": { 25 | "serial_number": "115222071051", 26 | "dim": (1280, 720), 27 | "exposure": 10500, 28 | }, 29 | "side_policy_256": { 30 | "serial_number": "242422305075", 31 | "dim": (1280, 720), 32 | "exposure": 13000, 33 | }, 34 | "side_classifier": { 35 | "serial_number": "242422305075", 36 | "dim": (1280, 720), 37 | "exposure": 13000, 38 | }, 39 | "demo": { 40 | "serial_number": "242422305075", 41 | "dim": (1280, 720), 42 | "exposure": 13000, 43 | }, 44 | } 45 | IMAGE_CROP = {"wrist_1": lambda img: img, 46 | "side_policy_256": lambda img: img[250:-150, 400:-500], 47 | "side_classifier": lambda img: img[390:-150, 420:-700], 48 | "demo": lambda img: img[50:-150, 400:-400]} 49 | 50 | TARGET_POSE = np.array([0.33, -0.15, 0.20, np.pi, 0, 0]) 51 | RESET_POSE = np.array([0.61, -0.17, 0.22, np.pi, 0, 0]) 52 | ACTION_SCALE = np.array([0.08, 0.2, 1]) 53 | RANDOM_RESET = True 54 | DISPLAY_IMAGE = True 55 | RANDOM_XY_RANGE = 0.02 56 | RANDOM_RZ_RANGE = 0.03 57 | ABS_POSE_LIMIT_HIGH = TARGET_POSE + np.array([0.3, 0.03, 0.02, 0.01, 0.01, 0.3]) 58 | ABS_POSE_LIMIT_LOW = TARGET_POSE - np.array([0.03, 0.05, 0.05, 0.01, 0.01, 0.3]) 59 | COMPLIANCE_PARAM = { 60 | "translational_stiffness": 2000, 61 | "translational_damping": 89, 62 | "rotational_stiffness": 150, 63 | "rotational_damping": 7, 64 | "translational_Ki": 0, 65 | "translational_clip_x": 0.008, 66 | "translational_clip_y": 0.005, 67 | "translational_clip_z": 0.005, 68 | "translational_clip_neg_x": 0.008, 69 | "translational_clip_neg_y": 0.005, 70 | "translational_clip_neg_z": 0.005, 71 | "rotational_clip_x": 0.02, 72 | "rotational_clip_y": 0.02, 73 | "rotational_clip_z": 0.02, 74 | "rotational_clip_neg_x": 0.02, 75 | "rotational_clip_neg_y": 0.02, 76 | "rotational_clip_neg_z": 0.02, 77 | "rotational_Ki": 0, 78 | } # for normal operation other than reset procedure 79 | PRECISION_PARAM = { 80 | "translational_stiffness": 2000, 81 | "translational_damping": 89, 82 | "rotational_stiffness": 150, 83 | "rotational_damping": 7, 84 | "translational_Ki": 0.0, 85 | "translational_clip_x": 0.01, 86 | "translational_clip_y": 0.01, 87 | "translational_clip_z": 0.01, 88 | "translational_clip_neg_x": 0.01, 89 | "translational_clip_neg_y": 0.01, 90 | "translational_clip_neg_z": 0.01, 91 | "rotational_clip_x": 0.03, 92 | "rotational_clip_y": 0.03, 93 | "rotational_clip_z": 0.03, 94 | "rotational_clip_neg_x": 0.03, 95 | "rotational_clip_neg_y": 0.03, 96 | "rotational_clip_neg_z": 0.03, 97 | "rotational_Ki": 0.0, 98 | } # only for reset procedure 99 | MAX_EPISODE_LENGTH = 100 100 | 101 | 102 | class TrainConfig(DefaultTrainingConfig): 103 | image_keys = ["side_policy_256", "wrist_1"] 104 | classifier_keys = ["side_classifier"] 105 | proprio_keys = ["tcp_pose", "tcp_vel", "tcp_force", "tcp_torque", "gripper_pose"] 106 | checkpoint_period = 2000 107 | cta_ratio = 2 108 | random_steps = 0 109 | discount = 0.98 110 | buffer_period = 1000 111 | encoder_type = "resnet-pretrained" 112 | setup_mode = "single-arm-learned-gripper" 113 | reward_neg = -0.05 114 | task_desc = "Put the yellow banana to the green plate" 115 | octo_path = "/root/online_rl/octo_model/octo-small" 116 | 117 | def get_environment(self, fake_env=False, save_video=False, classifier=False, stack_obs_num=1): 118 | env = PickBananaEnv( fake_env=fake_env, save_video=save_video, config=EnvConfig()) 119 | if not fake_env: 120 | env = SpacemouseIntervention(env) 121 | env = RelativeFrame(env) 122 | env = Quat2EulerWrapper(env) 123 | env = SERLObsWrapper(env, proprio_keys=self.proprio_keys) 124 | env = ChunkingWrapper(env, obs_horizon=stack_obs_num, act_exec_horizon=None) 125 | if classifier: 126 | classifier = load_classifier_func( 127 | key=jax.random.PRNGKey(0), 128 | sample=env.observation_space.sample(), 129 | image_keys=self.classifier_keys, 130 | checkpoint_path=os.path.abspath("classifier_ckpt/"), 131 | ) 132 | 133 | def reward_func(obs): 134 | def sigmoid(x): return 1 / (1 + jnp.exp(-x)) 135 | # Should open the gripper and pull up after putting the banana 136 | if sigmoid(classifier(obs)[0]) > 0.9 and env.curr_gripper_pos > 0.5 and env.currpos[2] > 0.16: 137 | return 10.0 138 | else: 139 | return self.reward_neg 140 | 141 | env = MultiCameraBinaryRewardClassifierWrapper(env, reward_func) 142 | env = GripperPenaltyWrapper(env, penalty=-0.2) 143 | return env 144 | -------------------------------------------------------------------------------- /examples/experiments/task1_pick_banana/run_actor_conrft.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ 2 | export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ 3 | python ../../train_conrft_octo.py "$@" \ 4 | --exp_name=task1_pick_banana \ 5 | --checkpoint_path=/root/online_rl/conrft/examples/experiments/task1_pick_banana/conrft \ 6 | --actor \ 7 | # --eval_checkpoint_step=26000 \ -------------------------------------------------------------------------------- /examples/experiments/task1_pick_banana/run_actor_hilserl.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ 2 | export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ 3 | python ../../train_rlpd.py "$@" \ 4 | --exp_name=task1_pick_banana \ 5 | --checkpoint_path=/root/online_rl/conrft/examples/experiments/task1_pick_banana/debug_hilserl \ 6 | --actor \ 7 | --eval_checkpoint_step=26000 \ -------------------------------------------------------------------------------- /examples/experiments/task1_pick_banana/run_learner_conrft.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ 2 | export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ 3 | python ../../train_conrft_octo.py "$@" \ 4 | --exp_name=task1_pick_banana \ 5 | --checkpoint_path=/root/online_rl/conrft/examples/experiments/task1_pick_banana/conrft\ 6 | --q_weight=1.0 \ 7 | --bc_weight=0.1 \ 8 | --demo_path=./demo_data/task1_pick_banana_30_demos.pkl \ 9 | --pretrain_steps=20000 \ 10 | --debug=False \ 11 | --learner \ -------------------------------------------------------------------------------- /examples/experiments/task1_pick_banana/run_learner_conrft_pretrain.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ 2 | export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ 3 | python ../../train_conrft_octo.py "$@" \ 4 | --exp_name=task1_pick_banana \ 5 | --checkpoint_path=/root/online_rl/conrft/examples/experiments/task1_pick_banana/conrft\ 6 | --q_weight=0.0 \ 7 | --bc_weight=1.0 \ 8 | --demo_path=./demo_data/task1_pick_banana_30_demos.pkl \ 9 | --pretrain_steps=20000 \ 10 | --debug=False \ 11 | --learner \ -------------------------------------------------------------------------------- /examples/experiments/task1_pick_banana/run_learner_hilserl.sh: -------------------------------------------------------------------------------- 1 | export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ 2 | export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ 3 | python ../../train_rlpd.py "$@" \ 4 | --exp_name=task1_pick_banana \ 5 | --checkpoint_path=/root/online_rl/conrft/examples/experiments/task1_pick_banana/debug_hilserl \ 6 | --demo_path=./demo_data/task1_pick_banana_30_demos.pkl \ 7 | --learner \ -------------------------------------------------------------------------------- /examples/experiments/task1_pick_banana/wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import OrderedDict 2 | from franka_env.camera.rs_capture import RSCapture 3 | from franka_env.camera.video_capture import VideoCapture 4 | from franka_env.utils.rotations import euler_2_quat 5 | import numpy as np 6 | import requests 7 | import copy 8 | import gymnasium as gym 9 | import time 10 | from franka_env.envs.franka_env import FrankaEnv 11 | 12 | class PickBananaEnv(FrankaEnv): 13 | def __init__(self, **kwargs): 14 | super().__init__(**kwargs) 15 | 16 | def init_cameras(self, name_serial_dict=None): 17 | """Init both wrist cameras.""" 18 | if self.cap is not None: # close cameras if they are already open 19 | self.close_cameras() 20 | 21 | self.cap = OrderedDict() 22 | for cam_name, kwargs in name_serial_dict.items(): 23 | if cam_name == "side_classifier": 24 | self.cap["side_classifier"] = self.cap["side_policy_256"] 25 | elif cam_name == "demo": 26 | self.cap["demo"] = self.cap["side_policy_256"] 27 | else: 28 | cap = VideoCapture(RSCapture(name=cam_name, **kwargs)) 29 | self.cap[cam_name] = cap 30 | 31 | 32 | def reset(self, **kwargs): 33 | self._recover() 34 | self._update_currpos() 35 | self._send_pos_command(self.currpos) 36 | time.sleep(0.3) 37 | requests.post(self.url + "update_param", json=self.config.PRECISION_PARAM) 38 | # Move above the target pose 39 | target = copy.deepcopy(self.currpos) 40 | target[2] = self.config.TARGET_POSE[2] + 0.05 41 | # target[2] = self.config.RESET_POSE[2] 42 | self.interpolate_move(target, timeout=1) 43 | time.sleep(0.5) 44 | 45 | obs, info = super().reset(**kwargs) 46 | self._send_gripper_command(1.0) 47 | time.sleep(1) 48 | self.success = False 49 | self._update_currpos() 50 | obs = self._get_obs() 51 | return obs, info 52 | 53 | def interpolate_move(self, goal: np.ndarray, timeout: float): 54 | """Move the robot to the goal position with linear interpolation.""" 55 | if goal.shape == (6,): 56 | goal = np.concatenate([goal[:3], euler_2_quat(goal[3:])]) 57 | self._send_pos_command(goal) 58 | time.sleep(timeout) 59 | self._update_currpos() 60 | 61 | def go_to_reset(self, joint_reset=False): 62 | """ 63 | The concrete steps to perform reset should be 64 | implemented each subclass for the specific task. 65 | Should override this method if custom reset procedure is needed. 66 | """ 67 | 68 | # Perform joint reset if needed 69 | if joint_reset: 70 | print("JOINT RESET") 71 | requests.post(self.url + "jointreset") 72 | time.sleep(0.5) 73 | 74 | # Perform Carteasian reset 75 | if self.randomreset: # randomize reset position in xy plane 76 | reset_pose = self.resetpos.copy() 77 | reset_pose[:2] += np.random.uniform( 78 | -self.random_xy_range, self.random_xy_range, (2,) 79 | ) 80 | euler_random = self._RESET_POSE[3:].copy() 81 | euler_random[-1] += np.random.uniform( 82 | -self.random_rz_range, self.random_rz_range 83 | ) 84 | reset_pose[3:] = euler_2_quat(euler_random) 85 | self.interpolate_move(reset_pose, timeout=1) 86 | else: 87 | reset_pose = self.resetpos.copy() 88 | self.interpolate_move(reset_pose, timeout=1) 89 | time.sleep(1.0) 90 | 91 | # Change to compliance mode 92 | requests.post(self.url + "update_param", json=self.config.COMPLIANCE_PARAM) 93 | 94 | class GripperPenaltyWrapper(gym.Wrapper): 95 | def __init__(self, env, penalty=-0.05): 96 | super().__init__(env) 97 | assert env.action_space.shape == (7,) 98 | self.penalty = penalty 99 | self.last_gripper_pos = None 100 | 101 | def reset(self, **kwargs): 102 | obs, info = self.env.reset(**kwargs) 103 | self.last_gripper_pos = obs["state"][0, 0] 104 | return obs, info 105 | 106 | def step(self, action): 107 | """Modifies the :attr:`env` :meth:`step` reward using :meth:`self.reward`.""" 108 | action = copy.deepcopy(action) 109 | grasp_action = action[..., -1] 110 | 111 | grasp_action = np.where(grasp_action > 0.5, 1, np.where(grasp_action < -0.5, -1, 0) ) 112 | action[..., -1] = grasp_action 113 | 114 | observation, reward, terminated, truncated, info = self.env.step(action) 115 | if "intervene_action" in info: 116 | action = info["intervene_action"] 117 | 118 | if (action[-1] < -0.5 and self.last_gripper_pos > 0.7) or ( 119 | action[-1] > 0.5 and self.last_gripper_pos < 0.7 120 | ): 121 | info["grasp_penalty"] = self.penalty 122 | else: 123 | info["grasp_penalty"] = 0.0 124 | 125 | self.last_gripper_pos = observation["state"][0, 0] 126 | return observation, reward, terminated, truncated, info 127 | 128 | -------------------------------------------------------------------------------- /examples/record_demos.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import copy 5 | import pickle as pkl 6 | import datetime 7 | from absl import app, flags 8 | import time 9 | 10 | from experiments.mappings import CONFIG_MAPPING 11 | from data_util import add_mc_returns_to_trajectory 12 | 13 | FLAGS = flags.FLAGS 14 | flags.DEFINE_string("exp_name", None, "Name of experiment corresponding to folder.") 15 | flags.DEFINE_integer("successes_needed", 20, "Number of successful demos to collect.") 16 | flags.DEFINE_float("gamma", 0.95, "return discount") 17 | flags.DEFINE_float("reward_neg", 0.0, "reward_neg for spase reward envs") 18 | flags.DEFINE_float("reward_scale", 1.0, "reward_scale ") 19 | flags.DEFINE_float("reward_bias", 0.0, "reward_bias") 20 | 21 | 22 | def main(_): 23 | assert FLAGS.exp_name in CONFIG_MAPPING, 'Experiment folder not found.' 24 | config = CONFIG_MAPPING[FLAGS.exp_name]() 25 | env = config.get_environment(fake_env=False, save_video=False, classifier=True) 26 | FLAGS.reward_neg = config.reward_neg 27 | 28 | obs, info = env.reset() 29 | print("Reset done") 30 | transitions = [] 31 | success_count = 0 32 | success_needed = FLAGS.successes_needed 33 | pbar = tqdm(total=success_needed) 34 | trajectory = [] 35 | returns = 0 36 | 37 | while success_count < success_needed: 38 | actions = np.zeros(env.action_space.sample().shape) 39 | next_obs, rew, done, truncated, info = env.step(actions) 40 | returns += rew 41 | if "intervene_action" in info: 42 | actions = info["intervene_action"] 43 | transition = copy.deepcopy( 44 | dict( 45 | observations=obs, 46 | actions=actions, 47 | next_observations=next_obs, 48 | rewards=rew, 49 | masks=1.0 - done, 50 | dones=done, 51 | infos=info, 52 | ) 53 | ) 54 | trajectory.append(transition) 55 | 56 | pbar.set_description(f"Return: {returns}") 57 | 58 | obs = next_obs 59 | if done: 60 | if info["succeed"]: 61 | trajectory = add_mc_returns_to_trajectory(trajectory, FLAGS.gamma, FLAGS.reward_scale, FLAGS.reward_bias, FLAGS.reward_neg, is_sparse_reward=True) 62 | for transition in trajectory: 63 | transitions.append(copy.deepcopy(transition)) 64 | success_count += 1 65 | pbar.update(1) 66 | trajectory = [] 67 | returns = 0 68 | obs, info = env.reset() 69 | time.sleep(2.0) 70 | 71 | if not os.path.exists("./demo_data"): 72 | os.makedirs("./demo_data") 73 | uuid = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 74 | file_name = f"./demo_data/{FLAGS.exp_name}_{success_needed}_demos_{uuid}.pkl" 75 | with open(file_name, "wb") as f: 76 | pkl.dump(transitions, f) 77 | print(f"saved {success_needed} demos to {file_name}") 78 | 79 | if __name__ == "__main__": 80 | app.run(main) -------------------------------------------------------------------------------- /examples/record_demos_octo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import copy 5 | import pickle as pkl 6 | import datetime 7 | from absl import app, flags 8 | import time 9 | 10 | from experiments.mappings import CONFIG_MAPPING 11 | from data_util import add_mc_returns_to_trajectory, add_embeddings_to_trajectory, add_next_embeddings_to_trajectory 12 | 13 | from octo.model.octo_model import OctoModel 14 | 15 | FLAGS = flags.FLAGS 16 | flags.DEFINE_string( 17 | "exp_name", None, "Name of experiment corresponding to folder.") 18 | flags.DEFINE_integer("successes_needed", 20, 19 | "Number of successful demos to collect.") 20 | flags.DEFINE_float("reward_scale", 1.0, "reward_scale ") 21 | flags.DEFINE_float("reward_bias", 0.0, "reward_bias") 22 | 23 | 24 | def main(_): 25 | assert FLAGS.exp_name in CONFIG_MAPPING, 'Experiment folder not found.' 26 | config = CONFIG_MAPPING[FLAGS.exp_name]() 27 | env = config.get_environment( 28 | fake_env=False, save_video=False, classifier=True, stack_obs_num=2) 29 | 30 | model = OctoModel.load_pretrained(config.octo_path) 31 | tasks = model.create_tasks(texts=[config.task_desc]) 32 | # model = None 33 | # tasks = None 34 | 35 | obs, info = env.reset() 36 | print(obs.keys()) 37 | print("Reset done") 38 | 39 | transitions = [] 40 | success_count = 0 41 | success_needed = FLAGS.successes_needed 42 | pbar = tqdm(total=success_needed) 43 | trajectory = [] 44 | returns = 0 45 | 46 | while success_count < success_needed: 47 | actions = np.zeros(env.action_space.sample().shape) 48 | next_obs, rew, done, truncated, info = env.step(actions) 49 | returns += rew 50 | if "intervene_action" in info: 51 | actions = info["intervene_action"] 52 | transition = copy.deepcopy( 53 | dict( 54 | observations=obs, 55 | actions=actions, 56 | next_observations=next_obs, 57 | rewards=rew, 58 | masks=1.0 - done, 59 | dones=done, 60 | infos=info, 61 | ) 62 | ) 63 | trajectory.append(transition) 64 | 65 | pbar.set_description(f"Return: {returns:.2f}") 66 | 67 | obs = next_obs 68 | if done: 69 | if info["succeed"]: 70 | trajectory = add_mc_returns_to_trajectory( 71 | trajectory, config.discount, FLAGS.reward_scale, FLAGS.reward_bias, config.reward_neg, is_sparse_reward=True) 72 | trajectory = add_embeddings_to_trajectory( 73 | trajectory, model, tasks=tasks) 74 | trajectory = add_next_embeddings_to_trajectory(trajectory) 75 | for transition in trajectory: 76 | transitions.append(copy.deepcopy(transition)) 77 | success_count += 1 78 | pbar.update(1) 79 | trajectory = [] 80 | returns = 0 81 | obs, info = env.reset() 82 | time.sleep(2.0) 83 | 84 | if not os.path.exists("./demo_data"): 85 | os.makedirs("./demo_data") 86 | uuid = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 87 | file_name = f"./demo_data/{FLAGS.exp_name}_{success_needed}_demos_{uuid}.pkl" 88 | with open(file_name, "wb") as f: 89 | pkl.dump(transitions, f) 90 | print(f"saved {success_needed} demos to {file_name}") 91 | 92 | 93 | if __name__ == "__main__": 94 | app.run(main) 95 | -------------------------------------------------------------------------------- /examples/record_success_fail.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from tqdm import tqdm 4 | import numpy as np 5 | import pickle as pkl 6 | import datetime 7 | from absl import app, flags 8 | from pynput import keyboard 9 | 10 | from experiments.mappings import CONFIG_MAPPING 11 | 12 | FLAGS = flags.FLAGS 13 | flags.DEFINE_string("exp_name", None, "Name of experiment corresponding to folder.") 14 | flags.DEFINE_integer("successes_needed", 200, "Number of successful transistions to collect.") 15 | 16 | 17 | success_key = False 18 | def on_press(key): 19 | global success_key 20 | try: 21 | if str(key) == 'Key.space': 22 | success_key = True 23 | except AttributeError: 24 | pass 25 | 26 | def main(_): 27 | global success_key 28 | listener = keyboard.Listener(on_press=on_press) 29 | listener.start() 30 | assert FLAGS.exp_name in CONFIG_MAPPING, 'Experiment folder not found.' 31 | config = CONFIG_MAPPING[FLAGS.exp_name]() 32 | env = config.get_environment(fake_env=False, save_video=False, classifier=False, stack_obs_num=2) 33 | 34 | obs, _ = env.reset() 35 | successes = [] 36 | failures = [] 37 | success_needed = FLAGS.successes_needed 38 | pbar = tqdm(total=success_needed) 39 | 40 | while len(successes) < success_needed: 41 | actions = np.zeros(env.action_space.sample().shape) 42 | next_obs, rew, done, truncated, info = env.step(actions) 43 | if "intervene_action" in info: 44 | actions = info["intervene_action"] 45 | # print(actions) 46 | 47 | transition = copy.deepcopy( 48 | dict( 49 | observations=obs, 50 | actions=actions, 51 | next_observations=next_obs, 52 | rewards=rew, 53 | masks=1.0 - done, 54 | dones=done, 55 | ) 56 | ) 57 | obs = next_obs 58 | if success_key: 59 | successes.append(transition) 60 | pbar.update(1) 61 | success_key = False 62 | # obs, _ = env.reset() 63 | else: 64 | failures.append(transition) 65 | 66 | if done or truncated: 67 | obs, _ = env.reset() 68 | 69 | if not os.path.exists("./classifier_data"): 70 | os.makedirs("./classifier_data") 71 | uuid = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 72 | file_name = f"./classifier_data/{FLAGS.exp_name}_{success_needed}_success_images_{uuid}.pkl" 73 | with open(file_name, "wb") as f: 74 | pkl.dump(successes, f) 75 | print(f"saved {success_needed} successful transitions to {file_name}") 76 | 77 | file_name = f"./classifier_data/{FLAGS.exp_name}_failure_images_{uuid}.pkl" 78 | with open(file_name, "wb") as f: 79 | pkl.dump(failures, f) 80 | print(f"saved {len(failures)} failure transitions to {file_name}") 81 | 82 | if __name__ == "__main__": 83 | app.run(main) 84 | -------------------------------------------------------------------------------- /examples/train_reward_classifier.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pickle as pkl 4 | import jax 5 | from jax import numpy as jnp 6 | import flax.linen as nn 7 | from flax.training import checkpoints 8 | import numpy as np 9 | import optax 10 | from tqdm import tqdm 11 | from absl import app, flags 12 | 13 | from serl_launcher.data.data_store import ReplayBuffer 14 | from serl_launcher.utils.train_utils import concat_batches 15 | from serl_launcher.vision.data_augmentations import batched_random_crop 16 | from serl_launcher.networks.reward_classifier import create_classifier 17 | 18 | from experiments.mappings import CONFIG_MAPPING 19 | 20 | 21 | FLAGS = flags.FLAGS 22 | flags.DEFINE_string("exp_name", None, "Name of experiment corresponding to folder.") 23 | flags.DEFINE_integer("num_epochs", 150, "Number of training epochs.") 24 | flags.DEFINE_integer("batch_size", 256, "Batch size.") 25 | 26 | 27 | def main(_): 28 | assert FLAGS.exp_name in CONFIG_MAPPING, 'Experiment folder not found.' 29 | config = CONFIG_MAPPING[FLAGS.exp_name]() 30 | env = config.get_environment(fake_env=True, save_video=False, classifier=False, stack_obs_num=2) 31 | 32 | devices = jax.local_devices() 33 | sharding = jax.sharding.PositionalSharding(devices) 34 | 35 | # Create buffer for positive transitions 36 | pos_buffer = ReplayBuffer(env.observation_space, env.action_space, capacity=20000, include_label=True,) 37 | 38 | success_paths = glob.glob(os.path.join(os.getcwd(), "classifier_data", "*success*.pkl")) 39 | for path in success_paths: 40 | success_data = pkl.load(open(path, "rb")) 41 | for trans in success_data: 42 | if "images" in trans['observations'].keys(): 43 | continue 44 | trans["labels"] = 1 45 | trans['actions'] = env.action_space.sample() 46 | pos_buffer.insert(trans) 47 | 48 | pos_iterator = pos_buffer.get_iterator(sample_args={"batch_size": FLAGS.batch_size // 2,}, device=sharding.replicate(),) 49 | 50 | # Create buffer for negative transitions 51 | neg_buffer = ReplayBuffer( env.observation_space, env.action_space, capacity=50000, include_label=True,) 52 | failure_paths = glob.glob(os.path.join(os.getcwd(), "classifier_data", "*failure*.pkl")) 53 | for path in failure_paths: 54 | failure_data = pkl.load( open(path, "rb")) 55 | for trans in failure_data: 56 | if "images" in trans['observations'].keys(): 57 | continue 58 | trans["labels"] = 0 59 | trans['actions'] = env.action_space.sample() 60 | neg_buffer.insert(trans) 61 | 62 | neg_iterator = neg_buffer.get_iterator(sample_args={"batch_size": FLAGS.batch_size // 2,}, device=sharding.replicate(),) 63 | 64 | print(f"failed buffer size: {len(neg_buffer)}") 65 | print(f"success buffer size: {len(pos_buffer)}") 66 | 67 | rng = jax.random.PRNGKey(0) 68 | rng, key = jax.random.split(rng) 69 | pos_sample = next(pos_iterator) 70 | neg_sample = next(neg_iterator) 71 | sample = concat_batches(pos_sample, neg_sample, axis=0) 72 | 73 | rng, key = jax.random.split(rng) 74 | classifier = create_classifier(key, sample["observations"], config.classifier_keys,) 75 | 76 | def data_augmentation_fn(rng, observations): 77 | for pixel_key in config.classifier_keys: 78 | observations = observations.copy( 79 | add_or_replace={ 80 | pixel_key: batched_random_crop( 81 | observations[pixel_key], rng, padding=4, num_batch_dims=2 82 | ) 83 | } 84 | ) 85 | return observations 86 | 87 | @jax.jit 88 | def train_step(state, batch, key): 89 | def loss_fn(params): 90 | logits = state.apply_fn({"params": params}, batch["observations"], rngs={"dropout": key}, train=True) 91 | return optax.sigmoid_binary_cross_entropy(logits, batch["labels"]).mean() 92 | 93 | grad_fn = jax.value_and_grad(loss_fn) 94 | loss, grads = grad_fn(state.params) 95 | logits = state.apply_fn({"params": state.params}, batch["observations"], train=False, rngs={"dropout": key}) 96 | train_accuracy = jnp.mean((nn.sigmoid(logits) >= 0.5) == batch["labels"]) 97 | 98 | return state.apply_gradients(grads=grads), loss, train_accuracy 99 | 100 | for epoch in tqdm(range(FLAGS.num_epochs)): 101 | # Sample equal number of positive and negative examples 102 | pos_sample = next(pos_iterator) 103 | neg_sample = next(neg_iterator) 104 | # Merge and create labels 105 | batch = concat_batches(pos_sample, neg_sample, axis=0) 106 | rng, key = jax.random.split(rng) 107 | obs = data_augmentation_fn(key, batch["observations"]) 108 | batch = batch.copy( 109 | add_or_replace={ 110 | "observations": obs, 111 | "labels": batch["labels"][..., None], 112 | } 113 | ) 114 | 115 | rng, key = jax.random.split(rng) 116 | classifier, train_loss, train_accuracy = train_step(classifier, batch, key) 117 | 118 | print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}") 119 | 120 | checkpoints.save_checkpoint(os.path.join(os.getcwd(), "classifier_ckpt/"), classifier, step=FLAGS.num_epochs, overwrite=True,) 121 | 122 | 123 | if __name__ == "__main__": 124 | app.run(main) -------------------------------------------------------------------------------- /serl_launcher/.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | build 3 | __pycache__ 4 | *.py[cod] 5 | *.so 6 | *.pyc 7 | wandb/ 8 | -------------------------------------------------------------------------------- /serl_launcher/README.md: -------------------------------------------------------------------------------- 1 | # Serl Launcher 2 | 3 | - Dependencies: `jax`, `agentlace` 4 | 5 | Code and scripts are modified from [jaxrl_m](https://github.com/dibyaghosh/jaxrl_m) or [jaxrl_m private].(https://github.com/rail-berkeley/jaxrl_minimal) 6 | -------------------------------------------------------------------------------- /serl_launcher/requirements.txt: -------------------------------------------------------------------------------- 1 | gym >= 0.26 2 | numpy>=1.24.3 3 | flax>=0.8.0 4 | distrax>=0.1.2 5 | ml_collections >= 0.1.0 6 | tqdm >= 4.60.0 7 | chex>=0.1.85 8 | optax>=0.1.5 9 | absl-py >= 0.12.0 10 | scipy==1.11.4 11 | wandb >= 0.12.14 12 | tensorflow>=2.15.0 13 | tensorflow_probability>=0.23.0 14 | einops >= 0.6.1 15 | imageio >= 2.31.1 16 | moviepy >= 1.0.3 17 | pre-commit==3.3.3 18 | gymnasium==0.29.1 19 | tf-keras 20 | pynput 21 | natsort 22 | matplotlib -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/serl_launcher/serl_launcher/__init__.py -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .continuous.bc import BCAgent 2 | from .continuous.sac import SACAgent 3 | from .continuous.sac_hybrid_single import SACAgentHybridSingleArm 4 | from .continuous.sac_hybrid_dual import SACAgentHybridDualArm 5 | 6 | agents = { 7 | "bc": BCAgent, 8 | "sac": SACAgent, 9 | "sac_hybrid_single": SACAgentHybridSingleArm, 10 | "sac_hybrid_dual": SACAgentHybridDualArm, 11 | } 12 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/agents/continuous/bc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Iterable, Optional 3 | 4 | import flax 5 | import flax.linen as nn 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | from flax.core import FrozenDict 11 | 12 | from serl_launcher.common.common import JaxRLTrainState, ModuleDict, nonpytree_field 13 | from serl_launcher.common.encoding import EncodingWrapper 14 | from serl_launcher.common.typing import Batch, PRNGKey 15 | from serl_launcher.networks.actor_critic_nets import Policy 16 | from serl_launcher.networks.mlp import MLP 17 | from serl_launcher.utils.train_utils import _unpack 18 | from serl_launcher.vision.data_augmentations import batched_random_crop 19 | 20 | 21 | class BCAgent(flax.struct.PyTreeNode): 22 | state: JaxRLTrainState 23 | config: dict = nonpytree_field() 24 | 25 | def data_augmentation_fn(self, rng, observations): 26 | for pixel_key in self.config["image_keys"]: 27 | observations = observations.copy( 28 | add_or_replace={ 29 | pixel_key: batched_random_crop( 30 | observations[pixel_key], rng, padding=4, num_batch_dims=2 31 | ) 32 | } 33 | ) 34 | return observations 35 | 36 | @partial(jax.jit, static_argnames="pmap_axis") 37 | def update(self, batch: Batch, pmap_axis: str = None): 38 | if self.config["image_keys"][0] not in batch["next_observations"]: 39 | batch = _unpack(batch) 40 | 41 | rng, aug_rng = jax.random.split(self.state.rng) 42 | if "augmentation_function" in self.config.keys() and self.config["augmentation_function"] is not None: 43 | batch = self.config["augmentation_function"](batch, aug_rng) 44 | 45 | def loss_fn(params, rng): 46 | rng, key = jax.random.split(rng) 47 | dist = self.state.apply_fn( 48 | {"params": params}, 49 | batch["observations"], 50 | temperature=1.0, 51 | train=True, 52 | rngs={"dropout": key}, 53 | name="actor", 54 | ) 55 | pi_actions = dist.mode() 56 | log_probs = dist.log_prob(batch["actions"]) 57 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 58 | actor_loss = -(log_probs).mean() 59 | 60 | return actor_loss, { 61 | "actor_loss": actor_loss, 62 | "mse": mse.mean(), 63 | } 64 | 65 | # compute gradients and update params 66 | new_state, info = self.state.apply_loss_fns( 67 | loss_fn, pmap_axis=pmap_axis, has_aux=True 68 | ) 69 | 70 | return self.replace(state=new_state), info 71 | 72 | def forward_policy(self, observations: np.ndarray, *, temperature: float = 1.0, non_squash_distribution: bool = False): 73 | dist = self.state.apply_fn( 74 | {"params": self.state.params}, 75 | observations, 76 | train=False, 77 | temperature=temperature, 78 | name="actor", 79 | non_squash_distribution=non_squash_distribution 80 | ) 81 | return dist 82 | 83 | @partial(jax.jit, static_argnames="argmax") 84 | def sample_actions( 85 | self, 86 | observations: np.ndarray, 87 | *, 88 | seed: Optional[PRNGKey] = None, 89 | temperature: float = 1.0, 90 | argmax=False, 91 | ) -> jnp.ndarray: 92 | dist = self.state.apply_fn( 93 | {"params": self.state.params}, 94 | observations, 95 | temperature=temperature, 96 | name="actor", 97 | ) 98 | if argmax: 99 | actions = dist.mode() 100 | else: 101 | actions = dist.sample(seed=seed) 102 | return actions 103 | 104 | @jax.jit 105 | def get_debug_metrics(self, batch, **kwargs): 106 | dist = self.state.apply_fn( 107 | {"params": self.state.params}, 108 | batch["observations"], 109 | temperature=1.0, 110 | name="actor", 111 | ) 112 | pi_actions = dist.mode() 113 | log_probs = dist.log_prob(batch["actions"]) 114 | mse = ((pi_actions - batch["actions"]) ** 2).sum(-1) 115 | 116 | return { 117 | "mse": mse, 118 | "log_probs": log_probs, 119 | "pi_actions": pi_actions, 120 | } 121 | 122 | @classmethod 123 | def create( 124 | cls, 125 | rng: PRNGKey, 126 | observations: FrozenDict, 127 | actions: jnp.ndarray, 128 | # Model architecture 129 | encoder_type: str = "resnet-pretrained", 130 | image_keys: Iterable[str] = ("image",), 131 | use_proprio: bool = False, 132 | network_kwargs: dict = { 133 | "hidden_dims": [256, 256], 134 | }, 135 | policy_kwargs: dict = { 136 | "tanh_squash_distribution": False, 137 | }, 138 | # Optimizer 139 | learning_rate: float = 3e-4, 140 | augmentation_function: Optional[callable] = None, 141 | ): 142 | if encoder_type == "resnet": 143 | from serl_launcher.vision.resnet_v1 import resnetv1_configs 144 | 145 | encoders = { 146 | image_key: resnetv1_configs["resnetv1-10"]( 147 | pooling_method="spatial_learned_embeddings", 148 | num_spatial_blocks=8, 149 | bottleneck_dim=256, 150 | name=f"encoder_{image_key}", 151 | ) 152 | for image_key in image_keys 153 | } 154 | elif encoder_type == "resnet-pretrained": 155 | from serl_launcher.vision.resnet_v1 import ( 156 | PreTrainedResNetEncoder, 157 | resnetv1_configs, 158 | ) 159 | 160 | pretrained_encoder = resnetv1_configs["resnetv1-10-frozen"]( 161 | pre_pooling=True, 162 | name="pretrained_encoder", 163 | ) 164 | encoders = { 165 | image_key: PreTrainedResNetEncoder( 166 | pooling_method="spatial_learned_embeddings", 167 | num_spatial_blocks=8, 168 | bottleneck_dim=256, 169 | pretrained_encoder=pretrained_encoder, 170 | name=f"encoder_{image_key}", 171 | ) 172 | for image_key in image_keys 173 | } 174 | else: 175 | raise NotImplementedError(f"Unknown encoder type: {encoder_type}") 176 | 177 | encoder_def = EncodingWrapper( 178 | encoder=encoders, 179 | use_proprio=use_proprio, 180 | enable_stacking=True, 181 | image_keys=image_keys, 182 | ) 183 | 184 | network_kwargs["activate_final"] = True 185 | networks = { 186 | "actor": Policy( 187 | encoder_def, 188 | MLP(**network_kwargs), 189 | action_dim=actions.shape[-1], 190 | **policy_kwargs, 191 | ) 192 | } 193 | 194 | model_def = ModuleDict(networks) 195 | 196 | tx = optax.adam(learning_rate) 197 | 198 | rng, init_rng = jax.random.split(rng) 199 | params = model_def.init(init_rng, actor=[observations])["params"] 200 | 201 | rng, create_rng = jax.random.split(rng) 202 | state = JaxRLTrainState.create( 203 | apply_fn=model_def.apply, 204 | params=params, 205 | txs=tx, 206 | target_params=params, 207 | rng=create_rng, 208 | ) 209 | config = dict( 210 | image_keys=image_keys, 211 | augmentation_function=augmentation_function 212 | ) 213 | 214 | agent = cls(state, config) 215 | 216 | if encoder_type == "resnet-pretrained": # load pretrained weights for ResNet-10 217 | from serl_launcher.utils.train_utils import load_resnet10_params 218 | agent = load_resnet10_params(agent, image_keys) 219 | 220 | return agent 221 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/common/encoding.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterable, Optional, Tuple 2 | 3 | import flax 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | from einops import rearrange, repeat 8 | from octo.model.octo_module import OctoTransformer 9 | from octo.utils.typing import Config, Data, Params, PRNGKey, Sequence 10 | 11 | class EncodingWrapper(nn.Module): 12 | """ 13 | Encodes observations into a single flat encoding, adding additional 14 | functionality for adding proprioception and stopping the gradient. 15 | 16 | Args: 17 | encoder: The encoder network. 18 | use_proprio: Whether to concatenate proprioception (after encoding). 19 | """ 20 | 21 | encoder: nn.Module 22 | use_proprio: bool 23 | proprio_latent_dim: int = 64 24 | enable_stacking: bool = False 25 | image_keys: Iterable[str] = ("image",) 26 | 27 | @nn.compact 28 | def __call__( 29 | self, 30 | observations: Dict[str, jnp.ndarray], 31 | train=False, 32 | stop_gradient=False, 33 | is_encoded=False, 34 | ) -> jnp.ndarray: 35 | # encode images with encoder 36 | encoded = [] 37 | for image_key in self.image_keys: 38 | image = observations[image_key] 39 | if not is_encoded: 40 | if self.enable_stacking: 41 | # Combine stacking and channels into a single dimension 42 | if len(image.shape) == 4: 43 | T = image.shape[0] 44 | if T > 1: 45 | image = image[-1:] # for stacked images, only use the last frame 46 | image = rearrange(image, "T H W C -> H W (T C)") 47 | if len(image.shape) == 5: 48 | T = image.shape[1] 49 | if T > 1: 50 | image = image[:, -1:] # for stacked images, only use the last frame 51 | image = rearrange(image, "B T H W C -> B H W (T C)") 52 | 53 | image = self.encoder[image_key](image, train=train, encode=not is_encoded) 54 | 55 | if stop_gradient: 56 | image = jax.lax.stop_gradient(image) 57 | 58 | encoded.append(image) 59 | 60 | encoded = jnp.concatenate(encoded, axis=-1) 61 | 62 | if self.use_proprio: 63 | # project state to embeddings as well 64 | state = observations["state"] 65 | if self.enable_stacking: 66 | # Combine stacking and channels into a single dimension 67 | if len(state.shape) == 2: 68 | state = rearrange(state, "T C -> (T C)") 69 | encoded = encoded.reshape(-1) 70 | if len(state.shape) == 3: 71 | state = rearrange(state, "B T C -> B (T C)") 72 | state = nn.Dense( 73 | self.proprio_latent_dim, kernel_init=nn.initializers.xavier_uniform() 74 | )(state) 75 | state = nn.LayerNorm()(state) 76 | state = nn.tanh(state) 77 | encoded = jnp.concatenate([encoded, state], axis=-1) 78 | 79 | return encoded 80 | 81 | class OctoEncodingWrapper(nn.Module): 82 | """ 83 | Encodes observations into a single flat encoding, adding additional 84 | functionality for adding proprioception and stopping the gradient. 85 | 86 | Args: 87 | encoder: The encoder network. 88 | use_proprio: Whether to concatenate proprioception (after encoding). 89 | """ 90 | 91 | encoder: OctoTransformer 92 | use_proprio: bool 93 | proprio_latent_dim: int = 64 94 | enable_stacking: bool = False 95 | image_keys: Iterable[str] = ("image",) 96 | 97 | @nn.compact 98 | def __call__( 99 | self, 100 | observations: Dict[str, jnp.ndarray], 101 | tasks: Data=None, 102 | action_embeddings: jnp.ndarray=None, 103 | train=True, 104 | stop_gradient=False, 105 | ) -> jnp.ndarray: 106 | if action_embeddings is None: 107 | image_primary = observations["side_policy_256"] 108 | image_wrist = observations["wrist_1"] 109 | if image_primary.ndim == 4: 110 | image_primary = image_primary[jnp.newaxis, ...] 111 | image_wrist = image_wrist[jnp.newaxis, ...] 112 | batch_size = image_primary.shape[0] 113 | window_size = image_primary.shape[1] 114 | timestep_pad_mask = jnp.ones((batch_size, window_size), dtype=bool) 115 | 116 | if not stop_gradient: 117 | def mask_image(image, mask_flag): 118 | return jax.lax.cond( 119 | mask_flag, 120 | lambda _: jnp.zeros_like(image), 121 | lambda _: image, 122 | operand=None) 123 | 124 | mask_flags = jax.random.bernoulli(self.make_rng('mask_wrist'), p=0.2, shape=(batch_size,)) 125 | image_wrist = jax.vmap(mask_image)(image_wrist, mask_flags) 126 | 127 | observation_octo = {"image_primary": image_primary, 128 | "image_wrist": image_wrist, 129 | "timestep_pad_mask": timestep_pad_mask, 130 | } 131 | 132 | transformer_outputs = self.encoder(observation_octo, tasks, timestep_pad_mask, train=not stop_gradient) 133 | token_group = transformer_outputs["readout_action"] 134 | action_embeddings = token_group.tokens.mean(axis=-2) 135 | 136 | action_embeddings = action_embeddings[:, -1, :] # remove window_size dimension 137 | else: 138 | action_embeddings = action_embeddings 139 | 140 | if stop_gradient: 141 | action_embeddings = jax.lax.stop_gradient(action_embeddings) 142 | 143 | encoded = action_embeddings 144 | if self.use_proprio: 145 | # project state to embeddings as well 146 | state = observations["state"] 147 | if self.enable_stacking: 148 | # Combine stacking and channels into a single dimension 149 | if len(state.shape) == 2: 150 | state = rearrange(state, "T C -> (T C)") 151 | encoded = encoded.reshape(-1) 152 | if len(state.shape) == 3: 153 | state = rearrange(state, "B T C -> B (T C)") 154 | state = nn.Dense( 155 | self.proprio_latent_dim, kernel_init=nn.initializers.xavier_uniform() 156 | )(state) 157 | state = nn.LayerNorm()(state) 158 | state = nn.tanh(state) 159 | encoded = jnp.concatenate([encoded, state], axis=-1) 160 | 161 | return encoded, action_embeddings 162 | 163 | 164 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/common/evaluation.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict 3 | from typing import Dict 4 | 5 | import gymnasium as gym 6 | import jax 7 | import numpy as np 8 | 9 | 10 | def supply_rng(f, rng=jax.random.PRNGKey(0)): 11 | def wrapped(*args, **kwargs): 12 | nonlocal rng 13 | rng, key = jax.random.split(rng) 14 | return f(*args, seed=key, **kwargs) 15 | 16 | return wrapped 17 | 18 | 19 | def flatten(d, parent_key="", sep="."): 20 | items = [] 21 | for k, v in d.items(): 22 | new_key = parent_key + sep + k if parent_key else k 23 | if hasattr(v, "items"): 24 | items.extend(flatten(v, new_key, sep=sep).items()) 25 | else: 26 | items.append((new_key, v)) 27 | return dict(items) 28 | 29 | 30 | def filter_info(info): 31 | filter_keys = [ 32 | "object_names", 33 | "target_object", 34 | "initial_positions", 35 | "target_position", 36 | "goal", 37 | ] 38 | for k in filter_keys: 39 | if k in info: 40 | del info[k] 41 | return info 42 | 43 | 44 | def add_to(dict_of_lists, single_dict): 45 | for k, v in single_dict.items(): 46 | dict_of_lists[k].append(v) 47 | 48 | 49 | def evaluate(policy_fn, env: gym.Env, num_episodes: int) -> Dict[str, float]: 50 | stats = defaultdict(list) 51 | for _ in range(num_episodes): 52 | observation, info = env.reset() 53 | add_to(stats, flatten(info)) 54 | done = False 55 | while not done: 56 | action = policy_fn(observation) 57 | observation, _, terminated, truncated, info = env.step(action) 58 | done = terminated or truncated 59 | add_to(stats, flatten(info)) 60 | add_to(stats, flatten(info, parent_key="final")) 61 | 62 | for k, v in stats.items(): 63 | stats[k] = np.mean(v) 64 | return stats 65 | 66 | 67 | def evaluate_with_trajectories( 68 | policy_fn, env: gym.Env, num_episodes: int 69 | ) -> Dict[str, float]: 70 | trajectories = [] 71 | stats = defaultdict(list) 72 | 73 | for _ in range(num_episodes): 74 | trajectory = defaultdict(list) 75 | observation, info = env.reset() 76 | add_to(stats, flatten(info)) 77 | done = False 78 | while not done: 79 | action = policy_fn(observation) 80 | next_observation, r, terminated, truncated, info = env.step(action) 81 | done = terminated or truncated 82 | transition = dict( 83 | observation=observation, 84 | next_observation=next_observation, 85 | action=action, 86 | reward=r, 87 | done=done, 88 | info=info, 89 | ) 90 | add_to(trajectory, transition) 91 | add_to(stats, flatten(info)) 92 | observation = next_observation 93 | add_to(stats, flatten(info, parent_key="final")) 94 | trajectories.append(trajectory) 95 | 96 | for k, v in stats.items(): 97 | stats[k] = np.mean(v) 98 | return stats, trajectories 99 | 100 | 101 | def bootstrap_std(arr, f=np.mean, n=30): 102 | arr = np.array(arr) 103 | return np.std([f(arr[np.random.choice(len(arr), len(arr))]) for _ in range(n)]) 104 | 105 | 106 | def parallel_evaluate(policy_fn, eval_envs, num_eval, verbose=True): 107 | n_envs = len(eval_envs.reset()) 108 | eval_episode_rewards = [] 109 | eval_episode_time_rewards = [] 110 | counter = np.zeros(n_envs) 111 | 112 | obs = eval_envs.reset() 113 | if verbose: 114 | print("Evaluating Envs") 115 | n_per = int(math.ceil(num_eval / n_envs)) 116 | n_to_eval = n_per * n_envs 117 | while len(eval_episode_rewards) < n_to_eval: 118 | action = policy_fn(obs) 119 | 120 | # Observe reward and next obs 121 | obs, _, done, infos = eval_envs.step(action) 122 | 123 | for n, info in enumerate(infos): 124 | if "episode" in info.keys() and counter[n] < n_per: 125 | eval_episode_rewards.append(info["episode"]["r"]) 126 | eval_episode_time_rewards.append(info["episode"]["time_r"]) 127 | counter[n] += 1 128 | if verbose: 129 | print( 130 | f"Evaluation using {len(eval_episode_rewards)} episodes: mean reward {np.mean(eval_episode_rewards):.5f} +- {bootstrap_std(eval_episode_rewards):.5f} \n" 131 | ) 132 | return eval_episode_rewards, eval_episode_time_rewards 133 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/common/optimizers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import optax 4 | 5 | 6 | def make_optimizer( 7 | learning_rate: float = 3e-4, 8 | warmup_steps: int = 0, 9 | cosine_decay_steps: Optional[int] = None, 10 | weight_decay: Optional[float] = None, 11 | clip_grad_norm: Optional[float] = None, 12 | return_lr_schedule: bool = False, 13 | ) -> optax.GradientTransformation: 14 | if cosine_decay_steps is not None: 15 | learning_rate_schedule = optax.warmup_cosine_decay_schedule( 16 | init_value=0.0, 17 | peak_value=learning_rate, 18 | warmup_steps=warmup_steps, 19 | decay_steps=cosine_decay_steps, 20 | end_value=0.0, 21 | ) 22 | else: 23 | learning_rate_schedule = optax.join_schedules( 24 | [ 25 | optax.linear_schedule(0.0, learning_rate, warmup_steps), 26 | optax.constant_schedule(learning_rate), 27 | ], 28 | [warmup_steps], 29 | ) 30 | 31 | # Define optimizers 32 | @optax.inject_hyperparams 33 | def optimizer(learning_rate: float, weight_decay: Optional[float]): 34 | optimizer_stages = [] 35 | 36 | if clip_grad_norm is not None: 37 | optimizer_stages.append(optax.clip_by_global_norm(clip_grad_norm)) 38 | 39 | if weight_decay is not None: 40 | optimizer_stages.append( 41 | optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay) 42 | ) 43 | else: 44 | optimizer_stages.append(optax.adam(learning_rate=learning_rate)) 45 | 46 | return optax.chain(*optimizer_stages) 47 | 48 | if return_lr_schedule: 49 | return ( 50 | optimizer(learning_rate=learning_rate_schedule, weight_decay=weight_decay), 51 | learning_rate_schedule, 52 | ) 53 | else: 54 | return optimizer( 55 | learning_rate=learning_rate_schedule, weight_decay=weight_decay 56 | ) 57 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/common/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Sequence, Union 2 | 3 | import flax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | PRNGKey = Any 9 | Params = flax.core.FrozenDict[str, Any] 10 | Shape = Sequence[int] 11 | Dtype = Any # this could be a real type? 12 | InfoDict = Dict[str, float] 13 | Array = Union[np.ndarray, jnp.ndarray, tf.Tensor] 14 | Data = Union[Array, Dict[str, "Data"]] 15 | Batch = Dict[str, Data] 16 | # A method to be passed into TrainState.__call__ 17 | ModuleMethod = Union[str, Callable, None] 18 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/common/wandb.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import tempfile 3 | from copy import copy 4 | from socket import gethostname 5 | 6 | import absl.flags as flags 7 | import ml_collections 8 | import wandb 9 | 10 | 11 | def _recursive_flatten_dict(d: dict): 12 | keys, values = [], [] 13 | for key, value in d.items(): 14 | if isinstance(value, dict): 15 | sub_keys, sub_values = _recursive_flatten_dict(value) 16 | keys += [f"{key}/{k}" for k in sub_keys] 17 | values += sub_values 18 | else: 19 | keys.append(key) 20 | values.append(value) 21 | return keys, values 22 | 23 | 24 | class WandBLogger(object): 25 | @staticmethod 26 | def get_default_config(): 27 | config = ml_collections.ConfigDict() 28 | config.project = "serl_launcher" # WandB Project Name 29 | config.entity = ml_collections.config_dict.FieldReference(None, field_type=str) 30 | # Which entity to log as (default: your own user) 31 | config.exp_descriptor = "" # Run name (doesn't have to be unique) 32 | # Unique identifier for run (will be automatically generated unless 33 | # provided) 34 | config.unique_identifier = "" 35 | config.group = None 36 | return config 37 | 38 | def __init__( 39 | self, 40 | wandb_config, 41 | variant, 42 | wandb_output_dir=None, 43 | debug=False, 44 | ): 45 | self.config = wandb_config 46 | if self.config.unique_identifier == "": 47 | self.config.unique_identifier = datetime.datetime.now().strftime( 48 | "%Y%m%d_%H%M%S" 49 | ) 50 | 51 | self.config.experiment_id = ( 52 | self.experiment_id 53 | ) = f"{self.config.exp_descriptor}_{self.config.unique_identifier}" # NOQA 54 | 55 | print(self.config) 56 | 57 | if wandb_output_dir is None: 58 | wandb_output_dir = tempfile.mkdtemp() 59 | 60 | self._variant = copy(variant) 61 | 62 | if "hostname" not in self._variant: 63 | self._variant["hostname"] = gethostname() 64 | 65 | if debug: 66 | mode = "disabled" 67 | else: 68 | mode = "online" 69 | 70 | self.run = wandb.init( 71 | config=self._variant, 72 | project=self.config.project, 73 | entity=self.config.entity, 74 | group=self.config.group, 75 | tags=self.config.tag, 76 | dir=wandb_output_dir, 77 | id=self.config.experiment_id, 78 | save_code=True, 79 | mode=mode, 80 | ) 81 | 82 | if flags.FLAGS.is_parsed(): 83 | flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS} 84 | else: 85 | flag_dict = {} 86 | for k in flag_dict: 87 | if isinstance(flag_dict[k], ml_collections.ConfigDict): 88 | flag_dict[k] = flag_dict[k].to_dict() 89 | wandb.config.update(flag_dict) 90 | 91 | def log(self, data: dict, step: int = None): 92 | data_flat = _recursive_flatten_dict(data) 93 | data = {k: v for k, v in zip(*data_flat)} 94 | wandb.log(data, step=step) 95 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/serl_launcher/serl_launcher/data/__init__.py -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/data/data_store.py: -------------------------------------------------------------------------------- 1 | from threading import Lock 2 | from typing import Union, Iterable 3 | 4 | import gymnasium as gym 5 | import jax 6 | from serl_launcher.data.replay_buffer import ReplayBuffer 7 | from serl_launcher.data.memory_efficient_replay_buffer import ( 8 | MemoryEfficientReplayBuffer, 9 | ) 10 | 11 | from agentlace.data.data_store import DataStoreBase 12 | 13 | 14 | class ReplayBufferDataStore(ReplayBuffer, DataStoreBase): 15 | def __init__( 16 | self, 17 | observation_space: gym.Space, 18 | action_space: gym.Space, 19 | capacity: int, 20 | ): 21 | ReplayBuffer.__init__(self, observation_space, action_space, capacity) 22 | DataStoreBase.__init__(self, capacity) 23 | self._lock = Lock() 24 | 25 | # ensure thread safety 26 | def insert(self, *args, **kwargs): 27 | with self._lock: 28 | super(ReplayBufferDataStore, self).insert(*args, **kwargs) 29 | 30 | # ensure thread safety 31 | def sample(self, *args, **kwargs): 32 | with self._lock: 33 | return super(ReplayBufferDataStore, self).sample(*args, **kwargs) 34 | 35 | # NOTE: method for DataStoreBase 36 | def latest_data_id(self): 37 | return self._insert_index 38 | 39 | # NOTE: method for DataStoreBase 40 | def get_latest_data(self, from_id: int): 41 | raise NotImplementedError # TODO 42 | 43 | 44 | class MemoryEfficientReplayBufferDataStore(MemoryEfficientReplayBuffer, DataStoreBase): 45 | def __init__( 46 | self, 47 | observation_space: gym.Space, 48 | action_space: gym.Space, 49 | capacity: int, 50 | image_keys: Iterable[str] = ("image",), 51 | **kwargs, 52 | ): 53 | MemoryEfficientReplayBuffer.__init__( 54 | self, observation_space, action_space, capacity, pixel_keys=image_keys, **kwargs 55 | ) 56 | DataStoreBase.__init__(self, capacity) 57 | self._lock = Lock() 58 | 59 | # ensure thread safety 60 | def insert(self, *args, **kwargs): 61 | with self._lock: 62 | super(MemoryEfficientReplayBufferDataStore, self).insert(*args, **kwargs) 63 | 64 | # ensure thread safety 65 | def sample(self, *args, **kwargs): 66 | with self._lock: 67 | return super(MemoryEfficientReplayBufferDataStore, self).sample( 68 | *args, **kwargs 69 | ) 70 | 71 | # NOTE: method for DataStoreBase 72 | def latest_data_id(self): 73 | return self._insert_index 74 | 75 | # NOTE: method for DataStoreBase 76 | def get_latest_data(self, from_id: int): 77 | raise NotImplementedError # TODO 78 | 79 | 80 | def populate_data_store( 81 | data_store: DataStoreBase, 82 | demos_path: str, 83 | ): 84 | """ 85 | Utility function to populate demonstrations data into data_store. 86 | :return data_store 87 | """ 88 | import pickle as pkl 89 | import numpy as np 90 | from copy import deepcopy 91 | 92 | for demo_path in demos_path: 93 | with open(demo_path, "rb") as f: 94 | demo = pkl.load(f) 95 | for transition in demo: 96 | data_store.insert(transition) 97 | print(f"Loaded {len(data_store)} transitions.") 98 | return data_store 99 | 100 | 101 | def populate_data_store_with_z_axis_only( 102 | data_store: DataStoreBase, 103 | demos_path: str, 104 | ): 105 | """ 106 | Utility function to populate demonstrations data into data_store. 107 | This will remove the x and y cartesian coordinates from the state. 108 | :return data_store 109 | """ 110 | import pickle as pkl 111 | import numpy as np 112 | from copy import deepcopy 113 | 114 | for demo_path in demos_path: 115 | with open(demo_path, "rb") as f: 116 | demo = pkl.load(f) 117 | for transition in demo: 118 | tmp = deepcopy(transition) 119 | tmp["observations"]["state"] = np.concatenate( 120 | ( 121 | tmp["observations"]["state"][:, :4], 122 | tmp["observations"]["state"][:, 6][None, ...], 123 | tmp["observations"]["state"][:, 10:], 124 | ), 125 | axis=-1, 126 | ) 127 | tmp["next_observations"]["state"] = np.concatenate( 128 | ( 129 | tmp["next_observations"]["state"][:, :4], 130 | tmp["next_observations"]["state"][:, 6][None, ...], 131 | tmp["next_observations"]["state"][:, 10:], 132 | ), 133 | axis=-1, 134 | ) 135 | data_store.insert(tmp) 136 | print(f"Loaded {len(data_store)} transitions.") 137 | return data_store 138 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/data/dataset.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Dict, Iterable, Optional, Tuple, Union 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | from flax.core import frozen_dict 8 | from gymnasium.utils import seeding 9 | 10 | DataType = Union[np.ndarray, Dict[str, "DataType"]] 11 | DatasetDict = Dict[str, DataType] 12 | 13 | 14 | def _check_lengths(dataset_dict: DatasetDict, dataset_len: Optional[int] = None) -> int: 15 | for v in dataset_dict.values(): 16 | if isinstance(v, dict): 17 | dataset_len = dataset_len or _check_lengths(v, dataset_len) 18 | elif isinstance(v, np.ndarray): 19 | item_len = len(v) 20 | dataset_len = dataset_len or item_len 21 | assert dataset_len == item_len, "Inconsistent item lengths in the dataset." 22 | else: 23 | raise TypeError("Unsupported type.") 24 | return dataset_len 25 | 26 | 27 | def _subselect(dataset_dict: DatasetDict, index: np.ndarray) -> DatasetDict: 28 | new_dataset_dict = {} 29 | for k, v in dataset_dict.items(): 30 | if isinstance(v, dict): 31 | new_v = _subselect(v, index) 32 | elif isinstance(v, np.ndarray): 33 | new_v = v[index] 34 | else: 35 | raise TypeError("Unsupported type.") 36 | new_dataset_dict[k] = new_v 37 | return new_dataset_dict 38 | 39 | 40 | def _sample( 41 | dataset_dict: Union[np.ndarray, DatasetDict], indx: np.ndarray 42 | ) -> DatasetDict: 43 | if isinstance(dataset_dict, np.ndarray): 44 | return dataset_dict[indx] 45 | elif isinstance(dataset_dict, dict): 46 | batch = {} 47 | for k, v in dataset_dict.items(): 48 | batch[k] = _sample(v, indx) 49 | else: 50 | raise TypeError("Unsupported type.") 51 | return batch 52 | 53 | 54 | class Dataset(object): 55 | def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): 56 | self.dataset_dict = dataset_dict 57 | self.dataset_len = _check_lengths(dataset_dict) 58 | 59 | # Seeding similar to OpenAI Gym: 60 | # https://github.com/openai/gym/blob/master/gym/spaces/space.py#L46 61 | self._np_random = None 62 | self._seed = None 63 | if seed is not None: 64 | self.seed(seed) 65 | 66 | @property 67 | def np_random(self) -> np.random.RandomState: 68 | if self._np_random is None: 69 | self.seed() 70 | return self._np_random 71 | 72 | def seed(self, seed: Optional[int] = None) -> list: 73 | self._np_random, self._seed = seeding.np_random(seed) 74 | return [self._seed] 75 | 76 | def __len__(self) -> int: 77 | return self.dataset_len 78 | 79 | def sample( 80 | self, 81 | batch_size: int, 82 | keys: Optional[Iterable[str]] = None, 83 | indx: Optional[np.ndarray] = None, 84 | ) -> frozen_dict.FrozenDict: 85 | if indx is None: 86 | if hasattr(self.np_random, "integers"): 87 | indx = self.np_random.integers(len(self), size=batch_size) 88 | else: 89 | indx = self.np_random.randint(len(self), size=batch_size) 90 | 91 | batch = dict() 92 | 93 | if keys is None: 94 | keys = self.dataset_dict.keys() 95 | 96 | for k in keys: 97 | if isinstance(self.dataset_dict[k], dict): 98 | batch[k] = _sample(self.dataset_dict[k], indx) 99 | else: 100 | batch[k] = self.dataset_dict[k][indx] 101 | 102 | return frozen_dict.freeze(batch) 103 | 104 | def sample_jax(self, batch_size: int, keys: Optional[Iterable[str]] = None): 105 | if not hasattr(self, "rng"): 106 | self.rng = jax.random.PRNGKey(self._seed or 42) 107 | 108 | if keys is None: 109 | keys = self.dataset_dict.keys() 110 | 111 | # jax_dataset_dict = {k: self.dataset_dict[k] for k in keys} 112 | # jax_dataset_dict = jax.device_put(jax_dataset_dict) 113 | 114 | @jax.jit 115 | def _sample_jax(rng, src, max_indx: int): 116 | key, rng = jax.random.split(rng) 117 | indx = jax.random.randint(key, (batch_size,), minval=0, maxval=max_indx) 118 | return ( 119 | rng, 120 | indx.max(), 121 | jax.tree_map(lambda d: jnp.take(d, indx, axis=0), src), 122 | ) 123 | 124 | self._sample_jax = _sample_jax 125 | 126 | self.rng, indx_max, sample = self._sample_jax( 127 | self.rng, self.dataset_dict, len(self) 128 | ) 129 | return indx_max, sample 130 | 131 | def split(self, ratio: float) -> Tuple["Dataset", "Dataset"]: 132 | assert 0 < ratio and ratio < 1 133 | train_index = np.index_exp[: int(self.dataset_len * ratio)] 134 | test_index = np.index_exp[int(self.dataset_len * ratio) :] 135 | 136 | index = np.arange(len(self), dtype=np.int32) 137 | self.np_random.shuffle(index) 138 | train_index = index[: int(self.dataset_len * ratio)] 139 | test_index = index[int(self.dataset_len * ratio) :] 140 | 141 | train_dataset_dict = _subselect(self.dataset_dict, train_index) 142 | test_dataset_dict = _subselect(self.dataset_dict, test_index) 143 | return Dataset(train_dataset_dict), Dataset(test_dataset_dict) 144 | 145 | def _trajectory_boundaries_and_returns(self) -> Tuple[list, list, list]: 146 | episode_starts = [0] 147 | episode_ends = [] 148 | 149 | episode_return = 0 150 | episode_returns = [] 151 | 152 | for i in range(len(self)): 153 | episode_return += self.dataset_dict["rewards"][i] 154 | 155 | if self.dataset_dict["dones"][i]: 156 | episode_returns.append(episode_return) 157 | episode_ends.append(i + 1) 158 | if i + 1 < len(self): 159 | episode_starts.append(i + 1) 160 | episode_return = 0.0 161 | 162 | return episode_starts, episode_ends, episode_returns 163 | 164 | def filter( 165 | self, take_top: Optional[float] = None, threshold: Optional[float] = None 166 | ): 167 | assert (take_top is None and threshold is not None) or ( 168 | take_top is not None and threshold is None 169 | ) 170 | 171 | ( 172 | episode_starts, 173 | episode_ends, 174 | episode_returns, 175 | ) = self._trajectory_boundaries_and_returns() 176 | 177 | if take_top is not None: 178 | threshold = np.percentile(episode_returns, 100 - take_top) 179 | 180 | bool_indx = np.full((len(self),), False, dtype=bool) 181 | 182 | for i in range(len(episode_returns)): 183 | if episode_returns[i] >= threshold: 184 | bool_indx[episode_starts[i] : episode_ends[i]] = True 185 | 186 | self.dataset_dict = _subselect(self.dataset_dict, bool_indx) 187 | 188 | self.dataset_len = _check_lengths(self.dataset_dict) 189 | 190 | def normalize_returns(self, scaling: float = 1000): 191 | (_, _, episode_returns) = self._trajectory_boundaries_and_returns() 192 | self.dataset_dict["rewards"] /= np.max(episode_returns) - np.min( 193 | episode_returns 194 | ) 195 | self.dataset_dict["rewards"] *= scaling 196 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Iterable, Optional, Tuple 3 | 4 | import gymnasium as gym 5 | import numpy as np 6 | from serl_launcher.data.dataset import DatasetDict, _sample 7 | from serl_launcher.data.replay_buffer import ReplayBuffer 8 | from flax.core import frozen_dict 9 | from gymnasium.spaces import Box 10 | 11 | 12 | class MemoryEfficientReplayBuffer(ReplayBuffer): 13 | def __init__( 14 | self, 15 | observation_space: gym.Space, 16 | action_space: gym.Space, 17 | capacity: int, 18 | pixel_keys: Tuple[str, ...] = ("pixels",), 19 | include_next_actions: Optional[bool] = False, 20 | include_grasp_penalty: Optional[bool] = False, 21 | include_octo_embeddings: Optional[bool] = False, 22 | include_mc_returns: Optional[bool] = False, 23 | ): 24 | self.pixel_keys = pixel_keys 25 | 26 | observation_space = copy.deepcopy(observation_space) 27 | self._num_stack = None 28 | for pixel_key in self.pixel_keys: 29 | pixel_obs_space = observation_space.spaces[pixel_key] 30 | if self._num_stack is None: 31 | self._num_stack = pixel_obs_space.shape[0] 32 | else: 33 | assert self._num_stack == pixel_obs_space.shape[0] 34 | self._unstacked_dim_size = pixel_obs_space.shape[-1] 35 | low = pixel_obs_space.low[0] 36 | high = pixel_obs_space.high[0] 37 | unstacked_pixel_obs_space = Box( 38 | low=low, high=high, dtype=pixel_obs_space.dtype 39 | ) 40 | observation_space.spaces[pixel_key] = unstacked_pixel_obs_space 41 | 42 | next_observation_space_dict = copy.deepcopy(observation_space.spaces) 43 | for pixel_key in self.pixel_keys: 44 | next_observation_space_dict.pop(pixel_key) 45 | next_observation_space = gym.spaces.Dict(next_observation_space_dict) 46 | 47 | self._first = True 48 | self._is_correct_index = np.full(capacity, False, dtype=bool) 49 | 50 | super().__init__( 51 | observation_space, 52 | action_space, 53 | capacity, 54 | next_observation_space=next_observation_space, 55 | include_next_actions=include_next_actions, 56 | include_grasp_penalty=include_grasp_penalty, 57 | include_octo_embeddings=include_octo_embeddings, 58 | include_mc_returns=include_mc_returns, 59 | ) 60 | 61 | def insert(self, data_dict: DatasetDict): 62 | if self._insert_index == 0 and self._capacity == len(self) and not self._first: 63 | indxs = np.arange(len(self) - self._num_stack, len(self)) 64 | for indx in indxs: 65 | element = super().sample(1, indx=indx) 66 | self._is_correct_index[self._insert_index] = False 67 | super().insert(element) 68 | 69 | data_dict = data_dict.copy() 70 | data_dict["observations"] = data_dict["observations"].copy() 71 | data_dict["next_observations"] = data_dict["next_observations"].copy() 72 | 73 | obs_pixels = {} 74 | next_obs_pixels = {} 75 | for pixel_key in self.pixel_keys: 76 | obs_pixels[pixel_key] = data_dict["observations"].pop(pixel_key) 77 | next_obs_pixels[pixel_key] = data_dict["next_observations"].pop(pixel_key) 78 | 79 | if self._first: 80 | for i in range(self._num_stack): 81 | for pixel_key in self.pixel_keys: 82 | data_dict["observations"][pixel_key] = obs_pixels[pixel_key][i] 83 | 84 | self._is_correct_index[self._insert_index] = False 85 | super().insert(data_dict) 86 | 87 | for pixel_key in self.pixel_keys: 88 | data_dict["observations"][pixel_key] = next_obs_pixels[pixel_key][-1] 89 | 90 | self._first = data_dict["dones"] 91 | 92 | self._is_correct_index[self._insert_index] = True 93 | 94 | super().insert(data_dict) 95 | 96 | for i in range(self._num_stack): 97 | indx = (self._insert_index + i) % len(self) 98 | self._is_correct_index[indx] = False 99 | 100 | def sample( 101 | self, 102 | batch_size: int, 103 | keys: Optional[Iterable[str]] = None, 104 | indx: Optional[np.ndarray] = None, 105 | pack_obs: bool = False, 106 | ) -> frozen_dict.FrozenDict: 107 | """Samples from the replay buffer. 108 | 109 | Args: 110 | batch_size: Minibatch size. 111 | keys: Keys to sample. 112 | indx: Take indices instead of sampling. 113 | pack_obs: whether to pack img and next_img into one image. 114 | It's useful when they have overlapping frames. 115 | 116 | Returns: 117 | A frozen dictionary. 118 | """ 119 | 120 | if indx is None: 121 | if hasattr(self.np_random, "integers"): 122 | indx = self.np_random.integers(len(self), size=batch_size) 123 | else: 124 | indx = self.np_random.randint(len(self), size=batch_size) 125 | 126 | for i in range(batch_size): 127 | while not self._is_correct_index[indx[i]]: 128 | if hasattr(self.np_random, "integers"): 129 | indx[i] = self.np_random.integers(len(self)) 130 | else: 131 | indx[i] = self.np_random.randint(len(self)) 132 | else: 133 | raise NotImplementedError() 134 | 135 | if keys is None: 136 | keys = self.dataset_dict.keys() 137 | else: 138 | assert "observations" in keys 139 | 140 | keys = list(keys) 141 | keys.remove("observations") 142 | batch = super().sample(batch_size, keys, indx) 143 | batch = batch.unfreeze() 144 | 145 | obs_keys = self.dataset_dict["observations"].keys() 146 | obs_keys = list(obs_keys) 147 | for pixel_key in self.pixel_keys: 148 | obs_keys.remove(pixel_key) 149 | 150 | batch["observations"] = {} 151 | for k in obs_keys: 152 | batch["observations"][k] = _sample(self.dataset_dict["observations"][k], indx) 153 | 154 | for pixel_key in self.pixel_keys: 155 | obs_pixels = self.dataset_dict["observations"][pixel_key] 156 | obs_pixels = np.lib.stride_tricks.sliding_window_view(obs_pixels, self._num_stack + 1, axis=0) 157 | obs_pixels = obs_pixels[indx - self._num_stack] 158 | # transpose from (B, H, W, C, T) to (B, T, H, W, C) to follow jaxrl_m convention 159 | obs_pixels = obs_pixels.transpose((0, 4, 1, 2, 3)) 160 | 161 | if pack_obs: 162 | batch["observations"][pixel_key] = obs_pixels 163 | else: 164 | batch["observations"][pixel_key] = obs_pixels[:, :-1, ...] 165 | if "next_observations" in keys: 166 | batch["next_observations"][pixel_key] = obs_pixels[:, 1:, ...] 167 | 168 | return frozen_dict.freeze(batch) 169 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/data/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Any, Iterator, Optional, Sequence, Tuple, Union 3 | 4 | import gymnasium as gym 5 | import jax 6 | import numpy as np 7 | from serl_launcher.data.dataset import Dataset, DatasetDict 8 | 9 | 10 | def _init_replay_dict( 11 | obs_space: gym.Space, capacity: int 12 | ) -> Union[np.ndarray, DatasetDict]: 13 | if isinstance(obs_space, gym.spaces.Box): 14 | return np.empty((capacity, *obs_space.shape), dtype=obs_space.dtype) 15 | elif isinstance(obs_space, gym.spaces.Dict): 16 | data_dict = {} 17 | for k, v in obs_space.spaces.items(): 18 | data_dict[k] = _init_replay_dict(v, capacity) 19 | return data_dict 20 | else: 21 | raise TypeError() 22 | 23 | 24 | def _insert_recursively( 25 | dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int 26 | ): 27 | if isinstance(dataset_dict, np.ndarray): 28 | dataset_dict[insert_index] = data_dict 29 | elif isinstance(dataset_dict, dict): 30 | for k in dataset_dict.keys(): 31 | _insert_recursively(dataset_dict[k], data_dict[k], insert_index) 32 | else: 33 | raise TypeError() 34 | 35 | 36 | class ReplayBuffer(Dataset): 37 | def __init__( 38 | self, 39 | observation_space: gym.Space, 40 | action_space: gym.Space, 41 | capacity: int, 42 | next_observation_space: Optional[gym.Space] = None, 43 | include_next_actions: Optional[bool] = False, 44 | include_label: Optional[bool] = False, 45 | include_grasp_penalty: Optional[bool] = False, 46 | include_octo_embeddings: Optional[bool] = False, 47 | include_mc_returns: Optional[bool] = False, 48 | ): 49 | if next_observation_space is None: 50 | next_observation_space = observation_space 51 | 52 | observation_data = _init_replay_dict(observation_space, capacity) 53 | next_observation_data = _init_replay_dict(next_observation_space, capacity) 54 | dataset_dict = dict( 55 | observations=observation_data, 56 | next_observations=next_observation_data, 57 | actions=np.empty((capacity, *action_space.shape), dtype=action_space.dtype), 58 | rewards=np.empty((capacity,), dtype=np.float32), 59 | masks=np.empty((capacity,), dtype=np.float32), 60 | dones=np.empty((capacity,), dtype=bool), 61 | ) 62 | 63 | if include_mc_returns: 64 | dataset_dict['mc_returns'] = np.empty((capacity,), dtype=np.float32) 65 | 66 | if include_octo_embeddings: 67 | dataset_dict['embeddings'] = np.empty((capacity, 384), dtype=np.float32) 68 | dataset_dict['next_embeddings'] = np.empty((capacity, 384), dtype=np.float32) 69 | 70 | if include_next_actions: 71 | dataset_dict['next_actions'] = np.empty((capacity, *action_space.shape), dtype=action_space.dtype) 72 | dataset_dict['next_intvn'] = np.empty((capacity,), dtype=bool) 73 | 74 | if include_label: 75 | dataset_dict['labels'] = np.empty((capacity,), dtype=int) 76 | 77 | if include_grasp_penalty: 78 | dataset_dict['grasp_penalty'] = np.empty((capacity,), dtype=np.float32) 79 | 80 | super().__init__(dataset_dict) 81 | 82 | self._size = 0 83 | self._capacity = capacity 84 | self._insert_index = 0 85 | 86 | def __len__(self) -> int: 87 | return self._size 88 | 89 | def insert(self, data_dict: DatasetDict): 90 | _insert_recursively(self.dataset_dict, data_dict, self._insert_index) 91 | 92 | self._insert_index = (self._insert_index + 1) % self._capacity 93 | self._size = min(self._size + 1, self._capacity) 94 | 95 | def get_iterator(self, queue_size: int = 2, sample_args: dict = {}, device=None): 96 | # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device 97 | # queue_size = 2 should be ok for one GPU. 98 | queue = collections.deque() 99 | 100 | def enqueue(n): 101 | for _ in range(n): 102 | data = self.sample(**sample_args) 103 | queue.append(jax.device_put(data, device=device)) 104 | 105 | enqueue(queue_size) 106 | while queue: 107 | yield queue.popleft() 108 | enqueue(1) 109 | 110 | def download(self, from_idx: int, to_idx: int): 111 | indices = np.arange(from_idx, to_idx) 112 | data_dict = self.sample(batch_size=len(indices), indx=indices) 113 | return to_idx, data_dict 114 | 115 | def get_download_iterator(self): 116 | last_idx = 0 117 | while True: 118 | if last_idx >= self._size: 119 | raise RuntimeError( 120 | f"last_idx {last_idx} >= self._size {self._size}") 121 | last_idx, batch = self.download(last_idx, self._size) 122 | yield batch 123 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/networks/classifier.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | from einops import rearrange 3 | 4 | 5 | class BinaryClassifier(nn.Module): 6 | pretrained_encoder: nn.Module 7 | encoder: nn.Module 8 | network: nn.Module 9 | enable_stacking: bool = False 10 | 11 | @nn.compact 12 | def __call__(self, x, train=False, return_encoded=False, classify_encoded=False): 13 | if return_encoded: 14 | if self.enable_stacking: 15 | # Combine stacking and channels into a single dimension 16 | if len(x.shape) == 4: 17 | x = rearrange(x, "T H W C -> H W (T C)") 18 | if len(x.shape) == 5: 19 | x = rearrange(x, "B T H W C -> B H W (T C)") 20 | x = self.pretrained_encoder(x, train=train) 21 | return x 22 | 23 | x = self.encoder(x, train=train, is_encoded=classify_encoded) 24 | x = self.network(x, train=train) 25 | x = nn.Dense(1)(x).squeeze() 26 | return x 27 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/networks/diffusion_nets.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | 6 | def cosine_beta_schedule(timesteps, s=0.008): 7 | """ 8 | cosine schedule 9 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 10 | """ 11 | steps = timesteps + 1 12 | t = jnp.linspace(0, timesteps, steps) / timesteps 13 | alphas_cumprod = jnp.cos((t + s) / (1 + s) * jnp.pi * 0.5) ** 2 14 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 15 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 16 | return jnp.clip(betas, 0, 0.999) 17 | 18 | 19 | def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=2e-2): 20 | betas = jnp.linspace(beta_start, beta_end, timesteps) 21 | return betas 22 | 23 | 24 | def vp_beta_schedule(timesteps): 25 | t = jnp.arange(1, timesteps + 1) 26 | T = timesteps 27 | b_max = 10.0 28 | b_min = 0.1 29 | alpha = jnp.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T**2) 30 | betas = 1 - alpha 31 | return betas 32 | 33 | 34 | class ScoreActor(nn.Module): 35 | encoder: nn.Module 36 | time_preprocess: nn.Module 37 | cond_encoder: nn.Module 38 | reverse_network: nn.Module 39 | 40 | def __call__(self, observations, actions, time, train=False): 41 | t_ff = self.time_preprocess(time) 42 | cond_enc = self.cond_encoder(t_ff, train=train) 43 | obs_enc = self.encoder(observations, train=train) 44 | if obs_enc.shape[1] == 1: 45 | obs_enc = obs_enc[:, 0] 46 | reverse_input = jnp.concatenate([cond_enc, obs_enc, actions], axis=-1) 47 | eps_pred = self.reverse_network(reverse_input, train=train) 48 | 49 | # un-flatten pred sequence 50 | return eps_pred.reshape(actions.shape) 51 | 52 | 53 | class FourierFeatures(nn.Module): 54 | output_size: int 55 | learnable: bool = True 56 | 57 | @nn.compact 58 | def __call__(self, x: jnp.ndarray): 59 | if self.learnable: 60 | w = self.param( 61 | "kernel", 62 | nn.initializers.normal(0.2), 63 | (self.output_size // 2, x.shape[-1]), 64 | jnp.float32, 65 | ) 66 | f = 2 * jnp.pi * x @ w.T 67 | else: 68 | half_dim = self.output_size // 2 69 | f = jnp.log(10000) / (half_dim - 1) 70 | f = jnp.exp(jnp.arange(half_dim) * -f) 71 | f = x * f 72 | return jnp.concatenate([jnp.cos(f), jnp.sin(f)], axis=-1) 73 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/networks/lagrange.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable, Optional, Sequence 3 | 4 | import chex 5 | import flax.linen as nn 6 | import jax.numpy as jnp 7 | 8 | 9 | class LagrangeMultiplier(nn.Module): 10 | init_value: float = 1.0 11 | constraint_shape: Sequence[int] = () 12 | constraint_type: str = "eq" # One of ("eq", "leq", "geq") 13 | parameterization: Optional[ 14 | str 15 | ] = None # One of ("softplus", "exp"), or None for equality constraints 16 | 17 | @nn.compact 18 | def __call__( 19 | self, *, lhs: Optional[jnp.ndarray] = None, rhs: Optional[jnp.ndarray] = None 20 | ) -> jnp.ndarray: 21 | init_value = self.init_value 22 | 23 | if self.constraint_type != "eq": 24 | assert ( 25 | init_value > 0 26 | ), "Inequality constraints must have non-negative initial multiplier values" 27 | 28 | if self.parameterization == "softplus": 29 | init_value = jnp.log(jnp.exp(init_value) - 1) 30 | elif self.parameterization == "exp": 31 | init_value = jnp.log(init_value) 32 | elif self.parameterization == "none": 33 | pass 34 | else: 35 | raise ValueError( 36 | f"Invalid multiplier parameterization {self.parameterization}" 37 | ) 38 | else: 39 | assert ( 40 | self.parameterization is None 41 | ), "Equality constraints must have no parameterization" 42 | 43 | multiplier = self.param( 44 | "lagrange", 45 | lambda _, shape: jnp.full(shape, init_value), 46 | self.constraint_shape, 47 | ) 48 | 49 | if self.constraint_type != "eq": 50 | if self.parameterization == "softplus": 51 | multiplier = nn.softplus(multiplier) 52 | elif self.parameterization == "exp": 53 | multiplier = jnp.exp(multiplier) 54 | elif self.parameterization == "none": 55 | pass 56 | else: 57 | raise ValueError( 58 | f"Invalid multiplier parameterization {self.parameterization}" 59 | ) 60 | 61 | # Return the raw multiplier 62 | if lhs is None: 63 | return multiplier 64 | 65 | # Use the multiplier to compute the Lagrange penalty 66 | if rhs is None: 67 | rhs = jnp.zeros_like(lhs) 68 | 69 | diff = lhs - rhs 70 | 71 | chex.assert_equal_shape([diff, multiplier]) 72 | 73 | if self.constraint_type == "eq": 74 | return multiplier * diff 75 | elif self.constraint_type == "geq": 76 | return multiplier * diff 77 | elif self.constraint_type == "leq": 78 | return -multiplier * diff 79 | 80 | 81 | GeqLagrangeMultiplier = partial( 82 | LagrangeMultiplier, constraint_type="geq", parameterization="softplus" 83 | ) 84 | 85 | LeqLagrangeMultiplier = partial( 86 | LagrangeMultiplier, constraint_type="leq", parameterization="softplus" 87 | ) 88 | 89 | BetterLeqLagrangeMultiplier = partial( 90 | LagrangeMultiplier, constraint_type="leq", parameterization="none" 91 | ) 92 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/networks/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from serl_launcher.common.common import default_init 8 | 9 | 10 | class SinusoidalPosEmb(nn.Module): 11 | dim: int 12 | 13 | @nn.compact 14 | def __call__(self, time): 15 | half_dim = self.dim // 2 16 | embeddings = jnp.log(10000) / (half_dim - 1) 17 | embeddings = jnp.exp(jnp.arange(half_dim) * -embeddings) 18 | embeddings = time[:, None] * embeddings 19 | return jnp.concatenate([jnp.sin(embeddings), jnp.cos(embeddings)], axis=-1) 20 | 21 | class timeMLP(nn.Module): 22 | t_dim: Sequence[int] 23 | activations: Callable[[jnp.ndarray], jnp.ndarray] | str = nn.swish 24 | 25 | @nn.compact 26 | def __call__(self, t: jnp.ndarray) -> jnp.ndarray: 27 | activations = self.activations 28 | if isinstance(activations, str): 29 | activations = getattr(nn, activations) 30 | 31 | t = SinusoidalPosEmb(self.t_dim)(t) 32 | t = nn.Dense(self.t_dim * 2, kernel_init=default_init())(t) 33 | t = activations(t) 34 | t = nn.Dense(self.t_dim, kernel_init=default_init())(t) 35 | 36 | return t 37 | 38 | class MLP(nn.Module): 39 | hidden_dims: Sequence[int] 40 | activations: Callable[[jnp.ndarray], jnp.ndarray] | str = nn.swish 41 | activate_final: bool = False 42 | use_layer_norm: bool = False 43 | dropout_rate: Optional[float] = None 44 | 45 | @nn.compact 46 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: 47 | activations = self.activations 48 | if isinstance(activations, str): 49 | activations = getattr(nn, activations) 50 | 51 | for i, size in enumerate(self.hidden_dims): 52 | x = nn.Dense(size, kernel_init=default_init())(x) 53 | 54 | if i + 1 < len(self.hidden_dims) or self.activate_final: 55 | if self.dropout_rate is not None and self.dropout_rate > 0: 56 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 57 | if self.use_layer_norm: 58 | x = nn.LayerNorm()(x) 59 | x = activations(x) 60 | return x 61 | 62 | 63 | class MLPResNetBlock(nn.Module): 64 | features: int 65 | act: Callable 66 | dropout_rate: float = None 67 | use_layer_norm: bool = False 68 | 69 | @nn.compact 70 | def __call__(self, x, train: bool = False): 71 | residual = x 72 | if self.dropout_rate is not None and self.dropout_rate > 0: 73 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 74 | if self.use_layer_norm: 75 | x = nn.LayerNorm()(x) 76 | x = nn.Dense(self.features * 4)(x) 77 | x = self.act(x) 78 | x = nn.Dense(self.features)(x) 79 | 80 | if residual.shape != x.shape: 81 | residual = nn.Dense(self.features)(residual) 82 | 83 | return residual + x 84 | 85 | 86 | class MLPResNet(nn.Module): 87 | num_blocks: int 88 | out_dim: int 89 | dropout_rate: float = None 90 | use_layer_norm: bool = False 91 | hidden_dim: int = 256 92 | activations: Callable = nn.swish 93 | 94 | @nn.compact 95 | def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: 96 | x = nn.Dense(self.hidden_dim, kernel_init=default_init())(x) 97 | for _ in range(self.num_blocks): 98 | x = MLPResNetBlock( 99 | self.hidden_dim, 100 | act=self.activations, 101 | use_layer_norm=self.use_layer_norm, 102 | dropout_rate=self.dropout_rate, 103 | )(x, train=train) 104 | 105 | x = self.activations(x) 106 | x = nn.Dense(self.out_dim, kernel_init=default_init())(x) 107 | return x 108 | 109 | 110 | class Scalar(nn.Module): 111 | init_value: float 112 | 113 | def setup(self): 114 | self.value = self.param("value", lambda x: self.init_value) 115 | 116 | def __call__(self): 117 | return self.value 118 | 119 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/networks/reward_classifier.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import jax 3 | from jax import numpy as jnp 4 | import flax.linen as nn 5 | from flax.training.train_state import TrainState 6 | from flax.training import checkpoints 7 | import optax 8 | from typing import Callable, Dict, List 9 | 10 | 11 | from serl_launcher.vision.resnet_v1 import resnetv1_configs, PreTrainedResNetEncoder 12 | from serl_launcher.common.encoding import EncodingWrapper 13 | 14 | 15 | class BinaryClassifier(nn.Module): 16 | encoder_def: nn.Module 17 | hidden_dim: int = 256 18 | 19 | @nn.compact 20 | def __call__(self, x, train=False): 21 | x = self.encoder_def(x, train=train) 22 | x = nn.Dense(self.hidden_dim)(x) 23 | x = nn.Dropout(0.1)(x, deterministic=not train) 24 | x = nn.LayerNorm()(x) 25 | x = nn.relu(x) 26 | x = nn.Dense(1)(x) 27 | return x 28 | 29 | class NWayClassifier(nn.Module): 30 | encoder_def: nn.Module 31 | hidden_dim: int = 256 32 | n_way: int = 3 33 | 34 | @nn.compact 35 | def __call__(self, x, train=False): 36 | x = self.encoder_def(x, train=train) 37 | x = nn.Dense(self.hidden_dim)(x) 38 | x = nn.Dropout(0.1)(x, deterministic=not train) 39 | x = nn.LayerNorm()(x) 40 | x = nn.relu(x) 41 | x = nn.Dense(self.n_way)(x) 42 | return x 43 | 44 | 45 | def create_classifier( 46 | key: jnp.ndarray, 47 | sample: Dict, 48 | image_keys: List[str], 49 | pretrained_encoder_path: str = "../resnet10_params.pkl", 50 | n_way: int = 2, 51 | ): 52 | pretrained_encoder = resnetv1_configs["resnetv1-10-frozen"]( 53 | pre_pooling=True, 54 | name="pretrained_encoder", 55 | ) 56 | encoders = { 57 | image_key: PreTrainedResNetEncoder( 58 | pooling_method="spatial_learned_embeddings", 59 | num_spatial_blocks=8, 60 | bottleneck_dim=256, 61 | pretrained_encoder=pretrained_encoder, 62 | name=f"encoder_{image_key}", 63 | ) 64 | for image_key in image_keys 65 | } 66 | encoder_def = EncodingWrapper( 67 | encoder=encoders, 68 | use_proprio=False, 69 | enable_stacking=True, 70 | image_keys=image_keys, 71 | ) 72 | if n_way == 2: 73 | classifier_def = BinaryClassifier(encoder_def=encoder_def) 74 | else: 75 | classifier_def = NWayClassifier(encoder_def=encoder_def, n_way=n_way) 76 | params = classifier_def.init(key, sample)["params"] 77 | classifier = TrainState.create( 78 | apply_fn=classifier_def.apply, 79 | params=params, 80 | tx=optax.adam(learning_rate=1e-4), 81 | ) 82 | 83 | with open(pretrained_encoder_path, "rb") as f: 84 | encoder_params = pkl.load(f) 85 | param_count = sum(x.size for x in jax.tree_leaves(encoder_params)) 86 | print( 87 | f"Loaded {param_count/1e6}M parameters from ResNet-10 pretrained on ImageNet-1K" 88 | ) 89 | new_params = classifier.params 90 | for image_key in image_keys: 91 | if "pretrained_encoder" in new_params["encoder_def"][f"encoder_{image_key}"]: 92 | for k in new_params["encoder_def"][f"encoder_{image_key}"][ 93 | "pretrained_encoder" 94 | ]: 95 | if k in encoder_params: 96 | new_params["encoder_def"][f"encoder_{image_key}"][ 97 | "pretrained_encoder" 98 | ][k] = encoder_params[k] 99 | print(f"replaced {k} in encoder_{image_key}") 100 | 101 | classifier = classifier.replace(params=new_params) 102 | return classifier 103 | 104 | 105 | def load_classifier_func( 106 | key: jnp.ndarray, 107 | sample: Dict, 108 | image_keys: List[str], 109 | checkpoint_path: str, 110 | n_way: int = 2, 111 | ) -> Callable[[Dict], jnp.ndarray]: 112 | """ 113 | Return: a function that takes in an observation 114 | and returns the logits of the classifier. 115 | """ 116 | classifier = create_classifier(key, sample, image_keys, n_way=n_way) 117 | classifier = checkpoints.restore_checkpoint( 118 | checkpoint_path, 119 | target=classifier, 120 | ) 121 | func = lambda obs: classifier.apply_fn( 122 | {"params": classifier.params}, obs, train=False 123 | ) 124 | func = jax.jit(func) 125 | return func 126 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/serl_launcher/serl_launcher/utils/__init__.py -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/utils/jax_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax.core.frozen_dict import FrozenDict 4 | 5 | 6 | @jax.jit 7 | def batch_to_jax(batch): 8 | return jax.tree_util.tree_map(jax.device_put, batch) 9 | 10 | 11 | class JaxRNG(object): 12 | """A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside 13 | pure function. 14 | """ 15 | 16 | @classmethod 17 | def from_seed(cls, seed): 18 | return cls(jax.random.PRNGKey(seed)) 19 | 20 | def __init__(self, rng): 21 | self.rng = rng 22 | 23 | def __call__(self, keys=None): 24 | if keys is None: 25 | self.rng, split_rng = jax.random.split(self.rng) 26 | return split_rng 27 | elif isinstance(keys, int): 28 | split_rngs = jax.random.split(self.rng, num=keys + 1) 29 | self.rng = split_rngs[0] 30 | return tuple(split_rngs[1:]) 31 | else: 32 | split_rngs = jax.random.split(self.rng, num=len(keys) + 1) 33 | self.rng = split_rngs[0] 34 | return {key: val for key, val in zip(keys, split_rngs[1:])} 35 | 36 | 37 | def wrap_function_with_rng(rng): 38 | """To be used as decorator, automatically bookkeep a RNG for the wrapped function.""" 39 | 40 | def wrap_function(function): 41 | def wrapped(*args, **kwargs): 42 | nonlocal rng 43 | rng, split_rng = jax.random.split(rng) 44 | return function(split_rng, *args, **kwargs) 45 | 46 | return wrapped 47 | 48 | return wrap_function 49 | 50 | 51 | def init_rng(seed): 52 | global jax_utils_rng 53 | jax_utils_rng = JaxRNG.from_seed(seed) 54 | 55 | 56 | def next_rng(*args, **kwargs): 57 | global jax_utils_rng 58 | return jax_utils_rng(*args, **kwargs) 59 | 60 | 61 | def extend_and_repeat(tensor, axis, repeat): 62 | return jnp.repeat(jnp.expand_dims(tensor, axis), repeat, axis=axis) 63 | 64 | 65 | def mse_loss(val, target): 66 | return jnp.mean(jnp.square(val - target)) 67 | 68 | def append_zero(x): 69 | return jnp.concatenate([x, jnp.zeros((1,), dtype=x.dtype)]) 70 | 71 | def append_dims(x, target_dims): 72 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 73 | dims_to_append = target_dims - x.ndim 74 | if dims_to_append < 0: 75 | raise ValueError( 76 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 77 | ) 78 | return x[(...,) + (None,) * dims_to_append] 79 | 80 | def mean_flat(tensor): 81 | """ 82 | Take the mean over all non-batch dimensions. 83 | """ 84 | return jnp.mean(tensor, axis=tuple(range(1, len(tensor.shape)))) 85 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/utils/launcher.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | 3 | import jax 4 | from jax import nn 5 | import jax.numpy as jnp 6 | 7 | from agentlace.trainer import TrainerConfig 8 | 9 | from serl_launcher.common.typing import Batch, PRNGKey 10 | from serl_launcher.common.wandb import WandBLogger 11 | from serl_launcher.agents.continuous.bc import BCAgent 12 | from serl_launcher.agents.continuous.sac import SACAgent 13 | from serl_launcher.agents.continuous.sac_single import SACAgentSingleArm 14 | from serl_launcher.serl_launcher.agents.continuous.conrft_single_octo_cp import ConrftCPOctoAgentSingleArm 15 | from serl_launcher.vision.data_augmentations import batched_random_crop 16 | 17 | ############################################################################## 18 | 19 | 20 | def make_bc_agent( 21 | seed, 22 | sample_obs, 23 | sample_action, 24 | image_keys=("image",), 25 | encoder_type="resnet-pretrained" 26 | ): 27 | return BCAgent.create( 28 | jax.random.PRNGKey(seed), 29 | sample_obs, 30 | sample_action, 31 | network_kwargs={ 32 | "activations": nn.tanh, 33 | "use_layer_norm": True, 34 | "hidden_dims": [512, 512, 512], 35 | "dropout_rate": 0.25, 36 | }, 37 | policy_kwargs={ 38 | "tanh_squash_distribution": False, 39 | "std_parameterization": "exp", 40 | "std_min": 1e-5, 41 | "std_max": 5, 42 | }, 43 | use_proprio=True, 44 | encoder_type=encoder_type, 45 | image_keys=image_keys, 46 | augmentation_function=make_batch_augmentation_func(image_keys), 47 | ) 48 | 49 | 50 | def make_sac_pixel_agent( 51 | seed, 52 | sample_obs, 53 | sample_action, 54 | image_keys=("image",), 55 | encoder_type="resnet-pretrained", 56 | reward_bias=0.0, 57 | target_entropy=None, 58 | discount=0.97, 59 | fix_gripper: bool = False, 60 | ): 61 | agent = SACAgent.create_pixels( 62 | jax.random.PRNGKey(seed), 63 | sample_obs, 64 | sample_action, 65 | encoder_type=encoder_type, 66 | use_proprio=True, 67 | image_keys=image_keys, 68 | policy_kwargs={ 69 | "tanh_squash_distribution": True, 70 | "std_parameterization": "exp", 71 | "std_min": 1e-5, 72 | "std_max": 5, 73 | }, 74 | critic_network_kwargs={ 75 | "activations": nn.tanh, 76 | "use_layer_norm": True, 77 | "hidden_dims": [256, 256], 78 | }, 79 | policy_network_kwargs={ 80 | "activations": nn.tanh, 81 | "use_layer_norm": True, 82 | "hidden_dims": [256, 256], 83 | }, 84 | temperature_init=1e-2, 85 | discount=discount, 86 | fix_gripper=fix_gripper, 87 | backup_entropy=False, 88 | critic_ensemble_size=2, 89 | critic_subsample_size=None, 90 | reward_bias=reward_bias, 91 | target_entropy=target_entropy, 92 | augmentation_function=make_batch_augmentation_func(image_keys), 93 | ) 94 | return agent 95 | 96 | 97 | def make_sac_pixel_agent_single_arm( 98 | seed, 99 | sample_obs, 100 | sample_action, 101 | image_keys=("image",), 102 | encoder_type="resnet-pretrained", 103 | reward_bias=0.0, 104 | target_entropy=None, 105 | discount=0.97, 106 | ): 107 | agent = SACAgentSingleArm.create_pixels( 108 | jax.random.PRNGKey(seed), 109 | sample_obs, 110 | sample_action, 111 | encoder_type=encoder_type, 112 | use_proprio=True, 113 | image_keys=image_keys, 114 | policy_kwargs={ 115 | "tanh_squash_distribution": False, 116 | "std_parameterization": "exp", 117 | "std_min": 1e-5, 118 | "std_max": 5, 119 | }, 120 | critic_network_kwargs={ 121 | "activations": nn.tanh, 122 | "use_layer_norm": True, 123 | "hidden_dims": [256, 256], 124 | }, 125 | policy_network_kwargs={ 126 | "activations": nn.tanh, 127 | "use_layer_norm": True, 128 | "hidden_dims": [256, 256], 129 | }, 130 | temperature_init=1e-2, 131 | discount=discount, 132 | backup_entropy=False, 133 | critic_ensemble_size=2, 134 | critic_subsample_size=None, 135 | reward_bias=reward_bias, 136 | target_entropy=target_entropy, 137 | augmentation_function=make_batch_augmentation_func(image_keys), 138 | ) 139 | return agent 140 | 141 | 142 | def make_conrft_octo_cp_pixel_agent_single_arm( 143 | seed, 144 | sample_obs, 145 | sample_action, 146 | sample_tasks, 147 | octo_model, 148 | encoder_type="resnet-pretrained", 149 | image_keys=("image",), 150 | reward_bias=0.0, 151 | target_entropy=None, 152 | discount=0.97, 153 | num_scales=40, 154 | sigma_data: float = 0.5, 155 | sigma_min: float = 0.002, 156 | sigma_max: float = 80.0, 157 | rho: float = 7.0, 158 | fix_gripper: bool = False, 159 | q_weight: float = 0.1, 160 | bc_weight: float = 1.0, 161 | ): 162 | agent = ConrftCPOctoAgentSingleArm.create_pixels( 163 | jax.random.PRNGKey(seed), 164 | sample_obs, 165 | sample_action, 166 | sample_tasks, 167 | encoder_type=encoder_type, 168 | use_proprio=True, 169 | octo_model=octo_model, 170 | image_keys=image_keys, 171 | fix_gripper=fix_gripper, 172 | policy_kwargs={ 173 | "sigma_data": sigma_data, 174 | "sigma_max": sigma_max, 175 | "sigma_min": sigma_min, 176 | "rho": rho, 177 | "steps": num_scales, 178 | "clip_denoised": True, 179 | }, 180 | critic_network_kwargs={ 181 | "activations": nn.tanh, 182 | "use_layer_norm": True, 183 | "hidden_dims": [256, 256], 184 | }, 185 | policy_network_kwargs={ 186 | "activations": nn.tanh, 187 | "use_layer_norm": True, 188 | "hidden_dims": [256, 256], 189 | }, 190 | policy_t_network_kwargs={ 191 | "t_dim": 16, 192 | "activations": nn.tanh, 193 | }, 194 | num_scales=num_scales, 195 | sigma_min=sigma_min, 196 | sigma_max=sigma_max, 197 | sigma_data=sigma_data, 198 | rho=rho, 199 | discount=discount, 200 | reward_bias=reward_bias, 201 | target_entropy=target_entropy, 202 | critic_ensemble_size=2, 203 | critic_subsample_size=None, 204 | augmentation_function=make_batch_augmentation_func(image_keys), 205 | q_weight=q_weight, 206 | bc_weight=bc_weight, 207 | ) 208 | return agent 209 | 210 | 211 | def linear_schedule(step): 212 | init_value = 10.0 213 | end_value = 50.0 214 | decay_steps = 15_000 215 | 216 | linear_step = jnp.minimum(step, decay_steps) 217 | decayed_value = init_value + \ 218 | (end_value - init_value) * (linear_step / decay_steps) 219 | return decayed_value 220 | 221 | 222 | def make_batch_augmentation_func(image_keys) -> callable: 223 | 224 | def data_augmentation_fn(rng, observations): 225 | for pixel_key in image_keys: 226 | observations = observations.copy( 227 | add_or_replace={ 228 | pixel_key: batched_random_crop( 229 | observations[pixel_key], rng, padding=4, num_batch_dims=2 230 | ) 231 | } 232 | ) 233 | return observations 234 | 235 | def augment_batch(batch: Batch, rng: PRNGKey) -> Batch: 236 | rng, obs_rng, next_obs_rng = jax.random.split(rng, 3) 237 | obs = data_augmentation_fn(obs_rng, batch["observations"]) 238 | next_obs = data_augmentation_fn( 239 | next_obs_rng, batch["next_observations"]) 240 | batch = batch.copy( 241 | add_or_replace={ 242 | "observations": obs, 243 | "next_observations": next_obs, 244 | } 245 | ) 246 | return batch 247 | 248 | return augment_batch 249 | 250 | 251 | def make_trainer_config(port_number: int = 3333, broadcast_port: int = 3334): 252 | return TrainerConfig( 253 | port_number=port_number, 254 | broadcast_port=broadcast_port, 255 | request_types=["send-stats"], 256 | ) 257 | 258 | 259 | def make_wandb_logger( 260 | project: str = "conrft", 261 | description: str = "serl_launcher", 262 | debug: bool = False, 263 | ): 264 | wandb_config = WandBLogger.get_default_config() 265 | wandb_config.update( 266 | { 267 | "project": project, 268 | "exp_descriptor": description, 269 | "tag": description, 270 | } 271 | ) 272 | wandb_logger = WandBLogger( 273 | wandb_config=wandb_config, 274 | variant={}, 275 | debug=debug, 276 | ) 277 | return wandb_logger 278 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from ml_collections import ConfigDict 3 | from ml_collections.config_flags import config_flags 4 | from ml_collections.config_dict import config_dict 5 | import random 6 | import pprint 7 | import time 8 | import uuid 9 | import tempfile 10 | import os 11 | from copy import copy 12 | from socket import gethostname 13 | import cloudpickle as pickle 14 | import time 15 | from collections import deque 16 | from typing import Optional 17 | 18 | import numpy as np 19 | 20 | import gymnasium as gym 21 | 22 | class WandBLogger(object): 23 | 24 | @staticmethod 25 | def get_default_config(updates=None): 26 | config = ConfigDict() 27 | config.online = False 28 | config.prefix = 'JaxCQL' 29 | config.project = '' 30 | config.output_dir = '/tmp/JaxCQL' 31 | config.random_delay = 0.0 32 | config.experiment_id = config_dict.placeholder(str) 33 | config.anonymous = config_dict.placeholder(str) 34 | config.notes = config_dict.placeholder(str) 35 | config.entity = config_dict.placeholder(str) 36 | 37 | if updates is not None: 38 | config.update(ConfigDict(updates).copy_and_resolve_references()) 39 | return config 40 | 41 | def __init__(self, config, variant): 42 | self.config = self.get_default_config(config) 43 | 44 | if self.config.experiment_id is None: 45 | self.config.experiment_id = uuid.uuid4().hex 46 | 47 | if self.config.prefix != '': 48 | self.config.project = '{}--{}'.format(self.config.prefix, self.config.project) 49 | 50 | if self.config.output_dir == '': 51 | self.config.output_dir = tempfile.mkdtemp() 52 | else: 53 | self.config.output_dir = os.path.join(self.config.output_dir, self.config.experiment_id) 54 | os.makedirs(self.config.output_dir, exist_ok=True) 55 | 56 | self._variant = copy(variant) 57 | 58 | if 'hostname' not in self._variant: 59 | self._variant['hostname'] = gethostname() 60 | 61 | if self.config.random_delay > 0: 62 | time.sleep(np.random.uniform(0, self.config.random_delay)) 63 | 64 | self.run = wandb.init( 65 | reinit=True, 66 | config=self._variant, 67 | project=self.config.project, 68 | dir=self.config.output_dir, 69 | entity=config.entity, 70 | id=self.config.experiment_id, 71 | anonymous=self.config.anonymous, 72 | notes=self.config.notes, 73 | settings=wandb.Settings( 74 | start_method="thread", 75 | _disable_stats=True, 76 | ), 77 | mode='online' if self.config.online else 'offline', 78 | ) 79 | 80 | def log(self, *args, **kwargs): 81 | self.run.log(*args, **kwargs) 82 | 83 | def save_pickle(self, obj, filename): 84 | with open(os.path.join(self.config.output_dir, filename), 'wb') as fout: 85 | pickle.dump(obj, fout) 86 | 87 | @property 88 | def experiment_id(self): 89 | return self.config.experiment_id 90 | 91 | @property 92 | def variant(self): 93 | return self.config.variant 94 | 95 | @property 96 | def output_dir(self): 97 | return self.config.output_dir 98 | 99 | 100 | 101 | 102 | """Wrapper that tracks the cumulative rewards and episode lengths.""" 103 | class RecordEpisodeStatistics(gym.Wrapper, gym.utils.RecordConstructorArgs): 104 | """This wrapper will keep track of cumulative rewards and episode lengths. 105 | 106 | At the end of an episode, the statistics of the episode will be added to ``info`` 107 | using the key ``episode``. If using a vectorized environment also the key 108 | ``_episode`` is used which indicates whether the env at the respective index has 109 | the episode statistics. 110 | 111 | After the completion of an episode, ``info`` will look like this:: 112 | 113 | >>> info = { 114 | ... "episode": { 115 | ... "r": "", 116 | ... "l": "", 117 | ... "t": "" 118 | ... }, 119 | ... } 120 | 121 | For a vectorized environments the output will be in the form of:: 122 | 123 | >>> infos = { 124 | ... "final_observation": "", 125 | ... "_final_observation": "", 126 | ... "final_info": "", 127 | ... "_final_info": "", 128 | ... "episode": { 129 | ... "r": "", 130 | ... "l": "", 131 | ... "t": "" 132 | ... }, 133 | ... "_episode": "" 134 | ... } 135 | 136 | Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via 137 | :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively. 138 | 139 | Attributes: 140 | return_queue: The cumulative rewards of the last ``deque_size``-many episodes 141 | length_queue: The lengths of the last ``deque_size``-many episodes 142 | """ 143 | 144 | def __init__(self, env: gym.Env, deque_size: int = 100): 145 | """This wrapper will keep track of cumulative rewards and episode lengths. 146 | 147 | Args: 148 | env (Env): The environment to apply the wrapper 149 | deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue` 150 | """ 151 | gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size) 152 | super().__init__(env) 153 | 154 | try: 155 | self.num_envs = self.get_wrapper_attr("num_envs") 156 | self.is_vector_env = self.get_wrapper_attr("is_vector_env") 157 | except AttributeError: 158 | self.num_envs = 1 159 | self.is_vector_env = False 160 | 161 | self.episode_count = 0 162 | self.episode_start_times: np.ndarray = None 163 | self.episode_returns: Optional[np.ndarray] = None 164 | self.episode_lengths: Optional[np.ndarray] = None 165 | self.return_queue = deque(maxlen=deque_size) 166 | self.length_queue = deque(maxlen=deque_size) 167 | 168 | def reset(self, **kwargs): 169 | """Resets the environment using kwargs and resets the episode returns and lengths.""" 170 | obs, info = self.env.reset(**kwargs) 171 | self.episode_start_times = np.full( 172 | self.num_envs, time.perf_counter(), dtype=np.float32 173 | ) 174 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) 175 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 176 | return obs, info 177 | 178 | def step(self, action): 179 | """Steps through the environment, recording the episode statistics.""" 180 | ( 181 | observations, 182 | rewards, 183 | terminations, 184 | truncations, 185 | infos, 186 | ) = self.env.step(action) 187 | assert isinstance( 188 | infos, dict 189 | ), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order." 190 | self.episode_returns += rewards 191 | self.episode_lengths += 1 192 | dones = np.logical_or(terminations, truncations) 193 | num_dones = np.sum(dones) 194 | if num_dones: 195 | if "episode" in infos or "_episode" in infos: 196 | raise ValueError( 197 | "Attempted to add episode stats when they already exist" 198 | ) 199 | else: 200 | infos["episode"] = { 201 | "r": np.where(dones, self.episode_returns, 0.0), 202 | "l": np.where(dones, self.episode_lengths, 0), 203 | "t": np.where( 204 | dones, 205 | np.round(time.perf_counter() - self.episode_start_times, 6), 206 | 0.0, 207 | ), 208 | } 209 | if self.is_vector_env: 210 | infos["_episode"] = np.where(dones, True, False) 211 | self.return_queue.extend(self.episode_returns[dones]) 212 | self.length_queue.extend(self.episode_lengths[dones]) 213 | self.episode_count += num_dones 214 | self.episode_lengths[dones] = 0 215 | self.episode_returns[dones] = 0 216 | self.episode_start_times[dones] = time.perf_counter() 217 | return ( 218 | observations, 219 | rewards, 220 | terminations, 221 | truncations, 222 | infos, 223 | ) 224 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/utils/timer_utils.py: -------------------------------------------------------------------------------- 1 | """Timer utility.""" 2 | 3 | import time 4 | from collections import defaultdict 5 | 6 | 7 | class _TimerContextManager: 8 | def __init__(self, timer: "Timer", key: str): 9 | self.timer = timer 10 | self.key = key 11 | 12 | def __enter__(self): 13 | self.timer.tick(self.key) 14 | 15 | def __exit__(self, exc_type, exc_value, exc_traceback): 16 | self.timer.tock(self.key) 17 | 18 | 19 | class Timer: 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.counts = defaultdict(int) 25 | self.times = defaultdict(float) 26 | self.start_times = {} 27 | 28 | def tick(self, key): 29 | if key in self.start_times: 30 | raise ValueError(f"Timer is already ticking for key: {key}") 31 | self.start_times[key] = time.time() 32 | 33 | def tock(self, key): 34 | if key not in self.start_times: 35 | raise ValueError(f"Timer is not ticking for key: {key}") 36 | self.counts[key] += 1 37 | self.times[key] += time.time() - self.start_times[key] 38 | del self.start_times[key] 39 | 40 | def context(self, key): 41 | """ 42 | Use this like: 43 | 44 | with timer.context("key"): 45 | # do stuff 46 | 47 | Then timer.tock("key") will be called automatically. 48 | """ 49 | return _TimerContextManager(self, key) 50 | 51 | def get_average_times(self, reset=True): 52 | ret = {key: self.times[key] / self.counts[key] for key in self.counts} 53 | if reset: 54 | self.reset() 55 | return ret 56 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def ema(series, alpha=0.5): 5 | """ 6 | Exponential moving average 7 | :param series: the input series 8 | :param alpha: the smoothing factor 9 | :return: the smoothed series 10 | """ 11 | smoothed = np.zeros_like(series, dtype=float) 12 | smoothed[0] = series[0] 13 | for i in range(1, len(series)): 14 | smoothed[i] = alpha * series[i] + (1 - alpha) * smoothed[i - 1] 15 | return smoothed 16 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import requests 4 | from collections import defaultdict 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | 8 | import imageio 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | import tensorflow as tf 13 | import wandb 14 | from flax.core import frozen_dict 15 | from flax.training import checkpoints 16 | 17 | 18 | def ask_for_frame(images_dict): 19 | # Create a new figure 20 | fig, axes = plt.subplots(5, 5, figsize=(15, 20)) 21 | 22 | # Flatten the axes array for easier indexing 23 | axes = axes.flatten() 24 | for i, (idx, img) in enumerate(images_dict.items()): 25 | # Display the image 26 | axes[i].imshow(img) 27 | 28 | # Remove axis ticks 29 | axes[i].set_xticks([]) 30 | axes[i].set_yticks([]) 31 | 32 | # Overlay the index number 33 | axes[i].text(10, 30, str(idx), color='white', fontsize=12, 34 | bbox=dict(facecolor='black', alpha=0.7)) 35 | 36 | plt.tight_layout() 37 | plt.show(block=False) 38 | 39 | while True: 40 | try: 41 | first_success = int(input("First success frame number: ")) 42 | assert first_success in images_dict.keys() 43 | break 44 | except: 45 | continue 46 | 47 | plt.close(fig) 48 | 49 | return first_success 50 | 51 | 52 | def concat_batches(offline_batch, online_batch, axis=1): 53 | batch = defaultdict(list) 54 | 55 | if not isinstance(offline_batch, dict): 56 | offline_batch = offline_batch.unfreeze() 57 | 58 | if not isinstance(online_batch, dict): 59 | online_batch = online_batch.unfreeze() 60 | 61 | for k, v in offline_batch.items(): 62 | if type(v) is dict: 63 | batch[k] = concat_batches( 64 | offline_batch[k], online_batch[k], axis=axis) 65 | else: 66 | batch[k] = jnp.concatenate( 67 | (offline_batch[k], online_batch[k]), axis=axis) 68 | 69 | return frozen_dict.freeze(batch) 70 | 71 | 72 | def load_recorded_video( 73 | video_path: str, 74 | ): 75 | with tf.io.gfile.GFile(video_path, "rb") as f: 76 | video = np.array(imageio.mimread(f, "MP4")).transpose((0, 3, 1, 2)) 77 | assert video.shape[1] == 3, "Numpy array should be (T, C, H, W)" 78 | 79 | return wandb.Video(video, fps=20) 80 | 81 | 82 | def _unpack(batch): 83 | """ 84 | Helps to minimize CPU to GPU transfer. 85 | Assuming that if next_observation is missing, it's combined with observation: 86 | 87 | :param batch: a batch of data from the replay buffer, a dataset dict 88 | :return: a batch of unpacked data, a dataset dict 89 | """ 90 | 91 | for pixel_key in batch["observations"].keys(): 92 | if pixel_key not in batch["next_observations"]: 93 | obs_pixels = batch["observations"][pixel_key][:, :-1, ...] 94 | next_obs_pixels = batch["observations"][pixel_key][:, 1:, ...] 95 | 96 | obs = batch["observations"].copy(add_or_replace={pixel_key: obs_pixels}) 97 | next_obs = batch["next_observations"].copy(add_or_replace={pixel_key: next_obs_pixels}) 98 | batch = batch.copy(add_or_replace={"observations": obs, "next_observations": next_obs}) 99 | 100 | return batch 101 | 102 | 103 | def load_resnet10_params(agent, image_keys=("image",), public=True): 104 | """ 105 | Load pretrained resnet10 params from github release to an agent. 106 | :return: agent with pretrained resnet10 params 107 | """ 108 | file_name = "resnet10_params.pkl" 109 | if not public: # if github repo is not public, load from local file 110 | with open(file_name, "rb") as f: 111 | encoder_params = pkl.load(f) 112 | else: # when repo is released, download from url 113 | # Construct the full path to the file 114 | file_path = os.path.expanduser("~/.serl/") 115 | if not os.path.exists(file_path): 116 | os.makedirs(file_path) 117 | file_path = os.path.join(file_path, file_name) 118 | # Check if the file exists 119 | if os.path.exists(file_path): 120 | print(f"The ResNet-10 weights already exist at '{file_path}'.") 121 | else: 122 | url = f"https://github.com/rail-berkeley/serl/releases/download/resnet10/{file_name}" 123 | print(f"Downloading file from {url}") 124 | 125 | # Streaming download with progress bar 126 | try: 127 | response = requests.get(url, stream=True) 128 | total_size = int(response.headers.get("content-length", 0)) 129 | block_size = 1024 # 1 Kibibyte 130 | t = tqdm(total=total_size, unit="iB", unit_scale=True) 131 | with open(file_path, "wb") as f: 132 | for data in response.iter_content(block_size): 133 | t.update(len(data)) 134 | f.write(data) 135 | t.close() 136 | if total_size != 0 and t.n != total_size: 137 | raise Exception( 138 | "Error, something went wrong with the download") 139 | except Exception as e: 140 | raise RuntimeError(e) 141 | print("Download complete!") 142 | 143 | with open(file_path, "rb") as f: 144 | encoder_params = pkl.load(f) 145 | 146 | param_count = sum(x.size for x in jax.tree_leaves(encoder_params)) 147 | print( 148 | f"Loaded {param_count/1e6}M parameters from ResNet-10 pretrained on ImageNet-1K" 149 | ) 150 | 151 | new_params = agent.state.params 152 | 153 | for image_key in image_keys: 154 | if "modules_calql_actor" in new_params.keys(): 155 | new_encoder_params = new_params["modules_calql_actor"]["encoder"][ 156 | f"encoder_{image_key}" 157 | ] 158 | else: 159 | if "modules_critic" not in new_params.keys(): 160 | new_encoder_params = new_params["modules_actor"]["encoder"][ 161 | f"encoder_{image_key}" 162 | ] 163 | else: 164 | if "encoder" not in new_params["modules_critic"].keys(): 165 | new_encoder_params = new_params["modules_actor"]["encoder"][ 166 | f"encoder_{image_key}" 167 | ] 168 | else: 169 | new_encoder_params = new_params["modules_critic"]["encoder"][ 170 | f"encoder_{image_key}" 171 | ] 172 | 173 | if "pretrained_encoder" in new_encoder_params: 174 | new_encoder_params = new_encoder_params["pretrained_encoder"] 175 | for k in new_encoder_params: 176 | if k in encoder_params: 177 | new_encoder_params[k] = encoder_params[k] 178 | print(f"replaced {k} in pretrained_encoder") 179 | 180 | agent = agent.replace(state=agent.state.replace(params=new_params)) 181 | return agent 182 | 183 | 184 | def get_weightings(weight_schedule, snrs, sigma_data): 185 | if weight_schedule == "snr": 186 | weightings = snrs 187 | elif weight_schedule == "snr+1": 188 | weightings = snrs + 1 189 | elif weight_schedule == "karras": 190 | weightings = snrs + 1.0 / sigma_data**2 191 | elif weight_schedule == "uniform": 192 | weightings = jnp.ones_like(snrs) 193 | else: 194 | raise NotImplementedError() 195 | return weightings 196 | 197 | 198 | def get_snr(sigmas): 199 | return sigmas**-2 200 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from serl_launcher.vision.resnet_v1 import resnetv1_configs 2 | 3 | encoders = dict() 4 | encoders.update(resnetv1_configs) 5 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/vision/film_conditioning_layer.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/google-research/robotics_transformer/blob/master/film_efficientnet/film_conditioning_layer.py 2 | import flax.linen as nn 3 | import jax.numpy as jnp 4 | 5 | 6 | class FilmConditioning(nn.Module): 7 | @nn.compact 8 | def __call__(self, conv_filters: jnp.ndarray, conditioning: jnp.ndarray): 9 | """Applies FiLM conditioning to a convolutional feature map. 10 | 11 | Args: 12 | conv_filters: A tensor of shape [batch_size, height, width, channels]. 13 | conditioning: A tensor of shape [batch_size, conditioning_size]. 14 | 15 | Returns: 16 | A tensor of shape [batch_size, height, width, channels]. 17 | """ 18 | projected_cond_add = nn.Dense( 19 | features=conv_filters.shape[-1], 20 | kernel_init=nn.initializers.zeros, 21 | bias_init=nn.initializers.zeros, 22 | )(conditioning) 23 | projected_cond_mult = nn.Dense( 24 | features=conv_filters.shape[-1], 25 | kernel_init=nn.initializers.zeros, 26 | bias_init=nn.initializers.zeros, 27 | )(conditioning) 28 | 29 | projected_cond_add = projected_cond_add[..., None, None, :] 30 | projected_cond_mult = projected_cond_mult[..., None, None, :] 31 | 32 | return conv_filters * (1 + projected_cond_add) + projected_cond_mult 33 | 34 | 35 | if __name__ == "__main__": 36 | import jax 37 | import jax.numpy as jnp 38 | 39 | key = jax.random.PRNGKey(0) 40 | key, subkey = jax.random.split(key) 41 | x = jax.random.normal(subkey, (1, 32, 32, 3)) 42 | x = jnp.array(x) 43 | 44 | z = jnp.ones((1, 64)) 45 | film = FilmConditioning() 46 | params = film.init(key, x, z) 47 | y = film.apply(params, x, z) 48 | 49 | print(y.shape) 50 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/vision/spatial.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Callable 2 | import flax.linen as nn 3 | import jax.numpy as jnp 4 | 5 | 6 | class SpatialLearnedEmbeddings(nn.Module): 7 | height: int 8 | width: int 9 | channel: int 10 | num_features: int = 5 11 | param_dtype: jnp.dtype = jnp.float32 12 | 13 | @nn.compact 14 | def __call__(self, features): 15 | """ 16 | features is B x H x W X C 17 | """ 18 | squeeze = False 19 | if len(features.shape) == 3: 20 | features = jnp.expand_dims(features, 0) 21 | squeeze = True 22 | 23 | kernel = self.param( 24 | "kernel", 25 | nn.initializers.lecun_normal(), 26 | (self.height, self.width, self.channel, self.num_features), 27 | self.param_dtype, 28 | ) 29 | 30 | batch_size = features.shape[0] 31 | features = jnp.sum( 32 | jnp.expand_dims(features, -1) * jnp.expand_dims(kernel, 0), axis=(1, 2) 33 | ) 34 | features = jnp.reshape(features, [batch_size, -1]) 35 | if squeeze: 36 | features = jnp.squeeze(features, 0) 37 | return features 38 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/wrappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/serl_launcher/serl_launcher/wrappers/__init__.py -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/wrappers/chunking.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Optional 3 | 4 | import gymnasium as gym 5 | import gymnasium.spaces 6 | import jax 7 | import numpy as np 8 | 9 | 10 | def stack_obs(obs): 11 | dict_list = {k: [dic[k] for dic in obs] for k in obs[0]} 12 | return jax.tree_map( 13 | lambda x: np.stack(x), dict_list, is_leaf=lambda x: isinstance(x, list) 14 | ) 15 | 16 | 17 | def space_stack(space: gym.Space, repeat: int): 18 | if isinstance(space, gym.spaces.Box): 19 | return gym.spaces.Box( 20 | low=np.repeat(space.low[None], repeat, axis=0), 21 | high=np.repeat(space.high[None], repeat, axis=0), 22 | dtype=space.dtype, 23 | ) 24 | elif isinstance(space, gym.spaces.Discrete): 25 | return gym.spaces.MultiDiscrete([space.n] * repeat) 26 | elif isinstance(space, gym.spaces.Dict): 27 | return gym.spaces.Dict( 28 | {k: space_stack(v, repeat) for k, v in space.spaces.items()} 29 | ) 30 | else: 31 | raise TypeError() 32 | 33 | 34 | class ChunkingWrapper(gym.Wrapper): 35 | """ 36 | Enables observation histories and receding horizon control. 37 | 38 | Accumulates observations into obs_horizon size chunks. Starts by repeating the first obs. 39 | 40 | Executes act_exec_horizon actions in the environment. 41 | """ 42 | 43 | def __init__(self, env: gym.Env, obs_horizon: int, act_exec_horizon: Optional[int]): 44 | super().__init__(env) 45 | self.env = env 46 | self.obs_horizon = obs_horizon 47 | self.act_exec_horizon = act_exec_horizon 48 | 49 | self.current_obs = deque(maxlen=self.obs_horizon) 50 | 51 | self.observation_space = space_stack( 52 | self.env.observation_space, self.obs_horizon 53 | ) 54 | if self.act_exec_horizon is None: 55 | self.action_space = self.env.action_space 56 | else: 57 | self.action_space = space_stack( 58 | self.env.action_space, self.act_exec_horizon 59 | ) 60 | 61 | def step(self, action, *args): 62 | act_exec_horizon = self.act_exec_horizon 63 | if act_exec_horizon is None: 64 | action = [action] 65 | act_exec_horizon = 1 66 | 67 | assert len(action) >= act_exec_horizon 68 | 69 | for i in range(act_exec_horizon): 70 | obs, reward, done, trunc, info = self.env.step(action[i], *args) 71 | self.current_obs.append(obs) 72 | return (stack_obs(self.current_obs), reward, done, trunc, info) 73 | 74 | def reset(self, **kwargs): 75 | obs, info = self.env.reset(**kwargs) 76 | self.current_obs.extend([obs] * self.obs_horizon) 77 | return stack_obs(self.current_obs), info 78 | 79 | 80 | def post_stack_obs(obs, obs_horizon=1): 81 | if obs_horizon != 1: 82 | # TODO: Support proper stacking 83 | raise NotImplementedError("Only obs_horizon=1 is supported for now") 84 | obs = {k: v[None] for k, v in obs.items()} 85 | return obs -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/wrappers/front_camera_wrapper.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import threading 3 | import gymnasium as gym 4 | from gymnasium.core import Env 5 | from copy import deepcopy 6 | from franka_env.camera.video_capture import VideoCapture 7 | from franka_env.camera.rs_capture import RSCapture 8 | import cv2 9 | 10 | class FrontCameraWrapper(gym.ObservationWrapper): 11 | def __init__(self, env: Env): 12 | super().__init__(env) 13 | front_obs_space = { 14 | k: space for k, space in self.observation_space.items() if "wrist" not in k 15 | } 16 | 17 | self.front_observation_space = gym.spaces.Dict(front_obs_space) 18 | # self.observation_space = gym.spaces.Dict(new_obs_space) 19 | self.front_obs = None 20 | 21 | def observation(self, observation): 22 | # cache a copy of observation with only the front camera image 23 | new_obs = deepcopy(observation) 24 | new_obs.pop("wrist_1") 25 | self.front_obs = new_obs 26 | 27 | return observation 28 | 29 | def get_front_cam_obs(self): 30 | return self.front_obs 31 | 32 | class NewFrontCameraWrapper(gym.ObservationWrapper): 33 | """ 34 | This wrapper is using front camera to train classifier only. The wrapped env 35 | should have a front camera as part of the observation space. The resultant env 36 | should not have a front camera as part of the observation space. The front camera 37 | image should be saved and retrieved by get_front_cam_obs method. 38 | """ 39 | def __init__(self, env: Env): 40 | super().__init__(env) 41 | # self.observation_space = gym.spaces.Dict({ 42 | # k: space for k, space in self.observation_space.items() if "side" not in k 43 | # }) 44 | self.front_obs = None 45 | self.cap = VideoCapture( 46 | RSCapture(name="side", serial_number="128422272758", depth=False) 47 | ) 48 | # self.img_queue = queue.Queue() 49 | # self.displayer = ImageDisplayer(self.img_queue, "reward_image") 50 | # self.displayer.start() 51 | 52 | def observation(self, observation): 53 | # cache a copy of observation with only the front camera image 54 | image = self.cap.read() 55 | image = image[180:300, 170:280] 56 | image = cv2.resize(image, (128, 128))[None, ...] 57 | self.front_obs = deepcopy(image) 58 | # observation.pop("side") 59 | return observation 60 | 61 | def get_front_cam_obs(self): 62 | return self.front_obs 63 | 64 | 65 | # class ImageDisplayer(threading.Thread): 66 | # def __init__(self, queue, name): 67 | # threading.Thread.__init__(self) 68 | # self.queue = queue 69 | # self.daemon = True # make this a daemon thread 70 | # self.name = name 71 | 72 | # def run(self): 73 | # while True: 74 | # img_array = self.queue.get() # retrieve an image from the queue 75 | # if img_array is None: # None is our signal to exit 76 | # break 77 | # print(img_array.sum()) 78 | # cv2.imshow(self.name, img_array) 79 | # cv2.waitKey(1) -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/wrappers/norm.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | 4 | class UnnormalizeActionProprio(gym.ActionWrapper, gym.ObservationWrapper): 5 | """ 6 | Un-normalizes the action and proprio. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | env: gym.Env, 12 | action_proprio_metadata: dict, 13 | normalization_type: str = "normal", 14 | ): 15 | self.action_proprio_metadata = action_proprio_metadata 16 | self.normalization_type = normalization_type 17 | super().__init__(env) 18 | 19 | def unnormalize(self, data, metadata): 20 | if self.normalization_type == "normal": 21 | return (data * metadata["std"]) + metadata["mean"] 22 | elif self.normalization_type == "bounds": 23 | return (data * (metadata["max"] - metadata["min"])) + metadata["min"] 24 | else: 25 | raise ValueError( 26 | f"Unknown action/proprio normalization type: {self.normalization_type}" 27 | ) 28 | 29 | def action(self, action): 30 | return self.unnormalize(action, self.action_proprio_metadata["action"]) 31 | 32 | def observation(self, obs): 33 | obs["proprio"] = self.unnormalize( 34 | obs["proprio"], self.action_proprio_metadata["proprio"] 35 | ) 36 | return obs 37 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/wrappers/remap.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import gymnasium as gym 4 | import gymnasium.spaces 5 | import jax 6 | 7 | 8 | class RemapWrapper(gym.ObservationWrapper): 9 | def __init__(self, env: gym.Env, new_structure: Any): 10 | """ 11 | Remap a dictionary observation space to some other flat structure specified by keys. 12 | 13 | Params: 14 | env: Environment to wrap. 15 | new_structure: A tuple/dictionary/singleton where leaves are keys in the original observation space. 16 | """ 17 | super().__init__(env) 18 | self.new_structure = new_structure 19 | 20 | if isinstance(new_structure, tuple): 21 | self.observation_space = gym.spaces.Tuple( 22 | [env.observation_space[v] for v in new_structure] 23 | ) 24 | elif isinstance(new_structure, dict): 25 | self.observation_space = gym.spaces.Dict( 26 | {k: env.observation_space[v] for k, v in new_structure.items()} 27 | ) 28 | elif isinstance(new_structure, str): 29 | self.observation_space = env.observation_space[new_structure] 30 | else: 31 | raise TypeError(f"Unsupported type {type(new_structure)}") 32 | 33 | def observation(self, observation): 34 | return jax.tree_map(lambda x: observation[x], self.new_structure) 35 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from gymnasium.spaces import flatten_space, flatten 3 | 4 | 5 | class SERLObsWrapper(gym.ObservationWrapper): 6 | """ 7 | This observation wrapper treat the observation space as a dictionary 8 | of a flattened state space and the images. 9 | """ 10 | 11 | def __init__(self, env, proprio_keys=None): 12 | super().__init__(env) 13 | self.proprio_keys = proprio_keys 14 | if self.proprio_keys is None: 15 | self.proprio_keys = list(self.env.observation_space["state"].keys()) 16 | 17 | self.proprio_space = gym.spaces.Dict( 18 | {key: self.env.observation_space["state"][key] for key in self.proprio_keys} 19 | ) 20 | 21 | self.observation_space = gym.spaces.Dict( 22 | { 23 | "state": flatten_space(self.proprio_space), 24 | **(self.env.observation_space["images"]), 25 | } 26 | ) 27 | 28 | def observation(self, obs): 29 | obs = { 30 | "state": flatten( 31 | self.proprio_space, 32 | {key: obs["state"][key] for key in self.proprio_keys}, 33 | ), 34 | **(obs["images"]), 35 | } 36 | return obs 37 | 38 | def reset(self, **kwargs): 39 | obs, info = self.env.reset(**kwargs) 40 | return self.observation(obs), info 41 | 42 | def flatten_observations(obs, proprio_space, proprio_keys): 43 | obs = { 44 | "state": flatten( 45 | proprio_space, 46 | {key: obs["state"][key] for key in proprio_keys}, 47 | ), 48 | **(obs["images"]), 49 | } 50 | return obs -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/wrappers/video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | import gymnasium as gym 5 | import imageio 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | # Take from 10 | # https://github.com/denisyarats/pytorch_sac/ 11 | 12 | 13 | def compose_frames( 14 | all_frames: List[np.ndarray], 15 | num_videos_per_row: int, 16 | margin: int = 4, 17 | ): 18 | num_episodes = len(all_frames) 19 | 20 | if num_videos_per_row is None: 21 | num_videos_per_row = num_episodes 22 | 23 | t = 0 24 | end_of_all_epidoes = False 25 | frames_to_save = [] 26 | while not end_of_all_epidoes: 27 | frames_t = [] 28 | 29 | for i in range(num_episodes): 30 | # If the episode is shorter, repeat the last frame. 31 | t_ = min(t, len(all_frames[i]) - 1) 32 | frame_i_t = all_frames[i][t_] 33 | 34 | # Add the lines. 35 | frame_i_t = np.pad( 36 | frame_i_t, 37 | [[margin, margin], [margin, margin], [0, 0]], 38 | "constant", 39 | constant_values=0, 40 | ) 41 | 42 | frames_t.append(frame_i_t) 43 | 44 | # Arrange the videos based on num_videos_per_row. 45 | frame_t = None 46 | while len(frames_t) >= num_videos_per_row: 47 | frames_t_this_row = frames_t[:num_videos_per_row] 48 | frames_t = frames_t[num_videos_per_row:] 49 | 50 | frame_t_this_row = np.concatenate(frames_t_this_row, axis=1) 51 | if frame_t is None: 52 | frame_t = frame_t_this_row 53 | else: 54 | frame_t = np.concatenate([frame_t, frame_t_this_row], axis=0) 55 | 56 | frames_to_save.append(frame_t) 57 | t += 1 58 | end_of_all_epidoes = all([len(all_frames[i]) <= t for i in range(num_episodes)]) 59 | 60 | return frames_to_save 61 | 62 | 63 | class VideoRecorder(gym.Wrapper): 64 | def __init__( 65 | self, 66 | env: gym.Env, 67 | save_folder: str = "", 68 | save_prefix: str = None, 69 | height: int = 128, 70 | width: int = 128, 71 | fps: int = 30, 72 | camera_id: int = 0, 73 | goal_conditioned: bool = False, 74 | ): 75 | super().__init__(env) 76 | 77 | self.save_folder = save_folder 78 | self.save_prefix = save_prefix 79 | self.height = height 80 | self.width = width 81 | self.fps = fps 82 | self.camera_id = camera_id 83 | self.frames = [] 84 | self.goal_conditioned = goal_conditioned 85 | 86 | if not tf.io.gfile.exists(save_folder): 87 | tf.io.gfile.makedirs(save_folder) 88 | 89 | self.num_record_episodes = -1 90 | 91 | self.num_videos = 0 92 | 93 | # self.all_save_paths = None 94 | self.current_save_path = None 95 | 96 | def start_recording(self, num_episodes: int = None, num_videos_per_row: int = None): 97 | if num_videos_per_row is not None and num_episodes is not None: 98 | assert num_episodes >= num_videos_per_row 99 | 100 | self.num_record_episodes = num_episodes 101 | self.num_videos_per_row = num_videos_per_row 102 | 103 | # self.all_save_paths = [] 104 | self.all_frames = [] 105 | 106 | def stop_recording(self): 107 | self.num_record_episodes = None 108 | 109 | def step(self, action: np.ndarray): # NOQA 110 | 111 | if self.num_record_episodes is None or self.num_record_episodes == 0: 112 | observation, reward, terminated, truncated, info = self.env.step(action) 113 | 114 | elif self.num_record_episodes > 0: 115 | frame = self.env.render( 116 | height=self.height, width=self.width, camera_id=self.camera_id 117 | ) 118 | 119 | if frame is None: 120 | try: 121 | frame = self.sim.render( 122 | width=self.width, height=self.height, mode="offscreen" 123 | ) 124 | frame = np.flipud(frame) 125 | except Exception: 126 | raise NotImplementedError("Rendering is not implemented.") 127 | 128 | self.frames.append(frame.astype(np.uint8)) 129 | 130 | observation, reward, terminated, truncated, info = self.env.step(action) 131 | 132 | if terminated or truncated: 133 | if self.goal_conditioned: 134 | frames = [ 135 | np.concatenate([self.env.current_goal["image"], frame], axis=0) 136 | for frame in self.frames 137 | ] 138 | else: 139 | frames = self.frames 140 | 141 | self.all_frames.append(frames) 142 | 143 | if self.num_record_episodes > 0: 144 | self.num_record_episodes -= 1 145 | 146 | if self.num_record_episodes is None: 147 | # Plot one episode per file. 148 | frames_to_save = frames 149 | should_save = True 150 | elif self.num_record_episodes == 0: 151 | # Plot all episodes in one file. 152 | frames_to_save = compose_frames( 153 | self.all_frames, self.num_videos_per_row 154 | ) 155 | should_save = True 156 | else: 157 | should_save = False 158 | 159 | if should_save: 160 | filename = "%08d.mp4" % (self.num_videos) 161 | if self.save_prefix is not None and self.save_prefix != "": 162 | filename = f"{self.save_prefix}_{filename}" 163 | self.current_save_path = tf.io.gfile.join( 164 | self.save_folder, filename 165 | ) 166 | 167 | with tf.io.gfile.GFile(self.current_save_path, "wb") as f: 168 | imageio.mimsave(f, frames_to_save, "MP4", fps=self.fps) 169 | 170 | self.num_videos += 1 171 | 172 | self.frames = [] 173 | 174 | else: 175 | raise ValueError("Do not forget to call start_recording.") 176 | 177 | return observation, reward, terminated, truncated, info 178 | -------------------------------------------------------------------------------- /serl_launcher/serl_launcher/wrappers/video_wrapper.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import gym 3 | import numpy as np 4 | 5 | 6 | class VideoWrapper(gym.Wrapper): 7 | def __init__( 8 | self, 9 | env: gym.Env, 10 | name: str = "video", 11 | ): 12 | super().__init__(env) 13 | self._name = name 14 | self._video = OrderedDict() 15 | self.image_keys = [k for k in self.observation_space.keys() if k != "state"] 16 | 17 | def get_obs_frames(self, keys=None): 18 | if keys is None: 19 | video = {k: np.array(v) for k, v in self._video.items()} 20 | else: 21 | video = {k: np.array(v) for k, v in self._video.items() if k in keys} 22 | return video 23 | 24 | def get_rendered_video(self): 25 | frames = [] 26 | for i in range(len(self._video[self.image_keys[0]])): 27 | frame = [] 28 | for k in self.image_keys: 29 | frame.append(self._video[k][i]) 30 | frames.append(np.concatenate(frame, axis=1)) 31 | return np.concatenate(frames, axis=0) 32 | 33 | def _add_frame(self, obs): 34 | img = [] 35 | for k in self.image_keys: 36 | if k in obs: 37 | if k in self._video: 38 | self._video[k].append(obs[k]) 39 | else: 40 | self._video[k] = [obs[k]] 41 | 42 | def reset(self, **kwargs): 43 | self._video.clear() 44 | obs, info = super().reset(**kwargs) 45 | self._add_frame(obs) 46 | return obs, info 47 | 48 | def step(self, action: np.ndarray): 49 | obs, reward, done, truncate, info = super().step(action) 50 | self._add_frame(obs) 51 | return obs, reward, done, truncate, info 52 | -------------------------------------------------------------------------------- /serl_launcher/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="serl_launcher", 5 | version="0.1.2", 6 | description="library for rl experiments", 7 | url="https://github.com/rail-berkeley/serl", 8 | author="auth", 9 | license="MIT", 10 | install_requires=[ 11 | "zmq", 12 | "typing", 13 | "typing_extensions", 14 | "opencv-python", 15 | "lz4", 16 | "agentlace@git+https://github.com/youliangtan/agentlace.git@cf2c337c5e3694cdbfc14831b239bd657bc4894d", 17 | ], 18 | packages=find_packages(), 19 | zip_safe=False, 20 | ) 21 | -------------------------------------------------------------------------------- /serl_robot_infra/README.md: -------------------------------------------------------------------------------- 1 | # SERL Robot Infra 2 | ![](../docs/images/robot_infra_interfaces.png) 3 | 4 | All robot code is structured as follows: 5 | There is a Flask server which sends commands to the robot via ROS. There is a gym env for the robot which communicates with the Flask server via post requests. 6 | 7 | - `robot_server`: hosts a Flask server which sends commands to the robot via ROS 8 | - `franka_env`: gym env for the robot which communicates with the Flask server via post requests 9 | 10 | 11 | ### Installation 12 | 13 | 1. Install `libfranka` and `franka_ros` with instructions [here](https://frankaemika.github.io/docs/requirements.html). 14 | 15 | 2. Then install the `serl_franka_controllers` from https://github.com/rail-berkeley/serl_franka_controllers 16 | 17 | 3. Then, install this package and it's dependencies. 18 | ```bash 19 | conda activate hilserl 20 | pip install -e . 21 | ``` 22 | 23 | ### Usage 24 | 25 | **Robot Server** 26 | 27 | To start using the robot, first power on the robot (small switch on the back of robot control box on the floor). Unlock the robot from the browser interface by going to robot IP address in your browser, then press the black and white button to put the robot in FCI control mode (blue light). 28 | 29 | From there you should be able to navigate to `serl_robot_infra` and then simply run the franka server. This requires to be in a ROS environment. 30 | 31 | ```bash 32 | conda activate hilserl 33 | 34 | # script to start http server and ros controller 35 | python serl_robo_infra/robot_servers/franka_server.py \ 36 | --gripper_type= 37 | --robot_ip= 38 | --gripper_ip=<[Optional] Robotiq_gripper_IP> 39 | --reset_joint_target=<[Optional] robot_joints_when_robot_resets> 40 | ``` 41 | 42 | This should start ROS node impedence controller and the HTTP server. You can test that things are running by trying to move the end effector around, if the impedence controller is running it should be compliant. 43 | 44 | The HTTP server is used to communicate between the ROS controller and gym environments. Possible HTTP requests include: 45 | 46 | | Request | Description | 47 | | --- | --- | 48 | | startimp | Stop the impedance controller | 49 | | stopimp | Start the impedance controller | 50 | | pose | Command robot to go to desired end-effector pose given in base frame (xyz+quaternion) | 51 | | getpos | Return current end-effector pose in robot base frame (xyz+rpy)| 52 | | getvel | Return current end-effector velocity in robot base frame | 53 | | getforce | Return estimated force on end-effector in stiffness frame | 54 | | gettorque | Return estimated torque on end-effector in stiffness frame | 55 | | getq | Return current joint position | 56 | | getdq | Return current joint velocity | 57 | | getjacobian | Return current zero-jacobian | 58 | | getstate | Return all robot states | 59 | | jointreset | Perform joint reset | 60 | | activate_gripper | Activate the gripper (Robotiq only) | 61 | | reset_gripper | Reset the gripper (Robotiq only) | 62 | | get_gripper | Return current gripper position | 63 | | close_gripper | Close the gripper completely | 64 | | open_gripper | Open the gripper completely | 65 | | move_gripper | Move the gripper to a given position | 66 | | clearerr | Clear errors | 67 | | update_param | Update the impedance controller parameters | 68 | 69 | These commands can also be called in terminal. Useful ones include: 70 | ```bash 71 | curl -X POST http://127.0.0.2:5000/activate_gripper # Activate gripper 72 | curl -X POST http://127.0.0.2:5000/close_gripper # Close gripper 73 | curl -X POST http://127.0.0.2:5000/open_gripper # Open gripper 74 | curl -X POST http://127.0.0.1:5000/getpos # Print current end-effector pose 75 | curl -X POST http://127.0.0.1:5000/jointreset # Perform joint reset 76 | curl -X POST http://127.0.0.1:5000/stopimp # Stop the impedance controller 77 | curl -X POST http://127.0.0.1:5000/startimp # Start the impedance controller (**Only run this after stopimp**) 78 | curl -X POST http://127.0.0.1:5000/getpos_euler # Get Euler pose 79 | ``` -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/camera/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/serl_robot_infra/franka_env/camera/__init__.py -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/camera/multi_video_capture.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import threading 3 | import time 4 | import numpy as np 5 | from collections import OrderedDict 6 | 7 | class MultiVideoCapture: 8 | def __init__(self, caps): 9 | self.caps = caps 10 | self.queue = queue.Queue() 11 | self.t = threading.Thread(target=self._reader) 12 | self.t.daemon = False 13 | self.enable = True 14 | self.t.start() 15 | 16 | def _reader(self): 17 | while self.enable: 18 | frames = OrderedDict() 19 | for name, cap in self.caps.items(): 20 | ret, frame = cap.read() 21 | if ret: 22 | frames[name] = frame 23 | 24 | if frames: 25 | if not self.queue.empty(): 26 | try: 27 | self.queue.get_nowait() # discard previous (unprocessed) frame 28 | except queue.Empty: 29 | pass 30 | self.queue.put(frames) 31 | 32 | def read(self): 33 | return self.queue.get(timeout=5) 34 | 35 | def close(self): 36 | self.enable = False 37 | self.t.join() 38 | for cap in self.caps.values(): 39 | cap.close() 40 | -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/camera/rs_capture.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyrealsense2 as rs # Intel RealSense cross-platform open-source API 3 | 4 | 5 | class RSCapture: 6 | def get_device_serial_numbers(self): 7 | devices = rs.context().devices 8 | return [d.get_info(rs.camera_info.serial_number) for d in devices] 9 | 10 | def __init__(self, name, serial_number, dim=(640, 480), fps=15, depth=False, exposure=40000): 11 | self.name = name 12 | print(serial_number) 13 | print(self.get_device_serial_numbers()) 14 | assert serial_number in self.get_device_serial_numbers() 15 | self.serial_number = serial_number 16 | self.depth = depth 17 | self.pipe = rs.pipeline() 18 | self.cfg = rs.config() 19 | self.cfg.enable_device(self.serial_number) 20 | self.cfg.enable_stream(rs.stream.color, dim[0], dim[1], rs.format.bgr8, fps) 21 | if self.depth: 22 | self.cfg.enable_stream(rs.stream.depth, dim[0], dim[1], rs.format.z16, fps) 23 | self.profile = self.pipe.start(self.cfg) 24 | self.s = self.profile.get_device().query_sensors()[0] 25 | self.s.set_option(rs.option.exposure, exposure) 26 | 27 | # Create an align object 28 | # rs.align allows us to perform alignment of depth frames to others frames 29 | # The "align_to" is the stream type to which we plan to align depth frames. 30 | align_to = rs.stream.color 31 | self.align = rs.align(align_to) 32 | 33 | def read(self): 34 | frames = self.pipe.wait_for_frames() 35 | aligned_frames = self.align.process(frames) 36 | color_frame = aligned_frames.get_color_frame() 37 | if self.depth: 38 | depth_frame = aligned_frames.get_depth_frame() 39 | 40 | if color_frame.is_video_frame(): 41 | image = np.asarray(color_frame.get_data()) 42 | if self.depth and depth_frame.is_depth_frame(): 43 | depth = np.expand_dims(np.asarray(depth_frame.get_data()), axis=2) 44 | return True, np.concatenate((image, depth), axis=-1) 45 | else: 46 | return True, image 47 | else: 48 | return False, None 49 | 50 | def close(self): 51 | self.pipe.stop() 52 | self.cfg.disable_all_streams() 53 | -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/camera/video_capture.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import threading 3 | import time 4 | import numpy as np 5 | 6 | class VideoCapture: 7 | def __init__(self, cap, name=None): 8 | if name is None: 9 | name = cap.name 10 | self.name = name 11 | self.q = queue.Queue() 12 | self.cap = cap 13 | self.t = threading.Thread(target=self._reader) 14 | self.t.daemon = False 15 | self.enable = True 16 | self.t.start() 17 | 18 | def _reader(self): 19 | while self.enable: 20 | ret, frame = self.cap.read() 21 | if not ret: 22 | break 23 | if not self.q.empty(): 24 | try: 25 | self.q.get_nowait() # discard previous (unprocessed) frame 26 | except queue.Empty: 27 | pass 28 | self.q.put(frame) 29 | 30 | def read(self): 31 | # print(self.name, self.q.qsize()) 32 | return self.q.get(timeout=5) 33 | 34 | def close(self): 35 | self.enable = False 36 | self.t.join() 37 | self.cap.close() 38 | -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from franka_env.envs.franka_env import FrankaEnv, DefaultEnvConfig 2 | from franka_env.envs.franka_wrench_env import FrankaWrenchEnv -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/envs/dual_franka_env.py: -------------------------------------------------------------------------------- 1 | """Gym Interface for Franka""" 2 | import queue 3 | import threading 4 | import time 5 | import numpy as np 6 | import gymnasium as gym 7 | import cv2 8 | class ImageDisplayer(threading.Thread): 9 | def __init__(self, queue): 10 | threading.Thread.__init__(self) 11 | self.queue = queue 12 | self.daemon = True # make this a daemon thread 13 | 14 | def run(self): 15 | while True: 16 | img_array = self.queue.get() # retrieve an image from the queue 17 | if img_array is None: # None is our signal to exit 18 | break 19 | 20 | left_frame = np.concatenate( 21 | [cv2.resize(v, (256, 256)) for k, v in img_array.items() if "left" in k], axis=1 22 | ) 23 | right_frame = np.concatenate( 24 | [cv2.resize(v, (256, 256)) for k, v in img_array.items() if "right" in k], axis=1 25 | ) 26 | frame = np.concatenate([left_frame, right_frame], axis=1) 27 | 28 | cv2.imshow('Image', frame[..., ::-1]) 29 | cv2.waitKey(1) 30 | 31 | 32 | class DualFrankaEnv(gym.Env): 33 | def __init__( 34 | self, 35 | env_left, 36 | env_right, 37 | display_images=True, 38 | ): 39 | 40 | self.env_left = env_left 41 | self.env_right = env_right 42 | 43 | # Action/Observation Space 44 | action_dim = len(self.env_left.action_space.low) + len(self.env_right.action_space.low) 45 | self.action_space = gym.spaces.Box( 46 | np.ones((action_dim,), dtype=np.float32) * -1, 47 | np.ones((action_dim,), dtype=np.float32), 48 | ) 49 | image_dict = ({f"left/{key}": self.env_left.observation_space["images"][key] for key in self.env_left.observation_space["images"].keys()} | 50 | {f"right/{key}": self.env_right.observation_space["images"][key] for key in self.env_right.observation_space["images"].keys()}) 51 | 52 | state_dict = ({f"left/{key}": self.env_left.observation_space["state"][key] for key in self.env_left.observation_space["state"].keys()} | 53 | {f"right/{key}": self.env_right.observation_space["state"][key] for key in self.env_right.observation_space["state"].keys()}) 54 | 55 | self.observation_space = gym.spaces.Dict( 56 | { 57 | "state": gym.spaces.Dict(state_dict), 58 | "images": gym.spaces.Dict(image_dict) 59 | } 60 | ) 61 | self.display_images = display_images 62 | if self.display_images: 63 | self.img_queue = queue.Queue() 64 | self.displayer = ImageDisplayer(self.img_queue) 65 | self.displayer.start() 66 | 67 | def step(self, action: np.ndarray) -> tuple: 68 | action_left = action[:len(action)//2] 69 | action_right = action[len(action)//2:] 70 | def step_env_left(): 71 | global ob_left, reward_left, done_left 72 | ob_left, reward_left, done_left, _, _ = self.env_left.step(action_left) 73 | 74 | def step_env_right(): 75 | global ob_right, reward_right, done_right 76 | ob_right, reward_right, done_right, _, _ = self.env_right.step(action_right) 77 | 78 | # Create threads for each function 79 | thread_left = threading.Thread(target=step_env_left) 80 | thread_right = threading.Thread(target=step_env_right) 81 | 82 | # Start the threads 83 | thread_left.start() 84 | thread_right.start() 85 | 86 | # Wait for both threads to complete 87 | thread_left.join() 88 | thread_right.join() 89 | ob = self.combine_obs(ob_left, ob_right) 90 | if self.display_images: 91 | self.img_queue.put(ob['images']) 92 | return ob, int(reward_left and reward_right), done_left or done_right, False, {} 93 | 94 | 95 | def reset(self, **kwargs): 96 | def reset_env_left(): 97 | global ob_left 98 | ob_left, _ = self.env_left.reset(**kwargs) 99 | 100 | def reset_env_right(): 101 | global ob_right 102 | ob_right, _ = self.env_right.reset(**kwargs) 103 | 104 | thread_left = threading.Thread(target=reset_env_left) 105 | thread_right = threading.Thread(target=reset_env_right) 106 | thread_left.start() 107 | thread_right.start() 108 | thread_left.join() 109 | thread_right.join() 110 | 111 | ob = self.combine_obs(ob_left, ob_right) 112 | return ob, {} 113 | 114 | def combine_obs(self, ob_left, ob_right): 115 | left_images = {f"left/{key}": ob_left["images"][key] for key in ob_left["images"].keys()} 116 | right_images = {f"right/{key}": ob_right["images"][key] for key in ob_right["images"].keys()} 117 | left_state = {f"left/{key}": ob_left["state"][key] for key in ob_left["state"].keys()} 118 | right_state = {f"right/{key}": ob_right["state"][key] for key in ob_right["state"].keys()} 119 | ob = { 120 | "state": left_state | right_state, 121 | "images": left_images | right_images 122 | } 123 | return ob -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/envs/relative_env.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from scipy.spatial.transform import Rotation as R 3 | import gymnasium as gym 4 | import numpy as np 5 | from gym import Env 6 | from franka_env.utils.transformations import ( 7 | construct_adjoint_matrix, 8 | construct_homogeneous_matrix, 9 | ) 10 | 11 | 12 | class RelativeFrame(gym.Wrapper): 13 | """ 14 | This wrapper transforms the observation and action to be expressed in the end-effector frame. 15 | Optionally, it can transform the tcp_pose into a relative frame defined as the reset pose. 16 | 17 | This wrapper is expected to be used on top of the base Franka environment, which has the following 18 | observation space: 19 | { 20 | "state": spaces.Dict( 21 | { 22 | "tcp_pose": spaces.Box(-np.inf, np.inf, shape=(7,)), # xyz + quat 23 | ...... 24 | } 25 | ), 26 | ...... 27 | }, and at least 6 DoF action space with (x, y, z, rx, ry, rz, ...) 28 | """ 29 | 30 | def __init__(self, env: Env, include_relative_pose=True): 31 | super().__init__(env) 32 | self.adjoint_matrix = np.zeros((6, 6)) 33 | 34 | self.include_relative_pose = include_relative_pose 35 | if self.include_relative_pose: 36 | # Homogeneous transformation matrix from reset pose's relative frame to base frame 37 | self.T_r_o_inv = np.zeros((4, 4)) 38 | 39 | def step(self, action: np.ndarray): 40 | # action is assumed to be (x, y, z, rx, ry, rz, gripper) 41 | # Transform action from end-effector frame to base frame 42 | transformed_action = self.transform_action(action) 43 | obs, reward, done, truncated, info = self.env.step(transformed_action) 44 | info['original_state_obs'] = copy.deepcopy(obs['state']) 45 | 46 | # this is to convert the spacemouse intervention action 47 | if "intervene_action" in info: 48 | info["intervene_action"] = self.transform_action_inv(info["intervene_action"]) 49 | 50 | # Update adjoint matrix 51 | self.adjoint_matrix = construct_adjoint_matrix(obs["state"]["tcp_pose"]) 52 | 53 | # Transform observation to spatial frame 54 | transformed_obs = self.transform_observation(obs) 55 | return transformed_obs, reward, done, truncated, info 56 | 57 | def reset(self, **kwargs): 58 | obs, info = self.env.reset(**kwargs) 59 | info['original_state_obs'] = copy.deepcopy(obs['state']) 60 | 61 | # Update adjoint matrix 62 | self.adjoint_matrix = construct_adjoint_matrix(obs["state"]["tcp_pose"]) 63 | if self.include_relative_pose: 64 | # Update transformation matrix from the reset pose's relative frame to base frame 65 | self.T_r_o_inv = np.linalg.inv( 66 | construct_homogeneous_matrix(obs["state"]["tcp_pose"]) 67 | ) 68 | 69 | # Transform observation to spatial frame 70 | return self.transform_observation(obs), info 71 | 72 | def transform_observation(self, obs): 73 | """ 74 | Transform observations from spatial(base) frame into body(end-effector) frame 75 | using the adjoint matrix 76 | """ 77 | adjoint_inv = np.linalg.inv(self.adjoint_matrix) 78 | obs["state"]["tcp_vel"] = adjoint_inv @ obs["state"]["tcp_vel"] 79 | 80 | if self.include_relative_pose: 81 | T_b_o = construct_homogeneous_matrix(obs["state"]["tcp_pose"]) 82 | T_b_r = self.T_r_o_inv @ T_b_o 83 | 84 | # Reconstruct transformed tcp_pose vector 85 | p_b_r = T_b_r[:3, 3] 86 | theta_b_r = R.from_matrix(T_b_r[:3, :3]).as_quat() 87 | obs["state"]["tcp_pose"] = np.concatenate((p_b_r, theta_b_r)) 88 | 89 | return obs 90 | 91 | def transform_action(self, action: np.ndarray): 92 | """ 93 | Transform action from body(end-effector) frame into into spatial(base) frame 94 | using the adjoint matrix. 95 | """ 96 | action = np.array(action) # in case action is a jax read-only array 97 | action[:6] = self.adjoint_matrix @ action[:6] 98 | return action 99 | 100 | def transform_action_inv(self, action: np.ndarray): 101 | """ 102 | Transform action from spatial(base) frame into body(end-effector) frame 103 | using the adjoint matrix. 104 | """ 105 | action = np.array(action) 106 | action[:6] = np.linalg.inv(self.adjoint_matrix) @ action[:6] 107 | return action 108 | 109 | 110 | class DualRelativeFrame(gym.Wrapper): 111 | """ 112 | This wrapper transforms the observation and action to be expressed in the end-effector frame. 113 | Optionally, it can transform the tcp_pose into a relative frame defined as the reset pose. 114 | 115 | This wrapper is expected to be used on top of the base Franka environment, which has the following 116 | observation space: 117 | { 118 | "state": spaces.Dict( 119 | { 120 | "left/tcp_pose": spaces.Box(-np.inf, np.inf, shape=(7,)), # xyz + quat 121 | ... 122 | "right/tcp_pose": spaces.Box(-np.inf, np.inf, shape=(7,)), # xyz + quat 123 | ... 124 | } 125 | ), 126 | ...... 127 | }, and at least 12 DoF action space 128 | """ 129 | 130 | def __init__(self, env: Env, include_relative_pose=True): 131 | super().__init__(env) 132 | self.left_adjoint_matrix = np.zeros((6, 6)) 133 | self.right_adjoint_matrix = np.zeros((6, 6)) 134 | 135 | self.include_relative_pose = include_relative_pose 136 | if self.include_relative_pose: 137 | # Homogeneous transformation matrix from reset pose's relative frame to base frame 138 | self.left_T_r_o_inv = np.zeros((4, 4)) 139 | self.right_T_r_o_inv = np.zeros((4, 4)) 140 | 141 | def step(self, action: np.ndarray): 142 | # action is assumed to be (x, y, z, rx, ry, rz, gripper) 143 | # Transform action from end-effector frame to base frame 144 | transformed_action = self.transform_action(action) 145 | obs, reward, done, truncated, info = self.env.step(transformed_action) 146 | 147 | # this is to convert the spacemouse intervention action 148 | if "intervene_action" in info: 149 | info["intervene_action"] = self.transform_action_inv(info["intervene_action"]) 150 | 151 | # Update adjoint matrix 152 | self.left_adjoint_matrix = construct_adjoint_matrix(obs["state"]["left/tcp_pose"]) 153 | self.right_adjoint_matrix = construct_adjoint_matrix(obs["state"]["right/tcp_pose"]) 154 | 155 | # Transform observation to spatial frame 156 | transformed_obs = self.transform_observation(obs) 157 | return transformed_obs, reward, done, truncated, info 158 | 159 | def reset(self, **kwargs): 160 | obs, info = self.env.reset(**kwargs) 161 | 162 | # Update adjoint matrix 163 | self.left_adjoint_matrix = construct_adjoint_matrix(obs["state"]["left/tcp_pose"]) 164 | self.right_adjoint_matrix = construct_adjoint_matrix(obs["state"]["right/tcp_pose"]) 165 | 166 | if self.include_relative_pose: 167 | # Update transformation matrix from the reset pose's relative frame to base frame 168 | self.left_T_r_o_inv = np.linalg.inv( 169 | construct_homogeneous_matrix(obs["state"]["left/tcp_pose"]) 170 | ) 171 | self.right_T_r_o_inv = np.linalg.inv( 172 | construct_homogeneous_matrix(obs["state"]["right/tcp_pose"]) 173 | ) 174 | # Transform observation to spatial frame 175 | return self.transform_observation(obs), info 176 | 177 | def transform_observation(self, obs): 178 | """ 179 | Transform observations from spatial(base) frame into body(end-effector) frame 180 | using the adjoint matrix 181 | """ 182 | left_adjoint_inv = np.linalg.inv(self.left_adjoint_matrix) 183 | obs["state"]["left/tcp_vel"] = left_adjoint_inv @ obs["state"]["left/tcp_vel"] 184 | 185 | right_adjoint_inv = np.linalg.inv(self.right_adjoint_matrix) 186 | obs["state"]["right/tcp_vel"] = right_adjoint_inv @ obs["state"]["right/tcp_vel"] 187 | 188 | if self.include_relative_pose: 189 | left_T_b_o = construct_homogeneous_matrix(obs["state"]["left/tcp_pose"]) 190 | left_T_b_r = self.left_T_r_o_inv @ left_T_b_o 191 | 192 | # Reconstruct transformed tcp_pose vector 193 | left_p_b_r = left_T_b_r[:3, 3] 194 | left_theta_b_r = R.from_matrix(left_T_b_r[:3, :3]).as_quat() 195 | obs["state"]["left/tcp_pose"] = np.concatenate((left_p_b_r, left_theta_b_r)) 196 | 197 | right_T_b_o = construct_homogeneous_matrix(obs["state"]["right/tcp_pose"]) 198 | right_T_b_r = self.right_T_r_o_inv @ right_T_b_o 199 | 200 | # Reconstruct transformed tcp_pose vector 201 | right_p_b_r = right_T_b_r[:3, 3] 202 | right_theta_b_r = R.from_matrix(right_T_b_r[:3, :3]).as_quat() 203 | obs["state"]["right/tcp_pose"] = np.concatenate((right_p_b_r, right_theta_b_r)) 204 | 205 | 206 | return obs 207 | 208 | def transform_action(self, action: np.ndarray): 209 | """ 210 | Transform action from body(end-effector) frame into into spatial(base) frame 211 | using the adjoint matrix 212 | """ 213 | action = np.array(action) # in case action is a jax read-only array 214 | if len(action) == 12: 215 | action[:6] = self.left_adjoint_matrix @ action[:6] 216 | action[6:] = self.right_adjoint_matrix @ action[6:] 217 | elif len(action) == 14: 218 | action[:6] = self.left_adjoint_matrix @ action[:6] 219 | action[7:13] = self.right_adjoint_matrix @ action[7:13] 220 | else: 221 | raise ValueError("Action dimension not supported") 222 | return action 223 | 224 | def transform_action_inv(self, action: np.ndarray): 225 | """ 226 | Transform action from spatial(base) frame into body(end-effector) frame 227 | using the adjoint matrix. 228 | """ 229 | action = np.array(action) 230 | if len(action) == 12: 231 | action[:6] = np.linalg.inv(self.left_adjoint_matrix) @ action[:6] 232 | action[6:] = np.linalg.inv(self.right_adjoint_matrix) @ action[6:] 233 | elif len(action) == 14: 234 | action[:6] = np.linalg.inv(self.left_adjoint_matrix) @ action[:6] 235 | action[7:13] = np.linalg.inv(self.right_adjoint_matrix) @ action[7:13] 236 | else: 237 | raise ValueError("Action dimension not supported") 238 | return action -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/spacemouse/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/serl_robot_infra/franka_env/spacemouse/__init__.py -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/spacemouse/spacemouse_expert.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import numpy as np 3 | from franka_env.spacemouse import pyspacemouse 4 | from typing import Tuple 5 | 6 | 7 | class SpaceMouseExpert: 8 | """ 9 | This class provides an interface to the SpaceMouse. 10 | It continuously reads the SpaceMouse state and provides 11 | a "get_action" method to get the latest action and button state. 12 | """ 13 | 14 | def __init__(self): 15 | pyspacemouse.open() 16 | 17 | # Manager to handle shared state between processes 18 | self.manager = multiprocessing.Manager() 19 | self.latest_data = self.manager.dict() 20 | self.latest_data["action"] = [0.0] * 6 # Using lists for compatibility 21 | self.latest_data["buttons"] = [0, 0] 22 | 23 | # Start a process to continuously read the SpaceMouse state 24 | self.process = multiprocessing.Process(target=self._read_spacemouse) 25 | self.process.daemon = True 26 | self.process.start() 27 | 28 | def _read_spacemouse(self): 29 | while True: 30 | state = pyspacemouse.read_all() 31 | action = [0.0] * 6 32 | buttons = [0, 0] 33 | # print(len(state)) 34 | if len(state) == 2: 35 | action = [ 36 | -state[0].y, state[0].x, state[0].z, 37 | -state[0].roll, -state[0].pitch, -state[0].yaw, 38 | -state[1].y, state[1].x, state[1].z, 39 | -state[1].roll, -state[1].pitch, -state[1].yaw 40 | ] 41 | buttons = state[0].buttons + state[1].buttons 42 | elif len(state) == 1: 43 | action = [ 44 | -state[0].y, state[0].x, state[0].z, 45 | -state[0].roll, -state[0].pitch, -state[0].yaw 46 | ] 47 | buttons = state[0].buttons 48 | # print(buttons) 49 | 50 | # Update the shared state 51 | self.latest_data["action"] = action 52 | self.latest_data["buttons"] = buttons 53 | 54 | def get_action(self) -> Tuple[np.ndarray, list]: 55 | """Returns the latest action and button state of the SpaceMouse.""" 56 | action = self.latest_data["action"] 57 | buttons = self.latest_data["buttons"] 58 | return np.array(action), buttons 59 | 60 | def close(self): 61 | # pyspacemouse.close() 62 | self.process.terminate() 63 | -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/spacemouse/spacemouse_test.py: -------------------------------------------------------------------------------- 1 | """ Test the spacemouse output. """ 2 | import time 3 | import numpy as np 4 | from franka_env.spacemouse.spacemouse_expert import SpaceMouseExpert 5 | 6 | 7 | def test_spacemouse(): 8 | """Test the SpaceMouseExpert class. 9 | 10 | This interactive test prints the action and buttons of the spacemouse at a rate of 10Hz. 11 | The user is expected to move the spacemouse and press its buttons while the test is running. 12 | It keeps running until the user stops it. 13 | 14 | """ 15 | spacemouse0 = SpaceMouseExpert() 16 | with np.printoptions(precision=3, suppress=True): 17 | while True: 18 | action, buttons = spacemouse0.get_action() 19 | print(f"Spacemouse action: {action}, buttons: {buttons}") 20 | time.sleep(0.1) 21 | 22 | 23 | def main(): 24 | """Call spacemouse test.""" 25 | test_spacemouse() 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/serl_robot_infra/franka_env/utils/__init__.py -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/utils/rotations.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial.transform import Rotation as R 2 | import numpy as np 3 | from pyquaternion import Quaternion 4 | 5 | 6 | def quat_2_euler(quat): 7 | """calculates and returns: yaw, pitch, roll from given quaternion""" 8 | return R.from_quat(quat).as_euler("xyz") 9 | 10 | 11 | def euler_2_quat(xyz): 12 | yaw, pitch, roll = xyz 13 | yaw = np.pi - yaw 14 | yaw_matrix = np.array( 15 | [ 16 | [np.cos(yaw), -np.sin(yaw), 0.0], 17 | [np.sin(yaw), np.cos(yaw), 0.0], 18 | [0, 0, 1.0], 19 | ] 20 | ) 21 | pitch_matrix = np.array( 22 | [ 23 | [np.cos(pitch), 0.0, np.sin(pitch)], 24 | [0.0, 1.0, 0.0], 25 | [-np.sin(pitch), 0, np.cos(pitch)], 26 | ] 27 | ) 28 | roll_matrix = np.array( 29 | [ 30 | [1.0, 0, 0], 31 | [0, np.cos(roll), -np.sin(roll)], 32 | [0, np.sin(roll), np.cos(roll)], 33 | ] 34 | ) 35 | rot_mat = yaw_matrix.dot(pitch_matrix.dot(roll_matrix)) 36 | return Quaternion(matrix=rot_mat).elements 37 | 38 | def new_euler_2_quat(xyz): 39 | return R.from_euler("xyz", xyz).as_quat() 40 | -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/utils/transform_absolute_actions_and_obs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/serl_robot_infra/franka_env/utils/transform_absolute_actions_and_obs.py -------------------------------------------------------------------------------- /serl_robot_infra/franka_env/utils/transformations.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial.transform import Rotation as R 2 | import numpy as np 3 | 4 | 5 | def construct_adjoint_matrix(tcp_pose): 6 | """ 7 | Construct the adjoint matrix for a spatial velocity vector 8 | :args: tcp_pose: (x, y, z, qx, qy, qz, qw) 9 | """ 10 | rotation = R.from_quat(tcp_pose[3:]).as_matrix() 11 | translation = np.array(tcp_pose[:3]) 12 | skew_matrix = np.array( 13 | [ 14 | [0, -translation[2], translation[1]], 15 | [translation[2], 0, -translation[0]], 16 | [-translation[1], translation[0], 0], 17 | ] 18 | ) 19 | adjoint_matrix = np.zeros((6, 6)) 20 | adjoint_matrix[:3, :3] = rotation 21 | adjoint_matrix[3:, 3:] = rotation 22 | adjoint_matrix[3:, :3] = skew_matrix @ rotation 23 | return adjoint_matrix 24 | 25 | 26 | def construct_homogeneous_matrix(tcp_pose): 27 | """ 28 | Construct the homogeneous transformation matrix from given pose. 29 | args: tcp_pose: (x, y, z, qx, qy, qz, qw) 30 | """ 31 | rotation = R.from_quat(tcp_pose[3:]).as_matrix() 32 | translation = np.array(tcp_pose[:3]) 33 | T = np.zeros((4, 4)) 34 | T[:3, :3] = rotation 35 | T[:3, 3] = translation 36 | T[3, 3] = 1 37 | return T 38 | 39 | def construct_adjoint_matrix_from_euler(tcp_pose): 40 | """ 41 | Construct the adjoint matrix for a spatial velocity vector 42 | :args: tcp_pose: (x, y, z, qx, qy, qz, qw) 43 | """ 44 | rotation = R.from_euler("xyz", tcp_pose[3:]).as_matrix() 45 | translation = np.array(tcp_pose[:3]) 46 | skew_matrix = np.array( 47 | [ 48 | [0, -translation[2], translation[1]], 49 | [translation[2], 0, -translation[0]], 50 | [-translation[1], translation[0], 0], 51 | ] 52 | ) 53 | adjoint_matrix = np.zeros((6, 6)) 54 | adjoint_matrix[:3, :3] = rotation 55 | adjoint_matrix[3:, 3:] = rotation 56 | adjoint_matrix[3:, :3] = skew_matrix @ rotation 57 | return adjoint_matrix 58 | 59 | 60 | def construct_homogeneous_matrix_from_euler(tcp_pose): 61 | """ 62 | Construct the homogeneous transformation matrix from given pose. 63 | args: tcp_pose: (x, y, z, qx, qy, qz, qw) 64 | """ 65 | rotation = R.from_euler("xyz", tcp_pose[3:]).as_matrix() 66 | translation = np.array(tcp_pose[:3]) 67 | T = np.zeros((4, 4)) 68 | T[:3, :3] = rotation 69 | T[:3, 3] = translation 70 | T[3, 3] = 1 71 | return T 72 | -------------------------------------------------------------------------------- /serl_robot_infra/robot_servers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cccedric/conrft/3bacb6f73538ff02524e276690bf320fa6b6d323/serl_robot_infra/robot_servers/__init__.py -------------------------------------------------------------------------------- /serl_robot_infra/robot_servers/franka_gripper_server.py: -------------------------------------------------------------------------------- 1 | import rospy 2 | from franka_gripper.msg import GraspActionGoal, MoveActionGoal 3 | from sensor_msgs.msg import JointState 4 | import numpy as np 5 | 6 | from robot_servers.gripper_server import GripperServer 7 | 8 | 9 | class FrankaGripperServer(GripperServer): 10 | def __init__(self): 11 | super().__init__() 12 | self.grippermovepub = rospy.Publisher( 13 | "/franka_gripper/move/goal", MoveActionGoal, queue_size=1 14 | ) 15 | self.grippergrasppub = rospy.Publisher( 16 | "/franka_gripper/grasp/goal", GraspActionGoal, queue_size=1 17 | ) 18 | self.gripper_sub = rospy.Subscriber( 19 | "/franka_gripper/joint_states", JointState, self._update_gripper 20 | ) 21 | self.binary_gripper_pose = 0 22 | 23 | def open(self): 24 | if self.binary_gripper_pose == 0: 25 | return 26 | msg = MoveActionGoal() 27 | # msg.goal.width = 0.025 28 | msg.goal.width = 0.09 29 | msg.goal.speed = 0.3 30 | self.grippermovepub.publish(msg) 31 | self.binary_gripper_pose = 0 32 | 33 | def close(self): 34 | if self.binary_gripper_pose == 1: 35 | return 36 | msg = GraspActionGoal() 37 | msg.goal.width = 0.01 38 | msg.goal.speed = 0.3 39 | msg.goal.epsilon.inner = 1 40 | msg.goal.epsilon.outer = 1 41 | msg.goal.force = 1 42 | self.grippergrasppub.publish(msg) 43 | self.binary_gripper_pose = 1 44 | 45 | def close_slow(self): 46 | if self.binary_gripper_pose == 1: 47 | return 48 | msg = GraspActionGoal() 49 | msg.goal.width = 0.01 50 | msg.goal.speed = 0.1 51 | msg.goal.epsilon.inner = 1 52 | msg.goal.epsilon.outer = 1 53 | msg.goal.force = 1 54 | self.grippergrasppub.publish(msg) 55 | self.binary_gripper_pose = 1 56 | 57 | def move(self, position: int): 58 | """Move the gripper to a specific position in range [0, 255]""" 59 | msg = MoveActionGoal() 60 | msg.goal.width = float(position / (255 * 10)) # width in [0, 0.1]m 61 | msg.goal.speed = 0.3 62 | self.grippermovepub.publish(msg) 63 | 64 | def _update_gripper(self, msg): 65 | """internal callback to get the latest gripper position.""" 66 | self.gripper_pos = np.sum(msg.position) / 0.08 67 | -------------------------------------------------------------------------------- /serl_robot_infra/robot_servers/gripper_server.py: -------------------------------------------------------------------------------- 1 | class GripperServer: 2 | def __init__(self): 3 | self.gripper_pos = 0 4 | 5 | def open(self): 6 | pass 7 | 8 | def close(self): 9 | pass 10 | 11 | def move(self, position: int): 12 | pass 13 | 14 | def activate_gripper(self): 15 | pass 16 | 17 | def reset_gripper(self): 18 | pass 19 | -------------------------------------------------------------------------------- /serl_robot_infra/robot_servers/launch_left_server.sh: -------------------------------------------------------------------------------- 1 | # Source the setup.bash file for the first ROS workspace 2 | source ~/code/catkin_ws_FER/devel/setup.bash 3 | 4 | # Set ROS master URI to localhost 5 | export ROS_MASTER_URI=http://localhost:11311 6 | 7 | # Run the first instance of franka_server.py in the background 8 | python franka_server.py \ 9 | --robot_ip=173.16.0.2 \ 10 | --gripper_type=Franka \ 11 | --reset_joint_target=0,0,0,-1.9,-0,2,0 \ 12 | --flask_url=127.0.0.1 \ 13 | --ros_port=11311 -------------------------------------------------------------------------------- /serl_robot_infra/robot_servers/launch_right_eggflip_server.sh: -------------------------------------------------------------------------------- 1 | # Source the setup.bash file for the second ROS workspace 2 | source /home/undergrad/code/catkin_ws/devel/setup.bash 3 | 4 | # Change the ROS master URI to a different port 5 | export ROS_MASTER_URI=http://localhost:11511 6 | 7 | # Run the second instance of franka_server.py in the background 8 | python franka_eggflip_server.py \ 9 | --robot_ip=172.16.0.2 \ 10 | --gripper_type=None \ 11 | --flask_url=127.0.0.2 \ 12 | --ros_port=11511 -------------------------------------------------------------------------------- /serl_robot_infra/robot_servers/launch_right_server.sh: -------------------------------------------------------------------------------- 1 | # Source the setup.bash file for the second ROS workspace 2 | source /root/online_rl/catkin_ws/devel/setup.bash 3 | 4 | # Change the ROS master URI to a different port 5 | export ROS_MASTER_URI=http://localhost:11511 6 | 7 | # Run the second instance of franka_server.py in the background 8 | python franka_server.py \ 9 | --robot_ip=192.168.1.221 \ 10 | --gripper_type=Franka \ 11 | --flask_url=127.0.0.2 \ 12 | --ros_port=11511 -------------------------------------------------------------------------------- /serl_robot_infra/robot_servers/robotiq_gripper_server.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import rospy 3 | from robotiq_2f_gripper_control.msg import _Robotiq2FGripper_robot_output as outputMsg 4 | from robotiq_2f_gripper_control.msg import _Robotiq2FGripper_robot_input as inputMsg 5 | 6 | from robot_servers.gripper_server import GripperServer 7 | 8 | 9 | class RobotiqGripperServer(GripperServer): 10 | def __init__(self, gripper_ip): 11 | super().__init__() 12 | self.gripper = subprocess.Popen( 13 | [ 14 | "rosrun", 15 | "robotiq_2f_gripper_control", 16 | "Robotiq2FGripperTcpNode.py", 17 | gripper_ip, 18 | ], 19 | stdout=subprocess.PIPE, 20 | ) 21 | self.gripper_state_sub = rospy.Subscriber( 22 | "Robotiq2FGripperRobotInput", 23 | inputMsg.Robotiq2FGripper_robot_input, 24 | self._update_gripper, 25 | queue_size=1, 26 | ) 27 | self.gripperpub = rospy.Publisher( 28 | "Robotiq2FGripperRobotOutput", 29 | outputMsg.Robotiq2FGripper_robot_output, 30 | queue_size=1, 31 | ) 32 | self.gripper_command = outputMsg.Robotiq2FGripper_robot_output() 33 | 34 | def activate_gripper(self): 35 | self.gripper_command = self._generate_gripper_command("a", self.gripper_command) 36 | self.gripperpub.publish(self.gripper_command) 37 | 38 | def reset_gripper(self): 39 | self.gripper_command = self._generate_gripper_command("r", self.gripper_command) 40 | self.gripperpub.publish(self.gripper_command) 41 | self.activate_gripper() 42 | 43 | def open(self): 44 | self.gripper_command = self._generate_gripper_command("o", self.gripper_command) 45 | self.gripperpub.publish(self.gripper_command) 46 | 47 | def close(self): 48 | self.gripper_command = self._generate_gripper_command("c", self.gripper_command) 49 | self.gripperpub.publish(self.gripper_command) 50 | 51 | def move(self, position): 52 | self.gripper_command = self._generate_gripper_command(position, self.gripper_command) 53 | self.gripperpub.publish(self.gripper_command) 54 | 55 | def close_slow(self): 56 | self.gripper_command = self._generate_gripper_command("cs", self.gripper_command) 57 | self.gripperpub.publish(self.gripper_command) 58 | 59 | def _update_gripper(self, msg): 60 | """internal callback to get the latest gripper position.""" 61 | self.gripper_pos = 1 - msg.gPO / 255 62 | 63 | def _generate_gripper_command(self, char, command): 64 | """Update the gripper command according to the character entered by the user.""" 65 | if char == "a": 66 | command = outputMsg.Robotiq2FGripper_robot_output() 67 | command.rACT = 1 68 | command.rGTO = 1 69 | command.rSP = 255 70 | command.rFR = 30 71 | 72 | elif char == "r": 73 | command = outputMsg.Robotiq2FGripper_robot_output() 74 | command.rACT = 0 75 | command.rSP = 255 76 | 77 | elif char == "c": 78 | command.rPR = 255 79 | command.rSP = 255 80 | 81 | elif char == "cs": 82 | command.rPR = 255 83 | command.rSP = 50 84 | 85 | elif char == "o": 86 | command.rPR = 175 87 | command.rSP = 255 88 | 89 | # If the command entered is a int, assign this value to rPR 90 | # (i.e., move to this position) 91 | try: 92 | command.rPR = int(char) 93 | if command.rPR > 255: 94 | command.rPR = 255 95 | if command.rPR < 0: 96 | command.rPR = 0 97 | except ValueError: 98 | pass 99 | return command 100 | -------------------------------------------------------------------------------- /serl_robot_infra/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="serl_robot_infra", 5 | version="0.0.1", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "gymnasium", 9 | "pyrealsense2", 10 | "pymodbus==2.5.3", 11 | "opencv-python", 12 | "pyquaternion", 13 | "pyspacemouse", 14 | "hidapi", 15 | "pyyaml", 16 | "rospkg", 17 | "scipy", 18 | "requests", 19 | "flask", 20 | "defusedxml", 21 | ], 22 | ) 23 | --------------------------------------------------------------------------------