├── .gitignore ├── LICENSE ├── README.md ├── configs └── base.py ├── requirements.txt ├── scripts ├── policy_ckpt_to_hlo.py ├── robot │ ├── eval_diffusion.py │ ├── eval_gcbc.py │ └── eval_lcbc.py └── train.py ├── setup.py └── susie ├── data ├── __init__.py ├── datasets.py └── goal_relabeling.py ├── jax_utils.py ├── model.py ├── sampling.py └── scheduling.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/vim,visualstudiocode,python,macos,linux 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=vim,visualstudiocode,python,macos,linux 3 | 4 | ### Linux ### 5 | *~ 6 | 7 | # temporary files which can be created if a process still has a handle open of a deleted file 8 | .fuse_hidden* 9 | 10 | # KDE directory preferences 11 | .directory 12 | 13 | # Linux trash folder which might appear on any partition or disk 14 | .Trash-* 15 | 16 | # .nfs files are created when an open file is removed but is still being accessed 17 | .nfs* 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | 29 | # Thumbnails 30 | ._* 31 | 32 | # Files that might appear in the root of a volume 33 | .DocumentRevisions-V100 34 | .fseventsd 35 | .Spotlight-V100 36 | .TemporaryItems 37 | .Trashes 38 | .VolumeIcon.icns 39 | .com.apple.timemachine.donotpresent 40 | 41 | # Directories potentially created on remote AFP share 42 | .AppleDB 43 | .AppleDesktop 44 | Network Trash Folder 45 | Temporary Items 46 | .apdisk 47 | 48 | ### macOS Patch ### 49 | # iCloud generated files 50 | *.icloud 51 | 52 | ### Python ### 53 | # Byte-compiled / optimized / DLL files 54 | __pycache__/ 55 | *.py[cod] 56 | *$py.class 57 | 58 | # C extensions 59 | *.so 60 | 61 | # Distribution / packaging 62 | .Python 63 | build/ 64 | develop-eggs/ 65 | dist/ 66 | downloads/ 67 | eggs/ 68 | .eggs/ 69 | lib/ 70 | lib64/ 71 | parts/ 72 | sdist/ 73 | var/ 74 | wheels/ 75 | share/python-wheels/ 76 | *.egg-info/ 77 | .installed.cfg 78 | *.egg 79 | MANIFEST 80 | 81 | # PyInstaller 82 | # Usually these files are written by a python script from a template 83 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 84 | *.manifest 85 | *.spec 86 | 87 | # Installer logs 88 | pip-log.txt 89 | pip-delete-this-directory.txt 90 | 91 | # Unit test / coverage reports 92 | htmlcov/ 93 | .tox/ 94 | .nox/ 95 | .coverage 96 | .coverage.* 97 | .cache 98 | nosetests.xml 99 | coverage.xml 100 | *.cover 101 | *.py,cover 102 | .hypothesis/ 103 | .pytest_cache/ 104 | cover/ 105 | 106 | # Translations 107 | *.mo 108 | *.pot 109 | 110 | # Django stuff: 111 | *.log 112 | local_settings.py 113 | db.sqlite3 114 | db.sqlite3-journal 115 | 116 | # Flask stuff: 117 | instance/ 118 | .webassets-cache 119 | 120 | # Scrapy stuff: 121 | .scrapy 122 | 123 | # Sphinx documentation 124 | docs/_build/ 125 | 126 | # PyBuilder 127 | .pybuilder/ 128 | target/ 129 | 130 | # Jupyter Notebook 131 | .ipynb_checkpoints 132 | 133 | # IPython 134 | profile_default/ 135 | ipython_config.py 136 | 137 | # pyenv 138 | # For a library or package, you might want to ignore these files since the code is 139 | # intended to run in multiple environments; otherwise, check them in: 140 | # .python-version 141 | 142 | # pipenv 143 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 144 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 145 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 146 | # install all needed dependencies. 147 | #Pipfile.lock 148 | 149 | # poetry 150 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 151 | # This is especially recommended for binary packages to ensure reproducibility, and is more 152 | # commonly ignored for libraries. 153 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 154 | #poetry.lock 155 | 156 | # pdm 157 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 158 | #pdm.lock 159 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 160 | # in version control. 161 | # https://pdm.fming.dev/#use-with-ide 162 | .pdm.toml 163 | 164 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 165 | __pypackages__/ 166 | 167 | # Celery stuff 168 | celerybeat-schedule 169 | celerybeat.pid 170 | 171 | # SageMath parsed files 172 | *.sage.py 173 | 174 | # Environments 175 | .env 176 | .venv 177 | env/ 178 | venv/ 179 | ENV/ 180 | env.bak/ 181 | venv.bak/ 182 | 183 | # Spyder project settings 184 | .spyderproject 185 | .spyproject 186 | 187 | # Rope project settings 188 | .ropeproject 189 | 190 | # mkdocs documentation 191 | /site 192 | 193 | # mypy 194 | .mypy_cache/ 195 | .dmypy.json 196 | dmypy.json 197 | 198 | # Pyre type checker 199 | .pyre/ 200 | 201 | # pytype static type analyzer 202 | .pytype/ 203 | 204 | # Cython debug symbols 205 | cython_debug/ 206 | 207 | # PyCharm 208 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 209 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 210 | # and can be added to the global gitignore or merged into this file. For a more nuclear 211 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 212 | #.idea/ 213 | 214 | ### Python Patch ### 215 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 216 | poetry.toml 217 | 218 | # ruff 219 | .ruff_cache/ 220 | 221 | # LSP config files 222 | pyrightconfig.json 223 | 224 | ### Vim ### 225 | # Swap 226 | [._]*.s[a-v][a-z] 227 | !*.svg # comment out if you don't need vector files 228 | [._]*.sw[a-p] 229 | [._]s[a-rt-v][a-z] 230 | [._]ss[a-gi-z] 231 | [._]sw[a-p] 232 | 233 | # Session 234 | Session.vim 235 | Sessionx.vim 236 | 237 | # Temporary 238 | .netrwhist 239 | # Auto-generated tag files 240 | tags 241 | # Persistent undo 242 | [._]*.un~ 243 | 244 | ### VisualStudioCode ### 245 | .vscode/* 246 | !.vscode/settings.json 247 | !.vscode/tasks.json 248 | !.vscode/launch.json 249 | !.vscode/extensions.json 250 | !.vscode/*.code-snippets 251 | 252 | # Local History for Visual Studio Code 253 | .history/ 254 | 255 | # Built Visual Studio Code Extensions 256 | *.vsix 257 | 258 | ### VisualStudioCode Patch ### 259 | # Ignore all local history of files 260 | .history 261 | .ionide 262 | 263 | # End of https://www.toptal.com/developers/gitignore/api/vim,visualstudiocode,python,macos,linux -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Kevin Black 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # susie 2 | Code for the paper [Zero-Shot Robotic Manipulation With Pretrained Image-Editing Diffusion Models](https://rail-berkeley.github.io/susie/). 3 | 4 | This repository contains the code for training the high-level image-editing diffusion model on video data. For training the low-level policy, head over to the [BridgeData V2](https://github.com/rail-berkeley/bridge_data_v2) repository --- we use the `gc_ddpm_bc` agent, unmodified, with an action prediction horizon of 4 and the `delta_goals` relabeling strategy. 5 | 6 | For integration with the CALVIN simulator and reproducing our simulated results, see [our fork of the calvin-sim repo](https://github.com/pranavatreya/calvin-sim) and the [corresponding documentation in the BridgeData V2 repository](https://github.com/rail-berkeley/bridge_data_v2/tree/main/experiments/susie/calvin). 7 | 8 | - **Creating datasets**: this repo uses [dlimp](https://github.com/kvablack/dlimp) for dataloading. Check out the `scripts/` directory inside dlimp for creating TFRecords in a compatible format. 9 | - **Installation**: `pip install -r requirements.txt` to install the versions of required packages confirmed to be working with this codebase. Then, `pip install -e .`. Only tested with Python 3.10. You'll also have to manually install Jax for your platform (see the [Jax installation instructions](https://jax.readthedocs.io/en/latest/installation.html)). Make sure you have the Jax version specified in `requirements.txt` (rather than using `--upgrade` as suggested in the Jax docs). 10 | - **Training**: once the missing dataset paths have been filled in inside `base.py`, you can start training by running `python scripts/train.py --config configs/base.py:base`. 11 | - **Evaluation**: robot evaluation scripts are provided in the `scripts/robot` directory. You probably won't be able to run them, since you don't have our robot setup, but they are there for reference. See `create_sample_fn` in `susie/model.py` for canonical sampling code. 12 | 13 | ## Model Weights 14 | The UNet weights for our best-performing model, trained on BridgeData and Something-Something for 40k steps, are hosted [on HuggingFace](https://huggingface.co/kvablack/susie). They can be loaded using `FlaxUNet2DConditionModel.from_pretrained("kvablack/susie", subfolder="unet")`. Use with the standard Stable Diffusion v1-5 VAE and text encoder. 15 | 16 | Here's a quickstart for getting out-of-the-box subgoals using this repo: 17 | ```python 18 | from susie.model import create_sample_fn 19 | from susie.jax_utils import initialize_compilation_cache 20 | import requests 21 | import numpy as np 22 | from PIL import Image 23 | 24 | initialize_compilation_cache() 25 | 26 | IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg" 27 | 28 | sample_fn = create_sample_fn("kvablack/susie") 29 | image = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256))) 30 | image_out = sample_fn(image, "open the drawer") 31 | 32 | # to display the images if you're in a Jupyter notebook 33 | display(Image.fromarray(image)) 34 | display(Image.fromarray(image_out)) 35 | ``` 36 | -------------------------------------------------------------------------------- /configs/base.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from ml_collections import ConfigDict 4 | 5 | 6 | def base(): 7 | config = ConfigDict() 8 | 9 | # top-level stuff 10 | config.seed = 42 11 | config.wandb_project = "susie" 12 | config.run_name = "" 13 | config.logdir = "logs" 14 | config.num_steps = 40000 15 | config.log_interval = 100 16 | config.save_interval = 5000 17 | config.val_interval = 2500 18 | config.sample_interval = 2500 19 | config.num_val_batches = 128 20 | config.goal_drop_rate = 1.0 21 | config.curr_drop_rate = 0.0 22 | config.prompt_drop_rate = 0.0 23 | config.mesh = [-1, 1] # dp, fsdp 24 | 25 | config.wandb_resume_id = None 26 | 27 | config.vae = "runwayml/stable-diffusion-v1-5:flax" 28 | config.text_encoder = "runwayml/stable-diffusion-v1-5:flax" 29 | 30 | # ema 31 | config.ema = ema = ConfigDict() 32 | ema.max_decay = 0.999 33 | ema.min_decay = 0.999 34 | ema.update_every = 1 35 | ema.start_step = 0 36 | ema.inv_gamma = 1.0 37 | ema.power = 3 / 4 38 | 39 | # optim 40 | config.optim = optim = ConfigDict() 41 | optim.optimizer = "adamw" 42 | optim.lr = 1e-4 43 | optim.warmup_steps = 800 # linear warmup steps 44 | optim.decay_steps = 1e9 # cosine decay total steps (reaches 0 at this number) 45 | optim.weight_decay = ( 46 | 1e-2 # adamw weight decay -- pytorch default (which instructpix2pix and SD use) 47 | ) 48 | optim.beta1 = 0.9 49 | optim.beta2 = 0.999 50 | optim.epsilon = 1e-8 51 | optim.max_grad_norm = 1.0 52 | optim.accumulation_steps = 1 53 | 54 | # scheduling 55 | config.scheduling = scheduling = ConfigDict() 56 | scheduling.noise_schedule = "scaled_linear" 57 | 58 | # sampling 59 | config.sample = sample = ConfigDict() 60 | sample.num_contexts = 8 61 | sample.num_samples_per_context = 8 62 | sample.num_steps = 50 63 | sample.context_w = 2.5 64 | sample.prompt_w = 7.5 65 | sample.eta = 0.0 66 | 67 | # data 68 | config.data = ConfigDict() 69 | config.data.batch_size = 128 70 | 71 | data_base = ConfigDict() 72 | data_base.image_size = 256 73 | data_base.shuffle_buffer_size = 100000 74 | data_base.augment_kwargs = dict( 75 | random_resized_crop=dict(scale=[0.85, 1.0], ratio=[0.95, 1.05]), 76 | random_brightness=[0.05], 77 | random_contrast=[0.95, 1.05], 78 | random_saturation=[0.95, 1.05], 79 | random_hue=[0.025], 80 | augment_order=[ 81 | # "random_flip_left_right", 82 | # "random_resized_crop", 83 | # "random_brightness", 84 | # "random_contrast", 85 | # "random_saturation", 86 | # "random_hue", 87 | # "random_flip_up_down", 88 | # "random_rot90", 89 | ], 90 | ) 91 | 92 | # config.data.ego4d = ego4d = deepcopy(data_base) 93 | # ego4d.weight = 70.0 94 | # ego4d.data_path = "" 95 | # ego4d.goal_relabeling_fn = "subgoal_only" 96 | # ego4d.goal_relabeling_kwargs = dict( 97 | # subgoal_delta=(30, 60), 98 | # truncate=True, 99 | # ) 100 | 101 | config.data.bridge = bridge = deepcopy(data_base) 102 | bridge.weight = 45.0 103 | bridge.data_path = "" 104 | bridge.goal_relabeling_fn = "subgoal_only" 105 | bridge.goal_relabeling_kwargs = dict( 106 | subgoal_delta=(11, 14), 107 | truncate=False, 108 | ) 109 | 110 | # config.data.calvin = calvin = deepcopy(data_base) 111 | # calvin.weight = 15.0 112 | # calvin.data_path = "" 113 | # calvin.goal_relabeling_fn = "subgoal_only" 114 | # calvin.goal_relabeling_kwargs = dict( 115 | # subgoal_delta=(20, 21), 116 | # truncate=False, 117 | # ) 118 | 119 | config.data.somethingsomething = somethingsomething = deepcopy(data_base) 120 | somethingsomething.weight = 75.0 121 | somethingsomething.data_path = "" 122 | somethingsomething.goal_relabeling_fn = "subgoal_only" 123 | somethingsomething.goal_relabeling_kwargs = dict( 124 | subgoal_delta=(11, 14), 125 | truncate=False, 126 | ) 127 | 128 | # model 129 | config.model = model = ConfigDict() 130 | config.model.pretrained = "kvablack/instruct-pix2pix-flax" 131 | 132 | return config 133 | 134 | 135 | def debug(): 136 | config = base() 137 | config.logdir = "logs" 138 | config.log_interval = 150 139 | config.save_interval = 150 140 | config.val_interval = 150 141 | config.sample_interval = 10 142 | config.num_val_batches = 4 143 | 144 | config.vae = "runwayml/stable-diffusion-v1-5:flax" 145 | config.text_encoder = "runwayml/stable-diffusion-v1-5:flax" 146 | 147 | config.sample.num_contexts = 4 148 | config.sample.num_samples_per_context = 4 149 | config.sample.num_steps = 20 150 | config.sample.w = 1.0 151 | 152 | config.data.batch_size = 16 153 | for data in [d for d in config.data.values() if isinstance(d, ConfigDict)]: 154 | data.shuffle_buffer_size = 100 155 | # data.image_size = 32 156 | 157 | config.model.pretrained = None 158 | config.model.block_out_channels = (32, 32) 159 | config.model.down_block_types = ( 160 | "DownBlock2D", 161 | "DownBlock2D", 162 | # "DownBlock2D", 163 | # "CrossAttnDownBlock2D", 164 | # "DownBlock2D", 165 | ) 166 | config.model.up_block_types = ( 167 | "UpBlock2D", 168 | "UpBlock2D", 169 | # "CrossAttnUpBlock2D", 170 | # "UpBlock2D", 171 | # "UpBlock2D", 172 | ) 173 | config.model.layers_per_block = 1 174 | config.model.attention_head_dim = 1 175 | 176 | return config 177 | 178 | 179 | def get_config(name): 180 | return globals()[name]() 181 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | distrax==0.1.2 2 | flax==0.7.0 3 | transformers==4.33.1 4 | dlimp @ git+https://github.com/kvablack/dlimp@166a1b271b19434c3a04b7cfacce0aa0f27bf413 5 | einops==0.6.1 6 | wandb==0.15.5 7 | tensorflow==2.13.0 8 | ml-collections==0.1.0 9 | jax==0.4.11 10 | optax==0.1.5 11 | diffusers==0.18.2 12 | ml-dtypes==0.2.0 13 | -------------------------------------------------------------------------------- /scripts/policy_ckpt_to_hlo.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import jax 4 | import numpy as np 5 | import orbax.checkpoint 6 | import tensorflow as tf 7 | from absl import app, flags 8 | from jaxrl_m.agents import agents 9 | from jaxrl_m.vision import encoders 10 | 11 | import wandb 12 | from susie.jax_utils import serialize_jax_fn 13 | 14 | FLAGS = flags.FLAGS 15 | 16 | flags.DEFINE_string("checkpoint_path", None, "Path to checkpoint file", required=True) 17 | 18 | flags.DEFINE_string( 19 | "wandb_run_name", None, "Name of wandb run to get config from.", required=True 20 | ) 21 | 22 | flags.DEFINE_string( 23 | "outpath", None, "Path to save serialized policy to.", required=True 24 | ) 25 | 26 | flags.DEFINE_integer( 27 | "im_size", 256, "Image size, which was unfortunately not saved in config" 28 | ) 29 | 30 | 31 | def load_policy_checkpoint(path, wandb_run_name): 32 | assert tf.io.gfile.exists(path) 33 | 34 | # load information from wandb 35 | api = wandb.Api() 36 | run = api.run(wandb_run_name) 37 | config = run.config 38 | 39 | # create encoder from wandb config 40 | encoder_def = encoders[config["encoder"]](**config["encoder_kwargs"]) 41 | 42 | act_pred_horizon = run.config["dataset_kwargs"].get("act_pred_horizon") 43 | obs_horizon = run.config.get("obs_horizon") or run.config["dataset_kwargs"].get( 44 | "obs_horizon" 45 | ) 46 | 47 | if act_pred_horizon is not None: 48 | example_actions = np.zeros((1, act_pred_horizon, 7), dtype=np.float32) 49 | else: 50 | example_actions = np.zeros((1, 7), dtype=np.float32) 51 | 52 | if obs_horizon is not None: 53 | example_obs = { 54 | "image": np.zeros( 55 | (1, obs_horizon, FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8 56 | ) 57 | } 58 | else: 59 | example_obs = { 60 | "image": np.zeros((1, FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) 61 | } 62 | 63 | example_goal = { 64 | "image": np.zeros((1, FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) 65 | } 66 | 67 | example_batch = { 68 | "observations": example_obs, 69 | "actions": example_actions, 70 | "goals": example_goal, 71 | } 72 | 73 | # create agent from wandb config 74 | agent = agents[config["agent"]].create( 75 | rng=jax.random.PRNGKey(0), 76 | encoder_def=encoder_def, 77 | observations=example_batch["observations"], 78 | goals=example_batch["goals"], 79 | actions=example_batch["actions"], 80 | **config["agent_kwargs"], 81 | ) 82 | 83 | # load action metadata from wandb 84 | # action_metadata = config["bridgedata_config"]["action_metadata"] 85 | # action_mean = np.array(action_metadata["mean"]) 86 | # action_std = np.array(action_metadata["std"]) 87 | 88 | # load action metadata from wandb 89 | action_proprio_metadata = run.config["bridgedata_config"]["action_proprio_metadata"] 90 | action_mean = np.array(action_proprio_metadata["action"]["mean"]) 91 | action_std = np.array(action_proprio_metadata["action"]["std"]) 92 | 93 | # hydrate agent with parameters from checkpoint 94 | agent = orbax.checkpoint.PyTreeCheckpointer().restore( 95 | path, 96 | item=agent, 97 | ) 98 | 99 | def get_action(rng, obs_image, goal_image): 100 | obs = {"image": obs_image} 101 | goal_obs = {"image": goal_image} 102 | # some agents (e.g. DDPM) don't have argmax 103 | if inspect.signature(agent.sample_actions).parameters.get("argmax"): 104 | action = agent.sample_actions(obs, goal_obs, seed=rng, argmax=True) 105 | else: 106 | action = agent.sample_actions(obs, goal_obs, seed=rng) 107 | action = action * action_std + action_mean 108 | return action 109 | 110 | serialized = serialize_jax_fn( 111 | get_action, 112 | jax.random.PRNGKey(0), 113 | example_obs["image"][0], 114 | example_goal["image"][0], 115 | ) 116 | 117 | return serialized 118 | 119 | 120 | def main(_): 121 | serialized = load_policy_checkpoint(FLAGS.checkpoint_path, FLAGS.wandb_run_name) 122 | 123 | with open(FLAGS.outpath, "wb") as f: 124 | f.write(serialized) 125 | 126 | 127 | if __name__ == "__main__": 128 | app.run(main) 129 | -------------------------------------------------------------------------------- /scripts/robot/eval_diffusion.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | from collections import deque 4 | 5 | from susie.model import create_sample_fn 6 | 7 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 8 | import time 9 | from typing import Callable, List, Tuple 10 | 11 | import imageio 12 | import jax 13 | import numpy as np 14 | from absl import app, flags 15 | from pyquaternion import Quaternion 16 | 17 | # bridge_data_robot imports 18 | from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs 19 | 20 | from susie.jax_utils import ( 21 | deserialize_jax_fn, 22 | initialize_compilation_cache, 23 | ) 24 | 25 | ############################################################################## 26 | 27 | STEP_DURATION = 0.2 28 | NO_PITCH_ROLL = False 29 | NO_YAW = False 30 | STICKY_GRIPPER_NUM_STEPS = 2 31 | ENV_PARAMS = { 32 | "camera_topics": [{"name": "/blue/image_raw", "flip": False}], 33 | # forward, left, up 34 | # wallpaper 35 | # "override_workspace_boundaries": [ 36 | # [0.1, -0.15, 0.0, -1.57, 0], 37 | # [0.60, 0.25, 0.18, 1.57, 0], 38 | # ], 39 | # toysink2 40 | # "override_workspace_boundaries": [ 41 | # [0.21, -0.13, 0.06, -1.57, 0], 42 | # [0.36, 0.25, 0.18, 1.57, 0], 43 | # ], 44 | # microwave 45 | "override_workspace_boundaries": [ 46 | [0.1, -0.15, 0.05, -1.57, 0], 47 | [0.35, 0.25, 0.23, 1.57, 0], 48 | ], 49 | "move_duration": STEP_DURATION, 50 | } 51 | 52 | FIXED_STD = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) 53 | 54 | 55 | ############################################################################## 56 | 57 | np.set_printoptions(suppress=True) 58 | 59 | FLAGS = flags.FLAGS 60 | 61 | flags.DEFINE_string( 62 | "policy_checkpoint", None, "Path to policy checkpoint", required=True 63 | ) 64 | flags.DEFINE_string( 65 | "diffusion_checkpoint", None, "Path to diffusion checkpoint", required=True 66 | ) 67 | flags.DEFINE_string( 68 | "diffusion_wandb", 69 | None, 70 | "Name of wandb run to get diffusion config from.", 71 | required=True, 72 | ) 73 | flags.DEFINE_integer("diffusion_num_steps", 50, "Number of diffusion steps") 74 | flags.DEFINE_string( 75 | "diffusion_pretrained_path", 76 | None, 77 | "Path to pretrained model to get text encoder + VAE from.", 78 | required=True, 79 | ) 80 | 81 | flags.DEFINE_float("prompt_w", 1.0, "CFG weight to use for diffusion sampler") 82 | flags.DEFINE_float("context_w", 1.0, "CFG weight to use for diffusion sampler") 83 | 84 | flags.DEFINE_string("video_save_path", None, "Path to save video") 85 | 86 | flags.DEFINE_integer("num_timesteps", 40, "num timesteps") 87 | flags.DEFINE_bool("blocking", True, "Use the blocking controller") 88 | 89 | flags.DEFINE_spaceseplist("initial_eep", None, "Initial position", required=True) 90 | 91 | flags.DEFINE_string("ip", "localhost", "IP address of the robot") 92 | flags.DEFINE_integer("port", 5556, "Port of the robot") 93 | 94 | 95 | def state_to_eep(xyz_coor, zangle: float): 96 | """ 97 | Implement the state to eep function. 98 | Refered to `bridge_data_robot`'s `widowx_controller/widowx_controller.py` 99 | return a 4x4 matrix 100 | """ 101 | assert len(xyz_coor) == 3 102 | DEFAULT_ROTATION = np.array([[0, 0, 1.0], [0, 1.0, 0], [-1.0, 0, 0]]) 103 | new_pose = np.eye(4) 104 | new_pose[:3, -1] = xyz_coor 105 | new_quat = Quaternion(axis=np.array([0.0, 0.0, 1.0]), angle=zangle) * Quaternion( 106 | matrix=DEFAULT_ROTATION 107 | ) 108 | new_pose[:3, :3] = new_quat.rotation_matrix 109 | # yaw, pitch, roll = quat.yaw_pitch_roll 110 | return new_pose 111 | 112 | 113 | def rollout_subgoal( 114 | widowx_client: WidowXClient, 115 | get_action: Callable[[np.ndarray, np.ndarray], np.ndarray], 116 | goal_obs: np.ndarray, 117 | num_timesteps: int, 118 | obs_horizon: int, 119 | is_gripper_closed: bool = False, 120 | ) -> Tuple[List[np.ndarray], List[np.ndarray], bool]: 121 | num_consecutive_gripper_change_actions = 0 122 | 123 | last_tstep = time.time() 124 | images = [] 125 | full_images = [] 126 | t = 0 127 | actions = None 128 | rng = jax.random.PRNGKey(int(time.time())) 129 | if obs_horizon is not None: 130 | obs_hist = deque(maxlen=obs_horizon) 131 | try: 132 | while t < num_timesteps: 133 | if time.time() > last_tstep + STEP_DURATION or FLAGS.blocking: 134 | obs = widowx_client.get_observation() 135 | if obs is None: 136 | print("WARNING retrying to get observation...") 137 | continue 138 | 139 | obs = ( 140 | obs["image"] 141 | .reshape(3, goal_obs.shape[0], goal_obs.shape[1]) 142 | .transpose(1, 2, 0) 143 | * 255 144 | ).astype(np.uint8) 145 | images.append(obs) 146 | 147 | # deal with obs history 148 | if obs_horizon is not None: 149 | if len(obs_hist) == 0: 150 | obs_hist.extend([obs] * obs_horizon) 151 | else: 152 | obs_hist.append(obs) 153 | obs = np.stack(obs_hist) 154 | 155 | last_tstep = time.time() 156 | 157 | # deal with mutli-action prediction 158 | rng, key = jax.random.split(rng) 159 | pred_actions = jax.device_get(get_action(key, obs, goal_obs)) 160 | if len(pred_actions.shape) == 1: 161 | pred_actions = pred_actions[None] 162 | if actions is None: 163 | actions = np.zeros_like(pred_actions) 164 | weights = 1 / (np.arange(len(pred_actions)) + 1) 165 | else: 166 | actions = np.concatenate([actions[1:], np.zeros_like(actions[-1:])]) 167 | weights = np.concatenate([weights[1:], [1 / len(weights)]]) 168 | actions += pred_actions * weights[:, None] 169 | 170 | action = actions[0] 171 | 172 | # sticky gripper logic 173 | if (action[-1] < 0.5) != is_gripper_closed: 174 | num_consecutive_gripper_change_actions += 1 175 | else: 176 | num_consecutive_gripper_change_actions = 0 177 | 178 | if num_consecutive_gripper_change_actions >= STICKY_GRIPPER_NUM_STEPS: 179 | is_gripper_closed = not is_gripper_closed 180 | num_consecutive_gripper_change_actions = 0 181 | 182 | action[-1] = 0.0 if is_gripper_closed else 1.0 183 | 184 | # remove degrees of freedom 185 | if NO_PITCH_ROLL: 186 | action[3] = 0 187 | action[4] = 0 188 | if NO_YAW: 189 | action[5] = 0 190 | 191 | action_norm = np.linalg.norm(action[:3]) 192 | 193 | print( 194 | f"Timestep {t}, action norm: {action_norm * 100:.1f}cm, gripper state: {action[-1]}" 195 | ) 196 | widowx_client.step_action(action, blocking=FLAGS.blocking) 197 | 198 | t += 1 199 | except KeyboardInterrupt: 200 | return images, full_images, is_gripper_closed, True 201 | return images, full_images, is_gripper_closed, False 202 | 203 | 204 | def main(_): 205 | initialize_compilation_cache() 206 | get_action = deserialize_jax_fn(FLAGS.policy_checkpoint) 207 | 208 | obs_horizon = get_action.args_info[0][1].aval.shape[0] 209 | im_size = get_action.args_info[0][1].aval.shape[1] 210 | 211 | diffusion_sample = create_sample_fn( 212 | FLAGS.diffusion_checkpoint, 213 | FLAGS.diffusion_wandb, 214 | FLAGS.diffusion_num_steps, 215 | FLAGS.prompt_w, 216 | FLAGS.context_w, 217 | 0.0, 218 | FLAGS.diffusion_pretrained_path, 219 | ) 220 | 221 | print(f"obs horizon: {obs_horizon}, im size: {im_size}") 222 | 223 | # init environment 224 | env_params = WidowXConfigs.DefaultEnvParams.copy() 225 | env_params.update(ENV_PARAMS) 226 | widowx_client = WidowXClient(host=FLAGS.ip, port=FLAGS.port) 227 | 228 | # goal sampling loop 229 | prompt = None 230 | is_gripper_closed = False # track gripper state between subgoals 231 | while True: 232 | # ask for new goal 233 | if prompt is None or input("New prompt? [y/n]") == "y": 234 | prompt = input("Enter prompt: ") 235 | 236 | widowx_client.init(env_params) 237 | 238 | assert isinstance(FLAGS.initial_eep, list) 239 | initial_eep = [float(e) for e in FLAGS.initial_eep] 240 | widowx_client.move_gripper(1.0) # open gripper 241 | widowx_client.move_gripper(1.0) # open gripper 242 | 243 | print(f"Moving to position {initial_eep}") 244 | print(widowx_client.move(state_to_eep(initial_eep, 0), blocking=True)) 245 | time.sleep(2.0) 246 | print(widowx_client.move(state_to_eep(initial_eep, 0), blocking=True)) 247 | 248 | input("Press [Enter] to start.") 249 | 250 | # take image 251 | obs = widowx_client.get_observation() 252 | while obs is None: 253 | print("WARNING retrying to get observation...") 254 | obs = widowx_client.get_observation() 255 | time.sleep(1) 256 | 257 | image_obs = ( 258 | obs["image"].reshape(3, im_size, im_size).transpose(1, 2, 0) * 255 259 | ).astype(np.uint8) 260 | 261 | images = [] 262 | goals = [] 263 | full_images = [] 264 | done = False 265 | n = 0 266 | while not done: 267 | # sample goal 268 | print(f"Sampling goal {n}...") 269 | imageio.imwrite("start.png", image_obs) 270 | print( 271 | image_obs.shape, image_obs.dtype, np.max(image_obs), np.min(image_obs) 272 | ) 273 | print(f"'{prompt}'") 274 | image_goal = diffusion_sample(image_obs, prompt) 275 | imageio.imwrite("goal.png", image_goal) 276 | 277 | # do rollout 278 | ( 279 | rollout_images, 280 | rollout_full_images, 281 | is_gripper_closed, 282 | done, 283 | ) = rollout_subgoal( 284 | widowx_client, 285 | get_action, 286 | image_goal, 287 | FLAGS.num_timesteps, 288 | obs_horizon, 289 | is_gripper_closed, 290 | ) 291 | images.extend(rollout_images) 292 | full_images.extend(rollout_full_images) 293 | goals.extend([image_goal] * len(rollout_images)) 294 | 295 | image_obs = rollout_images[-1] 296 | 297 | n += 1 298 | 299 | if FLAGS.video_save_path is not None: 300 | save_path = os.path.join( 301 | FLAGS.video_save_path, 302 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S.mp4"), 303 | ) 304 | video = np.concatenate([goals, images], axis=1) 305 | imageio.mimsave( 306 | save_path, 307 | video, 308 | fps=1.0 / STEP_DURATION * 3, 309 | ) 310 | with open(save_path.replace(".mp4", "_prompt.txt"), "w") as f: 311 | f.write(prompt) 312 | 313 | 314 | if __name__ == "__main__": 315 | app.run(main) 316 | -------------------------------------------------------------------------------- /scripts/robot/eval_gcbc.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | from collections import deque 5 | 6 | import imageio 7 | import jax 8 | import numpy as np 9 | from absl import app, flags 10 | from pyquaternion import Quaternion 11 | 12 | # bridge_data_robot imports 13 | from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs 14 | 15 | import wandb 16 | from susie.jax_utils import ( 17 | deserialize_jax_fn, 18 | initialize_compilation_cache, 19 | ) 20 | 21 | ############################################################################## 22 | 23 | STEP_DURATION = 0.2 24 | NO_PITCH_ROLL = False 25 | NO_YAW = False 26 | STICKY_GRIPPER_NUM_STEPS = 2 27 | ENV_PARAMS = { 28 | "camera_topics": [{"name": "/blue/image_raw", "flip": False}], 29 | "return_full_image": False, 30 | # forward, left, up 31 | # wallpaper 32 | # "override_workspace_boundaries": [ 33 | # [0.1, -0.15, 0.0, -1.57, 0], 34 | # [0.60, 0.25, 0.18, 1.57, 0], 35 | # ], 36 | # toysink2 37 | # "override_workspace_boundaries": [ 38 | # [0.21, -0.13, 0.06, -1.57, 0], 39 | # [0.36, 0.25, 0.18, 1.57, 0], 40 | # ], 41 | # microwave 42 | "override_workspace_boundaries": [ 43 | [0.1, -0.15, 0.05, -1.57, 0], 44 | [0.31, 0.25, 0.23, 1.57, 0], 45 | ], 46 | } 47 | 48 | FIXED_STD = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) 49 | 50 | 51 | ############################################################################## 52 | 53 | np.set_printoptions(suppress=True) 54 | 55 | FLAGS = flags.FLAGS 56 | 57 | flags.DEFINE_string( 58 | "policy_checkpoint", None, "Path to policy checkpoint", required=True 59 | ) 60 | flags.DEFINE_string("video_save_path", None, "Path to save video") 61 | 62 | flags.DEFINE_integer("num_timesteps", 120, "num timesteps") 63 | flags.DEFINE_bool("blocking", True, "Use the blocking controller") 64 | 65 | flags.DEFINE_spaceseplist("goal_eep", None, "Goal position") 66 | flags.DEFINE_spaceseplist("initial_eep", None, "Initial position") 67 | 68 | flags.DEFINE_string("ip", "localhost", "IP address of the robot") 69 | flags.DEFINE_integer("port", 5556, "Port of the robot") 70 | 71 | 72 | def state_to_eep(xyz_coor, zangle: float): 73 | """ 74 | Implement the state to eep function. 75 | Refered to `bridge_data_robot`'s `widowx_controller/widowx_controller.py` 76 | return a 4x4 matrix 77 | """ 78 | assert len(xyz_coor) == 3 79 | DEFAULT_ROTATION = np.array([[0, 0, 1.0], [0, 1.0, 0], [-1.0, 0, 0]]) 80 | new_pose = np.eye(4) 81 | new_pose[:3, -1] = xyz_coor 82 | new_quat = Quaternion(axis=np.array([0.0, 0.0, 1.0]), angle=zangle) * Quaternion( 83 | matrix=DEFAULT_ROTATION 84 | ) 85 | new_pose[:3, :3] = new_quat.rotation_matrix 86 | # yaw, pitch, roll = quat.yaw_pitch_roll 87 | return new_pose 88 | 89 | 90 | def rollout_subgoal( 91 | widowx_client, get_action, goal_obs, num_timesteps, obs_horizon, im_size 92 | ): 93 | is_gripper_closed = False 94 | num_consecutive_gripper_change_actions = 0 95 | 96 | last_tstep = time.time() 97 | images = [] 98 | full_images = [] 99 | t = 0 100 | actions = None 101 | rng = jax.random.PRNGKey(int(time.time())) 102 | if obs_horizon is not None: 103 | obs_hist = deque(maxlen=obs_horizon) 104 | try: 105 | while t < num_timesteps: 106 | if time.time() > last_tstep + STEP_DURATION or FLAGS.blocking: 107 | obs = widowx_client.get_observation() 108 | if obs is None: 109 | print("WARNING retrying to get observation...") 110 | continue 111 | 112 | full_images.append(obs["full_image"]) 113 | 114 | obs = ( 115 | obs["image"].reshape(3, im_size, im_size).transpose(1, 2, 0) * 255 116 | ).astype(np.uint8) 117 | images.append(obs) 118 | 119 | # deal with obs history 120 | if obs_horizon is not None: 121 | if len(obs_hist) == 0: 122 | obs_hist.extend([obs] * obs_horizon) 123 | else: 124 | obs_hist.append(obs) 125 | obs = np.stack(obs_hist) 126 | 127 | last_tstep = time.time() 128 | 129 | # deal with mutli-action prediction 130 | rng, key = jax.random.split(rng) 131 | pred_actions = jax.device_get(get_action(key, obs, goal_obs)) 132 | if len(pred_actions.shape) == 1: 133 | pred_actions = pred_actions[None] 134 | if actions is None: 135 | actions = np.zeros_like(pred_actions) 136 | weights = 1 / (np.arange(len(pred_actions)) + 1) 137 | else: 138 | actions = np.concatenate([actions[1:], np.zeros_like(actions[-1:])]) 139 | weights = np.concatenate([weights[1:], [1 / len(weights)]]) 140 | actions += pred_actions * weights[:, None] 141 | 142 | action = actions[0] 143 | 144 | # sticky gripper logic 145 | if (action[-1] < 0.5) != is_gripper_closed: 146 | num_consecutive_gripper_change_actions += 1 147 | else: 148 | num_consecutive_gripper_change_actions = 0 149 | 150 | if num_consecutive_gripper_change_actions >= STICKY_GRIPPER_NUM_STEPS: 151 | is_gripper_closed = not is_gripper_closed 152 | num_consecutive_gripper_change_actions = 0 153 | 154 | action[-1] = 0.0 if is_gripper_closed else 1.0 155 | 156 | # remove degrees of freedom 157 | if NO_PITCH_ROLL: 158 | action[3] = 0 159 | action[4] = 0 160 | if NO_YAW: 161 | action[5] = 0 162 | 163 | action_norm = np.linalg.norm(action[:3]) 164 | 165 | print( 166 | f"Timestep {t}, action norm: {action_norm * 100:.1f}cm, gripper state: {action[-1]}" 167 | ) 168 | widowx_client.step_action(action, blocking=FLAGS.blocking) 169 | 170 | t += 1 171 | except KeyboardInterrupt: 172 | return images, full_images, True 173 | return images, full_images, False 174 | 175 | 176 | def main(_): 177 | initialize_compilation_cache() 178 | get_action = deserialize_jax_fn(FLAGS.policy_checkpoint) 179 | 180 | obs_horizon = get_action.args_info[0][1].aval.shape[0] 181 | im_size = get_action.args_info[0][1].aval.shape[1] 182 | 183 | print(f"obs horizon: {obs_horizon}, im size: {im_size}") 184 | 185 | # init environment 186 | env_params = WidowXConfigs.DefaultEnvParams.copy() 187 | env_params.update(ENV_PARAMS) 188 | widowx_client = WidowXClient(host=FLAGS.ip, port=FLAGS.port) 189 | 190 | # goal sampling loop 191 | image_goal = None 192 | while True: 193 | widowx_client.init(env_params) 194 | 195 | # ask for new goal 196 | if image_goal is None: 197 | print("Taking a new goal...") 198 | ch = "y" 199 | else: 200 | ch = input("Taking a new goal? [y/n]") 201 | if ch == "y": 202 | if FLAGS.goal_eep is not None: 203 | assert isinstance(FLAGS.goal_eep, list) 204 | goal_eep = [float(e) for e in FLAGS.goal_eep] 205 | else: 206 | # pick random goal eep 207 | low_bound = [0.24, -0.1, 0.05, -1.57, 0] 208 | high_bound = [0.4, 0.20, 0.15, 1.57, 0] 209 | goal_eep = np.random.uniform(low_bound[:3], high_bound[:3]) 210 | widowx_client.move_gripper(1.0) # open gripper 211 | widowx_client.move_gripper(1.0) # open gripper 212 | 213 | print(f"Moving to goal position {goal_eep}") 214 | widowx_client.move(state_to_eep(goal_eep, 0), blocking=True) 215 | input("Press [Enter] when ready for taking the goal image. ") 216 | 217 | # take goal image 218 | obs = widowx_client.get_observation() 219 | while obs is None: 220 | print("WARNING retrying to get observation...") 221 | obs = widowx_client.get_observation() 222 | time.sleep(1) 223 | 224 | image_goal = ( 225 | obs["image"].reshape(3, im_size, im_size).transpose(1, 2, 0) * 255 226 | ).astype(np.uint8) 227 | 228 | # move to initial position 229 | if FLAGS.initial_eep is not None: 230 | assert isinstance(FLAGS.initial_eep, list) 231 | initial_eep = [float(e) for e in FLAGS.initial_eep] 232 | print(f"Moving to initial position {initial_eep}") 233 | widowx_client.move_gripper(1.0) # open gripper 234 | widowx_client.move_gripper(1.0) # open gripper 235 | widowx_client.move(state_to_eep(initial_eep, 0), blocking=True) 236 | 237 | input("Press [Enter] to start.") 238 | 239 | # do rollout 240 | images, full_images, done = rollout_subgoal( 241 | widowx_client, 242 | get_action, 243 | image_goal, 244 | FLAGS.num_timesteps, 245 | obs_horizon, 246 | im_size, 247 | ) 248 | 249 | if FLAGS.video_save_path is not None: 250 | save_path = os.path.join( 251 | FLAGS.video_save_path, 252 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S.mp4"), 253 | ) 254 | images = np.array(images) 255 | video = np.concatenate( 256 | [np.broadcast_to(image_goal[None], images.shape), images], axis=1 257 | ) 258 | imageio.mimsave( 259 | save_path, 260 | video, 261 | fps=3.0 / STEP_DURATION, 262 | ) 263 | 264 | 265 | if __name__ == "__main__": 266 | app.run(main) 267 | -------------------------------------------------------------------------------- /scripts/robot/eval_lcbc.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | from functools import partial 5 | 6 | import cv2 7 | import imageio 8 | import jax 9 | import numpy as np 10 | import orbax.checkpoint 11 | from absl import app, flags 12 | from jaxrl_m.agents import agents 13 | from jaxrl_m.data.text_processing import text_processors 14 | from jaxrl_m.vision import encoders 15 | from pyquaternion import Quaternion 16 | 17 | # bridge_data_robot imports 18 | from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs 19 | 20 | import wandb 21 | 22 | ############################################################################## 23 | 24 | STEP_DURATION = 0.2 25 | NO_PITCH_ROLL = False 26 | NO_YAW = False 27 | STICKY_GRIPPER_NUM_STEPS = 2 28 | ENV_PARAMS = { 29 | "camera_topics": [{"name": "/blue/image_raw", "flip": False}], 30 | "return_full_image": False, 31 | # toysink2 32 | "override_workspace_boundaries": [ 33 | [0.21, -0.13, 0.06, -1.57, 0], 34 | [0.37, 0.25, 0.18, 1.57, 0], 35 | ], 36 | } 37 | 38 | FIXED_STD = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) 39 | 40 | 41 | ############################################################################## 42 | 43 | np.set_printoptions(suppress=True) 44 | 45 | FLAGS = flags.FLAGS 46 | 47 | flags.DEFINE_string( 48 | "policy_checkpoint", None, "Path to policy checkpoint", required=True 49 | ) 50 | flags.DEFINE_string( 51 | "policy_wandb", None, "Policy checkpoint wandb run name", required=True 52 | ) 53 | 54 | flags.DEFINE_integer("im_size", None, "Image size", required=True) 55 | 56 | flags.DEFINE_string("video_save_path", None, "Path to save video") 57 | 58 | flags.DEFINE_integer("num_timesteps", 120, "num timesteps") 59 | flags.DEFINE_bool("blocking", True, "Use the blocking controller") 60 | 61 | flags.DEFINE_spaceseplist("goal_eep", None, "Goal position") 62 | flags.DEFINE_spaceseplist("initial_eep", None, "Initial position") 63 | 64 | flags.DEFINE_string("ip", "localhost", "IP address of the robot") 65 | flags.DEFINE_integer("port", 5556, "Port of the robot") 66 | 67 | 68 | def state_to_eep(xyz_coor, zangle: float): 69 | """ 70 | Implement the state to eep function. 71 | Refered to `bridge_data_robot`'s `widowx_controller/widowx_controller.py` 72 | return a 4x4 matrix 73 | """ 74 | assert len(xyz_coor) == 3 75 | DEFAULT_ROTATION = np.array([[0, 0, 1.0], [0, 1.0, 0], [-1.0, 0, 0]]) 76 | new_pose = np.eye(4) 77 | new_pose[:3, -1] = xyz_coor 78 | new_quat = Quaternion(axis=np.array([0.0, 0.0, 1.0]), angle=zangle) * Quaternion( 79 | matrix=DEFAULT_ROTATION 80 | ) 81 | new_pose[:3, :3] = new_quat.rotation_matrix 82 | # yaw, pitch, roll = quat.yaw_pitch_roll 83 | return new_pose 84 | 85 | 86 | def load_policy_checkpoint(path, wandb_run_name): 87 | # load information from wandb 88 | api = wandb.Api() 89 | run = api.run(wandb_run_name) 90 | config = run.config 91 | 92 | # create encoder from wandb config 93 | encoder_def = encoders[config["encoder"]](**config["encoder_kwargs"]) 94 | 95 | example_actions = np.zeros((1, 7), dtype=np.float32) 96 | 97 | example_obs = { 98 | "image": np.zeros((1, FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) 99 | } 100 | 101 | example_batch = { 102 | "observations": example_obs, 103 | "goals": { 104 | "language": np.zeros( 105 | ( 106 | 1, 107 | 512, 108 | ), 109 | dtype=np.float32, 110 | ), 111 | }, 112 | "actions": example_actions, 113 | } 114 | 115 | # create agent from wandb config 116 | agent = jax.eval_shape( 117 | partial( 118 | agents[config["agent"]].create, 119 | rng=jax.random.PRNGKey(0), 120 | encoder_def=encoder_def, 121 | observations=example_batch["observations"], 122 | goals=example_batch["goals"], 123 | actions=example_batch["actions"], 124 | **config["agent_kwargs"], 125 | ), 126 | ) 127 | 128 | # load action metadata from wandb 129 | action_metadata = config["bridgedata_config"]["action_proprio_metadata"]["action"] 130 | action_mean = np.array(action_metadata["mean"]) 131 | action_std = np.array(action_metadata["std"]) 132 | 133 | # load text processor 134 | text_processor = text_processors[config["text_processor"]]( 135 | **config["text_processor_kwargs"] 136 | ) 137 | 138 | # hydrate agent with parameters from checkpoint 139 | agent = orbax.checkpoint.PyTreeCheckpointer().restore( 140 | path, 141 | item=agent, 142 | ) 143 | 144 | rng = jax.random.PRNGKey(0) 145 | 146 | def get_action(obs, goal_obs): 147 | nonlocal rng 148 | if "128" in path: 149 | obs["image"] = cv2.resize(obs["image"], (128, 128)) 150 | goal_obs["image"] = cv2.resize(goal_obs["image"], (128, 128)) 151 | rng, key = jax.random.split(rng) 152 | action = jax.device_get( 153 | agent.sample_actions(obs, goal_obs, seed=key, argmax=True) 154 | ) 155 | action = action * action_std + action_mean 156 | action += np.random.normal(0, FIXED_STD) 157 | return action 158 | 159 | return get_action, text_processor 160 | 161 | 162 | def rollout_subgoal(widowx_client, get_action, prompt_embed, num_timesteps): 163 | goal_obs = { 164 | "language": prompt_embed, 165 | } 166 | 167 | is_gripper_closed = False 168 | num_consecutive_gripper_change_actions = 0 169 | 170 | last_tstep = time.time() 171 | images = [] 172 | full_images = [] 173 | t = 0 174 | try: 175 | while t < num_timesteps: 176 | if time.time() > last_tstep + STEP_DURATION or FLAGS.blocking: 177 | obs = widowx_client.get_observation() 178 | if obs is None: 179 | print("WARNING retrying to get observation...") 180 | continue 181 | 182 | image_obs = ( 183 | obs["image"] 184 | .reshape(3, FLAGS.im_size, FLAGS.im_size) 185 | .transpose(1, 2, 0) 186 | * 255 187 | ).astype(np.uint8) 188 | images.append(image_obs) 189 | obs = {"image": image_obs, "proprio": obs["state"]} 190 | 191 | last_tstep = time.time() 192 | 193 | action = get_action(obs, goal_obs)[0] 194 | 195 | # sticky gripper logic 196 | if (action[-1] < 0.5) != is_gripper_closed: 197 | num_consecutive_gripper_change_actions += 1 198 | else: 199 | num_consecutive_gripper_change_actions = 0 200 | 201 | if num_consecutive_gripper_change_actions >= STICKY_GRIPPER_NUM_STEPS: 202 | is_gripper_closed = not is_gripper_closed 203 | num_consecutive_gripper_change_actions = 0 204 | 205 | action[-1] = 0.0 if is_gripper_closed else 1.0 206 | 207 | ### Preprocess action ### 208 | if NO_PITCH_ROLL: 209 | action[3] = 0 210 | action[4] = 0 211 | if NO_YAW: 212 | action[5] = 0 213 | 214 | print( 215 | f"Timestep {t}, action norm: {np.linalg.norm(action[:3] * 100):.1f}cm, gripper state: {action[-1]}" 216 | ) 217 | widowx_client.step_action(action) 218 | 219 | t += 1 220 | except KeyboardInterrupt: 221 | return images, full_images, True 222 | return images, full_images, False 223 | 224 | 225 | def main(_): 226 | get_action, text_processor = load_policy_checkpoint( 227 | FLAGS.policy_checkpoint, FLAGS.policy_wandb 228 | ) 229 | 230 | # init environment 231 | env_params = WidowXConfigs.DefaultEnvParams.copy() 232 | env_params.update(ENV_PARAMS) 233 | widowx_client = WidowXClient(host=FLAGS.ip, port=FLAGS.port) 234 | widowx_client.init(env_params) 235 | 236 | # wait for server to initialize 237 | print("Waiting for environment to start...") 238 | while widowx_client.get_observation() is None: 239 | time.sleep(0.5) 240 | 241 | # goal sampling loop 242 | prompt_embed = None 243 | done = False 244 | while True: 245 | # ask for new goal 246 | if prompt_embed is None: 247 | ch = "y" 248 | else: 249 | ch = input("New instruction? [y/n]") 250 | if ch == "y": 251 | prompt = input("Enter Prompt: ") 252 | prompt_embed = text_processor.encode(prompt) 253 | 254 | # move to initial position 255 | if FLAGS.initial_eep is not None: 256 | assert isinstance(FLAGS.initial_eep, list) 257 | initial_eep = [float(e) for e in FLAGS.initial_eep] 258 | print(f"Moving to initial position {initial_eep}") 259 | widowx_client.move_gripper(1.0) # open gripper 260 | widowx_client.move_gripper(1.0) # open gripper 261 | widowx_client.move(state_to_eep(initial_eep, 0)) 262 | 263 | input("Press [Enter] to start.") 264 | 265 | # do rollout 266 | images, full_images, done = rollout_subgoal( 267 | widowx_client, get_action, prompt_embed, FLAGS.num_timesteps 268 | ) 269 | 270 | if FLAGS.video_save_path is not None: 271 | save_path = os.path.join( 272 | FLAGS.video_save_path, 273 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S.mp4"), 274 | ) 275 | video = np.array(images) 276 | imageio.mimsave( 277 | save_path, 278 | video, 279 | fps=3.0 / STEP_DURATION, 280 | ) 281 | with open(save_path.replace(".mp4", "_prompt.txt"), "w") as f: 282 | f.write(prompt) 283 | 284 | 285 | if __name__ == "__main__": 286 | app.run(main) 287 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import functools 3 | import logging 4 | import os 5 | import tempfile 6 | import time 7 | from collections import defaultdict 8 | from functools import partial 9 | 10 | import einops as eo 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import optax 15 | import orbax.checkpoint 16 | import tensorflow as tf 17 | 18 | import tqdm 19 | from absl import app, flags 20 | from flax.training import orbax_utils 21 | from jax.experimental import multihost_utils 22 | from jax.lax import with_sharding_constraint as wsc 23 | from jax.sharding import NamedSharding 24 | from jax.sharding import PartitionSpec as P 25 | from ml_collections import ConfigDict, config_flags 26 | from PIL import Image 27 | 28 | import wandb 29 | from susie import sampling, scheduling 30 | from susie.data.datasets import get_data_loader 31 | from susie.jax_utils import ( 32 | host_broadcast_str, 33 | initialize_compilation_cache, 34 | ) 35 | from susie.model import ( 36 | EmaTrainState, 37 | create_model_def, 38 | load_pretrained_unet, 39 | load_text_encoder, 40 | load_vae, 41 | ) 42 | 43 | if jax.process_count() > 1: 44 | jax.distributed.initialize() 45 | 46 | tqdm = partial(tqdm.tqdm, dynamic_ncols=True) 47 | 48 | try: 49 | from jax_smi import initialise_tracking # type: ignore 50 | 51 | initialise_tracking() 52 | except ImportError: 53 | pass 54 | 55 | 56 | def fsdp_sharding(mesh: jax.sharding.Mesh, array: jax.ShapeDtypeStruct): 57 | if array.ndim < 2: 58 | # replicate scalar and vector arrays 59 | return NamedSharding(mesh, P()) 60 | 61 | # shard matrices and larger tensors across the fsdp dimension. the conv kernels are a little tricky because they 62 | # vary in which axis is a power of 2, so I'll just search for the first one that works. 63 | l = [] 64 | for n in array.shape: 65 | if n % mesh.shape["fsdp"] == 0: 66 | l.append("fsdp") 67 | return NamedSharding(mesh, P(*l)) 68 | l.append(None) 69 | 70 | logging.warning( 71 | f"Could not find a valid sharding for array of shape {array.shape} with mesh of shape {mesh.shape}" 72 | ) 73 | return NamedSharding(mesh, P()) 74 | 75 | 76 | def train_step( 77 | rng, 78 | state, 79 | batch, 80 | # static args 81 | log_snr_fn, 82 | uncond_prompt_embed, 83 | text_encode_fn, 84 | vae_encode_fn, 85 | curr_drop_rate=0.0, 86 | goal_drop_rate=0.0, 87 | prompt_drop_rate=0.0, 88 | eval_only=False, 89 | ): 90 | batch_size = batch["subgoals"].shape[0] 91 | 92 | # encode stuff 93 | for key in {"curr", "goals", "subgoals"}.intersection(batch.keys()): 94 | # VERY IMPORTANT: for some godforsaken reason, the context latents are 95 | # NOT scaled in InstructPix2Pix 96 | scale = key == "subgoals" 97 | rng, encode_rng = jax.random.split(rng) 98 | batch[key] = vae_encode_fn(encode_rng, batch[key], scale=scale) 99 | prompt_embeds = text_encode_fn(batch["prompt_ids"]) 100 | 101 | if goal_drop_rate == 1.0: 102 | batch["goals"] = jnp.zeros( 103 | batch["subgoals"].shape[:-1] + (0,), batch["subgoals"].dtype 104 | ) 105 | elif goal_drop_rate > 0: 106 | rng, mask_rng = jax.random.split(rng) 107 | batch["goals"] = jnp.where( 108 | jax.random.uniform(mask_rng, shape=(batch_size, 1, 1, 1)) < goal_drop_rate, 109 | 0, 110 | batch["goals"], 111 | ) 112 | 113 | if curr_drop_rate > 0: 114 | rng, mask_rng = jax.random.split(rng) 115 | batch["curr"] = jnp.where( 116 | jax.random.uniform(mask_rng, shape=(batch_size, 1, 1, 1)) < curr_drop_rate, 117 | 0, 118 | batch["curr"], 119 | ) 120 | 121 | if prompt_drop_rate > 0: 122 | rng, mask_rng = jax.random.split(rng) 123 | prompt_embeds = jnp.where( 124 | jax.random.uniform(mask_rng, shape=(batch_size, 1, 1)) < prompt_drop_rate, 125 | uncond_prompt_embed, 126 | prompt_embeds, 127 | ) 128 | 129 | x = batch["subgoals"] # the generation target 130 | y = jnp.concatenate( 131 | [batch["curr"], batch["goals"]], axis=-1 132 | ) # the conditioning image(s) 133 | 134 | # sample batch of timesteps from t ~ U[0, num_train_timesteps) 135 | rng, t_rng = jax.random.split(rng) 136 | t = jax.random.uniform(t_rng, shape=(batch_size,), dtype=jnp.float32) 137 | 138 | # sample noise (epsilon) from N(0, I) 139 | rng, noise_rng = jax.random.split(rng) 140 | noise = jax.random.normal(noise_rng, x.shape) 141 | 142 | log_snr = log_snr_fn(t) 143 | 144 | # generate the noised image from q(x_t | x_0, y) 145 | x_t = sampling.q_sample(x, log_snr, noise) 146 | 147 | input = jnp.concatenate([x_t, y], axis=-1) 148 | 149 | # seems like remat is actually enabled by default -- this disables it 150 | # @partial(jax.checkpoint, policy=jax.checkpoint_policies.everything_saveable) 151 | def loss_fn(params, rng): 152 | pred = state.apply_fn( 153 | {"params": params}, 154 | input, 155 | t * 1000, 156 | prompt_embeds, 157 | train=not eval_only, 158 | rngs={"dropout": rng}, 159 | ) 160 | assert pred.shape == noise.shape 161 | loss = (pred - noise) ** 2 162 | return jnp.mean(loss) 163 | 164 | info = {} 165 | if not eval_only: 166 | grad_fn = jax.value_and_grad(loss_fn) 167 | rng, dropout_rng = jax.random.split(rng) 168 | info["loss"], grads = grad_fn(state.params, dropout_rng) 169 | info["grad_norm"] = optax.global_norm(grads) 170 | 171 | new_state = state.apply_gradients(grads=grads) 172 | else: 173 | rng, dropout_rng = jax.random.split(rng) 174 | info["loss"] = loss_fn(state.params, dropout_rng) 175 | rng, dropout_rng = jax.random.split(rng) 176 | info["loss_ema"] = loss_fn(state.params_ema, dropout_rng) 177 | new_state = state 178 | 179 | return new_state, info 180 | 181 | 182 | FLAGS = flags.FLAGS 183 | 184 | config_flags.DEFINE_config_file( 185 | "config", 186 | None, 187 | "File path to the hyperparameter configuration.", 188 | lock_config=False, 189 | ) 190 | 191 | 192 | def main(_): 193 | config = FLAGS.config 194 | 195 | assert config.sample.num_contexts % 4 == 0 196 | 197 | # prevent tensorflow from using GPUs 198 | tf.config.experimental.set_visible_devices([], "GPU") 199 | tf.random.set_seed(config.seed + jax.process_index()) 200 | 201 | # get jax devices 202 | logging.info(f"JAX process: {jax.process_index()} of {jax.process_count()}") 203 | logging.info(f"Local devices: {jax.local_device_count()}") 204 | logging.info(f"Global devices: {jax.device_count()}") 205 | 206 | mesh = jax.sharding.Mesh( 207 | # create_device_mesh([32, 1]), # can't make contiguous meshes for the v4-64 pod for some reason 208 | np.array(jax.devices()).reshape(*config.mesh), 209 | axis_names=["dp", "fsdp"], 210 | ) 211 | replicated_sharding = NamedSharding(mesh, P()) 212 | # data gets sharded over both dp and fsdp logical axes 213 | data_sharding = NamedSharding(mesh, P(["dp", "fsdp"])) 214 | 215 | # initial rng 216 | rng = jax.random.PRNGKey(config.seed + jax.process_index()) 217 | 218 | # set up wandb run 219 | if config.wandb_resume_id is not None: 220 | run = wandb.Api().run(config.wandb_resume_id) 221 | old_num_steps = config.num_steps 222 | config = ConfigDict(run.config) 223 | config.num_steps = old_num_steps 224 | config.wandb_resume_id = run.id 225 | logdir = tf.io.gfile.join(config.logdir, run.name) 226 | 227 | if jax.process_index() == 0: 228 | wandb.init( 229 | project=run.project, 230 | id=run.id, 231 | resume="must", 232 | ) 233 | else: 234 | unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S") 235 | unique_id = host_broadcast_str(unique_id) 236 | 237 | if not config.run_name: 238 | config.run_name = unique_id 239 | else: 240 | config.run_name += "_" + unique_id 241 | 242 | logdir = tf.io.gfile.join(config.logdir, config.run_name) 243 | 244 | if jax.process_index() == 0: 245 | tf.io.gfile.makedirs(logdir) 246 | wandb.init( 247 | project=config.wandb_project, 248 | name=config.run_name, 249 | config=config.to_dict(), 250 | ) 251 | 252 | checkpointer = orbax.checkpoint.CheckpointManager( 253 | logdir, 254 | checkpointers={ 255 | "state": orbax.checkpoint.PyTreeCheckpointer(), 256 | "params_ema": orbax.checkpoint.PyTreeCheckpointer(), 257 | }, 258 | ) 259 | 260 | log_snr_fn = scheduling.create_log_snr_fn(config.scheduling) 261 | ema_decay_fn = scheduling.create_ema_decay_fn(config.ema) 262 | 263 | # load vae 264 | if config.vae is not None: 265 | vae_encode, vae_decode = load_vae(config.vae) 266 | 267 | # load text encoder 268 | tokenize, untokenize, text_encode = load_text_encoder(config.text_encoder) 269 | uncond_prompt_embed = jax.device_get(text_encode(tokenize([""]))) # (1, 77, 768) 270 | 271 | def tokenize_fn(batch): 272 | lang = [s.decode("utf-8") for s in batch.pop("lang")] 273 | assert all(s != "" for s in lang) 274 | batch["prompt_ids"] = tokenize(lang) 275 | return batch 276 | 277 | # load pretrained model 278 | if config.model.get("pretrained", None): 279 | pretrained_model_def, pretrained_params = load_pretrained_unet( 280 | config.model.pretrained, in_channels=12 if config.goal_drop_rate < 1 else 8 281 | ) 282 | pretrained_config = ConfigDict(pretrained_model_def.config) 283 | del config.model.pretrained 284 | if config.model.keys(): 285 | logging.warning(f"Overriding pretrained config keys: {config.model.keys()}") 286 | pretrained_config.update(config.model) 287 | config.model = pretrained_config 288 | else: 289 | pretrained_params = None 290 | 291 | # create model def 292 | config.model.out_channels = 4 if config.vae else 3 293 | model_def = create_model_def(config.model) 294 | 295 | # create optimizer 296 | learning_rate_fn = optax.warmup_cosine_decay_schedule( 297 | init_value=0.0, 298 | peak_value=config.optim.lr, 299 | warmup_steps=config.optim.warmup_steps, 300 | decay_steps=config.optim.decay_steps, 301 | end_value=0.0, 302 | ) 303 | tx = optax.adamw( 304 | learning_rate=learning_rate_fn, 305 | b1=config.optim.beta1, 306 | b2=config.optim.beta2, 307 | eps=config.optim.epsilon, 308 | weight_decay=config.optim.weight_decay, 309 | mu_dtype=jnp.bfloat16, 310 | ) 311 | tx = optax.chain( 312 | optax.clip_by_global_norm(config.optim.max_grad_norm), 313 | tx, 314 | ) 315 | 316 | if config.optim.accumulation_steps > 1: 317 | tx = optax.MultiSteps(tx, config.optim.accumulation_steps) 318 | 319 | # create data loader 320 | train_loader, val_loader, num_datasets = get_data_loader( 321 | config.data, tokenize_fn, mesh 322 | ) 323 | # warm up loaders 324 | logging.info("Warming up data loaders...") 325 | next(train_loader), next(val_loader) 326 | 327 | # initialize parameters 328 | if pretrained_params is None or config.wandb_resume_id is not None: 329 | example_batch = next(train_loader) 330 | print("--------------------------------") 331 | print(example_batch["subgoals"].shape) 332 | 333 | def init_fn(init_rng): 334 | if config.goal_drop_rate == 1.0: 335 | example_batch["goals"] = jnp.zeros( 336 | example_batch["subgoals"].shape[:-1] + (0,), 337 | example_batch["subgoals"].dtype, 338 | ) 339 | example_input = jnp.concatenate( 340 | [ 341 | example_batch["subgoals"], 342 | example_batch["curr"], 343 | example_batch["goals"], 344 | ], 345 | axis=-1, 346 | ) 347 | example_timesteps = jnp.zeros(example_input.shape[:1], example_input.dtype) 348 | example_prompt_embeds = jnp.zeros( 349 | [example_input.shape[0], 77, 768], example_input.dtype 350 | ) 351 | 352 | if config.vae: 353 | example_input = vae_encode(rng, example_input) 354 | 355 | params = model_def.init( 356 | init_rng, example_input, example_timesteps, example_prompt_embeds 357 | )["params"] 358 | state = EmaTrainState.create( 359 | apply_fn=model_def.apply, params=params, params_ema=params, tx=tx 360 | ) 361 | return state 362 | 363 | state_shape = jax.eval_shape( 364 | init_fn, rng 365 | ) # pytree of ShapeDtypeStructs for the TrainState 366 | state_sharding = jax.tree_map( 367 | lambda x: fsdp_sharding(mesh, x), state_shape 368 | ) # pytree of NamedShardings 369 | 370 | if config.wandb_resume_id is None: 371 | # initialize sharded TrainState 372 | rng, init_rng = jax.random.split(rng) 373 | state = jax.jit(init_fn, out_shardings=state_sharding)(init_rng) 374 | else: 375 | # restore from checkpoint 376 | state = checkpointer.restore( 377 | checkpointer.latest_step(), 378 | items={ 379 | "state": state_shape, 380 | "params_ema": None, 381 | }, 382 | )["state"] 383 | state = jax.tree_map( 384 | lambda arr, sharding: jax.make_array_from_callback( 385 | arr.shape, sharding, lambda index: arr[index] 386 | ), 387 | state, 388 | state_sharding, 389 | ) 390 | else: 391 | assert pretrained_params is not None 392 | state = EmaTrainState.create( 393 | apply_fn=model_def.apply, 394 | params=pretrained_params, 395 | params_ema=pretrained_params, 396 | tx=tx, 397 | ) 398 | state = jax.tree_map(np.array, state) 399 | state_sharding = jax.tree_map(lambda x: fsdp_sharding(mesh, x), state) 400 | state = jax.tree_map( 401 | lambda arr, sharding: jax.make_array_from_callback( 402 | arr.shape, sharding, lambda index: arr[index] 403 | ), 404 | state, 405 | state_sharding, 406 | ) 407 | 408 | # create train and eval step 409 | train_step_configured = partial( 410 | train_step, 411 | log_snr_fn=log_snr_fn, 412 | uncond_prompt_embed=uncond_prompt_embed, 413 | text_encode_fn=text_encode, 414 | vae_encode_fn=vae_encode if config.vae else lambda rng, x, *_, **__: x, 415 | curr_drop_rate=config.curr_drop_rate, 416 | goal_drop_rate=config.goal_drop_rate, 417 | prompt_drop_rate=config.prompt_drop_rate, 418 | ) 419 | train_in_shardings = ( 420 | replicated_sharding, # rng 421 | state_sharding, # state 422 | data_sharding, # batch 423 | ) 424 | train_out_shardings = ( 425 | state_sharding, # new_state 426 | replicated_sharding, # info 427 | ) 428 | train_step_jit = jax.jit( 429 | partial(train_step_configured, eval_only=False), 430 | in_shardings=train_in_shardings, 431 | out_shardings=train_out_shardings, 432 | donate_argnums=1, 433 | ) 434 | eval_step_jit = jax.jit( 435 | partial(train_step_configured, eval_only=True), 436 | in_shardings=train_in_shardings, 437 | out_shardings=train_out_shardings, 438 | donate_argnums=1, 439 | ) 440 | 441 | # shard ema decay 442 | EmaTrainState.apply_ema_decay = jax.jit( 443 | EmaTrainState.apply_ema_decay, 444 | in_shardings=(state_sharding, replicated_sharding), # state, ema_decay 445 | out_shardings=state_sharding, # new_state 446 | donate_argnums=0, # donate state (have to respecify; it doesn't carry over from the inner jit) 447 | ) 448 | 449 | # shard sample loop 450 | sample_loop_configured = partial( 451 | sampling.sample_loop, 452 | log_snr_fn=log_snr_fn, 453 | num_timesteps=config.sample.num_steps, 454 | context_w=config.sample.context_w, 455 | prompt_w=config.sample.prompt_w, 456 | eta=config.sample.eta, 457 | ) 458 | sample_loop_jit = jax.jit( 459 | sample_loop_configured, 460 | in_shardings=( 461 | replicated_sharding, # rng 462 | state_sharding, # state 463 | replicated_sharding, # y 464 | replicated_sharding, # prompt_embeds 465 | replicated_sharding, # uncond_y 466 | replicated_sharding, # uncond_prompt_embeds 467 | ), 468 | out_shardings=replicated_sharding, # returned samples 469 | ) 470 | 471 | train_metrics = defaultdict(list) 472 | last_t = time.time() 473 | 474 | start_step = int(jax.device_get(state.step)) 475 | pbar = tqdm(range(start_step, config.num_steps)) 476 | for step in pbar: 477 | batch = next(train_loader) 478 | 479 | rng, train_step_rng = jax.random.split(rng) 480 | state, info = train_step_jit(train_step_rng, state, batch) 481 | pbar.set_postfix_str(f"loss: {info['loss']:.6f}") 482 | for k, v in info.items(): 483 | train_metrics[k].append(v) 484 | 485 | # update ema params 486 | if (step + 1) <= config.ema.start_step: 487 | state = state.apply_ema_decay(0.0) 488 | if (step + 1) % config.ema.update_every == 0: 489 | ema_decay = ema_decay_fn(step) 490 | state = state.apply_ema_decay(ema_decay) 491 | 492 | # train logging 493 | if (step + 1) % config.log_interval == 0 and jax.process_index() == 0: 494 | summary = {f"train/{k}": np.mean(v) for k, v in train_metrics.items()} 495 | summary["time/seconds_per_step"] = ( 496 | time.time() - last_t 497 | ) / config.log_interval 498 | 499 | train_metrics = defaultdict(list) 500 | last_t = time.time() 501 | 502 | summary["train/ema_decay"] = jax.device_get(ema_decay) 503 | summary["train/lr"] = jax.device_get(learning_rate_fn(step)) 504 | 505 | wandb.log(summary, step=step + 1) 506 | 507 | if (step + 1) % config.val_interval == 0: 508 | # compute and log validation metrics 509 | val_metrics = defaultdict(list) 510 | for _ in tqdm(range(config.num_val_batches), desc="val", position=1): 511 | batch = next(val_loader) 512 | rng, val_step_rng = jax.random.split(rng) 513 | state, info = eval_step_jit(val_step_rng, state, batch) 514 | for k, v in info.items(): 515 | val_metrics[k].append(v) 516 | if jax.process_index() == 0: 517 | summary = {f"val/{k}": np.mean(v) for k, v in val_metrics.items()} 518 | wandb.log(summary, step=step + 1) 519 | 520 | if (step + 1) % config.sample_interval == 0: 521 | pbar.set_postfix_str("sampling") 522 | 523 | data = defaultdict(list) 524 | while not data or len(data["curr"]) < config.sample.num_contexts: 525 | batch = next(val_loader) 526 | batch = multihost_utils.process_allgather(batch) 527 | for key in {"curr", "goals", "prompt_ids"}.intersection(batch.keys()): 528 | data[key].extend(batch[key]) 529 | data = {k: np.array(v) for k, v in data.items()} 530 | 531 | data = jax.tree_map(lambda x: x[: config.sample.num_contexts], data) 532 | 533 | # get rid of goals if we're not using them 534 | if config.goal_drop_rate == 1.0: 535 | data["goals"] = np.zeros( 536 | data["curr"].shape[:-1] + (0,), data["curr"].dtype 537 | ) 538 | else: 539 | # make the first half have no prompt 540 | # data["prompt_ids"][: config.sample.num_contexts // 2] = uncond_prompt_id 541 | # make the second half have no goal 542 | # data["goals"][config.sample.num_contexts // 2 :] = 0 543 | pass 544 | 545 | # concatenate to make context 546 | contexts = np.concatenate([data["curr"], data["goals"]], axis=-1) 547 | 548 | # encode stuff 549 | if config.vae: 550 | rng, encode_rng = jax.random.split(rng) 551 | contexts = jax.device_get(vae_encode(encode_rng, contexts)) 552 | prompt_embeds = jax.device_get(text_encode(data["prompt_ids"])) 553 | 554 | # repeat 555 | contexts_repeated = eo.repeat( 556 | contexts, "n ... -> (n r) ...", r=config.sample.num_samples_per_context 557 | ) 558 | prompt_embeds_repeated = eo.repeat( 559 | prompt_embeds, 560 | "n ... -> (n r) ...", 561 | r=config.sample.num_samples_per_context, 562 | ) 563 | 564 | # run sample loop 565 | rng, sample_rng = jax.random.split(rng) 566 | samples = sample_loop_jit( 567 | sample_rng, 568 | state, 569 | contexts_repeated, 570 | prompt_embeds_repeated, 571 | jnp.zeros_like(contexts_repeated), 572 | jnp.broadcast_to(uncond_prompt_embed, prompt_embeds_repeated.shape), 573 | ) # (num_contexts * num_samples_per_context, h, w, c) 574 | 575 | if config.vae: 576 | samples = jax.device_get(vae_decode(samples, scale=True)) 577 | contexts = jax.device_get(vae_decode(contexts, scale=False)) 578 | 579 | right = eo.rearrange( 580 | samples, 581 | "(n r) h w c -> (n h) (r w) c", 582 | r=config.sample.num_samples_per_context, 583 | ) 584 | left = eo.rearrange(contexts, "n h w (x c) -> (n h) (x w) c", c=3) 585 | 586 | final_image = np.concatenate([left, right], axis=1) 587 | final_image = np.clip(np.round(final_image * 127.5 + 127.5), 0, 255).astype( 588 | np.uint8 589 | ) 590 | 591 | if jax.process_index() == 0: 592 | prompts = untokenize(data["prompt_ids"]) 593 | prompt_str = "; ".join(prompts) 594 | pil = Image.fromarray(final_image) 595 | with tf.io.gfile.GFile( 596 | tf.io.gfile.join(logdir, f"{step + 1}.jpg"), "wb" 597 | ) as f: 598 | pil.save(f, format="jpeg", quality=95) 599 | with tf.io.gfile.GFile( 600 | tf.io.gfile.join(logdir, f"{step + 1}.txt"), "w" 601 | ) as f: 602 | f.write(prompt_str) 603 | 604 | with tempfile.TemporaryDirectory() as tmpdir: 605 | pil.save(os.path.join(tmpdir, "image.jpg"), quality=95) 606 | wandb.log( 607 | { 608 | "samples": wandb.Image( 609 | os.path.join(tmpdir, "image.jpg"), caption=prompt_str 610 | ) 611 | }, 612 | step=step + 1, 613 | ) 614 | 615 | if (step + 1) % config.save_interval == 0: 616 | checkpointer.save( 617 | step + 1, 618 | {"state": state, "params_ema": state.params_ema}, 619 | { 620 | "state": { 621 | "save_args": orbax_utils.save_args_from_target(state), 622 | }, 623 | "params_ema": { 624 | "save_args": orbax_utils.save_args_from_target( 625 | state.params_ema 626 | ), 627 | }, 628 | }, 629 | ) 630 | 631 | 632 | if __name__ == "__main__": 633 | app.run(main) 634 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="susie", 5 | packages=find_packages(), 6 | version="0.0.1", 7 | install_requires=[ 8 | "absl-py", 9 | "diffusers[flax]", 10 | "ml_collections", 11 | "tensorflow", 12 | "wandb", 13 | "einops", 14 | ], 15 | ) 16 | -------------------------------------------------------------------------------- /susie/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import make_dataset 2 | -------------------------------------------------------------------------------- /susie/data/datasets.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Dict, List 3 | 4 | import dlimp as dl 5 | import numpy as np 6 | import tensorflow as tf 7 | from jax.experimental import multihost_utils 8 | from jax.sharding import PartitionSpec as P 9 | 10 | from . import goal_relabeling 11 | 12 | 13 | class Transforms: 14 | """Trajectory-level transforms for each dataset""" 15 | 16 | @staticmethod 17 | def ego4d(x: Dict[str, Any]) -> Dict[str, Any]: 18 | return x 19 | 20 | @staticmethod 21 | def bridge(x: Dict[str, Any]) -> Dict[str, Any]: 22 | CAMERA_VIEWS = {"images0", "images1", "images2"} 23 | # pick a random camera view 24 | views = tf.stack([x["obs"][k] for k in CAMERA_VIEWS]) 25 | lengths = tf.stack([tf.strings.length(x["obs"][k][0]) for k in CAMERA_VIEWS]) 26 | views = views[lengths > 0] 27 | idx = tf.random.uniform([], minval=0, maxval=tf.shape(views)[0], dtype=tf.int32) 28 | x["obs"] = views[idx] 29 | # x["obs"] = x["obs"]["images0"] 30 | 31 | del x["actions"] 32 | return x 33 | 34 | @staticmethod 35 | def calvin(x: Dict[str, Any]) -> Dict[str, Any]: 36 | x["obs"] = x.pop("image_states") 37 | x["lang"] = x.pop("language_annotation") 38 | 39 | del x["actions"] 40 | del x["proprioceptive_states"] 41 | 42 | return x 43 | 44 | @staticmethod 45 | def somethingsomething(x: Dict[str, Any]) -> Dict[str, Any]: 46 | return x 47 | 48 | 49 | class GetPaths: 50 | """Retrieves paths to TFRecord files or each dataset""" 51 | 52 | @staticmethod 53 | def ego4d(data_path: str, train: bool) -> str: 54 | return f"{data_path}/{'train' if train else 'val'}" 55 | 56 | @staticmethod 57 | def bridge(data_path: str, train: bool) -> str: 58 | return f"{data_path}/{'train' if train else 'val'}" 59 | 60 | @staticmethod 61 | def somethingsomething(data_path: str, train: bool) -> List[str]: 62 | return f"{data_path}/{'train' if train else 'val'}" 63 | 64 | @staticmethod 65 | def calvin(data_path: str, train: bool) -> List[str]: 66 | if train: 67 | return ( 68 | tf.io.gfile.glob(f"{data_path}/training/A/*") 69 | + tf.io.gfile.glob(f"{data_path}/training/B/*") 70 | + tf.io.gfile.glob(f"{data_path}/training/C/*") 71 | ) 72 | else: 73 | return tf.io.gfile.glob(f"{data_path}/validation/D/*") 74 | 75 | 76 | def make_dataset( 77 | name: str, 78 | data_path: str, 79 | image_size: int, 80 | shuffle_buffer_size: int, 81 | train: bool, 82 | goal_relabeling_fn: str, 83 | goal_relabeling_kwargs: dict = {}, 84 | augment_kwargs: dict = {}, 85 | ) -> dl.DLataset: 86 | paths = getattr(GetPaths, name)(data_path, train) 87 | 88 | dataset = ( 89 | dl.DLataset.from_tfrecords(paths) 90 | .map(dl.transforms.unflatten_dict) 91 | .map(getattr(Transforms, name)) 92 | .filter(lambda x: tf.math.reduce_all(x["lang"] != "")) 93 | .apply( 94 | partial( 95 | getattr(goal_relabeling, goal_relabeling_fn), **goal_relabeling_kwargs 96 | ), 97 | ) 98 | .unbatch() 99 | .shuffle(shuffle_buffer_size) 100 | ) 101 | 102 | dataset = dataset.map( 103 | partial(dl.transforms.decode_images, match=["curr", "goals", "subgoals"]) 104 | ).map( 105 | partial( 106 | dl.transforms.resize_images, 107 | match=["curr", "goals", "subgoals"], 108 | size=(image_size, image_size), 109 | ) 110 | ) 111 | 112 | if train: 113 | dataset = dataset.map( 114 | partial( 115 | dl.transforms.augment, 116 | traj_identical=False, 117 | keys_identical=True, 118 | match=["curr", "goals", "subgoals"], 119 | augment_kwargs=augment_kwargs, 120 | ) 121 | ) 122 | 123 | # normalize images to [-1, 1] 124 | dataset = dataset.map( 125 | partial( 126 | dl.transforms.selective_tree_map, 127 | match=["curr", "goals", "subgoals"], 128 | map_fn=lambda v: v / 127.5 - 1.0, 129 | ) 130 | ) 131 | 132 | return dataset.repeat() 133 | 134 | 135 | def get_data_loader(data_config, tokenize_fn, mesh=None): 136 | data_config = dict(data_config) 137 | batch_size = data_config.pop("batch_size") 138 | 139 | train_datasets = [] 140 | val_datasets = [] 141 | weights = [] 142 | for data_name, data_kwargs in data_config.items(): 143 | data_kwargs = dict(data_kwargs) 144 | weights.append(float(data_kwargs.pop("weight"))) 145 | train_datasets.append(make_dataset(data_name, train=True, **data_kwargs)) 146 | val_datasets.append(make_dataset(data_name, train=False, **data_kwargs)) 147 | 148 | train = dl.DLataset.sample_from_datasets( 149 | train_datasets, weights=weights, stop_on_empty_dataset=True 150 | ).batch(batch_size, num_parallel_calls=tf.data.AUTOTUNE) 151 | val = dl.DLataset.sample_from_datasets( 152 | val_datasets, weights=weights, stop_on_empty_dataset=True 153 | ).batch(batch_size, num_parallel_calls=tf.data.AUTOTUNE) 154 | 155 | def shard(batch): 156 | return multihost_utils.host_local_array_to_global_array( 157 | batch, 158 | mesh, 159 | P(("dp", "fsdp")), 160 | ) 161 | 162 | # WARNING: for some reason any amount of prefetching is also a total no-go in terms of memory usage... 163 | train = map(tokenize_fn, train.as_numpy_iterator()) 164 | val = map(tokenize_fn, val.as_numpy_iterator()) 165 | 166 | if mesh: 167 | return map(shard, train), map(shard, val), len(train_datasets) 168 | else: 169 | return train, val, len(train_datasets) 170 | -------------------------------------------------------------------------------- /susie/data/goal_relabeling.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def tail_goals(ds, *, tail_proportion, subgoal_delta): 5 | """ 6 | Relabels for subgoal training. Removes the `obs` key and adds `subgoals`, `curr`, and `goals` keys. 7 | 8 | The "goal" is selected from the last `tail_proportion` proportion of the trajectory. The "current obs" is selected 9 | from [0, len * (1 - tail_proportion) - subgoal_delta[0]). The "subgoal" is selected from [curr + subgoal_delta[0], 10 | min{curr + subgoal_delta[1], goal}). 11 | """ 12 | assert len(subgoal_delta) == 2 13 | 14 | def filter_fn(traj): 15 | num_frames = tf.shape(traj["obs"])[0] 16 | n = tf.cast( 17 | tf.math.ceil(tf.cast(num_frames, tf.float32) * tail_proportion), 18 | tf.int32, 19 | ) 20 | return num_frames > n + subgoal_delta[0] 21 | 22 | def map_fn(traj): 23 | num_frames = tf.shape(traj["obs"])[0] 24 | 25 | n = tf.cast( 26 | tf.math.ceil(tf.cast(num_frames, tf.float32) * tail_proportion), 27 | tf.int32, 28 | ) 29 | 30 | # select the last n transitions to be goals: [len - n, len) 31 | goal_idxs = tf.range(num_frames - n, num_frames) 32 | goals = tf.gather(traj["obs"], goal_idxs, name="tail_1") 33 | 34 | # for each goal, select a random state from [0, len - n - subgoal_delta[0]) 35 | rand = tf.random.uniform([n]) 36 | high = tf.cast(num_frames - n - subgoal_delta[0], tf.float32) 37 | curr_idxs = tf.cast(rand * high, tf.int32) 38 | curr = tf.gather(traj["obs"], curr_idxs, name="tail_2") 39 | 40 | # for each (curr, goal) pair, select a random subgoal from [curr + subgoal_delta[0], min{curr + 41 | # subgoal_delta[1], goal}) 42 | rand = tf.random.uniform([n]) 43 | low = tf.cast(curr_idxs + subgoal_delta[0], tf.float32) 44 | high = tf.cast(tf.minimum(curr_idxs + subgoal_delta[1], goal_idxs), tf.float32) 45 | subgoal_idxs = tf.cast(low + rand * (high - low), tf.int32) 46 | subgoals = tf.gather(traj["obs"], subgoal_idxs, name="tail_3") 47 | 48 | extras = {k: v[-n:] for k, v in traj.items() if k != "obs"} 49 | 50 | return {"subgoals": subgoals, "curr": curr, "goals": goals, **extras} 51 | 52 | return ds.filter(filter_fn).map(map_fn) 53 | 54 | 55 | def delta_goals(ds, *, goal_delta, subgoal_delta): 56 | """ 57 | Relabels for subgoal training. Removes the `obs` key and adds `subgoals`, `curr`, and `goals` keys. 58 | 59 | The "current obs" is selected from [0, len - goal_delta[0]). The "goal" is then selected from [curr + 60 | goal_delta[0], min{curr + goal_delta[1], len}). 61 | 62 | The "subgoal" is selected from [curr + subgoal_delta[0], min{curr + subgoal_delta[1], goal}). 63 | """ 64 | assert len(subgoal_delta) == 2 65 | assert len(goal_delta) == 2 66 | 67 | def filter_fn(traj): 68 | num_frames = tf.shape(traj["obs"])[0] 69 | n = num_frames - goal_delta[0] 70 | return n >= 1 71 | 72 | def map_fn(traj): 73 | num_frames = tf.shape(traj["obs"])[0] 74 | n = num_frames - goal_delta[0] 75 | 76 | # select [0, len - goal_delta[0]) to be the current obs 77 | curr_idxs = tf.range(n) 78 | curr = tf.gather(traj["obs"], curr_idxs, name="delta_1") 79 | 80 | # for each current obs, select a random goal from [curr + goal_delta[0], min{curr + goal_delta[1], len}) 81 | rand = tf.random.uniform([n]) 82 | low = tf.cast(curr_idxs + goal_delta[0], tf.float32) 83 | high = tf.cast(tf.minimum(curr_idxs + goal_delta[1], num_frames), tf.float32) 84 | goal_idxs = tf.cast(low + rand * (high - low), tf.int32) 85 | goal_idxs = tf.clip_by_value(goal_idxs, 0, num_frames - 1) 86 | goals = tf.gather(traj["obs"], goal_idxs, name="delta_2") 87 | 88 | # for each (curr, goal) pair, select a random subgoal from [curr + subgoal_delta[0], min{curr + subgoal_delta[1], goal}) 89 | rand = tf.random.uniform([n]) 90 | low = tf.cast(curr_idxs + subgoal_delta[0], tf.float32) 91 | high = tf.cast(tf.minimum(curr_idxs + subgoal_delta[1], goal_idxs), tf.float32) 92 | subgoal_idxs = tf.cast(low + rand * (high - low), tf.int32) 93 | subgoals = tf.gather(traj["obs"], subgoal_idxs, name="delta_3") 94 | 95 | extras = {k: v[:n] for k, v in traj.items() if k != "obs"} 96 | 97 | return {"subgoals": subgoals, "curr": curr, "goals": goals, **extras} 98 | 99 | return ds.filter(filter_fn).map(map_fn) 100 | 101 | 102 | def subgoal_only(ds, *, subgoal_delta, truncate=False): 103 | """ 104 | Relabels for subgoal training. Removes the `obs` key and adds `subgoals` and `curr` keys. 105 | 106 | If truncate == False: 107 | The "current obs" is selected from [0, len). The "subgoal" is then selected 108 | from [min{curr + subgoal_delta[0], len - 1}, min{curr + subgoal_delta[1], len}). 109 | else: 110 | The "current obs" is selected from [0, len - subgoal_delta[0]). The "subgoal" is then selected 111 | from [curr + subgoal_delta[0], min{curr + subgoal_delta[1], len}). 112 | 113 | """ 114 | assert len(subgoal_delta) == 2 115 | 116 | def filter_fn(traj): 117 | num_frames = tf.shape(traj["obs"])[0] 118 | n = num_frames - subgoal_delta[0] 119 | return n >= 1 120 | 121 | if truncate: 122 | 123 | def map_fn(traj): 124 | num_frames = tf.shape(traj["obs"])[0] 125 | n = num_frames - subgoal_delta[0] 126 | 127 | # select [0, len - subgoal_delta[0]) to be the current obs 128 | curr_idxs = tf.range(n) 129 | curr = tf.gather(traj["obs"], curr_idxs, name="subdelta_1") 130 | 131 | # for each current obs, select a random subgoal from [curr + subgoal_delta[0], min{curr + subgoal_delta[1], 132 | # len}) 133 | rand = tf.random.uniform([n]) 134 | low = tf.cast(curr_idxs + subgoal_delta[0], tf.float32) 135 | high = tf.cast( 136 | tf.minimum(curr_idxs + subgoal_delta[1], num_frames), tf.float32 137 | ) 138 | subgoal_idxs = tf.cast(low + rand * (high - low), tf.int32) 139 | subgoal_idxs = tf.clip_by_value(subgoal_idxs, 0, num_frames - 1) 140 | subgoals = tf.gather(traj["obs"], subgoal_idxs, name="subdelta_2") 141 | 142 | extras = {k: v[:n] for k, v in traj.items() if k != "obs"} 143 | 144 | return {"subgoals": subgoals, "curr": curr, **extras} 145 | 146 | else: 147 | 148 | def map_fn(traj): 149 | num_frames = tf.shape(traj["obs"])[0] 150 | 151 | # select [0, len) to be the current obs 152 | curr_idxs = tf.range(num_frames) 153 | curr = traj["obs"] 154 | 155 | # for each current obs, select a random subgoal from [min{curr + 156 | # subgoal_delta[0], len - 1}, min{curr + subgoal_delta[1], len}) 157 | rand = tf.random.uniform([num_frames]) 158 | low = tf.cast( 159 | tf.minimum(curr_idxs + subgoal_delta[0], num_frames - 1), tf.float32 160 | ) 161 | high = tf.cast( 162 | tf.minimum(curr_idxs + subgoal_delta[1], num_frames), tf.float32 163 | ) 164 | subgoal_idxs = tf.cast(low + rand * (high - low), tf.int32) 165 | subgoal_idxs = tf.clip_by_value(subgoal_idxs, 0, num_frames - 1) 166 | subgoals = tf.gather(traj["obs"], subgoal_idxs, name="subdelta_2") 167 | 168 | extras = {k: v for k, v in traj.items() if k != "obs"} 169 | 170 | return {"subgoals": subgoals, "curr": curr, **extras} 171 | 172 | return ds.filter(filter_fn).map(map_fn) 173 | 174 | 175 | def uniform(ds, *, dist_norm): 176 | """ 177 | Relabels with a true uniform distribution over future states. 178 | """ 179 | 180 | def map_fn(traj): 181 | traj_len = tf.shape(tf.nest.flatten(traj["obs"])[0])[0] 182 | 183 | # select a random future index for each transition i in the range [i + 1, traj_len) 184 | rand = tf.random.uniform([traj_len]) 185 | low = tf.cast(tf.range(traj_len) + 1, tf.float32) 186 | high = tf.cast(traj_len, tf.float32) 187 | goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) 188 | 189 | # sometimes there are floating-point errors that cause an out-of-bounds 190 | goal_idxs = tf.minimum(goal_idxs, traj_len - 1) 191 | 192 | traj["goals"] = tf.gather(traj["obs"], goal_idxs, name="uniform_1") 193 | traj["dists"] = goal_idxs - tf.range(traj_len) 194 | traj["curr"] = traj.pop("obs") 195 | 196 | traj["dists"] = 2 * tf.cast(traj["dists"], tf.float32) / dist_norm - 1 197 | 198 | return traj 199 | 200 | return ds.map(map_fn) 201 | -------------------------------------------------------------------------------- /susie/jax_utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import logging 3 | import os 4 | import pickle 5 | from copy import deepcopy 6 | from typing import Any, Callable, Sequence, Union 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | from jax._src import xla_bridge as xb 12 | from jax.experimental import multihost_utils 13 | from jax.experimental.compilation_cache import compilation_cache 14 | from jax.stages import Compiled 15 | from jaxlib.mlir import ir 16 | from jaxlib.mlir.dialects import chlo, stablehlo 17 | 18 | 19 | def host_broadcast_str(x: str) -> str: 20 | """Broadcast_one_to_all, but with a string. Strings should all be the same length.""" 21 | multihost_utils.assert_equal( 22 | len(x), f"String lengths are not equal: got {len(x)} for {jax.process_index()}" 23 | ) 24 | encoded = np.array([ord(c) for c in x], dtype=np.uint8) 25 | encoded = multihost_utils.broadcast_one_to_all(encoded) 26 | return "".join([chr(u) for u in encoded]) 27 | 28 | 29 | def shard_along_first_axis(x: Any, devices: Sequence[jax.Device]) -> jax.Array: 30 | """ 31 | Shard an array along the first axis, putting it on device in the process. 32 | Works in multi-host setting. 33 | """ 34 | sharding = jax.sharding.NamedSharding( 35 | jax.sharding.Mesh(devices, "x"), jax.sharding.PartitionSpec("x") 36 | ) 37 | x = jax.tree_map(jnp.array, x) 38 | return jax.tree_map( 39 | lambda arr: jax.make_array_from_callback( 40 | arr.shape, sharding, lambda index: arr[index] 41 | ), 42 | x, 43 | ) 44 | 45 | 46 | def replicate(x: Any, devices: Sequence[jax.Device]) -> jax.Array: 47 | """Replicate an array across devices. Works in multi-host setting.""" 48 | sharding = jax.sharding.PositionalSharding(devices).replicate() 49 | x = jax.tree_map(jnp.array, x) 50 | return jax.tree_map( 51 | lambda arr: jax.make_array_from_callback( 52 | arr.shape, sharding, lambda index: arr[index] 53 | ), 54 | x, 55 | ) 56 | 57 | 58 | def initialize_compilation_cache(path=os.path.expanduser("~/.jax_compilation_cache")): 59 | """Initializes the Jax persistent compilation cache.""" 60 | compilation_cache.initialize_cache(path) 61 | for logger in [logging.getLogger(name) for name in logging.root.manager.loggerDict]: 62 | logger.addFilter( 63 | lambda record: "Not writing persistent cache entry for" 64 | not in record.getMessage() 65 | and "Persistent compilation cache hit for" not in record.getMessage() 66 | and "to persistent compilation cache with key" not in record.getMessage() 67 | ) 68 | 69 | 70 | def serialize_jax_fn(fn: Callable, *args, **kwargs) -> bytes: 71 | """ 72 | Seralizes a Jax function using StableHLO and pickle. Only supports trivial 73 | shardings (no cross-device communication). 74 | """ 75 | lowered = jax.jit(fn, backend="cpu").lower(*args, **kwargs) 76 | 77 | output = io.BytesIO() 78 | lowered.compiler_ir("stablehlo").operation.write_bytecode(file=output) 79 | hlo = output.getvalue() 80 | 81 | n_invals = len(lowered._lowering.compile_args["in_shardings"]) 82 | n_outvals = len(lowered._lowering.compile_args["out_shardings"]) 83 | 84 | objs_to_skip = [lowered._lowering._hlo] 85 | for k in ["backend", "in_shardings", "out_shardings", "device_assignment"]: 86 | objs_to_skip.append(lowered._lowering.compile_args[k]) 87 | lowered = deepcopy(lowered, {id(obj): None for obj in objs_to_skip}) 88 | 89 | return pickle.dumps((lowered, hlo, n_invals, n_outvals)) 90 | 91 | 92 | def deserialize_jax_fn( 93 | serialized: Union[bytes, str], 94 | device: jax.Device = jax.devices()[0], 95 | ) -> Compiled: 96 | """ 97 | Deserializes and compiles a Jax function serialized using 98 | `serialize_jax_fn`. Forces computation onto a single device. 99 | """ 100 | if isinstance(serialized, str): 101 | with open(serialized, "rb") as f: 102 | serialized = f.read() 103 | lowered, hlo, n_invals, n_outvals = pickle.loads(serialized) 104 | 105 | with ir.Context() as context: 106 | stablehlo.register_dialect(context) 107 | chlo.register_chlo_dialect(context) 108 | hlo = ir.Module.parse(hlo) 109 | 110 | lowered._lowering._hlo = hlo 111 | 112 | lowered._lowering.compile_args["backend"] = xb.get_device_backend(device) 113 | lowered._lowering.compile_args["device_assignment"] = [device] 114 | lowered._lowering.compile_args["in_shardings"] = [ 115 | jax.sharding.PositionalSharding([device]) 116 | ] * n_invals 117 | lowered._lowering.compile_args["out_shardings"] = [ 118 | jax.sharding.PositionalSharding([device]) 119 | ] * n_outvals 120 | 121 | return lowered.compile() 122 | -------------------------------------------------------------------------------- /susie/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from functools import partial 4 | from typing import Any, Callable, List, Optional, Tuple 5 | 6 | import einops as eo 7 | import jax 8 | import jax.numpy as jnp 9 | import ml_collections 10 | import numpy as np 11 | import orbax.checkpoint 12 | from absl import logging 13 | from diffusers.models import FlaxAutoencoderKL, FlaxUNet2DConditionModel 14 | from flax.core.frozen_dict import FrozenDict 15 | from flax.training.train_state import TrainState 16 | from jax.lax import with_sharding_constraint as wsc 17 | from transformers import CLIPTokenizer, FlaxCLIPTextModel 18 | 19 | import wandb 20 | from susie import sampling, scheduling 21 | from susie.jax_utils import replicate 22 | 23 | 24 | class EmaTrainState(TrainState): 25 | params_ema: FrozenDict[str, Any] 26 | 27 | @partial(jax.jit, donate_argnums=0) 28 | def apply_ema_decay(self, ema_decay): 29 | params_ema = jax.tree_map( 30 | lambda p_ema, p: p_ema * ema_decay + p * (1.0 - ema_decay), 31 | self.params_ema, 32 | self.params, 33 | ) 34 | return self.replace(params_ema=params_ema) 35 | 36 | 37 | def create_model_def(config: dict) -> FlaxUNet2DConditionModel: 38 | model, unused_kwargs = FlaxUNet2DConditionModel.from_config( 39 | dict(config), return_unused_kwargs=True 40 | ) 41 | if unused_kwargs: 42 | logging.warning(f"FlaxUNet2DConditionModel unused kwargs: {unused_kwargs}") 43 | # monkey-patch __call__ to use channels-last 44 | model.__call__ = lambda self, sample, *args, **kwargs: eo.rearrange( 45 | FlaxUNet2DConditionModel.__call__( 46 | self, eo.rearrange(sample, "b h w c -> b c h w"), *args, **kwargs 47 | ).sample, 48 | "b c h w -> b h w c", 49 | ) 50 | return model 51 | 52 | 53 | def load_vae( 54 | path: str, 55 | ) -> Tuple[ 56 | Callable[[jax.Array, jax.Array, bool], jax.Array], 57 | Callable[[jax.Array, bool], jax.Array], 58 | ]: 59 | if ":" in path: 60 | path, revision = path.split(":") 61 | else: 62 | revision = None 63 | vae, vae_params = FlaxAutoencoderKL.from_pretrained( 64 | path, subfolder="vae", revision=revision 65 | ) 66 | # monkey-patch encode to use channels-last (it returns a FlaxDiagonalGaussianDistribution object, which is already 67 | # channels-last) 68 | vae.encode = lambda self, sample, *args, **kwargs: FlaxAutoencoderKL.encode( 69 | self, eo.rearrange(sample, "b h w c -> b c h w"), *args, **kwargs 70 | ).latent_dist 71 | 72 | # monkey-patch decode to use channels-last (it already accepts channels-last input) 73 | vae.decode = lambda self, latents, *args, **kwargs: eo.rearrange( 74 | FlaxAutoencoderKL.decode(self, latents, *args, **kwargs).sample, 75 | "b c h w -> b h w c", 76 | ) 77 | 78 | # HuggingFace places vae_params committed onto the CPU -_- 79 | # this one took me awhile to figure out... 80 | vae_params = jax.device_get(vae_params) 81 | 82 | @jax.jit 83 | def vae_encode(vae_params, key, sample, scale=False): 84 | # handle the case where `sample` is multiple images stacked 85 | batch_size = sample.shape[0] 86 | sample = eo.rearrange(sample, "n h w (x c) -> (n x) h w c", c=3) 87 | latents = vae.apply({"params": vae_params}, sample, method=vae.encode).sample( 88 | key 89 | ) 90 | latents = eo.rearrange(latents, "(n x) h w c -> n h w (x c)", n=batch_size) 91 | latents = jax.lax.cond( 92 | scale, lambda: latents * vae.config.scaling_factor, lambda: latents 93 | ) 94 | return latents 95 | 96 | @jax.jit 97 | def vae_decode(vae_params, latents, scale=True): 98 | # handle the case where `latents` is multiple images stacked 99 | batch_size = latents.shape[0] 100 | latents = eo.rearrange( 101 | latents, "n h w (x c) -> (n x) h w c", c=vae.config.latent_channels 102 | ) 103 | latents = jax.lax.cond( 104 | scale, lambda: latents / vae.config.scaling_factor, lambda: latents 105 | ) 106 | sample = vae.apply({"params": vae_params}, latents, method=vae.decode) 107 | sample = eo.rearrange(sample, "(n x) h w c -> n h w (x c)", n=batch_size) 108 | return sample 109 | 110 | return partial(vae_encode, vae_params), partial(vae_decode, vae_params) 111 | 112 | 113 | def load_text_encoder( 114 | path: str, 115 | ) -> Tuple[ 116 | Callable[[List[str]], np.ndarray], 117 | Callable[[np.ndarray], List[str]], 118 | Callable[[jax.Array], jax.Array], 119 | ]: 120 | if ":" in path: 121 | path, revision = path.split(":") 122 | else: 123 | revision = None 124 | text_encoder = FlaxCLIPTextModel.from_pretrained( 125 | path, subfolder="text_encoder", revision=revision 126 | ) 127 | tokenizer = CLIPTokenizer.from_pretrained( 128 | path, subfolder="tokenizer", revision=revision 129 | ) 130 | 131 | def tokenize(s: List[str]) -> np.ndarray: 132 | return tokenizer(s, padding="max_length", return_tensors="np").input_ids 133 | 134 | untokenize = partial(tokenizer.batch_decode, skip_special_tokens=True) 135 | 136 | @jax.jit 137 | def text_encode(params, prompt_ids): 138 | return text_encoder(prompt_ids, params=params)[0] 139 | 140 | return tokenize, untokenize, partial(text_encode, text_encoder.params) 141 | 142 | 143 | def load_pretrained_unet( 144 | path: str, in_channels: int 145 | ) -> Tuple[FlaxUNet2DConditionModel, dict]: 146 | model_def, params = FlaxUNet2DConditionModel.from_pretrained( 147 | path, dtype=np.float32, subfolder="unet" 148 | ) 149 | 150 | # same issue, they commit the params to the CPU, which totally messes stuff 151 | # up downstream... 152 | params = jax.device_get(params) 153 | 154 | # add extra parameters to conv_in if necessary 155 | old_conv_in = params["conv_in"]["kernel"] 156 | h, w, cin, cout = old_conv_in.shape 157 | logging.info(f"Adding {in_channels - cin} channels to conv_in") 158 | params["conv_in"]["kernel"] = np.zeros( 159 | (h, w, in_channels, cout), dtype=old_conv_in.dtype 160 | ) 161 | params["conv_in"]["kernel"][:, :, :cin, :] = old_conv_in 162 | 163 | # monkey-patch __call__ to use channels-last 164 | model_def.__call__ = lambda self, sample, *args, **kwargs: eo.rearrange( 165 | FlaxUNet2DConditionModel.__call__( 166 | self, eo.rearrange(sample, "b h w c -> b c h w"), *args, **kwargs 167 | ).sample, 168 | "b c h w -> b h w c", 169 | ) 170 | 171 | return model_def, params 172 | 173 | 174 | def create_sample_fn( 175 | path: str, 176 | wandb_run_name: Optional[str] = None, 177 | num_timesteps: int = 50, 178 | prompt_w: float = 7.5, 179 | context_w: float = 2.5, 180 | eta: float = 0.0, 181 | pretrained_path: str = "runwayml/stable-diffusion-v1-5:flax", 182 | ) -> Callable[[np.ndarray, str], np.ndarray]: 183 | if ( 184 | os.path.exists(path) 185 | and os.path.isdir(path) 186 | and "checkpoint" in os.listdir(path) 187 | ): 188 | # this is an orbax checkpoint 189 | assert wandb_run_name is not None 190 | # load config from wandb 191 | api = wandb.Api() 192 | run = api.run(wandb_run_name) 193 | config = ml_collections.ConfigDict(run.config) 194 | 195 | # load params 196 | params = orbax.checkpoint.PyTreeCheckpointer().restore(path, item=None) 197 | assert "params_ema" not in params 198 | 199 | # load model 200 | model_def = create_model_def(config.model) 201 | else: 202 | # assume this is in HuggingFace format 203 | model_def, params = load_pretrained_unet(path, in_channels=8) 204 | 205 | # hardcode scheduling config to be "scaled_linear" (used by Stable Diffusion) 206 | config = {"scheduling": {"noise_schedule": "scaled_linear"}} 207 | 208 | state = EmaTrainState( 209 | step=0, 210 | apply_fn=model_def.apply, 211 | params=None, 212 | params_ema=params, 213 | tx=None, 214 | opt_state=None, 215 | ) 216 | del params 217 | 218 | # load encoders 219 | vae_encode, vae_decode = load_vae(pretrained_path) 220 | tokenize, untokenize, text_encode = load_text_encoder(pretrained_path) 221 | uncond_prompt_embed = text_encode(tokenize([""])) # (1, 77, 768) 222 | 223 | log_snr_fn = scheduling.create_log_snr_fn(config["scheduling"]) 224 | sample_loop = partial(sampling.sample_loop, log_snr_fn=log_snr_fn) 225 | 226 | rng = jax.random.PRNGKey(int(time.time())) 227 | 228 | def sample(image, prompt, prompt_w=prompt_w, context_w=context_w): 229 | nonlocal rng 230 | 231 | image = image / 127.5 - 1.0 232 | image = image[None] 233 | assert image.shape == (1, 256, 256, 3) 234 | 235 | prompt_embeds = text_encode(tokenize([prompt])) 236 | 237 | # encode stuff 238 | rng, encode_rng = jax.random.split(rng) 239 | contexts = vae_encode(encode_rng, image, scale=False) 240 | 241 | rng, sample_rng = jax.random.split(rng) 242 | samples = sample_loop( 243 | sample_rng, 244 | state, 245 | contexts, 246 | prompt_embeds, 247 | num_timesteps=num_timesteps, 248 | prompt_w=prompt_w, 249 | context_w=context_w, 250 | eta=eta, 251 | uncond_y=jnp.zeros_like(contexts), 252 | uncond_prompt_embeds=uncond_prompt_embed, 253 | ) 254 | samples = vae_decode(samples) 255 | samples = jnp.clip(jnp.round(samples * 127.5 + 127.5), 0, 255).astype(jnp.uint8) 256 | 257 | return jax.device_get(samples[0]) 258 | 259 | return sample 260 | -------------------------------------------------------------------------------- /susie/sampling.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | @jax.jit 8 | def q_sample(x_0, log_snr, noise): 9 | """Samples from the diffusion process at a given timestep. 10 | 11 | Args: 12 | x_0: the start image. 13 | log_snr: the log SNR (lambda) for the timestep t. 14 | noise: the noise, typically N(0, I) or output of the UNet. 15 | Returns: 16 | the resulting sample from q(x_t | x_0). 17 | """ 18 | # based on A.4 of VDM paper 19 | # the signs are flipped bc log_snr is monotonic decreasing 20 | # (rather than increasing as is their learned gamma in the paper) 21 | alpha = jnp.sqrt(jax.nn.sigmoid(log_snr))[:, None, None, None] 22 | sigma = jnp.sqrt(jax.nn.sigmoid(-log_snr))[:, None, None, None] 23 | return alpha * x_0 + sigma * noise 24 | 25 | 26 | @partial(jax.jit, static_argnames="use_ema") 27 | def model_predict(state, x, y, prompt_embeds, t, use_ema=True): 28 | """Runs forward inference of the model. 29 | 30 | Args: 31 | state: an EmaTrainState instance. 32 | x: the input image. 33 | y: the image context. 34 | prompt_embeds: the prompt embeddings. 35 | t: the current timestep in the range [0, 1]. 36 | use_ema: whether to use the exponential moving average of the parameters. 37 | Returns: 38 | the raw output of the UNet. 39 | """ 40 | if use_ema: 41 | variables = {"params": state.params_ema} 42 | else: 43 | variables = {"params": state.params} 44 | 45 | input = jnp.concatenate([x, y], axis=-1) 46 | 47 | return state.apply_fn(variables, input, t * 1000, prompt_embeds, train=False) 48 | 49 | 50 | @partial(jax.jit, static_argnames="log_snr_fn") 51 | def sample_step( 52 | rng, 53 | state, 54 | x, 55 | y, 56 | prompt_embeds, 57 | uncond_y, 58 | uncond_prompt_embeds, 59 | t, 60 | t_next, 61 | log_snr_fn, 62 | context_w, 63 | prompt_w, 64 | eta, 65 | ): 66 | """Runs a sampling step. 67 | 68 | Derived from a combination of 69 | https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py, 70 | the appendix of the VDM paper, the appendix of the Imagen paper, and DDIM 71 | paper. 72 | 73 | Args: 74 | rng: a JAX PRNGKey. 75 | state: an EmaTrainState instance. 76 | x: the input image (batched). 77 | y: the image context (batched). 78 | prompt_embeds: the prompt embeddings (batched). 79 | t: the current timestep (t). 80 | t_next: the next timestep (s, with s < t). 81 | log_snr_fn: a function that takes a timestep t ~ [0, 1] and returns the log SNR. 82 | w: the weight to use for classifier-free guidance. 1.0 (default) uses the 83 | conditional model only. 84 | eta: the DDIM eta parameter. 0.0 is full deterministic, 1.0 is ancestral 85 | sampling. 86 | """ 87 | assert len(x.shape) == 4 88 | assert len(y.shape) == 4 89 | 90 | batched_t = jnp.full(x.shape[0], t) 91 | batched_t_next = jnp.full(x.shape[0], t_next) 92 | 93 | uncond_pred = model_predict(state, x, uncond_y, uncond_prompt_embeds, batched_t) 94 | context_pred = model_predict(state, x, y, uncond_prompt_embeds, batched_t) 95 | prompt_pred = model_predict(state, x, y, prompt_embeds, batched_t) 96 | 97 | pred_eps = ( 98 | uncond_pred 99 | + context_w * (context_pred - uncond_pred) 100 | + prompt_w * (prompt_pred - context_pred) 101 | ) 102 | 103 | # compute log snr 104 | log_snr = log_snr_fn(batched_t) 105 | log_snr_next = log_snr_fn(batched_t_next) 106 | 107 | # signs are flipped from VDM paper, see q_sample above 108 | alpha_t_sq = jax.nn.sigmoid(log_snr)[:, None, None, None] 109 | alpha_s_sq = jax.nn.sigmoid(log_snr_next)[:, None, None, None] 110 | sigma_t_sq = jax.nn.sigmoid(-log_snr)[:, None, None, None] 111 | sigma_s_sq = jax.nn.sigmoid(-log_snr_next)[:, None, None, None] 112 | 113 | # this constant from A.4 of the VDM paper is equal to sigma_{t|s}^2 / sigma_t^2 114 | c = -jnp.expm1(log_snr - log_snr_next)[:, None, None, None] 115 | 116 | # this is equivalent to the posterior stddev in ancestral sampling 117 | d = jnp.sqrt(sigma_s_sq * c) 118 | # DDIM scales this by eta 119 | d = eta * d 120 | 121 | # fresh noise 122 | noise = jax.random.normal(rng, x.shape) 123 | 124 | # get predicted x0 125 | x_0 = (x - jnp.sqrt(sigma_t_sq) * pred_eps) / jnp.sqrt(alpha_t_sq) 126 | 127 | # clip it -- removed bc latent space is not bounded 128 | # x_0 = jnp.clip(x_0, -1, 1) 129 | 130 | # compute x_s using DDIM formula 131 | x_s = ( 132 | jnp.sqrt(alpha_s_sq) * x_0 133 | + jnp.sqrt(sigma_s_sq - d**2) * pred_eps 134 | + d * noise 135 | ) 136 | return x_s, x_0 137 | 138 | 139 | @partial(jax.jit, static_argnames=("num_timesteps", "log_snr_fn")) 140 | def sample_loop( 141 | rng, 142 | state, 143 | y, 144 | prompt_embeds, 145 | uncond_y, 146 | uncond_prompt_embeds, 147 | *, 148 | log_snr_fn, 149 | num_timesteps, 150 | context_w=1.0, 151 | prompt_w=1.0, 152 | eta=0.0, 153 | ): 154 | """Runs the full sampling loop. 155 | 156 | Implements the following loop using a scan: 157 | 158 | ``` 159 | for t, t_next in reversed(zip(t_seq, t_seq_next)): 160 | rng, step_rng = jax.random.split(rng) 161 | x, x0 = sample_step(x, x0, ...) 162 | return x0 163 | ``` 164 | 165 | Args: 166 | rng: a JAX PRNGKey. 167 | state: an EmaTrainState instance. 168 | y: the image context (batched). 169 | prompt_embeds: the text prompt embeddings (batched). 170 | log_snr_fn: a function that takes a timestep t ~ [0, 1] and returns the log SNR. 171 | num_timesteps: the number of timesteps to run. 172 | w: the weight to use for classifier-free guidance. 1.0 (default) uses the 173 | conditional model only. 174 | eta: the DDIM eta parameter. 0.0 (default) is full deterministic, 1.0 is 175 | ancestral sampling. 176 | """ 177 | assert len(y.shape) == 4 178 | 179 | def scan_fn(carry, t_combined): 180 | rng, x, x0 = carry 181 | t, t_next = t_combined 182 | rng, step_rng = jax.random.split(rng) 183 | x, x0 = sample_step( 184 | step_rng, 185 | state, 186 | x, 187 | y, 188 | prompt_embeds, 189 | uncond_y, 190 | uncond_prompt_embeds, 191 | t, 192 | t_next, 193 | log_snr_fn, 194 | context_w, 195 | prompt_w, 196 | eta, 197 | ) 198 | return (rng, x, x0), None 199 | 200 | if y.shape[-1] % 4 == 0: 201 | # vae-encoded 202 | channel_dim = 4 203 | elif y.shape[-1] % 3 == 0: 204 | # full images 205 | channel_dim = 3 206 | else: 207 | raise ValueError(f"Invalid channel dimension {y.shape[-1]}") 208 | 209 | rng, init_rng = jax.random.split(rng) 210 | x = jax.random.normal( 211 | init_rng, y.shape[:-1] + (channel_dim,) 212 | ) # initial image (pure noise) 213 | x0 = jnp.zeros_like(x) # unused 214 | 215 | # evenly spaced sequence of timesteps 216 | t_seq = jnp.linspace(0, 1, num=num_timesteps, endpoint=False, dtype=jnp.float32) 217 | t_seq_cur = t_seq[1:] 218 | t_seq_next = t_seq[:-1] 219 | t_seq_combined = jnp.stack([t_seq_cur, t_seq_next], axis=-1) 220 | 221 | rng, scan_rng = jax.random.split(rng) 222 | (_, _, final_x0), _ = jax.lax.scan( 223 | scan_fn, (scan_rng, x, x0), t_seq_combined[::-1], unroll=1 224 | ) 225 | 226 | return final_x0 227 | -------------------------------------------------------------------------------- /susie/scheduling.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | # continuous timestep scheduling adapted from 5 | # https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 6 | # 7 | # this code follows the conventions and notations from the Variation Diffusion Models (VDM) 8 | # paper, where the timestep t is a continuous variable in [0, 1] that is mapped to a log SNR 9 | # lambda using a schedule. the log SNR is what the UNet is conditioned on, and is also used 10 | # to compute the mean and stdev (alpha and sigma) of the diffusion process. I believe 11 | # most of the particular formulas (e.g. the usage of expm1) are from the appendix of the 12 | # VDM paper. 13 | 14 | 15 | def lnpoch(a, b): 16 | # Computes the log of the rising factorial (a)_b in a numerically stable way when a >> b. 17 | # From https://stackoverflow.com/questions/21228076/the-precision-of-scipy-special-gammaln 18 | # lmao 19 | return ( 20 | (b**4 / 12 - b**3 / 6 + b**2 / 12) / a**3 21 | + (-(b**3) / 6 + b**2 / 4 - b / 12) / a**2 22 | + (b**2 / 2 - b / 2) / a 23 | - b * jnp.log(1 / a) 24 | ) 25 | 26 | 27 | def linear_log_snr(t, *, beta_start=0.001, beta_end=0.02, num_timesteps=1000): 28 | """Computes log SNR from t ~ [0, 1] for a linear beta schedule.""" 29 | m = (beta_end - beta_start) / num_timesteps 30 | b = 1 - beta_start 31 | n = t * num_timesteps 32 | 33 | log_alpha_sq = (n + 1) * jnp.log(m) + lnpoch(b / m - n, n + 1) 34 | return jax.scipy.special.logit(jnp.exp(log_alpha_sq)) 35 | 36 | 37 | def scaled_linear_log_snr(t, *, beta_start=0.00085, beta_end=0.012, num_timesteps=1000): 38 | """Computes log SNR from t ~ [0, 1] for a scaled (sqrt) linear beta schedule, as used in stable diffusion.""" 39 | m = (beta_end**0.5 - beta_start**0.5) / num_timesteps 40 | b = beta_start**0.5 41 | n = t * num_timesteps 42 | 43 | fact = lnpoch((1 - b) / m - n, n + 1) + lnpoch((1 + b) / m, n + 1) 44 | pow = 2 * (n + 1) * jnp.log(m) 45 | alpha_sq = jnp.exp(fact + pow) 46 | return jax.scipy.special.logit(alpha_sq) 47 | 48 | 49 | def cosine_log_snr(t, s: float = 0.008, d: float = 0.008): 50 | """Computes log SNR from t ~ [0, 1] for a cosine beta schedule. 51 | 52 | In the original Improved DDPM paper, they add an offset of s=0.008 on the 53 | *left* side of the schedule, because they found it was hard for the NN to 54 | predict very small amounts of noise. Without this offset we would have 55 | alpha=1 and sigma=0 at t=0 and hence log_snr=+inf. However, they leave the 56 | singularity on the *right* side of the schedule: i.e. at t=1, alpha=0 and 57 | sigma=1, so log_snr=-inf. They deal with this singularity by clipping beta 58 | to a maximum value of 0.999. The problem is that in this formulation we 59 | don't directly calculate alpha or beta -- instead we define the schedule in 60 | terms of log_snr and calculate all other relevant quantities from that. So, 61 | to deal with the singularity at t=1, I'm adding a symmetrical offset of 62 | d=0.008 on the right side of the schedule, so that log_snr is finite at t=1. 63 | I've never seen this anywhere, but hopefully it works :). 64 | """ 65 | return -jnp.log((jnp.cos(((t / (1 + d)) + s) / (1 + s) * jnp.pi * 0.5) ** -2) - 1) 66 | 67 | 68 | def create_log_snr_fn(config): 69 | """ 70 | Returns a function that maps from t ~ [0, 1] to lambda (log SNR). The log SNR is 71 | used to condition the neural network as well as compute the mean and stdev of the 72 | diffusion process. 73 | """ 74 | schedule_name = config["noise_schedule"] 75 | 76 | if schedule_name == "linear": 77 | log_snr_fn = linear_log_snr 78 | elif schedule_name == "cosine": 79 | log_snr_fn = cosine_log_snr 80 | elif schedule_name == "scaled_linear": 81 | log_snr_fn = scaled_linear_log_snr 82 | else: 83 | raise ValueError(f"unknown noise schedule {schedule_name}") 84 | 85 | return log_snr_fn 86 | 87 | 88 | def create_ema_decay_fn(config): 89 | def ema_decay_schedule(step): 90 | count = jnp.clip(step - config.start_step - 1, a_min=0.0) 91 | value = 1 - (1 + count / config.inv_gamma) ** -config.power 92 | ema_rate = jnp.clip(value, a_min=config.min_decay, a_max=config.max_decay) 93 | return ema_rate 94 | 95 | return ema_decay_schedule 96 | --------------------------------------------------------------------------------