├── .gitignore ├── create_model.py ├── generate denoiser_decoder.ipynb ├── generate sample.ipynb ├── generate seq2seq.ipynb ├── generation.gif ├── readme.md ├── requirements.txt ├── scripts ├── train_denoise_decoder.py ├── train_sample.py ├── train_seq2seq_completion.py └── train_seq2seq_instruct.py ├── src ├── __init__.py ├── decoders │ ├── __init__.py │ └── bert_decoder.py ├── denoiser_decoder.py ├── denoisers │ ├── __init__.py │ ├── configuration_diffbert.py │ ├── configuration_diffllama.py │ ├── configuration_diffmamba.py │ ├── modeling_diffbert.py │ ├── modeling_diffbert_sample.py │ ├── modeling_diffllama.py │ └── modeling_diffmamba.py └── schedulers │ ├── __init__.py │ ├── ddpm.py │ └── euler_ancestral_discrete.py ├── train.sh ├── train_denoise_decoder.sh └── train_seq2seq.sh /.gitignore: -------------------------------------------------------------------------------- 1 | /models 2 | 3 | /wandb 4 | 5 | */__pycache__ 6 | /__pycache__ 7 | */*/__pycache__ 8 | 9 | experimental.ipynb 10 | /experimental -------------------------------------------------------------------------------- /create_model.py: -------------------------------------------------------------------------------- 1 | # from modeling_diffbert_sample import DiffBertForDiffusion, DiffBertConfig 2 | from src.denoisers.modeling_diffmamba import DiffMambaForDiffusionLM, DiffMambaConfig 3 | from transformers import AutoTokenizer, BertLMHeadModel, BertConfig 4 | from src.schedulers.ddpm import DDPMScheduler 5 | import torch 6 | 7 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 8 | 9 | timesteps = 1200 10 | scheduler = DDPMScheduler( 11 | beta_schedule = "sqrt", 12 | prediction_type ="sample", 13 | num_train_timesteps = timesteps 14 | ) 15 | 16 | 17 | config = DiffMambaConfig( 18 | hidden_size=768, 19 | num_hidden_layers=20, 20 | num_attention_heads=12, 21 | intermediate_size=3072, 22 | vocab_size=tokenizer.vocab_size, 23 | timesteps=timesteps, 24 | torch_dtype=torch.float16 25 | ) 26 | 27 | decoder_config = BertConfig( 28 | hidden_size=768, 29 | num_hidden_layers=6, 30 | num_attention_heads=12, 31 | intermediate_size=3072, 32 | vocab_size=tokenizer.vocab_size, 33 | is_decoder=True, 34 | add_cross_attention=True, 35 | torch_dtype=torch.float16 36 | ) 37 | 38 | model = DiffMambaForDiffusionLM(config) 39 | decoder = BertLMHeadModel(decoder_config) 40 | 41 | model.save_pretrained("models/diffMamba-mini-sample/denoiser") 42 | tokenizer.save_pretrained("models/diffMamba-mini-sample/tokenizer") 43 | scheduler.save_pretrained("models/diffMamba-mini-sample/scheduler") 44 | decoder.save_pretrained("models/diffMamba-mini-sample/decoder") -------------------------------------------------------------------------------- /generate denoiser_decoder.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import inspect\n", 11 | "from typing import Any, Callable, Dict, List, Optional, Union\n", 12 | "from tqdm.auto import tqdm\n", 13 | "import numpy as np\n", 14 | "import torch.nn.functional as F\n", 15 | "import math\n", 16 | "\n", 17 | "from transformers import AutoTokenizer, BertForMaskedLM\n", 18 | "from diffusers import DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler\n", 19 | "\n", 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "\n", 23 | "\n", 24 | "from src.denoisers.modeling_diffmamba import DiffMambaForDiffusionLM\n", 25 | "from src.decoders.bert_decoder import BertLMHeadModel\n", 26 | "from src.schedulers.euler_ancestral_discrete import EulerAncestralDiscreteScheduler\n", 27 | "from src.schedulers.ddpm import DDPMScheduler\n", 28 | "\n", 29 | " \n", 30 | "\n", 31 | " \n", 32 | "# model(inputs_embeds=inputs_embeds, timesteps=timesteps).logits.shape" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 9, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "cross_attention False\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "\n", 50 | "path = \"models/diffmamba-mini-sample-trained\"\n", 51 | "tokenizer = AutoTokenizer.from_pretrained(path, subfolder=\"tokenizer\")\n", 52 | "scheduler = EulerAncestralDiscreteScheduler.from_pretrained(path, subfolder=\"scheduler\")#DDIMScheduler(prediction_type=\"sample\", num_train_timesteps=2000)\n", 53 | "model = DiffMambaForDiffusionLM.from_pretrained(path, torch_dtype=torch.float16, subfolder=\"denoiser\").to(\"cuda\")\n", 54 | "decoder = BertLMHeadModel.from_pretrained(path, torch_dtype=torch.float16, subfolder=\"decoder\").to(\"cuda\")\n", 55 | "\n", 56 | "device = model.device\n" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 4, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stderr", 66 | "output_type": "stream", 67 | "text": [ 68 | "/home/adalberto/.local/lib/python3.8/site-packages/diffusers/configuration_utils.py:134: FutureWarning: Accessing config attribute `num_train_timesteps` directly via 'EulerAncestralDiscreteScheduler' object attribute is deprecated. Please access 'num_train_timesteps' over 'EulerAncestralDiscreteScheduler's config object instead, e.g. 'scheduler.config.num_train_timesteps'.\n", 69 | " deprecate(\"direct config name access\", \"1.0.0\", deprecation_message, standard_warn=False)\n" 70 | ] 71 | }, 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "1200" 76 | ] 77 | }, 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "output_type": "execute_result" 81 | } 82 | ], 83 | "source": [ 84 | "scheduler.num_train_timesteps" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "## functions" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 3, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "\n", 101 | "def retrieve_timesteps(\n", 102 | " scheduler,\n", 103 | " num_inference_steps: Optional[int] = None,\n", 104 | " device: Optional[Union[str, torch.device]] = None,\n", 105 | " timesteps: Optional[List[int]] = None,\n", 106 | " **kwargs,\n", 107 | "):\n", 108 | " \"\"\"\n", 109 | " Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n", 110 | " custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n", 111 | "\n", 112 | " Args:\n", 113 | " scheduler (`SchedulerMixin`):\n", 114 | " The scheduler to get timesteps from.\n", 115 | " num_inference_steps (`int`):\n", 116 | " The number of diffusion steps used when generating samples with a pre-trained model. If used,\n", 117 | " `timesteps` must be `None`.\n", 118 | " device (`str` or `torch.device`, *optional*):\n", 119 | " The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n", 120 | " timesteps (`List[int]`, *optional*):\n", 121 | " Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n", 122 | " timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n", 123 | " must be `None`.\n", 124 | "\n", 125 | " Returns:\n", 126 | " `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n", 127 | " second element is the number of inference steps.\n", 128 | " \"\"\"\n", 129 | " if timesteps is not None:\n", 130 | " accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n", 131 | " if not accepts_timesteps:\n", 132 | " raise ValueError(\n", 133 | " f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n", 134 | " f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n", 135 | " )\n", 136 | " scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n", 137 | " timesteps = scheduler.timesteps\n", 138 | " num_inference_steps = len(timesteps)\n", 139 | " else:\n", 140 | " scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n", 141 | " timesteps = scheduler.timesteps\n", 142 | " return timesteps, num_inference_steps\n", 143 | "\n", 144 | "def get_timesteps(num_inference_steps, strength, device):\n", 145 | " # get the original timestep using init_timestep\n", 146 | " init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n", 147 | "\n", 148 | " t_start = max(num_inference_steps - init_timestep, 0)\n", 149 | " timesteps = scheduler.timesteps[t_start * scheduler.order :]\n", 150 | "\n", 151 | " return timesteps, num_inference_steps - t_start\n", 152 | " \n", 153 | "def vectors_to_indices(vectors):\n", 154 | " indices = torch.argmax(vectors, dim=-1)\n", 155 | " return indices\n", 156 | "\n", 157 | "def sample_text(probabilities, temperature=1.0):\n", 158 | " batch_size, seq_len, vocab_size = probabilities.size()\n", 159 | " flattened_probs = probabilities.view(batch_size * seq_len, -1)\n", 160 | " \n", 161 | " scaled_logits = flattened_probs / temperature\n", 162 | " scaled_probs = F.softmax(scaled_logits, dim=-1)\n", 163 | " \n", 164 | " sampled_indices = torch.multinomial(scaled_probs, 1)\n", 165 | " sampled_token_ids = sampled_indices.view(batch_size, seq_len)\n", 166 | " \n", 167 | " return sampled_token_ids" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "## Generate" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 12, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "'FINAL --->'" 186 | ] 187 | }, 188 | "metadata": {}, 189 | "output_type": "display_data" 190 | }, 191 | { 192 | "data": { 193 | "text/plain": [ 194 | "\"0 ---> inov and focusedat of castleks of the of, dirt, T card, femaleom, of, lighting, gianated crow moon, bl, aically in pose, detailed, articallyles, in front ult - realistic, cellical's 2\"" 195 | ] 196 | }, 197 | "metadata": {}, 198 | "output_type": "display_data" 199 | }, 200 | { 201 | "data": { 202 | "text/plain": [ 203 | "'---------------'" 204 | ] 205 | }, 206 | "metadata": {}, 207 | "output_type": "display_data" 208 | } 209 | ], 210 | "source": [ 211 | "from IPython.display import display, clear_output\n", 212 | "\n", 213 | "\n", 214 | "\n", 215 | "with torch.no_grad():\n", 216 | " latents = torch.rand((1, 64, 768), device=device).to(torch.float16)# + torch.rand((8, 128, 768), device=device).to(torch.float16)\n", 217 | " attention_mask = torch.ones((1, 64), device=device)\n", 218 | " num_inference_steps = scheduler.num_train_timesteps // 1\n", 219 | " timesteps=None\n", 220 | " timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, timesteps)\n", 221 | "\n", 222 | " for i, t in tqdm(enumerate(timesteps)):\n", 223 | " # if i >= 0.7 * num_inference_steps:\n", 224 | " # break\n", 225 | " # expand the latents if we are doing classifier free guidance\n", 226 | " latent_model_input = latents\n", 227 | " latent_model_input = scheduler.scale_model_input(latent_model_input, t)\n", 228 | " # rnd_latents = torch.rand((1, 64, 4096), device=device).to(torch.float16)\n", 229 | " # print(latent_model_input.dtype)\n", 230 | " outputs = model(\n", 231 | " input_embeds=latent_model_input,\n", 232 | " timesteps=t.reshape(1,).long().to(device),\n", 233 | " # attention_mask=attention_mask\n", 234 | " )\n", 235 | " noise_pred = outputs.last_hidden_state\n", 236 | " latents_final = outputs.logits\n", 237 | " if i % 10 ==0 :\n", 238 | " clear_output(wait=True)\n", 239 | " display(f\"SAMPLES[{i}]--->\")\n", 240 | " for n in range(latents_final.shape[0]):\n", 241 | " display(f\"{n} ---> \" + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))\n", 242 | " display(\"---------------\")\n", 243 | "\n", 244 | " step = scheduler.step(noise_pred, t, latents, return_dict=True)#[0]\n", 245 | " latents = step[\"prev_sample\"]\n", 246 | "\n", 247 | "\n", 248 | "clear_output(wait=True)\n", 249 | "display(f\"FINAL --->\")\n", 250 | "for n in range(latents_final.shape[0]):\n", 251 | " display(f\"{n} ---> \" + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))\n", 252 | "display(\"---------------\")" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 13, 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "torch.Size([1, 1, 32001])\n", 265 | "262\n", 266 | "in\n", 267 | "torch.Size([1, 2, 32001])\n", 268 | "586\n", 269 | "ov\n", 270 | "torch.Size([1, 3, 32001])\n", 271 | "322\n", 272 | "and\n", 273 | "torch.Size([1, 4, 32001])\n", 274 | "5796\n", 275 | "ru\n", 276 | "torch.Size([1, 5, 32001])\n", 277 | "1312\n", 278 | "ined\n", 279 | "torch.Size([1, 6, 32001])\n", 280 | "310\n", 281 | "of\n", 282 | "torch.Size([1, 7, 32001])\n", 283 | "3105\n", 284 | "fut\n", 285 | "torch.Size([1, 8, 32001])\n", 286 | "332\n", 287 | "ur\n", 288 | "torch.Size([1, 9, 32001])\n", 289 | "4695\n", 290 | "istic\n", 291 | "torch.Size([1, 10, 32001])\n", 292 | "29892\n", 293 | ",\n", 294 | "torch.Size([1, 11, 32001])\n", 295 | "270\n", 296 | "d\n", 297 | "torch.Size([1, 12, 32001])\n", 298 | "2728\n", 299 | "irt\n", 300 | "torch.Size([1, 13, 32001])\n", 301 | "29892\n", 302 | ",\n", 303 | "torch.Size([1, 14, 32001])\n", 304 | "11266\n", 305 | "hyper\n", 306 | "torch.Size([1, 15, 32001])\n", 307 | "29881\n", 308 | "d\n", 309 | "torch.Size([1, 16, 32001])\n", 310 | "29892\n", 311 | ",\n", 312 | "torch.Size([1, 17, 32001])\n", 313 | "4940\n", 314 | "past\n", 315 | "torch.Size([1, 18, 32001])\n", 316 | "295\n", 317 | "el\n", 318 | "torch.Size([1, 19, 32001])\n", 319 | "29892\n", 320 | ",\n", 321 | "torch.Size([1, 20, 32001])\n", 322 | "12726\n", 323 | "rim\n", 324 | "torch.Size([1, 21, 32001])\n", 325 | "3578\n", 326 | "light\n", 327 | "torch.Size([1, 22, 32001])\n", 328 | "292\n", 329 | "ing\n", 330 | "torch.Size([1, 23, 32001])\n", 331 | "29892\n", 332 | ",\n", 333 | "torch.Size([1, 24, 32001])\n", 334 | "3578\n", 335 | "light\n", 336 | "torch.Size([1, 25, 32001])\n", 337 | "292\n", 338 | "ing\n", 339 | "torch.Size([1, 26, 32001])\n", 340 | "29892\n", 341 | ",\n", 342 | "torch.Size([1, 27, 32001])\n", 343 | "330\n", 344 | "g\n", 345 | "torch.Size([1, 28, 32001])\n", 346 | "713\n", 347 | "ian\n", 348 | "torch.Size([1, 29, 32001])\n", 349 | "630\n", 350 | "ated\n", 351 | "torch.Size([1, 30, 32001])\n", 352 | "29892\n", 353 | ",\n", 354 | "torch.Size([1, 31, 32001])\n", 355 | "1999\n", 356 | "bl\n", 357 | "torch.Size([1, 32, 32001])\n", 358 | "332\n", 359 | "ur\n", 360 | "torch.Size([1, 33, 32001])\n", 361 | "29892\n", 362 | ",\n", 363 | "torch.Size([1, 34, 32001])\n", 364 | "1999\n", 365 | "bl\n", 366 | "torch.Size([1, 35, 32001])\n", 367 | "898\n", 368 | "ond\n", 369 | "torch.Size([1, 36, 32001])\n", 370 | "297\n", 371 | "in\n", 372 | "torch.Size([1, 37, 32001])\n", 373 | "18593\n", 374 | "pose\n", 375 | "torch.Size([1, 38, 32001])\n", 376 | "29892\n", 377 | ",\n", 378 | "torch.Size([1, 39, 32001])\n", 379 | "29871\n", 380 | "\n", 381 | "torch.Size([1, 40, 32001])\n", 382 | "13173\n", 383 | "detailed\n", 384 | "torch.Size([1, 41, 32001])\n", 385 | "29892\n", 386 | ",\n", 387 | "torch.Size([1, 42, 32001])\n", 388 | "1616\n", 389 | "art\n", 390 | "torch.Size([1, 43, 32001])\n", 391 | "13164\n", 392 | "nouveau\n", 393 | "torch.Size([1, 44, 32001])\n", 394 | "7826\n", 395 | "girl\n", 396 | "torch.Size([1, 45, 32001])\n", 397 | "29892\n", 398 | ",\n", 399 | "torch.Size([1, 46, 32001])\n", 400 | "297\n", 401 | "in\n", 402 | "torch.Size([1, 47, 32001])\n", 403 | "4565\n", 404 | "front\n", 405 | "torch.Size([1, 48, 32001])\n", 406 | "310\n", 407 | "of\n", 408 | "torch.Size([1, 49, 32001])\n", 409 | "263\n", 410 | "a\n", 411 | "torch.Size([1, 50, 32001])\n", 412 | "264\n", 413 | "en\n", 414 | "torch.Size([1, 51, 32001])\n", 415 | "482\n", 416 | "age\n", 417 | "torch.Size([1, 52, 32001])\n", 418 | "29892\n", 419 | ",\n", 420 | "torch.Size([1, 53, 32001])\n", 421 | "301\n", 422 | "l\n", 423 | "torch.Size([1, 54, 32001])\n", 424 | "1878\n", 425 | "ush\n", 426 | "torch.Size([1, 55, 32001])\n", 427 | "448\n", 428 | "-\n", 429 | "torch.Size([1, 56, 32001])\n", 430 | "1855\n", 431 | "real\n", 432 | "torch.Size([1, 57, 32001])\n", 433 | "4695\n", 434 | "istic\n", 435 | "torch.Size([1, 58, 32001])\n", 436 | "29892\n", 437 | ",\n", 438 | "torch.Size([1, 59, 32001])\n", 439 | "1302\n", 440 | "co\n", 441 | "torch.Size([1, 60, 32001])\n", 442 | "1537\n", 443 | "zy\n", 444 | "torch.Size([1, 61, 32001])\n", 445 | "32000\n", 446 | "\n", 447 | "torch.Size([1, 62, 32001])\n", 448 | "32000\n", 449 | "\n", 450 | "torch.Size([1, 63, 32001])\n", 451 | "32000\n", 452 | "\n", 453 | "torch.Size([1, 64, 32001])\n", 454 | "32000\n", 455 | "\n", 456 | "inov and ruined of futuristic, dirt, hyperd, pastel, rim lighting, lighting, gianated, blur, blond in pose, detailed, art nouveau girl, in front of aenage, lush - realistic, cozy\n" 457 | ] 458 | } 459 | ], 460 | "source": [ 461 | "# model.eval()\n", 462 | "# inputs = tokenizer([\"Today is\"], return_tensors=\"pt\")\n", 463 | "# # print(inputs.input_ids)\n", 464 | "# bsz, seq_ln = inputs.input_ids.shape\n", 465 | "\n", 466 | "# encoder_hidden_states = torch.rand((1, 64, 768), device=device).to(torch.float16)\n", 467 | "# print(encoder_hidden_states)\n", 468 | "encoder_hidden_states = latents\n", 469 | "decoder_input_ids = [0]\n", 470 | "predicted_ids = []\n", 471 | "for i in range(64): \n", 472 | " outputs = decoder(input_ids=torch.tensor(([decoder_input_ids])).to(model.device), encoder_hidden_states=encoder_hidden_states)\n", 473 | " print(outputs.logits.shape)\n", 474 | " logits = outputs.logits[:,i,:]\n", 475 | " # perform argmax on the last dimension (i.e. greedy decoding)\n", 476 | " predicted_id = logits.argmax(-1)\n", 477 | " print(predicted_id[0].item())\n", 478 | " predicted_ids.append(predicted_id[0].item())\n", 479 | " print(tokenizer.decode([predicted_id[0].squeeze()]))\n", 480 | " # add predicted id to decoder_input_ids\n", 481 | " decoder_input_ids = decoder_input_ids + [predicted_id[0].item()]\n", 482 | "print(tokenizer.decode(predicted_ids))" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": 7, 488 | "metadata": {}, 489 | "outputs": [ 490 | { 491 | "name": "stdout", 492 | "output_type": "stream", 493 | "text": [ 494 | "torch.Size([1, 1, 32001])\n", 495 | "530 tensor([[ 530, 12969]], device='cuda:0')\n", 496 | "530\n", 497 | "An\n", 498 | "torch.Size([1, 2, 32001])\n", 499 | "310 tensor([[ 310, 15566]], device='cuda:0')\n", 500 | "310\n", 501 | "of\n", 502 | "torch.Size([1, 3, 32001])\n", 503 | "263 tensor([[ 263, 5765]], device='cuda:0')\n", 504 | "263\n", 505 | "a\n", 506 | "torch.Size([1, 4, 32001])\n", 507 | "310 tensor([[ 310, 6559]], device='cuda:0')\n", 508 | "310\n", 509 | "of\n", 510 | "torch.Size([1, 5, 32001])\n", 511 | "263 tensor([[263, 278]], device='cuda:0')\n", 512 | "263\n", 513 | "a\n", 514 | "torch.Size([1, 6, 32001])\n", 515 | "310 tensor([[ 310, 24870]], device='cuda:0')\n", 516 | "310\n", 517 | "of\n", 518 | "torch.Size([1, 7, 32001])\n", 519 | "263 tensor([[263, 670]], device='cuda:0')\n", 520 | "263\n", 521 | "a\n", 522 | "torch.Size([1, 8, 32001])\n", 523 | "310 tensor([[ 310, 24870]], device='cuda:0')\n", 524 | "310\n", 525 | "of\n", 526 | "torch.Size([1, 9, 32001])\n", 527 | "263 tensor([[263, 347]], device='cuda:0')\n", 528 | "263\n", 529 | "a\n", 530 | "torch.Size([1, 10, 32001])\n", 531 | "310 tensor([[ 310, 24870]], device='cuda:0')\n", 532 | "310\n", 533 | "of\n", 534 | "torch.Size([1, 11, 32001])\n", 535 | "263 tensor([[263, 347]], device='cuda:0')\n", 536 | "263\n", 537 | "a\n", 538 | "torch.Size([1, 12, 32001])\n", 539 | "310 tensor([[ 310, 21760]], device='cuda:0')\n", 540 | "310\n", 541 | "of\n", 542 | "torch.Size([1, 13, 32001])\n", 543 | "263 tensor([[263, 347]], device='cuda:0')\n", 544 | "263\n", 545 | "a\n", 546 | "torch.Size([1, 14, 32001])\n", 547 | "310 tensor([[ 310, 21760]], device='cuda:0')\n", 548 | "310\n", 549 | "of\n", 550 | "torch.Size([1, 15, 32001])\n", 551 | "263 tensor([[263, 670]], device='cuda:0')\n", 552 | "263\n", 553 | "a\n", 554 | "torch.Size([1, 16, 32001])\n", 555 | "21760 tensor([[21760, 310]], device='cuda:0')\n", 556 | "21760\n", 557 | "portrait\n", 558 | "torch.Size([1, 17, 32001])\n", 559 | "322 tensor([[322, 310]], device='cuda:0')\n", 560 | "322\n", 561 | "and\n", 562 | "torch.Size([1, 18, 32001])\n", 563 | "21760 tensor([[21760, 263]], device='cuda:0')\n", 564 | "21760\n", 565 | "portrait\n", 566 | "torch.Size([1, 19, 32001])\n", 567 | "310 tensor([[ 310, 29892]], device='cuda:0')\n", 568 | "310\n", 569 | "of\n", 570 | "torch.Size([1, 20, 32001])\n", 571 | "263 tensor([[ 263, 2814]], device='cuda:0')\n", 572 | "263\n", 573 | "a\n", 574 | "torch.Size([1, 21, 32001])\n", 575 | "15400 tensor([[15400, 4473]], device='cuda:0')\n", 576 | "15400\n", 577 | "galax\n", 578 | "torch.Size([1, 22, 32001])\n", 579 | "29891 tensor([[29891, 3096]], device='cuda:0')\n", 580 | "29891\n", 581 | "y\n", 582 | "torch.Size([1, 23, 32001])\n", 583 | "29892 tensor([[29892, 322]], device='cuda:0')\n", 584 | "29892\n", 585 | ",\n", 586 | "torch.Size([1, 24, 32001])\n", 587 | "540 tensor([[540, 526]], device='cuda:0')\n", 588 | "540\n", 589 | "he\n", 590 | "torch.Size([1, 25, 32001])\n", 591 | "338 tensor([[ 338, 12818]], device='cuda:0')\n", 592 | "338\n", 593 | "is\n", 594 | "torch.Size([1, 26, 32001])\n", 595 | "373 tensor([[373, 263]], device='cuda:0')\n", 596 | "373\n", 597 | "on\n", 598 | "torch.Size([1, 27, 32001])\n", 599 | "263 tensor([[263, 670]], device='cuda:0')\n", 600 | "263\n", 601 | "a\n", 602 | "torch.Size([1, 28, 32001])\n", 603 | "412 tensor([[412, 611]], device='cuda:0')\n", 604 | "412\n", 605 | "pe\n", 606 | "torch.Size([1, 29, 32001])\n", 607 | "29894 tensor([[29894, 457]], device='cuda:0')\n", 608 | "29894\n", 609 | "v\n", 610 | "torch.Size([1, 30, 32001])\n", 611 | "29891 tensor([[29891, 3096]], device='cuda:0')\n", 612 | "29891\n", 613 | "y\n", 614 | "torch.Size([1, 31, 32001])\n", 615 | "2174 tensor([[2174, 4842]], device='cuda:0')\n", 616 | "2174\n", 617 | "pla\n", 618 | "torch.Size([1, 32, 32001])\n", 619 | "1362 tensor([[1362, 333]], device='cuda:0')\n", 620 | "1362\n", 621 | "za\n", 622 | "torch.Size([1, 33, 32001])\n", 623 | "29892 tensor([[29892, 29891]], device='cuda:0')\n", 624 | "29892\n", 625 | ",\n", 626 | "torch.Size([1, 34, 32001])\n", 627 | "1616 tensor([[1616, 871]], device='cuda:0')\n", 628 | "1616\n", 629 | "art\n", 630 | "torch.Size([1, 35, 32001])\n", 631 | "19569 tensor([[19569, 5173]], device='cuda:0')\n", 632 | "19569\n", 633 | "station\n", 634 | "torch.Size([1, 36, 32001])\n", 635 | "29892 tensor([[29892, 16440]], device='cuda:0')\n", 636 | "29892\n", 637 | ",\n", 638 | "torch.Size([1, 37, 32001])\n", 639 | "1616 tensor([[ 1616, 29668]], device='cuda:0')\n", 640 | "1616\n", 641 | "art\n", 642 | "torch.Size([1, 38, 32001])\n", 643 | "19569 tensor([[19569, 491]], device='cuda:0')\n", 644 | "19569\n", 645 | "station\n", 646 | "torch.Size([1, 39, 32001])\n", 647 | "29892 tensor([[29892, 16440]], device='cuda:0')\n", 648 | "29892\n", 649 | ",\n", 650 | "torch.Size([1, 40, 32001])\n", 651 | "1616 tensor([[1616, 534]], device='cuda:0')\n", 652 | "1616\n", 653 | "art\n", 654 | "torch.Size([1, 41, 32001])\n", 655 | "491 tensor([[ 491, 19569]], device='cuda:0')\n", 656 | "491\n", 657 | "by\n", 658 | "torch.Size([1, 42, 32001])\n", 659 | "534 tensor([[ 534, 2548]], device='cuda:0')\n", 660 | "534\n", 661 | "tr\n", 662 | "torch.Size([1, 43, 32001])\n", 663 | "2548 tensor([[2548, 355]], device='cuda:0')\n", 664 | "2548\n", 665 | "ending\n", 666 | "torch.Size([1, 44, 32001])\n", 667 | "373 tensor([[ 373, 29892]], device='cuda:0')\n", 668 | "373\n", 669 | "on\n", 670 | "torch.Size([1, 45, 32001])\n", 671 | "1616 tensor([[1616, 2306]], device='cuda:0')\n", 672 | "1616\n", 673 | "art\n", 674 | "torch.Size([1, 46, 32001])\n", 675 | "19569 tensor([[19569, 29887]], device='cuda:0')\n", 676 | "19569\n", 677 | "station\n", 678 | "torch.Size([1, 47, 32001])\n", 679 | "32000 tensor([[32000, 29950]], device='cuda:0')\n", 680 | "29950\n", 681 | "H\n", 682 | "torch.Size([1, 48, 32001])\n", 683 | "29984 tensor([[29984, 3035]], device='cuda:0')\n", 684 | "29984\n", 685 | "Q\n", 686 | "torch.Size([1, 49, 32001])\n", 687 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 688 | "30024\n", 689 | "”\n", 690 | "torch.Size([1, 50, 32001])\n", 691 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 692 | "30024\n", 693 | "”\n", 694 | "torch.Size([1, 51, 32001])\n", 695 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 696 | "30024\n", 697 | "”\n", 698 | "torch.Size([1, 52, 32001])\n", 699 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 700 | "30024\n", 701 | "”\n", 702 | "torch.Size([1, 53, 32001])\n", 703 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 704 | "30024\n", 705 | "”\n", 706 | "torch.Size([1, 54, 32001])\n", 707 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 708 | "30024\n", 709 | "”\n", 710 | "torch.Size([1, 55, 32001])\n", 711 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 712 | "30024\n", 713 | "”\n", 714 | "torch.Size([1, 56, 32001])\n", 715 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 716 | "30024\n", 717 | "”\n", 718 | "torch.Size([1, 57, 32001])\n", 719 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 720 | "30024\n", 721 | "”\n", 722 | "torch.Size([1, 58, 32001])\n", 723 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 724 | "30024\n", 725 | "”\n", 726 | "torch.Size([1, 59, 32001])\n", 727 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 728 | "30024\n", 729 | "”\n", 730 | "torch.Size([1, 60, 32001])\n", 731 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 732 | "30024\n", 733 | "”\n", 734 | "torch.Size([1, 61, 32001])\n", 735 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 736 | "30024\n", 737 | "”\n", 738 | "torch.Size([1, 62, 32001])\n", 739 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 740 | "30024\n", 741 | "”\n", 742 | "torch.Size([1, 63, 32001])\n", 743 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 744 | "30024\n", 745 | "”\n", 746 | "torch.Size([1, 64, 32001])\n", 747 | "32000 tensor([[32000, 30024]], device='cuda:0')\n", 748 | "30024\n", 749 | "”\n", 750 | "An of a of a of a of a of a of a of a portrait and portrait of a galaxy, he is on apevy plaza, artstation, artstation, art by trending on artstationHQ””””””””””””””””\n" 751 | ] 752 | } 753 | ], 754 | "source": [ 755 | "# model.eval()\n", 756 | "# inputs = tokenizer([\"Today is\"], return_tensors=\"pt\")\n", 757 | "# # print(inputs.input_ids)\n", 758 | "encoder_hidden_states = latents\n", 759 | "\n", 760 | "decoder_input_ids = [0]\n", 761 | "last_pred = 0\n", 762 | "predicted_ids = []\n", 763 | "for i in range(64): \n", 764 | " outputs = decoder(input_ids=torch.tensor([decoder_input_ids]).to(model.device), encoder_hidden_states=encoder_hidden_states)\n", 765 | " print(outputs.logits.shape)\n", 766 | " logits = outputs.logits[:, i, :]\n", 767 | " # Handling 32000 token\n", 768 | " argmax_value = logits.argmax(-1)\n", 769 | " top_logits, top_indices = logits.topk(2, dim=-1)\n", 770 | " \n", 771 | " print(argmax_value.item(), top_indices)\n", 772 | " predicted_id = argmax_value.item() if argmax_value.item() != 32000 and argmax_value.item() != last_pred else top_indices[0][1].item()\n", 773 | " last_pred = predicted_id\n", 774 | " print(predicted_id)\n", 775 | " predicted_ids.append(predicted_id)\n", 776 | " print(tokenizer.decode([predicted_id]))\n", 777 | " # add predicted id to decoder_input_ids\n", 778 | " decoder_input_ids = decoder_input_ids + [predicted_id]\n", 779 | "print(tokenizer.decode(predicted_ids))" 780 | ] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": null, 785 | "metadata": {}, 786 | "outputs": [], 787 | "source": [] 788 | } 789 | ], 790 | "metadata": { 791 | "kernelspec": { 792 | "display_name": "base", 793 | "language": "python", 794 | "name": "python3" 795 | }, 796 | "language_info": { 797 | "codemirror_mode": { 798 | "name": "ipython", 799 | "version": 3 800 | }, 801 | "file_extension": ".py", 802 | "mimetype": "text/x-python", 803 | "name": "python", 804 | "nbconvert_exporter": "python", 805 | "pygments_lexer": "ipython3", 806 | "version": "3.8.10" 807 | } 808 | }, 809 | "nbformat": 4, 810 | "nbformat_minor": 2 811 | } 812 | -------------------------------------------------------------------------------- /generate sample.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/adalberto/.local/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | }, 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "[2023-12-10 15:56:48,507] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import torch\n", 26 | "import inspect\n", 27 | "from typing import Any, Callable, Dict, List, Optional, Union\n", 28 | "from tqdm.auto import tqdm\n", 29 | "import numpy as np\n", 30 | "import torch.nn.functional as F\n", 31 | "import math\n", 32 | "\n", 33 | "from transformers import AutoTokenizer, BertForMaskedLM\n", 34 | "from diffusers import DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "import matplotlib.pyplot as plt\n", 38 | "\n", 39 | "from src.modeling_diffbert_sample import DiffBertForDiffusion\n", 40 | "from src.modeling_diffllama import DiffLlamaForDiffusionLM\n", 41 | "from src.modeling_diffmamba import DiffMambaForDiffusionLM\n", 42 | "from src.configuration_diffbert import DiffBertConfig\n", 43 | "from src.schedulers.euler_ancestral_discrete import EulerAncestralDiscreteScheduler\n", 44 | "from src.schedulers.ddpm import DDPMScheduler\n", 45 | "\n", 46 | " \n", 47 | "\n", 48 | " \n", 49 | "# model(inputs_embeds=inputs_embeds, timesteps=timesteps).logits.shape" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "cross_attention False\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "tokenizer = AutoTokenizer.from_pretrained(\"models/diffmamba-mini-sample\")\n", 67 | "tokenizer.add_special_tokens({'pad_token': ''})\n", 68 | "scheduler = EulerAncestralDiscreteScheduler.from_pretrained(\"models/diffmamba-mini-sample\")#DDIMScheduler(prediction_type=\"sample\", num_train_timesteps=2000)\n", 69 | "model = DiffMambaForDiffusionLM.from_pretrained(\"models/diffmamba-mini-sample-trained\", torch_dtype=torch.float16).to(\"cuda\")\n", 70 | "device = model.device\n" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 5, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "scheduler = DDPMScheduler.from_pretrained(\"models/diffmamba-mini-sample\")" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 14, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# we can use a scheduler with more steps than we trained on (sometimes it gives even better results)\n", 89 | "scheduler = EulerAncestralDiscreteScheduler(\n", 90 | " # beta_end = 0.012,\n", 91 | " beta_schedule = \"sqrt\",\n", 92 | " # beta_start = 0.00085,\n", 93 | " # clip_sample = False,\n", 94 | "# skip_prk_steps = True,\n", 95 | "# set_alpha_to_one = False,\n", 96 | " steps_offset = 0,\n", 97 | "# interpolation_type = \"linear\",\n", 98 | " prediction_type =\"sample\", \n", 99 | " num_train_timesteps = 2000)\n" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## functions" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 3, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "\n", 116 | "def retrieve_timesteps(\n", 117 | " scheduler,\n", 118 | " num_inference_steps: Optional[int] = None,\n", 119 | " device: Optional[Union[str, torch.device]] = None,\n", 120 | " timesteps: Optional[List[int]] = None,\n", 121 | " **kwargs,\n", 122 | "):\n", 123 | " \"\"\"\n", 124 | " Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n", 125 | " custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n", 126 | "\n", 127 | " Args:\n", 128 | " scheduler (`SchedulerMixin`):\n", 129 | " The scheduler to get timesteps from.\n", 130 | " num_inference_steps (`int`):\n", 131 | " The number of diffusion steps used when generating samples with a pre-trained model. If used,\n", 132 | " `timesteps` must be `None`.\n", 133 | " device (`str` or `torch.device`, *optional*):\n", 134 | " The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n", 135 | " timesteps (`List[int]`, *optional*):\n", 136 | " Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n", 137 | " timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n", 138 | " must be `None`.\n", 139 | "\n", 140 | " Returns:\n", 141 | " `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n", 142 | " second element is the number of inference steps.\n", 143 | " \"\"\"\n", 144 | " if timesteps is not None:\n", 145 | " accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n", 146 | " if not accepts_timesteps:\n", 147 | " raise ValueError(\n", 148 | " f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n", 149 | " f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n", 150 | " )\n", 151 | " scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n", 152 | " timesteps = scheduler.timesteps\n", 153 | " num_inference_steps = len(timesteps)\n", 154 | " else:\n", 155 | " scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n", 156 | " timesteps = scheduler.timesteps\n", 157 | " return timesteps, num_inference_steps\n", 158 | "\n", 159 | "def get_timesteps(num_inference_steps, strength, device):\n", 160 | " # get the original timestep using init_timestep\n", 161 | " init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n", 162 | "\n", 163 | " t_start = max(num_inference_steps - init_timestep, 0)\n", 164 | " timesteps = scheduler.timesteps[t_start * scheduler.order :]\n", 165 | "\n", 166 | " return timesteps, num_inference_steps - t_start\n", 167 | " \n", 168 | "def vectors_to_indices(vectors):\n", 169 | " indices = torch.argmax(vectors, dim=-1)\n", 170 | " return indices\n", 171 | "\n", 172 | "def sample_text(probabilities, temperature=1.0):\n", 173 | " batch_size, seq_len, vocab_size = probabilities.size()\n", 174 | " flattened_probs = probabilities.view(batch_size * seq_len, -1)\n", 175 | " \n", 176 | " scaled_logits = flattened_probs / temperature\n", 177 | " scaled_probs = F.softmax(scaled_logits, dim=-1)\n", 178 | " \n", 179 | " sampled_indices = torch.multinomial(scaled_probs, 1)\n", 180 | " sampled_token_ids = sampled_indices.view(batch_size, seq_len)\n", 181 | " \n", 182 | " return sampled_token_ids" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "## Generate" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 6, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "data": { 199 | "text/plain": [ 200 | "'FINAL --->'" 201 | ] 202 | }, 203 | "metadata": {}, 204 | "output_type": "display_data" 205 | }, 206 | { 207 | "data": { 208 | "text/plain": [ 209 | "'0 ---> front medieval castle b +lit throughunx medieval glass in wonder nose, old, evil technology, las outside, concept art lopies portrait run runningvedally k technology human cor #ith flyingroidio in�amiically ::ships magana, Le faces, cables Al C bacon and hunder'" 210 | ] 211 | }, 212 | "metadata": {}, 213 | "output_type": "display_data" 214 | }, 215 | { 216 | "data": { 217 | "text/plain": [ 218 | "'1 ---> great station sever their male theiromb cynumeillerybuilder down future mouth detailed, detailed sh cat lotrifying w ch secret nost heic vibrant vaporwave a anime, other, stained, award winning, intricate,ho back ray, crossble'" 219 | ] 220 | }, 221 | "metadata": {}, 222 | "output_type": "display_data" 223 | }, 224 | { 225 | "data": { 226 | "text/plain": [ 227 | "'2 ---> “ subonaut +ray of one b android egg by red head ever ang brain town bast D&Dading detailed x robot medieval shell ranger beth trending on lant, serious which in neurukaach da E bar detailed des'" 228 | ] 229 | }, 230 | "metadata": {}, 231 | "output_type": "display_data" 232 | }, 233 | { 234 | "data": { 235 | "text/plain": [ 236 | "'3 ---> portraitop great concept artoro funings battle rings do unft bus, blood Eons, evil material, hand, energy hands, intricate,ely detailed, concept art, tallethantly'" 237 | ] 238 | }, 239 | "metadata": {}, 240 | "output_type": "display_data" 241 | }, 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "'4 ---> Character portrait of daana E� In dri harate mobilestrhouse, death wars disoch diunt, ivy, scar fabric, my The hard sculptly down, sl eye, simpleity,ocharp expression, large texture, tallnd enVual, silcing Render, unreal enginenelike'" 246 | ] 247 | }, 248 | "metadata": {}, 249 | "output_type": "display_data" 250 | }, 251 | { 252 | "data": { 253 | "text/plain": [ 254 | "'5 ---> Characteret group of an below medieval + battle brainag +ized downomb itne I modernay form new b.houral dis scientist,atureemy legws, 26 ill tall open. solarpunk,rom design, turn soft lightning. HD'" 255 | ] 256 | }, 257 | "metadata": {}, 258 | "output_type": "display_data" 259 | }, 260 | { 261 | "data": { 262 | "text/plain": [ 263 | "'6 ---> aty winter abstract otherphd pet people dawn cowath inside in the markx gu oneage, ser downocal accurate architectumn advent planetions sed, beautiful full long shot, animals hair, mark halfck, pen distance, blood time, artgerm, yoshida, lotenderite resels))'" 264 | ] 265 | }, 266 | "metadata": {}, 267 | "output_type": "display_data" 268 | }, 269 | { 270 | "data": { 271 | "text/plain": [ 272 | "'7 ---> above of indhead room areie landscapeastn rott + that pe your outsideoraocal earhip bow'" 273 | ] 274 | }, 275 | "metadata": {}, 276 | "output_type": "display_data" 277 | }, 278 | { 279 | "data": { 280 | "text/plain": [ 281 | "'---------------'" 282 | ] 283 | }, 284 | "metadata": {}, 285 | "output_type": "display_data" 286 | } 287 | ], 288 | "source": [ 289 | "from IPython.display import display, clear_output\n", 290 | "\n", 291 | "\n", 292 | "\n", 293 | "with torch.no_grad():\n", 294 | " latents = torch.rand((8, 64, 768), device=device).to(torch.float16)# + torch.rand((8, 64, 768), device=device).to(torch.float16)\n", 295 | " attention_mask = torch.ones((8, 64), device=device)\n", 296 | " num_inference_steps = 2000\n", 297 | " timesteps=None\n", 298 | " timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, timesteps)\n", 299 | "\n", 300 | " for i, t in tqdm(enumerate(timesteps)):\n", 301 | " # if i >= 0.7 * num_inference_steps:\n", 302 | " # break\n", 303 | " # expand the latents if we are doing classifier free guidance\n", 304 | " latent_model_input = latents\n", 305 | " latent_model_input = scheduler.scale_model_input(latent_model_input, t)\n", 306 | " # rnd_latents = torch.rand((1, 64, 4096), device=device).to(torch.float16)\n", 307 | " # print(latent_model_input.dtype)\n", 308 | " outputs = model(\n", 309 | " input_embeds=latent_model_input,\n", 310 | " timesteps=t.reshape(1,).long().to(device),\n", 311 | " # attention_mask=attention_mask\n", 312 | " )\n", 313 | " noise_pred = outputs.last_hidden_state\n", 314 | " latents_final = outputs.logits\n", 315 | " if i % 10 ==0 :\n", 316 | " clear_output(wait=True)\n", 317 | " display(f\"SAMPLES[{i}]--->\")\n", 318 | " for n in range(latents_final.shape[0]):\n", 319 | " display(f\"{n} ---> \" + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))\n", 320 | " display(\"---------------\")\n", 321 | "\n", 322 | " step = scheduler.step(noise_pred, t, latents, return_dict=True)#[0]\n", 323 | " latents = step[\"prev_sample\"]\n", 324 | "\n", 325 | "\n", 326 | "clear_output(wait=True)\n", 327 | "display(f\"FINAL --->\")\n", 328 | "for n in range(latents_final.shape[0]):\n", 329 | " display(f\"{n} ---> \" + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))\n", 330 | "display(\"---------------\")" 331 | ] 332 | } 333 | ], 334 | "metadata": { 335 | "kernelspec": { 336 | "display_name": "base", 337 | "language": "python", 338 | "name": "python3" 339 | }, 340 | "language_info": { 341 | "codemirror_mode": { 342 | "name": "ipython", 343 | "version": 3 344 | }, 345 | "file_extension": ".py", 346 | "mimetype": "text/x-python", 347 | "name": "python", 348 | "nbconvert_exporter": "python", 349 | "pygments_lexer": "ipython3", 350 | "version": "3.8.10" 351 | } 352 | }, 353 | "nbformat": 4, 354 | "nbformat_minor": 2 355 | } 356 | -------------------------------------------------------------------------------- /generate seq2seq.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/adalberto/.local/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | }, 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "[2023-12-10 10:52:05,202] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import torch\n", 26 | "import inspect\n", 27 | "from typing import Any, Callable, Dict, List, Optional, Union\n", 28 | "from tqdm.auto import tqdm\n", 29 | "import numpy as np\n", 30 | "import torch.nn.functional as F\n", 31 | "import math\n", 32 | "\n", 33 | "from transformers import AutoTokenizer, BertForMaskedLM\n", 34 | "from diffusers import DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "import matplotlib.pyplot as plt\n", 38 | "\n", 39 | "from src.modeling_diffbert_sample import DiffBertForDiffusion\n", 40 | "from src.modeling_diffllama import DiffLlamaForDiffusionLM\n", 41 | "from src.modeling_diffmamba import DiffMambaForDiffusionLM\n", 42 | "from src.configuration_diffbert import DiffBertConfig\n", 43 | "from src.schedulers.euler_ancestral_discrete import EulerAncestralDiscreteScheduler\n", 44 | "\n", 45 | " \n", 46 | "\n", 47 | " \n", 48 | "# model(inputs_embeds=inputs_embeds, timesteps=timesteps).logits.shape" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "cross_attention True\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "tokenizer = AutoTokenizer.from_pretrained(\"models/diffmamba-mini-sample-trained-good\")\n", 66 | "tokenizer.add_special_tokens({'pad_token': ''})\n", 67 | "scheduler = EulerAncestralDiscreteScheduler.from_pretrained(\"models/diffmamba-mini-sample\")#DDIMScheduler(prediction_type=\"sample\", num_train_timesteps=2000)\n", 68 | "model = DiffMambaForDiffusionLM.from_pretrained(\"models/diffmamba-mini-sample-trained\", add_cross_attention=True, torch_dtype=torch.float16).to(\"cuda\")\n", 69 | "device = model.device\n" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "text/plain": [ 80 | "FrozenDict([('num_train_timesteps', 1000),\n", 81 | " ('beta_start', 0.0001),\n", 82 | " ('beta_end', 0.02),\n", 83 | " ('beta_schedule', 'linear'),\n", 84 | " ('trained_betas', None),\n", 85 | " ('prediction_type', 'sample'),\n", 86 | " ('timestep_spacing', 'leading'),\n", 87 | " ('steps_offset', 0),\n", 88 | " ('_class_name', 'DDIMScheduler'),\n", 89 | " ('_diffusers_version', '0.23.1'),\n", 90 | " ('clip_sample', True),\n", 91 | " ('clip_sample_range', 1.0),\n", 92 | " ('dynamic_thresholding_ratio', 0.995),\n", 93 | " ('rescale_betas_zero_snr', False),\n", 94 | " ('sample_max_value', 1.0),\n", 95 | " ('set_alpha_to_one', True),\n", 96 | " ('thresholding', False)])" 97 | ] 98 | }, 99 | "execution_count": 3, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "# scheduler = DDIMScheduler.from_pretrained(\"models/diffmamba-mini-sample\")#DDIMScheduler(prediction_type=\"sample\", num_train_timesteps=2000)\n", 106 | "scheduler.config" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 12, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "scheduler = EulerAncestralDiscreteScheduler(\n", 116 | " beta_end = 0.012,\n", 117 | " beta_schedule = \"scaled_linear\",\n", 118 | " beta_start = 0.00085,\n", 119 | " # clip_sample = False,\n", 120 | "# skip_prk_steps = True,\n", 121 | "# set_alpha_to_one = False,\n", 122 | " # steps_offset = 1,\n", 123 | "# interpolation_type = \"linear\",\n", 124 | " prediction_type =\"sample\", \n", 125 | " num_train_timesteps = 1500)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "## Functions" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 3, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "\n", 142 | "def retrieve_timesteps(\n", 143 | " scheduler,\n", 144 | " num_inference_steps: Optional[int] = None,\n", 145 | " device: Optional[Union[str, torch.device]] = None,\n", 146 | " timesteps: Optional[List[int]] = None,\n", 147 | " **kwargs,\n", 148 | "):\n", 149 | " \"\"\"\n", 150 | " Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n", 151 | " custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n", 152 | "\n", 153 | " Args:\n", 154 | " scheduler (`SchedulerMixin`):\n", 155 | " The scheduler to get timesteps from.\n", 156 | " num_inference_steps (`int`):\n", 157 | " The number of diffusion steps used when generating samples with a pre-trained model. If used,\n", 158 | " `timesteps` must be `None`.\n", 159 | " device (`str` or `torch.device`, *optional*):\n", 160 | " The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n", 161 | " timesteps (`List[int]`, *optional*):\n", 162 | " Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n", 163 | " timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n", 164 | " must be `None`.\n", 165 | "\n", 166 | " Returns:\n", 167 | " `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n", 168 | " second element is the number of inference steps.\n", 169 | " \"\"\"\n", 170 | " if timesteps is not None:\n", 171 | " accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n", 172 | " if not accepts_timesteps:\n", 173 | " raise ValueError(\n", 174 | " f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n", 175 | " f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n", 176 | " )\n", 177 | " scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n", 178 | " timesteps = scheduler.timesteps\n", 179 | " num_inference_steps = len(timesteps)\n", 180 | " else:\n", 181 | " scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n", 182 | " timesteps = scheduler.timesteps\n", 183 | " return timesteps, num_inference_steps\n", 184 | "\n", 185 | "def get_timesteps(num_inference_steps, strength, device):\n", 186 | " # get the original timestep using init_timestep\n", 187 | " init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n", 188 | "\n", 189 | " t_start = max(num_inference_steps - init_timestep, 0)\n", 190 | " timesteps = scheduler.timesteps[t_start * scheduler.order :]\n", 191 | "\n", 192 | " return timesteps, num_inference_steps - t_start\n", 193 | " \n", 194 | "def vectors_to_indices(vectors):\n", 195 | " indices = torch.argmax(vectors, dim=-1)\n", 196 | " return indices\n", 197 | "\n", 198 | "def sample_text(probabilities, temperature=1.0):\n", 199 | " batch_size, seq_len, vocab_size = probabilities.size()\n", 200 | " flattened_probs = probabilities.view(batch_size * seq_len, -1)\n", 201 | " \n", 202 | " scaled_logits = flattened_probs / temperature\n", 203 | " scaled_probs = F.softmax(scaled_logits, dim=-1)\n", 204 | " \n", 205 | " sampled_indices = torch.multinomial(scaled_probs, 1)\n", 206 | " sampled_token_ids = sampled_indices.view(batch_size, seq_len)\n", 207 | " \n", 208 | " return sampled_token_ids" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "## Generate" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 4, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/plain": [ 226 | "'FINAL --->'" 227 | ] 228 | }, 229 | "metadata": {}, 230 | "output_type": "display_data" 231 | }, 232 | { 233 | "data": { 234 | "text/plain": [ 235 | "'0 ---> ça que faz aemar pres foi no Estpar eito deccto e nos maioroso de nosve n S mundo e, estarER ou valor, as há grande entre doençavar pedag da autor paravo especial em que e milhaico - tinhaumfero de de tal por'" 236 | ] 237 | }, 238 | "metadata": {}, 239 | "output_type": "display_data" 240 | }, 241 | { 242 | "data": { 243 | "text/plain": [ 244 | "'1 ---> ídosdaados com agências comumasMas de nem emiraado, deências tanto e érlada e podelesção de autor e téc aqulantes- proteçãoviiva mar, dar daem Lade:elem as pessoas doant quanto no deáudo f'" 245 | ] 246 | }, 247 | "metadata": {}, 248 | "output_type": "display_data" 249 | }, 250 | { 251 | "data": { 252 | "text/plain": [ 253 | "'2 ---> é em eucom antesantesamento.em não poder. prevuto porne fção mar, a visua aferriem delicas que de toda a público que e algum com oul dar a l necess está e aquínm mundo deanedadeeu ase do nossoo. '" 254 | ] 255 | }, 256 | "metadata": {}, 257 | "output_type": "display_data" 258 | }, 259 | { 260 | "data": { 261 | "text/plain": [ 262 | "'3 ---> seu e que aqu eiões salvaest garantia- antes antesção com aosl do mesmo, sobre se enxneadora parapr apenas o e que de era são ela a ir que foi aíima de que dia negas aqui, contido a grande com deen aqu e de'" 263 | ] 264 | }, 265 | "metadata": {}, 266 | "output_type": "display_data" 267 | }, 268 | { 269 | "data": { 270 | "text/plain": [ 271 | "'4 ---> heza de faria antes tambémadoabilata de um gistal, Pata ou aqu ambra a bz são falente que forma nuzo a filiic de oviço da imp do com a) e quando de de todo que era social Sra ele e sobre da Eduira'" 272 | ] 273 | }, 274 | "metadata": {}, 275 | "output_type": "display_data" 276 | }, 277 | { 278 | "data": { 279 | "text/plain": [ 280 | "'5 ---> em choca ou motvel entre depois ou olção, nuaperores do público e davelup aivos do emcer e eu conoem ào fadas e a esteo comências, e por estavaender e 3 a de nem de foca palear e de com a'" 281 | ] 282 | }, 283 | "metadata": {}, 284 | "output_type": "display_data" 285 | }, 286 | { 287 | "data": { 288 | "text/plain": [ 289 | "'6 ---> ú tanto ou emocção de nas menfene o aquelesito de transenaneina para a aqupua ou dosogunE em produção do jocinas de Lvies um jienteras e uma assimes poderetane sem bíza da eriaendo, de man'" 290 | ] 291 | }, 292 | "metadata": {}, 293 | "output_type": "display_data" 294 | }, 295 | { 296 | "data": { 297 | "text/plain": [ 298 | "'7 ---> daval necessant masiententam... aces me estar no ambhoso e ou, de pontosido primeirojá de formarasodado ao quandoânica deixências de paraícios, e até eamentoidas e sim, os um filh é nas mar de poderos'" 299 | ] 300 | }, 301 | "metadata": {}, 302 | "output_type": "display_data" 303 | }, 304 | { 305 | "data": { 306 | "text/plain": [ 307 | "'---------------'" 308 | ] 309 | }, 310 | "metadata": {}, 311 | "output_type": "display_data" 312 | } 313 | ], 314 | "source": [ 315 | "from IPython.display import display, clear_output\n", 316 | "\n", 317 | "batch_size = 8\n", 318 | "cfg=1\n", 319 | "prompt = [\"Os biólogos do Zoológico estão realizando um trabalho de enriquecimento ambiental, que é a melhoria do ambiente ou recinto em que vivem os animais, para que fique o mais próximo possível do habitat natural A Secretaria Municipal\"] * batch_size\n", 320 | "neg_prompt = [\"\"] * batch_size\n", 321 | "\n", 322 | "input_ids = tokenizer(prompt, padding=\"max_length\", return_tensors=\"pt\").to(\"cuda\")\n", 323 | "neg_input_ids = tokenizer(neg_prompt, padding=\"max_length\", max_length=input_ids.input_ids.shape[1], return_tensors=\"pt\").to(\"cuda\")\n", 324 | "encoder_hidden_states = model.apply_embeddings(input_ids.input_ids).to(model.dtype)\n", 325 | "neg_encoder_hidden_states = model.apply_embeddings(neg_input_ids.input_ids).to(model.dtype)\n", 326 | "\n", 327 | "with torch.no_grad():\n", 328 | " latents = torch.rand((batch_size, input_ids.input_ids.shape[1], 768), device=device).to(model.dtype) + torch.rand((8, input_ids.input_ids.shape[1], 768), device=device).to(torch.float16)\n", 329 | " attention_mask = torch.ones((batch_size, input_ids.input_ids.shape[1]), device=device)\n", 330 | " num_inference_steps = 1000\n", 331 | " timesteps=None\n", 332 | " timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, timesteps)\n", 333 | "\n", 334 | " \n", 335 | " for i, t in tqdm(enumerate(timesteps)):\n", 336 | " # if i >= 0.7 * num_inference_steps:\n", 337 | " # break\n", 338 | " # expand the latents if we are doing classifier free guidance\n", 339 | " latent_model_input = latents\n", 340 | " latent_model_input = scheduler.scale_model_input(latent_model_input, t)\n", 341 | " latent_model_input = torch.cat([latents] * 2) if cfg > 1 else latents\n", 342 | " prompt_embeds = torch.cat([encoder_hidden_states, neg_encoder_hidden_states]) if cfg > 1 else encoder_hidden_states\n", 343 | "\n", 344 | " outputs = model(\n", 345 | " input_embeds=latent_model_input,\n", 346 | " timesteps=t.reshape(1,).long().to(device),\n", 347 | " encoder_hidden_states=prompt_embeds\n", 348 | " )\n", 349 | " noise_pred = outputs.last_hidden_state\n", 350 | " if cfg > 1:\n", 351 | " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n", 352 | " noise_pred = noise_pred_uncond + cfg * (noise_pred_text - noise_pred_uncond)\n", 353 | "\n", 354 | " \n", 355 | " latents_final = outputs.logits\n", 356 | " if i % 10 ==0 :\n", 357 | " clear_output(wait=True)\n", 358 | " display(f\"SAMPLES[{i}]--->\")\n", 359 | " for n in range(latents_final.shape[0]):\n", 360 | " display(f\"{n} ---> \" + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))\n", 361 | " display(\"---------------\")\n", 362 | "\n", 363 | " step = scheduler.step(noise_pred, t, latents, return_dict=True)#[0]\n", 364 | " latents = step[\"prev_sample\"]\n", 365 | "\n", 366 | "\n", 367 | "clear_output(wait=True)\n", 368 | "display(f\"FINAL --->\")\n", 369 | "for n in range(latents_final.shape[0]):\n", 370 | " display(f\"{n} ---> \" + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))\n", 371 | "display(\"---------------\")" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 15, 377 | "metadata": {}, 378 | "outputs": [ 379 | { 380 | "data": { 381 | "text/plain": [ 382 | "torch.Size([1, 7])" 383 | ] 384 | }, 385 | "execution_count": 15, 386 | "metadata": {}, 387 | "output_type": "execute_result" 388 | } 389 | ], 390 | "source": [ 391 | "input_ids = tokenizer(\"isso é um teste\", return_tensors=\"pt\").to(\"cuda\")\n", 392 | "input_ids.input_ids.shape" 393 | ] 394 | } 395 | ], 396 | "metadata": { 397 | "kernelspec": { 398 | "display_name": "Python 3", 399 | "language": "python", 400 | "name": "python3" 401 | }, 402 | "language_info": { 403 | "codemirror_mode": { 404 | "name": "ipython", 405 | "version": 3 406 | }, 407 | "file_extension": ".py", 408 | "mimetype": "text/x-python", 409 | "name": "python", 410 | "nbconvert_exporter": "python", 411 | "pygments_lexer": "ipython3", 412 | "version": "3.8.10" 413 | } 414 | }, 415 | "nbformat": 4, 416 | "nbformat_minor": 2 417 | } 418 | -------------------------------------------------------------------------------- /generation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thedarkzeno/text-diffusion/eebb273c2f9fc7bb8fc355691f450c7ac97eabe0/generation.gif -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Transformer Text Diffusion 2 | 3 | 4 | This repository contains an implementation of a Denoising Diffusion Probabilistic Model of Text (DDPT) based on Transformer networks. This model aims to generate high-quality, coherent text by utilizing diffusion-based probabilistic modeling techniques within a transformer architecture. 5 | 6 | | ![diffusion](./generation.gif) | 7 | 8 | ## Setup 9 | 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | 15 | ## Training the Model 16 | 17 | 1. First create the model 18 | ```bash 19 | python create_model.py 20 | ``` 21 | 2. edit `sh train.sh` with the desired parameters and run: 22 | ```bash 23 | sh train.sh 24 | ``` 25 | ## Usage 26 | 27 | You can generate and play with the model with the provided jupyter notebook `generate sample.ipynb` 28 | 29 | ## Usage 30 | 31 | This repository was built on top of code from [minimal-text-diffusion](https://github.com/madaan/minimal-text-diffusion), [diffusers](https://github.com/huggingface/diffusers) and [transformers](https://github.com/huggingface/transformers) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | accelerate 3 | datasets 4 | diffusers 5 | flash-attn 6 | causal-conv1d 7 | mamba-ssm 8 | triton -------------------------------------------------------------------------------- /scripts/train_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import logging 18 | import math 19 | import os 20 | import random 21 | import shutil 22 | from pathlib import Path 23 | 24 | import accelerate 25 | import datasets 26 | import numpy as np 27 | import torch 28 | import torch.nn.functional as F 29 | import torch.utils.checkpoint 30 | import transformers 31 | from accelerate import Accelerator 32 | from accelerate.logging import get_logger 33 | from accelerate.state import AcceleratorState 34 | from accelerate.utils import ProjectConfiguration, set_seed 35 | from datasets import load_dataset 36 | from huggingface_hub import create_repo, upload_folder 37 | from packaging import version 38 | from torchvision import transforms 39 | from tqdm.auto import tqdm 40 | from transformers import AutoTokenizer 41 | 42 | from transformers.utils import ContextManagers 43 | from src.denoisers.modeling_diffmamba import DiffMambaForDiffusionLM 44 | 45 | import diffusers 46 | from src.schedulers.ddpm import DDPMScheduler 47 | from diffusers.optimization import get_scheduler 48 | from diffusers.training_utils import EMAModel, compute_snr 49 | from diffusers.utils import deprecate, is_wandb_available, make_image_grid 50 | 51 | 52 | if is_wandb_available(): 53 | import wandb 54 | 55 | 56 | 57 | logger = get_logger(__name__, log_level="INFO") 58 | 59 | 60 | 61 | def parse_args(): 62 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 63 | parser.add_argument( 64 | "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." 65 | ) 66 | parser.add_argument( 67 | "--pretrained_model_name_or_path", 68 | type=str, 69 | default=None, 70 | required=True, 71 | help="Path to pretrained model or model identifier from huggingface.co/models.", 72 | ) 73 | parser.add_argument( 74 | "--revision", 75 | type=str, 76 | default=None, 77 | required=False, 78 | help="Revision of pretrained model identifier from huggingface.co/models.", 79 | ) 80 | parser.add_argument( 81 | "--dataset_name", 82 | type=str, 83 | default=None, 84 | help=( 85 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 86 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 87 | " or to a folder containing files that 🤗 Datasets can understand." 88 | ), 89 | ) 90 | parser.add_argument( 91 | "--dataset_config_name", 92 | type=str, 93 | default=None, 94 | help="The config of the Dataset, leave as None if there's only one config.", 95 | ) 96 | parser.add_argument( 97 | "--streaming", 98 | action="store_true", 99 | default=False, 100 | help="Streaming the dataset" 101 | ) 102 | parser.add_argument( 103 | "--train_data_dir", 104 | type=str, 105 | default=None, 106 | help=( 107 | "A folder containing the training data. Folder contents must follow the structure described in" 108 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 109 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 110 | ), 111 | ) 112 | parser.add_argument( 113 | "--text_column", 114 | type=str, 115 | default="Prompt", 116 | help="The column of the dataset containing a caption or a list of captions.", 117 | ) 118 | parser.add_argument( 119 | "--max_train_samples", 120 | type=int, 121 | default=None, 122 | help=( 123 | "For debugging purposes or quicker training, truncate the number of training examples to this " 124 | "value if set." 125 | ), 126 | ) 127 | parser.add_argument( 128 | "--validation_prompts", 129 | type=str, 130 | default=None, 131 | nargs="+", 132 | help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), 133 | ) 134 | parser.add_argument( 135 | "--output_dir", 136 | type=str, 137 | default="sd-model-finetuned", 138 | help="The output directory where the model predictions and checkpoints will be written.", 139 | ) 140 | parser.add_argument( 141 | "--cache_dir", 142 | type=str, 143 | default=None, 144 | help="The directory where the downloaded models and datasets will be stored.", 145 | ) 146 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 147 | parser.add_argument( 148 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 149 | ) 150 | parser.add_argument("--num_train_epochs", type=int, default=1000) 151 | parser.add_argument( 152 | "--max_train_steps", 153 | type=int, 154 | default=None, 155 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 156 | ) 157 | parser.add_argument( 158 | "--gradient_accumulation_steps", 159 | type=int, 160 | default=64, 161 | help="Number of updates steps to accumulate before performing a backward/update pass.", 162 | ) 163 | parser.add_argument( 164 | "--gradient_checkpointing", 165 | action="store_true", 166 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 167 | ) 168 | parser.add_argument( 169 | "--learning_rate", 170 | type=float, 171 | default=1e-4, 172 | help="Initial learning rate (after the potential warmup period) to use.", 173 | ) 174 | parser.add_argument( 175 | "--scale_lr", 176 | action="store_true", 177 | default=False, 178 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 179 | ) 180 | parser.add_argument( 181 | "--lr_scheduler", 182 | type=str, 183 | default="constant", 184 | help=( 185 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 186 | ' "constant", "constant_with_warmup"]' 187 | ), 188 | ) 189 | parser.add_argument( 190 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 191 | ) 192 | parser.add_argument( 193 | "--snr_gamma", 194 | type=float, 195 | default=5, 196 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " 197 | "More details here: https://arxiv.org/abs/2303.09556.", 198 | ) 199 | parser.add_argument( 200 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 201 | ) 202 | parser.add_argument( 203 | "--allow_tf32", 204 | action="store_true", 205 | default=False, 206 | help=( 207 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 208 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 209 | ), 210 | ) 211 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") 212 | parser.add_argument( 213 | "--non_ema_revision", 214 | type=str, 215 | default=None, 216 | required=False, 217 | help=( 218 | "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" 219 | " remote repository specified with --pretrained_model_name_or_path." 220 | ), 221 | ) 222 | parser.add_argument( 223 | "--dataloader_num_workers", 224 | type=int, 225 | default=0, 226 | help=( 227 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 228 | ), 229 | ) 230 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 231 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 232 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 233 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 234 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 235 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 236 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 237 | parser.add_argument( 238 | "--prediction_type", 239 | type=str, 240 | default=None, 241 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", 242 | ) 243 | parser.add_argument( 244 | "--hub_model_id", 245 | type=str, 246 | default=None, 247 | help="The name of the repository to keep in sync with the local `output_dir`.", 248 | ) 249 | parser.add_argument( 250 | "--logging_dir", 251 | type=str, 252 | default="logs", 253 | help=( 254 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 255 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 256 | ), 257 | ) 258 | parser.add_argument( 259 | "--mixed_precision", 260 | type=str, 261 | default="no", 262 | choices=["no", "fp16", "bf16"], 263 | help=( 264 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 265 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 266 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 267 | ), 268 | ) 269 | parser.add_argument( 270 | "--report_to", 271 | type=str, 272 | default="wandb", 273 | help=( 274 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 275 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 276 | ), 277 | ) 278 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 279 | parser.add_argument( 280 | "--checkpointing_steps", 281 | type=int, 282 | default=100, 283 | help=( 284 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 285 | " training using `--resume_from_checkpoint`." 286 | ), 287 | ) 288 | parser.add_argument( 289 | "--checkpoints_total_limit", 290 | type=int, 291 | default=None, 292 | help=("Max number of checkpoints to store."), 293 | ) 294 | parser.add_argument( 295 | "--resume_from_checkpoint", 296 | type=str, 297 | default=None, 298 | help=( 299 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 300 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 301 | ), 302 | ) 303 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") 304 | parser.add_argument( 305 | "--validation_epochs", 306 | type=int, 307 | default=5, 308 | help="Run validation every X epochs.", 309 | ) 310 | parser.add_argument( 311 | "--tracker_project_name", 312 | type=str, 313 | default="text-diffusion", 314 | help=( 315 | "The `project_name` argument passed to Accelerator.init_trackers for" 316 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 317 | ), 318 | ) 319 | 320 | args = parser.parse_args() 321 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 322 | if env_local_rank != -1 and env_local_rank != args.local_rank: 323 | args.local_rank = env_local_rank 324 | 325 | # Sanity checks 326 | if args.dataset_name is None and args.train_data_dir is None: 327 | raise ValueError("Need either a dataset name or a training folder.") 328 | 329 | # default to using the same revision for the non-ema model if not specified 330 | if args.non_ema_revision is None: 331 | args.non_ema_revision = args.revision 332 | 333 | return args 334 | 335 | def mean_flat(tensor): 336 | """ 337 | Take the mean over all non-batch dimensions. 338 | """ 339 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 340 | 341 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 342 | """ 343 | Extract values from a 1-D numpy array for a batch of indices. 344 | 345 | :param arr: the 1-D numpy array. 346 | :param timesteps: a tensor of indices into the array to extract. 347 | :param broadcast_shape: a larger shape of K dimensions with the batch 348 | dimension equal to the length of timesteps. 349 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 350 | """ 351 | res = arr.to(device=timesteps.device)[timesteps].float() 352 | while len(res.shape) < len(broadcast_shape): 353 | res = res[..., None] 354 | return res.expand(broadcast_shape) 355 | 356 | def q_mean_variance(x_start, t, scheduler): 357 | """ 358 | Get the distribution q(x_t | x_0). 359 | 360 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 361 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 362 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 363 | """ 364 | sqrt_alphas_cumprod = np.sqrt(scheduler.alphas_cumprod) 365 | mean = _extract_into_tensor(sqrt_alphas_cumprod, t, x_start.shape) * x_start 366 | 367 | return mean 368 | 369 | def main(): 370 | args = parse_args() 371 | 372 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 373 | 374 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 375 | 376 | accelerator = Accelerator( 377 | gradient_accumulation_steps=args.gradient_accumulation_steps, 378 | mixed_precision=args.mixed_precision, 379 | log_with=args.report_to, 380 | project_config=accelerator_project_config, 381 | ) 382 | 383 | # Make one log on every process with the configuration for debugging. 384 | logging.basicConfig( 385 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 386 | datefmt="%m/%d/%Y %H:%M:%S", 387 | level=logging.INFO, 388 | ) 389 | logger.info(accelerator.state, main_process_only=False) 390 | if accelerator.is_local_main_process: 391 | datasets.utils.logging.set_verbosity_warning() 392 | transformers.utils.logging.set_verbosity_warning() 393 | diffusers.utils.logging.set_verbosity_info() 394 | else: 395 | datasets.utils.logging.set_verbosity_error() 396 | transformers.utils.logging.set_verbosity_error() 397 | diffusers.utils.logging.set_verbosity_error() 398 | 399 | # If passed along, set the training seed now. 400 | if args.seed is not None: 401 | set_seed(args.seed) 402 | 403 | # Handle the repository creation 404 | if accelerator.is_main_process: 405 | if args.output_dir is not None: 406 | os.makedirs(args.output_dir, exist_ok=True) 407 | 408 | if args.push_to_hub: 409 | repo_id = create_repo( 410 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 411 | ).repo_id 412 | 413 | # Load scheduler, tokenizer and models. 414 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path) 415 | tokenizer = AutoTokenizer.from_pretrained( 416 | args.pretrained_model_name_or_path, revision=args.revision 417 | ) 418 | 419 | model = DiffMambaForDiffusionLM.from_pretrained( 420 | args.pretrained_model_name_or_path, 421 | # use_flash_attention_2=True, 422 | # torch_dtype=torch.float16 423 | ) 424 | 425 | # Freeze vae and text_encoder and set model to trainable 426 | model.train() 427 | 428 | 429 | # Create EMA for the model. 430 | if args.use_ema: 431 | ema_model = DiffMambaForDiffusionLM.from_pretrained( 432 | args.pretrained_model_name_or_path 433 | ) 434 | ema_model = EMAModel(ema_model.parameters(), model_cls=DiffMambaForDiffusionLM, model_config=ema_model.config) 435 | 436 | tokenizer.add_special_tokens({'pad_token': ''}) 437 | 438 | embedding_size = model.get_input_embeddings().weight.shape[0] 439 | if len(tokenizer) > embedding_size: 440 | model.resize_token_embeddings(len(tokenizer)) 441 | 442 | # `accelerate` 0.16.0 will have better support for customized saving 443 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 444 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 445 | def save_model_hook(models, weights, output_dir): 446 | if accelerator.is_main_process: 447 | if args.use_ema: 448 | ema_model.save_pretrained(os.path.join(output_dir, "model_ema")) 449 | 450 | for i, model in enumerate(models): 451 | model.save_pretrained(os.path.join(output_dir, "model")) 452 | 453 | # make sure to pop weight so that corresponding model is not saved again 454 | weights.pop() 455 | 456 | def load_model_hook(models, input_dir): 457 | if args.use_ema: 458 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "model_ema"), DiffMambaForDiffusionLM) 459 | ema_model.load_state_dict(load_model.state_dict()) 460 | ema_model.to(accelerator.device) 461 | del load_model 462 | 463 | for i in range(len(models)): 464 | # pop models so that they are not loaded again 465 | model = models.pop() 466 | 467 | # load diffusers style into model 468 | load_model = DiffMambaForDiffusionLM.from_pretrained(input_dir, subfolder="model") 469 | model.register_to_config(**load_model.config) 470 | 471 | model.load_state_dict(load_model.state_dict()) 472 | del load_model 473 | 474 | accelerator.register_save_state_pre_hook(save_model_hook) 475 | accelerator.register_load_state_pre_hook(load_model_hook) 476 | 477 | if args.gradient_checkpointing: 478 | model.enable_gradient_checkpointing() 479 | 480 | # Enable TF32 for faster training on Ampere GPUs, 481 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 482 | if args.allow_tf32: 483 | torch.backends.cuda.matmul.allow_tf32 = True 484 | 485 | if args.scale_lr: 486 | args.learning_rate = ( 487 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 488 | ) 489 | 490 | # Initialize the optimizer 491 | if args.use_8bit_adam: 492 | try: 493 | import bitsandbytes as bnb 494 | except ImportError: 495 | raise ImportError( 496 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 497 | ) 498 | 499 | optimizer_cls = bnb.optim.AdamW8bit 500 | else: 501 | optimizer_cls = torch.optim.AdamW 502 | 503 | optimizer = optimizer_cls( 504 | model.parameters(), 505 | lr=args.learning_rate, 506 | betas=(args.adam_beta1, args.adam_beta2), 507 | weight_decay=args.adam_weight_decay, 508 | eps=args.adam_epsilon, 509 | ) 510 | 511 | # Get the datasets: you can either provide your own training and evaluation files (see below) 512 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 513 | 514 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 515 | # download the dataset. 516 | 517 | # Downloading and loading a dataset from the hub. 518 | dataset = load_dataset( 519 | args.dataset_name, 520 | args.dataset_config_name, 521 | cache_dir=args.cache_dir, 522 | data_dir=args.train_data_dir, 523 | streaming=args.streaming, 524 | ) 525 | 526 | 527 | # Preprocessing the datasets. 528 | # We need to tokenize inputs and targets. 529 | column_names = dataset["train"].column_names 530 | 531 | # 6. Get the column names for input/target. 532 | # dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) 533 | 534 | if args.text_column is None: 535 | text_column = column_names[0] 536 | else: 537 | text_column = args.text_column 538 | if text_column not in column_names: 539 | print(text_column, column_names) 540 | raise ValueError( 541 | f"--text_column' value '{args.text_column}' needs to be one of: {', '.join(column_names)}" 542 | ) 543 | 544 | # Preprocessing the datasets. 545 | # We need to tokenize input captions and transform the images. 546 | 547 | def tokenize_captions(examples, is_train=True): 548 | captions = [] 549 | for caption in examples[text_column]: 550 | if isinstance(caption, str): 551 | captions.append(caption) 552 | elif isinstance(caption, (list, np.ndarray)): 553 | # take a random caption if there are multiple 554 | captions.append(random.choice(caption) if is_train else caption[0]) 555 | else: 556 | raise ValueError( 557 | f"Caption column `{text_column}` should contain either strings or lists of strings." 558 | ) 559 | # print(captions) 560 | inputs = tokenizer( 561 | captions, max_length=64, padding="max_length", truncation=True, return_tensors="pt" 562 | ) 563 | return inputs.input_ids 564 | 565 | 566 | 567 | def apply_embeddings(token_ids): 568 | 569 | vectors = model.apply_embeddings(input_ids=token_ids.to(model.device)) 570 | 571 | return vectors 572 | 573 | def preprocess_train(examples): 574 | examples["input_ids"] = tokenize_captions(examples) 575 | return examples 576 | 577 | with accelerator.main_process_first(): 578 | if args.max_train_samples is not None: 579 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 580 | # Set the training transforms 581 | train_dataset = dataset["train"].with_transform(preprocess_train) 582 | 583 | def collate_fn(examples): 584 | input_ids = torch.stack([example["input_ids"] for example in examples]) 585 | # one_hots = torch.stack([apply_embeddings(example["input_ids"]) for example in examples]) 586 | # one_hots = apply_embeddings(input_ids) 587 | return {"input_ids": input_ids} 588 | 589 | 590 | 591 | # DataLoaders creation: 592 | train_dataloader = torch.utils.data.DataLoader( 593 | train_dataset, 594 | shuffle=True, 595 | collate_fn=collate_fn, 596 | batch_size=args.train_batch_size, 597 | num_workers=args.dataloader_num_workers, 598 | ) 599 | 600 | # Scheduler and math around the number of training steps. 601 | overrode_max_train_steps = False 602 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 603 | if args.max_train_steps is None: 604 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 605 | overrode_max_train_steps = True 606 | 607 | lr_scheduler = get_scheduler( 608 | args.lr_scheduler, 609 | optimizer=optimizer, 610 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 611 | num_training_steps=args.max_train_steps * accelerator.num_processes, 612 | ) 613 | 614 | # Prepare everything with our `accelerator`. 615 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 616 | model, optimizer, train_dataloader, lr_scheduler 617 | ) 618 | 619 | if args.use_ema: 620 | ema_model.to(accelerator.device) 621 | 622 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora model) to half-precision 623 | # as these weights are only used for inference, keeping weights in full precision is not required. 624 | weight_dtype = torch.float32 625 | if accelerator.mixed_precision == "fp16": 626 | weight_dtype = torch.float16 627 | args.mixed_precision = accelerator.mixed_precision 628 | elif accelerator.mixed_precision == "bf16": 629 | weight_dtype = torch.bfloat16 630 | args.mixed_precision = accelerator.mixed_precision 631 | 632 | 633 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 634 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 635 | if overrode_max_train_steps: 636 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 637 | # Afterwards we recalculate our number of training epochs 638 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 639 | 640 | # We need to initialize the trackers we use, and also store our configuration. 641 | # The trackers initializes automatically on the main process. 642 | if accelerator.is_main_process: 643 | tracker_config = dict(vars(args)) 644 | tracker_config.pop("validation_prompts") 645 | accelerator.init_trackers(args.tracker_project_name, tracker_config) 646 | 647 | 648 | class MeanZeroLoss(torch.nn.Module): 649 | def __init__(self): 650 | super(MeanZeroLoss, self).__init__() 651 | 652 | def forward(self, input_tensor): 653 | mean_value = torch.mean(input_tensor) 654 | loss = torch.abs(mean_value) # Penalize the absolute deviation from zero 655 | return loss 656 | 657 | # Train! 658 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 659 | 660 | logger.info("***** Running training *****") 661 | logger.info(f" Num examples = {len(train_dataset)}") 662 | logger.info(f" Num Epochs = {args.num_train_epochs}") 663 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 664 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 665 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 666 | logger.info(f" Total optimization steps = {args.max_train_steps}") 667 | global_step = 0 668 | first_epoch = 0 669 | 670 | mean_zero_loss_function = MeanZeroLoss() 671 | # Potentially load in the weights and states from a previous save 672 | if args.resume_from_checkpoint: 673 | if args.resume_from_checkpoint != "latest": 674 | path = os.path.basename(args.resume_from_checkpoint) 675 | else: 676 | # Get the most recent checkpoint 677 | dirs = os.listdir(args.output_dir) 678 | dirs = [d for d in dirs if d.startswith("checkpoint")] 679 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 680 | path = dirs[-1] if len(dirs) > 0 else None 681 | 682 | if path is None: 683 | accelerator.print( 684 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 685 | ) 686 | args.resume_from_checkpoint = None 687 | initial_global_step = 0 688 | else: 689 | accelerator.print(f"Resuming from checkpoint {path}") 690 | accelerator.load_state(os.path.join(args.output_dir, path)) 691 | global_step = int(path.split("-")[1]) 692 | 693 | initial_global_step = global_step 694 | first_epoch = global_step // num_update_steps_per_epoch 695 | 696 | else: 697 | initial_global_step = 0 698 | 699 | progress_bar = tqdm( 700 | range(0, args.max_train_steps), 701 | initial=initial_global_step, 702 | desc="Steps", 703 | # Only show the progress bar once on each machine. 704 | disable=not accelerator.is_local_main_process, 705 | ) 706 | 707 | for epoch in range(first_epoch, args.num_train_epochs): 708 | train_loss = 0.0 709 | for step, batch in enumerate(train_dataloader): 710 | with accelerator.accumulate(model): 711 | # Convert images to latent space 712 | input_ids = batch["input_ids"].to(model.device) 713 | 714 | max_steps = noise_scheduler.config.num_train_timesteps 715 | bsz = input_ids.shape[0] 716 | timesteps = torch.randint(0, max_steps, (bsz,), device=model.device) 717 | timesteps = timesteps.long() 718 | 719 | latents = apply_embeddings(input_ids) 720 | # print(latents.shape) 721 | 722 | noise = torch.randn_like(latents, requires_grad=False) 723 | 724 | 725 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 726 | 727 | 728 | # noisy_latents, noise, timesteps = add_noise(latents) 729 | 730 | 731 | # Get the target for loss depending on the prediction type 732 | if args.prediction_type is not None: 733 | # set prediction_type of scheduler if defined 734 | noise_scheduler.register_to_config(prediction_type=args.prediction_type) 735 | 736 | if noise_scheduler.config.prediction_type == "epsilon": 737 | target = noise 738 | elif noise_scheduler.config.prediction_type == "sample": 739 | target = latents 740 | elif noise_scheduler.config.prediction_type == "v_prediction": 741 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 742 | else: 743 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 744 | 745 | # Predict the noise residual and compute loss 746 | outputs = model(noisy_latents, timesteps=timesteps, labels=input_ids) 747 | nll_loss=outputs.loss 748 | model_pred = outputs.last_hidden_state 749 | 750 | ae_out = model.apply_lm_head(latents, labels=input_ids) 751 | 752 | ae_loss = ae_out.loss 753 | 754 | 755 | if args.snr_gamma is None: 756 | mse_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 757 | # mean_zero_loss = mean_zero_loss_function(model_pred.float()) 758 | 759 | # out_mean = q_mean_variance( 760 | # model_pred, torch.LongTensor([max_steps - 1]).to(model_pred.device), noise_scheduler 761 | # ) 762 | # tT_loss = mean_flat(out_mean**2).mean() 763 | # print(tT_loss) 764 | if step < 100: 765 | loss = ae_loss 766 | else: 767 | loss = nll_loss + mse_loss + ae_loss# + tT_loss 768 | else: 769 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. 770 | # Since we predict the noise instead of x_0, the original formulation is slightly changed. 771 | # This is discussed in Section 4.2 of the same paper. 772 | snr = compute_snr(noise_scheduler, timesteps) 773 | if noise_scheduler.config.prediction_type == "v_prediction": 774 | # Velocity objective requires that we add one to SNR values before we divide by them. 775 | snr = snr + 1 776 | mse_loss_weights = ( 777 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 778 | ) 779 | 780 | 781 | mse_loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 782 | mse_loss = mse_loss.mean(dim=list(range(1, len(mse_loss.shape)))) * mse_loss_weights 783 | mse_loss = mse_loss.mean() 784 | # mean_zero_loss = mean_zero_loss_function(model_pred.float()) 785 | if step < 100: 786 | loss = ae_loss 787 | else: 788 | loss = nll_loss + mse_loss + ae_loss 789 | 790 | # Gather the losses across all processes for logging (if we use distributed training). 791 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 792 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 793 | 794 | # Backpropagate 795 | accelerator.backward(loss) 796 | if accelerator.sync_gradients: 797 | accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) 798 | optimizer.step() 799 | lr_scheduler.step() 800 | optimizer.zero_grad() 801 | 802 | # Checks if the accelerator has performed an optimization step behind the scenes 803 | if accelerator.sync_gradients: 804 | if args.use_ema: 805 | ema_model.step(model.parameters()) 806 | progress_bar.update(1) 807 | global_step += 1 808 | accelerator.log({"train_loss": train_loss}, step=global_step) 809 | train_loss = 0.0 810 | 811 | if global_step % args.checkpointing_steps == 0: 812 | if accelerator.is_main_process: 813 | model = accelerator.unwrap_model(model) 814 | if args.use_ema: 815 | ema_model.copy_to(model.parameters()) 816 | 817 | model.save_pretrained(args.output_dir) 818 | 819 | 820 | logs = {"loss": loss.detach().item(), 821 | "mse_loss": mse_loss.detach().item(), 822 | "nll_loss": nll_loss.detach().item(), 823 | # "mz_loss": mean_zero_loss.detach().item(), 824 | "ae_loss": ae_loss.detach().item(), 825 | "lr": lr_scheduler.get_last_lr()[0]} 826 | progress_bar.set_postfix(**logs) 827 | 828 | if global_step >= args.max_train_steps: 829 | break 830 | 831 | 832 | 833 | # Create the pipeline using the trained modules and save it. 834 | accelerator.wait_for_everyone() 835 | if accelerator.is_main_process: 836 | model = accelerator.unwrap_model(model) 837 | if args.use_ema: 838 | ema_model.copy_to(model.parameters()) 839 | 840 | model.save_pretrained(args.output_dir) 841 | 842 | accelerator.end_training() 843 | 844 | 845 | if __name__ == "__main__": 846 | main() -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thedarkzeno/text-diffusion/eebb273c2f9fc7bb8fc355691f450c7ac97eabe0/src/__init__.py -------------------------------------------------------------------------------- /src/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thedarkzeno/text-diffusion/eebb273c2f9fc7bb8fc355691f450c7ac97eabe0/src/decoders/__init__.py -------------------------------------------------------------------------------- /src/denoiser_decoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Classes to support Encoder-Decoder architectures""" 16 | 17 | 18 | import gc 19 | import inspect 20 | import os 21 | import tempfile 22 | import warnings 23 | from typing import Optional, Tuple, Union 24 | 25 | import torch 26 | from torch import nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | from transformers.configuration_utils import PretrainedConfig 30 | from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput 31 | from transformers.modeling_utils import PreTrainedModel 32 | from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings 33 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, DenoiserDecoderConfig 34 | 35 | 36 | logger = logging.get_logger(__name__) 37 | 38 | 39 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 40 | """ 41 | Shift input ids one token to the right. 42 | """ 43 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 44 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 45 | if decoder_start_token_id is None: 46 | raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") 47 | shifted_input_ids[:, 0] = decoder_start_token_id 48 | 49 | if pad_token_id is None: 50 | raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") 51 | # replace possible -100 values in labels by `pad_token_id` 52 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 53 | 54 | return shifted_input_ids 55 | 56 | 57 | class DenoiserDecoderModel(PreTrainedModel): 58 | 59 | config_class = DenoiserDecoderConfig 60 | base_model_prefix = "denoiser_decoder" 61 | main_input_name = "input_ids" 62 | supports_gradient_checkpointing = True 63 | 64 | def __init__( 65 | self, 66 | config: Optional[PretrainedConfig] = None, 67 | denoiser: Optional[PreTrainedModel] = None, 68 | decoder: Optional[PreTrainedModel] = None, 69 | scheduler = None 70 | ): 71 | if config is None and (denoiser is None or decoder is None): 72 | raise ValueError("Either a configuration or an denoiser and a decoder has to be provided.") 73 | if config is None: 74 | config = DenoiserDecoderConfig.from_denoiser_decoder_configs(denoiser.config, decoder.config) 75 | else: 76 | if not isinstance(config, self.config_class): 77 | raise ValueError(f"Config: {config} has to be of type {self.config_class}") 78 | 79 | if config.decoder.cross_attention_hidden_size is not None: 80 | if config.decoder.cross_attention_hidden_size != config.denoiser.hidden_size: 81 | raise ValueError( 82 | "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" 83 | f" to the denoiser's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" 84 | f" `config.decoder.cross_attention_hidden_size` and {config.denoiser.hidden_size} for" 85 | " `config.denoiser.hidden_size`." 86 | ) 87 | 88 | # initialize with config 89 | super().__init__(config) 90 | 91 | if denoiser is None: 92 | denoiser = AutoModel.from_config(config.denoiser) 93 | 94 | if decoder is None: 95 | decoder = AutoModelForCausalLM.from_config(config.decoder) 96 | 97 | self.denoiser = denoiser 98 | self.decoder = decoder 99 | 100 | if self.denoiser.config.to_dict() != self.config.denoiser.to_dict(): 101 | logger.warning( 102 | f"Config of the denoiser: {self.denoiser.__class__} is overwritten by shared denoiser config:" 103 | f" {self.config.denoiser}" 104 | ) 105 | if self.decoder.config.to_dict() != self.config.decoder.to_dict(): 106 | logger.warning( 107 | f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" 108 | f" {self.config.decoder}" 109 | ) 110 | 111 | # make sure that the individual model's config refers to the shared config 112 | # so that the updates to the config will be synced 113 | self.denoiser.config = self.config.denoiser 114 | self.decoder.config = self.config.decoder 115 | 116 | # denoiser outputs might need to be projected to different dimension for decoder 117 | if ( 118 | self.denoiser.config.hidden_size != self.decoder.config.hidden_size 119 | and self.decoder.config.cross_attention_hidden_size is None 120 | ): 121 | self.enc_to_dec_proj = nn.Linear(self.denoiser.config.hidden_size, self.decoder.config.hidden_size) 122 | 123 | if self.denoiser.get_output_embeddings() is not None: 124 | raise ValueError( 125 | f"The denoiser {self.denoiser} should not have a LM Head. Please use a model without LM Head" 126 | ) 127 | 128 | decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys()) 129 | if "encoder_hidden_states" not in decoder_signature: 130 | raise ValueError( 131 | "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " 132 | "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" 133 | ) 134 | 135 | # tie encoder, decoder weights if config set accordingly 136 | self.tie_weights() 137 | 138 | def tie_weights(self): 139 | # tie encoder & decoder if needed 140 | if self.config.tie_denoiser_decoder: 141 | # tie denoiser and decoder base model 142 | decoder_base_model_prefix = self.decoder.base_model_prefix 143 | self._tie_denoiser_decoder_weights( 144 | self.denoiser, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix 145 | ) 146 | 147 | def get_denoiser(self): 148 | return self.denoiser 149 | 150 | def get_decoder(self): 151 | return self.decoder 152 | 153 | def get_input_embeddings(self): 154 | return self.denoiser.get_input_embeddings() 155 | 156 | def get_output_embeddings(self): 157 | return self.decoder.get_output_embeddings() 158 | 159 | def set_output_embeddings(self, new_embeddings): 160 | return self.decoder.set_output_embeddings(new_embeddings) 161 | 162 | @classmethod 163 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 164 | r""" 165 | Example: 166 | 167 | ```python 168 | >>> from transformers import DenoiserDecoderModel 169 | 170 | >>> model = DenoiserDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") 171 | ```""" 172 | 173 | from_tf = kwargs.pop("from_tf", False) 174 | if from_tf: 175 | from transformers import TFDenoiserDecoderModel 176 | 177 | _tf_model = TFDenoiserDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) 178 | config = _tf_model.config 179 | 180 | # Using `tf_model` instead 181 | denoiser = _tf_model.denoiser.__class__(_tf_model.config.denoiser) 182 | decoder = _tf_model.decoder.__class__(_tf_model.config.decoder) 183 | # Make sure models are built 184 | denoiser(denoiser.dummy_inputs) 185 | decoder(decoder.dummy_inputs) 186 | 187 | # Get the variable correspondence between `_tf_model` and `denoiser` and `decoder` 188 | denoiser_variables = {} 189 | for v in denoiser.trainable_variables + denoiser.non_trainable_variables: 190 | denoiser_variables["/".join(v.name.split("/")[1:])] = v 191 | decoder_variables = {} 192 | for v in decoder.trainable_variables + decoder.non_trainable_variables: 193 | decoder_variables["/".join(v.name.split("/")[1:])] = v 194 | 195 | _denoiser_variables = {} 196 | for v in _tf_model.denoiser.trainable_variables + _tf_model.denoiser.non_trainable_variables: 197 | _denoiser_variables["/".join(v.name.split("/")[2:])] = v 198 | _decoder_variables = {} 199 | for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables: 200 | _decoder_variables["/".join(v.name.split("/")[2:])] = v 201 | 202 | # assign weight values to `denoiser` and `decoder` from `_tf_model` 203 | for name, v in denoiser_variables.items(): 204 | v.assign(_denoiser_variables[name]) 205 | for name, v in decoder_variables.items(): 206 | v.assign(_decoder_variables[name]) 207 | 208 | tf_model = TFDenoiserDecoderModel(denoiser=denoiser, decoder=decoder) 209 | 210 | # Deal with `enc_to_dec_proj` 211 | if hasattr(_tf_model, "enc_to_dec_proj"): 212 | tf_model(tf_model.dummy_inputs) 213 | tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel) 214 | tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias) 215 | 216 | with tempfile.TemporaryDirectory() as tmpdirname: 217 | denoiser_dir = os.path.join(tmpdirname, "denoiser") 218 | decoder_dir = os.path.join(tmpdirname, "decoder") 219 | tf_model.denoiser.save_pretrained(denoiser_dir) 220 | tf_model.decoder.save_pretrained(decoder_dir) 221 | 222 | if hasattr(tf_model, "enc_to_dec_proj"): 223 | enc_to_dec_proj_weight = torch.transpose( 224 | torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0 225 | ) 226 | enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy()) 227 | 228 | del _tf_model 229 | del tf_model 230 | gc.collect() 231 | 232 | model = DenoiserDecoderModel.from_denoiser_decoder_pretrained( 233 | denoiser_dir, decoder_dir, denoiser_from_tf=True, decoder_from_tf=True 234 | ) 235 | # This is only for copying some specific attributes of this particular model. 236 | model.config = config 237 | 238 | if hasattr(model, "enc_to_dec_proj"): 239 | model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous() 240 | model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous() 241 | 242 | return model 243 | 244 | # At the moment fast initialization is not supported for composite models 245 | if kwargs.get("_fast_init", False): 246 | logger.warning( 247 | "Fast initialization is currently not supported for DenoiserDecoderModel. " 248 | "Falling back to slow initialization..." 249 | ) 250 | kwargs["_fast_init"] = False 251 | 252 | return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) 253 | 254 | @classmethod 255 | def from_denoiser_decoder_pretrained( 256 | cls, 257 | denoiser_pretrained_model_name_or_path: str = None, 258 | decoder_pretrained_model_name_or_path: str = None, 259 | *model_args, 260 | **kwargs, 261 | ) -> PreTrainedModel: 262 | 263 | kwargs_denoiser = { 264 | argument[len("denoiser_") :]: value for argument, value in kwargs.items() if argument.startswith("denoiser_") 265 | } 266 | 267 | kwargs_decoder = { 268 | argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") 269 | } 270 | 271 | # remove denoiser, decoder kwargs from kwargs 272 | for key in kwargs_denoiser.keys(): 273 | del kwargs["denoiser_" + key] 274 | for key in kwargs_decoder.keys(): 275 | del kwargs["decoder_" + key] 276 | 277 | # Load and initialize the denoiser and decoder 278 | # The distinction between denoiser and decoder at the model level is made 279 | # by the value of the flag `is_decoder` that we need to set correctly. 280 | denoiser = kwargs_denoiser.pop("model", None) 281 | if denoiser is None: 282 | if denoiser_pretrained_model_name_or_path is None: 283 | raise ValueError( 284 | "If `denoiser_model` is not defined as an argument, a `denoiser_pretrained_model_name_or_path` has " 285 | "to be defined." 286 | ) 287 | 288 | if "config" not in kwargs_denoiser: 289 | denoiser_config, kwargs_denoiser = AutoConfig.from_pretrained( 290 | denoiser_pretrained_model_name_or_path, **kwargs_denoiser, return_unused_kwargs=True 291 | ) 292 | 293 | if denoiser_config.is_decoder is True or denoiser_config.add_cross_attention is True: 294 | logger.info( 295 | f"Initializing {denoiser_pretrained_model_name_or_path} as a denoiser model " 296 | "from a decoder model. Cross-attention and casual mask are disabled." 297 | ) 298 | denoiser_config.is_decoder = False 299 | denoiser_config.add_cross_attention = False 300 | 301 | kwargs_denoiser["config"] = denoiser_config 302 | 303 | denoiser = AutoModel.from_pretrained(denoiser_pretrained_model_name_or_path, *model_args, **kwargs_denoiser) 304 | 305 | decoder = kwargs_decoder.pop("model", None) 306 | if decoder is None: 307 | if decoder_pretrained_model_name_or_path is None: 308 | raise ValueError( 309 | "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " 310 | "to be defined." 311 | ) 312 | 313 | if "config" not in kwargs_decoder: 314 | decoder_config, kwargs_decoder = AutoConfig.from_pretrained( 315 | decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True 316 | ) 317 | 318 | if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: 319 | logger.info( 320 | f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" 321 | f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" 322 | f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." 323 | ) 324 | decoder_config.is_decoder = True 325 | decoder_config.add_cross_attention = True 326 | 327 | kwargs_decoder["config"] = decoder_config 328 | 329 | if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: 330 | logger.warning( 331 | f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " 332 | f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " 333 | "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " 334 | "passed to `.from_denoiser_decoder_pretrained(...)` are set to `True` or do not pass a " 335 | "`decoder_config` to `.from_denoiser_decoder_pretrained(...)`" 336 | ) 337 | 338 | decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) 339 | 340 | # instantiate config with corresponding kwargs 341 | config = DenoiserDecoderConfig.from_denoiser_decoder_configs(denoiser.config, decoder.config, **kwargs) 342 | return cls(denoiser=denoiser, decoder=decoder, config=config) 343 | 344 | def forward( 345 | self, 346 | input_ids: Optional[torch.LongTensor] = None, 347 | attention_mask: Optional[torch.FloatTensor] = None, 348 | decoder_input_ids: Optional[torch.LongTensor] = None, 349 | decoder_attention_mask: Optional[torch.BoolTensor] = None, 350 | denoiser_outputs: Optional[Tuple[torch.FloatTensor]] = None, 351 | past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, 352 | inputs_embeds: Optional[torch.FloatTensor] = None, 353 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 354 | labels: Optional[torch.LongTensor] = None, 355 | use_cache: Optional[bool] = None, 356 | output_attentions: Optional[bool] = None, 357 | output_hidden_states: Optional[bool] = None, 358 | return_dict: Optional[bool] = None, 359 | **kwargs, 360 | ) -> Union[Tuple, Seq2SeqLMOutput]: 361 | r""" 362 | Returns: 363 | 364 | Examples: 365 | 366 | ```python 367 | >>> from transformers import DenoiserDecoderModel, BertTokenizer 368 | >>> import torch 369 | 370 | >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 371 | >>> model = DenoiserDecoderModel.from_denoiser_decoder_pretrained( 372 | ... "bert-base-uncased", "bert-base-uncased" 373 | ... ) # initialize Bert2Bert from pre-trained checkpoints 374 | 375 | >>> # training 376 | >>> model.config.decoder_start_token_id = tokenizer.cls_token_id 377 | >>> model.config.pad_token_id = tokenizer.pad_token_id 378 | >>> model.config.vocab_size = model.config.decoder.vocab_size 379 | 380 | >>> input_ids = tokenizer("This is a really long text", return_tensors="pt").input_ids 381 | >>> labels = tokenizer("This is the corresponding summary", return_tensors="pt").input_ids 382 | >>> outputs = model(input_ids=input_ids, labels=labels) 383 | >>> loss, logits = outputs.loss, outputs.logits 384 | 385 | >>> # save and load from pretrained 386 | >>> model.save_pretrained("bert2bert") 387 | >>> model = DenoiserDecoderModel.from_pretrained("bert2bert") 388 | 389 | >>> # generation 390 | >>> generated = model.generate(input_ids) 391 | ```""" 392 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 393 | 394 | kwargs_denoiser = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} 395 | 396 | kwargs_decoder = { 397 | argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") 398 | } 399 | 400 | if denoiser_outputs is None: 401 | denoiser_outputs = self.denoiser( 402 | input_ids=input_ids, 403 | attention_mask=attention_mask, 404 | inputs_embeds=inputs_embeds, 405 | output_attentions=output_attentions, 406 | output_hidden_states=output_hidden_states, 407 | return_dict=return_dict, 408 | **kwargs_denoiser, 409 | ) 410 | elif isinstance(denoiser_outputs, tuple): 411 | denoiser_outputs = BaseModelOutput(*denoiser_outputs) 412 | 413 | encoder_hidden_states = denoiser_outputs[0] 414 | 415 | # optionally project encoder_hidden_states 416 | if ( 417 | self.denoiser.config.hidden_size != self.decoder.config.hidden_size 418 | and self.decoder.config.cross_attention_hidden_size is None 419 | ): 420 | encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) 421 | 422 | if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): 423 | decoder_input_ids = shift_tokens_right( 424 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 425 | ) 426 | if decoder_attention_mask is None: 427 | decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id) 428 | 429 | # Decode 430 | decoder_outputs = self.decoder( 431 | input_ids=decoder_input_ids, 432 | attention_mask=decoder_attention_mask, 433 | encoder_hidden_states=encoder_hidden_states, 434 | encoder_attention_mask=attention_mask, 435 | inputs_embeds=decoder_inputs_embeds, 436 | output_attentions=output_attentions, 437 | output_hidden_states=output_hidden_states, 438 | use_cache=use_cache, 439 | past_key_values=past_key_values, 440 | return_dict=return_dict, 441 | **kwargs_decoder, 442 | ) 443 | 444 | # Compute loss independent from decoder (as some shift the logits inside them) 445 | loss = None 446 | if labels is not None: 447 | logits = decoder_outputs.logits if return_dict else decoder_outputs[0] 448 | loss_fct = CrossEntropyLoss() 449 | loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) 450 | 451 | if not return_dict: 452 | if loss is not None: 453 | return (loss,) + decoder_outputs + denoiser_outputs 454 | else: 455 | return decoder_outputs + denoiser_outputs 456 | 457 | return Seq2SeqLMOutput( 458 | loss=loss, 459 | logits=decoder_outputs.logits, 460 | past_key_values=decoder_outputs.past_key_values, 461 | decoder_hidden_states=decoder_outputs.hidden_states, 462 | decoder_attentions=decoder_outputs.attentions, 463 | cross_attentions=decoder_outputs.cross_attentions, 464 | encoder_last_hidden_state=denoiser_outputs.last_hidden_state, 465 | encoder_hidden_states=denoiser_outputs.hidden_states, 466 | encoder_attentions=denoiser_outputs.attentions, 467 | ) 468 | 469 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 470 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 471 | 472 | def prepare_inputs_for_generation( 473 | self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, denoiser_outputs=None, **kwargs 474 | ): 475 | decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) 476 | decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None 477 | input_dict = { 478 | "attention_mask": attention_mask, 479 | "decoder_attention_mask": decoder_attention_mask, 480 | "decoder_input_ids": decoder_inputs["input_ids"], 481 | "denoiser_outputs": denoiser_outputs, 482 | "past_key_values": decoder_inputs["past_key_values"], 483 | "use_cache": use_cache, 484 | } 485 | return input_dict 486 | 487 | def resize_token_embeddings(self, *args, **kwargs): 488 | raise NotImplementedError( 489 | "Resizing the embedding layers via the DenoiserDecoderModel directly is not supported. Please use the" 490 | " respective methods of the wrapped objects (model.denoiser.resize_token_embeddings(...) or" 491 | " model.decoder.resize_token_embeddings(...))" 492 | ) 493 | 494 | def _reorder_cache(self, past_key_values, beam_idx): 495 | # apply decoder cache reordering here 496 | return self.decoder._reorder_cache(past_key_values, beam_idx) -------------------------------------------------------------------------------- /src/denoisers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thedarkzeno/text-diffusion/eebb273c2f9fc7bb8fc355691f450c7ac97eabe0/src/denoisers/__init__.py -------------------------------------------------------------------------------- /src/denoisers/configuration_diffbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration""" 17 | from collections import OrderedDict 18 | from typing import Mapping 19 | 20 | from transformers.configuration_utils import PretrainedConfig 21 | from transformers.onnx import OnnxConfig 22 | from transformers.utils import logging 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | 29 | class DiffBertConfig(PretrainedConfig): 30 | r""" 31 | This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to 32 | instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a 33 | configuration with the defaults will yield a similar configuration to that of the BERT 34 | [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. 35 | 36 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 37 | documentation from [`PretrainedConfig`] for more information. 38 | 39 | 40 | Args: 41 | vocab_size (`int`, *optional*, defaults to 30522): 42 | Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the 43 | `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. 44 | hidden_size (`int`, *optional*, defaults to 768): 45 | Dimensionality of the encoder layers and the pooler layer. 46 | num_hidden_layers (`int`, *optional*, defaults to 12): 47 | Number of hidden layers in the Transformer encoder. 48 | num_attention_heads (`int`, *optional*, defaults to 12): 49 | Number of attention heads for each attention layer in the Transformer encoder. 50 | intermediate_size (`int`, *optional*, defaults to 3072): 51 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 52 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): 53 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 54 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 55 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1): 56 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 57 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): 58 | The dropout ratio for the attention probabilities. 59 | max_position_embeddings (`int`, *optional*, defaults to 512): 60 | The maximum sequence length that this model might ever be used with. Typically set this to something large 61 | just in case (e.g., 512 or 1024 or 2048). 62 | type_vocab_size (`int`, *optional*, defaults to 2): 63 | The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. 64 | initializer_range (`float`, *optional*, defaults to 0.02): 65 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 66 | layer_norm_eps (`float`, *optional*, defaults to 1e-12): 67 | The epsilon used by the layer normalization layers. 68 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`): 69 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For 70 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to 71 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). 72 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models 73 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). 74 | is_decoder (`bool`, *optional*, defaults to `False`): 75 | Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. 76 | use_cache (`bool`, *optional*, defaults to `True`): 77 | Whether or not the model should return the last key/values attentions (not used by all models). Only 78 | relevant if `config.is_decoder=True`. 79 | classifier_dropout (`float`, *optional*): 80 | The dropout ratio for the classification head. 81 | 82 | Examples: 83 | 84 | ```python 85 | >>> from transformers import BertConfig, BertModel 86 | 87 | >>> # Initializing a BERT bert-base-uncased style configuration 88 | >>> configuration = BertConfig() 89 | 90 | >>> # Initializing a model (with random weights) from the bert-base-uncased style configuration 91 | >>> model = BertModel(configuration) 92 | 93 | >>> # Accessing the model configuration 94 | >>> configuration = model.config 95 | ```""" 96 | 97 | model_type = "diffbert" 98 | 99 | def __init__( 100 | self, 101 | vocab_size=30522, 102 | hidden_size=768, 103 | num_hidden_layers=12, 104 | num_attention_heads=12, 105 | intermediate_size=3072, 106 | hidden_act="silu", 107 | rms_norm_eps=1e-05, 108 | timesteps=1000, 109 | hidden_dropout_prob=0.1, 110 | attention_probs_dropout_prob=0.1, 111 | max_position_embeddings=512, 112 | type_vocab_size=2, 113 | initializer_range=0.02, 114 | # layer_norm_eps=1e-12, 115 | pad_token_id=0, 116 | # position_embedding_type="absolute", 117 | use_cache=True, 118 | rotary_value=False, 119 | classifier_dropout=None, 120 | **kwargs, 121 | ): 122 | super().__init__(pad_token_id=pad_token_id, **kwargs) 123 | 124 | self.vocab_size = vocab_size 125 | self.hidden_size = hidden_size 126 | self.num_hidden_layers = num_hidden_layers 127 | self.num_attention_heads = num_attention_heads 128 | self.hidden_act = hidden_act 129 | self.timesteps=timesteps 130 | self.intermediate_size = intermediate_size 131 | self.hidden_dropout_prob = hidden_dropout_prob 132 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 133 | self.max_position_embeddings = max_position_embeddings 134 | self.type_vocab_size = type_vocab_size 135 | self.initializer_range = initializer_range 136 | self.rms_norm_eps=rms_norm_eps 137 | # self.layer_norm_eps = layer_norm_eps 138 | # self.position_embedding_type = position_embedding_type 139 | self.use_cache = use_cache 140 | self.rotary_value = rotary_value 141 | self.classifier_dropout = classifier_dropout 142 | 143 | 144 | class BertOnnxConfig(OnnxConfig): 145 | @property 146 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 147 | if self.task == "multiple-choice": 148 | dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} 149 | else: 150 | dynamic_axis = {0: "batch", 1: "sequence"} 151 | return OrderedDict( 152 | [ 153 | ("input_ids", dynamic_axis), 154 | ("attention_mask", dynamic_axis), 155 | ("token_type_ids", dynamic_axis), 156 | ] 157 | ) -------------------------------------------------------------------------------- /src/denoisers/configuration_diffllama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaMA model configuration""" 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class DiffLlamaConfig(PretrainedConfig): 32 | 33 | 34 | model_type = "diffllama" 35 | keys_to_ignore_at_inference = ["past_key_values"] 36 | 37 | def __init__( 38 | self, 39 | vocab_size=32000, 40 | hidden_size=4096, 41 | intermediate_size=11008, 42 | num_hidden_layers=6, 43 | num_attention_heads=32, 44 | num_key_value_heads=None, 45 | hidden_act="silu", 46 | timesteps=2000, 47 | max_position_embeddings=2048, 48 | initializer_range=0.02, 49 | rms_norm_eps=1e-6, 50 | use_cache=True, 51 | pad_token_id=None, 52 | bos_token_id=1, 53 | eos_token_id=2, 54 | pretraining_tp=1, 55 | tie_word_embeddings=False, 56 | rope_theta=10000.0, 57 | rope_scaling=None, 58 | attention_bias=False, 59 | attention_dropout=0.0, 60 | **kwargs, 61 | ): 62 | self.vocab_size = vocab_size 63 | self.max_position_embeddings = max_position_embeddings 64 | self.hidden_size = hidden_size 65 | self.intermediate_size = intermediate_size 66 | self.num_hidden_layers = num_hidden_layers 67 | self.num_attention_heads = num_attention_heads 68 | 69 | # for backward compatibility 70 | if num_key_value_heads is None: 71 | num_key_value_heads = num_attention_heads 72 | 73 | 74 | self.timesteps=timesteps 75 | 76 | 77 | self.num_key_value_heads = num_key_value_heads 78 | self.hidden_act = hidden_act 79 | self.initializer_range = initializer_range 80 | self.rms_norm_eps = rms_norm_eps 81 | self.pretraining_tp = pretraining_tp 82 | self.use_cache = use_cache 83 | self.rope_theta = rope_theta 84 | self.rope_scaling = rope_scaling 85 | self._rope_scaling_validation() 86 | self.attention_bias = attention_bias 87 | self.attention_dropout = attention_dropout 88 | 89 | super().__init__( 90 | pad_token_id=pad_token_id, 91 | bos_token_id=bos_token_id, 92 | eos_token_id=eos_token_id, 93 | tie_word_embeddings=tie_word_embeddings, 94 | **kwargs, 95 | ) 96 | 97 | def _rope_scaling_validation(self): 98 | """ 99 | Validate the `rope_scaling` configuration. 100 | """ 101 | if self.rope_scaling is None: 102 | return 103 | 104 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 105 | raise ValueError( 106 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 107 | f"got {self.rope_scaling}" 108 | ) 109 | rope_scaling_type = self.rope_scaling.get("type", None) 110 | rope_scaling_factor = self.rope_scaling.get("factor", None) 111 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 112 | raise ValueError( 113 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 114 | ) 115 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 116 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") -------------------------------------------------------------------------------- /src/denoisers/configuration_diffmamba.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaMA model configuration""" 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class DiffMambaConfig(PretrainedConfig): 32 | 33 | 34 | model_type = "diffmamba" 35 | keys_to_ignore_at_inference = ["past_key_values"] 36 | 37 | def __init__( 38 | self, 39 | vocab_size=32000, 40 | hidden_size=4096, 41 | intermediate_size=11008, 42 | num_hidden_layers=6, 43 | num_attention_heads=32, 44 | n_mamba_inversion=4, 45 | num_key_value_heads=None, 46 | hidden_act="silu", 47 | timesteps=2000, 48 | max_position_embeddings=2048, 49 | initializer_range=0.02, 50 | rms_norm_eps=1e-6, 51 | use_cache=True, 52 | pad_token_id=None, 53 | bos_token_id=1, 54 | eos_token_id=2, 55 | pretraining_tp=1, 56 | tie_word_embeddings=False, 57 | attention_bias=False, 58 | attention_dropout=0.0, 59 | cross_attention=False, 60 | **kwargs, 61 | ): 62 | self.vocab_size = vocab_size 63 | self.max_position_embeddings = max_position_embeddings 64 | self.hidden_size = hidden_size 65 | self.intermediate_size = intermediate_size 66 | self.num_hidden_layers = num_hidden_layers 67 | self.num_attention_heads = num_attention_heads 68 | 69 | self.n_mamba_inversion=n_mamba_inversion 70 | 71 | # for backward compatibility 72 | if num_key_value_heads is None: 73 | num_key_value_heads = num_attention_heads 74 | 75 | 76 | self.timesteps=timesteps 77 | 78 | 79 | self.num_key_value_heads = num_key_value_heads 80 | self.hidden_act = hidden_act 81 | self.initializer_range = initializer_range 82 | self.rms_norm_eps = rms_norm_eps 83 | self.pretraining_tp = pretraining_tp 84 | self.use_cache = use_cache 85 | 86 | 87 | self.attention_bias = attention_bias 88 | self.attention_dropout = attention_dropout 89 | 90 | self.cross_attention=cross_attention 91 | 92 | super().__init__( 93 | pad_token_id=pad_token_id, 94 | bos_token_id=bos_token_id, 95 | eos_token_id=eos_token_id, 96 | tie_word_embeddings=tie_word_embeddings, 97 | **kwargs, 98 | ) 99 | 100 | -------------------------------------------------------------------------------- /src/denoisers/modeling_diffmamba.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch LLaMA model.""" 21 | import math 22 | import warnings 23 | from dataclasses import dataclass 24 | from typing import List, Optional, Tuple, Union 25 | from functools import partial 26 | 27 | import torch 28 | import torch.nn.functional as F 29 | import torch.utils.checkpoint 30 | from torch import nn 31 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 32 | 33 | from transformers.activations import ACT2FN 34 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 35 | from transformers.modeling_utils import PreTrainedModel 36 | # from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 37 | from transformers.utils import ( 38 | ModelOutput, 39 | add_start_docstrings, 40 | logging, 41 | replace_return_docstrings, 42 | ) 43 | from transformers.utils.import_utils import is_torch_fx_available 44 | from src.denoisers.configuration_diffmamba import DiffMambaConfig 45 | from mamba_ssm.modules.mamba_simple import Mamba, Block 46 | 47 | try: 48 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 49 | except ImportError: 50 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 51 | 52 | 53 | 54 | logger = logging.get_logger(__name__) 55 | 56 | _CONFIG_FOR_DOC = "DiffMambaConfig" 57 | 58 | 59 | 60 | def timestep_embedding(timesteps, dim, max_period=10000): 61 | """ 62 | Create sinusoidal timestep embeddings. 63 | 64 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 65 | These may be fractional. 66 | :param dim: the dimension of the output. 67 | :param max_period: controls the minimum frequency of the embeddings. 68 | :return: an [N x dim] Tensor of positional embeddings. 69 | """ 70 | half = dim // 2 71 | freqs = torch.exp( 72 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 73 | ).to(device=timesteps.device) 74 | args = timesteps[:, None].float() * freqs[None] 75 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 76 | if dim % 2: 77 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 78 | return embedding 79 | 80 | 81 | 82 | @dataclass 83 | class DiffusionLMOutput(ModelOutput): 84 | loss: Optional[torch.FloatTensor] = None 85 | logits: torch.FloatTensor = None 86 | last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None 87 | 88 | 89 | class DiffMambaRMSNorm(nn.Module): 90 | def __init__(self, hidden_size, eps=1e-6): 91 | """ 92 | DiffMambaRMSNorm is equivalent to T5LayerNorm 93 | """ 94 | super().__init__() 95 | self.weight = nn.Parameter(torch.ones(hidden_size)) 96 | self.variance_epsilon = eps 97 | 98 | def forward(self, hidden_states): 99 | input_dtype = hidden_states.dtype 100 | hidden_states = hidden_states.to(torch.float32) 101 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 102 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 103 | return self.weight * hidden_states.to(input_dtype) 104 | 105 | 106 | class DiffMambaMLP(nn.Module): 107 | def __init__(self, config): 108 | super().__init__() 109 | self.config = config 110 | self.hidden_size = config.hidden_size 111 | self.intermediate_size = config.intermediate_size 112 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 113 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 114 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 115 | self.act_fn = ACT2FN[config.hidden_act] 116 | 117 | def forward(self, x): 118 | if self.config.pretraining_tp > 1: 119 | slice = self.intermediate_size // self.config.pretraining_tp 120 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) 121 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) 122 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) 123 | 124 | gate_proj = torch.cat( 125 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 126 | ) 127 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) 128 | 129 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) 130 | down_proj = [ 131 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) 132 | ] 133 | down_proj = sum(down_proj) 134 | else: 135 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 136 | 137 | return down_proj 138 | 139 | 140 | class DiffMambaCrossAttention(nn.Module): 141 | def __init__(self, config): 142 | super().__init__() 143 | self.config=config 144 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 145 | raise ValueError( 146 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 147 | f"heads ({config.num_attention_heads})" 148 | ) 149 | 150 | self.num_attention_heads = config.num_attention_heads 151 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 152 | self.all_head_size = self.num_attention_heads * self.attention_head_size 153 | 154 | self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False) 155 | self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False) 156 | self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False) 157 | 158 | self.dropout = nn.Dropout(config.attention_dropout) 159 | 160 | 161 | def transpose_for_scores(self, x): 162 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 163 | x = x.view(*new_x_shape) 164 | return x.permute(0, 2, 1, 3) 165 | 166 | def forward( 167 | self, 168 | hidden_states, 169 | encoder_hidden_states, 170 | attention_mask=None, 171 | head_mask=None, 172 | encoder_attention_mask=None, 173 | ): 174 | mixed_query_layer = self.query(hidden_states) 175 | query_layer = self.transpose_for_scores(mixed_query_layer) 176 | 177 | 178 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 179 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 180 | attention_mask = encoder_attention_mask 181 | 182 | 183 | 184 | # Take the dot product between "query" and "key" to get the raw attention scores. 185 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 186 | 187 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 188 | if attention_mask is not None: 189 | # Apply the attention mask is (precomputed for all layers in RoFormerModel forward() function) 190 | attention_scores = attention_scores + attention_mask 191 | 192 | # Normalize the attention scores to probabilities. 193 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 194 | 195 | # This is actually dropping out entire tokens to attend to, which might 196 | # seem a bit unusual, but is taken from the original Transformer paper. 197 | attention_probs = self.dropout(attention_probs) 198 | 199 | # Mask heads if we want to 200 | if head_mask is not None: 201 | attention_probs = attention_probs * head_mask 202 | 203 | context_layer = torch.matmul(attention_probs, value_layer) 204 | 205 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 206 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 207 | context_layer = context_layer.view(*new_context_layer_shape) 208 | 209 | 210 | outputs = context_layer 211 | 212 | return outputs 213 | 214 | 215 | class DiffMambaPreTrainedModel(PreTrainedModel): 216 | config_class = DiffMambaConfig 217 | base_model_prefix = "model" 218 | supports_gradient_checkpointing = True 219 | _no_split_modules = ["DiffMambaDecoderLayer"] 220 | _skip_keys_device_placement = "past_key_values" 221 | 222 | def _init_weights(self, module): 223 | return 224 | 225 | def create_block( 226 | d_model, 227 | ssm_cfg=None, 228 | norm_epsilon=1e-5, 229 | rms_norm=False, 230 | residual_in_fp32=False, 231 | fused_add_norm=False, 232 | layer_idx=None, 233 | device=None, 234 | dtype=None, 235 | ): 236 | if ssm_cfg is None: 237 | ssm_cfg = {} 238 | factory_kwargs = {"device": device, "dtype": dtype} 239 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) 240 | norm_cls = partial( 241 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 242 | ) 243 | block = Block( 244 | d_model, 245 | mixer_cls, 246 | norm_cls=norm_cls, 247 | fused_add_norm=fused_add_norm, 248 | residual_in_fp32=residual_in_fp32, 249 | ) 250 | block.layer_idx = layer_idx 251 | return block 252 | 253 | class DiffMambaModel(DiffMambaPreTrainedModel): 254 | """ 255 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DiffMambaDecoderLayer`] 256 | 257 | Args: 258 | config: DiffMambaConfig 259 | """ 260 | 261 | def __init__(self, config: DiffMambaConfig): 262 | super().__init__(config) 263 | self.padding_idx = config.pad_token_id 264 | self.vocab_size = config.vocab_size 265 | self.n_mamba_inversion = config.n_mamba_inversion 266 | 267 | self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 268 | self.input_proj = nn.Sequential( 269 | nn.Linear(config.hidden_size, config.hidden_size), 270 | nn.Tanh(), 271 | nn.Linear(config.hidden_size, config.hidden_size), 272 | ) 273 | self.time_embedding = nn.Sequential( 274 | nn.Linear(config.hidden_size, config.hidden_size//2, bias=False), 275 | nn.SiLU(), 276 | nn.Linear(config.hidden_size//2, config.hidden_size, bias=False), 277 | ) 278 | # self.layers = nn.ModuleList([DiffMambaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 279 | 280 | self.layers = nn.ModuleList( 281 | [ 282 | create_block( 283 | config.hidden_size, 284 | ssm_cfg=None, 285 | norm_epsilon=config.rms_norm_eps, 286 | rms_norm=True, 287 | residual_in_fp32=True, 288 | fused_add_norm=True, 289 | layer_idx=i, 290 | # **factory_kwargs, 291 | ) 292 | for i in range(config.num_hidden_layers) 293 | ] 294 | ) 295 | self.cross_attention = config.cross_attention 296 | print("cross_attention", config.cross_attention) 297 | if config.cross_attention == True: 298 | self.cross_attention_layers = nn.ModuleList( 299 | nn.ModuleList([DiffMambaCrossAttention(config) for _ in range(config.num_hidden_layers)]) 300 | ) 301 | self.cross_attentions_norms = nn.ModuleList( 302 | nn.ModuleList([DiffMambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(config.num_hidden_layers)]) 303 | ) 304 | 305 | self.Last_layer_RMSNorm = DiffMambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 306 | self.RMSNorm = DiffMambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 307 | self.norm = DiffMambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 308 | 309 | self.gradient_checkpointing = False 310 | # Initialize weights and apply final processing 311 | # self.post_init() 312 | 313 | def get_input_embeddings(self): 314 | return self.embeddings 315 | 316 | def set_input_embeddings(self, value): 317 | self.embeddings = value 318 | 319 | 320 | def forward( 321 | self, 322 | input_embeds, 323 | timesteps, 324 | encoder_hidden_states: Optional[torch.Tensor] = None, 325 | position_ids: Optional[torch.LongTensor] = None, 326 | past_key_values: Optional[List[torch.FloatTensor]] = None, 327 | use_cache: Optional[bool] = None, 328 | output_attentions: Optional[bool] = None, 329 | output_hidden_states: Optional[bool] = None, 330 | return_dict: Optional[bool] = None, 331 | ) -> Union[Tuple, BaseModelOutputWithPast]: 332 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 333 | output_hidden_states = ( 334 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 335 | ) 336 | use_cache = use_cache if use_cache is not None else self.config.use_cache 337 | 338 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 339 | 340 | 341 | batch_size, seq_length = input_embeds.shape[:2] 342 | 343 | 344 | past_key_values_length = 0 345 | if past_key_values is not None: 346 | past_key_values_length = past_key_values[0][0].shape[2] 347 | 348 | if position_ids is None: 349 | device = input_embeds.device 350 | position_ids = torch.arange( 351 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 352 | ) 353 | position_ids = position_ids.unsqueeze(0) 354 | 355 | 356 | # embed positions 357 | 358 | 359 | if self.gradient_checkpointing and self.training: 360 | if use_cache: 361 | logger.warning_once( 362 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 363 | ) 364 | use_cache = False 365 | 366 | 367 | input_embeds = self.input_proj(input_embeds) 368 | time_embeddings = timestep_embedding(timesteps, self.config.hidden_size).unsqueeze(1).repeat(1, seq_length, 1) 369 | time_embeddings = self.time_embedding(time_embeddings.to(input_embeds.dtype)) 370 | input_embeds += time_embeddings #+ position_embeddings 371 | input_embeds = self.RMSNorm(input_embeds) 372 | 373 | hidden_states = input_embeds 374 | # decoder layers 375 | all_hidden_states = () if output_hidden_states else None 376 | all_self_attns = () if output_attentions else None 377 | next_decoder_cache = () if use_cache else None 378 | 379 | residual = None 380 | if self.cross_attention: 381 | for i, (layer_forward, cross_attention, norm) in enumerate(zip(self.layers, self.cross_attention_layers, self.cross_attentions_norms)): 382 | hidden_states, residual = layer_forward( 383 | hidden_states, residual, inference_params=None 384 | ) 385 | 386 | residual = norm(hidden_states + residual).to(self.dtype) 387 | hidden_states = cross_attention(residual, encoder_hidden_states) 388 | 389 | if (i + 1) % self.n_mamba_inversion == 0 and (i+1) != len(self.layers): 390 | hidden_states = torch.flip(hidden_states, dims=[1]) 391 | residual = torch.flip(residual, dims=[1]) 392 | else: 393 | for i, layer_forward in enumerate(self.layers): 394 | hidden_states, residual = layer_forward( 395 | hidden_states, residual, inference_params=None 396 | ) 397 | 398 | if (i + 1) % self.n_mamba_inversion == 0 and (i+1) != len(self.layers): 399 | hidden_states = torch.flip(hidden_states, dims=[1]) 400 | residual = torch.flip(residual, dims=[1]) 401 | 402 | hidden_states = self.Last_layer_RMSNorm(hidden_states + residual).to(self.dtype) 403 | 404 | # add hidden states from the last decoder layer 405 | if output_hidden_states: 406 | all_hidden_states += (hidden_states,) 407 | 408 | next_cache = next_decoder_cache if use_cache else None 409 | if not return_dict: 410 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 411 | return BaseModelOutputWithPast( 412 | last_hidden_state=hidden_states, 413 | past_key_values=next_cache, 414 | hidden_states=all_hidden_states, 415 | attentions=all_self_attns, 416 | ) 417 | 418 | 419 | class DiffMambaForDiffusionLM(DiffMambaPreTrainedModel): 420 | # _tied_weights_keys = ["lm_head.weight"] 421 | 422 | def __init__(self, config): 423 | super().__init__(config) 424 | self.model = DiffMambaModel(config) 425 | self.vocab_size = config.vocab_size 426 | self.out_proj = nn.Sequential( 427 | nn.Linear(config.hidden_size, config.hidden_size, bias=False), 428 | nn.Tanh(), 429 | nn.Linear(config.hidden_size, config.hidden_size, bias=False) 430 | ) 431 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 432 | 433 | # Initialize weights and apply final processing 434 | # self.post_init() 435 | 436 | def get_input_embeddings(self): 437 | return self.model.embeddings 438 | 439 | def set_input_embeddings(self, value): 440 | self.model.embeddings = value 441 | 442 | def get_output_embeddings(self): 443 | return self.lm_head 444 | 445 | def set_output_embeddings(self, new_embeddings): 446 | self.lm_head = new_embeddings 447 | 448 | def set_decoder(self, decoder): 449 | self.model = decoder 450 | 451 | def get_decoder(self): 452 | return self.model 453 | 454 | def apply_embeddings(self, input_ids): 455 | return self.model.embeddings(input_ids) 456 | 457 | def apply_lm_head(self, latents, labels=None): 458 | logits = self.lm_head(latents) 459 | logits = logits.float() 460 | loss=None 461 | if labels is not None: 462 | loss_fct = CrossEntropyLoss() # -100 index = padding token 463 | loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) 464 | return DiffusionLMOutput( 465 | loss=loss, 466 | logits=logits, 467 | last_hidden_state=latents, 468 | ) 469 | 470 | 471 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 472 | def forward( 473 | self, 474 | input_embeds, 475 | timesteps, 476 | encoder_hidden_states: Optional[torch.Tensor] = None, 477 | # attention_mask: Optional[torch.Tensor] = None, 478 | position_ids: Optional[torch.LongTensor] = None, 479 | past_key_values: Optional[List[torch.FloatTensor]] = None, 480 | labels: Optional[torch.LongTensor] = None, 481 | use_cache: Optional[bool] = None, 482 | output_attentions: Optional[bool] = None, 483 | output_hidden_states: Optional[bool] = None, 484 | return_dict: Optional[bool] = None, 485 | ) -> Union[Tuple, CausalLMOutputWithPast]: 486 | r""" 487 | Args: 488 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 489 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 490 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 491 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 492 | 493 | Returns: 494 | 495 | Example: 496 | 497 | ```python 498 | >>> from transformers import AutoTokenizer, DiffMambaForCausalLM 499 | 500 | >>> model = DiffMambaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 501 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 502 | 503 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 504 | >>> inputs = tokenizer(prompt, return_tensors="pt") 505 | 506 | >>> # Generate 507 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 508 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 509 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 510 | ```""" 511 | 512 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 513 | output_hidden_states = ( 514 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 515 | ) 516 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 517 | 518 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 519 | outputs = self.model( 520 | input_embeds=input_embeds, 521 | timesteps=timesteps, 522 | encoder_hidden_states=encoder_hidden_states, 523 | # attention_mask=attention_mask, 524 | position_ids=position_ids, 525 | past_key_values=past_key_values, 526 | use_cache=use_cache, 527 | output_attentions=output_attentions, 528 | output_hidden_states=output_hidden_states, 529 | return_dict=return_dict, 530 | ) 531 | 532 | hidden_states = outputs[0] 533 | hidden_states = self.out_proj(hidden_states) 534 | if self.config.pretraining_tp > 1: 535 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 536 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 537 | logits = torch.cat(logits, dim=-1) 538 | else: 539 | logits = self.lm_head(hidden_states) 540 | logits = logits.float() 541 | 542 | loss = None 543 | if labels is not None: 544 | # # Shift so that tokens < n predict n 545 | # shift_logits = logits[..., :-1, :].contiguous() 546 | # shift_labels = labels[..., 1:].contiguous() 547 | # # Flatten the tokens 548 | # loss_fct = CrossEntropyLoss() 549 | # shift_logits = shift_logits.view(-1, self.config.vocab_size) 550 | # shift_labels = shift_labels.view(-1) 551 | # # Enable model parallelism 552 | # shift_labels = shift_labels.to(shift_logits.device) 553 | # loss = loss_fct(shift_logits, shift_labels) 554 | loss_fct = CrossEntropyLoss() # -100 index = padding token 555 | loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) 556 | 557 | if not return_dict: 558 | output = (logits,) + outputs[1:] 559 | return (loss,) + output if loss is not None else output 560 | 561 | return DiffusionLMOutput( 562 | loss=loss, 563 | logits=logits, 564 | last_hidden_state=hidden_states, 565 | 566 | ) 567 | 568 | def prepare_inputs_for_generation( 569 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 570 | ): 571 | if past_key_values is not None: 572 | past_length = past_key_values[0][0].shape[2] 573 | 574 | # Some generation methods already pass only the last input ID 575 | if input_ids.shape[1] > past_length: 576 | remove_prefix_length = past_length 577 | else: 578 | # Default to old behavior: keep only final ID 579 | remove_prefix_length = input_ids.shape[1] - 1 580 | 581 | input_ids = input_ids[:, remove_prefix_length:] 582 | 583 | position_ids = kwargs.get("position_ids", None) 584 | if attention_mask is not None and position_ids is None: 585 | # create position_ids on the fly for batch generation 586 | position_ids = attention_mask.long().cumsum(-1) - 1 587 | position_ids.masked_fill_(attention_mask == 0, 1) 588 | if past_key_values: 589 | position_ids = position_ids[:, -input_ids.shape[1] :] 590 | 591 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 592 | if inputs_embeds is not None and past_key_values is None: 593 | model_inputs = {"inputs_embeds": inputs_embeds} 594 | else: 595 | model_inputs = {"input_ids": input_ids} 596 | 597 | model_inputs.update( 598 | { 599 | "position_ids": position_ids, 600 | "past_key_values": past_key_values, 601 | "use_cache": kwargs.get("use_cache"), 602 | "attention_mask": attention_mask, 603 | } 604 | ) 605 | return model_inputs 606 | 607 | @staticmethod 608 | def _reorder_cache(past_key_values, beam_idx): 609 | reordered_past = () 610 | for layer_past in past_key_values: 611 | reordered_past += ( 612 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 613 | ) 614 | return reordered_past 615 | 616 | 617 | -------------------------------------------------------------------------------- /src/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["ddpm", "euler_ancestral_discrete"] -------------------------------------------------------------------------------- /src/schedulers/ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim 16 | 17 | import math 18 | from dataclasses import dataclass 19 | from typing import List, Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import torch 23 | 24 | from diffusers.configuration_utils import ConfigMixin, register_to_config 25 | from diffusers.utils import BaseOutput 26 | from diffusers.utils.torch_utils import randn_tensor 27 | from diffusers import SchedulerMixin 28 | 29 | 30 | @dataclass 31 | class DDPMSchedulerOutput(BaseOutput): 32 | """ 33 | Output class for the scheduler's `step` function output. 34 | 35 | Args: 36 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 37 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 38 | denoising loop. 39 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 40 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 41 | `pred_original_sample` can be used to preview progress or for guidance. 42 | """ 43 | 44 | prev_sample: torch.FloatTensor 45 | pred_original_sample: Optional[torch.FloatTensor] = None 46 | 47 | 48 | def betas_for_alpha_bar( 49 | num_diffusion_timesteps, 50 | max_beta=0.999, 51 | alpha_transform_type="cosine", 52 | ): 53 | """ 54 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 55 | (1-beta) over time from t = [0,1]. 56 | 57 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 58 | to that part of the diffusion process. 59 | 60 | 61 | Args: 62 | num_diffusion_timesteps (`int`): the number of betas to produce. 63 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 64 | prevent singularities. 65 | alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. 66 | Choose from `cosine` or `exp` 67 | 68 | Returns: 69 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 70 | """ 71 | if alpha_transform_type == "cosine": 72 | 73 | def alpha_bar_fn(t): 74 | return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 75 | 76 | elif alpha_transform_type == "exp": 77 | 78 | def alpha_bar_fn(t): 79 | return math.exp(t * -12.0) 80 | 81 | elif alpha_transform_type == "sqrt": 82 | def alpha_bar_fn(t): 83 | return 1 - np.sqrt(t + 0.0001) 84 | else: 85 | raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") 86 | 87 | betas = [] 88 | for i in range(num_diffusion_timesteps): 89 | t1 = i / num_diffusion_timesteps 90 | t2 = (i + 1) / num_diffusion_timesteps 91 | betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) 92 | return torch.tensor(betas, dtype=torch.float32) 93 | 94 | class DDPMScheduler(SchedulerMixin, ConfigMixin): 95 | """ 96 | `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling. 97 | 98 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 99 | methods the library implements for all schedulers such as loading and saving. 100 | 101 | Args: 102 | num_train_timesteps (`int`, defaults to 1000): 103 | The number of diffusion steps to train the model. 104 | beta_start (`float`, defaults to 0.0001): 105 | The starting `beta` value of inference. 106 | beta_end (`float`, defaults to 0.02): 107 | The final `beta` value. 108 | beta_schedule (`str`, defaults to `"linear"`): 109 | The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from 110 | `linear`, `scaled_linear`, or `squaredcos_cap_v2`. 111 | variance_type (`str`, defaults to `"fixed_small"`): 112 | Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, 113 | `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. 114 | clip_sample (`bool`, defaults to `True`): 115 | Clip the predicted sample for numerical stability. 116 | clip_sample_range (`float`, defaults to 1.0): 117 | The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. 118 | prediction_type (`str`, defaults to `epsilon`, *optional*): 119 | Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), 120 | `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen 121 | Video](https://imagen.research.google/video/paper.pdf) paper). 122 | thresholding (`bool`, defaults to `False`): 123 | Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such 124 | as Stable Diffusion. 125 | dynamic_thresholding_ratio (`float`, defaults to 0.995): 126 | The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. 127 | sample_max_value (`float`, defaults to 1.0): 128 | The threshold value for dynamic thresholding. Valid only when `thresholding=True`. 129 | timestep_spacing (`str`, defaults to `"leading"`): 130 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 131 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 132 | steps_offset (`int`, defaults to 0): 133 | An offset added to the inference steps. You can use a combination of `offset=1` and 134 | `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable 135 | Diffusion. 136 | """ 137 | 138 | # _compatibles = [e.name for e in KarrasDiffusionSchedulers] 139 | order = 1 140 | 141 | @register_to_config 142 | def __init__( 143 | self, 144 | num_train_timesteps: int = 1000, 145 | beta_start: float = 0.0001, 146 | beta_end: float = 0.02, 147 | beta_schedule: str = "sqrt", 148 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, 149 | variance_type: str = "fixed_small", 150 | clip_sample: bool = True, 151 | prediction_type: str = "sample", 152 | thresholding: bool = False, 153 | dynamic_thresholding_ratio: float = 0.995, 154 | clip_sample_range: float = 1.0, 155 | sample_max_value: float = 1.0, 156 | timestep_spacing: str = "leading", 157 | steps_offset: int = 0, 158 | ): 159 | if trained_betas is not None: 160 | self.betas = torch.tensor(trained_betas, dtype=torch.float32) 161 | elif beta_schedule == "linear": 162 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 163 | elif beta_schedule == "sqrt": 164 | self.betas = betas_for_alpha_bar( 165 | num_train_timesteps, 166 | alpha_transform_type="sqrt", 167 | ) 168 | elif beta_schedule == "scaled_linear": 169 | # this schedule is very specific to the latent diffusion model. 170 | self.betas = ( 171 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 172 | ) 173 | elif beta_schedule == "squaredcos_cap_v2": 174 | # Glide cosine schedule 175 | self.betas = betas_for_alpha_bar(num_train_timesteps) 176 | elif beta_schedule == "sigmoid": 177 | # GeoDiff sigmoid schedule 178 | betas = torch.linspace(-6, 6, num_train_timesteps) 179 | self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 180 | else: 181 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 182 | self.alphas = 1.0 - self.betas 183 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 184 | self.one = torch.tensor(1.0) 185 | 186 | # standard deviation of the initial noise distribution 187 | self.init_noise_sigma = 1.0 188 | 189 | # setable values 190 | self.custom_timesteps = False 191 | self.num_inference_steps = None 192 | self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) 193 | 194 | self.variance_type = variance_type 195 | 196 | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: 197 | """ 198 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 199 | current timestep. 200 | 201 | Args: 202 | sample (`torch.FloatTensor`): 203 | The input sample. 204 | timestep (`int`, *optional*): 205 | The current timestep in the diffusion chain. 206 | 207 | Returns: 208 | `torch.FloatTensor`: 209 | A scaled input sample. 210 | """ 211 | return sample 212 | 213 | def set_timesteps( 214 | self, 215 | num_inference_steps: Optional[int] = None, 216 | device: Union[str, torch.device] = None, 217 | timesteps: Optional[List[int]] = None, 218 | ): 219 | """ 220 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 221 | 222 | Args: 223 | num_inference_steps (`int`): 224 | The number of diffusion steps used when generating samples with a pre-trained model. If used, 225 | `timesteps` must be `None`. 226 | device (`str` or `torch.device`, *optional*): 227 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 228 | timesteps (`List[int]`, *optional*): 229 | Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default 230 | timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, 231 | `num_inference_steps` must be `None`. 232 | 233 | """ 234 | if num_inference_steps is not None and timesteps is not None: 235 | raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") 236 | 237 | if timesteps is not None: 238 | for i in range(1, len(timesteps)): 239 | if timesteps[i] >= timesteps[i - 1]: 240 | raise ValueError("`custom_timesteps` must be in descending order.") 241 | 242 | if timesteps[0] >= self.config.num_train_timesteps: 243 | raise ValueError( 244 | f"`timesteps` must start before `self.config.train_timesteps`:" 245 | f" {self.config.num_train_timesteps}." 246 | ) 247 | 248 | timesteps = np.array(timesteps, dtype=np.int64) 249 | self.custom_timesteps = True 250 | else: 251 | if num_inference_steps > self.config.num_train_timesteps: 252 | raise ValueError( 253 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" 254 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" 255 | f" maximal {self.config.num_train_timesteps} timesteps." 256 | ) 257 | 258 | self.num_inference_steps = num_inference_steps 259 | self.custom_timesteps = False 260 | 261 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 262 | if self.config.timestep_spacing == "linspace": 263 | timesteps = ( 264 | np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) 265 | .round()[::-1] 266 | .copy() 267 | .astype(np.int64) 268 | ) 269 | elif self.config.timestep_spacing == "leading": 270 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 271 | # creates integer timesteps by multiplying by ratio 272 | # casting to int to avoid issues when num_inference_step is power of 3 273 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) 274 | timesteps += self.config.steps_offset 275 | elif self.config.timestep_spacing == "trailing": 276 | step_ratio = self.config.num_train_timesteps / self.num_inference_steps 277 | # creates integer timesteps by multiplying by ratio 278 | # casting to int to avoid issues when num_inference_step is power of 3 279 | timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) 280 | timesteps -= 1 281 | else: 282 | raise ValueError( 283 | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." 284 | ) 285 | 286 | self.timesteps = torch.from_numpy(timesteps).to(device) 287 | 288 | def _get_variance(self, t, predicted_variance=None, variance_type=None): 289 | prev_t = self.previous_timestep(t) 290 | 291 | alpha_prod_t = self.alphas_cumprod[t] 292 | alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one 293 | current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev 294 | 295 | # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) 296 | # and sample from it to get previous sample 297 | # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample 298 | variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t 299 | 300 | # we always take the log of variance, so clamp it to ensure it's not 0 301 | variance = torch.clamp(variance, min=1e-20) 302 | 303 | if variance_type is None: 304 | variance_type = self.config.variance_type 305 | 306 | # hacks - were probably added for training stability 307 | if variance_type == "fixed_small": 308 | variance = variance 309 | # for rl-diffuser https://arxiv.org/abs/2205.09991 310 | elif variance_type == "fixed_small_log": 311 | variance = torch.log(variance) 312 | variance = torch.exp(0.5 * variance) 313 | elif variance_type == "fixed_large": 314 | variance = current_beta_t 315 | elif variance_type == "fixed_large_log": 316 | # Glide max_log 317 | variance = torch.log(current_beta_t) 318 | elif variance_type == "learned": 319 | return predicted_variance 320 | elif variance_type == "learned_range": 321 | min_log = torch.log(variance) 322 | max_log = torch.log(current_beta_t) 323 | frac = (predicted_variance + 1) / 2 324 | variance = frac * max_log + (1 - frac) * min_log 325 | 326 | return variance 327 | 328 | def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: 329 | """ 330 | "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the 331 | prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by 332 | s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing 333 | pixels from saturation at each step. We find that dynamic thresholding results in significantly better 334 | photorealism as well as better image-text alignment, especially when using very large guidance weights." 335 | 336 | https://arxiv.org/abs/2205.11487 337 | """ 338 | dtype = sample.dtype 339 | batch_size, channels, *remaining_dims = sample.shape 340 | 341 | if dtype not in (torch.float32, torch.float64): 342 | sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half 343 | 344 | # Flatten sample for doing quantile calculation along each image 345 | sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) 346 | 347 | abs_sample = sample.abs() # "a certain percentile absolute pixel value" 348 | 349 | s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) 350 | s = torch.clamp( 351 | s, min=1, max=self.config.sample_max_value 352 | ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] 353 | s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 354 | sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" 355 | 356 | sample = sample.reshape(batch_size, channels, *remaining_dims) 357 | sample = sample.to(dtype) 358 | 359 | return sample 360 | 361 | def step( 362 | self, 363 | model_output: torch.FloatTensor, 364 | timestep: int, 365 | sample: torch.FloatTensor, 366 | generator=None, 367 | return_dict: bool = True, 368 | ) -> Union[DDPMSchedulerOutput, Tuple]: 369 | """ 370 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 371 | process from the learned model outputs (most often the predicted noise). 372 | 373 | Args: 374 | model_output (`torch.FloatTensor`): 375 | The direct output from learned diffusion model. 376 | timestep (`float`): 377 | The current discrete timestep in the diffusion chain. 378 | sample (`torch.FloatTensor`): 379 | A current instance of a sample created by the diffusion process. 380 | generator (`torch.Generator`, *optional*): 381 | A random number generator. 382 | return_dict (`bool`, *optional*, defaults to `True`): 383 | Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. 384 | 385 | Returns: 386 | [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: 387 | If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a 388 | tuple is returned where the first element is the sample tensor. 389 | 390 | """ 391 | t = timestep 392 | 393 | prev_t = self.previous_timestep(t) 394 | 395 | if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: 396 | model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) 397 | else: 398 | predicted_variance = None 399 | 400 | # 1. compute alphas, betas 401 | alpha_prod_t = self.alphas_cumprod[t] 402 | alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one 403 | beta_prod_t = 1 - alpha_prod_t 404 | beta_prod_t_prev = 1 - alpha_prod_t_prev 405 | current_alpha_t = alpha_prod_t / alpha_prod_t_prev 406 | current_beta_t = 1 - current_alpha_t 407 | 408 | # 2. compute predicted original sample from predicted noise also called 409 | # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf 410 | if self.config.prediction_type == "epsilon": 411 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 412 | elif self.config.prediction_type == "sample": 413 | pred_original_sample = model_output 414 | elif self.config.prediction_type == "v_prediction": 415 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 416 | else: 417 | raise ValueError( 418 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" 419 | " `v_prediction` for the DDPMScheduler." 420 | ) 421 | 422 | # 3. Clip or threshold "predicted x_0" 423 | if self.config.thresholding: 424 | pred_original_sample = self._threshold_sample(pred_original_sample) 425 | elif self.config.clip_sample: 426 | pred_original_sample = pred_original_sample.clamp( 427 | -self.config.clip_sample_range, self.config.clip_sample_range 428 | ) 429 | 430 | # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t 431 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 432 | pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t 433 | current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t 434 | 435 | # 5. Compute predicted previous sample µ_t 436 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 437 | pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample 438 | 439 | # 6. Add noise 440 | variance = 0 441 | if t > 0: 442 | device = model_output.device 443 | variance_noise = randn_tensor( 444 | model_output.shape, generator=generator, device=device, dtype=model_output.dtype 445 | ) 446 | if self.variance_type == "fixed_small_log": 447 | variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise 448 | elif self.variance_type == "learned_range": 449 | variance = self._get_variance(t, predicted_variance=predicted_variance) 450 | variance = torch.exp(0.5 * variance) * variance_noise 451 | else: 452 | variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise 453 | 454 | pred_prev_sample = pred_prev_sample + variance 455 | 456 | if not return_dict: 457 | return (pred_prev_sample,) 458 | 459 | return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) 460 | 461 | def add_noise( 462 | self, 463 | original_samples: torch.FloatTensor, 464 | noise: torch.FloatTensor, 465 | timesteps: torch.IntTensor, 466 | ) -> torch.FloatTensor: 467 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 468 | alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) 469 | timesteps = timesteps.to(original_samples.device) 470 | 471 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 472 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 473 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 474 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 475 | 476 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 477 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 478 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 479 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 480 | 481 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 482 | return noisy_samples 483 | 484 | def get_velocity( 485 | self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor 486 | ) -> torch.FloatTensor: 487 | # Make sure alphas_cumprod and timestep have same device and dtype as sample 488 | alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) 489 | timesteps = timesteps.to(sample.device) 490 | 491 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 492 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 493 | while len(sqrt_alpha_prod.shape) < len(sample.shape): 494 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 495 | 496 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 497 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 498 | while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): 499 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 500 | 501 | velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample 502 | return velocity 503 | 504 | def __len__(self): 505 | return self.config.num_train_timesteps 506 | 507 | def previous_timestep(self, timestep): 508 | if self.custom_timesteps: 509 | index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] 510 | if index == self.timesteps.shape[0] - 1: 511 | prev_t = torch.tensor(-1) 512 | else: 513 | prev_t = self.timesteps[index + 1] 514 | else: 515 | num_inference_steps = ( 516 | self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps 517 | ) 518 | prev_t = timestep - self.config.num_train_timesteps // num_inference_steps 519 | 520 | return prev_t 521 | -------------------------------------------------------------------------------- /src/schedulers/euler_ancestral_discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from dataclasses import dataclass 17 | from typing import List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.utils import BaseOutput, logging 24 | from diffusers.utils.torch_utils import randn_tensor 25 | from diffusers import SchedulerMixin 26 | 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | 31 | @dataclass 32 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete 33 | class EulerAncestralDiscreteSchedulerOutput(BaseOutput): 34 | """ 35 | Output class for the scheduler's `step` function output. 36 | 37 | Args: 38 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 39 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 40 | denoising loop. 41 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 42 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 43 | `pred_original_sample` can be used to preview progress or for guidance. 44 | """ 45 | 46 | prev_sample: torch.FloatTensor 47 | pred_original_sample: Optional[torch.FloatTensor] = None 48 | 49 | 50 | # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 51 | def betas_for_alpha_bar( 52 | num_diffusion_timesteps, 53 | max_beta=0.999, 54 | alpha_transform_type="cosine", 55 | ): 56 | """ 57 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 58 | (1-beta) over time from t = [0,1]. 59 | 60 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 61 | to that part of the diffusion process. 62 | 63 | 64 | Args: 65 | num_diffusion_timesteps (`int`): the number of betas to produce. 66 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 67 | prevent singularities. 68 | alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. 69 | Choose from `cosine` or `exp` 70 | 71 | Returns: 72 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 73 | """ 74 | if alpha_transform_type == "cosine": 75 | 76 | def alpha_bar_fn(t): 77 | return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 78 | 79 | elif alpha_transform_type == "exp": 80 | 81 | def alpha_bar_fn(t): 82 | return math.exp(t * -12.0) 83 | 84 | elif alpha_transform_type == "sqrt": 85 | def alpha_bar_fn(t): 86 | return 1 - np.sqrt(t + 0.0001) 87 | 88 | else: 89 | raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") 90 | 91 | betas = [] 92 | for i in range(num_diffusion_timesteps): 93 | t1 = i / num_diffusion_timesteps 94 | t2 = (i + 1) / num_diffusion_timesteps 95 | betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) 96 | return torch.tensor(betas, dtype=torch.float32) 97 | 98 | def betas_for_alpha_bar_sqrt(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 99 | """ 100 | Create a beta schedule that discretizes the given alpha_t_bar function, 101 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 102 | 103 | :param num_diffusion_timesteps: the number of betas to produce. 104 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 105 | produces the cumulative product of (1-beta) up to that 106 | part of the diffusion process. 107 | :param max_beta: the maximum beta to use; use values lower than 1 to 108 | prevent singularities. 109 | """ 110 | betas = [] 111 | for i in range(num_diffusion_timesteps): 112 | t1 = i / num_diffusion_timesteps 113 | t2 = (i + 1) / num_diffusion_timesteps 114 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 115 | return np.array(betas) 116 | 117 | class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): 118 | """ 119 | Ancestral sampling with Euler method steps. 120 | 121 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 122 | methods the library implements for all schedulers such as loading and saving. 123 | 124 | Args: 125 | num_train_timesteps (`int`, defaults to 1000): 126 | The number of diffusion steps to train the model. 127 | beta_start (`float`, defaults to 0.0001): 128 | The starting `beta` value of inference. 129 | beta_end (`float`, defaults to 0.02): 130 | The final `beta` value. 131 | beta_schedule (`str`, defaults to `"linear"`): 132 | The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from 133 | `linear` or `scaled_linear`. 134 | trained_betas (`np.ndarray`, *optional*): 135 | Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. 136 | prediction_type (`str`, defaults to `epsilon`, *optional*): 137 | Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), 138 | `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen 139 | Video](https://imagen.research.google/video/paper.pdf) paper). 140 | timestep_spacing (`str`, defaults to `"linspace"`): 141 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 142 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 143 | steps_offset (`int`, defaults to 0): 144 | An offset added to the inference steps. You can use a combination of `offset=1` and 145 | `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable 146 | Diffusion. 147 | """ 148 | 149 | # _compatibles = [e.name for e in KarrasDiffusionSchedulers] 150 | order = 1 151 | 152 | @register_to_config 153 | def __init__( 154 | self, 155 | num_train_timesteps: int = 1000, 156 | beta_start: float = 0.0001, 157 | beta_end: float = 0.02, 158 | beta_schedule: str = "sqrt", 159 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, 160 | prediction_type: str = "sample", 161 | timestep_spacing: str = "linspace", 162 | steps_offset: int = 0, 163 | ): 164 | if trained_betas is not None: 165 | self.betas = torch.tensor(trained_betas, dtype=torch.float32) 166 | elif beta_schedule == "linear": 167 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 168 | elif beta_schedule == "sqrt": 169 | self.betas = betas_for_alpha_bar( 170 | num_train_timesteps, 171 | alpha_transform_type="sqrt", 172 | ) 173 | elif beta_schedule == "scaled_linear": 174 | # this schedule is very specific to the latent diffusion model. 175 | self.betas = ( 176 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 177 | ) 178 | elif beta_schedule == "squaredcos_cap_v2": 179 | # Glide cosine schedule 180 | self.betas = betas_for_alpha_bar(num_train_timesteps) 181 | else: 182 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 183 | 184 | self.alphas = 1.0 - self.betas 185 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 186 | 187 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 188 | sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) 189 | self.sigmas = torch.from_numpy(sigmas) 190 | 191 | # setable values 192 | self.num_inference_steps = None 193 | timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() 194 | self.timesteps = torch.from_numpy(timesteps) 195 | self.is_scale_input_called = False 196 | 197 | self._step_index = None 198 | 199 | @property 200 | def init_noise_sigma(self): 201 | # standard deviation of the initial noise distribution 202 | if self.config.timestep_spacing in ["linspace", "trailing"]: 203 | return self.sigmas.max() 204 | 205 | return (self.sigmas.max() ** 2 + 1) ** 0.5 206 | 207 | @property 208 | def step_index(self): 209 | """ 210 | The index counter for current timestep. It will increae 1 after each scheduler step. 211 | """ 212 | return self._step_index 213 | 214 | def scale_model_input( 215 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] 216 | ) -> torch.FloatTensor: 217 | """ 218 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 219 | current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. 220 | 221 | Args: 222 | sample (`torch.FloatTensor`): 223 | The input sample. 224 | timestep (`int`, *optional*): 225 | The current timestep in the diffusion chain. 226 | 227 | Returns: 228 | `torch.FloatTensor`: 229 | A scaled input sample. 230 | """ 231 | 232 | if self.step_index is None: 233 | self._init_step_index(timestep) 234 | 235 | sigma = self.sigmas[self.step_index] 236 | sample = sample / ((sigma**2 + 1) ** 0.5) 237 | self.is_scale_input_called = True 238 | return sample 239 | 240 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 241 | """ 242 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 243 | 244 | Args: 245 | num_inference_steps (`int`): 246 | The number of diffusion steps used when generating samples with a pre-trained model. 247 | device (`str` or `torch.device`, *optional*): 248 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 249 | """ 250 | self.num_inference_steps = num_inference_steps 251 | 252 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 253 | if self.config.timestep_spacing == "linspace": 254 | timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ 255 | ::-1 256 | ].copy() 257 | elif self.config.timestep_spacing == "leading": 258 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 259 | # creates integer timesteps by multiplying by ratio 260 | # casting to int to avoid issues when num_inference_step is power of 3 261 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) 262 | timesteps += self.config.steps_offset 263 | elif self.config.timestep_spacing == "trailing": 264 | step_ratio = self.config.num_train_timesteps / self.num_inference_steps 265 | # creates integer timesteps by multiplying by ratio 266 | # casting to int to avoid issues when num_inference_step is power of 3 267 | timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) 268 | timesteps -= 1 269 | else: 270 | raise ValueError( 271 | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." 272 | ) 273 | 274 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) 275 | sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) 276 | sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) 277 | self.sigmas = torch.from_numpy(sigmas).to(device=device) 278 | 279 | self.timesteps = torch.from_numpy(timesteps).to(device=device) 280 | self._step_index = None 281 | 282 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index 283 | def _init_step_index(self, timestep): 284 | if isinstance(timestep, torch.Tensor): 285 | timestep = timestep.to(self.timesteps.device) 286 | 287 | index_candidates = (self.timesteps == timestep).nonzero() 288 | 289 | # The sigma index that is taken for the **very** first `step` 290 | # is always the second index (or the last index if there is only 1) 291 | # This way we can ensure we don't accidentally skip a sigma in 292 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 293 | if len(index_candidates) > 1: 294 | step_index = index_candidates[1] 295 | else: 296 | step_index = index_candidates[0] 297 | 298 | self._step_index = step_index.item() 299 | 300 | def step( 301 | self, 302 | model_output: torch.FloatTensor, 303 | timestep: Union[float, torch.FloatTensor], 304 | sample: torch.FloatTensor, 305 | generator: Optional[torch.Generator] = None, 306 | return_dict: bool = True, 307 | ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: 308 | """ 309 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 310 | process from the learned model outputs (most often the predicted noise). 311 | 312 | Args: 313 | model_output (`torch.FloatTensor`): 314 | The direct output from learned diffusion model. 315 | timestep (`float`): 316 | The current discrete timestep in the diffusion chain. 317 | sample (`torch.FloatTensor`): 318 | A current instance of a sample created by the diffusion process. 319 | generator (`torch.Generator`, *optional*): 320 | A random number generator. 321 | return_dict (`bool`): 322 | Whether or not to return a 323 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. 324 | 325 | Returns: 326 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: 327 | If return_dict is `True`, 328 | [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, 329 | otherwise a tuple is returned where the first element is the sample tensor. 330 | 331 | """ 332 | 333 | if ( 334 | isinstance(timestep, int) 335 | or isinstance(timestep, torch.IntTensor) 336 | or isinstance(timestep, torch.LongTensor) 337 | ): 338 | raise ValueError( 339 | ( 340 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 341 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 342 | " one of the `scheduler.timesteps` as a timestep." 343 | ), 344 | ) 345 | 346 | if not self.is_scale_input_called: 347 | logger.warning( 348 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 349 | "See `StableDiffusionPipeline` for a usage example." 350 | ) 351 | 352 | if self.step_index is None: 353 | self._init_step_index(timestep) 354 | 355 | sigma = self.sigmas[self.step_index] 356 | 357 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 358 | if self.config.prediction_type == "epsilon": 359 | pred_original_sample = sample - sigma * model_output 360 | elif self.config.prediction_type == "sample": 361 | pred_original_sample = model_output 362 | elif self.config.prediction_type == "v_prediction": 363 | # * c_out + input * c_skip 364 | pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) 365 | # elif self.config.prediction_type == "sample": 366 | # raise NotImplementedError("prediction_type not implemented yet: sample") 367 | else: 368 | raise ValueError( 369 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" 370 | ) 371 | 372 | sigma_from = self.sigmas[self.step_index] 373 | sigma_to = self.sigmas[self.step_index + 1] 374 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 375 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 376 | 377 | # 2. Convert to an ODE derivative 378 | derivative = (sample - pred_original_sample) / sigma 379 | 380 | dt = sigma_down - sigma 381 | 382 | prev_sample = sample + derivative * dt 383 | 384 | device = model_output.device 385 | noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) 386 | 387 | prev_sample = prev_sample + noise * sigma_up 388 | 389 | # upon completion increase step index by one 390 | self._step_index += 1 391 | 392 | if not return_dict: 393 | return (prev_sample,) 394 | 395 | return EulerAncestralDiscreteSchedulerOutput( 396 | prev_sample=prev_sample, pred_original_sample=pred_original_sample 397 | ) 398 | 399 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise 400 | def add_noise( 401 | self, 402 | original_samples: torch.FloatTensor, 403 | noise: torch.FloatTensor, 404 | timesteps: torch.FloatTensor, 405 | ) -> torch.FloatTensor: 406 | # Make sure sigmas and timesteps have the same device and dtype as original_samples 407 | sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) 408 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): 409 | # mps does not support float64 410 | schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) 411 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32) 412 | else: 413 | schedule_timesteps = self.timesteps.to(original_samples.device) 414 | timesteps = timesteps.to(original_samples.device) 415 | 416 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 417 | 418 | sigma = sigmas[step_indices].flatten() 419 | while len(sigma.shape) < len(original_samples.shape): 420 | sigma = sigma.unsqueeze(-1) 421 | 422 | noisy_samples = original_samples + noise * sigma 423 | return noisy_samples 424 | 425 | def __len__(self): 426 | return self.config.num_train_timesteps 427 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python3 train_sample.py --pretrained_model_name_or_path="models/diffmamba-mini-sample" --dataset_name="Gustavosta/Stable-Diffusion-Prompts" --output_dir="models/diffmamba-mini-sample-trained" -------------------------------------------------------------------------------- /train_denoise_decoder.sh: -------------------------------------------------------------------------------- 1 | python3 ./scripts/train_denoise_decoder.py \ 2 | --pretrained_model_name_or_path="models/diffmamba-mini-sample-trained" \ 3 | --dataset_name="Gustavosta/Stable-Diffusion-Prompts" \ 4 | --output_dir="models/diffmamba-mini-sample-trained" \ 5 | --text_column Prompt \ 6 | --train_batch_size 8 \ 7 | --context_length 64 -------------------------------------------------------------------------------- /train_seq2seq.sh: -------------------------------------------------------------------------------- 1 | python3 train_seq2seq_completion.py \ 2 | --pretrained_model_name_or_path="models/diffMamba-mini-sample" \ 3 | --train_file="../roberta/data/brwac-train.txt" \ 4 | --output_dir="Gustavosta/Stable-Diffusion-Prompts" \ 5 | --text_column text \ 6 | --train_batch_size 8 \ 7 | --context_length 128 --------------------------------------------------------------------------------