├── README.md └── transformer_cond_2_sample.py /README.md: -------------------------------------------------------------------------------- 1 | # cond_transformer_2 2 | A CLIP conditioned Decision Transformer. 3 | -------------------------------------------------------------------------------- /transformer_cond_2_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Samples an image from a CLIP conditioned Decision Transformer.""" 4 | 5 | import argparse 6 | from pathlib import Path 7 | import sys 8 | 9 | from omegaconf import OmegaConf 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | from torchvision import transforms 14 | from torchvision.transforms import functional as TF 15 | from transformers import top_k_top_p_filtering 16 | from tqdm import trange 17 | 18 | sys.path.append('./taming-transformers') 19 | 20 | from CLIP import clip 21 | from taming.models import vqgan 22 | 23 | 24 | def setup_exceptions(): 25 | try: 26 | from IPython.core.ultratb import FormattedTB 27 | sys.excepthook = FormattedTB(mode='Plain', color_scheme='Neutral') 28 | except ImportError: 29 | pass 30 | 31 | 32 | class CausalTransformerEncoder(nn.TransformerEncoder): 33 | def forward(self, src, mask=None, src_key_padding_mask=None, cache=None): 34 | output = src 35 | 36 | if self.training: 37 | if cache is not None: 38 | raise ValueError("cache parameter should be None in training mode") 39 | for mod in self.layers: 40 | output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) 41 | 42 | if self.norm is not None: 43 | output = self.norm(output) 44 | 45 | return output 46 | 47 | new_token_cache = [] 48 | compute_len = src.shape[0] 49 | if cache is not None: 50 | compute_len -= cache.shape[1] 51 | for i, mod in enumerate(self.layers): 52 | output = mod(output, compute_len=compute_len) 53 | new_token_cache.append(output) 54 | if cache is not None: 55 | output = torch.cat([cache[i], output], dim=0) 56 | 57 | if cache is not None: 58 | new_cache = torch.cat([cache, torch.stack(new_token_cache, dim=0)], dim=1) 59 | else: 60 | new_cache = torch.stack(new_token_cache, dim=0) 61 | 62 | return output, new_cache 63 | 64 | 65 | class CausalTransformerEncoderLayer(nn.TransformerEncoderLayer): 66 | def forward(self, src, src_mask=None, src_key_padding_mask=None, compute_len=None): 67 | if self.training: 68 | return super().forward(src, src_mask, src_key_padding_mask) 69 | 70 | if compute_len is None: 71 | src_last_tok = src 72 | else: 73 | src_last_tok = src[-compute_len:, :, :] 74 | 75 | attn_mask = src_mask if compute_len > 1 else None 76 | tmp_src = self.self_attn(src_last_tok, src, src, attn_mask=attn_mask, 77 | key_padding_mask=src_key_padding_mask)[0] 78 | src_last_tok = src_last_tok + self.dropout1(tmp_src) 79 | src_last_tok = self.norm1(src_last_tok) 80 | 81 | tmp_src = self.linear2(self.dropout(self.activation(self.linear1(src_last_tok)))) 82 | src_last_tok = src_last_tok + self.dropout2(tmp_src) 83 | src_last_tok = self.norm2(src_last_tok) 84 | return src_last_tok 85 | 86 | 87 | class CLIPToImageTransformer(nn.Module): 88 | def __init__(self, clip_dim, seq_len, n_toks): 89 | super().__init__() 90 | self.clip_dim = clip_dim 91 | d_model = 1024 92 | self.clip_in_proj = nn.Linear(clip_dim, d_model, bias=False) 93 | self.clip_score_in_proj = nn.Linear(1, d_model, bias=False) 94 | self.in_embed = nn.Embedding(n_toks, d_model) 95 | self.out_proj = nn.Linear(d_model, n_toks) 96 | layer = CausalTransformerEncoderLayer(d_model, d_model // 64, d_model * 4, 97 | dropout=0, activation='gelu') 98 | self.encoder = CausalTransformerEncoder(layer, 24) 99 | self.pos_emb = nn.Parameter(torch.zeros([seq_len + 1, d_model])) 100 | self.register_buffer('mask', self._generate_causal_mask(seq_len + 1), persistent=False) 101 | 102 | @staticmethod 103 | def _generate_causal_mask(size): 104 | mask = (torch.triu(torch.ones([size, size])) == 1).transpose(0, 1) 105 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0)) 106 | mask[0, 1] = 0 107 | return mask 108 | 109 | def forward(self, clip_embed, clip_score, input=None, cache=None): 110 | if input is None: 111 | input = torch.zeros([len(clip_embed), 0], dtype=torch.long, device=clip_embed.device) 112 | clip_embed_proj = self.clip_in_proj(F.normalize(clip_embed, dim=1) * self.clip_dim**0.5) 113 | clip_score_proj = self.clip_score_in_proj(clip_score) 114 | embed = torch.cat([clip_embed_proj.unsqueeze(0), 115 | clip_score_proj.unsqueeze(0), 116 | self.in_embed(input.T)]) 117 | embed_plus_pos = embed + self.pos_emb[:len(embed)].unsqueeze(1) 118 | mask = self.mask[:len(embed), :len(embed)] 119 | out, cache = self.encoder(embed_plus_pos, mask, cache=cache) 120 | return self.out_proj(out[1:]).transpose(0, 1), cache 121 | 122 | 123 | def main(): 124 | setup_exceptions() 125 | 126 | p = argparse.ArgumentParser(description=__doc__, 127 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 128 | p.add_argument('prompt', type=str, 129 | help='the prompt') 130 | p.add_argument('--batch-size', '-bs', type=int, default=4, 131 | help='the batch size') 132 | p.add_argument('--checkpoint', type=Path, required=True, 133 | help='the checkpoint to use') 134 | p.add_argument('--clip-score', type=float, default=1., 135 | help='the CLIP score to condition on') 136 | p.add_argument('--device', type=str, default=None, 137 | help='the device to use') 138 | p.add_argument('--half', action='store_true', 139 | help='use half precision') 140 | p.add_argument('-k', type=int, default=1, 141 | help='the number of samples to save') 142 | p.add_argument('-n', type=int, default=1, 143 | help='the number of samples to draw') 144 | p.add_argument('--output', '-o', type=str, default='out', 145 | help='the output prefix') 146 | p.add_argument('--seed', type=int, default=0, 147 | help='the random seed') 148 | p.add_argument('--temperature', type=float, default=1., 149 | help='the softmax temperature for sampling') 150 | p.add_argument('--top-k', type=int, default=0, 151 | help='the top-k value for sampling') 152 | p.add_argument('--top-p', type=float, default=1., 153 | help='the top-p value for sampling') 154 | p.add_argument('--vqgan-checkpoint', type=Path, required=True, 155 | help='the VQGAN checkpoint (.ckpt)') 156 | p.add_argument('--vqgan-config', type=Path, required=True, 157 | help='the VQGAN config (.yaml)') 158 | args = p.parse_args() 159 | 160 | if args.device: 161 | device = torch.device(args.device) 162 | else: 163 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 164 | print('Using device:', device) 165 | 166 | dtype = torch.half if args.half else torch.float 167 | 168 | perceptor = clip.load('ViT-B/32', jit=False)[0].to(device).eval().requires_grad_(False) 169 | normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 170 | std=[0.26862954, 0.26130258, 0.27577711]) 171 | vqgan_config = OmegaConf.load(args.vqgan_config) 172 | vqgan_model = vqgan.VQModel(**vqgan_config.model.params).to(device) 173 | vqgan_model.eval().requires_grad_(False) 174 | vqgan_model.init_from_ckpt(args.vqgan_checkpoint) 175 | del vqgan_model.loss 176 | 177 | clip_dim = perceptor.visual.output_dim 178 | clip_input_res = perceptor.visual.input_resolution 179 | e_dim = vqgan_model.quantize.e_dim 180 | f = 2**(vqgan_model.decoder.num_resolutions - 1) 181 | n_toks = vqgan_model.quantize.n_e 182 | size_x, size_y = 384, 384 183 | toks_x, toks_y = size_x // f, size_y // f 184 | 185 | torch.manual_seed(args.seed) 186 | 187 | text_embed = perceptor.encode_text(clip.tokenize(args.prompt).to(device)).to(dtype) 188 | text_embed = text_embed.repeat([args.n, 1]) 189 | clip_score = torch.ones([text_embed.shape[0], 1], device=device, dtype=dtype) * args.clip_score 190 | 191 | model = CLIPToImageTransformer(clip_dim, toks_y * toks_x, n_toks) 192 | ckpt = torch.load(args.checkpoint, map_location=device) 193 | model.load_state_dict(ckpt['model']) 194 | model = model.to(device, dtype).eval().requires_grad_(False) 195 | 196 | @torch.no_grad() 197 | def sample(clip_embed, clip_score, temperature=1., top_k=0, top_p=1.): 198 | tokens = torch.zeros([len(clip_embed), 0], dtype=torch.long, device=device) 199 | cache = None 200 | for i in trange(toks_y * toks_x, leave=False): 201 | logits, cache = model(clip_embed, clip_score, tokens, cache=cache) 202 | logits = logits[:, -1] / temperature 203 | logits = top_k_top_p_filtering(logits, top_k, top_p) 204 | next_token = logits.softmax(1).multinomial(1) 205 | tokens = torch.cat([tokens, next_token], dim=1) 206 | return tokens 207 | 208 | def decode(tokens): 209 | z = vqgan_model.quantize.embedding(tokens).view([-1, toks_y, toks_x, e_dim]).movedim(3, 1) 210 | return vqgan_model.decode(z).add(1).div(2).clamp(0, 1) 211 | 212 | try: 213 | out_lst, sim_lst = [], [] 214 | for i in trange(0, len(text_embed), args.batch_size): 215 | tokens = sample(text_embed[i:i+args.batch_size], clip_score[i:i+args.batch_size], 216 | temperature=args.temperature, top_k=args.top_k, top_p=args.top_p) 217 | out = decode(tokens) 218 | out_lst.append(out) 219 | out_for_clip = F.interpolate(out, (clip_input_res, clip_input_res), 220 | mode='bilinear', align_corners=False) 221 | image_embed = perceptor.encode_image(normalize(out_for_clip)).to(dtype) 222 | sim = torch.cosine_similarity(text_embed[i:i+args.batch_size], image_embed) 223 | sim_lst.append(sim) 224 | out = torch.cat(out_lst) 225 | sim = torch.cat(sim_lst) 226 | best_values, best_indices = sim.topk(min(args.k, args.n)) 227 | for i, index in enumerate(best_indices): 228 | TF.to_pil_image(out[index]).save(args.output + f'_{i:03}.png') 229 | print(f'Actual CLIP score for output {i}: {best_values[i].item():g}') 230 | except KeyboardInterrupt: 231 | pass 232 | 233 | 234 | if __name__ == '__main__': 235 | main() 236 | --------------------------------------------------------------------------------