├── .gitignore ├── README.md ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs └── metaworld.py ├── experiments ├── __init__.py ├── run_oracle.py ├── train_furl.py ├── train_liv.py ├── train_relay.py └── train_sac.py ├── main.py ├── models ├── __init__.py ├── common.py ├── furl.py ├── liv.py ├── projection.py ├── sac.py └── vlm.py ├── requirements.txt ├── scripts └── run.sh └── utils ├── __init__.py ├── buffer_utils.py ├── env_utils.py ├── liv_utils.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | data 4 | logs* 5 | imgs* 6 | saved_videos* 7 | saved_models* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FuRL 2 | 3 | ## Environment Setup 4 | 5 | Install the conda env via: 6 | 7 | ```shell 8 | conda create --name furl python==3.11 9 | conda activate furl 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Training 14 | 15 | ### Generating Expert Dataset 16 | 17 | An optional setting in FuRL is to use a goal image to accelerate the exploration before we collected the first successful trajectory. 18 | 19 | ```script 20 | python main.py --config.env_name=door-open-v2-goal-hidden --config.exp_name=oracle 21 | ``` 22 | 23 | The oracle trajectory data will be saved in `data/oracle`. 24 | 25 | ### Example on Fixed-goal Task 26 | 27 | ``` 28 | python main.py --config.env_name=door-open-v2-goal-hidden --config.exp_name=furl 29 | ``` 30 | 31 | ### Example on Random-goal Task 32 | 33 | ``` 34 | python main.py --config.env_name=door-open-v2-goal-observable --config.exp_name=furl 35 | ``` 36 | 37 | ## Paper 38 | 39 | [**FuRL: Visual-Language Models as Fuzzy Rewards for Reinforcement Learning**](https://arxiv.org/pdf/2406.00645) 40 | 41 | Yuwei Fu, Haichao Zhang, Di Wu, Wei Xu, Benoit Boulet 42 | 43 | *International Conference on Machine Learning* (ICML), 2024 44 | 45 | ## Cite 46 | 47 | Please cite our work if you find it useful: 48 | 49 | ```txt 50 | @InProceedings{fu2024, 51 | title = {FuRL: Visual-Language Models as Fuzzy Rewards for Reinforcement Learning}, 52 | author = {Yuwei Fu and Haichao Zhang and Di Wu and Wei Xu and Benoit Boulet}, 53 | booktitle = {Proceedings of the 41st International Conference on Machine Learning}, 54 | year = {2024} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuyw/FuRL/e3c4c45cf674960c8fb8bcdfcc567189455e260f/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC, antialias=None), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 95 | jit: bool = False, download_root: str = None, scratch=False): 96 | """Load a CLIP model 97 | 98 | Parameters 99 | ---------- 100 | name : str 101 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 102 | 103 | device : Union[str, torch.device] 104 | The device to put the loaded model 105 | 106 | jit : bool 107 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 108 | 109 | download_root: str 110 | path to download the model files; by default, it uses "~/.cache/clip" 111 | 112 | Returns 113 | ------- 114 | model : torch.nn.Module 115 | The CLIP model 116 | 117 | preprocess : Callable[[PIL.Image], torch.Tensor] 118 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 119 | """ 120 | if name in _MODELS: 121 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 122 | elif os.path.isfile(name): 123 | model_path = name 124 | else: 125 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 126 | 127 | with open(model_path, 'rb') as opened_file: 128 | try: 129 | # loading JIT archive 130 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 131 | state_dict = None 132 | except RuntimeError: 133 | # loading saved state dict 134 | if jit: 135 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 136 | jit = False 137 | state_dict = torch.load(opened_file, map_location="cpu") 138 | 139 | if not jit: 140 | model = build_model(state_dict or model.state_dict(), scratch=scratch).to(device) 141 | if str(device) == "cpu": 142 | model.float() 143 | return model, _transform(model.visual.input_resolution) 144 | 145 | # patch the device names 146 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 147 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 148 | 149 | def patch_device(module): 150 | try: 151 | graphs = [module.graph] if hasattr(module, "graph") else [] 152 | except RuntimeError: 153 | graphs = [] 154 | 155 | if hasattr(module, "forward1"): 156 | graphs.append(module.forward1.graph) 157 | 158 | for graph in graphs: 159 | for node in graph.findAllNodes("prim::Constant"): 160 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 161 | node.copyAttributes(device_node) 162 | 163 | model.apply(patch_device) 164 | patch_device(model.encode_image) 165 | patch_device(model.encode_text) 166 | 167 | # patch dtype to float32 on CPU 168 | if str(device) == "cpu": 169 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 170 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 171 | float_node = float_input.node() 172 | 173 | def patch_float(module): 174 | try: 175 | graphs = [module.graph] if hasattr(module, "graph") else [] 176 | except RuntimeError: 177 | graphs = [] 178 | 179 | if hasattr(module, "forward1"): 180 | graphs.append(module.forward1.graph) 181 | 182 | for graph in graphs: 183 | for node in graph.findAllNodes("aten::to"): 184 | inputs = list(node.inputs()) 185 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 186 | if inputs[i].node()["value"] == 5: 187 | inputs[i].node().copyAttributes(float_node) 188 | 189 | model.apply(patch_float) 190 | patch_float(model.encode_image) 191 | patch_float(model.encode_text) 192 | 193 | model.float() 194 | 195 | return model, _transform(model.input_resolution.item()) 196 | 197 | 198 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 199 | """ 200 | Returns the tokenized representation of given input string(s) 201 | 202 | Parameters 203 | ---------- 204 | texts : Union[str, List[str]] 205 | An input string or a list of input strings to tokenize 206 | 207 | context_length : int 208 | The context length to use; all CLIP models use 77 as the context length 209 | 210 | truncate: bool 211 | Whether to truncate the text in case its encoding is longer than the context length 212 | 213 | Returns 214 | ------- 215 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 216 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 217 | """ 218 | if isinstance(texts, str): 219 | texts = [texts] 220 | 221 | sot_token = _tokenizer.encoder["<|startoftext|>"] 222 | eot_token = _tokenizer.encoder["<|endoftext|>"] 223 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 224 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 225 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 226 | else: 227 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 228 | 229 | for i, tokens in enumerate(all_tokens): 230 | if len(tokens) > context_length: 231 | if truncate: 232 | tokens = tokens[:context_length] 233 | tokens[-1] = eot_token 234 | else: 235 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 236 | result[i, :len(tokens)] = torch.tensor(tokens) 237 | 238 | return result 239 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | out = self.relu1(self.bn1(self.conv1(x))) 45 | out = self.relu2(self.bn2(self.conv2(out))) 46 | out = self.avgpool(out) 47 | out = self.bn3(self.conv3(out)) 48 | 49 | if self.downsample is not None: 50 | identity = self.downsample(x) 51 | 52 | out += identity 53 | out = self.relu3(out) 54 | return out 55 | 56 | 57 | class AttentionPool2d(nn.Module): 58 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 59 | super().__init__() 60 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 61 | self.k_proj = nn.Linear(embed_dim, embed_dim) 62 | self.q_proj = nn.Linear(embed_dim, embed_dim) 63 | self.v_proj = nn.Linear(embed_dim, embed_dim) 64 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 65 | self.num_heads = num_heads 66 | 67 | def forward(self, x): 68 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 69 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 70 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 71 | x, _ = F.multi_head_attention_forward( 72 | query=x[:1], key=x, value=x, 73 | embed_dim_to_check=x.shape[-1], 74 | num_heads=self.num_heads, 75 | q_proj_weight=self.q_proj.weight, 76 | k_proj_weight=self.k_proj.weight, 77 | v_proj_weight=self.v_proj.weight, 78 | in_proj_weight=None, 79 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 80 | bias_k=None, 81 | bias_v=None, 82 | add_zero_attn=False, 83 | dropout_p=0, 84 | out_proj_weight=self.c_proj.weight, 85 | out_proj_bias=self.c_proj.bias, 86 | use_separate_proj_weight=True, 87 | training=self.training, 88 | need_weights=False 89 | ) 90 | return x.squeeze(0) 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.relu1 = nn.ReLU(inplace=True) 110 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 111 | self.bn2 = nn.BatchNorm2d(width // 2) 112 | self.relu2 = nn.ReLU(inplace=True) 113 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 114 | self.bn3 = nn.BatchNorm2d(width) 115 | self.relu3 = nn.ReLU(inplace=True) 116 | self.avgpool = nn.AvgPool2d(2) 117 | 118 | # residual layers 119 | self._inplanes = width # this is a *mutable* variable used during construction 120 | self.layer1 = self._make_layer(width, layers[0]) 121 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 122 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 123 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 124 | 125 | embed_dim = width * 32 # the ResNet feature dimension 126 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 127 | 128 | def _make_layer(self, planes, blocks, stride=1): 129 | layers = [Bottleneck(self._inplanes, planes, stride)] 130 | 131 | self._inplanes = planes * Bottleneck.expansion 132 | for _ in range(1, blocks): 133 | layers.append(Bottleneck(self._inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | def stem(x): 139 | x = self.relu1(self.bn1(self.conv1(x))) 140 | x = self.relu2(self.bn2(self.conv2(x))) 141 | x = self.relu3(self.bn3(self.conv3(x))) 142 | x = self.avgpool(x) 143 | return x 144 | 145 | x = x.type(self.conv1.weight.dtype) 146 | x = stem(x) 147 | x = self.layer1(x) 148 | x = self.layer2(x) 149 | x = self.layer3(x) 150 | x = self.layer4(x) 151 | x = self.attnpool(x) 152 | 153 | return x 154 | 155 | 156 | class LayerNorm(nn.LayerNorm): 157 | """Subclass torch's LayerNorm to handle fp16.""" 158 | 159 | def forward(self, x: torch.Tensor): 160 | orig_type = x.dtype 161 | ret = super().forward(x.type(torch.float32)) 162 | return ret.type(orig_type) 163 | 164 | 165 | class QuickGELU(nn.Module): 166 | def forward(self, x: torch.Tensor): 167 | return x * torch.sigmoid(1.702 * x) 168 | 169 | 170 | class ResidualAttentionBlock(nn.Module): 171 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 172 | super().__init__() 173 | 174 | self.attn = nn.MultiheadAttention(d_model, n_head) 175 | self.ln_1 = LayerNorm(d_model) 176 | self.mlp = nn.Sequential(OrderedDict([ 177 | ("c_fc", nn.Linear(d_model, d_model * 4)), 178 | ("gelu", QuickGELU()), 179 | ("c_proj", nn.Linear(d_model * 4, d_model)) 180 | ])) 181 | self.ln_2 = LayerNorm(d_model) 182 | self.attn_mask = attn_mask 183 | 184 | def attention(self, x: torch.Tensor): 185 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 186 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 187 | 188 | def forward(self, x: torch.Tensor): 189 | x = x + self.attention(self.ln_1(x)) 190 | x = x + self.mlp(self.ln_2(x)) 191 | return x 192 | 193 | 194 | class Transformer(nn.Module): 195 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 196 | super().__init__() 197 | self.width = width 198 | self.layers = layers 199 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 200 | 201 | def forward(self, x: torch.Tensor): 202 | return self.resblocks(x) 203 | 204 | 205 | class VisionTransformer(nn.Module): 206 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 207 | super().__init__() 208 | self.input_resolution = input_resolution 209 | self.output_dim = output_dim 210 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 211 | 212 | scale = width ** -0.5 213 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 214 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 215 | self.ln_pre = LayerNorm(width) 216 | 217 | self.transformer = Transformer(width, layers, heads) 218 | 219 | self.ln_post = LayerNorm(width) 220 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 221 | 222 | def forward(self, x: torch.Tensor): 223 | x = self.conv1(x) # shape = [*, width, grid, grid] 224 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 225 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 226 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 227 | x = x + self.positional_embedding.to(x.dtype) 228 | x = self.ln_pre(x) 229 | 230 | x = x.permute(1, 0, 2) # NLD -> LND 231 | x = self.transformer(x) 232 | x = x.permute(1, 0, 2) # LND -> NLD 233 | 234 | x = self.ln_post(x[:, 0, :]) 235 | 236 | if self.proj is not None: 237 | x = x @ self.proj 238 | 239 | return x 240 | 241 | 242 | class CLIP(nn.Module): 243 | def __init__(self, 244 | embed_dim: int, 245 | # vision 246 | image_resolution: int, 247 | vision_layers: Union[Tuple[int, int, int, int], int], 248 | vision_width: int, 249 | vision_patch_size: int, 250 | # text 251 | context_length: int, 252 | vocab_size: int, 253 | transformer_width: int, 254 | transformer_heads: int, 255 | transformer_layers: int 256 | ): 257 | super().__init__() 258 | 259 | self.context_length = context_length 260 | 261 | if isinstance(vision_layers, (tuple, list)): 262 | vision_heads = vision_width * 32 // 64 263 | self.visual = ModifiedResNet( 264 | layers=vision_layers, 265 | output_dim=embed_dim, 266 | heads=vision_heads, 267 | input_resolution=image_resolution, 268 | width=vision_width 269 | ) 270 | else: 271 | vision_heads = vision_width // 64 272 | self.visual = VisionTransformer( 273 | input_resolution=image_resolution, 274 | patch_size=vision_patch_size, 275 | width=vision_width, 276 | layers=vision_layers, 277 | heads=vision_heads, 278 | output_dim=embed_dim 279 | ) 280 | 281 | self.transformer = Transformer( 282 | width=transformer_width, 283 | layers=transformer_layers, 284 | heads=transformer_heads, 285 | attn_mask=self.build_attention_mask() 286 | ) 287 | 288 | self.vocab_size = vocab_size 289 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 290 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 291 | self.ln_final = LayerNorm(transformer_width) 292 | 293 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 294 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 295 | 296 | self.initialize_parameters() 297 | 298 | def initialize_parameters(self): 299 | nn.init.normal_(self.token_embedding.weight, std=0.02) 300 | nn.init.normal_(self.positional_embedding, std=0.01) 301 | 302 | if isinstance(self.visual, ModifiedResNet): 303 | if self.visual.attnpool is not None: 304 | std = self.visual.attnpool.c_proj.in_features ** -0.5 305 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 306 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 307 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 308 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 309 | 310 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 311 | for name, param in resnet_block.named_parameters(): 312 | if name.endswith("bn3.weight"): 313 | nn.init.zeros_(param) 314 | 315 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 316 | attn_std = self.transformer.width ** -0.5 317 | fc_std = (2 * self.transformer.width) ** -0.5 318 | for block in self.transformer.resblocks: 319 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 320 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 321 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 322 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 323 | 324 | if self.text_projection is not None: 325 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 326 | 327 | def build_attention_mask(self): 328 | # lazily create causal attention mask, with full attention between the vision tokens 329 | # pytorch uses additive attention mask; fill with -inf 330 | mask = torch.empty(self.context_length, self.context_length) 331 | mask.fill_(float("-inf")) 332 | mask.triu_(1) # zero out the lower diagonal 333 | return mask 334 | 335 | @property 336 | def dtype(self): 337 | return self.visual.conv1.weight.dtype 338 | 339 | def encode_image(self, image): 340 | return self.visual(image.type(self.dtype)) 341 | 342 | def encode_text(self, text): 343 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 344 | x = x + self.positional_embedding.type(self.dtype) 345 | x = x.permute(1, 0, 2) # NLD -> LND 346 | x = self.transformer(x) 347 | x = x.permute(1, 0, 2) # LND -> NLD 348 | x = self.ln_final(x).type(self.dtype) 349 | 350 | # x.shape = [batch_size, n_ctx, transformer.width] 351 | # take features from the eot embedding (eot_token is the highest number in each sequence) 352 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 353 | return x 354 | 355 | def forward(self, image, text): 356 | image_features = self.encode_image(image) 357 | text_features = self.encode_text(text) 358 | 359 | # normalized features 360 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 361 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 362 | 363 | # cosine similarity as logits 364 | logit_scale = self.logit_scale.exp() 365 | logits_per_image = logit_scale * image_features @ text_features.t() 366 | logits_per_text = logits_per_image.t() 367 | 368 | # shape = [global_batch_size, global_batch_size] 369 | return logits_per_image, logits_per_text 370 | 371 | 372 | def convert_weights(model: nn.Module): 373 | """Convert applicable model parameters to fp16""" 374 | 375 | def _convert_weights_to_fp16(l): 376 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 377 | l.weight.data = l.weight.data.half() 378 | if l.bias is not None: 379 | l.bias.data = l.bias.data.half() 380 | 381 | if isinstance(l, nn.MultiheadAttention): 382 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 383 | tensor = getattr(l, attr) 384 | if tensor is not None: 385 | tensor.data = tensor.data.half() 386 | 387 | for name in ["text_projection", "proj"]: 388 | if hasattr(l, name): 389 | attr = getattr(l, name) 390 | if attr is not None: 391 | attr.data = attr.data.half() 392 | 393 | model.apply(_convert_weights_to_fp16) 394 | 395 | 396 | def build_model(state_dict: dict, scratch=False): 397 | vit = "visual.proj" in state_dict 398 | 399 | if vit: 400 | vision_width = state_dict["visual.conv1.weight"].shape[0] 401 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 402 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 403 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 404 | image_resolution = vision_patch_size * grid_size 405 | else: 406 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 407 | vision_layers = tuple(counts) 408 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 409 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 410 | vision_patch_size = None 411 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 412 | image_resolution = output_width * 32 413 | 414 | embed_dim = state_dict["text_projection"].shape[1] 415 | context_length = state_dict["positional_embedding"].shape[0] 416 | vocab_size = state_dict["token_embedding.weight"].shape[0] 417 | transformer_width = state_dict["ln_final.weight"].shape[0] 418 | transformer_heads = transformer_width // 64 419 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 420 | 421 | model = CLIP( 422 | embed_dim, 423 | image_resolution, vision_layers, vision_width, vision_patch_size, 424 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 425 | ) 426 | 427 | for key in ["input_resolution", "context_length", "vocab_size"]: 428 | if key in state_dict: 429 | del state_dict[key] 430 | 431 | convert_weights(model) 432 | 433 | if not scratch: 434 | model.load_state_dict(state_dict) 435 | return model.eval() 436 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /configs/metaworld.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | config.env_name = "peg-insert-side-v2-goal-observable" 8 | config.camera_id = 2 9 | config.residual = False 10 | config.eval_episodes = 100 11 | config.start_timesteps = 10000 12 | config.max_timesteps = int(1e6) 13 | config.decay_timesteps = int(7.5e5) 14 | config.eval_freq = config.max_timesteps // 10 15 | config.log_freq = config.max_timesteps // 100 16 | config.ckpt_freq = config.max_timesteps // 10 17 | config.lr = 1e-4 18 | config.seed = 0 19 | config.tau = 0.01 20 | config.gamma = 0.99 21 | config.batch_size = 256 22 | config.hidden_dims = (256, 256) 23 | config.initializer = "orthogonal" 24 | config.exp_name = "furl" 25 | 26 | # relay 27 | config.relay_threshold = 2500 28 | config.expl_noise = 0.2 29 | 30 | # fine-tune 31 | config.rho = 0.05 32 | config.gap = 10 33 | config.crop = False 34 | config.l2_margin = 0.25 35 | config.cosine_margin = 0.25 36 | config.embed_buffer_size = 20000 37 | 38 | return config 39 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from . import run_oracle 2 | from . import train_liv 3 | from . import train_sac 4 | from . import train_relay 5 | from . import train_furl -------------------------------------------------------------------------------- /experiments/run_oracle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import imageio 5 | import numpy as np 6 | import gymnasium as gym 7 | import ml_collections 8 | from utils import make_env, load_liv, TASKS 9 | 10 | 11 | from metaworld.policies import (SawyerButtonPressTopdownV2Policy, 12 | SawyerDrawerOpenV2Policy, 13 | SawyerDrawerCloseV2Policy, 14 | SawyerReachV2Policy, 15 | SawyerDoorOpenV2Policy, 16 | SawyerPushV2Policy, 17 | SawyerWindowOpenV2Policy, 18 | SawyerWindowCloseV2Policy, 19 | SawyerBasketballV2Policy, 20 | SawyerPegInsertionSideV2Policy, 21 | SawyerPickPlaceV2Policy, 22 | SawyerSweepV2Policy) 23 | 24 | 25 | ORACLE_POLICY = { 26 | "button-press-topdown-v2-goal": SawyerButtonPressTopdownV2Policy, 27 | "drawer-open-v2-goal": SawyerDrawerOpenV2Policy, 28 | "drawer-close-v2-goal": SawyerDrawerCloseV2Policy, 29 | "reach-v2-goal": SawyerReachV2Policy, 30 | "door-open-v2-goal": SawyerDoorOpenV2Policy, 31 | "push-v2-goal": SawyerPushV2Policy, 32 | "window-open-v2-goal": SawyerWindowOpenV2Policy, 33 | "window-close-v2-goal": SawyerWindowCloseV2Policy, 34 | "basketball-v2-goal": SawyerBasketballV2Policy, 35 | "peg-insert-side-v2-goal": SawyerPegInsertionSideV2Policy, 36 | "pick-place-v2-goal": SawyerPickPlaceV2Policy, 37 | "sweep-v2-goal": SawyerSweepV2Policy, 38 | } 39 | 40 | 41 | def eval_policy(policy, 42 | env, 43 | camera_id: int = 2, 44 | eval_episodes: int = 1, 45 | video_dir: str = None, 46 | traj_dir: str = None): 47 | t1 = time.time() 48 | eval_reward, eval_success, avg_step = 0, 0, 0 49 | frames, success, states = [], [], [] 50 | for i in range(1, eval_episodes + 1): 51 | obs, _ = env.reset() 52 | states = [obs] 53 | if video_dir and i == eval_episodes: 54 | frame = env.mujoco_renderer.render(render_mode="rgb_array", 55 | camera_id=camera_id) 56 | frames.append(frame[::-1]) 57 | success.append(0.0) 58 | while True: 59 | avg_step += 1 60 | action = policy.get_action(obs) 61 | obs, reward, terminated, truncated, info = env.step(action) 62 | states.append(obs) 63 | eval_success += info["success"] 64 | success.append(info["success"]) 65 | if video_dir and i == eval_episodes: 66 | frame = env.mujoco_renderer.render(render_mode="rgb_array", 67 | camera_id=camera_id) 68 | frames.append(frame[::-1]) 69 | eval_reward += reward 70 | if terminated or truncated: 71 | break 72 | eval_reward /= eval_episodes 73 | eval_success /= eval_episodes 74 | 75 | if video_dir: 76 | imageio.mimsave(f"{video_dir}_{eval_reward:.0f}.mp4", frames, fps=60) 77 | 78 | success = np.array(success) 79 | images = np.array(frames, dtype=np.uint8) 80 | np.savez(traj_dir, images=images, success=success) 81 | 82 | return eval_reward, eval_success, avg_step, time.time() - t1, images 83 | 84 | 85 | def evaluate(config: ml_collections.ConfigDict): 86 | start_time = time.time() 87 | 88 | # initialize the dm_control environment 89 | traj_dir = f"data/oracle/{config.env_name}/s{config.seed}_c{config.camera_id}" 90 | video_dir = f"saved_videos/oracle/{config.env_name}" 91 | os.makedirs(f"saved_videos/oracle", exist_ok=True) 92 | os.makedirs(f"data/oracle/{config.env_name}", exist_ok=True) 93 | 94 | config.unlock() 95 | if config.env_name == "reach-v2-goal-hidden": 96 | config.env_name = "reach-v2-goal-observable" 97 | elif config.env_name == "push-v2-goal-hidden": 98 | config.env_name = "push-v2-goal-observable" 99 | elif config.env_name == "pick-place-v2-goal-hidden": 100 | config.env_name = "pick-place-v2-goal-observable" 101 | elif config.env_name == "peg-insert-side-v2-goal-hidden": 102 | config.env_name == "peg-insert-side-v2-goal-observable" 103 | env = make_env(config.env_name, 104 | seed=config.seed, 105 | image_size=480, 106 | camera_id=config.camera_id) 107 | 108 | policy = ORACLE_POLICY["-".join(config.env_name.split("-")[:-1])]() 109 | eval_reward, eval_success, _, _, images = eval_policy(policy, 110 | env, 111 | camera_id=config.camera_id, 112 | video_dir=video_dir, 113 | traj_dir=traj_dir) 114 | print(f"{config.env_name}: eval_reward={eval_reward:.2f}, eval_success={eval_success:.0f}") 115 | -------------------------------------------------------------------------------- /experiments/train_furl.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".4" 3 | 4 | import cv2 5 | import time 6 | import clip 7 | import optax 8 | import imageio 9 | import ml_collections 10 | import gymnasium as gym 11 | import numpy as np 12 | import pandas as pd 13 | import matplotlib.pyplot as plt 14 | 15 | import torch 16 | import torchvision.transforms as T 17 | 18 | from tqdm import trange 19 | from models import SACAgent, FuRLAgent, RewardModel 20 | from utils import (TASKS, DistanceBuffer, EmbeddingBuffer, log_git, 21 | get_logger, make_env, load_liv) 22 | 23 | 24 | ################### 25 | # Utils Functions # 26 | ################### 27 | def crop_center(config, image): 28 | x1, x2, y1, y2 = 32, 224, 32, 224 29 | return image[x1:x2, y1:y2, :] 30 | 31 | 32 | def eval_policy(agent: SACAgent, 33 | env: gym.Env, 34 | eval_episodes: int = 10): 35 | t1 = time.time() 36 | eval_reward, eval_success, avg_step = 0, 0, 0 37 | for i in range(1, eval_episodes + 1): 38 | obs, _ = env.reset() 39 | while True: 40 | avg_step += 1 41 | action = agent.sample_action(obs, eval_mode=True) 42 | obs, reward, terminated, truncated, info = env.step(action) 43 | eval_reward += reward 44 | if terminated or truncated: 45 | eval_success += info["success"] 46 | break 47 | 48 | eval_reward /= eval_episodes 49 | eval_success /= eval_episodes 50 | 51 | return eval_reward, eval_success, avg_step, time.time() - t1 52 | 53 | 54 | def setup_logging(config): 55 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 56 | 57 | # logging 58 | exp_prefix = f"furl_rho{config.rho}" 59 | exp_name = f"{exp_prefix}/{config.env_name}/s{config.seed}_{timestamp}" 60 | os.makedirs(f"logs/{exp_prefix}/{config.env_name}", exist_ok=True) 61 | exp_info = f"# Running experiment for: {exp_name} #" 62 | print("#" * len(exp_info) + f"\n{exp_info}\n" + "#" * len(exp_info)) 63 | logger = get_logger(f"logs/{exp_name}.log") 64 | 65 | # add git commit info 66 | log_git(config) 67 | logger.info(f"Config:\n{config}\n") 68 | 69 | # set random seed 70 | np.random.seed(config.seed) 71 | torch.manual_seed(config.seed) 72 | 73 | return exp_name, logger 74 | 75 | 76 | def setup_exp(config): 77 | # liv 78 | transform = T.Compose([T.ToTensor()]) 79 | liv = load_liv() 80 | 81 | # task description embedding 82 | with torch.no_grad(): 83 | token = clip.tokenize([TASKS[config.env_name]]) 84 | text_embedding = liv(input=token, modality="text") 85 | text_embedding = text_embedding.detach().cpu().numpy() 86 | data = np.load(f"data/oracle/{config.env_name}/s0_c{config.camera_id}.npz") 87 | 88 | # goal_embedding / text_embedding 89 | oracle_images = data["images"] 90 | oracle_success = data["success"] 91 | oracle_traj_len = np.where(oracle_success)[0][0] + 1 # 84 92 | 93 | # initialize the environment 94 | env = make_env(config.env_name, 95 | seed=config.seed, 96 | camera_id=config.camera_id) 97 | eval_seed = config.seed if "hidden" in config.env_name else config.seed+100 98 | eval_env = make_env(config.env_name, 99 | seed=eval_seed, 100 | image_size=480, 101 | camera_id=config.camera_id) 102 | 103 | # environment parameter 104 | obs_dim = env.observation_space.shape[0] 105 | act_dim = env.action_space.shape[0] 106 | max_action = env.action_space.high[0] 107 | goal_image = data["images"][oracle_traj_len-1] 108 | goal_image = crop_center(config, goal_image) 109 | processed_goal_image = cv2.cvtColor(goal_image, cv2.COLOR_RGB2BGR) 110 | processed_goal_image = transform(processed_goal_image) 111 | goal_embedding = liv(input=processed_goal_image.to("cuda")[None], modality="vision") 112 | goal_embedding = goal_embedding.detach().cpu().numpy() 113 | 114 | # fixed LIV representation projection 115 | vlm_agent = FuRLAgent(obs_dim=obs_dim, 116 | act_dim=act_dim, 117 | max_action=max_action, 118 | seed=config.seed, 119 | tau=config.tau, 120 | rho=config.rho, 121 | margin=config.cosine_margin, 122 | gamma=config.gamma, 123 | lr=config.lr, 124 | text_embedding=text_embedding, 125 | goal_embedding=goal_embedding, 126 | hidden_dims=config.hidden_dims) 127 | 128 | # SAC agent 129 | sac_agent = SACAgent(obs_dim=obs_dim, 130 | act_dim=act_dim, 131 | max_action=max_action, 132 | seed=config.seed, 133 | tau=config.tau, 134 | gamma=config.gamma, 135 | lr=config.lr, 136 | hidden_dims=config.hidden_dims) 137 | 138 | # Initialize the reward model 139 | reward_model = RewardModel(seed=config.seed, 140 | text_embedding=text_embedding, 141 | goal_embedding=goal_embedding) 142 | 143 | # Replay buffer 144 | replay_buffer = DistanceBuffer(obs_dim=obs_dim, 145 | act_dim=act_dim, 146 | max_size=int(5e5)) 147 | 148 | return ( 149 | transform, 150 | liv, 151 | env, 152 | eval_env, 153 | vlm_agent, 154 | sac_agent, 155 | reward_model, 156 | replay_buffer, 157 | goal_image, 158 | ) 159 | 160 | 161 | ################# 162 | # Main Function # 163 | ################# 164 | def train_and_evaluate(config: ml_collections.ConfigDict): 165 | start_time = time.time() 166 | 167 | # logging setup 168 | exp_name, logger = setup_logging(config) 169 | 170 | # experiment setup 171 | (transform, 172 | liv, 173 | env, 174 | eval_env, 175 | vlm_agent, 176 | sac_agent, 177 | reward_model, 178 | replay_buffer, 179 | goal_image) = setup_exp(config) 180 | 181 | # reward for untrained agent 182 | eval_episodes = 1 if "hidden" in config.env_name else 10 183 | eval_reward, eval_success, _, _ = eval_policy(vlm_agent, 184 | eval_env, 185 | eval_episodes) 186 | logs = [{ 187 | "step": 0, 188 | "eval_reward": eval_reward, 189 | "eval_success": eval_success, 190 | }] 191 | 192 | first_success_step = 0 193 | 194 | # trajectory embedding 195 | embedding_buffer = EmbeddingBuffer(emb_dim=1024, 196 | gap=config.gap, 197 | max_size=config.embed_buffer_size) 198 | traj_embeddings = np.zeros((500, 1024)) 199 | traj_success = np.zeros(500) 200 | 201 | # relay freqs 202 | relay_freqs = [50, 100, 150, 200] 203 | relay_freq = np.random.choice(relay_freqs) 204 | logger.info(f"Relay freqs: {relay_freqs}\n") 205 | 206 | # start training 207 | obs, _ = env.reset() 208 | goal = obs[-3:] 209 | reward, ep_task_reward, ep_vlm_reward = 0, 0, 0 210 | success_cnt, ep_num, ep_step = 0, 0, 0 211 | lst_ep_step, lst_ep_task_reward, lst_ep_vlm_reward = 0, 0, 0 212 | sac_step, vlm_step = 0, 0 213 | lst_sac_step, lst_vlm_step = 0, 0 214 | policies = ["vlm", "sac"] 215 | use_relay = True 216 | pos_cosine = neg_cosine = lag_cosine = 0 217 | pos_cosine_max = neg_cosine_max = lag_cosine_max = 0 218 | pos_cosine_min = neg_cosine_min = lag_cosine_min = 0 219 | neg_num = neg_loss = neg_loss_max = 0 220 | pos_num = pos_loss = pos_loss_max = 0 221 | for t in trange(1, config.max_timesteps + 1): 222 | if t <= config.start_timesteps: 223 | action = env.action_space.sample() 224 | else: 225 | if use_relay: 226 | if policies[(ep_step//relay_freq)%2] == "vlm": 227 | vlm_step += 1 228 | action = vlm_agent.sample_action(obs) 229 | else: 230 | sac_step += 1 231 | action = sac_agent.sample_action(obs) 232 | action_noise = np.random.normal( 233 | 0, sac_agent.max_action*config.expl_noise, size=sac_agent.act_dim) 234 | action = (action + action_noise).clip( 235 | -sac_agent.max_action, sac_agent.max_action) 236 | else: 237 | vlm_step += 1 238 | action = vlm_agent.sample_action(obs) 239 | next_obs, task_reward, terminated, truncated, info = env.step(action) 240 | 241 | # vision language model reward 242 | image = env.mujoco_renderer.render( 243 | render_mode="rgb_array", 244 | camera_id=config.camera_id).copy() 245 | image = image[::-1] 246 | image = crop_center(config, image) 247 | processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 248 | processed_image = transform(processed_image) 249 | with torch.no_grad(): 250 | image_embedding = liv(input=processed_image.to("cuda")[None], modality="vision") 251 | image_embedding = image_embedding.detach().cpu().numpy() 252 | l2_distance = np.square(image_embedding - vlm_agent.goal_embedding).sum(-1)**0.5 253 | vlm_reward = reward_model.get_vlm_reward(reward_model.proj_state, image_embedding).item() 254 | 255 | reward = int(info["success"]) 256 | success_cnt += reward 257 | 258 | traj_embeddings[ep_step] = image_embedding 259 | traj_success[ep_step] = reward 260 | ep_step += 1 261 | 262 | if first_success_step == 0 and reward: 263 | first_success_step = ep_step 264 | 265 | # add to buffer 266 | replay_buffer.add(obs, 267 | action, 268 | next_obs, 269 | reward-1, 270 | terminated, 271 | image_embedding, 272 | l2_distance) 273 | obs = next_obs 274 | ep_vlm_reward += vlm_reward 275 | ep_task_reward += task_reward 276 | 277 | # start a new trajectory 278 | if terminated or truncated: 279 | obs, _ = env.reset() 280 | goal = obs[-3:] 281 | lst_ep_step = ep_step 282 | lst_ep_task_reward = ep_task_reward 283 | lst_ep_vlm_reward = ep_vlm_reward 284 | lst_sac_step = sac_step 285 | lst_vlm_step = vlm_step 286 | ep_vlm_reward = 0 287 | ep_task_reward = 0 288 | sac_step = 0 289 | vlm_step = 0 290 | policies = policies[::-1] 291 | relay_freq = np.random.choice(relay_freqs) 292 | 293 | # save embedding 294 | if first_success_step == 0: 295 | for j in range(ep_step): 296 | embedding_buffer.add(embedding=traj_embeddings[j], 297 | success=False) 298 | else: 299 | for j in range(first_success_step): 300 | embedding_buffer.add(embedding=traj_embeddings[j], 301 | success=True, 302 | valid=j>=config.gap) 303 | 304 | for j in range(first_success_step, ep_step): 305 | if traj_success[j]: 306 | embedding_buffer.add(embedding=traj_embeddings[j], 307 | success=True, 308 | valid=j>=config.gap) 309 | else: 310 | break 311 | 312 | ep_step = 0 313 | ep_num += 1 314 | first_success_step = 0 315 | 316 | if use_relay and embedding_buffer.pos_size >= config.relay_threshold: 317 | use_relay = False 318 | 319 | # training 320 | if t > config.start_timesteps: 321 | if (success_cnt > 0) and (embedding_buffer.valid_size > 0): 322 | batch = replay_buffer.sample(config.batch_size) 323 | embedding_batch = embedding_buffer.sample(config.batch_size) 324 | batch_vlm_rewards = reward_model.get_vlm_reward(reward_model.proj_state, 325 | batch.embeddings) 326 | proj_log_info = reward_model.update_pos(embedding_batch) 327 | log_info = vlm_agent.update(batch, batch_vlm_rewards) 328 | pos_cosine = proj_log_info["pos_cosine"] 329 | neg_cosine = proj_log_info["neg_cosine"] 330 | lag_cosine = proj_log_info["lag_cosine"] 331 | pos_cosine_max = proj_log_info["pos_cosine_max"] 332 | neg_cosine_max = proj_log_info["neg_cosine_max"] 333 | lag_cosine_max = proj_log_info["lag_cosine_max"] 334 | pos_cosine_min = proj_log_info["pos_cosine_min"] 335 | neg_cosine_min = proj_log_info["neg_cosine_min"] 336 | lag_cosine_min = proj_log_info["lag_cosine_min"] 337 | neg_num = proj_log_info["neg_num"] 338 | neg_loss = proj_log_info["neg_loss"] 339 | neg_loss_max = proj_log_info["neg_loss_max"] 340 | pos_num = proj_log_info["pos_num"] 341 | pos_loss = proj_log_info["pos_loss"] 342 | pos_loss_max = proj_log_info["pos_loss_max"] 343 | 344 | # collected zero successful trajectory 345 | else: 346 | batch = replay_buffer.sample_with_mask(config.batch_size, config.l2_margin) 347 | proj_log_info = reward_model.update_neg(batch) 348 | batch_vlm_rewards = proj_log_info.pop("vlm_rewards") 349 | log_info = vlm_agent.update(batch, batch_vlm_rewards) 350 | pos_loss = proj_log_info["pos_loss"] 351 | 352 | # update SAC agent 353 | if use_relay: _ = sac_agent.update(batch) 354 | 355 | # eval 356 | if t % config.eval_freq == 0: 357 | eval_reward, eval_success, _, _ = eval_policy(vlm_agent, 358 | eval_env, 359 | eval_episodes) 360 | 361 | # logging 362 | if t % config.log_freq == 0: 363 | if t > config.start_timesteps: 364 | log_info.update({ 365 | "step": t, 366 | "success": reward, 367 | "task_reward": lst_ep_task_reward, 368 | "vlm_reward": lst_ep_vlm_reward, 369 | "eval_reward": eval_reward, 370 | "eval_success": eval_success, 371 | "batch_reward": batch.rewards.mean(), 372 | "batch_reward_max": batch.rewards.max(), 373 | "batch_reward_min": batch.rewards.min(), 374 | "batch_vlm_reward": batch_vlm_rewards.mean(), 375 | "batch_vlm_reward_max": batch_vlm_rewards.max(), 376 | "batch_vlm_reward_min": batch_vlm_rewards.min(), 377 | "time": (time.time() - start_time) / 60 378 | }) 379 | logger.info( 380 | f"\n[T {t//1000}K][{log_info['time']:.2f} min] " 381 | f"task_reward: {lst_ep_task_reward:.2f}, " 382 | f"vlm_reward: {lst_ep_vlm_reward:.2f}\n" 383 | f"\tvlm_reward: {eval_reward:.2f}, vlm_success: {eval_success:.0f}, " 384 | f"vlm_step: {lst_vlm_step}\n" 385 | f"\tq_loss: {log_info['critic_loss']:.3f}, " 386 | f"a_loss: {log_info['alpha_loss']:.3f}, " 387 | f"q: {log_info['q']:.3f}, q_max: {log_info['q_max']:.3f}\n" 388 | f"\tR: {log_info['batch_reward']:.3f}, " 389 | f"Rmax: {log_info['batch_reward_max']:.3f}, " 390 | f"Rmin: {log_info['batch_reward_min']:.3f}\n" 391 | f"\tvlm_R: {log_info['batch_vlm_reward']:.3f}, " 392 | f"vlm_Rmax: {log_info['batch_vlm_reward_max']:.3f}, " 393 | f"vlm_Rmin: {log_info['batch_vlm_reward_min']:.3f}\n" 394 | f"\tep_num: {ep_num}, success_cnt: {success_cnt}, " 395 | f"success: {reward}\n" 396 | f"\tpos_ptr: {embedding_buffer.pos_ptr}, " 397 | f"valid_ptr: {embedding_buffer.valid_ptr}, " 398 | f"neg_ptr: {embedding_buffer.neg_ptr}\n" 399 | f"\tpNum: {pos_num}, pLoss: {pos_loss:.3f}, pLossMax: {pos_loss_max:.3f}\n" 400 | f"\tnNum: {neg_num}, nLoss: {neg_loss:.3f}, nLossMax: {neg_loss_max:.3f}\n" 401 | f"\tpCos: {pos_cosine:.2f}, pCos_max: {pos_cosine_max:.2f}, " 402 | f"pCos_min: {pos_cosine_min:.2f}\n" 403 | f"\tnCos: {neg_cosine:.2f}, nCos_max: {neg_cosine_max:.2f}, " 404 | f"nCos_min: {neg_cosine_min:.2f}\n" 405 | f"\tlCos: {lag_cosine:.2f}, lCos_max: {lag_cosine_max:.2f}, " 406 | f"lCos_min: {lag_cosine_min:.2f}\n" 407 | f"\tgoal: ({goal[0]:.3f}, {goal[1]:.3f}, {goal[2]:.3f}), " 408 | f"rvlm_R: {log_info['rvlm_reward']:.4f}\n" 409 | ) 410 | logs.append(log_info) 411 | else: 412 | logs.append({ 413 | "step": t, 414 | "task_reward": lst_ep_task_reward, 415 | "vlm_reward": lst_ep_vlm_reward, 416 | "eval_reward": eval_reward, 417 | "eval_success": eval_success, 418 | "time": (time.time() - start_time) / 60, 419 | }) 420 | logger.info( 421 | f"\n[T {t//1000}K][{logs[-1]['time']:.2f} min] " 422 | f"task_reward: {lst_ep_task_reward:.2f}, " 423 | f"vlm_reward: {lst_ep_vlm_reward:.2f}\n" 424 | ) 425 | 426 | 427 | # save logs 428 | log_df = pd.DataFrame(logs) 429 | log_df.to_csv(f"logs/{exp_name}.csv") 430 | 431 | # close env 432 | env.close() 433 | eval_env.close() 434 | -------------------------------------------------------------------------------- /experiments/train_liv.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".4" 3 | 4 | import cv2 5 | import time 6 | import clip 7 | import git 8 | import logging 9 | import imageio 10 | import gymnasium as gym 11 | import ml_collections 12 | import numpy as np 13 | import pandas as pd 14 | 15 | import torch 16 | import torchvision.transforms as T 17 | 18 | from tqdm import trange 19 | from models import VLMAgent 20 | from utils import (TASKS, VLMBuffer, log_git, get_logger, make_env, load_liv) 21 | 22 | 23 | ################### 24 | # Utils Functions # 25 | ################### 26 | def crop_center(config, image): 27 | x1, x2, y1, y2 = 32, 224, 32, 224 28 | return image[x1:x2, y1:y2, :] 29 | 30 | 31 | def eval_policy(agent: VLMAgent, 32 | env: gym.Env, 33 | eval_episodes: int = 10, 34 | hidden_env: bool = True): 35 | t1 = time.time() 36 | eval_reward, eval_success, avg_step = 0, 0, 0 37 | for i in range(1, eval_episodes + 1): 38 | obs, _ = env.reset() 39 | while True: 40 | avg_step += 1 41 | action = agent.sample_action(obs, eval_mode=True) 42 | obs, reward, terminated, truncated, info = env.step(action) 43 | eval_reward += reward 44 | if terminated or truncated: 45 | eval_success += info["success"] 46 | break 47 | 48 | eval_reward /= eval_episodes 49 | eval_success /= eval_episodes 50 | 51 | return eval_reward, eval_success, avg_step, time.time() - t1 52 | 53 | 54 | def setup_logging(config): 55 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 56 | 57 | # logging 58 | exp_prefix = f"liv_rho{config.rho}" 59 | exp_name = f"{exp_prefix}/{config.env_name}/s{config.seed}_{timestamp}" 60 | os.makedirs(f"logs/{exp_prefix}/{config.env_name}", exist_ok=True) 61 | exp_info = f"# Running experiment for: {exp_name} #" 62 | print("#" * len(exp_info) + f"\n{exp_info}\n" + "#" * len(exp_info)) 63 | logger = get_logger(f"logs/{exp_name}.log") 64 | 65 | # add git commit info 66 | log_git(config) 67 | logger.info(f"Config:\n{config}\n") 68 | 69 | # set random seed 70 | np.random.seed(config.seed) 71 | torch.manual_seed(config.seed) 72 | 73 | return exp_name, logger 74 | 75 | 76 | def setup_exp(config): 77 | # liv 78 | transform = T.Compose([T.ToTensor()]) 79 | liv = load_liv() 80 | liv.eval() 81 | 82 | # task description embedding 83 | with torch.no_grad(): 84 | token = clip.tokenize([TASKS[config.env_name]]) 85 | text_embedding = liv(input=token, modality="text") 86 | text_embedding = text_embedding.detach().cpu().numpy() 87 | 88 | # initialize the environment 89 | env = make_env(config.env_name, seed=config.seed) 90 | eval_seed = config.seed if "hidden" in config.env_name else config.seed+100 91 | eval_env = make_env(config.env_name, 92 | seed=eval_seed, 93 | image_size=480, 94 | camera_id=config.camera_id) 95 | 96 | # environment parameter 97 | obs_dim = env.observation_space.shape[0] 98 | act_dim = env.action_space.shape[0] 99 | max_action = env.action_space.high[0] 100 | 101 | # fixed LIV representation projection 102 | vlm_agent = VLMAgent(obs_dim=obs_dim, 103 | act_dim=act_dim, 104 | max_action=max_action, 105 | seed=config.seed, 106 | tau=config.tau, 107 | rho=config.rho, 108 | gamma=config.gamma, 109 | lr=config.lr, 110 | text_embedding=text_embedding, 111 | hidden_dims=config.hidden_dims) 112 | 113 | # Replay buffer 114 | replay_buffer = VLMBuffer(obs_dim=obs_dim, act_dim=act_dim) 115 | 116 | return transform, liv, env, eval_env, vlm_agent, replay_buffer 117 | 118 | 119 | ################# 120 | # Main Function # 121 | ################# 122 | def train_and_evaluate(config: ml_collections.ConfigDict): 123 | start_time = time.time() 124 | 125 | # logging setup 126 | exp_name, logger = setup_logging(config) 127 | 128 | # experiment setup 129 | (transform, 130 | liv, 131 | env, 132 | eval_env, 133 | vlm_agent, 134 | replay_buffer) = setup_exp(config) 135 | 136 | # reward for untrained agent 137 | eval_episodes = 1 if "hidden" in config.env_name else 10 138 | eval_reward, eval_success, _, _ = eval_policy(vlm_agent, 139 | eval_env, 140 | eval_episodes) 141 | logs = [{ 142 | "step": 0, 143 | "eval_reward": eval_reward, 144 | "eval_success": eval_success 145 | }] 146 | 147 | # start training 148 | obs, _ = env.reset() 149 | success_cnt = 0 150 | ep_num, ep_step, success = 0, 0, 0 151 | ep_task_reward, ep_vlm_reward = 0, 0 152 | lst_ep_step, lst_ep_task_reward, lst_ep_vlm_reward = 0, 0, 0 153 | for t in trange(1, config.max_timesteps + 1): 154 | if t <= config.start_timesteps: 155 | action = env.action_space.sample() 156 | else: 157 | action = vlm_agent.sample_action(obs) 158 | next_obs, task_reward, terminated, truncated, info = env.step(action) 159 | 160 | # vision language model reward 161 | image = env.mujoco_renderer.render( 162 | render_mode="rgb_array", 163 | camera_id=config.camera_id).copy() 164 | image = image[::-1] 165 | image = crop_center(config, image) 166 | processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 167 | processed_image = transform(processed_image) 168 | with torch.no_grad(): 169 | image_embedding = liv(input=processed_image.to("cuda")[None], modality="vision") 170 | image_embedding = image_embedding.detach().cpu().numpy() 171 | vlm_reward = vlm_agent.get_reward(image_embedding).item() # (1, 1024) 172 | reward = info["success"] 173 | success_cnt += info["success"] 174 | 175 | # add to buffer 176 | replay_buffer.add(obs, 177 | action, 178 | next_obs, 179 | vlm_reward, 180 | reward-1, 181 | terminated) 182 | obs = next_obs 183 | ep_vlm_reward += vlm_reward 184 | ep_task_reward += task_reward 185 | ep_step += 1 186 | 187 | # start a new trajectory 188 | if terminated or truncated: 189 | obs, _ = env.reset() 190 | success = info["success"] 191 | lst_ep_step = ep_step 192 | lst_ep_task_reward = ep_task_reward 193 | lst_ep_vlm_reward = ep_vlm_reward 194 | ep_vlm_reward = 0 195 | ep_task_reward = 0 196 | ep_step = 0 197 | ep_num += 1 198 | 199 | # training 200 | if t > config.start_timesteps: 201 | batch = replay_buffer.sample(config.batch_size) 202 | log_info = vlm_agent.update(batch) 203 | 204 | # eval 205 | if t % config.eval_freq == 0: 206 | eval_reward, eval_success, _, _ = eval_policy( 207 | agent, eval_env, eval_episodes=eval_episodes) 208 | 209 | # logging 210 | if t % config.log_freq == 0: 211 | if t > config.start_timesteps: 212 | log_info.update({ 213 | "step": t, 214 | "success": success, 215 | "task_reward": lst_ep_task_reward, 216 | "vlm_reward": lst_ep_vlm_reward, 217 | "eval_reward": eval_reward, 218 | "eval_success": eval_success, 219 | "batch_reward": batch.rewards.mean(), 220 | "batch_reward_max": batch.rewards.max(), 221 | "batch_reward_min": batch.rewards.min(), 222 | "time": (time.time() - start_time) / 60 223 | }) 224 | logger.info( 225 | f"\n[T {t//1000}K][{log_info['time']:.2f} min] " 226 | f"task_reward: {lst_ep_task_reward:.2f}, " 227 | f"vlm_reward: {lst_ep_vlm_reward:.2f}\n" 228 | f"\tq_loss: {log_info['critic_loss']:.3f}, " 229 | f"a_loss: {log_info['alpha_loss']:.3f}, " 230 | f"q: {log_info['q']:.3f}, q_max: {log_info['q_max']:.3f}\n" 231 | f"\tR: {log_info['batch_reward']:.3f}, " 232 | f"Rmax: {log_info['batch_reward_max']:.3f}, " 233 | f"Rmin: {log_info['batch_reward_min']:.3f}\n" 234 | f"\tep_num: {ep_num}, success_cnt: {success_cnt}, " 235 | f"success: {success}\n" 236 | ) 237 | logs.append(log_info) 238 | else: 239 | logs.append({ 240 | "step": t, 241 | "task_reward": lst_ep_task_reward, 242 | "vlm_reward": lst_ep_vlm_reward, 243 | "eval_reward": eval_reward, 244 | "eval_success": eval_success, 245 | "time": (time.time() - start_time) / 60, 246 | }) 247 | logger.info( 248 | f"\n[T {t//1000}K][{logs[-1]['time']:.2f} min] " 249 | f"task_reward: {lst_ep_task_reward:.2f}, " 250 | f"vlm_reward: {lst_ep_vlm_reward:.2f}\n" 251 | ) 252 | 253 | # save logs 254 | log_df = pd.DataFrame(logs) 255 | log_df.to_csv(f"logs/{exp_name}.csv") 256 | 257 | # close env 258 | env.close() 259 | eval_env.close() 260 | -------------------------------------------------------------------------------- /experiments/train_relay.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".4" 3 | 4 | import cv2 5 | import time 6 | import clip 7 | import logging 8 | import imageio 9 | import gymnasium as gym 10 | import ml_collections 11 | import numpy as np 12 | import pandas as pd 13 | 14 | import torch 15 | import torchvision.transforms as T 16 | 17 | from tqdm import trange 18 | from models import SACAgent, VLMAgent 19 | from utils import (TASKS, VLMBuffer, log_git, get_logger, make_env, load_liv) 20 | 21 | 22 | ################### 23 | # Utils Functions # 24 | ################### 25 | def eval_policy(agent: SACAgent, 26 | env: gym.Env, 27 | logger: logging.Logger = None, 28 | eval_episodes: int = 10): 29 | t1 = time.time() 30 | eval_reward, eval_success, avg_step = 0, 0, 0 31 | for i in range(1, eval_episodes + 1): 32 | obs, _ = env.reset() 33 | while True: 34 | avg_step += 1 35 | action = agent.sample_action(obs, eval_mode=True) 36 | obs, reward, terminated, truncated, info = env.step(action) 37 | eval_reward += reward 38 | if terminated or truncated: 39 | eval_success += info["success"] 40 | break 41 | 42 | eval_reward /= eval_episodes 43 | eval_success /= eval_episodes 44 | 45 | if logger: 46 | logger.info(f"Eval obs = ({obs[-3]:.4f}, {obs[-2]:.4f}, {obs[-1]:.4f})") 47 | 48 | return eval_reward, eval_success, avg_step, time.time() - t1 49 | 50 | 51 | def setup_logging(config): 52 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 53 | 54 | # logging 55 | exp_prefix = f"relay_rho{config.rho}" 56 | exp_name = f"{exp_prefix}/{config.env_name}/s{config.seed}_{timestamp}" 57 | os.makedirs(f"logs/{exp_prefix}/{config.env_name}", exist_ok=True) 58 | exp_info = f"# Running experiment for: {exp_name} #" 59 | print("#" * len(exp_info) + f"\n{exp_info}\n" + "#" * len(exp_info)) 60 | logger = get_logger(f"logs/{exp_name}.log") 61 | 62 | # add git commit info 63 | log_git(config) 64 | logger.info(f"Config:\n{config}\n") 65 | 66 | # set random seed 67 | np.random.seed(config.seed) 68 | torch.manual_seed(config.seed) 69 | 70 | return exp_name, logger 71 | 72 | 73 | def setup_exp(config): 74 | # liv 75 | transform = T.Compose([T.ToTensor()]) 76 | liv = load_liv() 77 | 78 | # task description embedding 79 | with torch.no_grad(): 80 | token = clip.tokenize([TASKS[config.env_name]]) 81 | text_embedding = liv(input=token, modality="text") 82 | text_embedding = text_embedding.detach().cpu().numpy() 83 | 84 | # initialize the environment 85 | env = make_env(config.env_name, seed=config.seed) 86 | eval_seed = config.seed if "hidden" in config.env_name else config.seed+100 87 | eval_env = make_env(config.env_name, 88 | seed=eval_seed, 89 | image_size=480, 90 | camera_id=config.camera_id) 91 | 92 | # environment parameter 93 | obs_dim = env.observation_space.shape[0] 94 | act_dim = env.action_space.shape[0] 95 | max_action = env.action_space.high[0] 96 | 97 | # fixed LIV representation without fine-tuning 98 | vlm_agent = VLMAgent(obs_dim=obs_dim, 99 | act_dim=act_dim, 100 | seed=config.seed, 101 | tau=config.tau, 102 | rho=config.rho, 103 | gamma=config.gamma, 104 | lr=config.lr, 105 | text_embedding=text_embedding, 106 | hidden_dims=config.hidden_dims) 107 | 108 | # SAC agent 109 | sac_agent = SACAgent(obs_dim=obs_dim, 110 | act_dim=act_dim, 111 | max_action=max_action, 112 | seed=config.seed, 113 | tau=config.tau, 114 | gamma=config.gamma, 115 | lr=config.lr, 116 | hidden_dims=config.hidden_dims) 117 | 118 | # Replay buffer 119 | replay_buffer = VLMBuffer(obs_dim=obs_dim, act_dim=act_dim) 120 | 121 | return transform, liv, env, eval_env, vlm_agent, sac_agent, replay_buffer 122 | 123 | 124 | def train_and_evaluate(config: ml_collections.ConfigDict): 125 | start_time = time.time() 126 | 127 | # logging setup 128 | exp_name, logger = setup_logging(config) 129 | 130 | # experiment setup 131 | (transform, 132 | liv, 133 | env, 134 | eval_env, 135 | vlm_agent, 136 | sac_agent, 137 | replay_buffer) = setup_exp(config) 138 | 139 | # reward for untrained agent 140 | eval_episodes = 1 if "hidden" in config.env_name else 10 141 | eval_reward, eval_success, _, _ = eval_policy(vlm_agent, 142 | eval_env, 143 | eval_episodes=eval_episodes) 144 | logs = [{ 145 | "step": 0, 146 | "eval_reward": eval_reward, 147 | "eval_success": eval_success, 148 | }] 149 | 150 | # relay freqs 151 | relay_freqs = [50, 100, 150, 200] 152 | relay_freq = np.random.choice(relay_freqs) 153 | logger.info(f"Relay freqs: {relay_freqs}\n") 154 | 155 | # start training 156 | obs, _ = env.reset() 157 | reward, ep_task_reward, ep_vlm_reward = 0, 0, 0 158 | success_cnt, ep_step = 0, 0 159 | lst_ep_step, lst_ep_task_reward, lst_ep_vlm_reward = 0, 0, 0 160 | sac_step, vlm_step = 0, 0 161 | lst_sac_step, lst_vlm_step = 0, 0 162 | policies = ["vlm", "sac"] 163 | use_relay = True 164 | 165 | for t in trange(1, config.max_timesteps + 1): 166 | if t <= config.start_timesteps: 167 | action = env.action_space.sample() 168 | else: 169 | if use_relay: 170 | if policies[(ep_step//relay_freq)%2] == "vlm": 171 | vlm_step += 1 172 | action = vlm_agent.sample_action(obs) 173 | else: 174 | sac_step += 1 175 | action = sac_agent.sample_action(obs) 176 | action_noise = np.random.normal( 177 | 0, sac_agent.max_action*config.expl_noise, size=sac_agent.act_dim) 178 | action = (action + action_noise).clip( 179 | -sac_agent.max_action, sac_agent.max_action) 180 | else: 181 | vlm_step += 1 182 | action = vlm_agent.sample_action(obs) 183 | next_obs, task_reward, terminated, truncated, info = env.step(action) 184 | 185 | # vision language model reward 186 | image = env.mujoco_renderer.render( 187 | render_mode="rgb_array", 188 | camera_id=config.camera_id).copy() 189 | image = image[::-1] 190 | processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 191 | processed_image = transform(processed_image) 192 | 193 | with torch.no_grad(): 194 | image_embedding = liv(input=processed_image.to("cuda")[None], modality="vision") 195 | image_embedding = image_embedding.detach().cpu().numpy() 196 | vlm_reward = vlm_agent.get_reward(image_embedding).item() 197 | reward = int(info["success"]) 198 | success_cnt += reward 199 | ep_step += 1 200 | 201 | # add to buffer 202 | replay_buffer.add(obs, 203 | action, 204 | next_obs, 205 | vlm_reward, 206 | reward-1, # constant reward shifting 207 | terminated) 208 | obs = next_obs 209 | ep_vlm_reward += vlm_reward 210 | ep_task_reward += task_reward 211 | 212 | # start a new trajectory 213 | if terminated or truncated: 214 | obs, _ = env.reset() 215 | lst_ep_step = ep_step 216 | lst_ep_task_reward = ep_task_reward 217 | lst_ep_vlm_reward = ep_vlm_reward 218 | lst_sac_step = sac_step 219 | lst_vlm_step = vlm_step 220 | ep_vlm_reward = 0 221 | ep_task_reward = 0 222 | sac_step = 0 223 | vlm_step = 0 224 | policies = policies[::-1] 225 | relay_freq = np.random.choice(relay_freqs) 226 | 227 | ep_step = 0 228 | first_success_step = 0 229 | 230 | if success_cnt >= config.relay_threshold: 231 | use_relay = False 232 | 233 | # training 234 | if t > config.start_timesteps: 235 | batch = replay_buffer.sample(config.batch_size) 236 | log_info = vlm_agent.update(batch) 237 | if use_relay: _ = sac_agent.update(batch) 238 | 239 | # eval 240 | if t % config.eval_freq == 0: 241 | eval_reward, eval_success, _, _ = eval_policy(vlm_agent, 242 | eval_env, 243 | eval_episodes=eval_episodes) 244 | 245 | # logging 246 | if t % config.log_freq == 0: 247 | if t > config.start_timesteps: 248 | log_info.update({ 249 | "step": t, 250 | "success": reward, 251 | "task_reward": lst_ep_task_reward, 252 | "vlm_reward": lst_ep_vlm_reward, 253 | "eval_reward": eval_reward, 254 | "eval_success": eval_success, 255 | "batch_reward": batch.rewards.mean(), 256 | "batch_reward_max": batch.rewards.max(), 257 | "batch_reward_min": batch.rewards.min(), 258 | "batch_vlm_reward": batch.vlm_rewards.mean(), 259 | "batch_vlm_reward_max": batch.vlm_rewards.max(), 260 | "batch_vlm_reward_min": batch.vlm_rewards.min(), 261 | "time": (time.time() - start_time) / 60 262 | }) 263 | logger.info( 264 | f"\n[T {t//1000}K][{log_info['time']:.2f} min] " 265 | f"task_reward: {lst_ep_task_reward:.2f}, " 266 | f"vlm_reward: {lst_ep_vlm_reward:.2f}\n" 267 | f"\tvlm_reward: {eval_reward:.2f}, vlm_success: {eval_success:.0f}, " 268 | f"vlm_step: {lst_vlm_step}\n" 269 | f"\tq_loss: {log_info['critic_loss']:.3f}, " 270 | f"a_loss: {log_info['alpha_loss']:.3f}, " 271 | f"q: {log_info['q']:.3f}, q_max: {log_info['q_max']:.3f}\n" 272 | f"\tR: {log_info['batch_reward']:.3f}, " 273 | f"Rmax: {log_info['batch_reward_max']:.3f}, " 274 | f"success_cnt: {success_cnt}, success: {reward}\n" 275 | f"\tvlm_R: {log_info['batch_vlm_reward']:.3f}, " 276 | f"vlm_Rmax: {log_info['batch_vlm_reward_max']:.3f}, " 277 | f"vlm_Rmin: {log_info['batch_vlm_reward_min']:.3f}\n" 278 | ) 279 | logs.append(log_info) 280 | else: 281 | logs.append({ 282 | "step": t, 283 | "task_reward": lst_ep_task_reward, 284 | "vlm_reward": lst_ep_vlm_reward, 285 | "eval_reward": eval_reward, 286 | "eval_success": eval_success, 287 | "time": (time.time() - start_time) / 60, 288 | }) 289 | logger.info( 290 | f"\n[T {t//1000}K][{logs[-1]['time']:.2f} min] " 291 | f"task_reward: {lst_ep_task_reward:.2f}\n" 292 | ) 293 | 294 | # save logs 295 | log_df = pd.DataFrame(logs) 296 | log_df.to_csv(f"logs/{exp_name}.csv") 297 | 298 | # close env 299 | env.close() 300 | eval_env.close() 301 | -------------------------------------------------------------------------------- /experiments/train_sac.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".2" 3 | 4 | import time 5 | import gymnasium as gym 6 | import ml_collections 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from tqdm import trange 11 | from models import SACAgent 12 | from utils import ReplayBuffer, log_git, get_logger, make_env 13 | 14 | 15 | ################### 16 | # Utils Functions # 17 | ################### 18 | def eval_policy(agent: SACAgent, 19 | env: gym.Env, 20 | eval_episodes: int = 10): 21 | t1 = time.time() 22 | eval_reward, eval_success, avg_step = 0, 0, 0 23 | for i in range(1, eval_episodes + 1): 24 | obs, _ = env.reset() 25 | while True: 26 | avg_step += 1 27 | action = agent.sample_action(obs, eval_mode=True) 28 | obs, reward, terminated, truncated, info = env.step(action) 29 | eval_reward += reward 30 | if terminated or truncated: 31 | eval_success += info["success"] 32 | break 33 | 34 | eval_reward /= eval_episodes 35 | eval_success /= eval_episodes 36 | 37 | return eval_reward, eval_success, avg_step, time.time() - t1 38 | 39 | 40 | def setup_logging(config): 41 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 42 | 43 | # logging 44 | exp_prefix = "sac" 45 | exp_name = f"{exp_prefix}/{config.env_name}/s{config.seed}_{timestamp}" 46 | os.makedirs(f"logs/{exp_prefix}/{config.env_name}", exist_ok=True) 47 | exp_info = f"# Running experiment for: {exp_name} #" 48 | print("#" * len(exp_info) + f"\n{exp_info}\n" + "#" * len(exp_info)) 49 | logger = get_logger(f"logs/{exp_name}.log") 50 | 51 | # add git commit info 52 | log_git(config) 53 | logger.info(f"Config:\n{config}\n") 54 | 55 | # set random seed 56 | np.random.seed(config.seed) 57 | 58 | return exp_name, logger 59 | 60 | 61 | def setup_exp(config): 62 | # initialize the environment 63 | env = make_env(config.env_name, 64 | image_size=480, 65 | seed=config.seed) 66 | eval_seed = config.seed if "hidden" in config.env_name else config.seed+100 67 | eval_env = make_env(config.env_name, 68 | seed=eval_seed, 69 | image_size=480, 70 | camera_id=config.camera_id) 71 | 72 | # environment parameter 73 | obs_dim = env.observation_space.shape[0] 74 | act_dim = env.action_space.shape[0] 75 | max_action = env.action_space.high[0] 76 | 77 | # SAC agent 78 | agent = SACAgent(obs_dim=obs_dim, 79 | act_dim=act_dim, 80 | max_action=max_action, 81 | seed=config.seed, 82 | tau=config.tau, 83 | gamma=config.gamma, 84 | lr=config.lr, 85 | hidden_dims=config.hidden_dims) 86 | 87 | # Replay buffer 88 | replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim) 89 | 90 | return env, eval_env, agent, replay_buffer 91 | 92 | 93 | ################# 94 | # Main Function # 95 | ################# 96 | def train_and_evaluate(config: ml_collections.ConfigDict): 97 | start_time = time.time() 98 | 99 | # logging 100 | exp_name, logger = setup_logging(config) 101 | 102 | # experiment setup 103 | (env, 104 | eval_env, 105 | agent, 106 | replay_buffer) = setup_exp(config) 107 | 108 | # reward for untrained agent 109 | eval_episodes = 1 if "hidden" in config.env_name else 10 110 | eval_reward, eval_success, _, _ = eval_policy(agent, 111 | eval_env, 112 | eval_episodes=eval_episodes) 113 | logs = [{ 114 | "step": 0, 115 | "eval_reward": eval_reward, 116 | "eval_success": eval_success 117 | }] 118 | 119 | # start training 120 | obs, _ = env.reset() 121 | success, cum_success, ep_step = 0, 0, 0 122 | ep_task_reward, ep_reward = 0, 0 123 | lst_ep_task_reward, lst_ep_reward = 0, 0 124 | for t in trange(1, config.max_timesteps + 1): 125 | ep_step += 1 126 | if t <= config.start_timesteps: 127 | action = env.action_space.sample() 128 | else: 129 | action = agent.sample_action(obs) 130 | next_obs, task_reward, terminated, truncated, info = env.step(action) 131 | cum_success += info["success"] 132 | 133 | replay_buffer.add(obs, 134 | action, 135 | next_obs, 136 | info["success"]-1, 137 | terminated) 138 | obs = next_obs 139 | ep_reward += info["success"] 140 | ep_task_reward += task_reward 141 | 142 | # start a new trajectory 143 | if terminated or truncated: 144 | obs, _ = env.reset() 145 | success = info["success"] 146 | lst_ep_task_reward = ep_task_reward 147 | lst_ep_reward = ep_reward 148 | ep_task_reward = 0 149 | ep_reward = 0 150 | ep_step = 0 151 | 152 | # training 153 | if t > config.start_timesteps: 154 | batch = replay_buffer.sample(config.batch_size) 155 | log_info = agent.update(batch) 156 | 157 | # eval 158 | if t % config.eval_freq == 0: 159 | eval_reward, eval_success, _, _ = eval_policy(agent, 160 | eval_env, 161 | eval_episodes=eval_episodes) 162 | 163 | # logging 164 | if t % config.log_freq == 0: 165 | if t > config.start_timesteps: 166 | log_info.update({ 167 | "step": t, 168 | "success": success, 169 | "reward": lst_ep_reward, 170 | "task_reward": lst_ep_task_reward, 171 | "eval_reward": eval_reward, 172 | "eval_success": eval_success, 173 | "batch_reward": batch.rewards.mean(), 174 | "batch_reward_max": batch.rewards.max(), 175 | "batch_reward_min": batch.rewards.min(), 176 | "time": (time.time() - start_time) / 60 177 | }) 178 | logger.info( 179 | f"\n[T {t//1000}K][{log_info['time']:.2f} min] " 180 | f"task_R: {lst_ep_task_reward:.2f}, " 181 | f"ep_R: {lst_ep_reward:.0f}\n" 182 | f"\tq_loss: {log_info['critic_loss']:.3f}, " 183 | f"a_loss: {log_info['alpha_loss']:.3f}, " 184 | f"q: {log_info['q']:.2f}, " 185 | f"q_max: {log_info['q_max']:.2f}\n" 186 | f"\tR: {log_info['batch_reward']:.3f}, " 187 | f"Rmax: {log_info['batch_reward_max']:.1f}, " 188 | f"Rmin: {log_info['batch_reward_min']:.1f}, " 189 | f"success: {success:.0f}, " 190 | f"cum_success: {cum_success:.0f}\n") 191 | logs.append(log_info) 192 | else: 193 | logs.append({ 194 | "step": t, 195 | "reward": lst_ep_reward, 196 | "task_reward": lst_ep_task_reward, 197 | "eval_reward": eval_reward, 198 | "eval_success": eval_success, 199 | "time": (time.time() - start_time) / 60, 200 | }) 201 | logger.info( 202 | f"\n[T {t//1000}K][{logs[-1]['time']:.2f} min] " 203 | f"task_reward: {lst_ep_task_reward:.2f}, " 204 | f"ep_reward: {lst_ep_reward:.2f}\n" 205 | ) 206 | 207 | # save logs 208 | log_df = pd.DataFrame(logs) 209 | log_df.to_csv(f"logs/{exp_name}.csv") 210 | 211 | # close env 212 | env.close() 213 | eval_env.close() 214 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from absl import app, flags 2 | from ml_collections import config_flags 3 | import os 4 | import experiments 5 | 6 | config_flags.DEFINE_config_file("config", default="configs/metaworld.py") 7 | FLAGS = flags.FLAGS 8 | 9 | 10 | def main(argv): 11 | config = FLAGS.config 12 | 13 | try: 14 | if config.exp_name == "oracle": 15 | experiments.run_oracle.evaluate(config) 16 | 17 | elif config.exp_name == "sac": 18 | experiments.train_sac.train_and_evaluate(config) 19 | 20 | elif config.exp_name == "liv": 21 | experiments.train_liv.train_and_evaluate(config) 22 | 23 | elif config.exp_name == "relay": 24 | experiments.train_relay.train_and_evaluate(config) 25 | 26 | elif config.exp_name == "furl": 27 | experiments.train_furl.train_and_evaluate(config) 28 | 29 | except KeyboardInterrupt as e: 30 | print("Skip to the next experiment.") 31 | 32 | 33 | if __name__ == '__main__': 34 | app.run(main) 35 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .liv import LIV 3 | from .sac import SACAgent 4 | from .vlm import VLMAgent 5 | from .furl import FuRLAgent 6 | from .projection import Projection, RewardModel -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Sequence 2 | 3 | import distrax 4 | import jax 5 | import jax.numpy as jnp 6 | from flax import linen as nn 7 | 8 | 9 | ################### 10 | # Utils Functions # 11 | ################### 12 | class MLP(nn.Module): 13 | hidden_dims: Sequence[int] = (256, 256) 14 | activate_final: bool = True 15 | 16 | @nn.compact 17 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 18 | for i, size in enumerate(self.hidden_dims): 19 | x = nn.Dense(size)(x) 20 | if i + 1 < len(self.hidden_dims) or self.activate_final: 21 | x = nn.relu(x) 22 | return x 23 | 24 | 25 | class Scalar(nn.Module): 26 | init_value: float 27 | 28 | def setup(self): 29 | self.value = self.param("value", lambda x: self.init_value) 30 | 31 | def __call__(self): 32 | return self.value 33 | 34 | 35 | ################ 36 | # Actor Critic # 37 | ################ 38 | class Actor(nn.Module): 39 | act_dim: int 40 | max_action: float = 1.0 41 | hidden_dims: Sequence[int] = (256, 256) 42 | log_std_min: Optional[float] = None 43 | log_std_max: Optional[float] = None 44 | min_scale: float = 1e-3 45 | 46 | def setup(self): 47 | self.net = MLP(self.hidden_dims, activate_final=True) 48 | self.mu_layer = nn.Dense(self.act_dim) 49 | self.std_layer = nn.Dense(self.act_dim) 50 | 51 | def __call__(self, rng: Any, observation: jnp.ndarray): 52 | x = self.net(observation) 53 | mu = self.mu_layer(x) 54 | mean_action = nn.tanh(mu) 55 | 56 | std = self.std_layer(x) 57 | std = jax.nn.softplus(std) + self.min_scale 58 | 59 | action_distribution = distrax.Transformed( 60 | distrax.MultivariateNormalDiag(mu, std), 61 | distrax.Block(distrax.Tanh(), ndims=1)) 62 | sampled_action, logp = action_distribution.sample_and_log_prob( 63 | seed=rng) 64 | 65 | return mean_action * self.max_action, sampled_action * self.max_action, logp 66 | 67 | def get_logprob(self, observation, action): 68 | x = self.net(observation) 69 | mu = self.mu_layer(x) 70 | mean_action = nn.tanh(mu) 71 | 72 | std = self.std_layer(x) 73 | std = jax.nn.softplus(std) + self.min_scale 74 | 75 | action_distribution = distrax.Normal(mu, std) 76 | raw_action = atanh(action) 77 | log_prob = action_distribution.log_prob(raw_action).sum(-1) 78 | log_prob -= 2 * (jnp.log(2) - raw_action - 79 | jax.nn.softplus(-2 * raw_action)).sum(-1) 80 | return log_prob, mu, std 81 | 82 | 83 | class Critic(nn.Module): 84 | hidden_dims: Sequence[int] = (256, 256) 85 | output_dim: int = 1 86 | 87 | def setup(self): 88 | self.net = MLP(self.hidden_dims, activate_final=True) 89 | self.out_layer = nn.Dense(self.output_dim) 90 | 91 | def __call__(self, 92 | observations: jnp.ndarray, 93 | actions: jnp.ndarray) -> jnp.ndarray: 94 | x = jnp.concatenate([observations, actions], axis=-1) 95 | x = self.net(x) 96 | q = self.out_layer(x) 97 | return q.squeeze() 98 | 99 | 100 | class DoubleCritic(nn.Module): 101 | hidden_dims: Sequence[int] = (256, 256) 102 | output_dim: int = 1 103 | num_qs: int = 2 104 | 105 | @nn.compact 106 | def __call__(self, observations, actions): 107 | VmapCritic = nn.vmap(Critic, 108 | variable_axes={"params": 0}, 109 | split_rngs={"params": True}, 110 | in_axes=None, 111 | out_axes=0, 112 | axis_size=self.num_qs) 113 | qs = VmapCritic(self.hidden_dims, self.output_dim)(observations, 114 | actions) 115 | return qs 116 | 117 | 118 | #################### 119 | # Vectorized Agent # 120 | #################### 121 | class EnsembleDense(nn.Module): 122 | ensemble_num: int 123 | features: int 124 | use_bias: bool = True 125 | dtype: Any = jnp.float32 126 | precision: Any = None 127 | kernel_init: Callable = nn.initializers.lecun_normal() 128 | bias_init: Callable = nn.initializers.zeros 129 | 130 | @nn.compact 131 | def __call__(self, inputs: jnp.array) -> jnp.array: 132 | inputs = jnp.asarray(inputs, self.dtype) 133 | kernel = self.param( 134 | "kernel", self.kernel_init, 135 | (self.ensemble_num, inputs.shape[-1], self.features)) 136 | kernel = jnp.asarray(kernel, self.dtype) 137 | y = jnp.einsum("ij,ijk->ik", inputs, kernel) 138 | if self.use_bias: 139 | bias = self.param("bias", self.bias_init, 140 | (self.ensemble_num, self.features)) 141 | bias = jnp.asarray(bias, self.dtype) 142 | y += bias 143 | return y 144 | 145 | 146 | class EnsembleCritic(nn.Module): 147 | ensemble_num: int 148 | hid_dim: int = 256 149 | 150 | def setup(self): 151 | self.l1 = EnsembleDense(ensemble_num=self.ensemble_num, 152 | features=self.hid_dim, 153 | name="fc1") 154 | self.l2 = EnsembleDense(ensemble_num=self.ensemble_num, 155 | features=self.hid_dim, 156 | name="fc2") 157 | self.l3 = EnsembleDense(ensemble_num=self.ensemble_num, 158 | features=1, 159 | name="fc3") 160 | 161 | def __call__(self, observations, actions): 162 | x = jnp.concatenate([observations, actions], axis=-1) 163 | x = nn.relu(self.l1(x)) 164 | x = nn.relu(self.l2(x)) 165 | x = self.l3(x) 166 | return x.squeeze(-1) 167 | 168 | 169 | class EnsembleDoubleCritic(nn.Module): 170 | ensemble_num: int 171 | hid_dim: int = 256 172 | 173 | def setup(self): 174 | self.q1 = EnsembleCritic(self.ensemble_num, self.hid_dim) 175 | self.q2 = EnsembleCritic(self.ensemble_num, self.hid_dim) 176 | 177 | def __call__(self, observations, actions): 178 | q1 = self.q1(observations, actions) 179 | q2 = self.q2(observations, actions) 180 | return q1, q2 181 | 182 | 183 | class EnsembleActor(nn.Module): 184 | ensemble_num: int 185 | act_dim: int 186 | hid_dim: int = 256 187 | max_action: float = 1.0 188 | min_scale: float = 1e-3 189 | 190 | def setup(self): 191 | self.l1 = EnsembleDense(ensemble_num=self.ensemble_num, 192 | features=self.hid_dim, 193 | name="fc1") 194 | self.l2 = EnsembleDense(ensemble_num=self.ensemble_num, 195 | features=self.hid_dim, 196 | name="fc2") 197 | self.mu_layer = EnsembleDense(ensemble_num=self.ensemble_num, 198 | features=self.act_dim, 199 | name="mu") 200 | self.std_layer = EnsembleDense(ensemble_num=self.ensemble_num, 201 | features=self.act_dim, 202 | name="std") 203 | 204 | def __call__(self, observation: jnp.ndarray): 205 | x = nn.relu(self.l1(observation)) 206 | x = nn.relu(self.l2(x)) 207 | mu = self.mu_layer(x) 208 | mean_action = nn.tanh(mu) 209 | 210 | std = self.std_layer(x) 211 | std = jax.nn.softplus(std) + self.min_scale 212 | 213 | action_distribution = distrax.Transformed( 214 | distrax.MultivariateNormalDiag(mu, std), 215 | distrax.Block(distrax.Tanh(), ndims=1)) 216 | return mean_action, action_distribution 217 | 218 | def get_logprob(self, observation, action): 219 | x = nn.relu(self.l1(observation)) 220 | x = nn.relu(self.l2(x)) 221 | mu = self.mu_layer(x) 222 | mean_action = nn.tanh(mu) 223 | 224 | std = self.std_layer(x) 225 | std = jax.nn.softplus(std) + self.min_scale 226 | 227 | action_distribution = distrax.Normal(mu, std) 228 | raw_action = atanh(action) 229 | log_prob = action_distribution.log_prob(raw_action).sum(-1) 230 | log_prob -= 2 * (jnp.log(2) - raw_action - 231 | jax.nn.softplus(-2 * raw_action)).sum(-1) 232 | return log_prob 233 | 234 | 235 | class EnsembleScalar(nn.Module): 236 | init_value: jnp.ndarray 237 | 238 | def setup(self): 239 | self.value = self.param("value", lambda x: jnp.array(self.init_value)) 240 | 241 | def __call__(self): 242 | return self.value 243 | 244 | 245 | class EnsembleMLP(nn.Module): 246 | ensemble_num: int 247 | hid_dim: int = 256 248 | 249 | def setup(self): 250 | self.l1 = EnsembleDense(ensemble_num=self.ensemble_num, 251 | features=self.hid_dim, 252 | name="fc1") 253 | self.l2 = EnsembleDense(ensemble_num=self.ensemble_num, 254 | features=self.hid_dim, 255 | name="fc2") 256 | 257 | def __call__(self, x): 258 | x = nn.relu(self.l1(x)) 259 | x = self.l2(x) 260 | return x 261 | -------------------------------------------------------------------------------- /models/furl.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Sequence 2 | import os 3 | import functools 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import optax 8 | import orbax.checkpoint as ocp 9 | from flax import linen as nn 10 | from flax.core import FrozenDict 11 | from flax.training import train_state 12 | from utils import Batch, target_update 13 | from models import Actor, DoubleCritic, Scalar, MLP 14 | 15 | 16 | class FuRLAgent: 17 | def __init__(self, 18 | obs_dim: int, 19 | act_dim: int, 20 | max_action: float = 1.0, 21 | seed: int = 42, 22 | tau: float = 0.005, 23 | rho: float = 0.1, 24 | margin: float = 0.1, 25 | gamma: float = 0.99, 26 | lr: float = 3e-4, 27 | ckpt_dir: str = None, 28 | text_embedding: jnp.ndarray = None, 29 | goal_embedding: jnp.ndarray = None, 30 | hidden_dims: Sequence[int] = (256, 256)): 31 | self.lr = lr 32 | self.tau = tau 33 | self.rho = rho 34 | self.gamma = gamma 35 | self.margin = margin 36 | self.max_action = max_action 37 | self.target_entropy = -act_dim / 2 38 | self.rng = jax.random.PRNGKey(seed) 39 | self.rng, actor_key, critic_key, proj_key = jax.random.split(self.rng, 4) 40 | self.goal_embedding = goal_embedding 41 | self.text_embedding = text_embedding 42 | 43 | # Dummy inputs 44 | dummy_obs = jnp.ones([1, obs_dim], dtype=jnp.float32) 45 | dummy_act = jnp.ones([1, act_dim], dtype=jnp.float32) 46 | 47 | # Create the optimizer 48 | actor_tx = optax.adam(lr) 49 | critic_tx = optax.adam(lr) 50 | 51 | # Initialize the Actor 52 | self.actor = Actor(act_dim=act_dim, 53 | max_action=max_action, 54 | hidden_dims=hidden_dims) 55 | actor_params = self.actor.init(actor_key, 56 | actor_key, 57 | dummy_obs)["params"] 58 | self.actor_state = train_state.TrainState.create( 59 | apply_fn=self.actor.apply, params=actor_params, tx=actor_tx) 60 | 61 | # Initialize the Critic 62 | self.critic = DoubleCritic(hidden_dims=hidden_dims) 63 | critic_params = self.critic.init(critic_key, 64 | dummy_obs, 65 | dummy_act)["params"] 66 | self.critic_target_params = critic_params 67 | self.critic_state = train_state.TrainState.create( 68 | apply_fn=self.critic.apply, params=critic_params, tx=critic_tx) 69 | 70 | # Entropy tuning 71 | self.rng, alpha_key = jax.random.split(self.rng, 2) 72 | self.log_alpha = Scalar(0.0) 73 | self.alpha_state = train_state.TrainState.create( 74 | apply_fn=None, 75 | params=self.log_alpha.init(alpha_key)["params"], 76 | tx=optax.adam(lr)) 77 | 78 | # Checkpoint 79 | if ckpt_dir is not None: 80 | self.ckpt_dir = ckpt_dir 81 | self.checkpointer = ocp.StandardCheckpointer() 82 | 83 | @functools.partial(jax.jit, static_argnames=("self")) 84 | def _sample_action(self, 85 | params: FrozenDict, 86 | rng: Any, 87 | observation: np.ndarray) -> jnp.ndarray: 88 | mean_action, sampled_action, _ = self.actor.apply({"params": params}, 89 | rng, 90 | observation) 91 | return mean_action, sampled_action 92 | 93 | def sample_action(self, 94 | observation: np.ndarray, 95 | eval_mode: bool = False) -> np.ndarray: 96 | self.rng, sample_rng = jax.random.split(self.rng) 97 | mean_action, sampled_action = self._sample_action( 98 | self.actor_state.params, sample_rng, observation) 99 | action = mean_action if eval_mode else sampled_action 100 | action = np.asarray(action) 101 | return action.clip(-self.max_action, self.max_action) 102 | 103 | def actor_alpha_train_step(self, batch: Batch, key: Any, 104 | alpha_state: train_state.TrainState, 105 | actor_state: train_state.TrainState, 106 | critic_state: train_state.TrainState): 107 | 108 | frozen_critic_params = critic_state.params 109 | 110 | def loss_fn(alpha_params: FrozenDict, actor_params: FrozenDict, 111 | rng: Any, observation: jnp.ndarray): 112 | # sample action with actor 113 | _, sampled_action, logp = self.actor.apply( 114 | {"params": actor_params}, rng, observation) 115 | 116 | # compute alpha loss 117 | log_alpha = self.log_alpha.apply({"params": alpha_params}) 118 | alpha = jnp.exp(log_alpha) 119 | alpha_loss = -alpha * jax.lax.stop_gradient( 120 | logp + self.target_entropy) 121 | 122 | # stop alpha gradient 123 | alpha = jax.lax.stop_gradient(alpha) 124 | 125 | # We use frozen_params so that gradients can flow back to the actor without being used to update the critic. 126 | sampled_q1, sampled_q2 = self.critic.apply( 127 | {"params": frozen_critic_params}, observation, sampled_action) 128 | sampled_q = jnp.minimum(sampled_q1, sampled_q2) 129 | 130 | # Actor loss 131 | actor_loss = alpha * logp - sampled_q 132 | 133 | # return info 134 | actor_alpha_loss = actor_loss + alpha_loss 135 | log_info = { 136 | "actor_loss": actor_loss, 137 | "alpha_loss": alpha_loss, 138 | "alpha": alpha, 139 | "logp": logp 140 | } 141 | return actor_alpha_loss, log_info 142 | 143 | # compute gradient with vmap 144 | grad_fn = jax.vmap(jax.value_and_grad(loss_fn, 145 | argnums=(0, 1), 146 | has_aux=True), 147 | in_axes=(None, None, 0, 0)) 148 | keys = jnp.stack(jax.random.split(key, num=batch.actions.shape[0])) 149 | 150 | (_, log_info), grads = grad_fn(alpha_state.params, actor_state.params, 151 | keys, batch.observations) 152 | grads = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), grads) 153 | log_info = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), log_info) 154 | 155 | # Update TrainState 156 | alpha_grads, actor_grads = grads 157 | new_alpha_state = alpha_state.apply_gradients(grads=alpha_grads) 158 | new_actor_state = actor_state.apply_gradients(grads=actor_grads) 159 | return new_alpha_state, new_actor_state, log_info 160 | 161 | def critic_train_step(self, 162 | batch: Batch, 163 | vlm_rewards: jnp.ndarray, 164 | key: Any, 165 | alpha: float, 166 | actor_state: train_state.TrainState, 167 | critic_state: train_state.TrainState, 168 | critic_target_params: FrozenDict): 169 | 170 | frozen_actor_params = actor_state.params 171 | 172 | def loss_fn(params: FrozenDict, rng: Any, observation: jnp.ndarray, 173 | action: jnp.ndarray, reward: jnp.ndarray, 174 | next_observation: jnp.ndarray, discount: jnp.ndarray): 175 | 176 | # current q value 177 | q1, q2 = self.critic.apply({"params": params}, observation, action) 178 | 179 | # next q value 180 | _, next_action, logp_next_action = self.actor.apply( 181 | {"params": frozen_actor_params}, rng, next_observation) 182 | next_q1, next_q2 = self.critic.apply( 183 | {"params": critic_target_params}, next_observation, next_action) 184 | next_q = jnp.minimum(next_q1, next_q2) - alpha * logp_next_action 185 | 186 | # target q value 187 | target_q = reward + self.gamma * discount * next_q 188 | 189 | # td error 190 | critic_loss1 = (q1 - target_q)**2 191 | critic_loss2 = (q2 - target_q)**2 192 | critic_loss = critic_loss1 + critic_loss2 193 | log_info = { 194 | "critic_loss": critic_loss, 195 | "q": q1, 196 | } 197 | return critic_loss, log_info 198 | 199 | # compute gradient with vmap 200 | grad_fn = jax.vmap(jax.value_and_grad(loss_fn, has_aux=True), 201 | in_axes=(None, 0, 0, 0, 0, 0, 0)) 202 | keys = jnp.stack(jax.random.split(key, num=batch.actions.shape[0])) 203 | 204 | # reward shaping 205 | rewards = batch.rewards + self.rho * vlm_rewards 206 | (_, log_info), grads = grad_fn(critic_state.params, 207 | keys, 208 | batch.observations, 209 | batch.actions, 210 | rewards, 211 | batch.next_observations, 212 | batch.discounts) 213 | extra_log_info = {"q_max": log_info["q"].max(), "rvlm_reward": rewards.mean()} 214 | grads = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), grads) 215 | log_info = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), log_info) 216 | log_info.update(extra_log_info) 217 | 218 | # Update TrainState 219 | new_critic_state = critic_state.apply_gradients(grads=grads) 220 | new_critic_target_params = target_update(new_critic_state.params, 221 | critic_target_params, 222 | self.tau) 223 | return new_critic_state, new_critic_target_params, log_info 224 | 225 | @functools.partial(jax.jit, static_argnames=("self")) 226 | def train_step(self, 227 | batch: Batch, 228 | vlm_rewards: jnp.ndarray, 229 | key: Any, 230 | alpha_state: train_state.TrainState, 231 | actor_state: train_state.TrainState, 232 | critic_state: train_state.TrainState, 233 | critic_target_params: FrozenDict): 234 | 235 | key1, key2 = jax.random.split(key) 236 | new_alpha_state, new_actor_state, actor_log_info = self.actor_alpha_train_step( 237 | batch, key1, alpha_state, actor_state, critic_state) 238 | alpha = actor_log_info["alpha"] 239 | new_critic_state, new_critic_target_params, critic_log_info = self.critic_train_step( 240 | batch, vlm_rewards, key2, alpha, actor_state, critic_state, 241 | critic_target_params) 242 | log_info = {**actor_log_info, **critic_log_info} 243 | return new_alpha_state, new_actor_state, new_critic_state, \ 244 | new_critic_target_params, log_info 245 | 246 | def update(self, batch: Batch, vlm_rewards: jnp.ndarray): 247 | self.rng, key = jax.random.split(self.rng, 2) 248 | (self.alpha_state, 249 | self.actor_state, 250 | self.critic_state, 251 | self.critic_target_params, 252 | log_info) = self.train_step(batch, 253 | vlm_rewards, 254 | key, 255 | self.alpha_state, 256 | self.actor_state, 257 | self.critic_state, 258 | self.critic_target_params) 259 | return log_info 260 | 261 | def save(self, cnt: int = 0): 262 | params = {"actor": self.actor_state.params, 263 | "critic": self.critic_state.params} 264 | self.checkpointer.save(f"{self.ckpt_dir}/{cnt}", 265 | params, 266 | force=True) 267 | 268 | def load(self, ckpt_dir: str, cnt: int = 0): 269 | raw_restored = self.checkpointer.restore(f"{ckpt_dir}/{cnt}") 270 | actor_params = raw_restored["actor"] 271 | critic_params = raw_restored["critic"] 272 | 273 | self.actor_state = train_state.TrainState.create( 274 | apply_fn=self.actor.apply, 275 | params=actor_params, 276 | tx=optax.adam(self.lr)) 277 | self.critic_state = train_state.TrainState.create( 278 | apply_fn=self.critic.apply, 279 | params=critic_params, 280 | tx=optax.adam(self.lr)) 281 | self.critic_target_params = critic_params 282 | -------------------------------------------------------------------------------- /models/liv.py: -------------------------------------------------------------------------------- 1 | import clip 2 | from clip.model import CLIP 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torchvision.transforms as T 7 | 8 | 9 | class LIV(nn.Module): 10 | def __init__(self, 11 | modelid="RN50", 12 | device="cuda", 13 | lr=1e-5, 14 | weight_decay=0.001, 15 | visionweight=1.0, 16 | langweight=1.0, 17 | clipweight=1.0, 18 | gamma=0.98, 19 | metric="cos", 20 | num_negatives=0, 21 | grad_text=True, 22 | scratch=False): 23 | super().__init__() 24 | 25 | self.modelid = modelid 26 | self.device = device 27 | self.visionweight = visionweight 28 | self.langweight = langweight 29 | self.clipweight = clipweight 30 | 31 | self.gamma = gamma 32 | self.num_negatives = num_negatives 33 | self.metric = metric 34 | self.grad_text = grad_text 35 | 36 | # Load CLIP model and transform 37 | model, cliptransforms = clip.load(modelid, device=self.device, scratch=scratch, jit=False) 38 | 39 | # CLIP precision 40 | if device == "cpu": 41 | model.float() 42 | else : 43 | clip.model.convert_weights(model) 44 | 45 | self.model = model 46 | self.model.train() 47 | self.transforms = cliptransforms 48 | 49 | self.transforms_tensor = nn.Sequential( 50 | T.Resize(self.model.visual.input_resolution, antialias=None), 51 | T.CenterCrop(self.model.visual.input_resolution), 52 | T.Normalize((0.48145466, 0.4578275, 0.40821073), 53 | (0.26862954, 0.26130258, 0.27577711)) 54 | ) 55 | 56 | self.output_dim = self.model.visual.output_dim 57 | 58 | # Optimizer 59 | self.encoder_opt = torch.optim.Adam(list(self.model.parameters()), 60 | lr=lr, betas=(0.9,0.98),eps=1e-6, weight_decay=weight_decay) 61 | 62 | ## Forward Call (im --> representation) 63 | def forward(self, input, modality="vision", normalize=True): 64 | if modality == "vision": 65 | if torch.max(input) > 10.0: input = input / 255.0 66 | input = self.transforms_tensor(input).to(self.device) 67 | features = self.model.encode_image(input) 68 | elif modality == "text": 69 | b_token = input 70 | if self.grad_text: 71 | features = self.model.encode_text(b_token) 72 | else: 73 | with torch.no_grad(): 74 | features = self.model.encode_text(b_token) 75 | else: 76 | raise NotImplementedError 77 | 78 | return features 79 | 80 | def sim(self, tensor1, tensor2): 81 | if type(tensor1) == np.ndarray: 82 | tensor1 = torch.from_numpy(tensor1).to(self.device) 83 | tensor2 = torch.from_numpy(tensor2).to(self.device) 84 | if self.metric == 'l2': 85 | d = -torch.linalg.norm(tensor1 - tensor2, dim = -1) 86 | elif self.metric == 'cos': 87 | tensor1 = tensor1 / tensor1.norm(dim=-1, keepdim=True) 88 | tensor2 = tensor2 / tensor2.norm(dim=-1, keepdim=True) 89 | d = torch.nn.CosineSimilarity(-1)(tensor1, tensor2) 90 | else: 91 | raise NotImplementedError 92 | return d 93 | 94 | def get_reward(self, e0, es, le, encoded=True): 95 | assert encoded == True 96 | return self.sim(es, le) 97 | -------------------------------------------------------------------------------- /models/projection.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import functools 4 | import optax 5 | import orbax.checkpoint as ocp 6 | 7 | from flax import linen as nn 8 | from flax.training import train_state 9 | from models import MLP 10 | 11 | 12 | class Projection(nn.Module): 13 | def setup(self): 14 | self.text_encoder = MLP(hidden_dims=(256, 64), activate_final=False) 15 | self.image_encoder = MLP(hidden_dims=(256, 64), activate_final=False) 16 | 17 | def __call__(self, text_embedding, image_embedding): 18 | proj_text_embedding = self.text_encoder(text_embedding) 19 | proj_image_embedding = self.image_encoder(image_embedding) 20 | return proj_text_embedding, proj_image_embedding 21 | 22 | def encode_image(self, image_embeddings): 23 | return self.image_encoder(image_embeddings) 24 | 25 | def encode_text(self, text_embedding): 26 | return self.text_encoder(text_embedding) 27 | 28 | 29 | class RewardModel: 30 | def __init__(self, 31 | seed: int = 42, 32 | lr: float = 1e-4, 33 | margin: float = 0.1, 34 | emb_dim: int = 1024, 35 | ckpt_dir: str = None, 36 | text_embedding: jnp.ndarray = None, 37 | goal_embedding: jnp.ndarray = None): 38 | self.lr = lr 39 | self.margin = margin 40 | self.text_embedding = text_embedding 41 | self.goal_embedding = goal_embedding 42 | self.rng = jax.random.PRNGKey(seed) 43 | self.rng, key = jax.random.split(self.rng, 2) 44 | dummy_emb = jnp.ones([1, emb_dim], dtype=jnp.float32) 45 | 46 | self.proj = Projection() 47 | proj_params = self.proj.init(key, 48 | jnp.ones([1, 1024], dtype=jnp.float32), 49 | dummy_emb)["params"] 50 | self.proj_state = train_state.TrainState.create( 51 | apply_fn=self.proj.apply, 52 | params=proj_params, 53 | tx=optax.adam(lr)) 54 | 55 | if ckpt_dir is not None: 56 | self.ckpt_dir = ckpt_dir 57 | self.checkpointer = ocp.StandardCheckpointer() 58 | 59 | @functools.partial(jax.jit, static_argnames=("self")) 60 | def get_vlm_reward(self, proj_state, img_embeddings): 61 | proj_img_embeddings = self.proj.apply( 62 | {"params": proj_state.params}, img_embeddings, 63 | method=self.proj.encode_image) 64 | proj_text_embedding = self.proj.apply( 65 | {"params": proj_state.params}, self.text_embedding, 66 | method=self.proj.encode_text) 67 | cosine_similarity = optax.cosine_similarity(proj_img_embeddings, 68 | proj_text_embedding) 69 | return cosine_similarity 70 | 71 | @functools.partial(jax.jit, static_argnames=("self")) 72 | def train_pos_step(self, 73 | pos_embeddings, 74 | neg_embeddings, 75 | lag_embeddings, 76 | proj_state): 77 | def loss_fn(params): 78 | proj_text_embedding = self.proj.apply( 79 | {"params": params}, self.text_embedding, 80 | method=self.proj.encode_text) 81 | 82 | proj_pos_embeddings = self.proj.apply( 83 | {"params": params}, pos_embeddings, 84 | method=self.proj.encode_image) 85 | proj_neg_embeddings = self.proj.apply( 86 | {"params": params}, neg_embeddings, 87 | method=self.proj.encode_image) 88 | proj_lag_embeddings = self.proj.apply( 89 | {"params": params}, lag_embeddings, 90 | method=self.proj.encode_image) 91 | 92 | pos_cosine = optax.cosine_similarity(proj_text_embedding, 93 | proj_pos_embeddings) 94 | neg_cosine = optax.cosine_similarity(proj_text_embedding, 95 | proj_neg_embeddings) 96 | lag_cosine = optax.cosine_similarity(proj_text_embedding, 97 | proj_lag_embeddings) 98 | 99 | # pos-neg: pos_cosine > lag_cosine > negative_cosine 100 | neg_mask = (neg_cosine - pos_cosine + self.margin) > 0 101 | neg_loss = neg_mask * (neg_cosine - pos_cosine) 102 | 103 | # pos-pos: pos_cosine > lag_cosine 104 | pos_mask = (lag_cosine - pos_cosine + self.margin) > 0 105 | pos_loss = pos_mask * (lag_cosine - pos_cosine) 106 | total_loss = pos_loss.mean() + neg_loss.mean() 107 | log_info = { 108 | "pos_cosine": pos_cosine.mean(), 109 | "pos_cosine_max": pos_cosine.max(), 110 | "pos_cosine_min": pos_cosine.min(), 111 | 112 | "neg_cosine": neg_cosine.mean(), 113 | "neg_cosine_max": neg_cosine.max(), 114 | "neg_cosine_min": neg_cosine.min(), 115 | 116 | "lag_cosine": lag_cosine.mean(), 117 | "lag_cosine_max": lag_cosine.max(), 118 | "lag_cosine_min": lag_cosine.min(), 119 | 120 | "neg_num": neg_mask.sum(), 121 | "neg_loss": neg_loss.mean(), 122 | "neg_loss_max": neg_loss.max(), 123 | 124 | "pos_num": pos_mask.sum(), 125 | "pos_loss": pos_loss.mean(), 126 | "pos_loss_max": pos_loss.max(), 127 | } 128 | return total_loss, log_info 129 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 130 | (_, log_info), grad = grad_fn(proj_state.params) 131 | new_proj_state = proj_state.apply_gradients(grads=grad) 132 | return new_proj_state, log_info 133 | 134 | @functools.partial(jax.jit, static_argnames=("self")) 135 | def train_neg_step(self, 136 | batch, 137 | proj_state): 138 | def loss_fn(params): 139 | proj_text_embedding = self.proj.apply( 140 | {"params": params}, self.text_embedding, 141 | method=self.proj.encode_text) 142 | 143 | proj_embeddings = self.proj.apply( 144 | {"params": params}, batch.embeddings, 145 | method=self.proj.encode_image) 146 | 147 | # cosine similarity 148 | cosine = optax.cosine_similarity(proj_text_embedding, proj_embeddings) 149 | cosine_delta = cosine.reshape(-1, 1) - cosine.reshape(1, -1) 150 | 151 | loss = (nn.relu(-cosine_delta + self.margin) * batch.masks).sum(-1).mean() 152 | log_info = {"pos_loss": loss, "vlm_rewards": cosine} 153 | return loss, log_info 154 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 155 | (_, log_info), grad = grad_fn(proj_state.params) 156 | new_proj_state = proj_state.apply_gradients(grads=grad) 157 | return new_proj_state, log_info 158 | 159 | def update_neg(self, batch): 160 | self.proj_state, log_info = self.train_neg_step(batch, self.proj_state) 161 | return log_info 162 | 163 | def update_pos(self, batch): 164 | self.proj_state, log_info = self.train_pos_step(batch.pos_embeddings, 165 | batch.neg_embeddings, 166 | batch.lag_embeddings, 167 | self.proj_state) 168 | return log_info 169 | 170 | def save(self, cnt): 171 | self.checkpointer.save(f"{self.ckpt_dir}/{cnt}", 172 | {"proj": self.proj_state.params}, 173 | force=True) 174 | 175 | def load(self, ckpt_dir: str, cnt: int = 0): 176 | raw_restored = self.checkpointer.restore(f"{ckpt_dir}/{cnt}") 177 | proj_params = raw_restored["proj"] 178 | self.proj_state = train_state.TrainState.create( 179 | apply_fn=self.proj.apply, 180 | params=proj_params, 181 | tx=optax.adam(self.lr)) 182 | -------------------------------------------------------------------------------- /models/sac.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Sequence 2 | import os 3 | import functools 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import optax 8 | import orbax.checkpoint as ocp 9 | from flax.core import FrozenDict 10 | from flax.training import train_state 11 | from utils import Batch, target_update 12 | 13 | from models import Actor, DoubleCritic, Scalar 14 | 15 | 16 | ##################### 17 | # Soft Actor-Critic # 18 | ##################### 19 | class SACAgent: 20 | 21 | def __init__(self, 22 | obs_dim: int, 23 | act_dim: int, 24 | max_action: float = 1.0, 25 | seed: int = 42, 26 | tau: float = 0.005, 27 | gamma: float = 0.99, 28 | lr: float = 3e-4, 29 | hidden_dims: Sequence[int] = (256, 256), 30 | ckpt_dir: str = None): 31 | 32 | self.lr = lr 33 | self.tau = tau 34 | self.gamma = gamma 35 | self.act_dim = act_dim 36 | self.max_action = max_action 37 | self.target_entropy = -act_dim / 2 38 | 39 | self.rng = jax.random.PRNGKey(seed) 40 | self.rng, actor_key, critic_key = jax.random.split(self.rng, 3) 41 | 42 | # Dummy inputs 43 | self.dummy_obs = jnp.ones([1, obs_dim], dtype=jnp.float32) 44 | dummy_act = jnp.ones([1, act_dim], dtype=jnp.float32) 45 | 46 | # Create the optimizer 47 | actor_tx = optax.adam(lr) 48 | critic_tx = optax.adam(lr) 49 | 50 | # Initialize the Actor 51 | self.actor = Actor(act_dim=act_dim, 52 | max_action=max_action, 53 | hidden_dims=hidden_dims) 54 | actor_params = self.actor.init(actor_key, actor_key, self.dummy_obs)["params"] 55 | self.actor_state = train_state.TrainState.create( 56 | apply_fn=self.actor.apply, params=actor_params, tx=actor_tx) 57 | 58 | # Initialize the Critic 59 | self.critic = DoubleCritic(hidden_dims=hidden_dims) 60 | critic_params = self.critic.init(critic_key, self.dummy_obs, dummy_act)["params"] 61 | self.critic_target_params = critic_params 62 | self.critic_state = train_state.TrainState.create( 63 | apply_fn=self.critic.apply, params=critic_params, tx=critic_tx) 64 | 65 | # Entropy tuning 66 | self.rng, alpha_key = jax.random.split(self.rng, 2) 67 | self.log_alpha = Scalar(0.0) 68 | self.alpha_state = train_state.TrainState.create( 69 | apply_fn=None, 70 | params=self.log_alpha.init(alpha_key)["params"], 71 | tx=optax.adam(lr)) 72 | 73 | # Checkpoint 74 | if ckpt_dir is not None: 75 | self.ckpt_dir = ckpt_dir 76 | self.checkpointer = ocp.StandardCheckpointer() 77 | 78 | @functools.partial(jax.jit, static_argnames=("self")) 79 | def _sample_action(self, params: FrozenDict, rng: Any, 80 | observation: np.ndarray) -> jnp.ndarray: 81 | mean_action, sampled_action, _ = self.actor.apply({"params": params}, 82 | rng, observation) 83 | return mean_action, sampled_action 84 | 85 | def sample_action(self, 86 | observation: np.ndarray, 87 | eval_mode: bool = False) -> np.ndarray: 88 | # for deterministic result 89 | if eval_mode: 90 | sample_rng = self.rng 91 | else: 92 | self.rng, sample_rng = jax.random.split(self.rng) 93 | mean_action, sampled_action = self._sample_action( 94 | self.actor_state.params, sample_rng, observation) 95 | action = mean_action if eval_mode else sampled_action 96 | action = np.asarray(action) 97 | return action.clip(-self.max_action, self.max_action) 98 | 99 | def actor_alpha_train_step(self, 100 | batch: Batch, 101 | key: Any, 102 | alpha_state: train_state.TrainState, 103 | actor_state: train_state.TrainState, 104 | critic_state: train_state.TrainState): 105 | 106 | frozen_critic_params = critic_state.params 107 | 108 | def loss_fn(alpha_params: FrozenDict, 109 | actor_params: FrozenDict, 110 | rng: Any, 111 | observation: jnp.ndarray): 112 | # sample action with actor 113 | _, sampled_action, logp = self.actor.apply( 114 | {"params": actor_params}, rng, observation) 115 | 116 | # compute alpha loss 117 | log_alpha = self.log_alpha.apply({"params": alpha_params}) 118 | alpha = jnp.exp(log_alpha) 119 | alpha_loss = -alpha * jax.lax.stop_gradient(logp + 120 | self.target_entropy) 121 | 122 | # stop alpha gradient 123 | alpha = jax.lax.stop_gradient(alpha) 124 | 125 | # We use frozen_params so that gradients can flow back to the actor without being used to update the critic. 126 | sampled_q1, sampled_q2 = self.critic.apply( 127 | {"params": frozen_critic_params}, observation, sampled_action) 128 | sampled_q = jnp.minimum(sampled_q1, sampled_q2) 129 | 130 | # Actor loss 131 | actor_loss = alpha * logp - sampled_q 132 | 133 | # return info 134 | actor_alpha_loss = actor_loss + alpha_loss 135 | log_info = { 136 | "actor_loss": actor_loss, 137 | "alpha_loss": alpha_loss, 138 | "alpha": alpha, 139 | "logp": logp 140 | } 141 | return actor_alpha_loss, log_info 142 | 143 | # compute gradient with vmap 144 | grad_fn = jax.vmap(jax.value_and_grad(loss_fn, 145 | argnums=(0, 1), 146 | has_aux=True), 147 | in_axes=(None, None, 0, 0)) 148 | keys = jnp.stack(jax.random.split(key, num=batch.actions.shape[0])) 149 | 150 | (_, log_info), grads = grad_fn(alpha_state.params, actor_state.params, 151 | keys, batch.observations) 152 | grads = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), 153 | grads) 154 | log_info = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), 155 | log_info) 156 | 157 | # Update TrainState 158 | alpha_grads, actor_grads = grads 159 | new_alpha_state = alpha_state.apply_gradients(grads=alpha_grads) 160 | new_actor_state = actor_state.apply_gradients(grads=actor_grads) 161 | return new_alpha_state, new_actor_state, log_info 162 | 163 | def critic_train_step(self, 164 | batch: Batch, 165 | key: Any, 166 | alpha: float, 167 | actor_state: train_state.TrainState, 168 | critic_state: train_state.TrainState, 169 | critic_target_params: FrozenDict): 170 | 171 | frozen_actor_params = actor_state.params 172 | 173 | def loss_fn(params: FrozenDict, 174 | rng: Any, 175 | observation: jnp.ndarray, 176 | action: jnp.ndarray, 177 | reward: jnp.ndarray, 178 | next_observation: jnp.ndarray, 179 | discount: jnp.ndarray): 180 | 181 | # current q value 182 | q1, q2 = self.critic.apply({"params": params}, observation, action) 183 | 184 | # next q value 185 | _, next_action, logp_next_action = self.actor.apply( 186 | {"params": frozen_actor_params}, rng, next_observation) 187 | next_q1, next_q2 = self.critic.apply( 188 | {"params": critic_target_params}, next_observation, next_action) 189 | next_q = jnp.minimum(next_q1, next_q2) - alpha * logp_next_action 190 | 191 | # target q value 192 | target_q = reward + self.gamma * discount * next_q 193 | 194 | # td error 195 | critic_loss1 = (q1 - target_q)**2 196 | critic_loss2 = (q2 - target_q)**2 197 | critic_loss = critic_loss1 + critic_loss2 198 | log_info = { 199 | "critic_loss": critic_loss, 200 | "q": q1, 201 | } 202 | return critic_loss, log_info 203 | 204 | # compute gradient with vmap 205 | grad_fn = jax.vmap(jax.value_and_grad(loss_fn, has_aux=True), 206 | in_axes=(None, 0, 0, 0, 0, 0, 0)) 207 | keys = jnp.stack(jax.random.split(key, num=batch.actions.shape[0])) 208 | 209 | (_, log_info), grads = grad_fn(critic_state.params, 210 | keys, 211 | batch.observations, 212 | batch.actions, 213 | batch.rewards, 214 | batch.next_observations, 215 | batch.discounts) 216 | extra_log_info = {"q_max": log_info["q"].max()} 217 | grads = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), grads) 218 | log_info = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), log_info) 219 | log_info.update(extra_log_info) 220 | 221 | # Update TrainState 222 | new_critic_state = critic_state.apply_gradients(grads=grads) 223 | new_critic_target_params = target_update(new_critic_state.params, 224 | critic_target_params, 225 | self.tau) 226 | return new_critic_state, new_critic_target_params, log_info 227 | 228 | @functools.partial(jax.jit, static_argnames=("self")) 229 | def train_step(self, 230 | batch: Batch, 231 | key: Any, 232 | alpha_state: train_state.TrainState, 233 | actor_state: train_state.TrainState, 234 | critic_state: train_state.TrainState, 235 | critic_target_params: FrozenDict): 236 | key1, key2 = jax.random.split(key) 237 | new_alpha_state, new_actor_state, actor_log_info = self.actor_alpha_train_step( 238 | batch, key1, alpha_state, actor_state, critic_state) 239 | alpha = actor_log_info["alpha"] 240 | new_critic_state, new_critic_target_params, critic_log_info = self.critic_train_step( 241 | batch, key2, alpha, actor_state, critic_state, critic_target_params) 242 | log_info = {**actor_log_info, **critic_log_info} 243 | return new_alpha_state, new_actor_state, new_critic_state, new_critic_target_params, log_info 244 | 245 | def update(self, batch: Batch): 246 | self.rng, key = jax.random.split(self.rng, 2) 247 | (self.alpha_state, 248 | self.actor_state, 249 | self.critic_state, 250 | self.critic_target_params, 251 | log_info) = self.train_step(batch, 252 | key, 253 | self.alpha_state, 254 | self.actor_state, 255 | self.critic_state, 256 | self.critic_target_params) 257 | return log_info 258 | 259 | def save(self, cnt: int = 0): 260 | params = {"actor": self.actor_state.params, 261 | "critic": self.critic_state.params} 262 | self.checkpointer.save(f"{self.ckpt_dir}/{cnt}", 263 | params, force=True) 264 | 265 | def load(self, ckpt_dir: str, cnt: int = 0): 266 | raw_restored = self.checkpointer.restore(f"{ckpt_dir}/{cnt}") 267 | actor_params = raw_restored["actor"] 268 | critic_params = raw_restored["critic"] 269 | 270 | self.actor_state = train_state.TrainState.create( 271 | apply_fn=self.actor.apply, 272 | params=actor_params, 273 | tx=optax.adam(self.lr)) 274 | self.critic_state = train_state.TrainState.create( 275 | apply_fn=self.critic.apply, 276 | params=critic_params, 277 | tx=optax.adam(self.lr)) 278 | self.critic_target_params = critic_params 279 | -------------------------------------------------------------------------------- /models/vlm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Sequence 2 | import os 3 | import functools 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import optax 8 | import orbax.checkpoint as ocp 9 | import matplotlib.pyplot as plt 10 | from flax import linen as nn 11 | from flax.core import FrozenDict 12 | from flax.training import train_state, orbax_utils 13 | from utils import Batch, target_update 14 | from models import Actor, DoubleCritic, Scalar 15 | 16 | 17 | ####################### 18 | # Without Fine-tuning # 19 | ####################### 20 | class VLMAgent: 21 | 22 | def __init__(self, 23 | obs_dim: int, 24 | act_dim: int, 25 | max_action: float = 1.0, 26 | seed: int = 42, 27 | tau: float = 0.005, 28 | rho: float = 0.1, 29 | gamma: float = 0.99, 30 | lr: float = 3e-4, 31 | text_embedding: jnp.ndarray = None, 32 | hidden_dims: Sequence[int] = (256, 256), 33 | ckpt_dir: str = None): 34 | 35 | self.lr = lr 36 | self.tau = tau 37 | self.rho = rho 38 | self.gamma = gamma 39 | self.max_action = max_action 40 | self.target_entropy = -act_dim / 2 41 | 42 | self.rng = jax.random.PRNGKey(seed) 43 | self.rng, actor_key, critic_key, W_key = jax.random.split(self.rng, 4) 44 | 45 | self.text_embedding = text_embedding 46 | 47 | # Dummy inputs 48 | dummy_obs = jnp.ones([1, obs_dim], dtype=jnp.float32) 49 | dummy_act = jnp.ones([1, act_dim], dtype=jnp.float32) 50 | 51 | # Create the optimizer 52 | actor_tx = optax.adam(lr) 53 | critic_tx = optax.adam(lr) 54 | 55 | # Initialize the Actor 56 | self.actor = Actor(act_dim=act_dim, 57 | max_action=max_action, 58 | hidden_dims=hidden_dims) 59 | actor_params = self.actor.init(actor_key, 60 | actor_key, 61 | dummy_obs)["params"] 62 | self.actor_state = train_state.TrainState.create( 63 | apply_fn=self.actor.apply, params=actor_params, tx=actor_tx) 64 | 65 | # Initialize the Critic 66 | self.critic = DoubleCritic(hidden_dims=hidden_dims) 67 | critic_params = self.critic.init(critic_key, 68 | dummy_obs, 69 | dummy_act)["params"] 70 | self.critic_target_params = critic_params 71 | self.critic_state = train_state.TrainState.create( 72 | apply_fn=self.critic.apply, params=critic_params, tx=critic_tx) 73 | 74 | # Entropy tuning 75 | self.rng, alpha_key = jax.random.split(self.rng, 2) 76 | self.log_alpha = Scalar(0.0) 77 | self.alpha_state = train_state.TrainState.create( 78 | apply_fn=None, 79 | params=self.log_alpha.init(alpha_key)["params"], 80 | tx=optax.adam(lr)) 81 | 82 | # Checkpoint 83 | if ckpt_dir is not None: 84 | self.ckpt_dir = ckpt_dir 85 | self.checkpointer = ocp.StandardCheckpointer() 86 | 87 | @functools.partial(jax.jit, static_argnames=("self")) 88 | def get_reward(self, image_embedding): 89 | reward = optax.cosine_similarity(image_embedding, self.text_embedding) 90 | return reward 91 | 92 | @functools.partial(jax.jit, static_argnames=("self")) 93 | def _sample_action(self, params: FrozenDict, rng: Any, 94 | observation: np.ndarray) -> jnp.ndarray: 95 | mean_action, sampled_action, _ = self.actor.apply({"params": params}, 96 | rng, observation) 97 | return mean_action, sampled_action 98 | 99 | def sample_action(self, 100 | observation: np.ndarray, 101 | eval_mode: bool = False) -> np.ndarray: 102 | self.rng, sample_rng = jax.random.split(self.rng) 103 | mean_action, sampled_action = self._sample_action( 104 | self.actor_state.params, sample_rng, observation) 105 | action = mean_action if eval_mode else sampled_action 106 | action = np.asarray(action) 107 | return action.clip(-self.max_action, self.max_action) 108 | 109 | def actor_alpha_train_step(self, batch: Batch, key: Any, 110 | alpha_state: train_state.TrainState, 111 | actor_state: train_state.TrainState, 112 | critic_state: train_state.TrainState): 113 | 114 | frozen_critic_params = critic_state.params 115 | 116 | def loss_fn(alpha_params: FrozenDict, actor_params: FrozenDict, 117 | rng: Any, observation: jnp.ndarray): 118 | # sample action with actor 119 | _, sampled_action, logp = self.actor.apply( 120 | {"params": actor_params}, rng, observation) 121 | 122 | # compute alpha loss 123 | log_alpha = self.log_alpha.apply({"params": alpha_params}) 124 | alpha = jnp.exp(log_alpha) 125 | alpha_loss = -alpha * jax.lax.stop_gradient( 126 | logp + self.target_entropy) 127 | 128 | # stop alpha gradient 129 | alpha = jax.lax.stop_gradient(alpha) 130 | 131 | # We use frozen_params so that gradients can flow back to the actor without being used to update the critic. 132 | sampled_q1, sampled_q2 = self.critic.apply( 133 | {"params": frozen_critic_params}, observation, sampled_action) 134 | sampled_q = jnp.minimum(sampled_q1, sampled_q2) 135 | 136 | # Actor loss 137 | actor_loss = alpha * logp - sampled_q 138 | 139 | # return info 140 | actor_alpha_loss = actor_loss + alpha_loss 141 | log_info = { 142 | "actor_loss": actor_loss, 143 | "alpha_loss": alpha_loss, 144 | "alpha": alpha, 145 | "logp": logp 146 | } 147 | return actor_alpha_loss, log_info 148 | 149 | # compute gradient with vmap 150 | grad_fn = jax.vmap(jax.value_and_grad(loss_fn, 151 | argnums=(0, 1), 152 | has_aux=True), 153 | in_axes=(None, None, 0, 0)) 154 | keys = jnp.stack(jax.random.split(key, num=batch.actions.shape[0])) 155 | 156 | (_, log_info), grads = grad_fn(alpha_state.params, actor_state.params, 157 | keys, batch.observations) 158 | grads = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), grads) 159 | log_info = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), log_info) 160 | 161 | # Update TrainState 162 | alpha_grads, actor_grads = grads 163 | new_alpha_state = alpha_state.apply_gradients(grads=alpha_grads) 164 | new_actor_state = actor_state.apply_gradients(grads=actor_grads) 165 | return new_alpha_state, new_actor_state, log_info 166 | 167 | def critic_train_step(self, batch: Batch, key: Any, alpha: float, 168 | actor_state: train_state.TrainState, 169 | critic_state: train_state.TrainState, 170 | critic_target_params: FrozenDict): 171 | 172 | frozen_actor_params = actor_state.params 173 | 174 | def loss_fn(params: FrozenDict, rng: Any, observation: jnp.ndarray, 175 | action: jnp.ndarray, reward: jnp.ndarray, 176 | next_observation: jnp.ndarray, discount: jnp.ndarray): 177 | 178 | # current q value 179 | q1, q2 = self.critic.apply({"params": params}, observation, action) 180 | 181 | # next q value 182 | _, next_action, logp_next_action = self.actor.apply( 183 | {"params": frozen_actor_params}, rng, next_observation) 184 | next_q1, next_q2 = self.critic.apply( 185 | {"params": critic_target_params}, next_observation, next_action) 186 | next_q = jnp.minimum(next_q1, next_q2) - alpha * logp_next_action 187 | 188 | # target q value 189 | target_q = reward + self.gamma * discount * next_q 190 | 191 | # td error 192 | critic_loss1 = (q1 - target_q)**2 193 | critic_loss2 = (q2 - target_q)**2 194 | critic_loss = critic_loss1 + critic_loss2 195 | log_info = { 196 | "critic_loss": critic_loss, 197 | "q": q1, 198 | } 199 | return critic_loss, log_info 200 | 201 | # compute gradient with vmap 202 | grad_fn = jax.vmap(jax.value_and_grad(loss_fn, has_aux=True), 203 | in_axes=(None, 0, 0, 0, 0, 0, 0)) 204 | keys = jnp.stack(jax.random.split(key, num=batch.actions.shape[0])) 205 | 206 | # reward shaping 207 | rewards = batch.rewards + self.rho * batch.vlm_rewards 208 | (_, log_info), grads = grad_fn(critic_state.params, 209 | keys, 210 | batch.observations, 211 | batch.actions, 212 | rewards, 213 | batch.next_observations, 214 | batch.discounts) 215 | extra_log_info = {"q_max": log_info["q"].max()} 216 | grads = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), grads) 217 | log_info = jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), log_info) 218 | log_info.update(extra_log_info) 219 | 220 | # Update TrainState 221 | new_critic_state = critic_state.apply_gradients(grads=grads) 222 | new_critic_target_params = target_update(new_critic_state.params, 223 | critic_target_params, 224 | self.tau) 225 | return new_critic_state, new_critic_target_params, log_info 226 | 227 | @functools.partial(jax.jit, static_argnames=("self")) 228 | def train_step(self, batch: Batch, key: Any, 229 | alpha_state: train_state.TrainState, 230 | actor_state: train_state.TrainState, 231 | critic_state: train_state.TrainState, 232 | critic_target_params: FrozenDict): 233 | key1, key2 = jax.random.split(key) 234 | new_alpha_state, new_actor_state, actor_log_info = self.actor_alpha_train_step( 235 | batch, key1, alpha_state, actor_state, critic_state) 236 | alpha = actor_log_info["alpha"] 237 | new_critic_state, new_critic_target_params, critic_log_info = self.critic_train_step( 238 | batch, key2, alpha, actor_state, critic_state, 239 | critic_target_params) 240 | log_info = {**actor_log_info, **critic_log_info} 241 | return new_alpha_state, new_actor_state, new_critic_state, new_critic_target_params, log_info 242 | 243 | def update(self, batch: Batch): 244 | self.rng, key = jax.random.split(self.rng, 2) 245 | (self.alpha_state, 246 | self.actor_state, 247 | self.critic_state, 248 | self.critic_target_params, 249 | log_info) = self.train_step(batch, 250 | key, 251 | self.alpha_state, 252 | self.actor_state, 253 | self.critic_state, 254 | self.critic_target_params) 255 | return log_info 256 | 257 | def save(self, cnt: int = 0): 258 | self.checkpointer.save(f"{self.ckpt_dir}/{cnt}", 259 | {"proj": self.proj_state.params}, 260 | force=True) 261 | 262 | def load(self, ckpt_dir: str, cnt: int = 0): 263 | raw_restored = self.checkpointer.restore(f"{ckpt_dir}/{cnt}") 264 | actor_params = raw_restored["actor"] 265 | critic_params = raw_restored["critic"] 266 | self.actor_state = train_state.TrainState.create( 267 | apply_fn=self.actor.apply, 268 | params=actor_params, 269 | tx=optax.adam(self.lr)) 270 | self.critic_state = train_state.TrainState.create( 271 | apply_fn=self.critic.apply, 272 | params=critic_params, 273 | tx=optax.adam(self.lr)) 274 | self.critic_target_params = critic_params 275 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | antlr4-python3-runtime==4.9.3 3 | asttokens==2.4.1 4 | beautifulsoup4==4.12.3 5 | certifi==2024.2.2 6 | charset-normalizer==3.3.2 7 | chex==0.1.85 8 | cloudpickle==3.0.0 9 | contextlib2==21.6.0 10 | contourpy==1.2.0 11 | cycler==0.12.1 12 | decorator==4.4.2 13 | distrax==0.1.5 14 | dm-tree==0.1.8 15 | etils==1.7.0 16 | executing==2.0.1 17 | Farama-Notifications==0.0.4 18 | filelock==3.13.1 19 | flax==0.7.4 20 | fonttools==4.49.0 21 | fsspec==2024.2.0 22 | ftfy==6.1.3 23 | gast==0.5.4 24 | gdown==5.1.0 25 | gitdb==4.0.11 26 | GitPython==3.1.42 27 | glfw==2.7.0 28 | gymnasium==0.29.1 29 | gymnasium-robotics==1.2.4 30 | huggingface-hub==0.21.4 31 | hydra-core==1.3.2 32 | idna==3.6 33 | imageio==2.34.0 34 | imageio-ffmpeg==0.4.9 35 | importlib_resources==6.1.3 36 | ipython==8.22.2 37 | jax==0.4.16 38 | jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.16+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl#sha256=065f61497f1c2d75ff53e05a6d0d92060b9f4ac487ee16211a46121c6c245767 39 | jedi==0.19.1 40 | Jinja2==3.1.3 41 | kiwisolver==1.4.5 42 | markdown-it-py==3.0.0 43 | MarkupSafe==2.1.5 44 | matplotlib==3.8.3 45 | matplotlib-inline==0.1.6 46 | mdurl==0.1.2 47 | metaworld @ git+https://github.com/Farama-Foundation/Metaworld.git@c822f28f582ba1ad49eb5dcf61016566f28003ba 48 | ml-collections==0.1.1 49 | ml-dtypes==0.3.2 50 | moviepy==1.0.3 51 | mpmath==1.3.0 52 | msgpack==1.0.8 53 | mujoco==2.3.7 54 | nest-asyncio==1.6.0 55 | networkx==3.2.1 56 | numpy==1.26.4 57 | nvidia-cublas-cu12==12.1.3.1 58 | nvidia-cuda-cupti-cu12==12.1.105 59 | nvidia-cuda-nvrtc-cu12==12.1.105 60 | nvidia-cuda-runtime-cu12==12.1.105 61 | nvidia-cudnn-cu12==8.9.2.26 62 | nvidia-cufft-cu12==11.0.2.54 63 | nvidia-curand-cu12==10.3.2.106 64 | nvidia-cusolver-cu12==11.4.5.107 65 | nvidia-cusparse-cu12==12.1.0.106 66 | nvidia-nccl-cu12==2.19.3 67 | nvidia-nvjitlink-cu12==12.4.99 68 | nvidia-nvtx-cu12==12.1.105 69 | omegaconf==2.3.0 70 | open-clip-torch==2.24.0 71 | opencv-python==4.9.0.80 72 | opt-einsum==3.3.0 73 | optax==0.2.1 74 | orbax-checkpoint==0.5.3 75 | packaging==23.2 76 | pandas==2.2.1 77 | parso==0.8.3 78 | pettingzoo==1.24.3 79 | pexpect==4.9.0 80 | pillow==10.2.0 81 | proglog==0.1.10 82 | prompt-toolkit==3.0.43 83 | protobuf==4.25.3 84 | ptyprocess==0.7.0 85 | pure-eval==0.2.2 86 | Pygments==2.17.2 87 | PyOpenGL==3.1.7 88 | pyparsing==3.1.2 89 | PySocks==1.7.1 90 | python-dateutil==2.9.0.post0 91 | pytz==2024.1 92 | PyYAML==6.0.1 93 | regex==2023.12.25 94 | requests==2.31.0 95 | rich==13.7.1 96 | safetensors==0.4.2 97 | scipy==1.12.0 98 | sentencepiece==0.2.0 99 | six==1.16.0 100 | smmap==5.0.1 101 | soupsieve==2.5 102 | stack-data==0.6.3 103 | sympy==1.12 104 | tensorflow-probability==0.23.0 105 | tensorstore==0.1.54 106 | timm==0.9.16 107 | toolz==0.12.1 108 | torch==2.2.1 109 | torchvision==0.17.1 110 | tqdm==4.66.2 111 | traitlets==5.14.1 112 | triton==2.2.0 113 | typing_extensions==4.10.0 114 | tzdata==2024.1 115 | urllib3==2.2.1 116 | wcwidth==0.2.13 117 | zipp==3.17.0 -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Script to reproduce results 5 | envs=( 6 | "door-open-v2-goal-observable" 7 | # "button-press-topdown-v2-goal-observable" 8 | # "window-close-v2-goal-observable" 9 | # "drawer-open-v2-goal-observable" 10 | # "reach-v2-goal-observable" 11 | 12 | # "door-open-v2-goal-hidden" 13 | # "button-press-topdown-v2-goal-hidden" 14 | # "window-close-v2-goal-hidden" 15 | # "drawer-open-v2-goal-hidden" 16 | # "reach-v2-goal-hidden" 17 | ) 18 | 19 | 20 | for seed in 42 21 | do 22 | for env in ${envs[*]} 23 | do 24 | python main.py \ 25 | --config.env_name=$env \ 26 | --config.exp_name=furl \ 27 | --config.seed=$seed \ 28 | --config.rho=0.05 \ 29 | --config.gap=10 \ 30 | --config.expl_noise=0.2 \ 31 | --config.embed_buffer_size=20000 32 | done 33 | done 34 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .env_utils import TASKS, make_env 2 | from .liv_utils import load_liv 3 | from .train_utils import get_logger, target_update, log_git 4 | from .buffer_utils import Batch, ReplayBuffer, VLMBuffer, DistanceBuffer, EmbeddingBuffer 5 | -------------------------------------------------------------------------------- /utils/buffer_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import jax 3 | import logging 4 | import numpy as np 5 | from flax.core import FrozenDict 6 | 7 | 8 | # basic batch 9 | Batch = collections.namedtuple( 10 | "Batch", 11 | ["observations", "actions", "rewards", "discounts", "next_observations"]) 12 | 13 | 14 | FinetuneBatch = collections.namedtuple( 15 | "FinetuneBatch", 16 | ["observations", "actions", "rewards", "discounts", "next_observations", "embeddings"]) 17 | 18 | 19 | MaskBatch = collections.namedtuple( 20 | "MaskBatch", 21 | ["observations", "actions", "rewards", "discounts", "next_observations", "embeddings", "masks"]) 22 | 23 | 24 | VLMBatch = collections.namedtuple( 25 | "VLMBatch", 26 | ["observations", "actions", "rewards", "vlm_rewards", "discounts", "next_observations"]) 27 | 28 | 29 | EmbeddingBatch = collections.namedtuple( 30 | "EmbeddingBatch", 31 | ["pos_embeddings", "neg_embeddings", "lag_embeddings"]) 32 | 33 | 34 | class ReplayBuffer: 35 | 36 | def __init__(self, obs_dim: int, act_dim: int, max_size: int = int(1e6)): 37 | self.max_size = max_size 38 | self.ptr = 0 39 | self.size = 0 40 | 41 | self.observations = np.zeros((max_size, obs_dim)) 42 | self.actions = np.zeros((max_size, act_dim)) 43 | self.next_observations = np.zeros((max_size, obs_dim)) 44 | self.rewards = np.zeros(max_size) 45 | self.discounts = np.zeros(max_size) 46 | 47 | def add(self, 48 | observation: np.ndarray, 49 | action: np.ndarray, 50 | next_observation: np.ndarray, 51 | reward: float, 52 | done: float): 53 | self.observations[self.ptr] = observation 54 | self.actions[self.ptr] = action 55 | self.next_observations[self.ptr] = next_observation 56 | self.rewards[self.ptr] = reward 57 | self.discounts[self.ptr] = 1 - done 58 | 59 | self.ptr = (self.ptr + 1) % self.max_size 60 | self.size = min(self.size + 1, self.max_size) 61 | 62 | def sample(self, batch_size: int) -> Batch: 63 | idx = np.random.randint(0, self.size, size=batch_size) 64 | batch = Batch(observations=self.observations[idx], 65 | actions=self.actions[idx], 66 | rewards=self.rewards[idx], 67 | discounts=self.discounts[idx], 68 | next_observations=self.next_observations[idx]) 69 | return batch 70 | 71 | def save(self, fname: str): 72 | np.savez(fname, 73 | observations=self.observations, 74 | actions=self.actions, 75 | next_observations=self.next_observations, 76 | rewards=self.rewards, 77 | discounts=self.discounts) 78 | 79 | 80 | class VLMBuffer: 81 | 82 | def __init__(self, 83 | obs_dim: int, 84 | act_dim: int, 85 | max_size: int = int(1e6)): 86 | self.max_size = max_size 87 | self.ptr = 0 88 | self.size = 0 89 | 90 | self.observations = np.zeros((max_size, obs_dim)) 91 | self.actions = np.zeros((max_size, act_dim)) 92 | self.next_observations = np.zeros((max_size, obs_dim)) 93 | self.vlm_rewards = np.zeros(max_size) 94 | self.rewards = np.zeros(max_size) 95 | self.discounts = np.zeros(max_size) 96 | 97 | def add(self, 98 | observation: np.ndarray, 99 | action: np.ndarray, 100 | next_observation: np.ndarray, 101 | vlm_reward: float, 102 | reward: float, 103 | done: float): 104 | self.observations[self.ptr] = observation 105 | self.actions[self.ptr] = action 106 | self.next_observations[self.ptr] = next_observation 107 | self.vlm_rewards[self.ptr] = vlm_reward 108 | self.rewards[self.ptr] = reward 109 | self.discounts[self.ptr] = 1 - done 110 | 111 | self.ptr = (self.ptr + 1) % self.max_size 112 | self.size = min(self.size + 1, self.max_size) 113 | 114 | def sample(self, batch_size: int) -> Batch: 115 | idx = np.random.randint(0, self.size, size=batch_size) 116 | rewards = self.rewards[idx] 117 | vlm_rewards = self.vlm_rewards[idx] 118 | batch = VLMBatch(observations=self.observations[idx], 119 | actions=self.actions[idx], 120 | rewards=rewards, 121 | vlm_rewards=vlm_rewards, 122 | discounts=self.discounts[idx], 123 | next_observations=self.next_observations[idx]) 124 | return batch 125 | 126 | 127 | class DistanceBuffer: 128 | 129 | def __init__(self, 130 | obs_dim: int, 131 | act_dim: int, 132 | emb_dim: int = 1024, 133 | max_size: int = int(1e6)): 134 | self.max_size = max_size 135 | self.ptr = 0 136 | self.size = 0 137 | 138 | self.observations = np.zeros((max_size, obs_dim)) 139 | self.actions = np.zeros((max_size, act_dim)) 140 | self.rewards = np.zeros(max_size) 141 | self.next_observations = np.zeros((max_size, obs_dim)) 142 | self.discounts = np.zeros(max_size) 143 | self.embeddings = np.zeros((max_size, emb_dim)) 144 | self.distances = np.zeros((max_size)) 145 | 146 | def add(self, 147 | observation: np.ndarray, 148 | action: np.ndarray, 149 | next_observation: np.ndarray, 150 | reward: float, 151 | done: float, 152 | embedding: np.ndarray, 153 | distance: float = 0): 154 | 155 | self.observations[self.ptr] = observation 156 | self.actions[self.ptr] = action 157 | self.next_observations[self.ptr] = next_observation 158 | self.rewards[self.ptr] = reward 159 | self.discounts[self.ptr] = 1 - done 160 | self.embeddings[self.ptr] = embedding 161 | self.distances[self.ptr] = distance 162 | 163 | self.ptr = (self.ptr + 1) % self.max_size 164 | self.size = min(self.size + 1, self.max_size) 165 | 166 | def sample_with_mask(self, batch_size: int, l2_margin: float = 0.05) -> Batch: 167 | idx = np.random.randint(0, self.size, size=batch_size) 168 | distance = self.distances[idx] 169 | 170 | l2_delta = distance.reshape(-1, 1) - distance.reshape(1, -1) 171 | masks = (l2_delta < -l2_margin).astype(np.float32) 172 | 173 | batch = MaskBatch(observations=self.observations[idx], 174 | actions=self.actions[idx], 175 | rewards=self.rewards[idx], 176 | discounts=self.discounts[idx], 177 | next_observations=self.next_observations[idx], 178 | embeddings=self.embeddings[idx], 179 | masks=masks) 180 | 181 | return batch 182 | 183 | def sample(self, batch_size: int) -> Batch: 184 | idx = np.random.randint(0, self.size, size=batch_size) 185 | batch = FinetuneBatch(observations=self.observations[idx], 186 | actions=self.actions[idx], 187 | rewards=self.rewards[idx], 188 | discounts=self.discounts[idx], 189 | next_observations=self.next_observations[idx], 190 | embeddings=self.embeddings[idx]) 191 | 192 | return batch 193 | 194 | 195 | class EmbeddingBuffer: 196 | def __init__(self, 197 | emb_dim: int, 198 | gap: int = 10, 199 | max_size: int = int(1e5)): 200 | self.gap = gap 201 | self.max_size = max_size 202 | 203 | self.pos_ptr = 0 204 | self.pos_size = 0 205 | self.pos_embeddings = np.zeros((max_size, emb_dim)) 206 | 207 | self.neg_ptr = 0 208 | self.neg_size = 0 209 | self.neg_embeddings = np.zeros((max_size, emb_dim)) 210 | 211 | self.valid_ptr = 0 212 | self.valid_size = 0 213 | self.valid_idxes = np.zeros(max_size, dtype=np.int32) 214 | 215 | def add(self, 216 | embedding: np.ndarray, 217 | success: bool = False, 218 | valid: bool = False): 219 | if success: 220 | self.pos_embeddings[self.pos_ptr] = embedding 221 | if valid: 222 | self.valid_idxes[self.valid_ptr] = self.pos_ptr 223 | self.valid_ptr = (self.valid_ptr + 1) % self.max_size 224 | self.valid_size = min(self.valid_size + 1, self.max_size) 225 | self.pos_ptr = (self.pos_ptr + 1) % self.max_size 226 | self.pos_size = min(self.pos_size + 1, self.max_size) 227 | else: 228 | self.neg_embeddings[self.neg_ptr] = embedding 229 | self.neg_ptr = (self.neg_ptr + 1) % self.max_size 230 | self.neg_size = min(self.neg_size + 1, self.max_size) 231 | 232 | def sample(self, batch_size): 233 | neg_idx = np.random.randint(0, self.neg_size, size=batch_size) 234 | valid_idx = np.random.randint(0, self.valid_size, size=batch_size) 235 | pos_idx = self.valid_idxes[valid_idx] 236 | lag_idx = (pos_idx - self.gap) % self.valid_size 237 | 238 | pos_embeddings = self.pos_embeddings[pos_idx] 239 | lag_embeddings = self.pos_embeddings[lag_idx] 240 | neg_embeddings = self.neg_embeddings[neg_idx] 241 | return EmbeddingBatch(pos_embeddings=pos_embeddings, 242 | lag_embeddings=lag_embeddings, 243 | neg_embeddings=neg_embeddings) 244 | 245 | def save(self, fdir): 246 | np.savez(fdir, 247 | pos_embeddings=self.pos_embeddings, 248 | neg_embeddings=self.neg_embeddings, 249 | pos_ptr=self.pos_ptr, 250 | pos_size=self.pos_size, 251 | neg_ptr=self.neg_ptr, 252 | neg_size=self.neg_size, 253 | valid_ptr=self.valid_ptr, 254 | valid_size=self.valid_size, 255 | valid_idxes=self.valid_idxes) 256 | -------------------------------------------------------------------------------- /utils/env_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import imageio 4 | import numpy as np 5 | import gymnasium as gym 6 | from gymnasium.spaces import Box 7 | from metaworld.envs import ( 8 | ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE, 9 | ALL_V2_ENVIRONMENTS_GOAL_HIDDEN 10 | ) 11 | 12 | 13 | TASKS = { 14 | "window-open-v2-goal-hidden": "push and open a window", 15 | "window-close-v2-goal-hidden": "push and close a window", 16 | "button-press-topdown-v2-goal-hidden": "press a button from the top", 17 | "door-open-v2-goal-hidden": "open a door with a revolving joint", 18 | "drawer-close-v2-goal-hidden": "push and close a drawer", 19 | "drawer-open-v2-goal-hidden": "open a drawer", 20 | "push-v2-goal-hidden": "push the puck to a goal", 21 | "reach-v2-goal-hidden": "reach a goal position", 22 | "window-open-v2-goal-observable": "push and open a window", 23 | "window-close-v2-goal-observable": "push and close a window", 24 | "button-press-topdown-v2-goal-observable": "press a button from the top", 25 | "door-open-v2-goal-observable": "open a door with a revolving joint", 26 | "drawer-close-v2-goal-observable": "push and close a drawer", 27 | "drawer-open-v2-goal-observable": "open a drawer", 28 | "push-v2-goal-observable": "push the puck to a goal", 29 | "reach-v2-goal-observable": "reach a goal position", 30 | "pick-place-v2-goal-observable": "pick and place a puck to a goal", 31 | "peg-insert-side-v2-goal-observable": "insert a peg sideways", 32 | } 33 | 34 | 35 | ################### 36 | # Utils Functions # 37 | ################### 38 | def randomize_initial_state(env, num_step=50): 39 | dx = np.random.uniform(-0.5, 0.5) 40 | dy = np.random.uniform(-0.5, 0) 41 | dz = np.random.uniform(0., 0.05) 42 | actions = [np.array([dx, 0., 0., 0.]), 43 | np.array([0., dy, 0., 0.]), 44 | np.array([0., 0., dz, 0.])] 45 | for _ in range(num_step): 46 | action = actions[np.random.randint(3)] 47 | _ = env.step(action) 48 | env.curr_path_length = 0 49 | 50 | 51 | ################ 52 | # Env Wrappers # 53 | ################ 54 | class PixelObservationWrapper(gym.ObservationWrapper): 55 | def __init__(self, 56 | env: gym.Env, 57 | image_size: int, 58 | camera_id: int=1): 59 | super().__init__(env) 60 | self.observation_space = Box(low=0, 61 | high=255, 62 | shape=(image_size, image_size, 3), 63 | dtype=np.uint8) 64 | self.viewer = self.env.unwrapped.mujoco_renderer._get_viewer( 65 | render_mode="rgb_array") 66 | self.camera_id = camera_id 67 | self.image_size = image_size 68 | 69 | def get_image(self): 70 | img = self.unwrapped.mujoco_renderer.render( 71 | render_mode="rgb_array", camera_id=self.camera_id) 72 | return img[::-1] 73 | 74 | def observation(self, observation): 75 | return self.get_image() 76 | 77 | def render_img(self, render_image_size: int = 256): 78 | self.viewer.viewport.width = render_image_size 79 | self.viewer.viewport.height = render_image_size 80 | frame = self.env.render() 81 | self.viewer.viewport.width = self.image_size 82 | self.viewer.viewport.height = self.image_size 83 | return frame[::-1] 84 | 85 | 86 | class RepeatAction(gym.Wrapper): 87 | def __init__(self, env: gym.Env, action_repeat: int=4): 88 | super().__init__(env) 89 | self._action_repeat = action_repeat 90 | 91 | def step(self, action: np.ndarray): 92 | total_reward = 0.0 93 | combined_info = {} 94 | 95 | for _ in range(self._action_repeat): 96 | obs, reward, terminated, truncated, info = self.env.step(action) 97 | total_reward += reward 98 | combined_info.update(info) 99 | if terminated or truncated: 100 | break 101 | 102 | return obs, total_reward, terminated, truncated, combined_info 103 | 104 | 105 | ################# 106 | # Main Function # 107 | ################# 108 | def make_env(env_name: str = "drawer-open-v2-goal-hidden", 109 | seed: int = 42, 110 | camera_id: int = 1, 111 | render_image_size: int = 256, 112 | image_size: int = 256, # 84 113 | use_pixel: bool = False, 114 | action_repeat: int = 1, 115 | render_mode: str = "rgb_array"): 116 | 117 | if "hidden" in env_name: 118 | env = ALL_V2_ENVIRONMENTS_GOAL_HIDDEN[env_name](seed=seed) 119 | else: 120 | env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_name](seed=seed, render_mode=render_mode) 121 | env._freeze_rand_vec = False 122 | env.camera_id = camera_id 123 | viewer = env.unwrapped.mujoco_renderer._get_viewer(render_mode=render_mode) 124 | viewer.viewport.width = image_size 125 | viewer.viewport.height = image_size 126 | 127 | # sticky actions 128 | if action_repeat > 1: 129 | env = RepeatAction(env, action_repeat) 130 | 131 | # use pixel-based obs 132 | if use_pixel: 133 | env = PixelObservationWrapper(env, image_size, camera_id) 134 | 135 | # set random seed 136 | env.reset(seed=seed) 137 | env.action_space.seed(seed=seed) 138 | 139 | return env 140 | -------------------------------------------------------------------------------- /utils/liv_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from os.path import expanduser 4 | import omegaconf 5 | import gdown 6 | import hydra 7 | import torch 8 | from huggingface_hub import hf_hub_download 9 | from torch.hub import load_state_dict_from_url 10 | 11 | 12 | device = "cuda" if torch.cuda.is_available() else "cpu" 13 | 14 | 15 | VALID_ARGS = ["_target_", "device", "lr", "hidden_dim", "size", 16 | "l2weight", "l1weight", "num_negatives"] 17 | 18 | 19 | def cleanup_config(cfg): 20 | config = copy.deepcopy(cfg) 21 | keys = config.agent.keys() 22 | for key in list(keys): 23 | if key not in VALID_ARGS: 24 | del config.agent[key] 25 | config.agent["_target_"] = "models.LIV" 26 | config["device"] = device 27 | return config.agent 28 | 29 | 30 | def load_liv(model_id="resnet50"): 31 | """ 32 | model_config = { 33 | 'save_snapshot': True, 34 | 'load_snap': '', 35 | 'dataset': 'ego4d', 36 | 'num_workers': 10, 37 | 'batch_size': 32, 38 | 'train_steps': 2000000, 39 | 'eval_freq': 20000, 40 | 'seed': 1, 41 | 'device': 'cuda', 42 | 'lr': 0.0001, 43 | 'wandbproject': None, 44 | 'wandbuser': None, 45 | 'doaug': 'rctraj', 46 | 'agent': { 47 | '_target_': 'models.model_vip.VIP', 48 | 'device': '${device}', 49 | 'lr': '${lr}', 50 | 'hidden_dim': 1024, 51 | 'size': 50, 52 | 'l2weight': 0.001, 53 | 'l1weight': 0.001, 54 | 'gamma': 0.98, 55 | 'bs': '${batch_size}' 56 | } 57 | } 58 | 59 | clean_config = { 60 | '_target_': 'vip.VIP', 61 | 'device': '${device}', 62 | 'lr': '${lr}', 63 | 'hidden_dim': 1024, 64 | 'size': 50, 65 | 'l2weight': 0.001, 66 | 'l1weight': 0.001 67 | } 68 | """ 69 | base_dir = os.path.join(expanduser("~"), ".liv") 70 | os.makedirs(os.path.join(base_dir, model_id), exist_ok=True) 71 | 72 | folder_dir = os.path.join(base_dir, model_id) 73 | model_dir = os.path.join(base_dir, model_id, "model.pt") 74 | config_dir = os.path.join(base_dir, model_id, "config.yaml") 75 | 76 | try: 77 | hf_hub_download(repo_id="jasonyma/LIV", 78 | filename="model.pt", 79 | local_dir=folder_dir) 80 | hf_hub_download(repo_id="jasonyma/LIV", 81 | filename="config.yaml", 82 | local_dir=folder_dir) 83 | except: 84 | model_url = "https://drive.google.com/uc?id=1l1ufzVLxpE5BK7JY6ZnVBljVzmK5c4P3" 85 | config_url = "https://drive.google.com/uc?id=1GWA5oSJDuHGB2WEdyZZmkro83FNmtaWl" 86 | if not os.path.exists(model_dir): 87 | gdown.download(model_url, model_dir, quiet=False) 88 | gdown.download(config_url, config_dir, quiet=False) 89 | else: 90 | load_state_dict_from_url(model_url, 91 | folder_dir, 92 | map_location=torch.device(device)) 93 | load_state_dict_from_url(config_url, folder_dir) 94 | 95 | model_config = omegaconf.OmegaConf.load(config_dir) 96 | clean_config = cleanup_config(model_config) 97 | 98 | liv_model = hydra.utils.instantiate(clean_config) 99 | liv_model = torch.nn.DataParallel(liv_model) 100 | liv_state_dict = torch.load(model_dir, map_location=torch.device(device))["liv"] 101 | liv_model.load_state_dict(liv_state_dict) 102 | liv_model.eval() 103 | return liv_model 104 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import git 3 | import jax 4 | import numpy as np 5 | from flax.core import FrozenDict 6 | 7 | 8 | def get_logger(fname: str) -> logging.Logger: 9 | logging.basicConfig(level=logging.INFO, 10 | format='%(asctime)s - %(message)s', 11 | datefmt='%Y-%m-%d %H:%M:%S', 12 | filename=fname, 13 | filemode='w', 14 | force=True) 15 | logger = logging.getLogger() 16 | return logger 17 | 18 | 19 | def log_git(config): 20 | config.unlock() 21 | repo = git.Repo(search_parent_directories=True) 22 | sha = repo.head.object.hexsha 23 | config["commit"] = sha 24 | 25 | 26 | def target_update(params: FrozenDict, 27 | target_params: FrozenDict, 28 | tau: float) -> FrozenDict: 29 | 30 | def _update(param: FrozenDict, target_param: FrozenDict): 31 | return tau * param + (1 - tau) * target_param 32 | 33 | updated_params = jax.tree_util.tree_map(_update, params, target_params) 34 | return updated_params 35 | --------------------------------------------------------------------------------