├── FastSDXL ├── BlockUNet.py ├── OLSS.py ├── StateDictConverter.py └── Styler.py ├── LICENSE ├── README.md ├── launch_with_olss.py ├── launch_without_olss.py ├── models └── olss_scheduler.bin └── run_olss.py /FastSDXL/BlockUNet.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | from safetensors import safe_open 3 | from diffusers import ModelMixin 4 | from diffusers.configuration_utils import FrozenDict 5 | from .StateDictConverter import convert_state_dict_civitai_diffusers 6 | 7 | 8 | class Timesteps(torch.nn.Module): 9 | def __init__(self, num_channels): 10 | super().__init__() 11 | self.num_channels = num_channels 12 | 13 | def forward(self, timesteps): 14 | half_dim = self.num_channels // 2 15 | exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) / half_dim 16 | timesteps = timesteps.unsqueeze(-1) 17 | emb = timesteps.float() * torch.exp(exponent) 18 | emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) 19 | return emb 20 | 21 | 22 | class GEGLU(torch.nn.Module): 23 | 24 | def __init__(self, dim_in, dim_out): 25 | super().__init__() 26 | self.proj = torch.nn.Linear(dim_in, dim_out * 2) 27 | 28 | def forward(self, hidden_states): 29 | hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) 30 | return hidden_states * torch.nn.functional.gelu(gate) 31 | 32 | 33 | class Attention(torch.nn.Module): 34 | 35 | def __init__(self, query_dim, heads, dim_head, cross_attention_dim=None): 36 | super().__init__() 37 | inner_dim = dim_head * heads 38 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim 39 | self.heads = heads 40 | self.dim_head = dim_head 41 | 42 | self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=False) 43 | self.to_k = torch.nn.Linear(cross_attention_dim, inner_dim, bias=False) 44 | self.to_v = torch.nn.Linear(cross_attention_dim, inner_dim, bias=False) 45 | self.to_out = torch.nn.Linear(inner_dim, query_dim, bias=True) 46 | 47 | def forward( 48 | self, 49 | hidden_states, 50 | encoder_hidden_states=None, 51 | ): 52 | if encoder_hidden_states is None: 53 | encoder_hidden_states = hidden_states 54 | 55 | batch_size = encoder_hidden_states.shape[0] 56 | 57 | query = self.to_q(hidden_states) 58 | key = self.to_k(encoder_hidden_states) 59 | value = self.to_v(encoder_hidden_states) 60 | 61 | query = query.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2) 62 | key = key.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2) 63 | value = value.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2) 64 | 65 | hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value) 66 | hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.heads * self.dim_head) 67 | hidden_states = hidden_states.to(query.dtype) 68 | 69 | hidden_states = self.to_out(hidden_states) 70 | 71 | return hidden_states 72 | 73 | 74 | class BasicTransformerBlock(torch.nn.Module): 75 | 76 | def __init__(self, dim, num_attention_heads, attention_head_dim, cross_attention_dim): 77 | super().__init__() 78 | 79 | # 1. Self-Attn 80 | self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True) 81 | self.attn1 = Attention(query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim) 82 | 83 | # 2. Cross-Attn 84 | self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True) 85 | self.attn2 = Attention(query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim) 86 | 87 | # 3. Feed-forward 88 | self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True) 89 | self.act_fn = GEGLU(dim, dim * 4) 90 | self.ff = torch.nn.Linear(dim * 4, dim) 91 | 92 | 93 | def forward(self, hidden_states, encoder_hidden_states): 94 | # 1. Self-Attention 95 | norm_hidden_states = self.norm1(hidden_states) 96 | attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None,) 97 | hidden_states = attn_output + hidden_states 98 | 99 | # 2. Cross-Attention 100 | norm_hidden_states = self.norm2(hidden_states) 101 | attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) 102 | hidden_states = attn_output + hidden_states 103 | 104 | # 3. Feed-forward 105 | norm_hidden_states = self.norm3(hidden_states) 106 | ff_output = self.act_fn(norm_hidden_states) 107 | ff_output = self.ff(ff_output) 108 | hidden_states = ff_output + hidden_states 109 | 110 | return hidden_states 111 | 112 | 113 | class DownSampler(torch.nn.Module): 114 | def __init__(self, channels): 115 | super().__init__() 116 | self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=1) 117 | 118 | def forward(self, hidden_states, time_emb, text_emb, res_stack): 119 | hidden_states = self.conv(hidden_states) 120 | return hidden_states, time_emb, text_emb, res_stack 121 | 122 | 123 | class UpSampler(torch.nn.Module): 124 | def __init__(self, channels): 125 | super().__init__() 126 | self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1) 127 | 128 | def forward(self, hidden_states, time_emb, text_emb, res_stack): 129 | hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") 130 | hidden_states = self.conv(hidden_states) 131 | return hidden_states, time_emb, text_emb, res_stack 132 | 133 | 134 | class ResnetBlock(torch.nn.Module): 135 | def __init__(self, in_channels, out_channels, temb_channels, groups=32, eps=1e-5): 136 | super().__init__() 137 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 138 | self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 139 | self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) 140 | self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) 141 | self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 142 | self.nonlinearity = torch.nn.SiLU() 143 | self.conv_shortcut = None 144 | if in_channels != out_channels: 145 | self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) 146 | 147 | def forward(self, hidden_states, time_emb, text_emb, res_stack): 148 | x = hidden_states 149 | x = self.norm1(x) 150 | x = self.nonlinearity(x) 151 | x = self.conv1(x) 152 | emb = self.nonlinearity(time_emb) 153 | emb = self.time_emb_proj(emb)[:, :, None, None] 154 | x = x + emb 155 | x = self.norm2(x) 156 | x = self.nonlinearity(x) 157 | x = self.conv2(x) 158 | if self.conv_shortcut is not None: 159 | hidden_states = self.conv_shortcut(hidden_states) 160 | hidden_states = hidden_states + x 161 | return hidden_states, time_emb, text_emb, res_stack 162 | 163 | 164 | class AttentionBlock(torch.nn.Module): 165 | 166 | def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, cross_attention_dim=None, norm_num_groups=32): 167 | super().__init__() 168 | inner_dim = num_attention_heads * attention_head_dim 169 | 170 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 171 | self.proj_in = torch.nn.Linear(in_channels, inner_dim) 172 | 173 | self.transformer_blocks = torch.nn.ModuleList([ 174 | BasicTransformerBlock( 175 | inner_dim, 176 | num_attention_heads, 177 | attention_head_dim, 178 | cross_attention_dim=cross_attention_dim 179 | ) 180 | for d in range(num_layers) 181 | ]) 182 | 183 | self.proj_out = torch.nn.Linear(inner_dim, in_channels) 184 | 185 | def forward(self, hidden_states, time_emb, text_emb, res_stack): 186 | batch, _, height, width = hidden_states.shape 187 | residual = hidden_states 188 | 189 | hidden_states = self.norm(hidden_states) 190 | inner_dim = hidden_states.shape[1] 191 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 192 | hidden_states = self.proj_in(hidden_states) 193 | 194 | for block in self.transformer_blocks: 195 | hidden_states = block( 196 | hidden_states, 197 | encoder_hidden_states=text_emb 198 | ) 199 | 200 | hidden_states = self.proj_out(hidden_states) 201 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 202 | hidden_states = hidden_states + residual 203 | 204 | return hidden_states, time_emb, text_emb, res_stack 205 | 206 | 207 | class PushBlock(torch.nn.Module): 208 | def __init__(self): 209 | super().__init__() 210 | 211 | def forward(self, hidden_states, time_emb, text_emb, res_stack): 212 | res_stack.append(hidden_states) 213 | return hidden_states, time_emb, text_emb, res_stack 214 | 215 | 216 | class PopBlock(torch.nn.Module): 217 | def __init__(self): 218 | super().__init__() 219 | 220 | def forward(self, hidden_states, time_emb, text_emb, res_stack): 221 | res_hidden_states = res_stack.pop() 222 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 223 | return hidden_states, time_emb, text_emb, res_stack 224 | 225 | 226 | class BlockUNet(ModelMixin): 227 | def __init__(self): 228 | super().__init__() 229 | self.time_proj = Timesteps(320) 230 | self.time_embedding = torch.nn.Sequential( 231 | torch.nn.Linear(320, 1280), 232 | torch.nn.SiLU(), 233 | torch.nn.Linear(1280, 1280) 234 | ) 235 | self.add_time_proj = Timesteps(256) 236 | self.add_time_embedding = torch.nn.Sequential( 237 | torch.nn.Linear(2816, 1280), 238 | torch.nn.SiLU(), 239 | torch.nn.Linear(1280, 1280) 240 | ) 241 | self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1) 242 | 243 | self.blocks = torch.nn.ModuleList([ 244 | # DownBlock2D 245 | ResnetBlock(320, 320, 1280), 246 | PushBlock(), 247 | ResnetBlock(320, 320, 1280), 248 | PushBlock(), 249 | DownSampler(320), 250 | PushBlock(), 251 | # CrossAttnDownBlock2D 252 | ResnetBlock(320, 640, 1280), 253 | AttentionBlock(10, 64, 640, 2, 2048), 254 | PushBlock(), 255 | ResnetBlock(640, 640, 1280), 256 | AttentionBlock(10, 64, 640, 2, 2048), 257 | PushBlock(), 258 | DownSampler(640), 259 | PushBlock(), 260 | # CrossAttnDownBlock2D 261 | ResnetBlock(640, 1280, 1280), 262 | AttentionBlock(20, 64, 1280, 10, 2048), 263 | PushBlock(), 264 | ResnetBlock(1280, 1280, 1280), 265 | AttentionBlock(20, 64, 1280, 10, 2048), 266 | PushBlock(), 267 | # UNetMidBlock2DCrossAttn 268 | ResnetBlock(1280, 1280, 1280), 269 | AttentionBlock(20, 64, 1280, 10, 2048), 270 | ResnetBlock(1280, 1280, 1280), 271 | # CrossAttnUpBlock2D 272 | PopBlock(), 273 | ResnetBlock(2560, 1280, 1280), 274 | AttentionBlock(20, 64, 1280, 10, 2048), 275 | PopBlock(), 276 | ResnetBlock(2560, 1280, 1280), 277 | AttentionBlock(20, 64, 1280, 10, 2048), 278 | PopBlock(), 279 | ResnetBlock(1920, 1280, 1280), 280 | AttentionBlock(20, 64, 1280, 10, 2048), 281 | UpSampler(1280), 282 | # CrossAttnUpBlock2D 283 | PopBlock(), 284 | ResnetBlock(1920, 640, 1280), 285 | AttentionBlock(10, 64, 640, 2, 2048), 286 | PopBlock(), 287 | ResnetBlock(1280, 640, 1280), 288 | AttentionBlock(10, 64, 640, 2, 2048), 289 | PopBlock(), 290 | ResnetBlock(960, 640, 1280), 291 | AttentionBlock(10, 64, 640, 2, 2048), 292 | UpSampler(640), 293 | # UpBlock2D 294 | PopBlock(), 295 | ResnetBlock(960, 320, 1280), 296 | PopBlock(), 297 | ResnetBlock(640, 320, 1280), 298 | PopBlock(), 299 | ResnetBlock(640, 320, 1280) 300 | ]) 301 | 302 | self.conv_norm_out = torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5) 303 | self.conv_act = torch.nn.SiLU() 304 | self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1) 305 | 306 | # For diffusers 307 | self.config = FrozenDict([ 308 | ('sample_size', 128), ('in_channels', 4), ('out_channels', 4), ('center_input_sample', False), ('flip_sin_to_cos', True), 309 | ('freq_shift', 0), ('down_block_types', ['DownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D']), 310 | ('mid_block_type', 'UNetMidBlock2DCrossAttn'), ('up_block_types', ['CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'UpBlock2D']), 311 | ('only_cross_attention', False), ('block_out_channels', [320, 640, 1280]), ('layers_per_block', 2), ('downsample_padding', 1), 312 | ('mid_block_scale_factor', 1), ('act_fn', 'silu'), ('norm_num_groups', 32), ('norm_eps', 1e-05), ('cross_attention_dim', 2048), 313 | ('transformer_layers_per_block', [1, 2, 10]), ('encoder_hid_dim', None), ('encoder_hid_dim_type', None), ('attention_head_dim', [5, 10, 20]), 314 | ('num_attention_heads', None), ('dual_cross_attention', False), ('use_linear_projection', True), ('class_embed_type', None), 315 | ('addition_embed_type', 'text_time'), ('addition_time_embed_dim', 256), ('num_class_embeds', None), ('upcast_attention', None), 316 | ('resnet_time_scale_shift', 'default'), ('resnet_skip_time_act', False), ('resnet_out_scale_factor', 1.0), ('time_embedding_type', 'positional'), 317 | ('time_embedding_dim', None), ('time_embedding_act_fn', None), ('timestep_post_act', None), ('time_cond_proj_dim', None), 318 | ('conv_in_kernel', 3), ('conv_out_kernel', 3), ('projection_class_embeddings_input_dim', 2816), ('attention_type', 'default'), 319 | ('class_embeddings_concat', False), ('mid_block_only_cross_attention', None), ('cross_attention_norm', None), 320 | ('addition_embed_type_num_heads', 64), ('_class_name', 'UNet2DConditionModel'), ('_diffusers_version', '0.20.2'), 321 | ('_name_or_path', 'models/stabilityai/stable-diffusion-xl-base-1.0\\unet')]) 322 | self.add_embedding = FrozenDict([("linear_1", FrozenDict([("in_features", 2816)]))]) 323 | 324 | def from_diffusers(self, safetensor_path=None, state_dict=None): 325 | # Load state_dict 326 | if safetensor_path is not None: 327 | state_dict = {} 328 | with safe_open(safetensor_path, framework="pt", device="cpu") as f: 329 | for name in f.keys(): 330 | state_dict[name] = f.get_tensor(name) 331 | 332 | # Analyze the architecture 333 | block_types = [block.__class__.__name__ for block in self.blocks] 334 | 335 | # Rename each parameter 336 | name_list = sorted([name for name in state_dict]) 337 | rename_dict = {} 338 | block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1} 339 | last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""} 340 | for name in name_list: 341 | names = name.split(".") 342 | if names[0] in ["conv_in", "conv_norm_out", "conv_out"]: 343 | pass 344 | elif names[0] in ["time_embedding", "add_embedding"]: 345 | if names[0] == "add_embedding": 346 | names[0] = "add_time_embedding" 347 | names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]] 348 | elif names[0] in ["down_blocks", "mid_block", "up_blocks"]: 349 | if names[0] == "mid_block": 350 | names.insert(1, "0") 351 | block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]] 352 | block_type_with_id = ".".join(names[:4]) 353 | if block_type_with_id != last_block_type_with_id[block_type]: 354 | block_id[block_type] += 1 355 | last_block_type_with_id[block_type] = block_type_with_id 356 | while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: 357 | block_id[block_type] += 1 358 | block_type_with_id = ".".join(names[:4]) 359 | names = ["blocks", str(block_id[block_type])] + names[4:] 360 | if "ff" in names: 361 | ff_index = names.index("ff") 362 | component = ".".join(names[ff_index:ff_index+3]) 363 | component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component] 364 | names = names[:ff_index] + [component] + names[ff_index+3:] 365 | if "to_out" in names: 366 | names.pop(names.index("to_out") + 1) 367 | else: 368 | raise ValueError(f"Unknown parameters: {name}") 369 | rename_dict[name] = ".".join(names) 370 | 371 | # Convert state_dict 372 | state_dict_ = {} 373 | for name, param in state_dict.items(): 374 | state_dict_[rename_dict[name]] = param 375 | self.load_state_dict(state_dict_) 376 | 377 | def from_civitai(self, safetensor_path=None, state_dict=None): 378 | # Load state_dict 379 | if safetensor_path is not None: 380 | state_dict = {} 381 | with safe_open(safetensor_path, framework="pt", device="cpu") as f: 382 | for name in f.keys(): 383 | state_dict[name] = f.get_tensor(name) 384 | 385 | # Convert state_dict 386 | state_dict = convert_state_dict_civitai_diffusers(state_dict) 387 | self.from_diffusers(state_dict=state_dict) 388 | 389 | def process( 390 | self, 391 | sample, 392 | timestep, 393 | encoder_hidden_states, 394 | added_cond_kwargs, 395 | **kwargs 396 | ): 397 | # 1. time 398 | t_emb = self.time_proj(timestep[None]).to(sample.dtype) 399 | t_emb = self.time_embedding(t_emb) 400 | 401 | time_embeds = self.add_time_proj(added_cond_kwargs["time_ids"]) 402 | time_embeds = time_embeds.reshape((time_embeds.shape[0], -1)) 403 | add_embeds = torch.concat([added_cond_kwargs["text_embeds"], time_embeds], dim=-1) 404 | add_embeds = add_embeds.to(sample.dtype) 405 | add_embeds = self.add_time_embedding(add_embeds) 406 | 407 | time_emb = t_emb + add_embeds 408 | 409 | # 2. pre-process 410 | hidden_states = self.conv_in(sample) 411 | text_emb = encoder_hidden_states 412 | res_stack = [hidden_states] 413 | 414 | # 3. blocks 415 | for i, block in enumerate(self.blocks): 416 | hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) 417 | 418 | hidden_states = self.conv_norm_out(hidden_states) 419 | hidden_states = self.conv_act(hidden_states) 420 | hidden_states = self.conv_out(hidden_states) 421 | 422 | return hidden_states 423 | 424 | def forward( 425 | self, 426 | sample, 427 | timestep, 428 | encoder_hidden_states, 429 | added_cond_kwargs, 430 | **kwargs 431 | ): 432 | hidden_states = [] 433 | for i in range(sample.shape[0]): 434 | added_cond_kwargs_ = {} 435 | added_cond_kwargs_["text_embeds"] = added_cond_kwargs["text_embeds"][i:i+1] 436 | added_cond_kwargs_["time_ids"] = added_cond_kwargs["time_ids"][i:i+1] 437 | hidden_states.append(self.process(sample[i:i+1], timestep, encoder_hidden_states[i:i+1], added_cond_kwargs_)) 438 | hidden_states = torch.concat(hidden_states, dim=0) 439 | return (hidden_states,) 440 | 441 | -------------------------------------------------------------------------------- /FastSDXL/OLSS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | 5 | class OLSSSchedulerModel(torch.nn.Module): 6 | 7 | def __init__(self, wx, we): 8 | super(OLSSSchedulerModel, self).__init__() 9 | assert len(wx.shape)==1 and len(we.shape)==2 10 | T = wx.shape[0] 11 | assert T==we.shape[0] and T==we.shape[1] 12 | self.register_parameter("wx", torch.nn.Parameter(wx)) 13 | self.register_parameter("we", torch.nn.Parameter(we)) 14 | 15 | def forward(self, t, xT, e_prev): 16 | assert t - len(e_prev) + 1 == 0 17 | x = xT*self.wx[t] 18 | for e, we in zip(e_prev, self.we[t]): 19 | x += e*we 20 | return x.to(xT.dtype) 21 | 22 | 23 | class OLSSScheduler(): 24 | 25 | def __init__(self, timesteps, model): 26 | self.timesteps = timesteps 27 | self.model = model 28 | self.init_noise_sigma = 1.0 29 | self.order = 1 30 | 31 | @staticmethod 32 | def load(path): 33 | timesteps, wx, we = torch.load(path, map_location="cpu") 34 | model = OLSSSchedulerModel(wx, we) 35 | return OLSSScheduler(timesteps, model) 36 | 37 | def save(self, path): 38 | timesteps, wx, we = self.timesteps, self.model.wx, self.model.we 39 | torch.save((timesteps, wx, we), path) 40 | 41 | def set_timesteps(self, num_inference_steps, device = "cuda"): 42 | self.xT = None 43 | self.e_prev = [] 44 | self.t_prev = -1 45 | self.model = self.model.to(device) 46 | self.timesteps = self.timesteps.to(device) 47 | 48 | def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs): 49 | return sample 50 | 51 | @torch.no_grad() 52 | def step( 53 | self, 54 | model_output: torch.FloatTensor, 55 | timestep: int, 56 | sample: torch.FloatTensor, 57 | *args, **kwargs 58 | ): 59 | t = self.timesteps.tolist().index(timestep) 60 | assert self.t_prev==-1 or t==self.t_prev+1 61 | if self.t_prev==-1: 62 | self.xT = sample 63 | self.e_prev.append(model_output) 64 | x = self.model(t, self.xT, self.e_prev) 65 | if t+1==len(self.timesteps): 66 | self.xT = None 67 | self.e_prev = [] 68 | self.t_prev = -1 69 | else: 70 | self.t_prev = t 71 | return (x,) 72 | 73 | 74 | class OLSSSolver: 75 | 76 | def __init__(self): 77 | pass 78 | 79 | def solve_linear_regression(self, X, Y): 80 | X = X.to(torch.float64) 81 | Y = Y.to(torch.float64) 82 | # coef = torch.linalg.pinv(X.T @ X) @ X.T @ Y 83 | coef = torch.linalg.lstsq(X, Y).solution 84 | return coef 85 | 86 | def solve_scheduer_parameters(self, xT, e_prev, x): 87 | # prepare 88 | xe_prev = torch.concat([xT, e_prev], dim=0) 89 | xe_prev = xe_prev.reshape(xe_prev.shape[0], -1) 90 | x = x.flatten() 91 | # solve the ordinary least squares problem 92 | coef = self.solve_linear_regression(xe_prev.T, x) 93 | # split the parameters 94 | wx, we = coef[:1], coef[1:] 95 | # error 96 | x_pred = torch.matmul(coef.unsqueeze(0), xe_prev.to(torch.float64)).squeeze(0) 97 | err = torch.nn.functional.mse_loss(x_pred, x).tolist() 98 | return wx, we, err 99 | 100 | @torch.no_grad() 101 | def resolve_diffusion_process(self, 102 | steps_accelerate, 103 | t_path, 104 | x_path, 105 | e_path, 106 | i_path=None): 107 | steps_inference = t_path.shape[0] 108 | # accelerate path 109 | if i_path is None: 110 | i_path = torch.arange(0, steps_inference, steps_inference//steps_accelerate)[:steps_accelerate] 111 | t_path = t_path[i_path] 112 | x_path = torch.concat([x_path[i_path], x_path[-1:]]) 113 | e_path = e_path[i_path] 114 | # parameters 115 | wx = torch.zeros(steps_accelerate, dtype=torch.float64) 116 | we = torch.zeros((steps_accelerate, steps_accelerate), dtype=torch.float64) 117 | for i in range(steps_accelerate): 118 | x = x_path[i+1] 119 | xT = x_path[0:1] 120 | e_prev = e_path[:i+1] 121 | wx[i], we[i, :i+1], _ = self.solve_scheduer_parameters(xT, e_prev, x) 122 | return t_path, wx, we 123 | 124 | def search_next_step_with_error_limit(self, x_prev, e_prev, x_flat, i_lowerbound, max_error): 125 | i_next = i_lowerbound 126 | i_upperbound = len(x_flat)-1 127 | while i_upperbound>i_lowerbound: 128 | i_next = (i_lowerbound + i_upperbound + 1)//2 129 | x_goal = x_flat[i_next] 130 | _, _, err_step = self.solve_scheduer_parameters(x_prev, e_prev, x_goal) 131 | if err_step>max_error: 132 | i_upperbound = i_next - 1 133 | else: 134 | i_lowerbound = i_next 135 | i_next = i_lowerbound 136 | return i_next 137 | 138 | def search_path_with_error_limit(self, 139 | max_steps, 140 | t_path, 141 | x_path, 142 | e_path, 143 | max_error): 144 | # prepare for calculation 145 | num_inference_steps = t_path.shape[0] 146 | x_flat = x_path.reshape(num_inference_steps+1, -1) 147 | e_flat = e_path.reshape(num_inference_steps, -1) 148 | # search (greedy) 149 | i_path_acc = [0] 150 | for step in range(max_steps): 151 | x_prev = x_flat[i_path_acc[step:step+1]] 152 | e_prev = e_flat[i_path_acc] 153 | i_lowerbound = i_path_acc[step] + 1 154 | i_next = self.search_next_step_with_error_limit(x_prev, e_prev, x_flat, i_lowerbound, max_error) 155 | if i_next == num_inference_steps: 156 | return i_path_acc 157 | else: 158 | i_path_acc.append(i_next) 159 | return None 160 | 161 | @torch.no_grad() 162 | def resolve_diffusion_process_graph(self, 163 | num_accelerate_steps, 164 | t_path, 165 | x_path, 166 | e_path, 167 | max_iter = 30, 168 | verbose = 0): 169 | error_l, error_r = 0.0, 10.0 170 | for it in tqdm(range(max_iter), desc="OLSS is solving the parameters"): 171 | error_m = (error_l + error_r) / 2 172 | path = self.search_path_with_error_limit(num_accelerate_steps, t_path, x_path, e_path, error_m) 173 | if path is None: 174 | error_l = error_m 175 | else: 176 | error_r = error_m 177 | if verbose>0: 178 | print(f"search for path with maximum error: {error_m}") 179 | if path is None: 180 | print(" cannot find such path") 181 | else: 182 | print(f" find a path with length {len(path)}: {path}") 183 | path = self.search_path_with_error_limit(num_accelerate_steps, t_path, x_path, e_path, error_r) 184 | timesteps, wx, we = self.resolve_diffusion_process(num_accelerate_steps, t_path, x_path, e_path, i_path=path) 185 | return timesteps, wx, we 186 | 187 | 188 | class SchedulerWrapper: 189 | def __init__(self, scheduler): 190 | self.scheduler = scheduler 191 | self.catch_x, self.catch_e, self.catch_x_ = {}, {}, {} 192 | self.olss_scheduler = None 193 | 194 | def set_timesteps(self, num_inference_steps, **kwargs): 195 | if self.olss_scheduler is None: 196 | result = self.scheduler.set_timesteps(num_inference_steps, **kwargs) 197 | self.timesteps = self.scheduler.timesteps 198 | self.init_noise_sigma = self.scheduler.init_noise_sigma 199 | self.order = self.scheduler.order 200 | return result 201 | else: 202 | result = self.olss_scheduler.set_timesteps(num_inference_steps, **kwargs) 203 | self.timesteps = self.olss_scheduler.timesteps 204 | self.init_noise_sigma = self.scheduler.init_noise_sigma 205 | self.order = self.scheduler.order 206 | return result 207 | 208 | def step(self, model_output, timestep, sample, **kwargs): 209 | if self.olss_scheduler is None: 210 | result = self.scheduler.step(model_output, timestep, sample, **kwargs) 211 | timestep = timestep.tolist() 212 | if timestep not in self.catch_x: 213 | self.catch_x[timestep] = [] 214 | self.catch_e[timestep] = [] 215 | self.catch_x_[timestep] = [] 216 | self.catch_x[timestep].append(sample.clone().detach().cpu()) 217 | self.catch_e[timestep].append(model_output.clone().detach().cpu()) 218 | self.catch_x_[timestep].append(result[0].clone().detach().cpu()) 219 | return result 220 | else: 221 | result = self.olss_scheduler.step(model_output, timestep, sample, **kwargs) 222 | return result 223 | 224 | def scale_model_input(self, sample, timestep): 225 | return sample 226 | 227 | def add_noise(self, original_samples, noise, timesteps): 228 | result = self.scheduler.add_noise(original_samples, noise, timesteps) 229 | return result 230 | 231 | def get_path(self): 232 | t_path = sorted([t for t in self.catch_x], reverse=True) 233 | x_path, e_path = [], [] 234 | for t in t_path: 235 | x = torch.cat(self.catch_x[t], dim=0) 236 | x_path.append(x) 237 | e = torch.cat(self.catch_e[t], dim=0) 238 | e_path.append(e) 239 | t_final = t_path[-1] 240 | x_final = torch.cat(self.catch_x_[t_final], dim=0) 241 | x_path.append(x_final) 242 | t_path = torch.tensor(t_path, dtype=torch.int32) 243 | x_path = torch.stack(x_path) 244 | e_path = torch.stack(e_path) 245 | return t_path, x_path, e_path 246 | 247 | def prepare_olss(self, num_accelerate_steps): 248 | solver = OLSSSolver() 249 | t_path, x_path, e_path = self.get_path() 250 | timesteps, wx, we = solver.resolve_diffusion_process_graph( 251 | num_accelerate_steps, t_path, x_path, e_path) 252 | self.olss_model = OLSSSchedulerModel(wx, we) 253 | self.olss_scheduler = OLSSScheduler(timesteps, self.olss_model) 254 | -------------------------------------------------------------------------------- /FastSDXL/Styler.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lllyasviel/Fooocus/blob/main/modules/sdxl_styles.py 2 | styles = [ 3 | { 4 | "name": "Default (Slightly Cinematic)", 5 | "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", 6 | "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" 7 | }, 8 | { 9 | "name": "sai-3d-model", 10 | "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", 11 | "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting" 12 | }, 13 | { 14 | "name": "sai-analog film", 15 | "prompt": "analog film photo {prompt} . faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage", 16 | "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured" 17 | }, 18 | { 19 | "name": "sai-anime", 20 | "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", 21 | "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast" 22 | }, 23 | { 24 | "name": "sai-cinematic", 25 | "prompt": "cinematic film still {prompt} . shallow depth of field, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", 26 | "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" 27 | }, 28 | { 29 | "name": "sai-comic book", 30 | "prompt": "comic {prompt} . graphic illustration, comic art, graphic novel art, vibrant, highly detailed", 31 | "negative_prompt": "photograph, deformed, glitch, noisy, realistic, stock photo" 32 | }, 33 | { 34 | "name": "sai-craft clay", 35 | "prompt": "play-doh style {prompt} . sculpture, clay art, centered composition, Claymation", 36 | "negative_prompt": "sloppy, messy, grainy, highly detailed, ultra textured, photo" 37 | }, 38 | { 39 | "name": "sai-digital art", 40 | "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", 41 | "negative_prompt": "photo, photorealistic, realism, ugly" 42 | }, 43 | { 44 | "name": "sai-enhance", 45 | "prompt": "breathtaking {prompt} . award-winning, professional, highly detailed", 46 | "negative_prompt": "ugly, deformed, noisy, blurry, distorted, grainy" 47 | }, 48 | { 49 | "name": "sai-fantasy art", 50 | "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", 51 | "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white" 52 | }, 53 | { 54 | "name": "sai-isometric", 55 | "prompt": "isometric style {prompt} . vibrant, beautiful, crisp, detailed, ultra detailed, intricate", 56 | "negative_prompt": "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy, realistic, photographic" 57 | }, 58 | { 59 | "name": "sai-line art", 60 | "prompt": "line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics", 61 | "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic" 62 | }, 63 | { 64 | "name": "sai-lowpoly", 65 | "prompt": "low-poly style {prompt} . low-poly game art, polygon mesh, jagged, blocky, wireframe edges, centered composition", 66 | "negative_prompt": "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo" 67 | }, 68 | { 69 | "name": "sai-neonpunk", 70 | "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", 71 | "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured" 72 | }, 73 | { 74 | "name": "sai-origami", 75 | "prompt": "origami style {prompt} . paper art, pleated paper, folded, origami art, pleats, cut and fold, centered composition", 76 | "negative_prompt": "noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo" 77 | }, 78 | { 79 | "name": "sai-photographic", 80 | "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", 81 | "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly" 82 | }, 83 | { 84 | "name": "sai-pixel art", 85 | "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", 86 | "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic" 87 | }, 88 | { 89 | "name": "sai-texture", 90 | "prompt": "texture {prompt} top down close-up", 91 | "negative_prompt": "ugly, deformed, noisy, blurry" 92 | }, 93 | { 94 | "name": "ads-advertising", 95 | "prompt": "advertising poster style {prompt} . Professional, modern, product-focused, commercial, eye-catching, highly detailed", 96 | "negative_prompt": "noisy, blurry, amateurish, sloppy, unattractive" 97 | }, 98 | { 99 | "name": "ads-automotive", 100 | "prompt": "automotive advertisement style {prompt} . sleek, dynamic, professional, commercial, vehicle-focused, high-resolution, highly detailed", 101 | "negative_prompt": "noisy, blurry, unattractive, sloppy, unprofessional" 102 | }, 103 | { 104 | "name": "ads-corporate", 105 | "prompt": "corporate branding style {prompt} . professional, clean, modern, sleek, minimalist, business-oriented, highly detailed", 106 | "negative_prompt": "noisy, blurry, grungy, sloppy, cluttered, disorganized" 107 | }, 108 | { 109 | "name": "ads-fashion editorial", 110 | "prompt": "fashion editorial style {prompt} . high fashion, trendy, stylish, editorial, magazine style, professional, highly detailed", 111 | "negative_prompt": "outdated, blurry, noisy, unattractive, sloppy" 112 | }, 113 | { 114 | "name": "ads-food photography", 115 | "prompt": "food photography style {prompt} . appetizing, professional, culinary, high-resolution, commercial, highly detailed", 116 | "negative_prompt": "unappetizing, sloppy, unprofessional, noisy, blurry" 117 | }, 118 | { 119 | "name": "ads-gourmet food photography", 120 | "prompt": "gourmet food photo of {prompt} . soft natural lighting, macro details, vibrant colors, fresh ingredients, glistening textures, bokeh background, styled plating, wooden tabletop, garnished, tantalizing, editorial quality", 121 | "negative_prompt": "cartoon, anime, sketch, grayscale, dull, overexposed, cluttered, messy plate, deformed" 122 | }, 123 | { 124 | "name": "ads-luxury", 125 | "prompt": "luxury product style {prompt} . elegant, sophisticated, high-end, luxurious, professional, highly detailed", 126 | "negative_prompt": "cheap, noisy, blurry, unattractive, amateurish" 127 | }, 128 | { 129 | "name": "ads-real estate", 130 | "prompt": "real estate photography style {prompt} . professional, inviting, well-lit, high-resolution, property-focused, commercial, highly detailed", 131 | "negative_prompt": "dark, blurry, unappealing, noisy, unprofessional" 132 | }, 133 | { 134 | "name": "ads-retail", 135 | "prompt": "retail packaging style {prompt} . vibrant, enticing, commercial, product-focused, eye-catching, professional, highly detailed", 136 | "negative_prompt": "noisy, blurry, amateurish, sloppy, unattractive" 137 | }, 138 | { 139 | "name": "artstyle-abstract", 140 | "prompt": "abstract style {prompt} . non-representational, colors and shapes, expression of feelings, imaginative, highly detailed", 141 | "negative_prompt": "realistic, photographic, figurative, concrete" 142 | }, 143 | { 144 | "name": "artstyle-abstract expressionism", 145 | "prompt": "abstract expressionist painting {prompt} . energetic brushwork, bold colors, abstract forms, expressive, emotional", 146 | "negative_prompt": "realistic, photorealistic, low contrast, plain, simple, monochrome" 147 | }, 148 | { 149 | "name": "artstyle-art deco", 150 | "prompt": "art deco style {prompt} . geometric shapes, bold colors, luxurious, elegant, decorative, symmetrical, ornate, detailed", 151 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, modernist, minimalist" 152 | }, 153 | { 154 | "name": "artstyle-art nouveau", 155 | "prompt": "art nouveau style {prompt} . elegant, decorative, curvilinear forms, nature-inspired, ornate, detailed", 156 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, modernist, minimalist" 157 | }, 158 | { 159 | "name": "artstyle-constructivist", 160 | "prompt": "constructivist style {prompt} . geometric shapes, bold colors, dynamic composition, propaganda art style", 161 | "negative_prompt": "realistic, photorealistic, low contrast, plain, simple, abstract expressionism" 162 | }, 163 | { 164 | "name": "artstyle-cubist", 165 | "prompt": "cubist artwork {prompt} . geometric shapes, abstract, innovative, revolutionary", 166 | "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy" 167 | }, 168 | { 169 | "name": "artstyle-expressionist", 170 | "prompt": "expressionist {prompt} . raw, emotional, dynamic, distortion for emotional effect, vibrant, use of unusual colors, detailed", 171 | "negative_prompt": "realism, symmetry, quiet, calm, photo" 172 | }, 173 | { 174 | "name": "artstyle-graffiti", 175 | "prompt": "graffiti style {prompt} . street art, vibrant, urban, detailed, tag, mural", 176 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic" 177 | }, 178 | { 179 | "name": "artstyle-hyperrealism", 180 | "prompt": "hyperrealistic art {prompt} . extremely high-resolution details, photographic, realism pushed to extreme, fine texture, incredibly lifelike", 181 | "negative_prompt": "simplified, abstract, unrealistic, impressionistic, low resolution" 182 | }, 183 | { 184 | "name": "artstyle-impressionist", 185 | "prompt": "impressionist painting {prompt} . loose brushwork, vibrant color, light and shadow play, captures feeling over form", 186 | "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy" 187 | }, 188 | { 189 | "name": "artstyle-pointillism", 190 | "prompt": "pointillism style {prompt} . composed entirely of small, distinct dots of color, vibrant, highly detailed", 191 | "negative_prompt": "line drawing, smooth shading, large color fields, simplistic" 192 | }, 193 | { 194 | "name": "artstyle-pop art", 195 | "prompt": "pop Art style {prompt} . bright colors, bold outlines, popular culture themes, ironic or kitsch", 196 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, minimalist" 197 | }, 198 | { 199 | "name": "artstyle-psychedelic", 200 | "prompt": "psychedelic style {prompt} . vibrant colors, swirling patterns, abstract forms, surreal, trippy", 201 | "negative_prompt": "monochrome, black and white, low contrast, realistic, photorealistic, plain, simple" 202 | }, 203 | { 204 | "name": "artstyle-renaissance", 205 | "prompt": "renaissance style {prompt} . realistic, perspective, light and shadow, religious or mythological themes, highly detailed", 206 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, modernist, minimalist, abstract" 207 | }, 208 | { 209 | "name": "artstyle-steampunk", 210 | "prompt": "steampunk style {prompt} . antique, mechanical, brass and copper tones, gears, intricate, detailed", 211 | "negative_prompt": "deformed, glitch, noisy, low contrast, anime, photorealistic" 212 | }, 213 | { 214 | "name": "artstyle-surrealist", 215 | "prompt": "surrealist art {prompt} . dreamlike, mysterious, provocative, symbolic, intricate, detailed", 216 | "negative_prompt": "anime, photorealistic, realistic, deformed, glitch, noisy, low contrast" 217 | }, 218 | { 219 | "name": "artstyle-typography", 220 | "prompt": "typographic art {prompt} . stylized, intricate, detailed, artistic, text-based", 221 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic" 222 | }, 223 | { 224 | "name": "artstyle-watercolor", 225 | "prompt": "watercolor painting {prompt} . vibrant, beautiful, painterly, detailed, textural, artistic", 226 | "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy" 227 | }, 228 | { 229 | "name": "futuristic-biomechanical", 230 | "prompt": "biomechanical style {prompt} . blend of organic and mechanical elements, futuristic, cybernetic, detailed, intricate", 231 | "negative_prompt": "natural, rustic, primitive, organic, simplistic" 232 | }, 233 | { 234 | "name": "futuristic-biomechanical cyberpunk", 235 | "prompt": "biomechanical cyberpunk {prompt} . cybernetics, human-machine fusion, dystopian, organic meets artificial, dark, intricate, highly detailed", 236 | "negative_prompt": "natural, colorful, deformed, sketch, low contrast, watercolor" 237 | }, 238 | { 239 | "name": "futuristic-cybernetic", 240 | "prompt": "cybernetic style {prompt} . futuristic, technological, cybernetic enhancements, robotics, artificial intelligence themes", 241 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, historical, medieval" 242 | }, 243 | { 244 | "name": "futuristic-cybernetic robot", 245 | "prompt": "cybernetic robot {prompt} . android, AI, machine, metal, wires, tech, futuristic, highly detailed", 246 | "negative_prompt": "organic, natural, human, sketch, watercolor, low contrast" 247 | }, 248 | { 249 | "name": "futuristic-cyberpunk cityscape", 250 | "prompt": "cyberpunk cityscape {prompt} . neon lights, dark alleys, skyscrapers, futuristic, vibrant colors, high contrast, highly detailed", 251 | "negative_prompt": "natural, rural, deformed, low contrast, black and white, sketch, watercolor" 252 | }, 253 | { 254 | "name": "futuristic-futuristic", 255 | "prompt": "futuristic style {prompt} . sleek, modern, ultramodern, high tech, detailed", 256 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, vintage, antique" 257 | }, 258 | { 259 | "name": "futuristic-retro cyberpunk", 260 | "prompt": "retro cyberpunk {prompt} . 80's inspired, synthwave, neon, vibrant, detailed, retro futurism", 261 | "negative_prompt": "modern, desaturated, black and white, realism, low contrast" 262 | }, 263 | { 264 | "name": "futuristic-retro futurism", 265 | "prompt": "retro-futuristic {prompt} . vintage sci-fi, 50s and 60s style, atomic age, vibrant, highly detailed", 266 | "negative_prompt": "contemporary, realistic, rustic, primitive" 267 | }, 268 | { 269 | "name": "futuristic-sci-fi", 270 | "prompt": "sci-fi style {prompt} . futuristic, technological, alien worlds, space themes, advanced civilizations", 271 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, historical, medieval" 272 | }, 273 | { 274 | "name": "futuristic-vaporwave", 275 | "prompt": "vaporwave style {prompt} . retro aesthetic, cyberpunk, vibrant, neon colors, vintage 80s and 90s style, highly detailed", 276 | "negative_prompt": "monochrome, muted colors, realism, rustic, minimalist, dark" 277 | }, 278 | { 279 | "name": "game-bubble bobble", 280 | "prompt": "Bubble Bobble style {prompt} . 8-bit, cute, pixelated, fantasy, vibrant, reminiscent of Bubble Bobble game", 281 | "negative_prompt": "realistic, modern, photorealistic, violent, horror" 282 | }, 283 | { 284 | "name": "game-cyberpunk game", 285 | "prompt": "cyberpunk game style {prompt} . neon, dystopian, futuristic, digital, vibrant, detailed, high contrast, reminiscent of cyberpunk genre video games", 286 | "negative_prompt": "historical, natural, rustic, low detailed" 287 | }, 288 | { 289 | "name": "game-fighting game", 290 | "prompt": "fighting game style {prompt} . dynamic, vibrant, action-packed, detailed character design, reminiscent of fighting video games", 291 | "negative_prompt": "peaceful, calm, minimalist, photorealistic" 292 | }, 293 | { 294 | "name": "game-gta", 295 | "prompt": "GTA-style artwork {prompt} . satirical, exaggerated, pop art style, vibrant colors, iconic characters, action-packed", 296 | "negative_prompt": "realistic, black and white, low contrast, impressionist, cubist, noisy, blurry, deformed" 297 | }, 298 | { 299 | "name": "game-mario", 300 | "prompt": "Super Mario style {prompt} . vibrant, cute, cartoony, fantasy, playful, reminiscent of Super Mario series", 301 | "negative_prompt": "realistic, modern, horror, dystopian, violent" 302 | }, 303 | { 304 | "name": "game-minecraft", 305 | "prompt": "Minecraft style {prompt} . blocky, pixelated, vibrant colors, recognizable characters and objects, game assets", 306 | "negative_prompt": "smooth, realistic, detailed, photorealistic, noise, blurry, deformed" 307 | }, 308 | { 309 | "name": "game-pokemon", 310 | "prompt": "Pokémon style {prompt} . vibrant, cute, anime, fantasy, reminiscent of Pokémon series", 311 | "negative_prompt": "realistic, modern, horror, dystopian, violent" 312 | }, 313 | { 314 | "name": "game-retro arcade", 315 | "prompt": "retro arcade style {prompt} . 8-bit, pixelated, vibrant, classic video game, old school gaming, reminiscent of 80s and 90s arcade games", 316 | "negative_prompt": "modern, ultra-high resolution, photorealistic, 3D" 317 | }, 318 | { 319 | "name": "game-retro game", 320 | "prompt": "retro game art {prompt} . 16-bit, vibrant colors, pixelated, nostalgic, charming, fun", 321 | "negative_prompt": "realistic, photorealistic, 35mm film, deformed, glitch, low contrast, noisy" 322 | }, 323 | { 324 | "name": "game-rpg fantasy game", 325 | "prompt": "role-playing game (RPG) style fantasy {prompt} . detailed, vibrant, immersive, reminiscent of high fantasy RPG games", 326 | "negative_prompt": "sci-fi, modern, urban, futuristic, low detailed" 327 | }, 328 | { 329 | "name": "game-strategy game", 330 | "prompt": "strategy game style {prompt} . overhead view, detailed map, units, reminiscent of real-time strategy video games", 331 | "negative_prompt": "first-person view, modern, photorealistic" 332 | }, 333 | { 334 | "name": "game-streetfighter", 335 | "prompt": "Street Fighter style {prompt} . vibrant, dynamic, arcade, 2D fighting game, highly detailed, reminiscent of Street Fighter series", 336 | "negative_prompt": "3D, realistic, modern, photorealistic, turn-based strategy" 337 | }, 338 | { 339 | "name": "game-zelda", 340 | "prompt": "Legend of Zelda style {prompt} . vibrant, fantasy, detailed, epic, heroic, reminiscent of The Legend of Zelda series", 341 | "negative_prompt": "sci-fi, modern, realistic, horror" 342 | }, 343 | { 344 | "name": "misc-architectural", 345 | "prompt": "architectural style {prompt} . clean lines, geometric shapes, minimalist, modern, architectural drawing, highly detailed", 346 | "negative_prompt": "curved lines, ornate, baroque, abstract, grunge" 347 | }, 348 | { 349 | "name": "misc-disco", 350 | "prompt": "disco-themed {prompt} . vibrant, groovy, retro 70s style, shiny disco balls, neon lights, dance floor, highly detailed", 351 | "negative_prompt": "minimalist, rustic, monochrome, contemporary, simplistic" 352 | }, 353 | { 354 | "name": "misc-dreamscape", 355 | "prompt": "dreamscape {prompt} . surreal, ethereal, dreamy, mysterious, fantasy, highly detailed", 356 | "negative_prompt": "realistic, concrete, ordinary, mundane" 357 | }, 358 | { 359 | "name": "misc-dystopian", 360 | "prompt": "dystopian style {prompt} . bleak, post-apocalyptic, somber, dramatic, highly detailed", 361 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, cheerful, optimistic, vibrant, colorful" 362 | }, 363 | { 364 | "name": "misc-fairy tale", 365 | "prompt": "fairy tale {prompt} . magical, fantastical, enchanting, storybook style, highly detailed", 366 | "negative_prompt": "realistic, modern, ordinary, mundane" 367 | }, 368 | { 369 | "name": "misc-gothic", 370 | "prompt": "gothic style {prompt} . dark, mysterious, haunting, dramatic, ornate, detailed", 371 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, cheerful, optimistic" 372 | }, 373 | { 374 | "name": "misc-grunge", 375 | "prompt": "grunge style {prompt} . textured, distressed, vintage, edgy, punk rock vibe, dirty, noisy", 376 | "negative_prompt": "smooth, clean, minimalist, sleek, modern, photorealistic" 377 | }, 378 | { 379 | "name": "misc-horror", 380 | "prompt": "horror-themed {prompt} . eerie, unsettling, dark, spooky, suspenseful, grim, highly detailed", 381 | "negative_prompt": "cheerful, bright, vibrant, light-hearted, cute" 382 | }, 383 | { 384 | "name": "misc-kawaii", 385 | "prompt": "kawaii style {prompt} . cute, adorable, brightly colored, cheerful, anime influence, highly detailed", 386 | "negative_prompt": "dark, scary, realistic, monochrome, abstract" 387 | }, 388 | { 389 | "name": "misc-lovecraftian", 390 | "prompt": "lovecraftian horror {prompt} . eldritch, cosmic horror, unknown, mysterious, surreal, highly detailed", 391 | "negative_prompt": "light-hearted, mundane, familiar, simplistic, realistic" 392 | }, 393 | { 394 | "name": "misc-macabre", 395 | "prompt": "macabre style {prompt} . dark, gothic, grim, haunting, highly detailed", 396 | "negative_prompt": "bright, cheerful, light-hearted, cartoonish, cute" 397 | }, 398 | { 399 | "name": "misc-manga", 400 | "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", 401 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style" 402 | }, 403 | { 404 | "name": "misc-metropolis", 405 | "prompt": "metropolis-themed {prompt} . urban, cityscape, skyscrapers, modern, futuristic, highly detailed", 406 | "negative_prompt": "rural, natural, rustic, historical, simple" 407 | }, 408 | { 409 | "name": "misc-minimalist", 410 | "prompt": "minimalist style {prompt} . simple, clean, uncluttered, modern, elegant", 411 | "negative_prompt": "ornate, complicated, highly detailed, cluttered, disordered, messy, noisy" 412 | }, 413 | { 414 | "name": "misc-monochrome", 415 | "prompt": "monochrome {prompt} . black and white, contrast, tone, texture, detailed", 416 | "negative_prompt": "colorful, vibrant, noisy, blurry, deformed" 417 | }, 418 | { 419 | "name": "misc-nautical", 420 | "prompt": "nautical-themed {prompt} . sea, ocean, ships, maritime, beach, marine life, highly detailed", 421 | "negative_prompt": "landlocked, desert, mountains, urban, rustic" 422 | }, 423 | { 424 | "name": "misc-space", 425 | "prompt": "space-themed {prompt} . cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed", 426 | "negative_prompt": "earthly, mundane, ground-based, realism" 427 | }, 428 | { 429 | "name": "misc-stained glass", 430 | "prompt": "stained glass style {prompt} . vibrant, beautiful, translucent, intricate, detailed", 431 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic" 432 | }, 433 | { 434 | "name": "misc-techwear fashion", 435 | "prompt": "techwear fashion {prompt} . futuristic, cyberpunk, urban, tactical, sleek, dark, highly detailed", 436 | "negative_prompt": "vintage, rural, colorful, low contrast, realism, sketch, watercolor" 437 | }, 438 | { 439 | "name": "misc-tribal", 440 | "prompt": "tribal style {prompt} . indigenous, ethnic, traditional patterns, bold, natural colors, highly detailed", 441 | "negative_prompt": "modern, futuristic, minimalist, pastel" 442 | }, 443 | { 444 | "name": "misc-zentangle", 445 | "prompt": "zentangle {prompt} . intricate, abstract, monochrome, patterns, meditative, highly detailed", 446 | "negative_prompt": "colorful, representative, simplistic, large fields of color" 447 | }, 448 | { 449 | "name": "papercraft-collage", 450 | "prompt": "collage style {prompt} . mixed media, layered, textural, detailed, artistic", 451 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic" 452 | }, 453 | { 454 | "name": "papercraft-flat papercut", 455 | "prompt": "flat papercut style {prompt} . silhouette, clean cuts, paper, sharp edges, minimalist, color block", 456 | "negative_prompt": "3D, high detail, noise, grainy, blurry, painting, drawing, photo, disfigured" 457 | }, 458 | { 459 | "name": "papercraft-kirigami", 460 | "prompt": "kirigami representation of {prompt} . 3D, paper folding, paper cutting, Japanese, intricate, symmetrical, precision, clean lines", 461 | "negative_prompt": "painting, drawing, 2D, noisy, blurry, deformed" 462 | }, 463 | { 464 | "name": "papercraft-paper mache", 465 | "prompt": "paper mache representation of {prompt} . 3D, sculptural, textured, handmade, vibrant, fun", 466 | "negative_prompt": "2D, flat, photo, sketch, digital art, deformed, noisy, blurry" 467 | }, 468 | { 469 | "name": "papercraft-paper quilling", 470 | "prompt": "paper quilling art of {prompt} . intricate, delicate, curling, rolling, shaping, coiling, loops, 3D, dimensional, ornamental", 471 | "negative_prompt": "photo, painting, drawing, 2D, flat, deformed, noisy, blurry" 472 | }, 473 | { 474 | "name": "papercraft-papercut collage", 475 | "prompt": "papercut collage of {prompt} . mixed media, textured paper, overlapping, asymmetrical, abstract, vibrant", 476 | "negative_prompt": "photo, 3D, realistic, drawing, painting, high detail, disfigured" 477 | }, 478 | { 479 | "name": "papercraft-papercut shadow box", 480 | "prompt": "3D papercut shadow box of {prompt} . layered, dimensional, depth, silhouette, shadow, papercut, handmade, high contrast", 481 | "negative_prompt": "painting, drawing, photo, 2D, flat, high detail, blurry, noisy, disfigured" 482 | }, 483 | { 484 | "name": "papercraft-stacked papercut", 485 | "prompt": "stacked papercut art of {prompt} . 3D, layered, dimensional, depth, precision cut, stacked layers, papercut, high contrast", 486 | "negative_prompt": "2D, flat, noisy, blurry, painting, drawing, photo, deformed" 487 | }, 488 | { 489 | "name": "papercraft-thick layered papercut", 490 | "prompt": "thick layered papercut art of {prompt} . deep 3D, volumetric, dimensional, depth, thick paper, high stack, heavy texture, tangible layers", 491 | "negative_prompt": "2D, flat, thin paper, low stack, smooth texture, painting, drawing, photo, deformed" 492 | }, 493 | { 494 | "name": "photo-alien", 495 | "prompt": "alien-themed {prompt} . extraterrestrial, cosmic, otherworldly, mysterious, sci-fi, highly detailed", 496 | "negative_prompt": "earthly, mundane, common, realistic, simple" 497 | }, 498 | { 499 | "name": "photo-film noir", 500 | "prompt": "film noir style {prompt} . monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic", 501 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, vibrant, colorful" 502 | }, 503 | { 504 | "name": "photo-glamour", 505 | "prompt": "glamorous photo {prompt} . high fashion, luxurious, extravagant, stylish, sensual, opulent, elegance, stunning beauty, professional, high contrast, detailed", 506 | "negative_prompt": "ugly, deformed, noisy, blurry, distorted, grainy, sketch, low contrast, dull, plain, modest" 507 | }, 508 | { 509 | "name": "photo-hdr", 510 | "prompt": "HDR photo of {prompt} . High dynamic range, vivid, rich details, clear shadows and highlights, realistic, intense, enhanced contrast, highly detailed", 511 | "negative_prompt": "flat, low contrast, oversaturated, underexposed, overexposed, blurred, noisy" 512 | }, 513 | { 514 | "name": "photo-iphone photographic", 515 | "prompt": "iphone photo {prompt} . large depth of field, deep depth of field, highly detailed", 516 | "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, shallow depth of field, bokeh" 517 | }, 518 | { 519 | "name": "photo-long exposure", 520 | "prompt": "long exposure photo of {prompt} . Blurred motion, streaks of light, surreal, dreamy, ghosting effect, highly detailed", 521 | "negative_prompt": "static, noisy, deformed, shaky, abrupt, flat, low contrast" 522 | }, 523 | { 524 | "name": "photo-neon noir", 525 | "prompt": "neon noir {prompt} . cyberpunk, dark, rainy streets, neon signs, high contrast, low light, vibrant, highly detailed", 526 | "negative_prompt": "bright, sunny, daytime, low contrast, black and white, sketch, watercolor" 527 | }, 528 | { 529 | "name": "photo-silhouette", 530 | "prompt": "silhouette style {prompt} . high contrast, minimalistic, black and white, stark, dramatic", 531 | "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, color, realism, photorealistic" 532 | }, 533 | { 534 | "name": "photo-tilt-shift", 535 | "prompt": "tilt-shift photo of {prompt} . selective focus, miniature effect, blurred background, highly detailed, vibrant, perspective control", 536 | "negative_prompt": "blurry, noisy, deformed, flat, low contrast, unrealistic, oversaturated, underexposed" 537 | } 538 | ] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2023] [Zhongjie Duan] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastSDXL 2 | 3 | This is an efficient implementation of Stable-Diffusion-XL. I make the following improvements: 4 | 5 | * Reconstruct the architecture of UNet. This UNet implementation is faster than others ([Diffusers](https://github.com/huggingface/diffusers), [Fooocus](https://github.com/lllyasviel/Fooocus), etc.). If you are interested in this implementation, please see `FastSDXL/BlockUNet.py`. The source code of this UNet is short and easy to understand. You can also use this component in your own projects. 6 | * Use a trainable scheduler named [OLSS](https://arxiv.org/abs/2305.14677). The implementation of OLSS is based on [this project](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler). I find this scheduler can improve the quality of generated images with given steps but we need to train it first (the training process requires a few minutes). To synthesize images with a specific style, OLSS is a good choice. 7 | 8 | ## Usage 9 | 10 | The code is headless. I developed this project based on `diffusers==0.21.3`. If you find it cannot run with another version of `diffusers`, please open an issue and tell me. 11 | 12 | ``` 13 | pip install diffusers safetensors torch gradio 14 | ``` 15 | 16 | To launch a webui without OLSS scheduler, please run the following command. 17 | 18 | ``` 19 | python launch_without_olss.py 20 | ``` 21 | 22 | To train an OLSS scheduler, please see `run_olss.py` for more details. 23 | 24 | I trained an OLSS scheduler with the style of "Slightly Cinematic". You can use it by running the following command. 25 | 26 | ``` 27 | python launch_with_olss.py 28 | ``` 29 | 30 | ## Efficiency of the reconstructed UNet 31 | 32 | I tested my code using NVidia 3060 laptop (6G, 85W). The resolution is 1024*1024, and the model is converted to float16 format. 33 | 34 | * Diffusers: CUDA out of memory 35 | * Fooocus: 1.78s/it 36 | * FastSDXL (ours): 1.17s/it 37 | 38 | ## Performance of OLSS scheduler 39 | 40 | Here are some examples of OLSS. 41 | 42 | Prompt template: cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy 43 | Negative prompt: anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured 44 | 45 | * a young girl, black hair, white clothes, in a garden, the background is red and white flowers 46 | 47 | In the image generated by OLSS, the details of flowers on the right are more realistic. 48 | 49 | |DDIM|DPM|OLSS| 50 | |-|-|-| 51 | |![ddim](https://github.com/Artiprocher/FastSDXL/assets/35051019/942ccd69-1e7a-43ea-80f7-747a756d1cfb)|![dpmsolver](https://github.com/Artiprocher/FastSDXL/assets/35051019/a33302d8-2744-4e13-8bbd-66d13fcb832f)|![olss](https://github.com/Artiprocher/FastSDXL/assets/35051019/06446a51-8e41-4e3e-b37d-21ed9370d998)| 52 | 53 | * a forest in spring, birds 54 | 55 | More birds. 56 | 57 | |DDIM|DPM|OLSS| 58 | |-|-|-| 59 | |![ddim](https://github.com/Artiprocher/FastSDXL/assets/35051019/4a13362f-1d82-4400-8d96-cafd2eff4907)|![dpmsolver](https://github.com/Artiprocher/FastSDXL/assets/35051019/9174b534-46bf-41da-b914-0f3259bfe1b3)|![olss](https://github.com/Artiprocher/FastSDXL/assets/35051019/223667d6-c3d4-4b0c-987f-b687754876e3)| 60 | 61 | * an orange cat and a pink ball on a white sofa 62 | 63 | In this example, the image generated by OLSS is completely different with others. 64 | 65 | |DDIM|DPM|OLSS| 66 | |-|-|-| 67 | |![ddim](https://github.com/Artiprocher/FastSDXL/assets/35051019/97d26ea9-cde2-4ed6-a0e9-b08e012cad07)|![dpmsolver](https://github.com/Artiprocher/FastSDXL/assets/35051019/655a4d59-12a0-4a1e-a5bd-fc773f1c64ae)|![olss](https://github.com/Artiprocher/FastSDXL/assets/35051019/a201bf44-a343-40bc-a4ed-58226d34af64)| 68 | 69 | * a robot, with blue lightsaber, in a city 70 | 71 | Sometimes OLSS can modify the composition and fix the composition error. 72 | 73 | |DDIM|DPM|OLSS| 74 | |-|-|-| 75 | |![ddim](https://github.com/Artiprocher/FastSDXL/assets/35051019/deae9978-f2c2-40e2-bb3b-7cbc118d8780)|![dpmsolver](https://github.com/Artiprocher/FastSDXL/assets/35051019/c68e9dd1-bba5-49c1-8592-c66016d89cff)|![olss](https://github.com/Artiprocher/FastSDXL/assets/35051019/e9a0337a-9b69-447f-80ab-5021fdff63ad)| 76 | -------------------------------------------------------------------------------- /launch_with_olss.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 3 | import gradio as gr 4 | from diffusers import DiffusionPipeline 5 | import torch 6 | from FastSDXL.BlockUNet import BlockUNet 7 | from FastSDXL.Styler import styles 8 | from FastSDXL.OLSS import OLSSScheduler 9 | 10 | 11 | pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16") 12 | block_unet = BlockUNet().half().to("cuda") 13 | block_unet.from_diffusers(state_dict=pipe.unet.state_dict()) 14 | pipe.unet = block_unet 15 | pipe.scheduler = OLSSScheduler.load("models/olss_scheduler.bin") 16 | pipe.enable_model_cpu_offload() 17 | 18 | 19 | def generate_image(prompt, negative_prompt, height, width, style_name="Default (Slightly Cinematic)", denoising_steps=30): 20 | height = (height + 63) // 64 * 64 21 | width = (width + 63) // 64 * 64 22 | for style in styles: 23 | if style["name"] == style_name: 24 | prompt = style["prompt"].replace("{prompt}", prompt) 25 | negative_prompt = style["negative_prompt"] + negative_prompt 26 | break 27 | print("Prompt:", prompt) 28 | print("Negative prompt:", negative_prompt) 29 | image = pipe(prompt=prompt, height=height, width=width, num_inference_steps=denoising_steps).images[0] 30 | return image 31 | 32 | with gr.Blocks() as demo: 33 | with gr.Row(): 34 | with gr.Column(): 35 | prompt = gr.Textbox(label="Prompt") 36 | negative_prompt = gr.Textbox(label="Negative prompt") 37 | height = gr.Slider(label="Height", minimum=512, maximum=2048, value=1024, step=64) 38 | width = gr.Slider(label="Width", minimum=512, maximum=2048, value=1024, step=64) 39 | button = gr.Button(label="Generate") 40 | with gr.Column(): 41 | image = gr.Image(label="Generated image") 42 | button.click(fn=generate_image, inputs=[prompt, negative_prompt, height, width], outputs=[image]) 43 | 44 | demo.queue() 45 | demo.launch() 46 | -------------------------------------------------------------------------------- /launch_without_olss.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 3 | import gradio as gr 4 | from diffusers import DiffusionPipeline 5 | import torch 6 | from FastSDXL.BlockUNet import BlockUNet 7 | from FastSDXL.Styler import styles 8 | 9 | 10 | pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16") 11 | block_unet = BlockUNet().half().to("cuda") 12 | block_unet.from_diffusers(state_dict=pipe.unet.state_dict()) 13 | pipe.unet = block_unet 14 | pipe.enable_model_cpu_offload() 15 | 16 | 17 | def generate_image(prompt, negative_prompt, style_name, height, width, denoising_steps): 18 | height = (height + 63) // 64 * 64 19 | width = (width + 63) // 64 * 64 20 | for style in styles: 21 | if style["name"] == style_name: 22 | prompt = style["prompt"].replace("{prompt}", prompt) 23 | negative_prompt = style["negative_prompt"] + negative_prompt 24 | break 25 | print("Prompt:", prompt) 26 | print("Negative prompt:", negative_prompt) 27 | image = pipe(prompt=prompt, height=height, width=width, num_inference_steps=denoising_steps).images[0] 28 | return image 29 | 30 | with gr.Blocks() as demo: 31 | with gr.Row(): 32 | with gr.Column(): 33 | prompt = gr.Textbox(label="Prompt") 34 | negative_prompt = gr.Textbox(label="Negative prompt") 35 | style = gr.Dropdown(label="Style", choices=["None"] + [style["name"] for style in styles], value="None") 36 | height = gr.Slider(label="Height", minimum=512, maximum=2048, value=1024, step=64) 37 | width = gr.Slider(label="Width", minimum=512, maximum=2048, value=1024, step=64) 38 | denoising_steps = gr.Slider(label="Denoising steps", minimum=10, maximum=100, value=30, step=1) 39 | button = gr.Button(label="Generate") 40 | with gr.Column(): 41 | image = gr.Image(label="Generated image") 42 | button.click(fn=generate_image, inputs=[prompt, negative_prompt, style, height, width, denoising_steps], outputs=[image]) 43 | 44 | demo.queue() 45 | demo.launch() 46 | -------------------------------------------------------------------------------- /models/olss_scheduler.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Artiprocher/FastSDXL/7d561b74c66682b6f2d6391f1eb667e16836cc20/models/olss_scheduler.bin -------------------------------------------------------------------------------- /run_olss.py: -------------------------------------------------------------------------------- 1 | from diffusers import DiffusionPipeline 2 | import torch 3 | from FastSDXL.BlockUNet import BlockUNet 4 | from FastSDXL.OLSS import SchedulerWrapper, OLSSScheduler 5 | from diffusers import DDIMScheduler, DPMSolverMultistepScheduler 6 | 7 | 8 | pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16") 9 | block_unet = BlockUNet().half() 10 | block_unet.from_diffusers(state_dict=pipe.unet.state_dict()) 11 | pipe.unet = block_unet 12 | pipe.enable_model_cpu_offload() 13 | 14 | 15 | # Train 16 | train_steps = 300 17 | inference_steps = 30 18 | pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config)) 19 | pipe( 20 | prompt="cinematic still a dog. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", 21 | negative_prompt="anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", 22 | height=1024, width=1024, num_inference_steps=train_steps 23 | ) 24 | pipe( 25 | prompt="cinematic still a cat. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", 26 | negative_prompt="anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", 27 | height=1024, width=1024, num_inference_steps=train_steps 28 | ) 29 | pipe( 30 | prompt="cinematic still a woman. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", 31 | negative_prompt="anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", 32 | height=1024, width=1024, num_inference_steps=train_steps 33 | ) 34 | pipe( 35 | prompt="cinematic still a car. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", 36 | negative_prompt="anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", 37 | height=1024, width=1024, num_inference_steps=train_steps 38 | ) 39 | pipe.scheduler.prepare_olss(inference_steps) 40 | pipe.scheduler.olss_scheduler.save("models/olss_scheduler.bin") 41 | 42 | 43 | # Test 44 | prompt = "cinematic still a forest in spring, birds. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy" 45 | negative_prompt = "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured" 46 | 47 | torch.manual_seed(0) 48 | pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0/scheduler") 49 | image = pipe(prompt=prompt, negative_prompt=negative_prompt, height=1024, width=1024, num_inference_steps=inference_steps).images[0] 50 | image.save(f"dpmsolver.png") 51 | 52 | torch.manual_seed(0) 53 | pipe.scheduler = OLSSScheduler.load("models/olss_scheduler.bin") 54 | image = pipe(prompt=prompt, negative_prompt=negative_prompt, height=1024, width=1024, num_inference_steps=inference_steps).images[0] 55 | image.save(f"olss.png") 56 | 57 | torch.manual_seed(0) 58 | pipe.scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0/scheduler") 59 | image = pipe(prompt=prompt, negative_prompt=negative_prompt, height=1024, width=1024, num_inference_steps=inference_steps).images[0] 60 | image.save(f"ddim.png") 61 | --------------------------------------------------------------------------------