├── .DS_Store ├── .gitignore ├── LICENSE ├── NLTKPerceptronPosTagger └── averaged_perceptron_tagger.pickle ├── README.md ├── Word2Vec └── .gitkeep ├── attack ├── .DS_Store ├── attack_lib.py ├── best_img_path │ └── .DS_Store ├── compute_img_sim.py ├── english.py ├── intermediate_img_path │ └── .DS_Store ├── run_attack.py ├── target.png └── word2vec │ ├── .DS_Store │ └── word2vec_embed.py └── target_model ├── .DS_Store └── min_dalle ├── .DS_Store ├── __init__.py ├── image_from_text.py ├── pretrained └── .gitkeep ├── replicate ├── cog.yaml └── predictor.py ├── setup.py └── tkinter_ui.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Computer Security and Privacy Laboratory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NLTKPerceptronPosTagger/averaged_perceptron_tagger.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/NLTKPerceptronPosTagger/averaged_perceptron_tagger.pickle -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RIATIG: Reliable and Imperceptible Adversarial Text-to-Image Generation with Natural Prompts 2 | 3 | ## Install 4 | 5 | ### CLIP 6 | Please follow the guidelines in [CLIP Github Repository](https://github.com/openai/CLIP) to install CLIP 7 | 8 | 9 | ### DALL•E Mini 10 | Run the following command to install DALL•E Mini: 11 | ``` 12 | $ pip install min-dalle 13 | ``` 14 | Go into the following folder: 15 | ``` 16 | $ cd /target_model/min_dalle/pretrained 17 | ``` 18 | Download and uncompress the files: [dalle-bart](https://drive.google.com/file/d/1Qq_FARjdZlHra3r_g2ZvMLyDPzsgx6Af/view?usp=sharing) and [vqgan](https://drive.google.com/file/d/1ckxflXZnnWJzRvFHhpzuj11Pxr07Wby0/view?usp=sharing) 19 | 20 | ### Word2Vec 21 | Go into the folder: 22 | ``` 23 | $ cd /Word2Vec 24 | ``` 25 | Download the files: [word2id.pkl](https://drive.google.com/file/d/11kSfFGm1YOo5N08GGytnZy4cMpDTyd0h/view?usp=sharing) and [wordvec.pkl](https://drive.google.com/file/d/1h1hhkyZWZc-JhKqJBPtnJ2riooXMY-e0/view?usp=sharing) 26 | 27 | 28 | ## Run attack 29 | To run our attack: 30 | ``` 31 | python run_attack.py --ori_sent [original sentence] --tar_img_path [target image path] --tar_sent [target sentence] --log_save_path [log save path] --intem_img_path [intermediate results save path] --best_img_path [output best images save path] --mutate_by_impor [whether select the word by importance in mutation] 32 | ``` 33 | 34 | For a quick demo: 35 | ``` 36 | python run_attack.py --ori_sent "a herd of cows that are grazing on the grass" --tar_img_path "./target.png" --tar_sent "a large red and white boat floating on top of a lake" 37 | ``` 38 | 39 | ## Citation 40 | If you find our work useful, please cite: 41 | 42 | ``` 43 | @InProceedings{Liu_2023_CVPR, 44 | author = {Liu, Han and Wu, Yuhao and Zhai, Shixuan and Yuan, Bo and Zhang, Ning}, 45 | title = {RIATIG: Reliable and Imperceptible Adversarial Text-to-Image Generation With Natural Prompts}, 46 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 47 | month = {June}, 48 | year = {2023}, 49 | pages = {20585-20594} 50 | } 51 | ``` 52 | 53 | ## Acknowledements 54 | 55 | Thanks for the open souce code: 56 | #### CLIP: https://github.com/openai/CLIP 57 | #### DALL•E Mini: https://github.com/kuprel/min-dalle 58 | #### OpenAttack: https://github.com/thunlp/OpenAttack -------------------------------------------------------------------------------- /Word2Vec/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/Word2Vec/.gitkeep -------------------------------------------------------------------------------- /attack/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/attack/.DS_Store -------------------------------------------------------------------------------- /attack/attack_lib.py: -------------------------------------------------------------------------------- 1 | 2 | from nltk.tokenize import RegexpTokenizer 3 | from english import ENGLISH_FILTER_WORDS 4 | import random 5 | import sys 6 | from word2vec.word2vec_embed import Word2VecSubstitute 7 | import shutil 8 | import numpy as np 9 | from compute_img_sim import compute_img_sim 10 | 11 | sys.path.append('../target_model/min_dalle') 12 | from image_from_text import dalle_mini_gen_img_from_text 13 | 14 | class Error(Exception): 15 | """Base class for other exceptions""" 16 | pass 17 | 18 | class WordNotInDictionaryException(Error): 19 | """Raised when the input value is too small""" 20 | pass 21 | 22 | class attack(): 23 | def __init__(self, tar_sent): 24 | #nothing to be initialized 25 | 26 | self.count = 0 27 | 28 | tokenizer = RegexpTokenizer(r'\w+') 29 | tokens = tokenizer.tokenize(tar_sent.lower()) 30 | 31 | #filter unimportant word 32 | 33 | target_word_ls = [] 34 | for token in tokens: 35 | if token.lower() in ENGLISH_FILTER_WORDS: 36 | continue 37 | target_word_ls.append(token) 38 | self.target_sent_tokens = target_word_ls 39 | print("tar_sent_tokens: ", self.target_sent_tokens) 40 | 41 | self.Word2vec = Word2VecSubstitute(tar_tokens=self.target_sent_tokens) 42 | 43 | print("initialize attack class.") 44 | 45 | 46 | def selectBug(self, original_word, if_initial=False, word_idx=None, x_prime=None): 47 | bugs = self.generateBugs(original_word, if_initial) 48 | target_num = random.randint(0, len(bugs)-1) 49 | bugs_ls = list(bugs.values()) 50 | #randomly select a bug to return 51 | bug_choice = bugs_ls[target_num] 52 | return bug_choice 53 | 54 | 55 | def replaceWithBug(self, x_prime, word_idx, bug): 56 | return x_prime[:word_idx] + [bug] + x_prime[word_idx + 1:] 57 | 58 | def generateBugs(self, word, if_initial=False, sub_w_enabled=False, typo_enabled=False): 59 | 60 | if if_initial: 61 | bugs = {"insert": word, "sub_W": word, "del_C": word, "sub_tar_W": word} 62 | if len(word) <= 2: 63 | return bugs 64 | bugs["insert"] = self.bug_insert(word) 65 | bugs["sub_W"] = self.bug_sub_W(word) 66 | bugs["del_C"] = self.bug_delete(word) 67 | bugs["sub_tar_W"] = self.bug_sub_tar_W(word) 68 | else: 69 | bugs = {"sub_W": word, "ins_C": word} 70 | if len(word) <= 2: 71 | return bugs 72 | bugs["sub_W"] = self.bug_sub_W(word) 73 | bugs["ins_C"] = self.bug_insert(word) 74 | 75 | return bugs 76 | 77 | def bug_sub_tar_W(self, word): 78 | word_index = random.randint(0, len(self.target_sent_tokens) - 1) 79 | tar_word = self.target_sent_tokens[word_index] 80 | res = self.Word2vec.substitute(tar_word) 81 | if len(res) == 0: 82 | return word 83 | return res[0][0] 84 | 85 | def bug_sub_W(self, word): 86 | try: 87 | res = self.Word2vec.substitute(word) 88 | if len(res) == 0: 89 | return word 90 | return res[0][0] 91 | except WordNotInDictionaryException: 92 | return word 93 | 94 | def bug_insert(self, word): 95 | if len(word) >= 6: 96 | return word 97 | res = word 98 | point = random.randint(1, len(word) - 1) 99 | #insert _ instread " " 100 | res = res[0:point] + "_" + res[point:] 101 | return res 102 | 103 | def bug_delete(self, word): 104 | res = word 105 | point = random.randint(1, len(word) - 2) 106 | res = res[0:point] + res[point + 1:] 107 | return res 108 | 109 | def bug_swap(self, word): 110 | if len(word) <= 4: 111 | return word 112 | res = word 113 | points = random.sample(range(1, len(word) - 1), 2) 114 | a = points[0] 115 | b = points[1] 116 | 117 | res = list(res) 118 | w = res[a] 119 | res[a] = res[b] 120 | res[b] = w 121 | res = ''.join(res) 122 | return res 123 | 124 | def bug_random_sub(self, word): 125 | res = word 126 | point = random.randint(0, len(word)-1) 127 | 128 | choices = "qwertyuiopasdfghjklzxcvbnm" 129 | 130 | subbed_choice = choices[random.randint(0, len(list(choices))-1)] 131 | res = list(res) 132 | res[point] = subbed_choice 133 | res = ''.join(res) 134 | return res 135 | 136 | def bug_convert_to_leet(self, word): 137 | # Dictionary that maps each letter to its leet speak equivalent. 138 | leet_dict = { 139 | 'a': '4', 140 | 'b': '8', 141 | 'e': '3', 142 | 'g': '6', 143 | 'l': '1', 144 | 'o': '0', 145 | 's': '5', 146 | 't': '7' 147 | } 148 | 149 | # Replace each letter in the text with its leet speak equivalent. 150 | res = ''.join(leet_dict.get(c.lower(), c) for c in word) 151 | 152 | return res 153 | 154 | 155 | def bug_sub_C(self, word): 156 | res = word 157 | key_neighbors = self.get_key_neighbors() 158 | point = random.randint(0, len(word) - 1) 159 | 160 | if word[point] not in key_neighbors: 161 | return word 162 | choices = key_neighbors[word[point]] 163 | subbed_choice = choices[random.randint(0, len(choices) - 1)] 164 | res = list(res) 165 | res[point] = subbed_choice 166 | res = ''.join(res) 167 | 168 | return res 169 | 170 | def get_key_neighbors(self): 171 | ## TODO: support other language here 172 | # By keyboard proximity 173 | neighbors = { 174 | "q": "was", "w": "qeasd", "e": "wrsdf", "r": "etdfg", "t": "ryfgh", "y": "tughj", "u": "yihjk", 175 | "i": "uojkl", "o": "ipkl", "p": "ol", 176 | "a": "qwszx", "s": "qweadzx", "d": "wersfxc", "f": "ertdgcv", "g": "rtyfhvb", "h": "tyugjbn", 177 | "j": "yuihknm", "k": "uiojlm", "l": "opk", 178 | "z": "asx", "x": "sdzc", "c": "dfxv", "v": "fgcb", "b": "ghvn", "n": "hjbm", "m": "jkn" 179 | } 180 | # By visual proximity 181 | neighbors['i'] += '1' 182 | neighbors['l'] += '1' 183 | neighbors['z'] += '2' 184 | neighbors['e'] += '3' 185 | neighbors['a'] += '4' 186 | neighbors['s'] += '5' 187 | neighbors['g'] += '6' 188 | neighbors['b'] += '8' 189 | neighbors['g'] += '9' 190 | neighbors['q'] += '9' 191 | neighbors['o'] += '0' 192 | 193 | return neighbors 194 | 195 | def sort_words_by_importance(tokenizer, save_dir, ori_sent, tar_img_path): 196 | tokens = tokenizer.tokenize(ori_sent.lower()) 197 | sim_ls = [] 198 | 199 | for i in range(len(tokens)): 200 | new_tokens = tokens[:i] + tokens[i+1:] 201 | x_prime_sent = " ".join(new_tokens) 202 | 203 | x_img_path = save_dir + "gen.png" 204 | 205 | dalle_mini_gen_img_from_text(x_prime_sent, x_img_path) 206 | 207 | similarity = compute_img_sim(x_img_path, tar_img_path) 208 | sim_ls.append(similarity.item()) 209 | 210 | sim_arr = np.array(sim_ls) 211 | scores_logits = np.exp(sim_arr - sim_arr.max()) 212 | sim_probs = scores_logits / scores_logits.sum() 213 | 214 | return sim_probs 215 | -------------------------------------------------------------------------------- /attack/best_img_path/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/attack/best_img_path/.DS_Store -------------------------------------------------------------------------------- /attack/compute_img_sim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | from PIL import Image 4 | 5 | def compute_img_sim(img1_path, img2_path): 6 | 7 | device = "cuda" if torch.cuda.is_available() else "cpu" 8 | model, preprocess = clip.load("ViT-B/32", device=device) 9 | 10 | image_1 = preprocess(Image.open(img1_path)).unsqueeze(0).to(device) 11 | image_2 = preprocess(Image.open(img2_path)).unsqueeze(0).to(device) 12 | 13 | with torch.no_grad(): 14 | img1_features = model.encode_image(image_1) 15 | img2_features = model.encode_image(image_2) 16 | 17 | img1_features /= img1_features.norm(dim=-1, keepdim=True) 18 | img2_features /= img2_features.norm(dim=-1, keepdim=True) 19 | 20 | similarity = 100. * (img1_features @ img2_features.T) 21 | 22 | return similarity -------------------------------------------------------------------------------- /attack/english.py: -------------------------------------------------------------------------------- 1 | ENGLISH_FILTER_WORDS = [ 2 | 'a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost', 3 | 'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another', 4 | 'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as', 5 | 'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides', 6 | 'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn', 7 | "didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere', 8 | 'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for', 9 | 'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence', 10 | 'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his', 11 | 'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's", 12 | 'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn', 13 | "mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself', 14 | 'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none', 15 | 'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only', 16 | 'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per', 17 | 'please', 's', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow', 18 | 'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs', 19 | 'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein', 20 | 'thereupon', 'these', 'they', 'this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too', 21 | 'toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't", 22 | 'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where', 23 | 'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while', 24 | 'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won', 25 | "won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've", 26 | 'your', 'yours', 'yourself', 'yourselves', 'have', 'be' 27 | ] -------------------------------------------------------------------------------- /attack/intermediate_img_path/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/attack/intermediate_img_path/.DS_Store -------------------------------------------------------------------------------- /attack/run_attack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import random 5 | import shutil 6 | from nltk.tokenize import RegexpTokenizer 7 | from english import ENGLISH_FILTER_WORDS 8 | from compute_img_sim import compute_img_sim 9 | from attack_lib import attack 10 | from attack_lib import sort_words_by_importance 11 | sys.path.append('../target_model/min_dalle') 12 | from image_from_text import dalle_mini_gen_img_from_text 13 | import argparse 14 | 15 | 16 | def check_if_contains(tokens): 17 | flag = False 18 | loc = 0 19 | for token in tokens: 20 | if "_" in token: 21 | flag = True 22 | break 23 | loc += 1 24 | return flag, loc 25 | 26 | def check_if_in_list(sent, sent_ls): 27 | flag = False 28 | for tar_sent in sent_ls: 29 | if sent == tar_sent: 30 | flag = True 31 | break 32 | return flag 33 | 34 | 35 | def get_new_pop(elite_pop, elite_pop_scores, pop_size): 36 | scores_logits = np.exp(elite_pop_scores - elite_pop_scores.max()) 37 | elite_pop_probs = scores_logits / scores_logits.sum() 38 | 39 | cand1 = [elite_pop[i] for i in np.random.choice(len(elite_pop), p=elite_pop_probs, size=pop_size)] 40 | cand2 = [elite_pop[i] for i in np.random.choice(len(elite_pop), p=elite_pop_probs, size=pop_size)] 41 | 42 | #exchange two parts randomly 43 | mask = np.random.rand(pop_size, len(elite_pop[0])) < 0.5 44 | 45 | next_pop = [] 46 | pop_index = 0 47 | for pop_flag in mask: 48 | pop = [] 49 | word_index = 0 50 | for word_flag in pop_flag: 51 | if word_flag: 52 | pop.append(cand1[pop_index][word_index]) 53 | else: 54 | pop.append(cand2[pop_index][word_index]) 55 | word_index += 1 56 | next_pop.append(pop) 57 | pop_index += 1 58 | 59 | return next_pop 60 | 61 | 62 | class Genetic(): 63 | 64 | def __init__(self, ori_sent, tar_img_path, tar_sent, log_save_path, intem_img_path, best_img_path, mutate_by_impor): 65 | 66 | self.init_pop_size = 150 67 | self.pop_size = 15 68 | self.elite_size = 8 69 | self.mutation_p = 0.85 70 | self.mu = 0.99 71 | self.alpha = 0.001 72 | self.max_iters = 50 73 | self.store_thres = 80 74 | 75 | self.target_img_path = tar_img_path 76 | self.log_save_path = log_save_path 77 | self.intermediate_path = intem_img_path 78 | self.best_img_path = best_img_path 79 | self.target_sent = tar_sent 80 | self.mutate_by_impor = mutate_by_impor 81 | 82 | #initialize attack class 83 | self.attack_cls = attack(self.target_sent) 84 | 85 | #initialize tokenizer 86 | self.tokenizer = RegexpTokenizer(r'\w+') 87 | tokens = self.tokenizer.tokenize(ori_sent.lower()) 88 | 89 | #generate large initialization corpus 90 | self.pop = self.initial_mutate(tokens, self.init_pop_size) 91 | print("initial pop: ", self.pop) 92 | 93 | def initial_mutate(self, pop, nums): 94 | #random select the pop sentence that will mutate 95 | new_pop = [pop] 96 | new_sent_ls = [" ".join(pop)] 97 | 98 | #append the list until it fills out nums 99 | count = 0 100 | while count < nums-1: 101 | word_idx = np.random.choice(len(pop), size=1) 102 | word = pop[word_idx[0]] 103 | if word.lower() in ENGLISH_FILTER_WORDS: 104 | continue 105 | 106 | bug = self.attack_cls.selectBug(word, if_initial=True) 107 | tokens = self.attack_cls.replaceWithBug(pop, word_idx[0], bug) 108 | #join it into a sentence 109 | x_prime_sent = " ".join(tokens) 110 | if (check_if_in_list(x_prime_sent, new_sent_ls)): 111 | continue 112 | 113 | new_sent_ls.append(x_prime_sent) 114 | new_pop.append(tokens) 115 | count += 1 116 | print("current count: ", count) 117 | 118 | return new_pop 119 | 120 | 121 | def get_fitness_score(self, input_tokens): 122 | #get fitness score of all the sentences 123 | sim_score_ls = [] 124 | 125 | for tokens in input_tokens: 126 | x_prime_sent = " ".join(tokens) 127 | x_prime_sent = x_prime_sent.replace("_", " ") 128 | 129 | x_img_path = self.intermediate_path + "gen.png" 130 | 131 | dalle_mini_gen_img_from_text(x_prime_sent, x_img_path) 132 | 133 | similarity = compute_img_sim(x_img_path, self.target_img_path) 134 | 135 | #if similarity > self.store_thres: 136 | # best_ori_path = self.best_img_path + x_prime_sent + "_score_" + str(similarity.item()) + ".png" 137 | # shutil.copy(x_img_path, best_ori_path) 138 | 139 | sim_score_ls.append(similarity.item()) 140 | 141 | print(f"x_prime_sent: {x_prime_sent}, similarity: {similarity.item()}") 142 | sim_score_arr = np.array(sim_score_ls) 143 | return sim_score_arr 144 | 145 | def mutate_pop(self, pop, mutation_p, mutate_by_impor): 146 | #random select the pop sentence that will mutate 147 | mask = np.random.rand(len(pop)) < mutation_p 148 | new_pop = [] 149 | pop_index = 0 150 | for flag in mask: 151 | if not flag: 152 | new_pop.append(pop[pop_index]) 153 | else: 154 | tokens = pop[pop_index] 155 | 156 | if mutate_by_impor: 157 | x_prime_sent = " ".join(tokens) 158 | sim_probs = sort_words_by_importance(self.tokenizer, self.intermediate_path, x_prime_sent, self.target_img_path) 159 | word_idx = np.random.choice(len(tokens), p=sim_probs, size=1) 160 | else: 161 | word_idx = np.random.choice(len(tokens), size=1) 162 | word = tokens[word_idx[0]] 163 | 164 | if word.lower() in ENGLISH_FILTER_WORDS: 165 | new_pop.append(pop[pop_index]) 166 | continue 167 | 168 | word_slice = word.split("_") 169 | if len(word_slice) > 1: 170 | #randomly choose one 171 | sub_word_idx = np.random.choice(len(word_slice), size=1) 172 | sub_word = word_slice[sub_word_idx[0]] 173 | bug = self.attack_cls.selectBug(sub_word, if_initial=False) 174 | word_slice[sub_word_idx[0]] = bug 175 | final_bug = '_'.join(word_slice) 176 | else: 177 | final_bug = self.attack_cls.selectBug(word, if_initial=False) 178 | 179 | tokens = self.attack_cls.replaceWithBug(tokens, word_idx[0], final_bug) 180 | new_pop.append(tokens) 181 | pop_index += 1 182 | 183 | return new_pop 184 | 185 | def run(self, log=None): 186 | best_save_dir = self.best_img_path 187 | itr = 1 188 | prev_score = None 189 | save_dir = self.intermediate_path 190 | best_score = float("-inf") 191 | if log is not None: 192 | log.write('target phrase: ' + self.target_sent + '\n') 193 | 194 | while itr <= self.max_iters: 195 | 196 | print(f"-----------itr num:{itr}----------------") 197 | log.write("------------- iteration:" + str(itr) + " ---------------\n") 198 | pop_scores = self.get_fitness_score(self.pop) 199 | elite_ind = np.argsort(pop_scores)[-self.elite_size:] 200 | elite_pop = [self.pop[i] for i in elite_ind] 201 | elite_pop_scores = pop_scores[elite_ind] 202 | 203 | print("current best score: ", elite_pop_scores[-1]) 204 | 205 | for i in elite_ind: 206 | if pop_scores[i] > self.store_thres: 207 | x_prime_sent_store = " ".join(self.pop[i]) 208 | x_prime_sent_store = x_prime_sent_store.replace("_", " ") 209 | log.write(str(pop_scores[i]) + " " + x_prime_sent_store + "\n") 210 | 211 | if elite_pop_scores[-1] > best_score: 212 | best_score = elite_pop_scores[-1] 213 | #store the current best image 214 | x_prime_sent = " ".join(elite_pop[-1]) 215 | x_prime_sent = x_prime_sent.replace("_", " ") 216 | 217 | x_img_path = save_dir + "gen.png" 218 | 219 | dalle_mini_gen_img_from_text(x_prime_sent, x_img_path) 220 | 221 | best_ori_path = best_save_dir + "itr_" + str(itr) + "_score_" + str(elite_pop_scores[-1]) + ".png" 222 | shutil.copy(x_img_path, best_ori_path) 223 | 224 | #new best adversarial sentences 225 | log.write("new best adv: " + str(elite_pop_scores[-1]) + " " + x_prime_sent + "\n") 226 | log.flush() 227 | 228 | 229 | if prev_score is not None and prev_score != elite_pop_scores[-1]: 230 | self.mutation_p = self.mu * self.mutation_p + self.alpha / np.abs(elite_pop_scores[-1] - prev_score) 231 | 232 | next_pop = get_new_pop(elite_pop, elite_pop_scores, self.pop_size) 233 | 234 | self.pop = self.mutate_pop(next_pop, self.mutation_p, self.mutate_by_impor) 235 | 236 | prev_score = elite_pop_scores[-1] 237 | itr += 1 238 | 239 | return 240 | 241 | if __name__ == "__main__": 242 | 243 | parser = argparse.ArgumentParser() 244 | parser.add_argument('--ori_sent', type=str, required=True, help='original sentence') 245 | parser.add_argument('--tar_img_path', type=str, required=True, help='target image path') 246 | parser.add_argument('--tar_sent', type=str, required=True, help='target sentence') 247 | parser.add_argument('--log_save_path', type=str, default='run_log.txt', help='path to save log') 248 | parser.add_argument('--intem_img_path', type=str, default='./intermediate_img_path/', help='path to save intermediate imgs') 249 | parser.add_argument('--best_img_path', type=str, default='./best_img_path/', help='path to save best output imgs') 250 | parser.add_argument('--mutate_by_impor', type=bool, default=False, help='whether select word by importance in mutation') 251 | args = parser.parse_args() 252 | 253 | g = Genetic(args.ori_sent, args.tar_img_path, args.tar_sent, args.log_save_path, args.intem_img_path, args.best_img_path, args.mutate_by_impor) 254 | with open(args.log_save_path, 'w') as log: 255 | g.run(log=log) 256 | -------------------------------------------------------------------------------- /attack/target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/attack/target.png -------------------------------------------------------------------------------- /attack/word2vec/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/attack/word2vec/.DS_Store -------------------------------------------------------------------------------- /attack/word2vec/word2vec_embed.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | import torch 3 | import numpy as np 4 | import pickle 5 | import os 6 | import random 7 | import re 8 | import nltk 9 | 10 | 11 | class Error(Exception): 12 | """Base class for other exceptions""" 13 | pass 14 | 15 | 16 | class WordNotInDictionaryException(Error): 17 | """Raised when the input value is too small""" 18 | pass 19 | 20 | 21 | 22 | class WordEmbedding(): 23 | def __init__(self, word2id : Dict[str, int], embedding) -> None: 24 | self.word2id = word2id 25 | self.embedding = embedding 26 | 27 | def transform(self, word, token_unk): 28 | if word in self.word2id: 29 | return self.embedding[ self.word2id[word] ] 30 | else: 31 | if isinstance(token_unk, int): 32 | return self.embedding[ token_unk ] 33 | else: 34 | return self.embedding[ self.word2id[ token_unk ] ] 35 | 36 | 37 | 38 | def LOAD(path): 39 | word2id = pickle.load( open( os.path.join(path, "word2id.pkl"), "rb") ) 40 | wordvec = pickle.load( open( os.path.join(path, "wordvec.pkl"), "rb") ) 41 | return WordEmbedding(word2id, wordvec) 42 | 43 | 44 | def LOAD_perceptron_tagger(path): 45 | ret = __import__("nltk").tag.PerceptronTagger(load=False) 46 | ret.load("file:" + os.path.join(path, "averaged_perceptron_tagger.pickle")) 47 | return ret.tag 48 | 49 | 50 | _POS_MAPPING = { 51 | "JJ": "adj", 52 | "VB": "verb", 53 | "NN": "noun", 54 | "RB": "adv" 55 | } 56 | 57 | def isEnglish(s): 58 | try: 59 | s.encode(encoding='utf-8').decode('ascii') 60 | except UnicodeDecodeError: 61 | return False 62 | else: 63 | return True 64 | 65 | def check_word_in_tokens(word, tokens): 66 | flag = False 67 | for token in tokens: 68 | if token == word: 69 | flag = True 70 | break 71 | return flag 72 | 73 | class Word2VecSubstitute(): 74 | 75 | def __init__(self, tar_tokens=None, cosine=False, k = 100, threshold = 10, device = None): 76 | """ 77 | Embedding based word substitute. 78 | 79 | Args: 80 | word2id: A `dict` maps words to indexes. 81 | embedding: A word embedding matrix. 82 | cosine: If `true` then the cosine distance is used, otherwise the Euclidian distance is used. 83 | threshold: Distance threshold. Default: 0.5 84 | k: Top-k results to return. If k is `None`, all results will be returned. Default: 50 85 | device: A pytocrh device for computing distances. Default: "cpu" 86 | 87 | """ 88 | 89 | if device is None: 90 | device = "cpu" 91 | 92 | #load wordvec 93 | wordvec = LOAD("../Word2Vec/") 94 | 95 | self.tar_tokens = tar_tokens 96 | self.word2id = wordvec.word2id 97 | self.embedding = torch.from_numpy(wordvec.embedding) 98 | self.cosine = cosine 99 | self.k = k 100 | self.threshold = threshold 101 | 102 | self.id2word = { 103 | val: key for key, val in self.word2id.items() 104 | } 105 | 106 | if cosine: 107 | self.embedding = self.embedding / self.embedding.norm(dim=1, keepdim=True) 108 | 109 | self.embedding = self.embedding.to(device) 110 | self.pos_tagger = LOAD_perceptron_tagger("../NLTKPerceptronPosTagger/") 111 | 112 | 113 | def get_pos(self, word, pos_tagging=True): 114 | 115 | tokens = [word] 116 | for word, pos in self.pos_tagger(tokens): 117 | if pos[:2] in _POS_MAPPING: 118 | mapped_pos = _POS_MAPPING[pos[:2]] 119 | else: 120 | mapped_pos = "other" 121 | return mapped_pos 122 | 123 | 124 | def substitute(self, word): 125 | if word not in self.word2id: 126 | return [] 127 | 128 | #get pos of word 129 | ori_pos = self.get_pos(word) 130 | 131 | wdid = self.word2id[word] 132 | wdvec = self.embedding[wdid, :] 133 | if self.cosine: 134 | dis = 1 - (wdvec * self.embedding).sum(dim=1) 135 | else: 136 | dis = (wdvec - self.embedding).norm(dim=1) 137 | 138 | idx = dis.argsort() 139 | 140 | if self.k is not None: 141 | idx = idx[:self.k] 142 | 143 | #filter index dis that are larger than threshold 144 | 145 | output_idx = [] 146 | for i in idx: 147 | if dis[i] < self.threshold and dis[i] != 0: 148 | #print(f"idx: {i}, dis[i]: {dis[i]}") 149 | output_idx.append(i.item()) 150 | 151 | idx = output_idx 152 | new_idx = [] 153 | #filter strange long word 154 | for id_ in idx: 155 | flag = True 156 | pos = self.get_pos(self.id2word[id_]) 157 | 158 | if "_" in self.id2word[id_]: 159 | word_slc = self.id2word[id_].split("_") 160 | 161 | for word_s in word_slc: 162 | if isEnglish(word_s)==False: 163 | flag = False 164 | break 165 | else: 166 | if isEnglish(self.id2word[id_])==False: 167 | flag = False 168 | 169 | #check if in target token list 170 | is_in_target = False 171 | if self.tar_tokens is not None: 172 | is_in_target = check_word_in_tokens(self.id2word[id_], self.tar_tokens) 173 | 174 | if flag and pos == ori_pos: 175 | if not is_in_target: 176 | new_idx.append(id_) 177 | 178 | if len(new_idx) == 0: 179 | return [] 180 | idx = random.choices(new_idx, k=1) 181 | 182 | return [ 183 | (self.id2word[id_], dis[id_].item()) for id_ in idx 184 | ] -------------------------------------------------------------------------------- /target_model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/target_model/.DS_Store -------------------------------------------------------------------------------- /target_model/min_dalle/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/target_model/min_dalle/.DS_Store -------------------------------------------------------------------------------- /target_model/min_dalle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/target_model/min_dalle/__init__.py -------------------------------------------------------------------------------- /target_model/min_dalle/image_from_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PIL import Image 4 | from min_dalle import MinDalle 5 | import torch 6 | 7 | def ascii_from_image(image: Image.Image, size: int = 128) -> str: 8 | gray_pixels = image.resize((size, int(0.55 * size))).convert('L').getdata() 9 | chars = list('.,;/IOX') 10 | chars = [chars[i * len(chars) // 256] for i in gray_pixels] 11 | chars = [chars[i * size: (i + 1) * size] for i in range(size // 2)] 12 | return '\n'.join(''.join(row) for row in chars) 13 | 14 | 15 | def save_image(image: Image.Image, path: str): 16 | if os.path.isdir(path): 17 | path = os.path.join(path, 'generated.png') 18 | elif not path.endswith('.png'): 19 | path += '.png' 20 | image.save(path) 21 | return image 22 | 23 | 24 | def generate_image( 25 | is_mega: bool, 26 | text: str, 27 | seed: int, 28 | grid_size: int, 29 | top_k: int, 30 | image_path: str, 31 | models_root: str, 32 | fp16: bool, 33 | ): 34 | model = MinDalle( 35 | is_mega=is_mega, 36 | models_root=models_root, 37 | is_reusable=False, 38 | is_verbose=True, 39 | dtype=torch.float16 if fp16 else torch.float32 40 | ) 41 | 42 | image = model.generate_image( 43 | text, 44 | seed, 45 | grid_size, 46 | top_k=top_k, 47 | is_verbose=True 48 | ) 49 | save_image(image, image_path) 50 | 51 | def dalle_mini_gen_img_from_text(ori_sent, ori_img_path, seed=-1): 52 | 53 | generate_image( 54 | is_mega=False, 55 | text=ori_sent, 56 | seed=seed, 57 | grid_size=1, 58 | top_k=256, 59 | image_path=ori_img_path, 60 | models_root='pretrained', 61 | fp16=False, 62 | ) 63 | 64 | return 65 | -------------------------------------------------------------------------------- /target_model/min_dalle/pretrained/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUSTL-CSPL/RIATIG/10524da55cab6b58ad34831ee661d4be2a5688f4/target_model/min_dalle/pretrained/.gitkeep -------------------------------------------------------------------------------- /target_model/min_dalle/replicate/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | cuda: "11.5.1" 3 | gpu: true 4 | python_version: "3.10" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | python_packages: 9 | - "min-dalle==0.4.5" 10 | - "emoji==1.7.0" 11 | run: 12 | - pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html 13 | 14 | predict: "predictor.py:ReplicatePredictor" -------------------------------------------------------------------------------- /target_model/min_dalle/replicate/predictor.py: -------------------------------------------------------------------------------- 1 | from min_dalle import MinDalle 2 | import tempfile 3 | import string 4 | import torch, torch.backends.cudnn, torch.backends.cuda 5 | from typing import Iterator 6 | from emoji import demojize 7 | from cog import BasePredictor, Path, Input 8 | 9 | torch.backends.cudnn.deterministic = False 10 | torch.backends.cudnn.benchmark = True 11 | torch.backends.cuda.matmul.allow_tf32 = True 12 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True 13 | 14 | def filename_from_text(text: str) -> str: 15 | text = demojize(text, delimiters=['', '']) 16 | text = text.lower().encode("ascii", errors="ignore").decode() 17 | allowed_chars = string.ascii_lowercase + ' ' 18 | text = ''.join(i for i in text.lower() if i in allowed_chars) 19 | text = text[:64] 20 | text = '-'.join(text.strip().split()) 21 | if len(text) == 0: text = 'blank' 22 | return text 23 | 24 | class ReplicatePredictor(BasePredictor): 25 | def setup(self): 26 | self.model = MinDalle( 27 | is_mega=True, 28 | is_reusable=True, 29 | dtype=torch.float32, 30 | device='cuda' 31 | ) 32 | 33 | def predict( 34 | self, 35 | text: str = Input(default='Dali painting of WALL·E'), 36 | save_as_png: bool = Input(default=False), 37 | progressive_outputs: bool = Input(default=True), 38 | seamless: bool = Input(default=False), 39 | grid_size: int = Input(ge=1, le=9, default=5), 40 | temperature: float = Input( 41 | ge=0.01, 42 | le=16, 43 | default=4 44 | ), 45 | top_k: int = Input( 46 | choices=[2 ** i for i in range(15)], 47 | default=64, 48 | description='Advanced Setting, see Readme below if interested.' 49 | ), 50 | supercondition_factor: int = Input( 51 | choices=[2 ** i for i in range(2, 7)], 52 | default=16, 53 | description='Advanced Setting, see Readme below if interested.' 54 | ) 55 | ) -> Iterator[Path]: 56 | image_stream = self.model.generate_image_stream( 57 | text = text, 58 | seed = -1, 59 | grid_size = grid_size, 60 | progressive_outputs = progressive_outputs, 61 | is_seamless = seamless, 62 | temperature = temperature, 63 | supercondition_factor = float(supercondition_factor), 64 | top_k = top_k, 65 | is_verbose = True 66 | ) 67 | 68 | i = 0 69 | path = Path(tempfile.mkdtemp()) 70 | for image in image_stream: 71 | i += 1 72 | is_final = i == 8 if progressive_outputs else True 73 | ext = 'png' if is_final and save_as_png else 'jpg' 74 | filename = filename_from_text(text) 75 | filename += '' if is_final else '-iter-{}'.format(i) 76 | image_path = path / '{}.{}'.format(filename, ext) 77 | image.save(str(image_path)) 78 | yield image_path -------------------------------------------------------------------------------- /target_model/min_dalle/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | # from pathlib import Path 3 | 4 | setuptools.setup( 5 | name='min-dalle', 6 | description = 'min(DALL·E)', 7 | # long_description=(Path(__file__).parent / "README.rst").read_text(), 8 | version='0.4.11', 9 | author='Brett Kuprel', 10 | author_email='brkuprel@gmail.com', 11 | url='https://github.com/kuprel/min-dalle', 12 | packages=[ 13 | 'min_dalle', 14 | 'min_dalle.models' 15 | ], 16 | license='MIT', 17 | install_requires=[ 18 | 'torch>=1.11', 19 | 'typing_extensions>=4.1', 20 | 'numpy>=1.21', 21 | 'pillow>=7.1', 22 | 'requests>=2.23', 23 | 'emoji' 24 | ], 25 | keywords = [ 26 | 'artificial intelligence', 27 | 'deep learning', 28 | 'text-to-image', 29 | 'pytorch' 30 | ] 31 | ) -------------------------------------------------------------------------------- /target_model/min_dalle/tkinter_ui.py: -------------------------------------------------------------------------------- 1 | from min_dalle import MinDalle 2 | import sys 3 | import PIL 4 | import PIL.Image 5 | import PIL.ImageTk 6 | import tkinter 7 | from tkinter import ttk 8 | 9 | def regen_root(): 10 | global root 11 | global blank_image 12 | global padding_image 13 | 14 | root = tkinter.Tk() 15 | root.wm_resizable(False, False) 16 | 17 | blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256 * 2, 256 * 2), mode="RGB")) 18 | padding_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(16, 16), mode="RGBA")) 19 | 20 | regen_root() 21 | 22 | is_mega = None 23 | def set_mega_true_and_destroy(): 24 | global is_mega 25 | is_mega = True 26 | root.destroy() 27 | def set_mega_false_and_destroy(): 28 | global is_mega 29 | is_mega = False 30 | root.destroy() 31 | 32 | frm = ttk.Frame(root, padding=16) 33 | frm.grid() 34 | ttk.Button(frm, text="Mega", command=set_mega_true_and_destroy).grid(column=0, row=0) 35 | ttk.Label(frm, image=padding_image).grid(column=1, row=0) 36 | ttk.Button(frm, text="Mini", command=set_mega_false_and_destroy).grid(column=2, row=0) 37 | root.mainloop() 38 | 39 | if is_mega is None: 40 | print("no option selected") 41 | sys.exit(0) 42 | 43 | #print("is_mega", is_mega) 44 | 45 | model = MinDalle( 46 | models_root="./pretrained", 47 | is_mega=is_mega, 48 | is_reusable=True, 49 | is_verbose=True 50 | ) 51 | 52 | regen_root() 53 | 54 | label_image_content = blank_image 55 | 56 | sv_prompt = tkinter.StringVar(value="artificial intelligence") 57 | sv_temperature = tkinter.StringVar(value="1") 58 | sv_topk = tkinter.StringVar(value="128") 59 | sv_supercond = tkinter.StringVar(value="16") 60 | bv_seamless = tkinter.BooleanVar(value=False) 61 | 62 | def generate(): 63 | # check fields 64 | try: 65 | temperature = float(sv_temperature.get()) 66 | except: 67 | sv_temperature.set("ERROR") 68 | return 69 | try: 70 | topk = int(sv_topk.get()) 71 | except: 72 | sv_topk.set("ERROR") 73 | return 74 | try: 75 | supercond = int(sv_supercond.get()) 76 | except: 77 | sv_supercond.set("ERROR") 78 | return 79 | try: 80 | is_seamless = bool(bv_seamless.get()) 81 | except: 82 | return 83 | # and continue 84 | global label_image_content 85 | image_stream = model.generate_image_stream( 86 | sv_prompt.get(), 87 | grid_size=2, 88 | seed=-1, 89 | progressive_outputs=True, 90 | is_seamless=is_seamless, 91 | temperature=temperature, 92 | top_k=topk, 93 | supercondition_factor=supercond, 94 | is_verbose=True 95 | ) 96 | for image in image_stream: 97 | global final_image 98 | final_image = image 99 | label_image_content = PIL.ImageTk.PhotoImage(image) 100 | label_image.configure(image=label_image_content) 101 | label_image.update() 102 | 103 | def save(): 104 | final_image.save('generated/out.png') 105 | 106 | frm = ttk.Frame(root, padding=16) 107 | frm.grid() 108 | 109 | props = ttk.Frame(frm) 110 | 111 | # outer structure (hbox) 112 | label_image = ttk.Label(frm, image=blank_image) 113 | label_image.grid(column=0, row=0) 114 | ttk.Label(frm, image=padding_image).grid(column=1, row=0) 115 | props.grid(column=2, row=0) 116 | 117 | # inner structure (properties and shit) 118 | # prompt field 119 | ttk.Label(props, text="Prompt:").grid(column=0, row=0) 120 | ttk.Entry(props, textvariable=sv_prompt).grid(column=1, row=0) 121 | # 122 | ttk.Label(props, image=padding_image).grid(column=0, row=1) 123 | # temperature field 124 | ttk.Label(props, text="Temperature:").grid(column=0, row=2) 125 | ttk.Entry(props, textvariable=sv_temperature).grid(column=1, row=2) 126 | # 127 | ttk.Label(props, image=padding_image).grid(column=0, row=3) 128 | # topk field 129 | ttk.Label(props, text="Top-K:").grid(column=0, row=4) 130 | ttk.Entry(props, textvariable=sv_topk).grid(column=1, row=4) 131 | # 132 | ttk.Label(props, image=padding_image).grid(column=0, row=5) 133 | # superconditioning field 134 | ttk.Label(props, text="Supercondition Factor:").grid(column=0, row=6) 135 | ttk.Entry(props, textvariable=sv_supercond).grid(column=1, row=6) 136 | # 137 | ttk.Label(props, image=padding_image).grid(column=0, row=7) 138 | # seamless 139 | ttk.Label(props, text="Seamless:").grid(column=0, row=8) 140 | ttk.Checkbutton(props, variable=bv_seamless).grid(column=1, row=8) 141 | # 142 | ttk.Label(props, image=padding_image).grid(column=0, row=9) 143 | # buttons 144 | ttk.Button(props, text="Generate", command=generate).grid(column=0, row=10) 145 | ttk.Button(props, text="Quit", command=root.destroy).grid(column=1, row=10) 146 | ttk.Button(props, text="Save", command=save).grid(column=2, row=10) 147 | 148 | root.mainloop() --------------------------------------------------------------------------------