├── .gitignore
├── dino
├── vqkd_teacher
│ ├── __init__.py
│ ├── clip
│ │ ├── __init__.py
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ ├── simple_tokenizer.py
│ │ ├── clip.py
│ │ └── model.py
│ └── dino.py
├── save_dino_codes.py
├── norm_ema_quantizer.py
├── train_vq_dino_voc.py
└── modeling_vqkd.py
├── requirements.txt
├── utils
├── video.py
├── network_utils.py
├── dmc.py
├── utils.py
├── drqv2.py
└── numpy_replay_buffer.py
├── README.md
├── musik
├── musik_model.py
└── train_vq_musik_voc.py
└── vae
├── save_vq_codes.py
├── train_vq_vae_voc.py
└── train_pixel_vq_vae_voc.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/taming
2 | **/wandb
3 | **/__pycache__
4 | **.pth
--------------------------------------------------------------------------------
/dino/vqkd_teacher/__init__.py:
--------------------------------------------------------------------------------
1 | from .dino import *
2 | from .clip import *
--------------------------------------------------------------------------------
/dino/vqkd_teacher/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 | from .model import *
3 |
--------------------------------------------------------------------------------
/dino/vqkd_teacher/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manantomar/video-occupancy-models/HEAD/dino/vqkd_teacher/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gym==0.17.2+computecanada
2 | h5py==3.8.0+computecanada
3 | matplotlib==3.7.0+computecanada
4 | mujoco-py==2.0.2.10
5 | numpy==1.23.0+computecanada
6 | opencv-python==4.5.1.48+computecanada
7 | pyarrow==5.0.0
8 | seaborn==0.13.2+computecanada
9 | -e git+https://github.com/manantomar/taming-transformers.git@3d8c8ac03d12db0ed361ec74250022cc09af8cbc#egg=taming_transformers
10 | timm==0.9.16+computecanada
11 | torch==2.0.1+cu118
12 | torchaudio==2.0.2+cu118
13 | torchmetrics==1.2.1+computecanada
14 | torchvision==0.11.1+computecanada
15 | tqdm==4.66.1+computecanada
16 | transformers==4.33.3+computecanada
17 | typed-argument-parser==1.9.0
18 | vector_quantize_pytorch==1.12.12
19 | vit-pytorch==1.6.5
20 | wandb==0.15.9
21 | zipp==3.16.2+computecanada
22 |
--------------------------------------------------------------------------------
/utils/video.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import cv2
6 | import imageio
7 | import numpy as np
8 |
9 |
10 | class VideoRecorder:
11 | def __init__(self, root_dir, render_size=256, fps=20):
12 | if root_dir is not None:
13 | self.save_dir = root_dir / 'eval_video'
14 | self.save_dir.mkdir(exist_ok=True)
15 | else:
16 | self.save_dir = None
17 |
18 | self.render_size = render_size
19 | self.fps = fps
20 | self.frames = []
21 |
22 | def init(self, env, enabled=True):
23 | self.frames = []
24 | self.enabled = self.save_dir is not None and enabled
25 | self.record(env)
26 |
27 | def record(self, env):
28 | if self.enabled:
29 | if hasattr(env, 'physics'):
30 | frame = env.physics.render(height=self.render_size,
31 | width=self.render_size,
32 | camera_id=0)
33 | else:
34 | frame = env.render()
35 | self.frames.append(frame)
36 |
37 | def save(self, file_name):
38 | if self.enabled:
39 | path = self.save_dir / file_name
40 | imageio.mimsave(str(path), self.frames, fps=self.fps)
41 |
42 |
43 | class TrainVideoRecorder:
44 | def __init__(self, root_dir, render_size=256, fps=20):
45 | if root_dir is not None:
46 | self.save_dir = root_dir / 'train_video'
47 | self.save_dir.mkdir(exist_ok=True)
48 | else:
49 | self.save_dir = None
50 |
51 | self.render_size = render_size
52 | self.fps = fps
53 | self.frames = []
54 |
55 | def init(self, obs, enabled=True):
56 | self.frames = []
57 | self.enabled = self.save_dir is not None and enabled
58 | self.record(obs)
59 |
60 | def record(self, obs):
61 | if self.enabled:
62 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0),
63 | dsize=(self.render_size, self.render_size),
64 | interpolation=cv2.INTER_CUBIC)
65 | self.frames.append(frame)
66 |
67 | def save(self, file_name):
68 | if self.enabled:
69 | path = self.save_dir / file_name
70 | imageio.mimsave(str(path), self.frames, fps=self.fps)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Video Occupancy Models
2 |
3 | Code for the paper `Video Occupancy Models`, includes three versions of quantizing the input video frames -- `vae` which uses a VQ-VAE, `dino` which uses quantized DINO, and `musik` which uses quantized Multi-step Inverse Dynamics.
4 |
5 |
6 |
7 | This is a PyTorch/GPU implementation of the paper [Video Occupancy Models](https://arxiv.org/pdf/2407.09533):
8 | ```
9 | @Article{VideoOccupancyModels2024,
10 | author = {Manan Tomar and Philippe Hansen-Estruch and Philip Bachman and Alex Lamb and John Langford and Matthew E. Taylor and Sergey Levine,
11 | journal = {arXiv:2407.09533},
12 | title = {Video Occupancy Models},
13 | year = {2024},
14 | }
15 | ```
16 | ### Installation
17 |
18 | The main packages are provided in the `requirements.txt` file. This code has been tested on a virtual env with Python-3.8 with the package versions listed in the requirements file.
19 |
20 | ### Model Checkpoints and Datasets
21 |
22 | The following table provides the pre-trained model checkpoints and datasets used in the paper:
23 |
24 |
25 |
26 | |
27 | Cheetah |
28 | Walker |
29 |
30 | | VQ-VAE fine-tuned model checkpoint |
31 | download |
32 | download |
33 |
34 | | DINO latent datasets |
35 | link |
36 |
37 | | VQ-VAE latent datasets |
38 | link |
39 | link |
40 |
41 |
42 |
43 | ### VQ-VAE VOC
44 |
45 | You would need to download the contents of this [folder](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/) and place them one directory above where this repo is present. This folder contains model descriptions for using a VQ-VAE model from the [taming-transformers](https://github.com/CompVis/taming-transformers?tab=readme-ov-file) codebase.
46 |
47 | Run [train_vq_vae_voc.py](https://github.com/manantomar/video-occupancy-models/blob/master/vae/train_vq_vae_voc.py) to train a VOC model on stored VQ-VAE latents. If you want to train both the VQ-VAE and the VOC model on pixel data then run [train_pixel_vq_vae_voc.py](https://github.com/manantomar/video-occupancy-models/blob/master/vae/train_pixel_vq_vae_voc.py). In case you want to create your own latents by traning VQ-VAE on a custom dataset use the `collect_latents()` and `train_vq_latents()` methods in [save_vq_codes.py](https://github.com/manantomar/video-occupancy-models/blob/master/vae/save_vq_codes.py).
48 |
49 | ### DINO VOC
50 |
51 | We use a quantized verison of [DINO](https://arxiv.org/abs/2104.14294) from [BEiT-v2](https://github.com/microsoft/unilm/tree/master/beit2). You would need to download this [dino model file](https://github.com/addf400/files/releases/download/BEiT-v2/vqkd_encoder_base_decoder_1x768x12_dino-663c55d7.pth) and place them one directory above where this repo is present.
52 |
53 | Run [train_vq_dino_voc.py](https://github.com/manantomar/video-occupancy-models/blob/master/dino/train_vq_dino_voc.py) to train a VOC model on stored DINO latents. Again, in case you want to create your own latents by running a quantized version of DINO on a custom dataset use the `collect_latents()` method in [save_dino_codes.py](https://github.com/manantomar/video-occupancy-models/blob/master/dino/save_dino_codes.py).
54 |
55 | ### MUSIK VOC
56 |
57 | In the case, action data is also available, we use a quantized multi-step inverse kinematics (MUSIK) objective to train the representation.
58 |
59 | Run [train_vq_musik_voc.py](https://github.com/manantomar/video-occupancy-models/blob/master/musik/train_vq_musik_voc.py) to train a VOC model along with the [MUSIK](https://arxiv.org/pdf/2211.00164) objective on pixel data.
60 |
--------------------------------------------------------------------------------
/dino/vqkd_teacher/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/musik/musik_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import pytorch_lightning as pl
4 |
5 | from taming.modules.diffusionmodules.model import Encoder, Decoder
6 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7 |
8 | from network_utils import MUSIKPredictor
9 |
10 | class VQMUSIKModel(pl.LightningModule):
11 | def __init__(self,
12 | ddconfig,
13 | lossconfig,
14 | n_embed,
15 | embed_dim,
16 | ckpt_path=None,
17 | ignore_keys=[],
18 | image_key="image",
19 | colorize_nlabels=None,
20 | monitor=None,
21 | remap=None,
22 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
23 | ):
24 | super().__init__()
25 | self.image_key = image_key
26 | self.encoder = Encoder(**ddconfig)
27 | self.decoder = MUSIKPredictor(2 * embed_dim, 6, 1.0).to('cuda')
28 | self.quant_linear = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Linear(256 * 25, embed_dim))
29 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
30 | remap=remap, sane_index_shape=sane_index_shape)
31 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
32 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
33 | if ckpt_path is not None:
34 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
35 | self.image_key = image_key
36 | if colorize_nlabels is not None:
37 | assert type(colorize_nlabels)==int
38 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39 | if monitor is not None:
40 | self.monitor = monitor
41 |
42 | def init_from_ckpt(self, path, ignore_keys=list()):
43 | sd = torch.load(path, map_location="cpu")["state_dict"]
44 | keys = list(sd.keys())
45 | for k in keys:
46 | for ik in ignore_keys:
47 | if k.startswith(ik):
48 | print("Deleting key {} from state_dict.".format(k))
49 | del sd[k]
50 | self.load_state_dict(sd, strict=False)
51 | print(f"Restored from {path}")
52 |
53 | def encode(self, x, return_prequant=False):
54 | h = self.encoder(x)
55 | h = self.quant_conv(h)
56 | quant, emb_loss, info = self.quantize(h)
57 | if return_prequant:
58 | return quant, emb_loss, info, h
59 | else:
60 | return quant, emb_loss, info
61 |
62 | def decode(self, quant):
63 | quant = self.post_quant_conv(quant)
64 | dec = self.decoder(quant)
65 | return dec
66 |
67 | def decode_linear(self, quant):
68 | quant = self.post_quant_conv(quant)
69 | quant = quant.reshape(quant.shape[0], -1)
70 | dec = self.quant_linear(quant)
71 | return dec
72 |
73 | def decode_code(self, code_b):
74 | quant_b = self.quantize.embed_code(code_b)
75 | dec = self.decode(quant_b)
76 | return dec
77 |
78 | def forward(self, input):
79 | quant, diff, _ = self.encode(input)
80 | dec = self.decode(quant)
81 | return dec, diff
82 |
83 | def get_input(self, batch, k):
84 | x = batch[k]
85 | if len(x.shape) == 3:
86 | x = x[..., None]
87 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
88 | return x.float()
89 |
90 | def configure_optimizers(self):
91 | lr = self.learning_rate
92 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
93 | list(self.decoder.parameters())+
94 | list(self.quantize.parameters())+
95 | list(self.quant_conv.parameters())+
96 | list(self.post_quant_conv.parameters()),
97 | lr=lr, betas=(0.5, 0.9))
98 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
99 | lr=lr, betas=(0.5, 0.9))
100 | return [opt_ae, opt_disc], []
101 |
102 | def get_last_layer(self):
103 | return self.decoder.conv_out.weight
104 |
105 | def log_images(self, batch, **kwargs):
106 | log = dict()
107 | x = self.get_input(batch, self.image_key)
108 | x = x.to(self.device)
109 | xrec, _ = self(x)
110 | if x.shape[1] > 3:
111 | # colorize with random projection
112 | assert xrec.shape[1] > 3
113 | x = self.to_rgb(x)
114 | xrec = self.to_rgb(xrec)
115 | log["inputs"] = x
116 | log["reconstructions"] = xrec
117 | return log
118 |
119 | def to_rgb(self, x):
120 | assert self.image_key == "segmentation"
121 | if not hasattr(self, "colorize"):
122 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
123 | x = F.conv2d(x, weight=self.colorize)
124 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
125 | return x
--------------------------------------------------------------------------------
/dino/save_dino_codes.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '../utils/')
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from pathlib import Path
10 | from matplotlib import pyplot as plt
11 |
12 | import random
13 | import utils
14 | import os
15 | import wandb
16 | import math
17 |
18 | from torchvision.utils import make_grid
19 | from drqv2 import RandomShiftsAug
20 |
21 | from dm_env import specs
22 | import dmc
23 |
24 | import h5py
25 | from omegaconf import OmegaConf
26 |
27 | from torchvision import transforms as pth_transforms
28 | from timm.models import create_model
29 |
30 | import modeling_vqkd
31 |
32 | def get_parameter_names(model, forbidden_layer_types):
33 | """
34 | Returns the names of the model parameters that are not inside a forbidden layer.
35 | """
36 | result = []
37 | for name, child in model.named_children():
38 | result += [
39 | f"{name}.{n}"
40 | for n in get_parameter_names(child, forbidden_layer_types)
41 | if not isinstance(child, tuple(forbidden_layer_types))
42 | ]
43 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
44 | result += list(model._parameters.keys())
45 | return result
46 |
47 | class DINOAgent:
48 | def __init__(self, augmentation=RandomShiftsAug(pad=4)):
49 |
50 | wandb.init(project="Video Occupancy Models",
51 | entity=None, dir=os.getcwd())
52 |
53 | self.hdf5_file_path = "../../cheetah_train/org/3_cheetah_run_random.hdf5"
54 | self.save_path = "../../cheetah_train/dino_latents/3_random_dino_latents.hdf5"
55 | self.dino_path = '../../vqkd_encoder_base_decoder_1x768x12_dino-663c55d7.pth'
56 |
57 | self.model = create_model(
58 | 'vqkd_encoder_base_decoder_1x768x12_dino',
59 | pretrained=True,
60 | pretrained_weight=self.dino_path,
61 | as_tokenzer=True,
62 | ).to('cuda').eval()
63 |
64 | # vq vae optimizer
65 | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=3e-4)
66 |
67 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to('cuda')
68 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to('cuda')
69 |
70 | self.device = 'cuda'
71 | # data augmentation
72 | self.aug = augmentation
73 |
74 | self.train()
75 |
76 | def train(self, training=True):
77 | self.training = training
78 |
79 | def process_obs(self, obs):
80 | obs = self.aug(obs.float())
81 | obs = F.interpolate(obs, size=224)
82 |
83 | obs_shape = obs.shape
84 |
85 | obs = torch.einsum('nchw->nhwc', obs / 255.) - self.imagenet_mean / self.imagenet_std
86 | obs = torch.einsum('nhwc->nchw', obs).reshape((obs_shape[0], 3, *obs_shape[2:]))
87 | return obs, obs_shape
88 |
89 | def collect_latents(self, step=0):
90 | with h5py.File(self.hdf5_file_path, 'r') as hf:
91 | fobs = hf['observation'][()]
92 | faction = hf['action'][()]
93 | fdiscount = hf['discount'][()]
94 | freward = hf['reward'][()]
95 | fstep_type = hf['step_type'][()]
96 |
97 | fobs = fobs
98 | batch_size = 25
99 |
100 | assert fobs.shape[0] % batch_size == 0
101 | iters = fobs.shape[0] / batch_size
102 |
103 | print("Total Obs are {}, Batch Size is {}, Total Iterations is {}".format(fobs.shape[0], batch_size, iters))
104 |
105 | dino_latents = []
106 | for i in range(int(iters)):
107 | obs = fobs[i*batch_size:(i + 1)*batch_size]
108 | obs = utils.to_torch([obs], self.device)[0]
109 |
110 | obs, obs_shape = self.process_obs(obs)
111 |
112 | # dino embed
113 | quant_context, y_context, context_loss = self.model.encode(obs)
114 |
115 | y_context = y_context.reshape((obs_shape[0], -1)).detach()
116 |
117 | # collect discrete vq indices
118 | dino_latents.append(y_context)
119 |
120 | dataset_names = ['action', 'discount', 'observation', 'reward', 'step_type', 'dino_latents']
121 | data_arrays = [faction, fdiscount, fobs, freward, fstep_type, torch.stack(dino_latents).reshape((batch_size * int(iters), -1)).cpu().long()]
122 | self.create_hdf5_file(self.save_path, dataset_names, data_arrays)
123 |
124 | def create_hdf5_file(self, file_path, dataset_names, data_arrays):
125 | """
126 | Create an HDF5 file and store multiple arrays in it.
127 |
128 | Parameters:
129 | - file_path: Path to the HDF5 file.
130 | - dataset_names: List of dataset names.
131 | - data_arrays: List of NumPy arrays to be stored in the HDF5 file.
132 | """
133 | # org dataset ['action', 'discount', 'observation', 'reward', 'step_type']
134 | with h5py.File(file_path, 'w') as hf:
135 | for dataset_name, data_array in zip(dataset_names, data_arrays):
136 | # Create a dataset in the HDF5 file
137 | hf.create_dataset(dataset_name, data=data_array)
138 |
139 |
140 | if __name__ == "__main__":
141 |
142 | agent = DINOAgent()
143 | agent.collect_latents()
144 |
--------------------------------------------------------------------------------
/utils/network_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import torch.distributed as dist
9 | import utils
10 |
11 | def weight_init(m):
12 | if isinstance(m, nn.Linear):
13 | nn.init.orthogonal_(m.weight.data)
14 | if hasattr(m.bias, 'data'):
15 | m.bias.data.fill_(0.0)
16 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
17 | gain = nn.init.calculate_gain('relu')
18 | nn.init.orthogonal_(m.weight.data, gain)
19 | if hasattr(m.bias, 'data'):
20 | m.bias.data.fill_(0.0)
21 |
22 | class RandomShiftsAug(nn.Module):
23 | def __init__(self, pad=4):
24 | super().__init__()
25 | self.pad = pad
26 |
27 | def forward(self, x):
28 | # x = T.Resize((x.shape[0], x.shape[1], 64, 64))
29 | n, c, h, w = x.size()
30 | assert h == w
31 | padding = tuple([self.pad] * 4)
32 | x = F.pad(x, padding, 'replicate')
33 | eps = 1.0 / (h + 2 * self.pad)
34 | arange = torch.linspace(-1.0 + eps,
35 | 1.0 - eps,
36 | h + 2 * self.pad,
37 | device=x.device,
38 | dtype=x.dtype)[:h]
39 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
40 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
41 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
42 |
43 | shift = torch.randint(0,
44 | 2 * self.pad + 1,
45 | size=(n, 1, 1, 2),
46 | device=x.device,
47 | dtype=x.dtype)
48 | shift *= 2.0 / (h + 2 * self.pad)
49 |
50 | grid = base_grid + shift
51 | return F.grid_sample(x,
52 | grid,
53 | padding_mode='zeros',
54 | align_corners=False)
55 |
56 |
57 | def get_parameter_names(model, forbidden_layer_types):
58 | """
59 | Returns the names of the model parameters that are not inside a forbidden layer.
60 | """
61 | result = []
62 | for name, child in model.named_children():
63 | result += [
64 | f"{name}.{n}"
65 | for n in get_parameter_names(child, forbidden_layer_types)
66 | if not isinstance(child, tuple(forbidden_layer_types))
67 | ]
68 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
69 | result += list(model._parameters.keys())
70 | return result
71 |
72 | class Predictor(nn.Module):
73 | def __init__(self, feature_dim):
74 | super(Predictor, self).__init__()
75 |
76 | self.l1 = nn.Linear(feature_dim, 256)
77 | self.l2 = nn.Linear(256, 256)
78 | self.l3 = nn.Linear(256, 1)
79 |
80 | def forward(self, state):
81 | a = F.relu(self.l1(state))
82 | a = F.relu(self.l2(a))
83 | return self.l3(a)
84 |
85 |
86 | class EMA():
87 | def __init__(self, beta):
88 | super().__init__()
89 | self.beta = beta
90 |
91 | def update_average(self, old, new):
92 | if old is None:
93 | return new
94 | return old * self.beta + (1 - self.beta) * new
95 |
96 |
97 | def update_moving_average(ema_updater, ma_model, current_model):
98 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
99 | old_weight, up_weight = ma_params.data, current_params.data
100 | ma_params.data = ema_updater.update_average(old_weight, up_weight)
101 |
102 |
103 | class ConvPredictor(nn.Module):
104 | def __init__(self, obs_shape):
105 | super().__init__()
106 |
107 | assert len(obs_shape) == 3
108 | self.repr_dim = 32 * 35 * 35
109 | feature_dim = 50
110 | hidden_dim = 1024
111 |
112 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
113 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
114 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
115 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
116 | nn.ReLU())
117 |
118 | self.trunk = nn.Sequential(nn.Linear(self.repr_dim, feature_dim),
119 | nn.LayerNorm(feature_dim), nn.Tanh())
120 |
121 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
122 | nn.ReLU(inplace=True),
123 | nn.Linear(hidden_dim, hidden_dim),
124 | nn.ReLU(inplace=True),
125 | nn.Linear(hidden_dim, 21))
126 |
127 | # self.linear = nn.Sequential(nn.Linear(self.repr_dim, 256), nn.ReLU(),
128 | # nn.Linear(256, 256), nn.ReLU(),
129 | # nn.Linear(256, 21), nn.Tanh())
130 |
131 | self.apply(utils.weight_init)
132 |
133 | def forward(self, obs, std=0.1, eval=False):
134 | obs = obs / 255.0 - 0.5
135 | h = self.convnet(obs)
136 | h = h.view(h.shape[0], -1)
137 | h = self.trunk(h)
138 | mu = self.policy(h)
139 |
140 | std = torch.ones_like(mu) * std
141 |
142 | dist = utils.TruncatedNormal(mu, std)
143 | if eval:
144 | action = dist.mean
145 | else:
146 | action = dist.sample(clip=0.3)
147 | return action
148 |
149 | class projection_MLP(nn.Module):
150 | def __init__(self, in_dim, hidden_dim=256, out_dim=50): #256):
151 | super().__init__()
152 | # hidden_dim = in_dim
153 | self.layer1 = nn.Sequential(
154 | nn.Linear(in_dim, hidden_dim, bias=False),
155 | nn.BatchNorm1d(hidden_dim),
156 | nn.ReLU(inplace=True)
157 | )
158 | self.layer2 = nn.Linear(hidden_dim, out_dim)
159 | def forward(self, x):
160 | x = self.layer1(x)
161 | x = self.layer2(x)
162 | return x
163 |
164 | class InfoNCE(nn.Module):
165 | def __init__(self, feature_dim, action_dim, num_actions=1):
166 | super().__init__()
167 |
168 | self.train_samples = 256
169 | self.action_dim = action_dim
170 |
171 | self.projector = projection_MLP(feature_dim, 256, 1)
172 |
173 | # self.apply(weight_init)
174 |
175 | def forward(self, x1, x2, action, return_logits=False):
176 | self.device = x1.device
177 | # Generate N negatives, one for each element in the batch: (B, N, D).
178 | negatives = self.sample(x1.size(0), action.size(1))
179 |
180 | # Merge target and negatives: (B, N+1, D).
181 | targets = torch.cat([action.unsqueeze(dim=1), negatives], dim=1)
182 |
183 | # Generate a random permutation of the positives and negatives.
184 | permutation = torch.rand(targets.size(0), targets.size(1)).argsort(dim=1)
185 | targets = targets[torch.arange(targets.size(0)).unsqueeze(-1), permutation]
186 |
187 | # Get the original index of the positive. This will serve as the class label
188 | # for the loss.
189 | ground_truth = (permutation == 0).nonzero()[:, 1].to(self.device)
190 |
191 | # For every element in the mini-batch, there is 1 positive for which the EBM
192 | # should output a low energy value, and N negatives for which the EBM should
193 | # output high energy values.
194 | fused = torch.cat([x1.unsqueeze(1).expand(-1, targets.size(1), -1), x2.unsqueeze(1).expand(-1, targets.size(1), -1), targets], dim=-1)
195 | B, N, D = fused.size()
196 | fused = fused.reshape(B * N, D)
197 | out = self.projector(fused)
198 | energy = out.view(B, N)
199 |
200 | # Interpreting the energy as a negative logit, we can apply a cross entropy loss
201 | # to train the EBM.
202 | logits = -1.0 * energy
203 | loss = F.cross_entropy(logits, ground_truth.detach())
204 |
205 | if return_logits:
206 | return logits
207 |
208 | return loss
209 |
210 | def _sample(self, num_samples: int, action_size: int) -> torch.Tensor:
211 | """Helper method for drawing samples from the uniform random distribution."""
212 | size = (num_samples, action_size)
213 | samples = np.random.uniform(-1, 1, size=size)
214 | return torch.as_tensor(samples, dtype=torch.float32, device=self.device)
215 |
216 | def sample(self, batch_size: int, action_size: int) -> torch.Tensor:
217 | samples = self._sample(batch_size * self.train_samples, action_size)
218 | return samples.reshape(batch_size, self.train_samples, -1)
219 |
220 | class MUSIKPredictor(nn.Module):
221 | def __init__(self, state_dim, action_dim, max_action):
222 | super(MUSIKPredictor, self).__init__()
223 |
224 | feature_dim = 50
225 | hidden_dim = 1024
226 |
227 | self.trunk = nn.Sequential(nn.Linear(state_dim, feature_dim),
228 | nn.LayerNorm(feature_dim), nn.Tanh())
229 |
230 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
231 | nn.ReLU(inplace=True),
232 | nn.Linear(hidden_dim, hidden_dim),
233 | nn.ReLU(inplace=True),
234 | nn.Linear(hidden_dim, action_dim))
235 |
236 | self.max_action = max_action
237 | self.apply(utils.weight_init)
238 |
239 | def forward(self, state, std=0.1, eval=False):
240 | h = self.trunk(state)
241 | mu = self.policy(h)
242 | mu = torch.tanh(mu)
243 |
244 | std = torch.ones_like(mu) * std
245 |
246 | dist = utils.TruncatedNormal(mu, std)
247 | if eval:
248 | action = dist.mean
249 | else:
250 | action = dist.sample(clip=0.3)
251 | return action
252 |
253 |
254 |
--------------------------------------------------------------------------------
/vae/save_vq_codes.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '../utils/')
3 |
4 | from taming.models.vqgan import VQModel
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from pathlib import Path
12 | from matplotlib import pyplot as plt
13 |
14 | import random
15 | import utils
16 | import os
17 | import wandb
18 | import math
19 |
20 | from torchvision.utils import make_grid
21 | from drqv2 import RandomShiftsAug
22 |
23 | from dm_env import specs
24 | import dmc
25 |
26 | import h5py
27 | from omegaconf import OmegaConf
28 |
29 | def get_parameter_names(model, forbidden_layer_types):
30 | """
31 | Returns the names of the model parameters that are not inside a forbidden layer.
32 | """
33 | result = []
34 | for name, child in model.named_children():
35 | result += [
36 | f"{name}.{n}"
37 | for n in get_parameter_names(child, forbidden_layer_types)
38 | if not isinstance(child, tuple(forbidden_layer_types))
39 | ]
40 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
41 | result += list(model._parameters.keys())
42 | return result
43 |
44 | class VQAgent:
45 | def __init__(self, augmentation=RandomShiftsAug(pad=4)):
46 |
47 | wandb.init(project="Video Occupancy Models",
48 | entity=None, dir=os.getcwd())
49 |
50 | self.hdf5_dir_path = "../../walker_train/org/"
51 | self.hdf5_file_path = "../../walker_train/org/3_walker_walk_medium.hdf5"
52 | self.save_path = "../../walker_train/vq_latents/3_medium_vq_latents.hdf5"
53 | self.taming_path = "../../"
54 | self.vq_model_path = "../../walker_vq_model.pth"
55 |
56 | ################################################################################
57 | # #
58 | # VQ VAE Setup #
59 | # #
60 | ################################################################################
61 | config_path = os.path.join(self.taming_path, "vqgan_imagenet_f16_1024/configs/model.yaml")
62 | config = OmegaConf.load(config_path)
63 |
64 | self.model = VQModel(**config.model.params).to('cuda')
65 |
66 | self.from_imagenet = False
67 | if self.from_imagenet:
68 | ckpt_path = os.path.join(self.taming_path, "vqgan_imagenet_f16_1024/ckpts/last.ckpt")
69 | sd = torch.load(ckpt_path, map_location="cuda")["state_dict"]
70 | missing, unexpected = self.model.load_state_dict(sd, strict=False)
71 | else:
72 | print("Loading fine-tuned VQ model...")
73 | self.model.load_state_dict(torch.load(self.vq_model_path))
74 |
75 | # vq vae optimizer
76 | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=3e-4)
77 |
78 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to('cuda')
79 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to('cuda')
80 |
81 | self.device = 'cuda'
82 | # data augmentation
83 | self.aug = augmentation
84 |
85 | self.train()
86 |
87 | def train(self, training=True):
88 | self.training = training
89 |
90 | def process_obs(self, obs):
91 | obs = self.aug(obs.float())
92 | obs = F.interpolate(obs, size=80)
93 |
94 | obs_shape = obs.shape
95 |
96 | obs = torch.einsum('nchw->nhwc', obs / 255.) - self.imagenet_mean / self.imagenet_std
97 | obs = torch.einsum('nhwc->nchw', obs).reshape((obs_shape[0], 3, *obs_shape[2:]))
98 | return obs, obs_shape
99 |
100 | def collect_latents(self, step=0):
101 | with h5py.File(self.hdf5_file_path, 'r') as hf:
102 | fobs = hf['observation'][()]
103 | faction = hf['action'][()]
104 | fdiscount = hf['discount'][()]
105 | freward = hf['reward'][()]
106 | fstep_type = hf['step_type'][()]
107 |
108 | fobs = fobs
109 | batch_size = 25
110 |
111 | assert fobs.shape[0] % batch_size == 0
112 | iters = fobs.shape[0] / batch_size
113 |
114 | print("Total Obs are {}, Batch Size is {}, Total Iterations is {}".format(fobs.shape[0], batch_size, iters))
115 |
116 | vae_latents = []
117 | for i in range(int(iters)):
118 | obs = fobs[i*batch_size:(i + 1)*batch_size]
119 | obs = utils.to_torch([obs], self.device)[0]
120 |
121 | obs, obs_shape = self.process_obs(obs)
122 |
123 | # vq embed
124 | quant_context, emb_loss_context, info_context = self.model.encode(obs)
125 |
126 | # collect discrete vq indices
127 | y_context = info_context[2].view(obs_shape[0], -1).detach()
128 | vae_latents.append(y_context)
129 |
130 | dataset_names = ['action', 'discount', 'observation', 'reward', 'step_type', 'vae_latents']
131 | data_arrays = [faction, fdiscount, fobs, freward, fstep_type, torch.stack(vae_latents).reshape((batch_size * int(iters), -1)).cpu().long()]
132 | self.create_hdf5_file(self.save_path, dataset_names, data_arrays)
133 |
134 | def train_vq_latents(self):
135 | hdf5_files = [f for f in os.listdir(self.hdf5_dir_path) if f.endswith('.hdf5')]
136 |
137 | fobs = []
138 | # Loop through each file and read data
139 | for file in hdf5_files:
140 | file_path = os.path.join(self.hdf5_dir_path, file)
141 | with h5py.File(file_path, 'r') as hf:
142 | fobs.append(hf['observation'][()])
143 | fobs = np.stack(fobs)
144 | fobs = fobs.reshape((-1, *fobs.shape[2:]))
145 |
146 | batch_size = 64
147 |
148 | vae_latents = []
149 |
150 | for step in range(50000): # Num of updates
151 | idx = np.random.choice(fobs.shape[0], size=batch_size)
152 | obs = fobs[idx]
153 | obs = utils.to_torch([obs], self.device)[0]
154 |
155 | obs, obs_shape = self.process_obs(obs)
156 |
157 | # vq embed
158 | quant_context, emb_loss_context, info_context = self.model.encode(obs)
159 |
160 | xrec, qloss = self.model.decode(quant_context), emb_loss_context
161 | vae_loss, log_dict_ae = self.model.loss(qloss, obs, xrec, 0, step, last_layer=self.model.get_last_layer(), split="train")
162 |
163 | # collect discrete vq indices
164 | y_context = info_context[2].view(obs_shape[0], -1).detach()
165 |
166 | if step % 100 == 0:
167 | with torch.no_grad():
168 | print("vae loss", vae_loss)
169 | viz_imgs = []
170 | viz_imgs.append(xrec)
171 | viz_imgs.append(obs)
172 |
173 | viz_imgs = torch.stack(viz_imgs)[:, :5]
174 | t, n, c, h, w = viz_imgs.shape
175 | viz_imgs = torch.einsum('tnchw->ntchw', viz_imgs)
176 | viz_imgs = viz_imgs.reshape(t*n, c, h, w)
177 | viz_img = make_grid(viz_imgs, nrow=t, normalize=True, scale_each=True)
178 |
179 | img = wandb.Image(viz_img)
180 | wandb.log({f"Gamma Pred": img}, step=step)
181 |
182 | loss = vae_loss
183 | loss.backward()
184 |
185 | if step % 1 == 0:
186 | self.optimizer.step()
187 | self.optimizer.zero_grad()
188 |
189 | if step % 2000 == 0:
190 | self.save_vq_weights(step)
191 |
192 | def create_hdf5_file(self, file_path, dataset_names, data_arrays):
193 | """
194 | Create an HDF5 file and store multiple arrays in it.
195 |
196 | Parameters:
197 | - file_path: Path to the HDF5 file.
198 | - dataset_names: List of dataset names.
199 | - data_arrays: List of NumPy arrays to be stored in the HDF5 file.
200 | """
201 | # org dataset ['action', 'discount', 'observation', 'reward', 'step_type']
202 | with h5py.File(file_path, 'w') as hf:
203 | for dataset_name, data_array in zip(dataset_names, data_arrays):
204 | # Create a dataset in the HDF5 file
205 | hf.create_dataset(dataset_name, data=data_array)
206 |
207 | def save_vq_weights(self, step):
208 | torch.save(self.model.state_dict(), self.vq_model_path)
209 |
210 |
211 | if __name__ == "__main__":
212 |
213 | agent = VQAgent()
214 | agent.collect_latents()
215 |
216 | # agent.train_vq_latents()
217 |
--------------------------------------------------------------------------------
/dino/norm_ema_quantizer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beitv2
4 | # Copyright (c) 2022 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Zhiliang Peng
7 | # Based on VQGAN code bases
8 | # https://github.com/CompVis/taming-transformers
9 | # --------------------------------------------------------'
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.distributed as distributed
15 | from einops import rearrange, repeat
16 |
17 |
18 | def l2norm(t):
19 | return F.normalize(t, p = 2, dim = -1)
20 |
21 | def ema_inplace(moving_avg, new, decay):
22 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
23 |
24 | def sample_vectors(samples, num):
25 | num_samples, device = samples.shape[0], samples.device
26 |
27 | if num_samples >= num:
28 | indices = torch.randperm(num_samples, device = device)[:num]
29 | else:
30 | indices = torch.randint(0, num_samples, (num,), device = device)
31 |
32 | return samples[indices]
33 |
34 | def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
35 | dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
36 |
37 | means = sample_vectors(samples, num_clusters)
38 |
39 | for _ in range(num_iters):
40 | if use_cosine_sim:
41 | dists = samples @ means.t()
42 | else:
43 | diffs = rearrange(samples, 'n d -> n () d') \
44 | - rearrange(means, 'c d -> () c d')
45 | dists = -(diffs ** 2).sum(dim = -1)
46 |
47 | buckets = dists.max(dim = -1).indices
48 | bins = torch.bincount(buckets, minlength = num_clusters)
49 | zero_mask = bins == 0
50 | bins_min_clamped = bins.masked_fill(zero_mask, 1)
51 |
52 | new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype)
53 | new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples)
54 | new_means = new_means / bins_min_clamped[..., None]
55 |
56 | if use_cosine_sim:
57 | new_means = l2norm(new_means)
58 |
59 | means = torch.where(zero_mask[..., None], means, new_means)
60 |
61 | return means, bins
62 |
63 |
64 | class EmbeddingEMA(nn.Module):
65 | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
66 | super().__init__()
67 | self.num_tokens = num_tokens
68 | self.codebook_dim = codebook_dim
69 | self.decay = decay
70 | self.eps = eps
71 | if codebook_init_path == '':
72 | if not kmeans_init:
73 | weight = torch.randn(num_tokens, codebook_dim)
74 | weight = l2norm(weight)
75 | else:
76 | weight = torch.zeros(num_tokens, codebook_dim)
77 | self.register_buffer('initted', torch.Tensor([not kmeans_init]))
78 | else:
79 | print(f"load init codebook weight from {codebook_init_path}")
80 | codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
81 | weight = codebook_ckpt_weight.clone()
82 | self.register_buffer('initted', torch.Tensor([True]))
83 |
84 | self.weight = nn.Parameter(weight, requires_grad = False)
85 | self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
86 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
87 | # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
88 | self.update = True
89 |
90 | @torch.jit.ignore
91 | def init_embed_(self, data):
92 | if self.initted:
93 | return
94 | print("Performing Kemans init for codebook")
95 | embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim = True)
96 | self.weight.data.copy_(embed)
97 | self.cluster_size.data.copy_(cluster_size)
98 | self.initted.data.copy_(torch.Tensor([True]))
99 |
100 | def forward(self, embed_id):
101 | return F.embedding(embed_id, self.weight)
102 |
103 | def cluster_size_ema_update(self, new_cluster_size):
104 | self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
105 |
106 | def embed_avg_ema_update(self, new_embed_avg):
107 | self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
108 |
109 | def weight_update(self, num_tokens):
110 | n = self.cluster_size.sum()
111 | smoothed_cluster_size = (
112 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
113 | )
114 | #normalize embedding average with smoothed cluster size
115 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
116 | # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
117 | self.weight.data.copy_(embed_normalized)
118 |
119 | def norm_ema_inplace(moving_avg, new, decay):
120 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
121 | moving_avg.data.copy_(l2norm(moving_avg.data))
122 |
123 | class NormEMAVectorQuantizer(nn.Module):
124 | def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
125 | statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
126 | super().__init__()
127 | self.codebook_dim = embedding_dim
128 | self.num_tokens = n_embed
129 | self.beta = beta
130 | self.decay = decay
131 |
132 | # learnable = True if orthogonal_reg_weight > 0 else False
133 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
134 |
135 | self.statistic_code_usage = statistic_code_usage
136 | if statistic_code_usage:
137 | self.register_buffer('cluster_size', torch.zeros(n_embed))
138 | if distributed.is_available() and distributed.is_initialized():
139 | print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
140 | self.all_reduce_fn = distributed.all_reduce
141 | else:
142 | self.all_reduce_fn = nn.Identity()
143 |
144 | def reset_cluster_size(self, device):
145 | if self.statistic_code_usage:
146 | self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
147 | self.cluster_size = self.cluster_size.to(device)
148 |
149 | def forward(self, z):
150 | # reshape z -> (batch, height, width, channel) and flatten
151 | #z, 'b c h w -> b h w c'
152 | z = rearrange(z, 'b c h w -> b h w c')
153 | z = l2norm(z)
154 | z_flattened = z.reshape(-1, self.codebook_dim)
155 |
156 | self.embedding.init_embed_(z_flattened)
157 |
158 | d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
159 | self.embedding.weight.pow(2).sum(dim=1) - 2 * \
160 | torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
161 |
162 | encoding_indices = torch.argmin(d, dim=1)
163 |
164 | z_q = self.embedding(encoding_indices).view(z.shape)
165 |
166 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
167 |
168 | if not self.training:
169 | with torch.no_grad():
170 | cluster_size = encodings.sum(0)
171 | self.all_reduce_fn(cluster_size)
172 | ema_inplace(self.cluster_size, cluster_size, self.decay)
173 |
174 | if self.training and self.embedding.update:
175 | #EMA cluster size
176 |
177 | bins = encodings.sum(0)
178 | self.all_reduce_fn(bins)
179 |
180 | # self.embedding.cluster_size_ema_update(bins)
181 | ema_inplace(self.cluster_size, bins, self.decay)
182 |
183 | zero_mask = (bins == 0)
184 | bins = bins.masked_fill(zero_mask, 1.)
185 |
186 | embed_sum = z_flattened.t() @ encodings
187 | self.all_reduce_fn(embed_sum)
188 |
189 | embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
190 | embed_normalized = l2norm(embed_normalized)
191 |
192 | embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
193 | embed_normalized)
194 | norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
195 |
196 | # compute loss for embedding
197 | loss = self.beta * F.mse_loss(z_q.detach(), z)
198 |
199 | # preserve gradients
200 | z_q = z + (z_q - z).detach()
201 |
202 | # reshape back to match original input shape
203 | #z_q, 'b h w c -> b c h w'
204 | z_q = rearrange(z_q, 'b h w c -> b c h w')
205 | return z_q, loss, encoding_indices
206 |
207 |
--------------------------------------------------------------------------------
/dino/vqkd_teacher/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
37 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
38 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
39 | }
40 |
41 |
42 | def _download(url: str, root: str):
43 | os.makedirs(root, exist_ok=True)
44 | filename = os.path.basename(url)
45 |
46 | expected_sha256 = url.split("/")[-2]
47 | download_target = os.path.join(root, filename)
48 |
49 | if os.path.exists(download_target) and not os.path.isfile(download_target):
50 | raise RuntimeError(f"{download_target} exists and is not a regular file")
51 |
52 | if os.path.isfile(download_target):
53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
54 | return download_target
55 | else:
56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
57 |
58 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
60 | while True:
61 | buffer = source.read(8192)
62 | if not buffer:
63 | break
64 |
65 | output.write(buffer)
66 | loop.update(len(buffer))
67 |
68 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
69 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
70 |
71 | return download_target
72 |
73 |
74 | def _convert_image_to_rgb(image):
75 | return image.convert("RGB")
76 |
77 |
78 | def _transform(n_px):
79 | return Compose([
80 | Resize(n_px, interpolation=BICUBIC),
81 | CenterCrop(n_px),
82 | _convert_image_to_rgb,
83 | ToTensor(),
84 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
85 | ])
86 |
87 |
88 | def available_models() -> List[str]:
89 | """Returns the names of available CLIP models"""
90 | return list(_MODELS.keys())
91 |
92 |
93 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
94 | """Load a CLIP model
95 |
96 | Parameters
97 | ----------
98 | name : str
99 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
100 |
101 | device : Union[str, torch.device]
102 | The device to put the loaded model
103 |
104 | jit : bool
105 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
106 |
107 | download_root: str
108 | path to download the model files; by default, it uses "~/.cache/clip"
109 |
110 | Returns
111 | -------
112 | model : torch.nn.Module
113 | The CLIP model
114 |
115 | preprocess : Callable[[PIL.Image], torch.Tensor]
116 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
117 | """
118 | if name in _MODELS:
119 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
120 | elif os.path.isfile(name):
121 | model_path = name
122 | else:
123 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
124 |
125 | try:
126 | # loading JIT archive
127 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
128 | state_dict = None
129 | except RuntimeError:
130 | # loading saved state dict
131 | if jit:
132 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
133 | jit = False
134 | state_dict = torch.load(model_path, map_location="cpu")
135 |
136 | if not jit:
137 | model = build_model(state_dict or model.state_dict()).to(device)
138 | if str(device) == "cpu":
139 | model.float()
140 | return model, _transform(model.visual.input_resolution)
141 |
142 | # patch the device names
143 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
144 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
145 |
146 | def patch_device(module):
147 | try:
148 | graphs = [module.graph] if hasattr(module, "graph") else []
149 | except RuntimeError:
150 | graphs = []
151 |
152 | if hasattr(module, "forward1"):
153 | graphs.append(module.forward1.graph)
154 |
155 | for graph in graphs:
156 | for node in graph.findAllNodes("prim::Constant"):
157 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
158 | node.copyAttributes(device_node)
159 |
160 | model.apply(patch_device)
161 | patch_device(model.encode_image)
162 | patch_device(model.encode_text)
163 |
164 | # patch dtype to float32 on CPU
165 | if str(device) == "cpu":
166 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
167 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
168 | float_node = float_input.node()
169 |
170 | def patch_float(module):
171 | try:
172 | graphs = [module.graph] if hasattr(module, "graph") else []
173 | except RuntimeError:
174 | graphs = []
175 |
176 | if hasattr(module, "forward1"):
177 | graphs.append(module.forward1.graph)
178 |
179 | for graph in graphs:
180 | for node in graph.findAllNodes("aten::to"):
181 | inputs = list(node.inputs())
182 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
183 | if inputs[i].node()["value"] == 5:
184 | inputs[i].node().copyAttributes(float_node)
185 |
186 | model.apply(patch_float)
187 | patch_float(model.encode_image)
188 | patch_float(model.encode_text)
189 |
190 | model.float()
191 |
192 | return model, _transform(model.input_resolution.item())
193 |
194 |
195 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
196 | """
197 | Returns the tokenized representation of given input string(s)
198 |
199 | Parameters
200 | ----------
201 | texts : Union[str, List[str]]
202 | An input string or a list of input strings to tokenize
203 |
204 | context_length : int
205 | The context length to use; all CLIP models use 77 as the context length
206 |
207 | truncate: bool
208 | Whether to truncate the text in case its encoding is longer than the context length
209 |
210 | Returns
211 | -------
212 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
213 | """
214 | if isinstance(texts, str):
215 | texts = [texts]
216 |
217 | sot_token = _tokenizer.encoder["<|startoftext|>"]
218 | eot_token = _tokenizer.encoder["<|endoftext|>"]
219 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
220 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
221 |
222 | for i, tokens in enumerate(all_tokens):
223 | if len(tokens) > context_length:
224 | if truncate:
225 | tokens = tokens[:context_length]
226 | tokens[-1] = eot_token
227 | else:
228 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
229 | result[i, :len(tokens)] = torch.tensor(tokens)
230 |
231 | return result
232 |
--------------------------------------------------------------------------------
/utils/dmc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | from collections import deque
6 | from typing import Any, NamedTuple
7 |
8 | import dm_env
9 | import numpy as np
10 | from dm_control import manipulation, suite
11 | from dm_control.suite.wrappers import action_scale, pixels
12 | from dm_env import StepType, specs
13 |
14 | # from envs.distracting_control.suite import distracting_wrapper
15 | #import envs.fb_mtenv_dmc as fb_mtenv_dmc
16 |
17 |
18 | def get_unique_int(difficulty: str) -> int:
19 | return int.from_bytes(f'{difficulty}_0'.encode(), 'little') % (2 ** 31)
20 |
21 |
22 | distracting_kwargs_lookup = {
23 | 'easy': {'difficulty': 'easy', 'fixed_distraction': False},
24 | 'medium': {'difficulty': 'medium', 'fixed_distraction': False},
25 | 'hard': {'difficulty': 'hard', 'fixed_distraction': False},
26 | 'fixed_easy': {'difficulty': 'easy', 'fixed_distraction': True, 'color_seed': get_unique_int('easy'),
27 | 'background_seed': get_unique_int('easy'), 'camera_seed': get_unique_int('easy')},
28 | 'fixed_medium': {'difficulty': 'medium', 'fixed_distraction': True, 'color_seed': get_unique_int('medium'),
29 | 'background_seed': get_unique_int('medium'), 'camera_seed': get_unique_int('medium')},
30 | 'fixed_hard': {'difficulty': 'hard', 'fixed_distraction': True, 'color_seed': get_unique_int('hard'),
31 | 'background_seed': get_unique_int('hard'), 'camera_seed': get_unique_int('hard')},
32 | }
33 |
34 | multitask_modes = [f'len_{i}' for i in range(1, 11, 1)]
35 |
36 |
37 | class ExtendedTimeStep(NamedTuple):
38 | step_type: Any
39 | reward: Any
40 | discount: Any
41 | observation: Any
42 | pixel_observation: Any
43 | action: Any
44 | latent: Any
45 | imp_action: Any
46 | k_step: Any
47 |
48 | def first(self):
49 | return self.step_type == StepType.FIRST
50 |
51 | def mid(self):
52 | return self.step_type == StepType.MID
53 |
54 | def last(self):
55 | return self.step_type == StepType.LAST
56 |
57 | #def __getitem__(self, attr):
58 | # return getattr(self, str(attr))
59 |
60 | class ExtendedTimeStepEval(NamedTuple):
61 | step_type: Any
62 | reward: Any
63 | discount: Any
64 | observation: Any
65 | action: Any
66 | latent: Any
67 | imp_action: Any
68 | k_step: Any
69 |
70 | def first(self):
71 | return self.step_type == StepType.FIRST
72 |
73 | def mid(self):
74 | return self.step_type == StepType.MID
75 |
76 | def last(self):
77 | return self.step_type == StepType.LAST
78 |
79 | #def __getitem__(self, attr):
80 | # return getattr(self, str(attr))
81 |
82 | class ActionRepeatWrapper(dm_env.Environment):
83 | def __init__(self, env, num_repeats):
84 | self._env = env
85 | self._num_repeats = num_repeats
86 |
87 | def step(self, action):
88 | reward = 0.0
89 | discount = 1.0
90 | for i in range(self._num_repeats):
91 | time_step = self._env.step(action)
92 | reward += (time_step.reward or 0.0) * discount
93 | discount *= time_step.discount
94 | if time_step.last():
95 | break
96 |
97 | return time_step._replace(reward=reward, discount=discount)
98 |
99 | def observation_spec(self):
100 | return self._env.observation_spec()
101 |
102 | def action_spec(self):
103 | return self._env.action_spec()
104 |
105 | def reset(self):
106 | return self._env.reset()
107 |
108 | def __getattr__(self, name):
109 | return getattr(self._env, name)
110 |
111 |
112 | class FrameStackWrapper(dm_env.Environment):
113 | def __init__(self, env, num_frames, pixels_key='pixels'):
114 | self._env = env
115 | self._num_frames = num_frames
116 | self._frames = deque([], maxlen=num_frames)
117 | self._pixels_key = pixels_key
118 |
119 | wrapped_obs_spec = env.observation_spec()
120 | assert pixels_key in wrapped_obs_spec
121 |
122 | pixels_shape = wrapped_obs_spec[pixels_key].shape
123 | # remove batch dim
124 | if len(pixels_shape) == 4:
125 | pixels_shape = pixels_shape[1:]
126 | self._obs_spec = specs.BoundedArray(shape=np.concatenate(
127 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0),
128 | dtype=np.uint8,
129 | minimum=0,
130 | maximum=255,
131 | name='observation')
132 |
133 | def _transform_observation(self, time_step):
134 | assert len(self._frames) == self._num_frames
135 | obs = np.concatenate(list(self._frames), axis=0)
136 | return time_step._replace(observation=obs)
137 |
138 | def _extract_pixels(self, time_step):
139 | try:
140 | pos = time_step.observation['position']
141 | except:
142 | pos = None
143 | pixels = time_step.observation[self._pixels_key]
144 | # remove batch dim
145 | if len(pixels.shape) == 4:
146 | pixels = pixels[0]
147 | return pixels.transpose(2, 0, 1).copy(), pos
148 |
149 | def reset(self):
150 | time_step = self._env.reset()
151 | pixels, pos = self._extract_pixels(time_step)
152 | for _ in range(self._num_frames):
153 | self._frames.append(pixels)
154 | return self._transform_observation(time_step), pos
155 |
156 | def step(self, action):
157 | time_step = self._env.step(action)
158 | pixels, pos = self._extract_pixels(time_step)
159 | self._frames.append(pixels)
160 | return self._transform_observation(time_step), pos
161 |
162 | def observation_spec(self):
163 | return self._obs_spec
164 |
165 | def action_spec(self):
166 | return self._env.action_spec()
167 |
168 | def __getattr__(self, name):
169 | return getattr(self._env, name)
170 |
171 |
172 | class ActionDTypeWrapper(dm_env.Environment):
173 | def __init__(self, env, dtype):
174 | self._env = env
175 | wrapped_action_spec = env.action_spec()
176 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
177 | dtype,
178 | wrapped_action_spec.minimum,
179 | wrapped_action_spec.maximum,
180 | 'action')
181 |
182 | def step(self, action):
183 | action = action.astype(self._env.action_spec().dtype)
184 | return self._env.step(action)
185 |
186 | def observation_spec(self):
187 | return self._env.observation_spec()
188 |
189 | def action_spec(self):
190 | return self._action_spec
191 |
192 | def reset(self):
193 | return self._env.reset()
194 |
195 | def __getattr__(self, name):
196 | return getattr(self._env, name)
197 |
198 |
199 | class ExtendedTimeStepWrapper(dm_env.Environment):
200 | def __init__(self, env):
201 | self._env = env
202 |
203 | def reset(self):
204 | time_step, pos = self._env.reset()
205 | return self._augment_time_step(time_step), pos
206 |
207 | def step(self, action):
208 | time_step, pos = self._env.step(action)
209 | return self._augment_time_step(time_step, action), pos
210 |
211 | def _augment_time_step(self, time_step, action=None):
212 | if action is None:
213 | action_spec = self.action_spec()
214 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
215 | return ExtendedTimeStepEval(observation=time_step.observation,
216 | step_type=time_step.step_type,
217 | action=action,
218 | reward=time_step.reward or 0.0,
219 | discount=time_step.discount or 1.0,
220 | latent=np.zeros(256),
221 | imp_action=np.zeros(84*84*1),
222 | k_step=0)
223 |
224 | def observation_spec(self):
225 | return self._env.observation_spec()
226 |
227 | def action_spec(self):
228 | return self._env.action_spec()
229 |
230 | def __getattr__(self, name):
231 | return getattr(self._env, name)
232 |
233 |
234 | def make(name, frame_stack, action_repeat, seed, distracting_mode: str = None, multitask_mode: str = None):
235 | pixel_hw = 84
236 | if 'offline' in name:
237 | name = '_'.join(name.split('_')[1:3])
238 | domain, task = name.split('_', 1)
239 | # overwrite cup to ball_in_cup
240 | domain = dict(cup='ball_in_cup').get(domain, domain)
241 |
242 | # make sure reward is not visualized
243 | if multitask_mode is None:
244 | if (domain, task) in suite.ALL_TASKS:
245 | env = suite.load(domain,
246 | task,
247 | task_kwargs={'random': seed},
248 | visualize_reward=False)
249 | pixels_key = 'pixels'
250 | else:
251 | name = f'{domain}_{task}_vision'
252 | env = manipulation.load(name, seed=seed)
253 | pixels_key = 'front_close'
254 | else:
255 | assert multitask_mode in multitask_modes, 'Unrecognised length setting'
256 | idx = multitask_mode.split('_', 1)[1]
257 |
258 | if domain == 'walker' and task == 'walk':
259 | xml = f'len_{idx}'
260 | elif domain == 'cheetah' and task == 'run':
261 | xml = f'torso_length_{idx}'
262 | else:
263 | raise Exception
264 |
265 | env = fb_mtenv_dmc.load(
266 | domain_name=domain,
267 | task_name=task,
268 | task_kwargs={'xml_file_id': xml, 'random': seed},
269 | visualize_reward=False,
270 | )
271 | pixels_key = 'pixels'
272 |
273 | # add wrappers
274 | env = ActionDTypeWrapper(env, np.float32)
275 | env = ActionRepeatWrapper(env, action_repeat)
276 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0)
277 | # add renderings for clasical tasks
278 | if (domain, task) in suite.ALL_TASKS:
279 | # zoom in camera for quadruped
280 | camera_id = dict(quadruped=2).get(domain, 0)
281 | render_kwargs = dict(height=pixel_hw, width=pixel_hw, camera_id=camera_id)
282 | if distracting_mode is not None:
283 | assert distracting_mode in distracting_kwargs_lookup, 'Unrecognised distraction'
284 | kwargs = distracting_kwargs_lookup[distracting_mode]
285 | kwargs['pixels_only'] = False
286 | kwargs['render_kwargs'] = render_kwargs
287 | kwargs['background_dataset_path'] = "/home/manant/scratch/DAVIS/JPEGImages/480p/" #"DAVIS/JPEGImages/480p/"
288 | env = distracting_wrapper(
289 | env,
290 | domain,
291 | **kwargs
292 | )
293 | else:
294 | env = pixels.Wrapper(env,
295 | pixels_only=False,
296 | render_kwargs=render_kwargs)
297 | # stack several frames
298 | env = FrameStackWrapper(env, frame_stack, pixels_key)
299 | env = ExtendedTimeStepWrapper(env)
300 | return env
301 |
--------------------------------------------------------------------------------
/dino/train_vq_dino_voc.py:
--------------------------------------------------------------------------------
1 | # from torchvision import transforms as pth_transforms
2 | from timm.models import create_model
3 | import modeling_vqkd
4 |
5 | import sys
6 | sys.path.insert(0, '../utils/')
7 |
8 | import os
9 | import math
10 | import copy
11 | import wandb
12 | import random
13 | import utils
14 |
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from omegaconf import OmegaConf
20 |
21 | from pathlib import Path
22 | from matplotlib import pyplot as plt
23 | from torchvision.utils import make_grid
24 |
25 | from numpy_replay_buffer import EfficientReplayBuffer
26 | from utils import load_offline_dataset_into_buffer
27 |
28 | # import dmc
29 | from dm_env import specs
30 |
31 | from transformers.optimization import get_scheduler
32 | from transformers import GPT2LMHeadModel, GPT2Config
33 |
34 | from network_utils import get_parameter_names, update_moving_average, Predictor, EMA, RandomShiftsAug
35 |
36 | ##################
37 | obs_resize = 224
38 |
39 | offline_dir = "../../cheetah_train/dino_latents/"
40 | dino_path = '../../vqkd_encoder_base_decoder_1x768x12_dino-663c55d7.pth'
41 | save_dir_path = "./"
42 |
43 | batch_size = 32
44 | frame_stack = 2
45 | device = 'cuda'
46 | ##################
47 | # - Pretrained DINO + Quantization
48 | ##################
49 |
50 | class TFAgent:
51 | def __init__(self, discount=0.8, augmentation=RandomShiftsAug(pad=4)):
52 |
53 | wandb.init(project="Video Occupancy Models",
54 | id="voc_dino_gamma_{}".format(discount),
55 | entity=None, dir=os.getcwd())
56 |
57 | data_specs = (specs.BoundedArray(shape=(9, 84, 84), dtype=np.uint8, name='observation', minimum=0, maximum=255),
58 | specs.BoundedArray(shape=(6,), dtype=np.float32, name='action', minimum=-1.0, maximum=1.0),
59 | specs.Array((1,), np.float32, 'reward'),
60 | specs.Array((1,), np.float32, 'discount'))
61 |
62 | self.discount = discount
63 | self.batch_size = batch_size
64 | self.nsteps = min(int(1 / (1 - self.discount)), 10) # only sample max 10 steps
65 | self.codebook_size = 8192
66 |
67 | ######### train on VQ latents #########
68 | self.replay_buffer = EfficientReplayBuffer(1000000, self.batch_size, 1, self.discount, frame_stack, False, data_specs)
69 | load_offline_dataset_into_buffer(Path(offline_dir), self.replay_buffer, None, frame_stack, 1000000, latent_style="dino")
70 | ##########################
71 |
72 | self.model = create_model(
73 | 'vqkd_encoder_base_decoder_1x768x12_dino',
74 | pretrained=True,
75 | pretrained_weight=dino_path,
76 | as_tokenzer=True,
77 | ) #.to('cuda').eval()
78 | self.model.to(device).eval()
79 |
80 | ######### gpt setup #########
81 | configuration = GPT2Config(vocab_size=8192, n_positions=196 * frame_stack * 2, n_layer=4, n_head=4, n_embed=128) # nano-gpt
82 | self.gpt = GPT2LMHeadModel(configuration).to(device)
83 | self.gpt_target = copy.deepcopy(self.gpt)
84 | self.gpt_target.generation_config.output_scores = True
85 | self.target_ema_updater = EMA(0.9)
86 | ##########################
87 |
88 | ######### optimizer setup #########
89 | self.reward_predictor = Predictor(32 * 14 * 14 * frame_stack).to(device)
90 | self.gpt_optimizer = torch.optim.AdamW(list(self.get_grouped_params(self.gpt)), lr=3e-4)
91 | self.optimizer = torch.optim.AdamW(list(self.reward_predictor.parameters()), lr=3e-4)
92 |
93 | num_training_steps = 100000
94 | self.warmup_ratio = 0.05
95 | warmup_steps = math.ceil(num_training_steps * self.warmup_ratio)
96 | self.lr_scheduler = get_scheduler(
97 | "cosine",
98 | optimizer=self.gpt_optimizer,
99 | num_warmup_steps=warmup_steps,
100 | num_training_steps=num_training_steps,
101 | )
102 | ##########################
103 |
104 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to(device)
105 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to(device)
106 |
107 | self.device = device
108 | self.aug = augmentation
109 |
110 | self.saving_iter = [50, 100, 500, 1000, 2000, 5000, 10000, 50000, 75000, 100000]
111 | self.train()
112 |
113 | def get_grouped_params(self, model):
114 | decay_parameters = get_parameter_names(model, [nn.LayerNorm])
115 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
116 | optimizer_grouped_parameters = [
117 | {
118 | "params": [
119 | p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad)
120 | ],
121 | "weight_decay": 0.1,
122 | },
123 | {
124 | "params": [
125 | p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)
126 | ],
127 | "weight_decay": 0.0,
128 | },
129 | ]
130 | return optimizer_grouped_parameters
131 |
132 | def train(self, training=True):
133 | self.training = training
134 |
135 | def preprocess_obs(self, obs):
136 | obs = F.interpolate(obs, size=obs_resize)
137 |
138 | if len(obs.shape) == 3:
139 | obs = obs.unsqueeze(0)
140 |
141 | try:
142 | assert len(obs.shape) == 4 # B x C x H x W
143 | org_obs_shape = obs.shape
144 | # normalize and preprocess
145 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(frame_stack)])
146 | obs = torch.einsum('snhwc->nschw', obs).reshape((org_obs_shape[0] * frame_stack, 3, *org_obs_shape[2:]))
147 | except:
148 | assert len(obs.shape) == 5 # T x B x C x H x W
149 | org_obs_shape = t, b, c, h, w = obs.shape
150 | obs = torch.stack([torch.einsum('tnchw->tnhwc', obs[:, :, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(frame_stack)])
151 | obs = torch.einsum('stnhwc->tnschw', obs).reshape((t, b * frame_stack, 3, h, w))
152 | return obs, org_obs_shape
153 |
154 | def update(self, step=0):
155 | metrics = dict()
156 |
157 | batch, indices = next(self.replay_buffer)
158 | obs, action, reward, discount, next_obs, _, _, _, obs_k = utils.to_torch(batch, self.device)
159 |
160 | y_context = obs.long()
161 | y_target = obs_k.long()
162 | obs_shape = obs.shape
163 |
164 | quant_target = self.model.get_codebook_entry(y_target.reshape(obs_shape[0]*frame_stack, -1))
165 |
166 | quant_context = self.model.get_codebook_entry(y_context.reshape(obs_shape[0]*frame_stack, -1))
167 |
168 | pred_reward = self.reward_predictor(quant_target.detach().float().reshape(obs_shape[0], -1))
169 | reward_loss = F.mse_loss(pred_reward, reward.float()).mean()
170 |
171 | # generate target
172 | with torch.no_grad():
173 | p_t = self.gpt_target.generate(y_target, max_new_tokens=y_target.shape[-1], do_sample=True, pad_token_id=-100)
174 | p_t = p_t[:, -y_target.shape[-1]:]
175 |
176 | # gamma sampling
177 | gamma = self.discount * torch.ones((y_context.shape[0], ), device=y_context.device)
178 | prob = torch.bernoulli(gamma)
179 | p_target = torch.zeros_like(y_target)
180 |
181 | # with prob 1-gamma, sample from next state
182 | p_c_idx = torch.nonzero(1 - prob)
183 | p_target[p_c_idx] = y_target[p_c_idx]
184 |
185 | # with prob gamma, sample from bootstrapped model
186 | p_t_idx = torch.nonzero(prob)
187 | p_target[p_t_idx] = p_t[p_t_idx]
188 |
189 | # gpt predictions
190 | inp = torch.cat([y_context, p_target], dim=1)
191 | # mask_ids = torch.cat([context_mask_ids, target_mask_ids], dim=1)
192 | outputs = self.gpt(inp, labels=inp)
193 | gpt_loss = outputs.loss
194 |
195 | loss = gpt_loss + reward_loss
196 | loss.backward()
197 |
198 | # grad accumulate
199 | if step % 1 == 0:
200 | self.optimizer.step()
201 | self.gpt_optimizer.step()
202 |
203 | self.optimizer.zero_grad()
204 | self.gpt_optimizer.zero_grad()
205 |
206 | self.lr_scheduler.step()
207 | update_moving_average(self.target_ema_updater, self.gpt_target, self.gpt)
208 |
209 | # visualize predictions
210 | if step % 200 == 0:
211 | with torch.no_grad():
212 | # sample a batch of traj and corresponding values
213 | batch, indices = self.replay_buffer.sample_spr(jumps=self.nsteps)
214 | _, _, _, _, _, all_obs, all_pixel_obs, _, values = utils.to_torch(batch, self.device)
215 |
216 | # preprocess first obs from traj
217 | obs = all_obs[0]
218 | obs_shape = obs.shape
219 |
220 | # embed first obs
221 | y_context = obs.long()
222 |
223 | value_loss = self.get_value_estimates(y_context, values, obs_shape)
224 | wandb.log({"value loss": value_loss}, step=step)
225 |
226 | wandb.log({"gpt loss": gpt_loss}, step=step)
227 | wandb.log({"reward loss": reward_loss}, step=step)
228 |
229 | # save gpt model
230 | if step in self.saving_iter:
231 | print("saving gpt weights...")
232 | self.save_gpt_weights(step)
233 |
234 | return metrics
235 |
236 | def save_gpt_weights(self, step):
237 | torch.save(self.gpt.state_dict(), os.path.join(save_dir_path, "dino_nanogpt_gamma_{}_{}_model_step_{}.pth".format(self.discount, self.codebook_size, step)))
238 |
239 | def get_value_estimates(self, y_context, values, obs_shape):
240 | # Take a state, get samples from the gamma distribution,
241 | # Run the reward predictor through these to get value estimates
242 | # Get ground truth value estimates by simply taking discounted sum of rewards
243 | # Compare these for different states
244 |
245 | num_gamma_samples = 100
246 | values_pred = []
247 |
248 | for i in range(num_gamma_samples):
249 | outputs = self.gpt_target.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, output_scores=True, return_dict_in_generate=True, pad_token_id=-100) #, kwargs={'token_type_ids': context_mask_ids})
250 | p_t = outputs.sequences[:, -y_context.shape[-1]:]
251 |
252 | # quant = self.model.quantize.get_codebook_entry(p_t, None)
253 | # quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2)
254 | quant = self.model.get_codebook_entry(p_t.reshape(obs_shape[0]*frame_stack, -1))
255 |
256 | values_pred.append(self.reward_predictor(quant.float().reshape(obs_shape[0], -1)).squeeze(1))
257 |
258 | values_pred = torch.stack(values_pred).sum(0) / (100 * (1 - self.discount))
259 |
260 | value_estimation_loss = F.mse_loss(values_pred, values.squeeze(1).float()).mean()
261 | print("val estimation", value_estimation_loss, values_pred[:5], values[:5])
262 |
263 | return value_estimation_loss
264 |
265 |
266 | import argparse
267 |
268 | if __name__ == "__main__":
269 | parser = argparse.ArgumentParser(description='Create a dictionary with command-line arguments.')
270 |
271 | parser.add_argument('--discount', type=float, default=0.8, help='discount')
272 | args = parser.parse_args()
273 |
274 | agent = TFAgent(discount=args.discount)
275 |
276 | agent.optimizer.zero_grad()
277 | agent.gpt_optimizer.zero_grad()
278 |
279 | for step in range(100000):
280 | agent.update(step)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import random
6 | import re
7 | import time
8 |
9 | import numpy as np
10 | import h5py
11 | from collections import deque
12 | import dmc
13 | from dm_env import StepType
14 | from numpy_replay_buffer import EfficientReplayBuffer
15 |
16 | import torch
17 | import torch.nn as nn
18 | from torch import distributions as pyd
19 | from torch.distributions.utils import _standard_normal
20 | import torchvision.transforms as T
21 |
22 | class eval_mode:
23 | def __init__(self, *models):
24 | self.models = models
25 |
26 | def __enter__(self):
27 | self.prev_states = []
28 | for model in self.models:
29 | self.prev_states.append(model.training)
30 | model.train(False)
31 |
32 | def __exit__(self, *args):
33 | for model, state in zip(self.models, self.prev_states):
34 | model.train(state)
35 | return False
36 |
37 |
38 | def set_seed_everywhere(seed):
39 | torch.manual_seed(seed)
40 | if torch.cuda.is_available():
41 | torch.cuda.manual_seed_all(seed)
42 | np.random.seed(seed)
43 | random.seed(seed)
44 |
45 |
46 | def soft_update_params(net, target_net, tau):
47 | for param, target_param in zip(net.parameters(), target_net.parameters()):
48 | target_param.data.copy_(tau * param.data +
49 | (1 - tau) * target_param.data)
50 |
51 |
52 | def to_torch(xs, device):
53 | return tuple(torch.as_tensor(x, device=device) for x in xs)
54 |
55 |
56 | def weight_init(m):
57 | if isinstance(m, nn.Linear):
58 | nn.init.orthogonal_(m.weight.data)
59 | if hasattr(m.bias, 'data'):
60 | m.bias.data.fill_(0.0)
61 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
62 | gain = nn.init.calculate_gain('relu')
63 | nn.init.orthogonal_(m.weight.data, gain)
64 | if hasattr(m.bias, 'data'):
65 | m.bias.data.fill_(0.0)
66 |
67 |
68 | class Until:
69 | def __init__(self, until, action_repeat=1):
70 | self._until = until
71 | self._action_repeat = action_repeat
72 |
73 | def __call__(self, step):
74 | if self._until is None:
75 | return True
76 | until = self._until // self._action_repeat
77 | return step < until
78 |
79 |
80 | class Every:
81 | def __init__(self, every, action_repeat=1):
82 | self._every = every
83 | self._action_repeat = action_repeat
84 |
85 | def __call__(self, step):
86 | if self._every is None:
87 | return False
88 | every = self._every // self._action_repeat
89 | if step % every == 0:
90 | return True
91 | return False
92 |
93 |
94 | class Timer:
95 | def __init__(self):
96 | self._start_time = time.time()
97 | self._last_time = time.time()
98 |
99 | def reset(self):
100 | elapsed_time = time.time() - self._last_time
101 | self._last_time = time.time()
102 | total_time = time.time() - self._start_time
103 | return elapsed_time, total_time
104 |
105 | def total_time(self):
106 | return time.time() - self._start_time
107 |
108 |
109 | class TruncatedNormal(pyd.Normal):
110 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
111 | super().__init__(loc, scale, validate_args=False)
112 | self.low = low
113 | self.high = high
114 | self.eps = eps
115 |
116 | def _clamp(self, x):
117 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
118 | x = x - x.detach() + clamped_x.detach()
119 | return x
120 |
121 | def sample(self, clip=None, sample_shape=torch.Size()):
122 | shape = self._extended_shape(sample_shape)
123 | eps = _standard_normal(shape,
124 | dtype=self.loc.dtype,
125 | device=self.loc.device)
126 | eps *= self.scale
127 | if clip is not None:
128 | eps = torch.clamp(eps, -clip, clip)
129 | x = self.loc + eps
130 | return self._clamp(x)
131 |
132 |
133 | def schedule(schdl, step):
134 | try:
135 | return float(schdl)
136 | except ValueError:
137 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
138 | if match:
139 | init, final, duration = [float(g) for g in match.groups()]
140 | mix = np.clip(step / duration, 0.0, 1.0)
141 | return (1.0 - mix) * init + mix * final
142 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)
143 | if match:
144 | init, final1, duration1, final2, duration2 = [
145 | float(g) for g in match.groups()
146 | ]
147 | if step <= duration1:
148 | mix = np.clip(step / duration1, 0.0, 1.0)
149 | return (1.0 - mix) * init + mix * final1
150 | else:
151 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
152 | return (1.0 - mix) * final1 + mix * final2
153 | raise NotImplementedError(schdl)
154 |
155 |
156 | step_type_lookup = {
157 | 0: StepType.FIRST,
158 | 1: StepType.MID,
159 | 2: StepType.LAST
160 | }
161 |
162 |
163 | def load_offline_dataset_into_buffer(offline_dir, replay_buffer, agent, frame_stack, replay_buffer_size, future_sampling_steps=2, latent_style="vae"):
164 | filenames = sorted(offline_dir.glob('*.hdf5'))
165 | num_steps = 0
166 | print("filename is", filenames, offline_dir, offline_dir.glob('*.hdf5'))
167 | for filename in filenames:
168 | #try:
169 | episodes = h5py.File(filename, 'r')
170 | episodes = {k: episodes[k][:] for k in episodes.keys()}
171 | add_offline_data_to_buffer(episodes, replay_buffer, agent, framestack=frame_stack, future_sampling_steps=future_sampling_steps, latent_style=latent_style)
172 | length = episodes['reward'].shape[0]
173 | num_steps += length
174 | #except Exception as e:
175 | # print(f'Could not load episode {str(filename)}: {e}')
176 | # continue
177 | print("Loaded {} offline timesteps so far...".format(int(num_steps)))
178 | if num_steps >= replay_buffer_size:
179 | break
180 | print("Finished, loaded {} timesteps.".format(int(num_steps)))
181 |
182 |
183 | def add_offline_data_to_buffer(offline_data: dict, replay_buffer: EfficientReplayBuffer, agent, framestack: int = 3, future_sampling_steps: int = 2, latent_style: str = "vae"):
184 | offline_data_length = offline_data['reward'].shape[0]
185 | for v in offline_data.values():
186 | assert v.shape[0] == offline_data_length
187 | done_list = np.argwhere(offline_data['step_type']==2)
188 | assert len(done_list) > 1
189 | interval = done_list[1] - done_list[0]
190 | now = -1
191 | max_k = future_sampling_steps #15
192 |
193 | resize = T.Compose([
194 | T.ToPILImage(),
195 | T.Resize((64, 64)),
196 | # T.ToTensor()
197 | ])
198 |
199 | for idx in range(offline_data_length):
200 | time_step = get_timestep_from_idx(offline_data, idx, latent_style)
201 | if not time_step.first():
202 | now += 1
203 | # stacked_frames.append(np.asarray(resize(time_step.observation.reshape(84, 84, -1))).reshape(-1, 64, 64))
204 | stacked_frames.append(time_step.observation)
205 | stacked_pixel_frames.append(time_step.pixel_observation)
206 | time_step_stack = time_step._replace(observation=np.concatenate(stacked_frames, axis=0), pixel_observation=np.concatenate(stacked_pixel_frames, axis=0))
207 | with torch.no_grad(): #, eval_mode(agent):
208 | ob = torch.as_tensor(np.concatenate(stacked_frames, axis=0), device='cuda')
209 | pixel_ob = torch.as_tensor(np.concatenate(stacked_pixel_frames, axis=0), device='cuda')
210 | # imp_action = torch.abs(agent.actor(agent.encoder(ob.unsqueeze(0)).squeeze(0)))
211 | #act = torch.as_tensor(imp_action.reshape(9, 84, 84), device=agent.device)
212 | #new_ob = torch.clamp(torch.sqrt(act) * ob + torch.sqrt(1 - act) * 255 * torch.randn(9, 84, 84, device=agent.device), min=0, max=255).type(torch.int64)
213 | # latent = agent.get_latent(ob, imp_action).squeeze(0) #agent.encoder(new_ob.unsqueeze(0)).squeeze(0) #agent.get_latent(state, action, latent=No)
214 | # time_step_stack = time_step_stack._replace(latent=latent.cpu().detach().numpy())
215 | # time_step_stack = time_step_stack._replace(imp_action=imp_action.cpu().numpy())
216 | rindex = min(interval-1, now+max_k)
217 | rindex = rindex - now
218 | time_step_stack = time_step_stack._replace(k_step=rindex)
219 | replay_buffer.add(time_step_stack)
220 | else:
221 | now = -1
222 | stacked_frames = deque(maxlen=framestack)
223 | stacked_pixel_frames = deque(maxlen=framestack)
224 | while len(stacked_frames) < framestack:
225 | # stacked_frames.append(np.asarray(resize(time_step.observation.reshape(84, 84, -1))).reshape(-1, 64, 64))
226 | stacked_frames.append(time_step.observation)
227 | stacked_pixel_frames.append(time_step.pixel_observation)
228 | time_step_stack = time_step._replace(observation=np.concatenate(stacked_frames, axis=0), pixel_observation=np.concatenate(stacked_pixel_frames, axis=0))
229 | with torch.no_grad(): #, eval_mode(agent):
230 | ob = torch.as_tensor(np.concatenate(stacked_frames, axis=0), device='cuda')
231 | pixel_ob = torch.as_tensor(np.concatenate(stacked_pixel_frames, axis=0), device='cuda')
232 | # imp_action = torch.abs(agent.actor(agent.encoder(ob.unsqueeze(0)).squeeze(0)))
233 | #act = torch.as_tensor(imp_action.reshape(9, 84, 84), device=agent.device)
234 | #new_ob = torch.clamp(torch.sqrt(act) * ob + torch.sqrt(1 - act) * 255 * torch.randn(9, 84, 84, device=agent.device), min=0, max=255).type(torch.int64)
235 | # latent = agent.get_latent(ob, imp_action).squeeze(0) #agent.get_latent(state, action, latent=No)
236 | #imp_action = torch.abs(agent.actor(latent))
237 | # time_step_stack = time_step_stack._replace(latent=latent.cpu().detach().numpy())
238 | # time_step_stack = time_step_stack._replace(imp_action=imp_action.cpu().numpy())
239 | rindex = min(interval-1, now+max_k) #random.randint(now+1, min(interval-1, now+max_k))
240 | rindex = rindex - now
241 | time_step_stack = time_step_stack._replace(k_step=rindex)
242 | replay_buffer.add(time_step_stack)
243 |
244 |
245 | ## TODO: Add dummy values for storing 'mae_latents' and 'mae_ids_restore' so that a common replay buffer can be used for both, OR have two separate replay buffer codes
246 | def get_timestep_from_idx(offline_data: dict, idx: int, latent_style: str):
247 | if latent_style == "vae":
248 | return dmc.ExtendedTimeStep(
249 | step_type=step_type_lookup[offline_data['step_type'][idx]],
250 | reward=offline_data['reward'][idx],
251 | pixel_observation=offline_data['observation'][idx],
252 | observation=offline_data['vae_latents'][idx],
253 | discount=offline_data['discount'][idx],
254 | action=offline_data['action'][idx],
255 | latent=np.zeros(256),
256 | imp_action=np.zeros(84*84*1),
257 | k_step = idx
258 | )
259 | elif latent_style == "dino":
260 | return dmc.ExtendedTimeStep(
261 | step_type=step_type_lookup[offline_data['step_type'][idx]],
262 | reward=offline_data['reward'][idx],
263 | pixel_observation=offline_data['observation'][idx],
264 | observation=offline_data['dino_latents'][idx],
265 | discount=offline_data['discount'][idx],
266 | action=offline_data['action'][idx],
267 | latent=np.zeros(256),
268 | imp_action=np.zeros(84*84*1),
269 | k_step = idx
270 | )
--------------------------------------------------------------------------------
/dino/vqkd_teacher/dino.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beitv2
4 | # Copyright (c) 2022 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Zhiliang Peng
7 | # Based on DINO code bases
8 | # https://github.com/facebookresearch/dino
9 | # --------------------------------------------------------'
10 |
11 | import math
12 | from functools import partial
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from functools import partial, reduce
18 | from collections import OrderedDict
19 |
20 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
21 |
22 | import pdb
23 |
24 | # https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth
25 |
26 | def drop_path(x, drop_prob: float = 0., training: bool = False):
27 | if drop_prob == 0. or not training:
28 | return x
29 | keep_prob = 1 - drop_prob
30 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
31 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
32 | random_tensor.floor_() # binarize
33 | output = x.div(keep_prob) * random_tensor
34 | return output
35 |
36 |
37 | class DropPath(nn.Module):
38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
39 | """
40 | def __init__(self, drop_prob=None):
41 | super(DropPath, self).__init__()
42 | self.drop_prob = drop_prob
43 |
44 | def forward(self, x):
45 | return drop_path(x, self.drop_prob, self.training)
46 |
47 |
48 | class Mlp(nn.Module):
49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
50 | super().__init__()
51 | out_features = out_features or in_features
52 | hidden_features = hidden_features or in_features
53 | self.fc1 = nn.Linear(in_features, hidden_features)
54 | self.act = act_layer()
55 | self.fc2 = nn.Linear(hidden_features, out_features)
56 | self.drop = nn.Dropout(drop)
57 |
58 | def forward(self, x):
59 | x = self.fc1(x)
60 | x = self.act(x)
61 | x = self.drop(x)
62 | x = self.fc2(x)
63 | x = self.drop(x)
64 | return x
65 |
66 |
67 | class Attention(nn.Module):
68 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
69 | super().__init__()
70 | self.num_heads = num_heads
71 | head_dim = dim // num_heads
72 | self.scale = qk_scale or head_dim ** -0.5
73 |
74 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
75 | self.attn_drop = nn.Dropout(attn_drop)
76 | self.proj = nn.Linear(dim, dim)
77 | self.proj_drop = nn.Dropout(proj_drop)
78 |
79 | def forward(self, x):
80 | B, N, C = x.shape
81 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
82 | q, k, v = qkv[0], qkv[1], qkv[2]
83 |
84 | attn = (q @ k.transpose(-2, -1)) * self.scale
85 | attn = attn.softmax(dim=-1)
86 | attn = self.attn_drop(attn)
87 |
88 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
89 | x = self.proj(x)
90 | x = self.proj_drop(x)
91 | return x, attn
92 |
93 |
94 | class Block(nn.Module):
95 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
96 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
97 | super().__init__()
98 | self.norm1 = norm_layer(dim)
99 | self.attn = Attention(
100 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
101 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
102 | self.norm2 = norm_layer(dim)
103 | mlp_hidden_dim = int(dim * mlp_ratio)
104 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
105 |
106 | def forward(self, x, return_attention=False):
107 | y, attn = self.attn(self.norm1(x))
108 | if return_attention:
109 | return attn
110 | x = x + self.drop_path(y)
111 | x = x + self.drop_path(self.mlp(self.norm2(x)))
112 | return x
113 |
114 |
115 | class PatchEmbed(nn.Module):
116 | """ Image to Patch Embedding
117 | """
118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
119 | super().__init__()
120 | num_patches = (img_size // patch_size) * (img_size // patch_size)
121 | self.img_size = img_size
122 | self.patch_size = patch_size
123 | self.num_patches = num_patches
124 |
125 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
126 |
127 | def forward(self, x):
128 | B, C, H, W = x.shape
129 | x = self.proj(x).flatten(2).transpose(1, 2)
130 | return x
131 |
132 |
133 | class VisionTransformer(nn.Module):
134 | """ Vision Transformer """
135 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
136 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
137 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
138 | super().__init__()
139 | self.num_features = self.embed_dim = embed_dim
140 |
141 | self.patch_embed = PatchEmbed(
142 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
143 | num_patches = self.patch_embed.num_patches
144 |
145 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
146 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
147 | self.pos_drop = nn.Dropout(p=drop_rate)
148 |
149 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
150 | self.blocks = nn.ModuleList([
151 | Block(
152 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
153 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
154 | for i in range(depth)])
155 | self.norm = norm_layer(embed_dim)
156 |
157 | # Classifier head
158 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
159 |
160 | trunc_normal_(self.pos_embed, std=.02)
161 | trunc_normal_(self.cls_token, std=.02)
162 | self.apply(self._init_weights)
163 |
164 | if kwargs.get('pretrained', True):
165 | self.load_from_pretrained('https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth')
166 | if not kwargs.get('requires_grad', False):
167 | for param in self.parameters():
168 | param.requires_grad = False
169 |
170 | def _init_weights(self, m):
171 | if isinstance(m, nn.Linear):
172 | trunc_normal_(m.weight, std=.02)
173 | if isinstance(m, nn.Linear) and m.bias is not None:
174 | nn.init.constant_(m.bias, 0)
175 | elif isinstance(m, nn.LayerNorm):
176 | nn.init.constant_(m.bias, 0)
177 | nn.init.constant_(m.weight, 1.0)
178 |
179 | def load_from_pretrained(self, ckpt_path):
180 | if ckpt_path.startswith('https'):
181 | sd = torch.hub.load_state_dict_from_url(ckpt_path, map_location='cpu', check_hash=True)
182 | else:
183 | sd = torch.load(ckpt_path, map_location='cpu')
184 |
185 | missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
186 | print(f"Load weight for dino model: {ckpt_path}")
187 | print(f"missing_keys: {missing_keys}")
188 | print(f"unexpected_keys: {unexpected_keys}")
189 |
190 | def interpolate_pos_encoding(self, x, w, h):
191 | npatch = x.shape[1] - 1
192 | N = self.pos_embed.shape[1] - 1
193 | if npatch == N and w == h:
194 | return self.pos_embed
195 | class_pos_embed = self.pos_embed[:, 0]
196 | patch_pos_embed = self.pos_embed[:, 1:]
197 | dim = x.shape[-1]
198 | w0 = w // self.patch_embed.patch_size
199 | h0 = h // self.patch_embed.patch_size
200 | # we add a small number to avoid floating point error in the interpolation
201 | # see discussion at https://github.com/facebookresearch/dino/issues/8
202 | w0, h0 = w0 + 0.1, h0 + 0.1
203 | patch_pos_embed = nn.functional.interpolate(
204 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
205 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
206 | mode='bicubic',
207 | )
208 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
209 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
211 |
212 | def prepare_tokens(self, x):
213 | B, nc, w, h = x.shape
214 | x = self.patch_embed(x) # patch linear embedding
215 |
216 | # add the [CLS] token to the embed patch tokens
217 | cls_tokens = self.cls_token.expand(B, -1, -1)
218 | x = torch.cat((cls_tokens, x), dim=1)
219 |
220 | # add positional encoding to each token
221 | x = x + self.interpolate_pos_encoding(x, w, h)
222 |
223 | return self.pos_drop(x)
224 |
225 | def forward(self, x, return_patch_tokens=False, return_all_tokens=False):
226 | x = self.prepare_tokens(x)
227 | for blk in self.blocks:
228 | x = blk(x)
229 | x = self.norm(x)
230 | if return_all_tokens:
231 | return x
232 | elif return_patch_tokens:
233 | return x[:, 1:]
234 | else:
235 | return x[:, 0]
236 |
237 | def get_last_selfattention(self, x):
238 | x = self.prepare_tokens(x)
239 | for i, blk in enumerate(self.blocks):
240 | if i < len(self.blocks) - 1:
241 | x = blk(x)
242 | else:
243 | # return attention of the last block
244 | return blk(x, return_attention=True)
245 |
246 | def get_intermediate_layers(self, x, n=1):
247 | x = self.prepare_tokens(x)
248 | # we return the output tokens from the `n` last blocks
249 | output = []
250 | for i, blk in enumerate(self.blocks):
251 | x = blk(x)
252 | if len(self.blocks) - i <= n:
253 | output.append(self.norm(x))
254 | return output
255 |
256 | def forward_intermediate(self, x, layer_id=12):
257 | x = self.prepare_tokens(x)
258 |
259 | if isinstance(layer_id, list):
260 | output_list = []
261 | for l, blk in enumerate(self.blocks):
262 | x = blk(x)
263 | if l in layer_id:
264 | output_list.append(x[:, 1:])
265 | # output_list.append(self.norm(x))
266 | return output_list
267 | elif isinstance(layer_id, int):
268 | for l, blk in enumerate(self.blocks):
269 | if l < layer_id:
270 | x = blk(x)
271 | elif l == layer_id:
272 | # pdb.set_trace()
273 | x = blk.norm1(x)
274 | else:
275 | break
276 | return x[:, 1:]
277 |
278 |
279 | def vit_tiny(patch_size=16, **kwargs):
280 | model = VisionTransformer(
281 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
282 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
283 | return model
284 |
285 |
286 | def vit_small(patch_size=16, **kwargs):
287 | model = VisionTransformer(
288 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
289 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
290 | return model
291 |
292 |
293 | def vit_base(patch_size=16, **kwargs):
294 | model = VisionTransformer(
295 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
296 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
297 | return model
298 |
299 | def get_dino_vit_base():
300 | return vit_base(pretrained=True, requires_grad=False)
301 |
302 |
303 | class DINOHead(nn.Module):
304 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
305 | super().__init__()
306 | nlayers = max(nlayers, 1)
307 | if nlayers == 1:
308 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
309 | else:
310 | layers = [nn.Linear(in_dim, hidden_dim)]
311 | if use_bn:
312 | layers.append(nn.BatchNorm1d(hidden_dim))
313 | layers.append(nn.GELU())
314 | for _ in range(nlayers - 2):
315 | layers.append(nn.Linear(hidden_dim, hidden_dim))
316 | if use_bn:
317 | layers.append(nn.BatchNorm1d(hidden_dim))
318 | layers.append(nn.GELU())
319 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
320 | self.mlp = nn.Sequential(*layers)
321 | self.apply(self._init_weights)
322 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
323 | self.last_layer.weight_g.data.fill_(1)
324 | if norm_last_layer:
325 | self.last_layer.weight_g.requires_grad = False
326 |
327 | def _init_weights(self, m):
328 | if isinstance(m, nn.Linear):
329 | trunc_normal_(m.weight, std=.02)
330 | if isinstance(m, nn.Linear) and m.bias is not None:
331 | nn.init.constant_(m.bias, 0)
332 |
333 | def forward(self, x):
334 | x = self.mlp(x)
335 | x = nn.functional.normalize(x, dim=-1, p=2)
336 | x = self.last_layer(x)
337 | return x
--------------------------------------------------------------------------------
/utils/drqv2.py:
--------------------------------------------------------------------------------
1 | from transformers import GPT2LMHeadModel, GPT2Config
2 |
3 | import hydra
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision.transforms as T
9 |
10 | import utils
11 | from pathlib import Path
12 |
13 | from numpy_replay_buffer import EfficientReplayBuffer
14 | from utils import load_offline_dataset_into_buffer
15 |
16 | from dm_env import specs
17 | import dmc
18 |
19 |
20 | def k3s1p0(x):
21 | return x - 2
22 |
23 |
24 | def k4s2p0(x):
25 | assert ((x % 2) == 0)
26 | return (x // 2) - 1
27 |
28 |
29 | class RandomShiftsAug(nn.Module):
30 | def __init__(self, pad=4):
31 | super().__init__()
32 | self.pad = pad
33 |
34 | def forward(self, x):
35 | # x = T.Resize((x.shape[0], x.shape[1], 64, 64))
36 | n, c, h, w = x.size()
37 | assert h == w
38 | padding = tuple([self.pad] * 4)
39 | x = F.pad(x, padding, 'replicate')
40 | eps = 1.0 / (h + 2 * self.pad)
41 | arange = torch.linspace(-1.0 + eps,
42 | 1.0 - eps,
43 | h + 2 * self.pad,
44 | device=x.device,
45 | dtype=x.dtype)[:h]
46 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
47 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
48 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
49 |
50 | shift = torch.randint(0,
51 | 2 * self.pad + 1,
52 | size=(n, 1, 1, 2),
53 | device=x.device,
54 | dtype=x.dtype)
55 | shift *= 2.0 / (h + 2 * self.pad)
56 |
57 | grid = base_grid + shift
58 | return F.grid_sample(x,
59 | grid,
60 | padding_mode='zeros',
61 | align_corners=False)
62 |
63 |
64 | class NoShiftAug(nn.Module):
65 | def __init__(self):
66 | super().__init__()
67 |
68 | def forward(self, x):
69 | return x
70 |
71 | class LayerNorm(nn.Module):
72 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
73 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
74 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
75 | with shape (batch_size, channels, height, width).
76 | """
77 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
78 | super().__init__()
79 | self.weight = nn.Parameter(torch.ones(normalized_shape))
80 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
81 | self.eps = eps
82 | self.data_format = data_format
83 | if self.data_format not in ["channels_last", "channels_first"]:
84 | raise NotImplementedError
85 | self.normalized_shape = (normalized_shape, )
86 |
87 | def forward(self, x):
88 | if self.data_format == "channels_last":
89 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
90 | elif self.data_format == "channels_first":
91 | u = x.mean(1, keepdim=True)
92 | s = (x - u).pow(2).mean(1, keepdim=True)
93 | x = (x - u) / torch.sqrt(s + self.eps)
94 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
95 | return x
96 |
97 | class Encoder(nn.Module):
98 | def __init__(self, obs_shape, feature_dim):
99 | super().__init__()
100 |
101 | assert len(obs_shape) == 1
102 |
103 | action_dim = 6
104 | self.repr_dim = 75
105 |
106 | self.linear = nn.Sequential(nn.Linear(self.repr_dim, feature_dim), nn.BatchNorm1d(feature_dim),
107 | nn.ReLU())
108 |
109 | self.apply(utils.weight_init)
110 |
111 | def forward(self, obs):
112 | h = self.linear(obs)
113 | return h
114 |
115 | class Actor(nn.Module):
116 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
117 | super().__init__()
118 |
119 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
120 | nn.LayerNorm(feature_dim), nn.Tanh())
121 |
122 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
123 | nn.ReLU(inplace=True),
124 | nn.Linear(hidden_dim, hidden_dim),
125 | nn.ReLU(inplace=True),
126 | nn.Linear(hidden_dim, action_shape[0]))
127 |
128 | self.apply(utils.weight_init)
129 |
130 | def forward(self, obs, std=None):
131 | # h = self.trunk(obs)
132 |
133 | mu = self.policy(obs)
134 | mu = torch.tanh(mu)
135 | if std is None:
136 | return mu
137 | std = torch.ones_like(mu) * std
138 |
139 | dist = utils.TruncatedNormal(mu, std)
140 | return dist
141 |
142 |
143 | class Critic(nn.Module):
144 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
145 | super().__init__()
146 |
147 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
148 | nn.LayerNorm(feature_dim), nn.Tanh())
149 |
150 | self.Q1 = nn.Sequential(
151 | nn.Linear(feature_dim + action_shape[0], hidden_dim),
152 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
153 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
154 |
155 | self.Q2 = nn.Sequential(
156 | nn.Linear(feature_dim + action_shape[0], hidden_dim),
157 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
158 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
159 |
160 | self.apply(utils.weight_init)
161 |
162 | def forward(self, obs, action):
163 | # h = self.trunk(obs)
164 | h_action = torch.cat([obs, action], dim=-1)
165 | q1 = self.Q1(h_action)
166 | q2 = self.Q2(h_action)
167 |
168 | return q1, q2
169 |
170 |
171 | class DrQV2Agent:
172 | def __init__(self, obs_shape=(75,), action_shape=(6,), device='cuda', lr=3e-4, feature_dim=64,
173 | hidden_dim=256, critic_target_tau=0.005, num_expl_steps=2000,
174 | update_every_steps=2, stddev_schedule='linear(1.0,0.1,100000)',
175 | stddev_clip=0.3, use_tb=False,
176 | offline=True, bc_weight=2.5, augmentation=RandomShiftsAug(pad=4),
177 | use_bc=True):
178 | self.device = device
179 | self.critic_target_tau = critic_target_tau
180 | self.update_every_steps = update_every_steps
181 | self.use_tb = use_tb
182 | self.num_expl_steps = num_expl_steps
183 | self.stddev_schedule = stddev_schedule
184 | self.stddev_clip = stddev_clip
185 | self.offline = offline
186 | self.bc_weight = bc_weight
187 | self.use_bc = use_bc
188 |
189 | # replay buffer
190 | self.train_env = dmc.make("offline_cheetah_run_expert", 3, 2, 0, None)
191 | data_specs = (self.train_env.observation_spec(),
192 | self.train_env.action_spec(),
193 | specs.Array((1,), np.float32, 'reward'),
194 | specs.Array((1,), np.float32, 'discount'))
195 |
196 | self.discount = 0.99
197 | self.replay_buffer = EfficientReplayBuffer(25000, 32, 1, self.discount, 3, False, data_specs)
198 |
199 | offline_dir = "/home/manant/scratch/expert/vq_latents/"
200 | load_offline_dataset_into_buffer(Path(offline_dir), self.replay_buffer, None, 3, 25000)
201 |
202 | # gpt model
203 | configuration = GPT2Config(vocab_size=1024)
204 | self.gpt = GPT2LMHeadModel(configuration).to('cuda')
205 | self.gpt.load_state_dict(torch.load("/home/manant/scratch/pixel_gamma/checkpoints/vq_gpt_gamma_model_4000.pth"))
206 |
207 | # actor critic models
208 | self.encoder = Encoder(obs_shape, feature_dim).to(device)
209 | self.actor = Actor(feature_dim, action_shape, feature_dim,
210 | hidden_dim).to(device)
211 |
212 | self.critic = Critic(feature_dim, action_shape, feature_dim,
213 | hidden_dim).to(device)
214 | self.critic_target = Critic(feature_dim, action_shape,
215 | feature_dim, hidden_dim).to(device)
216 | self.critic_target.load_state_dict(self.critic.state_dict())
217 |
218 | # optimizers
219 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr)
220 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
221 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
222 |
223 | # data augmentation
224 | self.aug = augmentation
225 |
226 | self.train()
227 | self.critic_target.train()
228 |
229 | def train(self, training=True):
230 | self.training = training
231 | self.encoder.train(training)
232 | self.actor.train(training)
233 | self.critic.train(training)
234 |
235 | def act(self, obs, latent, step, eval_mode):
236 | obs = torch.as_tensor(obs, device=self.device)
237 | obs = self.encoder(obs.unsqueeze(0))
238 | stddev = utils.schedule(self.stddev_schedule, step)
239 | dist = self.actor(obs, stddev)
240 | if eval_mode:
241 | action = dist.mean
242 | else:
243 | action = dist.sample(clip=None)
244 | if step < self.num_expl_steps:
245 | action.uniform_(-1.0, 1.0)
246 | return action.cpu().numpy()[0], None
247 |
248 | def update_critic(self, obs, action, reward, discount, next_obs, step):
249 | metrics = dict()
250 |
251 | with torch.no_grad():
252 | stddev = utils.schedule(self.stddev_schedule, step)
253 | dist = self.actor(next_obs, stddev)
254 | next_action = dist.sample(clip=self.stddev_clip)
255 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
256 | target_V = torch.min(target_Q1, target_Q2)
257 | target_Q = reward.float() + (discount * target_V)
258 |
259 | Q1, Q2 = self.critic(obs, action)
260 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
261 |
262 | if self.use_tb:
263 | metrics['critic_target_q'] = target_Q.mean().item()
264 | metrics['critic_q1'] = Q1.mean().item()
265 | metrics['critic_q2'] = Q2.mean().item()
266 | metrics['critic_loss'] = critic_loss.item()
267 |
268 | # optimize encoder and critic
269 | self.encoder_opt.zero_grad(set_to_none=True)
270 | self.critic_opt.zero_grad(set_to_none=True)
271 | critic_loss.backward()
272 | self.critic_opt.step()
273 | self.encoder_opt.step()
274 |
275 | if step % 5000 == 0:
276 | print("Critic Loss", critic_loss)
277 |
278 | return metrics
279 |
280 | def update_actor(self, obs, step, behavioural_action=None):
281 | metrics = dict()
282 |
283 | stddev = utils.schedule(self.stddev_schedule, step)
284 | dist = self.actor(obs, stddev)
285 | action = dist.sample(clip=self.stddev_clip)
286 | log_prob = dist.log_prob(action).sum(-1, keepdim=True)
287 | Q1, Q2 = self.critic(obs, action)
288 | Q = torch.min(Q1, Q2)
289 |
290 | actor_policy_improvement_loss = -Q.mean()
291 |
292 | actor_loss = actor_policy_improvement_loss
293 |
294 | # offline BC Loss
295 | if self.offline:
296 | actor_bc_loss = F.mse_loss(action, behavioural_action)
297 | # Eq. 5 of arXiv:2106.06860
298 | lam = self.bc_weight / Q.detach().abs().mean()
299 | if self.use_bc:
300 | actor_loss = actor_policy_improvement_loss * lam + actor_bc_loss
301 | else:
302 | actor_loss = actor_policy_improvement_loss #* lam
303 |
304 | # optimize actor
305 | self.actor_opt.zero_grad(set_to_none=True)
306 | actor_loss.backward()
307 | self.actor_opt.step()
308 |
309 | if self.use_tb:
310 | metrics['actor_loss'] = actor_policy_improvement_loss.item()
311 | metrics['actor_logprob'] = log_prob.mean().item()
312 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
313 | if self.offline:
314 | metrics['actor_bc_loss'] = actor_bc_loss.item()
315 |
316 | if step % 5000 == 0:
317 | print("Actor Loss", actor_loss)
318 |
319 | return metrics
320 |
321 | def update(self, step):
322 | metrics = dict()
323 |
324 | if step % self.update_every_steps != 0:
325 | return metrics
326 |
327 | batch, indices = next(self.replay_buffer)
328 | obs, action, reward, discount, next_obs, _, _, _, _ = utils.to_torch(
329 | batch, self.device)
330 |
331 | # augment
332 | obs = obs.float() #self.aug(obs.float())
333 | next_obs = next_obs.float() #self.aug(next_obs.float())
334 | # encode
335 | obs = self.encoder(obs)
336 | with torch.no_grad():
337 | next_obs = self.encoder(next_obs)
338 |
339 | if self.use_tb:
340 | metrics['batch_reward'] = reward.mean().item()
341 |
342 | # update critic
343 | metrics.update(
344 | self.update_critic(obs, action, reward, discount, next_obs, step))
345 |
346 | # update actor
347 | if self.offline:
348 | metrics.update(self.update_actor(obs.detach(), step, action.detach()))
349 | else:
350 | metrics.update(self.update_actor(obs.detach(), step))
351 |
352 | # update critic target
353 | utils.soft_update_params(self.critic, self.critic_target,
354 | self.critic_target_tau)
355 |
356 | if step % 20000 == 0:
357 | torch.save(self.actor.state_dict(), "/home/manant/scratch/pixel_gamma/checkpoints/drq_actor_{}.pth".format(step))
358 | torch.save(self.encoder.state_dict(), "/home/manant/scratch/pixel_gamma/checkpoints/drq_encoder_{}.pth".format(step))
359 |
360 | return metrics
361 |
362 | if __name__ == "__main__":
363 |
364 | agent = DrQV2Agent()
365 |
366 | for step in range(2000000):
367 | agent.update(step)
368 |
--------------------------------------------------------------------------------
/utils/numpy_replay_buffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import abc
3 |
4 |
5 | class AbstractReplayBuffer(abc.ABC):
6 | @abc.abstractmethod
7 | def add(self, time_step):
8 | pass
9 |
10 | @abc.abstractmethod
11 | def __next__(self, ):
12 | pass
13 |
14 | @abc.abstractmethod
15 | def __len__(self, ):
16 | pass
17 |
18 |
19 | class EfficientReplayBuffer(AbstractReplayBuffer):
20 | '''Fast + efficient replay buffer implementation in numpy.'''
21 |
22 | def __init__(self, buffer_size, batch_size, nstep, discount, frame_stack, spr_style_buffer,
23 | data_specs=None, pixel_samples=False, sarsa=False):
24 | self.buffer_size = buffer_size
25 | self.data_dict = {}
26 | self.index = -1
27 | self.traj_index = 0
28 | self.frame_stack = frame_stack
29 | self._recorded_frames = frame_stack + 1
30 | self.batch_size = batch_size
31 | self.nstep = nstep
32 | self.discount = discount
33 | self.full = False
34 | self.discount_vec = np.power(discount, np.arange(nstep)) # n_step - first dim should broadcast
35 | self.next_dis = discount ** nstep
36 | self.sarsa = sarsa
37 | self.latent_shape = 256 #50
38 | self.imp_act_shape = 84 * 84 * 1
39 | self.spr_style_buffer = spr_style_buffer
40 | self.pixel_samples = pixel_samples
41 |
42 | def _initial_setup(self, time_step):
43 | self.index = 0
44 | self.obs_shape = list(time_step.observation.shape)
45 | self.pixel_obs_shape = list(time_step.pixel_observation.shape)
46 | self.ims_channels = self.obs_shape[0] // self.frame_stack
47 | self.pixel_ims_channels = self.pixel_obs_shape[0] // self.frame_stack
48 | self.act_shape = time_step.action.shape
49 |
50 | self.obs = np.zeros([self.buffer_size, self.ims_channels, *self.obs_shape[1:]], dtype=np.int32)
51 | self.pixel_obs = np.zeros([self.buffer_size, self.pixel_ims_channels, *self.pixel_obs_shape[1:]], dtype=np.uint8)
52 | self.act = np.zeros([self.buffer_size, *self.act_shape], dtype=np.float32)
53 | self.latent = np.zeros([self.buffer_size, self.latent_shape], dtype=np.float32)
54 | self.imp_act = np.ones([self.buffer_size, self.imp_act_shape], dtype=np.float32)
55 | self.rew = np.zeros([self.buffer_size], dtype=np.float32)
56 | self.dis = np.zeros([self.buffer_size], dtype=np.float32)
57 | self.valid = np.zeros([self.buffer_size], dtype=np.bool_)
58 | self.k_step = np.zeros([self.buffer_size], dtype=np.float32)
59 | self.obs_k = np.zeros([self.buffer_size, self.ims_channels, *self.obs_shape[1:]], dtype=np.int32)
60 |
61 | def add_data_point(self, time_step):
62 | first = time_step.first()
63 | latest_obs = time_step.observation[-self.ims_channels:].astype(np.int32)
64 | latest_pixel_obs = time_step.pixel_observation[-self.pixel_ims_channels:].astype(np.uint8)
65 | if first:
66 | end_index = self.index + self.frame_stack
67 | end_invalid = end_index + self.frame_stack + 1
68 | if end_invalid > self.buffer_size:
69 | if end_index > self.buffer_size:
70 | end_index = end_index % self.buffer_size
71 | self.obs[self.index:self.buffer_size] = latest_obs
72 | self.obs[0:end_index] = latest_obs
73 | self.pixel_obs[self.index:self.buffer_size] = latest_pixel_obs
74 | self.pixel_obs[0:end_index] = latest_pixel_obs
75 | self.full = True
76 | else:
77 | self.obs[self.index:end_index] = latest_obs
78 | self.pixel_obs[self.index:end_index] = latest_pixel_obs
79 | end_invalid = end_invalid % self.buffer_size
80 | self.valid[self.index:self.buffer_size] = False
81 | self.valid[0:end_invalid] = False
82 | else:
83 | self.obs[self.index:end_index] = latest_obs
84 | self.pixel_obs[self.index:end_index] = latest_pixel_obs
85 | self.valid[self.index:end_invalid] = False
86 | self.index = end_index
87 | self.traj_index = 1
88 | else:
89 | np.copyto(self.obs[self.index], latest_obs) # Check most recent image
90 | np.copyto(self.pixel_obs[self.index], latest_pixel_obs)
91 | np.copyto(self.act[self.index], time_step.action)
92 | np.copyto(self.latent[self.index], time_step.latent)
93 | np.copyto(self.imp_act[self.index], time_step.imp_action)
94 | self.rew[self.index] = time_step.reward
95 | self.dis[self.index] = time_step.discount
96 | self.valid[(self.index + self.frame_stack) % self.buffer_size] = False
97 | self.k_step[self.index] = time_step.k_step
98 | if self.traj_index >= self.nstep:
99 | self.valid[(self.index - self.nstep + 1) % self.buffer_size] = True
100 | self.index += 1
101 | self.traj_index += 1
102 | if self.index == self.buffer_size:
103 | self.index = 0
104 | self.full = True
105 |
106 | def add(self, time_step):
107 | if self.index == -1:
108 | self._initial_setup(time_step)
109 | self.add_data_point(time_step)
110 |
111 | def get_stats(self, ):
112 | print("obs shape", self.obs.shape)
113 | mean = np.mean(self.obs, axis=(0, 2, 3))
114 | std = np.std(self.obs, axis=(0, 2, 3))
115 | return mean, std
116 |
117 | def __next__(self):
118 | indices = np.random.choice(self.valid.nonzero()[0] - 8, size=self.batch_size)
119 | # if spr:
120 | # return self.gather_spr_indices(indices), indices
121 | # else:
122 | return self.gather_nstep_indices(indices), indices
123 |
124 | def sample_spr(self, indices=None, jumps=8):
125 | if indices is None:
126 | indices = np.random.choice(self.valid.nonzero()[0] - jumps, size=self.batch_size)
127 | return self.gather_spr_indices(indices, jumps), indices
128 |
129 | def replace_latent(self, indices, latents):
130 | self.latent[indices] = latents
131 |
132 | def replace_action(self, indices, imp_actions):
133 | self.imp_act[indices] = imp_actions
134 |
135 | def sample_previous_latent(self, indices):
136 | return self.latent[indices - 1]
137 |
138 | def gather_nstep_indices(self, indices):
139 | n_samples = indices.shape[0]
140 | all_gather_ranges = np.stack([np.arange(indices[i] - self.frame_stack, indices[i] + self.nstep)
141 | for i in range(n_samples)], axis=0) % self.buffer_size
142 | gather_ranges = all_gather_ranges[:, self.frame_stack:] # bs x nstep
143 | obs_gather_ranges = all_gather_ranges[:, :self.frame_stack]
144 | nobs_gather_ranges = all_gather_ranges[:, -self.frame_stack:]
145 |
146 | all_rewards = self.rew[gather_ranges] #/ np.max(self.rew)
147 |
148 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement
149 | rew = np.sum(all_rewards * self.discount_vec, axis=1, keepdims=True) / gather_ranges.shape[1]
150 |
151 | if self.pixel_samples:
152 | # In case we require pixel observation instead of the saved latents
153 | obs = np.reshape(self.pixel_obs[obs_gather_ranges], [n_samples, *self.pixel_obs_shape])
154 | nobs = np.reshape(self.pixel_obs[nobs_gather_ranges], [n_samples, *self.pixel_obs_shape])
155 | else:
156 | obs = np.reshape(self.obs[obs_gather_ranges], [n_samples, *self.obs_shape])
157 | nobs = np.reshape(self.obs[nobs_gather_ranges], [n_samples, *self.obs_shape])
158 |
159 | act = self.act[indices]
160 | latent = self.latent[indices]
161 | imp_act = self.imp_act[indices]
162 | dis = np.expand_dims(self.next_dis * self.dis[nobs_gather_ranges[:, -1]], axis=-1)
163 |
164 | k_step = self.k_step[indices].astype(int)
165 | k_step_rand = []
166 | for each in k_step:
167 | if each > 1:
168 | k_step_rand.append(np.random.randint(low=1, high=each))
169 | else:
170 | k_step_rand.append(1)
171 | # k_step_rand = [np.random.randint(low=1, high=each) for each in k_step]
172 | k_all_gather_ranges = np.stack([np.arange(indices[i] + k_step_rand[i] - self.frame_stack, indices[i] + k_step_rand[i] + self.nstep)
173 | for i in range(n_samples)], axis=0) % self.buffer_size
174 | k_obs_gather_ranges = k_all_gather_ranges[:, :self.frame_stack]
175 | if self.pixel_samples:
176 | obs_k = np.reshape(self.pixel_obs[k_obs_gather_ranges], [n_samples, *self.pixel_obs_shape])
177 | else:
178 | obs_k = np.reshape(self.obs[k_obs_gather_ranges], [n_samples, *self.obs_shape])
179 |
180 | # k_all_gather_ranges = np.stack([np.arange(indices[i], indices[i] + k_step_rand[i])
181 | # for i in range(n_samples)], axis=0) % self.buffer_size
182 | act_k = self.act[indices + k_step_rand]
183 |
184 | if self.sarsa:
185 | nact = self.act[indices + self.nstep]
186 | return (obs, act, rew, dis, nobs, nact, latent, imp_act)
187 |
188 | return (obs, act, rew, dis, nobs, latent, act_k, k_step_rand, obs_k)
189 |
190 | def gather_spr_indices(self, indices, jumps=8):
191 | n_samples = indices.shape[0]
192 | all_gather_ranges = np.stack([np.arange(indices[i] - self.frame_stack, indices[i] + self.nstep)
193 | for i in range(n_samples)], axis=0) % self.buffer_size
194 | gather_ranges = all_gather_ranges[:, self.frame_stack:] # bs x nstep
195 | obs_gather_ranges = all_gather_ranges[:, :self.frame_stack]
196 | nobs_gather_ranges = all_gather_ranges[:, -self.frame_stack:]
197 |
198 | all_rewards = self.rew[gather_ranges] #/ np.max(self.rew)
199 |
200 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement
201 | rew = np.sum(all_rewards * self.discount_vec, axis=1, keepdims=True) / gather_ranges.shape[1]
202 |
203 | if self.pixel_samples:
204 | # In case we require pixel observation instead of the saved latents
205 | obs = np.reshape(self.pixel_obs[obs_gather_ranges], [n_samples, *self.pixel_obs_shape])
206 | nobs = np.reshape(self.pixel_obs[nobs_gather_ranges], [n_samples, *self.pixel_obs_shape])
207 | else:
208 | obs = np.reshape(self.obs[obs_gather_ranges], [n_samples, *self.obs_shape])
209 | nobs = np.reshape(self.obs[nobs_gather_ranges], [n_samples, *self.obs_shape])
210 |
211 | act = self.act[indices]
212 | latent = self.latent[indices]
213 | imp_act = self.imp_act[indices]
214 | dis = np.expand_dims(self.next_dis * self.dis[nobs_gather_ranges[:, -1]], axis=-1)
215 |
216 | k_all_gather_ranges = np.stack([np.arange(indices[i] - self.frame_stack, indices[i] + jumps)
217 | for i in range(n_samples)], axis=0) % self.buffer_size
218 | all_obs = []
219 | all_pixel_obs = []
220 | all_act = []
221 | all_rew = []
222 | for i in range(jumps+1):
223 | if i == 0:
224 | k_obs_gather_ranges = k_all_gather_ranges[:, :self.frame_stack]
225 | else:
226 | k_obs_gather_ranges = k_all_gather_ranges[:, self.frame_stack + i - self.frame_stack:self.frame_stack + i]
227 | obs_k = np.reshape(self.obs[k_obs_gather_ranges], [n_samples, *self.obs_shape])
228 | pixel_obs_k = np.reshape(self.pixel_obs[k_obs_gather_ranges], [n_samples, *self.pixel_obs_shape])
229 | act_k = self.act[indices + i]
230 | all_obs.append(obs_k)
231 | all_pixel_obs.append(pixel_obs_k)
232 | all_act.append(act_k)
233 |
234 | # all_frames_rewards = self.rew[k_all_gather_ranges[:, self.frame_stack-2:]] / np.max(self.rew)
235 | all_frames_rewards = self.rew[k_all_gather_ranges[:, self.frame_stack:]] #/ np.max(self.rew)
236 | discount_vec = np.power(self.discount, np.arange(all_frames_rewards.shape[1]))
237 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement
238 | values = np.sum(all_frames_rewards * discount_vec, axis=1, keepdims=True)
239 |
240 | all_obs = np.stack(all_obs)
241 | all_pixel_obs = np.stack(all_pixel_obs)
242 | all_act = np.stack(all_act)
243 |
244 | k_step = self.k_step[indices].astype(int)
245 | k_step_rand = []
246 | for each in k_step:
247 | if each > 1:
248 | k_step_rand.append(np.random.randint(low=1, high=each))
249 | else:
250 | k_step_rand.append(1)
251 | # k_step_rand = [np.random.randint(low=1, high=each) for each in k_step]
252 | k_all_gather_ranges = np.stack([np.arange(indices[i] + k_step_rand[i] - self.frame_stack, indices[i] + k_step_rand[i] + self.nstep)
253 | for i in range(n_samples)], axis=0) % self.buffer_size
254 | k_obs_gather_ranges = k_all_gather_ranges[:, :self.frame_stack]
255 | if self.pixel_samples:
256 | obs_k = np.reshape(self.pixel_obs[k_obs_gather_ranges], [n_samples, *self.pixel_obs_shape])
257 | else:
258 | obs_k = np.reshape(self.obs[k_obs_gather_ranges], [n_samples, *self.obs_shape])
259 |
260 | if self.sarsa:
261 | nact = self.act[indices + self.nstep]
262 | return (obs, act, rew, dis, nobs, nact, latent, imp_act)
263 |
264 | return (obs, act, rew, dis, obs_k, all_obs, all_pixel_obs, all_act, values)
265 | # return (obs, act, rew, dis, nobs, all_obs, all_pixel_obs, all_act, values)
266 |
267 | def __len__(self):
268 | if self.full:
269 | return self.buffer_size
270 | else:
271 | return self.index
272 |
273 | def get_train_and_val_indices(self, validation_percentage):
274 | all_indices = self.valid.nonzero()[0]
275 | num_indices = all_indices.shape[0]
276 | num_val = int(num_indices * validation_percentage)
277 | np.random.shuffle(all_indices)
278 | val_indices, train_indices = np.split(all_indices,
279 | [num_val])
280 | return train_indices, val_indices
281 |
282 | def get_obs_act_batch(self, indices):
283 | n_samples = indices.shape[0]
284 | obs_gather_ranges = np.stack([np.arange(indices[i] - self.frame_stack, indices[i])
285 | for i in range(n_samples)], axis=0) % self.buffer_size
286 | obs = np.reshape(self.obs[obs_gather_ranges], [n_samples, *self.obs_shape])
287 | act = self.act[indices]
288 | return obs, act
--------------------------------------------------------------------------------
/vae/train_vq_vae_voc.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '../utils/')
3 |
4 | from transformers.optimization import get_scheduler
5 | from transformers import GPT2LMHeadModel, GPT2Config
6 |
7 | from taming.models.vqgan import VQModel
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from pathlib import Path
15 | from matplotlib import pyplot as plt
16 |
17 | import random
18 | import utils
19 | import os
20 | import math
21 |
22 | from torchvision.utils import make_grid
23 | from drqv2 import RandomShiftsAug
24 |
25 | from numpy_replay_buffer import EfficientReplayBuffer
26 | from utils import load_offline_dataset_into_buffer
27 |
28 | from dm_env import specs
29 | import dmc
30 |
31 | import copy
32 | import wandb
33 | from omegaconf import OmegaConf
34 |
35 | from torch.distributions.categorical import Categorical
36 |
37 | from network_utils import Predictor, EMA, get_parameter_names, update_moving_average
38 |
39 | ######################
40 | batch_size = 32
41 |
42 | taming_path = "../../" # taming models path
43 | offline_dir = "../../walker_train/vq_latents/" # dataset dir path
44 | vq_model_path = "../../walker_vq_model_step.pth" # vqvae model path
45 | save_dir_path = "./" # voc's gpt model save path
46 |
47 | ######################
48 |
49 | class TFAgent:
50 | def __init__(self, discount=0.8, augmentation=RandomShiftsAug(pad=4)):
51 |
52 | wandb.init(project="Video Occupancy Models",
53 | id="voc_vqvae_gamma_{}".format(discount),
54 | entity=None, dir=os.getcwd())
55 |
56 | self.train_env = dmc.make("offline_walker_walk_expert", 3, 2, 0, None)
57 | data_specs = (self.train_env.observation_spec(),
58 | self.train_env.action_spec(),
59 | specs.Array((1,), np.float32, 'reward'),
60 | specs.Array((1,), np.float32, 'discount'))
61 |
62 | self.discount = discount
63 | self.batch_size = batch_size
64 | self.nsteps = min(int(1 / (1 - self.discount)), 10) # only sample max 10 steps
65 | self.codebook_size = 1024
66 |
67 | ######### train on VQ latents #########
68 | self.replay_buffer = EfficientReplayBuffer(1000000, self.batch_size, 1, self.discount, 3, False, data_specs)
69 | load_offline_dataset_into_buffer(Path(offline_dir), self.replay_buffer, None, 3, 1000000, latent_style="vae")
70 | ##########################
71 |
72 | ######### vq-vae setup #########
73 | config_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/configs/model.yaml")
74 | config = OmegaConf.load(config_path)
75 | self.model = VQModel(**config.model.params) # Don't put the entire model on device, we only need the decoder
76 | self.quantizer = self.model.quantize.to('cuda')
77 | self.decoder = nn.Sequential(self.model.post_quant_conv, self.model.decoder).to('cuda')
78 |
79 | self.from_imagenet = False
80 | if self.from_imagenet:
81 | ckpt_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/ckpts/last.ckpt")
82 | sd = torch.load(ckpt_path, map_location="cuda")["state_dict"]
83 | missing, unexpected = self.model.load_state_dict(sd, strict=False)
84 | else:
85 | print("Loading fine-tuned VQ model...")
86 | self.model.load_state_dict(torch.load(vq_model_path))
87 | ##########################
88 |
89 | ######### gpt setup #########
90 | configuration = GPT2Config(vocab_size=1024, n_layer=4, n_head=8, n_embed=512, resid_pdrop=0.2, embd_pdrop=0.2, attn_prdrop=0.2)
91 | self.gpt = GPT2LMHeadModel(configuration).to('cuda')
92 | self.gpt_target = copy.deepcopy(self.gpt)
93 | self.gpt_target.generation_config.output_scores = True
94 | self.target_ema_updater = EMA(0.9)
95 | ##########################
96 |
97 | ######### optimizer setup #########
98 | self.reward_predictor = Predictor(256 * 25 * 3).to('cuda')
99 | self.gpt_optimizer = torch.optim.AdamW(self.get_grouped_params(), lr=3e-4)
100 | self.optimizer = torch.optim.AdamW(list(self.reward_predictor.parameters()), lr=3e-4)
101 |
102 | num_training_steps = 100000
103 | self.warmup_ratio = 0.05
104 | warmup_steps = math.ceil(num_training_steps * self.warmup_ratio)
105 | self.lr_scheduler = get_scheduler(
106 | "cosine",
107 | optimizer=self.gpt_optimizer,
108 | num_warmup_steps=warmup_steps,
109 | num_training_steps=num_training_steps,
110 | )
111 | ##########################
112 |
113 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to('cuda')
114 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to('cuda')
115 |
116 | self.device = 'cuda'
117 | self.aug = augmentation
118 |
119 | self.saving_iter = [50, 100, 500, 1000, 2000, 5000, 10000, 50000, 75000, 100000]
120 | self.train()
121 |
122 | def get_grouped_params(self):
123 | decay_parameters = get_parameter_names(self.gpt, [nn.LayerNorm])
124 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
125 | optimizer_grouped_parameters = [
126 | {
127 | "params": [
128 | p for n, p in self.gpt.named_parameters() if (n in decay_parameters and p.requires_grad)
129 | ],
130 | "weight_decay": 0.1,
131 | },
132 | {
133 | "params": [
134 | p for n, p in self.gpt.named_parameters() if (n not in decay_parameters and p.requires_grad)
135 | ],
136 | "weight_decay": 0.0,
137 | },
138 | ]
139 | return optimizer_grouped_parameters
140 |
141 | def train(self, training=True):
142 | self.training = training
143 |
144 | def update(self, step=0):
145 | metrics = dict()
146 |
147 | batch, indices = next(self.replay_buffer)
148 | obs, action, reward, discount, next_obs, latent, _, _, obs_k = utils.to_torch(batch, self.device)
149 |
150 | # collect discrete vq indices
151 | y_context = obs.long()
152 | y_target = obs_k.long()
153 | obs_shape = obs.shape
154 |
155 | quant_target = self.quantizer.get_codebook_entry(y_target, None)
156 | quant_target = quant_target.view(-1, 5, 5, 256).permute(0, 3, 1, 2)
157 |
158 | pred_reward = self.reward_predictor(quant_target.detach().float().reshape(obs_shape[0], -1))
159 | reward_loss = F.mse_loss(pred_reward, reward.float()).mean()
160 |
161 | # generate target
162 | with torch.no_grad():
163 | p_t = self.gpt_target.generate(y_target, max_new_tokens=y_target.shape[-1], do_sample=True, pad_token_id=-100)
164 | p_t = p_t[:, -y_target.shape[-1]:]
165 |
166 | # gamma sampling
167 | gamma = self.discount * torch.ones((y_context.shape[0], ), device=y_context.device)
168 | prob = torch.bernoulli(gamma)
169 | p_target = torch.zeros_like(y_target)
170 |
171 | # with prob 1-gamma, sample from next state
172 | p_c_idx = torch.nonzero(1 - prob)
173 | p_target[p_c_idx] = y_target[p_c_idx]
174 |
175 | # with prob gamma, sample from bootstrapped model
176 | p_t_idx = torch.nonzero(prob)
177 | p_target[p_t_idx] = p_t[p_t_idx]
178 |
179 |
180 | # gpt predictions
181 | inp = torch.cat([y_context, p_target], dim=1)
182 | outputs = self.gpt(inp, labels=inp)
183 | gpt_loss = outputs.loss
184 |
185 | loss = gpt_loss + reward_loss
186 |
187 | loss.backward()
188 |
189 | # grad accumulate
190 | if step % 2 == 0:
191 | self.optimizer.step()
192 | self.gpt_optimizer.step()
193 |
194 | self.optimizer.zero_grad()
195 | self.gpt_optimizer.zero_grad()
196 |
197 | self.lr_scheduler.step()
198 | update_moving_average(self.target_ema_updater, self.gpt_target, self.gpt)
199 |
200 | # visualize predictions
201 | if step % 100 == 0 and step != 0:
202 | with torch.no_grad():
203 | # sample a batch of traj and corresponding values
204 | batch, indices = self.replay_buffer.sample_spr(jumps=self.nsteps)
205 | _, _, _, _, _, all_obs, all_pixel_obs, _, values = utils.to_torch(batch, self.device)
206 |
207 | # preprocess first obs from traj
208 | obs = all_obs[0]
209 | obs_shape = obs.shape
210 | pixel_obs_shape = F.interpolate(all_pixel_obs[0], size=80).shape
211 |
212 | y_context = obs.long()
213 |
214 | # sample target predictions
215 | p_t = self.gpt.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, pad_token_id=-100)[:, -y_context.shape[-1]:]
216 |
217 | # reconstruct sampled prediction
218 | quant = self.quantizer.get_codebook_entry(p_t, None)
219 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2)
220 |
221 | p_t_pixel_recon = self.decoder(quant)
222 |
223 | viz_imgs = []
224 | viz_imgs.append(p_t_pixel_recon)
225 | for i in range(all_pixel_obs.shape[0]):
226 | obs = all_pixel_obs[i]
227 | obs = F.interpolate(obs, size=80)
228 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)])
229 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((pixel_obs_shape[0] * 3, 3, *pixel_obs_shape[2:])) #torch.einsum('nhwc->nchw', obs)
230 | viz_imgs.append(obs)
231 |
232 | value_loss = self.get_value_estimates(y_context, values, obs_shape)
233 | wandb.log({"value loss": value_loss}, step=step)
234 | density_value_loss = self.get_density_value_estimates(y_context, all_obs, obs_shape)
235 | wandb.log({"density value loss": density_value_loss}, step=step)
236 |
237 | viz_imgs = torch.stack(viz_imgs)[:, :8]
238 | t, n, c, h, w = viz_imgs.shape
239 | viz_imgs = torch.einsum('tnchw->ntchw', viz_imgs)
240 | viz_imgs = viz_imgs.reshape(t*n, c, h, w)
241 | viz_img = make_grid(viz_imgs, nrow=t, normalize=True, scale_each=True)
242 |
243 | img = wandb.Image(viz_img)
244 | wandb.log({f"Gamma Pred": img}, step=step)
245 |
246 | wandb.log({"reward loss": reward_loss}, step=step)
247 | wandb.log({"gpt loss": gpt_loss}, step=step)
248 |
249 | # save gpt model
250 | if step in self.saving_iter:
251 | print("saving gpt weights...")
252 | # self.save_gpt_weights(step)
253 |
254 | return metrics
255 |
256 | def save_gpt_weights(self, step):
257 | torch.save(self.gpt.state_dict(), os.path.join(save_dir_path, "vqvae_microgpt_gamma_{}_{}_model_step_{}.pth".format(self.discount, self.codebook_size, step)))
258 |
259 | def get_value_estimates(self, y_context, values, obs_shape):
260 | # Take a state, get samples from the gamma distribution,
261 | # Run the reward predictor through these to get value estimates
262 | # Get ground truth value estimates by simply taking discounted sum of rewards
263 | # Compare these for different states
264 |
265 | num_gamma_samples = 100
266 | values_pred = []
267 |
268 | for i in range(num_gamma_samples):
269 | outputs = self.gpt_target.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, output_scores=True, return_dict_in_generate=True, pad_token_id=-100)
270 | p_t = outputs.sequences[:, -y_context.shape[-1]:]
271 |
272 | quant = self.quantizer.get_codebook_entry(p_t, None)
273 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2)
274 |
275 | values_pred.append(self.reward_predictor(quant.float().reshape(obs_shape[0], -1)).squeeze(1))
276 |
277 | values_pred = torch.stack(values_pred).sum(0) / (100 * (1 - self.discount))
278 |
279 | value_estimation_loss = F.mse_loss(values_pred, values.squeeze(1).float()).mean()
280 | print("val estimation", value_estimation_loss, values_pred[:5], values[:5])
281 |
282 | return value_estimation_loss
283 |
284 | def get_density_value_estimates(self, y_context, all_obs, obs_shape):
285 |
286 | values_pred = []
287 | values_actual = []
288 | for i in range(all_obs.shape[0]-1):
289 | y_target = all_obs[i+1].long()
290 | quant_target = self.quantizer.get_codebook_entry(y_target, None)
291 | quant_target = quant_target.view(-1, 5, 5, 256).permute(0, 3, 1, 2)
292 |
293 | inp = torch.cat([y_context, y_target], dim=1)
294 | outputs = self.gpt(inp, labels=inp)
295 | logits = outputs.logits[:, -y_target.shape[1]-1:-1]
296 | scores = torch.nn.functional.log_softmax(logits, dim=2)
297 |
298 | gathered_scores = torch.gather(scores, dim=2, index=y_target.unsqueeze(2))
299 | gathered_logits = torch.gather(logits, dim=2, index=y_target.unsqueeze(2))
300 |
301 | input_length = y_target.shape[1]
302 | output_length = input_length + torch.sum(gathered_logits < 0, dim=1)
303 | prob = torch.exp(gathered_scores.sum(1) / output_length)
304 | values_pred.append(prob.squeeze(1) * self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1))
305 | values_actual.append(self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1))
306 |
307 | values_pred = torch.stack(values_pred).sum(0)
308 |
309 | discount_vec = torch.pow(self.discount, torch.arange(torch.stack(values_actual).shape[0], device='cuda'))
310 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement
311 | values_actual = torch.sum(torch.stack(values_actual) * discount_vec.repeat(torch.stack(values_actual).shape[1], 1).T, dim=0)
312 |
313 | value_estimation_loss = F.mse_loss(values_pred, values_actual).mean()
314 | print("density val estimation", value_estimation_loss, values_pred[:5], values_actual[:5])
315 |
316 | return value_estimation_loss
317 |
318 | import argparse
319 |
320 | if __name__ == "__main__":
321 | parser = argparse.ArgumentParser(description='Create a dictionary with command-line arguments.')
322 |
323 | parser.add_argument('--discount', type=float, default=0.8, help='discount')
324 | args = parser.parse_args()
325 |
326 | agent = TFAgent(discount=args.discount)
327 |
328 | agent.optimizer.zero_grad()
329 | agent.gpt_optimizer.zero_grad()
330 |
331 | for step in range(100000):
332 | agent.update(step)
--------------------------------------------------------------------------------
/vae/train_pixel_vq_vae_voc.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '../utils/')
3 |
4 | from transformers.optimization import get_scheduler
5 | from transformers import GPT2LMHeadModel, GPT2Config
6 |
7 | from taming.models.vqgan import VQModel
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from pathlib import Path
15 | from matplotlib import pyplot as plt
16 |
17 | import random
18 | import utils
19 | import os
20 | import math
21 |
22 | from torchvision.utils import make_grid
23 | from drqv2 import RandomShiftsAug
24 |
25 | from numpy_replay_buffer import EfficientReplayBuffer
26 | from utils import load_offline_dataset_into_buffer
27 |
28 | from dm_env import specs
29 | import dmc
30 |
31 | import copy
32 | import wandb
33 | from omegaconf import OmegaConf
34 | from torchvision.utils import make_grid
35 |
36 | from network_utils import Predictor, EMA, get_parameter_names, update_moving_average
37 |
38 | ######################
39 | batch_size = 32
40 |
41 | taming_path = "../../" # taming models path
42 | offline_dir = "../../walker_train/vq_latents/" # dataset dir path
43 | vq_model_path = "../../walker_vq_model_step.pth" # vqvae model path
44 | save_dir_path = "./" # voc's gpt model save path
45 |
46 | ######################
47 |
48 | class TFAgent:
49 | def __init__(self, discount=0.8, codebook_size=1024, augmentation=RandomShiftsAug(pad=4)):
50 |
51 | wandb.init(project="Video Occupancy Models",
52 | id="voc_vqvae_gamma_{}_td_{}".format(discount, codebook_size),
53 | entity=None, dir=os.getcwd())
54 |
55 | self.train_env = dmc.make("offline_walker_walk_expert", 3, 2, 0, None)
56 | data_specs = (self.train_env.observation_spec(),
57 | self.train_env.action_spec(),
58 | specs.Array((1,), np.float32, 'reward'),
59 | specs.Array((1,), np.float32, 'discount'))
60 |
61 | self.discount = discount
62 | self.codebook_size = codebook_size
63 | self.batch_size = batch_size
64 | self.replay_buffer = EfficientReplayBuffer(1000000, self.batch_size, 1, self.discount, 3, False, data_specs, pixel_samples=True)
65 |
66 | load_offline_dataset_into_buffer(Path(offline_dir), self.replay_buffer, None, 3, 1000000)
67 |
68 | config_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/configs/model.yaml")
69 | config = OmegaConf.load(config_path)
70 |
71 | config.model.params.n_embed = self.codebook_size
72 | self.model = VQModel(**config.model.params).to('cuda')
73 |
74 | self.from_imagenet = False
75 | if self.from_imagenet:
76 | ckpt_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/ckpts/last.ckpt")
77 | sd = torch.load(ckpt_path, map_location="cuda")["state_dict"]
78 | missing, unexpected = self.model.load_state_dict(sd, strict=False)
79 | else:
80 | print("Loading fine-tuned VQ model...")
81 | self.model.load_state_dict(torch.load(vq_model_path))
82 |
83 | configuration = GPT2Config(vocab_size=self.codebook_size, n_layer=4, n_head=8, n_embed=512, resid_pdrop=0.2, embd_pdrop=0.2, attn_prdrop=0.2)
84 | self.gpt = GPT2LMHeadModel(configuration).to('cuda')
85 |
86 | self.gpt_target = copy.deepcopy(self.gpt)
87 | self.target_ema_updater = EMA(0.9)
88 |
89 | self.gpt_target.generation_config.output_scores = True
90 |
91 | self.reward_predictor = Predictor(256 * 25 * 3).to('cuda')
92 |
93 | self.gpt_optimizer = torch.optim.AdamW(self.get_grouped_params(), lr=3e-4)
94 | self.optimizer = torch.optim.AdamW(list(self.model.parameters()) + list(self.reward_predictor.parameters()), lr=3e-4)
95 |
96 | num_training_steps = 100000
97 | self.warmup_ratio = 0.05
98 | warmup_steps = math.ceil(num_training_steps * self.warmup_ratio)
99 | self.lr_scheduler = get_scheduler(
100 | "cosine",
101 | optimizer=self.gpt_optimizer,
102 | num_warmup_steps=warmup_steps,
103 | num_training_steps=num_training_steps,
104 | )
105 |
106 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to('cuda')
107 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to('cuda')
108 |
109 | self.device = 'cuda'
110 | self.aug = augmentation
111 |
112 | self.saving_iter = [50, 100, 500, 1000, 2000, 5000, 10000, 50000, 75000, 100000]
113 | self.train()
114 |
115 | def get_grouped_params(self):
116 | decay_parameters = get_parameter_names(self.gpt, [nn.LayerNorm])
117 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
118 | optimizer_grouped_parameters = [
119 | {
120 | "params": [
121 | p for n, p in self.gpt.named_parameters() if (n in decay_parameters and p.requires_grad)
122 | ],
123 | "weight_decay": 0.1,
124 | },
125 | {
126 | "params": [
127 | p for n, p in self.gpt.named_parameters() if (n not in decay_parameters and p.requires_grad)
128 | ],
129 | "weight_decay": 0.0,
130 | },
131 | ]
132 | return optimizer_grouped_parameters
133 |
134 | def train(self, training=True):
135 | self.training = training
136 |
137 | def update(self, step=0):
138 | metrics = dict()
139 |
140 | batch, indices = next(self.replay_buffer)
141 | obs, action, reward, discount, next_obs, latent, _, _, obs_k = utils.to_torch(
142 | batch, self.device)
143 |
144 | # augment
145 | obs = self.aug(obs.float())
146 | obs_k = self.aug(next_obs.float())
147 |
148 | # reshape obs
149 | obs = F.interpolate(obs, size=80)
150 | obs_k = F.interpolate(obs_k, size=80)
151 |
152 | obs_shape = obs.shape
153 |
154 | # normalize and preprocess
155 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)])
156 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((obs_shape[0] * 3, 3, *obs_shape[2:]))
157 |
158 | obs_k = torch.stack([torch.einsum('nchw->nhwc', obs_k[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)])
159 | obs_k = torch.einsum('tnhwc->ntchw', obs_k).reshape((obs_shape[0] * 3, 3, *obs_shape[2:]))
160 |
161 | # vq embed
162 | quant_context, emb_loss_context, info_context = self.model.encode(obs)
163 | quant_target, emb_loss_target, info_target = self.model.encode(obs_k)
164 |
165 | # collect discrete vq indices
166 | y_context = info_context[2].view(obs_shape[0], -1).detach()
167 | y_target = info_target[2].view(obs_shape[0], -1).detach()
168 |
169 | pred_reward = self.reward_predictor(quant_target.detach().float().reshape(obs_shape[0], -1))
170 | reward_loss = F.mse_loss(pred_reward, reward.float()).mean()
171 |
172 | # generate target
173 | with torch.no_grad():
174 | p_t = self.gpt_target.generate(y_target, max_new_tokens=y_target.shape[-1], do_sample=True, pad_token_id=-100)
175 | p_t = p_t[:, -y_target.shape[-1]:]
176 |
177 | # gamma sampling
178 | gamma = self.discount * torch.ones((y_context.shape[0], ), device=y_context.device)
179 | prob = torch.bernoulli(gamma)
180 | p_target = torch.zeros_like(y_target)
181 |
182 | # with prob 1-gamma, sample from next state
183 | p_c_idx = torch.nonzero(1 - prob)
184 | p_target[p_c_idx] = y_target[p_c_idx]
185 |
186 | # with prob gamma, sample from bootstrapped model
187 | p_t_idx = torch.nonzero(prob)
188 | p_target[p_t_idx] = p_t[p_t_idx]
189 |
190 | xrec, qloss = self.model.decode(quant_context), emb_loss_context
191 | vae_loss, log_dict_ae = self.model.loss(qloss, obs, xrec, 0, step, last_layer=self.model.get_last_layer(), split="train")
192 |
193 | # gpt predictions
194 | inp = torch.cat([y_context, p_target], dim=1)
195 | outputs = self.gpt(inp, labels=inp)
196 | gpt_loss = outputs.loss
197 |
198 | loss = vae_loss + gpt_loss + reward_loss
199 |
200 | loss.backward()
201 |
202 | # grad accumulate
203 | if step % 2 == 0:
204 | self.optimizer.step()
205 | self.gpt_optimizer.step()
206 |
207 | self.optimizer.zero_grad()
208 | self.gpt_optimizer.zero_grad()
209 |
210 | self.lr_scheduler.step()
211 | update_moving_average(self.target_ema_updater, self.gpt_target, self.gpt)
212 |
213 | # visualize predictions
214 | if step % 100 == 0:
215 | with torch.no_grad():
216 | # sample a batch of traj and corresponding values
217 | batch, indices = self.replay_buffer.sample_spr()
218 | _, _, _, _, _, _, all_obs, _, values = utils.to_torch(batch, self.device)
219 |
220 | # preprocess first obs from traj
221 | obs = F.interpolate(all_obs[0], size=80)
222 | obs_shape = obs.shape
223 |
224 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)])
225 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((obs_shape[0] * 3, 3, *obs_shape[2:])) #torch.einsum('nhwc->nchw', obs)
226 |
227 | # vq embed first obs
228 | quant_context, emb_loss_context, info_context = self.model.encode(obs)
229 | y_context = info_context[2].view(obs_shape[0], -1).detach()
230 |
231 | # sample target predictions
232 | p_t = self.gpt.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, pad_token_id=-100)[:, -y_context.shape[-1]:]
233 |
234 | # reconstruct sampled prediction
235 | quant = self.model.quantize.get_codebook_entry(y_context, None) #self.model.quantize.get_codebook_entry(p_t, None)
236 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2)
237 |
238 | y_pixel_recon = self.model.decode(quant)
239 |
240 | quant = self.model.quantize.get_codebook_entry(p_t, None) #self.model.quantize.get_codebook_entry(p_t, None)
241 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2)
242 |
243 | p_t_pixel_recon = self.model.decode(quant)
244 |
245 | viz_imgs = []
246 | viz_imgs.append(p_t_pixel_recon)
247 | viz_imgs.append(y_pixel_recon)
248 | all_obs = self.paint_obs(all_obs)
249 | for i in range(all_obs.shape[0]):
250 | obs = all_obs[i]
251 | obs = F.interpolate(obs, size=80)
252 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)])
253 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((obs_shape[0] * 3, 3, *obs_shape[2:])) #torch.einsum('nhwc->nchw', obs)
254 | viz_imgs.append(obs)
255 |
256 | value_loss = self.get_value_estimates(y_context, values, obs_shape)
257 | wandb.log({"value loss": value_loss}, step=step)
258 | # density_value_loss = self.get_density_value_estimates(y_context, viz_imgs, obs_shape)
259 | # wandb.log({"density value loss": density_value_loss}, step=step)
260 |
261 | viz_imgs = torch.stack(viz_imgs)[:, :8]
262 | t, n, c, h, w = viz_imgs.shape
263 | viz_imgs = torch.einsum('tnchw->ntchw', viz_imgs)
264 | viz_imgs = viz_imgs.reshape(t*n, c, h, w)
265 | viz_img = make_grid(viz_imgs, nrow=t, normalize=True, scale_each=True)
266 |
267 | img = wandb.Image(viz_img)
268 | wandb.log({f"Gamma Pred": img}, step=step)
269 |
270 | wandb.log({"reward loss": reward_loss}, step=step)
271 |
272 | # if finetuning, save vq model at 2k steps
273 | if step in self.saving_iter:
274 | print("saving gpt weights...")
275 | self.save_vq_weights(step)
276 | self.save_gpt_weights(step)
277 |
278 | return metrics
279 |
280 | def save_vq_weights(self, step):
281 | torch.save(self.model.state_dict(), os.path.join(save_dir_path, "vqvae_model_{}_td_{}_step_{}.pth".format(self.discount, self.codebook_size, step)))
282 |
283 | def save_gpt_weights(self, step):
284 | torch.save(self.gpt.state_dict(), os.path.join(save_dir_path, "pixel_vqvae_microgpt_gamma_{}_{}_model_step_{}.pth".format(self.discount, self.codebook_size, step)))
285 |
286 | def get_value_estimates(self, y_context, values, obs_shape):
287 | # Take a state, get samples from the gamma distribution,
288 | # Run the reward predictor through these to get value estimates
289 | # Get ground truth value estimates by simply taking discounted sum of rewards
290 | # Compare these for different states
291 |
292 | num_gamma_samples = 100
293 | values_pred = []
294 |
295 | for i in range(num_gamma_samples):
296 | outputs = self.gpt_target.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, output_scores=True, return_dict_in_generate=True, pad_token_id=-100)
297 | p_t = outputs.sequences[:, -y_context.shape[-1]:]
298 |
299 | quant = self.model.quantize.get_codebook_entry(p_t, None)
300 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2)
301 |
302 | values_pred.append(self.reward_predictor(quant.float().reshape(obs_shape[0], -1)).squeeze(1))
303 |
304 | values_pred = torch.stack(values_pred).sum(0) / (100 * (1 - self.discount))
305 |
306 | value_estimation_loss = F.mse_loss(values_pred, values.squeeze(1).float()).mean()
307 | print("val estimation", value_estimation_loss, values_pred[:5], values[:5])
308 |
309 | return value_estimation_loss
310 |
311 | def get_density_value_estimates(self, y_context, all_obs, obs_shape):
312 |
313 | values_pred = []
314 | values_actual = []
315 | for i in range(all_obs.shape[0]-1):
316 | quant_target, _, info_target = self.model.encode(all_obs[i+1])
317 | y_target = info_target[2].view(obs_shape[0], -1).detach()
318 |
319 | inp = torch.cat([y_context, y_target], dim=1)
320 | outputs = self.gpt(inp, labels=inp)
321 | logits = outputs.logits[:, -y_target.shape[1]-1:-1]
322 | scores = torch.nn.functional.log_softmax(logits, dim=2)
323 |
324 | gathered_scores = torch.gather(scores, dim=2, index=y_target.unsqueeze(2))
325 | gathered_logits = torch.gather(logits, dim=2, index=y_target.unsqueeze(2))
326 |
327 | input_length = y_target.shape[1]
328 | output_length = input_length + torch.sum(gathered_logits < 0, dim=1)
329 | prob = torch.exp(gathered_scores.sum(1) / output_length)
330 | values_pred.append(prob.squeeze(1) * self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1))
331 | values_actual.append(self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1))
332 |
333 | values_pred = torch.stack(values_pred).sum(0)
334 |
335 | discount_vec = torch.pow(self.discount, torch.arange(torch.stack(values_actual).shape[0], device='cuda'))
336 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement
337 | values_actual = torch.sum(torch.stack(values_actual) * discount_vec.repeat(torch.stack(values_actual).shape[1], 1).T, dim=0)
338 |
339 | value_estimation_loss = F.mse_loss(values_pred, values_actual).mean()
340 | print("density val estimation", value_estimation_loss, values_pred[:5], values_actual[:5])
341 |
342 | return value_estimation_loss
343 |
344 | def paint_obs(self, all_obs):
345 | for i in range(all_obs.shape[0]):
346 | obs = all_obs[i] # get first of all frame of all batches, B x 9 x 80 x 80
347 |
348 | # Every ith frame is colored one way, for all batches
349 | obs[:, ::3, :5, -5:] = 255.0 * i / all_obs.shape[0] # set top corner for first channel to 255
350 | obs[:, 1::3, :5, -5:] = 0.0 # set top corner for second channel to 0
351 | obs[:, 2::3, :5, -5:] = 0.0 # set top corner for third channel to 0
352 |
353 | all_obs[i] = obs
354 |
355 | return all_obs
356 |
357 |
358 | import argparse
359 |
360 | if __name__ == "__main__":
361 | parser = argparse.ArgumentParser(description='Create a dictionary with command-line arguments.')
362 |
363 | parser.add_argument('--discount', type=float, default=0.8, help='discount')
364 | parser.add_argument('--codebook_size', type=int, default=1024, help='codebook size')
365 | args = parser.parse_args()
366 |
367 | agent = TFAgent(discount=args.discount, codebook_size=args.codebook_size)
368 |
369 | agent.optimizer.zero_grad()
370 | agent.gpt_optimizer.zero_grad()
371 |
372 | for step in range(100000):
373 | agent.update(step)
--------------------------------------------------------------------------------
/musik/train_vq_musik_voc.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '../utils/')
3 |
4 | from musik_model import VQMUSIKModel
5 |
6 | import os
7 | import math
8 | import copy
9 | import wandb
10 | import random
11 | import utils
12 |
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from omegaconf import OmegaConf
18 |
19 | from pathlib import Path
20 | from matplotlib import pyplot as plt
21 | from torchvision.utils import make_grid
22 |
23 | from numpy_replay_buffer import EfficientReplayBuffer
24 | from utils import load_offline_dataset_into_buffer
25 |
26 | from transformers.optimization import get_scheduler
27 | from transformers import GPT2LMHeadModel, GPT2Config
28 |
29 | import dmc
30 | from dm_env import specs
31 | from drqv2 import RandomShiftsAug
32 |
33 | from network_utils import Predictor, EMA, get_parameter_names, update_moving_average
34 | from network_utils import InfoNCE
35 |
36 | from omegaconf import OmegaConf
37 |
38 | ##################
39 | obs_resize = 80
40 | mae_seq_len = 65 # no. of tokens actually processed once masking is done
41 | embed_size = 128
42 |
43 | action_dim = 6
44 |
45 | taming_path = "../../" # taming models path
46 | offline_dir = "../../cheetah_train/vq_latents/"
47 | save_dir_path = "./" # voc's gpt model save path
48 |
49 | batch_size = 32
50 | ##################
51 |
52 | class TFAgent:
53 | def __init__(self, discount=0.8, codebook_size=1024, augmentation=RandomShiftsAug(pad=4)):
54 |
55 | wandb.init(project="Video Occupancy Models",
56 | id="voc_musik_gamma_{}_{}".format(discount, codebook_size),
57 | entity=None, dir=os.getcwd())
58 |
59 | self.train_env = dmc.make("offline_cheetah_run_expert", 3, 2, 0, None)
60 | data_specs = (self.train_env.observation_spec(),
61 | self.train_env.action_spec(),
62 | specs.Array((1,), np.float32, 'reward'),
63 | specs.Array((1,), np.float32, 'discount'))
64 |
65 | self.discount = discount
66 | self.batch_size = batch_size
67 | self.nsteps = min(int(1 / (1 - self.discount)), 10) # only sample max 10 steps
68 | self.codebook_size = codebook_size
69 |
70 | ######### random and mixed data buffers #########
71 | self.replay_buffer = EfficientReplayBuffer(1000000, self.batch_size, 1, self.discount, 3, False, data_specs, pixel_samples=True)
72 | load_offline_dataset_into_buffer(Path(offline_dir), self.replay_buffer, None, 3, 1000000, future_sampling_steps=15)
73 | ##########################
74 |
75 | ######### vq-musik setup #########
76 | config_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/configs/model.yaml")
77 | config = OmegaConf.load(config_path)
78 | # config.model.params.ddconfig.in_channels = 9
79 | config.model.params.n_embed = self.codebook_size
80 | self.model = VQMUSIKModel(**config.model.params).to('cuda')
81 |
82 | self.from_imagenet = False
83 | if self.from_imagenet:
84 | ckpt_path = os.path.join(taming_path, "vqgan_imagenet_f16_1024/ckpts/last.ckpt")
85 | sd = torch.load(ckpt_path, map_location="cuda")["state_dict"]
86 | missing, unexpected = self.model.encoder.load_state_dict(sd, strict=False)
87 |
88 | # self.model.load_state_dict(torch.load("/home/manant/scratch/vq_model_7000.pth"), strict=False)
89 |
90 | # Musik predictor network for training the representation
91 | self.musik_predictor = InfoNCE(1542, action_dim, 1).to('cuda')
92 | ##########################
93 |
94 | ######### gpt setup #########
95 | configuration = GPT2Config(vocab_size=self.codebook_size, n_layer=4, n_head=8, n_embed=512, resid_pdrop=0.2, embd_pdrop=0.2, attn_prdrop=0.2)
96 | self.gpt = GPT2LMHeadModel(configuration).to('cuda')
97 | self.gpt_target = copy.deepcopy(self.gpt)
98 | self.gpt_target.generation_config.output_scores = True
99 | self.target_ema_updater = EMA(0.9)
100 | self.encoder_target_ema_updater = EMA(0.9)
101 | ##########################
102 |
103 | ######### optimizer setup #########
104 | self.reward_predictor = Predictor(19200).to('cuda') # embed_size x mae_seq_len x num_codebooks x frame_stack = 256 * 25 * 3
105 | self.gpt_optimizer = torch.optim.AdamW(list(self.get_grouped_params(self.gpt)), lr=3e-4)
106 | self.optimizer = torch.optim.AdamW(list(self.reward_predictor.parameters()) + list(self.musik_predictor.parameters()) + list(self.model.parameters()), lr=3e-4)
107 |
108 | num_training_steps = 100000
109 | self.warmup_ratio = 0.05
110 | warmup_steps = math.ceil(num_training_steps * self.warmup_ratio)
111 | self.lr_scheduler = get_scheduler(
112 | "cosine",
113 | optimizer=self.gpt_optimizer,
114 | num_warmup_steps=warmup_steps,
115 | num_training_steps=num_training_steps,
116 | )
117 | ##########################
118 |
119 | self.imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).to('cuda')
120 | self.imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).to('cuda')
121 |
122 | self.device = 'cuda'
123 | self.aug = augmentation
124 |
125 | self.saving_iter = [50, 100, 500, 1000, 2000, 5000, 10000, 50000, 75000, 100000]
126 | self.train()
127 | self.model.train()
128 |
129 | def get_grouped_params(self, model):
130 | decay_parameters = get_parameter_names(model, [nn.LayerNorm])
131 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
132 | optimizer_grouped_parameters = [
133 | {
134 | "params": [
135 | p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad)
136 | ],
137 | "weight_decay": 0.1,
138 | },
139 | {
140 | "params": [
141 | p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)
142 | ],
143 | "weight_decay": 0.0,
144 | },
145 | ]
146 | return optimizer_grouped_parameters
147 |
148 | def train(self, training=True):
149 | self.training = training
150 |
151 | def preprocess_obs(self, obs):
152 | obs = F.interpolate(obs, size=obs_resize)
153 |
154 | if len(obs.shape) == 3:
155 | obs = obs.unsqueeze(0)
156 | org_obs_shape = obs.shape
157 |
158 | try:
159 | assert len(obs.shape) == 4 # B x C x H x W
160 | org_obs_shape = obs.shape
161 | # normalize and preprocess
162 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)])
163 | obs = torch.einsum('snhwc->nschw', obs).reshape((org_obs_shape[0] * 3, 3, *org_obs_shape[2:]))
164 | except:
165 | assert len(obs.shape) == 5 # T x B x C x H x W
166 | org_obs_shape = t, b, c, h, w = obs.shape
167 | obs = torch.stack([torch.einsum('tnchw->tnhwc', obs[:, :, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)])
168 | obs = torch.einsum('stnhwc->tnschw', obs).reshape((t, b * 3, 3, h, w))
169 | return obs, org_obs_shape
170 |
171 | def update(self, step=0):
172 | metrics = dict()
173 |
174 | batch, indices = next(self.replay_buffer)
175 | obs, action, reward, discount, next_obs, _, _, _, obs_k = utils.to_torch(batch, self.device)
176 |
177 | # augment
178 | obs = self.aug(obs.float())
179 | next_obs = self.aug(next_obs.float())
180 | obs_k = self.aug(obs_k.float())
181 |
182 | obs, obs_shape = self.preprocess_obs(obs) # process current obs
183 | next_obs, _ = self.preprocess_obs(next_obs)
184 | obs_k, _ = self.preprocess_obs(obs_k) # process next/future obs
185 |
186 | quant_context, emb_loss_context, info_context = self.model.encode(obs)
187 | quant_future_target, emb_loss_future_target, info_future_target = self.model.encode(obs_k)
188 | quant_target, emb_loss_target, info_target = self.model.encode(next_obs)
189 |
190 | y_context = info_context[2].view(obs_shape[0], -1).detach()
191 | y_target = info_target[2].view(obs_shape[0], -1).detach()
192 |
193 | pred_reward = self.reward_predictor(quant_target.detach().float().reshape(obs_shape[0], -1))
194 | reward_loss = F.mse_loss(pred_reward, reward.float()).mean()
195 |
196 | # generate target
197 | with torch.no_grad():
198 | p_t = self.gpt_target.generate(y_target, max_new_tokens=y_target.shape[-1], do_sample=True, pad_token_id=-100)
199 | p_t = p_t[:, -y_target.shape[-1]:]
200 |
201 | # gamma sampling
202 | gamma = self.discount * torch.ones((y_context.shape[0], ), device=y_context.device)
203 | prob = torch.bernoulli(gamma)
204 | p_target = torch.zeros_like(y_target)
205 |
206 | # with prob 1-gamma, sample from next state
207 | p_c_idx = torch.nonzero(1 - prob)
208 | p_target[p_c_idx] = y_target[p_c_idx]
209 |
210 | # with prob gamma, sample from bootstrapped model
211 | p_t_idx = torch.nonzero(prob)
212 | p_target[p_t_idx] = p_t[p_t_idx]
213 |
214 | # gpt predictions
215 | inp = torch.cat([y_context, p_target], dim=1)
216 | outputs = self.gpt(inp, labels=inp)
217 | gpt_loss = outputs.loss
218 |
219 | enc = self.model.decode_linear(quant_context).reshape((obs_shape[0], -1))
220 | enc_k = self.model.decode_linear(quant_future_target).reshape((obs_shape[0], -1))
221 |
222 | musik_loss = self.musik_predictor(enc, enc_k, action) + emb_loss_context + emb_loss_future_target # musik loss + codebook loss
223 |
224 | loss = reward_loss + gpt_loss + musik_loss
225 | loss.backward()
226 |
227 | # grad accumulate
228 | if step % 2 == 0:
229 | self.optimizer.step()
230 | self.gpt_optimizer.step()
231 |
232 | self.optimizer.zero_grad()
233 | self.gpt_optimizer.zero_grad()
234 |
235 | self.lr_scheduler.step()
236 | update_moving_average(self.target_ema_updater, self.gpt_target, self.gpt)
237 |
238 | # visualize predictions
239 | if step % 200 == 0:
240 | with torch.no_grad():
241 | # sample a batch of traj and corresponding values
242 | batch, indices = self.replay_buffer.sample_spr()
243 | _, _, _, _, _, _, all_obs, _, values = utils.to_torch(batch, self.device)
244 |
245 | # preprocess first obs from traj
246 | obs = F.interpolate(all_obs[0], size=obs_resize)
247 | obs_shape = obs.shape
248 |
249 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)])
250 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((obs_shape[0] * 3, 3, *obs_shape[2:])) #torch.einsum('nhwc->nchw', obs)
251 | # obs = torch.einsum('snhwc->nschw', obs).reshape((obs_shape[0], 9, *obs_shape[2:]))
252 |
253 | # vq embed first obs
254 | quant_context, emb_loss_context, info_context = self.model.encode(obs)
255 | y_context = info_context[2].view(obs_shape[0], -1).detach()
256 |
257 | value_loss = self.get_value_estimates(y_context, values, obs_shape)
258 | wandb.log({"value loss": value_loss}, step=step)
259 |
260 | density_value_loss = self.get_density_value_estimates(y_context, all_obs, obs_shape)
261 | wandb.log({"density value loss": density_value_loss}, step=step)
262 |
263 | print("losses are", reward_loss, musik_loss, gpt_loss, emb_loss_context)
264 |
265 | wandb.log({"gpt loss": gpt_loss}, step=step)
266 | wandb.log({"rep loss": musik_loss}, step=step)
267 | wandb.log({"reward loss": reward_loss}, step=step)
268 |
269 | # save gpt model
270 | if step in self.saving_iter:
271 | print("saving gpt weights...")
272 | self.save_musik_weights(step)
273 | self.save_gpt_weights(step)
274 |
275 | return metrics
276 |
277 | def save_musik_weights(self, step):
278 | torch.save(self.model.state_dict(), os.path.join(save_dir_path, "vq_musik_model_{}_{}_model_step_{}.pth".format(self.discount, self.codebook_size, step)))
279 |
280 | def save_gpt_weights(self, step):
281 | torch.save(self.gpt.state_dict(), "/home/manant/scratch/pixel_gamma/checkpoints/may_3_runs/musik/vq_conv_musik_microgpt_gamma_{}_{}_{}_model_step_{}.pth".format(self.discount, self.target_style, self.codebook_size, step))
282 | torch.save(self.gpt.state_dict(), os.path.join(save_dir_path, "musik_microgpt_gamma_{}_{}_model_step_{}.pth".format(self.discount, self.codebook_size, step)))
283 |
284 | def get_value_estimates(self, y_context, values, obs_shape):
285 | # Take a state, get samples from the gamma distribution,
286 | # Run the reward predictor through these to get value estimates
287 | # Get ground truth value estimates by simply taking discounted sum of rewards
288 | # Compare these for different states
289 |
290 | num_gamma_samples = 100
291 | values_pred = []
292 |
293 | for i in range(num_gamma_samples):
294 | outputs = self.gpt_target.generate(y_context, max_new_tokens=y_context.shape[-1], do_sample=True, output_scores=True, return_dict_in_generate=True, pad_token_id=-100) #, kwargs={'token_type_ids': context_mask_ids})
295 | p_t = outputs.sequences[:, -y_context.shape[-1]:]
296 |
297 | quant = self.model.quantize.get_codebook_entry(p_t, None)
298 | quant = quant.view(-1, 5, 5, 256).permute(0, 3, 1, 2)
299 |
300 | values_pred.append(self.reward_predictor(quant.float().reshape(obs_shape[0], -1)).squeeze(1))
301 |
302 | values_pred = torch.stack(values_pred).sum(0) / (100 * (1 - self.discount))
303 |
304 | value_estimation_loss = F.mse_loss(values_pred, values.squeeze(1).float()).mean()
305 | print("val estimation", value_estimation_loss, values_pred[:5], values[:5])
306 |
307 | return value_estimation_loss
308 |
309 | def get_density_value_estimates(self, y_context, all_obs, obs_shape):
310 |
311 | values_pred = []
312 | values_actual = []
313 | for i in range(all_obs.shape[0]-1):
314 | obs = F.interpolate(all_obs[i+1], size=obs_resize)
315 | obs_shape = obs.shape
316 |
317 | obs = torch.stack([torch.einsum('nchw->nhwc', obs[:, i*3:3+i*3] / 255.) - self.imagenet_mean / self.imagenet_std for i in range(3)])
318 | obs = torch.einsum('tnhwc->ntchw', obs).reshape((obs_shape[0] * 3, 3, *obs_shape[2:]))
319 | quant_target, emb_loss_target, info_target = self.model.encode(obs)
320 | y_target = info_target[2].view(obs_shape[0], -1).detach()
321 |
322 | inp = torch.cat([y_context, y_target], dim=1)
323 | outputs = self.gpt(inp, labels=inp)
324 | logits = outputs.logits[:, -y_target.shape[1]-1:-1]
325 | scores = torch.nn.functional.log_softmax(logits, dim=2)
326 |
327 | gathered_scores = torch.gather(scores, dim=2, index=y_target.unsqueeze(2))
328 | gathered_logits = torch.gather(logits, dim=2, index=y_target.unsqueeze(2))
329 |
330 | input_length = y_target.shape[1]
331 | output_length = input_length + torch.sum(gathered_logits < 0, dim=1)
332 | prob = torch.exp(gathered_scores.sum(1) / output_length)
333 | values_pred.append(prob.squeeze(1) * self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1))
334 | values_actual.append(self.reward_predictor(quant_target.float().reshape(obs_shape[0], -1)).squeeze(1))
335 |
336 | values_pred = torch.stack(values_pred).sum(0)
337 |
338 | discount_vec = torch.pow(self.discount, torch.arange(torch.stack(values_actual).shape[0], device='cuda'))
339 | # Could implement below operation as a matmul in pytorch for marginal additional speed improvement
340 | values_actual = torch.sum(torch.stack(values_actual) * discount_vec.repeat(torch.stack(values_actual).shape[1], 1).T, dim=0)
341 |
342 | value_estimation_loss = F.mse_loss(values_pred, values_actual).mean()
343 | print("density val estimation", value_estimation_loss, values_pred[:5], values_actual[:5])
344 |
345 | return value_estimation_loss
346 |
347 |
348 | import argparse
349 |
350 | if __name__ == "__main__":
351 | parser = argparse.ArgumentParser(description='Create a dictionary with command-line arguments.')
352 |
353 | parser.add_argument('--discount', type=float, default=0.8, help='discount')
354 | parser.add_argument('--codebook_size', type=int, default=1024, help='codebook size')
355 | args = parser.parse_args()
356 |
357 | agent = TFAgent(discount=args.discount, codebook_size=args.codebook_size)
358 |
359 | agent.optimizer.zero_grad()
360 | agent.gpt_optimizer.zero_grad()
361 |
362 | for step in range(100000):
363 | agent.update(step)
--------------------------------------------------------------------------------
/dino/modeling_vqkd.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beitv2
4 | # Copyright (c) 2022 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Zhiliang Peng
7 | # Based on VQGAN code bases
8 | # https://github.com/CompVis/taming-transformers
9 | # --------------------------------------------------------'
10 |
11 | import torch
12 | import numpy as np
13 | from torch import nn, einsum
14 | import torch.nn.functional as F
15 | import math
16 | from collections import OrderedDict
17 | from functools import partial, reduce
18 | from einops import rearrange
19 | from timm.models.layers import trunc_normal_
20 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
21 | from timm.models.registry import register_model
22 |
23 | from modeling_finetune import VisionTransformer
24 | from norm_ema_quantizer import NormEMAVectorQuantizer
25 |
26 | from vqkd_teacher import get_dino_vit_base#, clip
27 |
28 | class VQKD(nn.Module):
29 | def __init__(self,
30 | encoder_config,
31 | decoder_config,
32 | n_embed=8192,
33 | embed_dim=32,
34 | decay=0.99,
35 | process_type='default',
36 | quantize_kmeans_init=True,
37 | teacher_model_type='clip',
38 | decoder_out_dim=512,
39 | rec_loss_type='cosine',
40 | **kwargs
41 | ):
42 | super().__init__()
43 | print(kwargs)
44 | if decoder_config['in_chans'] != embed_dim:
45 | print(f"Rewrite the in_chans in decoder from {decoder_config['in_chans']} to {embed_dim}")
46 | decoder_config['in_chans'] = embed_dim
47 |
48 | # encoder & decode params
49 | print('Final encoder config', encoder_config)
50 | self.encoder = VisionTransformer(**encoder_config)
51 |
52 | print('Final decoder config', decoder_config)
53 | self.decoder = VisionTransformer(**decoder_config)
54 |
55 | self.quantize = NormEMAVectorQuantizer(
56 | n_embed=n_embed, embedding_dim=embed_dim, beta=1.0, kmeans_init=quantize_kmeans_init, decay=decay,
57 | )
58 |
59 | self.patch_size = encoder_config['patch_size']
60 | self.token_shape = (encoder_config['img_size'] // self.patch_size, encoder_config['img_size'] // self.patch_size)
61 |
62 | ## Teacher model setting
63 | self.teacher_model_type = teacher_model_type
64 | self.decoder_out_dim = decoder_out_dim
65 | if self.teacher_model_type == 'clip':
66 | self.scaling_layer = ScalingLayerForClip()
67 | self.teacher_model, _ = clip.load("ViT-B/16", device='cpu', jit=False)
68 | self.decoder_out_dim = 512
69 |
70 | elif self.teacher_model_type == 'dino':
71 | self.scaling_layer = ScalingLayerForIM()
72 | self.teacher_model = get_dino_vit_base()
73 | self.decoder_out_dim = 768
74 |
75 | else:
76 | self.teacher_model = None
77 |
78 | if self.teacher_model is not None:
79 | for param in self.teacher_model.parameters():
80 | param.requires_grad = False # fix teacher_model model
81 |
82 | self.teacher_model.eval()
83 | self.teacher_input_size = kwargs.get('teacher_input_size', 224)
84 |
85 | # task layer
86 | self.encode_task_layer = nn.Sequential(
87 | nn.Linear(encoder_config['embed_dim'], encoder_config['embed_dim']),
88 | nn.Tanh(),
89 | nn.Linear(encoder_config['embed_dim'], embed_dim) # for quantize
90 | )
91 | self.decode_task_layer = nn.Sequential(
92 | nn.Linear(decoder_config['embed_dim'], decoder_config['embed_dim']),
93 | nn.Tanh(),
94 | nn.Linear(decoder_config['embed_dim'], self.decoder_out_dim),
95 | )
96 |
97 | self.rec_loss_type = rec_loss_type
98 |
99 | print(f"process type for VQKD: {process_type}")
100 | self.process_type = process_type # in ['default', 'dall-e']
101 | self.logit_laplace_eps = 0.1
102 | self.kwargs = kwargs
103 |
104 | self.encode_task_layer.apply(self._init_weights)
105 | self.decode_task_layer.apply(self._init_weights)
106 | print("model initialized")
107 |
108 | def _init_weights(self, m):
109 | if isinstance(m, nn.Linear):
110 | trunc_normal_(m.weight, std=.02)
111 | if isinstance(m, nn.Linear) and m.bias is not None:
112 | nn.init.constant_(m.bias, 0)
113 | elif isinstance(m, nn.LayerNorm):
114 | nn.init.constant_(m.bias, 0)
115 | nn.init.constant_(m.weight, 1.0)
116 |
117 | @torch.jit.ignore
118 | def no_weight_decay(self):
119 | return {'quantize.embedding.weight', 'decoder.cls_token', 'decoder.pos_embed',
120 | 'encoder.cls_token', 'encoder.pos_embed'}
121 |
122 | @property
123 | def device(self):
124 | return self.decoder.cls_token.device
125 |
126 | def pre_process(self, data):
127 | print("pre process called", data.shape)
128 | # if self.process_type == 'default':
129 | # # TODO: modify for adapt
130 | # data = data.to(self.device)
131 | # if data.max() <= 1.:
132 | # data = data * 255.
133 | # data = data / 127.5 - 1.0
134 | # elif self.process_type == 'imagenet_norm':
135 | # mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(self.device)[None, :, None, None]
136 | # std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(self.device)[None, :, None, None]
137 | # data = (data - mean) / std
138 | return data
139 |
140 | def get_number_of_tokens(self):
141 | return self.quantize.n_e
142 |
143 | def get_tokens(self, data, **kwargs):
144 |
145 | data = self.pre_process(data)
146 | quantize, embed_ind, loss = self.encode(data)
147 | output = {}
148 | output['token'] = embed_ind.view(data.shape[0], -1)
149 | output['input_img'] = data
150 |
151 | return output
152 |
153 | def get_codebook_entry(self, ind):
154 | ind_shape = ind.shape
155 | ind = ind.reshape(-1)
156 | z = self.quantize.embedding(ind)
157 | z = z.reshape(ind_shape[0], -1)
158 | return z
159 |
160 | def encode(self, x):
161 | encoder_features = self.encoder(x, return_patch_tokens=True)
162 |
163 | with torch.cuda.amp.autocast(enabled=False):
164 | to_quantizer_features = self.encode_task_layer(encoder_features.type_as(self.encode_task_layer[-1].weight))
165 |
166 | N = to_quantizer_features.shape[1]
167 | h, w = int(math.sqrt(N)), int(math.sqrt(N))
168 |
169 | to_quantizer_features = rearrange(to_quantizer_features, 'b (h w) c -> b c h w', h=h, w=w) # reshape for quantizer
170 | quantize, loss, embed_ind = self.quantize(to_quantizer_features)
171 | # print("before quant and after shapes", to_quantizer_features.shape, quantize.shape)
172 |
173 | return quantize, embed_ind, loss
174 |
175 | def decode(self, quantize, **kwargs):
176 | # reshape tokens to feature maps for patch embed in decoder
177 | # quantize = rearrange(quantize, 'b (h w) c -> b c h w', h=self.token_shape[0], w=self.token_shape[1])
178 | decoder_features = self.decoder(quantize, return_patch_tokens=True)
179 | rec = self.decode_task_layer(decoder_features)
180 |
181 | return rec
182 |
183 | def get_codebook_indices(self, x, **kwargs):
184 | # for beit pre-training
185 | return self.get_tokens(x, **kwargs)['token']
186 |
187 | @torch.no_grad()
188 | def get_regress_target(self, x, **kwargs):
189 |
190 | norm_imgs = self.scaling_layer(x)
191 | if self.teacher_model_type == 'clip':
192 | target = self.teacher_model.encode_image(norm_imgs, return_all_tokens=True) @ self.teacher_model.visual.proj
193 | elif self.teacher_model_type == 'dino':
194 | target = self.teacher_model.forward(norm_imgs, return_patch_tokens=True)
195 | else:
196 | raise NotImplementedError
197 |
198 | return target
199 |
200 | def calculate_rec_loss(self, rec, target):
201 | if self.rec_loss_type == 'cosine':
202 | target = target / target.norm(dim=-1, keepdim=True)
203 | rec = rec / rec.norm(dim=-1, keepdim=True)
204 | rec_loss = (1 - (target * rec).sum(-1)).mean()
205 | else:
206 | raise NotImplementedError
207 |
208 | return rec_loss
209 |
210 | def forward(self, x, **kwargs):
211 | """
212 | x: shape [B, 3, H, W] in [0, 1]
213 | """
214 | x = self.pre_process(x) # rescale to [-1, 1]
215 |
216 | target = self.get_regress_target(x, **kwargs)
217 |
218 | quantize, embed_ind, emb_loss = self.encode(x)
219 | xrec = self.decode(quantize)
220 |
221 | rec_loss = self.calculate_rec_loss(xrec, target)
222 | loss = emb_loss + rec_loss
223 |
224 | log = {}
225 | split="train" if self.training else "val"
226 | log[f'{split}/quant_loss'] = emb_loss.detach().mean()
227 | log[f'{split}/rec_loss'] = rec_loss.detach().mean()
228 | log[f'{split}/total_loss'] = loss.detach().mean()
229 |
230 | return loss, log
231 |
232 | class ScalingLayerForClip(nn.Module):
233 | def __init__(self):
234 | super(ScalingLayerForClip, self).__init__()
235 | self.register_buffer('shift', torch.Tensor([0.48145466, 0.4578275, 0.40821073])[None, :, None, None])
236 | self.register_buffer('scale', torch.Tensor([0.26862954, 0.26130258, 0.27577711])[None, :, None, None])
237 |
238 | def forward(self, inp):
239 | inp = ((inp + 1.) * 127.5).clamp(0, 255.) / 255. # rescale to [0, 1.]
240 | return (inp - self.shift) / self.scale
241 |
242 | class ScalingLayerForIM(nn.Module):
243 | def __init__(self):
244 | super(ScalingLayerForIM, self).__init__()
245 | self.register_buffer('shift', torch.Tensor([0.485, 0.456, 0.406])[None, :, None, None]) # scale for tokenizer with default prosscess type \in [-1, 1]
246 | self.register_buffer('scale', torch.Tensor([0.229, 0.224, 0.225])[None, :, None, None])
247 |
248 | def forward(self, inp):
249 | inp = ((inp + 1.) * 127.5).clamp(0, 255.) / 255. # rescale to [0, 1.]
250 | return (inp - self.shift) / self.scale
251 |
252 | def get_model_default_params():
253 | return dict(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12,
254 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
255 | norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=0., use_abs_pos_emb=True,
256 | use_rel_pos_bias=False, use_shared_rel_pos_bias=False, use_mean_pooling=True, init_scale=0.001)
257 |
258 | @register_model
259 | def vqkd_encoder_base_decoder_1x768x12_clip(pretrained=False, pretrained_weight=None, as_tokenzer=False, img_size=224,
260 | n_code=8192, code_dim=32, **kwargs):
261 | encoder_config, decoder_config = get_model_default_params(), get_model_default_params()
262 |
263 | # encoder settings
264 | encoder_config['img_size'] = img_size
265 | encoder_config['num_classes'] = 0
266 | # decoder settings
267 | decoder_config['img_size'] = img_size // decoder_config['patch_size']
268 | decoder_config['patch_size'] = 1
269 | decoder_config['in_chans'] = code_dim
270 | decoder_config['num_classes'] = 0
271 | decoder_config['depth'] = 1
272 | # teacher settings
273 | _ = kwargs.pop("teacher_model_type", "clip")
274 |
275 | teacher_model_type = 'clip' if not as_tokenzer else 'None'
276 | decoder_out_dim = 512
277 |
278 | model = VQKD(encoder_config, decoder_config, n_code, code_dim, teacher_model_type=teacher_model_type,
279 | decoder_out_dim=decoder_out_dim, **kwargs)
280 |
281 | if as_tokenzer:
282 | assert pretrained
283 | assert pretrained_weight is not None
284 |
285 | if pretrained_weight.startswith('https'):
286 | weights = torch.hub.load_state_dict_from_url(pretrained_weight, map_location='cpu', check_hash=True)
287 | else:
288 | weights = torch.load(pretrained_weight, map_location='cpu')
289 |
290 | if 'model' in weights:
291 | weights = weights['model']
292 | else:
293 | weights = weights["state_dict"]
294 | keys = list(weights.keys())
295 |
296 | for k in keys:
297 | if k.startswith("loss") or k.startswith("teacher") or k.startswith("scaling"):
298 | del weights[k]
299 | model.load_state_dict(weights)
300 | return model
301 |
302 | @register_model
303 | def vqkd_encoder_base_decoder_3x768x12_clip(pretrained=False, pretrained_weight=None, as_tokenzer=False, img_size=224,
304 | n_code=8192, code_dim=32, **kwargs):
305 | encoder_config, decoder_config = get_model_default_params(), get_model_default_params()
306 |
307 | # encoder settings
308 | encoder_config['img_size'] = img_size
309 | encoder_config['num_classes'] = 0
310 | # decoder settings
311 | decoder_config['img_size'] = img_size // decoder_config['patch_size']
312 | decoder_config['patch_size'] = 1
313 | decoder_config['in_chans'] = code_dim
314 | decoder_config['num_classes'] = 0
315 | decoder_config['depth'] = 3
316 | # teacher settings
317 | _ = kwargs.pop("teacher_model_type", "clip")
318 |
319 | teacher_model_type = 'clip' if not as_tokenzer else 'None'
320 | decoder_out_dim = 512
321 |
322 | model = VQKD(encoder_config, decoder_config, n_code, code_dim, teacher_model_type=teacher_model_type,
323 | decoder_out_dim=decoder_out_dim, **kwargs)
324 |
325 | if as_tokenzer:
326 | assert pretrained
327 | assert pretrained_weight is not None
328 |
329 | if pretrained_weight.startswith('https'):
330 | weights = torch.hub.load_state_dict_from_url(pretrained_weight, map_location='cpu', check_hash=True)
331 | else:
332 | weights = torch.load(pretrained_weight, map_location='cpu')
333 |
334 | if 'model' in weights:
335 | weights = weights['model']
336 | else:
337 | weights = weights["state_dict"]
338 | keys = list(weights.keys())
339 |
340 | for k in keys:
341 | if k.startswith("loss") or k.startswith("teacher") or k.startswith("scaling"):
342 | del weights[k]
343 | model.load_state_dict(weights)
344 | return model
345 |
346 |
347 | @register_model
348 | def vqkd_encoder_base_decoder_1x768x12_dino(pretrained=False, pretrained_weight=None, as_tokenzer=False, img_size=224,
349 | n_code=8192, code_dim=32, **kwargs):
350 | encoder_config, decoder_config = get_model_default_params(), get_model_default_params()
351 |
352 | # encoder settings
353 | encoder_config['img_size'] = img_size
354 | encoder_config['num_classes'] = 0
355 | # decoder settings
356 | decoder_config['img_size'] = img_size // decoder_config['patch_size']
357 | decoder_config['patch_size'] = 1
358 | decoder_config['in_chans'] = code_dim
359 | decoder_config['num_classes'] = 0
360 | decoder_config['depth'] = 1
361 | # teacher settings
362 | _ = kwargs.pop("teacher_model_type", "dino")
363 |
364 | teacher_model_type = 'dino' if not as_tokenzer else 'None'
365 | decoder_out_dim = 768
366 |
367 | model = VQKD(encoder_config, decoder_config, n_code, code_dim, teacher_model_type=teacher_model_type,
368 | decoder_out_dim=decoder_out_dim, **kwargs)
369 |
370 | if as_tokenzer:
371 | assert pretrained
372 | assert pretrained_weight is not None
373 |
374 | if pretrained_weight.startswith('https'):
375 | weights = torch.hub.load_state_dict_from_url(pretrained_weight, map_location='cpu', check_hash=True)
376 | else:
377 | weights = torch.load(pretrained_weight, map_location='cpu')
378 |
379 | if 'model' in weights:
380 | weights = weights['model']
381 | else:
382 | weights = weights["state_dict"]
383 | keys = list(weights.keys())
384 |
385 | for k in keys:
386 | if k.startswith("loss") or k.startswith("teacher") or k.startswith("scaling"):
387 | del weights[k]
388 | model.load_state_dict(weights)
389 | return model
390 |
391 |
392 | if __name__ == '__main__':
393 | pass
394 |
395 |
396 |
397 |
398 |
399 |
400 |
--------------------------------------------------------------------------------
/dino/vqkd_teacher/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 | import math
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 | import pdb
10 |
11 | class Bottleneck(nn.Module):
12 | expansion = 4
13 |
14 | def __init__(self, inplanes, planes, stride=1):
15 | super().__init__()
16 |
17 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
18 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 |
24 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
25 |
26 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
27 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
28 |
29 | self.relu = nn.ReLU(inplace=True)
30 | self.downsample = None
31 | self.stride = stride
32 |
33 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
34 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
35 | self.downsample = nn.Sequential(OrderedDict([
36 | ("-1", nn.AvgPool2d(stride)),
37 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
38 | ("1", nn.BatchNorm2d(planes * self.expansion))
39 | ]))
40 |
41 | def forward(self, x: torch.Tensor):
42 | identity = x
43 |
44 | out = self.relu(self.bn1(self.conv1(x)))
45 | out = self.relu(self.bn2(self.conv2(out)))
46 | out = self.avgpool(out)
47 | out = self.bn3(self.conv3(out))
48 |
49 | if self.downsample is not None:
50 | identity = self.downsample(x)
51 |
52 | out += identity
53 | out = self.relu(out)
54 | return out
55 |
56 |
57 | class AttentionPool2d(nn.Module):
58 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
59 | super().__init__()
60 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
61 | self.k_proj = nn.Linear(embed_dim, embed_dim)
62 | self.q_proj = nn.Linear(embed_dim, embed_dim)
63 | self.v_proj = nn.Linear(embed_dim, embed_dim)
64 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
65 | self.num_heads = num_heads
66 |
67 | def forward(self, x, return_all_tokens=False):
68 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
69 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
70 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
71 | x, _ = F.multi_head_attention_forward(
72 | query=x, key=x, value=x,
73 | embed_dim_to_check=x.shape[-1],
74 | num_heads=self.num_heads,
75 | q_proj_weight=self.q_proj.weight,
76 | k_proj_weight=self.k_proj.weight,
77 | v_proj_weight=self.v_proj.weight,
78 | in_proj_weight=None,
79 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
80 | bias_k=None,
81 | bias_v=None,
82 | add_zero_attn=False,
83 | dropout_p=0,
84 | out_proj_weight=self.c_proj.weight,
85 | out_proj_bias=self.c_proj.bias,
86 | use_separate_proj_weight=True,
87 | training=self.training,
88 | need_weights=False
89 | )
90 | if return_all_tokens:
91 | return x
92 | else:
93 | return x[0]
94 |
95 |
96 | class ModifiedResNet(nn.Module):
97 | """
98 | A ResNet class that is similar to torchvision's but contains the following changes:
99 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
100 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
101 | - The final pooling layer is a QKV attention instead of an average pool
102 | """
103 |
104 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
105 | super().__init__()
106 | self.output_dim = output_dim
107 | self.input_resolution = input_resolution
108 |
109 | # the 3-layer stem
110 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
111 | self.bn1 = nn.BatchNorm2d(width // 2)
112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113 | self.bn2 = nn.BatchNorm2d(width // 2)
114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115 | self.bn3 = nn.BatchNorm2d(width)
116 | self.avgpool = nn.AvgPool2d(2)
117 | self.relu = nn.ReLU(inplace=True)
118 |
119 | # residual layers
120 | self._inplanes = width # this is a *mutable* variable used during construction
121 | self.layer1 = self._make_layer(width, layers[0])
122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125 |
126 | embed_dim = width * 32 # the ResNet feature dimension
127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128 |
129 | def _make_layer(self, planes, blocks, stride=1):
130 | layers = [Bottleneck(self._inplanes, planes, stride)]
131 |
132 | self._inplanes = planes * Bottleneck.expansion
133 | for _ in range(1, blocks):
134 | layers.append(Bottleneck(self._inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x, return_side_out=False, return_all_tokens=False):
139 | def stem(x):
140 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
141 | x = self.relu(bn(conv(x)))
142 | x = self.avgpool(x)
143 | return x
144 | out = []
145 | x = x.type(self.conv1.weight.dtype)
146 | x = stem(x)
147 | x = self.layer1(x)
148 | if return_side_out:
149 | out.append(x)
150 | x = self.layer2(x)
151 | if return_side_out:
152 | out.append(x)
153 | x = self.layer3(x)
154 | if return_side_out:
155 | out.append(x)
156 | x = self.layer4(x)
157 | if return_side_out:
158 | out.append(x)
159 | x = self.attnpool(x, return_all_tokens)
160 | out.append(x)
161 | if len(out) == 1:
162 | return x
163 | else:
164 | return out
165 |
166 |
167 | class LayerNorm(nn.LayerNorm):
168 | """Subclass torch's LayerNorm to handle fp16."""
169 |
170 | def forward(self, x: torch.Tensor):
171 | orig_type = x.dtype
172 | ret = super().forward(x.type(torch.float32))
173 | return ret.type(orig_type)
174 |
175 |
176 | class QuickGELU(nn.Module):
177 | def forward(self, x: torch.Tensor):
178 | return x * torch.sigmoid(1.702 * x)
179 |
180 |
181 | class ResidualAttentionBlock(nn.Module):
182 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
183 | super().__init__()
184 |
185 | self.attn = nn.MultiheadAttention(d_model, n_head)
186 | self.ln_1 = LayerNorm(d_model)
187 | self.mlp = nn.Sequential(OrderedDict([
188 | ("c_fc", nn.Linear(d_model, d_model * 4)),
189 | ("gelu", QuickGELU()),
190 | ("c_proj", nn.Linear(d_model * 4, d_model))
191 | ]))
192 | self.ln_2 = LayerNorm(d_model)
193 | self.attn_mask = attn_mask
194 |
195 | def attention(self, x: torch.Tensor):
196 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
197 | # pdb.set_trace()
198 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
199 |
200 | def forward(self, x: torch.Tensor):
201 | x = x + self.attention(self.ln_1(x))
202 | x = x + self.mlp(self.ln_2(x))
203 | return x
204 |
205 |
206 | class Transformer(nn.Module):
207 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
208 | super().__init__()
209 | self.width = width
210 | self.layers = layers
211 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
212 |
213 | def forward(self, x: torch.Tensor, return_intermediate_out: bool = False):
214 | if return_intermediate_out:
215 | output = []
216 | for block in self.resblocks:
217 | x = block(x)
218 | output.append(x)
219 | return output
220 |
221 | return self.resblocks(x)
222 |
223 |
224 | class VisionTransformer(nn.Module):
225 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
226 | super().__init__()
227 | self.input_resolution = input_resolution
228 | self.patch_size = patch_size
229 | self.output_dim = output_dim
230 | self.width = width
231 | self.heads = heads
232 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
233 |
234 | scale = width ** -0.5
235 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
236 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
237 | self.ln_pre = LayerNorm(width)
238 |
239 | self.transformer = Transformer(width, layers, heads)
240 |
241 | self.ln_post = LayerNorm(width)
242 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
243 |
244 | def forward(self, x: torch.Tensor, return_all_tokens=False, return_all_final_tokens=False, **kwargs):
245 |
246 | B, nc, w, h = x.shape
247 |
248 | x = self.conv1(x) # shape = [*, width, grid, grid]
249 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
250 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
251 |
252 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
253 |
254 | if x.shape[1] != self.positional_embedding.shape[0]:
255 | x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype)
256 | else:
257 | x = x + self.positional_embedding.to(x.dtype)
258 |
259 | x = self.ln_pre(x)
260 |
261 | x = x.permute(1, 0, 2) # NLD -> LND
262 | x = self.transformer(x)
263 | x = x.permute(1, 0, 2) # LND -> NLD
264 |
265 | if return_all_tokens:
266 | x = self.ln_post(x)
267 | return x[:, 1:, :]
268 |
269 | if return_all_final_tokens:
270 | return self.ln_post(x) @ self.proj
271 |
272 | x = self.ln_post(x[:, 0, :])
273 |
274 | if self.proj is not None:
275 | x = x @ self.proj
276 |
277 | return x
278 |
279 | def interpolate_pos_encoding(self, x, w, h):
280 | # pdb.set_trace()
281 | npatch = x.shape[1] - 1
282 | N = self.positional_embedding.shape[0] - 1 # 256 for large
283 | if npatch == N and w == h:
284 | return self.positional_embedding
285 | class_pos_embed = self.positional_embedding[[0]]
286 | patch_pos_embed = self.positional_embedding[1:]
287 | dim = x.shape[-1]
288 | w0 = w // self.patch_size
289 | h0 = h // self.patch_size
290 | # we add a small number to avoid floating point error in the interpolation
291 | # see discussion at https://github.com/facebookresearch/dino/issues/8
292 | w0, h0 = w0 + 0.1, h0 + 0.1
293 | patch_pos_embed = nn.functional.interpolate(
294 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
295 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
296 | mode='bicubic',
297 | )
298 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
299 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
300 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
301 |
302 |
303 | class CLIP(nn.Module):
304 | def __init__(self,
305 | embed_dim: int, # 512
306 | # vision
307 | image_resolution: int, # 224
308 | vision_layers: Union[Tuple[int, int, int, int], int], # 12
309 | vision_width: int, # 768
310 | vision_patch_size: int, # 16
311 | # text
312 | context_length: int, # 77
313 | vocab_size: int, # 49408
314 | transformer_width: int, # 512
315 | transformer_heads: int, # 8
316 | transformer_layers: int # 12
317 | ):
318 | super().__init__()
319 | # pdb.set_trace()
320 | self.context_length = context_length
321 |
322 | if isinstance(vision_layers, (tuple, list)):
323 | vision_heads = vision_width * 32 // 64
324 | self.visual = ModifiedResNet(
325 | layers=vision_layers,
326 | output_dim=embed_dim,
327 | heads=vision_heads,
328 | input_resolution=image_resolution,
329 | width=vision_width
330 | )
331 | else:
332 | vision_heads = vision_width // 64
333 | self.visual = VisionTransformer(
334 | input_resolution=image_resolution,
335 | patch_size=vision_patch_size,
336 | width=vision_width,
337 | layers=vision_layers,
338 | heads=vision_heads,
339 | output_dim=embed_dim
340 | )
341 |
342 | self.transformer = Transformer(
343 | width=transformer_width,
344 | layers=transformer_layers,
345 | heads=transformer_heads,
346 | attn_mask=self.build_attention_mask()
347 | )
348 |
349 | self.vocab_size = vocab_size
350 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
351 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
352 | self.ln_final = LayerNorm(transformer_width)
353 |
354 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
355 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
356 |
357 | self.initialize_parameters()
358 |
359 | def initialize_parameters(self):
360 | nn.init.normal_(self.token_embedding.weight, std=0.02)
361 | nn.init.normal_(self.positional_embedding, std=0.01)
362 |
363 | if isinstance(self.visual, ModifiedResNet):
364 | if self.visual.attnpool is not None:
365 | std = self.visual.attnpool.c_proj.in_features ** -0.5
366 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
367 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
368 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
369 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
370 |
371 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
372 | for name, param in resnet_block.named_parameters():
373 | if name.endswith("bn3.weight"):
374 | nn.init.zeros_(param)
375 |
376 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
377 | attn_std = self.transformer.width ** -0.5
378 | fc_std = (2 * self.transformer.width) ** -0.5
379 | for block in self.transformer.resblocks:
380 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
381 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
382 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
383 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
384 |
385 | if self.text_projection is not None:
386 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
387 |
388 | def build_attention_mask(self):
389 | # lazily create causal attention mask, with full attention between the vision tokens
390 | # pytorch uses additive attention mask; fill with -inf
391 | mask = torch.empty(self.context_length, self.context_length)
392 | mask.fill_(float("-inf"))
393 | mask.triu_(1) # zero out the lower diagonal
394 | return mask
395 |
396 | @property
397 | def dtype(self):
398 | return self.visual.conv1.weight.dtype
399 |
400 | def encode_image(self, image, return_side_out=False, return_all_tokens=False, return_all_final_tokens=False, **kwargs):
401 | return self.visual(image.type(self.dtype), return_all_tokens, return_all_final_tokens, **kwargs)
402 |
403 | def encode_text(self, text, return_all_tokens=False, return_patch_tokens=False):
404 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
405 |
406 | x = x + self.positional_embedding.type(self.dtype)
407 | x = x.permute(1, 0, 2) # NLD -> LND
408 | x = self.transformer(x)
409 | x = x.permute(1, 0, 2) # LND -> NLD
410 | x = self.ln_final(x).type(self.dtype)
411 |
412 | if return_patch_tokens:
413 | return x
414 | # x.shape = [batch_size, n_ctx, transformer.width]
415 | # take features from the eot embedding (eot_token is the highest number in each sequence)
416 | if return_all_tokens:
417 | # pdb.set_trace()
418 | x = x @ self.text_projection
419 | else:
420 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
421 | return x
422 |
423 | def forward(self, image, text):
424 | image_features = self.encode_image(image)
425 | text_features = self.encode_text(text)
426 |
427 | # normalized features
428 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
429 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
430 |
431 | # cosine similarity as logits
432 | logit_scale = self.logit_scale.exp()
433 | logits_per_image = logit_scale * image_features @ text_features.t()
434 | logits_per_text = logits_per_image.t()
435 |
436 | # shape = [global_batch_size, global_batch_size]
437 | return logits_per_image, logits_per_text
438 |
439 |
440 | def convert_weights(model: nn.Module):
441 | """Convert applicable model parameters to fp16"""
442 |
443 | def _convert_weights_to_fp16(l):
444 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
445 | l.weight.data = l.weight.data.half()
446 | if l.bias is not None:
447 | l.bias.data = l.bias.data.half()
448 |
449 | if isinstance(l, nn.MultiheadAttention):
450 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
451 | tensor = getattr(l, attr)
452 | if tensor is not None:
453 | tensor.data = tensor.data.half()
454 |
455 | for name in ["text_projection", "proj"]:
456 | if hasattr(l, name):
457 | attr = getattr(l, name)
458 | if attr is not None:
459 | attr.data = attr.data.half()
460 |
461 | model.apply(_convert_weights_to_fp16)
462 |
463 |
464 | def build_model(state_dict: dict):
465 | vit = "visual.proj" in state_dict
466 |
467 | if vit:
468 | vision_width = state_dict["visual.conv1.weight"].shape[0]
469 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
470 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
471 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
472 | image_resolution = vision_patch_size * grid_size
473 | else:
474 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
475 | vision_layers = tuple(counts)
476 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
477 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
478 | vision_patch_size = None
479 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
480 | image_resolution = output_width * 32
481 |
482 | embed_dim = state_dict["text_projection"].shape[1]
483 | context_length = state_dict["positional_embedding"].shape[0]
484 | vocab_size = state_dict["token_embedding.weight"].shape[0]
485 | transformer_width = state_dict["ln_final.weight"].shape[0]
486 | transformer_heads = transformer_width // 64
487 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
488 |
489 | model = CLIP(
490 | embed_dim,
491 | image_resolution, vision_layers, vision_width, vision_patch_size,
492 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
493 | )
494 |
495 | for key in ["input_resolution", "context_length", "vocab_size"]:
496 | if key in state_dict:
497 | del state_dict[key]
498 |
499 | convert_weights(model)
500 | model.load_state_dict(state_dict)
501 | return model.eval()
502 |
--------------------------------------------------------------------------------