├── .gitignore ├── README.md ├── cog.yaml ├── example_images ├── OCR │ ├── stanford.jpg │ └── welcome_sign.jpg ├── arithmetics │ ├── king2.jpg │ ├── man2.jpg │ └── woman2.jpg ├── captions │ ├── COCO_val2014_000000008775.jpg │ ├── COCO_val2014_000000097017.jpg │ ├── COCO_val2014_000000406395.jpg │ └── COCO_val2014_000000557731.jpg └── real_world │ ├── london.jpg │ ├── simpsons.jpg │ └── trump.jpg ├── forbidden_tokens.npy ├── git_images ├── Architecture.jpg ├── relations.jpg └── teaser.jpg ├── model ├── ZeroCLIP.py ├── ZeroCLIP_batched.py └── __init__.py ├── predict.py ├── predict_arithmetic.py ├── requirements.txt ├── run.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of [Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic](https://arxiv.org/abs/2111.14447) [CVPR 2022] 2 | ### Check out our follow-up work - [Zero-Shot Video Captioning with Evolving Pseudo-Tokens](https://github.com/YoadTew/zero-shot-video-to-text)! 3 | [[Paper]](https://arxiv.org/abs/2111.14447) [[Notebook]](https://www.kaggle.com/yoavstau/zero-shot-image-to-text/notebook) [[Caption Demo]](https://replicate.com/yoadtew/zero-shot-image-to-text) [[Arithmetic Demo]](https://replicate.com/yoadtew/arithmetic) [[Visual Relations Dataset]](https://drive.google.com/file/d/1hf5_zPI3hfMLNMTllZtWXcjf6ZoSTGcI) 4 | 5 | ⭐ ***New:*** Run captioning configuration it in the [browser](https://replicate.com/yoadtew/zero-shot-image-to-text) using replicate.ai UI. 6 | 7 | ## Approach 8 | ![](git_images/Architecture.jpg) 9 | 10 | ## Example of capabilities 11 | ![](git_images/teaser.jpg) 12 | 13 | ## Example of Visual-Semantic Arithmetic 14 | ![](git_images/relations.jpg) 15 | 16 | ## Usage 17 | 18 | ### To run captioning on a single image: 19 | 20 | ```bash 21 | $ python run.py 22 | --reset_context_delta 23 | --caption_img_path "example_images/captions/COCO_val2014_000000097017.jpg" 24 | ``` 25 | 26 | ### To run model on visual arithmetic: 27 | 28 | ```bash 29 | $ python run.py 30 | --reset_context_delta 31 | --end_factor 1.06 32 | --fusion_factor 0.95 33 | --grad_norm_factor 0.95 34 | --run_type arithmetics 35 | --arithmetics_imgs "example_images/arithmetics/woman2.jpg" "example_images/arithmetics/king2.jpg" "example_images/arithmetics/man2.jpg" 36 | --arithmetics_weights 1 1 -1 37 | ``` 38 | 39 | ### To run model on real world knowledge: 40 | 41 | ```bash 42 | $ python run.py 43 | --reset_context_delta --cond_text "Image of" 44 | --end_factor 1.04 45 | --caption_img_path "example_images/real_world/simpsons.jpg" 46 | ``` 47 | 48 | ### To run model on OCR: 49 | 50 | ```bash 51 | $ python run.py 52 | --reset_context_delta --cond_text "Image of text that says" 53 | --end_factor 1.04 54 | --caption_img_path "example_images/OCR/welcome_sign.jpg" 55 | ``` 56 | 57 | ### For runtime speedup using multiple gpus, use the --multi_gpu flag: 58 | 59 | ```bash 60 | $ CUDA_VISIBLE_DEVICES=0,1,2,3,4 python run.py 61 | --reset_context_delta 62 | --caption_img_path "example_images/captions/COCO_val2014_000000097017.jpg" 63 | --multi_gpu 64 | ``` 65 | 66 | ## Citation 67 | Please cite our work if you use it in your research: 68 | ``` 69 | @article{tewel2021zero, 70 | title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic}, 71 | author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior}, 72 | journal={arXiv preprint arXiv:2111.14447}, 73 | year={2021} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | python_version: "3.8" 4 | system_packages: 5 | - "libgl1-mesa-glx" 6 | - "libglib2.0-0" 7 | python_packages: 8 | - "git+https://github.com/openai/CLIP.git" 9 | - "git+https://github.com/YoadTew/zero-shot-image-to-text.git" 10 | 11 | predict: "predict.py:Predictor" 12 | #predict: "predict_arithmetic.py:Predictor" -------------------------------------------------------------------------------- /example_images/OCR/stanford.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/OCR/stanford.jpg -------------------------------------------------------------------------------- /example_images/OCR/welcome_sign.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/OCR/welcome_sign.jpg -------------------------------------------------------------------------------- /example_images/arithmetics/king2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/arithmetics/king2.jpg -------------------------------------------------------------------------------- /example_images/arithmetics/man2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/arithmetics/man2.jpg -------------------------------------------------------------------------------- /example_images/arithmetics/woman2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/arithmetics/woman2.jpg -------------------------------------------------------------------------------- /example_images/captions/COCO_val2014_000000008775.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/captions/COCO_val2014_000000008775.jpg -------------------------------------------------------------------------------- /example_images/captions/COCO_val2014_000000097017.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/captions/COCO_val2014_000000097017.jpg -------------------------------------------------------------------------------- /example_images/captions/COCO_val2014_000000406395.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/captions/COCO_val2014_000000406395.jpg -------------------------------------------------------------------------------- /example_images/captions/COCO_val2014_000000557731.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/captions/COCO_val2014_000000557731.jpg -------------------------------------------------------------------------------- /example_images/real_world/london.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/real_world/london.jpg -------------------------------------------------------------------------------- /example_images/real_world/simpsons.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/real_world/simpsons.jpg -------------------------------------------------------------------------------- /example_images/real_world/trump.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/example_images/real_world/trump.jpg -------------------------------------------------------------------------------- /forbidden_tokens.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/forbidden_tokens.npy -------------------------------------------------------------------------------- /git_images/Architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/git_images/Architecture.jpg -------------------------------------------------------------------------------- /git_images/relations.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/git_images/relations.jpg -------------------------------------------------------------------------------- /git_images/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/git_images/teaser.jpg -------------------------------------------------------------------------------- /model/ZeroCLIP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer 4 | from transformers.models.gpt_neo import GPTNeoForCausalLM 5 | import torch 6 | import clip 7 | from PIL import Image 8 | from datetime import datetime 9 | import sys 10 | 11 | 12 | def log_info(text, verbose=True): 13 | if verbose: 14 | dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") 15 | print(f'{dt_string} | {text}') 16 | sys.stdout.flush() 17 | 18 | 19 | def add_context(x, y): 20 | return (x[0] + y[0], x[1] + y[1]) 21 | 22 | 23 | def convert_models_to_fp32(model): 24 | for p in model.parameters(): 25 | p.data = p.data.float() 26 | 27 | 28 | class CLIPTextGenerator: 29 | def __init__(self, 30 | seed=0, 31 | lm_model='gpt-2', 32 | forbidden_tokens_file_path='./forbidden_tokens.npy', 33 | clip_checkpoints='./clip_checkpoints', 34 | target_seq_length=15, 35 | reset_context_delta=True, 36 | num_iterations=5, 37 | clip_loss_temperature=0.01, 38 | clip_scale=1., 39 | ce_scale=0.2, 40 | stepsize=0.3, 41 | grad_norm_factor=0.9, 42 | fusion_factor=0.99, 43 | repetition_penalty=1., 44 | end_token='.', 45 | end_factor=1.01, 46 | forbidden_factor=20, 47 | **kwargs): 48 | 49 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 50 | 51 | # set Random seed 52 | torch.manual_seed(seed) 53 | np.random.seed(seed) 54 | 55 | # Initialize Language model 56 | self.context_prefix = '' 57 | 58 | if lm_model == 'gpt-neo': 59 | self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M') 60 | self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True) 61 | elif lm_model == 'gpt-2': 62 | self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') 63 | self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True) 64 | self.context_prefix = self.lm_tokenizer.bos_token 65 | 66 | self.lm_model.to(self.device) 67 | self.lm_model.eval() 68 | 69 | self.forbidden_tokens = np.load(forbidden_tokens_file_path) 70 | self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if 71 | (x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())] 72 | 73 | # Freeze LM weights 74 | for param in self.lm_model.parameters(): 75 | param.requires_grad = False 76 | 77 | # Initialize CLIP 78 | self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, 79 | download_root=clip_checkpoints, jit=False) 80 | # convert_models_to_fp32(self.clip) 81 | 82 | # Init arguments 83 | self.target_seq_length = target_seq_length 84 | self.reset_context_delta = reset_context_delta 85 | self.num_iterations = num_iterations 86 | self.clip_loss_temperature = clip_loss_temperature 87 | self.clip_scale = clip_scale 88 | self.ce_scale = ce_scale 89 | self.stepsize = stepsize 90 | self.grad_norm_factor = grad_norm_factor 91 | self.fusion_factor = fusion_factor 92 | self.repetition_penalty = repetition_penalty 93 | self.end_token = self.lm_tokenizer.encode(end_token)[0] 94 | self.end_factor = end_factor 95 | self.ef_idx = 1 96 | self.forbidden_factor = forbidden_factor 97 | 98 | def get_img_feature(self, img_path, weights): 99 | imgs = [Image.open(x) for x in img_path] 100 | clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] 101 | 102 | with torch.no_grad(): 103 | image_fts = [self.clip.encode_image(x) for x in clip_imgs] 104 | 105 | if weights is not None: 106 | image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) 107 | else: 108 | image_features = sum(image_fts) 109 | 110 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 111 | return image_features.detach() 112 | 113 | def get_txt_features(self, text): 114 | clip_texts = clip.tokenize(text).to(self.device) 115 | 116 | with torch.no_grad(): 117 | text_features = self.clip.encode_text(clip_texts) 118 | 119 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 120 | return text_features.detach() 121 | 122 | def get_combined_feature(self, img_path, texts, weights_i, weights_t): 123 | imgs = [Image.open(x) for x in img_path] 124 | clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] 125 | clip_texts = [clip.tokenize(x).to(self.device) for x in texts] 126 | 127 | with torch.no_grad(): 128 | image_fts = [self.clip.encode_image(x) for x in clip_imgs] 129 | text_fts = [self.clip.encode_text(x) for x in clip_texts] 130 | 131 | features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) 132 | if weights_t is not None: 133 | features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) 134 | 135 | features = features / features.norm(dim=-1, keepdim=True) 136 | return features.detach() 137 | 138 | def run(self, image_features, cond_text, beam_size): 139 | self.image_features = image_features 140 | 141 | context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) 142 | 143 | output_tokens, output_text = self.generate_text(context_tokens, beam_size) 144 | 145 | return output_text 146 | 147 | def generate_text(self, context_tokens, beam_size): 148 | context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) 149 | 150 | gen_tokens = None 151 | scores = None 152 | seq_lengths = torch.ones(beam_size, device=self.device) 153 | is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) 154 | 155 | for i in range(self.target_seq_length): 156 | probs = self.get_next_probs(i, context_tokens) 157 | logits = probs.log() 158 | 159 | if scores is None: 160 | scores, next_tokens = logits.topk(beam_size, -1) 161 | context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) 162 | next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) 163 | 164 | if gen_tokens is None: 165 | gen_tokens = next_tokens 166 | else: 167 | gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) 168 | gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) 169 | else: 170 | logits[is_stopped] = -float(np.inf) 171 | logits[is_stopped, 0] = 0 172 | scores_sum = scores[:, None] + logits 173 | seq_lengths[~is_stopped] += 1 174 | scores_sum_average = scores_sum / seq_lengths[:, None] 175 | scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( 176 | beam_size, -1) 177 | next_tokens_source = next_tokens // scores_sum.shape[1] 178 | seq_lengths = seq_lengths[next_tokens_source] 179 | next_tokens = next_tokens % scores_sum.shape[1] 180 | next_tokens = next_tokens.unsqueeze(1) 181 | gen_tokens = gen_tokens[next_tokens_source] 182 | gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) 183 | context_tokens = context_tokens[next_tokens_source] 184 | scores = scores_sum_average * seq_lengths 185 | is_stopped = is_stopped[next_tokens_source] 186 | 187 | context_tokens = torch.cat((context_tokens, next_tokens), dim=1) 188 | is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() 189 | 190 | #### 191 | tmp_scores = scores / seq_lengths 192 | tmp_output_list = gen_tokens.cpu().numpy() 193 | tmp_output_texts = [ 194 | self.lm_tokenizer.decode(tmp_output) 195 | for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) 196 | ] 197 | tmp_order = tmp_scores.argsort(descending=True) 198 | tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] 199 | log_info(tmp_output_texts, verbose=True) 200 | #### 201 | 202 | if is_stopped.all(): 203 | break 204 | 205 | scores = scores / seq_lengths 206 | output_list = gen_tokens.cpu().numpy() 207 | output_texts = [ 208 | self.lm_tokenizer.decode(output[: int(length)]) 209 | for output, length in zip(output_list, seq_lengths) 210 | ] 211 | order = scores.argsort(descending=True) 212 | output_texts = [output_texts[i] for i in order] 213 | 214 | return context_tokens, output_texts 215 | 216 | def get_next_probs(self, i, context_tokens): 217 | last_token = context_tokens[:, -1:] 218 | 219 | if self.reset_context_delta and context_tokens.size(1) > 1: 220 | context = self.lm_model(context_tokens[:, :-1])["past_key_values"] 221 | 222 | # Logits of LM with unshifted context 223 | logits_before_shift = self.lm_model(context_tokens)["logits"] 224 | logits_before_shift = logits_before_shift[:, -1, :] 225 | probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) 226 | 227 | if context: 228 | context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) 229 | 230 | lm_output = self.lm_model(last_token, past_key_values=context) 231 | logits, past = ( 232 | lm_output["logits"], 233 | lm_output["past_key_values"], 234 | ) 235 | logits = logits[:, -1, :] 236 | 237 | logits = self.update_special_tokens_logits(context_tokens, i, logits) 238 | 239 | probs = nn.functional.softmax(logits, dim=-1) 240 | probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) 241 | probs = probs / probs.sum() 242 | 243 | return probs 244 | 245 | def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): 246 | context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] 247 | 248 | window_mask = torch.ones_like(context[0][0]).to(self.device) 249 | 250 | for i in range(self.num_iterations): 251 | curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in 252 | context_delta] 253 | 254 | for p0, p1 in curr_shift: 255 | p0.retain_grad() 256 | p1.retain_grad() 257 | 258 | shifted_context = list(map(add_context, context, curr_shift)) 259 | 260 | shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) 261 | logits = shifted_outputs["logits"][:, -1, :] 262 | probs = nn.functional.softmax(logits, dim=-1) 263 | 264 | loss = 0.0 265 | 266 | # CLIP LOSS 267 | clip_loss, clip_losses = self.clip_loss(probs, context_tokens) 268 | loss += self.clip_scale * clip_loss 269 | 270 | # CE/Fluency loss 271 | ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) 272 | loss += ce_loss.sum() 273 | 274 | loss.backward() 275 | 276 | # ---------- Weights ---------- 277 | combined_scores_k = -(ce_loss) 278 | combined_scores_c = -(self.clip_scale * torch.stack(clip_losses)) 279 | 280 | # minmax 281 | if combined_scores_k.shape[0] == 1: 282 | tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device) 283 | else: 284 | tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / ( 285 | combined_scores_k.max() - combined_scores_k.min()) 286 | tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / ( 287 | combined_scores_c.max() - combined_scores_c.min()) 288 | 289 | tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c 290 | tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1) 291 | 292 | factor = 1 293 | 294 | # --------- Specific Gen --------- 295 | sep_grads = None 296 | 297 | for b in range(context_tokens.shape[0]): 298 | tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] 299 | for p_ in curr_shift] 300 | 301 | # normalize gradients 302 | tmp_grad = [tuple([-self.stepsize * factor * ( 303 | x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ 304 | j] ** self.grad_norm_factor).data.cpu().numpy() 305 | for j, x in enumerate(p_)]) 306 | for i, p_ in enumerate(curr_shift)] 307 | if sep_grads is None: 308 | sep_grads = tmp_grad 309 | else: 310 | for l_index in range(len(sep_grads)): 311 | sep_grads[l_index] = list(sep_grads[l_index]) 312 | for k_index in range(len(sep_grads[0])): 313 | sep_grads[l_index][k_index] = np.concatenate( 314 | (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) 315 | sep_grads[l_index] = tuple(sep_grads[l_index]) 316 | final_grads = sep_grads 317 | 318 | # --------- update context --------- 319 | context_delta = list(map(add_context, final_grads, context_delta)) 320 | 321 | for p0, p1 in curr_shift: 322 | p0.grad.data.zero_() 323 | p1.grad.data.zero_() 324 | 325 | new_context = [] 326 | for p0, p1 in context: 327 | new_context.append((p0.detach(), p1.detach())) 328 | context = new_context 329 | 330 | context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) 331 | for p_ in context_delta] 332 | context = list(map(add_context, context, context_delta)) 333 | 334 | new_context = [] 335 | for p0, p1 in context: 336 | new_context.append((p0.detach(), p1.detach())) 337 | context = new_context 338 | 339 | return context 340 | 341 | def update_special_tokens_logits(self, context_tokens, i, logits): 342 | for beam_id in range(context_tokens.shape[0]): 343 | for token_idx in set(context_tokens[beam_id][-4:].tolist()): 344 | factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) 345 | logits[beam_id, token_idx] /= factor 346 | 347 | if i >= self.ef_idx: 348 | factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) 349 | logits[beam_id, self.end_token] *= factor 350 | if i == 0: 351 | start_factor = 1.6 352 | factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) 353 | logits[beam_id, self.end_token] /= factor 354 | 355 | for token_idx in list(self.forbidden_tokens): 356 | factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) 357 | logits[beam_id, token_idx] /= factor 358 | 359 | return logits 360 | 361 | def clip_loss(self, probs, context_tokens): 362 | for p_ in self.clip.transformer.parameters(): 363 | if p_.grad is not None: 364 | p_.grad.data.zero_() 365 | 366 | top_size = 512 367 | _, top_indices = probs.topk(top_size, -1) 368 | 369 | prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] 370 | 371 | clip_loss = 0 372 | losses = [] 373 | for idx_p in range(probs.shape[0]): 374 | top_texts = [] 375 | prefix_text = prefix_texts[idx_p] 376 | for x in top_indices[idx_p]: 377 | top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) 378 | text_features = self.get_txt_features(top_texts) 379 | 380 | with torch.no_grad(): 381 | similiraties = (self.image_features @ text_features.T) 382 | target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() 383 | target_probs = target_probs.type(torch.float32) 384 | 385 | target = torch.zeros_like(probs[idx_p]) 386 | target[top_indices[idx_p]] = target_probs[0] 387 | target = target.unsqueeze(0) 388 | cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) 389 | 390 | clip_loss += cur_clip_loss 391 | losses.append(cur_clip_loss) 392 | 393 | return clip_loss, losses -------------------------------------------------------------------------------- /model/ZeroCLIP_batched.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer 4 | from transformers.models.gpt_neo import GPTNeoForCausalLM 5 | import torch 6 | import clip 7 | from PIL import Image 8 | from datetime import datetime 9 | import sys 10 | 11 | class TextCLIP(nn.Module): 12 | def __init__(self, model): 13 | super(TextCLIP, self).__init__() 14 | self.model = model 15 | 16 | def forward(self, text): 17 | return self.model.encode_text(text) 18 | 19 | 20 | class ImageCLIP(nn.Module): 21 | def __init__(self, model): 22 | super(ImageCLIP, self).__init__() 23 | self.model = model 24 | 25 | def forward(self, image): 26 | return self.model.encode_image(image) 27 | 28 | def log_info(text, verbose=True): 29 | if verbose: 30 | dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") 31 | print(f'{dt_string} | {text}') 32 | sys.stdout.flush() 33 | 34 | 35 | def add_context(x, y): 36 | return (x[0] + y[0], x[1] + y[1]) 37 | 38 | 39 | def convert_models_to_fp32(model): 40 | for p in model.parameters(): 41 | p.data = p.data.float() 42 | 43 | 44 | class CLIPTextGenerator: 45 | def __init__(self, 46 | seed=0, 47 | lm_model='gpt-2', 48 | forbidden_tokens_file_path='./forbidden_tokens.npy', 49 | clip_checkpoints='./clip_checkpoints', 50 | target_seq_length=15, 51 | reset_context_delta=True, 52 | num_iterations=5, 53 | clip_loss_temperature=0.01, 54 | clip_scale=1., 55 | ce_scale=0.2, 56 | stepsize=0.3, 57 | grad_norm_factor=0.9, 58 | fusion_factor=0.99, 59 | repetition_penalty=1., 60 | end_token='.', 61 | end_factor=1.01, 62 | forbidden_factor=20, 63 | **kwargs): 64 | 65 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 66 | 67 | # set Random seed 68 | torch.manual_seed(seed) 69 | np.random.seed(seed) 70 | 71 | # Initialize Language model 72 | self.context_prefix = '' 73 | 74 | if lm_model == 'gpt-neo': 75 | self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M') 76 | self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True) 77 | elif lm_model == 'gpt-2': 78 | self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') 79 | self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True) 80 | self.context_prefix = self.lm_tokenizer.bos_token 81 | 82 | self.lm_model.to(self.device) 83 | self.lm_model.eval() 84 | 85 | self.forbidden_tokens = np.load(forbidden_tokens_file_path) 86 | self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if 87 | (x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())] 88 | 89 | # Freeze LM weights 90 | for param in self.lm_model.parameters(): 91 | param.requires_grad = False 92 | 93 | # Initialize CLIP 94 | self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, 95 | download_root=clip_checkpoints, jit=False) 96 | self.clip_image = ImageCLIP(self.clip) 97 | self.clip_image = torch.nn.DataParallel(self.clip_image) 98 | self.clip_text = TextCLIP(self.clip) 99 | self.clip_text = torch.nn.DataParallel(self.clip_text) 100 | 101 | # Init arguments 102 | self.target_seq_length = target_seq_length 103 | self.reset_context_delta = reset_context_delta 104 | self.num_iterations = num_iterations 105 | self.clip_loss_temperature = clip_loss_temperature 106 | self.clip_scale = clip_scale 107 | self.ce_scale = ce_scale 108 | self.stepsize = stepsize 109 | self.grad_norm_factor = grad_norm_factor 110 | self.fusion_factor = fusion_factor 111 | self.repetition_penalty = repetition_penalty 112 | self.end_token = self.lm_tokenizer.encode(end_token)[0] 113 | self.end_factor = end_factor 114 | self.ef_idx = 1 115 | self.forbidden_factor = forbidden_factor 116 | 117 | def get_img_feature(self, img_path, weights): 118 | imgs = [Image.open(x) for x in img_path] 119 | clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] 120 | 121 | with torch.no_grad(): 122 | image_fts = [self.clip_image(x) for x in clip_imgs] 123 | 124 | if weights is not None: 125 | image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) 126 | else: 127 | image_features = sum(image_fts) 128 | 129 | image_features = torch.nn.functional.normalize(image_features, dim=-1) 130 | return image_features.detach() 131 | 132 | def get_txt_features(self, text): 133 | clip_texts = clip.tokenize(text).to(self.device) 134 | 135 | with torch.no_grad(): 136 | text_features = self.clip_text(clip_texts) 137 | 138 | text_features = torch.nn.functional.normalize(text_features, dim=-1) 139 | return text_features.detach() 140 | 141 | def get_combined_feature(self, img_path, texts, weights_i, weights_t): 142 | imgs = [Image.open(x) for x in img_path] 143 | clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] 144 | clip_texts = [clip.tokenize(x).to(self.device) for x in texts] 145 | 146 | with torch.no_grad(): 147 | image_fts = [self.clip.encode_image(x) for x in clip_imgs] 148 | text_fts = [self.clip.encode_text(x) for x in clip_texts] 149 | 150 | features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) 151 | if weights_t is not None: 152 | features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) 153 | 154 | features = features / features.norm(dim=-1, keepdim=True) 155 | return features.detach() 156 | 157 | def run(self, image_features, cond_text, beam_size): 158 | self.image_features = image_features 159 | 160 | context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) 161 | 162 | output_tokens, output_text = self.generate_text(context_tokens, beam_size) 163 | 164 | return output_text 165 | 166 | def generate_text(self, context_tokens, beam_size): 167 | context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) 168 | 169 | gen_tokens = None 170 | scores = None 171 | seq_lengths = torch.ones(beam_size, device=self.device) 172 | is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) 173 | 174 | for i in range(self.target_seq_length): 175 | probs = self.get_next_probs(i, context_tokens) 176 | logits = probs.log() 177 | 178 | if scores is None: 179 | scores, next_tokens = logits.topk(beam_size, -1) 180 | context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) 181 | next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) 182 | 183 | if gen_tokens is None: 184 | gen_tokens = next_tokens 185 | else: 186 | gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) 187 | gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) 188 | else: 189 | logits[is_stopped] = -float(np.inf) 190 | logits[is_stopped, 0] = 0 191 | scores_sum = scores[:, None] + logits 192 | seq_lengths[~is_stopped] += 1 193 | scores_sum_average = scores_sum / seq_lengths[:, None] 194 | scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( 195 | beam_size, -1) 196 | next_tokens_source = next_tokens // scores_sum.shape[1] 197 | seq_lengths = seq_lengths[next_tokens_source] 198 | next_tokens = next_tokens % scores_sum.shape[1] 199 | next_tokens = next_tokens.unsqueeze(1) 200 | gen_tokens = gen_tokens[next_tokens_source] 201 | gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) 202 | context_tokens = context_tokens[next_tokens_source] 203 | scores = scores_sum_average * seq_lengths 204 | is_stopped = is_stopped[next_tokens_source] 205 | 206 | context_tokens = torch.cat((context_tokens, next_tokens), dim=1) 207 | is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() 208 | 209 | #### 210 | tmp_scores = scores / seq_lengths 211 | tmp_output_list = gen_tokens.cpu().numpy() 212 | tmp_output_texts = [ 213 | self.lm_tokenizer.decode(tmp_output) 214 | for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) 215 | ] 216 | tmp_order = tmp_scores.argsort(descending=True) 217 | tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] 218 | log_info(tmp_output_texts, verbose=True) 219 | #### 220 | 221 | if is_stopped.all(): 222 | break 223 | 224 | scores = scores / seq_lengths 225 | output_list = gen_tokens.cpu().numpy() 226 | output_texts = [ 227 | self.lm_tokenizer.decode(output[: int(length)]) 228 | for output, length in zip(output_list, seq_lengths) 229 | ] 230 | order = scores.argsort(descending=True) 231 | output_texts = [output_texts[i] for i in order] 232 | 233 | return context_tokens, output_texts 234 | 235 | def get_next_probs(self, i, context_tokens): 236 | last_token = context_tokens[:, -1:] 237 | 238 | if self.reset_context_delta and context_tokens.size(1) > 1: 239 | context = self.lm_model(context_tokens[:, :-1])["past_key_values"] 240 | 241 | # Logits of LM with unshifted context 242 | logits_before_shift = self.lm_model(context_tokens)["logits"] 243 | logits_before_shift = logits_before_shift[:, -1, :] 244 | probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) 245 | 246 | if context: 247 | context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) 248 | 249 | lm_output = self.lm_model(last_token, past_key_values=context) 250 | logits, past = ( 251 | lm_output["logits"], 252 | lm_output["past_key_values"], 253 | ) 254 | logits = logits[:, -1, :] 255 | 256 | logits = self.update_special_tokens_logits(context_tokens, i, logits) 257 | 258 | probs = nn.functional.softmax(logits, dim=-1) 259 | probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) 260 | probs = probs / probs.sum() 261 | 262 | return probs 263 | 264 | def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): 265 | context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] 266 | 267 | for i in range(self.num_iterations): 268 | curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in 269 | context_delta] 270 | 271 | for p0, p1 in curr_shift: 272 | p0.retain_grad() 273 | p1.retain_grad() 274 | 275 | shifted_context = list(map(add_context, context, curr_shift)) 276 | 277 | shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) 278 | logits = shifted_outputs["logits"][:, -1, :] 279 | probs = nn.functional.softmax(logits, dim=-1) 280 | 281 | loss = 0.0 282 | 283 | # CLIP LOSS 284 | clip_loss, clip_losses = self.clip_loss(probs, context_tokens) 285 | loss += self.clip_scale * clip_loss 286 | 287 | # CE/Fluency loss 288 | ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) 289 | loss += ce_loss.sum() 290 | 291 | loss.backward() 292 | 293 | # --------- Specific Gen --------- 294 | final_grads = self.norm_grad(context, context_tokens, curr_shift) 295 | 296 | # --------- update context --------- 297 | context_delta = list(map(add_context, final_grads, context_delta)) 298 | 299 | for p0, p1 in curr_shift: 300 | p0.grad.data.zero_() 301 | p1.grad.data.zero_() 302 | 303 | new_context = [] 304 | for p0, p1 in context: 305 | new_context.append((p0.detach(), p1.detach())) 306 | context = new_context 307 | 308 | context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) 309 | for p_ in context_delta] 310 | context = list(map(add_context, context, context_delta)) 311 | 312 | new_context = [] 313 | for p0, p1 in context: 314 | new_context.append((p0.detach(), p1.detach())) 315 | context = new_context 316 | 317 | return context 318 | 319 | def norm_grad(self, context, context_tokens, curr_shift, ): 320 | factor = 1 321 | sep_grads = None 322 | window_mask = torch.ones_like(context[0][0]).to(self.device) 323 | 324 | for b in range(context_tokens.shape[0]): 325 | tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] 326 | for p_ in curr_shift] 327 | 328 | # normalize gradients 329 | tmp_grad = [tuple([-self.stepsize * factor * ( 330 | x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ 331 | j] ** self.grad_norm_factor).data.cpu().numpy() 332 | for j, x in enumerate(p_)]) 333 | for i, p_ in enumerate(curr_shift)] 334 | if sep_grads is None: 335 | sep_grads = tmp_grad 336 | else: 337 | for l_index in range(len(sep_grads)): 338 | sep_grads[l_index] = list(sep_grads[l_index]) 339 | for k_index in range(len(sep_grads[0])): 340 | sep_grads[l_index][k_index] = np.concatenate( 341 | (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) 342 | sep_grads[l_index] = tuple(sep_grads[l_index]) 343 | final_grads = sep_grads 344 | 345 | return final_grads 346 | 347 | def update_special_tokens_logits(self, context_tokens, i, logits): 348 | for beam_id in range(context_tokens.shape[0]): 349 | for token_idx in set(context_tokens[beam_id][-4:].tolist()): 350 | factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) 351 | logits[beam_id, token_idx] /= factor 352 | 353 | if i >= self.ef_idx: 354 | factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) 355 | logits[beam_id, self.end_token] *= factor 356 | if i == 0: 357 | start_factor = 1.6 358 | factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) 359 | logits[beam_id, self.end_token] /= factor 360 | 361 | for token_idx in list(self.forbidden_tokens): 362 | factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) 363 | logits[beam_id, token_idx] /= factor 364 | 365 | return logits 366 | 367 | def clip_loss(self, probs, context_tokens): 368 | for p_ in self.clip.transformer.parameters(): 369 | if p_.grad is not None: 370 | p_.grad.data.zero_() 371 | 372 | top_size = 512 373 | top_probs, top_indices = probs.topk(top_size, -1) 374 | 375 | prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens] 376 | 377 | clip_loss = 0 378 | losses = [] 379 | 380 | top_texts = [] 381 | for idx_p in range(probs.shape[0]): 382 | prefix_text = prefix_texts[idx_p] 383 | for x in top_indices[idx_p]: 384 | top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) 385 | 386 | text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1) 387 | 388 | with torch.no_grad(): 389 | similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1) 390 | similiraties = similiraties.reshape(probs.size(0), -1) 391 | target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() 392 | target_probs = target_probs.type(torch.float32) 393 | 394 | clip_loss += torch.sum(-(target_probs * torch.log(top_probs))) 395 | # for idx_p in range(probs.shape[0]): 396 | # top_texts = [] 397 | # prefix_text = prefix_texts[idx_p] 398 | # for x in top_indices[idx_p]: 399 | # top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) 400 | # text_features = self.get_txt_features(top_texts) 401 | # 402 | # with torch.no_grad(): 403 | # similiraties = (self.image_features @ text_features.T) 404 | # target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() 405 | # target_probs = target_probs.type(torch.float32) 406 | # 407 | # target = torch.zeros_like(probs[idx_p]) 408 | # target[top_indices[idx_p]] = target_probs[0] 409 | # target = target.unsqueeze(0) 410 | # cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) 411 | # 412 | # clip_loss += cur_clip_loss 413 | # losses.append(cur_clip_loss) 414 | 415 | return clip_loss, losses 416 | 417 | def clip_loss_old(self, probs, context_tokens): 418 | for p_ in self.clip.transformer.parameters(): 419 | if p_.grad is not None: 420 | p_.grad.data.zero_() 421 | 422 | top_size = 512 423 | _, top_indices = probs.topk(top_size, -1) 424 | 425 | prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] 426 | 427 | clip_loss = 0 428 | losses = [] 429 | for idx_p in range(probs.shape[0]): 430 | top_texts = [] 431 | prefix_text = prefix_texts[idx_p] 432 | for x in top_indices[idx_p]: 433 | top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) 434 | text_features = self.get_txt_features(top_texts) 435 | 436 | with torch.no_grad(): 437 | similiraties = (self.image_features @ text_features.T) 438 | target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() 439 | target_probs = target_probs.type(torch.float32) 440 | 441 | target = torch.zeros_like(probs[idx_p]) 442 | target[top_indices[idx_p]] = target_probs[0] 443 | target = target.unsqueeze(0) 444 | cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) 445 | 446 | clip_loss += cur_clip_loss 447 | losses.append(cur_clip_loss) 448 | 449 | return clip_loss, losses -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoadTew/zero-shot-image-to-text/b99f0af59aa98f9447c5e691e702fd5c9cccd5fb/model/__init__.py -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import sys 4 | sys.path.append('CLIP') 5 | from pathlib import Path 6 | import cog 7 | import argparse 8 | import torch 9 | import clip 10 | from model.ZeroCLIP import CLIPTextGenerator 11 | 12 | def perplexity_score(text, lm_model, lm_tokenizer, device): 13 | encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') 14 | input_ids = encodings.input_ids.to(device) 15 | target_ids = input_ids.clone() 16 | 17 | outputs = lm_model(input_ids, labels=target_ids) 18 | log_likelihood = outputs[0] 19 | ll = log_likelihood.item() 20 | 21 | return ll 22 | 23 | class Predictor(cog.Predictor): 24 | def setup(self): 25 | self.args = get_args() 26 | self.args.reset_context_delta = True 27 | self.text_generator = CLIPTextGenerator(**vars(self.args)) 28 | 29 | @cog.input( 30 | "image", 31 | type=Path, 32 | help="input image" 33 | ) 34 | @cog.input( 35 | "cond_text", 36 | type=str, 37 | default='Image of a', 38 | help="conditional text", 39 | ) 40 | @cog.input( 41 | "beam_size", 42 | type=int, 43 | default=5, min=1, max=10, 44 | help="Number of beams to use", 45 | ) 46 | @cog.input( 47 | "end_factor", 48 | type=float, 49 | default=1.01, min=1.0, max=1.10, 50 | help="Higher value for shorter captions", 51 | ) 52 | @cog.input( 53 | "max_seq_length", 54 | type=int, 55 | default=15, min=1, max=20, 56 | help="Maximum number of tokens to generate", 57 | ) 58 | @cog.input( 59 | "ce_loss_scale", 60 | type=float, 61 | default=0.2, min=0.0, max=0.6, 62 | help="Scale of cross-entropy loss with un-shifted language model", 63 | ) 64 | def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale): 65 | self.args.cond_text = cond_text 66 | self.text_generator.end_factor = end_factor 67 | self.text_generator.target_seq_length = max_seq_length 68 | self.text_generator.ce_scale = ce_loss_scale 69 | 70 | image_features = self.text_generator.get_img_feature([str(image)], None) 71 | captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) 72 | 73 | # CLIP SCORE 74 | encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) 75 | for c in captions] 76 | encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] 77 | best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() 78 | 79 | # Perplexity SCORE 80 | ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] 81 | best_ppl_index = torch.tensor(ppl_scores).argmin().item() 82 | 83 | best_clip_caption = self.args.cond_text + captions[best_clip_idx] 84 | best_mixed = self.args.cond_text + captions[0] 85 | best_PPL = self.args.cond_text + captions[best_ppl_index] 86 | 87 | final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' 88 | 89 | return final 90 | # return self.args.cond_text + captions[best_clip_idx] 91 | 92 | 93 | def get_args(): 94 | parser = argparse.ArgumentParser() 95 | 96 | parser.add_argument("--seed", type=int, default=0) 97 | parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") 98 | parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") 99 | parser.add_argument("--target_seq_length", type=int, default=15) 100 | parser.add_argument("--cond_text", type=str, default="Image of a") 101 | parser.add_argument("--reset_context_delta", action="store_true", 102 | help="Should we reset the context at each token gen") 103 | parser.add_argument("--num_iterations", type=int, default=5) 104 | parser.add_argument("--clip_loss_temperature", type=float, default=0.01) 105 | parser.add_argument("--clip_scale", type=float, default=1) 106 | parser.add_argument("--ce_scale", type=float, default=0.2) 107 | parser.add_argument("--stepsize", type=float, default=0.3) 108 | parser.add_argument("--grad_norm_factor", type=float, default=0.9) 109 | parser.add_argument("--fusion_factor", type=float, default=0.99) 110 | parser.add_argument("--repetition_penalty", type=float, default=1) 111 | parser.add_argument("--end_token", type=str, default=".", help="Token to end text") 112 | parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") 113 | parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") 114 | parser.add_argument("--beam_size", type=int, default=5) 115 | 116 | args = parser.parse_args('') 117 | return args 118 | -------------------------------------------------------------------------------- /predict_arithmetic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import sys 4 | sys.path.append('CLIP') 5 | from pathlib import Path 6 | import cog 7 | import argparse 8 | import torch 9 | import clip 10 | from model.ZeroCLIP import CLIPTextGenerator 11 | 12 | def perplexity_score(text, lm_model, lm_tokenizer, device): 13 | encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') 14 | input_ids = encodings.input_ids.to(device) 15 | target_ids = input_ids.clone() 16 | 17 | outputs = lm_model(input_ids, labels=target_ids) 18 | log_likelihood = outputs[0] 19 | ll = log_likelihood.item() 20 | 21 | return ll 22 | 23 | class Predictor(cog.Predictor): 24 | def setup(self): 25 | self.args = get_args() 26 | self.args.reset_context_delta = True 27 | self.text_generator = CLIPTextGenerator(**vars(self.args)) 28 | 29 | @cog.input( 30 | "image1", 31 | type=Path, 32 | help="Final result will be: image1 + (image2 - image3)" 33 | ) 34 | @cog.input( 35 | "image2", 36 | type=Path, 37 | help="Final result will be: image1 + (image2 - image3)" 38 | ) 39 | @cog.input( 40 | "image3", 41 | type=Path, 42 | help="Final result will be: image1 + (image2 - image3)" 43 | ) 44 | @cog.input( 45 | "cond_text", 46 | type=str, 47 | default='Image of a', 48 | help="conditional text", 49 | ) 50 | @cog.input( 51 | "beam_size", 52 | type=int, 53 | default=3, min=1, max=10, 54 | help="Number of beams to use", 55 | ) 56 | @cog.input( 57 | "end_factors", 58 | type=float, 59 | default=1.06, min=1.0, max=1.10, 60 | help="Higher value for shorter captions", 61 | ) 62 | @cog.input( 63 | "max_seq_lengths", 64 | type=int, 65 | default=3, min=1, max=20, 66 | help="Maximum number of tokens to generate", 67 | ) 68 | @cog.input( 69 | "ce_loss_scale", 70 | type=float, 71 | default=0.2, min=0.0, max=0.6, 72 | help="Scale of cross-entropy loss with un-shifted language model", 73 | ) 74 | def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale): 75 | self.args.cond_text = cond_text 76 | self.text_generator.end_factor = end_factors 77 | self.text_generator.target_seq_length = max_seq_lengths 78 | self.text_generator.ce_scale = ce_loss_scale 79 | self.text_generator.fusion_factor = 0.95 80 | self.text_generator.grad_norm_factor = 0.95 81 | 82 | image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None) 83 | captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) 84 | 85 | # CLIP SCORE 86 | encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) 87 | for c in captions] 88 | encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] 89 | best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() 90 | 91 | # Perplexity SCORE 92 | ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] 93 | best_ppl_index = torch.tensor(ppl_scores).argmin().item() 94 | 95 | best_clip_caption = self.args.cond_text + captions[best_clip_idx] 96 | best_mixed = self.args.cond_text + captions[0] 97 | best_PPL = self.args.cond_text + captions[best_ppl_index] 98 | 99 | final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' 100 | 101 | return final 102 | # return self.args.cond_text + captions[best_clip_idx] 103 | 104 | 105 | def get_args(): 106 | parser = argparse.ArgumentParser() 107 | 108 | parser.add_argument("--seed", type=int, default=0) 109 | parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") 110 | parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") 111 | parser.add_argument("--target_seq_length", type=int, default=15) 112 | parser.add_argument("--cond_text", type=str, default="Image of a") 113 | parser.add_argument("--reset_context_delta", action="store_true", 114 | help="Should we reset the context at each token gen") 115 | parser.add_argument("--num_iterations", type=int, default=5) 116 | parser.add_argument("--clip_loss_temperature", type=float, default=0.01) 117 | parser.add_argument("--clip_scale", type=float, default=1) 118 | parser.add_argument("--ce_scale", type=float, default=0.2) 119 | parser.add_argument("--stepsize", type=float, default=0.3) 120 | parser.add_argument("--grad_norm_factor", type=float, default=0.95) 121 | parser.add_argument("--fusion_factor", type=float, default=0.95) 122 | parser.add_argument("--repetition_penalty", type=float, default=1) 123 | parser.add_argument("--end_token", type=str, default=".", help="Token to end text") 124 | parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") 125 | parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") 126 | parser.add_argument("--beam_size", type=int, default=5) 127 | 128 | args = parser.parse_args('') 129 | return args 130 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | torch==1.9 5 | torchvision 6 | transformers==4.11.2 7 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import clip 4 | from model.ZeroCLIP import CLIPTextGenerator 5 | from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--seed", type=int, default=0) 11 | parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") 12 | parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") 13 | parser.add_argument("--target_seq_length", type=int, default=15) 14 | parser.add_argument("--cond_text", type=str, default="Image of a") 15 | parser.add_argument("--reset_context_delta", action="store_true", 16 | help="Should we reset the context at each token gen") 17 | parser.add_argument("--num_iterations", type=int, default=5) 18 | parser.add_argument("--clip_loss_temperature", type=float, default=0.01) 19 | parser.add_argument("--clip_scale", type=float, default=1) 20 | parser.add_argument("--ce_scale", type=float, default=0.2) 21 | parser.add_argument("--stepsize", type=float, default=0.3) 22 | parser.add_argument("--grad_norm_factor", type=float, default=0.9) 23 | parser.add_argument("--fusion_factor", type=float, default=0.99) 24 | parser.add_argument("--repetition_penalty", type=float, default=1) 25 | parser.add_argument("--end_token", type=str, default=".", help="Token to end text") 26 | parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") 27 | parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") 28 | parser.add_argument("--beam_size", type=int, default=5) 29 | 30 | parser.add_argument("--multi_gpu", action="store_true") 31 | 32 | parser.add_argument('--run_type', 33 | default='caption', 34 | nargs='?', 35 | choices=['caption', 'arithmetics']) 36 | 37 | parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg', 38 | help="Path to image for captioning") 39 | 40 | parser.add_argument("--arithmetics_imgs", nargs="+", 41 | default=['example_images/arithmetics/woman2.jpg', 42 | 'example_images/arithmetics/king2.jpg', 43 | 'example_images/arithmetics/man2.jpg']) 44 | parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1]) 45 | 46 | args = parser.parse_args() 47 | 48 | return args 49 | 50 | def run(args, img_path): 51 | if args.multi_gpu: 52 | text_generator = CLIPTextGenerator_multigpu(**vars(args)) 53 | else: 54 | text_generator = CLIPTextGenerator(**vars(args)) 55 | 56 | image_features = text_generator.get_img_feature([img_path], None) 57 | captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size) 58 | 59 | encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions] 60 | encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] 61 | best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() 62 | 63 | print(captions) 64 | print('best clip:', args.cond_text + captions[best_clip_idx]) 65 | 66 | def run_arithmetic(args, imgs_path, img_weights): 67 | if args.multi_gpu: 68 | text_generator = CLIPTextGenerator_multigpu(**vars(args)) 69 | else: 70 | text_generator = CLIPTextGenerator(**vars(args)) 71 | 72 | image_features = text_generator.get_combined_feature(imgs_path, [], img_weights, None) 73 | captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size) 74 | 75 | encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions] 76 | encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] 77 | best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() 78 | 79 | print(captions) 80 | print('best clip:', args.cond_text + captions[best_clip_idx]) 81 | 82 | if __name__ == "__main__": 83 | args = get_args() 84 | 85 | if args.run_type == 'caption': 86 | run(args, img_path=args.caption_img_path) 87 | elif args.run_type == 'arithmetics': 88 | args.arithmetics_weights = [float(x) for x in args.arithmetics_weights] 89 | run_arithmetic(args, imgs_path=args.arithmetics_imgs, img_weights=args.arithmetics_weights) 90 | else: 91 | raise Exception('run_type must be caption or arithmetics!') -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="zero-shot-image-to-text", 8 | py_modules=["zero-shot-image-to-text"], 9 | version="1.0", 10 | description="", 11 | packages=find_packages(), 12 | install_requires=[ 13 | str(r) 14 | for r in pkg_resources.parse_requirements( 15 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 16 | ) 17 | ], 18 | include_package_data=True 19 | ) --------------------------------------------------------------------------------