├── .gitignore ├── .gitkeep ├── README.md ├── config ├── __init__.py └── default.yaml ├── data ├── .gitkeep ├── test │ └── keep ├── textures │ └── keep └── train │ └── keep ├── dataset ├── __init__.py └── mnist_color_texture_dataset.py ├── model ├── __init__.py ├── dalle.py ├── decoder.py ├── discrete_vae.py ├── encoder.py ├── mingpt.py └── quantizer.py ├── requirements.txt └── tools ├── __init__.py ├── generate_image.py ├── infer_dvae.py ├── train_dalle.py └── train_dvae.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all image files 2 | *.jpg 3 | *.png 4 | *.jpeg 5 | 6 | # Ignore pycharm and system files 7 | .DS_Store 8 | *.idea 9 | __pycache__ 10 | *.zip 11 | 12 | # Ignore dataset files 13 | *.csv 14 | 15 | # Ignore checkpoints 16 | *.pth 17 | 18 | # Ignore pickle files 19 | *.pkl -------------------------------------------------------------------------------- /.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/.gitkeep -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | DallE Implementation in pytorch with generation using mingpt 2 | ======== 3 | 4 | This repository implements DallE-1 [Zero-Shot Text-to-Image Generation](https://arxiv.org/abs/2102.12092) on a synthetic dataset of mnist colored numbers on textures/solid background . 5 | 6 | 7 | ## DallE Tutorial Video 8 | 9 | DallE Tutorial 11 | 12 | 13 | ## Sample from dataset 14 | 15 | 16 | 17 | 18 | 19 | A lot of parts of the implementation have been taken from below two repositories: 20 | 1. GPT from - https://github.com/karpathy/minGPT/blob/master/mingpt/model.py 21 | 2. Parts of DallE implementation from https://github.com/lucidrains/DALLE-pytorch/tree/main/dalle_pytorch . 22 | 23 | I have only kept the minimal version of Dalle which allows us to get decent results(on this dataset) and play around with it. If you are looking for a much more efficient and complete implementation please use the above repo. 24 | 25 | ## Data preparation 26 | For setting up the mnist dataset: 27 | Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation 28 | 29 | Download Quarter RGB resolution texture data from [ALOT Homepage](https://aloi.science.uva.nl/public_alot/) 30 | 31 | If you are facing issues then use `curl` 32 | 33 | `curl -O https://aloi.science.uva.nl/public_alot/tars/alot_png4.tar` 34 | 35 | 36 | In case you want to train on higher resolution, you can download that as well and but you would have to create new train.json and test.json. 37 | Rest of the code should work fine as long as you create valid json files. 38 | 39 | Download train.json and test.json from [Drive](https://drive.google.com/drive/folders/1DSpNaM6hk8VNFVKHs-VK97AlP_8ynRKC?usp=sharing) 40 | Verify the data directory has the following structure after textures download 41 | ``` 42 | DallE/data/textures/{texture_number} 43 | *.png 44 | DallE/data/train/images/{0/1/.../9} 45 | *.png 46 | DallE/data/test/images/{0/1/.../9} 47 | *.png 48 | DallE/data/train.json 49 | DallE/data/test.json 50 | ``` 51 | 52 | # Quickstart 53 | * Create a new conda environment with python 3.8 then run below commands 54 | * ```git clone https://github.com/explainingai-code/DallE.git``` 55 | * ```cd DallE``` 56 | * ```pip install -r requirements.txt``` 57 | * For training/inferencing discrete vae and gpt use the below commands passing the desired configuration file as the config argument in case you want to play with it. 58 | * ```python -m tools.train_dvae``` for training discrete vae 59 | * ```python -m tools.infer_dvae``` for generating reconstructions 60 | * ```python -m tools.train_dalle``` for training minimal version of DallE 61 | * ```python -m tools.generate_image``` for using the trained DallE to generate images 62 | 63 | ## Configuration 64 | * ```config/default.yaml``` - Allows you to play with different components of discrete vae as well as DallE and play around with these modifications 65 | 66 | 67 | ## Output 68 | Outputs will be saved according to the configuration present in yaml files. 69 | 70 | For every run a folder of ```task_name``` key in config will be created and ```output_train_dir``` will be created inside it. 71 | 72 | During training of Discrete VAE and DallE the following output will be saved 73 | * Best Model checkpoints(DVAE and DallE) in ```task_name``` directory 74 | 75 | During inference the following output will be saved 76 | * Reconstructions for sample of test set in ```task_name/dvae_reconstruction.png``` 77 | * GPT generation output in ```task_name/generation_results.png``` 78 | 79 | 80 | ## Sample Output for DallE 81 | 82 | Running default config DiscreteVAE should give you below reconstructions (left - input | right - reconstruction) 83 | 84 | 85 | 86 | Sample Generation Output after 40 epochs with 4 layers and 512 hidden dimension and 8 attention heads 87 | 88 | Generate 0 in blue and solid background of olive 89 | 90 | Generate 1 in cyan and texture background of cracker 91 | 92 | Generate 6 in pink and texture background of stones 93 | 94 | Generate 8 in red and texture background of lego 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | ## Citations 104 | 105 | ``` 106 | @misc{ramesh2021zeroshot, 107 | title={Zero-Shot Text-to-Image Generation}, 108 | author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever}, 109 | year={2021}, 110 | eprint={2102.12092}, 111 | archivePrefix={arXiv}, 112 | primaryClass={cs.CV} 113 | } 114 | ``` 115 | 116 | 117 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/config/__init__.py -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: 'data' 3 | image_size: 112 4 | drop_background_prob: 0.1 5 | drop_color_prob: 0.1 6 | 7 | model_params: 8 | vae_num_embeddings: 2048 9 | vae_embedding_dim : 1024 10 | # Will be 112/8 as we have 3 downsamples 11 | dalle_image_size: 14 12 | 13 | gpt_config: 14 | embd_pdrop: 0.1 15 | resid_pdrop: 0.1 16 | attn_pdrop: 0.1 17 | n_layer: 4 18 | n_head: 8 19 | n_embd: 512 20 | 21 | train_params: 22 | task_name: 'default' 23 | batch_size: 64 24 | dalle_batch_size: 256 25 | num_epochs: 40 26 | num_epochs_dalle: 50 27 | dalle_image_loss: 10 28 | kl_weight: 0 29 | lr: 0.001 30 | crit: 'l1' 31 | seed: 1111 32 | save_vae_training_image: True 33 | vae_ckpt_name: 'vae_ckpt.pth' 34 | dalle_ckpt_name: 'dalle_ckpt.pth' 35 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/data/.gitkeep -------------------------------------------------------------------------------- /data/test/keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/data/test/keep -------------------------------------------------------------------------------- /data/textures/keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/data/textures/keep -------------------------------------------------------------------------------- /data/train/keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/data/train/keep -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/mnist_color_texture_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import random 5 | import matplotlib.colors as mcolors 6 | import torch 7 | import json 8 | from tqdm import tqdm 9 | from torch.utils.data.dataset import Dataset 10 | from torch.utils.data.dataloader import DataLoader 11 | 12 | 13 | def get_square_crop(image): 14 | h,w = image.shape[:2] 15 | if h > w: 16 | return image[(h - w)//2:-(h - w)//2, :, :] 17 | else: 18 | return image[:, (w - h) // 2:-(w - h) // 2, :] 19 | 20 | 21 | class MnistVisualLanguageDataset(Dataset): 22 | r""" 23 | Minimal visual language dataset class which auto generates fixed format caption 24 | for each dataset point 25 | """ 26 | def __init__(self, split, config): 27 | self.split = split 28 | self.db_root = config['root_dir'] 29 | self.im_size = config['image_size'] 30 | 31 | # Probability of randomly dropping background info 32 | self.drop_background_info_prob = config['drop_background_prob'] 33 | # Probability for dropping font color and background color 34 | self.drop_font_color_info_prob = config['drop_color_prob'] 35 | 36 | # Auto generated caption formats 37 | self.generation_text_format_tokens = ' generate image of {} in {} and a {} background of {} ' 38 | self.generation_text_format_tokens_drop_bg = ' generate image of {} in {} ' 39 | self.generation_text_format_tokens_drop_color = ' generate image of {} ' 40 | 41 | # Validate right amount of padding and ensure all are same length 42 | assert (len(self.generation_text_format_tokens.split(' ')) == len(self.generation_text_format_tokens_drop_bg.split(' ')) 43 | == len(self.generation_text_format_tokens_drop_color.split(' '))) 44 | self.max_token_len = len(self.generation_text_format_tokens.split(' ')) 45 | self.visual_language_db = json.load(open(os.path.join(self.db_root, self.split + '.json'))) 46 | 47 | self.vocab_idx_to_word, self.vocab_word_to_idx = self.build_vocab() 48 | 49 | def build_vocab(self): 50 | r""" 51 | Method to get dictionary of word to indexes and 52 | indexes to word to be used for tokenizing 53 | and for generation purposes 54 | :return: 55 | """ 56 | vocab_generation_tokens = [word for word in self.generation_text_format_tokens.split(' ') if word != '{}'] 57 | vocab_generation_tokens += [word for word in self.generation_text_format_tokens_drop_bg.split(' ') if word != '{}'] 58 | vocab_generation_tokens += [word for word in self.generation_text_format_tokens_drop_color.split(' ') if word != '{}'] 59 | vocab_preset = set(vocab_generation_tokens) 60 | for db_entry in self.visual_language_db: 61 | if 'texture_name' in db_entry: 62 | vocab_preset.add(db_entry['texture_name']) 63 | vocab_preset.add('texture') 64 | if 'background_color' in db_entry: 65 | vocab_preset.add(db_entry['background_color']) 66 | vocab_preset.add('solid') 67 | vocab_preset.add(db_entry['digit_name']) 68 | vocab_preset.add(db_entry['digit_color']) 69 | vocab_tokens = sorted(list(vocab_preset)) 70 | vocab_word_to_idx = { k:v for (k,v) in zip(vocab_tokens, range(len(vocab_tokens)))} 71 | vocab_idx_to_word = { v:k for (k,v) in zip(vocab_tokens, range(len(vocab_tokens)))} 72 | return vocab_idx_to_word, vocab_word_to_idx 73 | 74 | def __len__(self): 75 | return len(self.visual_language_db) 76 | 77 | def __getitem__(self, index): 78 | entry = self.visual_language_db[index] 79 | background_type = 'solid' if 'background_color' in entry else 'texture' 80 | drop_type = random.choices(['no_drop','drop_bg','drop_bg_and_color'], 81 | weights=[1-self.drop_background_info_prob-self.drop_font_color_info_prob, 82 | self.drop_background_info_prob, 83 | self.drop_font_color_info_prob])[0] 84 | if drop_type == 'no_drop': 85 | text = self.generation_text_format_tokens.format(entry['digit_name'], 86 | entry['digit_color'], 87 | background_type, 88 | entry['background_color'] if background_type == 'solid' else 89 | entry['texture_name']) 90 | elif drop_type == 'drop_bg': 91 | text = self.generation_text_format_tokens_drop_bg.format(entry['digit_name'], 92 | entry['digit_color']) 93 | else: 94 | text = self.generation_text_format_tokens_drop_color.format(entry['digit_name']) 95 | 96 | text_tokens = [self.vocab_word_to_idx[word] for word in text.split(' ')] 97 | text_tokens = torch.LongTensor(text_tokens) 98 | 99 | digit_im = cv2.imread(os.path.join(self.db_root, entry['digit_image'])) 100 | digit_im = cv2.cvtColor(digit_im, cv2.COLOR_BGR2RGB) 101 | digit_im = cv2.resize(digit_im, (self.im_size, self.im_size)) 102 | 103 | # Discretize mnist images to be either 0 or 1 104 | digit_im[digit_im > 50] = 255 105 | digit_im[digit_im <= 50] = 0 106 | mask_val = (digit_im > 0).astype(np.float32) 107 | color_scale = mcolors.hex2color('tab:{}'.format(entry['digit_color'])) 108 | digit_im = np.concatenate((digit_im[:, :, 0][..., None] * color_scale[0], 109 | digit_im[:, :, 1][..., None] * color_scale[1], 110 | digit_im[:, :, 2][..., None] * color_scale[2]), axis=-1) 111 | if background_type == 'texture': 112 | im = cv2.imread(os.path.join(self.db_root, entry['texture_image'])) 113 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 114 | im = get_square_crop(im) 115 | im = cv2.resize(im, (self.im_size, self.im_size)) 116 | else: 117 | im = np.ones((self.im_size, self.im_size, 3)) 118 | back_color_scale = mcolors.hex2color('tab:{}'.format(entry['background_color'])) 119 | im[:, :, 0] = 255*back_color_scale[0] 120 | im[:, :, 1] = 255*back_color_scale[1] 121 | im[:, :, 2] = 255*back_color_scale[2] 122 | out_im = mask_val * digit_im + (1 - mask_val) * im 123 | im_tensor = torch.from_numpy(out_im).permute((2, 0, 1)) 124 | im_tensor = 2 * (im_tensor / 255) - 1 125 | return { 126 | "image" : im_tensor, 127 | "text_tokens" : text_tokens, 128 | "text" : text, 129 | } 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/model/__init__.py -------------------------------------------------------------------------------- /model/dalle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.mingpt import GPT, DallEGPTConfig 4 | 5 | class DallE(nn.Module): 6 | r""" 7 | Class handling the logic for DallE 8 | Calls the vae and passes the text and image tokens 9 | together with target to gpt 10 | """ 11 | def __init__(self, vae, num_words, image_size, max_text_len, image_vocab_size, gpt_config): 12 | super(DallE, self).__init__() 13 | self.vae = vae 14 | 15 | # Text Vocab size 16 | self.num_words = num_words 17 | # Number of Image tokens 18 | self.image_size = image_size 19 | # Maximum Text Sequence Length 20 | self.max_text_len = max_text_len 21 | 22 | # Image tokens vocabulary size (num_of_embeddings) 23 | image_vocab_size = image_vocab_size 24 | 25 | # Length of largest sequence so that we tell gpt 26 | # to have that as the context size 27 | max_sequence_len = max_text_len + image_size*image_size 28 | config = DallEGPTConfig(text_vocab_size=num_words, 29 | image_vocab_size=image_vocab_size, 30 | max_sequence_len=max_sequence_len, 31 | im_size=image_size, 32 | **gpt_config) 33 | self.gpt = GPT(config) 34 | 35 | def forward(self, im, text): 36 | # Call Discrete vae 37 | image_tokens = self.vae.get_codebook_indices(im).reshape(im.size(0), -1) 38 | 39 | # Shift the target image tokens as image tokens + text vocab size 40 | # Last fc layer will predict 0 to (num_words + num_embeddings) output probabilities 41 | # We will formulate the target such first num_words-1 are text token probabilities 42 | # and num_words to num_words+num_embeddings are image token probabilities 43 | target_image_tokens = image_tokens + self.num_words 44 | labels = None 45 | 46 | if self.training: 47 | # Pass one position shifted tokens as targets only in training 48 | labels = torch.cat((text[:, 1:], target_image_tokens), dim=1) 49 | # Loss of text and Loss image separately so that we can get better images 50 | logits, loss_text, loss_image = self.gpt(image_tokens, text, targets=labels) 51 | return logits, loss_text, loss_image 52 | 53 | -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Decoder(nn.Module): 6 | r""" 7 | Decoder with couple of residual blocks 8 | followed by conv transpose relu layers 9 | """ 10 | def __init__(self, embedding_dim): 11 | super(Decoder, self).__init__() 12 | 13 | self.decoder_layers = nn.ModuleList([ 14 | nn.ConvTranspose2d(64, 64, 4, 2, 1), 15 | nn.ReLU(), 16 | nn.ConvTranspose2d(64, 32, 4, 2, 1), 17 | nn.ReLU(), 18 | nn.ConvTranspose2d(32, 16, 4, 2, 1), 19 | nn.ReLU(), 20 | nn.Conv2d(16, 3, 1), 21 | nn.Tanh() 22 | ]) 23 | 24 | self.residuals = nn.ModuleList([ 25 | nn.Sequential( 26 | nn.Conv2d(64, 64, 3, padding=1), 27 | nn.ReLU(), 28 | nn.Conv2d(64, 64, 3, padding=1), 29 | nn.ReLU()), 30 | nn.Sequential( 31 | nn.Conv2d(64, 64, 3, padding=1), 32 | nn.ReLU(), 33 | nn.Conv2d(64, 64, 3, padding=1), 34 | nn.ReLU()) 35 | ]) 36 | 37 | self.decoder_quant_conv = nn.Conv2d(embedding_dim, 64, 1) 38 | 39 | 40 | def forward(self, x): 41 | out = self.decoder_quant_conv(x) 42 | for layer in self.residuals: 43 | out = layer(out)+out 44 | for idx, layer in enumerate(self.decoder_layers): 45 | out = layer(out) 46 | return out 47 | 48 | 49 | 50 | if __name__ == '__main__': 51 | import torch 52 | import torch.nn as nn 53 | import yaml 54 | decoder = Decoder() 55 | 56 | out = decoder(torch.rand((3, 64, 14, 14))) 57 | print(out.shape) 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /model/discrete_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.encoder import Encoder 4 | from model.decoder import Decoder 5 | from model.quantizer import Quantizer 6 | 7 | class DiscreteVAE(nn.Module): 8 | def __init__(self, num_embeddings=1024, embedding_dim=512): 9 | super(DiscreteVAE, self).__init__() 10 | self.encoder = Encoder(num_embeddings=num_embeddings) 11 | self.quantizer = Quantizer(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 12 | self.decoder = Decoder(embedding_dim=embedding_dim) 13 | 14 | 15 | def get_codebook_indices(self, x): 16 | # x.shape = B,C,H,W 17 | enc_logits = self.encoder(x) 18 | # enc_logits.shape = B,C,H,W 19 | indices = torch.argmax(enc_logits, dim=1) 20 | return indices 21 | 22 | def decode_from_codebook_indices(self, indices): 23 | quantized_indices = self.quantizer.quantize_indices(indices) 24 | return self.decoder(quantized_indices) 25 | 26 | def forward(self, x): 27 | enc = self.encoder(x) 28 | quant_output, kl, logits, log_qy = self.quantizer(enc) 29 | out = self.decoder(quant_output) 30 | return out, kl, log_qy 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Encoder(nn.Module): 6 | r""" 7 | Encoder is conv relu blocks 8 | followed by couple of residual blocks. 9 | Last 1x1 conv converts to logits with 10 | num_embeddings as output size 11 | """ 12 | def __init__(self, num_embeddings): 13 | super(Encoder, self).__init__() 14 | # Encoder is just Conv relu blocks 15 | self.encoder_layers = nn.ModuleList([ 16 | nn.Conv2d(3, 32, 4, 2, 1), 17 | nn.ReLU(), 18 | nn.Conv2d(32, 64, 4, 2, 1), 19 | nn.ReLU(), 20 | nn.Conv2d(64, 64, 4, 2, 1), 21 | nn.ReLU(), 22 | ]) 23 | self.residuals = nn.ModuleList([ 24 | nn.Sequential( 25 | nn.Conv2d(64, 64, 3, padding = 1), 26 | nn.ReLU(), 27 | nn.Conv2d(64, 64, 3, padding = 1), 28 | nn.ReLU()), 29 | nn.Sequential( 30 | nn.Conv2d(64, 64, 3, padding = 1), 31 | nn.ReLU(), 32 | nn.Conv2d(64, 64, 3, padding=1), 33 | nn.ReLU()) 34 | ]) 35 | self.encoder_quant_conv = nn.Sequential( 36 | nn.Conv2d(64, num_embeddings, 1)) 37 | 38 | 39 | def forward(self, x): 40 | out = x 41 | for layer in self.encoder_layers: 42 | out = layer(out) 43 | for layer in self.residuals: 44 | out = out + layer(out) 45 | out = self.encoder_quant_conv(out) 46 | return out -------------------------------------------------------------------------------- /model/mingpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPT model: 3 | - the initial stem consists of a combination of token encoding and a positional encoding 4 | - the meat of it is a uniform sequence of Transformer blocks 5 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block 6 | - all blocks feed into a central residual pathway similar to resnets 7 | - the final decoder is a linear projection into a vanilla Softmax classifier 8 | """ 9 | 10 | import math 11 | import logging 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | class DallEGPTConfig: 20 | r""" 21 | Minimal DallE config holding all fields requored for 22 | training gpt for both text and iamge tokens 23 | """ 24 | def __init__(self, text_vocab_size, 25 | image_vocab_size, 26 | max_sequence_len, im_size, **kwargs): 27 | self.text_vocab_size = text_vocab_size 28 | self.image_vocab_size = image_vocab_size 29 | # Fixing block size to maximum sequence length we have seen 30 | self.block_size = max_sequence_len 31 | self.im_size = im_size 32 | self.num_text_tokens = max_sequence_len - im_size*im_size 33 | for k,v in kwargs.items(): 34 | setattr(self, k, v) 35 | 36 | 37 | class CausalSelfAttention(nn.Module): 38 | """ 39 | A vanilla multi-head masked self-attention layer with a projection at the end. 40 | It is possible to use torch.nn.MultiheadAttention here but I am including an 41 | explicit implementation here to show that there is nothing too scary here. 42 | """ 43 | 44 | def __init__(self, config): 45 | super().__init__() 46 | assert config.n_embd % config.n_head == 0 47 | # key, query, value projections for all heads 48 | self.key = nn.Linear(config.n_embd, config.n_embd) 49 | self.query = nn.Linear(config.n_embd, config.n_embd) 50 | self.value = nn.Linear(config.n_embd, config.n_embd) 51 | # regularization 52 | self.attn_drop = nn.Dropout(config.attn_pdrop) 53 | self.resid_drop = nn.Dropout(config.resid_pdrop) 54 | # output projection 55 | self.proj = nn.Linear(config.n_embd, config.n_embd) 56 | # causal mask to ensure that attention is only applied to the left in the input sequence 57 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) 58 | .view(1, 1, config.block_size, config.block_size)) 59 | self.n_head = config.n_head 60 | 61 | def forward(self, x, layer_past=None): 62 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 63 | 64 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 65 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 66 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 67 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 68 | 69 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 70 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 71 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 72 | 73 | att = F.softmax(att, dim=-1) 74 | att = self.attn_drop(att) 75 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 76 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 77 | 78 | # output projection 79 | y = self.resid_drop(self.proj(y)) 80 | return y 81 | 82 | class Block(nn.Module): 83 | """ an unassuming Transformer block """ 84 | 85 | def __init__(self, config): 86 | super().__init__() 87 | self.ln1 = nn.LayerNorm(config.n_embd) 88 | self.ln2 = nn.LayerNorm(config.n_embd) 89 | self.attn = CausalSelfAttention(config) 90 | self.mlp = nn.Sequential( 91 | nn.Linear(config.n_embd, 4 * config.n_embd), 92 | nn.GELU(), 93 | nn.Linear(4 * config.n_embd, config.n_embd), 94 | nn.Dropout(config.resid_pdrop), 95 | ) 96 | 97 | def forward(self, x): 98 | y = self.attn(self.ln1(x)) 99 | x = x + y 100 | x = x + self.mlp(self.ln2(x)) 101 | return x 102 | 103 | class GPT(nn.Module): 104 | """ the full GPT language model, with a context size of block_size """ 105 | 106 | def __init__(self, config): 107 | super().__init__() 108 | 109 | # input embedding stem 110 | self.text_tok_emb = nn.Embedding(config.text_vocab_size, config.n_embd) 111 | self.image_tok_emb = nn.Embedding(config.image_vocab_size, config.n_embd) 112 | 113 | self.text_pos_emb = nn.Parameter(torch.zeros(1, config.num_text_tokens, config.n_embd)) 114 | self.image_pos_emb = nn.Parameter(torch.zeros(1, config.im_size ** 2, config.n_embd)) 115 | self.drop = nn.Dropout(config.embd_pdrop) 116 | # transformer 117 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 118 | # decoder head 119 | self.ln_f = nn.LayerNorm(config.n_embd) 120 | self.head = nn.Linear(config.n_embd, config.text_vocab_size + config.image_vocab_size, bias=False) 121 | self.config = config 122 | self.block_size = config.block_size 123 | self.apply(self._init_weights) 124 | 125 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) 126 | 127 | def get_block_size(self): 128 | return self.block_size 129 | 130 | def _init_weights(self, module): 131 | if isinstance(module, (nn.Linear, nn.Embedding)): 132 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 133 | if isinstance(module, nn.Linear) and module.bias is not None: 134 | torch.nn.init.zeros_(module.bias) 135 | elif isinstance(module, nn.LayerNorm): 136 | torch.nn.init.zeros_(module.bias) 137 | torch.nn.init.ones_(module.weight) 138 | elif isinstance(module, GPT): 139 | torch.nn.init.normal_(module.text_pos_emb, mean=0.0, std=0.02) 140 | # Simply do the same thing for image as for text 141 | torch.nn.init.normal_(module.image_pos_emb, mean=0.0, std=0.02) 142 | 143 | 144 | def forward(self, image_tokens, text_tokens, targets=None,): 145 | b, im_t = image_tokens.size() 146 | b, text_t = text_tokens.size() 147 | assert im_t + text_t <= self.block_size, "Cannot forward, model block size is exhausted." 148 | 149 | text_emb = self.text_tok_emb(text_tokens) 150 | text_pos = self.text_pos_emb[:, :text_t, :] 151 | text_token_embeddings = self.drop(text_emb + text_pos) 152 | x = text_token_embeddings 153 | 154 | # Add image tokens for input sequence if needed. 155 | # Won't be needed for first pixel generation 156 | if im_t > 0: 157 | image_emb = self.image_tok_emb(image_tokens) 158 | image_pos = self.image_pos_emb[:, :im_t, :] 159 | image_token_embeddings = self.drop(image_emb + image_pos) 160 | x = torch.cat([x, image_token_embeddings], dim=1) 161 | 162 | x = self.blocks(x) 163 | x = self.ln_f(x) 164 | logits = self.head(x) 165 | 166 | # if we are given some desired targets also calculate the loss 167 | # Separate text and image loss 168 | loss_text = None 169 | loss_image = None 170 | if targets is not None: 171 | logits = logits[:, :-1, :] 172 | 173 | # Separate text and image token loss computation 174 | text_logits = logits[:, :text_t - 1, :].permute((0, 2, 1)) 175 | image_logits = logits[:, text_t - 1:, :].permute((0, 2, 1)) 176 | 177 | # For now just mask logits of image tokens for text targets 178 | # And mask out text tokens logits for iamge targets 179 | # Dont want gpt to gain points by simply decreasing scores for indexes of the other type 180 | # And anyway at inference you would always sample image token when generating image 181 | text_logits[:, self.config.text_vocab_size:, :] = -torch.finfo(logits.dtype).max 182 | image_logits[:, :self.config.text_vocab_size, :] = -torch.finfo(logits.dtype).max 183 | loss_text = F.cross_entropy(text_logits, targets[:, :text_t-1]) 184 | loss_image = F.cross_entropy(image_logits, targets[:, text_t-1:]) 185 | return logits, loss_text, loss_image -------------------------------------------------------------------------------- /model/quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import einsum, rearrange 4 | 5 | 6 | class Quantizer(nn.Module): 7 | def __init__(self, num_embeddings, embedding_dim): 8 | super(Quantizer, self).__init__() 9 | 10 | self.num_embeddings = num_embeddings 11 | self.embedding = nn.Embedding(self.num_embeddings, embedding_dim) 12 | 13 | def forward(self, x): 14 | B, C, H, W = x.shape 15 | one_hot = torch.nn.functional.gumbel_softmax(x, tau=0.9, dim=1, hard=False) 16 | sampled = einsum(one_hot, self.embedding.weight, 'b n h w, n d -> b d h w') 17 | 18 | # Compute kl loss 19 | logits = rearrange(x, 'b n h w -> b (h w) n') 20 | log_qy = torch.nn.functional.log_softmax(logits, dim=-1) 21 | log_uniform = torch.log(torch.tensor([1. / self.num_embeddings], device=torch.device(x.device))) 22 | kl_div = torch.nn.functional.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target=True) 23 | return sampled, kl_div, logits, log_qy 24 | 25 | def quantize_indices(self, indices): 26 | return einsum(indices, self.embedding.weight, 'b n h w, n d -> b d h w') 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | matplotlib==3.7.2 3 | numpy==1.23.5 4 | opencv_python==4.8.0.74 5 | PyYAML==6.0 6 | torch==1.11.0 7 | torchvision==0.12.0 8 | tqdm==4.65.0 9 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/tools/__init__.py -------------------------------------------------------------------------------- /tools/generate_image.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import random 5 | import os 6 | import torchvision 7 | import numpy as np 8 | from einops import rearrange 9 | from tqdm import tqdm 10 | from model.discrete_vae import DiscreteVAE 11 | from model.dalle import DallE 12 | from dataset.mnist_color_texture_dataset import MnistVisualLanguageDataset 13 | from torchvision.utils import make_grid 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | def infer(args): 19 | ######## Read the config file ####### 20 | with open(args.config_path, 'r') as file: 21 | try: 22 | config = yaml.safe_load(file) 23 | except yaml.YAMLError as exc: 24 | print(exc) 25 | print(config) 26 | 27 | ######## Set the desired seed value ####### 28 | # Ignoring the fixed seed value 29 | seed = np.random.randint(0, 1000) 30 | torch.manual_seed(seed) 31 | np.random.seed(seed) 32 | random.seed(seed) 33 | if device == 'cuda': 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | if not os.path.exists(config['train_params']['task_name']): 37 | os.mkdir(config['train_params']['task_name']) 38 | 39 | # Create db to fetch the configuration values like vocab size (should do something better) 40 | mnist = MnistVisualLanguageDataset('train', config['dataset_params']) 41 | 42 | ###### Load Discrete VAE##### 43 | 44 | vae = DiscreteVAE( 45 | num_embeddings=config['model_params']['vae_num_embeddings'], 46 | embedding_dim=config['model_params']['vae_embedding_dim'] 47 | ) 48 | vae.to(device) 49 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'], 50 | config['train_params']['vae_ckpt_name'])): 51 | print('Found checkpoint... Taking vae from that') 52 | vae.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'], 53 | config['train_params']['vae_ckpt_name']), map_location=device)) 54 | else: 55 | print('No checkpoint found at {}/{}... Exiting'.format(config['train_params']['task_name'], 56 | config['train_params']['vae_ckpt_name'])) 57 | print('Train vae first') 58 | return 59 | vae.eval() 60 | vae.requires_grad_(False) 61 | 62 | ############################## 63 | 64 | 65 | ########### Load DallE ########## 66 | model = DallE(vae=vae, 67 | num_words=len(mnist.vocab_word_to_idx), 68 | image_size=config['model_params']['dalle_image_size'], 69 | max_text_len=mnist.max_token_len, 70 | image_vocab_size=config['model_params']['vae_num_embeddings'], 71 | gpt_config=config['gpt_config']) 72 | model.to(device) 73 | model.eval() 74 | model.requires_grad_(False) 75 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'], 76 | config['train_params']['dalle_ckpt_name'])): 77 | print('Found checkpoint... Starting training from that') 78 | model.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'], 79 | config['train_params']['dalle_ckpt_name']), map_location=device)) 80 | else: 81 | print('No checkpoint found for dalle at {}/{}... Exiting'.format(config['train_params']['task_name'], 82 | config['train_params']['dalle_ckpt_name'])) 83 | 84 | return 85 | ################################# 86 | 87 | im_tokens_len = config['model_params']['dalle_image_size'] * config['model_params']['dalle_image_size'] 88 | colors = ['red', 'blue', 'pink', 'green', 'cyan'] 89 | textures = ['lego', 'stones', 'wool', 'cracker', 'peas'] 90 | solids = ['orange', 'olive', 'purple'] 91 | numbers = list(range(10)) 92 | 93 | #### Genrate 10 random images ###### 94 | vae_inputs = [] 95 | fnames = [] 96 | for _ in tqdm(range(10)): 97 | color = random.choice(colors) 98 | number = random.choice(numbers) 99 | 100 | ######## Set the desired seed value ####### 101 | seed = np.random.randint(0, 1000) 102 | torch.manual_seed(seed) 103 | np.random.seed(seed) 104 | random.seed(seed) 105 | if device == 'cuda': 106 | torch.cuda.manual_seed_all(seed) 107 | 108 | if random.random() < 0.1: 109 | solid = random.choice(solids) 110 | sent = ('generate image of {} in {} and a solid background of {}' 111 | .format(number, color, solid).split(' ')) 112 | fnames.append('{}_{}_{}.png'.format(number, color, solid)) 113 | else: 114 | texture = random.choice(textures) 115 | sent = ('generate image of {} in {} and a texture background of {}'. 116 | format(number, color, texture).split(' ')) 117 | fnames.append('{}_{}_{}.png'.format(number, color, texture)) 118 | sent = [''] + sent + [''] 119 | text_tokens = torch.LongTensor([mnist.vocab_word_to_idx[word] for word in sent]).to(device).unsqueeze(0) 120 | random_im_tokens = torch.randint(0, config['model_params']['vae_num_embeddings'], 121 | (model.image_size * model.image_size,)).to(device) 122 | 123 | #### Generate pixels one by one ##### 124 | im_tokens = torch.LongTensor([]).to(device) 125 | for tok_idx in range(im_tokens_len): 126 | logits, _, _ = model.gpt(im_tokens.unsqueeze(0), text_tokens) 127 | logits = logits[:, -1, :] 128 | 129 | # Ignore logits of all non-image tokens 130 | logits[:, :len(mnist.vocab_word_to_idx)] = -torch.finfo(logits.dtype).max 131 | 132 | # Get topk and sample from them 133 | val, ind = torch.topk(logits, 3) 134 | probs = torch.full_like(logits, -torch.finfo(logits.dtype).max) 135 | probs.scatter_(1, ind, val) 136 | probs = torch.nn.functional.softmax(logits, dim=-1) 137 | sample = torch.multinomial(probs, num_samples=1)[0].to(device) 138 | 139 | # Reduce predicted output by text vocab size to get vae token index 140 | sample -= model.num_words 141 | 142 | im_tokens = torch.cat((im_tokens, sample), dim=-1) 143 | random_im_tokens[:tok_idx + 1] = im_tokens 144 | 145 | 146 | vae_input = random_im_tokens.reshape((model.image_size, model.image_size)) 147 | vae_inputs.append(vae_input.unsqueeze(0)) 148 | 149 | # Pass predicted discrete sequence to vae 150 | vae_inputs = torch.cat(vae_inputs, dim=0) 151 | z = torch.nn.functional.one_hot(vae_inputs, num_classes=config['model_params']['vae_num_embeddings']) 152 | z = rearrange(z, 'b h w c -> b c h w').float() 153 | output = vae.decode_from_codebook_indices(z) 154 | output = (output + 1) / 2 155 | for idx in range((output.size(0))): 156 | img = torchvision.transforms.ToPILImage()(output[idx].detach().cpu()) 157 | img.save(os.path.join(config['train_params']['task_name'], 158 | fnames[idx])) 159 | 160 | 161 | 162 | 163 | 164 | if __name__ == '__main__': 165 | parser = argparse.ArgumentParser(description='Arguments for generating outputs') 166 | parser.add_argument('--config', dest='config_path', 167 | default='config/default.yaml', type=str) 168 | args = parser.parse_args() 169 | infer(args) -------------------------------------------------------------------------------- /tools/infer_dvae.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import yaml 3 | import argparse 4 | import torch 5 | import os 6 | import torchvision 7 | from model.discrete_vae import DiscreteVAE 8 | from dataset.mnist_color_texture_dataset import MnistVisualLanguageDataset 9 | from torchvision.utils import make_grid 10 | from einops import rearrange 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | 15 | def inference(args): 16 | r""" 17 | Method to infer discrete vae and get 18 | reconstructions 19 | :param args: 20 | :return: 21 | """ 22 | with open(args.config_path, 'r') as file: 23 | try: 24 | config = yaml.safe_load(file) 25 | except yaml.YAMLError as exc: 26 | print(exc) 27 | print(config) 28 | 29 | model = DiscreteVAE( 30 | num_embeddings=config['model_params']['vae_num_embeddings'], 31 | embedding_dim=config['model_params']['vae_embedding_dim'] 32 | ) 33 | model.to(device) 34 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'], 35 | config['train_params']['vae_ckpt_name'])): 36 | print('Found checkpoint... Inferring from that') 37 | model.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'], 38 | config['train_params']['vae_ckpt_name']), map_location=device)) 39 | else: 40 | print('No checkpoint found at {}/{}... Exiting'.format(config['train_params']['task_name'], 41 | config['train_params']['vae_ckpt_name'])) 42 | return 43 | model.eval() 44 | mnist = MnistVisualLanguageDataset('test', config['dataset_params']) 45 | 46 | # Generate reconstructions for 100 samples 47 | idxs = torch.randint(0, len(mnist) - 1, (25,)) 48 | ims = torch.cat([mnist[idx]['image'][None, :] for idx in idxs]).float().to(device) 49 | output = model(ims) 50 | generated_im = output[0] 51 | 52 | # Dataset generates -1 to 1 we convert it to 0-1 53 | ims = (ims + 1) / 2 54 | generated_im = (generated_im + 1) / 2 55 | out = torch.hstack([ims, generated_im]) 56 | output = rearrange(out, 'b (c d) h w -> b (d) h (c w)', c=2, d=3) 57 | grid = make_grid(output, nrow=5) 58 | img = torchvision.transforms.ToPILImage()(grid.detach().cpu()) 59 | img.save(os.path.join(config['train_params']['task_name'], 60 | 'dvae_reconstructions.png')) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser(description='Arguments for discrete vae inference') 65 | parser.add_argument('--config', dest='config_path', 66 | default='config/default.yaml', type=str) 67 | args = parser.parse_args() 68 | inference(args) 69 | -------------------------------------------------------------------------------- /tools/train_dalle.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import random 5 | import os 6 | import numpy as np 7 | from tqdm import tqdm 8 | from model.discrete_vae import DiscreteVAE 9 | from model.dalle import DallE 10 | from torch.utils.data.dataloader import DataLoader 11 | from dataset.mnist_color_texture_dataset import MnistVisualLanguageDataset 12 | from torch.optim import Adam 13 | from torch.optim.lr_scheduler import ReduceLROnPlateau 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | plt_counts = 0 17 | 18 | 19 | def train_for_one_epoch(epoch_idx, model, loader, optimizer, config): 20 | r""" 21 | Method to run the training for one epoch. 22 | :param epoch_idx: iteration number of current epoch 23 | :param model: Dalle model 24 | :param mnist_loader: Data loder 25 | :param optimizer: optimzier to be used taken from config 26 | :param crtierion: For computing the loss 27 | :param config: configuration for the current run 28 | :return: 29 | """ 30 | losses = [] 31 | for data in tqdm(loader): 32 | im = data['image'] 33 | text_tokens = data['text_tokens'] 34 | im = im.float().to(device) 35 | text = text_tokens.long().to(device) 36 | optimizer.zero_grad() 37 | 38 | _, loss_text, loss_image = model(im, text) 39 | loss = (loss_text*1 + loss_image*config['train_params']['dalle_image_loss']) / (1+config['train_params']['dalle_image_loss']) 40 | losses.append(loss.item()) 41 | loss.backward() 42 | optimizer.step() 43 | print('Finished epoch: {} | Modelling Loss : {:.4f} '. 44 | format(epoch_idx + 1, 45 | np.mean(losses))) 46 | return np.mean(losses) 47 | 48 | 49 | def train(args): 50 | ######## Read the config file ####### 51 | with open(args.config_path, 'r') as file: 52 | try: 53 | config = yaml.safe_load(file) 54 | except yaml.YAMLError as exc: 55 | print(exc) 56 | print(config) 57 | 58 | ####################################### 59 | 60 | ######## Set the desired seed value ####### 61 | seed = config['train_params']['seed'] 62 | torch.manual_seed(seed) 63 | np.random.seed(seed) 64 | random.seed(seed) 65 | if device == 'cuda': 66 | torch.cuda.manual_seed_all(seed) 67 | 68 | if not os.path.exists(config['train_params']['task_name']): 69 | os.mkdir(config['train_params']['task_name']) 70 | 71 | ######## Create the model and dataset ########## 72 | num_epochs = config['train_params']['num_epochs_dalle'] 73 | mnist = MnistVisualLanguageDataset('train', config['dataset_params']) 74 | mnist_loader = DataLoader(mnist, batch_size=config['train_params']['dalle_batch_size'], 75 | shuffle=True, num_workers=4) 76 | vae = DiscreteVAE( 77 | num_embeddings=config['model_params']['vae_num_embeddings'], 78 | embedding_dim=config['model_params']['vae_embedding_dim'] 79 | ) 80 | vae.to(device) 81 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'], 82 | config['train_params']['vae_ckpt_name'])): 83 | print('Found checkpoint... Taking vae from that') 84 | vae.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'], 85 | config['train_params']['vae_ckpt_name']),map_location=device)) 86 | else: 87 | print('No checkpoint found at {}/{}... Exiting'.format(config['train_params']['task_name'], 88 | config['train_params']['vae_ckpt_name'])) 89 | print('Train vae first') 90 | return 91 | vae.eval() 92 | vae.requires_grad_(False) 93 | 94 | 95 | model = DallE(vae=vae, 96 | num_words=len(mnist.vocab_word_to_idx), 97 | image_size=config['model_params']['dalle_image_size'], 98 | max_text_len=mnist.max_token_len, 99 | image_vocab_size=config['model_params']['vae_num_embeddings'], 100 | gpt_config=config['gpt_config']) 101 | model.to(device) 102 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'], 103 | config['train_params']['dalle_ckpt_name'])): 104 | print('Found checkpoint... Starting training from that') 105 | model.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'], 106 | config['train_params']['dalle_ckpt_name']),map_location=device)) 107 | 108 | ####### Training Parameters ############ 109 | optimizer = Adam(model.parameters(), lr=config['train_params']['lr']) 110 | scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=1, verbose=True) 111 | 112 | best_loss = np.inf 113 | for epoch_idx in range(num_epochs): 114 | mean_loss = train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, config) 115 | scheduler.step(mean_loss) 116 | # Simply update checkpoint if found better version 117 | if mean_loss < best_loss: 118 | print('Improved Loss from {:.4f} to {:.4f} .... Saving Model'.format(best_loss, mean_loss)) 119 | torch.save(model.state_dict(), '{}/{}'.format(config['train_params']['task_name'], 120 | config['train_params']['dalle_ckpt_name'])) 121 | best_loss = mean_loss 122 | else: 123 | print('No Loss Improvement. Best Loss : {:.4f}'.format(best_loss)) 124 | 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser(description='Arguments for dalle training') 128 | parser.add_argument('--config', dest='config_path', 129 | default='config/default.yaml', type=str) 130 | args = parser.parse_args() 131 | train(args) 132 | -------------------------------------------------------------------------------- /tools/train_dvae.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import torch 4 | import cv2 5 | import random 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | from model.discrete_vae import DiscreteVAE 10 | from torch.utils.data.dataloader import DataLoader 11 | from dataset.mnist_color_texture_dataset import MnistVisualLanguageDataset 12 | from torch.optim import Adam 13 | from torch.optim.lr_scheduler import ReduceLROnPlateau 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | def train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, crtierion, config): 19 | losses = [] 20 | count = 0 21 | for data in tqdm(mnist_loader): 22 | # For vae we only need images 23 | im = data['image'] 24 | im = im.float().to(device) 25 | optimizer.zero_grad() 26 | 27 | output, kl, log_qy = model(im) 28 | if config['train_params']['save_vae_training_image'] and count % 25 == 0: 29 | im_input = cv2.cvtColor((255 * (im.detach() + 1) / 2).cpu().permute((0, 2, 3, 1)).numpy()[0], 30 | cv2.COLOR_RGB2BGR) 31 | im_output = cv2.cvtColor((255 * (output.detach() + 1) / 2).cpu().permute((0, 2, 3, 1)).numpy()[0], 32 | cv2.COLOR_RGB2BGR) 33 | cv2.imwrite('{}/input.jpeg'.format(config['train_params']['task_name']), im_input) 34 | cv2.imwrite('{}/output.jpeg'.format(config['train_params']['task_name']), im_output) 35 | 36 | loss = (crtierion(output, im) + config['train_params']['kl_weight']*kl)/(1+config['train_params']['kl_weight']) 37 | losses.append(loss.item()) 38 | loss.backward() 39 | optimizer.step() 40 | count += 1 41 | 42 | print('Finished epoch: {} | Loss : {:.4f} '. 43 | format(epoch_idx + 1, 44 | np.mean(losses))) 45 | return np.mean(losses) 46 | 47 | 48 | def train(args): 49 | ######## Read the config file ####### 50 | with open(args.config_path, 'r') as file: 51 | try: 52 | config = yaml.safe_load(file) 53 | except yaml.YAMLError as exc: 54 | print(exc) 55 | print(config) 56 | ####################################### 57 | 58 | ######## Set the desired seed value ####### 59 | seed = config['train_params']['seed'] 60 | torch.manual_seed(seed) 61 | np.random.seed(seed) 62 | random.seed(seed) 63 | if device == 'cuda': 64 | torch.cuda.manual_seed_all(seed) 65 | 66 | if not os.path.exists(config['train_params']['task_name']): 67 | os.mkdir(config['train_params']['task_name']) 68 | 69 | 70 | ####################################### 71 | # Create the model and dataset 72 | num_epochs = config['train_params']['num_epochs'] 73 | model = DiscreteVAE( 74 | num_embeddings=config['model_params']['vae_num_embeddings'], 75 | embedding_dim=config['model_params']['vae_embedding_dim'] 76 | ) 77 | model.to(device) 78 | 79 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'], 80 | config['train_params']['vae_ckpt_name'])): 81 | print('Found checkpoint... Starting training from that') 82 | model.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'], 83 | config['train_params']['vae_ckpt_name']))) 84 | mnist = MnistVisualLanguageDataset('train', config['dataset_params']) 85 | mnist_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], 86 | shuffle=True, num_workers=4) 87 | 88 | optimizer = Adam(model.parameters(), lr=config['train_params']['lr']) 89 | scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=1, verbose=True) 90 | criterion = { 91 | 'l1': torch.nn.SmoothL1Loss(beta=0.1), 92 | 'l2': torch.nn.MSELoss() 93 | }.get(config['train_params']['crit']) 94 | 95 | 96 | if not os.path.exists(config['train_params']['task_name']): 97 | os.mkdir(config['train_params']['task_name']) 98 | 99 | 100 | best_loss = np.inf 101 | for epoch_idx in range(num_epochs): 102 | mean_loss = train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, criterion, config) 103 | scheduler.step(mean_loss) 104 | # Simply update checkpoint if found better version 105 | if mean_loss < best_loss: 106 | print('Improved Loss from {:.4f} to {:.4f} .... Saving Model'.format(best_loss, mean_loss)) 107 | torch.save(model.state_dict(), '{}/{}'.format(config['train_params']['task_name'], 108 | config['train_params']['vae_ckpt_name'])) 109 | best_loss = mean_loss 110 | else: 111 | print('No Loss Improvement. Best Loss : {:.4f}'.format(best_loss)) 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser(description='Arguments for vae training') 116 | parser.add_argument('--config', dest='config_path', 117 | default='config/default.yaml', type=str) 118 | args = parser.parse_args() 119 | train(args) 120 | --------------------------------------------------------------------------------