├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── storydalle.iml
└── vcs.xml
├── DEMO.MD
├── LICENSE
├── MODEL_CARD.MD
├── README.MD
├── assets
├── demo.pdf
├── demo.png
├── demo_pororo_good.png
├── pororo_characters.png
├── story_dalle.png
└── story_dalle_predictions.png
├── mega-story-dalle
├── didemo_dataloader.py
├── flintstones_dataloader.py
├── min_dalle
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── min_dalle.cpython-38.pyc
│ │ └── text_tokenizer.cpython-38.pyc
│ ├── min_dalle.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── dalle_bart_decoder.cpython-38.pyc
│ │ │ ├── dalle_bart_encoder.cpython-38.pyc
│ │ │ └── vqgan_detokenizer.cpython-38.pyc
│ │ ├── dalle_bart_decoder.py
│ │ ├── dalle_bart_encoder.py
│ │ └── vqgan_detokenizer.py
│ └── text_tokenizer.py
├── pororo_dataloader.py
├── setup.py
├── tkinter_ui.py
├── train_story.sh
└── train_t2i.py
└── story-dalle
├── 1.3B
└── config.yaml
├── __init__.py
├── __pycache__
└── pororo_dataloader.cpython-38.pyc
├── dalle
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ └── trainer_prefix.cpython-38.pyc
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── prefix_tuning_model.cpython-38.pyc
│ │ └── tokenizer.cpython-38.pyc
│ ├── stage1
│ │ ├── __pycache__
│ │ │ ├── layers.cpython-37.pyc
│ │ │ ├── layers.cpython-38.pyc
│ │ │ ├── vqgan.cpython-37.pyc
│ │ │ └── vqgan.cpython-38.pyc
│ │ ├── layers.py
│ │ └── vqgan.py
│ ├── stage2
│ │ ├── __pycache__
│ │ │ ├── layers.cpython-37.pyc
│ │ │ ├── layers.cpython-38.pyc
│ │ │ ├── transformer.cpython-37.pyc
│ │ │ └── transformer.cpython-38.pyc
│ │ ├── layers.py
│ │ └── transformer.py
│ └── tokenizer.py
├── trainer_prefix.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── config.cpython-38.pyc
│ ├── sampling.cpython-38.pyc
│ ├── utils.cpython-37.pyc
│ └── utils.cpython-38.pyc
│ ├── config.py
│ ├── sampling.py
│ └── utils.py
├── didemo_dataloader.py
├── eval_char_clf.py
├── eval_char_clf.sh
├── eval_fid.py
├── eval_fid.sh
├── flintstones_dataloader.py
├── get_use_embeddings.py
├── infer_story.sh
├── infer_t2i.py
├── pororo_dataloader.py
├── train_story.sh
├── train_t2i.py
├── utils.py
└── vfid
├── __pycache__
├── fid_score.cpython-38.pyc
└── inception.cpython-38.pyc
├── fid_score.py
└── inception.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/storydalle.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/DEMO.MD:
--------------------------------------------------------------------------------
1 | ## Demo Coming Soon!
2 | (as soon as porting from Gradio Local Demo to HuggingFace Spaces is complete)
3 | Meanwhile, see the snapshot of our demo functionalities below!
4 |
5 | 
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Adyasha Maharana, Darryl Hannan and Mohit Bansal
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/MODEL_CARD.MD:
--------------------------------------------------------------------------------
1 | ### Model Description
2 |
3 | StoryDALL-E \[1\] is a model trained for the task of Story Visualization \[2\].
4 | The model receives a sequence of captions as input and generates a corresponding sequence of images which form a visual story depicting the narrative in the captions.
5 | We modify this task to enable the model to receive an initial scene as input, which can be used as a cue for the setting of the story and also for generating unseen or low-resource visual elements. We refer to this task as Story Continuation \[1\].
6 | StoryDALL-E is based on the [dalle](https://github.com/kakaobrain/minDALL-E) model. **This model has been developed for academic purposes only.**
7 |
8 | \[[paper](https://arxiv.org/abs/2209.06192)\] \[[code](https://github.com/adymaharana/storydalle/)\] \[[demo]()\]
9 |
10 | ### Dataset
11 |
12 | This model has been trained using the Pororo story visualization dataset \[2\].
13 | The data was adapted from the popular cartoon series *Pororo the Little Penguin* and originally released by \[3\].
14 | The Pororo dataset contains 9 recurring characters, as shown below, in the decreasing order of their frequency in the training data.
15 |
16 |
17 |
18 | The training dataset contains nearly 10,000 samples in the training set. Most of the scenes occur in a snowy village, surrounded by hills, trees and houses. A few episodes are located in gardens or water bodies. All the captions are in the English language and predominantly contain verbs in the present tense.
19 |
20 | Additionally, the training of this model starts from the pretrained checkpoint of mega-dalle, which is trained on 15 million images from the Conceptual Captions dataset \[4\] that has been scraped and filtered from billions of webpages.
21 |
22 | ### Intended Use
23 | This model is intended for generating visual stories containing the 9 characters in the Pororo dataset. This version of the StoryDALL-E model is reasonable at the following scenarios:
24 | * Frames containing a single character.
25 | * Overtly visual actions such as *making cookies*, *walking*, *reading a book*, *sitting*.
26 | * Scenes taking place in snowy settings, indoors and gardens.
27 | * Visual stories contaning 1-3 characters across all frames.
28 | * Scene transitions e.g. from day to night.
29 |
30 | Here are some examples of generated visual stories for the above-mentioned settings.
31 |
32 |
33 |
34 |
35 | Due to the small training dataset size for story visualization, the model has poor generalization to some unseen settings. The model struggles to generate coherent images in the following scenarios.
36 | * Multiple characters in a frame.
37 | * Non-visual actions such as *compliment*.
38 | * Characters that are infrequent in the training dataset e.g. *Rody*, *Harry*.
39 | * Background locations that are not found in the cartoon e.g. *a busy city*.
40 | * Color-based descriptions for object.
41 | * Completely new characters based on textual descriptions.
42 |
43 | In summary, we find that the model performs well at visualizing stories with up to three characters across all frames and struggles at generating coherent visuals for more than three characters.
44 | The model copies visual elements from the source image and copies to each of the generated frames in the story, hence maintaining a continuous flow in narration by virtue of conditioning on an initial scene.
45 | StoryDALL-E performs best at generating overtly visual actions and is capable of generating semantic concepts that do not appear in the story continuation dataset, such as *doughnut* and *lion*, by leveraging the pretrained knowledge of DALL-E Mega when possible.
46 | Most of the scenes in the Pororo dataset occur within the setting of a snowy village with wooden houses surrounded by trees and snow. Hence, the model usually generates scenes with similar visual elements.
47 |
48 | ### Ethical Considerations.
49 | Our experimental results are specific to the task of story continuation.
50 | By using cartoon images in our task, we avoid the egregious ethical issues associated with real-world usage of image generation such as DeepFakes.
51 | We focus not on generating realistic images, but on improved multi-modal understanding in the context of story visualization.
52 |
53 | ### Citation:
54 | ```
55 | @inproceedings{maharana2022storydalle,
56 | title={StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation},
57 | author={Maharana, Adyasha and Hannan, Darryl and Bansal, Mohit},
58 | booktitle={ECCV},
59 | year={2022}
60 | }
61 | ```
62 | Send questions, feedback or comments to adyasha@cs.unc.edu.
63 |
64 | ### References
65 |
66 | \[1\] Maharana, Adyasha, et al. "StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation." ECCV. 2022.
67 |
68 | \[2\] Li, Yitong, et al. "Storygan: A sequential conditional gan for story visualization." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
69 |
70 | \[3\] Kim, Kyung-Min, et al. "DeepStory: video story QA by deep embedded memory networks." Proceedings of the 26th International Joint Conference on Artificial Intelligence. 2017.
71 |
72 | \[4\] Sharma, Piyush, et al. "Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2018.
73 |
74 | \[5\] Mitchell, Margaret, et al. "Model cards for model reporting." Proceedings of the conference on fairness, accountability, and transparency. 2019.
75 |
76 |
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 | ## StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation
2 |
3 | PyTorch code for the ECCV 2022 paper "StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation".
4 |
5 | \[[Paper](https://arxiv.org/abs/2209.06192)\] \[[Model Card](https://github.com/adymaharana/storydalle/blob/main/MODEL_CARD.MD)\] \[[Spaces Demo](https://huggingface.co/spaces/ECCV2022/storydalle)\] \[[Replicate Demo](https://replicate.com/adymaharana/story-dalle)\]
6 |
7 | 
8 |
9 | 
10 |
11 | ### Training
12 |
13 | #### Prepare Repository:
14 | Download the PororoSV dataset and associated files from [here](https://drive.google.com/file/d/11Io1_BufAayJ1BpdxxV2uJUvCcirbrNc/view?usp=sharing) (updated) and save it as ```./data/pororo/```.
15 | Download the FlintstonesSV dataset and associated files from [here](https://drive.google.com/file/d/1kG4esNwabJQPWqadSDaugrlF4dRaV33_/view?usp=sharing) and save it as ```./data/flintstones```
16 | Download the DiDeMoSV dataset and associated files from [here](https://drive.google.com/file/d/1zgj_bpE6Woyi-G76axF0nO-yzQaLBayc/view?usp=sharing) and save it as ```./data/didemo```
17 |
18 | This repository contains separate folders for training StoryDALL-E based on [minDALL-E](https://github.com/kakaobrain/minDALL-E) and [DALL-E Mega](https://github.com/kuprel/min-dalle) models i.e. the ```./story_dalle/``` and ```./mega-story-dalle``` models respectively.
19 |
20 | #### Training StoryDALL-E based on minDALL-E:
21 |
22 | 1. To finetune the minDALL-E model for story continuation, first migrate to the corresponding folder:\
23 | ```cd story-dalle```
24 | 2. Set the environment variables in ```train_story.sh``` to point to the right locations in your system. Specifically, change the ```$DATA_DIR```, ```$OUTPUT_ROOT``` and ```$LOG_DIR``` if different from the default locations.
25 | 3. Download the pretrained checkpoint from [here](https://github.com/kakaobrain/minDALL-E) and save it in ```./1.3B```
26 | 4. Run the following command:
27 | ```bash train_story.sh ```
28 |
29 |
30 | #### Training StoryDALL-E based on DALL-E Mega:
31 |
32 | 1. To finetune the DALL-E Mega model for story continuation, first migrate to the corresponding folder:\
33 | ```cd mega-story-dalle```
34 | 2. Set the environment variables in ```train_story.sh``` to point to the right locations in your system. Specifically, change the ```$DATA_DIR```, ```$OUTPUT_ROOT``` and ```$LOG_DIR``` if different from the default locations.
35 | 3. Pretrained checkpoints for generative model and VQGAN detokenizer are automatically downloaded upon initialization. Download the pretrained weights for VQGAN tokenizer from [here](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/) and place it in the same folder as VQGAN detokenizer.
36 | 4. Run the following command:
37 | ```bash train_story.sh ```
38 |
39 | ### Inference
40 | Pretrained checkpoints for minDALL-E based StoryDALL-E can be downloaded from here: [PororoSV](https://drive.google.com/file/d/1lJ6zMZ6qTvFu6H35-VEdFlN13MMslivJ/view?usp=sharing)
41 |
42 | For a demo of inference using cog, check out [this repo](https://github.com/daanelson/story-dalle-cog).
43 |
44 | #### Inferring from StoryDALL-E based on minDALL-E:
45 |
46 | 1. To infer from the minDALL-E model for story continuation, first migrate to the corresponding folder:\
47 | ```cd story-dalle```
48 | 2. Set the environment variables in ```infer_story.sh``` to point to the right locations in your system. Specifically, change the ```$DATA_DIR```, ```$OUTPUT_ROOT``` and ```$MODEL_CKPT``` if different from the default locations.
49 | 3Run the following command:
50 | ```bash infer_story.sh ```
51 |
52 | #### Memory Requirements for Inference:
53 |
54 | For double-precision inference, the StoryDALLE model requires nearly 40 GB of space. The memory requirements can be reduced to 20GB by performing mixed precision inference from the autoregressive decoder (included in codebase, see line 1095 in story-dalle/dalle/models/__init_.py). Note that the VQGAN model needs to operate at full precision to retain high-quality of the generated images.
55 |
56 |
57 | ### Acknowledgements
58 | Thanks to the fantastic folks at Kakao Brain and HuggingFace for their work on open-sourced versions of min-DALLE and DALL-E Mega. Much of this codebase has been adapted from [here](https://github.com/kakaobrain/minDALL-E) and [here](https://github.com/kuprel/min-dalle).
--------------------------------------------------------------------------------
/assets/demo.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/demo.pdf
--------------------------------------------------------------------------------
/assets/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/demo.png
--------------------------------------------------------------------------------
/assets/demo_pororo_good.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/demo_pororo_good.png
--------------------------------------------------------------------------------
/assets/pororo_characters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/pororo_characters.png
--------------------------------------------------------------------------------
/assets/story_dalle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/story_dalle.png
--------------------------------------------------------------------------------
/assets/story_dalle_predictions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/assets/story_dalle_predictions.png
--------------------------------------------------------------------------------
/mega-story-dalle/flintstones_dataloader.py:
--------------------------------------------------------------------------------
1 | import os, pickle
2 | from tqdm import tqdm
3 | import numpy as np
4 | import torch.utils.data
5 | import PIL
6 | from random import randrange
7 | import json
8 | from torchvision import transforms
9 | from PIL import Image
10 |
11 | unique_characters = ["Wilma", "Fred", "Betty", "Barney", "Dino", "Pebbles", "Mr Slate"]
12 |
13 | class ImageDataset(torch.utils.data.Dataset):
14 | def __init__(self, dir_path, tokenizer, preprocess, mode='train'):
15 | self.dir_path = dir_path
16 |
17 | splits = json.load(open(os.path.join(self.dir_path, 'train-val-test_split.json'), 'r'))
18 | train_id, val_id, test_id = splits["train"], splits["val"], splits["test"]
19 |
20 | if mode == 'train':
21 | self.orders = train_id
22 | elif mode =='val':
23 | self.orders = val_id
24 | elif mode == 'test':
25 | self.orders = test_id
26 | else:
27 | raise ValueError
28 | print("Total number of clips {}".format(len(self.orders)))
29 |
30 | annotations = json.load(open(os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json')))
31 | self.descriptions = {}
32 | for sample in annotations:
33 | self.descriptions[sample["globalID"]] = sample["description"]
34 |
35 | self.preprocess = preprocess
36 | self.tokenizer = tokenizer
37 |
38 | def __getitem__(self, item):
39 |
40 | # single image input
41 | globalID = self.orders[item]
42 | path = os.path.join(self.dir_path, 'video_frames_sampled', globalID + '.npy')
43 | arr = np.load(path)
44 | n_frames = arr.shape[0]
45 | random_range = randrange(n_frames)
46 | im = arr[random_range]
47 | image = np.array(im)
48 | image = PIL.Image.fromarray(image.astype('uint8'), 'RGB')
49 | text = self.descriptions[globalID]
50 | tokens = self.tokenizer.encode(text.lower())
51 | tokens = torch.LongTensor(tokens.ids)
52 | image = self.preprocess(image)
53 |
54 | return image, tokens
55 |
56 | def __len__(self):
57 | return len(self.orders)
58 |
59 |
60 | class StoryImageDataset(torch.utils.data.Dataset):
61 | def __init__(self, dir_path, tokenizer, transform=None, mode='train', im_input_size=128, out_img_folder='', return_labels=False):
62 | self.dir_path = dir_path
63 |
64 | splits = json.load(open(os.path.join(self.dir_path, 'train-val-test_split.json'), 'r'))
65 | train_id, val_id, test_id = splits["train"], splits["val"], splits["test"]
66 |
67 | min_len = 4
68 | if os.path.exists(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl')):
69 | self.followings = pickle.load(open(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl'), 'rb'))
70 | else:
71 | print("Cache does not exist")
72 | all_clips = train_id + val_id + test_id
73 | all_clips.sort()
74 | for idx, clip in enumerate(tqdm(all_clips, desc="Counting total number of frames")):
75 | season, episode = int(clip.split('_')[1]), int(clip.split('_')[3])
76 | has_frames = True
77 | for c in all_clips[idx+1:idx+min_len+1]:
78 | s_c, e_c = int(c.split('_')[1]), int(c.split('_')[3])
79 | if s_c != season or e_c != episode:
80 | has_frames = False
81 | break
82 | if has_frames:
83 | self.followings[clip] = all_clips[idx+1:idx+min_len+1]
84 | else:
85 | continue
86 | pickle.dump(self.followings, open(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl'), 'wb'))
87 |
88 | train_id = [tid for tid in train_id if tid in self.followings]
89 | val_id = [vid for vid in val_id if vid in self.followings]
90 | test_id = [tid for tid in test_id if tid in self.followings]
91 |
92 | self.labels = pickle.load(open(os.path.join(dir_path, 'labels.pkl'), 'rb'))
93 |
94 | if mode == 'train':
95 | self.orders = train_id
96 | elif mode =='val':
97 | val_id = [vid for vid in val_id if len(self.followings[vid]) == 4]
98 | self.orders = val_id
99 | elif mode == 'test':
100 | test_id = [vid for vid in test_id if len(self.followings[vid]) == 4]
101 | self.orders = test_id[:1900]
102 | else:
103 | raise ValueError
104 | print("Total number of clips {}".format(len(self.orders)))
105 |
106 | annotations = json.load(open(os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json')))
107 | self.descriptions = {}
108 | for sample in annotations:
109 | self.descriptions[sample["globalID"]] = sample["description"]
110 |
111 | self.embeds = np.load(os.path.join(self.dir_path, "flintstones_use_embeddings.npy"))
112 | self.sent2idx = pickle.load(open(os.path.join(self.dir_path, 'flintstones_use_embed_idxs.pkl'), 'rb'))
113 |
114 | if mode == 'train':
115 | if transform:
116 | self.transform = transform
117 | else:
118 | self.transform = transforms.Compose([
119 | transforms.RandomResizedCrop(im_input_size),
120 | transforms.RandomHorizontalFlip(),
121 | transforms.ToTensor(),
122 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
123 | ])
124 | else:
125 | if transform:
126 | self.transform = transform
127 | else:
128 | self.transform = transforms.Compose([
129 | transforms.Resize(im_input_size),
130 | transforms.CenterCrop(im_input_size),
131 | transforms.ToTensor(),
132 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
133 | ])
134 |
135 | self.tokenizer = tokenizer
136 | self.return_labels = return_labels
137 | self.out_img_folder = out_img_folder
138 |
139 | def __getitem__(self, item):
140 |
141 | # single image input
142 | globalIDs = [self.orders[item]] + self.followings[self.orders[item]]
143 | tokens = []
144 | images = []
145 | for idx, globalID in enumerate(globalIDs):
146 | if self.out_img_folder and idx != 0:
147 | image = Image.open(os.path.join(self.out_img_folder, 'gen_sample_%s_%s.png' % (item, idx-1))).convert('RGB')
148 | else:
149 | path = os.path.join(self.dir_path, 'video_frames_sampled', globalID + '.npy')
150 | arr = np.load(path)
151 | n_frames = arr.shape[0]
152 | random_range = randrange(n_frames)
153 | im = arr[random_range]
154 | # image = np.array(im)
155 | image = PIL.Image.fromarray(im.astype('uint8'), 'RGB')
156 | images.append(image)
157 | text = self.descriptions[globalID]
158 | if idx != 0:
159 | if self.tokenizer is not None:
160 | tokens.append(self.tokenizer.encode(text.lower()))
161 | else:
162 | tokens.append(text)
163 | if self.tokenizer is not None:
164 | tokens = torch.stack([torch.LongTensor(token.ids) for token in tokens])
165 |
166 | sent_embeds = [torch.tensor(self.embeds[self.sent2idx[globalID]]) for globalID in globalIDs[1:]]
167 |
168 | if self.return_labels:
169 | labels = [torch.tensor(self.labels[globalID]) for globalID in globalIDs[1:]]
170 | return torch.stack([self.transform(im) for im in images[1:]]), torch.stack(labels), tokens, self.transform(
171 | images[0]), torch.stack(sent_embeds)
172 | else:
173 | return torch.stack([self.transform(im) for im in images[1:]]), tokens, self.transform(images[0]), torch.stack(sent_embeds)
174 |
175 | def __len__(self):
176 | return len(self.orders)
177 |
178 |
179 | # if __name__ == "__main__":
180 | #
181 | # dataset = StoryImageDataset('/nas-ssd/adyasha/datasets/flintstones', None, None, 'val')
182 | # for item in range(len(dataset)):
183 | # texts = []
184 | # globalIDs = [dataset.orders[item]] + dataset.followings[dataset.orders[item]]
185 | # for idx, globalID in enumerate(globalIDs):
186 | # text = dataset.descriptions[globalID]
187 | # texts.append(text)
188 | # if len(texts) != 5:
189 | # print(item, globalIDs)
190 |
191 |
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/__init__.py:
--------------------------------------------------------------------------------
1 | from .min_dalle import MinDalle, PipelineParallelMinDalle
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/__pycache__/min_dalle.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/__pycache__/min_dalle.cpython-38.pyc
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/__pycache__/text_tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/__pycache__/text_tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .dalle_bart_encoder import DalleBartEncoder
2 | from .dalle_bart_decoder import DalleBartDecoder
3 | from .vqgan_detokenizer import VQGanDetokenizer, VQGanTokenizer
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/models/__pycache__/dalle_bart_decoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/models/__pycache__/dalle_bart_decoder.cpython-38.pyc
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/models/__pycache__/dalle_bart_encoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/models/__pycache__/dalle_bart_encoder.cpython-38.pyc
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/models/__pycache__/vqgan_detokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/mega-story-dalle/min_dalle/models/__pycache__/vqgan_detokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/models/dalle_bart_decoder.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List
2 | import torch
3 | from torch import nn, LongTensor, FloatTensor, BoolTensor
4 | from .dalle_bart_encoder import GLU, AttentionBase
5 |
6 | IMAGE_TOKEN_COUNT = 256
7 |
8 |
9 | class DecoderCrossAttention(AttentionBase):
10 | def forward(
11 | self,
12 | decoder_state: FloatTensor,
13 | encoder_state: FloatTensor,
14 | attention_mask: BoolTensor
15 | ) -> FloatTensor:
16 | keys = self.k_proj.forward(encoder_state)
17 | values = self.v_proj.forward(encoder_state)
18 | queries = self.q_proj.forward(decoder_state)
19 | return super().forward(keys, values, queries, attention_mask)
20 |
21 |
22 | class DecoderSelfAttention(AttentionBase):
23 | def __init__(self, head_count: int, embed_count: int):
24 | super().__init__(head_count, embed_count)
25 |
26 | def forward(
27 | self,
28 | decoder_state: FloatTensor,
29 | attention_state: FloatTensor,
30 | attention_mask: BoolTensor,
31 | token_index: LongTensor = None
32 | ) -> Tuple[FloatTensor, FloatTensor]:
33 | keys = self.k_proj.forward(decoder_state)
34 | values = self.v_proj.forward(decoder_state)
35 | queries = self.q_proj.forward(decoder_state)
36 |
37 | # TODO: refer to the cache process in minDALLE to fix this
38 | if token_index is not None:
39 | token_count = token_index.shape[1]
40 | if token_count == 1:
41 | batch_count = decoder_state.shape[0]
42 | attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
43 | attention_state[:, token_index[0]] = attn_state_new
44 | keys = attention_state[:batch_count]
45 | values = attention_state[batch_count:]
46 |
47 | decoder_state = super().forward(keys, values, queries, attention_mask)
48 | return decoder_state, attention_state
49 |
50 |
51 | class DecoderLayer(nn.Module):
52 | def __init__(
53 | self,
54 | head_count: int,
55 | embed_count: int,
56 | glu_embed_count: int,
57 | device: str,
58 | condition: bool = False
59 | ):
60 | super().__init__()
61 | self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
62 | self.self_attn = DecoderSelfAttention(head_count, embed_count)
63 | self.self_attn_layer_norm = nn.LayerNorm(embed_count)
64 | self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
65 | self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
66 | self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
67 | self.glu = GLU(embed_count, glu_embed_count)
68 | self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device)
69 | self.head_count = head_count
70 |
71 | self.condition = condition
72 | if condition:
73 | self.pre_condition_attn_layer_norm = nn.LayerNorm(embed_count)
74 | self.condition_attn = DecoderCrossAttention(head_count, embed_count)
75 | self.condition_attn_layer_norm = nn.LayerNorm(embed_count)
76 |
77 | def sample(
78 | self,
79 | decoder_state: FloatTensor,
80 | encoder_state: FloatTensor,
81 | attention_state: FloatTensor,
82 | attention_mask: BoolTensor,
83 | token_index: LongTensor,
84 | condition_state: FloatTensor = None
85 | ) -> Tuple[FloatTensor, FloatTensor]:
86 | # Self Attention
87 | token_count = token_index.shape[1]
88 | if token_count == 1:
89 | # print(self.token_indices.device, token_index.device)
90 | self_attn_mask = self.token_indices <= token_index
91 | self_attn_mask = self_attn_mask[:, None, None, :]
92 | else:
93 | self_attn_mask = (
94 | self.token_indices[None, None, :token_count] <=
95 | token_index[:, :, None]
96 | )
97 | self_attn_mask = self_attn_mask[:, None, :, :]
98 |
99 | # TODO: Fix self-attention mask
100 |
101 | residual = decoder_state
102 | decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
103 | decoder_state, attention_state = self.self_attn.forward(
104 | decoder_state=decoder_state,
105 | attention_state=attention_state,
106 | attention_mask=self_attn_mask,
107 | token_index=token_index
108 | )
109 | decoder_state = self.self_attn_layer_norm.forward(decoder_state)
110 | decoder_state = residual + decoder_state
111 |
112 | # Cross Attention
113 | residual = decoder_state
114 | decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state)
115 | decoder_state = self.encoder_attn.forward(
116 | decoder_state=decoder_state,
117 | encoder_state=encoder_state,
118 | attention_mask=attention_mask
119 | )
120 | decoder_state = self.encoder_attn_layer_norm.forward(decoder_state)
121 | decoder_state = residual + decoder_state
122 |
123 | # Cross-Attention Over Image Condition
124 | if self.condition:
125 | assert condition_state is not None
126 | residual = decoder_state
127 | decoder_state = self.pre_condition_attn_layer_norm.forward(decoder_state)
128 | decoder_state = self.condition_attn.forward(
129 | decoder_state=decoder_state,
130 | encoder_state=encoder_state,
131 | attention_mask=attention_mask
132 | )
133 | decoder_state = self.condition_attn_layer_norm.forward(decoder_state)
134 | decoder_state = residual + decoder_state
135 |
136 | # Feed forward
137 | residual = decoder_state
138 | decoder_state = self.glu.forward(decoder_state)
139 | decoder_state = residual + decoder_state
140 |
141 | return decoder_state, attention_state
142 |
143 |
144 | def forward(
145 | self,
146 | decoder_state: FloatTensor,
147 | encoder_state: FloatTensor,
148 | attention_state: FloatTensor,
149 | attention_mask: BoolTensor,
150 | condition_state: FloatTensor = None
151 | ) -> Tuple[FloatTensor, FloatTensor]:
152 | # Self Attention
153 | # token_count = token_index.shape[1]
154 | # if token_count == 1:
155 | # self_attn_mask = self.token_indices <= token_index
156 | # self_attn_mask = self_attn_mask[:, None, None, :]
157 | # else:
158 | # self_attn_mask = (
159 | # self.token_indices[None, None, :token_count] <=
160 | # token_index[:, :, None]
161 | # )
162 | # self_attn_mask = self_attn_mask[:, None, :, :]
163 |
164 | # TODO: Fix self-attention mask
165 | B, N = decoder_state.shape[:2]
166 | self_attn_mask = torch.tril(torch.ones(size=(N, N), device=decoder_state.device)).view(1, 1, N, N).repeat(B, self.head_count, 1, 1)
167 | # print("Self-attention mask shape: ", self_attn_mask.shape)
168 |
169 | residual = decoder_state
170 | decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
171 | decoder_state, attention_state = self.self_attn.forward(
172 | decoder_state=decoder_state,
173 | attention_state=attention_state,
174 | attention_mask=self_attn_mask,
175 | # token_index=token_index
176 | )
177 | decoder_state = self.self_attn_layer_norm.forward(decoder_state)
178 | decoder_state = residual + decoder_state
179 |
180 | # Cross Attention
181 | residual = decoder_state
182 | decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state)
183 | decoder_state = self.encoder_attn.forward(
184 | decoder_state=decoder_state,
185 | encoder_state=encoder_state,
186 | attention_mask=attention_mask
187 | )
188 | decoder_state = self.encoder_attn_layer_norm.forward(decoder_state)
189 | decoder_state = residual + decoder_state
190 |
191 | # Cross-Attention Over Image Condition
192 | if self.condition:
193 | assert condition_state is not None
194 | residual = decoder_state
195 | decoder_state = self.pre_condition_attn_layer_norm.forward(decoder_state)
196 | decoder_state = self.condition_attn.forward(
197 | decoder_state=decoder_state,
198 | encoder_state=encoder_state,
199 | attention_mask=attention_mask
200 | )
201 | decoder_state = self.condition_attn_layer_norm.forward(decoder_state)
202 | decoder_state = residual + decoder_state
203 |
204 | # Feed forward
205 | residual = decoder_state
206 | decoder_state = self.glu.forward(decoder_state)
207 | decoder_state = residual + decoder_state
208 |
209 | return decoder_state, attention_state
210 |
211 |
212 | class DalleBartDecoder(nn.Module):
213 | def __init__(
214 | self,
215 | image_vocab_count: int,
216 | embed_count: int,
217 | attention_head_count: int,
218 | glu_embed_count: int,
219 | layer_count: int,
220 | device: str,
221 | condition: bool = False
222 | ):
223 | super().__init__()
224 | self.layer_count = layer_count
225 | self.embed_count = embed_count
226 | self.image_vocab_count = image_vocab_count
227 | self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
228 | self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count)
229 | self.layers: List[DecoderLayer] = nn.ModuleList([
230 | DecoderLayer(
231 | head_count=attention_head_count,
232 | embed_count=embed_count,
233 | glu_embed_count=glu_embed_count,
234 | device=device,
235 | condition = (i+1)%3 == 0 if condition else False
236 | )
237 | for i in range(layer_count)
238 | ])
239 | self.condition = condition
240 | self.layernorm_embedding = nn.LayerNorm(embed_count)
241 | self.final_ln = nn.LayerNorm(embed_count)
242 | self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
243 | self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device)
244 |
245 | if self.condition:
246 | print("Initialized %s condition attention layers" % sum([(i+1)%3 == 0 for i in range(layer_count)]))
247 |
248 |
249 | def forward(
250 | self,
251 | attention_mask: BoolTensor,
252 | encoder_state: FloatTensor,
253 | attention_state: FloatTensor,
254 | prev_tokens: LongTensor,
255 | condition_state: FloatTensor = None
256 | ) -> Tuple[FloatTensor, FloatTensor]:
257 | decoder_state = self.embed_tokens.forward(prev_tokens)
258 | B, N = prev_tokens.shape
259 | pos_enc_tokens = torch.arange(N, device=prev_tokens.device).repeat((B, 1))
260 | decoder_state += self.embed_positions.forward(pos_enc_tokens)
261 | decoder_state = self.layernorm_embedding.forward(decoder_state)
262 |
263 | if condition_state is not None:
264 | condition_state = self.embed_tokens.forward(condition_state)
265 | B_c, N_c = condition_state.shape[:2]
266 | pos_enc_tokens = torch.arange(N_c, device=condition_state.device).repeat((B_c, 1))
267 | # print(condition_state.shape, pos_enc_tokens.shape)
268 | condition_state += self.embed_positions.forward(pos_enc_tokens)
269 | # print(condition_state.shape)
270 | condition_state = condition_state.repeat_interleave(int(B/B_c), dim=0)
271 | # print(condition_state.shape)
272 |
273 | for i in range(self.layer_count):
274 | decoder_state, attention_state[i] = self.layers[i].forward(
275 | decoder_state,
276 | encoder_state,
277 | attention_state[i],
278 | attention_mask,
279 | condition_state=condition_state if self.condition and (i+1)%3 == 0 else None
280 | )
281 | decoder_state = self.final_ln(decoder_state)
282 | logits = self.lm_head(decoder_state)
283 | return logits, attention_state
284 |
285 | def sample(
286 | self,
287 | attention_mask: BoolTensor,
288 | encoder_state: FloatTensor,
289 | attention_state: FloatTensor,
290 | prev_tokens: LongTensor,
291 | token_index: LongTensor,
292 | condition_state: FloatTensor = None,
293 | supercondition: bool = False
294 | ) -> Tuple[FloatTensor, FloatTensor]:
295 | image_count = encoder_state.shape[0] // 2
296 | token_index = token_index.unsqueeze(0).repeat(image_count * 2, 1)
297 | if supercondition:
298 | prev_tokens = prev_tokens.repeat(2, 1)
299 | decoder_state = self.embed_tokens.forward(prev_tokens)
300 | decoder_state += self.embed_positions.forward(token_index)
301 | decoder_state = self.layernorm_embedding.forward(decoder_state)
302 | for i in range(self.layer_count):
303 | decoder_state, attention_state[i] = self.layers[i].sample(
304 | decoder_state,
305 | encoder_state,
306 | attention_state[i],
307 | attention_mask,
308 | token_index,
309 | condition_state=condition_state if self.condition and (i + 1) % 3 == 0 else None
310 | )
311 | decoder_state = self.final_ln(decoder_state)
312 | logits = self.lm_head(decoder_state)
313 | return logits, attention_state
314 |
315 |
316 | def sample_tokens(self, settings, **kwargs) -> Tuple[LongTensor, FloatTensor]:
317 | logits, attention_state = self.sample(supercondition=settings[2] != 1, **kwargs)
318 | image_count = logits.shape[0] // 2
319 | temperature = settings[[0]]
320 | top_k = settings[[1]].to(torch.long)
321 | supercondition_factor = settings[[2]]
322 |
323 | logits = logits[:, -1, : 2 ** 14]
324 | if supercondition_factor != 1:
325 | logits: FloatTensor = (
326 | logits[:image_count] * (1 - supercondition_factor) +
327 | logits[image_count:] * supercondition_factor
328 | )
329 | else:
330 | # logits: FloatTensor = (
331 | # logits[:image_count] * 0 +
332 | # logits[image_count:] * 1
333 | # )
334 | # print(logits.shape)
335 | pass
336 |
337 | # print(logits.shape)
338 | logits_sorted, _ = logits.sort(descending=True)
339 | # print(logits_sorted.shape)
340 | is_kept = logits >= logits_sorted[:, top_k - 1]
341 | if len(is_kept.shape) == 3:
342 | is_kept = logits >= logits_sorted[:, [top_k - 1]]
343 | assert len(is_kept.shape) == 2
344 | # print(logits_sorted[:, [0]])
345 | logits -= logits_sorted[:, [0]]
346 | # print(logits.shape)
347 | logits /= temperature
348 | # print(logits.shape, temperature)
349 | logits.exp_()
350 | logits *= is_kept.to(torch.float32)
351 | image_tokens = torch.multinomial(logits, 1)[:, 0]
352 | return image_tokens, attention_state
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/models/dalle_bart_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch
3 | from torch import nn, BoolTensor, FloatTensor, LongTensor
4 |
5 |
6 | class GLU(nn.Module):
7 | def __init__(self, count_in_out: int, count_middle: int):
8 | super().__init__()
9 | self.gelu = nn.GELU()
10 | self.ln0 = nn.LayerNorm(count_in_out)
11 | self.ln1 = nn.LayerNorm(count_middle)
12 | self.fc0 = nn.Linear(count_in_out, count_middle, bias=False)
13 | self.fc1 = nn.Linear(count_in_out, count_middle, bias=False)
14 | self.fc2 = nn.Linear(count_middle, count_in_out, bias=False)
15 |
16 | def forward(self, z: FloatTensor) -> FloatTensor:
17 | z = self.ln0.forward(z)
18 | w = self.fc0.forward(z)
19 | w = self.gelu.forward(w)
20 | v = self.fc1.forward(z)
21 | z = self.ln1.forward(w * v)
22 | z = self.fc2.forward(z)
23 | return z
24 |
25 |
26 | class AttentionBase(nn.Module):
27 | def __init__(self, head_count: int, embed_count: int):
28 | super().__init__()
29 | self.head_count = head_count
30 | self.embed_count = embed_count
31 |
32 | self.k_proj = nn.Linear(embed_count, embed_count, bias=False)
33 | self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
34 | self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
35 | self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
36 |
37 | def forward(
38 | self,
39 | keys: FloatTensor,
40 | values: FloatTensor,
41 | queries: FloatTensor,
42 | attention_mask: BoolTensor
43 | ) -> FloatTensor:
44 | keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
45 | values = values.reshape(values.shape[:2] + (self.head_count, -1))
46 | queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
47 | queries /= queries.shape[-1] ** 0.5
48 | attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12
49 | attention_weights: FloatTensor = torch.einsum(
50 | 'bqhc,bkhc->bhqk',
51 | queries,
52 | keys
53 | )
54 |
55 | # print(keys.shape, values.shape, queries.shape)
56 | # print("Attention weights shape: ", attention_weights.shape)
57 | # print("Attention bias shape: ", attention_bias.shape)
58 |
59 | attention_weights += attention_bias
60 | attention_weights = torch.softmax(attention_weights, -1)
61 | attention_output: FloatTensor = torch.einsum(
62 | "bhqk,bkhc->bqhc",
63 | attention_weights,
64 | values
65 | )
66 | shape = attention_output.shape[:2] + (self.embed_count,)
67 | attention_output = attention_output.reshape(shape)
68 | attention_output = self.out_proj.forward(attention_output)
69 | return attention_output
70 |
71 |
72 | class EncoderSelfAttention(AttentionBase):
73 | def forward(
74 | self,
75 | encoder_state: FloatTensor,
76 | attention_mask: BoolTensor
77 | ) -> FloatTensor:
78 | keys = self.k_proj.forward(encoder_state)
79 | values = self.v_proj.forward(encoder_state)
80 | queries = self.q_proj.forward(encoder_state)
81 | return super().forward(keys, values, queries, attention_mask)
82 |
83 |
84 | class EncoderLayer(nn.Module):
85 | def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
86 | super().__init__()
87 | self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
88 | self.self_attn = EncoderSelfAttention(head_count, embed_count)
89 | self.self_attn_layer_norm = nn.LayerNorm(embed_count)
90 | self.glu = GLU(embed_count, glu_embed_count)
91 |
92 | def forward(
93 | self,
94 | encoder_state: FloatTensor,
95 | attention_mask: BoolTensor
96 | ) -> FloatTensor:
97 | residual = encoder_state
98 | encoder_state = self.pre_self_attn_layer_norm.forward(encoder_state)
99 | encoder_state = self.self_attn.forward(encoder_state, attention_mask)
100 | encoder_state = self.self_attn_layer_norm.forward(encoder_state)
101 | encoder_state = residual + encoder_state
102 | residual = encoder_state
103 | encoder_state = self.glu.forward(encoder_state)
104 | encoder_state = residual + encoder_state
105 | return encoder_state
106 |
107 |
108 | class DalleBartEncoder(nn.Module):
109 | def __init__(
110 | self,
111 | layer_count: int,
112 | embed_count: int,
113 | attention_head_count: int,
114 | text_vocab_count: int,
115 | text_token_count: int,
116 | glu_embed_count: int,
117 | device: str
118 | ):
119 | super().__init__()
120 | self.text_vocab_count = text_vocab_count
121 | self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
122 | self.embed_positions = nn.Embedding(text_token_count, embed_count)
123 | self.layers: List[EncoderLayer] = nn.ModuleList([
124 | EncoderLayer(
125 | embed_count = embed_count,
126 | head_count = attention_head_count,
127 | glu_embed_count = glu_embed_count
128 | )
129 | for _ in range(layer_count)
130 | ])
131 | self.layernorm_embedding = nn.LayerNorm(embed_count)
132 | self.final_ln = nn.LayerNorm(embed_count)
133 | token_indices = torch.arange(text_token_count, device=device)
134 | self.pose_tokens = torch.stack([token_indices] * 2)
135 |
136 |
137 | def _init_weights(self, module: nn.Module) -> None:
138 | if isinstance(module, (nn.Linear, nn.Embedding)):
139 | module.weight.data.normal_(mean=0.0, std=0.02)
140 | if isinstance(module, nn.Linear) and module.bias is not None:
141 | module.bias.data.zero_()
142 | elif isinstance(module, nn.LayerNorm):
143 | module.bias.data.zero_()
144 | module.weight.data.fill_(1.0)
145 |
146 |
147 | def resize_token_embeddings(self, new_num_tokens):
148 |
149 | old_num_tokens, old_embedding_dim = self.embed_tokens.weight.size()
150 | new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
151 | new_embeddings.to(self.embed_tokens.weight.device, dtype=self.embed_tokens.weight.dtype)
152 | self._init_weights(new_embeddings)
153 | # numbers of tokens to copy
154 | n = min(old_num_tokens, new_num_tokens)
155 | new_embeddings.weight.data[:n, :] = self.embed_tokens.weight.data[:n, :]
156 | self.embed_tokens = new_embeddings
157 |
158 | return new_embeddings
159 |
160 |
161 | def forward(self, text_tokens: LongTensor) -> FloatTensor:
162 | attention_mask = text_tokens.not_equal(1)[:, None, None, :]
163 | # encoder_state = (
164 | # self.embed_tokens.forward(text_tokens) +
165 | # self.embed_positions.forward(self.pose_tokens)
166 | # )
167 | # print(text_tokens.shape)
168 | B, L = text_tokens.shape
169 | encoder_state = (
170 | self.embed_tokens.forward(text_tokens) +
171 | self.embed_positions.forward(torch.stack([torch.arange(L, device=text_tokens.device)]*B))
172 | )
173 | # print(encoder_state.requires_grad, encoder_state.grad_fn)
174 | encoder_state = self.layernorm_embedding.forward(encoder_state)
175 | # print(encoder_state.requires_grad, encoder_state.grad_fn)
176 | for layer in self.layers:
177 | encoder_state = layer.forward(encoder_state, attention_mask)
178 | # print(encoder_state.requires_grad, encoder_state.grad_fn)
179 | encoder_state = self.final_ln.forward(encoder_state)
180 | # print(encoder_state.requires_grad, encoder_state.grad_fn)
181 | return encoder_state
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/models/vqgan_detokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch import FloatTensor, LongTensor
4 | from math import sqrt
5 | from einops import rearrange
6 |
7 | class ResnetBlock(nn.Module):
8 | def __init__(self, log2_count_in: int, log2_count_out: int):
9 | super().__init__()
10 | m, n = 2 ** log2_count_in, 2 ** log2_count_out
11 | self.is_middle = m == n
12 | self.norm1 = nn.GroupNorm(2 ** 5, m)
13 | self.conv1 = nn.Conv2d(m, n, 3, padding=1)
14 | self.norm2 = nn.GroupNorm(2 ** 5, n)
15 | self.conv2 = nn.Conv2d(n, n, 3, padding=1)
16 | if not self.is_middle:
17 | self.nin_shortcut = nn.Conv2d(m, n, 1)
18 |
19 | def forward(self, x: FloatTensor) -> FloatTensor:
20 | h = x
21 | h = self.norm1.forward(h)
22 | h *= torch.sigmoid(h)
23 | h = self.conv1.forward(h)
24 | h = self.norm2.forward(h)
25 | h *= torch.sigmoid(h)
26 | h = self.conv2(h)
27 | if not self.is_middle:
28 | x = self.nin_shortcut.forward(x)
29 | return x + h
30 |
31 |
32 | class AttentionBlock(nn.Module):
33 | def __init__(self):
34 | super().__init__()
35 | n = 2 ** 9
36 | self.norm = nn.GroupNorm(2 ** 5, n)
37 | self.q = nn.Conv2d(n, n, 1)
38 | self.k = nn.Conv2d(n, n, 1)
39 | self.v = nn.Conv2d(n, n, 1)
40 | self.proj_out = nn.Conv2d(n, n, 1)
41 |
42 | def forward(self, x: FloatTensor) -> FloatTensor:
43 | n, m = 2 ** 9, x.shape[0]
44 | h = x
45 | h = self.norm(h)
46 | k = self.k.forward(h)
47 | v = self.v.forward(h)
48 | q = self.q.forward(h)
49 | k = k.reshape(m, n, -1)
50 | v = v.reshape(m, n, -1)
51 | q = q.reshape(m, n, -1)
52 | q = q.permute(0, 2, 1)
53 | w = torch.bmm(q, k)
54 | w /= n ** 0.5
55 | w = torch.softmax(w, dim=2)
56 | w = w.permute(0, 2, 1)
57 | h = torch.bmm(v, w)
58 | token_count = int(sqrt(h.shape[-1]))
59 | h = h.reshape(m, n, token_count, token_count)
60 | h = self.proj_out.forward(h)
61 | return x + h
62 |
63 |
64 | class MiddleLayer(nn.Module):
65 | def __init__(self):
66 | super().__init__()
67 | self.block_1 = ResnetBlock(9, 9)
68 | self.attn_1 = AttentionBlock()
69 | self.block_2 = ResnetBlock(9, 9)
70 |
71 | def forward(self, h: FloatTensor) -> FloatTensor:
72 | h = self.block_1.forward(h)
73 | h = self.attn_1.forward(h)
74 | h = self.block_2.forward(h)
75 | return h
76 |
77 |
78 | class Upsample(nn.Module):
79 | def __init__(self, log2_count):
80 | super().__init__()
81 | n = 2 ** log2_count
82 | self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
83 | self.conv = nn.Conv2d(n, n, 3, padding=1)
84 |
85 | def forward(self, x: FloatTensor) -> FloatTensor:
86 | x = self.upsample.forward(x.to(torch.float32))
87 | x = self.conv.forward(x)
88 | return x
89 |
90 |
91 | class Downsample(nn.Module):
92 | def __init__(self, log2_count):
93 | super().__init__()
94 | n = 2 ** log2_count
95 | self.conv = torch.nn.Conv2d(n, n, kernel_size=3, stride=2, padding=0)
96 |
97 | def forward(self, x):
98 | pad = (0,1,0,1)
99 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
100 | x = self.conv(x)
101 | return x
102 |
103 |
104 | class DownsampleBlock(nn.Module):
105 | def __init__(
106 | self,
107 | log2_count_in: int,
108 | log2_count_out: int,
109 | has_attention: bool,
110 | has_downsample: bool
111 | ):
112 | super().__init__()
113 | self.has_attention = has_attention
114 | self.has_downsample = has_downsample
115 |
116 | self.block = nn.ModuleList([
117 | ResnetBlock(log2_count_in, log2_count_out),
118 | ResnetBlock(log2_count_out, log2_count_out),
119 | ])
120 |
121 | if has_attention:
122 | self.attn = nn.ModuleList([
123 | AttentionBlock(),
124 | AttentionBlock(),
125 | ])
126 |
127 | if has_downsample:
128 | self.downsample = Downsample(log2_count_out)
129 |
130 | def forward(self, h: FloatTensor) -> FloatTensor:
131 | for j in range(2):
132 | h = self.block[j].forward(h)
133 | if self.has_attention:
134 | h = self.attn[j].forward(h)
135 | if self.has_downsample:
136 | h = self.downsample.forward(h)
137 | return h
138 |
139 |
140 |
141 | class UpsampleBlock(nn.Module):
142 | def __init__(
143 | self,
144 | log2_count_in: int,
145 | log2_count_out: int,
146 | has_attention: bool,
147 | has_upsample: bool
148 | ):
149 | super().__init__()
150 | self.has_attention = has_attention
151 | self.has_upsample = has_upsample
152 |
153 | self.block = nn.ModuleList([
154 | ResnetBlock(log2_count_in, log2_count_out),
155 | ResnetBlock(log2_count_out, log2_count_out),
156 | ResnetBlock(log2_count_out, log2_count_out),
157 | ])
158 |
159 | if has_attention:
160 | self.attn = nn.ModuleList([
161 | AttentionBlock(),
162 | AttentionBlock(),
163 | AttentionBlock()
164 | ])
165 |
166 | if has_upsample:
167 | self.upsample = Upsample(log2_count_out)
168 |
169 |
170 | def forward(self, h: FloatTensor) -> FloatTensor:
171 | for j in range(3):
172 | h = self.block[j].forward(h)
173 | if self.has_attention:
174 | h = self.attn[j].forward(h)
175 | if self.has_upsample:
176 | h = self.upsample.forward(h)
177 | return h
178 |
179 |
180 | class Encoder(nn.Module):
181 | def __init__(self):
182 | super().__init__()
183 |
184 | # downsampling
185 | self.conv_in = torch.nn.Conv2d(3, 2 ** 7, 3, stride=1, padding=1)
186 |
187 | self.down = nn.ModuleList([
188 | DownsampleBlock(7, 7, False, True),
189 | DownsampleBlock(7, 7, False, True),
190 | DownsampleBlock(7, 8, False, True),
191 | DownsampleBlock(8, 8, False, True),
192 | DownsampleBlock(8, 9, True, False),
193 | ])
194 |
195 | # middle
196 | self.mid = MiddleLayer()
197 |
198 | # end
199 | self.norm_out = nn.GroupNorm(2 ** 5, 2 ** 9)
200 | self.conv_out = nn.Conv2d(2 ** 9, 2 ** 8, 3, padding=1)
201 |
202 | def forward(self, z: FloatTensor) -> FloatTensor:
203 | z = self.conv_in.forward(z)
204 | for i in range(5):
205 | z = self.down[i].forward(z)
206 | z = self.mid.forward(z)
207 |
208 | z = self.norm_out.forward(z)
209 | z *= torch.sigmoid(z)
210 | z = self.conv_out.forward(z)
211 | return z
212 |
213 |
214 | class Decoder(nn.Module):
215 | def __init__(self):
216 | super().__init__()
217 |
218 | self.conv_in = nn.Conv2d(2 ** 8, 2 ** 9, 3, padding=1)
219 | self.mid = MiddleLayer()
220 |
221 | self.up = nn.ModuleList([
222 | UpsampleBlock(7, 7, False, False),
223 | UpsampleBlock(8, 7, False, True),
224 | UpsampleBlock(8, 8, False, True),
225 | UpsampleBlock(9, 8, False, True),
226 | UpsampleBlock(9, 9, True, True)
227 | ])
228 |
229 | self.norm_out = nn.GroupNorm(2 ** 5, 2 ** 7)
230 | self.conv_out = nn.Conv2d(2 ** 7, 3, 3, padding=1)
231 |
232 | def forward(self, z: FloatTensor) -> FloatTensor:
233 | z = self.conv_in.forward(z)
234 | z = self.mid.forward(z)
235 |
236 | for i in reversed(range(5)):
237 | z = self.up[i].forward(z)
238 |
239 | z = self.norm_out.forward(z)
240 | z *= torch.sigmoid(z)
241 | z = self.conv_out.forward(z)
242 | return z
243 |
244 | class VectorQuantizer2(nn.Module):
245 | """
246 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
247 | avoids costly matrix multiplications and allows for post-hoc remapping of indices.
248 | """
249 | # NOTE: due to a bug the beta term was applied to the wrong term. for
250 | # backwards compatibility we use the buggy version by default, but you can
251 | # specify legacy=False to fix it.
252 | def __init__(self, n_e=16384, e_dim=256, beta=0.25,
253 | sane_index_shape=False, legacy=True):
254 | super().__init__()
255 | self.n_e = n_e
256 | self.e_dim = e_dim
257 | self.beta = beta
258 | self.legacy = legacy
259 |
260 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
261 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
262 |
263 | self.re_embed = n_e
264 | self.remap = None
265 |
266 | self.sane_index_shape = sane_index_shape
267 |
268 | def remap_to_used(self, inds):
269 | ishape = inds.shape
270 | assert len(ishape)>1
271 | inds = inds.reshape(ishape[0],-1)
272 | used = self.used.to(inds)
273 | match = (inds[:,:,None]==used[None,None,...]).long()
274 | new = match.argmax(-1)
275 | unknown = match.sum(2)<1
276 | if self.unknown_index == "random":
277 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
278 | else:
279 | new[unknown] = self.unknown_index
280 | return new.reshape(ishape)
281 |
282 | def unmap_to_all(self, inds):
283 | ishape = inds.shape
284 | assert len(ishape)>1
285 | inds = inds.reshape(ishape[0],-1)
286 | used = self.used.to(inds)
287 | if self.re_embed > self.used.shape[0]: # extra token
288 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero
289 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
290 | return back.reshape(ishape)
291 |
292 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
293 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
294 | assert rescale_logits==False, "Only for interface compatible with Gumbel"
295 | assert return_logits==False, "Only for interface compatible with Gumbel"
296 | # reshape z -> (batch, height, width, channel) and flatten
297 | z = rearrange(z, 'b c h w -> b h w c').contiguous()
298 | z_flattened = z.view(-1, self.e_dim)
299 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
300 |
301 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
302 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
303 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
304 |
305 | min_encoding_indices = torch.argmin(d, dim=1)
306 | z_q = self.embedding(min_encoding_indices).view(z.shape)
307 | perplexity = None
308 | min_encodings = None
309 |
310 | # compute loss for embedding
311 | if not self.legacy:
312 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
313 | torch.mean((z_q - z.detach()) ** 2)
314 | else:
315 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
316 | torch.mean((z_q - z.detach()) ** 2)
317 |
318 | # preserve gradients
319 | z_q = z + (z_q - z).detach()
320 |
321 | # reshape back to match original input shape
322 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
323 |
324 | if self.remap is not None:
325 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
326 | min_encoding_indices = self.remap_to_used(min_encoding_indices)
327 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
328 |
329 | if self.sane_index_shape:
330 | min_encoding_indices = min_encoding_indices.reshape(
331 | z_q.shape[0], z_q.shape[2], z_q.shape[3])
332 |
333 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
334 |
335 | def get_codebook_entry(self, indices, shape):
336 | # shape specifying (batch, height, width, channel)
337 | if self.remap is not None:
338 | indices = indices.reshape(shape[0],-1) # add batch axis
339 | indices = self.unmap_to_all(indices)
340 | indices = indices.reshape(-1) # flatten again
341 |
342 | # get quantized latent vectors
343 | z_q = self.embedding(indices)
344 |
345 | if shape is not None:
346 | z_q = z_q.view(shape)
347 | # reshape back to match original input shape
348 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
349 |
350 | return z_q
351 |
352 |
353 | class VQGanDetokenizer(nn.Module):
354 | def __init__(self):
355 | super().__init__()
356 | vocab_count, embed_count = 2 ** 14, 2 ** 8
357 | self.vocab_count = vocab_count
358 | self.embedding = nn.Embedding(vocab_count, embed_count)
359 | self.post_quant_conv = nn.Conv2d(embed_count, embed_count, 1)
360 | self.decoder = Decoder()
361 |
362 | def forward(self, is_seamless: bool, z: LongTensor, grid: bool = False) -> FloatTensor:
363 | image_count = z.shape[0]
364 | grid_size = int(sqrt(z.shape[0]))
365 | token_count = grid_size * 2 ** 4
366 |
367 | if is_seamless:
368 | z = z.view([grid_size, grid_size, 2 ** 4, 2 ** 4])
369 | z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2)
370 | z = z.flatten().unsqueeze(1)
371 | z = self.embedding.forward(z)
372 | z = z.view((1, token_count, token_count, 2 ** 8))
373 | else:
374 | z = self.embedding.forward(z)
375 | z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
376 |
377 | z = z.permute(0, 3, 1, 2).contiguous()
378 | z = self.post_quant_conv.forward(z)
379 | z = self.decoder.forward(z)
380 | z = z.permute(0, 2, 3, 1)
381 | z = z.clip(0.0, 1.0) * 255
382 |
383 | if is_seamless:
384 | z = z[0]
385 | else:
386 | if grid:
387 | z = z.view([grid_size, grid_size, 2 ** 8, 2 ** 8, 3])
388 | z = z.flatten(1, 2).transpose(1, 0).flatten(1, 2)
389 | else:
390 | z = z.view([image_count, 2 ** 8, 2 ** 8, 3])
391 |
392 | return z
393 |
394 |
395 | class VQGanTokenizer(nn.Module):
396 | def __init__(self):
397 | super().__init__()
398 | vocab_count, embed_count = 2 ** 14, 2 ** 8
399 | self.vocab_count = vocab_count
400 | self.quant_conv = nn.Conv2d(embed_count, embed_count, 1)
401 | self.encoder = Encoder()
402 | self.quantize = VectorQuantizer2(sane_index_shape=True)
403 |
404 | def forward(self, x: LongTensor):
405 |
406 | h = self.encoder(x)
407 | # print('VQGan encoder output shape', h.shape)
408 | h = self.quant_conv(h)
409 | # print("Post-quant shape", h.shape)
410 | quant, emb_loss, info = self.quantize(h)
411 | # print(quant.shape, info[-1].shape)
412 | return quant, emb_loss, info
--------------------------------------------------------------------------------
/mega-story-dalle/min_dalle/text_tokenizer.py:
--------------------------------------------------------------------------------
1 | from math import inf
2 | from typing import List, Tuple
3 | from emoji import demojize
4 |
5 | class TextTokenizer:
6 | def __init__(self, vocab: dict, merges: List[str]):
7 | self.token_from_subword = vocab
8 | pairs = [tuple(pair.split()) for pair in merges]
9 | self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
10 | self.new_tokens = []
11 |
12 | def add_tokens(self, new_tokens):
13 |
14 | self.new_tokens = new_tokens
15 | original_length = len(self.token_from_subword)
16 | for t in new_tokens:
17 | self.token_from_subword[t] = len(self.token_from_subword)
18 | print("Increased vocabulary from %s to %s" % (original_length, len(self.token_from_subword)))
19 |
20 | def tokenize(self, text: str, is_verbose: bool = False) -> List[int]:
21 | sep_token = self.token_from_subword['']
22 | cls_token = self.token_from_subword['']
23 | unk_token = self.token_from_subword['']
24 | text = demojize(text, delimiters=['', ''])
25 | text = text.lower().encode("ascii", errors="ignore").decode()
26 |
27 | # tokens = [
28 | # self.token_from_subword.get(subword, unk_token)
29 | # for word in text.split(" ") if len(word) > 0
30 | # for subword in self.get_byte_pair_encoding(word, is_verbose)
31 | # ]
32 | # print([subword for word in text.split(" ") if len(word) > 0
33 | # for subword in self.get_byte_pair_encoding(word, is_verbose)])
34 |
35 | sub_words = []
36 | tokens = []
37 | for word in text.split(" "):
38 | if len(word) > 0:
39 | if word in self.new_tokens:
40 | tokens.append(self.token_from_subword[word])
41 | sub_words.append(word)
42 | else:
43 | for subword in self.get_byte_pair_encoding(word, is_verbose):
44 | tokens.append(self.token_from_subword.get(subword, unk_token))
45 | sub_words.append(subword)
46 | # print(sub_words)
47 |
48 | return [cls_token] + tokens + [sep_token]
49 |
50 | def get_byte_pair_encoding(self, word: str, is_verbose: bool) -> List[str]:
51 | def get_pair_rank(pair: Tuple[str, str]) -> int:
52 | return self.rank_from_pair.get(pair, inf)
53 |
54 | subwords = [chr(ord(" ") + 256)] + list(word)
55 | while len(subwords) > 1:
56 | pairs = list(zip(subwords[:-1], subwords[1:]))
57 | pair_to_merge = min(pairs, key=get_pair_rank)
58 | if pair_to_merge not in self.rank_from_pair: break
59 | i = pairs.index(pair_to_merge)
60 | subwords = (
61 | (subwords[:i] if i > 0 else []) +
62 | [subwords[i] + subwords[i + 1]] +
63 | (subwords[i + 2:] if i + 2 < len(subwords) else [])
64 | )
65 |
66 | if is_verbose: print(subwords)
67 | return subwords
--------------------------------------------------------------------------------
/mega-story-dalle/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | # from pathlib import Path
3 |
4 | setuptools.setup(
5 | name='min-dalle',
6 | description = 'min(DALL·E)',
7 | # long_description=(Path(__file__).parent / "README.rst").read_text(),
8 | version='0.4.11',
9 | author='Brett Kuprel',
10 | author_email='brkuprel@gmail.com',
11 | url='https://github.com/kuprel/min-dalle',
12 | packages=[
13 | 'min_dalle',
14 | 'min_dalle.models'
15 | ],
16 | license='MIT',
17 | install_requires=[
18 | 'torch>=1.11',
19 | 'typing_extensions>=4.1',
20 | 'numpy>=1.21',
21 | 'pillow>=7.1',
22 | 'requests>=2.23',
23 | 'emoji'
24 | ],
25 | keywords = [
26 | 'artificial intelligence',
27 | 'deep learning',
28 | 'text-to-image',
29 | 'pytorch'
30 | ]
31 | )
--------------------------------------------------------------------------------
/mega-story-dalle/tkinter_ui.py:
--------------------------------------------------------------------------------
1 | from min_dalle import MinDalle
2 | import sys
3 | import PIL
4 | import PIL.Image
5 | import PIL.ImageTk
6 | import tkinter
7 | from tkinter import ttk
8 |
9 | def regen_root():
10 | global root
11 | global blank_image
12 | global padding_image
13 |
14 | root = tkinter.Tk()
15 | root.wm_resizable(False, False)
16 |
17 | blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256 * 2, 256 * 2), mode="RGB"))
18 | padding_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(16, 16), mode="RGBA"))
19 |
20 | regen_root()
21 |
22 | is_mega = None
23 | def set_mega_true_and_destroy():
24 | global is_mega
25 | is_mega = True
26 | root.destroy()
27 | def set_mega_false_and_destroy():
28 | global is_mega
29 | is_mega = False
30 | root.destroy()
31 |
32 | frm = ttk.Frame(root, padding=16)
33 | frm.grid()
34 | ttk.Button(frm, text="Mega", command=set_mega_true_and_destroy).grid(column=0, row=0)
35 | ttk.Label(frm, image=padding_image).grid(column=1, row=0)
36 | ttk.Button(frm, text="Mini", command=set_mega_false_and_destroy).grid(column=2, row=0)
37 | root.mainloop()
38 |
39 | if is_mega is None:
40 | print("no option selected")
41 | sys.exit(0)
42 |
43 | print("is_mega", is_mega)
44 |
45 | model = MinDalle(
46 | models_root="./pretrained",
47 | is_mega=is_mega,
48 | is_reusable=True,
49 | is_verbose=True
50 | )
51 |
52 | regen_root()
53 |
54 | label_image_content = blank_image
55 |
56 | sv_prompt = tkinter.StringVar(value="artificial intelligence")
57 | sv_temperature = tkinter.StringVar(value="1")
58 | sv_topk = tkinter.StringVar(value="128")
59 | sv_supercond = tkinter.StringVar(value="16")
60 | bv_seamless = tkinter.BooleanVar(value=False)
61 |
62 | def generate():
63 | # check fields
64 | try:
65 | temperature = float(sv_temperature.get())
66 | except:
67 | sv_temperature.set("ERROR")
68 | return
69 | try:
70 | topk = int(sv_topk.get())
71 | except:
72 | sv_topk.set("ERROR")
73 | return
74 | try:
75 | supercond = int(sv_supercond.get())
76 | except:
77 | sv_supercond.set("ERROR")
78 | return
79 | try:
80 | is_seamless = bool(bv_seamless.get())
81 | except:
82 | return
83 | # and continue
84 | global label_image_content
85 | image_stream = model.generate_image_stream(
86 | sv_prompt.get(),
87 | grid_size=2,
88 | seed=-1,
89 | progressive_outputs=True,
90 | is_seamless=is_seamless,
91 | temperature=temperature,
92 | top_k=topk,
93 | supercondition_factor=supercond,
94 | is_verbose=True
95 | )
96 | for image in image_stream:
97 | global final_image
98 | final_image = image
99 | label_image_content = PIL.ImageTk.PhotoImage(image)
100 | label_image.configure(image=label_image_content)
101 | label_image.update()
102 |
103 | def save():
104 | final_image.save('generated/out.png')
105 |
106 | frm = ttk.Frame(root, padding=16)
107 | frm.grid()
108 |
109 | props = ttk.Frame(frm)
110 |
111 | # outer structure (hbox)
112 | label_image = ttk.Label(frm, image=blank_image)
113 | label_image.grid(column=0, row=0)
114 | ttk.Label(frm, image=padding_image).grid(column=1, row=0)
115 | props.grid(column=2, row=0)
116 |
117 | # inner structure (properties and shit)
118 | # prompt field
119 | ttk.Label(props, text="Prompt:").grid(column=0, row=0)
120 | ttk.Entry(props, textvariable=sv_prompt).grid(column=1, row=0)
121 | #
122 | ttk.Label(props, image=padding_image).grid(column=0, row=1)
123 | # temperature field
124 | ttk.Label(props, text="Temperature:").grid(column=0, row=2)
125 | ttk.Entry(props, textvariable=sv_temperature).grid(column=1, row=2)
126 | #
127 | ttk.Label(props, image=padding_image).grid(column=0, row=3)
128 | # topk field
129 | ttk.Label(props, text="Top-K:").grid(column=0, row=4)
130 | ttk.Entry(props, textvariable=sv_topk).grid(column=1, row=4)
131 | #
132 | ttk.Label(props, image=padding_image).grid(column=0, row=5)
133 | # superconditioning field
134 | ttk.Label(props, text="Supercondition Factor:").grid(column=0, row=6)
135 | ttk.Entry(props, textvariable=sv_supercond).grid(column=1, row=6)
136 | #
137 | ttk.Label(props, image=padding_image).grid(column=0, row=7)
138 | # seamless
139 | ttk.Label(props, text="Seamless:").grid(column=0, row=8)
140 | ttk.Checkbutton(props, variable=bv_seamless).grid(column=1, row=8)
141 | #
142 | ttk.Label(props, image=padding_image).grid(column=0, row=9)
143 | # buttons
144 | ttk.Button(props, text="Generate", command=generate).grid(column=0, row=10)
145 | ttk.Button(props, text="Quit", command=root.destroy).grid(column=1, row=10)
146 | ttk.Button(props, text="Save", command=save).grid(column=2, row=10)
147 |
148 | root.mainloop()
--------------------------------------------------------------------------------
/mega-story-dalle/train_story.sh:
--------------------------------------------------------------------------------
1 | if [ "$1" = "pororo" ]; then
2 | echo "Training on Pororo"
3 | DATA_DIR=../data/pororo/
4 | OUTPUT_ROOT=./out/pororo
5 | SENT_EMBED=512
6 | STORY_LEN=4
7 | LR=1e-4
8 | TRAIN_BS=1
9 | GRAD_ACC=4
10 | DATA_DIR=../data/flintstones
11 | OUTPUT_ROOT=./out/flintstones
12 | SENT_EMBED=512
13 | STORY_LEN=4
14 | LR=1e-5
15 | TRAIN_BS=1
16 | GRAD_ACC=4
17 | elif [ "$1" = "didemo" ]; then
18 | echo "Training on DiDeMo"
19 | DATA_DIR=../data/didemo
20 | OUTPUT_ROOT=./out/didemo
21 | SENT_EMBED=512
22 | STORY_LEN=2
23 | TRAIN_BS=1
24 | GRAD_ACC=8
25 | fi
26 |
27 | #--prefix_model_name_or_path './1.3B/' \
28 | #--model_name_or_path './1.3B/' \
29 |
30 | python ./train_t2i.py \
31 | --model_name_or_path './pretrained/' \
32 | --tuning_mode 'story' \
33 | --dataset_name $1 \
34 | --story_len $STORY_LEN \
35 | --sent_embed $SENT_EMBED \
36 | --prefix_dropout 0.2 \
37 | --data_dir $DATA_DIR \
38 | --dataloader_num_workers 4 \
39 | --output_dir $OUTPUT_ROOT \
40 | --log_dir /nas-ssd/adyasha/runs/ \
41 | --do_train --do_eval \
42 | --per_gpu_train_batch_size $TRAIN_BS \
43 | --per_gpu_eval_batch_size 2 \
44 | --num_train_epochs 50 \
45 | --gradient_accumulation_steps $GRAD_ACC \
46 | --learning_rate $LR \
47 | --logging_steps 50 \
48 | --eval_steps 500 \
49 | --generate_steps 1000 \
50 | --is_mega
--------------------------------------------------------------------------------
/story-dalle/1.3B/config.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | tokenizer_type: CharBPE
3 | context_length: 64
4 | image_resolution: 256
5 |
6 | stage1:
7 | type: vqgan
8 | embed_dim: 256
9 | n_embed: 16384
10 | hparams:
11 | double_z: False
12 | z_channels: 256
13 | resolution: 256
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult: [1, 1, 2, 2, 4]
18 | num_res_blocks: 2
19 | attn_resolutions: [16]
20 | pdrop: 0.0
21 |
22 | stage2:
23 | type: transformer1d
24 | vocab_size_txt: 16384
25 | vocab_size_img: 16384
26 | hparams:
27 | embed_dim: 1536
28 | n_layers: 42
29 | n_heads: 24
30 | n_dense_layers: 42
31 | ctx_len_img: 256
32 | ctx_len_txt: 64
33 | embd_pdrop: 0.0
34 | resid_pdrop: 0.0
35 | attn_pdrop: 0.0
36 | mlp_bias: True
37 | attn_bias: True
38 | gelu_use_approx: False
39 |
--------------------------------------------------------------------------------
/story-dalle/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/__init__.py
--------------------------------------------------------------------------------
/story-dalle/__pycache__/pororo_dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/__pycache__/pororo_dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/__init__.py
--------------------------------------------------------------------------------
/story-dalle/dalle/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/__pycache__/trainer_prefix.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/__pycache__/trainer_prefix.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/__pycache__/prefix_tuning_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/__pycache__/prefix_tuning_model.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/__pycache__/tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/__pycache__/tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage1/__pycache__/layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage1/__pycache__/layers.cpython-37.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage1/__pycache__/layers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage1/__pycache__/layers.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage1/__pycache__/vqgan.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage1/__pycache__/vqgan.cpython-37.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage1/__pycache__/vqgan.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage1/__pycache__/vqgan.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage1/layers.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import torch
7 | import torch.nn as nn
8 | from typing import Tuple, Optional
9 |
10 |
11 | def nonlinearity(x):
12 | # swish
13 | return x*torch.sigmoid(x)
14 |
15 |
16 | def Normalize(in_channels):
17 | return torch.nn.GroupNorm(num_groups=32,
18 | num_channels=in_channels,
19 | eps=1e-6,
20 | affine=True)
21 |
22 |
23 | class Upsample(nn.Module):
24 | def __init__(self, in_channels, with_conv):
25 | super().__init__()
26 | self.with_conv = with_conv
27 | if self.with_conv:
28 | self.conv = torch.nn.Conv2d(in_channels,
29 | in_channels,
30 | kernel_size=3,
31 | stride=1,
32 | padding=1)
33 |
34 | def forward(self, x):
35 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
36 | if self.with_conv:
37 | x = self.conv(x)
38 | return x
39 |
40 |
41 | class Downsample(nn.Module):
42 | def __init__(self, in_channels, with_conv):
43 | super().__init__()
44 | self.with_conv = with_conv
45 | if self.with_conv:
46 | # no asymmetric padding in torch conv, must do it ourselves
47 | self.conv = torch.nn.Conv2d(in_channels,
48 | in_channels,
49 | kernel_size=3,
50 | stride=2,
51 | padding=0)
52 |
53 | def forward(self, x):
54 | if self.with_conv:
55 | pad = (0, 1, 0, 1)
56 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
57 | x = self.conv(x)
58 | else:
59 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
60 | return x
61 |
62 |
63 | class ResnetBlock(nn.Module):
64 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
65 | dropout, temb_channels=512):
66 | assert temb_channels == 0
67 | super().__init__()
68 | self.in_channels = in_channels
69 | out_channels = in_channels if out_channels is None else out_channels
70 | self.out_channels = out_channels
71 | self.use_conv_shortcut = conv_shortcut
72 |
73 | self.norm1 = Normalize(in_channels)
74 | self.conv1 = torch.nn.Conv2d(in_channels,
75 | out_channels,
76 | kernel_size=3,
77 | stride=1,
78 | padding=1)
79 | self.norm2 = Normalize(out_channels)
80 | self.dropout = torch.nn.Dropout(dropout)
81 | self.conv2 = torch.nn.Conv2d(out_channels,
82 | out_channels,
83 | kernel_size=3,
84 | stride=1,
85 | padding=1)
86 | if self.in_channels != self.out_channels:
87 | if self.use_conv_shortcut:
88 | self.conv_shortcut = torch.nn.Conv2d(in_channels,
89 | out_channels,
90 | kernel_size=3,
91 | stride=1,
92 | padding=1)
93 | else:
94 | self.nin_shortcut = torch.nn.Conv2d(in_channels,
95 | out_channels,
96 | kernel_size=1,
97 | stride=1,
98 | padding=0)
99 |
100 | def forward(self, x, temb=None):
101 | assert temb is None
102 |
103 | h = x
104 | h = self.norm1(h)
105 | h = nonlinearity(h)
106 | h = self.conv1(h)
107 |
108 | h = self.norm2(h)
109 | h = nonlinearity(h)
110 | h = self.dropout(h)
111 | h = self.conv2(h)
112 |
113 | if self.in_channels != self.out_channels:
114 | if self.use_conv_shortcut:
115 | x = self.conv_shortcut(x)
116 | else:
117 | x = self.nin_shortcut(x)
118 | return x+h
119 |
120 |
121 | class AttnBlock(nn.Module):
122 | def __init__(self, in_channels):
123 | super().__init__()
124 | self.in_channels = in_channels
125 |
126 | self.norm = Normalize(in_channels)
127 | self.q = torch.nn.Conv2d(in_channels,
128 | in_channels,
129 | kernel_size=1,
130 | stride=1,
131 | padding=0)
132 | self.k = torch.nn.Conv2d(in_channels,
133 | in_channels,
134 | kernel_size=1,
135 | stride=1,
136 | padding=0)
137 | self.v = torch.nn.Conv2d(in_channels,
138 | in_channels,
139 | kernel_size=1,
140 | stride=1,
141 | padding=0)
142 | self.proj_out = torch.nn.Conv2d(in_channels,
143 | in_channels,
144 | kernel_size=1,
145 | stride=1,
146 | padding=0)
147 |
148 | def forward(self, x):
149 | h_ = x
150 | h_ = self.norm(h_)
151 | q = self.q(h_)
152 | k = self.k(h_)
153 | v = self.v(h_)
154 |
155 | # compute attention
156 | b, c, h, w = q.shape
157 | q = q.reshape(b, c, h*w)
158 | q = q.permute(0, 2, 1) # b,hw,c
159 | k = k.reshape(b, c, h*w) # b,c,hw
160 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
161 | w_ = w_ * (int(c)**(-0.5))
162 | w_ = torch.nn.functional.softmax(w_, dim=2)
163 |
164 | # attend to values
165 | v = v.reshape(b, c, h*w)
166 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
167 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
168 | h_ = h_.reshape(b, c, h, w)
169 |
170 | h_ = self.proj_out(h_)
171 | return x+h_
172 |
173 |
174 | class Encoder(nn.Module):
175 | def __init__(self,
176 | *, # forced to use named arguments
177 | ch: int,
178 | out_ch: int,
179 | ch_mult: Tuple[int] = (1, 2, 4, 8),
180 | num_res_blocks: int,
181 | attn_resolutions: Tuple[int],
182 | pdrop: float = 0.0,
183 | resamp_with_conv: bool = True,
184 | in_channels: int,
185 | resolution: int,
186 | z_channels: int,
187 | double_z: Optional[bool] = None) -> None:
188 | super().__init__()
189 | self.ch = ch
190 | self.temb_ch = 0
191 | self.num_resolutions = len(ch_mult)
192 | self.num_res_blocks = num_res_blocks
193 | self.resolution = resolution
194 | self.in_channels = in_channels
195 |
196 | # downsampling
197 | self.conv_in = torch.nn.Conv2d(in_channels,
198 | self.ch,
199 | kernel_size=3,
200 | stride=1,
201 | padding=1)
202 |
203 | curr_res = resolution
204 | in_ch_mult = (1,)+tuple(ch_mult)
205 | self.down = nn.ModuleList()
206 | for i_level in range(self.num_resolutions):
207 | block = nn.ModuleList()
208 | attn = nn.ModuleList()
209 | block_in = ch*in_ch_mult[i_level]
210 | block_out = ch*ch_mult[i_level]
211 | for i_block in range(self.num_res_blocks):
212 | block.append(ResnetBlock(in_channels=block_in,
213 | out_channels=block_out,
214 | temb_channels=self.temb_ch,
215 | dropout=pdrop))
216 | block_in = block_out
217 | if curr_res in attn_resolutions:
218 | attn.append(AttnBlock(block_in))
219 | down = nn.Module()
220 | down.block = block
221 | down.attn = attn
222 | if i_level != self.num_resolutions-1:
223 | down.downsample = Downsample(block_in, resamp_with_conv)
224 | curr_res = curr_res // 2
225 | self.down.append(down)
226 |
227 | # middle
228 | self.mid = nn.Module()
229 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
230 | out_channels=block_in,
231 | temb_channels=self.temb_ch,
232 | dropout=pdrop)
233 | self.mid.attn_1 = AttnBlock(block_in)
234 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
235 | out_channels=block_in,
236 | temb_channels=self.temb_ch,
237 | dropout=pdrop)
238 |
239 | # end
240 | self.norm_out = Normalize(block_in)
241 | self.conv_out = torch.nn.Conv2d(block_in,
242 | 2*z_channels if double_z else z_channels,
243 | kernel_size=3,
244 | stride=1,
245 | padding=1)
246 |
247 | def forward(self, x):
248 | assert x.shape[2] == x.shape[3] == self.resolution, \
249 | "{}, {}".format(x.shape, self.resolution)
250 |
251 | # downsampling
252 | h = self.conv_in(x)
253 | for i_level in range(self.num_resolutions):
254 | for i_block in range(self.num_res_blocks):
255 | h = self.down[i_level].block[i_block](h)
256 | if len(self.down[i_level].attn) > 0:
257 | h = self.down[i_level].attn[i_block](h)
258 | if i_level != self.num_resolutions-1:
259 | h = self.down[i_level].downsample(h)
260 |
261 | # middle
262 | h = self.mid.block_1(h)
263 | h = self.mid.attn_1(h)
264 | h = self.mid.block_2(h)
265 |
266 | # end
267 | h = self.norm_out(h)
268 | h = nonlinearity(h)
269 | h = self.conv_out(h)
270 | return h
271 |
272 |
273 | class Decoder(nn.Module):
274 | def __init__(self,
275 | *, # forced to use named arguments
276 | ch: int,
277 | out_ch: int,
278 | ch_mult: Tuple[int] = (1, 2, 4, 8),
279 | num_res_blocks: int,
280 | attn_resolutions: Tuple[int],
281 | pdrop: float = 0.0,
282 | resamp_with_conv: bool = True,
283 | in_channels: int,
284 | resolution: int,
285 | z_channels: int,
286 | double_z: bool) -> None:
287 | super().__init__()
288 | self.ch = ch
289 | self.temb_ch = 0
290 | self.num_resolutions = len(ch_mult)
291 | self.num_res_blocks = num_res_blocks
292 | self.resolution = resolution
293 | self.in_channels = in_channels
294 |
295 | # compute in_ch_mult, block_in and curr_res at lowest res
296 | block_in = ch*ch_mult[self.num_resolutions-1]
297 | curr_res = resolution // 2**(self.num_resolutions-1)
298 | self.z_shape = (1, z_channels, curr_res, curr_res)
299 |
300 | # z to block_in
301 | self.conv_in = torch.nn.Conv2d(z_channels,
302 | block_in,
303 | kernel_size=3,
304 | stride=1,
305 | padding=1)
306 |
307 | # middle
308 | self.mid = nn.Module()
309 | self.mid.block_1 = ResnetBlock(in_channels=block_in,
310 | out_channels=block_in,
311 | temb_channels=self.temb_ch,
312 | dropout=pdrop)
313 | self.mid.attn_1 = AttnBlock(block_in)
314 | self.mid.block_2 = ResnetBlock(in_channels=block_in,
315 | out_channels=block_in,
316 | temb_channels=self.temb_ch,
317 | dropout=pdrop)
318 |
319 | # upsampling
320 | self.up = nn.ModuleList()
321 | for i_level in reversed(range(self.num_resolutions)):
322 | block = nn.ModuleList()
323 | attn = nn.ModuleList()
324 | block_out = ch*ch_mult[i_level]
325 | for i_block in range(self.num_res_blocks+1):
326 | block.append(ResnetBlock(in_channels=block_in,
327 | out_channels=block_out,
328 | temb_channels=self.temb_ch,
329 | dropout=pdrop))
330 | block_in = block_out
331 | if curr_res in attn_resolutions:
332 | attn.append(AttnBlock(block_in))
333 | up = nn.Module()
334 | up.block = block
335 | up.attn = attn
336 | if i_level != 0:
337 | up.upsample = Upsample(block_in, resamp_with_conv)
338 | curr_res = curr_res * 2
339 | self.up.insert(0, up) # prepend to get consistent order
340 |
341 | # end
342 | self.norm_out = Normalize(block_in)
343 | self.conv_out = torch.nn.Conv2d(block_in,
344 | out_ch,
345 | kernel_size=3,
346 | stride=1,
347 | padding=1)
348 |
349 | def forward(self, z):
350 | assert z.shape[1:] == self.z_shape[1:]
351 | self.last_z_shape = z.shape
352 |
353 | # z to block_in
354 | h = self.conv_in(z)
355 |
356 | # middle
357 | h = self.mid.block_1(h)
358 | h = self.mid.attn_1(h)
359 | h = self.mid.block_2(h)
360 |
361 | # upsampling
362 | for i_level in reversed(range(self.num_resolutions)):
363 | for i_block in range(self.num_res_blocks+1):
364 | h = self.up[i_level].block[i_block](h)
365 | if len(self.up[i_level].attn) > 0:
366 | h = self.up[i_level].attn[i_block](h)
367 | if i_level != 0:
368 | h = self.up[i_level].upsample(h)
369 |
370 | h = self.norm_out(h)
371 | h = nonlinearity(h)
372 | h = self.conv_out(h)
373 | return h
374 |
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage1/vqgan.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3 | # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4 | # ------------------------------------------------------------------------------------
5 |
6 | import torch
7 | import torch.nn as nn
8 | from typing import List, Tuple, Optional
9 | from einops import rearrange
10 | from omegaconf import OmegaConf
11 | from .layers import Encoder, Decoder
12 |
13 |
14 | class VectorQuantizer(nn.Module):
15 | """
16 | Simplified VectorQuantizer in the original VQGAN repository
17 | by removing unncessary modules for sampling
18 | """
19 | def __init__(self, dim: int, n_embed: int, beta: float) -> None:
20 | super().__init__()
21 | self.n_embed = n_embed
22 | self.dim = dim
23 | self.beta = beta
24 |
25 | self.embedding = nn.Embedding(self.n_embed, self.dim)
26 | self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed)
27 |
28 | def forward(self,
29 | z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
30 | z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C]
31 | z_flattened = z.view(-1, self.dim)
32 |
33 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
34 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
35 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
36 |
37 | min_encoding_indices = torch.argmin(d, dim=1)
38 | z_q = self.embedding(min_encoding_indices).view(z.shape)
39 | return z_q, min_encoding_indices
40 |
41 | def get_codebook_entry(self,
42 | indices: torch.LongTensor,
43 | shape: Optional[List[int]] = None) -> torch.FloatTensor:
44 | z_q = self.embedding(indices)
45 | if shape is not None:
46 | z_q = z_q.view(shape)
47 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
48 | return z_q
49 |
50 |
51 | class VQGAN(nn.Module):
52 | def __init__(self, n_embed: int, embed_dim: int, hparams: OmegaConf) -> None:
53 | super().__init__()
54 | self.encoder = Encoder(**hparams)
55 | self.decoder = Decoder(**hparams)
56 | self.quantize = VectorQuantizer(dim=embed_dim, n_embed=n_embed, beta=0.25)
57 | self.quant_conv = torch.nn.Conv2d(hparams.z_channels, embed_dim, 1)
58 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, hparams.z_channels, 1)
59 | self.latent_dim = hparams.attn_resolutions[0]
60 |
61 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
62 | quant = self.encode(x)
63 | dec = self.decode(quant)
64 | return dec
65 |
66 | def encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
67 | h = self.encoder(x)
68 | h = self.quant_conv(h)
69 | quant = self.quantize(h)[0]
70 | quant = rearrange(quant, 'b h w c -> b c h w').contiguous()
71 | return quant
72 |
73 | def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
74 | quant = self.post_quant_conv(quant)
75 | dec = self.decoder(quant)
76 | return dec
77 |
78 | def decode_code(self, code: torch.LongTensor) -> torch.FloatTensor:
79 | quant = self.quantize.get_codebook_entry(code)
80 | quant = quant.permute(0, 3, 1, 2)
81 | dec = self.decode(quant)
82 | return dec
83 |
84 | def get_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
85 | h = self.encoder(x)
86 | h = self.quant_conv(h)
87 | codes = self.quantize(h)[1].view(x.shape[0], self.latent_dim ** 2)
88 | return codes
89 |
90 | def from_ckpt(self, path: str, strict: bool = True) -> None:
91 | ckpt = torch.load(path, map_location='cpu')['state_dict']
92 | self.load_state_dict(ckpt, strict=strict)
93 | print(f'{path} successfully restored..')
94 |
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage2/__pycache__/layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage2/__pycache__/layers.cpython-37.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage2/__pycache__/layers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage2/__pycache__/layers.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage2/__pycache__/transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage2/__pycache__/transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage2/__pycache__/transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/models/stage2/__pycache__/transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/models/stage2/layers.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Minimal DALL-E
3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 | # Modified from minGPT (https://github.com/karpathy/minGPT)
7 | # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
8 | # ------------------------------------------------------------------------------------
9 |
10 | import math
11 | import torch
12 | import torch.nn as nn
13 | from torch.nn import functional as F
14 |
15 |
16 | class GELU(nn.Module):
17 | def __init__(self, use_approx=False):
18 | super().__init__()
19 | self.use_approx = use_approx
20 |
21 | def forward(self, x):
22 | if self.use_approx:
23 | return x * torch.sigmoid(1.702 * x)
24 | else:
25 | return F.gelu(x)
26 |
27 |
28 | class MultiHeadSelfAttention(nn.Module):
29 |
30 | def __init__(self,
31 | ctx_len: int,
32 | embed_dim: int,
33 | n_heads: int,
34 | resid_pdrop: float,
35 | attn_pdrop: float,
36 | attn_bias: bool,
37 | use_mask: bool = True):
38 | super().__init__()
39 | assert embed_dim % n_heads == 0
40 |
41 | # key, query, value projections for all heads
42 | self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
43 | self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
44 | self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
45 |
46 | # regularization
47 | self.attn_drop = nn.Dropout(attn_pdrop)
48 | self.resid_drop = nn.Dropout(resid_pdrop)
49 |
50 | # output projection
51 | self.proj = nn.Linear(embed_dim, embed_dim, attn_bias)
52 |
53 | self.n_heads = n_heads
54 | self.ctx_len = ctx_len
55 | self.use_mask = use_mask
56 | if self.use_mask:
57 | self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False)
58 | self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len)
59 |
60 | def forward(self, x, use_cache=False, layer_past=None):
61 | B, T, C = x.shape
62 | x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
63 |
64 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
65 | k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
66 | q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
67 | v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
68 |
69 | if use_cache:
70 | present = torch.stack([k, v])
71 |
72 | if layer_past is not None:
73 | # print(layer_past.shape, k.shape, v.shape, q.shape)
74 | # print("LayerPast shape", layer_past.shape)
75 | past_key, past_value = layer_past
76 |
77 | if len(past_key.shape) == 4:
78 | _, _, seq_len, dim = past_key.shape
79 | k = torch.cat([past_key.reshape(-1, seq_len, dim), k], dim=-2)
80 | v = torch.cat([past_value.reshape(-1, seq_len, dim), v], dim=-2)
81 | elif len(past_key.shape) == 3:
82 | past_key, past_value = layer_past
83 | k = torch.cat([past_key, k], dim=-2)
84 | v = torch.cat([past_value, v], dim=-2)
85 | else:
86 | raise ValueError
87 |
88 | if use_cache and layer_past is not None:
89 | # Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K)
90 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
91 | att = F.softmax(att, dim=-1)
92 | att = self.attn_drop(att)
93 | y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs)
94 | else:
95 | # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T)
96 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
97 | if self.use_mask:
98 | # TODO : Flip when not prompt tunign
99 | # mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T]
100 | if T == self.ctx_len:
101 | mask = self.mask
102 | else:
103 | mask = torch.tril(torch.ones(T, T)).view(1, T, T).to(att.device)
104 | att = att.masked_fill(mask == 0, float('-inf'))
105 | att = F.softmax(att, dim=-1)
106 | att = self.attn_drop(att)
107 | y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
108 | y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
109 |
110 | # output projection
111 | y = self.resid_drop(self.proj(y))
112 | if use_cache:
113 | return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C)
114 | else:
115 | return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C)
116 |
117 | def forward_with_context(self, x, context, mask=None):
118 | B, T, C = x.shape
119 | x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
120 |
121 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
122 | q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
123 |
124 | B, T_c, C = context.shape
125 | k = self.key(context).view(T_c, B * self.n_heads, C // self.n_heads).transpose(0, 1) # (B*nh, T, hs)
126 | v = self.value(context).view(T_c, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
127 |
128 | # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, Tc) -> (B * nh, T, Tc)
129 | att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
130 | att = F.softmax(att, dim=-1)
131 | att = self.attn_drop(att)
132 | y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
133 | y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
134 |
135 | # output projection
136 | y = self.resid_drop(self.proj(y)).transpose(0, 1).contiguous()
137 | if mask is not None:
138 | y = y.masked_fill(mask == 0, float('0.0'))
139 | return y # (T, B, C) -> (B, T, C)
140 |
141 |
142 | class Block(nn.Module):
143 |
144 | def __init__(self,
145 | ctx_len: int,
146 | embed_dim: int,
147 | n_heads: int,
148 | mlp_bias: bool,
149 | attn_bias: bool,
150 | resid_pdrop: bool,
151 | attn_pdrop: bool,
152 | gelu_use_approx: bool):
153 | super().__init__()
154 | self.ln1 = nn.LayerNorm(embed_dim)
155 | self.ln2 = nn.LayerNorm(embed_dim)
156 |
157 | self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
158 | embed_dim=embed_dim,
159 | n_heads=n_heads,
160 | attn_pdrop=attn_pdrop,
161 | resid_pdrop=resid_pdrop,
162 | attn_bias=attn_bias,
163 | use_mask=True)
164 | self.mlp = nn.Sequential(
165 | nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias),
166 | GELU(gelu_use_approx),
167 | nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias),
168 | nn.Dropout(resid_pdrop),
169 | )
170 |
171 | def forward(self, x, layer_past=None):
172 | x = x + self.attn(self.ln1(x), layer_past=layer_past)
173 | x = x + self.mlp(self.ln2(x))
174 | return x
175 |
176 | def sample(self, x, layer_past=None):
177 | attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
178 | x = x + attn
179 | x = x + self.mlp(self.ln2(x))
180 | return x, present
181 |
182 | def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past=None):
183 | attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
184 | x = x + attn
185 | c_attn = cross_attn_layer(x, context, context_mask)
186 | x = x + c_attn
187 | x = x + self.mlp(self.ln2(x))
188 | return x, present
189 |
190 |
191 | class CrossAttentionLayer(nn.Module):
192 |
193 | def __init__(self,
194 | ctx_len: int,
195 | embed_dim: int,
196 | n_heads: int,
197 | attn_bias: bool,
198 | resid_pdrop: bool,
199 | attn_pdrop: bool):
200 | super().__init__()
201 |
202 | self.ln1 = nn.LayerNorm(embed_dim)
203 | self.ln2 = nn.LayerNorm(embed_dim)
204 | self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
205 | embed_dim=embed_dim,
206 | n_heads=n_heads,
207 | attn_pdrop=attn_pdrop,
208 | resid_pdrop=resid_pdrop,
209 | attn_bias=attn_bias,
210 | use_mask=False)
211 |
212 | def forward(self, x, context, context_mask=None):
213 | attn = self.attn.forward_with_context(self.ln1(x), self.ln2(context), context_mask)
214 | # x = x + attn
215 | # return x
216 | return attn
--------------------------------------------------------------------------------
/story-dalle/dalle/models/tokenizer.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Minimal DALL-E
3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | import os
8 | from functools import partial
9 | from tokenizers import CharBPETokenizer
10 |
11 |
12 | def build_tokenizer(path: str,
13 | context_length: int = 64,
14 | *args,
15 | **kwargs):
16 | try:
17 | from_file = partial(CharBPETokenizer.from_file,
18 | vocab_filename=os.path.join(path, 'bpe-16k-vocab.json'),
19 | merges_filename=os.path.join(path, 'bpe-16k-merges.txt'),
20 | unk_token='[UNK]')
21 | tokenizer = from_file(*args, **kwargs)
22 | except:
23 | from_file = partial(CharBPETokenizer.from_file,
24 | vocab_filename=os.path.join(path, 'vocab.json'),
25 | merges_filename=os.path.join(path, 'merges.txt'),
26 | unk_token='[UNK]')
27 | tokenizer = from_file(*args, **kwargs)
28 |
29 | # tokenizer = from_file(*args, **kwargs)
30 | tokenizer.add_special_tokens(['[PAD]'])
31 | tokenizer.enable_padding(length=context_length,
32 | pad_id=tokenizer.token_to_id('[PAD]'))
33 | tokenizer.enable_truncation(max_length=context_length)
34 | print(f'{path} successfully restored..')
35 | return tokenizer
36 |
--------------------------------------------------------------------------------
/story-dalle/dalle/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .config import *
3 | from .sampling import *
--------------------------------------------------------------------------------
/story-dalle/dalle/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/utils/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/utils/__pycache__/sampling.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/sampling.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/dalle/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/dalle/utils/config.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Minimal DALL-E
3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | from typing import Optional, List
8 | from dataclasses import dataclass, field
9 | from omegaconf import OmegaConf
10 |
11 |
12 | @dataclass
13 | class DataConfig:
14 | dataset: Optional[str] = None
15 | tokenizer_type: str = 'CharBPE'
16 | context_length: int = 64
17 | image_resolution: int = 256
18 | transforms: str = 'dalle-vqvae'
19 | bpe_pdrop: Optional[float] = None
20 |
21 |
22 | @dataclass
23 | class Stage1Hparams:
24 | double_z: bool = False
25 | z_channels: int = 256
26 | resolution: int = 256
27 | in_channels: int = 3
28 | out_ch: int = 3
29 | ch: int = 128
30 | ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
31 | num_res_blocks: int = 2
32 | attn_resolutions: List[int] = field(default_factory=lambda: [16])
33 | pdrop: float = 0.0
34 |
35 |
36 | @dataclass
37 | class Stage2Hparams:
38 | embed_dim: int = 1536
39 | n_layers: int = 42
40 | n_heads: int = 24
41 | n_dense_layers: int = 42
42 | ctx_len_img: int = 256
43 | ctx_len_txt: int = 64
44 | embd_pdrop: float = 0.0
45 | resid_pdrop: float = 0.0
46 | attn_pdrop: float = 0.0
47 | mlp_bias: bool = True
48 | attn_bias: bool = True
49 | gelu_use_approx: bool = False
50 | use_head_txt: bool = True
51 | n_classes: Optional[int] = None
52 |
53 |
54 | @dataclass
55 | class Stage1Config:
56 | type: str = 'vqgan'
57 | embed_dim: int = 256
58 | n_embed: int = 16384
59 | hparams: Stage1Hparams = Stage1Hparams()
60 |
61 |
62 | @dataclass
63 | class Stage2Config:
64 | type: str = 'transformer1d'
65 | vocab_size_txt: int = 16384
66 | vocab_size_img: int = 16384
67 | use_cls_cond: Optional[bool] = None
68 | hparams: Stage2Hparams = Stage2Hparams()
69 |
70 |
71 | @dataclass
72 | class WarmupConfig:
73 | epoch: int = 1
74 | multiplier: int = 1
75 | buffer_epoch: int = 0
76 | min_lr: float = 0.0
77 | mode: str = 'fix'
78 | peak_lr: float = 1e-4
79 | start_from_zero: bool = True
80 |
81 |
82 | @dataclass
83 | class OptConfig:
84 | opt_type: str = 'adamW'
85 | learning_rate: float = 5e-5
86 | weight_decay: float = 1e-4
87 | betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
88 | grad_clip_norm: float = 1.0
89 |
90 | sched_type: str = 'cosine'
91 | max_steps: int = 0
92 | min_lr: float = 1e-6
93 |
94 |
95 | @dataclass
96 | class ExpConfig:
97 | per_gpu_train_batch_size: int = 4
98 | per_gpu_eval_batch_size: int = 32
99 | num_train_epochs: int = 10
100 | save_ckpt_freq: int = 1
101 | test_freq: int = 10
102 | use_amp: bool = True
103 |
104 |
105 | @dataclass
106 | class PrefixModelConfig:
107 | model_name_or_path: Optional[str] = ''
108 | prefix_model_name_or_path: str = ''
109 | prefix_mode: str = 'activation'
110 | tuning_mode: str = 'finetune'
111 | top_k_layers: int = 2
112 | parameterize_mode: str = 'mlp'
113 | optim_prefix: bool = False
114 | preseqlen: int = 10
115 | prefix_dropout: float = 0.1
116 | init_random: bool = False
117 | hidden_dim_prefix: int = 512
118 | lowdata: bool = False
119 | lowdata_token: str = ''
120 | init_shallow: bool = False
121 | init_shallow_word: bool = False
122 | teacher_dropout: float = 0.1
123 | gumbel: bool = False
124 | replay_buffer: bool = False
125 |
126 |
127 | @dataclass
128 | class PromptModelConfig:
129 | model_name_or_path: Optional[str] = ''
130 | prefix_model_name_or_path: str = ''
131 | tuning_mode: str = 'prompt'
132 | preseqlen: int = 10
133 | prefix_dropout: float = 0.1
134 |
135 |
136 | @dataclass
137 | class StoryModelConfig:
138 | model_name_or_path: Optional[str] = ''
139 | prefix_model_name_or_path: str = ''
140 | tuning_mode: str = 'story'
141 | preseqlen: int = 10
142 | prefix_dropout: float = 0.1
143 | prompt: bool = False
144 | story_len: int = 4
145 | sent_embed: int = 256
146 | condition: bool = False
147 | clip_embed: bool = False
148 |
149 |
150 | @dataclass
151 | class DefaultConfig:
152 | dataset: DataConfig = DataConfig()
153 | stage1: Stage1Config = Stage1Config()
154 | stage2: Stage2Config = Stage2Config()
155 |
156 |
157 | @dataclass
158 | class FineTuningConfig:
159 | dataset: DataConfig = DataConfig()
160 | stage1: Stage1Config = Stage1Config()
161 | stage2: Stage2Config = Stage2Config()
162 | optimizer: OptConfig = OptConfig()
163 | experiment: ExpConfig = ExpConfig()
164 |
165 |
166 | @dataclass
167 | class PrefixTuningConfig:
168 | dataset: DataConfig = DataConfig()
169 | stage1: Stage1Config = Stage1Config()
170 | stage2: Stage2Config = Stage2Config()
171 | prefix: PrefixModelConfig = PrefixModelConfig()
172 | optimizer: OptConfig = OptConfig()
173 | experiment: ExpConfig = ExpConfig()
174 |
175 |
176 | @dataclass
177 | class PromptTuningConfig:
178 | dataset: DataConfig = DataConfig()
179 | stage1: Stage1Config = Stage1Config()
180 | stage2: Stage2Config = Stage2Config()
181 | prompt: PromptModelConfig = PromptModelConfig()
182 | optimizer: OptConfig = OptConfig()
183 | experiment: ExpConfig = ExpConfig()
184 |
185 |
186 | @dataclass
187 | class StoryConfig:
188 | dataset: DataConfig = DataConfig()
189 | stage1: Stage1Config = Stage1Config()
190 | stage2: Stage2Config = Stage2Config()
191 | story: StoryModelConfig = StoryModelConfig()
192 | optimizer: OptConfig = OptConfig()
193 | experiment: ExpConfig = ExpConfig()
194 |
195 |
196 | def get_base_config(mode):
197 | if mode == 'default':
198 | return OmegaConf.structured(DefaultConfig)
199 | elif mode == 'finetuning':
200 | return OmegaConf.structured(FineTuningConfig)
201 | elif mode == 'prefixtuning':
202 | return OmegaConf.structured(PrefixTuningConfig)
203 | elif mode == 'prompt_tuning':
204 | return OmegaConf.structured(PromptTuningConfig)
205 | elif mode == 'story':
206 | return OmegaConf.structured(StoryConfig)
207 | else:
208 | raise ValueError
209 | # return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)
210 |
--------------------------------------------------------------------------------
/story-dalle/dalle/utils/sampling.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Minimal DALL-E
3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | import torch
8 | from typing import Optional
9 | from tqdm import tqdm
10 | from torch.nn import functional as F
11 |
12 |
13 | torch.set_printoptions(precision=2, threshold=10)
14 | def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
15 | if k is None:
16 | return logits
17 | else:
18 | v, ix = torch.topk(logits, k)
19 | out = logits.clone()
20 | out[out < v[:, [-1]]] = -float('Inf')
21 | return out
22 |
23 |
24 | def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
25 | if p is None:
26 | return probs
27 | else:
28 | sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
29 | cum_probs = torch.cumsum(sorted_probs, dim=-1)
30 |
31 | sorted_idx_remove_cond = cum_probs >= p
32 |
33 | sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
34 | sorted_idx_remove_cond[..., 0] = 0
35 |
36 | indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
37 | probs = probs.masked_fill(indices_to_remove, 0.0)
38 | norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
39 | return norm_probs
40 |
41 |
42 | def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor:
43 | device = inputs.device
44 | if mode == '1d':
45 | B, N = inputs.shape
46 | xs_pos = torch.arange(N, device=device).repeat((B, 1))
47 | elif mode == '2d':
48 | B, H, W = inputs.shape
49 | xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
50 | xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
51 | xs_pos = (xs_pos_h, xs_pos_w)
52 | else:
53 | raise ValueError('%s positional encoding invalid' % mode)
54 | return xs_pos
55 |
56 |
57 | @torch.no_grad()
58 | def sampling(model: torch.nn.Module,
59 | tokens: torch.LongTensor,
60 | top_k: Optional[float] = None,
61 | top_p: Optional[float] = None,
62 | softmax_temperature: float = 1.0,
63 | is_tqdm: bool = True,
64 | use_fp16: bool = True,
65 | max_seq_len: int = 256,
66 | prompt: Optional[torch.tensor] = None,
67 | pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor:
68 |
69 | code = None
70 | past = None
71 |
72 | pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
73 | pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
74 |
75 | for cnt, h in enumerate(pbar):
76 | if code is None:
77 | code_ = None
78 | pos_enc_code_ = None
79 | else:
80 | code_ = code.clone().detach()
81 | pos_enc_code_ = get_positional_encoding(code_, mode='1d')
82 | code_ = code_[:, cnt-1].unsqueeze(-1)
83 | pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
84 |
85 | logits, present = model.sampling(images=code_,
86 | texts=tokens,
87 | pos_images=pos_enc_code_,
88 | pos_texts=pos_enc_tokens,
89 | use_fp16=use_fp16,
90 | past=past,
91 | prompt=prompt,
92 | pos_prompt=pos_prompt)
93 |
94 | logits = logits.to(dtype=torch.float32)
95 | logits = logits / softmax_temperature
96 |
97 | # print(len(present), present[0].shape)
98 | present = torch.stack(present).clone().detach()
99 | if past is None:
100 | past = [present]
101 | else:
102 | past.append(present)
103 |
104 | logits = cutoff_topk_logits(logits, top_k)
105 | probs = F.softmax(logits, dim=-1)
106 | probs = cutoff_topp_probs(probs, top_p)
107 | # print(probs[0])
108 |
109 | idx = torch.multinomial(probs, num_samples=1).clone().detach()
110 | # print(idx)
111 | code = idx if code is None else torch.cat([code, idx], axis=1)
112 |
113 | del past
114 | return code
115 |
116 |
117 | @torch.no_grad()
118 | def sampling_prefix(model: torch.nn.Module,
119 | tokens: torch.LongTensor,
120 | past: torch.FloatTensor,
121 | top_k: Optional[float] = None,
122 | top_p: Optional[float] = None,
123 | softmax_temperature: float = 1.0,
124 | is_tqdm: bool = True,
125 | use_fp16: bool = True,
126 | max_seq_len: int = 256,
127 | labels = None) -> torch.LongTensor:
128 | code = None
129 |
130 | pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
131 | pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
132 |
133 | # print("Entering sampling_prefix; ", past.shape)
134 | if past is not None:
135 | past = [past]
136 |
137 | for cnt, h in enumerate(pbar):
138 | if code is None:
139 | code_ = None
140 | pos_enc_code_ = None
141 | else:
142 | code_ = code.clone().detach()
143 | pos_enc_code_ = get_positional_encoding(code_, mode='1d')
144 | code_ = code_[:, cnt-1].unsqueeze(-1)
145 | pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
146 |
147 | # print("Looop enter")
148 | # print(cnt, past[0].shape)
149 | # print("-------------------")
150 | logits, present = model.sampling(images=code_,
151 | texts=tokens,
152 | pos_images=pos_enc_code_,
153 | pos_texts=pos_enc_tokens,
154 | use_fp16=use_fp16,
155 | past=past)
156 | logits = logits.to(dtype=torch.float32)
157 | logits = logits / softmax_temperature
158 |
159 | present = torch.stack(present).clone().detach()
160 |
161 | # print('Present', present.shape)
162 |
163 | if past is None:
164 | past = [present]
165 | else:
166 | # print("Loop end")
167 | # print(present.shape)
168 | # print("-----------------")
169 |
170 | # n_layers, temp, _, seq_len, n_dim = present.shape
171 | # _, _, bs, n_heads, pre_seq_len, n_dim = past[0].shape
172 | # assert temp == 2
173 | # past.append(present.view(n_layers, temp, bs, n_heads, seq_len, n_dim))
174 |
175 | past.append(present)
176 |
177 | logits = cutoff_topk_logits(logits, top_k)
178 | probs = F.softmax(logits, dim=-1)
179 | probs = cutoff_topp_probs(probs, top_p)
180 | print(torch.topk(probs, 5, dim=-1))
181 | if labels is not None:
182 | print(labels[cnt])
183 | idx = torch.multinomial(probs, num_samples=1).clone().detach()
184 | # print(idx)
185 | code = idx if code is None else torch.cat([code, idx], axis=1)
186 |
187 | del past
188 | return code
189 |
190 |
191 | @torch.no_grad()
192 | def sampling_prefix_new(model: torch.nn.Module,
193 | tokens: torch.LongTensor,
194 | past: torch.FloatTensor,
195 | top_k: Optional[float] = None,
196 | top_p: Optional[float] = None,
197 | softmax_temperature: float = 1.0,
198 | is_tqdm: bool = True,
199 | use_fp16: bool = True,
200 | max_seq_len: int = 256) -> torch.LongTensor:
201 | code = None
202 |
203 | pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
204 | pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
205 |
206 | # print("Entering sampling_prefix; ", past.shape)
207 | if past is not None:
208 | past = [past]
209 |
210 | for cnt, h in enumerate(pbar):
211 | if code is None:
212 | code_ = None
213 | pos_enc_code_ = None
214 | else:
215 | code_ = code.clone().detach()
216 | pos_enc_code_ = get_positional_encoding(code_, mode='1d')
217 | # code_ = code_[:, cnt-1].unsqueeze(-1)
218 | # pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
219 |
220 | # print("Looop enter")
221 | # print(cnt, past[0].shape)
222 | # print("-------------------")
223 |
224 | if cnt == 0:
225 | logits, present = model.sampling(images=code_,
226 | texts=tokens,
227 | pos_images=pos_enc_code_,
228 | pos_texts=pos_enc_tokens,
229 | use_fp16=use_fp16,
230 | past=past)
231 | logits = logits.to(dtype=torch.float32)
232 | logits = logits / softmax_temperature
233 |
234 | present = torch.stack(present).clone().detach()
235 |
236 | # print('Present', present.shape)
237 |
238 | if past is None:
239 | past = [present]
240 | else:
241 | pass
242 |
243 | logits = cutoff_topk_logits(logits, top_k)
244 | probs = F.softmax(logits, dim=-1)
245 | probs = cutoff_topp_probs(probs, top_p)
246 | # print(torch.topk(probs[0], 5))
247 | idx = torch.multinomial(probs, num_samples=1).clone().detach()
248 | # print(idx)
249 | code = idx if code is None else torch.cat([code, idx], axis=1)
250 |
251 | else:
252 | pass
253 |
254 |
255 | del past
256 | return code
257 |
258 | @torch.no_grad()
259 | def sampling_conditional(model: torch.nn.Module,
260 | cross_attention_idxs,
261 | cross_attention_layers,
262 | tokens: torch.LongTensor,
263 | src_codes: torch.FloatTensor,
264 | top_k: Optional[float] = None,
265 | top_p: Optional[float] = None,
266 | softmax_temperature: float = 1.0,
267 | is_tqdm: bool = True,
268 | use_fp16: bool = True,
269 | max_seq_len: int = 256,
270 | prompt: Optional[torch.tensor] = None,
271 | pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor:
272 |
273 | code = None
274 | past = None
275 |
276 | pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
277 | pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
278 |
279 | src_pos_tokens = get_positional_encoding(src_codes, mode='1d')
280 | src_tokens = model.tok_emb_img(src_codes)
281 | src_tokens = src_tokens + model.pos_emb_img(src_pos_tokens)
282 |
283 | for cnt, h in enumerate(pbar):
284 | if code is None:
285 | code_ = None
286 | pos_enc_code_ = None
287 | else:
288 | code_ = code.clone().detach()
289 | pos_enc_code_ = get_positional_encoding(code_, mode='1d')
290 | code_ = code_[:, cnt-1].unsqueeze(-1)
291 | pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
292 |
293 | logits, present = model.sampling_with_context(images=code_,
294 | cross_attention_idxs=cross_attention_idxs,
295 | cross_attention_layers=cross_attention_layers,
296 | texts=tokens,
297 | pos_images=pos_enc_code_,
298 | pos_texts=pos_enc_tokens,
299 | source_image=src_tokens,
300 | use_fp16=use_fp16,
301 | past=past,
302 | prompt=prompt,
303 | pos_prompt=pos_prompt)
304 | logits = logits.to(dtype=torch.float32)
305 | logits = logits / softmax_temperature
306 |
307 | present = torch.stack(present).clone().detach()
308 | if past is None:
309 | past = [present]
310 | else:
311 | past.append(present)
312 |
313 | logits = cutoff_topk_logits(logits, top_k)
314 | probs = F.softmax(logits, dim=-1)
315 | probs = cutoff_topp_probs(probs, top_p)
316 |
317 | idx = torch.multinomial(probs, num_samples=1).clone().detach()
318 | code = idx if code is None else torch.cat([code, idx], axis=1)
319 |
320 | del past
321 | return code
322 |
323 |
324 | @torch.no_grad()
325 | def sampling_igpt(model: torch.nn.Module,
326 | sos: torch.FloatTensor,
327 | top_k: Optional[float] = None,
328 | top_p: Optional[float] = None,
329 | softmax_temperature: float = 1.0,
330 | is_tqdm: bool = True,
331 | use_fp16: bool = True,
332 | max_seq_len: int = 256) -> torch.LongTensor:
333 | code = None
334 | past = None
335 | pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
336 |
337 | for cnt, h in enumerate(pbar):
338 | if code is None:
339 | code_ = None
340 | pos_enc_code_ = None
341 | else:
342 | code_ = code.clone().detach()
343 | pos_enc_code_ = get_positional_encoding(code_, mode='1d')
344 | code_ = code_[:, cnt-1].unsqueeze(-1)
345 | pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
346 |
347 | logits, present = model.sampling(sos=sos,
348 | codes=code_,
349 | pos_codes=pos_enc_code_,
350 | use_fp16=use_fp16,
351 | past=past)
352 | logits = logits.to(dtype=torch.float32)
353 | logits = logits / softmax_temperature
354 |
355 | present = torch.stack(present).clone().detach()
356 | if past is None:
357 | past = [present]
358 | else:
359 | past.append(present)
360 |
361 | logits = cutoff_topk_logits(logits, top_k)
362 | probs = F.softmax(logits, dim=-1)
363 | probs = cutoff_topp_probs(probs, top_p)
364 |
365 | idx = torch.multinomial(probs, num_samples=1).clone().detach()
366 | code = idx if code is None else torch.cat([code, idx], axis=1)
367 |
368 | del past
369 | return code
370 |
--------------------------------------------------------------------------------
/story-dalle/dalle/utils/utils.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------
2 | # Minimal DALL-E
3 | # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------
6 |
7 | import os
8 | import random
9 | import urllib
10 | import hashlib
11 | import tarfile
12 | import torch
13 | import clip
14 | import numpy as np
15 | from PIL import Image
16 | from torch.nn import functional as F
17 | from tqdm import tqdm
18 | import torchvision.utils as vutils
19 | import matplotlib.pyplot as plt
20 |
21 |
22 | def set_seed(seed: int):
23 | random.seed(seed)
24 | np.random.seed(seed)
25 | torch.manual_seed(seed)
26 | torch.cuda.manual_seed_all(seed)
27 |
28 |
29 | @torch.no_grad()
30 | def clip_score(prompt: str,
31 | images: np.ndarray,
32 | model_clip: torch.nn.Module,
33 | preprocess_clip,
34 | device: str) -> np.ndarray:
35 | images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
36 | images = torch.stack(images, dim=0).to(device=device)
37 | texts = clip.tokenize(prompt).to(device=device)
38 | texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
39 |
40 | image_features = model_clip.encode_image(images)
41 | text_features = model_clip.encode_text(texts)
42 |
43 | scores = F.cosine_similarity(image_features, text_features).squeeze()
44 | rank = torch.argsort(scores, descending=True).cpu().numpy()
45 | return rank
46 |
47 |
48 | def download(url: str, root: str) -> str:
49 | os.makedirs(root, exist_ok=True)
50 | filename = os.path.basename(url)
51 | pathname = filename[:-len('.tar.gz')]
52 |
53 | expected_md5 = url.split("/")[-2]
54 | download_target = os.path.join(root, filename)
55 | result_path = os.path.join(root, pathname)
56 |
57 | if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
58 | return result_path
59 |
60 | with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
61 | with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
62 | unit_divisor=1024) as loop:
63 | while True:
64 | buffer = source.read(8192)
65 | if not buffer:
66 | break
67 |
68 | output.write(buffer)
69 | loop.update(len(buffer))
70 |
71 | if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
72 | raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')
73 |
74 | with tarfile.open(download_target, 'r:gz') as f:
75 | pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
76 | for member in pbar:
77 | pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
78 | f.extract(member=member, path=root)
79 |
80 | return result_path
81 |
82 |
83 | def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
84 | if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
85 | return download(url_or_path, root)
86 | return url_or_path
87 |
88 |
89 | def images_to_numpy(tensor):
90 | generated = tensor.data.cpu().numpy().transpose(1,2,0)
91 | generated[generated < -1] = -1
92 | generated[generated > 1] = 1
93 | generated = (generated + 1) / 2 * 255
94 | return generated.astype('uint8')
95 |
96 |
97 | def save_image(ground_truth, images, out_dir, batch_idx):
98 |
99 | for i, im in enumerate(images):
100 | if len(im.shape) == 3:
101 | plt.imsave(os.path.join(out_dir, 'test_%s_%s.png' % (batch_idx, i)), im)
102 | else:
103 | bs = im.shape[0]
104 | # plt.imsave()
105 | for j in range(bs):
106 | plt.imsave(os.path.join(out_dir, 'test_%s_%s_%s.png' % (batch_idx, i, j)), im[j])
107 |
108 |
109 | # print("Ground truth Images shape: ", ground_truth.shape, len(images))
110 |
111 | # images = vutils.make_grid(images, nrow=ground_truth.shape[0])
112 | # images = images_to_numpy(images)
113 | #
114 | # if ground_truth is not None:
115 | # ground_truth = vutils.make_grid(ground_truth, 5)
116 | # ground_truth = images_to_numpy(ground_truth)
117 | # print("Ground Truth shape, Generated Images shape: ", ground_truth.shape, images.shape)
118 | # images = np.concatenate([ground_truth, images], axis=0)
119 | #
120 | # output = Image.fromarray(images)
121 | # output.save('%s/fake_samples_epoch_%03d.png' % (out_dir, batch_idx))
122 |
123 | # if texts is not None:
124 | # fid = open('%s/fake_samples_epoch_%03d_%03d.txt' % (image_dir, epoch, idx), 'w')
125 | # for idx in range(images.shape[0]):
126 | # fid.write(str(idx) + '--------------------------------------------------------\n')
127 | # for i in range(len(texts)):
128 | # fid.write(texts[i][idx] + '\n')
129 | # fid.write('\n\n')
130 | # fid.close()
131 | return
--------------------------------------------------------------------------------
/story-dalle/didemo_dataloader.py:
--------------------------------------------------------------------------------
1 | import os, json, pickle
2 | from tqdm import tqdm
3 | import numpy as np
4 | import torch.utils.data
5 | from torchvision import transforms
6 | from collections import Counter
7 | import nltk
8 | from PIL import Image
9 |
10 | class ImageDataset(torch.utils.data.Dataset):
11 | def __init__(self, img_folder, tokenizer, preprocess, mode='train'):
12 |
13 | self.lengths = []
14 | self.followings = []
15 | self.dir_path = img_folder
16 | if mode == 'train':
17 | self.file_path = os.path.join(img_folder, 'train.json')
18 | elif mode == 'val':
19 | self.file_path = os.path.join(img_folder, 'val.json')
20 | else:
21 | self.file_path = os.path.join(img_folder, 'test.json')
22 |
23 | min_len = 4
24 | self.total_frames = 0
25 | self.images = []
26 | if os.path.exists(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy')) and os.path.exists(
27 | os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')):
28 | self.images = np.load(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), encoding='latin1')
29 | self.followings = np.load(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'))
30 | else:
31 | print("Building image list cache")
32 | data = json.load(open(self.file_path))
33 | for ex in data:
34 | # Set the first image as the frame in the first description
35 | dont_use = False
36 | for tup in ex['desc_to_frame']:
37 | if not os.path.exists(os.path.join('/'.join(list(os.path.abspath(img_folder).split('/'))[:-1]), tup[1])):
38 | dont_use = True
39 | if dont_use:
40 | continue
41 | self.images.append(ex['desc_to_frame'][0][1])
42 | # Set remaining images to the rest of the images
43 | following_imgs = [tup[1] for tup in ex['desc_to_frame'][1:]]
44 | self.followings.append(following_imgs)
45 | np.save(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), self.images)
46 | np.save(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'), self.followings)
47 | print("Total number of clips {}".format(len(self.images)))
48 |
49 | self.descriptions_original = {tup[1]: tup[0] for ex in json.load(open(self.file_path)) for tup in ex['desc_to_frame']}
50 | self.preprocess = preprocess
51 | self.tokenizer = tokenizer
52 |
53 | def __getitem__(self, item):
54 |
55 | src_img_path = self.images[item]
56 |
57 | src_image_raw = Image.open(os.path.join(os.path.dirname(self.dir_path), src_img_path)).convert('RGB')
58 | src_image = self.preprocess(src_image_raw)
59 | # open the target image and caption
60 | caption = self.descriptions_original[src_img_path]
61 | tokens = self.tokenizer.encode(caption)
62 | tokens = torch.LongTensor(tokens.ids)
63 |
64 | return src_image, tokens
65 |
66 | def __len__(self):
67 | return len(self.images)
68 |
69 |
70 | class StoryImageDataset(torch.utils.data.Dataset):
71 | def __init__(self, img_folder, tokenizer, transform=None, mode='train'):
72 |
73 | self.lengths = []
74 | self.followings = []
75 | self.dir_path = img_folder
76 | if mode == 'train':
77 | self.file_path = os.path.join(img_folder, 'train.json')
78 | elif mode == 'val':
79 | self.file_path = os.path.join(img_folder, 'val.json')
80 | else:
81 | self.file_path = os.path.join(img_folder, 'test.json')
82 |
83 | min_len = 2
84 | self.total_frames = 0
85 | self.images = []
86 | if os.path.exists(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy')) and os.path.exists(
87 | os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')):
88 | self.images = np.load(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), encoding='latin1')
89 | self.followings = np.load(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'))
90 | else:
91 | print("Building image list cache")
92 | data = json.load(open(self.file_path))
93 | for ex in data:
94 | # Set the first image as the frame in the first description
95 | dont_use = False
96 | for tup in ex['desc_to_frame']:
97 | if not os.path.exists(os.path.join('/'.join(list(os.path.abspath(img_folder).split('/'))[:-1]), tup[1])):
98 | dont_use = True
99 | if dont_use:
100 | continue
101 | if len(ex['desc_to_frame']) < min_len+1:
102 | continue
103 | self.images.append(ex['desc_to_frame'][0][1])
104 | # Set remaining images to the rest of the images
105 | following_imgs = [tup[1] for tup in ex['desc_to_frame'][1:1+min_len]]
106 | self.followings.append(following_imgs)
107 | np.save(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), self.images)
108 | np.save(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'), self.followings)
109 | print("Total number of clips {}".format(len(self.images)))
110 | self.descriptions_original = {tup[1]: tup[0] for ex in json.load(open(self.file_path)) for tup in ex['desc_to_frame']}
111 |
112 | if mode == 'train':
113 | if transform:
114 | self.transform = transform
115 | else:
116 | self.transform = transforms.Compose([
117 | transforms.RandomResizedCrop(256),
118 | transforms.RandomHorizontalFlip(),
119 | transforms.ToTensor(),
120 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
121 | ])
122 | else:
123 | if transform:
124 | self.transform = transform
125 | else:
126 | self.transform = transforms.Compose([
127 | transforms.Resize(256),
128 | transforms.CenterCrop(256),
129 | transforms.ToTensor(),
130 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
131 | ])
132 |
133 | self.descriptions_vec = pickle.load(open(os.path.join(img_folder, 'didemo_use_embeds.pkl'), 'rb'))
134 | self.tokenizer = tokenizer
135 |
136 |
137 | def __getitem__(self, item):
138 |
139 | frame_path_list = [self.images[item]]
140 | for i in range(len(self.followings[item])):
141 | frame_path_list.append(str(self.followings[item][i]))
142 |
143 | images = []
144 | tokens = []
145 |
146 | for img_path in frame_path_list:
147 | im = self.transform(Image.open(os.path.join(os.path.dirname(os.path.normpath(self.dir_path)), img_path)))
148 | images.append(im)
149 | if self.tokenizer is not None:
150 | tokens.append(self.tokenizer.encode(self.descriptions_original[img_path]))
151 | else:
152 | tokens.append(self.descriptions_original[img_path])
153 |
154 | if self.tokenizer is not None:
155 | tokens = torch.stack([torch.LongTensor(token.ids) for token in tokens[1:]])
156 |
157 | sent_embeds = [torch.tensor(self.descriptions_vec[frame_path]) for frame_path in frame_path_list[1:]]
158 | return torch.stack(images[1:]), tokens, images[0], torch.stack(sent_embeds)
159 |
160 |
161 | def __len__(self):
162 | return len(self.images)
163 |
164 |
165 | class CopyImageDataset(torch.utils.data.Dataset):
166 | def __init__(self, img_folder, tokenizer, preprocess, mode='train', video_len=2):
167 |
168 | self.lengths = []
169 | self.followings = []
170 | self.dir_path = img_folder
171 | if mode == 'train':
172 | self.file_path = os.path.join(img_folder, 'train.json')
173 | elif mode == 'val':
174 | self.file_path = os.path.join(img_folder, 'val.json')
175 | else:
176 | self.file_path = os.path.join(img_folder, 'test.json')
177 | self.video_len = video_len
178 |
179 | min_len = 4
180 | self.total_frames = 0
181 | self.images = []
182 | if os.path.exists(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy')) and os.path.exists(
183 | os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')):
184 | self.images = np.load(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), encoding='latin1')
185 | self.followings = np.load(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'))
186 | else:
187 | print("Building image list cache")
188 | data = json.load(open(self.file_path))
189 | for ex in data:
190 | # Set the first image as the frame in the first description
191 | dont_use = False
192 | for tup in ex['desc_to_frame']:
193 | if not os.path.exists(os.path.join('/'.join(list(os.path.abspath(img_folder).split('/'))[:-1]), tup[1])):
194 | dont_use = True
195 | if dont_use:
196 | continue
197 | self.images.append(ex['desc_to_frame'][0][1])
198 | # Set remaining images to the rest of the images
199 | following_imgs = [tup[1] for tup in ex['desc_to_frame'][1:]]
200 | self.followings.append(following_imgs)
201 | np.save(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), self.images)
202 | np.save(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'), self.followings)
203 | print("Total number of clips {}".format(len(self.images)))
204 |
205 | self.descriptions_original = {tup[1]: tup[0] for ex in json.load(open(self.file_path)) for tup in ex['desc_to_frame']}
206 | self.preprocess = preprocess
207 | self.tokenizer = tokenizer
208 |
209 | def __getitem__(self, item):
210 |
211 | src_img_path = self.images[item]
212 | src_image_raw = Image.open(os.path.join(os.path.dirname(self.dir_path), src_img_path)).convert('RGB')
213 | src_image = self.preprocess(src_image_raw)
214 |
215 | tgt_img_paths = [str(self.followings[item][i]) for i in range(self.video_len)]
216 | # open the target image and caption
217 | tgt_images = [self.preprocess(Image.open(os.path.join(os.path.dirname(self.dir_path), tgt_img_path)).convert('RGB')) for tgt_img_path in tgt_img_paths]
218 |
219 | captions = [self.descriptions_original[tgt_img_path] for tgt_img_path in tgt_img_paths]
220 | tokens = [self.tokenizer.encode(caption) for caption in captions]
221 | tokens = [torch.LongTensor(token.ids) for token in tokens]
222 |
223 | return torch.stack(tgt_images), torch.stack(tokens), src_image
224 |
225 | def __len__(self):
226 | return len(self.images)
227 |
228 |
229 | class CopyStoryDataset(torch.utils.data.Dataset):
230 | def __init__(self, img_folder, preprocess, mode='train', max_t_len=72, video_len=5, resnet=False, condition_seq_len=0):
231 |
232 | self.lengths = []
233 | self.followings = []
234 | self.video_len = video_len
235 | min_len = video_len - 1
236 |
237 | if mode == 'train':
238 | self.file_path = os.path.join(img_folder, 'train.json')
239 | elif mode == 'val':
240 | self.file_path = os.path.join(img_folder, 'val.json')
241 | else:
242 | self.file_path = os.path.join(img_folder, 'test.json')
243 |
244 | self.dir_path = img_folder
245 |
246 | if os.path.exists(os.path.join(self.dir_path, 'dalle_vocab.pkl')):
247 | vocab_from_file = True
248 | vocab_file = os.path.join(self.dir_path, 'dalle_vocab.pkl')
249 | else:
250 | vocab_from_file = False
251 | vocab_file = os.path.join(self.dir_path, 'dalle_vocab.pkl')
252 |
253 | self.vocab = Vocabulary(vocab_threshold=5,
254 | vocab_file=vocab_file,
255 | annotations_file=os.path.join(self.dir_path, 'train.json'),
256 | vocab_from_file=vocab_from_file)
257 |
258 | print("Length of Vocabulary ", len(self.vocab))
259 |
260 | self.descriptions_original = {tup[1]: tup[0] for ex in json.load(open(self.file_path)) for tup in ex['desc_to_frame']}
261 |
262 | self.total_frames = 0
263 | self.images = []
264 | if os.path.exists(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy')) and os.path.exists(
265 | os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy')):
266 | self.images = np.load(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), encoding='latin1')
267 | self.followings = np.load(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'))
268 | else:
269 | print("Building image list cache")
270 | data = json.load(open(self.file_path))
271 | for ex in data:
272 | # Set the first image as the frame in the first description
273 | dont_use = False
274 | for tup in ex['desc_to_frame']:
275 | if not os.path.exists(os.path.join('/'.join(list(os.path.abspath(img_folder).split('/'))[:-1]), tup[1])):
276 | dont_use = True
277 | if dont_use:
278 | continue
279 | self.images.append(ex['desc_to_frame'][0][1])
280 | # Set remaining images to the rest of the images
281 | following_imgs = [tup[1] for tup in ex['desc_to_frame'][1:]]
282 | self.followings.append(following_imgs)
283 | np.save(os.path.join(img_folder, 'img_cache' + str(min_len) + '_' + mode + '.npy'), self.images)
284 | np.save(os.path.join(img_folder, 'following_cache' + str(min_len) + '_' + mode + '.npy'), self.followings)
285 | print("Total number of clips {}".format(len(self.images)))
286 |
287 | self.preprocess = preprocess
288 | self.max_t_len = max_t_len
289 |
290 | self.resnet = resnet
291 | im_input_size = 299
292 | if mode == 'train':
293 | self.transform = transforms.Compose([
294 | transforms.Resize((im_input_size, im_input_size)),
295 | transforms.RandomHorizontalFlip(),
296 | transforms.ToTensor(),
297 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
298 | ])
299 | else:
300 | self.transform = transforms.Compose([
301 | transforms.Resize((im_input_size, im_input_size)),
302 | transforms.RandomHorizontalFlip(),
303 | transforms.ToTensor(),
304 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
305 | ])
306 |
307 | self.condition_seq_len = condition_seq_len
308 |
309 | def __getitem__(self, item):
310 |
311 | # source image
312 | src_img_path = self.images[item].replace('didemo/', '')
313 | src_image_raw = Image.open(os.path.join(self.dir_path, src_img_path)).convert('RGB')
314 | # open the source images
315 | if self.resnet:
316 | src_image = self.transform(src_image_raw)
317 | else:
318 | src_image = self.preprocess(src_image_raw)
319 |
320 | # source caption
321 | src_caption = self.descriptions_original['didemo/' + src_img_path]
322 | src_text_tokens, src_text_mask = self.vocab._tokenize_pad_sentence(str(src_caption).lower(), self.max_t_len,
323 | condition=self.condition_seq_len)
324 |
325 | tgt_images = []
326 | tgt_text_tokens = [src_text_tokens]
327 | tgt_text_masks = [src_text_mask]
328 | for i in range(0, self.video_len-1):
329 | tgt_img_path = str(self.followings[item][i]).replace('didemo/', '')
330 | # open the target image and caption
331 | caption = self.descriptions_original['didemo/' + tgt_img_path]
332 | tgt_image = self.preprocess(Image.open(os.path.join(self.dir_path, tgt_img_path)).convert('RGB'))
333 | tgt_images.append(tgt_image.unsqueeze(0))
334 | # image = Image.open(os.path.join(self.out_dir, 'img-' + str(item) + '.png')).convert('RGB')
335 | text_tokens, text_mask = self.vocab._tokenize_pad_sentence(str(caption).lower(), self.max_t_len, condition=self.condition_seq_len)
336 | tgt_text_tokens.append(text_tokens)
337 | tgt_text_masks.append(text_mask)
338 |
339 | return torch.cat(tgt_images, dim=0), torch.LongTensor(tgt_text_tokens), torch.LongTensor(tgt_text_masks), src_image
340 |
341 | def __len__(self):
342 | return len(self.images)
343 |
344 |
345 | if __name__ == "__main__":
346 |
347 | dataset = StoryImageDataset('/nas-ssd/adyasha/datasets/didemo', None, None, 'val')
348 |
349 | all_captions = {}
350 | for item in range(len(dataset)):
351 |
352 | frame_path_list = [dataset.images[item]]
353 | for i in range(len(dataset.followings[item])):
354 | frame_path_list.append(str(dataset.followings[item][i]))
355 | captions = [dataset.descriptions_original[img_path] for img_path in frame_path_list]
356 | all_captions[item] = captions
357 |
358 | with open(os.path.join('/nas-ssd/adyasha/datasets/didemo', 'all_captions_val.json'), 'w') as f:
359 | json.dump(all_captions, f, indent=4)
--------------------------------------------------------------------------------
/story-dalle/eval_char_clf.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 | import torch.nn as nn
4 | from torchvision import models
5 | import torch
6 | from tqdm import tqdm
7 | import numpy as np
8 | from scipy.stats import entropy
9 | import os
10 | import PIL
11 | import torchvision.utils as vutils
12 | import argparse
13 | from sklearn.metrics import classification_report, accuracy_score
14 | from torchvision import transforms
15 |
16 | epsilon = 1e-7
17 |
18 |
19 |
20 | def set_parameter_requires_grad(model, feature_extracting):
21 | if feature_extracting:
22 | for param in model.parameters():
23 | param.requires_grad = False
24 |
25 | def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
26 | # Initialize these variables which will be set in this if statement. Each of these
27 | # variables is model specific.
28 | model_ft = None
29 | input_size = 0
30 |
31 | if model_name == "resnet":
32 | """ Resnet18
33 | """
34 | model_ft = models.resnet18(pretrained=use_pretrained)
35 | set_parameter_requires_grad(model_ft, feature_extract)
36 | num_ftrs = model_ft.fc.in_features
37 | model_ft.fc = nn.Linear(num_ftrs, num_classes)
38 | input_size = 224
39 |
40 | if model_name == "resnet50":
41 | """ Resnet50
42 | """
43 | model_ft = models.resnet50(pretrained=use_pretrained)
44 | set_parameter_requires_grad(model_ft, feature_extract)
45 | num_ftrs = model_ft.fc.in_features
46 | model_ft.fc = nn.Linear(num_ftrs, num_classes)
47 | input_size = 224
48 |
49 | elif model_name == "resnet101":
50 | """ Resnet50
51 | """
52 | model_ft = models.resnet101(pretrained=use_pretrained)
53 | set_parameter_requires_grad(model_ft, feature_extract)
54 | num_ftrs = model_ft.fc.in_features
55 | model_ft.fc = nn.Linear(num_ftrs, num_classes)
56 | input_size = 224
57 |
58 | elif model_name == "alexnet":
59 | """ Alexnet
60 | """
61 | model_ft = models.alexnet(pretrained=use_pretrained)
62 | set_parameter_requires_grad(model_ft, feature_extract)
63 | num_ftrs = model_ft.classifier[6].in_features
64 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
65 | input_size = 224
66 |
67 | elif model_name == "vgg":
68 | """ VGG11_bn
69 | """
70 | model_ft = models.vgg11_bn(pretrained=use_pretrained)
71 | set_parameter_requires_grad(model_ft, feature_extract)
72 | num_ftrs = model_ft.classifier[6].in_features
73 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
74 | input_size = 224
75 |
76 | elif model_name == "squeezenet":
77 | """ Squeezenet
78 | """
79 | model_ft = models.squeezenet1_0(pretrained=use_pretrained)
80 | set_parameter_requires_grad(model_ft, feature_extract)
81 | model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
82 | model_ft.num_classes = num_classes
83 | input_size = 224
84 |
85 |
86 | elif model_name == "densenet":
87 | """ Densenet
88 | """
89 | model_ft = models.densenet121(pretrained=use_pretrained)
90 | set_parameter_requires_grad(model_ft, feature_extract)
91 | num_ftrs = model_ft.classifier.in_features
92 | model_ft.classifier = nn.Linear(num_ftrs, num_classes)
93 | input_size = 224
94 |
95 | elif model_name == "inception":
96 | """ Inception v3
97 | Be careful, expects (299,299) sized images and has auxiliary output
98 | """
99 | model_ft = models.inception_v3(pretrained=use_pretrained)
100 | set_parameter_requires_grad(model_ft, feature_extract)
101 | # Handle the auxilary net
102 | num_ftrs = model_ft.AuxLogits.fc.in_features
103 | model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
104 | # Handle the primary net
105 | num_ftrs = model_ft.fc.in_features
106 | model_ft.fc = nn.Linear(num_ftrs,num_classes)
107 | input_size = 299
108 |
109 | else:
110 | print("Invalid model name, exiting...")
111 | exit()
112 |
113 | return model_ft, input_size
114 |
115 | def images_to_numpy(tensor):
116 | generated = tensor.data.cpu().numpy().transpose(1, 2, 0)
117 | generated[generated < -1] = -1
118 | generated[generated > 1] = 1
119 | generated = (generated + 1) / 2 * 255
120 | return generated.astype('uint8')
121 |
122 |
123 | def evaluate_gt(root_image_dir, model_name, model_path):
124 |
125 | if args.dataset == 'pororo':
126 | from pororo_dataloader import ImageDataset, StoryImageDataset
127 | elif args.dataset == 'flintstones':
128 | from flintstones_dataloader import ImageDataset, StoryImageDataset
129 | else:
130 | raise ValueError
131 |
132 | # Number of classes in the dataset
133 | num_classes = 9
134 | # when True we only update the reshaped layer params
135 | feature_extract = False
136 | video_len = 5
137 | n_channels = 3
138 |
139 | running_corrects = 0
140 | running_recalls = 0
141 | total_positives = 0
142 |
143 | phase = 'eval'
144 | is_inception = True if model_name == 'inception' else False
145 |
146 | model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=False)
147 | model_ft.load_state_dict(torch.load(model_path))
148 |
149 | # Detect if we have a GPU available
150 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
151 | # Send the model to GPU
152 | model_ft = model_ft.to(device)
153 | model_ft.eval() # Set model to evaluate mode
154 | image_dataset = ImageDataset(root_image_dir, input_size, mode='val')
155 | print("Number of samples in evaluation set: %s" % len(image_dataset))
156 | batch_size = 32
157 |
158 | # Create validation dataloaders
159 | dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
160 |
161 | print("Number of batches in evaluation dataloader: %s" % len(dataloader))
162 |
163 | all_predictions = []
164 | all_labels = []
165 | story_accuracy = 0
166 | image_accuracy = 0
167 |
168 | # Iterate over data.
169 | for i, (inputs, labels) in tqdm(enumerate(dataloader)):
170 |
171 | inputs = inputs.to(device)
172 | labels = labels.to(device)
173 |
174 | with torch.set_grad_enabled(phase == 'train'):
175 |
176 | outputs = model_ft(inputs)
177 | if model_name == 'imgD':
178 | outputs = model_ft.cate_classify(outputs).squeeze()
179 | preds = torch.round(nn.functional.sigmoid(outputs))
180 | all_predictions.append(preds.cpu().numpy())
181 | all_labels.append(labels.cpu().numpy())
182 |
183 | # statistics
184 | iter_corrects = torch.sum(preds == labels.float().data)
185 | xidxs, yidxs = torch.where(labels.data == 1)
186 | # print(xidxs, yidxs)
187 | # print([labels.data[xidx, yidx] == preds[xidx, yidx] for xidx, yidx in zip(xidxs, yidxs)])
188 | iter_recalls = sum(
189 | [x.item() for x in
190 | [labels.float().data[xidx, yidx] == preds[xidx, yidx] for xidx, yidx in zip(xidxs, yidxs)]])
191 | total_positives += xidxs.size(0)
192 |
193 | for l, p in zip(labels, preds):
194 | if torch.all(torch.eq(l.float().data, p)):
195 | image_accuracy += 1
196 |
197 | running_corrects += iter_corrects
198 | running_recalls += iter_recalls
199 |
200 | epoch_acc = running_corrects * 100 / (len(image_dataset) * num_classes)
201 | epoch_recall = running_recalls * 100 / total_positives
202 | print('{} Acc: {:.4f} Recall: {:.4f}%'.format(phase, epoch_acc, epoch_recall))
203 | print('{} Story Exact Match Acc: {:.4f}%'.format(phase, float(story_accuracy) * 100 / len(image_dataset)))
204 | print('{} Image Exact Match Acc: {:.4f}%'.format(phase, float(image_accuracy) * 100 / len(image_dataset)))
205 |
206 | all_predictions = np.concatenate(all_predictions, axis=0)
207 | all_labels = np.concatenate(all_labels, axis=0)
208 | print(all_predictions.shape, all_labels.shape, image_accuracy, len(image_dataset))
209 | preds = np.round(1 / (1 + np.exp(-all_predictions)))
210 | print(classification_report(all_labels, preds, digits=4))
211 |
212 | # for i in range(0, 9):
213 | # print("Character %s" % i)
214 | # print(classification_report(all_labels[:, i], preds[:, i]))
215 |
216 | # Inception Score
217 | # all_predictions = all_predictions + epsilon
218 | # py = np.mean(all_predictions, axis=0)
219 | # print(py, py.shape)
220 | # split_scores = []
221 | # splits = 10
222 | # N = all_predictions.shape[0]
223 | # for k in range(splits):
224 | # part = all_predictions[k * (N // splits): (k + 1) * (N // splits), :]
225 | # py = np.mean(part, axis=0)
226 | # scores = []
227 | #
228 | # for i in range(part.shape[0]):
229 | # pyx = part[i, :]
230 | # scores.append(entropy(pyx, py))
231 | # split_scores.append(np.exp(np.mean(scores)))
232 | # print("InceptionScore", np.mean(split_scores), np.std(split_scores))
233 |
234 |
235 | def evaluate(args):
236 |
237 | root_image_dir, model_name, model_path = args.img_ref_dir, args.model_name, args.model_path
238 |
239 | if args.dataset == 'pororo':
240 | from pororo_dataloader import ImageDataset, StoryImageDataset
241 | elif args.dataset == 'flintstones':
242 | from flintstones_dataloader import ImageDataset, StoryImageDataset
243 | else:
244 | raise ValueError
245 |
246 | # when True we only update the reshaped layer params
247 | feature_extract = False
248 | video_len = 5
249 | n_channels = 3
250 |
251 | phase = 'eval'
252 |
253 | model_ft, input_size = initialize_model(model_name, args.num_classes, feature_extract, use_pretrained=False)
254 | model_ft.load_state_dict(torch.load(model_path))
255 |
256 | img_transform = transforms.Compose([
257 | # Image.fromarray,
258 | transforms.Resize(input_size),
259 | transforms.CenterCrop(input_size),
260 | transforms.ToTensor(),
261 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
262 | ])
263 |
264 | # Detect if we have a GPU available
265 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
266 | # Send the model to GPU
267 | model_ft = model_ft.to(device)
268 | model_ft.eval() # Set model to evaluate mode
269 |
270 | # Create training and validation datasets
271 | try:
272 | image_dataset = StoryImageDataset(args.img_ref_dir, None, preprocess=img_transform,
273 | mode=args.mode,
274 | out_img_folder=args.img_gen_dir,
275 | return_labels=True)
276 | except TypeError:
277 | image_dataset = StoryImageDataset(args.img_ref_dir, None, transform=img_transform,
278 | mode=args.mode,
279 | out_img_folder=args.img_gen_dir,
280 | return_labels=True)
281 | print("Number of samples in evaluation set: %s" % len(image_dataset))
282 | batch_size = 20
283 |
284 | # Create validation dataloaders
285 | dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=1)
286 |
287 | print("Number of batches in evaluation dataloader: %s" % len(dataloader))
288 |
289 | all_predictions = []
290 | all_labels = []
291 | story_accuracy = 0
292 | image_accuracy = 0
293 |
294 | running_corrects = 0
295 | running_recalls = 0
296 | total_positives = 0
297 |
298 | # Iterate over data.
299 | no_char_images = 0
300 | for i, batch in tqdm(enumerate(dataloader)):
301 |
302 | inputs = batch[0]
303 | labels = batch[1]
304 |
305 | inputs = inputs.view(-1, n_channels, inputs.shape[-2], inputs.shape[-1])
306 | labels = labels.view(-1, labels.shape[-1])
307 | assert inputs.shape[0] == labels.shape[0]
308 | inputs = inputs.to(device)
309 | labels = labels.to(device)
310 |
311 | # forward
312 | # track history if only in train
313 | with torch.no_grad():
314 | # Get model outputs and calculate loss
315 | # Special case for inception because in training it has an auxiliary output. In train
316 | # mode we calculate the loss by summing the final output and the auxiliary output
317 | # but in testing we only consider the final output.
318 | outputs = model_ft(inputs)
319 | preds = torch.round(nn.functional.sigmoid(outputs))
320 | all_predictions.append(preds.cpu().numpy())
321 | all_labels.append(labels.cpu().numpy())
322 |
323 | # statistics
324 | iter_corrects = torch.sum(preds == labels.float().data)
325 | xidxs, yidxs = torch.where(labels.data == 1)
326 | # print(xidxs, yidxs)
327 | # print([labels.data[xidx, yidx] == preds[xidx, yidx] for xidx, yidx in zip(xidxs, yidxs)])
328 | iter_recalls = sum(
329 | [x.item() for x in [labels.float().data[xidx, yidx] == preds[xidx, yidx] for xidx, yidx in zip(xidxs, yidxs)]])
330 | total_positives += xidxs.size(0)
331 |
332 | labels = labels.view(-1, labels.shape[-1])
333 | preds = preds.view(-1, labels.shape[-1])
334 | assert labels.shape[0] == preds.shape[0]
335 |
336 |
337 | for label, pred in zip(labels, preds):
338 | if not torch.any(label):
339 | no_char_images += 1
340 | if torch.all(torch.eq(label.float().data, pred)):
341 | image_accuracy += 1
342 |
343 | running_corrects += iter_corrects
344 | running_recalls += iter_recalls
345 |
346 | print("Frames with no images: ", no_char_images)
347 |
348 | all_predictions = np.concatenate(all_predictions, axis=0)
349 | all_labels = np.concatenate(all_labels, axis=0)
350 | print(all_predictions.shape, all_labels.shape, image_accuracy, len(image_dataset))
351 | # preds = np.round(1 / (1 + np.exp(-all_predictions)))
352 | print(classification_report(all_labels, all_predictions, digits=4))
353 | print("Accuracy: ", accuracy_score(all_labels, all_predictions))
354 |
355 | epoch_acc = float(running_corrects) * 100 / (all_labels.shape[0] * all_labels.shape[1])
356 | epoch_recall = float(running_recalls) * 100 / total_positives
357 | print('Manually calculated accuracy: ', epoch_acc)
358 | print('{} Acc: {:.4f} Recall: {:.4f}%'.format(phase, accuracy_score(all_labels, all_predictions), epoch_recall))
359 | print('{} Image Exact Match (Frame) Acc: {:.4f}%'.format(phase, image_accuracy * 100 / all_labels.shape[0]))
360 |
361 |
362 | if __name__ == "__main__":
363 |
364 | parser = argparse.ArgumentParser(description='Evaluate for Character Recall & InceptionScore')
365 | parser.add_argument('--dataset', type=str, default='pororo')
366 | parser.add_argument('--img_ref_dir', type=str, required=True)
367 | parser.add_argument('--img_gen_dir', type=str, required=True)
368 | parser.add_argument('--num_classes', type=int, required=True)
369 | parser.add_argument('--model_path', type=str, required=True)
370 | parser.add_argument('--model_name', type=str, required=True)
371 | parser.add_argument('--mode', type=str, required=True)
372 | parser.add_argument('--ground_truth', action='store_true')
373 | args = parser.parse_args()
374 |
375 | if args.ground_truth:
376 | evaluate_gt(args)
377 | else:
378 | evaluate(args)
379 |
380 | # numpy_to_img(os.path.join(args.image_dir, 'images-epoch-%s.npy' % args.epoch_start),
381 | # os.path.join(args.image_dir, 'images-epoch-%s/' % args.epoch_start), 299)
382 |
383 |
--------------------------------------------------------------------------------
/story-dalle/eval_char_clf.sh:
--------------------------------------------------------------------------------
1 | #python eval_char_clf.py --dataset pororo --img_ref_dir /nas-ssd/adyasha/datasets/pororo_png/ --img_gen_dir /nas-ssd/adyasha/out/minDALLEs/pororo/test_images/images/ --model_name inception --model_path ./out/pororo-epoch-10.pt --mode test --num_classes 9
2 |
3 | python eval_char_clf.py --dataset flintstones --img_ref_dir /nas-ssd/adyasha/datasets/flintstones/ --img_gen_dir /nas-ssd/adyasha/out/minDALLEs/flintstones/test_images/images/ --model_name inception --model_path /playpen-ssd/adyasha/projects/StoryGAN/classifier/models/inception_32_1e-05/best.pt --mode test --num_classes 7
4 |
--------------------------------------------------------------------------------
/story-dalle/eval_fid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 | import PIL
4 | import argparse
5 | import functools
6 | import os
7 | from vfid.fid_score import fid_score
8 | import torchvision.datasets as datasets
9 |
10 | def main(args):
11 |
12 |
13 | image_transforms = transforms.Compose([
14 | transforms.Resize((args.imsize, args.imsize)),
15 | # transforms.RandomHorizontalFlip(),
16 | transforms.ToTensor(),
17 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
18 |
19 | if args.task == 'pororo':
20 | import pororo_dataloader as data
21 | elif args.task == 'flintstones':
22 | import flintstones_dataloader as data
23 | elif args.task == 'didemo':
24 | import didemo_dataloader as data
25 | else:
26 | raise ValueError
27 |
28 | try:
29 | ref_dataset = data.StoryImageDataset(args.img_ref_dir,
30 | None,
31 | preprocess=image_transforms,
32 | mode=args.mode)
33 | except TypeError:
34 | ref_dataset = data.StoryImageDataset(args.img_ref_dir,
35 | None,
36 | transform=image_transforms,
37 | mode=args.mode)
38 |
39 | gen_dataset = datasets.ImageFolder(root=args.img_gen_dir, transform=image_transforms)
40 |
41 |
42 | fid = fid_score(ref_dataset, gen_dataset, cuda=True, normalize=True, r_cache=os.path.join(args.img_ref_dir, 'fid_cache_%s.npz' % args.mode), batch_size=1)
43 | print('Frechet Image Distance: ', fid)
44 |
45 |
46 | if __name__ == "__main__":
47 |
48 | parser = argparse.ArgumentParser(description='Evaluate Frechet Story and Image distance')
49 | parser.add_argument('--img_ref_dir', type=str, required=True)
50 | parser.add_argument('--img_gen_dir', type=str, required=True)
51 | parser.add_argument('--mode', type=str, default='test')
52 | parser.add_argument('--task', type=str, default='pororo')
53 | parser.add_argument('--imsize', type=int, default=64)
54 | args = parser.parse_args()
55 |
56 | print(args)
57 | main(args)
58 |
--------------------------------------------------------------------------------
/story-dalle/eval_fid.sh:
--------------------------------------------------------------------------------
1 | python eval_fid.py --task pororo --img_ref_dir /nas-ssd/adyasha/datasets/pororo/ --img_gen_dir /nas-ssd/adyasha/out/minDALLEs/pororo/test_images/ --mode test
2 |
3 | python eval_fid.py --task didemo --img_ref_dir /nas-ssd/adyasha/datasets/didemo/ --img_gen_dir /nas-ssd/adyasha/out/minDALLEs/didemo/test_images/ --mode test
--------------------------------------------------------------------------------
/story-dalle/flintstones_dataloader.py:
--------------------------------------------------------------------------------
1 | import os, pickle
2 | from tqdm import tqdm
3 | import numpy as np
4 | import torch.utils.data
5 | import PIL
6 | from random import randrange
7 | import json
8 | from torchvision import transforms
9 | from PIL import Image
10 |
11 | unique_characters = ["Wilma", "Fred", "Betty", "Barney", "Dino", "Pebbles", "Mr Slate"]
12 |
13 | class ImageDataset(torch.utils.data.Dataset):
14 | def __init__(self, dir_path, tokenizer, preprocess, mode='train'):
15 | self.dir_path = dir_path
16 |
17 | splits = json.load(open(os.path.join(self.dir_path, 'train-val-test_split.json'), 'r'))
18 | train_id, val_id, test_id = splits["train"], splits["val"], splits["test"]
19 |
20 | if mode == 'train':
21 | self.orders = train_id
22 | elif mode =='val':
23 | self.orders = val_id
24 | elif mode == 'test':
25 | self.orders = test_id
26 | else:
27 | raise ValueError
28 | print("Total number of clips {}".format(len(self.orders)))
29 |
30 | annotations = json.load(open(os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json')))
31 | self.descriptions = {}
32 | for sample in annotations:
33 | self.descriptions[sample["globalID"]] = sample["description"]
34 |
35 | self.preprocess = preprocess
36 | self.tokenizer = tokenizer
37 |
38 | def __getitem__(self, item):
39 |
40 | # single image input
41 | globalID = self.orders[item]
42 | path = os.path.join(self.dir_path, 'video_frames_sampled', globalID + '.npy')
43 | arr = np.load(path)
44 | n_frames = arr.shape[0]
45 | random_range = randrange(n_frames)
46 | im = arr[random_range]
47 | image = np.array(im)
48 | image = PIL.Image.fromarray(image.astype('uint8'), 'RGB')
49 | text = self.descriptions[globalID]
50 | tokens = self.tokenizer.encode(text.lower())
51 | tokens = torch.LongTensor(tokens.ids)
52 | image = self.preprocess(image)
53 |
54 | return image, tokens
55 |
56 | def __len__(self):
57 | return len(self.orders)
58 |
59 |
60 | class StoryImageDataset(torch.utils.data.Dataset):
61 | def __init__(self, dir_path, tokenizer, transform=None, mode='train', im_input_size=128, out_img_folder='', return_labels=False):
62 | self.dir_path = dir_path
63 |
64 | splits = json.load(open(os.path.join(self.dir_path, 'train-val-test_split.json'), 'r'))
65 | train_id, val_id, test_id = splits["train"], splits["val"], splits["test"]
66 |
67 | min_len = 4
68 | if os.path.exists(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl')):
69 | self.followings = pickle.load(open(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl'), 'rb'))
70 | else:
71 | print("Cache does not exist")
72 | all_clips = train_id + val_id + test_id
73 | all_clips.sort()
74 | for idx, clip in enumerate(tqdm(all_clips, desc="Counting total number of frames")):
75 | season, episode = int(clip.split('_')[1]), int(clip.split('_')[3])
76 | has_frames = True
77 | for c in all_clips[idx+1:idx+min_len+1]:
78 | s_c, e_c = int(c.split('_')[1]), int(c.split('_')[3])
79 | if s_c != season or e_c != episode:
80 | has_frames = False
81 | break
82 | if has_frames:
83 | self.followings[clip] = all_clips[idx+1:idx+min_len+1]
84 | else:
85 | continue
86 | pickle.dump(self.followings, open(os.path.join(self.dir_path, 'following_cache' + str(min_len) + '.pkl'), 'wb'))
87 |
88 | train_id = [tid for tid in train_id if tid in self.followings]
89 | val_id = [vid for vid in val_id if vid in self.followings]
90 | test_id = [tid for tid in test_id if tid in self.followings]
91 |
92 | self.labels = pickle.load(open(os.path.join(dir_path, 'labels.pkl'), 'rb'))
93 |
94 | if mode == 'train':
95 | self.orders = train_id
96 | elif mode =='val':
97 | val_id = [vid for vid in val_id if len(self.followings[vid]) == 4]
98 | self.orders = val_id
99 | elif mode == 'test':
100 | test_id = [vid for vid in test_id if len(self.followings[vid]) == 4]
101 | self.orders = test_id[:1900]
102 | else:
103 | raise ValueError
104 | print("Total number of clips {}".format(len(self.orders)))
105 |
106 | annotations = json.load(open(os.path.join(self.dir_path, 'flintstones_annotations_v1-0.json')))
107 | self.descriptions = {}
108 | for sample in annotations:
109 | self.descriptions[sample["globalID"]] = sample["description"]
110 |
111 | self.embeds = np.load(os.path.join(self.dir_path, "flintstones_use_embeddings.npy"))
112 | self.sent2idx = pickle.load(open(os.path.join(self.dir_path, 'flintstones_use_embed_idxs.pkl'), 'rb'))
113 |
114 | if mode == 'train':
115 | if transform:
116 | self.transform = transform
117 | else:
118 | self.transform = transforms.Compose([
119 | transforms.RandomResizedCrop(im_input_size),
120 | transforms.RandomHorizontalFlip(),
121 | transforms.ToTensor(),
122 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
123 | ])
124 | else:
125 | if transform:
126 | self.transform = transform
127 | else:
128 | self.transform = transforms.Compose([
129 | transforms.Resize(im_input_size),
130 | transforms.CenterCrop(im_input_size),
131 | transforms.ToTensor(),
132 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
133 | ])
134 |
135 | self.tokenizer = tokenizer
136 | self.return_labels = return_labels
137 | self.out_img_folder = out_img_folder
138 |
139 | def __getitem__(self, item):
140 |
141 | # single image input
142 | globalIDs = [self.orders[item]] + self.followings[self.orders[item]]
143 | tokens = []
144 | images = []
145 | for idx, globalID in enumerate(globalIDs):
146 | if self.out_img_folder and idx != 0:
147 | image = Image.open(os.path.join(self.out_img_folder, 'gen_sample_%s_%s.png' % (item, idx-1))).convert('RGB')
148 | else:
149 | path = os.path.join(self.dir_path, 'video_frames_sampled', globalID + '.npy')
150 | arr = np.load(path)
151 | n_frames = arr.shape[0]
152 | random_range = randrange(n_frames)
153 | im = arr[random_range]
154 | # image = np.array(im)
155 | image = PIL.Image.fromarray(im.astype('uint8'), 'RGB')
156 | images.append(image)
157 | text = self.descriptions[globalID]
158 | if idx != 0:
159 | if self.tokenizer is not None:
160 | tokens.append(self.tokenizer.encode(text.lower()))
161 | else:
162 | tokens.append(text)
163 | if self.tokenizer is not None:
164 | tokens = torch.stack([torch.LongTensor(token.ids) for token in tokens])
165 |
166 | sent_embeds = [torch.tensor(self.embeds[self.sent2idx[globalID]]) for globalID in globalIDs[1:]]
167 |
168 | if self.return_labels:
169 | labels = [torch.tensor(self.labels[globalID]) for globalID in globalIDs[1:]]
170 | return torch.stack([self.transform(im) for im in images[1:]]), torch.stack(labels), tokens, self.transform(
171 | images[0]), torch.stack(sent_embeds)
172 | else:
173 | return torch.stack([self.transform(im) for im in images[1:]]), tokens, self.transform(images[0]), torch.stack(sent_embeds)
174 |
175 | def __len__(self):
176 | return len(self.orders)
177 |
178 |
179 | # if __name__ == "__main__":
180 | #
181 | # dataset = StoryImageDataset('/nas-ssd/adyasha/datasets/flintstones', None, None, 'val')
182 | # for item in range(len(dataset)):
183 | # texts = []
184 | # globalIDs = [dataset.orders[item]] + dataset.followings[dataset.orders[item]]
185 | # for idx, globalID in enumerate(globalIDs):
186 | # text = dataset.descriptions[globalID]
187 | # texts.append(text)
188 | # if len(texts) != 5:
189 | # print(item, globalIDs)
190 |
191 |
--------------------------------------------------------------------------------
/story-dalle/get_use_embeddings.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import pickle
4 | from tqdm import tqdm
5 | import numpy as np
6 | import sys
7 | import os
8 |
9 | def get_embeddings(dataset='didemo', data_dir = ''):
10 |
11 |
12 | if dataset == 'flintstones':
13 | annotations = json.load(open('../flintstones/flintstones_annotations_v1-0.json', 'r'))
14 | globalIDs = [s["globalID"] for s in annotations]
15 | descriptions = [s["description"] for s in annotations]
16 | elif dataset == 'didemo':
17 | descriptions = {}
18 | for filepath in ['train.json', 'val.json', 'test.json']:
19 | d = {tup[1]: tup[0] for ex in json.load(open(os.path.join(data_dir, filepath))) for tup in ex['desc_to_frame']}
20 | descriptions.update(d)
21 | elif dataset == 'mpii':
22 | all_keys = []
23 | descriptions = {}
24 | for filepath in ['train.json', 'val.json', 'test.json']:
25 | data = json.load(open(os.path.join(data_dir, filepath)))
26 | print(len(data))
27 | for ex in tqdm(data):
28 | for tup in ex['desc_to_frame']:
29 | k = tup[1].replace('/ssd-playpen/dhannan/StoryDatasets/mpii/', '')
30 | all_keys.append(k)
31 | descriptions[k] = tup[0]
32 | print(len(descriptions), len(all_keys), len(set(all_keys)))
33 | print(set(all_keys))
34 | else:
35 | raise ValueError
36 | sys.exit()
37 |
38 | import tensorflow_hub as hub
39 | embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5")
40 | all_embeddings = None
41 | bs = 128
42 | description_keys = list(descriptions.keys())
43 | for i in tqdm(range(0, math.ceil(len(description_keys)/bs)), desc="Extraction seed embeddings"):
44 | # embeddings = embed(descriptions[i*bs:(i+1)*bs]).numpy()
45 | embeddings = embed([descriptions[k] for k in description_keys[i * bs:(i + 1) * bs]]).numpy()
46 | if all_embeddings is None:
47 | all_embeddings = embeddings
48 | else:
49 | all_embeddings = np.concatenate([all_embeddings, embeddings], axis=0)
50 | print(all_embeddings.shape, len(description_keys))
51 | embeddings = {k: v for v, k in zip(all_embeddings, description_keys)}
52 | pickle.dump(embeddings, open(os.path.join(data_dir, '%s_use_embed_idxs.pkl' % dataset), 'wb'))
53 |
54 | # np.save(os.path.join(data_dir, '%s_use_embeddings.npy' % dataset), all_embeddings)
55 | # pickle.dump({key: val for val, key in enumerate(globalIDs)}, open(os.path.join(data_dir, '%s_use_embed_idxs.pkl'), 'wb'))
56 |
57 |
58 | # get_embeddings('didemo', '/nas-ssd/adyasha/datasets/didemo')
59 | get_embeddings('mpii', '/nas-ssd/adyasha/datasets/mpii')
--------------------------------------------------------------------------------
/story-dalle/infer_story.sh:
--------------------------------------------------------------------------------
1 | if [ "$1" = "pororo" ]; then
2 | echo "Evaluating on Pororo"
3 | DATA_DIR=../data/pororo
4 | OUTPUT_ROOT=../out/pororo
5 | MODEL_CKPT=''
6 | SENT_EMBED=512
7 | STORY_LEN=4
8 | elif [ "$1" = "flintstones" ]; then
9 | echo "Evaluating on Flintstones"
10 | DATA_DIR=../data/flintstones
11 | OUTPUT_ROOT=./out/flintstones
12 | MODEL_CKPT=''
13 | SENT_EMBED=512
14 | STORY_LEN=4
15 | elif [ "$1" = "didemo" ]; then
16 | echo "Evaluating on DiDeMo"
17 | DATA_DIR=../data/didemo
18 | OUTPUT_ROOT=./out/didemo
19 | MODEL_CKPT=''
20 | SENT_EMBED=512
21 | STORY_LEN=2
22 | fi
23 |
24 |
25 | python ./infer_t2i.py \
26 | --model_name_or_path $MODEL_CKPT \
27 | --prefix_model_name_or_path './1.3B/' \
28 | --dataset_name $1 \
29 | --tuning_mode story \
30 | --dataset_name $1 \
31 | --preseqlen 32 \
32 | --condition \
33 | --story_len $STORY_LEN \
34 | --sent_embed $SENT_EMBED \
35 | --prefix_dropout 0.2 \
36 | --data_dir $DATA_DIR \
37 | --dataloader_num_workers 1 \
38 | --do_eval \
39 | --per_gpu_eval_batch_size 16 \
40 | --output_dir $OUTPUT_ROOT \
41 | --mode $2
42 |
--------------------------------------------------------------------------------
/story-dalle/train_story.sh:
--------------------------------------------------------------------------------
1 | if [ "$1" = "pororo" ]; then
2 | echo "Training on Pororo"
3 | DATA_DIR=../data/pororo/
4 | OUTPUT_ROOT=./out/pororo
5 | SENT_EMBED=512
6 | STORY_LEN=4
7 | LR=1e-4
8 | TRAIN_BS=1
9 | GRAD_ACC=4
10 | elif [ "$1" = "flintstones" ]; then
11 | echo "Training on Flintstones"
12 | DATA_DIR=../data/flintstones
13 | OUTPUT_ROOT=./out/flintstones
14 | SENT_EMBED=512
15 | STORY_LEN=4
16 | LR=1e-5
17 | TRAIN_BS=1
18 | GRAD_ACC=4
19 | elif [ "$1" = "didemo" ]; then
20 | echo "Training on DiDeMo"
21 | DATA_DIR=../data/didemo
22 | OUTPUT_ROOT=./out/didemo
23 | SENT_EMBED=512
24 | STORY_LEN=2
25 | TRAIN_BS=1
26 | GRAD_ACC=8
27 | fi
28 |
29 | LOG_DIR=../runs/
30 |
31 | python ./train_t2i.py \
32 | --prefix_model_name_or_path './1.3B/' \
33 | --tuning_mode story \
34 | --dataset_name $1 \
35 | --preseqlen 32 \
36 | --condition \
37 | --story_len $STORY_LEN \
38 | --sent_embed $SENT_EMBED \
39 | --prefix_dropout 0.2 \
40 | --data_dir $DATA_DIR \
41 | --dataloader_num_workers 4 \
42 | --output_dir $OUTPUT_ROOT \
43 | --log_dir $LOG_DIR \
44 | --do_train --do_eval \
45 | --per_gpu_train_batch_size $TRAIN_BS \
46 | --per_gpu_eval_batch_size 2 \
47 | --num_train_epochs 50 \
48 | --gradient_accumulation_steps $GRAD_ACC \
49 | --learning_rate $LR \
50 | --logging_steps 50 \
51 | --eval_steps 500 \
52 | --generate_steps 1000
53 |
54 |
--------------------------------------------------------------------------------
/story-dalle/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.utils as vutils
4 | from torchvision.utils import save_image
5 | from tqdm import tqdm
6 | import numpy as np
7 | import PIL
8 |
9 | def save_as_image(tensor, out_dir, suffix):
10 | img = vutils.make_grid(tensor, video_len=1, padding=0)
11 | save_image(img, '%s/gen_sample_%s.png' % (out_dir, suffix))
12 |
13 | def acc_tensors_to_images(tensor_dir, key, out_dir):
14 |
15 | files = [f for f in os.listdir(tensor_dir) if f.endswith('pt') and key in f]
16 | sorted_files = sorted(files, key=lambda x: int(x[:-3].split('_')[-1]))
17 | print(files[:10])
18 | print(sorted_files[:20])
19 | all_tensors = []
20 | for f in tqdm(files, desc="eading tensors"):
21 | t = torch.load(os.path.join(tensor_dir, f))
22 | # print(t[0].shape)
23 | # print(t.shape)
24 | all_tensors.append(t)
25 | all_tensors = torch.cat(all_tensors, dim=0)
26 | print(all_tensors.shape)
27 |
28 | torch.save(all_tensors, os.path.join(tensor_dir, 'sdalle_story_%s.pt' % key))
29 |
30 | # for i in tqdm(range(0, all_tensors.shape[0]), desc='Preapring images'):
31 | # for j in range(0, all_tensors.shape[1]):
32 | # save_as_image(all_tensors[i, j], out_dir, '%s_%s' % (i, j))
33 |
34 | def images_to_numpy(tensor):
35 | generated = tensor.data.cpu().numpy().transpose(1, 2, 0)
36 | generated[generated < -1] = -1
37 | generated[generated > 1] = 1
38 | generated = (generated + 1) / 2 * 255
39 | return generated.astype('uint8')
40 |
41 | def numpy_to_img(numpy_file, outdir, img_size):
42 |
43 | if not os.path.exists(outdir):
44 | os.makedirs(outdir)
45 |
46 | x = np.load(numpy_file)
47 | print("Numpy image file shape: ", x.shape)
48 | for i in tqdm(range(x.shape[0])):
49 | frames = x[i, :, :, :, :]
50 | frames = np.swapaxes(frames, 0, 1)
51 | # frames = torch.Tensor(frames).view(-1, 3, 64, 64)
52 | # frames = torch.nn.functional.upsample(frames, size=(img_size, img_size), mode="bilinear")
53 |
54 | # vutils.save_image(vutils.make_grid(torch.Tensor(frames).view(-1, 3, 64, 64), 1, padding=0), 'sequence-2.png')
55 | all_images = images_to_numpy(vutils.make_grid(torch.Tensor(frames).view(-1, 3, 64, 64), 1, padding=0))
56 | # all_images = images_to_numpy(vutils.make_grid(frames, 1, padding=0))
57 | # print(all_images.shape)
58 | for j, idx in enumerate(range(64, all_images.shape[0] + 1, 64)):
59 | output = PIL.Image.fromarray(all_images[idx-64: idx, :, :])
60 | output.save(os.path.join(outdir, 'img-%s-%s.png' % (i, j)))
61 | img = PIL.Image.open(os.path.join(outdir, 'img-%s-%s.png' % (i, j)))
62 | if img_size != 64:
63 | img = img.resize((img_size, img_size,))
64 | img.save(os.path.join(outdir, 'img-%s-%s.png' % (i, j)))
65 |
66 | if __name__ == "__main__":
67 |
68 | acc_tensors_to_images('/nas-ssd/adyasha/out/minDALLEs/pororo/', 'test', '/nas-ssd/adyasha/out/minDALLEs/pororo/test_images')
69 | # acc_tensors_to_images('/nas-ssd/adyasha/out/minDALLEs/didemo/', 'test', '/nas-ssd/adyasha/out/minDALLEs/didemo/test_images/images')
70 | # acc_tensors_to_images('/nas-ssd/adyasha/out/minDALLEs/flintstones/', 'test', '/nas-ssd/adyasha/out/minDALLEs/flintstones/test_images/images')
71 |
72 | # numpy_to_img('/nas-ssd/adyasha/out/SGANc/didemo/val-images-epoch-120.npy', '/nas-ssd/adyasha/out/SGANc/didemo/val_images/', 299)
--------------------------------------------------------------------------------
/story-dalle/vfid/__pycache__/fid_score.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/vfid/__pycache__/fid_score.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/vfid/__pycache__/inception.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adymaharana/storydalle/132dd19f7277dae36c16c5630792deb12fa5a09f/story-dalle/vfid/__pycache__/inception.cpython-38.pyc
--------------------------------------------------------------------------------
/story-dalle/vfid/fid_score.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Calculates the Frechet Inception Distance (FID) to evalulate Video GAN
3 |
4 | The difference of this GAN is replacing the original encoder using residual 2+1 encoder
5 |
6 | The FID metric calculates the distance between two distributions of images.
7 | Typically, we have summary statistics (mean & covariance matrix) of one
8 | of these distributions, while the 2nd distribution is given by a GAN.
9 |
10 | When run as a stand-alone program, it compares the distribution of
11 | images that are stored as PNG/JPEG at a specified location with a
12 | distribution given by summary statistics (in pickle format).
13 |
14 | The FID is calculated by assuming that X_1 and X_2 are the activations of
15 | the pool_3 layer of the inception net for generated samples and real world
16 | samples respectively.
17 |
18 | See --help to see further details.
19 |
20 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
21 | of Tensorflow
22 |
23 | Copyright 2018 Institute of Bioinformatics, JKU Linz
24 |
25 | Licensed under the Apache License, Version 2.0 (the "License");
26 | you may not use this file except in compliance with the License.
27 | You may obtain a copy of the License at
28 |
29 | http://www.apache.org/licenses/LICENSE-2.0
30 |
31 | Unless required by applicable law or agreed to in writing, software
32 | distributed under the License is distributed on an "AS IS" BASIS,
33 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
34 | See the License for the specific language governing permissions and
35 | limitations under the License.
36 | """
37 | import os
38 | import numpy as np
39 | import torch
40 | import torch.nn.functional as F
41 | from torch.utils.data import DataLoader
42 | from tqdm import tqdm
43 | from scipy import linalg
44 | import PIL
45 | import functools
46 | from .inception import InceptionV3
47 |
48 | def calculate_activation_statistics(imgs, model, batch_size=32, dims=2048,
49 | cuda=False, normalize=False, verbose=0, is_ref=False):
50 | """Calculates the activations of the pool_3 layer for all images.
51 |
52 | Params:
53 | imgs: image dataset
54 | model: Instance of inception model
55 | batch_size: Batch size of images for the model to process at once.
56 | Make sure that the number of samples is a multiple of the batch
57 | size, otherwise some samples are ignored. This behavior is retained
58 | to match the original FID score implementation.
59 | cuda: If set to True, use GPU
60 | normalize: If the value range of imgs is [-1, 1], set to True to
61 | shift value range to [0, 1].
62 | verbose: If verbose > 0, show progressbar during evaluation
63 | Returns:
64 | mu: The mean over samples of the activations of the pool_3 layer of
65 | the inception model.
66 | sigma: The covariance matrix of the activations of the pool_3 layer of
67 | the inception model.
68 | """
69 | model.eval()
70 | if cuda:
71 | device = torch.device('cuda')
72 | else:
73 | device = torch.device('cpu')
74 | model.to(device)
75 |
76 | with torch.no_grad():
77 | features = []
78 | features_cache = []
79 | dataloader = DataLoader(
80 | imgs, batch_size=batch_size, num_workers=4 if is_ref else 0, drop_last=True, shuffle=False)
81 | if verbose > 0:
82 | iter_dataset = tqdm(dataloader, dynamic_ncols=True)
83 | else:
84 | iter_dataset = dataloader
85 | for batch in tqdm(iter_dataset):
86 | images = batch[0]
87 | if len(images.shape) == 5:
88 | video_len, n_channels, h, w = images.size(-4), images.size(-3), images.size(-2), images.size(-1)
89 | images = images.view(batch_size * video_len, n_channels, h, w)
90 | # print(images.shape)
91 | images = images.type(torch.FloatTensor).to(device)
92 |
93 | # print(images.shape)
94 | if normalize:
95 | images = (images + 1) / 2 # [-1, 1] -> [0, 1]
96 | if images.size(3) != 299:
97 | images = F.interpolate(images, size=(299, 299),
98 | mode='bilinear', align_corners=False)
99 | pred = model(images)[0]
100 |
101 | # If model output is not scalar, apply global spatial average
102 | # pooling. This happens if you choose a dimensionality not equal
103 | # 2048.
104 | if pred.shape[2] != 1 or pred.shape[3] != 1:
105 | pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
106 | features.append(pred.cpu().numpy().reshape(-1, dims))
107 | features_cache.append(np.expand_dims(pred.cpu().numpy().reshape(-1, dims), axis=0))
108 |
109 | features = np.concatenate(features, axis=0)
110 | mu = np.mean(features, axis=0)
111 | sigma = np.cov(features, rowvar=False)
112 |
113 | return mu, sigma, np.concatenate(features_cache, axis=0)
114 |
115 |
116 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
117 | """Numpy implementation of the Frechet Distance.
118 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
119 | and X_2 ~ N(mu_2, C_2) is
120 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
121 |
122 | Stable version by Dougal J. Sutherland.
123 |
124 | Params:
125 | mu1: Numpy array containing the activations of a layer of the
126 | inception net (like returned by the function 'get_predictions')
127 | for generated samples.
128 | mu2: The sample mean over activations, precalculated on an
129 | representative data set.
130 | sigma1: The covariance matrix over activations for generated samples.
131 | sigma2: The covariance matrix over activations, precalculated on an
132 | representative data set.
133 |
134 | Returns:
135 | The Frechet Distance.
136 | """
137 |
138 | mu1 = np.atleast_1d(mu1)
139 | mu2 = np.atleast_1d(mu2)
140 |
141 | sigma1 = np.atleast_2d(sigma1)
142 | sigma2 = np.atleast_2d(sigma2)
143 |
144 | assert mu1.shape == mu2.shape, \
145 | 'Training and test mean vectors have different lengths'
146 | assert sigma1.shape == sigma2.shape, \
147 | 'Training and test covariances have different dimensions'
148 |
149 | diff = mu1 - mu2
150 |
151 | # Product might be almost singular
152 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
153 | if not np.isfinite(covmean).all():
154 | print('fid calculation produces singular product; adding %s to diagonal of cov estimates' % eps)
155 | offset = np.eye(sigma1.shape[0]) * eps
156 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
157 |
158 | # Numerical error might give slight imaginary component
159 | if np.iscomplexobj(covmean):
160 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
161 | m = np.max(np.abs(covmean.imag))
162 | raise ValueError('Imaginary component {}'.format(m))
163 | covmean = covmean.real
164 |
165 | return (diff.dot(diff) +
166 | np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean))
167 |
168 |
169 | def fid_score(r_imgs, g_imgs, batch_size=32, dims=2048, cuda=False,
170 | normalize=False, r_cache=None, verbose=0):
171 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
172 | model = InceptionV3([block_idx])
173 |
174 | # cache real dataset
175 | if r_cache and not r_cache.endswith('.npz'):
176 | r_cache = r_cache + '.npz'
177 | if r_cache and os.path.exists(r_cache):
178 | data = np.load(r_cache)
179 | m1, s1 = data['m1'], data['s1']
180 | else:
181 | m1, s1, f1 = calculate_activation_statistics(r_imgs, model, batch_size,
182 | dims, cuda, normalize)
183 | if r_cache is not None:
184 | # os.makedirs(os.path.dirname(r_cache), exist_ok=True)
185 | np.savez(r_cache, m1=m1, s1=s1)
186 | np.save(r_cache.replace('.npz', '.npy'), f1)
187 | # compute generative image dataset
188 | m2, s2, f2 = calculate_activation_statistics(g_imgs, model, batch_size, dims,
189 | cuda, normalize)
190 | np.save(r_cache.replace('.npz', '_gen.npy'), f2)
191 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
192 |
193 | return fid_value
--------------------------------------------------------------------------------
/story-dalle/vfid/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 |
6 | try:
7 | from torchvision.models.utils import load_state_dict_from_url
8 | except ImportError:
9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
10 |
11 | # Inception weights ported to Pytorch from
12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14 |
15 |
16 | class InceptionV3(nn.Module):
17 | """Pretrained InceptionV3 network returning feature maps"""
18 |
19 | # Index of default block of inception to return,
20 | # corresponds to output of final average pooling
21 | DEFAULT_BLOCK_INDEX = 3
22 |
23 | # Maps feature dimensionality to their output blocks indices
24 | BLOCK_INDEX_BY_DIM = {
25 | 64: 0, # First max pooling features
26 | 192: 1, # Second max pooling featurs
27 | 768: 2, # Pre-aux classifier features
28 | 2048: 3 # Final average pooling features
29 | }
30 |
31 | def __init__(self,
32 | output_blocks=[DEFAULT_BLOCK_INDEX],
33 | resize_input=True,
34 | normalize_input=True,
35 | requires_grad=False,
36 | use_fid_inception=True):
37 | """Build pretrained InceptionV3
38 |
39 | Parameters
40 | ----------
41 | output_blocks : list of int
42 | Indices of blocks to return features of. Possible values are:
43 | - 0: corresponds to output of first max pooling
44 | - 1: corresponds to output of second max pooling
45 | - 2: corresponds to output which is fed to aux classifier
46 | - 3: corresponds to output of final average pooling
47 | resize_input : bool
48 | If true, bilinearly resizes input to width and height 299 before
49 | feeding input to model. As the network without fully connected
50 | layers is fully convolutional, it should be able to handle inputs
51 | of arbitrary size, so resizing might not be strictly needed
52 | normalize_input : bool
53 | If true, scales the input from range (0, 1) to the range the
54 | pretrained Inception network expects, namely (-1, 1)
55 | requires_grad : bool
56 | If true, parameters of the model require gradients. Possibly useful
57 | for finetuning the network
58 | use_fid_inception : bool
59 | If true, uses the pretrained Inception model used in Tensorflow's
60 | FID implementation. If false, uses the pretrained Inception model
61 | available in torchvision. The FID Inception model has different
62 | weights and a slightly different structure from torchvision's
63 | Inception model. If you want to compute FID scores, you are
64 | strongly advised to set this parameter to true to get comparable
65 | results.
66 | """
67 | super(InceptionV3, self).__init__()
68 |
69 | self.resize_input = resize_input
70 | self.normalize_input = normalize_input
71 | self.output_blocks = sorted(output_blocks)
72 | self.last_needed_block = max(output_blocks)
73 |
74 | assert self.last_needed_block <= 3, \
75 | 'Last possible output block index is 3'
76 |
77 | self.blocks = nn.ModuleList()
78 |
79 | if use_fid_inception:
80 | inception = fid_inception_v3()
81 | else:
82 | inception = models.inception_v3(pretrained=True)
83 |
84 | # Block 0: input to maxpool1
85 | block0 = [
86 | inception.Conv2d_1a_3x3,
87 | inception.Conv2d_2a_3x3,
88 | inception.Conv2d_2b_3x3,
89 | nn.MaxPool2d(kernel_size=3, stride=2)
90 | ]
91 | self.blocks.append(nn.Sequential(*block0))
92 |
93 | # Block 1: maxpool1 to maxpool2
94 | if self.last_needed_block >= 1:
95 | block1 = [
96 | inception.Conv2d_3b_1x1,
97 | inception.Conv2d_4a_3x3,
98 | nn.MaxPool2d(kernel_size=3, stride=2)
99 | ]
100 | self.blocks.append(nn.Sequential(*block1))
101 |
102 | # Block 2: maxpool2 to aux classifier
103 | if self.last_needed_block >= 2:
104 | block2 = [
105 | inception.Mixed_5b,
106 | inception.Mixed_5c,
107 | inception.Mixed_5d,
108 | inception.Mixed_6a,
109 | inception.Mixed_6b,
110 | inception.Mixed_6c,
111 | inception.Mixed_6d,
112 | inception.Mixed_6e,
113 | ]
114 | self.blocks.append(nn.Sequential(*block2))
115 |
116 | # Block 3: aux classifier to final avgpool
117 | if self.last_needed_block >= 3:
118 | block3 = [
119 | inception.Mixed_7a,
120 | inception.Mixed_7b,
121 | inception.Mixed_7c,
122 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
123 | ]
124 | self.blocks.append(nn.Sequential(*block3))
125 |
126 | for param in self.parameters():
127 | param.requires_grad = requires_grad
128 |
129 | def forward(self, inp):
130 | """Get Inception feature maps
131 |
132 | Parameters
133 | ----------
134 | inp : torch.autograd.Variable
135 | Input tensor of shape Bx3xHxW. Values are expected to be in
136 | range (0, 1)
137 |
138 | Returns
139 | -------
140 | List of torch.autograd.Variable, corresponding to the selected output
141 | block, sorted ascending by index
142 | """
143 | outp = []
144 | x = inp
145 |
146 | if self.resize_input:
147 | x = F.interpolate(x,
148 | size=(299, 299),
149 | mode='bilinear',
150 | align_corners=False)
151 |
152 | if self.normalize_input:
153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154 |
155 | for idx, block in enumerate(self.blocks):
156 | x = block(x)
157 | if idx in self.output_blocks:
158 | outp.append(x)
159 |
160 | if idx == self.last_needed_block:
161 | break
162 |
163 | return outp
164 |
165 |
166 | def fid_inception_v3():
167 | """Build pretrained Inception model for FID computation
168 |
169 | The Inception model for FID computation uses a different set of weights
170 | and has a slightly different structure than torchvision's Inception.
171 |
172 | This method first constructs torchvision's Inception and then patches the
173 | necessary parts that are different in the FID Inception model.
174 | """
175 | inception = models.inception_v3(num_classes=1008,
176 | aux_logits=False,
177 | pretrained=False)
178 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
179 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
180 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
181 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
182 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
183 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
184 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
185 | inception.Mixed_7b = FIDInceptionE_1(1280)
186 | inception.Mixed_7c = FIDInceptionE_2(2048)
187 |
188 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
189 | inception.load_state_dict(state_dict)
190 | return inception
191 |
192 |
193 | class FIDInceptionA(models.inception.InceptionA):
194 | """InceptionA block patched for FID computation"""
195 | def __init__(self, in_channels, pool_features):
196 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
197 |
198 | def forward(self, x):
199 | branch1x1 = self.branch1x1(x)
200 |
201 | branch5x5 = self.branch5x5_1(x)
202 | branch5x5 = self.branch5x5_2(branch5x5)
203 |
204 | branch3x3dbl = self.branch3x3dbl_1(x)
205 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
206 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
207 |
208 | # Patch: Tensorflow's average pool does not use the padded zero's in
209 | # its average calculation
210 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
211 | count_include_pad=False)
212 | branch_pool = self.branch_pool(branch_pool)
213 |
214 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
215 | return torch.cat(outputs, 1)
216 |
217 |
218 | class FIDInceptionC(models.inception.InceptionC):
219 | """InceptionC block patched for FID computation"""
220 | def __init__(self, in_channels, channels_7x7):
221 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
222 |
223 | def forward(self, x):
224 | branch1x1 = self.branch1x1(x)
225 |
226 | branch7x7 = self.branch7x7_1(x)
227 | branch7x7 = self.branch7x7_2(branch7x7)
228 | branch7x7 = self.branch7x7_3(branch7x7)
229 |
230 | branch7x7dbl = self.branch7x7dbl_1(x)
231 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
232 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
233 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
234 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
235 |
236 | # Patch: Tensorflow's average pool does not use the padded zero's in
237 | # its average calculation
238 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
239 | count_include_pad=False)
240 | branch_pool = self.branch_pool(branch_pool)
241 |
242 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
243 | return torch.cat(outputs, 1)
244 |
245 |
246 | class FIDInceptionE_1(models.inception.InceptionE):
247 | """First InceptionE block patched for FID computation"""
248 | def __init__(self, in_channels):
249 | super(FIDInceptionE_1, self).__init__(in_channels)
250 |
251 | def forward(self, x):
252 | branch1x1 = self.branch1x1(x)
253 |
254 | branch3x3 = self.branch3x3_1(x)
255 | branch3x3 = [
256 | self.branch3x3_2a(branch3x3),
257 | self.branch3x3_2b(branch3x3),
258 | ]
259 | branch3x3 = torch.cat(branch3x3, 1)
260 |
261 | branch3x3dbl = self.branch3x3dbl_1(x)
262 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
263 | branch3x3dbl = [
264 | self.branch3x3dbl_3a(branch3x3dbl),
265 | self.branch3x3dbl_3b(branch3x3dbl),
266 | ]
267 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
268 |
269 | # Patch: Tensorflow's average pool does not use the padded zero's in
270 | # its average calculation
271 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
272 | count_include_pad=False)
273 | branch_pool = self.branch_pool(branch_pool)
274 |
275 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
276 | return torch.cat(outputs, 1)
277 |
278 |
279 | class FIDInceptionE_2(models.inception.InceptionE):
280 | """Second InceptionE block patched for FID computation"""
281 | def __init__(self, in_channels):
282 | super(FIDInceptionE_2, self).__init__(in_channels)
283 |
284 | def forward(self, x):
285 | branch1x1 = self.branch1x1(x)
286 |
287 | branch3x3 = self.branch3x3_1(x)
288 | branch3x3 = [
289 | self.branch3x3_2a(branch3x3),
290 | self.branch3x3_2b(branch3x3),
291 | ]
292 | branch3x3 = torch.cat(branch3x3, 1)
293 |
294 | branch3x3dbl = self.branch3x3dbl_1(x)
295 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
296 | branch3x3dbl = [
297 | self.branch3x3dbl_3a(branch3x3dbl),
298 | self.branch3x3dbl_3b(branch3x3dbl),
299 | ]
300 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
301 |
302 | # Patch: The FID Inception model uses max pooling instead of average
303 | # pooling. This is likely an error in this specific Inception
304 | # implementation, as other Inception models use average pooling here
305 | # (which matches the description in the paper).
306 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
307 | branch_pool = self.branch_pool(branch_pool)
308 |
309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
310 | return torch.cat(outputs, 1)
--------------------------------------------------------------------------------