├── .gitignore ├── dino ├── vqkd_teacher │ ├── __init__.py │ ├── clip │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── simple_tokenizer.py │ │ ├── clip.py │ │ └── model.py │ └── dino.py ├── save_dino_codes.py ├── norm_ema_quantizer.py ├── train_vq_dino_voc.py └── modeling_vqkd.py ├── requirements.txt ├── utils ├── video.py ├── network_utils.py ├── dmc.py ├── utils.py ├── drqv2.py └── numpy_replay_buffer.py ├── README.md ├── musik ├── musik_model.py └── train_vq_musik_voc.py └── vae ├── save_vq_codes.py ├── train_vq_vae_voc.py └── train_pixel_vq_vae_voc.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/taming 2 | **/wandb 3 | **/__pycache__ 4 | **.pth -------------------------------------------------------------------------------- /dino/vqkd_teacher/__init__.py: -------------------------------------------------------------------------------- 1 | from .dino import * 2 | from .clip import * -------------------------------------------------------------------------------- /dino/vqkd_teacher/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | from .model import * 3 | -------------------------------------------------------------------------------- /dino/vqkd_teacher/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manantomar/video-occupancy-models/HEAD/dino/vqkd_teacher/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.17.2+computecanada 2 | h5py==3.8.0+computecanada 3 | matplotlib==3.7.0+computecanada 4 | mujoco-py==2.0.2.10 5 | numpy==1.23.0+computecanada 6 | opencv-python==4.5.1.48+computecanada 7 | pyarrow==5.0.0 8 | seaborn==0.13.2+computecanada 9 | -e git+https://github.com/manantomar/taming-transformers.git@3d8c8ac03d12db0ed361ec74250022cc09af8cbc#egg=taming_transformers 10 | timm==0.9.16+computecanada 11 | torch==2.0.1+cu118 12 | torchaudio==2.0.2+cu118 13 | torchmetrics==1.2.1+computecanada 14 | torchvision==0.11.1+computecanada 15 | tqdm==4.66.1+computecanada 16 | transformers==4.33.3+computecanada 17 | typed-argument-parser==1.9.0 18 | vector_quantize_pytorch==1.12.12 19 | vit-pytorch==1.6.5 20 | wandb==0.15.9 21 | zipp==3.16.2+computecanada 22 | -------------------------------------------------------------------------------- /utils/video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import cv2 6 | import imageio 7 | import numpy as np 8 | 9 | 10 | class VideoRecorder: 11 | def __init__(self, root_dir, render_size=256, fps=20): 12 | if root_dir is not None: 13 | self.save_dir = root_dir / 'eval_video' 14 | self.save_dir.mkdir(exist_ok=True) 15 | else: 16 | self.save_dir = None 17 | 18 | self.render_size = render_size 19 | self.fps = fps 20 | self.frames = [] 21 | 22 | def init(self, env, enabled=True): 23 | self.frames = [] 24 | self.enabled = self.save_dir is not None and enabled 25 | self.record(env) 26 | 27 | def record(self, env): 28 | if self.enabled: 29 | if hasattr(env, 'physics'): 30 | frame = env.physics.render(height=self.render_size, 31 | width=self.render_size, 32 | camera_id=0) 33 | else: 34 | frame = env.render() 35 | self.frames.append(frame) 36 | 37 | def save(self, file_name): 38 | if self.enabled: 39 | path = self.save_dir / file_name 40 | imageio.mimsave(str(path), self.frames, fps=self.fps) 41 | 42 | 43 | class TrainVideoRecorder: 44 | def __init__(self, root_dir, render_size=256, fps=20): 45 | if root_dir is not None: 46 | self.save_dir = root_dir / 'train_video' 47 | self.save_dir.mkdir(exist_ok=True) 48 | else: 49 | self.save_dir = None 50 | 51 | self.render_size = render_size 52 | self.fps = fps 53 | self.frames = [] 54 | 55 | def init(self, obs, enabled=True): 56 | self.frames = [] 57 | self.enabled = self.save_dir is not None and enabled 58 | self.record(obs) 59 | 60 | def record(self, obs): 61 | if self.enabled: 62 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0), 63 | dsize=(self.render_size, self.render_size), 64 | interpolation=cv2.INTER_CUBIC) 65 | self.frames.append(frame) 66 | 67 | def save(self, file_name): 68 | if self.enabled: 69 | path = self.save_dir / file_name 70 | imageio.mimsave(str(path), self.frames, fps=self.fps) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Video Occupancy Models 2 | 3 | Code for the paper `Video Occupancy Models`, includes three versions of quantizing the input video frames -- `vae` which uses a VQ-VAE, `dino` which uses quantized DINO, and `musik` which uses quantized Multi-step Inverse Dynamics. 4 | 5 | Screenshot 2024-07-16 at 12 05 30 PM 6 | 7 | This is a PyTorch/GPU implementation of the paper [Video Occupancy Models](https://arxiv.org/pdf/2407.09533): 8 | ``` 9 | @Article{VideoOccupancyModels2024, 10 | author = {Manan Tomar and Philippe Hansen-Estruch and Philip Bachman and Alex Lamb and John Langford and Matthew E. Taylor and Sergey Levine, 11 | journal = {arXiv:2407.09533}, 12 | title = {Video Occupancy Models}, 13 | year = {2024}, 14 | } 15 | ``` 16 | ### Installation 17 | 18 | The main packages are provided in the `requirements.txt` file. This code has been tested on a virtual env with Python-3.8 with the package versions listed in the requirements file. 19 | 20 | ### Model Checkpoints and Datasets 21 | 22 | The following table provides the pre-trained model checkpoints and datasets used in the paper: 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 |
CheetahWalker
VQ-VAE fine-tuned model checkpointdownloaddownload
DINO latent datasetslink
VQ-VAE latent datasetslinklink
42 | 43 | ### VQ-VAE VOC 44 | 45 | You would need to download the contents of this [folder](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/) and place them one directory above where this repo is present. This folder contains model descriptions for using a VQ-VAE model from the [taming-transformers](https://github.com/CompVis/taming-transformers?tab=readme-ov-file) codebase. 46 | 47 | Run [train_vq_vae_voc.py](https://github.com/manantomar/video-occupancy-models/blob/master/vae/train_vq_vae_voc.py) to train a VOC model on stored VQ-VAE latents. If you want to train both the VQ-VAE and the VOC model on pixel data then run [train_pixel_vq_vae_voc.py](https://github.com/manantomar/video-occupancy-models/blob/master/vae/train_pixel_vq_vae_voc.py). In case you want to create your own latents by traning VQ-VAE on a custom dataset use the `collect_latents()` and `train_vq_latents()` methods in [save_vq_codes.py](https://github.com/manantomar/video-occupancy-models/blob/master/vae/save_vq_codes.py). 48 | 49 | ### DINO VOC 50 | 51 | We use a quantized verison of [DINO](https://arxiv.org/abs/2104.14294) from [BEiT-v2](https://github.com/microsoft/unilm/tree/master/beit2). You would need to download this [dino model file](https://github.com/addf400/files/releases/download/BEiT-v2/vqkd_encoder_base_decoder_1x768x12_dino-663c55d7.pth) and place them one directory above where this repo is present. 52 | 53 | Run [train_vq_dino_voc.py](https://github.com/manantomar/video-occupancy-models/blob/master/dino/train_vq_dino_voc.py) to train a VOC model on stored DINO latents. Again, in case you want to create your own latents by running a quantized version of DINO on a custom dataset use the `collect_latents()` method in [save_dino_codes.py](https://github.com/manantomar/video-occupancy-models/blob/master/dino/save_dino_codes.py). 54 | 55 | ### MUSIK VOC 56 | 57 | In the case, action data is also available, we use a quantized multi-step inverse kinematics (MUSIK) objective to train the representation. 58 | 59 | Run [train_vq_musik_voc.py](https://github.com/manantomar/video-occupancy-models/blob/master/musik/train_vq_musik_voc.py) to train a VOC model along with the [MUSIK](https://arxiv.org/pdf/2211.00164) objective on pixel data. 60 | -------------------------------------------------------------------------------- /dino/vqkd_teacher/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /musik/musik_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | 5 | from taming.modules.diffusionmodules.model import Encoder, Decoder 6 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 7 | 8 | from network_utils import MUSIKPredictor 9 | 10 | class VQMUSIKModel(pl.LightningModule): 11 | def __init__(self, 12 | ddconfig, 13 | lossconfig, 14 | n_embed, 15 | embed_dim, 16 | ckpt_path=None, 17 | ignore_keys=[], 18 | image_key="image", 19 | colorize_nlabels=None, 20 | monitor=None, 21 | remap=None, 22 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 23 | ): 24 | super().__init__() 25 | self.image_key = image_key 26 | self.encoder = Encoder(**ddconfig) 27 | self.decoder = MUSIKPredictor(2 * embed_dim, 6, 1.0).to('cuda') 28 | self.quant_linear = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Linear(256 * 25, embed_dim)) 29 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, 30 | remap=remap, sane_index_shape=sane_index_shape) 31 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 32 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 33 | if ckpt_path is not None: 34 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 35 | self.image_key = image_key 36 | if colorize_nlabels is not None: 37 | assert type(colorize_nlabels)==int 38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 39 | if monitor is not None: 40 | self.monitor = monitor 41 | 42 | def init_from_ckpt(self, path, ignore_keys=list()): 43 | sd = torch.load(path, map_location="cpu")["state_dict"] 44 | keys = list(sd.keys()) 45 | for k in keys: 46 | for ik in ignore_keys: 47 | if k.startswith(ik): 48 | print("Deleting key {} from state_dict.".format(k)) 49 | del sd[k] 50 | self.load_state_dict(sd, strict=False) 51 | print(f"Restored from {path}") 52 | 53 | def encode(self, x, return_prequant=False): 54 | h = self.encoder(x) 55 | h = self.quant_conv(h) 56 | quant, emb_loss, info = self.quantize(h) 57 | if return_prequant: 58 | return quant, emb_loss, info, h 59 | else: 60 | return quant, emb_loss, info 61 | 62 | def decode(self, quant): 63 | quant = self.post_quant_conv(quant) 64 | dec = self.decoder(quant) 65 | return dec 66 | 67 | def decode_linear(self, quant): 68 | quant = self.post_quant_conv(quant) 69 | quant = quant.reshape(quant.shape[0], -1) 70 | dec = self.quant_linear(quant) 71 | return dec 72 | 73 | def decode_code(self, code_b): 74 | quant_b = self.quantize.embed_code(code_b) 75 | dec = self.decode(quant_b) 76 | return dec 77 | 78 | def forward(self, input): 79 | quant, diff, _ = self.encode(input) 80 | dec = self.decode(quant) 81 | return dec, diff 82 | 83 | def get_input(self, batch, k): 84 | x = batch[k] 85 | if len(x.shape) == 3: 86 | x = x[..., None] 87 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 88 | return x.float() 89 | 90 | def configure_optimizers(self): 91 | lr = self.learning_rate 92 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 93 | list(self.decoder.parameters())+ 94 | list(self.quantize.parameters())+ 95 | list(self.quant_conv.parameters())+ 96 | list(self.post_quant_conv.parameters()), 97 | lr=lr, betas=(0.5, 0.9)) 98 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 99 | lr=lr, betas=(0.5, 0.9)) 100 | return [opt_ae, opt_disc], [] 101 | 102 | def get_last_layer(self): 103 | return self.decoder.conv_out.weight 104 | 105 | def log_images(self, batch, **kwargs): 106 | log = dict() 107 | x = self.get_input(batch, self.image_key) 108 | x = x.to(self.device) 109 | xrec, _ = self(x) 110 | if x.shape[1] > 3: 111 | # colorize with random projection 112 | assert xrec.shape[1] > 3 113 | x = self.to_rgb(x) 114 | xrec = self.to_rgb(xrec) 115 | log["inputs"] = x 116 | log["reconstructions"] = xrec 117 | return log 118 | 119 | def to_rgb(self, x): 120 | assert self.image_key == "segmentation" 121 | if not hasattr(self, "colorize"): 122 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 123 | x = F.conv2d(x, weight=self.colorize) 124 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 125 | return x -------------------------------------------------------------------------------- /dino/save_dino_codes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../utils/') 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from pathlib import Path 10 | from matplotlib import pyplot as plt 11 | 12 | import random 13 | import utils 14 | import os 15 | import wandb 16 | import math 17 | 18 | from torchvision.utils import make_grid 19 | from drqv2 import RandomShiftsAug 20 | 21 | from dm_env import specs 22 | import dmc 23 | 24 | import h5py 25 | from omegaconf import OmegaConf 26 | 27 | from torchvision import transforms as pth_transforms 28 | from timm.models import create_model 29 | 30 | import modeling_vqkd 31 | 32 | def get_parameter_names(model, forbidden_layer_types): 33 | """ 34 | Returns the names of the model parameters that are not inside a forbidden layer. 35 | """ 36 | result = [] 37 | for name, child in model.named_children(): 38 | result += [ 39 | f"{name}.{n}" 40 | for n in get_parameter_names(child, forbidden_layer_types) 41 | if not isinstance(child, tuple(forbidden_layer_types)) 42 | ] 43 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child. 44 | result += list(model._parameters.keys()) 45 | return result 46 | 47 | class DINOAgent: 48 | def __init__(self, augmentation=RandomShiftsAug(pad=4)): 49 | 50 | wandb.init(project="Video Occupancy Models", 51 | entity=None, dir=os.getcwd()) 52 | 53 | self.hdf5_file_path = "../../cheetah_train/org/3_cheetah_run_random.hdf5" 54 | self.save_path = "../../cheetah_train/dino_latents/3_random_dino_latents.hdf5" 55 | self.dino_path = '../../vqkd_encoder_base_decoder_1x768x12_dino-663c55d7.pth' 56 | 57 | self.model = create_model( 58 | 'vqkd_encoder_base_decoder_1x768x12_dino', 59 | pretrained=True, 60 | pretrained_weight=self.dino_path, 61 | as_tokenzer=True, 62 | ).to('cuda').eval() 63 | 64 | # vq vae optimizer 65 | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=3e-4) 66 | 67 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to('cuda') 68 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to('cuda') 69 | 70 | self.device = 'cuda' 71 | # data augmentation 72 | self.aug = augmentation 73 | 74 | self.train() 75 | 76 | def train(self, training=True): 77 | self.training = training 78 | 79 | def process_obs(self, obs): 80 | obs = self.aug(obs.float()) 81 | obs = F.interpolate(obs, size=224) 82 | 83 | obs_shape = obs.shape 84 | 85 | obs = torch.einsum('nchw->nhwc', obs / 255.) - self.imagenet_mean / self.imagenet_std 86 | obs = torch.einsum('nhwc->nchw', obs).reshape((obs_shape[0], 3, *obs_shape[2:])) 87 | return obs, obs_shape 88 | 89 | def collect_latents(self, step=0): 90 | with h5py.File(self.hdf5_file_path, 'r') as hf: 91 | fobs = hf['observation'][()] 92 | faction = hf['action'][()] 93 | fdiscount = hf['discount'][()] 94 | freward = hf['reward'][()] 95 | fstep_type = hf['step_type'][()] 96 | 97 | fobs = fobs 98 | batch_size = 25 99 | 100 | assert fobs.shape[0] % batch_size == 0 101 | iters = fobs.shape[0] / batch_size 102 | 103 | print("Total Obs are {}, Batch Size is {}, Total Iterations is {}".format(fobs.shape[0], batch_size, iters)) 104 | 105 | dino_latents = [] 106 | for i in range(int(iters)): 107 | obs = fobs[i*batch_size:(i + 1)*batch_size] 108 | obs = utils.to_torch([obs], self.device)[0] 109 | 110 | obs, obs_shape = self.process_obs(obs) 111 | 112 | # dino embed 113 | quant_context, y_context, context_loss = self.model.encode(obs) 114 | 115 | y_context = y_context.reshape((obs_shape[0], -1)).detach() 116 | 117 | # collect discrete vq indices 118 | dino_latents.append(y_context) 119 | 120 | dataset_names = ['action', 'discount', 'observation', 'reward', 'step_type', 'dino_latents'] 121 | data_arrays = [faction, fdiscount, fobs, freward, fstep_type, torch.stack(dino_latents).reshape((batch_size * int(iters), -1)).cpu().long()] 122 | self.create_hdf5_file(self.save_path, dataset_names, data_arrays) 123 | 124 | def create_hdf5_file(self, file_path, dataset_names, data_arrays): 125 | """ 126 | Create an HDF5 file and store multiple arrays in it. 127 | 128 | Parameters: 129 | - file_path: Path to the HDF5 file. 130 | - dataset_names: List of dataset names. 131 | - data_arrays: List of NumPy arrays to be stored in the HDF5 file. 132 | """ 133 | # org dataset ['action', 'discount', 'observation', 'reward', 'step_type'] 134 | with h5py.File(file_path, 'w') as hf: 135 | for dataset_name, data_array in zip(dataset_names, data_arrays): 136 | # Create a dataset in the HDF5 file 137 | hf.create_dataset(dataset_name, data=data_array) 138 | 139 | 140 | if __name__ == "__main__": 141 | 142 | agent = DINOAgent() 143 | agent.collect_latents() 144 | -------------------------------------------------------------------------------- /utils/network_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import torch.distributed as dist 9 | import utils 10 | 11 | def weight_init(m): 12 | if isinstance(m, nn.Linear): 13 | nn.init.orthogonal_(m.weight.data) 14 | if hasattr(m.bias, 'data'): 15 | m.bias.data.fill_(0.0) 16 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 17 | gain = nn.init.calculate_gain('relu') 18 | nn.init.orthogonal_(m.weight.data, gain) 19 | if hasattr(m.bias, 'data'): 20 | m.bias.data.fill_(0.0) 21 | 22 | class RandomShiftsAug(nn.Module): 23 | def __init__(self, pad=4): 24 | super().__init__() 25 | self.pad = pad 26 | 27 | def forward(self, x): 28 | # x = T.Resize((x.shape[0], x.shape[1], 64, 64)) 29 | n, c, h, w = x.size() 30 | assert h == w 31 | padding = tuple([self.pad] * 4) 32 | x = F.pad(x, padding, 'replicate') 33 | eps = 1.0 / (h + 2 * self.pad) 34 | arange = torch.linspace(-1.0 + eps, 35 | 1.0 - eps, 36 | h + 2 * self.pad, 37 | device=x.device, 38 | dtype=x.dtype)[:h] 39 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 40 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 41 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 42 | 43 | shift = torch.randint(0, 44 | 2 * self.pad + 1, 45 | size=(n, 1, 1, 2), 46 | device=x.device, 47 | dtype=x.dtype) 48 | shift *= 2.0 / (h + 2 * self.pad) 49 | 50 | grid = base_grid + shift 51 | return F.grid_sample(x, 52 | grid, 53 | padding_mode='zeros', 54 | align_corners=False) 55 | 56 | 57 | def get_parameter_names(model, forbidden_layer_types): 58 | """ 59 | Returns the names of the model parameters that are not inside a forbidden layer. 60 | """ 61 | result = [] 62 | for name, child in model.named_children(): 63 | result += [ 64 | f"{name}.{n}" 65 | for n in get_parameter_names(child, forbidden_layer_types) 66 | if not isinstance(child, tuple(forbidden_layer_types)) 67 | ] 68 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child. 69 | result += list(model._parameters.keys()) 70 | return result 71 | 72 | class Predictor(nn.Module): 73 | def __init__(self, feature_dim): 74 | super(Predictor, self).__init__() 75 | 76 | self.l1 = nn.Linear(feature_dim, 256) 77 | self.l2 = nn.Linear(256, 256) 78 | self.l3 = nn.Linear(256, 1) 79 | 80 | def forward(self, state): 81 | a = F.relu(self.l1(state)) 82 | a = F.relu(self.l2(a)) 83 | return self.l3(a) 84 | 85 | 86 | class EMA(): 87 | def __init__(self, beta): 88 | super().__init__() 89 | self.beta = beta 90 | 91 | def update_average(self, old, new): 92 | if old is None: 93 | return new 94 | return old * self.beta + (1 - self.beta) * new 95 | 96 | 97 | def update_moving_average(ema_updater, ma_model, current_model): 98 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 99 | old_weight, up_weight = ma_params.data, current_params.data 100 | ma_params.data = ema_updater.update_average(old_weight, up_weight) 101 | 102 | 103 | class ConvPredictor(nn.Module): 104 | def __init__(self, obs_shape): 105 | super().__init__() 106 | 107 | assert len(obs_shape) == 3 108 | self.repr_dim = 32 * 35 * 35 109 | feature_dim = 50 110 | hidden_dim = 1024 111 | 112 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2), 113 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 114 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 115 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 116 | nn.ReLU()) 117 | 118 | self.trunk = nn.Sequential(nn.Linear(self.repr_dim, feature_dim), 119 | nn.LayerNorm(feature_dim), nn.Tanh()) 120 | 121 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 122 | nn.ReLU(inplace=True), 123 | nn.Linear(hidden_dim, hidden_dim), 124 | nn.ReLU(inplace=True), 125 | nn.Linear(hidden_dim, 21)) 126 | 127 | # self.linear = nn.Sequential(nn.Linear(self.repr_dim, 256), nn.ReLU(), 128 | # nn.Linear(256, 256), nn.ReLU(), 129 | # nn.Linear(256, 21), nn.Tanh()) 130 | 131 | self.apply(utils.weight_init) 132 | 133 | def forward(self, obs, std=0.1, eval=False): 134 | obs = obs / 255.0 - 0.5 135 | h = self.convnet(obs) 136 | h = h.view(h.shape[0], -1) 137 | h = self.trunk(h) 138 | mu = self.policy(h) 139 | 140 | std = torch.ones_like(mu) * std 141 | 142 | dist = utils.TruncatedNormal(mu, std) 143 | if eval: 144 | action = dist.mean 145 | else: 146 | action = dist.sample(clip=0.3) 147 | return action 148 | 149 | class projection_MLP(nn.Module): 150 | def __init__(self, in_dim, hidden_dim=256, out_dim=50): #256): 151 | super().__init__() 152 | # hidden_dim = in_dim 153 | self.layer1 = nn.Sequential( 154 | nn.Linear(in_dim, hidden_dim, bias=False), 155 | nn.BatchNorm1d(hidden_dim), 156 | nn.ReLU(inplace=True) 157 | ) 158 | self.layer2 = nn.Linear(hidden_dim, out_dim) 159 | def forward(self, x): 160 | x = self.layer1(x) 161 | x = self.layer2(x) 162 | return x 163 | 164 | class InfoNCE(nn.Module): 165 | def __init__(self, feature_dim, action_dim, num_actions=1): 166 | super().__init__() 167 | 168 | self.train_samples = 256 169 | self.action_dim = action_dim 170 | 171 | self.projector = projection_MLP(feature_dim, 256, 1) 172 | 173 | # self.apply(weight_init) 174 | 175 | def forward(self, x1, x2, action, return_logits=False): 176 | self.device = x1.device 177 | # Generate N negatives, one for each element in the batch: (B, N, D). 178 | negatives = self.sample(x1.size(0), action.size(1)) 179 | 180 | # Merge target and negatives: (B, N+1, D). 181 | targets = torch.cat([action.unsqueeze(dim=1), negatives], dim=1) 182 | 183 | # Generate a random permutation of the positives and negatives. 184 | permutation = torch.rand(targets.size(0), targets.size(1)).argsort(dim=1) 185 | targets = targets[torch.arange(targets.size(0)).unsqueeze(-1), permutation] 186 | 187 | # Get the original index of the positive. This will serve as the class label 188 | # for the loss. 189 | ground_truth = (permutation == 0).nonzero()[:, 1].to(self.device) 190 | 191 | # For every element in the mini-batch, there is 1 positive for which the EBM 192 | # should output a low energy value, and N negatives for which the EBM should 193 | # output high energy values. 194 | fused = torch.cat([x1.unsqueeze(1).expand(-1, targets.size(1), -1), x2.unsqueeze(1).expand(-1, targets.size(1), -1), targets], dim=-1) 195 | B, N, D = fused.size() 196 | fused = fused.reshape(B * N, D) 197 | out = self.projector(fused) 198 | energy = out.view(B, N) 199 | 200 | # Interpreting the energy as a negative logit, we can apply a cross entropy loss 201 | # to train the EBM. 202 | logits = -1.0 * energy 203 | loss = F.cross_entropy(logits, ground_truth.detach()) 204 | 205 | if return_logits: 206 | return logits 207 | 208 | return loss 209 | 210 | def _sample(self, num_samples: int, action_size: int) -> torch.Tensor: 211 | """Helper method for drawing samples from the uniform random distribution.""" 212 | size = (num_samples, action_size) 213 | samples = np.random.uniform(-1, 1, size=size) 214 | return torch.as_tensor(samples, dtype=torch.float32, device=self.device) 215 | 216 | def sample(self, batch_size: int, action_size: int) -> torch.Tensor: 217 | samples = self._sample(batch_size * self.train_samples, action_size) 218 | return samples.reshape(batch_size, self.train_samples, -1) 219 | 220 | class MUSIKPredictor(nn.Module): 221 | def __init__(self, state_dim, action_dim, max_action): 222 | super(MUSIKPredictor, self).__init__() 223 | 224 | feature_dim = 50 225 | hidden_dim = 1024 226 | 227 | self.trunk = nn.Sequential(nn.Linear(state_dim, feature_dim), 228 | nn.LayerNorm(feature_dim), nn.Tanh()) 229 | 230 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 231 | nn.ReLU(inplace=True), 232 | nn.Linear(hidden_dim, hidden_dim), 233 | nn.ReLU(inplace=True), 234 | nn.Linear(hidden_dim, action_dim)) 235 | 236 | self.max_action = max_action 237 | self.apply(utils.weight_init) 238 | 239 | def forward(self, state, std=0.1, eval=False): 240 | h = self.trunk(state) 241 | mu = self.policy(h) 242 | mu = torch.tanh(mu) 243 | 244 | std = torch.ones_like(mu) * std 245 | 246 | dist = utils.TruncatedNormal(mu, std) 247 | if eval: 248 | action = dist.mean 249 | else: 250 | action = dist.sample(clip=0.3) 251 | return action 252 | 253 | 254 | -------------------------------------------------------------------------------- /vae/save_vq_codes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../utils/') 3 | 4 | from taming.models.vqgan import VQModel 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from pathlib import Path 12 | from matplotlib import pyplot as plt 13 | 14 | import random 15 | import utils 16 | import os 17 | import wandb 18 | import math 19 | 20 | from torchvision.utils import make_grid 21 | from drqv2 import RandomShiftsAug 22 | 23 | from dm_env import specs 24 | import dmc 25 | 26 | import h5py 27 | from omegaconf import OmegaConf 28 | 29 | def get_parameter_names(model, forbidden_layer_types): 30 | """ 31 | Returns the names of the model parameters that are not inside a forbidden layer. 32 | """ 33 | result = [] 34 | for name, child in model.named_children(): 35 | result += [ 36 | f"{name}.{n}" 37 | for n in get_parameter_names(child, forbidden_layer_types) 38 | if not isinstance(child, tuple(forbidden_layer_types)) 39 | ] 40 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child. 41 | result += list(model._parameters.keys()) 42 | return result 43 | 44 | class VQAgent: 45 | def __init__(self, augmentation=RandomShiftsAug(pad=4)): 46 | 47 | wandb.init(project="Video Occupancy Models", 48 | entity=None, dir=os.getcwd()) 49 | 50 | self.hdf5_dir_path = "../../walker_train/org/" 51 | self.hdf5_file_path = "../../walker_train/org/3_walker_walk_medium.hdf5" 52 | self.save_path = "../../walker_train/vq_latents/3_medium_vq_latents.hdf5" 53 | self.taming_path = "../../" 54 | self.vq_model_path = "../../walker_vq_model.pth" 55 | 56 | ################################################################################ 57 | # # 58 | # VQ VAE Setup # 59 | # # 60 | ################################################################################ 61 | config_path = os.path.join(self.taming_path, "vqgan_imagenet_f16_1024/configs/model.yaml") 62 | config = OmegaConf.load(config_path) 63 | 64 | self.model = VQModel(**config.model.params).to('cuda') 65 | 66 | self.from_imagenet = False 67 | if self.from_imagenet: 68 | ckpt_path = os.path.join(self.taming_path, "vqgan_imagenet_f16_1024/ckpts/last.ckpt") 69 | sd = torch.load(ckpt_path, map_location="cuda")["state_dict"] 70 | missing, unexpected = self.model.load_state_dict(sd, strict=False) 71 | else: 72 | print("Loading fine-tuned VQ model...") 73 | self.model.load_state_dict(torch.load(self.vq_model_path)) 74 | 75 | # vq vae optimizer 76 | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=3e-4) 77 | 78 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to('cuda') 79 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to('cuda') 80 | 81 | self.device = 'cuda' 82 | # data augmentation 83 | self.aug = augmentation 84 | 85 | self.train() 86 | 87 | def train(self, training=True): 88 | self.training = training 89 | 90 | def process_obs(self, obs): 91 | obs = self.aug(obs.float()) 92 | obs = F.interpolate(obs, size=80) 93 | 94 | obs_shape = obs.shape 95 | 96 | obs = torch.einsum('nchw->nhwc', obs / 255.) - self.imagenet_mean / self.imagenet_std 97 | obs = torch.einsum('nhwc->nchw', obs).reshape((obs_shape[0], 3, *obs_shape[2:])) 98 | return obs, obs_shape 99 | 100 | def collect_latents(self, step=0): 101 | with h5py.File(self.hdf5_file_path, 'r') as hf: 102 | fobs = hf['observation'][()] 103 | faction = hf['action'][()] 104 | fdiscount = hf['discount'][()] 105 | freward = hf['reward'][()] 106 | fstep_type = hf['step_type'][()] 107 | 108 | fobs = fobs 109 | batch_size = 25 110 | 111 | assert fobs.shape[0] % batch_size == 0 112 | iters = fobs.shape[0] / batch_size 113 | 114 | print("Total Obs are {}, Batch Size is {}, Total Iterations is {}".format(fobs.shape[0], batch_size, iters)) 115 | 116 | vae_latents = [] 117 | for i in range(int(iters)): 118 | obs = fobs[i*batch_size:(i + 1)*batch_size] 119 | obs = utils.to_torch([obs], self.device)[0] 120 | 121 | obs, obs_shape = self.process_obs(obs) 122 | 123 | # vq embed 124 | quant_context, emb_loss_context, info_context = self.model.encode(obs) 125 | 126 | # collect discrete vq indices 127 | y_context = info_context[2].view(obs_shape[0], -1).detach() 128 | vae_latents.append(y_context) 129 | 130 | dataset_names = ['action', 'discount', 'observation', 'reward', 'step_type', 'vae_latents'] 131 | data_arrays = [faction, fdiscount, fobs, freward, fstep_type, torch.stack(vae_latents).reshape((batch_size * int(iters), -1)).cpu().long()] 132 | self.create_hdf5_file(self.save_path, dataset_names, data_arrays) 133 | 134 | def train_vq_latents(self): 135 | hdf5_files = [f for f in os.listdir(self.hdf5_dir_path) if f.endswith('.hdf5')] 136 | 137 | fobs = [] 138 | # Loop through each file and read data 139 | for file in hdf5_files: 140 | file_path = os.path.join(self.hdf5_dir_path, file) 141 | with h5py.File(file_path, 'r') as hf: 142 | fobs.append(hf['observation'][()]) 143 | fobs = np.stack(fobs) 144 | fobs = fobs.reshape((-1, *fobs.shape[2:])) 145 | 146 | batch_size = 64 147 | 148 | vae_latents = [] 149 | 150 | for step in range(50000): # Num of updates 151 | idx = np.random.choice(fobs.shape[0], size=batch_size) 152 | obs = fobs[idx] 153 | obs = utils.to_torch([obs], self.device)[0] 154 | 155 | obs, obs_shape = self.process_obs(obs) 156 | 157 | # vq embed 158 | quant_context, emb_loss_context, info_context = self.model.encode(obs) 159 | 160 | xrec, qloss = self.model.decode(quant_context), emb_loss_context 161 | vae_loss, log_dict_ae = self.model.loss(qloss, obs, xrec, 0, step, last_layer=self.model.get_last_layer(), split="train") 162 | 163 | # collect discrete vq indices 164 | y_context = info_context[2].view(obs_shape[0], -1).detach() 165 | 166 | if step % 100 == 0: 167 | with torch.no_grad(): 168 | print("vae loss", vae_loss) 169 | viz_imgs = [] 170 | viz_imgs.append(xrec) 171 | viz_imgs.append(obs) 172 | 173 | viz_imgs = torch.stack(viz_imgs)[:, :5] 174 | t, n, c, h, w = viz_imgs.shape 175 | viz_imgs = torch.einsum('tnchw->ntchw', viz_imgs) 176 | viz_imgs = viz_imgs.reshape(t*n, c, h, w) 177 | viz_img = make_grid(viz_imgs, nrow=t, normalize=True, scale_each=True) 178 | 179 | img = wandb.Image(viz_img) 180 | wandb.log({f"Gamma Pred": img}, step=step) 181 | 182 | loss = vae_loss 183 | loss.backward() 184 | 185 | if step % 1 == 0: 186 | self.optimizer.step() 187 | self.optimizer.zero_grad() 188 | 189 | if step % 2000 == 0: 190 | self.save_vq_weights(step) 191 | 192 | def create_hdf5_file(self, file_path, dataset_names, data_arrays): 193 | """ 194 | Create an HDF5 file and store multiple arrays in it. 195 | 196 | Parameters: 197 | - file_path: Path to the HDF5 file. 198 | - dataset_names: List of dataset names. 199 | - data_arrays: List of NumPy arrays to be stored in the HDF5 file. 200 | """ 201 | # org dataset ['action', 'discount', 'observation', 'reward', 'step_type'] 202 | with h5py.File(file_path, 'w') as hf: 203 | for dataset_name, data_array in zip(dataset_names, data_arrays): 204 | # Create a dataset in the HDF5 file 205 | hf.create_dataset(dataset_name, data=data_array) 206 | 207 | def save_vq_weights(self, step): 208 | torch.save(self.model.state_dict(), self.vq_model_path) 209 | 210 | 211 | if __name__ == "__main__": 212 | 213 | agent = VQAgent() 214 | agent.collect_latents() 215 | 216 | # agent.train_vq_latents() 217 | -------------------------------------------------------------------------------- /dino/norm_ema_quantizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beitv2 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Zhiliang Peng 7 | # Based on VQGAN code bases 8 | # https://github.com/CompVis/taming-transformers 9 | # --------------------------------------------------------' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.distributed as distributed 15 | from einops import rearrange, repeat 16 | 17 | 18 | def l2norm(t): 19 | return F.normalize(t, p = 2, dim = -1) 20 | 21 | def ema_inplace(moving_avg, new, decay): 22 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) 23 | 24 | def sample_vectors(samples, num): 25 | num_samples, device = samples.shape[0], samples.device 26 | 27 | if num_samples >= num: 28 | indices = torch.randperm(num_samples, device = device)[:num] 29 | else: 30 | indices = torch.randint(0, num_samples, (num,), device = device) 31 | 32 | return samples[indices] 33 | 34 | def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): 35 | dim, dtype, device = samples.shape[-1], samples.dtype, samples.device 36 | 37 | means = sample_vectors(samples, num_clusters) 38 | 39 | for _ in range(num_iters): 40 | if use_cosine_sim: 41 | dists = samples @ means.t() 42 | else: 43 | diffs = rearrange(samples, 'n d -> n () d') \ 44 | - rearrange(means, 'c d -> () c d') 45 | dists = -(diffs ** 2).sum(dim = -1) 46 | 47 | buckets = dists.max(dim = -1).indices 48 | bins = torch.bincount(buckets, minlength = num_clusters) 49 | zero_mask = bins == 0 50 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 51 | 52 | new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype) 53 | new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples) 54 | new_means = new_means / bins_min_clamped[..., None] 55 | 56 | if use_cosine_sim: 57 | new_means = l2norm(new_means) 58 | 59 | means = torch.where(zero_mask[..., None], means, new_means) 60 | 61 | return means, bins 62 | 63 | 64 | class EmbeddingEMA(nn.Module): 65 | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''): 66 | super().__init__() 67 | self.num_tokens = num_tokens 68 | self.codebook_dim = codebook_dim 69 | self.decay = decay 70 | self.eps = eps 71 | if codebook_init_path == '': 72 | if not kmeans_init: 73 | weight = torch.randn(num_tokens, codebook_dim) 74 | weight = l2norm(weight) 75 | else: 76 | weight = torch.zeros(num_tokens, codebook_dim) 77 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 78 | else: 79 | print(f"load init codebook weight from {codebook_init_path}") 80 | codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu') 81 | weight = codebook_ckpt_weight.clone() 82 | self.register_buffer('initted', torch.Tensor([True])) 83 | 84 | self.weight = nn.Parameter(weight, requires_grad = False) 85 | self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) 86 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) 87 | # self.register_buffer('initted', torch.Tensor([not kmeans_init])) 88 | self.update = True 89 | 90 | @torch.jit.ignore 91 | def init_embed_(self, data): 92 | if self.initted: 93 | return 94 | print("Performing Kemans init for codebook") 95 | embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim = True) 96 | self.weight.data.copy_(embed) 97 | self.cluster_size.data.copy_(cluster_size) 98 | self.initted.data.copy_(torch.Tensor([True])) 99 | 100 | def forward(self, embed_id): 101 | return F.embedding(embed_id, self.weight) 102 | 103 | def cluster_size_ema_update(self, new_cluster_size): 104 | self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) 105 | 106 | def embed_avg_ema_update(self, new_embed_avg): 107 | self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) 108 | 109 | def weight_update(self, num_tokens): 110 | n = self.cluster_size.sum() 111 | smoothed_cluster_size = ( 112 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n 113 | ) 114 | #normalize embedding average with smoothed cluster size 115 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) 116 | # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) 117 | self.weight.data.copy_(embed_normalized) 118 | 119 | def norm_ema_inplace(moving_avg, new, decay): 120 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) 121 | moving_avg.data.copy_(l2norm(moving_avg.data)) 122 | 123 | class NormEMAVectorQuantizer(nn.Module): 124 | def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, 125 | statistic_code_usage=True, kmeans_init=False, codebook_init_path=''): 126 | super().__init__() 127 | self.codebook_dim = embedding_dim 128 | self.num_tokens = n_embed 129 | self.beta = beta 130 | self.decay = decay 131 | 132 | # learnable = True if orthogonal_reg_weight > 0 else False 133 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path) 134 | 135 | self.statistic_code_usage = statistic_code_usage 136 | if statistic_code_usage: 137 | self.register_buffer('cluster_size', torch.zeros(n_embed)) 138 | if distributed.is_available() and distributed.is_initialized(): 139 | print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!") 140 | self.all_reduce_fn = distributed.all_reduce 141 | else: 142 | self.all_reduce_fn = nn.Identity() 143 | 144 | def reset_cluster_size(self, device): 145 | if self.statistic_code_usage: 146 | self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) 147 | self.cluster_size = self.cluster_size.to(device) 148 | 149 | def forward(self, z): 150 | # reshape z -> (batch, height, width, channel) and flatten 151 | #z, 'b c h w -> b h w c' 152 | z = rearrange(z, 'b c h w -> b h w c') 153 | z = l2norm(z) 154 | z_flattened = z.reshape(-1, self.codebook_dim) 155 | 156 | self.embedding.init_embed_(z_flattened) 157 | 158 | d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ 159 | self.embedding.weight.pow(2).sum(dim=1) - 2 * \ 160 | torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' 161 | 162 | encoding_indices = torch.argmin(d, dim=1) 163 | 164 | z_q = self.embedding(encoding_indices).view(z.shape) 165 | 166 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) 167 | 168 | if not self.training: 169 | with torch.no_grad(): 170 | cluster_size = encodings.sum(0) 171 | self.all_reduce_fn(cluster_size) 172 | ema_inplace(self.cluster_size, cluster_size, self.decay) 173 | 174 | if self.training and self.embedding.update: 175 | #EMA cluster size 176 | 177 | bins = encodings.sum(0) 178 | self.all_reduce_fn(bins) 179 | 180 | # self.embedding.cluster_size_ema_update(bins) 181 | ema_inplace(self.cluster_size, bins, self.decay) 182 | 183 | zero_mask = (bins == 0) 184 | bins = bins.masked_fill(zero_mask, 1.) 185 | 186 | embed_sum = z_flattened.t() @ encodings 187 | self.all_reduce_fn(embed_sum) 188 | 189 | embed_normalized = (embed_sum / bins.unsqueeze(0)).t() 190 | embed_normalized = l2norm(embed_normalized) 191 | 192 | embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, 193 | embed_normalized) 194 | norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) 195 | 196 | # compute loss for embedding 197 | loss = self.beta * F.mse_loss(z_q.detach(), z) 198 | 199 | # preserve gradients 200 | z_q = z + (z_q - z).detach() 201 | 202 | # reshape back to match original input shape 203 | #z_q, 'b h w c -> b c h w' 204 | z_q = rearrange(z_q, 'b h w c -> b c h w') 205 | return z_q, loss, encoding_indices 206 | 207 | -------------------------------------------------------------------------------- /dino/vqkd_teacher/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 37 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 38 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 39 | } 40 | 41 | 42 | def _download(url: str, root: str): 43 | os.makedirs(root, exist_ok=True) 44 | filename = os.path.basename(url) 45 | 46 | expected_sha256 = url.split("/")[-2] 47 | download_target = os.path.join(root, filename) 48 | 49 | if os.path.exists(download_target) and not os.path.isfile(download_target): 50 | raise RuntimeError(f"{download_target} exists and is not a regular file") 51 | 52 | if os.path.isfile(download_target): 53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 54 | return download_target 55 | else: 56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 57 | 58 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 60 | while True: 61 | buffer = source.read(8192) 62 | if not buffer: 63 | break 64 | 65 | output.write(buffer) 66 | loop.update(len(buffer)) 67 | 68 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 69 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 70 | 71 | return download_target 72 | 73 | 74 | def _convert_image_to_rgb(image): 75 | return image.convert("RGB") 76 | 77 | 78 | def _transform(n_px): 79 | return Compose([ 80 | Resize(n_px, interpolation=BICUBIC), 81 | CenterCrop(n_px), 82 | _convert_image_to_rgb, 83 | ToTensor(), 84 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 85 | ]) 86 | 87 | 88 | def available_models() -> List[str]: 89 | """Returns the names of available CLIP models""" 90 | return list(_MODELS.keys()) 91 | 92 | 93 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 94 | """Load a CLIP model 95 | 96 | Parameters 97 | ---------- 98 | name : str 99 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 100 | 101 | device : Union[str, torch.device] 102 | The device to put the loaded model 103 | 104 | jit : bool 105 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 106 | 107 | download_root: str 108 | path to download the model files; by default, it uses "~/.cache/clip" 109 | 110 | Returns 111 | ------- 112 | model : torch.nn.Module 113 | The CLIP model 114 | 115 | preprocess : Callable[[PIL.Image], torch.Tensor] 116 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 117 | """ 118 | if name in _MODELS: 119 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 120 | elif os.path.isfile(name): 121 | model_path = name 122 | else: 123 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 124 | 125 | try: 126 | # loading JIT archive 127 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 128 | state_dict = None 129 | except RuntimeError: 130 | # loading saved state dict 131 | if jit: 132 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 133 | jit = False 134 | state_dict = torch.load(model_path, map_location="cpu") 135 | 136 | if not jit: 137 | model = build_model(state_dict or model.state_dict()).to(device) 138 | if str(device) == "cpu": 139 | model.float() 140 | return model, _transform(model.visual.input_resolution) 141 | 142 | # patch the device names 143 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 144 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 145 | 146 | def patch_device(module): 147 | try: 148 | graphs = [module.graph] if hasattr(module, "graph") else [] 149 | except RuntimeError: 150 | graphs = [] 151 | 152 | if hasattr(module, "forward1"): 153 | graphs.append(module.forward1.graph) 154 | 155 | for graph in graphs: 156 | for node in graph.findAllNodes("prim::Constant"): 157 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 158 | node.copyAttributes(device_node) 159 | 160 | model.apply(patch_device) 161 | patch_device(model.encode_image) 162 | patch_device(model.encode_text) 163 | 164 | # patch dtype to float32 on CPU 165 | if str(device) == "cpu": 166 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 167 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 168 | float_node = float_input.node() 169 | 170 | def patch_float(module): 171 | try: 172 | graphs = [module.graph] if hasattr(module, "graph") else [] 173 | except RuntimeError: 174 | graphs = [] 175 | 176 | if hasattr(module, "forward1"): 177 | graphs.append(module.forward1.graph) 178 | 179 | for graph in graphs: 180 | for node in graph.findAllNodes("aten::to"): 181 | inputs = list(node.inputs()) 182 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 183 | if inputs[i].node()["value"] == 5: 184 | inputs[i].node().copyAttributes(float_node) 185 | 186 | model.apply(patch_float) 187 | patch_float(model.encode_image) 188 | patch_float(model.encode_text) 189 | 190 | model.float() 191 | 192 | return model, _transform(model.input_resolution.item()) 193 | 194 | 195 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 196 | """ 197 | Returns the tokenized representation of given input string(s) 198 | 199 | Parameters 200 | ---------- 201 | texts : Union[str, List[str]] 202 | An input string or a list of input strings to tokenize 203 | 204 | context_length : int 205 | The context length to use; all CLIP models use 77 as the context length 206 | 207 | truncate: bool 208 | Whether to truncate the text in case its encoding is longer than the context length 209 | 210 | Returns 211 | ------- 212 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 213 | """ 214 | if isinstance(texts, str): 215 | texts = [texts] 216 | 217 | sot_token = _tokenizer.encoder["<|startoftext|>"] 218 | eot_token = _tokenizer.encoder["<|endoftext|>"] 219 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 220 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 221 | 222 | for i, tokens in enumerate(all_tokens): 223 | if len(tokens) > context_length: 224 | if truncate: 225 | tokens = tokens[:context_length] 226 | tokens[-1] = eot_token 227 | else: 228 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 229 | result[i, :len(tokens)] = torch.tensor(tokens) 230 | 231 | return result 232 | -------------------------------------------------------------------------------- /utils/dmc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from collections import deque 6 | from typing import Any, NamedTuple 7 | 8 | import dm_env 9 | import numpy as np 10 | from dm_control import manipulation, suite 11 | from dm_control.suite.wrappers import action_scale, pixels 12 | from dm_env import StepType, specs 13 | 14 | # from envs.distracting_control.suite import distracting_wrapper 15 | #import envs.fb_mtenv_dmc as fb_mtenv_dmc 16 | 17 | 18 | def get_unique_int(difficulty: str) -> int: 19 | return int.from_bytes(f'{difficulty}_0'.encode(), 'little') % (2 ** 31) 20 | 21 | 22 | distracting_kwargs_lookup = { 23 | 'easy': {'difficulty': 'easy', 'fixed_distraction': False}, 24 | 'medium': {'difficulty': 'medium', 'fixed_distraction': False}, 25 | 'hard': {'difficulty': 'hard', 'fixed_distraction': False}, 26 | 'fixed_easy': {'difficulty': 'easy', 'fixed_distraction': True, 'color_seed': get_unique_int('easy'), 27 | 'background_seed': get_unique_int('easy'), 'camera_seed': get_unique_int('easy')}, 28 | 'fixed_medium': {'difficulty': 'medium', 'fixed_distraction': True, 'color_seed': get_unique_int('medium'), 29 | 'background_seed': get_unique_int('medium'), 'camera_seed': get_unique_int('medium')}, 30 | 'fixed_hard': {'difficulty': 'hard', 'fixed_distraction': True, 'color_seed': get_unique_int('hard'), 31 | 'background_seed': get_unique_int('hard'), 'camera_seed': get_unique_int('hard')}, 32 | } 33 | 34 | multitask_modes = [f'len_{i}' for i in range(1, 11, 1)] 35 | 36 | 37 | class ExtendedTimeStep(NamedTuple): 38 | step_type: Any 39 | reward: Any 40 | discount: Any 41 | observation: Any 42 | pixel_observation: Any 43 | action: Any 44 | latent: Any 45 | imp_action: Any 46 | k_step: Any 47 | 48 | def first(self): 49 | return self.step_type == StepType.FIRST 50 | 51 | def mid(self): 52 | return self.step_type == StepType.MID 53 | 54 | def last(self): 55 | return self.step_type == StepType.LAST 56 | 57 | #def __getitem__(self, attr): 58 | # return getattr(self, str(attr)) 59 | 60 | class ExtendedTimeStepEval(NamedTuple): 61 | step_type: Any 62 | reward: Any 63 | discount: Any 64 | observation: Any 65 | action: Any 66 | latent: Any 67 | imp_action: Any 68 | k_step: Any 69 | 70 | def first(self): 71 | return self.step_type == StepType.FIRST 72 | 73 | def mid(self): 74 | return self.step_type == StepType.MID 75 | 76 | def last(self): 77 | return self.step_type == StepType.LAST 78 | 79 | #def __getitem__(self, attr): 80 | # return getattr(self, str(attr)) 81 | 82 | class ActionRepeatWrapper(dm_env.Environment): 83 | def __init__(self, env, num_repeats): 84 | self._env = env 85 | self._num_repeats = num_repeats 86 | 87 | def step(self, action): 88 | reward = 0.0 89 | discount = 1.0 90 | for i in range(self._num_repeats): 91 | time_step = self._env.step(action) 92 | reward += (time_step.reward or 0.0) * discount 93 | discount *= time_step.discount 94 | if time_step.last(): 95 | break 96 | 97 | return time_step._replace(reward=reward, discount=discount) 98 | 99 | def observation_spec(self): 100 | return self._env.observation_spec() 101 | 102 | def action_spec(self): 103 | return self._env.action_spec() 104 | 105 | def reset(self): 106 | return self._env.reset() 107 | 108 | def __getattr__(self, name): 109 | return getattr(self._env, name) 110 | 111 | 112 | class FrameStackWrapper(dm_env.Environment): 113 | def __init__(self, env, num_frames, pixels_key='pixels'): 114 | self._env = env 115 | self._num_frames = num_frames 116 | self._frames = deque([], maxlen=num_frames) 117 | self._pixels_key = pixels_key 118 | 119 | wrapped_obs_spec = env.observation_spec() 120 | assert pixels_key in wrapped_obs_spec 121 | 122 | pixels_shape = wrapped_obs_spec[pixels_key].shape 123 | # remove batch dim 124 | if len(pixels_shape) == 4: 125 | pixels_shape = pixels_shape[1:] 126 | self._obs_spec = specs.BoundedArray(shape=np.concatenate( 127 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0), 128 | dtype=np.uint8, 129 | minimum=0, 130 | maximum=255, 131 | name='observation') 132 | 133 | def _transform_observation(self, time_step): 134 | assert len(self._frames) == self._num_frames 135 | obs = np.concatenate(list(self._frames), axis=0) 136 | return time_step._replace(observation=obs) 137 | 138 | def _extract_pixels(self, time_step): 139 | try: 140 | pos = time_step.observation['position'] 141 | except: 142 | pos = None 143 | pixels = time_step.observation[self._pixels_key] 144 | # remove batch dim 145 | if len(pixels.shape) == 4: 146 | pixels = pixels[0] 147 | return pixels.transpose(2, 0, 1).copy(), pos 148 | 149 | def reset(self): 150 | time_step = self._env.reset() 151 | pixels, pos = self._extract_pixels(time_step) 152 | for _ in range(self._num_frames): 153 | self._frames.append(pixels) 154 | return self._transform_observation(time_step), pos 155 | 156 | def step(self, action): 157 | time_step = self._env.step(action) 158 | pixels, pos = self._extract_pixels(time_step) 159 | self._frames.append(pixels) 160 | return self._transform_observation(time_step), pos 161 | 162 | def observation_spec(self): 163 | return self._obs_spec 164 | 165 | def action_spec(self): 166 | return self._env.action_spec() 167 | 168 | def __getattr__(self, name): 169 | return getattr(self._env, name) 170 | 171 | 172 | class ActionDTypeWrapper(dm_env.Environment): 173 | def __init__(self, env, dtype): 174 | self._env = env 175 | wrapped_action_spec = env.action_spec() 176 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape, 177 | dtype, 178 | wrapped_action_spec.minimum, 179 | wrapped_action_spec.maximum, 180 | 'action') 181 | 182 | def step(self, action): 183 | action = action.astype(self._env.action_spec().dtype) 184 | return self._env.step(action) 185 | 186 | def observation_spec(self): 187 | return self._env.observation_spec() 188 | 189 | def action_spec(self): 190 | return self._action_spec 191 | 192 | def reset(self): 193 | return self._env.reset() 194 | 195 | def __getattr__(self, name): 196 | return getattr(self._env, name) 197 | 198 | 199 | class ExtendedTimeStepWrapper(dm_env.Environment): 200 | def __init__(self, env): 201 | self._env = env 202 | 203 | def reset(self): 204 | time_step, pos = self._env.reset() 205 | return self._augment_time_step(time_step), pos 206 | 207 | def step(self, action): 208 | time_step, pos = self._env.step(action) 209 | return self._augment_time_step(time_step, action), pos 210 | 211 | def _augment_time_step(self, time_step, action=None): 212 | if action is None: 213 | action_spec = self.action_spec() 214 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 215 | return ExtendedTimeStepEval(observation=time_step.observation, 216 | step_type=time_step.step_type, 217 | action=action, 218 | reward=time_step.reward or 0.0, 219 | discount=time_step.discount or 1.0, 220 | latent=np.zeros(256), 221 | imp_action=np.zeros(84*84*1), 222 | k_step=0) 223 | 224 | def observation_spec(self): 225 | return self._env.observation_spec() 226 | 227 | def action_spec(self): 228 | return self._env.action_spec() 229 | 230 | def __getattr__(self, name): 231 | return getattr(self._env, name) 232 | 233 | 234 | def make(name, frame_stack, action_repeat, seed, distracting_mode: str = None, multitask_mode: str = None): 235 | pixel_hw = 84 236 | if 'offline' in name: 237 | name = '_'.join(name.split('_')[1:3]) 238 | domain, task = name.split('_', 1) 239 | # overwrite cup to ball_in_cup 240 | domain = dict(cup='ball_in_cup').get(domain, domain) 241 | 242 | # make sure reward is not visualized 243 | if multitask_mode is None: 244 | if (domain, task) in suite.ALL_TASKS: 245 | env = suite.load(domain, 246 | task, 247 | task_kwargs={'random': seed}, 248 | visualize_reward=False) 249 | pixels_key = 'pixels' 250 | else: 251 | name = f'{domain}_{task}_vision' 252 | env = manipulation.load(name, seed=seed) 253 | pixels_key = 'front_close' 254 | else: 255 | assert multitask_mode in multitask_modes, 'Unrecognised length setting' 256 | idx = multitask_mode.split('_', 1)[1] 257 | 258 | if domain == 'walker' and task == 'walk': 259 | xml = f'len_{idx}' 260 | elif domain == 'cheetah' and task == 'run': 261 | xml = f'torso_length_{idx}' 262 | else: 263 | raise Exception 264 | 265 | env = fb_mtenv_dmc.load( 266 | domain_name=domain, 267 | task_name=task, 268 | task_kwargs={'xml_file_id': xml, 'random': seed}, 269 | visualize_reward=False, 270 | ) 271 | pixels_key = 'pixels' 272 | 273 | # add wrappers 274 | env = ActionDTypeWrapper(env, np.float32) 275 | env = ActionRepeatWrapper(env, action_repeat) 276 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) 277 | # add renderings for clasical tasks 278 | if (domain, task) in suite.ALL_TASKS: 279 | # zoom in camera for quadruped 280 | camera_id = dict(quadruped=2).get(domain, 0) 281 | render_kwargs = dict(height=pixel_hw, width=pixel_hw, camera_id=camera_id) 282 | if distracting_mode is not None: 283 | assert distracting_mode in distracting_kwargs_lookup, 'Unrecognised distraction' 284 | kwargs = distracting_kwargs_lookup[distracting_mode] 285 | kwargs['pixels_only'] = False 286 | kwargs['render_kwargs'] = render_kwargs 287 | kwargs['background_dataset_path'] = "/home/manant/scratch/DAVIS/JPEGImages/480p/" #"DAVIS/JPEGImages/480p/" 288 | env = distracting_wrapper( 289 | env, 290 | domain, 291 | **kwargs 292 | ) 293 | else: 294 | env = pixels.Wrapper(env, 295 | pixels_only=False, 296 | render_kwargs=render_kwargs) 297 | # stack several frames 298 | env = FrameStackWrapper(env, frame_stack, pixels_key) 299 | env = ExtendedTimeStepWrapper(env) 300 | return env 301 | -------------------------------------------------------------------------------- /dino/train_vq_dino_voc.py: -------------------------------------------------------------------------------- 1 | # from torchvision import transforms as pth_transforms 2 | from timm.models import create_model 3 | import modeling_vqkd 4 | 5 | import sys 6 | sys.path.insert(0, '../utils/') 7 | 8 | import os 9 | import math 10 | import copy 11 | import wandb 12 | import random 13 | import utils 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from omegaconf import OmegaConf 20 | 21 | from pathlib import Path 22 | from matplotlib import pyplot as plt 23 | from torchvision.utils import make_grid 24 | 25 | from numpy_replay_buffer import EfficientReplayBuffer 26 | from utils import load_offline_dataset_into_buffer 27 | 28 | # import dmc 29 | from dm_env import specs 30 | 31 | from transformers.optimization import get_scheduler 32 | from transformers import GPT2LMHeadModel, GPT2Config 33 | 34 | from network_utils import get_parameter_names, update_moving_average, Predictor, EMA, RandomShiftsAug 35 | 36 | ################## 37 | obs_resize = 224 38 | 39 | offline_dir = "../../cheetah_train/dino_latents/" 40 | dino_path = '../../vqkd_encoder_base_decoder_1x768x12_dino-663c55d7.pth' 41 | save_dir_path = "./" 42 | 43 | batch_size = 32 44 | frame_stack = 2 45 | device = 'cuda' 46 | ################## 47 | # - Pretrained DINO + Quantization 48 | ################## 49 | 50 | class TFAgent: 51 | def __init__(self, discount=0.8, augmentation=RandomShiftsAug(pad=4)): 52 | 53 | wandb.init(project="Video Occupancy Models", 54 | id="voc_dino_gamma_{}".format(discount), 55 | entity=None, dir=os.getcwd()) 56 | 57 | data_specs = (specs.BoundedArray(shape=(9, 84, 84), dtype=np.uint8, name='observation', minimum=0, maximum=255), 58 | specs.BoundedArray(shape=(6,), dtype=np.float32, name='action', minimum=-1.0, maximum=1.0), 59 | specs.Array((1,), np.float32, 'reward'), 60 | specs.Array((1,), np.float32, 'discount')) 61 | 62 | self.discount = discount 63 | self.batch_size = batch_size 64 | self.nsteps = min(int(1 / (1 - self.discount)), 10) # only sample max 10 steps 65 | self.codebook_size = 8192 66 | 67 | ######### train on VQ latents ######### 68 | self.replay_buffer = EfficientReplayBuffer(1000000, self.batch_size, 1, self.discount, frame_stack, False, data_specs) 69 | load_offline_dataset_into_buffer(Path(offline_dir), self.replay_buffer, None, frame_stack, 1000000, latent_style="dino") 70 | ########################## 71 | 72 | self.model = create_model( 73 | 'vqkd_encoder_base_decoder_1x768x12_dino', 74 | pretrained=True, 75 | pretrained_weight=dino_path, 76 | as_tokenzer=True, 77 | ) #.to('cuda').eval() 78 | self.model.to(device).eval() 79 | 80 | ######### gpt setup ######### 81 | configuration = GPT2Config(vocab_size=8192, n_positions=196 * frame_stack * 2, n_layer=4, n_head=4, n_embed=128) # nano-gpt 82 | self.gpt = GPT2LMHeadModel(configuration).to(device) 83 | self.gpt_target = copy.deepcopy(self.gpt) 84 | self.gpt_target.generation_config.output_scores = True 85 | self.target_ema_updater = EMA(0.9) 86 | ########################## 87 | 88 | ######### optimizer setup ######### 89 | self.reward_predictor = Predictor(32 * 14 * 14 * frame_stack).to(device) 90 | self.gpt_optimizer = torch.optim.AdamW(list(self.get_grouped_params(self.gpt)), lr=3e-4) 91 | self.optimizer = torch.optim.AdamW(list(self.reward_predictor.parameters()), lr=3e-4) 92 | 93 | num_training_steps = 100000 94 | self.warmup_ratio = 0.05 95 | warmup_steps = math.ceil(num_training_steps * self.warmup_ratio) 96 | self.lr_scheduler = get_scheduler( 97 | "cosine", 98 | optimizer=self.gpt_optimizer, 99 | num_warmup_steps=warmup_steps, 100 | num_training_steps=num_training_steps, 101 | ) 102 | ########################## 103 | 104 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to(device) 105 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to(device) 106 | 107 | self.device = device 108 | self.aug = augmentation 109 | 110 | self.saving_iter = [50, 100, 500, 1000, 2000, 5000, 10000, 50000, 75000, 100000] 111 | self.train() 112 | 113 | def get_grouped_params(self, model): 114 | decay_parameters = get_parameter_names(model, [nn.LayerNorm]) 115 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 116 | optimizer_grouped_parameters = [ 117 | { 118 | "params": [ 119 | p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad) 120 | ], 121 | "weight_decay": 0.1, 122 | }, 123 | { 124 | "params": [ 125 | p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad) 126 | ], 127 | "weight_decay": 0.0, 128 | }, 129 | ] 130 | return optimizer_grouped_parameters 131 | 132 | def train(self, training=True): 133 | self.training = training 134 | 135 | def preprocess_obs(self, obs): 136 | obs = F.interpolate(obs, size=obs_resize) 137 | 138 | if len(obs.shape) == 3: 139 | obs = obs.unsqueeze(0) 140 | 141 | try: 142 | assert len(obs.shape) == 4 # B x C x H x W 143 | org_obs_shape = obs.shape 144 | # normalize and preprocess 145 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(frame_stack)]) 146 | obs = torch.einsum('snhwc->nschw', obs).reshape((org_obs_shape[0] * frame_stack, 3, *org_obs_shape[2:])) 147 | except: 148 | assert len(obs.shape) == 5 # T x B x C x H x W 149 | org_obs_shape = t, b, c, h, w = obs.shape 150 | obs = torch.stack([torch.einsum('tnchw->tnhwc', obs[:, :, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(frame_stack)]) 151 | obs = torch.einsum('stnhwc->tnschw', obs).reshape((t, b * frame_stack, 3, h, w)) 152 | return obs, org_obs_shape 153 | 154 | def update(self, step=0): 155 | metrics = dict() 156 | 157 | batch, indices = next(self.replay_buffer) 158 | obs, action, reward, discount, next_obs, _, _, _, obs_k = utils.to_torch(batch, self.device) 159 | 160 | y_context = obs.long() 161 | y_target = obs_k.long() 162 | obs_shape = obs.shape 163 | 164 | quant_target = self.model.get_codebook_entry(y_target.reshape(obs_shape[0]*frame_stack, -1)) 165 | 166 | quant_context = self.model.get_codebook_entry(y_context.reshape(obs_shape[0]*frame_stack, -1)) 167 | 168 | pred_reward = self.reward_predictor(quant_target.detach().float().reshape(obs_shape[0], -1)) 169 | reward_loss = F.mse_loss(pred_reward, reward.float()).mean() 170 | 171 | # generate target 172 | with torch.no_grad(): 173 | p_t = self.gpt_target.generate(y_target, max_new_tokens=y_target.shape[-1], do_sample=True, pad_token_id=-100) 174 | p_t = p_t[:, -y_target.shape[-1]:] 175 | 176 | # gamma sampling 177 | gamma = self.discount * torch.ones((y_context.shape[0], ), device=y_context.device) 178 | prob = torch.bernoulli(gamma) 179 | p_target = torch.zeros_like(y_target) 180 | 181 | # with prob 1-gamma, sample from next state 182 | p_c_idx = torch.nonzero(1 - prob) 183 | p_target[p_c_idx] = y_target[p_c_idx] 184 | 185 | # with prob gamma, sample from bootstrapped model 186 | p_t_idx = torch.nonzero(prob) 187 | p_target[p_t_idx] = p_t[p_t_idx] 188 | 189 | # gpt predictions 190 | inp = torch.cat([y_context, p_target], dim=1) 191 | # mask_ids = torch.cat([context_mask_ids, target_mask_ids], dim=1) 192 | outputs = self.gpt(inp, labels=inp) 193 | gpt_loss = outputs.loss 194 | 195 | loss = gpt_loss + reward_loss 196 | loss.backward() 197 | 198 | # grad accumulate 199 | if step % 1 == 0: 200 | self.optimizer.step() 201 | self.gpt_optimizer.step() 202 | 203 | self.optimizer.zero_grad() 204 | self.gpt_optimizer.zero_grad() 205 | 206 | self.lr_scheduler.step() 207 | update_moving_average(self.target_ema_updater, self.gpt_target, self.gpt) 208 | 209 | # visualize predictions 210 | if step % 200 == 0: 211 | with torch.no_grad(): 212 | # sample a batch of traj and corresponding values 213 | batch, indices = self.replay_buffer.sample_spr(jumps=self.nsteps) 214 | _, _, _, _, _, all_obs, all_pixel_obs, _, values = utils.to_torch(batch, self.device) 215 | 216 | # preprocess first obs from traj 217 | obs = all_obs[0] 218 | obs_shape = obs.shape 219 | 220 | # embed first obs 221 | y_context = obs.long() 222 | 223 | value_loss = self.get_value_estimates(y_context, values, obs_shape) 224 | wandb.log({"value loss": value_loss}, step=step) 225 | 226 | wandb.log({"gpt loss": gpt_loss}, step=step) 227 | wandb.log({"reward loss": reward_loss}, step=step) 228 | 229 | # save gpt model 230 | if step in self.saving_iter: 231 | print("saving gpt weights...") 232 | self.save_gpt_weights(step) 233 | 234 | return metrics 235 | 236 | def save_gpt_weights(self, step): 237 | torch.save(self.gpt.state_dict(), os.path.join(save_dir_path, "dino_nanogpt_gamma_{}_{}_model_step_{}.pth".format(self.discount, self.codebook_size, step))) 238 | 239 | def get_value_estimates(self, y_context, values, obs_shape): 240 | # Take a state, get samples from the gamma distribution, 241 | # Run the reward predictor through these to get value estimates 242 | # Get ground truth value estimates by simply taking discounted sum of rewards 243 | # Compare these for different states 244 | 245 | num_gamma_samples = 100 246 | values_pred = [] 247 | 248 | for i in range(num_gamma_samples): 249 | outputs = self.gpt_target.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, output_scores=True, return_dict_in_generate=True, pad_token_id=-100) #, kwargs={'token_type_ids': context_mask_ids}) 250 | p_t = outputs.sequences[:, -y_context.shape[-1]:] 251 | 252 | # quant = self.model.quantize.get_codebook_entry(p_t, None) 253 | # quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2) 254 | quant = self.model.get_codebook_entry(p_t.reshape(obs_shape[0]*frame_stack, -1)) 255 | 256 | values_pred.append(self.reward_predictor(quant.float().reshape(obs_shape[0], -1)).squeeze(1)) 257 | 258 | values_pred = torch.stack(values_pred).sum(0) / (100 * (1 - self.discount)) 259 | 260 | value_estimation_loss = F.mse_loss(values_pred, values.squeeze(1).float()).mean() 261 | print("val estimation", value_estimation_loss, values_pred[:5], values[:5]) 262 | 263 | return value_estimation_loss 264 | 265 | 266 | import argparse 267 | 268 | if __name__ == "__main__": 269 | parser = argparse.ArgumentParser(description='Create a dictionary with command-line arguments.') 270 | 271 | parser.add_argument('--discount', type=float, default=0.8, help='discount') 272 | args = parser.parse_args() 273 | 274 | agent = TFAgent(discount=args.discount) 275 | 276 | agent.optimizer.zero_grad() 277 | agent.gpt_optimizer.zero_grad() 278 | 279 | for step in range(100000): 280 | agent.update(step) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import random 6 | import re 7 | import time 8 | 9 | import numpy as np 10 | import h5py 11 | from collections import deque 12 | import dmc 13 | from dm_env import StepType 14 | from numpy_replay_buffer import EfficientReplayBuffer 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch import distributions as pyd 19 | from torch.distributions.utils import _standard_normal 20 | import torchvision.transforms as T 21 | 22 | class eval_mode: 23 | def __init__(self, *models): 24 | self.models = models 25 | 26 | def __enter__(self): 27 | self.prev_states = [] 28 | for model in self.models: 29 | self.prev_states.append(model.training) 30 | model.train(False) 31 | 32 | def __exit__(self, *args): 33 | for model, state in zip(self.models, self.prev_states): 34 | model.train(state) 35 | return False 36 | 37 | 38 | def set_seed_everywhere(seed): 39 | torch.manual_seed(seed) 40 | if torch.cuda.is_available(): 41 | torch.cuda.manual_seed_all(seed) 42 | np.random.seed(seed) 43 | random.seed(seed) 44 | 45 | 46 | def soft_update_params(net, target_net, tau): 47 | for param, target_param in zip(net.parameters(), target_net.parameters()): 48 | target_param.data.copy_(tau * param.data + 49 | (1 - tau) * target_param.data) 50 | 51 | 52 | def to_torch(xs, device): 53 | return tuple(torch.as_tensor(x, device=device) for x in xs) 54 | 55 | 56 | def weight_init(m): 57 | if isinstance(m, nn.Linear): 58 | nn.init.orthogonal_(m.weight.data) 59 | if hasattr(m.bias, 'data'): 60 | m.bias.data.fill_(0.0) 61 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 62 | gain = nn.init.calculate_gain('relu') 63 | nn.init.orthogonal_(m.weight.data, gain) 64 | if hasattr(m.bias, 'data'): 65 | m.bias.data.fill_(0.0) 66 | 67 | 68 | class Until: 69 | def __init__(self, until, action_repeat=1): 70 | self._until = until 71 | self._action_repeat = action_repeat 72 | 73 | def __call__(self, step): 74 | if self._until is None: 75 | return True 76 | until = self._until // self._action_repeat 77 | return step < until 78 | 79 | 80 | class Every: 81 | def __init__(self, every, action_repeat=1): 82 | self._every = every 83 | self._action_repeat = action_repeat 84 | 85 | def __call__(self, step): 86 | if self._every is None: 87 | return False 88 | every = self._every // self._action_repeat 89 | if step % every == 0: 90 | return True 91 | return False 92 | 93 | 94 | class Timer: 95 | def __init__(self): 96 | self._start_time = time.time() 97 | self._last_time = time.time() 98 | 99 | def reset(self): 100 | elapsed_time = time.time() - self._last_time 101 | self._last_time = time.time() 102 | total_time = time.time() - self._start_time 103 | return elapsed_time, total_time 104 | 105 | def total_time(self): 106 | return time.time() - self._start_time 107 | 108 | 109 | class TruncatedNormal(pyd.Normal): 110 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 111 | super().__init__(loc, scale, validate_args=False) 112 | self.low = low 113 | self.high = high 114 | self.eps = eps 115 | 116 | def _clamp(self, x): 117 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 118 | x = x - x.detach() + clamped_x.detach() 119 | return x 120 | 121 | def sample(self, clip=None, sample_shape=torch.Size()): 122 | shape = self._extended_shape(sample_shape) 123 | eps = _standard_normal(shape, 124 | dtype=self.loc.dtype, 125 | device=self.loc.device) 126 | eps *= self.scale 127 | if clip is not None: 128 | eps = torch.clamp(eps, -clip, clip) 129 | x = self.loc + eps 130 | return self._clamp(x) 131 | 132 | 133 | def schedule(schdl, step): 134 | try: 135 | return float(schdl) 136 | except ValueError: 137 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl) 138 | if match: 139 | init, final, duration = [float(g) for g in match.groups()] 140 | mix = np.clip(step / duration, 0.0, 1.0) 141 | return (1.0 - mix) * init + mix * final 142 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl) 143 | if match: 144 | init, final1, duration1, final2, duration2 = [ 145 | float(g) for g in match.groups() 146 | ] 147 | if step <= duration1: 148 | mix = np.clip(step / duration1, 0.0, 1.0) 149 | return (1.0 - mix) * init + mix * final1 150 | else: 151 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 152 | return (1.0 - mix) * final1 + mix * final2 153 | raise NotImplementedError(schdl) 154 | 155 | 156 | step_type_lookup = { 157 | 0: StepType.FIRST, 158 | 1: StepType.MID, 159 | 2: StepType.LAST 160 | } 161 | 162 | 163 | def load_offline_dataset_into_buffer(offline_dir, replay_buffer, agent, frame_stack, replay_buffer_size, future_sampling_steps=2, latent_style="vae"): 164 | filenames = sorted(offline_dir.glob('*.hdf5')) 165 | num_steps = 0 166 | print("filename is", filenames, offline_dir, offline_dir.glob('*.hdf5')) 167 | for filename in filenames: 168 | #try: 169 | episodes = h5py.File(filename, 'r') 170 | episodes = {k: episodes[k][:] for k in episodes.keys()} 171 | add_offline_data_to_buffer(episodes, replay_buffer, agent, framestack=frame_stack, future_sampling_steps=future_sampling_steps, latent_style=latent_style) 172 | length = episodes['reward'].shape[0] 173 | num_steps += length 174 | #except Exception as e: 175 | # print(f'Could not load episode {str(filename)}: {e}') 176 | # continue 177 | print("Loaded {} offline timesteps so far...".format(int(num_steps))) 178 | if num_steps >= replay_buffer_size: 179 | break 180 | print("Finished, loaded {} timesteps.".format(int(num_steps))) 181 | 182 | 183 | def add_offline_data_to_buffer(offline_data: dict, replay_buffer: EfficientReplayBuffer, agent, framestack: int = 3, future_sampling_steps: int = 2, latent_style: str = "vae"): 184 | offline_data_length = offline_data['reward'].shape[0] 185 | for v in offline_data.values(): 186 | assert v.shape[0] == offline_data_length 187 | done_list = np.argwhere(offline_data['step_type']==2) 188 | assert len(done_list) > 1 189 | interval = done_list[1] - done_list[0] 190 | now = -1 191 | max_k = future_sampling_steps #15 192 | 193 | resize = T.Compose([ 194 | T.ToPILImage(), 195 | T.Resize((64, 64)), 196 | # T.ToTensor() 197 | ]) 198 | 199 | for idx in range(offline_data_length): 200 | time_step = get_timestep_from_idx(offline_data, idx, latent_style) 201 | if not time_step.first(): 202 | now += 1 203 | # stacked_frames.append(np.asarray(resize(time_step.observation.reshape(84, 84, -1))).reshape(-1, 64, 64)) 204 | stacked_frames.append(time_step.observation) 205 | stacked_pixel_frames.append(time_step.pixel_observation) 206 | time_step_stack = time_step._replace(observation=np.concatenate(stacked_frames, axis=0), pixel_observation=np.concatenate(stacked_pixel_frames, axis=0)) 207 | with torch.no_grad(): #, eval_mode(agent): 208 | ob = torch.as_tensor(np.concatenate(stacked_frames, axis=0), device='cuda') 209 | pixel_ob = torch.as_tensor(np.concatenate(stacked_pixel_frames, axis=0), device='cuda') 210 | # imp_action = torch.abs(agent.actor(agent.encoder(ob.unsqueeze(0)).squeeze(0))) 211 | #act = torch.as_tensor(imp_action.reshape(9, 84, 84), device=agent.device) 212 | #new_ob = torch.clamp(torch.sqrt(act) * ob + torch.sqrt(1 - act) * 255 * torch.randn(9, 84, 84, device=agent.device), min=0, max=255).type(torch.int64) 213 | # latent = agent.get_latent(ob, imp_action).squeeze(0) #agent.encoder(new_ob.unsqueeze(0)).squeeze(0) #agent.get_latent(state, action, latent=No) 214 | # time_step_stack = time_step_stack._replace(latent=latent.cpu().detach().numpy()) 215 | # time_step_stack = time_step_stack._replace(imp_action=imp_action.cpu().numpy()) 216 | rindex = min(interval-1, now+max_k) 217 | rindex = rindex - now 218 | time_step_stack = time_step_stack._replace(k_step=rindex) 219 | replay_buffer.add(time_step_stack) 220 | else: 221 | now = -1 222 | stacked_frames = deque(maxlen=framestack) 223 | stacked_pixel_frames = deque(maxlen=framestack) 224 | while len(stacked_frames) < framestack: 225 | # stacked_frames.append(np.asarray(resize(time_step.observation.reshape(84, 84, -1))).reshape(-1, 64, 64)) 226 | stacked_frames.append(time_step.observation) 227 | stacked_pixel_frames.append(time_step.pixel_observation) 228 | time_step_stack = time_step._replace(observation=np.concatenate(stacked_frames, axis=0), pixel_observation=np.concatenate(stacked_pixel_frames, axis=0)) 229 | with torch.no_grad(): #, eval_mode(agent): 230 | ob = torch.as_tensor(np.concatenate(stacked_frames, axis=0), device='cuda') 231 | pixel_ob = torch.as_tensor(np.concatenate(stacked_pixel_frames, axis=0), device='cuda') 232 | # imp_action = torch.abs(agent.actor(agent.encoder(ob.unsqueeze(0)).squeeze(0))) 233 | #act = torch.as_tensor(imp_action.reshape(9, 84, 84), device=agent.device) 234 | #new_ob = torch.clamp(torch.sqrt(act) * ob + torch.sqrt(1 - act) * 255 * torch.randn(9, 84, 84, device=agent.device), min=0, max=255).type(torch.int64) 235 | # latent = agent.get_latent(ob, imp_action).squeeze(0) #agent.get_latent(state, action, latent=No) 236 | #imp_action = torch.abs(agent.actor(latent)) 237 | # time_step_stack = time_step_stack._replace(latent=latent.cpu().detach().numpy()) 238 | # time_step_stack = time_step_stack._replace(imp_action=imp_action.cpu().numpy()) 239 | rindex = min(interval-1, now+max_k) #random.randint(now+1, min(interval-1, now+max_k)) 240 | rindex = rindex - now 241 | time_step_stack = time_step_stack._replace(k_step=rindex) 242 | replay_buffer.add(time_step_stack) 243 | 244 | 245 | ## TODO: Add dummy values for storing 'mae_latents' and 'mae_ids_restore' so that a common replay buffer can be used for both, OR have two separate replay buffer codes 246 | def get_timestep_from_idx(offline_data: dict, idx: int, latent_style: str): 247 | if latent_style == "vae": 248 | return dmc.ExtendedTimeStep( 249 | step_type=step_type_lookup[offline_data['step_type'][idx]], 250 | reward=offline_data['reward'][idx], 251 | pixel_observation=offline_data['observation'][idx], 252 | observation=offline_data['vae_latents'][idx], 253 | discount=offline_data['discount'][idx], 254 | action=offline_data['action'][idx], 255 | latent=np.zeros(256), 256 | imp_action=np.zeros(84*84*1), 257 | k_step = idx 258 | ) 259 | elif latent_style == "dino": 260 | return dmc.ExtendedTimeStep( 261 | step_type=step_type_lookup[offline_data['step_type'][idx]], 262 | reward=offline_data['reward'][idx], 263 | pixel_observation=offline_data['observation'][idx], 264 | observation=offline_data['dino_latents'][idx], 265 | discount=offline_data['discount'][idx], 266 | action=offline_data['action'][idx], 267 | latent=np.zeros(256), 268 | imp_action=np.zeros(84*84*1), 269 | k_step = idx 270 | ) -------------------------------------------------------------------------------- /dino/vqkd_teacher/dino.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beitv2 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Zhiliang Peng 7 | # Based on DINO code bases 8 | # https://github.com/facebookresearch/dino 9 | # --------------------------------------------------------' 10 | 11 | import math 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from functools import partial, reduce 18 | from collections import OrderedDict 19 | 20 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 21 | 22 | import pdb 23 | 24 | # https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth 25 | 26 | def drop_path(x, drop_prob: float = 0., training: bool = False): 27 | if drop_prob == 0. or not training: 28 | return x 29 | keep_prob = 1 - drop_prob 30 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 31 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 32 | random_tensor.floor_() # binarize 33 | output = x.div(keep_prob) * random_tensor 34 | return output 35 | 36 | 37 | class DropPath(nn.Module): 38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 39 | """ 40 | def __init__(self, drop_prob=None): 41 | super(DropPath, self).__init__() 42 | self.drop_prob = drop_prob 43 | 44 | def forward(self, x): 45 | return drop_path(x, self.drop_prob, self.training) 46 | 47 | 48 | class Mlp(nn.Module): 49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 50 | super().__init__() 51 | out_features = out_features or in_features 52 | hidden_features = hidden_features or in_features 53 | self.fc1 = nn.Linear(in_features, hidden_features) 54 | self.act = act_layer() 55 | self.fc2 = nn.Linear(hidden_features, out_features) 56 | self.drop = nn.Dropout(drop) 57 | 58 | def forward(self, x): 59 | x = self.fc1(x) 60 | x = self.act(x) 61 | x = self.drop(x) 62 | x = self.fc2(x) 63 | x = self.drop(x) 64 | return x 65 | 66 | 67 | class Attention(nn.Module): 68 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 69 | super().__init__() 70 | self.num_heads = num_heads 71 | head_dim = dim // num_heads 72 | self.scale = qk_scale or head_dim ** -0.5 73 | 74 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 75 | self.attn_drop = nn.Dropout(attn_drop) 76 | self.proj = nn.Linear(dim, dim) 77 | self.proj_drop = nn.Dropout(proj_drop) 78 | 79 | def forward(self, x): 80 | B, N, C = x.shape 81 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 82 | q, k, v = qkv[0], qkv[1], qkv[2] 83 | 84 | attn = (q @ k.transpose(-2, -1)) * self.scale 85 | attn = attn.softmax(dim=-1) 86 | attn = self.attn_drop(attn) 87 | 88 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 89 | x = self.proj(x) 90 | x = self.proj_drop(x) 91 | return x, attn 92 | 93 | 94 | class Block(nn.Module): 95 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 96 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 97 | super().__init__() 98 | self.norm1 = norm_layer(dim) 99 | self.attn = Attention( 100 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 101 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 102 | self.norm2 = norm_layer(dim) 103 | mlp_hidden_dim = int(dim * mlp_ratio) 104 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 105 | 106 | def forward(self, x, return_attention=False): 107 | y, attn = self.attn(self.norm1(x)) 108 | if return_attention: 109 | return attn 110 | x = x + self.drop_path(y) 111 | x = x + self.drop_path(self.mlp(self.norm2(x))) 112 | return x 113 | 114 | 115 | class PatchEmbed(nn.Module): 116 | """ Image to Patch Embedding 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 119 | super().__init__() 120 | num_patches = (img_size // patch_size) * (img_size // patch_size) 121 | self.img_size = img_size 122 | self.patch_size = patch_size 123 | self.num_patches = num_patches 124 | 125 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 126 | 127 | def forward(self, x): 128 | B, C, H, W = x.shape 129 | x = self.proj(x).flatten(2).transpose(1, 2) 130 | return x 131 | 132 | 133 | class VisionTransformer(nn.Module): 134 | """ Vision Transformer """ 135 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 136 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 137 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 138 | super().__init__() 139 | self.num_features = self.embed_dim = embed_dim 140 | 141 | self.patch_embed = PatchEmbed( 142 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 143 | num_patches = self.patch_embed.num_patches 144 | 145 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 146 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 147 | self.pos_drop = nn.Dropout(p=drop_rate) 148 | 149 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 150 | self.blocks = nn.ModuleList([ 151 | Block( 152 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 153 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 154 | for i in range(depth)]) 155 | self.norm = norm_layer(embed_dim) 156 | 157 | # Classifier head 158 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 159 | 160 | trunc_normal_(self.pos_embed, std=.02) 161 | trunc_normal_(self.cls_token, std=.02) 162 | self.apply(self._init_weights) 163 | 164 | if kwargs.get('pretrained', True): 165 | self.load_from_pretrained('https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth') 166 | if not kwargs.get('requires_grad', False): 167 | for param in self.parameters(): 168 | param.requires_grad = False 169 | 170 | def _init_weights(self, m): 171 | if isinstance(m, nn.Linear): 172 | trunc_normal_(m.weight, std=.02) 173 | if isinstance(m, nn.Linear) and m.bias is not None: 174 | nn.init.constant_(m.bias, 0) 175 | elif isinstance(m, nn.LayerNorm): 176 | nn.init.constant_(m.bias, 0) 177 | nn.init.constant_(m.weight, 1.0) 178 | 179 | def load_from_pretrained(self, ckpt_path): 180 | if ckpt_path.startswith('https'): 181 | sd = torch.hub.load_state_dict_from_url(ckpt_path, map_location='cpu', check_hash=True) 182 | else: 183 | sd = torch.load(ckpt_path, map_location='cpu') 184 | 185 | missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) 186 | print(f"Load weight for dino model: {ckpt_path}") 187 | print(f"missing_keys: {missing_keys}") 188 | print(f"unexpected_keys: {unexpected_keys}") 189 | 190 | def interpolate_pos_encoding(self, x, w, h): 191 | npatch = x.shape[1] - 1 192 | N = self.pos_embed.shape[1] - 1 193 | if npatch == N and w == h: 194 | return self.pos_embed 195 | class_pos_embed = self.pos_embed[:, 0] 196 | patch_pos_embed = self.pos_embed[:, 1:] 197 | dim = x.shape[-1] 198 | w0 = w // self.patch_embed.patch_size 199 | h0 = h // self.patch_embed.patch_size 200 | # we add a small number to avoid floating point error in the interpolation 201 | # see discussion at https://github.com/facebookresearch/dino/issues/8 202 | w0, h0 = w0 + 0.1, h0 + 0.1 203 | patch_pos_embed = nn.functional.interpolate( 204 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 205 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 206 | mode='bicubic', 207 | ) 208 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 209 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 210 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 211 | 212 | def prepare_tokens(self, x): 213 | B, nc, w, h = x.shape 214 | x = self.patch_embed(x) # patch linear embedding 215 | 216 | # add the [CLS] token to the embed patch tokens 217 | cls_tokens = self.cls_token.expand(B, -1, -1) 218 | x = torch.cat((cls_tokens, x), dim=1) 219 | 220 | # add positional encoding to each token 221 | x = x + self.interpolate_pos_encoding(x, w, h) 222 | 223 | return self.pos_drop(x) 224 | 225 | def forward(self, x, return_patch_tokens=False, return_all_tokens=False): 226 | x = self.prepare_tokens(x) 227 | for blk in self.blocks: 228 | x = blk(x) 229 | x = self.norm(x) 230 | if return_all_tokens: 231 | return x 232 | elif return_patch_tokens: 233 | return x[:, 1:] 234 | else: 235 | return x[:, 0] 236 | 237 | def get_last_selfattention(self, x): 238 | x = self.prepare_tokens(x) 239 | for i, blk in enumerate(self.blocks): 240 | if i < len(self.blocks) - 1: 241 | x = blk(x) 242 | else: 243 | # return attention of the last block 244 | return blk(x, return_attention=True) 245 | 246 | def get_intermediate_layers(self, x, n=1): 247 | x = self.prepare_tokens(x) 248 | # we return the output tokens from the `n` last blocks 249 | output = [] 250 | for i, blk in enumerate(self.blocks): 251 | x = blk(x) 252 | if len(self.blocks) - i <= n: 253 | output.append(self.norm(x)) 254 | return output 255 | 256 | def forward_intermediate(self, x, layer_id=12): 257 | x = self.prepare_tokens(x) 258 | 259 | if isinstance(layer_id, list): 260 | output_list = [] 261 | for l, blk in enumerate(self.blocks): 262 | x = blk(x) 263 | if l in layer_id: 264 | output_list.append(x[:, 1:]) 265 | # output_list.append(self.norm(x)) 266 | return output_list 267 | elif isinstance(layer_id, int): 268 | for l, blk in enumerate(self.blocks): 269 | if l < layer_id: 270 | x = blk(x) 271 | elif l == layer_id: 272 | # pdb.set_trace() 273 | x = blk.norm1(x) 274 | else: 275 | break 276 | return x[:, 1:] 277 | 278 | 279 | def vit_tiny(patch_size=16, **kwargs): 280 | model = VisionTransformer( 281 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 282 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 283 | return model 284 | 285 | 286 | def vit_small(patch_size=16, **kwargs): 287 | model = VisionTransformer( 288 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 289 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 290 | return model 291 | 292 | 293 | def vit_base(patch_size=16, **kwargs): 294 | model = VisionTransformer( 295 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 296 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 297 | return model 298 | 299 | def get_dino_vit_base(): 300 | return vit_base(pretrained=True, requires_grad=False) 301 | 302 | 303 | class DINOHead(nn.Module): 304 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 305 | super().__init__() 306 | nlayers = max(nlayers, 1) 307 | if nlayers == 1: 308 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 309 | else: 310 | layers = [nn.Linear(in_dim, hidden_dim)] 311 | if use_bn: 312 | layers.append(nn.BatchNorm1d(hidden_dim)) 313 | layers.append(nn.GELU()) 314 | for _ in range(nlayers - 2): 315 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 316 | if use_bn: 317 | layers.append(nn.BatchNorm1d(hidden_dim)) 318 | layers.append(nn.GELU()) 319 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 320 | self.mlp = nn.Sequential(*layers) 321 | self.apply(self._init_weights) 322 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 323 | self.last_layer.weight_g.data.fill_(1) 324 | if norm_last_layer: 325 | self.last_layer.weight_g.requires_grad = False 326 | 327 | def _init_weights(self, m): 328 | if isinstance(m, nn.Linear): 329 | trunc_normal_(m.weight, std=.02) 330 | if isinstance(m, nn.Linear) and m.bias is not None: 331 | nn.init.constant_(m.bias, 0) 332 | 333 | def forward(self, x): 334 | x = self.mlp(x) 335 | x = nn.functional.normalize(x, dim=-1, p=2) 336 | x = self.last_layer(x) 337 | return x -------------------------------------------------------------------------------- /utils/drqv2.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2LMHeadModel, GPT2Config 2 | 3 | import hydra 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision.transforms as T 9 | 10 | import utils 11 | from pathlib import Path 12 | 13 | from numpy_replay_buffer import EfficientReplayBuffer 14 | from utils import load_offline_dataset_into_buffer 15 | 16 | from dm_env import specs 17 | import dmc 18 | 19 | 20 | def k3s1p0(x): 21 | return x - 2 22 | 23 | 24 | def k4s2p0(x): 25 | assert ((x % 2) == 0) 26 | return (x // 2) - 1 27 | 28 | 29 | class RandomShiftsAug(nn.Module): 30 | def __init__(self, pad=4): 31 | super().__init__() 32 | self.pad = pad 33 | 34 | def forward(self, x): 35 | # x = T.Resize((x.shape[0], x.shape[1], 64, 64)) 36 | n, c, h, w = x.size() 37 | assert h == w 38 | padding = tuple([self.pad] * 4) 39 | x = F.pad(x, padding, 'replicate') 40 | eps = 1.0 / (h + 2 * self.pad) 41 | arange = torch.linspace(-1.0 + eps, 42 | 1.0 - eps, 43 | h + 2 * self.pad, 44 | device=x.device, 45 | dtype=x.dtype)[:h] 46 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 47 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 48 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 49 | 50 | shift = torch.randint(0, 51 | 2 * self.pad + 1, 52 | size=(n, 1, 1, 2), 53 | device=x.device, 54 | dtype=x.dtype) 55 | shift *= 2.0 / (h + 2 * self.pad) 56 | 57 | grid = base_grid + shift 58 | return F.grid_sample(x, 59 | grid, 60 | padding_mode='zeros', 61 | align_corners=False) 62 | 63 | 64 | class NoShiftAug(nn.Module): 65 | def __init__(self): 66 | super().__init__() 67 | 68 | def forward(self, x): 69 | return x 70 | 71 | class LayerNorm(nn.Module): 72 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 73 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 74 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 75 | with shape (batch_size, channels, height, width). 76 | """ 77 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 78 | super().__init__() 79 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 80 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 81 | self.eps = eps 82 | self.data_format = data_format 83 | if self.data_format not in ["channels_last", "channels_first"]: 84 | raise NotImplementedError 85 | self.normalized_shape = (normalized_shape, ) 86 | 87 | def forward(self, x): 88 | if self.data_format == "channels_last": 89 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 90 | elif self.data_format == "channels_first": 91 | u = x.mean(1, keepdim=True) 92 | s = (x - u).pow(2).mean(1, keepdim=True) 93 | x = (x - u) / torch.sqrt(s + self.eps) 94 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 95 | return x 96 | 97 | class Encoder(nn.Module): 98 | def __init__(self, obs_shape, feature_dim): 99 | super().__init__() 100 | 101 | assert len(obs_shape) == 1 102 | 103 | action_dim = 6 104 | self.repr_dim = 75 105 | 106 | self.linear = nn.Sequential(nn.Linear(self.repr_dim, feature_dim), nn.BatchNorm1d(feature_dim), 107 | nn.ReLU()) 108 | 109 | self.apply(utils.weight_init) 110 | 111 | def forward(self, obs): 112 | h = self.linear(obs) 113 | return h 114 | 115 | class Actor(nn.Module): 116 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim): 117 | super().__init__() 118 | 119 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim), 120 | nn.LayerNorm(feature_dim), nn.Tanh()) 121 | 122 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 123 | nn.ReLU(inplace=True), 124 | nn.Linear(hidden_dim, hidden_dim), 125 | nn.ReLU(inplace=True), 126 | nn.Linear(hidden_dim, action_shape[0])) 127 | 128 | self.apply(utils.weight_init) 129 | 130 | def forward(self, obs, std=None): 131 | # h = self.trunk(obs) 132 | 133 | mu = self.policy(obs) 134 | mu = torch.tanh(mu) 135 | if std is None: 136 | return mu 137 | std = torch.ones_like(mu) * std 138 | 139 | dist = utils.TruncatedNormal(mu, std) 140 | return dist 141 | 142 | 143 | class Critic(nn.Module): 144 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim): 145 | super().__init__() 146 | 147 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim), 148 | nn.LayerNorm(feature_dim), nn.Tanh()) 149 | 150 | self.Q1 = nn.Sequential( 151 | nn.Linear(feature_dim + action_shape[0], hidden_dim), 152 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), 153 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) 154 | 155 | self.Q2 = nn.Sequential( 156 | nn.Linear(feature_dim + action_shape[0], hidden_dim), 157 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), 158 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) 159 | 160 | self.apply(utils.weight_init) 161 | 162 | def forward(self, obs, action): 163 | # h = self.trunk(obs) 164 | h_action = torch.cat([obs, action], dim=-1) 165 | q1 = self.Q1(h_action) 166 | q2 = self.Q2(h_action) 167 | 168 | return q1, q2 169 | 170 | 171 | class DrQV2Agent: 172 | def __init__(self, obs_shape=(75,), action_shape=(6,), device='cuda', lr=3e-4, feature_dim=64, 173 | hidden_dim=256, critic_target_tau=0.005, num_expl_steps=2000, 174 | update_every_steps=2, stddev_schedule='linear(1.0,0.1,100000)', 175 | stddev_clip=0.3, use_tb=False, 176 | offline=True, bc_weight=2.5, augmentation=RandomShiftsAug(pad=4), 177 | use_bc=True): 178 | self.device = device 179 | self.critic_target_tau = critic_target_tau 180 | self.update_every_steps = update_every_steps 181 | self.use_tb = use_tb 182 | self.num_expl_steps = num_expl_steps 183 | self.stddev_schedule = stddev_schedule 184 | self.stddev_clip = stddev_clip 185 | self.offline = offline 186 | self.bc_weight = bc_weight 187 | self.use_bc = use_bc 188 | 189 | # replay buffer 190 | self.train_env = dmc.make("offline_cheetah_run_expert", 3, 2, 0, None) 191 | data_specs = (self.train_env.observation_spec(), 192 | self.train_env.action_spec(), 193 | specs.Array((1,), np.float32, 'reward'), 194 | specs.Array((1,), np.float32, 'discount')) 195 | 196 | self.discount = 0.99 197 | self.replay_buffer = EfficientReplayBuffer(25000, 32, 1, self.discount, 3, False, data_specs) 198 | 199 | offline_dir = "/home/manant/scratch/expert/vq_latents/" 200 | load_offline_dataset_into_buffer(Path(offline_dir), self.replay_buffer, None, 3, 25000) 201 | 202 | # gpt model 203 | configuration = GPT2Config(vocab_size=1024) 204 | self.gpt = GPT2LMHeadModel(configuration).to('cuda') 205 | self.gpt.load_state_dict(torch.load("/home/manant/scratch/pixel_gamma/checkpoints/vq_gpt_gamma_model_4000.pth")) 206 | 207 | # actor critic models 208 | self.encoder = Encoder(obs_shape, feature_dim).to(device) 209 | self.actor = Actor(feature_dim, action_shape, feature_dim, 210 | hidden_dim).to(device) 211 | 212 | self.critic = Critic(feature_dim, action_shape, feature_dim, 213 | hidden_dim).to(device) 214 | self.critic_target = Critic(feature_dim, action_shape, 215 | feature_dim, hidden_dim).to(device) 216 | self.critic_target.load_state_dict(self.critic.state_dict()) 217 | 218 | # optimizers 219 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr) 220 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 221 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 222 | 223 | # data augmentation 224 | self.aug = augmentation 225 | 226 | self.train() 227 | self.critic_target.train() 228 | 229 | def train(self, training=True): 230 | self.training = training 231 | self.encoder.train(training) 232 | self.actor.train(training) 233 | self.critic.train(training) 234 | 235 | def act(self, obs, latent, step, eval_mode): 236 | obs = torch.as_tensor(obs, device=self.device) 237 | obs = self.encoder(obs.unsqueeze(0)) 238 | stddev = utils.schedule(self.stddev_schedule, step) 239 | dist = self.actor(obs, stddev) 240 | if eval_mode: 241 | action = dist.mean 242 | else: 243 | action = dist.sample(clip=None) 244 | if step < self.num_expl_steps: 245 | action.uniform_(-1.0, 1.0) 246 | return action.cpu().numpy()[0], None 247 | 248 | def update_critic(self, obs, action, reward, discount, next_obs, step): 249 | metrics = dict() 250 | 251 | with torch.no_grad(): 252 | stddev = utils.schedule(self.stddev_schedule, step) 253 | dist = self.actor(next_obs, stddev) 254 | next_action = dist.sample(clip=self.stddev_clip) 255 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 256 | target_V = torch.min(target_Q1, target_Q2) 257 | target_Q = reward.float() + (discount * target_V) 258 | 259 | Q1, Q2 = self.critic(obs, action) 260 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 261 | 262 | if self.use_tb: 263 | metrics['critic_target_q'] = target_Q.mean().item() 264 | metrics['critic_q1'] = Q1.mean().item() 265 | metrics['critic_q2'] = Q2.mean().item() 266 | metrics['critic_loss'] = critic_loss.item() 267 | 268 | # optimize encoder and critic 269 | self.encoder_opt.zero_grad(set_to_none=True) 270 | self.critic_opt.zero_grad(set_to_none=True) 271 | critic_loss.backward() 272 | self.critic_opt.step() 273 | self.encoder_opt.step() 274 | 275 | if step % 5000 == 0: 276 | print("Critic Loss", critic_loss) 277 | 278 | return metrics 279 | 280 | def update_actor(self, obs, step, behavioural_action=None): 281 | metrics = dict() 282 | 283 | stddev = utils.schedule(self.stddev_schedule, step) 284 | dist = self.actor(obs, stddev) 285 | action = dist.sample(clip=self.stddev_clip) 286 | log_prob = dist.log_prob(action).sum(-1, keepdim=True) 287 | Q1, Q2 = self.critic(obs, action) 288 | Q = torch.min(Q1, Q2) 289 | 290 | actor_policy_improvement_loss = -Q.mean() 291 | 292 | actor_loss = actor_policy_improvement_loss 293 | 294 | # offline BC Loss 295 | if self.offline: 296 | actor_bc_loss = F.mse_loss(action, behavioural_action) 297 | # Eq. 5 of arXiv:2106.06860 298 | lam = self.bc_weight / Q.detach().abs().mean() 299 | if self.use_bc: 300 | actor_loss = actor_policy_improvement_loss * lam + actor_bc_loss 301 | else: 302 | actor_loss = actor_policy_improvement_loss #* lam 303 | 304 | # optimize actor 305 | self.actor_opt.zero_grad(set_to_none=True) 306 | actor_loss.backward() 307 | self.actor_opt.step() 308 | 309 | if self.use_tb: 310 | metrics['actor_loss'] = actor_policy_improvement_loss.item() 311 | metrics['actor_logprob'] = log_prob.mean().item() 312 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() 313 | if self.offline: 314 | metrics['actor_bc_loss'] = actor_bc_loss.item() 315 | 316 | if step % 5000 == 0: 317 | print("Actor Loss", actor_loss) 318 | 319 | return metrics 320 | 321 | def update(self, step): 322 | metrics = dict() 323 | 324 | if step % self.update_every_steps != 0: 325 | return metrics 326 | 327 | batch, indices = next(self.replay_buffer) 328 | obs, action, reward, discount, next_obs, _, _, _, _ = utils.to_torch( 329 | batch, self.device) 330 | 331 | # augment 332 | obs = obs.float() #self.aug(obs.float()) 333 | next_obs = next_obs.float() #self.aug(next_obs.float()) 334 | # encode 335 | obs = self.encoder(obs) 336 | with torch.no_grad(): 337 | next_obs = self.encoder(next_obs) 338 | 339 | if self.use_tb: 340 | metrics['batch_reward'] = reward.mean().item() 341 | 342 | # update critic 343 | metrics.update( 344 | self.update_critic(obs, action, reward, discount, next_obs, step)) 345 | 346 | # update actor 347 | if self.offline: 348 | metrics.update(self.update_actor(obs.detach(), step, action.detach())) 349 | else: 350 | metrics.update(self.update_actor(obs.detach(), step)) 351 | 352 | # update critic target 353 | utils.soft_update_params(self.critic, self.critic_target, 354 | self.critic_target_tau) 355 | 356 | if step % 20000 == 0: 357 | torch.save(self.actor.state_dict(), "/home/manant/scratch/pixel_gamma/checkpoints/drq_actor_{}.pth".format(step)) 358 | torch.save(self.encoder.state_dict(), "/home/manant/scratch/pixel_gamma/checkpoints/drq_encoder_{}.pth".format(step)) 359 | 360 | return metrics 361 | 362 | if __name__ == "__main__": 363 | 364 | agent = DrQV2Agent() 365 | 366 | for step in range(2000000): 367 | agent.update(step) 368 | -------------------------------------------------------------------------------- /utils/numpy_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import abc 3 | 4 | 5 | class AbstractReplayBuffer(abc.ABC): 6 | @abc.abstractmethod 7 | def add(self, time_step): 8 | pass 9 | 10 | @abc.abstractmethod 11 | def __next__(self, ): 12 | pass 13 | 14 | @abc.abstractmethod 15 | def __len__(self, ): 16 | pass 17 | 18 | 19 | class EfficientReplayBuffer(AbstractReplayBuffer): 20 | '''Fast + efficient replay buffer implementation in numpy.''' 21 | 22 | def __init__(self, buffer_size, batch_size, nstep, discount, frame_stack, spr_style_buffer, 23 | data_specs=None, pixel_samples=False, sarsa=False): 24 | self.buffer_size = buffer_size 25 | self.data_dict = {} 26 | self.index = -1 27 | self.traj_index = 0 28 | self.frame_stack = frame_stack 29 | self._recorded_frames = frame_stack + 1 30 | self.batch_size = batch_size 31 | self.nstep = nstep 32 | self.discount = discount 33 | self.full = False 34 | self.discount_vec = np.power(discount, np.arange(nstep)) # n_step - first dim should broadcast 35 | self.next_dis = discount ** nstep 36 | self.sarsa = sarsa 37 | self.latent_shape = 256 #50 38 | self.imp_act_shape = 84 * 84 * 1 39 | self.spr_style_buffer = spr_style_buffer 40 | self.pixel_samples = pixel_samples 41 | 42 | def _initial_setup(self, time_step): 43 | self.index = 0 44 | self.obs_shape = list(time_step.observation.shape) 45 | self.pixel_obs_shape = list(time_step.pixel_observation.shape) 46 | self.ims_channels = self.obs_shape[0] // self.frame_stack 47 | self.pixel_ims_channels = self.pixel_obs_shape[0] // self.frame_stack 48 | self.act_shape = time_step.action.shape 49 | 50 | self.obs = np.zeros([self.buffer_size, self.ims_channels, *self.obs_shape[1:]], dtype=np.int32) 51 | self.pixel_obs = np.zeros([self.buffer_size, self.pixel_ims_channels, *self.pixel_obs_shape[1:]], dtype=np.uint8) 52 | self.act = np.zeros([self.buffer_size, *self.act_shape], dtype=np.float32) 53 | self.latent = np.zeros([self.buffer_size, self.latent_shape], dtype=np.float32) 54 | self.imp_act = np.ones([self.buffer_size, self.imp_act_shape], dtype=np.float32) 55 | self.rew = np.zeros([self.buffer_size], dtype=np.float32) 56 | self.dis = np.zeros([self.buffer_size], dtype=np.float32) 57 | self.valid = np.zeros([self.buffer_size], dtype=np.bool_) 58 | self.k_step = np.zeros([self.buffer_size], dtype=np.float32) 59 | self.obs_k = np.zeros([self.buffer_size, self.ims_channels, *self.obs_shape[1:]], dtype=np.int32) 60 | 61 | def add_data_point(self, time_step): 62 | first = time_step.first() 63 | latest_obs = time_step.observation[-self.ims_channels:].astype(np.int32) 64 | latest_pixel_obs = time_step.pixel_observation[-self.pixel_ims_channels:].astype(np.uint8) 65 | if first: 66 | end_index = self.index + self.frame_stack 67 | end_invalid = end_index + self.frame_stack + 1 68 | if end_invalid > self.buffer_size: 69 | if end_index > self.buffer_size: 70 | end_index = end_index % self.buffer_size 71 | self.obs[self.index:self.buffer_size] = latest_obs 72 | self.obs[0:end_index] = latest_obs 73 | self.pixel_obs[self.index:self.buffer_size] = latest_pixel_obs 74 | self.pixel_obs[0:end_index] = latest_pixel_obs 75 | self.full = True 76 | else: 77 | self.obs[self.index:end_index] = latest_obs 78 | self.pixel_obs[self.index:end_index] = latest_pixel_obs 79 | end_invalid = end_invalid % self.buffer_size 80 | self.valid[self.index:self.buffer_size] = False 81 | self.valid[0:end_invalid] = False 82 | else: 83 | self.obs[self.index:end_index] = latest_obs 84 | self.pixel_obs[self.index:end_index] = latest_pixel_obs 85 | self.valid[self.index:end_invalid] = False 86 | self.index = end_index 87 | self.traj_index = 1 88 | else: 89 | np.copyto(self.obs[self.index], latest_obs) # Check most recent image 90 | np.copyto(self.pixel_obs[self.index], latest_pixel_obs) 91 | np.copyto(self.act[self.index], time_step.action) 92 | np.copyto(self.latent[self.index], time_step.latent) 93 | np.copyto(self.imp_act[self.index], time_step.imp_action) 94 | self.rew[self.index] = time_step.reward 95 | self.dis[self.index] = time_step.discount 96 | self.valid[(self.index + self.frame_stack) % self.buffer_size] = False 97 | self.k_step[self.index] = time_step.k_step 98 | if self.traj_index >= self.nstep: 99 | self.valid[(self.index - self.nstep + 1) % self.buffer_size] = True 100 | self.index += 1 101 | self.traj_index += 1 102 | if self.index == self.buffer_size: 103 | self.index = 0 104 | self.full = True 105 | 106 | def add(self, time_step): 107 | if self.index == -1: 108 | self._initial_setup(time_step) 109 | self.add_data_point(time_step) 110 | 111 | def get_stats(self, ): 112 | print("obs shape", self.obs.shape) 113 | mean = np.mean(self.obs, axis=(0, 2, 3)) 114 | std = np.std(self.obs, axis=(0, 2, 3)) 115 | return mean, std 116 | 117 | def __next__(self): 118 | indices = np.random.choice(self.valid.nonzero()[0] - 8, size=self.batch_size) 119 | # if spr: 120 | # return self.gather_spr_indices(indices), indices 121 | # else: 122 | return self.gather_nstep_indices(indices), indices 123 | 124 | def sample_spr(self, indices=None, jumps=8): 125 | if indices is None: 126 | indices = np.random.choice(self.valid.nonzero()[0] - jumps, size=self.batch_size) 127 | return self.gather_spr_indices(indices, jumps), indices 128 | 129 | def replace_latent(self, indices, latents): 130 | self.latent[indices] = latents 131 | 132 | def replace_action(self, indices, imp_actions): 133 | self.imp_act[indices] = imp_actions 134 | 135 | def sample_previous_latent(self, indices): 136 | return self.latent[indices - 1] 137 | 138 | def gather_nstep_indices(self, indices): 139 | n_samples = indices.shape[0] 140 | all_gather_ranges = np.stack([np.arange(indices[i] - self.frame_stack, indices[i] + self.nstep) 141 | for i in range(n_samples)], axis=0) % self.buffer_size 142 | gather_ranges = all_gather_ranges[:, self.frame_stack:] # bs x nstep 143 | obs_gather_ranges = all_gather_ranges[:, :self.frame_stack] 144 | nobs_gather_ranges = all_gather_ranges[:, -self.frame_stack:] 145 | 146 | all_rewards = self.rew[gather_ranges] #/ np.max(self.rew) 147 | 148 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement 149 | rew = np.sum(all_rewards * self.discount_vec, axis=1, keepdims=True) / gather_ranges.shape[1] 150 | 151 | if self.pixel_samples: 152 | # In case we require pixel observation instead of the saved latents 153 | obs = np.reshape(self.pixel_obs[obs_gather_ranges], [n_samples, *self.pixel_obs_shape]) 154 | nobs = np.reshape(self.pixel_obs[nobs_gather_ranges], [n_samples, *self.pixel_obs_shape]) 155 | else: 156 | obs = np.reshape(self.obs[obs_gather_ranges], [n_samples, *self.obs_shape]) 157 | nobs = np.reshape(self.obs[nobs_gather_ranges], [n_samples, *self.obs_shape]) 158 | 159 | act = self.act[indices] 160 | latent = self.latent[indices] 161 | imp_act = self.imp_act[indices] 162 | dis = np.expand_dims(self.next_dis * self.dis[nobs_gather_ranges[:, -1]], axis=-1) 163 | 164 | k_step = self.k_step[indices].astype(int) 165 | k_step_rand = [] 166 | for each in k_step: 167 | if each > 1: 168 | k_step_rand.append(np.random.randint(low=1, high=each)) 169 | else: 170 | k_step_rand.append(1) 171 | # k_step_rand = [np.random.randint(low=1, high=each) for each in k_step] 172 | k_all_gather_ranges = np.stack([np.arange(indices[i] + k_step_rand[i] - self.frame_stack, indices[i] + k_step_rand[i] + self.nstep) 173 | for i in range(n_samples)], axis=0) % self.buffer_size 174 | k_obs_gather_ranges = k_all_gather_ranges[:, :self.frame_stack] 175 | if self.pixel_samples: 176 | obs_k = np.reshape(self.pixel_obs[k_obs_gather_ranges], [n_samples, *self.pixel_obs_shape]) 177 | else: 178 | obs_k = np.reshape(self.obs[k_obs_gather_ranges], [n_samples, *self.obs_shape]) 179 | 180 | # k_all_gather_ranges = np.stack([np.arange(indices[i], indices[i] + k_step_rand[i]) 181 | # for i in range(n_samples)], axis=0) % self.buffer_size 182 | act_k = self.act[indices + k_step_rand] 183 | 184 | if self.sarsa: 185 | nact = self.act[indices + self.nstep] 186 | return (obs, act, rew, dis, nobs, nact, latent, imp_act) 187 | 188 | return (obs, act, rew, dis, nobs, latent, act_k, k_step_rand, obs_k) 189 | 190 | def gather_spr_indices(self, indices, jumps=8): 191 | n_samples = indices.shape[0] 192 | all_gather_ranges = np.stack([np.arange(indices[i] - self.frame_stack, indices[i] + self.nstep) 193 | for i in range(n_samples)], axis=0) % self.buffer_size 194 | gather_ranges = all_gather_ranges[:, self.frame_stack:] # bs x nstep 195 | obs_gather_ranges = all_gather_ranges[:, :self.frame_stack] 196 | nobs_gather_ranges = all_gather_ranges[:, -self.frame_stack:] 197 | 198 | all_rewards = self.rew[gather_ranges] #/ np.max(self.rew) 199 | 200 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement 201 | rew = np.sum(all_rewards * self.discount_vec, axis=1, keepdims=True) / gather_ranges.shape[1] 202 | 203 | if self.pixel_samples: 204 | # In case we require pixel observation instead of the saved latents 205 | obs = np.reshape(self.pixel_obs[obs_gather_ranges], [n_samples, *self.pixel_obs_shape]) 206 | nobs = np.reshape(self.pixel_obs[nobs_gather_ranges], [n_samples, *self.pixel_obs_shape]) 207 | else: 208 | obs = np.reshape(self.obs[obs_gather_ranges], [n_samples, *self.obs_shape]) 209 | nobs = np.reshape(self.obs[nobs_gather_ranges], [n_samples, *self.obs_shape]) 210 | 211 | act = self.act[indices] 212 | latent = self.latent[indices] 213 | imp_act = self.imp_act[indices] 214 | dis = np.expand_dims(self.next_dis * self.dis[nobs_gather_ranges[:, -1]], axis=-1) 215 | 216 | k_all_gather_ranges = np.stack([np.arange(indices[i] - self.frame_stack, indices[i] + jumps) 217 | for i in range(n_samples)], axis=0) % self.buffer_size 218 | all_obs = [] 219 | all_pixel_obs = [] 220 | all_act = [] 221 | all_rew = [] 222 | for i in range(jumps+1): 223 | if i == 0: 224 | k_obs_gather_ranges = k_all_gather_ranges[:, :self.frame_stack] 225 | else: 226 | k_obs_gather_ranges = k_all_gather_ranges[:, self.frame_stack + i - self.frame_stack:self.frame_stack + i] 227 | obs_k = np.reshape(self.obs[k_obs_gather_ranges], [n_samples, *self.obs_shape]) 228 | pixel_obs_k = np.reshape(self.pixel_obs[k_obs_gather_ranges], [n_samples, *self.pixel_obs_shape]) 229 | act_k = self.act[indices + i] 230 | all_obs.append(obs_k) 231 | all_pixel_obs.append(pixel_obs_k) 232 | all_act.append(act_k) 233 | 234 | # all_frames_rewards = self.rew[k_all_gather_ranges[:, self.frame_stack-2:]] / np.max(self.rew) 235 | all_frames_rewards = self.rew[k_all_gather_ranges[:, self.frame_stack:]] #/ np.max(self.rew) 236 | discount_vec = np.power(self.discount, np.arange(all_frames_rewards.shape[1])) 237 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement 238 | values = np.sum(all_frames_rewards * discount_vec, axis=1, keepdims=True) 239 | 240 | all_obs = np.stack(all_obs) 241 | all_pixel_obs = np.stack(all_pixel_obs) 242 | all_act = np.stack(all_act) 243 | 244 | k_step = self.k_step[indices].astype(int) 245 | k_step_rand = [] 246 | for each in k_step: 247 | if each > 1: 248 | k_step_rand.append(np.random.randint(low=1, high=each)) 249 | else: 250 | k_step_rand.append(1) 251 | # k_step_rand = [np.random.randint(low=1, high=each) for each in k_step] 252 | k_all_gather_ranges = np.stack([np.arange(indices[i] + k_step_rand[i] - self.frame_stack, indices[i] + k_step_rand[i] + self.nstep) 253 | for i in range(n_samples)], axis=0) % self.buffer_size 254 | k_obs_gather_ranges = k_all_gather_ranges[:, :self.frame_stack] 255 | if self.pixel_samples: 256 | obs_k = np.reshape(self.pixel_obs[k_obs_gather_ranges], [n_samples, *self.pixel_obs_shape]) 257 | else: 258 | obs_k = np.reshape(self.obs[k_obs_gather_ranges], [n_samples, *self.obs_shape]) 259 | 260 | if self.sarsa: 261 | nact = self.act[indices + self.nstep] 262 | return (obs, act, rew, dis, nobs, nact, latent, imp_act) 263 | 264 | return (obs, act, rew, dis, obs_k, all_obs, all_pixel_obs, all_act, values) 265 | # return (obs, act, rew, dis, nobs, all_obs, all_pixel_obs, all_act, values) 266 | 267 | def __len__(self): 268 | if self.full: 269 | return self.buffer_size 270 | else: 271 | return self.index 272 | 273 | def get_train_and_val_indices(self, validation_percentage): 274 | all_indices = self.valid.nonzero()[0] 275 | num_indices = all_indices.shape[0] 276 | num_val = int(num_indices * validation_percentage) 277 | np.random.shuffle(all_indices) 278 | val_indices, train_indices = np.split(all_indices, 279 | [num_val]) 280 | return train_indices, val_indices 281 | 282 | def get_obs_act_batch(self, indices): 283 | n_samples = indices.shape[0] 284 | obs_gather_ranges = np.stack([np.arange(indices[i] - self.frame_stack, indices[i]) 285 | for i in range(n_samples)], axis=0) % self.buffer_size 286 | obs = np.reshape(self.obs[obs_gather_ranges], [n_samples, *self.obs_shape]) 287 | act = self.act[indices] 288 | return obs, act -------------------------------------------------------------------------------- /vae/train_vq_vae_voc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../utils/') 3 | 4 | from transformers.optimization import get_scheduler 5 | from transformers import GPT2LMHeadModel, GPT2Config 6 | 7 | from taming.models.vqgan import VQModel 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from pathlib import Path 15 | from matplotlib import pyplot as plt 16 | 17 | import random 18 | import utils 19 | import os 20 | import math 21 | 22 | from torchvision.utils import make_grid 23 | from drqv2 import RandomShiftsAug 24 | 25 | from numpy_replay_buffer import EfficientReplayBuffer 26 | from utils import load_offline_dataset_into_buffer 27 | 28 | from dm_env import specs 29 | import dmc 30 | 31 | import copy 32 | import wandb 33 | from omegaconf import OmegaConf 34 | 35 | from torch.distributions.categorical import Categorical 36 | 37 | from network_utils import Predictor, EMA, get_parameter_names, update_moving_average 38 | 39 | ###################### 40 | batch_size = 32 41 | 42 | taming_path = "../../" # taming models path 43 | offline_dir = "../../walker_train/vq_latents/" # dataset dir path 44 | vq_model_path = "../../walker_vq_model_step.pth" # vqvae model path 45 | save_dir_path = "./" # voc's gpt model save path 46 | 47 | ###################### 48 | 49 | class TFAgent: 50 | def __init__(self, discount=0.8, augmentation=RandomShiftsAug(pad=4)): 51 | 52 | wandb.init(project="Video Occupancy Models", 53 | id="voc_vqvae_gamma_{}".format(discount), 54 | entity=None, dir=os.getcwd()) 55 | 56 | self.train_env = dmc.make("offline_walker_walk_expert", 3, 2, 0, None) 57 | data_specs = (self.train_env.observation_spec(), 58 | self.train_env.action_spec(), 59 | specs.Array((1,), np.float32, 'reward'), 60 | specs.Array((1,), np.float32, 'discount')) 61 | 62 | self.discount = discount 63 | self.batch_size = batch_size 64 | self.nsteps = min(int(1 / (1 - self.discount)), 10) # only sample max 10 steps 65 | self.codebook_size = 1024 66 | 67 | ######### train on VQ latents ######### 68 | self.replay_buffer = EfficientReplayBuffer(1000000, self.batch_size, 1, self.discount, 3, False, data_specs) 69 | load_offline_dataset_into_buffer(Path(offline_dir), self.replay_buffer, None, 3, 1000000, latent_style="vae") 70 | ########################## 71 | 72 | ######### vq-vae setup ######### 73 | config_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/configs/model.yaml") 74 | config = OmegaConf.load(config_path) 75 | self.model = VQModel(**config.model.params) # Don't put the entire model on device, we only need the decoder 76 | self.quantizer = self.model.quantize.to('cuda') 77 | self.decoder = nn.Sequential(self.model.post_quant_conv, self.model.decoder).to('cuda') 78 | 79 | self.from_imagenet = False 80 | if self.from_imagenet: 81 | ckpt_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/ckpts/last.ckpt") 82 | sd = torch.load(ckpt_path, map_location="cuda")["state_dict"] 83 | missing, unexpected = self.model.load_state_dict(sd, strict=False) 84 | else: 85 | print("Loading fine-tuned VQ model...") 86 | self.model.load_state_dict(torch.load(vq_model_path)) 87 | ########################## 88 | 89 | ######### gpt setup ######### 90 | configuration = GPT2Config(vocab_size=1024, n_layer=4, n_head=8, n_embed=512, resid_pdrop=0.2, embd_pdrop=0.2, attn_prdrop=0.2) 91 | self.gpt = GPT2LMHeadModel(configuration).to('cuda') 92 | self.gpt_target = copy.deepcopy(self.gpt) 93 | self.gpt_target.generation_config.output_scores = True 94 | self.target_ema_updater = EMA(0.9) 95 | ########################## 96 | 97 | ######### optimizer setup ######### 98 | self.reward_predictor = Predictor(256 * 25 * 3).to('cuda') 99 | self.gpt_optimizer = torch.optim.AdamW(self.get_grouped_params(), lr=3e-4) 100 | self.optimizer = torch.optim.AdamW(list(self.reward_predictor.parameters()), lr=3e-4) 101 | 102 | num_training_steps = 100000 103 | self.warmup_ratio = 0.05 104 | warmup_steps = math.ceil(num_training_steps * self.warmup_ratio) 105 | self.lr_scheduler = get_scheduler( 106 | "cosine", 107 | optimizer=self.gpt_optimizer, 108 | num_warmup_steps=warmup_steps, 109 | num_training_steps=num_training_steps, 110 | ) 111 | ########################## 112 | 113 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to('cuda') 114 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to('cuda') 115 | 116 | self.device = 'cuda' 117 | self.aug = augmentation 118 | 119 | self.saving_iter = [50, 100, 500, 1000, 2000, 5000, 10000, 50000, 75000, 100000] 120 | self.train() 121 | 122 | def get_grouped_params(self): 123 | decay_parameters = get_parameter_names(self.gpt, [nn.LayerNorm]) 124 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 125 | optimizer_grouped_parameters = [ 126 | { 127 | "params": [ 128 | p for n, p in self.gpt.named_parameters() if (n in decay_parameters and p.requires_grad) 129 | ], 130 | "weight_decay": 0.1, 131 | }, 132 | { 133 | "params": [ 134 | p for n, p in self.gpt.named_parameters() if (n not in decay_parameters and p.requires_grad) 135 | ], 136 | "weight_decay": 0.0, 137 | }, 138 | ] 139 | return optimizer_grouped_parameters 140 | 141 | def train(self, training=True): 142 | self.training = training 143 | 144 | def update(self, step=0): 145 | metrics = dict() 146 | 147 | batch, indices = next(self.replay_buffer) 148 | obs, action, reward, discount, next_obs, latent, _, _, obs_k = utils.to_torch(batch, self.device) 149 | 150 | # collect discrete vq indices 151 | y_context = obs.long() 152 | y_target = obs_k.long() 153 | obs_shape = obs.shape 154 | 155 | quant_target = self.quantizer.get_codebook_entry(y_target, None) 156 | quant_target = quant_target.view(-1, 5, 5, 256).permute(0, 3, 1, 2) 157 | 158 | pred_reward = self.reward_predictor(quant_target.detach().float().reshape(obs_shape[0], -1)) 159 | reward_loss = F.mse_loss(pred_reward, reward.float()).mean() 160 | 161 | # generate target 162 | with torch.no_grad(): 163 | p_t = self.gpt_target.generate(y_target, max_new_tokens=y_target.shape[-1], do_sample=True, pad_token_id=-100) 164 | p_t = p_t[:, -y_target.shape[-1]:] 165 | 166 | # gamma sampling 167 | gamma = self.discount * torch.ones((y_context.shape[0], ), device=y_context.device) 168 | prob = torch.bernoulli(gamma) 169 | p_target = torch.zeros_like(y_target) 170 | 171 | # with prob 1-gamma, sample from next state 172 | p_c_idx = torch.nonzero(1 - prob) 173 | p_target[p_c_idx] = y_target[p_c_idx] 174 | 175 | # with prob gamma, sample from bootstrapped model 176 | p_t_idx = torch.nonzero(prob) 177 | p_target[p_t_idx] = p_t[p_t_idx] 178 | 179 | 180 | # gpt predictions 181 | inp = torch.cat([y_context, p_target], dim=1) 182 | outputs = self.gpt(inp, labels=inp) 183 | gpt_loss = outputs.loss 184 | 185 | loss = gpt_loss + reward_loss 186 | 187 | loss.backward() 188 | 189 | # grad accumulate 190 | if step % 2 == 0: 191 | self.optimizer.step() 192 | self.gpt_optimizer.step() 193 | 194 | self.optimizer.zero_grad() 195 | self.gpt_optimizer.zero_grad() 196 | 197 | self.lr_scheduler.step() 198 | update_moving_average(self.target_ema_updater, self.gpt_target, self.gpt) 199 | 200 | # visualize predictions 201 | if step % 100 == 0 and step != 0: 202 | with torch.no_grad(): 203 | # sample a batch of traj and corresponding values 204 | batch, indices = self.replay_buffer.sample_spr(jumps=self.nsteps) 205 | _, _, _, _, _, all_obs, all_pixel_obs, _, values = utils.to_torch(batch, self.device) 206 | 207 | # preprocess first obs from traj 208 | obs = all_obs[0] 209 | obs_shape = obs.shape 210 | pixel_obs_shape = F.interpolate(all_pixel_obs[0], size=80).shape 211 | 212 | y_context = obs.long() 213 | 214 | # sample target predictions 215 | p_t = self.gpt.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, pad_token_id=-100)[:, -y_context.shape[-1]:] 216 | 217 | # reconstruct sampled prediction 218 | quant = self.quantizer.get_codebook_entry(p_t, None) 219 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2) 220 | 221 | p_t_pixel_recon = self.decoder(quant) 222 | 223 | viz_imgs = [] 224 | viz_imgs.append(p_t_pixel_recon) 225 | for i in range(all_pixel_obs.shape[0]): 226 | obs = all_pixel_obs[i] 227 | obs = F.interpolate(obs, size=80) 228 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)]) 229 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((pixel_obs_shape[0] * 3, 3, *pixel_obs_shape[2:])) #torch.einsum('nhwc->nchw', obs) 230 | viz_imgs.append(obs) 231 | 232 | value_loss = self.get_value_estimates(y_context, values, obs_shape) 233 | wandb.log({"value loss": value_loss}, step=step) 234 | density_value_loss = self.get_density_value_estimates(y_context, all_obs, obs_shape) 235 | wandb.log({"density value loss": density_value_loss}, step=step) 236 | 237 | viz_imgs = torch.stack(viz_imgs)[:, :8] 238 | t, n, c, h, w = viz_imgs.shape 239 | viz_imgs = torch.einsum('tnchw->ntchw', viz_imgs) 240 | viz_imgs = viz_imgs.reshape(t*n, c, h, w) 241 | viz_img = make_grid(viz_imgs, nrow=t, normalize=True, scale_each=True) 242 | 243 | img = wandb.Image(viz_img) 244 | wandb.log({f"Gamma Pred": img}, step=step) 245 | 246 | wandb.log({"reward loss": reward_loss}, step=step) 247 | wandb.log({"gpt loss": gpt_loss}, step=step) 248 | 249 | # save gpt model 250 | if step in self.saving_iter: 251 | print("saving gpt weights...") 252 | # self.save_gpt_weights(step) 253 | 254 | return metrics 255 | 256 | def save_gpt_weights(self, step): 257 | torch.save(self.gpt.state_dict(), os.path.join(save_dir_path, "vqvae_microgpt_gamma_{}_{}_model_step_{}.pth".format(self.discount, self.codebook_size, step))) 258 | 259 | def get_value_estimates(self, y_context, values, obs_shape): 260 | # Take a state, get samples from the gamma distribution, 261 | # Run the reward predictor through these to get value estimates 262 | # Get ground truth value estimates by simply taking discounted sum of rewards 263 | # Compare these for different states 264 | 265 | num_gamma_samples = 100 266 | values_pred = [] 267 | 268 | for i in range(num_gamma_samples): 269 | outputs = self.gpt_target.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, output_scores=True, return_dict_in_generate=True, pad_token_id=-100) 270 | p_t = outputs.sequences[:, -y_context.shape[-1]:] 271 | 272 | quant = self.quantizer.get_codebook_entry(p_t, None) 273 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2) 274 | 275 | values_pred.append(self.reward_predictor(quant.float().reshape(obs_shape[0], -1)).squeeze(1)) 276 | 277 | values_pred = torch.stack(values_pred).sum(0) / (100 * (1 - self.discount)) 278 | 279 | value_estimation_loss = F.mse_loss(values_pred, values.squeeze(1).float()).mean() 280 | print("val estimation", value_estimation_loss, values_pred[:5], values[:5]) 281 | 282 | return value_estimation_loss 283 | 284 | def get_density_value_estimates(self, y_context, all_obs, obs_shape): 285 | 286 | values_pred = [] 287 | values_actual = [] 288 | for i in range(all_obs.shape[0]-1): 289 | y_target = all_obs[i+1].long() 290 | quant_target = self.quantizer.get_codebook_entry(y_target, None) 291 | quant_target = quant_target.view(-1, 5, 5, 256).permute(0, 3, 1, 2) 292 | 293 | inp = torch.cat([y_context, y_target], dim=1) 294 | outputs = self.gpt(inp, labels=inp) 295 | logits = outputs.logits[:, -y_target.shape[1]-1:-1] 296 | scores = torch.nn.functional.log_softmax(logits, dim=2) 297 | 298 | gathered_scores = torch.gather(scores, dim=2, index=y_target.unsqueeze(2)) 299 | gathered_logits = torch.gather(logits, dim=2, index=y_target.unsqueeze(2)) 300 | 301 | input_length = y_target.shape[1] 302 | output_length = input_length + torch.sum(gathered_logits < 0, dim=1) 303 | prob = torch.exp(gathered_scores.sum(1) / output_length) 304 | values_pred.append(prob.squeeze(1) * self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1)) 305 | values_actual.append(self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1)) 306 | 307 | values_pred = torch.stack(values_pred).sum(0) 308 | 309 | discount_vec = torch.pow(self.discount, torch.arange(torch.stack(values_actual).shape[0], device='cuda')) 310 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement 311 | values_actual = torch.sum(torch.stack(values_actual) * discount_vec.repeat(torch.stack(values_actual).shape[1], 1).T, dim=0) 312 | 313 | value_estimation_loss = F.mse_loss(values_pred, values_actual).mean() 314 | print("density val estimation", value_estimation_loss, values_pred[:5], values_actual[:5]) 315 | 316 | return value_estimation_loss 317 | 318 | import argparse 319 | 320 | if __name__ == "__main__": 321 | parser = argparse.ArgumentParser(description='Create a dictionary with command-line arguments.') 322 | 323 | parser.add_argument('--discount', type=float, default=0.8, help='discount') 324 | args = parser.parse_args() 325 | 326 | agent = TFAgent(discount=args.discount) 327 | 328 | agent.optimizer.zero_grad() 329 | agent.gpt_optimizer.zero_grad() 330 | 331 | for step in range(100000): 332 | agent.update(step) -------------------------------------------------------------------------------- /vae/train_pixel_vq_vae_voc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../utils/') 3 | 4 | from transformers.optimization import get_scheduler 5 | from transformers import GPT2LMHeadModel, GPT2Config 6 | 7 | from taming.models.vqgan import VQModel 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from pathlib import Path 15 | from matplotlib import pyplot as plt 16 | 17 | import random 18 | import utils 19 | import os 20 | import math 21 | 22 | from torchvision.utils import make_grid 23 | from drqv2 import RandomShiftsAug 24 | 25 | from numpy_replay_buffer import EfficientReplayBuffer 26 | from utils import load_offline_dataset_into_buffer 27 | 28 | from dm_env import specs 29 | import dmc 30 | 31 | import copy 32 | import wandb 33 | from omegaconf import OmegaConf 34 | from torchvision.utils import make_grid 35 | 36 | from network_utils import Predictor, EMA, get_parameter_names, update_moving_average 37 | 38 | ###################### 39 | batch_size = 32 40 | 41 | taming_path = "../../" # taming models path 42 | offline_dir = "../../walker_train/vq_latents/" # dataset dir path 43 | vq_model_path = "../../walker_vq_model_step.pth" # vqvae model path 44 | save_dir_path = "./" # voc's gpt model save path 45 | 46 | ###################### 47 | 48 | class TFAgent: 49 | def __init__(self, discount=0.8, codebook_size=1024, augmentation=RandomShiftsAug(pad=4)): 50 | 51 | wandb.init(project="Video Occupancy Models", 52 | id="voc_vqvae_gamma_{}_td_{}".format(discount, codebook_size), 53 | entity=None, dir=os.getcwd()) 54 | 55 | self.train_env = dmc.make("offline_walker_walk_expert", 3, 2, 0, None) 56 | data_specs = (self.train_env.observation_spec(), 57 | self.train_env.action_spec(), 58 | specs.Array((1,), np.float32, 'reward'), 59 | specs.Array((1,), np.float32, 'discount')) 60 | 61 | self.discount = discount 62 | self.codebook_size = codebook_size 63 | self.batch_size = batch_size 64 | self.replay_buffer = EfficientReplayBuffer(1000000, self.batch_size, 1, self.discount, 3, False, data_specs, pixel_samples=True) 65 | 66 | load_offline_dataset_into_buffer(Path(offline_dir), self.replay_buffer, None, 3, 1000000) 67 | 68 | config_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/configs/model.yaml") 69 | config = OmegaConf.load(config_path) 70 | 71 | config.model.params.n_embed = self.codebook_size 72 | self.model = VQModel(**config.model.params).to('cuda') 73 | 74 | self.from_imagenet = False 75 | if self.from_imagenet: 76 | ckpt_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/ckpts/last.ckpt") 77 | sd = torch.load(ckpt_path, map_location="cuda")["state_dict"] 78 | missing, unexpected = self.model.load_state_dict(sd, strict=False) 79 | else: 80 | print("Loading fine-tuned VQ model...") 81 | self.model.load_state_dict(torch.load(vq_model_path)) 82 | 83 | configuration = GPT2Config(vocab_size=self.codebook_size, n_layer=4, n_head=8, n_embed=512, resid_pdrop=0.2, embd_pdrop=0.2, attn_prdrop=0.2) 84 | self.gpt = GPT2LMHeadModel(configuration).to('cuda') 85 | 86 | self.gpt_target = copy.deepcopy(self.gpt) 87 | self.target_ema_updater = EMA(0.9) 88 | 89 | self.gpt_target.generation_config.output_scores = True 90 | 91 | self.reward_predictor = Predictor(256 * 25 * 3).to('cuda') 92 | 93 | self.gpt_optimizer = torch.optim.AdamW(self.get_grouped_params(), lr=3e-4) 94 | self.optimizer = torch.optim.AdamW(list(self.model.parameters()) + list(self.reward_predictor.parameters()), lr=3e-4) 95 | 96 | num_training_steps = 100000 97 | self.warmup_ratio = 0.05 98 | warmup_steps = math.ceil(num_training_steps * self.warmup_ratio) 99 | self.lr_scheduler = get_scheduler( 100 | "cosine", 101 | optimizer=self.gpt_optimizer, 102 | num_warmup_steps=warmup_steps, 103 | num_training_steps=num_training_steps, 104 | ) 105 | 106 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to('cuda') 107 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to('cuda') 108 | 109 | self.device = 'cuda' 110 | self.aug = augmentation 111 | 112 | self.saving_iter = [50, 100, 500, 1000, 2000, 5000, 10000, 50000, 75000, 100000] 113 | self.train() 114 | 115 | def get_grouped_params(self): 116 | decay_parameters = get_parameter_names(self.gpt, [nn.LayerNorm]) 117 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 118 | optimizer_grouped_parameters = [ 119 | { 120 | "params": [ 121 | p for n, p in self.gpt.named_parameters() if (n in decay_parameters and p.requires_grad) 122 | ], 123 | "weight_decay": 0.1, 124 | }, 125 | { 126 | "params": [ 127 | p for n, p in self.gpt.named_parameters() if (n not in decay_parameters and p.requires_grad) 128 | ], 129 | "weight_decay": 0.0, 130 | }, 131 | ] 132 | return optimizer_grouped_parameters 133 | 134 | def train(self, training=True): 135 | self.training = training 136 | 137 | def update(self, step=0): 138 | metrics = dict() 139 | 140 | batch, indices = next(self.replay_buffer) 141 | obs, action, reward, discount, next_obs, latent, _, _, obs_k = utils.to_torch( 142 | batch, self.device) 143 | 144 | # augment 145 | obs = self.aug(obs.float()) 146 | obs_k = self.aug(next_obs.float()) 147 | 148 | # reshape obs 149 | obs = F.interpolate(obs, size=80) 150 | obs_k = F.interpolate(obs_k, size=80) 151 | 152 | obs_shape = obs.shape 153 | 154 | # normalize and preprocess 155 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)]) 156 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((obs_shape[0] * 3, 3, *obs_shape[2:])) 157 | 158 | obs_k = torch.stack([torch.einsum('nchw->nhwc', obs_k[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)]) 159 | obs_k = torch.einsum('tnhwc->ntchw', obs_k).reshape((obs_shape[0] * 3, 3, *obs_shape[2:])) 160 | 161 | # vq embed 162 | quant_context, emb_loss_context, info_context = self.model.encode(obs) 163 | quant_target, emb_loss_target, info_target = self.model.encode(obs_k) 164 | 165 | # collect discrete vq indices 166 | y_context = info_context[2].view(obs_shape[0], -1).detach() 167 | y_target = info_target[2].view(obs_shape[0], -1).detach() 168 | 169 | pred_reward = self.reward_predictor(quant_target.detach().float().reshape(obs_shape[0], -1)) 170 | reward_loss = F.mse_loss(pred_reward, reward.float()).mean() 171 | 172 | # generate target 173 | with torch.no_grad(): 174 | p_t = self.gpt_target.generate(y_target, max_new_tokens=y_target.shape[-1], do_sample=True, pad_token_id=-100) 175 | p_t = p_t[:, -y_target.shape[-1]:] 176 | 177 | # gamma sampling 178 | gamma = self.discount * torch.ones((y_context.shape[0], ), device=y_context.device) 179 | prob = torch.bernoulli(gamma) 180 | p_target = torch.zeros_like(y_target) 181 | 182 | # with prob 1-gamma, sample from next state 183 | p_c_idx = torch.nonzero(1 - prob) 184 | p_target[p_c_idx] = y_target[p_c_idx] 185 | 186 | # with prob gamma, sample from bootstrapped model 187 | p_t_idx = torch.nonzero(prob) 188 | p_target[p_t_idx] = p_t[p_t_idx] 189 | 190 | xrec, qloss = self.model.decode(quant_context), emb_loss_context 191 | vae_loss, log_dict_ae = self.model.loss(qloss, obs, xrec, 0, step, last_layer=self.model.get_last_layer(), split="train") 192 | 193 | # gpt predictions 194 | inp = torch.cat([y_context, p_target], dim=1) 195 | outputs = self.gpt(inp, labels=inp) 196 | gpt_loss = outputs.loss 197 | 198 | loss = vae_loss + gpt_loss + reward_loss 199 | 200 | loss.backward() 201 | 202 | # grad accumulate 203 | if step % 2 == 0: 204 | self.optimizer.step() 205 | self.gpt_optimizer.step() 206 | 207 | self.optimizer.zero_grad() 208 | self.gpt_optimizer.zero_grad() 209 | 210 | self.lr_scheduler.step() 211 | update_moving_average(self.target_ema_updater, self.gpt_target, self.gpt) 212 | 213 | # visualize predictions 214 | if step % 100 == 0: 215 | with torch.no_grad(): 216 | # sample a batch of traj and corresponding values 217 | batch, indices = self.replay_buffer.sample_spr() 218 | _, _, _, _, _, _, all_obs, _, values = utils.to_torch(batch, self.device) 219 | 220 | # preprocess first obs from traj 221 | obs = F.interpolate(all_obs[0], size=80) 222 | obs_shape = obs.shape 223 | 224 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)]) 225 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((obs_shape[0] * 3, 3, *obs_shape[2:])) #torch.einsum('nhwc->nchw', obs) 226 | 227 | # vq embed first obs 228 | quant_context, emb_loss_context, info_context = self.model.encode(obs) 229 | y_context = info_context[2].view(obs_shape[0], -1).detach() 230 | 231 | # sample target predictions 232 | p_t = self.gpt.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, pad_token_id=-100)[:, -y_context.shape[-1]:] 233 | 234 | # reconstruct sampled prediction 235 | quant = self.model.quantize.get_codebook_entry(y_context, None) #self.model.quantize.get_codebook_entry(p_t, None) 236 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2) 237 | 238 | y_pixel_recon = self.model.decode(quant) 239 | 240 | quant = self.model.quantize.get_codebook_entry(p_t, None) #self.model.quantize.get_codebook_entry(p_t, None) 241 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2) 242 | 243 | p_t_pixel_recon = self.model.decode(quant) 244 | 245 | viz_imgs = [] 246 | viz_imgs.append(p_t_pixel_recon) 247 | viz_imgs.append(y_pixel_recon) 248 | all_obs = self.paint_obs(all_obs) 249 | for i in range(all_obs.shape[0]): 250 | obs = all_obs[i] 251 | obs = F.interpolate(obs, size=80) 252 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)]) 253 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((obs_shape[0] * 3, 3, *obs_shape[2:])) #torch.einsum('nhwc->nchw', obs) 254 | viz_imgs.append(obs) 255 | 256 | value_loss = self.get_value_estimates(y_context, values, obs_shape) 257 | wandb.log({"value loss": value_loss}, step=step) 258 | # density_value_loss = self.get_density_value_estimates(y_context, viz_imgs, obs_shape) 259 | # wandb.log({"density value loss": density_value_loss}, step=step) 260 | 261 | viz_imgs = torch.stack(viz_imgs)[:, :8] 262 | t, n, c, h, w = viz_imgs.shape 263 | viz_imgs = torch.einsum('tnchw->ntchw', viz_imgs) 264 | viz_imgs = viz_imgs.reshape(t*n, c, h, w) 265 | viz_img = make_grid(viz_imgs, nrow=t, normalize=True, scale_each=True) 266 | 267 | img = wandb.Image(viz_img) 268 | wandb.log({f"Gamma Pred": img}, step=step) 269 | 270 | wandb.log({"reward loss": reward_loss}, step=step) 271 | 272 | # if finetuning, save vq model at 2k steps 273 | if step in self.saving_iter: 274 | print("saving gpt weights...") 275 | self.save_vq_weights(step) 276 | self.save_gpt_weights(step) 277 | 278 | return metrics 279 | 280 | def save_vq_weights(self, step): 281 | torch.save(self.model.state_dict(), os.path.join(save_dir_path, "vqvae_model_{}_td_{}_step_{}.pth".format(self.discount, self.codebook_size, step))) 282 | 283 | def save_gpt_weights(self, step): 284 | torch.save(self.gpt.state_dict(), os.path.join(save_dir_path, "pixel_vqvae_microgpt_gamma_{}_{}_model_step_{}.pth".format(self.discount, self.codebook_size, step))) 285 | 286 | def get_value_estimates(self, y_context, values, obs_shape): 287 | # Take a state, get samples from the gamma distribution, 288 | # Run the reward predictor through these to get value estimates 289 | # Get ground truth value estimates by simply taking discounted sum of rewards 290 | # Compare these for different states 291 | 292 | num_gamma_samples = 100 293 | values_pred = [] 294 | 295 | for i in range(num_gamma_samples): 296 | outputs = self.gpt_target.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, output_scores=True, return_dict_in_generate=True, pad_token_id=-100) 297 | p_t = outputs.sequences[:, -y_context.shape[-1]:] 298 | 299 | quant = self.model.quantize.get_codebook_entry(p_t, None) 300 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2) 301 | 302 | values_pred.append(self.reward_predictor(quant.float().reshape(obs_shape[0], -1)).squeeze(1)) 303 | 304 | values_pred = torch.stack(values_pred).sum(0) / (100 * (1 - self.discount)) 305 | 306 | value_estimation_loss = F.mse_loss(values_pred, values.squeeze(1).float()).mean() 307 | print("val estimation", value_estimation_loss, values_pred[:5], values[:5]) 308 | 309 | return value_estimation_loss 310 | 311 | def get_density_value_estimates(self, y_context, all_obs, obs_shape): 312 | 313 | values_pred = [] 314 | values_actual = [] 315 | for i in range(all_obs.shape[0]-1): 316 | quant_target, _, info_target = self.model.encode(all_obs[i+1]) 317 | y_target = info_target[2].view(obs_shape[0], -1).detach() 318 | 319 | inp = torch.cat([y_context, y_target], dim=1) 320 | outputs = self.gpt(inp, labels=inp) 321 | logits = outputs.logits[:, -y_target.shape[1]-1:-1] 322 | scores = torch.nn.functional.log_softmax(logits, dim=2) 323 | 324 | gathered_scores = torch.gather(scores, dim=2, index=y_target.unsqueeze(2)) 325 | gathered_logits = torch.gather(logits, dim=2, index=y_target.unsqueeze(2)) 326 | 327 | input_length = y_target.shape[1] 328 | output_length = input_length + torch.sum(gathered_logits < 0, dim=1) 329 | prob = torch.exp(gathered_scores.sum(1) / output_length) 330 | values_pred.append(prob.squeeze(1) * self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1)) 331 | values_actual.append(self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1)) 332 | 333 | values_pred = torch.stack(values_pred).sum(0) 334 | 335 | discount_vec = torch.pow(self.discount, torch.arange(torch.stack(values_actual).shape[0], device='cuda')) 336 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement 337 | values_actual = torch.sum(torch.stack(values_actual) * discount_vec.repeat(torch.stack(values_actual).shape[1], 1).T, dim=0) 338 | 339 | value_estimation_loss = F.mse_loss(values_pred, values_actual).mean() 340 | print("density val estimation", value_estimation_loss, values_pred[:5], values_actual[:5]) 341 | 342 | return value_estimation_loss 343 | 344 | def paint_obs(self, all_obs): 345 | for i in range(all_obs.shape[0]): 346 | obs = all_obs[i] # get first of all frame of all batches, B x 9 x 80 x 80 347 | 348 | # Every ith frame is colored one way, for all batches 349 | obs[:, ::3, :5, -5:] = 255.0 * i / all_obs.shape[0] # set top corner for first channel to 255 350 | obs[:, 1::3, :5, -5:] = 0.0 # set top corner for second channel to 0 351 | obs[:, 2::3, :5, -5:] = 0.0 # set top corner for third channel to 0 352 | 353 | all_obs[i] = obs 354 | 355 | return all_obs 356 | 357 | 358 | import argparse 359 | 360 | if __name__ == "__main__": 361 | parser = argparse.ArgumentParser(description='Create a dictionary with command-line arguments.') 362 | 363 | parser.add_argument('--discount', type=float, default=0.8, help='discount') 364 | parser.add_argument('--codebook_size', type=int, default=1024, help='codebook size') 365 | args = parser.parse_args() 366 | 367 | agent = TFAgent(discount=args.discount, codebook_size=args.codebook_size) 368 | 369 | agent.optimizer.zero_grad() 370 | agent.gpt_optimizer.zero_grad() 371 | 372 | for step in range(100000): 373 | agent.update(step) -------------------------------------------------------------------------------- /musik/train_vq_musik_voc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../utils/') 3 | 4 | from musik_model import VQMUSIKModel 5 | 6 | import os 7 | import math 8 | import copy 9 | import wandb 10 | import random 11 | import utils 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from omegaconf import OmegaConf 18 | 19 | from pathlib import Path 20 | from matplotlib import pyplot as plt 21 | from torchvision.utils import make_grid 22 | 23 | from numpy_replay_buffer import EfficientReplayBuffer 24 | from utils import load_offline_dataset_into_buffer 25 | 26 | from transformers.optimization import get_scheduler 27 | from transformers import GPT2LMHeadModel, GPT2Config 28 | 29 | import dmc 30 | from dm_env import specs 31 | from drqv2 import RandomShiftsAug 32 | 33 | from network_utils import Predictor, EMA, get_parameter_names, update_moving_average 34 | from network_utils import InfoNCE 35 | 36 | from omegaconf import OmegaConf 37 | 38 | ################## 39 | obs_resize = 80 40 | mae_seq_len = 65 # no. of tokens actually processed once masking is done 41 | embed_size = 128 42 | 43 | action_dim = 6 44 | 45 | taming_path = "../../" # taming models path 46 | offline_dir = "../../cheetah_train/vq_latents/" 47 | save_dir_path = "./" # voc's gpt model save path 48 | 49 | batch_size = 32 50 | ################## 51 | 52 | class TFAgent: 53 | def __init__(self, discount=0.8, codebook_size=1024, augmentation=RandomShiftsAug(pad=4)): 54 | 55 | wandb.init(project="Video Occupancy Models", 56 | id="voc_musik_gamma_{}_{}".format(discount, codebook_size), 57 | entity=None, dir=os.getcwd()) 58 | 59 | self.train_env = dmc.make("offline_cheetah_run_expert", 3, 2, 0, None) 60 | data_specs = (self.train_env.observation_spec(), 61 | self.train_env.action_spec(), 62 | specs.Array((1,), np.float32, 'reward'), 63 | specs.Array((1,), np.float32, 'discount')) 64 | 65 | self.discount = discount 66 | self.batch_size = batch_size 67 | self.nsteps = min(int(1 / (1 - self.discount)), 10) # only sample max 10 steps 68 | self.codebook_size = codebook_size 69 | 70 | ######### random and mixed data buffers ######### 71 | self.replay_buffer = EfficientReplayBuffer(1000000, self.batch_size, 1, self.discount, 3, False, data_specs, pixel_samples=True) 72 | load_offline_dataset_into_buffer(Path(offline_dir), self.replay_buffer, None, 3, 1000000, future_sampling_steps=15) 73 | ########################## 74 | 75 | ######### vq-musik setup ######### 76 | config_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/configs/model.yaml") 77 | config = OmegaConf.load(config_path) 78 | # config.model.params.ddconfig.in_channels = 9 79 | config.model.params.n_embed = self.codebook_size 80 | self.model = VQMUSIKModel(**config.model.params).to('cuda') 81 | 82 | self.from_imagenet = False 83 | if self.from_imagenet: 84 | ckpt_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/ckpts/last.ckpt") 85 | sd = torch.load(ckpt_path, map_location="cuda")["state_dict"] 86 | missing, unexpected = self.model.encoder.load_state_dict(sd, strict=False) 87 | 88 | # self.model.load_state_dict(torch.load("/home/manant/scratch/vq_model_7000.pth"), strict=False) 89 | 90 | # Musik predictor network for training the representation 91 | self.musik_predictor = InfoNCE(1542, action_dim, 1).to('cuda') 92 | ########################## 93 | 94 | ######### gpt setup ######### 95 | configuration = GPT2Config(vocab_size=self.codebook_size, n_layer=4, n_head=8, n_embed=512, resid_pdrop=0.2, embd_pdrop=0.2, attn_prdrop=0.2) 96 | self.gpt = GPT2LMHeadModel(configuration).to('cuda') 97 | self.gpt_target = copy.deepcopy(self.gpt) 98 | self.gpt_target.generation_config.output_scores = True 99 | self.target_ema_updater = EMA(0.9) 100 | self.encoder_target_ema_updater = EMA(0.9) 101 | ########################## 102 | 103 | ######### optimizer setup ######### 104 | self.reward_predictor = Predictor(19200).to('cuda') # embed_size x mae_seq_len x num_codebooks x frame_stack = 256 * 25 * 3 105 | self.gpt_optimizer = torch.optim.AdamW(list(self.get_grouped_params(self.gpt)), lr=3e-4) 106 | self.optimizer = torch.optim.AdamW(list(self.reward_predictor.parameters()) + list(self.musik_predictor.parameters()) + list(self.model.parameters()), lr=3e-4) 107 | 108 | num_training_steps = 100000 109 | self.warmup_ratio = 0.05 110 | warmup_steps = math.ceil(num_training_steps * self.warmup_ratio) 111 | self.lr_scheduler = get_scheduler( 112 | "cosine", 113 | optimizer=self.gpt_optimizer, 114 | num_warmup_steps=warmup_steps, 115 | num_training_steps=num_training_steps, 116 | ) 117 | ########################## 118 | 119 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to('cuda') 120 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to('cuda') 121 | 122 | self.device = 'cuda' 123 | self.aug = augmentation 124 | 125 | self.saving_iter = [50, 100, 500, 1000, 2000, 5000, 10000, 50000, 75000, 100000] 126 | self.train() 127 | self.model.train() 128 | 129 | def get_grouped_params(self, model): 130 | decay_parameters = get_parameter_names(model, [nn.LayerNorm]) 131 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 132 | optimizer_grouped_parameters = [ 133 | { 134 | "params": [ 135 | p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad) 136 | ], 137 | "weight_decay": 0.1, 138 | }, 139 | { 140 | "params": [ 141 | p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad) 142 | ], 143 | "weight_decay": 0.0, 144 | }, 145 | ] 146 | return optimizer_grouped_parameters 147 | 148 | def train(self, training=True): 149 | self.training = training 150 | 151 | def preprocess_obs(self, obs): 152 | obs = F.interpolate(obs, size=obs_resize) 153 | 154 | if len(obs.shape) == 3: 155 | obs = obs.unsqueeze(0) 156 | org_obs_shape = obs.shape 157 | 158 | try: 159 | assert len(obs.shape) == 4 # B x C x H x W 160 | org_obs_shape = obs.shape 161 | # normalize and preprocess 162 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)]) 163 | obs = torch.einsum('snhwc->nschw', obs).reshape((org_obs_shape[0] * 3, 3, *org_obs_shape[2:])) 164 | except: 165 | assert len(obs.shape) == 5 # T x B x C x H x W 166 | org_obs_shape = t, b, c, h, w = obs.shape 167 | obs = torch.stack([torch.einsum('tnchw->tnhwc', obs[:, :, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)]) 168 | obs = torch.einsum('stnhwc->tnschw', obs).reshape((t, b * 3, 3, h, w)) 169 | return obs, org_obs_shape 170 | 171 | def update(self, step=0): 172 | metrics = dict() 173 | 174 | batch, indices = next(self.replay_buffer) 175 | obs, action, reward, discount, next_obs, _, _, _, obs_k = utils.to_torch(batch, self.device) 176 | 177 | # augment 178 | obs = self.aug(obs.float()) 179 | next_obs = self.aug(next_obs.float()) 180 | obs_k = self.aug(obs_k.float()) 181 | 182 | obs, obs_shape = self.preprocess_obs(obs) # process current obs 183 | next_obs, _ = self.preprocess_obs(next_obs) 184 | obs_k, _ = self.preprocess_obs(obs_k) # process next/future obs 185 | 186 | quant_context, emb_loss_context, info_context = self.model.encode(obs) 187 | quant_future_target, emb_loss_future_target, info_future_target = self.model.encode(obs_k) 188 | quant_target, emb_loss_target, info_target = self.model.encode(next_obs) 189 | 190 | y_context = info_context[2].view(obs_shape[0], -1).detach() 191 | y_target = info_target[2].view(obs_shape[0], -1).detach() 192 | 193 | pred_reward = self.reward_predictor(quant_target.detach().float().reshape(obs_shape[0], -1)) 194 | reward_loss = F.mse_loss(pred_reward, reward.float()).mean() 195 | 196 | # generate target 197 | with torch.no_grad(): 198 | p_t = self.gpt_target.generate(y_target, max_new_tokens=y_target.shape[-1], do_sample=True, pad_token_id=-100) 199 | p_t = p_t[:, -y_target.shape[-1]:] 200 | 201 | # gamma sampling 202 | gamma = self.discount * torch.ones((y_context.shape[0], ), device=y_context.device) 203 | prob = torch.bernoulli(gamma) 204 | p_target = torch.zeros_like(y_target) 205 | 206 | # with prob 1-gamma, sample from next state 207 | p_c_idx = torch.nonzero(1 - prob) 208 | p_target[p_c_idx] = y_target[p_c_idx] 209 | 210 | # with prob gamma, sample from bootstrapped model 211 | p_t_idx = torch.nonzero(prob) 212 | p_target[p_t_idx] = p_t[p_t_idx] 213 | 214 | # gpt predictions 215 | inp = torch.cat([y_context, p_target], dim=1) 216 | outputs = self.gpt(inp, labels=inp) 217 | gpt_loss = outputs.loss 218 | 219 | enc = self.model.decode_linear(quant_context).reshape((obs_shape[0], -1)) 220 | enc_k = self.model.decode_linear(quant_future_target).reshape((obs_shape[0], -1)) 221 | 222 | musik_loss = self.musik_predictor(enc, enc_k, action) + emb_loss_context + emb_loss_future_target # musik loss + codebook loss 223 | 224 | loss = reward_loss + gpt_loss + musik_loss 225 | loss.backward() 226 | 227 | # grad accumulate 228 | if step % 2 == 0: 229 | self.optimizer.step() 230 | self.gpt_optimizer.step() 231 | 232 | self.optimizer.zero_grad() 233 | self.gpt_optimizer.zero_grad() 234 | 235 | self.lr_scheduler.step() 236 | update_moving_average(self.target_ema_updater, self.gpt_target, self.gpt) 237 | 238 | # visualize predictions 239 | if step % 200 == 0: 240 | with torch.no_grad(): 241 | # sample a batch of traj and corresponding values 242 | batch, indices = self.replay_buffer.sample_spr() 243 | _, _, _, _, _, _, all_obs, _, values = utils.to_torch(batch, self.device) 244 | 245 | # preprocess first obs from traj 246 | obs = F.interpolate(all_obs[0], size=obs_resize) 247 | obs_shape = obs.shape 248 | 249 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)]) 250 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((obs_shape[0] * 3, 3, *obs_shape[2:])) #torch.einsum('nhwc->nchw', obs) 251 | # obs = torch.einsum('snhwc->nschw', obs).reshape((obs_shape[0], 9, *obs_shape[2:])) 252 | 253 | # vq embed first obs 254 | quant_context, emb_loss_context, info_context = self.model.encode(obs) 255 | y_context = info_context[2].view(obs_shape[0], -1).detach() 256 | 257 | value_loss = self.get_value_estimates(y_context, values, obs_shape) 258 | wandb.log({"value loss": value_loss}, step=step) 259 | 260 | density_value_loss = self.get_density_value_estimates(y_context, all_obs, obs_shape) 261 | wandb.log({"density value loss": density_value_loss}, step=step) 262 | 263 | print("losses are", reward_loss, musik_loss, gpt_loss, emb_loss_context) 264 | 265 | wandb.log({"gpt loss": gpt_loss}, step=step) 266 | wandb.log({"rep loss": musik_loss}, step=step) 267 | wandb.log({"reward loss": reward_loss}, step=step) 268 | 269 | # save gpt model 270 | if step in self.saving_iter: 271 | print("saving gpt weights...") 272 | self.save_musik_weights(step) 273 | self.save_gpt_weights(step) 274 | 275 | return metrics 276 | 277 | def save_musik_weights(self, step): 278 | torch.save(self.model.state_dict(), os.path.join(save_dir_path, "vq_musik_model_{}_{}_model_step_{}.pth".format(self.discount, self.codebook_size, step))) 279 | 280 | def save_gpt_weights(self, step): 281 | torch.save(self.gpt.state_dict(), "/home/manant/scratch/pixel_gamma/checkpoints/may_3_runs/musik/vq_conv_musik_microgpt_gamma_{}_{}_{}_model_step_{}.pth".format(self.discount, self.target_style, self.codebook_size, step)) 282 | torch.save(self.gpt.state_dict(), os.path.join(save_dir_path, "musik_microgpt_gamma_{}_{}_model_step_{}.pth".format(self.discount, self.codebook_size, step))) 283 | 284 | def get_value_estimates(self, y_context, values, obs_shape): 285 | # Take a state, get samples from the gamma distribution, 286 | # Run the reward predictor through these to get value estimates 287 | # Get ground truth value estimates by simply taking discounted sum of rewards 288 | # Compare these for different states 289 | 290 | num_gamma_samples = 100 291 | values_pred = [] 292 | 293 | for i in range(num_gamma_samples): 294 | outputs = self.gpt_target.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, output_scores=True, return_dict_in_generate=True, pad_token_id=-100) #, kwargs={'token_type_ids': context_mask_ids}) 295 | p_t = outputs.sequences[:, -y_context.shape[-1]:] 296 | 297 | quant = self.model.quantize.get_codebook_entry(p_t, None) 298 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2) 299 | 300 | values_pred.append(self.reward_predictor(quant.float().reshape(obs_shape[0], -1)).squeeze(1)) 301 | 302 | values_pred = torch.stack(values_pred).sum(0) / (100 * (1 - self.discount)) 303 | 304 | value_estimation_loss = F.mse_loss(values_pred, values.squeeze(1).float()).mean() 305 | print("val estimation", value_estimation_loss, values_pred[:5], values[:5]) 306 | 307 | return value_estimation_loss 308 | 309 | def get_density_value_estimates(self, y_context, all_obs, obs_shape): 310 | 311 | values_pred = [] 312 | values_actual = [] 313 | for i in range(all_obs.shape[0]-1): 314 | obs = F.interpolate(all_obs[i+1], size=obs_resize) 315 | obs_shape = obs.shape 316 | 317 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)]) 318 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((obs_shape[0] * 3, 3, *obs_shape[2:])) 319 | quant_target, emb_loss_target, info_target = self.model.encode(obs) 320 | y_target = info_target[2].view(obs_shape[0], -1).detach() 321 | 322 | inp = torch.cat([y_context, y_target], dim=1) 323 | outputs = self.gpt(inp, labels=inp) 324 | logits = outputs.logits[:, -y_target.shape[1]-1:-1] 325 | scores = torch.nn.functional.log_softmax(logits, dim=2) 326 | 327 | gathered_scores = torch.gather(scores, dim=2, index=y_target.unsqueeze(2)) 328 | gathered_logits = torch.gather(logits, dim=2, index=y_target.unsqueeze(2)) 329 | 330 | input_length = y_target.shape[1] 331 | output_length = input_length + torch.sum(gathered_logits < 0, dim=1) 332 | prob = torch.exp(gathered_scores.sum(1) / output_length) 333 | values_pred.append(prob.squeeze(1) * self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1)) 334 | values_actual.append(self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1)) 335 | 336 | values_pred = torch.stack(values_pred).sum(0) 337 | 338 | discount_vec = torch.pow(self.discount, torch.arange(torch.stack(values_actual).shape[0], device='cuda')) 339 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement 340 | values_actual = torch.sum(torch.stack(values_actual) * discount_vec.repeat(torch.stack(values_actual).shape[1], 1).T, dim=0) 341 | 342 | value_estimation_loss = F.mse_loss(values_pred, values_actual).mean() 343 | print("density val estimation", value_estimation_loss, values_pred[:5], values_actual[:5]) 344 | 345 | return value_estimation_loss 346 | 347 | 348 | import argparse 349 | 350 | if __name__ == "__main__": 351 | parser = argparse.ArgumentParser(description='Create a dictionary with command-line arguments.') 352 | 353 | parser.add_argument('--discount', type=float, default=0.8, help='discount') 354 | parser.add_argument('--codebook_size', type=int, default=1024, help='codebook size') 355 | args = parser.parse_args() 356 | 357 | agent = TFAgent(discount=args.discount, codebook_size=args.codebook_size) 358 | 359 | agent.optimizer.zero_grad() 360 | agent.gpt_optimizer.zero_grad() 361 | 362 | for step in range(100000): 363 | agent.update(step) -------------------------------------------------------------------------------- /dino/modeling_vqkd.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beitv2 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Zhiliang Peng 7 | # Based on VQGAN code bases 8 | # https://github.com/CompVis/taming-transformers 9 | # --------------------------------------------------------' 10 | 11 | import torch 12 | import numpy as np 13 | from torch import nn, einsum 14 | import torch.nn.functional as F 15 | import math 16 | from collections import OrderedDict 17 | from functools import partial, reduce 18 | from einops import rearrange 19 | from timm.models.layers import trunc_normal_ 20 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 21 | from timm.models.registry import register_model 22 | 23 | from modeling_finetune import VisionTransformer 24 | from norm_ema_quantizer import NormEMAVectorQuantizer 25 | 26 | from vqkd_teacher import get_dino_vit_base#, clip 27 | 28 | class VQKD(nn.Module): 29 | def __init__(self, 30 | encoder_config, 31 | decoder_config, 32 | n_embed=8192, 33 | embed_dim=32, 34 | decay=0.99, 35 | process_type='default', 36 | quantize_kmeans_init=True, 37 | teacher_model_type='clip', 38 | decoder_out_dim=512, 39 | rec_loss_type='cosine', 40 | **kwargs 41 | ): 42 | super().__init__() 43 | print(kwargs) 44 | if decoder_config['in_chans'] != embed_dim: 45 | print(f"Rewrite the in_chans in decoder from {decoder_config['in_chans']} to {embed_dim}") 46 | decoder_config['in_chans'] = embed_dim 47 | 48 | # encoder & decode params 49 | print('Final encoder config', encoder_config) 50 | self.encoder = VisionTransformer(**encoder_config) 51 | 52 | print('Final decoder config', decoder_config) 53 | self.decoder = VisionTransformer(**decoder_config) 54 | 55 | self.quantize = NormEMAVectorQuantizer( 56 | n_embed=n_embed, embedding_dim=embed_dim, beta=1.0, kmeans_init=quantize_kmeans_init, decay=decay, 57 | ) 58 | 59 | self.patch_size = encoder_config['patch_size'] 60 | self.token_shape = (encoder_config['img_size'] // self.patch_size, encoder_config['img_size'] // self.patch_size) 61 | 62 | ## Teacher model setting 63 | self.teacher_model_type = teacher_model_type 64 | self.decoder_out_dim = decoder_out_dim 65 | if self.teacher_model_type == 'clip': 66 | self.scaling_layer = ScalingLayerForClip() 67 | self.teacher_model, _ = clip.load("ViT-B/16", device='cpu', jit=False) 68 | self.decoder_out_dim = 512 69 | 70 | elif self.teacher_model_type == 'dino': 71 | self.scaling_layer = ScalingLayerForIM() 72 | self.teacher_model = get_dino_vit_base() 73 | self.decoder_out_dim = 768 74 | 75 | else: 76 | self.teacher_model = None 77 | 78 | if self.teacher_model is not None: 79 | for param in self.teacher_model.parameters(): 80 | param.requires_grad = False # fix teacher_model model 81 | 82 | self.teacher_model.eval() 83 | self.teacher_input_size = kwargs.get('teacher_input_size', 224) 84 | 85 | # task layer 86 | self.encode_task_layer = nn.Sequential( 87 | nn.Linear(encoder_config['embed_dim'], encoder_config['embed_dim']), 88 | nn.Tanh(), 89 | nn.Linear(encoder_config['embed_dim'], embed_dim) # for quantize 90 | ) 91 | self.decode_task_layer = nn.Sequential( 92 | nn.Linear(decoder_config['embed_dim'], decoder_config['embed_dim']), 93 | nn.Tanh(), 94 | nn.Linear(decoder_config['embed_dim'], self.decoder_out_dim), 95 | ) 96 | 97 | self.rec_loss_type = rec_loss_type 98 | 99 | print(f"process type for VQKD: {process_type}") 100 | self.process_type = process_type # in ['default', 'dall-e'] 101 | self.logit_laplace_eps = 0.1 102 | self.kwargs = kwargs 103 | 104 | self.encode_task_layer.apply(self._init_weights) 105 | self.decode_task_layer.apply(self._init_weights) 106 | print("model initialized") 107 | 108 | def _init_weights(self, m): 109 | if isinstance(m, nn.Linear): 110 | trunc_normal_(m.weight, std=.02) 111 | if isinstance(m, nn.Linear) and m.bias is not None: 112 | nn.init.constant_(m.bias, 0) 113 | elif isinstance(m, nn.LayerNorm): 114 | nn.init.constant_(m.bias, 0) 115 | nn.init.constant_(m.weight, 1.0) 116 | 117 | @torch.jit.ignore 118 | def no_weight_decay(self): 119 | return {'quantize.embedding.weight', 'decoder.cls_token', 'decoder.pos_embed', 120 | 'encoder.cls_token', 'encoder.pos_embed'} 121 | 122 | @property 123 | def device(self): 124 | return self.decoder.cls_token.device 125 | 126 | def pre_process(self, data): 127 | print("pre process called", data.shape) 128 | # if self.process_type == 'default': 129 | # # TODO: modify for adapt 130 | # data = data.to(self.device) 131 | # if data.max() <= 1.: 132 | # data = data * 255. 133 | # data = data / 127.5 - 1.0 134 | # elif self.process_type == 'imagenet_norm': 135 | # mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(self.device)[None, :, None, None] 136 | # std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(self.device)[None, :, None, None] 137 | # data = (data - mean) / std 138 | return data 139 | 140 | def get_number_of_tokens(self): 141 | return self.quantize.n_e 142 | 143 | def get_tokens(self, data, **kwargs): 144 | 145 | data = self.pre_process(data) 146 | quantize, embed_ind, loss = self.encode(data) 147 | output = {} 148 | output['token'] = embed_ind.view(data.shape[0], -1) 149 | output['input_img'] = data 150 | 151 | return output 152 | 153 | def get_codebook_entry(self, ind): 154 | ind_shape = ind.shape 155 | ind = ind.reshape(-1) 156 | z = self.quantize.embedding(ind) 157 | z = z.reshape(ind_shape[0], -1) 158 | return z 159 | 160 | def encode(self, x): 161 | encoder_features = self.encoder(x, return_patch_tokens=True) 162 | 163 | with torch.cuda.amp.autocast(enabled=False): 164 | to_quantizer_features = self.encode_task_layer(encoder_features.type_as(self.encode_task_layer[-1].weight)) 165 | 166 | N = to_quantizer_features.shape[1] 167 | h, w = int(math.sqrt(N)), int(math.sqrt(N)) 168 | 169 | to_quantizer_features = rearrange(to_quantizer_features, 'b (h w) c -> b c h w', h=h, w=w) # reshape for quantizer 170 | quantize, loss, embed_ind = self.quantize(to_quantizer_features) 171 | # print("before quant and after shapes", to_quantizer_features.shape, quantize.shape) 172 | 173 | return quantize, embed_ind, loss 174 | 175 | def decode(self, quantize, **kwargs): 176 | # reshape tokens to feature maps for patch embed in decoder 177 | # quantize = rearrange(quantize, 'b (h w) c -> b c h w', h=self.token_shape[0], w=self.token_shape[1]) 178 | decoder_features = self.decoder(quantize, return_patch_tokens=True) 179 | rec = self.decode_task_layer(decoder_features) 180 | 181 | return rec 182 | 183 | def get_codebook_indices(self, x, **kwargs): 184 | # for beit pre-training 185 | return self.get_tokens(x, **kwargs)['token'] 186 | 187 | @torch.no_grad() 188 | def get_regress_target(self, x, **kwargs): 189 | 190 | norm_imgs = self.scaling_layer(x) 191 | if self.teacher_model_type == 'clip': 192 | target = self.teacher_model.encode_image(norm_imgs, return_all_tokens=True) @ self.teacher_model.visual.proj 193 | elif self.teacher_model_type == 'dino': 194 | target = self.teacher_model.forward(norm_imgs, return_patch_tokens=True) 195 | else: 196 | raise NotImplementedError 197 | 198 | return target 199 | 200 | def calculate_rec_loss(self, rec, target): 201 | if self.rec_loss_type == 'cosine': 202 | target = target / target.norm(dim=-1, keepdim=True) 203 | rec = rec / rec.norm(dim=-1, keepdim=True) 204 | rec_loss = (1 - (target * rec).sum(-1)).mean() 205 | else: 206 | raise NotImplementedError 207 | 208 | return rec_loss 209 | 210 | def forward(self, x, **kwargs): 211 | """ 212 | x: shape [B, 3, H, W] in [0, 1] 213 | """ 214 | x = self.pre_process(x) # rescale to [-1, 1] 215 | 216 | target = self.get_regress_target(x, **kwargs) 217 | 218 | quantize, embed_ind, emb_loss = self.encode(x) 219 | xrec = self.decode(quantize) 220 | 221 | rec_loss = self.calculate_rec_loss(xrec, target) 222 | loss = emb_loss + rec_loss 223 | 224 | log = {} 225 | split="train" if self.training else "val" 226 | log[f'{split}/quant_loss'] = emb_loss.detach().mean() 227 | log[f'{split}/rec_loss'] = rec_loss.detach().mean() 228 | log[f'{split}/total_loss'] = loss.detach().mean() 229 | 230 | return loss, log 231 | 232 | class ScalingLayerForClip(nn.Module): 233 | def __init__(self): 234 | super(ScalingLayerForClip, self).__init__() 235 | self.register_buffer('shift', torch.Tensor([0.48145466, 0.4578275, 0.40821073])[None, :, None, None]) 236 | self.register_buffer('scale', torch.Tensor([0.26862954, 0.26130258, 0.27577711])[None, :, None, None]) 237 | 238 | def forward(self, inp): 239 | inp = ((inp + 1.) * 127.5).clamp(0, 255.) / 255. # rescale to [0, 1.] 240 | return (inp - self.shift) / self.scale 241 | 242 | class ScalingLayerForIM(nn.Module): 243 | def __init__(self): 244 | super(ScalingLayerForIM, self).__init__() 245 | self.register_buffer('shift', torch.Tensor([0.485, 0.456, 0.406])[None, :, None, None]) # scale for tokenizer with default prosscess type \in [-1, 1] 246 | self.register_buffer('scale', torch.Tensor([0.229, 0.224, 0.225])[None, :, None, None]) 247 | 248 | def forward(self, inp): 249 | inp = ((inp + 1.) * 127.5).clamp(0, 255.) / 255. # rescale to [0, 1.] 250 | return (inp - self.shift) / self.scale 251 | 252 | def get_model_default_params(): 253 | return dict(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, 254 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 255 | norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=0., use_abs_pos_emb=True, 256 | use_rel_pos_bias=False, use_shared_rel_pos_bias=False, use_mean_pooling=True, init_scale=0.001) 257 | 258 | @register_model 259 | def vqkd_encoder_base_decoder_1x768x12_clip(pretrained=False, pretrained_weight=None, as_tokenzer=False, img_size=224, 260 | n_code=8192, code_dim=32, **kwargs): 261 | encoder_config, decoder_config = get_model_default_params(), get_model_default_params() 262 | 263 | # encoder settings 264 | encoder_config['img_size'] = img_size 265 | encoder_config['num_classes'] = 0 266 | # decoder settings 267 | decoder_config['img_size'] = img_size // decoder_config['patch_size'] 268 | decoder_config['patch_size'] = 1 269 | decoder_config['in_chans'] = code_dim 270 | decoder_config['num_classes'] = 0 271 | decoder_config['depth'] = 1 272 | # teacher settings 273 | _ = kwargs.pop("teacher_model_type", "clip") 274 | 275 | teacher_model_type = 'clip' if not as_tokenzer else 'None' 276 | decoder_out_dim = 512 277 | 278 | model = VQKD(encoder_config, decoder_config, n_code, code_dim, teacher_model_type=teacher_model_type, 279 | decoder_out_dim=decoder_out_dim, **kwargs) 280 | 281 | if as_tokenzer: 282 | assert pretrained 283 | assert pretrained_weight is not None 284 | 285 | if pretrained_weight.startswith('https'): 286 | weights = torch.hub.load_state_dict_from_url(pretrained_weight, map_location='cpu', check_hash=True) 287 | else: 288 | weights = torch.load(pretrained_weight, map_location='cpu') 289 | 290 | if 'model' in weights: 291 | weights = weights['model'] 292 | else: 293 | weights = weights["state_dict"] 294 | keys = list(weights.keys()) 295 | 296 | for k in keys: 297 | if k.startswith("loss") or k.startswith("teacher") or k.startswith("scaling"): 298 | del weights[k] 299 | model.load_state_dict(weights) 300 | return model 301 | 302 | @register_model 303 | def vqkd_encoder_base_decoder_3x768x12_clip(pretrained=False, pretrained_weight=None, as_tokenzer=False, img_size=224, 304 | n_code=8192, code_dim=32, **kwargs): 305 | encoder_config, decoder_config = get_model_default_params(), get_model_default_params() 306 | 307 | # encoder settings 308 | encoder_config['img_size'] = img_size 309 | encoder_config['num_classes'] = 0 310 | # decoder settings 311 | decoder_config['img_size'] = img_size // decoder_config['patch_size'] 312 | decoder_config['patch_size'] = 1 313 | decoder_config['in_chans'] = code_dim 314 | decoder_config['num_classes'] = 0 315 | decoder_config['depth'] = 3 316 | # teacher settings 317 | _ = kwargs.pop("teacher_model_type", "clip") 318 | 319 | teacher_model_type = 'clip' if not as_tokenzer else 'None' 320 | decoder_out_dim = 512 321 | 322 | model = VQKD(encoder_config, decoder_config, n_code, code_dim, teacher_model_type=teacher_model_type, 323 | decoder_out_dim=decoder_out_dim, **kwargs) 324 | 325 | if as_tokenzer: 326 | assert pretrained 327 | assert pretrained_weight is not None 328 | 329 | if pretrained_weight.startswith('https'): 330 | weights = torch.hub.load_state_dict_from_url(pretrained_weight, map_location='cpu', check_hash=True) 331 | else: 332 | weights = torch.load(pretrained_weight, map_location='cpu') 333 | 334 | if 'model' in weights: 335 | weights = weights['model'] 336 | else: 337 | weights = weights["state_dict"] 338 | keys = list(weights.keys()) 339 | 340 | for k in keys: 341 | if k.startswith("loss") or k.startswith("teacher") or k.startswith("scaling"): 342 | del weights[k] 343 | model.load_state_dict(weights) 344 | return model 345 | 346 | 347 | @register_model 348 | def vqkd_encoder_base_decoder_1x768x12_dino(pretrained=False, pretrained_weight=None, as_tokenzer=False, img_size=224, 349 | n_code=8192, code_dim=32, **kwargs): 350 | encoder_config, decoder_config = get_model_default_params(), get_model_default_params() 351 | 352 | # encoder settings 353 | encoder_config['img_size'] = img_size 354 | encoder_config['num_classes'] = 0 355 | # decoder settings 356 | decoder_config['img_size'] = img_size // decoder_config['patch_size'] 357 | decoder_config['patch_size'] = 1 358 | decoder_config['in_chans'] = code_dim 359 | decoder_config['num_classes'] = 0 360 | decoder_config['depth'] = 1 361 | # teacher settings 362 | _ = kwargs.pop("teacher_model_type", "dino") 363 | 364 | teacher_model_type = 'dino' if not as_tokenzer else 'None' 365 | decoder_out_dim = 768 366 | 367 | model = VQKD(encoder_config, decoder_config, n_code, code_dim, teacher_model_type=teacher_model_type, 368 | decoder_out_dim=decoder_out_dim, **kwargs) 369 | 370 | if as_tokenzer: 371 | assert pretrained 372 | assert pretrained_weight is not None 373 | 374 | if pretrained_weight.startswith('https'): 375 | weights = torch.hub.load_state_dict_from_url(pretrained_weight, map_location='cpu', check_hash=True) 376 | else: 377 | weights = torch.load(pretrained_weight, map_location='cpu') 378 | 379 | if 'model' in weights: 380 | weights = weights['model'] 381 | else: 382 | weights = weights["state_dict"] 383 | keys = list(weights.keys()) 384 | 385 | for k in keys: 386 | if k.startswith("loss") or k.startswith("teacher") or k.startswith("scaling"): 387 | del weights[k] 388 | model.load_state_dict(weights) 389 | return model 390 | 391 | 392 | if __name__ == '__main__': 393 | pass 394 | 395 | 396 | 397 | 398 | 399 | 400 | -------------------------------------------------------------------------------- /dino/vqkd_teacher/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | import pdb 10 | 11 | class Bottleneck(nn.Module): 12 | expansion = 4 13 | 14 | def __init__(self, inplanes, planes, stride=1): 15 | super().__init__() 16 | 17 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 18 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 25 | 26 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 27 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 28 | 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = None 31 | self.stride = stride 32 | 33 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 34 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 35 | self.downsample = nn.Sequential(OrderedDict([ 36 | ("-1", nn.AvgPool2d(stride)), 37 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 38 | ("1", nn.BatchNorm2d(planes * self.expansion)) 39 | ])) 40 | 41 | def forward(self, x: torch.Tensor): 42 | identity = x 43 | 44 | out = self.relu(self.bn1(self.conv1(x))) 45 | out = self.relu(self.bn2(self.conv2(out))) 46 | out = self.avgpool(out) 47 | out = self.bn3(self.conv3(out)) 48 | 49 | if self.downsample is not None: 50 | identity = self.downsample(x) 51 | 52 | out += identity 53 | out = self.relu(out) 54 | return out 55 | 56 | 57 | class AttentionPool2d(nn.Module): 58 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 59 | super().__init__() 60 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 61 | self.k_proj = nn.Linear(embed_dim, embed_dim) 62 | self.q_proj = nn.Linear(embed_dim, embed_dim) 63 | self.v_proj = nn.Linear(embed_dim, embed_dim) 64 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 65 | self.num_heads = num_heads 66 | 67 | def forward(self, x, return_all_tokens=False): 68 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 69 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 70 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 71 | x, _ = F.multi_head_attention_forward( 72 | query=x, key=x, value=x, 73 | embed_dim_to_check=x.shape[-1], 74 | num_heads=self.num_heads, 75 | q_proj_weight=self.q_proj.weight, 76 | k_proj_weight=self.k_proj.weight, 77 | v_proj_weight=self.v_proj.weight, 78 | in_proj_weight=None, 79 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 80 | bias_k=None, 81 | bias_v=None, 82 | add_zero_attn=False, 83 | dropout_p=0, 84 | out_proj_weight=self.c_proj.weight, 85 | out_proj_bias=self.c_proj.bias, 86 | use_separate_proj_weight=True, 87 | training=self.training, 88 | need_weights=False 89 | ) 90 | if return_all_tokens: 91 | return x 92 | else: 93 | return x[0] 94 | 95 | 96 | class ModifiedResNet(nn.Module): 97 | """ 98 | A ResNet class that is similar to torchvision's but contains the following changes: 99 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 100 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 101 | - The final pooling layer is a QKV attention instead of an average pool 102 | """ 103 | 104 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 105 | super().__init__() 106 | self.output_dim = output_dim 107 | self.input_resolution = input_resolution 108 | 109 | # the 3-layer stem 110 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(width // 2) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.avgpool = nn.AvgPool2d(2) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x, return_side_out=False, return_all_tokens=False): 139 | def stem(x): 140 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 141 | x = self.relu(bn(conv(x))) 142 | x = self.avgpool(x) 143 | return x 144 | out = [] 145 | x = x.type(self.conv1.weight.dtype) 146 | x = stem(x) 147 | x = self.layer1(x) 148 | if return_side_out: 149 | out.append(x) 150 | x = self.layer2(x) 151 | if return_side_out: 152 | out.append(x) 153 | x = self.layer3(x) 154 | if return_side_out: 155 | out.append(x) 156 | x = self.layer4(x) 157 | if return_side_out: 158 | out.append(x) 159 | x = self.attnpool(x, return_all_tokens) 160 | out.append(x) 161 | if len(out) == 1: 162 | return x 163 | else: 164 | return out 165 | 166 | 167 | class LayerNorm(nn.LayerNorm): 168 | """Subclass torch's LayerNorm to handle fp16.""" 169 | 170 | def forward(self, x: torch.Tensor): 171 | orig_type = x.dtype 172 | ret = super().forward(x.type(torch.float32)) 173 | return ret.type(orig_type) 174 | 175 | 176 | class QuickGELU(nn.Module): 177 | def forward(self, x: torch.Tensor): 178 | return x * torch.sigmoid(1.702 * x) 179 | 180 | 181 | class ResidualAttentionBlock(nn.Module): 182 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 183 | super().__init__() 184 | 185 | self.attn = nn.MultiheadAttention(d_model, n_head) 186 | self.ln_1 = LayerNorm(d_model) 187 | self.mlp = nn.Sequential(OrderedDict([ 188 | ("c_fc", nn.Linear(d_model, d_model * 4)), 189 | ("gelu", QuickGELU()), 190 | ("c_proj", nn.Linear(d_model * 4, d_model)) 191 | ])) 192 | self.ln_2 = LayerNorm(d_model) 193 | self.attn_mask = attn_mask 194 | 195 | def attention(self, x: torch.Tensor): 196 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 197 | # pdb.set_trace() 198 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 199 | 200 | def forward(self, x: torch.Tensor): 201 | x = x + self.attention(self.ln_1(x)) 202 | x = x + self.mlp(self.ln_2(x)) 203 | return x 204 | 205 | 206 | class Transformer(nn.Module): 207 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 208 | super().__init__() 209 | self.width = width 210 | self.layers = layers 211 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 212 | 213 | def forward(self, x: torch.Tensor, return_intermediate_out: bool = False): 214 | if return_intermediate_out: 215 | output = [] 216 | for block in self.resblocks: 217 | x = block(x) 218 | output.append(x) 219 | return output 220 | 221 | return self.resblocks(x) 222 | 223 | 224 | class VisionTransformer(nn.Module): 225 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 226 | super().__init__() 227 | self.input_resolution = input_resolution 228 | self.patch_size = patch_size 229 | self.output_dim = output_dim 230 | self.width = width 231 | self.heads = heads 232 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 233 | 234 | scale = width ** -0.5 235 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 236 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 237 | self.ln_pre = LayerNorm(width) 238 | 239 | self.transformer = Transformer(width, layers, heads) 240 | 241 | self.ln_post = LayerNorm(width) 242 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 243 | 244 | def forward(self, x: torch.Tensor, return_all_tokens=False, return_all_final_tokens=False, **kwargs): 245 | 246 | B, nc, w, h = x.shape 247 | 248 | x = self.conv1(x) # shape = [*, width, grid, grid] 249 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 250 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 251 | 252 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 253 | 254 | if x.shape[1] != self.positional_embedding.shape[0]: 255 | x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype) 256 | else: 257 | x = x + self.positional_embedding.to(x.dtype) 258 | 259 | x = self.ln_pre(x) 260 | 261 | x = x.permute(1, 0, 2) # NLD -> LND 262 | x = self.transformer(x) 263 | x = x.permute(1, 0, 2) # LND -> NLD 264 | 265 | if return_all_tokens: 266 | x = self.ln_post(x) 267 | return x[:, 1:, :] 268 | 269 | if return_all_final_tokens: 270 | return self.ln_post(x) @ self.proj 271 | 272 | x = self.ln_post(x[:, 0, :]) 273 | 274 | if self.proj is not None: 275 | x = x @ self.proj 276 | 277 | return x 278 | 279 | def interpolate_pos_encoding(self, x, w, h): 280 | # pdb.set_trace() 281 | npatch = x.shape[1] - 1 282 | N = self.positional_embedding.shape[0] - 1 # 256 for large 283 | if npatch == N and w == h: 284 | return self.positional_embedding 285 | class_pos_embed = self.positional_embedding[[0]] 286 | patch_pos_embed = self.positional_embedding[1:] 287 | dim = x.shape[-1] 288 | w0 = w // self.patch_size 289 | h0 = h // self.patch_size 290 | # we add a small number to avoid floating point error in the interpolation 291 | # see discussion at https://github.com/facebookresearch/dino/issues/8 292 | w0, h0 = w0 + 0.1, h0 + 0.1 293 | patch_pos_embed = nn.functional.interpolate( 294 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 295 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 296 | mode='bicubic', 297 | ) 298 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 299 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 300 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 301 | 302 | 303 | class CLIP(nn.Module): 304 | def __init__(self, 305 | embed_dim: int, # 512 306 | # vision 307 | image_resolution: int, # 224 308 | vision_layers: Union[Tuple[int, int, int, int], int], # 12 309 | vision_width: int, # 768 310 | vision_patch_size: int, # 16 311 | # text 312 | context_length: int, # 77 313 | vocab_size: int, # 49408 314 | transformer_width: int, # 512 315 | transformer_heads: int, # 8 316 | transformer_layers: int # 12 317 | ): 318 | super().__init__() 319 | # pdb.set_trace() 320 | self.context_length = context_length 321 | 322 | if isinstance(vision_layers, (tuple, list)): 323 | vision_heads = vision_width * 32 // 64 324 | self.visual = ModifiedResNet( 325 | layers=vision_layers, 326 | output_dim=embed_dim, 327 | heads=vision_heads, 328 | input_resolution=image_resolution, 329 | width=vision_width 330 | ) 331 | else: 332 | vision_heads = vision_width // 64 333 | self.visual = VisionTransformer( 334 | input_resolution=image_resolution, 335 | patch_size=vision_patch_size, 336 | width=vision_width, 337 | layers=vision_layers, 338 | heads=vision_heads, 339 | output_dim=embed_dim 340 | ) 341 | 342 | self.transformer = Transformer( 343 | width=transformer_width, 344 | layers=transformer_layers, 345 | heads=transformer_heads, 346 | attn_mask=self.build_attention_mask() 347 | ) 348 | 349 | self.vocab_size = vocab_size 350 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 351 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 352 | self.ln_final = LayerNorm(transformer_width) 353 | 354 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 355 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 356 | 357 | self.initialize_parameters() 358 | 359 | def initialize_parameters(self): 360 | nn.init.normal_(self.token_embedding.weight, std=0.02) 361 | nn.init.normal_(self.positional_embedding, std=0.01) 362 | 363 | if isinstance(self.visual, ModifiedResNet): 364 | if self.visual.attnpool is not None: 365 | std = self.visual.attnpool.c_proj.in_features ** -0.5 366 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 367 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 368 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 369 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 370 | 371 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 372 | for name, param in resnet_block.named_parameters(): 373 | if name.endswith("bn3.weight"): 374 | nn.init.zeros_(param) 375 | 376 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 377 | attn_std = self.transformer.width ** -0.5 378 | fc_std = (2 * self.transformer.width) ** -0.5 379 | for block in self.transformer.resblocks: 380 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 381 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 382 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 383 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 384 | 385 | if self.text_projection is not None: 386 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 387 | 388 | def build_attention_mask(self): 389 | # lazily create causal attention mask, with full attention between the vision tokens 390 | # pytorch uses additive attention mask; fill with -inf 391 | mask = torch.empty(self.context_length, self.context_length) 392 | mask.fill_(float("-inf")) 393 | mask.triu_(1) # zero out the lower diagonal 394 | return mask 395 | 396 | @property 397 | def dtype(self): 398 | return self.visual.conv1.weight.dtype 399 | 400 | def encode_image(self, image, return_side_out=False, return_all_tokens=False, return_all_final_tokens=False, **kwargs): 401 | return self.visual(image.type(self.dtype), return_all_tokens, return_all_final_tokens, **kwargs) 402 | 403 | def encode_text(self, text, return_all_tokens=False, return_patch_tokens=False): 404 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 405 | 406 | x = x + self.positional_embedding.type(self.dtype) 407 | x = x.permute(1, 0, 2) # NLD -> LND 408 | x = self.transformer(x) 409 | x = x.permute(1, 0, 2) # LND -> NLD 410 | x = self.ln_final(x).type(self.dtype) 411 | 412 | if return_patch_tokens: 413 | return x 414 | # x.shape = [batch_size, n_ctx, transformer.width] 415 | # take features from the eot embedding (eot_token is the highest number in each sequence) 416 | if return_all_tokens: 417 | # pdb.set_trace() 418 | x = x @ self.text_projection 419 | else: 420 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 421 | return x 422 | 423 | def forward(self, image, text): 424 | image_features = self.encode_image(image) 425 | text_features = self.encode_text(text) 426 | 427 | # normalized features 428 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 429 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 430 | 431 | # cosine similarity as logits 432 | logit_scale = self.logit_scale.exp() 433 | logits_per_image = logit_scale * image_features @ text_features.t() 434 | logits_per_text = logits_per_image.t() 435 | 436 | # shape = [global_batch_size, global_batch_size] 437 | return logits_per_image, logits_per_text 438 | 439 | 440 | def convert_weights(model: nn.Module): 441 | """Convert applicable model parameters to fp16""" 442 | 443 | def _convert_weights_to_fp16(l): 444 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 445 | l.weight.data = l.weight.data.half() 446 | if l.bias is not None: 447 | l.bias.data = l.bias.data.half() 448 | 449 | if isinstance(l, nn.MultiheadAttention): 450 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 451 | tensor = getattr(l, attr) 452 | if tensor is not None: 453 | tensor.data = tensor.data.half() 454 | 455 | for name in ["text_projection", "proj"]: 456 | if hasattr(l, name): 457 | attr = getattr(l, name) 458 | if attr is not None: 459 | attr.data = attr.data.half() 460 | 461 | model.apply(_convert_weights_to_fp16) 462 | 463 | 464 | def build_model(state_dict: dict): 465 | vit = "visual.proj" in state_dict 466 | 467 | if vit: 468 | vision_width = state_dict["visual.conv1.weight"].shape[0] 469 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 470 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 471 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 472 | image_resolution = vision_patch_size * grid_size 473 | else: 474 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 475 | vision_layers = tuple(counts) 476 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 477 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 478 | vision_patch_size = None 479 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 480 | image_resolution = output_width * 32 481 | 482 | embed_dim = state_dict["text_projection"].shape[1] 483 | context_length = state_dict["positional_embedding"].shape[0] 484 | vocab_size = state_dict["token_embedding.weight"].shape[0] 485 | transformer_width = state_dict["ln_final.weight"].shape[0] 486 | transformer_heads = transformer_width // 64 487 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 488 | 489 | model = CLIP( 490 | embed_dim, 491 | image_resolution, vision_layers, vision_width, vision_patch_size, 492 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 493 | ) 494 | 495 | for key in ["input_resolution", "context_length", "vocab_size"]: 496 | if key in state_dict: 497 | del state_dict[key] 498 | 499 | convert_weights(model) 500 | model.load_state_dict(state_dict) 501 | return model.eval() 502 | --------------------------------------------------------------------------------