├── LICENSE ├── POS_classifier.py ├── README.md ├── __pycache__ └── utils.cpython-38.pyc ├── app.py ├── clip ├── build_text_index.py ├── clip.py └── clipretrieval.py ├── compute_n_div.py ├── control_gen_utils.py ├── demo.py ├── examples ├── Gosh.jpeg ├── cat.png ├── girl.jpg └── horse.png ├── gen_utils.py ├── paper_images ├── diversecaptioning.jpg ├── framework.gif ├── framework.jpg ├── gibbs_bert.gif ├── gibbs_bert_mask.gif ├── lengthcontrol.jpg ├── moreimagestyles.jpg ├── poscontrol.jpg ├── sentimentcontrol.jpg └── style_examples.pdf ├── requirements.txt ├── run.py ├── sentiments_classifer.py ├── stop_words.txt └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zequn Zeng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /POS_classifier.py: -------------------------------------------------------------------------------- 1 | from nltk.tokenize import word_tokenize 2 | from nltk import pos_tag 3 | import torch 4 | import json 5 | 6 | def batch_texts_POS_analysis(batch_texts, pos_templete, device="cuda"): 7 | batch_size = len(batch_texts) 8 | pos_tags = [] 9 | pos_scores = torch.zeros(batch_size) 10 | 11 | for b_id in range(batch_size): 12 | text = batch_texts[b_id] 13 | words = word_tokenize(text) 14 | word_tag = pos_tag(words, tagset="universal") 15 | res_tag = [tag[1] for tag in word_tag] 16 | total_num = len(pos_templete) 17 | correct = 0 18 | if len(res_tag) <= total_num: 19 | cur_tag = res_tag + [""] * (len(pos_templete)-len(res_tag)) 20 | else: 21 | cur_tag = res_tag[:total_num] 22 | for word_id in range(len(cur_tag)): 23 | if pos_templete[word_id]=="": 24 | correct += 1 25 | elif cur_tag[word_id] in pos_templete[word_id]: 26 | correct +=1 27 | acc = correct/total_num 28 | pos_tags.append(res_tag) 29 | pos_scores[b_id] = acc 30 | 31 | return pos_tags, pos_scores 32 | 33 | def text_POS_analysis(text): 34 | words = word_tokenize(text) 35 | word_tag = pos_tag(words, tagset="universal") 36 | res_tag = [tag[1] for tag in word_tag] 37 | 38 | return res_tag 39 | 40 | if __name__=="__main__": 41 | batch_texts = ["A cat sitting in the bed.", 42 | "Two men in a nice hotel room one playing a video game with a remote control.", 43 | "The man sitting in the chair feels like an invisible,dead man."] 44 | pos_templete = ['DET', 'NOUN', 'ADP', 'ADJ', 'NOUN', '.', 'NOUN', 'CONJ', 'NOUN', 'ADP', 'PRON', '.'] 45 | 46 | batch_texts_POS_analysis(batch_texts, pos_templete, device="cuda") 47 | cur_path = "iter_15.json" 48 | all_caption = [] 49 | 50 | with open(cur_path, "r") as cur_json_file: 51 | all_res = list(json.load(cur_json_file).values()) 52 | for res in all_res: 53 | if isinstance(res, list): 54 | all_caption += res 55 | else: 56 | all_caption.append(res) 57 | pos_tags, pos_scores = batch_texts_POS_analysis(all_caption, pos_templete, device="cuda") 58 | word_id = 12 59 | pos_dict = {"ADJ": 0, "ADP": 0, "ADV": 0, 60 | "CONJ": 0, "DET": 0, "NOUN": 0,"X":0, 61 | "NUM": 0, "PRT": 0, "PRON": 0, "VERB": 0, ".": 0} 62 | for pos_tag in pos_tags: 63 | if word_id < len(pos_tag): 64 | pos_dict[pos_tag[word_id]] += 1 65 | print(1) 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConZIC 2 | **[CVPR 2023][ConZIC: Controllable Zero-shot Image Captioning by Sampling-Based Polishing](https://arxiv.org/abs/2303.02437)** 3 |
4 | [Zequn Zeng](https://joeyz0z.github.io/), 5 | [Hao Zhang](https://scholar.google.com/citations?user=Eo8e5icAAAAJ), 6 | [Zhengjue Wang](https://scholar.google.com/citations?user=qTQj_I4AAAAJ), 7 | [Ruiying Lu](https://ieeexplore.ieee.org/author/37088439713), 8 | [Dongsheng Wang](https://wds2014.github.io/), 9 | [Bo Chen](https://scholar.google.com/citations?user=uv16_-UAAAAJ) 10 |
11 | 12 | 13 | 14 | [comment]: <> ([![Project Website](https://img.shields.io/badge/Project-Website-orange)](https://tuneavideo.github.io/)) 15 | [![arXiv](https://img.shields.io/badge/arXiv-2303.02437-b31b1b.svg)](https://arxiv.org/abs/2303.02437) 16 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/jiaqingj/ConZIC) 17 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MyjEAuygQblYwTjK67ATo7XALuHkkiLa?usp=sharing) 18 | 19 | ### News 20 | * [2023/4] Adding demo on Huggingface Space and Colab! 21 | * [2023/3] ConZIC is publicly released! 22 | 23 | *** 24 | ### Framework 25 | ![](paper_images/framework.gif) 26 | 27 | #### Gibbs-BERT 28 | ![](paper_images/gibbs_bert_mask.gif) 29 | 30 | #### Example of sentiment control 31 | ![](paper_images/sentimentcontrol.jpg) 32 | 33 | 34 | 35 | 36 | ## DEMO 37 | 38 | ### Preparation 39 | Please download [CLIP](https://huggingface.co/openai/clip-vit-base-patch32) and [BERT](https://huggingface.co/bert-base-uncased) from Huggingface Space. 40 | 41 | SketchyCOCOcaption benchmark in our work is available [here](https://drive.google.com/file/d/1WBaq8OdvyyXpbYtmuFIvko6855rESwHE/view?usp=share_link). 42 | 43 | Environments setup. 44 | ``` 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | ### To run zero-shot captioning on images: 49 | ConZIC supports arbitary generation orders by change **order**. You can increase **alpha** for more fluency, **beta** for more image content. Notably, there is a trade-off between fluency and image-matching degree. 50 | **Sequential**: update tokens in classical left to right order. At each iteration, the whole sentence will be updated. 51 | ``` 52 | python demo.py --run_type "caption" --order "sequential" --sentence_len 10 --caption_img_path "./examples/girl.jpg" --samples_num 1 53 | --lm_model "bert-base-uncased" --match_model "openai/clip-vit-base-patch32" 54 | --alpha 0.02 --beta 2.0 55 | ``` 56 | **Shuffled**: update tokens in random shuffled generation order, different orders resulting in different captions. 57 | ``` 58 | python demo.py --run_type "caption" --order "shuffle" --sentence_len 10 --caption_img_path "./examples/girl.jpg" --samples_num 3 59 | --lm_model "bert-base-uncased" --match_model "openai/clip-vit-base-patch32" 60 | --alpha 0.02 --beta 2.0 61 | ``` 62 | **Random**: only randomly select a position and then update this token at each iteration, high diversity due to high randomness. 63 | ``` 64 | python demo.py --run_type "caption" --order "random" --sentence_len 10 --caption_img_path "./examples/girl.jpg" --samples_num 3 65 | --lm_model "bert-base-uncased" --match_model "openai/clip-vit-base-patch32" 66 | --alpha 0.02 --beta 2.0 67 | ``` 68 | 69 | ### To run controllable zero-shot captioning on images: 70 | ConZIC supports many text-related controllable signals. For examples: 71 | **Sentiments(positive/negative)**: you can increase **gamma** for higher controllable degree, there is also a trade-off. 72 | ``` 73 | python demo.py 74 | --run_type "controllable" --control_type "sentiment" --sentiment_type "positive" 75 | --order "sequential" --sentence_len 10 --caption_img_path "./examples/girl.jpg" --samples_num 1 76 | --lm_model "bert-base-uncased" --match_model "openai/clip-vit-base-patch32" 77 | --alpha 0.02 --beta 2.0 --gamma 5.0 78 | ``` 79 | **Part-of-speech(POS)**: it will meet the predefined POS templete as much as possible. 80 | ``` 81 | python demo.py 82 | --run_type "controllable" --control_type "pos" --order "sequential" 83 | --pos_type "your predefined POS templete" 84 | --sentence_len 10 --caption_img_path "./examples/girl.jpg" --samples_num 1 85 | --lm_model "bert-base-uncased" --match_model "openai/clip-vit-base-patch32" 86 | --alpha 0.02 --beta 2.0 --gamma 5.0 87 | ``` 88 | **Length**: change **sentence_len**. 89 | 90 | ## Gradio Demo 91 | We highly recommend to use the following **WebUI** demo in your browser from the local url: http://127.0.0.1:7860. 92 | ``` 93 | pip install gradio 94 | python app.py --lm_model "bert-base-uncased" --match_model "openai/clip-vit-base-patch32" 95 | ``` 96 | You can also use the **demo.launch()** function to create a public link used by anyone to access the demo from their browser by setting share=True. 97 | 98 | **** 99 | ### Citation 100 | Please cite our work if you use it in your research: 101 | ``` 102 | @article{zeng2023conzic, 103 | title={ConZIC: Controllable Zero-shot Image Captioning by Sampling-Based Polishing}, 104 | author={Zeng, Zequn and Zhang, Hao and Wang, Zhengjue and Lu, Ruiying and Wang, Dongsheng and Chen, Bo}, 105 | journal={arXiv preprint arXiv:2303.02437}, 106 | year={2023} 107 | } 108 | ``` 109 | 110 | ### Contact 111 | If you have any questions, please contact zzequn99@163.com or zhanghao_xidian@163.com. 112 | 113 | 114 | ### Acknowledgment 115 | This code is based on the [bert-gen](https://github.com/nyu-dl/bert-gen) and [MAGIC](https://github.com/yxuansu/MAGIC). 116 | 117 | Thanks for [Jiaqing Jiang](https://github.com/blre6) providing huggingface and Colab demo. 118 | 119 | 120 | -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from utils import create_logger, set_seed, format_output 2 | import os 3 | import time 4 | import argparse 5 | import json 6 | from PIL import Image 7 | import torch 8 | import gradio as gr 9 | import nltk 10 | 11 | from clip.clip import CLIP 12 | from gen_utils import generate_caption 13 | from control_gen_utils import control_generate_caption 14 | from transformers import AutoModelForMaskedLM, AutoTokenizer 15 | 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser() 19 | 20 | parser.add_argument("--seed", type=int, default=42) 21 | parser.add_argument("--batch_size", type=int, default=1, help = "Only supports batch_size=1 currently.") 22 | parser.add_argument("--device", type=str, 23 | default='cuda',choices=['cuda','cpu']) 24 | 25 | ## Generation and Controllable Type 26 | parser.add_argument('--run_type', 27 | default='caption', 28 | nargs='?', 29 | choices=['caption', 'controllable']) 30 | parser.add_argument('--prompt', 31 | default='Image of a',type=str) 32 | parser.add_argument('--order', 33 | default='shuffle', 34 | nargs='?', 35 | choices=['sequential', 'shuffle', 'span', 'random'], 36 | help="Generation order of text") 37 | parser.add_argument('--control_type', 38 | default='sentiment', 39 | nargs='?', 40 | choices=["sentiment","pos"], 41 | help="which controllable task to conduct") 42 | parser.add_argument('--pos_type', type=list, 43 | default=[['DET'], ['ADJ','NOUN'], ['NOUN'], 44 | ['VERB'], ['VERB'],['ADV'], ['ADP'], 45 | ['DET','NOUN'], ['NOUN'], ['NOUN','.'], 46 | ['.','NOUN'],['.','NOUN']], 47 | help="predefined part-of-speech templete") 48 | parser.add_argument('--sentiment_type', 49 | default="positive", 50 | nargs='?', 51 | choices=["positive", "negative"]) 52 | parser.add_argument('--samples_num', 53 | default=2,type=int) 54 | 55 | ## Hyperparameters 56 | parser.add_argument("--sentence_len", type=int, default=10) 57 | parser.add_argument("--candidate_k", type=int, default=200) 58 | parser.add_argument("--alpha", type=float, default=0.02, help="weight for fluency") 59 | parser.add_argument("--beta", type=float, default=2.0, help="weight for image-matching degree") 60 | parser.add_argument("--gamma", type=float, default=5.0, help="weight for controllable degree") 61 | parser.add_argument("--lm_temperature", type=float, default=0.1) 62 | parser.add_argument("--num_iterations", type=int, default=1, help="predefined iterations for Gibbs Sampling") 63 | 64 | ## Models and Paths 65 | parser.add_argument("--lm_model", type=str, default='bert-base-uncased', 66 | help="Path to language model") # bert,roberta 67 | parser.add_argument("--match_model", type=str, default='clip-vit-base-patch32', 68 | help="Path to Image-Text model") # clip,align 69 | parser.add_argument("--caption_img_path", type=str, default='./examples/girl.jpg', 70 | help="file path of the image for captioning") 71 | parser.add_argument("--stop_words_path", type=str, default='stop_words.txt', 72 | help="Path to stop_words.txt") 73 | parser.add_argument("--add_extra_stopwords", type=list, default=[], 74 | help="you can add some extra stop words") 75 | 76 | args = parser.parse_args() 77 | 78 | return args 79 | 80 | def run_caption(args, image, lm_model, lm_tokenizer, clip, token_mask, logger): 81 | FinalCaptionList = [] 82 | BestCaptionList = [] 83 | img_name = ['Your image'] 84 | image_instance = image.convert("RGB") 85 | for sample_id in range(args.samples_num): 86 | logger.info(f"Sample {sample_id}: ") 87 | gen_texts, clip_scores = generate_caption(img_name, lm_model, clip, lm_tokenizer, image_instance, token_mask, logger, 88 | prompt=args.prompt, batch_size=args.batch_size, max_len=args.sentence_len, 89 | top_k=args.candidate_k, temperature=args.lm_temperature, 90 | max_iter=args.num_iterations,alpha=args.alpha,beta=args.beta, 91 | generate_order = args.order) 92 | FinalCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-2][0] 93 | BestCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-1][0] 94 | FinalCaptionList.append(FinalCaptionStr) 95 | BestCaptionList.append(BestCaptionStr) 96 | return FinalCaptionList, BestCaptionList 97 | 98 | 99 | 100 | def run_control(run_type, args, image, lm_model, lm_tokenizer, clip, token_mask, logger): 101 | FinalCaptionList = [] 102 | BestCaptionList = [] 103 | img_name = ['Your image'] 104 | image_instance = image.convert("RGB") 105 | for sample_id in range(args.samples_num): 106 | logger.info(f"Sample {sample_id}: ") 107 | gen_texts, clip_scores = control_generate_caption(img_name, lm_model, clip, lm_tokenizer, image_instance, token_mask, logger, 108 | prompt=args.prompt, batch_size=args.batch_size, max_len=args.sentence_len, 109 | top_k=args.candidate_k, temperature=args.lm_temperature, 110 | max_iter=args.num_iterations, alpha=args.alpha, 111 | beta=args.beta, gamma=args.gamma, 112 | ctl_type = args.control_type, style_type=args.sentiment_type,pos_type=args.pos_type, generate_order=args.order) 113 | FinalCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-2][0] 114 | BestCaptionStr = "Sample {}: ".format(sample_id + 1) + gen_texts[-1][0] 115 | FinalCaptionList.append(FinalCaptionStr) 116 | BestCaptionList.append(BestCaptionStr) 117 | return FinalCaptionList, BestCaptionList 118 | 119 | def Demo(RunType, ControlType, SentimentType, Order, Length, NumIterations, SamplesNum, Alpha, Beta, Gamma, Img): 120 | args = get_args() 121 | set_seed(args.seed) 122 | 123 | args.num_iterations = NumIterations 124 | args.sentence_len = Length 125 | args.run_type = RunType 126 | args.control_type = ControlType 127 | args.sentiment_type = SentimentType 128 | args.alpha = Alpha 129 | args.beta = Beta 130 | args.gamma = Gamma 131 | args.samples_num = SamplesNum 132 | args.order = Order 133 | img = Img 134 | 135 | run_type = "caption" if args.run_type=="caption" else args.control_type 136 | if run_type=="sentiment": 137 | run_type = args.sentiment_type 138 | 139 | if os.path.exists("logger")== False: 140 | os.mkdir("logger") 141 | logger = create_logger( 142 | "logger",'demo_{}_{}_len{}_topk{}_alpha{}_beta{}_gamma{}_lmtemp{}_{}.log'.format( 143 | run_type, args.order,args.sentence_len, 144 | args.candidate_k, args.alpha,args.beta,args.gamma,args.lm_temperature, 145 | time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))) 146 | 147 | logger.info(f"Generating order:{args.order}") 148 | logger.info(f"Run type:{run_type}") 149 | logger.info(args) 150 | 151 | # Load pre-trained model (weights) 152 | lm_model = AutoModelForMaskedLM.from_pretrained(args.lm_model) 153 | lm_tokenizer = AutoTokenizer.from_pretrained(args.lm_model) 154 | lm_model.eval() 155 | clip = CLIP(args.match_model) 156 | clip.eval() 157 | 158 | lm_model = lm_model.to(args.device) 159 | clip = clip.to(args.device) 160 | 161 | ## Remove stop words, token mask 162 | with open(args.stop_words_path,'r',encoding='utf-8') as stop_words_file: 163 | stop_words = stop_words_file.readlines() 164 | stop_words_ = [stop_word.rstrip('\n') for stop_word in stop_words] 165 | stop_words_ += args.add_extra_stopwords 166 | stop_ids = lm_tokenizer.convert_tokens_to_ids(stop_words_) 167 | token_mask = torch.ones((1,lm_tokenizer.vocab_size)) 168 | for stop_id in stop_ids: 169 | token_mask[0,stop_id]=0 170 | token_mask = token_mask.to(args.device) 171 | with torch.no_grad(): 172 | if args.run_type == 'caption': 173 | FinalCaption, BestCaption = run_caption(args, img, lm_model, lm_tokenizer, clip, token_mask, logger) 174 | elif args.run_type == 'controllable': 175 | FinalCaption, BestCaption = run_control(run_type, args, img, lm_model, lm_tokenizer, clip, token_mask, logger) 176 | else: 177 | raise Exception('run_type must be caption or controllable!') 178 | 179 | logger.handlers = [] 180 | 181 | FinalCaptionFormat, BestCaptionFormat = format_output(SamplesNum, FinalCaption, BestCaption) 182 | return FinalCaptionFormat, BestCaptionFormat 183 | 184 | 185 | def RunTypeChange(choice): 186 | if choice == "caption": 187 | return gr.update(visible=False) 188 | elif choice == "controllable": 189 | return gr.update(visible=True) 190 | 191 | 192 | def ControlTypeChange(choice): 193 | if choice == "pos": 194 | return gr.update(visible=False) 195 | elif choice == "sentiment": 196 | return gr.update(visible=True) 197 | 198 | with gr.Blocks() as demo: 199 | 200 | gr.Markdown(""" 201 | # ConZIC 202 | ### Controllable Zero-shot Image Captioning by Sampling-Based Polishing 203 | """) 204 | 205 | with gr.Row(): 206 | with gr.Column(): 207 | RunType = gr.Radio( 208 | ["caption", "controllable"], value="caption", label="Run Type", info="Select the Run Type" 209 | ) 210 | ControlType = gr.Radio( 211 | ["sentiment", "pos"], value="sentiment", label="Control Type", info="Select the Control Type", 212 | visible=False, interactive=True 213 | ) 214 | SentimentType = gr.Radio( 215 | ["positive", "negative"], value="positive", label="Sentiment Type", info="Select the Sentiment Type", 216 | visible=False, interactive=True 217 | ) 218 | Order = gr.Radio( 219 | ["sequential", "shuffle", "random"], value="shuffle", label="Order", info="Generation order of text" 220 | ) 221 | 222 | RunType.change(fn = RunTypeChange, inputs = RunType, outputs = SentimentType) 223 | RunType.change(fn = RunTypeChange, inputs = RunType, outputs = ControlType) 224 | ControlType.change(fn = ControlTypeChange, inputs = ControlType, outputs = SentimentType) 225 | 226 | with gr.Row(): 227 | Length = gr.Slider( 228 | 5, 15, value=10, label="Sentence Length", info="Choose betwen 5 and 15", step=1 229 | ) 230 | NumIterations = gr.Slider( 231 | 1, 15, value=10, label="Num Iterations", info="predefined iterations for Gibbs Sampling", step=1 232 | ) 233 | with gr.Row(): 234 | SamplesNum = gr.Slider( 235 | 1, 5, value=2, label="Samples Num", step=1 236 | ) 237 | Alpha = gr.Slider( 238 | 0, 1, value=0.02, label="Alpha", info="Weight for fluency", step=0.01 239 | ) 240 | with gr.Row(): 241 | Beta = gr.Slider( 242 | 1, 5, value=2, label="Beta", info="Weight for image-matching degree", step=0.5 243 | ) 244 | Gamma = gr.Slider( 245 | 1, 10, value=5, label="Gamma", info="weight for controllable degree", step=0.5 246 | ) 247 | with gr.Column(): 248 | 249 | Img = gr.Image(label="Upload Picture", type = "pil") 250 | 251 | FinalCaption = gr.Textbox(label="Final Caption", lines=5, placeholder="Final Caption") 252 | BestCaption = gr.Textbox(label="Best Caption", lines=5, placeholder="Best Caption") 253 | with gr.Row(): 254 | gen_button = gr.Button("Submit") 255 | clear_button = gr.Button("Reset") 256 | 257 | gen_button.click( 258 | fn = Demo, 259 | inputs = [ 260 | RunType, ControlType, SentimentType, Order, Length, NumIterations, SamplesNum, Alpha, Beta, Gamma, Img 261 | ], 262 | outputs = [ 263 | FinalCaption, BestCaption 264 | ] 265 | ) 266 | clear_button.click( 267 | fn = lambda : [gr.Radio.update(value = 'caption'), gr.Radio.update(value = 'pos'), gr.Radio.update(value = 'positive'), 268 | gr.Radio.update(value = 'shuffle'), gr.Slider.update(value = 10), gr.Slider.update(value = 10), 269 | gr.Slider.update(value = 2), gr.Slider.update(value = 0.02), gr.Slider.update(value = 2), 270 | gr.Slider.update(value = 5) 271 | ], 272 | inputs = [ 273 | ], 274 | outputs = [ 275 | RunType, ControlType, SentimentType, Order, Length, NumIterations, SamplesNum, Alpha, Beta, Gamma 276 | ] 277 | ) 278 | if __name__ == "__main__": 279 | 280 | # nltk.download('wordnet') 281 | # nltk.download('punkt') 282 | # nltk.download('averaged_perceptron_tagger') 283 | # nltk.download('sentiwordnet') 284 | 285 | demo.launch() 286 | -------------------------------------------------------------------------------- /clip/build_text_index.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import progressbar 5 | import os 6 | 7 | def parse_config(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--clip_name", type=str, default="openai/clip-vit-base-patch32") 10 | parser.add_argument("--text_file_path", type=str) 11 | # save configuration 12 | parser.add_argument("--save_index_prefix", type=str, help='where to save the mips index') 13 | parser.add_argument("--save_index_name", type=str) 14 | parser.add_argument("--save_mapping_dict_name", type=str, 15 | help="a json file that stores a dictory. The dictory contains mapping between mips index and caption text") 16 | # inference configuration 17 | parser.add_argument("--batch_size", type=int, help="the batch size used to conduct inference with CLIP") 18 | return parser.parse_args() 19 | 20 | def load_batch_text(text_file_path, batch_size): 21 | import json 22 | with open(text_file_path) as f: 23 | item_list = json.load(f) 24 | 25 | text_list = [] 26 | for item in item_list: 27 | captions = item["captions"] 28 | for cap in captions: 29 | text_list.append(cap) 30 | print ('Number of text instances is {}'.format(len(text_list))) 31 | 32 | data_num = len(text_list) 33 | batch_num = data_num // batch_size 34 | batch_text_list = [] 35 | s_idx, e_idx = 0, batch_size 36 | for p_idx in range(batch_num): 37 | one_batch_text_list = [] 38 | for idx in range(s_idx, e_idx): 39 | one_batch_text_list.append(text_list[idx]) 40 | batch_text_list.append(one_batch_text_list) 41 | return batch_text_list 42 | 43 | 44 | import argparse 45 | if __name__ == '__main__': 46 | if torch.cuda.is_available(): 47 | print ('Cuda is available.') 48 | cuda_available = torch.cuda.is_available() 49 | args = parse_config() 50 | device = torch.device('cuda') 51 | 52 | import os 53 | if os.path.exists(args.save_index_prefix): 54 | pass 55 | else: # recursively construct directory 56 | os.makedirs(args.save_index_prefix, exist_ok=True) 57 | 58 | print ('Loading CLIP...') 59 | from clip import CLIP 60 | model = CLIP(args.clip_name) 61 | if cuda_available: 62 | model = model.cuda(device) 63 | model.eval() 64 | print ('CLIP loaded!') 65 | 66 | print ('Loading text data...') 67 | batch_text_list = load_batch_text(args.text_file_path, args.batch_size) 68 | print ('Text data loaded.') 69 | 70 | res_text_vec_list, res_text_list = [], [] 71 | batch_num = len(batch_text_list) 72 | print ('Number of batches is {}'.format(batch_num)) 73 | print ('Start inference...') 74 | p = progressbar.ProgressBar(batch_num) 75 | p.start() 76 | with torch.no_grad(): 77 | for p_idx in range(batch_num): 78 | p.update(p_idx) 79 | one_text_batch = batch_text_list[p_idx] 80 | one_batch_vec = model.compute_batch_index_text_representation(one_text_batch).detach().cpu() 81 | one_batch_vec_list = one_batch_vec.unbind(dim=0) 82 | bsz = len(one_batch_vec_list) 83 | for k in range(bsz): 84 | res_text_vec_list.append(one_batch_vec_list[k].numpy()) 85 | res_text_list.append(one_text_batch[k]) 86 | p.finish() 87 | assert len(res_text_vec_list) == len(res_text_list) 88 | print ('Inference completed!') 89 | 90 | index_text_mapping_dict = {} 91 | for k in range(len(res_text_list)): 92 | index_text_mapping_dict[k] = res_text_list[k] 93 | mapping_list_save_path = args.save_index_prefix + '/' + args.save_mapping_dict_name 94 | import json 95 | with open(mapping_list_save_path, 'w') as outfile: 96 | json.dump(index_text_mapping_dict, outfile, indent=4) 97 | print ('Mapping dictionary saved!') 98 | 99 | print ('Start buiding index...') 100 | index_save_path = args.save_index_prefix + '/' + args.save_index_name 101 | with open(index_save_path, 'w', encoding = 'utf8') as o: 102 | for vec in res_text_vec_list: 103 | one_text = ' '.join([str(num) for num in vec]).strip() 104 | o.writelines(one_text + '\n') 105 | print ('Index completed!') 106 | -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import requests 3 | from torch import nn 4 | from PIL import Image 5 | 6 | class CLIP(nn.Module): 7 | def __init__(self, model_name): 8 | super(CLIP, self).__init__() 9 | # model name: e.g. openai/clip-vit-base-patch32 10 | print ('Initializing CLIP model...') 11 | from transformers import CLIPProcessor, CLIPModel 12 | self.model = CLIPModel.from_pretrained(model_name) 13 | self.model.eval() 14 | self.processor = CLIPProcessor.from_pretrained(model_name) 15 | from transformers import CLIPTokenizer 16 | self.tokenizer = CLIPTokenizer.from_pretrained(model_name) 17 | self.cuda_has_been_checked = False 18 | print ('CLIP model initialized.') 19 | 20 | def check_cuda(self): 21 | self.cuda_available = next(self.model.parameters()).is_cuda 22 | self.device = next(self.model.parameters()).get_device() 23 | if self.cuda_available: 24 | print ('Cuda is available.') 25 | print ('Device is {}'.format(self.device)) 26 | else: 27 | print ('Cuda is not available.') 28 | print ('Device is {}'.format(self.device)) 29 | 30 | @torch.no_grad() 31 | def compute_image_representation_from_image_path(self, image_path): 32 | if not self.cuda_has_been_checked: 33 | self.check_cuda() 34 | self.cuda_has_been_checked = True 35 | else: 36 | pass 37 | # image_path: the path of the image 38 | image = Image.open(image_path) 39 | inputs = self.processor(images=image, return_tensors="pt") 40 | pixel_values = inputs['pixel_values'] 41 | if self.cuda_available: 42 | pixel_values = pixel_values.cuda(self.device) 43 | visual_outputs = self.model.vision_model(pixel_values=pixel_values) 44 | image_embeds = visual_outputs[1] 45 | image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] 46 | return image_embeds 47 | 48 | def compute_image_representation_from_image_instance(self, image): 49 | if not self.cuda_has_been_checked: 50 | self.check_cuda() 51 | self.cuda_has_been_checked = True 52 | else: 53 | pass 54 | # image_path: the path of the image 55 | inputs = self.processor(images=image, return_tensors="pt") 56 | pixel_values = inputs['pixel_values'] 57 | if self.cuda_available: 58 | pixel_values = pixel_values.cuda(self.device) 59 | visual_outputs = self.model.vision_model(pixel_values=pixel_values) 60 | image_embeds = visual_outputs[1] 61 | image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] 62 | return image_embeds 63 | 64 | def compute_text_representation(self, text_list): 65 | if not self.cuda_has_been_checked: 66 | self.check_cuda() 67 | self.cuda_has_been_checked = True 68 | else: 69 | pass 70 | # text_list: a list of text 71 | text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt", 72 | max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True) 73 | # self.tokenizer.max_len_single_sentence + 2 = 77 74 | input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask'] 75 | if self.cuda_available: 76 | input_ids = input_ids.cuda(self.device) 77 | attention_mask = attention_mask.cuda(self.device) 78 | text_outputs = self.model.text_model( 79 | input_ids=input_ids, 80 | attention_mask=attention_mask 81 | ) 82 | text_embeds = text_outputs[1] 83 | text_embeds = self.model.text_projection(text_embeds) 84 | return text_embeds 85 | 86 | def compute_image_text_similarity_via_embeddings(self, image_embeds, text_embeds): 87 | ''' 88 | image_embeds: batch x embed_dim 89 | text_embeds: batch x len(text_list) x embed_dim 90 | ''' 91 | text_embeds = text_embeds.view(image_embeds.shape[0], -1, text_embeds.shape[-1]) 92 | image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) 93 | text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) 94 | image_embeds = image_embeds.unsqueeze(-1) 95 | logit_scale = self.model.logit_scale.exp() 96 | logits_per_text = torch.matmul(text_embeds, image_embeds) * logit_scale 97 | logits_per_image = logits_per_text.squeeze(-1) 98 | return logits_per_image.softmax(dim=1), logits_per_image/logit_scale # batch x len(text_list) 99 | 100 | def compute_image_text_similarity_via_raw_text(self, image_embeds, text_list): 101 | text_embeds = self.compute_text_representation(text_list) 102 | return self.compute_image_text_similarity_via_embeddings(image_embeds, text_embeds) 103 | 104 | ### -------------------- functions for building index ---------------------- ### 105 | def compute_batch_index_image_features(self, image_list): 106 | ''' 107 | # list of image instances 108 | ''' 109 | if not self.cuda_has_been_checked: 110 | self.check_cuda() 111 | self.cuda_has_been_checked = True 112 | else: 113 | pass 114 | # image_path: the path of the image 115 | inputs = self.processor(images=image_list, return_tensors="pt") 116 | pixel_values = inputs['pixel_values'] 117 | if self.cuda_available: 118 | pixel_values = pixel_values.cuda(self.device) 119 | visual_outputs = self.model.vision_model(pixel_values=pixel_values) 120 | image_embeds = visual_outputs[1] 121 | image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] 122 | return image_embeds # len(image_list) x embed_dim 123 | 124 | def compute_batch_index_text_representation(self, text_list): 125 | if not self.cuda_has_been_checked: 126 | self.check_cuda() 127 | self.cuda_has_been_checked = True 128 | else: 129 | pass 130 | # text_list: a list of text 131 | #text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt") 132 | text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt", 133 | max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True) 134 | input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask'] 135 | if self.cuda_available: 136 | input_ids = input_ids.cuda(self.device) 137 | attention_mask = attention_mask.cuda(self.device) 138 | text_outputs = self.model.text_model( 139 | input_ids=input_ids, 140 | attention_mask=attention_mask 141 | ) 142 | text_embeds = text_outputs[1] 143 | text_embeds = self.model.text_projection(text_embeds) 144 | return text_embeds 145 | #logit_scale = self.model.logit_scale.exp() 146 | #text_embeds = text_embeds * logit_scale 147 | #return text_embeds 148 | 149 | -------------------------------------------------------------------------------- /clip/clipretrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import torch 4 | import progressbar 5 | import numpy as np 6 | from PIL import Image 7 | 8 | class CLIPIndex: 9 | def __init__(self, index_matrix_path, mapping_dict_path, clip): 10 | ''' 11 | index_path: the pre-trained index 12 | mapping_dict_path: the pre-indexed mapping dictionary 13 | clip: the pre-trained clip model 14 | ''' 15 | print ('Loading index...') 16 | self.index_matrix = self.normalization(self.load_matrix(index_matrix_path)) 17 | print ('Index loaded.') 18 | print (self.index_matrix.shape) 19 | with open(mapping_dict_path) as f: 20 | self.mapping_dict = json.load(f) 21 | self.clip = clip 22 | 23 | def load_matrix(self, in_f): 24 | matrix_list = [] 25 | with open(in_f, 'r', encoding = 'utf8') as i: 26 | lines = i.readlines() 27 | for l in lines: 28 | one_vec = [float(num) for num in l.strip('\n').split()] 29 | matrix_list.append(one_vec) 30 | return np.array(matrix_list) 31 | 32 | def normalization(self, matrix): 33 | ''' 34 | matrix: num_instance x num_feature 35 | ''' 36 | return matrix / np.linalg.norm(matrix, axis=1, keepdims=True) 37 | 38 | def get_image_representation(self, image_path): 39 | image_instance = Image.open(image_path) 40 | image_vec = self.clip.compute_batch_index_image_features([image_instance]).detach().cpu().numpy() 41 | image_vec = self.normalization(image_vec) 42 | return image_vec 43 | 44 | def search_text(self, image_path): 45 | image_vec = self.get_image_representation(image_path) 46 | sort_idx_list = np.matmul(image_vec, self.index_matrix.transpose())[0].argsort()[::-1] 47 | top_idx = sort_idx_list[0] 48 | return self.mapping_dict[str(top_idx)] 49 | 50 | 51 | def parse_config(): 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("--clip_name", type=str) 54 | parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images") 55 | parser.add_argument("--test_path", type=str) 56 | # index configuration 57 | parser.add_argument("--index_matrix_path", type=str) 58 | parser.add_argument("--mapping_dict_path", type=str) 59 | # save configuration 60 | parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory") 61 | parser.add_argument("--save_name", type=str, help="the name of the saved file") 62 | return parser.parse_args() 63 | 64 | import argparse 65 | if __name__ == '__main__': 66 | if torch.cuda.is_available(): 67 | print ('Cuda is available.') 68 | cuda_available = torch.cuda.is_available() 69 | args = parse_config() 70 | device = torch.device('cuda') 71 | 72 | save_path_prefix = args.save_path_prefix 73 | import os 74 | if os.path.exists(save_path_prefix): 75 | pass 76 | else: # recursively construct directory 77 | os.makedirs(save_path_prefix, exist_ok=True) 78 | # parse save name 79 | save_name = args.save_name 80 | full_save_path = save_path_prefix + '/' + save_name 81 | print ('full save path is {}'.format(full_save_path)) 82 | 83 | print ('Loading CLIP...') 84 | from clip import CLIP 85 | clip = CLIP(args.clip_name) 86 | if cuda_available: 87 | clip = clip.cuda(device) 88 | clip.eval() 89 | print ('CLIP loaded!') 90 | 91 | clipindex = CLIPIndex(args.index_matrix_path, args.mapping_dict_path, clip) 92 | 93 | print ('Loading data...') 94 | import json 95 | with open(args.test_path) as f: 96 | item_list = json.load(f) 97 | print ('Data loaded.') 98 | print ('Number of test instances is {}'.format(len(item_list))) 99 | 100 | result_list = [] 101 | invalid_num = 0 102 | print ('----------------------------------------------------------------') 103 | with torch.no_grad(): 104 | test_num = len(item_list) 105 | #test_num = 10 106 | print ('Number of inference instances is {}'.format(test_num)) 107 | p = progressbar.ProgressBar(test_num) 108 | p.start() 109 | for p_idx in range(test_num): 110 | p.update(p_idx) 111 | one_test_dict = item_list[p_idx] 112 | 113 | one_res_dict = { 114 | 'split':one_test_dict['split'], 115 | 'image_name':one_test_dict['image_name'], 116 | #'file_path':one_test_dict['file_path'], 117 | 'captions':one_test_dict['captions'] 118 | } 119 | 120 | image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name'] 121 | try: 122 | output_text = clipindex.search_text(image_full_path) 123 | one_res_dict['prediction'] = output_text 124 | result_list.append(one_res_dict) 125 | except: 126 | invalid_num += 1 127 | print ('invalid number is {}'.format(invalid_num)) 128 | continue 129 | p.finish() 130 | print ('Inference completed!') 131 | 132 | import json 133 | with open(full_save_path, 'w') as outfile: 134 | json.dump(result_list, outfile, indent=4) 135 | 136 | -------------------------------------------------------------------------------- /compute_n_div.py: -------------------------------------------------------------------------------- 1 | from nltk.tokenize import word_tokenize 2 | from collections import defaultdict 3 | import json 4 | 5 | def calc_diversity(predicts,vocab): 6 | tokens = [0.0, 0.0] 7 | types = [defaultdict(int), defaultdict(int)] 8 | for gg in predicts: 9 | g = word_tokenize(gg.lower()) 10 | # g = gg.rstrip().lower().rstrip(".").split() 11 | for word in g: 12 | if word not in vocab: 13 | vocab.append(word) 14 | for n in range(2): 15 | for idx in range(len(g)-n): 16 | ngram = ' '.join(g[idx:idx+n+1]) 17 | types[n][ngram] = 1 18 | tokens[n] += 1 19 | div1 = len(types[0].keys())/tokens[0] 20 | div2 = len(types[1].keys())/tokens[1] 21 | return [div1, div2], vocab 22 | 23 | def calc_vocab_num(predicts): 24 | vocab = [] 25 | for sentence in predicts: 26 | g = word_tokenize(sentence.lower()) 27 | for word in g: 28 | if word not in vocab: 29 | vocab.append(word) 30 | return vocab 31 | 32 | div1 = 0 33 | div2 = 0 34 | json_path = "diversity_formal.json" 35 | 36 | vocab = [] 37 | with open(json_path,"r") as cur_json_file: 38 | cur_res = json.load(cur_json_file) 39 | for item in cur_res: 40 | div_n, vocab = calc_diversity(item["captions"],vocab) 41 | div1 += div_n[0] 42 | div2 += div_n[1] 43 | div1 /= len(cur_res) 44 | div2 /= len(cur_res) 45 | with open("stop_words.txt",'r') as stop_word_file: 46 | stop_words = stop_word_file.readlines() 47 | stop_words = [word.rstrip() for word in stop_words] 48 | vocab = [word for word in vocab if (word not in stop_words and "unused" not in word)] 49 | print("vocab_len:",len(set(vocab))) 50 | print("div_1:",div1) 51 | print("div_2:",div2) 52 | 53 | -------------------------------------------------------------------------------- /control_gen_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import random 5 | from utils import get_init_text, update_token_mask 6 | from sentiments_classifer import batch_texts_POS_Sentiments_analysis 7 | from POS_classifier import batch_texts_POS_analysis 8 | 9 | import time 10 | 11 | 12 | def generate_caption_step(out, gen_idx, mask, temperature=None, top_k=0): 13 | """ Generate a word from out[gen_idx] 14 | 15 | args: 16 | - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size 17 | - gen_idx (int): location for which to generate for 18 | - top_k (int): if >0, only sample from the top k most probable words 19 | """ 20 | logits = out[:, gen_idx] 21 | if temperature is not None: 22 | logits = logits / temperature 23 | 24 | probs = F.softmax(logits, dim=-1) 25 | probs *= (mask) 26 | top_k_probs, top_k_ids = probs.topk(top_k, dim=-1) 27 | 28 | return top_k_probs, top_k_ids 29 | 30 | def sentiment_sequential_generation(img_name, model, clip, tokenizer,image_instance,token_mask, prompt, logger, 31 | max_len=15, top_k=0,temperature=None, alpha=0.7,beta=1, 32 | max_iters=20,batch_size=1, 33 | verbose=True,gamma=5, ctl_signal="positive"): 34 | """ Generate one word at a time, in L->R order """ 35 | seed_len = len(prompt.split())+1 36 | batch = get_init_text(tokenizer,prompt, max_len, batch_size) 37 | image_embeds = clip.compute_image_representation_from_image_instance(image_instance) 38 | clip_score_sequence = [] 39 | best_clip_score_list = [0] * batch_size 40 | best_caption_list = ['None'] * batch_size 41 | inp = torch.tensor(batch).to(image_embeds.device) 42 | gen_texts_list = [] 43 | for iter_num in range(max_iters): 44 | for ii in range(max_len): 45 | token_mask = update_token_mask(tokenizer, token_mask, max_len, ii) 46 | inp[:,seed_len + ii] = tokenizer.mask_token_id 47 | inp_ = inp.clone().detach() 48 | out = model(inp).logits 49 | probs, idxs = generate_caption_step(out, gen_idx=seed_len + ii,mask=token_mask, top_k=top_k, temperature=temperature) 50 | topk_inp = inp_.unsqueeze(1).repeat(1,top_k,1) 51 | idxs_ = (idxs * token_mask[0][idxs]).long() 52 | topk_inp[:,:,ii + seed_len] = idxs_ 53 | repeats = ((idxs_[:,:, None] == topk_inp).float().sum(2) - 1) 54 | topk_inp_batch = topk_inp.view(-1,topk_inp.shape[-1]) 55 | batch_text_list= tokenizer.batch_decode(topk_inp_batch , skip_special_tokens=True) 56 | sentiment_probs_batch, sentiment_scores_batch, pos_tags, wordnet_pos_tags = batch_texts_POS_Sentiments_analysis( 57 | batch_text_list, 1, topk_inp.device, sentiment_ctl=ctl_signal, batch_size_image = batch_size) 58 | clip_score, clip_ref = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) 59 | final_score = alpha * probs + beta * clip_score + gamma * sentiment_probs_batch + 0.1 * (1-torch.exp(repeats)) 60 | best_clip_id = final_score.argmax(dim=1).view(-1,1) 61 | inp[:,seed_len + ii] = idxs_.gather(1, best_clip_id).squeeze(-1) 62 | current_clip_score = clip_ref.gather(1,best_clip_id).squeeze(-1) 63 | current_senti_score = sentiment_scores_batch.gather(1, best_clip_id).squeeze(-1) 64 | clip_score_sequence_batch = current_clip_score.cpu().detach().numpy().tolist() 65 | senti_score_sequence_batch = current_senti_score.cpu().detach().numpy().tolist() 66 | if verbose and np.mod(iter_num + 1, 1) == 0: 67 | for_print_batch = tokenizer.batch_decode(inp) 68 | cur_text_batch= tokenizer.batch_decode(inp,skip_special_tokens=True) 69 | for jj in range(batch_size): 70 | if best_clip_score_list[jj] < clip_score_sequence_batch[jj]: 71 | best_clip_score_list[jj] = clip_score_sequence_batch[jj] 72 | best_caption_list[jj] = cur_text_batch[jj] 73 | logger.info(f"iter {iter_num + 1}, The {jj+1}-th image: {img_name[jj]}, clip score {clip_score_sequence_batch[jj]:.3f}" 74 | f", ctl score {senti_score_sequence_batch[jj]:.3f}: "+ for_print_batch[jj]) 75 | gen_texts_list.append(cur_text_batch) 76 | clip_score_sequence.append(clip_score_sequence_batch) 77 | gen_texts_list.append(best_caption_list) 78 | clip_score_sequence.append(best_clip_score_list) 79 | 80 | return gen_texts_list, clip_score_sequence 81 | 82 | def sentiment_shuffle_generation(img_name, model, clip, tokenizer,image_instance,token_mask, prompt, logger, 83 | max_len=15, top_k=0,temperature=None, alpha=0.7,beta=1, 84 | max_iters=20,batch_size=1, 85 | verbose=True,gamma=5, ctl_signal="positive"): 86 | """ Generate one word at a time, in random generation order """ 87 | seed_len = len(prompt.split())+1 88 | batch = get_init_text(tokenizer,prompt, max_len, batch_size) 89 | image_embeds = clip.compute_image_representation_from_image_instance(image_instance) 90 | inp = torch.tensor(batch).to(image_embeds.device) 91 | clip_score_sequence = [] 92 | best_clip_score_list = [0] * batch_size 93 | best_caption_list = ['None'] * batch_size 94 | random_lst = list(range(max_len)) 95 | random.shuffle(random_lst) 96 | logger.info(f"Order_list:{random_lst}") 97 | gen_texts_list = [] 98 | for iter_num in range(max_iters): 99 | for ii in random_lst: 100 | token_mask = update_token_mask(tokenizer, token_mask, max_len, ii) 101 | inp[:,seed_len + ii] = tokenizer.mask_token_id 102 | inp_ = inp.clone().detach() 103 | out = model(inp).logits 104 | probs, idxs = generate_caption_step(out, gen_idx=seed_len + ii,mask=token_mask, top_k=top_k, temperature=temperature) 105 | topk_inp = inp_.unsqueeze(1).repeat(1,top_k,1) 106 | idxs_ = (idxs * token_mask[0][idxs]).long() 107 | topk_inp[:,:,ii + seed_len] = idxs_ 108 | repeats = ((idxs_[:,:, None] == topk_inp).float().sum(2) - 1) 109 | topk_inp_batch = topk_inp.view(-1,topk_inp.shape[-1]) 110 | batch_text_list= tokenizer.batch_decode(topk_inp_batch , skip_special_tokens=True) 111 | sentiment_probs_batch, sentiment_scores_batch, pos_tags, wordnet_pos_tags = batch_texts_POS_Sentiments_analysis( 112 | batch_text_list, 1, topk_inp.device, sentiment_ctl=ctl_signal, batch_size_image = batch_size) 113 | clip_score, clip_ref = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) 114 | final_score = alpha * probs + beta * clip_score + gamma * sentiment_probs_batch + 0.1 * (1-torch.exp(repeats)) 115 | best_clip_id = final_score.argmax(dim=1).view(-1,1) 116 | inp[:,seed_len + ii] = idxs_.gather(1, best_clip_id).squeeze(-1) 117 | current_clip_score = clip_ref.gather(1,best_clip_id).squeeze(-1) 118 | current_senti_score = sentiment_scores_batch.gather(1, best_clip_id).squeeze(-1) 119 | clip_score_sequence_batch = current_clip_score.cpu().detach().numpy().tolist() 120 | senti_score_sequence_batch = current_senti_score.cpu().detach().numpy().tolist() 121 | if verbose and np.mod(iter_num + 1, 1) == 0: 122 | for_print_batch = tokenizer.batch_decode(inp) 123 | cur_text_batch= tokenizer.batch_decode(inp,skip_special_tokens=True) 124 | for jj in range(batch_size): 125 | if best_clip_score_list[jj] < clip_score_sequence_batch[jj]: 126 | best_clip_score_list[jj] = clip_score_sequence_batch[jj] 127 | best_caption_list[jj] = cur_text_batch[jj] 128 | logger.info(f"iter {iter_num + 1}, The {jj+1}-th image: {img_name[jj]}, clip score {clip_score_sequence_batch[jj]:.3f}" 129 | f", ctl score {senti_score_sequence_batch[jj]:.3f}: "+ for_print_batch[jj]) 130 | gen_texts_list.append(cur_text_batch) 131 | clip_score_sequence.append(clip_score_sequence_batch) 132 | gen_texts_list.append(best_caption_list) 133 | clip_score_sequence.append(best_clip_score_list) 134 | return gen_texts_list, clip_score_sequence 135 | 136 | def POS_sequential_generation(img_name, model, clip, tokenizer,image_instance,token_mask, prompt, logger, 137 | max_len=15, top_k=0,temperature=None, alpha=0.7,beta=1,gamma=0.1, 138 | max_iters=20,batch_size=1,ctl_signal=["DET"], 139 | verbose=True): 140 | """ Generate one word at a time, in L->R order """ 141 | 142 | seed_len = len(prompt.split())+1 143 | templete = False 144 | logger.info(ctl_signal) 145 | batch = get_init_text(tokenizer,prompt, max_len, batch_size) 146 | image_embeds = clip.compute_image_representation_from_image_instance(image_instance) 147 | clip_score_sequence = [] 148 | best_clip_score_list = [0] * batch_size 149 | best_ctl_score_list = [0] * batch_size 150 | best_caption_list = ['None'] * batch_size 151 | inp = torch.tensor(batch).to(image_embeds.device) 152 | gen_texts_list= [] 153 | for iter_num in range(max_iters): 154 | for ii in range(max_len): 155 | token_mask = update_token_mask(tokenizer, token_mask, max_len, ii) 156 | inp[:,seed_len + ii] = tokenizer.mask_token_id 157 | inp_ = inp.clone().detach() 158 | out = model(inp).logits 159 | probs, idxs = generate_caption_step(out, gen_idx=seed_len + ii,mask=token_mask, top_k=top_k, temperature=temperature) 160 | topk_inp = inp_.unsqueeze(1).repeat(1,top_k,1) 161 | idxs_ = (idxs * token_mask[0][idxs]).long() 162 | topk_inp[:,:,ii + seed_len] = idxs_ 163 | topk_inp_batch = topk_inp.view(-1,topk_inp.shape[-1]) 164 | batch_text_list= tokenizer.batch_decode(topk_inp_batch , skip_special_tokens=True) 165 | pos_tags, pos_scores = batch_texts_POS_analysis(batch_text_list, ctl_signal, device=idxs_.device) 166 | pos_scores_batch = pos_scores.view([batch_size, -1]) 167 | pos_probs = torch.softmax(pos_scores_batch/0.1, dim=-1).to(idxs_.device) 168 | clip_score, clip_ref = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) 169 | final_score = alpha * probs + beta * clip_score + gamma * pos_probs 170 | best_clip_id = final_score.argmax(dim=1).view(-1,1) 171 | inp[:,seed_len + ii] = idxs_.gather(1, best_clip_id).squeeze(-1) 172 | current_clip_score = clip_ref.gather(1,best_clip_id).squeeze(-1) 173 | current_ctl_score = pos_scores_batch.gather(1,best_clip_id).squeeze(-1) 174 | be_clip_id_batch = best_clip_id.reshape(-1).cpu() 175 | pos_tags_sequence_batch = [] 176 | for i in range(batch_size): 177 | pos_tags_sequence_batch.append(pos_tags[be_clip_id_batch[i]+i*top_k]) 178 | clip_score_sequence_batch = current_clip_score.cpu().detach().numpy().tolist() 179 | ctl_score_sequence_batch = current_ctl_score.cpu().detach().numpy().tolist() 180 | if verbose and np.mod(iter_num + 1, 1) == 0: 181 | for_print_batch = tokenizer.batch_decode(inp) 182 | cur_text_batch= tokenizer.batch_decode(inp,skip_special_tokens=True) 183 | for jj in range(batch_size): 184 | if best_clip_score_list[jj] < clip_score_sequence_batch[jj]: 185 | best_clip_score_list[jj] = clip_score_sequence_batch[jj] 186 | best_ctl_score_list[jj] = ctl_score_sequence_batch[jj] 187 | best_caption_list[jj] = cur_text_batch[jj] 188 | logger.info(f"iter {iter_num + 1}, The {jj+1}-th image: {img_name[jj]}, clip score {clip_score_sequence_batch[jj]:.3f}" 189 | f", ctl score {ctl_score_sequence_batch[jj]:.3f}: "+ for_print_batch[jj]) 190 | logger.info(pos_tags_sequence_batch[jj]) 191 | gen_texts_list.append(cur_text_batch) 192 | clip_score_sequence.append(clip_score_sequence_batch) 193 | gen_texts_list.append(best_caption_list) 194 | clip_score_sequence.append(best_clip_score_list) 195 | return gen_texts_list, clip_score_sequence 196 | 197 | def control_generate_caption(img_name, model, clip, tokenizer,image_instance,token_mask,logger, 198 | prompt="", batch_size=10, max_len=25, 199 | top_k=100, temperature=1.0, max_iter=500,alpha=0.7,beta=1,gamma=5, 200 | ctl_type="sentiment", style_type="positive",pos_type=None,generate_order="sequential"): 201 | # controllable funcitions to call 202 | start_time = time.time() 203 | if ctl_type=="sentiment": # sentiment control 204 | if generate_order=="sequential": 205 | generate_texts, clip_scores = sentiment_sequential_generation(img_name, model, clip, tokenizer, image_instance, token_mask, prompt, logger, 206 | batch_size=batch_size, max_len=max_len, top_k=top_k, 207 | alpha=alpha,beta=beta,gamma=gamma,temperature=temperature, 208 | max_iters=max_iter, ctl_signal=style_type) 209 | else: 210 | generate_texts, clip_scores = sentiment_shuffle_generation(img_name, model, clip, tokenizer, image_instance, 211 | token_mask, prompt, logger, 212 | batch_size=batch_size, max_len=max_len, 213 | top_k=top_k, 214 | alpha=alpha, beta=beta, gamma=gamma, 215 | temperature=temperature, 216 | max_iters=max_iter, 217 | ctl_signal=style_type) 218 | 219 | else: # POS control 220 | generate_texts, clip_scores = POS_sequential_generation(img_name, model, clip, tokenizer, image_instance, token_mask, prompt, logger, 221 | batch_size=batch_size, max_len=max_len, top_k=top_k, 222 | alpha=alpha,beta=beta,gamma=gamma,temperature=temperature, ctl_signal=pos_type, 223 | max_iters=max_iter) 224 | 225 | logger.info("Finished in %.3fs" % (time.time() - start_time)) 226 | final_caption = generate_texts[-2] 227 | best_caption = generate_texts[-1] 228 | for i in range(batch_size): 229 | logger.info(f"The {i+1}-th image: {img_name[i]}") 230 | logger.info(f"final caption: {final_caption[i]}") 231 | logger.info(f"best caption: {best_caption[i]}") 232 | return generate_texts, clip_scores -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from utils import create_logger,set_seed 2 | import os 3 | import time 4 | import argparse 5 | import json 6 | from PIL import Image 7 | import torch 8 | 9 | from clip.clip import CLIP 10 | from gen_utils import generate_caption 11 | from control_gen_utils import control_generate_caption 12 | from transformers import AutoModelForMaskedLM, AutoTokenizer 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument("--seed", type=int, default=42) 19 | parser.add_argument("--batch_size", type=int, default=1, help = "Only supports batch_size=1 currently.") 20 | parser.add_argument("--device", type=str, 21 | default='cuda',choices=['cuda','cpu']) 22 | 23 | ## Generation and Controllable Type 24 | parser.add_argument('--run_type', 25 | default='controllable', 26 | nargs='?', 27 | choices=['caption', 'controllable']) 28 | parser.add_argument('--prompt', 29 | default='Image of a',type=str) 30 | parser.add_argument('--order', 31 | default='shuffle', 32 | nargs='?', 33 | choices=['sequential', 'shuffle', 'span', 'random'], 34 | help="Generation order of text") 35 | parser.add_argument('--control_type', 36 | default='sentiment', 37 | nargs='?', 38 | choices=["sentiment","pos"], 39 | help="which controllable task to conduct") 40 | parser.add_argument('--pos_type', type=list, 41 | default=[['DET'], ['ADJ','NOUN'], ['NOUN'], 42 | ['VERB'], ['VERB'],['ADV'], ['ADP'], 43 | ['DET','NOUN'], ['NOUN'], ['NOUN','.'], 44 | ['.','NOUN'],['.','NOUN']], 45 | help="predefined part-of-speech templete") 46 | parser.add_argument('--sentiment_type', 47 | default="positive", 48 | nargs='?', 49 | choices=["positive", "negative"]) 50 | parser.add_argument('--samples_num', 51 | default=2,type=int) 52 | 53 | ## Hyperparameters 54 | parser.add_argument("--sentence_len", type=int, default=10) 55 | parser.add_argument("--candidate_k", type=int, default=200) 56 | parser.add_argument("--alpha", type=float, default=0.02, help="weight for fluency") 57 | parser.add_argument("--beta", type=float, default=2.0, help="weight for image-matching degree") 58 | parser.add_argument("--gamma", type=float, default=5.0, help="weight for controllable degree") 59 | parser.add_argument("--lm_temperature", type=float, default=0.1) 60 | parser.add_argument("--num_iterations", type=int, default=10, help="predefined iterations for Gibbs Sampling") 61 | 62 | ## Models and Paths 63 | parser.add_argument("--lm_model", type=str, default='bert-base-uncased', 64 | help="Path to language model") # bert,roberta 65 | parser.add_argument("--match_model", type=str, default='openai/clip-vit-base-patch32', 66 | help="Path to Image-Text model") # clip,align 67 | parser.add_argument("--caption_img_path", type=str, default='./examples/girl.jpg', 68 | help="file path of the image for captioning") 69 | parser.add_argument("--stop_words_path", type=str, default='stop_words.txt', 70 | help="Path to stop_words.txt") 71 | parser.add_argument("--add_extra_stopwords", type=list, default=[], 72 | help="you can add some extra stop words") 73 | 74 | args = parser.parse_args() 75 | 76 | return args 77 | 78 | def run_caption(args, image_path, lm_model, lm_tokenizer, clip, token_mask, logger): 79 | 80 | logger.info(f"Processing: {image_path}") 81 | image_instance = Image.open(image_path).convert("RGB") 82 | img_name = [image_path.split("/")[-1]] 83 | for sample_id in range(args.samples_num): 84 | logger.info(f"Sample {sample_id}: ") 85 | gen_texts, clip_scores = generate_caption(img_name,lm_model, clip, lm_tokenizer, image_instance, token_mask, logger, 86 | prompt=args.prompt, batch_size=args.batch_size, max_len=args.sentence_len, 87 | top_k=args.candidate_k, temperature=args.lm_temperature, 88 | max_iter=args.num_iterations,alpha=args.alpha,beta=args.beta, 89 | generate_order = args.order) 90 | 91 | def run_control(run_type, args, image_path, lm_model, lm_tokenizer, clip, token_mask, logger): 92 | 93 | logger.info(f"Processing: {image_path}") 94 | image_instance = Image.open(image_path).convert("RGB") 95 | img_name = [image_path.split("/")[-1]] 96 | for sample_id in range(args.samples_num): 97 | logger.info(f"Sample {sample_id}: ") 98 | gen_texts, clip_scores = control_generate_caption(img_name,lm_model, clip, lm_tokenizer, image_instance, token_mask, logger, 99 | prompt=args.prompt, batch_size=args.batch_size, max_len=args.sentence_len, 100 | top_k=args.candidate_k, temperature=args.lm_temperature, 101 | max_iter=args.num_iterations, alpha=args.alpha, 102 | beta=args.beta, gamma=args.gamma, 103 | ctl_type = args.control_type, style_type=args.sentiment_type,pos_type=args.pos_type, generate_order=args.order) 104 | 105 | if __name__ == "__main__": 106 | args = get_args() 107 | set_seed(args.seed) 108 | run_type = "caption" if args.run_type=="caption" else args.control_type 109 | if run_type=="sentiment": 110 | run_type = args.sentiment_type 111 | 112 | if os.path.exists("logger")== False: 113 | os.mkdir("logger") 114 | logger = create_logger( 115 | "logger",'demo_{}_{}_len{}_topk{}_alpha{}_beta{}_gamma{}_lmtemp{}_{}.log'.format( 116 | run_type, args.order,args.sentence_len, 117 | args.candidate_k, args.alpha,args.beta,args.gamma,args.lm_temperature, 118 | time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))) 119 | 120 | logger.info(f"Generating order:{args.order}") 121 | logger.info(f"Run type:{run_type}") 122 | logger.info(args) 123 | 124 | # Load pre-trained model (weights) 125 | lm_model = AutoModelForMaskedLM.from_pretrained(args.lm_model) 126 | lm_tokenizer = AutoTokenizer.from_pretrained(args.lm_model) 127 | lm_model.eval() 128 | clip = CLIP(args.match_model) 129 | clip.eval() 130 | 131 | lm_model = lm_model.to(args.device) 132 | clip = clip.to(args.device) 133 | 134 | ## Remove stop words, token mask 135 | with open(args.stop_words_path,'r',encoding='utf-8') as stop_words_file: 136 | stop_words = stop_words_file.readlines() 137 | stop_words_ = [stop_word.rstrip('\n') for stop_word in stop_words] 138 | stop_words_ += args.add_extra_stopwords 139 | stop_ids = lm_tokenizer.convert_tokens_to_ids(stop_words_) 140 | token_mask = torch.ones((1,lm_tokenizer.vocab_size)) 141 | for stop_id in stop_ids: 142 | token_mask[0,stop_id]=0 143 | token_mask = token_mask.to(args.device) 144 | 145 | img_path = args.caption_img_path 146 | with torch.no_grad(): 147 | if args.run_type == 'caption': 148 | run_caption(args, img_path, lm_model, lm_tokenizer, clip, token_mask, logger) 149 | elif args.run_type == 'controllable': 150 | run_control(run_type, args, img_path, lm_model, lm_tokenizer, clip, token_mask, logger) 151 | else: 152 | raise Exception('run_type must be caption or controllable!') 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /examples/Gosh.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/examples/Gosh.jpeg -------------------------------------------------------------------------------- /examples/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/examples/cat.png -------------------------------------------------------------------------------- /examples/girl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/examples/girl.jpg -------------------------------------------------------------------------------- /examples/horse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/examples/horse.png -------------------------------------------------------------------------------- /gen_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import random 5 | from utils import get_init_text, update_token_mask 6 | import time 7 | 8 | 9 | 10 | def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False, return_list=True): 11 | """ Generate a word from out[gen_idx] 12 | 13 | args: 14 | - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size 15 | - gen_idx (int): location for which to generate for 16 | - top_k (int): if >0, only sample from the top k most probable words 17 | - sample (Bool): if True, sample from full distribution. Overridden by top_k 18 | """ 19 | logits = out[:, gen_idx] 20 | if temperature is not None: 21 | logits = logits / temperature 22 | if top_k > 0: 23 | kth_vals, kth_idx = logits.topk(top_k, dim=-1) 24 | dist = torch.distributions.categorical.Categorical(logits=kth_vals) 25 | idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1) 26 | elif sample: 27 | dist = torch.distributions.categorical.Categorical(logits=logits) 28 | idx = dist.sample().squeeze(-1) 29 | else: 30 | idx = torch.argmax(logits, dim=-1) 31 | return idx.tolist() if return_list else idx 32 | 33 | def generate_caption_step(out, gen_idx, mask, temperature=None, top_k=100): 34 | # out, gen_idx=seed_len + ii, mask=token_mask, top_k=top_k, temperature=temperature 35 | """ Generate a word from out[gen_idx] 36 | args: 37 | - out (torch.Tensor): tensor of logits of size (batch_size, seq_len, vocab_size) 38 | - gen_idx (int): location for which to generate for 39 | - mask (torch.Tensor): (1, vocab_size) 40 | - top_k (int): candidate k 41 | """ 42 | logits = out[:, gen_idx] 43 | if temperature is not None: 44 | logits = logits / temperature 45 | probs = F.softmax(logits, dim=-1) 46 | probs *= (mask) 47 | top_k_probs, top_k_ids = probs.topk(top_k, dim=-1) 48 | 49 | return top_k_probs, top_k_ids 50 | 51 | def sequential_generation(img_name, model, clip, tokenizer, image_instance,token_mask, prompt, logger, 52 | max_len=15, top_k=100,temperature=None, alpha=0.7,beta=1, 53 | max_iters=20,batch_size=1, verbose=True): 54 | """ Generate one word at a time, in L->R order """ 55 | 56 | seed_len = len(prompt.split())+1 57 | batch = get_init_text(tokenizer, prompt, max_len, batch_size) 58 | image_embeds = clip.compute_image_representation_from_image_instance(image_instance) 59 | clip_score_sequence = [] 60 | best_clip_score_list = [0] * batch_size 61 | best_caption_list = ['None'] * batch_size 62 | inp = torch.tensor(batch).to(image_embeds.device) 63 | gen_texts_list = [] 64 | for iter_num in range(max_iters): 65 | for ii in range(max_len): 66 | token_mask = update_token_mask(tokenizer, token_mask, max_len, ii) 67 | inp[:,seed_len + ii] = tokenizer.mask_token_id 68 | inp_ = inp.clone().detach() 69 | out = model(inp).logits 70 | probs, idxs = generate_caption_step(out, gen_idx=seed_len + ii, mask=token_mask, top_k=top_k, temperature=temperature) 71 | topk_inp = inp_.unsqueeze(1).repeat(1,top_k,1) 72 | idxs_ = (idxs * token_mask[0][idxs]).long() 73 | topk_inp[:,:,ii + seed_len] = idxs_ 74 | topk_inp_batch = topk_inp.view(-1,topk_inp.shape[-1]) 75 | batch_text_list= tokenizer.batch_decode(topk_inp_batch , skip_special_tokens=True) 76 | clip_score, clip_ref = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) 77 | final_score = alpha * probs + beta * clip_score 78 | best_clip_id = final_score.argmax(dim=1).view(-1,1) 79 | inp[:,seed_len + ii] = idxs_.gather(1, best_clip_id).squeeze(-1) 80 | current_clip_score = clip_ref.gather(1,best_clip_id).squeeze(-1) 81 | clip_score_sequence_batch = current_clip_score.cpu().detach().numpy().tolist() 82 | if verbose and np.mod(iter_num + 1, 1) == 0: 83 | for_print_batch = tokenizer.batch_decode(inp) 84 | cur_text_batch= tokenizer.batch_decode(inp,skip_special_tokens=True) 85 | for jj in range(batch_size): 86 | if best_clip_score_list[jj] < clip_score_sequence_batch[jj]: 87 | best_clip_score_list[jj] = clip_score_sequence_batch[jj] 88 | best_caption_list[jj] = cur_text_batch[jj] 89 | logger.info(f"iter {iter_num + 1}, The {jj+1}-th image: {img_name[jj]}," 90 | f"clip score {clip_score_sequence_batch[jj]:.3f}: "+ for_print_batch[jj]) 91 | gen_texts_list.append(cur_text_batch) 92 | clip_score_sequence.append(clip_score_sequence_batch) 93 | gen_texts_list.append(best_caption_list) 94 | clip_score_sequence.append(best_clip_score_list) 95 | 96 | return gen_texts_list, clip_score_sequence 97 | 98 | def shuffle_generation(img_name, model, clip, tokenizer,image_instance,token_mask, prompt, logger, 99 | max_len=15, top_k=0,temperature=None, alpha=0.7,beta=1, 100 | max_iters=20,batch_size=1, 101 | verbose=True): 102 | """ Generate one word at a time, in random generation order """ 103 | seed_len = len(prompt.split())+1 104 | batch = get_init_text(tokenizer,prompt, max_len, batch_size) 105 | image_embeds = clip.compute_image_representation_from_image_instance(image_instance) 106 | inp = torch.tensor(batch).to(image_embeds.device) 107 | clip_score_sequence = [] 108 | best_clip_score_list = [0] * batch_size 109 | best_caption_list = ['None'] * batch_size 110 | random_lst = list(range(max_len)) 111 | random.shuffle(random_lst) 112 | logger.info(f"Order_list:{random_lst}") 113 | gen_texts_list = [] 114 | for iter_num in range(max_iters): 115 | for ii in random_lst: 116 | token_mask = update_token_mask(tokenizer, token_mask, max_len, ii) 117 | inp[:,seed_len + ii] = tokenizer.mask_token_id 118 | inp_ = inp.clone().detach() 119 | out = model(inp).logits 120 | probs, idxs = generate_caption_step(out, gen_idx=seed_len + ii,mask=token_mask, top_k=top_k, temperature=temperature) 121 | topk_inp = inp_.unsqueeze(1).repeat(1,top_k,1) 122 | idxs_ = (idxs * token_mask[0][idxs]).long() 123 | topk_inp[:,:,ii + seed_len] = idxs_ 124 | topk_inp_batch = topk_inp.view(-1,topk_inp.shape[-1]) 125 | batch_text_list= tokenizer.batch_decode(topk_inp_batch , skip_special_tokens=True) 126 | clip_score, clip_ref = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) 127 | final_score = alpha * probs + beta * clip_score 128 | best_clip_id = final_score.argmax(dim=1).view(-1,1) 129 | inp[:,seed_len + ii] = idxs_.gather(1, best_clip_id).squeeze(-1) 130 | current_clip_score = clip_ref.gather(1,best_clip_id).squeeze(-1) 131 | clip_score_sequence_batch = current_clip_score.cpu().detach().numpy().tolist() 132 | if verbose and np.mod(iter_num + 1, 1) == 0: 133 | for_print_batch = tokenizer.batch_decode(inp) 134 | cur_text_batch= tokenizer.batch_decode(inp,skip_special_tokens=True) 135 | for jj in range(batch_size): 136 | if best_clip_score_list[jj] < clip_score_sequence_batch[jj]: 137 | best_clip_score_list[jj] = clip_score_sequence_batch[jj] 138 | best_caption_list[jj] = cur_text_batch[jj] 139 | logger.info(f"iter {iter_num + 1}, The {jj+1}-th image: {img_name[jj]}," 140 | f"clip score {clip_score_sequence_batch[jj]:.3f}: "+ for_print_batch[jj]) 141 | gen_texts_list.append(cur_text_batch) 142 | clip_score_sequence.append(clip_score_sequence_batch) 143 | gen_texts_list.append(best_caption_list) 144 | clip_score_sequence.append(best_clip_score_list) 145 | 146 | return gen_texts_list, clip_score_sequence 147 | 148 | def span_generation(img_name, model, clip, tokenizer,image_instance,token_mask, prompt, logger, 149 | max_len=15, top_k=0,temperature=None, alpha=0.7,beta=1, 150 | max_iters=20,batch_size=1,verbose=True): 151 | """ Generate multiple words at a time (span generation), in L->R order """ 152 | seed_len = len(prompt.split())+1 153 | span_len = 2 154 | batch = get_init_text(tokenizer,prompt, max_len, batch_size) 155 | image_embeds = clip.compute_image_representation_from_image_instance(image_instance) 156 | clip_score_sequence = [] 157 | best_clip_score_list = [0] * batch_size 158 | best_caption_list = ['None'] * batch_size 159 | inp = torch.tensor(batch).to(image_embeds.device) 160 | gen_texts_list= [] 161 | for iter_num in range(max_iters): 162 | for span_start in range(0,max_len,span_len): 163 | span_end = min(span_start+span_len,max_len) 164 | inp[:,seed_len + span_start: seed_len + span_end] = tokenizer.mask_token_id 165 | out = model(inp).logits 166 | for ii in range(span_start,span_end): 167 | token_mask = update_token_mask(tokenizer, token_mask, max_len, ii) 168 | inp_ = inp.clone().detach() 169 | probs, idxs = generate_caption_step(out, gen_idx=seed_len + ii, mask=token_mask, top_k=top_k, 170 | temperature=temperature) 171 | topk_inp = inp_.unsqueeze(1).repeat(1,top_k,1) 172 | idxs_ = (idxs * token_mask[0][idxs]).long() 173 | topk_inp[:,:,ii + seed_len] = idxs_ 174 | topk_inp_batch = topk_inp.view(-1,topk_inp.shape[-1]) 175 | batch_text_list= tokenizer.batch_decode(topk_inp_batch , skip_special_tokens=True) 176 | clip_score, clip_ref = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) 177 | final_score = alpha * probs + beta * clip_score 178 | best_clip_id = final_score.argmax(dim=1).view(-1,1) 179 | inp[:,seed_len + ii] = idxs_.gather(1, best_clip_id).squeeze(-1) 180 | current_clip_score = clip_ref.gather(1,best_clip_id).squeeze(-1) 181 | clip_score_sequence_batch = current_clip_score.cpu().detach().numpy().tolist() 182 | if verbose and np.mod(iter_num + 1, 1) == 0: 183 | for_print_batch = tokenizer.batch_decode(inp) 184 | cur_text_batch= tokenizer.batch_decode(inp,skip_special_tokens=True) 185 | for jj in range(batch_size): 186 | if best_clip_score_list[jj] < clip_score_sequence_batch[jj]: 187 | best_clip_score_list[jj] = clip_score_sequence_batch[jj] 188 | best_caption_list[jj] = cur_text_batch[jj] 189 | logger.info(f"iter {iter_num + 1}, The {jj+1}-th image: {img_name[jj]}," 190 | f"clip score {clip_score_sequence_batch[jj]:.3f}: "+ for_print_batch[jj]) 191 | gen_texts_list.append(cur_text_batch) 192 | clip_score_sequence.append(clip_score_sequence_batch) 193 | gen_texts_list.append(best_caption_list) 194 | clip_score_sequence.append(best_clip_score_list) 195 | return gen_texts_list, clip_score_sequence 196 | 197 | def random_generation(img_name, model, clip, tokenizer,image_instance,token_mask, prompt, logger, 198 | max_len=15, top_k=0, temperature=None,alpha=0.7,beta=2,max_iters=300,print_every=10,batch_size=1,verbose=True): 199 | """ Generate for one random position at a timestep""" 200 | 201 | seed_len = len(prompt.split())+1 202 | batch = get_init_text(tokenizer, prompt, max_len, batch_size) 203 | image_embeds = clip.compute_image_representation_from_image_instance(image_instance) 204 | clip_score_sequence = [] 205 | best_clip_score_list = [0] * batch_size 206 | best_caption_list = ['None'] * batch_size 207 | inp = torch.tensor(batch).to(image_embeds.device) 208 | gen_texts_list = [] 209 | for ii in range(max_iters): 210 | kk = np.random.randint(0, max_len) 211 | token_mask = update_token_mask(tokenizer, token_mask, max_len, kk) 212 | inp[:,seed_len + kk] = tokenizer.mask_token_id 213 | inp_ = inp.clone().detach() 214 | out = model(inp).logits 215 | probs, idxs = generate_caption_step(out,gen_idx=seed_len + kk,mask=token_mask, top_k=top_k, temperature=temperature) 216 | topk_inp = inp_.unsqueeze(1).repeat(1,top_k,1) 217 | idxs_ = (idxs * token_mask[0][idxs]).long() 218 | topk_inp[:,:,kk + seed_len] = idxs_ 219 | topk_inp_batch = topk_inp.view(-1,topk_inp.shape[-1]) 220 | batch_text_list= tokenizer.batch_decode(topk_inp_batch , skip_special_tokens=True) 221 | clip_score, clip_ref = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) 222 | final_score = alpha * probs + beta * clip_score 223 | best_clip_id = final_score.argmax(dim=1).view(-1,1) 224 | inp[:,seed_len + kk] = idxs_.gather(1, best_clip_id).squeeze(-1) 225 | current_clip_score = clip_ref.gather(1,best_clip_id).squeeze(-1) 226 | clip_score_sequence_batch = current_clip_score.cpu().detach().numpy().tolist() 227 | cur_text_batch= tokenizer.batch_decode(inp,skip_special_tokens=True) 228 | for jj in range(batch_size): 229 | if best_clip_score_list[jj] < clip_score_sequence_batch[jj]: 230 | best_clip_score_list[jj] = clip_score_sequence_batch[jj] 231 | best_caption_list[jj] = cur_text_batch[jj] 232 | if verbose and np.mod(ii + 1, print_every) == 0: 233 | for_print_batch = tokenizer.batch_decode(inp) 234 | for jj in range(batch_size): 235 | logger.info(f"iter {ii + 1}, The {jj+1}-th image: {img_name[jj]}," 236 | f"clip score {clip_score_sequence_batch[jj]:.3f}: "+ for_print_batch[jj]) 237 | gen_texts_list.append(cur_text_batch) 238 | clip_score_sequence.append(clip_score_sequence_batch) 239 | gen_texts_list.append(best_caption_list) 240 | clip_score_sequence.append(best_clip_score_list) 241 | 242 | return gen_texts_list, clip_score_sequence 243 | 244 | def parallel_generation(img_name, model, clip, tokenizer,image_instance,token_mask, prompt, logger, 245 | max_len=15, top_k=0, temperature=None, alpha=0.1, beta=1, 246 | max_iters=300,batch_size=1,print_every=1, verbose=True): 247 | """ Generate for all positions at a time step """ 248 | seed_len = len(prompt.split())+1 249 | batch = get_init_text(tokenizer,prompt, max_len, batch_size) 250 | image_embeds = clip.compute_image_representation_from_image_instance(image_instance) 251 | clip_score_sequence = [] 252 | best_clip_score_list = [0] * batch_size 253 | best_caption_list = ['None'] * batch_size 254 | inp = torch.tensor(batch).to(image_embeds.device) 255 | gen_texts_list = [] 256 | for ii in range(max_iters): 257 | inp_ = inp.clone().detach() 258 | out = model(inp).logits 259 | gen_texts = [] 260 | for kk in range(max_len): 261 | probs, idxs = generate_caption_step(out, gen_idx=seed_len + kk,mask=token_mask, top_k=top_k, temperature=temperature) 262 | clip_score_sequence_batch = [] 263 | for jj in range(batch_size): 264 | topk_inp = inp_.unsqueeze(0).repeat(top_k,1,1) 265 | topk_inp[:, jj, ii + seed_len] = (idxs[jj] * token_mask[0][idxs[jj]]).long() 266 | batch_text_list = tokenizer.batch_decode(topk_inp[:,jj,:], skip_special_tokens=True) 267 | single_image_embeds = image_embeds[jj].unsqueeze(0) 268 | clip_score,clip_ref = clip.compute_image_text_similarity_via_raw_text(single_image_embeds, batch_text_list) 269 | final_score = alpha * probs[jj,:] + beta * clip_score 270 | best_clip_id = final_score.argmax() 271 | inp[jj][seed_len + kk] = idxs[jj][best_clip_id] 272 | current_clip_score = clip_ref[0][best_clip_id] 273 | clip_score_sequence_batch.append(current_clip_score.cpu().item()) 274 | if verbose and np.mod(ii, 1) == 0: 275 | for jj in range(batch_size): 276 | for_print = tokenizer.decode(inp[jj]) 277 | cur_text = tokenizer.decode(inp[jj],skip_special_tokens=True) 278 | if best_clip_score_list[jj] < clip_score_sequence_batch[jj]: 279 | best_clip_score_list[jj] = clip_score_sequence_batch[jj] 280 | best_caption_list[jj] = cur_text 281 | gen_texts.append(cur_text) 282 | logger.info(f"iter {ii + 1}, The {jj+1}-th image: {img_name[jj]}, clip score {clip_score_sequence_batch[jj]:.3f}: "+ for_print) 283 | gen_texts_list.append(gen_texts) 284 | clip_score_sequence.append(clip_score_sequence_batch) 285 | gen_texts_list.append(best_caption_list) 286 | clip_score_sequence.append(best_clip_score_list) 287 | return gen_texts_list, clip_score_sequence 288 | 289 | def generate_caption(img_name, model, clip, tokenizer,image_instance,token_mask,logger, 290 | prompt="", batch_size=1, max_len=15, 291 | top_k=100, temperature=1.0, max_iter=500,alpha=0.7,beta=1, 292 | generate_order="sequential"): 293 | # main generation functions to call 294 | start_time = time.time() 295 | 296 | if generate_order=="sequential": 297 | generate_texts, clip_scores = sequential_generation(img_name, model, clip, tokenizer, image_instance, token_mask, prompt, logger, 298 | batch_size=batch_size, max_len=max_len, top_k=top_k, 299 | alpha=alpha,beta=beta,temperature=temperature, 300 | max_iters=max_iter) 301 | 302 | elif generate_order=="shuffle": 303 | # max_iter = 15 304 | generate_texts, clip_scores = shuffle_generation(img_name, model, clip, tokenizer,image_instance,token_mask,prompt, logger, 305 | batch_size=batch_size, max_len=max_len, top_k=top_k, 306 | alpha=alpha,beta=beta,temperature=temperature,max_iters=max_iter) 307 | 308 | elif generate_order=="random": 309 | max_iter *= max_len 310 | print_every = max_len 311 | generate_texts, clip_scores = random_generation(img_name, model, clip, tokenizer,image_instance,token_mask,prompt,logger, 312 | max_len=max_len, top_k=top_k,alpha=alpha,beta=beta,print_every=print_every, 313 | temperature=temperature, batch_size=batch_size, max_iters=max_iter,verbose=True) 314 | 315 | elif generate_order=="span": 316 | max_iter = max_iter 317 | generate_texts, clip_scores = span_generation(img_name, model, clip, tokenizer, image_instance, token_mask, prompt, logger, 318 | batch_size=batch_size, max_len=max_len, top_k=top_k, 319 | alpha=alpha,beta=beta,temperature=temperature, max_iters=max_iter) 320 | 321 | elif generate_order=="parallel": 322 | generate_texts, clip_scores = parallel_generation(img_name, model, clip, tokenizer,image_instance,token_mask,prompt, logger, 323 | max_len=max_len, temperature=temperature, top_k=top_k, alpha=alpha,beta=beta, 324 | max_iters=max_iter, batch_size=batch_size, verbose=True) 325 | 326 | logger.info("Finished in %.3fs" % (time.time() - start_time)) 327 | final_caption = generate_texts[-2] 328 | best_caption = generate_texts[-1] 329 | for i in range(batch_size): 330 | logger.info(f"The {i+1}-th image: {img_name[i]}") 331 | logger.info(f"final caption: {final_caption[i]}") 332 | logger.info(f"best caption: {best_caption[i]}") 333 | return generate_texts, clip_scores -------------------------------------------------------------------------------- /paper_images/diversecaptioning.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/paper_images/diversecaptioning.jpg -------------------------------------------------------------------------------- /paper_images/framework.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/paper_images/framework.gif -------------------------------------------------------------------------------- /paper_images/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/paper_images/framework.jpg -------------------------------------------------------------------------------- /paper_images/gibbs_bert.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/paper_images/gibbs_bert.gif -------------------------------------------------------------------------------- /paper_images/gibbs_bert_mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/paper_images/gibbs_bert_mask.gif -------------------------------------------------------------------------------- /paper_images/lengthcontrol.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/paper_images/lengthcontrol.jpg -------------------------------------------------------------------------------- /paper_images/moreimagestyles.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/paper_images/moreimagestyles.jpg -------------------------------------------------------------------------------- /paper_images/poscontrol.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/paper_images/poscontrol.jpg -------------------------------------------------------------------------------- /paper_images/sentimentcontrol.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/paper_images/sentimentcontrol.jpg -------------------------------------------------------------------------------- /paper_images/style_examples.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joeyz0z/ConZIC/a30fe62cf225f79f9570a25055cab2ab4180010f/paper_images/style_examples.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colorlog 2 | nltk 3 | transformers 4 | torch 5 | torchvision 6 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from utils import create_logger,set_seed 2 | import os 3 | import time 4 | import argparse 5 | import json 6 | from PIL import Image 7 | import torch 8 | 9 | from clip.clip import CLIP 10 | from gen_utils import generate_caption 11 | from control_gen_utils import control_generate_caption 12 | from transformers import AutoModelForMaskedLM, AutoTokenizer 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument("--seed", type=int, default=42) 19 | parser.add_argument("--batch_size", type=int, default=2, help = "support batch_size>1 currently.") 20 | parser.add_argument("--device", type=str, 21 | default='cuda',choices=['cuda','cpu']) 22 | 23 | ## Generation and Controllable Type 24 | parser.add_argument('--run_type', 25 | default='controllable', 26 | nargs='?', 27 | choices=['caption', 'controllable']) 28 | parser.add_argument('--prompt', 29 | default='Image of a',type=str) 30 | parser.add_argument('--order', 31 | default='shuffle', 32 | nargs='?', 33 | choices=['sequential', 'shuffle', 'span', 'random'], 34 | help="Generation order of text") 35 | parser.add_argument('--control_type', 36 | default='sentiment', 37 | nargs='?', 38 | choices=["sentiment","pos"], 39 | help="which controllable task to conduct") 40 | parser.add_argument('--pos_type', type=list, 41 | default=[['DET'], ['ADJ','NOUN'], ['NOUN'], 42 | ['VERB'], ['VERB'],['ADV'], ['ADP'], 43 | ['DET','NOUN'], ['NOUN'], ['NOUN','.'], 44 | ['.','NOUN'],['.','NOUN']], 45 | help="predefined part-of-speech templete") 46 | parser.add_argument('--sentiment_type', 47 | default="positive", 48 | nargs='?', 49 | choices=["positive", "negative"]) 50 | parser.add_argument('--samples_num', 51 | default=2,type=int) 52 | 53 | ## Hyperparameters 54 | parser.add_argument("--sentence_len", type=int, default=10) 55 | parser.add_argument("--candidate_k", type=int, default=200) 56 | parser.add_argument("--alpha", type=float, default=0.02, help="weight for fluency") 57 | parser.add_argument("--beta", type=float, default=2.0, help="weight for image-matching degree") 58 | parser.add_argument("--gamma", type=float, default=5.0, help="weight for controllable degree") 59 | parser.add_argument("--lm_temperature", type=float, default=0.1) 60 | parser.add_argument("--num_iterations", type=int, default=10, help="predefined iterations for Gibbs Sampling") 61 | 62 | ## Models and Paths 63 | parser.add_argument("--lm_model", type=str, default='bert-base-uncased', 64 | help="Path to language model") # bert,roberta 65 | parser.add_argument("--match_model", type=str, default='clip-vit-base-patch32', 66 | help="Path to Image-Text model") # clip,align 67 | parser.add_argument("--caption_img_path", type=str, default='./examples/', 68 | help="file path of images for captioning") 69 | parser.add_argument("--stop_words_path", type=str, default='stop_words.txt', 70 | help="Path to stop_words.txt") 71 | parser.add_argument("--add_extra_stopwords", type=list, default=[], 72 | help="you can add some extra stop words") 73 | 74 | args = parser.parse_args() 75 | 76 | return args 77 | 78 | def run_caption(args, img_name, img_pil_list, lm_model, lm_tokenizer, clip, token_mask, logger, all_results): 79 | 80 | image_instance = img_pil_list 81 | gen_texts, clip_scores = generate_caption(img_name, lm_model, clip, lm_tokenizer, image_instance, token_mask, logger, 82 | prompt=args.prompt, batch_size=args.batch_size, max_len=args.sentence_len, 83 | top_k=args.candidate_k, temperature=args.lm_temperature, 84 | max_iter=args.num_iterations,alpha=args.alpha,beta=args.beta, 85 | generate_order = args.order) 86 | for iter_id, gen_text_list in enumerate(gen_texts): 87 | for jj in range(len(gen_text_list)): 88 | image_id = img_name[jj].split(".")[0] 89 | if all_results[iter_id]==None: 90 | all_results[iter_id] = {image_id: gen_text_list[jj]} 91 | else: 92 | all_results[iter_id][image_id] = gen_text_list[jj] 93 | return all_results 94 | 95 | def run_control(run_type, args, img_name, img_pil_list, lm_model, lm_tokenizer, clip, token_mask, logger, all_results): 96 | 97 | image_instance = img_pil_list 98 | gen_texts, clip_scores = control_generate_caption(img_name, lm_model, clip, lm_tokenizer, image_instance, token_mask, logger, 99 | prompt=args.prompt, batch_size=args.batch_size, max_len=args.sentence_len, 100 | top_k=args.candidate_k, temperature=args.lm_temperature, 101 | max_iter=args.num_iterations, alpha=args.alpha, 102 | beta=args.beta, gamma=args.gamma, 103 | ctl_type = args.control_type, style_type=args.sentiment_type,pos_type=args.pos_type, generate_order=args.order) 104 | 105 | for iter_id, gen_text_list in enumerate(gen_texts): 106 | for jj in range(len(gen_text_list)): 107 | image_id = img_name[jj].split(".")[0] 108 | if all_results[iter_id]==None: 109 | all_results[iter_id] = {image_id: gen_text_list[jj]} 110 | else: 111 | all_results[iter_id][image_id] = gen_text_list[jj] 112 | return all_results 113 | 114 | if __name__ == "__main__": 115 | args = get_args() 116 | set_seed(args.seed) 117 | run_type = "caption" if args.run_type=="caption" else args.control_type 118 | if run_type=="sentiment": 119 | run_type = args.sentiment_type 120 | 121 | if os.path.exists("logger")== False: 122 | os.mkdir("logger") 123 | logger = create_logger( 124 | "logger",'{}_{}_len{}_topk{}_alpha{}_beta{}_gamma{}_lmtemp{}_{}.log'.format( 125 | run_type, args.order,args.sentence_len, 126 | args.candidate_k, args.alpha,args.beta,args.gamma,args.lm_temperature, 127 | time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))) 128 | 129 | logger.info(f"Generating order:{args.order}") 130 | logger.info(f"Run type:{run_type}") 131 | logger.info(args) 132 | 133 | # Load pre-trained model (weights) 134 | lm_model = AutoModelForMaskedLM.from_pretrained(args.lm_model) 135 | lm_tokenizer = AutoTokenizer.from_pretrained(args.lm_model) 136 | lm_model.eval() 137 | clip = CLIP(args.match_model) 138 | clip.eval() 139 | 140 | lm_model = lm_model.to(args.device) 141 | clip = clip.to(args.device) 142 | 143 | ## Remove stop words, token mask 144 | with open(args.stop_words_path,'r',encoding='utf-8') as stop_words_file: 145 | stop_words = stop_words_file.readlines() 146 | stop_words_ = [stop_word.rstrip('\n') for stop_word in stop_words] 147 | stop_words_ += args.add_extra_stopwords 148 | stop_ids = lm_tokenizer.convert_tokens_to_ids(stop_words_) 149 | token_mask = torch.ones((1,lm_tokenizer.vocab_size)) 150 | for stop_id in stop_ids: 151 | token_mask[0,stop_id]=0 152 | token_mask = token_mask.to(args.device) 153 | 154 | img_dir = args.caption_img_path 155 | 156 | class Imgdata(Dataset): 157 | def __init__(self, dir_path): 158 | self.dir_path = dir_path 159 | self.img_name_list = os.listdir(dir_path) 160 | 161 | def __getitem__(self, idx): 162 | img_name = self.img_name_list[idx] 163 | img_item_path = os.path.join(self.dir_path,img_name) 164 | img = Image.open(img_item_path).convert("RGB") 165 | return img, img_name 166 | def __len__(self): 167 | return len(self.img_name_list) 168 | 169 | def collate_img(batch_data): 170 | img_path_batch_list = list() 171 | name_batch_list = list() 172 | for unit in batch_data: 173 | img_path_batch_list.append(unit[0]) 174 | name_batch_list.append(unit[1]) 175 | return img_path_batch_list,name_batch_list 176 | 177 | img_data = Imgdata(img_dir) 178 | train_loader = DataLoader(img_data, batch_size=args.batch_size, collate_fn=collate_img, shuffle=False, drop_last=True) 179 | 180 | for sample_id in range(args.samples_num): 181 | all_results = [None] * (args.num_iterations+1) 182 | logger.info(f"Sample {sample_id+1}: ") 183 | 184 | for batch_idx, (img_batch_pil_list, name_batch_list) in enumerate(train_loader): 185 | logger.info(f"The {batch_idx+1}-th batch:") 186 | with torch.no_grad(): 187 | if args.run_type == 'caption': 188 | all_results = run_caption(args, name_batch_list, img_batch_pil_list, lm_model, lm_tokenizer, clip, token_mask, logger, all_results) 189 | elif args.run_type == 'controllable': 190 | all_results = run_control(run_type, args, name_batch_list, img_batch_pil_list,lm_model, lm_tokenizer, clip, token_mask, logger, all_results) 191 | else: 192 | raise Exception('run_type must be caption or controllable!') 193 | 194 | if args.run_type == 'caption': 195 | # 保存结果 196 | save_dir = "results/caption_%s_len%d_topk%d_alpha%.3f_beta%.3f_gamma%.3f_lmTemp%.3f/sample_%d" % ( 197 | args.order,args.sentence_len, args.candidate_k, args.alpha, args.beta,args.gamma,args.lm_temperature,sample_id) 198 | if os.path.exists(save_dir) == False: 199 | os.makedirs(save_dir) 200 | for iter_id in range(len(all_results)): 201 | if iter_id!=len(all_results)-1: 202 | cur_json_file = os.path.join(save_dir,f"iter_{iter_id}.json") 203 | with open(cur_json_file,'w') as _json: 204 | json.dump(all_results[iter_id], _json) 205 | else: 206 | cur_json_file = os.path.join(save_dir,f"best_clipscore.json") 207 | with open(cur_json_file,'w') as _json: 208 | json.dump(all_results[iter_id], _json) 209 | elif args.run_type == 'controllable': 210 | save_dir = "results/%s_%s_len%d_topk%d_alpha%.3f_beta%.3f_gamma%.3f_lmTemp%.3f/sample_%d" % ( 211 | run_type,args.order,args.sentence_len, args.candidate_k, args.alpha, args.beta,args.gamma,args.lm_temperature, sample_id) 212 | if os.path.exists(save_dir) == False: 213 | os.makedirs(save_dir) 214 | for iter_id in range(len(all_results)): 215 | if iter_id!=len(all_results)-1: 216 | cur_json_file = os.path.join(save_dir,f"iter_{iter_id}.json") 217 | with open(cur_json_file,'w') as _json: 218 | json.dump(all_results[iter_id], _json) 219 | else: 220 | cur_json_file = os.path.join(save_dir,f"best_clipscore.json") 221 | with open(cur_json_file,'w') as _json: 222 | json.dump(all_results[iter_id], _json) 223 | 224 | 225 | -------------------------------------------------------------------------------- /sentiments_classifer.py: -------------------------------------------------------------------------------- 1 | from nltk.tokenize import word_tokenize 2 | from nltk import pos_tag 3 | from nltk.corpus import sentiwordnet 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | 9 | def text_POS_Sentiments_analysis(text,sentiment_ctl=None): 10 | """ 11 | id: 0,1,2,3,4 12 | pos:none,n,v,a,r 13 | """ 14 | words = word_tokenize(text) 15 | 16 | word_tag = pos_tag(words) 17 | res_tag = [tag[1] for tag in word_tag] 18 | tag_map = {'NN': 'n', 'NNP': 'n', 'NNPS': 'n', 'NNS': 'n', 'UH': 'n', \ 19 | 'VB': 'v', 'VBD': 'v', 'VBG': 'v', 'VBN': 'v', 'VBP': 'v', 'VBZ': 'v', \ 20 | 'JJ': 'a', 'JJR': 'a', 'JJS': 'a', \ 21 | 'RB': 'r', 'RBR': 'r', 'RBS': 'r', 'RP': 'r', 'WRB': 'r'} 22 | 23 | word_tag = [(t[0], tag_map[t[1]]) if t[1] in tag_map else (t[0], '') for t in word_tag] 24 | 25 | wordnet_tag = [tag[1] for tag in word_tag] 26 | sentiment_synsets = [list(sentiwordnet.senti_synsets(t[0], t[1])) for t in word_tag] 27 | 28 | if sentiment_ctl is None: 29 | return 0, res_tag, wordnet_tag 30 | score = sum(sum([x.pos_score() - x.neg_score() for x in s]) / len(s) for s in sentiment_synsets if len(s) != 0) 31 | if sentiment_ctl=="negative": 32 | score = -score 33 | return score, res_tag, wordnet_tag 34 | 35 | def batch_texts_POS_Sentiments_analysis(batch_texts, temperature,device,sentiment_ctl=None, batch_size_image=1): 36 | batch_size = len(batch_texts) 37 | senti_scores = torch.zeros(batch_size) 38 | pos_tags = [] 39 | wordnet_pos_tags = [] 40 | for b_id in range(batch_size): 41 | text = batch_texts[b_id] 42 | score, cur_tag, cur_word_tag = text_POS_Sentiments_analysis(text,sentiment_ctl=sentiment_ctl) 43 | senti_scores[b_id] = score 44 | pos_tags.append(cur_tag) 45 | wordnet_pos_tags.append(cur_word_tag) 46 | senti_scores_batch = senti_scores.view(batch_size_image, -1).to(device) 47 | senti_probs_batch = F.softmax(senti_scores_batch / temperature,dim=1).to(device) 48 | return senti_probs_batch, senti_scores_batch, pos_tags, wordnet_pos_tags 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /stop_words.txt: -------------------------------------------------------------------------------- 1 | ... 2 | [unused0] 3 | [unused1] 4 | [unused2] 5 | [unused3] 6 | [unused4] 7 | [unused5] 8 | [unused6] 9 | [unused7] 10 | [unused8] 11 | [unused9] 12 | [unused10] 13 | [unused11] 14 | [unused12] 15 | [unused13] 16 | [unused14] 17 | [unused15] 18 | [unused16] 19 | [unused17] 20 | [unused18] 21 | [unused19] 22 | [unused20] 23 | [unused21] 24 | [unused22] 25 | [unused23] 26 | [unused24] 27 | [unused25] 28 | [unused26] 29 | [unused27] 30 | [unused28] 31 | [unused29] 32 | [unused30] 33 | [unused31] 34 | [unused32] 35 | [unused33] 36 | [unused34] 37 | [unused35] 38 | [unused36] 39 | [unused37] 40 | [unused38] 41 | [unused39] 42 | [unused40] 43 | [unused41] 44 | [unused42] 45 | [unused43] 46 | [unused44] 47 | [unused45] 48 | [unused46] 49 | [unused47] 50 | [unused48] 51 | [unused49] 52 | [unused50] 53 | [unused51] 54 | [unused52] 55 | [unused53] 56 | [unused54] 57 | [unused55] 58 | [unused56] 59 | [unused57] 60 | [unused58] 61 | [unused59] 62 | [unused60] 63 | [unused61] 64 | [unused62] 65 | [unused63] 66 | [unused64] 67 | [unused65] 68 | [unused66] 69 | [unused67] 70 | [unused68] 71 | [unused69] 72 | [unused70] 73 | [unused71] 74 | [unused72] 75 | [unused73] 76 | [unused74] 77 | [unused75] 78 | [unused76] 79 | [unused77] 80 | [unused78] 81 | [unused79] 82 | [unused80] 83 | [unused81] 84 | [unused82] 85 | [unused83] 86 | [unused84] 87 | [unused85] 88 | [unused86] 89 | [unused87] 90 | [unused88] 91 | [unused89] 92 | [unused90] 93 | [unused91] 94 | [unused92] 95 | [unused93] 96 | [unused94] 97 | [unused95] 98 | [unused96] 99 | [unused97] 100 | [unused98] 101 | [unused99] 102 | [unused100] 103 | [unused101] 104 | [unused102] 105 | [unused103] 106 | [unused104] 107 | [unused105] 108 | [unused106] 109 | [unused107] 110 | [unused108] 111 | [unused109] 112 | [unused110] 113 | [unused111] 114 | [unused112] 115 | [unused113] 116 | [unused114] 117 | [unused115] 118 | [unused116] 119 | [unused117] 120 | [unused118] 121 | [unused119] 122 | [unused120] 123 | [unused121] 124 | [unused122] 125 | [unused123] 126 | [unused124] 127 | [unused125] 128 | [unused126] 129 | [unused127] 130 | [unused128] 131 | [unused129] 132 | [unused130] 133 | [unused131] 134 | [unused132] 135 | [unused133] 136 | [unused134] 137 | [unused135] 138 | [unused136] 139 | [unused137] 140 | [unused138] 141 | [unused139] 142 | [unused140] 143 | [unused141] 144 | [unused142] 145 | [unused143] 146 | [unused144] 147 | [unused145] 148 | [unused146] 149 | [unused147] 150 | [unused148] 151 | [unused149] 152 | [unused150] 153 | [unused151] 154 | [unused152] 155 | [unused153] 156 | [unused154] 157 | [unused155] 158 | [unused156] 159 | [unused157] 160 | [unused158] 161 | [unused159] 162 | [unused160] 163 | [unused161] 164 | [unused162] 165 | [unused163] 166 | [unused164] 167 | [unused165] 168 | [unused166] 169 | [unused167] 170 | [unused168] 171 | [unused169] 172 | [unused170] 173 | [unused171] 174 | [unused172] 175 | [unused173] 176 | [unused174] 177 | [unused175] 178 | [unused176] 179 | [unused177] 180 | [unused178] 181 | [unused179] 182 | [unused180] 183 | [unused181] 184 | [unused182] 185 | [unused183] 186 | [unused184] 187 | [unused185] 188 | [unused186] 189 | [unused187] 190 | [unused188] 191 | [unused189] 192 | [unused190] 193 | [unused191] 194 | [unused192] 195 | [unused193] 196 | [unused194] 197 | [unused195] 198 | [unused196] 199 | [unused197] 200 | [unused198] 201 | [unused199] 202 | [unused200] 203 | [unused201] 204 | [unused202] 205 | [unused203] 206 | [unused204] 207 | [unused205] 208 | [unused206] 209 | [unused207] 210 | [unused208] 211 | [unused209] 212 | [unused210] 213 | [unused211] 214 | [unused212] 215 | [unused213] 216 | [unused214] 217 | [unused215] 218 | [unused216] 219 | [unused217] 220 | [unused218] 221 | [unused219] 222 | [unused220] 223 | [unused221] 224 | [unused222] 225 | [unused223] 226 | [unused224] 227 | [unused225] 228 | [unused226] 229 | [unused227] 230 | [unused228] 231 | [unused229] 232 | [unused230] 233 | [unused231] 234 | [unused232] 235 | [unused233] 236 | [unused234] 237 | [unused235] 238 | [unused236] 239 | [unused237] 240 | [unused238] 241 | [unused239] 242 | [unused240] 243 | [unused241] 244 | [unused242] 245 | [unused243] 246 | [unused244] 247 | [unused245] 248 | [unused246] 249 | [unused247] 250 | [unused248] 251 | [unused249] 252 | [unused250] 253 | [unused251] 254 | [unused252] 255 | [unused253] 256 | [unused254] 257 | [unused255] 258 | [unused256] 259 | [unused257] 260 | [unused258] 261 | [unused259] 262 | [unused260] 263 | [unused261] 264 | [unused262] 265 | [unused263] 266 | [unused264] 267 | [unused265] 268 | [unused266] 269 | [unused267] 270 | [unused268] 271 | [unused269] 272 | [unused270] 273 | [unused271] 274 | [unused272] 275 | [unused273] 276 | [unused274] 277 | [unused275] 278 | [unused276] 279 | [unused277] 280 | [unused278] 281 | [unused279] 282 | [unused280] 283 | [unused281] 284 | [unused282] 285 | [unused283] 286 | [unused284] 287 | [unused285] 288 | [unused286] 289 | [unused287] 290 | [unused288] 291 | [unused289] 292 | [unused290] 293 | [unused291] 294 | [unused292] 295 | [unused293] 296 | [unused294] 297 | [unused295] 298 | [unused296] 299 | [unused297] 300 | [unused298] 301 | [unused299] 302 | [unused300] 303 | [unused301] 304 | [unused302] 305 | [unused303] 306 | [unused304] 307 | [unused305] 308 | [unused306] 309 | [unused307] 310 | [unused308] 311 | [unused309] 312 | [unused310] 313 | [unused311] 314 | [unused312] 315 | [unused313] 316 | [unused314] 317 | [unused315] 318 | [unused316] 319 | [unused317] 320 | [unused318] 321 | [unused319] 322 | [unused320] 323 | [unused321] 324 | [unused322] 325 | [unused323] 326 | [unused324] 327 | [unused325] 328 | [unused326] 329 | [unused327] 330 | [unused328] 331 | [unused329] 332 | [unused330] 333 | [unused331] 334 | [unused332] 335 | [unused333] 336 | [unused334] 337 | [unused335] 338 | [unused336] 339 | [unused337] 340 | [unused338] 341 | [unused339] 342 | [unused340] 343 | [unused341] 344 | [unused342] 345 | [unused343] 346 | [unused344] 347 | [unused345] 348 | [unused346] 349 | [unused347] 350 | [unused348] 351 | [unused349] 352 | [unused350] 353 | [unused351] 354 | [unused352] 355 | [unused353] 356 | [unused354] 357 | [unused355] 358 | [unused356] 359 | [unused357] 360 | [unused358] 361 | [unused359] 362 | [unused360] 363 | [unused361] 364 | [unused362] 365 | [unused363] 366 | [unused364] 367 | [unused365] 368 | [unused366] 369 | [unused367] 370 | [unused368] 371 | [unused369] 372 | [unused370] 373 | [unused371] 374 | [unused372] 375 | [unused373] 376 | [unused374] 377 | [unused375] 378 | [unused376] 379 | [unused377] 380 | [unused378] 381 | [unused379] 382 | [unused380] 383 | [unused381] 384 | [unused382] 385 | [unused383] 386 | [unused384] 387 | [unused385] 388 | [unused386] 389 | [unused387] 390 | [unused388] 391 | [unused389] 392 | [unused390] 393 | [unused391] 394 | [unused392] 395 | [unused393] 396 | [unused394] 397 | [unused395] 398 | [unused396] 399 | [unused397] 400 | [unused398] 401 | [unused399] 402 | [unused400] 403 | [unused401] 404 | [unused402] 405 | [unused403] 406 | [unused404] 407 | [unused405] 408 | [unused406] 409 | [unused407] 410 | [unused408] 411 | [unused409] 412 | [unused410] 413 | [unused411] 414 | [unused412] 415 | [unused413] 416 | [unused414] 417 | [unused415] 418 | [unused416] 419 | [unused417] 420 | [unused418] 421 | [unused419] 422 | [unused420] 423 | [unused421] 424 | [unused422] 425 | [unused423] 426 | [unused424] 427 | [unused425] 428 | [unused426] 429 | [unused427] 430 | [unused428] 431 | [unused429] 432 | [unused430] 433 | [unused431] 434 | [unused432] 435 | [unused433] 436 | [unused434] 437 | [unused435] 438 | [unused436] 439 | [unused437] 440 | [unused438] 441 | [unused439] 442 | [unused440] 443 | [unused441] 444 | [unused442] 445 | [unused443] 446 | [unused444] 447 | [unused445] 448 | [unused446] 449 | [unused447] 450 | [unused448] 451 | [unused449] 452 | [unused450] 453 | [unused451] 454 | [unused452] 455 | [unused453] 456 | [unused454] 457 | [unused455] 458 | [unused456] 459 | [unused457] 460 | [unused458] 461 | [unused459] 462 | [unused460] 463 | [unused461] 464 | [unused462] 465 | [unused463] 466 | [unused464] 467 | [unused465] 468 | [unused466] 469 | [unused467] 470 | [unused468] 471 | [unused469] 472 | [unused470] 473 | [unused471] 474 | [unused472] 475 | [unused473] 476 | [unused474] 477 | [unused475] 478 | [unused476] 479 | [unused477] 480 | [unused478] 481 | [unused479] 482 | [unused480] 483 | [unused481] 484 | [unused482] 485 | [unused483] 486 | [unused484] 487 | [unused485] 488 | [unused486] 489 | [unused487] 490 | [unused488] 491 | [unused489] 492 | [unused490] 493 | [unused491] 494 | [unused492] 495 | [unused493] 496 | [unused494] 497 | [unused495] 498 | [unused496] 499 | [unused497] 500 | [unused498] 501 | [unused499] 502 | [unused500] 503 | [unused501] 504 | [unused502] 505 | [unused503] 506 | [unused504] 507 | [unused505] 508 | [unused506] 509 | [unused507] 510 | [unused508] 511 | [unused509] 512 | [unused510] 513 | [unused511] 514 | [unused512] 515 | [unused513] 516 | [unused514] 517 | [unused515] 518 | [unused516] 519 | [unused517] 520 | [unused518] 521 | [unused519] 522 | [unused520] 523 | [unused521] 524 | [unused522] 525 | [unused523] 526 | [unused524] 527 | [unused525] 528 | [unused526] 529 | [unused527] 530 | [unused528] 531 | [unused529] 532 | [unused530] 533 | [unused531] 534 | [unused532] 535 | [unused533] 536 | [unused534] 537 | [unused535] 538 | [unused536] 539 | [unused537] 540 | [unused538] 541 | [unused539] 542 | [unused540] 543 | [unused541] 544 | [unused542] 545 | [unused543] 546 | [unused544] 547 | [unused545] 548 | [unused546] 549 | [unused547] 550 | [unused548] 551 | [unused549] 552 | [unused550] 553 | [unused551] 554 | [unused552] 555 | [unused553] 556 | [unused554] 557 | [unused555] 558 | [unused556] 559 | [unused557] 560 | [unused558] 561 | [unused559] 562 | [unused560] 563 | [unused561] 564 | [unused562] 565 | [unused563] 566 | [unused564] 567 | [unused565] 568 | [unused566] 569 | [unused567] 570 | [unused568] 571 | [unused569] 572 | [unused570] 573 | [unused571] 574 | [unused572] 575 | [unused573] 576 | [unused574] 577 | [unused575] 578 | [unused576] 579 | [unused577] 580 | [unused578] 581 | [unused579] 582 | [unused580] 583 | [unused581] 584 | [unused582] 585 | [unused583] 586 | [unused584] 587 | [unused585] 588 | [unused586] 589 | [unused587] 590 | [unused588] 591 | [unused589] 592 | [unused590] 593 | [unused591] 594 | [unused592] 595 | [unused593] 596 | [unused594] 597 | [unused595] 598 | [unused596] 599 | [unused597] 600 | [unused598] 601 | [unused599] 602 | [unused600] 603 | [unused601] 604 | [unused602] 605 | [unused603] 606 | [unused604] 607 | [unused605] 608 | [unused606] 609 | [unused607] 610 | [unused608] 611 | [unused609] 612 | [unused610] 613 | [unused611] 614 | [unused612] 615 | [unused613] 616 | [unused614] 617 | [unused615] 618 | [unused616] 619 | [unused617] 620 | [unused618] 621 | [unused619] 622 | [unused620] 623 | [unused621] 624 | [unused622] 625 | [unused623] 626 | [unused624] 627 | [unused625] 628 | [unused626] 629 | [unused627] 630 | [unused628] 631 | [unused629] 632 | [unused630] 633 | [unused631] 634 | [unused632] 635 | [unused633] 636 | [unused634] 637 | [unused635] 638 | [unused636] 639 | [unused637] 640 | [unused638] 641 | [unused639] 642 | [unused640] 643 | [unused641] 644 | [unused642] 645 | [unused643] 646 | [unused644] 647 | [unused645] 648 | [unused646] 649 | [unused647] 650 | [unused648] 651 | [unused649] 652 | [unused650] 653 | [unused651] 654 | [unused652] 655 | [unused653] 656 | [unused654] 657 | [unused655] 658 | [unused656] 659 | [unused657] 660 | [unused658] 661 | [unused659] 662 | [unused660] 663 | [unused661] 664 | [unused662] 665 | [unused663] 666 | [unused664] 667 | [unused665] 668 | [unused666] 669 | [unused667] 670 | [unused668] 671 | [unused669] 672 | [unused670] 673 | [unused671] 674 | [unused672] 675 | [unused673] 676 | [unused674] 677 | [unused675] 678 | [unused676] 679 | [unused677] 680 | [unused678] 681 | [unused679] 682 | [unused680] 683 | [unused681] 684 | [unused682] 685 | [unused683] 686 | [unused684] 687 | [unused685] 688 | [unused686] 689 | [unused687] 690 | [unused688] 691 | [unused689] 692 | [unused690] 693 | [unused691] 694 | [unused692] 695 | [unused693] 696 | [unused694] 697 | [unused695] 698 | [unused696] 699 | [unused697] 700 | [unused698] 701 | [unused699] 702 | [unused700] 703 | [unused701] 704 | [unused702] 705 | [unused703] 706 | [unused704] 707 | [unused705] 708 | [unused706] 709 | [unused707] 710 | [unused708] 711 | [unused709] 712 | [unused710] 713 | [unused711] 714 | [unused712] 715 | [unused713] 716 | [unused714] 717 | [unused715] 718 | [unused716] 719 | [unused717] 720 | [unused718] 721 | [unused719] 722 | [unused720] 723 | [unused721] 724 | [unused722] 725 | [unused723] 726 | [unused724] 727 | [unused725] 728 | [unused726] 729 | [unused727] 730 | [unused728] 731 | [unused729] 732 | [unused730] 733 | [unused731] 734 | [unused732] 735 | [unused733] 736 | [unused734] 737 | [unused735] 738 | [unused736] 739 | [unused737] 740 | [unused738] 741 | [unused739] 742 | [unused740] 743 | [unused741] 744 | [unused742] 745 | [unused743] 746 | [unused744] 747 | [unused745] 748 | [unused746] 749 | [unused747] 750 | [unused748] 751 | [unused749] 752 | [unused750] 753 | [unused751] 754 | [unused752] 755 | [unused753] 756 | [unused754] 757 | [unused755] 758 | [unused756] 759 | [unused757] 760 | [unused758] 761 | [unused759] 762 | [unused760] 763 | [unused761] 764 | [unused762] 765 | [unused763] 766 | [unused764] 767 | [unused765] 768 | [unused766] 769 | [unused767] 770 | [unused768] 771 | [unused769] 772 | [unused770] 773 | [unused771] 774 | [unused772] 775 | [unused773] 776 | [unused774] 777 | [unused775] 778 | [unused776] 779 | [unused777] 780 | [unused778] 781 | [unused779] 782 | [unused780] 783 | [unused781] 784 | [unused782] 785 | [unused783] 786 | [unused784] 787 | [unused785] 788 | [unused786] 789 | [unused787] 790 | [unused788] 791 | [unused789] 792 | [unused790] 793 | [unused791] 794 | [unused792] 795 | [unused793] 796 | [unused794] 797 | [unused795] 798 | [unused796] 799 | [unused797] 800 | [unused798] 801 | [unused799] 802 | [unused800] 803 | [unused801] 804 | [unused802] 805 | [unused803] 806 | [unused804] 807 | [unused805] 808 | [unused806] 809 | [unused807] 810 | [unused808] 811 | [unused809] 812 | [unused810] 813 | [unused811] 814 | [unused812] 815 | [unused813] 816 | [unused814] 817 | [unused815] 818 | [unused816] 819 | [unused817] 820 | [unused818] 821 | [unused819] 822 | [unused820] 823 | [unused821] 824 | [unused822] 825 | [unused823] 826 | [unused824] 827 | [unused825] 828 | [unused826] 829 | [unused827] 830 | [unused828] 831 | [unused829] 832 | [unused830] 833 | [unused831] 834 | [unused832] 835 | [unused833] 836 | [unused834] 837 | [unused835] 838 | [unused836] 839 | [unused837] 840 | [unused838] 841 | [unused839] 842 | [unused840] 843 | [unused841] 844 | [unused842] 845 | [unused843] 846 | [unused844] 847 | [unused845] 848 | [unused846] 849 | [unused847] 850 | [unused848] 851 | [unused849] 852 | [unused850] 853 | [unused851] 854 | [unused852] 855 | [unused853] 856 | [unused854] 857 | [unused855] 858 | [unused856] 859 | [unused857] 860 | [unused858] 861 | [unused859] 862 | [unused860] 863 | [unused861] 864 | [unused862] 865 | [unused863] 866 | [unused864] 867 | [unused865] 868 | [unused866] 869 | [unused867] 870 | [unused868] 871 | [unused869] 872 | [unused870] 873 | [unused871] 874 | [unused872] 875 | [unused873] 876 | [unused874] 877 | [unused875] 878 | [unused876] 879 | [unused877] 880 | [unused878] 881 | [unused879] 882 | [unused880] 883 | [unused881] 884 | [unused882] 885 | [unused883] 886 | [unused884] 887 | [unused885] 888 | [unused886] 889 | [unused887] 890 | [unused888] 891 | [unused889] 892 | [unused890] 893 | [unused891] 894 | [unused892] 895 | [unused893] 896 | [unused894] 897 | [unused895] 898 | [unused896] 899 | [unused897] 900 | [unused898] 901 | [unused899] 902 | [unused900] 903 | [unused901] 904 | [unused902] 905 | [unused903] 906 | [unused904] 907 | [unused905] 908 | [unused906] 909 | [unused907] 910 | [unused908] 911 | [unused909] 912 | [unused910] 913 | [unused911] 914 | [unused912] 915 | [unused913] 916 | [unused914] 917 | [unused915] 918 | [unused916] 919 | [unused917] 920 | [unused918] 921 | [unused919] 922 | [unused920] 923 | [unused921] 924 | [unused922] 925 | [unused923] 926 | [unused924] 927 | [unused925] 928 | [unused926] 929 | [unused927] 930 | [unused928] 931 | [unused929] 932 | [unused930] 933 | [unused931] 934 | [unused932] 935 | [unused933] 936 | [unused934] 937 | [unused935] 938 | [unused936] 939 | [unused937] 940 | [unused938] 941 | [unused939] 942 | [unused940] 943 | [unused941] 944 | [unused942] 945 | [unused943] 946 | [unused944] 947 | [unused945] 948 | [unused946] 949 | [unused947] 950 | [unused948] 951 | [unused949] 952 | [unused950] 953 | [unused951] 954 | [unused952] 955 | [unused953] 956 | [unused954] 957 | [unused955] 958 | [unused956] 959 | [unused957] 960 | [unused958] 961 | [unused959] 962 | [unused960] 963 | [unused961] 964 | [unused962] 965 | [unused963] 966 | [unused964] 967 | [unused965] 968 | [unused966] 969 | [unused967] 970 | [unused968] 971 | [unused969] 972 | [unused970] 973 | [unused971] 974 | [unused972] 975 | [unused973] 976 | [unused974] 977 | [unused975] 978 | [unused976] 979 | [unused977] 980 | [unused978] 981 | [unused979] 982 | [unused980] 983 | [unused981] 984 | [unused982] 985 | [unused983] 986 | [unused984] 987 | [unused985] 988 | [unused986] 989 | [unused987] 990 | [unused988] 991 | [unused989] 992 | [unused990] 993 | [unused991] 994 | [unused992] 995 | [unused993] 996 | ! 997 | " 998 | # 999 | $ 1000 | % 1001 | & 1002 | ' 1003 | ( 1004 | ) 1005 | * 1006 | + 1007 | , 1008 | - 1009 | / 1010 | : 1011 | ; 1012 | < 1013 | = 1014 | > 1015 | ? 1016 | @ 1017 | [ 1018 | \ 1019 | ] 1020 | ^ 1021 | _ 1022 | ` 1023 | { 1024 | | 1025 | } 1026 | ~ 1027 | ¡ 1028 | ¢ 1029 | £ 1030 | ¤ 1031 | ¥ 1032 | ¦ 1033 | § 1034 | ¨ 1035 | © 1036 | ª 1037 | « 1038 | ¬ 1039 | ® 1040 | ° 1041 | ± 1042 | ² 1043 | ³ 1044 | ´ 1045 | µ 1046 | ¶ 1047 | · 1048 | ¹ 1049 | º 1050 | » 1051 | ¼ 1052 | ½ 1053 | ¾ 1054 | ¿ 1055 | × 1056 | ß 1057 | æ 1058 | ð 1059 | ÷ 1060 | ø 1061 | þ 1062 | đ 1063 | ħ 1064 | ı 1065 | ł 1066 | ŋ 1067 | œ 1068 | ƒ 1069 | ɐ 1070 | ɑ 1071 | ɒ 1072 | ɔ 1073 | ɕ 1074 | ə 1075 | ɛ 1076 | ɡ 1077 | ɣ 1078 | ɨ 1079 | ɪ 1080 | ɫ 1081 | ɬ 1082 | ɯ 1083 | ɲ 1084 | ɴ 1085 | ɹ 1086 | ɾ 1087 | ʀ 1088 | ʁ 1089 | ʂ 1090 | ʃ 1091 | ʉ 1092 | ʊ 1093 | ʋ 1094 | ʌ 1095 | ʎ 1096 | ʐ 1097 | ʑ 1098 | ʒ 1099 | ʔ 1100 | ʰ 1101 | ʲ 1102 | ʳ 1103 | ʷ 1104 | ʸ 1105 | ʻ 1106 | ʼ 1107 | ʾ 1108 | ʿ 1109 | ˈ 1110 | ː 1111 | ˡ 1112 | ˢ 1113 | ˣ 1114 | ˤ 1115 | α 1116 | β 1117 | γ 1118 | δ 1119 | ε 1120 | ζ 1121 | η 1122 | θ 1123 | ι 1124 | κ 1125 | λ 1126 | μ 1127 | ν 1128 | ξ 1129 | ο 1130 | π 1131 | ρ 1132 | ς 1133 | σ 1134 | τ 1135 | υ 1136 | φ 1137 | χ 1138 | ψ 1139 | ω 1140 | а 1141 | б 1142 | в 1143 | г 1144 | д 1145 | е 1146 | ж 1147 | з 1148 | и 1149 | к 1150 | л 1151 | м 1152 | н 1153 | о 1154 | п 1155 | р 1156 | с 1157 | т 1158 | у 1159 | ф 1160 | х 1161 | ц 1162 | ч 1163 | ш 1164 | щ 1165 | ъ 1166 | ы 1167 | ь 1168 | э 1169 | ю 1170 | я 1171 | ђ 1172 | є 1173 | і 1174 | ј 1175 | љ 1176 | њ 1177 | ћ 1178 | ӏ 1179 | ա 1180 | բ 1181 | գ 1182 | դ 1183 | ե 1184 | թ 1185 | ի 1186 | լ 1187 | կ 1188 | հ 1189 | մ 1190 | յ 1191 | ն 1192 | ո 1193 | պ 1194 | ս 1195 | վ 1196 | տ 1197 | ր 1198 | ւ 1199 | ք 1200 | ־ 1201 | א 1202 | ב 1203 | ג 1204 | ד 1205 | ה 1206 | ו 1207 | ז 1208 | ח 1209 | ט 1210 | י 1211 | ך 1212 | כ 1213 | ל 1214 | ם 1215 | מ 1216 | ן 1217 | נ 1218 | ס 1219 | ע 1220 | ף 1221 | פ 1222 | ץ 1223 | צ 1224 | ק 1225 | ר 1226 | ש 1227 | ת 1228 | ، 1229 | ء 1230 | ا 1231 | ب 1232 | ة 1233 | ت 1234 | ث 1235 | ج 1236 | ح 1237 | خ 1238 | د 1239 | ذ 1240 | ر 1241 | ز 1242 | س 1243 | ش 1244 | ص 1245 | ض 1246 | ط 1247 | ظ 1248 | ع 1249 | غ 1250 | ـ 1251 | ف 1252 | ق 1253 | ك 1254 | ل 1255 | م 1256 | ن 1257 | ه 1258 | و 1259 | ى 1260 | ي 1261 | ٹ 1262 | پ 1263 | چ 1264 | ک 1265 | گ 1266 | ں 1267 | ھ 1268 | ہ 1269 | ی 1270 | ے 1271 | अ 1272 | आ 1273 | उ 1274 | ए 1275 | क 1276 | ख 1277 | ग 1278 | च 1279 | ज 1280 | ट 1281 | ड 1282 | ण 1283 | त 1284 | थ 1285 | द 1286 | ध 1287 | न 1288 | प 1289 | ब 1290 | भ 1291 | म 1292 | य 1293 | र 1294 | ल 1295 | व 1296 | श 1297 | ष 1298 | स 1299 | ह 1300 | ा 1301 | ि 1302 | ी 1303 | ो 1304 | । 1305 | ॥ 1306 | ং 1307 | অ 1308 | আ 1309 | ই 1310 | উ 1311 | এ 1312 | ও 1313 | ক 1314 | খ 1315 | গ 1316 | চ 1317 | ছ 1318 | জ 1319 | ট 1320 | ড 1321 | ণ 1322 | ত 1323 | থ 1324 | দ 1325 | ধ 1326 | ন 1327 | প 1328 | ব 1329 | ভ 1330 | ম 1331 | য 1332 | র 1333 | ল 1334 | শ 1335 | ষ 1336 | স 1337 | হ 1338 | া 1339 | ি 1340 | ী 1341 | ে 1342 | க 1343 | ச 1344 | ட 1345 | த 1346 | ந 1347 | ன 1348 | ப 1349 | ம 1350 | ய 1351 | ர 1352 | ல 1353 | ள 1354 | வ 1355 | ா 1356 | ி 1357 | ு 1358 | ே 1359 | ை 1360 | ನ 1361 | ರ 1362 | ಾ 1363 | ක 1364 | ය 1365 | ර 1366 | ල 1367 | ව 1368 | ා 1369 | ก 1370 | ง 1371 | ต 1372 | ท 1373 | น 1374 | พ 1375 | ม 1376 | ย 1377 | ร 1378 | ล 1379 | ว 1380 | ส 1381 | อ 1382 | า 1383 | เ 1384 | ་ 1385 | ། 1386 | ག 1387 | ང 1388 | ད 1389 | ན 1390 | པ 1391 | བ 1392 | མ 1393 | འ 1394 | ར 1395 | ལ 1396 | ས 1397 | မ 1398 | ა 1399 | ბ 1400 | გ 1401 | დ 1402 | ე 1403 | ვ 1404 | თ 1405 | ი 1406 | კ 1407 | ლ 1408 | მ 1409 | ნ 1410 | ო 1411 | რ 1412 | ს 1413 | ტ 1414 | უ 1415 | ᄀ 1416 | ᄂ 1417 | ᄃ 1418 | ᄅ 1419 | ᄆ 1420 | ᄇ 1421 | ᄉ 1422 | ᄊ 1423 | ᄋ 1424 | ᄌ 1425 | ᄎ 1426 | ᄏ 1427 | ᄐ 1428 | ᄑ 1429 | ᄒ 1430 | ᅡ 1431 | ᅢ 1432 | ᅥ 1433 | ᅦ 1434 | ᅧ 1435 | ᅩ 1436 | ᅪ 1437 | ᅭ 1438 | ᅮ 1439 | ᅯ 1440 | ᅲ 1441 | ᅳ 1442 | ᅴ 1443 | ᅵ 1444 | ᆨ 1445 | ᆫ 1446 | ᆯ 1447 | ᆷ 1448 | ᆸ 1449 | ᆼ 1450 | ᴬ 1451 | ᴮ 1452 | ᴰ 1453 | ᴵ 1454 | ᴺ 1455 | ᵀ 1456 | ᵃ 1457 | ᵇ 1458 | ᵈ 1459 | ᵉ 1460 | ᵍ 1461 | ᵏ 1462 | ᵐ 1463 | ᵒ 1464 | ᵖ 1465 | ᵗ 1466 | ᵘ 1467 | ᵢ 1468 | ᵣ 1469 | ᵤ 1470 | ᵥ 1471 | ᶜ 1472 | ᶠ 1473 | ‐ 1474 | ‑ 1475 | ‒ 1476 | – 1477 | — 1478 | ― 1479 | ‖ 1480 | ‘ 1481 | ’ 1482 | ‚ 1483 | “ 1484 | ” 1485 | „ 1486 | † 1487 | ‡ 1488 | • 1489 | … 1490 | ‰ 1491 | ′ 1492 | ″ 1493 | › 1494 | ‿ 1495 | ⁄ 1496 | ⁰ 1497 | ⁱ 1498 | ⁴ 1499 | ⁵ 1500 | ⁶ 1501 | ⁷ 1502 | ⁸ 1503 | ⁹ 1504 | ⁺ 1505 | ⁻ 1506 | ⁿ 1507 | ₀ 1508 | ₁ 1509 | ₂ 1510 | ₃ 1511 | ₄ 1512 | ₅ 1513 | ₆ 1514 | ₇ 1515 | ₈ 1516 | ₉ 1517 | ₊ 1518 | ₍ 1519 | ₎ 1520 | ₐ 1521 | ₑ 1522 | ₒ 1523 | ₓ 1524 | ₕ 1525 | ₖ 1526 | ₗ 1527 | ₘ 1528 | ₙ 1529 | ₚ 1530 | ₛ 1531 | ₜ 1532 | ₤ 1533 | ₩ 1534 | € 1535 | ₱ 1536 | ₹ 1537 | ℓ 1538 | № 1539 | ℝ 1540 | ™ 1541 | ⅓ 1542 | ⅔ 1543 | ← 1544 | ↑ 1545 | → 1546 | ↓ 1547 | ↔ 1548 | ↦ 1549 | ⇄ 1550 | ⇌ 1551 | ⇒ 1552 | ∂ 1553 | ∅ 1554 | ∆ 1555 | ∇ 1556 | ∈ 1557 | − 1558 | ∗ 1559 | ∘ 1560 | √ 1561 | ∞ 1562 | ∧ 1563 | ∨ 1564 | ∩ 1565 | ∪ 1566 | ≈ 1567 | ≡ 1568 | ≤ 1569 | ≥ 1570 | ⊂ 1571 | ⊆ 1572 | ⊕ 1573 | ⊗ 1574 | ⋅ 1575 | ─ 1576 | │ 1577 | ■ 1578 | ▪ 1579 | ● 1580 | ★ 1581 | ☆ 1582 | ☉ 1583 | ♠ 1584 | ♣ 1585 | ♥ 1586 | ♦ 1587 | ♭ 1588 | ♯ 1589 | ⟨ 1590 | ⟩ 1591 | ⱼ 1592 | ⺩ 1593 | ⺼ 1594 | ⽥ 1595 | 、 1596 | 。 1597 | 〈 1598 | 〉 1599 | 《 1600 | 》 1601 | 「 1602 | 」 1603 | 『 1604 | 』 1605 | 〜 1606 | あ 1607 | い 1608 | う 1609 | え 1610 | お 1611 | か 1612 | き 1613 | く 1614 | け 1615 | こ 1616 | さ 1617 | し 1618 | す 1619 | せ 1620 | そ 1621 | た 1622 | ち 1623 | っ 1624 | つ 1625 | て 1626 | と 1627 | な 1628 | に 1629 | ぬ 1630 | ね 1631 | の 1632 | は 1633 | ひ 1634 | ふ 1635 | へ 1636 | ほ 1637 | ま 1638 | み 1639 | む 1640 | め 1641 | も 1642 | や 1643 | ゆ 1644 | よ 1645 | ら 1646 | り 1647 | る 1648 | れ 1649 | ろ 1650 | を 1651 | ん 1652 | ァ 1653 | ア 1654 | ィ 1655 | イ 1656 | ウ 1657 | ェ 1658 | エ 1659 | オ 1660 | カ 1661 | キ 1662 | ク 1663 | ケ 1664 | コ 1665 | サ 1666 | シ 1667 | ス 1668 | セ 1669 | タ 1670 | チ 1671 | ッ 1672 | ツ 1673 | テ 1674 | ト 1675 | ナ 1676 | ニ 1677 | ノ 1678 | ハ 1679 | ヒ 1680 | フ 1681 | ヘ 1682 | ホ 1683 | マ 1684 | ミ 1685 | ム 1686 | メ 1687 | モ 1688 | ャ 1689 | ュ 1690 | ョ 1691 | ラ 1692 | リ 1693 | ル 1694 | レ 1695 | ロ 1696 | ワ 1697 | ン 1698 | ・ 1699 | ー 1700 | 一 1701 | 三 1702 | 上 1703 | 下 1704 | 不 1705 | 世 1706 | 中 1707 | 主 1708 | 久 1709 | 之 1710 | 也 1711 | 事 1712 | 二 1713 | 五 1714 | 井 1715 | 京 1716 | 人 1717 | 亻 1718 | 仁 1719 | 介 1720 | 代 1721 | 仮 1722 | 伊 1723 | 会 1724 | 佐 1725 | 侍 1726 | 保 1727 | 信 1728 | 健 1729 | 元 1730 | 光 1731 | 八 1732 | 公 1733 | 内 1734 | 出 1735 | 分 1736 | 前 1737 | 劉 1738 | 力 1739 | 加 1740 | 勝 1741 | 北 1742 | 区 1743 | 十 1744 | 千 1745 | 南 1746 | 博 1747 | 原 1748 | 口 1749 | 古 1750 | 史 1751 | 司 1752 | 合 1753 | 吉 1754 | 同 1755 | 名 1756 | 和 1757 | 囗 1758 | 四 1759 | 国 1760 | 國 1761 | 土 1762 | 地 1763 | 坂 1764 | 城 1765 | 堂 1766 | 場 1767 | 士 1768 | 夏 1769 | 外 1770 | 大 1771 | 天 1772 | 太 1773 | 夫 1774 | 奈 1775 | 女 1776 | 子 1777 | 学 1778 | 宀 1779 | 宇 1780 | 安 1781 | 宗 1782 | 定 1783 | 宣 1784 | 宮 1785 | 家 1786 | 宿 1787 | 寺 1788 | 將 1789 | 小 1790 | 尚 1791 | 山 1792 | 岡 1793 | 島 1794 | 崎 1795 | 川 1796 | 州 1797 | 巿 1798 | 帝 1799 | 平 1800 | 年 1801 | 幸 1802 | 广 1803 | 弘 1804 | 張 1805 | 彳 1806 | 後 1807 | 御 1808 | 德 1809 | 心 1810 | 忄 1811 | 志 1812 | 忠 1813 | 愛 1814 | 成 1815 | 我 1816 | 戦 1817 | 戸 1818 | 手 1819 | 扌 1820 | 政 1821 | 文 1822 | 新 1823 | 方 1824 | 日 1825 | 明 1826 | 星 1827 | 春 1828 | 昭 1829 | 智 1830 | 曲 1831 | 書 1832 | 月 1833 | 有 1834 | 朝 1835 | 木 1836 | 本 1837 | 李 1838 | 村 1839 | 東 1840 | 松 1841 | 林 1842 | 森 1843 | 楊 1844 | 樹 1845 | 橋 1846 | 歌 1847 | 止 1848 | 正 1849 | 武 1850 | 比 1851 | 氏 1852 | 民 1853 | 水 1854 | 氵 1855 | 氷 1856 | 永 1857 | 江 1858 | 沢 1859 | 河 1860 | 治 1861 | 法 1862 | 海 1863 | 清 1864 | 漢 1865 | 瀬 1866 | 火 1867 | 版 1868 | 犬 1869 | 王 1870 | 生 1871 | 田 1872 | 男 1873 | 疒 1874 | 発 1875 | 白 1876 | 的 1877 | 皇 1878 | 目 1879 | 相 1880 | 省 1881 | 真 1882 | 石 1883 | 示 1884 | 社 1885 | 神 1886 | 福 1887 | 禾 1888 | 秀 1889 | 秋 1890 | 空 1891 | 立 1892 | 章 1893 | 竹 1894 | 糹 1895 | 美 1896 | 義 1897 | 耳 1898 | 良 1899 | 艹 1900 | 花 1901 | 英 1902 | 華 1903 | 葉 1904 | 藤 1905 | 行 1906 | 街 1907 | 西 1908 | 見 1909 | 訁 1910 | 語 1911 | 谷 1912 | 貝 1913 | 貴 1914 | 車 1915 | 軍 1916 | 辶 1917 | 道 1918 | 郎 1919 | 郡 1920 | 部 1921 | 都 1922 | 里 1923 | 野 1924 | 金 1925 | 鈴 1926 | 镇 1927 | 長 1928 | 門 1929 | 間 1930 | 阝 1931 | 阿 1932 | 陳 1933 | 陽 1934 | 雄 1935 | 青 1936 | 面 1937 | 風 1938 | 食 1939 | 香 1940 | 馬 1941 | 高 1942 | 龍 1943 | 龸 1944 | fi 1945 | fl 1946 | ! 1947 | ( 1948 | ) 1949 | - 1950 | . 1951 | / 1952 | : 1953 | ? 1954 | ~ 1955 | 0 1956 | 1 1957 | 2 1958 | 3 1959 | 4 1960 | 5 1961 | 6 1962 | 7 1963 | 8 1964 | 9 1965 | ² 1966 | ³ 1967 | ¹ 1968 | ⁰ 1969 | ⁴ 1970 | ⁵ 1971 | ⁶ 1972 | ⁷ 1973 | ⁸ 1974 | ⁹ 1975 | ₀ 1976 | ₁ 1977 | ₂ 1978 | ₃ 1979 | ₄ 1980 | ₅ 1981 | ₆ 1982 | ₇ 1983 | ₈ 1984 | ₉ 1985 | 10 1986 | 000 1987 | 2010 1988 | 2011 1989 | 12 1990 | 2012 1991 | 2008 1992 | 2009 1993 | 2013 1994 | 2007 1995 | 2006 1996 | 2014 1997 | 15 1998 | 20 1999 | 18 2000 | 2015 2001 | 11 2002 | 2016 2003 | 30 2004 | 2005 2005 | 16 2006 | 14 2007 | 13 2008 | 2017 2009 | 25 2010 | 2004 2011 | 2000 2012 | 17 2013 | 24 2014 | 2003 2015 | 2002 2016 | 100 2017 | 21 2018 | 19 2019 | 2001 2020 | 22 2021 | 23 2022 | 1999 2023 | 28 2024 | 26 2025 | 27 2026 | 1998 2027 | 1997 2028 | 1996 2029 | 50 2030 | 29 2031 | 2018 2032 | 1995 2033 | 1994 2034 | 1992 2035 | 1993 2036 | 31 2037 | 40 2038 | 1991 2039 | 1990 2040 | 1989 2041 | 1988 2042 | 1987 2043 | 1986 2044 | 1985 2045 | 1984 2046 | 1980 2047 | 500 2048 | 1983 2049 | 1982 2050 | 1979 2051 | 1981 2052 | 200 2053 | 1972 2054 | 1976 2055 | 1978 2056 | 1974 2057 | 1975 2058 | 1977 2059 | 1970 2060 | 1968 2061 | 1973 2062 | 1945 2063 | 1971 2064 | 45 2065 | 60 2066 | 1969 2067 | 1967 2068 | 35 2069 | 65 2070 | 1964 2071 | 1966 2072 | 1965 2073 | 32 2074 | 1960 2075 | 1944 2076 | 1963 2077 | 1962 2078 | 1942 2079 | 80 2080 | 1961 2081 | 1943 2082 | 1956 2083 | 1958 2084 | 1959 2085 | 1941 2086 | 1940 2087 | 1948 2088 | 1957 2089 | 1939 2090 | 1946 2091 | 1950 2092 | 90 2093 | 33 2094 | 70 2095 | 1955 2096 | 300 2097 | 1952 2098 | 00 2099 | 1947 2100 | 44 2101 | 36 2102 | 1954 2103 | 1953 2104 | 1949 2105 | 34 2106 | 1951 2107 | 64 2108 | 38 2109 | 1938 2110 | 37 2111 | 1936 2112 | 1918 2113 | 400 2114 | 75 2115 | 1937 2116 | 42 2117 | 1935 2118 | 1920 2119 | 39 2120 | 48 2121 | 1930 2122 | 1919 2123 | 1933 2124 | 1914 2125 | 1934 2126 | 55 2127 | 1917 2128 | 41 2129 | 1929 2130 | 1928 2131 | 1932 2132 | 47 2133 | 52 2134 | 43 2135 | 1931 2136 | 49 2137 | 1927 2138 | 1922 2139 | 46 2140 | 1924 2141 | 1925 2142 | 51 2143 | 1912 2144 | 1926 2145 | 1921 2146 | 978 2147 | 1923 2148 | 1915 2149 | 1916 2150 | 1910 2151 | 150 2152 | 1913 2153 | 54 2154 | 1900 2155 | 600 2156 | 56 2157 | 1911 2158 | 53 2159 | 1908 2160 | 95 2161 | 59 2162 | 800 2163 | 58 2164 | 57 2165 | 1905 2166 | 08 2167 | 1906 2168 | 1907 2169 | 250 2170 | 1909 2171 | 99 2172 | 85 2173 | 09 2174 | 1904 2175 | 05 2176 | 07 2177 | 06 2178 | 66 2179 | 1902 2180 | 1901 2181 | 1903 2182 | 62 2183 | 98 2184 | 72 2185 | 04 2186 | 01 2187 | 96 2188 | 97 2189 | 03 2190 | 120 2191 | 1898 2192 | 88 2193 | 61 2194 | 93 2195 | 76 2196 | 67 2197 | 1899 2198 | 02 2199 | 63 2200 | 1890 2201 | 91 2202 | 92 2203 | 77 2204 | 68 2205 | 78 2206 | 81 2207 | 1895 2208 | 1896 2209 | 1897 2210 | 700 2211 | 69 2212 | 74 2213 | 94 2214 | 71 2215 | 84 2216 | 73 2217 | 82 2218 | 1889 2219 | 89 2220 | 1893 2221 | 1892 2222 | 79 2223 | 1894 2224 | 86 2225 | 1885 2226 | 87 2227 | 1891 2228 | 83 2229 | 1888 2230 | 1000 2231 | 1864 2232 | 1865 2233 | 1880 2234 | 1887 2235 | 1861 2236 | 1862 2237 | 1863 2238 | 1886 2239 | 1870 2240 | 1884 2241 | 1881 2242 | 1882 2243 | 1883 2244 | 1878 2245 | 110 2246 | 1860 2247 | 1876 2248 | 1871 2249 | 1879 2250 | 1875 2251 | 1867 2252 | 1877 2253 | 130 2254 | 1872 2255 | 1868 2256 | 1874 2257 | 1873 2258 | 1866 2259 | 900 2260 | 1869 2261 | 101 2262 | 1850 2263 | 1848 2264 | 160 2265 | 1859 2266 | 1857 2267 | 180 2268 | 1854 2269 | 1855 2270 | 1858 2271 | 140 2272 | 350 2273 | 1856 2274 | 125 2275 | 105 2276 | 1852 2277 | 1851 2278 | 1840 2279 | 1853 2280 | 1849 2281 | 1847 2282 | 1846 2283 | 102 2284 | 360 2285 | 1830 2286 | 1845 2287 | 104 2288 | 750 2289 | 1837 2290 | 1844 2291 | 103 2292 | 1800 2293 | 1841 2294 | 1812 2295 | 1838 2296 | 1842 2297 | 1839 2298 | 1843 2299 | 1836 2300 | 106 2301 | 1835 2302 | 1832 2303 | 450 2304 | 1500 2305 | 2019 2306 | 220 2307 | 107 2308 | 115 2309 | 1815 2310 | 1834 2311 | 108 2312 | 170 2313 | 1831 2314 | 1814 2315 | 1833 2316 | 1820 2317 | 111 2318 | 112 2319 | 240 2320 | 1825 2321 | 135 2322 | 1828 2323 | 109 2324 | 1829 2325 | 1824 2326 | 1821 2327 | 1810 2328 | 230 2329 | 190 2330 | 128 2331 | 3000 2332 | 1826 2333 | 1818 2334 | 113 2335 | 1813 2336 | 1822 2337 | 1827 2338 | 1816 2339 | 1793 2340 | 1801 2341 | 114 2342 | 1806 2343 | 1823 2344 | 1817 2345 | 1819 2346 | 117 2347 | 121 2348 | 2020 2349 | 1803 2350 | 1809 2351 | 175 2352 | 210 2353 | 116 2354 | 118 2355 | 127 2356 | 1798 2357 | 1808 2358 | 1811 2359 | 122 2360 | 1805 2361 | 123 2362 | 1804 2363 | 1794 2364 | 1807 2365 | 550 2366 | 119 2367 | 1790 2368 | 1795 2369 | 124 2370 | 1792 2371 | 280 2372 | 5000 2373 | 1802 2374 | 260 2375 | 320 2376 | 1789 2377 | 145 2378 | 270 2379 | 650 2380 | 1799 2381 | 1796 2382 | 165 2383 | 1776 2384 | 126 2385 | 132 2386 | 1797 2387 | 155 2388 | 330 2389 | 1775 2390 | 1791 2391 | 129 2392 | 133 2393 | 131 2394 | 144 2395 | 1200 2396 | 1600 2397 | 137 2398 | 225 2399 | 152 2400 | 138 2401 | 1780 2402 | 134 2403 | 1783 2404 | 185 2405 | 136 2406 | 141 2407 | 1788 2408 | 850 2409 | 340 2410 | 1787 2411 | 143 2412 | 142 2413 | 1777 2414 | 501 2415 | 205 2416 | 1778 2417 | 146 2418 | 201 2419 | 370 2420 | 148 2421 | 147 2422 | 1784 2423 | 151 2424 | 1700 2425 | 139 2426 | 154 2427 | 153 2428 | 156 2429 | 167 2430 | 1781 2431 | 202 2432 | 1758 2433 | 1782 2434 | 168 2435 | 380 2436 | 310 2437 | 290 2438 | 1785 2439 | 460 2440 | 256 2441 | 480 2442 | 195 2443 | 149 2444 | 161 2445 | 157 2446 | 215 2447 | 440 2448 | 1786 2449 | 420 2450 | 1772 2451 | 275 2452 | 1774 2453 | 192 2454 | 1779 2455 | 182 2456 | 158 2457 | 1770 2458 | 235 2459 | 162 2460 | 163 2461 | 164 2462 | 1660 2463 | 375 2464 | 177 2465 | 212 2466 | 1750 2467 | 171 2468 | 172 2469 | 1763 2470 | 208 2471 | 203 2472 | 176 2473 | 169 2474 | 181 2475 | 166 2476 | 183 2477 | 206 2478 | 159 2479 | 222 2480 | 1760 2481 | 188 2482 | 301 2483 | 410 2484 | 211 2485 | 178 2486 | 365 2487 | 209 2488 | 173 2489 | 187 2490 | 174 2491 | 1300 2492 | 430 2493 | 221 2494 | 186 2495 | 520 2496 | 204 2497 | 325 2498 | 184 2499 | 224 2500 | 640 2501 | 1768 2502 | 610 2503 | 207 2504 | 191 2505 | 213 2506 | 1773 2507 | 214 2508 | 194 2509 | 197 2510 | 193 2511 | 303 2512 | 911 2513 | 198 2514 | 390 2515 | 196 2516 | 4000 2517 | 540 2518 | 216 2519 | 231 2520 | 179 2521 | 950 2522 | 217 2523 | 305 2524 | 189 2525 | 265 2526 | 219 2527 | 255 2528 | 1400 2529 | 1769 2530 | 232 2531 | 1771 2532 | 199 2533 | 218 2534 | 1765 2535 | 223 2536 | 1762 2537 | 660 2538 | 245 2539 | 226 2540 | 312 2541 | 470 2542 | 333 2543 | 560 2544 | 1761 2545 | 1766 2546 | 1755 2547 | 1764 2548 | 227 2549 | 1767 2550 | 1640 2551 | 264 2552 | 1759 2553 | 295 2554 | 1740 2555 | 285 2556 | 1745 2557 | 1650 2558 | 262 2559 | 234 2560 | 238 2561 | 302 2562 | 737 2563 | 1100 2564 | 233 2565 | 254 2566 | 228 2567 | 490 2568 | 241 2569 | 1756 2570 | 246 2571 | 242 2572 | 1648 2573 | 251 2574 | 1754 2575 | 1715 2576 | 1757 2577 | 401 2578 | 1689 2579 | 229 2580 | 625 2581 | 720 2582 | 243 2583 | 252 2584 | 315 2585 | 281 2586 | 313 2587 | 287 2588 | 253 2589 | 1730 2590 | 425 2591 | 237 2592 | 247 2593 | 510 2594 | 1644 2595 | 530 2596 | 311 2597 | 1720 2598 | 236 2599 | 630 2600 | 620 2601 | 249 2602 | 239 2603 | 580 2604 | 322 2605 | 345 2606 | 1753 2607 | 1710 2608 | 304 2609 | 802 2610 | 680 2611 | 316 2612 | 405 2613 | 321 2614 | 1661 2615 | 1642 2616 | 1688 2617 | 435 2618 | 244 2619 | 272 2620 | 308 2621 | 1620 2622 | 257 2623 | 258 2624 | 512 2625 | 335 2626 | 385 2627 | 1751 2628 | 261 2629 | 1748 2630 | 1746 2631 | 1747 2632 | 307 2633 | 248 2634 | 1680 2635 | 306 2636 | 760 2637 | 395 2638 | 415 2639 | 1749 2640 | 278 2641 | 1752 2642 | 1690 2643 | 404 2644 | 288 2645 | 570 2646 | 286 2647 | 1630 2648 | 1707 2649 | 309 2650 | 1685 2651 | 271 2652 | 2500 2653 | 276 2654 | 268 2655 | 266 2656 | 590 2657 | 259 2658 | 980 2659 | 1714 2660 | 263 2661 | 328 2662 | 1741 2663 | 1727 2664 | 273 2665 | 747 2666 | 323 2667 | 267 2668 | 283 2669 | 1643 2670 | 670 2671 | 277 2672 | 274 2673 | 001 2674 | 1743 2675 | 525 2676 | 1603 2677 | 1725 2678 | 2021 2679 | 1641 2680 | 1742 2681 | 269 2682 | 279 2683 | 292 2684 | 1610 2685 | 1739 2686 | 740 2687 | 1744 2688 | 412 2689 | 999 2690 | 1662 2691 | 299 2692 | 6000 2693 | 1701 2694 | 1735 2695 | 1645 2696 | 357 2697 | 1550 2698 | 1670 2699 | 314 2700 | 1625 2701 | 282 2702 | 355 2703 | 1724 2704 | 319 2705 | 1649 2706 | 1723 2707 | 317 2708 | 960 2709 | 820 2710 | 1722 2711 | 1737 2712 | 1702 2713 | 1728 2714 | 880 2715 | 284 2716 | 293 2717 | 521 2718 | 1718 2719 | 318 2720 | 1713 2721 | 1621 2722 | 289 2723 | 291 2724 | 1675 2725 | 296 2726 | 1733 2727 | 324 2728 | 298 2729 | 1672 2730 | 1708 2731 | 1734 2732 | 1666 2733 | 1683 2734 | 1635 2735 | 406 2736 | 1654 2737 | 1638 2738 | 297 2739 | 356 2740 | 411 2741 | 417 2742 | 1717 2743 | 331 2744 | 1540 2745 | 1732 2746 | 1667 2747 | 875 2748 | 710 2749 | 1665 2750 | 1721 2751 | 910 2752 | 1704 2753 | 343 2754 | 354 2755 | 1629 2756 | 338 2757 | 1679 2758 | 336 2759 | 730 2760 | 1738 2761 | 441 2762 | 402 2763 | 1609 2764 | 690 2765 | 840 2766 | 1622 2767 | 294 2768 | 451 2769 | 1719 2770 | 326 2771 | 1736 2772 | 1086 2773 | 1605 2774 | 403 2775 | 1716 2776 | 1632 2777 | 475 2778 | 1580 2779 | 1659 2780 | 1726 2781 | 341 2782 | 1703 2783 | 1656 2784 | 1655 2785 | 1731 2786 | 1729 2787 | 1711 2788 | 1712 2789 | 327 2790 | 351 2791 | 1664 2792 | 337 2793 | 1634 2794 | 1624 2795 | 780 2796 | 1692 2797 | 1628 2798 | 1697 2799 | 1016 2800 | 050 2801 | 1699 2802 | 1604 2803 | 1611 2804 | 1646 2805 | 1626 2806 | 1652 2807 | 870 2808 | 1570 2809 | 352 2810 | 407 2811 | 1658 2812 | 505 2813 | 1709 2814 | 339 2815 | 1663 2816 | 1618 2817 | 1623 2818 | 770 2819 | 1651 2820 | 1695 2821 | 1560 2822 | 1612 2823 | 422 2824 | 495 2825 | 1653 2826 | 1705 2827 | 332 2828 | 381 2829 | 930 2830 | 344 2831 | 421 2832 | 1682 2833 | 555 2834 | 334 2835 | 329 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import colorlog 4 | import random 5 | import torch 6 | 7 | 8 | def create_logger(folder, filename): 9 | log_colors = { 10 | 'DEBUG': 'blue', 11 | 'INFO': 'white', 12 | 'WARNING': 'green', 13 | 'ERROR': 'red', 14 | 'CRITICAL': 'yellow', 15 | } 16 | 17 | import logging 18 | logger = logging.getLogger('ConZIC') 19 | # %(filename)s$RESET:%(lineno)d 20 | # LOGFORMAT = "%(log_color)s%(asctime)s [%(log_color)s%(filename)s:%(lineno)d] | %(log_color)s%(message)s%(reset)s |" 21 | LOGFORMAT = "" 22 | LOG_LEVEL = logging.DEBUG 23 | logging.root.setLevel(LOG_LEVEL) 24 | stream = logging.StreamHandler() 25 | stream.setLevel(LOG_LEVEL) 26 | stream.setFormatter(colorlog.ColoredFormatter(LOGFORMAT, datefmt='%d %H:%M', log_colors=log_colors)) 27 | 28 | # print to log file 29 | hdlr = logging.FileHandler(os.path.join(folder, filename)) 30 | hdlr.setLevel(LOG_LEVEL) 31 | # hdlr.setFormatter(logging.Formatter("[%(asctime)s] %(message)s")) 32 | hdlr.setFormatter(logging.Formatter("%(message)s")) 33 | logger.addHandler(hdlr) 34 | logger.addHandler(stream) 35 | return logger 36 | 37 | def set_seed(seed): 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | 46 | def get_init_text(tokenizer, seed_text, max_len, batch_size=1): 47 | """ Get initial sentence by padding seed_text with [mask] words to max_len """ 48 | text = seed_text + tokenizer.mask_token * max_len 49 | ids = tokenizer.encode(text) 50 | batch = [ids] * batch_size 51 | return batch 52 | 53 | def update_token_mask(tokenizer, token_mask, max_len, index): 54 | """ '.'(full stop) is only allowed in the last token position """ 55 | if index == max_len - 1: 56 | token_mask[:, tokenizer.vocab['.']] = 1 57 | else: 58 | token_mask[:, tokenizer.vocab['.']] = 0 59 | return token_mask 60 | 61 | def format_output(sample_num, FinalCaption, BestCaption): 62 | if sample_num == 1: 63 | return f"{FinalCaption[0]}", f"{BestCaption[0]}" 64 | elif sample_num ==2: 65 | return f"{FinalCaption[0]}\n{FinalCaption[1]}", f"{BestCaption[0]}\n{BestCaption[1]}" 66 | elif sample_num ==3: 67 | return f"{FinalCaption[0]}\n{FinalCaption[1]}\n{FinalCaption[2]}",\ 68 | f"{BestCaption[0]}\n{BestCaption[1]}\n{BestCaption[2]}" 69 | elif sample_num ==4: 70 | return f"{FinalCaption[0]}\n{FinalCaption[1]}\n{FinalCaption[2]}\n{FinalCaption[3]}",\ 71 | f"{BestCaption[0]}\n{BestCaption[1]}\n{BestCaption[2]}\n{BestCaption[3]}" 72 | else: 73 | return f"{FinalCaption[0]}\n{FinalCaption[1]}\n{FinalCaption[2]}\n{FinalCaption[3]}\n{FinalCaption[4]}",\ 74 | f"{BestCaption[0]}\n{BestCaption[1]}\n{BestCaption[2]}\n{BestCaption[3]}\n{BestCaption[4]}" --------------------------------------------------------------------------------