├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── storydalle.iml └── vcs.xml ├── DEMO.MD ├── LICENSE ├── MODEL_CARD.MD ├── README.MD ├── assets ├── demo.pdf ├── demo.png ├── demo_pororo_good.png ├── pororo_characters.png ├── story_dalle.png └── story_dalle_predictions.png ├── mega-story-dalle ├── didemo_dataloader.py ├── flintstones_dataloader.py ├── min_dalle │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── min_dalle.cpython-38.pyc │ │ └── text_tokenizer.cpython-38.pyc │ ├── min_dalle.py │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── dalle_bart_decoder.cpython-38.pyc │ │ │ ├── dalle_bart_encoder.cpython-38.pyc │ │ │ └── vqgan_detokenizer.cpython-38.pyc │ │ ├── dalle_bart_decoder.py │ │ ├── dalle_bart_encoder.py │ │ └── vqgan_detokenizer.py │ └── text_tokenizer.py ├── pororo_dataloader.py ├── setup.py ├── tkinter_ui.py ├── train_story.sh └── train_t2i.py └── story-dalle ├── 1.3B └── config.yaml ├── __init__.py ├── __pycache__ └── pororo_dataloader.cpython-38.pyc ├── dalle ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ └── trainer_prefix.cpython-38.pyc ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── prefix_tuning_model.cpython-38.pyc │ │ └── tokenizer.cpython-38.pyc │ ├── stage1 │ │ ├── __pycache__ │ │ │ ├── layers.cpython-37.pyc │ │ │ ├── layers.cpython-38.pyc │ │ │ ├── vqgan.cpython-37.pyc │ │ │ └── vqgan.cpython-38.pyc │ │ ├── layers.py │ │ └── vqgan.py │ ├── stage2 │ │ ├── __pycache__ │ │ │ ├── layers.cpython-37.pyc │ │ │ ├── layers.cpython-38.pyc │ │ │ ├── transformer.cpython-37.pyc │ │ │ └── transformer.cpython-38.pyc │ │ ├── layers.py │ │ └── transformer.py │ └── tokenizer.py ├── trainer_prefix.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── config.cpython-38.pyc │ ├── sampling.cpython-38.pyc │ ├── utils.cpython-37.pyc │ └── utils.cpython-38.pyc │ ├── config.py │ ├── sampling.py │ └── utils.py ├── didemo_dataloader.py ├── eval_char_clf.py ├── eval_char_clf.sh ├── eval_fid.py ├── eval_fid.sh ├── flintstones_dataloader.py ├── get_use_embeddings.py ├── infer_story.sh ├── infer_t2i.py ├── pororo_dataloader.py ├── train_story.sh ├── train_t2i.py ├── utils.py └── vfid ├── __pycache__ ├── fid_score.cpython-38.pyc └── inception.cpython-38.pyc ├── fid_score.py └── inception.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/storydalle.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /DEMO.MD: -------------------------------------------------------------------------------- 1 | ## Demo Coming Soon! 2 | (as soon as porting from Gradio Local Demo to HuggingFace Spaces is complete) 3 | Meanwhile, see the snapshot of our demo functionalities below! 4 | 5 | ![image](./assets/demo.png) 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Adyasha Maharana, Darryl Hannan and Mohit Bansal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MODEL_CARD.MD: -------------------------------------------------------------------------------- 1 | ### Model Description 2 | 3 | StoryDALL-E \[1\] is a model trained for the task of Story Visualization \[2\]. 4 | The model receives a sequence of captions as input and generates a corresponding sequence of images which form a visual story depicting the narrative in the captions. 5 | We modify this task to enable the model to receive an initial scene as input, which can be used as a cue for the setting of the story and also for generating unseen or low-resource visual elements. We refer to this task as Story Continuation \[1\]. 6 | StoryDALL-E is based on the [dalle](https://github.com/kakaobrain/minDALL-E) model. **This model has been developed for academic purposes only.** 7 | 8 | \[[paper](https://arxiv.org/abs/2209.06192)\] \[[code](https://github.com/adymaharana/storydalle/)\] \[[demo]()\] 9 | 10 | ### Dataset 11 | 12 | This model has been trained using the Pororo story visualization dataset \[2\]. 13 | The data was adapted from the popular cartoon series *Pororo the Little Penguin* and originally released by \[3\]. 14 | The Pororo dataset contains 9 recurring characters, as shown below, in the decreasing order of their frequency in the training data. 15 |

16 | 17 |

18 | The training dataset contains nearly 10,000 samples in the training set. Most of the scenes occur in a snowy village, surrounded by hills, trees and houses. A few episodes are located in gardens or water bodies. All the captions are in the English language and predominantly contain verbs in the present tense. 19 |
20 | Additionally, the training of this model starts from the pretrained checkpoint of mega-dalle, which is trained on 15 million images from the Conceptual Captions dataset \[4\] that has been scraped and filtered from billions of webpages. 21 | 22 | ### Intended Use 23 | This model is intended for generating visual stories containing the 9 characters in the Pororo dataset. This version of the StoryDALL-E model is reasonable at the following scenarios: 24 | * Frames containing a single character. 25 | * Overtly visual actions such as *making cookies*, *walking*, *reading a book*, *sitting*. 26 | * Scenes taking place in snowy settings, indoors and gardens. 27 | * Visual stories contaning 1-3 characters across all frames. 28 | * Scene transitions e.g. from day to night. 29 | 30 | Here are some examples of generated visual stories for the above-mentioned settings. 31 |

32 | 33 |

34 | 35 | Due to the small training dataset size for story visualization, the model has poor generalization to some unseen settings. The model struggles to generate coherent images in the following scenarios. 36 | * Multiple characters in a frame. 37 | * Non-visual actions such as *compliment*. 38 | * Characters that are infrequent in the training dataset e.g. *Rody*, *Harry*. 39 | * Background locations that are not found in the cartoon e.g. *a busy city*. 40 | * Color-based descriptions for object. 41 | * Completely new characters based on textual descriptions. 42 | 43 | In summary, we find that the model performs well at visualizing stories with up to three characters across all frames and struggles at generating coherent visuals for more than three characters. 44 | The model copies visual elements from the source image and copies to each of the generated frames in the story, hence maintaining a continuous flow in narration by virtue of conditioning on an initial scene. 45 | StoryDALL-E performs best at generating overtly visual actions and is capable of generating semantic concepts that do not appear in the story continuation dataset, such as *doughnut* and *lion*, by leveraging the pretrained knowledge of DALL-E Mega when possible. 46 | Most of the scenes in the Pororo dataset occur within the setting of a snowy village with wooden houses surrounded by trees and snow. Hence, the model usually generates scenes with similar visual elements. 47 | 48 | ### Ethical Considerations. 49 | Our experimental results are specific to the task of story continuation. 50 | By using cartoon images in our task, we avoid the egregious ethical issues associated with real-world usage of image generation such as DeepFakes. 51 | We focus not on generating realistic images, but on improved multi-modal understanding in the context of story visualization. 52 | 53 | ### Citation: 54 | ``` 55 | @inproceedings{maharana2022storydalle, 56 | title={StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation}, 57 | author={Maharana, Adyasha and Hannan, Darryl and Bansal, Mohit}, 58 | booktitle={ECCV}, 59 | year={2022} 60 | } 61 | ``` 62 | Send questions, feedback or comments to adyasha@cs.unc.edu. 63 | 64 | ### References 65 | 66 | \[1\] Maharana, Adyasha, et al. "StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation." ECCV. 2022. 67 | 68 | \[2\] Li, Yitong, et al. "Storygan: A sequential conditional gan for story visualization." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. 69 | 70 | \[3\] Kim, Kyung-Min, et al. "DeepStory: video story QA by deep embedded memory networks." Proceedings of the 26th International Joint Conference on Artificial Intelligence. 2017. 71 | 72 | \[4\] Sharma, Piyush, et al. "Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2018. 73 | 74 | \[5\] Mitchell, Margaret, et al. "Model cards for model reporting." Proceedings of the conference on fairness, accountability, and transparency. 2019. 75 | 76 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | ## StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation 2 | 3 | PyTorch code for the ECCV 2022 paper "StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation". 4 | 5 | \[[Paper](https://arxiv.org/abs/2209.06192)\] \[[Model Card](https://github.com/adymaharana/storydalle/blob/main/MODEL_CARD.MD)\] \[[Spaces Demo](https://huggingface.co/spaces/ECCV2022/storydalle)\] \[[Replicate Demo](https://replicate.com/adymaharana/story-dalle)\] 6 | 7 | ![image](./assets/story_dalle_predictions.png) 8 | 9 | ![image](./assets/story_dalle.png) 10 | 11 | ### Training 12 | 13 | #### Prepare Repository: 14 | Download the PororoSV dataset and associated files from [here](https://drive.google.com/file/d/11Io1_BufAayJ1BpdxxV2uJUvCcirbrNc/view?usp=sharing) (updated) and save it as ```./data/pororo/```.
15 | Download the FlintstonesSV dataset and associated files from [here](https://drive.google.com/file/d/1kG4esNwabJQPWqadSDaugrlF4dRaV33_/view?usp=sharing) and save it as ```./data/flintstones```
16 | Download the DiDeMoSV dataset and associated files from [here](https://drive.google.com/file/d/1zgj_bpE6Woyi-G76axF0nO-yzQaLBayc/view?usp=sharing) and save it as ```./data/didemo```
17 | 18 | This repository contains separate folders for training StoryDALL-E based on [minDALL-E](https://github.com/kakaobrain/minDALL-E) and [DALL-E Mega](https://github.com/kuprel/min-dalle) models i.e. the ```./story_dalle/``` and ```./mega-story-dalle``` models respectively. 19 | 20 | #### Training StoryDALL-E based on minDALL-E: 21 | 22 | 1. To finetune the minDALL-E model for story continuation, first migrate to the corresponding folder:\ 23 | ```cd story-dalle```
24 | 2. Set the environment variables in ```train_story.sh``` to point to the right locations in your system. Specifically, change the ```$DATA_DIR```, ```$OUTPUT_ROOT``` and ```$LOG_DIR``` if different from the default locations. 25 | 3. Download the pretrained checkpoint from [here](https://github.com/kakaobrain/minDALL-E) and save it in ```./1.3B``` 26 | 4. Run the following command: 27 | ```bash train_story.sh ``` 28 | 29 | 30 | #### Training StoryDALL-E based on DALL-E Mega: 31 | 32 | 1. To finetune the DALL-E Mega model for story continuation, first migrate to the corresponding folder:\ 33 | ```cd mega-story-dalle```
34 | 2. Set the environment variables in ```train_story.sh``` to point to the right locations in your system. Specifically, change the ```$DATA_DIR```, ```$OUTPUT_ROOT``` and ```$LOG_DIR``` if different from the default locations. 35 | 3. Pretrained checkpoints for generative model and VQGAN detokenizer are automatically downloaded upon initialization. Download the pretrained weights for VQGAN tokenizer from [here](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/) and place it in the same folder as VQGAN detokenizer. 36 | 4. Run the following command: 37 | ```bash train_story.sh ``` 38 | 39 | ### Inference 40 | Pretrained checkpoints for minDALL-E based StoryDALL-E can be downloaded from here: [PororoSV](https://drive.google.com/file/d/1lJ6zMZ6qTvFu6H35-VEdFlN13MMslivJ/view?usp=sharing)
41 | 42 | For a demo of inference using cog, check out [this repo](https://github.com/daanelson/story-dalle-cog). 43 | 44 | #### Inferring from StoryDALL-E based on minDALL-E: 45 | 46 | 1. To infer from the minDALL-E model for story continuation, first migrate to the corresponding folder:\ 47 | ```cd story-dalle```
48 | 2. Set the environment variables in ```infer_story.sh``` to point to the right locations in your system. Specifically, change the ```$DATA_DIR```, ```$OUTPUT_ROOT``` and ```$MODEL_CKPT``` if different from the default locations. 49 | 3Run the following command: 50 | ```bash infer_story.sh ``` 51 | 52 | #### Memory Requirements for Inference: 53 | 54 | For double-precision inference, the StoryDALLE model requires nearly 40 GB of space. The memory requirements can be reduced to 20GB by performing mixed precision inference from the autoregressive decoder (included in codebase, see line 1095 in story-dalle/dalle/models/__init_.py). Note that the VQGAN model needs to operate at full precision to retain high-quality of the generated images. 55 | 56 | 57 | ### Acknowledgements 58 | Thanks to the fantastic folks at Kakao Brain and HuggingFace for their work on open-sourced versions of min-DALLE and DALL-E Mega. Much of this codebase has been adapted from [here](https://github.com/kakaobrain/minDALL-E) and [here](https://github.com/kuprel/min-dalle). -------------------------------------------------------------------------------- /assets/demo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/demo.pdf -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/demo.png -------------------------------------------------------------------------------- /assets/demo_pororo_good.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/demo_pororo_good.png -------------------------------------------------------------------------------- /assets/pororo_characters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/pororo_characters.png -------------------------------------------------------------------------------- /assets/story_dalle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/story_dalle.png -------------------------------------------------------------------------------- /assets/story_dalle_predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/story_dalle_predictions.png -------------------------------------------------------------------------------- /mega-story-dalle/flintstones_dataloader.py: -------------------------------------------------------------------------------- 1 | import os, pickle 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch.utils.data 5 | import PIL 6 | from random import randrange 7 | import json 8 | from torchvision import transforms 9 | from PIL import Image 10 | 11 | unique_characters = ["Wilma", "Fred", "Betty", "Barney", "Dino", "Pebbles", "Mr Slate"] 12 | 13 | class ImageDataset(torch.utils.data.Dataset): 14 | def __init__(self, dir_path, tokenizer, preprocess, mode='train'): 15 | self.dir_path = dir_path 16 | 17 | splits = json.load(open(os.path.join(self.dir_path, 'train-val-test_split.json'), 'r')) 18 | train_id, val_id, test_id = splits["train"], splits["val"], splits["test"] 19 | 20 | if mode == 'train': 21 | self.orders = train_id 22 | elif mode =='val': 23 | self.orders = val_id 24 | elif mode == 'test': 25 | self.orders = test_id 26 | else: 27 | raise ValueError 28 | print("Total number of clips {}".format(len(self.orders))) 29 | 30 | annotations = json.load(open(os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json'))) 31 | self.descriptions = {} 32 | for sample in annotations: 33 | self.descriptions[sample["globalID"]] = sample["description"] 34 | 35 | self.preprocess = preprocess 36 | self.tokenizer = tokenizer 37 | 38 | def __getitem__(self, item): 39 | 40 | # single image input 41 | globalID = self.orders[item] 42 | path = os.path.join(self.dir_path, 'video_frames_sampled', globalID + '.npy') 43 | arr = np.load(path) 44 | n_frames = arr.shape[0] 45 | random_range = randrange(n_frames) 46 | im = arr[random_range] 47 | image = np.array(im) 48 | image = PIL.Image.fromarray(image.astype('uint8'), 'RGB') 49 | text = self.descriptions[globalID] 50 | tokens = self.tokenizer.encode(text.lower()) 51 | tokens = torch.LongTensor(tokens.ids) 52 | image = self.preprocess(image) 53 | 54 | return image, tokens 55 | 56 | def __len__(self): 57 | return len(self.orders) 58 | 59 | 60 | class StoryImageDataset(torch.utils.data.Dataset): 61 | def __init__(self, dir_path, tokenizer, transform=None, mode='train', im_input_size=128, out_img_folder='', return_labels=False): 62 | self.dir_path = dir_path 63 | 64 | splits = json.load(open(os.path.join(self.dir_path, 'train-val-test_split.json'), 'r')) 65 | train_id, val_id, test_id = splits["train"], splits["val"], splits["test"] 66 | 67 | min_len = 4 68 | if os.path.exists(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl')): 69 | self.followings = pickle.load(open(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl'), 'rb')) 70 | else: 71 | print("Cache does not exist") 72 | all_clips = train_id + val_id + test_id 73 | all_clips.sort() 74 | for idx, clip in enumerate(tqdm(all_clips, desc="Counting total number of frames")): 75 | season, episode = int(clip.split('_')[1]), int(clip.split('_')[3]) 76 | has_frames = True 77 | for c in all_clips[idx+1:idx+min_len+1]: 78 | s_c, e_c = int(c.split('_')[1]), int(c.split('_')[3]) 79 | if s_c != season or e_c != episode: 80 | has_frames = False 81 | break 82 | if has_frames: 83 | self.followings[clip] = all_clips[idx+1:idx+min_len+1] 84 | else: 85 | continue 86 | pickle.dump(self.followings, open(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl'), 'wb')) 87 | 88 | train_id = [tid for tid in train_id if tid in self.followings] 89 | val_id = [vid for vid in val_id if vid in self.followings] 90 | test_id = [tid for tid in test_id if tid in self.followings] 91 | 92 | self.labels = pickle.load(open(os.path.join(dir_path, 'labels.pkl'), 'rb')) 93 | 94 | if mode == 'train': 95 | self.orders = train_id 96 | elif mode =='val': 97 | val_id = [vid for vid in val_id if len(self.followings[vid]) == 4] 98 | self.orders = val_id 99 | elif mode == 'test': 100 | test_id = [vid for vid in test_id if len(self.followings[vid]) == 4] 101 | self.orders = test_id[:1900] 102 | else: 103 | raise ValueError 104 | print("Total number of clips {}".format(len(self.orders))) 105 | 106 | annotations = json.load(open(os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json'))) 107 | self.descriptions = {} 108 | for sample in annotations: 109 | self.descriptions[sample["globalID"]] = sample["description"] 110 | 111 | self.embeds = np.load(os.path.join(self.dir_path, "flintstones_use_embeddings.npy")) 112 | self.sent2idx = pickle.load(open(os.path.join(self.dir_path, 'flintstones_use_embed_idxs.pkl'), 'rb')) 113 | 114 | if mode == 'train': 115 | if transform: 116 | self.transform = transform 117 | else: 118 | self.transform = transforms.Compose([ 119 | transforms.RandomResizedCrop(im_input_size), 120 | transforms.RandomHorizontalFlip(), 121 | transforms.ToTensor(), 122 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 123 | ]) 124 | else: 125 | if transform: 126 | self.transform = transform 127 | else: 128 | self.transform = transforms.Compose([ 129 | transforms.Resize(im_input_size), 130 | transforms.CenterCrop(im_input_size), 131 | transforms.ToTensor(), 132 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 133 | ]) 134 | 135 | self.tokenizer = tokenizer 136 | self.return_labels = return_labels 137 | self.out_img_folder = out_img_folder 138 | 139 | def __getitem__(self, item): 140 | 141 | # single image input 142 | globalIDs = [self.orders[item]] + self.followings[self.orders[item]] 143 | tokens = [] 144 | images = [] 145 | for idx, globalID in enumerate(globalIDs): 146 | if self.out_img_folder and idx != 0: 147 | image = Image.open(os.path.join(self.out_img_folder, 'gen_sample_%s_%s.png' % (item, idx-1))).convert('RGB') 148 | else: 149 | path = os.path.join(self.dir_path, 'video_frames_sampled', globalID + '.npy') 150 | arr = np.load(path) 151 | n_frames = arr.shape[0] 152 | random_range = randrange(n_frames) 153 | im = arr[random_range] 154 | # image = np.array(im) 155 | image = PIL.Image.fromarray(im.astype('uint8'), 'RGB') 156 | images.append(image) 157 | text = self.descriptions[globalID] 158 | if idx != 0: 159 | if self.tokenizer is not None: 160 | tokens.append(self.tokenizer.encode(text.lower())) 161 | else: 162 | tokens.append(text) 163 | if self.tokenizer is not None: 164 | tokens = torch.stack([torch.LongTensor(token.ids) for token in tokens]) 165 | 166 | sent_embeds = [torch.tensor(self.embeds[self.sent2idx[globalID]]) for globalID in globalIDs[1:]] 167 | 168 | if self.return_labels: 169 | labels = [torch.tensor(self.labels[globalID]) for globalID in globalIDs[1:]] 170 | return torch.stack([self.transform(im) for im in images[1:]]), torch.stack(labels), tokens, self.transform( 171 | images[0]), torch.stack(sent_embeds) 172 | else: 173 | return torch.stack([self.transform(im) for im in images[1:]]), tokens, self.transform(images[0]), torch.stack(sent_embeds) 174 | 175 | def __len__(self): 176 | return len(self.orders) 177 | 178 | 179 | # if __name__ == "__main__": 180 | # 181 | # dataset = StoryImageDataset('/nas-ssd/adyasha/datasets/flintstones', None, None, 'val') 182 | # for item in range(len(dataset)): 183 | # texts = [] 184 | # globalIDs = [dataset.orders[item]] + dataset.followings[dataset.orders[item]] 185 | # for idx, globalID in enumerate(globalIDs): 186 | # text = dataset.descriptions[globalID] 187 | # texts.append(text) 188 | # if len(texts) != 5: 189 | # print(item, globalIDs) 190 | 191 | -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/__init__.py: -------------------------------------------------------------------------------- 1 | from .min_dalle import MinDalle, PipelineParallelMinDalle -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/__pycache__/min_dalle.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/__pycache__/min_dalle.cpython-38.pyc -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/__pycache__/text_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/__pycache__/text_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dalle_bart_encoder import DalleBartEncoder 2 | from .dalle_bart_decoder import DalleBartDecoder 3 | from .vqgan_detokenizer import VQGanDetokenizer, VQGanTokenizer -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/models/__pycache__/dalle_bart_decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/models/__pycache__/dalle_bart_decoder.cpython-38.pyc -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/models/__pycache__/dalle_bart_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/models/__pycache__/dalle_bart_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/models/__pycache__/vqgan_detokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/models/__pycache__/vqgan_detokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/models/dalle_bart_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | import torch 3 | from torch import nn, LongTensor, FloatTensor, BoolTensor 4 | from .dalle_bart_encoder import GLU, AttentionBase 5 | 6 | IMAGE_TOKEN_COUNT = 256 7 | 8 | 9 | class DecoderCrossAttention(AttentionBase): 10 | def forward( 11 | self, 12 | decoder_state: FloatTensor, 13 | encoder_state: FloatTensor, 14 | attention_mask: BoolTensor 15 | ) -> FloatTensor: 16 | keys = self.k_proj.forward(encoder_state) 17 | values = self.v_proj.forward(encoder_state) 18 | queries = self.q_proj.forward(decoder_state) 19 | return super().forward(keys, values, queries, attention_mask) 20 | 21 | 22 | class DecoderSelfAttention(AttentionBase): 23 | def __init__(self, head_count: int, embed_count: int): 24 | super().__init__(head_count, embed_count) 25 | 26 | def forward( 27 | self, 28 | decoder_state: FloatTensor, 29 | attention_state: FloatTensor, 30 | attention_mask: BoolTensor, 31 | token_index: LongTensor = None 32 | ) -> Tuple[FloatTensor, FloatTensor]: 33 | keys = self.k_proj.forward(decoder_state) 34 | values = self.v_proj.forward(decoder_state) 35 | queries = self.q_proj.forward(decoder_state) 36 | 37 | # TODO: refer to the cache process in minDALLE to fix this 38 | if token_index is not None: 39 | token_count = token_index.shape[1] 40 | if token_count == 1: 41 | batch_count = decoder_state.shape[0] 42 | attn_state_new = torch.cat([keys, values]).to(attention_state.dtype) 43 | attention_state[:, token_index[0]] = attn_state_new 44 | keys = attention_state[:batch_count] 45 | values = attention_state[batch_count:] 46 | 47 | decoder_state = super().forward(keys, values, queries, attention_mask) 48 | return decoder_state, attention_state 49 | 50 | 51 | class DecoderLayer(nn.Module): 52 | def __init__( 53 | self, 54 | head_count: int, 55 | embed_count: int, 56 | glu_embed_count: int, 57 | device: str, 58 | condition: bool = False 59 | ): 60 | super().__init__() 61 | self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) 62 | self.self_attn = DecoderSelfAttention(head_count, embed_count) 63 | self.self_attn_layer_norm = nn.LayerNorm(embed_count) 64 | self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count) 65 | self.encoder_attn = DecoderCrossAttention(head_count, embed_count) 66 | self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) 67 | self.glu = GLU(embed_count, glu_embed_count) 68 | self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device) 69 | self.head_count = head_count 70 | 71 | self.condition = condition 72 | if condition: 73 | self.pre_condition_attn_layer_norm = nn.LayerNorm(embed_count) 74 | self.condition_attn = DecoderCrossAttention(head_count, embed_count) 75 | self.condition_attn_layer_norm = nn.LayerNorm(embed_count) 76 | 77 | def sample( 78 | self, 79 | decoder_state: FloatTensor, 80 | encoder_state: FloatTensor, 81 | attention_state: FloatTensor, 82 | attention_mask: BoolTensor, 83 | token_index: LongTensor, 84 | condition_state: FloatTensor = None 85 | ) -> Tuple[FloatTensor, FloatTensor]: 86 | # Self Attention 87 | token_count = token_index.shape[1] 88 | if token_count == 1: 89 | # print(self.token_indices.device, token_index.device) 90 | self_attn_mask = self.token_indices <= token_index 91 | self_attn_mask = self_attn_mask[:, None, None, :] 92 | else: 93 | self_attn_mask = ( 94 | self.token_indices[None, None, :token_count] <= 95 | token_index[:, :, None] 96 | ) 97 | self_attn_mask = self_attn_mask[:, None, :, :] 98 | 99 | # TODO: Fix self-attention mask 100 | 101 | residual = decoder_state 102 | decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) 103 | decoder_state, attention_state = self.self_attn.forward( 104 | decoder_state=decoder_state, 105 | attention_state=attention_state, 106 | attention_mask=self_attn_mask, 107 | token_index=token_index 108 | ) 109 | decoder_state = self.self_attn_layer_norm.forward(decoder_state) 110 | decoder_state = residual + decoder_state 111 | 112 | # Cross Attention 113 | residual = decoder_state 114 | decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state) 115 | decoder_state = self.encoder_attn.forward( 116 | decoder_state=decoder_state, 117 | encoder_state=encoder_state, 118 | attention_mask=attention_mask 119 | ) 120 | decoder_state = self.encoder_attn_layer_norm.forward(decoder_state) 121 | decoder_state = residual + decoder_state 122 | 123 | # Cross-Attention Over Image Condition 124 | if self.condition: 125 | assert condition_state is not None 126 | residual = decoder_state 127 | decoder_state = self.pre_condition_attn_layer_norm.forward(decoder_state) 128 | decoder_state = self.condition_attn.forward( 129 | decoder_state=decoder_state, 130 | encoder_state=encoder_state, 131 | attention_mask=attention_mask 132 | ) 133 | decoder_state = self.condition_attn_layer_norm.forward(decoder_state) 134 | decoder_state = residual + decoder_state 135 | 136 | # Feed forward 137 | residual = decoder_state 138 | decoder_state = self.glu.forward(decoder_state) 139 | decoder_state = residual + decoder_state 140 | 141 | return decoder_state, attention_state 142 | 143 | 144 | def forward( 145 | self, 146 | decoder_state: FloatTensor, 147 | encoder_state: FloatTensor, 148 | attention_state: FloatTensor, 149 | attention_mask: BoolTensor, 150 | condition_state: FloatTensor = None 151 | ) -> Tuple[FloatTensor, FloatTensor]: 152 | # Self Attention 153 | # token_count = token_index.shape[1] 154 | # if token_count == 1: 155 | # self_attn_mask = self.token_indices <= token_index 156 | # self_attn_mask = self_attn_mask[:, None, None, :] 157 | # else: 158 | # self_attn_mask = ( 159 | # self.token_indices[None, None, :token_count] <= 160 | # token_index[:, :, None] 161 | # ) 162 | # self_attn_mask = self_attn_mask[:, None, :, :] 163 | 164 | # TODO: Fix self-attention mask 165 | B, N = decoder_state.shape[:2] 166 | self_attn_mask = torch.tril(torch.ones(size=(N, N), device=decoder_state.device)).view(1, 1, N, N).repeat(B, self.head_count, 1, 1) 167 | # print("Self-attention mask shape: ", self_attn_mask.shape) 168 | 169 | residual = decoder_state 170 | decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) 171 | decoder_state, attention_state = self.self_attn.forward( 172 | decoder_state=decoder_state, 173 | attention_state=attention_state, 174 | attention_mask=self_attn_mask, 175 | # token_index=token_index 176 | ) 177 | decoder_state = self.self_attn_layer_norm.forward(decoder_state) 178 | decoder_state = residual + decoder_state 179 | 180 | # Cross Attention 181 | residual = decoder_state 182 | decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state) 183 | decoder_state = self.encoder_attn.forward( 184 | decoder_state=decoder_state, 185 | encoder_state=encoder_state, 186 | attention_mask=attention_mask 187 | ) 188 | decoder_state = self.encoder_attn_layer_norm.forward(decoder_state) 189 | decoder_state = residual + decoder_state 190 | 191 | # Cross-Attention Over Image Condition 192 | if self.condition: 193 | assert condition_state is not None 194 | residual = decoder_state 195 | decoder_state = self.pre_condition_attn_layer_norm.forward(decoder_state) 196 | decoder_state = self.condition_attn.forward( 197 | decoder_state=decoder_state, 198 | encoder_state=encoder_state, 199 | attention_mask=attention_mask 200 | ) 201 | decoder_state = self.condition_attn_layer_norm.forward(decoder_state) 202 | decoder_state = residual + decoder_state 203 | 204 | # Feed forward 205 | residual = decoder_state 206 | decoder_state = self.glu.forward(decoder_state) 207 | decoder_state = residual + decoder_state 208 | 209 | return decoder_state, attention_state 210 | 211 | 212 | class DalleBartDecoder(nn.Module): 213 | def __init__( 214 | self, 215 | image_vocab_count: int, 216 | embed_count: int, 217 | attention_head_count: int, 218 | glu_embed_count: int, 219 | layer_count: int, 220 | device: str, 221 | condition: bool = False 222 | ): 223 | super().__init__() 224 | self.layer_count = layer_count 225 | self.embed_count = embed_count 226 | self.image_vocab_count = image_vocab_count 227 | self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) 228 | self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count) 229 | self.layers: List[DecoderLayer] = nn.ModuleList([ 230 | DecoderLayer( 231 | head_count=attention_head_count, 232 | embed_count=embed_count, 233 | glu_embed_count=glu_embed_count, 234 | device=device, 235 | condition = (i+1)%3 == 0 if condition else False 236 | ) 237 | for i in range(layer_count) 238 | ]) 239 | self.condition = condition 240 | self.layernorm_embedding = nn.LayerNorm(embed_count) 241 | self.final_ln = nn.LayerNorm(embed_count) 242 | self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False) 243 | self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device) 244 | 245 | if self.condition: 246 | print("Initialized %s condition attention layers" % sum([(i+1)%3 == 0 for i in range(layer_count)])) 247 | 248 | 249 | def forward( 250 | self, 251 | attention_mask: BoolTensor, 252 | encoder_state: FloatTensor, 253 | attention_state: FloatTensor, 254 | prev_tokens: LongTensor, 255 | condition_state: FloatTensor = None 256 | ) -> Tuple[FloatTensor, FloatTensor]: 257 | decoder_state = self.embed_tokens.forward(prev_tokens) 258 | B, N = prev_tokens.shape 259 | pos_enc_tokens = torch.arange(N, device=prev_tokens.device).repeat((B, 1)) 260 | decoder_state += self.embed_positions.forward(pos_enc_tokens) 261 | decoder_state = self.layernorm_embedding.forward(decoder_state) 262 | 263 | if condition_state is not None: 264 | condition_state = self.embed_tokens.forward(condition_state) 265 | B_c, N_c = condition_state.shape[:2] 266 | pos_enc_tokens = torch.arange(N_c, device=condition_state.device).repeat((B_c, 1)) 267 | # print(condition_state.shape, pos_enc_tokens.shape) 268 | condition_state += self.embed_positions.forward(pos_enc_tokens) 269 | # print(condition_state.shape) 270 | condition_state = condition_state.repeat_interleave(int(B/B_c), dim=0) 271 | # print(condition_state.shape) 272 | 273 | for i in range(self.layer_count): 274 | decoder_state, attention_state[i] = self.layers[i].forward( 275 | decoder_state, 276 | encoder_state, 277 | attention_state[i], 278 | attention_mask, 279 | condition_state=condition_state if self.condition and (i+1)%3 == 0 else None 280 | ) 281 | decoder_state = self.final_ln(decoder_state) 282 | logits = self.lm_head(decoder_state) 283 | return logits, attention_state 284 | 285 | def sample( 286 | self, 287 | attention_mask: BoolTensor, 288 | encoder_state: FloatTensor, 289 | attention_state: FloatTensor, 290 | prev_tokens: LongTensor, 291 | token_index: LongTensor, 292 | condition_state: FloatTensor = None, 293 | supercondition: bool = False 294 | ) -> Tuple[FloatTensor, FloatTensor]: 295 | image_count = encoder_state.shape[0] // 2 296 | token_index = token_index.unsqueeze(0).repeat(image_count * 2, 1) 297 | if supercondition: 298 | prev_tokens = prev_tokens.repeat(2, 1) 299 | decoder_state = self.embed_tokens.forward(prev_tokens) 300 | decoder_state += self.embed_positions.forward(token_index) 301 | decoder_state = self.layernorm_embedding.forward(decoder_state) 302 | for i in range(self.layer_count): 303 | decoder_state, attention_state[i] = self.layers[i].sample( 304 | decoder_state, 305 | encoder_state, 306 | attention_state[i], 307 | attention_mask, 308 | token_index, 309 | condition_state=condition_state if self.condition and (i + 1) % 3 == 0 else None 310 | ) 311 | decoder_state = self.final_ln(decoder_state) 312 | logits = self.lm_head(decoder_state) 313 | return logits, attention_state 314 | 315 | 316 | def sample_tokens(self, settings, **kwargs) -> Tuple[LongTensor, FloatTensor]: 317 | logits, attention_state = self.sample(supercondition=settings[2] != 1, **kwargs) 318 | image_count = logits.shape[0] // 2 319 | temperature = settings[[0]] 320 | top_k = settings[[1]].to(torch.long) 321 | supercondition_factor = settings[[2]] 322 | 323 | logits = logits[:, -1, : 2 ** 14] 324 | if supercondition_factor != 1: 325 | logits: FloatTensor = ( 326 | logits[:image_count] * (1 - supercondition_factor) + 327 | logits[image_count:] * supercondition_factor 328 | ) 329 | else: 330 | # logits: FloatTensor = ( 331 | # logits[:image_count] * 0 + 332 | # logits[image_count:] * 1 333 | # ) 334 | # print(logits.shape) 335 | pass 336 | 337 | # print(logits.shape) 338 | logits_sorted, _ = logits.sort(descending=True) 339 | # print(logits_sorted.shape) 340 | is_kept = logits >= logits_sorted[:, top_k - 1] 341 | if len(is_kept.shape) == 3: 342 | is_kept = logits >= logits_sorted[:, [top_k - 1]] 343 | assert len(is_kept.shape) == 2 344 | # print(logits_sorted[:, [0]]) 345 | logits -= logits_sorted[:, [0]] 346 | # print(logits.shape) 347 | logits /= temperature 348 | # print(logits.shape, temperature) 349 | logits.exp_() 350 | logits *= is_kept.to(torch.float32) 351 | image_tokens = torch.multinomial(logits, 1)[:, 0] 352 | return image_tokens, attention_state -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/models/dalle_bart_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | from torch import nn, BoolTensor, FloatTensor, LongTensor 4 | 5 | 6 | class GLU(nn.Module): 7 | def __init__(self, count_in_out: int, count_middle: int): 8 | super().__init__() 9 | self.gelu = nn.GELU() 10 | self.ln0 = nn.LayerNorm(count_in_out) 11 | self.ln1 = nn.LayerNorm(count_middle) 12 | self.fc0 = nn.Linear(count_in_out, count_middle, bias=False) 13 | self.fc1 = nn.Linear(count_in_out, count_middle, bias=False) 14 | self.fc2 = nn.Linear(count_middle, count_in_out, bias=False) 15 | 16 | def forward(self, z: FloatTensor) -> FloatTensor: 17 | z = self.ln0.forward(z) 18 | w = self.fc0.forward(z) 19 | w = self.gelu.forward(w) 20 | v = self.fc1.forward(z) 21 | z = self.ln1.forward(w * v) 22 | z = self.fc2.forward(z) 23 | return z 24 | 25 | 26 | class AttentionBase(nn.Module): 27 | def __init__(self, head_count: int, embed_count: int): 28 | super().__init__() 29 | self.head_count = head_count 30 | self.embed_count = embed_count 31 | 32 | self.k_proj = nn.Linear(embed_count, embed_count, bias=False) 33 | self.v_proj = nn.Linear(embed_count, embed_count, bias=False) 34 | self.q_proj = nn.Linear(embed_count, embed_count, bias=False) 35 | self.out_proj = nn.Linear(embed_count, embed_count, bias=False) 36 | 37 | def forward( 38 | self, 39 | keys: FloatTensor, 40 | values: FloatTensor, 41 | queries: FloatTensor, 42 | attention_mask: BoolTensor 43 | ) -> FloatTensor: 44 | keys = keys.reshape(keys.shape[:2] + (self.head_count, -1)) 45 | values = values.reshape(values.shape[:2] + (self.head_count, -1)) 46 | queries = queries.reshape(queries.shape[:2] + (self.head_count, -1)) 47 | queries /= queries.shape[-1] ** 0.5 48 | attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12 49 | attention_weights: FloatTensor = torch.einsum( 50 | 'bqhc,bkhc->bhqk', 51 | queries, 52 | keys 53 | ) 54 | 55 | # print(keys.shape, values.shape, queries.shape) 56 | # print("Attention weights shape: ", attention_weights.shape) 57 | # print("Attention bias shape: ", attention_bias.shape) 58 | 59 | attention_weights += attention_bias 60 | attention_weights = torch.softmax(attention_weights, -1) 61 | attention_output: FloatTensor = torch.einsum( 62 | "bhqk,bkhc->bqhc", 63 | attention_weights, 64 | values 65 | ) 66 | shape = attention_output.shape[:2] + (self.embed_count,) 67 | attention_output = attention_output.reshape(shape) 68 | attention_output = self.out_proj.forward(attention_output) 69 | return attention_output 70 | 71 | 72 | class EncoderSelfAttention(AttentionBase): 73 | def forward( 74 | self, 75 | encoder_state: FloatTensor, 76 | attention_mask: BoolTensor 77 | ) -> FloatTensor: 78 | keys = self.k_proj.forward(encoder_state) 79 | values = self.v_proj.forward(encoder_state) 80 | queries = self.q_proj.forward(encoder_state) 81 | return super().forward(keys, values, queries, attention_mask) 82 | 83 | 84 | class EncoderLayer(nn.Module): 85 | def __init__(self, embed_count: int, head_count: int, glu_embed_count: int): 86 | super().__init__() 87 | self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) 88 | self.self_attn = EncoderSelfAttention(head_count, embed_count) 89 | self.self_attn_layer_norm = nn.LayerNorm(embed_count) 90 | self.glu = GLU(embed_count, glu_embed_count) 91 | 92 | def forward( 93 | self, 94 | encoder_state: FloatTensor, 95 | attention_mask: BoolTensor 96 | ) -> FloatTensor: 97 | residual = encoder_state 98 | encoder_state = self.pre_self_attn_layer_norm.forward(encoder_state) 99 | encoder_state = self.self_attn.forward(encoder_state, attention_mask) 100 | encoder_state = self.self_attn_layer_norm.forward(encoder_state) 101 | encoder_state = residual + encoder_state 102 | residual = encoder_state 103 | encoder_state = self.glu.forward(encoder_state) 104 | encoder_state = residual + encoder_state 105 | return encoder_state 106 | 107 | 108 | class DalleBartEncoder(nn.Module): 109 | def __init__( 110 | self, 111 | layer_count: int, 112 | embed_count: int, 113 | attention_head_count: int, 114 | text_vocab_count: int, 115 | text_token_count: int, 116 | glu_embed_count: int, 117 | device: str 118 | ): 119 | super().__init__() 120 | self.text_vocab_count = text_vocab_count 121 | self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) 122 | self.embed_positions = nn.Embedding(text_token_count, embed_count) 123 | self.layers: List[EncoderLayer] = nn.ModuleList([ 124 | EncoderLayer( 125 | embed_count = embed_count, 126 | head_count = attention_head_count, 127 | glu_embed_count = glu_embed_count 128 | ) 129 | for _ in range(layer_count) 130 | ]) 131 | self.layernorm_embedding = nn.LayerNorm(embed_count) 132 | self.final_ln = nn.LayerNorm(embed_count) 133 | token_indices = torch.arange(text_token_count, device=device) 134 | self.pose_tokens = torch.stack([token_indices] * 2) 135 | 136 | 137 | def _init_weights(self, module: nn.Module) -> None: 138 | if isinstance(module, (nn.Linear, nn.Embedding)): 139 | module.weight.data.normal_(mean=0.0, std=0.02) 140 | if isinstance(module, nn.Linear) and module.bias is not None: 141 | module.bias.data.zero_() 142 | elif isinstance(module, nn.LayerNorm): 143 | module.bias.data.zero_() 144 | module.weight.data.fill_(1.0) 145 | 146 | 147 | def resize_token_embeddings(self, new_num_tokens): 148 | 149 | old_num_tokens, old_embedding_dim = self.embed_tokens.weight.size() 150 | new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) 151 | new_embeddings.to(self.embed_tokens.weight.device, dtype=self.embed_tokens.weight.dtype) 152 | self._init_weights(new_embeddings) 153 | # numbers of tokens to copy 154 | n = min(old_num_tokens, new_num_tokens) 155 | new_embeddings.weight.data[:n, :] = self.embed_tokens.weight.data[:n, :] 156 | self.embed_tokens = new_embeddings 157 | 158 | return new_embeddings 159 | 160 | 161 | def forward(self, text_tokens: LongTensor) -> FloatTensor: 162 | attention_mask = text_tokens.not_equal(1)[:, None, None, :] 163 | # encoder_state = ( 164 | # self.embed_tokens.forward(text_tokens) + 165 | # self.embed_positions.forward(self.pose_tokens) 166 | # ) 167 | # print(text_tokens.shape) 168 | B, L = text_tokens.shape 169 | encoder_state = ( 170 | self.embed_tokens.forward(text_tokens) + 171 | self.embed_positions.forward(torch.stack([torch.arange(L, device=text_tokens.device)]*B)) 172 | ) 173 | # print(encoder_state.requires_grad, encoder_state.grad_fn) 174 | encoder_state = self.layernorm_embedding.forward(encoder_state) 175 | # print(encoder_state.requires_grad, encoder_state.grad_fn) 176 | for layer in self.layers: 177 | encoder_state = layer.forward(encoder_state, attention_mask) 178 | # print(encoder_state.requires_grad, encoder_state.grad_fn) 179 | encoder_state = self.final_ln.forward(encoder_state) 180 | # print(encoder_state.requires_grad, encoder_state.grad_fn) 181 | return encoder_state -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/models/vqgan_detokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import FloatTensor, LongTensor 4 | from math import sqrt 5 | from einops import rearrange 6 | 7 | class ResnetBlock(nn.Module): 8 | def __init__(self, log2_count_in: int, log2_count_out: int): 9 | super().__init__() 10 | m, n = 2 ** log2_count_in, 2 ** log2_count_out 11 | self.is_middle = m == n 12 | self.norm1 = nn.GroupNorm(2 ** 5, m) 13 | self.conv1 = nn.Conv2d(m, n, 3, padding=1) 14 | self.norm2 = nn.GroupNorm(2 ** 5, n) 15 | self.conv2 = nn.Conv2d(n, n, 3, padding=1) 16 | if not self.is_middle: 17 | self.nin_shortcut = nn.Conv2d(m, n, 1) 18 | 19 | def forward(self, x: FloatTensor) -> FloatTensor: 20 | h = x 21 | h = self.norm1.forward(h) 22 | h *= torch.sigmoid(h) 23 | h = self.conv1.forward(h) 24 | h = self.norm2.forward(h) 25 | h *= torch.sigmoid(h) 26 | h = self.conv2(h) 27 | if not self.is_middle: 28 | x = self.nin_shortcut.forward(x) 29 | return x + h 30 | 31 | 32 | class AttentionBlock(nn.Module): 33 | def __init__(self): 34 | super().__init__() 35 | n = 2 ** 9 36 | self.norm = nn.GroupNorm(2 ** 5, n) 37 | self.q = nn.Conv2d(n, n, 1) 38 | self.k = nn.Conv2d(n, n, 1) 39 | self.v = nn.Conv2d(n, n, 1) 40 | self.proj_out = nn.Conv2d(n, n, 1) 41 | 42 | def forward(self, x: FloatTensor) -> FloatTensor: 43 | n, m = 2 ** 9, x.shape[0] 44 | h = x 45 | h = self.norm(h) 46 | k = self.k.forward(h) 47 | v = self.v.forward(h) 48 | q = self.q.forward(h) 49 | k = k.reshape(m, n, -1) 50 | v = v.reshape(m, n, -1) 51 | q = q.reshape(m, n, -1) 52 | q = q.permute(0, 2, 1) 53 | w = torch.bmm(q, k) 54 | w /= n ** 0.5 55 | w = torch.softmax(w, dim=2) 56 | w = w.permute(0, 2, 1) 57 | h = torch.bmm(v, w) 58 | token_count = int(sqrt(h.shape[-1])) 59 | h = h.reshape(m, n, token_count, token_count) 60 | h = self.proj_out.forward(h) 61 | return x + h 62 | 63 | 64 | class MiddleLayer(nn.Module): 65 | def __init__(self): 66 | super().__init__() 67 | self.block_1 = ResnetBlock(9, 9) 68 | self.attn_1 = AttentionBlock() 69 | self.block_2 = ResnetBlock(9, 9) 70 | 71 | def forward(self, h: FloatTensor) -> FloatTensor: 72 | h = self.block_1.forward(h) 73 | h = self.attn_1.forward(h) 74 | h = self.block_2.forward(h) 75 | return h 76 | 77 | 78 | class Upsample(nn.Module): 79 | def __init__(self, log2_count): 80 | super().__init__() 81 | n = 2 ** log2_count 82 | self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2) 83 | self.conv = nn.Conv2d(n, n, 3, padding=1) 84 | 85 | def forward(self, x: FloatTensor) -> FloatTensor: 86 | x = self.upsample.forward(x.to(torch.float32)) 87 | x = self.conv.forward(x) 88 | return x 89 | 90 | 91 | class Downsample(nn.Module): 92 | def __init__(self, log2_count): 93 | super().__init__() 94 | n = 2 ** log2_count 95 | self.conv = torch.nn.Conv2d(n, n, kernel_size=3, stride=2, padding=0) 96 | 97 | def forward(self, x): 98 | pad = (0,1,0,1) 99 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 100 | x = self.conv(x) 101 | return x 102 | 103 | 104 | class DownsampleBlock(nn.Module): 105 | def __init__( 106 | self, 107 | log2_count_in: int, 108 | log2_count_out: int, 109 | has_attention: bool, 110 | has_downsample: bool 111 | ): 112 | super().__init__() 113 | self.has_attention = has_attention 114 | self.has_downsample = has_downsample 115 | 116 | self.block = nn.ModuleList([ 117 | ResnetBlock(log2_count_in, log2_count_out), 118 | ResnetBlock(log2_count_out, log2_count_out), 119 | ]) 120 | 121 | if has_attention: 122 | self.attn = nn.ModuleList([ 123 | AttentionBlock(), 124 | AttentionBlock(), 125 | ]) 126 | 127 | if has_downsample: 128 | self.downsample = Downsample(log2_count_out) 129 | 130 | def forward(self, h: FloatTensor) -> FloatTensor: 131 | for j in range(2): 132 | h = self.block[j].forward(h) 133 | if self.has_attention: 134 | h = self.attn[j].forward(h) 135 | if self.has_downsample: 136 | h = self.downsample.forward(h) 137 | return h 138 | 139 | 140 | 141 | class UpsampleBlock(nn.Module): 142 | def __init__( 143 | self, 144 | log2_count_in: int, 145 | log2_count_out: int, 146 | has_attention: bool, 147 | has_upsample: bool 148 | ): 149 | super().__init__() 150 | self.has_attention = has_attention 151 | self.has_upsample = has_upsample 152 | 153 | self.block = nn.ModuleList([ 154 | ResnetBlock(log2_count_in, log2_count_out), 155 | ResnetBlock(log2_count_out, log2_count_out), 156 | ResnetBlock(log2_count_out, log2_count_out), 157 | ]) 158 | 159 | if has_attention: 160 | self.attn = nn.ModuleList([ 161 | AttentionBlock(), 162 | AttentionBlock(), 163 | AttentionBlock() 164 | ]) 165 | 166 | if has_upsample: 167 | self.upsample = Upsample(log2_count_out) 168 | 169 | 170 | def forward(self, h: FloatTensor) -> FloatTensor: 171 | for j in range(3): 172 | h = self.block[j].forward(h) 173 | if self.has_attention: 174 | h = self.attn[j].forward(h) 175 | if self.has_upsample: 176 | h = self.upsample.forward(h) 177 | return h 178 | 179 | 180 | class Encoder(nn.Module): 181 | def __init__(self): 182 | super().__init__() 183 | 184 | # downsampling 185 | self.conv_in = torch.nn.Conv2d(3, 2 ** 7, 3, stride=1, padding=1) 186 | 187 | self.down = nn.ModuleList([ 188 | DownsampleBlock(7, 7, False, True), 189 | DownsampleBlock(7, 7, False, True), 190 | DownsampleBlock(7, 8, False, True), 191 | DownsampleBlock(8, 8, False, True), 192 | DownsampleBlock(8, 9, True, False), 193 | ]) 194 | 195 | # middle 196 | self.mid = MiddleLayer() 197 | 198 | # end 199 | self.norm_out = nn.GroupNorm(2 ** 5, 2 ** 9) 200 | self.conv_out = nn.Conv2d(2 ** 9, 2 ** 8, 3, padding=1) 201 | 202 | def forward(self, z: FloatTensor) -> FloatTensor: 203 | z = self.conv_in.forward(z) 204 | for i in range(5): 205 | z = self.down[i].forward(z) 206 | z = self.mid.forward(z) 207 | 208 | z = self.norm_out.forward(z) 209 | z *= torch.sigmoid(z) 210 | z = self.conv_out.forward(z) 211 | return z 212 | 213 | 214 | class Decoder(nn.Module): 215 | def __init__(self): 216 | super().__init__() 217 | 218 | self.conv_in = nn.Conv2d(2 ** 8, 2 ** 9, 3, padding=1) 219 | self.mid = MiddleLayer() 220 | 221 | self.up = nn.ModuleList([ 222 | UpsampleBlock(7, 7, False, False), 223 | UpsampleBlock(8, 7, False, True), 224 | UpsampleBlock(8, 8, False, True), 225 | UpsampleBlock(9, 8, False, True), 226 | UpsampleBlock(9, 9, True, True) 227 | ]) 228 | 229 | self.norm_out = nn.GroupNorm(2 ** 5, 2 ** 7) 230 | self.conv_out = nn.Conv2d(2 ** 7, 3, 3, padding=1) 231 | 232 | def forward(self, z: FloatTensor) -> FloatTensor: 233 | z = self.conv_in.forward(z) 234 | z = self.mid.forward(z) 235 | 236 | for i in reversed(range(5)): 237 | z = self.up[i].forward(z) 238 | 239 | z = self.norm_out.forward(z) 240 | z *= torch.sigmoid(z) 241 | z = self.conv_out.forward(z) 242 | return z 243 | 244 | class VectorQuantizer2(nn.Module): 245 | """ 246 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 247 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 248 | """ 249 | # NOTE: due to a bug the beta term was applied to the wrong term. for 250 | # backwards compatibility we use the buggy version by default, but you can 251 | # specify legacy=False to fix it. 252 | def __init__(self, n_e=16384, e_dim=256, beta=0.25, 253 | sane_index_shape=False, legacy=True): 254 | super().__init__() 255 | self.n_e = n_e 256 | self.e_dim = e_dim 257 | self.beta = beta 258 | self.legacy = legacy 259 | 260 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 261 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 262 | 263 | self.re_embed = n_e 264 | self.remap = None 265 | 266 | self.sane_index_shape = sane_index_shape 267 | 268 | def remap_to_used(self, inds): 269 | ishape = inds.shape 270 | assert len(ishape)>1 271 | inds = inds.reshape(ishape[0],-1) 272 | used = self.used.to(inds) 273 | match = (inds[:,:,None]==used[None,None,...]).long() 274 | new = match.argmax(-1) 275 | unknown = match.sum(2)<1 276 | if self.unknown_index == "random": 277 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 278 | else: 279 | new[unknown] = self.unknown_index 280 | return new.reshape(ishape) 281 | 282 | def unmap_to_all(self, inds): 283 | ishape = inds.shape 284 | assert len(ishape)>1 285 | inds = inds.reshape(ishape[0],-1) 286 | used = self.used.to(inds) 287 | if self.re_embed > self.used.shape[0]: # extra token 288 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 289 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 290 | return back.reshape(ishape) 291 | 292 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 293 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 294 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 295 | assert return_logits==False, "Only for interface compatible with Gumbel" 296 | # reshape z -> (batch, height, width, channel) and flatten 297 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 298 | z_flattened = z.view(-1, self.e_dim) 299 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 300 | 301 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 302 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 303 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 304 | 305 | min_encoding_indices = torch.argmin(d, dim=1) 306 | z_q = self.embedding(min_encoding_indices).view(z.shape) 307 | perplexity = None 308 | min_encodings = None 309 | 310 | # compute loss for embedding 311 | if not self.legacy: 312 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 313 | torch.mean((z_q - z.detach()) ** 2) 314 | else: 315 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 316 | torch.mean((z_q - z.detach()) ** 2) 317 | 318 | # preserve gradients 319 | z_q = z + (z_q - z).detach() 320 | 321 | # reshape back to match original input shape 322 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 323 | 324 | if self.remap is not None: 325 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 326 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 327 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 328 | 329 | if self.sane_index_shape: 330 | min_encoding_indices = min_encoding_indices.reshape( 331 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 332 | 333 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 334 | 335 | def get_codebook_entry(self, indices, shape): 336 | # shape specifying (batch, height, width, channel) 337 | if self.remap is not None: 338 | indices = indices.reshape(shape[0],-1) # add batch axis 339 | indices = self.unmap_to_all(indices) 340 | indices = indices.reshape(-1) # flatten again 341 | 342 | # get quantized latent vectors 343 | z_q = self.embedding(indices) 344 | 345 | if shape is not None: 346 | z_q = z_q.view(shape) 347 | # reshape back to match original input shape 348 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 349 | 350 | return z_q 351 | 352 | 353 | class VQGanDetokenizer(nn.Module): 354 | def __init__(self): 355 | super().__init__() 356 | vocab_count, embed_count = 2 ** 14, 2 ** 8 357 | self.vocab_count = vocab_count 358 | self.embedding = nn.Embedding(vocab_count, embed_count) 359 | self.post_quant_conv = nn.Conv2d(embed_count, embed_count, 1) 360 | self.decoder = Decoder() 361 | 362 | def forward(self, is_seamless: bool, z: LongTensor, grid: bool = False) -> FloatTensor: 363 | image_count = z.shape[0] 364 | grid_size = int(sqrt(z.shape[0])) 365 | token_count = grid_size * 2 ** 4 366 | 367 | if is_seamless: 368 | z = z.view([grid_size, grid_size, 2 ** 4, 2 ** 4]) 369 | z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2) 370 | z = z.flatten().unsqueeze(1) 371 | z = self.embedding.forward(z) 372 | z = z.view((1, token_count, token_count, 2 ** 8)) 373 | else: 374 | z = self.embedding.forward(z) 375 | z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8)) 376 | 377 | z = z.permute(0, 3, 1, 2).contiguous() 378 | z = self.post_quant_conv.forward(z) 379 | z = self.decoder.forward(z) 380 | z = z.permute(0, 2, 3, 1) 381 | z = z.clip(0.0, 1.0) * 255 382 | 383 | if is_seamless: 384 | z = z[0] 385 | else: 386 | if grid: 387 | z = z.view([grid_size, grid_size, 2 ** 8, 2 ** 8, 3]) 388 | z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2) 389 | else: 390 | z = z.view([image_count, 2 ** 8, 2 ** 8, 3]) 391 | 392 | return z 393 | 394 | 395 | class VQGanTokenizer(nn.Module): 396 | def __init__(self): 397 | super().__init__() 398 | vocab_count, embed_count = 2 ** 14, 2 ** 8 399 | self.vocab_count = vocab_count 400 | self.quant_conv = nn.Conv2d(embed_count, embed_count, 1) 401 | self.encoder = Encoder() 402 | self.quantize = VectorQuantizer2(sane_index_shape=True) 403 | 404 | def forward(self, x: LongTensor): 405 | 406 | h = self.encoder(x) 407 | # print('VQGan encoder output shape', h.shape) 408 | h = self.quant_conv(h) 409 | # print("Post-quant shape", h.shape) 410 | quant, emb_loss, info = self.quantize(h) 411 | # print(quant.shape, info[-1].shape) 412 | return quant, emb_loss, info -------------------------------------------------------------------------------- /mega-story-dalle/min_dalle/text_tokenizer.py: -------------------------------------------------------------------------------- 1 | from math import inf 2 | from typing import List, Tuple 3 | from emoji import demojize 4 | 5 | class TextTokenizer: 6 | def __init__(self, vocab: dict, merges: List[str]): 7 | self.token_from_subword = vocab 8 | pairs = [tuple(pair.split()) for pair in merges] 9 | self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) 10 | self.new_tokens = [] 11 | 12 | def add_tokens(self, new_tokens): 13 | 14 | self.new_tokens = new_tokens 15 | original_length = len(self.token_from_subword) 16 | for t in new_tokens: 17 | self.token_from_subword[t] = len(self.token_from_subword) 18 | print("Increased vocabulary from %s to %s" % (original_length, len(self.token_from_subword))) 19 | 20 | def tokenize(self, text: str, is_verbose: bool = False) -> List[int]: 21 | sep_token = self.token_from_subword[''] 22 | cls_token = self.token_from_subword[''] 23 | unk_token = self.token_from_subword[''] 24 | text = demojize(text, delimiters=['', '']) 25 | text = text.lower().encode("ascii", errors="ignore").decode() 26 | 27 | # tokens = [ 28 | # self.token_from_subword.get(subword, unk_token) 29 | # for word in text.split(" ") if len(word) > 0 30 | # for subword in self.get_byte_pair_encoding(word, is_verbose) 31 | # ] 32 | # print([subword for word in text.split(" ") if len(word) > 0 33 | # for subword in self.get_byte_pair_encoding(word, is_verbose)]) 34 | 35 | sub_words = [] 36 | tokens = [] 37 | for word in text.split(" "): 38 | if len(word) > 0: 39 | if word in self.new_tokens: 40 | tokens.append(self.token_from_subword[word]) 41 | sub_words.append(word) 42 | else: 43 | for subword in self.get_byte_pair_encoding(word, is_verbose): 44 | tokens.append(self.token_from_subword.get(subword, unk_token)) 45 | sub_words.append(subword) 46 | # print(sub_words) 47 | 48 | return [cls_token] + tokens + [sep_token] 49 | 50 | def get_byte_pair_encoding(self, word: str, is_verbose: bool) -> List[str]: 51 | def get_pair_rank(pair: Tuple[str, str]) -> int: 52 | return self.rank_from_pair.get(pair, inf) 53 | 54 | subwords = [chr(ord(" ") + 256)] + list(word) 55 | while len(subwords) > 1: 56 | pairs = list(zip(subwords[:-1], subwords[1:])) 57 | pair_to_merge = min(pairs, key=get_pair_rank) 58 | if pair_to_merge not in self.rank_from_pair: break 59 | i = pairs.index(pair_to_merge) 60 | subwords = ( 61 | (subwords[:i] if i > 0 else []) + 62 | [subwords[i] + subwords[i + 1]] + 63 | (subwords[i + 2:] if i + 2 < len(subwords) else []) 64 | ) 65 | 66 | if is_verbose: print(subwords) 67 | return subwords -------------------------------------------------------------------------------- /mega-story-dalle/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | # from pathlib import Path 3 | 4 | setuptools.setup( 5 | name='min-dalle', 6 | description = 'min(DALL·E)', 7 | # long_description=(Path(__file__).parent / "README.rst").read_text(), 8 | version='0.4.11', 9 | author='Brett Kuprel', 10 | author_email='brkuprel@gmail.com', 11 | url='https://github.com/kuprel/min-dalle', 12 | packages=[ 13 | 'min_dalle', 14 | 'min_dalle.models' 15 | ], 16 | license='MIT', 17 | install_requires=[ 18 | 'torch>=1.11', 19 | 'typing_extensions>=4.1', 20 | 'numpy>=1.21', 21 | 'pillow>=7.1', 22 | 'requests>=2.23', 23 | 'emoji' 24 | ], 25 | keywords = [ 26 | 'artificial intelligence', 27 | 'deep learning', 28 | 'text-to-image', 29 | 'pytorch' 30 | ] 31 | ) -------------------------------------------------------------------------------- /mega-story-dalle/tkinter_ui.py: -------------------------------------------------------------------------------- 1 | from min_dalle import MinDalle 2 | import sys 3 | import PIL 4 | import PIL.Image 5 | import PIL.ImageTk 6 | import tkinter 7 | from tkinter import ttk 8 | 9 | def regen_root(): 10 | global root 11 | global blank_image 12 | global padding_image 13 | 14 | root = tkinter.Tk() 15 | root.wm_resizable(False, False) 16 | 17 | blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256 * 2, 256 * 2), mode="RGB")) 18 | padding_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(16, 16), mode="RGBA")) 19 | 20 | regen_root() 21 | 22 | is_mega = None 23 | def set_mega_true_and_destroy(): 24 | global is_mega 25 | is_mega = True 26 | root.destroy() 27 | def set_mega_false_and_destroy(): 28 | global is_mega 29 | is_mega = False 30 | root.destroy() 31 | 32 | frm = ttk.Frame(root, padding=16) 33 | frm.grid() 34 | ttk.Button(frm, text="Mega", command=set_mega_true_and_destroy).grid(column=0, row=0) 35 | ttk.Label(frm, image=padding_image).grid(column=1, row=0) 36 | ttk.Button(frm, text="Mini", command=set_mega_false_and_destroy).grid(column=2, row=0) 37 | root.mainloop() 38 | 39 | if is_mega is None: 40 | print("no option selected") 41 | sys.exit(0) 42 | 43 | print("is_mega", is_mega) 44 | 45 | model = MinDalle( 46 | models_root="./pretrained", 47 | is_mega=is_mega, 48 | is_reusable=True, 49 | is_verbose=True 50 | ) 51 | 52 | regen_root() 53 | 54 | label_image_content = blank_image 55 | 56 | sv_prompt = tkinter.StringVar(value="artificial intelligence") 57 | sv_temperature = tkinter.StringVar(value="1") 58 | sv_topk = tkinter.StringVar(value="128") 59 | sv_supercond = tkinter.StringVar(value="16") 60 | bv_seamless = tkinter.BooleanVar(value=False) 61 | 62 | def generate(): 63 | # check fields 64 | try: 65 | temperature = float(sv_temperature.get()) 66 | except: 67 | sv_temperature.set("ERROR") 68 | return 69 | try: 70 | topk = int(sv_topk.get()) 71 | except: 72 | sv_topk.set("ERROR") 73 | return 74 | try: 75 | supercond = int(sv_supercond.get()) 76 | except: 77 | sv_supercond.set("ERROR") 78 | return 79 | try: 80 | is_seamless = bool(bv_seamless.get()) 81 | except: 82 | return 83 | # and continue 84 | global label_image_content 85 | image_stream = model.generate_image_stream( 86 | sv_prompt.get(), 87 | grid_size=2, 88 | seed=-1, 89 | progressive_outputs=True, 90 | is_seamless=is_seamless, 91 | temperature=temperature, 92 | top_k=topk, 93 | supercondition_factor=supercond, 94 | is_verbose=True 95 | ) 96 | for image in image_stream: 97 | global final_image 98 | final_image = image 99 | label_image_content = PIL.ImageTk.PhotoImage(image) 100 | label_image.configure(image=label_image_content) 101 | label_image.update() 102 | 103 | def save(): 104 | final_image.save('generated/out.png') 105 | 106 | frm = ttk.Frame(root, padding=16) 107 | frm.grid() 108 | 109 | props = ttk.Frame(frm) 110 | 111 | # outer structure (hbox) 112 | label_image = ttk.Label(frm, image=blank_image) 113 | label_image.grid(column=0, row=0) 114 | ttk.Label(frm, image=padding_image).grid(column=1, row=0) 115 | props.grid(column=2, row=0) 116 | 117 | # inner structure (properties and shit) 118 | # prompt field 119 | ttk.Label(props, text="Prompt:").grid(column=0, row=0) 120 | ttk.Entry(props, textvariable=sv_prompt).grid(column=1, row=0) 121 | # 122 | ttk.Label(props, image=padding_image).grid(column=0, row=1) 123 | # temperature field 124 | ttk.Label(props, text="Temperature:").grid(column=0, row=2) 125 | ttk.Entry(props, textvariable=sv_temperature).grid(column=1, row=2) 126 | # 127 | ttk.Label(props, image=padding_image).grid(column=0, row=3) 128 | # topk field 129 | ttk.Label(props, text="Top-K:").grid(column=0, row=4) 130 | ttk.Entry(props, textvariable=sv_topk).grid(column=1, row=4) 131 | # 132 | ttk.Label(props, image=padding_image).grid(column=0, row=5) 133 | # superconditioning field 134 | ttk.Label(props, text="Supercondition Factor:").grid(column=0, row=6) 135 | ttk.Entry(props, textvariable=sv_supercond).grid(column=1, row=6) 136 | # 137 | ttk.Label(props, image=padding_image).grid(column=0, row=7) 138 | # seamless 139 | ttk.Label(props, text="Seamless:").grid(column=0, row=8) 140 | ttk.Checkbutton(props, variable=bv_seamless).grid(column=1, row=8) 141 | # 142 | ttk.Label(props, image=padding_image).grid(column=0, row=9) 143 | # buttons 144 | ttk.Button(props, text="Generate", command=generate).grid(column=0, row=10) 145 | ttk.Button(props, text="Quit", command=root.destroy).grid(column=1, row=10) 146 | ttk.Button(props, text="Save", command=save).grid(column=2, row=10) 147 | 148 | root.mainloop() -------------------------------------------------------------------------------- /mega-story-dalle/train_story.sh: -------------------------------------------------------------------------------- 1 | if [ "$1" = "pororo" ]; then 2 | echo "Training on Pororo" 3 | DATA_DIR=../data/pororo/ 4 | OUTPUT_ROOT=./out/pororo 5 | SENT_EMBED=512 6 | STORY_LEN=4 7 | LR=1e-4 8 | TRAIN_BS=1 9 | GRAD_ACC=4 10 | DATA_DIR=../data/flintstones 11 | OUTPUT_ROOT=./out/flintstones 12 | SENT_EMBED=512 13 | STORY_LEN=4 14 | LR=1e-5 15 | TRAIN_BS=1 16 | GRAD_ACC=4 17 | elif [ "$1" = "didemo" ]; then 18 | echo "Training on DiDeMo" 19 | DATA_DIR=../data/didemo 20 | OUTPUT_ROOT=./out/didemo 21 | SENT_EMBED=512 22 | STORY_LEN=2 23 | TRAIN_BS=1 24 | GRAD_ACC=8 25 | fi 26 | 27 | #--prefix_model_name_or_path './1.3B/' \ 28 | #--model_name_or_path './1.3B/' \ 29 | 30 | python ./train_t2i.py \ 31 | --model_name_or_path './pretrained/' \ 32 | --tuning_mode 'story' \ 33 | --dataset_name $1 \ 34 | --story_len $STORY_LEN \ 35 | --sent_embed $SENT_EMBED \ 36 | --prefix_dropout 0.2 \ 37 | --data_dir $DATA_DIR \ 38 | --dataloader_num_workers 4 \ 39 | --output_dir $OUTPUT_ROOT \ 40 | --log_dir /nas-ssd/adyasha/runs/ \ 41 | --do_train --do_eval \ 42 | --per_gpu_train_batch_size $TRAIN_BS \ 43 | --per_gpu_eval_batch_size 2 \ 44 | --num_train_epochs 50 \ 45 | --gradient_accumulation_steps $GRAD_ACC \ 46 | --learning_rate $LR \ 47 | --logging_steps 50 \ 48 | --eval_steps 500 \ 49 | --generate_steps 1000 \ 50 | --is_mega -------------------------------------------------------------------------------- /story-dalle/1.3B/config.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | tokenizer_type: CharBPE 3 | context_length: 64 4 | image_resolution: 256 5 | 6 | stage1: 7 | type: vqgan 8 | embed_dim: 256 9 | n_embed: 16384 10 | hparams: 11 | double_z: False 12 | z_channels: 256 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: [1, 1, 2, 2, 4] 18 | num_res_blocks: 2 19 | attn_resolutions: [16] 20 | pdrop: 0.0 21 | 22 | stage2: 23 | type: transformer1d 24 | vocab_size_txt: 16384 25 | vocab_size_img: 16384 26 | hparams: 27 | embed_dim: 1536 28 | n_layers: 42 29 | n_heads: 24 30 | n_dense_layers: 42 31 | ctx_len_img: 256 32 | ctx_len_txt: 64 33 | embd_pdrop: 0.0 34 | resid_pdrop: 0.0 35 | attn_pdrop: 0.0 36 | mlp_bias: True 37 | attn_bias: True 38 | gelu_use_approx: False 39 | -------------------------------------------------------------------------------- /story-dalle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/__init__.py -------------------------------------------------------------------------------- /story-dalle/__pycache__/pororo_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/__pycache__/pororo_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/__init__.py -------------------------------------------------------------------------------- /story-dalle/dalle/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/__pycache__/trainer_prefix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/__pycache__/trainer_prefix.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/__pycache__/prefix_tuning_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/__pycache__/prefix_tuning_model.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/__pycache__/tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/__pycache__/tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage1/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage1/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage1/__pycache__/layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage1/__pycache__/layers.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage1/__pycache__/vqgan.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage1/__pycache__/vqgan.cpython-37.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage1/__pycache__/vqgan.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage1/__pycache__/vqgan.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage1/layers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from VQGAN (https://github.com/CompVis/taming-transformers) 3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from typing import Tuple, Optional 9 | 10 | 11 | def nonlinearity(x): 12 | # swish 13 | return x*torch.sigmoid(x) 14 | 15 | 16 | def Normalize(in_channels): 17 | return torch.nn.GroupNorm(num_groups=32, 18 | num_channels=in_channels, 19 | eps=1e-6, 20 | affine=True) 21 | 22 | 23 | class Upsample(nn.Module): 24 | def __init__(self, in_channels, with_conv): 25 | super().__init__() 26 | self.with_conv = with_conv 27 | if self.with_conv: 28 | self.conv = torch.nn.Conv2d(in_channels, 29 | in_channels, 30 | kernel_size=3, 31 | stride=1, 32 | padding=1) 33 | 34 | def forward(self, x): 35 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 36 | if self.with_conv: 37 | x = self.conv(x) 38 | return x 39 | 40 | 41 | class Downsample(nn.Module): 42 | def __init__(self, in_channels, with_conv): 43 | super().__init__() 44 | self.with_conv = with_conv 45 | if self.with_conv: 46 | # no asymmetric padding in torch conv, must do it ourselves 47 | self.conv = torch.nn.Conv2d(in_channels, 48 | in_channels, 49 | kernel_size=3, 50 | stride=2, 51 | padding=0) 52 | 53 | def forward(self, x): 54 | if self.with_conv: 55 | pad = (0, 1, 0, 1) 56 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 57 | x = self.conv(x) 58 | else: 59 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 60 | return x 61 | 62 | 63 | class ResnetBlock(nn.Module): 64 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 65 | dropout, temb_channels=512): 66 | assert temb_channels == 0 67 | super().__init__() 68 | self.in_channels = in_channels 69 | out_channels = in_channels if out_channels is None else out_channels 70 | self.out_channels = out_channels 71 | self.use_conv_shortcut = conv_shortcut 72 | 73 | self.norm1 = Normalize(in_channels) 74 | self.conv1 = torch.nn.Conv2d(in_channels, 75 | out_channels, 76 | kernel_size=3, 77 | stride=1, 78 | padding=1) 79 | self.norm2 = Normalize(out_channels) 80 | self.dropout = torch.nn.Dropout(dropout) 81 | self.conv2 = torch.nn.Conv2d(out_channels, 82 | out_channels, 83 | kernel_size=3, 84 | stride=1, 85 | padding=1) 86 | if self.in_channels != self.out_channels: 87 | if self.use_conv_shortcut: 88 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 89 | out_channels, 90 | kernel_size=3, 91 | stride=1, 92 | padding=1) 93 | else: 94 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 95 | out_channels, 96 | kernel_size=1, 97 | stride=1, 98 | padding=0) 99 | 100 | def forward(self, x, temb=None): 101 | assert temb is None 102 | 103 | h = x 104 | h = self.norm1(h) 105 | h = nonlinearity(h) 106 | h = self.conv1(h) 107 | 108 | h = self.norm2(h) 109 | h = nonlinearity(h) 110 | h = self.dropout(h) 111 | h = self.conv2(h) 112 | 113 | if self.in_channels != self.out_channels: 114 | if self.use_conv_shortcut: 115 | x = self.conv_shortcut(x) 116 | else: 117 | x = self.nin_shortcut(x) 118 | return x+h 119 | 120 | 121 | class AttnBlock(nn.Module): 122 | def __init__(self, in_channels): 123 | super().__init__() 124 | self.in_channels = in_channels 125 | 126 | self.norm = Normalize(in_channels) 127 | self.q = torch.nn.Conv2d(in_channels, 128 | in_channels, 129 | kernel_size=1, 130 | stride=1, 131 | padding=0) 132 | self.k = torch.nn.Conv2d(in_channels, 133 | in_channels, 134 | kernel_size=1, 135 | stride=1, 136 | padding=0) 137 | self.v = torch.nn.Conv2d(in_channels, 138 | in_channels, 139 | kernel_size=1, 140 | stride=1, 141 | padding=0) 142 | self.proj_out = torch.nn.Conv2d(in_channels, 143 | in_channels, 144 | kernel_size=1, 145 | stride=1, 146 | padding=0) 147 | 148 | def forward(self, x): 149 | h_ = x 150 | h_ = self.norm(h_) 151 | q = self.q(h_) 152 | k = self.k(h_) 153 | v = self.v(h_) 154 | 155 | # compute attention 156 | b, c, h, w = q.shape 157 | q = q.reshape(b, c, h*w) 158 | q = q.permute(0, 2, 1) # b,hw,c 159 | k = k.reshape(b, c, h*w) # b,c,hw 160 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 161 | w_ = w_ * (int(c)**(-0.5)) 162 | w_ = torch.nn.functional.softmax(w_, dim=2) 163 | 164 | # attend to values 165 | v = v.reshape(b, c, h*w) 166 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 167 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 168 | h_ = h_.reshape(b, c, h, w) 169 | 170 | h_ = self.proj_out(h_) 171 | return x+h_ 172 | 173 | 174 | class Encoder(nn.Module): 175 | def __init__(self, 176 | *, # forced to use named arguments 177 | ch: int, 178 | out_ch: int, 179 | ch_mult: Tuple[int] = (1, 2, 4, 8), 180 | num_res_blocks: int, 181 | attn_resolutions: Tuple[int], 182 | pdrop: float = 0.0, 183 | resamp_with_conv: bool = True, 184 | in_channels: int, 185 | resolution: int, 186 | z_channels: int, 187 | double_z: Optional[bool] = None) -> None: 188 | super().__init__() 189 | self.ch = ch 190 | self.temb_ch = 0 191 | self.num_resolutions = len(ch_mult) 192 | self.num_res_blocks = num_res_blocks 193 | self.resolution = resolution 194 | self.in_channels = in_channels 195 | 196 | # downsampling 197 | self.conv_in = torch.nn.Conv2d(in_channels, 198 | self.ch, 199 | kernel_size=3, 200 | stride=1, 201 | padding=1) 202 | 203 | curr_res = resolution 204 | in_ch_mult = (1,)+tuple(ch_mult) 205 | self.down = nn.ModuleList() 206 | for i_level in range(self.num_resolutions): 207 | block = nn.ModuleList() 208 | attn = nn.ModuleList() 209 | block_in = ch*in_ch_mult[i_level] 210 | block_out = ch*ch_mult[i_level] 211 | for i_block in range(self.num_res_blocks): 212 | block.append(ResnetBlock(in_channels=block_in, 213 | out_channels=block_out, 214 | temb_channels=self.temb_ch, 215 | dropout=pdrop)) 216 | block_in = block_out 217 | if curr_res in attn_resolutions: 218 | attn.append(AttnBlock(block_in)) 219 | down = nn.Module() 220 | down.block = block 221 | down.attn = attn 222 | if i_level != self.num_resolutions-1: 223 | down.downsample = Downsample(block_in, resamp_with_conv) 224 | curr_res = curr_res // 2 225 | self.down.append(down) 226 | 227 | # middle 228 | self.mid = nn.Module() 229 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 230 | out_channels=block_in, 231 | temb_channels=self.temb_ch, 232 | dropout=pdrop) 233 | self.mid.attn_1 = AttnBlock(block_in) 234 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 235 | out_channels=block_in, 236 | temb_channels=self.temb_ch, 237 | dropout=pdrop) 238 | 239 | # end 240 | self.norm_out = Normalize(block_in) 241 | self.conv_out = torch.nn.Conv2d(block_in, 242 | 2*z_channels if double_z else z_channels, 243 | kernel_size=3, 244 | stride=1, 245 | padding=1) 246 | 247 | def forward(self, x): 248 | assert x.shape[2] == x.shape[3] == self.resolution, \ 249 | "{}, {}".format(x.shape, self.resolution) 250 | 251 | # downsampling 252 | h = self.conv_in(x) 253 | for i_level in range(self.num_resolutions): 254 | for i_block in range(self.num_res_blocks): 255 | h = self.down[i_level].block[i_block](h) 256 | if len(self.down[i_level].attn) > 0: 257 | h = self.down[i_level].attn[i_block](h) 258 | if i_level != self.num_resolutions-1: 259 | h = self.down[i_level].downsample(h) 260 | 261 | # middle 262 | h = self.mid.block_1(h) 263 | h = self.mid.attn_1(h) 264 | h = self.mid.block_2(h) 265 | 266 | # end 267 | h = self.norm_out(h) 268 | h = nonlinearity(h) 269 | h = self.conv_out(h) 270 | return h 271 | 272 | 273 | class Decoder(nn.Module): 274 | def __init__(self, 275 | *, # forced to use named arguments 276 | ch: int, 277 | out_ch: int, 278 | ch_mult: Tuple[int] = (1, 2, 4, 8), 279 | num_res_blocks: int, 280 | attn_resolutions: Tuple[int], 281 | pdrop: float = 0.0, 282 | resamp_with_conv: bool = True, 283 | in_channels: int, 284 | resolution: int, 285 | z_channels: int, 286 | double_z: bool) -> None: 287 | super().__init__() 288 | self.ch = ch 289 | self.temb_ch = 0 290 | self.num_resolutions = len(ch_mult) 291 | self.num_res_blocks = num_res_blocks 292 | self.resolution = resolution 293 | self.in_channels = in_channels 294 | 295 | # compute in_ch_mult, block_in and curr_res at lowest res 296 | block_in = ch*ch_mult[self.num_resolutions-1] 297 | curr_res = resolution // 2**(self.num_resolutions-1) 298 | self.z_shape = (1, z_channels, curr_res, curr_res) 299 | 300 | # z to block_in 301 | self.conv_in = torch.nn.Conv2d(z_channels, 302 | block_in, 303 | kernel_size=3, 304 | stride=1, 305 | padding=1) 306 | 307 | # middle 308 | self.mid = nn.Module() 309 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 310 | out_channels=block_in, 311 | temb_channels=self.temb_ch, 312 | dropout=pdrop) 313 | self.mid.attn_1 = AttnBlock(block_in) 314 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 315 | out_channels=block_in, 316 | temb_channels=self.temb_ch, 317 | dropout=pdrop) 318 | 319 | # upsampling 320 | self.up = nn.ModuleList() 321 | for i_level in reversed(range(self.num_resolutions)): 322 | block = nn.ModuleList() 323 | attn = nn.ModuleList() 324 | block_out = ch*ch_mult[i_level] 325 | for i_block in range(self.num_res_blocks+1): 326 | block.append(ResnetBlock(in_channels=block_in, 327 | out_channels=block_out, 328 | temb_channels=self.temb_ch, 329 | dropout=pdrop)) 330 | block_in = block_out 331 | if curr_res in attn_resolutions: 332 | attn.append(AttnBlock(block_in)) 333 | up = nn.Module() 334 | up.block = block 335 | up.attn = attn 336 | if i_level != 0: 337 | up.upsample = Upsample(block_in, resamp_with_conv) 338 | curr_res = curr_res * 2 339 | self.up.insert(0, up) # prepend to get consistent order 340 | 341 | # end 342 | self.norm_out = Normalize(block_in) 343 | self.conv_out = torch.nn.Conv2d(block_in, 344 | out_ch, 345 | kernel_size=3, 346 | stride=1, 347 | padding=1) 348 | 349 | def forward(self, z): 350 | assert z.shape[1:] == self.z_shape[1:] 351 | self.last_z_shape = z.shape 352 | 353 | # z to block_in 354 | h = self.conv_in(z) 355 | 356 | # middle 357 | h = self.mid.block_1(h) 358 | h = self.mid.attn_1(h) 359 | h = self.mid.block_2(h) 360 | 361 | # upsampling 362 | for i_level in reversed(range(self.num_resolutions)): 363 | for i_block in range(self.num_res_blocks+1): 364 | h = self.up[i_level].block[i_block](h) 365 | if len(self.up[i_level].attn) > 0: 366 | h = self.up[i_level].attn[i_block](h) 367 | if i_level != 0: 368 | h = self.up[i_level].upsample(h) 369 | 370 | h = self.norm_out(h) 371 | h = nonlinearity(h) 372 | h = self.conv_out(h) 373 | return h 374 | -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage1/vqgan.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Modified from VQGAN (https://github.com/CompVis/taming-transformers) 3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved. 4 | # ------------------------------------------------------------------------------------ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from typing import List, Tuple, Optional 9 | from einops import rearrange 10 | from omegaconf import OmegaConf 11 | from .layers import Encoder, Decoder 12 | 13 | 14 | class VectorQuantizer(nn.Module): 15 | """ 16 | Simplified VectorQuantizer in the original VQGAN repository 17 | by removing unncessary modules for sampling 18 | """ 19 | def __init__(self, dim: int, n_embed: int, beta: float) -> None: 20 | super().__init__() 21 | self.n_embed = n_embed 22 | self.dim = dim 23 | self.beta = beta 24 | 25 | self.embedding = nn.Embedding(self.n_embed, self.dim) 26 | self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed) 27 | 28 | def forward(self, 29 | z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: 30 | z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C] 31 | z_flattened = z.view(-1, self.dim) 32 | 33 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 34 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 35 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 36 | 37 | min_encoding_indices = torch.argmin(d, dim=1) 38 | z_q = self.embedding(min_encoding_indices).view(z.shape) 39 | return z_q, min_encoding_indices 40 | 41 | def get_codebook_entry(self, 42 | indices: torch.LongTensor, 43 | shape: Optional[List[int]] = None) -> torch.FloatTensor: 44 | z_q = self.embedding(indices) 45 | if shape is not None: 46 | z_q = z_q.view(shape) 47 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 48 | return z_q 49 | 50 | 51 | class VQGAN(nn.Module): 52 | def __init__(self, n_embed: int, embed_dim: int, hparams: OmegaConf) -> None: 53 | super().__init__() 54 | self.encoder = Encoder(**hparams) 55 | self.decoder = Decoder(**hparams) 56 | self.quantize = VectorQuantizer(dim=embed_dim, n_embed=n_embed, beta=0.25) 57 | self.quant_conv = torch.nn.Conv2d(hparams.z_channels, embed_dim, 1) 58 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, hparams.z_channels, 1) 59 | self.latent_dim = hparams.attn_resolutions[0] 60 | 61 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 62 | quant = self.encode(x) 63 | dec = self.decode(quant) 64 | return dec 65 | 66 | def encode(self, x: torch.FloatTensor) -> torch.FloatTensor: 67 | h = self.encoder(x) 68 | h = self.quant_conv(h) 69 | quant = self.quantize(h)[0] 70 | quant = rearrange(quant, 'b h w c -> b c h w').contiguous() 71 | return quant 72 | 73 | def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor: 74 | quant = self.post_quant_conv(quant) 75 | dec = self.decoder(quant) 76 | return dec 77 | 78 | def decode_code(self, code: torch.LongTensor) -> torch.FloatTensor: 79 | quant = self.quantize.get_codebook_entry(code) 80 | quant = quant.permute(0, 3, 1, 2) 81 | dec = self.decode(quant) 82 | return dec 83 | 84 | def get_codes(self, x: torch.FloatTensor) -> torch.LongTensor: 85 | h = self.encoder(x) 86 | h = self.quant_conv(h) 87 | codes = self.quantize(h)[1].view(x.shape[0], self.latent_dim ** 2) 88 | return codes 89 | 90 | def from_ckpt(self, path: str, strict: bool = True) -> None: 91 | ckpt = torch.load(path, map_location='cpu')['state_dict'] 92 | self.load_state_dict(ckpt, strict=strict) 93 | print(f'{path} successfully restored..') 94 | -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage2/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage2/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage2/__pycache__/layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage2/__pycache__/layers.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage2/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage2/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage2/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage2/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/models/stage2/layers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Minimal DALL-E 3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | # Modified from minGPT (https://github.com/karpathy/minGPT) 7 | # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved. 8 | # ------------------------------------------------------------------------------------ 9 | 10 | import math 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import functional as F 14 | 15 | 16 | class GELU(nn.Module): 17 | def __init__(self, use_approx=False): 18 | super().__init__() 19 | self.use_approx = use_approx 20 | 21 | def forward(self, x): 22 | if self.use_approx: 23 | return x * torch.sigmoid(1.702 * x) 24 | else: 25 | return F.gelu(x) 26 | 27 | 28 | class MultiHeadSelfAttention(nn.Module): 29 | 30 | def __init__(self, 31 | ctx_len: int, 32 | embed_dim: int, 33 | n_heads: int, 34 | resid_pdrop: float, 35 | attn_pdrop: float, 36 | attn_bias: bool, 37 | use_mask: bool = True): 38 | super().__init__() 39 | assert embed_dim % n_heads == 0 40 | 41 | # key, query, value projections for all heads 42 | self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias) 43 | self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias) 44 | self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias) 45 | 46 | # regularization 47 | self.attn_drop = nn.Dropout(attn_pdrop) 48 | self.resid_drop = nn.Dropout(resid_pdrop) 49 | 50 | # output projection 51 | self.proj = nn.Linear(embed_dim, embed_dim, attn_bias) 52 | 53 | self.n_heads = n_heads 54 | self.ctx_len = ctx_len 55 | self.use_mask = use_mask 56 | if self.use_mask: 57 | self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False) 58 | self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len) 59 | 60 | def forward(self, x, use_cache=False, layer_past=None): 61 | B, T, C = x.shape 62 | x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C) 63 | 64 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 65 | k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs) 66 | q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs) 67 | v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs) 68 | 69 | if use_cache: 70 | present = torch.stack([k, v]) 71 | 72 | if layer_past is not None: 73 | # print(layer_past.shape, k.shape, v.shape, q.shape) 74 | # print("LayerPast shape", layer_past.shape) 75 | past_key, past_value = layer_past 76 | 77 | if len(past_key.shape) == 4: 78 | _, _, seq_len, dim = past_key.shape 79 | k = torch.cat([past_key.reshape(-1, seq_len, dim), k], dim=-2) 80 | v = torch.cat([past_value.reshape(-1, seq_len, dim), v], dim=-2) 81 | elif len(past_key.shape) == 3: 82 | past_key, past_value = layer_past 83 | k = torch.cat([past_key, k], dim=-2) 84 | v = torch.cat([past_value, v], dim=-2) 85 | else: 86 | raise ValueError 87 | 88 | if use_cache and layer_past is not None: 89 | # Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K) 90 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))) 91 | att = F.softmax(att, dim=-1) 92 | att = self.attn_drop(att) 93 | y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs) 94 | else: 95 | # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T) 96 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))) 97 | if self.use_mask: 98 | # TODO : Flip when not prompt tunign 99 | # mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T] 100 | if T == self.ctx_len: 101 | mask = self.mask 102 | else: 103 | mask = torch.tril(torch.ones(T, T)).view(1, T, T).to(att.device) 104 | att = att.masked_fill(mask == 0, float('-inf')) 105 | att = F.softmax(att, dim=-1) 106 | att = self.attn_drop(att) 107 | y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs) 108 | y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side 109 | 110 | # output projection 111 | y = self.resid_drop(self.proj(y)) 112 | if use_cache: 113 | return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C) 114 | else: 115 | return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C) 116 | 117 | def forward_with_context(self, x, context, mask=None): 118 | B, T, C = x.shape 119 | x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C) 120 | 121 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 122 | q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs) 123 | 124 | B, T_c, C = context.shape 125 | k = self.key(context).view(T_c, B * self.n_heads, C // self.n_heads).transpose(0, 1) # (B*nh, T, hs) 126 | v = self.value(context).view(T_c, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs) 127 | 128 | # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, Tc) -> (B * nh, T, Tc) 129 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))) 130 | att = F.softmax(att, dim=-1) 131 | att = self.attn_drop(att) 132 | y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs) 133 | y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side 134 | 135 | # output projection 136 | y = self.resid_drop(self.proj(y)).transpose(0, 1).contiguous() 137 | if mask is not None: 138 | y = y.masked_fill(mask == 0, float('0.0')) 139 | return y # (T, B, C) -> (B, T, C) 140 | 141 | 142 | class Block(nn.Module): 143 | 144 | def __init__(self, 145 | ctx_len: int, 146 | embed_dim: int, 147 | n_heads: int, 148 | mlp_bias: bool, 149 | attn_bias: bool, 150 | resid_pdrop: bool, 151 | attn_pdrop: bool, 152 | gelu_use_approx: bool): 153 | super().__init__() 154 | self.ln1 = nn.LayerNorm(embed_dim) 155 | self.ln2 = nn.LayerNorm(embed_dim) 156 | 157 | self.attn = MultiHeadSelfAttention(ctx_len=ctx_len, 158 | embed_dim=embed_dim, 159 | n_heads=n_heads, 160 | attn_pdrop=attn_pdrop, 161 | resid_pdrop=resid_pdrop, 162 | attn_bias=attn_bias, 163 | use_mask=True) 164 | self.mlp = nn.Sequential( 165 | nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias), 166 | GELU(gelu_use_approx), 167 | nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias), 168 | nn.Dropout(resid_pdrop), 169 | ) 170 | 171 | def forward(self, x, layer_past=None): 172 | x = x + self.attn(self.ln1(x), layer_past=layer_past) 173 | x = x + self.mlp(self.ln2(x)) 174 | return x 175 | 176 | def sample(self, x, layer_past=None): 177 | attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past) 178 | x = x + attn 179 | x = x + self.mlp(self.ln2(x)) 180 | return x, present 181 | 182 | def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past=None): 183 | attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past) 184 | x = x + attn 185 | c_attn = cross_attn_layer(x, context, context_mask) 186 | x = x + c_attn 187 | x = x + self.mlp(self.ln2(x)) 188 | return x, present 189 | 190 | 191 | class CrossAttentionLayer(nn.Module): 192 | 193 | def __init__(self, 194 | ctx_len: int, 195 | embed_dim: int, 196 | n_heads: int, 197 | attn_bias: bool, 198 | resid_pdrop: bool, 199 | attn_pdrop: bool): 200 | super().__init__() 201 | 202 | self.ln1 = nn.LayerNorm(embed_dim) 203 | self.ln2 = nn.LayerNorm(embed_dim) 204 | self.attn = MultiHeadSelfAttention(ctx_len=ctx_len, 205 | embed_dim=embed_dim, 206 | n_heads=n_heads, 207 | attn_pdrop=attn_pdrop, 208 | resid_pdrop=resid_pdrop, 209 | attn_bias=attn_bias, 210 | use_mask=False) 211 | 212 | def forward(self, x, context, context_mask=None): 213 | attn = self.attn.forward_with_context(self.ln1(x), self.ln2(context), context_mask) 214 | # x = x + attn 215 | # return x 216 | return attn -------------------------------------------------------------------------------- /story-dalle/dalle/models/tokenizer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Minimal DALL-E 3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import os 8 | from functools import partial 9 | from tokenizers import CharBPETokenizer 10 | 11 | 12 | def build_tokenizer(path: str, 13 | context_length: int = 64, 14 | *args, 15 | **kwargs): 16 | try: 17 | from_file = partial(CharBPETokenizer.from_file, 18 | vocab_filename=os.path.join(path, 'bpe-16k-vocab.json'), 19 | merges_filename=os.path.join(path, 'bpe-16k-merges.txt'), 20 | unk_token='[UNK]') 21 | tokenizer = from_file(*args, **kwargs) 22 | except: 23 | from_file = partial(CharBPETokenizer.from_file, 24 | vocab_filename=os.path.join(path, 'vocab.json'), 25 | merges_filename=os.path.join(path, 'merges.txt'), 26 | unk_token='[UNK]') 27 | tokenizer = from_file(*args, **kwargs) 28 | 29 | # tokenizer = from_file(*args, **kwargs) 30 | tokenizer.add_special_tokens(['[PAD]']) 31 | tokenizer.enable_padding(length=context_length, 32 | pad_id=tokenizer.token_to_id('[PAD]')) 33 | tokenizer.enable_truncation(max_length=context_length) 34 | print(f'{path} successfully restored..') 35 | return tokenizer 36 | -------------------------------------------------------------------------------- /story-dalle/dalle/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .config import * 3 | from .sampling import * -------------------------------------------------------------------------------- /story-dalle/dalle/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/utils/__pycache__/sampling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/sampling.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/dalle/utils/config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Minimal DALL-E 3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | from typing import Optional, List 8 | from dataclasses import dataclass, field 9 | from omegaconf import OmegaConf 10 | 11 | 12 | @dataclass 13 | class DataConfig: 14 | dataset: Optional[str] = None 15 | tokenizer_type: str = 'CharBPE' 16 | context_length: int = 64 17 | image_resolution: int = 256 18 | transforms: str = 'dalle-vqvae' 19 | bpe_pdrop: Optional[float] = None 20 | 21 | 22 | @dataclass 23 | class Stage1Hparams: 24 | double_z: bool = False 25 | z_channels: int = 256 26 | resolution: int = 256 27 | in_channels: int = 3 28 | out_ch: int = 3 29 | ch: int = 128 30 | ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) 31 | num_res_blocks: int = 2 32 | attn_resolutions: List[int] = field(default_factory=lambda: [16]) 33 | pdrop: float = 0.0 34 | 35 | 36 | @dataclass 37 | class Stage2Hparams: 38 | embed_dim: int = 1536 39 | n_layers: int = 42 40 | n_heads: int = 24 41 | n_dense_layers: int = 42 42 | ctx_len_img: int = 256 43 | ctx_len_txt: int = 64 44 | embd_pdrop: float = 0.0 45 | resid_pdrop: float = 0.0 46 | attn_pdrop: float = 0.0 47 | mlp_bias: bool = True 48 | attn_bias: bool = True 49 | gelu_use_approx: bool = False 50 | use_head_txt: bool = True 51 | n_classes: Optional[int] = None 52 | 53 | 54 | @dataclass 55 | class Stage1Config: 56 | type: str = 'vqgan' 57 | embed_dim: int = 256 58 | n_embed: int = 16384 59 | hparams: Stage1Hparams = Stage1Hparams() 60 | 61 | 62 | @dataclass 63 | class Stage2Config: 64 | type: str = 'transformer1d' 65 | vocab_size_txt: int = 16384 66 | vocab_size_img: int = 16384 67 | use_cls_cond: Optional[bool] = None 68 | hparams: Stage2Hparams = Stage2Hparams() 69 | 70 | 71 | @dataclass 72 | class WarmupConfig: 73 | epoch: int = 1 74 | multiplier: int = 1 75 | buffer_epoch: int = 0 76 | min_lr: float = 0.0 77 | mode: str = 'fix' 78 | peak_lr: float = 1e-4 79 | start_from_zero: bool = True 80 | 81 | 82 | @dataclass 83 | class OptConfig: 84 | opt_type: str = 'adamW' 85 | learning_rate: float = 5e-5 86 | weight_decay: float = 1e-4 87 | betas: List[float] = field(default_factory=lambda: [0.9, 0.99]) 88 | grad_clip_norm: float = 1.0 89 | 90 | sched_type: str = 'cosine' 91 | max_steps: int = 0 92 | min_lr: float = 1e-6 93 | 94 | 95 | @dataclass 96 | class ExpConfig: 97 | per_gpu_train_batch_size: int = 4 98 | per_gpu_eval_batch_size: int = 32 99 | num_train_epochs: int = 10 100 | save_ckpt_freq: int = 1 101 | test_freq: int = 10 102 | use_amp: bool = True 103 | 104 | 105 | @dataclass 106 | class PrefixModelConfig: 107 | model_name_or_path: Optional[str] = '' 108 | prefix_model_name_or_path: str = '' 109 | prefix_mode: str = 'activation' 110 | tuning_mode: str = 'finetune' 111 | top_k_layers: int = 2 112 | parameterize_mode: str = 'mlp' 113 | optim_prefix: bool = False 114 | preseqlen: int = 10 115 | prefix_dropout: float = 0.1 116 | init_random: bool = False 117 | hidden_dim_prefix: int = 512 118 | lowdata: bool = False 119 | lowdata_token: str = '' 120 | init_shallow: bool = False 121 | init_shallow_word: bool = False 122 | teacher_dropout: float = 0.1 123 | gumbel: bool = False 124 | replay_buffer: bool = False 125 | 126 | 127 | @dataclass 128 | class PromptModelConfig: 129 | model_name_or_path: Optional[str] = '' 130 | prefix_model_name_or_path: str = '' 131 | tuning_mode: str = 'prompt' 132 | preseqlen: int = 10 133 | prefix_dropout: float = 0.1 134 | 135 | 136 | @dataclass 137 | class StoryModelConfig: 138 | model_name_or_path: Optional[str] = '' 139 | prefix_model_name_or_path: str = '' 140 | tuning_mode: str = 'story' 141 | preseqlen: int = 10 142 | prefix_dropout: float = 0.1 143 | prompt: bool = False 144 | story_len: int = 4 145 | sent_embed: int = 256 146 | condition: bool = False 147 | clip_embed: bool = False 148 | 149 | 150 | @dataclass 151 | class DefaultConfig: 152 | dataset: DataConfig = DataConfig() 153 | stage1: Stage1Config = Stage1Config() 154 | stage2: Stage2Config = Stage2Config() 155 | 156 | 157 | @dataclass 158 | class FineTuningConfig: 159 | dataset: DataConfig = DataConfig() 160 | stage1: Stage1Config = Stage1Config() 161 | stage2: Stage2Config = Stage2Config() 162 | optimizer: OptConfig = OptConfig() 163 | experiment: ExpConfig = ExpConfig() 164 | 165 | 166 | @dataclass 167 | class PrefixTuningConfig: 168 | dataset: DataConfig = DataConfig() 169 | stage1: Stage1Config = Stage1Config() 170 | stage2: Stage2Config = Stage2Config() 171 | prefix: PrefixModelConfig = PrefixModelConfig() 172 | optimizer: OptConfig = OptConfig() 173 | experiment: ExpConfig = ExpConfig() 174 | 175 | 176 | @dataclass 177 | class PromptTuningConfig: 178 | dataset: DataConfig = DataConfig() 179 | stage1: Stage1Config = Stage1Config() 180 | stage2: Stage2Config = Stage2Config() 181 | prompt: PromptModelConfig = PromptModelConfig() 182 | optimizer: OptConfig = OptConfig() 183 | experiment: ExpConfig = ExpConfig() 184 | 185 | 186 | @dataclass 187 | class StoryConfig: 188 | dataset: DataConfig = DataConfig() 189 | stage1: Stage1Config = Stage1Config() 190 | stage2: Stage2Config = Stage2Config() 191 | story: StoryModelConfig = StoryModelConfig() 192 | optimizer: OptConfig = OptConfig() 193 | experiment: ExpConfig = ExpConfig() 194 | 195 | 196 | def get_base_config(mode): 197 | if mode == 'default': 198 | return OmegaConf.structured(DefaultConfig) 199 | elif mode == 'finetuning': 200 | return OmegaConf.structured(FineTuningConfig) 201 | elif mode == 'prefixtuning': 202 | return OmegaConf.structured(PrefixTuningConfig) 203 | elif mode == 'prompt_tuning': 204 | return OmegaConf.structured(PromptTuningConfig) 205 | elif mode == 'story': 206 | return OmegaConf.structured(StoryConfig) 207 | else: 208 | raise ValueError 209 | # return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig) 210 | -------------------------------------------------------------------------------- /story-dalle/dalle/utils/sampling.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Minimal DALL-E 3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import torch 8 | from typing import Optional 9 | from tqdm import tqdm 10 | from torch.nn import functional as F 11 | 12 | 13 | torch.set_printoptions(precision=2, threshold=10) 14 | def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor: 15 | if k is None: 16 | return logits 17 | else: 18 | v, ix = torch.topk(logits, k) 19 | out = logits.clone() 20 | out[out < v[:, [-1]]] = -float('Inf') 21 | return out 22 | 23 | 24 | def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor: 25 | if p is None: 26 | return probs 27 | else: 28 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) 29 | cum_probs = torch.cumsum(sorted_probs, dim=-1) 30 | 31 | sorted_idx_remove_cond = cum_probs >= p 32 | 33 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone() 34 | sorted_idx_remove_cond[..., 0] = 0 35 | 36 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond) 37 | probs = probs.masked_fill(indices_to_remove, 0.0) 38 | norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True) 39 | return norm_probs 40 | 41 | 42 | def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor: 43 | device = inputs.device 44 | if mode == '1d': 45 | B, N = inputs.shape 46 | xs_pos = torch.arange(N, device=device).repeat((B, 1)) 47 | elif mode == '2d': 48 | B, H, W = inputs.shape 49 | xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2) 50 | xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1) 51 | xs_pos = (xs_pos_h, xs_pos_w) 52 | else: 53 | raise ValueError('%s positional encoding invalid' % mode) 54 | return xs_pos 55 | 56 | 57 | @torch.no_grad() 58 | def sampling(model: torch.nn.Module, 59 | tokens: torch.LongTensor, 60 | top_k: Optional[float] = None, 61 | top_p: Optional[float] = None, 62 | softmax_temperature: float = 1.0, 63 | is_tqdm: bool = True, 64 | use_fp16: bool = True, 65 | max_seq_len: int = 256, 66 | prompt: Optional[torch.tensor] = None, 67 | pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor: 68 | 69 | code = None 70 | past = None 71 | 72 | pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len) 73 | pos_enc_tokens = get_positional_encoding(tokens, mode='1d') 74 | 75 | for cnt, h in enumerate(pbar): 76 | if code is None: 77 | code_ = None 78 | pos_enc_code_ = None 79 | else: 80 | code_ = code.clone().detach() 81 | pos_enc_code_ = get_positional_encoding(code_, mode='1d') 82 | code_ = code_[:, cnt-1].unsqueeze(-1) 83 | pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1) 84 | 85 | logits, present = model.sampling(images=code_, 86 | texts=tokens, 87 | pos_images=pos_enc_code_, 88 | pos_texts=pos_enc_tokens, 89 | use_fp16=use_fp16, 90 | past=past, 91 | prompt=prompt, 92 | pos_prompt=pos_prompt) 93 | 94 | logits = logits.to(dtype=torch.float32) 95 | logits = logits / softmax_temperature 96 | 97 | # print(len(present), present[0].shape) 98 | present = torch.stack(present).clone().detach() 99 | if past is None: 100 | past = [present] 101 | else: 102 | past.append(present) 103 | 104 | logits = cutoff_topk_logits(logits, top_k) 105 | probs = F.softmax(logits, dim=-1) 106 | probs = cutoff_topp_probs(probs, top_p) 107 | # print(probs[0]) 108 | 109 | idx = torch.multinomial(probs, num_samples=1).clone().detach() 110 | # print(idx) 111 | code = idx if code is None else torch.cat([code, idx], axis=1) 112 | 113 | del past 114 | return code 115 | 116 | 117 | @torch.no_grad() 118 | def sampling_prefix(model: torch.nn.Module, 119 | tokens: torch.LongTensor, 120 | past: torch.FloatTensor, 121 | top_k: Optional[float] = None, 122 | top_p: Optional[float] = None, 123 | softmax_temperature: float = 1.0, 124 | is_tqdm: bool = True, 125 | use_fp16: bool = True, 126 | max_seq_len: int = 256, 127 | labels = None) -> torch.LongTensor: 128 | code = None 129 | 130 | pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len) 131 | pos_enc_tokens = get_positional_encoding(tokens, mode='1d') 132 | 133 | # print("Entering sampling_prefix; ", past.shape) 134 | if past is not None: 135 | past = [past] 136 | 137 | for cnt, h in enumerate(pbar): 138 | if code is None: 139 | code_ = None 140 | pos_enc_code_ = None 141 | else: 142 | code_ = code.clone().detach() 143 | pos_enc_code_ = get_positional_encoding(code_, mode='1d') 144 | code_ = code_[:, cnt-1].unsqueeze(-1) 145 | pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1) 146 | 147 | # print("Looop enter") 148 | # print(cnt, past[0].shape) 149 | # print("-------------------") 150 | logits, present = model.sampling(images=code_, 151 | texts=tokens, 152 | pos_images=pos_enc_code_, 153 | pos_texts=pos_enc_tokens, 154 | use_fp16=use_fp16, 155 | past=past) 156 | logits = logits.to(dtype=torch.float32) 157 | logits = logits / softmax_temperature 158 | 159 | present = torch.stack(present).clone().detach() 160 | 161 | # print('Present', present.shape) 162 | 163 | if past is None: 164 | past = [present] 165 | else: 166 | # print("Loop end") 167 | # print(present.shape) 168 | # print("-----------------") 169 | 170 | # n_layers, temp, _, seq_len, n_dim = present.shape 171 | # _, _, bs, n_heads, pre_seq_len, n_dim = past[0].shape 172 | # assert temp == 2 173 | # past.append(present.view(n_layers, temp, bs, n_heads, seq_len, n_dim)) 174 | 175 | past.append(present) 176 | 177 | logits = cutoff_topk_logits(logits, top_k) 178 | probs = F.softmax(logits, dim=-1) 179 | probs = cutoff_topp_probs(probs, top_p) 180 | print(torch.topk(probs, 5, dim=-1)) 181 | if labels is not None: 182 | print(labels[cnt]) 183 | idx = torch.multinomial(probs, num_samples=1).clone().detach() 184 | # print(idx) 185 | code = idx if code is None else torch.cat([code, idx], axis=1) 186 | 187 | del past 188 | return code 189 | 190 | 191 | @torch.no_grad() 192 | def sampling_prefix_new(model: torch.nn.Module, 193 | tokens: torch.LongTensor, 194 | past: torch.FloatTensor, 195 | top_k: Optional[float] = None, 196 | top_p: Optional[float] = None, 197 | softmax_temperature: float = 1.0, 198 | is_tqdm: bool = True, 199 | use_fp16: bool = True, 200 | max_seq_len: int = 256) -> torch.LongTensor: 201 | code = None 202 | 203 | pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len) 204 | pos_enc_tokens = get_positional_encoding(tokens, mode='1d') 205 | 206 | # print("Entering sampling_prefix; ", past.shape) 207 | if past is not None: 208 | past = [past] 209 | 210 | for cnt, h in enumerate(pbar): 211 | if code is None: 212 | code_ = None 213 | pos_enc_code_ = None 214 | else: 215 | code_ = code.clone().detach() 216 | pos_enc_code_ = get_positional_encoding(code_, mode='1d') 217 | # code_ = code_[:, cnt-1].unsqueeze(-1) 218 | # pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1) 219 | 220 | # print("Looop enter") 221 | # print(cnt, past[0].shape) 222 | # print("-------------------") 223 | 224 | if cnt == 0: 225 | logits, present = model.sampling(images=code_, 226 | texts=tokens, 227 | pos_images=pos_enc_code_, 228 | pos_texts=pos_enc_tokens, 229 | use_fp16=use_fp16, 230 | past=past) 231 | logits = logits.to(dtype=torch.float32) 232 | logits = logits / softmax_temperature 233 | 234 | present = torch.stack(present).clone().detach() 235 | 236 | # print('Present', present.shape) 237 | 238 | if past is None: 239 | past = [present] 240 | else: 241 | pass 242 | 243 | logits = cutoff_topk_logits(logits, top_k) 244 | probs = F.softmax(logits, dim=-1) 245 | probs = cutoff_topp_probs(probs, top_p) 246 | # print(torch.topk(probs[0], 5)) 247 | idx = torch.multinomial(probs, num_samples=1).clone().detach() 248 | # print(idx) 249 | code = idx if code is None else torch.cat([code, idx], axis=1) 250 | 251 | else: 252 | pass 253 | 254 | 255 | del past 256 | return code 257 | 258 | @torch.no_grad() 259 | def sampling_conditional(model: torch.nn.Module, 260 | cross_attention_idxs, 261 | cross_attention_layers, 262 | tokens: torch.LongTensor, 263 | src_codes: torch.FloatTensor, 264 | top_k: Optional[float] = None, 265 | top_p: Optional[float] = None, 266 | softmax_temperature: float = 1.0, 267 | is_tqdm: bool = True, 268 | use_fp16: bool = True, 269 | max_seq_len: int = 256, 270 | prompt: Optional[torch.tensor] = None, 271 | pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor: 272 | 273 | code = None 274 | past = None 275 | 276 | pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len) 277 | pos_enc_tokens = get_positional_encoding(tokens, mode='1d') 278 | 279 | src_pos_tokens = get_positional_encoding(src_codes, mode='1d') 280 | src_tokens = model.tok_emb_img(src_codes) 281 | src_tokens = src_tokens + model.pos_emb_img(src_pos_tokens) 282 | 283 | for cnt, h in enumerate(pbar): 284 | if code is None: 285 | code_ = None 286 | pos_enc_code_ = None 287 | else: 288 | code_ = code.clone().detach() 289 | pos_enc_code_ = get_positional_encoding(code_, mode='1d') 290 | code_ = code_[:, cnt-1].unsqueeze(-1) 291 | pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1) 292 | 293 | logits, present = model.sampling_with_context(images=code_, 294 | cross_attention_idxs=cross_attention_idxs, 295 | cross_attention_layers=cross_attention_layers, 296 | texts=tokens, 297 | pos_images=pos_enc_code_, 298 | pos_texts=pos_enc_tokens, 299 | source_image=src_tokens, 300 | use_fp16=use_fp16, 301 | past=past, 302 | prompt=prompt, 303 | pos_prompt=pos_prompt) 304 | logits = logits.to(dtype=torch.float32) 305 | logits = logits / softmax_temperature 306 | 307 | present = torch.stack(present).clone().detach() 308 | if past is None: 309 | past = [present] 310 | else: 311 | past.append(present) 312 | 313 | logits = cutoff_topk_logits(logits, top_k) 314 | probs = F.softmax(logits, dim=-1) 315 | probs = cutoff_topp_probs(probs, top_p) 316 | 317 | idx = torch.multinomial(probs, num_samples=1).clone().detach() 318 | code = idx if code is None else torch.cat([code, idx], axis=1) 319 | 320 | del past 321 | return code 322 | 323 | 324 | @torch.no_grad() 325 | def sampling_igpt(model: torch.nn.Module, 326 | sos: torch.FloatTensor, 327 | top_k: Optional[float] = None, 328 | top_p: Optional[float] = None, 329 | softmax_temperature: float = 1.0, 330 | is_tqdm: bool = True, 331 | use_fp16: bool = True, 332 | max_seq_len: int = 256) -> torch.LongTensor: 333 | code = None 334 | past = None 335 | pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len) 336 | 337 | for cnt, h in enumerate(pbar): 338 | if code is None: 339 | code_ = None 340 | pos_enc_code_ = None 341 | else: 342 | code_ = code.clone().detach() 343 | pos_enc_code_ = get_positional_encoding(code_, mode='1d') 344 | code_ = code_[:, cnt-1].unsqueeze(-1) 345 | pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1) 346 | 347 | logits, present = model.sampling(sos=sos, 348 | codes=code_, 349 | pos_codes=pos_enc_code_, 350 | use_fp16=use_fp16, 351 | past=past) 352 | logits = logits.to(dtype=torch.float32) 353 | logits = logits / softmax_temperature 354 | 355 | present = torch.stack(present).clone().detach() 356 | if past is None: 357 | past = [present] 358 | else: 359 | past.append(present) 360 | 361 | logits = cutoff_topk_logits(logits, top_k) 362 | probs = F.softmax(logits, dim=-1) 363 | probs = cutoff_topp_probs(probs, top_p) 364 | 365 | idx = torch.multinomial(probs, num_samples=1).clone().detach() 366 | code = idx if code is None else torch.cat([code, idx], axis=1) 367 | 368 | del past 369 | return code 370 | -------------------------------------------------------------------------------- /story-dalle/dalle/utils/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------ 2 | # Minimal DALL-E 3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------ 6 | 7 | import os 8 | import random 9 | import urllib 10 | import hashlib 11 | import tarfile 12 | import torch 13 | import clip 14 | import numpy as np 15 | from PIL import Image 16 | from torch.nn import functional as F 17 | from tqdm import tqdm 18 | import torchvision.utils as vutils 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | def set_seed(seed: int): 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | 28 | 29 | @torch.no_grad() 30 | def clip_score(prompt: str, 31 | images: np.ndarray, 32 | model_clip: torch.nn.Module, 33 | preprocess_clip, 34 | device: str) -> np.ndarray: 35 | images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images] 36 | images = torch.stack(images, dim=0).to(device=device) 37 | texts = clip.tokenize(prompt).to(device=device) 38 | texts = torch.repeat_interleave(texts, images.shape[0], dim=0) 39 | 40 | image_features = model_clip.encode_image(images) 41 | text_features = model_clip.encode_text(texts) 42 | 43 | scores = F.cosine_similarity(image_features, text_features).squeeze() 44 | rank = torch.argsort(scores, descending=True).cpu().numpy() 45 | return rank 46 | 47 | 48 | def download(url: str, root: str) -> str: 49 | os.makedirs(root, exist_ok=True) 50 | filename = os.path.basename(url) 51 | pathname = filename[:-len('.tar.gz')] 52 | 53 | expected_md5 = url.split("/")[-2] 54 | download_target = os.path.join(root, filename) 55 | result_path = os.path.join(root, pathname) 56 | 57 | if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)): 58 | return result_path 59 | 60 | with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output: 61 | with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True, 62 | unit_divisor=1024) as loop: 63 | while True: 64 | buffer = source.read(8192) 65 | if not buffer: 66 | break 67 | 68 | output.write(buffer) 69 | loop.update(len(buffer)) 70 | 71 | if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5: 72 | raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match') 73 | 74 | with tarfile.open(download_target, 'r:gz') as f: 75 | pbar = tqdm(f.getmembers(), total=len(f.getmembers())) 76 | for member in pbar: 77 | pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)') 78 | f.extract(member=member, path=root) 79 | 80 | return result_path 81 | 82 | 83 | def realpath_url_or_path(url_or_path: str, root: str = None) -> str: 84 | if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'): 85 | return download(url_or_path, root) 86 | return url_or_path 87 | 88 | 89 | def images_to_numpy(tensor): 90 | generated = tensor.data.cpu().numpy().transpose(1,2,0) 91 | generated[generated < -1] = -1 92 | generated[generated > 1] = 1 93 | generated = (generated + 1) / 2 * 255 94 | return generated.astype('uint8') 95 | 96 | 97 | def save_image(ground_truth, images, out_dir, batch_idx): 98 | 99 | for i, im in enumerate(images): 100 | if len(im.shape) == 3: 101 | plt.imsave(os.path.join(out_dir, 'test_%s_%s.png' % (batch_idx, i)), im) 102 | else: 103 | bs = im.shape[0] 104 | # plt.imsave() 105 | for j in range(bs): 106 | plt.imsave(os.path.join(out_dir, 'test_%s_%s_%s.png' % (batch_idx, i, j)), im[j]) 107 | 108 | 109 | # print("Ground truth Images shape: ", ground_truth.shape, len(images)) 110 | 111 | # images = vutils.make_grid(images, nrow=ground_truth.shape[0]) 112 | # images = images_to_numpy(images) 113 | # 114 | # if ground_truth is not None: 115 | # ground_truth = vutils.make_grid(ground_truth, 5) 116 | # ground_truth = images_to_numpy(ground_truth) 117 | # print("Ground Truth shape, Generated Images shape: ", ground_truth.shape, images.shape) 118 | # images = np.concatenate([ground_truth, images], axis=0) 119 | # 120 | # output = Image.fromarray(images) 121 | # output.save('%s/fake_samples_epoch_%03d.png' % (out_dir, batch_idx)) 122 | 123 | # if texts is not None: 124 | # fid = open('%s/fake_samples_epoch_%03d_%03d.txt' % (image_dir, epoch, idx), 'w') 125 | # for idx in range(images.shape[0]): 126 | # fid.write(str(idx) + '--------------------------------------------------------\n') 127 | # for i in range(len(texts)): 128 | # fid.write(texts[i][idx] + '\n') 129 | # fid.write('\n\n') 130 | # fid.close() 131 | return -------------------------------------------------------------------------------- /story-dalle/didemo_dataloader.py: -------------------------------------------------------------------------------- 1 | import os, json, pickle 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch.utils.data 5 | from torchvision import transforms 6 | from collections import Counter 7 | import nltk 8 | from PIL import Image 9 | 10 | class ImageDataset(torch.utils.data.Dataset): 11 | def __init__(self, img_folder, tokenizer, preprocess, mode='train'): 12 | 13 | self.lengths = [] 14 | self.followings = [] 15 | self.dir_path = img_folder 16 | if mode == 'train': 17 | self.file_path = os.path.join(img_folder, 'train.json') 18 | elif mode == 'val': 19 | self.file_path = os.path.join(img_folder, 'val.json') 20 | else: 21 | self.file_path = os.path.join(img_folder, 'test.json') 22 | 23 | min_len = 4 24 | self.total_frames = 0 25 | self.images = [] 26 | if os.path.exists(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy')) and os.path.exists( 27 | os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')): 28 | self.images = np.load(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), encoding='latin1') 29 | self.followings = np.load(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')) 30 | else: 31 | print("Building image list cache") 32 | data = json.load(open(self.file_path)) 33 | for ex in data: 34 | # Set the first image as the frame in the first description 35 | dont_use = False 36 | for tup in ex['desc_to_frame']: 37 | if not os.path.exists(os.path.join('/'.join(list(os.path.abspath(img_folder).split('/'))[:-1]), tup[1])): 38 | dont_use = True 39 | if dont_use: 40 | continue 41 | self.images.append(ex['desc_to_frame'][0][1]) 42 | # Set remaining images to the rest of the images 43 | following_imgs = [tup[1] for tup in ex['desc_to_frame'][1:]] 44 | self.followings.append(following_imgs) 45 | np.save(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), self.images) 46 | np.save(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'), self.followings) 47 | print("Total number of clips {}".format(len(self.images))) 48 | 49 | self.descriptions_original = {tup[1]: tup[0] for ex in json.load(open(self.file_path)) for tup in ex['desc_to_frame']} 50 | self.preprocess = preprocess 51 | self.tokenizer = tokenizer 52 | 53 | def __getitem__(self, item): 54 | 55 | src_img_path = self.images[item] 56 | 57 | src_image_raw = Image.open(os.path.join(os.path.dirname(self.dir_path), src_img_path)).convert('RGB') 58 | src_image = self.preprocess(src_image_raw) 59 | # open the target image and caption 60 | caption = self.descriptions_original[src_img_path] 61 | tokens = self.tokenizer.encode(caption) 62 | tokens = torch.LongTensor(tokens.ids) 63 | 64 | return src_image, tokens 65 | 66 | def __len__(self): 67 | return len(self.images) 68 | 69 | 70 | class StoryImageDataset(torch.utils.data.Dataset): 71 | def __init__(self, img_folder, tokenizer, transform=None, mode='train'): 72 | 73 | self.lengths = [] 74 | self.followings = [] 75 | self.dir_path = img_folder 76 | if mode == 'train': 77 | self.file_path = os.path.join(img_folder, 'train.json') 78 | elif mode == 'val': 79 | self.file_path = os.path.join(img_folder, 'val.json') 80 | else: 81 | self.file_path = os.path.join(img_folder, 'test.json') 82 | 83 | min_len = 2 84 | self.total_frames = 0 85 | self.images = [] 86 | if os.path.exists(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy')) and os.path.exists( 87 | os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')): 88 | self.images = np.load(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), encoding='latin1') 89 | self.followings = np.load(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')) 90 | else: 91 | print("Building image list cache") 92 | data = json.load(open(self.file_path)) 93 | for ex in data: 94 | # Set the first image as the frame in the first description 95 | dont_use = False 96 | for tup in ex['desc_to_frame']: 97 | if not os.path.exists(os.path.join('/'.join(list(os.path.abspath(img_folder).split('/'))[:-1]), tup[1])): 98 | dont_use = True 99 | if dont_use: 100 | continue 101 | if len(ex['desc_to_frame']) < min_len+1: 102 | continue 103 | self.images.append(ex['desc_to_frame'][0][1]) 104 | # Set remaining images to the rest of the images 105 | following_imgs = [tup[1] for tup in ex['desc_to_frame'][1:1+min_len]] 106 | self.followings.append(following_imgs) 107 | np.save(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), self.images) 108 | np.save(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'), self.followings) 109 | print("Total number of clips {}".format(len(self.images))) 110 | self.descriptions_original = {tup[1]: tup[0] for ex in json.load(open(self.file_path)) for tup in ex['desc_to_frame']} 111 | 112 | if mode == 'train': 113 | if transform: 114 | self.transform = transform 115 | else: 116 | self.transform = transforms.Compose([ 117 | transforms.RandomResizedCrop(256), 118 | transforms.RandomHorizontalFlip(), 119 | transforms.ToTensor(), 120 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 121 | ]) 122 | else: 123 | if transform: 124 | self.transform = transform 125 | else: 126 | self.transform = transforms.Compose([ 127 | transforms.Resize(256), 128 | transforms.CenterCrop(256), 129 | transforms.ToTensor(), 130 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 131 | ]) 132 | 133 | self.descriptions_vec = pickle.load(open(os.path.join(img_folder, 'didemo_use_embeds.pkl'), 'rb')) 134 | self.tokenizer = tokenizer 135 | 136 | 137 | def __getitem__(self, item): 138 | 139 | frame_path_list = [self.images[item]] 140 | for i in range(len(self.followings[item])): 141 | frame_path_list.append(str(self.followings[item][i])) 142 | 143 | images = [] 144 | tokens = [] 145 | 146 | for img_path in frame_path_list: 147 | im = self.transform(Image.open(os.path.join(os.path.dirname(os.path.normpath(self.dir_path)), img_path))) 148 | images.append(im) 149 | if self.tokenizer is not None: 150 | tokens.append(self.tokenizer.encode(self.descriptions_original[img_path])) 151 | else: 152 | tokens.append(self.descriptions_original[img_path]) 153 | 154 | if self.tokenizer is not None: 155 | tokens = torch.stack([torch.LongTensor(token.ids) for token in tokens[1:]]) 156 | 157 | sent_embeds = [torch.tensor(self.descriptions_vec[frame_path]) for frame_path in frame_path_list[1:]] 158 | return torch.stack(images[1:]), tokens, images[0], torch.stack(sent_embeds) 159 | 160 | 161 | def __len__(self): 162 | return len(self.images) 163 | 164 | 165 | class CopyImageDataset(torch.utils.data.Dataset): 166 | def __init__(self, img_folder, tokenizer, preprocess, mode='train', video_len=2): 167 | 168 | self.lengths = [] 169 | self.followings = [] 170 | self.dir_path = img_folder 171 | if mode == 'train': 172 | self.file_path = os.path.join(img_folder, 'train.json') 173 | elif mode == 'val': 174 | self.file_path = os.path.join(img_folder, 'val.json') 175 | else: 176 | self.file_path = os.path.join(img_folder, 'test.json') 177 | self.video_len = video_len 178 | 179 | min_len = 4 180 | self.total_frames = 0 181 | self.images = [] 182 | if os.path.exists(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy')) and os.path.exists( 183 | os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')): 184 | self.images = np.load(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), encoding='latin1') 185 | self.followings = np.load(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')) 186 | else: 187 | print("Building image list cache") 188 | data = json.load(open(self.file_path)) 189 | for ex in data: 190 | # Set the first image as the frame in the first description 191 | dont_use = False 192 | for tup in ex['desc_to_frame']: 193 | if not os.path.exists(os.path.join('/'.join(list(os.path.abspath(img_folder).split('/'))[:-1]), tup[1])): 194 | dont_use = True 195 | if dont_use: 196 | continue 197 | self.images.append(ex['desc_to_frame'][0][1]) 198 | # Set remaining images to the rest of the images 199 | following_imgs = [tup[1] for tup in ex['desc_to_frame'][1:]] 200 | self.followings.append(following_imgs) 201 | np.save(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), self.images) 202 | np.save(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'), self.followings) 203 | print("Total number of clips {}".format(len(self.images))) 204 | 205 | self.descriptions_original = {tup[1]: tup[0] for ex in json.load(open(self.file_path)) for tup in ex['desc_to_frame']} 206 | self.preprocess = preprocess 207 | self.tokenizer = tokenizer 208 | 209 | def __getitem__(self, item): 210 | 211 | src_img_path = self.images[item] 212 | src_image_raw = Image.open(os.path.join(os.path.dirname(self.dir_path), src_img_path)).convert('RGB') 213 | src_image = self.preprocess(src_image_raw) 214 | 215 | tgt_img_paths = [str(self.followings[item][i]) for i in range(self.video_len)] 216 | # open the target image and caption 217 | tgt_images = [self.preprocess(Image.open(os.path.join(os.path.dirname(self.dir_path), tgt_img_path)).convert('RGB')) for tgt_img_path in tgt_img_paths] 218 | 219 | captions = [self.descriptions_original[tgt_img_path] for tgt_img_path in tgt_img_paths] 220 | tokens = [self.tokenizer.encode(caption) for caption in captions] 221 | tokens = [torch.LongTensor(token.ids) for token in tokens] 222 | 223 | return torch.stack(tgt_images), torch.stack(tokens), src_image 224 | 225 | def __len__(self): 226 | return len(self.images) 227 | 228 | 229 | class CopyStoryDataset(torch.utils.data.Dataset): 230 | def __init__(self, img_folder, preprocess, mode='train', max_t_len=72, video_len=5, resnet=False, condition_seq_len=0): 231 | 232 | self.lengths = [] 233 | self.followings = [] 234 | self.video_len = video_len 235 | min_len = video_len - 1 236 | 237 | if mode == 'train': 238 | self.file_path = os.path.join(img_folder, 'train.json') 239 | elif mode == 'val': 240 | self.file_path = os.path.join(img_folder, 'val.json') 241 | else: 242 | self.file_path = os.path.join(img_folder, 'test.json') 243 | 244 | self.dir_path = img_folder 245 | 246 | if os.path.exists(os.path.join(self.dir_path, 'dalle_vocab.pkl')): 247 | vocab_from_file = True 248 | vocab_file = os.path.join(self.dir_path, 'dalle_vocab.pkl') 249 | else: 250 | vocab_from_file = False 251 | vocab_file = os.path.join(self.dir_path, 'dalle_vocab.pkl') 252 | 253 | self.vocab = Vocabulary(vocab_threshold=5, 254 | vocab_file=vocab_file, 255 | annotations_file=os.path.join(self.dir_path, 'train.json'), 256 | vocab_from_file=vocab_from_file) 257 | 258 | print("Length of Vocabulary ", len(self.vocab)) 259 | 260 | self.descriptions_original = {tup[1]: tup[0] for ex in json.load(open(self.file_path)) for tup in ex['desc_to_frame']} 261 | 262 | self.total_frames = 0 263 | self.images = [] 264 | if os.path.exists(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy')) and os.path.exists( 265 | os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')): 266 | self.images = np.load(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), encoding='latin1') 267 | self.followings = np.load(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')) 268 | else: 269 | print("Building image list cache") 270 | data = json.load(open(self.file_path)) 271 | for ex in data: 272 | # Set the first image as the frame in the first description 273 | dont_use = False 274 | for tup in ex['desc_to_frame']: 275 | if not os.path.exists(os.path.join('/'.join(list(os.path.abspath(img_folder).split('/'))[:-1]), tup[1])): 276 | dont_use = True 277 | if dont_use: 278 | continue 279 | self.images.append(ex['desc_to_frame'][0][1]) 280 | # Set remaining images to the rest of the images 281 | following_imgs = [tup[1] for tup in ex['desc_to_frame'][1:]] 282 | self.followings.append(following_imgs) 283 | np.save(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), self.images) 284 | np.save(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'), self.followings) 285 | print("Total number of clips {}".format(len(self.images))) 286 | 287 | self.preprocess = preprocess 288 | self.max_t_len = max_t_len 289 | 290 | self.resnet = resnet 291 | im_input_size = 299 292 | if mode == 'train': 293 | self.transform = transforms.Compose([ 294 | transforms.Resize((im_input_size, im_input_size)), 295 | transforms.RandomHorizontalFlip(), 296 | transforms.ToTensor(), 297 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 298 | ]) 299 | else: 300 | self.transform = transforms.Compose([ 301 | transforms.Resize((im_input_size, im_input_size)), 302 | transforms.RandomHorizontalFlip(), 303 | transforms.ToTensor(), 304 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 305 | ]) 306 | 307 | self.condition_seq_len = condition_seq_len 308 | 309 | def __getitem__(self, item): 310 | 311 | # source image 312 | src_img_path = self.images[item].replace('didemo/', '') 313 | src_image_raw = Image.open(os.path.join(self.dir_path, src_img_path)).convert('RGB') 314 | # open the source images 315 | if self.resnet: 316 | src_image = self.transform(src_image_raw) 317 | else: 318 | src_image = self.preprocess(src_image_raw) 319 | 320 | # source caption 321 | src_caption = self.descriptions_original['didemo/' + src_img_path] 322 | src_text_tokens, src_text_mask = self.vocab._tokenize_pad_sentence(str(src_caption).lower(), self.max_t_len, 323 | condition=self.condition_seq_len) 324 | 325 | tgt_images = [] 326 | tgt_text_tokens = [src_text_tokens] 327 | tgt_text_masks = [src_text_mask] 328 | for i in range(0, self.video_len-1): 329 | tgt_img_path = str(self.followings[item][i]).replace('didemo/', '') 330 | # open the target image and caption 331 | caption = self.descriptions_original['didemo/' + tgt_img_path] 332 | tgt_image = self.preprocess(Image.open(os.path.join(self.dir_path, tgt_img_path)).convert('RGB')) 333 | tgt_images.append(tgt_image.unsqueeze(0)) 334 | # image = Image.open(os.path.join(self.out_dir, 'img-' + str(item) + '.png')).convert('RGB') 335 | text_tokens, text_mask = self.vocab._tokenize_pad_sentence(str(caption).lower(), self.max_t_len, condition=self.condition_seq_len) 336 | tgt_text_tokens.append(text_tokens) 337 | tgt_text_masks.append(text_mask) 338 | 339 | return torch.cat(tgt_images, dim=0), torch.LongTensor(tgt_text_tokens), torch.LongTensor(tgt_text_masks), src_image 340 | 341 | def __len__(self): 342 | return len(self.images) 343 | 344 | 345 | if __name__ == "__main__": 346 | 347 | dataset = StoryImageDataset('/nas-ssd/adyasha/datasets/didemo', None, None, 'val') 348 | 349 | all_captions = {} 350 | for item in range(len(dataset)): 351 | 352 | frame_path_list = [dataset.images[item]] 353 | for i in range(len(dataset.followings[item])): 354 | frame_path_list.append(str(dataset.followings[item][i])) 355 | captions = [dataset.descriptions_original[img_path] for img_path in frame_path_list] 356 | all_captions[item] = captions 357 | 358 | with open(os.path.join('/nas-ssd/adyasha/datasets/didemo', 'all_captions_val.json'), 'w') as f: 359 | json.dump(all_captions, f, indent=4) -------------------------------------------------------------------------------- /story-dalle/eval_char_clf.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch.nn as nn 4 | from torchvision import models 5 | import torch 6 | from tqdm import tqdm 7 | import numpy as np 8 | from scipy.stats import entropy 9 | import os 10 | import PIL 11 | import torchvision.utils as vutils 12 | import argparse 13 | from sklearn.metrics import classification_report, accuracy_score 14 | from torchvision import transforms 15 | 16 | epsilon = 1e-7 17 | 18 | 19 | 20 | def set_parameter_requires_grad(model, feature_extracting): 21 | if feature_extracting: 22 | for param in model.parameters(): 23 | param.requires_grad = False 24 | 25 | def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True): 26 | # Initialize these variables which will be set in this if statement. Each of these 27 | # variables is model specific. 28 | model_ft = None 29 | input_size = 0 30 | 31 | if model_name == "resnet": 32 | """ Resnet18 33 | """ 34 | model_ft = models.resnet18(pretrained=use_pretrained) 35 | set_parameter_requires_grad(model_ft, feature_extract) 36 | num_ftrs = model_ft.fc.in_features 37 | model_ft.fc = nn.Linear(num_ftrs, num_classes) 38 | input_size = 224 39 | 40 | if model_name == "resnet50": 41 | """ Resnet50 42 | """ 43 | model_ft = models.resnet50(pretrained=use_pretrained) 44 | set_parameter_requires_grad(model_ft, feature_extract) 45 | num_ftrs = model_ft.fc.in_features 46 | model_ft.fc = nn.Linear(num_ftrs, num_classes) 47 | input_size = 224 48 | 49 | elif model_name == "resnet101": 50 | """ Resnet50 51 | """ 52 | model_ft = models.resnet101(pretrained=use_pretrained) 53 | set_parameter_requires_grad(model_ft, feature_extract) 54 | num_ftrs = model_ft.fc.in_features 55 | model_ft.fc = nn.Linear(num_ftrs, num_classes) 56 | input_size = 224 57 | 58 | elif model_name == "alexnet": 59 | """ Alexnet 60 | """ 61 | model_ft = models.alexnet(pretrained=use_pretrained) 62 | set_parameter_requires_grad(model_ft, feature_extract) 63 | num_ftrs = model_ft.classifier[6].in_features 64 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 65 | input_size = 224 66 | 67 | elif model_name == "vgg": 68 | """ VGG11_bn 69 | """ 70 | model_ft = models.vgg11_bn(pretrained=use_pretrained) 71 | set_parameter_requires_grad(model_ft, feature_extract) 72 | num_ftrs = model_ft.classifier[6].in_features 73 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 74 | input_size = 224 75 | 76 | elif model_name == "squeezenet": 77 | """ Squeezenet 78 | """ 79 | model_ft = models.squeezenet1_0(pretrained=use_pretrained) 80 | set_parameter_requires_grad(model_ft, feature_extract) 81 | model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1)) 82 | model_ft.num_classes = num_classes 83 | input_size = 224 84 | 85 | 86 | elif model_name == "densenet": 87 | """ Densenet 88 | """ 89 | model_ft = models.densenet121(pretrained=use_pretrained) 90 | set_parameter_requires_grad(model_ft, feature_extract) 91 | num_ftrs = model_ft.classifier.in_features 92 | model_ft.classifier = nn.Linear(num_ftrs, num_classes) 93 | input_size = 224 94 | 95 | elif model_name == "inception": 96 | """ Inception v3 97 | Be careful, expects (299,299) sized images and has auxiliary output 98 | """ 99 | model_ft = models.inception_v3(pretrained=use_pretrained) 100 | set_parameter_requires_grad(model_ft, feature_extract) 101 | # Handle the auxilary net 102 | num_ftrs = model_ft.AuxLogits.fc.in_features 103 | model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) 104 | # Handle the primary net 105 | num_ftrs = model_ft.fc.in_features 106 | model_ft.fc = nn.Linear(num_ftrs,num_classes) 107 | input_size = 299 108 | 109 | else: 110 | print("Invalid model name, exiting...") 111 | exit() 112 | 113 | return model_ft, input_size 114 | 115 | def images_to_numpy(tensor): 116 | generated = tensor.data.cpu().numpy().transpose(1, 2, 0) 117 | generated[generated < -1] = -1 118 | generated[generated > 1] = 1 119 | generated = (generated + 1) / 2 * 255 120 | return generated.astype('uint8') 121 | 122 | 123 | def evaluate_gt(root_image_dir, model_name, model_path): 124 | 125 | if args.dataset == 'pororo': 126 | from pororo_dataloader import ImageDataset, StoryImageDataset 127 | elif args.dataset == 'flintstones': 128 | from flintstones_dataloader import ImageDataset, StoryImageDataset 129 | else: 130 | raise ValueError 131 | 132 | # Number of classes in the dataset 133 | num_classes = 9 134 | # when True we only update the reshaped layer params 135 | feature_extract = False 136 | video_len = 5 137 | n_channels = 3 138 | 139 | running_corrects = 0 140 | running_recalls = 0 141 | total_positives = 0 142 | 143 | phase = 'eval' 144 | is_inception = True if model_name == 'inception' else False 145 | 146 | model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=False) 147 | model_ft.load_state_dict(torch.load(model_path)) 148 | 149 | # Detect if we have a GPU available 150 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 151 | # Send the model to GPU 152 | model_ft = model_ft.to(device) 153 | model_ft.eval() # Set model to evaluate mode 154 | image_dataset = ImageDataset(root_image_dir, input_size, mode='val') 155 | print("Number of samples in evaluation set: %s" % len(image_dataset)) 156 | batch_size = 32 157 | 158 | # Create validation dataloaders 159 | dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size, shuffle=False, num_workers=4) 160 | 161 | print("Number of batches in evaluation dataloader: %s" % len(dataloader)) 162 | 163 | all_predictions = [] 164 | all_labels = [] 165 | story_accuracy = 0 166 | image_accuracy = 0 167 | 168 | # Iterate over data. 169 | for i, (inputs, labels) in tqdm(enumerate(dataloader)): 170 | 171 | inputs = inputs.to(device) 172 | labels = labels.to(device) 173 | 174 | with torch.set_grad_enabled(phase == 'train'): 175 | 176 | outputs = model_ft(inputs) 177 | if model_name == 'imgD': 178 | outputs = model_ft.cate_classify(outputs).squeeze() 179 | preds = torch.round(nn.functional.sigmoid(outputs)) 180 | all_predictions.append(preds.cpu().numpy()) 181 | all_labels.append(labels.cpu().numpy()) 182 | 183 | # statistics 184 | iter_corrects = torch.sum(preds == labels.float().data) 185 | xidxs, yidxs = torch.where(labels.data == 1) 186 | # print(xidxs, yidxs) 187 | # print([labels.data[xidx, yidx] == preds[xidx, yidx] for xidx, yidx in zip(xidxs, yidxs)]) 188 | iter_recalls = sum( 189 | [x.item() for x in 190 | [labels.float().data[xidx, yidx] == preds[xidx, yidx] for xidx, yidx in zip(xidxs, yidxs)]]) 191 | total_positives += xidxs.size(0) 192 | 193 | for l, p in zip(labels, preds): 194 | if torch.all(torch.eq(l.float().data, p)): 195 | image_accuracy += 1 196 | 197 | running_corrects += iter_corrects 198 | running_recalls += iter_recalls 199 | 200 | epoch_acc = running_corrects * 100 / (len(image_dataset) * num_classes) 201 | epoch_recall = running_recalls * 100 / total_positives 202 | print('{} Acc: {:.4f} Recall: {:.4f}%'.format(phase, epoch_acc, epoch_recall)) 203 | print('{} Story Exact Match Acc: {:.4f}%'.format(phase, float(story_accuracy) * 100 / len(image_dataset))) 204 | print('{} Image Exact Match Acc: {:.4f}%'.format(phase, float(image_accuracy) * 100 / len(image_dataset))) 205 | 206 | all_predictions = np.concatenate(all_predictions, axis=0) 207 | all_labels = np.concatenate(all_labels, axis=0) 208 | print(all_predictions.shape, all_labels.shape, image_accuracy, len(image_dataset)) 209 | preds = np.round(1 / (1 + np.exp(-all_predictions))) 210 | print(classification_report(all_labels, preds, digits=4)) 211 | 212 | # for i in range(0, 9): 213 | # print("Character %s" % i) 214 | # print(classification_report(all_labels[:, i], preds[:, i])) 215 | 216 | # Inception Score 217 | # all_predictions = all_predictions + epsilon 218 | # py = np.mean(all_predictions, axis=0) 219 | # print(py, py.shape) 220 | # split_scores = [] 221 | # splits = 10 222 | # N = all_predictions.shape[0] 223 | # for k in range(splits): 224 | # part = all_predictions[k * (N // splits): (k + 1) * (N // splits), :] 225 | # py = np.mean(part, axis=0) 226 | # scores = [] 227 | # 228 | # for i in range(part.shape[0]): 229 | # pyx = part[i, :] 230 | # scores.append(entropy(pyx, py)) 231 | # split_scores.append(np.exp(np.mean(scores))) 232 | # print("InceptionScore", np.mean(split_scores), np.std(split_scores)) 233 | 234 | 235 | def evaluate(args): 236 | 237 | root_image_dir, model_name, model_path = args.img_ref_dir, args.model_name, args.model_path 238 | 239 | if args.dataset == 'pororo': 240 | from pororo_dataloader import ImageDataset, StoryImageDataset 241 | elif args.dataset == 'flintstones': 242 | from flintstones_dataloader import ImageDataset, StoryImageDataset 243 | else: 244 | raise ValueError 245 | 246 | # when True we only update the reshaped layer params 247 | feature_extract = False 248 | video_len = 5 249 | n_channels = 3 250 | 251 | phase = 'eval' 252 | 253 | model_ft, input_size = initialize_model(model_name, args.num_classes, feature_extract, use_pretrained=False) 254 | model_ft.load_state_dict(torch.load(model_path)) 255 | 256 | img_transform = transforms.Compose([ 257 | # Image.fromarray, 258 | transforms.Resize(input_size), 259 | transforms.CenterCrop(input_size), 260 | transforms.ToTensor(), 261 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 262 | ]) 263 | 264 | # Detect if we have a GPU available 265 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 266 | # Send the model to GPU 267 | model_ft = model_ft.to(device) 268 | model_ft.eval() # Set model to evaluate mode 269 | 270 | # Create training and validation datasets 271 | try: 272 | image_dataset = StoryImageDataset(args.img_ref_dir, None, preprocess=img_transform, 273 | mode=args.mode, 274 | out_img_folder=args.img_gen_dir, 275 | return_labels=True) 276 | except TypeError: 277 | image_dataset = StoryImageDataset(args.img_ref_dir, None, transform=img_transform, 278 | mode=args.mode, 279 | out_img_folder=args.img_gen_dir, 280 | return_labels=True) 281 | print("Number of samples in evaluation set: %s" % len(image_dataset)) 282 | batch_size = 20 283 | 284 | # Create validation dataloaders 285 | dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=1) 286 | 287 | print("Number of batches in evaluation dataloader: %s" % len(dataloader)) 288 | 289 | all_predictions = [] 290 | all_labels = [] 291 | story_accuracy = 0 292 | image_accuracy = 0 293 | 294 | running_corrects = 0 295 | running_recalls = 0 296 | total_positives = 0 297 | 298 | # Iterate over data. 299 | no_char_images = 0 300 | for i, batch in tqdm(enumerate(dataloader)): 301 | 302 | inputs = batch[0] 303 | labels = batch[1] 304 | 305 | inputs = inputs.view(-1, n_channels, inputs.shape[-2], inputs.shape[-1]) 306 | labels = labels.view(-1, labels.shape[-1]) 307 | assert inputs.shape[0] == labels.shape[0] 308 | inputs = inputs.to(device) 309 | labels = labels.to(device) 310 | 311 | # forward 312 | # track history if only in train 313 | with torch.no_grad(): 314 | # Get model outputs and calculate loss 315 | # Special case for inception because in training it has an auxiliary output. In train 316 | # mode we calculate the loss by summing the final output and the auxiliary output 317 | # but in testing we only consider the final output. 318 | outputs = model_ft(inputs) 319 | preds = torch.round(nn.functional.sigmoid(outputs)) 320 | all_predictions.append(preds.cpu().numpy()) 321 | all_labels.append(labels.cpu().numpy()) 322 | 323 | # statistics 324 | iter_corrects = torch.sum(preds == labels.float().data) 325 | xidxs, yidxs = torch.where(labels.data == 1) 326 | # print(xidxs, yidxs) 327 | # print([labels.data[xidx, yidx] == preds[xidx, yidx] for xidx, yidx in zip(xidxs, yidxs)]) 328 | iter_recalls = sum( 329 | [x.item() for x in [labels.float().data[xidx, yidx] == preds[xidx, yidx] for xidx, yidx in zip(xidxs, yidxs)]]) 330 | total_positives += xidxs.size(0) 331 | 332 | labels = labels.view(-1, labels.shape[-1]) 333 | preds = preds.view(-1, labels.shape[-1]) 334 | assert labels.shape[0] == preds.shape[0] 335 | 336 | 337 | for label, pred in zip(labels, preds): 338 | if not torch.any(label): 339 | no_char_images += 1 340 | if torch.all(torch.eq(label.float().data, pred)): 341 | image_accuracy += 1 342 | 343 | running_corrects += iter_corrects 344 | running_recalls += iter_recalls 345 | 346 | print("Frames with no images: ", no_char_images) 347 | 348 | all_predictions = np.concatenate(all_predictions, axis=0) 349 | all_labels = np.concatenate(all_labels, axis=0) 350 | print(all_predictions.shape, all_labels.shape, image_accuracy, len(image_dataset)) 351 | # preds = np.round(1 / (1 + np.exp(-all_predictions))) 352 | print(classification_report(all_labels, all_predictions, digits=4)) 353 | print("Accuracy: ", accuracy_score(all_labels, all_predictions)) 354 | 355 | epoch_acc = float(running_corrects) * 100 / (all_labels.shape[0] * all_labels.shape[1]) 356 | epoch_recall = float(running_recalls) * 100 / total_positives 357 | print('Manually calculated accuracy: ', epoch_acc) 358 | print('{} Acc: {:.4f} Recall: {:.4f}%'.format(phase, accuracy_score(all_labels, all_predictions), epoch_recall)) 359 | print('{} Image Exact Match (Frame) Acc: {:.4f}%'.format(phase, image_accuracy * 100 / all_labels.shape[0])) 360 | 361 | 362 | if __name__ == "__main__": 363 | 364 | parser = argparse.ArgumentParser(description='Evaluate for Character Recall & InceptionScore') 365 | parser.add_argument('--dataset', type=str, default='pororo') 366 | parser.add_argument('--img_ref_dir', type=str, required=True) 367 | parser.add_argument('--img_gen_dir', type=str, required=True) 368 | parser.add_argument('--num_classes', type=int, required=True) 369 | parser.add_argument('--model_path', type=str, required=True) 370 | parser.add_argument('--model_name', type=str, required=True) 371 | parser.add_argument('--mode', type=str, required=True) 372 | parser.add_argument('--ground_truth', action='store_true') 373 | args = parser.parse_args() 374 | 375 | if args.ground_truth: 376 | evaluate_gt(args) 377 | else: 378 | evaluate(args) 379 | 380 | # numpy_to_img(os.path.join(args.image_dir, 'images-epoch-%s.npy' % args.epoch_start), 381 | # os.path.join(args.image_dir, 'images-epoch-%s/' % args.epoch_start), 299) 382 | 383 | -------------------------------------------------------------------------------- /story-dalle/eval_char_clf.sh: -------------------------------------------------------------------------------- 1 | #python eval_char_clf.py --dataset pororo --img_ref_dir /nas-ssd/adyasha/datasets/pororo_png/ --img_gen_dir /nas-ssd/adyasha/out/minDALLEs/pororo/test_images/images/ --model_name inception --model_path ./out/pororo-epoch-10.pt --mode test --num_classes 9 2 | 3 | python eval_char_clf.py --dataset flintstones --img_ref_dir /nas-ssd/adyasha/datasets/flintstones/ --img_gen_dir /nas-ssd/adyasha/out/minDALLEs/flintstones/test_images/images/ --model_name inception --model_path /playpen-ssd/adyasha/projects/StoryGAN/classifier/models/inception_32_1e-05/best.pt --mode test --num_classes 7 4 | -------------------------------------------------------------------------------- /story-dalle/eval_fid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import PIL 4 | import argparse 5 | import functools 6 | import os 7 | from vfid.fid_score import fid_score 8 | import torchvision.datasets as datasets 9 | 10 | def main(args): 11 | 12 | 13 | image_transforms = transforms.Compose([ 14 | transforms.Resize((args.imsize, args.imsize)), 15 | # transforms.RandomHorizontalFlip(), 16 | transforms.ToTensor(), 17 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 18 | 19 | if args.task == 'pororo': 20 | import pororo_dataloader as data 21 | elif args.task == 'flintstones': 22 | import flintstones_dataloader as data 23 | elif args.task == 'didemo': 24 | import didemo_dataloader as data 25 | else: 26 | raise ValueError 27 | 28 | try: 29 | ref_dataset = data.StoryImageDataset(args.img_ref_dir, 30 | None, 31 | preprocess=image_transforms, 32 | mode=args.mode) 33 | except TypeError: 34 | ref_dataset = data.StoryImageDataset(args.img_ref_dir, 35 | None, 36 | transform=image_transforms, 37 | mode=args.mode) 38 | 39 | gen_dataset = datasets.ImageFolder(root=args.img_gen_dir, transform=image_transforms) 40 | 41 | 42 | fid = fid_score(ref_dataset, gen_dataset, cuda=True, normalize=True, r_cache=os.path.join(args.img_ref_dir, 'fid_cache_%s.npz' % args.mode), batch_size=1) 43 | print('Frechet Image Distance: ', fid) 44 | 45 | 46 | if __name__ == "__main__": 47 | 48 | parser = argparse.ArgumentParser(description='Evaluate Frechet Story and Image distance') 49 | parser.add_argument('--img_ref_dir', type=str, required=True) 50 | parser.add_argument('--img_gen_dir', type=str, required=True) 51 | parser.add_argument('--mode', type=str, default='test') 52 | parser.add_argument('--task', type=str, default='pororo') 53 | parser.add_argument('--imsize', type=int, default=64) 54 | args = parser.parse_args() 55 | 56 | print(args) 57 | main(args) 58 | -------------------------------------------------------------------------------- /story-dalle/eval_fid.sh: -------------------------------------------------------------------------------- 1 | python eval_fid.py --task pororo --img_ref_dir /nas-ssd/adyasha/datasets/pororo/ --img_gen_dir /nas-ssd/adyasha/out/minDALLEs/pororo/test_images/ --mode test 2 | 3 | python eval_fid.py --task didemo --img_ref_dir /nas-ssd/adyasha/datasets/didemo/ --img_gen_dir /nas-ssd/adyasha/out/minDALLEs/didemo/test_images/ --mode test -------------------------------------------------------------------------------- /story-dalle/flintstones_dataloader.py: -------------------------------------------------------------------------------- 1 | import os, pickle 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch.utils.data 5 | import PIL 6 | from random import randrange 7 | import json 8 | from torchvision import transforms 9 | from PIL import Image 10 | 11 | unique_characters = ["Wilma", "Fred", "Betty", "Barney", "Dino", "Pebbles", "Mr Slate"] 12 | 13 | class ImageDataset(torch.utils.data.Dataset): 14 | def __init__(self, dir_path, tokenizer, preprocess, mode='train'): 15 | self.dir_path = dir_path 16 | 17 | splits = json.load(open(os.path.join(self.dir_path, 'train-val-test_split.json'), 'r')) 18 | train_id, val_id, test_id = splits["train"], splits["val"], splits["test"] 19 | 20 | if mode == 'train': 21 | self.orders = train_id 22 | elif mode =='val': 23 | self.orders = val_id 24 | elif mode == 'test': 25 | self.orders = test_id 26 | else: 27 | raise ValueError 28 | print("Total number of clips {}".format(len(self.orders))) 29 | 30 | annotations = json.load(open(os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json'))) 31 | self.descriptions = {} 32 | for sample in annotations: 33 | self.descriptions[sample["globalID"]] = sample["description"] 34 | 35 | self.preprocess = preprocess 36 | self.tokenizer = tokenizer 37 | 38 | def __getitem__(self, item): 39 | 40 | # single image input 41 | globalID = self.orders[item] 42 | path = os.path.join(self.dir_path, 'video_frames_sampled', globalID + '.npy') 43 | arr = np.load(path) 44 | n_frames = arr.shape[0] 45 | random_range = randrange(n_frames) 46 | im = arr[random_range] 47 | image = np.array(im) 48 | image = PIL.Image.fromarray(image.astype('uint8'), 'RGB') 49 | text = self.descriptions[globalID] 50 | tokens = self.tokenizer.encode(text.lower()) 51 | tokens = torch.LongTensor(tokens.ids) 52 | image = self.preprocess(image) 53 | 54 | return image, tokens 55 | 56 | def __len__(self): 57 | return len(self.orders) 58 | 59 | 60 | class StoryImageDataset(torch.utils.data.Dataset): 61 | def __init__(self, dir_path, tokenizer, transform=None, mode='train', im_input_size=128, out_img_folder='', return_labels=False): 62 | self.dir_path = dir_path 63 | 64 | splits = json.load(open(os.path.join(self.dir_path, 'train-val-test_split.json'), 'r')) 65 | train_id, val_id, test_id = splits["train"], splits["val"], splits["test"] 66 | 67 | min_len = 4 68 | if os.path.exists(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl')): 69 | self.followings = pickle.load(open(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl'), 'rb')) 70 | else: 71 | print("Cache does not exist") 72 | all_clips = train_id + val_id + test_id 73 | all_clips.sort() 74 | for idx, clip in enumerate(tqdm(all_clips, desc="Counting total number of frames")): 75 | season, episode = int(clip.split('_')[1]), int(clip.split('_')[3]) 76 | has_frames = True 77 | for c in all_clips[idx+1:idx+min_len+1]: 78 | s_c, e_c = int(c.split('_')[1]), int(c.split('_')[3]) 79 | if s_c != season or e_c != episode: 80 | has_frames = False 81 | break 82 | if has_frames: 83 | self.followings[clip] = all_clips[idx+1:idx+min_len+1] 84 | else: 85 | continue 86 | pickle.dump(self.followings, open(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl'), 'wb')) 87 | 88 | train_id = [tid for tid in train_id if tid in self.followings] 89 | val_id = [vid for vid in val_id if vid in self.followings] 90 | test_id = [tid for tid in test_id if tid in self.followings] 91 | 92 | self.labels = pickle.load(open(os.path.join(dir_path, 'labels.pkl'), 'rb')) 93 | 94 | if mode == 'train': 95 | self.orders = train_id 96 | elif mode =='val': 97 | val_id = [vid for vid in val_id if len(self.followings[vid]) == 4] 98 | self.orders = val_id 99 | elif mode == 'test': 100 | test_id = [vid for vid in test_id if len(self.followings[vid]) == 4] 101 | self.orders = test_id[:1900] 102 | else: 103 | raise ValueError 104 | print("Total number of clips {}".format(len(self.orders))) 105 | 106 | annotations = json.load(open(os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json'))) 107 | self.descriptions = {} 108 | for sample in annotations: 109 | self.descriptions[sample["globalID"]] = sample["description"] 110 | 111 | self.embeds = np.load(os.path.join(self.dir_path, "flintstones_use_embeddings.npy")) 112 | self.sent2idx = pickle.load(open(os.path.join(self.dir_path, 'flintstones_use_embed_idxs.pkl'), 'rb')) 113 | 114 | if mode == 'train': 115 | if transform: 116 | self.transform = transform 117 | else: 118 | self.transform = transforms.Compose([ 119 | transforms.RandomResizedCrop(im_input_size), 120 | transforms.RandomHorizontalFlip(), 121 | transforms.ToTensor(), 122 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 123 | ]) 124 | else: 125 | if transform: 126 | self.transform = transform 127 | else: 128 | self.transform = transforms.Compose([ 129 | transforms.Resize(im_input_size), 130 | transforms.CenterCrop(im_input_size), 131 | transforms.ToTensor(), 132 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 133 | ]) 134 | 135 | self.tokenizer = tokenizer 136 | self.return_labels = return_labels 137 | self.out_img_folder = out_img_folder 138 | 139 | def __getitem__(self, item): 140 | 141 | # single image input 142 | globalIDs = [self.orders[item]] + self.followings[self.orders[item]] 143 | tokens = [] 144 | images = [] 145 | for idx, globalID in enumerate(globalIDs): 146 | if self.out_img_folder and idx != 0: 147 | image = Image.open(os.path.join(self.out_img_folder, 'gen_sample_%s_%s.png' % (item, idx-1))).convert('RGB') 148 | else: 149 | path = os.path.join(self.dir_path, 'video_frames_sampled', globalID + '.npy') 150 | arr = np.load(path) 151 | n_frames = arr.shape[0] 152 | random_range = randrange(n_frames) 153 | im = arr[random_range] 154 | # image = np.array(im) 155 | image = PIL.Image.fromarray(im.astype('uint8'), 'RGB') 156 | images.append(image) 157 | text = self.descriptions[globalID] 158 | if idx != 0: 159 | if self.tokenizer is not None: 160 | tokens.append(self.tokenizer.encode(text.lower())) 161 | else: 162 | tokens.append(text) 163 | if self.tokenizer is not None: 164 | tokens = torch.stack([torch.LongTensor(token.ids) for token in tokens]) 165 | 166 | sent_embeds = [torch.tensor(self.embeds[self.sent2idx[globalID]]) for globalID in globalIDs[1:]] 167 | 168 | if self.return_labels: 169 | labels = [torch.tensor(self.labels[globalID]) for globalID in globalIDs[1:]] 170 | return torch.stack([self.transform(im) for im in images[1:]]), torch.stack(labels), tokens, self.transform( 171 | images[0]), torch.stack(sent_embeds) 172 | else: 173 | return torch.stack([self.transform(im) for im in images[1:]]), tokens, self.transform(images[0]), torch.stack(sent_embeds) 174 | 175 | def __len__(self): 176 | return len(self.orders) 177 | 178 | 179 | # if __name__ == "__main__": 180 | # 181 | # dataset = StoryImageDataset('/nas-ssd/adyasha/datasets/flintstones', None, None, 'val') 182 | # for item in range(len(dataset)): 183 | # texts = [] 184 | # globalIDs = [dataset.orders[item]] + dataset.followings[dataset.orders[item]] 185 | # for idx, globalID in enumerate(globalIDs): 186 | # text = dataset.descriptions[globalID] 187 | # texts.append(text) 188 | # if len(texts) != 5: 189 | # print(item, globalIDs) 190 | 191 | -------------------------------------------------------------------------------- /story-dalle/get_use_embeddings.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import pickle 4 | from tqdm import tqdm 5 | import numpy as np 6 | import sys 7 | import os 8 | 9 | def get_embeddings(dataset='didemo', data_dir = ''): 10 | 11 | 12 | if dataset == 'flintstones': 13 | annotations = json.load(open('../flintstones/flintstones_annotations_v1-0.json', 'r')) 14 | globalIDs = [s["globalID"] for s in annotations] 15 | descriptions = [s["description"] for s in annotations] 16 | elif dataset == 'didemo': 17 | descriptions = {} 18 | for filepath in ['train.json', 'val.json', 'test.json']: 19 | d = {tup[1]: tup[0] for ex in json.load(open(os.path.join(data_dir, filepath))) for tup in ex['desc_to_frame']} 20 | descriptions.update(d) 21 | elif dataset == 'mpii': 22 | all_keys = [] 23 | descriptions = {} 24 | for filepath in ['train.json', 'val.json', 'test.json']: 25 | data = json.load(open(os.path.join(data_dir, filepath))) 26 | print(len(data)) 27 | for ex in tqdm(data): 28 | for tup in ex['desc_to_frame']: 29 | k = tup[1].replace('/ssd-playpen/dhannan/StoryDatasets/mpii/', '') 30 | all_keys.append(k) 31 | descriptions[k] = tup[0] 32 | print(len(descriptions), len(all_keys), len(set(all_keys))) 33 | print(set(all_keys)) 34 | else: 35 | raise ValueError 36 | sys.exit() 37 | 38 | import tensorflow_hub as hub 39 | embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5") 40 | all_embeddings = None 41 | bs = 128 42 | description_keys = list(descriptions.keys()) 43 | for i in tqdm(range(0, math.ceil(len(description_keys)/bs)), desc="Extraction seed embeddings"): 44 | # embeddings = embed(descriptions[i*bs:(i+1)*bs]).numpy() 45 | embeddings = embed([descriptions[k] for k in description_keys[i * bs:(i + 1) * bs]]).numpy() 46 | if all_embeddings is None: 47 | all_embeddings = embeddings 48 | else: 49 | all_embeddings = np.concatenate([all_embeddings, embeddings], axis=0) 50 | print(all_embeddings.shape, len(description_keys)) 51 | embeddings = {k: v for v, k in zip(all_embeddings, description_keys)} 52 | pickle.dump(embeddings, open(os.path.join(data_dir, '%s_use_embed_idxs.pkl' % dataset), 'wb')) 53 | 54 | # np.save(os.path.join(data_dir, '%s_use_embeddings.npy' % dataset), all_embeddings) 55 | # pickle.dump({key: val for val, key in enumerate(globalIDs)}, open(os.path.join(data_dir, '%s_use_embed_idxs.pkl'), 'wb')) 56 | 57 | 58 | # get_embeddings('didemo', '/nas-ssd/adyasha/datasets/didemo') 59 | get_embeddings('mpii', '/nas-ssd/adyasha/datasets/mpii') -------------------------------------------------------------------------------- /story-dalle/infer_story.sh: -------------------------------------------------------------------------------- 1 | if [ "$1" = "pororo" ]; then 2 | echo "Evaluating on Pororo" 3 | DATA_DIR=../data/pororo 4 | OUTPUT_ROOT=../out/pororo 5 | MODEL_CKPT='' 6 | SENT_EMBED=512 7 | STORY_LEN=4 8 | elif [ "$1" = "flintstones" ]; then 9 | echo "Evaluating on Flintstones" 10 | DATA_DIR=../data/flintstones 11 | OUTPUT_ROOT=./out/flintstones 12 | MODEL_CKPT='' 13 | SENT_EMBED=512 14 | STORY_LEN=4 15 | elif [ "$1" = "didemo" ]; then 16 | echo "Evaluating on DiDeMo" 17 | DATA_DIR=../data/didemo 18 | OUTPUT_ROOT=./out/didemo 19 | MODEL_CKPT='' 20 | SENT_EMBED=512 21 | STORY_LEN=2 22 | fi 23 | 24 | 25 | python ./infer_t2i.py \ 26 | --model_name_or_path $MODEL_CKPT \ 27 | --prefix_model_name_or_path './1.3B/' \ 28 | --dataset_name $1 \ 29 | --tuning_mode story \ 30 | --dataset_name $1 \ 31 | --preseqlen 32 \ 32 | --condition \ 33 | --story_len $STORY_LEN \ 34 | --sent_embed $SENT_EMBED \ 35 | --prefix_dropout 0.2 \ 36 | --data_dir $DATA_DIR \ 37 | --dataloader_num_workers 1 \ 38 | --do_eval \ 39 | --per_gpu_eval_batch_size 16 \ 40 | --output_dir $OUTPUT_ROOT \ 41 | --mode $2 42 | -------------------------------------------------------------------------------- /story-dalle/train_story.sh: -------------------------------------------------------------------------------- 1 | if [ "$1" = "pororo" ]; then 2 | echo "Training on Pororo" 3 | DATA_DIR=../data/pororo/ 4 | OUTPUT_ROOT=./out/pororo 5 | SENT_EMBED=512 6 | STORY_LEN=4 7 | LR=1e-4 8 | TRAIN_BS=1 9 | GRAD_ACC=4 10 | elif [ "$1" = "flintstones" ]; then 11 | echo "Training on Flintstones" 12 | DATA_DIR=../data/flintstones 13 | OUTPUT_ROOT=./out/flintstones 14 | SENT_EMBED=512 15 | STORY_LEN=4 16 | LR=1e-5 17 | TRAIN_BS=1 18 | GRAD_ACC=4 19 | elif [ "$1" = "didemo" ]; then 20 | echo "Training on DiDeMo" 21 | DATA_DIR=../data/didemo 22 | OUTPUT_ROOT=./out/didemo 23 | SENT_EMBED=512 24 | STORY_LEN=2 25 | TRAIN_BS=1 26 | GRAD_ACC=8 27 | fi 28 | 29 | LOG_DIR=../runs/ 30 | 31 | python ./train_t2i.py \ 32 | --prefix_model_name_or_path './1.3B/' \ 33 | --tuning_mode story \ 34 | --dataset_name $1 \ 35 | --preseqlen 32 \ 36 | --condition \ 37 | --story_len $STORY_LEN \ 38 | --sent_embed $SENT_EMBED \ 39 | --prefix_dropout 0.2 \ 40 | --data_dir $DATA_DIR \ 41 | --dataloader_num_workers 4 \ 42 | --output_dir $OUTPUT_ROOT \ 43 | --log_dir $LOG_DIR \ 44 | --do_train --do_eval \ 45 | --per_gpu_train_batch_size $TRAIN_BS \ 46 | --per_gpu_eval_batch_size 2 \ 47 | --num_train_epochs 50 \ 48 | --gradient_accumulation_steps $GRAD_ACC \ 49 | --learning_rate $LR \ 50 | --logging_steps 50 \ 51 | --eval_steps 500 \ 52 | --generate_steps 1000 53 | 54 | -------------------------------------------------------------------------------- /story-dalle/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.utils as vutils 4 | from torchvision.utils import save_image 5 | from tqdm import tqdm 6 | import numpy as np 7 | import PIL 8 | 9 | def save_as_image(tensor, out_dir, suffix): 10 | img = vutils.make_grid(tensor, video_len=1, padding=0) 11 | save_image(img, '%s/gen_sample_%s.png' % (out_dir, suffix)) 12 | 13 | def acc_tensors_to_images(tensor_dir, key, out_dir): 14 | 15 | files = [f for f in os.listdir(tensor_dir) if f.endswith('pt') and key in f] 16 | sorted_files = sorted(files, key=lambda x: int(x[:-3].split('_')[-1])) 17 | print(files[:10]) 18 | print(sorted_files[:20]) 19 | all_tensors = [] 20 | for f in tqdm(files, desc="eading tensors"): 21 | t = torch.load(os.path.join(tensor_dir, f)) 22 | # print(t[0].shape) 23 | # print(t.shape) 24 | all_tensors.append(t) 25 | all_tensors = torch.cat(all_tensors, dim=0) 26 | print(all_tensors.shape) 27 | 28 | torch.save(all_tensors, os.path.join(tensor_dir, 'sdalle_story_%s.pt' % key)) 29 | 30 | # for i in tqdm(range(0, all_tensors.shape[0]), desc='Preapring images'): 31 | # for j in range(0, all_tensors.shape[1]): 32 | # save_as_image(all_tensors[i, j], out_dir, '%s_%s' % (i, j)) 33 | 34 | def images_to_numpy(tensor): 35 | generated = tensor.data.cpu().numpy().transpose(1, 2, 0) 36 | generated[generated < -1] = -1 37 | generated[generated > 1] = 1 38 | generated = (generated + 1) / 2 * 255 39 | return generated.astype('uint8') 40 | 41 | def numpy_to_img(numpy_file, outdir, img_size): 42 | 43 | if not os.path.exists(outdir): 44 | os.makedirs(outdir) 45 | 46 | x = np.load(numpy_file) 47 | print("Numpy image file shape: ", x.shape) 48 | for i in tqdm(range(x.shape[0])): 49 | frames = x[i, :, :, :, :] 50 | frames = np.swapaxes(frames, 0, 1) 51 | # frames = torch.Tensor(frames).view(-1, 3, 64, 64) 52 | # frames = torch.nn.functional.upsample(frames, size=(img_size, img_size), mode="bilinear") 53 | 54 | # vutils.save_image(vutils.make_grid(torch.Tensor(frames).view(-1, 3, 64, 64), 1, padding=0), 'sequence-2.png') 55 | all_images = images_to_numpy(vutils.make_grid(torch.Tensor(frames).view(-1, 3, 64, 64), 1, padding=0)) 56 | # all_images = images_to_numpy(vutils.make_grid(frames, 1, padding=0)) 57 | # print(all_images.shape) 58 | for j, idx in enumerate(range(64, all_images.shape[0] + 1, 64)): 59 | output = PIL.Image.fromarray(all_images[idx-64: idx, :, :]) 60 | output.save(os.path.join(outdir, 'img-%s-%s.png' % (i, j))) 61 | img = PIL.Image.open(os.path.join(outdir, 'img-%s-%s.png' % (i, j))) 62 | if img_size != 64: 63 | img = img.resize((img_size, img_size,)) 64 | img.save(os.path.join(outdir, 'img-%s-%s.png' % (i, j))) 65 | 66 | if __name__ == "__main__": 67 | 68 | acc_tensors_to_images('/nas-ssd/adyasha/out/minDALLEs/pororo/', 'test', '/nas-ssd/adyasha/out/minDALLEs/pororo/test_images') 69 | # acc_tensors_to_images('/nas-ssd/adyasha/out/minDALLEs/didemo/', 'test', '/nas-ssd/adyasha/out/minDALLEs/didemo/test_images/images') 70 | # acc_tensors_to_images('/nas-ssd/adyasha/out/minDALLEs/flintstones/', 'test', '/nas-ssd/adyasha/out/minDALLEs/flintstones/test_images/images') 71 | 72 | # numpy_to_img('/nas-ssd/adyasha/out/SGANc/didemo/val-images-epoch-120.npy', '/nas-ssd/adyasha/out/SGANc/didemo/val_images/', 299) -------------------------------------------------------------------------------- /story-dalle/vfid/__pycache__/fid_score.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/vfid/__pycache__/fid_score.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/vfid/__pycache__/inception.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/vfid/__pycache__/inception.cpython-38.pyc -------------------------------------------------------------------------------- /story-dalle/vfid/fid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Frechet Inception Distance (FID) to evalulate Video GAN 3 | 4 | The difference of this GAN is replacing the original encoder using residual 2+1 encoder 5 | 6 | The FID metric calculates the distance between two distributions of images. 7 | Typically, we have summary statistics (mean & covariance matrix) of one 8 | of these distributions, while the 2nd distribution is given by a GAN. 9 | 10 | When run as a stand-alone program, it compares the distribution of 11 | images that are stored as PNG/JPEG at a specified location with a 12 | distribution given by summary statistics (in pickle format). 13 | 14 | The FID is calculated by assuming that X_1 and X_2 are the activations of 15 | the pool_3 layer of the inception net for generated samples and real world 16 | samples respectively. 17 | 18 | See --help to see further details. 19 | 20 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 21 | of Tensorflow 22 | 23 | Copyright 2018 Institute of Bioinformatics, JKU Linz 24 | 25 | Licensed under the Apache License, Version 2.0 (the "License"); 26 | you may not use this file except in compliance with the License. 27 | You may obtain a copy of the License at 28 | 29 | http://www.apache.org/licenses/LICENSE-2.0 30 | 31 | Unless required by applicable law or agreed to in writing, software 32 | distributed under the License is distributed on an "AS IS" BASIS, 33 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 34 | See the License for the specific language governing permissions and 35 | limitations under the License. 36 | """ 37 | import os 38 | import numpy as np 39 | import torch 40 | import torch.nn.functional as F 41 | from torch.utils.data import DataLoader 42 | from tqdm import tqdm 43 | from scipy import linalg 44 | import PIL 45 | import functools 46 | from .inception import InceptionV3 47 | 48 | def calculate_activation_statistics(imgs, model, batch_size=32, dims=2048, 49 | cuda=False, normalize=False, verbose=0, is_ref=False): 50 | """Calculates the activations of the pool_3 layer for all images. 51 | 52 | Params: 53 | imgs: image dataset 54 | model: Instance of inception model 55 | batch_size: Batch size of images for the model to process at once. 56 | Make sure that the number of samples is a multiple of the batch 57 | size, otherwise some samples are ignored. This behavior is retained 58 | to match the original FID score implementation. 59 | cuda: If set to True, use GPU 60 | normalize: If the value range of imgs is [-1, 1], set to True to 61 | shift value range to [0, 1]. 62 | verbose: If verbose > 0, show progressbar during evaluation 63 | Returns: 64 | mu: The mean over samples of the activations of the pool_3 layer of 65 | the inception model. 66 | sigma: The covariance matrix of the activations of the pool_3 layer of 67 | the inception model. 68 | """ 69 | model.eval() 70 | if cuda: 71 | device = torch.device('cuda') 72 | else: 73 | device = torch.device('cpu') 74 | model.to(device) 75 | 76 | with torch.no_grad(): 77 | features = [] 78 | features_cache = [] 79 | dataloader = DataLoader( 80 | imgs, batch_size=batch_size, num_workers=4 if is_ref else 0, drop_last=True, shuffle=False) 81 | if verbose > 0: 82 | iter_dataset = tqdm(dataloader, dynamic_ncols=True) 83 | else: 84 | iter_dataset = dataloader 85 | for batch in tqdm(iter_dataset): 86 | images = batch[0] 87 | if len(images.shape) == 5: 88 | video_len, n_channels, h, w = images.size(-4), images.size(-3), images.size(-2), images.size(-1) 89 | images = images.view(batch_size * video_len, n_channels, h, w) 90 | # print(images.shape) 91 | images = images.type(torch.FloatTensor).to(device) 92 | 93 | # print(images.shape) 94 | if normalize: 95 | images = (images + 1) / 2 # [-1, 1] -> [0, 1] 96 | if images.size(3) != 299: 97 | images = F.interpolate(images, size=(299, 299), 98 | mode='bilinear', align_corners=False) 99 | pred = model(images)[0] 100 | 101 | # If model output is not scalar, apply global spatial average 102 | # pooling. This happens if you choose a dimensionality not equal 103 | # 2048. 104 | if pred.shape[2] != 1 or pred.shape[3] != 1: 105 | pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1)) 106 | features.append(pred.cpu().numpy().reshape(-1, dims)) 107 | features_cache.append(np.expand_dims(pred.cpu().numpy().reshape(-1, dims), axis=0)) 108 | 109 | features = np.concatenate(features, axis=0) 110 | mu = np.mean(features, axis=0) 111 | sigma = np.cov(features, rowvar=False) 112 | 113 | return mu, sigma, np.concatenate(features_cache, axis=0) 114 | 115 | 116 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 117 | """Numpy implementation of the Frechet Distance. 118 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 119 | and X_2 ~ N(mu_2, C_2) is 120 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 121 | 122 | Stable version by Dougal J. Sutherland. 123 | 124 | Params: 125 | mu1: Numpy array containing the activations of a layer of the 126 | inception net (like returned by the function 'get_predictions') 127 | for generated samples. 128 | mu2: The sample mean over activations, precalculated on an 129 | representative data set. 130 | sigma1: The covariance matrix over activations for generated samples. 131 | sigma2: The covariance matrix over activations, precalculated on an 132 | representative data set. 133 | 134 | Returns: 135 | The Frechet Distance. 136 | """ 137 | 138 | mu1 = np.atleast_1d(mu1) 139 | mu2 = np.atleast_1d(mu2) 140 | 141 | sigma1 = np.atleast_2d(sigma1) 142 | sigma2 = np.atleast_2d(sigma2) 143 | 144 | assert mu1.shape == mu2.shape, \ 145 | 'Training and test mean vectors have different lengths' 146 | assert sigma1.shape == sigma2.shape, \ 147 | 'Training and test covariances have different dimensions' 148 | 149 | diff = mu1 - mu2 150 | 151 | # Product might be almost singular 152 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 153 | if not np.isfinite(covmean).all(): 154 | print('fid calculation produces singular product; adding %s to diagonal of cov estimates' % eps) 155 | offset = np.eye(sigma1.shape[0]) * eps 156 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 157 | 158 | # Numerical error might give slight imaginary component 159 | if np.iscomplexobj(covmean): 160 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 161 | m = np.max(np.abs(covmean.imag)) 162 | raise ValueError('Imaginary component {}'.format(m)) 163 | covmean = covmean.real 164 | 165 | return (diff.dot(diff) + 166 | np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)) 167 | 168 | 169 | def fid_score(r_imgs, g_imgs, batch_size=32, dims=2048, cuda=False, 170 | normalize=False, r_cache=None, verbose=0): 171 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 172 | model = InceptionV3([block_idx]) 173 | 174 | # cache real dataset 175 | if r_cache and not r_cache.endswith('.npz'): 176 | r_cache = r_cache + '.npz' 177 | if r_cache and os.path.exists(r_cache): 178 | data = np.load(r_cache) 179 | m1, s1 = data['m1'], data['s1'] 180 | else: 181 | m1, s1, f1 = calculate_activation_statistics(r_imgs, model, batch_size, 182 | dims, cuda, normalize) 183 | if r_cache is not None: 184 | # os.makedirs(os.path.dirname(r_cache), exist_ok=True) 185 | np.savez(r_cache, m1=m1, s1=s1) 186 | np.save(r_cache.replace('.npz', '.npy'), f1) 187 | # compute generative image dataset 188 | m2, s2, f2 = calculate_activation_statistics(g_imgs, model, batch_size, dims, 189 | cuda, normalize) 190 | np.save(r_cache.replace('.npz', '_gen.npy'), f2) 191 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 192 | 193 | return fid_value -------------------------------------------------------------------------------- /story-dalle/vfid/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = models.inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def fid_inception_v3(): 167 | """Build pretrained Inception model for FID computation 168 | 169 | The Inception model for FID computation uses a different set of weights 170 | and has a slightly different structure than torchvision's Inception. 171 | 172 | This method first constructs torchvision's Inception and then patches the 173 | necessary parts that are different in the FID Inception model. 174 | """ 175 | inception = models.inception_v3(num_classes=1008, 176 | aux_logits=False, 177 | pretrained=False) 178 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 179 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 180 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 181 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 182 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 183 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 184 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 185 | inception.Mixed_7b = FIDInceptionE_1(1280) 186 | inception.Mixed_7c = FIDInceptionE_2(2048) 187 | 188 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 189 | inception.load_state_dict(state_dict) 190 | return inception 191 | 192 | 193 | class FIDInceptionA(models.inception.InceptionA): 194 | """InceptionA block patched for FID computation""" 195 | def __init__(self, in_channels, pool_features): 196 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 197 | 198 | def forward(self, x): 199 | branch1x1 = self.branch1x1(x) 200 | 201 | branch5x5 = self.branch5x5_1(x) 202 | branch5x5 = self.branch5x5_2(branch5x5) 203 | 204 | branch3x3dbl = self.branch3x3dbl_1(x) 205 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 206 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 207 | 208 | # Patch: Tensorflow's average pool does not use the padded zero's in 209 | # its average calculation 210 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 211 | count_include_pad=False) 212 | branch_pool = self.branch_pool(branch_pool) 213 | 214 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 215 | return torch.cat(outputs, 1) 216 | 217 | 218 | class FIDInceptionC(models.inception.InceptionC): 219 | """InceptionC block patched for FID computation""" 220 | def __init__(self, in_channels, channels_7x7): 221 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 222 | 223 | def forward(self, x): 224 | branch1x1 = self.branch1x1(x) 225 | 226 | branch7x7 = self.branch7x7_1(x) 227 | branch7x7 = self.branch7x7_2(branch7x7) 228 | branch7x7 = self.branch7x7_3(branch7x7) 229 | 230 | branch7x7dbl = self.branch7x7dbl_1(x) 231 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 232 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 233 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 234 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 235 | 236 | # Patch: Tensorflow's average pool does not use the padded zero's in 237 | # its average calculation 238 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 239 | count_include_pad=False) 240 | branch_pool = self.branch_pool(branch_pool) 241 | 242 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 243 | return torch.cat(outputs, 1) 244 | 245 | 246 | class FIDInceptionE_1(models.inception.InceptionE): 247 | """First InceptionE block patched for FID computation""" 248 | def __init__(self, in_channels): 249 | super(FIDInceptionE_1, self).__init__(in_channels) 250 | 251 | def forward(self, x): 252 | branch1x1 = self.branch1x1(x) 253 | 254 | branch3x3 = self.branch3x3_1(x) 255 | branch3x3 = [ 256 | self.branch3x3_2a(branch3x3), 257 | self.branch3x3_2b(branch3x3), 258 | ] 259 | branch3x3 = torch.cat(branch3x3, 1) 260 | 261 | branch3x3dbl = self.branch3x3dbl_1(x) 262 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 263 | branch3x3dbl = [ 264 | self.branch3x3dbl_3a(branch3x3dbl), 265 | self.branch3x3dbl_3b(branch3x3dbl), 266 | ] 267 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 268 | 269 | # Patch: Tensorflow's average pool does not use the padded zero's in 270 | # its average calculation 271 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 272 | count_include_pad=False) 273 | branch_pool = self.branch_pool(branch_pool) 274 | 275 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 276 | return torch.cat(outputs, 1) 277 | 278 | 279 | class FIDInceptionE_2(models.inception.InceptionE): 280 | """Second InceptionE block patched for FID computation""" 281 | def __init__(self, in_channels): 282 | super(FIDInceptionE_2, self).__init__(in_channels) 283 | 284 | def forward(self, x): 285 | branch1x1 = self.branch1x1(x) 286 | 287 | branch3x3 = self.branch3x3_1(x) 288 | branch3x3 = [ 289 | self.branch3x3_2a(branch3x3), 290 | self.branch3x3_2b(branch3x3), 291 | ] 292 | branch3x3 = torch.cat(branch3x3, 1) 293 | 294 | branch3x3dbl = self.branch3x3dbl_1(x) 295 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 296 | branch3x3dbl = [ 297 | self.branch3x3dbl_3a(branch3x3dbl), 298 | self.branch3x3dbl_3b(branch3x3dbl), 299 | ] 300 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 301 | 302 | # Patch: The FID Inception model uses max pooling instead of average 303 | # pooling. This is likely an error in this specific Inception 304 | # implementation, as other Inception models use average pooling here 305 | # (which matches the description in the paper). 306 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 307 | branch_pool = self.branch_pool(branch_pool) 308 | 309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 310 | return torch.cat(outputs, 1) --------------------------------------------------------------------------------