├── .gitignore
├── .gitkeep
├── README.md
├── config
├── __init__.py
└── default.yaml
├── data
├── .gitkeep
├── test
│ └── keep
├── textures
│ └── keep
└── train
│ └── keep
├── dataset
├── __init__.py
└── mnist_color_texture_dataset.py
├── model
├── __init__.py
├── dalle.py
├── decoder.py
├── discrete_vae.py
├── encoder.py
├── mingpt.py
└── quantizer.py
├── requirements.txt
└── tools
├── __init__.py
├── generate_image.py
├── infer_dvae.py
├── train_dalle.py
└── train_dvae.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore all image files
2 | *.jpg
3 | *.png
4 | *.jpeg
5 |
6 | # Ignore pycharm and system files
7 | .DS_Store
8 | *.idea
9 | __pycache__
10 | *.zip
11 |
12 | # Ignore dataset files
13 | *.csv
14 |
15 | # Ignore checkpoints
16 | *.pth
17 |
18 | # Ignore pickle files
19 | *.pkl
--------------------------------------------------------------------------------
/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/.gitkeep
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | DallE Implementation in pytorch with generation using mingpt
2 | ========
3 |
4 | This repository implements DallE-1 [Zero-Shot Text-to-Image Generation](https://arxiv.org/abs/2102.12092) on a synthetic dataset of mnist colored numbers on textures/solid background .
5 |
6 |
7 | ## DallE Tutorial Video
8 |
9 |
11 |
12 |
13 | ## Sample from dataset
14 |
15 |
16 |
17 |
18 |
19 | A lot of parts of the implementation have been taken from below two repositories:
20 | 1. GPT from - https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
21 | 2. Parts of DallE implementation from https://github.com/lucidrains/DALLE-pytorch/tree/main/dalle_pytorch .
22 |
23 | I have only kept the minimal version of Dalle which allows us to get decent results(on this dataset) and play around with it. If you are looking for a much more efficient and complete implementation please use the above repo.
24 |
25 | ## Data preparation
26 | For setting up the mnist dataset:
27 | Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
28 |
29 | Download Quarter RGB resolution texture data from [ALOT Homepage](https://aloi.science.uva.nl/public_alot/)
30 |
31 | If you are facing issues then use `curl`
32 |
33 | `curl -O https://aloi.science.uva.nl/public_alot/tars/alot_png4.tar`
34 |
35 |
36 | In case you want to train on higher resolution, you can download that as well and but you would have to create new train.json and test.json.
37 | Rest of the code should work fine as long as you create valid json files.
38 |
39 | Download train.json and test.json from [Drive](https://drive.google.com/drive/folders/1DSpNaM6hk8VNFVKHs-VK97AlP_8ynRKC?usp=sharing)
40 | Verify the data directory has the following structure after textures download
41 | ```
42 | DallE/data/textures/{texture_number}
43 | *.png
44 | DallE/data/train/images/{0/1/.../9}
45 | *.png
46 | DallE/data/test/images/{0/1/.../9}
47 | *.png
48 | DallE/data/train.json
49 | DallE/data/test.json
50 | ```
51 |
52 | # Quickstart
53 | * Create a new conda environment with python 3.8 then run below commands
54 | * ```git clone https://github.com/explainingai-code/DallE.git```
55 | * ```cd DallE```
56 | * ```pip install -r requirements.txt```
57 | * For training/inferencing discrete vae and gpt use the below commands passing the desired configuration file as the config argument in case you want to play with it.
58 | * ```python -m tools.train_dvae``` for training discrete vae
59 | * ```python -m tools.infer_dvae``` for generating reconstructions
60 | * ```python -m tools.train_dalle``` for training minimal version of DallE
61 | * ```python -m tools.generate_image``` for using the trained DallE to generate images
62 |
63 | ## Configuration
64 | * ```config/default.yaml``` - Allows you to play with different components of discrete vae as well as DallE and play around with these modifications
65 |
66 |
67 | ## Output
68 | Outputs will be saved according to the configuration present in yaml files.
69 |
70 | For every run a folder of ```task_name``` key in config will be created and ```output_train_dir``` will be created inside it.
71 |
72 | During training of Discrete VAE and DallE the following output will be saved
73 | * Best Model checkpoints(DVAE and DallE) in ```task_name``` directory
74 |
75 | During inference the following output will be saved
76 | * Reconstructions for sample of test set in ```task_name/dvae_reconstruction.png```
77 | * GPT generation output in ```task_name/generation_results.png```
78 |
79 |
80 | ## Sample Output for DallE
81 |
82 | Running default config DiscreteVAE should give you below reconstructions (left - input | right - reconstruction)
83 |
84 |
85 |
86 | Sample Generation Output after 40 epochs with 4 layers and 512 hidden dimension and 8 attention heads
87 |
88 | Generate 0 in blue and solid background of olive
89 |
90 | Generate 1 in cyan and texture background of cracker
91 |
92 | Generate 6 in pink and texture background of stones
93 |
94 | Generate 8 in red and texture background of lego
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 | ## Citations
104 |
105 | ```
106 | @misc{ramesh2021zeroshot,
107 | title={Zero-Shot Text-to-Image Generation},
108 | author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
109 | year={2021},
110 | eprint={2102.12092},
111 | archivePrefix={arXiv},
112 | primaryClass={cs.CV}
113 | }
114 | ```
115 |
116 |
117 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/config/__init__.py
--------------------------------------------------------------------------------
/config/default.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | root_dir: 'data'
3 | image_size: 112
4 | drop_background_prob: 0.1
5 | drop_color_prob: 0.1
6 |
7 | model_params:
8 | vae_num_embeddings: 2048
9 | vae_embedding_dim : 1024
10 | # Will be 112/8 as we have 3 downsamples
11 | dalle_image_size: 14
12 |
13 | gpt_config:
14 | embd_pdrop: 0.1
15 | resid_pdrop: 0.1
16 | attn_pdrop: 0.1
17 | n_layer: 4
18 | n_head: 8
19 | n_embd: 512
20 |
21 | train_params:
22 | task_name: 'default'
23 | batch_size: 64
24 | dalle_batch_size: 256
25 | num_epochs: 40
26 | num_epochs_dalle: 50
27 | dalle_image_loss: 10
28 | kl_weight: 0
29 | lr: 0.001
30 | crit: 'l1'
31 | seed: 1111
32 | save_vae_training_image: True
33 | vae_ckpt_name: 'vae_ckpt.pth'
34 | dalle_ckpt_name: 'dalle_ckpt.pth'
35 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/data/.gitkeep
--------------------------------------------------------------------------------
/data/test/keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/data/test/keep
--------------------------------------------------------------------------------
/data/textures/keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/data/textures/keep
--------------------------------------------------------------------------------
/data/train/keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/data/train/keep
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/mnist_color_texture_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 | import random
5 | import matplotlib.colors as mcolors
6 | import torch
7 | import json
8 | from tqdm import tqdm
9 | from torch.utils.data.dataset import Dataset
10 | from torch.utils.data.dataloader import DataLoader
11 |
12 |
13 | def get_square_crop(image):
14 | h,w = image.shape[:2]
15 | if h > w:
16 | return image[(h - w)//2:-(h - w)//2, :, :]
17 | else:
18 | return image[:, (w - h) // 2:-(w - h) // 2, :]
19 |
20 |
21 | class MnistVisualLanguageDataset(Dataset):
22 | r"""
23 | Minimal visual language dataset class which auto generates fixed format caption
24 | for each dataset point
25 | """
26 | def __init__(self, split, config):
27 | self.split = split
28 | self.db_root = config['root_dir']
29 | self.im_size = config['image_size']
30 |
31 | # Probability of randomly dropping background info
32 | self.drop_background_info_prob = config['drop_background_prob']
33 | # Probability for dropping font color and background color
34 | self.drop_font_color_info_prob = config['drop_color_prob']
35 |
36 | # Auto generated caption formats
37 | self.generation_text_format_tokens = ' generate image of {} in {} and a {} background of {} '
38 | self.generation_text_format_tokens_drop_bg = ' generate image of {} in {} '
39 | self.generation_text_format_tokens_drop_color = ' generate image of {} '
40 |
41 | # Validate right amount of padding and ensure all are same length
42 | assert (len(self.generation_text_format_tokens.split(' ')) == len(self.generation_text_format_tokens_drop_bg.split(' '))
43 | == len(self.generation_text_format_tokens_drop_color.split(' ')))
44 | self.max_token_len = len(self.generation_text_format_tokens.split(' '))
45 | self.visual_language_db = json.load(open(os.path.join(self.db_root, self.split + '.json')))
46 |
47 | self.vocab_idx_to_word, self.vocab_word_to_idx = self.build_vocab()
48 |
49 | def build_vocab(self):
50 | r"""
51 | Method to get dictionary of word to indexes and
52 | indexes to word to be used for tokenizing
53 | and for generation purposes
54 | :return:
55 | """
56 | vocab_generation_tokens = [word for word in self.generation_text_format_tokens.split(' ') if word != '{}']
57 | vocab_generation_tokens += [word for word in self.generation_text_format_tokens_drop_bg.split(' ') if word != '{}']
58 | vocab_generation_tokens += [word for word in self.generation_text_format_tokens_drop_color.split(' ') if word != '{}']
59 | vocab_preset = set(vocab_generation_tokens)
60 | for db_entry in self.visual_language_db:
61 | if 'texture_name' in db_entry:
62 | vocab_preset.add(db_entry['texture_name'])
63 | vocab_preset.add('texture')
64 | if 'background_color' in db_entry:
65 | vocab_preset.add(db_entry['background_color'])
66 | vocab_preset.add('solid')
67 | vocab_preset.add(db_entry['digit_name'])
68 | vocab_preset.add(db_entry['digit_color'])
69 | vocab_tokens = sorted(list(vocab_preset))
70 | vocab_word_to_idx = { k:v for (k,v) in zip(vocab_tokens, range(len(vocab_tokens)))}
71 | vocab_idx_to_word = { v:k for (k,v) in zip(vocab_tokens, range(len(vocab_tokens)))}
72 | return vocab_idx_to_word, vocab_word_to_idx
73 |
74 | def __len__(self):
75 | return len(self.visual_language_db)
76 |
77 | def __getitem__(self, index):
78 | entry = self.visual_language_db[index]
79 | background_type = 'solid' if 'background_color' in entry else 'texture'
80 | drop_type = random.choices(['no_drop','drop_bg','drop_bg_and_color'],
81 | weights=[1-self.drop_background_info_prob-self.drop_font_color_info_prob,
82 | self.drop_background_info_prob,
83 | self.drop_font_color_info_prob])[0]
84 | if drop_type == 'no_drop':
85 | text = self.generation_text_format_tokens.format(entry['digit_name'],
86 | entry['digit_color'],
87 | background_type,
88 | entry['background_color'] if background_type == 'solid' else
89 | entry['texture_name'])
90 | elif drop_type == 'drop_bg':
91 | text = self.generation_text_format_tokens_drop_bg.format(entry['digit_name'],
92 | entry['digit_color'])
93 | else:
94 | text = self.generation_text_format_tokens_drop_color.format(entry['digit_name'])
95 |
96 | text_tokens = [self.vocab_word_to_idx[word] for word in text.split(' ')]
97 | text_tokens = torch.LongTensor(text_tokens)
98 |
99 | digit_im = cv2.imread(os.path.join(self.db_root, entry['digit_image']))
100 | digit_im = cv2.cvtColor(digit_im, cv2.COLOR_BGR2RGB)
101 | digit_im = cv2.resize(digit_im, (self.im_size, self.im_size))
102 |
103 | # Discretize mnist images to be either 0 or 1
104 | digit_im[digit_im > 50] = 255
105 | digit_im[digit_im <= 50] = 0
106 | mask_val = (digit_im > 0).astype(np.float32)
107 | color_scale = mcolors.hex2color('tab:{}'.format(entry['digit_color']))
108 | digit_im = np.concatenate((digit_im[:, :, 0][..., None] * color_scale[0],
109 | digit_im[:, :, 1][..., None] * color_scale[1],
110 | digit_im[:, :, 2][..., None] * color_scale[2]), axis=-1)
111 | if background_type == 'texture':
112 | im = cv2.imread(os.path.join(self.db_root, entry['texture_image']))
113 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
114 | im = get_square_crop(im)
115 | im = cv2.resize(im, (self.im_size, self.im_size))
116 | else:
117 | im = np.ones((self.im_size, self.im_size, 3))
118 | back_color_scale = mcolors.hex2color('tab:{}'.format(entry['background_color']))
119 | im[:, :, 0] = 255*back_color_scale[0]
120 | im[:, :, 1] = 255*back_color_scale[1]
121 | im[:, :, 2] = 255*back_color_scale[2]
122 | out_im = mask_val * digit_im + (1 - mask_val) * im
123 | im_tensor = torch.from_numpy(out_im).permute((2, 0, 1))
124 | im_tensor = 2 * (im_tensor / 255) - 1
125 | return {
126 | "image" : im_tensor,
127 | "text_tokens" : text_tokens,
128 | "text" : text,
129 | }
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/model/__init__.py
--------------------------------------------------------------------------------
/model/dalle.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from model.mingpt import GPT, DallEGPTConfig
4 |
5 | class DallE(nn.Module):
6 | r"""
7 | Class handling the logic for DallE
8 | Calls the vae and passes the text and image tokens
9 | together with target to gpt
10 | """
11 | def __init__(self, vae, num_words, image_size, max_text_len, image_vocab_size, gpt_config):
12 | super(DallE, self).__init__()
13 | self.vae = vae
14 |
15 | # Text Vocab size
16 | self.num_words = num_words
17 | # Number of Image tokens
18 | self.image_size = image_size
19 | # Maximum Text Sequence Length
20 | self.max_text_len = max_text_len
21 |
22 | # Image tokens vocabulary size (num_of_embeddings)
23 | image_vocab_size = image_vocab_size
24 |
25 | # Length of largest sequence so that we tell gpt
26 | # to have that as the context size
27 | max_sequence_len = max_text_len + image_size*image_size
28 | config = DallEGPTConfig(text_vocab_size=num_words,
29 | image_vocab_size=image_vocab_size,
30 | max_sequence_len=max_sequence_len,
31 | im_size=image_size,
32 | **gpt_config)
33 | self.gpt = GPT(config)
34 |
35 | def forward(self, im, text):
36 | # Call Discrete vae
37 | image_tokens = self.vae.get_codebook_indices(im).reshape(im.size(0), -1)
38 |
39 | # Shift the target image tokens as image tokens + text vocab size
40 | # Last fc layer will predict 0 to (num_words + num_embeddings) output probabilities
41 | # We will formulate the target such first num_words-1 are text token probabilities
42 | # and num_words to num_words+num_embeddings are image token probabilities
43 | target_image_tokens = image_tokens + self.num_words
44 | labels = None
45 |
46 | if self.training:
47 | # Pass one position shifted tokens as targets only in training
48 | labels = torch.cat((text[:, 1:], target_image_tokens), dim=1)
49 | # Loss of text and Loss image separately so that we can get better images
50 | logits, loss_text, loss_image = self.gpt(image_tokens, text, targets=labels)
51 | return logits, loss_text, loss_image
52 |
53 |
--------------------------------------------------------------------------------
/model/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Decoder(nn.Module):
6 | r"""
7 | Decoder with couple of residual blocks
8 | followed by conv transpose relu layers
9 | """
10 | def __init__(self, embedding_dim):
11 | super(Decoder, self).__init__()
12 |
13 | self.decoder_layers = nn.ModuleList([
14 | nn.ConvTranspose2d(64, 64, 4, 2, 1),
15 | nn.ReLU(),
16 | nn.ConvTranspose2d(64, 32, 4, 2, 1),
17 | nn.ReLU(),
18 | nn.ConvTranspose2d(32, 16, 4, 2, 1),
19 | nn.ReLU(),
20 | nn.Conv2d(16, 3, 1),
21 | nn.Tanh()
22 | ])
23 |
24 | self.residuals = nn.ModuleList([
25 | nn.Sequential(
26 | nn.Conv2d(64, 64, 3, padding=1),
27 | nn.ReLU(),
28 | nn.Conv2d(64, 64, 3, padding=1),
29 | nn.ReLU()),
30 | nn.Sequential(
31 | nn.Conv2d(64, 64, 3, padding=1),
32 | nn.ReLU(),
33 | nn.Conv2d(64, 64, 3, padding=1),
34 | nn.ReLU())
35 | ])
36 |
37 | self.decoder_quant_conv = nn.Conv2d(embedding_dim, 64, 1)
38 |
39 |
40 | def forward(self, x):
41 | out = self.decoder_quant_conv(x)
42 | for layer in self.residuals:
43 | out = layer(out)+out
44 | for idx, layer in enumerate(self.decoder_layers):
45 | out = layer(out)
46 | return out
47 |
48 |
49 |
50 | if __name__ == '__main__':
51 | import torch
52 | import torch.nn as nn
53 | import yaml
54 | decoder = Decoder()
55 |
56 | out = decoder(torch.rand((3, 64, 14, 14)))
57 | print(out.shape)
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/model/discrete_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from model.encoder import Encoder
4 | from model.decoder import Decoder
5 | from model.quantizer import Quantizer
6 |
7 | class DiscreteVAE(nn.Module):
8 | def __init__(self, num_embeddings=1024, embedding_dim=512):
9 | super(DiscreteVAE, self).__init__()
10 | self.encoder = Encoder(num_embeddings=num_embeddings)
11 | self.quantizer = Quantizer(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
12 | self.decoder = Decoder(embedding_dim=embedding_dim)
13 |
14 |
15 | def get_codebook_indices(self, x):
16 | # x.shape = B,C,H,W
17 | enc_logits = self.encoder(x)
18 | # enc_logits.shape = B,C,H,W
19 | indices = torch.argmax(enc_logits, dim=1)
20 | return indices
21 |
22 | def decode_from_codebook_indices(self, indices):
23 | quantized_indices = self.quantizer.quantize_indices(indices)
24 | return self.decoder(quantized_indices)
25 |
26 | def forward(self, x):
27 | enc = self.encoder(x)
28 | quant_output, kl, logits, log_qy = self.quantizer(enc)
29 | out = self.decoder(quant_output)
30 | return out, kl, log_qy
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/model/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Encoder(nn.Module):
6 | r"""
7 | Encoder is conv relu blocks
8 | followed by couple of residual blocks.
9 | Last 1x1 conv converts to logits with
10 | num_embeddings as output size
11 | """
12 | def __init__(self, num_embeddings):
13 | super(Encoder, self).__init__()
14 | # Encoder is just Conv relu blocks
15 | self.encoder_layers = nn.ModuleList([
16 | nn.Conv2d(3, 32, 4, 2, 1),
17 | nn.ReLU(),
18 | nn.Conv2d(32, 64, 4, 2, 1),
19 | nn.ReLU(),
20 | nn.Conv2d(64, 64, 4, 2, 1),
21 | nn.ReLU(),
22 | ])
23 | self.residuals = nn.ModuleList([
24 | nn.Sequential(
25 | nn.Conv2d(64, 64, 3, padding = 1),
26 | nn.ReLU(),
27 | nn.Conv2d(64, 64, 3, padding = 1),
28 | nn.ReLU()),
29 | nn.Sequential(
30 | nn.Conv2d(64, 64, 3, padding = 1),
31 | nn.ReLU(),
32 | nn.Conv2d(64, 64, 3, padding=1),
33 | nn.ReLU())
34 | ])
35 | self.encoder_quant_conv = nn.Sequential(
36 | nn.Conv2d(64, num_embeddings, 1))
37 |
38 |
39 | def forward(self, x):
40 | out = x
41 | for layer in self.encoder_layers:
42 | out = layer(out)
43 | for layer in self.residuals:
44 | out = out + layer(out)
45 | out = self.encoder_quant_conv(out)
46 | return out
--------------------------------------------------------------------------------
/model/mingpt.py:
--------------------------------------------------------------------------------
1 | """
2 | GPT model:
3 | - the initial stem consists of a combination of token encoding and a positional encoding
4 | - the meat of it is a uniform sequence of Transformer blocks
5 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
6 | - all blocks feed into a central residual pathway similar to resnets
7 | - the final decoder is a linear projection into a vanilla Softmax classifier
8 | """
9 |
10 | import math
11 | import logging
12 |
13 | import torch
14 | import torch.nn as nn
15 | from torch.nn import functional as F
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | class DallEGPTConfig:
20 | r"""
21 | Minimal DallE config holding all fields requored for
22 | training gpt for both text and iamge tokens
23 | """
24 | def __init__(self, text_vocab_size,
25 | image_vocab_size,
26 | max_sequence_len, im_size, **kwargs):
27 | self.text_vocab_size = text_vocab_size
28 | self.image_vocab_size = image_vocab_size
29 | # Fixing block size to maximum sequence length we have seen
30 | self.block_size = max_sequence_len
31 | self.im_size = im_size
32 | self.num_text_tokens = max_sequence_len - im_size*im_size
33 | for k,v in kwargs.items():
34 | setattr(self, k, v)
35 |
36 |
37 | class CausalSelfAttention(nn.Module):
38 | """
39 | A vanilla multi-head masked self-attention layer with a projection at the end.
40 | It is possible to use torch.nn.MultiheadAttention here but I am including an
41 | explicit implementation here to show that there is nothing too scary here.
42 | """
43 |
44 | def __init__(self, config):
45 | super().__init__()
46 | assert config.n_embd % config.n_head == 0
47 | # key, query, value projections for all heads
48 | self.key = nn.Linear(config.n_embd, config.n_embd)
49 | self.query = nn.Linear(config.n_embd, config.n_embd)
50 | self.value = nn.Linear(config.n_embd, config.n_embd)
51 | # regularization
52 | self.attn_drop = nn.Dropout(config.attn_pdrop)
53 | self.resid_drop = nn.Dropout(config.resid_pdrop)
54 | # output projection
55 | self.proj = nn.Linear(config.n_embd, config.n_embd)
56 | # causal mask to ensure that attention is only applied to the left in the input sequence
57 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
58 | .view(1, 1, config.block_size, config.block_size))
59 | self.n_head = config.n_head
60 |
61 | def forward(self, x, layer_past=None):
62 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
63 |
64 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
65 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
66 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
67 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68 |
69 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
70 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
71 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
72 |
73 | att = F.softmax(att, dim=-1)
74 | att = self.attn_drop(att)
75 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
76 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
77 |
78 | # output projection
79 | y = self.resid_drop(self.proj(y))
80 | return y
81 |
82 | class Block(nn.Module):
83 | """ an unassuming Transformer block """
84 |
85 | def __init__(self, config):
86 | super().__init__()
87 | self.ln1 = nn.LayerNorm(config.n_embd)
88 | self.ln2 = nn.LayerNorm(config.n_embd)
89 | self.attn = CausalSelfAttention(config)
90 | self.mlp = nn.Sequential(
91 | nn.Linear(config.n_embd, 4 * config.n_embd),
92 | nn.GELU(),
93 | nn.Linear(4 * config.n_embd, config.n_embd),
94 | nn.Dropout(config.resid_pdrop),
95 | )
96 |
97 | def forward(self, x):
98 | y = self.attn(self.ln1(x))
99 | x = x + y
100 | x = x + self.mlp(self.ln2(x))
101 | return x
102 |
103 | class GPT(nn.Module):
104 | """ the full GPT language model, with a context size of block_size """
105 |
106 | def __init__(self, config):
107 | super().__init__()
108 |
109 | # input embedding stem
110 | self.text_tok_emb = nn.Embedding(config.text_vocab_size, config.n_embd)
111 | self.image_tok_emb = nn.Embedding(config.image_vocab_size, config.n_embd)
112 |
113 | self.text_pos_emb = nn.Parameter(torch.zeros(1, config.num_text_tokens, config.n_embd))
114 | self.image_pos_emb = nn.Parameter(torch.zeros(1, config.im_size ** 2, config.n_embd))
115 | self.drop = nn.Dropout(config.embd_pdrop)
116 | # transformer
117 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
118 | # decoder head
119 | self.ln_f = nn.LayerNorm(config.n_embd)
120 | self.head = nn.Linear(config.n_embd, config.text_vocab_size + config.image_vocab_size, bias=False)
121 | self.config = config
122 | self.block_size = config.block_size
123 | self.apply(self._init_weights)
124 |
125 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
126 |
127 | def get_block_size(self):
128 | return self.block_size
129 |
130 | def _init_weights(self, module):
131 | if isinstance(module, (nn.Linear, nn.Embedding)):
132 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
133 | if isinstance(module, nn.Linear) and module.bias is not None:
134 | torch.nn.init.zeros_(module.bias)
135 | elif isinstance(module, nn.LayerNorm):
136 | torch.nn.init.zeros_(module.bias)
137 | torch.nn.init.ones_(module.weight)
138 | elif isinstance(module, GPT):
139 | torch.nn.init.normal_(module.text_pos_emb, mean=0.0, std=0.02)
140 | # Simply do the same thing for image as for text
141 | torch.nn.init.normal_(module.image_pos_emb, mean=0.0, std=0.02)
142 |
143 |
144 | def forward(self, image_tokens, text_tokens, targets=None,):
145 | b, im_t = image_tokens.size()
146 | b, text_t = text_tokens.size()
147 | assert im_t + text_t <= self.block_size, "Cannot forward, model block size is exhausted."
148 |
149 | text_emb = self.text_tok_emb(text_tokens)
150 | text_pos = self.text_pos_emb[:, :text_t, :]
151 | text_token_embeddings = self.drop(text_emb + text_pos)
152 | x = text_token_embeddings
153 |
154 | # Add image tokens for input sequence if needed.
155 | # Won't be needed for first pixel generation
156 | if im_t > 0:
157 | image_emb = self.image_tok_emb(image_tokens)
158 | image_pos = self.image_pos_emb[:, :im_t, :]
159 | image_token_embeddings = self.drop(image_emb + image_pos)
160 | x = torch.cat([x, image_token_embeddings], dim=1)
161 |
162 | x = self.blocks(x)
163 | x = self.ln_f(x)
164 | logits = self.head(x)
165 |
166 | # if we are given some desired targets also calculate the loss
167 | # Separate text and image loss
168 | loss_text = None
169 | loss_image = None
170 | if targets is not None:
171 | logits = logits[:, :-1, :]
172 |
173 | # Separate text and image token loss computation
174 | text_logits = logits[:, :text_t - 1, :].permute((0, 2, 1))
175 | image_logits = logits[:, text_t - 1:, :].permute((0, 2, 1))
176 |
177 | # For now just mask logits of image tokens for text targets
178 | # And mask out text tokens logits for iamge targets
179 | # Dont want gpt to gain points by simply decreasing scores for indexes of the other type
180 | # And anyway at inference you would always sample image token when generating image
181 | text_logits[:, self.config.text_vocab_size:, :] = -torch.finfo(logits.dtype).max
182 | image_logits[:, :self.config.text_vocab_size, :] = -torch.finfo(logits.dtype).max
183 | loss_text = F.cross_entropy(text_logits, targets[:, :text_t-1])
184 | loss_image = F.cross_entropy(image_logits, targets[:, text_t-1:])
185 | return logits, loss_text, loss_image
--------------------------------------------------------------------------------
/model/quantizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import einsum, rearrange
4 |
5 |
6 | class Quantizer(nn.Module):
7 | def __init__(self, num_embeddings, embedding_dim):
8 | super(Quantizer, self).__init__()
9 |
10 | self.num_embeddings = num_embeddings
11 | self.embedding = nn.Embedding(self.num_embeddings, embedding_dim)
12 |
13 | def forward(self, x):
14 | B, C, H, W = x.shape
15 | one_hot = torch.nn.functional.gumbel_softmax(x, tau=0.9, dim=1, hard=False)
16 | sampled = einsum(one_hot, self.embedding.weight, 'b n h w, n d -> b d h w')
17 |
18 | # Compute kl loss
19 | logits = rearrange(x, 'b n h w -> b (h w) n')
20 | log_qy = torch.nn.functional.log_softmax(logits, dim=-1)
21 | log_uniform = torch.log(torch.tensor([1. / self.num_embeddings], device=torch.device(x.device)))
22 | kl_div = torch.nn.functional.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target=True)
23 | return sampled, kl_div, logits, log_qy
24 |
25 | def quantize_indices(self, indices):
26 | return einsum(indices, self.embedding.weight, 'b n h w, n d -> b d h w')
27 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.6.1
2 | matplotlib==3.7.2
3 | numpy==1.23.5
4 | opencv_python==4.8.0.74
5 | PyYAML==6.0
6 | torch==1.11.0
7 | torchvision==0.12.0
8 | tqdm==4.65.0
9 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/Dalle-Pytorch/c54c23cd7a0e0d62e89a9c9f36060873272e6d6e/tools/__init__.py
--------------------------------------------------------------------------------
/tools/generate_image.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import random
5 | import os
6 | import torchvision
7 | import numpy as np
8 | from einops import rearrange
9 | from tqdm import tqdm
10 | from model.discrete_vae import DiscreteVAE
11 | from model.dalle import DallE
12 | from dataset.mnist_color_texture_dataset import MnistVisualLanguageDataset
13 | from torchvision.utils import make_grid
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 |
17 |
18 | def infer(args):
19 | ######## Read the config file #######
20 | with open(args.config_path, 'r') as file:
21 | try:
22 | config = yaml.safe_load(file)
23 | except yaml.YAMLError as exc:
24 | print(exc)
25 | print(config)
26 |
27 | ######## Set the desired seed value #######
28 | # Ignoring the fixed seed value
29 | seed = np.random.randint(0, 1000)
30 | torch.manual_seed(seed)
31 | np.random.seed(seed)
32 | random.seed(seed)
33 | if device == 'cuda':
34 | torch.cuda.manual_seed_all(seed)
35 |
36 | if not os.path.exists(config['train_params']['task_name']):
37 | os.mkdir(config['train_params']['task_name'])
38 |
39 | # Create db to fetch the configuration values like vocab size (should do something better)
40 | mnist = MnistVisualLanguageDataset('train', config['dataset_params'])
41 |
42 | ###### Load Discrete VAE#####
43 |
44 | vae = DiscreteVAE(
45 | num_embeddings=config['model_params']['vae_num_embeddings'],
46 | embedding_dim=config['model_params']['vae_embedding_dim']
47 | )
48 | vae.to(device)
49 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'],
50 | config['train_params']['vae_ckpt_name'])):
51 | print('Found checkpoint... Taking vae from that')
52 | vae.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'],
53 | config['train_params']['vae_ckpt_name']), map_location=device))
54 | else:
55 | print('No checkpoint found at {}/{}... Exiting'.format(config['train_params']['task_name'],
56 | config['train_params']['vae_ckpt_name']))
57 | print('Train vae first')
58 | return
59 | vae.eval()
60 | vae.requires_grad_(False)
61 |
62 | ##############################
63 |
64 |
65 | ########### Load DallE ##########
66 | model = DallE(vae=vae,
67 | num_words=len(mnist.vocab_word_to_idx),
68 | image_size=config['model_params']['dalle_image_size'],
69 | max_text_len=mnist.max_token_len,
70 | image_vocab_size=config['model_params']['vae_num_embeddings'],
71 | gpt_config=config['gpt_config'])
72 | model.to(device)
73 | model.eval()
74 | model.requires_grad_(False)
75 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'],
76 | config['train_params']['dalle_ckpt_name'])):
77 | print('Found checkpoint... Starting training from that')
78 | model.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'],
79 | config['train_params']['dalle_ckpt_name']), map_location=device))
80 | else:
81 | print('No checkpoint found for dalle at {}/{}... Exiting'.format(config['train_params']['task_name'],
82 | config['train_params']['dalle_ckpt_name']))
83 |
84 | return
85 | #################################
86 |
87 | im_tokens_len = config['model_params']['dalle_image_size'] * config['model_params']['dalle_image_size']
88 | colors = ['red', 'blue', 'pink', 'green', 'cyan']
89 | textures = ['lego', 'stones', 'wool', 'cracker', 'peas']
90 | solids = ['orange', 'olive', 'purple']
91 | numbers = list(range(10))
92 |
93 | #### Genrate 10 random images ######
94 | vae_inputs = []
95 | fnames = []
96 | for _ in tqdm(range(10)):
97 | color = random.choice(colors)
98 | number = random.choice(numbers)
99 |
100 | ######## Set the desired seed value #######
101 | seed = np.random.randint(0, 1000)
102 | torch.manual_seed(seed)
103 | np.random.seed(seed)
104 | random.seed(seed)
105 | if device == 'cuda':
106 | torch.cuda.manual_seed_all(seed)
107 |
108 | if random.random() < 0.1:
109 | solid = random.choice(solids)
110 | sent = ('generate image of {} in {} and a solid background of {}'
111 | .format(number, color, solid).split(' '))
112 | fnames.append('{}_{}_{}.png'.format(number, color, solid))
113 | else:
114 | texture = random.choice(textures)
115 | sent = ('generate image of {} in {} and a texture background of {}'.
116 | format(number, color, texture).split(' '))
117 | fnames.append('{}_{}_{}.png'.format(number, color, texture))
118 | sent = [''] + sent + ['']
119 | text_tokens = torch.LongTensor([mnist.vocab_word_to_idx[word] for word in sent]).to(device).unsqueeze(0)
120 | random_im_tokens = torch.randint(0, config['model_params']['vae_num_embeddings'],
121 | (model.image_size * model.image_size,)).to(device)
122 |
123 | #### Generate pixels one by one #####
124 | im_tokens = torch.LongTensor([]).to(device)
125 | for tok_idx in range(im_tokens_len):
126 | logits, _, _ = model.gpt(im_tokens.unsqueeze(0), text_tokens)
127 | logits = logits[:, -1, :]
128 |
129 | # Ignore logits of all non-image tokens
130 | logits[:, :len(mnist.vocab_word_to_idx)] = -torch.finfo(logits.dtype).max
131 |
132 | # Get topk and sample from them
133 | val, ind = torch.topk(logits, 3)
134 | probs = torch.full_like(logits, -torch.finfo(logits.dtype).max)
135 | probs.scatter_(1, ind, val)
136 | probs = torch.nn.functional.softmax(logits, dim=-1)
137 | sample = torch.multinomial(probs, num_samples=1)[0].to(device)
138 |
139 | # Reduce predicted output by text vocab size to get vae token index
140 | sample -= model.num_words
141 |
142 | im_tokens = torch.cat((im_tokens, sample), dim=-1)
143 | random_im_tokens[:tok_idx + 1] = im_tokens
144 |
145 |
146 | vae_input = random_im_tokens.reshape((model.image_size, model.image_size))
147 | vae_inputs.append(vae_input.unsqueeze(0))
148 |
149 | # Pass predicted discrete sequence to vae
150 | vae_inputs = torch.cat(vae_inputs, dim=0)
151 | z = torch.nn.functional.one_hot(vae_inputs, num_classes=config['model_params']['vae_num_embeddings'])
152 | z = rearrange(z, 'b h w c -> b c h w').float()
153 | output = vae.decode_from_codebook_indices(z)
154 | output = (output + 1) / 2
155 | for idx in range((output.size(0))):
156 | img = torchvision.transforms.ToPILImage()(output[idx].detach().cpu())
157 | img.save(os.path.join(config['train_params']['task_name'],
158 | fnames[idx]))
159 |
160 |
161 |
162 |
163 |
164 | if __name__ == '__main__':
165 | parser = argparse.ArgumentParser(description='Arguments for generating outputs')
166 | parser.add_argument('--config', dest='config_path',
167 | default='config/default.yaml', type=str)
168 | args = parser.parse_args()
169 | infer(args)
--------------------------------------------------------------------------------
/tools/infer_dvae.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import yaml
3 | import argparse
4 | import torch
5 | import os
6 | import torchvision
7 | from model.discrete_vae import DiscreteVAE
8 | from dataset.mnist_color_texture_dataset import MnistVisualLanguageDataset
9 | from torchvision.utils import make_grid
10 | from einops import rearrange
11 |
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 |
15 | def inference(args):
16 | r"""
17 | Method to infer discrete vae and get
18 | reconstructions
19 | :param args:
20 | :return:
21 | """
22 | with open(args.config_path, 'r') as file:
23 | try:
24 | config = yaml.safe_load(file)
25 | except yaml.YAMLError as exc:
26 | print(exc)
27 | print(config)
28 |
29 | model = DiscreteVAE(
30 | num_embeddings=config['model_params']['vae_num_embeddings'],
31 | embedding_dim=config['model_params']['vae_embedding_dim']
32 | )
33 | model.to(device)
34 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'],
35 | config['train_params']['vae_ckpt_name'])):
36 | print('Found checkpoint... Inferring from that')
37 | model.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'],
38 | config['train_params']['vae_ckpt_name']), map_location=device))
39 | else:
40 | print('No checkpoint found at {}/{}... Exiting'.format(config['train_params']['task_name'],
41 | config['train_params']['vae_ckpt_name']))
42 | return
43 | model.eval()
44 | mnist = MnistVisualLanguageDataset('test', config['dataset_params'])
45 |
46 | # Generate reconstructions for 100 samples
47 | idxs = torch.randint(0, len(mnist) - 1, (25,))
48 | ims = torch.cat([mnist[idx]['image'][None, :] for idx in idxs]).float().to(device)
49 | output = model(ims)
50 | generated_im = output[0]
51 |
52 | # Dataset generates -1 to 1 we convert it to 0-1
53 | ims = (ims + 1) / 2
54 | generated_im = (generated_im + 1) / 2
55 | out = torch.hstack([ims, generated_im])
56 | output = rearrange(out, 'b (c d) h w -> b (d) h (c w)', c=2, d=3)
57 | grid = make_grid(output, nrow=5)
58 | img = torchvision.transforms.ToPILImage()(grid.detach().cpu())
59 | img.save(os.path.join(config['train_params']['task_name'],
60 | 'dvae_reconstructions.png'))
61 |
62 |
63 | if __name__ == '__main__':
64 | parser = argparse.ArgumentParser(description='Arguments for discrete vae inference')
65 | parser.add_argument('--config', dest='config_path',
66 | default='config/default.yaml', type=str)
67 | args = parser.parse_args()
68 | inference(args)
69 |
--------------------------------------------------------------------------------
/tools/train_dalle.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import random
5 | import os
6 | import numpy as np
7 | from tqdm import tqdm
8 | from model.discrete_vae import DiscreteVAE
9 | from model.dalle import DallE
10 | from torch.utils.data.dataloader import DataLoader
11 | from dataset.mnist_color_texture_dataset import MnistVisualLanguageDataset
12 | from torch.optim import Adam
13 | from torch.optim.lr_scheduler import ReduceLROnPlateau
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 | plt_counts = 0
17 |
18 |
19 | def train_for_one_epoch(epoch_idx, model, loader, optimizer, config):
20 | r"""
21 | Method to run the training for one epoch.
22 | :param epoch_idx: iteration number of current epoch
23 | :param model: Dalle model
24 | :param mnist_loader: Data loder
25 | :param optimizer: optimzier to be used taken from config
26 | :param crtierion: For computing the loss
27 | :param config: configuration for the current run
28 | :return:
29 | """
30 | losses = []
31 | for data in tqdm(loader):
32 | im = data['image']
33 | text_tokens = data['text_tokens']
34 | im = im.float().to(device)
35 | text = text_tokens.long().to(device)
36 | optimizer.zero_grad()
37 |
38 | _, loss_text, loss_image = model(im, text)
39 | loss = (loss_text*1 + loss_image*config['train_params']['dalle_image_loss']) / (1+config['train_params']['dalle_image_loss'])
40 | losses.append(loss.item())
41 | loss.backward()
42 | optimizer.step()
43 | print('Finished epoch: {} | Modelling Loss : {:.4f} '.
44 | format(epoch_idx + 1,
45 | np.mean(losses)))
46 | return np.mean(losses)
47 |
48 |
49 | def train(args):
50 | ######## Read the config file #######
51 | with open(args.config_path, 'r') as file:
52 | try:
53 | config = yaml.safe_load(file)
54 | except yaml.YAMLError as exc:
55 | print(exc)
56 | print(config)
57 |
58 | #######################################
59 |
60 | ######## Set the desired seed value #######
61 | seed = config['train_params']['seed']
62 | torch.manual_seed(seed)
63 | np.random.seed(seed)
64 | random.seed(seed)
65 | if device == 'cuda':
66 | torch.cuda.manual_seed_all(seed)
67 |
68 | if not os.path.exists(config['train_params']['task_name']):
69 | os.mkdir(config['train_params']['task_name'])
70 |
71 | ######## Create the model and dataset ##########
72 | num_epochs = config['train_params']['num_epochs_dalle']
73 | mnist = MnistVisualLanguageDataset('train', config['dataset_params'])
74 | mnist_loader = DataLoader(mnist, batch_size=config['train_params']['dalle_batch_size'],
75 | shuffle=True, num_workers=4)
76 | vae = DiscreteVAE(
77 | num_embeddings=config['model_params']['vae_num_embeddings'],
78 | embedding_dim=config['model_params']['vae_embedding_dim']
79 | )
80 | vae.to(device)
81 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'],
82 | config['train_params']['vae_ckpt_name'])):
83 | print('Found checkpoint... Taking vae from that')
84 | vae.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'],
85 | config['train_params']['vae_ckpt_name']),map_location=device))
86 | else:
87 | print('No checkpoint found at {}/{}... Exiting'.format(config['train_params']['task_name'],
88 | config['train_params']['vae_ckpt_name']))
89 | print('Train vae first')
90 | return
91 | vae.eval()
92 | vae.requires_grad_(False)
93 |
94 |
95 | model = DallE(vae=vae,
96 | num_words=len(mnist.vocab_word_to_idx),
97 | image_size=config['model_params']['dalle_image_size'],
98 | max_text_len=mnist.max_token_len,
99 | image_vocab_size=config['model_params']['vae_num_embeddings'],
100 | gpt_config=config['gpt_config'])
101 | model.to(device)
102 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'],
103 | config['train_params']['dalle_ckpt_name'])):
104 | print('Found checkpoint... Starting training from that')
105 | model.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'],
106 | config['train_params']['dalle_ckpt_name']),map_location=device))
107 |
108 | ####### Training Parameters ############
109 | optimizer = Adam(model.parameters(), lr=config['train_params']['lr'])
110 | scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=1, verbose=True)
111 |
112 | best_loss = np.inf
113 | for epoch_idx in range(num_epochs):
114 | mean_loss = train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, config)
115 | scheduler.step(mean_loss)
116 | # Simply update checkpoint if found better version
117 | if mean_loss < best_loss:
118 | print('Improved Loss from {:.4f} to {:.4f} .... Saving Model'.format(best_loss, mean_loss))
119 | torch.save(model.state_dict(), '{}/{}'.format(config['train_params']['task_name'],
120 | config['train_params']['dalle_ckpt_name']))
121 | best_loss = mean_loss
122 | else:
123 | print('No Loss Improvement. Best Loss : {:.4f}'.format(best_loss))
124 |
125 |
126 | if __name__ == '__main__':
127 | parser = argparse.ArgumentParser(description='Arguments for dalle training')
128 | parser.add_argument('--config', dest='config_path',
129 | default='config/default.yaml', type=str)
130 | args = parser.parse_args()
131 | train(args)
132 |
--------------------------------------------------------------------------------
/tools/train_dvae.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | import torch
4 | import cv2
5 | import random
6 | import os
7 | import numpy as np
8 | from tqdm import tqdm
9 | from model.discrete_vae import DiscreteVAE
10 | from torch.utils.data.dataloader import DataLoader
11 | from dataset.mnist_color_texture_dataset import MnistVisualLanguageDataset
12 | from torch.optim import Adam
13 | from torch.optim.lr_scheduler import ReduceLROnPlateau
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 |
17 |
18 | def train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, crtierion, config):
19 | losses = []
20 | count = 0
21 | for data in tqdm(mnist_loader):
22 | # For vae we only need images
23 | im = data['image']
24 | im = im.float().to(device)
25 | optimizer.zero_grad()
26 |
27 | output, kl, log_qy = model(im)
28 | if config['train_params']['save_vae_training_image'] and count % 25 == 0:
29 | im_input = cv2.cvtColor((255 * (im.detach() + 1) / 2).cpu().permute((0, 2, 3, 1)).numpy()[0],
30 | cv2.COLOR_RGB2BGR)
31 | im_output = cv2.cvtColor((255 * (output.detach() + 1) / 2).cpu().permute((0, 2, 3, 1)).numpy()[0],
32 | cv2.COLOR_RGB2BGR)
33 | cv2.imwrite('{}/input.jpeg'.format(config['train_params']['task_name']), im_input)
34 | cv2.imwrite('{}/output.jpeg'.format(config['train_params']['task_name']), im_output)
35 |
36 | loss = (crtierion(output, im) + config['train_params']['kl_weight']*kl)/(1+config['train_params']['kl_weight'])
37 | losses.append(loss.item())
38 | loss.backward()
39 | optimizer.step()
40 | count += 1
41 |
42 | print('Finished epoch: {} | Loss : {:.4f} '.
43 | format(epoch_idx + 1,
44 | np.mean(losses)))
45 | return np.mean(losses)
46 |
47 |
48 | def train(args):
49 | ######## Read the config file #######
50 | with open(args.config_path, 'r') as file:
51 | try:
52 | config = yaml.safe_load(file)
53 | except yaml.YAMLError as exc:
54 | print(exc)
55 | print(config)
56 | #######################################
57 |
58 | ######## Set the desired seed value #######
59 | seed = config['train_params']['seed']
60 | torch.manual_seed(seed)
61 | np.random.seed(seed)
62 | random.seed(seed)
63 | if device == 'cuda':
64 | torch.cuda.manual_seed_all(seed)
65 |
66 | if not os.path.exists(config['train_params']['task_name']):
67 | os.mkdir(config['train_params']['task_name'])
68 |
69 |
70 | #######################################
71 | # Create the model and dataset
72 | num_epochs = config['train_params']['num_epochs']
73 | model = DiscreteVAE(
74 | num_embeddings=config['model_params']['vae_num_embeddings'],
75 | embedding_dim=config['model_params']['vae_embedding_dim']
76 | )
77 | model.to(device)
78 |
79 | if os.path.exists('{}/{}'.format(config['train_params']['task_name'],
80 | config['train_params']['vae_ckpt_name'])):
81 | print('Found checkpoint... Starting training from that')
82 | model.load_state_dict(torch.load('{}/{}'.format(config['train_params']['task_name'],
83 | config['train_params']['vae_ckpt_name'])))
84 | mnist = MnistVisualLanguageDataset('train', config['dataset_params'])
85 | mnist_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'],
86 | shuffle=True, num_workers=4)
87 |
88 | optimizer = Adam(model.parameters(), lr=config['train_params']['lr'])
89 | scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=1, verbose=True)
90 | criterion = {
91 | 'l1': torch.nn.SmoothL1Loss(beta=0.1),
92 | 'l2': torch.nn.MSELoss()
93 | }.get(config['train_params']['crit'])
94 |
95 |
96 | if not os.path.exists(config['train_params']['task_name']):
97 | os.mkdir(config['train_params']['task_name'])
98 |
99 |
100 | best_loss = np.inf
101 | for epoch_idx in range(num_epochs):
102 | mean_loss = train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, criterion, config)
103 | scheduler.step(mean_loss)
104 | # Simply update checkpoint if found better version
105 | if mean_loss < best_loss:
106 | print('Improved Loss from {:.4f} to {:.4f} .... Saving Model'.format(best_loss, mean_loss))
107 | torch.save(model.state_dict(), '{}/{}'.format(config['train_params']['task_name'],
108 | config['train_params']['vae_ckpt_name']))
109 | best_loss = mean_loss
110 | else:
111 | print('No Loss Improvement. Best Loss : {:.4f}'.format(best_loss))
112 |
113 |
114 | if __name__ == '__main__':
115 | parser = argparse.ArgumentParser(description='Arguments for vae training')
116 | parser.add_argument('--config', dest='config_path',
117 | default='config/default.yaml', type=str)
118 | args = parser.parse_args()
119 | train(args)
120 |
--------------------------------------------------------------------------------