├── .gitignore ├── LICENSE ├── README.md ├── annotator ├── canny │ └── __init__.py ├── render_images.py └── util.py ├── app.py ├── cldm ├── cldm.py ├── ddim_hacked.py ├── hack.py ├── logger.py └── model.py ├── clip_score.py ├── configs ├── config.yaml ├── config_ema.yaml └── train_configs │ ├── laion_glyph_glyphcontrol_train.yaml │ └── textcaps_glyphcontrol_ablation.yaml ├── data └── README.md ├── environment.yaml ├── fonts ├── AlumniSans.ttf ├── DejaVuSans.ttf ├── NotoSans.ttf ├── ZCOOLXiaoWei.ttf └── calibri.ttf ├── glyph_instructions.yaml ├── inference.py ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── laion_glyph_control.py │ ├── simple.py │ ├── textcaps_control.py │ └── util.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ └── utils_image.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py └── util.py ├── main.py ├── ocr_acc.py ├── pretrained_models └── .md ├── readme_files ├── architecture-n.png ├── interface-clean.png ├── interface.png └── teaser_6.png ├── requirements.txt ├── scripts └── rendertext_tool.py ├── setup.py ├── test_on_benchmark.py └── text_prompts ├── paper ├── CreativeBench │ ├── all_unigram_100000_plus_100_prompt_file_GlyphDraw_origin_remove_render_words.txt │ ├── all_unigram_10000_100000_100_prompt_file_GlyphDraw_origin_remove_render_words.txt │ ├── all_unigram_1000_10000_100_prompt_file_GlyphDraw_origin_remove_render_words.txt │ └── all_unigram_top_1000_100_prompt_file_GlyphDraw_origin_remove_render_words.txt └── SimpleBench │ ├── all_unigram_100000_plus_100_1_gram.txt │ ├── all_unigram_10000_100000_100_1_gram.txt │ ├── all_unigram_1000_10000_100_1_gram.txt │ └── all_unigram_top_1000_100_1_gram.txt └── raw ├── CreativeBench └── GlyphDraw_origin_remove_render_words.txt └── SimpleBench ├── all_unigram_100000_plus_100.txt ├── all_unigram_10000_100000_100.txt ├── all_unigram_1000_10000_100.txt └── all_unigram_top_1000_100.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.pyc 3 | amlt/ 4 | .amltconfig 5 | deepspeed/* 6 | *_bp.* 7 | amlt* 8 | src/clip* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 AIGText 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /annotator/canny/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | class CannyDetector: 5 | def __call__(self, img, low_threshold, high_threshold): 6 | return cv2.Canny(img, low_threshold, high_threshold) 7 | -------------------------------------------------------------------------------- /annotator/render_images.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFont, ImageDraw 2 | import random 3 | import numpy as np 4 | 5 | # resize height to image_height first, then shrink or pad to image_width 6 | def resize_and_pad_image(pil_image, image_size): 7 | 8 | if isinstance(image_size, (tuple, list)) and len(image_size) == 2: 9 | image_width, image_height = image_size 10 | elif isinstance(image_size, int): 11 | image_width = image_height = image_size 12 | else: 13 | raise ValueError(f"Image size should be int or list/tuple of int not {image_size}") 14 | 15 | while pil_image.size[1] >= 2 * image_height: 16 | pil_image = pil_image.resize( 17 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 18 | ) 19 | 20 | scale = image_height / pil_image.size[1] 21 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size),resample=Image.BICUBIC) 22 | 23 | # shrink 24 | if pil_image.size[0] > image_width: 25 | pil_image = pil_image.resize((image_width, image_height),resample=Image.BICUBIC) 26 | 27 | # padding 28 | if pil_image.size[0] < image_width: 29 | img = Image.new(mode="RGB",size=(image_width,image_height), color="white") 30 | width, _ = pil_image.size 31 | img.paste(pil_image,((image_width - width)//2, 0)) 32 | pil_image = img 33 | 34 | return pil_image 35 | 36 | def render_text_image_custom(image_size, bboxes, rendered_txt_values, num_rows_values, font_name="calibri", align = "center"): 37 | # aligns = ["center", "left", "right"] 38 | ''' 39 | Render text image based on the glyph instructions, i.e., the list of tuples (text, bbox, num_rows). 40 | Currently we just use Calibri font to render glyph images. 41 | ''' 42 | print(image_size, bboxes, rendered_txt_values, num_rows_values, align) 43 | background = Image.new("RGB", image_size, "white") 44 | font = ImageFont.truetype(f"fonts/{font_name}.ttf", encoding='utf-8', size=512) 45 | 46 | for text, bbox, num_rows in zip(rendered_txt_values, bboxes, num_rows_values): 47 | 48 | if len(text) == 0: 49 | continue 50 | 51 | text = text.strip() 52 | if num_rows != 1: 53 | word_tokens = text.split() 54 | num_tokens = len(word_tokens) 55 | index_list = range(1, num_tokens + 1) 56 | if num_tokens > num_rows: 57 | index_list = random.sample(index_list, num_rows) 58 | index_list.sort() 59 | line_list = [] 60 | start_idx = 0 61 | for index in index_list: 62 | line_list.append( 63 | " ".join(word_tokens 64 | [start_idx: index] 65 | ) 66 | ) 67 | start_idx = index 68 | text = "\n".join(line_list) 69 | 70 | if 'ratio' not in bbox or bbox['ratio'] == 0 or bbox['ratio'] < 1e-4: 71 | image4ratio = Image.new("RGB", (512, 512), "white") 72 | draw = ImageDraw.Draw(image4ratio) 73 | _, _ , w, h = draw.textbbox(xy=(0,0),text = text, font=font) 74 | ratio = w / h 75 | else: 76 | ratio = bbox['ratio'] 77 | 78 | width = int(bbox['width'] * image_size[1]) 79 | height = int(width / ratio) 80 | top_left_x = int(bbox['top_left_x'] * image_size[0]) 81 | top_left_y = int(bbox['top_left_y'] * image_size[1]) 82 | yaw = bbox['yaw'] 83 | 84 | text_image = Image.new("RGB", (512, 512), "white") 85 | draw = ImageDraw.Draw(text_image) 86 | x,y,w,h = draw.textbbox(xy=(0,0),text = text, font=font) 87 | text_image = Image.new("RGB", (w, h), "white") 88 | draw = ImageDraw.Draw(text_image) 89 | draw.text((-x/2,-y/2), text, "black", font=font, align=align) 90 | text_image = resize_and_pad_image(text_image, (width, height)) 91 | text_image = text_image.rotate(angle=-yaw, expand=True, fillcolor="white") 92 | # image = Image.new("RGB", (w, h), "white") 93 | # draw = ImageDraw.Draw(image) 94 | 95 | background.paste(text_image, (top_left_x, top_left_y)) 96 | 97 | return background 98 | 99 | def render_text_image_laionglyph(image_size, ocrinfo, confidence_threshold=0.5): 100 | ''' 101 | Render the glyph image according to the ocr information for the samples in the LAIONGlyph Dataset 102 | ''' 103 | font = ImageFont.truetype("calibri.ttf", encoding='utf-8', size=512) 104 | background = Image.new("RGB", image_size, "white") 105 | 106 | for sub_ocr_info in ocrinfo: 107 | 108 | bbox, text, confidence = sub_ocr_info 109 | 110 | if confidence < confidence_threshold: 111 | continue 112 | 113 | # print(bbox, text, confidence) 114 | # Calculate the real size 115 | real_width = int(bbox[1][0] - bbox[0][0]) 116 | real_height = int(bbox[3][1] - bbox[0][1]) 117 | # Calculate the rotation parameter 118 | bbox_center = [(bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2] 119 | angle = np.arctan2(bbox[1][1] - bbox_center[1], bbox[1][0] - bbox_center[0]) * 180 / np.pi 120 | 121 | text_image = Image.new("RGB", (512, 512), "white") 122 | draw = ImageDraw.Draw(text_image) 123 | x,y,w,h = draw.textbbox(xy=(0,0),text = text, font=font) 124 | text_image = Image.new("RGB", (w, h), "white") 125 | draw = ImageDraw.Draw(text_image) 126 | draw.text((-x/2,-y/2), text, "black", font=font, align="center") 127 | 128 | text_image = resize_and_pad_image(text_image, (real_width, real_height)) 129 | text_image = text_image.rotate(angle=-angle, expand=True, fillcolor="white") 130 | background.paste(text_image, (int(bbox[0][0]), int(bbox[0][1]))) 131 | 132 | return background -------------------------------------------------------------------------------- /annotator/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | 6 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 7 | 8 | 9 | def HWC3(x): 10 | assert x.dtype == np.uint8 11 | if x.ndim == 2: 12 | x = x[:, :, None] 13 | assert x.ndim == 3 14 | H, W, C = x.shape 15 | assert C == 1 or C == 3 or C == 4 16 | if C == 3: 17 | return x 18 | if C == 1: 19 | return np.concatenate([x, x, x], axis=2) 20 | if C == 4: 21 | color = x[:, :, 0:3].astype(np.float32) 22 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 23 | y = color * alpha + 255.0 * (1.0 - alpha) 24 | y = y.clip(0, 255).astype(np.uint8) 25 | return y 26 | 27 | 28 | def resize_image(input_image, resolution): 29 | H, W, C = input_image.shape 30 | H = float(H) 31 | W = float(W) 32 | k = float(resolution) / min(H, W) 33 | H *= k 34 | W *= k 35 | H = int(np.round(H / 64.0)) * 64 36 | W = int(np.round(W / 64.0)) * 64 37 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 38 | return img 39 | -------------------------------------------------------------------------------- /cldm/hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | 4 | import ldm.modules.encoders.modules 5 | import ldm.modules.attention 6 | 7 | from transformers import logging 8 | from ldm.modules.attention import default 9 | 10 | 11 | def disable_verbosity(): 12 | logging.set_verbosity_error() 13 | print('logging improved.') 14 | return 15 | 16 | 17 | def enable_sliced_attention(): 18 | ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward 19 | print('Enabled sliced_attention.') 20 | return 21 | 22 | 23 | def hack_everything(clip_skip=0): 24 | disable_verbosity() 25 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward 26 | ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip 27 | print('Enabled clip hacks.') 28 | return 29 | 30 | 31 | # Written by Lvmin 32 | def _hacked_clip_forward(self, text): 33 | PAD = self.tokenizer.pad_token_id 34 | EOS = self.tokenizer.eos_token_id 35 | BOS = self.tokenizer.bos_token_id 36 | 37 | def tokenize(t): 38 | return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"] 39 | 40 | def transformer_encode(t): 41 | if self.clip_skip > 1: 42 | rt = self.transformer(input_ids=t, output_hidden_states=True) 43 | return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip]) 44 | else: 45 | return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state 46 | 47 | def split(x): 48 | return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3] 49 | 50 | def pad(x, p, i): 51 | return x[:i] if len(x) >= i else x + [p] * (i - len(x)) 52 | 53 | raw_tokens_list = tokenize(text) 54 | tokens_list = [] 55 | 56 | for raw_tokens in raw_tokens_list: 57 | raw_tokens_123 = split(raw_tokens) 58 | raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123] 59 | raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123] 60 | tokens_list.append(raw_tokens_123) 61 | 62 | tokens_list = torch.IntTensor(tokens_list).to(self.device) 63 | 64 | feed = einops.rearrange(tokens_list, 'b f i -> (b f) i') 65 | y = transformer_encode(feed) 66 | z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3) 67 | 68 | return z 69 | 70 | 71 | # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py 72 | def _hacked_sliced_attentin_forward(self, x, context=None, mask=None): 73 | h = self.heads 74 | 75 | q = self.to_q(x) 76 | context = default(context, x) 77 | k = self.to_k(context) 78 | v = self.to_v(context) 79 | del context, x 80 | 81 | q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 82 | 83 | limit = k.shape[0] 84 | att_step = 1 85 | q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0)) 86 | k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0)) 87 | v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0)) 88 | 89 | q_chunks.reverse() 90 | k_chunks.reverse() 91 | v_chunks.reverse() 92 | sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) 93 | del k, q, v 94 | for i in range(0, limit, att_step): 95 | q_buffer = q_chunks.pop() 96 | k_buffer = k_chunks.pop() 97 | v_buffer = v_chunks.pop() 98 | sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale 99 | 100 | del k_buffer, q_buffer 101 | # attention, what we cannot get enough of, by chunks 102 | 103 | sim_buffer = sim_buffer.softmax(dim=-1) 104 | 105 | sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) 106 | del v_buffer 107 | sim[i:i + att_step, :, :] = sim_buffer 108 | 109 | del sim_buffer 110 | sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h) 111 | return self.to_out(sim) 112 | -------------------------------------------------------------------------------- /cldm/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.utilities.distributed import rank_zero_only 9 | 10 | 11 | class ImageLogger(Callback): 12 | def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True, 13 | rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, 14 | log_images_kwargs=None): 15 | super().__init__() 16 | self.rescale = rescale 17 | self.batch_freq = batch_frequency 18 | self.max_images = max_images 19 | if not increase_log_steps: 20 | self.log_steps = [self.batch_freq] 21 | self.clamp = clamp 22 | self.disabled = disabled 23 | self.log_on_batch_idx = log_on_batch_idx 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | self.log_first_step = log_first_step 26 | 27 | @rank_zero_only 28 | def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): 29 | root = os.path.join(save_dir, "image_log", split) 30 | for k in images: 31 | grid = torchvision.utils.make_grid(images[k], nrow=4) 32 | if self.rescale: 33 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 34 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 35 | grid = grid.numpy() 36 | grid = (grid * 255).astype(np.uint8) 37 | filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) 38 | path = os.path.join(root, filename) 39 | os.makedirs(os.path.split(path)[0], exist_ok=True) 40 | Image.fromarray(grid).save(path) 41 | 42 | def log_img(self, pl_module, batch, batch_idx, split="train"): 43 | check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step 44 | if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 45 | hasattr(pl_module, "log_images") and 46 | callable(pl_module.log_images) and 47 | self.max_images > 0): 48 | logger = type(pl_module.logger) 49 | 50 | is_train = pl_module.training 51 | if is_train: 52 | pl_module.eval() 53 | 54 | with torch.no_grad(): 55 | images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) 56 | 57 | for k in images: 58 | N = min(images[k].shape[0], self.max_images) 59 | images[k] = images[k][:N] 60 | if isinstance(images[k], torch.Tensor): 61 | images[k] = images[k].detach().cpu() 62 | if self.clamp: 63 | images[k] = torch.clamp(images[k], -1., 1.) 64 | 65 | self.log_local(pl_module.logger.save_dir, split, images, 66 | pl_module.global_step, pl_module.current_epoch, batch_idx) 67 | 68 | if is_train: 69 | pl_module.train() 70 | 71 | def check_frequency(self, check_idx): 72 | return check_idx % self.batch_freq == 0 73 | 74 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 75 | if not self.disabled: 76 | self.log_img(pl_module, batch, batch_idx, split="train") 77 | -------------------------------------------------------------------------------- /cldm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from omegaconf import OmegaConf 5 | from ldm.util import instantiate_from_config 6 | 7 | 8 | def get_state_dict(d): 9 | return d.get('state_dict', d) 10 | 11 | 12 | def load_state_dict(ckpt_path, location='cpu'): 13 | _, extension = os.path.splitext(ckpt_path) 14 | if extension.lower() == ".safetensors": 15 | import safetensors.torch 16 | state_dict = safetensors.torch.load_file(ckpt_path, device=location) 17 | else: 18 | state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) 19 | state_dict = get_state_dict(state_dict) 20 | print(f'Loaded state_dict from [{ckpt_path}]') 21 | return state_dict 22 | 23 | 24 | def create_model(config_path): 25 | config = OmegaConf.load(config_path) 26 | model = instantiate_from_config(config.model).cpu() 27 | print(f'Loaded model config from [{config_path}]') 28 | return model 29 | -------------------------------------------------------------------------------- /clip_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the CLIP Scores 2 | 3 | The CLIP model is a contrasitively learned language-image model. There is 4 | an image encoder and a text encoder. It is believed that the CLIP model could 5 | measure the similarity of cross modalities. Please find more information from 6 | https://github.com/openai/CLIP. 7 | 8 | The CLIP Score measures the Cosine Similarity between two embedded features. 9 | This repository utilizes the pretrained CLIP Model to calculate 10 | the mean average of cosine similarities. 11 | """ 12 | 13 | import os 14 | from PIL import Image 15 | import clip 16 | import torch 17 | from PIL import Image 18 | from torch.utils.data import Dataset, DataLoader 19 | from tqdm import tqdm 20 | from argparse import ArgumentParser 21 | 22 | IMAGE_EXTENSIONS = 'jpg' 23 | PROMPT_TYPE = {'Sign', 'GlyphDraw'} # "Sign": SimpleBench; "GlyphDraw": CreativeBench 24 | 25 | def parser_fn(): 26 | 27 | parser = ArgumentParser() 28 | parser.add_argument('--batch-size', type=int, default=16, 29 | help='Batch size to use') 30 | parser.add_argument('--clip-model', type=str, default='ViT-B/32', 31 | help='CLIP model to use') 32 | parser.add_argument('--num-workers', type=int, default=1, 33 | help='Number of processes') 34 | parser.add_argument('--prompt_type', type=str, default=None, 35 | help='Sign or GlyphDraw') 36 | parser.add_argument('--img_path', type=str, 37 | help='Image folder path') 38 | parser.add_argument('--img_path_multi', type=str, default=None, 39 | help='The path including multiple Image folder paths') 40 | parser.add_argument('--ckpt_name', type=str, default=None, 41 | help='The checkpoint name') 42 | return parser 43 | 44 | class DummyDataset(Dataset): 45 | 46 | def __init__(self, img_path, prompt_type, 47 | transform = None, 48 | tokenizer = None) -> None: 49 | super().__init__() 50 | 51 | if prompt_type is None: 52 | if "GlyphDraw" in img_path: 53 | prompt_type = 'GlyphDraw' 54 | else: 55 | prompt_type = 'Sign' 56 | 57 | assert prompt_type in PROMPT_TYPE 58 | self.img_path = img_path 59 | self.prompt_type = prompt_type 60 | self.transform = transform 61 | self.tokenizer = tokenizer 62 | if prompt_type == 'Sign': 63 | self._prepare_sign(img_path) 64 | if prompt_type == 'GlyphDraw': 65 | self._prepare_glyphdraw(img_path) 66 | print(f"{len(self.img_path_list)} image paths") 67 | print(f"First example:\n Image Path: {self.img_path_list[:1]}\n Prompt:{self.text_list[:1]}") 68 | 69 | assert len(self.img_path_list) == len(self.text_list) 70 | 71 | def _prepare_sign(self, img_path): 72 | 73 | self.img_path_list = [] 74 | self.text_list = [] 75 | 76 | for item in [i for i in os.listdir(img_path) if "." not in i]: 77 | path = os.path.join(img_path, item) 78 | for sub_item in [i for i in os.listdir(path) if IMAGE_EXTENSIONS in i and item in i and "glyph" not in i]: 79 | sub_path = os.path.join(path, sub_item) 80 | 81 | self.img_path_list.append(sub_path) 82 | self.text_list.append(f'A sign that says "{item}"') 83 | 84 | def _prepare_glyphdraw(self, img_path): 85 | 86 | self.img_path_list = [] 87 | self.text_list = [] 88 | 89 | for item in [i for i in os.listdir(img_path) if "." not in i]: 90 | path = os.path.join(img_path, item) 91 | 92 | with open(os.path.join(path, "prompt.txt"), 'r') as fp: 93 | prompt = fp.readline() 94 | 95 | for sub_item in [i for i in os.listdir(path) if IMAGE_EXTENSIONS in i and item in i and "glyph" not in i]: 96 | sub_path = os.path.join(path, sub_item) 97 | 98 | self.img_path_list.append(sub_path) 99 | self.text_list.append(prompt) 100 | 101 | def __len__(self): 102 | return len(self.img_path_list) 103 | 104 | def __getitem__(self, index): 105 | 106 | img_path = self.img_path_list[index] 107 | text = self.text_list[index] 108 | image = Image.open(img_path) 109 | 110 | if self.transform: 111 | image = self.transform(image) 112 | 113 | if self.tokenizer: 114 | text = self.tokenizer(text).squeeze() 115 | 116 | return image, text 117 | 118 | @torch.no_grad() 119 | def calculate_clip_score(dataloader, model, device): 120 | score_acc = 0. 121 | sample_num = 0. 122 | logit_scale = model.logit_scale.exp() 123 | print(f"Clip Model logit_scale is:{logit_scale}") 124 | for image, text in dataloader: 125 | 126 | image_features = model.encode_image(image.to(device)) 127 | 128 | text_features = model.encode_text(text.to(device)) 129 | 130 | # normalize features 131 | image_features = image_features / image_features.norm(dim=1, keepdim=True).to(torch.float32) 132 | text_features = text_features / text_features.norm(dim=1, keepdim=True).to(torch.float32) 133 | 134 | # calculate scores 135 | score = logit_scale * (image_features * text_features).sum() 136 | score_acc += score 137 | sample_num += image.shape[0] 138 | 139 | return score_acc / sample_num 140 | 141 | def main(img_path, args): 142 | 143 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 144 | print(device) 145 | num_workers = args.num_workers 146 | print("--------------------------") 147 | print(f"Evaluating on the {img_path}") 148 | print('Loading CLIP model: {}'.format(args.clip_model)) 149 | model, preprocess = clip.load(args.clip_model, device=device) 150 | 151 | dataset = DummyDataset(img_path, args.prompt_type, 152 | transform=preprocess, tokenizer=clip.tokenize) 153 | 154 | if len(dataset) != 400: 155 | return 156 | dataloader = DataLoader(dataset, args.batch_size, 157 | num_workers=num_workers, pin_memory=True) 158 | dataloader = tqdm(dataloader) 159 | 160 | print('Calculating CLIP Score:') 161 | clip_score = calculate_clip_score(dataloader, model, device) 162 | clip_score = clip_score.cpu().item() 163 | print('CLIP Score: ', clip_score) 164 | 165 | 166 | if __name__ == '__main__': 167 | 168 | args = parser_fn().parse_args() 169 | if args.img_path_multi is not None: 170 | from glob import glob 171 | img_paths = glob(args.img_path_multi + "/*") 172 | for img_path in img_paths: 173 | if not os.path.isdir(img_path): 174 | print(img_path, "is not a directory") 175 | continue 176 | if args.ckpt_name is not None: 177 | img_path = os.path.join(img_path, args.ckpt_name) 178 | main(img_path, args) 179 | else: 180 | main(args.img_path, args) -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: cldm.cldm.ControlLDM 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | control_key: "hint" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: false 16 | conditioning_key: crossattn 17 | monitor: #val/loss_simple_ema 18 | scale_factor: 0.18215 19 | only_mid_control: False 20 | sd_locked: True 21 | use_ema: False 22 | 23 | control_stage_config: 24 | target: cldm.cldm.ControlNet 25 | params: 26 | use_checkpoint: True 27 | image_size: 32 # unused 28 | in_channels: 4 29 | hint_channels: 3 30 | model_channels: 320 31 | attention_resolutions: [ 4, 2, 1 ] 32 | num_res_blocks: 2 33 | channel_mult: [ 1, 2, 4, 4 ] 34 | num_head_channels: 64 # need to fix for flash-attn 35 | use_spatial_transformer: True 36 | use_linear_in_transformer: True 37 | transformer_depth: 1 38 | context_dim: 1024 39 | legacy: False 40 | 41 | unet_config: 42 | target: cldm.cldm.ControlledUnetModel 43 | params: 44 | use_checkpoint: True 45 | image_size: 32 # unused 46 | in_channels: 4 47 | out_channels: 4 48 | model_channels: 320 49 | attention_resolutions: [ 4, 2, 1 ] 50 | num_res_blocks: 2 51 | channel_mult: [ 1, 2, 4, 4 ] 52 | num_head_channels: 64 # need to fix for flash-attn 53 | use_spatial_transformer: True 54 | use_linear_in_transformer: True 55 | transformer_depth: 1 56 | context_dim: 1024 57 | legacy: False 58 | 59 | first_stage_config: 60 | target: ldm.models.autoencoder.AutoencoderKL 61 | params: 62 | embed_dim: 4 63 | monitor: val/rec_loss 64 | ddconfig: 65 | #attn_type: "vanilla-xformers" 66 | double_z: true 67 | z_channels: 4 68 | resolution: 256 69 | in_channels: 3 70 | out_ch: 3 71 | ch: 128 72 | ch_mult: 73 | - 1 74 | - 2 75 | - 4 76 | - 4 77 | num_res_blocks: 2 78 | attn_resolutions: [] 79 | dropout: 0.0 80 | lossconfig: 81 | target: torch.nn.Identity 82 | 83 | cond_stage_config: 84 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 85 | params: 86 | freeze: True 87 | layer: "penultimate" 88 | # device: "cpu" 89 | -------------------------------------------------------------------------------- /configs/config_ema.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: cldm.cldm.ControlLDM 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | control_key: "hint" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: false 16 | conditioning_key: crossattn 17 | monitor: #val/loss_simple_ema 18 | scale_factor: 0.18215 19 | only_mid_control: False 20 | sd_locked: True 21 | use_ema: True 22 | 23 | control_stage_config: 24 | target: cldm.cldm.ControlNet 25 | params: 26 | use_checkpoint: True 27 | image_size: 32 # unused 28 | in_channels: 4 29 | hint_channels: 3 30 | model_channels: 320 31 | attention_resolutions: [ 4, 2, 1 ] 32 | num_res_blocks: 2 33 | channel_mult: [ 1, 2, 4, 4 ] 34 | num_head_channels: 64 # need to fix for flash-attn 35 | use_spatial_transformer: True 36 | use_linear_in_transformer: True 37 | transformer_depth: 1 38 | context_dim: 1024 39 | legacy: False 40 | 41 | unet_config: 42 | target: cldm.cldm.ControlledUnetModel 43 | params: 44 | use_checkpoint: True 45 | image_size: 32 # unused 46 | in_channels: 4 47 | out_channels: 4 48 | model_channels: 320 49 | attention_resolutions: [ 4, 2, 1 ] 50 | num_res_blocks: 2 51 | channel_mult: [ 1, 2, 4, 4 ] 52 | num_head_channels: 64 # need to fix for flash-attn 53 | use_spatial_transformer: True 54 | use_linear_in_transformer: True 55 | transformer_depth: 1 56 | context_dim: 1024 57 | legacy: False 58 | 59 | first_stage_config: 60 | target: ldm.models.autoencoder.AutoencoderKL 61 | params: 62 | embed_dim: 4 63 | monitor: val/rec_loss 64 | ddconfig: 65 | #attn_type: "vanilla-xformers" 66 | double_z: true 67 | z_channels: 4 68 | resolution: 256 69 | in_channels: 3 70 | out_ch: 3 71 | ch: 128 72 | ch_mult: 73 | - 1 74 | - 2 75 | - 4 76 | - 4 77 | num_res_blocks: 2 78 | attn_resolutions: [] 79 | dropout: 0.0 80 | lossconfig: 81 | target: torch.nn.Identity 82 | 83 | cond_stage_config: 84 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 85 | params: 86 | freeze: True 87 | layer: "penultimate" 88 | # device: "cpu" 89 | -------------------------------------------------------------------------------- /configs/train_configs/laion_glyph_glyphcontrol_train.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: cldm.cldm.ControlLDM 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | control_key: "hint" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: false 16 | conditioning_key: crossattn 17 | monitor: #val/loss_simple_ema 18 | scale_factor: 0.18215 19 | only_mid_control: False 20 | sd_locked: True #False 21 | use_ema: True #False 22 | ckpt_path: "pretrained_models/control_sd20_ini.ckpt" 23 | reset_ema: True 24 | reset_num_ema_updates: false 25 | keep_num_ema_updates: false 26 | only_model: false 27 | log_all_grad_norm: True 28 | sep_lr: True 29 | decoder_lr: 1.0e-4 30 | sep_cond_txt: True 31 | exchange_cond_txt: False 32 | 33 | scheduler_config: 34 | target: ldm.lr_scheduler.LambdaLinearScheduler 35 | params: 36 | warm_up_steps: [ 1 ] # NOTE 1 for resuming. use 10000 if starting from scratch 37 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 38 | f_start: [ 1.e-6 ] 39 | f_max: [ 1. ] 40 | f_min: [ 1. ] 41 | 42 | control_stage_config: 43 | target: cldm.cldm.ControlNet 44 | params: 45 | use_checkpoint: True 46 | image_size: 32 # unused 47 | in_channels: 4 48 | hint_channels: 3 49 | model_channels: 320 50 | attention_resolutions: [ 4, 2, 1 ] 51 | num_res_blocks: 2 52 | channel_mult: [ 1, 2, 4, 4 ] 53 | num_head_channels: 64 # need to fix for flash-attn 54 | use_spatial_transformer: True 55 | use_linear_in_transformer: True 56 | transformer_depth: 1 57 | context_dim: 1024 58 | legacy: False 59 | 60 | unet_config: 61 | target: cldm.cldm.ControlledUnetModel 62 | params: 63 | use_checkpoint: True 64 | # use_fp16: True # False 65 | image_size: 32 # unused 66 | in_channels: 4 67 | out_channels: 4 68 | model_channels: 320 69 | attention_resolutions: [ 4, 2, 1 ] 70 | num_res_blocks: 2 71 | channel_mult: [ 1, 2, 4, 4 ] 72 | num_head_channels: 64 # need to fix for flash-attn 73 | use_spatial_transformer: True 74 | use_linear_in_transformer: True 75 | transformer_depth: 1 76 | context_dim: 1024 77 | legacy: False 78 | 79 | first_stage_config: 80 | target: ldm.models.autoencoder.AutoencoderKL 81 | params: 82 | embed_dim: 4 83 | monitor: val/rec_loss 84 | ddconfig: 85 | #attn_type: "vanilla-xformers" 86 | double_z: true 87 | z_channels: 4 88 | resolution: 256 89 | in_channels: 3 90 | out_ch: 3 91 | ch: 128 92 | ch_mult: 93 | - 1 94 | - 2 95 | - 4 96 | - 4 97 | num_res_blocks: 2 98 | attn_resolutions: [] 99 | dropout: 0.0 100 | lossconfig: 101 | target: torch.nn.Identity 102 | 103 | cond_stage_config: 104 | target: ldm.modules.encoders.modules.FrozenOpenCLIPSepEncoder 105 | params: 106 | freeze: True 107 | layer: "penultimate" 108 | data: 109 | target: main.DataModuleFromConfig 110 | params: 111 | batch_size: 16 # (halve gradually if OOM for large batch size) 112 | num_workers: 2 113 | wrap: False 114 | custom_collate: True 115 | train: 116 | target: ldm.data.laion_glyph_control.LAIONGlyphCLDataset 117 | params: 118 | control_key: "hint" 119 | no_hint: False 120 | BLIP_caption: True 121 | new_proc_config: 122 | target: ldm.data.util.new_process_im_base 123 | params: 124 | size: 512 125 | interpolation: 3 126 | do_flip: False 127 | hint_range_m11: False 128 | rendered_txt_in_caption: False 129 | caption_choices: ["original", "original"] 130 | caption_drop_rates: [0.1, 0.5] 131 | rm_text_from_cp: False 132 | replace_token: "" 133 | 134 | 135 | lightning: 136 | callbacks: 137 | metrics_over_trainsteps_checkpoint: 138 | target: pytorch_lightning.callbacks.ModelCheckpoint 139 | params: 140 | every_n_train_steps: 1000 141 | 142 | modelcheckpoint: 143 | target: pytorch_lightning.callbacks.ModelCheckpoint 144 | params: 145 | save_top_k: -1 146 | trainer: 147 | benchmark: True 148 | # max_epochs: 100 149 | accumulate_grad_batches: 2 # (set > 1 if OOM for large batch size) 150 | deterministic: True 151 | profiler: "simple" 152 | log_every_n_steps: 3 153 | # num_nodes: 8 -------------------------------------------------------------------------------- /configs/train_configs/textcaps_glyphcontrol_ablation.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: cldm.cldm.ControlLDM 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | control_key: "hint" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: false 16 | conditioning_key: crossattn 17 | monitor: #val/loss_simple_ema 18 | scale_factor: 0.18215 19 | only_mid_control: False 20 | sd_locked: False # unlock the original U-Net decoder 21 | use_ema: True #False 22 | # ckpt_path: "pretrained_models/control_sd20_ini.ckpt" 23 | ckpt_path: {laionglyph_pretrained_model} 24 | reset_ema: True 25 | reset_num_ema_updates: false 26 | keep_num_ema_updates: false #True 27 | only_model: false 28 | log_all_grad_norm: True 29 | sep_lr: True 30 | decoder_lr: 1.0e-4 31 | sep_cond_txt: True 32 | exchange_cond_txt: False 33 | 34 | scheduler_config: 35 | target: ldm.lr_scheduler.LambdaLinearScheduler 36 | params: 37 | warm_up_steps: [ 1 ] # NOTE 1 for resuming. use 10000 if starting from scratch 38 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 39 | f_start: [ 1.e-6 ] 40 | f_max: [ 1. ] 41 | f_min: [ 1. ] 42 | 43 | control_stage_config: 44 | target: cldm.cldm.ControlNet 45 | params: 46 | use_checkpoint: True 47 | image_size: 32 # unused 48 | in_channels: 4 49 | hint_channels: 3 50 | model_channels: 320 51 | attention_resolutions: [ 4, 2, 1 ] 52 | num_res_blocks: 2 53 | channel_mult: [ 1, 2, 4, 4 ] 54 | num_head_channels: 64 # need to fix for flash-attn 55 | use_spatial_transformer: True 56 | use_linear_in_transformer: True 57 | transformer_depth: 1 58 | context_dim: 1024 59 | legacy: False 60 | 61 | unet_config: 62 | target: cldm.cldm.ControlledUnetModel 63 | params: 64 | use_checkpoint: True 65 | image_size: 32 # unused 66 | in_channels: 4 67 | out_channels: 4 68 | model_channels: 320 69 | attention_resolutions: [ 4, 2, 1 ] 70 | num_res_blocks: 2 71 | channel_mult: [ 1, 2, 4, 4 ] 72 | num_head_channels: 64 # need to fix for flash-attn 73 | use_spatial_transformer: True 74 | use_linear_in_transformer: True 75 | transformer_depth: 1 76 | context_dim: 1024 77 | legacy: False 78 | 79 | first_stage_config: 80 | target: ldm.models.autoencoder.AutoencoderKL 81 | params: 82 | embed_dim: 4 83 | monitor: val/rec_loss 84 | ddconfig: 85 | #attn_type: "vanilla-xformers" 86 | double_z: true 87 | z_channels: 4 88 | resolution: 256 89 | in_channels: 3 90 | out_ch: 3 91 | ch: 128 92 | ch_mult: 93 | - 1 94 | - 2 95 | - 4 96 | - 4 97 | num_res_blocks: 2 98 | attn_resolutions: [] 99 | dropout: 0.0 100 | lossconfig: 101 | target: torch.nn.Identity 102 | 103 | cond_stage_config: 104 | target: ldm.modules.encoders.modules.FrozenOpenCLIPSepEncoder 105 | params: 106 | freeze: True 107 | layer: "penultimate" 108 | data: 109 | target: main.DataModuleFromConfig 110 | params: 111 | batch_size: 4 112 | num_workers: 2 113 | wrap: False 114 | custom_collate: True 115 | train: 116 | target: ldm.data.textcaps_control.TextCapsCLDataset 117 | params: 118 | control_key: "hint" 119 | no_hint: False 120 | filter_data: False 121 | filter_words: ["sign", "poster", "book"] 122 | OneCapPerImage: True 123 | filter_token_num: False 124 | max_token_num: 5 125 | do_new_proc: True 126 | new_proc_config: 127 | target: ldm.data.util.new_process_im 128 | params: 129 | size: 512 130 | interpolation: 3 131 | do_flip: False 132 | hint_range_m11: False 133 | imagenet_proc: False 134 | new_ocr_info: True 135 | rendered_txt_in_caption: False 136 | caption_choices: ["original", "original"] 137 | caption_drop_rates: [0.1, 0.5] 138 | 139 | 140 | lightning: 141 | callbacks: 142 | metrics_over_trainsteps_checkpoint: 143 | target: pytorch_lightning.callbacks.ModelCheckpoint 144 | params: 145 | every_n_train_steps: 300 #1000 146 | 147 | modelcheckpoint: 148 | target: pytorch_lightning.callbacks.ModelCheckpoint 149 | params: 150 | save_top_k: -1 151 | 152 | 153 | trainer: 154 | benchmark: True 155 | # max_epochs: 100 156 | accumulate_grad_batches: 1 #2 157 | deterministic: True 158 | profiler: "simple" 159 | log_every_n_steps: 3 -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # LAION-Glyph Dataset 2 | 3 | * **LAION-Glyph 1M** 4 | 5 | File name: ```LAION-Glyph-1M.json```. 6 | 7 | ``` 8 | [ 9 | { 10 | "img_id": sample id with '\t' seprating two parts,e.g., "part-00012 00002014175" 11 | 12 | "img_code": the base64 code of the image, use Image.open(BytesIO(base64.b64decode(img_code))) to decode the original image 13 | 14 | "caption_origin": original caption provided by LAION dataset 15 | 16 | "caption_blip": the caption generated by BLIP-2 17 | 18 | "ocr_info": the information of multiple detected OCR bounding boxes, the format for each box: [ 19 | [top_left, top_right, lower_right, lower_left], 20 | [text, confidence] 21 | ] 22 | e.g:[ 23 | [[[102.0, 36.0], [250.0, 36.0], [250.0, 67.0], [102.0, 67.0]], ['BALTIMORE', 0.9966500401496887]], 24 | [[[31.0, 75.0], [321.0, 75.0], [321.0, 102.0], [31.0, 102.0]], ['BUSINESSJOURNAL', 0.9743010997772217]] 25 | ] 26 | }, 27 | ... 28 | ] 29 | ``` 30 | * **LAION-Glyph 10M** 31 | 32 | There are 10 files in total. Each contains 1M samples with the same format like LAION-Glyph 1M. 33 | File name: ```LAION-Glyph-10M_x.json```. (x = 0-9) 34 | 35 | [**Notes**] 36 | * Since each json file has large size (~100GB), it would be better to split each json file into multiple (e.g., 10 or 100) json files with smaller size. 37 | 38 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: GlyphControl 2 | channels: 3 | - pytorch 4 | # - defaults 5 | - conda-forge 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2 11 | - nvidia/label/cuda-11.3.1 12 | - xformers 13 | # - xformers/label/dev 14 | dependencies: #TODO: reformat and recheck 15 | - python=3.9 16 | - pip #=20.3 17 | - cudatoolkit=11.3 18 | - cuda-nvcc=11.3 19 | - pytorch=1.12.1 20 | - torchvision=0.13.1 21 | - torchaudio=0.12.1 22 | - numpy=1.23.1 23 | - xformers 24 | - pip: 25 | - albumentations==1.3.0 26 | - opencv-python==4.6.0.66 # 27 | - deepspeed # TODO 28 | - imageio==2.9.0 29 | - imageio-ffmpeg==0.4.2 30 | - pytorch-lightning==1.6.5 31 | - omegaconf==2.1.1 32 | - test-tube>=0.7.5 33 | - einops==0.3.0 34 | - transformers==4.24.0 35 | - open_clip_torch==2.0.2 36 | - torchmetrics==0.6.0 37 | - timm 38 | - gradio 39 | - wandb 40 | - tqdm 41 | - easyocr 42 | - triton==2.0.0.dev20221120 43 | - Levenshtein 44 | - py-cpuinfo 45 | - hjson 46 | - git+https://github.com/openai/CLIP.git 47 | - -i https://pypi.tuna.tsinghua.edu.cn/simple 48 | # - -e . 49 | -------------------------------------------------------------------------------- /fonts/AlumniSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/fonts/AlumniSans.ttf -------------------------------------------------------------------------------- /fonts/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/fonts/DejaVuSans.ttf -------------------------------------------------------------------------------- /fonts/NotoSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/fonts/NotoSans.ttf -------------------------------------------------------------------------------- /fonts/ZCOOLXiaoWei.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/fonts/ZCOOLXiaoWei.ttf -------------------------------------------------------------------------------- /fonts/calibri.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/fonts/calibri.ttf -------------------------------------------------------------------------------- /glyph_instructions.yaml: -------------------------------------------------------------------------------- 1 | Instructions: 2 | rendered_txt_values: ["APPLE"] 3 | # the width of the OCR box (i.e., the font size) 4 | width_values: [0.3] 5 | # the width-height ratio of the OCR box, if the ratio == 0, the ratio will be set as optimal ratio 6 | ratio_values: [0] 7 | # the relative coordinates of the top left corner of the OCR box 8 | top_left_x_values: [0.35] 9 | top_left_y_values: [0.4] 10 | # the yaw rotation angle of the OCR box ([-20, 20]) 11 | yaw_values: [0] 12 | # the number of rows where the text will be placed 13 | num_rows_values: [1] -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # GlyphControl: Glyph Conditional Control for Visual Text Generation 3 | # Paper Link: https://arxiv.org/pdf/2305.18259 4 | # Code Link: https://github.com/AIGText/GlyphControl-release 5 | # This script is used for inference. 6 | # ------------------------------------------ 7 | 8 | 9 | import torch 10 | import time 11 | from PIL import Image 12 | from cldm.hack import disable_verbosity, enable_sliced_attention 13 | from scripts.rendertext_tool import Render_Text, load_model_from_config 14 | from omegaconf import OmegaConf 15 | import argparse 16 | import os 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "--cfg", 21 | type=str, 22 | default="configs/config.yaml", 23 | help="path to model config", 24 | ) 25 | parser.add_argument( 26 | "--ckpt", 27 | type=str, 28 | default="checkpoints/laion10M_epoch_6_model_ema_only.ckpt", 29 | help="path to checkpoint of model", 30 | ) 31 | parser.add_argument( 32 | "--save_path", 33 | type=str, 34 | default="generated_images", 35 | help="where to save images" 36 | ) 37 | parser.add_argument( 38 | "--save_memory", 39 | type=str, 40 | default="whether to save memory by transferring some unused parts of models to the cpu device during inference", 41 | help="path to checkpoint of model", 42 | ) 43 | # specify the inference settings 44 | parser.add_argument( 45 | "--glyph_instructions", 46 | type=str, 47 | default=None, #"glyph_instructions.yaml", 48 | help="path to glyph instructions", 49 | ) 50 | parser.add_argument( 51 | "--prompt", 52 | type=str, 53 | nargs="?", 54 | default="A sign that says 'APPLE'", 55 | help="the prompt" 56 | ) 57 | parser.add_argument( 58 | "--num_samples", 59 | type=int, 60 | default=4, 61 | help="how many samples to produce for each given prompt. A.k.a batch size", 62 | ) 63 | parser.add_argument( 64 | "--a_prompt", 65 | type=str, 66 | default='4K, dslr, best quality, extremely detailed', 67 | help="additional prompt" 68 | ) 69 | parser.add_argument( 70 | "--n_prompt", 71 | type=str, 72 | default='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality', 73 | help="negative prompt" 74 | ) 75 | parser.add_argument( 76 | "--image_resolution", 77 | type=int, 78 | default=512, 79 | help="image resolution", 80 | ) 81 | parser.add_argument( 82 | "--strength", 83 | type=float, 84 | default=1, 85 | help="control strength", 86 | ) 87 | parser.add_argument( 88 | "--scale", 89 | type=float, 90 | default=9.0, 91 | help="classifier-free guidance scale", 92 | ) 93 | parser.add_argument( 94 | "--ddim_steps", 95 | type=int, 96 | default=20, 97 | help="ddim steps", 98 | ) 99 | parser.add_argument( 100 | "--seed", 101 | type=int, 102 | default=0, 103 | help="seed", 104 | ) 105 | parser.add_argument( 106 | "--guess_mode", 107 | action="store_true", 108 | help="whether use guess mode", 109 | ) 110 | parser.add_argument( 111 | "--eta", 112 | type=float, 113 | default=0, 114 | help="eta", 115 | ) 116 | args = parser.parse_args() 117 | return args 118 | 119 | if __name__ == "__main__": 120 | args = parse_args() 121 | disable_verbosity() 122 | if args.save_memory: 123 | enable_sliced_attention() 124 | 125 | try: 126 | # Glyph Instructions 127 | glyph_instructions = OmegaConf.load(args.glyph_instructions).Instructions 128 | # print(glyph_instructions) 129 | rendered_txt_values = glyph_instructions.rendered_txt_values 130 | width_values = glyph_instructions.width_values 131 | ratio_values = glyph_instructions.ratio_values 132 | top_left_x_values = glyph_instructions.top_left_x_values 133 | top_left_y_values = glyph_instructions.top_left_y_values 134 | yaw_values = glyph_instructions.yaw_values 135 | num_rows_values = glyph_instructions.num_rows_values 136 | # print(rendered_txt_values, width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values, num_rows_values) 137 | except Exception as e: 138 | print(e) 139 | rendered_txt_values = [""] 140 | width_values, ratio_values, top_left_x_values, top_left_y_values, yaw_values, num_rows_values = [None] * 6 141 | 142 | cfg = OmegaConf.load(args.cfg) 143 | model = load_model_from_config(cfg, args.ckpt, verbose=True) 144 | render_tool = Render_Text(model, save_memory = args.save_memory) 145 | 146 | # Render glyph images and generate corresponding visual text 147 | # print(args.prompt) 148 | results = render_tool.process_multi(rendered_txt_values, args.prompt, 149 | width_values, ratio_values, 150 | top_left_x_values, top_left_y_values, 151 | yaw_values, num_rows_values, 152 | args.num_samples, args.image_resolution, 153 | args.ddim_steps, args.guess_mode, 154 | args.strength, args.scale, args.seed, 155 | args.eta, args.a_prompt, args.n_prompt) 156 | 157 | 158 | result_path = os.path.join(args.save_path, args.prompt) 159 | os.makedirs(result_path, exist_ok=True) 160 | render_none = len([1 for rendered_txt in rendered_txt_values if rendered_txt != ""]) == 0 161 | if render_none: 162 | for idx, result in enumerate(results): 163 | result_im = Image.fromarray(result) 164 | result_im.save(os.path.join(result_path, f"{idx}.jpg")) 165 | else: 166 | rendered_txt_join = "_".join(rendered_txt_values) 167 | results[0].save(os.path.join(result_path, f"{rendered_txt_join}_glyph_image.jpg")) 168 | for idx, result in enumerate(results[1:]): 169 | result_im = Image.fromarray(result) 170 | result_im.save(os.path.join(result_path, f"{rendered_txt_join}_{idx}.jpg")) 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | from PIL import Image, ImageFile 4 | from pathlib import Path 5 | from functools import partial 6 | from torchvision import transforms as T, utils 7 | from torch import nn 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def cycle(dl): 13 | while True: 14 | for data in dl: 15 | yield data 16 | 17 | def convert_image_to(img_type, image): 18 | if image.mode != img_type: 19 | return image.convert(img_type) 20 | return image 21 | 22 | class Txt2ImgIterableBaseDataset(IterableDataset): 23 | ''' 24 | Define an interface to make the IterableDatasets for text2img data chainable 25 | ''' 26 | def __init__(self, num_records=0, valid_ids=None, size=256): 27 | super().__init__() 28 | self.num_records = num_records 29 | self.valid_ids = valid_ids 30 | self.sample_ids = valid_ids 31 | self.size = size 32 | 33 | # print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 34 | 35 | # def __len__(self): 36 | # return self.num_records 37 | 38 | @abstractmethod 39 | def __iter__(self): 40 | pass 41 | 42 | class BaseDataset(Dataset): 43 | def __init__( 44 | self, 45 | folder, 46 | image_size, 47 | exts = ['jpg', 'jpeg', 'png', 'tiff'], 48 | convert_image_to_type = None 49 | ): 50 | super().__init__() 51 | self.folder = folder 52 | self.image_size = image_size 53 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 54 | 55 | convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity() 56 | 57 | self.transform = T.Compose([ 58 | T.Lambda(convert_fn), 59 | T.Resize(image_size), 60 | T.RandomHorizontalFlip(), 61 | T.CenterCrop(image_size), 62 | T.ToTensor() 63 | ]) 64 | 65 | def __len__(self): 66 | return len(self.paths) 67 | 68 | def __getitem__(self, index): 69 | path = self.paths[index] 70 | img = Image.open(path) 71 | return self.transform(img) -------------------------------------------------------------------------------- /ldm/data/laion_glyph_control.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import numpy as np 3 | from omegaconf import DictConfig, ListConfig 4 | import torch 5 | from torch.utils.data import Dataset 6 | from pathlib import Path 7 | import json 8 | from PIL import Image 9 | from torchvision import transforms 10 | from einops import rearrange 11 | from ldm.util import instantiate_from_config 12 | # from datasets import load_dataset 13 | import os 14 | from collections import defaultdict 15 | import cv2 16 | import albumentations 17 | import random 18 | from ldm.data.util import new_process_im_base #, imagenet_process_im 19 | from glob import glob 20 | import random 21 | import base64 22 | from io import BytesIO 23 | from annotator.render_images import render_text_image_laionglyph 24 | 25 | 26 | class LAIONGlyphCLDataset(Dataset): 27 | 28 | ''' 29 | data class for LAIONGlyph dataset 30 | 31 | Input: 32 | data_folder: the folder storing the data json files. 33 | data_info_file: the tsv file recording the location of each sample 34 | The file for 10M dataset should look like this: 35 | LAION-Glyph-10M_0.json\t0 36 | LAION-Glyph-10M_0.json\t1 37 | ... 38 | LAION-Glyph-10M_1.json\t0 39 | LAION-Glyph-10M_1.json\t1 40 | ... 41 | 42 | ''' 43 | def __init__(self, 44 | 45 | data_folder, 46 | data_info_file, 47 | 48 | max_num_samples = -1, 49 | no_hint = False, 50 | 51 | first_stage_key = "jpg", 52 | cond_stage_key = "txt", 53 | control_key = "hint", 54 | BLIP_caption = False, #True, 55 | ocr_threshold = 0.5, 56 | 57 | rendered_txt_in_caption = False, 58 | caption_choices = ["original", "w_rend_text", "wo_rend_text"], 59 | caption_drop_rates = [0.1, 0.5, 0.1], 60 | 61 | postprocess=None, 62 | new_proc_config = None, 63 | rm_text_from_cp = False, 64 | replace_token = "", 65 | ) -> None: 66 | with open(data_info_file, "r") as f: 67 | data_infos = f.readlines() 68 | if max_num_samples > 0: 69 | data_infos = random.sample(data_infos, max_num_samples) 70 | self.data_infos = data_infos 71 | self.data_folder = data_folder 72 | 73 | self.ocr_threshold = ocr_threshold # the threshold of OCR recognition confidence 74 | self.no_hint = no_hint 75 | 76 | self.caption_choices = caption_choices 77 | self.caption_drop_rates = caption_drop_rates # random drop caption 78 | self.rendered_txt_in_caption = rendered_txt_in_caption 79 | self.BLIP_caption = BLIP_caption # whether to use the captions generated by BLIP-2 80 | 81 | self.first_stage_key = first_stage_key 82 | self.cond_stage_key = cond_stage_key 83 | self.control_key = control_key 84 | 85 | # postprocess 86 | if isinstance(postprocess, DictConfig): 87 | postprocess = instantiate_from_config(postprocess) 88 | self.postprocess = postprocess 89 | 90 | # image transform 91 | if new_proc_config is not None: 92 | self.new_proc_func = instantiate_from_config(new_proc_config) 93 | else: 94 | self.new_proc_func = new_process_im_base() 95 | 96 | self.rm_text_from_cp = rm_text_from_cp 97 | self.replace_token = replace_token 98 | 99 | 100 | def __len__(self): 101 | return len(self.data_infos) 102 | 103 | def __getitem__(self, index): 104 | data = {} 105 | # data info 106 | data_info = self.data_infos[index] 107 | filename, idx_in_file = data_info.split("\t")[:] 108 | idx_in_file = int(idx_in_file.strip()) 109 | with open(os.path.join(self.data_folder, filename), "r") as f: 110 | ori_data = json.load(f)[idx_in_file] 111 | img_id = ori_data["img_id"] 112 | 113 | # 1. Load the original image 114 | img_code = ori_data["img_code"] 115 | try: 116 | ori_img = Image.open(BytesIO(base64.b64decode(img_code))) 117 | except: 118 | print("can't open original image: {}".format(img_id)) 119 | return self.__getitem__(np.random.choice(self.__len__())) 120 | 121 | # 2. Load the caption 122 | if self.BLIP_caption: 123 | caption_ori = ori_data["caption_blip"] 124 | else: 125 | caption_ori = ori_data["caption_origin"] 126 | img_size = ori_img.size 127 | 128 | # 3. Load ocr info 129 | ocr_info = data["ocr_info"] 130 | 131 | pos_info_list = [] 132 | pos_info_tuples = [] 133 | for info in ocr_info: 134 | bbox, (text, confidence) = info 135 | if confidence > self.ocr_threshold: 136 | xy_info = np.array(bbox) 137 | min_x, min_y = np.min(xy_info, axis = 0).astype(int) 138 | max_x, max_y = np.max(xy_info, axis = 0).astype(int) 139 | pos_info_list.append( 140 | [min_x, min_y, max_x, max_y] 141 | ) 142 | mean_xy = (xy_info[0] + xy_info[2]) / 2 143 | lf = xy_info[0, 0] # min_x 144 | pos_info_tuples.append((text, 0.2 * lf + mean_xy[1])) #0.15 145 | ocr_txt = info[1] 146 | 147 | pos_info_list = np.array(pos_info_list) 148 | all_lf, all_up = np.min(pos_info_list[:, :2], axis = 0) 149 | all_rg, all_dn = np.max(pos_info_list[:, 2:], axis = 0) 150 | all_pos_info = [all_lf, all_up, all_rg, all_dn] 151 | 152 | # hint glyph image 153 | if not self.no_hint: 154 | try: 155 | hint_img = render_text_image_laionglyph( 156 | img_size, ocr_info, self.ocr_threshold 157 | ) 158 | except: 159 | print("can't render hint image: {}".format(img_id)) 160 | return self.__getitem__(np.random.choice(self.__len__())) 161 | else: 162 | hint_img = None 163 | 164 | assert all_pos_info 165 | im, im_hint = self.new_proc_func(ori_img, all_pos_info, hint_img) 166 | 167 | if not self.no_hint: 168 | assert im_hint is not None 169 | data[self.control_key] = im_hint 170 | data[self.first_stage_key] = im 171 | 172 | caption_wr_text = None 173 | arrange_tokens = [item[0] for item in (sorted(pos_info_tuples, key=lambda x: x[1]))] 174 | if self.rendered_txt_in_caption: 175 | valid_words = " ".join(arrange_tokens) 176 | caption_wr_text = caption_ori + '. Words in the image: "{}"'.format(valid_words) 177 | 178 | # process the ori 179 | caption_wo_text = None # 180 | if self.rm_text_from_cp and self.BLIP_caption: 181 | # [Default: False] remove the rendered words from the caption while using BLIP captions 182 | caption_items = caption_ori.split(" ") 183 | lower_arrange_tokens = [tk.lower() for tk in arrange_tokens] 184 | caption_wo_text = [] 185 | for cp_item in caption_items: 186 | if cp_item.lower() in lower_arrange_tokens: 187 | if self.replace_token != "": 188 | caption_wo_text.append(self.replace_token) 189 | else: 190 | caption_wo_text.append(cp_item) 191 | caption_wo_text = " ".join(caption_wo_text) 192 | prompt_list = [] 193 | for i in range(len(self.caption_choices)): 194 | cc = self.caption_choices[i] 195 | if cc == "original": 196 | caption = caption_ori 197 | elif cc == "w_rend_text": 198 | caption = caption_wr_text if caption_wr_text is not None else caption_ori 199 | elif cc == "wo_rend_text": 200 | caption = caption_wo_text if caption_wo_text is not None else caption_ori 201 | 202 | if torch.rand(1) < self.caption_drop_rates[i]: 203 | caption = "" 204 | prompt_list.append(caption) 205 | 206 | data[self.cond_stage_key] = prompt_list if len(prompt_list) > 1 else prompt_list[0] 207 | 208 | if self.postprocess is not None: 209 | data = self.postprocess(data) 210 | 211 | return data 212 | -------------------------------------------------------------------------------- /ldm/data/simple.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import numpy as np 3 | from omegaconf import DictConfig, ListConfig 4 | import torch 5 | from torch.utils.data import Dataset 6 | from pathlib import Path 7 | import json 8 | from PIL import Image 9 | from torchvision import transforms 10 | from einops import rearrange 11 | from ldm.util import instantiate_from_config 12 | from datasets import load_dataset 13 | 14 | def make_multi_folder_data(paths, caption_files=None, **kwargs): 15 | """Make a concat dataset from multiple folders 16 | Don't suport captions yet 17 | 18 | If paths is a list, that's ok, if it's a Dict interpret it as: 19 | k=folder v=n_times to repeat that 20 | """ 21 | list_of_paths = [] 22 | if isinstance(paths, (Dict, DictConfig)): 23 | assert caption_files is None, \ 24 | "Caption files not yet supported for repeats" 25 | for folder_path, repeats in paths.items(): 26 | list_of_paths.extend([folder_path]*repeats) 27 | paths = list_of_paths 28 | 29 | if caption_files is not None: 30 | datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] 31 | else: 32 | datasets = [FolderData(p, **kwargs) for p in paths] 33 | return torch.utils.data.ConcatDataset(datasets) 34 | 35 | class FolderData(Dataset): 36 | def __init__(self, 37 | root_dir, 38 | caption_file=None, 39 | image_transforms=[], 40 | ext="jpg", 41 | default_caption="", 42 | postprocess=None, 43 | return_paths=False, 44 | ) -> None: 45 | """Create a dataset from a folder of images. 46 | If you pass in a root directory it will be searched for images 47 | ending in ext (ext can be a list) 48 | """ 49 | self.root_dir = Path(root_dir) 50 | self.default_caption = default_caption 51 | self.return_paths = return_paths 52 | if isinstance(postprocess, DictConfig): 53 | postprocess = instantiate_from_config(postprocess) 54 | self.postprocess = postprocess 55 | if caption_file is not None: 56 | with open(caption_file, "rt") as f: 57 | ext = Path(caption_file).suffix.lower() 58 | if ext == ".json": 59 | captions = json.load(f) 60 | elif ext == ".jsonl": 61 | lines = f.readlines() 62 | lines = [json.loads(x) for x in lines] 63 | captions = {x["file_name"]: x["text"].strip("\n") for x in lines} 64 | else: 65 | raise ValueError(f"Unrecognised format: {ext}") 66 | self.captions = captions 67 | else: 68 | self.captions = None 69 | 70 | if not isinstance(ext, (tuple, list, ListConfig)): 71 | ext = [ext] 72 | 73 | # Only used if there is no caption file 74 | self.paths = [] 75 | for e in ext: 76 | self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) 77 | if isinstance(image_transforms, ListConfig): 78 | image_transforms = [instantiate_from_config(tt) for tt in image_transforms] 79 | image_transforms.extend([transforms.ToTensor(), 80 | transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) 81 | image_transforms = transforms.Compose(image_transforms) 82 | self.tform = image_transforms 83 | 84 | 85 | def __len__(self): 86 | if self.captions is not None: 87 | return len(self.captions.keys()) 88 | else: 89 | return len(self.paths) 90 | 91 | def __getitem__(self, index): 92 | data = {} 93 | if self.captions is not None: 94 | chosen = list(self.captions.keys())[index] 95 | caption = self.captions.get(chosen, None) 96 | if caption is None: 97 | caption = self.default_caption 98 | filename = self.root_dir/chosen 99 | else: 100 | filename = self.paths[index] 101 | 102 | if self.return_paths: 103 | data["path"] = str(filename) 104 | 105 | im = Image.open(filename) 106 | im = self.process_im(im) 107 | data["image"] = im 108 | 109 | if self.captions is not None: 110 | data["txt"] = caption 111 | else: 112 | data["txt"] = self.default_caption 113 | 114 | if self.postprocess is not None: 115 | data = self.postprocess(data) 116 | 117 | return data 118 | 119 | def process_im(self, im): 120 | im = im.convert("RGB") 121 | return self.tform(im) 122 | 123 | def hf_dataset( 124 | name, 125 | image_transforms=[], 126 | image_column="image", 127 | text_column="text", 128 | split='train', 129 | image_key='image', 130 | caption_key='txt', 131 | ): 132 | """Make huggingface dataset with appropriate list of transforms applied 133 | """ 134 | ds = load_dataset(name, split=split) 135 | image_transforms = [instantiate_from_config(tt) for tt in image_transforms] 136 | image_transforms.extend([transforms.ToTensor(), 137 | transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) 138 | tform = transforms.Compose(image_transforms) 139 | 140 | assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" 141 | assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}" 142 | 143 | def pre_process(examples): 144 | processed = {} 145 | processed[image_key] = [tform(im) for im in examples[image_column]] 146 | processed[caption_key] = examples[text_column] 147 | return processed 148 | 149 | ds.set_transform(pre_process) 150 | return ds 151 | 152 | class TextOnly(Dataset): 153 | def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): 154 | """Returns only captions with dummy images""" 155 | self.output_size = output_size 156 | self.image_key = image_key 157 | self.caption_key = caption_key 158 | if isinstance(captions, Path): 159 | self.captions = self._load_caption_file(captions) 160 | else: 161 | self.captions = captions 162 | 163 | if n_gpus > 1: 164 | # hack to make sure that all the captions appear on each gpu 165 | repeated = [n_gpus*[x] for x in self.captions] 166 | self.captions = [] 167 | [self.captions.extend(x) for x in repeated] 168 | 169 | def __len__(self): 170 | return len(self.captions) 171 | 172 | def __getitem__(self, index): 173 | dummy_im = torch.zeros(3, self.output_size, self.output_size) 174 | dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') 175 | return {self.image_key: dummy_im, self.caption_key: self.captions[index]} 176 | 177 | def _load_caption_file(self, filename): 178 | with open(filename, 'rt') as f: 179 | captions = f.readlines() 180 | return [x.strip('\n') for x in captions] -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | import albumentations 5 | from torchvision import transforms 6 | from PIL import Image 7 | import numpy as np 8 | from einops import rearrange 9 | import cv2 10 | from ldm.util import instantiate_from_config 11 | from omegaconf import ListConfig 12 | from open_clip.transform import ResizeMaxSize 13 | class AddMiDaS(object): 14 | def __init__(self, model_type): 15 | super().__init__() 16 | self.transform = load_midas_transform(model_type) 17 | 18 | def pt2np(self, x): 19 | x = ((x + 1.0) * .5).detach().cpu().numpy() 20 | return x 21 | 22 | def np2pt(self, x): 23 | x = torch.from_numpy(x) * 2 - 1. 24 | return x 25 | 26 | def __call__(self, sample): 27 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 28 | x = self.pt2np(sample['jpg']) 29 | x = self.transform({"image": x})["image"] 30 | sample['midas_in'] = x 31 | return sample 32 | 33 | class new_process_im_base: 34 | def __init__(self, 35 | size = 512, 36 | interpolation = 3, 37 | do_flip = True, 38 | flip_p = 0.5, 39 | hint_range_m11 = False, 40 | ): 41 | self.do_flip = do_flip 42 | self.flip_p = flip_p 43 | self.rescale = transforms.Resize(size=size, interpolation=interpolation) 44 | if self.do_flip: 45 | self.flip = transforms.functional.hflip 46 | # base_tf [-1, 1] 47 | base_tf_m11 = [ transforms.ToTensor(), # to be checked 48 | transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] 49 | self.base_tf_m11 = transforms.Compose(base_tf_m11) 50 | # base_tf [0, 1] 51 | base_tf_01 = [ transforms.ToTensor(), # to be checked 52 | transforms.Lambda(lambda x: rearrange(x, 'c h w -> h w c'))] 53 | self.base_tf_01 = transforms.Compose(base_tf_01) 54 | self.hint_range_m11 = hint_range_m11 55 | 56 | def __call__(self, im, pos_info, im_hint = None): 57 | # im = Image.open(filename) 58 | im = im.convert("RGB") 59 | # crop 60 | size = im.size 61 | crop_size = min(size) 62 | crop_axis = size.index(crop_size) 63 | lf, up, rg, dn = pos_info 64 | if crop_axis == 0: 65 | # width 66 | box_up, box_dn = self.generate_range(up, dn, size[1], size[0]) 67 | box_lf, box_rg = 0, size[0] 68 | else: 69 | box_lf, box_rg = self.generate_range(lf, rg, size[0], size[1]) 70 | box_up, box_dn = 0, size[1] 71 | im = im.crop((box_lf, box_up, box_rg, box_dn)) 72 | # rescale 73 | im = self.rescale(im) 74 | # 75 | flip_img = False 76 | if self.do_flip: 77 | if torch.rand(1) < self.flip_p: 78 | im = self.flip(im) 79 | flip_img = True 80 | im = self.base_tf_m11(im) 81 | # im_hint = None 82 | # if hint_filename is not None: 83 | # im_hint = Image.open(hint_filename) 84 | if im_hint is not None: 85 | im_hint = im_hint.convert("RGB") 86 | im_hint = im_hint.crop((box_lf, box_up, box_rg, box_dn)) 87 | im_hint = self.rescale(im_hint) 88 | if flip_img: 89 | im_hint = self.flip(im_hint) 90 | im_hint = self.base_tf_m11(im_hint) if self.hint_range_m11 else self.base_tf_01(im_hint) 91 | return im, im_hint 92 | 93 | def generate_range(self, low, high, len_max, len_min): 94 | mid = (low + high) / 2 * (len_max if high <= 1 else 1) 95 | max_range = min(mid + len_min / 2, len_max) 96 | min_range = min( 97 | max(mid - len_min / 2, 0 ), 98 | max(max_range - len_min, 0) 99 | ) 100 | return int(min_range), int(min_range + len_min) 101 | 102 | class new_process_im(new_process_im_base): 103 | def __call__(self, filename, pos_info, hint_filename = None): 104 | im = Image.open(filename) 105 | if hint_filename is not None: 106 | im_hint = Image.open(hint_filename) 107 | else: 108 | im_hint = None 109 | return super().__call__(im, pos_info, im_hint) 110 | 111 | class imagenet_process_im: 112 | def __init__(self, 113 | size = 512, 114 | do_flip = False, 115 | min_crop_f=0.5, 116 | max_crop_f=1., 117 | flip_p=0.5, 118 | random_crop=False 119 | ): 120 | 121 | self.do_flip = do_flip 122 | if self.do_flip: 123 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 124 | # self.base = self.get_base() 125 | 126 | # self.size = size 127 | self.min_crop_f = min_crop_f 128 | self.max_crop_f = max_crop_f 129 | assert(max_crop_f <= 1.) 130 | self.center_crop = not random_crop 131 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) 132 | self.size = size 133 | 134 | def __call__(self, im): 135 | im = im.convert("RGB") 136 | image = np.array(im).astype(np.uint8) 137 | # if image.shape[0] < self.size or image.shape[1] < self.size: 138 | # return None 139 | # crop 140 | min_side_len = min(image.shape[:2]) 141 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) 142 | crop_side_len = int(crop_side_len) 143 | if self.center_crop: 144 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) 145 | else: 146 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) 147 | image = self.cropper(image=image)["image"] # ? 148 | # rescale 149 | image = self.image_rescaler(image=image)["image"] 150 | # flip 151 | if self.do_flip: 152 | image = self.flip(Image.fromarray(image)) 153 | image = np.array(image).astype(np.uint8) 154 | return (image/127.5 - 1.0).astype(np.float32) 155 | 156 | # used for CLIP image encoder 157 | class process_wb_im: 158 | def __init__(self, 159 | size = 224, 160 | # do_padding = True, 161 | image_transforms=[], 162 | use_clip_resize=False, 163 | image_mean = None, 164 | image_std = None, 165 | exchange_channel = True, 166 | ): 167 | self.image_rescaler = albumentations.LongestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) 168 | self.image_size = size 169 | # self.do_padding = do_padding 170 | self.pad = albumentations.PadIfNeeded(min_height= self.image_size, min_width=self.image_size, 171 | border_mode=cv2.BORDER_CONSTANT, value= (255, 255, 255), 172 | ) 173 | if isinstance(image_transforms, ListConfig): 174 | image_transforms = [instantiate_from_config(tt) for tt in image_transforms] 175 | image_transforms.extend([ 176 | transforms.ToTensor(), 177 | transforms.Normalize( 178 | mean= image_mean if image_mean is not None else (0.48145466, 0.4578275, 0.40821073), 179 | std= image_std if image_std is not None else (0.26862954, 0.26130258, 0.27577711) 180 | ), 181 | ]) 182 | # transforms.Lambda(lambda x: rearrange(x, 'c h w -> h w c')) 183 | # ]) # transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) 184 | if exchange_channel: 185 | image_transforms.append( 186 | transforms.Lambda(lambda x: rearrange(x, 'c h w -> h w c')) 187 | ) 188 | image_transforms = transforms.Compose(image_transforms) 189 | self.tform = image_transforms 190 | self.use_clip_resize = use_clip_resize 191 | self.clip_resize = ResizeMaxSize(max_size = self.image_size, interpolation=transforms.InterpolationMode.BICUBIC, fill=(255, 255, 255)) 192 | 193 | def __call__(self, im): 194 | im = im.convert("RGB") 195 | # if self.do_padding: 196 | # im = self.padding_image(im) 197 | if self.use_clip_resize: 198 | im = self.clip_resize(im) 199 | else: 200 | im = self.padding_image(im) 201 | return self.tform(im) 202 | 203 | 204 | def padding_image(self, im): 205 | # resize 206 | im = np.array(im).astype(np.uint8) 207 | im_rescaled = self.image_rescaler(image=im)["image"] 208 | # padding 209 | im_padded = self.pad(image=im_rescaled)["image"] 210 | return im_padded 211 | 212 | # use for VQ-GAN 213 | class vqgan_process_im: 214 | def __init__(self, size=384, random_crop=False, augment=False, ori_preprocessor = False, to_tensor=False): 215 | self.size = size 216 | self.random_crop = random_crop 217 | self.augment = augment 218 | assert self.size is not None and self.size > 0 219 | if ori_preprocessor: 220 | # if self.size is not None and self.size > 0: 221 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 222 | if not self.random_crop: # train 223 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 224 | else: # test 225 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 226 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 227 | # else: 228 | # self.preprocessor = lambda **kwargs: kwargs 229 | else: 230 | self.rescaler = albumentations.LongestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) 231 | self.pad = albumentations.PadIfNeeded(min_height= self.size, min_width=self.size, 232 | border_mode=cv2.BORDER_CONSTANT, value= (255, 255, 255), 233 | ) 234 | self.preprocessor = albumentations.Compose([self.rescaler, self.pad]) 235 | 236 | if self.augment: # train 237 | # Add data aug transformations 238 | self.data_augmentation = albumentations.Compose([ 239 | albumentations.GaussianBlur(p=0.1), 240 | albumentations.OneOf([ 241 | albumentations.HueSaturationValue (p=0.3), 242 | albumentations.ToGray(p=0.3), 243 | albumentations.ChannelShuffle(p=0.3) 244 | ], p=0.3) 245 | ]) 246 | 247 | if to_tensor: 248 | self.tform = transforms.ToTensor() 249 | self.to_tensor = to_tensor 250 | 251 | # if exchange_channel: 252 | # self.exchange_channel = transforms.Lambda(lambda x: rearrange(x, 'c h w -> h w c')) 253 | 254 | def __call__(self, image): 255 | image = image.convert("RGB") 256 | image = np.array(image).astype(np.uint8) 257 | image = self.preprocessor(image=image)["image"] 258 | if self.augment: 259 | image = self.data_augmentation(image=image)['image'] 260 | image = (image/127.5 - 1.0).astype(np.float32) 261 | if self.to_tensor: 262 | image = self.tform(image) 263 | return image -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 126 | "dtype": torch.get_autocast_gpu_dtype(), 127 | "cache_enabled": torch.is_autocast_cache_enabled()} 128 | with torch.no_grad(): 129 | output_tensors = ctx.run_function(*ctx.input_tensors) 130 | return output_tensors 131 | 132 | @staticmethod 133 | def backward(ctx, *output_grads): 134 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 135 | with torch.enable_grad(), \ 136 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 137 | # Fixes a bug where the first op in run_function modifies the 138 | # Tensor storage in place, which is not allowed for detach()'d 139 | # Tensors. 140 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 141 | output_tensors = ctx.run_function(*shallow_copies) 142 | input_grads = torch.autograd.grad( 143 | output_tensors, 144 | ctx.input_tensors + ctx.input_params, 145 | output_grads, 146 | allow_unused=True, 147 | ) 148 | del ctx.input_tensors 149 | del ctx.input_params 150 | del output_tensors 151 | return (None, None) + input_grads 152 | 153 | 154 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 155 | """ 156 | Create sinusoidal timestep embeddings. 157 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 158 | These may be fractional. 159 | :param dim: the dimension of the output. 160 | :param max_period: controls the minimum frequency of the embeddings. 161 | :return: an [N x dim] Tensor of positional embeddings. 162 | """ 163 | if not repeat_only: 164 | half = dim // 2 165 | freqs = torch.exp( 166 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 167 | ).to(device=timesteps.device) 168 | args = timesteps[:, None].float() * freqs[None] 169 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 170 | if dim % 2: 171 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 172 | else: 173 | embedding = repeat(timesteps, 'b -> b d', d=dim) 174 | return embedding 175 | 176 | 177 | def zero_module(module): 178 | """ 179 | Zero out the parameters of a module and return it. 180 | """ 181 | for p in module.parameters(): 182 | p.detach().zero_() 183 | return module 184 | 185 | def identity_init_fc(module): 186 | """ 187 | initial weights of a fc module as 1 and bias as 0. 188 | """ 189 | nn.init.eye_(module.weight) 190 | nn.init.constant(module.bias, 0) 191 | # for p in module.parameters(): 192 | # nn.init.ones_(p) 193 | return module 194 | 195 | def scale_module(module, scale): 196 | """ 197 | Scale the parameters of a module and return it. 198 | """ 199 | for p in module.parameters(): 200 | p.detach().mul_(scale) 201 | return module 202 | 203 | 204 | def mean_flat(tensor): 205 | """ 206 | Take the mean over all non-batch dimensions. 207 | """ 208 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 209 | 210 | 211 | def normalization(channels): 212 | """ 213 | Make a standard normalization layer. 214 | :param channels: number of input channels. 215 | :return: an nn.Module for normalization. 216 | """ 217 | return GroupNorm32(32, channels) 218 | 219 | 220 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 221 | class SiLU(nn.Module): 222 | def forward(self, x): 223 | return x * torch.sigmoid(x) 224 | 225 | 226 | class GroupNorm32(nn.GroupNorm): 227 | def forward(self, x): 228 | return super().forward(x.float()).type(x.dtype) 229 | 230 | def conv_nd(dims, *args, **kwargs): 231 | """ 232 | Create a 1D, 2D, or 3D convolution module. 233 | """ 234 | if dims == 1: 235 | return nn.Conv1d(*args, **kwargs) 236 | elif dims == 2: 237 | return nn.Conv2d(*args, **kwargs) 238 | elif dims == 3: 239 | return nn.Conv3d(*args, **kwargs) 240 | raise ValueError(f"unsupported dimensions: {dims}") 241 | 242 | 243 | def linear(*args, **kwargs): 244 | """ 245 | Create a linear module. 246 | """ 247 | return nn.Linear(*args, **kwargs) 248 | 249 | 250 | def avg_pool_nd(dims, *args, **kwargs): 251 | """ 252 | Create a 1D, 2D, or 3D average pooling module. 253 | """ 254 | if dims == 1: 255 | return nn.AvgPool1d(*args, **kwargs) 256 | elif dims == 2: 257 | return nn.AvgPool2d(*args, **kwargs) 258 | elif dims == 3: 259 | return nn.AvgPool3d(*args, **kwargs) 260 | raise ValueError(f"unsupported dimensions: {dims}") 261 | 262 | 263 | class HybridConditioner(nn.Module): 264 | 265 | def __init__(self, c_concat_config, c_crossattn_config): 266 | super().__init__() 267 | self.concat_conditioner = instantiate_from_config(c_concat_config) 268 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 269 | 270 | def forward(self, c_concat, c_crossattn): 271 | c_concat = self.concat_conditioner(c_concat) 272 | c_crossattn = self.crossattn_conditioner(c_crossattn) 273 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 274 | 275 | 276 | def noise_like(shape, device, repeat=False): 277 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 278 | noise = lambda: torch.randn(shape, device=device) 279 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, init_num_updates = 0, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(init_num_updates, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) # 0 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | if decay == self.decay: 37 | print("ema_num_updates: ", self.num_updates, "decay: ", decay) 38 | 39 | one_minus_decay = 1.0 - decay 40 | 41 | with torch.no_grad(): 42 | m_param = dict(model.named_parameters()) 43 | shadow_params = dict(self.named_buffers()) 44 | 45 | for key in m_param: 46 | if m_param[key].requires_grad: 47 | sname = self.m_name2s_name[key] 48 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 49 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 50 | else: 51 | assert not key in self.m_name2s_name 52 | 53 | def copy_to(self, model): 54 | m_param = dict(model.named_parameters()) 55 | shadow_params = dict(self.named_buffers()) 56 | for key in m_param: 57 | if m_param[key].requires_grad: 58 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 59 | else: 60 | assert not key in self.m_name2s_name 61 | 62 | def store(self, parameters): 63 | """ 64 | Save the current parameters for restoring later. 65 | Args: 66 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 67 | temporarily stored. 68 | """ 69 | self.collected_params = [param.clone() for param in parameters] 70 | 71 | def restore(self, parameters): 72 | """ 73 | Restore the parameters stored with the `store` method. 74 | Useful to validate the model with EMA parameters without affecting the 75 | original optimization process. Store the parameters before the 76 | `copy_to` method. After validation (or model saving), use this to 77 | restore the former parameters. 78 | Args: 79 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 80 | updated with the stored parameters. 81 | """ 82 | for c_param, param in zip(self.collected_params, parameters): 83 | param.data.copy_(c_param.data) 84 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | 5 | class BaseModel(torch.nn.Module): 6 | def load(self, path): 7 | """Load model from file. 8 | 9 | Args: 10 | path (str): file path 11 | """ 12 | parameters = torch.load(path, map_location=torch.device('cpu')) 13 | 14 | if "optimizer" in parameters: 15 | parameters = parameters["model"] 16 | 17 | self.load_state_dict(parameters) 18 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 20 | nc = int(40 * (wh[0] / 256)) 21 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 22 | 23 | try: 24 | draw.text((0, 0), lines, fill="black", font=font) 25 | except UnicodeEncodeError: 26 | print("Cant encode string for logging. Skipping.") 27 | 28 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 29 | txts.append(txt) 30 | txts = np.stack(txts) 31 | txts = torch.tensor(txts) 32 | return txts 33 | 34 | 35 | def ismap(x): 36 | if not isinstance(x, torch.Tensor): 37 | return False 38 | return (len(x.shape) == 4) and (x.shape[1] > 3) 39 | 40 | 41 | def isimage(x): 42 | if not isinstance(x,torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 45 | 46 | 47 | def exists(x): 48 | return x is not None 49 | 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if isfunction(d) else d 55 | 56 | 57 | def mean_flat(tensor): 58 | """ 59 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 60 | Take the mean over all non-batch dimensions. 61 | """ 62 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 63 | 64 | 65 | def count_params(model, verbose=False): 66 | total_params = sum(p.numel() for p in model.parameters()) 67 | if verbose: 68 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 69 | return total_params 70 | 71 | 72 | def instantiate_from_config(config): 73 | if not "target" in config: 74 | if config == '__is_first_stage__': 75 | return None 76 | elif config == "__is_unconditional__": 77 | return None 78 | raise KeyError("Expected key `target` to instantiate.") 79 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 80 | 81 | 82 | def get_obj_from_str(string, reload=False): 83 | module, cls = string.rsplit(".", 1) 84 | if reload: 85 | module_imp = importlib.import_module(module) 86 | importlib.reload(module_imp) 87 | return getattr(importlib.import_module(module, package=None), cls) 88 | 89 | 90 | class AdamWwithEMAandWings(optim.Optimizer): 91 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 92 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 93 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 94 | ema_power=1., param_names=()): 95 | """AdamW that saves EMA versions of the parameters.""" 96 | if not 0.0 <= lr: 97 | raise ValueError("Invalid learning rate: {}".format(lr)) 98 | if not 0.0 <= eps: 99 | raise ValueError("Invalid epsilon value: {}".format(eps)) 100 | if not 0.0 <= betas[0] < 1.0: 101 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 102 | if not 0.0 <= betas[1] < 1.0: 103 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 104 | if not 0.0 <= weight_decay: 105 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 106 | if not 0.0 <= ema_decay <= 1.0: 107 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 108 | defaults = dict(lr=lr, betas=betas, eps=eps, 109 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 110 | ema_power=ema_power, param_names=param_names) 111 | super().__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super().__setstate__(state) 115 | for group in self.param_groups: 116 | group.setdefault('amsgrad', False) 117 | 118 | @torch.no_grad() 119 | def step(self, closure=None): 120 | """Performs a single optimization step. 121 | Args: 122 | closure (callable, optional): A closure that reevaluates the model 123 | and returns the loss. 124 | """ 125 | loss = None 126 | if closure is not None: 127 | with torch.enable_grad(): 128 | loss = closure() 129 | 130 | for group in self.param_groups: 131 | params_with_grad = [] 132 | grads = [] 133 | exp_avgs = [] 134 | exp_avg_sqs = [] 135 | ema_params_with_grad = [] 136 | state_sums = [] 137 | max_exp_avg_sqs = [] 138 | state_steps = [] 139 | amsgrad = group['amsgrad'] 140 | beta1, beta2 = group['betas'] 141 | ema_decay = group['ema_decay'] 142 | ema_power = group['ema_power'] 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | params_with_grad.append(p) 148 | if p.grad.is_sparse: 149 | raise RuntimeError('AdamW does not support sparse gradients') 150 | grads.append(p.grad) 151 | 152 | state = self.state[p] 153 | 154 | # State initialization 155 | if len(state) == 0: 156 | state['step'] = 0 157 | # Exponential moving average of gradient values 158 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 159 | # Exponential moving average of squared gradient values 160 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 161 | if amsgrad: 162 | # Maintains max of all exp. moving avg. of sq. grad. values 163 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 164 | # Exponential moving average of parameter values 165 | state['param_exp_avg'] = p.detach().float().clone() 166 | 167 | exp_avgs.append(state['exp_avg']) 168 | exp_avg_sqs.append(state['exp_avg_sq']) 169 | ema_params_with_grad.append(state['param_exp_avg']) 170 | 171 | if amsgrad: 172 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 173 | 174 | # update the steps for each param group update 175 | state['step'] += 1 176 | # record the step after step update 177 | state_steps.append(state['step']) 178 | 179 | optim._functional.adamw(params_with_grad, 180 | grads, 181 | exp_avgs, 182 | exp_avg_sqs, 183 | max_exp_avg_sqs, 184 | state_steps, 185 | amsgrad=amsgrad, 186 | beta1=beta1, 187 | beta2=beta2, 188 | lr=group['lr'], 189 | weight_decay=group['weight_decay'], 190 | eps=group['eps'], 191 | maximize=False) 192 | 193 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 194 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 195 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 196 | 197 | return loss 198 | 199 | def islistortuple(item): 200 | if isinstance(item, list) or isinstance(item, tuple): 201 | return True 202 | else: 203 | return False -------------------------------------------------------------------------------- /ocr_acc.py: -------------------------------------------------------------------------------- 1 | import easyocr 2 | import os 3 | import argparse 4 | from PIL import Image 5 | import numpy as np 6 | import Levenshtein as lev 7 | 8 | class AverageMeter(object): 9 | '''Computes and stores the average and current value''' 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def __repr__(self) -> str: 26 | return str(self.avg) 27 | 28 | class OCR_EM_Counter(object): 29 | '''Computes and stores the OCR Exactly Match Accuracy.''' 30 | def __init__(self): 31 | self.reset() 32 | 33 | def reset(self): 34 | self.ocr_acc_em = {} 35 | self.ocr_acc_em_rate = 0 36 | 37 | def add_text(self, text): 38 | if text not in self.ocr_acc_em: 39 | self.ocr_acc_em[text] = AverageMeter() 40 | 41 | def update(self, text, ocr_result): 42 | ocr_texts = [item[1] for item in ocr_result] 43 | self.ocr_acc_em[text].update(text in ocr_texts) 44 | self.ocr_acc_em_rate = sum([value.sum for value in self.ocr_acc_em.values()]) / sum([value.count for value in self.ocr_acc_em.values()]) 45 | 46 | def __repr__(self) -> str: 47 | ocr_str = ",".join([f"{key}:{repr(value)}" for key, value in self.ocr_acc_em.items()]) 48 | return f"OCR Accuracy is {ocr_str}.\nOCR EM Accuracy is {self.ocr_acc_em_rate}." 49 | # return f"OCR EM Accuracy is {self.ocr_acc_em_rate}." 50 | 51 | class OCR_EM_without_capitalization_Counter(object): 52 | '''Computes and stores the OCR Exactly Match Accuracy.''' 53 | def __init__(self): 54 | self.reset() 55 | 56 | def reset(self): 57 | self.ocr_acc_em = {} 58 | self.ocr_acc_em_rate = 0 59 | 60 | def add_text(self, text): 61 | if text not in self.ocr_acc_em: 62 | self.ocr_acc_em[text] = AverageMeter() 63 | 64 | def update(self, text, ocr_result): 65 | ocr_texts = [item[1].lower() for item in ocr_result] 66 | self.ocr_acc_em[text].update(text.lower() in ocr_texts) 67 | self.ocr_acc_em_rate = sum([value.sum for value in self.ocr_acc_em.values()]) / sum([value.count for value in self.ocr_acc_em.values()]) 68 | 69 | def __repr__(self) -> str: 70 | ocr_str = ",".join([f"{key}:{repr(value)}" for key, value in self.ocr_acc_em.items()]) 71 | return f"OCR without capitalization Accuracy is {ocr_str}.\nOCR EM without capitalization Accuracy is {self.ocr_acc_em_rate}." 72 | 73 | class OCR_Levenshtein_Distance(object): 74 | '''Computes and stores the OCR Levenshtein Distance Accuracy.''' 75 | def __init__(self): 76 | self.reset() 77 | 78 | def reset(self): 79 | self.ocr_lev = {} 80 | self.ocr_lev_avg = 0 81 | 82 | def add_text(self, text): 83 | if text not in self.ocr_lev: 84 | self.ocr_lev[text] = AverageMeter() 85 | 86 | def update(self, text, ocr_result): 87 | ocr_texts = [item[1] for item in ocr_result] 88 | lev_distance = [lev.distance(text, ocr_text) for ocr_text in ocr_texts] 89 | if lev_distance: 90 | self.ocr_lev[text].update(min(lev_distance)) 91 | self.ocr_lev_avg = sum([value.sum for value in self.ocr_lev.values()]) / sum([value.count for value in self.ocr_lev.values()]) 92 | 93 | def __repr__(self) -> str: 94 | return f"The Average Levenshtein Distance between Groundtruth and OCR result is {self.ocr_lev_avg}." 95 | 96 | if __name__ == "__main__": 97 | 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument('--path', type=str, default = "evaluate/images/stablediffusion_DrawText_Spelling_0.01_0.1_random", help='data file path') 100 | parser.add_argument('--num', type=int, default = 4, help='num per words') 101 | args = parser.parse_args() 102 | 103 | reader = easyocr.Reader(['en']) 104 | print(f"Evaluate on {args.path}.") 105 | ocr_em_counter = OCR_EM_Counter() 106 | ocr_em_wc_counter = OCR_EM_without_capitalization_Counter() 107 | ocr_lev = OCR_Levenshtein_Distance() 108 | for item in os.listdir(args.path): 109 | text = item 110 | path = os.path.join(args.path, item) 111 | ocr_em_counter.add_text(text) 112 | ocr_em_wc_counter.add_text(text) 113 | ocr_lev.add_text(text) 114 | for sub_item in [item for item in os.listdir(path) if ".png" in item][:args.num]: 115 | sub_path = os.path.join(path, sub_item) 116 | try: 117 | image = Image.open(sub_path) 118 | except: 119 | continue 120 | image_array = np.array(image) 121 | ocr_result = reader.readtext(image_array) 122 | ocr_em_counter.update(text, ocr_result) 123 | ocr_em_wc_counter.update(text, ocr_result) 124 | ocr_lev.update(text, ocr_result) 125 | 126 | print(ocr_em_counter) 127 | print(ocr_em_wc_counter) 128 | print(ocr_lev) -------------------------------------------------------------------------------- /pretrained_models/.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/pretrained_models/.md -------------------------------------------------------------------------------- /readme_files/architecture-n.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/readme_files/architecture-n.png -------------------------------------------------------------------------------- /readme_files/interface-clean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/readme_files/interface-clean.png -------------------------------------------------------------------------------- /readme_files/interface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/readme_files/interface.png -------------------------------------------------------------------------------- /readme_files/teaser_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIGText/GlyphControl-release/5891d9b245ac9209c22e3d6d5e933b95d428a230/readme_files/teaser_6.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pip #=20.3 2 | torch 3 | torchvision 4 | numpy==1.23.1 5 | albumentations==1.3.0 6 | opencv-python 7 | imageio==2.9.0 8 | imageio-ffmpeg==0.4.2 9 | pytorch-lightning==1.6.5 10 | omegaconf==2.1.1 11 | test-tube>=0.7.5 12 | einops==0.3.0 13 | transformers==4.24.0 14 | open_clip_torch==2.0.2 15 | torchmetrics==0.6.0 16 | timm 17 | gradio 18 | wandb 19 | tqdm 20 | easyocr 21 | triton==2.0.0.dev20221120 22 | Levenshtein 23 | xformers 24 | # ==0.0.21.dev541 25 | # new 26 | py-cpuinfo 27 | hjson 28 | pydantic<2.0.0 29 | git+https://github.com/openai/CLIP.git 30 | # pip uninstall nvidia_cublas_cu11 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='glyphcontrol', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) -------------------------------------------------------------------------------- /text_prompts/paper/CreativeBench/all_unigram_100000_plus_100_prompt_file_GlyphDraw_origin_remove_render_words.txt: -------------------------------------------------------------------------------- 1 | Black and white sign with the words "activeware" on a white background, wireframe, generative art. 2 | The slogan "aftnerne" is written on the schoolbag. 3 | Black and white sign with the words "alvoni" on a white background, wireframe, generative art. 4 | This car has a sign that reads "amprofon". 5 | "apazine" is written on the battery. 6 | Little bee holding a sign that says "bailup". 7 | This thermos has the slogan "baipeung" written on it. 8 | A detailed drawing with the text "bamore", alphabetism, thick gauge filigree. 9 | There is a sign "barilga" in the elevator. 10 | There is a book on the table with the title "bellsaint". 11 | In class, the teacher wrote the phrase "brahmaic" on the blackboard. 12 | At the airport, a sign that says "burwardeslyn". 13 | In a hospital, a sign that says "capitularis". 14 | A hand-drawn blueprint of a time machine titled "caranobe". 15 | A mouse with a flashlight says "challannain". 16 | A dark forest with only one light in the distance and the text "chapultepee". 17 | In class, the teacher wrote the phrase "chelonitis" on the blackboard. 18 | A yellow saxophone in rainbow colored smoke with the words "chinnakannan" looking like musical smoke. 19 | This cosmetic bottle says "churchilll". 20 | In the park, there is a sign "clirimtare". 21 | A lizard sitting on home plate at a baseball field with the words "copden" in a speech bubble. 22 | Studio shot of a pair of shoe sculptures made from colored wires and the text "cormont". 23 | A cartoon of a cat with a thought bubble that says "cubelo". 24 | There is a "danakadan" sign in the hotel. 25 | The city of Toronto seen from the plane, with a huge tower in the center of the frame, with the text "edsvikens" in the cartoon. 26 | There is a notice board next to the train station that reads "elitei". 27 | A detailed drawing with the text "ellisberg", alphabetism, thick gauge filigree. 28 | Small snail holding a sign that says "ensco". 29 | At the train station, a sign that says "galysheva". 30 | In the exhibition hall, a sign that reads "gamengame". 31 | Minimal sculpture of the word "garoff", made of light metallic iridescent chrome thin lines, 3D rendering, isometric perspective, super detailed, dark background. 32 | Photo illustration of Earth being struck by multiple lightning bolts merging, titled "ginzuishou". 33 | An image of a powerful looking car that looks like it was built for off-roading with the text "giregi". 34 | In the scenic spot, a sign that reads "hallatar". 35 | There is a sign saying "hammerclaw" in the mall. 36 | The slogan "harnetty" is posted on the bus stop. 37 | "heeresbekleidungsamt" slogan printed on school bus. 38 | A lizard sitting on home plate at a baseball field with the words "heyyyy" in a speech bubble. 39 | A picture of the Earth with the words "iccg". 40 | Text "iggesund" sculpture photo booth made of thin colored lines. 41 | This suitcase says "ivha". 42 | Wooden Giraffe Toothbrush with "jialingjang" lettering in rainbow colors. 43 | A parrot on a pirate ship, a parrot wearing a pirate hat with the text "jiaogulan". 44 | In a supermarket, a sign that says "josanas". 45 | Close up of a toothpaste tube figurine, 3D rendering, candy pastels, with the text "kangun" on the tube. 46 | A poster titled "kastlander" showing different species of quail. 47 | In the game lobby, the game console displays "keohan". 48 | A shot of a vine with the text "kokachin" sprouting, centered. 49 | The words "kornephoros" are written on the paper towel. 50 | There is a notice in the supermarket that says "kravchunovsky". 51 | An antique bottle labeled "lajtai". 52 | Little frog holding a sign that says "linneas". 53 | Minimal sculpture of the word "logonov", made of light metallic iridescent chrome thin lines, 3D rendering, isometric perspective, super detailed, dark background. 54 | Photo illustration of Earth being struck by multiple lightning bolts merging, titled "maduropeptin". 55 | Pillow in the shape of "malaefone", Alphabetism, fun jumbled letters, [close-up] bread, author unknown, graphic art. 56 | "managat" reminder posted on the bus. 57 | A cartoon of a dog wearing a chef's hat with a thought bubble that says "manoylo". 58 | Photos with "mediashopping" sign. 59 | A pumpkin with a beard, a monocle and a top hat with the text "minioudaki" in a speech bubble. 60 | Hubble and the Milky Way with the text "monaut". 61 | Billboard with "muthialpettah". 62 | A photo of a dandelion field with the caption "nalitabari". 63 | An art installation of a chair with "nizhnelomovskoye" engraved on the back. 64 | A parrot on a pirate ship, a parrot wearing a pirate hat with the text "odnoin". 65 | Blueprint of a house with a triangular roof, square walls and rectangular floor with the message "ookstores". 66 | A plane flies over the city with the words "ormelle" written in smoke trails. 67 | "otoparts" generative art, sticky smoke made of dots, rivers, graphic design and white background. 68 | The slogan of "paikary" is written in the lottery station. 69 | A photo of an aquarium with fish in it, with the words "procdump". 70 | "quercusglobulus" reminder posted on the bus. 71 | A newspaper headline read "qurniyya" and a photo showed a half-eaten pumpkin. 72 | The promotional video of "reviewcentre" is played in the screening hall. 73 | "rimydis" reminder posted in the restaurant. 74 | There is a notice in the supermarket that says "saukrieg". 75 | Professionally designed logo for a bakery called "scilliano". 76 | A globe with the words "secciones" written in bold letters, with continents in bright colors. 77 | At the train station, a sign that says "serranito". 78 | Professionally designed logo for a bakery called "siyaya". 79 | In a hospital, a sign that says "slevyas". 80 | Piggy holding a sign that says "spissoy". 81 | Photo of two hands holding a heart in one and a lightning bolt with the words "stotnikar". 82 | Grow in a pretty pot with a "sukhaybarat" sign. 83 | "switzerlands" written on the door. 84 | "syretzk" written on the door. 85 | Scrabble board showing the words "taktai". 86 | A hastily handwritten note saying "thangundi" posted on the fridge. 87 | Photo of a helicopter with "thellaru" written on the side, landing on a helipad in a valley with river, trees and mountains in the background. 88 | "theresianos" generative art, sticky smoke made of dots, rivers, graphic design and white background. 89 | A shot of a vine with the text "thornborg" sprouting, centered. 90 | A photo of a rabbit drinking coffee and reading a book titled "tikaw" is visible. 91 | Studio shot of a pair of shoe sculptures made from colored wires and the text "tiruvanaikaval". 92 | This suitcase says "tuolanshan". 93 | Black and white sign with the words "undpko" on a white background, wireframe, generative art. 94 | Billboard with "unveiledness". 95 | "vayalil" sign with home decor. 96 | "vernetois" warning printed on the entrance of the Public Security Bureau. 97 | At the train station, a sign that says "vogeding". 98 | A street sign on the street reads "wowomen". 99 | The sign "zavydovytska" hangs beside the swimming pool. 100 | A t-shirt that says "zwangsanstalt". 101 | -------------------------------------------------------------------------------- /text_prompts/paper/CreativeBench/all_unigram_10000_100000_100_prompt_file_GlyphDraw_origin_remove_render_words.txt: -------------------------------------------------------------------------------- 1 | This suitcase says "adelbert". 2 | An antique bottle labeled "architecturally". 3 | The words "artyom" are written on the paper towel. 4 | At the airport, a sign that says "ashlee". 5 | A hastily handwritten note saying "barrel-vaulted" posted on the fridge. 6 | A pumpkin with a beard, a monocle and a top hat with the text "benazir" in a speech bubble. 7 | A lizard sitting on home plate at a baseball field with the words "beria" in a speech bubble. 8 | Text "bosanska" sculpture photo booth made of thin colored lines. 9 | "braschi" written on the door. 10 | A shot of a vine with the text "broaden" sprouting, centered. 11 | Blueprint of a house with a triangular roof, square walls and rectangular floor with the message "buen". 12 | Black and white sign with the words "capsize" on a white background, wireframe, generative art. 13 | In a supermarket, a sign that says "carboxylase". 14 | Wooden Giraffe Toothbrush with "carden" lettering in rainbow colors. 15 | At the train station, a sign that says "carney". 16 | Professionally designed logo for a bakery called "cayuse". 17 | Little frog holding a sign that says "chuen". 18 | The slogan "chatman" is posted on the bus stop. 19 | Studio shot of a pair of shoe sculptures made from colored wires and the text "chiu". 20 | Scrabble board showing the words "coleoptera". 21 | The slogan of "complexion" is written in the lottery station. 22 | A globe with the words "contrarian" written in bold letters, with continents in bright colors. 23 | A hand-drawn blueprint of a time machine titled "devereaux". 24 | A newspaper headline read "drogba" and a photo showed a half-eaten pumpkin. 25 | Photo of a helicopter with "elegans" written on the side, landing on a helipad in a valley with river, trees and mountains in the background. 26 | There is a notice in the supermarket that says "enoch". 27 | This car has a sign that reads "eupithecia". 28 | "galatians" written on the door. 29 | A photo of a rabbit drinking coffee and reading a book titled "goldstone" is visible. 30 | Billboard with "groot". 31 | "gunboat" reminder posted in the restaurant. 32 | Piggy holding a sign that says "handsome". 33 | Minimal sculpture of the word "hata", made of light metallic iridescent chrome thin lines, 3D rendering, isometric perspective, super detailed, dark background. 34 | There is a sign "holdout" in the elevator. 35 | In the game lobby, the game console displays "hypodermic". 36 | There is a "incrementally" sign in the hotel. 37 | A plane flies over the city with the words "inhabitants" written in smoke trails. 38 | In a hospital, a sign that says "italics". 39 | In class, the teacher wrote the phrase "jasdf" on the blackboard. 40 | A mouse with a flashlight says "kahungunu". 41 | "lauper" generative art, sticky smoke made of dots, rivers, graphic design and white background. 42 | An image of a powerful looking car that looks like it was built for off-roading with the text "leffler". 43 | There is a book on the table with the title "limits". 44 | Professionally designed logo for a bakery called "lucha". 45 | "luchino" slogan printed on school bus. 46 | A parrot on a pirate ship, a parrot wearing a pirate hat with the text "lupu". 47 | This thermos has the slogan "maggs" written on it. 48 | A cartoon of a dog wearing a chef's hat with a thought bubble that says "marra". 49 | Photo illustration of Earth being struck by multiple lightning bolts merging, titled "mauricio". 50 | A photo of a dandelion field with the caption "millersville". 51 | In the park, there is a sign "nebel". 52 | At the train station, a sign that says "nepean". 53 | "obsidian" generative art, sticky smoke made of dots, rivers, graphic design and white background. 54 | A photo of an aquarium with fish in it, with the words "painstakingly". 55 | "phosphor" sign with home decor. 56 | In the exhibition hall, a sign that reads "polls". 57 | The city of Toronto seen from the plane, with a huge tower in the center of the frame, with the text "poona" in the cartoon. 58 | Black and white sign with the words "raimundo" on a white background, wireframe, generative art. 59 | The promotional video of "redemptorist" is played in the screening hall. 60 | A t-shirt that says "reigns". 61 | Minimal sculpture of the word "ribbentrop", made of light metallic iridescent chrome thin lines, 3D rendering, isometric perspective, super detailed, dark background. 62 | Photo of two hands holding a heart in one and a lightning bolt with the words "ricketts". 63 | Studio shot of a pair of shoe sculptures made from colored wires and the text "rte". 64 | An art installation of a chair with "rtve" engraved on the back. 65 | In class, the teacher wrote the phrase "scabies" on the blackboard. 66 | "sein" warning printed on the entrance of the Public Security Bureau. 67 | Photo illustration of Earth being struck by multiple lightning bolts merging, titled "seismologist". 68 | "sen" reminder posted on the bus. 69 | There is a notice in the supermarket that says "smirnoff". 70 | "solti" reminder posted on the bus. 71 | There is a sign saying "sounds" in the mall. 72 | Pillow in the shape of "speaks", Alphabetism, fun jumbled letters, [close-up] bread, author unknown, graphic art. 73 | This cosmetic bottle says "spellman". 74 | A parrot on a pirate ship, a parrot wearing a pirate hat with the text "spillover". 75 | Little bee holding a sign that says "sponsors". 76 | In a hospital, a sign that says "spotify". 77 | The sign "strindberg" hangs beside the swimming pool. 78 | Hubble and the Milky Way with the text "subsidised". 79 | There is a notice board next to the train station that reads "tabbed". 80 | A cartoon of a cat with a thought bubble that says "tablas". 81 | A dark forest with only one light in the distance and the text "takraw". 82 | Small snail holding a sign that says "taxus". 83 | Close up of a toothpaste tube figurine, 3D rendering, candy pastels, with the text "teatre" on the tube. 84 | In the scenic spot, a sign that reads "throttle". 85 | "tirtha" is written on the battery. 86 | This suitcase says "toba". 87 | A detailed drawing with the text "tozawa", alphabetism, thick gauge filigree. 88 | A shot of a vine with the text "trappers" sprouting, centered. 89 | A yellow saxophone in rainbow colored smoke with the words "umaga" looking like musical smoke. 90 | Photos with "unchanging" sign. 91 | Billboard with "undergrad". 92 | A street sign on the street reads "union-tribune". 93 | The slogan "virgins" is written on the schoolbag. 94 | At the train station, a sign that says "visconti". 95 | Grow in a pretty pot with a "vladimirovich" sign. 96 | A detailed drawing with the text "wace", alphabetism, thick gauge filigree. 97 | A lizard sitting on home plate at a baseball field with the words "wemyss" in a speech bubble. 98 | Black and white sign with the words "whoop" on a white background, wireframe, generative art. 99 | A picture of the Earth with the words "wukong". 100 | A poster titled "yemi" showing different species of quail. 101 | -------------------------------------------------------------------------------- /text_prompts/paper/CreativeBench/all_unigram_1000_10000_100_prompt_file_GlyphDraw_origin_remove_render_words.txt: -------------------------------------------------------------------------------- 1 | A shot of a vine with the text "abdomen" sprouting, centered. 2 | "accommodate" reminder posted in the restaurant. 3 | The sign "activate" hangs beside the swimming pool. 4 | A picture of the Earth with the words "aka". 5 | There is a sign saying "ammunition" in the mall. 6 | The city of Toronto seen from the plane, with a huge tower in the center of the frame, with the text "apostle" in the cartoon. 7 | Billboard with "appalachian". 8 | "associates" reminder posted on the bus. 9 | Grow in a pretty pot with a "bangalore" sign. 10 | There is a notice in the supermarket that says "beaver". 11 | In the exhibition hall, a sign that reads "beetle". 12 | Photo illustration of Earth being struck by multiple lightning bolts merging, titled "bernard". 13 | "broadway" slogan printed on school bus. 14 | Piggy holding a sign that says "brussels". 15 | The promotional video of "buildings" is played in the screening hall. 16 | Photos with "butcher" sign. 17 | A mouse with a flashlight says "caesar". 18 | Photo illustration of Earth being struck by multiple lightning bolts merging, titled "caucus". 19 | A globe with the words "children" written in bold letters, with continents in bright colors. 20 | A photo of a dandelion field with the caption "churchyard". 21 | There is a "cloud" sign in the hotel. 22 | The slogan "coleman" is written on the schoolbag. 23 | A pumpkin with a beard, a monocle and a top hat with the text "column" in a speech bubble. 24 | This suitcase says "countess". 25 | A photo of an aquarium with fish in it, with the words "dartmouth". 26 | Wooden Giraffe Toothbrush with "defunct" lettering in rainbow colors. 27 | A parrot on a pirate ship, a parrot wearing a pirate hat with the text "dense". 28 | Small snail holding a sign that says "derivative". 29 | Close up of a toothpaste tube figurine, 3D rendering, candy pastels, with the text "desk" on the tube. 30 | A hastily handwritten note saying "devotion" posted on the fridge. 31 | In a supermarket, a sign that says "divine". 32 | A hand-drawn blueprint of a time machine titled "doyle". 33 | Little bee holding a sign that says "drink". 34 | A cartoon of a cat with a thought bubble that says "dune". 35 | Minimal sculpture of the word "elimination", made of light metallic iridescent chrome thin lines, 3D rendering, isometric perspective, super detailed, dark background. 36 | "emit" warning printed on the entrance of the Public Security Bureau. 37 | At the train station, a sign that says "enable". 38 | A poster titled "entry" showing different species of quail. 39 | At the airport, a sign that says "evans". 40 | A yellow saxophone in rainbow colored smoke with the words "ford" looking like musical smoke. 41 | The words "fountain" are written on the paper towel. 42 | A shot of a vine with the text "gear" sprouting, centered. 43 | This cosmetic bottle says "grouping". 44 | A photo of a rabbit drinking coffee and reading a book titled "growth" is visible. 45 | "handball" written on the door. 46 | This car has a sign that reads "heaven". 47 | Billboard with "insert". 48 | Professionally designed logo for a bakery called "intended". 49 | An art installation of a chair with "inventor" engraved on the back. 50 | Text "islands" sculpture photo booth made of thin colored lines. 51 | Pillow in the shape of "kate", Alphabetism, fun jumbled letters, [close-up] bread, author unknown, graphic art. 52 | Hubble and the Milky Way with the text "larva". 53 | This suitcase says "lean". 54 | "lighting" generative art, sticky smoke made of dots, rivers, graphic design and white background. 55 | Blueprint of a house with a triangular roof, square walls and rectangular floor with the message "literally". 56 | A dark forest with only one light in the distance and the text "lounge". 57 | An antique bottle labeled "mad". 58 | There is a notice in the supermarket that says "magical". 59 | A lizard sitting on home plate at a baseball field with the words "marcel" in a speech bubble. 60 | A detailed drawing with the text "maybe", alphabetism, thick gauge filigree. 61 | "melbourne" written on the door. 62 | "membership" generative art, sticky smoke made of dots, rivers, graphic design and white background. 63 | A parrot on a pirate ship, a parrot wearing a pirate hat with the text "memoir". 64 | There is a notice board next to the train station that reads "monroe". 65 | In the game lobby, the game console displays "neighboring". 66 | A detailed drawing with the text "noon", alphabetism, thick gauge filigree. 67 | The slogan "palestinian" is posted on the bus stop. 68 | In class, the teacher wrote the phrase "philharmonic" on the blackboard. 69 | "physically" reminder posted on the bus. 70 | The slogan of "physiology" is written in the lottery station. 71 | An image of a powerful looking car that looks like it was built for off-roading with the text "pile". 72 | In a hospital, a sign that says "pocket". 73 | At the train station, a sign that says "popularly". 74 | Photo of a helicopter with "postwar" written on the side, landing on a helipad in a valley with river, trees and mountains in the background. 75 | Photo of two hands holding a heart in one and a lightning bolt with the words "precision". 76 | A newspaper headline read "provoke" and a photo showed a half-eaten pumpkin. 77 | Little frog holding a sign that says "rainy". 78 | A lizard sitting on home plate at a baseball field with the words "raw" in a speech bubble. 79 | Minimal sculpture of the word "reduced", made of light metallic iridescent chrome thin lines, 3D rendering, isometric perspective, super detailed, dark background. 80 | Black and white sign with the words "relative" on a white background, wireframe, generative art. 81 | Professionally designed logo for a bakery called "rotary". 82 | In the scenic spot, a sign that reads "salzburg". 83 | Black and white sign with the words "satisfy" on a white background, wireframe, generative art. 84 | A t-shirt that says "seminary". 85 | "soup" sign with home decor. 86 | In the park, there is a sign "spectacular". 87 | There is a sign "suburban" in the elevator. 88 | "superstar" is written on the battery. 89 | A cartoon of a dog wearing a chef's hat with a thought bubble that says "surgeon". 90 | Studio shot of a pair of shoe sculptures made from colored wires and the text "thoroughly". 91 | A plane flies over the city with the words "tornado" written in smoke trails. 92 | Black and white sign with the words "towers" on a white background, wireframe, generative art. 93 | Studio shot of a pair of shoe sculptures made from colored wires and the text "translate". 94 | In class, the teacher wrote the phrase "treaty" on the blackboard. 95 | In a hospital, a sign that says "tyre". 96 | A street sign on the street reads "unrest". 97 | This thermos has the slogan "urban" written on it. 98 | Scrabble board showing the words "walt". 99 | There is a book on the table with the title "wisdom". 100 | At the train station, a sign that says "withdraw". 101 | -------------------------------------------------------------------------------- /text_prompts/paper/CreativeBench/all_unigram_top_1000_100_prompt_file_GlyphDraw_origin_remove_render_words.txt: -------------------------------------------------------------------------------- 1 | There is a sign saying "accept" in the mall. 2 | "acquire" warning printed on the entrance of the Public Security Bureau. 3 | A hastily handwritten note saying "african" posted on the fridge. 4 | "american" generative art, sticky smoke made of dots, rivers, graphic design and white background. 5 | In a hospital, a sign that says "angeles". 6 | "appearance" written on the door. 7 | Photo illustration of Earth being struck by multiple lightning bolts merging, titled "arm". 8 | Blueprint of a house with a triangular roof, square walls and rectangular floor with the message "arrive". 9 | Close up of a toothpaste tube figurine, 3D rendering, candy pastels, with the text "assembly" on the tube. 10 | Photos with "attempt" sign. 11 | The city of Toronto seen from the plane, with a huge tower in the center of the frame, with the text "away" in the cartoon. 12 | A detailed drawing with the text "band", alphabetism, thick gauge filigree. 13 | This suitcase says "bay". 14 | A dark forest with only one light in the distance and the text "boy". 15 | A photo of a dandelion field with the caption "canada". 16 | Minimal sculpture of the word "cast", made of light metallic iridescent chrome thin lines, 3D rendering, isometric perspective, super detailed, dark background. 17 | Little bee holding a sign that says "cell". 18 | A poster titled "church" showing different species of quail. 19 | Professionally designed logo for a bakery called "club". 20 | A hand-drawn blueprint of a time machine titled "commercial". 21 | Photo of two hands holding a heart in one and a lightning bolt with the words "connect". 22 | A photo of an aquarium with fish in it, with the words "construct". 23 | Black and white sign with the words "construction" on a white background, wireframe, generative art. 24 | Text "dance" sculpture photo booth made of thin colored lines. 25 | A lizard sitting on home plate at a baseball field with the words "debut" in a speech bubble. 26 | In the scenic spot, a sign that reads "decision". 27 | There is a "department" sign in the hotel. 28 | Black and white sign with the words "direct" on a white background, wireframe, generative art. 29 | Studio shot of a pair of shoe sculptures made from colored wires and the text "drop". 30 | Professionally designed logo for a bakery called "effort". 31 | Billboard with "energy". 32 | A picture of the Earth with the words "facility". 33 | A t-shirt that says "feature". 34 | A mouse with a flashlight says "final". 35 | There is a notice in the supermarket that says "financial". 36 | In a supermarket, a sign that says "fly". 37 | In the exhibition hall, a sign that reads "french". 38 | There is a notice board next to the train station that reads "greek". 39 | There is a notice in the supermarket that says "help". 40 | A street sign on the street reads "hospital". 41 | An art installation of a chair with "independent" engraved on the back. 42 | Scrabble board showing the words "individual". 43 | In the game lobby, the game console displays "king". 44 | There is a book on the table with the title "lack". 45 | At the train station, a sign that says "lake". 46 | The promotional video of "late" is played in the screening hall. 47 | The words "leave" are written on the paper towel. 48 | A cartoon of a cat with a thought bubble that says "life". 49 | This thermos has the slogan "link" written on it. 50 | A yellow saxophone in rainbow colored smoke with the words "london" looking like musical smoke. 51 | Black and white sign with the words "look" on a white background, wireframe, generative art. 52 | A detailed drawing with the text "lord", alphabetism, thick gauge filigree. 53 | A globe with the words "main" written in bold letters, with continents in bright colors. 54 | A cartoon of a dog wearing a chef's hat with a thought bubble that says "man". 55 | Little frog holding a sign that says "measure". 56 | "medium" is written on the battery. 57 | Hubble and the Milky Way with the text "mile". 58 | In the park, there is a sign "month". 59 | Wooden Giraffe Toothbrush with "mountain" lettering in rainbow colors. 60 | A parrot on a pirate ship, a parrot wearing a pirate hat with the text "northern". 61 | Piggy holding a sign that says "notable". 62 | "october" generative art, sticky smoke made of dots, rivers, graphic design and white background. 63 | The slogan of "official" is written in the lottery station. 64 | Small snail holding a sign that says "operate". 65 | "party" reminder posted on the bus. 66 | "pass" reminder posted on the bus. 67 | Minimal sculpture of the word "person", made of light metallic iridescent chrome thin lines, 3D rendering, isometric perspective, super detailed, dark background. 68 | An antique bottle labeled "police". 69 | In class, the teacher wrote the phrase "popular" on the blackboard. 70 | Studio shot of a pair of shoe sculptures made from colored wires and the text "produce". 71 | At the train station, a sign that says "professional". 72 | The slogan "program" is posted on the bus stop. 73 | Photo of a helicopter with "public" written on the side, landing on a helipad in a valley with river, trees and mountains in the background. 74 | Photo illustration of Earth being struck by multiple lightning bolts merging, titled "purpose". 75 | An image of a powerful looking car that looks like it was built for off-roading with the text "range". 76 | The sign "reach" hangs beside the swimming pool. 77 | This car has a sign that reads "records". 78 | Billboard with "red". 79 | "refer" sign with home decor. 80 | At the airport, a sign that says "regular". 81 | Pillow in the shape of "religious", Alphabetism, fun jumbled letters, [close-up] bread, author unknown, graphic art. 82 | This suitcase says "similar". 83 | A lizard sitting on home plate at a baseball field with the words "social" in a speech bubble. 84 | A newspaper headline read "stone" and a photo showed a half-eaten pumpkin. 85 | A pumpkin with a beard, a monocle and a top hat with the text "subsequently" in a speech bubble. 86 | In a hospital, a sign that says "survive". 87 | A shot of a vine with the text "teach" sprouting, centered. 88 | "traditional" reminder posted in the restaurant. 89 | "travel" written on the door. 90 | This cosmetic bottle says "type". 91 | The slogan "upper" is written on the schoolbag. 92 | A parrot on a pirate ship, a parrot wearing a pirate hat with the text "van". 93 | There is a sign "variety" in the elevator. 94 | "vehicle" slogan printed on school bus. 95 | Grow in a pretty pot with a "victory" sign. 96 | A shot of a vine with the text "video" sprouting, centered. 97 | In class, the teacher wrote the phrase "wife" on the blackboard. 98 | A photo of a rabbit drinking coffee and reading a book titled "william" is visible. 99 | At the train station, a sign that says "year". 100 | A plane flies over the city with the words "young" written in smoke trails. 101 | -------------------------------------------------------------------------------- /text_prompts/paper/SimpleBench/all_unigram_100000_plus_100_1_gram.txt: -------------------------------------------------------------------------------- 1 | A sign that says "activeware". 2 | A sign that says "aftnerne". 3 | A sign that says "alvoni". 4 | A sign that says "amprofon". 5 | A sign that says "apazine". 6 | A sign that says "bailup". 7 | A sign that says "baipeung". 8 | A sign that says "bamore". 9 | A sign that says "barilga". 10 | A sign that says "bellsaint". 11 | A sign that says "brahmaic". 12 | A sign that says "burwardeslyn". 13 | A sign that says "capitularis". 14 | A sign that says "challannain". 15 | A sign that says "caranobe". 16 | A sign that says "chapultepee". 17 | A sign that says "chelonitis". 18 | A sign that says "chinnakannan". 19 | A sign that says "churchilll". 20 | A sign that says "clirimtare". 21 | A sign that says "cormont". 22 | A sign that says "copden". 23 | A sign that says "cubelo". 24 | A sign that says "danakadan". 25 | A sign that says "edsvikens". 26 | A sign that says "elitei". 27 | A sign that says "ellisberg". 28 | A sign that says "ensco". 29 | A sign that says "galysheva". 30 | A sign that says "gamengame". 31 | A sign that says "garoff". 32 | A sign that says "ginzuishou". 33 | A sign that says "giregi". 34 | A sign that says "hallatar". 35 | A sign that says "hammerclaw". 36 | A sign that says "harnetty". 37 | A sign that says "heeresbekleidungsamt". 38 | A sign that says "heyyyy". 39 | A sign that says "iccg". 40 | A sign that says "iggesund". 41 | A sign that says "ivha". 42 | A sign that says "jialingjang". 43 | A sign that says "jiaogulan". 44 | A sign that says "josanas". 45 | A sign that says "kangun". 46 | A sign that says "kastlander". 47 | A sign that says "keohan". 48 | A sign that says "kokachin". 49 | A sign that says "kornephoros". 50 | A sign that says "kravchunovsky". 51 | A sign that says "lajtai". 52 | A sign that says "linneas". 53 | A sign that says "logonov". 54 | A sign that says "maduropeptin". 55 | A sign that says "malaefone". 56 | A sign that says "managat". 57 | A sign that says "manoylo". 58 | A sign that says "mediashopping". 59 | A sign that says "minioudaki". 60 | A sign that says "monaut". 61 | A sign that says "muthialpettah". 62 | A sign that says "nalitabari". 63 | A sign that says "nizhnelomovskoye". 64 | A sign that says "odnoin". 65 | A sign that says "ookstores". 66 | A sign that says "ormelle". 67 | A sign that says "otoparts". 68 | A sign that says "paikary". 69 | A sign that says "procdump". 70 | A sign that says "quercusglobulus". 71 | A sign that says "qurniyya". 72 | A sign that says "reviewcentre". 73 | A sign that says "rimydis". 74 | A sign that says "saukrieg". 75 | A sign that says "scilliano". 76 | A sign that says "secciones". 77 | A sign that says "serranito". 78 | A sign that says "siyaya". 79 | A sign that says "slevyas". 80 | A sign that says "spissoy". 81 | A sign that says "stotnikar". 82 | A sign that says "sukhaybarat". 83 | A sign that says "switzerlands". 84 | A sign that says "syretzk". 85 | A sign that says "taktai". 86 | A sign that says "thangundi". 87 | A sign that says "thellaru". 88 | A sign that says "theresianos". 89 | A sign that says "thornborg". 90 | A sign that says "tikaw". 91 | A sign that says "tiruvanaikaval". 92 | A sign that says "tuolanshan". 93 | A sign that says "undpko". 94 | A sign that says "unveiledness". 95 | A sign that says "vayalil". 96 | A sign that says "vernetois". 97 | A sign that says "vogeding". 98 | A sign that says "wowomen". 99 | A sign that says "zavydovytska". 100 | A sign that says "zwangsanstalt". 101 | -------------------------------------------------------------------------------- /text_prompts/paper/SimpleBench/all_unigram_10000_100000_100_1_gram.txt: -------------------------------------------------------------------------------- 1 | A sign that says "adelbert". 2 | A sign that says "architecturally". 3 | A sign that says "artyom". 4 | A sign that says "ashlee". 5 | A sign that says "barrel-vaulted". 6 | A sign that says "benazir". 7 | A sign that says "beria". 8 | A sign that says "bosanska". 9 | A sign that says "braschi". 10 | A sign that says "broaden". 11 | A sign that says "buen". 12 | A sign that says "capsize". 13 | A sign that says "carboxylase". 14 | A sign that says "carden". 15 | A sign that says "carney". 16 | A sign that says "cayuse". 17 | A sign that says "chatman". 18 | A sign that says "chiu". 19 | A sign that says "chuen". 20 | A sign that says "coleoptera". 21 | A sign that says "complexion". 22 | A sign that says "contrarian". 23 | A sign that says "devereaux". 24 | A sign that says "drogba". 25 | A sign that says "elegans". 26 | A sign that says "enoch". 27 | A sign that says "eupithecia". 28 | A sign that says "galatians". 29 | A sign that says "goldstone". 30 | A sign that says "groot". 31 | A sign that says "gunboat". 32 | A sign that says "handsome". 33 | A sign that says "hata". 34 | A sign that says "holdout". 35 | A sign that says "hypodermic". 36 | A sign that says "incrementally". 37 | A sign that says "inhabitants". 38 | A sign that says "italics". 39 | A sign that says "jasdf". 40 | A sign that says "kahungunu". 41 | A sign that says "lauper". 42 | A sign that says "leffler". 43 | A sign that says "limits". 44 | A sign that says "lucha". 45 | A sign that says "luchino". 46 | A sign that says "lupu". 47 | A sign that says "maggs". 48 | A sign that says "marra". 49 | A sign that says "mauricio". 50 | A sign that says "millersville". 51 | A sign that says "nebel". 52 | A sign that says "nepean". 53 | A sign that says "obsidian". 54 | A sign that says "painstakingly". 55 | A sign that says "phosphor". 56 | A sign that says "polls". 57 | A sign that says "poona". 58 | A sign that says "raimundo". 59 | A sign that says "redemptorist". 60 | A sign that says "reigns". 61 | A sign that says "ribbentrop". 62 | A sign that says "ricketts". 63 | A sign that says "rte". 64 | A sign that says "rtve". 65 | A sign that says "scabies". 66 | A sign that says "sein". 67 | A sign that says "seismologist". 68 | A sign that says "sen". 69 | A sign that says "smirnoff". 70 | A sign that says "solti". 71 | A sign that says "sounds". 72 | A sign that says "speaks". 73 | A sign that says "spellman". 74 | A sign that says "spillover". 75 | A sign that says "sponsors". 76 | A sign that says "spotify". 77 | A sign that says "strindberg". 78 | A sign that says "subsidised". 79 | A sign that says "tabbed". 80 | A sign that says "tablas". 81 | A sign that says "takraw". 82 | A sign that says "taxus". 83 | A sign that says "teatre". 84 | A sign that says "throttle". 85 | A sign that says "tirtha". 86 | A sign that says "toba". 87 | A sign that says "tozawa". 88 | A sign that says "trappers". 89 | A sign that says "umaga". 90 | A sign that says "unchanging". 91 | A sign that says "undergrad". 92 | A sign that says "union-tribune". 93 | A sign that says "virgins". 94 | A sign that says "visconti". 95 | A sign that says "vladimirovich". 96 | A sign that says "wace". 97 | A sign that says "wemyss". 98 | A sign that says "whoop". 99 | A sign that says "wukong". 100 | A sign that says "yemi". 101 | -------------------------------------------------------------------------------- /text_prompts/paper/SimpleBench/all_unigram_1000_10000_100_1_gram.txt: -------------------------------------------------------------------------------- 1 | A sign that says "abdomen". 2 | A sign that says "accommodate". 3 | A sign that says "activate". 4 | A sign that says "aka". 5 | A sign that says "ammunition". 6 | A sign that says "apostle". 7 | A sign that says "appalachian". 8 | A sign that says "associates". 9 | A sign that says "bangalore". 10 | A sign that says "beaver". 11 | A sign that says "beetle". 12 | A sign that says "bernard". 13 | A sign that says "broadway". 14 | A sign that says "brussels". 15 | A sign that says "buildings". 16 | A sign that says "butcher". 17 | A sign that says "caesar". 18 | A sign that says "caucus". 19 | A sign that says "children". 20 | A sign that says "churchyard". 21 | A sign that says "coleman". 22 | A sign that says "cloud". 23 | A sign that says "column". 24 | A sign that says "countess". 25 | A sign that says "dartmouth". 26 | A sign that says "defunct". 27 | A sign that says "dense". 28 | A sign that says "derivative". 29 | A sign that says "desk". 30 | A sign that says "devotion". 31 | A sign that says "divine". 32 | A sign that says "doyle". 33 | A sign that says "drink". 34 | A sign that says "dune". 35 | A sign that says "elimination". 36 | A sign that says "emit". 37 | A sign that says "enable". 38 | A sign that says "entry". 39 | A sign that says "evans". 40 | A sign that says "ford". 41 | A sign that says "fountain". 42 | A sign that says "gear". 43 | A sign that says "grouping". 44 | A sign that says "growth". 45 | A sign that says "handball". 46 | A sign that says "heaven". 47 | A sign that says "insert". 48 | A sign that says "intended". 49 | A sign that says "inventor". 50 | A sign that says "islands". 51 | A sign that says "kate". 52 | A sign that says "larva". 53 | A sign that says "lean". 54 | A sign that says "lighting". 55 | A sign that says "literally". 56 | A sign that says "lounge". 57 | A sign that says "mad". 58 | A sign that says "magical". 59 | A sign that says "marcel". 60 | A sign that says "maybe". 61 | A sign that says "melbourne". 62 | A sign that says "membership". 63 | A sign that says "memoir". 64 | A sign that says "monroe". 65 | A sign that says "neighboring". 66 | A sign that says "noon". 67 | A sign that says "palestinian". 68 | A sign that says "philharmonic". 69 | A sign that says "physically". 70 | A sign that says "physiology". 71 | A sign that says "pile". 72 | A sign that says "pocket". 73 | A sign that says "popularly". 74 | A sign that says "postwar". 75 | A sign that says "precision". 76 | A sign that says "provoke". 77 | A sign that says "rainy". 78 | A sign that says "raw". 79 | A sign that says "reduced". 80 | A sign that says "relative". 81 | A sign that says "rotary". 82 | A sign that says "salzburg". 83 | A sign that says "satisfy". 84 | A sign that says "seminary". 85 | A sign that says "soup". 86 | A sign that says "spectacular". 87 | A sign that says "suburban". 88 | A sign that says "superstar". 89 | A sign that says "surgeon". 90 | A sign that says "thoroughly". 91 | A sign that says "tornado". 92 | A sign that says "towers". 93 | A sign that says "translate". 94 | A sign that says "treaty". 95 | A sign that says "tyre". 96 | A sign that says "unrest". 97 | A sign that says "urban". 98 | A sign that says "walt". 99 | A sign that says "wisdom". 100 | A sign that says "withdraw". 101 | -------------------------------------------------------------------------------- /text_prompts/paper/SimpleBench/all_unigram_top_1000_100_1_gram.txt: -------------------------------------------------------------------------------- 1 | A sign that says "accept". 2 | A sign that says "acquire". 3 | A sign that says "african". 4 | A sign that says "american". 5 | A sign that says "angeles". 6 | A sign that says "appearance". 7 | A sign that says "arm". 8 | A sign that says "arrive". 9 | A sign that says "assembly". 10 | A sign that says "attempt". 11 | A sign that says "away". 12 | A sign that says "band". 13 | A sign that says "bay". 14 | A sign that says "boy". 15 | A sign that says "canada". 16 | A sign that says "cast". 17 | A sign that says "cell". 18 | A sign that says "church". 19 | A sign that says "club". 20 | A sign that says "commercial". 21 | A sign that says "connect". 22 | A sign that says "construct". 23 | A sign that says "construction". 24 | A sign that says "dance". 25 | A sign that says "debut". 26 | A sign that says "decision". 27 | A sign that says "department". 28 | A sign that says "direct". 29 | A sign that says "drop". 30 | A sign that says "effort". 31 | A sign that says "energy". 32 | A sign that says "facility". 33 | A sign that says "feature". 34 | A sign that says "final". 35 | A sign that says "financial". 36 | A sign that says "fly". 37 | A sign that says "french". 38 | A sign that says "greek". 39 | A sign that says "help". 40 | A sign that says "hospital". 41 | A sign that says "independent". 42 | A sign that says "individual". 43 | A sign that says "king". 44 | A sign that says "lack". 45 | A sign that says "lake". 46 | A sign that says "late". 47 | A sign that says "leave". 48 | A sign that says "life". 49 | A sign that says "link". 50 | A sign that says "london". 51 | A sign that says "look". 52 | A sign that says "lord". 53 | A sign that says "main". 54 | A sign that says "man". 55 | A sign that says "measure". 56 | A sign that says "medium". 57 | A sign that says "mile". 58 | A sign that says "month". 59 | A sign that says "mountain". 60 | A sign that says "northern". 61 | A sign that says "notable". 62 | A sign that says "october". 63 | A sign that says "official". 64 | A sign that says "operate". 65 | A sign that says "party". 66 | A sign that says "pass". 67 | A sign that says "person". 68 | A sign that says "police". 69 | A sign that says "popular". 70 | A sign that says "produce". 71 | A sign that says "professional". 72 | A sign that says "program". 73 | A sign that says "public". 74 | A sign that says "purpose". 75 | A sign that says "range". 76 | A sign that says "reach". 77 | A sign that says "records". 78 | A sign that says "red". 79 | A sign that says "refer". 80 | A sign that says "regular". 81 | A sign that says "religious". 82 | A sign that says "similar". 83 | A sign that says "social". 84 | A sign that says "stone". 85 | A sign that says "subsequently". 86 | A sign that says "survive". 87 | A sign that says "teach". 88 | A sign that says "traditional". 89 | A sign that says "travel". 90 | A sign that says "type". 91 | A sign that says "upper". 92 | A sign that says "van". 93 | A sign that says "variety". 94 | A sign that says "vehicle". 95 | A sign that says "victory". 96 | A sign that says "video". 97 | A sign that says "wife". 98 | A sign that says "william". 99 | A sign that says "year". 100 | A sign that says "young". 101 | -------------------------------------------------------------------------------- /text_prompts/raw/SimpleBench/all_unigram_100000_plus_100.txt: -------------------------------------------------------------------------------- 1 | kangun 2 | harnetty 3 | tikaw 4 | reviewcentre 5 | linneas 6 | switzerlands 7 | cormont 8 | lajtai 9 | vernetois 10 | copden 11 | procdump 12 | managat 13 | ellisberg 14 | clirimtare 15 | qurniyya 16 | churchilll 17 | iggesund 18 | vogeding 19 | hammerclaw 20 | galysheva 21 | challannain 22 | gamengame 23 | minioudaki 24 | scilliano 25 | aftnerne 26 | nalitabari 27 | undpko 28 | sukhaybarat 29 | thangundi 30 | burwardeslyn 31 | zavydovytska 32 | slevyas 33 | kokachin 34 | spissoy 35 | caranobe 36 | nizhnelomovskoye 37 | vayalil 38 | taktai 39 | giregi 40 | capitularis 41 | heeresbekleidungsamt 42 | secciones 43 | elitei 44 | muthialpettah 45 | monaut 46 | ivha 47 | otoparts 48 | bamore 49 | jialingjang 50 | chinnakannan 51 | syretzk 52 | bailup 53 | malaefone 54 | jiaogulan 55 | bellsaint 56 | brahmaic 57 | chapultepee 58 | zwangsanstalt 59 | paikary 60 | stotnikar 61 | maduropeptin 62 | heyyyy 63 | edsvikens 64 | ookstores 65 | wowomen 66 | unveiledness 67 | mediashopping 68 | ginzuishou 69 | saukrieg 70 | kornephoros 71 | thornborg 72 | barilga 73 | activeware 74 | ensco 75 | amprofon 76 | manoylo 77 | baipeung 78 | apazine 79 | logonov 80 | josanas 81 | garoff 82 | cubelo 83 | ormelle 84 | chelonitis 85 | keohan 86 | danakadan 87 | tiruvanaikaval 88 | thellaru 89 | serranito 90 | quercusglobulus 91 | theresianos 92 | rimydis 93 | alvoni 94 | hallatar 95 | odnoin 96 | kravchunovsky 97 | iccg 98 | kastlander 99 | tuolanshan 100 | siyaya -------------------------------------------------------------------------------- /text_prompts/raw/SimpleBench/all_unigram_10000_100000_100.txt: -------------------------------------------------------------------------------- 1 | teatre 2 | chatman 3 | goldstone 4 | redemptorist 5 | chuen 6 | braschi 7 | chiu 8 | architecturally 9 | sein 10 | beria 11 | painstakingly 12 | solti 13 | tozawa 14 | nebel 15 | drogba 16 | spellman 17 | bosanska 18 | visconti 19 | sounds 20 | carney 21 | kahungunu 22 | polls 23 | benazir 24 | lucha 25 | virgins 26 | millersville 27 | capsize 28 | vladimirovich 29 | barrel-vaulted 30 | ashlee 31 | strindberg 32 | italics 33 | broaden 34 | handsome 35 | devereaux 36 | rtve 37 | phosphor 38 | coleoptera 39 | leffler 40 | spotify 41 | luchino 42 | contrarian 43 | tabbed 44 | undergrad 45 | subsidised 46 | toba 47 | lauper 48 | wace 49 | carden 50 | umaga 51 | galatians 52 | sponsors 53 | speaks 54 | lupu 55 | limits 56 | scabies 57 | takraw 58 | reigns 59 | complexion 60 | ricketts 61 | seismologist 62 | wemyss 63 | poona 64 | buen 65 | union-tribune 66 | groot 67 | unchanging 68 | mauricio 69 | enoch 70 | artyom 71 | trappers 72 | holdout 73 | whoop 74 | taxus 75 | eupithecia 76 | marra 77 | maggs 78 | tirtha 79 | hata 80 | carboxylase 81 | ribbentrop 82 | tablas 83 | inhabitants 84 | jasdf 85 | hypodermic 86 | incrementally 87 | rte 88 | elegans 89 | nepean 90 | sen 91 | obsidian 92 | gunboat 93 | raimundo 94 | throttle 95 | spillover 96 | smirnoff 97 | wukong 98 | yemi 99 | adelbert 100 | cayuse -------------------------------------------------------------------------------- /text_prompts/raw/SimpleBench/all_unigram_1000_10000_100.txt: -------------------------------------------------------------------------------- 1 | desk 2 | palestinian 3 | growth 4 | buildings 5 | rainy 6 | handball 7 | translate 8 | mad 9 | emit 10 | marcel 11 | dartmouth 12 | physically 13 | noon 14 | spectacular 15 | provoke 16 | grouping 17 | islands 18 | enable 19 | ammunition 20 | popularly 21 | caesar 22 | beetle 23 | column 24 | rotary 25 | coleman 26 | churchyard 27 | satisfy 28 | bangalore 29 | devotion 30 | evans 31 | activate 32 | tyre 33 | abdomen 34 | brussels 35 | doyle 36 | inventor 37 | soup 38 | walt 39 | pile 40 | pocket 41 | broadway 42 | children 43 | monroe 44 | insert 45 | larva 46 | countess 47 | membership 48 | maybe 49 | defunct 50 | ford 51 | melbourne 52 | drink 53 | kate 54 | memoir 55 | wisdom 56 | treaty 57 | lounge 58 | seminary 59 | physiology 60 | precision 61 | bernard 62 | raw 63 | apostle 64 | literally 65 | unrest 66 | appalachian 67 | butcher 68 | caucus 69 | magical 70 | fountain 71 | gear 72 | suburban 73 | towers 74 | derivative 75 | heaven 76 | surgeon 77 | urban 78 | superstar 79 | reduced 80 | divine 81 | elimination 82 | dune 83 | tornado 84 | philharmonic 85 | neighboring 86 | cloud 87 | thoroughly 88 | postwar 89 | withdraw 90 | associates 91 | lighting 92 | accommodate 93 | relative 94 | salzburg 95 | dense 96 | beaver 97 | aka 98 | entry 99 | lean 100 | intended -------------------------------------------------------------------------------- /text_prompts/raw/SimpleBench/all_unigram_top_1000_100.txt: -------------------------------------------------------------------------------- 1 | assembly 2 | program 3 | william 4 | late 5 | measure 6 | travel 7 | drop 8 | police 9 | acquire 10 | social 11 | construct 12 | party 13 | lord 14 | month 15 | stone 16 | type 17 | dance 18 | year 19 | accept 20 | lake 21 | final 22 | french 23 | subsequently 24 | club 25 | upper 26 | canada 27 | direct 28 | victory 29 | african 30 | regular 31 | reach 32 | survive 33 | video 34 | notable 35 | commercial 36 | independent 37 | refer 38 | individual 39 | range 40 | angeles 41 | vehicle 42 | main 43 | greek 44 | red 45 | mile 46 | bay 47 | american 48 | band 49 | mountain 50 | london 51 | appearance 52 | cell 53 | religious 54 | van 55 | lack 56 | wife 57 | boy 58 | feature 59 | official 60 | connect 61 | arm 62 | debut 63 | away 64 | arrive 65 | hospital 66 | energy 67 | attempt 68 | purpose 69 | financial 70 | leave 71 | teach 72 | variety 73 | construction 74 | operate 75 | records 76 | man 77 | link 78 | medium 79 | person 80 | fly 81 | cast 82 | life 83 | young 84 | popular 85 | king 86 | department 87 | produce 88 | public 89 | professional 90 | pass 91 | october 92 | traditional 93 | look 94 | decision 95 | northern 96 | help 97 | facility 98 | church 99 | similar 100 | effort --------------------------------------------------------------------------------