├── input ├── src │ ├── 0.png │ ├── config.py │ ├── data.py │ ├── bert_emb.py │ └── dataset_check.py └── README.md ├── examples ├── bird1.jpg ├── bird2.jpg ├── bird3.jpg ├── bird4.jpg ├── bird5.jpg ├── flower1.jpg ├── flower2.jpg ├── flower3.jpg ├── flower4.jpg ├── flower5.jpg ├── framework.jpg ├── framework.png └── captions ├── requirements.txt ├── src ├── environment.yml ├── util.py ├── dataset.py ├── engine.py ├── train.py ├── args.py └── layers.py ├── cfg ├── s2.yml └── s1.yml ├── LICENSE ├── .gitignore └── README.md /input/src/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/input/src/0.png -------------------------------------------------------------------------------- /examples/bird1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/bird1.jpg -------------------------------------------------------------------------------- /examples/bird2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/bird2.jpg -------------------------------------------------------------------------------- /examples/bird3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/bird3.jpg -------------------------------------------------------------------------------- /examples/bird4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/bird4.jpg -------------------------------------------------------------------------------- /examples/bird5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/bird5.jpg -------------------------------------------------------------------------------- /examples/flower1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/flower1.jpg -------------------------------------------------------------------------------- /examples/flower2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/flower2.jpg -------------------------------------------------------------------------------- /examples/flower3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/flower3.jpg -------------------------------------------------------------------------------- /examples/flower4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/flower4.jpg -------------------------------------------------------------------------------- /examples/flower5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/flower5.jpg -------------------------------------------------------------------------------- /examples/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/framework.jpg -------------------------------------------------------------------------------- /examples/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/framework.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ujson 2 | numpy 3 | spacy 4 | tensorboard 5 | tensorflow 6 | tensorboardX 7 | tqdm 8 | urllib3 9 | torch 10 | torchvision 11 | transformers 12 | torchfile -------------------------------------------------------------------------------- /src/environment.yml: -------------------------------------------------------------------------------- 1 | name: ganctober 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.6 8 | - ujson 9 | - numpy 10 | - pip 11 | - spacy=2.0.16 12 | - tensorboard 13 | - tensorflow 14 | - tensorboardX 15 | - tqdm 16 | - urllib3 17 | - pytorch=1.0.0 18 | - pip: 19 | - torch==1.0.0 20 | -------------------------------------------------------------------------------- /cfg/s2.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: "stage2" 2 | STAGE: 2 3 | train_bs: 4 # 64 4 | train_workers: 1 5 | embedding_type: "bert" 6 | TRAIN_MAX_EPOCH: 600 7 | TRAIN_GEN_LR: 0.0002 # 2e-4 8 | TRAIN_DISC_LR: 0.0002 # 2e-4 9 | TRAIN_LR_DECAY_EPOCH: 100 10 | TRAIN_SNAPSHOT_INTERVAL: 10 # 2000 11 | STAGE1_G_path: "../output/model/stage1_netG_epoch_600.pth" 12 | NET_G_path: "" 13 | NET_D_path: "" 14 | 15 | pretrained_epoch: 600 -------------------------------------------------------------------------------- /cfg/s1.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: "stage1" 2 | STAGE: 1 3 | train_bs: 128 # 128 fits / 64 standard 4 | train_workers: 4 5 | embedding_type: "bert" # cnn-rnn / bert 6 | TRAIN_MAX_EPOCH: 600 7 | TRAIN_GEN_LR: 0.0002 # 2e-4 8 | TRAIN_DISC_LR: 0.0002 # 2e-4 9 | TRAIN_LR_DECAY_EPOCH: 50 10 | TRAIN_SNAPSHOT_INTERVAL: 2000 11 | NET_G_path: "" 12 | NET_D_path: "" 13 | # NET_G_path: "../old_outputs/output_1_0/model/netG_epoch_120.pth" 14 | # NET_D_path: "../old_outputs/output_1_0/model/netD_epoch_120.pth" -------------------------------------------------------------------------------- /input/src/config.py: -------------------------------------------------------------------------------- 1 | """All the paths and arguments. 2 | 3 | Authors: 4 | Abhiraj Tiwari (abhirajtiwari@gmail.com) 5 | Sahil Khose (sahilkhose18@gmail.com) 6 | """ 7 | # import transformers 8 | 9 | # BERT_PATH = "../data/bert_base_uncased" 10 | ANNOTATIONS = "../data/birds/text_c10" 11 | ANNOTATION_EMB = "../data/birds/embeddings" 12 | IMAGE_DIR = "../data/CUB_200_2011/images" 13 | EMB_LEN = "../data/emb_lens.txt" 14 | SENT_LEN = "../data/sent_len.txt" 15 | 16 | DEVICE = "cuda" 17 | 18 | ANNOTATIONS_URL = "https://drive.google.com/open?id=0B3y_msrWZaXLT1BZdVdycDY5TEE" 19 | CUB_URL = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" -------------------------------------------------------------------------------- /examples/captions: -------------------------------------------------------------------------------- 1 | A small yellow bird with a black crown and a short black pointed beak 2 | This small bird has a white breast, light grey head, and black wings and tail 3 | This bird is white, black, and brown in color, with a brown beak 4 | A white bird with a black crown and yellow beak 5 | This bird is completely red with black wings and pointy beak 6 | 7 | 8 | This flower has overlapping pink pointed petals surrounding a ring of short yellow filaments 9 | This flower has long thin yellow petals and a lot of yellow anthers in the center 10 | This flower is white, pink, and yellow in color, and has petals that are multi colored 11 | This flower is pink, white, and yellow in color, and has petals that are striped 12 | This flower has white petals with a yellow tip and a yellow pistil 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sahil Khose 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /input/README.md: -------------------------------------------------------------------------------- 1 | ## :four_leaf_clover: New 2 | - run the following to automate the dataset download 3 | ```bash 4 | cd input/src 5 | python3 data.py 6 | ``` 7 | ``` 8 | ganctober 9 | │ LICENSE 10 | │ README.md 11 | │ requirements.txt 12 | │ 13 | └──>cfg 14 | │ 15 | └──>examples 16 | │ 17 | └──>input 18 | │ │ 19 | │ │ 20 | │ └──>data 21 | │ | │ 22 | │ | └──>bert_base_uncased 23 | | | | 24 | │ | └──>birds 25 | │ | | 26 | | | └──>CUB_200_2011 27 | │ │ 28 | │ └──>src 29 | | | bert_emb.py 30 | | | config.py 31 | | | data.py 32 | | | dataset_check.py 33 | | 34 | | 35 | └──>old_outputs 36 | | 37 | └──>output 38 | | 39 | └──>src 40 | │ │ args.py 41 | | | dataset.py 42 | │ │ engine.py 43 | │ │ environment.yml 44 | │ │ layers.py 45 | │ │ train.py 46 | │ │ util.py 47 | ``` 48 | -------------------------------------------------------------------------------------------- 49 | 50 | ## Dataset (Old) 51 | - Download the following and extract/move them to the mentioned directories so that your workspace is similar to the one in the figure. 52 | - Download the The Caltech-UCSD Birds-200-2011 (CUB) Dataset from: http://www.vision.caltech.edu/visipedia/CUB-200-2011.html and extract it in `input/data/` to create `input/data/CUB_200_2011` directory which contains `images` directory with the images we need for our task.
53 | - Read the README about the dataset on the webiste 54 | - Download the text descriptions from: https://drive.google.com/open?id=0B3y_msrWZaXLT1BZdVdycDY5TEE and extract it in `input/data/` to create `input/data/birds` directory which contains `text_c10` directory which contains all the annotations needed for our task.
55 | - Download bert_base_uncased from: https://www.kaggle.com/abhishek/bert-base-uncased and extract it in `input/data/` to create `input/data/bert_base_uncased` to create annotation bert embeddings -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # input data and models 2 | input/data/* 3 | output/ 4 | old_outputs/ 5 | *.tar 6 | 7 | # data files 8 | *.csv 9 | *.pt 10 | *.npy 11 | *.txt 12 | !requirements.txt 13 | *.h5 14 | *.pkl 15 | *.pth 16 | 17 | ###################################################################### 18 | ###################################################################### 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | cover/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | .pybuilder/ 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | # For a library or package, you might want to ignore these files since the code is 105 | # intended to run in multiple environments; otherwise, check them in: 106 | # .python-version 107 | 108 | # pipenv 109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 112 | # install all needed dependencies. 113 | #Pipfile.lock 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | ganctober.code-workspace 158 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | """Utility classes and methods. 2 | 3 | Authors: 4 | Abhiraj Tiwari (abhirajtiwari@gmail.com) 5 | Sahil Khose (sahilkhose18@gmail.com) 6 | """ 7 | import args 8 | 9 | import matplotlib.pyplot as plt 10 | import os 11 | import pandas as pd 12 | import torch 13 | import torchvision.utils as vutils 14 | 15 | from json import dumps 16 | print("__"*80) 17 | 18 | #TODO fetch saved generated images during training and their corresponding annotations 19 | 20 | def save_img_results(data_img, fake, epoch, args): 21 | num = args.VIS_COUNT 22 | fake = fake[0:num] 23 | # data_img is changed to [0, 1] 24 | if data_img is not None: 25 | data_img = data_img[0:num] 26 | vutils.save_image(data_img, os.path.join(args.image_save_dir, "real_samples.png"), normalize=True) 27 | # fake data is stil [-1, 1] 28 | vutils.save_image(fake.data, os.path.join(args.image_save_dir, f"fake_samples_epoch_{epoch:04}.png"), normalize=True) 29 | else: 30 | vutils.save_image(fake.data, os.path.join(args.image_save_dir, f"lr_fake_samples_epoch_{epoch:04}.png"), normalize=True) 31 | 32 | def save_model(netG, netD, epoch, args): 33 | torch.save(netG.state_dict(), os.path.join(args.model_dir, f"netG_epoch_{epoch}.pth")) 34 | torch.save(netD.state_dict(), os.path.join(args.model_dir, f"netD_epoch_{epoch}.pth")) #? implementation has saved just the last disc? decision? 35 | print("Save G/D models") 36 | 37 | 38 | def make_dir(path): 39 | if not os.path.exists(path): 40 | os.makedirs(path) 41 | 42 | 43 | def check_dataset(training_set): 44 | t, i, b = training_set[1] 45 | print("Bert emb shape: ", t.shape) 46 | print("bbox: ", b) 47 | plt.imshow(i) 48 | plt.show() 49 | print("__"*80) 50 | 51 | def check_args(): 52 | """ 53 | To test args.py 54 | """ 55 | print("get_data_args:") 56 | data_args = args.get_all_args() 57 | print(f'Args: {dumps(vars(data_args), indent=4, sort_keys=True)}') 58 | print("__"*80) 59 | 60 | print("To fetch arguments:") 61 | print(data_args.device) 62 | print(data_args.images_dir) 63 | print("__"*80) 64 | 65 | ##################################################################################### 66 | print("To fetch images, text embeddings, bert embeddings:") 67 | df_train_filenames = pd.read_pickle(data_args.train_filenames) 68 | # print(df_train_filenames[0]) # List[str] : len train -> 8855, len test -> 2933 69 | image_path = os.path.join(data_args.images_dir, df_train_filenames[0] + ".jpg") 70 | text_path = os.path.join(data_args.annotations_dir, df_train_filenames[0] + ".txt") 71 | bert_path = os.path.join(data_args.bert_annotations_dir, df_train_filenames[0], "0.pt") 72 | 73 | 74 | print("\nBird type: ") 75 | print(df_train_filenames[0].split("/")[0]) 76 | 77 | print("\nAnnotations of the bird image: \n") 78 | [print(f"{idx}: {ele}") for idx, ele in enumerate(open(text_path).read().split("\n")[:-1])] 79 | 80 | print("\nShape of bert embedding of annotation no 0:") 81 | emb = torch.load(bert_path) 82 | print(emb.shape) # (1, 768) 83 | 84 | img = plt.imread(image_path) 85 | plt.imshow(img) 86 | plt.show() 87 | 88 | 89 | 90 | 91 | if __name__ == "__main__": 92 | check_args() 93 | # make_dir("../output/model") 94 | -------------------------------------------------------------------------------- /input/src/data.py: -------------------------------------------------------------------------------- 1 | """Downloads data 2 | 3 | Authors: 4 | Abhiraj Tiwari (abhirajtiwari@gmail.com) 5 | Sahil Khose (sahilkhose18@gmail.com) 6 | """ 7 | import requests 8 | import os 9 | import tqdm 10 | 11 | class GoogleDriveDownloader(object): 12 | """ 13 | Downloading a file stored on Google Drive by its URL. 14 | If the link is pointing to another resource, the redirect chain is being expanded. 15 | Returns the output path. 16 | """ 17 | 18 | base_url = 'https://docs.google.com/uc?export=download' 19 | chunk_size = 32768 20 | 21 | def __init__(self, url, out_dir): 22 | super().__init__() 23 | 24 | self.out_name = url.rsplit('/', 1)[-1] 25 | self.url = self._get_redirect_url(url) 26 | self.out_dir = out_dir 27 | 28 | @staticmethod 29 | def _get_redirect_url(url): 30 | response = requests.get(url) 31 | if response.url != url and response.url is not None: 32 | redirect_url = response.url 33 | return redirect_url 34 | else: 35 | return url 36 | 37 | @staticmethod 38 | def _get_confirm_token(response): 39 | for key, value in response.cookies.items(): 40 | if key.startswith('download_warning'): 41 | return value 42 | return None 43 | 44 | def _save_response_content(self, response): 45 | with open(self.fpath, 'wb') as f: 46 | bar = tqdm.tqdm(total=None) 47 | progress = 0 48 | for chunk in response.iter_content(self.chunk_size): 49 | if chunk: 50 | f.write(chunk) 51 | progress += len(chunk) 52 | bar.update(progress - bar.n) 53 | bar.close() 54 | 55 | @property 56 | def file_id(self): 57 | return self.url.split('?')[0].split('/')[-2] 58 | 59 | @property 60 | def fpath(self): 61 | return os.path.join(self.out_dir, self.out_name) 62 | 63 | def download(self): 64 | os.makedirs(self.out_dir, exist_ok=True) 65 | 66 | if os.path.isfile(self.fpath): 67 | print('File is downloaded yet:', self.fpath) 68 | else: 69 | session = requests.Session() 70 | response = session.get(self.base_url, params={'id': self.file_id}, stream=True) 71 | token = self._get_confirm_token(response) 72 | 73 | if token: 74 | response = session.get(self.base_url, params={'id': self.file_id, 'confirm': token}, stream=True) 75 | else: 76 | raise RuntimeError() 77 | 78 | self._save_response_content(response) 79 | 80 | return self.fpath 81 | 82 | 83 | def main(): 84 | os.makedirs("../data/", exist_ok=True) 85 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 86 | dl = GoogleDriveDownloader(url, '../data/') 87 | dl.download() 88 | 89 | url_b = 'https://drive.google.com/file/d/1MVfYF0qVgKHTQKFdA7lexGWnRIs7Ax9c/view?usp=sharing' 90 | dl_b = GoogleDriveDownloader(url_b, '../data/') 91 | dl_b.download() 92 | 93 | os.system("unzip ../data/birds.zip -d ../data/") 94 | os.system("tar -xvf ../data/CUB_200_2011.tgz -C ../data/") 95 | 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StackGAN 2 | 3 | - PyTorch implementation of the paper [StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/pdf/1612.03242.pdf) by Han Zhang, Tao Xu, Hongsheng Li, Shaoting Zhang, Xiaogang Wang, Xiaolei Huang, Dimitris Metaxas. 4 | 5 | ## :bulb: What's new? 6 | - We use BERT embeddings for the text description instead of the char-CNN-RNN text embeddings that were used in the paper implementation. 7 | 8 | 9 | ## Pretrained model 10 | - [Stage 1](https://drive.google.com/drive/folders/14AyNcu7oZJe2aMevynAbYIpMKN7I3yHT?usp=sharing) trained using BERT embeddings instead of the orignal char-CNN-RNN text embeddings 11 | - [Stage 2](https://drive.google.com/drive/folders/1Pyndsp9oraE15ssD4MZJBVsyLW1ECCIi?usp=sharing) trained using BERT embeddings instead of the orignal char-CNN-RNN text embeddings 12 | 13 | ## Paper examples 14 | #### :bird: Examples for birds (char-CNN-RNN embeddings), more on [youtube](https://youtu.be/93yaf_kE0Fg): 15 | ![](examples/bird1.jpg)
16 | ![](examples/bird2.jpg)
17 | ![](examples/bird4.jpg)
18 | ![](examples/bird3.jpg)
19 | 20 | -------------------------------------------------------------------------------------------- 21 | 22 | #### :sunflower: Examples for flowers (char-CNN-RNN embeddings), more on [youtube](https://youtu.be/SuRyL5vhCIM): 23 | ![](examples/flower1.jpg)
24 | ![](examples/flower2.jpg)
25 | ![](examples/flower3.jpg)
26 | ![](examples/flower4.jpg)
27 | 28 | -------------------------------------------------------------------------------------------- 29 | 30 | ## :clipboard: Dependencies 31 | ```bash 32 | git clone https://github.com/sahilkhose/StackGAN-BERT.git 33 | pip3 install -r requirements.txt 34 | ``` 35 | 36 | ## Dataset 37 | Check instructions in `/input/README.md` 38 | ```bash 39 | cd input/src 40 | python3 data.py 41 | ``` 42 | 43 | ## Generating BERT embeddings of annotations 44 | Change the DEVICE to `cpu` in `input/src/config.py` if `cuda` is not available 45 | ```bash 46 | python3 bert_emb.py 47 | ``` 48 | 49 | ## :wrench: Training 50 | ```bash 51 | cd ../../src 52 | ``` 53 | Option 1: CLI args training `src/args.py` 54 | ```bash 55 | python3 train.py --TRAIN_MAX_EPOCH 10 56 | ``` 57 | Option 2: yaml args training `cfg/s1.yml` and `cfg/s2.yml` 58 | ```bash 59 | python3 train.py --conf ../cfg/s1.yml 60 | 61 | mkdir ../old_outputs 62 | mv ../output ../old_outputs/output_stage-1 63 | 64 | python3 train.py --conf ../cfg/s2.yml 65 | 66 | mv ../output ../old_outputs/output_stage-2 67 | ``` 68 | To load the tensorboard 69 | ```bash 70 | tensorboard --logdir=../output 71 | ``` 72 | 73 | -------------------------------------------------------------------------------------------- 74 | 75 | ## :books: Citing StackGAN 76 | If you find StackGAN useful in your research, please consider citing: 77 | 78 | ``` 79 | @inproceedings{han2017stackgan, 80 | Author = {Han Zhang and Tao Xu and Hongsheng Li and Shaoting Zhang and Xiaogang Wang and Xiaolei Huang and Dimitris Metaxas}, 81 | Title = {StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks}, 82 | Year = {2017}, 83 | booktitle = {{ICCV}}, 84 | } 85 | ``` 86 | 87 | **Follow-up work** 88 | 89 | - [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/abs/1710.10916) 90 | - [AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks](https://arxiv.org/abs/1711.10485) [[supplementary]](https://1drv.ms/b/s!Aj4exx_cRA4ghK5-kUG-EqH7hgknUA) [[code]](https://github.com/taoxugit/AttnGAN) 91 | 92 | **References** 93 | 94 | - Generative Adversarial Text-to-Image Synthesis [Paper](https://arxiv.org/abs/1605.05396) [Code](https://github.com/reedscot/icml2016) 95 | - Learning Deep Representations of Fine-grained Visual Descriptions [Paper](https://arxiv.org/abs/1605.05395) [Code](https://github.com/reedscot/cvpr2016) 96 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | """CUB Dataset class. 2 | 3 | Authors: 4 | Abhiraj Tiwari (abhirajtiwari@gmail.com) 5 | Sahil Khose (sahilkhose18@gmail.com) 6 | """ 7 | import args 8 | 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import os 13 | import pandas as pd 14 | import pickle 15 | import PIL 16 | import torch 17 | 18 | from torchvision import transforms 19 | 20 | from PIL import Image 21 | print("__"*80) 22 | 23 | 24 | class CUBDataset(torch.utils.data.Dataset): 25 | def __init__(self, pickl_file, img_dir, bert_emb=None, cnn_emb=None, stage=1): 26 | self.file_names = pd.read_pickle(pickl_file)[:-23] 27 | self.img_dir = img_dir 28 | self.bert_emb = bert_emb 29 | if cnn_emb is not None: 30 | self.cnn_emb = np.array(pickle.load(open(cnn_emb, "rb"), encoding='latin1')) # to switch to cnn-rnn embeddings 31 | else: 32 | self.cnn_emb = cnn_emb 33 | self.stage = stage 34 | self.f_to_bbox = dict_bbox() 35 | 36 | def __len__(self): 37 | # Total number of samples 38 | return len(self.file_names) 39 | 40 | def __getitem__(self, index): 41 | # Select sample: 42 | data_id = str(self.file_names[index]) 43 | 44 | # Fetch text emb, image, bbox: 45 | idx = np.random.randint(0, 9) 46 | if self.cnn_emb is not None: 47 | text_emb = torch.tensor(self.cnn_emb[index][idx], dtype=torch.float) # (1024) 48 | else: 49 | text_emb = torch.load(os.path.join(self.bert_emb, data_id)+f"/{idx}.pt", map_location="cpu") 50 | text_emb = text_emb.squeeze(0) # (768) 51 | 52 | bbox = self.f_to_bbox[data_id] 53 | if self.stage == 1: 54 | image = get_img(img_path=os.path.join(self.img_dir, data_id) + ".jpg", bbox=bbox, image_size=(64, 64)) 55 | else: 56 | image = get_img(img_path=os.path.join(self.img_dir, data_id) + ".jpg", bbox=bbox, image_size=(256, 256)) 57 | # image = torch.tensor(np.array(image), dtype=torch.float) 58 | image = transforms.ToTensor()(image) 59 | 60 | return text_emb, image 61 | 62 | 63 | def dict_bbox(): 64 | """ 65 | returns filename to bbox dict 66 | """ 67 | data_args = args.get_all_args() 68 | 69 | df_bbox = pd.read_csv(data_args.bounding_boxes, delim_whitespace=True, header=None).astype(int) 70 | df_filenames = pd.read_csv(data_args.images_id_file, delim_whitespace=True, header=None) 71 | 72 | filenames = df_filenames[1].tolist() 73 | 74 | filename_bbox = {} 75 | for i in range(len(filenames)): 76 | bbox = df_bbox.iloc[i][1:].tolist() 77 | key = filenames[i].replace(".jpg", "") 78 | filename_bbox[key] = bbox 79 | 80 | return filename_bbox 81 | 82 | def get_img(img_path, bbox, image_size): 83 | """ 84 | Load and resize image 85 | """ 86 | img = Image.open(img_path).convert('RGB') 87 | width, height = img.size 88 | if bbox is not None: 89 | R = int(np.maximum(bbox[2], bbox[3]) * 0.75) 90 | center_x = int((2 * bbox[0] + bbox[2]) / 2) 91 | center_y = int((2 * bbox[1] + bbox[3]) / 2) 92 | y1 = np.maximum(0, center_y - R) 93 | y2 = np.minimum(height, center_y + R) 94 | x1 = np.maximum(0, center_x - R) 95 | x2 = np.minimum(width, center_x + R) 96 | img = img.crop([x1, y1, x2, y2]) 97 | img = img.resize(image_size, PIL.Image.BILINEAR) 98 | return img 99 | 100 | if __name__ == "__main__": 101 | data_args = args.get_all_args() 102 | train_filenames = data_args.train_filenames 103 | test_filenames = data_args.test_filenames 104 | 105 | bert_embed = False 106 | 107 | if bert_embed: 108 | ###* Bert embeddings dataset: 109 | dataset_test = CUBDataset(train_filenames, data_args.images_dir, bert_emb=data_args.bert_annotations_dir) 110 | else: 111 | ###* cnn-rnn embeddings dataset: 112 | # cnn_embeddings = np.array(pickle.load(open(data_args.cnn_annotations_emb_train, "rb"), encoding='latin1')) 113 | dataset_test = CUBDataset(train_filenames, data_args.images_dir, cnn_emb=data_args.cnn_annotations_emb_train) 114 | 115 | 116 | t, i = dataset_test[1] 117 | if bert_embed: 118 | print("Bert emb shape: ", t.shape) 119 | else: 120 | print("rnn-cnn emb shape: ", t.shape) 121 | print("Image shape: ", i.shape) 122 | plt.imshow(i.permute(1, 2, 0)) 123 | plt.show() 124 | 125 | ########################################################### 126 | # filename = "001.Black_footed_Albatross/Black_Footed_Albatross_0046_18" 127 | # f_to_bbox = dict_bbox() 128 | # print(f_to_bbox[filename]) 129 | -------------------------------------------------------------------------------- /input/src/bert_emb.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Generates bert embeddings from annotations. 3 | 4 | Authors: 5 | Abhiraj Tiwari (abhirajtiwari@gmail.com) 6 | Sahil Khose (sahilkhose18@gmail.com) 7 | ''' 8 | 9 | import config 10 | 11 | import os 12 | import torch 13 | import logging 14 | logging.basicConfig(level=logging.ERROR) 15 | 16 | # from transformers import BertTokenizer, BertModel 17 | from transformers import AutoTokenizer, AutoModelForMaskedLM 18 | from tqdm import tqdm 19 | 20 | print("__"*80) 21 | print("Imports finished.") 22 | print("Loading BERT Tokenizer and model...") 23 | ############################################################################################### 24 | # tokenizer = BertTokenizer.from_pretrained(config.BERT_PATH, do_lower_case=True) 25 | # model = BertModel.from_pretrained( 26 | # config.BERT_PATH, output_hidden_states=True).to(config.DEVICE) 27 | 28 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) 29 | model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased", output_hidden_states=True).to(config.DEVICE) 30 | model.eval() 31 | 32 | ############################################################################################### 33 | print("BERT tokenizer and model loaded.") 34 | 35 | def sent_emb(sent): 36 | encoded_dict = tokenizer.encode_plus( 37 | sent, 38 | add_special_tokens=True, 39 | max_length=128, # This is changed. 40 | pad_to_max_length=True, 41 | return_attention_mask=True, 42 | return_tensors='pt', 43 | truncation=True 44 | ) 45 | 46 | input_ids = encoded_dict['input_ids'] 47 | attention_masks = encoded_dict['attention_mask'] 48 | 49 | with torch.no_grad(): 50 | ### token embeddings: 51 | outputs = model(input_ids.to(config.DEVICE), 52 | attention_masks.to(config.DEVICE)) 53 | hidden_states = outputs[2] 54 | token_embeddings = torch.stack(hidden_states, dim=0) 55 | 56 | ### sentence embeddings: 57 | token_vecs = hidden_states[-2][0] 58 | sentence_embedding = torch.mean(token_vecs, dim=0) 59 | sentence_embedding = sentence_embedding.view(1, -1) 60 | return sentence_embedding 61 | 62 | 63 | def max_len(): 64 | max_len = 0 65 | emb_lens = [] 66 | for bird_type in tqdm(sorted(os.listdir(config.ANNOTATIONS)), total=len(os.listdir(config.ANNOTATIONS))): 67 | for file in sorted(os.listdir(os.path.join(config.ANNOTATIONS, bird_type))): 68 | text = open(os.path.join(config.ANNOTATIONS, bird_type, file), "r").read().split('\n')[:-1] 69 | for annotation in text: 70 | input_ids = tokenizer.encode(annotation, add_special_tokens=True) 71 | r = len(input_ids) 72 | max_len = max(max_len, r) 73 | emb_lens.append(r) 74 | print(f"Saving emb lens to {config.EMB_LEN}") 75 | file = open(config.EMB_LEN, "w") 76 | for ele in emb_lens: 77 | file.write(str(ele) + "\n") 78 | return max_len 79 | 80 | def make_dir(dir_name): 81 | if not os.path.exists(dir_name): 82 | os.makedirs(dir_name) 83 | 84 | def generate_text_embs(): 85 | print(f"Saving text embeddings to {config.ANNOTATION_EMB}") 86 | try: 87 | make_dir(config.ANNOTATION_EMB) 88 | for bird_type in tqdm(sorted(os.listdir(config.ANNOTATIONS)), total=len(os.listdir(config.ANNOTATIONS))): 89 | make_dir(os.path.join(config.ANNOTATION_EMB, bird_type)) 90 | for file in sorted(os.listdir(os.path.join(config.ANNOTATIONS, bird_type))): 91 | make_dir(os.path.join(config.ANNOTATION_EMB, bird_type, file.replace(".txt", ""))) 92 | text = open(os.path.join(config.ANNOTATIONS, bird_type, file), "r").read().split('\n')[:-1] 93 | for annotation_id, annotation in enumerate(text): 94 | emb = sent_emb(annotation) 95 | torch.save(emb, os.path.join(config.ANNOTATION_EMB, bird_type, file.replace(".txt", ""), str(annotation_id) + ".pt")) 96 | except Exception as e: 97 | print(f"Error in {bird_type}/{file}") 98 | print(e) 99 | 100 | 101 | if __name__ == "__main__": 102 | # To determine mex_length in tokenizer.encode_plus() 103 | # print("Max length of annotations: ", max_len()) # 80 104 | 105 | # Generate sentence embeddings: 106 | if os.path.exists(config.ANNOTATION_EMB) and len(os.listdir(config.ANNOTATION_EMB)) == len(os.listdir(config.IMAGE_DIR)): # checking if we have all the embeddings folders created already 107 | print("Bert embeddings already exist. Skipping...") 108 | else: 109 | print("Generating bert embs...") 110 | generate_text_embs() -------------------------------------------------------------------------------- /input/src/dataset_check.py: -------------------------------------------------------------------------------- 1 | """Displays the 10 annotations and the corresponding picture. 2 | 3 | Authors: 4 | Abhiraj Tiwari (abhirajtiwari@gmail.com) 5 | Sahil Khose (sahilkhose18@gmail.com) 6 | """ 7 | import config 8 | 9 | import numpy as np 10 | import os 11 | import matplotlib.pyplot as plt 12 | import pickle 13 | import torch 14 | 15 | from PIL import Image 16 | from scipy.spatial.distance import cosine 17 | from tqdm import tqdm 18 | print("__"*80) 19 | print("Imports finished") 20 | print("__"*80) 21 | 22 | 23 | def display_specific(bird_type_no=0, file_no=0, file_idx=None): 24 | """ 25 | Prints annotations and displays images of a specific bird 26 | """ 27 | bird_type = sorted(os.listdir(config.IMAGE_DIR))[bird_type_no] 28 | file = sorted(os.listdir(os.path.join(config.ANNOTATIONS, bird_type)))[file_no] 29 | 30 | if file_idx is None: 31 | filename = os.path.join(bird_type, file) 32 | else: 33 | filenames = np.array(pickle.load(open("../data/birds/train/filenames.pickle", "rb"), encoding='latin1')) 34 | filename = filenames[file_idx] 35 | filename += ".txt" 36 | 37 | print(f"\nFile: {filename}\n") 38 | 39 | text = open(os.path.join(config.ANNOTATIONS, filename), "r").read().split('\n')[:-1] 40 | [print(f"{idx}: {line}") for idx, line in enumerate(text)] 41 | filename = filename.replace(".txt", ".jpg") 42 | plt.imshow(plt.imread(os.path.join(config.IMAGE_DIR, filename))) 43 | plt.show() 44 | 45 | def compare_bert_emb(file_1, file_2, emb_no=0): 46 | emb_1 = torch.load(os.path.join(config.ANNOTATION_EMB, file_1, f"{emb_no}.pt"), map_location="cpu") 47 | emb_2 = torch.load(os.path.join(config.ANNOTATION_EMB, file_2, f"{emb_no}.pt"), map_location="cpu") 48 | # print(emb_1.shape) # (1, 768) 49 | 50 | bert_sim = 1 - cosine(emb_1, emb_2) 51 | print(f"cosine similarity bert emb: {bert_sim:.2f}") 52 | 53 | def compare_cnn_emb(emb_idx_1, emb_idx_2, emb_no=0): 54 | embeddings = np.array(pickle.load(open("../data/birds/train/char-CNN-RNN-embeddings.pickle", "rb"), encoding='latin1')) 55 | # print(embeddings.shape) # (8855, 10, 1024) 56 | 57 | cnn_sim = 1 - cosine(embeddings[emb_idx_1][emb_no], embeddings[emb_idx_2][emb_no]) 58 | print(f"cosine similarity cnn embs: {cnn_sim:.2f}") 59 | 60 | def compare_embedding_quality(emb_idx_1=0, emb_idx_2=1, emb_no=0): 61 | ###* Filenames to fetch embs: 62 | filenames = np.array(pickle.load(open("../data/birds/train/filenames.pickle", "rb"), encoding='latin1')) 63 | # print(filenames.shape) # (8855, ) 64 | 65 | ###* File paths: 66 | file_1 = filenames[emb_idx_1] 67 | file_2 = filenames[emb_idx_2] 68 | 69 | print(f"File 1: {file_1}") 70 | print(f"File 2: {file_2}\n") 71 | 72 | ###* Annotations: 73 | text1 = open(os.path.join(config.ANNOTATIONS, file_1+".txt"), "r").read().split('\n')[:-1] 74 | text2 = open(os.path.join(config.ANNOTATIONS, file_2+".txt"), "r").read().split('\n')[:-1] 75 | print("Annotation 1: ", text1[emb_no]) 76 | print("Annotation 2: ", text2[emb_no]) 77 | print() 78 | 79 | ###* Cosine similarity: 80 | compare_cnn_emb(emb_idx_1, emb_idx_2, emb_no=0) 81 | compare_bert_emb(file_1, file_2, emb_no=0) 82 | 83 | ###* Display images: 84 | fig = plt.figure() 85 | fig.add_subplot(1, 2, 1) 86 | plt.imshow(plt.imread(os.path.join(config.IMAGE_DIR, file_1 + ".jpg"))) 87 | fig.add_subplot(1, 2, 2) 88 | plt.imshow(plt.imread(os.path.join(config.IMAGE_DIR, file_2 + ".jpg"))) 89 | # plt.show() 90 | 91 | def check_model(file_idx, model): 92 | import sys 93 | sys.path.insert(1, "../../src/") 94 | import layers 95 | emb_no = 0 96 | ###* load the models 97 | netG = layers.Stage1Generator().cuda() 98 | netG.load_state_dict(torch.load(model)) 99 | netG.eval() 100 | with torch.no_grad(): 101 | ###* load the embeddings 102 | filenames = np.array(pickle.load(open("../data/birds/train/filenames.pickle", "rb"), encoding='latin1')) 103 | file_name = filenames[file_idx] 104 | emb = torch.load(os.path.join(config.ANNOTATION_EMB, file_name, f"{emb_no}.pt")) 105 | 106 | ###* Forward pass 107 | print(emb.shape) # (1, 768) 108 | noise = torch.autograd.Variable(torch.FloatTensor(1, 100)).cuda() 109 | noise.data.normal_(0, 1) 110 | _, fake_image, mu, logvar = netG(emb, noise) 111 | fake_image = fake_image.squeeze(0) 112 | print(fake_image.shape) #(3, 64, 64) 113 | 114 | im_save(fake_image, count=0) 115 | return fake_image 116 | 117 | def im_save(fake_img, count=0): 118 | save_name = f"{count}.png" 119 | im = fake_img.cpu().numpy() 120 | im = (im + 1.0) * 127.5 121 | im = im.astype(np.uint8) 122 | # print("im", im.shape) 123 | im = np.transpose(im, (1, 2, 0)) 124 | # print("im", im.shape) 125 | im = Image.fromarray(im) 126 | im.save(save_name) 127 | 128 | if __name__ == "__main__": 129 | # display_specific(bird_type_no=0, file_no=0) # old method 130 | display_specific(file_idx=14) # new method 131 | 132 | print("__"*80) 133 | # compare_embedding_quality(emb_idx_1=0, emb_idx_2=1, emb_no=0) 134 | ###* emb_idx < 8855, emb_no < 10 135 | 136 | 137 | print("__"*80) 138 | check_model(file_idx=14, 139 | model="../../old_outputs/output-3/model/netG_epoch_110.pth") 140 | 141 | plt.show() -------------------------------------------------------------------------------- /src/engine.py: -------------------------------------------------------------------------------- 1 | """Train and Eval functions 2 | 3 | Authors: 4 | Abhiraj Tiwari (abhirajtiwari@gmail.com) 5 | Sahil Khose (sahilkhose18@gmail.com) 6 | """ 7 | import util 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from torch.utils.tensorboard import summary 13 | from torch.utils.tensorboard import FileWriter 14 | from tqdm import tqdm 15 | 16 | def KL_loss(mu, logvar): 17 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 18 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 19 | KLD = torch.mean(KLD_element).mul_(-0.5) 20 | return KLD 21 | 22 | def disc_loss(disc, real_imgs, fake_imgs, real_labels, fake_labels, conditional_vector): 23 | loss_fn = nn.BCELoss() 24 | batch_size = real_imgs.shape[0] 25 | cond = conditional_vector.detach() 26 | fake_imgs = fake_imgs.detach() 27 | 28 | real_loss = loss_fn(disc(cond, real_imgs), real_labels) 29 | fake_loss = loss_fn(disc(cond, fake_imgs), fake_labels) 30 | wrong_loss = loss_fn(disc(cond[1:], real_imgs[:-1]), fake_labels[1:]) 31 | loss = real_loss + (fake_loss+wrong_loss)*0.5 32 | return loss, real_loss, wrong_loss, fake_loss 33 | 34 | def gen_loss(disc, fake_imgs, real_labels, conditional_vector): 35 | loss_fn = nn.BCELoss() 36 | cond = conditional_vector.detach() 37 | fake_loss = loss_fn(disc(cond, fake_imgs), real_labels) 38 | return fake_loss 39 | 40 | def weights_init(m): 41 | classname = m.__class__.__name__ 42 | if classname.find('Conv') != -1: 43 | m.weight.data.normal_(0.0, 0.02) 44 | elif classname.find('BatchNorm') != -1: 45 | m.weight.data.normal_(1.0, 0.02) 46 | m.bias.data.fill_(0) 47 | elif classname.find('Linear') != -1: 48 | m.weight.data.normal_(0.0, 0.02) 49 | if m.bias is not None: 50 | m.bias.data.fill_(0.0) 51 | 52 | 53 | def train_new_fn(data_loader, args, netG, netD, real_labels, fake_labels, noise, fixed_noise, optimizerD, optimizerG, epoch, count, summary_writer): 54 | errD_, errD_real_, errD_wrong_, errD_fake_, errG_, kl_loss_ = 0, 0, 0, 0, 0, 0 55 | for batch_id, data in tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Train Epoch {epoch}/{args.TRAIN_MAX_EPOCH}"): 56 | ###* Prepare training data: 57 | text_emb, real_images = data 58 | text_emb = text_emb.to(args.device) 59 | real_images = real_images.to(args.device) 60 | 61 | ###* Generate fake images: 62 | noise.data.normal_(0, 1) 63 | _, fake_images, mu, logvar = netG(text_emb, noise) 64 | 65 | ###* Update D network: 66 | netD.zero_grad() 67 | errD, errD_real, errD_wrong, errD_fake = disc_loss(netD, real_images, fake_images, real_labels, fake_labels, text_emb) 68 | errD.backward() 69 | optimizerD.step() 70 | 71 | ###* Update G network: 72 | netG.zero_grad() 73 | errG = gen_loss(netD, fake_images, real_labels, text_emb) 74 | kl_loss = KL_loss(mu, logvar) 75 | errG_total = errG + kl_loss * args.TRAIN_COEFF_KL 76 | errG_total.backward() 77 | optimizerG.step() 78 | 79 | count += 1 80 | 81 | if batch_id % 10 == 0: 82 | summary_D = summary.scalar("D_loss", errD.data) 83 | summary_D_r = summary.scalar("D_loss_real", errD_real.data) 84 | summary_D_w = summary.scalar("D_loss_wrong", errD_wrong.data) 85 | summary_D_f = summary.scalar("D_loss_fake", errD_fake.data) 86 | summary_G = summary.scalar("G_loss", errG.data) 87 | summary_KL = summary.scalar("KL_loss", kl_loss.data) 88 | 89 | summary_writer.add_summary(summary_D, count) 90 | summary_writer.add_summary(summary_D_r, count) 91 | summary_writer.add_summary(summary_D_w, count) 92 | summary_writer.add_summary(summary_D_f, count) 93 | summary_writer.add_summary(summary_G, count) 94 | summary_writer.add_summary(summary_KL, count) 95 | 96 | ###* save the image result for each epoch: 97 | lr_fake, fake, _, _ = netG(text_emb, fixed_noise) 98 | util.save_img_results(real_images, fake, epoch, args) 99 | if lr_fake is not None: 100 | util.save_img_results(None, lr_fake, epoch, args) 101 | 102 | errD_ += errD 103 | errD_real_ += errD_real 104 | errD_wrong_ += errD_wrong 105 | errD_fake_ += errD_fake 106 | errG_ += errG 107 | kl_loss_ += kl_loss 108 | 109 | errD_ /= len(data_loader) 110 | errD_real_ /= len(data_loader) 111 | errD_wrong_ /= len(data_loader) 112 | errD_fake_ /= len(data_loader) 113 | errG_ /= len(data_loader) 114 | kl_loss_ /= len(data_loader) 115 | 116 | return errD_, errD_real_, errD_wrong_, errD_fake_, errG_, kl_loss_, count 117 | 118 | 119 | 120 | 121 | def eval_fn(data_loader, model, device, epoch): 122 | model.eval() 123 | fin_y = [] 124 | fin_outputs = [] 125 | LOSS = 0. 126 | 127 | with torch.no_grad(): 128 | for batch_id, data in tqdm(enumerate(data_loader), total=len(data_loader)): 129 | text_embs, images = data 130 | 131 | # Loading it to device 132 | text_embs = text_embs.to(device, dtype=torch.float) 133 | images = images.to(device, dtype=torch.float) 134 | 135 | # getting outputs from model and calculating loss 136 | outputs = model(text_embs, images) 137 | loss = loss_fn(outputs, images) # TODO figure this out 138 | LOSS += loss 139 | 140 | # for calculating accuracy and other metrics # TODO figure this out 141 | fin_y.extend(images.view(-1, 1).cpu().detach().numpy().tolist()) 142 | fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist()) 143 | 144 | LOSS /= len(data_loader) 145 | return fin_outputs, fin_y, LOSS 146 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """Test a model and generate submission CSV. 2 | 3 | > python3 train.py --conf ../cfg/s1.yml 4 | 5 | Usage: 6 | > python train.py --load_path PATH --name NAME 7 | where 8 | > PATH is a path to a checkpoint (e.g., save/train/model-01/best.pth.tar) 9 | > NAME is a name to identify the train run 10 | 11 | Authors: 12 | Abhiraj Tiwari (abhirajtiwari@gmail.com) 13 | Sahil Khose (sahilkhose18@gmail.com) 14 | """ 15 | import args 16 | # import config 17 | import dataset 18 | import engine 19 | import layers 20 | import util 21 | 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | import os 25 | import pandas as pd 26 | import time 27 | import torch 28 | import torchfile 29 | 30 | 31 | from PIL import Image 32 | from sklearn import metrics 33 | from sklearn import model_selection 34 | from torch.utils.tensorboard import summary 35 | from torch.utils.tensorboard import FileWriter 36 | from torch.autograd import Variable 37 | 38 | print("__"*80) 39 | print("Imports Done...") 40 | 41 | 42 | def load_stage1(args): 43 | #* Init models and weights: 44 | from layers import Stage1Generator, Stage1Discriminator 45 | if args.embedding_type == "bert": 46 | netG = Stage1Generator(emb_dim=768) 47 | netD = Stage1Discriminator(emb_dim=768) 48 | else: 49 | netG = Stage1Generator(emb_dim=1024) 50 | netD = Stage1Discriminator(emb_dim=1024) 51 | 52 | netG.apply(engine.weights_init) 53 | netD.apply(engine.weights_init) 54 | 55 | #* Load saved model: 56 | if args.NET_G_path != "": 57 | netG.load_state_dict(torch.load(args.NET_G_path)) 58 | print("__"*80) 59 | print("Generator loaded from: ", args.NET_G_path) 60 | print("__"*80) 61 | if args.NET_D_path != "": 62 | netD.load_state_dict(torch.load(args.NET_D_path)) 63 | print("__"*80) 64 | print("Discriminator loaded from: ", args.NET_D_path) 65 | print("__"*80) 66 | 67 | #* Load on device: 68 | if args.device == "cuda": 69 | netG.cuda() 70 | netD.cuda() 71 | 72 | print("__"*80) 73 | print("GENERATOR:") 74 | print(netG) 75 | print("__"*80) 76 | print("DISCRIMINATOR:") 77 | print(netD) 78 | print("__"*80) 79 | 80 | return netG, netD 81 | 82 | 83 | def load_stage2(args): 84 | #* Init models and weights: 85 | from layers import Stage2Generator, Stage2Discriminator, Stage1Generator 86 | if args.embedding_type == "bert": 87 | Stage1_G = Stage1Generator(emb_dim=768) 88 | netG = Stage2Generator(Stage1_G, emb_dim=768) 89 | netD = Stage2Discriminator(emb_dim=768) 90 | else: 91 | Stage1_G = Stage1Generator(emb_dim=1024) 92 | netG = Stage2Generator(Stage1_G, emb_dim=1024) 93 | netD = Stage2Discriminator(emb_dim=1024) 94 | netG.apply(engine.weights_init) 95 | netD.apply(engine.weights_init) 96 | 97 | #* Load saved model: 98 | if args.NET_G_path != "": 99 | netG.load_state_dict(torch.load(args.NET_G_path)) 100 | print("Generator loaded from: ", args.NET_G_path) 101 | elif args.STAGE1_G_path != "": 102 | netG.stage1_gen.load_state_dict(torch.load(args.STAGE1_G_path)) 103 | print("Generator 1 loaded from: ", args.STAGE1_G_path) 104 | else: 105 | print("Please give the Stage 1 generator path") 106 | return 107 | 108 | if args.NET_D_path != "": 109 | netD.load_state_dict(torch.load(args.NET_D_path)) 110 | print("Discriminator loaded from: ", args.NET_D_path) 111 | 112 | #* Load on device: 113 | if args.device == "cuda": 114 | netG.cuda() 115 | netD.cuda() 116 | 117 | print("__"*80) 118 | print(netG) 119 | print("__"*80) 120 | print(netD) 121 | print("__"*80) 122 | 123 | return netG, netD 124 | 125 | 126 | def run(args): 127 | if args.STAGE == 1: 128 | netG, netD = load_stage1(args) 129 | else: 130 | netG, netD = load_stage2(args) 131 | 132 | # Setting up device 133 | device = torch.device(args.device) 134 | 135 | # Load model 136 | netG.to(device) 137 | netD.to(device) 138 | 139 | nz = args.n_z 140 | batch_size = args.train_bs 141 | noise = Variable(torch.FloatTensor(batch_size, nz)).to(device) 142 | with torch.no_grad(): 143 | fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)).to(device) # volatile=True 144 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)).to(device) 145 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)).to(device) 146 | 147 | gen_lr = args.TRAIN_GEN_LR 148 | disc_lr = args.TRAIN_DISC_LR 149 | 150 | lr_decay_step = args.TRAIN_LR_DECAY_EPOCH 151 | 152 | optimizerD = torch.optim.Adam(netD.parameters(), lr=args.TRAIN_DISC_LR, betas=(0.5, 0.999)) 153 | 154 | netG_para = [] 155 | for p in netG.parameters(): 156 | if p.requires_grad: 157 | netG_para.append(p) 158 | optimizerG = torch.optim.Adam(netG_para, lr=args.TRAIN_GEN_LR, betas=(0.5, 0.999)) 159 | 160 | count = 0 161 | 162 | if args.embedding_type == "bert": 163 | training_set = dataset.CUBDataset(pickl_file=args.train_filenames, img_dir=args.images_dir, bert_emb=args.bert_annotations_dir, stage=args.STAGE) 164 | testing_set = dataset.CUBDataset(pickl_file=args.test_filenames, img_dir=args.images_dir, bert_emb=args.bert_annotations_dir, stage=args.STAGE) 165 | else: 166 | training_set = dataset.CUBDataset(pickl_file=args.train_filenames, img_dir=args.images_dir, cnn_emb=args.cnn_annotations_emb_train, stage=args.STAGE) 167 | testing_set = dataset.CUBDataset(pickl_file=args.test_filenames, img_dir=args.images_dir, cnn_emb=args.cnn_annotations_emb_test, stage=args.STAGE) 168 | train_data_loader = torch.utils.data.DataLoader(training_set, batch_size=args.train_bs, num_workers=args.train_workers) 169 | test_data_loader = torch.utils.data.DataLoader(testing_set, batch_size=args.test_bs, num_workers=args.test_workers) 170 | # util.check_dataset(training_set) 171 | # util.check_dataset(testing_set) 172 | 173 | 174 | # best_accuracy = 0 175 | 176 | util.make_dir(args.image_save_dir) 177 | util.make_dir(args.model_dir) 178 | util.make_dir(args.log_dir) 179 | summary_writer = FileWriter(args.log_dir) 180 | 181 | for epoch in range(1, args.TRAIN_MAX_EPOCH+1): 182 | print("__"*80) 183 | start_t = time.time() 184 | 185 | if epoch % lr_decay_step == 0 and epoch > 0: 186 | gen_lr *= 0.5 187 | for param_group in optimizerG.param_groups: 188 | param_group["lr"] = gen_lr 189 | disc_lr *= 0.5 190 | for param_group in optimizerD.param_groups: 191 | param_group["lr"] = disc_lr 192 | 193 | errD, errD_real, errD_wrong, errD_fake, errG, kl_loss, count = engine.train_new_fn( 194 | train_data_loader, args, netG, netD, real_labels, fake_labels, 195 | noise, fixed_noise, optimizerD, optimizerG, epoch, count, summary_writer) 196 | 197 | end_t = time.time() 198 | 199 | print(f"[{epoch}/{args.TRAIN_MAX_EPOCH}] Loss_D: {errD:.4f}, Loss_G: {errG:.4f}, Loss_KL: {kl_loss:.4f}, Loss_real: {errD_real:.4f}, Loss_wrong: {errD_wrong:.4f}, Loss_fake: {errD_fake:.4f}, Total Time: {end_t-start_t :.2f} sec") 200 | if epoch % args.TRAIN_SNAPSHOT_INTERVAL == 0 or epoch == 1: 201 | util.save_model(netG, netD, epoch, args) 202 | 203 | util.save_model(netG, netD, args.TRAIN_MAX_EPOCH, args) 204 | summary_writer.close() 205 | 206 | 207 | def sample(args, datapath): 208 | if args.STAGE == 1: 209 | netG, _ = load_stage1(args) 210 | else: 211 | netG, _ = load_stage2(args) 212 | netG.eval() 213 | 214 | ###* Load text embeddings generated from the encoder: 215 | t_file = torchfile.load(datapath) 216 | captions_list = t_file.raw_txt 217 | embeddings = np.concatenate(t_file.fea_txt, axis=0) 218 | num_embeddings = len(captions_list) 219 | print(f"Successfully load sentences from: {args.datapath}") 220 | print(f"Total number of sentences: {num_embeddings}") 221 | print(f"Num embeddings: {num_embeddings} {embeddings.shape}") 222 | 223 | ###* Path to save generated samples: 224 | save_dir = args.NET_G[:args.NET_G.find(".pth")] 225 | util.make_dir(save_dir) 226 | 227 | batch_size = np.minimum(num_embeddings, args.train_bs) 228 | nz = args.n_z 229 | noise = Variable(torch.FloatTensor(batch_size, nz)) 230 | noise = noise.to(args.device) 231 | count = 0 232 | while count < num_embeddings: 233 | if count > 3000: 234 | break 235 | iend = count + batch_size 236 | if iend > num_embeddings: 237 | iend = num_embeddings 238 | count = num_embeddings - batch_size 239 | embeddings_batch = embeddings[count:iend] 240 | # captions_batch = captions_list[count:iend] 241 | text_embedding = Variable(torch.FloatTensor(embeddings_batch)) 242 | text_embedding = text_embedding.to(args.device) 243 | 244 | ###* Generate fake images: 245 | noise.data.normal_(0, 1) 246 | _, fake_imgs, mu, logvar = netG(text_embedding, noise) 247 | for i in range(batch_size): 248 | save_name = f"{save_dir}/{count+i}.png" 249 | im = fake_imgs[i].data.cpu().numpy() 250 | im = (im + 1.0) * 127.5 251 | im = im.astype(np.uint8) 252 | # print("im", im.shape) 253 | im = np.transpose(im, (1, 2, 0)) 254 | # print("im", im.shape) 255 | im = Image.fromarray(im) 256 | im.save(save_name) 257 | count += batch_size 258 | 259 | if __name__ == "__main__": 260 | args_ = args.get_all_args() 261 | args.print_args(args_) 262 | run(args_) 263 | # datapath = os.path.join(args_.datapath, "test/val_captions.t7") 264 | # sample(args_, datapath) 265 | -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | """Command-line arguments. 2 | 3 | Authors: 4 | Abhiraj Tiwari (abhirajtiwari@gmail.com) 5 | Sahil Khose (sahilkhose18@gmail.com) 6 | """ 7 | # > python3 args.py --conf ../cfg/s1.yml 8 | import argparse 9 | import yaml 10 | from json import dumps 11 | 12 | ########################################### MAIN FUNCTIONS ########################################### 13 | 14 | def get_all_args(): 15 | return parse_args(get_parser()) 16 | 17 | 18 | def print_args(args): 19 | print("__"*80) 20 | print(f'ARGUMENTS: \n{dumps(vars(args), indent=4, sort_keys=True)}') 21 | print("__"*80) 22 | 23 | 24 | def parse_args(parser): 25 | args = parser.parse_args() 26 | if args.config_file: 27 | data = yaml.load(args.config_file, Loader=yaml.Loader) 28 | delattr(args, "config_file") 29 | arg_dict = args.__dict__ 30 | for key, value in data.items(): 31 | if isinstance(value, list): 32 | for v in value: 33 | arg_dict[key].append(v) 34 | else: 35 | arg_dict[key] = value 36 | return args 37 | 38 | 39 | def get_parser(): 40 | """Get all parameters""" 41 | parser = argparse.ArgumentParser("Arguments") 42 | 43 | add_conf_args(parser) 44 | add_train_args(parser) 45 | add_model_args(parser) 46 | add_data_args(parser) 47 | 48 | return parser 49 | 50 | 51 | ########################################### PARSER ARGUMENTS ########################################### 52 | def add_conf_args(parser): 53 | parser.add_argument("--config-file", 54 | type=argparse.FileType(mode="r"), 55 | dest="config_file", 56 | help="config yaml file to pass params") 57 | 58 | 59 | def add_train_args(parser): 60 | """ train.py """ 61 | parser.add_argument("--NET_G_path", 62 | type=str, 63 | default="", 64 | help="Generator model loading path") 65 | parser.add_argument("--NET_D_path", 66 | type=str, 67 | default="", 68 | help="Discriminator model loading path") 69 | parser.add_argument("--STAGE1_G_path", 70 | type=str, 71 | default="../old_outputs/output_0/model/netG_epoch_120.pth", 72 | help="Stage 1 Generator model path for Stage 2 training") 73 | parser.add_argument("--train_bs", 74 | type=int, 75 | default=2, 76 | help="train batch size") 77 | parser.add_argument("--test_bs", 78 | type=int, 79 | default=1, 80 | help="test batch size") 81 | parser.add_argument("--train_workers", 82 | type=int, 83 | default=1, 84 | help="train num_workers") 85 | parser.add_argument("--test_workers", 86 | type=int, 87 | default=1, 88 | help="test num_workers") 89 | parser.add_argument("--TRAIN_GEN_LR", 90 | type=float, 91 | default=2e-4, 92 | help="train generator learning rate") 93 | parser.add_argument("--TRAIN_DISC_LR", 94 | type=float, 95 | default=2e-4, 96 | help="test discriminator learning rate") 97 | parser.add_argument("--TRAIN_LR_DECAY_EPOCH", 98 | type=int, 99 | default=20, 100 | help="train lr decay epoch") 101 | parser.add_argument("--TRAIN_MAX_EPOCH", 102 | type=int, 103 | default=120, 104 | help="train maximum epochs") 105 | parser.add_argument("--TRAIN_SNAPSHOT_INTERVAL", 106 | type=int, 107 | default=5, 108 | help="Snapshot interval") 109 | parser.add_argument("--TRAIN_COEFF_KL", 110 | type=float, 111 | default=2.0, 112 | help="train coefficient KL") 113 | parser.add_argument("--dataset_name", 114 | type=str, 115 | default="birds", 116 | help="birds/flowers: dataset name") 117 | parser.add_argument("--embedding_type", 118 | type=str, 119 | default="bert", 120 | help="bert/cnn-rnn: embedding type") 121 | parser.add_argument("--datapath", 122 | type=str, 123 | default="/gdrive/MyDrive/ganctober_training/output", 124 | help="datapath dir") 125 | parser.add_argument("--image_save_dir", 126 | type=str, 127 | default="/gdrive/MyDrive/ganctober_training/output/image/", 128 | help="Image save dir") 129 | parser.add_argument("--model_dir", 130 | type=str, 131 | default="/gdrive/MyDrive/ganctober_training/output/model/", 132 | help="Model save dir") 133 | parser.add_argument("--log_dir", 134 | type=str, 135 | default="/gdrive/MyDrive/ganctober_training/output/log/", 136 | help="Log dir for tensorboard") 137 | parser.add_argument("--VIS_COUNT", 138 | type=int, 139 | default=64, 140 | help="") 141 | parser.add_argument("--STAGE", 142 | type=int, 143 | default=2, 144 | help="Stage to train/eval (1/2)") 145 | parser.add_argument("--device", 146 | type=str, 147 | default="cuda", #! CHANGE THIS TO CUDA BEFORE TRAINING 148 | help="Device type: cuda/cpu") 149 | 150 | 151 | def add_model_args(parser): 152 | """ 153 | Refer to StackGAN paper: https://arxiv.org/pdf/1612.03242.pdf 154 | for parameter names. 155 | """ 156 | parser.add_argument("--n_g", 157 | type=int, 158 | default=128, 159 | help="") 160 | parser.add_argument("--n_z", 161 | type=int, 162 | default=100, 163 | help="") 164 | parser.add_argument("--m_d", 165 | type=int, 166 | default=4, 167 | help="") 168 | parser.add_argument("--m_g", 169 | type=int, 170 | default=16, 171 | help="") 172 | parser.add_argument("--n_d", 173 | type=int, 174 | default=128, 175 | help="") 176 | parser.add_argument("--w_0", 177 | type=int, 178 | default=64, 179 | help="") 180 | parser.add_argument("--h_0", 181 | type=int, 182 | default=256, 183 | help="") 184 | parser.add_argument("--w", 185 | type=int, 186 | default=256, 187 | help="") 188 | parser.add_argument("--h", 189 | type=int, 190 | default=256, 191 | help="") 192 | 193 | 194 | def add_data_args(parser): 195 | '''Get all data paths''' 196 | ###* Directories: 197 | parser.add_argument("--annotations_dir", 198 | type=str, 199 | default="../input/data/birds/text_c10/", 200 | help="Annotations dir path") 201 | parser.add_argument("--bert_annotations_dir", 202 | type=str, 203 | default="../input/data/birds/embeddings/", 204 | help="Annotations BERT embeddings dir path") 205 | parser.add_argument("--bert_path", 206 | type=str, 207 | default="../input/data/bert_base_uncased/", 208 | help="Bert model dir path") 209 | parser.add_argument("--images_dir", 210 | type=str, 211 | default="../input/data/CUB_200_2011/images/", 212 | help="Images dir path") 213 | 214 | ###* Files: 215 | add_birds_file_args(parser) 216 | add_cub_file_args(parser) 217 | 218 | 219 | def add_birds_file_args(parser): 220 | """ 221 | Paths for files under input/data/birds 222 | files: 223 | - filenames.pickle 224 | List[str] {filename DOES NOT contain .jpg (or any) extenstion} 225 | To fetch the image: input/data/CUB_200_2011/images/.jpg == /.jpg 226 | To fetch the text annotation: input/data/birds/text_c10/.txt == /.txt 227 | To fetch the bert annotation: input/data/birds/embeddings//[0-9].pt == //[0-9].pt 228 | There are 2 such files: 229 | input/data/birds/train/filenames.pickle : len = 8855 230 | input/data/birds/test/filenames.pickle : len = 2933 231 | """ 232 | parser.add_argument("--cnn_annotations_emb_train", 233 | type=str, 234 | default="../input/data/birds/train/char-CNN-RNN-embeddings.pickle", 235 | help="char-CNN-RNN-embeddings pickle file for train") 236 | parser.add_argument("--cnn_annotations_emb_test", 237 | type=str, 238 | default="../input/data/birds/test/char-CNN-RNN-embeddings.pickle", 239 | help="char-CNN-RNN-embeddings pickle file for test") 240 | parser.add_argument("--train_filenames", 241 | type=str, 242 | default="../input/data/birds/train/filenames.pickle", 243 | help="Pickle file path: filenames for train set") 244 | parser.add_argument("--test_filenames", 245 | type=str, 246 | default="../input/data/birds/test/filenames.pickle", 247 | help="Pickle file path: filenames for test set") 248 | 249 | 250 | def add_cub_file_args(parser): 251 | """ 252 | Paths for files under input/data/CUB_200_2011 253 | files: 254 | - images.txt 255 | {image_name contains .jpg extenstion} 256 | To fetch the image: input/data/CUB_200_2011/images/ 257 | number_of_images = 11788 258 | 259 | - train_test_split.txt 260 | 261 | 1: train, 0: test 262 | train size = 5994 263 | test size = 5794 264 | 265 | - bounding_boxes.txt 266 | 267 | """ 268 | parser.add_argument("--images_id_file", 269 | type=str, 270 | default="../input/data/CUB_200_2011/images.txt", 271 | help="Text file path: mapping image id to image path") # 272 | parser.add_argument("--train_test_split_file", 273 | type=str, 274 | default="../input/data/CUB_200_2011/train_test_split.txt", 275 | help="Text file path: mapping image id to train/test split") # 276 | parser.add_argument("--bounding_boxes", 277 | type=str, 278 | default="../input/data/CUB_200_2011/bounding_boxes.txt", 279 | help="Text file path: mapping image id to train/test split") # 280 | 281 | 282 | if __name__ == "__main__": 283 | args = get_all_args() 284 | print_args(args) 285 | print("args:", args.train_bs) 286 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | """Assortment of layers for use in models.py. 2 | Refer to StackGAN paper: https://arxiv.org/pdf/1612.03242.pdf 3 | for variable names and working. 4 | 5 | Authors: 6 | Abhiraj Tiwari (abhirajtiwari@gmail.com) 7 | Sahil Khose (sahilkhose18@gmail.com) 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | 12 | from torch.autograd import Variable 13 | 14 | 15 | def conv3x3(in_channels, out_channels): 16 | """3x3 conv with same padding""" 17 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) 18 | 19 | 20 | class ResBlock(nn.Module): 21 | def __init__(self, channel_num): 22 | super(ResBlock, self).__init__() 23 | self.block = nn.Sequential( 24 | conv3x3(channel_num, channel_num), 25 | nn.BatchNorm2d(channel_num), 26 | nn.ReLU(True), 27 | conv3x3(channel_num, channel_num), 28 | nn.BatchNorm2d(channel_num)) 29 | self.relu = nn.ReLU(inplace=True) 30 | 31 | def forward(self, x): 32 | residual = x 33 | out = self.block(x) 34 | out += residual 35 | out = self.relu(out) 36 | return out 37 | 38 | 39 | 40 | def _downsample(in_channels, out_channels): 41 | return nn.Sequential( 42 | nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), 43 | nn.BatchNorm2d(out_channels), 44 | nn.LeakyReLU(0.2, inplace=True) 45 | ) 46 | 47 | def _upsample(in_channels, out_channels): 48 | return nn.Sequential( 49 | nn.Upsample(scale_factor=2, mode='nearest'), 50 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 51 | nn.BatchNorm2d(out_channels), 52 | nn.ReLU(inplace=True) 53 | ) 54 | 55 | class CAug(nn.Module): 56 | """Module for conditional augmentation. 57 | Takes input as bert embeddings of annotations and sends output to Stage 1 and 2 generators. 58 | """ 59 | def __init__ (self, emb_dim=768, n_g=128, device="cuda"): #! CHANGE THIS TO CUDA 60 | """ 61 | @param emb_dim (int) : Size of annotation embeddings. 62 | @param n_g (int) : Dimension of mu, epsilon and c_0_hat 63 | @param device (torch.device) : cuda/cpu 64 | """ 65 | super(CAug, self).__init__() 66 | self.emb_dim = emb_dim 67 | self.n_g = n_g 68 | self.fc = nn.Linear(self.emb_dim, self.n_g*2, bias=True) # To split in mu and sigma 69 | self.relu = nn.ReLU() 70 | self.device = device 71 | 72 | def forward(self, text_emb): 73 | """ 74 | @param text_emb (torch.tensor): Text embedding. (batch, emb_dim) 75 | @returns c_0_hat (torch.tensor): Gaussian conditioning variable. (batch, n_g) 76 | """ 77 | enc = self.relu(self.fc(text_emb)).squeeze(1) # (batch, n_g*2) 78 | 79 | mu = enc[:, :self.n_g] # (batch, n_g) 80 | logvar = enc[:, self.n_g:] # (batch, n_g) 81 | 82 | sigma = (logvar * 0.5).exp_() # exp(logvar * 0.5) = exp(log(var^0.5)) = sqrt(var) = std 83 | 84 | epsilon = Variable(torch.FloatTensor(sigma.size()).normal_()) 85 | 86 | c_0_hat = epsilon.to(self.device) * sigma + mu # (batch, n_g) 87 | 88 | return c_0_hat, mu, logvar 89 | 90 | ######################### STAGE 1 ######################### 91 | 92 | class Stage1Generator(nn.Module): 93 | """ 94 | Stage 1 generator. 95 | Takes in input from Conditional Augmentation and outputs 64x64 image to Stage1Discrimantor. 96 | """ 97 | def __init__(self, n_g=128, n_z=100, emb_dim=768): 98 | """ 99 | @param n_g (int) : Dimension of c_0_hat. 100 | @param n_z (int) : Dimension of noise vector. 101 | """ 102 | super(Stage1Generator, self).__init__() 103 | self.n_g = n_g 104 | self.n_z = n_z 105 | self.emb_dim = emb_dim 106 | self.inp_ch = self.n_g*8 107 | 108 | # (batch, bert_size) -> (batch, n_g) 109 | self.caug = CAug(emb_dim=self.emb_dim) 110 | 111 | # (batch, n_g + n_z) -> (batch, inp_ch * 4 * 4) 112 | self.fc = nn.Sequential( 113 | nn.Linear(self.n_g + self.n_z, self.inp_ch * 4 * 4, bias=False), 114 | nn.BatchNorm1d(self.inp_ch * 4 * 4), 115 | nn.ReLU(True) 116 | ) 117 | 118 | # (batch, inp_ch, 4, 4) -> (batch, inp_ch//2, 8, 8) 119 | self.up1 = _upsample(self.inp_ch, self.inp_ch // 2) 120 | # -> (batch, inp_ch//4, 16, 16) 121 | self.up2 = _upsample(self.inp_ch // 2, self.inp_ch // 4) 122 | # -> (batch, inp_ch//8, 32, 32) 123 | self.up3 = _upsample(self.inp_ch // 4, self.inp_ch // 8) 124 | # -> (batch, inp_ch//16, 64, 64) 125 | self.up4 = _upsample(self.inp_ch // 8, self.inp_ch // 16) 126 | 127 | # -> (batch, 3, 64, 64) 128 | self.img = nn.Sequential( 129 | conv3x3(self.inp_ch // 16, 3), 130 | nn.Tanh() 131 | ) 132 | 133 | def forward(self, text_emb, noise): 134 | """ 135 | @param c_0_hat (torch.tensor) : Output of Conditional Augmentation (batch, n_g) 136 | @returns out (torch.tensor) : Generator 1 image output (batch, 3, 64, 64) 137 | """ 138 | c_0_hat, mu, logvar = self.caug(text_emb) 139 | 140 | # -> (batch, n_g + n_z) (batch, 128 + 100) 141 | c_z = torch.cat((c_0_hat, noise), dim=1) 142 | 143 | # -> (batch, 1024 * 4 * 4) 144 | inp = self.fc(c_z) 145 | 146 | # -> (batch, 1024, 4, 4) 147 | inp = inp.view(-1, self.inp_ch, 4, 4) 148 | 149 | inp = self.up1(inp) # (batch, 512, 8, 8) 150 | inp = self.up2(inp) # (batch, 256, 16, 16) 151 | inp = self.up3(inp) # (batch, 128, 32, 32) 152 | inp = self.up4(inp) # (batch, 64, 64, 64) 153 | 154 | fake_img = self.img(inp) # (batch, 3, 64, 64) 155 | return None, fake_img, mu, logvar 156 | 157 | class Stage1Discriminator(nn.Module): 158 | """ 159 | Stage 1 discriminator 160 | """ 161 | def __init__(self, n_d=128, m_d=4, emb_dim=768, img_dim=64): 162 | super(Stage1Discriminator, self).__init__() 163 | self.n_d = n_d 164 | self.m_d = m_d 165 | self.emb_dim = emb_dim 166 | 167 | self.fc_for_text = nn.Linear(self.emb_dim, self.n_d) 168 | self.down_sample = nn.Sequential( 169 | # (batch, 3, 64, 64) -> (batch, img_dim, 32, 32) 170 | nn.Conv2d(3, img_dim, kernel_size=4, stride=2, padding=1, bias=False), # (batch, 64, 32, 32) 171 | nn.LeakyReLU(0.2, inplace=True), 172 | # -> (batch, img_dim * 2, 16, 16) 173 | _downsample(img_dim, img_dim*2), # (batch, 128, 16, 16) 174 | # -> (batch, img_dim * 4, 8, 8) 175 | _downsample(img_dim*2, img_dim*4), # (batch, 256, 8, 8) 176 | # -> (batch, img_dim * 8, 4, 4) 177 | _downsample(img_dim*4, img_dim*8) # (batch, 512, 4, 4) 178 | ) 179 | 180 | self.out_logits = nn.Sequential( 181 | # (batch, img_dim*8 + n_d, 4, 4) -> (batch, img_dim*8, 4, 4) 182 | conv3x3(img_dim*8 + self.n_d, img_dim*8), 183 | nn.BatchNorm2d(img_dim*8), 184 | nn.LeakyReLU(0.2, inplace=True), 185 | # -> (batch, 1) 186 | nn.Conv2d(img_dim*8, 1, kernel_size=4, stride=4), 187 | nn.Sigmoid() 188 | ) 189 | 190 | def forward(self, text_emb, img): 191 | # image encode 192 | enc = self.down_sample(img) 193 | 194 | # text emb 195 | compressed = self.fc_for_text(text_emb) 196 | compressed = compressed.unsqueeze(2).unsqueeze(3).repeat(1, 1, self.m_d, self.m_d) 197 | 198 | con = torch.cat((enc, compressed), dim=1) 199 | 200 | output = self.out_logits(con) 201 | return output.view(-1) 202 | 203 | 204 | ######################### STAGE 2 ######################### 205 | class Stage2Generator(nn.Module): 206 | """ 207 | Stage 2 generator. 208 | Takes in input from Conditional Augmentation and outputs 256x256 image to Stage2Discrimantor. 209 | """ 210 | def __init__(self, stage1_gen, n_g=128, n_z=100, ef_size=128, n_res=4, emb_dim=768): 211 | """ 212 | @param n_g (int) : Dimension of c_0_hat. 213 | """ 214 | super(Stage2Generator, self).__init__() 215 | self.n_g = n_g 216 | self.n_z = n_z 217 | self.ef_size = ef_size 218 | self.n_res = n_res 219 | self.emb_dim = emb_dim 220 | 221 | self.stage1_gen = stage1_gen 222 | # Freezing the stage 1 generator: 223 | for param in self.stage1_gen.parameters(): 224 | param.requires_grad = False 225 | 226 | # (batch, bert_size) -> (batch, n_g) 227 | self.caug = CAug(emb_dim=self.emb_dim) 228 | 229 | # -> (batch, n_g*4, 16, 16) 230 | self.encoder = nn.Sequential( 231 | conv3x3(3, n_g), # (batch, 128, 64, 64) 232 | nn.LeakyReLU(0.2, inplace=True), #? Paper: leaky, code: relu 233 | _downsample(n_g, n_g*2), # (batch, 256, 32, 32) 234 | _downsample(n_g*2, n_g*4) # (batch, 512, 16, 16) 235 | ) 236 | 237 | # (batch, ef_size + n_g * 4, 16, 16) -> (batch, n_g * 4, 16, 16) 238 | # (batch, 128 + 512, 16, 16) -> (batch, 512, 16, 16) 239 | self.cat_conv = nn.Sequential( 240 | conv3x3(self.ef_size + self.n_g * 4, self.n_g * 4), 241 | nn.BatchNorm2d(self.n_g * 4), 242 | nn.ReLU(inplace=True) 243 | ) 244 | 245 | # -> (batch, n_g * 4, 16, 16) 246 | # (batch, 512, 16, 16) 247 | self.residual = nn.Sequential( 248 | *[ 249 | ResBlock(self.n_g * 4) for _ in range(self.n_res) 250 | ] 251 | ) 252 | 253 | # -> (batch, n_g * 2, 32, 32) 254 | self.up1 = _upsample(n_g * 4, n_g * 2) # (batch, 256, 32, 32) 255 | # -> (batch, n_g, 64, 64) 256 | self.up2 = _upsample(n_g * 2, n_g) # (batch, 128, 64, 64) 257 | # -> (batch, n_g // 2, 128, 128) 258 | self.up3 = _upsample(n_g, n_g // 2) # (batch, 64, 128, 128) 259 | # -> (batch, n_g // 4, 256, 256) 260 | self.up4 = _upsample(n_g // 2, n_g // 4) # (batch, 32, 256, 256) 261 | 262 | # (batch, 3, 256, 256) 263 | self.img = nn.Sequential( 264 | conv3x3(n_g // 4, 3), 265 | nn.Tanh() 266 | ) 267 | 268 | def forward(self, text_emb, noise): 269 | """ 270 | @param c_0_hat (torch.tensor) : Output of Conditional Augmentation (batch, n_g) 271 | @param s1_image (torch.tensor) : Ouput of Stage 1 Generator (batch, 3, 64, 64) 272 | @returns out (torch.tensor) : Generator 2 image output (batch, 3, 256, 256) 273 | """ 274 | _, stage1_img, _, _ = self.stage1_gen(text_emb, noise) 275 | stage1_img = stage1_img.detach() 276 | 277 | encoded_img = self.encoder(stage1_img) 278 | 279 | c_0_hat, mu, logvar = self.caug(text_emb) 280 | c_0_hat = c_0_hat.unsqueeze(2).unsqueeze(3).repeat(1, 1, 16, 16) 281 | 282 | # -> (batch, ef_size + n_g * 4, 16, 16) # (batch, 640, 16, 16) 283 | concat_out = torch.cat((encoded_img, c_0_hat), dim=1) 284 | 285 | # -> (batch, n_g * 4, 16, 16) 286 | h_out = self.cat_conv(concat_out) 287 | h_out = self.residual(h_out) 288 | 289 | h_out = self.up1(h_out) 290 | h_out = self.up2(h_out) 291 | h_out = self.up3(h_out) 292 | # -> (batch, ng // 4, 256, 256) 293 | h_out = self.up4(h_out) 294 | 295 | # -> (batch, 3, 256, 256) 296 | fake_img = self.img(h_out) 297 | 298 | return stage1_img, fake_img, mu, logvar 299 | 300 | 301 | class Stage2Discriminator(nn.Module): 302 | """ 303 | Stage 2 discriminator 304 | """ 305 | 306 | def __init__(self, n_d=128, m_d=4, emb_dim=768, img_dim=256): 307 | super(Stage2Discriminator, self).__init__() 308 | self.n_d = n_d 309 | self.m_d = m_d 310 | self.emb_dim = emb_dim 311 | 312 | self.fc_for_text = nn.Linear(self.emb_dim, self.n_d) 313 | self.down_sample = nn.Sequential( 314 | # (batch, 3, 64, 64) -> (batch, img_dim//4, 128, 128) 315 | nn.Conv2d(3, img_dim//4, kernel_size=4, stride=2, padding=1, bias=False), # (batch, 64, 128, 128) 316 | nn.LeakyReLU(0.2, inplace=True), 317 | # -> (batch, img_dim//2, 64, 64) 318 | _downsample(img_dim//4, img_dim//2), # (batch, 128, 64, 64) 319 | # -> (batch, img_dim, 32, 32) 320 | _downsample(img_dim//2, img_dim), # (batch, 256, 32, 32) 321 | # -> (batch, img_dim*2, 16, 16) 322 | _downsample(img_dim, img_dim*2), # (batch, 512, 16, 16) 323 | # -> (batch, img_dim*4, 8, 8) 324 | _downsample(img_dim*2, img_dim*4), # (batch, 1024, 8, 8) 325 | # -> (batch, img_dim*8, 4, 4) 326 | _downsample(img_dim*4, img_dim*8), # (batch, 2096, 4, 4) 327 | # -> (batch, img_dim*4, 4, 4) 328 | conv3x3(img_dim*8, img_dim*4), # (batch, 1024, 4, 4) 329 | nn.BatchNorm2d(img_dim*4), 330 | nn.LeakyReLU(0.2, inplace=True), 331 | # -> (batch, img_dim*2, 4, 4) 332 | conv3x3(img_dim * 4, img_dim * 2), # (batch, 512, 4, 4) 333 | nn.BatchNorm2d(img_dim * 2), 334 | nn.LeakyReLU(0.2, inplace=True) 335 | ) 336 | 337 | self.out_logits = nn.Sequential( 338 | # (batch, img_dim*2 + n_d, 4, 4) -> (batch, img_dim*2, 4, 4) 339 | conv3x3(img_dim*2 + self.n_d, img_dim*2), 340 | nn.BatchNorm2d(img_dim*2), 341 | nn.LeakyReLU(0.2, inplace=True), 342 | # -> (batch, 1) 343 | nn.Conv2d(img_dim*2, 1, kernel_size=4, stride=4), 344 | nn.Sigmoid() 345 | ) 346 | 347 | def forward(self, text_emb, img): 348 | # image encode 349 | enc = self.down_sample(img) 350 | 351 | # text emb 352 | compressed = self.fc_for_text(text_emb) 353 | compressed = compressed.unsqueeze(2).unsqueeze(3).repeat(1, 1, self.m_d, self.m_d) 354 | 355 | con = torch.cat((enc, compressed), dim=1) 356 | 357 | output = self.out_logits(con) 358 | return output.view(-1) 359 | 360 | ######################### ######################### 361 | 362 | 363 | if __name__ == "__main__": 364 | batch_size = 2 365 | n_z = 100 366 | emb_dim = 1024 # 768 367 | emb = torch.randn((batch_size, emb_dim)) 368 | noise = torch.empty((batch_size, n_z)).normal_() 369 | 370 | generator1 = Stage1Generator(emb_dim=emb_dim) 371 | generator2 = Stage2Generator(generator1, emb_dim=emb_dim) 372 | 373 | discriminator1 = Stage1Discriminator(emb_dim=emb_dim) 374 | discriminator2 = Stage2Discriminator(emb_dim=emb_dim) 375 | 376 | 377 | _, gen1, _, _ = generator1(emb, noise) 378 | print("output1 image dimensions :", gen1.size()) # (batch_size, 3, 64, 64) 379 | assert gen1.shape == (batch_size, 3, 64, 64) 380 | print() 381 | 382 | disc1 = discriminator1(emb, gen1) 383 | print("output1 discriminator", disc1.size()) # (batch_size) 384 | # assert disc1.shape == (batch_size) 385 | print() 386 | 387 | _, gen2, _, _ = generator2(emb, noise) 388 | print("output2 image dimensions :", gen2.size()) # (batch_size, 3, 256, 256) 389 | assert gen2.shape == (batch_size, 3, 256, 256) 390 | print() 391 | 392 | disc2 = discriminator2(emb, gen2) 393 | print("output2 discriminator", disc2.size()) # (batch_size) 394 | # assert disc2.shape == (batch_size) 395 | print() 396 | 397 | ca = CAug(emb_dim=emb_dim, n_g=128, device='cpu') 398 | out_ca, _, _ = ca(emb) 399 | print("Conditional Aug output size: ", out_ca.size()) # (batch_size, 128) 400 | assert out_ca.shape == (batch_size, 128) 401 | 402 | 403 | ###* Checking init weights 404 | # import engine 405 | # netG = Stage1Generator() 406 | # netG.apply(engine.weights_init) 407 | pass 408 | --------------------------------------------------------------------------------